diff --git a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py index ef20346ab..392e03a87 100644 --- a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py +++ b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py @@ -32,7 +32,22 @@ meta JSON) """ -COLUMNS = [el.split()[0] for el in CREATE_TABLE_STATEMENT.splitlines()[2:-1]] +INSERT_STATEMENT = """ +INSERT INTO {table_name} +(id, embedding, content, dataframe, blob_data, blob_meta, blob_mime_type, meta) +VALUES (%(id)s, %(embedding)s, %(content)s, %(dataframe)s, %(blob_data)s, %(blob_meta)s, %(blob_mime_type)s, %(meta)s) +""" + +UPDATE_STATEMENT = """ +ON CONFLICT (id) DO UPDATE SET +embedding = EXCLUDED.embedding, +content = EXCLUDED.content, +dataframe = EXCLUDED.dataframe, +blob_data = EXCLUDED.blob_data, +blob_meta = EXCLUDED.blob_meta, +blob_mime_type = EXCLUDED.blob_mime_type, +meta = EXCLUDED.meta +""" VECTOR_FUNCTION_TO_POSTGRESQL_OPS = { "cosine_distance": "vector_cosine_ops", @@ -276,26 +291,17 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D db_documents = self._from_haystack_to_pg_documents(documents) - columns_str = "(" + ", ".join(col for col in COLUMNS) + ")" - values_placeholder_str = "VALUES (" + ", ".join(f"%({col})s" for col in COLUMNS) + ")" - - insert_statement = SQL("INSERT INTO {table_name} " + columns_str + " " + values_placeholder_str).format( - table_name=Identifier(self.table_name) - ) + sql_insert = SQL(INSERT_STATEMENT).format(table_name=Identifier(self.table_name)) if policy == DuplicatePolicy.OVERWRITE: - update_statement = SQL( - "ON CONFLICT (id) DO UPDATE SET " - + ", ".join(f"{col} = EXCLUDED.{col}" for col in COLUMNS if col != "id") - ) - insert_statement += update_statement + sql_insert += SQL(UPDATE_STATEMENT) elif policy == DuplicatePolicy.SKIP: - insert_statement += SQL("ON CONFLICT DO NOTHING") + sql_insert += SQL("ON CONFLICT DO NOTHING") - insert_statement += SQL(" RETURNING id") + sql_insert += SQL(" RETURNING id") try: - self._cursor.executemany(insert_statement, db_documents, returning=True) + self._cursor.executemany(sql_insert, db_documents, returning=True) except IntegrityError as ie: self._connection.rollback() raise DuplicateDocumentError from ie