Skip to content

Commit

Permalink
Improved treatment of defaults in API
Browse files Browse the repository at this point in the history
  • Loading branch information
caseyclements committed Feb 11, 2025
1 parent a52f040 commit aa5d9e7
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions libs/langchain-mongodb/langchain_mongodb/graphrag/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def __init__(
self,
collection: Collection,
entity_extraction_model: BaseChatModel,
entity_prompt: ChatPromptTemplate = prompts.entity_prompt,
query_prompt: ChatPromptTemplate = prompts.query_prompt,
entity_prompt: ChatPromptTemplate = None,
query_prompt: ChatPromptTemplate = None,
max_depth: int = 2,
allowed_entity_types: List[str] = None,
allowed_relationship_types: List[str] = None,
Expand All @@ -108,7 +108,9 @@ def __init__(
collection: Collection representing an Entity Graph.
entity_extraction_model: LLM for converting documents into Graph of Entities and Relationships.
entity_prompt: Prompt to fill graph store with entities following schema.
Defaults to .prompts.ENTITY_EXTRACTION_INSTRUCTIONS
query_prompt: Prompt extracts entities and relationships as search starting points.
Defaults to .prompts.NAME_EXTRACTION_INSTRUCTIONS
max_depth: Maximum recursion depth in graph traversal.
allowed_entity_types: If provided, constrains search to these types.
allowed_relationship_types: If provided, constrains search to these types.
Expand All @@ -120,8 +122,12 @@ def __init__(
- If "error", an exception will be raised if any document does not match the schema.
"""
self.entity_extraction_model = entity_extraction_model
self.entity_prompt = entity_prompt
self.query_prompt = query_prompt
self.entity_prompt = (
prompts.entity_prompt if entity_prompt is None else entity_prompt
)
self.query_prompt = (
prompts.query_prompt if query_prompt is None else query_prompt
)
self.max_depth = max_depth
self._schema = deepcopy(entity_schema)
if allowed_entity_types:
Expand All @@ -138,7 +144,6 @@ def __init__(
] = allowed_relationship_types
else:
self.allowed_relationship_types = []

if validate:
collection.database.command(
"collMod",
Expand All @@ -155,7 +160,6 @@ def __init__(
1,
SystemMessagePromptTemplate.from_template(entity_examples),
)

if entity_name_examples:
self.query_prompt.messages.insert(
1,
Expand Down

0 comments on commit aa5d9e7

Please sign in to comment.