Skip to content

Commit

Permalink
google-genai: added logic for method get_num_tokens() (#16205)
Browse files Browse the repository at this point in the history
<!-- Thank you for contributing to LangChain!

Please title your PR "partners: google-genai",

Replace this entire comment with:
- **Description:** : added logic for method get_num_tokens() for
ChatGoogleGenerativeAI , GoogleGenerativeAI,
  - **Issue:** : #16204,
  - **Dependencies:** : None,
  - **Twitter handle:** @Aditya_Rane

---------

Co-authored-by: [email protected] <[email protected]>
Co-authored-by: Leonid Kuligin <[email protected]>
  • Loading branch information
3 people authored Jan 25, 2024
1 parent 0785432 commit 9dd7cbb
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 74 deletions.
71 changes: 24 additions & 47 deletions libs/partners/google-genai/langchain_google_genai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.pydantic_v1 import SecretStr, root_validator
from langchain_core.utils import get_from_dict_or_env
from tenacity import (
before_sleep_log,
Expand All @@ -53,6 +53,7 @@
)

from langchain_google_genai._common import GoogleGenerativeAIError
from langchain_google_genai.llms import GoogleModelFamily, _BaseGoogleGenerativeAI

IMAGE_TYPES: Tuple = ()
try:
Expand Down Expand Up @@ -417,7 +418,7 @@ def _response_to_result(
return ChatResult(generations=generations, llm_output=llm_output)


class ChatGoogleGenerativeAI(BaseChatModel):
class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
"""`Google Generative AI` Chat models API.
To use, you must have either:
Expand All @@ -435,53 +436,13 @@ class ChatGoogleGenerativeAI(BaseChatModel):
"""

model: str = Field(
...,
description="""The name of the model to use.
Supported examples:
- gemini-pro""",
)
max_output_tokens: int = Field(default=None, description="Max output tokens")

client: Any #: :meta private:
google_api_key: Optional[SecretStr] = None
temperature: Optional[float] = None
"""Run inference with this temperature. Must by in the closed
interval [0.0, 1.0]."""
top_k: Optional[int] = None
"""Decode using top-k sampling: consider the set of top_k most probable tokens.
Must be positive."""
top_p: Optional[float] = None
"""The maximum cumulative probability of tokens to consider when sampling.
The model uses combined Top-k and nucleus sampling.
Tokens are sorted based on their assigned probabilities so
that only the most likely tokens are considered. Top-k
sampling directly limits the maximum number of tokens to
consider, while Nucleus sampling limits number of tokens
based on the cumulative probability.
Note: The default value varies by model, see the
`Model.top_p` attribute of the `Model` returned the
`genai.get_model` function.
"""
n: int = Field(default=1, alias="candidate_count")
"""Number of chat completions to generate for each prompt. Note that the API may
not return the full n completions if duplicates are generated."""

convert_system_message_to_human: bool = False
"""Whether to merge any leading SystemMessage into the following HumanMessage.
Gemini does not support system messages; any unsupported messages will
raise an error."""
client_options: Optional[Dict] = Field(
None,
description="Client options to pass to the Google API client.",
)
transport: Optional[str] = Field(
None,
description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
)

class Config:
allow_population_by_field_name = True
Expand All @@ -494,10 +455,6 @@ def lc_secrets(self) -> Dict[str, str]:
def _llm_type(self) -> str:
return "chat-google-generative-ai"

@property
def _is_geminiai(self) -> bool:
return self.model is not None and "gemini" in self.model

@classmethod
def is_lc_serializable(self) -> bool:
return True
Expand Down Expand Up @@ -658,3 +615,23 @@ def _prepare_chat(
message = history.pop()
chat = self.client.start_chat(history=history)
return params, chat, message

def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text.
Useful for checking if an input will fit in a model's context window.
Args:
text: The string input to tokenize.
Returns:
The integer number of tokens in the text.
"""
if self._model_family == GoogleModelFamily.GEMINI:
result = self.client.count_tokens(text)
token_count = result.total_tokens
else:
result = self.client.count_text_tokens(model=self.model, prompt=text)
token_count = result["token_count"]

return token_count
68 changes: 43 additions & 25 deletions libs/partners/google-genai/langchain_google_genai/llms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from enum import Enum, auto
from typing import Any, Callable, Dict, Iterator, List, Optional, Union

import google.api_core
Expand All @@ -15,6 +16,19 @@
from langchain_core.utils import get_from_dict_or_env


class GoogleModelFamily(str, Enum):
GEMINI = auto()
PALM = auto()

@classmethod
def _missing_(cls, value: Any) -> Optional["GoogleModelFamily"]:
if "gemini" in value.lower():
return GoogleModelFamily.GEMINI
elif "text-bison" in value.lower():
return GoogleModelFamily.PALM
return None


def _create_retry_decorator(
llm: BaseLLM,
*,
Expand Down Expand Up @@ -75,10 +89,6 @@ def _completion_with_retry(
)


def _is_gemini_model(model_name: str) -> bool:
return "gemini" in model_name


def _strip_erroneous_leading_spaces(text: str) -> str:
"""Strip erroneous leading spaces from text.
Expand All @@ -92,17 +102,9 @@ def _strip_erroneous_leading_spaces(text: str) -> str:
return text


class GoogleGenerativeAI(BaseLLM, BaseModel):
"""Google GenerativeAI models.
Example:
.. code-block:: python
from langchain_google_genai import GoogleGenerativeAI
llm = GoogleGenerativeAI(model="gemini-pro")
"""
class _BaseGoogleGenerativeAI(BaseModel):
"""Base class for Google Generative AI LLMs"""

client: Any #: :meta private:
model: str = Field(
...,
description="""The name of the model to use.
Expand Down Expand Up @@ -141,15 +143,27 @@ class GoogleGenerativeAI(BaseLLM, BaseModel):
description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
)

@property
def is_gemini(self) -> bool:
"""Returns whether a model is belongs to a Gemini family or not."""
return _is_gemini_model(self.model)

@property
def lc_secrets(self) -> Dict[str, str]:
return {"google_api_key": "GOOGLE_API_KEY"}

@property
def _model_family(self) -> str:
return GoogleModelFamily(self.model)


class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
"""Google GenerativeAI models.
Example:
.. code-block:: python
from langchain_google_genai import GoogleGenerativeAI
llm = GoogleGenerativeAI(model="gemini-pro")
"""

client: Any #: :meta private:

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validates params and passes them to google-generativeai package."""
Expand All @@ -167,7 +181,7 @@ def validate_environment(cls, values: Dict) -> Dict:
client_options=values.get("client_options"),
)

if _is_gemini_model(model_name):
if GoogleModelFamily(model_name) == GoogleModelFamily.GEMINI:
values["client"] = genai.GenerativeModel(model_name=model_name)
else:
values["client"] = genai
Expand Down Expand Up @@ -203,7 +217,7 @@ def _generate(
"candidate_count": self.n,
}
for prompt in prompts:
if self.is_gemini:
if self._model_family == GoogleModelFamily.GEMINI:
res = _completion_with_retry(
self,
prompt=prompt,
Expand Down Expand Up @@ -279,7 +293,11 @@ def get_num_tokens(self, text: str) -> int:
Returns:
The integer number of tokens in the text.
"""
if self.is_gemini:
raise ValueError("Counting tokens is not yet supported!")
result = self.client.count_text_tokens(model=self.model, prompt=text)
return result["token_count"]
if self._model_family == GoogleModelFamily.GEMINI:
result = self.client.count_tokens(text)
token_count = result.total_tokens
else:
result = self.client.count_text_tokens(model=self.model, prompt=text)
token_count = result["token_count"]

return token_count
5 changes: 3 additions & 2 deletions libs/partners/google-genai/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,9 @@ def test_chat_google_genai_system_message() -> None:
response = model([system_message, message1, message2, message3])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)


def test_generativeai_get_num_tokens_gemini() -> None:
llm = ChatGoogleGenerativeAI(temperature=0, model="gemini-pro")
output = llm.get_num_tokens("How are you?")
assert output == 4
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,9 @@ def test_generativeai_stream() -> None:
llm = GoogleGenerativeAI(temperature=0, model="gemini-pro")
outputs = list(llm.stream("Please say foo:"))
assert isinstance(outputs[0], str)


def test_generativeai_get_num_tokens_gemini() -> None:
llm = GoogleGenerativeAI(temperature=0, model="gemini-pro")
output = llm.get_num_tokens("How are you?")
assert output == 4
8 changes: 8 additions & 0 deletions libs/partners/google-genai/tests/unit_tests/test_llms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from langchain_google_genai.llms import GoogleModelFamily


def test_model_family() -> None:
model = GoogleModelFamily("gemini-pro")
assert model == GoogleModelFamily.GEMINI
model = GoogleModelFamily("gemini-ultra")
assert model == GoogleModelFamily.GEMINI

0 comments on commit 9dd7cbb

Please sign in to comment.