From fe42fbd038bb7dcd3a928bdbdb7c6bad073f0af6 Mon Sep 17 00:00:00 2001 From: Shreyas Pimpalgaonkar Date: Sun, 20 Oct 2024 15:58:13 -0700 Subject: [PATCH] Contextual Chunk Enrichment (#1433) * add semantic chunking * working * precommit * pre-commits --- py/core/base/providers/ingestion.py | 5 +- py/core/configs/full.toml | 11 +- .../hatchet/ingestion_workflow.py | 9 + py/core/main/services/ingestion_service.py | 182 ++++++++++++++++++ .../pipes/ingestion/chunk_enrichment_pipe.py | 25 +++ py/core/providers/database/vector.py | 50 +++++ .../prompts/defaults/chunk_enrichment.yaml | 21 ++ py/r2r.toml | 7 +- py/shared/abstractions/ingestion.py | 46 +++++ 9 files changed, 351 insertions(+), 5 deletions(-) create mode 100644 py/core/pipes/ingestion/chunk_enrichment_pipe.py create mode 100644 py/core/providers/prompts/defaults/chunk_enrichment.yaml create mode 100644 py/shared/abstractions/ingestion.py diff --git a/py/core/base/providers/ingestion.py b/py/core/base/providers/ingestion.py index 6ff0be1d4..1bb7be3f3 100644 --- a/py/core/base/providers/ingestion.py +++ b/py/core/base/providers/ingestion.py @@ -1,8 +1,8 @@ import logging from abc import ABC from enum import Enum - from .base import Provider, ProviderConfig +from shared.abstractions.ingestion import ChunkEnrichmentSettings logger = logging.getLogger() @@ -10,6 +10,9 @@ class IngestionConfig(ProviderConfig): provider: str = "r2r" excluded_parsers: list[str] = ["mp4"] + chunk_enrichment_settings: ChunkEnrichmentSettings = ( + ChunkEnrichmentSettings() + ) extra_parsers: dict[str, str] = {} @property diff --git a/py/core/configs/full.toml b/py/core/configs/full.toml index 43b6641e7..b30d841be 100644 --- a/py/core/configs/full.toml +++ b/py/core/configs/full.toml @@ -6,8 +6,17 @@ new_after_n_chars = 512 max_characters = 1_024 combine_under_n_chars = 128 overlap = 256 + [ingestion.extra_parsers] - pdf = "zerox" + pdf = "zerox" + + [ingestion.chunk_enrichment_settings] + strategies = ["semantic", "neighborhood"] + forward_chunks = 3 + backward_chunks = 3 + semantic_neighbors = 10 + semantic_similarity_threshold = 0.7 + generation_config = { model = "azure/gpt-4o-mini" } [orchestration] provider = "hatchet" diff --git a/py/core/main/orchestration/hatchet/ingestion_workflow.py b/py/core/main/orchestration/hatchet/ingestion_workflow.py index 84b230581..2f9c765c3 100644 --- a/py/core/main/orchestration/hatchet/ingestion_workflow.py +++ b/py/core/main/orchestration/hatchet/ingestion_workflow.py @@ -141,6 +141,15 @@ async def parse(self, context: Context) -> dict: "is_update" ) + # add contextual chunks + await self.ingestion_service.chunk_enrichment( + document_id=document_info.id, + ) + + # delete original chunks + + # delete original chunk vectors + await self.ingestion_service.finalize_ingestion( document_info, is_update=is_update ) diff --git a/py/core/main/services/ingestion_service.py b/py/core/main/services/ingestion_service.py index bd93b9ddf..d1d2dbf3c 100644 --- a/py/core/main/services/ingestion_service.py +++ b/py/core/main/services/ingestion_service.py @@ -1,6 +1,8 @@ import json import logging +import uuid from datetime import datetime +import asyncio from typing import Any, AsyncGenerator, Optional, Sequence, Union from uuid import UUID @@ -14,7 +16,9 @@ R2RLoggingProvider, RawChunk, RunManager, + Vector, VectorEntry, + VectorType, decrement_version, ) from core.base.api.models import UserResponse @@ -25,6 +29,11 @@ VectorTableName, ) +from shared.abstractions.ingestion import ( + ChunkEnrichmentStrategy, + ChunkEnrichmentSettings, +) +from core.base import IngestionConfig from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders from ..config import R2RConfig from .base import Service @@ -345,6 +354,164 @@ async def ingest_chunks_ingress( return document_info + async def _get_enriched_chunk_text( + self, + chunk_idx: int, + chunk: dict, + document_id: UUID, + chunk_enrichment_settings: ChunkEnrichmentSettings, + document_chunks: list[dict], + document_chunks_dict: dict, + ) -> VectorEntry: + # get chunks in context + + context_chunk_ids = [] + for enrichment_strategy in chunk_enrichment_settings.strategies: + if enrichment_strategy == ChunkEnrichmentStrategy.NEIGHBORHOOD: + for prev in range( + 1, chunk_enrichment_settings.backward_chunks + 1 + ): + if chunk_idx - prev >= 0: + context_chunk_ids.append( + document_chunks[chunk_idx - prev]["extraction_id"] + ) + for next in range( + 1, chunk_enrichment_settings.forward_chunks + 1 + ): + if chunk_idx + next < len(document_chunks): + context_chunk_ids.append( + document_chunks[chunk_idx + next]["extraction_id"] + ) + + elif enrichment_strategy == ChunkEnrichmentStrategy.SEMANTIC: + semantic_neighbors = self.providers.database.vector.get_semantic_neighbors( + document_id=document_id, + chunk_id=chunk["extraction_id"], + limit=chunk_enrichment_settings.semantic_neighbors, + similarity_threshold=chunk_enrichment_settings.semantic_similarity_threshold, + ) + + for neighbor in semantic_neighbors: + context_chunk_ids.append(neighbor["extraction_id"]) + + context_chunk_ids = list(set(context_chunk_ids)) + + context_chunk_texts = [] + for context_chunk_id in context_chunk_ids: + context_chunk_texts.append( + document_chunks_dict[context_chunk_id]["text"] + ) + + # enrich chunk + # get prompt and call LLM on it. Then finally embed and store it. + # don't call a pipe here. + # just call the LLM directly + try: + updated_chunk_text = ( + ( + await self.providers.llm.aget_completion( + messages=await self.providers.prompt._get_message_payload( + task_prompt_name="chunk_enrichment", + task_inputs={ + "context_chunks": ( + "\n".join(context_chunk_texts) + ), + "chunk": chunk["text"], + }, + ), + generation_config=chunk_enrichment_settings.generation_config, + ) + ) + .choices[0] + .message.content + ) + + except Exception as e: + updated_chunk_text = chunk["text"] + chunk["metadata"]["chunk_enrichment_status"] = "failed" + else: + if not updated_chunk_text: + updated_chunk_text = chunk["text"] + chunk["metadata"]["chunk_enrichment_status"] = "failed" + else: + chunk["metadata"]["chunk_enrichment_status"] = "success" + + data = await self.providers.embedding.async_get_embedding( + updated_chunk_text or chunk["text"] + ) + + chunk["metadata"]["original_text"] = chunk["text"] + + vector_entry_new = VectorEntry( + extraction_id=uuid.uuid5( + uuid.NAMESPACE_DNS, str(chunk["extraction_id"]) + ), + vector=Vector(data=data, type=VectorType.FIXED, length=len(data)), + document_id=document_id, + user_id=chunk["user_id"], + collection_ids=chunk["collection_ids"], + text=updated_chunk_text or chunk["text"], + metadata=chunk["metadata"], + ) + + return vector_entry_new + + async def chunk_enrichment(self, document_id: UUID) -> int: + # just call the pipe on every chunk of the document + + # TODO: Why is the config not recognized as an ingestionconfig but as a providerconfig? + chunk_enrichment_settings = ( + self.providers.ingestion.config.chunk_enrichment_settings # type: ignore + ) + # get all document_chunks + document_chunks = self.providers.database.vector.get_document_chunks( + document_id=document_id, + )["results"] + + new_vector_entries = [] + document_chunks_dict = { + chunk["extraction_id"]: chunk for chunk in document_chunks + } + + tasks = [] + total_completed = 0 + for chunk_idx, chunk in enumerate(document_chunks): + tasks.append( + self._get_enriched_chunk_text( + chunk_idx, + chunk, + document_id, + chunk_enrichment_settings, + document_chunks, + document_chunks_dict, + ) + ) + + if len(tasks) == 128: + new_vector_entries.extend(await asyncio.gather(*tasks)) + total_completed += 128 + logger.info( + f"Completed {total_completed} out of {len(document_chunks)} chunks for document {document_id}" + ) + tasks = [] + + new_vector_entries.extend(await asyncio.gather(*tasks)) + logger.info( + f"Completed enrichment of {len(document_chunks)} chunks for document {document_id}" + ) + + # delete old chunks from vector db + self.providers.database.vector.delete( + filters={ + "document_id": document_id, + }, + ) + + # embed and store the enriched chunk + self.providers.database.vector.upsert_entries(new_vector_entries) + + return len(new_vector_entries) + class IngestionServiceAdapter: @staticmethod @@ -358,6 +525,21 @@ def _parse_user_data(user_data) -> UserResponse: ) from e return UserResponse.from_dict(user_data) + @staticmethod + def _parse_chunk_enrichment_settings( + chunk_enrichment_settings: dict, + ) -> ChunkEnrichmentSettings: + if isinstance(chunk_enrichment_settings, str): + try: + chunk_enrichment_settings = json.loads( + chunk_enrichment_settings + ) + except json.JSONDecodeError as e: + raise ValueError( + f"Invalid chunk enrichment settings format: {chunk_enrichment_settings}" + ) from e + return ChunkEnrichmentSettings.from_dict(chunk_enrichment_settings) + @staticmethod def parse_ingest_file_input(data: dict) -> dict: return { diff --git a/py/core/pipes/ingestion/chunk_enrichment_pipe.py b/py/core/pipes/ingestion/chunk_enrichment_pipe.py new file mode 100644 index 000000000..15a77be8b --- /dev/null +++ b/py/core/pipes/ingestion/chunk_enrichment_pipe.py @@ -0,0 +1,25 @@ +# import asyncio +# import logging +# from typing import Any, AsyncGenerator, Optional, Union + +# from core.base import ( +# AsyncPipe, +# PipeType, +# R2RLoggingProvider, +# ) + + +# class ChunkEnrichmentPipe(AsyncPipe): +# """ +# Enriches chunks using a specified embedding model. +# """ + +# class Input(AsyncPipe.Input): +# message: list[DocumentExtraction] + + +# def __init__(self, config: AsyncPipe.PipeConfig, type: PipeType = PipeType.INGESTOR, pipe_logger: Optional[R2RLoggingProvider] = None): +# super().__init__(config, type, pipe_logger) + +# async def run(self, input: Input, state: Optional[AsyncState] = None, run_manager: Optional[RunManager] = None) -> AsyncGenerator[DocumentExtraction, None]: +# pass diff --git a/py/core/providers/database/vector.py b/py/core/providers/database/vector.py index 0496bd579..fe59d20bb 100644 --- a/py/core/providers/database/vector.py +++ b/py/core/providers/database/vector.py @@ -24,6 +24,7 @@ VectorQuantizationType, VectorTableName, ) +from uuid import UUID from .vecs import Client, Collection, create_client @@ -546,6 +547,55 @@ def get_document_chunks( return {"results": chunks, "total_entries": total} + def get_semantic_neighbors( + self, + document_id: UUID, + chunk_id: UUID, + limit: int = 10, + similarity_threshold: float = 0.5, + ) -> list[dict[str, Any]]: + if self.collection is None: + raise ValueError("Collection is not initialized.") + + table_name = self.collection.table.name + query = text( + f""" + WITH target_vector AS ( + SELECT vec FROM {self.project_name}."{table_name}" + WHERE document_id = :document_id AND extraction_id = :chunk_id + ) + SELECT t.extraction_id, t.text, t.metadata, t.document_id, (t.vec <=> tv.vec) AS similarity + FROM {self.project_name}."{table_name}" t, target_vector tv + WHERE (t.vec <=> tv.vec) >= :similarity_threshold + AND t.document_id = :document_id + AND t.extraction_id != :chunk_id + ORDER BY similarity ASC + LIMIT :limit + """ + ) + + with self.vx.Session() as sess: + results = sess.execute( + query, + { + "document_id": document_id, + "chunk_id": chunk_id, + "similarity_threshold": similarity_threshold, + "limit": limit, + }, + ).fetchall() + + return [ + { + "extraction_id": r[0], + "text": r[1], + "metadata": r[2], + "document_id": r[3], + "similarity": r[4], + } + for r in results + ] + def close(self) -> None: if self.vx: with self.vx.Session() as sess: diff --git a/py/core/providers/prompts/defaults/chunk_enrichment.yaml b/py/core/providers/prompts/defaults/chunk_enrichment.yaml new file mode 100644 index 000000000..b932235c9 --- /dev/null +++ b/py/core/providers/prompts/defaults/chunk_enrichment.yaml @@ -0,0 +1,21 @@ +chunk_enrichment: + template: > + ## Task: + + You are given a chunk of text. Your task is to enrich it with additional context from additional chunks that form the context of the chunk. + + Please make sure that the additional context you provide is relevant to the chunk. + + ## Context Chunks: + {context_chunks} + + ## Chunk: + {chunk} + + Note that: + - You will make the chunk extremely precise and useful + + ## Response: + input_types: + chunk: str + context_chunks: str diff --git a/py/r2r.toml b/py/r2r.toml index 55e31642b..6029657ab 100644 --- a/py/r2r.toml +++ b/py/r2r.toml @@ -55,6 +55,7 @@ chunking_strategy = "recursive" chunk_size = 1_024 chunk_overlap = 512 excluded_parsers = ["mp4"] + [ingestion.extra_parsers] pdf = "zerox" @@ -70,17 +71,17 @@ batch_size = 256 fragment_merge_count = 4 # number of fragments to merge into a single extraction max_knowledge_triples = 100 max_description_input_length = 1024 - generation_config = { model = "openai/gpt-4o-mini" } # and other params, model used for triplet extraction + generation_config = { model = "azure/gpt-4o-mini" } # and other params, model used for triplet extraction [kg.kg_enrichment_settings] community_reports_prompt = "graphrag_community_reports" - generation_config = { model = "openai/gpt-4o-mini" } # and other params, model used for node description and graph clustering + generation_config = { model = "azure/gpt-4o-mini" } # and other params, model used for node description and graph clustering leiden_params = {} [kg.kg_search_settings] map_system_prompt = "graphrag_map_system" reduce_system_prompt = "graphrag_reduce_system" - generation_config = { model = "openai/gpt-4o-mini" } + generation_config = { model = "azure/gpt-4o-mini" } [logging] provider = "r2r" diff --git a/py/shared/abstractions/ingestion.py b/py/shared/abstractions/ingestion.py new file mode 100644 index 000000000..9fb646246 --- /dev/null +++ b/py/shared/abstractions/ingestion.py @@ -0,0 +1,46 @@ +# Abstractions for ingestion + +from enum import Enum + +from pydantic import Field + +from .base import R2RSerializable +from .llm import GenerationConfig + + +class ChunkEnrichmentStrategy(str, Enum): + SEMANTIC = "semantic" + NEIGHBORHOOD = "neighborhood" + + def __str__(self) -> str: + return self.value + + +class ChunkEnrichmentSettings(R2RSerializable): + """ + Settings for chunk enrichment. + """ + + strategies: list[ChunkEnrichmentStrategy] = Field( + default=[], + description="The strategies to use for chunk enrichment. Union of chunks obtained from each strategy is used as context.", + ) + forward_chunks: int = Field( + default=3, + description="The number of chunks to include before the current chunk", + ) + backward_chunks: int = Field( + default=3, + description="The number of chunks to include after the current chunk", + ) + semantic_neighbors: int = Field( + default=10, description="The number of semantic neighbors to include" + ) + semantic_similarity_threshold: float = Field( + default=0.7, + description="The similarity threshold for semantic neighbors", + ) + generation_config: GenerationConfig = Field( + default=GenerationConfig(), + description="The generation config to use for chunk enrichment", + )