diff --git a/.semversioner/next-release/major-20250213175726371530.json b/.semversioner/next-release/major-20250213175726371530.json new file mode 100644 index 0000000000..33db379ee4 --- /dev/null +++ b/.semversioner/next-release/major-20250213175726371530.json @@ -0,0 +1,4 @@ +{ + "type": "major", + "description": "Add children to communities to avoid re-compute." +} diff --git a/docs/examples_notebooks/index_migration_to_v2.ipynb b/docs/examples_notebooks/index_migration_to_v2.ipynb index ed7300a6b5..863b8f1d7d 100644 --- a/docs/examples_notebooks/index_migration_to_v2.ipynb +++ b/docs/examples_notebooks/index_migration_to_v2.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 41, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -25,17 +25,17 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# This is the directory that has your settings.yaml\n", - "PROJECT_DIRECTORY = \"\"" ] }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -54,7 +54,7 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -65,7 +65,7 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -96,6 +96,30 @@ " final_nodes.loc[:, [\"id\", \"degree\", \"x\", \"y\"]].groupby(\"id\").first().reset_index()\n", ")\n", "final_entities = final_entities.merge(graph_props, on=\"id\", how=\"left\")\n", + "# we're also persistint the frequency column\n", + "final_entities[\"frequency\"] = final_entities[\"text_unit_ids\"].count()\n", + "\n", + "\n", + "# we added children to communities to eliminate query-time reconstruction\n", + "parent_grouped = final_communities.groupby(\"parent\").agg(\n", + " children=(\"community\", \"unique\")\n", + ")\n", + "final_communities = final_communities.merge(\n", + " parent_grouped,\n", + " left_on=\"community\",\n", + " right_on=\"parent\",\n", + " how=\"left\",\n", + ")\n", + "\n", + "# add children to the reports as well\n", + "final_community_reports = final_community_reports.merge(\n", + " parent_grouped,\n", + " left_on=\"community\",\n", + " right_on=\"parent\",\n", + " how=\"left\",\n", + ")\n", + "\n", + "# copy children into the reports as well\n", "\n", "# we renamed all the output files for better clarity now that we don't have workflow naming constraints from DataShaper\n", "await write_table_to_storage(final_documents, \"documents\", storage)\n", diff --git a/graphrag/index/flows/create_communities.py b/graphrag/index/flows/create_communities.py index 452a081a88..cf9abc8f1a 100644 --- a/graphrag/index/flows/create_communities.py +++ b/graphrag/index/flows/create_communities.py @@ -4,8 +4,10 @@ """All the steps to transform final communities.""" from datetime import datetime, timezone +from typing import cast from uuid import uuid4 +import numpy as np import pandas as pd from graphrag.index.operations.cluster_graph import cluster_graph @@ -92,7 +94,21 @@ def create_communities( str ) final_communities["parent"] = final_communities["parent"].astype(int) - + # collect the children so we have a tree going both ways + parent_grouped = cast( + "pd.DataFrame", + final_communities.groupby("parent").agg(children=("community", "unique")), + ) + final_communities = final_communities.merge( + parent_grouped, + left_on="community", + right_on="parent", + how="left", + ) + # replace NaN children with empty list + final_communities["children"] = final_communities["children"].apply( + lambda x: x if isinstance(x, np.ndarray) else [] # type: ignore + ) # add fields for incremental update tracking final_communities["period"] = datetime.now(timezone.utc).date().isoformat() final_communities["size"] = final_communities.loc[:, "entity_ids"].apply(len) @@ -103,8 +119,9 @@ def create_communities( "id", "human_readable_id", "community", - "parent", "level", + "parent", + "children", "title", "entity_ids", "relationship_ids", diff --git a/graphrag/index/flows/create_community_reports.py b/graphrag/index/flows/create_community_reports.py index 3365692071..42f3c23a87 100644 --- a/graphrag/index/flows/create_community_reports.py +++ b/graphrag/index/flows/create_community_reports.py @@ -62,6 +62,7 @@ async def create_community_reports( community_reports = await summarize_communities( nodes, + communities, local_contexts, build_level_context, callbacks, diff --git a/graphrag/index/flows/create_community_reports_text.py b/graphrag/index/flows/create_community_reports_text.py index 1e2fd5ca97..c568f9d247 100644 --- a/graphrag/index/flows/create_community_reports_text.py +++ b/graphrag/index/flows/create_community_reports_text.py @@ -53,6 +53,7 @@ async def create_community_reports_text( community_reports = await summarize_communities( nodes, + communities, local_contexts, build_level_context, callbacks, diff --git a/graphrag/index/operations/finalize_community_reports.py b/graphrag/index/operations/finalize_community_reports.py index 7da1d6f305..8699730fb8 100644 --- a/graphrag/index/operations/finalize_community_reports.py +++ b/graphrag/index/operations/finalize_community_reports.py @@ -13,9 +13,9 @@ def finalize_community_reports( communities: pd.DataFrame, ) -> pd.DataFrame: """All the steps to transform final community reports.""" - # Merge with communities to add size and period + # Merge with communities to add shared fields community_reports = reports.merge( - communities.loc[:, ["community", "parent", "size", "period"]], + communities.loc[:, ["community", "parent", "children", "size", "period"]], on="community", how="left", copy=False, @@ -31,8 +31,9 @@ def finalize_community_reports( "id", "human_readable_id", "community", - "parent", "level", + "parent", + "children", "title", "summary", "full_content", diff --git a/graphrag/index/operations/summarize_communities/summarize_communities.py b/graphrag/index/operations/summarize_communities/summarize_communities.py index 17ffa90436..6fcd3557d8 100644 --- a/graphrag/index/operations/summarize_communities/summarize_communities.py +++ b/graphrag/index/operations/summarize_communities/summarize_communities.py @@ -20,7 +20,6 @@ ) from graphrag.index.operations.summarize_communities.utils import ( get_levels, - restore_community_hierarchy, ) from graphrag.index.run.derive_from_rows import derive_from_rows from graphrag.logger.progress import progress_ticker @@ -30,6 +29,7 @@ async def summarize_communities( nodes: pd.DataFrame, + communities: pd.DataFrame, local_contexts, level_context_builder: Callable, callbacks: WorkflowCallbacks, @@ -49,7 +49,12 @@ async def summarize_communities( if strategy_config.get("llm") and strategy_config["llm"]["max_retries"] == -1: strategy_config["llm"]["max_retries"] = len(nodes) - community_hierarchy = restore_community_hierarchy(nodes) + community_hierarchy = ( + communities.explode("children") + .rename({"children": "sub_community"}, axis=1) + .loc[:, ["community", "level", "sub_community"]] + ).dropna() + levels = get_levels(nodes) level_contexts = [] diff --git a/graphrag/index/operations/summarize_communities/utils.py b/graphrag/index/operations/summarize_communities/utils.py index 4767cf0971..04abded174 100644 --- a/graphrag/index/operations/summarize_communities/utils.py +++ b/graphrag/index/operations/summarize_communities/utils.py @@ -3,8 +3,6 @@ """A module containing community report generation utilities.""" -from itertools import pairwise - import pandas as pd import graphrag.model.schemas as schemas @@ -17,48 +15,3 @@ def get_levels( levels = df[level_column].dropna().unique() levels = [int(lvl) for lvl in levels if lvl != -1] return sorted(levels, reverse=True) - - -def restore_community_hierarchy( - input: pd.DataFrame, - name_column: str = schemas.TITLE, - community_column: str = schemas.COMMUNITY_ID, - level_column: str = schemas.COMMUNITY_LEVEL, -) -> pd.DataFrame: - """Restore the community hierarchy from the node data.""" - # Group by community and level, aggregate names as lists - community_df = ( - input.groupby([community_column, level_column])[name_column] - .apply(set) - .reset_index() - ) - - # Build dictionary with levels as integers - community_levels = { - level: group.set_index(community_column)[name_column].to_dict() - for level, group in community_df.groupby(level_column) - } - - # get unique levels, sorted in ascending order - levels = sorted(community_levels.keys()) # type: ignore - community_hierarchy = [] - - # Iterate through adjacent levels - for current_level, next_level in pairwise(levels): - current_communities = community_levels[current_level] - next_communities = community_levels[next_level] - - # Find sub-communities - for curr_comm, curr_entities in current_communities.items(): - for next_comm, next_entities in next_communities.items(): - if next_entities.issubset(curr_entities): - community_hierarchy.append({ - community_column: curr_comm, - schemas.COMMUNITY_LEVEL: current_level, - schemas.SUB_COMMUNITY: next_comm, - schemas.SUB_COMMUNITY_SIZE: len(next_entities), - }) - - return pd.DataFrame( - community_hierarchy, - ) diff --git a/graphrag/model/community.py b/graphrag/model/community.py index 43d6c4033a..3109531bb8 100644 --- a/graphrag/model/community.py +++ b/graphrag/model/community.py @@ -13,9 +13,15 @@ class Community(Named): """A protocol for a community in the system.""" - level: str = "" + level: str """Community level.""" + parent: str + """Community ID of the parent node of this community.""" + + children: list[str] + """List of community IDs of the child nodes of this community.""" + entity_ids: list[str] | None = None """List of entity IDs related to the community (optional).""" @@ -25,9 +31,6 @@ class Community(Named): covariate_ids: dict[str, list[str]] | None = None """Dictionary of different types of covariates related to the community (optional), e.g. claims""" - sub_community_ids: list[str] | None = None - """List of community IDs of the child nodes of this community (optional).""" - attributes: dict[str, Any] | None = None """A dictionary of additional attributes associated with the community (optional). To be included in the search prompt.""" @@ -48,7 +51,8 @@ def from_dict( entities_key: str = "entity_ids", relationships_key: str = "relationship_ids", covariates_key: str = "covariate_ids", - sub_communities_key: str = "sub_community_ids", + parent_key: str = "parent", + children_key: str = "children", attributes_key: str = "attributes", size_key: str = "size", period_key: str = "period", @@ -57,12 +61,13 @@ def from_dict( return Community( id=d[id_key], title=d[title_key], - short_id=d.get(short_id_key), level=d[level_key], + parent=d[parent_key], + children=d[children_key], + short_id=d.get(short_id_key), entity_ids=d.get(entities_key), relationship_ids=d.get(relationships_key), covariate_ids=d.get(covariates_key), - sub_community_ids=d.get(sub_communities_key), attributes=d.get(attributes_key), size=d.get(size_key), period=d.get(period_key), diff --git a/graphrag/model/schemas.py b/graphrag/model/schemas.py index dd10766ac1..1505c64e61 100644 --- a/graphrag/model/schemas.py +++ b/graphrag/model/schemas.py @@ -29,7 +29,6 @@ # COMMUNITY HIERARCHY TABLE SCHEMA SUB_COMMUNITY = "sub_community" -SUB_COMMUNITY_SIZE = "sub_community_size" COMMUNITY_LEVEL = "level" # COMMUNITY CONTEXT TABLE SCHEMA diff --git a/graphrag/query/context_builder/dynamic_community_selection.py b/graphrag/query/context_builder/dynamic_community_selection.py index 2a101f6ca3..ebd98e86ae 100644 --- a/graphrag/query/context_builder/dynamic_community_selection.py +++ b/graphrag/query/context_builder/dynamic_community_selection.py @@ -56,23 +56,7 @@ def __init__( self.llm_kwargs = llm_kwargs self.reports = {report.community_id: report for report in community_reports} - # mapping from community to sub communities - self.node2children = { - community.short_id: ( - [] - if community.sub_community_ids is None - else [str(x) for x in community.sub_community_ids] - ) - for community in communities - if community.short_id is not None - } - - # mapping from community to parent community - self.node2parent: dict[str, str] = { - sub_community: community - for community, sub_communities in self.node2children.items() - for sub_community in sub_communities - } + self.communities = {community.short_id: community for community in communities} # mapping from level to communities self.levels: dict[str, list[str]] = {} @@ -140,18 +124,18 @@ async def select(self, query: str) -> tuple[list[CommunityReport], dict[str, Any relevant_communities.add(community) # find children nodes of the current node and append them to the queue # TODO check why some sub_communities are NOT in report_df - if community in self.node2children: - for sub_community in self.node2children[community]: - if sub_community in self.reports: - communities_to_rate.append(sub_community) + if community in self.communities: + for child in self.communities[community].children: + if child in self.reports: + communities_to_rate.append(child) else: log.debug( "dynamic community selection: cannot find community %s in reports", - sub_community, + child, ) # remove parent node if the current node is deemed relevant - if not self.keep_parent and community in self.node2parent: - relevant_communities.discard(self.node2parent[community]) + if not self.keep_parent and community in self.communities: + relevant_communities.discard(self.communities[community].parent) queue = communities_to_rate level += 1 if ( diff --git a/graphrag/query/indexer_adapters.py b/graphrag/query/indexer_adapters.py index d7929ccced..ec12cd633c 100644 --- a/graphrag/query/indexer_adapters.py +++ b/graphrag/query/indexer_adapters.py @@ -12,9 +12,6 @@ import pandas as pd from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.index.operations.summarize_communities.utils import ( - restore_community_hierarchy, -) from graphrag.model.community import Community from graphrag.model.community_report import CommunityReport from graphrag.model.covariate import Covariate @@ -197,27 +194,6 @@ def read_indexer_communities( ] nodes_df = nodes_df.loc[nodes_df.community.isin(reports_df.community.unique())] - # reconstruct the community hierarchy - # note that restore_community_hierarchy only return communities with sub communities - community_hierarchy = restore_community_hierarchy(input=nodes_df) - - # small datasets can result in hierarchies that are only one deep, so the hierarchy will have no rows - if not community_hierarchy.empty: - community_hierarchy = ( - community_hierarchy.groupby(["community"]) - .agg({"sub_community": list}) - .reset_index() - .rename(columns={"sub_community": "sub_community_ids"}) - ) - # add sub community IDs to community DataFrame - communities_df = communities_df.merge( - community_hierarchy, on="community", how="left" - ) - # replace NaN sub community IDs with empty list - communities_df.sub_community_ids = communities_df.sub_community_ids.apply( - lambda x: x if isinstance(x, list) else [] - ) - return read_communities( communities_df, id_col="id", @@ -227,7 +203,8 @@ def read_indexer_communities( entities_col=None, relationships_col=None, covariates_col=None, - sub_communities_col="sub_community_ids", + parent_col="parent", + children_col="children", attributes_cols=None, ) diff --git a/graphrag/query/input/loaders/dfs.py b/graphrag/query/input/loaders/dfs.py index 6df0e7e78d..8b0ce42b53 100644 --- a/graphrag/query/input/loaders/dfs.py +++ b/graphrag/query/input/loaders/dfs.py @@ -12,6 +12,7 @@ from graphrag.model.relationship import Relationship from graphrag.model.text_unit import TextUnit from graphrag.query.input.loaders.utils import ( + to_list, to_optional_dict, to_optional_float, to_optional_int, @@ -154,7 +155,8 @@ def read_communities( entities_col: str | None = "entity_ids", relationships_col: str | None = "relationship_ids", covariates_col: str | None = "covariate_ids", - sub_communities_col: str | None = "sub_community_ids", + parent_col: str | None = "parent", + children_col: str | None = "children", attributes_cols: list[str] | None = None, ) -> list[Community]: """Read communities from a dataframe using pre-converted records.""" @@ -172,7 +174,8 @@ def read_communities( covariate_ids=to_optional_dict( row, covariates_col, key_type=str, value_type=str ), - sub_community_ids=to_optional_list(row, sub_communities_col), + parent=to_str(row, parent_col), + children=to_list(row, children_col), attributes=( {col: row.get(col) for col in attributes_cols} if attributes_cols diff --git a/tests/verbs/data/communities.parquet b/tests/verbs/data/communities.parquet index 7ba53b2cc0..a251c45277 100644 Binary files a/tests/verbs/data/communities.parquet and b/tests/verbs/data/communities.parquet differ diff --git a/tests/verbs/data/community_reports.parquet b/tests/verbs/data/community_reports.parquet index f42eab2d0c..15ec9d300b 100644 Binary files a/tests/verbs/data/community_reports.parquet and b/tests/verbs/data/community_reports.parquet differ diff --git a/tests/verbs/data/covariates.parquet b/tests/verbs/data/covariates.parquet index d75063c36b..63d5ddbca1 100644 Binary files a/tests/verbs/data/covariates.parquet and b/tests/verbs/data/covariates.parquet differ diff --git a/tests/verbs/data/documents.parquet b/tests/verbs/data/documents.parquet index 6902cc5328..b8a40153f6 100644 Binary files a/tests/verbs/data/documents.parquet and b/tests/verbs/data/documents.parquet differ diff --git a/tests/verbs/data/entities.parquet b/tests/verbs/data/entities.parquet index 5407156216..277a05582d 100644 Binary files a/tests/verbs/data/entities.parquet and b/tests/verbs/data/entities.parquet differ diff --git a/tests/verbs/data/relationships.parquet b/tests/verbs/data/relationships.parquet index 9862af3c41..bc4f52b2da 100644 Binary files a/tests/verbs/data/relationships.parquet and b/tests/verbs/data/relationships.parquet differ diff --git a/tests/verbs/data/text_units.parquet b/tests/verbs/data/text_units.parquet index 35f1983236..75e101581f 100644 Binary files a/tests/verbs/data/text_units.parquet and b/tests/verbs/data/text_units.parquet differ diff --git a/tests/verbs/test_create_base_text_units.py b/tests/verbs/test_create_base_text_units.py index 7ac752a4ed..68936f8839 100644 --- a/tests/verbs/test_create_base_text_units.py +++ b/tests/verbs/test_create_base_text_units.py @@ -11,6 +11,7 @@ compare_outputs, create_test_context, load_test_table, + update_document_metadata, ) @@ -43,6 +44,8 @@ async def test_create_base_text_units_metadata(): config.input.metadata = ["title"] config.chunks.prepend_metadata = True + await update_document_metadata(config.input.metadata, context) + await run_workflow( config, context, @@ -65,6 +68,8 @@ async def test_create_base_text_units_metadata_included_in_chunk(): config.chunks.prepend_metadata = True config.chunks.chunk_size_includes_metadata = True + await update_document_metadata(config.input.metadata, context) + await run_workflow( config, context, diff --git a/tests/verbs/test_create_final_documents.py b/tests/verbs/test_create_final_documents.py index c7f2a25f21..78477ec719 100644 --- a/tests/verbs/test_create_final_documents.py +++ b/tests/verbs/test_create_final_documents.py @@ -13,6 +13,7 @@ compare_outputs, create_test_context, load_test_table, + update_document_metadata, ) @@ -37,8 +38,6 @@ async def test_create_final_documents(): async def test_create_final_documents_with_metadata_column(): - expected = load_test_table("documents") - context = await create_test_context( storage=["text_units"], ) @@ -46,6 +45,11 @@ async def test_create_final_documents_with_metadata_column(): config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG}) config.input.metadata = ["title"] + # simulate the metadata construction during initial input loading + await update_document_metadata(config.input.metadata, context) + + expected = await load_table_from_storage("documents", context.storage) + await run_workflow( config, context, @@ -54,12 +58,12 @@ async def test_create_final_documents_with_metadata_column(): actual = await load_table_from_storage("documents", context.storage) - # we should have dropped "title" and added "attributes" - # our test dataframe does not have attributes, so we'll assert without it + # our test dataframe does not have metadata, so we'll assert without it # and separately confirm it is in the output compare_outputs( - actual, expected, columns=["id", "human_readable_id", "text", "text_unit_ids"] + actual, expected, columns=["id", "human_readable_id", "text", "metadata"] ) - assert len(actual.columns) == 6 + assert len(actual.columns) == 7 assert "title" in actual.columns + assert "text_unit_ids" in actual.columns assert "metadata" in actual.columns diff --git a/tests/verbs/util.py b/tests/verbs/util.py index 61711349cb..6167a4859e 100644 --- a/tests/verbs/util.py +++ b/tests/verbs/util.py @@ -7,7 +7,7 @@ import graphrag.config.defaults as defs from graphrag.index.context import PipelineRunContext from graphrag.index.run.utils import create_run_context -from graphrag.utils.storage import write_table_to_storage +from graphrag.utils.storage import load_table_from_storage, write_table_to_storage pd.set_option("display.max_columns", None) @@ -43,7 +43,6 @@ async def create_test_context(storage: list[str] | None = None) -> PipelineRunCo if storage: for name in storage: table = load_test_table(name) - # normal storage interface insists on bytes await write_table_to_storage(table, name, context.storage) return context @@ -83,3 +82,12 @@ def compare_outputs( print("Actual:") print(actual[column]) raise + + +async def update_document_metadata(metadata: list[str], context: PipelineRunContext): + """Takes the default documents and adds the configured metadata columns for later parsing by the text units and final documents workflows.""" + documents = await load_table_from_storage("documents", context.storage) + documents["metadata"] = documents[metadata].apply(lambda row: row.to_dict(), axis=1) + await write_table_to_storage( + documents, "documents", context.storage + ) # write to the runtime context storage only