Skip to content

Commit

Permalink
add semantic chunking
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyaspimpalgaonkar committed Oct 18, 2024
1 parent c9be2c5 commit ddd551c
Show file tree
Hide file tree
Showing 9 changed files with 255 additions and 6 deletions.
4 changes: 2 additions & 2 deletions py/core/base/providers/ingestion.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import logging
from abc import ABC
from enum import Enum

from .base import Provider, ProviderConfig

from shared.abstractions.ingestion import ChunkEnrichmentSettings
logger = logging.getLogger()


class IngestionConfig(ProviderConfig):
provider: str = "r2r"
excluded_parsers: list[str] = ["mp4"]
chunk_enrichment_settings: ChunkEnrichmentSettings = ChunkEnrichmentSettings()
extra_parsers: dict[str, str] = {}

@property
Expand Down
11 changes: 10 additions & 1 deletion py/core/configs/full.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,17 @@ new_after_n_chars = 512
max_characters = 1_024
combine_under_n_chars = 128
overlap = 256

[ingestion.extra_parsers]
pdf = "zerox"
pdf = "zerox"

[ingestion.chunk_enrichment_settings]
strategies = ["semantic", "neighborhood"]
forward_chunks = 3
backward_chunks = 3
semantic_neighbors = 10
semantic_similarity_threshold = 0.7
generation_config = { model = "azure/gpt-4o-mini" }

[orchestration]
provider = "hatchet"
Expand Down
14 changes: 14 additions & 0 deletions py/core/main/orchestration/hatchet/ingestion_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,20 @@ async def parse(self, context: Context) -> dict:
"is_update"
)


# add contextual chunks
await self.ingestion_service.chunk_enrichment(
document_id = document_info.id,
)


# delete original chunks


# delete original chunk vectors



await self.ingestion_service.finalize_ingestion(
document_info, is_update=is_update
)
Expand Down
103 changes: 103 additions & 0 deletions py/core/main/services/ingestion_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import logging
import uuid
from datetime import datetime
from typing import Any, AsyncGenerator, Optional, Sequence, Union
from uuid import UUID
Expand All @@ -25,6 +26,8 @@
VectorTableName,
)

from shared.abstractions.ingestion import ChunkEnrichmentStrategy, ChunkEnrichmentSettings
from core.base import IngestionConfig
from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders
from ..config import R2RConfig
from .base import Service
Expand Down Expand Up @@ -344,7 +347,96 @@ async def ingest_chunks_ingress(
)

return document_info

async def chunk_enrichment(self, document_id: UUID) -> None:
# just call the pipe on every chunk of the document

chunk_enrichment_settings = self.providers.ingestion.config.chunk_enrichment_settings

# get all document_chunks
document_chunks = self.providers.database.vector.get_document_chunks(
document_id=document_id,
)['results']

new_vector_entries = []
document_chunks_dict = {chunk['extraction_id']: chunk for chunk in document_chunks}

for chunk_idx, chunk in enumerate(document_chunks):

# get chunks in context
context_chunk_ids = []
for enrichment_strategy in chunk_enrichment_settings.strategies:
if enrichment_strategy == ChunkEnrichmentStrategy.NEIGHBORHOOD:
for prev in range(1, chunk_enrichment_settings.backward_chunks + 1):
if chunk_idx - prev >= 0:
context_chunk_ids.append(document_chunks[chunk_idx - prev].extraction_id)
for next in range(1, chunk_enrichment_settings.forward_chunks + 1):
if chunk_idx + next < len(document_chunks):
context_chunk_ids.append(document_chunks[chunk_idx + next].extraction_id)

elif enrichment_strategy == ChunkEnrichmentStrategy.SEMANTIC:
semantic_neighbors = self.providers.database.vector.get_semantic_neighbors(
document_id=document_id,
chunk_id=chunk['extraction_id'],
limit=chunk_enrichment_settings.semantic_neighbors,
similarity_threshold=chunk_enrichment_settings.semantic_similarity_threshold,
)
for neighbor in semantic_neighbors:
context_chunk_ids.append(neighbor['extraction_id'])

context_chunk_ids = set(context_chunk_ids)

context_chunk_texts = []
for context_chunk_id in context_chunk_ids:
context_chunk_texts.append(document_chunks_dict[context_chunk_id].text)

# enrich chunk
# get prompt and call LLM on it. Then finally embed and store it.
# don't call a pipe here.
# just call the LLM directly
updated_chunk_text = (
(
await self.providers.llm.aget_completion(
messages=await self.providers.prompt._aget_message_payload(
task_prompt_name='chunk_enrichment',
task_inputs={
"context_chunks": (
"\n".join(context_chunk_texts)
),
"chunk": chunk['text'],
},
),
generation_config=chunk_enrichment_settings.generation_config,
)
)
.choices[0]
.message.content
)

vector_entry_new = VectorEntry(
extraction_id=uuid.uuid5(uuid.NAMESPACE_DNS, str(chunk['extraction_id'])),
vector = (await self.providers.embedding.async_get_embedding(updated_chunk_text)),
document_id=document_id,
user_id=chunk['user_id'],
collection_ids=chunk['collection_ids'],
text=updated_chunk_text,
metadata=chunk['metadata'],
)

new_vector_entries.append(vector_entry_new)

# delete old chunks from document_chunk_dics

await self.providers.database.vector.delete(
filters = {
"document_id": document_id,
},
)

# embed and store the enriched chunk
await self.providers.database.vector.upsert_entries(new_vector_entries)

return len(new_vector_entries)

class IngestionServiceAdapter:
@staticmethod
Expand All @@ -357,6 +449,17 @@ def _parse_user_data(user_data) -> UserResponse:
f"Invalid user data format: {user_data}"
) from e
return UserResponse.from_dict(user_data)

@staticmethod
def _parse_chunk_enrichment_settings(chunk_enrichment_settings: dict) -> ChunkEnrichmentSettings:
if isinstance(chunk_enrichment_settings, str):
try:
chunk_enrichment_settings = json.loads(chunk_enrichment_settings)
except json.JSONDecodeError as e:
raise ValueError(
f"Invalid chunk enrichment settings format: {chunk_enrichment_settings}"
) from e
return ChunkEnrichmentSettings.from_dict(chunk_enrichment_settings)

@staticmethod
def parse_ingest_file_input(data: dict) -> dict:
Expand Down
30 changes: 30 additions & 0 deletions py/core/pipes/ingestion/chunk_enrichment_pipe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# import asyncio
# import logging
# from typing import Any, AsyncGenerator, Optional, Union

# from core.base import (
# AsyncPipe,
# PipeType,
# R2RLoggingProvider,
# )



# class ChunkEnrichmentPipe(AsyncPipe):
# """
# Enriches chunks using a specified embedding model.
# """

# class Input(AsyncPipe.Input):
# message: list[DocumentExtraction]


# def __init__(self, config: AsyncPipe.PipeConfig, type: PipeType = PipeType.INGESTOR, pipe_logger: Optional[R2RLoggingProvider] = None):
# super().__init__(config, type, pipe_logger)

# async def run(self, input: Input, state: Optional[AsyncState] = None, run_manager: Optional[RunManager] = None) -> AsyncGenerator[DocumentExtraction, None]:
# pass




43 changes: 43 additions & 0 deletions py/core/providers/database/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
VectorQuantizationType,
VectorTableName,
)
from uuid import UUID

from .vecs import Client, Collection, create_client

Expand Down Expand Up @@ -546,6 +547,46 @@ def get_document_chunks(

return {"results": chunks, "total_entries": total}


def get_semantic_neighbors(
self,
document_id: UUID,
chunk_id: UUID,
limit: int = 10,
similarity_threshold: float = 0.7,
) -> list[dict[str, Any]]:
if self.collection is None:
raise ValueError("Collection is not initialized.")

table_name = self.collection.table.name
query = text(
f"""
WITH target_vector AS (
SELECT vec FROM {self.project_name}."{table_name}"
WHERE document_id = :document_id AND extraction_id = :chunk_id
)
SELECT t.*, (t.vec <=> tv.vec) AS similarity
FROM {self.project_name}."{table_name}" t, target_vector tv
WHERE (t.vec <=> tv.vec) < :similarity_threshold
AND t.document_id != :document_id
ORDER BY similarity ASC
LIMIT :limit
"""
)

with self.vx.Session() as sess:
results = sess.execute(
query,
{
"document_id": document_id,
"chunk_id": chunk_id,
"similarity_threshold": similarity_threshold,
"limit": limit
}
).fetchall()

return [dict(r) for r in results]

def close(self) -> None:
if self.vx:
with self.vx.Session() as sess:
Expand All @@ -554,3 +595,5 @@ def close(self) -> None:
sess.bind.dispose() # type: ignore

logger.info("Closed PGVectorDB connection.")


21 changes: 21 additions & 0 deletions py/core/providers/prompts/defaults/chunk_enrichment.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
chunk_enrichment:
template: >
## Task:
You are given a chunk of text. Your task is to enrich it with additional context from additional chunks that form the context of the chunk.
Please make sure that the additional context you provide is relevant to the chunk.
## Context Chunks:
{context_chunks}
## Chunk:
{chunk}
Note that:
- You will make the chunk extremely precise and useful
## Response:
input_types:
chunk: str
context_chunks: str
7 changes: 4 additions & 3 deletions py/r2r.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ chunking_strategy = "recursive"
chunk_size = 1_024
chunk_overlap = 512
excluded_parsers = ["mp4"]

[ingestion.extra_parsers]
pdf = "zerox"

Expand All @@ -70,17 +71,17 @@ batch_size = 256
fragment_merge_count = 4 # number of fragments to merge into a single extraction
max_knowledge_triples = 100
max_description_input_length = 1024
generation_config = { model = "openai/gpt-4o-mini" } # and other params, model used for triplet extraction
generation_config = { model = "azure/gpt-4o-mini" } # and other params, model used for triplet extraction

[kg.kg_enrichment_settings]
community_reports_prompt = "graphrag_community_reports"
generation_config = { model = "openai/gpt-4o-mini" } # and other params, model used for node description and graph clustering
generation_config = { model = "azure/gpt-4o-mini" } # and other params, model used for node description and graph clustering
leiden_params = {}

[kg.kg_search_settings]
map_system_prompt = "graphrag_map_system"
reduce_system_prompt = "graphrag_reduce_system"
generation_config = { model = "openai/gpt-4o-mini" }
generation_config = { model = "azure/gpt-4o-mini" }

[logging]
provider = "r2r"
Expand Down
28 changes: 28 additions & 0 deletions py/shared/abstractions/ingestion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Abstractions for ingestion

from enum import Enum

from pydantic import Field

from .base import R2RSerializable
from .llm import GenerationConfig


class ChunkEnrichmentStrategy(str, Enum):
SEMANTIC = "semantic"
NEIGHBORHOOD = "neighborhood"

def __str__(self) -> str:
return self.value

class ChunkEnrichmentSettings(R2RSerializable):
"""
Settings for chunk enrichment.
"""

strategies: list[ChunkEnrichmentStrategy] = Field(default=[], description="The strategies to use for chunk enrichment. Union of chunks obtained from each strategy is used as context.")
forward_chunks: int = Field(default=3, description="The number of chunks to include before the current chunk")
backward_chunks: int = Field(default=3, description="The number of chunks to include after the current chunk")
semantic_neighbors: int = Field(default=10, description="The number of semantic neighbors to include")
semantic_similarity_threshold: float = Field(default=0.7, description="The similarity threshold for semantic neighbors")
generation_config: GenerationConfig = Field(default=GenerationConfig(), description="The generation config to use for chunk enrichment")

0 comments on commit ddd551c

Please sign in to comment.