Skip to content

Commit

Permalink
Merge branch 'main' into register-workflows
Browse files Browse the repository at this point in the history
  • Loading branch information
natoverse authored Feb 13, 2025
2 parents 1f0c5e4 + 5ef2399 commit a64b4c3
Show file tree
Hide file tree
Showing 7 changed files with 208 additions and 326 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20250213222251109897.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Optimize data iteration by removing some iterrows from code"
}
2 changes: 1 addition & 1 deletion graphrag/index/operations/create_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ def create_graph(

if nodes is not None:
nodes.set_index(node_id, inplace=True)
graph.add_nodes_from((n, dict(d)) for n, d in nodes.iterrows())
graph.add_nodes_from(nodes.to_dict("index").items())

return graph
85 changes: 0 additions & 85 deletions graphrag/index/operations/snapshot_rows.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,12 @@ async def get_summarized(

node_futures = [
do_summarize_descriptions(
str(row[1]["title"]),
sorted(set(row[1]["description"])),
str(row.title), # type: ignore
sorted(set(row.description)), # type: ignore
ticker,
semaphore,
)
for row in nodes.iterrows()
for row in nodes.itertuples(index=False)
]

node_results = await asyncio.gather(*node_futures)
Expand All @@ -109,12 +109,12 @@ async def get_summarized(

edge_futures = [
do_summarize_descriptions(
(str(row[1]["source"]), str(row[1]["target"])),
sorted(set(row[1]["description"])),
(str(row.source), str(row.target)), # type: ignore
sorted(set(row.description)), # type: ignore
ticker,
semaphore,
)
for row in edges.iterrows()
for row in edges.itertuples(index=False)
]

edge_results = await asyncio.gather(*edge_futures)
Expand Down
9 changes: 6 additions & 3 deletions graphrag/index/update/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,12 @@ async def _run_entity_summarization(

# Prepare tasks for async summarization where needed
async def process_row(row):
description = row["description"]
# Accessing attributes directly from the named tuple.
description = row.description
if isinstance(description, list) and len(description) > 1:
# Run entity summarization asynchronously
result = await run_entity_summarization(
row["title"],
row.title,
description,
callbacks,
cache,
Expand All @@ -134,7 +135,9 @@ async def process_row(row):
return description[0] if isinstance(description, list) else description

# Create a list of async tasks for summarization
tasks = [process_row(row) for _, row in entities_df.iterrows()]
tasks = [
process_row(row) for row in entities_df.itertuples(index=False, name="Entity")
]
results = await asyncio.gather(*tasks)

# Update the 'description' column in the DataFrame
Expand Down
104 changes: 62 additions & 42 deletions graphrag/query/input/loaders/dfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@
)


def _prepare_records(df: pd.DataFrame) -> list[dict]:
"""
Reset index and convert the DataFrame to a list of dictionaries.
We rename the reset index column to 'Index' for consistency.
"""
df_reset = df.reset_index().rename(columns={"index": "Index"})
return df_reset.to_dict("records")


def read_entities(
df: pd.DataFrame,
id_col: str = "id",
Expand All @@ -35,12 +45,14 @@ def read_entities(
rank_col: str | None = "degree",
attributes_cols: list[str] | None = None,
) -> list[Entity]:
"""Read entities from a dataframe."""
entities = []
for idx, row in df.iterrows():
entity = Entity(
"""Read entities from a dataframe using pre-converted records."""
records = _prepare_records(df)
return [
Entity(
id=to_str(row, id_col),
short_id=to_optional_str(row, short_id_col) if short_id_col else str(idx),
short_id=to_optional_str(row, short_id_col)
if short_id_col
else str(row["Index"]),
title=to_str(row, title_col),
type=to_optional_str(row, type_col),
description=to_optional_str(row, description_col),
Expand All @@ -57,8 +69,8 @@ def read_entities(
else None
),
)
entities.append(entity)
return entities
for row in records
]


def read_relationships(
Expand All @@ -74,12 +86,14 @@ def read_relationships(
text_unit_ids_col: str | None = "text_unit_ids",
attributes_cols: list[str] | None = None,
) -> list[Relationship]:
"""Read relationships from a dataframe."""
relationships = []
for idx, row in df.iterrows():
rel = Relationship(
"""Read relationships from a dataframe using pre-converted records."""
records = _prepare_records(df)
return [
Relationship(
id=to_str(row, id_col),
short_id=to_optional_str(row, short_id_col) if short_id_col else str(idx),
short_id=to_optional_str(row, short_id_col)
if short_id_col
else str(row["Index"]),
source=to_str(row, source_col),
target=to_str(row, target_col),
description=to_optional_str(row, description_col),
Expand All @@ -95,8 +109,8 @@ def read_relationships(
else None
),
)
relationships.append(rel)
return relationships
for row in records
]


def read_covariates(
Expand All @@ -108,12 +122,14 @@ def read_covariates(
text_unit_ids_col: str | None = "text_unit_ids",
attributes_cols: list[str] | None = None,
) -> list[Covariate]:
"""Read covariates from a dataframe."""
covariates = []
for idx, row in df.iterrows():
cov = Covariate(
"""Read covariates from a dataframe using pre-converted records."""
records = _prepare_records(df)
return [
Covariate(
id=to_str(row, id_col),
short_id=to_optional_str(row, short_id_col) if short_id_col else str(idx),
short_id=to_optional_str(row, short_id_col)
if short_id_col
else str(row["Index"]),
subject_id=to_str(row, subject_col),
covariate_type=(
to_str(row, covariate_type_col) if covariate_type_col else "claim"
Expand All @@ -125,8 +141,8 @@ def read_covariates(
else None
),
)
covariates.append(cov)
return covariates
for row in records
]


def read_communities(
Expand All @@ -141,12 +157,14 @@ def read_communities(
sub_communities_col: str | None = "sub_community_ids",
attributes_cols: list[str] | None = None,
) -> list[Community]:
"""Read communities from a dataframe."""
communities = []
for idx, row in df.iterrows():
comm = Community(
"""Read communities from a dataframe using pre-converted records."""
records = _prepare_records(df)
return [
Community(
id=to_str(row, id_col),
short_id=to_optional_str(row, short_id_col) if short_id_col else str(idx),
short_id=to_optional_str(row, short_id_col)
if short_id_col
else str(row["Index"]),
title=to_str(row, title_col),
level=to_str(row, level_col),
entity_ids=to_optional_list(row, entities_col, item_type=str),
Expand All @@ -161,8 +179,8 @@ def read_communities(
else None
),
)
communities.append(comm)
return communities
for row in records
]


def read_community_reports(
Expand All @@ -177,12 +195,14 @@ def read_community_reports(
content_embedding_col: str | None = "full_content_embedding",
attributes_cols: list[str] | None = None,
) -> list[CommunityReport]:
"""Read community reports from a dataframe."""
reports = []
for idx, row in df.iterrows():
report = CommunityReport(
"""Read community reports from a dataframe using pre-converted records."""
records = _prepare_records(df)
return [
CommunityReport(
id=to_str(row, id_col),
short_id=to_optional_str(row, short_id_col) if short_id_col else str(idx),
short_id=to_optional_str(row, short_id_col)
if short_id_col
else str(row["Index"]),
title=to_str(row, title_col),
community_id=to_str(row, community_col),
summary=to_str(row, summary_col),
Expand All @@ -197,8 +217,8 @@ def read_community_reports(
else None
),
)
reports.append(report)
return reports
for row in records
]


def read_text_units(
Expand All @@ -212,12 +232,12 @@ def read_text_units(
document_ids_col: str | None = "document_ids",
attributes_cols: list[str] | None = None,
) -> list[TextUnit]:
"""Read text units from a dataframe."""
text_units = []
for idx, row in df.iterrows():
chunk = TextUnit(
"""Read text units from a dataframe using pre-converted records."""
records = _prepare_records(df)
return [
TextUnit(
id=to_str(row, id_col),
short_id=str(idx),
short_id=str(row["Index"]),
text=to_str(row, text_col),
entity_ids=to_optional_list(row, entities_col, item_type=str),
relationship_ids=to_optional_list(row, relationships_col, item_type=str),
Expand All @@ -232,5 +252,5 @@ def read_text_units(
else None
),
)
text_units.append(chunk)
return text_units
for row in records
]
Loading

0 comments on commit a64b4c3

Please sign in to comment.