Skip to content

Commit

Permalink
Fixed query parameter bug in GraphRAG class
Browse files Browse the repository at this point in the history
  • Loading branch information
alexthomas93 committed Aug 22, 2024
1 parent 6e544df commit 070b46e
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 7 deletions.
14 changes: 7 additions & 7 deletions src/neo4j_genai/generation/graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,13 @@ def search(
DeprecationWarning,
stacklevel=2,
)
elif isinstance(query, str):
warnings.warn(
"'query' is deprecated and will be removed in a future version, please use 'query_text' instead.",
DeprecationWarning,
stacklevel=2,
)
query_text = query
elif isinstance(query, str):
warnings.warn(
"'query' is deprecated and will be removed in a future version, please use 'query_text' instead.",
DeprecationWarning,
stacklevel=2,
)
query_text = query

validated_data = RagSearchModel(
query_text=query_text,
Expand Down
20 changes: 20 additions & 0 deletions tests/unit/test_graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import MagicMock
from warnings import catch_warnings

import pytest
from neo4j_genai.exceptions import RagInitializationError, SearchValidationError
Expand All @@ -21,6 +22,7 @@
from neo4j_genai.generation.types import RagResultModel
from neo4j_genai.llm import LLMResponse
from neo4j_genai.types import RetrieverResult, RetrieverResultItem
from pydantic import ValidationError


def test_graphrag_prompt_template() -> None:
Expand Down Expand Up @@ -99,3 +101,21 @@ def test_graphrag_search_error(retriever_mock: MagicMock, llm: MagicMock) -> Non
with pytest.raises(SearchValidationError) as excinfo:
rag.search(10) # type: ignore
assert "Input should be a valid string" in str(excinfo)


def test_graphrag_search_query_deprecation_warning(
retriever_mock: MagicMock, llm: MagicMock
) -> None:
with catch_warnings(record=True) as warn_list:
rag = GraphRAG(
retriever=retriever_mock,
llm=llm,
)
with pytest.raises(ValidationError):
rag.search(query="Some query text")

assert len(warn_list) == 1
assert (
str(warn_list[0].message)
== "'query' is deprecated and will be removed in a future version, please use 'query_text' instead."
)

0 comments on commit 070b46e

Please sign in to comment.