Skip to content

Commit

Permalink
explicit create statement
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed Jan 22, 2024
1 parent f4a41e1 commit ae8c7ff
Showing 1 changed file with 18 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,19 @@

logger = logging.getLogger(__name__)

TABLE_DEFINITION = [
("id", "VARCHAR(128)", "PRIMARY KEY"),
("embedding", "VECTOR({embedding_dimension})", ""),
("content", "TEXT", ""),
("dataframe", "JSON", ""),
("blob_data", "BYTEA", ""),
("blob_meta", "JSON", ""),
("blob_mime_type", "VARCHAR(255)", ""),
("meta", "JSON", ""),
]
CREATE_TABLE_STATEMENT = """
CREATE TABLE IF NOT EXISTS {table_name} (
id VARCHAR(128) PRIMARY KEY,
embedding VECTOR({embedding_dimension}),
content TEXT,
dataframe JSON,
blob_data BYTEA,
blob_meta JSON,
blob_mime_type VARCHAR(255),
meta JSON)
"""

COLUMNS = [el.split()[0] for el in CREATE_TABLE_STATEMENT.splitlines()[2:-1]]

SIMILARITY_FUNCTION_TO_POSTGRESQL_OPS = {
"cosine_distance": "vector_cosine_ops",
Expand All @@ -49,9 +52,7 @@ def __init__(
connection_string: str,
table_name: str = "haystack_documents",
embedding_dimension: int = 768,
embedding_similarity_function: Literal[
"cosine_distance", "inner_product", "l2_distance"
] = "cosine_distance",
embedding_similarity_function: Literal["cosine_distance", "inner_product", "l2_distance"] = "cosine_distance",
recreate_table: bool = False,
search_strategy: Literal["exact_nearest_neighbor", "hnsw"] = "exact_nearest_neighbor",
hnsw_recreate_index_if_exists: bool = False,
Expand Down Expand Up @@ -154,9 +155,7 @@ def _create_table_if_not_exists(self):
Creates the table to store Haystack documents if it doesn't exist yet.
"""

table_structure_str = ", ".join(f"{name} {dtype} {constraint}" for name, dtype, constraint in TABLE_DEFINITION)

create_sql = SQL("CREATE TABLE IF NOT EXISTS {table_name} (" + table_structure_str + ")").format(
create_sql = SQL(CREATE_TABLE_STATEMENT).format(
table_name=Identifier(self.table_name), embedding_dimension=SQLLiteral(self.embedding_dimension)
)

Expand Down Expand Up @@ -276,8 +275,8 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D

db_documents = self._from_haystack_to_pg_documents(documents)

columns_str = "(" + ", ".join(col for col, *_ in TABLE_DEFINITION) + ")"
values_placeholder_str = "VALUES (" + ", ".join(f"%({col})s" for col, *_ in TABLE_DEFINITION) + ")"
columns_str = "(" + ", ".join(col for col, *_ in COLUMNS) + ")"
values_placeholder_str = "VALUES (" + ", ".join(f"%({col})s" for col, *_ in COLUMNS) + ")"

insert_statement = SQL("INSERT INTO {table_name} " + columns_str + " " + values_placeholder_str).format(
table_name=Identifier(self.table_name)
Expand All @@ -286,7 +285,7 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
if policy == DuplicatePolicy.OVERWRITE:
update_statement = SQL(
"ON CONFLICT (id) DO UPDATE SET "
+ ", ".join(f"{col} = EXCLUDED.{col}" for col, *_ in TABLE_DEFINITION if col != "id")
+ ", ".join(f"{col} = EXCLUDED.{col}" for col, *_ in COLUMNS if col != "id")
)
insert_statement += update_statement
elif policy == DuplicatePolicy.SKIP:
Expand Down

0 comments on commit ae8c7ff

Please sign in to comment.