Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
emrgnt-cmplxty committed Nov 28, 2024
1 parent db50ef1 commit d454e88
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 25 deletions.
4 changes: 2 additions & 2 deletions py/core/main/api/v3/graph_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -1391,12 +1391,12 @@ async def pull(
)
if has_document:
logger.info(
f"Document {document.id} is already in graph {collection_id}, skipping"
f"Document {document.id} is already in graph {collection_id}, skipping."
)
continue
if len(entities[0]) == 0:
logger.warning(
f"Document {document.id} has no entities, extraction may not have been called"
f"Document {document.id} has no entities, extraction may not have been called, skipping."
)
continue

Expand Down
24 changes: 11 additions & 13 deletions py/core/main/orchestration/simple/kg_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,30 +53,28 @@ async def extract_triples(input_data):
offset = 0
while True:
# Fetch current batch
batch = (
await service.providers.database.collections_handler.documents_in_collection(
collection_id=collection_id,
offset=offset,
limit=batch_size,
)
)["results"]

batch = (await service.providers.database.collections_handler.documents_in_collection(
collection_id=collection_id,
offset=offset,
limit=batch_size
))["results"]

# If no documents returned, we've reached the end
if not batch:
break

# Add current batch to results
documents.extend(batch)

# Update offset for next batch
offset += batch_size

# Optional: If batch is smaller than batch_size, we've reached the end
if len(batch) < batch_size:
break

# documents = service.providers.database.collections_handler.documents_in_collection(input_data.get("collection_id"), offset=0, limit=1000)
print("extracting for documents = ", documents)
print('extracting for documents = ', documents)
document_ids = [document.id for document in documents]

logger.info(
Expand All @@ -91,7 +89,7 @@ async def extract_triples(input_data):
document_id=document_id,
**input_data["kg_creation_settings"],
):
print("extraction = ", extraction)
print('found extraction w/ entities = = ', len(extraction.entities))
extractions.append(extraction)
await service.store_kg_extractions(extractions)

Expand Down
26 changes: 16 additions & 10 deletions py/core/main/services/kg_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,15 +1074,19 @@ async def _extract_kg(
"relation_types": "\n".join(relation_types),
},
)
print('starting a job....')

for attempt in range(retries):
try:
print('getting a response....')

response = await self.providers.llm.aget_completion(
messages,
generation_config=generation_config,
)

kg_extraction = response.choices[0].message.content
print('kg_extraction = ', kg_extraction)

if not kg_extraction:
raise R2RException(
Expand Down Expand Up @@ -1111,6 +1115,7 @@ async def parse_fn(response_str: str) -> Any:
relationships = re.findall(
relationship_pattern, response_str
)
print('found len(relationships) = ', len(relationships))

entities_arr = []
for entity in entities:
Expand All @@ -1133,6 +1138,7 @@ async def parse_fn(response_str: str) -> Any:
attributes={},
)
)
print('found len(entities) = ', len(entities))

relations_arr = []
for relationship in relationships:
Expand Down Expand Up @@ -1180,13 +1186,13 @@ async def parse_fn(response_str: str) -> Any:
if attempt < retries - 1:
await asyncio.sleep(delay)
else:
logger.error(
print(
f"Failed after retries with for chunk {chunks[0].id} of document {chunks[0].document_id}: {e}"
)
# raise e # you should raise an error.
# add metadata to entities and relationships

logger.info(
print(
f"KGExtractionPipe: Completed task number {task_id} of {total_tasks} for document {chunks[0].document_id}",
)

Expand All @@ -1205,13 +1211,15 @@ async def store_kg_extractions(

total_entities, total_relationships = 0, 0

print('received len(kg_extractions) = ', len(kg_extractions))
for extraction in kg_extractions:
print("extraction = ", extraction)
# print("extraction = ", extraction)

total_entities, total_relationships = (
total_entities + len(extraction.entities),
total_relationships + len(extraction.relationships),
)
# total_entities, total_relationships = (
# total_entities + len(extraction.entities),
# total_relationships + len(extraction.relationships),
# )
print('storing len(extraction.entities) = ', len(extraction.entities))

if extraction.entities:
await self.providers.database.graph_handler.entities.create(
Expand All @@ -1221,6 +1229,4 @@ async def store_kg_extractions(
if extraction.relationships:
await self.providers.database.graph_handler.relationships.create(
extraction.relationships, store_type="document"
)

return (total_entities, total_relationships)
)
1 change: 1 addition & 0 deletions py/core/providers/database/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ async def create(
"""

for value in values:
print('inserting len(values) into graph = ', len(values))
result = await self.connection_manager.fetchrow_query(QUERY, value)
results.append(result["id"])

Expand Down

0 comments on commit d454e88

Please sign in to comment.