Skip to content

Commit

Permalink
langchain-google-vertexai: perserving grounding metadata (#16309)
Browse files Browse the repository at this point in the history
Revival of #14549 that
closes #14548.
  • Loading branch information
jamesbraza authored Jan 25, 2024
1 parent adc0084 commit 0785432
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 35 deletions.
54 changes: 29 additions & 25 deletions libs/partners/google-vertexai/langchain_google_vertexai/_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Utilities to init Vertex AI."""

import dataclasses
from importlib import metadata
from typing import Any, Callable, Dict, Optional, Union

Expand All @@ -10,7 +12,13 @@
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import create_base_retry_decorator
from vertexai.preview.generative_models import Image # type: ignore
from vertexai.generative_models._generative_models import ( # type: ignore[import-untyped]
Candidate,
)
from vertexai.language_models import ( # type: ignore[import-untyped]
TextGenerationResponse,
)
from vertexai.preview.generative_models import Image # type: ignore[import-untyped]


def create_retry_decorator(
Expand Down Expand Up @@ -88,27 +96,23 @@ def is_gemini_model(model_name: str) -> bool:
return model_name is not None and "gemini" in model_name


def get_generation_info(candidate: Any, is_gemini: bool) -> Optional[Dict[str, Any]]:
try:
if is_gemini:
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini#response_body
return {
"is_blocked": any(
[rating.blocked for rating in candidate.safety_ratings]
),
"safety_ratings": [
{
"category": rating.category.name,
"probability_label": rating.probability.name,
}
for rating in candidate.safety_ratings
],
}
else:
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text-chat#response_body
return {
"is_blocked": candidate.is_blocked,
"safety_attributes": candidate.safety_attributes,
}
except Exception:
return None
def get_generation_info(
candidate: Union[TextGenerationResponse, Candidate], is_gemini: bool
) -> Dict[str, Any]:
if is_gemini:
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini#response_body
return {
"is_blocked": any([rating.blocked for rating in candidate.safety_ratings]),
"safety_ratings": [
{
"category": rating.category.name,
"probability_label": rating.probability.name,
}
for rating in candidate.safety_ratings
],
"citation_metadata": candidate.citation_metadata,
}
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text-chat#response_body
candidate_dc = dataclasses.asdict(candidate)
candidate_dc.pop("text")
return {k: v for k, v in candidate_dc.items() if not k.startswith("_")}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Test ChatGoogleVertexAI chat model."""
from typing import cast
from typing import Optional, cast

import pytest
from langchain_core.messages import (
Expand All @@ -16,7 +16,7 @@


@pytest.mark.parametrize("model_name", model_names_to_test)
def test_initialization(model_name: str) -> None:
def test_initialization(model_name: Optional[str]) -> None:
"""Test chat model initialization."""
if model_name:
model = ChatVertexAI(model_name=model_name)
Expand All @@ -30,7 +30,7 @@ def test_initialization(model_name: str) -> None:


@pytest.mark.parametrize("model_name", model_names_to_test)
def test_vertexai_single_call(model_name: str) -> None:
def test_vertexai_single_call(model_name: Optional[str]) -> None:
if model_name:
model = ChatVertexAI(model_name=model_name)
else:
Expand Down Expand Up @@ -164,7 +164,7 @@ def test_vertexai_single_call_with_examples() -> None:


@pytest.mark.parametrize("model_name", model_names_to_test)
def test_vertexai_single_call_with_history(model_name: str) -> None:
def test_vertexai_single_call_with_history(model_name: Optional[str]) -> None:
if model_name:
model = ChatVertexAI(model_name=model_name)
else:
Expand Down Expand Up @@ -203,7 +203,7 @@ def test_chat_vertexai_gemini_system_message_error(model_name: str) -> None:


@pytest.mark.parametrize("model_name", model_names_to_test)
def test_chat_vertexai_system_message(model_name: str) -> None:
def test_chat_vertexai_system_message(model_name: Optional[str]) -> None:
if model_name:
model = ChatVertexAI(
model_name=model_name, convert_system_message_to_human=True
Expand Down
34 changes: 29 additions & 5 deletions libs/partners/google-vertexai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Test chat model integration."""
from typing import Any, Dict, Optional

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from unittest.mock import MagicMock, Mock, patch

import pytest
Expand Down Expand Up @@ -45,6 +47,13 @@ def test_parse_examples_failes_wrong_sequence() -> None:
)


@dataclass
class StubTextChatResponse:
"""Stub text-chat response from VertexAI for testing."""

text: str


@pytest.mark.parametrize("stop", [None, "stop1"])
def test_vertexai_args_passed(stop: Optional[str]) -> None:
response_text = "Goodbye"
Expand All @@ -59,7 +68,7 @@ def test_vertexai_args_passed(stop: Optional[str]) -> None:
# Mock the library to ensure the args are passed correctly
with patch("vertexai._model_garden._model_garden_models._from_pretrained") as mg:
mock_response = MagicMock()
mock_response.candidates = [Mock(text=response_text)]
mock_response.candidates = [StubTextChatResponse(text=response_text)]
mock_chat = MagicMock()
mock_send_message = MagicMock(return_value=mock_response)
mock_chat.send_message = mock_send_message
Expand Down Expand Up @@ -136,7 +145,7 @@ def test_default_params_palm() -> None:

with patch("vertexai._model_garden._model_garden_models._from_pretrained") as mg:
mock_response = MagicMock()
mock_response.candidates = [Mock(text="Goodbye")]
mock_response.candidates = [StubTextChatResponse(text="Goodbye")]
mock_chat = MagicMock()
mock_send_message = MagicMock(return_value=mock_response)
mock_chat.send_message = mock_send_message
Expand All @@ -159,13 +168,28 @@ def test_default_params_palm() -> None:
)


@dataclass
class StubGeminiResponse:
"""Stub gemini response from VertexAI for testing."""

text: str
content: Any
citation_metadata: Any
safety_ratings: List[Any] = field(default_factory=list)


def test_default_params_gemini() -> None:
user_prompt = "Hello"

with patch("langchain_google_vertexai.chat_models.GenerativeModel") as gm:
mock_response = MagicMock()
content = Mock(parts=[Mock(function_call=None)])
mock_response.candidates = [Mock(text="Goodbye", content=content)]
mock_response.candidates = [
StubGeminiResponse(
text="Goodbye",
content=Mock(parts=[Mock(function_call=None)]),
citation_metadata=Mock(),
)
]
mock_chat = MagicMock()
mock_send_message = MagicMock(return_value=mock_response)
mock_chat.send_message = mock_send_message
Expand Down

0 comments on commit 0785432

Please sign in to comment.