From dcd3e4c73750cab1e0ad276e1b22f45801536965 Mon Sep 17 00:00:00 2001
From: Aaron Abbott <aaronabbott@google.com>
Date: Wed, 15 Jan 2025 23:15:20 +0000
Subject: [PATCH] Add common gen AI utils into opentelemetry-instrumentation

---
 CHANGELOG.md                                  |  3 +
 .../instrumentation/genai_utils.py            | 53 +++++++++++
 .../tests/test_genai_utils.py                 | 91 +++++++++++++++++++
 3 files changed, 147 insertions(+)
 create mode 100644 opentelemetry-instrumentation/src/opentelemetry/instrumentation/genai_utils.py
 create mode 100644 opentelemetry-instrumentation/tests/test_genai_utils.py

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 6e40e73270..4a00dd4cbd 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
 
 ## Unreleased
 
+- `opentelemetry-instrumentation` Add common gen AI utils into opentelemetry-instrumentation
+  ([#3188](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3188))
+
 ### Added
 
 - `opentelemetry-instrumentation-confluent-kafka` Add support for confluent-kafka <=2.7.0
diff --git a/opentelemetry-instrumentation/src/opentelemetry/instrumentation/genai_utils.py b/opentelemetry-instrumentation/src/opentelemetry/instrumentation/genai_utils.py
new file mode 100644
index 0000000000..44ba12da30
--- /dev/null
+++ b/opentelemetry-instrumentation/src/opentelemetry/instrumentation/genai_utils.py
@@ -0,0 +1,53 @@
+# Copyright The OpenTelemetry Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from os import environ
+from typing import Mapping
+
+from opentelemetry.semconv._incubating.attributes import (
+    gen_ai_attributes as GenAIAttributes,
+)
+from opentelemetry.semconv.attributes import (
+    error_attributes as ErrorAttributes,
+)
+from opentelemetry.trace import Span
+from opentelemetry.trace.status import Status, StatusCode
+from opentelemetry.util.types import AttributeValue
+
+OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT = (
+    "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT"
+)
+
+
+def is_content_enabled() -> bool:
+    capture_content = environ.get(
+        OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT, "false"
+    )
+
+    return capture_content.lower() == "true"
+
+
+def get_span_name(span_attributes: Mapping[str, AttributeValue]) -> str:
+    name = span_attributes.get(GenAIAttributes.GEN_AI_OPERATION_NAME, "")
+    model = span_attributes.get(GenAIAttributes.GEN_AI_REQUEST_MODEL, "")
+    return f"{name} {model}"
+
+
+def handle_span_exception(span: Span, error: Exception) -> None:
+    span.set_status(Status(StatusCode.ERROR, str(error)))
+    if span.is_recording():
+        span.set_attribute(
+            ErrorAttributes.ERROR_TYPE, type(error).__qualname__
+        )
+    span.end()
diff --git a/opentelemetry-instrumentation/tests/test_genai_utils.py b/opentelemetry-instrumentation/tests/test_genai_utils.py
new file mode 100644
index 0000000000..b61d4987e2
--- /dev/null
+++ b/opentelemetry-instrumentation/tests/test_genai_utils.py
@@ -0,0 +1,91 @@
+# Copyright The OpenTelemetry Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from unittest.mock import patch
+
+from opentelemetry.instrumentation.genai_utils import (
+    OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT,
+    get_span_name,
+    handle_span_exception,
+    is_content_enabled,
+)
+from opentelemetry.sdk.trace import ReadableSpan
+from opentelemetry.test.test_base import TestBase
+from opentelemetry.trace.status import StatusCode
+
+
+class MyTestException(Exception):
+    pass
+
+
+class TestGenaiUtils(TestBase):
+    @patch.dict(
+        "os.environ",
+        {},
+    )
+    def test_is_content_enabled_default(self):
+        self.assertFalse(is_content_enabled())
+
+    def test_is_content_enabled_true(self):
+        for env_value in "true", "TRUE", "True", "tRue":
+            with patch.dict(
+                "os.environ",
+                {
+                    OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT: env_value
+                },
+            ):
+                self.assertTrue(is_content_enabled())
+
+    def test_is_content_enabled_false(self):
+        for env_value in "false", "FALSE", "False", "fAlse":
+            with patch.dict(
+                "os.environ",
+                {
+                    OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT: env_value
+                },
+            ):
+                self.assertFalse(is_content_enabled())
+
+    def test_get_span_name(self):
+        span_attributes = {
+            "gen_ai.operation.name": "chat",
+            "gen_ai.request.model": "mymodel",
+        }
+        self.assertEqual(get_span_name(span_attributes), "chat mymodel")
+
+        span_attributes = {
+            "gen_ai.operation.name": "chat",
+        }
+        self.assertEqual(get_span_name(span_attributes), "chat ")
+
+        span_attributes = {
+            "gen_ai.request.model": "mymodel",
+        }
+        self.assertEqual(get_span_name(span_attributes), " mymodel")
+
+        span_attributes = {}
+        self.assertEqual(get_span_name(span_attributes), " ")
+
+    def test_handle_span_exception(self):
+        tracer = self.tracer_provider.get_tracer("test_handle_span_exception")
+        with tracer.start_as_current_span("foo") as span:
+            handle_span_exception(span, MyTestException())
+
+        self.assertEqual(len(self.get_finished_spans()), 1)
+        finished_span: ReadableSpan = self.get_finished_spans()[0]
+        self.assertEqual(finished_span.name, "foo")
+        self.assertIs(finished_span.status.status_code, StatusCode.ERROR)
+        self.assertEqual(
+            finished_span.attributes["error.type"], "MyTestException"
+        )