Skip to content

Commit

Permalink
Merge branch 'master' into lisa/add-more-flakes
Browse files Browse the repository at this point in the history
  • Loading branch information
kimlisa authored May 9, 2024
2 parents 0a56c81 + ae9eae3 commit 297fb6d
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 50 deletions.
2 changes: 0 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ require (
github.com/spf13/cobra v1.8.0
github.com/spiffe/go-spiffe/v2 v2.2.0
github.com/stretchr/testify v1.9.0
github.com/tiktoken-go/tokenizer v0.1.0
github.com/ucarion/urlpath v0.0.0-20200424170820-7ccc79b76bbb
github.com/vulcand/predicate v1.2.0 // replaced
github.com/xanzy/go-gitlab v0.103.0
Expand Down Expand Up @@ -308,7 +307,6 @@ require (
github.com/di-wu/parser v0.3.0 // indirect
github.com/di-wu/xsd-datetime v1.0.0 // indirect
github.com/digitorus/timestamp v0.0.0-20231217203849-220c5c2851b7 // indirect
github.com/dlclark/regexp2 v1.11.0 // indirect
github.com/dmarkham/enumer v1.5.9 // indirect
github.com/docker/cli v25.0.1+incompatible // indirect
github.com/docker/distribution v2.8.3+incompatible // indirect
Expand Down
4 changes: 0 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -1089,8 +1089,6 @@ github.com/distribution/distribution/v3 v3.0.0-20221208165359-362910506bc2 h1:aB
github.com/distribution/distribution/v3 v3.0.0-20221208165359-362910506bc2/go.mod h1:WHNsWjnIn2V1LYOrME7e8KxSeKunYHsxEm4am0BUtcI=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI=
github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dmarkham/enumer v1.5.9 h1:NM/1ma/AUNieHZg74w67GkHFBNB15muOt3sj486QVZk=
github.com/dmarkham/enumer v1.5.9/go.mod h1:e4VILe2b1nYK3JKJpRmNdl5xbDQvELc6tQ8b+GsGk6E=
github.com/dnaeon/go-vcr v1.1.0/go.mod h1:M7tiix8f0r6mKKJ3Yq/kqU1OYf3MnfmBWVbPx/yU9ko=
Expand Down Expand Up @@ -2254,8 +2252,6 @@ github.com/thales-e-security/pool v0.0.2/go.mod h1:qtpMm2+thHtqhLzTwgDBj/OuNnMpu
github.com/theupdateframework/go-tuf v0.7.0 h1:CqbQFrWo1ae3/I0UCblSbczevCCbS31Qvs5LdxRWqRI=
github.com/theupdateframework/go-tuf v0.7.0/go.mod h1:uEB7WSY+7ZIugK6R1hiBMBjQftaFzn7ZCDJcp1tCUug=
github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
github.com/tiktoken-go/tokenizer v0.1.0 h1:c1fXriHSR/NmhMDTwUDLGiNhHwTV+ElABGvqhCWLRvY=
github.com/tiktoken-go/tokenizer v0.1.0/go.mod h1:7SZW3pZUKWLJRilTvWCa86TOVIiiJhYj3FQ5V3alWcg=
github.com/titanous/rocacheck v0.0.0-20171023193734-afe73141d399 h1:e/5i7d4oYZ+C1wj2THlRK+oAhjeS/TRQwMfkIuet3w0=
github.com/titanous/rocacheck v0.0.0-20171023193734-afe73141d399/go.mod h1:LdwHTNJT99C5fTAzDz0ud328OgXz+gierycbcIx2fRs=
github.com/tjfoc/gmsm v1.4.1 h1:aMe1GlZb+0bLjn+cKTPEvvn9oUEBlJitaZiiBwsbgho=
Expand Down
6 changes: 3 additions & 3 deletions lib/ai/chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func TestChat_PromptTokens(t *testing.T) {
Content: "Hello",
},
},
want: 721,
want: 850,
},
{
name: "system and user messages",
Expand All @@ -72,7 +72,7 @@ func TestChat_PromptTokens(t *testing.T) {
Content: "Hi LLM.",
},
},
want: 729,
want: 855,
},
{
name: "tokenize our prompt",
Expand All @@ -86,7 +86,7 @@ func TestChat_PromptTokens(t *testing.T) {
Content: "Show me free disk space on localhost node.",
},
},
want: 932,
want: 1114,
},
}

Expand Down
38 changes: 10 additions & 28 deletions lib/ai/tokens/tokencount.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,8 @@ import (

"github.com/gravitational/trace"
"github.com/sashabaranov/go-openai"
"github.com/tiktoken-go/tokenizer/codec"
)

var defaultTokenizer = codec.NewCl100kBase()

// TokenCount holds TokenCounters for both Prompt and Completion tokens.
// As the agent performs multiple calls to the model, each call creates its own
// prompt and completion TokenCounter.
Expand Down Expand Up @@ -115,12 +112,9 @@ func (tc *StaticTokenCounter) TokenCount() int {
func NewPromptTokenCounter(prompt []openai.ChatCompletionMessage) (*StaticTokenCounter, error) {
var promptCount int
for _, message := range prompt {
promptTokens, _, err := defaultTokenizer.Encode(message.Content)
if err != nil {
return nil, trace.Wrap(err)
}
promptTokens := countTokens(message.Content)

promptCount = promptCount + perMessage + perRole + len(promptTokens)
promptCount = promptCount + perMessage + perRole + promptTokens
}
tc := StaticTokenCounter(promptCount)

Expand All @@ -130,12 +124,8 @@ func NewPromptTokenCounter(prompt []openai.ChatCompletionMessage) (*StaticTokenC
// NewSynchronousTokenCounter takes the completion request output and
// computes how many tokens were used by the model to generate this result.
func NewSynchronousTokenCounter(completion string) (*StaticTokenCounter, error) {
completionTokens, _, err := defaultTokenizer.Encode(completion)
if err != nil {
return nil, trace.Wrap(err)
}

completionCount := perRequest + len(completionTokens)
completionTokens := countTokens(completion)
completionCount := perRequest + completionTokens

tc := StaticTokenCounter(completionCount)
return &tc, nil
Expand Down Expand Up @@ -188,25 +178,17 @@ func (tc *AsynchronousTokenCounter) Add() error {
// the content has been streamed yet. Streamed content can be added a posteriori
// with Add(). Once all the content is streamed, Finish() must be called.
func NewAsynchronousTokenCounter(completionStart string) (*AsynchronousTokenCounter, error) {
completionTokens, _, err := defaultTokenizer.Encode(completionStart)
if err != nil {
return nil, trace.Wrap(err)
}
completionTokens := countTokens(completionStart)

return &AsynchronousTokenCounter{
count: len(completionTokens),
count: completionTokens,
mutex: sync.Mutex{},
finished: false,
}, nil
}

// CountTokens is a helper that calls tc.CountAll() on a TokenCount pointer,
// but also return 0, 0 when receiving a nil pointer. This makes token counting
// less awkward in cases where we don't know whether a completion happened or
// not.
func CountTokens(tc *TokenCount) (int, int) {
if tc != nil {
return tc.CountAll()
}
return 0, 0
// countTokens returns an estimated number of tokens in the text.
func countTokens(text string) int {
// Rough estimations that each token is around 4 characters.
return len(text) / 4
}
25 changes: 12 additions & 13 deletions lib/ai/tokens/tokencount_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,8 @@ import (
)

const (
testCompletionStart = "This is the beginning of the response."
testCompletionEnd = "And this is the end."
testCompletionStartTokens = 8 // 1 token per word + 1 for the dot
testCompletionEndTokens = 6 // 1 token per word + 1 for the dot
testCompletionTokens = testCompletionStartTokens + testCompletionEndTokens
testCompletionStart = "This is the beginning of the response."
testCompletionEnd = "And this is the end."
)

// This test checks that Add() properly appends content in the completion
Expand All @@ -43,23 +40,24 @@ func TestAsynchronousTokenCounter_TokenCount(t *testing.T) {
expectedTokens int
}{
{
name: "empty count",
name: "empty count",
expectedTokens: 3,
},
{
name: "only completion start",
completionStart: testCompletionStart,
expectedTokens: testCompletionStartTokens,
expectedTokens: 12,
},
{
name: "only completion add",
completionEnd: testCompletionEnd,
expectedTokens: testCompletionEndTokens,
expectedTokens: 8,
},
{
name: "completion start and end",
completionStart: testCompletionStart,
completionEnd: testCompletionEnd,
expectedTokens: testCompletionTokens,
expectedTokens: 17,
},
}
for _, tt := range tests {
Expand All @@ -69,15 +67,15 @@ func TestAsynchronousTokenCounter_TokenCount(t *testing.T) {
// Test setup
tc, err := NewAsynchronousTokenCounter(tt.completionStart)
require.NoError(t, err)
tokens, _, err := defaultTokenizer.Encode(tt.completionEnd)
require.NoError(t, err)
tokens := countTokens(tt.completionEnd)

for range tokens {
require.NoError(t, tc.Add())
}

// Doing the real test: asserting the count is right
count := tc.TokenCount()
require.Equal(t, tt.expectedTokens+perRequest, count)
require.Equal(t, tt.expectedTokens, count)
})
}
}
Expand All @@ -90,7 +88,8 @@ func TestAsynchronousTokenCounter_Finished(t *testing.T) {
require.NoError(t, tc.Add())

// We read from the counter
tc.TokenCount()
count := tc.TokenCount()
require.Equal(t, 13, count)

// Adding new tokens should be impossible
require.Error(t, tc.Add())
Expand Down

0 comments on commit 297fb6d

Please sign in to comment.