From e6c1093097a7b001093dc11a99e67e0946adb561 Mon Sep 17 00:00:00 2001 From: Jerry Date: Sun, 25 Feb 2024 21:30:05 -0800 Subject: [PATCH] Simplify relay parsing --- pycardano/pool_params.py | 97 +++++++++++++++--------------- test/pycardano/backend/conftest.py | 10 ++- test/pycardano/test_pool_params.py | 46 ++------------ 3 files changed, 61 insertions(+), 92 deletions(-) diff --git a/pycardano/pool_params.py b/pycardano/pool_params.py index 94955f99..e3121ce1 100644 --- a/pycardano/pool_params.py +++ b/pycardano/pool_params.py @@ -10,6 +10,7 @@ from typing import List, Optional, Type, Union from pycardano.crypto.bech32 import bech32_decode +from pycardano.exception import DeserializeException from pycardano.hash import ( PoolKeyHash, PoolMetadataHash, @@ -21,9 +22,20 @@ ArrayCBORSerializable, CBORSerializable, limit_primitive_type, - list_hook, ) +__all__ = [ + "PoolId", + "PoolMetadata", + "PoolParams", + "Relay", + "SingleHostAddr", + "SingleHostName", + "MultiHostName", + "RelayCBORSerializer", + "is_bech32_cardano_pool_id", +] + def is_bech32_cardano_pool_id(pool_id: str) -> bool: """Check if a string is a valid Cardano stake pool ID in bech32 format.""" @@ -158,11 +170,14 @@ def to_primitive(self) -> list: def from_primitive( cls: Type[SingleHostAddr], values: Union[list, tuple] ) -> SingleHostAddr: - return cls( - port=values[1], - ipv4=values[2], - ipv6=values[3], - ) + if values[0] == 0: + return cls( + port=values[1], + ipv4=values[2], + ipv6=values[3], + ) + else: + raise DeserializeException(f"Invalid SingleHostAddr type {values[0]}") @dataclass(repr=False) @@ -180,10 +195,13 @@ def __post_init__(self): def from_primitive( cls: Type[SingleHostName], values: Union[list, tuple] ) -> SingleHostName: - return cls( - port=values[1], - dns_name=values[2], - ) + if values[0] == 1: + return cls( + port=values[1], + dns_name=values[2], + ) + else: + raise DeserializeException(f"Invalid SingleHostName type {values[0]}") @dataclass(repr=False) @@ -200,9 +218,12 @@ def __post_init__(self): def from_primitive( cls: Type[MultiHostName], values: Union[list, tuple] ) -> MultiHostName: - return cls( - dns_name=values[1], - ) + if values[0] == 2: + return cls( + dns_name=values[1], + ) + else: + raise DeserializeException(f"Invalid MultiHostName type {values[0]}") Relay = Union[SingleHostAddr, SingleHostName, MultiHostName] @@ -214,37 +235,17 @@ class PoolMetadata(ArrayCBORSerializable): pool_metadata_hash: PoolMetadataHash -@dataclass(repr=False) -class FractionSerializer(CBORSerializable, Fraction, ABC): - @classmethod - @limit_primitive_type(Fraction, str, list) - def from_primitive( - cls: Type[Fraction], fraction: Union[Fraction, str, list] - ) -> Fraction: - if isinstance(fraction, Fraction): - return Fraction(int(fraction.numerator), int(fraction.denominator)) - elif isinstance(fraction, str): - numerator, denominator = fraction.split("/") - return Fraction(int(numerator), int(denominator)) - elif isinstance(fraction, list): - numerator, denominator = fraction[1] - return Fraction(int(numerator), int(denominator)) - - -@dataclass(repr=False) -class RelayCBORSerializer(ArrayCBORSerializable): - @classmethod - @limit_primitive_type(list) - def from_primitive( - cls: Type[RelayCBORSerializer], values: Union[list, tuple] - ) -> Relay | None: - if values[0] == 0: - return SingleHostAddr.from_primitive(values) - elif values[0] == 1: - return SingleHostName.from_primitive(values) - elif values[0] == 2: - return MultiHostName.from_primitive(values) - return None +def fraction_parser(fraction: Union[Fraction, str, list]) -> Fraction: + if isinstance(fraction, Fraction): + return Fraction(int(fraction.numerator), int(fraction.denominator)) + elif isinstance(fraction, str): + numerator, denominator = fraction.split("/") + return Fraction(int(numerator), int(denominator)) + elif isinstance(fraction, list): + numerator, denominator = fraction[1] + return Fraction(int(numerator), int(denominator)) + else: + raise ValueError(f"Invalid fraction type {fraction}") @dataclass(repr=False) @@ -253,13 +254,9 @@ class PoolParams(ArrayCBORSerializable): vrf_keyhash: VrfKeyHash pledge: int cost: int - margin: Fraction = field( - metadata={"object_hook": FractionSerializer.from_primitive} - ) + margin: Fraction = field(metadata={"object_hook": fraction_parser}) reward_account: RewardAccountHash pool_owners: List[VerificationKeyHash] - relays: Optional[List[Relay]] = field( - metadata={"object_hook": list_hook(RelayCBORSerializer)}, - ) + relays: Optional[List[Relay]] = None pool_metadata: Optional[PoolMetadata] = None id: Optional[PoolId] = field(default=None, metadata={"optional": True}) diff --git a/test/pycardano/backend/conftest.py b/test/pycardano/backend/conftest.py index 23571df5..ae878e7e 100644 --- a/test/pycardano/backend/conftest.py +++ b/test/pycardano/backend/conftest.py @@ -90,7 +90,10 @@ def genesis_file(genesis_json): yield genesis_file_path - genesis_file_path.unlink() + try: + genesis_file_path.unlink() + except FileNotFoundError: + pass @pytest.fixture(scope="session") @@ -190,4 +193,7 @@ def config_file(): yield config_file_path - config_file_path.unlink() + try: + config_file_path.unlink() + except FileNotFoundError: + pass diff --git a/test/pycardano/test_pool_params.py b/test/pycardano/test_pool_params.py index 461aa970..0f038ccf 100644 --- a/test/pycardano/test_pool_params.py +++ b/test/pycardano/test_pool_params.py @@ -17,14 +17,13 @@ VrfKeyHash, ) from pycardano.pool_params import ( # Fraction, - FractionSerializer, MultiHostName, PoolId, PoolMetadata, PoolParams, - RelayCBORSerializer, SingleHostAddr, SingleHostName, + fraction_parser, is_bech32_cardano_pool_id, ) @@ -208,30 +207,12 @@ def test_pool_metadata(url, pool_metadata_hash): ) def test_fraction_serializer(input_value): # Act - result = FractionSerializer.from_primitive(input_value) + result = fraction_parser(input_value) # Assert assert isinstance(result, Fraction) -@pytest.mark.parametrize( - "test_id, input_value, expected_output", - [ - ("HP-1", [0, 3001, "10.20.30.40", "::1"], SingleHostAddr), - ("HP-2", [1, 3001, "example.com"], SingleHostName), - ("HP-3", [2, "example.com"], MultiHostName), - ], -) -def test_relay_cbor_serializer(test_id, input_value, expected_output): - # Act - result = RelayCBORSerializer.from_primitive(input_value) - - # Assert - assert isinstance( - result, expected_output - ), f"Test {test_id} failed: {result} != {expected_output}" - - @pytest.mark.parametrize( "operator, vrf_keyhash, pledge, cost, margin, reward_account, pool_owners, relays, pool_metadata", [ @@ -245,7 +226,7 @@ def test_relay_cbor_serializer(test_id, input_value, expected_output): b"1" * REWARD_ACCOUNT_HASH_SIZE, [b"1" * VERIFICATION_KEY_HASH_SIZE], [ - [0, 3001, "10.20.30.40", None], + [0, 3001, SingleHostAddr.ipv4_to_bytes("10.20.30.40"), None], [1, 3001, "example.com"], [2, "example.com"], ], @@ -286,26 +267,11 @@ def test_pool_params( vrf_keyhash, pledge, cost, - FractionSerializer.from_primitive(margin), + fraction_parser(margin), reward_account, pool_owners, - [RelayCBORSerializer.from_primitive(x).to_primitive() for x in relays], + relays, pool_metadata, ] - # Act - pool_params = PoolParams.from_primitive(primitive_values) - - # Assert - assert isinstance(pool_params, PoolParams) - assert pool_params.operator == PoolKeyHash(operator) - assert pool_params.vrf_keyhash == VrfKeyHash(vrf_keyhash) - assert pool_params.pledge == pledge - assert pool_params.cost == cost - assert pool_params.margin == Fraction(margin) - assert pool_params.reward_account == RewardAccountHash(reward_account) - assert pool_params.pool_owners == [VerificationKeyHash(x) for x in pool_owners] - assert pool_params.relays == [RelayCBORSerializer.from_primitive(x) for x in relays] - assert pool_params.pool_metadata == PoolMetadata.from_primitive(pool_metadata) - - assert pool_params.to_primitive() == primitive_out + assert PoolParams.from_primitive(primitive_values).to_primitive() == primitive_out