Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: GA Context Cache Python SDK #4870

Merged
merged 1 commit into from
Mar 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions google/cloud/aiplatform/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,7 @@
services.featurestore_online_serving_service_client_v1
)
services.featurestore_service_client = services.featurestore_service_client_v1
# TODO(b/342585299): Temporary code. Switch to v1 once v1 is available.
services.gen_ai_cache_service_client = services.gen_ai_cache_service_client_v1beta1
services.gen_ai_cache_service_client = services.gen_ai_cache_service_client_v1
services.job_service_client = services.job_service_client_v1
services.model_garden_service_client = services.model_garden_service_client_v1
services.model_service_client = services.model_service_client_v1
Expand All @@ -203,8 +202,7 @@
types.annotation_spec = types.annotation_spec_v1
types.artifact = types.artifact_v1
types.batch_prediction_job = types.batch_prediction_job_v1
# TODO(b/342585299): Temporary code. Switch to v1 once v1 is available.
types.cached_content = types.cached_content_v1beta1
types.cached_content = types.cached_content_v1
types.completion_stats = types.completion_stats_v1
types.context = types.context_v1
types.custom_job = types.custom_job_v1
Expand Down
3 changes: 3 additions & 0 deletions google/cloud/aiplatform/compat/services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@
from google.cloud.aiplatform_v1.services.featurestore_service import (
client as featurestore_service_client_v1,
)
from google.cloud.aiplatform_v1.services.gen_ai_cache_service import (
client as gen_ai_cache_service_client_v1,
)
from google.cloud.aiplatform_v1.services.index_service import (
client as index_service_client_v1,
)
Expand Down
1 change: 1 addition & 0 deletions google/cloud/aiplatform/compat/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@
annotation_spec as annotation_spec_v1,
artifact as artifact_v1,
batch_prediction_job as batch_prediction_job_v1,
cached_content as cached_content_v1,
completion_stats as completion_stats_v1,
context as context_v1,
custom_job as custom_job_v1,
Expand Down
4 changes: 2 additions & 2 deletions google/cloud/aiplatform/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
feature_registry_service_client_v1,
featurestore_online_serving_service_client_v1,
featurestore_service_client_v1,
gen_ai_cache_service_client_v1,
index_service_client_v1,
index_endpoint_service_client_v1,
job_service_client_v1,
Expand Down Expand Up @@ -807,8 +808,7 @@ class GenAiCacheServiceClientWithOverride(ClientWithOverride):
_version_map = (
(
compat.V1,
# TODO(b/342585299): Temporary code. Switch to v1 once v1 is available.
gen_ai_cache_service_client_v1beta1.GenAiCacheServiceClient,
gen_ai_cache_service_client_v1.GenAiCacheServiceClient,
),
(
compat.V1BETA1,
Expand Down
20 changes: 10 additions & 10 deletions tests/unit/vertexai/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import json
import mock
import pytest
from vertexai.preview import caching
from vertexai.caching import CachedContent
from google.cloud.aiplatform import initializer
import vertexai
from google.cloud.aiplatform_v1beta1.types.cached_content import (
Expand All @@ -35,7 +35,7 @@
from google.cloud.aiplatform_v1beta1.types.tool import (
ToolConfig as GapicToolConfig,
)
from google.cloud.aiplatform_v1beta1.services import (
from google.cloud.aiplatform_v1.services import (
gen_ai_cache_service,
)

Expand Down Expand Up @@ -141,7 +141,7 @@ def list_cached_contents(self, request):

@pytest.mark.usefixtures("google_auth_mock")
class TestCaching:
"""Unit tests for caching.CachedContent."""
"""Unit tests for CachedContent."""

def setup_method(self):
vertexai.init(
Expand All @@ -156,7 +156,7 @@ def test_constructor_with_full_resource_name(self, mock_get_cached_content):
full_resource_name = (
"projects/123/locations/europe-west1/cachedContents/contents-id"
)
cache = caching.CachedContent(
cache = CachedContent(
cached_content_name=full_resource_name,
)

Expand All @@ -166,7 +166,7 @@ def test_constructor_with_full_resource_name(self, mock_get_cached_content):
def test_constructor_with_only_content_id(self, mock_get_cached_content):
partial_resource_name = "contents-id"

cache = caching.CachedContent(
cache = CachedContent(
cached_content_name=partial_resource_name,
)

Expand All @@ -179,7 +179,7 @@ def test_constructor_with_only_content_id(self, mock_get_cached_content):
def test_get_with_content_id(self, mock_get_cached_content):
partial_resource_name = "contents-id"

cache = caching.CachedContent.get(
cache = CachedContent.get(
cached_content_name=partial_resource_name,
)

Expand All @@ -192,7 +192,7 @@ def test_get_with_content_id(self, mock_get_cached_content):
def test_create_with_real_payload(
self, mock_create_cached_content, mock_get_cached_content
):
cache = caching.CachedContent.create(
cache = CachedContent.create(
model_name="model-name",
system_instruction=GapicContent(
role="system", parts=[GapicPart(text="system instruction")]
Expand All @@ -219,7 +219,7 @@ def test_create_with_real_payload(
def test_create_with_real_payload_and_wrapped_type(
self, mock_create_cached_content, mock_get_cached_content
):
cache = caching.CachedContent.create(
cache = CachedContent.create(
model_name="model-name",
system_instruction="Please answer my questions with cool",
tools=[],
Expand All @@ -239,15 +239,15 @@ def test_create_with_real_payload_and_wrapped_type(
assert cache.display_name == _TEST_DISPLAY_NAME

def test_list(self, mock_list_cached_contents):
cached_contents = caching.CachedContent.list()
cached_contents = CachedContent.list()
for i, cached_content in enumerate(cached_contents):
assert cached_content.name == f"cached_content{i + 1}_from_list_request"
assert cached_content.model_name == f"model-name{i + 1}"

def test_print_a_cached_content(
self, mock_create_cached_content, mock_get_cached_content
):
cached_content = caching.CachedContent.create(
cached_content = CachedContent.create(
model_name="model-name",
system_instruction="Please answer my questions with cool",
tools=[],
Expand Down
34 changes: 19 additions & 15 deletions tests/unit/vertexai/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@
gapic_content_types,
gapic_tool_types,
)
from google.cloud.aiplatform_v1beta1.types.cached_content import (
from google.cloud.aiplatform_v1.types.cached_content import (
CachedContent as GapicCachedContent,
)
from google.cloud.aiplatform_v1beta1.services import (
from google.cloud.aiplatform_v1.services import (
gen_ai_cache_service,
)
from vertexai.generative_models import _function_calling_utils
from vertexai.preview import caching
from vertexai.caching import CachedContent


_TEST_PROJECT = "test-project"
Expand Down Expand Up @@ -649,17 +649,19 @@ def test_generative_model_constructor_model_name(
with pytest.raises(ValueError):
generative_models.GenerativeModel("foo/bar/models/gemini-pro")

@pytest.mark.parametrize(
"generative_models",
[generative_models, preview_generative_models],
)
def test_generative_model_from_cached_content(
self, mock_get_cached_content_fixture
self, generative_models: generative_models, mock_get_cached_content_fixture
):
project_location_prefix = (
f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/"
)
cached_content = caching.CachedContent(
"cached-content-id-in-from-cached-content-test"
)
cached_content = CachedContent("cached-content-id-in-from-cached-content-test")

model = preview_generative_models.GenerativeModel.from_cached_content(
model = generative_models.GenerativeModel.from_cached_content(
cached_content=cached_content
)

Expand All @@ -683,14 +685,18 @@ def test_generative_model_from_cached_content(
== "cached-content-id-in-from-cached-content-test"
)

@pytest.mark.parametrize(
"generative_models",
[generative_models, preview_generative_models],
)
def test_generative_model_from_cached_content_with_resource_name(
self, mock_get_cached_content_fixture
self, mock_get_cached_content_fixture, generative_models: generative_models
):
project_location_prefix = (
f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/"
)

model = preview_generative_models.GenerativeModel.from_cached_content(
model = generative_models.GenerativeModel.from_cached_content(
cached_content="cached-content-id-in-from-cached-content-test"
)

Expand Down Expand Up @@ -848,7 +854,7 @@ def test_generate_content(
assert response5.text

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
target=prediction_service_v1.PredictionServiceClient,
attribute="generate_content",
new=lambda self, request: gapic_prediction_service_types.GenerateContentResponse(
candidates=[
Expand All @@ -870,11 +876,9 @@ def test_generate_content_with_cached_content(
self,
mock_get_cached_content_fixture,
):
cached_content = caching.CachedContent(
"cached-content-id-in-from-cached-content-test"
)
cached_content = CachedContent("cached-content-id-in-from-cached-content-test")

model = preview_generative_models.GenerativeModel.from_cached_content(
model = generative_models.GenerativeModel.from_cached_content(
cached_content=cached_content
)

Expand Down
25 changes: 25 additions & 0 deletions vertexai/caching/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Classes for working with the Gemini models."""

# We just want to re-export certain classes
# pylint: disable=g-multiple-import,g-importing-member
from vertexai.caching._caching import (
CachedContent,
)

__all__ = [
"CachedContent",
]
37 changes: 30 additions & 7 deletions vertexai/caching/_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
from google.cloud.aiplatform.compat.types import (
cached_content_v1beta1 as gca_cached_content,
)
from google.cloud.aiplatform_v1beta1.services import gen_ai_cache_service
from google.cloud.aiplatform_v1.services import (
gen_ai_cache_service as gen_ai_cache_service_v1,
)
from google.cloud.aiplatform_v1beta1.types.cached_content import (
CachedContent as GapicCachedContent,
)
Expand All @@ -36,6 +38,7 @@
GetCachedContentRequest,
UpdateCachedContentRequest,
)
from google.cloud.aiplatform_v1 import types as types_v1
from vertexai.generative_models import _generative_models
from vertexai.generative_models._generative_models import (
Content,
Expand Down Expand Up @@ -89,7 +92,7 @@ def _prepare_create_request(
if ttl and expire_time:
raise ValueError("Only one of ttl and expire_time can be set.")

request = CreateCachedContentRequest(
request_v1beta1 = CreateCachedContentRequest(
parent=f"projects/{project}/locations/{location}",
cached_content=GapicCachedContent(
model=model_name,
Expand All @@ -102,11 +105,21 @@ def _prepare_create_request(
display_name=display_name,
),
)
return request
serialized_message_v1beta1 = type(request_v1beta1).serialize(request_v1beta1)
try:
request_v1 = types_v1.CreateCachedContentRequest.deserialize(
serialized_message_v1beta1
)
except Exception as ex:
raise ValueError(
"Failed to convert CreateCachedContentRequest from v1beta1 to v1:\n"
f"{serialized_message_v1beta1}"
) from ex
return request_v1


def _prepare_get_cached_content_request(name: str) -> GetCachedContentRequest:
return GetCachedContentRequest(name=name)
return types_v1.GetCachedContentRequest(name=name)


class CachedContent(aiplatform_base._VertexAiResourceNounPlus):
Expand All @@ -122,7 +135,7 @@ class CachedContent(aiplatform_base._VertexAiResourceNounPlus):
client_class = aiplatform_utils.GenAiCacheServiceClientWithOverride

_gen_ai_cache_service_client_value: Optional[
gen_ai_cache_service.GenAiCacheServiceClient
gen_ai_cache_service_v1.GenAiCacheServiceClient
] = None

def __init__(self, cached_content_name: str):
Expand Down Expand Up @@ -253,15 +266,25 @@ def update(
update_mask.append("expire_time")

update_mask = field_mask_pb2.FieldMask(paths=update_mask)
request = UpdateCachedContentRequest(
request_v1beta1 = UpdateCachedContentRequest(
cached_content=GapicCachedContent(
name=self.resource_name,
expire_time=expire_time,
ttl=ttl,
),
update_mask=update_mask,
)
self.api_client.update_cached_content(request)
serialized_message_v1beta1 = type(request_v1beta1).serialize(request_v1beta1)
try:
request_v1 = types_v1.UpdateCachedContentRequest.deserialize(
serialized_message_v1beta1
)
except Exception as ex:
raise ValueError(
"Failed to convert UpdateCachedContentRequest from v1beta1 to v1:\n"
f"{serialized_message_v1beta1}"
) from ex
self.api_client.update_cached_content(request_v1)

@property
def expire_time(self) -> datetime.datetime:
Expand Down
Loading