diff --git a/libs/community/langchain_community/llms/llamafile.py b/libs/community/langchain_community/llms/llamafile.py index 1aff521ee3300..933ed5e025877 100644 --- a/libs/community/langchain_community/llms/llamafile.py +++ b/libs/community/langchain_community/llms/llamafile.py @@ -139,6 +139,7 @@ def _param_fieldnames(self) -> List[str]: "streaming", "tags", "verbose", + "custom_get_token_ids", ] attrs = [ k for k in get_pydantic_field_names(self.__class__) if k not in ignore_keys diff --git a/libs/core/langchain_core/language_models/base.py b/libs/core/langchain_core/language_models/base.py index 4941faea34589..a6addf8f6b0aa 100644 --- a/libs/core/langchain_core/language_models/base.py +++ b/libs/core/langchain_core/language_models/base.py @@ -5,6 +5,7 @@ from typing import ( TYPE_CHECKING, Any, + Callable, Dict, List, Mapping, @@ -97,6 +98,10 @@ class BaseLanguageModel( """Tags to add to the run trace.""" metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True) """Metadata to add to the run trace.""" + custom_get_token_ids: Optional[Callable[[str], List[int]]] = Field( + default=None, exclude=True + ) + """Optional encoder to use for counting tokens.""" @validator("verbose", pre=True, always=True) def set_verbose(cls, verbose: Optional[bool]) -> bool: @@ -310,7 +315,10 @@ def get_token_ids(self, text: str) -> List[int]: A list of ids corresponding to the tokens in the text, in order they occur in the text. """ - return _get_token_ids_default_method(text) + if self.custom_get_token_ids is not None: + return self.custom_get_token_ids(text) + else: + return _get_token_ids_default_method(text) def get_num_tokens(self, text: str) -> int: """Get the number of tokens present in the text. diff --git a/libs/partners/openai/langchain_openai/llms/base.py b/libs/partners/openai/langchain_openai/llms/base.py index f5602a816baf9..f59b2c6b6fa4d 100644 --- a/libs/partners/openai/langchain_openai/llms/base.py +++ b/libs/partners/openai/langchain_openai/llms/base.py @@ -521,6 +521,8 @@ def _llm_type(self) -> str: def get_token_ids(self, text: str) -> List[int]: """Get the token IDs using the tiktoken package.""" + if self.custom_get_token_ids is not None: + return self.custom_get_token_ids(text) # tiktoken NOT supported for Python < 3.8 if sys.version_info[1] < 8: return super().get_num_tokens(text) diff --git a/libs/partners/openai/tests/unit_tests/llms/test_base.py b/libs/partners/openai/tests/unit_tests/llms/test_base.py index d05bb0bfe546c..122846e2def13 100644 --- a/libs/partners/openai/tests/unit_tests/llms/test_base.py +++ b/libs/partners/openai/tests/unit_tests/llms/test_base.py @@ -1,4 +1,5 @@ import os +from typing import List import pytest @@ -54,3 +55,11 @@ def mock_completion() -> dict: def test_get_token_ids(model: str) -> None: OpenAI(model=model).get_token_ids("foo") return + + +def test_custom_token_counting() -> None: + def token_encoder(text: str) -> List[int]: + return [1, 2, 3] + + llm = OpenAI(custom_get_token_ids=token_encoder) + assert llm.get_token_ids("foo") == [1, 2, 3]