Skip to content

Commit

Permalink
Refactor (#99)
Browse files Browse the repository at this point in the history
* chore: refactor for compute
  • Loading branch information
BobTheBuidler authored Nov 9, 2023
1 parent 1a37bd2 commit 089f42a
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 19 deletions.
9 changes: 5 additions & 4 deletions a_sync/_kwargs.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@

from typing import Optional

from a_sync import _flags, exceptions


def get_flag_name(kwargs: dict) -> str:
def get_flag_name(kwargs: dict) -> Optional[str]:
present_flags = [flag for flag in _flags.VIABLE_FLAGS if flag in kwargs]
if len(present_flags) == 0:
raise exceptions.NoFlagsFound('kwargs', kwargs.keys())
return None
if len(present_flags) != 1:
raise exceptions.TooManyFlags('kwargs', present_flags)
return present_flags[0]

def is_sync(kwargs: dict, pop_flag: bool = False) -> bool:
flag = get_flag_name(kwargs)
def is_sync(flag: str, kwargs: dict, pop_flag: bool = False) -> bool:
flag_value = kwargs.pop(flag) if pop_flag else kwargs[flag]
return _flags.negate_if_necessary(flag, flag_value)
19 changes: 10 additions & 9 deletions a_sync/abstract.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@

import abc
import logging
from typing import Union

from a_sync import _flags, _kwargs, exceptions, modifiers
from a_sync._meta import ASyncMeta
from a_sync._typing import *

from a_sync.exceptions import NoFlagsFound

logger = logging.getLogger(__name__)

Expand All @@ -32,19 +31,21 @@ def __should_await_from_instance(self) -> bool:

def __should_await_from_kwargs(self, kwargs: dict) -> bool:
"""You can override this if you want."""
return _kwargs.is_sync(kwargs, pop_flag=True)
if flag := _kwargs.get_flag_name(kwargs):
return _kwargs.is_sync(flag, kwargs, pop_flag=True)
else:
raise NoFlagsFound("kwargs", kwargs.keys())

@classmethod
def __a_sync_instance_will_be_sync__(cls, args: tuple, kwargs: dict) -> bool:
"""You can override this if you want."""
try:
logger.debug("checking `%s.%s.__init__` signature against provided kwargs to determine a_sync mode for the new instance", cls.__module__, cls.__name__)
sync = _kwargs.is_sync(kwargs)
logger.debug("checking `%s.%s.__init__` signature against provided kwargs to determine a_sync mode for the new instance", cls.__module__, cls.__name__)
if flag := _kwargs.get_flag_name(kwargs):
sync = _kwargs.is_sync(flag, kwargs)
logger.debug("kwargs indicate the new instance created with args %s %s is %ssynchronous", args, kwargs, 'a' if sync is False else '')
return sync
except exceptions.NoFlagsFound:
logger.debug("No valid flags found in kwargs, checking class definition for defined default")
return cls.__a_sync_default_mode__() # type: ignore [return-value]
logger.debug("No valid flags found in kwargs, checking class definition for defined default")
return cls.__a_sync_default_mode__() # type: ignore [return-value]

######################################
# Concrete Methods (non-overridable) #
Expand Down
6 changes: 3 additions & 3 deletions a_sync/modified.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ def _async_def(self) -> bool:
return asyncio.iscoroutinefunction(self.__wrapped__)

def _run_sync(self, kwargs: dict):
try:
if flag := _kwargs.get_flag_name(kwargs):
# If a flag was specified in the kwargs, we will defer to it.
return _kwargs.is_sync(kwargs, pop_flag=True)
except exceptions.NoFlagsFound:
return _kwargs.is_sync(flag, kwargs, pop_flag=True)
else:
# No flag specified in the kwargs, we will defer to 'default'.
return self._sync_default

Expand Down
5 changes: 2 additions & 3 deletions a_sync/utils/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,10 @@ def done_callback(t: asyncio.Task) -> None:
get_task = asyncio.create_task(coro=queue.get(), name=str(queue))
_chain_future(get_task, next_fut)
for item in (await next_fut, *queue.get_nowait(-1)):
if not isinstance(item, _Done):
yield item
else:
if isinstance(item, _Done):
task.cancel()
return
yield item

if e := task.exception():
get_task.cancel()
Expand Down

0 comments on commit 089f42a

Please sign in to comment.