From bb7e64d5891a03ced9ba9b34d2edc5b8fbd0295b Mon Sep 17 00:00:00 2001 From: hansrajr Date: Wed, 27 Nov 2024 08:51:30 +0530 Subject: [PATCH] moving inference events part into metamodel Signed-off-by: hansrajr --- .../inference/haystack_entities.json | 13 ++++ .../inference/langchain_entities.json | 13 ++++ .../inference/llamaindex_entities.json | 13 ++++ src/monocle_apptrace/wrap_common.py | 71 +++++++------------ tests/haystack_metamodel_unittest.py | 2 +- .../langchain_custom_output_processor_test.py | 1 + tests/langchain_test.py | 37 ---------- tests/llama_index_test.py | 37 ---------- 8 files changed, 68 insertions(+), 119 deletions(-) diff --git a/src/monocle_apptrace/metamodel/maps/attributes/inference/haystack_entities.json b/src/monocle_apptrace/metamodel/maps/attributes/inference/haystack_entities.json index 6d1aa2a..ecf1cc7 100644 --- a/src/monocle_apptrace/metamodel/maps/attributes/inference/haystack_entities.json +++ b/src/monocle_apptrace/metamodel/maps/attributes/inference/haystack_entities.json @@ -31,5 +31,18 @@ "accessor": "lambda arguments: 'model.llm.'+resolve_from_alias(arguments['instance'].__dict__, ['model', 'model_name'])" } ] + ], + "events": [ + { + "attributes": { + "system": "lambda arguments: extract_messages(arguments)[0]", + "user": "lambda arguments: extract_messages(arguments)[1]" + } + }, + { + "attributes": { + "assistant": "lambda response: extract_assistant_message(response)" + } + } ] } diff --git a/src/monocle_apptrace/metamodel/maps/attributes/inference/langchain_entities.json b/src/monocle_apptrace/metamodel/maps/attributes/inference/langchain_entities.json index dd9ee1a..9f33b6e 100644 --- a/src/monocle_apptrace/metamodel/maps/attributes/inference/langchain_entities.json +++ b/src/monocle_apptrace/metamodel/maps/attributes/inference/langchain_entities.json @@ -31,5 +31,18 @@ "accessor": "lambda arguments: 'model.llm.'+resolve_from_alias(arguments['instance'].__dict__, ['model', 'model_name'])" } ] + ], + "events": [ + { + "attributes": { + "system": "lambda arguments: extract_messages(arguments)[0]", + "user": "lambda arguments: extract_messages(arguments)[1]" + } + }, + { + "attributes": { + "assistant": "lambda response: extract_assistant_message(response)" + } + } ] } diff --git a/src/monocle_apptrace/metamodel/maps/attributes/inference/llamaindex_entities.json b/src/monocle_apptrace/metamodel/maps/attributes/inference/llamaindex_entities.json index dd9ee1a..9f33b6e 100644 --- a/src/monocle_apptrace/metamodel/maps/attributes/inference/llamaindex_entities.json +++ b/src/monocle_apptrace/metamodel/maps/attributes/inference/llamaindex_entities.json @@ -31,5 +31,18 @@ "accessor": "lambda arguments: 'model.llm.'+resolve_from_alias(arguments['instance'].__dict__, ['model', 'model_name'])" } ] + ], + "events": [ + { + "attributes": { + "system": "lambda arguments: extract_messages(arguments)[0]", + "user": "lambda arguments: extract_messages(arguments)[1]" + } + }, + { + "attributes": { + "assistant": "lambda response: extract_assistant_message(response)" + } + } ] } diff --git a/src/monocle_apptrace/wrap_common.py b/src/monocle_apptrace/wrap_common.py index 9c46067..885582a 100644 --- a/src/monocle_apptrace/wrap_common.py +++ b/src/monocle_apptrace/wrap_common.py @@ -8,7 +8,8 @@ from opentelemetry.sdk.trace import Span from monocle_apptrace.utils import resolve_from_alias, with_tracer_wrapper, get_embedding_model, get_attribute, get_workflow_name, set_embedding_model, set_app_hosting_identifier_attribute from monocle_apptrace.utils import set_attribute -from monocle_apptrace.utils import get_fully_qualified_class_name, flatten_dict, get_nested_value +from monocle_apptrace.utils import get_fully_qualified_class_name, get_nested_value + logger = logging.getLogger(__name__) WORKFLOW_TYPE_KEY = "workflow_type" DATA_INPUT_KEY = "data.input" @@ -129,7 +130,20 @@ def process_span(to_wrap, span, instance, args, kwargs, return_value): else: logger.warning("attributes not found or incorrect written in entity json") span.set_attribute("span.count", count) - + if 'events' in output_processor: + events = output_processor['events'] + for event in events: + event_attributes = {} + for key, accessor in event["attributes"].items(): + accessor_function = eval(accessor) + if "arguments" in accessor: + event_attributes[key] = accessor_function(args) + elif "response" in accessor: + event_attributes[key] = accessor_function(return_value) + if event_attributes.get('user'): + span.add_event(name=DATA_INPUT_KEY, attributes=event_attributes) + else: + span.add_event(name=DATA_OUTPUT_KEY, attributes=event_attributes) else: logger.warning("empty or entities json is not in correct format") @@ -288,7 +302,7 @@ def update_llm_endpoint(curr_span: Span, instance): def get_provider_name(instance): provider_url = "" inference_endpoint = "" - parsed_provider_url = None + parsed_provider_url = "" try: base_url = getattr(instance.client._client, "base_url", None) if base_url: @@ -350,7 +364,6 @@ def get_input_from_args(chain_args): def update_span_from_llm_response(response, span: Span, instance, args): - update_events_for_inference_span(response=response, span=span, args=args) if (response is not None and isinstance(response, dict) and "meta" in response) or ( response is not None and hasattr(response, "response_metadata")): token_usage = None @@ -393,7 +406,10 @@ def update_span_from_llm_response(response, span: Span, instance, args): def extract_messages(args): """Extract system and user messages""" system_message, user_message = "", "" - + args_input = get_attribute(DATA_INPUT_KEY) + if args_input: + user_message = args_input + return system_message, user_message if args and isinstance(args, tuple) and len(args) > 0: if hasattr(args[0], "messages") and isinstance(args[0].messages, list): for msg in args[0].messages: @@ -419,45 +435,12 @@ def extract_assistant_message(response): return response.content if hasattr(response, "message") and hasattr(response.message, "content"): return response.message.content - return "" - - -def inference_span_haystack(span, args, response): - args_input = get_attribute(DATA_INPUT_KEY) - span.add_event(name="data.input", attributes={"input": args_input}) - output = (response['replies'][0].content if hasattr(response['replies'][0], 'content') else response['replies'][0]) - span.add_event(name="data.output", attributes={"response": output}) - - -def inference_span_llama_index(span, args, response): - system_message, user_message = extract_messages(args) - span.add_event(name="data.input", attributes={"system": system_message, "user": user_message}) - assistant_message = extract_assistant_message(response) - if assistant_message: - span.add_event(name="data.output", attributes={"assistant": assistant_message}) - - -span_handlers = { - 'haystack.components.generators.openai.OpenAIGenerator': inference_span_haystack, - 'haystack_integrations.components.generators.mistral.chat.chat_generator.MistralChatGenerator': inference_span_haystack, - 'llamaindex.openai': inference_span_llama_index, -} - - -def update_events_for_inference_span(response, span, args): - if getattr(span, "attributes", {}).get("span.type") == "inference": - handler = span_handlers.get(span.name) - if handler: - handler(span, args, response) + if "replies" in response: + if hasattr(response['replies'][0], 'content'): + return response['replies'][0].content else: - system_message, user_message = extract_messages(args) - assistant_message = extract_assistant_message(response) - if system_message: - span.add_event(name="data.input", attributes={"system": system_message, "user": user_message}, ) - else: - span.add_event(name="data.input", attributes={"input": user_message}) - if assistant_message: - span.add_event(name="data.output", attributes={"assistant" if system_message else "response": assistant_message}, ) + return response['replies'][0] + return "" def extract_query_from_content(content): @@ -537,7 +520,7 @@ def update_span_with_prompt_output(to_wrap, wrapped_args, span: Span): if isinstance(resp, list) and hasattr(resp[0], 'content'): span.add_event(PROMPT_OUTPUT_KEY, {RESPONSE: resp[0].content}) else: - span.add_event(PROMPT_OUTPUT_KEY, {RESPONSE: resp}) + span.add_event(PROMPT_OUTPUT_KEY, {RESPONSE: resp[0]}) elif isinstance(wrapped_args, str): span.add_event(PROMPT_OUTPUT_KEY, {RESPONSE: wrapped_args}) elif isinstance(wrapped_args, dict): diff --git a/tests/haystack_metamodel_unittest.py b/tests/haystack_metamodel_unittest.py index d2fdd28..1254a4f 100644 --- a/tests/haystack_metamodel_unittest.py +++ b/tests/haystack_metamodel_unittest.py @@ -113,7 +113,7 @@ def test_haystack(self, mock_post): model_name_found = True for event in span['events']: if event['name'] == "data.input": - assert event['attributes']['input'] == message + assert event['attributes']['user'] == message input_event = True diff --git a/tests/langchain_custom_output_processor_test.py b/tests/langchain_custom_output_processor_test.py index 7251329..3d440bf 100644 --- a/tests/langchain_custom_output_processor_test.py +++ b/tests/langchain_custom_output_processor_test.py @@ -195,6 +195,7 @@ def test_llm_chain(self, test_name, test_input_infra, llm_type, mock_post): assert llm_azure_openai_span["attributes"]["entity.1.deployment"] == os.environ.get("AZURE_OPENAI_API_DEPLOYMENT") assert llm_azure_openai_span["attributes"]["entity.1.inference_endpoint"] == "https://example.com/" assert llm_azure_openai_span["attributes"]["entity.2.type"] == "model.llm.gpt-3.5-turbo-0125" + assert "Latte is a coffee drink" in llm_azure_openai_span["events"][1]['attributes']['assistant'] finally: os.environ.pop(test_input_infra) diff --git a/tests/langchain_test.py b/tests/langchain_test.py index 32db839..22e2ae2 100644 --- a/tests/langchain_test.py +++ b/tests/langchain_test.py @@ -44,7 +44,6 @@ QUERY, RESPONSE, update_span_from_llm_response, - update_events_for_inference_span, ) from monocle_apptrace.wrapper import WrapperMethod from opentelemetry import trace @@ -238,42 +237,6 @@ def test_llm_response(self): event_found = True assert event_found, "META_DATA event with token usage was not found" - def test_update_events_for_inference_span(self): - span = MagicMock() - span.attributes = {"span.type": "inference"} - span.events = [] - - def add_event(name, attributes): - span.events.append({"name": name, "attributes": attributes}) - - span.add_event.side_effect = add_event - - mock_message_system = MagicMock() - mock_message_system.type = "system" - mock_message_system.content = "System message" - - mock_message_user = MagicMock() - mock_message_user.type = "user" - mock_message_user.content = "User input message" - - args = (MagicMock(messages=[mock_message_system, mock_message_user]),) - response = MagicMock() - response.content = "Assistant response message" - - update_events_for_inference_span(response, span, args) - - event_names = [event["name"] for event in span.events] - self.assertIn("data.input", event_names, "data.input event not found") - self.assertIn("data.output", event_names, "data.output event not found") - - # Validate attributes for "data.input" - input_event = next(event for event in span.events if event["name"] == "data.input") - self.assertEqual(input_event["attributes"]["system"], "System message") - self.assertEqual(input_event["attributes"]["user"], "User input message") - - # Validate attributes for "data.output" - output_event = next(event for event in span.events if event["name"] == "data.output") - self.assertEqual(output_event["attributes"]["assistant"], "Assistant response message") if __name__ == '__main__': diff --git a/tests/llama_index_test.py b/tests/llama_index_test.py index fdf3857..cbfe371 100644 --- a/tests/llama_index_test.py +++ b/tests/llama_index_test.py @@ -26,7 +26,6 @@ QUERY, RESPONSE, llm_wrapper, - update_events_for_inference_span, ) from monocle_apptrace.wrapper import WrapperMethod from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter @@ -147,42 +146,6 @@ def get_event_attributes(events, key): assert type_found assert vectorstore_provider - span = MagicMock() - span.attributes = {"span.type": "inference"} - span.name = "llamaindex.openai" - span.events = [] - - def add_event(name, attributes): - span.events.append({"name": name, "attributes": attributes}) - - span.add_event.side_effect = add_event - mock_message_system = MagicMock() - mock_message_system.role = "system" - mock_message_system.content = "System message" - - mock_message_user = MagicMock() - mock_message_user.role = "user" - mock_message_user.content = "Query:User input message, Answer:" - - args = ([mock_message_system, mock_message_user],) - response = MagicMock() - response.content = "Assistant response message" - - update_events_for_inference_span(response, span, args) - span.add_event.assert_any_call( - name="data.input", - attributes={ - "system": "System message", - "user": "User input message,", - } - ) - - span.add_event.assert_any_call( - name="data.output", - attributes={ - "assistant": "Assistant response message" - } - ) if __name__ == '__main__':