Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyaspimpalgaonkar committed Oct 22, 2024
1 parent f31bbc6 commit 0bdd504
Show file tree
Hide file tree
Showing 8 changed files with 366 additions and 17 deletions.
10 changes: 0 additions & 10 deletions py/core/configs/full.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,6 @@ overlap = 256
[ingestion.extra_parsers]
pdf = "zerox"

# turned off by default
# [ingestion.chunk_enrichment_settings]
# enable_chunk_enrichment = true
# strategies = ["semantic", "neighborhood"]
# forward_chunks = 3
# backward_chunks = 3
# semantic_neighbors = 10
# semantic_similarity_threshold = 0.7
# generation_config = { model = "openai/gpt-4o-mini" }

[orchestration]
provider = "hatchet"
kg_creation_concurrency_lipmit = 32
Expand Down
14 changes: 9 additions & 5 deletions py/core/main/orchestration/hatchet/ingestion_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,16 @@ async def parse(self, context: Context) -> dict:
status=IngestionStatus.SUCCESS,
)

# add contextual chunks
if getattr(
# add contextual chunk
logger.info(f"service.providers.ingestion.config: {service.providers.ingestion.config}")

chunk_enrichment_settings = getattr(
service.providers.ingestion.config,
"enable_chunk_enrichment",
False,
):
"chunk_enrichment_settings",
None,
)

if chunk_enrichment_settings and getattr(chunk_enrichment_settings, "enable_chunk_enrichment", False):

logger.info("Enriching document with contextual chunks")

Expand Down
1 change: 0 additions & 1 deletion py/core/providers/database/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,7 +933,6 @@ async def get_semantic_neighbors(
ORDER BY similarity ASC
LIMIT $4
"""

results = await self.connection_manager.fetch_query(
query,
(str(document_id), str(chunk_id), similarity_threshold, limit),
Expand Down
9 changes: 9 additions & 0 deletions py/r2r.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ chunk_size = 1_024
chunk_overlap = 512
excluded_parsers = ["mp4"]

[ingestion.chunk_enrichment_settings]
enable_chunk_enrichment = true # disabled by default
strategies = ["semantic", "neighborhood"]
forward_chunks = 3
backward_chunks = 3
semantic_neighbors = 10
semantic_similarity_threshold = 0.7
generation_config = { model = "openai/gpt-4o-mini" }

[ingestion.extra_parsers]
pdf = "zerox"

Expand Down
17 changes: 17 additions & 0 deletions py/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
KGEnrichmentStatus,
KGExtractionStatus,
)
from core.base import OrchestrationConfig
from core.providers import (
BCryptProvider,
LiteCompletionProvider,
Expand Down Expand Up @@ -179,6 +180,19 @@ def litellm_provider(app_config):
return LiteLLMEmbeddingProvider(config)



# Embeddings
@pytest.fixture
def litellm_provider_128(app_config):
config = EmbeddingConfig(
provider="litellm",
base_model="text-embedding-3-small",
base_dimension=128,
app=app_config,
)
return LiteLLMEmbeddingProvider(config)


# File Provider
@pytest.fixture(scope="function")
def file_config(app_config):
Expand Down Expand Up @@ -275,6 +289,9 @@ async def postgres_kg_provider(
def prompt_config(app_config):
return PromptConfig(provider="r2r", app=app_config)

@pytest.fixture(scope="function")
def orchestration_config(app_config):
return OrchestrationConfig(provider="simple", app=app_config)

@pytest.fixture(scope="function")
async def r2r_prompt_provider(prompt_config, temporary_postgres_db_provider):
Expand Down
173 changes: 173 additions & 0 deletions py/tests/core/providers/ingestion/test_contextual_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import pytest
from uuid import UUID
from datetime import datetime
from shared.api.models.auth.responses import UserResponse
from core.base import RawChunk, DocumentType, IngestionStatus, VectorEntry
from shared.abstractions.ingestion import ChunkEnrichmentStrategy, ChunkEnrichmentSettings
import subprocess
from core.main.services.ingestion_service import IngestionService, IngestionConfig
from core.main.abstractions import R2RProviders
from core.providers.orchestration import SimpleOrchestrationProvider
from core.providers.ingestion import R2RIngestionConfig, R2RIngestionProvider

from core.base import Vector, VectorType
import random

@pytest.fixture
def sample_document_id():
return UUID('12345678-1234-5678-1234-567812345678')

@pytest.fixture
def sample_user():
return UserResponse(
id=UUID('87654321-8765-4321-8765-432187654321'),
email="[email protected]",
is_superuser=True
)

@pytest.fixture
def collection_ids():
return [UUID('12345678-1234-5678-1234-567812345678')]

@pytest.fixture
def extraction_ids():
return [UUID('fce959df-46a2-4983-aa8b-dd1f93777e02'), UUID('9a85269c-84cd-4dff-bf21-7bd09974f668'), UUID('4b1199b2-2b96-4198-9ded-954c900a23dd')]

@pytest.fixture
def sample_chunks(sample_document_id, sample_user, collection_ids, extraction_ids):
return [
VectorEntry(
extraction_id=extraction_ids[0],
document_id=sample_document_id,
user_id=sample_user.id,
collection_ids=collection_ids,
vector=Vector(data=[random.random() for _ in range(128)], type=VectorType.FIXED, length=128),
text="This is the first chunk of text.",
metadata={"chunk_order": 0}
),
VectorEntry(
extraction_id=extraction_ids[1],
document_id=sample_document_id,
user_id=sample_user.id,
collection_ids=collection_ids,
vector=Vector(data=[random.random() for _ in range(128)], type=VectorType.FIXED, length=128),
text="This is the second chunk with different content.",
metadata={"chunk_order": 1}
),
VectorEntry(
extraction_id=extraction_ids[2],
document_id=sample_document_id,
user_id=sample_user.id,
collection_ids=collection_ids,
vector=Vector(data=[random.random() for _ in range(128)], type=VectorType.FIXED, length=128),
text="And this is the third chunk with more information.",
metadata={"chunk_order": 2}
)
]

@pytest.fixture
def enrichment_settings():
return ChunkEnrichmentSettings(
enable_chunk_enrichment=True,
strategies=[
ChunkEnrichmentStrategy.NEIGHBORHOOD,
ChunkEnrichmentStrategy.SEMANTIC
],
backward_chunks=1,
forward_chunks=1,
semantic_neighbors=2,
semantic_similarity_threshold=0.7
)


@pytest.fixture
def r2r_ingestion_provider(app_config):
return R2RIngestionProvider(R2RIngestionConfig(app=app_config))

@pytest.fixture
def orchestration_provider(orchestration_config):
return SimpleOrchestrationProvider(orchestration_config)

@pytest.fixture
def r2r_providers(r2r_ingestion_provider, r2r_prompt_provider, postgres_kg_provider, postgres_db_provider, litellm_provider_128, postgres_file_provider, r2r_auth_provider, litellm_completion_provider, orchestration_provider):
return R2RProviders(
ingestion=r2r_ingestion_provider,
prompt=r2r_prompt_provider,
kg=postgres_kg_provider,
database=postgres_db_provider,
embedding=litellm_provider_128,
file=postgres_file_provider,
auth=r2r_auth_provider,
llm=litellm_completion_provider,
orchestration=orchestration_provider,
)

@pytest.fixture
def ingestion_config(app_config, enrichment_settings):
return IngestionConfig(app=app_config, chunk_enrichment_settings=enrichment_settings)

@pytest.fixture
async def ingestion_service(r2r_providers, ingestion_config):
# You'll need to mock your dependencies here
service = IngestionService(
providers=r2r_providers,
config=ingestion_config,
pipes = [],
pipelines = [],
agents = [],
run_manager = None,
logging_connection = None
)
return service


async def test_chunk_enrichment_basic(sample_chunks, ingestion_service, sample_document_id, sample_user):
# Test basic chunk enrichment functionality

# ingest chunks ingress. Just add document info to the table
await ingestion_service.ingest_chunks_ingress(
document_id=sample_document_id,
chunks=sample_chunks,
metadata={},
user=sample_user
)

# upsert entries
await ingestion_service.providers.database.upsert_entries(sample_chunks)

# enrich chunks
await ingestion_service.chunk_enrichment(sample_document_id)

# document chunks
document_chunks = await ingestion_service.providers.database.get_document_chunks(sample_document_id)

assert len(document_chunks["results"]) == len(sample_chunks)

for document_chunk in document_chunks["results"]:
assert document_chunk["metadata"]["chunk_enrichment_status"] == "success"
assert document_chunk["metadata"]["original_text"] == sample_chunks[document_chunk["metadata"]["chunk_order"]].text


# Other tests
# TODO: Implement in services/test_ingestion_service.py

# test_enriched_chunk_content:
# Ingests chunks, enriches them, then verifies each chunk in DB has metadata containing both 'original_text' and 'chunk_enrichment_status' (success/failed)

# test_neighborhood_strategy:
# Tests _get_enriched_chunk_text() on middle chunk (idx 1) with NEIGHBORHOOD strategy to verify it incorporates text from chunks before/after it

# test_semantic_strategy:
# Sets ChunkEnrichmentStrategy.SEMANTIC, ingests chunks, then enriches them using semantic similarity to find and incorporate related chunks' content

# test_error_handling:
# Attempts chunk_enrichment() with non-existent UUID('00000000-0000-0000-0000-000000000000') to verify proper exception handling

# test_empty_chunks:
# Attempts to ingest_chunks_ingress() with empty chunks list to verify it raises appropriate exception rather than processing empty data

# test_concurrent_processing:
# Creates 200 RawChunks ("Chunk number {0-199}"), ingests and enriches them all to verify concurrent processing handles large batch correctly

# test_vector_storage:
# Ingests chunks, enriches them, then verifies get_document_vectors() returns vectors with correct structure including vector data and extraction_id fields
4 changes: 3 additions & 1 deletion py/tests/core/providers/kg/test_kg_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,9 @@ async def test_get_community_details(
await postgres_kg_provider.add_community_report(community_report_list[0])

community_level, entities, triples = (
await postgres_kg_provider.get_community_details(community_number=1)
await postgres_kg_provider.get_community_details(
community_number=1, collection_id=collection_id
)
)

assert community_level == 0
Expand Down
Loading

0 comments on commit 0bdd504

Please sign in to comment.