Skip to content

Commit

Permalink
moving inference events part into metamodel
Browse files Browse the repository at this point in the history
Signed-off-by: hansrajr <[email protected]>
  • Loading branch information
Hansrajr committed Nov 27, 2024
1 parent db843e9 commit bb7e64d
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 119 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,18 @@
"accessor": "lambda arguments: 'model.llm.'+resolve_from_alias(arguments['instance'].__dict__, ['model', 'model_name'])"
}
]
],
"events": [
{
"attributes": {
"system": "lambda arguments: extract_messages(arguments)[0]",
"user": "lambda arguments: extract_messages(arguments)[1]"
}
},
{
"attributes": {
"assistant": "lambda response: extract_assistant_message(response)"
}
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,18 @@
"accessor": "lambda arguments: 'model.llm.'+resolve_from_alias(arguments['instance'].__dict__, ['model', 'model_name'])"
}
]
],
"events": [
{
"attributes": {
"system": "lambda arguments: extract_messages(arguments)[0]",
"user": "lambda arguments: extract_messages(arguments)[1]"
}
},
{
"attributes": {
"assistant": "lambda response: extract_assistant_message(response)"
}
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,18 @@
"accessor": "lambda arguments: 'model.llm.'+resolve_from_alias(arguments['instance'].__dict__, ['model', 'model_name'])"
}
]
],
"events": [
{
"attributes": {
"system": "lambda arguments: extract_messages(arguments)[0]",
"user": "lambda arguments: extract_messages(arguments)[1]"
}
},
{
"attributes": {
"assistant": "lambda response: extract_assistant_message(response)"
}
}
]
}
71 changes: 27 additions & 44 deletions src/monocle_apptrace/wrap_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from opentelemetry.sdk.trace import Span
from monocle_apptrace.utils import resolve_from_alias, with_tracer_wrapper, get_embedding_model, get_attribute, get_workflow_name, set_embedding_model, set_app_hosting_identifier_attribute
from monocle_apptrace.utils import set_attribute
from monocle_apptrace.utils import get_fully_qualified_class_name, flatten_dict, get_nested_value
from monocle_apptrace.utils import get_fully_qualified_class_name, get_nested_value

logger = logging.getLogger(__name__)
WORKFLOW_TYPE_KEY = "workflow_type"
DATA_INPUT_KEY = "data.input"
Expand Down Expand Up @@ -129,7 +130,20 @@ def process_span(to_wrap, span, instance, args, kwargs, return_value):
else:
logger.warning("attributes not found or incorrect written in entity json")
span.set_attribute("span.count", count)

if 'events' in output_processor:
events = output_processor['events']
for event in events:
event_attributes = {}
for key, accessor in event["attributes"].items():
accessor_function = eval(accessor)
if "arguments" in accessor:
event_attributes[key] = accessor_function(args)
elif "response" in accessor:
event_attributes[key] = accessor_function(return_value)
if event_attributes.get('user'):
span.add_event(name=DATA_INPUT_KEY, attributes=event_attributes)
else:
span.add_event(name=DATA_OUTPUT_KEY, attributes=event_attributes)
else:
logger.warning("empty or entities json is not in correct format")

Expand Down Expand Up @@ -288,7 +302,7 @@ def update_llm_endpoint(curr_span: Span, instance):
def get_provider_name(instance):
provider_url = ""
inference_endpoint = ""
parsed_provider_url = None
parsed_provider_url = ""
try:
base_url = getattr(instance.client._client, "base_url", None)
if base_url:
Expand Down Expand Up @@ -350,7 +364,6 @@ def get_input_from_args(chain_args):


def update_span_from_llm_response(response, span: Span, instance, args):
update_events_for_inference_span(response=response, span=span, args=args)
if (response is not None and isinstance(response, dict) and "meta" in response) or (
response is not None and hasattr(response, "response_metadata")):
token_usage = None
Expand Down Expand Up @@ -393,7 +406,10 @@ def update_span_from_llm_response(response, span: Span, instance, args):
def extract_messages(args):
"""Extract system and user messages"""
system_message, user_message = "", ""

args_input = get_attribute(DATA_INPUT_KEY)
if args_input:
user_message = args_input
return system_message, user_message
if args and isinstance(args, tuple) and len(args) > 0:
if hasattr(args[0], "messages") and isinstance(args[0].messages, list):
for msg in args[0].messages:
Expand All @@ -419,45 +435,12 @@ def extract_assistant_message(response):
return response.content
if hasattr(response, "message") and hasattr(response.message, "content"):
return response.message.content
return ""


def inference_span_haystack(span, args, response):
args_input = get_attribute(DATA_INPUT_KEY)
span.add_event(name="data.input", attributes={"input": args_input})
output = (response['replies'][0].content if hasattr(response['replies'][0], 'content') else response['replies'][0])
span.add_event(name="data.output", attributes={"response": output})


def inference_span_llama_index(span, args, response):
system_message, user_message = extract_messages(args)
span.add_event(name="data.input", attributes={"system": system_message, "user": user_message})
assistant_message = extract_assistant_message(response)
if assistant_message:
span.add_event(name="data.output", attributes={"assistant": assistant_message})


span_handlers = {
'haystack.components.generators.openai.OpenAIGenerator': inference_span_haystack,
'haystack_integrations.components.generators.mistral.chat.chat_generator.MistralChatGenerator': inference_span_haystack,
'llamaindex.openai': inference_span_llama_index,
}


def update_events_for_inference_span(response, span, args):
if getattr(span, "attributes", {}).get("span.type") == "inference":
handler = span_handlers.get(span.name)
if handler:
handler(span, args, response)
if "replies" in response:
if hasattr(response['replies'][0], 'content'):
return response['replies'][0].content
else:
system_message, user_message = extract_messages(args)
assistant_message = extract_assistant_message(response)
if system_message:
span.add_event(name="data.input", attributes={"system": system_message, "user": user_message}, )
else:
span.add_event(name="data.input", attributes={"input": user_message})
if assistant_message:
span.add_event(name="data.output", attributes={"assistant" if system_message else "response": assistant_message}, )
return response['replies'][0]
return ""


def extract_query_from_content(content):
Expand Down Expand Up @@ -537,7 +520,7 @@ def update_span_with_prompt_output(to_wrap, wrapped_args, span: Span):
if isinstance(resp, list) and hasattr(resp[0], 'content'):
span.add_event(PROMPT_OUTPUT_KEY, {RESPONSE: resp[0].content})
else:
span.add_event(PROMPT_OUTPUT_KEY, {RESPONSE: resp})
span.add_event(PROMPT_OUTPUT_KEY, {RESPONSE: resp[0]})
elif isinstance(wrapped_args, str):
span.add_event(PROMPT_OUTPUT_KEY, {RESPONSE: wrapped_args})
elif isinstance(wrapped_args, dict):
Expand Down
2 changes: 1 addition & 1 deletion tests/haystack_metamodel_unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test_haystack(self, mock_post):
model_name_found = True
for event in span['events']:
if event['name'] == "data.input":
assert event['attributes']['input'] == message
assert event['attributes']['user'] == message
input_event = True


Expand Down
1 change: 1 addition & 0 deletions tests/langchain_custom_output_processor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def test_llm_chain(self, test_name, test_input_infra, llm_type, mock_post):
assert llm_azure_openai_span["attributes"]["entity.1.deployment"] == os.environ.get("AZURE_OPENAI_API_DEPLOYMENT")
assert llm_azure_openai_span["attributes"]["entity.1.inference_endpoint"] == "https://example.com/"
assert llm_azure_openai_span["attributes"]["entity.2.type"] == "model.llm.gpt-3.5-turbo-0125"
assert "Latte is a coffee drink" in llm_azure_openai_span["events"][1]['attributes']['assistant']

finally:
os.environ.pop(test_input_infra)
Expand Down
37 changes: 0 additions & 37 deletions tests/langchain_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
QUERY,
RESPONSE,
update_span_from_llm_response,
update_events_for_inference_span,
)
from monocle_apptrace.wrapper import WrapperMethod
from opentelemetry import trace
Expand Down Expand Up @@ -238,42 +237,6 @@ def test_llm_response(self):
event_found = True
assert event_found, "META_DATA event with token usage was not found"

def test_update_events_for_inference_span(self):
span = MagicMock()
span.attributes = {"span.type": "inference"}
span.events = []

def add_event(name, attributes):
span.events.append({"name": name, "attributes": attributes})

span.add_event.side_effect = add_event

mock_message_system = MagicMock()
mock_message_system.type = "system"
mock_message_system.content = "System message"

mock_message_user = MagicMock()
mock_message_user.type = "user"
mock_message_user.content = "User input message"

args = (MagicMock(messages=[mock_message_system, mock_message_user]),)
response = MagicMock()
response.content = "Assistant response message"

update_events_for_inference_span(response, span, args)

event_names = [event["name"] for event in span.events]
self.assertIn("data.input", event_names, "data.input event not found")
self.assertIn("data.output", event_names, "data.output event not found")

# Validate attributes for "data.input"
input_event = next(event for event in span.events if event["name"] == "data.input")
self.assertEqual(input_event["attributes"]["system"], "System message")
self.assertEqual(input_event["attributes"]["user"], "User input message")

# Validate attributes for "data.output"
output_event = next(event for event in span.events if event["name"] == "data.output")
self.assertEqual(output_event["attributes"]["assistant"], "Assistant response message")


if __name__ == '__main__':
Expand Down
37 changes: 0 additions & 37 deletions tests/llama_index_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
QUERY,
RESPONSE,
llm_wrapper,
update_events_for_inference_span,
)
from monocle_apptrace.wrapper import WrapperMethod
from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter
Expand Down Expand Up @@ -147,42 +146,6 @@ def get_event_attributes(events, key):
assert type_found
assert vectorstore_provider

span = MagicMock()
span.attributes = {"span.type": "inference"}
span.name = "llamaindex.openai"
span.events = []

def add_event(name, attributes):
span.events.append({"name": name, "attributes": attributes})

span.add_event.side_effect = add_event
mock_message_system = MagicMock()
mock_message_system.role = "system"
mock_message_system.content = "System message"

mock_message_user = MagicMock()
mock_message_user.role = "user"
mock_message_user.content = "Query:User input message, Answer:"

args = ([mock_message_system, mock_message_user],)
response = MagicMock()
response.content = "Assistant response message"

update_events_for_inference_span(response, span, args)
span.add_event.assert_any_call(
name="data.input",
attributes={
"system": "System message",
"user": "User input message,",
}
)

span.add_event.assert_any_call(
name="data.output",
attributes={
"assistant": "Assistant response message"
}
)


if __name__ == '__main__':
Expand Down

0 comments on commit bb7e64d

Please sign in to comment.