From 40cbc09009215256c0e08057a5359bf7bb09c2b0 Mon Sep 17 00:00:00 2001 From: Shreyas Pimpalgaonkar Date: Sun, 20 Oct 2024 13:57:38 -0700 Subject: [PATCH] working --- py/core/base/providers/kg.py | 1 + py/core/pipes/kg/deduplication.py | 6 +++-- py/core/pipes/kg/deduplication_summary.py | 31 +++++++++++++++-------- py/core/providers/kg/postgres.py | 18 +++++++++---- py/shared/abstractions/vector.py | 1 + 5 files changed, 40 insertions(+), 17 deletions(-) diff --git a/py/core/base/providers/kg.py b/py/core/base/providers/kg.py index 04bfc1398..d7f2aa1cc 100644 --- a/py/core/base/providers/kg.py +++ b/py/core/base/providers/kg.py @@ -111,6 +111,7 @@ async def get_entities( offset: int, limit: int, entity_ids: list[str] | None = None, + entity_names: list[str] | None = None, entity_table_name: str = "entity_embedding", ) -> dict: """Abstract method to get entities.""" diff --git a/py/core/pipes/kg/deduplication.py b/py/core/pipes/kg/deduplication.py index d7c8ecb73..79d7ba2d8 100644 --- a/py/core/pipes/kg/deduplication.py +++ b/py/core/pipes/kg/deduplication.py @@ -58,7 +58,7 @@ async def kg_named_entity_deduplication( entities = ( await self.kg_provider.get_entities( - collection_id=collection_id, offset=0, limit=entity_count + collection_id=collection_id, offset=0, limit=-1 ) )["entities"] @@ -116,7 +116,9 @@ async def kg_named_entity_deduplication( f"KGEntityDeduplicationPipe: Upserting {len(deduplicated_entities_list)} deduplicated entities for collection {collection_id}" ) await self.kg_provider.add_entities( - deduplicated_entities_list, table_name="entity_deduplicated" + deduplicated_entities_list, + table_name="entity_deduplicated", + # conflict_columns=["name", "collection_id"], ) return { diff --git a/py/core/pipes/kg/deduplication_summary.py b/py/core/pipes/kg/deduplication_summary.py index 81d5af3fd..de18413b5 100644 --- a/py/core/pipes/kg/deduplication_summary.py +++ b/py/core/pipes/kg/deduplication_summary.py @@ -48,7 +48,7 @@ async def _merge_entity_descriptions_llm_prompt( entity_name: str, entity_descriptions: list[str], generation_config: GenerationConfig, - ) -> str: + ) -> Entity: # find the index until the length is less than 1024 index = 0 @@ -89,11 +89,11 @@ async def _merge_entity_descriptions( entity_name: str, entity_descriptions: list[str], generation_config: GenerationConfig, - ) -> str: + ) -> Entity: # TODO: Expose this as a hyperparameter if len(entity_descriptions) <= 5: - return "\n".join(entity_descriptions) + return Entity(name=entity_name, description="\n".join(entity_descriptions)) else: return await self._merge_entity_descriptions_llm_prompt( entity_name, entity_descriptions, generation_config @@ -103,18 +103,20 @@ async def _prepare_and_upsert_entities( self, entities_batch: list[dict], collection_id: str ) -> list[dict]: - embeddings = await self.embedding_provider.get_embeddings( - [description["description"] for description in entities_batch] + embeddings = await self.embedding_provider.async_get_embeddings( + [entity.description for entity in entities_batch] ) for i, entity in enumerate(entities_batch): entity.description_embedding = embeddings[i] entity.collection_id = collection_id + entity.extraction_ids = [] + entity.document_ids = [] result = await self.kg_provider.add_entities( entities_batch, - entity_table_name="entity_deduplicated", - conflict_columns=["name", "collection_id"], + table_name="entity_deduplicated", + # conflict_columns=["name", "collection_id"], ) logger.info( @@ -156,6 +158,9 @@ async def _run_logic( ) )["entities"] + + logger.info(f"Entities: {entities}") + logger.info( f"Retrieved {len(entities)} entities for collection {collection_id}" ) @@ -168,15 +173,19 @@ async def _run_logic( entity_descriptions = ( await self.kg_provider.get_entities( - collection_id, + collection_id, offset, - limit, + -1, entity_names=entity_names, entity_table_name="entity_embedding", ) )["entities"] - logger.info(f"Entity descriptions: {entity_descriptions}") + entity_descriptions_names = [entity.name for entity in entity_descriptions] + + logger.info( + f"Retrieved {entity_descriptions_names} entity descriptions names for collection {collection_id}" + ) logger.info( f"Retrieved {len(entity_descriptions)} entity descriptions for collection {collection_id}" @@ -194,6 +203,8 @@ async def _run_logic( f"Merging entity descriptions for collection {collection_id}" ) + logger.info(f"Entity descriptions dict: {entity_descriptions_dict}") + tasks = [] for entity in entities: tasks.append( diff --git a/py/core/providers/kg/postgres.py b/py/core/providers/kg/postgres.py index 0b912c309..b60fcfbc7 100644 --- a/py/core/providers/kg/postgres.py +++ b/py/core/providers/kg/postgres.py @@ -149,7 +149,8 @@ async def create_tables( document_ids UUID[] NOT NULL, collection_id UUID NOT NULL, description_embedding {vector_column_str}, - attributes JSONB + attributes JSONB, + UNIQUE (name, collection_id, attributes) );""" await self.execute_query(query) @@ -221,6 +222,8 @@ async def _add_objects( {on_conflict_query} """ + logger.info(f"Query: {QUERY}") + logger.info(f"Upserting {len(objects)} objects into {table_name}") # Filter out null values for each object @@ -1069,7 +1072,7 @@ async def get_entities( self, collection_id: UUID, offset: int = 0, - limit: int = 100, + limit: int = -1, entity_ids: Optional[List[str]] = None, entity_names: Optional[List[str]] = None, entity_table_name: str = "entity_embedding", @@ -1085,7 +1088,12 @@ async def get_entities( conditions.append(f"name = ANY(${len(params) + 1})") params.append(entity_names) - params.extend([offset, limit]) + if limit != -1: + params.extend([offset, limit]) + offset_limit_clause = f"OFFSET ${len(params) - 1} LIMIT ${len(params)}" + else: + params.append(offset) + offset_limit_clause = f"OFFSET ${len(params)}" if entity_table_name == "entity_deduplicated": # entity deduplicated table has document_ids, not document_id. @@ -1096,7 +1104,7 @@ async def get_entities( WHERE collection_id = $1 {" AND " + " AND ".join(conditions) if conditions else ""} ORDER BY id - OFFSET ${len(params) - 1} LIMIT ${len(params)} + {offset_limit_clause} """ else: query = f""" @@ -1108,7 +1116,7 @@ async def get_entities( ) {" AND " + " AND ".join(conditions) if conditions else ""} ORDER BY id - OFFSET ${len(params) - 1} LIMIT ${len(params)} + {offset_limit_clause} """ results = await self.fetch_query(query, params) diff --git a/py/shared/abstractions/vector.py b/py/shared/abstractions/vector.py index 19294f15d..567776f11 100644 --- a/py/shared/abstractions/vector.py +++ b/py/shared/abstractions/vector.py @@ -109,6 +109,7 @@ class VectorTableName(str, Enum): CHUNKS = "chunks" ENTITIES = "entity_embedding" + ENTITY_DEDUPLICATED = "entity_deduplicated" # TODO: Add support for triples # TRIPLES = "triple_raw" COMMUNITIES = "community_report"