Skip to content

Commit

Permalink
Merge branch 'patch/alternative-up' into Nolan/AltPatchSDK
Browse files Browse the repository at this point in the history
  • Loading branch information
NolanTrem authored Nov 29, 2024
2 parents 8b5dbb2 + d454e88 commit 19eca5c
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 22 deletions.
4 changes: 2 additions & 2 deletions py/core/main/api/v3/graph_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -1638,12 +1638,12 @@ async def pull(
)
if has_document:
logger.info(
f"Document {document.id} is already in graph {collection_id}, skipping"
f"Document {document.id} is already in graph {collection_id}, skipping."
)
continue
if len(entities[0]) == 0:
logger.warning(
f"Document {document.id} has no entities, extraction may not have been called"
f"Document {document.id} has no entities, extraction may not have been called, skipping."
)
continue

Expand Down
22 changes: 10 additions & 12 deletions py/core/main/orchestration/simple/kg_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,30 +53,28 @@ async def extract_triples(input_data):
offset = 0
while True:
# Fetch current batch
batch = (
await service.providers.database.collections_handler.documents_in_collection(
collection_id=collection_id,
offset=offset,
limit=batch_size,
)
)["results"]

batch = (await service.providers.database.collections_handler.documents_in_collection(
collection_id=collection_id,
offset=offset,
limit=batch_size
))["results"]

# If no documents returned, we've reached the end
if not batch:
break

# Add current batch to results
documents.extend(batch)

# Update offset for next batch
offset += batch_size

# Optional: If batch is smaller than batch_size, we've reached the end
if len(batch) < batch_size:
break

# documents = service.providers.database.collections_handler.documents_in_collection(input_data.get("collection_id"), offset=0, limit=1000)
print("extracting for documents = ", documents)
print('extracting for documents = ', documents)
document_ids = [document.id for document in documents]

logger.info(
Expand Down
24 changes: 16 additions & 8 deletions py/core/main/services/kg_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,15 +1109,19 @@ async def _extract_kg(
"relation_types": "\n".join(relation_types),
},
)
print('starting a job....')

for attempt in range(retries):
try:
print('getting a response....')

response = await self.providers.llm.aget_completion(
messages,
generation_config=generation_config,
)

kg_extraction = response.choices[0].message.content
print('kg_extraction = ', kg_extraction)

if not kg_extraction:
raise R2RException(
Expand Down Expand Up @@ -1146,6 +1150,7 @@ async def parse_fn(response_str: str) -> Any:
relationships = re.findall(
relationship_pattern, response_str
)
print('found len(relationships) = ', len(relationships))

entities_arr = []
for entity in entities:
Expand All @@ -1168,6 +1173,7 @@ async def parse_fn(response_str: str) -> Any:
attributes={},
)
)
print('found len(entities) = ', len(entities))

relations_arr = []
for relationship in relationships:
Expand Down Expand Up @@ -1215,13 +1221,13 @@ async def parse_fn(response_str: str) -> Any:
if attempt < retries - 1:
await asyncio.sleep(delay)
else:
logger.error(
print(
f"Failed after retries with for chunk {chunks[0].id} of document {chunks[0].document_id}: {e}"
)
# raise e # you should raise an error.
# add metadata to entities and relationships

logger.info(
print(
f"KGExtractionPipe: Completed task number {task_id} of {total_tasks} for document {chunks[0].document_id}",
)

Expand All @@ -1240,11 +1246,15 @@ async def store_kg_extractions(

total_entities, total_relationships = 0, 0

print('received len(kg_extractions) = ', len(kg_extractions))
for extraction in kg_extractions:
total_entities, total_relationships = (
total_entities + len(extraction.entities),
total_relationships + len(extraction.relationships),
)
# print("extraction = ", extraction)

# total_entities, total_relationships = (
# total_entities + len(extraction.entities),
# total_relationships + len(extraction.relationships),
# )
print('storing len(extraction.entities) = ', len(extraction.entities))

for entity in extraction.entities:
await self.providers.database.graph_handler.entities.create(
Expand Down Expand Up @@ -1273,5 +1283,3 @@ async def store_kg_extractions(
metadata=relationship.metadata,
store_type="document", # type: ignore
)

return (total_entities, total_relationships)

0 comments on commit 19eca5c

Please sign in to comment.