-
Notifications
You must be signed in to change notification settings - Fork 311
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b1fc134
commit 76afe2a
Showing
7 changed files
with
187 additions
and
83 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,27 @@ | ||
"""3.2.16 | ||
Revision ID: 9f1f30b182ae | ||
Revises: | ||
Revises: | ||
Create Date: 2024-10-21 23:04:34.550418 | ||
""" | ||
|
||
from typing import Sequence, Union | ||
|
||
from alembic import op | ||
import sqlalchemy as sa | ||
|
||
|
||
# revision identifiers, used by Alembic | ||
revision: str = '9f1f30b182ae' | ||
revision: str = "9f1f30b182ae" | ||
down_revision: Union[str, None] = None | ||
branch_labels: Union[str, Sequence[str], None] = None | ||
depends_on: Union[str, Sequence[str], None] = None | ||
|
||
|
||
def upgrade() -> None: | ||
pass | ||
|
||
|
||
def downgrade() -> None: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,93 +3,135 @@ | |
from datetime import datetime | ||
from shared.api.models.auth.responses import UserResponse | ||
from core.base import RawChunk, DocumentType, IngestionStatus, VectorEntry | ||
from shared.abstractions.ingestion import ChunkEnrichmentStrategy, ChunkEnrichmentSettings | ||
from shared.abstractions.ingestion import ( | ||
ChunkEnrichmentStrategy, | ||
ChunkEnrichmentSettings, | ||
) | ||
import subprocess | ||
from core.main.services.ingestion_service import IngestionService, IngestionConfig | ||
from core.main.services.ingestion_service import ( | ||
IngestionService, | ||
IngestionConfig, | ||
) | ||
from core.main.abstractions import R2RProviders | ||
from core.providers.orchestration import SimpleOrchestrationProvider | ||
from core.providers.ingestion import R2RIngestionConfig, R2RIngestionProvider | ||
|
||
from core.base import Vector, VectorType | ||
import random | ||
|
||
|
||
@pytest.fixture | ||
def sample_document_id(): | ||
return UUID('12345678-1234-5678-1234-567812345678') | ||
return UUID("12345678-1234-5678-1234-567812345678") | ||
|
||
|
||
@pytest.fixture | ||
def sample_user(): | ||
return UserResponse( | ||
id=UUID('87654321-8765-4321-8765-432187654321'), | ||
id=UUID("87654321-8765-4321-8765-432187654321"), | ||
email="[email protected]", | ||
is_superuser=True | ||
is_superuser=True, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def collection_ids(): | ||
return [UUID('12345678-1234-5678-1234-567812345678')] | ||
return [UUID("12345678-1234-5678-1234-567812345678")] | ||
|
||
|
||
@pytest.fixture | ||
def extraction_ids(): | ||
return [UUID('fce959df-46a2-4983-aa8b-dd1f93777e02'), UUID('9a85269c-84cd-4dff-bf21-7bd09974f668'), UUID('4b1199b2-2b96-4198-9ded-954c900a23dd')] | ||
return [ | ||
UUID("fce959df-46a2-4983-aa8b-dd1f93777e02"), | ||
UUID("9a85269c-84cd-4dff-bf21-7bd09974f668"), | ||
UUID("4b1199b2-2b96-4198-9ded-954c900a23dd"), | ||
] | ||
|
||
|
||
@pytest.fixture | ||
def sample_chunks(sample_document_id, sample_user, collection_ids, extraction_ids): | ||
def sample_chunks( | ||
sample_document_id, sample_user, collection_ids, extraction_ids | ||
): | ||
return [ | ||
VectorEntry( | ||
extraction_id=extraction_ids[0], | ||
document_id=sample_document_id, | ||
user_id=sample_user.id, | ||
collection_ids=collection_ids, | ||
vector=Vector(data=[random.random() for _ in range(128)], type=VectorType.FIXED, length=128), | ||
vector=Vector( | ||
data=[random.random() for _ in range(128)], | ||
type=VectorType.FIXED, | ||
length=128, | ||
), | ||
text="This is the first chunk of text.", | ||
metadata={"chunk_order": 0} | ||
metadata={"chunk_order": 0}, | ||
), | ||
VectorEntry( | ||
extraction_id=extraction_ids[1], | ||
document_id=sample_document_id, | ||
user_id=sample_user.id, | ||
collection_ids=collection_ids, | ||
vector=Vector(data=[random.random() for _ in range(128)], type=VectorType.FIXED, length=128), | ||
vector=Vector( | ||
data=[random.random() for _ in range(128)], | ||
type=VectorType.FIXED, | ||
length=128, | ||
), | ||
text="This is the second chunk with different content.", | ||
metadata={"chunk_order": 1} | ||
metadata={"chunk_order": 1}, | ||
), | ||
VectorEntry( | ||
extraction_id=extraction_ids[2], | ||
document_id=sample_document_id, | ||
user_id=sample_user.id, | ||
collection_ids=collection_ids, | ||
vector=Vector(data=[random.random() for _ in range(128)], type=VectorType.FIXED, length=128), | ||
vector=Vector( | ||
data=[random.random() for _ in range(128)], | ||
type=VectorType.FIXED, | ||
length=128, | ||
), | ||
text="And this is the third chunk with more information.", | ||
metadata={"chunk_order": 2} | ||
) | ||
metadata={"chunk_order": 2}, | ||
), | ||
] | ||
|
||
|
||
@pytest.fixture | ||
def enrichment_settings(): | ||
return ChunkEnrichmentSettings( | ||
enable_chunk_enrichment=True, | ||
strategies=[ | ||
ChunkEnrichmentStrategy.NEIGHBORHOOD, | ||
ChunkEnrichmentStrategy.SEMANTIC | ||
ChunkEnrichmentStrategy.SEMANTIC, | ||
], | ||
backward_chunks=1, | ||
forward_chunks=1, | ||
semantic_neighbors=2, | ||
semantic_similarity_threshold=0.7 | ||
semantic_similarity_threshold=0.7, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def r2r_ingestion_provider(app_config): | ||
return R2RIngestionProvider(R2RIngestionConfig(app=app_config)) | ||
|
||
|
||
@pytest.fixture | ||
def orchestration_provider(orchestration_config): | ||
return SimpleOrchestrationProvider(orchestration_config) | ||
|
||
|
||
@pytest.fixture | ||
def r2r_providers(r2r_ingestion_provider, r2r_prompt_provider, postgres_kg_provider, postgres_db_provider, litellm_provider_128, postgres_file_provider, r2r_auth_provider, litellm_completion_provider, orchestration_provider): | ||
def r2r_providers( | ||
r2r_ingestion_provider, | ||
r2r_prompt_provider, | ||
postgres_kg_provider, | ||
postgres_db_provider, | ||
litellm_provider_128, | ||
postgres_file_provider, | ||
r2r_auth_provider, | ||
litellm_completion_provider, | ||
orchestration_provider, | ||
): | ||
return R2RProviders( | ||
ingestion=r2r_ingestion_provider, | ||
prompt=r2r_prompt_provider, | ||
|
@@ -102,34 +144,40 @@ def r2r_providers(r2r_ingestion_provider, r2r_prompt_provider, postgres_kg_provi | |
orchestration=orchestration_provider, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def ingestion_config(app_config, enrichment_settings): | ||
return IngestionConfig(app=app_config, chunk_enrichment_settings=enrichment_settings) | ||
return IngestionConfig( | ||
app=app_config, chunk_enrichment_settings=enrichment_settings | ||
) | ||
|
||
|
||
@pytest.fixture | ||
async def ingestion_service(r2r_providers, ingestion_config): | ||
# You'll need to mock your dependencies here | ||
service = IngestionService( | ||
providers=r2r_providers, | ||
config=ingestion_config, | ||
pipes = [], | ||
pipelines = [], | ||
agents = [], | ||
run_manager = None, | ||
logging_connection = None | ||
pipes=[], | ||
pipelines=[], | ||
agents=[], | ||
run_manager=None, | ||
logging_connection=None, | ||
) | ||
return service | ||
|
||
|
||
async def test_chunk_enrichment_basic(sample_chunks, ingestion_service, sample_document_id, sample_user): | ||
async def test_chunk_enrichment_basic( | ||
sample_chunks, ingestion_service, sample_document_id, sample_user | ||
): | ||
# Test basic chunk enrichment functionality | ||
|
||
# ingest chunks ingress. Just add document info to the table | ||
# ingest chunks ingress. Just add document info to the table | ||
await ingestion_service.ingest_chunks_ingress( | ||
document_id=sample_document_id, | ||
chunks=sample_chunks, | ||
metadata={}, | ||
user=sample_user | ||
user=sample_user, | ||
) | ||
|
||
# upsert entries | ||
|
@@ -139,13 +187,22 @@ async def test_chunk_enrichment_basic(sample_chunks, ingestion_service, sample_d | |
await ingestion_service.chunk_enrichment(sample_document_id) | ||
|
||
# document chunks | ||
document_chunks = await ingestion_service.providers.database.get_document_chunks(sample_document_id) | ||
document_chunks = ( | ||
await ingestion_service.providers.database.get_document_chunks( | ||
sample_document_id | ||
) | ||
) | ||
|
||
assert len(document_chunks["results"]) == len(sample_chunks) | ||
|
||
for document_chunk in document_chunks["results"]: | ||
assert document_chunk["metadata"]["chunk_enrichment_status"] == "success" | ||
assert document_chunk["metadata"]["original_text"] == sample_chunks[document_chunk["metadata"]["chunk_order"]].text | ||
assert ( | ||
document_chunk["metadata"]["chunk_enrichment_status"] == "success" | ||
) | ||
assert ( | ||
document_chunk["metadata"]["original_text"] | ||
== sample_chunks[document_chunk["metadata"]["chunk_order"]].text | ||
) | ||
|
||
|
||
# Other tests | ||
|
@@ -170,4 +227,4 @@ async def test_chunk_enrichment_basic(sample_chunks, ingestion_service, sample_d | |
# Creates 200 RawChunks ("Chunk number {0-199}"), ingests and enriches them all to verify concurrent processing handles large batch correctly | ||
|
||
# test_vector_storage: | ||
# Ingests chunks, enriches them, then verifies get_document_vectors() returns vectors with correct structure including vector data and extraction_id fields | ||
# Ingests chunks, enriches them, then verifies get_document_vectors() returns vectors with correct structure including vector data and extraction_id fields |
Oops, something went wrong.