Skip to content

Commit

Permalink
Add tests for address (#410)
Browse files Browse the repository at this point in the history
  • Loading branch information
tacmota authored Dec 21, 2024
1 parent dd3932d commit d7d2053
Show file tree
Hide file tree
Showing 3 changed files with 267 additions and 20 deletions.
7 changes: 7 additions & 0 deletions pycardano/key.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,13 @@ def from_signing_key(
) -> StakeKeyPair:
return cls(signing_key, StakeVerificationKey.from_signing_key(signing_key))

def __eq__(self, other):
if isinstance(other, StakeKeyPair):
return (
other.signing_key == self.signing_key
and other.verification_key == self.verification_key
)


class StakePoolSigningKey(SigningKey):
KEY_TYPE = "StakePoolSigningKey_ed25519"
Expand Down
207 changes: 188 additions & 19 deletions test/pycardano/test_address.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
from unittest import TestCase
import pytest

from pycardano.address import Address, PointerAddress
from pycardano.exception import DeserializeException
from pycardano.address import Address, AddressType, PointerAddress
from pycardano.exception import (
DecodingException,
DeserializeException,
InvalidAddressInputException,
)
from pycardano.hash import (
SCRIPT_HASH_SIZE,
VERIFICATION_KEY_HASH_SIZE,
ScriptHash,
VerificationKeyHash,
)
from pycardano.key import PaymentVerificationKey
from pycardano.network import Network

Expand All @@ -20,25 +30,184 @@ def test_payment_addr():
)


class PointerAddressTest(TestCase):
def test_from_primitive_invalid_value(self):
with self.assertRaises(DeserializeException):
PointerAddress.from_primitive(1)
def test_to_primitive_pointer_addr():
assert PointerAddress(1, 2, 3).to_primitive() == b"\x01\x02\x03"

with self.assertRaises(DeserializeException):
PointerAddress.from_primitive([])

with self.assertRaises(DeserializeException):
PointerAddress.from_primitive({})
def test_from_primitive_pointer_addr():
assert PointerAddress.from_primitive(
b"\x01\x02\x03"
) == PointerAddress.from_primitive(b"\x01\x02\x03")


class AddressTest(TestCase):
def test_from_primitive_invalid_value(self):
with self.assertRaises(DeserializeException):
Address.from_primitive(1)
def test_from_primitive_invalid_value_pointer_addr():
with pytest.raises(DecodingException):
PointerAddress.decode(data=b"\x01\x02")

with self.assertRaises(DeserializeException):
Address.from_primitive([])
with pytest.raises(DeserializeException):
PointerAddress.from_primitive(1)

with self.assertRaises(DeserializeException):
Address.from_primitive({})
with pytest.raises(DeserializeException):
PointerAddress.from_primitive([])

with pytest.raises(DeserializeException):
PointerAddress.from_primitive({})


def test_equality_pointer_addr():
assert PointerAddress(1, 2, 3) == PointerAddress(1, 2, 3)


def test_inequality_different_values_pointer_addr():
assert PointerAddress(1, 2, 3) != PointerAddress(4, 5, 6)


def test_inequality_not_pointer_addr():
assert PointerAddress(1, 2, 3) != (1, 2, 3)


def test_inequality_null_pointer_addr():
assert PointerAddress(1, 2, 3) != None


def test_self_equality_pointer_addr():
assert PointerAddress(1, 2, 3) == PointerAddress(1, 2, 3)


def test_from_primitive_invalid_value_addr():
with pytest.raises(DeserializeException):
Address.from_primitive(1)

with pytest.raises(DeserializeException):
Address.from_primitive([])

with pytest.raises(DeserializeException):
Address.from_primitive({})


def test_key_script_addr():
address = Address(
VerificationKeyHash(b"1" * VERIFICATION_KEY_HASH_SIZE),
ScriptHash(b"1" * SCRIPT_HASH_SIZE),
)
assert address.address_type == AddressType.KEY_SCRIPT


def test_script_key_addr():
address = Address(
ScriptHash(b"1" * SCRIPT_HASH_SIZE),
VerificationKeyHash(b"1" * VERIFICATION_KEY_HASH_SIZE),
)
assert address.address_type == AddressType.SCRIPT_KEY


def test_script_point_addr():
address = Address(ScriptHash(b"1" * SCRIPT_HASH_SIZE), PointerAddress(1, 2, 3))
assert address.address_type == AddressType.SCRIPT_POINTER


def test_none_script_hash_addr():
address = Address(None, ScriptHash(b"1" * SCRIPT_HASH_SIZE))
assert address.address_type == AddressType.NONE_SCRIPT


def test_invalid_combination_unhandled_types_addr():
class UnknownType:
pass

with pytest.raises(InvalidAddressInputException):
Address(UnknownType(), UnknownType())


def test_equality_same_values_addr():
a1 = Address(
VerificationKeyHash(b"1" * VERIFICATION_KEY_HASH_SIZE),
ScriptHash(b"1" * SCRIPT_HASH_SIZE),
)
a2 = Address(
VerificationKeyHash(b"1" * VERIFICATION_KEY_HASH_SIZE),
ScriptHash(b"1" * SCRIPT_HASH_SIZE),
)
assert a1 == a2


def test_inequality_not_address_addr():
a1 = Address(
VerificationKeyHash(b"1" * VERIFICATION_KEY_HASH_SIZE),
ScriptHash(b"1" * SCRIPT_HASH_SIZE),
)
not_address = (1, 2, 3)
assert a1 != not_address


def test_from_primitive_address_type_key_script_addr():
header = AddressType.KEY_SCRIPT.value << 4
payment = b"\x01" * VERIFICATION_KEY_HASH_SIZE
staking = b"\x02" * SCRIPT_HASH_SIZE
value = bytes([header]) + payment + staking

address = Address.from_primitive(value)

assert isinstance(address.payment_part, VerificationKeyHash)

assert isinstance(address.staking_part, ScriptHash)


def test_from_primitive_type_verification_key_hash_addr():
header = AddressType.KEY_POINTER.value << 4
payment = b"\x01" * VERIFICATION_KEY_HASH_SIZE
staking = b"\x01\x02\x03"
value = bytes([header]) + payment + staking

address = Address.from_primitive(value)

assert isinstance(address.payment_part, VerificationKeyHash)

assert isinstance(address.staking_part, PointerAddress)


def test_from_primitive_staking_script_hash_addr():
header = AddressType.SCRIPT_KEY.value << 4
payment = b"\x01" * SCRIPT_HASH_SIZE
staking = b"\x02" * VERIFICATION_KEY_HASH_SIZE
value = bytes([header]) + payment + staking

address = Address.from_primitive(value)

assert isinstance(address.payment_part, ScriptHash)

assert isinstance(address.staking_part, VerificationKeyHash)


def test_from_primitive_payment_script_hash_addr():
header = AddressType.SCRIPT_POINTER.value << 4
payment = b"\x01" * SCRIPT_HASH_SIZE
staking = b"\x01\x02\x03"
value = bytes([header]) + payment + staking

address = Address.from_primitive(value)

assert isinstance(address.payment_part, ScriptHash)


def test_from_primitive_type_none_addr():
header = AddressType.NONE_SCRIPT.value << 4
payment = b"\x01" * 14
staking = b"\x02" * 14
value = bytes([header]) + payment + staking

address = Address.from_primitive(value)

assert address.payment_part is None

assert isinstance(address.staking_part, ScriptHash)


def test_from_primitive_invalid_type_addr():
header = AddressType.BYRON.value << 4
payment = b"\x01" * 14
staking = b"\x02" * 14
value = bytes([header]) + payment + staking

with pytest.raises(DeserializeException):
Address.from_primitive(value)
73 changes: 72 additions & 1 deletion test/pycardano/test_key.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import json
import pathlib
import tempfile

import pytest
from mnemonic import Mnemonic

from pycardano import HDWallet
from pycardano import HDWallet, StakeKeyPair, StakeSigningKey, StakeVerificationKey
from pycardano.exception import InvalidKeyTypeException
from pycardano.key import (
ExtendedSigningKey,
ExtendedVerificationKey,
Key,
PaymentExtendedSigningKey,
PaymentKeyPair,
PaymentSigningKey,
Expand Down Expand Up @@ -66,6 +70,53 @@
)


def test_invalid_key_type():
data = json.dumps(
{
"type": "invalid_type",
"payload": "example_payload",
"description": "example_description",
}
)

with pytest.raises(InvalidKeyTypeException):
Key.from_json(data, validate_type=True)


def test_bytes_conversion():
assert bytes(Key(b"1234")) == b"1234"


def test_eq_not_instance():
assert Key(b"hello") != "1234"


def test_from_hdwallet_missing_xprivate_key():
with pytest.raises(InvalidKeyTypeException):
ExtendedSigningKey(b"1234").from_hdwallet(
HDWallet(
b"root_xprivate_key",
b"root_public_key",
b"root_chain_code",
None,
b"valid_public_key",
chain_code=b"valid_chain_code",
)
)

with pytest.raises(InvalidKeyTypeException):
ExtendedSigningKey(b"1234").from_hdwallet(
HDWallet(
b"root_xprivate_key",
b"root_public_key",
b"root_chain_code",
b"valid_xprivate_key",
b"valid_public_key",
None,
)
)


def test_payment_key():
assert (
SK.payload
Expand Down Expand Up @@ -123,12 +174,25 @@ def test_extended_payment_key_sign():


def test_key_pair():
PaymentSigningKey.generate()


def test_payment_key_pair():
PaymentKeyPair.generate()
sk = PaymentSigningKey.generate()
vk = PaymentVerificationKey.from_signing_key(sk)
assert PaymentKeyPair(sk, vk) == PaymentKeyPair.from_signing_key(sk)


def test_stake_key_pair():
StakeKeyPair.generate()
sk = StakeSigningKey.generate()
vk = StakeVerificationKey.from_signing_key(sk)
assert StakeKeyPair(sk, vk) == StakeKeyPair.from_signing_key(sk)


def test_stake_pool_key_pair():
StakePoolKeyPair.generate()
sk = StakePoolSigningKey.generate()
vk = StakePoolVerificationKey.from_signing_key(sk)
assert StakePoolKeyPair(sk, vk) == StakePoolKeyPair.from_signing_key(sk)
Expand Down Expand Up @@ -158,6 +222,13 @@ def test_key_save():
assert SK == sk


def test_key_save_invalid_address():
with tempfile.NamedTemporaryFile() as f:
SK.save(f.name)
with pytest.raises(IOError):
VK.save(f.name)


def test_stake_pool_key_save():
with tempfile.NamedTemporaryFile() as skf, tempfile.NamedTemporaryFile() as vkf:
SPSK.save(skf.name)
Expand Down

0 comments on commit d7d2053

Please sign in to comment.