Skip to content

Commit

Permalink
added event generation for non-async llm and retriever
Browse files Browse the repository at this point in the history
Signed-off-by: Ravi Anne <[email protected]>
  • Loading branch information
Ravi Anne committed Jul 18, 2024
1 parent a12d8cd commit 325061b
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 94 deletions.
59 changes: 36 additions & 23 deletions src/monocle_apptrace/wrap_common.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@

# Copyright (C) Okahu Inc 2023-2024. All rights reserved

import logging
import os
from opentelemetry.trace import Tracer, Span

from monocle_apptrace.utils import with_tracer_wrapper, resolve_from_alias
from opentelemetry.trace import Span, Tracer

from monocle_apptrace.utils import resolve_from_alias, with_tracer_wrapper

logger = logging.getLogger(__name__)
WORKFLOW_TYPE_KEY = "workflow_type"
CONTEXT_INPUT_KEY = "workflow_context_input"
CONTEXT_OUTPUT_KEY = "workflow_context_output"
PROMPT_INPUT_KEY = "workflow_input"
PROMPT_OUTPUT_KEY = "workflow_output"
CONTEXT_INPUT_KEY = "context_input"
CONTEXT_OUTPUT_KEY = "context_output"
PROMPT_INPUT_KEY = "input"
PROMPT_OUTPUT_KEY = "output"
QUERY = "question"
RESPONSE = "response"
TAGS = "tags"
CONTEXT_PROPERTIES_KEY = "workflow_context_properties"


Expand All @@ -28,19 +32,24 @@ def task_wrapper(tracer: Tracer, to_wrap, wrapped, instance, args, kwargs):
# Some Langchain objects are wrapped elsewhere, so we ignore them here
if instance.__class__.__name__ in ("AgentExecutor"):
return wrapped(*args, **kwargs)

if hasattr(instance, "name") and instance.name:
name = f"{to_wrap.get('span_name')}.{instance.name.lower()}"
elif to_wrap.get("span_name"):
name = to_wrap.get("span_name")
else:
name = f"langchain.task.{instance.__class__.__name__}"
kind = to_wrap.get("kind")

with tracer.start_as_current_span(name) as span:
if is_root_span(span):
update_span_with_prompt_input(to_wrap=to_wrap, wrapped_args=args, span=span)

#capture the tags attribute of the instance if present, else ignore
try:
span.set_attribute(TAGS, getattr(instance, TAGS))
except AttributeError:
pass
update_span_with_context_input(to_wrap=to_wrap, wrapped_args=args, span=span)
return_value = wrapped(*args, **kwargs)
update_span_with_context_output(to_wrap=to_wrap, return_value=return_value, span=span)
Expand Down Expand Up @@ -91,18 +100,18 @@ async def allm_wrapper(tracer, to_wrap, wrapped, instance, args, kwargs):
kind = to_wrap.get("kind")
with tracer.start_as_current_span(name) as span:
update_llm_endpoint(curr_span= span, instance=instance)

return_value = await wrapped(*args, **kwargs)

return return_value

@with_tracer_wrapper
def llm_wrapper(tracer: Tracer, to_wrap, wrapped, instance, args, kwargs):

# Some Langchain objects are wrapped elsewhere, so we ignore them here
if instance.__class__.__name__ in ("AgentExecutor"):
return wrapped(*args, **kwargs)

if callable(to_wrap.get("span_name_getter")):
name = to_wrap.get("span_name_getter")(instance)

Expand All @@ -117,7 +126,7 @@ def llm_wrapper(tracer: Tracer, to_wrap, wrapped, instance, args, kwargs):
update_llm_endpoint(curr_span= span, instance=instance)

return_value = wrapped(*args, **kwargs)
update_span_from_llm_response(response = return_value, span = span)
update_span_from_llm_response(response = return_value, span = span)

return return_value

Expand All @@ -133,7 +142,8 @@ def update_llm_endpoint(curr_span: Span, instance):
model_name = resolve_from_alias(instance.__dict__ , ["model","model_name"])
curr_span.set_attribute("openai_model_name", model_name)
# handling AzureOpenAI deployment
deployment_name = resolve_from_alias(instance.__dict__ , [ "engine", "azure_deployment", "deployment_name", "deployment_id", "deployment"])
deployment_name = resolve_from_alias(instance.__dict__ , [ "engine", "azure_deployment",
"deployment_name", "deployment_id", "deployment"])
curr_span.set_attribute("az_openai_deployment", deployment_name)
# handling the inference endpoint
inference_ep = resolve_from_alias(instance.__dict__,["azure_endpoint","api_base"])
Expand All @@ -156,9 +166,8 @@ def update_span_from_llm_response(response, span: Span):
span.set_attribute("completion_tokens", token_usage.get("completion_tokens"))
span.set_attribute("prompt_tokens", token_usage.get("prompt_tokens"))
span.set_attribute("total_tokens", token_usage.get("total_tokens"))

# extract token usage from llamaindex openai
if (response is not None and hasattr(response, "raw")):
if(response is not None and hasattr(response, "raw")):
if response.raw is not None:
token_usage = response.raw.get("usage")
if token_usage is not None:
Expand All @@ -180,24 +189,28 @@ def update_span_with_context_input(to_wrap, wrapped_args ,span: Span):
package_name: str = to_wrap.get('package')
if("langchain_core.retrievers" in package_name):
input_arg_text = wrapped_args[0]
span.set_attribute(CONTEXT_INPUT_KEY, input_arg_text)
span.add_event(CONTEXT_INPUT_KEY, {QUERY:input_arg_text})
if("llama_index.core.indices.base_retriever" in package_name):
input_arg_text = wrapped_args[0].query_str
span.set_attribute(CONTEXT_INPUT_KEY, input_arg_text)
span.add_event(CONTEXT_INPUT_KEY, {QUERY:input_arg_text})

def update_span_with_context_output(to_wrap, return_value ,span: Span):
package_name: str = to_wrap.get('package')
if("llama_index.core.indices.base_retriever" in package_name):
output_arg_text = return_value[0].text
span.set_attribute(CONTEXT_OUTPUT_KEY, output_arg_text)
span.add_event(CONTEXT_OUTPUT_KEY, {RESPONSE:output_arg_text})

def update_span_with_prompt_input(to_wrap, wrapped_args ,span: Span):
input_arg_text = wrapped_args[0]
span.set_attribute(PROMPT_INPUT_KEY, input_arg_text)

if isinstance(input_arg_text, dict):
span.add_event(PROMPT_INPUT_KEY,input_arg_text)
else:
span.add_event(PROMPT_INPUT_KEY,{QUERY:input_arg_text})

def update_span_with_prompt_output(to_wrap, wrapped_args ,span: Span):
package_name: str = to_wrap.get('package')
if type(wrapped_args) == str:
span.set_attribute(PROMPT_OUTPUT_KEY, wrapped_args)
if(isinstance(wrapped_args, str)):
span.add_event(PROMPT_OUTPUT_KEY, {RESPONSE:wrapped_args})
if("llama_index.core.base.base_query_engine" in package_name):
span.set_attribute(PROMPT_OUTPUT_KEY, wrapped_args.response)
span.add_event(PROMPT_OUTPUT_KEY, {RESPONSE:wrapped_args.response})
3 changes: 2 additions & 1 deletion tests/http_span_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import logging
import os
from typing import Optional, Sequence

import requests
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.trace.export import SpanExporter, SpanExportResult

REQUESTS_SUCCESS_STATUS_CODES = (200, 202)
REQUESTS_SUCCESS_STATUS_CODES = (200, 202, 201)

logger = logging.getLogger(__name__)

Expand Down
108 changes: 55 additions & 53 deletions tests/langchain_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import requests
from dummy_class import DummyClass
from embeddings_wrapper import HuggingFaceEmbeddings
from http_span_exporter import HttpSpanExporter
from langchain.llms.fake import FakeListLLM
from langchain.prompts import PromptTemplate
from langchain.schema import StrOutputParser
Expand All @@ -25,16 +26,15 @@
CONTEXT_PROPERTIES_KEY,
PROMPT_INPUT_KEY,
PROMPT_OUTPUT_KEY,
QUERY,
RESPONSE,
update_span_from_llm_response,
)
from monocle_apptrace.wrapper import WrapperMethod
from opentelemetry import trace
from opentelemetry.instrumentation.utils import unwrap
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanProcessor, ConsoleSpanExporter

from http_span_exporter import HttpSpanExporter
from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter

logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
Expand Down Expand Up @@ -74,9 +74,8 @@ def __createChain(self):

traceProvider.add_span_processor(monocleProcessor)
trace.set_tracer_provider(traceProvider)
instrumentor = MonocleInstrumentor()
instrumentor.instrument()
self.instrumentor = instrumentor
self.instrumentor = MonocleInstrumentor()
self.instrumentor.instrument()
self.processor = monocleProcessor
responses=[self.ragText]
llm = FakeListLLM(responses=responses)
Expand All @@ -99,55 +98,66 @@ def __createChain(self):
def setUp(self):
os.environ["HTTP_API_KEY"] = "key1"
os.environ["HTTP_INGESTION_ENDPOINT"] = "https://localhost:3000/api/v1/traces"


def tearDown(self) -> None:
print("cleaning up with teardown")
try:
self.instrumentor.uninstrument()
except:
print("teardown errors")

return super().tearDown()

def to_json(self,obj):
return json.dumps(obj, indent=4, default=lambda obj: obj.__dict__)

@patch.object(requests.Session, 'post')
def test_llm_chain(self, mock_post):
context_key = "context_key_1"
context_value = "context_value_1"
set_context_properties({context_key: context_value})

self.chain = self.__createChain()

mock_post.return_value.status_code = 201
mock_post.return_value.json.return_value = 'mock response'

query = "what is latte"
response = self.chain.invoke(query, config={})
assert response == self.ragText
time.sleep(5)
mock_post.assert_called_with(
url = 'https://localhost:3000/api/v1/traces',
data=ANY,
timeout=ANY
)

'''mock_post.call_args gives the parameters used to make post call.
This can be used to do more asserts'''
dataBodyStr = mock_post.call_args.kwargs['data']
dataJson = json.loads(dataBodyStr) # more asserts can be added on individual fields
assert len(dataJson['batch']) == 7


root_attributes = [x for x in dataJson["batch"] if x["parent_id"] == "None"][0]["attributes"]
assert root_attributes[PROMPT_INPUT_KEY] == query
assert root_attributes[PROMPT_OUTPUT_KEY] == TestHandler.ragText
assert root_attributes[f"{CONTEXT_PROPERTIES_KEY}.{context_key}"] == context_value
try:
context_key = "context_key_1"
context_value = "context_value_1"
set_context_properties({context_key: context_value})

self.chain = self.__createChain()
mock_post.return_value.status_code = 201
mock_post.return_value.json.return_value = 'mock response'

query = "what is latte"
response = self.chain.invoke(query, config={})
assert response == self.ragText
time.sleep(5)
mock_post.assert_called_with(
url = 'https://localhost:3000/api/v1/traces',
data=ANY,
timeout=ANY
)

'''mock_post.call_args gives the parameters used to make post call.
This can be used to do more asserts'''
dataBodyStr = mock_post.call_args.kwargs['data']
dataJson = json.loads(dataBodyStr) # more asserts can be added on individual fields
assert len(dataJson['batch']) == 7

root_span = [x for x in dataJson["batch"] if x["parent_id"] == "None"][0]
root_span_attributes = root_span["attributes"]
root_span_events = root_span["events"]

def get_event_attributes(events, key):
return [event['attributes'] for event in events if event['name'] == key][0]

input_event_attributes = get_event_attributes(root_span_events, PROMPT_INPUT_KEY)
output_event_attributes = get_event_attributes(root_span_events, PROMPT_OUTPUT_KEY)

assert input_event_attributes[QUERY] == query
assert output_event_attributes[RESPONSE] == TestHandler.ragText
assert root_span_attributes[f"{CONTEXT_PROPERTIES_KEY}.{context_key}"] == context_value

for spanObject in dataJson['batch']:
assert not spanObject["context"]["span_id"].startswith("0x")
assert not spanObject["context"]["trace_id"].startswith("0x")
finally:
try:
if(self.instrumentor is not None):
self.instrumentor.uninstrument()
except Exception as e:
print("Uninstrument failed:", e)

for spanObject in dataJson['batch']:
assert spanObject["context"]["span_id"].startswith("0x") == False
assert spanObject["context"]["trace_id"].startswith("0x") == False
def test_custom_methods(self):
app_name = "test"
wrap_method = MagicMock(return_value=3)
Expand All @@ -172,9 +182,7 @@ def test_custom_methods(self):

def test_llm_response(self):
trace.set_tracer_provider(TracerProvider())

tracer = trace.get_tracer(__name__)

span = tracer.start_span("foo", start_time=0)

message = AIMessage(
Expand All @@ -188,12 +196,6 @@ def test_llm_response(self):
assert span.attributes.get("prompt_tokens") == 584
assert span.attributes.get("total_tokens") == 642







if __name__ == '__main__':
unittest.main()

Expand Down
Loading

0 comments on commit 325061b

Please sign in to comment.