diff --git a/polytope_server/broker/broker.py b/polytope_server/broker/broker.py index c242b1c..392a460 100644 --- a/polytope_server/broker/broker.py +++ b/polytope_server/broker/broker.py @@ -23,7 +23,7 @@ from ..common import collection, queue, request_store from ..common.request import Status - +from ..common.observability.otel import restore_trace_context, create_new_span_consumer, create_new_span_producer, update_trace_context class Broker: def __init__(self, config): @@ -75,8 +75,11 @@ def check_requests(self): if self.check_limits(active_requests, wr): assert wr.status == Status.WAITING - active_requests.add(wr) - self.enqueue(wr) + # Restore the trace context for this request + extracted_ctx = restore_trace_context(wr) + with create_new_span_consumer("Enqueue request", request_id=wr.id, parent_context=extracted_ctx): + active_requests.add(wr) + self.enqueue(wr) if self.queue.count() >= self.max_queue_size: logging.info("Queue is full") @@ -133,11 +136,14 @@ def enqueue(self, request): logging.info("Queuing request", extra={"request_id": request.id}) try: - # Must update request_store before queue, worker checks request status immediately - request.set_status(Status.QUEUED) - self.request_store.update_request(request) - msg = queue.Message(request.serialize()) - self.queue.enqueue(msg) + with create_new_span_producer("Updating request", request_id=request.id): + # Must update request_store before queue, worker checks request status immediately + request.set_status(Status.QUEUED) + # Updating context for trace ctx propagation with the new span as parent + update_trace_context(request) + self.request_store.update_request(request) + msg = queue.Message(request.serialize()) + self.queue.enqueue(msg) except Exception as e: # If we fail to call this, the request will be stuck (POLY-21) logging.info( diff --git a/polytope_server/common/observability/otel.py b/polytope_server/common/observability/otel.py new file mode 100644 index 0000000..bb7c608 --- /dev/null +++ b/polytope_server/common/observability/otel.py @@ -0,0 +1,120 @@ +from opentelemetry import trace +from opentelemetry.propagate import inject, extract +from opentelemetry.trace import SpanKind, Status, StatusCode +from contextlib import contextmanager +import logging +from typing import Optional, Generator + +def add_or_update_trace_context(request, update: bool = False) -> None: + """ + Injects the current trace context into the request's OpenTelemetry trace context attribute. + + Args: + request: The request object to update. + update: Whether this is an update operation (default: False). + """ + carrier = {} + inject(carrier) + + if not hasattr(request, "otel_trace"): + request.otel_trace = {} + + request.otel_trace['carrier'] = carrier + + action = "Updated" if update else "Added" + logging.debug(f"[OTEL] {action} trace context with carrier: {carrier}") + + # Optionally set additional attributes on the current span + current_span = trace.get_current_span() + current_span.set_attribute("polytope.request.id", request.id) + +def add_trace_context(request) -> None: + """Adds a new trace context to the request.""" + add_or_update_trace_context(request, update=False) + +def update_trace_context(request) -> None: + """Updates the trace context in the request.""" + add_or_update_trace_context(request, update=True) + +def restore_trace_context(request) -> Optional[trace.Span]: + """ + Restores the trace context from the request. + + Args: + request: The request object containing the trace context. + + Returns: + The restored context, or None if not available. + """ + if not hasattr(request, 'otel_trace') or 'carrier' not in request.otel_trace: + logging.debug("[OTEL] No trace context found to restore.") + return None + + carrier = request.otel_trace['carrier'] + logging.debug(f"[OTEL] Restoring context from carrier: {carrier}") + extracted_context = extract(carrier) + + current_span = trace.get_current_span() + current_span.set_attribute("polytope.request.id", request.id) + + return extracted_context + +@contextmanager +def create_new_span( + span_name: str, + request_id: Optional[str] = None, + parent_context: Optional[trace.SpanContext] = None, + kind: SpanKind = SpanKind.SERVER, +) -> Generator[trace.Span, None, None]: + """ + Creates a new span with the specified attributes. + + Args: + span_name: Name of the span. + request_id: Optional request ID to associate with the span. + parent_context: Optional parent span context. + kind: The kind of span to create (default: SERVER). + role: Optional role to set as a span attribute. + + Yields: + The created span. + """ + tracer = trace.get_tracer(__name__) + attributes = {"polytope.request.id": request_id} if request_id else {} + + with tracer.start_as_current_span(span_name, context=parent_context, kind=kind, attributes=attributes) as span: + logging.debug(f"[OTEL] Created new span: {span_name}, parent: {parent_context}") + yield span + +@contextmanager +def create_new_span_internal(span_name: str, request_id: Optional[str] = None, parent_context: Optional[trace.SpanContext] = None) -> Generator[trace.Span, None, None]: + """Creates an internal span.""" + yield from create_new_span(span_name, request_id, parent_context, kind=SpanKind.INTERNAL) + +# Forcing span kind Server because of AWS representation +@contextmanager +def create_new_span_producer(span_name: str, request_id: Optional[str] = None, parent_context: Optional[trace.SpanContext] = None) -> Generator[trace.Span, None, None]: + """Creates a producer span.""" + yield from create_new_span(span_name, request_id, parent_context, kind=SpanKind.SERVER) + +# Forcing span kind Server because of AWS representation +@contextmanager +def create_new_span_consumer(span_name: str, request_id: Optional[str] = None, parent_context: Optional[trace.SpanContext] = None) -> Generator[trace.Span, None, None]: + """Creates a consumer span.""" + yield from create_new_span(span_name, request_id, parent_context, kind=SpanKind.SERVER) + +@contextmanager +def create_new_span_server(span_name: str, request_id: Optional[str] = None, parent_context: Optional[trace.SpanContext] = None) -> Generator[trace.Span, None, None]: + """Creates a server span.""" + yield from create_new_span(span_name, request_id, parent_context, kind=SpanKind.SERVER) + +def set_span_error(span: trace.Span, exception: Exception) -> None: + """ + Marks a span as having an error. + + Args: + span: The span to mark as an error. + exception: The exception to log. + """ + span.set_status(Status(StatusCode.ERROR, str(exception))) + logging.error(f"[OTEL] Span error set with exception: {exception}") diff --git a/polytope_server/common/request.py b/polytope_server/common/request.py index f2fdfb5..788dee2 100644 --- a/polytope_server/common/request.py +++ b/polytope_server/common/request.py @@ -24,7 +24,7 @@ import uuid from .user import User - +from .observability.otel import add_trace_context class Status(enum.Enum): WAITING = "waiting" @@ -57,6 +57,7 @@ class Request: "user_request", "content_length", "content_type", + "otel_trace", ] def __init__(self, from_dict=None, **kwargs): @@ -75,6 +76,9 @@ def __init__(self, from_dict=None, **kwargs): self.content_length = None self.content_type = "application/octet-stream" + # Adding context for OpenTelemetry in asynchronous processing + add_trace_context(self) + if from_dict: self.deserialize(from_dict) diff --git a/polytope_server/worker/worker.py b/polytope_server/worker/worker.py index 9a21d50..f262e7b 100644 --- a/polytope_server/worker/worker.py +++ b/polytope_server/worker/worker.py @@ -33,7 +33,7 @@ from ..common import request_store, staging from ..common.metric import WorkerInfo, WorkerStatusChange from ..common.request import Status - +from ..common.observability.otel import restore_trace_context, create_new_span_consumer, create_new_span_internal, set_span_error class Worker: """The worker: @@ -146,37 +146,41 @@ def run(self): id = self.queue_msg.body["id"] self.request = self.request_store.get_request(id) - # This occurs when a request has been revoked while it was on the queue - if self.request is None: - logging.info( - "Request no longer exists, ignoring", - extra={"request_id": id}, - ) - self.update_status("idle") - self.queue.ack(self.queue_msg) - - # Occurs if a request crashed a worker and the message gets requeued (status will be PROCESSING) - # We do not want to try this request again - elif self.request.status != Status.QUEUED: - logging.info( - "Request has unexpected status {}, setting to failed".format(self.request.status), - extra={"request_id": id}, - ) - self.request.set_status(Status.FAILED) - self.request_store.update_request(self.request) - self.update_status("idle") - self.queue.ack(self.queue_msg) - - # OK, process the request - else: - logging.info( - "Popped request from the queue, beginning worker thread.", - extra={"request_id": id}, - ) - self.request.set_status(Status.PROCESSING) - self.update_status("processing", request_id=self.request.id) - self.request_store.update_request(self.request) - self.future = self.thread_pool.submit(self.process_request, (self.request)) + # Restoring request ctx + extracted_ctx = restore_trace_context(self.request) + # Create a new span for enqueueing the message + with create_new_span_consumer("worker_processing_request", request_id=self.request.id, parent_context=extracted_ctx): + # This occurs when a request has been revoked while it was on the queue + if self.request is None: + logging.info( + "Request no longer exists, ignoring", + extra={"request_id": id}, + ) + self.update_status("idle") + self.queue.ack(self.queue_msg) + + # Occurs if a request crashed a worker and the message gets requeued (status will be PROCESSING) + # We do not want to try this request again + elif self.request.status != Status.QUEUED: + logging.info( + "Request has unexpected status {}, setting to failed".format(self.request.status), + extra={"request_id": id}, + ) + self.request.set_status(Status.FAILED) + self.request_store.update_request(self.request) + self.update_status("idle") + self.queue.ack(self.queue_msg) + + # OK, process the request + else: + logging.info( + "Popped request from the queue, beginning worker thread.", + extra={"request_id": id}, + ) + self.request.set_status(Status.PROCESSING) + self.update_status("processing", request_id=self.request.id) + self.request_store.update_request(self.request) + self.future = self.thread_pool.submit(self.process_request, (self.request)) else: self.update_status("idle") @@ -211,81 +215,105 @@ def run(self): def process_request(self, request): """Entrypoint for the worker thread.""" - id = request.id - collection = self.collections[request.collection] - - logging.info( - "Processing request on collection {}".format(collection.name), - extra={"request_id": id}, - ) - logging.info("Request is: {}".format(request.serialize())) - - input_data = self.fetch_input_data(request.url) + # Creating a new internal span for the full process_request block + with create_new_span_internal("Processing request", request.id) as process_span: + id = request.id + collection = self.collections[request.collection] - # Dispatch to listed datasources for this collection until we find one that handles the request - datasource = None - for ds in collection.datasources(): logging.info( - "Processing request using datasource {}".format(ds.get_type()), + "Processing request on collection {}".format(collection.name), extra={"request_id": id}, ) - if ds.dispatch(request, input_data): - datasource = ds - break + logging.info("Request is: {}".format(request.serialize())) - # Clean up - try: - # delete input data if it exists in staging (input data can come from external URLs too) - if input_data is not None: - if self.staging.query(id): - self.staging.delete(id) - - # upload result data - if datasource is not None: - request.url = self.staging.create(id, datasource.result(request), datasource.mime_type()) - - except Exception as e: - request.user_message += f"Failed to finalize request: [{str(type(e))}] {str(e)}" - logging.info(request.user_message, extra={"request_id": id}) - logging.exception("Failed to finalize request", extra={"request_id": id, "exception": str(e)}) - raise + input_data = self.fetch_input_data(request.url) - # Guarantee destruction of the datasource - finally: - if datasource is not None: - datasource.destroy(request) + # Dispatch to listed datasources for this collection until we find one that handles the request + datasource = None + for ds in collection.datasources(): + logging.info( + "Processing request using datasource {}".format(ds.get_type()), + extra={"request_id": id}, + ) - if datasource is None: - request.user_message += "Failed to process request." - logging.info(request.user_message, extra={"request_id": id}) - raise Exception("Failed to process request.") - else: - request.user_message += "Success" + # Creating new internal span for the datasource dispatch (calculate roundtrip time) + with create_new_span_internal("DataSource_{}".format(ds.get_type())) as span_ds: + if ds.dispatch(request, input_data): + datasource = ds + span_ds.set_attribute("polytope.datasource", ds.get_type()) + break - return + # Clean up + try: + # Creating new internal span for finalizing the request process + with create_new_span_internal("Finalizing request", request_id=id): + # delete input data if it exists in staging (input data can come from external URLs too) + if input_data is not None: + # New span for deleting time block + with create_new_span_internal("Deleting input data", request_id=id): + if self.staging.query(id): + self.staging.delete(id) + + # upload result data + if datasource is not None: + # Span for uploading data + with create_new_span_internal("Uploading result data", request_id=id): + request.url = self.staging.create(id, datasource.result(request), datasource.mime_type()) + # Getting key (name + ext) from url + object_id = id + ("." + request.url.split("/")[-1].split(".")[-1]) + # Getting data size in bytes + _, size = self.staging.stat(object_id) + process_span.set_attributes({ + "polytope.request.url": request.url, + "polytope.request.size": size + }) + + except Exception as e: + request.user_message += f"Failed to finalize request: [{str(type(e))}] {str(e)}" + logging.info(request.user_message, extra={"request_id": id}) + logging.exception("Failed to finalize request", extra={"request_id": id, "exception": str(e)}) + set_span_error(process_span, e) + raise + + # Guarantee destruction of the datasource + finally: + if datasource is not None: + datasource.destroy(request) + + if datasource is None: + request.user_message += "Failed to process request." + logging.info(request.user_message, extra={"request_id": id}) + set_span_error(process_span, Exception(request.user_message)) + raise Exception("Failed to process request.") + else: + request.user_message += "Success" + + return def fetch_input_data(self, url): """Downloads input data from external URL or staging""" - if url != "": - try: - response = requests.get(url, proxies=self.proxies) - response.raise_for_status() - except ( - requests.exceptions.ConnectionError, - requests.exceptions.HTTPError, - ): - logging.info("Retrying requests.get without proxies after failure") - response = requests.get(url) - response.raise_for_status() - - if response.status_code == 200: - logging.info("Downloaded data of size {} from {}".format(sys.getsizeof(response._content), url)) - return response._content - else: - raise Exception( - "Could not download data from {}, got {} : {}".format(url, response.status_code, response._content) - ) - return None + with create_new_span_internal("Fetching input data") as fetch_input_span: + if url != "": + try: + response = requests.get(url, proxies=self.proxies) + response.raise_for_status() + except ( + requests.exceptions.ConnectionError, + requests.exceptions.HTTPError, + ): + logging.info("Retrying requests.get without proxies after failure") + response = requests.get(url) + response.raise_for_status() + + if response.status_code == 200: + logging.info("Downloaded data of size {} from {}".format(sys.getsizeof(response._content), url)) + return response._content + else: + error_message = "Could not download data from {}, got {} : {}".format(url, response.status_code, response._content) + set_span_error(fetch_input_span, Exception(error_message)) + raise Exception(error_message) + + return None def on_request_complete(self, request): """Called when the future exits cleanly""" diff --git a/requirements.txt b/requirements.txt index 7dd2ed6..f34305a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,3 +33,4 @@ PyYAML==6.0.2 redis==5.0.8 requests==2.32.3 Werkzeug==3.0.4 +opentelemetry-api==1.27.0 diff --git a/tests/unit/test_opentelemtry.py b/tests/unit/test_opentelemtry.py new file mode 100644 index 0000000..19452af --- /dev/null +++ b/tests/unit/test_opentelemtry.py @@ -0,0 +1,106 @@ +import unittest +from polytope_server.common.observability.otel import ( + add_trace_context, + update_trace_context, + restore_trace_context, + create_new_span, + set_span_error, +) +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry import trace +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +from opentelemetry.trace import get_current_span, StatusCode + +class MockRequest: + def __init__(self, request_id): + self.id = request_id + self.otel_trace = {} + +class TestOpenTelemetryUtils(unittest.TestCase): + + @classmethod + def setUpClass(cls): + # Set up a real TracerProvider with an in-memory exporter + cls.span_exporter = InMemorySpanExporter() + tracer_provider = TracerProvider() + tracer_provider.add_span_processor(SimpleSpanProcessor(cls.span_exporter)) + # Set the TracerProvider globally + trace.set_tracer_provider(tracer_provider) + cls.tracer = tracer_provider.get_tracer(__name__) + + def setUp(self): + # Clear the exported spans before each test + self.span_exporter.clear() + + def test_add_trace_context(self): + with self.tracer.start_as_current_span("test_span"): + mock_request = MockRequest("test_id") + add_trace_context(mock_request) + + # Ensure a carrier was injected into the request's trace context + self.assertIn("carrier", mock_request.otel_trace) + self.assertIn("traceparent", mock_request.otel_trace["carrier"]) + + # Ensure the span has the correct attribute + current_span = get_current_span() + self.assertIn("polytope.request.id", current_span.attributes) + self.assertEqual(current_span.attributes["polytope.request.id"], "test_id") + + def test_update_trace_context(self): + with self.tracer.start_as_current_span("test_span"): + mock_request = MockRequest("test_id") + update_trace_context(mock_request) + + # Ensure the carrier was updated + self.assertIn("carrier", mock_request.otel_trace) + self.assertIn("traceparent", mock_request.otel_trace["carrier"]) + + # Validate the span attributes + current_span = get_current_span() + self.assertIn("polytope.request.id", current_span.attributes) + self.assertEqual(current_span.attributes["polytope.request.id"], "test_id") + + def test_restore_trace_context(self): + with self.tracer.start_as_current_span("test_span"): + # Add a carrier to the request + mock_request = MockRequest("test_id") + mock_request.otel_trace["carrier"] = {"traceparent": "00-1234567890abcdef1234567890abcdef-1234567890abcdef-01"} + + restored_context = restore_trace_context(mock_request) + self.assertIsNotNone(restored_context) + + # Ensure the current span has the correct attribute + current_span = get_current_span() + self.assertIn("polytope.request.id", current_span.attributes) + self.assertEqual(current_span.attributes["polytope.request.id"], "test_id") + + def test_create_new_span(self): + with create_new_span("test_span", request_id="test_id") as span: + self.assertTrue(span.is_recording()) + self.assertIn("polytope.request.id", span.attributes) + self.assertEqual(span.attributes["polytope.request.id"], "test_id") + + # Verify the span was recorded correctly + spans = self.span_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + recorded_span = spans[0] + self.assertEqual(recorded_span.name, "test_span") + self.assertIn("polytope.request.id", recorded_span.attributes) + self.assertEqual(recorded_span.attributes["polytope.request.id"], "test_id") + + def test_set_span_error(self): + with self.tracer.start_as_current_span("test_span") as span: + exception = Exception("Test exception") + set_span_error(span, exception) + + # Verify the span status + self.assertEqual(span.status.status_code, StatusCode.ERROR) + self.assertEqual(span.status.description, "Test exception") + + # Verify the span was recorded with error + spans = self.span_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + recorded_span = spans[0] + self.assertEqual(recorded_span.status.status_code, StatusCode.ERROR) + self.assertEqual(recorded_span.status.description, "Test exception")