From 0d58d169472596452dd9a99552822397a5b69621 Mon Sep 17 00:00:00 2001 From: Jakub Nyckowski Date: Thu, 9 May 2024 16:37:59 -0400 Subject: [PATCH] [v15] Remove usage of tiktoken (#41391) * Remove usage of tiktoken (#41374) * Remove usage of tiktoken The token counting mechanism in tokencount.go has been simplified by removing the usage of the tokenizer and instead estimating token count based on character length. * Update tests * Go build fix * Revert package.json changes --- go.mod | 2 -- go.sum | 4 ---- lib/ai/chat_test.go | 6 ++--- lib/ai/tokens/tokencount.go | 38 +++++++++----------------------- lib/ai/tokens/tokencount_test.go | 27 +++++++++++------------ lib/web/command.go | 6 ++++- 6 files changed, 31 insertions(+), 52 deletions(-) diff --git a/go.mod b/go.mod index 09c45be75d732..06cc6b0390b08 100644 --- a/go.mod +++ b/go.mod @@ -173,7 +173,6 @@ require ( github.com/spf13/cobra v1.8.0 github.com/spiffe/go-spiffe/v2 v2.2.0 github.com/stretchr/testify v1.9.0 - github.com/tiktoken-go/tokenizer v0.1.0 github.com/ucarion/urlpath v0.0.0-20200424170820-7ccc79b76bbb github.com/vulcand/predicate v1.2.0 // replaced github.com/xanzy/go-gitlab v0.103.0 @@ -298,7 +297,6 @@ require ( github.com/di-wu/parser v0.2.2 // indirect github.com/di-wu/xsd-datetime v1.0.0 // indirect github.com/digitorus/timestamp v0.0.0-20231217203849-220c5c2851b7 // indirect - github.com/dlclark/regexp2 v1.9.0 // indirect github.com/dmarkham/enumer v1.5.9 // indirect github.com/docker/cli v25.0.1+incompatible // indirect github.com/docker/distribution v2.8.3+incompatible // indirect diff --git a/go.sum b/go.sum index 0902eecd8ecb9..365d1bb2a4bba 100644 --- a/go.sum +++ b/go.sum @@ -471,8 +471,6 @@ github.com/distribution/distribution/v3 v3.0.0-20221208165359-362910506bc2 h1:aB github.com/distribution/distribution/v3 v3.0.0-20221208165359-362910506bc2/go.mod h1:WHNsWjnIn2V1LYOrME7e8KxSeKunYHsxEm4am0BUtcI= github.com/distribution/reference v0.5.0 h1:/FUIFXtfc/x2gpa5/VGfiGLuOIdYa1t65IKK2OFGvA0= github.com/distribution/reference v0.5.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= -github.com/dlclark/regexp2 v1.9.0 h1:pTK/l/3qYIKaRXuHnEnIf7Y5NxfRPfpb7dis6/gdlVI= -github.com/dlclark/regexp2 v1.9.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dmarkham/enumer v1.5.9 h1:NM/1ma/AUNieHZg74w67GkHFBNB15muOt3sj486QVZk= github.com/dmarkham/enumer v1.5.9/go.mod h1:e4VILe2b1nYK3JKJpRmNdl5xbDQvELc6tQ8b+GsGk6E= github.com/dnaeon/go-vcr v1.1.0/go.mod h1:M7tiix8f0r6mKKJ3Yq/kqU1OYf3MnfmBWVbPx/yU9ko= @@ -1504,8 +1502,6 @@ github.com/thales-e-security/pool v0.0.2/go.mod h1:qtpMm2+thHtqhLzTwgDBj/OuNnMpu github.com/theupdateframework/go-tuf v0.7.0 h1:CqbQFrWo1ae3/I0UCblSbczevCCbS31Qvs5LdxRWqRI= github.com/theupdateframework/go-tuf v0.7.0/go.mod h1:uEB7WSY+7ZIugK6R1hiBMBjQftaFzn7ZCDJcp1tCUug= github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= -github.com/tiktoken-go/tokenizer v0.1.0 h1:c1fXriHSR/NmhMDTwUDLGiNhHwTV+ElABGvqhCWLRvY= -github.com/tiktoken-go/tokenizer v0.1.0/go.mod h1:7SZW3pZUKWLJRilTvWCa86TOVIiiJhYj3FQ5V3alWcg= github.com/titanous/rocacheck v0.0.0-20171023193734-afe73141d399 h1:e/5i7d4oYZ+C1wj2THlRK+oAhjeS/TRQwMfkIuet3w0= github.com/titanous/rocacheck v0.0.0-20171023193734-afe73141d399/go.mod h1:LdwHTNJT99C5fTAzDz0ud328OgXz+gierycbcIx2fRs= github.com/tjfoc/gmsm v1.4.1 h1:aMe1GlZb+0bLjn+cKTPEvvn9oUEBlJitaZiiBwsbgho= diff --git a/lib/ai/chat_test.go b/lib/ai/chat_test.go index 5d5c1d14ca323..bf0b1bf5394fa 100644 --- a/lib/ai/chat_test.go +++ b/lib/ai/chat_test.go @@ -58,7 +58,7 @@ func TestChat_PromptTokens(t *testing.T) { Content: "Hello", }, }, - want: 721, + want: 850, }, { name: "system and user messages", @@ -72,7 +72,7 @@ func TestChat_PromptTokens(t *testing.T) { Content: "Hi LLM.", }, }, - want: 729, + want: 855, }, { name: "tokenize our prompt", @@ -86,7 +86,7 @@ func TestChat_PromptTokens(t *testing.T) { Content: "Show me free disk space on localhost node.", }, }, - want: 932, + want: 1114, }, } diff --git a/lib/ai/tokens/tokencount.go b/lib/ai/tokens/tokencount.go index 28a444ea41010..8550d2aae366c 100644 --- a/lib/ai/tokens/tokencount.go +++ b/lib/ai/tokens/tokencount.go @@ -23,11 +23,8 @@ import ( "github.com/gravitational/trace" "github.com/sashabaranov/go-openai" - "github.com/tiktoken-go/tokenizer/codec" ) -var defaultTokenizer = codec.NewCl100kBase() - // TokenCount holds TokenCounters for both Prompt and Completion tokens. // As the agent performs multiple calls to the model, each call creates its own // prompt and completion TokenCounter. @@ -115,12 +112,9 @@ func (tc *StaticTokenCounter) TokenCount() int { func NewPromptTokenCounter(prompt []openai.ChatCompletionMessage) (*StaticTokenCounter, error) { var promptCount int for _, message := range prompt { - promptTokens, _, err := defaultTokenizer.Encode(message.Content) - if err != nil { - return nil, trace.Wrap(err) - } + promptTokens := countTokens(message.Content) - promptCount = promptCount + perMessage + perRole + len(promptTokens) + promptCount = promptCount + perMessage + perRole + promptTokens } tc := StaticTokenCounter(promptCount) @@ -130,12 +124,8 @@ func NewPromptTokenCounter(prompt []openai.ChatCompletionMessage) (*StaticTokenC // NewSynchronousTokenCounter takes the completion request output and // computes how many tokens were used by the model to generate this result. func NewSynchronousTokenCounter(completion string) (*StaticTokenCounter, error) { - completionTokens, _, err := defaultTokenizer.Encode(completion) - if err != nil { - return nil, trace.Wrap(err) - } - - completionCount := perRequest + len(completionTokens) + completionTokens := countTokens(completion) + completionCount := perRequest + completionTokens tc := StaticTokenCounter(completionCount) return &tc, nil @@ -188,25 +178,17 @@ func (tc *AsynchronousTokenCounter) Add() error { // the content has been streamed yet. Streamed content can be added a posteriori // with Add(). Once all the content is streamed, Finish() must be called. func NewAsynchronousTokenCounter(completionStart string) (*AsynchronousTokenCounter, error) { - completionTokens, _, err := defaultTokenizer.Encode(completionStart) - if err != nil { - return nil, trace.Wrap(err) - } + completionTokens := countTokens(completionStart) return &AsynchronousTokenCounter{ - count: len(completionTokens), + count: completionTokens, mutex: sync.Mutex{}, finished: false, }, nil } -// CountTokens is a helper that calls tc.CountAll() on a TokenCount pointer, -// but also return 0, 0 when receiving a nil pointer. This makes token counting -// less awkward in cases where we don't know whether a completion happened or -// not. -func CountTokens(tc *TokenCount) (int, int) { - if tc != nil { - return tc.CountAll() - } - return 0, 0 +// countTokens returns an estimated number of tokens in the text. +func countTokens(text string) int { + // Rough estimations that each token is around 4 characters. + return len(text) / 4 } diff --git a/lib/ai/tokens/tokencount_test.go b/lib/ai/tokens/tokencount_test.go index 3289549b7c9d5..dd69bcf979a23 100644 --- a/lib/ai/tokens/tokencount_test.go +++ b/lib/ai/tokens/tokencount_test.go @@ -25,11 +25,8 @@ import ( ) const ( - testCompletionStart = "This is the beginning of the response." - testCompletionEnd = "And this is the end." - testCompletionStartTokens = 8 // 1 token per word + 1 for the dot - testCompletionEndTokens = 6 // 1 token per word + 1 for the dot - testCompletionTokens = testCompletionStartTokens + testCompletionEndTokens + testCompletionStart = "This is the beginning of the response." + testCompletionEnd = "And this is the end." ) // This test checks that Add() properly appends content in the completion @@ -43,23 +40,24 @@ func TestAsynchronousTokenCounter_TokenCount(t *testing.T) { expectedTokens int }{ { - name: "empty count", + name: "empty count", + expectedTokens: 3, }, { name: "only completion start", completionStart: testCompletionStart, - expectedTokens: testCompletionStartTokens, + expectedTokens: 12, }, { name: "only completion add", completionEnd: testCompletionEnd, - expectedTokens: testCompletionEndTokens, + expectedTokens: 8, }, { name: "completion start and end", completionStart: testCompletionStart, completionEnd: testCompletionEnd, - expectedTokens: testCompletionTokens, + expectedTokens: 17, }, } for _, tt := range tests { @@ -69,15 +67,15 @@ func TestAsynchronousTokenCounter_TokenCount(t *testing.T) { // Test setup tc, err := NewAsynchronousTokenCounter(tt.completionStart) require.NoError(t, err) - tokens, _, err := defaultTokenizer.Encode(tt.completionEnd) - require.NoError(t, err) - for range tokens { + tokens := countTokens(tt.completionEnd) + + for i := 0; i < tokens; i++ { require.NoError(t, tc.Add()) } // Doing the real test: asserting the count is right count := tc.TokenCount() - require.Equal(t, tt.expectedTokens+perRequest, count) + require.Equal(t, tt.expectedTokens, count) }) } } @@ -90,7 +88,8 @@ func TestAsynchronousTokenCounter_Finished(t *testing.T) { require.NoError(t, tc.Add()) // We read from the counter - tc.TokenCount() + count := tc.TokenCount() + require.Equal(t, 13, count) // Adding new tokens should be impossible require.Error(t, tc.Add()) diff --git a/lib/web/command.go b/lib/web/command.go index 6dd752ea87287..c7f0a503f9eaf 100644 --- a/lib/web/command.go +++ b/lib/web/command.go @@ -298,7 +298,11 @@ func (h *Handler) executeCommand( } } - prompt, completion := tokens.CountTokens(tokenCount) + prompt, completion := 0, 0 + if tokenCount != nil { + prompt = tokenCount.Prompt.CountAll() + completion = tokenCount.Completion.CountAll() + } usageEventReq := &clientproto.SubmitUsageEventRequest{ Event: &usageeventsv1.UsageEventOneOf{