Skip to content

Commit

Permalink
Merge branch 'main' into feature/taskcontext-and-taskscope
Browse files Browse the repository at this point in the history
  • Loading branch information
achimnol authored Dec 22, 2023
2 parents d2f2465 + 75ac9e3 commit 5647a22
Show file tree
Hide file tree
Showing 13 changed files with 130 additions and 27 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
python-version: ["3.11"]
python-version:
- "3.11"
- "3.12"
steps:
- uses: actions/checkout@v3
- name: Set up Python
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ celerybeat-schedule
.envrc

# virtualenv
.venv/
venv/
ENV/

Expand Down
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 23.7.0
rev: 23.9.1
hooks:
- id: black
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.285
rev: v0.0.291
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand Down
1 change: 1 addition & 0 deletions changes/62.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `aiotools.context.resetting()` as a sync/async context manager to auto-reset the given context variable
1 change: 1 addition & 0 deletions changes/63.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type checker support - now includes py.typed in the package to indicate to type checkers like mypy that typing is supported.
1 change: 1 addition & 0 deletions changes/64.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add Python 3.12 support
5 changes: 3 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ classifiers =
Programming Language :: Python
Programming Language :: Python :: 3
Programming Language :: Python :: 3.11
Programming Language :: Python :: 3.12
Topic :: Software Development
url = https://github.com/achimnol/aiotools
project_urls =
Expand All @@ -41,7 +42,7 @@ build =
twine~=4.0
towncrier~=22.12
test =
pytest~=7.2.2
pytest~=7.4.2
pytest-asyncio~=0.21
pytest-cov
pytest-mock
Expand All @@ -53,7 +54,7 @@ lint =
ruff>=0.0.285
ruff-lsp>=0.0.37
typecheck =
mypy~=1.4.1
mypy~=1.5.1
docs =
sphinx~=4.3
sphinx-rtd-theme~=1.0
Expand Down
56 changes: 49 additions & 7 deletions src/aiotools/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,19 @@

import asyncio
import contextlib
from typing import Iterable, List, Optional
from contextvars import ContextVar
from typing import (
Generic,
Iterable,
List,
Optional,
TypeVar,
)

from .types import AsyncClosable

__all__ = [
"resetting",
"AsyncContextManager",
"async_ctx_manager",
"actxmgr",
Expand All @@ -24,29 +34,61 @@
]


T = TypeVar("T")
T_AsyncClosable = TypeVar("T_AsyncClosable", bound=AsyncClosable)

AbstractAsyncContextManager = contextlib.AbstractAsyncContextManager
AsyncContextManager = contextlib._AsyncGeneratorContextManager
AsyncExitStack = contextlib.AsyncExitStack
async_ctx_manager = contextlib.asynccontextmanager
aclosing = contextlib.aclosing


class closing_async:
class resetting(Generic[T]):
"""
An extra context manager to auto-reset the given context variable.
It supports both the standard contextmanager protocol and the
async-contextmanager protocol.
.. versionadded:: 1.8.0
"""
An analogy to :func:`contextlib.closing` for objects with ``close()``
methods as async functions.

def __init__(self, ctxvar: ContextVar[T], value: T) -> None:
self._ctxvar = ctxvar
self._value = value

def __enter__(self) -> None:
self._token = self._ctxvar.set(self._value)

async def __aenter__(self) -> None:
self._token = self._ctxvar.set(self._value)

def __exit__(self, *exc_info) -> Optional[bool]:
self._ctxvar.reset(self._token)
return None

async def __aexit__(self, *exc_info) -> Optional[bool]:
self._ctxvar.reset(self._token)
return None


class closing_async(Generic[T_AsyncClosable]):
"""
An analogy to :func:`contextlib.closing` for objects defining the ``close()``
method as an async function.
.. versionadded:: 1.5.6
"""

def __init__(self, thing):
def __init__(self, thing: T_AsyncClosable) -> None:
self.thing = thing

async def __aenter__(self):
async def __aenter__(self) -> T_AsyncClosable:
return self.thing

async def __aexit__(self, *args):
async def __aexit__(self, *exc_info) -> Optional[bool]:
await self.thing.close()
return None


class AsyncContextGroup:
Expand Down
Empty file added src/aiotools/py.typed
Empty file.
26 changes: 14 additions & 12 deletions src/aiotools/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,18 +228,20 @@ def helper(*args, **kwargs):


def setup_child_watcher(loop: asyncio.AbstractEventLoop) -> None:
try:
watcher_cls = getattr(asyncio, "PidfdChildWatcher", None)
if _has_pidfd and watcher_cls:
watcher = watcher_cls()
asyncio.set_child_watcher(watcher)
else:
# Just get the default child watcher.
watcher = asyncio.get_child_watcher()
if not watcher.is_active():
watcher.attach_loop(loop)
except NotImplementedError:
pass # for uvloop
if sys.version_info < (3, 12, 0):
# see python/cpython#94597 (issue) and python/cpython#98215 (pr)
try:
watcher_cls = getattr(asyncio, "PidfdChildWatcher", None)
if _has_pidfd and watcher_cls:
watcher = watcher_cls()
asyncio.set_child_watcher(watcher)
else:
# Just get the default child watcher.
watcher = asyncio.get_child_watcher()
if not watcher.is_active():
watcher.attach_loop(loop)
except NotImplementedError:
pass # for uvloop


async def cancel_all_tasks() -> None:
Expand Down
3 changes: 1 addition & 2 deletions src/aiotools/taskgroup/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ async def __call__(
exc_type: type[Exception],
exc_obj: Exception,
exc_tb: TracebackType,
) -> None:
...
) -> None: ...


class MultiError(ExceptionGroup):
Expand Down
8 changes: 8 additions & 0 deletions src/aiotools/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,20 @@
Awaitable,
Coroutine,
Generator,
Protocol,
TypeAlias,
TypeVar,
runtime_checkable,
)

_T = TypeVar("_T")


@runtime_checkable
class AsyncClosable(Protocol):
async def close(self) -> None: ...


# taken from the typeshed
if sys.version_info >= (3, 12):
AwaitableLike: TypeAlias = Awaitable[_T] # noqa: Y047
Expand Down
47 changes: 46 additions & 1 deletion tests/test_context.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,56 @@
import asyncio
import sys
import warnings
from contextlib import suppress
from contextvars import ContextVar

import pytest

import aiotools
from aiotools.context import AbstractAsyncContextManager
from aiotools.context import AbstractAsyncContextManager, resetting

my_variable: ContextVar[int] = ContextVar("my_variable")


def test_resetting_ctxvar():
with pytest.raises(LookupError):
my_variable.get()
with resetting(my_variable, 1):
assert my_variable.get() == 1
with resetting(my_variable, 2):
assert my_variable.get() == 2
assert my_variable.get() == 1
with pytest.raises(LookupError):
my_variable.get()

# should behave the same way even when an exception occurs
with suppress(RuntimeError):
with resetting(my_variable, 10):
assert my_variable.get() == 10
raise RuntimeError("oops")
with pytest.raises(LookupError):
my_variable.get()


@pytest.mark.asyncio
async def test_resetting_ctxvar_async():
with pytest.raises(LookupError):
my_variable.get()
async with resetting(my_variable, 1):
assert my_variable.get() == 1
async with resetting(my_variable, 2):
assert my_variable.get() == 2
assert my_variable.get() == 1
with pytest.raises(LookupError):
my_variable.get()

# should behave the same way even when an exception occurs
with suppress(RuntimeError):
async with resetting(my_variable, 10):
assert my_variable.get() == 10
raise RuntimeError("oops")
with pytest.raises(LookupError):
my_variable.get()


def test_actxmgr_types():
Expand Down

0 comments on commit 5647a22

Please sign in to comment.