From d1cef28a614ecc8f9532f29611c11232b2a33054 Mon Sep 17 00:00:00 2001 From: Alex Thomas Date: Thu, 19 Sep 2024 16:16:04 +0100 Subject: [PATCH] Added Vertex AI LLM class (#141) * Added Vertex AI LLM class * Updated docstrings * Updated unit test workflow * Removed duplicate poetry install for pr workflow * Removed --no-root from poetry install in pr.yaml * Fixed typo * Updated docs * Updated CHANGELOG for previous PR * Updated CHANGELOG * Fixed typo --- .github/workflows/pr.yaml | 4 +- CHANGELOG.md | 6 ++ docs/source/api.rst | 9 ++- src/neo4j_graphrag/embeddings/vertexai.py | 12 +-- src/neo4j_graphrag/llm/__init__.py | 2 +- src/neo4j_graphrag/llm/base.py | 8 +- src/neo4j_graphrag/llm/openai_llm.py | 3 +- src/neo4j_graphrag/llm/vertexai.py | 98 +++++++++++++++++++++++ tests/unit/llm/test_vertexai_llm.py | 54 +++++++++++++ 9 files changed, 182 insertions(+), 14 deletions(-) create mode 100644 src/neo4j_graphrag/llm/vertexai.py create mode 100644 tests/unit/llm/test_vertexai_llm.py diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index c93dd98a..3bdc15e3 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -32,9 +32,7 @@ jobs: key: venv-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }} - name: Install dependencies if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' - run: poetry install --no-interaction --no-root - - name: Install root project - run: poetry install --no-interaction + run: poetry install --no-interaction --extras external_clients - name: Check format and linting run: | poetry run ruff check --select I . diff --git a/CHANGELOG.md b/CHANGELOG.md index d458e9df..f8470645 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,12 @@ - Fix bug in `Text2CypherRetriever` using `custom_prompt` arg where the `search` method would not inject the `query_text` content. - Add feature to include kwargs in `Text2CypherRetriever.search()` that will be injected into a custom prompt, if provided. - Add validation to `custom_prompt` parameter of `Text2CypherRetriever` to ensure that `query_text` placeholder exists in prompt. +- Introduced a fixed size text splitter component for splitting text into specified fixed size chunks with overlap. Updated examples and tests to utilize this new component. +- Introduced Vertex AI LLM class for integrating Vertex AI models. +- Added unit tests for the Vertex AI LLM class. + +### Fixed +- Resolved import issue with the Vertex AI Embeddings class. ### Changed - Moved the Embedder class to the neo4j_graphrag.embeddings directory for better organization alongside other custom embedders. diff --git a/docs/source/api.rst b/docs/source/api.rst index c475ddb5..b5597436 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -174,11 +174,16 @@ LLMInterface OpenAILLM ---------- +========= .. autoclass:: neo4j_graphrag.llm.OpenAILLM :members: +VertexAILLM +=========== + +.. autoclass:: neo4j_graphrag.llm.vertexai.VertexAILLM + :members: PromptTemplate ============== @@ -389,4 +394,4 @@ PipelineStatusUpdateError ========================= .. autoclass:: neo4j_graphrag.experimental.pipeline.exceptions.PipelineStatusUpdateError - :show-inheritance: \ No newline at end of file + :show-inheritance: diff --git a/src/neo4j_graphrag/embeddings/vertexai.py b/src/neo4j_graphrag/embeddings/vertexai.py index 3454f42e..dae547e7 100644 --- a/src/neo4j_graphrag/embeddings/vertexai.py +++ b/src/neo4j_graphrag/embeddings/vertexai.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from __future__ import annotations from typing import Any @@ -22,10 +21,8 @@ try: from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel except ImportError: - raise ImportError( - "Could not import Vertex AI python client. " - "Please install it with `pip install google-cloud-aiplatform`." - ) + TextEmbeddingInput = None + TextEmbeddingModel = None class VertexAIEmbeddings(Embedder): @@ -38,6 +35,11 @@ class VertexAIEmbeddings(Embedder): """ def __init__(self, model: str = "text-embedding-004") -> None: + if TextEmbeddingInput is None or TextEmbeddingInput is None: + raise ImportError( + "Could not import Vertex AI Python client. " + "Please install it with `pip install google-cloud-aiplatform`." + ) self.vertexai_model = TextEmbeddingModel.from_pretrained(model) def embed_query( diff --git a/src/neo4j_graphrag/llm/__init__.py b/src/neo4j_graphrag/llm/__init__.py index 8c2baad4..f2e6c749 100644 --- a/src/neo4j_graphrag/llm/__init__.py +++ b/src/neo4j_graphrag/llm/__init__.py @@ -16,4 +16,4 @@ from .openai_llm import OpenAILLM from .types import LLMResponse -__all__ = ["LLMResponse", "LLMInterface", "OpenAILLM"] +__all__ = ["LLMResponse", "LLMInterface", "OpenAILLM", "VertexAILLM"] diff --git a/src/neo4j_graphrag/llm/base.py b/src/neo4j_graphrag/llm/base.py index b9a639b7..3d98423b 100644 --- a/src/neo4j_graphrag/llm/base.py +++ b/src/neo4j_graphrag/llm/base.py @@ -21,7 +21,13 @@ class LLMInterface(ABC): - """Interface for large language models.""" + """Interface for large language models. + + Args: + model_name (str): The name of the language model. + model_params (Optional[dict], optional): Additional parameters passed to the model when text is sent to it. Defaults to None. + **kwargs (Any): Arguments passed to the model when for the class is initialised. Defaults to None. + """ def __init__( self, diff --git a/src/neo4j_graphrag/llm/openai_llm.py b/src/neo4j_graphrag/llm/openai_llm.py index 334277b4..2c622592 100644 --- a/src/neo4j_graphrag/llm/openai_llm.py +++ b/src/neo4j_graphrag/llm/openai_llm.py @@ -37,8 +37,7 @@ def __init__( Args: model_name (str): - model_params (str): Parameters like temperature and such that will be - passed to the model + model_params (str): Parameters like temperature that will be passed to the model when text is sent to it kwargs: All other parameters will be passed to the openai.OpenAI init. """ diff --git a/src/neo4j_graphrag/llm/vertexai.py b/src/neo4j_graphrag/llm/vertexai.py new file mode 100644 index 00000000..b0820e62 --- /dev/null +++ b/src/neo4j_graphrag/llm/vertexai.py @@ -0,0 +1,98 @@ +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Any, Optional + +from neo4j_graphrag.exceptions import LLMGenerationError +from neo4j_graphrag.llm.base import LLMInterface +from neo4j_graphrag.llm.types import LLMResponse + +try: + from vertexai.generative_models import GenerativeModel, ResponseValidationError +except ImportError: + GenerativeModel = None + ResponseValidationError = None + + +class VertexAILLM(LLMInterface): + """Interface for large language models on Vertex AI + + Args: + model_name (str, optional): Name of the LLM to use. Defaults to "gemini-1.5-flash-001". + model_params (Optional[dict], optional): Additional parameters passed to the model when text is sent to it. Defaults to None. + **kwargs (Any): Arguments passed to the model when for the class is initialised. Defaults to None. + + Raises: + LLMGenerationError: If there's an error generating the response from the model. + + Example: + + .. code-block:: python + + from neo4j_graphrag.llm import VertexAILLM + from vertexai.generative_models import GenerationConfig + + generation_config = GenerationConfig(temperature=0.0) + llm = VertexAILLM( + model_name="gemini-1.5-flash-001", generation_config=generation_config + ) + llm.invoke("Who is the mother of Paul Atreides?") + """ + + def __init__( + self, + model_name: str = "gemini-1.5-flash-001", + model_params: Optional[dict[str, Any]] = None, + **kwargs: Any, + ): + if GenerativeModel is None or ResponseValidationError is None: + raise ImportError( + "Could not import Vertex AI Python client. " + "Please install it with `pip install google-cloud-aiplatform`." + ) + super().__init__(model_name, model_params) + self.model = GenerativeModel(model_name=model_name, **kwargs) + + def invoke(self, input: str) -> LLMResponse: + """Sends text to the LLM and returns a response. + + Args: + input (str): The text to send to the LLM. + + Returns: + LLMResponse: The response from the LLM. + """ + try: + response = self.model.generate_content(input, **self.model_params) + return LLMResponse(content=response.text) + except ResponseValidationError as e: + raise LLMGenerationError(e) + + async def ainvoke(self, input: str) -> LLMResponse: + """Asynchronously sends text to the LLM and returns a response. + + Args: + input (str): The text to send to the LLM. + + Returns: + LLMResponse: The response from the LLM. + """ + try: + response = await self.model.generate_content_async( + input, **self.model_params + ) + return LLMResponse(content=response.text) + except ResponseValidationError as e: + raise LLMGenerationError(e) diff --git a/tests/unit/llm/test_vertexai_llm.py b/tests/unit/llm/test_vertexai_llm.py new file mode 100644 index 00000000..a32d6855 --- /dev/null +++ b/tests/unit/llm/test_vertexai_llm.py @@ -0,0 +1,54 @@ +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import pytest +from neo4j_graphrag.llm.vertexai import VertexAILLM + + +@patch("neo4j_graphrag.llm.vertexai.GenerativeModel", None) +def test_vertexai_llm_missing_dependency() -> None: + with pytest.raises(ImportError): + VertexAILLM(model_name="gemini-1.5-flash-001") + + +@patch("neo4j_graphrag.llm.vertexai.GenerativeModel") +def test_invoke_happy_path(GenerativeModelMock: MagicMock) -> None: + mock_response = Mock() + mock_response.text = "Return text" + mock_model = GenerativeModelMock.return_value + mock_model.generate_content.return_value = mock_response + model_params = {"temperature": 0.5} + llm = VertexAILLM("gemini-1.5-flash-001", model_params) + input_text = "may thy knife chip and shatter" + response = llm.invoke(input_text) + assert response.content == "Return text" + llm.model.generate_content.assert_called_once_with(input_text, **model_params) + + +@pytest.mark.asyncio +@patch("neo4j_graphrag.llm.vertexai.GenerativeModel") +async def test_ainvoke_happy_path(GenerativeModelMock: MagicMock) -> None: + mock_response = AsyncMock() + mock_response.text = "Return text" + mock_model = GenerativeModelMock.return_value + mock_model.generate_content_async = AsyncMock(return_value=mock_response) + model_params = {"temperature": 0.5} + llm = VertexAILLM("gemini-1.5-flash-001", model_params) + input_text = "may thy knife chip and shatter" + response = await llm.ainvoke(input_text) + assert response.content == "Return text" + llm.model.generate_content_async.assert_called_once_with(input_text, **model_params)