From c5b7ca9aafa763309add8e8447e782cfd302b67b Mon Sep 17 00:00:00 2001 From: Chris Kuehl Date: Fri, 1 Nov 2024 14:55:09 -0500 Subject: [PATCH] Migrate to Ruff (#984) * Replace flake8, black, reorder-python-imports with Ruff * Apply automated Ruff lint fixes * Temporarily disable E501 Will come back to this, just want to make sure Ruff isn't breaking anything first with the existing changes. * Manual lint fixes * Revert "Temporarily disable E501" This reverts commit f593bb0e95d54807f2cc0dc83c9492c979a0ef6e. * Manual E501 lint fixes * Set target-version = "py39" and re-run linter --- Makefile | 12 +- baseplate/__init__.py | 60 ++--- baseplate/clients/__init__.py | 3 +- baseplate/clients/cassandra.py | 53 ++-- baseplate/clients/kombu.py | 17 +- baseplate/clients/memcache/__init__.py | 35 +-- baseplate/clients/memcache/lib.py | 14 +- baseplate/clients/redis.py | 22 +- baseplate/clients/redis_cluster.py | 41 ++- baseplate/clients/requests.py | 42 ++-- baseplate/clients/sqlalchemy.py | 47 ++-- baseplate/clients/thrift.py | 95 +++---- baseplate/frameworks/pyramid/__init__.py | 32 +-- baseplate/frameworks/pyramid/csrf.py | 12 +- baseplate/frameworks/queue_consumer/kafka.py | 88 +++---- baseplate/frameworks/queue_consumer/kombu.py | 49 ++-- baseplate/frameworks/thrift/__init__.py | 78 +++--- baseplate/frameworks/thrift/command.py | 1 - baseplate/healthcheck/__init__.py | 9 +- baseplate/lib/__init__.py | 10 +- baseplate/lib/_requests.py | 4 +- baseplate/lib/config.py | 42 ++-- baseplate/lib/crypto.py | 2 +- baseplate/lib/datetime.py | 4 +- baseplate/lib/edgecontext.py | 6 +- baseplate/lib/events.py | 11 +- baseplate/lib/file_watcher.py | 17 +- baseplate/lib/live_data/__init__.py | 1 - baseplate/lib/live_data/writer.py | 6 +- baseplate/lib/live_data/zookeeper.py | 1 + baseplate/lib/message_queue.py | 2 +- baseplate/lib/metrics.py | 48 ++-- baseplate/lib/prometheus_metrics.py | 7 +- baseplate/lib/propagator_redditb3_http.py | 25 +- baseplate/lib/propagator_redditb3_thrift.py | 25 +- baseplate/lib/random.py | 20 +- baseplate/lib/ratelimit/__init__.py | 9 +- baseplate/lib/ratelimit/backends/memcache.py | 6 +- baseplate/lib/ratelimit/backends/redis.py | 6 +- baseplate/lib/retry.py | 4 +- baseplate/lib/secrets.py | 57 ++--- baseplate/lib/service_discovery.py | 17 +- baseplate/lib/thrift_pool.py | 19 +- baseplate/lib/tracing.py | 15 +- .../lint/db_query_string_format_plugin.py | 2 +- baseplate/lint/example_plugin.py | 24 +- baseplate/observers/logging.py | 4 +- baseplate/observers/metrics.py | 15 +- baseplate/observers/metrics_tagged.py | 38 ++- baseplate/observers/sentry.py | 24 +- baseplate/observers/timeout.py | 6 +- baseplate/observers/tracing.py | 57 ++--- baseplate/server/__init__.py | 78 +++--- baseplate/server/__main__.py | 1 - baseplate/server/einhorn.py | 1 + baseplate/server/monkey.py | 1 + baseplate/server/prometheus.py | 31 +-- baseplate/server/queue_consumer.py | 38 ++- baseplate/server/reloader.py | 11 +- baseplate/server/runtime_monitor.py | 48 ++-- baseplate/server/thrift.py | 14 +- baseplate/server/wsgi.py | 13 +- baseplate/sidecars/__init__.py | 7 +- baseplate/sidecars/event_publisher.py | 28 +-- baseplate/sidecars/live_data_watcher.py | 18 +- baseplate/sidecars/secrets_fetcher.py | 31 +-- baseplate/sidecars/trace_publisher.py | 23 +- baseplate/testing/lib/file_watcher.py | 12 +- baseplate/testing/lib/secrets.py | 7 +- docs/conf.py | 4 +- docs/pyproject.toml | 8 - docs/tutorial/chapter3/helloworld.py | 6 +- docs/tutorial/chapter4/helloworld.py | 6 +- poetry.lock | 236 +++--------------- pyproject.toml | 23 +- setup.cfg | 11 - tests/__init__.py | 1 - tests/integration/__init__.py | 11 +- tests/integration/cassandra_tests.py | 6 +- tests/integration/live_data/writer_tests.py | 2 - .../integration/live_data/zookeeper_tests.py | 1 - tests/integration/memcache_tests.py | 4 +- tests/integration/message_queue_tests.py | 3 +- tests/integration/otel_pyramid_tests.py | 15 +- tests/integration/otel_thrift_tests.py | 53 ++-- tests/integration/pyramid_tests.py | 16 +- tests/integration/ratelimit_tests.py | 7 +- tests/integration/redis_cluster_tests.py | 9 +- tests/integration/redis_testcase.py | 4 +- tests/integration/redis_tests.py | 18 +- tests/integration/requests_tests.py | 7 +- tests/integration/sqlalchemy_tests.py | 5 +- tests/integration/thrift_tests.py | 18 +- tests/integration/timeout_tests.py | 3 +- tests/integration/tracing_tests.py | 40 +-- tests/unit/clients/cassandra_tests.py | 17 +- tests/unit/clients/kombu_tests.py | 15 +- tests/unit/clients/memcache_tests.py | 7 +- tests/unit/clients/redis_cluster_tests.py | 32 +-- tests/unit/clients/redis_tests.py | 36 +-- tests/unit/clients/requests_tests.py | 15 +- tests/unit/clients/sqlalchemy_tests.py | 11 +- tests/unit/clients/thrift_tests.py | 22 +- tests/unit/core_tests.py | 25 +- tests/unit/frameworks/pyramid/csrf_tests.py | 4 +- .../pyramid/http_server_prom_tests.py | 18 +- .../frameworks/queue_consumer/kafka_tests.py | 24 +- .../frameworks/queue_consumer/kombu_tests.py | 90 +++---- tests/unit/frameworks/thrift_tests.py | 17 +- tests/unit/lib/config_tests.py | 1 - tests/unit/lib/crypto_tests.py | 2 - tests/unit/lib/datetime_tests.py | 17 +- tests/unit/lib/events/publisher_tests.py | 7 +- tests/unit/lib/events/queue_tests.py | 9 +- tests/unit/lib/file_watcher_tests.py | 1 - tests/unit/lib/metrics_tests.py | 5 +- tests/unit/lib/random_tests.py | 1 - tests/unit/lib/ratelimit_tests.py | 1 - tests/unit/lib/retry_tests.py | 13 +- tests/unit/lib/secrets/store_tests.py | 14 +- tests/unit/lib/secrets/vault_csi_tests.py | 34 +-- tests/unit/lib/service_discovery_tests.py | 5 +- tests/unit/lib/thrift_pool_tests.py | 12 +- tests/unit/observers/metrics_tagged_tests.py | 36 ++- tests/unit/observers/metrics_tests.py | 20 +- tests/unit/observers/sentry_tests.py | 11 +- .../unit/observers/tracing/publisher_tests.py | 4 +- tests/unit/observers/tracing_tests.py | 25 +- tests/unit/server/einhorn_tests.py | 1 - tests/unit/server/monkey_tests.py | 3 +- tests/unit/server/queue_consumer_tests.py | 17 +- tests/unit/server/server_tests.py | 2 - .../live_data_watcher_loader_tests.py | 14 +- .../unit/sidecars/live_data_watcher_tests.py | 9 +- 134 files changed, 1071 insertions(+), 1649 deletions(-) delete mode 100644 docs/pyproject.toml diff --git a/Makefile b/Makefile index aaf8c050c..a3d6648e5 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,3 @@ -REORDER_PYTHON_IMPORTS := reorder-python-imports --py3-plus --separate-from-import --separate-relative PYTHON_SOURCE = $(shell find baseplate/ tests/ -name '*.py') PYTHON_EXAMPLES = $(shell find docs/ -name '*.py') @@ -53,16 +52,13 @@ test: doctest .venv .PHONY: fmt fmt: .venv - .venv/bin/$(REORDER_PYTHON_IMPORTS) --exit-zero-even-if-changed $(PYTHON_SOURCE) - .venv/bin/black baseplate/ tests/ - .venv/bin/$(REORDER_PYTHON_IMPORTS) --application-directories /tmp --exit-zero-even-if-changed $(PYTHON_EXAMPLES) - .venv/bin/black docs/ # separate so it uses its own pyproject.toml + .venv/bin/ruff check --fix + .venv/bin/ruff format .PHONY: lint lint: .venv - .venv/bin/$(REORDER_PYTHON_IMPORTS) --diff-only $(PYTHON_SOURCE) - .venv/bin/black --diff --check baseplate/ tests/ - .venv/bin/flake8 baseplate tests + .venv/bin/ruff check + .venv/bin/ruff format --check PYTHONPATH=. .venv/bin/pylint baseplate/ .venv/bin/mypy baseplate/ diff --git a/baseplate/__init__.py b/baseplate/__init__.py index d3157c6d3..77d2bc56a 100644 --- a/baseplate/__init__.py +++ b/baseplate/__init__.py @@ -1,29 +1,15 @@ import logging import os import random - +from collections.abc import Iterator from contextlib import contextmanager from types import TracebackType -from typing import Any -from typing import Callable -from typing import Dict -from typing import Iterator -from typing import List -from typing import NamedTuple -from typing import Optional -from typing import Tuple -from typing import Type +from typing import Any, Callable, NamedTuple, Optional import gevent.monkey +from pkg_resources import DistributionNotFound, get_distribution -from pkg_resources import DistributionNotFound -from pkg_resources import get_distribution - -from baseplate.lib import config -from baseplate.lib import get_calling_module_name -from baseplate.lib import metrics -from baseplate.lib import UnknownCallerError - +from baseplate.lib import UnknownCallerError, config, get_calling_module_name, metrics try: __version__ = get_distribution(__name__).version @@ -51,7 +37,7 @@ def on_server_span_created(self, context: "RequestContext", server_span: "Server raise NotImplementedError -_ExcInfo = Tuple[Optional[Type[BaseException]], Optional[BaseException], Optional[TracebackType]] +_ExcInfo = tuple[Optional[type[BaseException]], Optional[BaseException], Optional[TracebackType]] class SpanObserver: @@ -157,7 +143,7 @@ def from_upstream( raise ValueError("invalid sampled value") if flags is not None: - if not 0 <= flags < 2 ** 64: + if not 0 <= flags < 2**64: raise ValueError("invalid flags value") return cls(trace_id, parent_id, span_id, sampled, flags) @@ -182,7 +168,7 @@ class RequestContext: def __init__( self, - context_config: Dict[str, Any], + context_config: dict[str, Any], prefix: Optional[str] = None, span: Optional["Span"] = None, wrapped: Optional["RequestContext"] = None, @@ -197,7 +183,7 @@ def __init__( # reference. so we fake it here and say "trust us". # # this would be much cleaner with a different API but this is where we are. - self.span: "Span" = span # type: ignore + self.span: Span = span # type: ignore def __getattr__(self, name: str) -> Any: try: @@ -279,9 +265,9 @@ def __init__(self, app_config: Optional[config.RawConfig] = None) -> None: ... """ - self.observers: List[BaseplateObserver] = [] + self.observers: list[BaseplateObserver] = [] self._metrics_client: Optional[metrics.Client] = None - self._context_config: Dict[str, Any] = {} + self._context_config: dict[str, Any] = {} self._app_config = app_config or {} self.service_name = self._app_config.get("baseplate.service_name") @@ -353,8 +339,10 @@ def configure_observers(self) -> None: skipped.append("metrics") if "tracing.service_name" in self._app_config: - from baseplate.observers.tracing import tracing_client_from_config - from baseplate.observers.tracing import TraceBaseplateObserver + from baseplate.observers.tracing import ( + TraceBaseplateObserver, + tracing_client_from_config, + ) tracing_client = tracing_client_from_config(self._app_config) self.register(TraceBaseplateObserver(tracing_client)) @@ -362,9 +350,11 @@ def configure_observers(self) -> None: skipped.append("tracing") if "sentry.dsn" in self._app_config or "SENTRY_DSN" in os.environ: - from baseplate.observers.sentry import init_sentry_client_from_config - from baseplate.observers.sentry import SentryBaseplateObserver - from baseplate.observers.sentry import _SentryUnhandledErrorReporter + from baseplate.observers.sentry import ( + SentryBaseplateObserver, + _SentryUnhandledErrorReporter, + init_sentry_client_from_config, + ) init_sentry_client_from_config(self._app_config) _SentryUnhandledErrorReporter.install() @@ -377,7 +367,7 @@ def configure_observers(self) -> None: "The following observers are unconfigured and won't run: %s", ", ".join(skipped) ) - def configure_context(self, context_spec: Dict[str, Any]) -> None: + def configure_context(self, context_spec: dict[str, Any]) -> None: """Add a number of objects to each request's context object. Configure and attach multiple clients to the @@ -509,8 +499,8 @@ def server_context(self, name: str) -> Iterator[RequestContext]: with self.make_server_span(context, name): yield context - def get_runtime_metric_reporters(self) -> Dict[str, Callable[[Any], None]]: - specs: List[Tuple[Optional[str], Dict[str, Any]]] = [(None, self._context_config)] + def get_runtime_metric_reporters(self) -> dict[str, Callable[[Any], None]]: + specs: list[tuple[Optional[str], dict[str, Any]]] = [(None, self._context_config)] result = {} while specs: prefix, spec = specs.pop(0) @@ -550,7 +540,7 @@ def __init__( self.context = context self.baseplate = baseplate self.component_name: Optional[str] = None - self.observers: List[SpanObserver] = [] + self.observers: list[SpanObserver] = [] def register(self, observer: SpanObserver) -> None: """Register an observer to receive events from this span.""" @@ -640,7 +630,7 @@ def __enter__(self) -> "Span": def __exit__( self, - exc_type: Optional[Type[BaseException]], + exc_type: Optional[type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType], ) -> None: @@ -655,7 +645,7 @@ def make_child( """Return a child Span whose parent is this Span.""" raise NotImplementedError - def with_tags(self, tags: Dict[str, Any]) -> "Span": + def with_tags(self, tags: dict[str, Any]) -> "Span": """Declare a set of tags to be added to a span before starting it in the context manager. Can be used as follow: diff --git a/baseplate/clients/__init__.py b/baseplate/clients/__init__.py index 3a79d7cbe..9af2eb015 100644 --- a/baseplate/clients/__init__.py +++ b/baseplate/clients/__init__.py @@ -5,9 +5,8 @@ trace information is passed on and metrics are collected automatically. """ -from typing import Any -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: import baseplate.lib.metrics diff --git a/baseplate/clients/cassandra.py b/baseplate/clients/cassandra.py index 54ba014ee..49f949321 100644 --- a/baseplate/clients/cassandra.py +++ b/baseplate/clients/cassandra.py @@ -1,31 +1,30 @@ import logging import time - +from collections.abc import Mapping, Sequence from threading import Event -from typing import Any -from typing import Callable -from typing import Dict -from typing import List -from typing import Mapping -from typing import NamedTuple -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import TYPE_CHECKING -from typing import Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + NamedTuple, + Optional, + Union, +) from cassandra.auth import PlainTextAuthProvider -from cassandra.cluster import _NOT_SET # pylint: disable=no-name-in-module -from cassandra.cluster import Cluster # pylint: disable=no-name-in-module -from cassandra.cluster import ExecutionProfile # pylint: disable=no-name-in-module -from cassandra.cluster import ResponseFuture # pylint: disable=no-name-in-module -from cassandra.cluster import Session # pylint: disable=no-name-in-module -from cassandra.query import BoundStatement # pylint: disable=no-name-in-module -from cassandra.query import PreparedStatement # pylint: disable=no-name-in-module -from cassandra.query import SimpleStatement # pylint: disable=no-name-in-module -from prometheus_client import Counter -from prometheus_client import Gauge -from prometheus_client import Histogram +from cassandra.cluster import ( # pylint: disable=no-name-in-module + _NOT_SET, + Cluster, + ExecutionProfile, + ResponseFuture, + Session, +) +from cassandra.query import ( # pylint: disable=no-name-in-module + BoundStatement, + PreparedStatement, + SimpleStatement, +) +from prometheus_client import Counter, Gauge, Histogram from baseplate import Span from baseplate.clients import ContextFactory @@ -70,7 +69,7 @@ def cluster_from_config( app_config: config.RawConfig, secrets: Optional[SecretsStore] = None, prefix: str = "cassandra.", - execution_profiles: Optional[Dict[str, ExecutionProfile]] = None, + execution_profiles: Optional[dict[str, ExecutionProfile]] = None, **kwargs: Any, ) -> Cluster: """Make a Cluster from a configuration dictionary. @@ -171,7 +170,7 @@ def __init__( prometheus_cluster_name: Optional[str] = None, ): self.session = session - self.prepared_statements: Dict[str, PreparedStatement] = {} + self.prepared_statements: dict[str, PreparedStatement] = {} self.prometheus_client_name = prometheus_client_name self.prometheus_cluster_name = prometheus_cluster_name @@ -318,7 +317,7 @@ def _on_execute_failed(exc: BaseException, args: CassandraCallbackArgs, event: E event.set() -RowFactory = Callable[[List[str], List[Tuple]], Any] +RowFactory = Callable[[list[str], list[tuple]], Any] Query = Union[str, SimpleStatement, PreparedStatement, BoundStatement] Parameters = Union[Sequence[Any], Mapping[str, Any]] @@ -329,7 +328,7 @@ def __init__( context_name: str, server_span: Span, session: Session, - prepared_statements: Dict[str, PreparedStatement], + prepared_statements: dict[str, PreparedStatement], prometheus_client_name: Optional[str] = None, prometheus_cluster_name: Optional[str] = None, ): diff --git a/baseplate/clients/kombu.py b/baseplate/clients/kombu.py index 783cbdb73..79af4cb9f 100644 --- a/baseplate/clients/kombu.py +++ b/baseplate/clients/kombu.py @@ -1,19 +1,11 @@ import abc import time - -from typing import Any -from typing import Generic -from typing import Optional -from typing import Type -from typing import TypeVar +from typing import Any, Generic, Optional, TypeVar import kombu.serialization - -from kombu import Connection -from kombu import Exchange +from kombu import Connection, Exchange from kombu.pools import Producers -from prometheus_client import Counter -from prometheus_client import Histogram +from prometheus_client import Counter, Histogram from thrift import TSerialization from thrift.protocol.TBinaryProtocol import TBinaryProtocolAcceleratedFactory from thrift.protocol.TProtocol import TProtocolFactory @@ -25,7 +17,6 @@ from baseplate.lib.prometheus_metrics import default_latency_buckets from baseplate.lib.secrets import SecretsStore - T = TypeVar("T") amqp_producer_labels = [ @@ -140,7 +131,7 @@ class KombuThriftSerializer(KombuSerializer[T]): # pylint: disable=unsubscripta def __init__( self, - thrift_class: Type[T], + thrift_class: type[T], protocol_factory: TProtocolFactory = TBinaryProtocolAcceleratedFactory(), ): self.thrift_class = thrift_class diff --git a/baseplate/clients/memcache/__init__.py b/baseplate/clients/memcache/__init__.py index b3f2390f4..106665541 100644 --- a/baseplate/clients/memcache/__init__.py +++ b/baseplate/clients/memcache/__init__.py @@ -1,27 +1,16 @@ +from collections.abc import Iterable, Sequence from time import perf_counter -from typing import Any -from typing import Callable -from typing import Dict -from typing import Iterable -from typing import List -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import Union - -from prometheus_client import Counter -from prometheus_client import Gauge -from prometheus_client import Histogram +from typing import Any, Callable, Optional, Union + +from prometheus_client import Counter, Gauge, Histogram from pymemcache.client.base import PooledClient from baseplate import Span from baseplate.clients import ContextFactory -from baseplate.lib import config -from baseplate.lib import metrics +from baseplate.lib import config, metrics from baseplate.lib.prometheus_metrics import default_latency_buckets - -Serializer = Callable[[str, Any], Tuple[bytes, int]] +Serializer = Callable[[str, Any], tuple[bytes, int]] Deserializer = Callable[[str, bytes, int], Any] @@ -254,8 +243,8 @@ def set(self, key: Key, value: Any, expire: int = 0, noreply: Optional[bool] = N @_prom_instrument def set_many( - self, values: Dict[Key, Any], expire: int = 0, noreply: Optional[bool] = None - ) -> List[str]: + self, values: dict[Key, Any], expire: int = 0, noreply: Optional[bool] = None + ) -> list[str]: with self._make_span("set_many") as span: span.set_tag("key_count", len(values)) span.set_tag("keys", make_keys_str(values.keys())) @@ -312,7 +301,7 @@ def get(self, key: Key, default: Any = None) -> Any: return self.pooled_client.get(key, **kwargs) @_prom_instrument - def get_many(self, keys: Sequence[Key]) -> Dict[Key, Any]: + def get_many(self, keys: Sequence[Key]) -> dict[Key, Any]: with self._make_span("get_many") as span: span.set_tag("key_count", len(keys)) span.set_tag("keys", make_keys_str(keys)) @@ -321,13 +310,13 @@ def get_many(self, keys: Sequence[Key]) -> Dict[Key, Any]: @_prom_instrument def gets( self, key: Key, default: Optional[Any] = None, cas_default: Optional[Any] = None - ) -> Tuple[Any, Any]: + ) -> tuple[Any, Any]: with self._make_span("gets") as span: span.set_tag("key", key) return self.pooled_client.gets(key, default=default, cas_default=cas_default) @_prom_instrument - def gets_many(self, keys: Sequence[Key]) -> Dict[Key, Tuple[Any, Any]]: + def gets_many(self, keys: Sequence[Key]) -> dict[Key, tuple[Any, Any]]: with self._make_span("gets_many") as span: span.set_tag("key_count", len(keys)) span.set_tag("keys", make_keys_str(keys)) @@ -379,7 +368,7 @@ def touch(self, key: Key, expire: int = 0, noreply: Optional[bool] = None) -> bo return self.pooled_client.touch(key, expire=expire, noreply=noreply) @_prom_instrument - def stats(self, *args: str) -> Dict[str, Any]: + def stats(self, *args: str) -> dict[str, Any]: with self._make_span("stats"): return self.pooled_client.stats(*args) diff --git a/baseplate/clients/memcache/lib.py b/baseplate/clients/memcache/lib.py index e4d2e125b..48781315c 100644 --- a/baseplate/clients/memcache/lib.py +++ b/baseplate/clients/memcache/lib.py @@ -10,14 +10,12 @@ should use pickle_and_compress() and decompress_and_unpickle(). """ + import json import logging import pickle import zlib - -from typing import Any -from typing import Callable -from typing import Tuple +from typing import Any, Callable class Flags: @@ -79,7 +77,7 @@ def decompress_and_load( # pylint: disable=unused-argument def make_dump_and_compress_fn( min_compress_length: int = 0, compress_level: int = 1 -) -> Callable[[str, Any], Tuple[bytes, int]]: +) -> Callable[[str, Any], tuple[bytes, int]]: """Make a serializer. This should be paired with @@ -101,7 +99,7 @@ def make_dump_and_compress_fn( def dump_and_compress( # pylint: disable=unused-argument key: str, value: Any - ) -> Tuple[bytes, int]: + ) -> tuple[bytes, int]: """Serialize a Python object in a way compatible with decompress_and_load(). :param key: the memcached key. @@ -194,7 +192,7 @@ def decompress_and_unpickle( # pylint: disable=unused-argument def make_pickle_and_compress_fn( min_compress_length: int = 0, compress_level: int = 1 -) -> Callable[[str, Any], Tuple[bytes, int]]: +) -> Callable[[str, Any], tuple[bytes, int]]: """Make a serializer compatible with ``pylibmc`` readers. The resulting method is a chain of :py:func:`pickle.dumps` and ``zlib`` @@ -218,7 +216,7 @@ def make_pickle_and_compress_fn( def pickle_and_compress( # pylint: disable=unused-argument key: str, value: Any - ) -> Tuple[bytes, int]: + ) -> tuple[bytes, int]: """Serialize a Python object in a way compatible with decompress_and_unpickle(). :param key: the memcached key. diff --git a/baseplate/clients/redis.py b/baseplate/clients/redis.py index 5db3a87d2..b0b6b2426 100644 --- a/baseplate/clients/redis.py +++ b/baseplate/clients/redis.py @@ -1,8 +1,6 @@ from math import ceil from time import perf_counter -from typing import Any -from typing import Dict -from typing import Optional +from typing import Any, Optional import redis @@ -12,16 +10,11 @@ except ImportError: from redis.client import Pipeline -from prometheus_client import Counter -from prometheus_client import Gauge -from prometheus_client import Histogram +from prometheus_client import Counter, Gauge, Histogram from baseplate import Span from baseplate.clients import ContextFactory -from baseplate.lib import config -from baseplate.lib import message_queue -from baseplate.lib import metrics - +from baseplate.lib import config, message_queue, metrics from baseplate.lib.prometheus_metrics import default_latency_buckets PROM_PREFIX = "redis_client" @@ -240,9 +233,10 @@ def execute_command(self, *args: Any, **kwargs: Any) -> Any: f"{PROM_LABELS_PREFIX}_database": self.connection_pool.connection_kwargs.get("db", ""), f"{PROM_LABELS_PREFIX}_type": "standalone", } - with self.server_span.make_child(trace_name), ACTIVE_REQUESTS.labels( - **labels - ).track_inprogress(): + with ( + self.server_span.make_child(trace_name), + ACTIVE_REQUESTS.labels(**labels).track_inprogress(), + ): start_time = perf_counter() success = "true" @@ -296,7 +290,7 @@ def __init__( trace_name: str, server_span: Span, connection_pool: redis.ConnectionPool, - response_callbacks: Dict, + response_callbacks: dict, redis_client_name: str = "", **kwargs: Any, ): diff --git a/baseplate/clients/redis_cluster.py b/baseplate/clients/redis_cluster.py index cad007f20..cb2eba287 100644 --- a/baseplate/clients/redis_cluster.py +++ b/baseplate/clients/redis_cluster.py @@ -1,28 +1,24 @@ import logging import random - from datetime import timedelta from time import perf_counter -from typing import Any -from typing import Dict -from typing import List -from typing import Optional +from typing import Any, Optional import rediscluster - from redis import RedisError from rediscluster.pipeline import ClusterPipeline from baseplate import Span from baseplate.clients import ContextFactory -from baseplate.clients.redis import ACTIVE_REQUESTS -from baseplate.clients.redis import LATENCY_SECONDS -from baseplate.clients.redis import MAX_CONNECTIONS -from baseplate.clients.redis import OPEN_CONNECTIONS -from baseplate.clients.redis import PROM_LABELS_PREFIX -from baseplate.clients.redis import REQUESTS_TOTAL -from baseplate.lib import config -from baseplate.lib import metrics +from baseplate.clients.redis import ( + ACTIVE_REQUESTS, + LATENCY_SECONDS, + MAX_CONNECTIONS, + OPEN_CONNECTIONS, + PROM_LABELS_PREFIX, + REQUESTS_TOTAL, +) +from baseplate.lib import config, metrics logger = logging.getLogger(__name__) randomizer = random.SystemRandom() @@ -155,16 +151,16 @@ def should_track_key_reads(self) -> bool: def should_track_key_writes(self) -> bool: return randomizer.random() < self.track_writes_sample_rate - def increment_keys_read_counter(self, key_list: List[str], ignore_errors: bool = True) -> None: + def increment_keys_read_counter(self, key_list: list[str], ignore_errors: bool = True) -> None: self._increment_hot_key_counter(key_list, self.reads_sorted_set_name, ignore_errors) def increment_keys_written_counter( - self, key_list: List[str], ignore_errors: bool = True + self, key_list: list[str], ignore_errors: bool = True ) -> None: self._increment_hot_key_counter(key_list, self.writes_sorted_set_name, ignore_errors) def _increment_hot_key_counter( - self, key_list: List[str], set_name: str, ignore_errors: bool = True + self, key_list: list[str], set_name: str, ignore_errors: bool = True ) -> None: if len(key_list) == 0: return @@ -183,7 +179,7 @@ def _increment_hot_key_counter( if not ignore_errors: raise - def maybe_track_key_usage(self, args: List[str]) -> None: + def maybe_track_key_usage(self, args: list[str]) -> None: """Probabilistically track usage of the keys in this command. If we have enabled key usage tracing *and* this command is withing the @@ -216,7 +212,7 @@ def maybe_track_key_usage(self, args: List[str]) -> None: # the desired behaviour. class ClusterWithReadReplicasBlockingConnectionPool(rediscluster.ClusterBlockingConnectionPool): # pylint: disable=arguments-differ - def get_node_by_slot(self, slot: int, read_command: bool = False) -> Dict[str, Any]: + def get_node_by_slot(self, slot: int, read_command: bool = False) -> dict[str, Any]: """Get a node from the slot. If the command is a read command we'll try to return a random node. @@ -260,8 +256,9 @@ def cluster_pool_from_config( * ``timeout``: . e.g. ``200 milliseconds`` (:py:func:`~baseplate.lib.config.Timespan`). How long to wait for a connection to become available. Additionally, will set ``socket_connect_timeout`` and ``socket_timeout`` if they're not set explicitly. - * ``socket_connect_timeout``: e.g. ``200 milliseconds`` (:py:func:`~baseplate.lib.config.Timespan`) - How long to wait for sockets to connect. + * ``socket_connect_timeout``: e.g. ``200 milliseconds`` + (:py:func:`~baseplate.lib.config.Timespan`) How long to wait for sockets to + connect. * ``socket_timeout``: e.g. ``200 milliseconds`` (:py:func:`~baseplate.lib.config.Timespan`) How long to wait for socket operations. * ``track_key_reads_sample_rate``: If greater than zero, which percentage of requests will @@ -506,7 +503,7 @@ def __init__( trace_name: str, server_span: Span, connection_pool: rediscluster.ClusterConnectionPool, - response_callbacks: Dict, + response_callbacks: dict, hot_key_tracker: Optional[HotKeyTracker], redis_client_name: str = "", **kwargs: Any, diff --git a/baseplate/clients/requests.py b/baseplate/clients/requests.py index 53d274192..0af5e9a25 100644 --- a/baseplate/clients/requests.py +++ b/baseplate/clients/requests.py @@ -2,29 +2,18 @@ import ipaddress import sys import time +from typing import Any, Optional, Union -from typing import Any -from typing import Optional -from typing import Type -from typing import Union - -from advocate import AddrValidator -from advocate import ValidatingHTTPAdapter +from advocate import AddrValidator, ValidatingHTTPAdapter from opentelemetry.instrumentation.requests import RequestsInstrumentor -from prometheus_client import Counter -from prometheus_client import Gauge -from prometheus_client import Histogram -from requests import PreparedRequest -from requests import Request -from requests import Response -from requests import Session +from prometheus_client import Counter, Gauge, Histogram +from requests import PreparedRequest, Request, Response, Session from requests.adapters import HTTPAdapter from baseplate import Span from baseplate.clients import ContextFactory from baseplate.lib import config -from baseplate.lib.prometheus_metrics import default_latency_buckets -from baseplate.lib.prometheus_metrics import getHTTPSuccessLabel +from baseplate.lib.prometheus_metrics import default_latency_buckets, getHTTPSuccessLabel RequestsInstrumentor().instrument() @@ -252,13 +241,18 @@ def send(self, request: PreparedRequest, **kwargs: Any) -> Response: start_time = time.perf_counter() try: - with self.span.make_child(f"{self.name}.request").with_tags( - { - "http.url": request.url, - "http.method": request.method.lower() if request.method else "", - "http.slug": self.client_name if self.client_name is not None else self.name, - } - ) as span, ACTIVE_REQUESTS.labels(**active_request_label_values).track_inprogress(): + with ( + self.span.make_child(f"{self.name}.request").with_tags( + { + "http.url": request.url, + "http.method": request.method.lower() if request.method else "", + "http.slug": self.client_name + if self.client_name is not None + else self.name, + } + ) as span, + ACTIVE_REQUESTS.labels(**active_request_label_values).track_inprogress(), + ): self._add_span_context(span, request) # we cannot re-use the same session every time because sessions re-use the same @@ -342,7 +336,7 @@ class RequestsContextFactory(ContextFactory): def __init__( self, adapter: HTTPAdapter, - session_cls: Type[BaseplateSession], + session_cls: type[BaseplateSession], client_name: Optional[str] = None, ) -> None: self.adapter = adapter diff --git a/baseplate/clients/sqlalchemy.py b/baseplate/clients/sqlalchemy.py index f96088125..7778c3a96 100644 --- a/baseplate/clients/sqlalchemy.py +++ b/baseplate/clients/sqlalchemy.py @@ -2,41 +2,28 @@ import re import typing - +from collections.abc import Sequence from time import perf_counter -from typing import Any -from typing import Dict -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import Union - -from prometheus_client import Counter -from prometheus_client import Gauge -from prometheus_client import Histogram -from sqlalchemy import create_engine -from sqlalchemy import event -from sqlalchemy.engine import Connection -from sqlalchemy.engine import Engine -from sqlalchemy.engine import ExceptionContext +from typing import Any, Optional, Union + +from prometheus_client import Counter, Gauge, Histogram +from sqlalchemy import create_engine, event +from sqlalchemy.engine import Connection, Engine, ExceptionContext from sqlalchemy.engine.interfaces import ExecutionContext from sqlalchemy.engine.url import make_url from sqlalchemy.orm import Session from sqlalchemy.pool import QueuePool -from baseplate import _ExcInfo -from baseplate import Span -from baseplate import SpanObserver +from baseplate import Span, SpanObserver, _ExcInfo from baseplate.clients import ContextFactory -from baseplate.lib import config -from baseplate.lib import metrics +from baseplate.lib import config, metrics from baseplate.lib.prometheus_metrics import default_latency_buckets from baseplate.lib.secrets import SecretsStore def engine_from_config( app_config: config.RawConfig, - secrets: Optional[SecretsStore] = None, + secrets: SecretsStore | None = None, prefix: str = "database.", **kwargs: Any, ) -> Engine: @@ -123,20 +110,18 @@ class SQLAlchemySession(config.Parser): """ - def __init__(self, secrets: Optional[SecretsStore] = None, **kwargs: Any): + def __init__(self, secrets: SecretsStore | None = None, **kwargs: Any): self.secrets = secrets self.kwargs = kwargs - def parse( - self, key_path: str, raw_config: config.RawConfig - ) -> "SQLAlchemySessionContextFactory": + def parse(self, key_path: str, raw_config: config.RawConfig) -> SQLAlchemySessionContextFactory: engine = engine_from_config( raw_config, secrets=self.secrets, prefix=f"{key_path}.", **self.kwargs ) return SQLAlchemySessionContextFactory(engine, key_path) -Parameters = Optional[Union[Dict[str, Any], Sequence[Any]]] +Parameters = Optional[Union[dict[str, Any], Sequence[Any]]] SAFE_TRACE_ID = re.compile("^[A-Za-z0-9_-]+$") @@ -246,9 +231,9 @@ def on_before_execute( cursor: Any, statement: str, parameters: Parameters, - context: Optional[ExecutionContext], + context: ExecutionContext | None, executemany: bool, - ) -> Tuple[str, Parameters]: + ) -> tuple[str, Parameters]: """Handle the engine's before_cursor_execute event.""" labels = { "sql_client_name": self.name, @@ -284,7 +269,7 @@ def on_after_execute( cursor: Any, statement: str, parameters: Parameters, - context: Optional[ExecutionContext], + context: ExecutionContext | None, executemany: bool, ) -> None: """Handle the event which happens after successful cursor execution.""" @@ -359,5 +344,5 @@ class SQLAlchemySessionSpanObserver(SpanObserver): def __init__(self, session: Session): self.session = session - def on_finish(self, exc_info: Optional[_ExcInfo]) -> None: + def on_finish(self, exc_info: _ExcInfo | None) -> None: self.session.close() diff --git a/baseplate/clients/thrift.py b/baseplate/clients/thrift.py index c4c880f8b..729ae3c3e 100644 --- a/baseplate/clients/thrift.py +++ b/baseplate/clients/thrift.py @@ -4,39 +4,29 @@ import socket import sys import time - from collections import OrderedDict +from collections.abc import Iterator from math import ceil -from typing import Any -from typing import Callable -from typing import Iterator -from typing import Optional +from typing import Any, Callable, Optional from opentelemetry import trace from opentelemetry.propagators.composite import CompositePropagator -from opentelemetry.semconv.trace import MessageTypeValues -from opentelemetry.semconv.trace import SpanAttributes +from opentelemetry.semconv.trace import MessageTypeValues, SpanAttributes from opentelemetry.trace import status from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator -from prometheus_client import Counter -from prometheus_client import Gauge -from prometheus_client import Histogram +from prometheus_client import Counter, Gauge, Histogram from thrift.protocol.TProtocol import TProtocolException -from thrift.Thrift import TApplicationException -from thrift.Thrift import TException +from thrift.Thrift import TApplicationException, TException from thrift.transport.TTransport import TTransportException from baseplate import Span from baseplate.clients import ContextFactory -from baseplate.lib import config -from baseplate.lib import metrics +from baseplate.lib import config, metrics from baseplate.lib.prometheus_metrics import default_latency_buckets from baseplate.lib.propagator_redditb3_thrift import RedditB3ThriftFormat from baseplate.lib.retry import RetryPolicy -from baseplate.lib.thrift_pool import thrift_pool_from_config -from baseplate.lib.thrift_pool import ThriftConnectionPool -from baseplate.thrift.ttypes import Error -from baseplate.thrift.ttypes import ErrorCode +from baseplate.lib.thrift_pool import ThriftConnectionPool, thrift_pool_from_config +from baseplate.thrift.ttypes import Error, ErrorCode logger = logging.getLogger(__name__) @@ -251,9 +241,12 @@ def _call_thrift_method(self: Any, *args: Any, **kwargs: Any) -> Any: for time_remaining in self.retry_policy: try: - with self.pool.connection() as prot, ACTIVE_REQUESTS.labels( - thrift_method=name, thrift_client_name=self.namespace - ).track_inprogress(): + with ( + self.pool.connection() as prot, + ACTIVE_REQUESTS.labels( + thrift_method=name, thrift_client_name=self.namespace + ).track_inprogress(), + ): start_time = time.perf_counter() span = self.server_span.make_child(trace_name) @@ -275,7 +268,7 @@ def _call_thrift_method(self: Any, *args: Any, **kwargs: Any) -> Any: if otel_attributes.get(SpanAttributes.NET_PEER_IP) in ["127.0.0.1", "::1"]: otel_attributes[SpanAttributes.NET_PEER_NAME] = "localhost" logger.debug( - "Will use the following otel span attributes. [span=%s, otel_attributes=%s]", + "Will use the following otel span attributes. [span=%s, otel_attributes=%s]", # noqa: E501 span, otel_attributes, ) @@ -302,8 +295,9 @@ def _call_thrift_method(self: Any, *args: Any, **kwargs: Any) -> Any: if not min_timeout or self.pool.timeout < min_timeout: min_timeout = self.pool.timeout if min_timeout and min_timeout > 0: - # min_timeout is in float seconds, we are converting to int milliseconds - # rounding up here. + # min_timeout is in float seconds, we are + # converting to int milliseconds rounding up + # here. prot.trans.set_header( b"Deadline-Budget", str(int(ceil(min_timeout * 1000))).encode() ) @@ -324,7 +318,8 @@ def _call_thrift_method(self: Any, *args: Any, **kwargs: Any) -> Any: last_error = str(exc) if exc.inner is not None: last_error += f" ({exc.inner})" - raise # we need to raise all exceptions so that self.pool.connect() self-heals + # we need to raise all exceptions so that self.pool.connect() self-heals + raise except (TApplicationException, TProtocolException): # these are subclasses of TException but aren't ones that # should be expected in the protocol. this is an error! @@ -374,27 +369,33 @@ def _call_thrift_method(self: Any, *args: Any, **kwargs: Any) -> Any: exception_type = exc_info[0].__name__ current_exc: Any = exc_info[1] try: - # We want the following code to execute whenever the - # service raises an instance of Baseplate's `Error` class. - # Unfortunately, we cannot just rely on `isinstance` to do - # what we want here because some services compile - # Baseplate's thrift file on their own and import `Error` - # from that. When this is done, `isinstance` will always - # return `False` since it's technically a different class. - # To fix this, we optimistically try to access `code` on - # `current_exc` and just catch the `AttributeError` if the - # `code` attribute is not present. - # Note: if the error code was not originally defined in baseplate, or the - # name associated with the error was overriden, this cannot reflect that - # we will emit the status code in both cases - # but the status will be blank in the first case, and the baseplate name - # in the second - - # Since this exception could be of any type, we may receive exceptions - # that have a `code` property that is actually not from Baseplate's - # `Error` class. In order to reduce (but not eliminate) the possibility - # of metric explosion, we validate it against the expected type for a - # proper Error code. + # We want the following code to execute + # whenever the service raises an instance of + # Baseplate's `Error` class. Unfortunately, we + # cannot just rely on `isinstance` to do what + # we want here because some services compile + # Baseplate's thrift file on their own and + # import `Error` from that. When this is done, + # `isinstance` will always return `False` since + # it's technically a different class. To fix + # this, we optimistically try to access `code` + # on `current_exc` and just catch the + # `AttributeError` if the `code` attribute is + # not present. Note: if the error code was not + # originally defined in baseplate, or the name + # associated with the error was overriden, this + # cannot reflect that we will emit the status + # code in both cases but the status will be + # blank in the first case, and the baseplate + # name in the second + + # Since this exception could be of any type, we + # may receive exceptions that have a `code` + # property that is actually not from + # Baseplate's `Error` class. In order to reduce + # (but not eliminate) the possibility of metric + # explosion, we validate it against the + # expected type for a proper Error code. if isinstance(current_exc.code, int): baseplate_status_code = str(current_exc.code) baseplate_status = ErrorCode()._VALUES_TO_NAMES.get( @@ -425,7 +426,7 @@ def _call_thrift_method(self: Any, *args: Any, **kwargs: Any) -> Any: # this only happens if we exhaust the retry policy raise TTransportException( type=TTransportException.TIMED_OUT, - message=f"retry policy exhausted while attempting {self.namespace}.{name}, last error was: {last_error}", + message=f"retry policy exhausted while attempting {self.namespace}.{name}, last error was: {last_error}", # noqa: E501 ) return _call_thrift_method diff --git a/baseplate/frameworks/pyramid/__init__.py b/baseplate/frameworks/pyramid/__init__.py index 01f2001ec..0a7a62436 100644 --- a/baseplate/frameworks/pyramid/__init__.py +++ b/baseplate/frameworks/pyramid/__init__.py @@ -2,38 +2,28 @@ import logging import sys import time - -from typing import Any -from typing import Callable -from typing import Dict -from typing import Iterable -from typing import Iterator -from typing import Mapping -from typing import Optional +from collections.abc import Iterable, Iterator, Mapping +from typing import Any, Callable, Optional import pyramid.events import pyramid.request import pyramid.tweens import webob.request - from opentelemetry import trace from opentelemetry.instrumentation.pyramid import PyramidInstrumentor -from prometheus_client import Counter -from prometheus_client import Gauge -from prometheus_client import Histogram +from prometheus_client import Counter, Gauge, Histogram from pyramid.config import Configurator from pyramid.registry import Registry from pyramid.request import Request from pyramid.response import Response -from baseplate import Baseplate -from baseplate import RequestContext -from baseplate import Span -from baseplate import TraceInfo +from baseplate import Baseplate, RequestContext, Span, TraceInfo from baseplate.lib.edgecontext import EdgeContextFactory -from baseplate.lib.prometheus_metrics import default_latency_buckets -from baseplate.lib.prometheus_metrics import default_size_buckets -from baseplate.lib.prometheus_metrics import getHTTPSuccessLabel +from baseplate.lib.prometheus_metrics import ( + default_latency_buckets, + default_size_buckets, + getHTTPSuccessLabel, +) from baseplate.thrift.ttypes import IsHealthyProbe logger = logging.getLogger(__name__) @@ -237,7 +227,7 @@ def manually_close_request_metrics(request: Request, response: Optional[Response request.reddit_tracked_endpoint = None else: logger.debug( - "Request metrics attempted to be closed but were never opened, no metrics will be tracked" + "Request metrics attempted to be closed but were never opened, no metrics will be tracked" # noqa: E501 ) @@ -323,7 +313,7 @@ class RequestFactory: def __init__(self, baseplate: Baseplate): self.baseplate = baseplate - def __call__(self, environ: Dict[str, str]) -> BaseplateRequest: + def __call__(self, environ: dict[str, str]) -> BaseplateRequest: return BaseplateRequest(environ, context_config=self.baseplate._context_config) def blank(self, path: str) -> BaseplateRequest: diff --git a/baseplate/frameworks/pyramid/csrf.py b/baseplate/frameworks/pyramid/csrf.py index 5eb6f6b6c..f32640553 100644 --- a/baseplate/frameworks/pyramid/csrf.py +++ b/baseplate/frameworks/pyramid/csrf.py @@ -1,17 +1,11 @@ import logging - from datetime import timedelta from typing import Any -from typing import Tuple from zope.interface import implementer -from baseplate.lib.crypto import make_signature -from baseplate.lib.crypto import SignatureError -from baseplate.lib.crypto import validate_signature -from baseplate.lib.secrets import SecretsStore -from baseplate.lib.secrets import VersionedSecret - +from baseplate.lib.crypto import SignatureError, make_signature, validate_signature +from baseplate.lib.secrets import SecretsStore, VersionedSecret logger = logging.getLogger(__name__) @@ -25,7 +19,7 @@ raise -def _make_csrf_token_payload(version: int, account_id: str) -> Tuple[str, str]: +def _make_csrf_token_payload(version: int, account_id: str) -> tuple[str, str]: version_str = str(version) payload = ".".join([version_str, account_id]) return version_str, payload diff --git a/baseplate/frameworks/queue_consumer/kafka.py b/baseplate/frameworks/queue_consumer/kafka.py index 3cedafb37..752491acf 100644 --- a/baseplate/frameworks/queue_consumer/kafka.py +++ b/baseplate/frameworks/queue_consumer/kafka.py @@ -3,33 +3,23 @@ import queue import socket import time - -from typing import Any -from typing import Callable -from typing import Dict -from typing import List -from typing import NamedTuple -from typing import Optional -from typing import Sequence -from typing import TYPE_CHECKING +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, Callable, NamedTuple, Optional import confluent_kafka - from gevent.server import StreamServer -from prometheus_client import Counter -from prometheus_client import Gauge -from prometheus_client import Histogram +from prometheus_client import Counter, Gauge, Histogram from typing_extensions import Self -from baseplate import Baseplate -from baseplate import RequestContext +from baseplate import Baseplate, RequestContext from baseplate.lib.prometheus_metrics import default_latency_buckets -from baseplate.server.queue_consumer import HealthcheckCallback -from baseplate.server.queue_consumer import make_simple_healthchecker -from baseplate.server.queue_consumer import MessageHandler -from baseplate.server.queue_consumer import PumpWorker -from baseplate.server.queue_consumer import QueueConsumerFactory - +from baseplate.server.queue_consumer import ( + HealthcheckCallback, + MessageHandler, + PumpWorker, + QueueConsumerFactory, + make_simple_healthchecker, +) if TYPE_CHECKING: WorkQueue = queue.Queue[confluent_kafka.Message] # pylint: disable=unsubscriptable-object @@ -152,9 +142,10 @@ def handle(self, message: confluent_kafka.Message) -> None: # We place the call to ``baseplate.make_server_span`` inside the # try/except block because we still want Baseplate to see and # handle the error (publish it to error reporting) - with self.baseplate.make_server_span( - context, f"{self.name}.handler" - ) as span, KAFKA_ACTIVE_MESSAGES.labels(**prom_labels._asdict()).track_inprogress(): + with ( + self.baseplate.make_server_span(context, f"{self.name}.handler") as span, + KAFKA_ACTIVE_MESSAGES.labels(**prom_labels._asdict()).track_inprogress(), + ): error = message.error() if error: prom_success = "false" @@ -267,7 +258,7 @@ def new( kafka_consume_batch_size: int = 1, message_unpack_fn: KafkaMessageDeserializer = json.loads, health_check_fn: Optional[HealthcheckCallback] = None, - kafka_config: Optional[Dict[str, Any]] = None, + kafka_config: Optional[dict[str, Any]] = None, prometheus_client_name: str = "", ) -> Self: """Return a new `_BaseKafkaQueueConsumerFactory`. @@ -314,7 +305,7 @@ def new( ) @classmethod - def _consumer_config(cls) -> Dict[str, Any]: + def _consumer_config(cls) -> dict[str, Any]: raise NotImplementedError @classmethod @@ -323,7 +314,7 @@ def make_kafka_consumer( bootstrap_servers: str, group_id: str, topics: Sequence[str], - kafka_config: Optional[Dict[str, Any]] = None, + kafka_config: Optional[dict[str, Any]] = None, ) -> confluent_kafka.Consumer: consumer_config = { "bootstrap.servers": bootstrap_servers, @@ -354,18 +345,18 @@ def make_kafka_consumer( for topic in topics: assert ( topic in all_topics - ), f"topic '{topic}' does not exist. maybe it's misspelled or on a different kafka cluster?" + ), f"topic '{topic}' does not exist. maybe it's misspelled or on a different kafka cluster?" # noqa: E501 # pylint: disable=unused-argument def log_assign( - consumer: confluent_kafka.Consumer, partitions: List[confluent_kafka.TopicPartition] + consumer: confluent_kafka.Consumer, partitions: list[confluent_kafka.TopicPartition] ) -> None: for topic_partition in partitions: logger.info("assigned %s/%s", topic_partition.topic, topic_partition.partition) # pylint: disable=unused-argument def log_revoke( - consumer: confluent_kafka.Consumer, partitions: List[confluent_kafka.TopicPartition] + consumer: confluent_kafka.Consumer, partitions: list[confluent_kafka.TopicPartition] ) -> None: for topic_partition in partitions: logger.info("revoked %s/%s", topic_partition.topic, topic_partition.partition) @@ -396,7 +387,9 @@ def build_health_checker(self, listener: socket.socket) -> StreamServer: class InOrderConsumerFactory(_BaseKafkaQueueConsumerFactory): - """Factory for running a :py:class:`~baseplate.server.queue_consumer.QueueConsumerServer` using Kafka. + """Factory for running a + :py:class:`~baseplate.server.queue_consumer.QueueConsumerServer` using + Kafka. The `InOrderConsumerFactory` attempts to achieve in order, exactly once message processing. @@ -406,7 +399,8 @@ class InOrderConsumerFactory(_BaseKafkaQueueConsumerFactory): that reads messages from the internal work queue, processes them with the `handler_fn`, and then commits each message's offset to the kafka consumer's internal state. - The Kafka Consumer will commit the offsets back to Kafka based on the auto.commit.interval.ms default which is 5 seconds + The Kafka Consumer will commit the offsets back to Kafka based on the + auto.commit.interval.ms default which is 5 seconds This one-at-a-time, in-order processing ensures that when a failure happens during processing we don't commit its offset (or the offset of any later @@ -423,17 +417,23 @@ class InOrderConsumerFactory(_BaseKafkaQueueConsumerFactory): UPDATE: The InOrderConsumerFactory can NEVER achieve in-order, exactly once message processing. - Message processing in Kafka to enable exactly once starts at the Producer enabling transactions, - and downstream consumers enabling reading exclusively from the committed offsets within a transactions. + Message processing in Kafka to enable exactly once starts at the Producer + enabling transactions, and downstream consumers enabling reading + exclusively from the committed offsets within a transactions. - Secondly, without defined keys in the messages from the producer, messages will be sent in a round robin fashion to all partitions in the topic. - This means that newer messages could be consumed before older ones if the consumer of those partitions with newer messages are faster. + Secondly, without defined keys in the messages from the producer, messages + will be sent in a round robin fashion to all partitions in the topic. This + means that newer messages could be consumed before older ones if the + consumer of those partitions with newer messages are faster. - Some improvements are made instead that retain the current behaviour, but don't put as much pressure on Kafka by committing every single offset. + Some improvements are made instead that retain the current behaviour, but + don't put as much pressure on Kafka by committing every single offset. Instead of committing every single message's offset back to Kafka, - the consumer now commits each offset to it's local offset store, and commits the highest seen value for each partition at a defined interval (auto.commit.interval.ms). - "enable.auto.offset.store" is set to false to give our application explicit control of when to store offsets. + the consumer now commits each offset to it's local offset store, and + commits the highest seen value for each partition at a defined interval + (auto.commit.interval.ms). "enable.auto.offset.store" is set to false to + give our application explicit control of when to store offsets. """ # we need to ensure that only a single message handler worker exists (max_concurrency = 1) @@ -441,7 +441,7 @@ class InOrderConsumerFactory(_BaseKafkaQueueConsumerFactory): message_handler_count = 0 @classmethod - def _consumer_config(cls) -> Dict[str, Any]: + def _consumer_config(cls) -> dict[str, Any]: return { # The consumer sends periodic heartbeats on a separate thread to # indicate its liveness to the broker. If no heartbeats are received by @@ -494,7 +494,9 @@ def commit_offset( class FastConsumerFactory(_BaseKafkaQueueConsumerFactory): - """Factory for running a :py:class:`~baseplate.server.queue_consumer.QueueConsumerServer` using Kafka. + """Factory for running a + :py:class:`~baseplate.server.queue_consumer.QueueConsumerServer` using + Kafka. The `FastConsumerFactory` prioritizes high throughput over exactly once message processing. @@ -543,7 +545,7 @@ class FastConsumerFactory(_BaseKafkaQueueConsumerFactory): # pylint: disable=unused-argument @staticmethod def _commit_callback( - err: confluent_kafka.KafkaError, topic_partition_list: List[confluent_kafka.TopicPartition] + err: confluent_kafka.KafkaError, topic_partition_list: list[confluent_kafka.TopicPartition] ) -> None: # called after automatic commits for topic_partition in topic_partition_list: @@ -565,7 +567,7 @@ def _commit_callback( ) @classmethod - def _consumer_config(cls) -> Dict[str, Any]: + def _consumer_config(cls) -> dict[str, Any]: return { # The consumer sends periodic heartbeats on a separate thread to # indicate its liveness to the broker. If no heartbeats are received by diff --git a/baseplate/frameworks/queue_consumer/kombu.py b/baseplate/frameworks/queue_consumer/kombu.py index 6610f5e16..57a8f8e37 100644 --- a/baseplate/frameworks/queue_consumer/kombu.py +++ b/baseplate/frameworks/queue_consumer/kombu.py @@ -2,35 +2,27 @@ import queue import socket import time - +from collections.abc import Sequence from enum import Enum -from typing import Any -from typing import Callable -from typing import Dict -from typing import NamedTuple -from typing import Optional -from typing import Sequence -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Callable, NamedTuple, Optional import kombu - from gevent.server import StreamServer from kombu.mixins import ConsumerMixin from kombu.transport.virtual import Channel -from prometheus_client import Counter -from prometheus_client import Gauge -from prometheus_client import Histogram +from prometheus_client import Counter, Gauge, Histogram -from baseplate import Baseplate -from baseplate import RequestContext +from baseplate import Baseplate, RequestContext from baseplate.clients.kombu import KombuSerializer from baseplate.lib.errors import KnownException from baseplate.lib.prometheus_metrics import default_latency_buckets -from baseplate.server.queue_consumer import HealthcheckCallback -from baseplate.server.queue_consumer import make_simple_healthchecker -from baseplate.server.queue_consumer import MessageHandler -from baseplate.server.queue_consumer import PumpWorker -from baseplate.server.queue_consumer import QueueConsumerFactory +from baseplate.server.queue_consumer import ( + HealthcheckCallback, + MessageHandler, + PumpWorker, + QueueConsumerFactory, + make_simple_healthchecker, +) class AmqpConsumerPrometheusLabels(NamedTuple): @@ -191,7 +183,7 @@ def _handle_error( if not self._is_error_recoverable(exc): message.reject() logger.exception( - "Unrecoverable error while trying to process a message. The message has been discarded." + "Unrecoverable error while trying to process a message. The message has been discarded." # noqa: E501 ) return @@ -276,11 +268,10 @@ def handle(self, message: kombu.Message) -> None: # We place the call to ``baseplate.make_server_span`` inside the # try/except block because we still want Baseplate to see and # handle the error (publish it to error reporting) - with self.baseplate.make_server_span( - context, self.name - ) as span, AMQP_ACTIVE_MESSAGES.labels( - **prometheus_labels._asdict() - ).track_inprogress(): + with ( + self.baseplate.make_server_span(context, self.name) as span, + AMQP_ACTIVE_MESSAGES.labels(**prometheus_labels._asdict()).track_inprogress(), + ): delivery_info = message.delivery_info message_body = None message_body = message.decode() @@ -317,7 +308,9 @@ def handle(self, message: kombu.Message) -> None: class KombuQueueConsumerFactory(QueueConsumerFactory): - """Factory for running a :py:class:`~baseplate.server.queue_consumer.QueueConsumerServer` using Kombu. + """Factory for running a + :py:class:`~baseplate.server.queue_consumer.QueueConsumerServer` using + Kombu. For simple cases where you just need a basic queue with all the default parameters for your message broker, you can use `KombuQueueConsumerFactory.new`. @@ -336,7 +329,7 @@ def __init__( error_handler_fn: Optional[ErrorHandler] = None, health_check_fn: Optional[HealthcheckCallback] = None, serializer: Optional[KombuSerializer] = None, - worker_kwargs: Optional[Dict[str, Any]] = None, + worker_kwargs: Optional[dict[str, Any]] = None, retry_mode: RetryMode = RetryMode.REQUEUE, retry_limit: Optional[int] = None, ): @@ -390,7 +383,7 @@ def new( error_handler_fn: Optional[ErrorHandler] = None, health_check_fn: Optional[HealthcheckCallback] = None, serializer: Optional[KombuSerializer] = None, - worker_kwargs: Optional[Dict[str, Any]] = None, + worker_kwargs: Optional[dict[str, Any]] = None, retry_mode: RetryMode = RetryMode.REQUEUE, retry_limit: Optional[int] = None, ) -> "KombuQueueConsumerFactory": diff --git a/baseplate/frameworks/thrift/__init__.py b/baseplate/frameworks/thrift/__init__.py index dba3600ac..ec16269ac 100644 --- a/baseplate/frameworks/thrift/__init__.py +++ b/baseplate/frameworks/thrift/__init__.py @@ -2,45 +2,29 @@ import random import sys import time - +from collections.abc import Iterator, Mapping from contextlib import contextmanager from logging import Logger -from typing import Any -from typing import Callable -from typing import Iterator -from typing import Mapping -from typing import Optional - -from form_observability import ContextAwareTracer -from form_observability import ctx +from typing import Any, Callable, Optional + +from form_observability import ContextAwareTracer, ctx from opentelemetry import trace -from opentelemetry.context import attach -from opentelemetry.context import detach +from opentelemetry.context import attach, detach from opentelemetry.propagators.composite import CompositePropagator -from opentelemetry.semconv.trace import MessageTypeValues -from opentelemetry.semconv.trace import SpanAttributes +from opentelemetry.semconv.trace import MessageTypeValues, SpanAttributes from opentelemetry.trace import Tracer from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator -from prometheus_client import Counter -from prometheus_client import Gauge -from prometheus_client import Histogram +from prometheus_client import Counter, Gauge, Histogram from requests.structures import CaseInsensitiveDict -from thrift.protocol.TProtocol import TProtocolBase -from thrift.protocol.TProtocol import TProtocolException -from thrift.Thrift import TApplicationException -from thrift.Thrift import TException -from thrift.Thrift import TProcessor +from thrift.protocol.TProtocol import TProtocolBase, TProtocolException +from thrift.Thrift import TApplicationException, TException, TProcessor from thrift.transport.TTransport import TTransportException -from baseplate import Baseplate -from baseplate import RequestContext -from baseplate import TraceInfo +from baseplate import Baseplate, RequestContext, TraceInfo from baseplate.lib.edgecontext import EdgeContextFactory from baseplate.lib.prometheus_metrics import default_latency_buckets from baseplate.lib.propagator_redditb3_thrift import RedditB3ThriftFormat -from baseplate.thrift.ttypes import Error -from baseplate.thrift.ttypes import ErrorCode - +from baseplate.thrift.ttypes import Error, ErrorCode logger = logging.getLogger(__name__) @@ -100,19 +84,19 @@ def _set_remote_context(self, request_context: RequestContext) -> Iterator[None] try: header_dict[k.decode()] = v.decode() except UnicodeDecodeError: - self.logger.info("Unable to decode header %s, ignoring." % k.decode()) + self.logger.info(f"Unable to decode header {k.decode()}, ignoring.") ctx = propagator.extract(header_dict) logger.debug("Extracted trace headers. [ctx=%s, header_dict=%s]", ctx, header_dict) if ctx: token = attach(ctx) - logger.debug("Attached context. [ctx=%s, token=%s]" % (ctx, token)) + logger.debug(f"Attached context. [ctx={ctx}, token={token}]") try: yield finally: detach(token) - logger.debug("Detached context. [ctx=%s, token=%s]" % (ctx, token)) + logger.debug(f"Detached context. [ctx={ctx}, token={token}]") else: yield else: @@ -162,8 +146,7 @@ def call_with_context(*args: Any, **kwargs: Any) -> Any: result = handler_fn(self.context, *args, **kwargs) except (TApplicationException, TProtocolException, TTransportException) as exc: logger.debug( - "Processing one of: TApplicationException, TProtocolException, TTransportException. [exc=%s]" - % exc + f"Processing one of: TApplicationException, TProtocolException, TTransportException. [exc={exc}]" # noqa: E501 ) # these are subclasses of TException but aren't ones that # should be expected in the protocol @@ -171,7 +154,7 @@ def call_with_context(*args: Any, **kwargs: Any) -> Any: otelspan.set_status(trace.status.Status(trace.status.StatusCode.ERROR)) raise except Error as exc: - logger.debug("Processing Error. [exc=%s]" % exc) + logger.debug(f"Processing Error. [exc={exc}]") c = ErrorCode() status = c._VALUES_TO_NAMES.get(exc.code, "") @@ -185,17 +168,17 @@ def call_with_context(*args: Any, **kwargs: Any) -> Any: span.set_tag("success", "false") # mark 5xx errors as failures since those are still "unexpected" if 500 <= exc.code < 600: - logger.debug("Processing 5xx baseplate Error. [exc=%s]" % exc) + logger.debug(f"Processing 5xx baseplate Error. [exc={exc}]") span.finish(exc_info=sys.exc_info()) otelspan.set_status(trace.status.Status(trace.status.StatusCode.ERROR)) else: - logger.debug("Processing non 5xx baseplate Error. [exc=%s]" % exc) + logger.debug(f"Processing non 5xx baseplate Error. [exc={exc}]") # Set as OK as this is an expected exception span.finish() otelspan.set_status(trace.status.Status(trace.status.StatusCode.OK)) raise except TException as exc: - logger.debug("Processing TException. [exc=%s]" % exc) + logger.debug(f"Processing TException. [exc={exc}]") span.set_tag("exception_type", type(exc).__name__) span.set_tag("success", "false") @@ -205,18 +188,18 @@ def call_with_context(*args: Any, **kwargs: Any) -> Any: otelspan.set_status(trace.status.Status(trace.status.StatusCode.OK)) raise except BaseException as exc: - logger.debug("Processing every other type of exception. [exc=%s]" % exc) + logger.debug(f"Processing every other type of exception. [exc={exc}]") # the handler crashed (or timed out)! span.finish(exc_info=sys.exc_info()) otelspan.set_status(trace.status.Status(trace.status.StatusCode.ERROR)) if self.convert_to_baseplate_error: - logger.debug("Converting exception to baseplate Error. [exc=%s]" % exc) + logger.debug(f"Converting exception to baseplate Error. [exc={exc}]") raise Error( code=ErrorCode.INTERNAL_SERVER_ERROR, message="Internal server error", ) - logger.debug("Re-raising unexpected exception. [exc=%s]" % exc) + logger.debug(f"Re-raising unexpected exception. [exc={exc}]") raise else: # a normal result @@ -252,13 +235,18 @@ def call_with_context(*args: Any, **kwargs: Any) -> Any: # To fix this, we optimistically try to access `code` on # `current_exc` and just catch the `AttributeError` if the # `code` attribute is not present. - # Note: if the error code was not originally defined in baseplate, or the - # name associated with the error was overriden, this cannot reflect that - # we will emit the status code in both cases - # but the status will be blank in the first case, and the baseplate name - # in the second + # Note: if the error code was not originally + # defined in baseplate, or the name associated + # with the error was overriden, this cannot + # reflect that we will emit the status code in + # both cases but the status will be blank in + # the first case, and the baseplate name in the + # second baseplate_status_code = current_exc.code # type: ignore - baseplate_status = ErrorCode()._VALUES_TO_NAMES.get(current_exc.code, "") # type: ignore + baseplate_status = ErrorCode()._VALUES_TO_NAMES.get( + current_exc.code, # type: ignore + "", + ) except AttributeError: pass PROM_REQUESTS.labels( diff --git a/baseplate/frameworks/thrift/command.py b/baseplate/frameworks/thrift/command.py index 4b9867792..a03f9dfa3 100644 --- a/baseplate/frameworks/thrift/command.py +++ b/baseplate/frameworks/thrift/command.py @@ -1,7 +1,6 @@ import glob import os import subprocess - from distutils.command.build_py import build_py from distutils.core import Command diff --git a/baseplate/healthcheck/__init__.py b/baseplate/healthcheck/__init__.py index bbdf82e9a..79a677e73 100644 --- a/baseplate/healthcheck/__init__.py +++ b/baseplate/healthcheck/__init__.py @@ -1,4 +1,5 @@ """Check health of a baseplate service on localhost.""" + import argparse import socket import sys @@ -8,14 +9,10 @@ import requests from baseplate.lib._requests import add_unix_socket_support -from baseplate.lib.config import Endpoint -from baseplate.lib.config import EndpointConfiguration -from baseplate.lib.config import InternetAddress +from baseplate.lib.config import Endpoint, EndpointConfiguration, InternetAddress from baseplate.lib.thrift_pool import ThriftConnectionPool from baseplate.thrift import BaseplateServiceV2 -from baseplate.thrift.ttypes import IsHealthyProbe -from baseplate.thrift.ttypes import IsHealthyRequest - +from baseplate.thrift.ttypes import IsHealthyProbe, IsHealthyRequest TIMEOUT = 30 # seconds diff --git a/baseplate/lib/__init__.py b/baseplate/lib/__init__.py index ccd316ea9..aa4dac420 100644 --- a/baseplate/lib/__init__.py +++ b/baseplate/lib/__init__.py @@ -1,12 +1,8 @@ """Internal library helpers.""" + import inspect import warnings - -from typing import Any -from typing import Callable -from typing import Generic -from typing import Type -from typing import TypeVar +from typing import Any, Callable, Generic, TypeVar def warn_deprecated(message: str) -> None: @@ -41,7 +37,7 @@ def __init__(self, wrapped: Callable[[Any], R]): self.__doc__ = wrapped.__doc__ self.__name__ = wrapped.__name__ - def __get__(self, instance: T, owner: Type[Any]) -> R: + def __get__(self, instance: T, owner: type[Any]) -> R: if instance is None: return self ret = self.wrapped(instance) diff --git a/baseplate/lib/_requests.py b/baseplate/lib/_requests.py index 3b611211a..272fbd284 100644 --- a/baseplate/lib/_requests.py +++ b/baseplate/lib/_requests.py @@ -3,10 +3,10 @@ This stuff is not stable yet, so it's only for baseplate-internal use. """ + import socket import urllib.parse - -from typing import Mapping +from collections.abc import Mapping from typing import Optional import requests.adapters diff --git a/baseplate/lib/config.py b/baseplate/lib/config.py index fa7d31760..59715730d 100644 --- a/baseplate/lib/config.py +++ b/baseplate/lib/config.py @@ -84,6 +84,7 @@ tempfile.close() """ + import base64 import datetime import functools @@ -92,19 +93,18 @@ import pwd import re import socket - -from typing import Any -from typing import Callable -from typing import Dict -from typing import Generic -from typing import IO -from typing import NamedTuple -from typing import NewType +from collections.abc import Sequence +from typing import ( + IO, + Any, + Callable, + Generic, + NamedTuple, + NewType, + TypeVar, + Union, +) from typing import Optional as OptionalType -from typing import Sequence -from typing import Set -from typing import TypeVar -from typing import Union class ConfigurationError(Exception): @@ -128,9 +128,7 @@ def Float(text: str) -> float: # noqa: D401 return float(text) -def Integer( - text: OptionalType[str] = None, base: int = 10 -) -> Union[int, Callable[[str], int]]: # noqa: D401 +def Integer(text: OptionalType[str] = None, base: int = 10) -> Union[int, Callable[[str], int]]: # noqa: D401 """An integer. To prevent mistakes, this will raise an error if the user attempts @@ -402,7 +400,8 @@ def DefaultFromEnv( The default is sourced from an environment variable with the name specified in ``default_src``. If the environment variable is not set, then the fallback will be used. - One of the following values must be provided: fallback, default_src, or the provided configuration + One of the following values must be provided: fallback, default_src, or the + provided configuration """ env = os.getenv(default_src) or "" default = Optional(item_parser, fallback)(env) @@ -453,13 +452,12 @@ def __init__(self) -> None: super().__init__() self.__dict__ = self - def __getattr__(self, name: str) -> Any: - ... + def __getattr__(self, name: str) -> Any: ... -ConfigSpecItem = Union["Parser", Dict[str, Any], Callable[[str], T]] -ConfigSpec = Dict[str, ConfigSpecItem] -RawConfig = Dict[str, str] +ConfigSpecItem = Union["Parser", dict[str, Any], Callable[[str], T]] +ConfigSpec = dict[str, ConfigSpecItem] +RawConfig = dict[str, str] class Parser(Generic[T]): @@ -606,7 +604,7 @@ def parse(self, key_path: str, raw_config: RawConfig) -> ConfigNamespace: matcher = re.compile("^" + root.replace(".", r"\.") + r"([^.]+)") values = ConfigNamespace() - seen_subkeys: Set[str] = set() + seen_subkeys: set[str] = set() for key in raw_config: m = matcher.search(key) if not m: diff --git a/baseplate/lib/crypto.py b/baseplate/lib/crypto.py index 2fcc47515..cb7185aea 100644 --- a/baseplate/lib/crypto.py +++ b/baseplate/lib/crypto.py @@ -29,6 +29,7 @@ """ + import base64 import binascii import datetime @@ -36,7 +37,6 @@ import hmac import struct import time - from typing import NamedTuple from baseplate.lib.secrets import VersionedSecret diff --git a/baseplate/lib/datetime.py b/baseplate/lib/datetime.py index 71a3e38a7..592abaa7e 100644 --- a/baseplate/lib/datetime.py +++ b/baseplate/lib/datetime.py @@ -1,6 +1,6 @@ """Extensions to the standard library `datetime` module.""" -from datetime import datetime -from datetime import timezone + +from datetime import datetime, timezone def datetime_to_epoch_milliseconds(dt: datetime) -> int: diff --git a/baseplate/lib/edgecontext.py b/baseplate/lib/edgecontext.py index 100813541..8008ea3ee 100644 --- a/baseplate/lib/edgecontext.py +++ b/baseplate/lib/edgecontext.py @@ -1,7 +1,5 @@ -from abc import ABC -from abc import abstractmethod -from typing import Any -from typing import Optional +from abc import ABC, abstractmethod +from typing import Any, Optional class EdgeContextFactory(ABC): diff --git a/baseplate/lib/events.py b/baseplate/lib/events.py index 838bd6fbe..89bf00b4d 100644 --- a/baseplate/lib/events.py +++ b/baseplate/lib/events.py @@ -9,12 +9,9 @@ by a separate daemon. """ -import logging -from typing import Any -from typing import Callable -from typing import Generic -from typing import TypeVar +import logging +from typing import Any, Callable, Generic, TypeVar from thrift import TSerialization from thrift.protocol.TJSONProtocol import TJSONProtocolFactory @@ -22,9 +19,7 @@ from baseplate import Span from baseplate.clients import ContextFactory from baseplate.lib import config -from baseplate.lib.message_queue import MessageQueue -from baseplate.lib.message_queue import TimedOutError - +from baseplate.lib.message_queue import MessageQueue, TimedOutError MAX_EVENT_SIZE = 102400 MAX_QUEUE_SIZE = 10000 diff --git a/baseplate/lib/file_watcher.py b/baseplate/lib/file_watcher.py index 1bc9f38d9..f4adab7e7 100644 --- a/baseplate/lib/file_watcher.py +++ b/baseplate/lib/file_watcher.py @@ -34,23 +34,14 @@ would change whenever the underlying file changes. """ + import logging import os import typing - -from typing import Callable -from typing import Generic -from typing import IO -from typing import NamedTuple -from typing import Optional -from typing import Tuple -from typing import Type -from typing import TypeVar -from typing import Union +from typing import IO, Callable, Generic, NamedTuple, Optional, TypeVar, Union from baseplate.lib.retry import RetryPolicy - logger = logging.getLogger(__name__) DEFAULT_FILEWATCHER_BACKOFF = 0.01 @@ -121,7 +112,7 @@ def __init__( self._path = path self._parser = parser self._mtime = 0.0 - self._data: Union[T, Type[_NOT_LOADED]] = _NOT_LOADED + self._data: Union[T, type[_NOT_LOADED]] = _NOT_LOADED self._open_options = _OpenOptions( mode="rb" if binary else "r", encoding=encoding or ("UTF-8" if not binary else None), @@ -165,7 +156,7 @@ def get_data(self) -> T: """ return self.get_data_and_mtime()[0] - def get_data_and_mtime(self) -> Tuple[T, float]: + def get_data_and_mtime(self) -> tuple[T, float]: """Return tuple of the current contents of the file and file mtime. The watcher ensures that the file is re-loaded and parsed whenever its diff --git a/baseplate/lib/live_data/__init__.py b/baseplate/lib/live_data/__init__.py index a6ed51b74..54aef898e 100644 --- a/baseplate/lib/live_data/__init__.py +++ b/baseplate/lib/live_data/__init__.py @@ -1,4 +1,3 @@ from baseplate.lib.live_data.zookeeper import zookeeper_client_from_config - __all__ = ["zookeeper_client_from_config"] diff --git a/baseplate/lib/live_data/writer.py b/baseplate/lib/live_data/writer.py index 88b0d7b45..d3459ae6c 100644 --- a/baseplate/lib/live_data/writer.py +++ b/baseplate/lib/live_data/writer.py @@ -1,21 +1,19 @@ """Write a file's contents to a node in ZooKeeper.""" + import argparse import configparser import difflib import logging import sys - from typing import BinaryIO from kazoo.client import KazooClient -from kazoo.exceptions import BadVersionError -from kazoo.exceptions import NoNodeError +from kazoo.exceptions import BadVersionError, NoNodeError from baseplate.lib.live_data.zookeeper import zookeeper_client_from_config from baseplate.lib.secrets import secrets_store_from_config from baseplate.server import EnvironmentInterpolation - logger = logging.getLogger(__name__) diff --git a/baseplate/lib/live_data/zookeeper.py b/baseplate/lib/live_data/zookeeper.py index 2f5f57352..586780df8 100644 --- a/baseplate/lib/live_data/zookeeper.py +++ b/baseplate/lib/live_data/zookeeper.py @@ -1,4 +1,5 @@ """Helpers for interacting with ZooKeeper.""" + from typing import Optional from kazoo.client import KazooClient diff --git a/baseplate/lib/message_queue.py b/baseplate/lib/message_queue.py index e521772cf..063ffddbf 100644 --- a/baseplate/lib/message_queue.py +++ b/baseplate/lib/message_queue.py @@ -1,6 +1,6 @@ """A Gevent-friendly POSIX message queue.""" -import select +import select from typing import Optional import posix_ipc diff --git a/baseplate/lib/metrics.py b/baseplate/lib/metrics.py index cceec8bab..405098f62 100644 --- a/baseplate/lib/metrics.py +++ b/baseplate/lib/metrics.py @@ -42,23 +42,17 @@ .. _StatsD: https://github.com/statsd/statsd """ + import collections import errno import logging import socket import time - from types import TracebackType -from typing import Any -from typing import DefaultDict -from typing import Dict -from typing import List -from typing import Optional -from typing import Type +from typing import Any, Optional from baseplate.lib import config - logger = logging.getLogger(__name__) @@ -66,7 +60,7 @@ def _metric_join(*nodes: bytes) -> bytes: return b".".join(node.strip(b".") for node in nodes if node) -def _format_tags(tags: Optional[Dict[str, Any]]) -> Optional[bytes]: +def _format_tags(tags: Optional[dict[str, Any]]) -> Optional[bytes]: if not tags: return None @@ -141,7 +135,7 @@ class BufferedTransport(Transport): def __init__(self, transport: Transport): self.transport = transport - self.buffer: List[bytes] = [] + self.buffer: list[bytes] = [] def send(self, serialized_metric: bytes) -> None: self.buffer.append(serialized_metric) @@ -156,10 +150,10 @@ def flush(self) -> None: class BaseClient: def __init__(self, transport: Transport, namespace: str): self.transport = transport - self.base_tags: Dict[str, Any] = {} + self.base_tags: dict[str, Any] = {} self.namespace = namespace.encode("ascii") - def timer(self, name: str, tags: Optional[Dict[str, Any]] = None) -> "Timer": + def timer(self, name: str, tags: Optional[dict[str, Any]] = None) -> "Timer": """Return a Timer with the given name. :param name: The name the timer should have. @@ -168,7 +162,7 @@ def timer(self, name: str, tags: Optional[Dict[str, Any]] = None) -> "Timer": timer_name = _metric_join(self.namespace, name.encode("ascii")) return Timer(self.transport, timer_name, {**self.base_tags, **(tags or {})}) - def counter(self, name: str, tags: Optional[Dict[str, Any]] = None) -> "Counter": + def counter(self, name: str, tags: Optional[dict[str, Any]] = None) -> "Counter": """Return a Counter with the given name. The sample rate is currently up to your application to enforce. @@ -179,7 +173,7 @@ def counter(self, name: str, tags: Optional[Dict[str, Any]] = None) -> "Counter" counter_name = _metric_join(self.namespace, name.encode("ascii")) return Counter(self.transport, counter_name, {**self.base_tags, **(tags or {})}) - def gauge(self, name: str, tags: Optional[Dict[str, Any]] = None) -> "Gauge": + def gauge(self, name: str, tags: Optional[dict[str, Any]] = None) -> "Gauge": """Return a Gauge with the given name. :param name: The name the gauge should have. @@ -188,7 +182,7 @@ def gauge(self, name: str, tags: Optional[Dict[str, Any]] = None) -> "Gauge": gauge_name = _metric_join(self.namespace, name.encode("ascii")) return Gauge(self.transport, gauge_name, {**self.base_tags, **(tags or {})}) - def histogram(self, name: str, tags: Optional[Dict[str, Any]] = None) -> "Histogram": + def histogram(self, name: str, tags: Optional[dict[str, Any]] = None) -> "Histogram": """Return a Histogram with the given name. :param name: The name the histogram should have. @@ -228,14 +222,14 @@ def __init__(self, transport: Transport, namespace: bytes): self.transport = BufferedTransport(transport) self.namespace = namespace self.base_tags = {} - self.counters: Dict[bytes, BatchCounter] = {} + self.counters: dict[bytes, BatchCounter] = {} def __enter__(self) -> "Batch": return self def __exit__( self, - exc_type: Optional[Type[BaseException]], + exc_type: Optional[type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType], ) -> None: @@ -256,14 +250,14 @@ def flush(self) -> None: ) logger.warning( "Metrics batch of %d bytes is too large to send, flush more often or reduce " - "amount done in this request. See https://baseplate.readthedocs.io/en/latest/guide/faq.html#what-do-i-do-about-metrics-batch-of-n-bytes-is-too-large-to-send. Top counters: %s", + "amount done in this request. See https://baseplate.readthedocs.io/en/latest/guide/faq.html#what-do-i-do-about-metrics-batch-of-n-bytes-is-too-large-to-send. Top counters: %s", # noqa: E501 exc.message_size, ", ".join(f"{c.name.decode()}={c.total:.0f}" for c in counters_by_total[:10]), ) except TransportError as exc: logger.warning("Failed to send metrics batch: %s", exc) - def counter(self, name: str, tags: Optional[Dict[str, Any]] = None) -> "Counter": + def counter(self, name: str, tags: Optional[dict[str, Any]] = None) -> "Counter": """Return a BatchCounter with the given name. The sample rate is currently up to your application to enforce. @@ -295,7 +289,7 @@ def __init__( self, transport: Transport, name: bytes, - tags: Optional[Dict[str, Any]] = None, + tags: Optional[dict[str, Any]] = None, ): self.transport = transport self.name = name @@ -343,7 +337,7 @@ def send(self, elapsed: float, sample_rate: float = 1.0) -> None: serialized = b"|".join([serialized, sampling_info]) self.transport.send(serialized) - def update_tags(self, tags: Dict) -> None: + def update_tags(self, tags: dict) -> None: assert not self.stopped self.tags.update(tags) @@ -352,7 +346,7 @@ def __enter__(self) -> None: def __exit__( self, - exc_type: Optional[Type[BaseException]], + exc_type: Optional[type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType], ) -> None: @@ -363,7 +357,7 @@ def __exit__( class Counter: """A counter for counting events over time.""" - def __init__(self, transport: Transport, name: bytes, tags: Optional[Dict[str, Any]] = None): + def __init__(self, transport: Transport, name: bytes, tags: Optional[dict[str, Any]] = None): self.transport = transport self.name = name self.tags = tags @@ -423,9 +417,9 @@ class BatchCounter(Counter): should be applied to "counter_name". """ - def __init__(self, transport: Transport, name: bytes, tags: Optional[Dict[str, Any]] = None): + def __init__(self, transport: Transport, name: bytes, tags: Optional[dict[str, Any]] = None): super().__init__(transport, name) - self.packets: DefaultDict[float, float] = collections.defaultdict(float) + self.packets: collections.defaultdict[float, float] = collections.defaultdict(float) self.tags = tags def increment(self, delta: float = 1.0, sample_rate: float = 1.0) -> None: @@ -470,7 +464,7 @@ def __init__( self, transport: Transport, name: bytes, - tags: Optional[Dict[str, Any]] = None, + tags: Optional[dict[str, Any]] = None, ) -> None: self.transport = transport self.name = name @@ -505,7 +499,7 @@ def __init__( self, transport: Transport, name: bytes, - tags: Optional[Dict[str, Any]] = None, + tags: Optional[dict[str, Any]] = None, ): self.transport = transport self.name = name diff --git a/baseplate/lib/prometheus_metrics.py b/baseplate/lib/prometheus_metrics.py index d8ee37422..dd91a94f6 100644 --- a/baseplate/lib/prometheus_metrics.py +++ b/baseplate/lib/prometheus_metrics.py @@ -1,8 +1,5 @@ -from typing import Dict - from baseplate.lib import config - # default_latency_buckets creates the default bucket values for time based histogram metrics. # we want this to match the baseplate.go default_buckets # bp.go v0 ref: https://github.com/reddit/baseplate.go/blob/master/prometheusbp/metrics.go. @@ -32,7 +29,7 @@ default_size_factor = 2 default_size_count = 20 default_size_buckets = [ - default_size_start * default_size_factor ** i for i in range(default_size_count) + default_size_start * default_size_factor**i for i in range(default_size_count) ] @@ -43,7 +40,7 @@ def getHTTPSuccessLabel(httpStatusCode: int) -> str: return str(200 <= httpStatusCode < 400).lower() -def is_metrics_enabled(raw_config: Dict[str, str]) -> bool: +def is_metrics_enabled(raw_config: dict[str, str]) -> bool: cfg = config.parse_config( raw_config, { diff --git a/baseplate/lib/propagator_redditb3_http.py b/baseplate/lib/propagator_redditb3_http.py index 555310b52..cb495c78f 100644 --- a/baseplate/lib/propagator_redditb3_http.py +++ b/baseplate/lib/propagator_redditb3_http.py @@ -1,19 +1,18 @@ import logging - +from collections.abc import Iterable from re import compile as re_compile -from typing import Any -from typing import Iterable -from typing import Optional -from typing import Set +from typing import Any, Optional from opentelemetry import trace from opentelemetry.context import Context -from opentelemetry.propagators.textmap import CarrierT -from opentelemetry.propagators.textmap import default_getter -from opentelemetry.propagators.textmap import default_setter -from opentelemetry.propagators.textmap import Getter -from opentelemetry.propagators.textmap import Setter -from opentelemetry.propagators.textmap import TextMapPropagator +from opentelemetry.propagators.textmap import ( + CarrierT, + Getter, + Setter, + TextMapPropagator, + default_getter, + default_setter, +) from opentelemetry.trace import format_span_id logger = logging.getLogger(__name__) @@ -93,7 +92,7 @@ def extract( or self._id_regex.fullmatch(extracted_span_id) is None ): logger.debug( - "No valid b3 traces headers in request. Aborting. [carrier=%s, context=%s, trace_id=%s, span_id=%s]", + "No valid b3 traces headers in request. Aborting. [carrier=%s, context=%s, trace_id=%s, span_id=%s]", # noqa: E501 carrier, context, extracted_trace_id, @@ -157,7 +156,7 @@ def inject( setter.set(carrier, self.SAMPLED_KEY, "1" if sampled else "0") @property - def fields(self) -> Set[str]: + def fields(self) -> set[str]: return { self.TRACE_ID_KEY, self.SPAN_ID_KEY, diff --git a/baseplate/lib/propagator_redditb3_thrift.py b/baseplate/lib/propagator_redditb3_thrift.py index 4522760d5..38f33178d 100644 --- a/baseplate/lib/propagator_redditb3_thrift.py +++ b/baseplate/lib/propagator_redditb3_thrift.py @@ -1,19 +1,18 @@ import logging - +from collections.abc import Iterable from re import compile as re_compile -from typing import Any -from typing import Iterable -from typing import Optional -from typing import Set +from typing import Any, Optional from opentelemetry import trace from opentelemetry.context import Context -from opentelemetry.propagators.textmap import CarrierT -from opentelemetry.propagators.textmap import default_getter -from opentelemetry.propagators.textmap import default_setter -from opentelemetry.propagators.textmap import Getter -from opentelemetry.propagators.textmap import Setter -from opentelemetry.propagators.textmap import TextMapPropagator +from opentelemetry.propagators.textmap import ( + CarrierT, + Getter, + Setter, + TextMapPropagator, + default_getter, + default_setter, +) from opentelemetry.trace import format_span_id logger = logging.getLogger(__name__) @@ -76,7 +75,7 @@ def extract( or self._id_regex.fullmatch(extracted_span_id) is None ): logger.debug( - "No valid b3 traces headers in request. Aborting. [carrier=%s, context=%s, trace_id=%s, span_id=%s]", + "No valid b3 traces headers in request. Aborting. [carrier=%s, context=%s, trace_id=%s, span_id=%s]", # noqa: E501 carrier, context, extracted_trace_id, @@ -140,7 +139,7 @@ def inject( setter.set(carrier, self.SAMPLED_KEY, "1" if sampled else "0") @property - def fields(self) -> Set[str]: + def fields(self) -> set[str]: return { self.TRACE_ID_KEY, self.SPAN_ID_KEY, diff --git a/baseplate/lib/random.py b/baseplate/lib/random.py index 048e490be..f48babb21 100644 --- a/baseplate/lib/random.py +++ b/baseplate/lib/random.py @@ -1,16 +1,10 @@ """Extensions to the standard library `random` module.""" + import bisect import random import typing - -from typing import Callable -from typing import Generic -from typing import Iterable -from typing import List -from typing import Optional -from typing import Set -from typing import TypeVar - +from collections.abc import Iterable +from typing import Callable, Generic, Optional, TypeVar T = TypeVar("T") @@ -49,7 +43,7 @@ class WeightedLottery(Generic[T]): """ def __init__(self, items: Iterable[T], weight_key: Callable[[T], int]): - self.weights: List[int] = [] + self.weights: list[int] = [] self.items = list(items) if not self.items: raise ValueError("items must not be empty") @@ -85,8 +79,8 @@ def sample(self, sample_size: int) -> Iterable[T]: if not 0 <= sample_size < len(self.items): raise ValueError("sample size is negative or larger than the population") - already_picked: Set[int] = set() - results: List[Optional[T]] = [None] * sample_size + already_picked: set[int] = set() + results: list[Optional[T]] = [None] * sample_size # we use indexes in the set so we don't add a hashability requirement # to the items in the population. @@ -96,4 +90,4 @@ def sample(self, sample_size: int) -> Iterable[T]: picked_index = self._pick_index() results[i] = self.items[picked_index] already_picked.add(picked_index) - return typing.cast(List[T], results) + return typing.cast(list[T], results) diff --git a/baseplate/lib/ratelimit/__init__.py b/baseplate/lib/ratelimit/__init__.py index 3655b7449..f26068b21 100644 --- a/baseplate/lib/ratelimit/__init__.py +++ b/baseplate/lib/ratelimit/__init__.py @@ -1,6 +1,7 @@ -from baseplate.lib.ratelimit.ratelimit import RateLimiter -from baseplate.lib.ratelimit.ratelimit import RateLimiterContextFactory -from baseplate.lib.ratelimit.ratelimit import RateLimitExceededException - +from baseplate.lib.ratelimit.ratelimit import ( + RateLimiter, + RateLimiterContextFactory, + RateLimitExceededException, +) __all__ = ["RateLimiter", "RateLimitExceededException", "RateLimiterContextFactory"] diff --git a/baseplate/lib/ratelimit/backends/memcache.py b/baseplate/lib/ratelimit/backends/memcache.py index 4c082a14f..b745c866d 100644 --- a/baseplate/lib/ratelimit/backends/memcache.py +++ b/baseplate/lib/ratelimit/backends/memcache.py @@ -2,10 +2,8 @@ from baseplate import Span from baseplate.clients import ContextFactory -from baseplate.clients.memcache import MemcacheContextFactory -from baseplate.clients.memcache import MonitoredMemcacheConnection -from baseplate.lib.ratelimit.backends import _get_current_bucket -from baseplate.lib.ratelimit.backends import RateLimitBackend +from baseplate.clients.memcache import MemcacheContextFactory, MonitoredMemcacheConnection +from baseplate.lib.ratelimit.backends import RateLimitBackend, _get_current_bucket class MemcacheRateLimitBackendContextFactory(ContextFactory): diff --git a/baseplate/lib/ratelimit/backends/redis.py b/baseplate/lib/ratelimit/backends/redis.py index 29da34cae..7b02e1c78 100644 --- a/baseplate/lib/ratelimit/backends/redis.py +++ b/baseplate/lib/ratelimit/backends/redis.py @@ -2,10 +2,8 @@ from baseplate import Span from baseplate.clients import ContextFactory -from baseplate.clients.redis import MonitoredRedisConnection -from baseplate.clients.redis import RedisContextFactory -from baseplate.lib.ratelimit.backends import _get_current_bucket -from baseplate.lib.ratelimit.backends import RateLimitBackend +from baseplate.clients.redis import MonitoredRedisConnection, RedisContextFactory +from baseplate.lib.ratelimit.backends import RateLimitBackend, _get_current_bucket class RedisRateLimitBackendContextFactory(ContextFactory): diff --git a/baseplate/lib/retry.py b/baseplate/lib/retry.py index e83f1b9ab..4f0614189 100644 --- a/baseplate/lib/retry.py +++ b/baseplate/lib/retry.py @@ -1,7 +1,7 @@ """Policies for retrying an operation safely.""" -import time -from typing import Iterator +import time +from collections.abc import Iterator from typing import Optional diff --git a/baseplate/lib/secrets.py b/baseplate/lib/secrets.py index 820fb6c3a..d026bcded 100644 --- a/baseplate/lib/secrets.py +++ b/baseplate/lib/secrets.py @@ -1,27 +1,18 @@ """Application integration with the secret fetcher daemon.""" + import base64 import binascii import json import logging import os - +from collections.abc import Iterator from pathlib import Path -from typing import Any -from typing import Dict -from typing import Iterator -from typing import NamedTuple -from typing import Optional -from typing import Protocol -from typing import Tuple +from typing import Any, NamedTuple, Optional, Protocol from baseplate import Span from baseplate.clients import ContextFactory -from baseplate.lib import cached_property -from baseplate.lib import config -from baseplate.lib import warn_deprecated -from baseplate.lib.file_watcher import FileWatcher -from baseplate.lib.file_watcher import WatchedFileNotAvailableError - +from baseplate.lib import cached_property, config, warn_deprecated +from baseplate.lib.file_watcher import FileWatcher, WatchedFileNotAvailableError ISO_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ" @@ -121,11 +112,10 @@ def _decode_secret(path: str, encoding: str, value: str) -> bytes: class SecretParser(Protocol): - def __call__(self, data: Dict[str, Any], secret_path: str = "") -> Dict[str, str]: - ... + def __call__(self, data: dict[str, Any], secret_path: str = "") -> dict[str, str]: ... -def parse_secrets_fetcher(data: Dict[str, Any], secret_path: str = "") -> Dict[str, str]: +def parse_secrets_fetcher(data: dict[str, Any], secret_path: str = "") -> dict[str, str]: try: return data["secrets"][secret_path] except KeyError: @@ -133,7 +123,7 @@ def parse_secrets_fetcher(data: Dict[str, Any], secret_path: str = "") -> Dict[s # pylint: disable=unused-argument -def parse_vault_csi(data: Dict[str, Any], secret_path: str = "") -> Dict[str, str]: +def parse_vault_csi(data: dict[str, Any], secret_path: str = "") -> dict[str, str]: return data["data"] @@ -159,13 +149,13 @@ def __init__( self.parser = parser or parse_secrets_fetcher self._filewatcher = FileWatcher(path, json.load, timeout=timeout, backoff=backoff) - def _get_data(self) -> Tuple[Any, float]: + def _get_data(self) -> tuple[Any, float]: try: return self._filewatcher.get_data_and_mtime() except WatchedFileNotAvailableError as exc: raise SecretsNotAvailableError(exc) - def get_raw(self, path: str) -> Dict[str, str]: + def get_raw(self, path: str) -> dict[str, str]: """Return a dictionary of key/value pairs for the given secret path. This is the raw representation of the secret in the underlying store. @@ -248,7 +238,7 @@ def get_vault_token(self) -> str: data, _ = self._get_data() return data["vault"]["token"] - def get_raw_and_mtime(self, secret_path: str) -> Tuple[Dict[str, str], float]: + def get_raw_and_mtime(self, secret_path: str) -> tuple[dict[str, str], float]: """Return raw secret and modification time. This returns the same data as :py:meth:`get_raw` as well as a UNIX @@ -262,7 +252,7 @@ def get_raw_and_mtime(self, secret_path: str) -> Tuple[Dict[str, str], float]: data, mtime = self._get_data() return self.parser(data, secret_path), mtime - def get_credentials_and_mtime(self, path: str) -> Tuple[CredentialSecret, float]: + def get_credentials_and_mtime(self, path: str) -> tuple[CredentialSecret, float]: """Return credentials secret and modification time. This returns the same data as :py:meth:`get_credentials` as well as a @@ -297,7 +287,7 @@ def get_credentials_and_mtime(self, path: str) -> Tuple[CredentialSecret, float] return CredentialSecret(**values), mtime - def get_simple_and_mtime(self, path: str) -> Tuple[bytes, float]: + def get_simple_and_mtime(self, path: str) -> tuple[bytes, float]: """Return simple secret and modification time. This returns the same data as :py:meth:`get_simple` as well as a UNIX @@ -321,7 +311,7 @@ def get_simple_and_mtime(self, path: str) -> Tuple[bytes, float]: encoding = secret_attributes.get("encoding", "identity") return _decode_secret(path, encoding, value), mtime - def get_versioned_and_mtime(self, path: str) -> Tuple[VersionedSecret, float]: + def get_versioned_and_mtime(self, path: str) -> tuple[VersionedSecret, float]: """Return versioned secret and modification time. This returns the same data as :py:meth:`get_versioned` as well as a @@ -371,17 +361,15 @@ def make_object_for_context(self, name: str, span: Span) -> "SecretsStore": class _CachingSecretsStore(SecretsStore): """Lazily load and cache the parsed data until the server span ends.""" - def __init__( - self, filewatcher: FileWatcher, parser: SecretParser - ): # pylint: disable=super-init-not-called + def __init__(self, filewatcher: FileWatcher, parser: SecretParser): # pylint: disable=super-init-not-called self._filewatcher = filewatcher self.parser = parser @cached_property - def _data(self) -> Tuple[Any, float]: + def _data(self) -> tuple[Any, float]: return super()._get_data() - def _get_data(self) -> Tuple[Dict, float]: + def _get_data(self) -> tuple[dict, float]: return self._data @@ -403,7 +391,7 @@ class VaultCSISecretsStore(SecretsStore): path: Path data_symlink: Path - cache: Dict[str, VaultCSIEntry] + cache: dict[str, VaultCSIEntry] def __init__( self, @@ -418,7 +406,7 @@ def __init__( raise ValueError(f"Expected {self.path} to be a directory.") if not self.data_symlink.is_dir(): raise ValueError( - f"Expected {self.data_symlink} to be a directory. Verify {self.path} is the root of the Vault CSI mount." + f"Expected {self.data_symlink} to be a directory. Verify {self.path} is the root of the Vault CSI mount." # noqa: E501 ) def get_vault_url(self) -> str: @@ -435,12 +423,12 @@ def _get_mtime(self) -> float: def _raw_secret(self, name: str) -> Any: try: - with open(self.data_symlink.joinpath(name), "r", encoding="UTF-8") as fp: + with open(self.data_symlink.joinpath(name), encoding="UTF-8") as fp: return self.parser(json.load(fp)) except FileNotFoundError as exc: raise SecretNotFoundError(name) from exc - def get_raw_and_mtime(self, secret_path: str) -> Tuple[Dict[str, str], float]: + def get_raw_and_mtime(self, secret_path: str) -> tuple[dict[str, str], float]: mtime = self._get_mtime() if cache_entry := self.cache.get(secret_path): if cache_entry.mtime == mtime: @@ -476,7 +464,8 @@ def secrets_store_from_config( to "secrets." :param backoff: retry backoff time for secrets file watcher. Defaults to None, which is mapped to DEFAULT_FILEWATCHER_BACKOFF. - :param provider: The secrets provider, acceptable values are 'vault' and 'vault_csi'. Defaults to 'vault' + :param provider: The secrets provider, acceptable values are 'vault' and + 'vault_csi'. Defaults to 'vault' """ assert prefix.endswith(".") diff --git a/baseplate/lib/service_discovery.py b/baseplate/lib/service_discovery.py index e8f2c9242..37a5ff82d 100644 --- a/baseplate/lib/service_discovery.py +++ b/baseplate/lib/service_discovery.py @@ -17,18 +17,13 @@ print(backend.endpoint.address) """ -import json -from typing import IO -from typing import List -from typing import NamedTuple -from typing import Optional -from typing import Sequence +import json +from collections.abc import Sequence +from typing import IO, NamedTuple, Optional -from baseplate.lib.config import Endpoint -from baseplate.lib.config import EndpointConfiguration -from baseplate.lib.file_watcher import FileWatcher -from baseplate.lib.file_watcher import WatchedFileNotAvailableError +from baseplate.lib.config import Endpoint, EndpointConfiguration +from baseplate.lib.file_watcher import FileWatcher, WatchedFileNotAvailableError from baseplate.lib.random import WeightedLottery @@ -60,7 +55,7 @@ class Backend(NamedTuple): class _Inventory(NamedTuple): - backends: List[Backend] + backends: list[Backend] lottery: Optional[WeightedLottery[Backend]] diff --git a/baseplate/lib/thrift_pool.py b/baseplate/lib/thrift_pool.py index 915a1c057..9b9600fc9 100644 --- a/baseplate/lib/thrift_pool.py +++ b/baseplate/lib/thrift_pool.py @@ -14,36 +14,29 @@ client.do_example_thing() """ + import contextlib import logging import queue import socket import time - -from typing import Any -from typing import Generator -from typing import Optional -from typing import Type -from typing import TYPE_CHECKING +from collections.abc import Generator +from typing import TYPE_CHECKING, Any, Optional from thrift.protocol import THeaderProtocol -from thrift.protocol.TProtocol import TProtocolBase -from thrift.protocol.TProtocol import TProtocolException -from thrift.protocol.TProtocol import TProtocolFactory -from thrift.Thrift import TApplicationException -from thrift.Thrift import TException +from thrift.protocol.TProtocol import TProtocolBase, TProtocolException, TProtocolFactory +from thrift.Thrift import TApplicationException, TException from thrift.transport.TSocket import TSocket from thrift.transport.TTransport import TTransportException from baseplate.lib import config from baseplate.lib.retry import RetryPolicy - logger = logging.getLogger(__name__) if TYPE_CHECKING: - ProtocolPool = Type[queue.Queue[TProtocolBase]] # pylint: disable=unsubscriptable-object + ProtocolPool = type[queue.Queue[TProtocolBase]] # pylint: disable=unsubscriptable-object else: ProtocolPool = queue.Queue diff --git a/baseplate/lib/tracing.py b/baseplate/lib/tracing.py index a929829ec..8655053f7 100644 --- a/baseplate/lib/tracing.py +++ b/baseplate/lib/tracing.py @@ -1,17 +1,11 @@ +from collections.abc import Sequence from typing import Optional -from typing import Sequence from opentelemetry.context import Context -from opentelemetry.sdk.trace.sampling import Decision -from opentelemetry.sdk.trace.sampling import Sampler -from opentelemetry.sdk.trace.sampling import SamplingResult -from opentelemetry.trace import Link -from opentelemetry.trace import SpanKind -from opentelemetry.trace import TraceState +from opentelemetry.sdk.trace.sampling import Decision, Sampler, SamplingResult +from opentelemetry.trace import Link, SpanKind, TraceState from opentelemetry.util.types import Attributes -from pyrate_limiter import Duration -from pyrate_limiter import Limiter -from pyrate_limiter import Rate +from pyrate_limiter import Duration, Limiter, Rate class RateLimited(Sampler): @@ -39,7 +33,6 @@ def should_sample( links: Optional[Sequence[Link]] = None, trace_state: Optional[TraceState] = None, ) -> SamplingResult: - res = self.sampler.should_sample( parent_context, trace_id, name, kind, attributes, links, trace_state ) diff --git a/baseplate/lint/db_query_string_format_plugin.py b/baseplate/lint/db_query_string_format_plugin.py index 273c5bb34..78b77fb81 100644 --- a/baseplate/lint/db_query_string_format_plugin.py +++ b/baseplate/lint/db_query_string_format_plugin.py @@ -8,7 +8,7 @@ class NoDbQueryStringFormatChecker(BaseChecker): priority = -1 msgs = { "W9000": ( - "Python string formatting found in database query. Database queries should use native parameter substitution.", + "Python string formatting found in database query. Database queries should use native parameter substitution.", # noqa: E501 "database-query-string-format", "This allows CQL/SQL injection.", ) diff --git a/baseplate/lint/example_plugin.py b/baseplate/lint/example_plugin.py index d8409debd..e38a9803b 100644 --- a/baseplate/lint/example_plugin.py +++ b/baseplate/lint/example_plugin.py @@ -1,7 +1,16 @@ -# Pylint documentation for writing a checker: http://pylint.pycqa.org/en/latest/how_tos/custom_checkers.html -# This is an example of a Pylint AST checker and should not be registered to use -# In an AST (abstract syntax tree) checker, the code will be represented as nodes of a tree -# We will use the astroid library: https://astroid.readthedocs.io/en/latest/api/general.html to visit and leave nodes +# Pylint documentation for writing a checker: +# http://pylint.pycqa.org/en/latest/how_tos/custom_checkers.html +# +# This is an example of a Pylint AST checker and should not be registered to +# use. +# +# In an AST (abstract syntax tree) checker, the code will be represented as +# nodes of a tree +# +# We will use the astroid library: +# https://astroid.readthedocs.io/en/latest/api/general.html to visit and leave +# nodes +# # Libraries needed for an AST checker from astroid import nodes from pylint.checkers import BaseChecker @@ -9,9 +18,9 @@ # Basic example of a Pylint AST (astract syntax tree) checker -# Checks for variables that have been reassigned in a function. If it finds a reassigned variable, it will throw an error +# Checks for variables that have been reassigned in a function. If it finds a +# reassigned variable, it will throw an error class NoReassignmentChecker(BaseChecker): - # Checker name name = "no-reassigned-variable" # Set priority to -1 @@ -19,7 +28,8 @@ class NoReassignmentChecker(BaseChecker): # Message dictionary msgs = { # message-id, consists of a letter and numbers - # Letter will be one of following letters (C=Convention, W=Warning, E=Error, F=Fatal, R=Refactoring) + # Letter will be one of following letters (C=Convention, W=Warning, + # E=Error, F=Fatal, R=Refactoring) # Numbers need to be unique and in-between 9000-9999 # Check https://baseplate.readthedocs.io/en/stable/linters/index.html#custom-checkers-list # for numbers that are already in use diff --git a/baseplate/observers/logging.py b/baseplate/observers/logging.py index 260a02e7c..aa2d145a3 100644 --- a/baseplate/observers/logging.py +++ b/baseplate/observers/logging.py @@ -1,8 +1,6 @@ import threading -from baseplate import BaseplateObserver -from baseplate import RequestContext -from baseplate import Span +from baseplate import BaseplateObserver, RequestContext, Span class LoggingBaseplateObserver(BaseplateObserver): diff --git a/baseplate/observers/metrics.py b/baseplate/observers/metrics.py index 8ceb71364..14eb45400 100644 --- a/baseplate/observers/metrics.py +++ b/baseplate/observers/metrics.py @@ -1,15 +1,8 @@ from random import random -from typing import Any -from typing import Optional - -from baseplate import _ExcInfo -from baseplate import BaseplateObserver -from baseplate import LocalSpan -from baseplate import RequestContext -from baseplate import Span -from baseplate import SpanObserver -from baseplate.lib import config -from baseplate.lib import metrics +from typing import Any, Optional + +from baseplate import BaseplateObserver, LocalSpan, RequestContext, Span, SpanObserver, _ExcInfo +from baseplate.lib import config, metrics from baseplate.observers.timeout import ServerTimeout diff --git a/baseplate/observers/metrics_tagged.py b/baseplate/observers/metrics_tagged.py index c1321c7f9..cf3090f3c 100644 --- a/baseplate/observers/metrics_tagged.py +++ b/baseplate/observers/metrics_tagged.py @@ -1,23 +1,15 @@ from random import random -from typing import Any -from typing import Dict -from typing import Optional -from typing import Set +from typing import Any, Optional -from baseplate import _ExcInfo -from baseplate import BaseplateObserver -from baseplate import LocalSpan -from baseplate import RequestContext -from baseplate import Span -from baseplate import SpanObserver -from baseplate.lib import config -from baseplate.lib import metrics +from baseplate import BaseplateObserver, LocalSpan, RequestContext, Span, SpanObserver, _ExcInfo +from baseplate.lib import config, metrics class TaggedMetricsBaseplateObserver(BaseplateObserver): """Metrics collecting observer. - This observer reports metrics to statsd in the Influx StatsD format. It does three important things: + This observer reports metrics to statsd in the Influx StatsD format. It + does three important things: * it tracks the time taken in serving each request. * it batches all metrics generated during a request into as few packets @@ -32,7 +24,7 @@ class TaggedMetricsBaseplateObserver(BaseplateObserver): """ - def __init__(self, client: metrics.Client, allowlist: Set[str], sample_rate: float = 1.0): + def __init__(self, client: metrics.Client, allowlist: set[str], sample_rate: float = 1.0): self.client = client self.allowlist = allowlist self.sample_rate = sample_rate @@ -88,15 +80,15 @@ def on_child_span_created(self, span: Span) -> None: class TaggedMetricsServerSpanObserver(SpanObserver): def __init__( - self, batch: metrics.Batch, server_span: Span, allowlist: Set[str], sample_rate: float = 1.0 + self, batch: metrics.Batch, server_span: Span, allowlist: set[str], sample_rate: float = 1.0 ): self.batch = batch self.span = server_span self.base_name = "baseplate.server" self.allowlist = allowlist - self.tags: Dict[str, Any] = {} + self.tags: dict[str, Any] = {} self.timer = batch.timer(f"{self.base_name}.latency") - self.counters: Dict[str, float] = {} + self.counters: dict[str, float] = {} self.sample_rate = sample_rate def on_start(self) -> None: @@ -139,15 +131,15 @@ def on_finish(self, exc_info: Optional[_ExcInfo]) -> None: class TaggedMetricsLocalSpanObserver(SpanObserver): def __init__( - self, batch: metrics.Batch, span: Span, allowlist: Set[str], sample_rate: float = 1.0 + self, batch: metrics.Batch, span: Span, allowlist: set[str], sample_rate: float = 1.0 ): self.batch = batch self.span = span - self.tags: Dict[str, Any] = {} + self.tags: dict[str, Any] = {} self.base_name = "baseplate.local" self.timer = batch.timer(f"{self.base_name}.latency") self.allowlist = allowlist - self.counters: Dict[str, float] = {} + self.counters: dict[str, float] = {} self.sample_rate = sample_rate def on_start(self) -> None: @@ -191,15 +183,15 @@ def on_finish(self, exc_info: Optional[_ExcInfo]) -> None: class TaggedMetricsClientSpanObserver(SpanObserver): def __init__( - self, batch: metrics.Batch, span: Span, allowlist: Set[str], sample_rate: float = 1.0 + self, batch: metrics.Batch, span: Span, allowlist: set[str], sample_rate: float = 1.0 ): self.batch = batch self.span = span self.base_name = "baseplate.client" - self.tags: Dict[str, Any] = {} + self.tags: dict[str, Any] = {} self.timer = batch.timer(f"{self.base_name}.latency") self.allowlist = allowlist - self.counters: Dict[str, float] = {} + self.counters: dict[str, float] = {} self.sample_rate = sample_rate def on_start(self) -> None: diff --git a/baseplate/observers/sentry.py b/baseplate/observers/sentry.py index 1f6aa07b3..6a7b95f51 100644 --- a/baseplate/observers/sentry.py +++ b/baseplate/observers/sentry.py @@ -1,22 +1,12 @@ from __future__ import annotations import logging - from types import TracebackType -from typing import Any -from typing import List -from typing import Optional -from typing import Type -from typing import TYPE_CHECKING -from typing import Union +from typing import TYPE_CHECKING, Any import sentry_sdk -from baseplate import _ExcInfo -from baseplate import BaseplateObserver -from baseplate import RequestContext -from baseplate import ServerSpanObserver -from baseplate import Span +from baseplate import BaseplateObserver, RequestContext, ServerSpanObserver, Span, _ExcInfo from baseplate.lib import config from baseplate.observers.timeout import ServerTimeout @@ -83,7 +73,7 @@ def init_sentry_client_from_config(raw_config: config.RawConfig, **kwargs: Any) kwargs.setdefault("sample_rate", cfg.sentry.sample_rate) - ignore_errors: List[Union[type, str]] = [] + ignore_errors: list[type | str] = [] ignore_errors.extend(ALWAYS_IGNORE_ERRORS) ignore_errors.extend(cfg.sentry.ignore_errors) kwargs.setdefault("ignore_errors", ignore_errors) @@ -124,7 +114,7 @@ def on_set_tag(self, key: str, value: Any) -> None: def on_log(self, name: str, payload: Any) -> None: self.sentry_hub.add_breadcrumb({"category": name, "message": str(payload)}) - def on_finish(self, exc_info: Optional[_ExcInfo] = None) -> None: + def on_finish(self, exc_info: _ExcInfo | None = None) -> None: if exc_info is not None: self.sentry_hub.capture_exception(error=exc_info) self.scope_manager.__exit__(None, None, None) @@ -155,9 +145,9 @@ def __init__(self, hub: GeventHub): def __call__( self, context: Any, - exc_type: Optional[Type[BaseException]], - value: Optional[BaseException], - tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + value: BaseException | None, + tb: TracebackType | None, ) -> None: sentry_sdk.capture_exception((exc_type, value, tb)) diff --git a/baseplate/observers/timeout.py b/baseplate/observers/timeout.py index b009ea933..bc611285e 100644 --- a/baseplate/observers/timeout.py +++ b/baseplate/observers/timeout.py @@ -2,11 +2,7 @@ import gevent -from baseplate import _ExcInfo -from baseplate import BaseplateObserver -from baseplate import RequestContext -from baseplate import ServerSpan -from baseplate import SpanObserver +from baseplate import BaseplateObserver, RequestContext, ServerSpan, SpanObserver, _ExcInfo from baseplate.lib import config diff --git a/baseplate/observers/tracing.py b/baseplate/observers/tracing.py index a593ee1e8..7b8d69c79 100644 --- a/baseplate/observers/tracing.py +++ b/baseplate/observers/tracing.py @@ -1,4 +1,5 @@ """Components for processing Baseplate spans for service request tracing.""" + import collections import json import logging @@ -8,32 +9,17 @@ import threading import time import typing - from datetime import datetime -from typing import Any -from typing import DefaultDict -from typing import Dict -from typing import List -from typing import NamedTuple -from typing import Optional +from typing import Any, NamedTuple, Optional import requests - from requests.exceptions import RequestException -from baseplate import _ExcInfo -from baseplate import BaseplateObserver -from baseplate import LocalSpan -from baseplate import RequestContext -from baseplate import Span -from baseplate import SpanObserver -from baseplate.lib import config -from baseplate.lib import warn_deprecated -from baseplate.lib.message_queue import MessageQueue -from baseplate.lib.message_queue import TimedOutError +from baseplate import BaseplateObserver, LocalSpan, RequestContext, Span, SpanObserver, _ExcInfo +from baseplate.lib import config, warn_deprecated +from baseplate.lib.message_queue import MessageQueue, TimedOutError from baseplate.observers.timeout import ServerTimeout - if typing.TYPE_CHECKING: SpanQueue = queue.Queue["TraceSpanObserver"] # pylint: disable=unsubscriptable-object else: @@ -200,8 +186,8 @@ def __init__(self, service_name: str, hostname: str, span: Span, recorder: "Reco self.start: Optional[int] = None self.end: Optional[int] = None self.elapsed: Optional[int] = None - self.binary_annotations: List[Dict[str, Any]] = [] - self.counters: DefaultDict[str, float] = collections.defaultdict(float) + self.binary_annotations: list[dict[str, Any]] = [] + self.counters: collections.defaultdict[str, float] = collections.defaultdict(float) self.on_set_tag(ANNOTATIONS["COMPONENT"], "baseplate") super().__init__() @@ -236,10 +222,10 @@ def on_set_tag(self, key: str, value: Any) -> None: def on_incr_tag(self, key: str, delta: float) -> None: self.counters[key] += delta - def _endpoint_info(self) -> Dict[str, str]: + def _endpoint_info(self) -> dict[str, str]: return {"serviceName": self.service_name, "ipv4": self.hostname} - def _create_time_annotation(self, annotation_type: str, timestamp: int) -> Dict[str, Any]: + def _create_time_annotation(self, annotation_type: str, timestamp: int) -> dict[str, Any]: """Create Zipkin-compatible Annotation for a span. This should be used for generating span annotations with a time component, @@ -249,7 +235,7 @@ def _create_time_annotation(self, annotation_type: str, timestamp: int) -> Dict[ def _create_binary_annotation( self, annotation_type: str, annotation_value: Any - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Create Zipkin-compatible BinaryAnnotation for a span. This should be used for generating span annotations that @@ -267,8 +253,8 @@ def _create_binary_annotation( return {"key": annotation_type, "value": annotation_value, "endpoint": endpoint_info} def _to_span_obj( - self, annotations: List[Dict[str, Any]], binary_annotations: List[Dict[str, Any]] - ) -> Dict[str, Any]: + self, annotations: list[dict[str, Any]], binary_annotations: list[dict[str, Any]] + ) -> dict[str, Any]: span = { "traceId": self.span.trace_id, "name": self.span.name, @@ -282,7 +268,7 @@ def _to_span_obj( span["parentId"] = self.span.parent_id or 0 return span - def _serialize(self) -> Dict[str, Any]: + def _serialize(self) -> dict[str, Any]: """Serialize span information into Zipkin-accepted format.""" annotations = [] @@ -348,7 +334,7 @@ def on_child_span_created(self, span: Span) -> None: ) span.register(trace_observer) - def _serialize(self) -> Dict[str, Any]: + def _serialize(self) -> dict[str, Any]: return self._to_span_obj([], self.binary_annotations) @@ -396,9 +382,9 @@ def on_child_span_created(self, span: Span) -> None: ) span.register(trace_observer) - def _serialize(self) -> Dict[str, Any]: + def _serialize(self) -> dict[str, Any]: """Serialize span information into Zipkin-accepted format.""" - annotations: List[Dict[str, Any]] = [] + annotations: list[dict[str, Any]] = [] annotations.append( self._create_time_annotation( @@ -431,7 +417,7 @@ def __init__( self.flush_worker.daemon = True self.flush_worker.start() - def flush_func(self, spans: List[Dict[str, Any]]) -> None: + def flush_func(self, spans: list[dict[str, Any]]) -> None: raise NotImplementedError def _flush_spans(self) -> None: @@ -440,7 +426,7 @@ def _flush_spans(self) -> None: # empties while being processed before reaching 10 spans, we flush # immediately. while True: - spans: List[Dict[str, Any]] = [] + spans: list[dict[str, Any]] = [] try: while len(spans) < self.max_span_batch: spans.append(self.span_queue.get_nowait()._serialize()) @@ -471,7 +457,7 @@ def __init__( ): super().__init__(max_queue_size, num_workers, max_span_batch, batch_wait_interval) - def flush_func(self, spans: List[Dict[str, Any]]) -> None: + def flush_func(self, spans: list[dict[str, Any]]) -> None: """Write a set of spans to debug log.""" for span in spans: self.logger.debug("Span recording: %s", span) @@ -489,7 +475,7 @@ def __init__( ): super().__init__(max_queue_size, num_workers, max_span_batch, batch_wait_interval) - def flush_func(self, spans: List[Dict[str, Any]]) -> None: + def flush_func(self, spans: list[dict[str, Any]]) -> None: return @@ -510,14 +496,13 @@ def __init__( max_span_batch: int = 100, batch_wait_interval: float = 0.5, ): - super().__init__(max_queue_size, num_workers, max_span_batch, batch_wait_interval) adapter = requests.adapters.HTTPAdapter(pool_connections=num_conns, pool_maxsize=num_conns) self.session = requests.Session() self.session.mount("http://", adapter) self.endpoint = f"http://{endpoint}/api/v1/spans" - def flush_func(self, spans: List[Dict[str, Any]]) -> None: + def flush_func(self, spans: list[dict[str, Any]]) -> None: """Send a set of spans to remote collector.""" try: self.session.post( diff --git a/baseplate/server/__init__.py b/baseplate/server/__init__.py index 2351955d2..a0887dbd9 100644 --- a/baseplate/server/__init__.py +++ b/baseplate/server/__init__.py @@ -2,6 +2,7 @@ This command serves your application from the given configuration file. """ + from __future__ import annotations import argparse @@ -21,56 +22,43 @@ import time import traceback import warnings - +from collections.abc import Mapping, MutableMapping, Sequence from dataclasses import dataclass from datetime import datetime from enum import Enum from rlcompleter import Completer from types import FrameType -from typing import Any -from typing import Callable -from typing import Dict -from typing import List -from typing import Mapping -from typing import MutableMapping -from typing import NamedTuple -from typing import Optional -from typing import Sequence -from typing import TextIO -from typing import Tuple +from typing import ( + Any, + Callable, + NamedTuple, + TextIO, +) from gevent.server import StreamServer -from opentelemetry import propagate -from opentelemetry import trace +from opentelemetry import propagate, trace from opentelemetry.context import Context from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter from opentelemetry.instrumentation.logging import LoggingInstrumentor from opentelemetry.instrumentation.threading import ThreadingInstrumentor from opentelemetry.propagators.composite import CompositePropagator -from opentelemetry.sdk.trace import Span -from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace import Span, TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor -from opentelemetry.sdk.trace.sampling import DEFAULT_ON -from opentelemetry.sdk.trace.sampling import ParentBased +from opentelemetry.sdk.trace.sampling import DEFAULT_ON, ParentBased from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from baseplate import Baseplate from baseplate.lib import warn_deprecated -from baseplate.lib.config import Endpoint -from baseplate.lib.config import EndpointConfiguration +from baseplate.lib.config import Endpoint, EndpointConfiguration, Timespan, parse_config from baseplate.lib.config import Optional as OptionalConfig -from baseplate.lib.config import parse_config -from baseplate.lib.config import Timespan from baseplate.lib.log_formatter import CustomJsonFormatter from baseplate.lib.prometheus_metrics import is_metrics_enabled from baseplate.lib.propagator_redditb3_http import RedditB3HTTPFormat from baseplate.lib.propagator_redditb3_thrift import RedditB3ThriftFormat from baseplate.lib.tracing import RateLimited -from baseplate.server import einhorn -from baseplate.server import reloader +from baseplate.server import einhorn, reloader from baseplate.server.net import bind_socket - logger = logging.getLogger(__name__) @@ -147,13 +135,13 @@ def before_get( class Configuration(NamedTuple): filename: str - server: Optional[Dict[str, str]] - app: Dict[str, str] + server: dict[str, str] | None + app: dict[str, str] has_logging_options: bool - shell: Optional[Dict[str, str]] + shell: dict[str, str] | None -def read_config(config_file: TextIO, server_name: Optional[str], app_name: str) -> Configuration: +def read_config(config_file: TextIO, server_name: str | None, app_name: str) -> Configuration: # we use RawConfigParser to reduce surprise caused by interpolation and so # that config.Percent works more naturally (no escaping %). parser = configparser.RawConfigParser(interpolation=EnvironmentInterpolation()) @@ -185,7 +173,7 @@ def configure_logging(config: Configuration, debug: bool) -> None: formatter: logging.Formatter if not sys.stdin.isatty(): formatter = CustomJsonFormatter( - "%(levelname)s %(message)s %(funcName)s %(lineno)d %(module)s %(name)s %(pathname)s %(process)d %(processName)s %(thread)d %(threadName)s" + "%(levelname)s %(message)s %(funcName)s %(lineno)d %(module)s %(name)s %(pathname)s %(process)d %(processName)s %(thread)d %(threadName)s" # noqa: E501 ) else: formatter = logging.Formatter("%(levelname)-8s %(message)s") @@ -209,7 +197,7 @@ def configure_logging(config: Configuration, debug: bool) -> None: class BaseplateBatchSpanProcessor(BatchSpanProcessor): def __init__( - self, otlp_exporter: OTLPSpanExporter, attributes: Optional[Dict[str, Any]] = None + self, otlp_exporter: OTLPSpanExporter, attributes: dict[str, Any] | None = None ) -> None: logger.info( "Initializing %s with global attributes=%s.", self.__class__.__name__, attributes @@ -217,7 +205,7 @@ def __init__( super().__init__(otlp_exporter) self.baseplate_global_attributes = attributes - def on_start(self, span: Span, parent_context: Optional[Context] = None) -> None: + def on_start(self, span: Span, parent_context: Context | None = None) -> None: if self.baseplate_global_attributes: span.set_attributes(self.baseplate_global_attributes) super().on_start(span, parent_context) @@ -253,7 +241,7 @@ def make_listener(endpoint: EndpointConfiguration) -> socket.socket: return bind_socket(endpoint) -def _load_factory(url: str, default_name: Optional[str] = None) -> Callable: +def _load_factory(url: str, default_name: str | None = None) -> Callable: """Load a factory function from a config file.""" module_name, sep, func_name = url.partition(":") if not sep: @@ -266,14 +254,14 @@ def _load_factory(url: str, default_name: Optional[str] = None) -> Callable: def make_server( - server_config: Dict[str, str], listener: socket.socket, app: Callable + server_config: dict[str, str], listener: socket.socket, app: Callable ) -> StreamServer: server_url = server_config["factory"] factory = _load_factory(server_url, default_name="make_server") return factory(server_config, listener, app) -def make_app(app_config: Dict[str, str]) -> Callable: +def make_app(app_config: dict[str, str]) -> Callable: app_url = app_config["factory"] factory = _load_factory(app_url, default_name="make_app") return factory(app_config) @@ -392,7 +380,7 @@ def load_and_run_script() -> None: entrypoint(config.app) -def _parse_baseplate_script_args() -> Tuple[argparse.Namespace, List[str]]: +def _parse_baseplate_script_args() -> tuple[argparse.Namespace, list[str]]: parser = argparse.ArgumentParser( description="Run a function with app configuration loaded.", formatter_class=argparse.RawDescriptionHelpFormatter, @@ -468,7 +456,7 @@ def load_and_run_shell() -> None: config = read_config(args.config_file, server_name=None, app_name=args.app_name) logging.basicConfig(level=logging.INFO) - env: Dict[str, Any] = {} + env: dict[str, Any] = {} env_banner = { "app": "This project's app instance", "context": "The context for this shell instance's span", @@ -513,7 +501,8 @@ def load_and_run_shell() -> None: ipython_config = Config() ipython_config.InteractiveShellApp.exec_lines = [ - # monkeypatch IPython's log-write() to enable formatted input logging, copying original code: + # monkeypatch IPython's log-write() to enable formatted input + # logging, copying original code: # https://github.com/ipython/ipython/blob/a54bf00feb5182fa821bd5457897b3b30a313436/IPython/core/logger.py#L187-L201 f""" ip = get_ipython() @@ -534,7 +523,7 @@ def log_write(self, data, kind="input", message_id="IEXC"): ip.logger.log_write = partial(log_write, ip.logger) ip.magic('logstart {console_logpath} append') ip.logger.log_write(data="Start IPython logging\\n", message_id="ISTR") - """ + """ # noqa: E501 ] ipython_config.TerminalInteractiveShell.banner2 = banner ipython_config.LoggingMagics.quiet = True @@ -569,7 +558,8 @@ def _get_shell_log_path() -> str: def _is_containerized() -> bool: - """Determine if we're running in a container based on cgroup awareness for various container runtimes.""" + """Determine if we're running in a container based on cgroup awareness for + various container runtimes.""" if os.path.exists("/.dockerenv"): return True @@ -599,7 +589,7 @@ def _has_PID1_parent() -> bool: class LoggedInteractiveConsole(code.InteractiveConsole): - def __init__(self, _locals: Dict[str, Any], logpath: str) -> None: + def __init__(self, _locals: dict[str, Any], logpath: str) -> None: code.InteractiveConsole.__init__(self, _locals) self.output_file = logpath self.pid = os.getpid() @@ -607,17 +597,17 @@ def __init__(self, _locals: Dict[str, Any], logpath: str) -> None: self.hostname = os.uname().nodename self.log_event(message="Start InteractiveConsole logging", message_id="CSTR") - def raw_input(self, prompt: Optional[str] = "") -> str: + def raw_input(self, prompt: str | None = "") -> str: data = input(prompt) self.log_event(message=data, message_id="CEXC") return data def log_event( - self, message: str, message_id: Optional[str] = "-", structured: Optional[str] = "-" + self, message: str, message_id: str | None = "-", structured: str | None = "-" ) -> None: """Generate an RFC 5424 compliant syslog format.""" timestamp = datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S.%fZ") - prompt = f"<{self.pri}>1 {timestamp} {self.hostname} baseplate-shell {self.pid} {message_id} {structured} {message}" + prompt = f"<{self.pri}>1 {timestamp} {self.hostname} baseplate-shell {self.pid} {message_id} {structured} {message}" # noqa: E501 with open(self.output_file, "w", encoding="UTF-8") as f: print(prompt, file=f) f.flush() diff --git a/baseplate/server/__main__.py b/baseplate/server/__main__.py index e4bbf1920..20f813a8d 100644 --- a/baseplate/server/__main__.py +++ b/baseplate/server/__main__.py @@ -1,4 +1,3 @@ from baseplate.server import load_app_and_run_server - load_app_and_run_server() diff --git a/baseplate/server/einhorn.py b/baseplate/server/einhorn.py index 32552a03b..32d0b67a2 100644 --- a/baseplate/server/einhorn.py +++ b/baseplate/server/einhorn.py @@ -1,4 +1,5 @@ """Client library for children of Einhorn.""" + import contextlib import json import os diff --git a/baseplate/server/monkey.py b/baseplate/server/monkey.py index 50b252901..9c628f60e 100644 --- a/baseplate/server/monkey.py +++ b/baseplate/server/monkey.py @@ -8,6 +8,7 @@ def patch_stdlib_queues() -> None: https://github.com/gevent/gevent/issues/1875 """ import queue + import gevent.queue monkey.patch_module(queue, gevent.queue, items=["Queue", "LifoQueue", "PriorityQueue"]) diff --git a/baseplate/server/prometheus.py b/baseplate/server/prometheus.py index 3c55ec458..285c2bf11 100644 --- a/baseplate/server/prometheus.py +++ b/baseplate/server/prometheus.py @@ -10,31 +10,32 @@ can aggregate and serve metrics for all workers. """ + import atexit import logging import os import sys - -from typing import Iterable +from collections.abc import Iterable from typing import TYPE_CHECKING -from gevent.pywsgi import LoggingLogAdapter -from gevent.pywsgi import WSGIServer -from prometheus_client import CollectorRegistry -from prometheus_client import CONTENT_TYPE_LATEST -from prometheus_client import generate_latest -from prometheus_client import multiprocess -from prometheus_client import values +from gevent.pywsgi import LoggingLogAdapter, WSGIServer +from prometheus_client import ( + CONTENT_TYPE_LATEST, + CollectorRegistry, + generate_latest, + multiprocess, + values, +) from prometheus_client.values import MultiProcessValue -from baseplate.lib.config import Endpoint -from baseplate.lib.config import EndpointConfiguration +from baseplate.lib.config import Endpoint, EndpointConfiguration from baseplate.server.net import bind_socket - if TYPE_CHECKING: - from _typeshed.wsgi import StartResponse # pylint: disable=import-error,no-name-in-module - from _typeshed.wsgi import WSGIEnvironment # pylint: disable=import-error,no-name-in-module + from _typeshed.wsgi import ( + StartResponse, # pylint: disable=import-error,no-name-in-module + WSGIEnvironment, # pylint: disable=import-error,no-name-in-module + ) logger = logging.getLogger(__name__) @@ -65,7 +66,7 @@ def export_metrics(environ: "WSGIEnvironment", start_response: "StartResponse") def start_prometheus_exporter(address: EndpointConfiguration = PROMETHEUS_EXPORTER_ADDRESS) -> None: if "PROMETHEUS_MULTIPROC_DIR" not in os.environ: logger.error( - "prometheus-client is installed but PROMETHEUS_MULTIPROC_DIR is not set to a writeable directory." + "prometheus-client is installed but PROMETHEUS_MULTIPROC_DIR is not set to a writeable directory." # noqa: E501 ) sys.exit(1) diff --git a/baseplate/server/queue_consumer.py b/baseplate/server/queue_consumer.py index c11cfbf07..dedaea2dc 100644 --- a/baseplate/server/queue_consumer.py +++ b/baseplate/server/queue_consumer.py @@ -8,27 +8,18 @@ import signal import socket import uuid - +from collections.abc import Sequence from threading import Thread -from typing import Any -from typing import Callable -from typing import Dict -from typing import List -from typing import Optional -from typing import Sequence -from typing import TYPE_CHECKING - -from gevent.pywsgi import LoggingLogAdapter -from gevent.pywsgi import WSGIServer +from typing import TYPE_CHECKING, Any, Callable + +from gevent.pywsgi import LoggingLogAdapter, WSGIServer from gevent.server import StreamServer import baseplate.lib.config - from baseplate.lib.retry import RetryPolicy from baseplate.observers.timeout import ServerTimeout from baseplate.server import runtime_monitor - logger = logging.getLogger(__name__) @@ -36,15 +27,15 @@ # TODO: Replace with wsgiref.types once on 3.11+ from _typeshed.wsgi import StartResponse -WSGIEnvironment = Dict[str, Any] +WSGIEnvironment = dict[str, Any] HealthcheckCallback = Callable[[WSGIEnvironment], bool] class HealthcheckApp: - def __init__(self, callback: Optional[HealthcheckCallback] = None) -> None: + def __init__(self, callback: HealthcheckCallback | None = None) -> None: self.callback = callback - def __call__(self, environ: WSGIEnvironment, start_response: StartResponse) -> List[bytes]: + def __call__(self, environ: WSGIEnvironment, start_response: StartResponse) -> list[bytes]: ok = True if self.callback: ok = self.callback(environ) @@ -57,7 +48,7 @@ def __call__(self, environ: WSGIEnvironment, start_response: StartResponse) -> L def make_simple_healthchecker( - listener: socket.socket, callback: Optional[HealthcheckCallback] = None + listener: socket.socket, callback: HealthcheckCallback | None = None ) -> WSGIServer: return WSGIServer( listener=listener, @@ -67,7 +58,8 @@ def make_simple_healthchecker( class PumpWorker(abc.ABC): - """Reads messages off of a message queue and puts them into a queue.Queue for handling by a MessageHandler. + """Reads messages off of a message queue and puts them into a queue.Queue + for handling by a MessageHandler. The QueueConsumerServer will run a single PumpWorker in its own thread. """ @@ -83,7 +75,8 @@ def run(self) -> None: @abc.abstractmethod def stop(self) -> None: - """Signal the PumpWorker that it should stop receiving new messages from its message queue.""" + """Signal the PumpWorker that it should stop receiving new messages + from its message queue.""" class MessageHandler(abc.ABC): @@ -140,7 +133,8 @@ def build_health_checker(self, listener: socket.socket) -> StreamServer: class QueueConsumer: - """Wrapper around a MessageHandler object that interfaces with the work_queue and starts/stops the handle loop. + """Wrapper around a MessageHandler object that interfaces with the + work_queue and starts/stops the handle loop. This object is used by the QueueConsumerServer to wrap a MessageHandler object before creating a worker Thread. This allows the MessageHandler to focus soley @@ -232,7 +226,7 @@ def new( consumer_factory: QueueConsumerFactory, listener: socket.socket, stop_timeout: datetime.timedelta, - ) -> "QueueConsumerServer": + ) -> QueueConsumerServer: """Build a new QueueConsumerServer.""" # We want to give some headroom on the queue so our handlers can grab # a new message right after they finish so we keep an extra @@ -314,7 +308,7 @@ def stop(self) -> None: def make_server( - server_config: Dict[str, str], listener: socket.socket, app: QueueConsumerFactory + server_config: dict[str, str], listener: socket.socket, app: QueueConsumerFactory ) -> QueueConsumerServer: """Make a queue consumer server for long running queue consumer apps. diff --git a/baseplate/server/reloader.py b/baseplate/server/reloader.py index c9ba1f3e6..8abd9e710 100644 --- a/baseplate/server/reloader.py +++ b/baseplate/server/reloader.py @@ -5,18 +5,15 @@ settings. """ + import logging import os import re import sys import threading import time - -from typing import Dict -from typing import Iterator +from collections.abc import Iterator, Sequence from typing import NoReturn -from typing import Sequence - logger = logging.getLogger(__name__) @@ -38,12 +35,12 @@ def _get_watched_files(extra_files: Sequence[str]) -> Iterator[str]: def _reload_when_files_change(extra_files: Sequence[str]) -> NoReturn: """Scan all watched files periodically and re-exec if anything changed.""" - initial_mtimes: Dict[str, float] = {} + initial_mtimes: dict[str, float] = {} while True: for filename in _get_watched_files(extra_files): try: current_mtime = os.path.getmtime(filename) - except os.error: + except OSError: continue initial_mtimes.setdefault(filename, current_mtime) diff --git a/baseplate/server/runtime_monitor.py b/baseplate/server/runtime_monitor.py index 76435af20..687c0e78a 100644 --- a/baseplate/server/runtime_monitor.py +++ b/baseplate/server/runtime_monitor.py @@ -6,28 +6,20 @@ import socket import threading import time - -from typing import Any -from typing import Callable -from typing import Dict -from typing import List -from typing import NoReturn -from typing import Optional +from typing import Any, Callable, NoReturn import gevent.events - from gevent.pool import Pool -from baseplate import _ExcInfo -from baseplate import Baseplate -from baseplate import BaseplateObserver -from baseplate import RequestContext -from baseplate import ServerSpan -from baseplate import ServerSpanObserver -from baseplate.lib import config -from baseplate.lib import metrics -from baseplate.lib import prometheus_metrics - +from baseplate import ( + Baseplate, + BaseplateObserver, + RequestContext, + ServerSpan, + ServerSpanObserver, + _ExcInfo, +) +from baseplate.lib import config, metrics, prometheus_metrics REPORT_INTERVAL_SECONDS = 10 MAX_REQUEST_AGE = 60 @@ -51,7 +43,7 @@ def report(self, batch: metrics.Batch) -> None: class _ActiveRequestsObserver(BaseplateObserver, _Reporter): def __init__(self) -> None: - self.live_requests: Dict[str, float] = {} + self.live_requests: dict[str, float] = {} def on_server_span_created(self, context: RequestContext, server_span: ServerSpan) -> None: observer = _ActiveRequestsServerSpanObserver(self, server_span.trace_id) @@ -78,7 +70,7 @@ def __init__(self, reporter: _ActiveRequestsObserver, trace_id: str): def on_start(self) -> None: self.reporter.live_requests[self.trace_id] = time.time() - def on_finish(self, exc_info: Optional[_ExcInfo]) -> None: + def on_finish(self, exc_info: _ExcInfo | None) -> None: self.reporter.live_requests.pop(self.trace_id, None) @@ -89,7 +81,7 @@ def __init__(self, max_blocking_time: int): gevent.config.max_blocking_time = max_blocking_time gevent.get_hub().start_periodic_monitoring_thread() - self.times_blocked: List[int] = [] + self.times_blocked: list[int] = [] def _on_gevent_event(self, event: Any) -> None: if isinstance(event, gevent.events.EventLoopBlocked): @@ -118,10 +110,10 @@ class _GCTimingReporter(_Reporter): def __init__(self) -> None: gc.callbacks.append(self._on_gc_event) - self.gc_durations: List[float] = [] - self.current_gc_start: Optional[float] = None + self.gc_durations: list[float] = [] + self.current_gc_start: float | None = None - def _on_gc_event(self, phase: str, _info: Dict[str, Any]) -> None: + def _on_gc_event(self, phase: str, _info: dict[str, Any]) -> None: if phase == "start": self.current_gc_start = time.time() elif phase == "stop": @@ -139,7 +131,7 @@ def report(self, batch: metrics.Batch) -> None: class _BaseplateReporter(_Reporter): - def __init__(self, reporters: Dict[str, Callable[[Any], None]]): + def __init__(self, reporters: dict[str, Callable[[Any], None]]): self.reporters = reporters def report(self, batch: metrics.Batch) -> None: @@ -188,7 +180,7 @@ def report(self, batch: metrics.Batch) -> None: # pylint: disable=unused-argume def _report_runtime_metrics_periodically( - metrics_client: metrics.Client, reporters: List[_Reporter] + metrics_client: metrics.Client, reporters: list[_Reporter] ) -> NoReturn: hostname = socket.gethostname() pid = str(os.getpid()) @@ -218,7 +210,7 @@ def _report_runtime_metrics_periodically( logger.debug("Error while sending server metrics: %s", exc) -def start(server_config: Dict[str, str], application: Any, pool: Pool) -> None: +def start(server_config: dict[str, str], application: Any, pool: Pool) -> None: baseplate: Baseplate | None = getattr(application, "baseplate", None) # As of October 1, 2022 Reddit uses Prometheus to track metrics, not Statsd # this checks to see if Prometheus metrics are enabled and uses this to determine @@ -248,7 +240,7 @@ def start(server_config: Dict[str, str], application: Any, pool: Pool) -> None: }, ) - reporters: List[_Reporter] = [] + reporters: list[_Reporter] = [] if cfg.monitoring.concurrency: reporters.append(_OpenConnectionsReporter(pool)) diff --git a/baseplate/server/thrift.py b/baseplate/server/thrift.py index bced92f8e..83c192d48 100644 --- a/baseplate/server/thrift.py +++ b/baseplate/server/thrift.py @@ -1,11 +1,7 @@ import datetime import logging import socket - -from typing import Any -from typing import Dict -from typing import Tuple -from typing import Union +from typing import Any, Union from form_observability import ctx from gevent.pool import Pool @@ -16,18 +12,16 @@ from thrift.Thrift import TProcessor from thrift.transport.THeaderTransport import THeaderClientType from thrift.transport.TSocket import TSocket -from thrift.transport.TTransport import TBufferedTransportFactory -from thrift.transport.TTransport import TTransportException +from thrift.transport.TTransport import TBufferedTransportFactory, TTransportException from baseplate.lib import config from baseplate.server import runtime_monitor - logger = logging.getLogger(__name__) tracer = trace.get_tracer(__name__) -Address = Union[Tuple[str, int], str] +Address = Union[tuple[str, int], str] # pylint: disable=too-many-public-methods @@ -89,7 +83,7 @@ def handle(self, client_socket: socket.socket, address: Address) -> None: trans.close() -def make_server(server_config: Dict[str, str], listener: socket.socket, app: Any) -> StreamServer: +def make_server(server_config: dict[str, str], listener: socket.socket, app: Any) -> StreamServer: # pylint: disable=maybe-no-member cfg = config.parse_config( server_config, diff --git a/baseplate/server/wsgi.py b/baseplate/server/wsgi.py index 902a0032a..448c41ad9 100644 --- a/baseplate/server/wsgi.py +++ b/baseplate/server/wsgi.py @@ -1,24 +1,19 @@ import datetime import logging import socket - from typing import Any -from typing import Dict from gevent.pool import Pool -from gevent.pywsgi import LoggingLogAdapter -from gevent.pywsgi import WSGIServer +from gevent.pywsgi import LoggingLogAdapter, WSGIServer from gevent.server import StreamServer from baseplate.lib import config -from baseplate.server import _load_factory -from baseplate.server import runtime_monitor - +from baseplate.server import _load_factory, runtime_monitor logger = logging.getLogger(__name__) -def make_server(server_config: Dict[str, str], listener: socket.socket, app: Any) -> StreamServer: +def make_server(server_config: dict[str, str], listener: socket.socket, app: Any) -> StreamServer: """Make a gevent server for WSGI apps.""" # pylint: disable=maybe-no-member cfg = config.parse_config( @@ -40,7 +35,7 @@ def make_server(server_config: Dict[str, str], listener: socket.socket, app: Any pool = Pool() log = LoggingLogAdapter(logger, level=logging.DEBUG) - kwargs: Dict[str, Any] = {} + kwargs: dict[str, Any] = {} if cfg.handler: kwargs["handler_class"] = _load_factory(cfg.handler, default_name=None) diff --git a/baseplate/sidecars/__init__.py b/baseplate/sidecars/__init__.py index 0ff05fc4f..a35700c36 100644 --- a/baseplate/sidecars/__init__.py +++ b/baseplate/sidecars/__init__.py @@ -1,8 +1,5 @@ import time - -from typing import List -from typing import NamedTuple -from typing import Optional +from typing import NamedTuple, Optional class SerializedBatch(NamedTuple): @@ -48,7 +45,7 @@ def serialize(self) -> SerializedBatch: ) def reset(self) -> None: - self._items: List[bytes] = [] + self._items: list[bytes] = [] self._size = 2 # the [] that wrap the json list diff --git a/baseplate/sidecars/event_publisher.py b/baseplate/sidecars/event_publisher.py index 874990886..f8f6ce806 100644 --- a/baseplate/sidecars/event_publisher.py +++ b/baseplate/sidecars/event_publisher.py @@ -5,28 +5,18 @@ import hashlib import hmac import logging - -from typing import Any -from typing import List -from typing import Optional +from typing import Any, Optional import requests from baseplate import __version__ as baseplate_version -from baseplate.lib import config -from baseplate.lib import metrics -from baseplate.lib.events import MAX_EVENT_SIZE -from baseplate.lib.events import MAX_QUEUE_SIZE -from baseplate.lib.message_queue import MessageQueue -from baseplate.lib.message_queue import TimedOutError +from baseplate.lib import config, metrics +from baseplate.lib.events import MAX_EVENT_SIZE, MAX_QUEUE_SIZE +from baseplate.lib.message_queue import MessageQueue, TimedOutError from baseplate.lib.metrics import metrics_client_from_config from baseplate.lib.retry import RetryPolicy from baseplate.server import EnvironmentInterpolation -from baseplate.sidecars import Batch -from baseplate.sidecars import BatchFull -from baseplate.sidecars import SerializedBatch -from baseplate.sidecars import TimeLimitedBatch - +from baseplate.sidecars import Batch, BatchFull, SerializedBatch, TimeLimitedBatch logger = logging.getLogger(__name__) @@ -83,7 +73,7 @@ def serialize(self) -> SerializedBatch: ) def reset(self) -> None: - self._items: List[bytes] = [] + self._items: list[bytes] = [] self._size = len(self._header) + len(self._end) @@ -105,9 +95,9 @@ def __init__(self, metrics_client: metrics.Client, cfg: Any): self.key_name = cfg.key.name self.key_secret = cfg.key.secret self.session = requests.Session() - self.session.headers[ - "User-Agent" - ] = f"baseplate.py-{self.__class__.__name__}/{baseplate_version}" + self.session.headers["User-Agent"] = ( + f"baseplate.py-{self.__class__.__name__}/{baseplate_version}" + ) def _sign_payload(self, payload: bytes) -> str: digest = hmac.new(self.key_secret, payload, hashlib.sha256).hexdigest() diff --git a/baseplate/sidecars/live_data_watcher.py b/baseplate/sidecars/live_data_watcher.py index acbefb444..234ea0469 100644 --- a/baseplate/sidecars/live_data_watcher.py +++ b/baseplate/sidecars/live_data_watcher.py @@ -1,4 +1,5 @@ """Watch nodes in ZooKeeper and sync their contents to disk on change.""" + import argparse import configparser import json @@ -7,18 +8,16 @@ import random import sys import time - from enum import Enum from pathlib import Path -from typing import Any -from typing import NoReturn -from typing import Optional +from typing import Any, NoReturn, Optional import boto3 # type: ignore - from botocore import UNSIGNED # type: ignore -from botocore.client import ClientError # type: ignore -from botocore.client import Config +from botocore.client import ( # type: ignore + ClientError, + Config, +) from botocore.exceptions import EndpointConnectionError # type: ignore from kazoo.client import KazooClient from kazoo.protocol.states import ZnodeStat @@ -28,7 +27,6 @@ from baseplate.lib.secrets import secrets_store_from_config from baseplate.server import EnvironmentInterpolation - logger = logging.getLogger(__name__) @@ -154,7 +152,7 @@ def _load_from_s3(data: bytes) -> bytes: except KeyError as e: # We require all of these keys to properly read from S3. logger.exception( - "Failed to update live config: unable to fetch content from s3: source config has invalid or missing keys: %s.", + "Failed to update live config: unable to fetch content from s3: source config has invalid or missing keys: %s.", # noqa: E501 e.args[0], ) raise LoaderException from e @@ -196,7 +194,7 @@ def _load_from_s3(data: bytes) -> bytes: raise LoaderException from error except ValueError as error: logger.exception( - "Failed to update live config: params for loading from S3 are incorrect. Received error: %s", + "Failed to update live config: params for loading from S3 are incorrect. Received error: %s", # noqa: E501 error, ) diff --git a/baseplate/sidecars/secrets_fetcher.py b/baseplate/sidecars/secrets_fetcher.py index 2eae5d546..5333c22ab 100644 --- a/baseplate/sidecars/secrets_fetcher.py +++ b/baseplate/sidecars/secrets_fetcher.py @@ -60,6 +60,7 @@ write to a new file in whatever format needed, and restart other services if necessary. """ + import argparse import configparser import datetime @@ -71,12 +72,7 @@ import time import urllib.parse import uuid - -from typing import Any -from typing import Callable -from typing import Dict -from typing import Optional -from typing import Tuple +from typing import Any, Callable, Optional import requests @@ -84,7 +80,6 @@ from baseplate.lib import config from baseplate.server import EnvironmentInterpolation - logger = logging.getLogger(__name__) @@ -96,9 +91,7 @@ different nonce, a vault operator may need to remove the instance ID from the identity whitelist. See https://www.vaultproject.io/docs/auth/aws.html#client-nonce -""".replace( - "\n", " " -) +""".replace("\n", " ") def fetch_instance_identity() -> str: @@ -143,7 +136,7 @@ def ttl_to_time(ttl: int) -> datetime.datetime: return datetime.datetime.utcnow() + datetime.timedelta(seconds=ttl) -Authenticator = Callable[["VaultClientFactory"], Tuple[str, datetime.datetime]] +Authenticator = Callable[["VaultClientFactory"], tuple[str, datetime.datetime]] class VaultClientFactory: @@ -155,10 +148,10 @@ def __init__(self, base_url: str, role: str, auth_type: Authenticator, mount_poi self.auth_type = auth_type self.mount_point = mount_point self.session = requests.Session() - self.session.headers[ - "User-Agent" - ] = f"baseplate.py-{self.__class__.__name__}/{baseplate_version}" - self.client: Optional["VaultClient"] = None + self.session.headers["User-Agent"] = ( + f"baseplate.py-{self.__class__.__name__}/{baseplate_version}" + ) + self.client: Optional[VaultClient] = None def _make_client(self) -> "VaultClient": """Obtain a client token from an auth backend and return a Vault client with it.""" @@ -166,7 +159,7 @@ def _make_client(self) -> "VaultClient": return VaultClient(self.session, self.base_url, client_token, lease_duration) - def _vault_kubernetes_auth(self) -> Tuple[str, datetime.datetime]: + def _vault_kubernetes_auth(self) -> tuple[str, datetime.datetime]: r"""Get a client token from Vault through the Kubernetes auth backend. This authenticates with Vault as a specified role using its @@ -208,7 +201,7 @@ def _vault_kubernetes_auth(self) -> Tuple[str, datetime.datetime]: auth = response.json()["auth"] return auth["client_token"], ttl_to_time(auth["lease_duration"]) - def _vault_aws_auth(self) -> Tuple[str, datetime.datetime]: + def _vault_aws_auth(self) -> tuple[str, datetime.datetime]: r"""Get a client token from Vault through the AWS auth backend. This authenticates with Vault as a specified role using its AWS @@ -256,7 +249,7 @@ def _vault_aws_auth(self) -> Tuple[str, datetime.datetime]: return auth["client_token"], ttl_to_time(auth["lease_duration"]) @staticmethod - def auth_types() -> Dict[str, Authenticator]: + def auth_types() -> dict[str, Authenticator]: """Return a dict of the supported auth types and respective methods.""" return { "aws": VaultClientFactory._vault_aws_auth, @@ -296,7 +289,7 @@ def is_about_to_expire(self) -> bool: expiration = self.token_expiration - VAULT_TOKEN_PREFETCH_TIME return expiration < datetime.datetime.utcnow() - def get_secret(self, secret_name: str) -> Tuple[Any, datetime.datetime]: + def get_secret(self, secret_name: str) -> tuple[Any, datetime.datetime]: """Get the value and expiration time of a named secret.""" logger.debug("Fetching secret %r.", secret_name) try: diff --git a/baseplate/sidecars/trace_publisher.py b/baseplate/sidecars/trace_publisher.py index 747373b33..51af7ead5 100644 --- a/baseplate/sidecars/trace_publisher.py +++ b/baseplate/sidecars/trace_publisher.py @@ -2,26 +2,18 @@ import configparser import logging import urllib.parse - from typing import Optional import requests from baseplate import __version__ as baseplate_version -from baseplate.lib import config -from baseplate.lib import metrics -from baseplate.lib.message_queue import MessageQueue -from baseplate.lib.message_queue import TimedOutError +from baseplate.lib import config, metrics +from baseplate.lib.message_queue import MessageQueue, TimedOutError from baseplate.lib.metrics import metrics_client_from_config from baseplate.lib.retry import RetryPolicy -from baseplate.observers.tracing import MAX_QUEUE_SIZE -from baseplate.observers.tracing import MAX_SPAN_SIZE +from baseplate.observers.tracing import MAX_QUEUE_SIZE, MAX_SPAN_SIZE from baseplate.server import EnvironmentInterpolation -from baseplate.sidecars import BatchFull -from baseplate.sidecars import RawJSONBatch -from baseplate.sidecars import SerializedBatch -from baseplate.sidecars import TimeLimitedBatch - +from baseplate.sidecars import BatchFull, RawJSONBatch, SerializedBatch, TimeLimitedBatch logger = logging.getLogger(__name__) @@ -58,13 +50,12 @@ def __init__( retry_limit: int = RETRY_LIMIT_DEFAULT, num_conns: int = 5, ): - adapter = requests.adapters.HTTPAdapter(pool_connections=num_conns, pool_maxsize=num_conns) parsed_url = urllib.parse.urlparse(zipkin_api_url) self.session = requests.Session() - self.session.headers[ - "User-Agent" - ] = f"baseplate.py-{self.__class__.__name__}/{baseplate_version}" + self.session.headers["User-Agent"] = ( + f"baseplate.py-{self.__class__.__name__}/{baseplate_version}" + ) self.session.mount(f"{parsed_url.scheme}://", adapter) self.endpoint = f"{zipkin_api_url}/spans" self.metrics = metrics_client diff --git a/baseplate/testing/lib/file_watcher.py b/baseplate/testing/lib/file_watcher.py index 83d10f301..2bf9b8172 100644 --- a/baseplate/testing/lib/file_watcher.py +++ b/baseplate/testing/lib/file_watcher.py @@ -1,13 +1,7 @@ import typing - -from typing import Tuple -from typing import Type from typing import Union -from baseplate.lib.file_watcher import _NOT_LOADED -from baseplate.lib.file_watcher import FileWatcher -from baseplate.lib.file_watcher import T -from baseplate.lib.file_watcher import WatchedFileNotAvailableError +from baseplate.lib.file_watcher import _NOT_LOADED, FileWatcher, T, WatchedFileNotAvailableError class FakeFileWatcher(FileWatcher): @@ -35,11 +29,11 @@ class FakeFileWatcher(FileWatcher): """ # pylint: disable=super-init-not-called - def __init__(self, data: Union[T, Type[_NOT_LOADED]] = _NOT_LOADED, mtime: float = 1234): + def __init__(self, data: Union[T, type[_NOT_LOADED]] = _NOT_LOADED, mtime: float = 1234): self.data = data self.mtime = mtime - def get_data_and_mtime(self) -> Tuple[T, float]: + def get_data_and_mtime(self) -> tuple[T, float]: if self.data is _NOT_LOADED: raise WatchedFileNotAvailableError("/fake-file-watcher", Exception("no value set")) return typing.cast(T, self.data), self.mtime diff --git a/baseplate/testing/lib/secrets.py b/baseplate/testing/lib/secrets.py index cc3d9cece..e393da84a 100644 --- a/baseplate/testing/lib/secrets.py +++ b/baseplate/testing/lib/secrets.py @@ -1,8 +1,5 @@ -from typing import Dict - from baseplate import Span -from baseplate.lib.secrets import parse_secrets_fetcher -from baseplate.lib.secrets import SecretsStore +from baseplate.lib.secrets import SecretsStore, parse_secrets_fetcher from baseplate.testing.lib.file_watcher import FakeFileWatcher @@ -34,7 +31,7 @@ class FakeSecretsStore(SecretsStore): """ # pylint: disable=super-init-not-called - def __init__(self, fake_secrets: Dict) -> None: + def __init__(self, fake_secrets: dict) -> None: self._filewatcher = FakeFileWatcher(fake_secrets) self.parser = parse_secrets_fetcher diff --git a/docs/conf.py b/docs/conf.py index fe8d9f141..cd7f61ca9 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -77,9 +77,7 @@ # which templates to put in the sidebar. we're just removing the relations # section from the defaults here, that's "next article" and "previous article" -html_sidebars = { - "**": ["about.html", "searchbox.html", "navigation.html"] -} +html_sidebars = {"**": ["about.html", "searchbox.html", "navigation.html"]} html_theme_options = { "description": "Reddit's Python Service Framework", diff --git a/docs/pyproject.toml b/docs/pyproject.toml deleted file mode 100644 index 8b89f3d71..000000000 --- a/docs/pyproject.toml +++ /dev/null @@ -1,8 +0,0 @@ -[tool.black] -# code blocks look funky with scroll bars and they're quite narrow in the -# alabaster theme. so we limit it to keep things sane. -line-length = 74 - -# Always use our latest supported version here since we want code snippets in -# docs to use the most up-to-date syntax. -target-version = ['py39'] diff --git a/docs/tutorial/chapter3/helloworld.py b/docs/tutorial/chapter3/helloworld.py index 2152754a6..f62237e05 100644 --- a/docs/tutorial/chapter3/helloworld.py +++ b/docs/tutorial/chapter3/helloworld.py @@ -1,9 +1,9 @@ -from baseplate import Baseplate -from baseplate.frameworks.pyramid import BaseplateConfigurator - from pyramid.config import Configurator from pyramid.view import view_config +from baseplate import Baseplate +from baseplate.frameworks.pyramid import BaseplateConfigurator + @view_config(route_name="hello_world", renderer="json") def hello_world(request): diff --git a/docs/tutorial/chapter4/helloworld.py b/docs/tutorial/chapter4/helloworld.py index fb59380e8..4ca3455de 100644 --- a/docs/tutorial/chapter4/helloworld.py +++ b/docs/tutorial/chapter4/helloworld.py @@ -1,10 +1,10 @@ +from pyramid.config import Configurator +from pyramid.view import view_config + from baseplate import Baseplate from baseplate.clients.sqlalchemy import SQLAlchemySession from baseplate.frameworks.pyramid import BaseplateConfigurator -from pyramid.config import Configurator -from pyramid.view import view_config - @view_config(route_name="hello_world", renderer="json") def hello_world(request): diff --git a/poetry.lock b/poetry.lock index f4ef3bb8e..4d97e03a8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -92,17 +92,6 @@ typing-extensions = {version = ">=4", markers = "python_version < \"3.11\""} [package.extras] tests = ["mypy (>=0.800)", "pytest", "pytest-asyncio"] -[[package]] -name = "aspy-refactor-imports" -version = "3.0.2" -description = "Utilities for refactoring imports in python-like syntax." -optional = false -python-versions = ">=3.7" -files = [ - {file = "aspy.refactor_imports-3.0.2-py2.py3-none-any.whl", hash = "sha256:f306037682479945df61b2e6d01bf97256d68f3e704742768deef549e0d61fbb"}, - {file = "aspy.refactor_imports-3.0.2.tar.gz", hash = "sha256:3c7329cdb2613c46fcd757c8e45120efbc3d4b9db805092911eb605c19c5795c"}, -] - [[package]] name = "astroid" version = "3.3.5" @@ -152,36 +141,6 @@ charset-normalizer = ["charset-normalizer"] html5lib = ["html5lib"] lxml = ["lxml"] -[[package]] -name = "black" -version = "21.10b0" -description = "The uncompromising code formatter." -optional = false -python-versions = ">=3.6.2" -files = [ - {file = "black-21.10b0-py3-none-any.whl", hash = "sha256:6eb7448da9143ee65b856a5f3676b7dda98ad9abe0f87fce8c59291f15e82a5b"}, - {file = "black-21.10b0.tar.gz", hash = "sha256:a9952229092e325fe5f3dae56d81f639b23f7131eb840781947e4b2886030f33"}, -] - -[package.dependencies] -click = ">=7.1.2" -mypy-extensions = ">=0.4.3" -pathspec = ">=0.9.0,<1" -platformdirs = ">=2" -regex = ">=2020.1.8" -tomli = ">=0.2.6,<2.0.0" -typing-extensions = [ - {version = ">=3.10.0.0", markers = "python_version < \"3.10\""}, - {version = ">=3.10.0.0,<3.10.0.1 || >3.10.0.1", markers = "python_version >= \"3.10\""}, -] - -[package.extras] -colorama = ["colorama (>=0.4.3)"] -d = ["aiohttp (>=3.7.4)"] -jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] -python2 = ["typed-ast (>=1.4.3)"] -uvloop = ["uvloop (>=0.15.2)"] - [[package]] name = "boto3" version = "1.35.53" @@ -909,22 +868,6 @@ docs = ["furo (>=2024.8.6)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2. testing = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "diff-cover (>=9.2)", "pytest (>=8.3.3)", "pytest-asyncio (>=0.24)", "pytest-cov (>=5)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.26.4)"] typing = ["typing-extensions (>=4.12.2)"] -[[package]] -name = "flake8" -version = "7.1.1" -description = "the modular source code checker: pep8 pyflakes and co" -optional = false -python-versions = ">=3.8.1" -files = [ - {file = "flake8-7.1.1-py2.py3-none-any.whl", hash = "sha256:597477df7860daa5aa0fdd84bf5208a043ab96b8e96ab708770ae0364dd03213"}, - {file = "flake8-7.1.1.tar.gz", hash = "sha256:049d058491e228e03e67b390f311bbf88fce2dbaa8fa673e7aea87b7198b8d38"}, -] - -[package.dependencies] -mccabe = ">=0.7.0,<0.8.0" -pycodestyle = ">=2.12.0,<2.13.0" -pyflakes = ">=3.2.0,<3.3.0" - [[package]] name = "formenergy-observability" version = "0.3.2" @@ -2249,17 +2192,6 @@ docs = ["Sphinx (>=1.7.5)", "pylons-sphinx-themes"] paste = ["Paste"] testing = ["Paste", "pytest", "pytest-cov"] -[[package]] -name = "pathspec" -version = "0.12.1" -description = "Utility library for gitignore style pattern matching of file paths." -optional = false -python-versions = ">=3.8" -files = [ - {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"}, - {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, -] - [[package]] name = "pendulum" version = "3.0.0" @@ -2559,17 +2491,6 @@ files = [ {file = "pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034"}, ] -[[package]] -name = "pycodestyle" -version = "2.12.1" -description = "Python style guide checker" -optional = false -python-versions = ">=3.8" -files = [ - {file = "pycodestyle-2.12.1-py2.py3-none-any.whl", hash = "sha256:46f0fb92069a7c28ab7bb558f05bfc0110dac69a0cd23c61ea0040283a9d78b3"}, - {file = "pycodestyle-2.12.1.tar.gz", hash = "sha256:6838eae08bbce4f6accd5d5572075c63626a15ee3e6f842df996bf62f6d73521"}, -] - [[package]] name = "pycparser" version = "2.22" @@ -2722,17 +2643,6 @@ snowballstemmer = ">=2.2.0" [package.extras] toml = ["tomli (>=1.2.3)"] -[[package]] -name = "pyflakes" -version = "3.2.0" -description = "passive checker of Python programs" -optional = false -python-versions = ">=3.8" -files = [ - {file = "pyflakes-3.2.0-py2.py3-none-any.whl", hash = "sha256:84b5be138a2dfbb40689ca07e2152deb896a65c3a3e24c251c5c62489568074a"}, - {file = "pyflakes-3.2.0.tar.gz", hash = "sha256:1c61603ff154621fb2a9172037d84dca3500def8c8b630657d1701f026f8af3f"}, -] - [[package]] name = "pygments" version = "2.18.0" @@ -3103,123 +3013,6 @@ redis = ">=3.0.0,<4.0.0" [package.extras] hiredis = ["hiredis (>=0.1.3)"] -[[package]] -name = "regex" -version = "2024.9.11" -description = "Alternative regular expression module, to replace re." -optional = false -python-versions = ">=3.8" -files = [ - {file = "regex-2024.9.11-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:1494fa8725c285a81d01dc8c06b55287a1ee5e0e382d8413adc0a9197aac6408"}, - {file = "regex-2024.9.11-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0e12c481ad92d129c78f13a2a3662317e46ee7ef96c94fd332e1c29131875b7d"}, - {file = "regex-2024.9.11-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:16e13a7929791ac1216afde26f712802e3df7bf0360b32e4914dca3ab8baeea5"}, - {file = "regex-2024.9.11-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:46989629904bad940bbec2106528140a218b4a36bb3042d8406980be1941429c"}, - {file = "regex-2024.9.11-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a906ed5e47a0ce5f04b2c981af1c9acf9e8696066900bf03b9d7879a6f679fc8"}, - {file = "regex-2024.9.11-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e9a091b0550b3b0207784a7d6d0f1a00d1d1c8a11699c1a4d93db3fbefc3ad35"}, - {file = "regex-2024.9.11-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ddcd9a179c0a6fa8add279a4444015acddcd7f232a49071ae57fa6e278f1f71"}, - {file = "regex-2024.9.11-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6b41e1adc61fa347662b09398e31ad446afadff932a24807d3ceb955ed865cc8"}, - {file = "regex-2024.9.11-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ced479f601cd2f8ca1fd7b23925a7e0ad512a56d6e9476f79b8f381d9d37090a"}, - {file = "regex-2024.9.11-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:635a1d96665f84b292e401c3d62775851aedc31d4f8784117b3c68c4fcd4118d"}, - {file = "regex-2024.9.11-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:c0256beda696edcf7d97ef16b2a33a8e5a875affd6fa6567b54f7c577b30a137"}, - {file = "regex-2024.9.11-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:3ce4f1185db3fbde8ed8aa223fc9620f276c58de8b0d4f8cc86fd1360829edb6"}, - {file = "regex-2024.9.11-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:09d77559e80dcc9d24570da3745ab859a9cf91953062e4ab126ba9d5993688ca"}, - {file = "regex-2024.9.11-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:7a22ccefd4db3f12b526eccb129390942fe874a3a9fdbdd24cf55773a1faab1a"}, - {file = "regex-2024.9.11-cp310-cp310-win32.whl", hash = "sha256:f745ec09bc1b0bd15cfc73df6fa4f726dcc26bb16c23a03f9e3367d357eeedd0"}, - {file = "regex-2024.9.11-cp310-cp310-win_amd64.whl", hash = "sha256:01c2acb51f8a7d6494c8c5eafe3d8e06d76563d8a8a4643b37e9b2dd8a2ff623"}, - {file = "regex-2024.9.11-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:2cce2449e5927a0bf084d346da6cd5eb016b2beca10d0013ab50e3c226ffc0df"}, - {file = "regex-2024.9.11-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3b37fa423beefa44919e009745ccbf353d8c981516e807995b2bd11c2c77d268"}, - {file = "regex-2024.9.11-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:64ce2799bd75039b480cc0360907c4fb2f50022f030bf9e7a8705b636e408fad"}, - {file = "regex-2024.9.11-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a4cc92bb6db56ab0c1cbd17294e14f5e9224f0cc6521167ef388332604e92679"}, - {file = "regex-2024.9.11-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d05ac6fa06959c4172eccd99a222e1fbf17b5670c4d596cb1e5cde99600674c4"}, - {file = "regex-2024.9.11-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:040562757795eeea356394a7fb13076ad4f99d3c62ab0f8bdfb21f99a1f85664"}, - {file = "regex-2024.9.11-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6113c008a7780792efc80f9dfe10ba0cd043cbf8dc9a76ef757850f51b4edc50"}, - {file = "regex-2024.9.11-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8e5fb5f77c8745a60105403a774fe2c1759b71d3e7b4ca237a5e67ad066c7199"}, - {file = "regex-2024.9.11-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:54d9ff35d4515debf14bc27f1e3b38bfc453eff3220f5bce159642fa762fe5d4"}, - {file = "regex-2024.9.11-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:df5cbb1fbc74a8305b6065d4ade43b993be03dbe0f8b30032cced0d7740994bd"}, - {file = "regex-2024.9.11-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:7fb89ee5d106e4a7a51bce305ac4efb981536301895f7bdcf93ec92ae0d91c7f"}, - {file = "regex-2024.9.11-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:a738b937d512b30bf75995c0159c0ddf9eec0775c9d72ac0202076c72f24aa96"}, - {file = "regex-2024.9.11-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:e28f9faeb14b6f23ac55bfbbfd3643f5c7c18ede093977f1df249f73fd22c7b1"}, - {file = "regex-2024.9.11-cp311-cp311-win32.whl", hash = "sha256:18e707ce6c92d7282dfce370cd205098384b8ee21544e7cb29b8aab955b66fa9"}, - {file = "regex-2024.9.11-cp311-cp311-win_amd64.whl", hash = "sha256:313ea15e5ff2a8cbbad96ccef6be638393041b0a7863183c2d31e0c6116688cf"}, - {file = "regex-2024.9.11-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:b0d0a6c64fcc4ef9c69bd5b3b3626cc3776520a1637d8abaa62b9edc147a58f7"}, - {file = "regex-2024.9.11-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:49b0e06786ea663f933f3710a51e9385ce0cba0ea56b67107fd841a55d56a231"}, - {file = "regex-2024.9.11-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5b513b6997a0b2f10e4fd3a1313568e373926e8c252bd76c960f96fd039cd28d"}, - {file = "regex-2024.9.11-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee439691d8c23e76f9802c42a95cfeebf9d47cf4ffd06f18489122dbb0a7ad64"}, - {file = "regex-2024.9.11-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a8f877c89719d759e52783f7fe6e1c67121076b87b40542966c02de5503ace42"}, - {file = "regex-2024.9.11-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:23b30c62d0f16827f2ae9f2bb87619bc4fba2044911e2e6c2eb1af0161cdb766"}, - {file = "regex-2024.9.11-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85ab7824093d8f10d44330fe1e6493f756f252d145323dd17ab6b48733ff6c0a"}, - {file = "regex-2024.9.11-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8dee5b4810a89447151999428fe096977346cf2f29f4d5e29609d2e19e0199c9"}, - {file = "regex-2024.9.11-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:98eeee2f2e63edae2181c886d7911ce502e1292794f4c5ee71e60e23e8d26b5d"}, - {file = "regex-2024.9.11-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:57fdd2e0b2694ce6fc2e5ccf189789c3e2962916fb38779d3e3521ff8fe7a822"}, - {file = "regex-2024.9.11-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:d552c78411f60b1fdaafd117a1fca2f02e562e309223b9d44b7de8be451ec5e0"}, - {file = "regex-2024.9.11-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:a0b2b80321c2ed3fcf0385ec9e51a12253c50f146fddb2abbb10f033fe3d049a"}, - {file = "regex-2024.9.11-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:18406efb2f5a0e57e3a5881cd9354c1512d3bb4f5c45d96d110a66114d84d23a"}, - {file = "regex-2024.9.11-cp312-cp312-win32.whl", hash = "sha256:e464b467f1588e2c42d26814231edecbcfe77f5ac414d92cbf4e7b55b2c2a776"}, - {file = "regex-2024.9.11-cp312-cp312-win_amd64.whl", hash = "sha256:9e8719792ca63c6b8340380352c24dcb8cd7ec49dae36e963742a275dfae6009"}, - {file = "regex-2024.9.11-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:c157bb447303070f256e084668b702073db99bbb61d44f85d811025fcf38f784"}, - {file = "regex-2024.9.11-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:4db21ece84dfeefc5d8a3863f101995de646c6cb0536952c321a2650aa202c36"}, - {file = "regex-2024.9.11-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:220e92a30b426daf23bb67a7962900ed4613589bab80382be09b48896d211e92"}, - {file = "regex-2024.9.11-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eb1ae19e64c14c7ec1995f40bd932448713d3c73509e82d8cd7744dc00e29e86"}, - {file = "regex-2024.9.11-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f47cd43a5bfa48f86925fe26fbdd0a488ff15b62468abb5d2a1e092a4fb10e85"}, - {file = "regex-2024.9.11-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9d4a76b96f398697fe01117093613166e6aa8195d63f1b4ec3f21ab637632963"}, - {file = "regex-2024.9.11-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ea51dcc0835eea2ea31d66456210a4e01a076d820e9039b04ae8d17ac11dee6"}, - {file = "regex-2024.9.11-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b7aaa315101c6567a9a45d2839322c51c8d6e81f67683d529512f5bcfb99c802"}, - {file = "regex-2024.9.11-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:c57d08ad67aba97af57a7263c2d9006d5c404d721c5f7542f077f109ec2a4a29"}, - {file = "regex-2024.9.11-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:f8404bf61298bb6f8224bb9176c1424548ee1181130818fcd2cbffddc768bed8"}, - {file = "regex-2024.9.11-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:dd4490a33eb909ef5078ab20f5f000087afa2a4daa27b4c072ccb3cb3050ad84"}, - {file = "regex-2024.9.11-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:eee9130eaad130649fd73e5cd92f60e55708952260ede70da64de420cdcad554"}, - {file = "regex-2024.9.11-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6a2644a93da36c784e546de579ec1806bfd2763ef47babc1b03d765fe560c9f8"}, - {file = "regex-2024.9.11-cp313-cp313-win32.whl", hash = "sha256:e997fd30430c57138adc06bba4c7c2968fb13d101e57dd5bb9355bf8ce3fa7e8"}, - {file = "regex-2024.9.11-cp313-cp313-win_amd64.whl", hash = "sha256:042c55879cfeb21a8adacc84ea347721d3d83a159da6acdf1116859e2427c43f"}, - {file = "regex-2024.9.11-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:35f4a6f96aa6cb3f2f7247027b07b15a374f0d5b912c0001418d1d55024d5cb4"}, - {file = "regex-2024.9.11-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:55b96e7ce3a69a8449a66984c268062fbaa0d8ae437b285428e12797baefce7e"}, - {file = "regex-2024.9.11-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cb130fccd1a37ed894824b8c046321540263013da72745d755f2d35114b81a60"}, - {file = "regex-2024.9.11-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:323c1f04be6b2968944d730e5c2091c8c89767903ecaa135203eec4565ed2b2b"}, - {file = "regex-2024.9.11-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:be1c8ed48c4c4065ecb19d882a0ce1afe0745dfad8ce48c49586b90a55f02366"}, - {file = "regex-2024.9.11-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b5b029322e6e7b94fff16cd120ab35a253236a5f99a79fb04fda7ae71ca20ae8"}, - {file = "regex-2024.9.11-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6fff13ef6b5f29221d6904aa816c34701462956aa72a77f1f151a8ec4f56aeb"}, - {file = "regex-2024.9.11-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:587d4af3979376652010e400accc30404e6c16b7df574048ab1f581af82065e4"}, - {file = "regex-2024.9.11-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:079400a8269544b955ffa9e31f186f01d96829110a3bf79dc338e9910f794fca"}, - {file = "regex-2024.9.11-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:f9268774428ec173654985ce55fc6caf4c6d11ade0f6f914d48ef4719eb05ebb"}, - {file = "regex-2024.9.11-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:23f9985c8784e544d53fc2930fc1ac1a7319f5d5332d228437acc9f418f2f168"}, - {file = "regex-2024.9.11-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:ae2941333154baff9838e88aa71c1d84f4438189ecc6021a12c7573728b5838e"}, - {file = "regex-2024.9.11-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:e93f1c331ca8e86fe877a48ad64e77882c0c4da0097f2212873a69bbfea95d0c"}, - {file = "regex-2024.9.11-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:846bc79ee753acf93aef4184c040d709940c9d001029ceb7b7a52747b80ed2dd"}, - {file = "regex-2024.9.11-cp38-cp38-win32.whl", hash = "sha256:c94bb0a9f1db10a1d16c00880bdebd5f9faf267273b8f5bd1878126e0fbde771"}, - {file = "regex-2024.9.11-cp38-cp38-win_amd64.whl", hash = "sha256:2b08fce89fbd45664d3df6ad93e554b6c16933ffa9d55cb7e01182baaf971508"}, - {file = "regex-2024.9.11-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:07f45f287469039ffc2c53caf6803cd506eb5f5f637f1d4acb37a738f71dd066"}, - {file = "regex-2024.9.11-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4838e24ee015101d9f901988001038f7f0d90dc0c3b115541a1365fb439add62"}, - {file = "regex-2024.9.11-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6edd623bae6a737f10ce853ea076f56f507fd7726bee96a41ee3d68d347e4d16"}, - {file = "regex-2024.9.11-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c69ada171c2d0e97a4b5aa78fbb835e0ffbb6b13fc5da968c09811346564f0d3"}, - {file = "regex-2024.9.11-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:02087ea0a03b4af1ed6ebab2c54d7118127fee8d71b26398e8e4b05b78963199"}, - {file = "regex-2024.9.11-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:69dee6a020693d12a3cf892aba4808fe168d2a4cef368eb9bf74f5398bfd4ee8"}, - {file = "regex-2024.9.11-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:297f54910247508e6e5cae669f2bc308985c60540a4edd1c77203ef19bfa63ca"}, - {file = "regex-2024.9.11-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ecea58b43a67b1b79805f1a0255730edaf5191ecef84dbc4cc85eb30bc8b63b9"}, - {file = "regex-2024.9.11-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:eab4bb380f15e189d1313195b062a6aa908f5bd687a0ceccd47c8211e9cf0d4a"}, - {file = "regex-2024.9.11-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:0cbff728659ce4bbf4c30b2a1be040faafaa9eca6ecde40aaff86f7889f4ab39"}, - {file = "regex-2024.9.11-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:54c4a097b8bc5bb0dfc83ae498061d53ad7b5762e00f4adaa23bee22b012e6ba"}, - {file = "regex-2024.9.11-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:73d6d2f64f4d894c96626a75578b0bf7d9e56dcda8c3d037a2118fdfe9b1c664"}, - {file = "regex-2024.9.11-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:e53b5fbab5d675aec9f0c501274c467c0f9a5d23696cfc94247e1fb56501ed89"}, - {file = "regex-2024.9.11-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:0ffbcf9221e04502fc35e54d1ce9567541979c3fdfb93d2c554f0ca583a19b35"}, - {file = "regex-2024.9.11-cp39-cp39-win32.whl", hash = "sha256:e4c22e1ac1f1ec1e09f72e6c44d8f2244173db7eb9629cc3a346a8d7ccc31142"}, - {file = "regex-2024.9.11-cp39-cp39-win_amd64.whl", hash = "sha256:faa3c142464efec496967359ca99696c896c591c56c53506bac1ad465f66e919"}, - {file = "regex-2024.9.11.tar.gz", hash = "sha256:6c188c307e8433bcb63dc1915022deb553b4203a70722fc542c363bf120a01fd"}, -] - -[[package]] -name = "reorder-python-imports" -version = "2.4.0" -description = "Tool for reordering python imports" -optional = false -python-versions = ">=3.6.1" -files = [ - {file = "reorder_python_imports-2.4.0-py2.py3-none-any.whl", hash = "sha256:995a2a93684af31837f30cf2bcddce2e7eb17f0d2d69c9905da103baf8cec42b"}, - {file = "reorder_python_imports-2.4.0.tar.gz", hash = "sha256:9a9e7774d66e9b410b619f934e8206a63dce5be26bd894f5006eb764bba6a26d"}, -] - -[package.dependencies] -"aspy.refactor-imports" = ">=2.1.0" - [[package]] name = "requests" version = "2.32.3" @@ -3279,6 +3072,33 @@ typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.1 [package.extras] jupyter = ["ipywidgets (>=7.5.1,<9)"] +[[package]] +name = "ruff" +version = "0.7.2" +description = "An extremely fast Python linter and code formatter, written in Rust." +optional = false +python-versions = ">=3.7" +files = [ + {file = "ruff-0.7.2-py3-none-linux_armv6l.whl", hash = "sha256:b73f873b5f52092e63ed540adefc3c36f1f803790ecf2590e1df8bf0a9f72cb8"}, + {file = "ruff-0.7.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:5b813ef26db1015953daf476202585512afd6a6862a02cde63f3bafb53d0b2d4"}, + {file = "ruff-0.7.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:853277dbd9675810c6826dad7a428d52a11760744508340e66bf46f8be9701d9"}, + {file = "ruff-0.7.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:21aae53ab1490a52bf4e3bf520c10ce120987b047c494cacf4edad0ba0888da2"}, + {file = "ruff-0.7.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ccc7e0fc6e0cb3168443eeadb6445285abaae75142ee22b2b72c27d790ab60ba"}, + {file = "ruff-0.7.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fd77877a4e43b3a98e5ef4715ba3862105e299af0c48942cc6d51ba3d97dc859"}, + {file = "ruff-0.7.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:e00163fb897d35523c70d71a46fbaa43bf7bf9af0f4534c53ea5b96b2e03397b"}, + {file = "ruff-0.7.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f3c54b538633482dc342e9b634d91168fe8cc56b30a4b4f99287f4e339103e88"}, + {file = "ruff-0.7.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7b792468e9804a204be221b14257566669d1db5c00d6bb335996e5cd7004ba80"}, + {file = "ruff-0.7.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dba53ed84ac19ae4bfb4ea4bf0172550a2285fa27fbb13e3746f04c80f7fa088"}, + {file = "ruff-0.7.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:b19fafe261bf741bca2764c14cbb4ee1819b67adb63ebc2db6401dcd652e3748"}, + {file = "ruff-0.7.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:28bd8220f4d8f79d590db9e2f6a0674f75ddbc3847277dd44ac1f8d30684b828"}, + {file = "ruff-0.7.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:9fd67094e77efbea932e62b5d2483006154794040abb3a5072e659096415ae1e"}, + {file = "ruff-0.7.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:576305393998b7bd6c46018f8104ea3a9cb3fa7908c21d8580e3274a3b04b691"}, + {file = "ruff-0.7.2-py3-none-win32.whl", hash = "sha256:fa993cfc9f0ff11187e82de874dfc3611df80852540331bc85c75809c93253a8"}, + {file = "ruff-0.7.2-py3-none-win_amd64.whl", hash = "sha256:dd8800cbe0254e06b8fec585e97554047fb82c894973f7ff18558eee33d1cb88"}, + {file = "ruff-0.7.2-py3-none-win_arm64.whl", hash = "sha256:bb8368cd45bba3f57bb29cbb8d64b4a33f8415d0149d2655c5c8539452ce7760"}, + {file = "ruff-0.7.2.tar.gz", hash = "sha256:2b14e77293380e475b4e3a7a368e14549288ed2931fce259a6f99978669e844f"}, +] + [[package]] name = "s3transfer" version = "0.10.3" @@ -4259,4 +4079,4 @@ zookeeper = ["kazoo"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "c46e7c1f81b12dc4f9c188833f8c7860f99bc4e6631a866685df2fd75bfe1e80" +content-hash = "9c41fa9374f3163923ac7a9f244e197f4356adfd703cd1fe573d33739a1a2c37" diff --git a/pyproject.toml b/pyproject.toml index 9aa840527..c881d9988 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,12 +70,7 @@ psycopg2 = ["psycopg2", "psycogreen"] zookeeper = ["kazoo"] [tool.poetry.group.dev.dependencies] -black = "21.10b0" -# TODO: This can be removed once we upgrade to a newer Black version. -# https://github.com/psf/black/issues/2964 -click = "<8.1.0" fakeredis = "*" -flake8 = ">=7.0.0" lxml = "*" moto = "*" mypy = "*" @@ -85,7 +80,6 @@ pylint = "*" pytest = "7.4.4" pytest-cov = "*" pytz = "*" -reorder-python-imports = "2.4.0" sphinx = "*" sphinx-autodoc-typehints = "*" types-redis = "*" @@ -94,6 +88,7 @@ types-setuptools = "*" webtest = "*" parameterized = "^0.9.0" opentelemetry-test-utils = "^0.47b0" +ruff = "*" [tool.poetry.scripts] @@ -106,10 +101,18 @@ baseplate-tshell = { reference = "bin/baseplate-tshell", type = "file" } [tool.poetry.plugins."distutils.commands"] build_thrift = "baseplate.frameworks.thrift.command:BuildThriftCommand" +[tool.ruff] +target-version = "py39" +line-length = 100 +extend-exclude = ["baseplate/thrift", "tests/integration/test_thrift"] + +[tool.ruff.lint] +extend-select = [ + "I", # isort + "UP", # pyupgrade + "E501", # line length +] + [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" - -[tool.black] -line-length = 100 -target-version = ['py39'] diff --git a/setup.cfg b/setup.cfg index b186ef051..c19609ce2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -27,17 +27,6 @@ exclude_lines = # mypy-only branches aren't live code if TYPE_CHECKING: -[flake8] -max-line-length = 100 -ignore = W503, E203, E501, D100, D101, D102, D103, D104, D105, D106, D107 -per-file-ignores = - baseplate/sidecars/*.py: E402, C0413 -exclude = - baseplate/thrift/ - tests/integration/test_thrift/ - build/ - .eggs/ - [mypy] python_version = 3.9 # https://opentelemetry.io/docs/instrumentation/python/mypy/ diff --git a/tests/__init__.py b/tests/__init__.py index 40e155d9f..b496a2091 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,6 +1,5 @@ from contextlib import nullcontext as does_not_raise - __all__ = [ "does_not_raise", ] diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py index 576c4e131..e612b3a9a 100644 --- a/tests/integration/__init__.py +++ b/tests/integration/__init__.py @@ -2,8 +2,7 @@ import socket import unittest -from baseplate import BaseplateObserver -from baseplate import SpanObserver +from baseplate import BaseplateObserver, SpanObserver from baseplate.lib.config import Endpoint from baseplate.lib.edgecontext import EdgeContextFactory @@ -26,7 +25,7 @@ def get_endpoint_or_skip_container(name, default_port): sock.settimeout(0.1) sock.connect(endpoint.address) except OSError: - raise unittest.SkipTest("could not find %s server for integration tests" % name) + raise unittest.SkipTest(f"could not find {name} server for integration tests") else: sock.close() @@ -52,9 +51,9 @@ def on_set_tag(self, key, value): def assert_tag(self, key, value): assert key in self.tags, f"{key!r} not found in tags ({list(self.tags.keys())!r})" - assert self.tags[key] == value, "tag {!r}: expected value {!r} but found {!r}".format( - key, value, self.tags[key] - ) + assert ( + self.tags[key] == value + ), f"tag {key!r}: expected value {value!r} but found {self.tags[key]!r}" def on_log(self, name, payload): self.logs.append((name, payload)) diff --git a/tests/integration/cassandra_tests.py b/tests/integration/cassandra_tests.py index 5b8b4a216..aa9cae6c1 100644 --- a/tests/integration/cassandra_tests.py +++ b/tests/integration/cassandra_tests.py @@ -1,22 +1,20 @@ import time import unittest - from unittest import mock try: - from cassandra import InvalidRequest, ConsistencyLevel + from cassandra import ConsistencyLevel, InvalidRequest from cassandra.cluster import ExecutionProfile from cassandra.concurrent import execute_concurrent_with_args from cassandra.query import dict_factory, named_tuple_factory except ImportError: raise unittest.SkipTest("cassandra-driver is not installed") -from baseplate.clients.cassandra import CassandraClient from baseplate import Baseplate +from baseplate.clients.cassandra import CassandraClient from . import TestBaseplateObserver, get_endpoint_or_skip_container - cassandra_endpoint = get_endpoint_or_skip_container("cassandra", 9042) diff --git a/tests/integration/live_data/writer_tests.py b/tests/integration/live_data/writer_tests.py index 370c7e365..4374e4d38 100644 --- a/tests/integration/live_data/writer_tests.py +++ b/tests/integration/live_data/writer_tests.py @@ -1,6 +1,5 @@ import unittest import uuid - from io import BytesIO from unittest import mock @@ -18,7 +17,6 @@ from .. import get_endpoint_or_skip_container - zookeeper_endpoint = get_endpoint_or_skip_container("zookeeper", 2181) diff --git a/tests/integration/live_data/zookeeper_tests.py b/tests/integration/live_data/zookeeper_tests.py index bb96cee58..f40eaeccf 100644 --- a/tests/integration/live_data/zookeeper_tests.py +++ b/tests/integration/live_data/zookeeper_tests.py @@ -1,6 +1,5 @@ import time import unittest - from unittest import mock import gevent.socket diff --git a/tests/integration/memcache_tests.py b/tests/integration/memcache_tests.py index d2da30042..bb0486e49 100644 --- a/tests/integration/memcache_tests.py +++ b/tests/integration/memcache_tests.py @@ -1,5 +1,4 @@ import unittest - from unittest import mock try: @@ -7,12 +6,11 @@ except ImportError: raise unittest.SkipTest("pymemcache is not installed") -from baseplate.clients.memcache import MemcacheClient, MonitoredMemcacheConnection, make_keys_str from baseplate import Baseplate, LocalSpan, ServerSpan +from baseplate.clients.memcache import MemcacheClient, MonitoredMemcacheConnection, make_keys_str from . import TestBaseplateObserver, get_endpoint_or_skip_container - memcached_endpoint = get_endpoint_or_skip_container("memcached", 11211) diff --git a/tests/integration/message_queue_tests.py b/tests/integration/message_queue_tests.py index 59938c71f..4c431200a 100644 --- a/tests/integration/message_queue_tests.py +++ b/tests/integration/message_queue_tests.py @@ -4,8 +4,7 @@ import posix_ipc -from baseplate.lib.message_queue import MessageQueue -from baseplate.lib.message_queue import TimedOutError +from baseplate.lib.message_queue import MessageQueue, TimedOutError class TestMessageQueueCreation(unittest.TestCase): diff --git a/tests/integration/otel_pyramid_tests.py b/tests/integration/otel_pyramid_tests.py index 54194cee6..0485db3c6 100644 --- a/tests/integration/otel_pyramid_tests.py +++ b/tests/integration/otel_pyramid_tests.py @@ -1,9 +1,7 @@ import unittest - from unittest import mock -from opentelemetry import propagate -from opentelemetry import trace +from opentelemetry import propagate, trace from opentelemetry.propagators.composite import CompositePropagator from opentelemetry.test.test_base import TestBase from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator @@ -14,19 +12,20 @@ from . import FakeEdgeContextFactory - propagate.set_global_textmap( CompositePropagator([RedditB3HTTPFormat(), TraceContextTextMapPropagator()]) ) try: import webtest - - from baseplate.frameworks.pyramid import BaseplateConfigurator - from baseplate.frameworks.pyramid import ServerSpanInitialized - from baseplate.frameworks.pyramid import StaticTrustHandler from pyramid.config import Configurator from pyramid.httpexceptions import HTTPInternalServerError + + from baseplate.frameworks.pyramid import ( + BaseplateConfigurator, + ServerSpanInitialized, + StaticTrustHandler, + ) except ImportError: raise unittest.SkipTest("pyramid/webtest is not installed") diff --git a/tests/integration/otel_thrift_tests.py b/tests/integration/otel_thrift_tests.py index e90f1b636..201ff7696 100644 --- a/tests/integration/otel_thrift_tests.py +++ b/tests/integration/otel_thrift_tests.py @@ -1,43 +1,32 @@ import contextlib import logging import unittest - from importlib import reload import gevent.monkey import pytest - -from opentelemetry import propagate -from opentelemetry import trace +from opentelemetry import propagate, trace from opentelemetry.propagators.composite import CompositePropagator -from opentelemetry.semconv.trace import MessageTypeValues -from opentelemetry.semconv.trace import SpanAttributes +from opentelemetry.semconv.trace import MessageTypeValues, SpanAttributes from opentelemetry.test.test_base import TestBase from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from parameterized import parameterized from thrift.protocol.TProtocol import TProtocolException -from thrift.Thrift import TApplicationException -from thrift.Thrift import TException +from thrift.Thrift import TApplicationException, TException from thrift.transport.TTransport import TTransportException -from baseplate import Baseplate -from baseplate import TraceInfo +from baseplate import Baseplate, TraceInfo from baseplate.clients.thrift import ThriftClient from baseplate.frameworks.thrift import baseplateify_processor from baseplate.lib import config from baseplate.lib.propagator_redditb3_http import RedditB3HTTPFormat from baseplate.lib.propagator_redditb3_thrift import RedditB3ThriftFormat from baseplate.lib.thrift_pool import ThriftConnectionPool -from baseplate.observers.timeout import ServerTimeout -from baseplate.observers.timeout import TimeoutBaseplateObserver +from baseplate.observers.timeout import ServerTimeout, TimeoutBaseplateObserver from baseplate.server import make_listener from baseplate.server.thrift import make_server -from baseplate.thrift import BaseplateService -from baseplate.thrift import BaseplateServiceV2 -from baseplate.thrift.ttypes import Error -from baseplate.thrift.ttypes import ErrorCode -from baseplate.thrift.ttypes import IsHealthyProbe -from baseplate.thrift.ttypes import IsHealthyRequest +from baseplate.thrift import BaseplateService, BaseplateServiceV2 +from baseplate.thrift.ttypes import Error, ErrorCode, IsHealthyProbe, IsHealthyRequest from . import FakeEdgeContextFactory from .test_thrift import TestService @@ -287,11 +276,11 @@ def example(self, context): with raw_thrift_client(server.endpoint, TestService) as client: transport = client._oprot.trans transport.set_header( - "Trace".encode("utf-8"), - "100985939111033328018442752961257817910".encode("utf-8"), + b"Trace", + b"100985939111033328018442752961257817910", ) - transport.set_header("Span".encode("utf-8"), "67667974448284344".encode("utf-8")) - transport.set_header("Sampled".encode("utf-8"), "1".encode("utf-8")) + transport.set_header(b"Span", b"67667974448284344") + transport.set_header(b"Sampled", b"1") client.example() finished_spans = self.get_finished_spans() @@ -316,9 +305,9 @@ def example(self, context): with serve_thrift(handler, TestService) as server: with raw_thrift_client(server.endpoint, TestService) as client: transport = client._oprot.trans - transport.set_header("Trace".encode("utf-8"), "2365116317615059789".encode("utf-8")) - transport.set_header("Span".encode("utf-8"), "11655119394564249508".encode("utf-8")) - transport.set_header("Sampled".encode("utf-8"), "1".encode("utf-8")) + transport.set_header(b"Trace", b"2365116317615059789") + transport.set_header(b"Span", b"11655119394564249508") + transport.set_header(b"Sampled", b"1") client.example() finished_spans = self.get_finished_spans() @@ -351,10 +340,10 @@ def example(self, context): transport = client._oprot.trans transport.set_header(b"traceparent", traceparent.encode()) # should get discarded - transport.set_header("Trace".encode("utf-8"), "20d294c28becf34d".encode("utf-8")) + transport.set_header(b"Trace", b"20d294c28becf34d") # should get discarded - transport.set_header("Span".encode("utf-8"), "a1bf4d567fc497a4".encode("utf-8")) - transport.set_header("Sampled".encode("utf-8"), "1".encode("utf-8")) + transport.set_header(b"Span", b"a1bf4d567fc497a4") + transport.set_header(b"Sampled", b"1") client_result = client.example() finished_spans = self.get_finished_spans() @@ -573,8 +562,9 @@ def example(self, context): with serve_thrift(handler, TestService, convert_to_baseplate_error=False) as server: with raw_thrift_client(server.endpoint, TestService) as client: - # although we set `convert_to_baseplate_error` to `False`, this still gets "converted" - # in the ``TestService`` interface. But the point is it's not just ``Error`` + # although we set `convert_to_baseplate_error` to `False`, this + # still gets "converted" in the ``TestService`` interface. But + # the point is it's not just ``Error`` with self.assertRaises(TApplicationException): client.example() @@ -682,8 +672,7 @@ def test_exception_handling_trace_status( some exceptions, or when the status on baseplate Error is a 5xx). """ logger.debug( - "exc=%s, convert=%s, expectation=%s, otel_exception=%s, otel_status=%s" - % (exc, convert, expectation, otel_exception, otel_status) + f"exc={exc}, convert={convert}, expectation={expectation}, otel_exception={otel_exception}, otel_status={otel_status}", # noqa: E501 ) class Handler(TestService.Iface): diff --git a/tests/integration/pyramid_tests.py b/tests/integration/pyramid_tests.py index 55dbe870a..8211bcbe5 100644 --- a/tests/integration/pyramid_tests.py +++ b/tests/integration/pyramid_tests.py @@ -1,26 +1,24 @@ import base64 import unittest - from unittest import mock from opentelemetry.test.test_base import TestBase from pyramid.response import Response -from baseplate import Baseplate -from baseplate import BaseplateObserver -from baseplate import ServerSpanObserver +from baseplate import Baseplate, BaseplateObserver, ServerSpanObserver from . import FakeEdgeContextFactory - try: import webtest - - from baseplate.frameworks.pyramid import BaseplateConfigurator - from baseplate.frameworks.pyramid import ServerSpanInitialized - from baseplate.frameworks.pyramid import StaticTrustHandler from pyramid.config import Configurator from pyramid.httpexceptions import HTTPInternalServerError + + from baseplate.frameworks.pyramid import ( + BaseplateConfigurator, + ServerSpanInitialized, + StaticTrustHandler, + ) except ImportError: raise unittest.SkipTest("pyramid/webtest is not installed") diff --git a/tests/integration/ratelimit_tests.py b/tests/integration/ratelimit_tests.py index 3193f0d35..ef5d5b8b3 100644 --- a/tests/integration/ratelimit_tests.py +++ b/tests/integration/ratelimit_tests.py @@ -1,26 +1,25 @@ import unittest - from time import sleep from uuid import uuid4 from baseplate import Baseplate -from baseplate.lib.ratelimit import RateLimiterContextFactory -from baseplate.lib.ratelimit import RateLimitExceededException +from baseplate.lib.ratelimit import RateLimiterContextFactory, RateLimitExceededException try: from pymemcache.client.base import PooledClient + from baseplate.lib.ratelimit.backends.memcache import MemcacheRateLimitBackendContextFactory except ImportError: raise unittest.SkipTest("pymemcache is not installed") try: from redis import ConnectionPool + from baseplate.lib.ratelimit.backends.redis import RedisRateLimitBackendContextFactory except ImportError: raise unittest.SkipTest("redis-py is not installed") from . import TestBaseplateObserver, get_endpoint_or_skip_container - redis_endpoint = get_endpoint_or_skip_container("redis", 6379) memcached_endpoint = get_endpoint_or_skip_container("memcached", 11211) diff --git a/tests/integration/redis_cluster_tests.py b/tests/integration/redis_cluster_tests.py index 504d23c7e..a1582266e 100644 --- a/tests/integration/redis_cluster_tests.py +++ b/tests/integration/redis_cluster_tests.py @@ -7,18 +7,13 @@ from prometheus_client import REGISTRY +from baseplate.clients.redis import ACTIVE_REQUESTS, LATENCY_SECONDS, REQUESTS_TOTAL +from baseplate.clients.redis_cluster import ClusterRedisClient, cluster_pool_from_config from baseplate.lib.config import ConfigurationError -from baseplate.clients.redis_cluster import cluster_pool_from_config - -from baseplate.clients.redis_cluster import ClusterRedisClient -from baseplate.clients.redis import REQUESTS_TOTAL -from baseplate.clients.redis import LATENCY_SECONDS -from baseplate.clients.redis import ACTIVE_REQUESTS from . import get_endpoint_or_skip_container from .redis_testcase import RedisIntegrationTestCase, redis_cluster_url - redis_endpoint = get_endpoint_or_skip_container("redis-cluster-node", 7000) diff --git a/tests/integration/redis_testcase.py b/tests/integration/redis_testcase.py index b72cbf2c9..6fe87c43e 100644 --- a/tests/integration/redis_testcase.py +++ b/tests/integration/redis_testcase.py @@ -3,11 +3,9 @@ import redis import baseplate.clients.redis as baseplate_redis - from baseplate import Baseplate -from . import get_endpoint_or_skip_container -from . import TestBaseplateObserver +from . import TestBaseplateObserver, get_endpoint_or_skip_container redis_url = f'redis://{get_endpoint_or_skip_container("redis", 6379)}' redis_cluster_url = f'redis://{get_endpoint_or_skip_container("redis-cluster-node", 7000)}' diff --git a/tests/integration/redis_tests.py b/tests/integration/redis_tests.py index 5454eaad7..e2d437478 100644 --- a/tests/integration/redis_tests.py +++ b/tests/integration/redis_tests.py @@ -6,18 +6,20 @@ except ImportError: raise unittest.SkipTest("redis-py is not installed") -from baseplate.clients.redis import ACTIVE_REQUESTS -from baseplate.clients.redis import REQUESTS_TOTAL -from baseplate.clients.redis import LATENCY_SECONDS -from baseplate.clients.redis import RedisClient +from prometheus_client import REGISTRY + +from baseplate.clients.redis import ( + ACTIVE_REQUESTS, + LATENCY_SECONDS, + REQUESTS_TOTAL, + MessageQueue, + RedisClient, +) +from baseplate.lib.message_queue import TimedOutError from . import get_endpoint_or_skip_container from .redis_testcase import RedisIntegrationTestCase, redis_url -from baseplate.clients.redis import MessageQueue -from baseplate.lib.message_queue import TimedOutError -from prometheus_client import REGISTRY - redis_endpoint = get_endpoint_or_skip_container("redis", 6379) diff --git a/tests/integration/requests_tests.py b/tests/integration/requests_tests.py index 7499a8e30..49fc1f046 100644 --- a/tests/integration/requests_tests.py +++ b/tests/integration/requests_tests.py @@ -4,15 +4,12 @@ import gevent import pytest import requests - from pyramid.config import Configurator from pyramid.httpexceptions import HTTPNoContent from baseplate import Baseplate -from baseplate.clients.requests import ExternalRequestsClient -from baseplate.clients.requests import InternalRequestsClient -from baseplate.frameworks.pyramid import BaseplateConfigurator -from baseplate.frameworks.pyramid import StaticTrustHandler +from baseplate.clients.requests import ExternalRequestsClient, InternalRequestsClient +from baseplate.frameworks.pyramid import BaseplateConfigurator, StaticTrustHandler from baseplate.lib import config from baseplate.server import make_listener from baseplate.server.wsgi import make_server diff --git a/tests/integration/sqlalchemy_tests.py b/tests/integration/sqlalchemy_tests.py index 617a772c3..3f3ddd75f 100644 --- a/tests/integration/sqlalchemy_tests.py +++ b/tests/integration/sqlalchemy_tests.py @@ -8,17 +8,16 @@ except ImportError: raise unittest.SkipTest("sqlalchemy is not installed") +from baseplate import Baseplate from baseplate.clients.sqlalchemy import ( - engine_from_config, SQLAlchemyEngineContextFactory, SQLAlchemySession, SQLAlchemySessionContextFactory, + engine_from_config, ) -from baseplate import Baseplate from . import TestBaseplateObserver - Base = declarative_base() diff --git a/tests/integration/thrift_tests.py b/tests/integration/thrift_tests.py index c36a4e335..a51ba1109 100644 --- a/tests/integration/thrift_tests.py +++ b/tests/integration/thrift_tests.py @@ -2,32 +2,22 @@ import logging import random import unittest - from importlib import reload from unittest import mock import gevent.monkey import pytest -from baseplate import Baseplate -from baseplate import BaseplateObserver -from baseplate import ServerSpanObserver -from baseplate import SpanObserver -from baseplate import TraceInfo +from baseplate import Baseplate, BaseplateObserver, ServerSpanObserver, SpanObserver, TraceInfo from baseplate.clients.thrift import ThriftClient from baseplate.frameworks.thrift import baseplateify_processor from baseplate.lib import config from baseplate.lib.thrift_pool import ThriftConnectionPool -from baseplate.observers.timeout import ServerTimeout -from baseplate.observers.timeout import TimeoutBaseplateObserver +from baseplate.observers.timeout import ServerTimeout, TimeoutBaseplateObserver from baseplate.server import make_listener from baseplate.server.thrift import make_server -from baseplate.thrift import BaseplateService -from baseplate.thrift import BaseplateServiceV2 -from baseplate.thrift.ttypes import Error -from baseplate.thrift.ttypes import ErrorCode -from baseplate.thrift.ttypes import IsHealthyProbe -from baseplate.thrift.ttypes import IsHealthyRequest +from baseplate.thrift import BaseplateService, BaseplateServiceV2 +from baseplate.thrift.ttypes import Error, ErrorCode, IsHealthyProbe, IsHealthyRequest from . import FakeEdgeContextFactory from .test_thrift import TestService diff --git a/tests/integration/timeout_tests.py b/tests/integration/timeout_tests.py index 6c2be1c94..dca0c06fd 100644 --- a/tests/integration/timeout_tests.py +++ b/tests/integration/timeout_tests.py @@ -2,8 +2,7 @@ import pytest from baseplate import Baseplate -from baseplate.observers.timeout import ServerTimeout -from baseplate.observers.timeout import TimeoutBaseplateObserver +from baseplate.observers.timeout import ServerTimeout, TimeoutBaseplateObserver def _create_baseplate_object(timeout: str): diff --git a/tests/integration/tracing_tests.py b/tests/integration/tracing_tests.py index 22724c876..b87642498 100644 --- a/tests/integration/tracing_tests.py +++ b/tests/integration/tracing_tests.py @@ -1,21 +1,20 @@ import unittest - from unittest import mock from baseplate import Baseplate -from baseplate.observers.tracing import make_client -from baseplate.observers.tracing import NullRecorder -from baseplate.observers.tracing import TraceBaseplateObserver -from baseplate.observers.tracing import TraceLocalSpanObserver -from baseplate.observers.tracing import TraceServerSpanObserver +from baseplate.observers.tracing import ( + NullRecorder, + TraceBaseplateObserver, + TraceLocalSpanObserver, + TraceServerSpanObserver, + make_client, +) try: import webtest - from pyramid.config import Configurator - from baseplate.frameworks.pyramid import BaseplateConfigurator - from baseplate.frameworks.pyramid import StaticTrustHandler + from baseplate.frameworks.pyramid import BaseplateConfigurator, StaticTrustHandler except ImportError: raise unittest.SkipTest("pyramid/webtest is not installed") @@ -94,14 +93,23 @@ def test_trace_on_inbound_request(self): self.assertEqual(span["parentId"], 0) def test_local_tracing_embedded(self): - with mock.patch.object( - TraceBaseplateObserver, "on_server_span_created", side_effect=self._register_server_mock - ), mock.patch.object( - TraceServerSpanObserver, "on_child_span_created", side_effect=self._register_local_mock - ), mock.patch.object( - TraceLocalSpanObserver, "on_child_span_created", side_effect=self._register_local_mock + with ( + mock.patch.object( + TraceBaseplateObserver, + "on_server_span_created", + side_effect=self._register_server_mock, + ), + mock.patch.object( + TraceServerSpanObserver, + "on_child_span_created", + side_effect=self._register_local_mock, + ), + mock.patch.object( + TraceLocalSpanObserver, + "on_child_span_created", + side_effect=self._register_local_mock, + ), ): - self.test_app.get("/local_test") # Verify that child span can be created within a local span context # and parent IDs are inherited accordingly. diff --git a/tests/unit/clients/cassandra_tests.py b/tests/unit/clients/cassandra_tests.py index 634a1aeb8..fe7538268 100644 --- a/tests/unit/clients/cassandra_tests.py +++ b/tests/unit/clients/cassandra_tests.py @@ -1,5 +1,4 @@ import unittest - from unittest import mock from prometheus_client import REGISTRY @@ -11,20 +10,21 @@ except ImportError: raise unittest.SkipTest("cassandra-driver is not installed") -import baseplate import logging -from baseplate.lib.config import ConfigurationError + +import baseplate from baseplate.clients.cassandra import ( - cluster_from_config, + REQUEST_ACTIVE, + REQUEST_TIME, + REQUEST_TOTAL, CassandraCallbackArgs, CassandraPrometheusLabels, CassandraSessionAdapter, - REQUEST_TIME, - REQUEST_ACTIVE, - REQUEST_TOTAL, _on_execute_complete, _on_execute_failed, + cluster_from_config, ) +from baseplate.lib.config import ConfigurationError from baseplate.lib.secrets import SecretsStore logger = logging.getLogger(__name__) @@ -115,7 +115,8 @@ def test_execute_async_prom_metrics(self): REGISTRY.get_sample_value( "cassandra_client_active_requests", { - "cassandra_client_name": "test", # client name defaults to name when not provided + # client name defaults to name when not provided + "cassandra_client_name": "test", "cassandra_keyspace": "keyspace", "cassandra_query_name": "", "cassandra_cluster_name": "", diff --git a/tests/unit/clients/kombu_tests.py b/tests/unit/clients/kombu_tests.py index 9d3d234e1..10a4c525d 100644 --- a/tests/unit/clients/kombu_tests.py +++ b/tests/unit/clients/kombu_tests.py @@ -1,15 +1,16 @@ from unittest import mock import pytest - from prometheus_client import REGISTRY -from baseplate.clients.kombu import _KombuProducer -from baseplate.clients.kombu import AMQP_PROCESSED_TOTAL -from baseplate.clients.kombu import AMQP_PROCESSING_TIME -from baseplate.clients.kombu import connection_from_config -from baseplate.clients.kombu import exchange_from_config -from baseplate.clients.kombu import KombuThriftSerializer +from baseplate.clients.kombu import ( + AMQP_PROCESSED_TOTAL, + AMQP_PROCESSING_TIME, + KombuThriftSerializer, + _KombuProducer, + connection_from_config, + exchange_from_config, +) from baseplate.lib.config import ConfigurationError from baseplate.testing.lib.secrets import FakeSecretsStore diff --git a/tests/unit/clients/memcache_tests.py b/tests/unit/clients/memcache_tests.py index b5b05c01b..9a4695160 100644 --- a/tests/unit/clients/memcache_tests.py +++ b/tests/unit/clients/memcache_tests.py @@ -1,6 +1,5 @@ import builtins import unittest - from unittest import mock try: @@ -11,10 +10,10 @@ del pymemcache from prometheus_client import REGISTRY -from baseplate.lib.config import ConfigurationError -from baseplate.clients.memcache import pool_from_config -from baseplate.clients.memcache import MonitoredMemcacheConnection + +from baseplate.clients.memcache import MonitoredMemcacheConnection, pool_from_config from baseplate.clients.memcache import lib as memcache_lib +from baseplate.lib.config import ConfigurationError class PrometheusInstrumentationTests(unittest.TestCase): diff --git a/tests/unit/clients/redis_cluster_tests.py b/tests/unit/clients/redis_cluster_tests.py index 0cf5df90d..9d65a981f 100644 --- a/tests/unit/clients/redis_cluster_tests.py +++ b/tests/unit/clients/redis_cluster_tests.py @@ -1,23 +1,23 @@ import os import unittest - from unittest import mock import fakeredis import pytest - from prometheus_client import REGISTRY from rediscluster.exceptions import RedisClusterException -from baseplate.clients.redis_cluster import ACTIVE_REQUESTS -from baseplate.clients.redis_cluster import cluster_pool_from_config -from baseplate.clients.redis_cluster import HotKeyTracker -from baseplate.clients.redis_cluster import LATENCY_SECONDS -from baseplate.clients.redis_cluster import MonitoredRedisClusterConnection -from baseplate.clients.redis_cluster import REQUESTS_TOTAL +from baseplate.clients.redis_cluster import ( + ACTIVE_REQUESTS, + LATENCY_SECONDS, + REQUESTS_TOTAL, + HotKeyTracker, + MonitoredRedisClusterConnection, + cluster_pool_from_config, +) -class DummyConnection(object): +class DummyConnection: description_format = "DummyConnection<>" def __init__(self, host="localhost", port=7000, socket_timeout=None, **kwargs): @@ -142,7 +142,8 @@ def test_pipeline_instrumentation(self, monitored_redis_connection, expected_lab ) as active_dec_spy_method: mock_manager.attach_mock(active_dec_spy_method, "dec") - # This KeyError is the same problem as the RedisClusterException in `test_execute_command_exc_redis_err` above + # This KeyError is the same problem as the + # RedisClusterException in `test_execute_command_exc_redis_err` above with pytest.raises(KeyError): monitored_redis_connection.pipeline("test").set("hello", 42).set( "goodbye", 23 @@ -157,10 +158,13 @@ def test_pipeline_instrumentation(self, monitored_redis_connection, expected_lab ) == 1.0 ), "Expected one 'pipeline' latency request" - assert mock_manager.mock_calls == [ - mock.call.inc(), - mock.call.dec(), - ], "Instrumentation should increment and then decrement active requests exactly once" + assert ( + mock_manager.mock_calls + == [ + mock.call.inc(), + mock.call.dec(), + ] + ), "Instrumentation should increment and then decrement active requests exactly once" # noqa: E501 print(list(REGISTRY.collect())) assert ( REGISTRY.get_sample_value(ACTIVE_REQUESTS._name, active_labels) == 0.0 diff --git a/tests/unit/clients/redis_tests.py b/tests/unit/clients/redis_tests.py index 799ddb19b..8d680c4f9 100644 --- a/tests/unit/clients/redis_tests.py +++ b/tests/unit/clients/redis_tests.py @@ -1,10 +1,8 @@ import os import unittest - from unittest import mock import pytest - from prometheus_client import REGISTRY try: @@ -15,12 +13,14 @@ del redis from redis.exceptions import ConnectionError +from baseplate.clients.redis import ( + ACTIVE_REQUESTS, + LATENCY_SECONDS, + REQUESTS_TOTAL, + MonitoredRedisConnection, + pool_from_config, +) from baseplate.lib.config import ConfigurationError -from baseplate.clients.redis import pool_from_config -from baseplate.clients.redis import ACTIVE_REQUESTS -from baseplate.clients.redis import REQUESTS_TOTAL -from baseplate.clients.redis import LATENCY_SECONDS -from baseplate.clients.redis import MonitoredRedisConnection class DummyConnection: @@ -156,10 +156,13 @@ def test_pipeline_instrumentation(self, monitored_redis_connection, expected_lab ) == 1.0 ), "Expected one 'pipeline' latency request" - assert mock_manager.mock_calls == [ - mock.call.inc(), - mock.call.dec(), - ], "Instrumentation should increment and then decrement active requests exactly once" + assert ( + mock_manager.mock_calls + == [ + mock.call.inc(), + mock.call.dec(), + ] + ), "Instrumentation should increment and then decrement active requests exactly once" # noqa: E501 assert ( REGISTRY.get_sample_value(ACTIVE_REQUESTS._name, active_labels) == 0.0 ), "Should have 0 (and not None) active requests" @@ -203,10 +206,13 @@ def test_pipeline_instrumentation_failing( ) == 1.0 ), "Expected one 'pipeline' latency request" - assert mock_manager.mock_calls == [ - mock.call.inc(), - mock.call.dec(), - ], "Instrumentation should increment and then decrement active requests exactly once" + assert ( + mock_manager.mock_calls + == [ + mock.call.inc(), + mock.call.dec(), + ] + ), "Instrumentation should increment and then decrement active requests exactly once" # noqa: E501 assert ( REGISTRY.get_sample_value(ACTIVE_REQUESTS._name, active_labels) == 0.0 ), "Should have 0 (and not None) active requests" diff --git a/tests/unit/clients/requests_tests.py b/tests/unit/clients/requests_tests.py index bdff96875..ceccf860b 100644 --- a/tests/unit/clients/requests_tests.py +++ b/tests/unit/clients/requests_tests.py @@ -2,16 +2,15 @@ from unittest import mock import pytest - from prometheus_client import REGISTRY -from requests import Request -from requests import Response -from requests import Session +from requests import Request, Response, Session -from baseplate.clients.requests import ACTIVE_REQUESTS -from baseplate.clients.requests import BaseplateSession -from baseplate.clients.requests import LATENCY_SECONDS -from baseplate.clients.requests import REQUESTS_TOTAL +from baseplate.clients.requests import ( + ACTIVE_REQUESTS, + LATENCY_SECONDS, + REQUESTS_TOTAL, + BaseplateSession, +) from baseplate.lib.prometheus_metrics import getHTTPSuccessLabel diff --git a/tests/unit/clients/sqlalchemy_tests.py b/tests/unit/clients/sqlalchemy_tests.py index 3d655b444..2a2176e76 100644 --- a/tests/unit/clients/sqlalchemy_tests.py +++ b/tests/unit/clients/sqlalchemy_tests.py @@ -1,5 +1,4 @@ import unittest - from unittest import mock try: @@ -7,13 +6,12 @@ except ImportError: raise unittest.SkipTest("sqlalchemy is not installed") -from baseplate.clients.sqlalchemy import engine_from_config -from baseplate.clients.sqlalchemy import SQLAlchemyEngineContextFactory -from baseplate.testing.lib.secrets import FakeSecretsStore - from prometheus_client import REGISTRY from sqlalchemy.pool import QueuePool +from baseplate.clients.sqlalchemy import SQLAlchemyEngineContextFactory, engine_from_config +from baseplate.testing.lib.secrets import FakeSecretsStore + class EngineFromConfigTests(unittest.TestCase): def setUp(self): @@ -96,7 +94,8 @@ def test_report_runtime_metrics_prom_no_queue_pool(self): self.factory.report_runtime_metrics(batch) prom_labels = {"sql_client_name": "factory_name"} - # this serves to prove that we never set these metrics / go down the code path after the isinstance check + # this serves to prove that we never set these metrics / go down the + # code path after the isinstance check self.assertEqual(REGISTRY.get_sample_value("sql_client_pool_max_size", prom_labels), None) self.assertEqual( REGISTRY.get_sample_value("sql_client_pool_client_connections", prom_labels), diff --git a/tests/unit/clients/thrift_tests.py b/tests/unit/clients/thrift_tests.py index 4dcd6249a..0120556bb 100644 --- a/tests/unit/clients/thrift_tests.py +++ b/tests/unit/clients/thrift_tests.py @@ -1,25 +1,23 @@ import unittest - from contextlib import nullcontext as does_not_raise from unittest import mock import pytest - from prometheus_client import REGISTRY from thrift.protocol.TProtocol import TProtocolException -from thrift.Thrift import TApplicationException -from thrift.Thrift import TException +from thrift.Thrift import TApplicationException, TException from thrift.transport.TTransport import TTransportException from baseplate.clients import thrift -from baseplate.clients.thrift import _build_thrift_proxy_method -from baseplate.clients.thrift import ACTIVE_REQUESTS -from baseplate.clients.thrift import REQUEST_LATENCY -from baseplate.clients.thrift import REQUESTS_TOTAL -from baseplate.clients.thrift import ThriftContextFactory +from baseplate.clients.thrift import ( + ACTIVE_REQUESTS, + REQUEST_LATENCY, + REQUESTS_TOTAL, + ThriftContextFactory, + _build_thrift_proxy_method, +) from baseplate.thrift import BaseplateServiceV2 -from baseplate.thrift.ttypes import Error -from baseplate.thrift.ttypes import ErrorCode +from baseplate.thrift.ttypes import Error, ErrorCode class EnumerateServiceMethodsTests(unittest.TestCase): @@ -162,7 +160,7 @@ def handle(*args, **kwargs): ) handler.client_cls.return_value = client_cls - thrift_success = str((exc is None)).lower() + thrift_success = str(exc is None).lower() prom_labels = { "thrift_method": "handle", "thrift_client_name": "test_namespace", diff --git a/tests/unit/core_tests.py b/tests/unit/core_tests.py index 6ee2ade1c..1647d4f18 100644 --- a/tests/unit/core_tests.py +++ b/tests/unit/core_tests.py @@ -1,18 +1,19 @@ import unittest - from unittest import mock -from baseplate import Baseplate -from baseplate import BaseplateObserver -from baseplate import LocalSpan -from baseplate import ParentSpanAlreadyFinishedError -from baseplate import RequestContext -from baseplate import ReusedContextObjectError -from baseplate import ServerSpan -from baseplate import ServerSpanObserver -from baseplate import Span -from baseplate import SpanObserver -from baseplate import TraceInfo +from baseplate import ( + Baseplate, + BaseplateObserver, + LocalSpan, + ParentSpanAlreadyFinishedError, + RequestContext, + ReusedContextObjectError, + ServerSpan, + ServerSpanObserver, + Span, + SpanObserver, + TraceInfo, +) from baseplate.clients import ContextFactory from baseplate.lib import config diff --git a/tests/unit/frameworks/pyramid/csrf_tests.py b/tests/unit/frameworks/pyramid/csrf_tests.py index 3f58fe8d9..b7159259b 100644 --- a/tests/unit/frameworks/pyramid/csrf_tests.py +++ b/tests/unit/frameworks/pyramid/csrf_tests.py @@ -1,15 +1,13 @@ import base64 import unittest - from unittest import mock from baseplate.lib.crypto import validate_signature from baseplate.testing.lib.secrets import FakeSecretsStore - has_csrf_policy = True try: - from baseplate.frameworks.pyramid.csrf import _make_csrf_token_payload, TokenCSRFStoragePolicy + from baseplate.frameworks.pyramid.csrf import TokenCSRFStoragePolicy, _make_csrf_token_payload except ImportError: has_csrf_policy = False diff --git a/tests/unit/frameworks/pyramid/http_server_prom_tests.py b/tests/unit/frameworks/pyramid/http_server_prom_tests.py index 553b0761e..813920dfa 100644 --- a/tests/unit/frameworks/pyramid/http_server_prom_tests.py +++ b/tests/unit/frameworks/pyramid/http_server_prom_tests.py @@ -1,20 +1,20 @@ import types - from contextlib import nullcontext as does_not_raise from unittest import mock import pytest - from prometheus_client import REGISTRY from pyramid.response import Response -from baseplate.frameworks.pyramid import _make_baseplate_tween -from baseplate.frameworks.pyramid import ACTIVE_REQUESTS -from baseplate.frameworks.pyramid import BaseplateConfigurator -from baseplate.frameworks.pyramid import REQUEST_LATENCY -from baseplate.frameworks.pyramid import REQUEST_SIZE -from baseplate.frameworks.pyramid import REQUESTS_TOTAL -from baseplate.frameworks.pyramid import RESPONSE_SIZE +from baseplate.frameworks.pyramid import ( + ACTIVE_REQUESTS, + REQUEST_LATENCY, + REQUEST_SIZE, + REQUESTS_TOTAL, + RESPONSE_SIZE, + BaseplateConfigurator, + _make_baseplate_tween, +) class TestPyramidHttpServerIntegrationPrometheus: diff --git a/tests/unit/frameworks/queue_consumer/kafka_tests.py b/tests/unit/frameworks/queue_consumer/kafka_tests.py index 468032465..f862c4059 100644 --- a/tests/unit/frameworks/queue_consumer/kafka_tests.py +++ b/tests/unit/frameworks/queue_consumer/kafka_tests.py @@ -1,25 +1,23 @@ import socket - from queue import Queue from unittest import mock import confluent_kafka import pytest - from gevent.server import StreamServer from prometheus_client import REGISTRY -from baseplate import Baseplate -from baseplate import RequestContext -from baseplate import ServerSpan -from baseplate.frameworks.queue_consumer.kafka import FastConsumerFactory -from baseplate.frameworks.queue_consumer.kafka import InOrderConsumerFactory -from baseplate.frameworks.queue_consumer.kafka import KAFKA_ACTIVE_MESSAGES -from baseplate.frameworks.queue_consumer.kafka import KAFKA_PROCESSED_TOTAL -from baseplate.frameworks.queue_consumer.kafka import KAFKA_PROCESSING_TIME -from baseplate.frameworks.queue_consumer.kafka import KafkaConsumerPrometheusLabels -from baseplate.frameworks.queue_consumer.kafka import KafkaConsumerWorker -from baseplate.frameworks.queue_consumer.kafka import KafkaMessageHandler +from baseplate import Baseplate, RequestContext, ServerSpan +from baseplate.frameworks.queue_consumer.kafka import ( + KAFKA_ACTIVE_MESSAGES, + KAFKA_PROCESSED_TOTAL, + KAFKA_PROCESSING_TIME, + FastConsumerFactory, + InOrderConsumerFactory, + KafkaConsumerPrometheusLabels, + KafkaConsumerWorker, + KafkaMessageHandler, +) from baseplate.lib import metrics diff --git a/tests/unit/frameworks/queue_consumer/kombu_tests.py b/tests/unit/frameworks/queue_consumer/kombu_tests.py index a68e72e9e..c6a8718cf 100644 --- a/tests/unit/frameworks/queue_consumer/kombu_tests.py +++ b/tests/unit/frameworks/queue_consumer/kombu_tests.py @@ -1,33 +1,30 @@ import socket import time - from queue import Queue from unittest import mock import kombu import pytest - from gevent.server import StreamServer from prometheus_client import REGISTRY -from baseplate import Baseplate -from baseplate import RequestContext -from baseplate import ServerSpan -from baseplate.frameworks.queue_consumer.kombu import AMQP_ACTIVE_MESSAGES -from baseplate.frameworks.queue_consumer.kombu import AMQP_PROCESSED_TOTAL -from baseplate.frameworks.queue_consumer.kombu import AMQP_PROCESSING_TIME -from baseplate.frameworks.queue_consumer.kombu import AMQP_REJECTED_REASON_RETRIES -from baseplate.frameworks.queue_consumer.kombu import AMQP_REJECTED_REASON_TTL -from baseplate.frameworks.queue_consumer.kombu import AMQP_REJECTED_TOTAL -from baseplate.frameworks.queue_consumer.kombu import AMQP_REPUBLISHED_TOTAL -from baseplate.frameworks.queue_consumer.kombu import AmqpConsumerPrometheusLabels -from baseplate.frameworks.queue_consumer.kombu import FatalMessageHandlerError -from baseplate.frameworks.queue_consumer.kombu import KombuConsumerWorker -from baseplate.frameworks.queue_consumer.kombu import KombuMessageHandler -from baseplate.frameworks.queue_consumer.kombu import KombuQueueConsumerFactory -from baseplate.frameworks.queue_consumer.kombu import RetryMode -from baseplate.lib.errors import RecoverableException -from baseplate.lib.errors import UnrecoverableException +from baseplate import Baseplate, RequestContext, ServerSpan +from baseplate.frameworks.queue_consumer.kombu import ( + AMQP_ACTIVE_MESSAGES, + AMQP_PROCESSED_TOTAL, + AMQP_PROCESSING_TIME, + AMQP_REJECTED_REASON_RETRIES, + AMQP_REJECTED_REASON_TTL, + AMQP_REJECTED_TOTAL, + AMQP_REPUBLISHED_TOTAL, + AmqpConsumerPrometheusLabels, + FatalMessageHandlerError, + KombuConsumerWorker, + KombuMessageHandler, + KombuQueueConsumerFactory, + RetryMode, +) +from baseplate.lib.errors import RecoverableException, UnrecoverableException from .... import does_not_raise @@ -125,20 +122,14 @@ def test_handle(self, ttl_delta, handled, context, span, baseplate, name, messag message.ack.assert_not_called() message.reject.assert_called_once() - assert ( - REGISTRY.get_sample_value( - f"{AMQP_PROCESSING_TIME._name}_bucket", - {**prom_labels._asdict(), **{"amqp_success": "true", "le": "+Inf"}}, - ) - == (1 if handled else None) - ) - assert ( - REGISTRY.get_sample_value( - f"{AMQP_PROCESSED_TOTAL._name}_total", - {**prom_labels._asdict(), **{"amqp_success": "true"}}, - ) - == (1 if handled else None) - ) + assert REGISTRY.get_sample_value( + f"{AMQP_PROCESSING_TIME._name}_bucket", + {**prom_labels._asdict(), **{"amqp_success": "true", "le": "+Inf"}}, + ) == (1 if handled else None) + assert REGISTRY.get_sample_value( + f"{AMQP_PROCESSED_TOTAL._name}_total", + {**prom_labels._asdict(), **{"amqp_success": "true"}}, + ) == (1 if handled else None) assert ( REGISTRY.get_sample_value( f"{AMQP_REPUBLISHED_TOTAL._name}_total", @@ -153,13 +144,10 @@ def test_handle(self, ttl_delta, handled, context, span, baseplate, name, messag ) is None ) - assert ( - REGISTRY.get_sample_value( - f"{AMQP_REJECTED_TOTAL._name}_total", - {**prom_labels._asdict(), **{"reason_code": AMQP_REJECTED_REASON_TTL}}, - ) - == (None if handled else 1) - ) + assert REGISTRY.get_sample_value( + f"{AMQP_REJECTED_TOTAL._name}_total", + {**prom_labels._asdict(), **{"reason_code": AMQP_REJECTED_REASON_TTL}}, + ) == (None if handled else 1) assert REGISTRY.get_sample_value( f"{AMQP_ACTIVE_MESSAGES._name}", prom_labels._asdict() ) == (0 if handled else None) @@ -246,7 +234,8 @@ def handler_fn(ctx, body, msg): is None ) - # we need to assert that not only the end result is 0, but that we increased and then decreased to that value + # we need to assert that not only the end result is 0, but that + # we increased and then decreased to that value assert mock_manager.mock_calls == [mock.call.inc(), mock.call.dec()] @pytest.mark.parametrize( @@ -311,7 +300,8 @@ def handler_fn(ctx, body, msg): ) == 0 ) - # we need to assert that not only the end result is 0, but that we increased and then decreased to that value + # we need to assert that not only the end result is 0, but that + # we increased and then decreased to that value assert mock_manager.mock_calls == [mock.call.inc(), mock.call.dec()] assert ( @@ -417,13 +407,10 @@ def handler_fn(ctx, body, msg): ) == 0 ) - assert ( - REGISTRY.get_sample_value( - f"{AMQP_REPUBLISHED_TOTAL._name}_total", - {**prom_labels._asdict()}, - ) - == (1 if republished else None) - ) + assert REGISTRY.get_sample_value( + f"{AMQP_REPUBLISHED_TOTAL._name}_total", + {**prom_labels._asdict()}, + ) == (1 if republished else None) retry_reached_expectation = None if attempt: if attempt >= 5 or (limit and attempt >= limit): @@ -442,7 +429,8 @@ def handler_fn(ctx, body, msg): ) is None ) - # we need to assert that not only the end result is 0, but that we increased and then decreased to that value + # we need to assert that not only the end result is 0, but that + # we increased and then decreased to that value assert mock_manager.mock_calls == [mock.call.inc(), mock.call.dec()] diff --git a/tests/unit/frameworks/thrift_tests.py b/tests/unit/frameworks/thrift_tests.py index 8e2d704e6..0ec5f8599 100644 --- a/tests/unit/frameworks/thrift_tests.py +++ b/tests/unit/frameworks/thrift_tests.py @@ -2,20 +2,19 @@ from unittest import mock import pytest - from opentelemetry import trace from prometheus_client import REGISTRY from thrift.protocol.TProtocol import TProtocolException -from thrift.Thrift import TApplicationException -from thrift.Thrift import TException +from thrift.Thrift import TApplicationException, TException from thrift.transport.TTransport import TTransportException -from baseplate.frameworks.thrift import _ContextAwareHandler -from baseplate.frameworks.thrift import PROM_ACTIVE -from baseplate.frameworks.thrift import PROM_LATENCY -from baseplate.frameworks.thrift import PROM_REQUESTS -from baseplate.thrift.ttypes import Error -from baseplate.thrift.ttypes import ErrorCode +from baseplate.frameworks.thrift import ( + PROM_ACTIVE, + PROM_LATENCY, + PROM_REQUESTS, + _ContextAwareHandler, +) +from baseplate.thrift.ttypes import Error, ErrorCode class Test_ThriftServerPrometheusMetrics: diff --git a/tests/unit/lib/config_tests.py b/tests/unit/lib/config_tests.py index f0f2eb55f..44a7a1459 100644 --- a/tests/unit/lib/config_tests.py +++ b/tests/unit/lib/config_tests.py @@ -1,7 +1,6 @@ import socket import tempfile import unittest - from unittest.mock import patch from baseplate.lib import config diff --git a/tests/unit/lib/crypto_tests.py b/tests/unit/lib/crypto_tests.py index 6f333e9d8..3f9c3ad42 100644 --- a/tests/unit/lib/crypto_tests.py +++ b/tests/unit/lib/crypto_tests.py @@ -1,5 +1,4 @@ import datetime - from unittest import mock import pytest @@ -7,7 +6,6 @@ from baseplate.lib import crypto from baseplate.lib.secrets import VersionedSecret - TEST_SECRET = VersionedSecret(previous=b"one", current=b"two", next=b"three") MESSAGE = "test message" VALID_TIL_1030 = b"AQAABgQAAOMD6M5zvQU0-GK_uKvPdKH7NOeRAq5Jdlkjwq67BzLt" diff --git a/tests/unit/lib/datetime_tests.py b/tests/unit/lib/datetime_tests.py index db86b4b6a..fee99e196 100644 --- a/tests/unit/lib/datetime_tests.py +++ b/tests/unit/lib/datetime_tests.py @@ -1,16 +1,15 @@ import unittest - -from datetime import datetime -from datetime import timezone +from datetime import datetime, timezone import pytz -from baseplate.lib.datetime import datetime_to_epoch_milliseconds -from baseplate.lib.datetime import datetime_to_epoch_seconds -from baseplate.lib.datetime import epoch_milliseconds_to_datetime -from baseplate.lib.datetime import epoch_seconds_to_datetime -from baseplate.lib.datetime import get_utc_now - +from baseplate.lib.datetime import ( + datetime_to_epoch_milliseconds, + datetime_to_epoch_seconds, + epoch_milliseconds_to_datetime, + epoch_seconds_to_datetime, + get_utc_now, +) EXAMPLE_DATETIME = datetime.utcnow().replace(tzinfo=timezone.utc, microsecond=0) diff --git a/tests/unit/lib/events/publisher_tests.py b/tests/unit/lib/events/publisher_tests.py index fc1d84f8b..41f5c4552 100644 --- a/tests/unit/lib/events/publisher_tests.py +++ b/tests/unit/lib/events/publisher_tests.py @@ -1,13 +1,10 @@ import unittest - from unittest import mock import requests -from baseplate.lib import config -from baseplate.lib import metrics -from baseplate.sidecars import event_publisher -from baseplate.sidecars import SerializedBatch +from baseplate.lib import config, metrics +from baseplate.sidecars import SerializedBatch, event_publisher class TimeLimitedBatchTests(unittest.TestCase): diff --git a/tests/unit/lib/events/queue_tests.py b/tests/unit/lib/events/queue_tests.py index 429ace542..644b0810b 100644 --- a/tests/unit/lib/events/queue_tests.py +++ b/tests/unit/lib/events/queue_tests.py @@ -1,13 +1,8 @@ import unittest - from unittest import mock -from baseplate.lib.events import EventQueue -from baseplate.lib.events import EventQueueFullError -from baseplate.lib.events import EventTooLargeError -from baseplate.lib.events import MAX_EVENT_SIZE -from baseplate.lib.message_queue import MessageQueue -from baseplate.lib.message_queue import TimedOutError +from baseplate.lib.events import MAX_EVENT_SIZE, EventQueue, EventQueueFullError, EventTooLargeError +from baseplate.lib.message_queue import MessageQueue, TimedOutError class EventQueueTests(unittest.TestCase): diff --git a/tests/unit/lib/file_watcher_tests.py b/tests/unit/lib/file_watcher_tests.py index 630ef8799..86c691206 100644 --- a/tests/unit/lib/file_watcher_tests.py +++ b/tests/unit/lib/file_watcher_tests.py @@ -3,7 +3,6 @@ import os import tempfile import unittest - from unittest import mock from baseplate.lib import file_watcher diff --git a/tests/unit/lib/metrics_tests.py b/tests/unit/lib/metrics_tests.py index 4d8076e37..9eea30d07 100644 --- a/tests/unit/lib/metrics_tests.py +++ b/tests/unit/lib/metrics_tests.py @@ -1,11 +1,8 @@ import socket import unittest - from unittest import mock -from baseplate.lib import config -from baseplate.lib import metrics - +from baseplate.lib import config, metrics EXAMPLE_ENDPOINT = config.EndpointConfiguration(socket.AF_INET, ("127.0.0.1", 1234)) diff --git a/tests/unit/lib/random_tests.py b/tests/unit/lib/random_tests.py index 0ae9f4257..ce7d57a23 100644 --- a/tests/unit/lib/random_tests.py +++ b/tests/unit/lib/random_tests.py @@ -1,6 +1,5 @@ import collections import unittest - from unittest import mock from baseplate.lib import random diff --git a/tests/unit/lib/ratelimit_tests.py b/tests/unit/lib/ratelimit_tests.py index 9c8f10ecd..a03314745 100644 --- a/tests/unit/lib/ratelimit_tests.py +++ b/tests/unit/lib/ratelimit_tests.py @@ -1,5 +1,4 @@ import unittest - from unittest import mock from pymemcache.client.base import PooledClient diff --git a/tests/unit/lib/retry_tests.py b/tests/unit/lib/retry_tests.py index d90149682..d6e034bdb 100644 --- a/tests/unit/lib/retry_tests.py +++ b/tests/unit/lib/retry_tests.py @@ -1,13 +1,14 @@ import itertools import unittest - from unittest import mock -from baseplate.lib.retry import ExponentialBackoffRetryPolicy -from baseplate.lib.retry import IndefiniteRetryPolicy -from baseplate.lib.retry import MaximumAttemptsRetryPolicy -from baseplate.lib.retry import RetryPolicy -from baseplate.lib.retry import TimeBudgetRetryPolicy +from baseplate.lib.retry import ( + ExponentialBackoffRetryPolicy, + IndefiniteRetryPolicy, + MaximumAttemptsRetryPolicy, + RetryPolicy, + TimeBudgetRetryPolicy, +) class RetryPolicyTests(unittest.TestCase): diff --git a/tests/unit/lib/secrets/store_tests.py b/tests/unit/lib/secrets/store_tests.py index a53ed4623..351dc55e5 100644 --- a/tests/unit/lib/secrets/store_tests.py +++ b/tests/unit/lib/secrets/store_tests.py @@ -1,11 +1,13 @@ import unittest -from baseplate.lib.secrets import CorruptSecretError -from baseplate.lib.secrets import CredentialSecret -from baseplate.lib.secrets import SecretNotFoundError -from baseplate.lib.secrets import secrets_store_from_config -from baseplate.lib.secrets import SecretsNotAvailableError -from baseplate.lib.secrets import SecretsStore +from baseplate.lib.secrets import ( + CorruptSecretError, + CredentialSecret, + SecretNotFoundError, + SecretsNotAvailableError, + SecretsStore, + secrets_store_from_config, +) from baseplate.testing.lib.file_watcher import FakeFileWatcher diff --git a/tests/unit/lib/secrets/vault_csi_tests.py b/tests/unit/lib/secrets/vault_csi_tests.py index a55880022..bbd01cfb0 100644 --- a/tests/unit/lib/secrets/vault_csi_tests.py +++ b/tests/unit/lib/secrets/vault_csi_tests.py @@ -5,24 +5,24 @@ import tempfile import typing import unittest - from pathlib import Path -from unittest.mock import mock_open -from unittest.mock import patch +from unittest.mock import mock_open, patch import gevent import pytest import typing_extensions -from baseplate.lib.secrets import SecretNotFoundError -from baseplate.lib.secrets import secrets_store_from_config -from baseplate.lib.secrets import SecretsStore -from baseplate.lib.secrets import VaultCSISecretsStore +from baseplate.lib.secrets import ( + SecretNotFoundError, + SecretsStore, + VaultCSISecretsStore, + secrets_store_from_config, +) -SecretType: typing_extensions.TypeAlias = typing.Dict[str, any] +SecretType: typing_extensions.TypeAlias = dict[str, any] -def write_secrets(secrets_data_path: Path, data: typing.Dict[str, SecretType]) -> None: +def write_secrets(secrets_data_path: Path, data: dict[str, SecretType]) -> None: """Write secrets to the current data directory.""" for key, value in data.items(): secret_path = secrets_data_path.joinpath(key) @@ -44,7 +44,7 @@ def write_symlinks(data_path: Path) -> None: human_path.symlink_to(csi_path.joinpath("..data/secret")) -def new_fake_csi(data: typing.Dict[str, SecretType]) -> Path: +def new_fake_csi(data: dict[str, SecretType]) -> Path: """Creates a simulated CSI directory with data and symlinks. Note that this would already be configured before the pod starts.""" csi_dir = Path(tempfile.mkdtemp()) @@ -56,7 +56,7 @@ def new_fake_csi(data: typing.Dict[str, SecretType]) -> Path: def simulate_secret_update( - csi_dir: Path, updated_data: typing.Optional[typing.Dict[str, SecretType]] = None + csi_dir: Path, updated_data: typing.Optional[dict[str, SecretType]] = None ) -> None: """Simulates either TTL expiry / a secret update.""" old_data_path = csi_dir.joinpath("..data").resolve() @@ -226,12 +226,12 @@ def test_secret_updated(self): expected_username = "".join(chars[:3]) expected_password = "".join(chars[3:]) new_secrets = EXAMPLE_UPDATED_SECRETS.copy() - new_secrets["secret/example-service/example-secret"]["data"][ - "username" - ] = expected_username - new_secrets["secret/example-service/example-secret"]["data"][ - "password" - ] = expected_password + new_secrets["secret/example-service/example-secret"]["data"]["username"] = ( + expected_username + ) + new_secrets["secret/example-service/example-secret"]["data"]["password"] = ( + expected_password + ) simulate_secret_update( self.csi_dir, updated_data=EXAMPLE_UPDATED_SECRETS, diff --git a/tests/unit/lib/service_discovery_tests.py b/tests/unit/lib/service_discovery_tests.py index f7ffe8637..e67f2f440 100644 --- a/tests/unit/lib/service_discovery_tests.py +++ b/tests/unit/lib/service_discovery_tests.py @@ -1,12 +1,9 @@ import unittest - from io import StringIO from unittest import mock from baseplate.lib import service_discovery -from baseplate.lib.file_watcher import FileWatcher -from baseplate.lib.file_watcher import WatchedFileNotAvailableError - +from baseplate.lib.file_watcher import FileWatcher, WatchedFileNotAvailableError TEST_INVENTORY_ONE = """\ [ diff --git a/tests/unit/lib/thrift_pool_tests.py b/tests/unit/lib/thrift_pool_tests.py index 205a7e973..207750221 100644 --- a/tests/unit/lib/thrift_pool_tests.py +++ b/tests/unit/lib/thrift_pool_tests.py @@ -1,21 +1,15 @@ import queue import socket import unittest - from unittest import mock -from thrift.protocol import TBinaryProtocol -from thrift.protocol import THeaderProtocol +from thrift.protocol import TBinaryProtocol, THeaderProtocol from thrift.Thrift import TException -from thrift.transport import THeaderTransport -from thrift.transport import TSocket -from thrift.transport import TTransport +from thrift.transport import THeaderTransport, TSocket, TTransport -from baseplate.lib import config -from baseplate.lib import thrift_pool +from baseplate.lib import config, thrift_pool from baseplate.observers.timeout import ServerTimeout - EXAMPLE_ENDPOINT = config.EndpointConfiguration(socket.AF_INET, ("127.0.0.1", 1234)) diff --git a/tests/unit/observers/metrics_tagged_tests.py b/tests/unit/observers/metrics_tagged_tests.py index 3ce916602..b030725f4 100644 --- a/tests/unit/observers/metrics_tagged_tests.py +++ b/tests/unit/observers/metrics_tagged_tests.py @@ -1,23 +1,17 @@ from __future__ import annotations import time - from typing import Any -from typing import Dict -from typing import Optional import pytest -from baseplate import RequestContext -from baseplate import ServerSpan -from baseplate import Span -from baseplate.lib.metrics import Counter -from baseplate.lib.metrics import Gauge -from baseplate.lib.metrics import Histogram -from baseplate.lib.metrics import Timer -from baseplate.observers.metrics_tagged import TaggedMetricsClientSpanObserver -from baseplate.observers.metrics_tagged import TaggedMetricsLocalSpanObserver -from baseplate.observers.metrics_tagged import TaggedMetricsServerSpanObserver +from baseplate import RequestContext, ServerSpan, Span +from baseplate.lib.metrics import Counter, Gauge, Histogram, Timer +from baseplate.observers.metrics_tagged import ( + TaggedMetricsClientSpanObserver, + TaggedMetricsLocalSpanObserver, + TaggedMetricsServerSpanObserver, +) class TestException(Exception): @@ -25,12 +19,12 @@ class TestException(Exception): class FakeTimer: - def __init__(self, batch: FakeBatch, name: str, tags: Dict[str, Any]): + def __init__(self, batch: FakeBatch, name: str, tags: dict[str, Any]): self.batch = batch self.name = name self.tags = tags - self.start_time: Optional[float] = None + self.start_time: float | None = None self.sample_rate: float = 1.0 def start(self, sample_rate: float = 1.0) -> None: @@ -52,12 +46,12 @@ def send(self, elapsed: float, sample_rate: float = 1.0) -> None: {"name": self.name, "elapsed": elapsed, "sample_rate": sample_rate, "tags": self.tags} ) - def update_tags(self, tags: Dict[str, Any]) -> None: + def update_tags(self, tags: dict[str, Any]) -> None: self.tags.update(tags) class FakeCounter: - def __init__(self, batch: FakeBatch, name: str, tags: Dict[str, Any]): + def __init__(self, batch: FakeBatch, name: str, tags: dict[str, Any]): self.batch = batch self.name = name self.tags = tags @@ -80,16 +74,16 @@ def __init__(self): self.counters = [] self.flushed = False - def timer(self, name: str, tags: Optional[Dict[str, Any]] = None) -> Timer: + def timer(self, name: str, tags: dict[str, Any] | None = None) -> Timer: return FakeTimer(self, name, tags or {}) - def counter(self, name: str, tags: Optional[Dict[str, Any]] = None) -> Counter: + def counter(self, name: str, tags: dict[str, Any] | None = None) -> Counter: return FakeCounter(self, name, tags or {}) - def gauge(self, name: str, tags: Optional[Dict[str, Any]] = None) -> Gauge: + def gauge(self, name: str, tags: dict[str, Any] | None = None) -> Gauge: raise NotImplementedError - def histogram(self, name: str, tags: Optional[Dict[str, Any]] = None) -> Histogram: + def histogram(self, name: str, tags: dict[str, Any] | None = None) -> Histogram: raise NotImplementedError def flush(self): diff --git a/tests/unit/observers/metrics_tests.py b/tests/unit/observers/metrics_tests.py index 5b2ec0914..f65914e2d 100644 --- a/tests/unit/observers/metrics_tests.py +++ b/tests/unit/observers/metrics_tests.py @@ -1,18 +1,14 @@ import unittest - from unittest import mock -from baseplate import LocalSpan -from baseplate import ServerSpan -from baseplate import Span -from baseplate.lib.metrics import Batch -from baseplate.lib.metrics import Client -from baseplate.lib.metrics import Counter -from baseplate.lib.metrics import Timer -from baseplate.observers.metrics import MetricsBaseplateObserver -from baseplate.observers.metrics import MetricsClientSpanObserver -from baseplate.observers.metrics import MetricsLocalSpanObserver -from baseplate.observers.metrics import MetricsServerSpanObserver +from baseplate import LocalSpan, ServerSpan, Span +from baseplate.lib.metrics import Batch, Client, Counter, Timer +from baseplate.observers.metrics import ( + MetricsBaseplateObserver, + MetricsClientSpanObserver, + MetricsLocalSpanObserver, + MetricsServerSpanObserver, +) class TestException(Exception): diff --git a/tests/unit/observers/sentry_tests.py b/tests/unit/observers/sentry_tests.py index e1f0b726b..c65bc0ff1 100644 --- a/tests/unit/observers/sentry_tests.py +++ b/tests/unit/observers/sentry_tests.py @@ -1,21 +1,22 @@ from typing import Any -from typing import Dict import gevent import pytest import sentry_sdk from baseplate import Baseplate -from baseplate.observers.sentry import _SentryUnhandledErrorReporter -from baseplate.observers.sentry import init_sentry_client_from_config -from baseplate.observers.sentry import SentryBaseplateObserver +from baseplate.observers.sentry import ( + SentryBaseplateObserver, + _SentryUnhandledErrorReporter, + init_sentry_client_from_config, +) class FakeTransport: def __init__(self): self.events = [] - def __call__(self, event: Dict[str, Any]) -> None: + def __call__(self, event: dict[str, Any]) -> None: self.events.append(event) diff --git a/tests/unit/observers/tracing/publisher_tests.py b/tests/unit/observers/tracing/publisher_tests.py index f2f53bedc..2a01813fb 100644 --- a/tests/unit/observers/tracing/publisher_tests.py +++ b/tests/unit/observers/tracing/publisher_tests.py @@ -1,12 +1,10 @@ import unittest - from unittest import mock import requests from baseplate.lib import metrics -from baseplate.sidecars import SerializedBatch -from baseplate.sidecars import trace_publisher +from baseplate.sidecars import SerializedBatch, trace_publisher class ZipkinPublisherTest(unittest.TestCase): diff --git a/tests/unit/observers/tracing_tests.py b/tests/unit/observers/tracing_tests.py index 37201ff18..5124821df 100644 --- a/tests/unit/observers/tracing_tests.py +++ b/tests/unit/observers/tracing_tests.py @@ -1,20 +1,20 @@ import json import unittest - from unittest import mock -from baseplate import ServerSpan -from baseplate import Span +from baseplate import ServerSpan, Span from baseplate.lib.config import Endpoint -from baseplate.observers.tracing import ANNOTATIONS -from baseplate.observers.tracing import LoggingRecorder -from baseplate.observers.tracing import make_client -from baseplate.observers.tracing import NullRecorder -from baseplate.observers.tracing import RemoteRecorder -from baseplate.observers.tracing import TraceBaseplateObserver -from baseplate.observers.tracing import TraceLocalSpanObserver -from baseplate.observers.tracing import TraceServerSpanObserver -from baseplate.observers.tracing import TraceSpanObserver +from baseplate.observers.tracing import ( + ANNOTATIONS, + LoggingRecorder, + NullRecorder, + RemoteRecorder, + TraceBaseplateObserver, + TraceLocalSpanObserver, + TraceServerSpanObserver, + TraceSpanObserver, + make_client, +) class TraceTestBase(unittest.TestCase): @@ -140,7 +140,6 @@ def test_component_set_on_initialization(self): self.assertTrue(component_set) def test_debug_span_tag_set_on_initialization(self): - for annotation in self.test_debug_span_observer.binary_annotations: if annotation["key"] == ANNOTATIONS["DEBUG"]: self.assertTrue(annotation["value"]) diff --git a/tests/unit/server/einhorn_tests.py b/tests/unit/server/einhorn_tests.py index 5ac09b200..b746ba4de 100644 --- a/tests/unit/server/einhorn_tests.py +++ b/tests/unit/server/einhorn_tests.py @@ -1,6 +1,5 @@ import socket import unittest - from unittest import mock from baseplate.server import einhorn diff --git a/tests/unit/server/monkey_tests.py b/tests/unit/server/monkey_tests.py index 8833b36f3..e5012d21a 100644 --- a/tests/unit/server/monkey_tests.py +++ b/tests/unit/server/monkey_tests.py @@ -5,8 +5,7 @@ import gevent.monkey import gevent.queue -from baseplate.server.monkey import gevent_is_patched -from baseplate.server.monkey import patch_stdlib_queues +from baseplate.server.monkey import gevent_is_patched, patch_stdlib_queues class MonkeyPatchTests(unittest.TestCase): diff --git a/tests/unit/server/queue_consumer_tests.py b/tests/unit/server/queue_consumer_tests.py index 3096c030f..83bcec404 100644 --- a/tests/unit/server/queue_consumer_tests.py +++ b/tests/unit/server/queue_consumer_tests.py @@ -3,7 +3,6 @@ import os import socket import time - from queue import Empty as QueueEmpty from queue import Queue from threading import Thread @@ -11,17 +10,17 @@ import pytest import webtest - from gevent.server import StreamServer from baseplate.observers.timeout import ServerTimeout -from baseplate.server.queue_consumer import HealthcheckApp -from baseplate.server.queue_consumer import MessageHandler -from baseplate.server.queue_consumer import PumpWorker -from baseplate.server.queue_consumer import QueueConsumer -from baseplate.server.queue_consumer import QueueConsumerFactory -from baseplate.server.queue_consumer import QueueConsumerServer - +from baseplate.server.queue_consumer import ( + HealthcheckApp, + MessageHandler, + PumpWorker, + QueueConsumer, + QueueConsumerFactory, + QueueConsumerServer, +) pytestmark = pytest.mark.skipif( "CI" not in os.environ, reason="tests takes too long to run for normal local iteration" diff --git a/tests/unit/server/server_tests.py b/tests/unit/server/server_tests.py index 07dc7bf98..d32f18161 100644 --- a/tests/unit/server/server_tests.py +++ b/tests/unit/server/server_tests.py @@ -2,7 +2,6 @@ import socket import sys import unittest - from unittest import mock import pytest @@ -10,7 +9,6 @@ from baseplate import server from baseplate.lib import config - EXAMPLE_ENDPOINT = config.EndpointConfiguration(socket.AF_INET, ("127.0.0.1", 1234)) diff --git a/tests/unit/sidecars/live_data_watcher_loader_tests.py b/tests/unit/sidecars/live_data_watcher_loader_tests.py index f8b034fe2..91ca0e600 100644 --- a/tests/unit/sidecars/live_data_watcher_loader_tests.py +++ b/tests/unit/sidecars/live_data_watcher_loader_tests.py @@ -1,21 +1,21 @@ import io import json import os - from unittest import mock import botocore.session import pytest - from botocore.response import StreamingBody from botocore.stub import Stubber from moto import mock_aws -from baseplate.sidecars.live_data_watcher import _load_from_s3 -from baseplate.sidecars.live_data_watcher import _parse_loader_type -from baseplate.sidecars.live_data_watcher import LoaderException -from baseplate.sidecars.live_data_watcher import LoaderType -from baseplate.sidecars.live_data_watcher import NodeWatcher +from baseplate.sidecars.live_data_watcher import ( + LoaderException, + LoaderType, + NodeWatcher, + _load_from_s3, + _parse_loader_type, +) @pytest.fixture() diff --git a/tests/unit/sidecars/live_data_watcher_tests.py b/tests/unit/sidecars/live_data_watcher_tests.py index e1e984b9b..cac614936 100644 --- a/tests/unit/sidecars/live_data_watcher_tests.py +++ b/tests/unit/sidecars/live_data_watcher_tests.py @@ -5,15 +5,12 @@ import pwd import tempfile import unittest - from pathlib import Path import boto3 - from moto import mock_aws -from baseplate.sidecars.live_data_watcher import _generate_sharded_file_key -from baseplate.sidecars.live_data_watcher import NodeWatcher +from baseplate.sidecars.live_data_watcher import NodeWatcher, _generate_sharded_file_key NUM_FILE_SHARDS = 6 @@ -87,7 +84,7 @@ def test_s3_load_type_on_change_no_sharding(self): dest = self.output_dir.joinpath("data.txt") inst = NodeWatcher(str(dest), os.getuid(), os.getgid(), 777) - new_content = b'{"live_data_watcher_load_type":"S3","bucket_name":"test_bucket","file_key":"test_file_key","sse_key":"test_decryption_key","region_name":"us-east-1"}' + new_content = b'{"live_data_watcher_load_type":"S3","bucket_name":"test_bucket","file_key":"test_file_key","sse_key":"test_decryption_key","region_name":"us-east-1"}' # noqa: E501 expected_content = b'{"foo_encrypted": "bar_encrypted"}' inst.on_change(new_content, None) self.assertEqual(expected_content, dest.read_bytes()) @@ -98,7 +95,7 @@ def test_s3_load_type_on_change_sharding(self): dest = self.output_dir.joinpath("data.txt") inst = NodeWatcher(str(dest), os.getuid(), os.getgid(), 777) - new_content = b'{"live_data_watcher_load_type":"S3","bucket_name":"test_bucket","file_key":"test_file_key","sse_key":"test_decryption_key","region_name":"us-east-1", "num_file_shards": 5}' + new_content = b'{"live_data_watcher_load_type":"S3","bucket_name":"test_bucket","file_key":"test_file_key","sse_key":"test_decryption_key","region_name":"us-east-1", "num_file_shards": 5}' # noqa: E501 expected_content = b'{"foo_encrypted": "bar_encrypted"}' # For safe measure, run this 50 times. It should succeed every time.