diff --git a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py index c7458813e..f85412622 100644 --- a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py +++ b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py @@ -20,16 +20,19 @@ logger = logging.getLogger(__name__) -TABLE_DEFINITION = [ - ("id", "VARCHAR(128)", "PRIMARY KEY"), - ("embedding", "VECTOR({embedding_dimension})", ""), - ("content", "TEXT", ""), - ("dataframe", "JSON", ""), - ("blob_data", "BYTEA", ""), - ("blob_meta", "JSON", ""), - ("blob_mime_type", "VARCHAR(255)", ""), - ("meta", "JSON", ""), -] +CREATE_TABLE_STATEMENT = """ +CREATE TABLE IF NOT EXISTS {table_name} ( +id VARCHAR(128) PRIMARY KEY, +embedding VECTOR({embedding_dimension}), +content TEXT, +dataframe JSON, +blob_data BYTEA, +blob_meta JSON, +blob_mime_type VARCHAR(255), +meta JSON) +""" + +COLUMNS = [el.split()[0] for el in CREATE_TABLE_STATEMENT.splitlines()[2:-1]] SIMILARITY_FUNCTION_TO_POSTGRESQL_OPS = { "cosine_distance": "vector_cosine_ops", @@ -49,9 +52,7 @@ def __init__( connection_string: str, table_name: str = "haystack_documents", embedding_dimension: int = 768, - embedding_similarity_function: Literal[ - "cosine_distance", "inner_product", "l2_distance" - ] = "cosine_distance", + embedding_similarity_function: Literal["cosine_distance", "inner_product", "l2_distance"] = "cosine_distance", recreate_table: bool = False, search_strategy: Literal["exact_nearest_neighbor", "hnsw"] = "exact_nearest_neighbor", hnsw_recreate_index_if_exists: bool = False, @@ -154,9 +155,7 @@ def _create_table_if_not_exists(self): Creates the table to store Haystack documents if it doesn't exist yet. """ - table_structure_str = ", ".join(f"{name} {dtype} {constraint}" for name, dtype, constraint in TABLE_DEFINITION) - - create_sql = SQL("CREATE TABLE IF NOT EXISTS {table_name} (" + table_structure_str + ")").format( + create_sql = SQL(CREATE_TABLE_STATEMENT).format( table_name=Identifier(self.table_name), embedding_dimension=SQLLiteral(self.embedding_dimension) ) @@ -276,8 +275,8 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D db_documents = self._from_haystack_to_pg_documents(documents) - columns_str = "(" + ", ".join(col for col, *_ in TABLE_DEFINITION) + ")" - values_placeholder_str = "VALUES (" + ", ".join(f"%({col})s" for col, *_ in TABLE_DEFINITION) + ")" + columns_str = "(" + ", ".join(col for col, *_ in COLUMNS) + ")" + values_placeholder_str = "VALUES (" + ", ".join(f"%({col})s" for col, *_ in COLUMNS) + ")" insert_statement = SQL("INSERT INTO {table_name} " + columns_str + " " + values_placeholder_str).format( table_name=Identifier(self.table_name) @@ -286,7 +285,7 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D if policy == DuplicatePolicy.OVERWRITE: update_statement = SQL( "ON CONFLICT (id) DO UPDATE SET " - + ", ".join(f"{col} = EXCLUDED.{col}" for col, *_ in TABLE_DEFINITION if col != "id") + + ", ".join(f"{col} = EXCLUDED.{col}" for col, *_ in COLUMNS if col != "id") ) insert_statement += update_statement elif policy == DuplicatePolicy.SKIP: