Skip to content

Commit

Permalink
Merge main
Browse files Browse the repository at this point in the history
  • Loading branch information
NolanTrem committed Nov 13, 2024
1 parent f47af4c commit 1858873
Show file tree
Hide file tree
Showing 15 changed files with 52 additions and 63 deletions.
2 changes: 1 addition & 1 deletion py/core/base/providers/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -1401,7 +1401,7 @@ async def search_documents(
query_text: str,
settings: SearchSettings,
query_embedding: Optional[list[float]] = None,
) -> list[DocumentInfo]:
) -> list[DocumentResponse]:
return await self.document_handler.search_documents(
query_text, query_embedding, settings
)
Expand Down
9 changes: 4 additions & 5 deletions py/core/main/api/v2/management_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,16 @@
WrappedAnalyticsResponse,
WrappedAppSettingsResponse,
WrappedBooleanResponse,
WrappedChunksResponse,
WrappedCollectionResponse,
WrappedCollectionsResponse,
WrappedConversationResponse,
WrappedConversationsResponse,
WrappedDeleteResponse,
WrappedChunksResponse,
WrappedGenericMessageResponse,
WrappedBooleanResponse,
WrappedDocumentsResponse,
WrappedPromptsResponse,
WrappedGenericMessageResponse,
WrappedLogResponse,
WrappedPromptsResponse,
WrappedServerStatsResponse,
WrappedUsersResponse,
)
Expand Down Expand Up @@ -335,7 +334,7 @@ async def documents_overview_app(
document_ids: list[str] = Query([]),
offset: int = Query(0, ge=0),
limit: int = Query(
1_000,
100,
ge=-1,
description="Number of items to return. Use -1 to return all items.",
),
Expand Down
8 changes: 4 additions & 4 deletions py/core/main/api/v3/chunks_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
RunType,
UnprocessedChunk,
UpdateChunk,
VectorSearchSettings,
SearchSettings,
ChunkResponse,
)
from core.base.api.models import (
Expand Down Expand Up @@ -53,7 +53,7 @@ def __init__(
def _select_filters(
self,
auth_user: Any,
search_settings: VectorSearchSettings | KGSearchSettings,
search_settings: SearchSettings | KGSearchSettings,
) -> dict[str, Any]:
selected_collections = {
str(cid) for cid in set(search_settings.selected_collection_ids)
Expand Down Expand Up @@ -302,8 +302,8 @@ async def create_chunks(
@self.base_endpoint
async def search_chunks(
query: str = Body(...),
vector_search_settings: VectorSearchSettings = Body(
default_factory=VectorSearchSettings,
vector_search_settings: SearchSettings = Body(
default_factory=SearchSettings,
),
auth_user=Depends(self.providers.auth.auth_wrapper),
) -> WrappedVectorSearchResponse: # type: ignore
Expand Down
16 changes: 8 additions & 8 deletions py/core/main/api/v3/retrieval_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
KGSearchSettings,
Message,
R2RException,
VectorSearchSettings,
SearchSettings,
)
from core.base.api.models import (
WrappedCompletionResponse,
Expand Down Expand Up @@ -46,7 +46,7 @@ def _register_workflows(self):
def _select_filters(
self,
auth_user: Any,
search_settings: VectorSearchSettings | KGSearchSettings,
search_settings: SearchSettings | KGSearchSettings,
) -> dict[str, Any]:
selected_collections = {
str(cid) for cid in set(search_settings.selected_collection_ids)
Expand Down Expand Up @@ -174,8 +174,8 @@ async def search_app(
...,
description="Search query to find relevant documents",
),
vector_search_settings: VectorSearchSettings = Body(
default_factory=VectorSearchSettings,
vector_search_settings: SearchSettings = Body(
default_factory=SearchSettings,
description="Settings for vector-based search",
),
kg_search_settings: KGSearchSettings = Body(
Expand Down Expand Up @@ -287,8 +287,8 @@ async def search_app(
@self.base_endpoint
async def rag_app(
query: str = Body(...),
vector_search_settings: VectorSearchSettings = Body(
default_factory=VectorSearchSettings,
vector_search_settings: SearchSettings = Body(
default_factory=SearchSettings,
description="Settings for vector-based search",
),
kg_search_settings: KGSearchSettings = Body(
Expand Down Expand Up @@ -434,8 +434,8 @@ async def agent_app(
deprecated=True,
description="List of messages (deprecated, use message instead)",
),
vector_search_settings: VectorSearchSettings = Body(
default_factory=VectorSearchSettings,
vector_search_settings: SearchSettings = Body(
default_factory=SearchSettings,
description="Settings for vector-based search",
),
kg_search_settings: KGSearchSettings = Body(
Expand Down
1 change: 0 additions & 1 deletion py/core/main/app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from core.base import R2RException
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.utils import get_openapi
Expand Down
3 changes: 1 addition & 2 deletions py/core/main/orchestration/simple/ingestion_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from fastapi import HTTPException
from litellm import AuthenticationError

from fastapi import HTTPException
from core.base import R2RException, DocumentChunk, increment_version
from core.base import DocumentChunk, R2RException, increment_version
from core.utils import (
generate_default_user_collection_id,
generate_extraction_id,
Expand Down
7 changes: 4 additions & 3 deletions py/core/main/services/ingestion_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import uuid
from datetime import datetime
from typing import Any, AsyncGenerator, Optional, Sequence, Union
from typing import Any, AsyncGenerator, Optional, Sequence
from uuid import UUID

from fastapi import HTTPException
Expand Down Expand Up @@ -234,7 +234,7 @@ async def parse_file(

async def augment_document_info(
self,
document_info: DocumentInfo,
document_info: DocumentResponse,
chunked_documents: list[dict],
) -> None:
if not self.config.ingestion.skip_document_summary:
Expand All @@ -253,6 +253,7 @@ async def augment_document_info(
task_prompt_name=self.config.ingestion.document_summary_task_prompt,
task_inputs={"document": document},
)
# FIXME: Why are we hardcoding the model here?
response = await self.providers.llm.aget_completion(
messages=messages,
generation_config=GenerationConfig(model="openai/gpt-4o-mini"),
Expand Down Expand Up @@ -286,7 +287,7 @@ async def embed_document(

async def store_embeddings(
self,
embeddings: Sequence[Union[dict, VectorEntry]],
embeddings: Sequence[dict | VectorEntry],
) -> AsyncGenerator[str, None]:
vector_entries = [
(
Expand Down
4 changes: 2 additions & 2 deletions py/core/main/services/retrieval_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from core import R2RStreamingRAGAgent
from core.base import (
DocumentInfo,
DocumentResponse,
EmbeddingPurpose,
GenerationConfig,
KGSearchSettings,
Expand Down Expand Up @@ -128,7 +128,7 @@ async def search_documents(
query: str,
settings: SearchSettings,
query_embedding: Optional[list[float]] = None,
) -> list[DocumentInfo]:
) -> list[DocumentResponse]:

return await self.providers.database.search_documents(
query_text=query,
Expand Down
3 changes: 2 additions & 1 deletion py/core/pipes/kg/triples_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,8 @@ async def _run_logic( # type: ignore

# sort the extractions accroding to chunk_order field in metadata in ascending order
extractions = sorted(
extractions, key=lambda x: x.metadata.get("chunk_order", float('inf'))
extractions,
key=lambda x: x.metadata.get("chunk_order", float("inf")),
)

# group these extractions into groups of extraction_merge_count
Expand Down
2 changes: 1 addition & 1 deletion py/core/providers/auth/r2r_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
from core.base import (
AuthConfig,
AuthProvider,
CollectionResponse,
CryptoProvider,
DatabaseProvider,
EmailProvider,
R2RException,
Token,
TokenData,
CollectionResponse,
)
from core.base.api.models import UserResponse

Expand Down
6 changes: 2 additions & 4 deletions py/core/providers/database/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from typing import Any, Optional
from uuid import UUID, uuid4

from fastapi import HTTPException
from asyncpg.exceptions import UniqueViolationError
from fastapi import HTTPException

from core.base import (
CollectionHandler,
Expand All @@ -19,9 +19,7 @@
IngestionStatus,
)
from core.base.api.models import CollectionResponse
from core.utils import (
generate_default_user_collection_id,
)
from core.utils import generate_default_user_collection_id

from .base import PostgresConnectionManager

Expand Down
4 changes: 2 additions & 2 deletions py/core/providers/database/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ async def hybrid_document_search(
search_settings.hybrid_search_settings.full_text_weight
)

for doc_id, scores in doc_scores.items():
for scores in doc_scores.values():
semantic_score = 1 / (rrf_k + scores["semantic_rank"])
full_text_score = 1 / (rrf_k + scores["full_text_rank"])

Expand Down Expand Up @@ -851,7 +851,7 @@ async def search_documents(

# TODO - Remove copy pasta, consolidate
def _build_filters(
self, filters: dict, parameters: list[Union[str, int, bytes]]
self, filters: dict, parameters: list[str | int | bytes]
) -> str:

def parse_condition(key: str, value: Any) -> str: # type: ignore
Expand Down
6 changes: 3 additions & 3 deletions py/core/providers/database/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
IndexArgsIVFFlat,
IndexMeasure,
IndexMethod,
R2RException,
SearchSettings,
VectorEntry,
VectorHandler,
VectorQuantizationType,
VectorSearchResult,
VectorTableName,
R2RException,
)

from .base import PostgresConnectionManager
Expand Down Expand Up @@ -1335,15 +1335,15 @@ async def list_chunks(
async def search_documents(
self,
query_text: str,
settings: DocumentSearchSettings,
settings: SearchSettings,
) -> list[dict[str, Any]]:
"""
Search for documents based on their metadata fields and/or body text.
Joins with document_info table to get complete document metadata.
Args:
query_text (str): The search query text
settings (DocumentSearchSettings): Search settings including search preferences and filters
settings (SearchSettings): Search settings including search preferences and filters
Returns:
list[dict[str, Any]]: List of documents with their search scores and complete metadata
Expand Down
24 changes: 11 additions & 13 deletions py/sdk/v2/mixins/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@
from typing import AsyncGenerator, Optional

from ...models import (
DocumentSearchSettings,
GenerationConfig,
KGSearchSettings,
Message,
RAGResponse,
SearchResponse,
SearchSettings,
)

Expand All @@ -18,8 +16,8 @@ class RetrievalMixins:
async def search_documents(
self,
query: str,
settings: Optional[dict | DocumentSearchSettings] = None,
) -> SearchResponse:
settings: Optional[dict] = None,
):
"""
Conduct a vector and/or KG search.
Expand All @@ -43,9 +41,9 @@ async def search_documents(
async def search(
self,
query: str,
vector_search_settings: Optional[Union[dict, SearchSettings]] = None,
kg_search_settings: Optional[Union[dict, KGSearchSettings]] = None,
) -> CombinedSearchResponse:
vector_search_settings: Optional[dict | SearchSettings] = None,
kg_search_settings: Optional[dict | KGSearchSettings] = None,
):
"""
Conduct a vector and/or KG search.
Expand Down Expand Up @@ -73,8 +71,8 @@ async def search(

async def completion(
self,
messages: list[Union[dict, Message]],
generation_config: Optional[Union[dict, GenerationConfig]] = None,
messages: list[dict | Message],
generation_config: Optional[dict | GenerationConfig] = None,
):
cast_messages: list[Message] = [
Message(**msg) if isinstance(msg, dict) else msg
Expand All @@ -94,12 +92,12 @@ async def completion(
async def rag(
self,
query: str,
rag_generation_config: Optional[Union[dict, GenerationConfig]] = None,
vector_search_settings: Optional[Union[dict, SearchSettings]] = None,
kg_search_settings: Optional[Union[dict, KGSearchSettings]] = None,
rag_generation_config: Optional[dict | GenerationConfig] = None,
vector_search_settings: Optional[dict | SearchSettings] = None,
kg_search_settings: Optional[dict | KGSearchSettings] = None,
task_prompt_override: Optional[str] = None,
include_title_if_available: Optional[bool] = False,
) -> Union[RAGResponse, AsyncGenerator[RAGResponse, None]]:
) -> RAGResponse | AsyncGenerator[RAGResponse, None]:
"""
Conducts a Retrieval Augmented Generation (RAG) search with the given query.
Expand Down
Loading

0 comments on commit 1858873

Please sign in to comment.