Skip to content

Commit

Permalink
Add typing to auxiliary classes around Pool (#1197)
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielNoord authored Nov 4, 2024
1 parent a273e0e commit e9bb695
Showing 1 changed file with 55 additions and 29 deletions.
84 changes: 55 additions & 29 deletions asyncpg/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0

from __future__ import annotations

import asyncio
from collections.abc import Awaitable, Callable
import functools
import inspect
import logging
import time
from types import TracebackType
from typing import Any, Optional, Type
import warnings

from . import compat
Expand All @@ -23,7 +27,14 @@

class PoolConnectionProxyMeta(type):

def __new__(mcls, name, bases, dct, *, wrap=False):
def __new__(
mcls,
name: str,
bases: tuple[Type[Any], ...],
dct: dict[str, Any],
*,
wrap: bool = False,
) -> PoolConnectionProxyMeta:
if wrap:
for attrname in dir(connection.Connection):
if attrname.startswith('_') or attrname in dct:
Expand All @@ -44,8 +55,10 @@ def __new__(mcls, name, bases, dct, *, wrap=False):
return super().__new__(mcls, name, bases, dct)

@staticmethod
def _wrap_connection_method(meth_name, iscoroutine):
def call_con_method(self, *args, **kwargs):
def _wrap_connection_method(
meth_name: str, iscoroutine: bool
) -> Callable[..., Any]:
def call_con_method(self: Any, *args: Any, **kwargs: Any) -> Any:
# This method will be owned by PoolConnectionProxy class.
if self._con is None:
raise exceptions.InterfaceError(
Expand All @@ -68,25 +81,26 @@ class PoolConnectionProxy(connection._ConnectionProxy,

__slots__ = ('_con', '_holder')

def __init__(self, holder: 'PoolConnectionHolder',
con: connection.Connection):
def __init__(
self, holder: PoolConnectionHolder, con: connection.Connection
) -> None:
self._con = con
self._holder = holder
con._set_proxy(self)

def __getattr__(self, attr):
def __getattr__(self, attr: str) -> Any:
# Proxy all unresolved attributes to the wrapped Connection object.
return getattr(self._con, attr)

def _detach(self) -> connection.Connection:
def _detach(self) -> Optional[connection.Connection]:
if self._con is None:
return

con, self._con = self._con, None
con._set_proxy(None)
return con

def __repr__(self):
def __repr__(self) -> str:
if self._con is None:
return '<{classname} [released] {id:#x}>'.format(
classname=self.__class__.__name__, id=id(self))
Expand All @@ -103,27 +117,34 @@ class PoolConnectionHolder:
'_inactive_callback', '_timeout',
'_generation')

def __init__(self, pool, *, max_queries, setup, max_inactive_time):
def __init__(
self,
pool: "Pool",
*,
max_queries: float,
setup: Optional[Callable[[PoolConnectionProxy], Awaitable[None]]],
max_inactive_time: float,
) -> None:

self._pool = pool
self._con = None
self._proxy = None
self._con: Optional[connection.Connection] = None
self._proxy: Optional[PoolConnectionProxy] = None

self._max_queries = max_queries
self._max_inactive_time = max_inactive_time
self._setup = setup
self._inactive_callback = None
self._in_use = None # type: asyncio.Future
self._timeout = None
self._generation = None
self._inactive_callback: Optional[Callable] = None
self._in_use: Optional[asyncio.Future] = None
self._timeout: Optional[float] = None
self._generation: Optional[int] = None

def is_connected(self):
def is_connected(self) -> bool:
return self._con is not None and not self._con.is_closed()

def is_idle(self):
def is_idle(self) -> bool:
return not self._in_use

async def connect(self):
async def connect(self) -> None:
if self._con is not None:
raise exceptions.InternalClientError(
'PoolConnectionHolder.connect() called while another '
Expand Down Expand Up @@ -171,7 +192,7 @@ async def acquire(self) -> PoolConnectionProxy:

return proxy

async def release(self, timeout):
async def release(self, timeout: Optional[float]) -> None:
if self._in_use is None:
raise exceptions.InternalClientError(
'PoolConnectionHolder.release() called on '
Expand Down Expand Up @@ -234,25 +255,25 @@ async def release(self, timeout):
# Rearm the connection inactivity timer.
self._setup_inactive_callback()

async def wait_until_released(self):
async def wait_until_released(self) -> None:
if self._in_use is None:
return
else:
await self._in_use

async def close(self):
async def close(self) -> None:
if self._con is not None:
# Connection.close() will call _release_on_close() to
# finish holder cleanup.
await self._con.close()

def terminate(self):
def terminate(self) -> None:
if self._con is not None:
# Connection.terminate() will call _release_on_close() to
# finish holder cleanup.
self._con.terminate()

def _setup_inactive_callback(self):
def _setup_inactive_callback(self) -> None:
if self._inactive_callback is not None:
raise exceptions.InternalClientError(
'pool connection inactivity timer already exists')
Expand All @@ -261,12 +282,12 @@ def _setup_inactive_callback(self):
self._inactive_callback = self._pool._loop.call_later(
self._max_inactive_time, self._deactivate_inactive_connection)

def _maybe_cancel_inactive_callback(self):
def _maybe_cancel_inactive_callback(self) -> None:
if self._inactive_callback is not None:
self._inactive_callback.cancel()
self._inactive_callback = None

def _deactivate_inactive_connection(self):
def _deactivate_inactive_connection(self) -> None:
if self._in_use is not None:
raise exceptions.InternalClientError(
'attempting to deactivate an acquired connection')
Expand All @@ -280,12 +301,12 @@ def _deactivate_inactive_connection(self):
# so terminate() above will not call the below.
self._release_on_close()

def _release_on_close(self):
def _release_on_close(self) -> None:
self._maybe_cancel_inactive_callback()
self._release()
self._con = None

def _release(self):
def _release(self) -> None:
"""Release this connection holder."""
if self._in_use is None:
# The holder is not checked out.
Expand Down Expand Up @@ -1012,7 +1033,7 @@ class PoolAcquireContext:

__slots__ = ('timeout', 'connection', 'done', 'pool')

def __init__(self, pool, timeout):
def __init__(self, pool: Pool, timeout: Optional[float]) -> None:
self.pool = pool
self.timeout = timeout
self.connection = None
Expand All @@ -1024,7 +1045,12 @@ async def __aenter__(self):
self.connection = await self.pool._acquire(self.timeout)
return self.connection

async def __aexit__(self, *exc):
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_val: Optional[BaseException] = None,
exc_tb: Optional[TracebackType] = None,
) -> None:
self.done = True
con = self.connection
self.connection = None
Expand Down

0 comments on commit e9bb695

Please sign in to comment.