diff --git a/libs/partners/google-genai/langchain_google_genai/chat_models.py b/libs/partners/google-genai/langchain_google_genai/chat_models.py index 5f38693ce5a5d..37cc5f342d689 100644 --- a/libs/partners/google-genai/langchain_google_genai/chat_models.py +++ b/libs/partners/google-genai/langchain_google_genai/chat_models.py @@ -42,7 +42,7 @@ SystemMessage, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.pydantic_v1 import Field, SecretStr, root_validator +from langchain_core.pydantic_v1 import SecretStr, root_validator from langchain_core.utils import get_from_dict_or_env from tenacity import ( before_sleep_log, @@ -53,6 +53,7 @@ ) from langchain_google_genai._common import GoogleGenerativeAIError +from langchain_google_genai.llms import GoogleModelFamily, _BaseGoogleGenerativeAI IMAGE_TYPES: Tuple = () try: @@ -417,7 +418,7 @@ def _response_to_result( return ChatResult(generations=generations, llm_output=llm_output) -class ChatGoogleGenerativeAI(BaseChatModel): +class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel): """`Google Generative AI` Chat models API. To use, you must have either: @@ -435,53 +436,13 @@ class ChatGoogleGenerativeAI(BaseChatModel): """ - model: str = Field( - ..., - description="""The name of the model to use. -Supported examples: - - gemini-pro""", - ) - max_output_tokens: int = Field(default=None, description="Max output tokens") - client: Any #: :meta private: - google_api_key: Optional[SecretStr] = None - temperature: Optional[float] = None - """Run inference with this temperature. Must by in the closed - interval [0.0, 1.0].""" - top_k: Optional[int] = None - """Decode using top-k sampling: consider the set of top_k most probable tokens. - Must be positive.""" - top_p: Optional[float] = None - """The maximum cumulative probability of tokens to consider when sampling. - - The model uses combined Top-k and nucleus sampling. - - Tokens are sorted based on their assigned probabilities so - that only the most likely tokens are considered. Top-k - sampling directly limits the maximum number of tokens to - consider, while Nucleus sampling limits number of tokens - based on the cumulative probability. - - Note: The default value varies by model, see the - `Model.top_p` attribute of the `Model` returned the - `genai.get_model` function. - """ - n: int = Field(default=1, alias="candidate_count") - """Number of chat completions to generate for each prompt. Note that the API may - not return the full n completions if duplicates are generated.""" + convert_system_message_to_human: bool = False """Whether to merge any leading SystemMessage into the following HumanMessage. Gemini does not support system messages; any unsupported messages will raise an error.""" - client_options: Optional[Dict] = Field( - None, - description="Client options to pass to the Google API client.", - ) - transport: Optional[str] = Field( - None, - description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].", - ) class Config: allow_population_by_field_name = True @@ -494,10 +455,6 @@ def lc_secrets(self) -> Dict[str, str]: def _llm_type(self) -> str: return "chat-google-generative-ai" - @property - def _is_geminiai(self) -> bool: - return self.model is not None and "gemini" in self.model - @classmethod def is_lc_serializable(self) -> bool: return True @@ -658,3 +615,23 @@ def _prepare_chat( message = history.pop() chat = self.client.start_chat(history=history) return params, chat, message + + def get_num_tokens(self, text: str) -> int: + """Get the number of tokens present in the text. + + Useful for checking if an input will fit in a model's context window. + + Args: + text: The string input to tokenize. + + Returns: + The integer number of tokens in the text. + """ + if self._model_family == GoogleModelFamily.GEMINI: + result = self.client.count_tokens(text) + token_count = result.total_tokens + else: + result = self.client.count_text_tokens(model=self.model, prompt=text) + token_count = result["token_count"] + + return token_count diff --git a/libs/partners/google-genai/langchain_google_genai/llms.py b/libs/partners/google-genai/langchain_google_genai/llms.py index 751ca8afebeb7..bf147410ce8d6 100644 --- a/libs/partners/google-genai/langchain_google_genai/llms.py +++ b/libs/partners/google-genai/langchain_google_genai/llms.py @@ -1,5 +1,6 @@ from __future__ import annotations +from enum import Enum, auto from typing import Any, Callable, Dict, Iterator, List, Optional, Union import google.api_core @@ -15,6 +16,19 @@ from langchain_core.utils import get_from_dict_or_env +class GoogleModelFamily(str, Enum): + GEMINI = auto() + PALM = auto() + + @classmethod + def _missing_(cls, value: Any) -> Optional["GoogleModelFamily"]: + if "gemini" in value.lower(): + return GoogleModelFamily.GEMINI + elif "text-bison" in value.lower(): + return GoogleModelFamily.PALM + return None + + def _create_retry_decorator( llm: BaseLLM, *, @@ -75,10 +89,6 @@ def _completion_with_retry( ) -def _is_gemini_model(model_name: str) -> bool: - return "gemini" in model_name - - def _strip_erroneous_leading_spaces(text: str) -> str: """Strip erroneous leading spaces from text. @@ -92,17 +102,9 @@ def _strip_erroneous_leading_spaces(text: str) -> str: return text -class GoogleGenerativeAI(BaseLLM, BaseModel): - """Google GenerativeAI models. - - Example: - .. code-block:: python - - from langchain_google_genai import GoogleGenerativeAI - llm = GoogleGenerativeAI(model="gemini-pro") - """ +class _BaseGoogleGenerativeAI(BaseModel): + """Base class for Google Generative AI LLMs""" - client: Any #: :meta private: model: str = Field( ..., description="""The name of the model to use. @@ -141,15 +143,27 @@ class GoogleGenerativeAI(BaseLLM, BaseModel): description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].", ) - @property - def is_gemini(self) -> bool: - """Returns whether a model is belongs to a Gemini family or not.""" - return _is_gemini_model(self.model) - @property def lc_secrets(self) -> Dict[str, str]: return {"google_api_key": "GOOGLE_API_KEY"} + @property + def _model_family(self) -> str: + return GoogleModelFamily(self.model) + + +class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM): + """Google GenerativeAI models. + + Example: + .. code-block:: python + + from langchain_google_genai import GoogleGenerativeAI + llm = GoogleGenerativeAI(model="gemini-pro") + """ + + client: Any #: :meta private: + @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validates params and passes them to google-generativeai package.""" @@ -167,7 +181,7 @@ def validate_environment(cls, values: Dict) -> Dict: client_options=values.get("client_options"), ) - if _is_gemini_model(model_name): + if GoogleModelFamily(model_name) == GoogleModelFamily.GEMINI: values["client"] = genai.GenerativeModel(model_name=model_name) else: values["client"] = genai @@ -203,7 +217,7 @@ def _generate( "candidate_count": self.n, } for prompt in prompts: - if self.is_gemini: + if self._model_family == GoogleModelFamily.GEMINI: res = _completion_with_retry( self, prompt=prompt, @@ -279,7 +293,11 @@ def get_num_tokens(self, text: str) -> int: Returns: The integer number of tokens in the text. """ - if self.is_gemini: - raise ValueError("Counting tokens is not yet supported!") - result = self.client.count_text_tokens(model=self.model, prompt=text) - return result["token_count"] + if self._model_family == GoogleModelFamily.GEMINI: + result = self.client.count_tokens(text) + token_count = result.total_tokens + else: + result = self.client.count_text_tokens(model=self.model, prompt=text) + token_count = result["token_count"] + + return token_count diff --git a/libs/partners/google-genai/poetry.lock b/libs/partners/google-genai/poetry.lock index 532f6108d01d0..edddadcb57d17 100644 --- a/libs/partners/google-genai/poetry.lock +++ b/libs/partners/google-genai/poetry.lock @@ -280,12 +280,12 @@ requests = ["requests (>=2.20.0,<3.0.0.dev0)"] [[package]] name = "google-generativeai" -version = "0.3.1" +version = "0.3.2" description = "Google Generative AI High level API client library and tools." optional = false python-versions = ">=3.9" files = [ - {file = "google_generativeai-0.3.1-py3-none-any.whl", hash = "sha256:800ec6041ca537b897d7ba654f4125651c64b38506f2bfce3b464370e3333a1b"}, + {file = "google_generativeai-0.3.2-py3-none-any.whl", hash = "sha256:8761147e6e167141932dc14a7b7af08f2310dd56668a78d206c19bb8bd85bcd7"}, ] [package.dependencies] @@ -294,6 +294,7 @@ google-api-core = "*" google-auth = "*" protobuf = "*" tqdm = "*" +typing-extensions = "*" [package.extras] dev = ["Pillow", "absl-py", "black", "ipython", "nose2", "pandas", "pytype", "pyyaml"] diff --git a/libs/partners/google-genai/tests/integration_tests/test_chat_models.py b/libs/partners/google-genai/tests/integration_tests/test_chat_models.py index 4e9c7c764ac0d..4551a860d0fa5 100644 --- a/libs/partners/google-genai/tests/integration_tests/test_chat_models.py +++ b/libs/partners/google-genai/tests/integration_tests/test_chat_models.py @@ -186,3 +186,9 @@ def test_chat_google_genai_system_message() -> None: response = model([system_message, message1, message2, message3]) assert isinstance(response, AIMessage) assert isinstance(response.content, str) + + +def test_generativeai_get_num_tokens_gemini() -> None: + llm = ChatGoogleGenerativeAI(temperature=0, model="gemini-pro") + output = llm.get_num_tokens("How are you?") + assert output == 4 diff --git a/libs/partners/google-genai/tests/integration_tests/test_llms.py b/libs/partners/google-genai/tests/integration_tests/test_llms.py index 09b17d8201ee1..9bdf49dda84cf 100644 --- a/libs/partners/google-genai/tests/integration_tests/test_llms.py +++ b/libs/partners/google-genai/tests/integration_tests/test_llms.py @@ -60,3 +60,9 @@ def test_generativeai_stream() -> None: llm = GoogleGenerativeAI(temperature=0, model="gemini-pro") outputs = list(llm.stream("Please say foo:")) assert isinstance(outputs[0], str) + + +def test_generativeai_get_num_tokens_gemini() -> None: + llm = GoogleGenerativeAI(temperature=0, model="gemini-pro") + output = llm.get_num_tokens("How are you?") + assert output == 4 diff --git a/libs/partners/google-genai/tests/unit_tests/test_llms.py b/libs/partners/google-genai/tests/unit_tests/test_llms.py new file mode 100644 index 0000000000000..9b0ae77fc6188 --- /dev/null +++ b/libs/partners/google-genai/tests/unit_tests/test_llms.py @@ -0,0 +1,8 @@ +from langchain_google_genai.llms import GoogleModelFamily + + +def test_model_family() -> None: + model = GoogleModelFamily("gemini-pro") + assert model == GoogleModelFamily.GEMINI + model = GoogleModelFamily("gemini-ultra") + assert model == GoogleModelFamily.GEMINI