diff --git a/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py b/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py index 69ff000..3770208 100644 --- a/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py +++ b/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py @@ -164,8 +164,8 @@ class GraphCypherQAChain(Chain): """ graph: GraphStore = Field(exclude=True) - cypher_generation_chain: Runnable - qa_chain: Runnable + cypher_generation_chain: Runnable[Dict[str, Any], str] + qa_chain: Runnable[Dict[str, Any], str] graph_schema: str input_key: str = "query" #: :meta private: output_key: str = "result" #: :meta private: @@ -239,7 +239,7 @@ def from_llm( qa_prompt: Optional[BasePromptTemplate] = None, cypher_prompt: Optional[BasePromptTemplate] = None, cypher_llm: Optional[BaseLanguageModel] = None, - qa_llm: Optional[Union[BaseLanguageModel, Any]] = None, + qa_llm: Optional[BaseLanguageModel] = None, exclude_types: List[str] = [], include_types: List[str] = [], validate_cypher: bool = False, @@ -250,16 +250,28 @@ def from_llm( **kwargs: Any, ) -> GraphCypherQAChain: """Initialize from LLM.""" + # Ensure at least one LLM is provided + if llm is None and qa_llm is None and cypher_llm is None: + raise ValueError("At least one LLM must be provided") - if not cypher_llm and not llm: - raise ValueError("Either `llm` or `cypher_llm` parameters must be provided") - if not qa_llm and not llm: - raise ValueError("Either `llm` or `qa_llm` parameters must be provided") - if cypher_llm and qa_llm and llm: + # Prevent all three LLMs from being provided simultaneously + if llm is not None and qa_llm is not None and cypher_llm is not None: raise ValueError( - "You can specify up to two of 'cypher_llm', 'qa_llm'" - ", and 'llm', but not all three simultaneously." + "You can specify up to two of 'cypher_llm', 'qa_llm', and 'llm', but " + "not all three simultaneously." ) + + # Assign default LLMs if specific ones are not provided + if llm is not None: + qa_llm = qa_llm or llm + cypher_llm = cypher_llm or llm + else: + # If llm is None, both qa_llm and cypher_llm must be provided + if qa_llm is None or cypher_llm is None: + raise ValueError( + "If `llm` is not provided, both `qa_llm` and `cypher_llm` must be " + "provided." + ) if cypher_prompt: if cypher_llm_kwargs: raise ValueError( @@ -271,6 +283,11 @@ def from_llm( cypher_prompt = cypher_llm_kwargs.pop( "prompt", CYPHER_GENERATION_PROMPT ) + if not isinstance(cypher_prompt, BasePromptTemplate): + raise ValueError( + "The cypher_llm_kwargs `prompt` must inherit from " + "BasePromptTemplate" + ) else: cypher_prompt = CYPHER_GENERATION_PROMPT if qa_prompt: @@ -282,6 +299,11 @@ def from_llm( else: if qa_llm_kwargs: qa_prompt = qa_llm_kwargs.pop("prompt", CYPHER_QA_PROMPT) + if not isinstance(qa_prompt, BasePromptTemplate): + raise ValueError( + "The qa_llm_kwargs `prompt` must inherit from " + "BasePromptTemplate" + ) else: qa_prompt = CYPHER_QA_PROMPT use_qa_llm_kwargs = qa_llm_kwargs if qa_llm_kwargs is not None else {} @@ -289,10 +311,12 @@ def from_llm( cypher_llm_kwargs if cypher_llm_kwargs is not None else {} ) - qa_llm = qa_llm or llm if use_function_response: try: - qa_llm.bind_tools({}) # type: ignore[union-attr] + if hasattr(qa_llm, "bind_tools"): + qa_llm.bind_tools({}) + else: + raise AttributeError response_prompt = ChatPromptTemplate.from_messages( [ SystemMessage(content=function_response_system), @@ -300,15 +324,13 @@ def from_llm( MessagesPlaceholder(variable_name="function_response"), ] ) - qa_chain = response_prompt | qa_llm | StrOutputParser() # type: ignore + qa_chain = response_prompt | qa_llm | StrOutputParser() except (NotImplementedError, AttributeError): raise ValueError("Provided LLM does not support native tools/functions") else: - qa_chain = qa_prompt | qa_llm.bind(**use_qa_llm_kwargs) | StrOutputParser() # type: ignore - - cypher_llm = cypher_llm or llm + qa_chain = qa_prompt | qa_llm.bind(**use_qa_llm_kwargs) | StrOutputParser() cypher_generation_chain = ( - cypher_prompt | cypher_llm.bind(**use_cypher_llm_kwargs) | StrOutputParser() # type: ignore + cypher_prompt | cypher_llm.bind(**use_cypher_llm_kwargs) | StrOutputParser() ) if exclude_types and include_types: @@ -379,6 +401,7 @@ def _call( else: context = [] + final_result: Union[List[Dict[str, Any]], str] if self.return_direct: final_result = context else: @@ -390,15 +413,14 @@ def _call( intermediate_steps.append({"context": context}) if self.use_function_response: function_response = get_function_response(question, context) - final_result = self.qa_chain.invoke( # type: ignore + final_result = self.qa_chain.invoke( {"question": question, "function_response": function_response}, ) else: - result = self.qa_chain.invoke( # type: ignore + final_result = self.qa_chain.invoke( {"question": question, "context": context}, callbacks=callbacks, ) - final_result = result # type: ignore chain_result: Dict[str, Any] = {self.output_key: final_result} if self.return_intermediate_steps: