Skip to content

Commit

Permalink
Add trace id propogation for constant trace id and from request (#111)
Browse files Browse the repository at this point in the history
Signed-off-by: hansrajr <[email protected]>
  • Loading branch information
Hansrajr authored Dec 24, 2024
1 parent d5ab8d1 commit a0ece17
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 20 deletions.
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ dev = [
'requests-aws4auth==1.2.3',
'opensearch-haystack==1.2.0',
'langchainhub==0.1.21',
'chromadb==0.4.22'
'chromadb==0.4.22',
'flask',
'opentelemetry-instrumentation-flask'
]

azure = [
Expand Down
75 changes: 70 additions & 5 deletions src/monocle_apptrace/instrumentation/common/instrumentor.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import logging
from typing import Collection, Dict, List, Union

import random
import uuid
from opentelemetry import trace
from opentelemetry.context import attach, get_value, set_value
from opentelemetry.context import attach, get_value, set_value, get_current, detach
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.utils import unwrap
from opentelemetry.trace import SpanContext
from opentelemetry.sdk.trace import TracerProvider, Span, id_generator
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
from opentelemetry.sdk.trace import Span, TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanProcessor
from opentelemetry.trace import get_tracer
from wrapt import wrap_function_wrapper

from opentelemetry.trace.propagation import set_span_in_context
from monocle_apptrace.exporters.monocle_exporters import get_monocle_exporter
from monocle_apptrace.instrumentation.common.span_handler import SpanHandler
from monocle_apptrace.instrumentation.common.wrapper_method import (
Expand All @@ -24,6 +27,10 @@

_instruments = ()

monocle_tracer_provider: TracerProvider = None

MONOCLE_INSTRUMENTOR = "monocle_apptrace"

class MonocleInstrumentor(BaseInstrumentor):
workflow_name: str = ""
user_wrapper_methods: list[Union[dict,WrapperMethod]] = []
Expand All @@ -47,7 +54,9 @@ def instrumentation_dependencies(self) -> Collection[str]:

def _instrument(self, **kwargs):
tracer_provider: TracerProvider = kwargs.get("tracer_provider")
tracer = get_tracer(instrumenting_module_name="monocle_apptrace", tracer_provider=tracer_provider)
global monocle_tracer_provider
monocle_tracer_provider = tracer_provider
tracer = get_tracer(instrumenting_module_name=MONOCLE_INSTRUMENTOR, tracer_provider=tracer_provider)

final_method_list = []
if self.union_with_default_methods is True:
Expand Down Expand Up @@ -136,4 +145,60 @@ def on_processor_start(span: Span, parent_context):
)

def set_context_properties(properties: dict) -> None:
attach(set_value(SESSION_PROPERTIES_KEY, properties))
attach(set_value(SESSION_PROPERTIES_KEY, properties))


def propagate_trace_id(traceId = "", use_trace_context = False):
try:
if traceId.startswith("0x"):
traceId = traceId.lstrip("0x")
tracer = get_tracer(instrumenting_module_name= MONOCLE_INSTRUMENTOR, tracer_provider= monocle_tracer_provider)
initial_id_generator = tracer.id_generator
_parent_span_context = get_current() if use_trace_context else None
if traceId and is_valid_trace_id_uuid(traceId):
tracer.id_generator = FixedIdGenerator(uuid.UUID(traceId).int)

span = tracer.start_span(name = "parent_placeholder_span", context= _parent_span_context)
updated_span_context = set_span_in_context(span=span, context= _parent_span_context)
updated_span_context = set_value("root_span_id", span.get_span_context().span_id, updated_span_context)
token = attach(updated_span_context)

span.end()
tracer.id_generator = initial_id_generator
return token
except:
logger.warning("Failed to propagate trace id")
return


def propagate_trace_id_from_traceparent():
propagate_trace_id(use_trace_context = True)


def stop_propagate_trace_id(token) -> None:
try:
detach(token)
except:
logger.warning("Failed to stop propagating trace id")


def is_valid_trace_id_uuid(traceId: str) -> bool:
try:
uuid.UUID(traceId)
return True
except:
pass
return False


class FixedIdGenerator(id_generator.IdGenerator):
def __init__(
self,
trace_id: int) -> None:
self.trace_id = trace_id

def generate_span_id(self) -> int:
return random.getrandbits(64)

def generate_trace_id(self) -> int:
return self.trace_id
8 changes: 6 additions & 2 deletions src/monocle_apptrace/instrumentation/common/span_handler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import os
from importlib.metadata import version

from opentelemetry.context import get_current
from opentelemetry.context import get_value
from opentelemetry.sdk.trace import Span

Expand Down Expand Up @@ -147,4 +147,8 @@ def get_workflow_name(self, span: Span) -> str:
return None

def __is_root_span(self, curr_span: Span) -> bool:
return curr_span.parent is None
try:
if curr_span is not None and hasattr(curr_span, "parent"):
return curr_span.parent is None or get_current().get("root_span_id") == curr_span.parent.span_id
except Exception as e:
logger.warning(f"Error finding root span: {e}")
17 changes: 7 additions & 10 deletions src/monocle_apptrace/instrumentation/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,25 +41,22 @@ def with_tracer_wrapper(func):

def _with_tracer(tracer, handler, to_wrap):
def wrapper(wrapped, instance, args, kwargs):
token = None
try:
# get and log the parent span context if injected by the application
# This is useful for debugging and tracing of Azure functions
_parent_span_context = get_current()
if _parent_span_context is not None and _parent_span_context.get(_SPAN_KEY, None):
parent_span: Span = _parent_span_context.get(_SPAN_KEY, None)
is_invalid_span = isinstance(parent_span, NonRecordingSpan)
if is_invalid_span:
token = attach(context={})
is_span = isinstance(parent_span, NonRecordingSpan)
if is_span:
logger.debug(
f"Parent span is found with trace id {hex(parent_span.get_span_context().trace_id)}")
except Exception as e:
logger.error("Exception in attaching parent context: %s", e)

val = func(tracer, handler, to_wrap, wrapped, instance, args, kwargs)
# Detach the token if it was set
if token:
try:
detach(token=token)
except Exception as e:
logger.error("Exception in detaching parent context: %s", e)
return val

return wrapper

return _with_tracer
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/test_haystack_opensearch_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ def test_haystack_opensearch_sample(setup):
span_attributes = span.attributes
if "span.type" in span_attributes and span_attributes["span.type"] == "retrieval":
# Assertions for all retrieval attributes
assert span_attributes["entity.1.name"] == "OpenSearchVectorSearch"
assert span_attributes["entity.1.type"] == "vectorstore.OpenSearchVectorSearch"
assert span_attributes["entity.1.name"] == "OpenSearchDocumentStore"
assert span_attributes["entity.1.type"] == "vectorstore.OpenSearchDocumentStore"
assert "entity.1.deployment" in span_attributes
assert span_attributes["entity.2.name"] == "sentence-transformers/all-mpnet-base-v2"
assert span_attributes["entity.2.type"] == "model.embedding.sentence-transformers/all-mpnet-base-v2"
Expand Down
135 changes: 135 additions & 0 deletions tests/unit/flask_instrumented_app_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import unittest
import threading
import time
from flask import Flask
from opentelemetry.instrumentation.flask import FlaskInstrumentor
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from monocle_apptrace.instrumentation.common.instrumentor import setup_monocle_telemetry
from monocle_apptrace.instrumentation.common.wrapper_method import WrapperMethod
from common.dummy_class import DummyClass, dummy_wrapper
from common.custom_exporter import CustomConsoleSpanExporter
from opentelemetry import trace

class TestValidateResponseMultithreaded(unittest.TestCase):

def setUp(self):
self.app = Flask(__name__)
FlaskInstrumentor().instrument_app(app=self.app, enable_commenter=True, commenter_options={})

@self.app.route("/")
def hello_world():
app_name = "test"
dummy_class_1 = DummyClass()
dummy_class_1.dummy_method()
tracer = trace.get_tracer(__name__)
with tracer.start_as_current_span("child_span") as child_span:
return "<p>Hello, World!</p>"

# Set up telemetry
self.capturing_exporter = CustomConsoleSpanExporter()
span_processor = SimpleSpanProcessor(self.capturing_exporter)
self.instrumentor = setup_monocle_telemetry(
workflow_name="test_1",
span_processors=[span_processor],
wrapper_methods=[
WrapperMethod(
package="common.dummy_class",
object_name="DummyClass",
method="dummy_method",
span_name="langchain.workflow",
wrapper_method=dummy_wrapper
)
],
)

def tearDown(self) -> None:
try:
if self.instrumentor is not None:
self.instrumentor.uninstrument()
except Exception as e:
print("Uninstrument failed:", e)
return super().tearDown()

@staticmethod
def send_request(app, headers, captured_responses):
client = app.test_client()
response = client.get("/", headers=headers)
captured_responses.append(response)
time.sleep(3)

def test_validate_response_multithreaded(self):
trace_id = "0af7651916cd43dd8448eb211c80319c"
parent_id = "b7ad6b7169203331"
traceparent = f"00-{trace_id}-{parent_id}-01"
headers = {"traceparent": traceparent}

num_threads = 3
threads = []
captured_responses = []

# Create threads to send requests
for _ in range(num_threads):
thread = threading.Thread(target=self.send_request, args=(self.app, headers, captured_responses))
threads.append(thread)
thread.start()

for thread in threads:
thread.join()

# Validate the responses
for response in captured_responses:
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data.decode("utf-8"), "<p>Hello, World!</p>")
self.assertEqual(response.request.headers.get("traceparent"), traceparent)

# Validate the spans
spans = self.capturing_exporter.captured_spans
self.assertGreater(len(spans), 0, "No spans were captured.")

root_spans = []
child_spans = []
grand_child_spans = []

for span in spans:
if span.name == "GET /":
root_spans.append(span)
elif span.name == "langchain.workflow":
child_spans.append(span)
elif span.name == "child_span":
grand_child_spans.append(span)

for root_span in root_spans:
root_trace_id = f"{root_span.context.trace_id:032x}"
root_parent_id = f"{root_span.parent.span_id:016x}"
self.assertEqual(root_trace_id, trace_id)
self.assertEqual(root_parent_id, parent_id)

for child_span in child_spans:
child_trace_id = f"{child_span.context.trace_id:032x}"
child_parent_id = f"{child_span.parent.span_id:016x}"
matching_root_span = next(
(root for root in root_spans if f"{root.context.span_id:016x}" == child_parent_id),
None,
)
matching_root_span_id = f"{matching_root_span.context.span_id:016x}"
self.assertEqual(child_parent_id, matching_root_span_id)
self.assertEqual(child_trace_id, trace_id)

for grandchild_span in grand_child_spans:
grandchild_trace_id = f"{grandchild_span.context.trace_id:032x}"
grandchild_parent_id = f"{grandchild_span.parent.span_id:016x}"

matching_child_span = next(
(child for child in child_spans if f"{child.parent.span_id:016x}" == grandchild_parent_id),
None,
)
matching_child_span_id = f"{matching_child_span.parent.span_id:016x}"
self.assertEqual(grandchild_parent_id, matching_child_span_id)
self.assertEqual(grandchild_trace_id, trace_id, "Trace IDs do not match.")

self.assertEqual(len(root_spans), num_threads)
self.assertEqual(len(child_spans), num_threads)
self.assertEqual(len(grand_child_spans), num_threads)

if __name__ == "__main__":
unittest.main()

0 comments on commit a0ece17

Please sign in to comment.