From 2c291bf4b1690b7dcbd26454634c7566d234b767 Mon Sep 17 00:00:00 2001 From: OlegZv <19969581+OlegZv@users.noreply.github.com> Date: Sat, 21 Dec 2024 23:48:30 -0500 Subject: [PATCH 1/8] Rough draft of the contribution --- .../instrumentation/redis/__init__.py | 327 +++++++++++++----- .../tests/test_redis.py | 125 ++++++- 2 files changed, 346 insertions(+), 106 deletions(-) diff --git a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py index 8a3096ad41..bc23902925 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py @@ -184,6 +184,62 @@ def _build_span_name( return name +def _add_create_attributes(span: Span, args: tuple[Any, ...]): + _set_span_attribute_if_value( + span, "redis.create_index.index", _value_or_none(args, 1) + ) + # According to: https://github.com/redis/redis-py/blob/master/redis/commands/search/commands.py#L155 schema is last argument for execute command + try: + schema_index = args.index("SCHEMA") + except ValueError: + return + schema = args[schema_index:] + field_attribute = "" + # Schema in format: + # [first_field_name, first_field_type, first_field_some_attribute1, first_field_some_attribute2, second_field_name, ...] + field_attribute = "".join( + f"Field(name: {schema[index - 1]}, type: {schema[index]});" + for index in range(1, len(schema)) + if schema[index] in _FIELD_TYPES + ) + _set_span_attribute_if_value( + span, + "redis.create_index.fields", + field_attribute, + ) + + +def _add_search_attributes(span: Span, response, args): + _set_span_attribute_if_value( + span, "redis.search.index", _value_or_none(args, 1) + ) + _set_span_attribute_if_value( + span, "redis.search.query", _value_or_none(args, 2) + ) + # Parse response from search + # https://redis.io/docs/latest/commands/ft.search/ + # Response in format: + # [number_of_returned_documents, index_of_first_returned_doc, first_doc(as a list), index_of_second_returned_doc, second_doc(as a list) ...] + # Returned documents in array format: + # [first_field_name, first_field_value, second_field_name, second_field_value ...] + number_of_returned_documents = _value_or_none(response, 0) + _set_span_attribute_if_value( + span, "redis.search.total", number_of_returned_documents + ) + if "NOCONTENT" in args or not number_of_returned_documents: + return + for document_number in range(number_of_returned_documents): + document_index = _value_or_none(response, 1 + 2 * document_number) + if document_index: + document = response[2 + 2 * document_number] + for attribute_name_index in range(0, len(document), 2): + _set_span_attribute_if_value( + span, + f"redis.search.xdoc_{document_index}.{document[attribute_name_index]}", + document[attribute_name_index + 1], + ) + + def _build_span_meta_data_for_pipeline( instance: PipelineInstance | AsyncPipelineInstance, ) -> tuple[list[Any], str, str]: @@ -214,11 +270,10 @@ def _build_span_meta_data_for_pipeline( return command_stack, resource, span_name -# pylint: disable=R0915 -def _instrument( - tracer: Tracer, - request_hook: _RequestHookT | None = None, - response_hook: _ResponseHookT | None = None, +def _traced_execute_factory( + tracer, + request_hook: _RequestHookT = None, + response_hook: _ResponseHookT = None, ): def _traced_execute_command( func: Callable[..., R], @@ -247,6 +302,14 @@ def _traced_execute_command( response_hook(span, instance, response) return response + return _traced_execute_command + + +def _traced_execute_pipeline_factory( + tracer, + request_hook: _RequestHookT = None, + response_hook: _ResponseHookT = None, +): def _traced_execute_pipeline( func: Callable[..., R], instance: PipelineInstance, @@ -284,90 +347,14 @@ def _traced_execute_pipeline( return response - def _add_create_attributes(span: Span, args: tuple[Any, ...]): - _set_span_attribute_if_value( - span, "redis.create_index.index", _value_or_none(args, 1) - ) - # According to: https://github.com/redis/redis-py/blob/master/redis/commands/search/commands.py#L155 schema is last argument for execute command - try: - schema_index = args.index("SCHEMA") - except ValueError: - return - schema = args[schema_index:] - field_attribute = "" - # Schema in format: - # [first_field_name, first_field_type, first_field_some_attribute1, first_field_some_attribute2, second_field_name, ...] - field_attribute = "".join( - f"Field(name: {schema[index - 1]}, type: {schema[index]});" - for index in range(1, len(schema)) - if schema[index] in _FIELD_TYPES - ) - _set_span_attribute_if_value( - span, - "redis.create_index.fields", - field_attribute, - ) - - def _add_search_attributes(span: Span, response, args): - _set_span_attribute_if_value( - span, "redis.search.index", _value_or_none(args, 1) - ) - _set_span_attribute_if_value( - span, "redis.search.query", _value_or_none(args, 2) - ) - # Parse response from search - # https://redis.io/docs/latest/commands/ft.search/ - # Response in format: - # [number_of_returned_documents, index_of_first_returned_doc, first_doc(as a list), index_of_second_returned_doc, second_doc(as a list) ...] - # Returned documents in array format: - # [first_field_name, first_field_value, second_field_name, second_field_value ...] - number_of_returned_documents = _value_or_none(response, 0) - _set_span_attribute_if_value( - span, "redis.search.total", number_of_returned_documents - ) - if "NOCONTENT" in args or not number_of_returned_documents: - return - for document_number in range(number_of_returned_documents): - document_index = _value_or_none(response, 1 + 2 * document_number) - if document_index: - document = response[2 + 2 * document_number] - for attribute_name_index in range(0, len(document), 2): - _set_span_attribute_if_value( - span, - f"redis.search.xdoc_{document_index}.{document[attribute_name_index]}", - document[attribute_name_index + 1], - ) + return _traced_execute_pipeline - pipeline_class = ( - "BasePipeline" if redis.VERSION < (3, 0, 0) else "Pipeline" - ) - redis_class = "StrictRedis" if redis.VERSION < (3, 0, 0) else "Redis" - - wrap_function_wrapper( - "redis", f"{redis_class}.execute_command", _traced_execute_command - ) - wrap_function_wrapper( - "redis.client", - f"{pipeline_class}.execute", - _traced_execute_pipeline, - ) - wrap_function_wrapper( - "redis.client", - f"{pipeline_class}.immediate_execute_command", - _traced_execute_command, - ) - if redis.VERSION >= _REDIS_CLUSTER_VERSION: - wrap_function_wrapper( - "redis.cluster", - "RedisCluster.execute_command", - _traced_execute_command, - ) - wrap_function_wrapper( - "redis.cluster", - "ClusterPipeline.execute", - _traced_execute_pipeline, - ) +def _async_traced_execute_factory( + tracer, + request_hook: _RequestHookT = None, + response_hook: _ResponseHookT = None, +): async def _async_traced_execute_command( func: Callable[..., Awaitable[R]], instance: AsyncRedisInstance, @@ -391,6 +378,14 @@ async def _async_traced_execute_command( response_hook(span, instance, response) return response + return _async_traced_execute_command + + +def _async_traced_execute_pipeline_factory( + tracer, + request_hook: _RequestHookT = None, + response_hook: _ResponseHookT = None, +): async def _async_traced_execute_pipeline( func: Callable[..., Awaitable[R]], instance: AsyncPipelineInstance, @@ -430,6 +425,57 @@ async def _async_traced_execute_pipeline( return response + return _async_traced_execute_pipeline + + +# pylint: disable=R0915 +def _instrument( + tracer: Tracer, + request_hook: _RequestHookT | None = None, + response_hook: _ResponseHookT | None = None, +): + _traced_execute_command = _traced_execute_factory( + tracer, request_hook, response_hook + ) + _traced_execute_pipeline = _traced_execute_pipeline_factory( + tracer, request_hook, response_hook + ) + pipeline_class = ( + "BasePipeline" if redis.VERSION < (3, 0, 0) else "Pipeline" + ) + redis_class = "StrictRedis" if redis.VERSION < (3, 0, 0) else "Redis" + + wrap_function_wrapper( + "redis", f"{redis_class}.execute_command", _traced_execute_command + ) + wrap_function_wrapper( + "redis.client", + f"{pipeline_class}.execute", + _traced_execute_pipeline, + ) + wrap_function_wrapper( + "redis.client", + f"{pipeline_class}.immediate_execute_command", + _traced_execute_command, + ) + if redis.VERSION >= _REDIS_CLUSTER_VERSION: + wrap_function_wrapper( + "redis.cluster", + "RedisCluster.execute_command", + _traced_execute_command, + ) + wrap_function_wrapper( + "redis.cluster", + "ClusterPipeline.execute", + _traced_execute_pipeline, + ) + + _async_traced_execute_command = _async_traced_execute_factory( + tracer, request_hook, response_hook + ) + _async_traced_execute_pipeline = _async_traced_execute_pipeline_factory( + tracer, request_hook, response_hook + ) if redis.VERSION >= _REDIS_ASYNCIO_VERSION: wrap_function_wrapper( "redis.asyncio", @@ -459,6 +505,94 @@ async def _async_traced_execute_pipeline( ) +def _instrument_client( + client, + tracer, + request_hook: _RequestHookT = None, + response_hook: _ResponseHookT = None, +): + # first, handle async clients + _async_traced_execute_command = _async_traced_execute_factory( + tracer, request_hook, response_hook + ) + _async_traced_execute_pipeline = _async_traced_execute_pipeline_factory( + tracer, request_hook, response_hook + ) + + def _async_pipeline_wrapper(func, instance, args, kwargs): + result = func(*args, **kwargs) + wrap_function_wrapper( + result, "execute", _async_traced_execute_pipeline + ) + wrap_function_wrapper( + result, "immediate_execute_command", _async_traced_execute_command + ) + return result + + if redis.VERSION >= _REDIS_ASYNCIO_VERSION: + client_type = ( + redis.asyncio.StrictRedis + if redis.VERSION < (3, 0, 0) + else redis.asyncio.Redis + ) + + if isinstance(client, client_type): + wrap_function_wrapper( + client, "execute_command", _async_traced_execute_command + ) + wrap_function_wrapper(client, "pipeline", _async_pipeline_wrapper) + return + + def _async_cluster_pipeline_wrapper(func, instance, args, kwargs): + result = func(*args, **kwargs) + wrap_function_wrapper( + result, "execute", _async_traced_execute_pipeline + ) + return result + + # handle + if redis.VERSION >= _REDIS_ASYNCIO_CLUSTER_VERSION and isinstance( + client, redis.asyncio.RedisCluster + ): + wrap_function_wrapper( + client, "execute_command", _async_traced_execute_command + ) + wrap_function_wrapper( + client, "pipeline", _async_cluster_pipeline_wrapper + ) + return + # for redis.client.Redis, redis.Cluster and v3.0.0 redis.client.StrictRedis + # the wrappers are the same + # client_type = ( + # redis.client.StrictRedis if redis.VERSION < (3, 0, 0) else redis.client.Redis + # ) + _traced_execute_command = _traced_execute_factory( + tracer, request_hook, response_hook + ) + _traced_execute_pipeline = _traced_execute_pipeline_factory( + tracer, request_hook, response_hook + ) + + def _pipeline_wrapper(func, instance, args, kwargs): + result = func(*args, **kwargs) + wrap_function_wrapper(result, "execute", _traced_execute_pipeline) + wrap_function_wrapper( + result, "immediate_execute_command", _traced_execute_command + ) + return result + + wrap_function_wrapper( + client, + "execute_command", + _traced_execute_command, + ) + wrap_function_wrapper( + client, + "pipeline", + _pipeline_wrapper, + ) + + class RedisInstrumentor(BaseInstrumentor): """An instrumentor for Redis. @@ -483,11 +617,20 @@ def _instrument(self, **kwargs: Any): tracer_provider=tracer_provider, schema_url="https://opentelemetry.io/schemas/1.11.0", ) - _instrument( - tracer, - request_hook=kwargs.get("request_hook"), - response_hook=kwargs.get("response_hook"), - ) + redis_client = kwargs.get("client") + if redis_client: + _instrument_client( + redis_client, + tracer, + request_hook=kwargs.get("request_hook"), + response_hook=kwargs.get("response_hook"), + ) + else: + _instrument( + tracer, + request_hook=kwargs.get("request_hook"), + response_hook=kwargs.get("response_hook"), + ) def _uninstrument(self, **kwargs: Any): if redis.VERSION < (3, 0, 0): diff --git a/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py b/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py index c436589adb..0ece1a9c3b 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py +++ b/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py @@ -19,7 +19,6 @@ import pytest import redis import redis.asyncio -from fakeredis.aioredis import FakeRedis from redis.exceptions import ConnectionError as redis_ConnectionError from redis.exceptions import WatchError @@ -360,7 +359,7 @@ def test_response_error(self): def test_watch_error_sync(self): def redis_operations(): with pytest.raises(WatchError): - redis_client = fakeredis.FakeStrictRedis() + redis_client = redis.Redis() pipe = redis_client.pipeline(transaction=True) pipe.watch("a") redis_client.set("a", "bad") # This will cause the WatchError @@ -389,25 +388,123 @@ def redis_operations(): class TestRedisAsync(TestBase, IsolatedAsyncioTestCase): def setUp(self): super().setUp() - RedisInstrumentor().instrument(tracer_provider=self.tracer_provider) + self.instrumentor = RedisInstrumentor() + self.client = redis.asyncio.Redis() def tearDown(self): super().tearDown() RedisInstrumentor().uninstrument() + async def _redis_operations(self, client): + with pytest.raises(WatchError): + async with client.pipeline(transaction=False) as pipe: + await pipe.watch("a") + await client.set("a", "bad") + pipe.multi() + await pipe.set("a", "1") + await pipe.execute() + @pytest.mark.asyncio async def test_watch_error_async(self): - async def redis_operations(): - with pytest.raises(WatchError): - redis_client = FakeRedis() - async with redis_client.pipeline(transaction=False) as pipe: - await pipe.watch("a") - await redis_client.set("a", "bad") - pipe.multi() - await pipe.set("a", "1") - await pipe.execute() - - await redis_operations() + self.instrumentor.instrument(tracer_provider=self.tracer_provider) + redis_client = redis.asyncio.Redis() + + await self._redis_operations(redis_client) + + spans = self.memory_exporter.get_finished_spans() + + # there should be 3 tests, we start watch operation and have 2 set operation on same key + self.assertEqual(len(spans), 3) + + self.assertEqual(spans[0].attributes.get("db.statement"), "WATCH ?") + self.assertEqual(spans[0].kind, SpanKind.CLIENT) + self.assertEqual(spans[0].status.status_code, trace.StatusCode.UNSET) + + for span in spans[1:]: + self.assertEqual(span.attributes.get("db.statement"), "SET ? ?") + self.assertEqual(span.kind, SpanKind.CLIENT) + self.assertEqual(span.status.status_code, trace.StatusCode.UNSET) + + @pytest.mark.asyncio + async def test_watch_error_async_only_client(self): + self.instrumentor.instrument( + tracer_provider=self.tracer_provider, client=self.client + ) + redis_client = redis.asyncio.Redis() + await self._redis_operations(redis_client) + + spans = self.memory_exporter.get_finished_spans() + + # there should be 3 tests, we start watch operation and have 2 set operation on same key + self.assertEqual(len(spans), 0) + + # now with the instrumented client we should get proper spans + await self._redis_operations(self.client) + + spans = self.memory_exporter.get_finished_spans() + + # there should be 3 tests, we start watch operation and have 2 set operation on same key + self.assertEqual(len(spans), 3) + + self.assertEqual(spans[0].attributes.get("db.statement"), "WATCH ?") + self.assertEqual(spans[0].kind, SpanKind.CLIENT) + self.assertEqual(spans[0].status.status_code, trace.StatusCode.UNSET) + + for span in spans[1:]: + self.assertEqual(span.attributes.get("db.statement"), "SET ? ?") + self.assertEqual(span.kind, SpanKind.CLIENT) + self.assertEqual(span.status.status_code, trace.StatusCode.UNSET) + + +class TestRedisInstance(TestBase): + def setUp(self): + super().setUp() + self.client = redis.Redis() + RedisInstrumentor().instrument( + tracer_provider=self.tracer_provider, client=self.client + ) + + def tearDown(self): + super().tearDown() + print("SHOULD TEARDOWN") + + def test_only_client_instrumented(self): + redis_client = redis.Redis() + + with mock.patch.object(redis_client, "connection"): + redis_client.get("key") + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 0) + + # now use the test client + with mock.patch.object(self.client, "connection"): + self.client.get("key") + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + span = spans[0] + self.assertEqual(span.name, "GET") + self.assertEqual(span.kind, SpanKind.CLIENT) + + @staticmethod + def redis_operations(client): + with pytest.raises(WatchError): + pipe = client.pipeline(transaction=True) + pipe.watch("a") + client.set("a", "bad") # This will cause the WatchError + pipe.multi() + pipe.set("a", "1") + pipe.execute() + + def test_watch_error_sync_only_client(self): + redis_client = redis.Redis() + + self.redis_operations(redis_client) + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 0) + + self.redis_operations(self.client) spans = self.memory_exporter.get_finished_spans() From c33f2b548237750e29ed5aadb50717c55681d493 Mon Sep 17 00:00:00 2001 From: OlegZv <19969581+OlegZv@users.noreply.github.com> Date: Tue, 26 Nov 2024 22:49:02 -0500 Subject: [PATCH 2/8] Add connection instrumentation based on the existing pattern. add a helper function assert_span_count to simplify tests add unit tests for pipeline hooks --- .../instrumentation/redis/__init__.py | 162 ++++++++++-------- .../tests/test_redis.py | 122 ++++++++++--- 2 files changed, 188 insertions(+), 96 deletions(-) diff --git a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py index bc23902925..e0afd3f57a 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py @@ -93,6 +93,7 @@ def response_hook(span, instance, response): from __future__ import annotations +import logging from typing import TYPE_CHECKING, Any, Callable, Collection import redis @@ -146,17 +147,26 @@ def response_hook(span, instance, response): _DEFAULT_SERVICE = "redis" - +_logger = logging.getLogger(__name__) _REDIS_ASYNCIO_VERSION = (4, 2, 0) -if redis.VERSION >= _REDIS_ASYNCIO_VERSION: - import redis.asyncio - _REDIS_CLUSTER_VERSION = (4, 1, 0) _REDIS_ASYNCIO_CLUSTER_VERSION = (4, 3, 2) _FIELD_TYPES = ["NUMERIC", "TEXT", "GEO", "TAG", "VECTOR"] +_CLIENT_ASYNCIO_SUPPORT = redis.VERSION >= _REDIS_ASYNCIO_VERSION +_CLIENT_ASYNCIO_CLUSTER_SUPPORT = ( + redis.VERSION >= _REDIS_ASYNCIO_CLUSTER_VERSION +) +_CLIENT_CLUSTER_SUPPORT = redis.VERSION >= _REDIS_CLUSTER_VERSION +_CLIENT_BEFORE_3_0_0 = redis.VERSION < (3, 0, 0) + +if _CLIENT_ASYNCIO_SUPPORT: + import redis.asyncio + +INSTRUMENTATION_ATTR = "_is_instrumented_by_opentelemetry" + def _set_connection_attributes( span: Span, conn: RedisInstance | AsyncRedisInstance @@ -440,10 +450,8 @@ def _instrument( _traced_execute_pipeline = _traced_execute_pipeline_factory( tracer, request_hook, response_hook ) - pipeline_class = ( - "BasePipeline" if redis.VERSION < (3, 0, 0) else "Pipeline" - ) - redis_class = "StrictRedis" if redis.VERSION < (3, 0, 0) else "Redis" + pipeline_class = "BasePipeline" if _CLIENT_BEFORE_3_0_0 else "Pipeline" + redis_class = "StrictRedis" if _CLIENT_BEFORE_3_0_0 else "Redis" wrap_function_wrapper( "redis", f"{redis_class}.execute_command", _traced_execute_command @@ -505,68 +513,55 @@ def _instrument( ) -def _instrument_client( +def _instrument_connection( client, tracer, request_hook: _RequestHookT = None, response_hook: _ResponseHookT = None, ): - # first, handle async clients - _async_traced_execute_command = _async_traced_execute_factory( + # first, handle async clients and cluster clients + _async_traced_execute = _async_traced_execute_factory( tracer, request_hook, response_hook ) _async_traced_execute_pipeline = _async_traced_execute_pipeline_factory( tracer, request_hook, response_hook ) - def _async_pipeline_wrapper(func, instance, args, kwargs): - result = func(*args, **kwargs) - wrap_function_wrapper( - result, "execute", _async_traced_execute_pipeline - ) - wrap_function_wrapper( - result, "immediate_execute_command", _async_traced_execute_command - ) - return result - - if redis.VERSION >= _REDIS_ASYNCIO_VERSION: - client_type = ( - redis.asyncio.StrictRedis - if redis.VERSION < (3, 0, 0) - else redis.asyncio.Redis - ) + if _CLIENT_ASYNCIO_SUPPORT and isinstance(client, redis.asyncio.Redis): - if isinstance(client, client_type): + def _async_pipeline_wrapper(func, instance, args, kwargs): + result = func(*args, **kwargs) wrap_function_wrapper( - client, "execute_command", _async_traced_execute_command + result, "execute", _async_traced_execute_pipeline ) - wrap_function_wrapper(client, "pipeline", _async_pipeline_wrapper) - return + wrap_function_wrapper( + result, "immediate_execute_command", _async_traced_execute + ) + return result - def _async_cluster_pipeline_wrapper(func, instance, args, kwargs): - result = func(*args, **kwargs) - wrap_function_wrapper( - result, "execute", _async_traced_execute_pipeline - ) - return result + wrap_function_wrapper(client, "execute_command", _async_traced_execute) + wrap_function_wrapper(client, "pipeline", _async_pipeline_wrapper) + return - # handle - if redis.VERSION >= _REDIS_ASYNCIO_CLUSTER_VERSION and isinstance( + if _CLIENT_ASYNCIO_CLUSTER_SUPPORT and isinstance( client, redis.asyncio.RedisCluster ): - wrap_function_wrapper( - client, "execute_command", _async_traced_execute_command - ) + + def _async_cluster_pipeline_wrapper(func, instance, args, kwargs): + result = func(*args, **kwargs) + wrap_function_wrapper( + result, "execute", _async_traced_execute_pipeline + ) + return result + + wrap_function_wrapper(client, "execute_command", _async_traced_execute) wrap_function_wrapper( client, "pipeline", _async_cluster_pipeline_wrapper ) return # for redis.client.Redis, redis.Cluster and v3.0.0 redis.client.StrictRedis # the wrappers are the same - # client_type = ( - # redis.client.StrictRedis if redis.VERSION < (3, 0, 0) else redis.client.Redis - # ) - _traced_execute_command = _traced_execute_factory( + _traced_execute = _traced_execute_factory( tracer, request_hook, response_hook ) _traced_execute_pipeline = _traced_execute_pipeline_factory( @@ -577,14 +572,14 @@ def _pipeline_wrapper(func, instance, args, kwargs): result = func(*args, **kwargs) wrap_function_wrapper(result, "execute", _traced_execute_pipeline) wrap_function_wrapper( - result, "immediate_execute_command", _traced_execute_command + result, "immediate_execute_command", _traced_execute ) return result wrap_function_wrapper( client, "execute_command", - _traced_execute_command, + _traced_execute, ) wrap_function_wrapper( client, @@ -599,6 +594,16 @@ class RedisInstrumentor(BaseInstrumentor): See `BaseInstrumentor` """ + @staticmethod + def _get_tracer(**kwargs): + tracer_provider = kwargs.get("tracer_provider") + return trace.get_tracer( + __name__, + __version__, + tracer_provider=tracer_provider, + schema_url="https://opentelemetry.io/schemas/1.11.0", + ) + def instrumentation_dependencies(self) -> Collection[str]: return _instruments @@ -610,30 +615,14 @@ def _instrument(self, **kwargs: Any): ``tracer_provider``: a TracerProvider, defaults to global. ``response_hook``: An optional callback which is invoked right before the span is finished processing a response. """ - tracer_provider = kwargs.get("tracer_provider") - tracer = trace.get_tracer( - __name__, - __version__, - tracer_provider=tracer_provider, - schema_url="https://opentelemetry.io/schemas/1.11.0", + _instrument( + self._get_tracer(**kwargs), + request_hook=kwargs.get("request_hook"), + response_hook=kwargs.get("response_hook"), ) - redis_client = kwargs.get("client") - if redis_client: - _instrument_client( - redis_client, - tracer, - request_hook=kwargs.get("request_hook"), - response_hook=kwargs.get("response_hook"), - ) - else: - _instrument( - tracer, - request_hook=kwargs.get("request_hook"), - response_hook=kwargs.get("response_hook"), - ) def _uninstrument(self, **kwargs: Any): - if redis.VERSION < (3, 0, 0): + if _CLIENT_BEFORE_3_0_0: unwrap(redis.StrictRedis, "execute_command") unwrap(redis.StrictRedis, "pipeline") unwrap(redis.Redis, "pipeline") @@ -661,3 +650,38 @@ def _uninstrument(self, **kwargs: Any): if redis.VERSION >= _REDIS_ASYNCIO_CLUSTER_VERSION: unwrap(redis.asyncio.cluster.RedisCluster, "execute_command") unwrap(redis.asyncio.cluster.ClusterPipeline, "execute") + + @staticmethod + def instrument_connection( + client, tracer_provider: None, request_hook=None, response_hook=None + ): + if not hasattr(client, INSTRUMENTATION_ATTR): + setattr(client, INSTRUMENTATION_ATTR, False) + if not getattr(client, INSTRUMENTATION_ATTR): + _instrument_connection( + client, + RedisInstrumentor._get_tracer(tracer_provider=tracer_provider), + request_hook=request_hook, + response_hook=response_hook, + ) + setattr(client, INSTRUMENTATION_ATTR, True) + else: + _logger.warning( + "Attempting to instrument Redis connection while already instrumented" + ) + + @staticmethod + def uninstrument_connection(client): + if getattr(client, INSTRUMENTATION_ATTR): + # for all clients we need to unwrap execute_command and pipeline functions + unwrap(client, "execute_command") + # pipeline was creating a pipeline and wrapping the functions of the + # created instance. any pipeline created before un-instrumenting will + # remain instrumented (pipelines should usually have a short span) + unwrap(client, "pipeline") + pass + else: + _logger.warning( + "Attempting to un-instrument Redis connection that wasn't instrumented" + ) + return diff --git a/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py b/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py index 0ece1a9c3b..dd8b49e179 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py +++ b/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py @@ -36,7 +36,7 @@ class TestRedis(TestBase): def setUp(self): super().setUp() - RedisInstrumentor().instrument(tracer_provider=self.tracer_provider) + RedisInstrumentor().instrument(_provider=self.tracer_provider) def tearDown(self): super().tearDown() @@ -391,11 +391,7 @@ def setUp(self): self.instrumentor = RedisInstrumentor() self.client = redis.asyncio.Redis() - def tearDown(self): - super().tearDown() - RedisInstrumentor().uninstrument() - - async def _redis_operations(self, client): + async def _redis_pipeline_operations(self, client): with pytest.raises(WatchError): async with client.pipeline(transaction=False) as pipe: await pipe.watch("a") @@ -406,32 +402,45 @@ async def _redis_operations(self, client): @pytest.mark.asyncio async def test_watch_error_async(self): - self.instrumentor.instrument(tracer_provider=self.tracer_provider) - redis_client = redis.asyncio.Redis() + # this tests also ensures the response_hook is called + response_attr = "my.response.attribute" + count = 0 - await self._redis_operations(redis_client) + def response_hook(span, conn, args): + nonlocal count + if span and span.is_recording(): + span.set_attribute(response_attr, count) + count += 1 - spans = self.memory_exporter.get_finished_spans() + self.instrumentor.instrument( + tracer_provider=self.tracer_provider, response_hook=response_hook + ) + redis_client = redis.asyncio.Redis() + + await self._redis_pipeline_operations(redis_client) # there should be 3 tests, we start watch operation and have 2 set operation on same key - self.assertEqual(len(spans), 3) + spans = self.assert_span_count(3) self.assertEqual(spans[0].attributes.get("db.statement"), "WATCH ?") self.assertEqual(spans[0].kind, SpanKind.CLIENT) self.assertEqual(spans[0].status.status_code, trace.StatusCode.UNSET) + self.assertEqual(spans[0].attributes.get(response_attr), 0) - for span in spans[1:]: + for span_index, span in enumerate(spans[1:], 1): self.assertEqual(span.attributes.get("db.statement"), "SET ? ?") self.assertEqual(span.kind, SpanKind.CLIENT) self.assertEqual(span.status.status_code, trace.StatusCode.UNSET) + self.assertEqual(span.attributes.get(response_attr), span_index) + RedisInstrumentor().uninstrument() @pytest.mark.asyncio async def test_watch_error_async_only_client(self): - self.instrumentor.instrument( + self.instrumentor.instrument_connection( tracer_provider=self.tracer_provider, client=self.client ) redis_client = redis.asyncio.Redis() - await self._redis_operations(redis_client) + await self._redis_pipeline_operations(redis_client) spans = self.memory_exporter.get_finished_spans() @@ -439,7 +448,7 @@ async def test_watch_error_async_only_client(self): self.assertEqual(len(spans), 0) # now with the instrumented client we should get proper spans - await self._redis_operations(self.client) + await self._redis_pipeline_operations(self.client) spans = self.memory_exporter.get_finished_spans() @@ -454,19 +463,83 @@ async def test_watch_error_async_only_client(self): self.assertEqual(span.attributes.get("db.statement"), "SET ? ?") self.assertEqual(span.kind, SpanKind.CLIENT) self.assertEqual(span.status.status_code, trace.StatusCode.UNSET) + RedisInstrumentor().uninstrument_connection(self.client) + + @pytest.mark.asyncio + async def test_request_response_hooks(self): + request_attr = "my.request.attribute" + response_attr = "my.response.attribute" + + def request_hook(span, conn, args, kwargs): + if span and span.is_recording(): + span.set_attribute(request_attr, args[0]) + + def response_hook(span, conn, args): + if span and span.is_recording(): + span.set_attribute(response_attr, args) + + self.instrumentor.instrument( + tracer_provider=self.tracer_provider, + request_hook=request_hook, + response_hook=response_hook, + ) + await self.client.set("key", "value") + + spans = self.assert_span_count(1) + + span = spans[0] + self.assertEqual(span.attributes.get(request_attr), "SET") + self.assertEqual(span.attributes.get(response_attr), True) + self.instrumentor.uninstrument() + + @pytest.mark.asyncio + async def test_request_response_hooks_connection_only(self): + request_attr = "my.request.attribute" + response_attr = "my.response.attribute" + + def request_hook(span, conn, args, kwargs): + if span and span.is_recording(): + span.set_attribute(request_attr, args[0]) + + def response_hook(span, conn, args): + if span and span.is_recording(): + span.set_attribute(response_attr, args) + + self.instrumentor.instrument_connection( + client=self.client, + tracer_provider=self.tracer_provider, + request_hook=request_hook, + response_hook=response_hook, + ) + await self.client.set("key", "value") + + spans = self.assert_span_count(1) + + span = spans[0] + self.assertEqual(span.attributes.get(request_attr), "SET") + self.assertEqual(span.attributes.get(response_attr), True) + # fresh client should not record any spans + fresh_client = redis.asyncio.Redis() + self.memory_exporter.clear() + await fresh_client.set("key", "value") + self.assert_span_count(0) + self.instrumentor.uninstrument_connection(self.client) + # after un-instrumenting the query should not be recorder + await self.client.set("key", "value") + spans = self.assert_span_count(0) class TestRedisInstance(TestBase): def setUp(self): super().setUp() self.client = redis.Redis() - RedisInstrumentor().instrument( - tracer_provider=self.tracer_provider, client=self.client + RedisInstrumentor().instrument_connection( + client=self.client, tracer_provider=self.tracer_provider ) def tearDown(self): super().tearDown() - print("SHOULD TEARDOWN") + RedisInstrumentor().uninstrument_connection(self.client) def test_only_client_instrumented(self): redis_client = redis.Redis() @@ -474,14 +547,12 @@ def test_only_client_instrumented(self): with mock.patch.object(redis_client, "connection"): redis_client.get("key") - spans = self.memory_exporter.get_finished_spans() - self.assertEqual(len(spans), 0) + spans = self.assert_span_count(0) # now use the test client with mock.patch.object(self.client, "connection"): self.client.get("key") - spans = self.memory_exporter.get_finished_spans() - self.assertEqual(len(spans), 1) + spans = self.assert_span_count(1) span = spans[0] self.assertEqual(span.name, "GET") self.assertEqual(span.kind, SpanKind.CLIENT) @@ -501,15 +572,12 @@ def test_watch_error_sync_only_client(self): self.redis_operations(redis_client) - spans = self.memory_exporter.get_finished_spans() - self.assertEqual(len(spans), 0) + self.assert_span_count(0) self.redis_operations(self.client) - spans = self.memory_exporter.get_finished_spans() - # there should be 3 tests, we start watch operation and have 2 set operation on same key - self.assertEqual(len(spans), 3) + spans = self.assert_span_count(3) self.assertEqual(spans[0].attributes.get("db.statement"), "WATCH ?") self.assertEqual(spans[0].kind, SpanKind.CLIENT) From 1f9950730c023faba82dc06228fab5daa0bdfa08 Mon Sep 17 00:00:00 2001 From: OlegZv <19969581+OlegZv@users.noreply.github.com> Date: Sun, 22 Dec 2024 00:01:38 -0500 Subject: [PATCH 3/8] fix tests to use fake redis --- .../tests/test_redis.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py b/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py index dd8b49e179..23ad23bd89 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py +++ b/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py @@ -19,6 +19,7 @@ import pytest import redis import redis.asyncio +from fakeredis.aioredis import FakeRedis from redis.exceptions import ConnectionError as redis_ConnectionError from redis.exceptions import WatchError @@ -359,7 +360,7 @@ def test_response_error(self): def test_watch_error_sync(self): def redis_operations(): with pytest.raises(WatchError): - redis_client = redis.Redis() + redis_client = fakeredis.FakeStrictRedis() pipe = redis_client.pipeline(transaction=True) pipe.watch("a") redis_client.set("a", "bad") # This will cause the WatchError @@ -389,7 +390,7 @@ class TestRedisAsync(TestBase, IsolatedAsyncioTestCase): def setUp(self): super().setUp() self.instrumentor = RedisInstrumentor() - self.client = redis.asyncio.Redis() + self.client = FakeRedis() async def _redis_pipeline_operations(self, client): with pytest.raises(WatchError): @@ -415,8 +416,7 @@ def response_hook(span, conn, args): self.instrumentor.instrument( tracer_provider=self.tracer_provider, response_hook=response_hook ) - redis_client = redis.asyncio.Redis() - + redis_client = FakeRedis() await self._redis_pipeline_operations(redis_client) # there should be 3 tests, we start watch operation and have 2 set operation on same key @@ -439,7 +439,7 @@ async def test_watch_error_async_only_client(self): self.instrumentor.instrument_connection( tracer_provider=self.tracer_provider, client=self.client ) - redis_client = redis.asyncio.Redis() + redis_client = FakeRedis() await self._redis_pipeline_operations(redis_client) spans = self.memory_exporter.get_finished_spans() @@ -519,7 +519,7 @@ def response_hook(span, conn, args): self.assertEqual(span.attributes.get(request_attr), "SET") self.assertEqual(span.attributes.get(response_attr), True) # fresh client should not record any spans - fresh_client = redis.asyncio.Redis() + fresh_client = FakeRedis() self.memory_exporter.clear() await fresh_client.set("key", "value") self.assert_span_count(0) @@ -532,7 +532,7 @@ def response_hook(span, conn, args): class TestRedisInstance(TestBase): def setUp(self): super().setUp() - self.client = redis.Redis() + self.client = fakeredis.FakeStrictRedis() RedisInstrumentor().instrument_connection( client=self.client, tracer_provider=self.tracer_provider ) @@ -568,7 +568,7 @@ def redis_operations(client): pipe.execute() def test_watch_error_sync_only_client(self): - redis_client = redis.Redis() + redis_client = fakeredis.FakeStrictRedis() self.redis_operations(redis_client) From edca781fb2c4ed8ca91d56565df4df8f44c2097e Mon Sep 17 00:00:00 2001 From: OlegZv <19969581+OlegZv@users.noreply.github.com> Date: Sun, 22 Dec 2024 00:03:14 -0500 Subject: [PATCH 4/8] replace the redis version checks with defines --- .../opentelemetry/instrumentation/redis/__init__.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py index e0afd3f57a..bba9aa4ee6 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py @@ -466,7 +466,7 @@ def _instrument( f"{pipeline_class}.immediate_execute_command", _traced_execute_command, ) - if redis.VERSION >= _REDIS_CLUSTER_VERSION: + if _CLIENT_CLUSTER_SUPPORT: wrap_function_wrapper( "redis.cluster", "RedisCluster.execute_command", @@ -484,7 +484,7 @@ def _instrument( _async_traced_execute_pipeline = _async_traced_execute_pipeline_factory( tracer, request_hook, response_hook ) - if redis.VERSION >= _REDIS_ASYNCIO_VERSION: + if _CLIENT_ASYNCIO_SUPPORT: wrap_function_wrapper( "redis.asyncio", f"{redis_class}.execute_command", @@ -500,7 +500,7 @@ def _instrument( f"{pipeline_class}.immediate_execute_command", _async_traced_execute_command, ) - if redis.VERSION >= _REDIS_ASYNCIO_CLUSTER_VERSION: + if _CLIENT_ASYNCIO_CLUSTER_SUPPORT: wrap_function_wrapper( "redis.asyncio.cluster", "RedisCluster.execute_command", @@ -639,15 +639,15 @@ def _uninstrument(self, **kwargs: Any): unwrap(redis.Redis, "pipeline") unwrap(redis.client.Pipeline, "execute") unwrap(redis.client.Pipeline, "immediate_execute_command") - if redis.VERSION >= _REDIS_CLUSTER_VERSION: + if _CLIENT_CLUSTER_SUPPORT: unwrap(redis.cluster.RedisCluster, "execute_command") unwrap(redis.cluster.ClusterPipeline, "execute") - if redis.VERSION >= _REDIS_ASYNCIO_VERSION: + if _CLIENT_ASYNCIO_SUPPORT: unwrap(redis.asyncio.Redis, "execute_command") unwrap(redis.asyncio.Redis, "pipeline") unwrap(redis.asyncio.client.Pipeline, "execute") unwrap(redis.asyncio.client.Pipeline, "immediate_execute_command") - if redis.VERSION >= _REDIS_ASYNCIO_CLUSTER_VERSION: + if _CLIENT_ASYNCIO_CLUSTER_SUPPORT: unwrap(redis.asyncio.cluster.RedisCluster, "execute_command") unwrap(redis.asyncio.cluster.ClusterPipeline, "execute") From be3ca10460855646955ba2c9c11ab7e07d4fc06b Mon Sep 17 00:00:00 2001 From: OlegZv <19969581+OlegZv@users.noreply.github.com> Date: Sun, 22 Dec 2024 00:03:39 -0500 Subject: [PATCH 5/8] Adjust comment and fix one test --- .../src/opentelemetry/instrumentation/redis/__init__.py | 4 ++-- .../opentelemetry-instrumentation-redis/tests/test_redis.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py index bba9aa4ee6..069d74e15a 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py @@ -675,8 +675,8 @@ def uninstrument_connection(client): if getattr(client, INSTRUMENTATION_ATTR): # for all clients we need to unwrap execute_command and pipeline functions unwrap(client, "execute_command") - # pipeline was creating a pipeline and wrapping the functions of the - # created instance. any pipeline created before un-instrumenting will + # the method was creating a pipeline and wrapping the functions of the + # created instance. any pipelines created before un-instrumenting will # remain instrumented (pipelines should usually have a short span) unwrap(client, "pipeline") pass diff --git a/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py b/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py index 23ad23bd89..da2834706d 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py +++ b/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py @@ -37,7 +37,7 @@ class TestRedis(TestBase): def setUp(self): super().setUp() - RedisInstrumentor().instrument(_provider=self.tracer_provider) + RedisInstrumentor().instrument(tracer_provider=self.tracer_provider) def tearDown(self): super().tearDown() From 0918232e345ea96d0af81906721322d4d0b5b654 Mon Sep 17 00:00:00 2001 From: OlegZv <19969581+OlegZv@users.noreply.github.com> Date: Sun, 22 Dec 2024 00:12:02 -0500 Subject: [PATCH 6/8] Update documentation with the client method --- docs/conf.py | 1 + docs/instrumentation/redis/redis.rst | 9 +- .../README.rst | 4 +- .../instrumentation/redis/__init__.py | 199 +++++++++++++----- .../tests/test_redis.py | 12 +- 5 files changed, 165 insertions(+), 60 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 8233fccb15..d9c326262b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -122,6 +122,7 @@ "https://opentelemetry-python.readthedocs.io/en/latest/", None, ), + "redis": ("https://redis-py.readthedocs.io/en/latest/", None), } # http://www.sphinx-doc.org/en/master/config.html#confval-nitpicky diff --git a/docs/instrumentation/redis/redis.rst b/docs/instrumentation/redis/redis.rst index 4e21bce24b..32bdec19e0 100644 --- a/docs/instrumentation/redis/redis.rst +++ b/docs/instrumentation/redis/redis.rst @@ -1,7 +1,10 @@ -OpenTelemetry Redis Instrumentation -=================================== +.. include:: ../../../instrumentation/opentelemetry-instrumentation-redis/README.rst + :end-before: References + +Usage +----- .. automodule:: opentelemetry.instrumentation.redis :members: :undoc-members: - :show-inheritance: + :show-inheritance: \ No newline at end of file diff --git a/instrumentation/opentelemetry-instrumentation-httpx/README.rst b/instrumentation/opentelemetry-instrumentation-httpx/README.rst index cc465dd615..e23ba7e404 100644 --- a/instrumentation/opentelemetry-instrumentation-httpx/README.rst +++ b/instrumentation/opentelemetry-instrumentation-httpx/README.rst @@ -43,9 +43,11 @@ Instrumenting single clients **************************** If you only want to instrument requests for specific client instances, you can -use the `instrument_client` method. +use the `instrument_client`_ method. +.. _instrument_client: #opentelemetry.instrumentation.httpx.HTTPXClientInstrumentor.instrument_client + .. code-block:: python import httpx diff --git a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py index 069d74e15a..bb24d9246c 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py @@ -15,15 +15,14 @@ """ Instrument `redis`_ to report Redis queries. -There are two options for instrumenting code. The first option is to use the -``opentelemetry-instrument`` executable which will automatically -instrument your Redis client. The second is to programmatically enable -instrumentation via the following code: - .. _redis: https://pypi.org/project/redis/ -Usage ------ + +Instrument All Clients +---------------------- + +The easiest way to instrument all redis client instances is by +``RedisInstrumentor().instrument()``: .. code:: python @@ -38,7 +37,7 @@ client = redis.StrictRedis(host="localhost", port=6379) client.get("my-key") -Async Redis clients (i.e. redis.asyncio.Redis) are also instrumented in the same way: +Async Redis clients (i.e. ``redis.asyncio.Redis``) are also instrumented in the same way: .. code:: python @@ -54,19 +53,44 @@ async def redis_get(): client = redis.asyncio.Redis(host="localhost", port=6379) await client.get("my-key") -The `instrument` method accepts the following keyword args: +.. note:: + Calling the ``instrument`` method will instrument the client classes, so any client + created after the ``instrument`` call will be instrumented. To instrument only a + single client, use :func:`RedisInstrumentor.instrument_client` method. + +Instrument Single Client +------------------------ -tracer_provider (TracerProvider) - an optional tracer provider +The :func:`RedisInstrumentor.instrument_client` can instrument a connection instance. This is useful when there are multiple clients with a different redis database index. +Or, you might have a different connection pool used for an application function you +don't want instrumented. -request_hook (Callable) - a function with extra user-defined logic to be performed before performing the request -this function signature is: def request_hook(span: Span, instance: redis.connection.Connection, args, kwargs) -> None +.. code:: python + + from opentelemetry.instrumentation.redis import RedisInstrumentor + import redis + + instrumented_client = redis.Redis() + not_instrumented_client = redis.Redis() + + # Instrument redis + RedisInstrumentor.instrument_client(client=instrumented_client) + + # This will report a span with the default settings + instrumented_client.get("my-key") -response_hook (Callable) - a function with extra user-defined logic to be performed after performing the request -this function signature is: def response_hook(span: Span, instance: redis.connection.Connection, response) -> None + # This will not have a span + not_instrumented_client.get("my-key") -for example: +.. warning:: + All client instances created after calling ``RedisInstrumentor().instrument`` will + be instrumented. To avoid instrumenting all clients, use + :func:`RedisInstrumentor.instrument_client` . -.. code: python +Request/Response Hooks +---------------------- + +.. code:: python from opentelemetry.instrumentation.redis import RedisInstrumentor import redis @@ -86,7 +110,6 @@ def response_hook(span, instance, response): client = redis.StrictRedis(host="localhost", port=6379) client.get("my-key") - API --- """ @@ -111,7 +134,13 @@ def response_hook(span, instance, response): from opentelemetry.instrumentation.redis.version import __version__ from opentelemetry.instrumentation.utils import unwrap from opentelemetry.semconv.trace import SpanAttributes -from opentelemetry.trace import Span, StatusCode, Tracer +from opentelemetry.trace import ( + Span, + StatusCode, + Tracer, + TracerProvider, + get_tracer, +) if TYPE_CHECKING: from typing import Awaitable, TypeVar @@ -122,10 +151,10 @@ def response_hook(span, instance, response): import redis.cluster import redis.connection - _RequestHookT = Callable[ + RequestHook = Callable[ [Span, redis.connection.Connection, list[Any], dict[str, Any]], None ] - _ResponseHookT = Callable[[Span, redis.connection.Connection, Any], None] + ResponseHook = Callable[[Span, redis.connection.Connection, Any], None] AsyncPipelineInstance = TypeVar( "AsyncPipelineInstance", @@ -148,6 +177,7 @@ def response_hook(span, instance, response): _DEFAULT_SERVICE = "redis" _logger = logging.getLogger(__name__) +assert hasattr(redis, "VERSION") _REDIS_ASYNCIO_VERSION = (4, 2, 0) _REDIS_CLUSTER_VERSION = (4, 1, 0) @@ -281,9 +311,9 @@ def _build_span_meta_data_for_pipeline( def _traced_execute_factory( - tracer, - request_hook: _RequestHookT = None, - response_hook: _ResponseHookT = None, + tracer: Tracer, + request_hook: RequestHook | None = None, + response_hook: ResponseHook | None = None, ): def _traced_execute_command( func: Callable[..., R], @@ -316,9 +346,9 @@ def _traced_execute_command( def _traced_execute_pipeline_factory( - tracer, - request_hook: _RequestHookT = None, - response_hook: _ResponseHookT = None, + tracer: Tracer, + request_hook: RequestHook | None = None, + response_hook: ResponseHook | None = None, ): def _traced_execute_pipeline( func: Callable[..., R], @@ -361,9 +391,9 @@ def _traced_execute_pipeline( def _async_traced_execute_factory( - tracer, - request_hook: _RequestHookT = None, - response_hook: _ResponseHookT = None, + tracer: Tracer, + request_hook: RequestHook | None = None, + response_hook: ResponseHook | None = None, ): async def _async_traced_execute_command( func: Callable[..., Awaitable[R]], @@ -392,9 +422,9 @@ async def _async_traced_execute_command( def _async_traced_execute_pipeline_factory( - tracer, - request_hook: _RequestHookT = None, - response_hook: _ResponseHookT = None, + tracer: Tracer, + request_hook: RequestHook | None = None, + response_hook: ResponseHook | None = None, ): async def _async_traced_execute_pipeline( func: Callable[..., Awaitable[R]], @@ -441,8 +471,8 @@ async def _async_traced_execute_pipeline( # pylint: disable=R0915 def _instrument( tracer: Tracer, - request_hook: _RequestHookT | None = None, - response_hook: _ResponseHookT | None = None, + request_hook: RequestHook | None = None, + response_hook: ResponseHook | None = None, ): _traced_execute_command = _traced_execute_factory( tracer, request_hook, response_hook @@ -513,11 +543,11 @@ def _instrument( ) -def _instrument_connection( +def _instrument_client( client, - tracer, - request_hook: _RequestHookT = None, - response_hook: _ResponseHookT = None, + tracer: Tracer, + request_hook: RequestHook | None = None, + response_hook: ResponseHook | None = None, ): # first, handle async clients and cluster clients _async_traced_execute = _async_traced_execute_factory( @@ -589,23 +619,48 @@ def _pipeline_wrapper(func, instance, args, kwargs): class RedisInstrumentor(BaseInstrumentor): - """An instrumentor for Redis. - - See `BaseInstrumentor` - """ - @staticmethod def _get_tracer(**kwargs): tracer_provider = kwargs.get("tracer_provider") - return trace.get_tracer( + return get_tracer( __name__, __version__, tracer_provider=tracer_provider, schema_url="https://opentelemetry.io/schemas/1.11.0", ) - def instrumentation_dependencies(self) -> Collection[str]: - return _instruments + def instrument( + self, + tracer_provider: TracerProvider | None = None, + request_hook: RequestHook | None = None, + response_hook: ResponseHook | None = None, + **kwargs, + ): + """Instruments all Redis/StrictRedis/RedisCluster and async client instances. + + Args: + tracer_provider: A TracerProvider, defaults to global. + request_hook: + a function with extra user-defined logic to run before performing the request. + + The ``args`` is a tuple, where items are + command arguments. For example ``client.set("mykey", "value", ex=5)`` would + have ``args`` as ``('SET', 'mykey', 'value', 'EX', 5)``. + + The ``kwargs`` represents occasional ``options`` passed by redis. For example, + if you use ``client.set("mykey", "value", get=True)``, the ``kwargs`` would be + ``{'get': True}``. + response_hook: + a function with extra user-defined logic to run after the request is complete. + + The ``args`` represents the response. + """ + super().instrument( + tracer_provider=tracer_provider, + request_hook=request_hook, + response_hook=response_hook, + **kwargs, + ) def _instrument(self, **kwargs: Any): """Instruments the redis module @@ -652,13 +707,42 @@ def _uninstrument(self, **kwargs: Any): unwrap(redis.asyncio.cluster.ClusterPipeline, "execute") @staticmethod - def instrument_connection( - client, tracer_provider: None, request_hook=None, response_hook=None + def instrument_client( + client: redis.StrictRedis + | redis.Redis + | redis.asyncio.Redis + | redis.cluster.RedisCluster + | redis.asyncio.cluster.RedisCluster, + tracer_provider: TracerProvider | None = None, + request_hook: RequestHook | None = None, + response_hook: ResponseHook | None = None, ): + """Instrument the provided Redis Client. The client can be sync or async. + Cluster client is also supported. + + Args: + client: The redis client. + tracer_provider: A TracerProvider, defaults to global. + request_hook: a function with extra user-defined logic to run before + performing the request. + + The ``args`` is a tuple, where items are + command arguments. For example ``client.set("mykey", "value", ex=5)`` would + have ``args`` as ``('SET', 'mykey', 'value', 'EX', 5)``. + + The ``kwargs`` represents occasional ``options`` passed by redis. For example, + if you use ``client.set("mykey", "value", get=True)``, the ``kwargs`` would be + ``{'get': True}``. + + response_hook: a function with extra user-defined logic to run after + the request is complete. + + The ``args`` represents the response. + """ if not hasattr(client, INSTRUMENTATION_ATTR): setattr(client, INSTRUMENTATION_ATTR, False) if not getattr(client, INSTRUMENTATION_ATTR): - _instrument_connection( + _instrument_client( client, RedisInstrumentor._get_tracer(tracer_provider=tracer_provider), request_hook=request_hook, @@ -671,7 +755,18 @@ def instrument_connection( ) @staticmethod - def uninstrument_connection(client): + def uninstrument_client( + client: redis.StrictRedis + | redis.Redis + | redis.asyncio.Redis + | redis.cluster.RedisCluster + | redis.asyncio.cluster.RedisCluster, + ): + """Disables instrumentation for the given client instance + + Args: + client: The redis client + """ if getattr(client, INSTRUMENTATION_ATTR): # for all clients we need to unwrap execute_command and pipeline functions unwrap(client, "execute_command") @@ -685,3 +780,7 @@ def uninstrument_connection(client): "Attempting to un-instrument Redis connection that wasn't instrumented" ) return + + def instrumentation_dependencies(self) -> Collection[str]: + """Return a list of python packages with versions that the will be instrumented.""" + return _instruments diff --git a/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py b/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py index da2834706d..098b44737e 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py +++ b/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py @@ -436,7 +436,7 @@ def response_hook(span, conn, args): @pytest.mark.asyncio async def test_watch_error_async_only_client(self): - self.instrumentor.instrument_connection( + self.instrumentor.instrument_client( tracer_provider=self.tracer_provider, client=self.client ) redis_client = FakeRedis() @@ -463,7 +463,7 @@ async def test_watch_error_async_only_client(self): self.assertEqual(span.attributes.get("db.statement"), "SET ? ?") self.assertEqual(span.kind, SpanKind.CLIENT) self.assertEqual(span.status.status_code, trace.StatusCode.UNSET) - RedisInstrumentor().uninstrument_connection(self.client) + RedisInstrumentor().uninstrument_client(self.client) @pytest.mark.asyncio async def test_request_response_hooks(self): @@ -505,7 +505,7 @@ def response_hook(span, conn, args): if span and span.is_recording(): span.set_attribute(response_attr, args) - self.instrumentor.instrument_connection( + self.instrumentor.instrument_client( client=self.client, tracer_provider=self.tracer_provider, request_hook=request_hook, @@ -523,7 +523,7 @@ def response_hook(span, conn, args): self.memory_exporter.clear() await fresh_client.set("key", "value") self.assert_span_count(0) - self.instrumentor.uninstrument_connection(self.client) + self.instrumentor.uninstrument_client(self.client) # after un-instrumenting the query should not be recorder await self.client.set("key", "value") spans = self.assert_span_count(0) @@ -533,13 +533,13 @@ class TestRedisInstance(TestBase): def setUp(self): super().setUp() self.client = fakeredis.FakeStrictRedis() - RedisInstrumentor().instrument_connection( + RedisInstrumentor().instrument_client( client=self.client, tracer_provider=self.tracer_provider ) def tearDown(self): super().tearDown() - RedisInstrumentor().uninstrument_connection(self.client) + RedisInstrumentor().uninstrument_client(self.client) def test_only_client_instrumented(self): redis_client = redis.Redis() From 843e708232b95ff5ad3c03f2c16d078c99cff482 Mon Sep 17 00:00:00 2001 From: OlegZv <19969581+OlegZv@users.noreply.github.com> Date: Sun, 22 Dec 2024 00:46:59 -0500 Subject: [PATCH 7/8] Update the changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index a243091b1d..deed3084f5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ([#3100](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3100)) - Add support to database stability opt-in in `_semconv` utilities and add tests ([#3111](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3111)) +- `opentelemetry-instrumentation-redis` Add support for redis client-specific instrumentation. + ([#3143](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3143)) ### Fixed From c169b85bfeaa7aa0459be18e20ce3af190b96836 Mon Sep 17 00:00:00 2001 From: OlegZv <19969581+OlegZv@users.noreply.github.com> Date: Sun, 22 Dec 2024 01:09:44 -0500 Subject: [PATCH 8/8] Update the HTTPX readme to point to proper class method --- .../opentelemetry-instrumentation-httpx/README.rst | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/instrumentation/opentelemetry-instrumentation-httpx/README.rst b/instrumentation/opentelemetry-instrumentation-httpx/README.rst index e23ba7e404..70825b6d38 100644 --- a/instrumentation/opentelemetry-instrumentation-httpx/README.rst +++ b/instrumentation/opentelemetry-instrumentation-httpx/README.rst @@ -43,11 +43,9 @@ Instrumenting single clients **************************** If you only want to instrument requests for specific client instances, you can -use the `instrument_client`_ method. +use the `HTTPXClientInstrumentor.instrument_client` method. -.. _instrument_client: #opentelemetry.instrumentation.httpx.HTTPXClientInstrumentor.instrument_client - .. code-block:: python import httpx