Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Community children #1704

Merged
merged 13 commits into from
Feb 14, 2025
4 changes: 4 additions & 0 deletions .semversioner/next-release/major-20250213175726371530.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "major",
"description": "Add children to communities to avoid re-compute."
}
36 changes: 30 additions & 6 deletions docs/examples_notebooks/index_migration_to_v2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 41,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -25,17 +25,17 @@
},
{
"cell_type": "code",
"execution_count": 42,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# This is the directory that has your settings.yaml\n",
"PROJECT_DIRECTORY = \"<your project directory\""
"PROJECT_DIRECTORY = \"<your project directory>\""
]
},
{
"cell_type": "code",
"execution_count": 43,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -54,7 +54,7 @@
},
{
"cell_type": "code",
"execution_count": 44,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -65,7 +65,7 @@
},
{
"cell_type": "code",
"execution_count": 45,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -96,6 +96,30 @@
" final_nodes.loc[:, [\"id\", \"degree\", \"x\", \"y\"]].groupby(\"id\").first().reset_index()\n",
")\n",
"final_entities = final_entities.merge(graph_props, on=\"id\", how=\"left\")\n",
"# we're also persistint the frequency column\n",
"final_entities[\"frequency\"] = final_entities[\"text_unit_ids\"].count()\n",
"\n",
"\n",
"# we added children to communities to eliminate query-time reconstruction\n",
"parent_grouped = final_communities.groupby(\"parent\").agg(\n",
" children=(\"community\", \"unique\")\n",
")\n",
"final_communities = final_communities.merge(\n",
" parent_grouped,\n",
" left_on=\"community\",\n",
" right_on=\"parent\",\n",
" how=\"left\",\n",
")\n",
"\n",
"# add children to the reports as well\n",
"final_community_reports = final_community_reports.merge(\n",
" parent_grouped,\n",
" left_on=\"community\",\n",
" right_on=\"parent\",\n",
" how=\"left\",\n",
")\n",
"\n",
"# copy children into the reports as well\n",
"\n",
"# we renamed all the output files for better clarity now that we don't have workflow naming constraints from DataShaper\n",
"await write_table_to_storage(final_documents, \"documents\", storage)\n",
Expand Down
21 changes: 19 additions & 2 deletions graphrag/index/flows/create_communities.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
"""All the steps to transform final communities."""

from datetime import datetime, timezone
from typing import cast
from uuid import uuid4

import numpy as np
import pandas as pd

from graphrag.index.operations.cluster_graph import cluster_graph
Expand Down Expand Up @@ -92,7 +94,21 @@ def create_communities(
str
)
final_communities["parent"] = final_communities["parent"].astype(int)

# collect the children so we have a tree going both ways
parent_grouped = cast(
"pd.DataFrame",
final_communities.groupby("parent").agg(children=("community", "unique")),
)
final_communities = final_communities.merge(
parent_grouped,
left_on="community",
right_on="parent",
how="left",
)
# replace NaN children with empty list
final_communities["children"] = final_communities["children"].apply(
lambda x: x if isinstance(x, np.ndarray) else [] # type: ignore
)
# add fields for incremental update tracking
final_communities["period"] = datetime.now(timezone.utc).date().isoformat()
final_communities["size"] = final_communities.loc[:, "entity_ids"].apply(len)
Expand All @@ -103,8 +119,9 @@ def create_communities(
"id",
"human_readable_id",
"community",
"parent",
"level",
"parent",
"children",
"title",
"entity_ids",
"relationship_ids",
Expand Down
1 change: 1 addition & 0 deletions graphrag/index/flows/create_community_reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ async def create_community_reports(

community_reports = await summarize_communities(
nodes,
communities,
local_contexts,
build_level_context,
callbacks,
Expand Down
1 change: 1 addition & 0 deletions graphrag/index/flows/create_community_reports_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ async def create_community_reports_text(

community_reports = await summarize_communities(
nodes,
communities,
local_contexts,
build_level_context,
callbacks,
Expand Down
7 changes: 4 additions & 3 deletions graphrag/index/operations/finalize_community_reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ def finalize_community_reports(
communities: pd.DataFrame,
) -> pd.DataFrame:
"""All the steps to transform final community reports."""
# Merge with communities to add size and period
# Merge with communities to add shared fields
community_reports = reports.merge(
communities.loc[:, ["community", "parent", "size", "period"]],
communities.loc[:, ["community", "parent", "children", "size", "period"]],
on="community",
how="left",
copy=False,
Expand All @@ -31,8 +31,9 @@ def finalize_community_reports(
"id",
"human_readable_id",
"community",
"parent",
"level",
"parent",
"children",
"title",
"summary",
"full_content",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
)
from graphrag.index.operations.summarize_communities.utils import (
get_levels,
restore_community_hierarchy,
)
from graphrag.index.run.derive_from_rows import derive_from_rows
from graphrag.logger.progress import progress_ticker
Expand All @@ -30,6 +29,7 @@

async def summarize_communities(
nodes: pd.DataFrame,
communities: pd.DataFrame,
local_contexts,
level_context_builder: Callable,
callbacks: WorkflowCallbacks,
Expand All @@ -49,7 +49,12 @@ async def summarize_communities(
if strategy_config.get("llm") and strategy_config["llm"]["max_retries"] == -1:
strategy_config["llm"]["max_retries"] = len(nodes)

community_hierarchy = restore_community_hierarchy(nodes)
community_hierarchy = (
communities.explode("children")
.rename({"children": "sub_community"}, axis=1)
.loc[:, ["community", "level", "sub_community"]]
).dropna()

levels = get_levels(nodes)

level_contexts = []
Expand Down
47 changes: 0 additions & 47 deletions graphrag/index/operations/summarize_communities/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

"""A module containing community report generation utilities."""

from itertools import pairwise

import pandas as pd

import graphrag.model.schemas as schemas
Expand All @@ -17,48 +15,3 @@ def get_levels(
levels = df[level_column].dropna().unique()
levels = [int(lvl) for lvl in levels if lvl != -1]
return sorted(levels, reverse=True)


def restore_community_hierarchy(
input: pd.DataFrame,
name_column: str = schemas.TITLE,
community_column: str = schemas.COMMUNITY_ID,
level_column: str = schemas.COMMUNITY_LEVEL,
) -> pd.DataFrame:
"""Restore the community hierarchy from the node data."""
# Group by community and level, aggregate names as lists
community_df = (
input.groupby([community_column, level_column])[name_column]
.apply(set)
.reset_index()
)

# Build dictionary with levels as integers
community_levels = {
level: group.set_index(community_column)[name_column].to_dict()
for level, group in community_df.groupby(level_column)
}

# get unique levels, sorted in ascending order
levels = sorted(community_levels.keys()) # type: ignore
community_hierarchy = []

# Iterate through adjacent levels
for current_level, next_level in pairwise(levels):
current_communities = community_levels[current_level]
next_communities = community_levels[next_level]

# Find sub-communities
for curr_comm, curr_entities in current_communities.items():
for next_comm, next_entities in next_communities.items():
if next_entities.issubset(curr_entities):
community_hierarchy.append({
community_column: curr_comm,
schemas.COMMUNITY_LEVEL: current_level,
schemas.SUB_COMMUNITY: next_comm,
schemas.SUB_COMMUNITY_SIZE: len(next_entities),
})

return pd.DataFrame(
community_hierarchy,
)
19 changes: 12 additions & 7 deletions graphrag/model/community.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,15 @@
class Community(Named):
"""A protocol for a community in the system."""

level: str = ""
level: str
"""Community level."""

parent: str
"""Community ID of the parent node of this community."""

children: list[str]
"""List of community IDs of the child nodes of this community."""

entity_ids: list[str] | None = None
"""List of entity IDs related to the community (optional)."""

Expand All @@ -25,9 +31,6 @@ class Community(Named):
covariate_ids: dict[str, list[str]] | None = None
"""Dictionary of different types of covariates related to the community (optional), e.g. claims"""

sub_community_ids: list[str] | None = None
"""List of community IDs of the child nodes of this community (optional)."""

attributes: dict[str, Any] | None = None
"""A dictionary of additional attributes associated with the community (optional). To be included in the search prompt."""

Expand All @@ -48,7 +51,8 @@ def from_dict(
entities_key: str = "entity_ids",
relationships_key: str = "relationship_ids",
covariates_key: str = "covariate_ids",
sub_communities_key: str = "sub_community_ids",
parent_key: str = "parent",
children_key: str = "children",
attributes_key: str = "attributes",
size_key: str = "size",
period_key: str = "period",
Expand All @@ -57,12 +61,13 @@ def from_dict(
return Community(
id=d[id_key],
title=d[title_key],
short_id=d.get(short_id_key),
level=d[level_key],
parent=d[parent_key],
children=d[children_key],
short_id=d.get(short_id_key),
entity_ids=d.get(entities_key),
relationship_ids=d.get(relationships_key),
covariate_ids=d.get(covariates_key),
sub_community_ids=d.get(sub_communities_key),
attributes=d.get(attributes_key),
size=d.get(size_key),
period=d.get(period_key),
Expand Down
1 change: 0 additions & 1 deletion graphrag/model/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

# COMMUNITY HIERARCHY TABLE SCHEMA
SUB_COMMUNITY = "sub_community"
SUB_COMMUNITY_SIZE = "sub_community_size"
COMMUNITY_LEVEL = "level"

# COMMUNITY CONTEXT TABLE SCHEMA
Expand Down
32 changes: 8 additions & 24 deletions graphrag/query/context_builder/dynamic_community_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,23 +56,7 @@ def __init__(
self.llm_kwargs = llm_kwargs

self.reports = {report.community_id: report for report in community_reports}
# mapping from community to sub communities
self.node2children = {
community.short_id: (
[]
if community.sub_community_ids is None
else [str(x) for x in community.sub_community_ids]
)
for community in communities
if community.short_id is not None
}

# mapping from community to parent community
self.node2parent: dict[str, str] = {
sub_community: community
for community, sub_communities in self.node2children.items()
for sub_community in sub_communities
}
self.communities = {community.short_id: community for community in communities}

# mapping from level to communities
self.levels: dict[str, list[str]] = {}
Expand Down Expand Up @@ -140,18 +124,18 @@ async def select(self, query: str) -> tuple[list[CommunityReport], dict[str, Any
relevant_communities.add(community)
# find children nodes of the current node and append them to the queue
# TODO check why some sub_communities are NOT in report_df
if community in self.node2children:
for sub_community in self.node2children[community]:
if sub_community in self.reports:
communities_to_rate.append(sub_community)
if community in self.communities:
for child in self.communities[community].children:
if child in self.reports:
communities_to_rate.append(child)
else:
log.debug(
"dynamic community selection: cannot find community %s in reports",
sub_community,
child,
)
# remove parent node if the current node is deemed relevant
if not self.keep_parent and community in self.node2parent:
relevant_communities.discard(self.node2parent[community])
if not self.keep_parent and community in self.communities:
relevant_communities.discard(self.communities[community].parent)
queue = communities_to_rate
level += 1
if (
Expand Down
Loading
Loading