Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyaspimpalgaonkar committed Oct 22, 2024
1 parent 6b1de83 commit f31bbc6
Show file tree
Hide file tree
Showing 15 changed files with 149 additions and 57 deletions.
7 changes: 7 additions & 0 deletions py/core/base/providers/kg.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,13 @@ async def get_creation_estimate(self, *args: Any, **kwargs: Any) -> Any:
"""Abstract method to get the creation estimate."""
pass

@abstractmethod
async def get_deduplication_estimate(
self, *args: Any, **kwargs: Any
) -> Any:
"""Abstract method to get the deduplication estimate."""
pass

@abstractmethod
async def get_enrichment_estimate(self, *args: Any, **kwargs: Any) -> Any:
"""Abstract method to get the enrichment estimate."""
Expand Down
1 change: 1 addition & 0 deletions py/core/configs/full.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ overlap = 256

# turned off by default
# [ingestion.chunk_enrichment_settings]
# enable_chunk_enrichment = true
# strategies = ["semantic", "neighborhood"]
# forward_chunks = 3
# backward_chunks = 3
Expand Down
19 changes: 11 additions & 8 deletions py/core/main/api/kg_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,14 +355,12 @@ async def deduplicate_entities(
if not run_type:
run_type = KGRunType.ESTIMATE


if run_type == KGRunType.ESTIMATE:
return await self.service.get_deduplication_estimate(
collection_id, server_deduplication_settings
)

server_deduplication_settings = (
self.service.providers.kg.config.kg_entity_deduplication_settings.dict()
self.service.providers.kg.config.kg_entity_deduplication_settings
)

logger.info(
f"Server deduplication settings: {server_deduplication_settings}"
)

if deduplication_settings:
Expand All @@ -375,10 +373,15 @@ async def deduplicate_entities(
)
logger.info(f"Input data: {server_deduplication_settings}")

if run_type == KGRunType.ESTIMATE:
return await self.service.get_deduplication_estimate(
collection_id, server_deduplication_settings
)

workflow_input = {
"collection_id": str(collection_id),
"run_type": run_type,
"kg_entity_deduplication_settings": server_deduplication_settings,
"kg_entity_deduplication_settings": server_deduplication_settings.model_dump_json(),
"user": auth_user.json(),
}

Expand Down
38 changes: 29 additions & 9 deletions py/core/main/orchestration/hatchet/ingestion_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,6 @@ async def parse(self, context: Context) -> dict:
"is_update"
)

# add contextual chunks
await self.ingestion_service.chunk_enrichment(
document_id=document_info.id,
)

# delete original chunks

# delete original chunk vectors

await self.ingestion_service.finalize_ingestion(
document_info, is_update=is_update
)
Expand All @@ -171,10 +162,39 @@ async def parse(self, context: Context) -> dict:
document_id=document_info.id, collection_id=collection_id
)

await self.ingestion_service.update_document_status(
document_info,
status=IngestionStatus.SUCCESS,
)

# add contextual chunks
if getattr(
service.providers.ingestion.config,
"enable_chunk_enrichment",
False,
):

logger.info("Enriching document with contextual chunks")

await self.ingestion_service.update_document_status(
document_info,
status=IngestionStatus.ENRICHING,
)

await self.ingestion_service.chunk_enrichment(
document_id=document_info.id,
)

await self.ingestion_service.update_document_status(
document_info,
status=IngestionStatus.ENRICHED,
)

return {
"status": "Successfully finalized ingestion",
"document_info": document_info.to_dict(),
}

except AuthenticationError as e:
raise R2RException(
status_code=401,
Expand Down
38 changes: 35 additions & 3 deletions py/core/main/orchestration/hatchet/kg_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,25 @@ def get_input_data_dict(input_data):
)
if key == "kg_enrichment_settings":
input_data[key] = json.loads(value)
input_data[key]["generation_config"] = GenerationConfig(
**input_data[key]["generation_config"]
)

if key == "kg_entity_deduplication_settings":
input_data[key] = json.loads(value)

if isinstance(input_data[key]["generation_config"], str):
input_data[key]["generation_config"] = json.loads(
input_data[key]["generation_config"]
)

input_data[key]["generation_config"] = GenerationConfig(
**input_data[key]["generation_config"]
)

logger.info(
f"KG Entity Deduplication Settings: {input_data[key]}"
)

if key == "generation_config":
input_data[key] = GenerationConfig(**input_data[key])
Expand Down Expand Up @@ -246,6 +265,12 @@ async def kg_entity_deduplication_setup(
)
)[0]["num_entities"]

input_data["kg_entity_deduplication_settings"][
"generation_config"
] = input_data["kg_entity_deduplication_settings"][
"generation_config"
].model_dump_json()

# run 100 entities in one workflow
total_workflows = math.ceil(number_of_distinct_entities / 100)
workflows = []
Expand All @@ -259,9 +284,11 @@ async def kg_entity_deduplication_setup(
"collection_id": collection_id,
"offset": offset,
"limit": 100,
"kg_entity_deduplication_settings": input_data[
"kg_entity_deduplication_settings"
],
"kg_entity_deduplication_settings": json.dumps(
input_data[
"kg_entity_deduplication_settings"
]
),
}
},
key=f"{i}/{total_workflows}_entity_deduplication_part",
Expand All @@ -284,6 +311,11 @@ def __init__(self, kg_service: KgService):
async def kg_entity_deduplication_summary(
self, context: Context
) -> dict:

logger.info(
f"Running KG Entity Deduplication Summary for input data: {context.workflow_input()['request']}"
)

input_data = get_input_data_dict(
context.workflow_input()["request"]
)
Expand Down
17 changes: 15 additions & 2 deletions py/core/main/services/ingestion_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,17 @@ async def _get_enriched_chunk_text(
context_chunk_texts = []
for context_chunk_id in context_chunk_ids:
context_chunk_texts.append(
document_chunks_dict[context_chunk_id]["text"]
(
document_chunks_dict[context_chunk_id]["text"],
document_chunks_dict[context_chunk_id]["metadata"][
"chunk_order"
],
)
)

# sort by chunk_order
context_chunk_texts.sort(key=lambda x: x[1])

# enrich chunk
# get prompt and call LLM on it. Then finally embed and store it.
# don't call a pipe here.
Expand All @@ -412,7 +420,12 @@ async def _get_enriched_chunk_text(
task_prompt_name="chunk_enrichment",
task_inputs={
"context_chunks": (
"\n".join(context_chunk_texts)
"\n".join(
[
text
for text, _ in context_chunk_texts
]
)
),
"chunk": chunk["text"],
},
Expand Down
3 changes: 1 addition & 2 deletions py/core/main/services/kg_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,6 @@ async def get_communities(
community_numbers,
)


@telemetry_event("get_deduplication_estimate")
async def get_deduplication_estimate(
self,
Expand Down Expand Up @@ -435,4 +434,4 @@ async def kg_entity_deduplication_summary(
run_manager=self.run_manager,
)

return await _collect_results(deduplication_summary_results)
return await _collect_results(deduplication_summary_results)
2 changes: 1 addition & 1 deletion py/core/pipes/kg/deduplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,4 +152,4 @@ async def _run_logic(
else:
raise NotImplementedError(
f"KGEntityDeduplicationPipe: Deduplication type {kg_entity_deduplication_type} not implemented"
)
)
4 changes: 3 additions & 1 deletion py/core/providers/database/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,9 @@ async def create_index(
"""

if table_name == VectorTableName.RAW_CHUNKS:
table_name_str = f"{self.project_name}.{VectorTableName.RAW_CHUNKS}"
table_name_str = (
f"{self.project_name}.{VectorTableName.RAW_CHUNKS}"
)
col_name = "vec"
elif table_name == VectorTableName.ENTITIES_DOCUMENT:
table_name_str = (
Expand Down
47 changes: 28 additions & 19 deletions py/core/providers/kg/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
KGProvider,
Triple,
)
from shared.abstractions import KGCreationSettings, KGEnrichmentSettings, KGEntityDeduplicationSettings
from shared.abstractions import (
KGCreationSettings,
KGEnrichmentSettings,
KGEntityDeduplicationSettings,
)
from shared.abstractions.vector import VectorQuantizationType
from shared.api.models.kg.responses import (
KGCreationEstimationResponse,
Expand Down Expand Up @@ -229,10 +233,6 @@ async def _add_objects(
{on_conflict_query}
"""

logger.info(f"Query: {QUERY}")

logger.info(f"Upserting {len(objects)} objects into {table_name}")

# Filter out null values for each object
params = [
tuple(
Expand All @@ -242,7 +242,6 @@ async def _add_objects(
)
for obj in objects
]
logger.info(f"Upserting {len(params)} params into {table_name}")

return await self.execute_many(QUERY, params) # type: ignore

Expand Down Expand Up @@ -277,8 +276,6 @@ async def add_entities(
)
cleaned_entities.append(entity_dict)

logger.info(f"Upserting {len(entities)} entities into {table_name}")

return await self._add_objects(
cleaned_entities, table_name, conflict_columns
)
Expand Down Expand Up @@ -1341,11 +1338,15 @@ async def update_entity_descriptions(self, entities: list[Entity]):

await self.execute_many(query, inputs) # type: ignore

async def get_deduplication_estimate(self, collection_id: UUID, kg_deduplication_settings: KGEntityDeduplicationSettings):
async def get_deduplication_estimate(
self,
collection_id: UUID,
kg_deduplication_settings: KGEntityDeduplicationSettings,
):
# number of documents in collection
query = f"""
SELECT name, count(name)
FROM {self._get_table_name("entity_document")}
FROM {self._get_table_name("entity_embedding")}
WHERE document_id = ANY(
SELECT document_id FROM {self._get_table_name("document_info")}
WHERE $1 = ANY(collection_ids)
Expand All @@ -1356,10 +1357,10 @@ async def get_deduplication_estimate(self, collection_id: UUID, kg_deduplication
entities = await self.fetch_query(query, [collection_id])
num_entities = len(entities)

estimated_llm_calls = num_entities
estimated_llm_calls = (num_entities, num_entities)
estimated_total_in_out_tokens_in_millions = (
estimated_llm_calls * 1000 / 1000000,
estimated_llm_calls * 5000 / 1000000,
estimated_llm_calls[0] * 1000 / 1000000,
estimated_llm_calls[1] * 5000 / 1000000,
)
estimated_cost_in_usd = (
estimated_total_in_out_tokens_in_millions[0]
Expand All @@ -1378,10 +1379,18 @@ async def get_deduplication_estimate(self, collection_id: UUID, kg_deduplication
)

return KGDeduplicationEstimationResponse(
message="Ran Deduplication Estimate (not the actual run). Note that these are estimated ranges, actual values may vary. To run the Deduplication process, run `deduplicate-entities` with `--run` in the cli, or `run_type=\"run\"` in the client.",
message='Ran Deduplication Estimate (not the actual run). Note that these are estimated ranges, actual values may vary. To run the Deduplication process, run `deduplicate-entities` with `--run` in the cli, or `run_type="run"` in the client.',
num_entities=num_entities,
estimated_llm_calls=self._get_str_estimation_output(estimated_llm_calls),
estimated_total_in_out_tokens_in_millions=self._get_str_estimation_output(estimated_total_in_out_tokens_in_millions),
estimated_cost_in_usd=self._get_str_estimation_output(estimated_cost_in_usd),
estimated_total_time_in_minutes=self._get_str_estimation_output(estimated_total_time_in_minutes),
)
estimated_llm_calls=self._get_str_estimation_output(
estimated_llm_calls
),
estimated_total_in_out_tokens_in_millions=self._get_str_estimation_output(
estimated_total_in_out_tokens_in_millions
),
estimated_cost_in_usd=self._get_str_estimation_output(
estimated_cost_in_usd
),
estimated_total_time_in_minutes=self._get_str_estimation_output(
estimated_total_time_in_minutes
),
)
2 changes: 1 addition & 1 deletion py/r2r.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ batch_size = 256
max_description_input_length = 65536
generation_config = { model = "openai/gpt-4o-mini" } # and other params, model used for triplet extraction

[kg.kg_deduplication_settings]
[kg.kg_entity_deduplication_settings]
kg_entity_deduplication_type = "by_name"
kg_entity_deduplication_prompt = "graphrag_entity_deduplication"
max_description_input_length = 65536
Expand Down
2 changes: 2 additions & 0 deletions py/shared/abstractions/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ class IngestionStatus(str, Enum):
CHUNKING = "chunking"
EMBEDDING = "embedding"
STORING = "storing"
ENRICHING = "enriching"
ENRICHED = "enriched"

FAILED = "failed"
SUCCESS = "success"
Expand Down
8 changes: 6 additions & 2 deletions py/shared/abstractions/ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,21 @@ class ChunkEnrichmentSettings(R2RSerializable):
Settings for chunk enrichment.
"""

enable_chunk_enrichment: bool = Field(
default=False,
description="Whether to enable chunk enrichment or not",
)
strategies: list[ChunkEnrichmentStrategy] = Field(
default=[],
description="The strategies to use for chunk enrichment. Union of chunks obtained from each strategy is used as context.",
)
forward_chunks: int = Field(
default=3,
description="The number of chunks to include before the current chunk",
description="The number after the current chunk to include in the LLM context while enriching",
)
backward_chunks: int = Field(
default=3,
description="The number of chunks to include after the current chunk",
description="The number of chunks before the current chunk in the LLM context while enriching",
)
semantic_neighbors: int = Field(
default=10, description="The number of semantic neighbors to include"
Expand Down
1 change: 0 additions & 1 deletion py/shared/abstractions/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ class VectorTableName(str, Enum):
This enum represents the different tables where we store vectors.
"""


RAW_CHUNKS = "vector"
ENTITIES_DOCUMENT = "entity_embedding"
ENTITIES_COLLECTION = "entity_collection"
Expand Down
Loading

0 comments on commit f31bbc6

Please sign in to comment.