diff --git a/libs/vertexai/langchain_google_vertexai/_base.py b/libs/vertexai/langchain_google_vertexai/_base.py index 495a6a512..168c901b7 100644 --- a/libs/vertexai/langchain_google_vertexai/_base.py +++ b/libs/vertexai/langchain_google_vertexai/_base.py @@ -143,9 +143,9 @@ def _default_params(self) -> Dict[str, Any]: updated_params = {} for param_name, param_value in params.items(): default_value = default_params.get(param_name) - if param_value or default_value: + if param_value is not None or default_value is not None: updated_params[param_name] = ( - param_value if param_value else default_value + param_value if param_value is not None else default_value ) return updated_params diff --git a/libs/vertexai/langchain_google_vertexai/llms.py b/libs/vertexai/langchain_google_vertexai/llms.py index 7a873a843..c616ce755 100644 --- a/libs/vertexai/langchain_google_vertexai/llms.py +++ b/libs/vertexai/langchain_google_vertexai/llms.py @@ -1,9 +1,7 @@ from __future__ import annotations -from concurrent.futures import Executor -from typing import Any, AsyncIterator, ClassVar, Dict, Iterator, List, Optional, Union +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union -import vertexai # type: ignore[import-untyped] from google.cloud.aiplatform import telemetry from langchain_core.callbacks.manager import ( AsyncCallbackManagerForLLMRun, @@ -11,7 +9,7 @@ ) from langchain_core.language_models.llms import BaseLLM from langchain_core.outputs import Generation, GenerationChunk, LLMResult -from langchain_core.pydantic_v1 import BaseModel, Field, root_validator +from langchain_core.pydantic_v1 import root_validator from vertexai.generative_models import ( # type: ignore[import-untyped] Candidate, GenerativeModel, @@ -25,12 +23,6 @@ TextGenerationResponse, ) from vertexai.preview.language_models import ( # type: ignore[import-untyped] - ChatModel as PreviewChatModel, -) -from vertexai.preview.language_models import ( - CodeChatModel as PreviewCodeChatModel, -) -from vertexai.preview.language_models import ( CodeGenerationModel as PreviewCodeGenerationModel, ) from vertexai.preview.language_models import ( @@ -38,16 +30,11 @@ ) from langchain_google_vertexai._base import ( - _PALM_DEFAULT_MAX_OUTPUT_TOKENS, - _PALM_DEFAULT_TEMPERATURE, - _PALM_DEFAULT_TOP_K, - _PALM_DEFAULT_TOP_P, + _VertexAICommon, ) -from langchain_google_vertexai._enums import HarmBlockThreshold, HarmCategory from langchain_google_vertexai._utils import ( create_retry_decorator, get_generation_info, - get_user_agent, is_codey_model, is_gemini_model, ) @@ -120,168 +107,6 @@ async def _acompletion_with_retry_inner( ) -class _VertexAIBase(BaseModel): - project: Optional[str] = None - "The default GCP project to use when making Vertex API calls." - location: str = "us-central1" - "The default location to use when making API calls." - request_parallelism: int = 5 - "The amount of parallelism allowed for requests issued to VertexAI models. " - "Default is 5." - max_retries: int = 6 - """The maximum number of retries to make when generating.""" - task_executor: ClassVar[Optional[Executor]] = Field(default=None, exclude=True) - stop: Optional[List[str]] = None - "Optional list of stop words to use when generating." - model_name: Optional[str] = None - "Underlying model name." - - @root_validator(pre=True) - def validate_params(cls, values: dict) -> dict: - if "model" in values and "model_name" not in values: - values["model_name"] = values.pop("model") - return values - - -class _VertexAICommon(_VertexAIBase): - client: Any = None #: :meta private: - client_preview: Any = None #: :meta private: - model_name: str - "Underlying model name." - temperature: Optional[float] = None - "Sampling temperature, it controls the degree of randomness in token selection." - max_output_tokens: Optional[int] = None - "Token limit determines the maximum amount of text output from one prompt." - top_p: Optional[float] = None - "Tokens are selected from most probable to least until the sum of their " - "probabilities equals the top-p value. Top-p is ignored for Codey models." - top_k: Optional[int] = None - "How the model selects tokens for output, the next token is selected from " - "among the top-k most probable tokens. Top-k is ignored for Codey models." - credentials: Any = Field(default=None, exclude=True) - "The default custom credentials (google.auth.credentials.Credentials) to use " - "when making API calls. If not provided, credentials will be ascertained from " - "the environment." - n: int = 1 - """How many completions to generate for each prompt.""" - streaming: bool = False - """Whether to stream the results or not.""" - safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None - """The default safety settings to use for all generations. - - For example: - - from langchain_google_vertexai import HarmBlockThreshold, HarmCategory - - safety_settings = { - HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE, - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, - HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH, - HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, - HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, - } - """ # noqa: E501 - - @property - def _llm_type(self) -> str: - return "vertexai" - - @property - def is_codey_model(self) -> bool: - return is_codey_model(self.model_name) - - @property - def _is_gemini_model(self) -> bool: - return is_gemini_model(self.model_name) - - @property - def _identifying_params(self) -> Dict[str, Any]: - """Gets the identifying parameters.""" - return {**{"model_name": self.model_name}, **self._default_params} - - @property - def _default_params(self) -> Dict[str, Any]: - if self._is_gemini_model: - default_params = {} - else: - default_params = { - "temperature": _PALM_DEFAULT_TEMPERATURE, - "max_output_tokens": _PALM_DEFAULT_MAX_OUTPUT_TOKENS, - "top_p": _PALM_DEFAULT_TOP_P, - "top_k": _PALM_DEFAULT_TOP_K, - } - params = { - "temperature": self.temperature, - "max_output_tokens": self.max_output_tokens, - "candidate_count": self.n, - } - if not self.is_codey_model: - params.update( - { - "top_k": self.top_k, - "top_p": self.top_p, - } - ) - updated_params = {} - for param_name, param_value in params.items(): - default_value = default_params.get(param_name) - if param_value is not None or default_value is not None: - updated_params[param_name] = ( - param_value if param_value is not None else default_value - ) - return updated_params - - @property - def _user_agent(self) -> str: - """Gets the User Agent.""" - _, user_agent = get_user_agent(f"{type(self).__name__}_{self.model_name}") - return user_agent - - @classmethod - def _init_vertexai(cls, values: Dict) -> None: - vertexai.init( - project=values.get("project"), - location=values.get("location"), - credentials=values.get("credentials"), - ) - return None - - def _prepare_params( - self, - stop: Optional[List[str]] = None, - stream: bool = False, - **kwargs: Any, - ) -> dict: - stop_sequences = stop or self.stop - params_mapping = {"n": "candidate_count"} - params = {params_mapping.get(k, k): v for k, v in kwargs.items()} - params = {**self._default_params, "stop_sequences": stop_sequences, **params} - if stream or self.streaming: - params.pop("candidate_count") - return params - - def get_num_tokens(self, text: str) -> int: - """Get the number of tokens present in the text. - - Useful for checking if an input will fit in a model's context window. - - Args: - text: The string input to tokenize. - - Returns: - The integer number of tokens in the text. - """ - is_palm_chat_model = isinstance( - self.client_preview, PreviewChatModel - ) or isinstance(self.client_preview, PreviewCodeChatModel) - if is_palm_chat_model: - result = self.client_preview.start_chat().count_tokens(text) - else: - result = self.client_preview.count_tokens([text]) - - return result.total_tokens - - class VertexAI(_VertexAICommon, BaseLLM): """Google Vertex AI large language models.""" diff --git a/libs/vertexai/tests/unit_tests/test_llm.py b/libs/vertexai/tests/unit_tests/test_llm.py index a0d0ac88d..72976d1bf 100644 --- a/libs/vertexai/tests/unit_tests/test_llm.py +++ b/libs/vertexai/tests/unit_tests/test_llm.py @@ -6,12 +6,12 @@ def test_model_name() -> None: - llm = VertexAI() + llm = VertexAI(project="test-project") assert llm.model_name == "text-bison" for llm in [ - VertexAI(model_name="text-bison@001"), - VertexAI(model="text-bison@001"), # type: ignore[call-arg] + VertexAI(model_name="text-bison@001", project="test-project"), + VertexAI(model="text-bison@001", project="test-project"), # type: ignore[call-arg] ]: assert llm.model_name == "text-bison@001"