Skip to content

Commit

Permalink
Removes unneeded # type: ignore comments (neo4j#198)
Browse files Browse the repository at this point in the history
* Fixes mypy issue for execute_query function

* Updated type ignore comments in langchain_compatiblity.py
  • Loading branch information
alexthomas93 authored Oct 22, 2024
1 parent fd9cd18 commit fd276af
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 18 deletions.
4 changes: 2 additions & 2 deletions examples/customize/answer/langchain_compatiblity.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@
driver,
index_name=INDEX,
retrieval_query="WITH node, score RETURN node.title as title, node.plot as plot",
embedder=embedder, # type: ignore
embedder=embedder, # type: ignore[arg-type, unused-ignore]
)

llm = ChatOpenAI(model="gpt-4o", temperature=0)

rag = GraphRAG(
retriever=retriever,
llm=llm, # type: ignore
llm=llm, # type: ignore[arg-type, unused-ignore]
)

result = rag.search("Tell me more about Avatar movies")
Expand Down
8 changes: 4 additions & 4 deletions src/neo4j_graphrag/experimental/components/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class EntityResolver(Component, abc.ABC):
"""Entity resolution base class
Args:
driver (neo4j.driver): The Neo4j driver to connect to the database.
driver (neo4j.Driver): The Neo4j driver to connect to the database.
filter_query (Optional[str]): Cypher query to select the entities to resolve. By default, all nodes with __Entity__ label are used
"""

Expand All @@ -47,7 +47,7 @@ class SinglePropertyExactMatchResolver(EntityResolver):
"""Resolve entities with same label and exact same property (default is "name").
Args:
driver (neo4j.driver): The Neo4j driver to connect to the database.
driver (neo4j.Driver): The Neo4j driver to connect to the database.
filter_query (Optional[str]): To reduce the resolution scope, add a Cypher WHERE clause.
resolve_property (str): The property that will be compared (default: "name"). If values match exactly, entities are merged.
neo4j_database (Optional[str]): The name of the Neo4j database to write to. Defaults to 'neo4j' if not provided.
Expand Down Expand Up @@ -94,7 +94,7 @@ async def run(self) -> ResolutionStats:
if self.filter_query:
match_query += self.filter_query
stat_query = f"{match_query} RETURN count(entity) as c"
records = await execute_query(
records, _, _ = await execute_query(
self.driver,
stat_query,
database_=self.database,
Expand Down Expand Up @@ -130,7 +130,7 @@ async def run(self) -> ResolutionStats:
"YIELD node "
"RETURN count(node) as c "
)
records = await execute_query(
records, _, _ = await execute_query(
self.driver,
merge_nodes_query,
database_=self.database,
Expand Down
2 changes: 1 addition & 1 deletion src/neo4j_graphrag/retrievers/text2cypher.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class Text2CypherRetriever(Retriever):
then retrieves records from a Neo4j database using the generated Cypher query
Args:
driver (neo4j.driver): The Neo4j Python driver.
driver (neo4j.Driver): The Neo4j Python driver.
llm (neo4j_graphrag.generation.llm.LLMInterface): LLM object to generate the Cypher query.
neo4j_schema (Optional[str]): Neo4j schema used to generate the Cypher query.
examples (Optional[list[str], optional): Optional user input/query pairs for the LLM to use as examples.
Expand Down
19 changes: 8 additions & 11 deletions src/neo4j_graphrag/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
# limitations under the License.
from __future__ import annotations

import inspect
from typing import Any, Optional, Union
from typing import Any, Optional

import neo4j

Expand All @@ -28,12 +27,10 @@ def validate_search_query_input(


async def execute_query(
driver: Union[neo4j.Driver, neo4j.AsyncDriver], query: str, **kwargs: Any
) -> list[neo4j.Record]:
if inspect.iscoroutinefunction(driver.execute_query):
records, _, _ = await driver.execute_query(query, **kwargs)
return records # type: ignore[no-any-return]
# ignoring type because mypy complains about coroutine
# but we're sure at this stage we do not have a coroutine anymore
records, _, _ = driver.execute_query(query, **kwargs) # type: ignore[misc]
return records # type: ignore[no-any-return]
driver: neo4j.Driver | neo4j.AsyncDriver, query: str, **kwargs: Any
) -> Any:
if isinstance(driver, neo4j.AsyncDriver):
result = await driver.execute_query(query, **kwargs)
else:
result = driver.execute_query(query, **kwargs)
return result

0 comments on commit fd276af

Please sign in to comment.