Skip to content

Commit

Permalink
Added Vertex AI LLM class (#141)
Browse files Browse the repository at this point in the history
* Added Vertex AI LLM class

* Updated docstrings

* Updated unit test workflow

* Removed duplicate poetry install for pr workflow

* Removed --no-root from poetry install in pr.yaml

* Fixed typo

* Updated docs

* Updated CHANGELOG for previous PR

* Updated CHANGELOG

* Fixed typo
  • Loading branch information
alexthomas93 authored Sep 19, 2024
1 parent bf65ddd commit d1cef28
Show file tree
Hide file tree
Showing 9 changed files with 182 additions and 14 deletions.
4 changes: 1 addition & 3 deletions .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ jobs:
key: venv-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}
- name: Install dependencies
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
run: poetry install --no-interaction --no-root
- name: Install root project
run: poetry install --no-interaction
run: poetry install --no-interaction --extras external_clients
- name: Check format and linting
run: |
poetry run ruff check --select I .
Expand Down
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
- Fix bug in `Text2CypherRetriever` using `custom_prompt` arg where the `search` method would not inject the `query_text` content.
- Add feature to include kwargs in `Text2CypherRetriever.search()` that will be injected into a custom prompt, if provided.
- Add validation to `custom_prompt` parameter of `Text2CypherRetriever` to ensure that `query_text` placeholder exists in prompt.
- Introduced a fixed size text splitter component for splitting text into specified fixed size chunks with overlap. Updated examples and tests to utilize this new component.
- Introduced Vertex AI LLM class for integrating Vertex AI models.
- Added unit tests for the Vertex AI LLM class.

### Fixed
- Resolved import issue with the Vertex AI Embeddings class.

### Changed
- Moved the Embedder class to the neo4j_graphrag.embeddings directory for better organization alongside other custom embedders.
Expand Down
9 changes: 7 additions & 2 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,16 @@ LLMInterface


OpenAILLM
---------
=========

.. autoclass:: neo4j_graphrag.llm.OpenAILLM
:members:

VertexAILLM
===========

.. autoclass:: neo4j_graphrag.llm.vertexai.VertexAILLM
:members:

PromptTemplate
==============
Expand Down Expand Up @@ -389,4 +394,4 @@ PipelineStatusUpdateError
=========================

.. autoclass:: neo4j_graphrag.experimental.pipeline.exceptions.PipelineStatusUpdateError
:show-inheritance:
:show-inheritance:
12 changes: 7 additions & 5 deletions src/neo4j_graphrag/embeddings/vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Any
Expand All @@ -22,10 +21,8 @@
try:
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
except ImportError:
raise ImportError(
"Could not import Vertex AI python client. "
"Please install it with `pip install google-cloud-aiplatform`."
)
TextEmbeddingInput = None
TextEmbeddingModel = None


class VertexAIEmbeddings(Embedder):
Expand All @@ -38,6 +35,11 @@ class VertexAIEmbeddings(Embedder):
"""

def __init__(self, model: str = "text-embedding-004") -> None:
if TextEmbeddingInput is None or TextEmbeddingInput is None:
raise ImportError(
"Could not import Vertex AI Python client. "
"Please install it with `pip install google-cloud-aiplatform`."
)
self.vertexai_model = TextEmbeddingModel.from_pretrained(model)

def embed_query(
Expand Down
2 changes: 1 addition & 1 deletion src/neo4j_graphrag/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
from .openai_llm import OpenAILLM
from .types import LLMResponse

__all__ = ["LLMResponse", "LLMInterface", "OpenAILLM"]
__all__ = ["LLMResponse", "LLMInterface", "OpenAILLM", "VertexAILLM"]
8 changes: 7 additions & 1 deletion src/neo4j_graphrag/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@


class LLMInterface(ABC):
"""Interface for large language models."""
"""Interface for large language models.
Args:
model_name (str): The name of the language model.
model_params (Optional[dict], optional): Additional parameters passed to the model when text is sent to it. Defaults to None.
**kwargs (Any): Arguments passed to the model when for the class is initialised. Defaults to None.
"""

def __init__(
self,
Expand Down
3 changes: 1 addition & 2 deletions src/neo4j_graphrag/llm/openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ def __init__(
Args:
model_name (str):
model_params (str): Parameters like temperature and such that will be
passed to the model
model_params (str): Parameters like temperature that will be passed to the model when text is sent to it
kwargs: All other parameters will be passed to the openai.OpenAI init.
"""
Expand Down
98 changes: 98 additions & 0 deletions src/neo4j_graphrag/llm/vertexai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Neo4j Sweden AB [https://neo4j.com]
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import Any, Optional

from neo4j_graphrag.exceptions import LLMGenerationError
from neo4j_graphrag.llm.base import LLMInterface
from neo4j_graphrag.llm.types import LLMResponse

try:
from vertexai.generative_models import GenerativeModel, ResponseValidationError
except ImportError:
GenerativeModel = None
ResponseValidationError = None


class VertexAILLM(LLMInterface):
"""Interface for large language models on Vertex AI
Args:
model_name (str, optional): Name of the LLM to use. Defaults to "gemini-1.5-flash-001".
model_params (Optional[dict], optional): Additional parameters passed to the model when text is sent to it. Defaults to None.
**kwargs (Any): Arguments passed to the model when for the class is initialised. Defaults to None.
Raises:
LLMGenerationError: If there's an error generating the response from the model.
Example:
.. code-block:: python
from neo4j_graphrag.llm import VertexAILLM
from vertexai.generative_models import GenerationConfig
generation_config = GenerationConfig(temperature=0.0)
llm = VertexAILLM(
model_name="gemini-1.5-flash-001", generation_config=generation_config
)
llm.invoke("Who is the mother of Paul Atreides?")
"""

def __init__(
self,
model_name: str = "gemini-1.5-flash-001",
model_params: Optional[dict[str, Any]] = None,
**kwargs: Any,
):
if GenerativeModel is None or ResponseValidationError is None:
raise ImportError(
"Could not import Vertex AI Python client. "
"Please install it with `pip install google-cloud-aiplatform`."
)
super().__init__(model_name, model_params)
self.model = GenerativeModel(model_name=model_name, **kwargs)

def invoke(self, input: str) -> LLMResponse:
"""Sends text to the LLM and returns a response.
Args:
input (str): The text to send to the LLM.
Returns:
LLMResponse: The response from the LLM.
"""
try:
response = self.model.generate_content(input, **self.model_params)
return LLMResponse(content=response.text)
except ResponseValidationError as e:
raise LLMGenerationError(e)

async def ainvoke(self, input: str) -> LLMResponse:
"""Asynchronously sends text to the LLM and returns a response.
Args:
input (str): The text to send to the LLM.
Returns:
LLMResponse: The response from the LLM.
"""
try:
response = await self.model.generate_content_async(
input, **self.model_params
)
return LLMResponse(content=response.text)
except ResponseValidationError as e:
raise LLMGenerationError(e)
54 changes: 54 additions & 0 deletions tests/unit/llm/test_vertexai_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Neo4j Sweden AB [https://neo4j.com]
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from unittest.mock import AsyncMock, MagicMock, Mock, patch

import pytest
from neo4j_graphrag.llm.vertexai import VertexAILLM


@patch("neo4j_graphrag.llm.vertexai.GenerativeModel", None)
def test_vertexai_llm_missing_dependency() -> None:
with pytest.raises(ImportError):
VertexAILLM(model_name="gemini-1.5-flash-001")


@patch("neo4j_graphrag.llm.vertexai.GenerativeModel")
def test_invoke_happy_path(GenerativeModelMock: MagicMock) -> None:
mock_response = Mock()
mock_response.text = "Return text"
mock_model = GenerativeModelMock.return_value
mock_model.generate_content.return_value = mock_response
model_params = {"temperature": 0.5}
llm = VertexAILLM("gemini-1.5-flash-001", model_params)
input_text = "may thy knife chip and shatter"
response = llm.invoke(input_text)
assert response.content == "Return text"
llm.model.generate_content.assert_called_once_with(input_text, **model_params)


@pytest.mark.asyncio
@patch("neo4j_graphrag.llm.vertexai.GenerativeModel")
async def test_ainvoke_happy_path(GenerativeModelMock: MagicMock) -> None:
mock_response = AsyncMock()
mock_response.text = "Return text"
mock_model = GenerativeModelMock.return_value
mock_model.generate_content_async = AsyncMock(return_value=mock_response)
model_params = {"temperature": 0.5}
llm = VertexAILLM("gemini-1.5-flash-001", model_params)
input_text = "may thy knife chip and shatter"
response = await llm.ainvoke(input_text)
assert response.content == "Return text"
llm.model.generate_content_async.assert_called_once_with(input_text, **model_params)

0 comments on commit d1cef28

Please sign in to comment.