Skip to content

Commit

Permalink
Merge pull request #11 from langchain-ai/move_vertexai
Browse files Browse the repository at this point in the history
move recent changes
  • Loading branch information
lkuligin authored Feb 13, 2024
2 parents 0bfce25 + 8be0f3b commit 62585bd
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 19 deletions.
28 changes: 25 additions & 3 deletions libs/vertexai/langchain_google_vertexai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Any, Dict, Iterator, List, Optional, Union, cast
from urllib.parse import urlparse

import proto # type: ignore[import-untyped]
import requests
from google.cloud.aiplatform_v1beta1.types.content import Part as GapicPart
from google.cloud.aiplatform_v1beta1.types.tool import FunctionCall
Expand Down Expand Up @@ -45,6 +46,12 @@
Image,
Part,
)
from vertexai.preview.language_models import ( # type: ignore
ChatModel as PreviewChatModel,
)
from vertexai.preview.language_models import (
CodeChatModel as PreviewCodeChatModel,
)

from langchain_google_vertexai._utils import (
get_generation_info,
Expand Down Expand Up @@ -272,10 +279,12 @@ def _parse_response_candidate(response_candidate: "Candidate") -> AIMessage:
first_part = response_candidate.content.parts[0]
if first_part.function_call:
function_call = {"name": first_part.function_call.name}

# dump to match other function calling llm for now
function_call_args_dict = proto.Message.to_dict(first_part.function_call)[
"args"
]
function_call["arguments"] = json.dumps(
{k: first_part.function_call.args[k] for k in first_part.function_call.args}
{k: function_call_args_dict[k] for k in function_call_args_dict}
)
additional_kwargs["function_call"] = function_call
return AIMessage(content=content, additional_kwargs=additional_kwargs)
Expand Down Expand Up @@ -316,12 +325,20 @@ def validate_environment(cls, values: Dict) -> Dict:
values["client"] = GenerativeModel(
model_name=values["model_name"], safety_settings=safety_settings
)
values["client_preview"] = GenerativeModel(
model_name=values["model_name"], safety_settings=safety_settings
)
else:
if is_codey_model(values["model_name"]):
model_cls = CodeChatModel
model_cls_preview = PreviewCodeChatModel
else:
model_cls = ChatModel
model_cls_preview = PreviewChatModel
values["client"] = model_cls.from_pretrained(values["model_name"])
values["client_preview"] = model_cls_preview.from_pretrained(
values["model_name"]
)
return values

def _generate(
Expand Down Expand Up @@ -493,8 +510,13 @@ def _stream(
# set param to `functions` until core tool/function calling implemented
raw_tools = params.pop("functions") if "functions" in params else None
tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None
safety_settings = params.pop("safety_settings", None)
responses = chat.send_message(
message, stream=True, generation_config=params, tools=tools
message,
stream=True,
generation_config=params,
safety_settings=safety_settings,
tools=tools,
)
for response in responses:
message = _parse_response_candidate(response.candidates[0])
Expand Down
41 changes: 27 additions & 14 deletions libs/vertexai/langchain_google_vertexai/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@
Image,
)
from vertexai.preview.language_models import ( # type: ignore[import-untyped]
ChatModel as PreviewChatModel,
)
from vertexai.preview.language_models import (
CodeChatModel as PreviewCodeChatModel,
)
from vertexai.preview.language_models import (
CodeGenerationModel as PreviewCodeGenerationModel,
)
from vertexai.preview.language_models import (
Expand Down Expand Up @@ -239,6 +245,27 @@ def _prepare_params(
params.pop("candidate_count")
return params

def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text.
Useful for checking if an input will fit in a model's context window.
Args:
text: The string input to tokenize.
Returns:
The integer number of tokens in the text.
"""
is_palm_chat_model = isinstance(
self.client_preview, PreviewChatModel
) or isinstance(self.client_preview, PreviewCodeChatModel)
if is_palm_chat_model:
result = self.client_preview.start_chat().count_tokens(text)
else:
result = self.client_preview.count_tokens([text])

return result.total_tokens


class VertexAI(_VertexAICommon, BaseLLM):
"""Google Vertex AI large language models."""
Expand Down Expand Up @@ -300,20 +327,6 @@ def validate_environment(cls, values: Dict) -> Dict:
raise ValueError("Only one candidate can be generated with streaming!")
return values

def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text.
Useful for checking if an input will fit in a model's context window.
Args:
text: The string input to tokenize.
Returns:
The integer number of tokens in the text.
"""
result = self.client_preview.count_tokens([text])
return result.total_tokens

def _response_to_generation(
self, response: TextGenerationResponse, *, stream: bool = False
) -> GenerationChunk:
Expand Down
12 changes: 12 additions & 0 deletions libs/vertexai/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,18 @@ def test_chat_vertexai_system_message(model_name: Optional[str]) -> None:
assert isinstance(response.content, str)


@pytest.mark.parametrize("model_name", model_names_to_test)
def test_get_num_tokens_from_messages(model_name: str) -> None:
if model_name:
model = ChatVertexAI(model_name=model_name, temperature=0.0)
else:
model = ChatVertexAI(temperature=0.0)
message = HumanMessage(content="Hello")
token = model.get_num_tokens_from_messages(messages=[message])
assert isinstance(token, int)
assert token == 3


def test_chat_vertexai_gemini_function_calling() -> None:
class MyModel(BaseModel):
name: str
Expand Down
2 changes: 0 additions & 2 deletions libs/vertexai/tests/integration_tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def test_tools() -> None:
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)

response = agent_executor.invoke({"input": "What is 6 raised to the 0.43 power?"})
print(response)
assert isinstance(response, dict)
assert response["input"] == "What is 6 raised to the 0.43 power?"

Expand All @@ -106,7 +105,6 @@ def test_stream() -> None:
]
response = list(llm.stream("What is 6 raised to the 0.43 power?", functions=tools))
assert len(response) == 1
# for chunk in response:
assert isinstance(response[0], AIMessageChunk)
assert "function_call" in response[0].additional_kwargs

Expand Down
114 changes: 114 additions & 0 deletions libs/vertexai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,35 @@
"""Test chat model integration."""

import json
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from unittest.mock import MagicMock, Mock, patch

import pytest
from google.cloud.aiplatform_v1beta1.types import (
Content,
FunctionCall,
Part,
)
from google.cloud.aiplatform_v1beta1.types import (
content as gapic_content_types,
)
from langchain_core.messages import (
AIMessage,
HumanMessage,
SystemMessage,
)
from vertexai.language_models import ChatMessage, InputOutputTextPair # type: ignore
from vertexai.preview.generative_models import ( # type: ignore
Candidate,
)

from langchain_google_vertexai.chat_models import (
ChatVertexAI,
_parse_chat_history,
_parse_chat_history_gemini,
_parse_examples,
_parse_response_candidate,
)


Expand Down Expand Up @@ -202,3 +215,104 @@ def test_default_params_gemini() -> None:
message = HumanMessage(content=user_prompt)
_ = model([message])
mock_start_chat.assert_called_once_with(history=[])


@pytest.mark.parametrize(
"raw_candidate, expected",
[
(
gapic_content_types.Candidate(
content=Content(
role="model",
parts=[
Part(
function_call=FunctionCall(
name="Information",
args={"name": "Ben"},
),
)
],
)
),
{
"name": "Information",
"arguments": {"name": "Ben"},
},
),
(
gapic_content_types.Candidate(
content=Content(
role="model",
parts=[
Part(
function_call=FunctionCall(
name="Information",
args={"info": ["A", "B", "C"]},
),
)
],
)
),
{
"name": "Information",
"arguments": {"info": ["A", "B", "C"]},
},
),
(
gapic_content_types.Candidate(
content=Content(
role="model",
parts=[
Part(
function_call=FunctionCall(
name="Information",
args={
"people": [
{"name": "Joe", "age": 30},
{"name": "Martha"},
]
},
),
)
],
)
),
{
"name": "Information",
"arguments": {
"people": [
{"name": "Joe", "age": 30},
{"name": "Martha"},
]
},
},
),
(
gapic_content_types.Candidate(
content=Content(
role="model",
parts=[
Part(
function_call=FunctionCall(
name="Information",
args={"info": [[1, 2, 3], [4, 5, 6]]},
),
)
],
)
),
{
"name": "Information",
"arguments": {"info": [[1, 2, 3], [4, 5, 6]]},
},
),
],
)
def test_parse_response_candidate(raw_candidate, expected) -> None:
response_candidate = Candidate._from_gapic(raw_candidate)
result = _parse_response_candidate(response_candidate)
result_arguments = json.loads(
result.additional_kwargs["function_call"]["arguments"]
)

assert result_arguments == expected["arguments"]

0 comments on commit 62585bd

Please sign in to comment.