Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added flag enums support #220

Merged
merged 29 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
0637ee8
added flag enums support
daler-sz Jan 20, 2024
b23f95f
add allow_single_value, allow_duplicates, allow_compound, by_exact_va…
daler-sz Jan 20, 2024
04e91c2
fix formatting
daler-sz Jan 20, 2024
1ce09cd
Merge remote-tracking branch 'origin/develop' into flag-enums-support
daler-sz Jan 20, 2024
6275bad
resolve merge conflict
daler-sz Jan 20, 2024
23e540d
fix type hint in flag loader
daler-sz Jan 21, 2024
d1fbc15
renamed exception MultipleBadVariant
daler-sz Jan 21, 2024
4f76d94
rename extract_non_compound_cases_from_flag function & inline it
daler-sz Jan 21, 2024
3977a1f
remove partial return in flag loader provider
daler-sz Jan 21, 2024
9807899
fix typehits
daler-sz Jan 21, 2024
372c56b
fix typehints and remove ignores in _get_loader_process_data of flag …
daler-sz Jan 21, 2024
7a80b37
add name mapping for flags
daler-sz Jan 21, 2024
9b3c6b7
flag tests refactoring
daler-sz Jan 22, 2024
2f515d3
rename EnumMappingGenerator
daler-sz Jan 22, 2024
95b53bd
facades for flag provider
daler-sz Jan 22, 2024
b791461
now input_value is the last field of MultipleBadVariant
daler-sz Jan 23, 2024
ca2f4fe
EnumMappingGenerators refactor
daler-sz Jan 23, 2024
a0fdafc
make convert_snake_style work not only with snake_cas
daler-sz Jan 23, 2024
8ae1701
flag facades refactoring
daler-sz Jan 23, 2024
117f3bc
flag provider refactoring
daler-sz Jan 23, 2024
0d77734
add FlagByExactValueProvider
daler-sz Jan 24, 2024
c3fe0bb
name style tests refactoring
daler-sz Jan 24, 2024
36784e3
Merge remote-tracking branch 'origin/develop' into flag-enums-support
daler-sz Jan 25, 2024
53411e0
fix type in TypeLoadError
daler-sz Jan 25, 2024
05354a1
remove flag alias support
daler-sz Jan 25, 2024
47b1421
refactoring && tests improve
daler-sz Jan 26, 2024
2f29f96
flag docs & changelog
daler-sz Jan 26, 2024
f66b269
small changes
daler-sz Jan 26, 2024
f32cd58
small changes
daler-sz Jan 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/adaptix/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
enum_by_exact_value,
enum_by_name,
enum_by_value,
flag_by_list_using_exact_value,
flag_by_exact_value,
flag_by_list_using_name,
loader,
name_mapping,
Expand Down Expand Up @@ -60,10 +60,10 @@
'enum_by_exact_value',
'enum_by_name',
'enum_by_value',
'flag_by_exact_value',
'flag_by_list_using_name',
'name_mapping',
'AdornedRetort',
'flag_by_list_using_name',
'flag_by_list_using_exact_value',
'FilledRetort',
'Retort',
'TypedDictAt38Warning',
Expand Down
226 changes: 134 additions & 92 deletions src/adaptix/_internal/morphing/enum_provider.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import math
from abc import ABC, abstractmethod
from enum import Enum, EnumMeta, Flag
from typing import Any, Hashable, Iterable, Mapping, Optional, Sequence, Type, TypeVar, Union, final

from typing_extensions import overload
from functools import reduce
from operator import or_
from typing import Any, Hashable, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Type, TypeVar, Union

from ..common import Dumper, Loader, TypeHint
from ..morphing.provider_template import DumperProvider, LoaderProvider
Expand All @@ -12,8 +12,8 @@
from ..provider.loc_stack_filtering import DirectMediator, LastLocMapChecker
from ..provider.provider_template import for_predicate
from ..provider.request_cls import LocMap, TypeHintLoc, get_type_from_request
from ..type_tools import normalize_type
from .load_error import BadVariantError, MsgError, MultipleBadVariant, TypeLoadError, ValueLoadError
from ..type_tools import is_subclass_soft, normalize_type
from .load_error import BadVariantError, DuplicatedValues, MsgError, MultipleBadVariant, OutOfRange, TypeLoadError
from .request_cls import DumperRequest, LoaderRequest

EnumT = TypeVar("EnumT", bound=Enum)
Expand All @@ -22,50 +22,54 @@

class BaseEnumMappingGenerator(ABC):
@abstractmethod
def _generate_mapping(self, cases: Iterable[EnumT]) -> Mapping[EnumT, Hashable]:
def generate_for_dumping(self, cases: Mapping[str, EnumT]) -> Mapping[EnumT, Hashable]:
...

@final
def generate_for_dumping(self, cases: Iterable[EnumT]) -> Mapping[EnumT, Hashable]:
return self._generate_mapping(cases)

@final
def generate_for_loading(self, cases: Iterable[EnumT]) -> Mapping[Hashable, EnumT]:
return {
mapping_result: case
for case, mapping_result in self._generate_mapping(cases).items()
}
@abstractmethod
def generate_for_loading(self, cases: Mapping[str, EnumT]) -> Mapping[str, Hashable]:
...


class NameEnumMappingGenerator(BaseEnumMappingGenerator):
class ByNameEnumMappingGenerator(BaseEnumMappingGenerator):
def __init__(
self, name_style: Optional[NameStyle] = None,
self,
name_style: Optional[NameStyle] = None,
map: Optional[Mapping[Union[str, Enum], str]] = None # noqa: A002
):
self._name_style = name_style
self._map = map if map is not None else {}

def _generate_mapping(self, cases: Iterable[EnumT]) -> Mapping[EnumT, str]:
def generate_for_dumping(self, cases: Mapping[str, EnumT]) -> Mapping[EnumT, str]:
result = {}

for case in cases:
if not (case in self._map or case.name in self._map):
if self._name_style:
result[case] = convert_snake_style(case.name, self._name_style)
else:
result[case] = case.name
continue
try:
result[case] = self._map[case]
except KeyError:
result[case] = self._map[case.name]
for case in cases.values():
if case in self._map:
mapped = self._map[case]
elif case.name in self._map:
mapped = self._map[case.name]
elif self._name_style:
mapped = convert_snake_style(case.name, self._name_style)
else:
mapped = case.name
result[case] = mapped

return result

def generate_for_loading(self, cases: Mapping[str, EnumT]) -> Mapping[str, EnumT]:
result: MutableMapping[str, EnumT] = {}

for name, case in cases.items():
if case in self._map and case not in result.values():
mapped = self._map[case]
elif name in self._map:
mapped = self._map[name]
elif self._name_style:
mapped = convert_snake_style(name, self._name_style)
else:
mapped = name
result[mapped] = case

class ExactValueEnumMappingGenerator(BaseEnumMappingGenerator):
def _generate_mapping(self, cases: Iterable[EnumT]) -> Mapping[EnumT, Hashable]:
return {case: case.value for case in cases}
return result


class AnyEnumLSC(LastLocMapChecker):
Expand All @@ -77,6 +81,15 @@ def _check_location(self, mediator: DirectMediator, loc: TypeHintLoc) -> bool:
return isinstance(norm.origin, EnumMeta)


class FlagEnumLSC(LastLocMapChecker):
def _check_location(self, mediator: DirectMediator, loc: TypeHintLoc) -> bool:
try:
norm = normalize_type(loc.type)
except ValueError:
return False
return is_subclass_soft(norm.origin, Flag)


@for_predicate(AnyEnumLSC())
class BaseEnumProvider(LoaderProvider, DumperProvider, ABC):
pass
Expand All @@ -86,10 +99,6 @@ def _enum_name_dumper(data):
return data.name


def _enum_name_loader(enum, name):
return enum[name]


class EnumNameProvider(BaseEnumProvider):
"""This provider represents enum members to the outside world by their name"""

Expand Down Expand Up @@ -157,10 +166,6 @@ def _enum_exact_value_dumper(data):
return data.value


def _enum_exact_value_loader(enum, value):
return enum(value)


class EnumExactValueProvider(BaseEnumProvider):
"""This provider represents enum members to the outside world
by their value without any processing
Expand Down Expand Up @@ -212,11 +217,48 @@ def _provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper:
return _enum_exact_value_dumper


def _extract_non_compound_cases_from_flag(enum: Type[FlagT]) -> Iterable[FlagT]:
return [case for case in enum.__members__.values() if not math.log2(case.value) % 1]
@for_predicate(FlagEnumLSC())
class FlagByExactValueProvider(BaseEnumProvider):
def _provide_loader(self, mediator: Mediator, request: LoaderRequest) -> Loader:
enum = get_type_from_request(request)
flag_mask = reduce(or_, enum.__members__.values()).value

if flag_mask < 0:
raise CannotProvide(
"Cannot create a loader for flag with negative values",
is_terminal=True,
is_demonstrative=True,
)

all_bits = 2 ** flag_mask.bit_length() - 1
if all_bits != flag_mask:
raise CannotProvide(
"Cannot create a loader for flag with skipped bits",
is_terminal=True,
is_demonstrative=True,
)

def flag_loader(data):
if type(data) is not int: # pylint: disable=unidiomatic-typecheck
raise TypeLoadError(int, data)

if not 0 <= data <= flag_mask:
raise OutOfRange(0, flag_mask, data)

return enum(data)

return flag_loader

def _provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper:
return _enum_exact_value_dumper


def _extract_non_compound_cases_from_flag(enum: Type[FlagT]) -> Mapping[str, FlagT]:
return {name: case for name, case in enum.__members__.items() if not math.log2(case.value) % 1}


class FlagProvider(BaseEnumProvider):
@for_predicate(FlagEnumLSC())
class FlagByListProvider(BaseEnumProvider):
def __init__(
self,
mapping_generator: BaseEnumMappingGenerator,
Expand All @@ -229,78 +271,78 @@ def __init__(
self._allow_duplicates = allow_duplicates
self._allow_compound = allow_compound

def _get_cases(self, enum: Type[FlagT]) -> Iterable[FlagT]:
def _get_cases(self, enum: Type[FlagT]) -> Mapping[str, FlagT]:
if self._allow_compound:
return enum.__members__.values()
return enum.__members__
return _extract_non_compound_cases_from_flag(enum)

@overload
def _get_loader_process_data(self, data: Union[int, Iterable[int]], enum: Type[Flag]) -> Sequence[int]:
...

@overload
def _get_loader_process_data(self, data: Iterable[str], enum: Type[Flag]) -> Sequence[str]:
...

def _get_loader_process_data(self, data, enum):
if isinstance(data, (str, int)):
if not self._allow_single_value:
raise TypeLoadError(expected_type=Iterable[str], input_value=data)
process_data = [data]
else:
process_data = list(data)

if not self._allow_duplicates:
if len(process_data) != len(set(process_data)):
raise ValueLoadError(
msg=f"Duplicates in {enum} loader are not allowed",
input_value=process_data
)

return process_data

def _provide_loader(self, mediator: Mediator, request: LoaderRequest) -> Loader:
def _provide_loader(self, mediator: Mediator, request: LoaderRequest) -> Loader: # noqa: CCR001
enum = get_type_from_request(request)

if not issubclass(enum, Flag):
raise CannotProvide
allow_single_value = self._allow_single_value
allow_duplicates = self._allow_duplicates

cases = self._get_cases(enum)
mapping = self._mapping_generator.generate_for_loading(cases)
variants = list(mapping.keys())
zero_case = enum(0)

def flag_loader(data) -> Flag:
if isinstance(data, Iterable) and type(data) is not str: # pylint: disable=unidiomatic-typecheck
process_data = tuple(data)
else:
if not allow_single_value:
raise TypeLoadError(
expected_type=Union[Iterable[str], Iterable[int]],
input_value=data
)
process_data = (data,)

if not allow_duplicates:
if len(process_data) != len(set(process_data)):
raise DuplicatedValues(data)

def _flag_loader(data: Union[int, Iterable[int], Iterable[str]]) -> Flag:
process_data = self._get_loader_process_data(data, enum)
cases = self._get_cases(enum)
mapping = self._mapping_generator.generate_for_loading(cases)
variants = list(mapping.keys())
bad_variants = []
result = enum(0)
result = zero_case
for item in process_data:
if item not in variants:
bad_variants.append(item)
continue
result = result | mapping[item]
result |= mapping[item]

if bad_variants:
raise MultipleBadVariant(
allowed_values=variants,
input_value=process_data,
invalid_values=bad_variants
invalid_values=bad_variants,
input_value=data,
)

return result

return _flag_loader
return flag_loader

def _provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper:
enum = get_type_from_request(request)

if not issubclass(enum, Flag):
raise CannotProvide
cases = self._get_cases(enum)
need_to_reverse = self._allow_compound and cases != _extract_non_compound_cases_from_flag(enum)

def flag_dumper(value: Flag) -> Iterable[Hashable]:
cases = self._get_cases(enum)
mapping = self._mapping_generator.generate_for_dumping(cases)
mapping = self._mapping_generator.generate_for_dumping(cases)

if value in cases:
return [mapping[value]]
return [mapping[case] for case in cases if case in value]
if need_to_reverse:
cases_sequence = tuple(reversed(cases.values()))
else:
cases_sequence = tuple(cases.values())

zero_case = enum(0)

def flag_dumper(value: Flag) -> Sequence[Hashable]:
result: List[Hashable] = []
cases_sum = zero_case
for case in cases_sequence:
if case in value and case not in cases_sum:
cases_sum |= case
result.append(mapping[case])
return list(reversed(result)) if need_to_reverse else result

return flag_dumper
Loading