From 0af920dd8e585322c87fb75d44ad0aac9ffbb4ad Mon Sep 17 00:00:00 2001 From: Shreyas Pimpalgaonkar Date: Sun, 20 Oct 2024 14:44:15 -0700 Subject: [PATCH] modify the update query --- py/core/pipes/kg/deduplication.py | 2 +- py/core/pipes/kg/deduplication_summary.py | 8 ++------ py/core/providers/kg/postgres.py | 15 +++++++++++++++ 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/py/core/pipes/kg/deduplication.py b/py/core/pipes/kg/deduplication.py index 7a85a72ca..ab0735bbb 100644 --- a/py/core/pipes/kg/deduplication.py +++ b/py/core/pipes/kg/deduplication.py @@ -116,7 +116,7 @@ async def kg_named_entity_deduplication( logger.info( f"KGEntityDeduplicationPipe: Upserting {len(deduplicated_entities_list)} deduplicated entities for collection {collection_id}" ) - await self.kg_provider.add_entities( + await self.kg_provider.add_entity_descriptions( deduplicated_entities_list, table_name="entity_deduplicated", conflict_columns=["name", "collection_id", 'attributes'], diff --git a/py/core/pipes/kg/deduplication_summary.py b/py/core/pipes/kg/deduplication_summary.py index 434297a34..5b9ff7b60 100644 --- a/py/core/pipes/kg/deduplication_summary.py +++ b/py/core/pipes/kg/deduplication_summary.py @@ -108,16 +108,12 @@ async def _prepare_and_upsert_entities( ) for i, entity in enumerate(entities_batch): - entity.description_embedding = embeddings[i] + entity.description_embedding = str(embeddings[i]) # type: ignore entity.collection_id = collection_id entity.attributes = {} - print(entities_batch) - - result = await self.kg_provider.add_entities( + result = await self.kg_provider.add_entity_descriptions( entities_batch, - table_name="entity_deduplicated", - conflict_columns=["name", "collection_id", 'attributes'], ) logger.info( diff --git a/py/core/providers/kg/postgres.py b/py/core/providers/kg/postgres.py index 55b79416a..1354f334d 100644 --- a/py/core/providers/kg/postgres.py +++ b/py/core/providers/kg/postgres.py @@ -1270,3 +1270,18 @@ async def get_triple_count( WHERE {" AND ".join(conditions)} """ return (await self.fetch_query(QUERY, params))[0]["count"] + + async def add_entity_descriptions(self, entities: list[Entity]): + + query = f""" + UPDATE {self._get_table_name("entity_deduplicated")} + SET description = $3, description_embedding = $4 + WHERE name = $1 AND collection_id = $2 + """ + + inputs = [ + (entity.name, entity.collection_id, entity.description, entity.description_embedding) + for entity in entities + ] + + await self.execute_many(query, inputs) \ No newline at end of file