From 140a057a7bb6da533254f47120b014f46204f95d Mon Sep 17 00:00:00 2001 From: Estelle Scifo Date: Thu, 12 Dec 2024 15:25:19 +0100 Subject: [PATCH] Add OllamaLLM and OllamaEmbeddings classes (#231) * Add OllamaLLM and OllamaEmbeddings classes using the ollama python client * Try removing import * :( * Add tests + reformat import in ollama embeddings for consistency with all other imports * Fix after merge --- CHANGELOG.md | 1 + README.md | 1 + docs/source/api.rst | 12 +++ docs/source/index.rst | 1 + docs/source/user_guide_rag.rst | 34 ++------- .../customize/embeddings/ollama_embeddings.py | 11 +-- examples/customize/llms/ollama_llm.py | 11 ++- poetry.lock | 55 +++++++++----- pyproject.toml | 2 + src/neo4j_graphrag/embeddings/__init__.py | 2 + src/neo4j_graphrag/embeddings/mistral.py | 3 +- src/neo4j_graphrag/embeddings/ollama.py | 65 ++++++++++++++++ src/neo4j_graphrag/llm/__init__.py | 2 + src/neo4j_graphrag/llm/ollama_llm.py | 76 +++++++++++++++++++ tests/unit/embeddings/test_ollama_embedder.py | 35 +++++++++ tests/unit/llm/test_ollama_llm.py | 46 +++++++++++ 16 files changed, 293 insertions(+), 64 deletions(-) create mode 100644 src/neo4j_graphrag/embeddings/ollama.py create mode 100644 src/neo4j_graphrag/llm/ollama_llm.py create mode 100644 tests/unit/embeddings/test_ollama_embedder.py create mode 100644 tests/unit/llm/test_ollama_llm.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ac264f0..45ebbeb2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ - Integrated json-repair package to handle and repair invalid JSON generated by LLMs. - Introduced InvalidJSONError exception for handling cases where JSON repair fails. - Ability to create a Pipeline or SimpleKGPipeline from a config file. See [the example](examples/build_graph/from_config_files/simple_kg_pipeline_from_config_file.py). +- Added `OllamaLLM` and `OllamaEmbeddings` classes to make Ollama support more explicit. Implementations using the `OpenAILLM` and `OpenAIEmbeddings` classes will still work. ## Changed - Updated LLM prompts to include stricter instructions for generating valid JSON. diff --git a/README.md b/README.md index 9fb37a0e..74a2d6dd 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ This package has some optional features that can be enabled using the extra dependencies described below: - LLM providers (at least one is required for RAG and KG Builder Pipeline): + - **ollama**: LLMs from Ollama - **openai**: LLMs from OpenAI (including AzureOpenAI) - **google**: LLMs from Vertex AI - **cohere**: LLMs from Cohere diff --git a/docs/source/api.rst b/docs/source/api.rst index 7d46b3f0..8e52f9a4 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -239,6 +239,12 @@ AzureOpenAIEmbeddings .. autoclass:: neo4j_graphrag.embeddings.openai.AzureOpenAIEmbeddings :members: +OllamaEmbeddings +================ + +.. autoclass:: neo4j_graphrag.embeddings.ollama.OllamaEmbeddings + :members: + VertexAIEmbeddings ================== @@ -286,6 +292,12 @@ AzureOpenAILLM :members: :undoc-members: get_messages, client_class, async_client_class +OllamaLLM +--------- + +.. autoclass:: neo4j_graphrag.llm.ollama_llm.OllamaLLM + :members: + VertexAILLM ----------- diff --git a/docs/source/index.rst b/docs/source/index.rst index 9d6f9a3d..f48bf281 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -88,6 +88,7 @@ Extra dependencies can be installed with: List of extra dependencies: - LLM providers (at least one is required for RAG and KG Builder Pipeline): + - **ollama**: LLMs from Ollama - **openai**: LLMs from OpenAI (including AzureOpenAI) - **google**: LLMs from Vertex AI - **cohere**: LLMs from Cohere diff --git a/docs/source/user_guide_rag.rst b/docs/source/user_guide_rag.rst index 2b2b528b..8074eef7 100644 --- a/docs/source/user_guide_rag.rst +++ b/docs/source/user_guide_rag.rst @@ -218,14 +218,13 @@ See :ref:`coherellm`. Using a Local Model via Ollama ------------------------------- -Similarly to the official OpenAI Python client, the `OpenAILLM` can be -used with Ollama. Assuming Ollama is running on the default address `127.0.0.1:11434`, +Assuming Ollama is running on the default address `127.0.0.1:11434`, it can be queried using the following: .. code:: python - from neo4j_graphrag.llm import OpenAILLM - llm = OpenAILLM(api_key="ollama", base_url="http://127.0.0.1:11434/v1", model_name="orca-mini") + from neo4j_graphrag.llm import OllamaLLM + llm = OllamaLLM(model_name="orca-mini") llm.invoke("say something") @@ -428,6 +427,7 @@ Currently, this package supports the following embedders: - :ref:`mistralaiembeddings` - :ref:`cohereembeddings` - :ref:`azureopenaiembeddings` +- :ref:`ollamaembeddings` The `OpenAIEmbeddings` was illustrated previously. Here is how to use the `SentenceTransformerEmbeddings`: @@ -438,31 +438,7 @@ The `OpenAIEmbeddings` was illustrated previously. Here is how to use the `Sente embedder = SentenceTransformerEmbeddings(model="all-MiniLM-L6-v2") # Note: this is the default model -If another embedder is desired, a custom embedder can be created. For example, consider -the following implementation of an embedder that wraps the `OllamaEmbedding` model from LlamaIndex: - -.. code:: python - - from llama_index.embeddings.ollama import OllamaEmbedding - from neo4j_graphrag.embeddings.base import Embedder - - class OllamaEmbedder(Embedder): - def __init__(self, ollama_embedding): - self.embedder = ollama_embedding - - def embed_query(self, text: str) -> list[float]: - embedding = self.embedder.get_text_embedding_batch( - [text], show_progress=True - ) - return embedding[0] - - ollama_embedding = OllamaEmbedding( - model_name="llama3", - base_url="http://localhost:11434", - ollama_additional_kwargs={"mirostat": 0}, - ) - embedder = OllamaEmbedder(ollama_embedding) - vector = embedder.embed_query("some text") +If another embedder is desired, a custom embedder can be created, using the `Embedder` interface. Other Vector Retriever Configuration diff --git a/examples/customize/embeddings/ollama_embeddings.py b/examples/customize/embeddings/ollama_embeddings.py index 2a5a0046..243cf4b6 100644 --- a/examples/customize/embeddings/ollama_embeddings.py +++ b/examples/customize/embeddings/ollama_embeddings.py @@ -1,15 +1,10 @@ """This example demonstrate how to embed a text into a vector -using OpenAI models and API. +using a local model served by Ollama. """ -from neo4j_graphrag.embeddings import OpenAIEmbeddings +from neo4j_graphrag.embeddings import OllamaEmbeddings -# not used but needs to be provided -api_key = "ollama" - -embeder = OpenAIEmbeddings( - base_url="http://localhost:11434/v1", - api_key=api_key, +embeder = OllamaEmbeddings( model="", ) res = embeder.embed_query("my question") diff --git a/examples/customize/llms/ollama_llm.py b/examples/customize/llms/ollama_llm.py index 56148042..f815d4db 100644 --- a/examples/customize/llms/ollama_llm.py +++ b/examples/customize/llms/ollama_llm.py @@ -1,12 +1,11 @@ -from neo4j_graphrag.llm import LLMResponse, OpenAILLM +"""This example demonstrate how to invoke an LLM using a local model +served by Ollama. +""" -# not used but needs to be provided -api_key = "ollama" +from neo4j_graphrag.llm import LLMResponse, OllamaLLM -llm = OpenAILLM( - base_url="http://localhost:11434/v1", +llm = OllamaLLM( model_name="", - api_key=api_key, ) res: LLMResponse = llm.invoke("What is the additive color model?") print(res.content) diff --git a/poetry.lock b/poetry.lock index a662e717..21f8bf66 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1026,13 +1026,13 @@ grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] [[package]] name = "google-auth" -version = "2.36.0" +version = "2.37.0" description = "Google Authentication Library" optional = true python-versions = ">=3.7" files = [ - {file = "google_auth-2.36.0-py2.py3-none-any.whl", hash = "sha256:51a15d47028b66fd36e5c64a82d2d57480075bccc7da37cde257fc94177a61fb"}, - {file = "google_auth-2.36.0.tar.gz", hash = "sha256:545e9618f2df0bcbb7dcbc45a546485b1212624716975a1ea5ae8149ce769ab1"}, + {file = "google_auth-2.37.0-py2.py3-none-any.whl", hash = "sha256:42664f18290a6be591be5329a96fe30184be1a1badb7292a7f686a9659de9ca0"}, + {file = "google_auth-2.37.0.tar.gz", hash = "sha256:0054623abf1f9c83492c63d3f47e77f0a544caa3d40b2d98e099a611c2dd5d00"}, ] [package.dependencies] @@ -1043,6 +1043,7 @@ rsa = ">=3.1.4,<5" [package.extras] aiohttp = ["aiohttp (>=3.6.2,<4.0.0.dev0)", "requests (>=2.20.0,<3.0.0.dev0)"] enterprise-cert = ["cryptography", "pyopenssl"] +pyjwt = ["cryptography (>=38.0.3)", "pyjwt (>=2.0)"] pyopenssl = ["cryptography (>=38.0.3)", "pyopenssl (>=20.0.0)"] reauth = ["pyu2f (>=0.1.5)"] requests = ["requests (>=2.20.0,<3.0.0.dev0)"] @@ -1591,13 +1592,13 @@ trio = ["trio (>=0.22.0,<1.0)"] [[package]] name = "httpx" -version = "0.27.0" +version = "0.27.2" description = "The next generation HTTP client." optional = true python-versions = ">=3.8" files = [ - {file = "httpx-0.27.0-py3-none-any.whl", hash = "sha256:71d5465162c13681bff01ad59b2cc68dd838ea1f10e51574bac27103f00c91a5"}, - {file = "httpx-0.27.0.tar.gz", hash = "sha256:a0cb88a46f32dc874e04ee956e4c2764aba2aa228f650b06788ba6bda2962ab5"}, + {file = "httpx-0.27.2-py3-none-any.whl", hash = "sha256:7bb2708e112d8fdd7829cd4243970f0c223274051cb35ee80c03301ee29a3df0"}, + {file = "httpx-0.27.2.tar.gz", hash = "sha256:f7c2be1d2f3c3c3160d441802406b206c2b76f5947b11115e6df10c6c65e66c2"}, ] [package.dependencies] @@ -1613,6 +1614,7 @@ brotli = ["brotli", "brotlicffi"] cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] +zstd = ["zstandard (>=0.18.0)"] [[package]] name = "httpx-sse" @@ -1976,13 +1978,13 @@ langchain-core = ">=0.3.15,<0.4.0" [[package]] name = "langsmith" -version = "0.2.2" +version = "0.2.3" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = true python-versions = "<4.0,>=3.9" files = [ - {file = "langsmith-0.2.2-py3-none-any.whl", hash = "sha256:4786d7dcdbc25e43d4a1bf70bbe12938a9eb2364feec8f6fc4d967162519b367"}, - {file = "langsmith-0.2.2.tar.gz", hash = "sha256:6f515ee41ae80968a7d552be1154414ccde57a0a534c960c8c3cd1835734095f"}, + {file = "langsmith-0.2.3-py3-none-any.whl", hash = "sha256:4958b6e918f57fedba6ddc55b8534d1e06478bb44c779aa73713ce898ca6ae87"}, + {file = "langsmith-0.2.3.tar.gz", hash = "sha256:54c231b07fdeb0f8472925074a0ec0ed2cb654a0437d63c6ccf76a9da635900d"}, ] [package.dependencies] @@ -2864,6 +2866,21 @@ files = [ {file = "nvidia_nvtx_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:641dccaaa1139f3ffb0d3164b4b84f9d253397e38246a4f2f36728b48566d485"}, ] +[[package]] +name = "ollama" +version = "0.4.4" +description = "The official Python client for Ollama." +optional = true +python-versions = "<4.0,>=3.8" +files = [ + {file = "ollama-0.4.4-py3-none-any.whl", hash = "sha256:0f466e845e2205a1cbf5a2fef4640027b90beaa3b06c574426d8b6b17fd6e139"}, + {file = "ollama-0.4.4.tar.gz", hash = "sha256:e1db064273c739babc2dde9ea84029c4a43415354741b6c50939ddd3dd0f7ffb"}, +] + +[package.dependencies] +httpx = ">=0.27.0,<0.28.0" +pydantic = ">=2.9.0,<3.0.0" + [[package]] name = "openai" version = "1.57.2" @@ -5053,23 +5070,22 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess [[package]] name = "weaviate-client" -version = "4.9.6" +version = "4.10.1" description = "A python native Weaviate client" optional = true python-versions = ">=3.9" files = [ - {file = "weaviate_client-4.9.6-py3-none-any.whl", hash = "sha256:1d3b551939c0f7314f25e417cbcf4cf34e7adf942627993eef36ae6b4a044673"}, - {file = "weaviate_client-4.9.6.tar.gz", hash = "sha256:56d67c40fc94b0d53e81e0aa4477baaebbf3646fbec26551df66e396a72adcb6"}, + {file = "weaviate_client-4.10.1-py3-none-any.whl", hash = "sha256:eedd693b2bbd02279109f9a37d2c33492bb460e480f4c5a4297530d4ad78469e"}, + {file = "weaviate_client-4.10.1.tar.gz", hash = "sha256:7a1513d0ce7d918100d255224da52fcfe3d0834a2d22069853ab95a2a17f102c"}, ] [package.dependencies] authlib = ">=1.2.1,<1.3.2" -grpcio = ">=1.57.0,<2.0.0" -grpcio-health-checking = ">=1.57.0,<2.0.0" -grpcio-tools = ">=1.57.0,<2.0.0" -httpx = ">=0.25.0,<=0.27.0" -pydantic = ">=2.5.0,<3.0.0" -requests = ">=2.30.0,<3.0.0" +grpcio = ">=1.66.2,<2.0.0" +grpcio-health-checking = ">=1.66.2,<2.0.0" +grpcio-tools = ">=1.66.2,<2.0.0" +httpx = ">=0.26.0,<0.29.0" +pydantic = ">=2.8.0,<3.0.0" validators = "0.34.0" [[package]] @@ -5269,6 +5285,7 @@ experimental = ["fsspec", "langchain-text-splitters", "llama-index", "pygraphviz google = ["google-cloud-aiplatform"] kg-creation-tools = ["pygraphviz", "pygraphviz"] mistralai = ["mistralai"] +ollama = ["ollama"] openai = ["openai"] pinecone = ["pinecone-client"] qdrant = ["qdrant-client"] @@ -5278,4 +5295,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = "^3.9.0" -content-hash = "5f729b5f7f31021258d04fcf26e2310f685f1b97113e888ba346df5c7393d4e4" +content-hash = "c2ab4d7e2b6f8dfa668cdab458068d085425115cbb08c201d3eb3be32c47325b" diff --git a/pyproject.toml b/pyproject.toml index 604fde99..0cba753d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ anthropic = { version = "^0.36.0", optional = true} sentence-transformers = {version = "^3.0.0", optional = true } json-repair = "^0.30.2" types-pyyaml = "^6.0.12.20240917" +ollama = {version = "^0.4.4", optional = true} [tool.poetry.group.dev.dependencies] urllib3 = "<2" @@ -69,6 +70,7 @@ pinecone = ["pinecone-client"] google = ["google-cloud-aiplatform"] cohere = ["cohere"] anthropic = ["anthropic"] +ollama = ["ollama"] openai = ["openai"] mistralai = ["mistralai"] qdrant = ["qdrant-client"] diff --git a/src/neo4j_graphrag/embeddings/__init__.py b/src/neo4j_graphrag/embeddings/__init__.py index 6c30951d..9398eefe 100644 --- a/src/neo4j_graphrag/embeddings/__init__.py +++ b/src/neo4j_graphrag/embeddings/__init__.py @@ -15,6 +15,7 @@ from .base import Embedder from .cohere import CohereEmbeddings from .mistral import MistralAIEmbeddings +from .ollama import OllamaEmbeddings from .openai import AzureOpenAIEmbeddings, OpenAIEmbeddings from .sentence_transformers import SentenceTransformerEmbeddings from .vertexai import VertexAIEmbeddings @@ -22,6 +23,7 @@ __all__ = [ "Embedder", "SentenceTransformerEmbeddings", + "OllamaEmbeddings", "OpenAIEmbeddings", "AzureOpenAIEmbeddings", "VertexAIEmbeddings", diff --git a/src/neo4j_graphrag/embeddings/mistral.py b/src/neo4j_graphrag/embeddings/mistral.py index 8c352263..09943019 100644 --- a/src/neo4j_graphrag/embeddings/mistral.py +++ b/src/neo4j_graphrag/embeddings/mistral.py @@ -57,8 +57,7 @@ def embed_query(self, text: str, **kwargs: Any) -> list[float]: **kwargs (Any): Additional keyword arguments to pass to the Mistral AI client. """ embeddings_batch_response = self.mistral_client.embeddings.create( - model=self.model, - inputs=[text], + model=self.model, inputs=[text], **kwargs ) if embeddings_batch_response is None or not embeddings_batch_response.data: raise EmbeddingsGenerationError("Failed to retrieve embeddings.") diff --git a/src/neo4j_graphrag/embeddings/ollama.py b/src/neo4j_graphrag/embeddings/ollama.py new file mode 100644 index 00000000..a3d804a0 --- /dev/null +++ b/src/neo4j_graphrag/embeddings/ollama.py @@ -0,0 +1,65 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any + +from neo4j_graphrag.embeddings.base import Embedder +from neo4j_graphrag.exceptions import EmbeddingsGenerationError + + +class OllamaEmbeddings(Embedder): + """ + Ollama embeddings class. + This class uses the ollama Python client to generate vector embeddings for text data. + + Args: + model (str): The name of the Mistral AI text embedding model to use. Defaults to "mistral-embed". + """ + + def __init__(self, model: str, **kwargs: Any) -> None: + try: + import ollama + except ImportError: + raise ImportError( + "Could not import ollama python client. " + "Please install it with `pip install ollama`." + ) + self.model = model + self.client = ollama.Client(**kwargs) + + def embed_query(self, text: str, **kwargs: Any) -> list[float]: + """ + Generate embeddings for a given query using an Ollama text embedding model. + + Args: + text (str): The text to generate an embedding for. + **kwargs (Any): Additional keyword arguments to pass to the Ollama client. + """ + embeddings_response = self.client.embed( + model=self.model, + input=text, + **kwargs, + ) + + if embeddings_response is None or embeddings_response.embeddings is None: + raise EmbeddingsGenerationError("Failed to retrieve embeddings.") + + embedding = embeddings_response.embeddings + if not isinstance(embedding, list): + raise EmbeddingsGenerationError("Embedding is not a list of floats.") + + return embedding diff --git a/src/neo4j_graphrag/llm/__init__.py b/src/neo4j_graphrag/llm/__init__.py index a8fbdd61..a9ece5cc 100644 --- a/src/neo4j_graphrag/llm/__init__.py +++ b/src/neo4j_graphrag/llm/__init__.py @@ -16,6 +16,7 @@ from .base import LLMInterface from .cohere_llm import CohereLLM from .mistralai_llm import MistralAILLM +from .ollama_llm import OllamaLLM from .openai_llm import AzureOpenAILLM, OpenAILLM from .types import LLMResponse from .vertexai_llm import VertexAILLM @@ -25,6 +26,7 @@ "CohereLLM", "LLMResponse", "LLMInterface", + "OllamaLLM", "OpenAILLM", "VertexAILLM", "AzureOpenAILLM", diff --git a/src/neo4j_graphrag/llm/ollama_llm.py b/src/neo4j_graphrag/llm/ollama_llm.py new file mode 100644 index 00000000..d82ccd25 --- /dev/null +++ b/src/neo4j_graphrag/llm/ollama_llm.py @@ -0,0 +1,76 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional + +from neo4j_graphrag.exceptions import LLMGenerationError + +from .base import LLMInterface +from .types import LLMResponse + + +class OllamaLLM(LLMInterface): + def __init__( + self, + model_name: str, + model_params: Optional[dict[str, Any]] = None, + **kwargs: Any, + ): + try: + import ollama + except ImportError: + raise ImportError( + "Could not import ollama Python client. " + "Please install it with `pip install ollama`." + ) + super().__init__(model_name, model_params, **kwargs) + self.ollama = ollama + self.client = ollama.Client( + **kwargs, + ) + self.async_client = ollama.AsyncClient( + **kwargs, + ) + + def invoke(self, input: str) -> LLMResponse: + try: + response = self.client.chat( + model=self.model_name, + messages=[ + { + "role": "user", + "content": input, + }, + ], + ) + content = response.message.content or "" + return LLMResponse(content=content) + except self.ollama.ResponseError as e: + raise LLMGenerationError(e) + + async def ainvoke(self, input: str) -> LLMResponse: + try: + response = await self.async_client.chat( + model=self.model_name, + messages=[ + { + "role": "user", + "content": input, + }, + ], + ) + content = response.message.content or "" + return LLMResponse(content=content) + except self.ollama.ResponseError as e: + raise LLMGenerationError(e) diff --git a/tests/unit/embeddings/test_ollama_embedder.py b/tests/unit/embeddings/test_ollama_embedder.py new file mode 100644 index 00000000..dded2b87 --- /dev/null +++ b/tests/unit/embeddings/test_ollama_embedder.py @@ -0,0 +1,35 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import MagicMock, Mock, patch + +import pytest +from neo4j_graphrag.embeddings.ollama import OllamaEmbeddings + + +@patch("builtins.__import__", side_effect=ImportError) +def test_ollama_embedder_missing_dependency(mock_import: Mock) -> None: + with pytest.raises(ImportError): + OllamaEmbeddings(model="test") + + +@patch("builtins.__import__") +def test_ollama_embedder_happy_path(mock_import: Mock) -> None: + mock_import.return_value.Client.return_value.embed.return_value = MagicMock( + embeddings=[1.0, 2.0], + ) + embedder = OllamaEmbeddings(model="test") + res = embedder.embed_query("my text") + assert isinstance(res, list) + assert res == [1.0, 2.0] diff --git a/tests/unit/llm/test_ollama_llm.py b/tests/unit/llm/test_ollama_llm.py new file mode 100644 index 00000000..f7be308b --- /dev/null +++ b/tests/unit/llm/test_ollama_llm.py @@ -0,0 +1,46 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import MagicMock, Mock, patch + +import ollama +import pytest +from neo4j_graphrag.llm import LLMResponse +from neo4j_graphrag.llm.ollama_llm import OllamaLLM + + +def get_mock_ollama() -> MagicMock: + mock = MagicMock() + mock.ResponseError = ollama.ResponseError + return mock + + +@patch("builtins.__import__", side_effect=ImportError) +def test_ollama_llm_missing_dependency(mock_import: Mock) -> None: + with pytest.raises(ImportError): + OllamaLLM(model_name="gpt-4o") + + +@patch("builtins.__import__") +def test_ollama_llm_happy_path(mock_import: Mock) -> None: + mock_ollama = get_mock_ollama() + mock_import.return_value = mock_ollama + mock_ollama.Client.return_value.chat.return_value = MagicMock( + message=MagicMock(content="ollama chat response"), + ) + llm = OllamaLLM(model_name="gpt") + + res = llm.invoke("my text") + assert isinstance(res, LLMResponse) + assert res.content == "ollama chat response"