Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
Signed-off-by: hansrajr <[email protected]>
  • Loading branch information
Hansrajr committed Nov 25, 2024
1 parent f0ea598 commit db843e9
Showing 1 changed file with 63 additions and 84 deletions.
147 changes: 63 additions & 84 deletions src/monocle_apptrace/wrap_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,51 +32,7 @@
"haystack": "workflow.haystack"
}

def get_embedding_model_for_vectorstore(instance):
# Handle Langchain or other frameworks where vectorstore exists
if hasattr(instance, 'vectorstore'):
vectorstore_dict = instance.vectorstore.__dict__

# Use inspect to check if the embedding function is from Sagemaker
if 'embedding_func' in vectorstore_dict:
embedding_func = vectorstore_dict['embedding_func']
class_name = embedding_func.__class__.__name__
file_location = inspect.getfile(embedding_func.__class__)

# Check if the class is SagemakerEndpointEmbeddings
if class_name == 'SagemakerEndpointEmbeddings' and 'langchain_community' in file_location:
# Set embedding_model as endpoint_name if it's Sagemaker
if hasattr(embedding_func, 'endpoint_name'):
return embedding_func.endpoint_name

# Default to the regular embedding model if not Sagemaker
return instance.vectorstore.embeddings.model

# Handle llama_index where _embed_model is present
if hasattr(instance, '_embed_model') and hasattr(instance._embed_model, 'model_name'):
return instance._embed_model.model_name

# Fallback if no specific model is found
return "Unknown Embedding Model"


framework_vector_store_mapping = {
'langchain_core.retrievers': lambda instance: {
'provider': type(instance.vectorstore).__name__,
'embedding_model': get_embedding_model_for_vectorstore(instance),
'type': VECTOR_STORE,
},
'llama_index.core.indices.base_retriever': lambda instance: {
'provider': type(instance._vector_store).__name__,
'embedding_model': get_embedding_model_for_vectorstore(instance),
'type': VECTOR_STORE,
},
'haystack.components.retrievers.in_memory': lambda instance: {
'provider': instance.__dict__.get("document_store").__class__.__name__,
'embedding_model': get_embedding_model(),
'type': VECTOR_STORE,
},
}

def get_embedding_model_haystack(instance):
try:
if hasattr(instance, 'get_component'):
Expand Down Expand Up @@ -434,49 +390,72 @@ def update_span_from_llm_response(response, span: Span, instance, args):
token_usage = None


def extract_messages(args):
"""Extract system and user messages"""
system_message, user_message = "", ""

if args and isinstance(args, tuple) and len(args) > 0:
if hasattr(args[0], "messages") and isinstance(args[0].messages, list):
for msg in args[0].messages:
if hasattr(msg, 'content') and hasattr(msg, 'type'):
if msg.type == "system":
system_message = msg.content
elif msg.type in ["user", "human"]:
user_message = msg.content
elif isinstance(args[0], list):
for msg in args[0]:
if hasattr(msg, 'content') and hasattr(msg, 'role'):
if msg.role == "system":
system_message = msg.content
elif msg.role in ["user", "human"]:
user_message = extract_query_from_content(msg.content)
return system_message, user_message


def extract_assistant_message(response):
if isinstance(response, str):
return response
if hasattr(response, "content"):
return response.content
if hasattr(response, "message") and hasattr(response.message, "content"):
return response.message.content
return ""


def inference_span_haystack(span, args, response):
args_input = get_attribute(DATA_INPUT_KEY)
span.add_event(name="data.input", attributes={"input": args_input})
output = (response['replies'][0].content if hasattr(response['replies'][0], 'content') else response['replies'][0])
span.add_event(name="data.output", attributes={"response": output})


def inference_span_llama_index(span, args, response):
system_message, user_message = extract_messages(args)
span.add_event(name="data.input", attributes={"system": system_message, "user": user_message})
assistant_message = extract_assistant_message(response)
if assistant_message:
span.add_event(name="data.output", attributes={"assistant": assistant_message})


span_handlers = {
'haystack.components.generators.openai.OpenAIGenerator': inference_span_haystack,
'haystack_integrations.components.generators.mistral.chat.chat_generator.MistralChatGenerator': inference_span_haystack,
'llamaindex.openai': inference_span_llama_index,
}


def update_events_for_inference_span(response, span, args):
if getattr(span, "attributes", {}).get("span.type") == "inference":
system_message = ""
user_message = ""
if span.name in ['haystack.components.generators.openai.OpenAIGenerator', 'haystack_integrations.components.generators.mistral.chat.chat_generator.MistralChatGenerator'] :
args_input = get_attribute(DATA_INPUT_KEY)
span.add_event(name="data.input", attributes={"input": args_input}, )
span.add_event(name="data.output", attributes={"response": response['replies'][0].content if hasattr(response['replies'][0],'content') else response['replies'][0]}, )
if args and isinstance(args, tuple) and len(args) > 0:
if hasattr(args[0], "messages") and isinstance(args[0].messages, list):
for msg in args[0].messages:
if hasattr(msg, 'content') and hasattr(msg, 'type'):
if msg.type == "system":
system_message = msg.content
elif msg.type in ["user", "human"]:
user_message = msg.content
if isinstance(args[0], list):
for msg in args[0]:
if hasattr(msg, 'content') and hasattr(msg, 'role'):
if msg.role == "system":
system_message = msg.content
elif msg.role in ["user", "human"]:
user_message = extract_query_from_content(msg.content)
if span.name == 'llamaindex.openai':
if args and isinstance(args, tuple):
chat_messages = args[0]
if isinstance(chat_messages, list):
for msg in chat_messages:
if hasattr(msg, "content") and hasattr(msg, "role"):
if msg.role == "system":
system_message = msg.content
elif msg.role in ["user", "human"]:
user_message = extract_query_from_content(msg.content)
handler = span_handlers.get(span.name)
if handler:
handler(span, args, response)
else:
system_message, user_message = extract_messages(args)
assistant_message = extract_assistant_message(response)
if system_message:
span.add_event(name="data.input", attributes={"system": system_message, "user": user_message, }, )
span.add_event(name="data.input", attributes={"system": system_message, "user": user_message}, )
else:
span.add_event(name="data.input", attributes={"input": user_message, }, )

assistant_message = ""
if isinstance(response, str):
assistant_message = response
if hasattr(response, "content") or hasattr(response, "message"):
assistant_message = getattr(response, "content", "") or response.message.content
span.add_event(name="data.input", attributes={"input": user_message})
if assistant_message:
span.add_event(name="data.output", attributes={"assistant" if system_message else "response": assistant_message}, )

Expand Down

0 comments on commit db843e9

Please sign in to comment.