Skip to content

Commit

Permalink
work doc chunks
Browse files Browse the repository at this point in the history
  • Loading branch information
emrgnt-cmplxty committed Oct 18, 2024
1 parent 681d276 commit 8bca097
Show file tree
Hide file tree
Showing 21 changed files with 253 additions and 289 deletions.
3 changes: 0 additions & 3 deletions py/core/base/providers/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,6 @@ class DatabaseProvider(Provider):
def __init__(self, config: DatabaseConfig):
logger.info(f"Initializing DatabaseProvider with config {config}.")

self.handle: Any = (
None # TODO - Type this properly, we later use it as a PostgresHandle
)
super().__init__(config)

@abstractmethod
Expand Down
1 change: 1 addition & 0 deletions py/core/main/api/management_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ async def document_chunks_app(
"total_entries": document_chunks["total_entries"]
}


@self.router.get("/collections_overview")
@self.base_endpoint
async def collections_overview_app(
Expand Down
4 changes: 2 additions & 2 deletions py/core/main/assembly/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ async def create_database_provider(
"Embedding config must have a base dimension to initialize database."
)

vector_db_dimension = self.config.embedding.base_dimension
dimension = self.config.embedding.base_dimension
quantization_type = (
self.config.embedding.quantization_settings.quantization_type
)
Expand All @@ -156,7 +156,7 @@ async def create_database_provider(

database_provider = PostgresDBProvider(
db_config,
vector_db_dimension,
dimension,
crypto_provider=crypto_provider,
quantization_type=quantization_type,
)
Expand Down
29 changes: 14 additions & 15 deletions py/core/main/orchestration/hatchet/ingestion_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,13 @@ async def parse(self, context: Context) -> dict:
status=IngestionStatus.SUCCESS,
)

collection_id = await service.providers.database.handle.assign_document_to_collection(
# TODO: Move logic onto the `management service`
collection_id=generate_default_user_collection_id(str(document_info.user_id))
await service.providers.database.assign_document_to_collection_relational(
document_id=document_info.id,
collection_id=generate_default_user_collection_id(
document_info.user_id
),
collection_id=collection_id,
)

service.providers.database.handle.assign_document_to_collection(
await service.providers.database.assign_document_to_collection_vector(
document_id=document_info.id, collection_id=collection_id
)

Expand Down Expand Up @@ -189,7 +188,7 @@ async def on_failure(self, context: Context) -> None:

try:
documents_overview = (
await self.ingestion_service.providers.database.handle.get_documents_overview(
await self.ingestion_service.providers.database.get_documents_overview(
filter_document_ids=[document_id]
)
)["results"]
Expand Down Expand Up @@ -248,7 +247,7 @@ async def update_files(self, context: Context) -> None:
)

documents_overview = (
await self.ingestion_service.providers.database.handle.get_documents_overview(
await self.ingestion_service.providers.database.get_documents_overview(
filter_document_ids=document_ids,
filter_user_ids=None if user.is_superuser else [user.id],
)
Expand Down Expand Up @@ -400,13 +399,13 @@ async def finalize(self, context: Context) -> dict:
)

try:
collection_id = await self.ingestion_service.providers.database.handle.assign_document_to_collection(
# TODO - Move logic onto the `management service`
collection_id = generate_default_user_collection_id(document_info.user_id)
await self.ingestion_service.providers.database.assign_document_to_collection_relational(
document_id=document_info.id,
collection_id=generate_default_user_collection_id(
document_info.user_id
),
collection_id=collection_id,
)
self.ingestion_service.providers.database.handle.assign_document_to_collection(
await self.ingestion_service.providers.database.assign_document_to_collection_vector(
document_id=document_info.id, collection_id=collection_id
)
except Exception as e:
Expand All @@ -432,7 +431,7 @@ async def on_failure(self, context: Context) -> None:

try:
documents_overview = (
await self.ingestion_service.providers.database.handle.get_documents_overview(
await self.ingestion_service.providers.database.get_documents_overview(
filter_document_ids=[document_id]
)
)["results"]
Expand Down Expand Up @@ -474,7 +473,7 @@ async def create_vector_index(self, context: Context) -> dict:
)
)

self.ingestion_service.providers.database.handle.create_index(
self.ingestion_service.providers.database.create_index(
**parsed_data
)

Expand Down
2 changes: 1 addition & 1 deletion py/core/main/orchestration/hatchet/kg_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ async def on_failure(self, context: Context) -> None:
return

try:
await self.kg_service.providers.database.handle.set_workflow_status(
await self.kg_service.providers.database.set_workflow_status(
id=uuid.UUID(document_id),
status_type="kg_extraction_status",
status=KGExtractionStatus.FAILED,
Expand Down
26 changes: 13 additions & 13 deletions py/core/main/orchestration/simple/ingestion_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,14 @@ async def ingest_files(input_data):
)

try:
collection_id = await service.providers.database.handle.assign_document_to_collection(
# TODO - Move logic onto management service
collection_id=generate_default_user_collection_id(str(document_info.user_id))
await service.providers.database.assign_document_to_collection_relational(
document_id=document_info.id,
collection_id=generate_default_user_collection_id(
str(document_info.user_id)
),
collection_id=collection_id,
)
service.providers.database.handle.assign_document_to_collection(
document_id=document_info.id, collection_id=collection_id
await service.providers.database.assign_document_to_collection_vector(
document_info.id, collection_id
)
except Exception as e:
logger.error(
Expand Down Expand Up @@ -125,7 +125,7 @@ async def update_files(input_data):
)

documents_overview = (
await service.providers.database.handle.get_documents_overview(
await service.providers.database.get_documents_overview(
filter_document_ids=document_ids,
filter_user_ids=None if user.is_superuser else [user.id],
)
Expand Down Expand Up @@ -227,13 +227,13 @@ async def ingest_chunks(input_data):
)

try:
collection_id = await service.providers.database.handle.assign_document_to_collection(
# TODO - Move logic onto management service
collection_id=generate_default_user_collection_id(str(document_info.user_id))
await service.providers.database.assign_document_to_collection_relational(
document_id=document_info.id,
collection_id=generate_default_user_collection_id(
str(document_info.user_id)
),
collection_id=collection_id,
)
service.providers.database.handle.assign_document_to_collection(
await service.providers.database.assign_document_to_collection_vector(
document_id=document_info.id, collection_id=collection_id
)
except Exception as e:
Expand Down Expand Up @@ -262,7 +262,7 @@ async def create_vector_index(input_data):
)
)

service.providers.database.handle.create_index(**parsed_data)
service.providers.database.create_index(**parsed_data)

except Exception as e:
raise R2RException(
Expand Down
22 changes: 11 additions & 11 deletions py/core/main/services/auth_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,22 @@ async def verify_email(
status_code=400, message="Email verification is not required"
)

user_id = await self.providers.database.handle.get_user_id_by_verification_code(
user_id = await self.providers.database.get_user_id_by_verification_code(
verification_code
)
if not user_id:
raise R2RException(
status_code=400, message="Invalid or expired verification code"
)

user = await self.providers.database.handle.get_user_by_id(user_id)
user = await self.providers.database.get_user_by_id(user_id)
if not user or user.email != email:
raise R2RException(
status_code=400, message="Invalid or expired verification code"
)

await self.providers.database.handle.mark_user_as_verified(user_id)
await self.providers.database.handle.remove_verification_code(
await self.providers.database.mark_user_as_verified(user_id)
await self.providers.database.remove_verification_code(
verification_code
)
return {"message": f"User account {user_id} verified successfully."}
Expand All @@ -72,7 +72,7 @@ async def login(self, email: str, password: str) -> dict[str, Token]:
@telemetry_event("GetCurrentUser")
async def user(self, token: str) -> UserResponse:
token_data = await self.providers.auth.decode_token(token)
user = await self.providers.database.handle.get_user_by_email(
user = await self.providers.database.get_user_by_email(
token_data.email
)
if user is None:
Expand Down Expand Up @@ -124,7 +124,7 @@ async def update_user(
profile_picture: Optional[str] = None,
) -> UserResponse:
user: UserResponse = (
await self.providers.database.handle.get_user_by_id(str(user_id))
await self.providers.database.get_user_by_id(str(user_id))
)
if not user:
raise R2RException(status_code=404, message="User not found")
Expand All @@ -138,7 +138,7 @@ async def update_user(
user.bio = bio
if profile_picture is not None:
user.profile_picture = profile_picture
return await self.providers.database.handle.update_user(user)
return await self.providers.database.update_user(user)

@telemetry_event("DeleteUserAccount")
async def delete_user(
Expand All @@ -148,7 +148,7 @@ async def delete_user(
delete_vector_data: bool = False,
is_superuser: bool = False,
) -> dict[str, str]:
user = await self.providers.database.handle.get_user_by_id(user_id)
user = await self.providers.database.get_user_by_id(user_id)
if not user:
raise R2RException(status_code=404, message="User not found")
if not (
Expand All @@ -158,9 +158,9 @@ async def delete_user(
)
):
raise R2RException(status_code=400, message="Incorrect password")
await self.providers.database.handle.delete_user_relational(user_id)
await self.providers.database.delete_user_relational(user_id)
if delete_vector_data:
self.providers.database.handle.delete_user_vector(user_id)
self.providers.database.delete_user_vector(user_id)

return {"message": f"User account {user_id} deleted successfully."}

Expand All @@ -170,6 +170,6 @@ async def clean_expired_blacklisted_tokens(
max_age_hours: int = 7 * 24,
current_time: Optional[datetime] = None,
):
await self.providers.database.handle.clean_expired_blacklisted_tokens(
await self.providers.database.clean_expired_blacklisted_tokens(
max_age_hours, current_time
)
12 changes: 6 additions & 6 deletions py/core/main/services/ingestion_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ async def ingest_file_ingress(
)

existing_document_info = (
await self.providers.database.handle.get_documents_overview(
await self.providers.database.get_documents_overview(
filter_user_ids=[user.id],
filter_document_ids=[document_id],
)
Expand All @@ -120,7 +120,7 @@ async def ingest_file_ingress(
message=f"Document {document_id} was already ingested and is not in a failed state.",
)

await self.providers.database.handle.upsert_documents_overview(
await self.providers.database.upsert_documents_overview(
document_info
)

Expand Down Expand Up @@ -256,7 +256,7 @@ async def finalize_ingestion(
is_update: bool = False,
) -> None:
if is_update:
self.providers.database.handle.delete(
self.providers.database.delete(
filters={
"$and": [
{"document_id": {"$eq": document_info.id}},
Expand Down Expand Up @@ -284,7 +284,7 @@ async def update_document_status(

async def _update_document_status_in_db(self, document_info: DocumentInfo):
try:
await self.providers.database.handle.upsert_documents_overview(
await self.providers.database.upsert_documents_overview(
document_info
)
except Exception as e:
Expand Down Expand Up @@ -325,7 +325,7 @@ async def ingest_chunks_ingress(
)

existing_document_info = (
await self.providers.database.handle.get_documents_overview(
await self.providers.database.get_documents_overview(
filter_user_ids=[user.id],
filter_document_ids=[document_id],
)
Expand All @@ -339,7 +339,7 @@ async def ingest_chunks_ingress(
message=f"Document {document_id} was already ingested and is not in a failed state.",
)

await self.providers.database.handle.upsert_documents_overview(
await self.providers.database.upsert_documents_overview(
document_info
)

Expand Down
8 changes: 4 additions & 4 deletions py/core/main/services/kg_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async def kg_triples_extraction(
f"KGService: Processing document {document_id} for KG extraction"
)

await self.providers.database.handle.set_workflow_status(
await self.providers.database.set_workflow_status(
id=document_id,
status_type="kg_extraction_status",
status=KGExtractionStatus.PROCESSING,
Expand Down Expand Up @@ -101,7 +101,7 @@ async def kg_triples_extraction(

except Exception as e:
logger.error(f"KGService: Error in kg_extraction: {e}")
await self.providers.database.handle.set_workflow_status(
await self.providers.database.set_workflow_status(
id=document_id,
status_type="kg_extraction_status",
status=KGExtractionStatus.FAILED,
Expand All @@ -128,7 +128,7 @@ async def get_document_ids_for_create_graph(
]

document_ids = (
await self.providers.database.handle.get_document_ids_by_status(
await self.providers.database.get_document_ids_by_status(
status_type="kg_extraction_status",
status=document_status_filter,
collection_id=collection_id,
Expand Down Expand Up @@ -195,7 +195,7 @@ async def kg_entity_description(
f"KGService: Completed kg_entity_description for batch {i+1}/{num_batches} for document {document_id}"
)

await self.providers.database.handle.set_workflow_status(
await self.providers.database.set_workflow_status(
id=document_id,
status_type="kg_extraction_status",
status=KGExtractionStatus.SUCCESS,
Expand Down
Loading

0 comments on commit 8bca097

Please sign in to comment.