diff --git a/kairon/shared/llm/processor.py b/kairon/shared/llm/processor.py index 07388c1a2..d14754f04 100644 --- a/kairon/shared/llm/processor.py +++ b/kairon/shared/llm/processor.py @@ -61,7 +61,7 @@ async def train(self, user, *args, **kwargs) -> Dict: await self.__delete_collections() count = 0 processor = CognitionDataProcessor() - batch_size = 100 + batch_size = 50 collections_data = CognitionData.objects(bot=self.bot) collection_groups = {} @@ -320,17 +320,17 @@ async def __collection_hybrid_query__(self, collection_name: Text, embeddings: D { "query": embeddings.get("dense", []), "using": "dense", - "limit": limit + "limit": limit * 2 }, { "query": embeddings.get("rerank", []), "using": "rerank", - "limit": limit + "limit": limit * 2 }, { "query": embeddings.get("sparse", {}), "using": "sparse", - "limit": limit + "limit": limit * 2 } ], "query": {"fusion": "rrf"}, diff --git a/kairon/shared/vector_embeddings/db/qdrant.py b/kairon/shared/vector_embeddings/db/qdrant.py index b4b92ece3..6f05f1032 100644 --- a/kairon/shared/vector_embeddings/db/qdrant.py +++ b/kairon/shared/vector_embeddings/db/qdrant.py @@ -37,7 +37,24 @@ async def perform_operation(self, data: Dict, user: str, **kwargs): user_msg = data.get(DbActionOperationType.embedding_search) if user_msg and isinstance(user_msg, str): vector = await self.__get_embedding(user_msg, user, **kwargs) - request['query'] = vector + request['prefetch'] = [ + { + "query": vector.get("dense", []), + "using": "dense", + "limit": 20 + }, + { + "query": vector.get("rerank", []), + "using": "rerank", + "limit": 20 + }, + { + "query": vector.get("sparse", {}), + "using": "sparse", + "limit": 20 + } + ] + request.update({"query": {"fusion": "rrf"}}) if DbActionOperationType.payload_search in data: payload = data.get(DbActionOperationType.payload_search) diff --git a/tests/integration_test/action_service_test.py b/tests/integration_test/action_service_test.py index a874df12a..e5d0b42e8 100644 --- a/tests/integration_test/action_service_test.py +++ b/tests/integration_test/action_service_test.py @@ -4137,8 +4137,32 @@ def test_vectordb_action_execution_embedding_search_from_value(mock_get_embeddin url=http_url, body=resp_msg, status=200, - match=[responses.matchers.json_params_matcher({'query': embeddings, - 'with_payload': True, 'limit': 10})], + match=[ + responses.matchers.json_params_matcher( + { + 'prefetch': [ + { + "query": embeddings.get("dense", []), + "using": "dense", + "limit": 20 + }, + { + "query": embeddings.get("rerank", []), + "using": "rerank", + "limit": 20 + }, + { + "query": embeddings.get("sparse", {}), + "using": "sparse", + "limit": 20 + } + ], + 'query': {"fusion": "rrf"}, + 'with_payload': True, + 'limit': 10 + } + ) + ], ) request_object = { @@ -4397,8 +4421,32 @@ def test_vectordb_action_execution_embedding_search_from_slot(mock_get_embedding url=http_url, body=resp_msg, status=200, - match=[responses.matchers.json_params_matcher({'query': embeddings, - 'with_payload': True, 'limit': 10})], + match=[ + responses.matchers.json_params_matcher( + { + 'prefetch': [ + { + "query": embeddings.get("dense", []), + "using": "dense", + "limit": 20 + }, + { + "query": embeddings.get("rerank", []), + "using": "rerank", + "limit": 20 + }, + { + "query": embeddings.get("sparse", {}), + "using": "sparse", + "limit": 20 + } + ], + 'query': {"fusion": "rrf"}, + 'with_payload': True, + 'limit': 10 + } + ) + ], ) request_object = { @@ -4509,8 +4557,32 @@ def test_vectordb_action_execution_embedding_search_no_response_dispatch(mock_ge url=http_url, body=resp_msg, status=200, - match=[responses.matchers.json_params_matcher({'query': embeddings, - 'with_payload': True, 'limit': 10})], + match=[ + responses.matchers.json_params_matcher( + { + 'prefetch': [ + { + "query": embeddings.get("dense", []), + "using": "dense", + "limit": 20 + }, + { + "query": embeddings.get("rerank", []), + "using": "rerank", + "limit": 20 + }, + { + "query": embeddings.get("sparse", {}), + "using": "sparse", + "limit": 20 + } + ], + 'query': {"fusion": "rrf"}, + 'with_payload': True, + 'limit': 10 + } + ) + ], ) request_object = { @@ -14235,10 +14307,34 @@ def test_vectordb_action_execution_embedding_payload_search(mock_get_embedding): url=http_url, body=resp_msg, status=200, - match=[responses.matchers.json_params_matcher({'with_payload': True, - 'limit': 10, - 'query': embeddings, - **payload}, strict_match=False)], + match=[ + responses.matchers.json_params_matcher( + { + 'prefetch': [ + { + "query": embeddings.get("dense", []), + "using": "dense", + "limit": 20 + }, + { + "query": embeddings.get("rerank", []), + "using": "rerank", + "limit": 20 + }, + { + "query": embeddings.get("sparse", {}), + "using": "sparse", + "limit": 20 + } + ], + 'query': {"fusion": "rrf"}, + 'with_payload': True, + 'limit': 10, + **payload + }, + strict_match=False + ) + ], ) request_object = { diff --git a/tests/unit_test/llm_test.py b/tests/unit_test/llm_test.py index 93db21a10..288d94565 100644 --- a/tests/unit_test/llm_test.py +++ b/tests/unit_test/llm_test.py @@ -769,17 +769,17 @@ async def test_gpt3_faq_embedding_predict(self, mock_get_embedding, aioresponses { "query": embeddings.get("dense", []), "using": "dense", - "limit": 10 + "limit": 20 }, { "query": embeddings.get("rerank", []), "using": "rerank", - "limit": 10 + "limit": 20 }, { "query": embeddings.get("sparse", {}), "using": "sparse", - "limit": 10 + "limit": 20 } ], "query": {"fusion": "rrf"}, @@ -881,17 +881,17 @@ async def test_gpt3_faq_embedding_predict_with_default_collection(self, mock_get { "query": embeddings.get("dense", []), "using": "dense", - "limit": 10 + "limit": 20 }, { "query": embeddings.get("rerank", []), "using": "rerank", - "limit": 10 + "limit": 20 }, { "query": embeddings.get("sparse", {}), "using": "sparse", - "limit": 10 + "limit": 20 } ], "query": {"fusion": "rrf"}, @@ -998,17 +998,17 @@ async def test_gpt3_faq_embedding_predict_with_values(self, mock_get_embedding, { "query": embeddings.get("dense", []), "using": "dense", - "limit": 10 + "limit": 20 }, { "query": embeddings.get("rerank", []), "using": "rerank", - "limit": 10 + "limit": 20 }, { "query": embeddings.get("sparse", {}), "using": "sparse", - "limit": 10 + "limit": 20 } ], "query": {"fusion": "rrf"}, @@ -1112,17 +1112,17 @@ async def test_gpt3_faq_embedding_predict_with_values_and_stream(self, mock_get_ { "query": embeddings.get("dense", []), "using": "dense", - "limit": 10 + "limit": 20 }, { "query": embeddings.get("rerank", []), "using": "rerank", - "limit": 10 + "limit": 20 }, { "query": embeddings.get("sparse", {}), "using": "sparse", - "limit": 10 + "limit": 20 } ], "query": {"fusion": "rrf"}, @@ -1270,17 +1270,17 @@ async def test_gpt3_faq_embedding_predict_with_values_with_instructions(self, { "query": embeddings.get("dense", []), "using": "dense", - "limit": 10 + "limit": 20 }, { "query": embeddings.get("rerank", []), "using": "rerank", - "limit": 10 + "limit": 20 }, { "query": embeddings.get("sparse", {}), "using": "sparse", - "limit": 10 + "limit": 20 } ], "query": {"fusion": "rrf"}, @@ -1383,17 +1383,17 @@ async def test_gpt3_faq_embedding_predict_completion_connection_error(self, mock { "query": embeddings.get("dense", []), "using": "dense", - "limit": 10 + "limit": 20 }, { "query": embeddings.get("rerank", []), "using": "rerank", - "limit": 10 + "limit": 20 }, { "query": embeddings.get("sparse", {}), "using": "sparse", - "limit": 10 + "limit": 20 } ], "query": {"fusion": "rrf"}, @@ -1618,17 +1618,17 @@ async def test_gpt3_faq_embedding_predict_with_previous_bot_responses(self, mock { "query": embeddings.get("dense", []), "using": "dense", - "limit": 10 + "limit": 20 }, { "query": embeddings.get("rerank", []), "using": "rerank", - "limit": 10 + "limit": 20 }, { "query": embeddings.get("sparse", {}), "using": "sparse", - "limit": 10 + "limit": 20 } ], "query": {"fusion": "rrf"}, @@ -1752,17 +1752,17 @@ async def test_gpt3_faq_embedding_predict_with_query_prompt(self, mock_get_embed { "query": embeddings.get("dense", []), "using": "dense", - "limit": 10 + "limit": 20 }, { "query": embeddings.get("rerank", []), "using": "rerank", - "limit": 10 + "limit": 20 }, { "query": embeddings.get("sparse", {}), "using": "sparse", - "limit": 10 + "limit": 20 } ], "query": {"fusion": "rrf"}, @@ -2288,9 +2288,9 @@ async def test_collection_hybrid_query_success(self, mock_request): headers={}, request_body={ "prefetch": [ - {"query": embeddings.get("dense", []), "using": "dense", "limit": limit}, - {"query": embeddings.get("rerank", []), "using": "rerank", "limit": limit}, - {"query": embeddings.get("sparse", {}), "using": "sparse", "limit": limit} + {"query": embeddings.get("dense", []), "using": "dense", "limit": limit * 2}, + {"query": embeddings.get("rerank", []), "using": "rerank", "limit": limit * 2}, + {"query": embeddings.get("sparse", {}), "using": "sparse", "limit": limit * 2} ], "query": {"fusion": "rrf"}, "with_payload": True, diff --git a/tests/unit_test/vector_embeddings/qdrant_test.py b/tests/unit_test/vector_embeddings/qdrant_test.py index ef5db079c..991291449 100644 --- a/tests/unit_test/vector_embeddings/qdrant_test.py +++ b/tests/unit_test/vector_embeddings/qdrant_test.py @@ -148,8 +148,27 @@ async def test_embedding_search_valid_request_body_payload(self, mock_http_reque mock_http_request.assert_called_once() called_args = mock_http_request.call_args called_payload = called_args.kwargs['request_body'] - assert called_payload == {'query': embeddings, - 'with_payload': True, - 'limit': 10} + assert called_payload == { + 'prefetch': [ + { + "query": embeddings.get("dense", []), + "using": "dense", + "limit": 20 + }, + { + "query": embeddings.get("rerank", []), + "using": "rerank", + "limit": 20 + }, + { + "query": embeddings.get("sparse", {}), + "using": "sparse", + "limit": 20 + } + ], + 'query': {"fusion": "rrf"}, + 'with_payload': True, + 'limit': 10 + } assert called_args.kwargs['http_url'] == 'http://localhost:6333/collections/5f50fd0a56v098ca10d75d2g/points/query' assert called_args.kwargs['request_method'] == 'POST'