Skip to content

Commit

Permalink
fixup! fixup! feat(cardano): add support for script addresses derivation
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielKerekes committed Oct 6, 2021
1 parent 248328f commit 0027ea9
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 73 deletions.
4 changes: 2 additions & 2 deletions core/src/all_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,8 +456,8 @@
import apps.cardano.helpers.account_path_check
apps.cardano.helpers.bech32
import apps.cardano.helpers.bech32
apps.cardano.helpers.credential_params
import apps.cardano.helpers.credential_params
apps.cardano.helpers.credential
import apps.cardano.helpers.credential
apps.cardano.helpers.hash_builder_collection
import apps.cardano.helpers.hash_builder_collection
apps.cardano.helpers.network_ids
Expand Down
6 changes: 3 additions & 3 deletions core/src/apps/cardano/get_address.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from . import seed
from .address import derive_human_readable_address, validate_address_parameters
from .helpers.credential_params import CredentialParams, should_show_address_credentials
from .helpers.credential import Credential, should_show_address_credentials
from .layout import show_cardano_address, show_credentials
from .sign_tx import validate_network_info

Expand Down Expand Up @@ -47,8 +47,8 @@ async def _display_address(
if should_show_address_credentials(address_parameters):
await show_credentials(
ctx,
CredentialParams.payment_params(address_parameters),
CredentialParams.stake_params(address_parameters),
Credential.payment_credential(address_parameters),
Credential.stake_credential(address_parameters),
)

await show_cardano_address(ctx, address_parameters, address, protocol_magic)
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from trezor.ui.layouts import PropertyType


class CredentialParams:
class Credential:
"""
Serves mainly as a wrapper object for credential parameters (so that they don't have to be
Serves mainly as a wrapper object for credentials (so that they don't have to be
passed into functions separately) which also determines all properties that should be shown
as warnings.
Also contains functions which simplify displaying the credential.
Expand Down Expand Up @@ -50,11 +50,11 @@ def __init__(
self.pointer = pointer

@classmethod
def payment_params(
def payment_credential(
cls, address_params: CardanoAddressParametersType
) -> "CredentialParams":
) -> "Credential":
address_type = address_params.address_type
credential_params = CredentialParams(
credential = cls(
"payment",
address_type,
address_params.address_n,
Expand All @@ -71,33 +71,33 @@ def payment_params(
CardanoAddressType.BYRON,
):
if not SCHEMA_PAYMENT.match(address_params.address_n):
credential_params.is_unusual_path = True
credential.is_unusual_path = True

elif address_type in (
CardanoAddressType.BASE_SCRIPT_KEY,
CardanoAddressType.BASE_SCRIPT_SCRIPT,
CardanoAddressType.POINTER_SCRIPT,
CardanoAddressType.ENTERPRISE_SCRIPT,
):
credential_params.is_other_warning = True
credential.is_other_warning = True

elif address_type in (
CardanoAddressType.REWARD,
CardanoAddressType.REWARD_SCRIPT,
):
credential_params.is_reward = True
credential.is_reward = True

else:
raise ValueError("Invalid address type")
raise RuntimeError # we didn't cover all address types

return credential_params
return credential

@classmethod
def stake_params(
def stake_credential(
cls, address_params: CardanoAddressParametersType
) -> "CredentialParams":
) -> "Credential":
address_type = address_params.address_type
credential_params = CredentialParams(
credential = cls(
"stake",
address_type,
address_params.address_n_staking,
Expand All @@ -108,50 +108,50 @@ def stake_params(

if address_type == CardanoAddressType.BASE:
if address_params.staking_key_hash:
credential_params.is_other_warning = True
credential.is_other_warning = True
else:
if not SCHEMA_STAKING.match(address_params.address_n_staking):
credential_params.is_unusual_path = True
credential.is_unusual_path = True
if not _do_base_address_credentials_match(
address_params.address_n,
address_params.address_n_staking,
):
credential_params.is_mismatch = True
credential.is_mismatch = True

elif address_type == CardanoAddressType.BASE_SCRIPT_KEY:
if address_params.address_n_staking and not SCHEMA_STAKING.match(
address_params.address_n_staking
):
credential_params.is_unusual_path = True
credential.is_unusual_path = True

elif address_type in (
CardanoAddressType.POINTER,
CardanoAddressType.POINTER_SCRIPT,
):
credential_params.is_other_warning = True
credential.is_other_warning = True

elif address_type == CardanoAddressType.REWARD:
if not SCHEMA_STAKING.match(address_params.address_n_staking):
credential_params.is_unusual_path = True
credential.is_unusual_path = True

elif address_type in (
CardanoAddressType.BASE_KEY_SCRIPT,
CardanoAddressType.BASE_SCRIPT_SCRIPT,
CardanoAddressType.REWARD_SCRIPT,
):
credential_params.is_other_warning = True
credential.is_other_warning = True

elif address_type in (
CardanoAddressType.ENTERPRISE,
CardanoAddressType.ENTERPRISE_SCRIPT,
CardanoAddressType.BYRON,
):
credential_params.is_no_staking = True
credential.is_no_staking = True

else:
raise ValueError("Invalid address type")
raise RuntimeError # we didn't cover all address types

return credential_params
return credential

def should_warn(self) -> bool:
return any(
Expand All @@ -164,17 +164,8 @@ def should_warn(self) -> bool:
)
)

def get_credential(self) -> list[int] | bytes | CardanoBlockchainPointerType | None:
if self.path:
return self.path
elif self.key_hash:
return self.key_hash
elif self.script_hash:
return self.script_hash
elif self.pointer:
return self.pointer
else:
return None
def is_set(self) -> bool:
return any((self.path, self.key_hash, self.script_hash, self.pointer))

def get_title(self) -> str:
if self.path:
Expand Down
35 changes: 17 additions & 18 deletions core/src/apps/cardano/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
)

from trezor.ui.layouts import PropertyType
from .helpers.credential_params import CredentialParams
from .helpers.credential import Credential


ADDRESS_TYPE_NAMES = {
Expand Down Expand Up @@ -247,62 +247,61 @@ async def confirm_sending_token(

async def show_credentials(
ctx: wire.Context,
payment_credential_params: CredentialParams,
stake_credential_params: CredentialParams,
payment_credential: Credential,
stake_credential: Credential,
is_change_output: bool = False,
) -> None:
await _show_credential(ctx, payment_credential_params, is_change_output)
await _show_credential(ctx, stake_credential_params, is_change_output)
await _show_credential(ctx, payment_credential, is_change_output)
await _show_credential(ctx, stake_credential, is_change_output)


async def _show_credential(
ctx: wire.Context,
credential_params: CredentialParams,
credential: Credential,
is_change_output: bool = False,
) -> None:
if is_change_output:
title = "Confirm transaction"
else:
title = "%s address" % ADDRESS_TYPE_NAMES[credential_params.address_type]
title = "%s address" % ADDRESS_TYPE_NAMES[credential.address_type]

props: list[PropertyType] = []

credential = credential_params.get_credential()
# Credential can be empty in case of enterprise address stake credential
# and reward address payment credential. In that case we don't want to
# show some of the "props".
if credential:
if credential.is_set():
if is_change_output:
address_usage = "Change address"
else:
address_usage = "Address"

credential_title = credential_params.get_title()
credential_title = credential.get_title()
props.append(
(
"%s %s credential is a %s:"
% (address_usage, credential_params.type_name, credential_title),
% (address_usage, credential.type_name, credential_title),
None,
)
)
props.extend(credential_params.format())
props.extend(credential.format())

if credential_params.is_unusual_path:
if credential.is_unusual_path:
props.append((None, "Path is unusual."))
if credential_params.is_mismatch:
if credential.is_mismatch:
props.append((None, "Credential doesn't match payment credential."))
if credential_params.is_reward:
if credential.is_reward:
props.append(("Address is a reward address.", None))
if credential_params.is_no_staking:
if credential.is_no_staking:
props.append(
(
"%s address - no staking rewards."
% ADDRESS_TYPE_NAMES[credential_params.address_type],
% ADDRESS_TYPE_NAMES[credential.address_type],
None,
)
)

if credential_params.should_warn():
if credential.should_warn():
icon = ui.ICON_WRONG
icon_color = ui.RED
else:
Expand Down
10 changes: 5 additions & 5 deletions core/src/apps/cardano/sign_tx.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
protocol_magics,
)
from .helpers.account_path_check import AccountPathChecker
from .helpers.credential_params import CredentialParams, should_show_address_credentials
from .helpers.credential import Credential, should_show_address_credentials
from .helpers.hash_builder_collection import HashBuilderDict, HashBuilderList
from .helpers.paths import (
CERTIFICATE_PATH_NAME,
Expand Down Expand Up @@ -792,8 +792,8 @@ async def _show_output(

await show_credentials(
ctx,
CredentialParams.payment_params(address_parameters),
CredentialParams.stake_params(address_parameters),
Credential.payment_credential(address_parameters),
Credential.stake_credential(address_parameters),
is_change_output=True,
)

Expand Down Expand Up @@ -1044,8 +1044,8 @@ def _fail_if_strict_and_unusual(
if not safety_checks.is_strict():
return

if CredentialParams.payment_params(address_parameters).is_unusual_path:
if Credential.payment_credential(address_parameters).is_unusual_path:
raise wire.DataError("Invalid %s" % CHANGE_OUTPUT_PATH_NAME.lower())

if CredentialParams.stake_params(address_parameters).is_unusual_path:
if Credential.stake_credential(address_parameters).is_unusual_path:
raise wire.DataError("Invalid %s" % CHANGE_OUTPUT_STAKING_PATH_NAME.lower())
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from common import *

from apps.cardano.helpers.credential_params import CredentialParams
from apps.cardano.helpers.credential import Credential
from apps.common.paths import HARDENED
from trezor.enums import CardanoAddressType
from trezor.messages import CardanoAddressParametersType, CardanoBlockchainPointerType
Expand Down Expand Up @@ -239,26 +239,26 @@ def _create_flags(
]


def _get_flags(credential_params: CredentialParams) -> tuple[bool, ...]:
def _get_flags(credential: Credential) -> tuple[bool, ...]:
return (
credential_params.is_reward,
credential_params.is_no_staking,
credential_params.is_mismatch,
credential_params.is_unusual_path,
credential_params.is_other_warning,
credential.is_reward,
credential.is_no_staking,
credential.is_mismatch,
credential.is_unusual_path,
credential.is_other_warning,
)


@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
class TestCardanoCredentialParams(unittest.TestCase):
def test_credential_params_flags(self):
class TestCardanoCredential(unittest.TestCase):
def test_credential_flags(self):
for (
address_parameters,
expected_payment_flags,
expected_stake_flags,
) in ADDRESS_PARAMETERS_CASES:
payment_credential = CredentialParams.payment_params(address_parameters)
stake_credential = CredentialParams.stake_params(address_parameters)
payment_credential = Credential.payment_credential(address_parameters)
stake_credential = Credential.stake_credential(address_parameters)
self.assertEqual(_get_flags(payment_credential), expected_payment_flags)
self.assertEqual(_get_flags(stake_credential), expected_stake_flags)

Expand Down

0 comments on commit 0027ea9

Please sign in to comment.