Skip to content

Commit

Permalink
Added Vertex AI LLM class
Browse files Browse the repository at this point in the history
  • Loading branch information
alexthomas93 committed Sep 18, 2024
1 parent fc7d319 commit cedfbf5
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 5 deletions.
12 changes: 7 additions & 5 deletions src/neo4j_graphrag/embeddings/vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Any
Expand All @@ -22,10 +21,8 @@
try:
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
except ImportError:
raise ImportError(
"Could not import Vertex AI python client. "
"Please install it with `pip install google-cloud-aiplatform`."
)
TextEmbeddingInput = None
TextEmbeddingModel = None


class VertexAIEmbeddings(Embedder):
Expand All @@ -38,6 +35,11 @@ class VertexAIEmbeddings(Embedder):
"""

def __init__(self, model: str = "text-embedding-004") -> None:
if TextEmbeddingInput is None or TextEmbeddingInput is None:
raise ImportError(
"Could not import Vertex AI python client. "
"Please install it with `pip install google-cloud-aiplatform`."
)
self.vertexai_model = TextEmbeddingModel.from_pretrained(model)

def embed_query(
Expand Down
81 changes: 81 additions & 0 deletions src/neo4j_graphrag/llm/vertexai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Neo4j Sweden AB [https://neo4j.com]
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import Any, Optional

from neo4j_graphrag.exceptions import LLMGenerationError
from neo4j_graphrag.llm.base import LLMInterface
from neo4j_graphrag.llm.types import LLMResponse

try:
from vertexai.generative_models import GenerativeModel, ResponseValidationError
except ImportError:
GenerativeModel = None
ResponseValidationError = None


class VertexAILLM(LLMInterface):
"""Interface for large language models on Vertex AI
Args:
model_name (str, optional): Name of the LLM to use. Defaults to "gemini-1.5-flash-001".
model_params (Optional[Dict[str, Any]], optional): Parameters for passed to the LLM's invoke and ainvoke functions.
"""

def __init__(
self,
model_name: str = "gemini-1.5-flash-001",
model_params: Optional[dict[str, Any]] = None,
**kwargs: Any,
):
if GenerativeModel is None or ResponseValidationError is None:
raise ImportError(
"Could not import Vertex AI python client. "
"Please install it with `pip install google-cloud-aiplatform`."
)
super().__init__(model_name, model_params)
self.model = GenerativeModel(model_name=model_name, **kwargs)

def invoke(self, input: str) -> LLMResponse:
"""Sends text to the LLM and returns a response.
Args:
input (str): The text to send to the LLM.
Returns:
LLMResponse: The response from the LLM.
"""
try:
response = self.model.generate_content(input, **self.model_params)
return LLMResponse(content=response.text)
except ResponseValidationError as e:
raise LLMGenerationError(e)

async def ainvoke(self, input: str) -> LLMResponse:
"""Asynchronously sends text to the LLM and returns a response.
Args:
input (str): The text to send to the LLM.
Returns:
LLMResponse: The response from the LLM.
"""
try:
response = await self.model.generate_content_async(
input, **self.model_params
)
return LLMResponse(content=response.text)
except ResponseValidationError as e:
raise LLMGenerationError(e)
54 changes: 54 additions & 0 deletions tests/unit/llm/test_vertexai_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Neo4j Sweden AB [https://neo4j.com]
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from unittest.mock import AsyncMock, MagicMock, Mock, patch

import pytest
from neo4j_graphrag.llm.vertexai import VertexAILLM


@patch("neo4j_graphrag.llm.vertexai.GenerativeModel", None)
def test_vertexai_llm_missing_dependency() -> None:
with pytest.raises(ImportError):
VertexAILLM(model_name="gemini-1.5-flash-001")


@patch("neo4j_graphrag.llm.vertexai.GenerativeModel")
def test_invoke_happy_path(GenerativeModelMock: MagicMock) -> None:
mock_response = Mock()
mock_response.text = "Return text"
mock_model = GenerativeModelMock.return_value
mock_model.generate_content.return_value = mock_response
model_params = {"temperature": 0.5}
llm = VertexAILLM("gemini-1.5-flash-001", model_params)
input_text = "may thy knife chip and shatter"
response = llm.invoke(input_text)
assert response.content == "Return text"
llm.model.generate_content.assert_called_once_with(input_text, **model_params)


@pytest.mark.asyncio
@patch("neo4j_graphrag.llm.vertexai.GenerativeModel")
async def test_ainvoke_happy_path(GenerativeModelMock: MagicMock) -> None:
mock_response = AsyncMock()
mock_response.text = "Return text"
mock_model = GenerativeModelMock.return_value
mock_model.generate_content_async = AsyncMock(return_value=mock_response)
model_params = {"temperature": 0.5}
llm = VertexAILLM("gemini-1.5-flash-001", model_params)
input_text = "may thy knife chip and shatter"
response = await llm.ainvoke(input_text)
assert response.content == "Return text"
llm.model.generate_content_async.assert_called_once_with(input_text, **model_params)

0 comments on commit cedfbf5

Please sign in to comment.