diff --git a/src/monocle_apptrace/utils.py b/src/monocle_apptrace/utils.py index ab386a7..ad2d9f0 100644 --- a/src/monocle_apptrace/utils.py +++ b/src/monocle_apptrace/utils.py @@ -6,6 +6,7 @@ from opentelemetry.context import attach, set_value, get_value from monocle_apptrace.constants import azure_service_map, aws_service_map from json.decoder import JSONDecodeError +logger = logging.getLogger(__name__) embedding_model_context = {} @@ -161,4 +162,11 @@ def get_attribute(key: str) -> str: Returns: The value associated with the given key. """ - return get_value(key) \ No newline at end of file + return get_value(key) + +def get_workflow_name(span: Span) -> str: + try: + return get_value("workflow_name") or span.resource.attributes.get("service.name") + except Exception as e: + logger.exception(f"Error getting workflow name: {e}") + return None \ No newline at end of file diff --git a/src/monocle_apptrace/wrap_common.py b/src/monocle_apptrace/wrap_common.py index bbfe2ee..b8f0438 100644 --- a/src/monocle_apptrace/wrap_common.py +++ b/src/monocle_apptrace/wrap_common.py @@ -4,7 +4,7 @@ import inspect from urllib.parse import urlparse from opentelemetry.trace import Span, Tracer -from monocle_apptrace.utils import resolve_from_alias, update_span_with_infra_name, with_tracer_wrapper, get_embedding_model, get_attribute +from monocle_apptrace.utils import resolve_from_alias, update_span_with_infra_name, with_tracer_wrapper, get_embedding_model, get_attribute, get_workflow_name from monocle_apptrace.utils import set_attribute from opentelemetry.context import get_value, attach, set_value logger = logging.getLogger(__name__) @@ -105,7 +105,7 @@ def process_span(to_wrap, span, instance, args): # Check if the output_processor is a valid JSON (in Python, that means it's a dictionary) span_index = 1 if is_root_span(span): - workflow_name = get_value("workflow_name") + workflow_name = get_workflow_name(span) if workflow_name: span.set_attribute(f"entity.{span_index}.name", workflow_name) # workflow type diff --git a/tests/langchain_custom_output_processor_test.py b/tests/langchain_custom_output_processor_test.py index 53b9bbb..f68af9b 100644 --- a/tests/langchain_custom_output_processor_test.py +++ b/tests/langchain_custom_output_processor_test.py @@ -18,6 +18,7 @@ from langchain_community.vectorstores import faiss from langchain_core.messages.ai import AIMessage from langchain_core.runnables import RunnablePassthrough +from monocle_apptrace.wrap_common import WORKFLOW_TYPE_MAP from monocle_apptrace.constants import ( AZURE_APP_SERVICE_ENV_NAME, AZURE_APP_SERVICE_NAME, @@ -175,6 +176,9 @@ def test_llm_chain(self, test_name, test_input_infra, llm_type, mock_post): This can be used to do more asserts''' dataBodyStr = mock_post.call_args.kwargs['data'] dataJson = json.loads(dataBodyStr) # more asserts can be added on individual fields + root_attributes = [x for x in dataJson["batch"] if x["parent_id"] == "None"][0]["attributes"] + assert root_attributes["entity.1.name"] == app_name + assert root_attributes["entity.1.type"] == WORKFLOW_TYPE_MAP['langchain'] if llm_type == "FakeListLLM": llm_vector_store_retriever_span = [x for x in dataJson["batch"] if 'langchain.task.VectorStoreRetriever' in x["name"]][0] inference_span = [x for x in dataJson["batch"] if 'langchain.task.FakeListLLM' in x["name"]][0]