Skip to content

Commit

Permalink
Merge pull request #220 from daler-sz/flag-enums-support
Browse files Browse the repository at this point in the history
Added flag enums support
  • Loading branch information
zhPavel authored Jan 26, 2024
2 parents 8002be8 + f32cd58 commit 7ff21d7
Show file tree
Hide file tree
Showing 11 changed files with 633 additions and 61 deletions.
4 changes: 4 additions & 0 deletions docs/changelog/fragments/197.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Added flags support

Now adaptix has two different ways to process flags: :func:`.flag_by_exact_value` (by default)
and :func:`.flag_by_member_names`.
13 changes: 12 additions & 1 deletion docs/loading-and-dumping/specific-types-behavior.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,18 @@ timedelta
Loader accepts instance of ``int``, ``float`` or ``Decimal`` representing seconds,
dumper serialize value via ``total_seconds`` method.

Enum subclasses

Flag subclasses
'''''''''''''''''''''''

Flag members by default are represented by their value. Note that flags with skipped
bits and negative values are not supported, so it is highly recommended to define flag
values via ``enum.auto()`` instead of manually specifying them.
Besides, adaptix provides another way to process flags: by list using their names.
See: :func:`.flag_by_member_names` for details.


Other Enum subclasses
'''''''''''''''''''''''

Enum members are represented by their value without any conversion.
Expand Down
4 changes: 4 additions & 0 deletions src/adaptix/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
enum_by_exact_value,
enum_by_name,
enum_by_value,
flag_by_exact_value,
flag_by_member_names,
loader,
name_mapping,
validator,
Expand Down Expand Up @@ -59,6 +61,8 @@
'enum_by_exact_value',
'enum_by_name',
'enum_by_value',
'flag_by_exact_value',
'flag_by_member_names',
'name_mapping',
'default_dict',
'AdornedRetort',
Expand Down
229 changes: 215 additions & 14 deletions src/adaptix/_internal/morphing/enum_provider.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,107 @@
from abc import ABC
import collections
import math
from abc import ABC, abstractmethod
from enum import Enum, EnumMeta, Flag
from typing import Any, Mapping, Optional, Type
from functools import reduce
from operator import or_
from typing import Any, Iterable, Mapping, Optional, Sequence, Type, TypeVar, Union, final

from ..common import Dumper, Loader, TypeHint
from ..morphing.provider_template import DumperProvider, LoaderProvider
from ..name_style import NameStyle, convert_snake_style
from ..provider.essential import CannotProvide, Mediator
from ..provider.loc_stack_filtering import DirectMediator, LastLocMapChecker
from ..provider.provider_template import for_predicate
from ..provider.request_cls import LocMap, TypeHintLoc, get_type_from_request
from ..type_tools import normalize_type
from .load_error import BadVariantError, MsgError
from ..provider.request_cls import LocMap, StrictCoercionRequest, TypeHintLoc, get_type_from_request
from ..type_tools import is_subclass_soft, normalize_type
from .load_error import (
BadVariantError,
DuplicatedValues,
ExcludedTypeLoadError,
MsgError,
MultipleBadVariant,
OutOfRange,
TypeLoadError,
)
from .request_cls import DumperRequest, LoaderRequest

EnumT = TypeVar("EnumT", bound=Enum)
FlagT = TypeVar("FlagT", bound=Flag)
CollectionsMapping = collections.abc.Mapping


class BaseEnumMappingGenerator(ABC):
@abstractmethod
def _generate_mapping(self, cases: Iterable[EnumT]) -> Mapping[EnumT, str]:
...

@final
def generate_for_dumping(self, cases: Iterable[EnumT]) -> Mapping[EnumT, str]:
return self._generate_mapping(cases)

@final
def generate_for_loading(self, cases: Iterable[EnumT]) -> Mapping[str, EnumT]:
return {
mapping_result: case
for case, mapping_result in self._generate_mapping(cases).items()
}


class ByNameEnumMappingGenerator(BaseEnumMappingGenerator):
def __init__(
self,
name_style: Optional[NameStyle] = None,
map: Optional[Mapping[Union[str, Enum], str]] = None # noqa: A002
):
self._name_style = name_style
self._map = map if map is not None else {}

def _generate_mapping(self, cases: Iterable[EnumT]) -> Mapping[EnumT, str]:
result = {}

for case in cases:
if case in self._map:
mapped = self._map[case]
elif case.name in self._map:
mapped = self._map[case.name]
elif self._name_style:
mapped = convert_snake_style(case.name, self._name_style)
else:
mapped = case.name
result[case] = mapped

return result


class AnyEnumLSC(LastLocMapChecker):
def _check_location(self, mediator: DirectMediator, loc: TypeHintLoc) -> bool:
try:
norm = normalize_type(loc.type)
except ValueError:
return False
return isinstance(norm.origin, EnumMeta)
origin = norm.origin
return isinstance(origin, EnumMeta) and not is_subclass_soft(origin, Flag)


class FlagEnumLSC(LastLocMapChecker):
def _check_location(self, mediator: DirectMediator, loc: TypeHintLoc) -> bool:
try:
norm = normalize_type(loc.type)
except ValueError:
return False
return is_subclass_soft(norm.origin, Flag)


@for_predicate(AnyEnumLSC())
class BaseEnumProvider(LoaderProvider, DumperProvider, ABC):
pass


@for_predicate(FlagEnumLSC())
class BaseFlagProvider(LoaderProvider, DumperProvider, ABC):
pass


def _enum_name_dumper(data):
return data.name

Expand All @@ -36,14 +111,6 @@ class EnumNameProvider(BaseEnumProvider):

def _provide_loader(self, mediator: Mediator, request: LoaderRequest) -> Loader:
enum = get_type_from_request(request)

if issubclass(enum, Flag):
raise CannotProvide(
"Flag subclasses is not supported yet",
is_terminal=True,
is_demonstrative=True
)

variants = [case.name for case in enum]

def enum_loader(data):
Expand Down Expand Up @@ -155,3 +222,137 @@ def _get_exact_value_to_member(self, enum: Type[Enum]) -> Optional[Mapping[Any,

def _provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper:
return _enum_exact_value_dumper


class FlagByExactValueProvider(BaseFlagProvider):
def _provide_loader(self, mediator: Mediator, request: LoaderRequest) -> Loader:
enum = get_type_from_request(request)
flag_mask = reduce(or_, enum.__members__.values()).value

if flag_mask < 0:
raise CannotProvide(
"Cannot create a loader for flag with negative values",
is_terminal=True,
is_demonstrative=True,
)

all_bits = 2 ** flag_mask.bit_length() - 1
if all_bits != flag_mask:
raise CannotProvide(
"Cannot create a loader for flag with skipped bits",
is_terminal=True,
is_demonstrative=True,
)

def flag_loader(data):
if type(data) is not int: # pylint: disable=unidiomatic-typecheck
raise TypeLoadError(int, data)

if data < 0 or data > flag_mask:
raise OutOfRange(0, flag_mask, data)

# data already has been validated for all edge cases
# so enum lookup cannot raise an error

return enum(data)

return flag_loader

def _provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper:
return _enum_exact_value_dumper


def _extract_non_compound_cases_from_flag(enum: Type[FlagT]) -> Sequence[FlagT]:
return [case for case in enum.__members__.values() if not math.log2(case.value) % 1]


class FlagByListProvider(BaseFlagProvider):
def __init__(
self,
mapping_generator: BaseEnumMappingGenerator,
allow_single_value: bool = False,
allow_duplicates: bool = True,
allow_compound: bool = True,
):
self._mapping_generator = mapping_generator
self._allow_single_value = allow_single_value
self._allow_duplicates = allow_duplicates
self._allow_compound = allow_compound

def _get_cases(self, enum: Type[FlagT]) -> Sequence[FlagT]:
if self._allow_compound:
return list(enum.__members__.values())
return _extract_non_compound_cases_from_flag(enum)

def _provide_loader(self, mediator: Mediator, request: LoaderRequest) -> Loader: # noqa: CCR001
enum = get_type_from_request(request)

strict_coercion = mediator.mandatory_provide(StrictCoercionRequest(loc_stack=request.loc_stack))
allow_single_value = self._allow_single_value
allow_duplicates = self._allow_duplicates

cases = self._get_cases(enum)
mapping = self._mapping_generator.generate_for_loading(cases)
variants = list(mapping.keys())
zero_case = enum(0)

# treat str and Iterable[str] as different types
expected_type = Union[str, Iterable[str]] if allow_single_value else Iterable[str]

def flag_loader(data) -> Flag: # noqa: CCR001
data_type = type(data)

if isinstance(data, Iterable) and data_type is not str:
if strict_coercion and isinstance(data, CollectionsMapping):
raise ExcludedTypeLoadError(expected_type, Mapping, data)
process_data = tuple(data)
else:
if not allow_single_value or data_type is not str:
raise TypeLoadError(expected_type, data)
process_data = (data,)

if not allow_duplicates:
if len(process_data) != len(set(process_data)):
raise DuplicatedValues(data)

bad_variants = []
result = zero_case
for item in process_data:
if item not in variants:
bad_variants.append(item)
continue
result |= mapping[item]

if bad_variants:
raise MultipleBadVariant(
allowed_values=variants,
invalid_values=bad_variants,
input_value=data,
)

return result

return flag_loader

def _provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper:
enum = get_type_from_request(request)

cases = self._get_cases(enum)
need_to_reverse = self._allow_compound and cases != _extract_non_compound_cases_from_flag(enum)
if need_to_reverse:
cases = tuple(reversed(cases))

mapping = self._mapping_generator.generate_for_dumping(cases)

zero_case = enum(0)

def flag_dumper(value: Flag) -> Sequence[str]:
result = []
cases_sum = zero_case
for case in cases:
if case in value and case not in cases_sum:
cases_sum |= case
result.append(mapping[case])
return list(reversed(result)) if need_to_reverse else result

return flag_dumper
67 changes: 66 additions & 1 deletion src/adaptix/_internal/morphing/facade/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,14 @@
from ...special_cases_optimization import as_is_stub
from ...utils import Omittable, Omitted
from ..dict_provider import DefaultDictProvider
from ..enum_provider import EnumExactValueProvider, EnumNameProvider, EnumValueProvider
from ..enum_provider import (
ByNameEnumMappingGenerator,
EnumExactValueProvider,
EnumNameProvider,
EnumValueProvider,
FlagByExactValueProvider,
FlagByListProvider,
)
from ..load_error import LoadError, ValidationError
from ..model.loader_provider import InlinedShapeModelLoaderProvider
from ..name_layout.base import ExtraIn, ExtraOut
Expand Down Expand Up @@ -350,6 +357,64 @@ def enum_by_value(first_pred: EnumPred, /, *preds: EnumPred, tp: TypeHint) -> Pr
return _wrap_enum_provider([first_pred, *preds], EnumValueProvider(tp))


def flag_by_exact_value(*preds: EnumPred) -> Provider:
"""Provider that represents flag members to the outside world by their value without any processing.
It does not support flags with skipped bits and negative values (it is recommended to use ``enum.auto()``
to define flag values instead of manually specifying them).
:param preds: Predicates specifying where the provider should be used.
The provider will be applied if any predicates meet the conditions,
if no predicates are passed, the provider will be used for all Flags.
See :ref:`predicate-system` for details.
:return: desired provider
"""
return _wrap_enum_provider(preds, FlagByExactValueProvider())


def flag_by_member_names(
*preds: EnumPred,
allow_single_value: bool = False,
allow_duplicates: bool = True,
allow_compound: bool = True,
name_style: Optional[NameStyle] = None,
map: Optional[Mapping[Union[str, Enum], str]] = None # noqa: A002
) -> Provider:
"""Provider that represents flag members to the outside world by list of their names.
Loader takes a flag members name list and returns united flag member
(given members combined by operator ``|``, namely `bitwise or`).
Dumper takes a flag member and returns a list of names of flag members, included in the given flag member.
:param preds: Predicates specifying where the provider should be used.
The provider will be applied if any predicates meet the conditions,
if no predicates are passed, the provider will be used for all Flags.
See :ref:`predicate-system` for details.
:param allow_single_value: Allows calling the loader with a single value.
If this is allowed, singlular values are treated as one element list.
:param allow_duplicates: Allows calling the loader with a list containing non-unique elements.
Unless this is allowed, loader will raise :exc:`.DuplicatedValues` in that case.
:param allow_compound: Allows the loader to accept names of compound members
(e.g. ``WHITE = RED | GREEN | BLUE``) and the dumper to return names of compound members.
If this is allowed, dumper will use compound members names to serialize value.
:param name_style: Name style for representing members to the outside world.
If it is set, the provider will automatically convert the names of all flag members to the specified convention.
:param map: Mapping for representing members to the outside world.
If it is set, the provider will use it to rename members individually;
its keys can either be member names as strings or member instances.
:return: desired provider
"""
return _wrap_enum_provider(
preds,
FlagByListProvider(
ByNameEnumMappingGenerator(name_style=name_style, map=map),
allow_single_value=allow_single_value,
allow_duplicates=allow_duplicates,
allow_compound=allow_compound,
),
)


def validator(
pred: Pred,
func: Callable[[Any], bool],
Expand Down
Loading

0 comments on commit 7ff21d7

Please sign in to comment.