From 78d8845366fc2bf06292f9bf7c0e363f88ca2662 Mon Sep 17 00:00:00 2001 From: Leonid Kuligin Date: Mon, 26 Feb 2024 14:17:07 +0100 Subject: [PATCH] Added Gemma (#24) --- .../langchain_google_vertexai/__init__.py | 13 +- .../langchain_google_vertexai/_base.py | 287 ++++++++++++++++++ .../langchain_google_vertexai/chat_models.py | 6 +- .../langchain_google_vertexai/embeddings.py | 2 +- .../langchain_google_vertexai/gemma.py | 217 +++++++++++++ .../langchain_google_vertexai/llms.py | 11 +- .../langchain_google_vertexai/model_garden.py | 72 +++++ .../tests/integration_tests/test_gemma.py | 93 ++++++ .../tests/integration_tests/test_llms.py | 85 +----- .../integration_tests/test_model_garden.py | 89 ++++++ .../vertexai/tests/unit_tests/test_imports.py | 4 + 11 files changed, 785 insertions(+), 94 deletions(-) create mode 100644 libs/vertexai/langchain_google_vertexai/_base.py create mode 100644 libs/vertexai/langchain_google_vertexai/gemma.py create mode 100644 libs/vertexai/langchain_google_vertexai/model_garden.py create mode 100644 libs/vertexai/tests/integration_tests/test_gemma.py create mode 100644 libs/vertexai/tests/integration_tests/test_model_garden.py diff --git a/libs/vertexai/langchain_google_vertexai/__init__.py b/libs/vertexai/langchain_google_vertexai/__init__.py index be365bde..9f4b90ee 100644 --- a/libs/vertexai/langchain_google_vertexai/__init__.py +++ b/libs/vertexai/langchain_google_vertexai/__init__.py @@ -3,10 +3,21 @@ from langchain_google_vertexai.chat_models import ChatVertexAI from langchain_google_vertexai.embeddings import VertexAIEmbeddings from langchain_google_vertexai.functions_utils import PydanticFunctionsOutputParser -from langchain_google_vertexai.llms import VertexAI, VertexAIModelGarden +from langchain_google_vertexai.gemma import ( + GemmaChatLocalKaggle, + GemmaChatVertexAIModelGarden, + GemmaLocalKaggle, + GemmaVertexAIModelGarden, +) +from langchain_google_vertexai.llms import VertexAI +from langchain_google_vertexai.model_garden import VertexAIModelGarden __all__ = [ "ChatVertexAI", + "GemmaVertexAIModelGarden", + "GemmaChatVertexAIModelGarden", + "GemmaLocalKaggle", + "GemmaChatLocalKaggle", "VertexAIEmbeddings", "VertexAI", "VertexAIModelGarden", diff --git a/libs/vertexai/langchain_google_vertexai/_base.py b/libs/vertexai/langchain_google_vertexai/_base.py new file mode 100644 index 00000000..1b2a9846 --- /dev/null +++ b/libs/vertexai/langchain_google_vertexai/_base.py @@ -0,0 +1,287 @@ +from __future__ import annotations + +from concurrent.futures import Executor +from typing import Any, ClassVar, Dict, List, Optional + +import vertexai # type: ignore[import-untyped] +from google.api_core.client_options import ClientOptions +from google.cloud.aiplatform.gapic import ( + PredictionServiceAsyncClient, + PredictionServiceClient, +) +from google.cloud.aiplatform.models import Prediction +from google.protobuf import json_format +from google.protobuf.struct_pb2 import Value +from langchain_core.outputs import Generation, LLMResult +from langchain_core.pydantic_v1 import BaseModel, Field, root_validator +from vertexai.language_models import ( # type: ignore[import-untyped] + TextGenerationModel, +) +from vertexai.preview.language_models import ( # type: ignore[import-untyped] + ChatModel as PreviewChatModel, +) +from vertexai.preview.language_models import ( + CodeChatModel as PreviewCodeChatModel, +) + +from langchain_google_vertexai._enums import HarmBlockThreshold, HarmCategory +from langchain_google_vertexai._utils import ( + get_client_info, + is_codey_model, + is_gemini_model, +) + +_PALM_DEFAULT_MAX_OUTPUT_TOKENS = TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS +_PALM_DEFAULT_TEMPERATURE = 0.0 +_PALM_DEFAULT_TOP_P = 0.95 +_PALM_DEFAULT_TOP_K = 40 + + +class _VertexAIBase(BaseModel): + client: Any = None #: :meta private: + project: Optional[str] = None + "The default GCP project to use when making Vertex API calls." + location: str = "us-central1" + "The default location to use when making API calls." + request_parallelism: int = 5 + "The amount of parallelism allowed for requests issued to VertexAI models. " + "Default is 5." + max_retries: int = 6 + """The maximum number of retries to make when generating.""" + task_executor: ClassVar[Optional[Executor]] = Field(default=None, exclude=True) + stop: Optional[List[str]] = None + "Optional list of stop words to use when generating." + model_name: Optional[str] = None + "Underlying model name." + + +class _VertexAICommon(_VertexAIBase): + client_preview: Any = None #: :meta private: + model_name: str + "Underlying model name." + temperature: Optional[float] = None + "Sampling temperature, it controls the degree of randomness in token selection." + max_output_tokens: Optional[int] = None + "Token limit determines the maximum amount of text output from one prompt." + top_p: Optional[float] = None + "Tokens are selected from most probable to least until the sum of their " + "probabilities equals the top-p value. Top-p is ignored for Codey models." + top_k: Optional[int] = None + "How the model selects tokens for output, the next token is selected from " + "among the top-k most probable tokens. Top-k is ignored for Codey models." + credentials: Any = Field(default=None, exclude=True) + "The default custom credentials (google.auth.credentials.Credentials) to use " + "when making API calls. If not provided, credentials will be ascertained from " + "the environment." + n: int = 1 + """How many completions to generate for each prompt.""" + streaming: bool = False + """Whether to stream the results or not.""" + safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None + """The default safety settings to use for all generations. + + For example: + + from langchain_google_vertexai import HarmBlockThreshold, HarmCategory + + safety_settings = { + HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH, + HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, + } + """ # noqa: E501 + + @property + def _llm_type(self) -> str: + return "vertexai" + + @property + def is_codey_model(self) -> bool: + return is_codey_model(self.model_name) + + @property + def _is_gemini_model(self) -> bool: + return is_gemini_model(self.model_name) + + @property + def _identifying_params(self) -> Dict[str, Any]: + """Gets the identifying parameters.""" + return {**{"model_name": self.model_name}, **self._default_params} + + @property + def _default_params(self) -> Dict[str, Any]: + if self._is_gemini_model: + default_params = {} + else: + default_params = { + "temperature": _PALM_DEFAULT_TEMPERATURE, + "max_output_tokens": _PALM_DEFAULT_MAX_OUTPUT_TOKENS, + "top_p": _PALM_DEFAULT_TOP_P, + "top_k": _PALM_DEFAULT_TOP_K, + } + params = { + "temperature": self.temperature, + "max_output_tokens": self.max_output_tokens, + "candidate_count": self.n, + } + if not self.is_codey_model: + params.update( + { + "top_k": self.top_k, + "top_p": self.top_p, + } + ) + updated_params = {} + for param_name, param_value in params.items(): + default_value = default_params.get(param_name) + if param_value or default_value: + updated_params[param_name] = ( + param_value if param_value else default_value + ) + return updated_params + + @classmethod + def _init_vertexai(cls, values: Dict) -> None: + vertexai.init( + project=values.get("project"), + location=values.get("location"), + credentials=values.get("credentials"), + ) + return None + + def _prepare_params( + self, + stop: Optional[List[str]] = None, + stream: bool = False, + **kwargs: Any, + ) -> dict: + stop_sequences = stop or self.stop + params_mapping = {"n": "candidate_count"} + params = {params_mapping.get(k, k): v for k, v in kwargs.items()} + params = {**self._default_params, "stop_sequences": stop_sequences, **params} + if stream or self.streaming: + params.pop("candidate_count") + return params + + def get_num_tokens(self, text: str) -> int: + """Get the number of tokens present in the text. + + Useful for checking if an input will fit in a model's context window. + + Args: + text: The string input to tokenize. + + Returns: + The integer number of tokens in the text. + """ + is_palm_chat_model = isinstance( + self.client_preview, PreviewChatModel + ) or isinstance(self.client_preview, PreviewCodeChatModel) + if is_palm_chat_model: + result = self.client_preview.start_chat().count_tokens(text) + else: + result = self.client_preview.count_tokens([text]) + + return result.total_tokens + + +class _BaseVertexAIModelGarden(_VertexAIBase): + """Large language models served from Vertex AI Model Garden.""" + + async_client: Any = None #: :meta private: + endpoint_id: str + "A name of an endpoint where the model has been deployed." + allowed_model_args: Optional[List[str]] = None + "Allowed optional args to be passed to the model." + prompt_arg: str = "prompt" + result_arg: Optional[str] = "generated_text" + "Set result_arg to None if output of the model is expected to be a string." + "Otherwise, if it's a dict, provided an argument that contains the result." + single_example_per_request: bool = True + "LLM endpoint currently serves only the first example in the request" + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that the python package exists in environment.""" + + if not values["project"]: + raise ValueError( + "A GCP project should be provided to run inference on Model Garden!" + ) + + client_options = ClientOptions( + api_endpoint=f"{values['location']}-aiplatform.googleapis.com" + ) + client_info = get_client_info(module="vertex-ai-model-garden") + values["client"] = PredictionServiceClient( + client_options=client_options, client_info=client_info + ) + values["async_client"] = PredictionServiceAsyncClient( + client_options=client_options, client_info=client_info + ) + return values + + @property + def endpoint_path(self) -> str: + return self.client.endpoint_path( + project=self.project, location=self.location, endpoint=self.endpoint_id + ) + + @property + def _llm_type(self) -> str: + return "vertexai_model_garden" + + def _prepare_request(self, prompts: List[str], **kwargs: Any) -> List["Value"]: + instances = [] + for prompt in prompts: + if self.allowed_model_args: + instance = { + k: v for k, v in kwargs.items() if k in self.allowed_model_args + } + else: + instance = {} + instance[self.prompt_arg] = prompt + instances.append(instance) + + predict_instances = [ + json_format.ParseDict(instance_dict, Value()) for instance_dict in instances + ] + return predict_instances + + def _parse_response(self, predictions: "Prediction") -> LLMResult: + generations: List[List[Generation]] = [] + for result in predictions.predictions: + if isinstance(result, str): + generations.append([Generation(text=self._parse_prediction(result))]) + else: + generations.append( + [ + Generation(text=self._parse_prediction(prediction)) + for prediction in result + ] + ) + return LLMResult(generations=generations) + + def _parse_prediction(self, prediction: Any) -> str: + if isinstance(prediction, str): + return prediction + + if self.result_arg: + try: + return prediction[self.result_arg] + except KeyError: + if isinstance(prediction, str): + error_desc = ( + "Provided non-None `result_arg` (result_arg=" + f"{self.result_arg}). But got prediction of type " + f"{type(prediction)} instead of dict. Most probably, you" + "need to set `result_arg=None` during VertexAIModelGarden " + "initialization." + ) + raise ValueError(error_desc) + else: + raise ValueError(f"{self.result_arg} key not found in prediction!") + + return prediction diff --git a/libs/vertexai/langchain_google_vertexai/chat_models.py b/libs/vertexai/langchain_google_vertexai/chat_models.py index 3f7f42d8..c7129a6d 100644 --- a/libs/vertexai/langchain_google_vertexai/chat_models.py +++ b/libs/vertexai/langchain_google_vertexai/chat_models.py @@ -49,6 +49,9 @@ CodeChatModel as PreviewCodeChatModel, ) +from langchain_google_vertexai._base import ( + _VertexAICommon, +) from langchain_google_vertexai._image_utils import ImageBytesLoader from langchain_google_vertexai._utils import ( get_generation_info, @@ -58,9 +61,6 @@ from langchain_google_vertexai.functions_utils import ( _format_tools_to_vertex_tool, ) -from langchain_google_vertexai.llms import ( - _VertexAICommon, -) logger = logging.getLogger(__name__) diff --git a/libs/vertexai/langchain_google_vertexai/embeddings.py b/libs/vertexai/langchain_google_vertexai/embeddings.py index 041c21fb..f7a9b97f 100644 --- a/libs/vertexai/langchain_google_vertexai/embeddings.py +++ b/libs/vertexai/langchain_google_vertexai/embeddings.py @@ -20,7 +20,7 @@ TextEmbeddingModel, ) -from langchain_google_vertexai.llms import _VertexAICommon +from langchain_google_vertexai._base import _VertexAICommon logger = logging.getLogger(__name__) diff --git a/libs/vertexai/langchain_google_vertexai/gemma.py b/libs/vertexai/langchain_google_vertexai/gemma.py new file mode 100644 index 00000000..5b90d0bc --- /dev/null +++ b/libs/vertexai/langchain_google_vertexai/gemma.py @@ -0,0 +1,217 @@ +import os +from typing import Any, Dict, List, Optional, cast + +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain_core.language_models.chat_models import ( + BaseChatModel, +) +from langchain_core.language_models.llms import BaseLLM +from langchain_core.messages import ( + AIMessage, + BaseMessage, + HumanMessage, + SystemMessage, +) +from langchain_core.outputs import ( + ChatGeneration, + ChatResult, + Generation, + LLMResult, +) +from langchain_core.pydantic_v1 import BaseModel, root_validator + +from langchain_google_vertexai._base import _BaseVertexAIModelGarden +from langchain_google_vertexai.model_garden import VertexAIModelGarden + +USER_CHAT_TEMPLATE = "user\n{prompt}\n" +MODEL_CHAT_TEMPLATE = "model\n{prompt}\n" + + +def gemma_messages_to_prompt(history: List[BaseMessage]) -> str: + """Converts a list of messages to a chat prompt for Gemma.""" + messages: List[str] = [] + if len(messages) == 1: + content = cast(str, history[0].content) + if isinstance(history[0], SystemMessage): + raise ValueError("Gemma currently doesn't support system message!") + return content + for message in history: + content = cast(str, message.content) + if isinstance(message, SystemMessage): + raise ValueError("Gemma currently doesn't support system message!") + elif isinstance(message, AIMessage): + messages.append(MODEL_CHAT_TEMPLATE.format(prompt=content)) + elif isinstance(message, HumanMessage): + messages.append(USER_CHAT_TEMPLATE.format(prompt=content)) + else: + raise ValueError(f"Unexpected message with type {type(message)}") + messages.append("model\n") + return "".join(messages) + + +class _GemmaBase(BaseModel): + max_tokens: Optional[int] = None + """The maximum number of tokens to generate.""" + temperature: Optional[float] = None + """The temperature to use for sampling.""" + top_p: Optional[float] = None + """The top-p value to use for sampling.""" + top_k: Optional[int] = None + """The top-k value to use for sampling.""" + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling gemma.""" + params = { + "max_tokens": self.max_tokens, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + } + return {k: v for k, v in params.items() if v is not None} + + def _get_params(self, **kwargs) -> Dict[str, Any]: + return {k: kwargs.get(k, v) for k, v in self._default_params.items()} + + +class GemmaVertexAIModelGarden(VertexAIModelGarden): + allowed_model_args: Optional[List[str]] = [ + "temperature", + "top_p", + "top_k", + "max_tokens", + ] + + @property + def _llm_type(self) -> str: + return "gemma_vertexai_model_garden" + + +class GemmaChatVertexAIModelGarden(_GemmaBase, _BaseVertexAIModelGarden, BaseChatModel): + allowed_model_args: Optional[List[str]] = [ + "temperature", + "top_p", + "top_k", + "max_tokens", + ] + + @property + def _llm_type(self) -> str: + return "gemma_vertexai_model_garden" + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling gemma.""" + params = {"max_length": self.max_tokens} + return {k: v for k, v in params.items() if v is not None} + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + request = self._get_params(**kwargs) + request["prompt"] = gemma_messages_to_prompt(messages) + output = self.client.predict(endpoint=self.endpoint_path, instances=[request]) + generations = [ + ChatGeneration( + message=AIMessage(content=output.predictions[0]), + ) + ] + return ChatResult(generations=generations) + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """Top Level call""" + request = self._get_params(**kwargs) + request["prompt"] = gemma_messages_to_prompt(messages) + output = await self.async_client.predict_( + endpoint=self.endpoint_path, instances=[request] + ) + generations = [ + ChatGeneration( + message=AIMessage(content=output.predictions[0]), + ) + ] + return ChatResult(generations=generations) + + +class _GemmaLocalKaggleBase(_GemmaBase): + """Local gemma model.""" + + client: Any = None #: :meta private: + keras_backend: str = "jax" + model_name: str = "gemma_2b_en" + """Gemma model name.""" + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that llama-cpp-python library is installed.""" + try: + os.environ["KERAS_BACKEND"] = values["keras_backend"] + from keras_nlp.models import GemmaCausalLM # type: ignore + except ImportError: + raise ImportError( + "Could not import GemmaCausalLM library. " + "Please install the GemmaCausalLM library to " + "use this model: pip install keras-nlp keras>=3 kaggle" + ) + + values["client"] = GemmaCausalLM.from_preset(values["model_name"]) + return values + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling gemma.""" + params = {"max_length": self.max_tokens} + return {k: v for k, v in params.items() if v is not None} + + +class GemmaLocalKaggle(_GemmaLocalKaggleBase, BaseLLM): + def _generate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> LLMResult: + """Run the LLM on the given prompt and input.""" + params = {"max_length": self.max_tokens} if self.max_tokens else {} + results = self.client.generate(prompts, **params) + results = results if isinstance(results, str) else [results] + return LLMResult(generations=[[Generation(text=result)] for result in results]) + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "gemma_local_kaggle" + + +class GemmaChatLocalKaggle(_GemmaLocalKaggleBase, BaseChatModel): + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + params = {"max_length": self.max_tokens} if self.max_tokens else {} + prompt = gemma_messages_to_prompt(messages) + output = self.client.generate(prompt, **params) + generation = ChatGeneration(message=AIMessage(content=output)) + return ChatResult(generations=[generation]) + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "gemma_local_chat_kaggle" diff --git a/libs/vertexai/langchain_google_vertexai/llms.py b/libs/vertexai/langchain_google_vertexai/llms.py index c236ab3b..9257269c 100644 --- a/libs/vertexai/langchain_google_vertexai/llms.py +++ b/libs/vertexai/langchain_google_vertexai/llms.py @@ -44,6 +44,12 @@ TextGenerationModel as PreviewTextGenerationModel, ) +from langchain_google_vertexai._base import ( + _PALM_DEFAULT_MAX_OUTPUT_TOKENS, + _PALM_DEFAULT_TEMPERATURE, + _PALM_DEFAULT_TOP_K, + _PALM_DEFAULT_TOP_P, +) from langchain_google_vertexai._enums import HarmBlockThreshold, HarmCategory from langchain_google_vertexai._utils import ( create_retry_decorator, @@ -53,11 +59,6 @@ is_gemini_model, ) -_PALM_DEFAULT_MAX_OUTPUT_TOKENS = TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS -_PALM_DEFAULT_TEMPERATURE = 0.0 -_PALM_DEFAULT_TOP_P = 0.95 -_PALM_DEFAULT_TOP_K = 40 - def _completion_with_retry( llm: VertexAI, diff --git a/libs/vertexai/langchain_google_vertexai/model_garden.py b/libs/vertexai/langchain_google_vertexai/model_garden.py new file mode 100644 index 00000000..6469fb42 --- /dev/null +++ b/libs/vertexai/langchain_google_vertexai/model_garden.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +import asyncio +from typing import Any, List, Optional + +from langchain_core.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain_core.language_models.llms import BaseLLM +from langchain_core.outputs import Generation, LLMResult + +from langchain_google_vertexai._base import _BaseVertexAIModelGarden + + +class VertexAIModelGarden(_BaseVertexAIModelGarden, BaseLLM): + """Large language models served from Vertex AI Model Garden.""" + + def _generate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> LLMResult: + """Run the LLM on the given prompt and input.""" + instances = self._prepare_request(prompts, **kwargs) + + if self.single_example_per_request and len(instances) > 1: + results = [] + for instance in instances: + response = self.client.predict( + endpoint=self.endpoint_path, instances=[instance] + ) + results.append(self._parse_prediction(response.predictions[0])) + return LLMResult( + generations=[[Generation(text=result)] for result in results] + ) + + response = self.client.predict(endpoint=self.endpoint_path, instances=instances) + return self._parse_response(response) + + async def _agenerate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> LLMResult: + """Run the LLM on the given prompt and input.""" + instances = self._prepare_request(prompts, **kwargs) + if self.single_example_per_request and len(instances) > 1: + responses = [] + for instance in instances: + responses.append( + self.async_client.predict( + endpoint=self.endpoint_path, instances=[instance] + ) + ) + + responses = await asyncio.gather(*responses) + return LLMResult( + generations=[ + [Generation(text=self._parse_prediction(response.predictions[0]))] + for response in responses + ] + ) + + response = await self.async_client.predict( + endpoint=self.endpoint_path, instances=instances + ) + return self._parse_response(response) diff --git a/libs/vertexai/tests/integration_tests/test_gemma.py b/libs/vertexai/tests/integration_tests/test_gemma.py new file mode 100644 index 00000000..de2d27ec --- /dev/null +++ b/libs/vertexai/tests/integration_tests/test_gemma.py @@ -0,0 +1,93 @@ +import os + +import pytest +from langchain_core.messages import ( + AIMessage, + HumanMessage, +) + +from langchain_google_vertexai import ( + GemmaChatLocalKaggle, + GemmaChatVertexAIModelGarden, + GemmaLocalKaggle, + GemmaVertexAIModelGarden, +) + + +@pytest.mark.skip("CI testing not set up") +@pytest.mark.skip("CI testing not set up") +def test_gemma_model_garden() -> None: + """In order to run this test, you should provide endpoint names. + + Example: + export GEMMA_ENDPOINT_ID=... + export PROJECT=... + """ + endpoint_id = os.environ["GEMMA_ENDPOINT_ID"] + project = os.environ["PROJECT"] + location = "us-central1" + llm = GemmaVertexAIModelGarden( + endpoint_id=endpoint_id, + project=project, + location=location, + ) + output = llm.invoke("What is the meaning of life?") + assert isinstance(output, str) + assert len(output) > 2 + assert llm._llm_type == "gemma_vertexai_model_garden" + + +@pytest.mark.skip("CI testing not set up") +def test_gemma_chat_model_garden() -> None: + """In order to run this test, you should provide endpoint names. + + Example: + export GEMMA_ENDPOINT_ID=... + export PROJECT=... + """ + endpoint_id = os.environ["GEMMA_ENDPOINT_ID"] + project = os.environ["PROJECT"] + location = "us-central1" + llm = GemmaChatVertexAIModelGarden( + endpoint_id=endpoint_id, + project=project, + location=location, + ) + assert llm._llm_type == "gemma_vertexai_model_garden" + + text_question1, text_answer1 = "How much is 2+2?", "4" + text_question2 = "How much is 3+3?" + message1 = HumanMessage(content=text_question1) + message2 = AIMessage(content=text_answer1) + message3 = HumanMessage(content=text_question2) + output = llm.invoke([message1]) + assert isinstance(output, AIMessage) + assert len(output.content) > 2 + output = llm.invoke([message1, message2, message3]) + assert isinstance(output, AIMessage) + assert len(output.content) > 2 + + +@pytest.mark.skip("CI testing not set up") +def test_gemma_kaggle() -> None: + llm = GemmaLocalKaggle(model_name="gemma_2b_en") + output = llm.invoke("What is the meaning of life?") + assert isinstance(output, str) + print(output) + assert len(output) > 2 + + +@pytest.mark.skip("CI testing not set up") +def test_gemma_chat_kaggle() -> None: + llm = GemmaChatLocalKaggle(model_name="gemma_2b_en") + text_question1, text_answer1 = "How much is 2+2?", "4" + text_question2 = "How much is 3+3?" + message1 = HumanMessage(content=text_question1) + message2 = AIMessage(content=text_answer1) + message3 = HumanMessage(content=text_question2) + output = llm.invoke([message1]) + assert isinstance(output, AIMessage) + assert len(output.content) > 2 + output = llm.invoke([message1, message2, message3]) + assert isinstance(output, AIMessage) + assert len(output.content) > 2 diff --git a/libs/vertexai/tests/integration_tests/test_llms.py b/libs/vertexai/tests/integration_tests/test_llms.py index 6b252f42..f9525c8b 100644 --- a/libs/vertexai/tests/integration_tests/test_llms.py +++ b/libs/vertexai/tests/integration_tests/test_llms.py @@ -3,13 +3,11 @@ Your end-user credentials would be used to make the calls (make sure you've run `gcloud auth login` first). """ -import os -from typing import Optional import pytest from langchain_core.outputs import LLMResult -from langchain_google_vertexai.llms import VertexAI, VertexAIModelGarden +from langchain_google_vertexai.llms import VertexAI model_names_to_test = ["text-bison@001", "gemini-pro"] model_names_to_test_with_default = [None] + model_names_to_test @@ -119,87 +117,6 @@ async def test_astream() -> None: assert isinstance(token, str) -@pytest.mark.skip("CI testing not set up") -@pytest.mark.parametrize( - "endpoint_os_variable_name,result_arg", - [("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)], -) -def test_model_garden( - endpoint_os_variable_name: str, result_arg: Optional[str] -) -> None: - """In order to run this test, you should provide endpoint names. - - Example: - export FALCON_ENDPOINT_ID=... - export LLAMA_ENDPOINT_ID=... - export PROJECT=... - """ - endpoint_id = os.environ[endpoint_os_variable_name] - project = os.environ["PROJECT"] - location = "europe-west4" - llm = VertexAIModelGarden( - endpoint_id=endpoint_id, - project=project, - result_arg=result_arg, - location=location, - ) - output = llm("What is the meaning of life?") - assert isinstance(output, str) - assert llm._llm_type == "vertexai_model_garden" - - -@pytest.mark.skip("CI testing not set up") -@pytest.mark.parametrize( - "endpoint_os_variable_name,result_arg", - [("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)], -) -def test_model_garden_generate( - endpoint_os_variable_name: str, result_arg: Optional[str] -) -> None: - """In order to run this test, you should provide endpoint names. - - Example: - export FALCON_ENDPOINT_ID=... - export LLAMA_ENDPOINT_ID=... - export PROJECT=... - """ - endpoint_id = os.environ[endpoint_os_variable_name] - project = os.environ["PROJECT"] - location = "europe-west4" - llm = VertexAIModelGarden( - endpoint_id=endpoint_id, - project=project, - result_arg=result_arg, - location=location, - ) - output = llm.generate(["What is the meaning of life?", "How much is 2+2"]) - assert isinstance(output, LLMResult) - assert len(output.generations) == 2 - - -@pytest.mark.skip("CI testing not set up") -@pytest.mark.asyncio -@pytest.mark.parametrize( - "endpoint_os_variable_name,result_arg", - [("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)], -) -async def test_model_garden_agenerate( - endpoint_os_variable_name: str, result_arg: Optional[str] -) -> None: - endpoint_id = os.environ[endpoint_os_variable_name] - project = os.environ["PROJECT"] - location = "europe-west4" - llm = VertexAIModelGarden( - endpoint_id=endpoint_id, - project=project, - result_arg=result_arg, - location=location, - ) - output = await llm.agenerate(["What is the meaning of life?", "How much is 2+2"]) - assert isinstance(output, LLMResult) - assert len(output.generations) == 2 - - @pytest.mark.parametrize( "model_name", model_names_to_test, diff --git a/libs/vertexai/tests/integration_tests/test_model_garden.py b/libs/vertexai/tests/integration_tests/test_model_garden.py new file mode 100644 index 00000000..5500692e --- /dev/null +++ b/libs/vertexai/tests/integration_tests/test_model_garden.py @@ -0,0 +1,89 @@ +import os +from typing import Optional + +import pytest +from langchain_core.outputs import LLMResult + +from langchain_google_vertexai import VertexAIModelGarden + + +@pytest.mark.skip("CI testing not set up") +@pytest.mark.parametrize( + "endpoint_os_variable_name,result_arg", + [("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)], +) +def test_model_garden( + endpoint_os_variable_name: str, result_arg: Optional[str] +) -> None: + """In order to run this test, you should provide endpoint names. + + Example: + export FALCON_ENDPOINT_ID=... + export LLAMA_ENDPOINT_ID=... + export PROJECT=... + """ + endpoint_id = os.environ[endpoint_os_variable_name] + project = os.environ["PROJECT"] + location = "europe-west4" + llm = VertexAIModelGarden( + endpoint_id=endpoint_id, + project=project, + result_arg=result_arg, + location=location, + ) + output = llm("What is the meaning of life?") + assert isinstance(output, str) + print(output) + assert llm._llm_type == "vertexai_model_garden" + + +@pytest.mark.skip("CI testing not set up") +@pytest.mark.parametrize( + "endpoint_os_variable_name,result_arg", + [("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)], +) +def test_model_garden_generate( + endpoint_os_variable_name: str, result_arg: Optional[str] +) -> None: + """In order to run this test, you should provide endpoint names. + + Example: + export FALCON_ENDPOINT_ID=... + export LLAMA_ENDPOINT_ID=... + export PROJECT=... + """ + endpoint_id = os.environ[endpoint_os_variable_name] + project = os.environ["PROJECT"] + location = "europe-west4" + llm = VertexAIModelGarden( + endpoint_id=endpoint_id, + project=project, + result_arg=result_arg, + location=location, + ) + output = llm.generate(["What is the meaning of life?", "How much is 2+2"]) + assert isinstance(output, LLMResult) + assert len(output.generations) == 2 + + +@pytest.mark.skip("CI testing not set up") +@pytest.mark.asyncio +@pytest.mark.parametrize( + "endpoint_os_variable_name,result_arg", + [("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)], +) +async def test_model_garden_agenerate( + endpoint_os_variable_name: str, result_arg: Optional[str] +) -> None: + endpoint_id = os.environ[endpoint_os_variable_name] + project = os.environ["PROJECT"] + location = "europe-west4" + llm = VertexAIModelGarden( + endpoint_id=endpoint_id, + project=project, + result_arg=result_arg, + location=location, + ) + output = await llm.agenerate(["What is the meaning of life?", "How much is 2+2"]) + assert isinstance(output, LLMResult) + assert len(output.generations) == 2 diff --git a/libs/vertexai/tests/unit_tests/test_imports.py b/libs/vertexai/tests/unit_tests/test_imports.py index 7afa74f1..e207e7e6 100644 --- a/libs/vertexai/tests/unit_tests/test_imports.py +++ b/libs/vertexai/tests/unit_tests/test_imports.py @@ -2,6 +2,10 @@ EXPECTED_ALL = [ "ChatVertexAI", + "GemmaVertexAIModelGarden", + "GemmaChatVertexAIModelGarden", + "GemmaLocalKaggle", + "GemmaChatLocalKaggle", "VertexAIEmbeddings", "VertexAI", "VertexAIModelGarden",