From 9d99f323eaa101c10166260c11948850e8087059 Mon Sep 17 00:00:00 2001 From: Chris Trevino Date: Fri, 26 Jul 2024 15:05:08 -0700 Subject: [PATCH] Add encoding model to entity/claim extraction config sections (#740) * Add encoding-model configuration to entity & claim extraction * add change note * pr updates * test fix * disable GH-based smoke tests --- .github/workflows/python-ci.yml | 30 +++++++++---------- .../patch-20240726181256417715.json | 4 +++ docsite/posts/config/env_vars.md | 14 +++++---- docsite/posts/config/json_yaml.md | 2 ++ graphrag/config/create_graphrag_config.py | 3 ++ .../claim_extraction_config_input.py | 1 + .../entity_extraction_config_input.py | 1 + .../config/models/claim_extraction_config.py | 6 +++- .../config/models/entity_extraction_config.py | 5 +++- graphrag/index/create_pipeline_config.py | 2 +- tests/unit/config/test_default_config.py | 6 +++- 11 files changed, 49 insertions(+), 25 deletions(-) create mode 100644 .semversioner/next-release/patch-20240726181256417715.json diff --git a/.github/workflows/python-ci.yml b/.github/workflows/python-ci.yml index 4c2464b9f2..319c430e11 100644 --- a/.github/workflows/python-ci.yml +++ b/.github/workflows/python-ci.yml @@ -108,18 +108,18 @@ jobs: run: | poetry run poe test_integration - - name: Smoke Test - if: steps.changes.outputs.python == 'true' - run: | - poetry run poe test_smoke - - - uses: actions/upload-artifact@v4 - if: always() - with: - name: smoke-test-artifacts-${{ matrix.python-version }}-${{ matrix.poetry-version }}-${{ runner.os }} - path: tests/fixtures/*/output - - - name: E2E Test - if: steps.changes.outputs.python == 'true' - run: | - ./scripts/e2e-test.sh + # - name: Smoke Test + # if: steps.changes.outputs.python == 'true' + # run: | + # poetry run poe test_smoke + + # - uses: actions/upload-artifact@v4 + # if: always() + # with: + # name: smoke-test-artifacts-${{ matrix.python-version }}-${{ matrix.poetry-version }}-${{ runner.os }} + # path: tests/fixtures/*/output + + # - name: E2E Test + # if: steps.changes.outputs.python == 'true' + # run: | + # ./scripts/e2e-test.sh diff --git a/.semversioner/next-release/patch-20240726181256417715.json b/.semversioner/next-release/patch-20240726181256417715.json new file mode 100644 index 0000000000..cff6615031 --- /dev/null +++ b/.semversioner/next-release/patch-20240726181256417715.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "add encoding-model to entity/claim extraction config" +} diff --git a/docsite/posts/config/env_vars.md b/docsite/posts/config/env_vars.md index 302d78e5fe..7a8d8f65ed 100644 --- a/docsite/posts/config/env_vars.md +++ b/docsite/posts/config/env_vars.md @@ -132,12 +132,12 @@ These settings control the data input used by the pipeline. Any settings with a ## Data Chunking -| Parameter | Description | Type | Required or Optional | Default | -| --------------------------- | ------------------------------------------------------------------------------------------- | ----- | -------------------- | ------- | -| `GRAPHRAG_CHUNK_SIZE` | The chunk size in tokens for text-chunk analysis windows. | `str` | optional | 1200 | -| `GRAPHRAG_CHUNK_OVERLAP` | The chunk overlap in tokens for text-chunk analysis windows. | `str` | optional | 100 | -| `GRAPHRAG_CHUNK_BY_COLUMNS` | A comma-separated list of document attributes to groupby when performing TextUnit chunking. | `str` | optional | `id` | -| `GRAPHRAG_CHUNK_ENCODING_MODEL` | The encoding model to use for chunking. | `str` | optional | `None` | +| Parameter | Description | Type | Required or Optional | Default | +| ------------------------------- | ------------------------------------------------------------------------------------------- | ----- | -------------------- | ----------------------------- | +| `GRAPHRAG_CHUNK_SIZE` | The chunk size in tokens for text-chunk analysis windows. | `str` | optional | 1200 | +| `GRAPHRAG_CHUNK_OVERLAP` | The chunk overlap in tokens for text-chunk analysis windows. | `str` | optional | 100 | +| `GRAPHRAG_CHUNK_BY_COLUMNS` | A comma-separated list of document attributes to groupby when performing TextUnit chunking. | `str` | optional | `id` | +| `GRAPHRAG_CHUNK_ENCODING_MODEL` | The encoding model to use for chunking. | `str` | optional | The top-level encoding model. | ## Prompting Overrides @@ -146,12 +146,14 @@ These settings control the data input used by the pipeline. Any settings with a | `GRAPHRAG_ENTITY_EXTRACTION_PROMPT_FILE` | The path (relative to the root) of an entity extraction prompt template text file. | `str` | optional | `None` | | `GRAPHRAG_ENTITY_EXTRACTION_MAX_GLEANINGS` | The maximum number of redrives (gleanings) to invoke when extracting entities in a loop. | `int` | optional | 1 | | `GRAPHRAG_ENTITY_EXTRACTION_ENTITY_TYPES` | A comma-separated list of entity types to extract. | `str` | optional | `organization,person,event,geo` | +| `GRAPHRAG_ENTITY_EXTRACTION_ENCODING_MODEL` | The encoding model to use for entity extraction. | `str` | optional | The top-level encoding model. | | `GRAPHRAG_SUMMARIZE_DESCRIPTIONS_PROMPT_FILE` | The path (relative to the root) of an description summarization prompt template text file. | `str` | optional | `None` | | `GRAPHRAG_SUMMARIZE_DESCRIPTIONS_MAX_LENGTH` | The maximum number of tokens to generate per description summarization. | `int` | optional | 500 | | `GRAPHRAG_CLAIM_EXTRACTION_ENABLED` | Whether claim extraction is enabled for this pipeline. | `bool` | optional | `False` | | `GRAPHRAG_CLAIM_EXTRACTION_DESCRIPTION` | The claim_description prompting argument to utilize. | `string` | optional | "Any claims or facts that could be relevant to threat analysis." | | `GRAPHRAG_CLAIM_EXTRACTION_PROMPT_FILE` | The claim extraction prompt to utilize. | `string` | optional | `None` | | `GRAPHRAG_CLAIM_EXTRACTION_MAX_GLEANINGS` | The maximum number of redrives (gleanings) to invoke when extracting claims in a loop. | `int` | optional | 1 | +| `GRAPHRAG_CLAIM_EXTRACTION_ENCODING_MODEL` | The encoding model to use for claim extraction. | `str` | optional | The top-level encoding model | | `GRAPHRAG_COMMUNITY_REPORTS_PROMPT_FILE` | The community reports extraction prompt to utilize. | `string` | optional | `None` | | `GRAPHRAG_COMMUNITY_REPORTS_MAX_LENGTH` | The maximum number of tokens to generate per community reports. | `int` | optional | 1500 | diff --git a/docsite/posts/config/json_yaml.md b/docsite/posts/config/json_yaml.md index 080256c238..8c2e5701d5 100644 --- a/docsite/posts/config/json_yaml.md +++ b/docsite/posts/config/json_yaml.md @@ -145,6 +145,7 @@ This is the base LLM configuration section. Other steps may override this config - `prompt` **str** - The prompt file to use. - `entity_types` **list[str]** - The entity types to identify. - `max_gleanings` **int** - The maximum number of gleaning cycles to use. +- `encoding_model` **str** - The text encoding model to use. By default, this will use the top-level encoding model. - `strategy` **dict** - Fully override the entity extraction strategy. ## summarize_descriptions @@ -169,6 +170,7 @@ This is the base LLM configuration section. Other steps may override this config - `prompt` **str** - The prompt file to use. - `description` **str** - Describes the types of claims we want to extract. - `max_gleanings` **int** - The maximum number of gleaning cycles to use. +- `encoding_model` **str** - The text encoding model to use. By default, this will use the top-level encoding model. - `strategy` **dict** - Fully override the claim extraction strategy. ## community_reports diff --git a/graphrag/config/create_graphrag_config.py b/graphrag/config/create_graphrag_config.py index 54c154668f..3504507be2 100644 --- a/graphrag/config/create_graphrag_config.py +++ b/graphrag/config/create_graphrag_config.py @@ -390,6 +390,7 @@ def hydrate_parallelization_params( size=reader.int("size") or defs.CHUNK_SIZE, overlap=reader.int("overlap") or defs.CHUNK_OVERLAP, group_by_columns=group_by_columns, + encoding_model=reader.str(Fragment.encoding_model), ) with ( reader.envvar_prefix(Section.snapshot), @@ -428,6 +429,7 @@ def hydrate_parallelization_params( or defs.ENTITY_EXTRACTION_ENTITY_TYPES, max_gleanings=max_gleanings, prompt=reader.str("prompt", Fragment.prompt_file), + encoding_model=reader.str(Fragment.encoding_model), ) claim_extraction_config = values.get("claim_extraction") or {} @@ -449,6 +451,7 @@ def hydrate_parallelization_params( description=reader.str("description") or defs.CLAIM_DESCRIPTION, prompt=reader.str("prompt", Fragment.prompt_file), max_gleanings=max_gleanings, + encoding_model=reader.str(Fragment.encoding_model), ) community_report_config = values.get("community_reports") or {} diff --git a/graphrag/config/input_models/claim_extraction_config_input.py b/graphrag/config/input_models/claim_extraction_config_input.py index 4827435cab..f23e31d0a7 100644 --- a/graphrag/config/input_models/claim_extraction_config_input.py +++ b/graphrag/config/input_models/claim_extraction_config_input.py @@ -16,3 +16,4 @@ class ClaimExtractionConfigInput(LLMConfigInput): description: NotRequired[str | None] max_gleanings: NotRequired[int | str | None] strategy: NotRequired[dict | None] + encoding_model: NotRequired[str | None] diff --git a/graphrag/config/input_models/entity_extraction_config_input.py b/graphrag/config/input_models/entity_extraction_config_input.py index a03cbfee1b..f1d3587e99 100644 --- a/graphrag/config/input_models/entity_extraction_config_input.py +++ b/graphrag/config/input_models/entity_extraction_config_input.py @@ -15,3 +15,4 @@ class EntityExtractionConfigInput(LLMConfigInput): entity_types: NotRequired[list[str] | str | None] max_gleanings: NotRequired[int | str | None] strategy: NotRequired[dict | None] + encoding_model: NotRequired[str | None] diff --git a/graphrag/config/models/claim_extraction_config.py b/graphrag/config/models/claim_extraction_config.py index a4437dd6d3..a26fdad26e 100644 --- a/graphrag/config/models/claim_extraction_config.py +++ b/graphrag/config/models/claim_extraction_config.py @@ -32,8 +32,11 @@ class ClaimExtractionConfig(LLMConfig): strategy: dict | None = Field( description="The override strategy to use.", default=None ) + encoding_model: str | None = Field( + default=None, description="The encoding model to use." + ) - def resolved_strategy(self, root_dir: str) -> dict: + def resolved_strategy(self, root_dir: str, encoding_model: str) -> dict: """Get the resolved claim extraction strategy.""" from graphrag.index.verbs.covariates.extract_covariates import ( ExtractClaimsStrategyType, @@ -50,4 +53,5 @@ def resolved_strategy(self, root_dir: str) -> dict: else None, "claim_description": self.description, "max_gleanings": self.max_gleanings, + "encoding_name": self.encoding_model or encoding_model, } diff --git a/graphrag/config/models/entity_extraction_config.py b/graphrag/config/models/entity_extraction_config.py index 26101f747c..ca160bc4e2 100644 --- a/graphrag/config/models/entity_extraction_config.py +++ b/graphrag/config/models/entity_extraction_config.py @@ -29,6 +29,9 @@ class EntityExtractionConfig(LLMConfig): strategy: dict | None = Field( description="Override the default entity extraction strategy", default=None ) + encoding_model: str | None = Field( + default=None, description="The encoding model to use." + ) def resolved_strategy(self, root_dir: str, encoding_model: str) -> dict: """Get the resolved entity extraction strategy.""" @@ -45,6 +48,6 @@ def resolved_strategy(self, root_dir: str, encoding_model: str) -> dict: else None, "max_gleanings": self.max_gleanings, # It's prechunked in create_base_text_units - "encoding_name": encoding_model, + "encoding_name": self.encoding_model or encoding_model, "prechunked": True, } diff --git a/graphrag/index/create_pipeline_config.py b/graphrag/index/create_pipeline_config.py index e930844a8d..993a0b6c74 100644 --- a/graphrag/index/create_pipeline_config.py +++ b/graphrag/index/create_pipeline_config.py @@ -405,7 +405,7 @@ def _covariate_workflows( "claim_extract": { **settings.claim_extraction.parallelization.model_dump(), "strategy": settings.claim_extraction.resolved_strategy( - settings.root_dir + settings.root_dir, settings.encoding_model ), }, }, diff --git a/tests/unit/config/test_default_config.py b/tests/unit/config/test_default_config.py index 50e526e6f8..095155afa6 100644 --- a/tests/unit/config/test_default_config.py +++ b/tests/unit/config/test_default_config.py @@ -93,6 +93,7 @@ "GRAPHRAG_CLAIM_EXTRACTION_DESCRIPTION": "test 123", "GRAPHRAG_CLAIM_EXTRACTION_MAX_GLEANINGS": "5000", "GRAPHRAG_CLAIM_EXTRACTION_PROMPT_FILE": "tests/unit/config/prompt-a.txt", + "GRAPHRAG_CLAIM_EXTRACTION_ENCODING_MODEL": "encoding_a", "GRAPHRAG_COMMUNITY_REPORTS_MAX_LENGTH": "23456", "GRAPHRAG_COMMUNITY_REPORTS_PROMPT_FILE": "tests/unit/config/prompt-b.txt", "GRAPHRAG_EMBEDDING_BATCH_MAX_TOKENS": "17", @@ -115,6 +116,7 @@ "GRAPHRAG_ENTITY_EXTRACTION_ENTITY_TYPES": "cat,dog,elephant", "GRAPHRAG_ENTITY_EXTRACTION_MAX_GLEANINGS": "112", "GRAPHRAG_ENTITY_EXTRACTION_PROMPT_FILE": "tests/unit/config/prompt-c.txt", + "GRAPHRAG_ENTITY_EXTRACTION_ENCODING_MODEL": "encoding_b", "GRAPHRAG_INPUT_BASE_DIR": "/some/input/dir", "GRAPHRAG_INPUT_CONNECTION_STRING": "input_cs", "GRAPHRAG_INPUT_CONTAINER_NAME": "input_cn", @@ -543,6 +545,7 @@ def test_create_parameters_from_env_vars(self) -> None: assert parameters.claim_extraction.description == "test 123" assert parameters.claim_extraction.max_gleanings == 5000 assert parameters.claim_extraction.prompt == "tests/unit/config/prompt-a.txt" + assert parameters.claim_extraction.encoding_model == "encoding_a" assert parameters.cluster_graph.max_cluster_size == 123 assert parameters.community_reports.max_length == 23456 assert parameters.community_reports.prompt == "tests/unit/config/prompt-b.txt" @@ -572,6 +575,7 @@ def test_create_parameters_from_env_vars(self) -> None: assert parameters.entity_extraction.llm.api_base == "http://some/base" assert parameters.entity_extraction.max_gleanings == 112 assert parameters.entity_extraction.prompt == "tests/unit/config/prompt-c.txt" + assert parameters.entity_extraction.encoding_model == "encoding_b" assert parameters.input.storage_account_blob_url == "input_account_blob_url" assert parameters.input.base_dir == "/some/input/dir" assert parameters.input.connection_string == "input_cs" @@ -910,7 +914,7 @@ def test_prompt_file_reading(self): assert strategy["extraction_prompt"] == "Hello, World! A" assert strategy["encoding_name"] == "abc123" - strategy = config.claim_extraction.resolved_strategy(".") + strategy = config.claim_extraction.resolved_strategy(".", "encoding_b") assert strategy["extraction_prompt"] == "Hello, World! B" strategy = config.community_reports.resolved_strategy(".")