From c7d75a2b7f25a755bc3f9688671a4e8ad198eacd Mon Sep 17 00:00:00 2001 From: findalexli Date: Tue, 5 Nov 2024 12:18:09 -0800 Subject: [PATCH] add tiktoken token count method --- .../langchain_community/chat_models/reka.py | 11 +++++----- .../tests/unit_tests/chat_models/test_reka.py | 20 +++++++++++++++++++ 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/libs/community/langchain_community/chat_models/reka.py b/libs/community/langchain_community/chat_models/reka.py index a1a3a1a8c35b6..d8c4b3abb7184 100644 --- a/libs/community/langchain_community/chat_models/reka.py +++ b/libs/community/langchain_community/chat_models/reka.py @@ -13,6 +13,7 @@ Union, ) +import tiktoken from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, @@ -153,9 +154,9 @@ class ChatReka(BaseChatModel): default_request_timeout: Optional[float] = None max_retries: int = 2 reka_api_key: Optional[str] = None - count_tokens: Optional[Callable[[str], int]] = None model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_config = ConfigDict(extra="forbid") + _tiktoken_encoder = None @model_validator(mode="before") @classmethod @@ -330,11 +331,9 @@ async def _agenerate( def get_num_tokens(self, text: str) -> int: """Calculate number of tokens.""" - if self.count_tokens is None: - raise NotImplementedError( - "get_num_tokens() is not implemented for Reka models." - ) - return self.count_tokens(text) + if self._tiktoken_encoder is None: + self._tiktoken_encoder = tiktoken.get_encoding("cl100k_base") + return len(self._tiktoken_encoder.encode(text)) def bind_tools( self, diff --git a/libs/community/tests/unit_tests/chat_models/test_reka.py b/libs/community/tests/unit_tests/chat_models/test_reka.py index 642690e0d117a..00818f8bd33c4 100644 --- a/libs/community/tests/unit_tests/chat_models/test_reka.py +++ b/libs/community/tests/unit_tests/chat_models/test_reka.py @@ -4,6 +4,7 @@ from unittest.mock import MagicMock, patch import pytest +import tiktoken from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from pydantic import ValidationError @@ -302,3 +303,22 @@ def test_multiple_system_messages_error() -> None: with pytest.raises(ValueError, match="Multiple system messages are not supported."): convert_to_reka_messages(messages) + + +@pytest.mark.requires("reka") +def test_get_num_tokens() -> None: + """Test that token counting works correctly.""" + llm = ChatReka() + + # Test basic text + text = "Hello, world!" + expected_tokens = len(tiktoken.get_encoding("cl100k_base").encode(text)) + assert llm.get_num_tokens(text) == expected_tokens + + # Test empty string + assert llm.get_num_tokens("") == 0 + + # Test longer text with special characters + complex_text = "Hello 🌍! This is a test of the token counting" + expected_tokens = len(tiktoken.get_encoding("cl100k_base").encode(complex_text)) + assert llm.get_num_tokens(complex_text) == expected_tokens