Skip to content

Commit

Permalink
working
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyaspimpalgaonkar committed Oct 20, 2024
1 parent 1be75de commit 40cbc09
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 17 deletions.
1 change: 1 addition & 0 deletions py/core/base/providers/kg.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ async def get_entities(
offset: int,
limit: int,
entity_ids: list[str] | None = None,
entity_names: list[str] | None = None,
entity_table_name: str = "entity_embedding",
) -> dict:
"""Abstract method to get entities."""
Expand Down
6 changes: 4 additions & 2 deletions py/core/pipes/kg/deduplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ async def kg_named_entity_deduplication(

entities = (
await self.kg_provider.get_entities(
collection_id=collection_id, offset=0, limit=entity_count
collection_id=collection_id, offset=0, limit=-1
)
)["entities"]

Expand Down Expand Up @@ -116,7 +116,9 @@ async def kg_named_entity_deduplication(
f"KGEntityDeduplicationPipe: Upserting {len(deduplicated_entities_list)} deduplicated entities for collection {collection_id}"
)
await self.kg_provider.add_entities(
deduplicated_entities_list, table_name="entity_deduplicated"
deduplicated_entities_list,
table_name="entity_deduplicated",
# conflict_columns=["name", "collection_id"],
)

return {
Expand Down
31 changes: 21 additions & 10 deletions py/core/pipes/kg/deduplication_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def _merge_entity_descriptions_llm_prompt(
entity_name: str,
entity_descriptions: list[str],
generation_config: GenerationConfig,
) -> str:
) -> Entity:

# find the index until the length is less than 1024
index = 0
Expand Down Expand Up @@ -89,11 +89,11 @@ async def _merge_entity_descriptions(
entity_name: str,
entity_descriptions: list[str],
generation_config: GenerationConfig,
) -> str:
) -> Entity:

# TODO: Expose this as a hyperparameter
if len(entity_descriptions) <= 5:
return "\n".join(entity_descriptions)
return Entity(name=entity_name, description="\n".join(entity_descriptions))
else:
return await self._merge_entity_descriptions_llm_prompt(
entity_name, entity_descriptions, generation_config
Expand All @@ -103,18 +103,20 @@ async def _prepare_and_upsert_entities(
self, entities_batch: list[dict], collection_id: str
) -> list[dict]:

embeddings = await self.embedding_provider.get_embeddings(
[description["description"] for description in entities_batch]
embeddings = await self.embedding_provider.async_get_embeddings(
[entity.description for entity in entities_batch]
)

for i, entity in enumerate(entities_batch):
entity.description_embedding = embeddings[i]
entity.collection_id = collection_id
entity.extraction_ids = []
entity.document_ids = []

result = await self.kg_provider.add_entities(
entities_batch,
entity_table_name="entity_deduplicated",
conflict_columns=["name", "collection_id"],
table_name="entity_deduplicated",
# conflict_columns=["name", "collection_id"],
)

logger.info(
Expand Down Expand Up @@ -156,6 +158,9 @@ async def _run_logic(
)
)["entities"]


logger.info(f"Entities: {entities}")

logger.info(
f"Retrieved {len(entities)} entities for collection {collection_id}"
)
Expand All @@ -168,15 +173,19 @@ async def _run_logic(

entity_descriptions = (
await self.kg_provider.get_entities(
collection_id,
collection_id,
offset,
limit,
-1,
entity_names=entity_names,
entity_table_name="entity_embedding",
)
)["entities"]

logger.info(f"Entity descriptions: {entity_descriptions}")
entity_descriptions_names = [entity.name for entity in entity_descriptions]

logger.info(
f"Retrieved {entity_descriptions_names} entity descriptions names for collection {collection_id}"
)

logger.info(
f"Retrieved {len(entity_descriptions)} entity descriptions for collection {collection_id}"
Expand All @@ -194,6 +203,8 @@ async def _run_logic(
f"Merging entity descriptions for collection {collection_id}"
)

logger.info(f"Entity descriptions dict: {entity_descriptions_dict}")

tasks = []
for entity in entities:
tasks.append(
Expand Down
18 changes: 13 additions & 5 deletions py/core/providers/kg/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ async def create_tables(
document_ids UUID[] NOT NULL,
collection_id UUID NOT NULL,
description_embedding {vector_column_str},
attributes JSONB
attributes JSONB,
UNIQUE (name, collection_id, attributes)
);"""

await self.execute_query(query)
Expand Down Expand Up @@ -221,6 +222,8 @@ async def _add_objects(
{on_conflict_query}
"""

logger.info(f"Query: {QUERY}")

logger.info(f"Upserting {len(objects)} objects into {table_name}")

# Filter out null values for each object
Expand Down Expand Up @@ -1069,7 +1072,7 @@ async def get_entities(
self,
collection_id: UUID,
offset: int = 0,
limit: int = 100,
limit: int = -1,
entity_ids: Optional[List[str]] = None,
entity_names: Optional[List[str]] = None,
entity_table_name: str = "entity_embedding",
Expand All @@ -1085,7 +1088,12 @@ async def get_entities(
conditions.append(f"name = ANY(${len(params) + 1})")
params.append(entity_names)

params.extend([offset, limit])
if limit != -1:
params.extend([offset, limit])
offset_limit_clause = f"OFFSET ${len(params) - 1} LIMIT ${len(params)}"
else:
params.append(offset)
offset_limit_clause = f"OFFSET ${len(params)}"

if entity_table_name == "entity_deduplicated":
# entity deduplicated table has document_ids, not document_id.
Expand All @@ -1096,7 +1104,7 @@ async def get_entities(
WHERE collection_id = $1
{" AND " + " AND ".join(conditions) if conditions else ""}
ORDER BY id
OFFSET ${len(params) - 1} LIMIT ${len(params)}
{offset_limit_clause}
"""
else:
query = f"""
Expand All @@ -1108,7 +1116,7 @@ async def get_entities(
)
{" AND " + " AND ".join(conditions) if conditions else ""}
ORDER BY id
OFFSET ${len(params) - 1} LIMIT ${len(params)}
{offset_limit_clause}
"""

results = await self.fetch_query(query, params)
Expand Down
1 change: 1 addition & 0 deletions py/shared/abstractions/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class VectorTableName(str, Enum):

CHUNKS = "chunks"
ENTITIES = "entity_embedding"
ENTITY_DEDUPLICATED = "entity_deduplicated"
# TODO: Add support for triples
# TRIPLES = "triple_raw"
COMMUNITIES = "community_report"
Expand Down

0 comments on commit 40cbc09

Please sign in to comment.