Skip to content

Commit

Permalink
feat: move default instance definition to class attribute (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
ovsds authored Feb 7, 2024
1 parent d2d4a1e commit 88e51e0
Show file tree
Hide file tree
Showing 13 changed files with 88 additions and 62 deletions.
6 changes: 6 additions & 0 deletions secret_transfer/core/base/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import typing

import typing_extensions

import secret_transfer.utils.pydantic as pydantic_utils
import secret_transfer.utils.types as utils_types

Expand All @@ -14,3 +16,7 @@ def parse_init_arguments(cls, **arguments: utils_types.BaseArgumentType) -> typi

model = cls._arguments_model.model_validate(arguments)
return model.shallow_model_dump()

@classmethod
def get_default_instances(cls) -> typing.Mapping[str, typing_extensions.Self]:
return {}
5 changes: 4 additions & 1 deletion secret_transfer/core/base/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class BaseRegistry(type, typing.Generic[protocol.ResourceType], metaclass=Regist
default_class: typing.Optional[type[protocol.ResourceType]]

def __init__(
cls,
cls: type[protocol.ResourceType],
name: str,
bases: tuple[type, ...],
attrs: dict[str, typing.Any],
Expand All @@ -29,6 +29,9 @@ def __init__(
if "__default__" in attrs and attrs["__default__"]:
cls.register_default_class(cls) # pyright: ignore[reportGeneralTypeIssues]

for name, instance in cls.get_default_instances().items():
cls.register_instance(name, instance) # pyright: ignore[reportGeneralTypeIssues]

@classmethod
def register_class(cls, name: str, class_: type[protocol.ResourceType]) -> None:
if name in cls.classes:
Expand Down
11 changes: 0 additions & 11 deletions secret_transfer/default/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
BashExportDestination,
EnvDestination,
GithubCliSecretsDestination,
register_default_destination_instances,
)
from .source import (
DotEnvSource,
Expand All @@ -12,19 +11,9 @@
UserInputSource,
VaultCLIKVSource,
YCCLILockboxSource,
register_default_source_instances,
)
from .transfer import DefaultTransfer


def register_default_instances() -> None:
register_default_source_instances()
register_default_destination_instances()


register_default_instances()


__all__ = [
"BashExportDestination",
"DefaultCollection",
Expand Down
10 changes: 1 addition & 9 deletions secret_transfer/default/destination/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,11 @@
from .bash_export import (
BashExportDestination,
register_bash_nexport_destinatio_instance,
)
from .env import EnvDestination, register_env_destination_instance
from .env import EnvDestination
from .gh_cli_secrets import GithubCliSecretsDestination


def register_default_destination_instances() -> None:
register_env_destination_instance()
register_bash_nexport_destinatio_instance()


__all__ = [
"BashExportDestination",
"EnvDestination",
"GithubCliSecretsDestination",
"register_default_destination_instances",
]
10 changes: 7 additions & 3 deletions secret_transfer/default/destination/bash_export.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import typing

import typing_extensions

import secret_transfer.core as core
import secret_transfer.utils.types as utils_types

Expand All @@ -9,6 +13,6 @@ def set(self, key: str, value: utils_types.BaseArgumentType) -> None:
def clean(self, key: str) -> None:
print(f"unset {key}")


def register_bash_nexport_destinatio_instance():
BashExportDestination().register("bash_export")
@classmethod
def get_default_instances(cls) -> typing.Mapping[str, typing_extensions.Self]:
return {"bash_export": cls()}
6 changes: 3 additions & 3 deletions secret_transfer/default/destination/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ def clean(self, key: str) -> None:
except KeyError:
pass


def register_env_destination_instance():
EnvDestination().register("env")
@classmethod
def get_default_instances(cls) -> dict[str, "EnvDestination"]:
return {"env": cls()}
10 changes: 2 additions & 8 deletions secret_transfer/default/source/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
from .dot_env import DotEnvSource
from .env import EnvSource, register_env_source_instance
from .env import EnvSource
from .preset import PresetSource
from .user_input import UserInputSource, register_user_input_source_instance
from .user_input import UserInputSource
from .vault_cli_kv import VaultCLIKVSource
from .yc_cli_lockbox import YCCLILockboxSource


def register_default_source_instances() -> None:
register_env_source_instance()
register_user_input_source_instance()


__all__ = [
"DotEnvSource",
"EnvSource",
Expand Down
6 changes: 3 additions & 3 deletions secret_transfer/default/source/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ def __getitem__(self, key: str) -> utils_types.LiteralArgumentType:
except KeyError as exc:
raise self.KeyNotFoundError(f"Key {key} is not found in {self.__class__.__name__}") from exc


def register_env_source_instance():
EnvSource().register("env")
@classmethod
def get_default_instances(cls) -> dict[str, "EnvSource"]:
return {"env": cls()}
6 changes: 3 additions & 3 deletions secret_transfer/default/source/user_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ class UserInputSource(core.AbstractSource):
def __getitem__(self, key: str) -> utils_types.LiteralArgumentType:
return getpass.getpass(prompt=f"Please provide a value for {key}: ")


def register_user_input_source_instance():
UserInputSource().register("user_input")
@classmethod
def get_default_instances(cls) -> dict[str, "UserInputSource"]:
return {"user_input": cls()}
6 changes: 6 additions & 0 deletions secret_transfer/protocol/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import typing

import typing_extensions


@typing.runtime_checkable
class BaseResourceProtocol(typing.Protocol):
@classmethod
def parse_init_arguments(cls, **arguments: typing.Any) -> typing.Mapping[str, typing.Any]:
...

@classmethod
def get_default_instances(cls) -> typing.Mapping[str, typing_extensions.Self]:
...


ResourceType = typing.TypeVar("ResourceType", bound=BaseResourceProtocol)
8 changes: 4 additions & 4 deletions secret_transfer/utils/cli/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ class RunError(Exception):
def __str__(self) -> str:
return (
f"Command '{self.command}' failed with exit code {self.exit_code}\n"
f"STDOUT:\n{self.stdout}\n"
f"STDERR:\n{self.stderr}\n"
f"STDOUT: {self.stdout}\n"
f"STDERR: {self.stderr}"
)


Expand All @@ -23,8 +23,8 @@ def run(command: str, encoding: str = "utf-8") -> str:
except subprocess.CalledProcessError as exc:
raise RunError(
exit_code=exc.returncode,
stdout=exc.stdout.decode(encoding),
stderr=exc.stderr.decode(encoding),
stdout=exc.stdout.decode(encoding).strip("\n"),
stderr=exc.stderr.decode(encoding).strip("\n"),
command=command,
) from exc

Expand Down
5 changes: 3 additions & 2 deletions tests/integration/utils/cli/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def test_run_utils_exit():
cli_utils.run(command)

assert exc.value.exit_code == 42
assert exc.value.stdout == "test\n"
assert exc.value.stderr == "test error\n"
assert exc.value.stdout == "test"
assert exc.value.stderr == "test error"
assert exc.value.command == command
assert str(exc.value) == f"Command '{command}' failed with exit code 42\n" "STDOUT: test\n" "STDERR: test error"
61 changes: 46 additions & 15 deletions tests/unit/core/base/test_registry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import typing

import pytest
import typing_extensions

import secret_transfer.core.base as core_base

Expand All @@ -13,6 +14,12 @@ class ResourceTestProtocol(typing.Protocol):
RegistryType = type[Registry]


class ClassBase:
@classmethod
def get_default_instances(cls) -> dict[str, typing_extensions.Self]:
return {}


@pytest.fixture(name="registry1")
def fixture_registry1() -> RegistryType:
class Registry1(Registry):
Expand All @@ -30,10 +37,10 @@ class Registry2(Registry):


def test_multiple_registry_have_different_defaults(registry1: RegistryType, registry2: RegistryType):
class Class1(metaclass=registry1):
class Class1(ClassBase, metaclass=registry1):
__register__ = True

class Class2(metaclass=registry2):
class Class2(ClassBase, metaclass=registry2):
__register__ = True

assert Class1.__name__ in registry1.classes
Expand All @@ -43,40 +50,40 @@ class Class2(metaclass=registry2):


def test_explicit_class_registration(registry1: RegistryType):
class Class1(metaclass=registry1):
class Class1(ClassBase, metaclass=registry1):
__register__ = True

assert Class1.__name__ in registry1.classes
assert registry1.classes[Class1.__name__] is Class1


def test_implicit_class_registration(registry1: RegistryType):
class Class1(metaclass=registry1):
class Class1(ClassBase, metaclass=registry1):
...

assert Class1.__name__ in registry1.classes
assert registry1.classes[Class1.__name__] is Class1


def test_explicit_class_non_registration(registry1: RegistryType):
class Class1(metaclass=registry1):
class Class1(ClassBase, metaclass=registry1):
__register__ = False

assert Class1.__name__ not in registry1.classes


def test_class_name_collision(registry1: RegistryType):
class Class1(metaclass=registry1): # pyright: ignore[reportUnusedClass, reportGeneralTypeIssues]
class Class1(ClassBase, metaclass=registry1): # pyright: ignore[reportUnusedClass, reportGeneralTypeIssues]
__register__ = True

with pytest.raises(ValueError):

class Class1(metaclass=registry1): # noqa: F811 # pyright: ignore[reportUnusedClass]
class Class1(ClassBase, metaclass=registry1): # noqa: F811 # pyright: ignore[reportUnusedClass]
__register__ = True


def test_instance_registration(registry1: RegistryType):
class Class1(metaclass=registry1):
class Class1(ClassBase, metaclass=registry1):
...

name = "test_name"
Expand All @@ -89,7 +96,7 @@ class Class1(metaclass=registry1):


def test_instance_name_collision(registry1: RegistryType):
class Class1(metaclass=registry1):
class Class1(ClassBase, metaclass=registry1):
...

name = "test_name"
Expand All @@ -103,7 +110,7 @@ class Class1(metaclass=registry1):


def test_default_class_registration(registry1: RegistryType):
class Class1(metaclass=registry1):
class Class1(ClassBase, metaclass=registry1):
...

registry1.register_default_class(Class1)
Expand All @@ -112,17 +119,17 @@ class Class1(metaclass=registry1):


def test_default_class_registration_using_attribute(registry1: RegistryType):
class Class1(metaclass=registry1):
class Class1(ClassBase, metaclass=registry1):
__default__ = True

assert registry1.default_class is Class1


def test_default_class_collision(registry1: RegistryType):
class Class1(metaclass=registry1):
class Class1(ClassBase, metaclass=registry1):
...

class Class2(metaclass=registry1):
class Class2(ClassBase, metaclass=registry1):
...

registry1.register_default_class(Class1)
Expand All @@ -131,14 +138,38 @@ class Class2(metaclass=registry1):


def test_default_class_collision_force(registry1: RegistryType):
class Class1(metaclass=registry1):
class Class1(ClassBase, metaclass=registry1):
...

class Class2(metaclass=registry1):
class Class2(ClassBase, metaclass=registry1):
...

registry1.register_default_class(Class1)

registry1.register_default_class(Class2, force=True)

assert registry1.default_class is Class2


def test_default_instance_registration(registry1: RegistryType):
class Class1(metaclass=registry1):
@classmethod
def get_default_instances(cls) -> dict[str, ResourceTestProtocol]:
return {"test_name": cls()}

assert "test_name" in registry1.instances
assert isinstance(registry1.instances["test_name"], Class1)


def test_default_instance_collision(registry1: RegistryType):
class Class1(metaclass=registry1): # pyright: ignore[reportUnusedClass]
@classmethod
def get_default_instances(cls) -> dict[str, ResourceTestProtocol]:
return {"test_name": cls()}

with pytest.raises(ValueError):

class Class2(metaclass=registry1): # pyright: ignore[reportUnusedClass]
@classmethod
def get_default_instances(cls) -> dict[str, ResourceTestProtocol]:
return {"test_name": cls()}

0 comments on commit 88e51e0

Please sign in to comment.