Skip to content

Commit

Permalink
adding asserts statements in unittest for vector store retrievers
Browse files Browse the repository at this point in the history
Signed-off-by: hansrajr <[email protected]>
  • Loading branch information
Hansrajr committed Sep 18, 2024
1 parent 8737199 commit 5bdfb65
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 10 deletions.
20 changes: 15 additions & 5 deletions tests/haystack_unitest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@


import json
import logging
import os
Expand All @@ -17,6 +16,8 @@
from monocle_apptrace.instrumentor import setup_monocle_telemetry
from monocle_apptrace.wrapper import WrapperMethod
from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter
from haystack.components.retrievers import InMemoryBM25Retriever
from haystack.document_stores.in_memory import InMemoryDocumentStore

logger = logging.getLogger(__name__)

Expand All @@ -36,10 +37,14 @@ def test_haystack(self, mock_post):
)
prompt_builder = DynamicChatPromptBuilder()
llm = OpenAIChatGenerator(api_key=Secret.from_token(api_key), model="gpt-4")
document_store = InMemoryDocumentStore()
retriever = InMemoryBM25Retriever(document_store=document_store)

pipe = Pipeline()
pipe.add_component("retriever", retriever)
pipe.add_component("prompt_builder", prompt_builder)
pipe.add_component("llm", llm)
pipe.connect("retriever", "prompt_builder.template_variables")
pipe.connect("prompt_builder.prompt", "llm.messages")
query = "OpenTelemetry"
messages = [ChatMessage.from_user("Tell me a joke about {{query}}")]
Expand All @@ -62,9 +67,9 @@ def test_haystack(self, mock_post):
This can be used to do more asserts'''
dataBodyStr = mock_post.call_args.kwargs['data']
logger.debug(dataBodyStr)
dataJson = json.loads(dataBodyStr) # more asserts can be added on individual fields
dataJson = json.loads(dataBodyStr) # more asserts can be added on individual fields

root_attributes = [x for x in dataJson["batch"] if x["parent_id"] == "None"][0]["attributes"]
root_attributes = [x for x in dataJson["batch"] if x["parent_id"] == "None"][0]["attributes"]
# assert root_attributes["workflow_input"] == query
# assert root_attributes["workflow_output"] == llm.dummy_response

Expand All @@ -78,6 +83,7 @@ def test_haystack(self, mock_post):

type_found = False
model_name_found = False
provider_found = False
assert root_attributes["workflow_input"] == query
assert root_attributes["workflow_output"] == TestHandler.ragText

Expand All @@ -88,9 +94,13 @@ def test_haystack(self, mock_post):
if span["name"] == "haystack.openai" and "model_name" in span["attributes"]:
assert span["attributes"]["model_name"] == "gpt-4"
model_name_found = True
if span["name"] == "haystack.retriever" and "type" in span["attributes"]:
assert span["attributes"]["provider_name"] == "InMemoryDocumentStore"
provider_found = True

assert type_found == True
assert model_name_found == True
assert type_found
assert model_name_found
assert provider_found

if __name__ == '__main__':
unittest.main()
9 changes: 6 additions & 3 deletions tests/langchain_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,17 @@ def test_llm_chain(self, test_name, test_input_infra, test_output_infra, mock_po
This can be used to do more asserts'''
dataBodyStr = mock_post.call_args.kwargs['data']
dataJson = json.loads(dataBodyStr) # more asserts can be added on individual fields
# assert len(dataJson['batch']) == 7
# assert len(dataJson['batch']) == 75 = {dict: 11} {'attributes': {'session.context_key_1': 'context_value_1'}, 'context': {'span_id': '18a8d75ec4c94523', 'trace_id': '4fedeffc8d9a4ec8b3029a437b667e15', 'trace_state': '[]'}, 'end_time': '2024-09-17T07:30:36.830264Z', 'events': [], 'kind': 'SpanKind.INTERNAL', 'links': [], 'name': 'langchain.task.StrOutputParser', 'parent_id': 'f371def04dfa963d', 'resource': {'attributes': {'service.name': 'test'}, 'schema_url': ''}, 'start_time': '2024-09-17T07:30:36.829211Z', 'status': {'status_code': 'UNSET'}}... View

root_span = [x for x in dataJson["batch"] if x["parent_id"] == "None"][0]
llm_span = [x for x in dataJson["batch"] if "FakeListLLM" in x["name"]][0]
root_span = [x for x in dataJson["batch"] if x["parent_id"] == "None"][0]
llm_span = [x for x in dataJson["batch"] if "FakeListLLM" in x["name"]][0]
llm_vector_store_retriever_span = [x for x in dataJson["batch"] if 'langchain.task.VectorStoreRetriever' in x["name"]][0]
root_span_attributes = root_span["attributes"]
root_span_events = root_span["events"]

assert llm_span["attributes"]["provider_name"] == "example.com"
assert llm_vector_store_retriever_span["attributes"]["embedding_model"] == "FAISS"
assert llm_vector_store_retriever_span["attributes"]["provider_name"] == "HuggingFaceEmbeddings"

def get_event_attributes(events, key):
return [event['attributes'] for event in events if event['name'] == key][0]
Expand Down
11 changes: 9 additions & 2 deletions tests/llama_index_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,18 @@ def get_event_attributes(events, key):

span_names: List[str] = [span["name"] for span in dataJson['batch']]
llm_span = [x for x in dataJson["batch"] if "llamaindex.OurLLM" in x["name"]][0]
llm_retriever_span = [x for x in dataJson["batch"] if "llamaindex.query" in x["name"]][0]
for name in ["llamaindex.retrieve", "llamaindex.query", "llamaindex.OurLLM"]:
assert name in span_names
assert llm_span["attributes"]["completion_tokens"] == 1
assert llm_span["attributes"]["prompt_tokens"] == 2
assert llm_span["attributes"]["total_tokens"] == 3

assert llm_retriever_span["attributes"]["embedding_model"] == "BAAI/bge-small-en-v1.5"
assert llm_retriever_span["attributes"]["provider_name"] == "SimpleVectorStore"

type_found = False
model_name_found = False
vectorstore_provider = False

for span in dataJson["batch"]:
if span["name"] == "llamaindex.query" and "workflow_type" in span["attributes"]:
Expand All @@ -137,9 +141,12 @@ def get_event_attributes(events, key):
if span["name"] == "llamaindex.OurLLM" and "model_name" in span["attributes"]:
assert span["attributes"]["model_name"] == "custom"
model_name_found = True

if span["name"] == "llamaindex.query" and "type" in span["attributes"]:
assert span["attributes"]["provider_name"] == "SimpleVectorStore"
vectorstore_provider = True
assert type_found
assert model_name_found
assert vectorstore_provider



Expand Down

0 comments on commit 5bdfb65

Please sign in to comment.