diff --git a/libs/vertexai/langchain_google_vertexai/__init__.py b/libs/vertexai/langchain_google_vertexai/__init__.py index ef27c7e2..bd057448 100644 --- a/libs/vertexai/langchain_google_vertexai/__init__.py +++ b/libs/vertexai/langchain_google_vertexai/__init__.py @@ -31,6 +31,7 @@ ) from langchain_google_vertexai.llms import VertexAI from langchain_google_vertexai.model_garden import VertexAIModelGarden +from langchain_google_vertexai.model_garden_maas import get_vertex_maas_model from langchain_google_vertexai.utils import create_context_cache from langchain_google_vertexai.vectorstores import ( DataStoreDocumentStorage, @@ -81,4 +82,5 @@ "VertexPairWiseStringEvaluator", "VertexStringEvaluator", "create_context_cache", + "get_vertex_maas_model", ] diff --git a/libs/vertexai/langchain_google_vertexai/model_garden_maas/__init__.py b/libs/vertexai/langchain_google_vertexai/model_garden_maas/__init__.py index e69de29b..3d586626 100644 --- a/libs/vertexai/langchain_google_vertexai/model_garden_maas/__init__.py +++ b/libs/vertexai/langchain_google_vertexai/model_garden_maas/__init__.py @@ -0,0 +1,24 @@ +from langchain_google_vertexai.model_garden_maas.llama import VertexModelGardenLlama + +_MISTRAL_MODELS = [ + "mistral-nemo@2407", + "mistral-large@2407", +] +_LLAMA_MODELS = ["meta/llama3-405b-instruct-maas"] +_MAAS_MODELS = _MISTRAL_MODELS + _LLAMA_MODELS + + +def get_vertex_maas_model(model_name, **kwargs): + """Return a corresponding Vertex MaaS instance. + + A factory method based on model's name. + """ + if model_name not in _MAAS_MODELS: + raise ValueError(f"model name {model_name} is not supported!") + if model_name in _MISTRAL_MODELS: + from langchain_google_vertexai.model_garden_maas.mistral import ( # noqa: F401 + VertexModelGardenMistral, + ) + + return VertexModelGardenMistral(model=model_name, **kwargs) + return VertexModelGardenLlama(model=model_name, **kwargs) diff --git a/libs/vertexai/langchain_google_vertexai/model_garden_maas/_base.py b/libs/vertexai/langchain_google_vertexai/model_garden_maas/_base.py index 134a8619..030b1c69 100644 --- a/libs/vertexai/langchain_google_vertexai/model_garden_maas/_base.py +++ b/libs/vertexai/langchain_google_vertexai/model_garden_maas/_base.py @@ -10,11 +10,11 @@ Union, ) -import httpx # type: ignore[unused-ignore, import-not-found] +import httpx from google import auth from google.auth.credentials import Credentials from google.auth.transport import requests as auth_requests -from httpx_sse import ( # type: ignore[import-not-found] +from httpx_sse import ( EventSource, aconnect_sse, connect_sse, @@ -99,6 +99,7 @@ class _BaseVertexMaasModelGarden(_VertexAIBase): append_tools_to_system_message: bool = False "Whether to append tools to the system message or not." model_family: Optional[VertexMaaSModelFamily] = None + timeout: int = 120 class Config: """Configuration for this pydantic object.""" @@ -106,6 +107,28 @@ class Config: allow_population_by_field_name = True arbitrary_types_allowed = True + def __init__(self, **kwargs): + super().__init__(**kwargs) + token = _get_token(credentials=self.credentials) + endpoint = self.get_url() + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + "Authorization": f"Bearer {token}", + "x-goog-api-client": self._library_version, + "user_agent": self._user_agent, + } + self.client = httpx.Client( + base_url=endpoint, + headers=headers, + timeout=self.timeout, + ) + self.async_client = httpx.AsyncClient( + base_url=endpoint, + headers=headers, + timeout=self.timeout, + ) + @root_validator(pre=True) def validate_environment_model_garden(cls, values: Dict) -> Dict: """Validate that the python package exists in environment.""" @@ -132,7 +155,7 @@ def _get_url_part(self, stream: bool = False) -> str: ":streamRawPredict" ) return f"publishers/mistralai/models/{self.full_model_name}:rawPredict" - return "openapi/chat/completions" + return "endpoints/openapi/chat/completions" def get_url(self) -> str: if self.model_family == VertexMaaSModelFamily.LLAMA: @@ -173,12 +196,17 @@ async def _completion_with_retry(**kwargs: Any) -> Any: kwargs["stream"] = False stream = kwargs["stream"] if stream: + # Llama and Mistral expect different "Content-Type" for streaming + headers = {"Accept": "text/event-stream"} + if headers_content_type := kwargs.pop("headers_content_type", None): + headers["Content-Type"] = headers_content_type + event_source = aconnect_sse( llm.async_client, "POST", llm._get_url_part(stream=True), json=kwargs, - headers={"Accept": "text/event-stream"}, + headers=headers, ) return _aiter_sse(event_source) else: @@ -197,6 +225,10 @@ def completion_with_retry(llm: _BaseVertexMaasModelGarden, **kwargs): kwargs = llm._enrich_params(kwargs) if stream: + # Llama and Mistral expect different "Content-Type" for streaming + headers = {"Accept": "text/event-stream"} + if headers_content_type := kwargs.pop("headers_content_type", None): + headers["Content-Type"] = headers_content_type def iter_sse(): with connect_sse( @@ -204,7 +236,7 @@ def iter_sse(): "POST", llm._get_url_part(stream=True), json=kwargs, - headers={"Accept": "text/event-stream"}, + headers=headers, ) as event_source: _raise_on_error(event_source.response) for event in event_source.iter_sse(): diff --git a/libs/vertexai/langchain_google_vertexai/model_garden_maas/llama.py b/libs/vertexai/langchain_google_vertexai/model_garden_maas/llama.py new file mode 100644 index 00000000..d80c4dd7 --- /dev/null +++ b/libs/vertexai/langchain_google_vertexai/model_garden_maas/llama.py @@ -0,0 +1,487 @@ +from __future__ import annotations + +import json +import uuid +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + Iterator, + List, + Literal, + Optional, + Sequence, + Type, + Union, + cast, + overload, +) + +from langchain_core.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain_core.language_models import LanguageModelInput +from langchain_core.language_models.chat_models import ( + BaseChatModel, + agenerate_from_stream, + generate_from_stream, +) +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) +from langchain_core.messages.tool import tool_call as create_tool_call +from langchain_core.messages.tool import tool_call_chunk +from langchain_core.outputs import ( + ChatGeneration, + ChatGenerationChunk, + ChatResult, +) +from langchain_core.runnables import Runnable +from langchain_core.tools import BaseTool +from langchain_core.utils.function_calling import ( + convert_to_openai_function, +) + +from langchain_google_vertexai.model_garden_maas._base import ( + _BaseVertexMaasModelGarden, + acompletion_with_retry, + completion_with_retry, +) + + +@overload +def _parse_response_candidate_llama( + response_candidate: Dict[str, str], streaming: Literal[False] = False +) -> AIMessage: + ... + + +@overload +def _parse_response_candidate_llama( + response_candidate: Dict[str, str], streaming: Literal[True] +) -> AIMessageChunk: + ... + + +def _parse_response_candidate_llama( + response_candidate: Dict[str, str], streaming: bool = False +) -> AIMessage: + content = response_candidate["content"] + role = response_candidate["role"] + if role != "assistant": + raise ValueError(f"Role in response is {role}, expected 'assistant'!") + tool_calls = [] + tool_call_chunks = [] + + response_json = None + try: + response_json = json.loads(response_candidate["content"]) + except ValueError: + pass + if response_json and "name" in response_json: + function_name = response_json["name"] + function_args = response_json.get("parameters", None) + if streaming: + tool_call_chunks.append( + tool_call_chunk( + name=function_name, args=function_args, id=str(uuid.uuid4()) + ) + ) + else: + tool_calls.append( + create_tool_call( + name=function_name, args=function_args, id=str(uuid.uuid4()) + ) + ) + content = "" + + if streaming: + return AIMessageChunk( + content=content, + tool_call_chunks=tool_call_chunks, + ) + + return AIMessage( + content=content, + tool_calls=tool_calls, + ) + + +class VertexModelGardenLlama(_BaseVertexMaasModelGarden, BaseChatModel): # type: ignore[misc] + """Integration for Llama 3.1 on Google Cloud Vertex AI Model-as-a-Service. + + For more information, see: + https://cloud.google.com/blog/products/ai-machine-learning/llama-3-1-on-vertex-ai + + Setup: + You need to enable a corresponding MaaS model (Google Cloud UI console -> + Vertex AI -> Model Garden -> search for a model you need and click enable) + + You must have the langchain-google-vertexai Python package installed + .. code-block:: bash + + pip install -U langchain-google-vertexai + + And either: + - Have credentials configured for your environment + (gcloud, workload identity, etc...) + - Store the path to a service account JSON file as the + GOOGLE_APPLICATION_CREDENTIALS environment variable + + This codebase uses the google.auth library which first looks for the application + credentials variable mentioned above, and then looks for system-level auth. + + For more information, see: + https://cloud.google.com/docs/authentication/application-default-credentials#GAC + and https://googleapis.dev/python/google-auth/latest/reference/google.auth.html#module-google.auth. + + Key init args — completion params: + model: str + Name of VertexMaaS model to use ("meta/llama3-405b-instruct-maas") + append_tools_to_system_message: bool + Whether to append tools to a system message + + + Key init args — client params: + credentials: Optional[google.auth.credentials.Credentials] + The default custom credentials to use when making API calls. If not + provided, credentials will be ascertained from the environment. + project: Optional[str] + The default GCP project to use when making Vertex API calls. + location: str = "us-central1" + The default location to use when making API calls. + + See full list of supported init args and their descriptions in the params section. + + Instantiate: + .. code-block:: python + + from langchain_google_vertexai import VertexMaaS + + llm = VertexModelGardenLlama( + model="meta/llama3-405b-instruct-maas", + # other params... + ) + + Invoke: + .. code-block:: python + + messages = [ + ("system", "You are a helpful translator. Translate the user sentence to French."), + ("human", "I love programming."), + ] + llm.invoke(messages) + + .. code-block:: python + + AIMessage(content="J'adore programmer. \n", id='run-925ce305-2268-44c4-875f-dde9128520ad-0') + + Stream: + .. code-block:: python + + for chunk in llm.stream(messages): + print(chunk) + + .. code-block:: python + + AIMessageChunk(content='J', id='run-9df01d73-84d9-42db-9d6b-b1466a019e89') + AIMessageChunk(content="'adore programmer. \n", id='run-9df01d73-84d9-42db-9d6b-b1466a019e89') + AIMessageChunk(content='', id='run-9df01d73-84d9-42db-9d6b-b1466a019e89') + + .. code-block:: python + + stream = llm.stream(messages) + full = next(stream) + for chunk in stream: + full += chunk + full + + .. code-block:: python + + AIMessageChunk(content="J'adore programmer. \n", id='run-b7f7492c-4cb5-42d0-8fc3-dce9b293b0fb') + + """ # noqa: E501 + + def _convert_messages( + self, messages: List[BaseMessage], tools: Optional[List[BaseTool]] = None + ) -> List[Dict[str, Any]]: + converted_messages: List[Dict[str, Any]] = [] + if tools and not self.append_tools_to_system_message: + raise ValueError( + "If providing tools, either format system message yourself or " + "append_tools_to_system_message to True!" + ) + elif tools: + tools_str = "\n".join( + [json.dumps(convert_to_openai_function(t)) for t in tools] + ) + formatted_system_message = ( + "You are an assistant with access to the following tools:\n\n" + f"{tools_str}\n\n" + "If you decide to use a tool, please respond with a JSON for a " + "function call with its proper arguments that best answers the " + "given prompt.\nRespond in the format " + '{"name": function name, "parameters": dictionary ' + "of argument name and its value}. Do not use variables.\n" + "Do not provide any additional comments when calling a tool.\n" + "Do not mention tools to the user when preparing the final answer." + ) + message = messages[0] + if not isinstance(message, SystemMessage): + converted_messages.append( + {"role": "system", "content": formatted_system_message} + ) + else: + converted_messages.append( + { + "role": "system", + "content": str(message.content) + + "\n" + + formatted_system_message, + } + ) + + for i, message in enumerate(messages): + if tools and isinstance(message, SystemMessage) and i == 0: + continue + if isinstance(message, AIMessage): + converted_messages.append( + {"role": "assistant", "content": message.content} + ) + elif isinstance(message, HumanMessage): + converted_messages.append({"role": "user", "content": message.content}) + elif isinstance(message, SystemMessage): + converted_messages.append( + {"role": "system", "content": message.content} + ) + elif isinstance(message, ToolMessage): + # we also need to format a previous message if we got a tool result + prev_message = messages[i - 1] + if not isinstance(prev_message, AIMessage): + raise ValueError("ToolMessage should follow AIMessage only!") + _ = converted_messages[-1].pop("content", None) + tool_calls = [] + for tool_call in prev_message.tool_calls: + tool_calls.append( + { + "type": "function", + "id": tool_call["id"], + "function": { + "name": tool_call["name"], + "arguments": json.dumps(tool_call.get("args", {})), + }, + } + ) + converted_messages[-1]["tool_calls"] = tool_calls + if len(tool_calls) > 1: + raise ValueError( + "Only a single function call per turn is supported!" + ) + converted_messages.append( + { + "role": "tool", + "name": message.name, + "content": message.content, + "tool_call_id": message.tool_call_id, + } + ) + else: + raise ValueError(f"Message type {type(message)} is not yet supported!") + return converted_messages + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + stream: Optional[bool] = None, + *, + tools: Optional[List[BaseTool]] = None, + **kwargs: Any, + ) -> ChatResult: + """Generate next turn in the conversation. + + Args: + messages: The history of the conversation as a list of messages. Code chat + does not support context. + stop: The list of stop words (optional). + run_manager: The CallbackManager for LLM run, it's not used at the moment. + stream: Whether to use the streaming endpoint. + + Returns: + The ChatResult that contains outputs generated by the model. + + Raises: + ValueError: if the last message in the list is not from human. + """ + if stream is True: + return generate_from_stream( + self._stream( + messages, + stop=stop, + run_manager=run_manager, + tools=tools, + **kwargs, + ) + ) + + converted_messages = self._convert_messages(messages, tools=tools) + + response = completion_with_retry(self, messages=converted_messages, **kwargs) + return self._create_chat_result(response) + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stream: Optional[bool] = None, + *, + tools: Optional[List[BaseTool]] = None, + **kwargs: Any, + ) -> ChatResult: + if stream: + stream_iter = self._astream( + messages=messages, stop=stop, run_manager=run_manager, **kwargs + ) + return await agenerate_from_stream(stream_iter) + + converted_messages = self._convert_messages(messages, tools=tools) + response = await acompletion_with_retry( + self, messages=converted_messages, run_manager=run_manager, **kwargs + ) + return self._create_chat_result(response) + + def _create_chat_result(self, response: Dict) -> ChatResult: + generations = [] + token_usage = response.get("usage", {}) + for candidate in response["choices"]: + finish_reason = response.get("finish_reason") + message = _parse_response_candidate_llama(candidate["message"]) + if token_usage and isinstance(message, AIMessage): + message.usage_metadata = { + "input_tokens": token_usage.get("prompt_tokens", 0), + "output_tokens": token_usage.get("completion_tokens", 0), + "total_tokens": token_usage.get("total_tokens", 0), + } + gen = ChatGeneration( + message=message, + generation_info={"finish_reason": finish_reason}, + ) + generations.append(gen) + + llm_output = {"token_usage": token_usage, "model": self.model_name} + return ChatResult(generations=generations, llm_output=llm_output) + + @property + def _llm_type(self) -> str: + """Return type of chat model.""" + return "vertexai_model_garden_maas_llama" + + def _parse_chunk(self, chunk: Dict) -> AIMessageChunk: + chunk_delta = chunk["choices"][0]["delta"] + content = chunk_delta.get("content", "") + if chunk_delta.get("role") != "assistant": + raise ValueError(f"Got chunk with non-assistant role: {chunk_delta}") + additional_kwargs = {} + if raw_tool_calls := chunk_delta.get("tool_calls"): + additional_kwargs["tool_calls"] = raw_tool_calls + try: + tool_call_chunks = [] + for raw_tool_call in raw_tool_calls: + if not raw_tool_call.get("index") and not raw_tool_call.get("id"): + tool_call_id = str(uuid.uuid4()) + else: + tool_call_id = raw_tool_call.get("id") + tool_call_chunks.append( + tool_call_chunk( + name=raw_tool_call["function"].get("name"), + args=raw_tool_call["function"].get("arguments"), + id=tool_call_id, + index=raw_tool_call.get("index"), + ) + ) + except KeyError: + pass + else: + tool_call_chunks = [] + if token_usage := chunk.get("usage"): + usage_metadata = { + "input_tokens": token_usage.get("prompt_tokens", 0), + "output_tokens": token_usage.get("completion_tokens", 0), + "total_tokens": token_usage.get("total_tokens", 0), + } + else: + usage_metadata = None + return AIMessageChunk( + content=content, + additional_kwargs=additional_kwargs, + tool_call_chunks=tool_call_chunks, + usage_metadata=usage_metadata, # type: ignore[arg-type] + ) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + *, + tools: Optional[List[BaseTool]] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + converted_messages = self._convert_messages(messages, tools=tools) + params = {**kwargs, "stream": True, "headers_content_type": "text/event-stream"} + + for chunk in completion_with_retry( + self, messages=converted_messages, run_manager=run_manager, **params + ): + if len(chunk["choices"]) == 0: + continue + message = self._parse_chunk(chunk) + gen_chunk = ChatGenerationChunk(message=message) + if run_manager: + run_manager.on_llm_new_token( + token=cast(str, message.content), chunk=gen_chunk + ) + yield gen_chunk + + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + *, + tools: Optional[List[BaseTool]] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + converted_messages = self._convert_messages(messages, tools=tools) + params = {**kwargs, "stream": True, "headers_content_type": "text/event-stream"} + + async for chunk in await acompletion_with_retry( + self, messages=converted_messages, run_manager=run_manager, **params + ): + if len(chunk["choices"]) == 0: + continue + message = self._parse_chunk(chunk) + gen_chunk = ChatGenerationChunk(message=message) + if run_manager: + await run_manager.on_llm_new_token( + token=cast(str, message.content), chunk=gen_chunk + ) + yield gen_chunk + + def bind_tools( + self, + tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]], + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + """Bind tool-like objects to this chat model.""" + formatted_tools = [convert_to_openai_function(tool) for tool in tools] + return super().bind(tools=formatted_tools, **kwargs) diff --git a/libs/vertexai/langchain_google_vertexai/model_garden_maas/mistral.py b/libs/vertexai/langchain_google_vertexai/model_garden_maas/mistral.py index 7f1626aa..8975864f 100644 --- a/libs/vertexai/langchain_google_vertexai/model_garden_maas/mistral.py +++ b/libs/vertexai/langchain_google_vertexai/model_garden_maas/mistral.py @@ -1,6 +1,5 @@ from typing import Any, Optional -import httpx # type: ignore[unused-ignore, import-not-found] from langchain_core.callbacks import ( CallbackManagerForLLMRun, ) @@ -10,7 +9,6 @@ from langchain_google_vertexai.model_garden_maas._base import ( _BaseVertexMaasModelGarden, - _get_token, acompletion_with_retry, completion_with_retry, ) @@ -19,33 +17,6 @@ class VertexModelGardenMistral(_BaseVertexMaasModelGarden, chat_models.ChatMistralAI): # type: ignore[unused-ignore, misc] - def __init__(self, **kwargs): - super().__init__(**kwargs) - token = _get_token(credentials=self.credentials) - self.endpoint = self.get_url() - self.client = httpx.Client( - base_url=self.endpoint, - headers={ - "Content-Type": "application/json", - "Accept": "application/json", - "Authorization": f"Bearer {token}", - "x-goog-api-client": self._library_version, - "user_agent": self._user_agent, - }, - timeout=self.timeout, - ) - self.async_client = httpx.AsyncClient( - base_url=self.endpoint, - headers={ - "Content-Type": "application/json", - "Accept": "application/json", - "Authorization": f"Bearer {token}", - "x-goog-api-client": self._library_version, - "user_agent": self._user_agent, - }, - timeout=self.timeout, - ) - def completion_with_retry( self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any ) -> Any: diff --git a/libs/vertexai/poetry.lock b/libs/vertexai/poetry.lock index c6f6629f..5a63e1ec 100644 --- a/libs/vertexai/poetry.lock +++ b/libs/vertexai/poetry.lock @@ -2834,4 +2834,4 @@ mistral = ["langchain-mistralai"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "a7c0a3c26c217f9a9b30080a50e275c248ed82d1b10cb1f4a808e294d8c74276" +content-hash = "8240d23b8fc671345e140d2720298d549a40e05da220a901ee8b72a1b51bcb25" diff --git a/libs/vertexai/pyproject.toml b/libs/vertexai/pyproject.toml index 968cd6de..44bd06fd 100644 --- a/libs/vertexai/pyproject.toml +++ b/libs/vertexai/pyproject.toml @@ -18,6 +18,8 @@ google-cloud-storage = "^2.17.0" # optional dependencies anthropic = { extras = ["vertexai"], version = ">=0.30.0,<1", optional = true } langchain-mistralai = { version = ">=0.1.12,<1", optional = true } +httpx = "^0.27.0" +httpx-sse = "^0.4.0" [tool.poetry.group.test] optional = true @@ -46,6 +48,8 @@ langchain-standard-tests = {git = "https://github.com/langchain-ai/langchain.git ignore-words-list = "rouge" + + [tool.poetry.group.codespell] optional = true diff --git a/libs/vertexai/tests/integration_tests/test_maas.py b/libs/vertexai/tests/integration_tests/test_maas.py index 78b5c016..d5ecb2d9 100644 --- a/libs/vertexai/tests/integration_tests/test_maas.py +++ b/libs/vertexai/tests/integration_tests/test_maas.py @@ -11,20 +11,19 @@ ) from langchain_core.tools import tool -from langchain_google_vertexai.model_garden_maas.mistral import ( - VertexModelGardenMistral, +from langchain_google_vertexai.model_garden_maas import ( + _LLAMA_MODELS, + _MISTRAL_MODELS, + get_vertex_maas_model, ) -model_names = [ - "mistral-nemo@2407", - "mistral-large@2407", -] +model_names = _LLAMA_MODELS + _MISTRAL_MODELS @pytest.mark.extended @pytest.mark.parametrize("model_name", model_names) def test_generate(model_name: str) -> None: - llm = VertexModelGardenMistral(model=model_name, location="us-central1") + llm = get_vertex_maas_model(model_name=model_name, location="us-central1") output = llm.invoke("What is the meaning of life?") assert isinstance(output, AIMessage) print(output) @@ -33,7 +32,7 @@ def test_generate(model_name: str) -> None: @pytest.mark.extended @pytest.mark.parametrize("model_name", model_names) async def test_agenerate(model_name: str) -> None: - llm = VertexModelGardenMistral(model=model_name, location="us-central1") + llm = get_vertex_maas_model(model_name=model_name, location="us-central1") output = await llm.ainvoke("What is the meaning of life?") assert isinstance(output, AIMessage) print(output) @@ -42,7 +41,7 @@ async def test_agenerate(model_name: str) -> None: @pytest.mark.extended @pytest.mark.parametrize("model_name", model_names) def test_stream(model_name: str) -> None: - llm = VertexModelGardenMistral(model=model_name, location="us-central1") + llm = get_vertex_maas_model(model_name=model_name, location="us-central1") output = llm.stream("What is the meaning of life?") for chunk in output: assert isinstance(chunk, AIMessageChunk) @@ -51,7 +50,7 @@ def test_stream(model_name: str) -> None: @pytest.mark.extended @pytest.mark.parametrize("model_name", model_names) async def test_astream(model_name: str) -> None: - llm = VertexModelGardenMistral(model=model_name, location="us-central1") + llm = get_vertex_maas_model(model_name=model_name, location="us-central1") output = llm.astream("What is the meaning of life?") async for chunk in output: assert isinstance(chunk, AIMessageChunk) @@ -72,7 +71,11 @@ def search( tools = [search] - llm = VertexModelGardenMistral(model=model_name, location="us-central1") + llm = get_vertex_maas_model( + model_name=model_name, + location="us-central1", + append_tools_to_system_message=True, + ) llm_with_search = llm.bind_tools( tools=tools, ) @@ -86,7 +89,7 @@ def search( assert isinstance(response, AIMessage) tool_calls = response.tool_calls - assert len(tool_calls) == 1 + assert len(tool_calls) > 0 tool_response = search("sparrow") tool_messages: List[BaseMessage] = [] @@ -103,5 +106,6 @@ def search( result = llm_with_search.invoke([request, response] + tool_messages) assert isinstance(result, AIMessage) - assert "brown" in result.content + if model_name in _MISTRAL_MODELS: + assert "brown" in result.content assert len(result.tool_calls) == 0 diff --git a/libs/vertexai/tests/unit_tests/test_imports.py b/libs/vertexai/tests/unit_tests/test_imports.py index a0ea17f5..998c4463 100644 --- a/libs/vertexai/tests/unit_tests/test_imports.py +++ b/libs/vertexai/tests/unit_tests/test_imports.py @@ -34,6 +34,7 @@ "VertexPairWiseStringEvaluator", "VertexStringEvaluator", "create_context_cache", + "get_vertex_maas_model", ] diff --git a/libs/vertexai/tests/unit_tests/test_maas.py b/libs/vertexai/tests/unit_tests/test_maas.py new file mode 100644 index 00000000..2b27b277 --- /dev/null +++ b/libs/vertexai/tests/unit_tests/test_maas.py @@ -0,0 +1,143 @@ +import json +from typing import Any, Dict +from unittest.mock import ANY, MagicMock, patch + +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage +from langchain_core.tools import tool +from langchain_core.utils.function_calling import convert_to_openai_function + +from langchain_google_vertexai.model_garden_maas import get_vertex_maas_model +from langchain_google_vertexai.model_garden_maas.llama import ( + _parse_response_candidate_llama, +) + + +@patch("langchain_google_vertexai.model_garden_maas._base.auth") +def test_llama_init(mock_auth: Any) -> None: + mock_credentials = MagicMock() + mock_credentials.token.return_value = "test-token" + mock_auth.default.return_value = (mock_credentials, None) + llm = get_vertex_maas_model( + model_name="meta/llama3-405b-instruct-maas", + location="moon-dark", + project="test-project", + ) + assert llm._llm_type == "vertexai_model_garden_maas_llama" + assert llm.model_name == "meta/llama3-405b-instruct-maas" + + assert ( + llm.get_url() + == "https://moon-dark-aiplatform.googleapis.com/v1beta1/projects/test-project/locations/moon-dark" + ) + assert llm._get_url_part() == "endpoints/openapi/chat/completions" + assert llm._get_url_part(stream=True) == "endpoints/openapi/chat/completions" + mock_credentials.refresh.assert_called_once() + + +@patch("langchain_google_vertexai.model_garden_maas._base.auth") +def test_parse_history(mock_auth: Any) -> None: + llm = get_vertex_maas_model( + model_name="meta/llama3-405b-instruct-maas", + location="us-central1", + project="test-project", + ) + history = [ + SystemMessage(content="You're a helpful assistant"), + HumanMessage(content="What is the capital of Great Britain?"), + AIMessage(content="London is a capital of Great Britain"), + ] + parsed_history = llm._convert_messages(history) + expected_parsed_history = [ + {"role": "system", "content": "You're a helpful assistant"}, + {"role": "user", "content": "What is the capital of Great Britain?"}, + {"role": "assistant", "content": "London is a capital of Great Britain"}, + ] + assert parsed_history == expected_parsed_history + + +@patch("langchain_google_vertexai.model_garden_maas._base.auth") +def test_parse_history_llama_tools(mock_auth: Any) -> None: + @tool + def get_weather(city: str) -> float: + """Get the current weather and temperature for a given city.""" + return 23.0 + + schema = convert_to_openai_function(get_weather) + + llm = get_vertex_maas_model( + model_name="meta/llama3-405b-instruct-maas", + location="us-central1", + project="test-project", + append_tools_to_system_message=True, + ) + history = [ + SystemMessage(content="You're a helpful assistant."), + HumanMessage(content="What is the weather in Munich?"), + ] + parsed_history = llm._convert_messages(history, tools=[get_weather]) + expected_parsed_history = [ + {"role": "system", "content": ANY}, + {"role": "user", "content": "What is the weather in Munich?"}, + ] + assert parsed_history == expected_parsed_history + assert json.dumps(schema) in parsed_history[0]["content"] + + history += [ + AIMessage( + content="", + tool_calls=[ + { + "name": "get_weather", + "args": {"city": "Munich"}, + "id": "1", + "type": "tool_call", + } + ], + ), + ToolMessage(content="32", name="get_weather", tool_call_id="1"), + ] + parsed_history = llm._convert_messages(history, tools=[get_weather]) + expected_parsed_history = [ + {"role": "system", "content": ANY}, + {"role": "user", "content": "What is the weather in Munich?"}, + { + "role": "assistant", + "tool_calls": [ + { + "type": "function", + "id": "1", + "function": { + "name": "get_weather", + "arguments": '{"city": "Munich"}', + }, + } + ], + }, + {"role": "tool", "name": "get_weather", "content": "32", "tool_call_id": "1"}, + ] + assert parsed_history == expected_parsed_history + + +def test_parse_response(): + candidate: Dict[str, str] = { + "content": "London is the capital of Great Britain", + "role": "assistant", + } + assert _parse_response_candidate_llama(candidate) == AIMessage( + content="London is the capital of Great Britain" + ) + candidate = { + "content": ('{"name": "test_tool", "parameters": {"arg1": "test", "arg2": 2}}'), + "role": "assistant", + } + parsed = _parse_response_candidate_llama(candidate) + assert isinstance(parsed, AIMessage) + assert parsed.content == "" + assert parsed.tool_calls == [ + { + "name": "test_tool", + "args": {"arg1": "test", "arg2": 2}, + "id": ANY, + "type": "tool_call", + } + ]