Skip to content

Commit

Permalink
fix typehints
Browse files Browse the repository at this point in the history
  • Loading branch information
daler-sz committed Jan 21, 2024
1 parent 3977a1f commit 6d14c75
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions src/adaptix/_internal/morphing/enum_provider.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
from abc import ABC
from enum import Enum, EnumMeta, Flag
from typing import Any, Iterable, Mapping, Optional, Type, Union
from typing import Any, Iterable, Mapping, Optional, Sequence, Type, TypeVar, Union

from ..common import Dumper, Loader, TypeHint
from ..morphing.provider_template import DumperProvider, LoaderProvider
Expand All @@ -13,6 +13,8 @@
from .load_error import BadVariantError, MsgError, MultipleBadVariant, TypeLoadError, ValueLoadError
from .request_cls import DumperRequest, LoaderRequest

FlagT = TypeVar("FlagT", bound=Flag)


class AnyEnumLSC(LastLocMapChecker):
def _check_location(self, mediator: DirectMediator, loc: TypeHintLoc) -> bool:
Expand Down Expand Up @@ -158,7 +160,7 @@ def _provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper:
return _enum_exact_value_dumper


def _extract_non_compound_cases_from_flag(enum: Type[Flag]) -> Iterable[Flag]:
def _extract_non_compound_cases_from_flag(enum: Type[FlagT]) -> Iterable[FlagT]:
return [case for case in enum.__members__.values() if not math.log2(case.value) % 1]


Expand All @@ -178,18 +180,20 @@ def __init__(
self._loader = _enum_exact_value_loader if by_exact_value else _enum_name_loader
self._dumper = _enum_exact_value_dumper if by_exact_value else _enum_name_dumper

def _get_cases(self, enum: Type[Flag]):
def _get_cases(self, enum: Type[FlagT]) -> Iterable[FlagT]:
if self._allow_compound:
return enum.__members__.values()
return _extract_non_compound_cases_from_flag(enum)

def _get_loader_process_data(self, data: Union[int, Iterable[int], Iterable[str]], enum: Type[Flag]):
def _get_loader_process_data(
self, data: Union[int, Iterable[int], Iterable[str]], enum: Type[Flag]
) -> Union[Sequence[int], Sequence[str]]:
if isinstance(data, (str, int)):
if not self._allow_single_value:
raise TypeLoadError(expected_type=Iterable[str], input_value=data)
process_data = [data]
process_data: Union[Sequence[int], Sequence[str]] = [list] # type: ignore[assignment]
else:
process_data = list(data)
process_data = list(data) # type: ignore[assignment]

if not self._allow_duplicates:
if len(process_data) != len(set(process_data)):
Expand Down Expand Up @@ -229,7 +233,7 @@ def _flag_loader(data: Union[int, Iterable[int], Iterable[str]]) -> Flag:
def _provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper:
enum = get_type_from_request(request)

def flag_dumper(value: Flag) -> Iterable[str]:
def flag_dumper(value: Flag) -> Union[Iterable[int], Iterable[str]]:
cases = self._get_cases(enum)
if value in cases:
return [self._dumper(value)]
Expand Down

0 comments on commit 6d14c75

Please sign in to comment.