Skip to content

Commit

Permalink
Simplify relay parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
cffls committed Feb 26, 2024
1 parent c845908 commit e6c1093
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 92 deletions.
97 changes: 47 additions & 50 deletions pycardano/pool_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import List, Optional, Type, Union

from pycardano.crypto.bech32 import bech32_decode
from pycardano.exception import DeserializeException
from pycardano.hash import (
PoolKeyHash,
PoolMetadataHash,
Expand All @@ -21,9 +22,20 @@
ArrayCBORSerializable,
CBORSerializable,
limit_primitive_type,
list_hook,
)

__all__ = [
"PoolId",
"PoolMetadata",
"PoolParams",
"Relay",
"SingleHostAddr",
"SingleHostName",
"MultiHostName",
"RelayCBORSerializer",
"is_bech32_cardano_pool_id",
]


def is_bech32_cardano_pool_id(pool_id: str) -> bool:
"""Check if a string is a valid Cardano stake pool ID in bech32 format."""
Expand Down Expand Up @@ -158,11 +170,14 @@ def to_primitive(self) -> list:
def from_primitive(
cls: Type[SingleHostAddr], values: Union[list, tuple]
) -> SingleHostAddr:
return cls(
port=values[1],
ipv4=values[2],
ipv6=values[3],
)
if values[0] == 0:
return cls(
port=values[1],
ipv4=values[2],
ipv6=values[3],
)
else:
raise DeserializeException(f"Invalid SingleHostAddr type {values[0]}")


@dataclass(repr=False)
Expand All @@ -180,10 +195,13 @@ def __post_init__(self):
def from_primitive(
cls: Type[SingleHostName], values: Union[list, tuple]
) -> SingleHostName:
return cls(
port=values[1],
dns_name=values[2],
)
if values[0] == 1:
return cls(
port=values[1],
dns_name=values[2],
)
else:
raise DeserializeException(f"Invalid SingleHostName type {values[0]}")


@dataclass(repr=False)
Expand All @@ -200,9 +218,12 @@ def __post_init__(self):
def from_primitive(
cls: Type[MultiHostName], values: Union[list, tuple]
) -> MultiHostName:
return cls(
dns_name=values[1],
)
if values[0] == 2:
return cls(
dns_name=values[1],
)
else:
raise DeserializeException(f"Invalid MultiHostName type {values[0]}")


Relay = Union[SingleHostAddr, SingleHostName, MultiHostName]
Expand All @@ -214,37 +235,17 @@ class PoolMetadata(ArrayCBORSerializable):
pool_metadata_hash: PoolMetadataHash


@dataclass(repr=False)
class FractionSerializer(CBORSerializable, Fraction, ABC):
@classmethod
@limit_primitive_type(Fraction, str, list)
def from_primitive(
cls: Type[Fraction], fraction: Union[Fraction, str, list]
) -> Fraction:
if isinstance(fraction, Fraction):
return Fraction(int(fraction.numerator), int(fraction.denominator))
elif isinstance(fraction, str):
numerator, denominator = fraction.split("/")
return Fraction(int(numerator), int(denominator))
elif isinstance(fraction, list):
numerator, denominator = fraction[1]
return Fraction(int(numerator), int(denominator))


@dataclass(repr=False)
class RelayCBORSerializer(ArrayCBORSerializable):
@classmethod
@limit_primitive_type(list)
def from_primitive(
cls: Type[RelayCBORSerializer], values: Union[list, tuple]
) -> Relay | None:
if values[0] == 0:
return SingleHostAddr.from_primitive(values)
elif values[0] == 1:
return SingleHostName.from_primitive(values)
elif values[0] == 2:
return MultiHostName.from_primitive(values)
return None
def fraction_parser(fraction: Union[Fraction, str, list]) -> Fraction:
if isinstance(fraction, Fraction):
return Fraction(int(fraction.numerator), int(fraction.denominator))
elif isinstance(fraction, str):
numerator, denominator = fraction.split("/")
return Fraction(int(numerator), int(denominator))
elif isinstance(fraction, list):
numerator, denominator = fraction[1]
return Fraction(int(numerator), int(denominator))
else:
raise ValueError(f"Invalid fraction type {fraction}")


@dataclass(repr=False)
Expand All @@ -253,13 +254,9 @@ class PoolParams(ArrayCBORSerializable):
vrf_keyhash: VrfKeyHash
pledge: int
cost: int
margin: Fraction = field(
metadata={"object_hook": FractionSerializer.from_primitive}
)
margin: Fraction = field(metadata={"object_hook": fraction_parser})
reward_account: RewardAccountHash
pool_owners: List[VerificationKeyHash]
relays: Optional[List[Relay]] = field(
metadata={"object_hook": list_hook(RelayCBORSerializer)},
)
relays: Optional[List[Relay]] = None
pool_metadata: Optional[PoolMetadata] = None
id: Optional[PoolId] = field(default=None, metadata={"optional": True})
10 changes: 8 additions & 2 deletions test/pycardano/backend/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,10 @@ def genesis_file(genesis_json):

yield genesis_file_path

genesis_file_path.unlink()
try:
genesis_file_path.unlink()
except FileNotFoundError:
pass


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -190,4 +193,7 @@ def config_file():

yield config_file_path

config_file_path.unlink()
try:
config_file_path.unlink()
except FileNotFoundError:
pass
46 changes: 6 additions & 40 deletions test/pycardano/test_pool_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@
VrfKeyHash,
)
from pycardano.pool_params import ( # Fraction,
FractionSerializer,
MultiHostName,
PoolId,
PoolMetadata,
PoolParams,
RelayCBORSerializer,
SingleHostAddr,
SingleHostName,
fraction_parser,
is_bech32_cardano_pool_id,
)

Expand Down Expand Up @@ -208,30 +207,12 @@ def test_pool_metadata(url, pool_metadata_hash):
)
def test_fraction_serializer(input_value):
# Act
result = FractionSerializer.from_primitive(input_value)
result = fraction_parser(input_value)

# Assert
assert isinstance(result, Fraction)


@pytest.mark.parametrize(
"test_id, input_value, expected_output",
[
("HP-1", [0, 3001, "10.20.30.40", "::1"], SingleHostAddr),
("HP-2", [1, 3001, "example.com"], SingleHostName),
("HP-3", [2, "example.com"], MultiHostName),
],
)
def test_relay_cbor_serializer(test_id, input_value, expected_output):
# Act
result = RelayCBORSerializer.from_primitive(input_value)

# Assert
assert isinstance(
result, expected_output
), f"Test {test_id} failed: {result} != {expected_output}"


@pytest.mark.parametrize(
"operator, vrf_keyhash, pledge, cost, margin, reward_account, pool_owners, relays, pool_metadata",
[
Expand All @@ -245,7 +226,7 @@ def test_relay_cbor_serializer(test_id, input_value, expected_output):
b"1" * REWARD_ACCOUNT_HASH_SIZE,
[b"1" * VERIFICATION_KEY_HASH_SIZE],
[
[0, 3001, "10.20.30.40", None],
[0, 3001, SingleHostAddr.ipv4_to_bytes("10.20.30.40"), None],
[1, 3001, "example.com"],
[2, "example.com"],
],
Expand Down Expand Up @@ -286,26 +267,11 @@ def test_pool_params(
vrf_keyhash,
pledge,
cost,
FractionSerializer.from_primitive(margin),
fraction_parser(margin),
reward_account,
pool_owners,
[RelayCBORSerializer.from_primitive(x).to_primitive() for x in relays],
relays,
pool_metadata,
]

# Act
pool_params = PoolParams.from_primitive(primitive_values)

# Assert
assert isinstance(pool_params, PoolParams)
assert pool_params.operator == PoolKeyHash(operator)
assert pool_params.vrf_keyhash == VrfKeyHash(vrf_keyhash)
assert pool_params.pledge == pledge
assert pool_params.cost == cost
assert pool_params.margin == Fraction(margin)
assert pool_params.reward_account == RewardAccountHash(reward_account)
assert pool_params.pool_owners == [VerificationKeyHash(x) for x in pool_owners]
assert pool_params.relays == [RelayCBORSerializer.from_primitive(x) for x in relays]
assert pool_params.pool_metadata == PoolMetadata.from_primitive(pool_metadata)

assert pool_params.to_primitive() == primitive_out
assert PoolParams.from_primitive(primitive_values).to_primitive() == primitive_out

0 comments on commit e6c1093

Please sign in to comment.