Skip to content

Commit

Permalink
Merge branch 'main' into ndjson_format_for_exporters
Browse files Browse the repository at this point in the history
  • Loading branch information
Hansrajr authored Oct 22, 2024
2 parents 8704fcc + 2b5112f commit 0bb528c
Show file tree
Hide file tree
Showing 5 changed files with 279 additions and 79 deletions.
1 change: 1 addition & 0 deletions src/monocle_apptrace/instrumentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def setup_monocle_telemetry(
})
span_processors = span_processors or [BatchSpanProcessor(get_monocle_exporter())]
trace_provider = TracerProvider(resource=resource)
attach(set_value("workflow_name", workflow_name))
tracer_provider_default = trace.get_tracer_provider()
provider_type = type(tracer_provider_default).__name__
is_proxy_provider = "Proxy" in provider_type
Expand Down
94 changes: 52 additions & 42 deletions src/monocle_apptrace/wrap_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from opentelemetry.trace import Span, Tracer
from monocle_apptrace.utils import resolve_from_alias, update_span_with_infra_name, with_tracer_wrapper, get_embedding_model, get_attribute
from monocle_apptrace.utils import set_attribute

from opentelemetry.context import get_value, attach, set_value
logger = logging.getLogger(__name__)
WORKFLOW_TYPE_KEY = "workflow_type"
DATA_INPUT_KEY = "data.input"
Expand Down Expand Up @@ -93,47 +93,61 @@ def task_wrapper(tracer: Tracer, to_wrap, wrapped, instance, args, kwargs):
name = f"langchain.task.{instance.__class__.__name__}"

with tracer.start_as_current_span(name) as span:
if "output_processor" in to_wrap:
process_span(to_wrap["output_processor"], span, instance, args)
process_span(to_wrap, span, instance, args)
pre_task_processing(to_wrap, instance, args, span)
return_value = wrapped(*args, **kwargs)
post_task_processing(to_wrap, span, return_value)

return return_value


def process_span(output_processor, span, instance, args):
def process_span(to_wrap, span, instance, args):
# Check if the output_processor is a valid JSON (in Python, that means it's a dictionary)
if isinstance(output_processor, dict) and len(output_processor) > 0:
if 'type' in output_processor:
span.set_attribute("span.type", output_processor['type'])
else:
logger.warning("type of span not found or incorrect written in entity json")
count = 0
if 'attributes' in output_processor:
count = len(output_processor["attributes"])
span.set_attribute("entity.count", count)
span_index = 1
for processors in output_processor["attributes"]:
for processor in processors:
if 'attribute' in processor and 'accessor' in processor:
attribute_name = f"entity.{span_index}.{processor['attribute']}"
try:
result = eval(processor['accessor'])(instance, args)
if result and isinstance(result, str):
span.set_attribute(attribute_name, result)
except Exception as e:
pass

else:
logger.warning("attribute or accessor not found or incorrect written in entity json")
span_index += 1
else:
logger.warning("attributes not found or incorrect written in entity json")
span.set_attribute("span.count", count)
span_index = 1
if is_root_span(span):
workflow_name = get_value("workflow_name")
if workflow_name:
span.set_attribute(f"entity.{span_index}.name", workflow_name)
# workflow type
package_name = to_wrap.get('package')
for (package, workflow_type) in WORKFLOW_TYPE_MAP.items():
if (package_name is not None and package in package_name):
span.set_attribute(f"entity.{span_index}.type", workflow_type)
span_index += 1
if 'output_processor' in to_wrap:
output_processor=to_wrap['output_processor']
if isinstance(output_processor, dict) and len(output_processor) > 0:
if 'type' in output_processor:
span.set_attribute("span.type", output_processor['type'])
else:
logger.warning("type of span not found or incorrect written in entity json")
count = 0
if 'attributes' in output_processor:
count = len(output_processor["attributes"])
span.set_attribute("entity.count", count)
span_index = 1
for processors in output_processor["attributes"]:
for processor in processors:
attribute = processor.get('attribute')
accessor = processor.get('accessor')

if attribute and accessor:
attribute_name = f"entity.{span_index}.{attribute}"
try:
result = eval(accessor)(instance, args)
if result and isinstance(result, str):
span.set_attribute(attribute_name, result)
except Exception as e:
logger.error(f"Error processing accessor: {e}")
else:
logger.warning(f"{' and '.join([key for key in ['attribute', 'accessor'] if not processor.get(key)])} not found or incorrect in entity JSON")
span_index += 1
else:
logger.warning("attributes not found or incorrect written in entity json")
span.set_attribute("span.count", count)

else:
logger.warning("empty or entities json is not in correct format")
else:
logger.warning("empty or entities json is not in correct format")


def post_task_processing(to_wrap, span, return_value):
Expand All @@ -149,9 +163,6 @@ def post_task_processing(to_wrap, span, return_value):
def pre_task_processing(to_wrap, instance, args, span):
try:
if is_root_span(span):
workflow_name = span.resource.attributes.get("service.name")
span.set_attribute("workflow_name", workflow_name)
update_workflow_type(to_wrap, span)
update_span_with_prompt_input(to_wrap=to_wrap, wrapped_args=args, span=span)
update_span_with_infra_name(span, INFRA_SERVICE_KEY)

Expand All @@ -175,8 +186,7 @@ async def atask_wrapper(tracer, to_wrap, wrapped, instance, args, kwargs):
else:
name = f"langchain.task.{instance.__class__.__name__}"
with tracer.start_as_current_span(name) as span:
if "output_processor" in to_wrap:
process_span(to_wrap["output_processor"], span, instance, args)
process_span(to_wrap["output_processor"], span, instance, args)
pre_task_processing(to_wrap, instance, args, span)
return_value = await wrapped(*args, **kwargs)
post_task_processing(to_wrap, span, return_value)
Expand Down Expand Up @@ -205,8 +215,8 @@ async def allm_wrapper(tracer, to_wrap, wrapped, instance, args, kwargs):
span.add_event(DATA_INPUT_KEY, {QUERY: input_arg_text})
provider_name = set_provider_name(instance)
instance_args = {"provider_name": provider_name}
if 'output_processor' in to_wrap:
process_span(to_wrap['output_processor'], span, instance, instance_args)

process_span(to_wrap, span, instance, instance_args)

return_value = await wrapped(*args, **kwargs)
if 'haystack.components.retrievers' in to_wrap['package'] and 'haystack.retriever' in span.name:
Expand Down Expand Up @@ -238,8 +248,8 @@ def llm_wrapper(tracer: Tracer, to_wrap, wrapped, instance, args, kwargs):
span.add_event(DATA_INPUT_KEY, {QUERY: input_arg_text})
provider_name = set_provider_name(instance)
instance_args = {"provider_name": provider_name}
if 'output_processor' in to_wrap:
process_span(to_wrap['output_processor'], span, instance, instance_args)

process_span(to_wrap, span, instance, instance_args)

return_value = wrapped(*args, **kwargs)
if 'haystack.components.retrievers' in to_wrap['package'] and 'haystack.retriever' in span.name:
Expand Down
2 changes: 1 addition & 1 deletion tests/langchain_custom_output_processor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from fake_list_llm import FakeListLLM
from parameterized import parameterized

from src.monocle_apptrace.wrap_common import task_wrapper
from monocle_apptrace.wrap_common import task_wrapper

logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
Expand Down
182 changes: 182 additions & 0 deletions tests/langchain_workflow_name_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@

import json
import logging
import os
import time

import unittest
from unittest.mock import ANY, MagicMock, patch

import pytest
import requests
from dummy_class import DummyClass
from embeddings_wrapper import HuggingFaceEmbeddings
from http_span_exporter import HttpSpanExporter
from langchain.prompts import PromptTemplate
from langchain.schema import StrOutputParser
from langchain_community.vectorstores import faiss
from langchain_core.messages.ai import AIMessage
from langchain_core.runnables import RunnablePassthrough
from monocle_apptrace.constants import (
AZURE_APP_SERVICE_ENV_NAME,
AZURE_APP_SERVICE_NAME,
AZURE_FUNCTION_NAME,
AZURE_FUNCTION_WORKER_ENV_NAME,
AZURE_ML_ENDPOINT_ENV_NAME,
AZURE_ML_SERVICE_NAME,
AWS_LAMBDA_ENV_NAME,
AWS_LAMBDA_SERVICE_NAME
)
from monocle_apptrace.instrumentor import (
MonocleInstrumentor,
set_context_properties,
setup_monocle_telemetry,
)
from monocle_apptrace.wrap_common import (
SESSION_PROPERTIES_KEY,
INFRA_SERVICE_KEY,
PROMPT_INPUT_KEY,
PROMPT_OUTPUT_KEY,
QUERY,
RESPONSE,
update_span_from_llm_response,
)
from monocle_apptrace.wrapper import WrapperMethod
from opentelemetry import trace
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter

from fake_list_llm import FakeListLLM
from parameterized import parameterized

from src.monocle_apptrace.wrap_common import task_wrapper

logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
fileHandler = logging.FileHandler('traces.txt' ,'w')
formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(name)s: %(message)s')
fileHandler.setFormatter(formatter)
logger.addHandler(fileHandler)

class TestWorkflowEntityProperties(unittest.TestCase):

prompt = PromptTemplate.from_template(
"""
<s> [INST] You are an assistant for question-answering tasks. Use the following pieces of retrieved context
to answer the question. If you don't know the answer, just say that you don't know. Use three sentences
maximum and keep the answer concise. [/INST] </s>
[INST] Question: {question}
Context: {context}
Answer: [/INST]
"""
)
ragText = """A latte is a coffee drink that consists of espresso, milk, and foam.\
It is served in a large cup or tall glass and has more milk compared to other espresso-based drinks.\
Latte art can be created on the surface of the drink using the milk."""

def __format_docs(self, docs):
return "\n\n ".join(doc.page_content for doc in docs)

def __createChain(self):

resource = Resource(attributes={
SERVICE_NAME: "coffee_rag_fake"
})
traceProvider = TracerProvider(resource=resource)
exporter = ConsoleSpanExporter()
monocleProcessor = BatchSpanProcessor(exporter)

traceProvider.add_span_processor(monocleProcessor)
trace.set_tracer_provider(traceProvider)
self.instrumentor = MonocleInstrumentor()
self.instrumentor.instrument()
self.processor = monocleProcessor
responses =[self.ragText]
llm = FakeListLLM(responses=responses)
llm.api_base = "https://example.com/"
embeddings = HuggingFaceEmbeddings(model_id = "multi-qa-mpnet-base-dot-v1")
my_path = os.path.abspath(os.path.dirname(__file__))
model_path = os.path.join(my_path, "./vector_data/coffee_embeddings")
vectorstore = faiss.FAISS.load_local(model_path, embeddings, allow_dangerous_deserialization = True)

retriever = vectorstore.as_retriever()

rag_chain = (
{"context": retriever| self.__format_docs, "question": RunnablePassthrough()}
| self.prompt
| llm
| StrOutputParser()
)
return rag_chain

def setUp(self):
os.environ["HTTP_API_KEY"] = "key1"
os.environ["HTTP_INGESTION_ENDPOINT"] = "https://localhost:3000/api/v1/traces"


def tearDown(self) -> None:
return super().tearDown()

@parameterized.expand([
("1", AZURE_ML_ENDPOINT_ENV_NAME, AZURE_ML_SERVICE_NAME),
("2", AZURE_FUNCTION_WORKER_ENV_NAME, AZURE_FUNCTION_NAME),
("3", AZURE_APP_SERVICE_ENV_NAME, AZURE_APP_SERVICE_NAME),
("4", AWS_LAMBDA_ENV_NAME, AWS_LAMBDA_SERVICE_NAME),
])
@patch.object(requests.Session, 'post')
def test_llm_chain(self, test_name, test_input_infra, test_output_infra, mock_post):
app_name = "test"
wrap_method = MagicMock(return_value=3)
setup_monocle_telemetry(
workflow_name=app_name,
span_processors=[
BatchSpanProcessor(HttpSpanExporter("https://localhost:3000/api/v1/traces"))
],
wrapper_methods=[])
try:

os.environ[test_input_infra] = "1"
context_key = "context_key_1"
context_value = "context_value_1"
set_context_properties({context_key: context_value})

self.chain = self.__createChain()
mock_post.return_value.status_code = 201
mock_post.return_value.json.return_value = 'mock response'

query = "what is latte"
response = self.chain.invoke(query, config={})
assert response == self.ragText
time.sleep(5)
mock_post.assert_called_with(
url = 'https://localhost:3000/api/v1/traces',
data=ANY,
timeout=ANY
)

'''mock_post.call_args gives the parameters used to make post call.
This can be used to do more asserts'''
dataBodyStr = mock_post.call_args.kwargs['data']
dataJson = json.loads(dataBodyStr) # more asserts can be added on individual fields

root_span = [x for x in dataJson["batch"] if x["parent_id"] == "None"][0]

# workflow_name and workflow_type in new format entity.{index}.name and entity.{index}.type

assert root_span["attributes"]["entity.1.name"] == "test"
assert root_span["attributes"]["entity.1.type"] == "workflow.langchain"

finally:
os.environ.pop(test_input_infra)
try:
if(self.instrumentor is not None):
self.instrumentor.uninstrument()
except Exception as e:
print("Uninstrument failed:", e)


if __name__ == '__main__':
unittest.main()


Loading

0 comments on commit 0bb528c

Please sign in to comment.