Skip to content

Commit

Permalink
explicit insert and update statements
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed Jan 23, 2024
1 parent 5c3d1ec commit 685462d
Showing 1 changed file with 21 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,22 @@
meta JSON)
"""

COLUMNS = [el.split()[0] for el in CREATE_TABLE_STATEMENT.splitlines()[2:-1]]
INSERT_STATEMENT = """
INSERT INTO {table_name}
(id, embedding, content, dataframe, blob_data, blob_meta, blob_mime_type, meta)
VALUES (%(id)s, %(embedding)s, %(content)s, %(dataframe)s, %(blob_data)s, %(blob_meta)s, %(blob_mime_type)s, %(meta)s)
"""

UPDATE_STATEMENT = """
ON CONFLICT (id) DO UPDATE SET
embedding = EXCLUDED.embedding,
content = EXCLUDED.content,
dataframe = EXCLUDED.dataframe,
blob_data = EXCLUDED.blob_data,
blob_meta = EXCLUDED.blob_meta,
blob_mime_type = EXCLUDED.blob_mime_type,
meta = EXCLUDED.meta
"""

VECTOR_FUNCTION_TO_POSTGRESQL_OPS = {
"cosine_distance": "vector_cosine_ops",
Expand Down Expand Up @@ -276,26 +291,17 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D

db_documents = self._from_haystack_to_pg_documents(documents)

columns_str = "(" + ", ".join(col for col in COLUMNS) + ")"
values_placeholder_str = "VALUES (" + ", ".join(f"%({col})s" for col in COLUMNS) + ")"

insert_statement = SQL("INSERT INTO {table_name} " + columns_str + " " + values_placeholder_str).format(
table_name=Identifier(self.table_name)
)
sql_insert = SQL(INSERT_STATEMENT).format(table_name=Identifier(self.table_name))

if policy == DuplicatePolicy.OVERWRITE:
update_statement = SQL(
"ON CONFLICT (id) DO UPDATE SET "
+ ", ".join(f"{col} = EXCLUDED.{col}" for col in COLUMNS if col != "id")
)
insert_statement += update_statement
sql_insert += SQL(UPDATE_STATEMENT)
elif policy == DuplicatePolicy.SKIP:
insert_statement += SQL("ON CONFLICT DO NOTHING")
sql_insert += SQL("ON CONFLICT DO NOTHING")

insert_statement += SQL(" RETURNING id")
sql_insert += SQL(" RETURNING id")

try:
self._cursor.executemany(insert_statement, db_documents, returning=True)
self._cursor.executemany(sql_insert, db_documents, returning=True)
except IntegrityError as ie:
self._connection.rollback()
raise DuplicateDocumentError from ie
Expand Down

0 comments on commit 685462d

Please sign in to comment.