Skip to content

Commit

Permalink
Feature/add document summary to ingestion (#1573)
Browse files Browse the repository at this point in the history
* adds document summary to ingestion pipeline

* cleanup impl

* new hybrid document search

* implement hybrid document search
  • Loading branch information
emrgnt-cmplxty authored Nov 11, 2024
1 parent 1bc3cee commit c3a0273
Show file tree
Hide file tree
Showing 32 changed files with 722 additions and 311 deletions.
2 changes: 1 addition & 1 deletion py/cli/commands/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
@click.option(
"--query", prompt="Enter your search query", help="The search query"
)
# VectorSearchSettings
# SearchSettings
@click.option(
"--use-vector-search",
is_flag=True,
Expand Down
3 changes: 1 addition & 2 deletions py/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,7 @@
"KGSearchResult",
"KGSearchSettings",
"VectorSearchResult",
"VectorSearchSettings",
"DocumentSearchSettings",
"SearchSettings",
"HybridSearchSettings",
# User abstractions
"Token",
Expand Down
4 changes: 2 additions & 2 deletions py/core/agent/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from core.base.abstractions import (
AggregateSearchResult,
KGSearchSettings,
VectorSearchSettings,
SearchSettings,
)
from core.base.agent import AgentConfig, Tool
from core.base.providers import CompletionProvider
Expand Down Expand Up @@ -57,7 +57,7 @@ def search_tool(self) -> Tool:
async def search(
self,
query: str,
vector_search_settings: VectorSearchSettings,
vector_search_settings: SearchSettings,
kg_search_settings: KGSearchSettings,
*args,
**kwargs,
Expand Down
3 changes: 1 addition & 2 deletions py/core/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@
"KGSearchResult",
"KGSearchSettings",
"VectorSearchResult",
"VectorSearchSettings",
"DocumentSearchSettings",
"SearchSettings",
"HybridSearchSettings",
# KG abstractions
"KGCreationSettings",
Expand Down
6 changes: 2 additions & 4 deletions py/core/base/abstractions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
from shared.abstractions.prompt import Prompt
from shared.abstractions.search import (
AggregateSearchResult,
DocumentSearchSettings,
HybridSearchSettings,
KGCommunityResult,
KGEntityResult,
Expand All @@ -61,8 +60,8 @@
KGSearchResult,
KGSearchResultType,
KGSearchSettings,
SearchSettings,
VectorSearchResult,
VectorSearchSettings,
)
from shared.abstractions.user import Token, TokenData, UserStats
from shared.abstractions.vector import (
Expand Down Expand Up @@ -130,8 +129,7 @@
"KGGlobalResult",
"KGSearchSettings",
"VectorSearchResult",
"VectorSearchSettings",
"DocumentSearchSettings",
"SearchSettings",
"HybridSearchSettings",
# KG abstractions
"KGCreationSettings",
Expand Down
41 changes: 24 additions & 17 deletions py/core/base/providers/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,18 @@
)
from core.base.abstractions import (
DocumentInfo,
DocumentSearchSettings,
IndexArgsHNSW,
IndexArgsIVFFlat,
IndexMeasure,
IndexMethod,
KGCreationSettings,
KGEnrichmentSettings,
KGEntityDeduplicationSettings,
SearchSettings,
UserStats,
VectorEntry,
VectorQuantizationType,
VectorSearchResult,
VectorSearchSettings,
VectorTableName,
)
from core.base.api.models import (
Expand Down Expand Up @@ -256,6 +255,15 @@ async def get_document_ids_by_status(
):
pass

@abstractmethod
async def search_documents(
self,
query_text: str,
query_embedding: Optional[list[float]] = None,
search_settings: Optional[SearchSettings] = None,
) -> list[DocumentInfo]:
pass


class CollectionHandler(Handler):
@abstractmethod
Expand Down Expand Up @@ -511,28 +519,22 @@ async def upsert_entries(self, entries: list[VectorEntry]) -> None:

@abstractmethod
async def semantic_search(
self, query_vector: list[float], search_settings: VectorSearchSettings
self, query_vector: list[float], search_settings: SearchSettings
) -> list[VectorSearchResult]:
pass

@abstractmethod
async def full_text_search(
self, query_text: str, search_settings: VectorSearchSettings
self, query_text: str, search_settings: SearchSettings
) -> list[VectorSearchResult]:
pass

@abstractmethod
async def search_documents(
self, query_text: str, settings: DocumentSearchSettings
) -> list[dict]:
pass

@abstractmethod
async def hybrid_search(
self,
query_text: str,
query_vector: list[float],
search_settings: VectorSearchSettings,
search_settings: SearchSettings,
*args,
**kwargs,
) -> list[VectorSearchResult]:
Expand Down Expand Up @@ -1404,14 +1406,14 @@ async def upsert_entries(self, entries: list[VectorEntry]) -> None:
return await self.vector_handler.upsert_entries(entries)

async def semantic_search(
self, query_vector: list[float], search_settings: VectorSearchSettings
self, query_vector: list[float], search_settings: SearchSettings
) -> list[VectorSearchResult]:
return await self.vector_handler.semantic_search(
query_vector, search_settings
)

async def full_text_search(
self, query_text: str, search_settings: VectorSearchSettings
self, query_text: str, search_settings: SearchSettings
) -> list[VectorSearchResult]:
return await self.vector_handler.full_text_search(
query_text, search_settings
Expand All @@ -1421,7 +1423,7 @@ async def hybrid_search(
self,
query_text: str,
query_vector: list[float],
search_settings: VectorSearchSettings,
search_settings: SearchSettings,
*args,
**kwargs,
) -> list[VectorSearchResult]:
Expand All @@ -1430,9 +1432,14 @@ async def hybrid_search(
)

async def search_documents(
self, query_text: str, settings: DocumentSearchSettings
) -> list[dict]:
return await self.vector_handler.search_documents(query_text, settings)
self,
query_text: str,
settings: SearchSettings,
query_embedding: Optional[list[float]] = None,
) -> list[DocumentInfo]:
return await self.document_handler.search_documents(
query_text, query_embedding, settings
)

async def delete(
self, filters: dict[str, Any]
Expand Down
7 changes: 6 additions & 1 deletion py/core/base/providers/ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ class IngestionConfig(ProviderConfig):
chunk_enrichment_settings: ChunkEnrichmentSettings = (
ChunkEnrichmentSettings()
)

extra_parsers: dict[str, str] = {}

audio_transcription_model: str = "openai/whisper-1"
Expand All @@ -29,6 +28,12 @@ class IngestionConfig(ProviderConfig):
vision_pdf_prompt_name: str = "vision_pdf"
vision_pdf_model: str = "openai/gpt-4-mini"

skip_document_summary: bool = False
document_summary_system_prompt: str = "default_system"
document_summary_task_prompt: str = "default_summary"
chunks_for_document_summary: int = 128
document_summary_model: str = "openai/gpt-4o-mini"

@property
def supported_providers(self) -> list[str]:
return ["r2r", "unstructured_local", "unstructured_api"]
Expand Down
2 changes: 2 additions & 0 deletions py/core/configs/full_local_llm.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ new_after_n_chars = 512
max_characters = 1_024
combine_under_n_chars = 128
overlap = 20
chunks_for_document_summary = 16
document_summary_model = "ollama/llama3.1"

[orchestration]
provider = "hatchet"
3 changes: 3 additions & 0 deletions py/core/configs/local_llm.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,6 @@ vision_pdf_model = "ollama/llama3.2-vision"

[ingestion.extra_parsers]
pdf = "zerox"

chunks_for_document_summary = 16
document_summary_model = "ollama/llama3.1"
1 change: 1 addition & 0 deletions py/core/configs/r2r_azure.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,6 @@ audio_transcription_model="azure/whisper-1"
vision_img_model = "azure/gpt-4o-mini"
vision_pdf_model = "azure/gpt-4o-mini"

document_summary_model = "azure/gpt-4o-mini"
[ingestion.chunk_enrichment_settings]
generation_config = { model = "azure/gpt-4o-mini" }
27 changes: 16 additions & 11 deletions py/core/main/api/retrieval_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@
from fastapi.responses import StreamingResponse

from core.base import (
DocumentSearchSettings,
GenerationConfig,
KGSearchSettings,
Message,
R2RException,
VectorSearchSettings,
SearchSettings,
)
from core.base.api.models import (
WrappedCompletionResponse,
Expand Down Expand Up @@ -58,7 +57,7 @@ def _register_workflows(self):
def _select_filters(
self,
auth_user: Any,
search_settings: Union[VectorSearchSettings, KGSearchSettings],
search_settings: Union[SearchSettings, KGSearchSettings],
) -> dict[str, Any]:
selected_collections = {
str(cid) for cid in set(search_settings.selected_collection_ids)
Expand Down Expand Up @@ -111,8 +110,8 @@ async def search_documents(
query: str = Body(
..., description=search_descriptions.get("query")
),
settings: DocumentSearchSettings = Body(
default_factory=DocumentSearchSettings,
settings: SearchSettings = Body(
default_factory=SearchSettings,
description="Settings for document search",
),
auth_user=Depends(self.service.providers.auth.auth_wrapper),
Expand All @@ -127,8 +126,14 @@ async def search_documents(
Allowed operators include `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `like`, `ilike`, `in`, and `nin`.
"""

query_embedding = (
await self.service.providers.embedding.async_get_embedding(
query
)
)
results = await self.service.search_documents(
query=query,
query_embedding=query_embedding,
settings=settings,
)
return results
Expand All @@ -142,8 +147,8 @@ async def search_app(
query: str = Body(
..., description=search_descriptions.get("query")
),
vector_search_settings: VectorSearchSettings = Body(
default_factory=VectorSearchSettings,
vector_search_settings: SearchSettings = Body(
default_factory=SearchSettings,
description=search_descriptions.get("vector_search_settings"),
),
kg_search_settings: KGSearchSettings = Body(
Expand Down Expand Up @@ -187,8 +192,8 @@ async def search_app(
@self.base_endpoint
async def rag_app(
query: str = Body(..., description=rag_descriptions.get("query")),
vector_search_settings: VectorSearchSettings = Body(
default_factory=VectorSearchSettings,
vector_search_settings: SearchSettings = Body(
default_factory=SearchSettings,
description=rag_descriptions.get("vector_search_settings"),
),
kg_search_settings: KGSearchSettings = Body(
Expand Down Expand Up @@ -261,8 +266,8 @@ async def agent_app(
description=agent_descriptions.get("messages"),
deprecated=True,
),
vector_search_settings: VectorSearchSettings = Body(
default_factory=VectorSearchSettings,
vector_search_settings: SearchSettings = Body(
default_factory=SearchSettings,
description=agent_descriptions.get("vector_search_settings"),
),
kg_search_settings: KGSearchSettings = Body(
Expand Down
10 changes: 9 additions & 1 deletion py/core/main/orchestration/hatchet/ingestion_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import uuid
from typing import TYPE_CHECKING
from uuid import UUID
from fastapi import HTTPException

from fastapi import HTTPException
from hatchet_sdk import ConcurrencyLimitStrategy, Context
from litellm import AuthenticationError

Expand Down Expand Up @@ -103,6 +103,14 @@ async def parse(self, context: Context) -> dict:
# document_info_dict = context.step_output("parse")["document_info"]
# document_info = DocumentInfo(**document_info_dict)

await service.update_document_status(
document_info, status=IngestionStatus.AUGMENTING
)
await service.augment_document_info(
document_info,
[extraction.to_dict() for extraction in extractions],
)

await self.ingestion_service.update_document_status(
document_info,
status=IngestionStatus.EMBEDDING,
Expand Down
7 changes: 6 additions & 1 deletion py/core/main/orchestration/simple/ingestion_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import logging
from uuid import UUID

from fastapi import HTTPException
from litellm import AuthenticationError

from fastapi import HTTPException
from core.base import DocumentExtraction, R2RException, increment_version
from core.utils import (
generate_default_user_collection_id,
Expand Down Expand Up @@ -44,6 +44,11 @@ async def ingest_files(input_data):
async for extraction in extractions_generator
]

await service.update_document_status(
document_info, status=IngestionStatus.AUGMENTING
)
await service.augment_document_info(document_info, extractions)

await service.update_document_status(
document_info, status=IngestionStatus.EMBEDDING
)
Expand Down
Loading

0 comments on commit c3a0273

Please sign in to comment.