Skip to content

Commit

Permalink
Removes unneeded type ignore comments
Browse files Browse the repository at this point in the history
  • Loading branch information
alexthomas93 committed Dec 8, 2024
1 parent ebb8b46 commit 3fe5c1b
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 31 deletions.
62 changes: 42 additions & 20 deletions libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ class GraphCypherQAChain(Chain):
"""

graph: GraphStore = Field(exclude=True)
cypher_generation_chain: Runnable
qa_chain: Runnable
cypher_generation_chain: Runnable[Dict[str, Any], str]
qa_chain: Runnable[Dict[str, Any], str]
graph_schema: str
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
Expand Down Expand Up @@ -239,7 +239,7 @@ def from_llm(
qa_prompt: Optional[BasePromptTemplate] = None,
cypher_prompt: Optional[BasePromptTemplate] = None,
cypher_llm: Optional[BaseLanguageModel] = None,
qa_llm: Optional[Union[BaseLanguageModel, Any]] = None,
qa_llm: Optional[BaseLanguageModel] = None,
exclude_types: List[str] = [],
include_types: List[str] = [],
validate_cypher: bool = False,
Expand All @@ -250,16 +250,28 @@ def from_llm(
**kwargs: Any,
) -> GraphCypherQAChain:
"""Initialize from LLM."""
# Ensure at least one LLM is provided
if llm is None and qa_llm is None and cypher_llm is None:
raise ValueError("At least one LLM must be provided")

if not cypher_llm and not llm:
raise ValueError("Either `llm` or `cypher_llm` parameters must be provided")
if not qa_llm and not llm:
raise ValueError("Either `llm` or `qa_llm` parameters must be provided")
if cypher_llm and qa_llm and llm:
# Prevent all three LLMs from being provided simultaneously
if llm is not None and qa_llm is not None and cypher_llm is not None:
raise ValueError(
"You can specify up to two of 'cypher_llm', 'qa_llm'"
", and 'llm', but not all three simultaneously."
"You can specify up to two of 'cypher_llm', 'qa_llm', and 'llm', but "
"not all three simultaneously."
)

# Assign default LLMs if specific ones are not provided
if llm is not None:
qa_llm = qa_llm or llm
cypher_llm = cypher_llm or llm
else:
# If llm is None, both qa_llm and cypher_llm must be provided
if qa_llm is None or cypher_llm is None:
raise ValueError(
"If `llm` is not provided, both `qa_llm` and `cypher_llm` must be "
"provided."
)
if cypher_prompt:
if cypher_llm_kwargs:
raise ValueError(
Expand All @@ -271,6 +283,11 @@ def from_llm(
cypher_prompt = cypher_llm_kwargs.pop(
"prompt", CYPHER_GENERATION_PROMPT
)
if not isinstance(cypher_prompt, BasePromptTemplate):
raise ValueError(
"The cypher_llm_kwargs `prompt` must inherit from "
"BasePromptTemplate"
)
else:
cypher_prompt = CYPHER_GENERATION_PROMPT
if qa_prompt:
Expand All @@ -282,33 +299,38 @@ def from_llm(
else:
if qa_llm_kwargs:
qa_prompt = qa_llm_kwargs.pop("prompt", CYPHER_QA_PROMPT)
if not isinstance(qa_prompt, BasePromptTemplate):
raise ValueError(
"The qa_llm_kwargs `prompt` must inherit from "
"BasePromptTemplate"
)
else:
qa_prompt = CYPHER_QA_PROMPT
use_qa_llm_kwargs = qa_llm_kwargs if qa_llm_kwargs is not None else {}
use_cypher_llm_kwargs = (
cypher_llm_kwargs if cypher_llm_kwargs is not None else {}
)

qa_llm = qa_llm or llm
if use_function_response:
try:
qa_llm.bind_tools({}) # type: ignore[union-attr]
if hasattr(qa_llm, "bind_tools"):
qa_llm.bind_tools({})
else:
raise AttributeError
response_prompt = ChatPromptTemplate.from_messages(
[
SystemMessage(content=function_response_system),
HumanMessagePromptTemplate.from_template("{question}"),
MessagesPlaceholder(variable_name="function_response"),
]
)
qa_chain = response_prompt | qa_llm | StrOutputParser() # type: ignore
qa_chain = response_prompt | qa_llm | StrOutputParser()
except (NotImplementedError, AttributeError):
raise ValueError("Provided LLM does not support native tools/functions")
else:
qa_chain = qa_prompt | qa_llm.bind(**use_qa_llm_kwargs) | StrOutputParser() # type: ignore

cypher_llm = cypher_llm or llm
qa_chain = qa_prompt | qa_llm.bind(**use_qa_llm_kwargs) | StrOutputParser()
cypher_generation_chain = (
cypher_prompt | cypher_llm.bind(**use_cypher_llm_kwargs) | StrOutputParser() # type: ignore
cypher_prompt | cypher_llm.bind(**use_cypher_llm_kwargs) | StrOutputParser()
)

if exclude_types and include_types:
Expand Down Expand Up @@ -379,6 +401,7 @@ def _call(
else:
context = []

final_result: Union[List[Dict[str, Any]], str]
if self.return_direct:
final_result = context
else:
Expand All @@ -390,15 +413,14 @@ def _call(
intermediate_steps.append({"context": context})
if self.use_function_response:
function_response = get_function_response(question, context)
final_result = self.qa_chain.invoke( # type: ignore
final_result = self.qa_chain.invoke(
{"question": question, "function_response": function_response},
)
else:
result = self.qa_chain.invoke( # type: ignore
final_result = self.qa_chain.invoke(
{"question": question, "context": context},
callbacks=callbacks,
)
final_result = result # type: ignore

chain_result: Dict[str, Any] = {self.output_key: final_result}
if self.return_intermediate_steps:
Expand Down
11 changes: 0 additions & 11 deletions libs/neo4j/tests/unit_tests/chains/test_graph_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,11 @@

from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory
from langchain_core.messages import SystemMessage
from langchain_core.messages import SystemMessage
from langchain_core.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
(
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
PromptTemplate,
),
)

from langchain_neo4j.chains.graph_qa.cypher import (
Expand Down Expand Up @@ -76,15 +70,10 @@ def test_graph_cypher_qa_chain_prompt_selection_1() -> None:
cypher_prompt=cypher_prompt,
allow_dangerous_requests=True,
)
<<<<<<< HEAD
assert chain.qa_chain.prompt == qa_prompt # type: ignore[union-attr]
assert chain.cypher_generation_chain.prompt == cypher_prompt
=======
assert hasattr(chain.qa_chain, "first")
assert chain.qa_chain.first == qa_prompt
assert hasattr(chain.cypher_generation_chain, "first")
assert chain.cypher_generation_chain.first == cypher_prompt
>>>>>>> 63f27b1 (Added hasattr assertions to tests/unit_tests/chains/test_graph_qa.py)


def test_graph_cypher_qa_chain_prompt_selection_2() -> None:
Expand Down

0 comments on commit 3fe5c1b

Please sign in to comment.