Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: GA Context Cache Python SDK #4870

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions google/cloud/aiplatform/compat/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@
annotation_spec as annotation_spec_v1,
artifact as artifact_v1,
batch_prediction_job as batch_prediction_job_v1,
cached_content as cached_content_v1,
completion_stats as completion_stats_v1,
context as context_v1,
custom_job as custom_job_v1,
Expand Down
18 changes: 9 additions & 9 deletions tests/unit/vertexai/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import json
import mock
import pytest
from vertexai.preview import caching
from vertexai.caching import _caching
from google.cloud.aiplatform import initializer
import vertexai
from google.cloud.aiplatform_v1beta1.types.cached_content import (
Expand Down Expand Up @@ -141,7 +141,7 @@ def list_cached_contents(self, request):

@pytest.mark.usefixtures("google_auth_mock")
class TestCaching:
"""Unit tests for caching.CachedContent."""
"""Unit tests for _caching.CachedContent."""

def setup_method(self):
vertexai.init(
Expand All @@ -156,7 +156,7 @@ def test_constructor_with_full_resource_name(self, mock_get_cached_content):
full_resource_name = (
"projects/123/locations/europe-west1/cachedContents/contents-id"
)
cache = caching.CachedContent(
cache = _caching.CachedContent(
cached_content_name=full_resource_name,
)

Expand All @@ -166,7 +166,7 @@ def test_constructor_with_full_resource_name(self, mock_get_cached_content):
def test_constructor_with_only_content_id(self, mock_get_cached_content):
partial_resource_name = "contents-id"

cache = caching.CachedContent(
cache = _caching.CachedContent(
cached_content_name=partial_resource_name,
)

Expand All @@ -179,7 +179,7 @@ def test_constructor_with_only_content_id(self, mock_get_cached_content):
def test_get_with_content_id(self, mock_get_cached_content):
partial_resource_name = "contents-id"

cache = caching.CachedContent.get(
cache = _caching.CachedContent.get(
cached_content_name=partial_resource_name,
)

Expand All @@ -192,7 +192,7 @@ def test_get_with_content_id(self, mock_get_cached_content):
def test_create_with_real_payload(
self, mock_create_cached_content, mock_get_cached_content
):
cache = caching.CachedContent.create(
cache = _caching.CachedContent.create(
model_name="model-name",
system_instruction=GapicContent(
role="system", parts=[GapicPart(text="system instruction")]
Expand All @@ -219,7 +219,7 @@ def test_create_with_real_payload(
def test_create_with_real_payload_and_wrapped_type(
self, mock_create_cached_content, mock_get_cached_content
):
cache = caching.CachedContent.create(
cache = _caching.CachedContent.create(
model_name="model-name",
system_instruction="Please answer my questions with cool",
tools=[],
Expand All @@ -239,15 +239,15 @@ def test_create_with_real_payload_and_wrapped_type(
assert cache.display_name == _TEST_DISPLAY_NAME

def test_list(self, mock_list_cached_contents):
cached_contents = caching.CachedContent.list()
cached_contents = _caching.CachedContent.list()
for i, cached_content in enumerate(cached_contents):
assert cached_content.name == f"cached_content{i + 1}_from_list_request"
assert cached_content.model_name == f"model-name{i + 1}"

def test_print_a_cached_content(
self, mock_create_cached_content, mock_get_cached_content
):
cached_content = caching.CachedContent.create(
cached_content = _caching.CachedContent.create(
model_name="model-name",
system_instruction="Please answer my questions with cool",
tools=[],
Expand Down
14 changes: 7 additions & 7 deletions tests/unit/vertexai/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
gen_ai_cache_service,
)
from vertexai.generative_models import _function_calling_utils
from vertexai.preview import caching
from vertexai.caching import _caching


_TEST_PROJECT = "test-project"
Expand Down Expand Up @@ -655,11 +655,11 @@ def test_generative_model_from_cached_content(
project_location_prefix = (
f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/"
)
cached_content = caching.CachedContent(
cached_content = _caching.CachedContent(
"cached-content-id-in-from-cached-content-test"
)

model = preview_generative_models.GenerativeModel.from_cached_content(
model = generative_models.GenerativeModel.from_cached_content(
cached_content=cached_content
)

Expand Down Expand Up @@ -690,7 +690,7 @@ def test_generative_model_from_cached_content_with_resource_name(
f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/"
)

model = preview_generative_models.GenerativeModel.from_cached_content(
model = generative_models.GenerativeModel.from_cached_content(
cached_content="cached-content-id-in-from-cached-content-test"
)

Expand Down Expand Up @@ -848,7 +848,7 @@ def test_generate_content(
assert response5.text

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
target=prediction_service_v1.PredictionServiceClient,
attribute="generate_content",
new=lambda self, request: gapic_prediction_service_types.GenerateContentResponse(
candidates=[
Expand All @@ -870,11 +870,11 @@ def test_generate_content_with_cached_content(
self,
mock_get_cached_content_fixture,
):
cached_content = caching.CachedContent(
cached_content = _caching.CachedContent(
"cached-content-id-in-from-cached-content-test"
)

model = preview_generative_models.GenerativeModel.from_cached_content(
model = generative_models.GenerativeModel.from_cached_content(
cached_content=cached_content
)

Expand Down
25 changes: 25 additions & 0 deletions vertexai/caching/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Classes for working with the Gemini models."""

# We just want to re-export certain classes
# pylint: disable=g-multiple-import,g-importing-member
from vertexai.caching._caching import (
CachedContent,
)

__all__ = [
"CachedContent",
]
78 changes: 39 additions & 39 deletions vertexai/generative_models/_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1161,6 +1161,45 @@ def start_chat(
response_validation=response_validation,
)

@classmethod
def from_cached_content(
cls,
cached_content: Union[str, "caching.CachedContent"],
*,
generation_config: Optional[GenerationConfigType] = None,
safety_settings: Optional[SafetySettingsType] = None,
) -> "_GenerativeModel":
"""Creates a model from cached content.

Creates a model instance with an existing cached content. The cached
content becomes the prefix of the requesting contents.

Args:
cached_content: The cached content resource name or object.
generation_config: The generation config to use for this model.
safety_settings: The safety settings to use for this model.

Returns:
A model instance with the cached content wtih cached content as
prefix of all its requests.
"""
if isinstance(cached_content, str):
from vertexai.caching import _caching

cached_content = _caching.CachedContent.get(cached_content)
model_name = cached_content.model_name
model = cls(
model_name=model_name,
generation_config=generation_config,
safety_settings=safety_settings,
tools=None,
tool_config=None,
system_instruction=None,
)
model._cached_content = cached_content

return model


_SUCCESSFUL_FINISH_REASONS = [
gapic_content_types.Candidate.FinishReason.STOP,
Expand Down Expand Up @@ -3515,42 +3554,3 @@ def start_chat(
response_validation=response_validation,
responder=responder,
)

@classmethod
def from_cached_content(
cls,
cached_content: Union[str, "caching.CachedContent"],
*,
generation_config: Optional[GenerationConfigType] = None,
safety_settings: Optional[SafetySettingsType] = None,
) -> "_GenerativeModel":
"""Creates a model from cached content.

Creates a model instance with an existing cached content. The cached
content becomes the prefix of the requesting contents.

Args:
cached_content: The cached content resource name or object.
generation_config: The generation config to use for this model.
safety_settings: The safety settings to use for this model.

Returns:
A model instance with the cached content wtih cached content as
prefix of all its requests.
"""
if isinstance(cached_content, str):
from vertexai.preview import caching

cached_content = caching.CachedContent.get(cached_content)
model_name = cached_content.model_name
model = cls(
model_name=model_name,
generation_config=generation_config,
safety_settings=safety_settings,
tools=None,
tool_config=None,
system_instruction=None,
)
model._cached_content = cached_content

return model
Loading