Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyaspimpalgaonkar committed Oct 22, 2024
1 parent b1fc134 commit 76afe2a
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 83 deletions.
8 changes: 6 additions & 2 deletions py/core/main/orchestration/hatchet/ingestion_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,15 +168,19 @@ async def parse(self, context: Context) -> dict:
)

# add contextual chunk
logger.info(f"service.providers.ingestion.config: {service.providers.ingestion.config}")
logger.info(
f"service.providers.ingestion.config: {service.providers.ingestion.config}"
)

chunk_enrichment_settings = getattr(
service.providers.ingestion.config,
"chunk_enrichment_settings",
None,
)

if chunk_enrichment_settings and getattr(chunk_enrichment_settings, "enable_chunk_enrichment", False):
if chunk_enrichment_settings and getattr(
chunk_enrichment_settings, "enable_chunk_enrichment", False
):

logger.info("Enriching document with contextual chunks")

Expand Down
5 changes: 5 additions & 0 deletions py/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
if config.config_file_name is not None:
fileConfig(config.config_file_name)


def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode."""
url = config.get_main_option("sqlalchemy.url")
Expand All @@ -29,11 +30,13 @@ def run_migrations_offline() -> None:
with context.begin_transaction():
context.run_migrations()


def do_run_migrations(connection: Connection) -> None:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()


async def run_async_migrations() -> None:
connectable = async_engine_from_config(
config.get_section(config.config_ini_section, {}),
Expand All @@ -46,10 +49,12 @@ async def run_async_migrations() -> None:

await connectable.dispose()


def run_migrations_online() -> None:
"""Run migrations in 'online' mode."""
asyncio.run(run_async_migrations())


if context.is_offline_mode():
run_migrations_offline()
else:
Expand Down
5 changes: 5 additions & 0 deletions py/migrations/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
if config.config_file_name is not None:
fileConfig(config.config_file_name)


def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode."""
url = config.get_main_option("sqlalchemy.url")
Expand All @@ -29,11 +30,13 @@ def run_migrations_offline() -> None:
with context.begin_transaction():
context.run_migrations()


def do_run_migrations(connection: Connection) -> None:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()


async def run_async_migrations() -> None:
connectable = async_engine_from_config(
config.get_section(config.config_ini_section, {}),
Expand All @@ -46,10 +49,12 @@ async def run_async_migrations() -> None:

await connectable.dispose()


def run_migrations_online() -> None:
"""Run migrations in 'online' mode."""
asyncio.run(run_async_migrations())


if context.is_offline_mode():
run_migrations_offline()
else:
Expand Down
7 changes: 5 additions & 2 deletions py/migrations/versions/9f1f30b182ae_3_2_16.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
"""3.2.16
Revision ID: 9f1f30b182ae
Revises:
Revises:
Create Date: 2024-10-21 23:04:34.550418
"""

from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic
revision: str = '9f1f30b182ae'
revision: str = "9f1f30b182ae"
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
pass


def downgrade() -> None:
pass
5 changes: 3 additions & 2 deletions py/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
IngestionStatus,
KGEnrichmentStatus,
KGExtractionStatus,
OrchestrationConfig,
)
from core.base import OrchestrationConfig
from core.providers import (
BCryptProvider,
LiteCompletionProvider,
Expand Down Expand Up @@ -180,7 +180,6 @@ def litellm_provider(app_config):
return LiteLLMEmbeddingProvider(config)



# Embeddings
@pytest.fixture
def litellm_provider_128(app_config):
Expand Down Expand Up @@ -289,10 +288,12 @@ async def postgres_kg_provider(
def prompt_config(app_config):
return PromptConfig(provider="r2r", app=app_config)


@pytest.fixture(scope="function")
def orchestration_config(app_config):
return OrchestrationConfig(provider="simple", app=app_config)


@pytest.fixture(scope="function")
async def r2r_prompt_provider(prompt_config, temporary_postgres_db_provider):
prompt_provider = R2RPromptProvider(
Expand Down
119 changes: 88 additions & 31 deletions py/tests/core/providers/ingestion/test_contextual_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,93 +3,135 @@
from datetime import datetime
from shared.api.models.auth.responses import UserResponse
from core.base import RawChunk, DocumentType, IngestionStatus, VectorEntry
from shared.abstractions.ingestion import ChunkEnrichmentStrategy, ChunkEnrichmentSettings
from shared.abstractions.ingestion import (
ChunkEnrichmentStrategy,
ChunkEnrichmentSettings,
)
import subprocess
from core.main.services.ingestion_service import IngestionService, IngestionConfig
from core.main.services.ingestion_service import (
IngestionService,
IngestionConfig,
)
from core.main.abstractions import R2RProviders
from core.providers.orchestration import SimpleOrchestrationProvider
from core.providers.ingestion import R2RIngestionConfig, R2RIngestionProvider

from core.base import Vector, VectorType
import random


@pytest.fixture
def sample_document_id():
return UUID('12345678-1234-5678-1234-567812345678')
return UUID("12345678-1234-5678-1234-567812345678")


@pytest.fixture
def sample_user():
return UserResponse(
id=UUID('87654321-8765-4321-8765-432187654321'),
id=UUID("87654321-8765-4321-8765-432187654321"),
email="[email protected]",
is_superuser=True
is_superuser=True,
)


@pytest.fixture
def collection_ids():
return [UUID('12345678-1234-5678-1234-567812345678')]
return [UUID("12345678-1234-5678-1234-567812345678")]


@pytest.fixture
def extraction_ids():
return [UUID('fce959df-46a2-4983-aa8b-dd1f93777e02'), UUID('9a85269c-84cd-4dff-bf21-7bd09974f668'), UUID('4b1199b2-2b96-4198-9ded-954c900a23dd')]
return [
UUID("fce959df-46a2-4983-aa8b-dd1f93777e02"),
UUID("9a85269c-84cd-4dff-bf21-7bd09974f668"),
UUID("4b1199b2-2b96-4198-9ded-954c900a23dd"),
]


@pytest.fixture
def sample_chunks(sample_document_id, sample_user, collection_ids, extraction_ids):
def sample_chunks(
sample_document_id, sample_user, collection_ids, extraction_ids
):
return [
VectorEntry(
extraction_id=extraction_ids[0],
document_id=sample_document_id,
user_id=sample_user.id,
collection_ids=collection_ids,
vector=Vector(data=[random.random() for _ in range(128)], type=VectorType.FIXED, length=128),
vector=Vector(
data=[random.random() for _ in range(128)],
type=VectorType.FIXED,
length=128,
),
text="This is the first chunk of text.",
metadata={"chunk_order": 0}
metadata={"chunk_order": 0},
),
VectorEntry(
extraction_id=extraction_ids[1],
document_id=sample_document_id,
user_id=sample_user.id,
collection_ids=collection_ids,
vector=Vector(data=[random.random() for _ in range(128)], type=VectorType.FIXED, length=128),
vector=Vector(
data=[random.random() for _ in range(128)],
type=VectorType.FIXED,
length=128,
),
text="This is the second chunk with different content.",
metadata={"chunk_order": 1}
metadata={"chunk_order": 1},
),
VectorEntry(
extraction_id=extraction_ids[2],
document_id=sample_document_id,
user_id=sample_user.id,
collection_ids=collection_ids,
vector=Vector(data=[random.random() for _ in range(128)], type=VectorType.FIXED, length=128),
vector=Vector(
data=[random.random() for _ in range(128)],
type=VectorType.FIXED,
length=128,
),
text="And this is the third chunk with more information.",
metadata={"chunk_order": 2}
)
metadata={"chunk_order": 2},
),
]


@pytest.fixture
def enrichment_settings():
return ChunkEnrichmentSettings(
enable_chunk_enrichment=True,
strategies=[
ChunkEnrichmentStrategy.NEIGHBORHOOD,
ChunkEnrichmentStrategy.SEMANTIC
ChunkEnrichmentStrategy.SEMANTIC,
],
backward_chunks=1,
forward_chunks=1,
semantic_neighbors=2,
semantic_similarity_threshold=0.7
semantic_similarity_threshold=0.7,
)


@pytest.fixture
def r2r_ingestion_provider(app_config):
return R2RIngestionProvider(R2RIngestionConfig(app=app_config))


@pytest.fixture
def orchestration_provider(orchestration_config):
return SimpleOrchestrationProvider(orchestration_config)


@pytest.fixture
def r2r_providers(r2r_ingestion_provider, r2r_prompt_provider, postgres_kg_provider, postgres_db_provider, litellm_provider_128, postgres_file_provider, r2r_auth_provider, litellm_completion_provider, orchestration_provider):
def r2r_providers(
r2r_ingestion_provider,
r2r_prompt_provider,
postgres_kg_provider,
postgres_db_provider,
litellm_provider_128,
postgres_file_provider,
r2r_auth_provider,
litellm_completion_provider,
orchestration_provider,
):
return R2RProviders(
ingestion=r2r_ingestion_provider,
prompt=r2r_prompt_provider,
Expand All @@ -102,34 +144,40 @@ def r2r_providers(r2r_ingestion_provider, r2r_prompt_provider, postgres_kg_provi
orchestration=orchestration_provider,
)


@pytest.fixture
def ingestion_config(app_config, enrichment_settings):
return IngestionConfig(app=app_config, chunk_enrichment_settings=enrichment_settings)
return IngestionConfig(
app=app_config, chunk_enrichment_settings=enrichment_settings
)


@pytest.fixture
async def ingestion_service(r2r_providers, ingestion_config):
# You'll need to mock your dependencies here
service = IngestionService(
providers=r2r_providers,
config=ingestion_config,
pipes = [],
pipelines = [],
agents = [],
run_manager = None,
logging_connection = None
pipes=[],
pipelines=[],
agents=[],
run_manager=None,
logging_connection=None,
)
return service


async def test_chunk_enrichment_basic(sample_chunks, ingestion_service, sample_document_id, sample_user):
async def test_chunk_enrichment_basic(
sample_chunks, ingestion_service, sample_document_id, sample_user
):
# Test basic chunk enrichment functionality

# ingest chunks ingress. Just add document info to the table
# ingest chunks ingress. Just add document info to the table
await ingestion_service.ingest_chunks_ingress(
document_id=sample_document_id,
chunks=sample_chunks,
metadata={},
user=sample_user
user=sample_user,
)

# upsert entries
Expand All @@ -139,13 +187,22 @@ async def test_chunk_enrichment_basic(sample_chunks, ingestion_service, sample_d
await ingestion_service.chunk_enrichment(sample_document_id)

# document chunks
document_chunks = await ingestion_service.providers.database.get_document_chunks(sample_document_id)
document_chunks = (
await ingestion_service.providers.database.get_document_chunks(
sample_document_id
)
)

assert len(document_chunks["results"]) == len(sample_chunks)

for document_chunk in document_chunks["results"]:
assert document_chunk["metadata"]["chunk_enrichment_status"] == "success"
assert document_chunk["metadata"]["original_text"] == sample_chunks[document_chunk["metadata"]["chunk_order"]].text
assert (
document_chunk["metadata"]["chunk_enrichment_status"] == "success"
)
assert (
document_chunk["metadata"]["original_text"]
== sample_chunks[document_chunk["metadata"]["chunk_order"]].text
)


# Other tests
Expand All @@ -170,4 +227,4 @@ async def test_chunk_enrichment_basic(sample_chunks, ingestion_service, sample_d
# Creates 200 RawChunks ("Chunk number {0-199}"), ingests and enriches them all to verify concurrent processing handles large batch correctly

# test_vector_storage:
# Ingests chunks, enriches them, then verifies get_document_vectors() returns vectors with correct structure including vector data and extraction_id fields
# Ingests chunks, enriches them, then verifies get_document_vectors() returns vectors with correct structure including vector data and extraction_id fields
Loading

0 comments on commit 76afe2a

Please sign in to comment.