Skip to content

Commit

Permalink
Removes #type: ignore comments (#19)
Browse files Browse the repository at this point in the history
* Removes unneeded type ignore comments

* Removed more unneeded type ignore comments

* Updated CHANGELOG
  • Loading branch information
alexthomas93 authored Dec 9, 2024
1 parent 2392837 commit 199d500
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 50 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
### Fixed

- Removed deprecated LLMChain from GraphCypherQAChain to resolve instantiation issues with the use_function_response parameter.
- Removed unnecessary # type: ignore comments, improving type safety and code clarity.

## 0.1.1

Expand Down
58 changes: 40 additions & 18 deletions libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ class GraphCypherQAChain(Chain):
"""

graph: GraphStore = Field(exclude=True)
cypher_generation_chain: Runnable
qa_chain: Runnable
cypher_generation_chain: Runnable[Dict[str, Any], str]
qa_chain: Runnable[Dict[str, Any], str]
graph_schema: str
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
Expand Down Expand Up @@ -239,7 +239,7 @@ def from_llm(
qa_prompt: Optional[BasePromptTemplate] = None,
cypher_prompt: Optional[BasePromptTemplate] = None,
cypher_llm: Optional[BaseLanguageModel] = None,
qa_llm: Optional[Union[BaseLanguageModel, Any]] = None,
qa_llm: Optional[BaseLanguageModel] = None,
exclude_types: List[str] = [],
include_types: List[str] = [],
validate_cypher: bool = False,
Expand All @@ -250,16 +250,28 @@ def from_llm(
**kwargs: Any,
) -> GraphCypherQAChain:
"""Initialize from LLM."""
# Ensure at least one LLM is provided
if llm is None and qa_llm is None and cypher_llm is None:
raise ValueError("At least one LLM must be provided")

if not cypher_llm and not llm:
raise ValueError("Either `llm` or `cypher_llm` parameters must be provided")
if not qa_llm and not llm:
raise ValueError("Either `llm` or `qa_llm` parameters must be provided")
if cypher_llm and qa_llm and llm:
# Prevent all three LLMs from being provided simultaneously
if llm is not None and qa_llm is not None and cypher_llm is not None:
raise ValueError(
"You can specify up to two of 'cypher_llm', 'qa_llm'"
", and 'llm', but not all three simultaneously."
)

# Assign default LLMs if specific ones are not provided
if llm is not None:
qa_llm = qa_llm or llm
cypher_llm = cypher_llm or llm
else:
# If llm is None, both qa_llm and cypher_llm must be provided
if qa_llm is None or cypher_llm is None:
raise ValueError(
"If `llm` is not provided, both `qa_llm` and `cypher_llm` must be "
"provided."
)
if cypher_prompt:
if cypher_llm_kwargs:
raise ValueError(
Expand All @@ -271,6 +283,11 @@ def from_llm(
cypher_prompt = cypher_llm_kwargs.pop(
"prompt", CYPHER_GENERATION_PROMPT
)
if not isinstance(cypher_prompt, BasePromptTemplate):
raise ValueError(
"The cypher_llm_kwargs `prompt` must inherit from "
"BasePromptTemplate"
)
else:
cypher_prompt = CYPHER_GENERATION_PROMPT
if qa_prompt:
Expand All @@ -282,33 +299,38 @@ def from_llm(
else:
if qa_llm_kwargs:
qa_prompt = qa_llm_kwargs.pop("prompt", CYPHER_QA_PROMPT)
if not isinstance(qa_prompt, BasePromptTemplate):
raise ValueError(
"The qa_llm_kwargs `prompt` must inherit from "
"BasePromptTemplate"
)
else:
qa_prompt = CYPHER_QA_PROMPT
use_qa_llm_kwargs = qa_llm_kwargs if qa_llm_kwargs is not None else {}
use_cypher_llm_kwargs = (
cypher_llm_kwargs if cypher_llm_kwargs is not None else {}
)

qa_llm = qa_llm or llm
if use_function_response:
try:
qa_llm.bind_tools({}) # type: ignore[union-attr]
if hasattr(qa_llm, "bind_tools"):
qa_llm.bind_tools({})
else:
raise AttributeError
response_prompt = ChatPromptTemplate.from_messages(
[
SystemMessage(content=function_response_system),
HumanMessagePromptTemplate.from_template("{question}"),
MessagesPlaceholder(variable_name="function_response"),
]
)
qa_chain = response_prompt | qa_llm | StrOutputParser() # type: ignore
qa_chain = response_prompt | qa_llm | StrOutputParser()
except (NotImplementedError, AttributeError):
raise ValueError("Provided LLM does not support native tools/functions")
else:
qa_chain = qa_prompt | qa_llm.bind(**use_qa_llm_kwargs) | StrOutputParser() # type: ignore

cypher_llm = cypher_llm or llm
qa_chain = qa_prompt | qa_llm.bind(**use_qa_llm_kwargs) | StrOutputParser()
cypher_generation_chain = (
cypher_prompt | cypher_llm.bind(**use_cypher_llm_kwargs) | StrOutputParser() # type: ignore
cypher_prompt | cypher_llm.bind(**use_cypher_llm_kwargs) | StrOutputParser()
)

if exclude_types and include_types:
Expand Down Expand Up @@ -379,6 +401,7 @@ def _call(
else:
context = []

final_result: Union[List[Dict[str, Any]], str]
if self.return_direct:
final_result = context
else:
Expand All @@ -390,15 +413,14 @@ def _call(
intermediate_steps.append({"context": context})
if self.use_function_response:
function_response = get_function_response(question, context)
final_result = self.qa_chain.invoke( # type: ignore
final_result = self.qa_chain.invoke(
{"question": question, "function_response": function_response},
)
else:
result = self.qa_chain.invoke( # type: ignore
final_result = self.qa_chain.invoke(
{"question": question, "context": context},
callbacks=callbacks,
)
final_result = result # type: ignore

chain_result: Dict[str, Any] = {self.output_key: final_result}
if self.return_intermediate_steps:
Expand Down
12 changes: 7 additions & 5 deletions libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,21 +461,23 @@ def query(
or e.code
== "Neo.DatabaseError.Transaction.TransactionStartFailed"
)
and "in an implicit transaction" in e.message # type: ignore
and e.message is not None
and "in an implicit transaction" in e.message
)
or ( # isPeriodicCommitError
e.code == "Neo.ClientError.Statement.SemanticError"
and e.message is not None
and (
"in an open transaction is not possible" in e.message # type: ignore
or "tried to execute in an explicit transaction" in e.message # type: ignore
"in an open transaction is not possible" in e.message
or "tried to execute in an explicit transaction" in e.message
)
)
):
raise
# fallback to allow implicit transactions
with self._driver.session(database=self._database) as session:
data = session.run(Query(text=query, timeout=self.timeout), params) # type: ignore
json_data = [r.data() for r in data]
result = session.run(Query(text=query, timeout=self.timeout), params)
json_data = [r.data() for r in result]
if self.sanitize:
json_data = [value_sanitize(el) for el in json_data]
return json_data
Expand Down
12 changes: 7 additions & 5 deletions libs/neo4j/langchain_neo4j/vectorstores/neo4j_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,21 +655,23 @@ def query(
or e.code
== "Neo.DatabaseError.Transaction.TransactionStartFailed"
)
and "in an implicit transaction" in e.message # type: ignore
and e.message is not None
and "in an implicit transaction" in e.message
)
or ( # isPeriodicCommitError
e.code == "Neo.ClientError.Statement.SemanticError"
and e.message is not None
and (
"in an open transaction is not possible" in e.message # type: ignore
or "tried to execute in an explicit transaction" in e.message # type: ignore
"in an open transaction is not possible" in e.message
or "tried to execute in an explicit transaction" in e.message
)
)
):
raise
# Fallback to allow implicit transactions
with self._driver.session(database=self._database) as session:
data = session.run(Query(text=query), params) # type: ignore
return [r.data() for r in data]
result = session.run(Query(text=query), params)
return [r.data() for r in result]

def verify_version(self) -> None:
"""
Expand Down
3 changes: 2 additions & 1 deletion libs/neo4j/tests/integration_tests/graphs/test_neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,9 @@ def test_neo4j_timeout() -> None:
try:
graph.query("UNWIND range(0,100000,1) AS i MERGE (:Foo {id:i})")
except Exception as e:
assert hasattr(e, "code")
assert (
e.code # type: ignore[attr-defined]
e.code
== "Neo.ClientError.Transaction.TransactionTimedOutClientConfiguration"
)

Expand Down
32 changes: 14 additions & 18 deletions libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
from langchain_neo4j.graphs.neo4j_graph import Neo4jGraph, value_sanitize


def test_value_sanitize_with_small_list(): # type: ignore[no-untyped-def]
def test_value_sanitize_with_small_list() -> None:
small_list = list(range(15)) # list size > LIST_LIMIT
input_dict = {"key1": "value1", "small_list": small_list}
expected_output = {"key1": "value1", "small_list": small_list}
assert value_sanitize(input_dict) == expected_output


def test_value_sanitize_with_oversized_list(): # type: ignore[no-untyped-def]
def test_value_sanitize_with_oversized_list() -> None:
oversized_list = list(range(150)) # list size > LIST_LIMIT
input_dict = {"key1": "value1", "oversized_list": oversized_list}
expected_output = {
Expand All @@ -22,21 +22,21 @@ def test_value_sanitize_with_oversized_list(): # type: ignore[no-untyped-def]
assert value_sanitize(input_dict) == expected_output


def test_value_sanitize_with_nested_oversized_list(): # type: ignore[no-untyped-def]
def test_value_sanitize_with_nested_oversized_list() -> None:
oversized_list = list(range(150)) # list size > LIST_LIMIT
input_dict = {"key1": "value1", "oversized_list": {"key": oversized_list}}
expected_output = {"key1": "value1", "oversized_list": {}}
assert value_sanitize(input_dict) == expected_output


def test_value_sanitize_with_dict_in_list(): # type: ignore[no-untyped-def]
def test_value_sanitize_with_dict_in_list() -> None:
oversized_list = list(range(150)) # list size > LIST_LIMIT
input_dict = {"key1": "value1", "oversized_list": [1, 2, {"key": oversized_list}]}
expected_output = {"key1": "value1", "oversized_list": [1, 2, {}]}
assert value_sanitize(input_dict) == expected_output


def test_value_sanitize_with_dict_in_nested_list(): # type: ignore[no-untyped-def]
def test_value_sanitize_with_dict_in_nested_list() -> None:
input_dict = {
"key1": "value1",
"deeply_nested_lists": [[[[{"final_nested_key": list(range(200))}]]]],
Expand All @@ -45,9 +45,9 @@ def test_value_sanitize_with_dict_in_nested_list(): # type: ignore[no-untyped-d
assert value_sanitize(input_dict) == expected_output


def test_driver_state_management(): # type: ignore[no-untyped-def]
def test_driver_state_management() -> None:
"""Comprehensive test for driver state management."""
with patch("neo4j.GraphDatabase.driver") as mock_driver:
with patch("neo4j.GraphDatabase.driver", autospec=True) as mock_driver:
# Setup mock driver
mock_driver_instance = MagicMock()
mock_driver.return_value = mock_driver_instance
Expand All @@ -60,7 +60,7 @@ def test_driver_state_management(): # type: ignore[no-untyped-def]

# Store original driver
original_driver = graph._driver
original_driver.close = MagicMock()
assert isinstance(original_driver.close, MagicMock)

# Test initial state
assert hasattr(graph, "_driver")
Expand All @@ -84,9 +84,9 @@ def test_driver_state_management(): # type: ignore[no-untyped-def]
graph.refresh_schema()


def test_close_method_removes_driver(): # type: ignore[no-untyped-def]
def test_close_method_removes_driver() -> None:
"""Test that close method removes the _driver attribute."""
with patch("neo4j.GraphDatabase.driver") as mock_driver:
with patch("neo4j.GraphDatabase.driver", autospec=True) as mock_driver:
# Configure mock to return a mock driver
mock_driver_instance = MagicMock()
mock_driver.return_value = mock_driver_instance
Expand All @@ -103,9 +103,7 @@ def test_close_method_removes_driver(): # type: ignore[no-untyped-def]

# Store a reference to the original driver
original_driver = graph._driver

# Ensure driver's close method can be mocked
original_driver.close = MagicMock()
assert isinstance(original_driver.close, MagicMock)

# Call close method
graph.close()
Expand All @@ -120,9 +118,9 @@ def test_close_method_removes_driver(): # type: ignore[no-untyped-def]
graph.close() # Should not raise any exception


def test_multiple_close_calls_safe(): # type: ignore[no-untyped-def]
def test_multiple_close_calls_safe() -> None:
"""Test that multiple close calls do not raise errors."""
with patch("neo4j.GraphDatabase.driver") as mock_driver:
with patch("neo4j.GraphDatabase.driver", autospec=True) as mock_driver:
# Configure mock to return a mock driver
mock_driver_instance = MagicMock()
mock_driver.return_value = mock_driver_instance
Expand All @@ -139,9 +137,7 @@ def test_multiple_close_calls_safe(): # type: ignore[no-untyped-def]

# Store a reference to the original driver
original_driver = graph._driver

# Mock the driver's close method
original_driver.close = MagicMock()
assert isinstance(original_driver.close, MagicMock)

# First close
graph.close()
Expand Down
6 changes: 3 additions & 3 deletions libs/neo4j/tests/unit_tests/vectorstores/test_neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def test_get_search_index_query_invalid_search_type() -> None:

with pytest.raises(ValueError) as exc_info:
_get_search_index_query(
search_type=invalid_search_type, # type: ignore
search_type=invalid_search_type, # type: ignore[arg-type]
index_type=IndexType.NODE,
)

Expand Down Expand Up @@ -356,7 +356,7 @@ def test_check_if_not_null_with_none_value() -> None:

def test_handle_field_filter_invalid_field_type() -> None:
with pytest.raises(ValueError) as exc_info:
_handle_field_filter(field=123, value="some_value") # type: ignore
_handle_field_filter(field=123, value="some_value") # type: ignore[arg-type]
assert "field should be a string" in str(exc_info.value)


Expand Down Expand Up @@ -535,7 +535,7 @@ def test_neo4jvector_invalid_distance_strategy() -> None:
url="bolt://localhost:7687",
username="neo4j",
password="password",
distance_strategy="INVALID_STRATEGY", # type: ignore
distance_strategy="INVALID_STRATEGY", # type: ignore[arg-type]
)
assert "distance_strategy must be either 'EUCLIDEAN_DISTANCE' or 'COSINE'" in str(
exc_info.value
Expand Down

0 comments on commit 199d500

Please sign in to comment.