diff --git a/examples/customize/answer/langchain_compatiblity.py b/examples/customize/answer/langchain_compatiblity.py index 9c8b0e06..d3382849 100644 --- a/examples/customize/answer/langchain_compatiblity.py +++ b/examples/customize/answer/langchain_compatiblity.py @@ -30,14 +30,14 @@ driver, index_name=INDEX, retrieval_query="WITH node, score RETURN node.title as title, node.plot as plot", - embedder=embedder, # type: ignore + embedder=embedder, # type: ignore[arg-type, unused-ignore] ) llm = ChatOpenAI(model="gpt-4o", temperature=0) rag = GraphRAG( retriever=retriever, - llm=llm, # type: ignore + llm=llm, # type: ignore[arg-type, unused-ignore] ) result = rag.search("Tell me more about Avatar movies") diff --git a/src/neo4j_graphrag/experimental/components/resolver.py b/src/neo4j_graphrag/experimental/components/resolver.py index 4e07a1d6..1dee9e10 100644 --- a/src/neo4j_graphrag/experimental/components/resolver.py +++ b/src/neo4j_graphrag/experimental/components/resolver.py @@ -26,7 +26,7 @@ class EntityResolver(Component, abc.ABC): """Entity resolution base class Args: - driver (neo4j.driver): The Neo4j driver to connect to the database. + driver (neo4j.Driver): The Neo4j driver to connect to the database. filter_query (Optional[str]): Cypher query to select the entities to resolve. By default, all nodes with __Entity__ label are used """ @@ -47,7 +47,7 @@ class SinglePropertyExactMatchResolver(EntityResolver): """Resolve entities with same label and exact same property (default is "name"). Args: - driver (neo4j.driver): The Neo4j driver to connect to the database. + driver (neo4j.Driver): The Neo4j driver to connect to the database. filter_query (Optional[str]): To reduce the resolution scope, add a Cypher WHERE clause. resolve_property (str): The property that will be compared (default: "name"). If values match exactly, entities are merged. neo4j_database (Optional[str]): The name of the Neo4j database to write to. Defaults to 'neo4j' if not provided. @@ -94,7 +94,7 @@ async def run(self) -> ResolutionStats: if self.filter_query: match_query += self.filter_query stat_query = f"{match_query} RETURN count(entity) as c" - records = await execute_query( + records, _, _ = await execute_query( self.driver, stat_query, database_=self.database, @@ -130,7 +130,7 @@ async def run(self) -> ResolutionStats: "YIELD node " "RETURN count(node) as c " ) - records = await execute_query( + records, _, _ = await execute_query( self.driver, merge_nodes_query, database_=self.database, diff --git a/src/neo4j_graphrag/retrievers/text2cypher.py b/src/neo4j_graphrag/retrievers/text2cypher.py index 84c3a7a3..2bb2aa69 100644 --- a/src/neo4j_graphrag/retrievers/text2cypher.py +++ b/src/neo4j_graphrag/retrievers/text2cypher.py @@ -51,7 +51,7 @@ class Text2CypherRetriever(Retriever): then retrieves records from a Neo4j database using the generated Cypher query Args: - driver (neo4j.driver): The Neo4j Python driver. + driver (neo4j.Driver): The Neo4j Python driver. llm (neo4j_graphrag.generation.llm.LLMInterface): LLM object to generate the Cypher query. neo4j_schema (Optional[str]): Neo4j schema used to generate the Cypher query. examples (Optional[list[str], optional): Optional user input/query pairs for the LLM to use as examples. diff --git a/src/neo4j_graphrag/utils.py b/src/neo4j_graphrag/utils.py index 5f1b322f..77109046 100644 --- a/src/neo4j_graphrag/utils.py +++ b/src/neo4j_graphrag/utils.py @@ -14,8 +14,7 @@ # limitations under the License. from __future__ import annotations -import inspect -from typing import Any, Optional, Union +from typing import Any, Optional import neo4j @@ -28,12 +27,10 @@ def validate_search_query_input( async def execute_query( - driver: Union[neo4j.Driver, neo4j.AsyncDriver], query: str, **kwargs: Any -) -> list[neo4j.Record]: - if inspect.iscoroutinefunction(driver.execute_query): - records, _, _ = await driver.execute_query(query, **kwargs) - return records # type: ignore[no-any-return] - # ignoring type because mypy complains about coroutine - # but we're sure at this stage we do not have a coroutine anymore - records, _, _ = driver.execute_query(query, **kwargs) # type: ignore[misc] - return records # type: ignore[no-any-return] + driver: neo4j.Driver | neo4j.AsyncDriver, query: str, **kwargs: Any +) -> Any: + if isinstance(driver, neo4j.AsyncDriver): + result = await driver.execute_query(query, **kwargs) + else: + result = driver.execute_query(query, **kwargs) + return result