diff --git a/py/core/main/orchestration/hatchet/ingestion_workflow.py b/py/core/main/orchestration/hatchet/ingestion_workflow.py index 5a03d74bd..dd03b68c4 100644 --- a/py/core/main/orchestration/hatchet/ingestion_workflow.py +++ b/py/core/main/orchestration/hatchet/ingestion_workflow.py @@ -168,7 +168,9 @@ async def parse(self, context: Context) -> dict: ) # add contextual chunk - logger.info(f"service.providers.ingestion.config: {service.providers.ingestion.config}") + logger.info( + f"service.providers.ingestion.config: {service.providers.ingestion.config}" + ) chunk_enrichment_settings = getattr( service.providers.ingestion.config, @@ -176,7 +178,9 @@ async def parse(self, context: Context) -> dict: None, ) - if chunk_enrichment_settings and getattr(chunk_enrichment_settings, "enable_chunk_enrichment", False): + if chunk_enrichment_settings and getattr( + chunk_enrichment_settings, "enable_chunk_enrichment", False + ): logger.info("Enriching document with contextual chunks") diff --git a/py/env.py b/py/env.py index 4ec88d3b3..f38aedc73 100644 --- a/py/env.py +++ b/py/env.py @@ -16,6 +16,7 @@ if config.config_file_name is not None: fileConfig(config.config_file_name) + def run_migrations_offline() -> None: """Run migrations in 'offline' mode.""" url = config.get_main_option("sqlalchemy.url") @@ -29,11 +30,13 @@ def run_migrations_offline() -> None: with context.begin_transaction(): context.run_migrations() + def do_run_migrations(connection: Connection) -> None: context.configure(connection=connection, target_metadata=target_metadata) with context.begin_transaction(): context.run_migrations() + async def run_async_migrations() -> None: connectable = async_engine_from_config( config.get_section(config.config_ini_section, {}), @@ -46,10 +49,12 @@ async def run_async_migrations() -> None: await connectable.dispose() + def run_migrations_online() -> None: """Run migrations in 'online' mode.""" asyncio.run(run_async_migrations()) + if context.is_offline_mode(): run_migrations_offline() else: diff --git a/py/migrations/env.py b/py/migrations/env.py index 4ec88d3b3..f38aedc73 100644 --- a/py/migrations/env.py +++ b/py/migrations/env.py @@ -16,6 +16,7 @@ if config.config_file_name is not None: fileConfig(config.config_file_name) + def run_migrations_offline() -> None: """Run migrations in 'offline' mode.""" url = config.get_main_option("sqlalchemy.url") @@ -29,11 +30,13 @@ def run_migrations_offline() -> None: with context.begin_transaction(): context.run_migrations() + def do_run_migrations(connection: Connection) -> None: context.configure(connection=connection, target_metadata=target_metadata) with context.begin_transaction(): context.run_migrations() + async def run_async_migrations() -> None: connectable = async_engine_from_config( config.get_section(config.config_ini_section, {}), @@ -46,10 +49,12 @@ async def run_async_migrations() -> None: await connectable.dispose() + def run_migrations_online() -> None: """Run migrations in 'online' mode.""" asyncio.run(run_async_migrations()) + if context.is_offline_mode(): run_migrations_offline() else: diff --git a/py/migrations/versions/9f1f30b182ae_3_2_16.py b/py/migrations/versions/9f1f30b182ae_3_2_16.py index fb48be112..9d683f567 100644 --- a/py/migrations/versions/9f1f30b182ae_3_2_16.py +++ b/py/migrations/versions/9f1f30b182ae_3_2_16.py @@ -1,10 +1,11 @@ """3.2.16 Revision ID: 9f1f30b182ae -Revises: +Revises: Create Date: 2024-10-21 23:04:34.550418 """ + from typing import Sequence, Union from alembic import op @@ -12,13 +13,15 @@ # revision identifiers, used by Alembic -revision: str = '9f1f30b182ae' +revision: str = "9f1f30b182ae" down_revision: Union[str, None] = None branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None + def upgrade() -> None: pass + def downgrade() -> None: pass diff --git a/py/tests/conftest.py b/py/tests/conftest.py index 0c535dd51..68f9c5fd7 100644 --- a/py/tests/conftest.py +++ b/py/tests/conftest.py @@ -27,8 +27,8 @@ IngestionStatus, KGEnrichmentStatus, KGExtractionStatus, + OrchestrationConfig, ) -from core.base import OrchestrationConfig from core.providers import ( BCryptProvider, LiteCompletionProvider, @@ -180,7 +180,6 @@ def litellm_provider(app_config): return LiteLLMEmbeddingProvider(config) - # Embeddings @pytest.fixture def litellm_provider_128(app_config): @@ -289,10 +288,12 @@ async def postgres_kg_provider( def prompt_config(app_config): return PromptConfig(provider="r2r", app=app_config) + @pytest.fixture(scope="function") def orchestration_config(app_config): return OrchestrationConfig(provider="simple", app=app_config) + @pytest.fixture(scope="function") async def r2r_prompt_provider(prompt_config, temporary_postgres_db_provider): prompt_provider = R2RPromptProvider( diff --git a/py/tests/core/providers/ingestion/test_contextual_embedding.py b/py/tests/core/providers/ingestion/test_contextual_embedding.py index 27a2179e9..713679654 100644 --- a/py/tests/core/providers/ingestion/test_contextual_embedding.py +++ b/py/tests/core/providers/ingestion/test_contextual_embedding.py @@ -3,9 +3,15 @@ from datetime import datetime from shared.api.models.auth.responses import UserResponse from core.base import RawChunk, DocumentType, IngestionStatus, VectorEntry -from shared.abstractions.ingestion import ChunkEnrichmentStrategy, ChunkEnrichmentSettings +from shared.abstractions.ingestion import ( + ChunkEnrichmentStrategy, + ChunkEnrichmentSettings, +) import subprocess -from core.main.services.ingestion_service import IngestionService, IngestionConfig +from core.main.services.ingestion_service import ( + IngestionService, + IngestionConfig, +) from core.main.abstractions import R2RProviders from core.providers.orchestration import SimpleOrchestrationProvider from core.providers.ingestion import R2RIngestionConfig, R2RIngestionProvider @@ -13,70 +19,94 @@ from core.base import Vector, VectorType import random + @pytest.fixture def sample_document_id(): - return UUID('12345678-1234-5678-1234-567812345678') + return UUID("12345678-1234-5678-1234-567812345678") + @pytest.fixture def sample_user(): return UserResponse( - id=UUID('87654321-8765-4321-8765-432187654321'), + id=UUID("87654321-8765-4321-8765-432187654321"), email="test@example.com", - is_superuser=True + is_superuser=True, ) + @pytest.fixture def collection_ids(): - return [UUID('12345678-1234-5678-1234-567812345678')] + return [UUID("12345678-1234-5678-1234-567812345678")] + @pytest.fixture def extraction_ids(): - return [UUID('fce959df-46a2-4983-aa8b-dd1f93777e02'), UUID('9a85269c-84cd-4dff-bf21-7bd09974f668'), UUID('4b1199b2-2b96-4198-9ded-954c900a23dd')] + return [ + UUID("fce959df-46a2-4983-aa8b-dd1f93777e02"), + UUID("9a85269c-84cd-4dff-bf21-7bd09974f668"), + UUID("4b1199b2-2b96-4198-9ded-954c900a23dd"), + ] + @pytest.fixture -def sample_chunks(sample_document_id, sample_user, collection_ids, extraction_ids): +def sample_chunks( + sample_document_id, sample_user, collection_ids, extraction_ids +): return [ VectorEntry( extraction_id=extraction_ids[0], document_id=sample_document_id, user_id=sample_user.id, collection_ids=collection_ids, - vector=Vector(data=[random.random() for _ in range(128)], type=VectorType.FIXED, length=128), + vector=Vector( + data=[random.random() for _ in range(128)], + type=VectorType.FIXED, + length=128, + ), text="This is the first chunk of text.", - metadata={"chunk_order": 0} + metadata={"chunk_order": 0}, ), VectorEntry( extraction_id=extraction_ids[1], document_id=sample_document_id, user_id=sample_user.id, collection_ids=collection_ids, - vector=Vector(data=[random.random() for _ in range(128)], type=VectorType.FIXED, length=128), + vector=Vector( + data=[random.random() for _ in range(128)], + type=VectorType.FIXED, + length=128, + ), text="This is the second chunk with different content.", - metadata={"chunk_order": 1} + metadata={"chunk_order": 1}, ), VectorEntry( extraction_id=extraction_ids[2], document_id=sample_document_id, user_id=sample_user.id, collection_ids=collection_ids, - vector=Vector(data=[random.random() for _ in range(128)], type=VectorType.FIXED, length=128), + vector=Vector( + data=[random.random() for _ in range(128)], + type=VectorType.FIXED, + length=128, + ), text="And this is the third chunk with more information.", - metadata={"chunk_order": 2} - ) + metadata={"chunk_order": 2}, + ), ] + @pytest.fixture def enrichment_settings(): return ChunkEnrichmentSettings( enable_chunk_enrichment=True, strategies=[ ChunkEnrichmentStrategy.NEIGHBORHOOD, - ChunkEnrichmentStrategy.SEMANTIC + ChunkEnrichmentStrategy.SEMANTIC, ], backward_chunks=1, forward_chunks=1, semantic_neighbors=2, - semantic_similarity_threshold=0.7 + semantic_similarity_threshold=0.7, ) @@ -84,12 +114,24 @@ def enrichment_settings(): def r2r_ingestion_provider(app_config): return R2RIngestionProvider(R2RIngestionConfig(app=app_config)) + @pytest.fixture def orchestration_provider(orchestration_config): return SimpleOrchestrationProvider(orchestration_config) + @pytest.fixture -def r2r_providers(r2r_ingestion_provider, r2r_prompt_provider, postgres_kg_provider, postgres_db_provider, litellm_provider_128, postgres_file_provider, r2r_auth_provider, litellm_completion_provider, orchestration_provider): +def r2r_providers( + r2r_ingestion_provider, + r2r_prompt_provider, + postgres_kg_provider, + postgres_db_provider, + litellm_provider_128, + postgres_file_provider, + r2r_auth_provider, + litellm_completion_provider, + orchestration_provider, +): return R2RProviders( ingestion=r2r_ingestion_provider, prompt=r2r_prompt_provider, @@ -102,9 +144,13 @@ def r2r_providers(r2r_ingestion_provider, r2r_prompt_provider, postgres_kg_provi orchestration=orchestration_provider, ) + @pytest.fixture def ingestion_config(app_config, enrichment_settings): - return IngestionConfig(app=app_config, chunk_enrichment_settings=enrichment_settings) + return IngestionConfig( + app=app_config, chunk_enrichment_settings=enrichment_settings + ) + @pytest.fixture async def ingestion_service(r2r_providers, ingestion_config): @@ -112,24 +158,26 @@ async def ingestion_service(r2r_providers, ingestion_config): service = IngestionService( providers=r2r_providers, config=ingestion_config, - pipes = [], - pipelines = [], - agents = [], - run_manager = None, - logging_connection = None + pipes=[], + pipelines=[], + agents=[], + run_manager=None, + logging_connection=None, ) return service -async def test_chunk_enrichment_basic(sample_chunks, ingestion_service, sample_document_id, sample_user): +async def test_chunk_enrichment_basic( + sample_chunks, ingestion_service, sample_document_id, sample_user +): # Test basic chunk enrichment functionality - # ingest chunks ingress. Just add document info to the table + # ingest chunks ingress. Just add document info to the table await ingestion_service.ingest_chunks_ingress( document_id=sample_document_id, chunks=sample_chunks, metadata={}, - user=sample_user + user=sample_user, ) # upsert entries @@ -139,13 +187,22 @@ async def test_chunk_enrichment_basic(sample_chunks, ingestion_service, sample_d await ingestion_service.chunk_enrichment(sample_document_id) # document chunks - document_chunks = await ingestion_service.providers.database.get_document_chunks(sample_document_id) + document_chunks = ( + await ingestion_service.providers.database.get_document_chunks( + sample_document_id + ) + ) assert len(document_chunks["results"]) == len(sample_chunks) for document_chunk in document_chunks["results"]: - assert document_chunk["metadata"]["chunk_enrichment_status"] == "success" - assert document_chunk["metadata"]["original_text"] == sample_chunks[document_chunk["metadata"]["chunk_order"]].text + assert ( + document_chunk["metadata"]["chunk_enrichment_status"] == "success" + ) + assert ( + document_chunk["metadata"]["original_text"] + == sample_chunks[document_chunk["metadata"]["chunk_order"]].text + ) # Other tests @@ -170,4 +227,4 @@ async def test_chunk_enrichment_basic(sample_chunks, ingestion_service, sample_d # Creates 200 RawChunks ("Chunk number {0-199}"), ingests and enriches them all to verify concurrent processing handles large batch correctly # test_vector_storage: -# Ingests chunks, enriches them, then verifies get_document_vectors() returns vectors with correct structure including vector data and extraction_id fields \ No newline at end of file +# Ingests chunks, enriches them, then verifies get_document_vectors() returns vectors with correct structure including vector data and extraction_id fields diff --git a/py/tests/core/services/test_ingestion_service.py b/py/tests/core/services/test_ingestion_service.py index c24a034aa..b6be7e360 100644 --- a/py/tests/core/services/test_ingestion_service.py +++ b/py/tests/core/services/test_ingestion_service.py @@ -1,128 +1,150 @@ -import pytest from uuid import UUID + +import pytest + from core.base import RawChunk from core.main.services.ingestion_service import IngestionService + @pytest.fixture def sample_document_id(): - return UUID('12345678-1234-5678-1234-567812345678') + return UUID("12345678-1234-5678-1234-567812345678") + @pytest.fixture def sample_chunks(): return [ RawChunk( text="This is the first chunk of text.", - metadata={"chunk_order": 1} + metadata={"chunk_order": 1}, ), RawChunk( text="This is the second chunk with different content.", - metadata={"chunk_order": 2} + metadata={"chunk_order": 2}, ), RawChunk( text="And this is the third chunk with more information.", - metadata={"chunk_order": 3} - ) + metadata={"chunk_order": 3}, + ), ] -async def test_ingest_chunks_ingress_success(ingestion_service, sample_document_id, sample_chunks): + +async def test_ingest_chunks_ingress_success( + ingestion_service, sample_document_id, sample_chunks +): """Test successful ingestion of chunks""" result = await ingestion_service.ingest_chunks_ingress( document_id=sample_document_id, chunks=sample_chunks, metadata={"title": "Test Document"}, - user_id="test_user" + user_id="test_user", ) - + assert result is not None # Add assertions based on your expected return type -async def test_ingest_chunks_ingress_empty_chunks(ingestion_service, sample_document_id): + +async def test_ingest_chunks_ingress_empty_chunks( + ingestion_service, sample_document_id +): """Test handling of empty chunks list""" with pytest.raises(ValueError): await ingestion_service.ingest_chunks_ingress( document_id=sample_document_id, chunks=[], metadata={}, - user_id="test_user" + user_id="test_user", ) -async def test_ingest_chunks_ingress_invalid_metadata(ingestion_service, sample_document_id, sample_chunks): + +async def test_ingest_chunks_ingress_invalid_metadata( + ingestion_service, sample_document_id, sample_chunks +): """Test handling of invalid metadata""" with pytest.raises(TypeError): await ingestion_service.ingest_chunks_ingress( document_id=sample_document_id, chunks=sample_chunks, metadata=None, # Invalid metadata - user_id="test_user" + user_id="test_user", ) -async def test_ingest_chunks_ingress_large_document(ingestion_service, sample_document_id): + +async def test_ingest_chunks_ingress_large_document( + ingestion_service, sample_document_id +): """Test ingestion of a large number of chunks""" large_chunks = [ - RawChunk( - text=f"Chunk number {i}", - metadata={"chunk_order": i} - ) + RawChunk(text=f"Chunk number {i}", metadata={"chunk_order": i}) for i in range(1000) ] - + result = await ingestion_service.ingest_chunks_ingress( document_id=sample_document_id, chunks=large_chunks, metadata={"title": "Large Document"}, - user_id="test_user" + user_id="test_user", ) - + assert result is not None # Add assertions for large document handling -async def test_ingest_chunks_ingress_duplicate_chunk_orders(ingestion_service, sample_document_id): + +async def test_ingest_chunks_ingress_duplicate_chunk_orders( + ingestion_service, sample_document_id +): """Test handling of chunks with duplicate chunk orders""" duplicate_chunks = [ - RawChunk( - text="First chunk", - metadata={"chunk_order": 1} - ), + RawChunk(text="First chunk", metadata={"chunk_order": 1}), RawChunk( text="Second chunk", - metadata={"chunk_order": 1} # Duplicate chunk_order - ) + metadata={"chunk_order": 1}, # Duplicate chunk_order + ), ] - + with pytest.raises(ValueError): await ingestion_service.ingest_chunks_ingress( document_id=sample_document_id, chunks=duplicate_chunks, metadata={}, - user_id="test_user" + user_id="test_user", ) -async def test_ingest_chunks_ingress_invalid_user(ingestion_service, sample_document_id, sample_chunks): + +async def test_ingest_chunks_ingress_invalid_user( + ingestion_service, sample_document_id, sample_chunks +): """Test handling of invalid user ID""" with pytest.raises(ValueError): await ingestion_service.ingest_chunks_ingress( document_id=sample_document_id, chunks=sample_chunks, metadata={}, - user_id="" # Invalid user ID + user_id="", # Invalid user ID ) -async def test_ingest_chunks_ingress_metadata_validation(ingestion_service, sample_document_id, sample_chunks): + +async def test_ingest_chunks_ingress_metadata_validation( + ingestion_service, sample_document_id, sample_chunks +): """Test metadata validation""" test_cases = [ ({"title": "Valid title"}, True), ({"title": ""}, False), ({"invalid_key": "value"}, False), - ({}, True), # Empty metadata might be valid depending on your requirements + ( + {}, + True, + ), # Empty metadata might be valid depending on your requirements ] - + for metadata, should_succeed in test_cases: if should_succeed: result = await ingestion_service.ingest_chunks_ingress( document_id=sample_document_id, chunks=sample_chunks, metadata=metadata, - user_id="test_user" + user_id="test_user", ) assert result is not None else: @@ -131,25 +153,32 @@ async def test_ingest_chunks_ingress_metadata_validation(ingestion_service, samp document_id=sample_document_id, chunks=sample_chunks, metadata=metadata, - user_id="test_user" + user_id="test_user", ) -async def test_ingest_chunks_ingress_concurrent_requests(ingestion_service, sample_chunks): + +async def test_ingest_chunks_ingress_concurrent_requests( + ingestion_service, sample_chunks +): """Test handling of concurrent ingestion requests""" import asyncio - - document_ids = [UUID('12345678-1234-5678-1234-56781234567' + str(i)) for i in range(5)] - + + document_ids = [ + UUID("12345678-1234-5678-1234-56781234567" + str(i)) for i in range(5) + ] + async def ingest_document(doc_id): return await ingestion_service.ingest_chunks_ingress( document_id=doc_id, chunks=sample_chunks, metadata={"title": f"Document {doc_id}"}, - user_id="test_user" + user_id="test_user", ) - - results = await asyncio.gather(*[ingest_document(doc_id) for doc_id in document_ids]) - + + results = await asyncio.gather( + *[ingest_document(doc_id) for doc_id in document_ids] + ) + assert len(results) == len(document_ids) for result in results: - assert result is not None \ No newline at end of file + assert result is not None