From 86f88bff2c70f4364c39f3acccfb36f5a7c5fde5 Mon Sep 17 00:00:00 2001 From: Estelle Scifo Date: Mon, 7 Oct 2024 17:24:45 +0200 Subject: [PATCH] Fix imports when some optional dependencies are not installed (#168) * Fix sentence-transformer embedding import * Fix import when openAI is not installed * Update changelog * Fix for mypy * ruff --- CHANGELOG.md | 3 + src/neo4j_graphrag/embeddings/openai.py | 34 ++++++---- .../embeddings/sentence_transformers.py | 18 ++--- src/neo4j_graphrag/llm/openai_llm.py | 65 ++++++++++++++----- tests/unit/embeddings/test_openai_embedder.py | 8 +-- .../embeddings/test_sentence_transformers.py | 21 +++--- tests/unit/llm/test_openai_llm.py | 12 ++-- 7 files changed, 106 insertions(+), 55 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cb7d74f4..59ad9fcc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,9 @@ ## Next +### Fixed +- Fix a bug where `openai` Python client and `numpy` were required to import any embedder or LLM. + ## 1.0.0a1 ## 1.0.0a0 diff --git a/src/neo4j_graphrag/embeddings/openai.py b/src/neo4j_graphrag/embeddings/openai.py index 6907059c..be10619d 100644 --- a/src/neo4j_graphrag/embeddings/openai.py +++ b/src/neo4j_graphrag/embeddings/openai.py @@ -15,7 +15,8 @@ from __future__ import annotations -from typing import Any, Type +import abc +from typing import Any from neo4j_graphrag.embeddings.base import Embedder @@ -25,26 +26,29 @@ openai = None # type: ignore -class OpenAIEmbeddings(Embedder): +class BaseOpenAIEmbeddings(Embedder, abc.ABC): + def __init__(self, model: str = "text-embedding-ada-002", **kwargs: Any) -> None: + if openai is None: + raise ImportError( + "Could not import openai python client. " + "Please install it with `pip install openai`." + ) + self.model = model + + +class OpenAIEmbeddings(BaseOpenAIEmbeddings): """ OpenAI embeddings class. This class uses the OpenAI python client to generate embeddings for text data. Args: model (str): The name of the OpenAI embedding model to use. Defaults to "text-embedding-ada-002". + kwargs: All other parameters will be passed to the openai.OpenAI init. """ - client_class: Type[openai.OpenAI] = openai.OpenAI - def __init__(self, model: str = "text-embedding-ada-002", **kwargs: Any) -> None: - if openai is None: - raise ImportError( - "Could not import openai python client. " - "Please install it with `pip install openai`." - ) - - self.openai_model = self.client_class(**kwargs) - self.model = model + super().__init__(model, **kwargs) + self.openai_client = openai.OpenAI(**kwargs) def embed_query(self, text: str, **kwargs: Any) -> list[float]: """ @@ -54,11 +58,13 @@ def embed_query(self, text: str, **kwargs: Any) -> list[float]: text (str): The text to generate an embedding for. **kwargs (Any): Additional arguments to pass to the OpenAI embedding generation function. """ - response = self.openai_model.embeddings.create( + response = self.openai_client.embeddings.create( input=text, model=self.model, **kwargs ) return response.data[0].embedding class AzureOpenAIEmbeddings(OpenAIEmbeddings): - client_class: Type[openai.OpenAI] = openai.AzureOpenAI + def __init__(self, model: str = "text-embedding-ada-002", **kwargs: Any) -> None: + super().__init__(model, **kwargs) + self.openai_client = openai.AzureOpenAI(**kwargs) diff --git a/src/neo4j_graphrag/embeddings/sentence_transformers.py b/src/neo4j_graphrag/embeddings/sentence_transformers.py index cfdbeaa5..9e5c2da1 100644 --- a/src/neo4j_graphrag/embeddings/sentence_transformers.py +++ b/src/neo4j_graphrag/embeddings/sentence_transformers.py @@ -15,8 +15,13 @@ from typing import Any -import numpy as np -import torch +try: + import numpy as np + import sentence_transformers + import torch +except ImportError: + sentence_transformers = None # type: ignore + from neo4j_graphrag.embeddings.base import Embedder @@ -25,15 +30,12 @@ class SentenceTransformerEmbeddings(Embedder): def __init__( self, model: str = "all-MiniLM-L6-v2", *args: Any, **kwargs: Any ) -> None: - try: - from sentence_transformers import SentenceTransformer - except ImportError as e: + if sentence_transformers is None: raise ImportError( "Could not import sentence_transformers python package. " "Please install it with `pip install sentence-transformers`." - ) from e - - self.model = SentenceTransformer(model, *args, **kwargs) + ) + self.model = sentence_transformers.SentenceTransformer(model, *args, **kwargs) def embed_query(self, text: str) -> Any: result = self.model.encode([text]) diff --git a/src/neo4j_graphrag/llm/openai_llm.py b/src/neo4j_graphrag/llm/openai_llm.py index e29813d4..50d33524 100644 --- a/src/neo4j_graphrag/llm/openai_llm.py +++ b/src/neo4j_graphrag/llm/openai_llm.py @@ -14,7 +14,8 @@ # limitations under the License. from __future__ import annotations -from typing import Any, Optional, Type +import abc +from typing import Any, Optional from ..exceptions import LLMGenerationError from .base import LLMInterface @@ -26,32 +27,30 @@ openai = None # type: ignore -class OpenAILLM(LLMInterface): - client_class: Type[openai.OpenAI] = openai.OpenAI - async_client_class: Type[openai.AsyncOpenAI] = openai.AsyncOpenAI +class BaseOpenAILLM(LLMInterface, abc.ABC): + client: Any + async_client: Any def __init__( self, model_name: str, model_params: Optional[dict[str, Any]] = None, - **kwargs: Any, ): """ + Base class for OpenAI LLM. + + Makes sure the openai Python client is installed during init. Args: model_name (str): model_params (str): Parameters like temperature that will be passed to the model when text is sent to it - kwargs: All other parameters will be passed to the openai.OpenAI init. - """ if openai is None: raise ImportError( "Could not import openai Python client. " "Please install it with `pip install openai`." ) - super().__init__(model_name, model_params, **kwargs) - self.client = self.client_class(**kwargs) - self.async_client = self.async_client_class(**kwargs) + super().__init__(model_name, model_params) def get_messages( self, @@ -76,7 +75,7 @@ def invoke(self, input: str) -> LLMResponse: """ try: response = self.client.chat.completions.create( - messages=self.get_messages(input), # type: ignore + messages=self.get_messages(input), model=self.model_name, **self.model_params, ) @@ -100,7 +99,7 @@ async def ainvoke(self, input: str) -> LLMResponse: """ try: response = await self.async_client.chat.completions.create( - messages=self.get_messages(input), # type: ignore + messages=self.get_messages(input), model=self.model_name, **self.model_params, ) @@ -110,6 +109,42 @@ async def ainvoke(self, input: str) -> LLMResponse: raise LLMGenerationError(e) -class AzureOpenAILLM(OpenAILLM): - client_class: Type[openai.OpenAI] = openai.AzureOpenAI - async_client_class: Type[openai.AsyncOpenAI] = openai.AsyncAzureOpenAI +class OpenAILLM(BaseOpenAILLM): + def __init__( + self, + model_name: str, + model_params: Optional[dict[str, Any]] = None, + **kwargs: Any, + ): + """OpenAI LLM + + Wrapper for the openai Python client LLM. + + Args: + model_name (str): + model_params (str): Parameters like temperature that will be passed to the model when text is sent to it + kwargs: All other parameters will be passed to the openai.OpenAI init. + """ + super().__init__(model_name, model_params) + self.client = openai.OpenAI(**kwargs) + self.async_client = openai.AsyncOpenAI(**kwargs) + + +class AzureOpenAILLM(BaseOpenAILLM): + def __init__( + self, + model_name: str, + model_params: Optional[dict[str, Any]] = None, + **kwargs: Any, + ): + """Azure OpenAI LLM. Use this class when using an OpenAI model + hosted on Microsoft Azure. + + Args: + model_name (str): + model_params (str): Parameters like temperature that will be passed to the model when text is sent to it + kwargs: All other parameters will be passed to the openai.OpenAI init. + """ + super().__init__(model_name, model_params) + self.client = openai.AzureOpenAI(**kwargs) + self.async_client = openai.AsyncAzureOpenAI(**kwargs) diff --git a/tests/unit/embeddings/test_openai_embedder.py b/tests/unit/embeddings/test_openai_embedder.py index 2b96e936..fbe259ce 100644 --- a/tests/unit/embeddings/test_openai_embedder.py +++ b/tests/unit/embeddings/test_openai_embedder.py @@ -27,9 +27,9 @@ def test_openai_embedder_missing_dependency() -> None: OpenAIEmbeddings() -@patch("neo4j_graphrag.embeddings.openai.OpenAIEmbeddings.client_class") +@patch("neo4j_graphrag.embeddings.openai.openai") def test_openai_embedder_happy_path(mock_openai: Mock) -> None: - mock_openai.return_value.embeddings.create.return_value = MagicMock( + mock_openai.OpenAI.return_value.embeddings.create.return_value = MagicMock( data=[MagicMock(embedding=[1.0, 2.0])], ) embedder = OpenAIEmbeddings(api_key="my key") @@ -44,9 +44,9 @@ def test_azure_openai_embedder_missing_dependency() -> None: AzureOpenAIEmbeddings() -@patch("neo4j_graphrag.embeddings.openai.AzureOpenAIEmbeddings.client_class") +@patch("neo4j_graphrag.embeddings.openai.openai") def test_azure_openai_embedder_happy_path(mock_openai: Mock) -> None: - mock_openai.return_value.embeddings.create.return_value = MagicMock( + mock_openai.AzureOpenAI.return_value.embeddings.create.return_value = MagicMock( data=[MagicMock(embedding=[1.0, 2.0])], ) embedder = AzureOpenAIEmbeddings( diff --git a/tests/unit/embeddings/test_sentence_transformers.py b/tests/unit/embeddings/test_sentence_transformers.py index 07dc363e..fd1798aa 100644 --- a/tests/unit/embeddings/test_sentence_transformers.py +++ b/tests/unit/embeddings/test_sentence_transformers.py @@ -8,34 +8,37 @@ ) -@patch("sentence_transformers.SentenceTransformer") +@patch("neo4j_graphrag.embeddings.sentence_transformers.sentence_transformers") def test_initialization(MockSentenceTransformer: MagicMock) -> None: instance = SentenceTransformerEmbeddings() - MockSentenceTransformer.assert_called_with("all-MiniLM-L6-v2") + MockSentenceTransformer.SentenceTransformer.assert_called_with("all-MiniLM-L6-v2") assert isinstance(instance, Embedder) -@patch("sentence_transformers.SentenceTransformer") +@patch("neo4j_graphrag.embeddings.sentence_transformers.sentence_transformers") def test_initialization_with_custom_model(MockSentenceTransformer: MagicMock) -> None: custom_model = "distilbert-base-nli-stsb-mean-tokens" SentenceTransformerEmbeddings(model=custom_model) - MockSentenceTransformer.assert_called_with(custom_model) + MockSentenceTransformer.SentenceTransformer.assert_called_with(custom_model) -@patch("sentence_transformers.SentenceTransformer") +@patch("neo4j_graphrag.embeddings.sentence_transformers.sentence_transformers") def test_embed_query(MockSentenceTransformer: MagicMock) -> None: - mock_model = MockSentenceTransformer.return_value + mock_model = MockSentenceTransformer.SentenceTransformer.return_value mock_model.encode.return_value = np.array([[0.1, 0.2, 0.3]]) instance = SentenceTransformerEmbeddings() result = instance.embed_query("test query") mock_model.encode.assert_called_with(["test query"]) - assert result == [0.1, 0.2, 0.3] assert isinstance(result, list) + assert result == [0.1, 0.2, 0.3] -@patch("sentence_transformers.SentenceTransformer", side_effect=ImportError) -def test_import_error(MockSentenceTransformer: MagicMock) -> None: +@patch( + "neo4j_graphrag.embeddings.sentence_transformers.sentence_transformers", + None, +) +def test_import_error() -> None: with pytest.raises(ImportError): SentenceTransformerEmbeddings() diff --git a/tests/unit/llm/test_openai_llm.py b/tests/unit/llm/test_openai_llm.py index 04bee3ca..3cf0e7d0 100644 --- a/tests/unit/llm/test_openai_llm.py +++ b/tests/unit/llm/test_openai_llm.py @@ -25,9 +25,9 @@ def test_openai_llm_missing_dependency() -> None: OpenAILLM(model_name="gpt-4o") -@patch("neo4j_graphrag.llm.openai_llm.OpenAILLM.client_class") +@patch("neo4j_graphrag.llm.openai_llm.openai") def test_openai_llm_happy_path(mock_openai: Mock) -> None: - mock_openai.return_value.chat.completions.create.return_value = MagicMock( + mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock( choices=[MagicMock(message=MagicMock(content="openai chat response"))], ) llm = OpenAILLM(api_key="my key", model_name="gpt") @@ -42,10 +42,12 @@ def test_azure_openai_llm_missing_dependency() -> None: AzureOpenAILLM(model_name="gpt-4o") -@patch("neo4j_graphrag.llm.openai_llm.AzureOpenAILLM.client_class") +@patch("neo4j_graphrag.llm.openai_llm.openai") def test_azure_openai_llm_happy_path(mock_openai: Mock) -> None: - mock_openai.return_value.chat.completions.create.return_value = MagicMock( - choices=[MagicMock(message=MagicMock(content="openai chat response"))], + mock_openai.AzureOpenAI.return_value.chat.completions.create.return_value = ( + MagicMock( + choices=[MagicMock(message=MagicMock(content="openai chat response"))], + ) ) llm = AzureOpenAILLM( model_name="gpt",