From aa5d9e70dccdcc6ed77b57b8ee4b55e6d988c3ec Mon Sep 17 00:00:00 2001 From: Casey Clements Date: Tue, 11 Feb 2025 14:10:17 -0500 Subject: [PATCH] Improved treatment of defaults in API --- .../langchain_mongodb/graphrag/graph.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/libs/langchain-mongodb/langchain_mongodb/graphrag/graph.py b/libs/langchain-mongodb/langchain_mongodb/graphrag/graph.py index 413d980..9cf9c43 100644 --- a/libs/langchain-mongodb/langchain_mongodb/graphrag/graph.py +++ b/libs/langchain-mongodb/langchain_mongodb/graphrag/graph.py @@ -93,8 +93,8 @@ def __init__( self, collection: Collection, entity_extraction_model: BaseChatModel, - entity_prompt: ChatPromptTemplate = prompts.entity_prompt, - query_prompt: ChatPromptTemplate = prompts.query_prompt, + entity_prompt: ChatPromptTemplate = None, + query_prompt: ChatPromptTemplate = None, max_depth: int = 2, allowed_entity_types: List[str] = None, allowed_relationship_types: List[str] = None, @@ -108,7 +108,9 @@ def __init__( collection: Collection representing an Entity Graph. entity_extraction_model: LLM for converting documents into Graph of Entities and Relationships. entity_prompt: Prompt to fill graph store with entities following schema. + Defaults to .prompts.ENTITY_EXTRACTION_INSTRUCTIONS query_prompt: Prompt extracts entities and relationships as search starting points. + Defaults to .prompts.NAME_EXTRACTION_INSTRUCTIONS max_depth: Maximum recursion depth in graph traversal. allowed_entity_types: If provided, constrains search to these types. allowed_relationship_types: If provided, constrains search to these types. @@ -120,8 +122,12 @@ def __init__( - If "error", an exception will be raised if any document does not match the schema. """ self.entity_extraction_model = entity_extraction_model - self.entity_prompt = entity_prompt - self.query_prompt = query_prompt + self.entity_prompt = ( + prompts.entity_prompt if entity_prompt is None else entity_prompt + ) + self.query_prompt = ( + prompts.query_prompt if query_prompt is None else query_prompt + ) self.max_depth = max_depth self._schema = deepcopy(entity_schema) if allowed_entity_types: @@ -138,7 +144,6 @@ def __init__( ] = allowed_relationship_types else: self.allowed_relationship_types = [] - if validate: collection.database.command( "collMod", @@ -155,7 +160,6 @@ def __init__( 1, SystemMessagePromptTemplate.from_template(entity_examples), ) - if entity_name_examples: self.query_prompt.messages.insert( 1,