Skip to content

Commit

Permalink
core, openai: support custom token encoders (#20762)
Browse files Browse the repository at this point in the history
  • Loading branch information
ccurme authored Apr 23, 2024
1 parent b481b73 commit 7a922f3
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 1 deletion.
1 change: 1 addition & 0 deletions libs/community/langchain_community/llms/llamafile.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def _param_fieldnames(self) -> List[str]:
"streaming",
"tags",
"verbose",
"custom_get_token_ids",
]
attrs = [
k for k in get_pydantic_field_names(self.__class__) if k not in ignore_keys
Expand Down
10 changes: 9 additions & 1 deletion libs/core/langchain_core/language_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Mapping,
Expand Down Expand Up @@ -97,6 +98,10 @@ class BaseLanguageModel(
"""Tags to add to the run trace."""
metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True)
"""Metadata to add to the run trace."""
custom_get_token_ids: Optional[Callable[[str], List[int]]] = Field(
default=None, exclude=True
)
"""Optional encoder to use for counting tokens."""

@validator("verbose", pre=True, always=True)
def set_verbose(cls, verbose: Optional[bool]) -> bool:
Expand Down Expand Up @@ -310,7 +315,10 @@ def get_token_ids(self, text: str) -> List[int]:
A list of ids corresponding to the tokens in the text, in order they occur
in the text.
"""
return _get_token_ids_default_method(text)
if self.custom_get_token_ids is not None:
return self.custom_get_token_ids(text)
else:
return _get_token_ids_default_method(text)

def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text.
Expand Down
2 changes: 2 additions & 0 deletions libs/partners/openai/langchain_openai/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,8 @@ def _llm_type(self) -> str:

def get_token_ids(self, text: str) -> List[int]:
"""Get the token IDs using the tiktoken package."""
if self.custom_get_token_ids is not None:
return self.custom_get_token_ids(text)
# tiktoken NOT supported for Python < 3.8
if sys.version_info[1] < 8:
return super().get_num_tokens(text)
Expand Down
9 changes: 9 additions & 0 deletions libs/partners/openai/tests/unit_tests/llms/test_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import List

import pytest

Expand Down Expand Up @@ -54,3 +55,11 @@ def mock_completion() -> dict:
def test_get_token_ids(model: str) -> None:
OpenAI(model=model).get_token_ids("foo")
return


def test_custom_token_counting() -> None:
def token_encoder(text: str) -> List[int]:
return [1, 2, 3]

llm = OpenAI(custom_get_token_ids=token_encoder)
assert llm.get_token_ids("foo") == [1, 2, 3]

0 comments on commit 7a922f3

Please sign in to comment.