Skip to content

Commit

Permalink
add tiktoken token count method
Browse files Browse the repository at this point in the history
  • Loading branch information
findalexli committed Nov 5, 2024
1 parent b59fb7f commit c7d75a2
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 6 deletions.
11 changes: 5 additions & 6 deletions libs/community/langchain_community/chat_models/reka.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Union,
)

import tiktoken
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
Expand Down Expand Up @@ -153,9 +154,9 @@ class ChatReka(BaseChatModel):
default_request_timeout: Optional[float] = None
max_retries: int = 2
reka_api_key: Optional[str] = None
count_tokens: Optional[Callable[[str], int]] = None
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
model_config = ConfigDict(extra="forbid")
_tiktoken_encoder = None

@model_validator(mode="before")
@classmethod
Expand Down Expand Up @@ -330,11 +331,9 @@ async def _agenerate(

def get_num_tokens(self, text: str) -> int:
"""Calculate number of tokens."""
if self.count_tokens is None:
raise NotImplementedError(
"get_num_tokens() is not implemented for Reka models."
)
return self.count_tokens(text)
if self._tiktoken_encoder is None:
self._tiktoken_encoder = tiktoken.get_encoding("cl100k_base")
return len(self._tiktoken_encoder.encode(text))

def bind_tools(
self,
Expand Down
20 changes: 20 additions & 0 deletions libs/community/tests/unit_tests/chat_models/test_reka.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from unittest.mock import MagicMock, patch

import pytest
import tiktoken
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from pydantic import ValidationError

Expand Down Expand Up @@ -302,3 +303,22 @@ def test_multiple_system_messages_error() -> None:

with pytest.raises(ValueError, match="Multiple system messages are not supported."):
convert_to_reka_messages(messages)


@pytest.mark.requires("reka")
def test_get_num_tokens() -> None:
"""Test that token counting works correctly."""
llm = ChatReka()

# Test basic text
text = "Hello, world!"
expected_tokens = len(tiktoken.get_encoding("cl100k_base").encode(text))
assert llm.get_num_tokens(text) == expected_tokens

# Test empty string
assert llm.get_num_tokens("") == 0

# Test longer text with special characters
complex_text = "Hello 🌍! This is a test of the token counting"
expected_tokens = len(tiktoken.get_encoding("cl100k_base").encode(complex_text))
assert llm.get_num_tokens(complex_text) == expected_tokens

0 comments on commit c7d75a2

Please sign in to comment.