Skip to content

Commit

Permalink
Fix imports when some optional dependencies are not installed (neo4j#168
Browse files Browse the repository at this point in the history
)

* Fix sentence-transformer embedding import

* Fix import when openAI is not installed

* Update changelog

* Fix for mypy

* ruff
  • Loading branch information
stellasia authored Oct 7, 2024
1 parent cfeda57 commit 86f88bf
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 55 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

## Next

### Fixed
- Fix a bug where `openai` Python client and `numpy` were required to import any embedder or LLM.

## 1.0.0a1

## 1.0.0a0
Expand Down
34 changes: 20 additions & 14 deletions src/neo4j_graphrag/embeddings/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

from __future__ import annotations

from typing import Any, Type
import abc
from typing import Any

from neo4j_graphrag.embeddings.base import Embedder

Expand All @@ -25,26 +26,29 @@
openai = None # type: ignore


class OpenAIEmbeddings(Embedder):
class BaseOpenAIEmbeddings(Embedder, abc.ABC):
def __init__(self, model: str = "text-embedding-ada-002", **kwargs: Any) -> None:
if openai is None:
raise ImportError(
"Could not import openai python client. "
"Please install it with `pip install openai`."
)
self.model = model


class OpenAIEmbeddings(BaseOpenAIEmbeddings):
"""
OpenAI embeddings class.
This class uses the OpenAI python client to generate embeddings for text data.
Args:
model (str): The name of the OpenAI embedding model to use. Defaults to "text-embedding-ada-002".
kwargs: All other parameters will be passed to the openai.OpenAI init.
"""

client_class: Type[openai.OpenAI] = openai.OpenAI

def __init__(self, model: str = "text-embedding-ada-002", **kwargs: Any) -> None:
if openai is None:
raise ImportError(
"Could not import openai python client. "
"Please install it with `pip install openai`."
)

self.openai_model = self.client_class(**kwargs)
self.model = model
super().__init__(model, **kwargs)
self.openai_client = openai.OpenAI(**kwargs)

def embed_query(self, text: str, **kwargs: Any) -> list[float]:
"""
Expand All @@ -54,11 +58,13 @@ def embed_query(self, text: str, **kwargs: Any) -> list[float]:
text (str): The text to generate an embedding for.
**kwargs (Any): Additional arguments to pass to the OpenAI embedding generation function.
"""
response = self.openai_model.embeddings.create(
response = self.openai_client.embeddings.create(
input=text, model=self.model, **kwargs
)
return response.data[0].embedding


class AzureOpenAIEmbeddings(OpenAIEmbeddings):
client_class: Type[openai.OpenAI] = openai.AzureOpenAI
def __init__(self, model: str = "text-embedding-ada-002", **kwargs: Any) -> None:
super().__init__(model, **kwargs)
self.openai_client = openai.AzureOpenAI(**kwargs)
18 changes: 10 additions & 8 deletions src/neo4j_graphrag/embeddings/sentence_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,13 @@

from typing import Any

import numpy as np
import torch
try:
import numpy as np
import sentence_transformers
import torch
except ImportError:
sentence_transformers = None # type: ignore


from neo4j_graphrag.embeddings.base import Embedder

Expand All @@ -25,15 +30,12 @@ class SentenceTransformerEmbeddings(Embedder):
def __init__(
self, model: str = "all-MiniLM-L6-v2", *args: Any, **kwargs: Any
) -> None:
try:
from sentence_transformers import SentenceTransformer
except ImportError as e:
if sentence_transformers is None:
raise ImportError(
"Could not import sentence_transformers python package. "
"Please install it with `pip install sentence-transformers`."
) from e

self.model = SentenceTransformer(model, *args, **kwargs)
)
self.model = sentence_transformers.SentenceTransformer(model, *args, **kwargs)

def embed_query(self, text: str) -> Any:
result = self.model.encode([text])
Expand Down
65 changes: 50 additions & 15 deletions src/neo4j_graphrag/llm/openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
# limitations under the License.
from __future__ import annotations

from typing import Any, Optional, Type
import abc
from typing import Any, Optional

from ..exceptions import LLMGenerationError
from .base import LLMInterface
Expand All @@ -26,32 +27,30 @@
openai = None # type: ignore


class OpenAILLM(LLMInterface):
client_class: Type[openai.OpenAI] = openai.OpenAI
async_client_class: Type[openai.AsyncOpenAI] = openai.AsyncOpenAI
class BaseOpenAILLM(LLMInterface, abc.ABC):
client: Any
async_client: Any

def __init__(
self,
model_name: str,
model_params: Optional[dict[str, Any]] = None,
**kwargs: Any,
):
"""
Base class for OpenAI LLM.
Makes sure the openai Python client is installed during init.
Args:
model_name (str):
model_params (str): Parameters like temperature that will be passed to the model when text is sent to it
kwargs: All other parameters will be passed to the openai.OpenAI init.
"""
if openai is None:
raise ImportError(
"Could not import openai Python client. "
"Please install it with `pip install openai`."
)
super().__init__(model_name, model_params, **kwargs)
self.client = self.client_class(**kwargs)
self.async_client = self.async_client_class(**kwargs)
super().__init__(model_name, model_params)

def get_messages(
self,
Expand All @@ -76,7 +75,7 @@ def invoke(self, input: str) -> LLMResponse:
"""
try:
response = self.client.chat.completions.create(
messages=self.get_messages(input), # type: ignore
messages=self.get_messages(input),
model=self.model_name,
**self.model_params,
)
Expand All @@ -100,7 +99,7 @@ async def ainvoke(self, input: str) -> LLMResponse:
"""
try:
response = await self.async_client.chat.completions.create(
messages=self.get_messages(input), # type: ignore
messages=self.get_messages(input),
model=self.model_name,
**self.model_params,
)
Expand All @@ -110,6 +109,42 @@ async def ainvoke(self, input: str) -> LLMResponse:
raise LLMGenerationError(e)


class AzureOpenAILLM(OpenAILLM):
client_class: Type[openai.OpenAI] = openai.AzureOpenAI
async_client_class: Type[openai.AsyncOpenAI] = openai.AsyncAzureOpenAI
class OpenAILLM(BaseOpenAILLM):
def __init__(
self,
model_name: str,
model_params: Optional[dict[str, Any]] = None,
**kwargs: Any,
):
"""OpenAI LLM
Wrapper for the openai Python client LLM.
Args:
model_name (str):
model_params (str): Parameters like temperature that will be passed to the model when text is sent to it
kwargs: All other parameters will be passed to the openai.OpenAI init.
"""
super().__init__(model_name, model_params)
self.client = openai.OpenAI(**kwargs)
self.async_client = openai.AsyncOpenAI(**kwargs)


class AzureOpenAILLM(BaseOpenAILLM):
def __init__(
self,
model_name: str,
model_params: Optional[dict[str, Any]] = None,
**kwargs: Any,
):
"""Azure OpenAI LLM. Use this class when using an OpenAI model
hosted on Microsoft Azure.
Args:
model_name (str):
model_params (str): Parameters like temperature that will be passed to the model when text is sent to it
kwargs: All other parameters will be passed to the openai.OpenAI init.
"""
super().__init__(model_name, model_params)
self.client = openai.AzureOpenAI(**kwargs)
self.async_client = openai.AsyncAzureOpenAI(**kwargs)
8 changes: 4 additions & 4 deletions tests/unit/embeddings/test_openai_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ def test_openai_embedder_missing_dependency() -> None:
OpenAIEmbeddings()


@patch("neo4j_graphrag.embeddings.openai.OpenAIEmbeddings.client_class")
@patch("neo4j_graphrag.embeddings.openai.openai")
def test_openai_embedder_happy_path(mock_openai: Mock) -> None:
mock_openai.return_value.embeddings.create.return_value = MagicMock(
mock_openai.OpenAI.return_value.embeddings.create.return_value = MagicMock(
data=[MagicMock(embedding=[1.0, 2.0])],
)
embedder = OpenAIEmbeddings(api_key="my key")
Expand All @@ -44,9 +44,9 @@ def test_azure_openai_embedder_missing_dependency() -> None:
AzureOpenAIEmbeddings()


@patch("neo4j_graphrag.embeddings.openai.AzureOpenAIEmbeddings.client_class")
@patch("neo4j_graphrag.embeddings.openai.openai")
def test_azure_openai_embedder_happy_path(mock_openai: Mock) -> None:
mock_openai.return_value.embeddings.create.return_value = MagicMock(
mock_openai.AzureOpenAI.return_value.embeddings.create.return_value = MagicMock(
data=[MagicMock(embedding=[1.0, 2.0])],
)
embedder = AzureOpenAIEmbeddings(
Expand Down
21 changes: 12 additions & 9 deletions tests/unit/embeddings/test_sentence_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,34 +8,37 @@
)


@patch("sentence_transformers.SentenceTransformer")
@patch("neo4j_graphrag.embeddings.sentence_transformers.sentence_transformers")
def test_initialization(MockSentenceTransformer: MagicMock) -> None:
instance = SentenceTransformerEmbeddings()
MockSentenceTransformer.assert_called_with("all-MiniLM-L6-v2")
MockSentenceTransformer.SentenceTransformer.assert_called_with("all-MiniLM-L6-v2")
assert isinstance(instance, Embedder)


@patch("sentence_transformers.SentenceTransformer")
@patch("neo4j_graphrag.embeddings.sentence_transformers.sentence_transformers")
def test_initialization_with_custom_model(MockSentenceTransformer: MagicMock) -> None:
custom_model = "distilbert-base-nli-stsb-mean-tokens"
SentenceTransformerEmbeddings(model=custom_model)
MockSentenceTransformer.assert_called_with(custom_model)
MockSentenceTransformer.SentenceTransformer.assert_called_with(custom_model)


@patch("sentence_transformers.SentenceTransformer")
@patch("neo4j_graphrag.embeddings.sentence_transformers.sentence_transformers")
def test_embed_query(MockSentenceTransformer: MagicMock) -> None:
mock_model = MockSentenceTransformer.return_value
mock_model = MockSentenceTransformer.SentenceTransformer.return_value
mock_model.encode.return_value = np.array([[0.1, 0.2, 0.3]])

instance = SentenceTransformerEmbeddings()
result = instance.embed_query("test query")

mock_model.encode.assert_called_with(["test query"])
assert result == [0.1, 0.2, 0.3]
assert isinstance(result, list)
assert result == [0.1, 0.2, 0.3]


@patch("sentence_transformers.SentenceTransformer", side_effect=ImportError)
def test_import_error(MockSentenceTransformer: MagicMock) -> None:
@patch(
"neo4j_graphrag.embeddings.sentence_transformers.sentence_transformers",
None,
)
def test_import_error() -> None:
with pytest.raises(ImportError):
SentenceTransformerEmbeddings()
12 changes: 7 additions & 5 deletions tests/unit/llm/test_openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ def test_openai_llm_missing_dependency() -> None:
OpenAILLM(model_name="gpt-4o")


@patch("neo4j_graphrag.llm.openai_llm.OpenAILLM.client_class")
@patch("neo4j_graphrag.llm.openai_llm.openai")
def test_openai_llm_happy_path(mock_openai: Mock) -> None:
mock_openai.return_value.chat.completions.create.return_value = MagicMock(
mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock(
choices=[MagicMock(message=MagicMock(content="openai chat response"))],
)
llm = OpenAILLM(api_key="my key", model_name="gpt")
Expand All @@ -42,10 +42,12 @@ def test_azure_openai_llm_missing_dependency() -> None:
AzureOpenAILLM(model_name="gpt-4o")


@patch("neo4j_graphrag.llm.openai_llm.AzureOpenAILLM.client_class")
@patch("neo4j_graphrag.llm.openai_llm.openai")
def test_azure_openai_llm_happy_path(mock_openai: Mock) -> None:
mock_openai.return_value.chat.completions.create.return_value = MagicMock(
choices=[MagicMock(message=MagicMock(content="openai chat response"))],
mock_openai.AzureOpenAI.return_value.chat.completions.create.return_value = (
MagicMock(
choices=[MagicMock(message=MagicMock(content="openai chat response"))],
)
)
llm = AzureOpenAILLM(
model_name="gpt",
Expand Down

0 comments on commit 86f88bf

Please sign in to comment.