diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index 9600bf8a1..610cb00a3 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -1391,12 +1391,12 @@ async def pull( ) if has_document: logger.info( - f"Document {document.id} is already in graph {collection_id}, skipping" + f"Document {document.id} is already in graph {collection_id}, skipping." ) continue if len(entities[0]) == 0: logger.warning( - f"Document {document.id} has no entities, extraction may not have been called" + f"Document {document.id} has no entities, extraction may not have been called, skipping." ) continue diff --git a/py/core/main/orchestration/simple/kg_workflow.py b/py/core/main/orchestration/simple/kg_workflow.py index ff21a9f82..ea06a02ea 100644 --- a/py/core/main/orchestration/simple/kg_workflow.py +++ b/py/core/main/orchestration/simple/kg_workflow.py @@ -53,30 +53,28 @@ async def extract_triples(input_data): offset = 0 while True: # Fetch current batch - batch = ( - await service.providers.database.collections_handler.documents_in_collection( - collection_id=collection_id, - offset=offset, - limit=batch_size, - ) - )["results"] - + batch = (await service.providers.database.collections_handler.documents_in_collection( + collection_id=collection_id, + offset=offset, + limit=batch_size + ))["results"] + # If no documents returned, we've reached the end if not batch: break - + # Add current batch to results documents.extend(batch) - + # Update offset for next batch offset += batch_size - + # Optional: If batch is smaller than batch_size, we've reached the end if len(batch) < batch_size: break # documents = service.providers.database.collections_handler.documents_in_collection(input_data.get("collection_id"), offset=0, limit=1000) - print("extracting for documents = ", documents) + print('extracting for documents = ', documents) document_ids = [document.id for document in documents] logger.info( @@ -91,7 +89,7 @@ async def extract_triples(input_data): document_id=document_id, **input_data["kg_creation_settings"], ): - print("extraction = ", extraction) + print('found extraction w/ entities = = ', len(extraction.entities)) extractions.append(extraction) await service.store_kg_extractions(extractions) diff --git a/py/core/main/services/kg_service.py b/py/core/main/services/kg_service.py index 4f7951a00..5b217875d 100644 --- a/py/core/main/services/kg_service.py +++ b/py/core/main/services/kg_service.py @@ -1074,15 +1074,19 @@ async def _extract_kg( "relation_types": "\n".join(relation_types), }, ) + print('starting a job....') for attempt in range(retries): try: + print('getting a response....') + response = await self.providers.llm.aget_completion( messages, generation_config=generation_config, ) kg_extraction = response.choices[0].message.content + print('kg_extraction = ', kg_extraction) if not kg_extraction: raise R2RException( @@ -1111,6 +1115,7 @@ async def parse_fn(response_str: str) -> Any: relationships = re.findall( relationship_pattern, response_str ) + print('found len(relationships) = ', len(relationships)) entities_arr = [] for entity in entities: @@ -1133,6 +1138,7 @@ async def parse_fn(response_str: str) -> Any: attributes={}, ) ) + print('found len(entities) = ', len(entities)) relations_arr = [] for relationship in relationships: @@ -1180,13 +1186,13 @@ async def parse_fn(response_str: str) -> Any: if attempt < retries - 1: await asyncio.sleep(delay) else: - logger.error( + print( f"Failed after retries with for chunk {chunks[0].id} of document {chunks[0].document_id}: {e}" ) # raise e # you should raise an error. # add metadata to entities and relationships - logger.info( + print( f"KGExtractionPipe: Completed task number {task_id} of {total_tasks} for document {chunks[0].document_id}", ) @@ -1205,13 +1211,15 @@ async def store_kg_extractions( total_entities, total_relationships = 0, 0 + print('received len(kg_extractions) = ', len(kg_extractions)) for extraction in kg_extractions: - print("extraction = ", extraction) + # print("extraction = ", extraction) - total_entities, total_relationships = ( - total_entities + len(extraction.entities), - total_relationships + len(extraction.relationships), - ) + # total_entities, total_relationships = ( + # total_entities + len(extraction.entities), + # total_relationships + len(extraction.relationships), + # ) + print('storing len(extraction.entities) = ', len(extraction.entities)) if extraction.entities: await self.providers.database.graph_handler.entities.create( @@ -1221,6 +1229,4 @@ async def store_kg_extractions( if extraction.relationships: await self.providers.database.graph_handler.relationships.create( extraction.relationships, store_type="document" - ) - - return (total_entities, total_relationships) + ) \ No newline at end of file diff --git a/py/core/providers/database/graph.py b/py/core/providers/database/graph.py index 84f206f01..1f1b14d81 100644 --- a/py/core/providers/database/graph.py +++ b/py/core/providers/database/graph.py @@ -158,6 +158,7 @@ async def create( """ for value in values: + print('inserting len(values) into graph = ', len(values)) result = await self.connection_manager.fetchrow_query(QUERY, value) results.append(result["id"])