Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints to psycopg #3067

Merged
merged 17 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/nitpick-exceptions.ini
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ py-class=
callable
Consumer
confluent_kafka.Message
psycopg.Connection
psycopg.AsyncConnection
ObjectProxy

any=
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,27 +137,28 @@
---
"""

from __future__ import annotations

import logging
import typing
from typing import Collection
from typing import Any, Callable, Collection, TypeVar

import psycopg # pylint: disable=import-self
from psycopg import (
AsyncCursor as pg_async_cursor, # pylint: disable=import-self,no-name-in-module
)
from psycopg import (
Cursor as pg_cursor, # pylint: disable=no-name-in-module,import-self
)
from psycopg.sql import Composed # pylint: disable=no-name-in-module

from opentelemetry.instrumentation import dbapi
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.psycopg.package import _instruments
from opentelemetry.instrumentation.psycopg.version import __version__
from opentelemetry.trace import TracerProvider

_logger = logging.getLogger(__name__)
_OTEL_CURSOR_FACTORY_KEY = "_otel_orig_cursor_factory"

ConnectionT = TypeVar(
"ConnectionT", psycopg.Connection, psycopg.AsyncConnection
)
CursorT = TypeVar("CursorT", psycopg.Cursor, psycopg.AsyncCursor)
aabmass marked this conversation as resolved.
Show resolved Hide resolved


class PsycopgInstrumentor(BaseInstrumentor):
_CONNECTION_ATTRIBUTES = {
Expand All @@ -172,7 +173,7 @@ class PsycopgInstrumentor(BaseInstrumentor):
def instrumentation_dependencies(self) -> Collection[str]:
return _instruments

def _instrument(self, **kwargs):
def _instrument(self, **kwargs: Any):
"""Integrate with PostgreSQL Psycopg library.
Psycopg: http://initd.org/psycopg/
"""
Expand Down Expand Up @@ -223,7 +224,7 @@ def _instrument(self, **kwargs):
enable_attribute_commenter=enable_attribute_commenter,
)

def _uninstrument(self, **kwargs):
def _uninstrument(self, **kwargs: Any):
""" "Disable Psycopg instrumentation"""
dbapi.unwrap_connect(psycopg, "connect") # pylint: disable=no-member
dbapi.unwrap_connect(
Expand All @@ -237,7 +238,9 @@ def _uninstrument(self, **kwargs):

# TODO(owais): check if core dbapi can do this for all dbapi implementations e.g, pymysql and mysql
@staticmethod
def instrument_connection(connection, tracer_provider=None):
def instrument_connection(
connection: ConnectionT, tracer_provider: TracerProvider | None = None
) -> ConnectionT:
"""Enable instrumentation in a psycopg connection.

Args:
Expand Down Expand Up @@ -269,7 +272,7 @@ def instrument_connection(connection, tracer_provider=None):

# TODO(owais): check if core dbapi can do this for all dbapi implementations e.g, pymysql and mysql
@staticmethod
def uninstrument_connection(connection):
def uninstrument_connection(connection: ConnectionT) -> ConnectionT:
connection.cursor_factory = getattr(
connection, _OTEL_CURSOR_FACTORY_KEY, None
)
Expand All @@ -281,9 +284,9 @@ def uninstrument_connection(connection):
class DatabaseApiIntegration(dbapi.DatabaseApiIntegration):
def wrapped_connection(
self,
connect_method: typing.Callable[..., typing.Any],
args: typing.Tuple[typing.Any, typing.Any],
kwargs: typing.Dict[typing.Any, typing.Any],
connect_method: Callable[..., Any],
args: tuple[Any, Any],
kwargs: dict[Any, Any],
):
"""Add object proxy to connection object."""
base_cursor_factory = kwargs.pop("cursor_factory", None)
Expand All @@ -299,9 +302,9 @@ def wrapped_connection(
class DatabaseApiAsyncIntegration(dbapi.DatabaseApiIntegration):
async def wrapped_connection(
self,
connect_method: typing.Callable[..., typing.Any],
args: typing.Tuple[typing.Any, typing.Any],
kwargs: typing.Dict[typing.Any, typing.Any],
connect_method: Callable[..., Any],
args: tuple[Any, Any],
kwargs: dict[Any, Any],
):
"""Add object proxy to connection object."""
base_cursor_factory = kwargs.pop("cursor_factory", None)
Expand All @@ -317,7 +320,7 @@ async def wrapped_connection(


class CursorTracer(dbapi.CursorTracer):
def get_operation_name(self, cursor, args):
def get_operation_name(self, cursor: CursorT, args: list[Any]) -> str:
if not args:
return ""

Expand All @@ -332,7 +335,7 @@ def get_operation_name(self, cursor, args):

return ""

def get_statement(self, cursor, args):
def get_statement(self, cursor: CursorT, args: list[Any]) -> str:
if not args:
return ""

Expand All @@ -342,7 +345,11 @@ def get_statement(self, cursor, args):
return statement


def _new_cursor_factory(db_api=None, base_factory=None, tracer_provider=None):
def _new_cursor_factory(
db_api: DatabaseApiIntegration | None = None,
base_factory: type[psycopg.Cursor] | None = None,
tracer_provider: TracerProvider | None = None,
):
if not db_api:
db_api = DatabaseApiIntegration(
__name__,
Expand All @@ -352,21 +359,21 @@ def _new_cursor_factory(db_api=None, base_factory=None, tracer_provider=None):
tracer_provider=tracer_provider,
)

base_factory = base_factory or pg_cursor
base_factory = base_factory or psycopg.Cursor
_cursor_tracer = CursorTracer(db_api)

class TracedCursorFactory(base_factory):
def execute(self, *args, **kwargs):
def execute(self, *args: Any, **kwargs: Any):
return _cursor_tracer.traced_execution(
self, super().execute, *args, **kwargs
)

def executemany(self, *args, **kwargs):
def executemany(self, *args: Any, **kwargs: Any):
return _cursor_tracer.traced_execution(
self, super().executemany, *args, **kwargs
)

def callproc(self, *args, **kwargs):
def callproc(self, *args: Any, **kwargs: Any):
return _cursor_tracer.traced_execution(
self, super().callproc, *args, **kwargs
)
Expand All @@ -375,7 +382,9 @@ def callproc(self, *args, **kwargs):


def _new_cursor_async_factory(
db_api=None, base_factory=None, tracer_provider=None
db_api: DatabaseApiAsyncIntegration | None = None,
base_factory: type[psycopg.AsyncCursor] | None = None,
tracer_provider: TracerProvider | None = None,
):
if not db_api:
db_api = DatabaseApiAsyncIntegration(
Expand All @@ -385,21 +394,21 @@ def _new_cursor_async_factory(
version=__version__,
tracer_provider=tracer_provider,
)
base_factory = base_factory or pg_async_cursor
base_factory = base_factory or psycopg.AsyncCursor
_cursor_tracer = CursorTracer(db_api)

class TracedCursorAsyncFactory(base_factory):
async def execute(self, *args, **kwargs):
async def execute(self, *args: Any, **kwargs: Any):
return await _cursor_tracer.traced_execution(
self, super().execute, *args, **kwargs
)

async def executemany(self, *args, **kwargs):
async def executemany(self, *args: Any, **kwargs: Any):
return await _cursor_tracer.traced_execution(
self, super().executemany, *args, **kwargs
)

async def callproc(self, *args, **kwargs):
async def callproc(self, *args: Any, **kwargs: Any):
return await _cursor_tracer.traced_execution(
self, super().callproc, *args, **kwargs
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations


_instruments = ("psycopg >= 3.1.0",)
_instruments: tuple[str, ...] = ("psycopg >= 3.1.0",)
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,10 @@ class PostgresqlIntegrationTestMixin:
def setUp(self):
super().setUp()
self.cursor_mock = mock.patch(
"opentelemetry.instrumentation.psycopg.pg_cursor", MockCursor
"opentelemetry.instrumentation.psycopg.psycopg.Cursor", MockCursor
)
self.cursor_async_mock = mock.patch(
"opentelemetry.instrumentation.psycopg.pg_async_cursor",
"opentelemetry.instrumentation.psycopg.psycopg.AsyncCursor",
MockAsyncCursor,
)
self.connection_mock = mock.patch("psycopg.connect", MockConnection)
Expand Down
Loading