From db843e90198110a2b7158fbba85910549f8d53a2 Mon Sep 17 00:00:00 2001 From: hansrajr Date: Mon, 25 Nov 2024 15:18:20 +0530 Subject: [PATCH] Refactor Signed-off-by: hansrajr --- src/monocle_apptrace/wrap_common.py | 147 ++++++++++++---------------- 1 file changed, 63 insertions(+), 84 deletions(-) diff --git a/src/monocle_apptrace/wrap_common.py b/src/monocle_apptrace/wrap_common.py index 93257ba..9c46067 100644 --- a/src/monocle_apptrace/wrap_common.py +++ b/src/monocle_apptrace/wrap_common.py @@ -32,51 +32,7 @@ "haystack": "workflow.haystack" } -def get_embedding_model_for_vectorstore(instance): - # Handle Langchain or other frameworks where vectorstore exists - if hasattr(instance, 'vectorstore'): - vectorstore_dict = instance.vectorstore.__dict__ - - # Use inspect to check if the embedding function is from Sagemaker - if 'embedding_func' in vectorstore_dict: - embedding_func = vectorstore_dict['embedding_func'] - class_name = embedding_func.__class__.__name__ - file_location = inspect.getfile(embedding_func.__class__) - - # Check if the class is SagemakerEndpointEmbeddings - if class_name == 'SagemakerEndpointEmbeddings' and 'langchain_community' in file_location: - # Set embedding_model as endpoint_name if it's Sagemaker - if hasattr(embedding_func, 'endpoint_name'): - return embedding_func.endpoint_name - - # Default to the regular embedding model if not Sagemaker - return instance.vectorstore.embeddings.model - - # Handle llama_index where _embed_model is present - if hasattr(instance, '_embed_model') and hasattr(instance._embed_model, 'model_name'): - return instance._embed_model.model_name - - # Fallback if no specific model is found - return "Unknown Embedding Model" - - -framework_vector_store_mapping = { - 'langchain_core.retrievers': lambda instance: { - 'provider': type(instance.vectorstore).__name__, - 'embedding_model': get_embedding_model_for_vectorstore(instance), - 'type': VECTOR_STORE, - }, - 'llama_index.core.indices.base_retriever': lambda instance: { - 'provider': type(instance._vector_store).__name__, - 'embedding_model': get_embedding_model_for_vectorstore(instance), - 'type': VECTOR_STORE, - }, - 'haystack.components.retrievers.in_memory': lambda instance: { - 'provider': instance.__dict__.get("document_store").__class__.__name__, - 'embedding_model': get_embedding_model(), - 'type': VECTOR_STORE, - }, -} + def get_embedding_model_haystack(instance): try: if hasattr(instance, 'get_component'): @@ -434,49 +390,72 @@ def update_span_from_llm_response(response, span: Span, instance, args): token_usage = None +def extract_messages(args): + """Extract system and user messages""" + system_message, user_message = "", "" + + if args and isinstance(args, tuple) and len(args) > 0: + if hasattr(args[0], "messages") and isinstance(args[0].messages, list): + for msg in args[0].messages: + if hasattr(msg, 'content') and hasattr(msg, 'type'): + if msg.type == "system": + system_message = msg.content + elif msg.type in ["user", "human"]: + user_message = msg.content + elif isinstance(args[0], list): + for msg in args[0]: + if hasattr(msg, 'content') and hasattr(msg, 'role'): + if msg.role == "system": + system_message = msg.content + elif msg.role in ["user", "human"]: + user_message = extract_query_from_content(msg.content) + return system_message, user_message + + +def extract_assistant_message(response): + if isinstance(response, str): + return response + if hasattr(response, "content"): + return response.content + if hasattr(response, "message") and hasattr(response.message, "content"): + return response.message.content + return "" + + +def inference_span_haystack(span, args, response): + args_input = get_attribute(DATA_INPUT_KEY) + span.add_event(name="data.input", attributes={"input": args_input}) + output = (response['replies'][0].content if hasattr(response['replies'][0], 'content') else response['replies'][0]) + span.add_event(name="data.output", attributes={"response": output}) + + +def inference_span_llama_index(span, args, response): + system_message, user_message = extract_messages(args) + span.add_event(name="data.input", attributes={"system": system_message, "user": user_message}) + assistant_message = extract_assistant_message(response) + if assistant_message: + span.add_event(name="data.output", attributes={"assistant": assistant_message}) + + +span_handlers = { + 'haystack.components.generators.openai.OpenAIGenerator': inference_span_haystack, + 'haystack_integrations.components.generators.mistral.chat.chat_generator.MistralChatGenerator': inference_span_haystack, + 'llamaindex.openai': inference_span_llama_index, +} + + def update_events_for_inference_span(response, span, args): if getattr(span, "attributes", {}).get("span.type") == "inference": - system_message = "" - user_message = "" - if span.name in ['haystack.components.generators.openai.OpenAIGenerator', 'haystack_integrations.components.generators.mistral.chat.chat_generator.MistralChatGenerator'] : - args_input = get_attribute(DATA_INPUT_KEY) - span.add_event(name="data.input", attributes={"input": args_input}, ) - span.add_event(name="data.output", attributes={"response": response['replies'][0].content if hasattr(response['replies'][0],'content') else response['replies'][0]}, ) - if args and isinstance(args, tuple) and len(args) > 0: - if hasattr(args[0], "messages") and isinstance(args[0].messages, list): - for msg in args[0].messages: - if hasattr(msg, 'content') and hasattr(msg, 'type'): - if msg.type == "system": - system_message = msg.content - elif msg.type in ["user", "human"]: - user_message = msg.content - if isinstance(args[0], list): - for msg in args[0]: - if hasattr(msg, 'content') and hasattr(msg, 'role'): - if msg.role == "system": - system_message = msg.content - elif msg.role in ["user", "human"]: - user_message = extract_query_from_content(msg.content) - if span.name == 'llamaindex.openai': - if args and isinstance(args, tuple): - chat_messages = args[0] - if isinstance(chat_messages, list): - for msg in chat_messages: - if hasattr(msg, "content") and hasattr(msg, "role"): - if msg.role == "system": - system_message = msg.content - elif msg.role in ["user", "human"]: - user_message = extract_query_from_content(msg.content) + handler = span_handlers.get(span.name) + if handler: + handler(span, args, response) + else: + system_message, user_message = extract_messages(args) + assistant_message = extract_assistant_message(response) if system_message: - span.add_event(name="data.input", attributes={"system": system_message, "user": user_message, }, ) + span.add_event(name="data.input", attributes={"system": system_message, "user": user_message}, ) else: - span.add_event(name="data.input", attributes={"input": user_message, }, ) - - assistant_message = "" - if isinstance(response, str): - assistant_message = response - if hasattr(response, "content") or hasattr(response, "message"): - assistant_message = getattr(response, "content", "") or response.message.content + span.add_event(name="data.input", attributes={"input": user_message}) if assistant_message: span.add_event(name="data.output", attributes={"assistant" if system_message else "response": assistant_message}, )