From 11cf22dc0418a24ed1a0e5c2ed019b59b387c493 Mon Sep 17 00:00:00 2001 From: lspataroG <167472995+lspataroG@users.noreply.github.com> Date: Thu, 12 Dec 2024 13:50:00 +0100 Subject: [PATCH] fixed BQ vector search batch_search (#629) --- .../bq_storage_vectorstores/bigquery.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/libs/community/langchain_google_community/bq_storage_vectorstores/bigquery.py b/libs/community/langchain_google_community/bq_storage_vectorstores/bigquery.py index b28a1da8..fa41473b 100644 --- a/libs/community/langchain_google_community/bq_storage_vectorstores/bigquery.py +++ b/libs/community/langchain_google_community/bq_storage_vectorstores/bigquery.py @@ -301,7 +301,7 @@ def _create_search_query( if table_to_query is not None: embeddings_query = f""" with embeddings as ( - SELECT {self.embedding_field}, ROW_NUMBER() OVER() as row_num + SELECT {self.embedding_field}, row_num from `{table_to_query}` )""" @@ -390,6 +390,7 @@ def _create_temp_bq_table( df = pd.DataFrame([]) df[self.embedding_field] = embeddings + df["row_num"] = list(range(len(df))) table_id = ( f"{self.project_id}." f"{self.temp_dataset_name}." @@ -397,7 +398,8 @@ def _create_temp_bq_table( ) schema = [ - bigquery.SchemaField(self.embedding_field, "FLOAT64", mode="REPEATED") + bigquery.SchemaField(self.embedding_field, "FLOAT64", mode="REPEATED"), + bigquery.SchemaField("row_num", "INT64"), ] table_ref = bigquery.Table(table_id, schema=schema) table = self._bq_client.create_table(table_ref) @@ -483,7 +485,7 @@ def batch_search( ) if queries is not None: - embeddings = self.embedding.embed_documents(queries) + embeddings = [self.embedding.embed_query(query) for query in queries] if embeddings is None: raise ValueError("Could not obtain embeddings - value is None.")