Skip to content

Commit

Permalink
Db Action embedding search fix (#1835)
Browse files Browse the repository at this point in the history
* Db Action embedding search fix

* Fixed test cases and training batch set to 50

* Doubled limit of individual model queries and updated corresponding test cases
  • Loading branch information
himanshugt16 authored Feb 27, 2025
1 parent f95252a commit 3050795
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 45 deletions.
8 changes: 4 additions & 4 deletions kairon/shared/llm/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async def train(self, user, *args, **kwargs) -> Dict:
await self.__delete_collections()
count = 0
processor = CognitionDataProcessor()
batch_size = 100
batch_size = 50

collections_data = CognitionData.objects(bot=self.bot)
collection_groups = {}
Expand Down Expand Up @@ -320,17 +320,17 @@ async def __collection_hybrid_query__(self, collection_name: Text, embeddings: D
{
"query": embeddings.get("dense", []),
"using": "dense",
"limit": limit
"limit": limit * 2
},
{
"query": embeddings.get("rerank", []),
"using": "rerank",
"limit": limit
"limit": limit * 2
},
{
"query": embeddings.get("sparse", {}),
"using": "sparse",
"limit": limit
"limit": limit * 2
}
],
"query": {"fusion": "rrf"},
Expand Down
19 changes: 18 additions & 1 deletion kairon/shared/vector_embeddings/db/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,24 @@ async def perform_operation(self, data: Dict, user: str, **kwargs):
user_msg = data.get(DbActionOperationType.embedding_search)
if user_msg and isinstance(user_msg, str):
vector = await self.__get_embedding(user_msg, user, **kwargs)
request['query'] = vector
request['prefetch'] = [
{
"query": vector.get("dense", []),
"using": "dense",
"limit": 20
},
{
"query": vector.get("rerank", []),
"using": "rerank",
"limit": 20
},
{
"query": vector.get("sparse", {}),
"using": "sparse",
"limit": 20
}
]
request.update({"query": {"fusion": "rrf"}})

if DbActionOperationType.payload_search in data:
payload = data.get(DbActionOperationType.payload_search)
Expand Down
116 changes: 106 additions & 10 deletions tests/integration_test/action_service_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4137,8 +4137,32 @@ def test_vectordb_action_execution_embedding_search_from_value(mock_get_embeddin
url=http_url,
body=resp_msg,
status=200,
match=[responses.matchers.json_params_matcher({'query': embeddings,
'with_payload': True, 'limit': 10})],
match=[
responses.matchers.json_params_matcher(
{
'prefetch': [
{
"query": embeddings.get("dense", []),
"using": "dense",
"limit": 20
},
{
"query": embeddings.get("rerank", []),
"using": "rerank",
"limit": 20
},
{
"query": embeddings.get("sparse", {}),
"using": "sparse",
"limit": 20
}
],
'query': {"fusion": "rrf"},
'with_payload': True,
'limit': 10
}
)
],
)

request_object = {
Expand Down Expand Up @@ -4397,8 +4421,32 @@ def test_vectordb_action_execution_embedding_search_from_slot(mock_get_embedding
url=http_url,
body=resp_msg,
status=200,
match=[responses.matchers.json_params_matcher({'query': embeddings,
'with_payload': True, 'limit': 10})],
match=[
responses.matchers.json_params_matcher(
{
'prefetch': [
{
"query": embeddings.get("dense", []),
"using": "dense",
"limit": 20
},
{
"query": embeddings.get("rerank", []),
"using": "rerank",
"limit": 20
},
{
"query": embeddings.get("sparse", {}),
"using": "sparse",
"limit": 20
}
],
'query': {"fusion": "rrf"},
'with_payload': True,
'limit': 10
}
)
],
)

request_object = {
Expand Down Expand Up @@ -4509,8 +4557,32 @@ def test_vectordb_action_execution_embedding_search_no_response_dispatch(mock_ge
url=http_url,
body=resp_msg,
status=200,
match=[responses.matchers.json_params_matcher({'query': embeddings,
'with_payload': True, 'limit': 10})],
match=[
responses.matchers.json_params_matcher(
{
'prefetch': [
{
"query": embeddings.get("dense", []),
"using": "dense",
"limit": 20
},
{
"query": embeddings.get("rerank", []),
"using": "rerank",
"limit": 20
},
{
"query": embeddings.get("sparse", {}),
"using": "sparse",
"limit": 20
}
],
'query': {"fusion": "rrf"},
'with_payload': True,
'limit': 10
}
)
],
)

request_object = {
Expand Down Expand Up @@ -14235,10 +14307,34 @@ def test_vectordb_action_execution_embedding_payload_search(mock_get_embedding):
url=http_url,
body=resp_msg,
status=200,
match=[responses.matchers.json_params_matcher({'with_payload': True,
'limit': 10,
'query': embeddings,
**payload}, strict_match=False)],
match=[
responses.matchers.json_params_matcher(
{
'prefetch': [
{
"query": embeddings.get("dense", []),
"using": "dense",
"limit": 20
},
{
"query": embeddings.get("rerank", []),
"using": "rerank",
"limit": 20
},
{
"query": embeddings.get("sparse", {}),
"using": "sparse",
"limit": 20
}
],
'query': {"fusion": "rrf"},
'with_payload': True,
'limit': 10,
**payload
},
strict_match=False
)
],
)

request_object = {
Expand Down
54 changes: 27 additions & 27 deletions tests/unit_test/llm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,17 +769,17 @@ async def test_gpt3_faq_embedding_predict(self, mock_get_embedding, aioresponses
{
"query": embeddings.get("dense", []),
"using": "dense",
"limit": 10
"limit": 20
},
{
"query": embeddings.get("rerank", []),
"using": "rerank",
"limit": 10
"limit": 20
},
{
"query": embeddings.get("sparse", {}),
"using": "sparse",
"limit": 10
"limit": 20
}
],
"query": {"fusion": "rrf"},
Expand Down Expand Up @@ -881,17 +881,17 @@ async def test_gpt3_faq_embedding_predict_with_default_collection(self, mock_get
{
"query": embeddings.get("dense", []),
"using": "dense",
"limit": 10
"limit": 20
},
{
"query": embeddings.get("rerank", []),
"using": "rerank",
"limit": 10
"limit": 20
},
{
"query": embeddings.get("sparse", {}),
"using": "sparse",
"limit": 10
"limit": 20
}
],
"query": {"fusion": "rrf"},
Expand Down Expand Up @@ -998,17 +998,17 @@ async def test_gpt3_faq_embedding_predict_with_values(self, mock_get_embedding,
{
"query": embeddings.get("dense", []),
"using": "dense",
"limit": 10
"limit": 20
},
{
"query": embeddings.get("rerank", []),
"using": "rerank",
"limit": 10
"limit": 20
},
{
"query": embeddings.get("sparse", {}),
"using": "sparse",
"limit": 10
"limit": 20
}
],
"query": {"fusion": "rrf"},
Expand Down Expand Up @@ -1112,17 +1112,17 @@ async def test_gpt3_faq_embedding_predict_with_values_and_stream(self, mock_get_
{
"query": embeddings.get("dense", []),
"using": "dense",
"limit": 10
"limit": 20
},
{
"query": embeddings.get("rerank", []),
"using": "rerank",
"limit": 10
"limit": 20
},
{
"query": embeddings.get("sparse", {}),
"using": "sparse",
"limit": 10
"limit": 20
}
],
"query": {"fusion": "rrf"},
Expand Down Expand Up @@ -1270,17 +1270,17 @@ async def test_gpt3_faq_embedding_predict_with_values_with_instructions(self,
{
"query": embeddings.get("dense", []),
"using": "dense",
"limit": 10
"limit": 20
},
{
"query": embeddings.get("rerank", []),
"using": "rerank",
"limit": 10
"limit": 20
},
{
"query": embeddings.get("sparse", {}),
"using": "sparse",
"limit": 10
"limit": 20
}
],
"query": {"fusion": "rrf"},
Expand Down Expand Up @@ -1383,17 +1383,17 @@ async def test_gpt3_faq_embedding_predict_completion_connection_error(self, mock
{
"query": embeddings.get("dense", []),
"using": "dense",
"limit": 10
"limit": 20
},
{
"query": embeddings.get("rerank", []),
"using": "rerank",
"limit": 10
"limit": 20
},
{
"query": embeddings.get("sparse", {}),
"using": "sparse",
"limit": 10
"limit": 20
}
],
"query": {"fusion": "rrf"},
Expand Down Expand Up @@ -1618,17 +1618,17 @@ async def test_gpt3_faq_embedding_predict_with_previous_bot_responses(self, mock
{
"query": embeddings.get("dense", []),
"using": "dense",
"limit": 10
"limit": 20
},
{
"query": embeddings.get("rerank", []),
"using": "rerank",
"limit": 10
"limit": 20
},
{
"query": embeddings.get("sparse", {}),
"using": "sparse",
"limit": 10
"limit": 20
}
],
"query": {"fusion": "rrf"},
Expand Down Expand Up @@ -1752,17 +1752,17 @@ async def test_gpt3_faq_embedding_predict_with_query_prompt(self, mock_get_embed
{
"query": embeddings.get("dense", []),
"using": "dense",
"limit": 10
"limit": 20
},
{
"query": embeddings.get("rerank", []),
"using": "rerank",
"limit": 10
"limit": 20
},
{
"query": embeddings.get("sparse", {}),
"using": "sparse",
"limit": 10
"limit": 20
}
],
"query": {"fusion": "rrf"},
Expand Down Expand Up @@ -2288,9 +2288,9 @@ async def test_collection_hybrid_query_success(self, mock_request):
headers={},
request_body={
"prefetch": [
{"query": embeddings.get("dense", []), "using": "dense", "limit": limit},
{"query": embeddings.get("rerank", []), "using": "rerank", "limit": limit},
{"query": embeddings.get("sparse", {}), "using": "sparse", "limit": limit}
{"query": embeddings.get("dense", []), "using": "dense", "limit": limit * 2},
{"query": embeddings.get("rerank", []), "using": "rerank", "limit": limit * 2},
{"query": embeddings.get("sparse", {}), "using": "sparse", "limit": limit * 2}
],
"query": {"fusion": "rrf"},
"with_payload": True,
Expand Down
Loading

0 comments on commit 3050795

Please sign in to comment.