diff --git a/pycardano/key.py b/pycardano/key.py index 26a05d12..77c64b7d 100644 --- a/pycardano/key.py +++ b/pycardano/key.py @@ -318,6 +318,13 @@ def from_signing_key( ) -> StakeKeyPair: return cls(signing_key, StakeVerificationKey.from_signing_key(signing_key)) + def __eq__(self, other): + if isinstance(other, StakeKeyPair): + return ( + other.signing_key == self.signing_key + and other.verification_key == self.verification_key + ) + class StakePoolSigningKey(SigningKey): KEY_TYPE = "StakePoolSigningKey_ed25519" diff --git a/test/pycardano/test_address.py b/test/pycardano/test_address.py index d239b1cf..b9ba8ddf 100644 --- a/test/pycardano/test_address.py +++ b/test/pycardano/test_address.py @@ -1,7 +1,17 @@ -from unittest import TestCase +import pytest -from pycardano.address import Address, PointerAddress -from pycardano.exception import DeserializeException +from pycardano.address import Address, AddressType, PointerAddress +from pycardano.exception import ( + DecodingException, + DeserializeException, + InvalidAddressInputException, +) +from pycardano.hash import ( + SCRIPT_HASH_SIZE, + VERIFICATION_KEY_HASH_SIZE, + ScriptHash, + VerificationKeyHash, +) from pycardano.key import PaymentVerificationKey from pycardano.network import Network @@ -20,25 +30,184 @@ def test_payment_addr(): ) -class PointerAddressTest(TestCase): - def test_from_primitive_invalid_value(self): - with self.assertRaises(DeserializeException): - PointerAddress.from_primitive(1) +def test_to_primitive_pointer_addr(): + assert PointerAddress(1, 2, 3).to_primitive() == b"\x01\x02\x03" - with self.assertRaises(DeserializeException): - PointerAddress.from_primitive([]) - with self.assertRaises(DeserializeException): - PointerAddress.from_primitive({}) +def test_from_primitive_pointer_addr(): + assert PointerAddress.from_primitive( + b"\x01\x02\x03" + ) == PointerAddress.from_primitive(b"\x01\x02\x03") -class AddressTest(TestCase): - def test_from_primitive_invalid_value(self): - with self.assertRaises(DeserializeException): - Address.from_primitive(1) +def test_from_primitive_invalid_value_pointer_addr(): + with pytest.raises(DecodingException): + PointerAddress.decode(data=b"\x01\x02") - with self.assertRaises(DeserializeException): - Address.from_primitive([]) + with pytest.raises(DeserializeException): + PointerAddress.from_primitive(1) - with self.assertRaises(DeserializeException): - Address.from_primitive({}) + with pytest.raises(DeserializeException): + PointerAddress.from_primitive([]) + + with pytest.raises(DeserializeException): + PointerAddress.from_primitive({}) + + +def test_equality_pointer_addr(): + assert PointerAddress(1, 2, 3) == PointerAddress(1, 2, 3) + + +def test_inequality_different_values_pointer_addr(): + assert PointerAddress(1, 2, 3) != PointerAddress(4, 5, 6) + + +def test_inequality_not_pointer_addr(): + assert PointerAddress(1, 2, 3) != (1, 2, 3) + + +def test_inequality_null_pointer_addr(): + assert PointerAddress(1, 2, 3) != None + + +def test_self_equality_pointer_addr(): + assert PointerAddress(1, 2, 3) == PointerAddress(1, 2, 3) + + +def test_from_primitive_invalid_value_addr(): + with pytest.raises(DeserializeException): + Address.from_primitive(1) + + with pytest.raises(DeserializeException): + Address.from_primitive([]) + + with pytest.raises(DeserializeException): + Address.from_primitive({}) + + +def test_key_script_addr(): + address = Address( + VerificationKeyHash(b"1" * VERIFICATION_KEY_HASH_SIZE), + ScriptHash(b"1" * SCRIPT_HASH_SIZE), + ) + assert address.address_type == AddressType.KEY_SCRIPT + + +def test_script_key_addr(): + address = Address( + ScriptHash(b"1" * SCRIPT_HASH_SIZE), + VerificationKeyHash(b"1" * VERIFICATION_KEY_HASH_SIZE), + ) + assert address.address_type == AddressType.SCRIPT_KEY + + +def test_script_point_addr(): + address = Address(ScriptHash(b"1" * SCRIPT_HASH_SIZE), PointerAddress(1, 2, 3)) + assert address.address_type == AddressType.SCRIPT_POINTER + + +def test_none_script_hash_addr(): + address = Address(None, ScriptHash(b"1" * SCRIPT_HASH_SIZE)) + assert address.address_type == AddressType.NONE_SCRIPT + + +def test_invalid_combination_unhandled_types_addr(): + class UnknownType: + pass + + with pytest.raises(InvalidAddressInputException): + Address(UnknownType(), UnknownType()) + + +def test_equality_same_values_addr(): + a1 = Address( + VerificationKeyHash(b"1" * VERIFICATION_KEY_HASH_SIZE), + ScriptHash(b"1" * SCRIPT_HASH_SIZE), + ) + a2 = Address( + VerificationKeyHash(b"1" * VERIFICATION_KEY_HASH_SIZE), + ScriptHash(b"1" * SCRIPT_HASH_SIZE), + ) + assert a1 == a2 + + +def test_inequality_not_address_addr(): + a1 = Address( + VerificationKeyHash(b"1" * VERIFICATION_KEY_HASH_SIZE), + ScriptHash(b"1" * SCRIPT_HASH_SIZE), + ) + not_address = (1, 2, 3) + assert a1 != not_address + + +def test_from_primitive_address_type_key_script_addr(): + header = AddressType.KEY_SCRIPT.value << 4 + payment = b"\x01" * VERIFICATION_KEY_HASH_SIZE + staking = b"\x02" * SCRIPT_HASH_SIZE + value = bytes([header]) + payment + staking + + address = Address.from_primitive(value) + + assert isinstance(address.payment_part, VerificationKeyHash) + + assert isinstance(address.staking_part, ScriptHash) + + +def test_from_primitive_type_verification_key_hash_addr(): + header = AddressType.KEY_POINTER.value << 4 + payment = b"\x01" * VERIFICATION_KEY_HASH_SIZE + staking = b"\x01\x02\x03" + value = bytes([header]) + payment + staking + + address = Address.from_primitive(value) + + assert isinstance(address.payment_part, VerificationKeyHash) + + assert isinstance(address.staking_part, PointerAddress) + + +def test_from_primitive_staking_script_hash_addr(): + header = AddressType.SCRIPT_KEY.value << 4 + payment = b"\x01" * SCRIPT_HASH_SIZE + staking = b"\x02" * VERIFICATION_KEY_HASH_SIZE + value = bytes([header]) + payment + staking + + address = Address.from_primitive(value) + + assert isinstance(address.payment_part, ScriptHash) + + assert isinstance(address.staking_part, VerificationKeyHash) + + +def test_from_primitive_payment_script_hash_addr(): + header = AddressType.SCRIPT_POINTER.value << 4 + payment = b"\x01" * SCRIPT_HASH_SIZE + staking = b"\x01\x02\x03" + value = bytes([header]) + payment + staking + + address = Address.from_primitive(value) + + assert isinstance(address.payment_part, ScriptHash) + + +def test_from_primitive_type_none_addr(): + header = AddressType.NONE_SCRIPT.value << 4 + payment = b"\x01" * 14 + staking = b"\x02" * 14 + value = bytes([header]) + payment + staking + + address = Address.from_primitive(value) + + assert address.payment_part is None + + assert isinstance(address.staking_part, ScriptHash) + + +def test_from_primitive_invalid_type_addr(): + header = AddressType.BYRON.value << 4 + payment = b"\x01" * 14 + staking = b"\x02" * 14 + value = bytes([header]) + payment + staking + + with pytest.raises(DeserializeException): + Address.from_primitive(value) diff --git a/test/pycardano/test_key.py b/test/pycardano/test_key.py index c5f0221b..9d9384be 100644 --- a/test/pycardano/test_key.py +++ b/test/pycardano/test_key.py @@ -1,12 +1,16 @@ +import json import pathlib import tempfile +import pytest from mnemonic import Mnemonic -from pycardano import HDWallet +from pycardano import HDWallet, StakeKeyPair, StakeSigningKey, StakeVerificationKey +from pycardano.exception import InvalidKeyTypeException from pycardano.key import ( ExtendedSigningKey, ExtendedVerificationKey, + Key, PaymentExtendedSigningKey, PaymentKeyPair, PaymentSigningKey, @@ -66,6 +70,53 @@ ) +def test_invalid_key_type(): + data = json.dumps( + { + "type": "invalid_type", + "payload": "example_payload", + "description": "example_description", + } + ) + + with pytest.raises(InvalidKeyTypeException): + Key.from_json(data, validate_type=True) + + +def test_bytes_conversion(): + assert bytes(Key(b"1234")) == b"1234" + + +def test_eq_not_instance(): + assert Key(b"hello") != "1234" + + +def test_from_hdwallet_missing_xprivate_key(): + with pytest.raises(InvalidKeyTypeException): + ExtendedSigningKey(b"1234").from_hdwallet( + HDWallet( + b"root_xprivate_key", + b"root_public_key", + b"root_chain_code", + None, + b"valid_public_key", + chain_code=b"valid_chain_code", + ) + ) + + with pytest.raises(InvalidKeyTypeException): + ExtendedSigningKey(b"1234").from_hdwallet( + HDWallet( + b"root_xprivate_key", + b"root_public_key", + b"root_chain_code", + b"valid_xprivate_key", + b"valid_public_key", + None, + ) + ) + + def test_payment_key(): assert ( SK.payload @@ -123,12 +174,25 @@ def test_extended_payment_key_sign(): def test_key_pair(): + PaymentSigningKey.generate() + + +def test_payment_key_pair(): + PaymentKeyPair.generate() sk = PaymentSigningKey.generate() vk = PaymentVerificationKey.from_signing_key(sk) assert PaymentKeyPair(sk, vk) == PaymentKeyPair.from_signing_key(sk) +def test_stake_key_pair(): + StakeKeyPair.generate() + sk = StakeSigningKey.generate() + vk = StakeVerificationKey.from_signing_key(sk) + assert StakeKeyPair(sk, vk) == StakeKeyPair.from_signing_key(sk) + + def test_stake_pool_key_pair(): + StakePoolKeyPair.generate() sk = StakePoolSigningKey.generate() vk = StakePoolVerificationKey.from_signing_key(sk) assert StakePoolKeyPair(sk, vk) == StakePoolKeyPair.from_signing_key(sk) @@ -158,6 +222,13 @@ def test_key_save(): assert SK == sk +def test_key_save_invalid_address(): + with tempfile.NamedTemporaryFile() as f: + SK.save(f.name) + with pytest.raises(IOError): + VK.save(f.name) + + def test_stake_pool_key_save(): with tempfile.NamedTemporaryFile() as skf, tempfile.NamedTemporaryFile() as vkf: SPSK.save(skf.name)