diff --git a/CHANGELOG.md b/CHANGELOG.md index a243091b1d..deed3084f5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ([#3100](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3100)) - Add support to database stability opt-in in `_semconv` utilities and add tests ([#3111](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3111)) +- `opentelemetry-instrumentation-redis` Add support for redis client-specific instrumentation. + ([#3143](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3143)) ### Fixed diff --git a/docs/conf.py b/docs/conf.py index 8233fccb15..d9c326262b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -122,6 +122,7 @@ "https://opentelemetry-python.readthedocs.io/en/latest/", None, ), + "redis": ("https://redis-py.readthedocs.io/en/latest/", None), } # http://www.sphinx-doc.org/en/master/config.html#confval-nitpicky diff --git a/docs/instrumentation/redis/redis.rst b/docs/instrumentation/redis/redis.rst index 4e21bce24b..32bdec19e0 100644 --- a/docs/instrumentation/redis/redis.rst +++ b/docs/instrumentation/redis/redis.rst @@ -1,7 +1,10 @@ -OpenTelemetry Redis Instrumentation -=================================== +.. include:: ../../../instrumentation/opentelemetry-instrumentation-redis/README.rst + :end-before: References + +Usage +----- .. automodule:: opentelemetry.instrumentation.redis :members: :undoc-members: - :show-inheritance: + :show-inheritance: \ No newline at end of file diff --git a/instrumentation/opentelemetry-instrumentation-httpx/README.rst b/instrumentation/opentelemetry-instrumentation-httpx/README.rst index cc465dd615..70825b6d38 100644 --- a/instrumentation/opentelemetry-instrumentation-httpx/README.rst +++ b/instrumentation/opentelemetry-instrumentation-httpx/README.rst @@ -43,7 +43,7 @@ Instrumenting single clients **************************** If you only want to instrument requests for specific client instances, you can -use the `instrument_client` method. +use the `HTTPXClientInstrumentor.instrument_client` method. .. code-block:: python diff --git a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py index 8a3096ad41..bb24d9246c 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py @@ -15,15 +15,14 @@ """ Instrument `redis`_ to report Redis queries. -There are two options for instrumenting code. The first option is to use the -``opentelemetry-instrument`` executable which will automatically -instrument your Redis client. The second is to programmatically enable -instrumentation via the following code: - .. _redis: https://pypi.org/project/redis/ -Usage ------ + +Instrument All Clients +---------------------- + +The easiest way to instrument all redis client instances is by +``RedisInstrumentor().instrument()``: .. code:: python @@ -38,7 +37,7 @@ client = redis.StrictRedis(host="localhost", port=6379) client.get("my-key") -Async Redis clients (i.e. redis.asyncio.Redis) are also instrumented in the same way: +Async Redis clients (i.e. ``redis.asyncio.Redis``) are also instrumented in the same way: .. code:: python @@ -54,19 +53,44 @@ async def redis_get(): client = redis.asyncio.Redis(host="localhost", port=6379) await client.get("my-key") -The `instrument` method accepts the following keyword args: +.. note:: + Calling the ``instrument`` method will instrument the client classes, so any client + created after the ``instrument`` call will be instrumented. To instrument only a + single client, use :func:`RedisInstrumentor.instrument_client` method. + +Instrument Single Client +------------------------ + +The :func:`RedisInstrumentor.instrument_client` can instrument a connection instance. This is useful when there are multiple clients with a different redis database index. +Or, you might have a different connection pool used for an application function you +don't want instrumented. + +.. code:: python + + from opentelemetry.instrumentation.redis import RedisInstrumentor + import redis + + instrumented_client = redis.Redis() + not_instrumented_client = redis.Redis() + + # Instrument redis + RedisInstrumentor.instrument_client(client=instrumented_client) -tracer_provider (TracerProvider) - an optional tracer provider + # This will report a span with the default settings + instrumented_client.get("my-key") -request_hook (Callable) - a function with extra user-defined logic to be performed before performing the request -this function signature is: def request_hook(span: Span, instance: redis.connection.Connection, args, kwargs) -> None + # This will not have a span + not_instrumented_client.get("my-key") -response_hook (Callable) - a function with extra user-defined logic to be performed after performing the request -this function signature is: def response_hook(span: Span, instance: redis.connection.Connection, response) -> None +.. warning:: + All client instances created after calling ``RedisInstrumentor().instrument`` will + be instrumented. To avoid instrumenting all clients, use + :func:`RedisInstrumentor.instrument_client` . -for example: +Request/Response Hooks +---------------------- -.. code: python +.. code:: python from opentelemetry.instrumentation.redis import RedisInstrumentor import redis @@ -86,13 +110,13 @@ def response_hook(span, instance, response): client = redis.StrictRedis(host="localhost", port=6379) client.get("my-key") - API --- """ from __future__ import annotations +import logging from typing import TYPE_CHECKING, Any, Callable, Collection import redis @@ -110,7 +134,13 @@ def response_hook(span, instance, response): from opentelemetry.instrumentation.redis.version import __version__ from opentelemetry.instrumentation.utils import unwrap from opentelemetry.semconv.trace import SpanAttributes -from opentelemetry.trace import Span, StatusCode, Tracer +from opentelemetry.trace import ( + Span, + StatusCode, + Tracer, + TracerProvider, + get_tracer, +) if TYPE_CHECKING: from typing import Awaitable, TypeVar @@ -121,10 +151,10 @@ def response_hook(span, instance, response): import redis.cluster import redis.connection - _RequestHookT = Callable[ + RequestHook = Callable[ [Span, redis.connection.Connection, list[Any], dict[str, Any]], None ] - _ResponseHookT = Callable[[Span, redis.connection.Connection, Any], None] + ResponseHook = Callable[[Span, redis.connection.Connection, Any], None] AsyncPipelineInstance = TypeVar( "AsyncPipelineInstance", @@ -146,17 +176,27 @@ def response_hook(span, instance, response): _DEFAULT_SERVICE = "redis" - +_logger = logging.getLogger(__name__) +assert hasattr(redis, "VERSION") _REDIS_ASYNCIO_VERSION = (4, 2, 0) -if redis.VERSION >= _REDIS_ASYNCIO_VERSION: - import redis.asyncio - _REDIS_CLUSTER_VERSION = (4, 1, 0) _REDIS_ASYNCIO_CLUSTER_VERSION = (4, 3, 2) _FIELD_TYPES = ["NUMERIC", "TEXT", "GEO", "TAG", "VECTOR"] +_CLIENT_ASYNCIO_SUPPORT = redis.VERSION >= _REDIS_ASYNCIO_VERSION +_CLIENT_ASYNCIO_CLUSTER_SUPPORT = ( + redis.VERSION >= _REDIS_ASYNCIO_CLUSTER_VERSION +) +_CLIENT_CLUSTER_SUPPORT = redis.VERSION >= _REDIS_CLUSTER_VERSION +_CLIENT_BEFORE_3_0_0 = redis.VERSION < (3, 0, 0) + +if _CLIENT_ASYNCIO_SUPPORT: + import redis.asyncio + +INSTRUMENTATION_ATTR = "_is_instrumented_by_opentelemetry" + def _set_connection_attributes( span: Span, conn: RedisInstance | AsyncRedisInstance @@ -184,6 +224,62 @@ def _build_span_name( return name +def _add_create_attributes(span: Span, args: tuple[Any, ...]): + _set_span_attribute_if_value( + span, "redis.create_index.index", _value_or_none(args, 1) + ) + # According to: https://github.com/redis/redis-py/blob/master/redis/commands/search/commands.py#L155 schema is last argument for execute command + try: + schema_index = args.index("SCHEMA") + except ValueError: + return + schema = args[schema_index:] + field_attribute = "" + # Schema in format: + # [first_field_name, first_field_type, first_field_some_attribute1, first_field_some_attribute2, second_field_name, ...] + field_attribute = "".join( + f"Field(name: {schema[index - 1]}, type: {schema[index]});" + for index in range(1, len(schema)) + if schema[index] in _FIELD_TYPES + ) + _set_span_attribute_if_value( + span, + "redis.create_index.fields", + field_attribute, + ) + + +def _add_search_attributes(span: Span, response, args): + _set_span_attribute_if_value( + span, "redis.search.index", _value_or_none(args, 1) + ) + _set_span_attribute_if_value( + span, "redis.search.query", _value_or_none(args, 2) + ) + # Parse response from search + # https://redis.io/docs/latest/commands/ft.search/ + # Response in format: + # [number_of_returned_documents, index_of_first_returned_doc, first_doc(as a list), index_of_second_returned_doc, second_doc(as a list) ...] + # Returned documents in array format: + # [first_field_name, first_field_value, second_field_name, second_field_value ...] + number_of_returned_documents = _value_or_none(response, 0) + _set_span_attribute_if_value( + span, "redis.search.total", number_of_returned_documents + ) + if "NOCONTENT" in args or not number_of_returned_documents: + return + for document_number in range(number_of_returned_documents): + document_index = _value_or_none(response, 1 + 2 * document_number) + if document_index: + document = response[2 + 2 * document_number] + for attribute_name_index in range(0, len(document), 2): + _set_span_attribute_if_value( + span, + f"redis.search.xdoc_{document_index}.{document[attribute_name_index]}", + document[attribute_name_index + 1], + ) + + def _build_span_meta_data_for_pipeline( instance: PipelineInstance | AsyncPipelineInstance, ) -> tuple[list[Any], str, str]: @@ -214,11 +310,10 @@ def _build_span_meta_data_for_pipeline( return command_stack, resource, span_name -# pylint: disable=R0915 -def _instrument( +def _traced_execute_factory( tracer: Tracer, - request_hook: _RequestHookT | None = None, - response_hook: _ResponseHookT | None = None, + request_hook: RequestHook | None = None, + response_hook: ResponseHook | None = None, ): def _traced_execute_command( func: Callable[..., R], @@ -247,6 +342,14 @@ def _traced_execute_command( response_hook(span, instance, response) return response + return _traced_execute_command + + +def _traced_execute_pipeline_factory( + tracer: Tracer, + request_hook: RequestHook | None = None, + response_hook: ResponseHook | None = None, +): def _traced_execute_pipeline( func: Callable[..., R], instance: PipelineInstance, @@ -284,90 +387,14 @@ def _traced_execute_pipeline( return response - def _add_create_attributes(span: Span, args: tuple[Any, ...]): - _set_span_attribute_if_value( - span, "redis.create_index.index", _value_or_none(args, 1) - ) - # According to: https://github.com/redis/redis-py/blob/master/redis/commands/search/commands.py#L155 schema is last argument for execute command - try: - schema_index = args.index("SCHEMA") - except ValueError: - return - schema = args[schema_index:] - field_attribute = "" - # Schema in format: - # [first_field_name, first_field_type, first_field_some_attribute1, first_field_some_attribute2, second_field_name, ...] - field_attribute = "".join( - f"Field(name: {schema[index - 1]}, type: {schema[index]});" - for index in range(1, len(schema)) - if schema[index] in _FIELD_TYPES - ) - _set_span_attribute_if_value( - span, - "redis.create_index.fields", - field_attribute, - ) + return _traced_execute_pipeline - def _add_search_attributes(span: Span, response, args): - _set_span_attribute_if_value( - span, "redis.search.index", _value_or_none(args, 1) - ) - _set_span_attribute_if_value( - span, "redis.search.query", _value_or_none(args, 2) - ) - # Parse response from search - # https://redis.io/docs/latest/commands/ft.search/ - # Response in format: - # [number_of_returned_documents, index_of_first_returned_doc, first_doc(as a list), index_of_second_returned_doc, second_doc(as a list) ...] - # Returned documents in array format: - # [first_field_name, first_field_value, second_field_name, second_field_value ...] - number_of_returned_documents = _value_or_none(response, 0) - _set_span_attribute_if_value( - span, "redis.search.total", number_of_returned_documents - ) - if "NOCONTENT" in args or not number_of_returned_documents: - return - for document_number in range(number_of_returned_documents): - document_index = _value_or_none(response, 1 + 2 * document_number) - if document_index: - document = response[2 + 2 * document_number] - for attribute_name_index in range(0, len(document), 2): - _set_span_attribute_if_value( - span, - f"redis.search.xdoc_{document_index}.{document[attribute_name_index]}", - document[attribute_name_index + 1], - ) - - pipeline_class = ( - "BasePipeline" if redis.VERSION < (3, 0, 0) else "Pipeline" - ) - redis_class = "StrictRedis" if redis.VERSION < (3, 0, 0) else "Redis" - - wrap_function_wrapper( - "redis", f"{redis_class}.execute_command", _traced_execute_command - ) - wrap_function_wrapper( - "redis.client", - f"{pipeline_class}.execute", - _traced_execute_pipeline, - ) - wrap_function_wrapper( - "redis.client", - f"{pipeline_class}.immediate_execute_command", - _traced_execute_command, - ) - if redis.VERSION >= _REDIS_CLUSTER_VERSION: - wrap_function_wrapper( - "redis.cluster", - "RedisCluster.execute_command", - _traced_execute_command, - ) - wrap_function_wrapper( - "redis.cluster", - "ClusterPipeline.execute", - _traced_execute_pipeline, - ) +def _async_traced_execute_factory( + tracer: Tracer, + request_hook: RequestHook | None = None, + response_hook: ResponseHook | None = None, +): async def _async_traced_execute_command( func: Callable[..., Awaitable[R]], instance: AsyncRedisInstance, @@ -391,6 +418,14 @@ async def _async_traced_execute_command( response_hook(span, instance, response) return response + return _async_traced_execute_command + + +def _async_traced_execute_pipeline_factory( + tracer: Tracer, + request_hook: RequestHook | None = None, + response_hook: ResponseHook | None = None, +): async def _async_traced_execute_pipeline( func: Callable[..., Awaitable[R]], instance: AsyncPipelineInstance, @@ -430,7 +465,56 @@ async def _async_traced_execute_pipeline( return response - if redis.VERSION >= _REDIS_ASYNCIO_VERSION: + return _async_traced_execute_pipeline + + +# pylint: disable=R0915 +def _instrument( + tracer: Tracer, + request_hook: RequestHook | None = None, + response_hook: ResponseHook | None = None, +): + _traced_execute_command = _traced_execute_factory( + tracer, request_hook, response_hook + ) + _traced_execute_pipeline = _traced_execute_pipeline_factory( + tracer, request_hook, response_hook + ) + pipeline_class = "BasePipeline" if _CLIENT_BEFORE_3_0_0 else "Pipeline" + redis_class = "StrictRedis" if _CLIENT_BEFORE_3_0_0 else "Redis" + + wrap_function_wrapper( + "redis", f"{redis_class}.execute_command", _traced_execute_command + ) + wrap_function_wrapper( + "redis.client", + f"{pipeline_class}.execute", + _traced_execute_pipeline, + ) + wrap_function_wrapper( + "redis.client", + f"{pipeline_class}.immediate_execute_command", + _traced_execute_command, + ) + if _CLIENT_CLUSTER_SUPPORT: + wrap_function_wrapper( + "redis.cluster", + "RedisCluster.execute_command", + _traced_execute_command, + ) + wrap_function_wrapper( + "redis.cluster", + "ClusterPipeline.execute", + _traced_execute_pipeline, + ) + + _async_traced_execute_command = _async_traced_execute_factory( + tracer, request_hook, response_hook + ) + _async_traced_execute_pipeline = _async_traced_execute_pipeline_factory( + tracer, request_hook, response_hook + ) + if _CLIENT_ASYNCIO_SUPPORT: wrap_function_wrapper( "redis.asyncio", f"{redis_class}.execute_command", @@ -446,7 +530,7 @@ async def _async_traced_execute_pipeline( f"{pipeline_class}.immediate_execute_command", _async_traced_execute_command, ) - if redis.VERSION >= _REDIS_ASYNCIO_CLUSTER_VERSION: + if _CLIENT_ASYNCIO_CLUSTER_SUPPORT: wrap_function_wrapper( "redis.asyncio.cluster", "RedisCluster.execute_command", @@ -459,14 +543,124 @@ async def _async_traced_execute_pipeline( ) +def _instrument_client( + client, + tracer: Tracer, + request_hook: RequestHook | None = None, + response_hook: ResponseHook | None = None, +): + # first, handle async clients and cluster clients + _async_traced_execute = _async_traced_execute_factory( + tracer, request_hook, response_hook + ) + _async_traced_execute_pipeline = _async_traced_execute_pipeline_factory( + tracer, request_hook, response_hook + ) + + if _CLIENT_ASYNCIO_SUPPORT and isinstance(client, redis.asyncio.Redis): + + def _async_pipeline_wrapper(func, instance, args, kwargs): + result = func(*args, **kwargs) + wrap_function_wrapper( + result, "execute", _async_traced_execute_pipeline + ) + wrap_function_wrapper( + result, "immediate_execute_command", _async_traced_execute + ) + return result + + wrap_function_wrapper(client, "execute_command", _async_traced_execute) + wrap_function_wrapper(client, "pipeline", _async_pipeline_wrapper) + return + + if _CLIENT_ASYNCIO_CLUSTER_SUPPORT and isinstance( + client, redis.asyncio.RedisCluster + ): + + def _async_cluster_pipeline_wrapper(func, instance, args, kwargs): + result = func(*args, **kwargs) + wrap_function_wrapper( + result, "execute", _async_traced_execute_pipeline + ) + return result + + wrap_function_wrapper(client, "execute_command", _async_traced_execute) + wrap_function_wrapper( + client, "pipeline", _async_cluster_pipeline_wrapper + ) + return + # for redis.client.Redis, redis.Cluster and v3.0.0 redis.client.StrictRedis + # the wrappers are the same + _traced_execute = _traced_execute_factory( + tracer, request_hook, response_hook + ) + _traced_execute_pipeline = _traced_execute_pipeline_factory( + tracer, request_hook, response_hook + ) + + def _pipeline_wrapper(func, instance, args, kwargs): + result = func(*args, **kwargs) + wrap_function_wrapper(result, "execute", _traced_execute_pipeline) + wrap_function_wrapper( + result, "immediate_execute_command", _traced_execute + ) + return result + + wrap_function_wrapper( + client, + "execute_command", + _traced_execute, + ) + wrap_function_wrapper( + client, + "pipeline", + _pipeline_wrapper, + ) + + class RedisInstrumentor(BaseInstrumentor): - """An instrumentor for Redis. + @staticmethod + def _get_tracer(**kwargs): + tracer_provider = kwargs.get("tracer_provider") + return get_tracer( + __name__, + __version__, + tracer_provider=tracer_provider, + schema_url="https://opentelemetry.io/schemas/1.11.0", + ) - See `BaseInstrumentor` - """ + def instrument( + self, + tracer_provider: TracerProvider | None = None, + request_hook: RequestHook | None = None, + response_hook: ResponseHook | None = None, + **kwargs, + ): + """Instruments all Redis/StrictRedis/RedisCluster and async client instances. - def instrumentation_dependencies(self) -> Collection[str]: - return _instruments + Args: + tracer_provider: A TracerProvider, defaults to global. + request_hook: + a function with extra user-defined logic to run before performing the request. + + The ``args`` is a tuple, where items are + command arguments. For example ``client.set("mykey", "value", ex=5)`` would + have ``args`` as ``('SET', 'mykey', 'value', 'EX', 5)``. + + The ``kwargs`` represents occasional ``options`` passed by redis. For example, + if you use ``client.set("mykey", "value", get=True)``, the ``kwargs`` would be + ``{'get': True}``. + response_hook: + a function with extra user-defined logic to run after the request is complete. + + The ``args`` represents the response. + """ + super().instrument( + tracer_provider=tracer_provider, + request_hook=request_hook, + response_hook=response_hook, + **kwargs, + ) def _instrument(self, **kwargs: Any): """Instruments the redis module @@ -476,21 +670,14 @@ def _instrument(self, **kwargs: Any): ``tracer_provider``: a TracerProvider, defaults to global. ``response_hook``: An optional callback which is invoked right before the span is finished processing a response. """ - tracer_provider = kwargs.get("tracer_provider") - tracer = trace.get_tracer( - __name__, - __version__, - tracer_provider=tracer_provider, - schema_url="https://opentelemetry.io/schemas/1.11.0", - ) _instrument( - tracer, + self._get_tracer(**kwargs), request_hook=kwargs.get("request_hook"), response_hook=kwargs.get("response_hook"), ) def _uninstrument(self, **kwargs: Any): - if redis.VERSION < (3, 0, 0): + if _CLIENT_BEFORE_3_0_0: unwrap(redis.StrictRedis, "execute_command") unwrap(redis.StrictRedis, "pipeline") unwrap(redis.Redis, "pipeline") @@ -507,14 +694,93 @@ def _uninstrument(self, **kwargs: Any): unwrap(redis.Redis, "pipeline") unwrap(redis.client.Pipeline, "execute") unwrap(redis.client.Pipeline, "immediate_execute_command") - if redis.VERSION >= _REDIS_CLUSTER_VERSION: + if _CLIENT_CLUSTER_SUPPORT: unwrap(redis.cluster.RedisCluster, "execute_command") unwrap(redis.cluster.ClusterPipeline, "execute") - if redis.VERSION >= _REDIS_ASYNCIO_VERSION: + if _CLIENT_ASYNCIO_SUPPORT: unwrap(redis.asyncio.Redis, "execute_command") unwrap(redis.asyncio.Redis, "pipeline") unwrap(redis.asyncio.client.Pipeline, "execute") unwrap(redis.asyncio.client.Pipeline, "immediate_execute_command") - if redis.VERSION >= _REDIS_ASYNCIO_CLUSTER_VERSION: + if _CLIENT_ASYNCIO_CLUSTER_SUPPORT: unwrap(redis.asyncio.cluster.RedisCluster, "execute_command") unwrap(redis.asyncio.cluster.ClusterPipeline, "execute") + + @staticmethod + def instrument_client( + client: redis.StrictRedis + | redis.Redis + | redis.asyncio.Redis + | redis.cluster.RedisCluster + | redis.asyncio.cluster.RedisCluster, + tracer_provider: TracerProvider | None = None, + request_hook: RequestHook | None = None, + response_hook: ResponseHook | None = None, + ): + """Instrument the provided Redis Client. The client can be sync or async. + Cluster client is also supported. + + Args: + client: The redis client. + tracer_provider: A TracerProvider, defaults to global. + request_hook: a function with extra user-defined logic to run before + performing the request. + + The ``args`` is a tuple, where items are + command arguments. For example ``client.set("mykey", "value", ex=5)`` would + have ``args`` as ``('SET', 'mykey', 'value', 'EX', 5)``. + + The ``kwargs`` represents occasional ``options`` passed by redis. For example, + if you use ``client.set("mykey", "value", get=True)``, the ``kwargs`` would be + ``{'get': True}``. + + response_hook: a function with extra user-defined logic to run after + the request is complete. + + The ``args`` represents the response. + """ + if not hasattr(client, INSTRUMENTATION_ATTR): + setattr(client, INSTRUMENTATION_ATTR, False) + if not getattr(client, INSTRUMENTATION_ATTR): + _instrument_client( + client, + RedisInstrumentor._get_tracer(tracer_provider=tracer_provider), + request_hook=request_hook, + response_hook=response_hook, + ) + setattr(client, INSTRUMENTATION_ATTR, True) + else: + _logger.warning( + "Attempting to instrument Redis connection while already instrumented" + ) + + @staticmethod + def uninstrument_client( + client: redis.StrictRedis + | redis.Redis + | redis.asyncio.Redis + | redis.cluster.RedisCluster + | redis.asyncio.cluster.RedisCluster, + ): + """Disables instrumentation for the given client instance + + Args: + client: The redis client + """ + if getattr(client, INSTRUMENTATION_ATTR): + # for all clients we need to unwrap execute_command and pipeline functions + unwrap(client, "execute_command") + # the method was creating a pipeline and wrapping the functions of the + # created instance. any pipelines created before un-instrumenting will + # remain instrumented (pipelines should usually have a short span) + unwrap(client, "pipeline") + pass + else: + _logger.warning( + "Attempting to un-instrument Redis connection that wasn't instrumented" + ) + return + + def instrumentation_dependencies(self) -> Collection[str]: + """Return a list of python packages with versions that the will be instrumented.""" + return _instruments diff --git a/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py b/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py index c436589adb..098b44737e 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py +++ b/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py @@ -389,25 +389,66 @@ def redis_operations(): class TestRedisAsync(TestBase, IsolatedAsyncioTestCase): def setUp(self): super().setUp() - RedisInstrumentor().instrument(tracer_provider=self.tracer_provider) + self.instrumentor = RedisInstrumentor() + self.client = FakeRedis() + + async def _redis_pipeline_operations(self, client): + with pytest.raises(WatchError): + async with client.pipeline(transaction=False) as pipe: + await pipe.watch("a") + await client.set("a", "bad") + pipe.multi() + await pipe.set("a", "1") + await pipe.execute() - def tearDown(self): - super().tearDown() + @pytest.mark.asyncio + async def test_watch_error_async(self): + # this tests also ensures the response_hook is called + response_attr = "my.response.attribute" + count = 0 + + def response_hook(span, conn, args): + nonlocal count + if span and span.is_recording(): + span.set_attribute(response_attr, count) + count += 1 + + self.instrumentor.instrument( + tracer_provider=self.tracer_provider, response_hook=response_hook + ) + redis_client = FakeRedis() + await self._redis_pipeline_operations(redis_client) + + # there should be 3 tests, we start watch operation and have 2 set operation on same key + spans = self.assert_span_count(3) + + self.assertEqual(spans[0].attributes.get("db.statement"), "WATCH ?") + self.assertEqual(spans[0].kind, SpanKind.CLIENT) + self.assertEqual(spans[0].status.status_code, trace.StatusCode.UNSET) + self.assertEqual(spans[0].attributes.get(response_attr), 0) + + for span_index, span in enumerate(spans[1:], 1): + self.assertEqual(span.attributes.get("db.statement"), "SET ? ?") + self.assertEqual(span.kind, SpanKind.CLIENT) + self.assertEqual(span.status.status_code, trace.StatusCode.UNSET) + self.assertEqual(span.attributes.get(response_attr), span_index) RedisInstrumentor().uninstrument() @pytest.mark.asyncio - async def test_watch_error_async(self): - async def redis_operations(): - with pytest.raises(WatchError): - redis_client = FakeRedis() - async with redis_client.pipeline(transaction=False) as pipe: - await pipe.watch("a") - await redis_client.set("a", "bad") - pipe.multi() - await pipe.set("a", "1") - await pipe.execute() + async def test_watch_error_async_only_client(self): + self.instrumentor.instrument_client( + tracer_provider=self.tracer_provider, client=self.client + ) + redis_client = FakeRedis() + await self._redis_pipeline_operations(redis_client) + + spans = self.memory_exporter.get_finished_spans() + + # there should be 3 tests, we start watch operation and have 2 set operation on same key + self.assertEqual(len(spans), 0) - await redis_operations() + # now with the instrumented client we should get proper spans + await self._redis_pipeline_operations(self.client) spans = self.memory_exporter.get_finished_spans() @@ -418,6 +459,130 @@ async def redis_operations(): self.assertEqual(spans[0].kind, SpanKind.CLIENT) self.assertEqual(spans[0].status.status_code, trace.StatusCode.UNSET) + for span in spans[1:]: + self.assertEqual(span.attributes.get("db.statement"), "SET ? ?") + self.assertEqual(span.kind, SpanKind.CLIENT) + self.assertEqual(span.status.status_code, trace.StatusCode.UNSET) + RedisInstrumentor().uninstrument_client(self.client) + + @pytest.mark.asyncio + async def test_request_response_hooks(self): + request_attr = "my.request.attribute" + response_attr = "my.response.attribute" + + def request_hook(span, conn, args, kwargs): + if span and span.is_recording(): + span.set_attribute(request_attr, args[0]) + + def response_hook(span, conn, args): + if span and span.is_recording(): + span.set_attribute(response_attr, args) + + self.instrumentor.instrument( + tracer_provider=self.tracer_provider, + request_hook=request_hook, + response_hook=response_hook, + ) + await self.client.set("key", "value") + + spans = self.assert_span_count(1) + + span = spans[0] + self.assertEqual(span.attributes.get(request_attr), "SET") + self.assertEqual(span.attributes.get(response_attr), True) + self.instrumentor.uninstrument() + + @pytest.mark.asyncio + async def test_request_response_hooks_connection_only(self): + request_attr = "my.request.attribute" + response_attr = "my.response.attribute" + + def request_hook(span, conn, args, kwargs): + if span and span.is_recording(): + span.set_attribute(request_attr, args[0]) + + def response_hook(span, conn, args): + if span and span.is_recording(): + span.set_attribute(response_attr, args) + + self.instrumentor.instrument_client( + client=self.client, + tracer_provider=self.tracer_provider, + request_hook=request_hook, + response_hook=response_hook, + ) + await self.client.set("key", "value") + + spans = self.assert_span_count(1) + + span = spans[0] + self.assertEqual(span.attributes.get(request_attr), "SET") + self.assertEqual(span.attributes.get(response_attr), True) + # fresh client should not record any spans + fresh_client = FakeRedis() + self.memory_exporter.clear() + await fresh_client.set("key", "value") + self.assert_span_count(0) + self.instrumentor.uninstrument_client(self.client) + # after un-instrumenting the query should not be recorder + await self.client.set("key", "value") + spans = self.assert_span_count(0) + + +class TestRedisInstance(TestBase): + def setUp(self): + super().setUp() + self.client = fakeredis.FakeStrictRedis() + RedisInstrumentor().instrument_client( + client=self.client, tracer_provider=self.tracer_provider + ) + + def tearDown(self): + super().tearDown() + RedisInstrumentor().uninstrument_client(self.client) + + def test_only_client_instrumented(self): + redis_client = redis.Redis() + + with mock.patch.object(redis_client, "connection"): + redis_client.get("key") + + spans = self.assert_span_count(0) + + # now use the test client + with mock.patch.object(self.client, "connection"): + self.client.get("key") + spans = self.assert_span_count(1) + span = spans[0] + self.assertEqual(span.name, "GET") + self.assertEqual(span.kind, SpanKind.CLIENT) + + @staticmethod + def redis_operations(client): + with pytest.raises(WatchError): + pipe = client.pipeline(transaction=True) + pipe.watch("a") + client.set("a", "bad") # This will cause the WatchError + pipe.multi() + pipe.set("a", "1") + pipe.execute() + + def test_watch_error_sync_only_client(self): + redis_client = fakeredis.FakeStrictRedis() + + self.redis_operations(redis_client) + + self.assert_span_count(0) + + self.redis_operations(self.client) + + # there should be 3 tests, we start watch operation and have 2 set operation on same key + spans = self.assert_span_count(3) + + self.assertEqual(spans[0].attributes.get("db.statement"), "WATCH ?") + self.assertEqual(spans[0].kind, SpanKind.CLIENT) + self.assertEqual(spans[0].status.status_code, trace.StatusCode.UNSET) + for span in spans[1:]: self.assertEqual(span.attributes.get("db.statement"), "SET ? ?") self.assertEqual(span.kind, SpanKind.CLIENT)