Skip to content

Commit

Permalink
Contextual Chunk Enrichment (#1433)
Browse files Browse the repository at this point in the history
* add semantic chunking

* working

* precommit

* pre-commits
  • Loading branch information
shreyaspimpalgaonkar authored Oct 20, 2024
1 parent 44e6fd8 commit fe42fbd
Show file tree
Hide file tree
Showing 9 changed files with 351 additions and 5 deletions.
5 changes: 4 additions & 1 deletion py/core/base/providers/ingestion.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import logging
from abc import ABC
from enum import Enum

from .base import Provider, ProviderConfig
from shared.abstractions.ingestion import ChunkEnrichmentSettings

logger = logging.getLogger()


class IngestionConfig(ProviderConfig):
provider: str = "r2r"
excluded_parsers: list[str] = ["mp4"]
chunk_enrichment_settings: ChunkEnrichmentSettings = (
ChunkEnrichmentSettings()
)
extra_parsers: dict[str, str] = {}

@property
Expand Down
11 changes: 10 additions & 1 deletion py/core/configs/full.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,17 @@ new_after_n_chars = 512
max_characters = 1_024
combine_under_n_chars = 128
overlap = 256

[ingestion.extra_parsers]
pdf = "zerox"
pdf = "zerox"

[ingestion.chunk_enrichment_settings]
strategies = ["semantic", "neighborhood"]
forward_chunks = 3
backward_chunks = 3
semantic_neighbors = 10
semantic_similarity_threshold = 0.7
generation_config = { model = "azure/gpt-4o-mini" }

[orchestration]
provider = "hatchet"
Expand Down
9 changes: 9 additions & 0 deletions py/core/main/orchestration/hatchet/ingestion_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,15 @@ async def parse(self, context: Context) -> dict:
"is_update"
)

# add contextual chunks
await self.ingestion_service.chunk_enrichment(
document_id=document_info.id,
)

# delete original chunks

# delete original chunk vectors

await self.ingestion_service.finalize_ingestion(
document_info, is_update=is_update
)
Expand Down
182 changes: 182 additions & 0 deletions py/core/main/services/ingestion_service.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
import logging
import uuid
from datetime import datetime
import asyncio
from typing import Any, AsyncGenerator, Optional, Sequence, Union
from uuid import UUID

Expand All @@ -14,7 +16,9 @@
R2RLoggingProvider,
RawChunk,
RunManager,
Vector,
VectorEntry,
VectorType,
decrement_version,
)
from core.base.api.models import UserResponse
Expand All @@ -25,6 +29,11 @@
VectorTableName,
)

from shared.abstractions.ingestion import (
ChunkEnrichmentStrategy,
ChunkEnrichmentSettings,
)
from core.base import IngestionConfig
from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders
from ..config import R2RConfig
from .base import Service
Expand Down Expand Up @@ -345,6 +354,164 @@ async def ingest_chunks_ingress(

return document_info

async def _get_enriched_chunk_text(
self,
chunk_idx: int,
chunk: dict,
document_id: UUID,
chunk_enrichment_settings: ChunkEnrichmentSettings,
document_chunks: list[dict],
document_chunks_dict: dict,
) -> VectorEntry:
# get chunks in context

context_chunk_ids = []
for enrichment_strategy in chunk_enrichment_settings.strategies:
if enrichment_strategy == ChunkEnrichmentStrategy.NEIGHBORHOOD:
for prev in range(
1, chunk_enrichment_settings.backward_chunks + 1
):
if chunk_idx - prev >= 0:
context_chunk_ids.append(
document_chunks[chunk_idx - prev]["extraction_id"]
)
for next in range(
1, chunk_enrichment_settings.forward_chunks + 1
):
if chunk_idx + next < len(document_chunks):
context_chunk_ids.append(
document_chunks[chunk_idx + next]["extraction_id"]
)

elif enrichment_strategy == ChunkEnrichmentStrategy.SEMANTIC:
semantic_neighbors = self.providers.database.vector.get_semantic_neighbors(
document_id=document_id,
chunk_id=chunk["extraction_id"],
limit=chunk_enrichment_settings.semantic_neighbors,
similarity_threshold=chunk_enrichment_settings.semantic_similarity_threshold,
)

for neighbor in semantic_neighbors:
context_chunk_ids.append(neighbor["extraction_id"])

context_chunk_ids = list(set(context_chunk_ids))

context_chunk_texts = []
for context_chunk_id in context_chunk_ids:
context_chunk_texts.append(
document_chunks_dict[context_chunk_id]["text"]
)

# enrich chunk
# get prompt and call LLM on it. Then finally embed and store it.
# don't call a pipe here.
# just call the LLM directly
try:
updated_chunk_text = (
(
await self.providers.llm.aget_completion(
messages=await self.providers.prompt._get_message_payload(
task_prompt_name="chunk_enrichment",
task_inputs={
"context_chunks": (
"\n".join(context_chunk_texts)
),
"chunk": chunk["text"],
},
),
generation_config=chunk_enrichment_settings.generation_config,
)
)
.choices[0]
.message.content
)

except Exception as e:
updated_chunk_text = chunk["text"]
chunk["metadata"]["chunk_enrichment_status"] = "failed"
else:
if not updated_chunk_text:
updated_chunk_text = chunk["text"]
chunk["metadata"]["chunk_enrichment_status"] = "failed"
else:
chunk["metadata"]["chunk_enrichment_status"] = "success"

data = await self.providers.embedding.async_get_embedding(
updated_chunk_text or chunk["text"]
)

chunk["metadata"]["original_text"] = chunk["text"]

vector_entry_new = VectorEntry(
extraction_id=uuid.uuid5(
uuid.NAMESPACE_DNS, str(chunk["extraction_id"])
),
vector=Vector(data=data, type=VectorType.FIXED, length=len(data)),
document_id=document_id,
user_id=chunk["user_id"],
collection_ids=chunk["collection_ids"],
text=updated_chunk_text or chunk["text"],
metadata=chunk["metadata"],
)

return vector_entry_new

async def chunk_enrichment(self, document_id: UUID) -> int:
# just call the pipe on every chunk of the document

# TODO: Why is the config not recognized as an ingestionconfig but as a providerconfig?
chunk_enrichment_settings = (
self.providers.ingestion.config.chunk_enrichment_settings # type: ignore
)
# get all document_chunks
document_chunks = self.providers.database.vector.get_document_chunks(
document_id=document_id,
)["results"]

new_vector_entries = []
document_chunks_dict = {
chunk["extraction_id"]: chunk for chunk in document_chunks
}

tasks = []
total_completed = 0
for chunk_idx, chunk in enumerate(document_chunks):
tasks.append(
self._get_enriched_chunk_text(
chunk_idx,
chunk,
document_id,
chunk_enrichment_settings,
document_chunks,
document_chunks_dict,
)
)

if len(tasks) == 128:
new_vector_entries.extend(await asyncio.gather(*tasks))
total_completed += 128
logger.info(
f"Completed {total_completed} out of {len(document_chunks)} chunks for document {document_id}"
)
tasks = []

new_vector_entries.extend(await asyncio.gather(*tasks))
logger.info(
f"Completed enrichment of {len(document_chunks)} chunks for document {document_id}"
)

# delete old chunks from vector db
self.providers.database.vector.delete(
filters={
"document_id": document_id,
},
)

# embed and store the enriched chunk
self.providers.database.vector.upsert_entries(new_vector_entries)

return len(new_vector_entries)


class IngestionServiceAdapter:
@staticmethod
Expand All @@ -358,6 +525,21 @@ def _parse_user_data(user_data) -> UserResponse:
) from e
return UserResponse.from_dict(user_data)

@staticmethod
def _parse_chunk_enrichment_settings(
chunk_enrichment_settings: dict,
) -> ChunkEnrichmentSettings:
if isinstance(chunk_enrichment_settings, str):
try:
chunk_enrichment_settings = json.loads(
chunk_enrichment_settings
)
except json.JSONDecodeError as e:
raise ValueError(
f"Invalid chunk enrichment settings format: {chunk_enrichment_settings}"
) from e
return ChunkEnrichmentSettings.from_dict(chunk_enrichment_settings)

@staticmethod
def parse_ingest_file_input(data: dict) -> dict:
return {
Expand Down
25 changes: 25 additions & 0 deletions py/core/pipes/ingestion/chunk_enrichment_pipe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# import asyncio
# import logging
# from typing import Any, AsyncGenerator, Optional, Union

# from core.base import (
# AsyncPipe,
# PipeType,
# R2RLoggingProvider,
# )


# class ChunkEnrichmentPipe(AsyncPipe):
# """
# Enriches chunks using a specified embedding model.
# """

# class Input(AsyncPipe.Input):
# message: list[DocumentExtraction]


# def __init__(self, config: AsyncPipe.PipeConfig, type: PipeType = PipeType.INGESTOR, pipe_logger: Optional[R2RLoggingProvider] = None):
# super().__init__(config, type, pipe_logger)

# async def run(self, input: Input, state: Optional[AsyncState] = None, run_manager: Optional[RunManager] = None) -> AsyncGenerator[DocumentExtraction, None]:
# pass
50 changes: 50 additions & 0 deletions py/core/providers/database/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
VectorQuantizationType,
VectorTableName,
)
from uuid import UUID

from .vecs import Client, Collection, create_client

Expand Down Expand Up @@ -546,6 +547,55 @@ def get_document_chunks(

return {"results": chunks, "total_entries": total}

def get_semantic_neighbors(
self,
document_id: UUID,
chunk_id: UUID,
limit: int = 10,
similarity_threshold: float = 0.5,
) -> list[dict[str, Any]]:
if self.collection is None:
raise ValueError("Collection is not initialized.")

table_name = self.collection.table.name
query = text(
f"""
WITH target_vector AS (
SELECT vec FROM {self.project_name}."{table_name}"
WHERE document_id = :document_id AND extraction_id = :chunk_id
)
SELECT t.extraction_id, t.text, t.metadata, t.document_id, (t.vec <=> tv.vec) AS similarity
FROM {self.project_name}."{table_name}" t, target_vector tv
WHERE (t.vec <=> tv.vec) >= :similarity_threshold
AND t.document_id = :document_id
AND t.extraction_id != :chunk_id
ORDER BY similarity ASC
LIMIT :limit
"""
)

with self.vx.Session() as sess:
results = sess.execute(
query,
{
"document_id": document_id,
"chunk_id": chunk_id,
"similarity_threshold": similarity_threshold,
"limit": limit,
},
).fetchall()

return [
{
"extraction_id": r[0],
"text": r[1],
"metadata": r[2],
"document_id": r[3],
"similarity": r[4],
}
for r in results
]

def close(self) -> None:
if self.vx:
with self.vx.Session() as sess:
Expand Down
21 changes: 21 additions & 0 deletions py/core/providers/prompts/defaults/chunk_enrichment.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
chunk_enrichment:
template: >
## Task:
You are given a chunk of text. Your task is to enrich it with additional context from additional chunks that form the context of the chunk.
Please make sure that the additional context you provide is relevant to the chunk.
## Context Chunks:
{context_chunks}
## Chunk:
{chunk}
Note that:
- You will make the chunk extremely precise and useful
## Response:
input_types:
chunk: str
context_chunks: str
Loading

0 comments on commit fe42fbd

Please sign in to comment.