From ddd551c03e0b3f99d96728a60629fbf1b768d372 Mon Sep 17 00:00:00 2001 From: Shreyas Pimpalgaonkar Date: Fri, 18 Oct 2024 16:54:20 -0700 Subject: [PATCH] add semantic chunking --- py/core/base/providers/ingestion.py | 4 +- py/core/configs/full.toml | 11 +- .../hatchet/ingestion_workflow.py | 14 +++ py/core/main/services/ingestion_service.py | 103 ++++++++++++++++++ .../pipes/ingestion/chunk_enrichment_pipe.py | 30 +++++ py/core/providers/database/vector.py | 43 ++++++++ .../prompts/defaults/chunk_enrichment.yaml | 21 ++++ py/r2r.toml | 7 +- py/shared/abstractions/ingestion.py | 28 +++++ 9 files changed, 255 insertions(+), 6 deletions(-) create mode 100644 py/core/pipes/ingestion/chunk_enrichment_pipe.py create mode 100644 py/core/providers/prompts/defaults/chunk_enrichment.yaml create mode 100644 py/shared/abstractions/ingestion.py diff --git a/py/core/base/providers/ingestion.py b/py/core/base/providers/ingestion.py index 6ff0be1d4..f0d6e763a 100644 --- a/py/core/base/providers/ingestion.py +++ b/py/core/base/providers/ingestion.py @@ -1,15 +1,15 @@ import logging from abc import ABC from enum import Enum - from .base import Provider, ProviderConfig - +from shared.abstractions.ingestion import ChunkEnrichmentSettings logger = logging.getLogger() class IngestionConfig(ProviderConfig): provider: str = "r2r" excluded_parsers: list[str] = ["mp4"] + chunk_enrichment_settings: ChunkEnrichmentSettings = ChunkEnrichmentSettings() extra_parsers: dict[str, str] = {} @property diff --git a/py/core/configs/full.toml b/py/core/configs/full.toml index 43b6641e7..b30d841be 100644 --- a/py/core/configs/full.toml +++ b/py/core/configs/full.toml @@ -6,8 +6,17 @@ new_after_n_chars = 512 max_characters = 1_024 combine_under_n_chars = 128 overlap = 256 + [ingestion.extra_parsers] - pdf = "zerox" + pdf = "zerox" + + [ingestion.chunk_enrichment_settings] + strategies = ["semantic", "neighborhood"] + forward_chunks = 3 + backward_chunks = 3 + semantic_neighbors = 10 + semantic_similarity_threshold = 0.7 + generation_config = { model = "azure/gpt-4o-mini" } [orchestration] provider = "hatchet" diff --git a/py/core/main/orchestration/hatchet/ingestion_workflow.py b/py/core/main/orchestration/hatchet/ingestion_workflow.py index 84b230581..e138a73d9 100644 --- a/py/core/main/orchestration/hatchet/ingestion_workflow.py +++ b/py/core/main/orchestration/hatchet/ingestion_workflow.py @@ -141,6 +141,20 @@ async def parse(self, context: Context) -> dict: "is_update" ) + + # add contextual chunks + await self.ingestion_service.chunk_enrichment( + document_id = document_info.id, + ) + + + # delete original chunks + + + # delete original chunk vectors + + + await self.ingestion_service.finalize_ingestion( document_info, is_update=is_update ) diff --git a/py/core/main/services/ingestion_service.py b/py/core/main/services/ingestion_service.py index bd93b9ddf..0448c31c9 100644 --- a/py/core/main/services/ingestion_service.py +++ b/py/core/main/services/ingestion_service.py @@ -1,5 +1,6 @@ import json import logging +import uuid from datetime import datetime from typing import Any, AsyncGenerator, Optional, Sequence, Union from uuid import UUID @@ -25,6 +26,8 @@ VectorTableName, ) +from shared.abstractions.ingestion import ChunkEnrichmentStrategy, ChunkEnrichmentSettings +from core.base import IngestionConfig from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders from ..config import R2RConfig from .base import Service @@ -344,7 +347,96 @@ async def ingest_chunks_ingress( ) return document_info + + async def chunk_enrichment(self, document_id: UUID) -> None: + # just call the pipe on every chunk of the document + + chunk_enrichment_settings = self.providers.ingestion.config.chunk_enrichment_settings + + # get all document_chunks + document_chunks = self.providers.database.vector.get_document_chunks( + document_id=document_id, + )['results'] + + new_vector_entries = [] + document_chunks_dict = {chunk['extraction_id']: chunk for chunk in document_chunks} + + for chunk_idx, chunk in enumerate(document_chunks): + + # get chunks in context + context_chunk_ids = [] + for enrichment_strategy in chunk_enrichment_settings.strategies: + if enrichment_strategy == ChunkEnrichmentStrategy.NEIGHBORHOOD: + for prev in range(1, chunk_enrichment_settings.backward_chunks + 1): + if chunk_idx - prev >= 0: + context_chunk_ids.append(document_chunks[chunk_idx - prev].extraction_id) + for next in range(1, chunk_enrichment_settings.forward_chunks + 1): + if chunk_idx + next < len(document_chunks): + context_chunk_ids.append(document_chunks[chunk_idx + next].extraction_id) + + elif enrichment_strategy == ChunkEnrichmentStrategy.SEMANTIC: + semantic_neighbors = self.providers.database.vector.get_semantic_neighbors( + document_id=document_id, + chunk_id=chunk['extraction_id'], + limit=chunk_enrichment_settings.semantic_neighbors, + similarity_threshold=chunk_enrichment_settings.semantic_similarity_threshold, + ) + for neighbor in semantic_neighbors: + context_chunk_ids.append(neighbor['extraction_id']) + + context_chunk_ids = set(context_chunk_ids) + + context_chunk_texts = [] + for context_chunk_id in context_chunk_ids: + context_chunk_texts.append(document_chunks_dict[context_chunk_id].text) + + # enrich chunk + # get prompt and call LLM on it. Then finally embed and store it. + # don't call a pipe here. + # just call the LLM directly + updated_chunk_text = ( + ( + await self.providers.llm.aget_completion( + messages=await self.providers.prompt._aget_message_payload( + task_prompt_name='chunk_enrichment', + task_inputs={ + "context_chunks": ( + "\n".join(context_chunk_texts) + ), + "chunk": chunk['text'], + }, + ), + generation_config=chunk_enrichment_settings.generation_config, + ) + ) + .choices[0] + .message.content + ) + + vector_entry_new = VectorEntry( + extraction_id=uuid.uuid5(uuid.NAMESPACE_DNS, str(chunk['extraction_id'])), + vector = (await self.providers.embedding.async_get_embedding(updated_chunk_text)), + document_id=document_id, + user_id=chunk['user_id'], + collection_ids=chunk['collection_ids'], + text=updated_chunk_text, + metadata=chunk['metadata'], + ) + + new_vector_entries.append(vector_entry_new) + + # delete old chunks from document_chunk_dics + + await self.providers.database.vector.delete( + filters = { + "document_id": document_id, + }, + ) + + # embed and store the enriched chunk + await self.providers.database.vector.upsert_entries(new_vector_entries) + return len(new_vector_entries) class IngestionServiceAdapter: @staticmethod @@ -357,6 +449,17 @@ def _parse_user_data(user_data) -> UserResponse: f"Invalid user data format: {user_data}" ) from e return UserResponse.from_dict(user_data) + + @staticmethod + def _parse_chunk_enrichment_settings(chunk_enrichment_settings: dict) -> ChunkEnrichmentSettings: + if isinstance(chunk_enrichment_settings, str): + try: + chunk_enrichment_settings = json.loads(chunk_enrichment_settings) + except json.JSONDecodeError as e: + raise ValueError( + f"Invalid chunk enrichment settings format: {chunk_enrichment_settings}" + ) from e + return ChunkEnrichmentSettings.from_dict(chunk_enrichment_settings) @staticmethod def parse_ingest_file_input(data: dict) -> dict: diff --git a/py/core/pipes/ingestion/chunk_enrichment_pipe.py b/py/core/pipes/ingestion/chunk_enrichment_pipe.py new file mode 100644 index 000000000..ff0a34d98 --- /dev/null +++ b/py/core/pipes/ingestion/chunk_enrichment_pipe.py @@ -0,0 +1,30 @@ +# import asyncio +# import logging +# from typing import Any, AsyncGenerator, Optional, Union + +# from core.base import ( +# AsyncPipe, +# PipeType, +# R2RLoggingProvider, +# ) + + + +# class ChunkEnrichmentPipe(AsyncPipe): +# """ +# Enriches chunks using a specified embedding model. +# """ + +# class Input(AsyncPipe.Input): +# message: list[DocumentExtraction] + + +# def __init__(self, config: AsyncPipe.PipeConfig, type: PipeType = PipeType.INGESTOR, pipe_logger: Optional[R2RLoggingProvider] = None): +# super().__init__(config, type, pipe_logger) + +# async def run(self, input: Input, state: Optional[AsyncState] = None, run_manager: Optional[RunManager] = None) -> AsyncGenerator[DocumentExtraction, None]: +# pass + + + + diff --git a/py/core/providers/database/vector.py b/py/core/providers/database/vector.py index 0496bd579..4dd612860 100644 --- a/py/core/providers/database/vector.py +++ b/py/core/providers/database/vector.py @@ -24,6 +24,7 @@ VectorQuantizationType, VectorTableName, ) +from uuid import UUID from .vecs import Client, Collection, create_client @@ -546,6 +547,46 @@ def get_document_chunks( return {"results": chunks, "total_entries": total} + + def get_semantic_neighbors( + self, + document_id: UUID, + chunk_id: UUID, + limit: int = 10, + similarity_threshold: float = 0.7, + ) -> list[dict[str, Any]]: + if self.collection is None: + raise ValueError("Collection is not initialized.") + + table_name = self.collection.table.name + query = text( + f""" + WITH target_vector AS ( + SELECT vec FROM {self.project_name}."{table_name}" + WHERE document_id = :document_id AND extraction_id = :chunk_id + ) + SELECT t.*, (t.vec <=> tv.vec) AS similarity + FROM {self.project_name}."{table_name}" t, target_vector tv + WHERE (t.vec <=> tv.vec) < :similarity_threshold + AND t.document_id != :document_id + ORDER BY similarity ASC + LIMIT :limit + """ + ) + + with self.vx.Session() as sess: + results = sess.execute( + query, + { + "document_id": document_id, + "chunk_id": chunk_id, + "similarity_threshold": similarity_threshold, + "limit": limit + } + ).fetchall() + + return [dict(r) for r in results] + def close(self) -> None: if self.vx: with self.vx.Session() as sess: @@ -554,3 +595,5 @@ def close(self) -> None: sess.bind.dispose() # type: ignore logger.info("Closed PGVectorDB connection.") + + diff --git a/py/core/providers/prompts/defaults/chunk_enrichment.yaml b/py/core/providers/prompts/defaults/chunk_enrichment.yaml new file mode 100644 index 000000000..91caf6981 --- /dev/null +++ b/py/core/providers/prompts/defaults/chunk_enrichment.yaml @@ -0,0 +1,21 @@ +chunk_enrichment: + template: > + ## Task: + + You are given a chunk of text. Your task is to enrich it with additional context from additional chunks that form the context of the chunk. + + Please make sure that the additional context you provide is relevant to the chunk. + + ## Context Chunks: + {context_chunks} + + ## Chunk: + {chunk} + + Note that: + - You will make the chunk extremely precise and useful + + ## Response: + input_types: + chunk: str + context_chunks: str diff --git a/py/r2r.toml b/py/r2r.toml index 55e31642b..6029657ab 100644 --- a/py/r2r.toml +++ b/py/r2r.toml @@ -55,6 +55,7 @@ chunking_strategy = "recursive" chunk_size = 1_024 chunk_overlap = 512 excluded_parsers = ["mp4"] + [ingestion.extra_parsers] pdf = "zerox" @@ -70,17 +71,17 @@ batch_size = 256 fragment_merge_count = 4 # number of fragments to merge into a single extraction max_knowledge_triples = 100 max_description_input_length = 1024 - generation_config = { model = "openai/gpt-4o-mini" } # and other params, model used for triplet extraction + generation_config = { model = "azure/gpt-4o-mini" } # and other params, model used for triplet extraction [kg.kg_enrichment_settings] community_reports_prompt = "graphrag_community_reports" - generation_config = { model = "openai/gpt-4o-mini" } # and other params, model used for node description and graph clustering + generation_config = { model = "azure/gpt-4o-mini" } # and other params, model used for node description and graph clustering leiden_params = {} [kg.kg_search_settings] map_system_prompt = "graphrag_map_system" reduce_system_prompt = "graphrag_reduce_system" - generation_config = { model = "openai/gpt-4o-mini" } + generation_config = { model = "azure/gpt-4o-mini" } [logging] provider = "r2r" diff --git a/py/shared/abstractions/ingestion.py b/py/shared/abstractions/ingestion.py new file mode 100644 index 000000000..62fffc4a5 --- /dev/null +++ b/py/shared/abstractions/ingestion.py @@ -0,0 +1,28 @@ +# Abstractions for ingestion + +from enum import Enum + +from pydantic import Field + +from .base import R2RSerializable +from .llm import GenerationConfig + + +class ChunkEnrichmentStrategy(str, Enum): + SEMANTIC = "semantic" + NEIGHBORHOOD = "neighborhood" + + def __str__(self) -> str: + return self.value + +class ChunkEnrichmentSettings(R2RSerializable): + """ + Settings for chunk enrichment. + """ + + strategies: list[ChunkEnrichmentStrategy] = Field(default=[], description="The strategies to use for chunk enrichment. Union of chunks obtained from each strategy is used as context.") + forward_chunks: int = Field(default=3, description="The number of chunks to include before the current chunk") + backward_chunks: int = Field(default=3, description="The number of chunks to include after the current chunk") + semantic_neighbors: int = Field(default=10, description="The number of semantic neighbors to include") + semantic_similarity_threshold: float = Field(default=0.7, description="The similarity threshold for semantic neighbors") + generation_config: GenerationConfig = Field(default=GenerationConfig(), description="The generation config to use for chunk enrichment")