diff --git a/src/adaptix/_internal/morphing/enum_provider.py b/src/adaptix/_internal/morphing/enum_provider.py index 6b8a7054..4fb1ad5e 100644 --- a/src/adaptix/_internal/morphing/enum_provider.py +++ b/src/adaptix/_internal/morphing/enum_provider.py @@ -1,7 +1,7 @@ import math from abc import ABC from enum import Enum, EnumMeta, Flag -from typing import Any, Iterable, Mapping, Optional, Type, Union +from typing import Any, Iterable, Mapping, Optional, Sequence, Type, TypeVar, Union from ..common import Dumper, Loader, TypeHint from ..morphing.provider_template import DumperProvider, LoaderProvider @@ -13,6 +13,8 @@ from .load_error import BadVariantError, MsgError, MultipleBadVariant, TypeLoadError, ValueLoadError from .request_cls import DumperRequest, LoaderRequest +FlagT = TypeVar("FlagT", bound=Flag) + class AnyEnumLSC(LastLocMapChecker): def _check_location(self, mediator: DirectMediator, loc: TypeHintLoc) -> bool: @@ -158,7 +160,7 @@ def _provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper: return _enum_exact_value_dumper -def _extract_non_compound_cases_from_flag(enum: Type[Flag]) -> Iterable[Flag]: +def _extract_non_compound_cases_from_flag(enum: Type[FlagT]) -> Iterable[FlagT]: return [case for case in enum.__members__.values() if not math.log2(case.value) % 1] @@ -178,18 +180,20 @@ def __init__( self._loader = _enum_exact_value_loader if by_exact_value else _enum_name_loader self._dumper = _enum_exact_value_dumper if by_exact_value else _enum_name_dumper - def _get_cases(self, enum: Type[Flag]): + def _get_cases(self, enum: Type[FlagT]) -> Iterable[FlagT]: if self._allow_compound: return enum.__members__.values() return _extract_non_compound_cases_from_flag(enum) - def _get_loader_process_data(self, data: Union[int, Iterable[int], Iterable[str]], enum: Type[Flag]): + def _get_loader_process_data( + self, data: Union[int, Iterable[int], Iterable[str]], enum: Type[Flag] + ) -> Union[Sequence[int], Sequence[str]]: if isinstance(data, (str, int)): if not self._allow_single_value: raise TypeLoadError(expected_type=Iterable[str], input_value=data) - process_data = [data] + process_data: Union[Sequence[int], Sequence[str]] = [data] # type: ignore[assignment] else: - process_data = list(data) + process_data = list(data) # type: ignore[assignment] if not self._allow_duplicates: if len(process_data) != len(set(process_data)): @@ -229,7 +233,7 @@ def _flag_loader(data: Union[int, Iterable[int], Iterable[str]]) -> Flag: def _provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper: enum = get_type_from_request(request) - def flag_dumper(value: Flag) -> Iterable[str]: + def flag_dumper(value: Flag) -> Union[Iterable[int], Iterable[str]]: cases = self._get_cases(enum) if value in cases: return [self._dumper(value)]