Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chore/remove iterrows #1708

Merged
merged 6 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20250213222251109897.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Optimize data iteration by removing some iterrows from code"
}
2 changes: 1 addition & 1 deletion graphrag/index/operations/create_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ def create_graph(

if nodes is not None:
nodes.set_index(node_id, inplace=True)
graph.add_nodes_from((n, dict(d)) for n, d in nodes.iterrows())
graph.add_nodes_from(nodes.to_dict("index").items())

return graph
85 changes: 0 additions & 85 deletions graphrag/index/operations/snapshot_rows.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,12 @@ async def get_summarized(

node_futures = [
do_summarize_descriptions(
str(row[1]["title"]),
sorted(set(row[1]["description"])),
str(row.title), # type: ignore
sorted(set(row.description)), # type: ignore
ticker,
semaphore,
)
for row in nodes.iterrows()
for row in nodes.itertuples(index=False)
]

node_results = await asyncio.gather(*node_futures)
Expand All @@ -109,12 +109,12 @@ async def get_summarized(

edge_futures = [
do_summarize_descriptions(
(str(row[1]["source"]), str(row[1]["target"])),
sorted(set(row[1]["description"])),
(str(row.source), str(row.target)), # type: ignore
sorted(set(row.description)), # type: ignore
ticker,
semaphore,
)
for row in edges.iterrows()
for row in edges.itertuples(index=False)
]

edge_results = await asyncio.gather(*edge_futures)
Expand Down
9 changes: 6 additions & 3 deletions graphrag/index/update/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,12 @@ async def _run_entity_summarization(

# Prepare tasks for async summarization where needed
async def process_row(row):
description = row["description"]
# Accessing attributes directly from the named tuple.
description = row.description
if isinstance(description, list) and len(description) > 1:
# Run entity summarization asynchronously
result = await run_entity_summarization(
row["title"],
row.title,
description,
callbacks,
cache,
Expand All @@ -134,7 +135,9 @@ async def process_row(row):
return description[0] if isinstance(description, list) else description

# Create a list of async tasks for summarization
tasks = [process_row(row) for _, row in entities_df.iterrows()]
tasks = [
process_row(row) for row in entities_df.itertuples(index=False, name="Entity")
]
results = await asyncio.gather(*tasks)

# Update the 'description' column in the DataFrame
Expand Down
104 changes: 62 additions & 42 deletions graphrag/query/input/loaders/dfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@
)


def _prepare_records(df: pd.DataFrame) -> list[dict]:
"""
Reset index and convert the DataFrame to a list of dictionaries.

We rename the reset index column to 'Index' for consistency.
"""
df_reset = df.reset_index().rename(columns={"index": "Index"})
return df_reset.to_dict("records")


def read_entities(
df: pd.DataFrame,
id_col: str = "id",
Expand All @@ -35,12 +45,14 @@ def read_entities(
rank_col: str | None = "degree",
attributes_cols: list[str] | None = None,
) -> list[Entity]:
"""Read entities from a dataframe."""
entities = []
for idx, row in df.iterrows():
entity = Entity(
"""Read entities from a dataframe using pre-converted records."""
records = _prepare_records(df)
return [
Entity(
id=to_str(row, id_col),
short_id=to_optional_str(row, short_id_col) if short_id_col else str(idx),
short_id=to_optional_str(row, short_id_col)
if short_id_col
else str(row["Index"]),
title=to_str(row, title_col),
type=to_optional_str(row, type_col),
description=to_optional_str(row, description_col),
Expand All @@ -57,8 +69,8 @@ def read_entities(
else None
),
)
entities.append(entity)
return entities
for row in records
]


def read_relationships(
Expand All @@ -74,12 +86,14 @@ def read_relationships(
text_unit_ids_col: str | None = "text_unit_ids",
attributes_cols: list[str] | None = None,
) -> list[Relationship]:
"""Read relationships from a dataframe."""
relationships = []
for idx, row in df.iterrows():
rel = Relationship(
"""Read relationships from a dataframe using pre-converted records."""
records = _prepare_records(df)
return [
Relationship(
id=to_str(row, id_col),
short_id=to_optional_str(row, short_id_col) if short_id_col else str(idx),
short_id=to_optional_str(row, short_id_col)
if short_id_col
else str(row["Index"]),
source=to_str(row, source_col),
target=to_str(row, target_col),
description=to_optional_str(row, description_col),
Expand All @@ -95,8 +109,8 @@ def read_relationships(
else None
),
)
relationships.append(rel)
return relationships
for row in records
]


def read_covariates(
Expand All @@ -108,12 +122,14 @@ def read_covariates(
text_unit_ids_col: str | None = "text_unit_ids",
attributes_cols: list[str] | None = None,
) -> list[Covariate]:
"""Read covariates from a dataframe."""
covariates = []
for idx, row in df.iterrows():
cov = Covariate(
"""Read covariates from a dataframe using pre-converted records."""
records = _prepare_records(df)
return [
Covariate(
id=to_str(row, id_col),
short_id=to_optional_str(row, short_id_col) if short_id_col else str(idx),
short_id=to_optional_str(row, short_id_col)
if short_id_col
else str(row["Index"]),
subject_id=to_str(row, subject_col),
covariate_type=(
to_str(row, covariate_type_col) if covariate_type_col else "claim"
Expand All @@ -125,8 +141,8 @@ def read_covariates(
else None
),
)
covariates.append(cov)
return covariates
for row in records
]


def read_communities(
Expand All @@ -141,12 +157,14 @@ def read_communities(
sub_communities_col: str | None = "sub_community_ids",
attributes_cols: list[str] | None = None,
) -> list[Community]:
"""Read communities from a dataframe."""
communities = []
for idx, row in df.iterrows():
comm = Community(
"""Read communities from a dataframe using pre-converted records."""
records = _prepare_records(df)
return [
Community(
id=to_str(row, id_col),
short_id=to_optional_str(row, short_id_col) if short_id_col else str(idx),
short_id=to_optional_str(row, short_id_col)
if short_id_col
else str(row["Index"]),
title=to_str(row, title_col),
level=to_str(row, level_col),
entity_ids=to_optional_list(row, entities_col, item_type=str),
Expand All @@ -161,8 +179,8 @@ def read_communities(
else None
),
)
communities.append(comm)
return communities
for row in records
]


def read_community_reports(
Expand All @@ -177,12 +195,14 @@ def read_community_reports(
content_embedding_col: str | None = "full_content_embedding",
attributes_cols: list[str] | None = None,
) -> list[CommunityReport]:
"""Read community reports from a dataframe."""
reports = []
for idx, row in df.iterrows():
report = CommunityReport(
"""Read community reports from a dataframe using pre-converted records."""
records = _prepare_records(df)
return [
CommunityReport(
id=to_str(row, id_col),
short_id=to_optional_str(row, short_id_col) if short_id_col else str(idx),
short_id=to_optional_str(row, short_id_col)
if short_id_col
else str(row["Index"]),
title=to_str(row, title_col),
community_id=to_str(row, community_col),
summary=to_str(row, summary_col),
Expand All @@ -197,8 +217,8 @@ def read_community_reports(
else None
),
)
reports.append(report)
return reports
for row in records
]


def read_text_units(
Expand All @@ -212,12 +232,12 @@ def read_text_units(
document_ids_col: str | None = "document_ids",
attributes_cols: list[str] | None = None,
) -> list[TextUnit]:
"""Read text units from a dataframe."""
text_units = []
for idx, row in df.iterrows():
chunk = TextUnit(
"""Read text units from a dataframe using pre-converted records."""
records = _prepare_records(df)
return [
TextUnit(
id=to_str(row, id_col),
short_id=str(idx),
short_id=str(row["Index"]),
text=to_str(row, text_col),
entity_ids=to_optional_list(row, entities_col, item_type=str),
relationship_ids=to_optional_list(row, relationships_col, item_type=str),
Expand All @@ -232,5 +252,5 @@ def read_text_units(
else None
),
)
text_units.append(chunk)
return text_units
for row in records
]
Loading
Loading