Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added flag enums support #220

Merged
merged 29 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
0637ee8
added flag enums support
daler-sz Jan 20, 2024
b23f95f
add allow_single_value, allow_duplicates, allow_compound, by_exact_va…
daler-sz Jan 20, 2024
04e91c2
fix formatting
daler-sz Jan 20, 2024
1ce09cd
Merge remote-tracking branch 'origin/develop' into flag-enums-support
daler-sz Jan 20, 2024
6275bad
resolve merge conflict
daler-sz Jan 20, 2024
23e540d
fix type hint in flag loader
daler-sz Jan 21, 2024
d1fbc15
renamed exception MultipleBadVariant
daler-sz Jan 21, 2024
4f76d94
rename extract_non_compound_cases_from_flag function & inline it
daler-sz Jan 21, 2024
3977a1f
remove partial return in flag loader provider
daler-sz Jan 21, 2024
9807899
fix typehits
daler-sz Jan 21, 2024
372c56b
fix typehints and remove ignores in _get_loader_process_data of flag …
daler-sz Jan 21, 2024
7a80b37
add name mapping for flags
daler-sz Jan 21, 2024
9b3c6b7
flag tests refactoring
daler-sz Jan 22, 2024
2f515d3
rename EnumMappingGenerator
daler-sz Jan 22, 2024
95b53bd
facades for flag provider
daler-sz Jan 22, 2024
b791461
now input_value is the last field of MultipleBadVariant
daler-sz Jan 23, 2024
ca2f4fe
EnumMappingGenerators refactor
daler-sz Jan 23, 2024
a0fdafc
make convert_snake_style work not only with snake_cas
daler-sz Jan 23, 2024
8ae1701
flag facades refactoring
daler-sz Jan 23, 2024
117f3bc
flag provider refactoring
daler-sz Jan 23, 2024
0d77734
add FlagByExactValueProvider
daler-sz Jan 24, 2024
c3fe0bb
name style tests refactoring
daler-sz Jan 24, 2024
36784e3
Merge remote-tracking branch 'origin/develop' into flag-enums-support
daler-sz Jan 25, 2024
53411e0
fix type in TypeLoadError
daler-sz Jan 25, 2024
05354a1
remove flag alias support
daler-sz Jan 25, 2024
47b1421
refactoring && tests improve
daler-sz Jan 26, 2024
2f29f96
flag docs & changelog
daler-sz Jan 26, 2024
f66b269
small changes
daler-sz Jan 26, 2024
f32cd58
small changes
daler-sz Jan 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 95 additions & 11 deletions src/adaptix/_internal/morphing/enum_provider.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import math
from abc import ABC
from enum import Enum, EnumMeta, Flag
from typing import Any, Mapping, Optional, Type
from functools import partial
from typing import Any, Iterable, Mapping, Optional, Type, Union

from ..common import Dumper, Loader, TypeHint
from ..morphing.provider_template import DumperProvider, LoaderProvider
from ..provider.essential import CannotProvide, Mediator
from ..provider.essential import Mediator
from ..provider.loc_stack_filtering import DirectMediator, LastLocMapChecker
from ..provider.provider_template import for_predicate
from ..provider.request_cls import LocMap, TypeHintLoc, get_type_from_request
from ..type_tools import normalize_type
from .load_error import BadVariantError, MsgError
from .load_error import BadVariantError, MsgError, MultipleBadVariantError, TypeLoadError, ValueLoadError
from .request_cls import DumperRequest, LoaderRequest


Expand All @@ -31,19 +33,15 @@ def _enum_name_dumper(data):
return data.name


def _enum_name_loader(enum, name):
return enum[name]


class EnumNameProvider(BaseEnumProvider):
"""This provider represents enum members to the outside world by their name"""

def _provide_loader(self, mediator: Mediator, request: LoaderRequest) -> Loader:
enum = get_type_from_request(request)

if issubclass(enum, Flag):
raise CannotProvide(
"Flag subclasses is not supported yet",
is_terminal=True,
is_demonstrative=True
)

variants = [case.name for case in enum]

def enum_loader(data):
Expand Down Expand Up @@ -106,6 +104,10 @@ def _enum_exact_value_dumper(data):
return data.value


def _enum_exact_value_loader(enum, value):
return enum(value)


class EnumExactValueProvider(BaseEnumProvider):
"""This provider represents enum members to the outside world
by their value without any processing
Expand Down Expand Up @@ -155,3 +157,85 @@ def _get_exact_value_to_member(self, enum: Type[Enum]) -> Optional[Mapping[Any,

def _provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper:
return _enum_exact_value_dumper


def _extract_common_cases_from_flag(enum: Type[Flag]) -> Iterable[Flag]:
cases = enum.__members__.values()
result = []
for case in cases:
if not math.log2(case.value) % 1:
result.append(case)
return result


class FlagProvider(BaseEnumProvider):
def __init__(
self,
allow_single_value: bool = False,
allow_duplicates: bool = True,
allow_compound: bool = True,
by_exact_value: bool = False
):
self._allow_single_value = allow_single_value
self._allow_duplicates = allow_duplicates
self._allow_compound = allow_compound
self._by_exact_value = by_exact_value

self._loader = _enum_exact_value_loader if by_exact_value else _enum_name_loader
self._dumper = _enum_exact_value_dumper if by_exact_value else _enum_name_dumper

def _get_cases(self, enum: Type[Flag]):
if self._allow_compound:
return enum.__members__.values()
return _extract_common_cases_from_flag(enum)

def _flag_loader(self, data: Union[int, Iterable[Union[int, str]]], enum: Type[Flag]) -> Flag:
if isinstance(data, (str, int)):
if not self._allow_single_value:
raise TypeLoadError(
expected_type=Iterable[str],
input_value=data,
)
process_data = [data]
else:
process_data = list(data)

if not self._allow_duplicates:
if len(process_data) != len(set(process_data)):
raise ValueLoadError(
msg=f"Duplicates in {enum} loader are not allowed",
input_value=process_data
)

variants = [self._dumper(case) for case in self._get_cases(enum)]
bad_variants = []
result = enum(0)
for item in process_data:
if item not in variants:
bad_variants.append(item)
continue
result = result | self._loader(enum, item)

if bad_variants:
raise MultipleBadVariantError(
allowed_values=variants,
input_values=process_data,
invalid_values=bad_variants
)

return result

def _provide_loader(self, mediator: Mediator, request: LoaderRequest) -> Loader:
enum = get_type_from_request(request)
return partial(self._flag_loader, enum=enum)

def _provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper:
enum = get_type_from_request(request)

def flag_dumper(value: Flag) -> Iterable[str]:
cases = self._get_cases(enum)
if value in cases:
return [self._dumper(value)]
return [self._dumper(case) for case in cases if case in value]

return flag_dumper
3 changes: 2 additions & 1 deletion src/adaptix/_internal/morphing/facade/retort.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
)
from ..constant_length_tuple_provider import ConstantLengthTupleProvider
from ..dict_provider import DictProvider
from ..enum_provider import EnumExactValueProvider
from ..enum_provider import EnumExactValueProvider, FlagProvider
from ..generic_provider import (
LiteralProvider,
NewTypeUnwrappingProvider,
Expand Down Expand Up @@ -80,6 +80,7 @@ class FilledRetort(OperatingRetort, ABC):
SecondsTimedeltaProvider(),

EnumExactValueProvider(), # it has higher priority than scalar types for Enum with mixins
FlagProvider(),

INT_LOADER_PROVIDER,
as_is_dumper(int),
Expand Down
8 changes: 8 additions & 0 deletions src/adaptix/_internal/morphing/load_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,14 @@ class BadVariantError(LoadError):
input_value: Any


@custom_exception
@dataclass(eq=False)
class MultipleBadVariantError(LoadError):
allowed_values: Iterable[Any]
input_values: Iterable[Any]
invalid_values: Iterable[Any]


@custom_exception
@dataclass(eq=False)
class DatetimeFormatMismatch(LoadError):
Expand Down
191 changes: 189 additions & 2 deletions tests/unit/morphing/test_enum_provider.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from enum import Enum, IntEnum
from enum import Enum, Flag, IntEnum, auto
from typing import Iterable

import pytest
from tests_helpers import TestRetort, raises_exc

from adaptix import dumper, enum_by_value, loader
from adaptix._internal.morphing.enum_provider import EnumExactValueProvider, EnumNameProvider
from adaptix._internal.morphing.enum_provider import EnumExactValueProvider, EnumNameProvider, FlagProvider
from adaptix._internal.morphing.load_error import MultipleBadVariantError, TypeLoadError, ValueLoadError
from adaptix.load_error import BadVariantError, MsgError


Expand All @@ -24,6 +26,14 @@ def _missing_(cls, value: object) -> 'MyEnumWithMissingHook':
raise ValueError


class FlagEnum(Flag):
V1 = auto()
V2 = auto()
V3 = auto()
V5 = V2 | V3
V6 = V1 | V2 | V3


def test_name_provider(strict_coercion, debug_trail):
retort = TestRetort(
strict_coercion=strict_coercion,
Expand Down Expand Up @@ -152,3 +162,180 @@ def test_value_provider(strict_coercion, debug_trail):
enum_dumper = retort.get_dumper(MyEnum)

assert enum_dumper(MyEnum.V1) == "PREFIX 1"


def test_flag_enum_loader(strict_coercion, debug_trail):
retort = TestRetort(
strict_coercion=strict_coercion,
debug_trail=debug_trail,
recipe=[
FlagProvider(),
]
)

loader = retort.get_loader(FlagEnum)
assert loader(["V1"]) == FlagEnum.V1
assert loader(["V1", "V2", "V3"]) == FlagEnum.V6
assert loader(["V6"]) == FlagEnum.V6
assert loader(["V2", "V3"]) == FlagEnum.V5
assert loader(["V1", "V2"]) == FlagEnum.V1 | FlagEnum.V2
assert loader(["V1", "V1"]) == FlagEnum.V1

variants = ["V1", "V2", "V3", "V5", "V6"]
raises_exc(
MultipleBadVariantError(
allowed_values=variants,
input_values=["V7", "V8", "V6"],
invalid_values=["V7", "V8"]
),
lambda: loader(["V7", "V8", "V6"])
)
raises_exc(
TypeLoadError(
expected_type=Iterable[str],
input_value="V1"
),
lambda: loader("V1")
)


def test_flag_enum_dumper(strict_coercion, debug_trail):
retort = TestRetort(
strict_coercion=strict_coercion,
debug_trail=debug_trail,
recipe=[
FlagProvider(),
]
)

dumper = retort.get_dumper(FlagEnum)
assert dumper(FlagEnum.V1) == ["V1"]
assert dumper(FlagEnum.V1 | FlagEnum.V2 | FlagEnum.V3) == ["V6"]
assert dumper(FlagEnum.V1 | FlagEnum.V2) == ["V1", "V2"]
assert dumper(FlagEnum.V1 & FlagEnum.V2) == []
assert dumper(FlagEnum.V2 & FlagEnum.V5) == ["V2"]
assert dumper(~FlagEnum.V2) == ["V1", "V3"]


def test_flag_enum_loader_by_exact_value(strict_coercion, debug_trail):
retort = TestRetort(
strict_coercion=strict_coercion,
debug_trail=debug_trail,
recipe=[
FlagProvider(by_exact_value=True),
]
)

loader = retort.get_loader(FlagEnum)
assert loader([1]) == FlagEnum.V1
assert loader([1, 2, 4]) == FlagEnum.V6
assert loader([7]) == FlagEnum.V6
assert loader([2, 4]) == FlagEnum.V5
assert loader([1, 2]) == FlagEnum.V1 | FlagEnum.V2
variants = [1, 2, 4, 6, 7]
raises_exc(
MultipleBadVariantError(
allowed_values=variants,
input_values=[3, 5, 7],
invalid_values=[3, 5]
),
lambda: loader([3, 5, 7])
)
raises_exc(
TypeLoadError(
expected_type=Iterable[str],
input_value=1
),
lambda: loader(1)
)


def test_flag_enum_dumper_by_exact_value(strict_coercion, debug_trail):
retort = TestRetort(
strict_coercion=strict_coercion,
debug_trail=debug_trail,
recipe=[
FlagProvider(by_exact_value=True),
]
)

dumper = retort.get_dumper(FlagEnum)
assert dumper(FlagEnum.V1) == [1]
assert dumper(FlagEnum.V1 | FlagEnum.V2 | FlagEnum.V3) == [7]
assert dumper(FlagEnum.V1 | FlagEnum.V2) == [1, 2]
assert dumper(FlagEnum.V1 & FlagEnum.V2) == []
assert dumper(FlagEnum.V2 & FlagEnum.V5) == [2]
assert dumper(~FlagEnum.V2) == [1, 4]


def test_flag_enum_loader_with_disallowed_compounds(strict_coercion, debug_trail):
retort = TestRetort(
strict_coercion=strict_coercion,
debug_trail=debug_trail,
recipe=[
FlagProvider(allow_compound=False),
]
)

loader = retort.get_loader(FlagEnum)
assert loader(["V1"]) == FlagEnum.V1
assert loader(["V1", "V2", "V3"]) == FlagEnum.V6
assert loader(["V1", "V2"]) == FlagEnum.V1 | FlagEnum.V2

variants = ["V1", "V2", "V3"]
raises_exc(
MultipleBadVariantError(
allowed_values=variants,
input_values=["V1", "V6"],
invalid_values=["V6"]
),
lambda: loader(["V1", "V6"])
)


def test_flag_enum_dumper_with_disallowed_compounds(strict_coercion, debug_trail):
retort = TestRetort(
strict_coercion=strict_coercion,
debug_trail=debug_trail,
recipe=[
FlagProvider(allow_compound=False),
]
)

dumper = retort.get_dumper(FlagEnum)
assert dumper(FlagEnum.V1) == ["V1"]
assert dumper(FlagEnum.V1 | FlagEnum.V2 | FlagEnum.V3) == ["V1", "V2", "V3"]
assert dumper(FlagEnum.V1 | FlagEnum.V2) == ["V1", "V2"]
assert dumper(FlagEnum.V1 & FlagEnum.V2) == []
assert dumper(FlagEnum.V2 & FlagEnum.V5) == ["V2"]
assert dumper(~FlagEnum.V2) == ["V1", "V3"]


def test_flag_enum_loader_with_allowed_single_value(strict_coercion, debug_trail):
retort = TestRetort(
strict_coercion=strict_coercion,
debug_trail=debug_trail,
recipe=[
FlagProvider(allow_single_value=True),
]
)

loader = retort.get_loader(FlagEnum)
assert loader("V1") == FlagEnum.V1
assert loader("V5") == FlagEnum.V5


def test_flag_enum_loader_with_disallowed_duplicates(strict_coercion, debug_trail):
retort = TestRetort(
strict_coercion=strict_coercion,
debug_trail=debug_trail,
recipe=[
FlagProvider(allow_duplicates=False),
]
)

loader = retort.get_loader(FlagEnum)
raises_exc(
ValueLoadError(f"Duplicates in {FlagEnum} loader are not allowed", ["V1", "V1"]),
lambda: loader(["V1", "V1"])
)