Skip to content

Commit

Permalink
fix: ASyncGenericSingleton when flag defined on class def (#103)
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler authored Jan 17, 2024
1 parent fafe58f commit 2421f0d
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 14 deletions.
55 changes: 41 additions & 14 deletions a_sync/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@

import inspect
import logging
from contextlib import suppress
from functools import cached_property
from typing import Optional
from typing import Any, Dict, Optional

from a_sync import _flags, exceptions
from a_sync.abstract import ASyncABC
Expand Down Expand Up @@ -47,30 +48,36 @@ def __a_sync_flag_value__(self) -> bool:

@classmethod # type: ignore [misc]
def __a_sync_default_mode__(cls) -> bool:
flag = cls.__get_a_sync_flag_name_from_signature()
flag_value = cls.__a_sync_flag_default_value_from_signature()
try:
flag = cls.__get_a_sync_flag_name_from_signature()
flag_value = cls.__a_sync_flag_default_value_from_signature()
except exceptions.NoFlagsFound:
flag = cls.__get_a_sync_flag_name_from_class_def()
flag_value = cls.__get_a_sync_flag_value_from_class_def(flag)
sync = _flags.negate_if_necessary(flag, flag_value) # type: ignore [arg-type]
logger.debug("`%s.%s` indicates default mode is %ssynchronous", cls, flag, 'a' if sync is False else '')
return sync

@classmethod
def __get_a_sync_flag_name_from_signature(cls) -> Optional[str]:
logger.debug("Searching for flags defined on %s", cls)
logger.debug("Searching for flags defined on %s.__init__", cls)
if cls.__name__ == "ASyncGenericBase":
logger.debug("There are no flags defined on the base class, this is expected. Skipping.")
return None
parameters = inspect.signature(cls.__init__).parameters
logger.debug("parameters: %s", parameters)
present_flags = [flag for flag in _flags.VIABLE_FLAGS if flag in parameters]
if len(present_flags) == 0:
logger.debug("There are too many flags defined on %s", cls)
raise exceptions.NoFlagsFound(cls, parameters.keys())
if len(present_flags) > 1:
logger.debug("There are too many flags defined on %s", cls)
raise exceptions.TooManyFlags(cls, present_flags)
flag = present_flags[0]
logger.debug("found flag %s", flag)
return flag
return cls.__parse_flag_name_from_list(parameters)

@classmethod
def __get_a_sync_flag_name_from_class_def(cls) -> Optional[str]:
logger.debug("Searching for flags defined on %s", cls)
try:
return cls.__parse_flag_name_from_list(cls.__dict__)
except exceptions.NoFlagsFound:
for base in cls.__bases__:
with suppress(exceptions.NoFlagsFound):
return cls.__parse_flag_name_from_list(base.__dict__)
raise exceptions.NoFlagsFound(cls, list(cls.__dict__.keys()))

@classmethod # type: ignore [misc]
def __a_sync_flag_default_value_from_signature(cls) -> bool:
Expand All @@ -84,3 +91,23 @@ def __a_sync_flag_default_value_from_signature(cls) -> bool:
)
logger.debug('%s defines %s, default value %s', cls, flag, flag_value)
return flag_value

@classmethod
def __get_a_sync_flag_value_from_class_def(cls, flag: str) -> Optional[bool]:
for spec in [cls, *cls.__bases__]:
flag_value = spec.__dict__.get(flag)
if flag_value is not None:
return flag_value

@classmethod
def __parse_flag_name_from_list(cls, items: Dict[str, Any]) -> Optional[str]:
present_flags = [flag for flag in _flags.VIABLE_FLAGS if flag in items]
if len(present_flags) == 0:
logger.debug("There are too many flags defined on %s", cls)
raise exceptions.NoFlagsFound(cls, items.keys())
if len(present_flags) > 1:
logger.debug("There are too many flags defined on %s", cls)
raise exceptions.TooManyFlags(cls, present_flags)
flag = present_flags[0]
logger.debug("found flag %s", flag)
return flag
20 changes: 20 additions & 0 deletions tests/test_singleton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@

from a_sync.singleton import ASyncGenericSingleton

def test_flag_predefined():
"""We had a failure case where the subclass implementation assigned the flag value to the class and did not allow user to determine at init time"""
class Test(ASyncGenericSingleton):
sync=True
def __init__(self):
...
Test()
class TestInherit(Test):
...
TestInherit()

class Test(ASyncGenericSingleton):
sync=False
Test()
class TestInherit(Test):
...
TestInherit()

0 comments on commit 2421f0d

Please sign in to comment.