Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removes #type: ignore comments #19

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
### Fixed

- Removed deprecated LLMChain from GraphCypherQAChain to resolve instantiation issues with the use_function_response parameter.
- Removed unnecessary # type: ignore comments, improving type safety and code clarity.

## 0.1.1

Expand Down
58 changes: 40 additions & 18 deletions libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ class GraphCypherQAChain(Chain):
"""

graph: GraphStore = Field(exclude=True)
cypher_generation_chain: Runnable
qa_chain: Runnable
cypher_generation_chain: Runnable[Dict[str, Any], str]
qa_chain: Runnable[Dict[str, Any], str]
graph_schema: str
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
Expand Down Expand Up @@ -239,7 +239,7 @@ def from_llm(
qa_prompt: Optional[BasePromptTemplate] = None,
cypher_prompt: Optional[BasePromptTemplate] = None,
cypher_llm: Optional[BaseLanguageModel] = None,
qa_llm: Optional[Union[BaseLanguageModel, Any]] = None,
qa_llm: Optional[BaseLanguageModel] = None,
exclude_types: List[str] = [],
include_types: List[str] = [],
validate_cypher: bool = False,
Expand All @@ -250,16 +250,28 @@ def from_llm(
**kwargs: Any,
) -> GraphCypherQAChain:
"""Initialize from LLM."""
# Ensure at least one LLM is provided
if llm is None and qa_llm is None and cypher_llm is None:
raise ValueError("At least one LLM must be provided")

if not cypher_llm and not llm:
raise ValueError("Either `llm` or `cypher_llm` parameters must be provided")
if not qa_llm and not llm:
raise ValueError("Either `llm` or `qa_llm` parameters must be provided")
if cypher_llm and qa_llm and llm:
# Prevent all three LLMs from being provided simultaneously
if llm is not None and qa_llm is not None and cypher_llm is not None:
raise ValueError(
"You can specify up to two of 'cypher_llm', 'qa_llm'"
", and 'llm', but not all three simultaneously."
)

# Assign default LLMs if specific ones are not provided
if llm is not None:
qa_llm = qa_llm or llm
cypher_llm = cypher_llm or llm
else:
# If llm is None, both qa_llm and cypher_llm must be provided
if qa_llm is None or cypher_llm is None:
raise ValueError(
"If `llm` is not provided, both `qa_llm` and `cypher_llm` must be "
"provided."
)
if cypher_prompt:
if cypher_llm_kwargs:
raise ValueError(
Expand All @@ -271,6 +283,11 @@ def from_llm(
cypher_prompt = cypher_llm_kwargs.pop(
"prompt", CYPHER_GENERATION_PROMPT
)
if not isinstance(cypher_prompt, BasePromptTemplate):
raise ValueError(
"The cypher_llm_kwargs `prompt` must inherit from "
"BasePromptTemplate"
)
else:
cypher_prompt = CYPHER_GENERATION_PROMPT
if qa_prompt:
Expand All @@ -282,33 +299,38 @@ def from_llm(
else:
if qa_llm_kwargs:
qa_prompt = qa_llm_kwargs.pop("prompt", CYPHER_QA_PROMPT)
if not isinstance(qa_prompt, BasePromptTemplate):
raise ValueError(
"The qa_llm_kwargs `prompt` must inherit from "
"BasePromptTemplate"
)
else:
qa_prompt = CYPHER_QA_PROMPT
use_qa_llm_kwargs = qa_llm_kwargs if qa_llm_kwargs is not None else {}
use_cypher_llm_kwargs = (
cypher_llm_kwargs if cypher_llm_kwargs is not None else {}
)

qa_llm = qa_llm or llm
if use_function_response:
try:
qa_llm.bind_tools({}) # type: ignore[union-attr]
if hasattr(qa_llm, "bind_tools"):
qa_llm.bind_tools({})
else:
raise AttributeError
response_prompt = ChatPromptTemplate.from_messages(
[
SystemMessage(content=function_response_system),
HumanMessagePromptTemplate.from_template("{question}"),
MessagesPlaceholder(variable_name="function_response"),
]
)
qa_chain = response_prompt | qa_llm | StrOutputParser() # type: ignore
qa_chain = response_prompt | qa_llm | StrOutputParser()
except (NotImplementedError, AttributeError):
raise ValueError("Provided LLM does not support native tools/functions")
else:
qa_chain = qa_prompt | qa_llm.bind(**use_qa_llm_kwargs) | StrOutputParser() # type: ignore

cypher_llm = cypher_llm or llm
qa_chain = qa_prompt | qa_llm.bind(**use_qa_llm_kwargs) | StrOutputParser()
cypher_generation_chain = (
cypher_prompt | cypher_llm.bind(**use_cypher_llm_kwargs) | StrOutputParser() # type: ignore
cypher_prompt | cypher_llm.bind(**use_cypher_llm_kwargs) | StrOutputParser()
)

if exclude_types and include_types:
Expand Down Expand Up @@ -379,6 +401,7 @@ def _call(
else:
context = []

final_result: Union[List[Dict[str, Any]], str]
if self.return_direct:
final_result = context
else:
Expand All @@ -390,15 +413,14 @@ def _call(
intermediate_steps.append({"context": context})
if self.use_function_response:
function_response = get_function_response(question, context)
final_result = self.qa_chain.invoke( # type: ignore
final_result = self.qa_chain.invoke(
{"question": question, "function_response": function_response},
)
else:
result = self.qa_chain.invoke( # type: ignore
final_result = self.qa_chain.invoke(
{"question": question, "context": context},
callbacks=callbacks,
)
final_result = result # type: ignore

chain_result: Dict[str, Any] = {self.output_key: final_result}
if self.return_intermediate_steps:
Expand Down
12 changes: 7 additions & 5 deletions libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,21 +461,23 @@ def query(
or e.code
== "Neo.DatabaseError.Transaction.TransactionStartFailed"
)
and "in an implicit transaction" in e.message # type: ignore
and e.message is not None
and "in an implicit transaction" in e.message
)
or ( # isPeriodicCommitError
e.code == "Neo.ClientError.Statement.SemanticError"
and e.message is not None
and (
"in an open transaction is not possible" in e.message # type: ignore
or "tried to execute in an explicit transaction" in e.message # type: ignore
"in an open transaction is not possible" in e.message
or "tried to execute in an explicit transaction" in e.message
)
)
):
raise
# fallback to allow implicit transactions
with self._driver.session(database=self._database) as session:
data = session.run(Query(text=query, timeout=self.timeout), params) # type: ignore
json_data = [r.data() for r in data]
result = session.run(Query(text=query, timeout=self.timeout), params)
json_data = [r.data() for r in result]
if self.sanitize:
json_data = [value_sanitize(el) for el in json_data]
return json_data
Expand Down
12 changes: 7 additions & 5 deletions libs/neo4j/langchain_neo4j/vectorstores/neo4j_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,21 +655,23 @@ def query(
or e.code
== "Neo.DatabaseError.Transaction.TransactionStartFailed"
)
and "in an implicit transaction" in e.message # type: ignore
and e.message is not None
and "in an implicit transaction" in e.message
)
or ( # isPeriodicCommitError
e.code == "Neo.ClientError.Statement.SemanticError"
and e.message is not None
and (
"in an open transaction is not possible" in e.message # type: ignore
or "tried to execute in an explicit transaction" in e.message # type: ignore
"in an open transaction is not possible" in e.message
or "tried to execute in an explicit transaction" in e.message
)
)
):
raise
# Fallback to allow implicit transactions
with self._driver.session(database=self._database) as session:
data = session.run(Query(text=query), params) # type: ignore
return [r.data() for r in data]
result = session.run(Query(text=query), params)
return [r.data() for r in result]

def verify_version(self) -> None:
"""
Expand Down
3 changes: 2 additions & 1 deletion libs/neo4j/tests/integration_tests/graphs/test_neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,9 @@ def test_neo4j_timeout() -> None:
try:
graph.query("UNWIND range(0,100000,1) AS i MERGE (:Foo {id:i})")
except Exception as e:
assert hasattr(e, "code")
assert (
e.code # type: ignore[attr-defined]
e.code
== "Neo.ClientError.Transaction.TransactionTimedOutClientConfiguration"
)

Expand Down
32 changes: 14 additions & 18 deletions libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
from langchain_neo4j.graphs.neo4j_graph import Neo4jGraph, value_sanitize


def test_value_sanitize_with_small_list(): # type: ignore[no-untyped-def]
def test_value_sanitize_with_small_list() -> None:
small_list = list(range(15)) # list size > LIST_LIMIT
input_dict = {"key1": "value1", "small_list": small_list}
expected_output = {"key1": "value1", "small_list": small_list}
assert value_sanitize(input_dict) == expected_output


def test_value_sanitize_with_oversized_list(): # type: ignore[no-untyped-def]
def test_value_sanitize_with_oversized_list() -> None:
oversized_list = list(range(150)) # list size > LIST_LIMIT
input_dict = {"key1": "value1", "oversized_list": oversized_list}
expected_output = {
Expand All @@ -22,21 +22,21 @@ def test_value_sanitize_with_oversized_list(): # type: ignore[no-untyped-def]
assert value_sanitize(input_dict) == expected_output


def test_value_sanitize_with_nested_oversized_list(): # type: ignore[no-untyped-def]
def test_value_sanitize_with_nested_oversized_list() -> None:
oversized_list = list(range(150)) # list size > LIST_LIMIT
input_dict = {"key1": "value1", "oversized_list": {"key": oversized_list}}
expected_output = {"key1": "value1", "oversized_list": {}}
assert value_sanitize(input_dict) == expected_output


def test_value_sanitize_with_dict_in_list(): # type: ignore[no-untyped-def]
def test_value_sanitize_with_dict_in_list() -> None:
oversized_list = list(range(150)) # list size > LIST_LIMIT
input_dict = {"key1": "value1", "oversized_list": [1, 2, {"key": oversized_list}]}
expected_output = {"key1": "value1", "oversized_list": [1, 2, {}]}
assert value_sanitize(input_dict) == expected_output


def test_value_sanitize_with_dict_in_nested_list(): # type: ignore[no-untyped-def]
def test_value_sanitize_with_dict_in_nested_list() -> None:
input_dict = {
"key1": "value1",
"deeply_nested_lists": [[[[{"final_nested_key": list(range(200))}]]]],
Expand All @@ -45,9 +45,9 @@ def test_value_sanitize_with_dict_in_nested_list(): # type: ignore[no-untyped-d
assert value_sanitize(input_dict) == expected_output


def test_driver_state_management(): # type: ignore[no-untyped-def]
def test_driver_state_management() -> None:
"""Comprehensive test for driver state management."""
with patch("neo4j.GraphDatabase.driver") as mock_driver:
with patch("neo4j.GraphDatabase.driver", autospec=True) as mock_driver:
# Setup mock driver
mock_driver_instance = MagicMock()
mock_driver.return_value = mock_driver_instance
Expand All @@ -60,7 +60,7 @@ def test_driver_state_management(): # type: ignore[no-untyped-def]

# Store original driver
original_driver = graph._driver
original_driver.close = MagicMock()
assert isinstance(original_driver.close, MagicMock)

# Test initial state
assert hasattr(graph, "_driver")
Expand All @@ -84,9 +84,9 @@ def test_driver_state_management(): # type: ignore[no-untyped-def]
graph.refresh_schema()


def test_close_method_removes_driver(): # type: ignore[no-untyped-def]
def test_close_method_removes_driver() -> None:
"""Test that close method removes the _driver attribute."""
with patch("neo4j.GraphDatabase.driver") as mock_driver:
with patch("neo4j.GraphDatabase.driver", autospec=True) as mock_driver:
# Configure mock to return a mock driver
mock_driver_instance = MagicMock()
mock_driver.return_value = mock_driver_instance
Expand All @@ -103,9 +103,7 @@ def test_close_method_removes_driver(): # type: ignore[no-untyped-def]

# Store a reference to the original driver
original_driver = graph._driver

# Ensure driver's close method can be mocked
original_driver.close = MagicMock()
assert isinstance(original_driver.close, MagicMock)

# Call close method
graph.close()
Expand All @@ -120,9 +118,9 @@ def test_close_method_removes_driver(): # type: ignore[no-untyped-def]
graph.close() # Should not raise any exception


def test_multiple_close_calls_safe(): # type: ignore[no-untyped-def]
def test_multiple_close_calls_safe() -> None:
"""Test that multiple close calls do not raise errors."""
with patch("neo4j.GraphDatabase.driver") as mock_driver:
with patch("neo4j.GraphDatabase.driver", autospec=True) as mock_driver:
# Configure mock to return a mock driver
mock_driver_instance = MagicMock()
mock_driver.return_value = mock_driver_instance
Expand All @@ -139,9 +137,7 @@ def test_multiple_close_calls_safe(): # type: ignore[no-untyped-def]

# Store a reference to the original driver
original_driver = graph._driver

# Mock the driver's close method
original_driver.close = MagicMock()
assert isinstance(original_driver.close, MagicMock)

# First close
graph.close()
Expand Down
6 changes: 3 additions & 3 deletions libs/neo4j/tests/unit_tests/vectorstores/test_neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def test_get_search_index_query_invalid_search_type() -> None:

with pytest.raises(ValueError) as exc_info:
_get_search_index_query(
search_type=invalid_search_type, # type: ignore
search_type=invalid_search_type, # type: ignore[arg-type]
index_type=IndexType.NODE,
)

Expand Down Expand Up @@ -356,7 +356,7 @@ def test_check_if_not_null_with_none_value() -> None:

def test_handle_field_filter_invalid_field_type() -> None:
with pytest.raises(ValueError) as exc_info:
_handle_field_filter(field=123, value="some_value") # type: ignore
_handle_field_filter(field=123, value="some_value") # type: ignore[arg-type]
assert "field should be a string" in str(exc_info.value)


Expand Down Expand Up @@ -535,7 +535,7 @@ def test_neo4jvector_invalid_distance_strategy() -> None:
url="bolt://localhost:7687",
username="neo4j",
password="password",
distance_strategy="INVALID_STRATEGY", # type: ignore
distance_strategy="INVALID_STRATEGY", # type: ignore[arg-type]
)
assert "distance_strategy must be either 'EUCLIDEAN_DISTANCE' or 'COSINE'" in str(
exc_info.value
Expand Down
Loading