diff --git a/src/monocle_apptrace/metamodel/maps/langchain_methods.json b/src/monocle_apptrace/metamodel/maps/langchain_methods.json
index 377cb62..9618d9e 100644
--- a/src/monocle_apptrace/metamodel/maps/langchain_methods.json
+++ b/src/monocle_apptrace/metamodel/maps/langchain_methods.json
@@ -27,21 +27,24 @@
"object": "BaseChatModel",
"method": "ainvoke",
"wrapper_package": "wrap_common",
- "wrapper_method": "allm_wrapper"
+ "wrapper_method": "allm_wrapper",
+ "output_processor": ["metamodel/maps/attributes/inference/langchain_entities.json"]
},
{
"package": "langchain_core.language_models.llms",
"object": "LLM",
"method": "_generate",
"wrapper_package": "wrap_common",
- "wrapper_method": "llm_wrapper"
+ "wrapper_method": "llm_wrapper",
+ "output_processor": ["metamodel/maps/attributes/inference/langchain_entities.json"]
},
{
"package": "langchain_core.language_models.llms",
"object": "LLM",
"method": "_agenerate",
"wrapper_package": "wrap_common",
- "wrapper_method": "llm_wrapper"
+ "wrapper_method": "allm_wrapper",
+ "output_processor": ["metamodel/maps/attributes/inference/langchain_entities.json"]
},
{
"package": "langchain_core.retrievers",
@@ -57,7 +60,8 @@
"object": "BaseRetriever",
"method": "ainvoke",
"wrapper_package": "wrap_common",
- "wrapper_method": "atask_wrapper"
+ "wrapper_method": "atask_wrapper",
+ "output_processor": ["metamodel/maps/attributes/retrieval/langchain_entities.json"]
},
{
"package": "langchain.schema",
@@ -106,4 +110,4 @@
"wrapper_method": "atask_wrapper"
}
]
-}
+}
\ No newline at end of file
diff --git a/src/monocle_apptrace/metamodel/maps/llamaindex_methods.json b/src/monocle_apptrace/metamodel/maps/llamaindex_methods.json
index 81f6e52..a7292ea 100644
--- a/src/monocle_apptrace/metamodel/maps/llamaindex_methods.json
+++ b/src/monocle_apptrace/metamodel/maps/llamaindex_methods.json
@@ -15,7 +15,8 @@
"method": "aretrieve",
"span_name": "llamaindex.retrieve",
"wrapper_package": "wrap_common",
- "wrapper_method": "atask_wrapper"
+ "wrapper_method": "atask_wrapper",
+ "output_processor": ["metamodel/maps/attributes/retrieval/llamaindex_entities.json"]
},
{
"package": "llama_index.core.base.base_query_engine",
@@ -39,7 +40,8 @@
"method": "chat",
"span_name": "llamaindex.llmchat",
"wrapper_package": "wrap_common",
- "wrapper_method": "task_wrapper"
+ "wrapper_method": "task_wrapper",
+ "output_processor": ["metamodel/maps/attributes/inference/llamaindex_entities.json"]
},
{
"package": "llama_index.core.llms.custom",
@@ -47,7 +49,8 @@
"method": "achat",
"span_name": "llamaindex.llmchat",
"wrapper_package": "wrap_common",
- "wrapper_method": "atask_wrapper"
+ "wrapper_method": "atask_wrapper",
+ "output_processor": ["metamodel/maps/attributes/inference/llamaindex_entities.json"]
},
{
"package": "llama_index.llms.openai.base",
@@ -64,7 +67,8 @@
"method": "achat",
"span_name": "llamaindex.openai",
"wrapper_package": "wrap_common",
- "wrapper_method": "allm_wrapper"
+ "wrapper_method": "allm_wrapper",
+ "output_processor": ["metamodel/maps/attributes/inference/llamaindex_entities.json"]
}
]
-}
+}
\ No newline at end of file
diff --git a/src/monocle_apptrace/utils.py b/src/monocle_apptrace/utils.py
index 624580e..ab386a7 100644
--- a/src/monocle_apptrace/utils.py
+++ b/src/monocle_apptrace/utils.py
@@ -59,7 +59,10 @@ def load_output_processor(wrapper_method, attributes_config_base_path):
logger.info(f'Output processor file path is: {output_processor_file_path}')
if isinstance(output_processor_file_path, str) and output_processor_file_path: # Combined condition
- absolute_file_path = os.path.join(attributes_config_base_path, output_processor_file_path)
+ if not attributes_config_base_path:
+ absolute_file_path = os.path.abspath(output_processor_file_path)
+ else:
+ absolute_file_path = os.path.join(attributes_config_base_path, output_processor_file_path)
logger.info(f'Absolute file path is: {absolute_file_path}')
try:
@@ -107,7 +110,7 @@ def process_wrapper_method_config(
wrapper_method["span_name_getter"] = get_wrapper_method(
wrapper_method["span_name_getter_package"],
wrapper_method["span_name_getter_method"])
- if "output_processor" in wrapper_method:
+ if "output_processor" in wrapper_method and wrapper_method["output_processor"]:
load_output_processor(wrapper_method, attributes_config_base_path)
def get_wrapper_method(package_name: str, method_name: str):
@@ -158,4 +161,4 @@ def get_attribute(key: str) -> str:
Returns:
The value associated with the given key.
"""
- return get_value(key)
+ return get_value(key)
\ No newline at end of file
diff --git a/src/monocle_apptrace/wrap_common.py b/src/monocle_apptrace/wrap_common.py
index d30fa40..690cad6 100644
--- a/src/monocle_apptrace/wrap_common.py
+++ b/src/monocle_apptrace/wrap_common.py
@@ -76,6 +76,7 @@ def get_embedding_model_for_vectorstore(instance):
},
}
+
@with_tracer_wrapper
def task_wrapper(tracer: Tracer, to_wrap, wrapped, instance, args, kwargs):
"""Instruments and calls every function defined in TO_WRAP."""
@@ -93,40 +94,46 @@ def task_wrapper(tracer: Tracer, to_wrap, wrapped, instance, args, kwargs):
with tracer.start_as_current_span(name) as span:
if "output_processor" in to_wrap:
- process_span(to_wrap["output_processor"],span,instance,args)
+ process_span(to_wrap["output_processor"], span, instance, args)
pre_task_processing(to_wrap, instance, args, span)
return_value = wrapped(*args, **kwargs)
post_task_processing(to_wrap, span, return_value)
return return_value
-def process_span(output_processor,span,instance,args):
- # Check if the output_processor is a valid JSON (in Python, that means it's a dictionary)
- if isinstance(output_processor, dict) and len(output_processor)>0:
- if 'type' in output_processor:
- span.set_attribute("span.type", output_processor['type'])
- else:
- logger.warning("type of span not found or incorrect written in entity json")
- count=0
- if 'attributes' in output_processor:
- count = len(output_processor["attributes"])
- span.set_attribute("entity.count", count)
- span_index = 1
- for processors in output_processor["attributes"]:
- for processor in processors:
- if 'attribute' in processor and 'accessor' in processor:
- attribute_name = f"entity.{span_index}.{processor['attribute']}"
- result = eval(processor['accessor'])(instance, args)
- span.set_attribute(attribute_name, result)
- else:
- logger.warning("attribute or accessor not found or incorrect written in entity json")
- span_index += 1
- else:
- logger.warning("attributes not found or incorrect written in entity json")
- span.set_attribute("span.count", count)
+def process_span(output_processor, span, instance, args):
+ # Check if the output_processor is a valid JSON (in Python, that means it's a dictionary)
+ if isinstance(output_processor, dict) and len(output_processor) > 0:
+ if 'type' in output_processor:
+ span.set_attribute("span.type", output_processor['type'])
+ else:
+ logger.warning("type of span not found or incorrect written in entity json")
+ count = 0
+ if 'attributes' in output_processor:
+ count = len(output_processor["attributes"])
+ span.set_attribute("entity.count", count)
+ span_index = 1
+ for processors in output_processor["attributes"]:
+ for processor in processors:
+ if 'attribute' in processor and 'accessor' in processor:
+ attribute_name = f"entity.{span_index}.{processor['attribute']}"
+ try:
+ result = eval(processor['accessor'])(instance, args)
+ if result and isinstance(result, str):
+ span.set_attribute(attribute_name, result)
+ except Exception as e:
+ pass
+
+ else:
+ logger.warning("attribute or accessor not found or incorrect written in entity json")
+ span_index += 1
else:
- logger.warning("empty or entities json is not in correct format")
+ logger.warning("attributes not found or incorrect written in entity json")
+ span.set_attribute("span.count", count)
+
+ else:
+ logger.warning("empty or entities json is not in correct format")
def post_task_processing(to_wrap, span, return_value):
@@ -138,6 +145,7 @@ def post_task_processing(to_wrap, span, return_value):
except:
logger.exception("exception in post_task_processing")
+
def pre_task_processing(to_wrap, instance, args, span):
try:
if is_root_span(span):
@@ -150,7 +158,7 @@ def pre_task_processing(to_wrap, instance, args, span):
update_span_with_context_input(to_wrap=to_wrap, wrapped_args=args, span=span)
except:
logger.exception("exception in pre_task_processing")
-
+
@with_tracer_wrapper
async def atask_wrapper(tracer, to_wrap, wrapped, instance, args, kwargs):
@@ -167,6 +175,8 @@ async def atask_wrapper(tracer, to_wrap, wrapped, instance, args, kwargs):
else:
name = f"langchain.task.{instance.__class__.__name__}"
with tracer.start_as_current_span(name) as span:
+ if "output_processor" in to_wrap:
+ process_span(to_wrap["output_processor"], span, instance, args)
pre_task_processing(to_wrap, instance, args, span)
return_value = await wrapped(*args, **kwargs)
post_task_processing(to_wrap, span, return_value)
@@ -190,11 +200,18 @@ async def allm_wrapper(tracer, to_wrap, wrapped, instance, args, kwargs):
else:
name = f"langchain.task.{instance.__class__.__name__}"
with tracer.start_as_current_span(name) as span:
- update_llm_endpoint(curr_span=span, instance=instance)
+ if 'haystack.components.retrievers' in to_wrap['package'] and 'haystack.retriever' in span.name:
+ input_arg_text = get_attribute(DATA_INPUT_KEY)
+ span.add_event(DATA_INPUT_KEY, {QUERY: input_arg_text})
+ provider_name = set_provider_name(instance)
+ instance_args = {"provider_name": provider_name}
+ if 'output_processor' in to_wrap:
+ process_span(to_wrap['output_processor'], span, instance, instance_args)
return_value = await wrapped(*args, **kwargs)
-
- update_span_from_llm_response(response = return_value, span = span, instance=instance)
+ if 'haystack.components.retrievers' in to_wrap['package'] and 'haystack.retriever' in span.name:
+ update_span_with_context_output(to_wrap=to_wrap, return_value=return_value, span=span)
+ update_span_from_llm_response(response=return_value, span=span, instance=instance)
return return_value
@@ -227,7 +244,7 @@ def llm_wrapper(tracer: Tracer, to_wrap, wrapped, instance, args, kwargs):
return_value = wrapped(*args, **kwargs)
if 'haystack.components.retrievers' in to_wrap['package'] and 'haystack.retriever' in span.name:
update_span_with_context_output(to_wrap=to_wrap, return_value=return_value, span=span)
- update_span_from_llm_response(response = return_value, span = span,instance=instance)
+ update_span_from_llm_response(response=return_value, span=span, instance=instance)
return return_value
@@ -261,6 +278,7 @@ def update_llm_endpoint(curr_span: Span, instance):
inference_endpoint=inference_ep
)
+
def set_provider_name(instance):
provider_url = ""
@@ -293,7 +311,8 @@ def get_input_from_args(chain_args):
return chain_args[0]
return ""
-def update_span_from_llm_response(response, span: Span ,instance):
+
+def update_span_from_llm_response(response, span: Span, instance):
# extract token uasge from langchain openai
if (response is not None and hasattr(response, "response_metadata")):
response_metadata = response.response_metadata
@@ -326,6 +345,7 @@ def update_span_from_llm_response(response, span: Span ,instance):
except AttributeError:
token_usage = None
+
def update_workflow_type(to_wrap, span: Span):
package_name = to_wrap.get('package')
@@ -346,6 +366,7 @@ def update_span_with_context_input(to_wrap, wrapped_args, span: Span):
if input_arg_text:
span.add_event(DATA_INPUT_KEY, {QUERY: input_arg_text})
+
def update_span_with_context_output(to_wrap, return_value, span: Span):
package_name: str = to_wrap.get('package')
output_arg_text = ""
@@ -362,6 +383,7 @@ def update_span_with_context_output(to_wrap, return_value, span: Span):
if output_arg_text:
span.add_event(DATA_OUTPUT_KEY, {RESPONSE: output_arg_text})
+
def update_span_with_prompt_input(to_wrap, wrapped_args, span: Span):
input_arg_text = wrapped_args[0]
@@ -370,6 +392,7 @@ def update_span_with_prompt_input(to_wrap, wrapped_args, span: Span):
else:
span.add_event(PROMPT_INPUT_KEY, {QUERY: input_arg_text})
+
def update_span_with_prompt_output(to_wrap, wrapped_args, span: Span):
package_name: str = to_wrap.get('package')
if isinstance(wrapped_args, str):
@@ -377,5 +400,4 @@ def update_span_with_prompt_output(to_wrap, wrapped_args, span: Span):
if isinstance(wrapped_args, dict):
span.add_event(PROMPT_OUTPUT_KEY, wrapped_args)
if "llama_index.core.base.base_query_engine" in package_name:
- span.add_event(PROMPT_OUTPUT_KEY, {RESPONSE:wrapped_args.response})
-
\ No newline at end of file
+ span.add_event(PROMPT_OUTPUT_KEY, {RESPONSE: wrapped_args.response})
diff --git a/tests/entities.json b/tests/entities.json
new file mode 100644
index 0000000..5fc8ef4
--- /dev/null
+++ b/tests/entities.json
@@ -0,0 +1,17 @@
+{
+ "type": "retrieval",
+ "attributes": [
+ [
+ {
+ "_comment": "vector store name and type",
+ "attribute": "name",
+ "accessor": "lambda instance,args: type(instance.vectorstore).__name__"
+ },
+ {
+ "attribute": "type",
+ "accessor": "lambda instance,args: 'vectorstore.'+type(instance.vectorstore).__name__"
+ }
+ ]
+
+ ]
+}
\ No newline at end of file
diff --git a/tests/langchain_async_processor_test.py b/tests/langchain_async_processor_test.py
new file mode 100644
index 0000000..14c909f
--- /dev/null
+++ b/tests/langchain_async_processor_test.py
@@ -0,0 +1,167 @@
+from unittest import IsolatedAsyncioTestCase
+import unittest
+
+import json
+import logging
+import os
+import time
+
+import unittest
+from unittest.mock import ANY, MagicMock, patch
+from unittest import IsolatedAsyncioTestCase
+
+import requests
+from dummy_class import DummyClass
+from embeddings_wrapper import HuggingFaceEmbeddings
+from http_span_exporter import HttpSpanExporter
+from langchain.prompts import PromptTemplate
+from langchain.schema import StrOutputParser
+from langchain_community.vectorstores import faiss
+from langchain_core.messages.ai import AIMessage
+from langchain_core.runnables import RunnablePassthrough
+from monocle_apptrace.instrumentor import (
+ MonocleInstrumentor,
+ set_context_properties,
+ setup_monocle_telemetry,
+)
+from monocle_apptrace.wrap_common import (
+ SESSION_PROPERTIES_KEY,
+ PROMPT_INPUT_KEY,
+ PROMPT_OUTPUT_KEY,
+ QUERY,
+ RESPONSE,
+ update_span_from_llm_response,
+)
+from monocle_apptrace.wrapper import WrapperMethod
+from opentelemetry import trace
+from opentelemetry.sdk.resources import SERVICE_NAME, Resource
+from opentelemetry.sdk.trace import TracerProvider
+from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter
+from monocle_apptrace.wrap_common import allm_wrapper
+from fake_list_llm import FakeListLLM
+
+logger = logging.getLogger()
+logger.setLevel(logging.DEBUG)
+fileHandler = logging.FileHandler('traces.txt', 'w')
+formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(name)s: %(message)s')
+fileHandler.setFormatter(formatter)
+logger.addHandler(fileHandler)
+events = []
+
+
+class Test(IsolatedAsyncioTestCase):
+ prompt = PromptTemplate.from_template(
+ """
+ [INST] You are an assistant for question-answering tasks. Use the following pieces of retrieved context
+ to answer the question. If you don't know the answer, just say that you don't know. Use three sentences
+ maximum and keep the answer concise. [/INST]
+ [INST] Question: {question}
+ Context: {context}
+ Answer: [/INST]
+ """
+ )
+ ragText = """A latte is a coffee drink that consists of espresso, milk, and foam.\
+ It is served in a large cup or tall glass and has more milk compared to other espresso-based drinks.\
+ Latte art can be created on the surface of the drink using the milk."""
+
+ def __format_docs(self, docs):
+ return "\n\n ".join(doc.page_content for doc in docs)
+
+ def __createChain(self):
+
+ resource = Resource(attributes={
+ SERVICE_NAME: "coffee_rag_fake"
+ })
+ traceProvider = TracerProvider(resource=resource)
+ exporter = ConsoleSpanExporter()
+ monocleProcessor = BatchSpanProcessor(exporter)
+
+ traceProvider.add_span_processor(monocleProcessor)
+ trace.set_tracer_provider(traceProvider)
+ self.instrumentor = MonocleInstrumentor()
+ self.instrumentor.instrument()
+ self.processor = monocleProcessor
+ responses = [self.ragText]
+ llm = FakeListLLM(responses=responses)
+ llm.api_base = "https://example.com/"
+ embeddings = HuggingFaceEmbeddings(model_id="multi-qa-mpnet-base-dot-v1")
+ my_path = os.path.abspath(os.path.dirname(__file__))
+ model_path = os.path.join(my_path, "./vector_data/coffee_embeddings")
+ vectorstore = faiss.FAISS.load_local(model_path, embeddings, allow_dangerous_deserialization=True)
+
+ retriever = vectorstore.as_retriever()
+
+ rag_chain = (
+ {"context": retriever | self.__format_docs, "question": RunnablePassthrough()}
+ | self.prompt
+ | llm
+ | StrOutputParser()
+ )
+ return rag_chain
+
+ def setUp(self):
+ events.append("setUp")
+
+ async def asyncSetUp(self):
+ os.environ["HTTP_API_KEY"] = "key1"
+ os.environ["HTTP_INGESTION_ENDPOINT"] = "https://localhost:3000/api/v1/traces"
+
+ @patch.object(requests.Session, 'post')
+ async def test_response(self, mock_post):
+ app_name = "test"
+ wrap_method = MagicMock(return_value=3)
+ setup_monocle_telemetry(
+ workflow_name=app_name,
+ span_processors=[
+ BatchSpanProcessor(HttpSpanExporter("https://localhost:3000/api/v1/traces"))
+ ],
+ wrapper_methods=[
+ WrapperMethod(
+ package="langchain.chat_models.base",
+ object_name="BaseChatModel",
+ method="ainvoke",
+ wrapper=allm_wrapper
+ )
+
+ ])
+ try:
+ context_key = "context_key_1"
+ context_value = "context_value_1"
+ set_context_properties({context_key: context_value})
+
+ self.chain = self.__createChain()
+ mock_post.return_value.status_code = 201
+ mock_post.return_value.json.return_value = 'mock response'
+
+ query = "what is latte"
+ response = await self.chain.ainvoke(query, config={})
+ assert response == self.ragText
+ time.sleep(5)
+ mock_post.assert_called_with(
+ url='https://localhost:3000/api/v1/traces',
+ data=ANY,
+ timeout=ANY
+ )
+
+ '''mock_post.call_args gives the parameters used to make post call.
+ This can be used to do more asserts'''
+ dataBodyStr = mock_post.call_args.kwargs['data']
+ dataJson = json.loads(dataBodyStr) # more asserts can be added on individual fields
+
+ llm_span = [x for x in dataJson["batch"] if "FakeListLLM" in x["name"]][0]
+
+ assert llm_span["attributes"]["span.type"] == "inference"
+ assert llm_span["attributes"]["entity.1.provider_name"] == "example.com"
+ assert llm_span["attributes"]["entity.1.type"] == "inference.azure_oai"
+ assert llm_span["attributes"]["entity.1.inference_endpoint"] == "https://example.com/"
+
+ finally:
+ try:
+ if (self.instrumentor is not None):
+ self.instrumentor.uninstrument()
+ except Exception as e:
+ print("Uninstrument failed:", e)
+
+
+if __name__ == "__main__":
+ unittest.main()
\ No newline at end of file
diff --git a/tests/langchain_custom_output_processor_test.py b/tests/langchain_custom_output_processor_test.py
new file mode 100644
index 0000000..7f980d9
--- /dev/null
+++ b/tests/langchain_custom_output_processor_test.py
@@ -0,0 +1,190 @@
+
+import json
+import logging
+import os
+import time
+
+import unittest
+from unittest.mock import ANY, MagicMock, patch
+
+import pytest
+import requests
+from dummy_class import DummyClass
+from embeddings_wrapper import HuggingFaceEmbeddings
+from http_span_exporter import HttpSpanExporter
+from langchain.prompts import PromptTemplate
+from langchain.schema import StrOutputParser
+from langchain_community.vectorstores import faiss
+from langchain_core.messages.ai import AIMessage
+from langchain_core.runnables import RunnablePassthrough
+from monocle_apptrace.constants import (
+ AZURE_APP_SERVICE_ENV_NAME,
+ AZURE_APP_SERVICE_NAME,
+ AZURE_FUNCTION_NAME,
+ AZURE_FUNCTION_WORKER_ENV_NAME,
+ AZURE_ML_ENDPOINT_ENV_NAME,
+ AZURE_ML_SERVICE_NAME,
+ AWS_LAMBDA_ENV_NAME,
+ AWS_LAMBDA_SERVICE_NAME
+)
+from monocle_apptrace.instrumentor import (
+ MonocleInstrumentor,
+ set_context_properties,
+ setup_monocle_telemetry,
+)
+from monocle_apptrace.wrap_common import (
+ SESSION_PROPERTIES_KEY,
+ INFRA_SERVICE_KEY,
+ PROMPT_INPUT_KEY,
+ PROMPT_OUTPUT_KEY,
+ QUERY,
+ RESPONSE,
+ update_span_from_llm_response,
+)
+from monocle_apptrace.wrapper import WrapperMethod
+from opentelemetry import trace
+from opentelemetry.sdk.resources import SERVICE_NAME, Resource
+from opentelemetry.sdk.trace import TracerProvider
+from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter
+
+from fake_list_llm import FakeListLLM
+from parameterized import parameterized
+
+from src.monocle_apptrace.wrap_common import task_wrapper
+
+logger = logging.getLogger()
+logger.setLevel(logging.DEBUG)
+fileHandler = logging.FileHandler('traces.txt' ,'w')
+formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(name)s: %(message)s')
+fileHandler.setFormatter(formatter)
+logger.addHandler(fileHandler)
+
+class TestHandler(unittest.TestCase):
+
+ prompt = PromptTemplate.from_template(
+ """
+ [INST] You are an assistant for question-answering tasks. Use the following pieces of retrieved context
+ to answer the question. If you don't know the answer, just say that you don't know. Use three sentences
+ maximum and keep the answer concise. [/INST]
+ [INST] Question: {question}
+ Context: {context}
+ Answer: [/INST]
+ """
+ )
+ ragText = """A latte is a coffee drink that consists of espresso, milk, and foam.\
+ It is served in a large cup or tall glass and has more milk compared to other espresso-based drinks.\
+ Latte art can be created on the surface of the drink using the milk."""
+
+ def __format_docs(self, docs):
+ return "\n\n ".join(doc.page_content for doc in docs)
+
+ def __createChain(self):
+
+ resource = Resource(attributes={
+ SERVICE_NAME: "coffee_rag_fake"
+ })
+ traceProvider = TracerProvider(resource=resource)
+ exporter = ConsoleSpanExporter()
+ monocleProcessor = BatchSpanProcessor(exporter)
+
+ traceProvider.add_span_processor(monocleProcessor)
+ trace.set_tracer_provider(traceProvider)
+ self.instrumentor = MonocleInstrumentor()
+ self.instrumentor.instrument()
+ self.processor = monocleProcessor
+ responses =[self.ragText]
+ llm = FakeListLLM(responses=responses)
+ llm.api_base = "https://example.com/"
+ embeddings = HuggingFaceEmbeddings(model_id = "multi-qa-mpnet-base-dot-v1")
+ my_path = os.path.abspath(os.path.dirname(__file__))
+ model_path = os.path.join(my_path, "./vector_data/coffee_embeddings")
+ vectorstore = faiss.FAISS.load_local(model_path, embeddings, allow_dangerous_deserialization = True)
+
+ retriever = vectorstore.as_retriever()
+
+ rag_chain = (
+ {"context": retriever| self.__format_docs, "question": RunnablePassthrough()}
+ | self.prompt
+ | llm
+ | StrOutputParser()
+ )
+ return rag_chain
+
+ def setUp(self):
+ os.environ["HTTP_API_KEY"] = "key1"
+ os.environ["HTTP_INGESTION_ENDPOINT"] = "https://localhost:3000/api/v1/traces"
+
+
+ def tearDown(self) -> None:
+ return super().tearDown()
+
+ @parameterized.expand([
+ ("1", AZURE_ML_ENDPOINT_ENV_NAME, AZURE_ML_SERVICE_NAME),
+ ("2", AZURE_FUNCTION_WORKER_ENV_NAME, AZURE_FUNCTION_NAME),
+ ("3", AZURE_APP_SERVICE_ENV_NAME, AZURE_APP_SERVICE_NAME),
+ ("4", AWS_LAMBDA_ENV_NAME, AWS_LAMBDA_SERVICE_NAME),
+ ])
+ @patch.object(requests.Session, 'post')
+ def test_llm_chain(self, test_name, test_input_infra, test_output_infra, mock_post):
+ app_name = "test"
+ wrap_method = MagicMock(return_value=3)
+ setup_monocle_telemetry(
+ workflow_name=app_name,
+ span_processors=[
+ BatchSpanProcessor(HttpSpanExporter("https://localhost:3000/api/v1/traces"))
+ ],
+ wrapper_methods=[
+ WrapperMethod(
+ package="langchain_core.retrievers",
+ object_name="BaseRetriever",
+ method="invoke",
+ wrapper=task_wrapper,
+ output_processor=["entities.json"]
+ ),
+
+ ])
+ try:
+
+ os.environ[test_input_infra] = "1"
+ context_key = "context_key_1"
+ context_value = "context_value_1"
+ set_context_properties({context_key: context_value})
+
+ self.chain = self.__createChain()
+ mock_post.return_value.status_code = 201
+ mock_post.return_value.json.return_value = 'mock response'
+
+ query = "what is latte"
+ response = self.chain.invoke(query, config={})
+ assert response == self.ragText
+ time.sleep(5)
+ mock_post.assert_called_with(
+ url = 'https://localhost:3000/api/v1/traces',
+ data=ANY,
+ timeout=ANY
+ )
+
+ '''mock_post.call_args gives the parameters used to make post call.
+ This can be used to do more asserts'''
+ dataBodyStr = mock_post.call_args.kwargs['data']
+ dataJson = json.loads(dataBodyStr) # more asserts can be added on individual fields
+
+ llm_vector_store_retriever_span = [x for x in dataJson["batch"] if 'langchain.task.VectorStoreRetriever' in x["name"]][0]
+
+ assert llm_vector_store_retriever_span["attributes"]["span.type"] == "retrieval"
+ assert llm_vector_store_retriever_span["attributes"]["entity.1.name"] == "FAISS"
+ assert llm_vector_store_retriever_span["attributes"]["entity.1.type"] == "vectorstore.FAISS"
+
+ finally:
+ os.environ.pop(test_input_infra)
+ try:
+ if(self.instrumentor is not None):
+ self.instrumentor.uninstrument()
+ except Exception as e:
+ print("Uninstrument failed:", e)
+
+
+if __name__ == '__main__':
+ unittest.main()
+
+