Skip to content

Commit

Permalink
address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed Jan 22, 2024
1 parent ae8c7ff commit ef442c2
Showing 1 changed file with 14 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def delete_table(self):

delete_sql = SQL("DROP TABLE IF EXISTS {table_name}").format(table_name=Identifier(self.table_name))

self._execute_sql(delete_sql, error_msg="Could not delete table in PgvectorDocumentStore")
self._execute_sql(delete_sql, error_msg=f"Could not delete table {self.table_name} in PgvectorDocumentStore")

def _handle_hnsw(self):
"""
Expand All @@ -193,7 +193,8 @@ def _handle_hnsw(self):
if index_esists and not self.hnsw_recreate_index_if_exists:
logger.warning(
"HNSW index already exists and won't be recreated. "
"If you want to recreate it, set hnsw_recreate_index=True"
"If you want to recreate it, pass 'hnsw_recreate_index_if_exists=True' to the "
"Document Store constructor"
)
return

Expand All @@ -208,7 +209,7 @@ def _create_hnsw_index(self):
"""

pg_ops = SIMILARITY_FUNCTION_TO_POSTGRESQL_OPS[self.embedding_similarity_function]
effective_hnsw_index_creation_kwargs = {
actual_hnsw_index_creation_kwargs = {
key: value
for key, value in self.hnsw_index_creation_kwargs.items()
if key in HNSW_INDEX_CREATION_VALID_KWARGS
Expand All @@ -218,12 +219,12 @@ def _create_hnsw_index(self):
index_name=Identifier(HNSW_INDEX_NAME), table_name=Identifier(self.table_name), ops=SQL(pg_ops)
)

if effective_hnsw_index_creation_kwargs:
effective_hnsw_index_creation_kwargs_str = ", ".join(
f"{key} = {value}" for key, value in effective_hnsw_index_creation_kwargs.items()
if actual_hnsw_index_creation_kwargs:
actual_hnsw_index_creation_kwargs_str = ", ".join(
f"{key} = {value}" for key, value in actual_hnsw_index_creation_kwargs.items()
)
sql_add_creation_kwargs = SQL("WITH ({creation_kwargs_str})").format(
creation_kwargs_str=SQL(effective_hnsw_index_creation_kwargs_str)
creation_kwargs_str=SQL(actual_hnsw_index_creation_kwargs_str)
)
sql_create_index += sql_add_creation_kwargs

Expand Down Expand Up @@ -275,8 +276,8 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D

db_documents = self._from_haystack_to_pg_documents(documents)

columns_str = "(" + ", ".join(col for col, *_ in COLUMNS) + ")"
values_placeholder_str = "VALUES (" + ", ".join(f"%({col})s" for col, *_ in COLUMNS) + ")"
columns_str = "(" + ", ".join(col for col in COLUMNS) + ")"
values_placeholder_str = "VALUES (" + ", ".join(f"%({col})s" for col in COLUMNS) + ")"

insert_statement = SQL("INSERT INTO {table_name} " + columns_str + " " + values_placeholder_str).format(
table_name=Identifier(self.table_name)
Expand All @@ -285,7 +286,7 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
if policy == DuplicatePolicy.OVERWRITE:
update_statement = SQL(
"ON CONFLICT (id) DO UPDATE SET "
+ ", ".join(f"{col} = EXCLUDED.{col}" for col, *_ in COLUMNS if col != "id")
+ ", ".join(f"{col} = EXCLUDED.{col}" for col in COLUMNS if col != "id")
)
insert_statement += update_statement
elif policy == DuplicatePolicy.SKIP:
Expand Down Expand Up @@ -347,6 +348,9 @@ def _from_pg_to_haystack_documents(self, documents: List[Dict[str, Any]]) -> Lis
blob_meta = haystack_dict.pop("blob_meta")
blob_mime_type = haystack_dict.pop("blob_mime_type")

if not haystack_dict["meta"]:
haystack_dict["meta"] = {}

haystack_document = Document.from_dict(haystack_dict)

if blob_data:
Expand Down

0 comments on commit ef442c2

Please sign in to comment.