diff --git a/src/monocle_apptrace/metamodel/maps/langchain_methods.json b/src/monocle_apptrace/metamodel/maps/langchain_methods.json index 377cb62..9618d9e 100644 --- a/src/monocle_apptrace/metamodel/maps/langchain_methods.json +++ b/src/monocle_apptrace/metamodel/maps/langchain_methods.json @@ -27,21 +27,24 @@ "object": "BaseChatModel", "method": "ainvoke", "wrapper_package": "wrap_common", - "wrapper_method": "allm_wrapper" + "wrapper_method": "allm_wrapper", + "output_processor": ["metamodel/maps/attributes/inference/langchain_entities.json"] }, { "package": "langchain_core.language_models.llms", "object": "LLM", "method": "_generate", "wrapper_package": "wrap_common", - "wrapper_method": "llm_wrapper" + "wrapper_method": "llm_wrapper", + "output_processor": ["metamodel/maps/attributes/inference/langchain_entities.json"] }, { "package": "langchain_core.language_models.llms", "object": "LLM", "method": "_agenerate", "wrapper_package": "wrap_common", - "wrapper_method": "llm_wrapper" + "wrapper_method": "allm_wrapper", + "output_processor": ["metamodel/maps/attributes/inference/langchain_entities.json"] }, { "package": "langchain_core.retrievers", @@ -57,7 +60,8 @@ "object": "BaseRetriever", "method": "ainvoke", "wrapper_package": "wrap_common", - "wrapper_method": "atask_wrapper" + "wrapper_method": "atask_wrapper", + "output_processor": ["metamodel/maps/attributes/retrieval/langchain_entities.json"] }, { "package": "langchain.schema", @@ -106,4 +110,4 @@ "wrapper_method": "atask_wrapper" } ] -} +} \ No newline at end of file diff --git a/src/monocle_apptrace/metamodel/maps/llamaindex_methods.json b/src/monocle_apptrace/metamodel/maps/llamaindex_methods.json index 81f6e52..a7292ea 100644 --- a/src/monocle_apptrace/metamodel/maps/llamaindex_methods.json +++ b/src/monocle_apptrace/metamodel/maps/llamaindex_methods.json @@ -15,7 +15,8 @@ "method": "aretrieve", "span_name": "llamaindex.retrieve", "wrapper_package": "wrap_common", - "wrapper_method": "atask_wrapper" + "wrapper_method": "atask_wrapper", + "output_processor": ["metamodel/maps/attributes/retrieval/llamaindex_entities.json"] }, { "package": "llama_index.core.base.base_query_engine", @@ -39,7 +40,8 @@ "method": "chat", "span_name": "llamaindex.llmchat", "wrapper_package": "wrap_common", - "wrapper_method": "task_wrapper" + "wrapper_method": "task_wrapper", + "output_processor": ["metamodel/maps/attributes/inference/llamaindex_entities.json"] }, { "package": "llama_index.core.llms.custom", @@ -47,7 +49,8 @@ "method": "achat", "span_name": "llamaindex.llmchat", "wrapper_package": "wrap_common", - "wrapper_method": "atask_wrapper" + "wrapper_method": "atask_wrapper", + "output_processor": ["metamodel/maps/attributes/inference/llamaindex_entities.json"] }, { "package": "llama_index.llms.openai.base", @@ -64,7 +67,8 @@ "method": "achat", "span_name": "llamaindex.openai", "wrapper_package": "wrap_common", - "wrapper_method": "allm_wrapper" + "wrapper_method": "allm_wrapper", + "output_processor": ["metamodel/maps/attributes/inference/llamaindex_entities.json"] } ] -} +} \ No newline at end of file diff --git a/src/monocle_apptrace/utils.py b/src/monocle_apptrace/utils.py index 624580e..ab386a7 100644 --- a/src/monocle_apptrace/utils.py +++ b/src/monocle_apptrace/utils.py @@ -59,7 +59,10 @@ def load_output_processor(wrapper_method, attributes_config_base_path): logger.info(f'Output processor file path is: {output_processor_file_path}') if isinstance(output_processor_file_path, str) and output_processor_file_path: # Combined condition - absolute_file_path = os.path.join(attributes_config_base_path, output_processor_file_path) + if not attributes_config_base_path: + absolute_file_path = os.path.abspath(output_processor_file_path) + else: + absolute_file_path = os.path.join(attributes_config_base_path, output_processor_file_path) logger.info(f'Absolute file path is: {absolute_file_path}') try: @@ -107,7 +110,7 @@ def process_wrapper_method_config( wrapper_method["span_name_getter"] = get_wrapper_method( wrapper_method["span_name_getter_package"], wrapper_method["span_name_getter_method"]) - if "output_processor" in wrapper_method: + if "output_processor" in wrapper_method and wrapper_method["output_processor"]: load_output_processor(wrapper_method, attributes_config_base_path) def get_wrapper_method(package_name: str, method_name: str): @@ -158,4 +161,4 @@ def get_attribute(key: str) -> str: Returns: The value associated with the given key. """ - return get_value(key) + return get_value(key) \ No newline at end of file diff --git a/src/monocle_apptrace/wrap_common.py b/src/monocle_apptrace/wrap_common.py index d30fa40..690cad6 100644 --- a/src/monocle_apptrace/wrap_common.py +++ b/src/monocle_apptrace/wrap_common.py @@ -76,6 +76,7 @@ def get_embedding_model_for_vectorstore(instance): }, } + @with_tracer_wrapper def task_wrapper(tracer: Tracer, to_wrap, wrapped, instance, args, kwargs): """Instruments and calls every function defined in TO_WRAP.""" @@ -93,40 +94,46 @@ def task_wrapper(tracer: Tracer, to_wrap, wrapped, instance, args, kwargs): with tracer.start_as_current_span(name) as span: if "output_processor" in to_wrap: - process_span(to_wrap["output_processor"],span,instance,args) + process_span(to_wrap["output_processor"], span, instance, args) pre_task_processing(to_wrap, instance, args, span) return_value = wrapped(*args, **kwargs) post_task_processing(to_wrap, span, return_value) return return_value -def process_span(output_processor,span,instance,args): - # Check if the output_processor is a valid JSON (in Python, that means it's a dictionary) - if isinstance(output_processor, dict) and len(output_processor)>0: - if 'type' in output_processor: - span.set_attribute("span.type", output_processor['type']) - else: - logger.warning("type of span not found or incorrect written in entity json") - count=0 - if 'attributes' in output_processor: - count = len(output_processor["attributes"]) - span.set_attribute("entity.count", count) - span_index = 1 - for processors in output_processor["attributes"]: - for processor in processors: - if 'attribute' in processor and 'accessor' in processor: - attribute_name = f"entity.{span_index}.{processor['attribute']}" - result = eval(processor['accessor'])(instance, args) - span.set_attribute(attribute_name, result) - else: - logger.warning("attribute or accessor not found or incorrect written in entity json") - span_index += 1 - else: - logger.warning("attributes not found or incorrect written in entity json") - span.set_attribute("span.count", count) +def process_span(output_processor, span, instance, args): + # Check if the output_processor is a valid JSON (in Python, that means it's a dictionary) + if isinstance(output_processor, dict) and len(output_processor) > 0: + if 'type' in output_processor: + span.set_attribute("span.type", output_processor['type']) + else: + logger.warning("type of span not found or incorrect written in entity json") + count = 0 + if 'attributes' in output_processor: + count = len(output_processor["attributes"]) + span.set_attribute("entity.count", count) + span_index = 1 + for processors in output_processor["attributes"]: + for processor in processors: + if 'attribute' in processor and 'accessor' in processor: + attribute_name = f"entity.{span_index}.{processor['attribute']}" + try: + result = eval(processor['accessor'])(instance, args) + if result and isinstance(result, str): + span.set_attribute(attribute_name, result) + except Exception as e: + pass + + else: + logger.warning("attribute or accessor not found or incorrect written in entity json") + span_index += 1 else: - logger.warning("empty or entities json is not in correct format") + logger.warning("attributes not found or incorrect written in entity json") + span.set_attribute("span.count", count) + + else: + logger.warning("empty or entities json is not in correct format") def post_task_processing(to_wrap, span, return_value): @@ -138,6 +145,7 @@ def post_task_processing(to_wrap, span, return_value): except: logger.exception("exception in post_task_processing") + def pre_task_processing(to_wrap, instance, args, span): try: if is_root_span(span): @@ -150,7 +158,7 @@ def pre_task_processing(to_wrap, instance, args, span): update_span_with_context_input(to_wrap=to_wrap, wrapped_args=args, span=span) except: logger.exception("exception in pre_task_processing") - + @with_tracer_wrapper async def atask_wrapper(tracer, to_wrap, wrapped, instance, args, kwargs): @@ -167,6 +175,8 @@ async def atask_wrapper(tracer, to_wrap, wrapped, instance, args, kwargs): else: name = f"langchain.task.{instance.__class__.__name__}" with tracer.start_as_current_span(name) as span: + if "output_processor" in to_wrap: + process_span(to_wrap["output_processor"], span, instance, args) pre_task_processing(to_wrap, instance, args, span) return_value = await wrapped(*args, **kwargs) post_task_processing(to_wrap, span, return_value) @@ -190,11 +200,18 @@ async def allm_wrapper(tracer, to_wrap, wrapped, instance, args, kwargs): else: name = f"langchain.task.{instance.__class__.__name__}" with tracer.start_as_current_span(name) as span: - update_llm_endpoint(curr_span=span, instance=instance) + if 'haystack.components.retrievers' in to_wrap['package'] and 'haystack.retriever' in span.name: + input_arg_text = get_attribute(DATA_INPUT_KEY) + span.add_event(DATA_INPUT_KEY, {QUERY: input_arg_text}) + provider_name = set_provider_name(instance) + instance_args = {"provider_name": provider_name} + if 'output_processor' in to_wrap: + process_span(to_wrap['output_processor'], span, instance, instance_args) return_value = await wrapped(*args, **kwargs) - - update_span_from_llm_response(response = return_value, span = span, instance=instance) + if 'haystack.components.retrievers' in to_wrap['package'] and 'haystack.retriever' in span.name: + update_span_with_context_output(to_wrap=to_wrap, return_value=return_value, span=span) + update_span_from_llm_response(response=return_value, span=span, instance=instance) return return_value @@ -227,7 +244,7 @@ def llm_wrapper(tracer: Tracer, to_wrap, wrapped, instance, args, kwargs): return_value = wrapped(*args, **kwargs) if 'haystack.components.retrievers' in to_wrap['package'] and 'haystack.retriever' in span.name: update_span_with_context_output(to_wrap=to_wrap, return_value=return_value, span=span) - update_span_from_llm_response(response = return_value, span = span,instance=instance) + update_span_from_llm_response(response=return_value, span=span, instance=instance) return return_value @@ -261,6 +278,7 @@ def update_llm_endpoint(curr_span: Span, instance): inference_endpoint=inference_ep ) + def set_provider_name(instance): provider_url = "" @@ -293,7 +311,8 @@ def get_input_from_args(chain_args): return chain_args[0] return "" -def update_span_from_llm_response(response, span: Span ,instance): + +def update_span_from_llm_response(response, span: Span, instance): # extract token uasge from langchain openai if (response is not None and hasattr(response, "response_metadata")): response_metadata = response.response_metadata @@ -326,6 +345,7 @@ def update_span_from_llm_response(response, span: Span ,instance): except AttributeError: token_usage = None + def update_workflow_type(to_wrap, span: Span): package_name = to_wrap.get('package') @@ -346,6 +366,7 @@ def update_span_with_context_input(to_wrap, wrapped_args, span: Span): if input_arg_text: span.add_event(DATA_INPUT_KEY, {QUERY: input_arg_text}) + def update_span_with_context_output(to_wrap, return_value, span: Span): package_name: str = to_wrap.get('package') output_arg_text = "" @@ -362,6 +383,7 @@ def update_span_with_context_output(to_wrap, return_value, span: Span): if output_arg_text: span.add_event(DATA_OUTPUT_KEY, {RESPONSE: output_arg_text}) + def update_span_with_prompt_input(to_wrap, wrapped_args, span: Span): input_arg_text = wrapped_args[0] @@ -370,6 +392,7 @@ def update_span_with_prompt_input(to_wrap, wrapped_args, span: Span): else: span.add_event(PROMPT_INPUT_KEY, {QUERY: input_arg_text}) + def update_span_with_prompt_output(to_wrap, wrapped_args, span: Span): package_name: str = to_wrap.get('package') if isinstance(wrapped_args, str): @@ -377,5 +400,4 @@ def update_span_with_prompt_output(to_wrap, wrapped_args, span: Span): if isinstance(wrapped_args, dict): span.add_event(PROMPT_OUTPUT_KEY, wrapped_args) if "llama_index.core.base.base_query_engine" in package_name: - span.add_event(PROMPT_OUTPUT_KEY, {RESPONSE:wrapped_args.response}) - \ No newline at end of file + span.add_event(PROMPT_OUTPUT_KEY, {RESPONSE: wrapped_args.response}) diff --git a/tests/entities.json b/tests/entities.json new file mode 100644 index 0000000..5fc8ef4 --- /dev/null +++ b/tests/entities.json @@ -0,0 +1,17 @@ +{ + "type": "retrieval", + "attributes": [ + [ + { + "_comment": "vector store name and type", + "attribute": "name", + "accessor": "lambda instance,args: type(instance.vectorstore).__name__" + }, + { + "attribute": "type", + "accessor": "lambda instance,args: 'vectorstore.'+type(instance.vectorstore).__name__" + } + ] + + ] +} \ No newline at end of file diff --git a/tests/langchain_async_processor_test.py b/tests/langchain_async_processor_test.py new file mode 100644 index 0000000..14c909f --- /dev/null +++ b/tests/langchain_async_processor_test.py @@ -0,0 +1,167 @@ +from unittest import IsolatedAsyncioTestCase +import unittest + +import json +import logging +import os +import time + +import unittest +from unittest.mock import ANY, MagicMock, patch +from unittest import IsolatedAsyncioTestCase + +import requests +from dummy_class import DummyClass +from embeddings_wrapper import HuggingFaceEmbeddings +from http_span_exporter import HttpSpanExporter +from langchain.prompts import PromptTemplate +from langchain.schema import StrOutputParser +from langchain_community.vectorstores import faiss +from langchain_core.messages.ai import AIMessage +from langchain_core.runnables import RunnablePassthrough +from monocle_apptrace.instrumentor import ( + MonocleInstrumentor, + set_context_properties, + setup_monocle_telemetry, +) +from monocle_apptrace.wrap_common import ( + SESSION_PROPERTIES_KEY, + PROMPT_INPUT_KEY, + PROMPT_OUTPUT_KEY, + QUERY, + RESPONSE, + update_span_from_llm_response, +) +from monocle_apptrace.wrapper import WrapperMethod +from opentelemetry import trace +from opentelemetry.sdk.resources import SERVICE_NAME, Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter +from monocle_apptrace.wrap_common import allm_wrapper +from fake_list_llm import FakeListLLM + +logger = logging.getLogger() +logger.setLevel(logging.DEBUG) +fileHandler = logging.FileHandler('traces.txt', 'w') +formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(name)s: %(message)s') +fileHandler.setFormatter(formatter) +logger.addHandler(fileHandler) +events = [] + + +class Test(IsolatedAsyncioTestCase): + prompt = PromptTemplate.from_template( + """ + [INST] You are an assistant for question-answering tasks. Use the following pieces of retrieved context + to answer the question. If you don't know the answer, just say that you don't know. Use three sentences + maximum and keep the answer concise. [/INST] + [INST] Question: {question} + Context: {context} + Answer: [/INST] + """ + ) + ragText = """A latte is a coffee drink that consists of espresso, milk, and foam.\ + It is served in a large cup or tall glass and has more milk compared to other espresso-based drinks.\ + Latte art can be created on the surface of the drink using the milk.""" + + def __format_docs(self, docs): + return "\n\n ".join(doc.page_content for doc in docs) + + def __createChain(self): + + resource = Resource(attributes={ + SERVICE_NAME: "coffee_rag_fake" + }) + traceProvider = TracerProvider(resource=resource) + exporter = ConsoleSpanExporter() + monocleProcessor = BatchSpanProcessor(exporter) + + traceProvider.add_span_processor(monocleProcessor) + trace.set_tracer_provider(traceProvider) + self.instrumentor = MonocleInstrumentor() + self.instrumentor.instrument() + self.processor = monocleProcessor + responses = [self.ragText] + llm = FakeListLLM(responses=responses) + llm.api_base = "https://example.com/" + embeddings = HuggingFaceEmbeddings(model_id="multi-qa-mpnet-base-dot-v1") + my_path = os.path.abspath(os.path.dirname(__file__)) + model_path = os.path.join(my_path, "./vector_data/coffee_embeddings") + vectorstore = faiss.FAISS.load_local(model_path, embeddings, allow_dangerous_deserialization=True) + + retriever = vectorstore.as_retriever() + + rag_chain = ( + {"context": retriever | self.__format_docs, "question": RunnablePassthrough()} + | self.prompt + | llm + | StrOutputParser() + ) + return rag_chain + + def setUp(self): + events.append("setUp") + + async def asyncSetUp(self): + os.environ["HTTP_API_KEY"] = "key1" + os.environ["HTTP_INGESTION_ENDPOINT"] = "https://localhost:3000/api/v1/traces" + + @patch.object(requests.Session, 'post') + async def test_response(self, mock_post): + app_name = "test" + wrap_method = MagicMock(return_value=3) + setup_monocle_telemetry( + workflow_name=app_name, + span_processors=[ + BatchSpanProcessor(HttpSpanExporter("https://localhost:3000/api/v1/traces")) + ], + wrapper_methods=[ + WrapperMethod( + package="langchain.chat_models.base", + object_name="BaseChatModel", + method="ainvoke", + wrapper=allm_wrapper + ) + + ]) + try: + context_key = "context_key_1" + context_value = "context_value_1" + set_context_properties({context_key: context_value}) + + self.chain = self.__createChain() + mock_post.return_value.status_code = 201 + mock_post.return_value.json.return_value = 'mock response' + + query = "what is latte" + response = await self.chain.ainvoke(query, config={}) + assert response == self.ragText + time.sleep(5) + mock_post.assert_called_with( + url='https://localhost:3000/api/v1/traces', + data=ANY, + timeout=ANY + ) + + '''mock_post.call_args gives the parameters used to make post call. + This can be used to do more asserts''' + dataBodyStr = mock_post.call_args.kwargs['data'] + dataJson = json.loads(dataBodyStr) # more asserts can be added on individual fields + + llm_span = [x for x in dataJson["batch"] if "FakeListLLM" in x["name"]][0] + + assert llm_span["attributes"]["span.type"] == "inference" + assert llm_span["attributes"]["entity.1.provider_name"] == "example.com" + assert llm_span["attributes"]["entity.1.type"] == "inference.azure_oai" + assert llm_span["attributes"]["entity.1.inference_endpoint"] == "https://example.com/" + + finally: + try: + if (self.instrumentor is not None): + self.instrumentor.uninstrument() + except Exception as e: + print("Uninstrument failed:", e) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/langchain_custom_output_processor_test.py b/tests/langchain_custom_output_processor_test.py new file mode 100644 index 0000000..7f980d9 --- /dev/null +++ b/tests/langchain_custom_output_processor_test.py @@ -0,0 +1,190 @@ + +import json +import logging +import os +import time + +import unittest +from unittest.mock import ANY, MagicMock, patch + +import pytest +import requests +from dummy_class import DummyClass +from embeddings_wrapper import HuggingFaceEmbeddings +from http_span_exporter import HttpSpanExporter +from langchain.prompts import PromptTemplate +from langchain.schema import StrOutputParser +from langchain_community.vectorstores import faiss +from langchain_core.messages.ai import AIMessage +from langchain_core.runnables import RunnablePassthrough +from monocle_apptrace.constants import ( + AZURE_APP_SERVICE_ENV_NAME, + AZURE_APP_SERVICE_NAME, + AZURE_FUNCTION_NAME, + AZURE_FUNCTION_WORKER_ENV_NAME, + AZURE_ML_ENDPOINT_ENV_NAME, + AZURE_ML_SERVICE_NAME, + AWS_LAMBDA_ENV_NAME, + AWS_LAMBDA_SERVICE_NAME +) +from monocle_apptrace.instrumentor import ( + MonocleInstrumentor, + set_context_properties, + setup_monocle_telemetry, +) +from monocle_apptrace.wrap_common import ( + SESSION_PROPERTIES_KEY, + INFRA_SERVICE_KEY, + PROMPT_INPUT_KEY, + PROMPT_OUTPUT_KEY, + QUERY, + RESPONSE, + update_span_from_llm_response, +) +from monocle_apptrace.wrapper import WrapperMethod +from opentelemetry import trace +from opentelemetry.sdk.resources import SERVICE_NAME, Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter + +from fake_list_llm import FakeListLLM +from parameterized import parameterized + +from src.monocle_apptrace.wrap_common import task_wrapper + +logger = logging.getLogger() +logger.setLevel(logging.DEBUG) +fileHandler = logging.FileHandler('traces.txt' ,'w') +formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(name)s: %(message)s') +fileHandler.setFormatter(formatter) +logger.addHandler(fileHandler) + +class TestHandler(unittest.TestCase): + + prompt = PromptTemplate.from_template( + """ + [INST] You are an assistant for question-answering tasks. Use the following pieces of retrieved context + to answer the question. If you don't know the answer, just say that you don't know. Use three sentences + maximum and keep the answer concise. [/INST] + [INST] Question: {question} + Context: {context} + Answer: [/INST] + """ + ) + ragText = """A latte is a coffee drink that consists of espresso, milk, and foam.\ + It is served in a large cup or tall glass and has more milk compared to other espresso-based drinks.\ + Latte art can be created on the surface of the drink using the milk.""" + + def __format_docs(self, docs): + return "\n\n ".join(doc.page_content for doc in docs) + + def __createChain(self): + + resource = Resource(attributes={ + SERVICE_NAME: "coffee_rag_fake" + }) + traceProvider = TracerProvider(resource=resource) + exporter = ConsoleSpanExporter() + monocleProcessor = BatchSpanProcessor(exporter) + + traceProvider.add_span_processor(monocleProcessor) + trace.set_tracer_provider(traceProvider) + self.instrumentor = MonocleInstrumentor() + self.instrumentor.instrument() + self.processor = monocleProcessor + responses =[self.ragText] + llm = FakeListLLM(responses=responses) + llm.api_base = "https://example.com/" + embeddings = HuggingFaceEmbeddings(model_id = "multi-qa-mpnet-base-dot-v1") + my_path = os.path.abspath(os.path.dirname(__file__)) + model_path = os.path.join(my_path, "./vector_data/coffee_embeddings") + vectorstore = faiss.FAISS.load_local(model_path, embeddings, allow_dangerous_deserialization = True) + + retriever = vectorstore.as_retriever() + + rag_chain = ( + {"context": retriever| self.__format_docs, "question": RunnablePassthrough()} + | self.prompt + | llm + | StrOutputParser() + ) + return rag_chain + + def setUp(self): + os.environ["HTTP_API_KEY"] = "key1" + os.environ["HTTP_INGESTION_ENDPOINT"] = "https://localhost:3000/api/v1/traces" + + + def tearDown(self) -> None: + return super().tearDown() + + @parameterized.expand([ + ("1", AZURE_ML_ENDPOINT_ENV_NAME, AZURE_ML_SERVICE_NAME), + ("2", AZURE_FUNCTION_WORKER_ENV_NAME, AZURE_FUNCTION_NAME), + ("3", AZURE_APP_SERVICE_ENV_NAME, AZURE_APP_SERVICE_NAME), + ("4", AWS_LAMBDA_ENV_NAME, AWS_LAMBDA_SERVICE_NAME), + ]) + @patch.object(requests.Session, 'post') + def test_llm_chain(self, test_name, test_input_infra, test_output_infra, mock_post): + app_name = "test" + wrap_method = MagicMock(return_value=3) + setup_monocle_telemetry( + workflow_name=app_name, + span_processors=[ + BatchSpanProcessor(HttpSpanExporter("https://localhost:3000/api/v1/traces")) + ], + wrapper_methods=[ + WrapperMethod( + package="langchain_core.retrievers", + object_name="BaseRetriever", + method="invoke", + wrapper=task_wrapper, + output_processor=["entities.json"] + ), + + ]) + try: + + os.environ[test_input_infra] = "1" + context_key = "context_key_1" + context_value = "context_value_1" + set_context_properties({context_key: context_value}) + + self.chain = self.__createChain() + mock_post.return_value.status_code = 201 + mock_post.return_value.json.return_value = 'mock response' + + query = "what is latte" + response = self.chain.invoke(query, config={}) + assert response == self.ragText + time.sleep(5) + mock_post.assert_called_with( + url = 'https://localhost:3000/api/v1/traces', + data=ANY, + timeout=ANY + ) + + '''mock_post.call_args gives the parameters used to make post call. + This can be used to do more asserts''' + dataBodyStr = mock_post.call_args.kwargs['data'] + dataJson = json.loads(dataBodyStr) # more asserts can be added on individual fields + + llm_vector_store_retriever_span = [x for x in dataJson["batch"] if 'langchain.task.VectorStoreRetriever' in x["name"]][0] + + assert llm_vector_store_retriever_span["attributes"]["span.type"] == "retrieval" + assert llm_vector_store_retriever_span["attributes"]["entity.1.name"] == "FAISS" + assert llm_vector_store_retriever_span["attributes"]["entity.1.type"] == "vectorstore.FAISS" + + finally: + os.environ.pop(test_input_infra) + try: + if(self.instrumentor is not None): + self.instrumentor.uninstrument() + except Exception as e: + print("Uninstrument failed:", e) + + +if __name__ == '__main__': + unittest.main() + +