diff --git a/Makefile b/Makefile index f0214347..8173a850 100644 --- a/Makefile +++ b/Makefile @@ -7,6 +7,7 @@ DOCKER_IMAGE=aiolibs/kafka:$(SCALA_VERSION)_$(KAFKA_VERSION) DIFF_BRANCH=origin/master FORMATTED_AREAS=\ aiokafka/codec.py \ + aiokafka/conn.py \ aiokafka/coordinator/ \ aiokafka/errors.py \ aiokafka/helpers.py \ @@ -15,6 +16,7 @@ FORMATTED_AREAS=\ aiokafka/protocol/ \ aiokafka/record/ \ tests/test_codec.py \ + tests/test_conn.py \ tests/test_helpers.py \ tests/test_protocol.py \ tests/test_protocol_object_conversion.py \ diff --git a/aiokafka/abc.py b/aiokafka/abc.py index abb9f216..f9264574 100644 --- a/aiokafka/abc.py +++ b/aiokafka/abc.py @@ -1,4 +1,5 @@ import abc +from typing import Dict class ConsumerRebalanceListener(abc.ABC): @@ -103,7 +104,7 @@ class AbstractTokenProvider(abc.ABC): """ @abc.abstractmethod - async def token(self): + async def token(self) -> str: """ An async callback returning a :class:`str` ID/Access Token to be sent to the Kafka client. In case where a synchronous callback is needed, @@ -122,7 +123,7 @@ def _token(self): # The actual synchronous token callback. """ - def extensions(self): + def extensions(self) -> Dict[str, str]: """ This is an OPTIONAL method that may be implemented. diff --git a/aiokafka/conn.py b/aiokafka/conn.py index a2402b72..f133eb1f 100644 --- a/aiokafka/conn.py +++ b/aiokafka/conn.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import asyncio import base64 import collections +import enum import functools import hashlib import hmac @@ -8,6 +11,7 @@ import logging import random import socket +import ssl import struct import sys import time @@ -15,16 +19,42 @@ import uuid import warnings import weakref +from typing import ( + Any, + Awaitable, + Callable, + Coroutine, + Dict, + Generator, + Iterable, + List, + Literal, + NamedTuple, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, + cast, + overload, +) import async_timeout +from typing_extensions import Buffer import aiokafka.errors as Errors from aiokafka.abc import AbstractTokenProvider from aiokafka.protocol.admin import ( ApiVersionRequest, SaslAuthenticateRequest, + SaslAuthenticateResponse_v0, + SaslAuthenticateResponse_v1, SaslHandShakeRequest, + SaslHandShakeResponse_v0, + SaslHandShakeResponse_v1, ) +from aiokafka.protocol.api import Request, Response from aiokafka.protocol.commit import ( GroupCoordinatorResponse_v0 as GroupCoordinatorResponse, ) @@ -33,7 +63,10 @@ try: import gssapi except ImportError: - gssapi = None + gssapi = None # type: ignore[assignment] + +RequestT = TypeVar("RequestT", bound=Request) +ResponseT = TypeVar("ResponseT", bound=Response) __all__ = ["AIOKafkaConnection", "create_conn"] @@ -45,7 +78,19 @@ SASL_QOP_AUTH = 1 -class CloseReason: +class Packet(NamedTuple): + correlation_id: int + request: Request[Response] + fut: asyncio.Future[Response] + + +class SaslPacket(NamedTuple): + correlation_id: None + request: None + fut: asyncio.Future[bytes] + + +class CloseReason(enum.IntEnum): CONNECTION_BROKEN = 0 CONNECTION_TIMEOUT = 1 OUT_OF_SYNC = 2 @@ -55,17 +100,18 @@ class CloseReason: class VersionInfo: - def __init__(self, versions): + def __init__(self, versions: Dict[int, Tuple[int, int]]) -> None: self._versions = versions - def pick_best(self, request_versions): - api_key = request_versions[0].API_KEY + def pick_best(self, request_versions: Sequence[Type[RequestT]]) -> Type[RequestT]: + api_key = cast(int, request_versions[0].API_KEY) if api_key not in self._versions: return request_versions[0] min_version, max_version = self._versions[api_key] for req_klass in reversed(request_versions): - if min_version <= req_klass.API_VERSION <= max_version: + req_api_version = cast(int, req_klass.API_VERSION) + if min_version <= req_api_version <= max_version: return req_klass raise Errors.KafkaError( @@ -75,24 +121,30 @@ def pick_best(self, request_versions): async def create_conn( - host, - port, + host: str, + port: int, *, - client_id="aiokafka", - request_timeout_ms=40000, - api_version=(0, 8, 2), - ssl_context=None, - security_protocol="PLAINTEXT", - max_idle_ms=None, - on_close=None, - sasl_mechanism=None, - sasl_plain_username=None, - sasl_plain_password=None, - sasl_kerberos_service_name="kafka", - sasl_kerberos_domain_name=None, - sasl_oauth_token_provider=None, - version_hint=None, -): + client_id: str = "aiokafka", + request_timeout_ms: float = 40000, + api_version: Union[Tuple[int, int], Tuple[int, int, int]] = (0, 8, 2), + ssl_context: Optional[ssl.SSLContext] = None, + security_protocol: Literal[ + "PLAINTEXT", "SASL_PLAINTEXT", "SSL", "SASL_SSL" + ] = "PLAINTEXT", + max_idle_ms: Optional[float] = None, + on_close: Optional[ + Callable[[AIOKafkaConnection, Optional[CloseReason]], None] + ] = None, + sasl_mechanism: Optional[ + Literal["PLAIN", "GSSAPI", "SCRAM-SHA-256", "SCRAM-SHA-512", "OAUTHBEARER"] + ] = None, + sasl_plain_username: Optional[str] = None, + sasl_plain_password: Optional[str] = None, + sasl_kerberos_service_name: str = "kafka", + sasl_kerberos_domain_name: Optional[str] = None, + sasl_oauth_token_provider: Optional[AbstractTokenProvider] = None, + version_hint: Optional[Union[Tuple[int, int], Tuple[int, int, int]]] = None, +) -> AIOKafkaConnection: conn = AIOKafkaConnection( host, port, @@ -116,11 +168,17 @@ async def create_conn( class AIOKafkaProtocol(asyncio.StreamReaderProtocol): - def __init__(self, closed_fut, *args, loop, **kw): + def __init__( + self, + closed_fut: asyncio.Future[None], + *args: Any, + loop: asyncio.AbstractEventLoop, + **kw: Any, + ) -> None: self._closed_fut = closed_fut - super().__init__(*args, loop=loop, **kw) + super().__init__(*args, loop=loop, **kw) # type: ignore[misc] - def connection_lost(self, exc): + def connection_lost(self, exc: Optional[Exception]) -> None: super().connection_lost(exc) if not self._closed_fut.cancelled(): self._closed_fut.set_result(None) @@ -129,29 +187,37 @@ def connection_lost(self, exc): class AIOKafkaConnection: """Class for manage connection to Kafka node""" - _reader = None # For __del__ to work properly, just in case - _source_traceback = None + _reader: Optional[asyncio.StreamReader] = ( + None # For __del__ to work properly, just in case + ) + _source_traceback: Optional[traceback.StackSummary] = None def __init__( self, - host, - port, + host: str, + port: int, *, - client_id="aiokafka", - request_timeout_ms=40000, - api_version=(0, 8, 2), - ssl_context=None, - security_protocol="PLAINTEXT", - max_idle_ms=None, - on_close=None, - sasl_mechanism=None, - sasl_plain_password=None, - sasl_plain_username=None, - sasl_kerberos_service_name="kafka", - sasl_kerberos_domain_name=None, - sasl_oauth_token_provider=None, - version_hint=None, - ): + client_id: str = "aiokafka", + request_timeout_ms: float = 40000, + api_version: Union[Tuple[int, int], Tuple[int, int, int]] = (0, 8, 2), + ssl_context: Optional[ssl.SSLContext] = None, + security_protocol: Literal[ + "PLAINTEXT", "SASL_PLAINTEXT", "SSL", "SASL_SSL" + ] = "PLAINTEXT", + max_idle_ms: Optional[float] = None, + on_close: Optional[ + Callable[[AIOKafkaConnection, Optional[CloseReason]], None] + ] = None, + sasl_mechanism: Optional[ + Literal["PLAIN", "GSSAPI", "SCRAM-SHA-256", "SCRAM-SHA-512", "OAUTHBEARER"] + ] = None, + sasl_plain_password: Optional[str] = None, + sasl_plain_username: Optional[str] = None, + sasl_kerberos_service_name: str = "kafka", + sasl_kerberos_domain_name: Optional[str] = None, + sasl_oauth_token_provider: Optional[AbstractTokenProvider] = None, + version_hint: Optional[Union[Tuple[int, int], Tuple[int, int, int]]] = None, + ) -> None: loop = get_running_loop() if sasl_mechanism == "GSSAPI": @@ -188,17 +254,21 @@ def __init__( self._version_hint = version_hint self._version_info = VersionInfo({}) - self._reader = self._writer = self._protocol = None + self._reader: Optional[asyncio.StreamReader] = None + self._writer: Optional[asyncio.StreamWriter] = None + self._protocol: Optional[AIOKafkaProtocol] = None # Even on small size seems to be a bit faster than list. # ~2x on size of 2 in Python3.6 - self._requests = collections.deque() - self._read_task = None - self._correlation_id = 0 - self._closed_fut = None + self._requests: collections.deque[Union[Packet, SaslPacket]] = ( + collections.deque() + ) + self._read_task: Optional[asyncio.Task[None]] = None + self._correlation_id: int = 0 + self._closed_fut: Optional[asyncio.Future[None]] = None self._max_idle_ms = max_idle_ms self._last_action = time.monotonic() - self._idle_handle = None + self._idle_handle: Optional[asyncio.Handle] = None self._on_close_cb = on_close @@ -207,7 +277,7 @@ def __init__( # Warn and try to close. We can close synchronously, so will attempt # that - def __del__(self, _warnings=warnings): + def __del__(self, _warnings=warnings) -> None: # type: ignore[no-untyped-def] if self.connected(): _warnings.warn( f"Unclosed AIOKafkaConnection {self!r}", @@ -230,7 +300,7 @@ def __del__(self, _warnings=warnings): context["source_traceback"] = self._source_traceback self._loop.call_exception_handler(context) - async def connect(self): + async def connect(self) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]: loop = self._loop self._closed_fut = create_future() if self._security_protocol in ["PLAINTEXT", "SASL_PLAINTEXT"]: @@ -268,10 +338,10 @@ async def connect(self): return reader, writer - async def _do_version_lookup(self): + async def _do_version_lookup(self) -> None: version_req = ApiVersionRequest[0]() response = await self.send(version_req) - versions = {} + versions: Dict[int, Tuple[int, int]] = {} for api_key, min_version, max_version in response.api_versions: assert min_version <= max_version, ( f"{min_version} should be less than" @@ -280,9 +350,10 @@ async def _do_version_lookup(self): versions[api_key] = (min_version, max_version) self._version_info = VersionInfo(versions) - async def _do_sasl_handshake(self): + async def _do_sasl_handshake(self) -> None: # NOTE: We will only fallback to v0.9 gssapi scheme if user explicitly # stated, that api_version is "0.9" + exc: Errors.KafkaError if self._version_hint and self._version_hint < (0, 10): handshake_klass = None assert self._sasl_mechanism == "GSSAPI", "Only GSSAPI supported for v0.9" @@ -291,6 +362,9 @@ async def _do_sasl_handshake(self): sasl_handshake = handshake_klass(self._sasl_mechanism) response = await self.send(sasl_handshake) + response = cast( + Union[SaslHandShakeResponse_v0, SaslHandShakeResponse_v1], response + ) error_type = Errors.for_code(response.error_code) if error_type is not Errors.NoError: error = error_type(self) @@ -318,6 +392,7 @@ async def _do_sasl_handshake(self): ): log.warning("Sending username and password in the clear") + authenticator: BaseSaslAuthenticator if self._sasl_mechanism == "GSSAPI": authenticator = self.authenticator_gssapi() elif self._sasl_mechanism.startswith("SCRAM-SHA-"): @@ -332,7 +407,7 @@ async def _do_sasl_handshake(self): else: auth_klass = None - auth_bytes = None + auth_bytes: Optional[bytes] = None expect_response = True while True: @@ -350,6 +425,10 @@ async def _do_sasl_handshake(self): else: req = auth_klass(payload) resp = await self.send(req) + resp = cast( + Union[SaslAuthenticateResponse_v0, SaslAuthenticateResponse_v1], + resp, + ) error_type = Errors.for_code(resp.error_code) if error_type is not Errors.NoError: exc = error_type(resp.error_message) @@ -368,41 +447,51 @@ async def _do_sasl_handshake(self): self._sasl_mechanism, ) - def authenticator_plain(self): + def authenticator_plain(self) -> SaslPlainAuthenticator: + assert self._sasl_plain_password is not None + assert self._sasl_plain_username is not None return SaslPlainAuthenticator( loop=self._loop, sasl_plain_password=self._sasl_plain_password, sasl_plain_username=self._sasl_plain_username, ) - def authenticator_gssapi(self): + def authenticator_gssapi(self) -> SaslGSSAPIAuthenticator: return SaslGSSAPIAuthenticator( loop=self._loop, principal=self.sasl_principal, ) - def authenticator_scram(self): + def authenticator_scram(self) -> ScramAuthenticator: + assert self._sasl_plain_password is not None + assert self._sasl_plain_username is not None + assert self._sasl_mechanism in ("SCRAM-SHA-256", "SCRAM-SHA-512") return ScramAuthenticator( loop=self._loop, sasl_plain_password=self._sasl_plain_password, sasl_plain_username=self._sasl_plain_username, - sasl_mechanism=self._sasl_mechanism, + sasl_mechanism=self._sasl_mechanism, # type: ignore[arg-type] ) - def authenticator_oauth(self): + def authenticator_oauth(self) -> OAuthAuthenticator: + assert self._sasl_oauth_token_provider is not None return OAuthAuthenticator( sasl_oauth_token_provider=self._sasl_oauth_token_provider, ) @property - def sasl_principal(self): + def sasl_principal(self) -> str: service = self._sasl_kerberos_service_name domain = self._sasl_kerberos_domain_name or self.host return f"{service}@{domain}" @classmethod - def _on_read_task_error(cls, self_ref, read_task): + def _on_read_task_error( + cls, + self_ref: weakref.ReferenceType[AIOKafkaConnection], + read_task: asyncio.Task[None], + ) -> None: # We don't want to react to cancelled errors if read_task.cancelled(): return @@ -418,12 +507,13 @@ def _on_read_task_error(cls, self_ref, read_task): self.close(reason=CloseReason.CONNECTION_BROKEN, exc=exc) @staticmethod - def _idle_check(self_ref): + def _idle_check(self_ref: weakref.ReferenceType[AIOKafkaConnection]) -> None: self = self_ref() if self is None: return idle_for = time.monotonic() - self._last_action + assert self._max_idle_ms is not None timeout = self._max_idle_ms / 1000 # If we have any pending requests, we are assumed to be not idle. # it's up to `request_timeout_ms` to break those. @@ -440,18 +530,31 @@ def _idle_check(self_ref): wake_up_in, self._idle_check, self_ref ) - def __repr__(self): + def __repr__(self) -> str: return f"" @property - def host(self): + def host(self) -> str: return self._host @property - def port(self): + def port(self) -> int: return self._port - def send(self, request, expect_response=True): + @overload + def send(self, request: Request[ResponseT]) -> Coroutine[None, None, ResponseT]: ... + @overload + def send( + self, request: Request[ResponseT], expect_response: Literal[False] + ) -> Coroutine[None, None, None]: ... + @overload + def send( + self, request: Request[ResponseT], expect_response: Literal[True] + ) -> Coroutine[None, None, ResponseT]: ... + + def send( + self, request: Request[ResponseT], expect_response: bool = True + ) -> Union[Coroutine[None, None, ResponseT], Coroutine[None, None, None]]: if self._writer is None: raise Errors.KafkaConnectionError( f"No connection to broker at {self._host}:{self._port}" @@ -477,11 +580,28 @@ def send(self, request, expect_response=True): return self._writer.drain() fut = self._loop.create_future() self._requests.append( - (correlation_id, request, fut), + Packet(correlation_id, request, fut), ) return wait_for(fut, self._request_timeout) - def _send_sasl_token(self, payload, expect_response=True): + @overload + def _send_sasl_token( + self, payload: bytes, expect_response: Literal[False] + ) -> Coroutine[None, None, None]: ... + @overload + def _send_sasl_token( + self, payload: bytes, expect_response: Literal[True] + ) -> Coroutine[None, None, bytes]: ... + @overload + def _send_sasl_token(self, payload: bytes) -> Coroutine[None, None, bytes]: ... + @overload + def _send_sasl_token( + self, payload: bytes, expect_response: bool + ) -> Union[Coroutine[None, None, None], Coroutine[None, None, bytes]]: ... + + def _send_sasl_token( + self, payload: bytes, expect_response: bool = True + ) -> Union[Coroutine[None, None, None], Coroutine[None, None, bytes]]: if self._writer is None: raise Errors.KafkaConnectionError( f"No connection to broker at {self._host}:{self._port}" @@ -499,17 +619,21 @@ def _send_sasl_token(self, payload, expect_response=True): return self._writer.drain() fut = self._loop.create_future() - self._requests.append((None, None, fut)) + self._requests.append(SaslPacket(None, None, fut)) return wait_for(fut, self._request_timeout) - def connected(self): + def connected(self) -> bool: return bool(self._reader is not None and not self._reader.at_eof()) - def close(self, reason=None, exc=None): + def close( + self, reason: Optional[CloseReason] = None, exc: Optional[Exception] = None + ) -> Optional[asyncio.Future[None]]: log.debug("Closing connection at %s:%s", self._host, self._port) if self._reader is not None: + assert self._writer is not None self._writer.close() self._writer = self._reader = None + assert self._read_task is not None if not self._read_task.done(): self._read_task.cancel() self._read_task = None @@ -533,7 +657,7 @@ def close(self, reason=None, exc=None): # a future in case we need to wait on it. return self._closed_fut - def _create_reader_task(self): + def _create_reader_task(self) -> asyncio.Task[None]: self_ref = weakref.ref(self) read_task = create_task(self._read(self_ref)) read_task.add_done_callback( @@ -542,7 +666,7 @@ def _create_reader_task(self): return read_task @staticmethod - async def _read(self_ref): + async def _read(self_ref: weakref.ReferenceType[AIOKafkaConnection]) -> None: # XXX: I know that it become a bit more ugly once cyclic references # were removed, but it's needed to allow connections to properly # release resources if leaked. @@ -552,6 +676,7 @@ async def _read(self_ref): return reader = self._reader del self + assert reader is not None while True: resp = await reader.readexactly(4) @@ -565,15 +690,16 @@ async def _read(self_ref): self._handle_frame(resp) del self - def _handle_frame(self, resp): - correlation_id, request, fut = self._requests[0] + def _handle_frame(self, resp: bytes) -> None: + packet = self._requests[0] - if correlation_id is None: # Is a SASL packet, just pass it though - if not fut.done(): - fut.set_result(resp) + if isinstance(packet, SaslPacket): # Is a SASL packet, just pass it though + if not packet.fut.done(): + packet.fut.set_result(resp) else: - resp = io.BytesIO(resp) - response_header = request.parse_response_header(resp) + correlation_id, request, fut = packet + resp_io = io.BytesIO(resp) + response_header = request.parse_response_header(resp_io) resp_type = request.RESPONSE_TYPE if ( @@ -600,7 +726,7 @@ def _handle_frame(self, resp): return if not fut.done(): - response = resp_type.decode(resp) + response = resp_type.decode(resp_io) log.debug("%s Response %d: %s", self, correlation_id, response) fut.set_result(response) @@ -611,23 +737,30 @@ def _handle_frame(self, resp): # this future. self._requests.popleft() - def _next_correlation_id(self): + def _next_correlation_id(self) -> int: self._correlation_id = (self._correlation_id + 1) % 2**31 return self._correlation_id class BaseSaslAuthenticator: - def step(self, payload): + # FIXME: move to __init__? + _loop: asyncio.AbstractEventLoop + _authenticator: Generator[Tuple[bytes, bool], bytes, None] + + def step(self, payload: Optional[bytes]) -> Awaitable[Optional[Tuple[bytes, bool]]]: return self._loop.run_in_executor(None, self._step, payload) - def _step(self, payload): + def _step(self, payload: Optional[bytes]) -> Optional[Tuple[bytes, bool]]: """Process next token in sequence and return with: ``None`` if it was the last needed exchange ``tuple`` tuple with new token and a boolean whether it requires an answer token """ try: - data = self._authenticator.send(payload) + if payload is None: + data = next(self._authenticator) + else: + data = self._authenticator.send(payload) except StopIteration: return None else: @@ -635,13 +768,19 @@ def _step(self, payload): class SaslPlainAuthenticator(BaseSaslAuthenticator): - def __init__(self, *, loop, sasl_plain_password, sasl_plain_username): + def __init__( + self, + *, + loop: asyncio.AbstractEventLoop, + sasl_plain_password: str, + sasl_plain_username: str, + ) -> None: self._loop = loop self._sasl_plain_username = sasl_plain_username self._sasl_plain_password = sasl_plain_password self._authenticator = self.authenticator_plain() - def authenticator_plain(self): + def authenticator_plain(self) -> Generator[Tuple[bytes, bool], bytes, None]: """Automaton to authenticate with SASL tokens""" # Send PLAIN credentials per RFC-4616 data = "\0".join( @@ -658,12 +797,12 @@ def authenticator_plain(self): class SaslGSSAPIAuthenticator(BaseSaslAuthenticator): - def __init__(self, *, loop, principal): + def __init__(self, *, loop: asyncio.AbstractEventLoop, principal: str) -> None: self._loop = loop self._principal = principal self._authenticator = self.authenticator_gssapi() - def authenticator_gssapi(self): + def authenticator_gssapi(self) -> Generator[Tuple[bytes, bool], bytes, None]: name = gssapi.Name( self._principal, name_type=gssapi.NameType.hostbased_service, @@ -679,6 +818,7 @@ def authenticator_gssapi(self): server_token = yield client_token, True + assert server_token is not None msg = client_ctx.unwrap(server_token).message qop = struct.pack("b", SASL_QOP_AUTH & msg[0]) @@ -697,33 +837,33 @@ class ScramAuthenticator(BaseSaslAuthenticator): def __init__( self, *, - loop, - sasl_plain_password, - sasl_plain_username, - sasl_mechanism, - ): + loop: asyncio.AbstractEventLoop, + sasl_plain_password: str, + sasl_plain_username: str, + sasl_mechanism: Literal["SCRAM-SHA-256", "SCRAM-SHA-512"], + ) -> None: self._loop = loop self._nonce = str(uuid.uuid4()).replace("-", "") self._auth_message = "" - self._salted_password = None + self._salted_password: Optional[bytes] = None self._sasl_plain_username = sasl_plain_username self._sasl_plain_password = sasl_plain_password.encode("utf-8") self._hashfunc = self.MECHANISMS[sasl_mechanism] self._hashname = "".join(sasl_mechanism.lower().split("-")[1:3]) - self._stored_key = None - self._client_key = None - self._client_signature = None - self._client_proof = None - self._server_key = None - self._server_signature = None + self._stored_key: Optional[bytes] = None + self._client_key: Optional[bytes] = None + self._client_signature: Optional[bytes] = None + self._client_proof: Optional[bytes] = None + self._server_key: Optional[bytes] = None + self._server_signature: Optional[bytes] = None self._authenticator = self.authenticator_scram() - def first_message(self): + def first_message(self) -> str: client_first_bare = f"n={self._sasl_plain_username},r={self._nonce}" self._auth_message += client_first_bare return "n,," + client_first_bare - def process_server_first_message(self, server_first): + def process_server_first_message(self, server_first: str) -> None: self._auth_message += "," + server_first params = dict(pair.split("=", 1) for pair in server_first.split(",")) server_nonce = params["r"] @@ -734,8 +874,10 @@ def process_server_first_message(self, server_first): salt = base64.b64decode(params["s"].encode("utf-8")) iterations = int(params["i"]) - self.create_salted_password(salt, iterations) + self._salted_password = hashlib.pbkdf2_hmac( + self._hashname, self._sasl_plain_password, salt, iterations + ) self._client_key = self.hmac(self._salted_password, b"Client Key") self._stored_key = self._hashfunc(self._client_key).digest() self._client_signature = self.hmac( @@ -749,16 +891,17 @@ def process_server_first_message(self, server_first): self._server_key, self._auth_message.encode("utf-8") ) - def final_message(self): + def final_message(self) -> str: + assert self._client_proof is not None client_proof = base64.b64encode(self._client_proof).decode("utf-8") return f"c=biws,r={self._nonce},p={client_proof}" - def process_server_final_message(self, server_final): + def process_server_final_message(self, server_final: str) -> None: params = dict(pair.split("=", 1) for pair in server_final.split(",")) if self._server_signature != base64.b64decode(params["v"].encode("utf-8")): raise ValueError("Server sent wrong signature!") - def authenticator_scram(self): + def authenticator_scram(self) -> Generator[Tuple[bytes, bool], bytes, None]: client_first = self.first_message().encode("utf-8") server_first = yield client_first, True self.process_server_first_message(server_first.decode("utf-8")) @@ -766,25 +909,22 @@ def authenticator_scram(self): server_final = yield client_final, True self.process_server_final_message(server_final.decode("utf-8")) - def hmac(self, key, msg): + def hmac(self, key: bytes, msg: Buffer) -> bytes: return hmac.new(key, msg, digestmod=self._hashfunc).digest() - def create_salted_password(self, salt, iterations): - self._salted_password = hashlib.pbkdf2_hmac( - self._hashname, self._sasl_plain_password, salt, iterations - ) - @staticmethod - def _xor_bytes(left, right): + def _xor_bytes(left: Iterable[int], right: Iterable[int]) -> bytes: return bytes(lb ^ rb for lb, rb in zip(left, right)) class OAuthAuthenticator(BaseSaslAuthenticator): - def __init__(self, *, sasl_oauth_token_provider): + def __init__(self, *, sasl_oauth_token_provider: AbstractTokenProvider) -> None: self._sasl_oauth_token_provider = sasl_oauth_token_provider self._token_sent = False - async def step(self, payload): + async def step( + self, payload: Optional[bytes] + ) -> Optional[Tuple[bytes, Literal[True]]]: if self._token_sent: return None token = await self._sasl_oauth_token_provider.token() @@ -795,10 +935,10 @@ async def step(self, payload): True, ) - def _build_oauth_client_request(self, token, token_extensions): + def _build_oauth_client_request(self, token: str, token_extensions: str) -> str: return f"n,,\x01auth=Bearer {token}{token_extensions}\x01\x01" - def _token_extensions(self): + def _token_extensions(self) -> str: """ Return a string representation of the OPTIONAL key-value pairs that can be sent with an OAUTHBEARER initial request. @@ -815,7 +955,7 @@ def _token_extensions(self): return "" -def _address_family(address): +def _address_family(address: str) -> socket.AddressFamily: """ Attempt to determine the family of an address (or hostname) @@ -834,7 +974,7 @@ def _address_family(address): return socket.AF_UNSPEC -def get_ip_port_afi(host_and_port_str): +def get_ip_port_afi(host_and_port_str: str) -> Tuple[str, int, socket.AddressFamily]: """ Parse the IP and port from a string in the format of: @@ -879,14 +1019,16 @@ def get_ip_port_afi(host_and_port_str): pass else: return host_and_port_str, DEFAULT_KAFKA_PORT, socket.AF_INET6 - host, port = host_and_port_str.rsplit(":", 1) - port = int(port) + host, port_str = host_and_port_str.rsplit(":", 1) + port = int(port_str) af = _address_family(host) return host, port, af -def collect_hosts(hosts, randomize=True): +def collect_hosts( + hosts: Union[str, Iterable[str]], randomize: bool = True +) -> List[Tuple[str, int, socket.AddressFamily]]: """ Collects a comma-separated set of hosts (host:port) and optionally randomize the returned list. diff --git a/aiokafka/protocol/admin.py b/aiokafka/protocol/admin.py index 2f374286..410ff47c 100644 --- a/aiokafka/protocol/admin.py +++ b/aiokafka/protocol/admin.py @@ -1,4 +1,4 @@ -from typing import Dict, Iterable, Optional, Tuple +from typing import Dict, Iterable, Optional, Sequence, Tuple from .api import Request, Response from .types import ( @@ -29,6 +29,9 @@ class ApiVersionResponse_v0(Response): ), ) + error_code: int + api_versions: Sequence[Tuple[int, int, int]] + class ApiVersionResponse_v1(Response): API_KEY = 18 @@ -42,39 +45,47 @@ class ApiVersionResponse_v1(Response): ("throttle_time_ms", Int32), ) + error_code: int + api_versions: Sequence[Tuple[int, int, int]] + throttle_time_ms: int + class ApiVersionResponse_v2(Response): API_KEY = 18 API_VERSION = 2 SCHEMA = ApiVersionResponse_v1.SCHEMA + error_code: int + api_versions: Sequence[Tuple[int, int, int]] + throttle_time_ms: int -class ApiVersionRequest_v0(Request): + +class ApiVersionRequest_v0(Request[ApiVersionResponse_v0]): API_KEY = 18 API_VERSION = 0 RESPONSE_TYPE = ApiVersionResponse_v0 SCHEMA = Schema() -class ApiVersionRequest_v1(Request): +class ApiVersionRequest_v1(Request[ApiVersionResponse_v1]): API_KEY = 18 API_VERSION = 1 RESPONSE_TYPE = ApiVersionResponse_v1 SCHEMA = ApiVersionRequest_v0.SCHEMA -class ApiVersionRequest_v2(Request): +class ApiVersionRequest_v2(Request[ApiVersionResponse_v1]): API_KEY = 18 API_VERSION = 2 - RESPONSE_TYPE = ApiVersionResponse_v1 + RESPONSE_TYPE = ApiVersionResponse_v1 # TODO: Why v1? SCHEMA = ApiVersionRequest_v0.SCHEMA -ApiVersionRequest = [ +ApiVersionRequest = ( ApiVersionRequest_v0, ApiVersionRequest_v1, ApiVersionRequest_v2, -] +) ApiVersionResponse = [ ApiVersionResponse_v0, ApiVersionResponse_v1, @@ -488,29 +499,35 @@ class SaslHandShakeResponse_v0(Response): ("error_code", Int16), ("enabled_mechanisms", Array(String("utf-8"))) ) + error_code: int + enabled_mechanisms: Sequence[str] + class SaslHandShakeResponse_v1(Response): API_KEY = 17 API_VERSION = 1 SCHEMA = SaslHandShakeResponse_v0.SCHEMA + error_code: int + enabled_mechanisms: Sequence[str] -class SaslHandShakeRequest_v0(Request): + +class SaslHandShakeRequest_v0(Request[SaslHandShakeResponse_v0]): API_KEY = 17 API_VERSION = 0 RESPONSE_TYPE = SaslHandShakeResponse_v0 SCHEMA = Schema(("mechanism", String("utf-8"))) -class SaslHandShakeRequest_v1(Request): +class SaslHandShakeRequest_v1(Request[SaslHandShakeResponse_v1]): API_KEY = 17 API_VERSION = 1 RESPONSE_TYPE = SaslHandShakeResponse_v1 SCHEMA = SaslHandShakeRequest_v0.SCHEMA -SaslHandShakeRequest = [SaslHandShakeRequest_v0, SaslHandShakeRequest_v1] -SaslHandShakeResponse = [SaslHandShakeResponse_v0, SaslHandShakeResponse_v1] +SaslHandShakeRequest = (SaslHandShakeRequest_v0, SaslHandShakeRequest_v1) +SaslHandShakeResponse = (SaslHandShakeResponse_v0, SaslHandShakeResponse_v1) class DescribeAclsResponse_v0(Response): @@ -992,6 +1009,10 @@ class SaslAuthenticateResponse_v0(Response): ("sasl_auth_bytes", Bytes), ) + error_code: int + error_message: str + sasl_auth_bytes: bytes + class SaslAuthenticateResponse_v1(Response): API_KEY = 36 @@ -1003,29 +1024,34 @@ class SaslAuthenticateResponse_v1(Response): ("session_lifetime_ms", Int64), ) + error_code: int + error_message: str + sasl_auth_bytes: bytes + session_lifetime_ms: int -class SaslAuthenticateRequest_v0(Request): + +class SaslAuthenticateRequest_v0(Request[SaslAuthenticateResponse_v0]): API_KEY = 36 API_VERSION = 0 RESPONSE_TYPE = SaslAuthenticateResponse_v0 SCHEMA = Schema(("sasl_auth_bytes", Bytes)) -class SaslAuthenticateRequest_v1(Request): +class SaslAuthenticateRequest_v1(Request[SaslAuthenticateResponse_v1]): API_KEY = 36 API_VERSION = 1 RESPONSE_TYPE = SaslAuthenticateResponse_v1 SCHEMA = SaslAuthenticateRequest_v0.SCHEMA -SaslAuthenticateRequest = [ +SaslAuthenticateRequest = ( SaslAuthenticateRequest_v0, SaslAuthenticateRequest_v1, -] -SaslAuthenticateResponse = [ +) +SaslAuthenticateResponse = ( SaslAuthenticateResponse_v0, SaslAuthenticateResponse_v1, -] +) class CreatePartitionsResponse_v0(Response): diff --git a/aiokafka/protocol/api.py b/aiokafka/protocol/api.py index 1e6ee3b6..c6c5a4ba 100644 --- a/aiokafka/protocol/api.py +++ b/aiokafka/protocol/api.py @@ -2,11 +2,17 @@ import abc from io import BytesIO -from typing import Any, ClassVar, Dict, Optional, Type, Union +from typing import Any, ClassVar, Dict, Generic, Optional, Type, Union + +from typing_extensions import TypeVar from .struct import Struct from .types import Array, Int16, Int32, Schema, String, TaggedFields +ResponseT_co = TypeVar( + "ResponseT_co", bound="Response", default="Response", covariant=True +) + class RequestHeader_v0(Struct): SCHEMA = Schema( @@ -17,7 +23,10 @@ class RequestHeader_v0(Struct): ) def __init__( - self, request: Request, correlation_id: int = 0, client_id: str = "aiokafka" + self, + request: Request[Any], + correlation_id: int = 0, + client_id: str = "aiokafka", ) -> None: super().__init__( request.API_KEY, request.API_VERSION, correlation_id, client_id @@ -36,7 +45,7 @@ class RequestHeader_v1(Struct): def __init__( self, - request: Request, + request: Request[Any], correlation_id: int = 0, client_id: str = "aiokafka", tags: Optional[Dict[int, bytes]] = None, @@ -51,6 +60,8 @@ class ResponseHeader_v0(Struct): ("correlation_id", Int32), ) + correlation_id: int + class ResponseHeader_v1(Struct): SCHEMA = Schema( @@ -58,8 +69,10 @@ class ResponseHeader_v1(Struct): ("tags", TaggedFields), ) + correlation_id: int + -class Request(Struct, metaclass=abc.ABCMeta): +class Request(Struct, Generic[ResponseT_co], metaclass=abc.ABCMeta): FLEXIBLE_VERSION: ClassVar[bool] = False @property @@ -74,7 +87,7 @@ def API_VERSION(self) -> int: @property @abc.abstractmethod - def RESPONSE_TYPE(self) -> Type[Response]: + def RESPONSE_TYPE(self) -> Type[ResponseT_co]: """The Response class associated with the api request""" @property diff --git a/aiokafka/protocol/commit.py b/aiokafka/protocol/commit.py index b0fda8c3..8305d360 100644 --- a/aiokafka/protocol/commit.py +++ b/aiokafka/protocol/commit.py @@ -275,6 +275,11 @@ class GroupCoordinatorResponse_v0(Response): ("port", Int32), ) + error_code: int + coordinator_id: int + host: str + port: int + class GroupCoordinatorResponse_v1(Response): API_KEY = 10 @@ -288,8 +293,15 @@ class GroupCoordinatorResponse_v1(Response): ("port", Int32), ) + throttle_time_ms: int + error_code: int + error_message: str + coordinator_id: int + host: str + port: int + -class GroupCoordinatorRequest_v0(Request): +class GroupCoordinatorRequest_v0(Request[GroupCoordinatorResponse_v0]): API_KEY = 10 API_VERSION = 0 RESPONSE_TYPE = GroupCoordinatorResponse_v0 @@ -298,7 +310,7 @@ class GroupCoordinatorRequest_v0(Request): ) -class GroupCoordinatorRequest_v1(Request): +class GroupCoordinatorRequest_v1(Request[GroupCoordinatorResponse_v1]): API_KEY = 10 API_VERSION = 1 RESPONSE_TYPE = GroupCoordinatorResponse_v1 diff --git a/tests/_testutil.py b/tests/_testutil.py index 67cd2f75..049afbc3 100644 --- a/tests/_testutil.py +++ b/tests/_testutil.py @@ -13,6 +13,7 @@ from concurrent import futures from contextlib import contextmanager from functools import wraps +from typing import List from unittest.mock import Mock import pytest @@ -352,6 +353,20 @@ def kdestroy(self): class KafkaIntegrationTestCase(unittest.TestCase): topic = None + # from setup_test_class fixture + loop: asyncio.AbstractEventLoop + kafka_host: str + kafka_port: int + kafka_ssl_port: int + kafka_sasl_plain_port: int + kafka_sasl_ssl_port: int + ssl_folder: pathlib.Path + acl_manager: ACLManager + kerberos_utils: KerberosUtils + kafka_config: KafkaConfig + hosts: List[str] + kafka_version: str + @contextmanager def silence_loop_exception_handler(self): if hasattr(self.loop, "get_exception_handler"): diff --git a/tests/conftest.py b/tests/conftest.py index d582386c..6231d64c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import gc import logging @@ -7,6 +9,7 @@ import sys import uuid from dataclasses import dataclass +from typing import Generator import pytest @@ -21,7 +24,7 @@ ) from aiokafka.util import NO_EXTENSIONS -from ._testutil import wait_kafka +from ._testutil import ACLManager, KafkaConfig, KerberosUtils, wait_kafka if not NO_EXTENSIONS: assert ( @@ -67,23 +70,23 @@ def docker(request): @pytest.fixture(scope="class") -def acl_manager(kafka_server, request): +def acl_manager( + kafka_server: KafkaServer, request: pytest.FixtureRequest +) -> ACLManager: image = request.config.getoption("--docker-image") tag = image.split(":")[-1].replace("_", "-") - from ._testutil import ACLManager - manager = ACLManager(kafka_server.container, tag) return manager @pytest.fixture(scope="class") -def kafka_config(kafka_server, request): +def kafka_config( + kafka_server: KafkaServer, request: pytest.FixtureRequest +) -> KafkaConfig: image = request.config.getoption("--docker-image") tag = image.split(":")[-1].replace("_", "-") - from ._testutil import KafkaConfig - manager = KafkaConfig(kafka_server.container, tag) return manager @@ -91,9 +94,7 @@ def kafka_config(kafka_server, request): if sys.platform != "win32": @pytest.fixture(scope="class") - def kerberos_utils(kafka_server): - from ._testutil import KerberosUtils - + def kerberos_utils(kafka_server: KafkaServer) -> KerberosUtils: utils = KerberosUtils(kafka_server.container) utils.create_keytab() return utils @@ -124,7 +125,9 @@ def kafka_image(): @pytest.fixture(scope="session") -def ssl_folder(docker_ip_address, docker, kafka_image): +def ssl_folder( + docker_ip_address: str, docker: libdocker.DockerClient, kafka_image: str +) -> pathlib.Path: ssl_dir = pathlib.Path("tests/ssl_cert") if ssl_dir.is_dir(): # Skip generating certificates when they already exist. Remove @@ -171,7 +174,7 @@ def ssl_folder(docker_ip_address, docker, kafka_image): @pytest.fixture(scope="session") -def docker_ip_address(): +def docker_ip_address() -> str: """Returns IP address of the docker daemon service.""" return "127.0.0.1" @@ -210,7 +213,7 @@ def hosts(self): @pytest.fixture(scope="session") def kafka_server( kafka_image, docker, docker_ip_address, unused_port, session_id, ssl_folder - ): + ) -> Generator[KafkaServer, None, None]: kafka_host = docker_ip_address kafka_port = unused_port() kafka_ssl_port = unused_port() @@ -316,8 +319,14 @@ def setup_test_class_serverless(request, loop): @pytest.fixture(scope="class") def setup_test_class( - request, loop, kafka_server, ssl_folder, acl_manager, kerberos_utils, kafka_config -): + request: pytest.FixtureRequest, + loop: asyncio.AbstractEventLoop, + kafka_server: KafkaServer, + ssl_folder: pathlib.Path, + acl_manager: ACLManager, + kerberos_utils: KerberosUtils, + kafka_config: KafkaConfig, +) -> None: request.cls.loop = loop request.cls.kafka_host = kafka_server.host request.cls.kafka_port = kafka_server.port diff --git a/tests/test_conn.py b/tests/test_conn.py index f0f4a075..11cbb3f3 100644 --- a/tests/test_conn.py +++ b/tests/test_conn.py @@ -1,13 +1,14 @@ import asyncio import gc import struct -from typing import Any +from typing import Any, List, NoReturn, Type from unittest import mock import pytest -from aiokafka.conn import AIOKafkaConnection, VersionInfo, create_conn +from aiokafka.conn import AIOKafkaConnection, SaslPacket, VersionInfo, create_conn from aiokafka.errors import ( + BrokerResponseError, CorrelationIdError, IllegalSaslStateError, KafkaConnectionError, @@ -22,6 +23,7 @@ SaslHandShakeRequest, SaslHandShakeResponse, ) +from aiokafka.protocol.api import Request, Response from aiokafka.protocol.commit import ( GroupCoordinatorRequest_v0 as GroupCoordinatorRequest, ) @@ -40,7 +42,7 @@ @pytest.mark.usefixtures("setup_test_class") class ConnIntegrationTest(KafkaIntegrationTestCase): @run_until_complete - async def test_ctor(self): + async def test_ctor(self) -> None: conn = AIOKafkaConnection("localhost", 1234) self.assertEqual("localhost", conn.host) self.assertEqual(1234, conn.port) @@ -49,7 +51,7 @@ async def test_ctor(self): self.assertIsNone(conn._writer) @run_until_complete - async def test_global_loop_for_create_conn(self): + async def test_global_loop_for_create_conn(self) -> None: loop = get_running_loop() host, port = self.kafka_host, self.kafka_port conn = await create_conn(host, port) @@ -60,7 +62,7 @@ async def test_global_loop_for_create_conn(self): conn.close() @run_until_complete - async def test_conn_warn_unclosed(self): + async def test_conn_warn_unclosed(self) -> None: host, port = self.kafka_host, self.kafka_port conn = await create_conn(host, port, max_idle_ms=100000) @@ -70,7 +72,7 @@ async def test_conn_warn_unclosed(self): gc.collect() @run_until_complete - async def test_basic_connection_load_meta(self): + async def test_basic_connection_load_meta(self) -> None: host, port = self.kafka_host, self.kafka_port conn = await create_conn(host, port) @@ -81,7 +83,7 @@ async def test_basic_connection_load_meta(self): self.assertIsInstance(response, MetadataResponse) @run_until_complete - async def test_connections_max_idle_ms(self): + async def test_connections_max_idle_ms(self) -> None: host, port = self.kafka_host, self.kafka_port conn = await create_conn(host, port, max_idle_ms=200) self.assertEqual(conn.connected(), True) @@ -94,10 +96,11 @@ async def test_connections_max_idle_ms(self): self.assertEqual(conn.connected(), True) # It shouldn't break if we have a long running call either + assert conn._reader is not None readexactly = conn._reader.readexactly with mock.patch.object(conn._reader, "readexactly") as mocked: - async def long_read(n): + async def long_read(n: int) -> bytes: await asyncio.sleep(0.2) return await readexactly(n) @@ -109,7 +112,7 @@ async def long_read(n): self.assertEqual(conn.connected(), False) @run_until_complete - async def test_send_without_response(self): + async def test_send_without_response(self) -> None: """Imitate producer without acknowledge, in this case client produces messages and kafka does not send response, and we make sure that futures do not stuck in queue forever""" @@ -137,7 +140,7 @@ async def test_send_without_response(self): conn.close() @run_until_complete - async def test_send_to_closed(self): + async def test_send_to_closed(self) -> None: host, port = self.kafka_host, self.kafka_port conn = AIOKafkaConnection(host=host, port=port) request = MetadataRequest([]) @@ -151,7 +154,7 @@ async def test_send_to_closed(self): await conn.send(request) @run_until_complete - async def test_invalid_correlation_id(self): + async def test_invalid_correlation_id(self) -> None: host, port = self.kafka_host, self.kafka_port request = MetadataRequest([]) @@ -163,14 +166,14 @@ async def test_invalid_correlation_id(self): reader = mock.MagicMock() int32 = struct.Struct(">i") resp = MetadataResponse(brokers=[], topics=[]) - resp = resp.encode() - resp = int32.pack(999) + resp # set invalid correlation id + resp_bytes = resp.encode() + resp_bytes = int32.pack(999) + resp_bytes # set invalid correlation id - async def first_resp(*args: Any, **kw: Any): - return int32.pack(len(resp)) + async def first_resp(*args: Any, **kw: Any) -> bytes: + return int32.pack(len(resp_bytes)) - async def second_resp(*args: Any, **kw: Any): - return resp + async def second_resp(*args: Any, **kw: Any) -> bytes: + return resp_bytes reader.readexactly.side_effect = [first_resp(), second_resp()] writer = mock.MagicMock() @@ -184,7 +187,7 @@ async def second_resp(*args: Any, **kw: Any): await conn.send(request) @run_until_complete - async def test_correlation_id_on_group_coordinator_req(self): + async def test_correlation_id_on_group_coordinator_req(self) -> None: host, port = self.kafka_host, self.kafka_port request = GroupCoordinatorRequest(consumer_group="test") @@ -198,14 +201,14 @@ async def test_correlation_id_on_group_coordinator_req(self): resp = GroupCoordinatorResponse( error_code=0, coordinator_id=22, host="127.0.0.1", port=3333 ) - resp = resp.encode() - resp = int32.pack(0) + resp # set correlation id to 0 + resp_bytes = resp.encode() + resp_bytes = int32.pack(0) + resp_bytes # set correlation id to 0 - async def first_resp(*args: Any, **kw: Any): - return int32.pack(len(resp)) + async def first_resp(*args: Any, **kw: Any) -> bytes: + return int32.pack(len(resp_bytes)) - async def second_resp(*args: Any, **kw: Any): - return resp + async def second_resp(*args: Any, **kw: Any) -> bytes: + return resp_bytes reader.readexactly.side_effect = [first_resp(), second_resp()] writer = mock.MagicMock() @@ -223,10 +226,10 @@ async def second_resp(*args: Any, **kw: Any): self.assertEqual(response.port, 3333) @run_until_complete - async def test_osserror_in_reader_task(self): + async def test_osserror_in_reader_task(self) -> None: host, port = self.kafka_host, self.kafka_port - async def invoke_osserror(*a, **kw): + async def invoke_osserror(*a: Any, **kw: Any) -> NoReturn: await asyncio.sleep(0.1) raise OSError("test oserror") @@ -249,28 +252,28 @@ async def invoke_osserror(*a, **kw): self.assertEqual(conn.connected(), False) @run_until_complete - async def test_close_disconnects_connection(self): + async def test_close_disconnects_connection(self) -> None: host, port = self.kafka_host, self.kafka_port conn = await create_conn(host, port) self.assertTrue(conn.connected()) conn.close() self.assertFalse(conn.connected()) - def test_connection_version_info(self): + def test_connection_version_info(self) -> None: # All version supported - version_info = VersionInfo({SaslHandShakeRequest[0].API_KEY: [0, 1]}) + version_info = VersionInfo({SaslHandShakeRequest[0].API_KEY: (0, 1)}) self.assertEqual( version_info.pick_best(SaslHandShakeRequest), SaslHandShakeRequest[1] ) # Broker only supports the lesser version - version_info = VersionInfo({SaslHandShakeRequest[0].API_KEY: [0, 0]}) + version_info = VersionInfo({SaslHandShakeRequest[0].API_KEY: (0, 0)}) self.assertEqual( version_info.pick_best(SaslHandShakeRequest), SaslHandShakeRequest[0] ) # We don't support any version compatible with the broker - version_info = VersionInfo({SaslHandShakeRequest[0].API_KEY: [2, 3]}) + version_info = VersionInfo({SaslHandShakeRequest[0].API_KEY: (2, 3)}) with self.assertRaises(KafkaError): self.assertEqual( version_info.pick_best(SaslHandShakeRequest), SaslHandShakeRequest[1] @@ -283,7 +286,7 @@ def test_connection_version_info(self): ) @run_until_complete - async def test__do_sasl_handshake_v0(self): + async def test__do_sasl_handshake_v0(self) -> None: host, port = self.kafka_host, self.kafka_port # setup connection with mocked send and send_bytes @@ -294,22 +297,22 @@ async def test__do_sasl_handshake_v0(self): sasl_plain_username="admin", sasl_plain_password="123", ) - conn.close = close_mock = mock.MagicMock() + conn.close = close_mock = mock.MagicMock() # type: ignore[method-assign] supported_mechanisms = ["PLAIN"] - error_class = NoError + error_class: Type[BrokerResponseError] = NoError - async def mock_send(request, expect_response=True): + async def mock_send(request: Request, expect_response: bool = True) -> Response: return SaslHandShakeResponse[0]( error_code=error_class.errno, enabled_mechanisms=supported_mechanisms ) - async def mock_sasl_send(payload, expect_response): + async def mock_sasl_send(payload: bytes, expect_response: bool) -> bytes: return b"" - conn.send = mock.Mock(side_effect=mock_send) - conn._send_sasl_token = mock.Mock(side_effect=mock_sasl_send) - conn._version_info = VersionInfo({SaslHandShakeRequest[0].API_KEY: [0, 0]}) + conn.send = mock.Mock(side_effect=mock_send) # type: ignore[method-assign] + conn._send_sasl_token = mock.Mock(side_effect=mock_sasl_send) # type: ignore[method-assign] + conn._version_info = VersionInfo({SaslHandShakeRequest[0].API_KEY: (0, 0)}) await conn._do_sasl_handshake() @@ -326,7 +329,7 @@ async def mock_sasl_send(payload, expect_response): self.assertTrue(close_mock.call_count) @run_until_complete - async def test__do_sasl_handshake_v1(self): + async def test__do_sasl_handshake_v1(self) -> None: host, port = self.kafka_host, self.kafka_port # setup connection with mocked send and send_bytes @@ -338,13 +341,13 @@ async def test__do_sasl_handshake_v1(self): sasl_plain_password="123", security_protocol="SASL_PLAINTEXT", ) - conn.close = close_mock = mock.MagicMock() + conn.close = close_mock = mock.MagicMock() # type: ignore[method-assign] supported_mechanisms = ["PLAIN"] - error_class = NoError - auth_error_class = NoError + error_class: Type[BrokerResponseError] = NoError + auth_error_class: Type[BrokerResponseError] = NoError - async def mock_send(request, expect_response=True): + async def mock_send(request: Request, expect_response: bool = True) -> Response: if request.API_KEY == SaslHandShakeRequest[0].API_KEY: assert request.API_VERSION == 1 return SaslHandShakeResponse[1]( @@ -359,8 +362,8 @@ async def mock_send(request, expect_response=True): sasl_auth_bytes=b"", ) - conn.send = mock.Mock(side_effect=mock_send) - conn._version_info = VersionInfo({SaslHandShakeRequest[0].API_KEY: [0, 1]}) + conn.send = mock.Mock(side_effect=mock_send) # type: ignore[method-assign] + conn._version_info = VersionInfo({SaslHandShakeRequest[0].API_KEY: (0, 1)}) await conn._do_sasl_handshake() @@ -386,16 +389,16 @@ async def mock_send(request, expect_response=True): self.assertTrue(close_mock.call_count) @run_until_complete - async def test__send_sasl_token(self): + async def test__send_sasl_token(self) -> None: # Before Kafka 1.0.0 SASL was performed on the wire without # KAFKA_HEADER in the protocol. So we needed another private # function to send `raw` data with only length prefixed # setup connection with mocked transport and protocol conn = AIOKafkaConnection(host="", port=9999) - conn.close = mock.MagicMock() + conn.close = mock.MagicMock() # type: ignore[method-assign] conn._writer = mock.MagicMock() - out_buffer = [] + out_buffer: List[bytes] = [] conn._writer.write = mock.Mock(side_effect=out_buffer.append) conn._reader = mock.MagicMock() self.assertEqual(len(conn._requests), 0) @@ -407,20 +410,21 @@ async def test__send_sasl_token(self): out_buffer.clear() # Resolve the request - conn._requests[0][2].set_result(None) + assert isinstance(conn._requests[0], SaslPacket) + conn._requests[0][2].set_result(b"") conn._requests.clear() await fut # Broken pipe error conn._writer.write.side_effect = OSError with self.assertRaises(KafkaConnectionError): - conn._send_sasl_token(b"Super data") + await conn._send_sasl_token(b"Super data") self.assertEqual(out_buffer, []) self.assertEqual(len(conn._requests), 0) self.assertEqual(conn.close.call_count, 1) conn._writer = None with self.assertRaises(KafkaConnectionError): - conn._send_sasl_token(b"Super data") + await conn._send_sasl_token(b"Super data") # We don't need to close 2ce self.assertEqual(conn.close.call_count, 1)