diff --git a/common/mock/openai.go b/common/mock/openai.go index d347f5b33..74bd0f6b1 100644 --- a/common/mock/openai.go +++ b/common/mock/openai.go @@ -16,12 +16,15 @@ package mock import ( "bytes" + "crypto/md5" "encoding/json" "fmt" - "github.com/emicklei/go-restful/v3" - "github.com/sashabaranov/go-openai" "net" "net/http" + + "github.com/emicklei/go-restful/v3" + "github.com/samber/lo" + "github.com/sashabaranov/go-openai" ) type OpenAIServer struct { @@ -29,8 +32,6 @@ type OpenAIServer struct { httpServer *http.Server authToken string ready chan struct{} - - mockEmbeddings []float32 } func NewOpenAIServer() *OpenAIServer { @@ -81,10 +82,6 @@ func (s *OpenAIServer) Close() error { return s.httpServer.Close() } -func (s *OpenAIServer) Embeddings(embeddings []float32) { - s.mockEmbeddings = embeddings -} - func (s *OpenAIServer) chatCompletion(req *restful.Request, resp *restful.Response) { var r openai.ChatCompletionRequest err := req.ReadEntity(&r) @@ -92,8 +89,11 @@ func (s *OpenAIServer) chatCompletion(req *restful.Request, resp *restful.Respon _ = resp.WriteError(http.StatusBadRequest, err) return } + content := r.Messages[0].Content + if r.Model == "deepseek-r1" { + content = "To be or not to be, that is the question." + content + } if r.Stream { - content := r.Messages[0].Content for i := 0; i < len(content); i += 8 { buf := bytes.NewBuffer(nil) buf.WriteString("data: ") @@ -112,7 +112,7 @@ func (s *OpenAIServer) chatCompletion(req *restful.Request, resp *restful.Respon _ = resp.WriteEntity(openai.ChatCompletionResponse{ Choices: []openai.ChatCompletionChoice{{ Message: openai.ChatCompletionMessage{ - Content: r.Messages[0].Content, + Content: content, }, }}, }) @@ -120,15 +120,35 @@ func (s *OpenAIServer) chatCompletion(req *restful.Request, resp *restful.Respon } func (s *OpenAIServer) embeddings(req *restful.Request, resp *restful.Response) { + // parse request var r openai.EmbeddingRequest err := req.ReadEntity(&r) if err != nil { _ = resp.WriteError(http.StatusBadRequest, err) return } + input, ok := r.Input.(string) + if !ok { + _ = resp.WriteError(http.StatusBadRequest, fmt.Errorf("invalid input type")) + return + } + + // write response _ = resp.WriteEntity(openai.EmbeddingResponse{ Data: []openai.Embedding{{ - Embedding: s.mockEmbeddings, + Embedding: Hash(input), }}, }) } + +func Hash(input string) []float32 { + hasher := md5.New() + _, err := hasher.Write([]byte(input)) + if err != nil { + panic(err) + } + h := hasher.Sum(nil) + return lo.Map(h, func(b byte, _ int) float32 { + return float32(b) + }) +} diff --git a/common/mock/openai_test.go b/common/mock/openai_test.go index 450d2c6f6..592d1ac7a 100644 --- a/common/mock/openai_test.go +++ b/common/mock/openai_test.go @@ -16,12 +16,13 @@ package mock import ( "context" - "github.com/juju/errors" - "github.com/sashabaranov/go-openai" - "github.com/stretchr/testify/suite" "io" "strings" "testing" + + "github.com/juju/errors" + "github.com/sashabaranov/go-openai" + "github.com/stretchr/testify/suite" ) type OpenAITestSuite struct { @@ -97,7 +98,6 @@ func (suite *OpenAITestSuite) TestChatCompletionStream() { } func (suite *OpenAITestSuite) TestEmbeddings() { - suite.server.Embeddings([]float32{1, 2, 3}) resp, err := suite.client.CreateEmbeddings( context.Background(), openai.EmbeddingRequest{ @@ -106,7 +106,7 @@ func (suite *OpenAITestSuite) TestEmbeddings() { }, ) suite.NoError(err) - suite.Equal([]float32{1, 2, 3}, resp.Data[0].Embedding) + suite.Equal([]float32{139, 26, 153, 83, 196, 97, 18, 150, 168, 39, 171, 248, 196, 120, 4, 215}, resp.Data[0].Embedding) } func TestOpenAITestSuite(t *testing.T) { diff --git a/config/config.go b/config/config.go index 76f718c20..2f524f2cd 100644 --- a/config/config.go +++ b/config/config.go @@ -150,7 +150,7 @@ type NeighborsConfig struct { type ItemToItemConfig struct { Name string `mapstructure:"name" json:"name"` - Type string `mapstructure:"type" json:"type" validate:"oneof=embedding tags users llm"` + Type string `mapstructure:"type" json:"type" validate:"oneof=embedding tags users chat"` Column string `mapstructure:"column" json:"column" validate:"item_expr"` Prompt string `mapstructure:"prompt" json:"prompt"` } diff --git a/go.mod b/go.mod index 5d232fe0c..3cb59abb0 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/go-viper/mapstructure/v2 v2.2.1 github.com/google/uuid v1.6.0 github.com/gorilla/securecookie v1.1.1 - github.com/gorse-io/dashboard v0.0.0-20250209091713-a70341e78d48 + github.com/gorse-io/dashboard v0.0.0-20250214134211-90d95a512041 github.com/gorse-io/gorse-go v0.5.0-alpha.1 github.com/haxii/go-swagger-ui v0.0.0-20210203093335-a63a6bbde946 github.com/jaswdr/faker v1.16.0 @@ -110,7 +110,6 @@ require ( github.com/golang/protobuf v1.5.3 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/google/flatbuffers v2.0.6+incompatible // indirect - github.com/google/pprof v0.0.0-20240827171923-fa2c70bbbfe5 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0 // indirect github.com/hashicorp/go-version v1.6.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect @@ -152,6 +151,7 @@ require ( github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/shopspring/decimal v1.3.1 // indirect + github.com/sirupsen/logrus v1.9.3 // indirect github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cast v1.6.0 // indirect diff --git a/go.sum b/go.sum index 0c136e0ed..594cd32ab 100644 --- a/go.sum +++ b/go.sum @@ -35,6 +35,8 @@ gioui.org v0.0.0-20210308172011-57750fc8a0a6/go.mod h1:RSH6KIUZ0p2xy5zHDxgAM4zum git.sr.ht/~sbinet/gg v0.3.1/go.mod h1:KGYtlADtqsqANL9ueOFkWymvzUvLMQllU5Ixo+8v3pc= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ= +github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE= github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030IGemrRc= github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= @@ -211,6 +213,8 @@ github.com/go-resty/resty/v2 v2.16.3/go.mod h1:hkJtXbA2iKHzJheXYvQ8snQES5ZLGKMwQ github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= +github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIxtHqx8aGss= github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= @@ -304,8 +308,8 @@ github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyC github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= github.com/gorse-io/clickhouse v0.3.3-0.20220715124633-688011a495bb h1:z/oOWE+Vy0PLcwIulZmIug4FtmvE3dJ1YOGprLeHwwY= github.com/gorse-io/clickhouse v0.3.3-0.20220715124633-688011a495bb/go.mod h1:iILWzbul8U+gsf4kqbheF2QzBmdvVp63mloGGK8emDI= -github.com/gorse-io/dashboard v0.0.0-20250209091713-a70341e78d48 h1:kfCK07ae/+NvxlcPqh0SpaXxkDlceqSmamsX7t/E4+w= -github.com/gorse-io/dashboard v0.0.0-20250209091713-a70341e78d48/go.mod h1:lv2bu311bjIJeRfY+6hiIaw20M6fLxT4ma9Ye+bpwGY= +github.com/gorse-io/dashboard v0.0.0-20250214134211-90d95a512041 h1:Uzv+3PZKKatcS8O7G6v3gUXvu9dfHWqOz8EmkUbGlaw= +github.com/gorse-io/dashboard v0.0.0-20250214134211-90d95a512041/go.mod h1:q98umjWRGQ3Qi1BPksKsy0nvxwKb8V+W4d4XRegxKYg= github.com/gorse-io/gorgonia v0.0.0-20230817132253-6dd1dbf95849 h1:Hwywr6NxzYeZYn35KwOsw7j8ZiMT60TBzpbn1MbEido= github.com/gorse-io/gorgonia v0.0.0-20230817132253-6dd1dbf95849/go.mod h1:TtVGAt7ENNmgBnC0JA68CAjIDCEtcqaRHvnkAWJ/Fu0= github.com/gorse-io/gorse-go v0.5.0-alpha.1 h1:QBWKGAbSKNAWnieXVIdQiE0lLGvKXfFFAFPOQEkPW/E= @@ -488,7 +492,13 @@ github.com/montanaflynn/stats v0.7.1/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= +github.com/nikolalohinski/gonja/v2 v2.3.3 h1:5cTcmz0i/DwJl67US8Rvnb4OkBXB5V5OWd5IIAPPkXw= github.com/nikolalohinski/gonja/v2 v2.3.3/go.mod h1:8KC3RlefxnOaY5P4rH5erdwV0/owS83U615cSnDLYFs= +github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= +github.com/onsi/ginkgo/v2 v2.20.1 h1:YlVIbqct+ZmnEph770q9Q7NVAz4wwIiVNahee6JyUzo= +github.com/onsi/ginkgo/v2 v2.20.1/go.mod h1:lG9ey2Z29hR41WMVthyJBGUBcBhGOtoPF2VFMvBXFCI= +github.com/onsi/gomega v1.35.1 h1:Cwbd75ZBPxFSuZ6T+rN/WCb/gOc6YgFBXLlZLhC7Ds4= +github.com/onsi/gomega v1.35.1/go.mod h1:PvZbdDc8J6XJEpDK4HCuRBm8a6Fzp9/DmhC9C7yFlog= github.com/openzipkin/zipkin-go v0.4.1 h1:kNd/ST2yLLWhaWrkgchya40TJabe8Hioj9udfPcEO5A= github.com/openzipkin/zipkin-go v0.4.1/go.mod h1:qY0VqDSN1pOBN94dBc6w2GJlWLiovAyg7Qt6/I9HecM= github.com/orcaman/concurrent-map v1.0.0 h1:I/2A2XPCb4IuQWcQhBhSwGfiuybl/J0ev9HDbW65HOY= @@ -585,6 +595,8 @@ github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPx github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= @@ -904,6 +916,7 @@ golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220330033206-e17cdc41300f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= diff --git a/logics/item_to_item.go b/logics/item_to_item.go index 0801a25ae..857952e95 100644 --- a/logics/item_to_item.go +++ b/logics/item_to_item.go @@ -15,15 +15,21 @@ package logics import ( + "context" "errors" + "fmt" "sort" + "strings" "time" "github.com/chewxy/math32" mapset "github.com/deckarep/golang-set/v2" "github.com/expr-lang/expr" "github.com/expr-lang/expr/vm" + "github.com/nikolalohinski/gonja/v2" + "github.com/nikolalohinski/gonja/v2/exec" "github.com/samber/lo" + "github.com/sashabaranov/go-openai" "github.com/zhenghaoz/gorse/base/floats" "github.com/zhenghaoz/gorse/base/log" "github.com/zhenghaoz/gorse/common/ann" @@ -35,14 +41,15 @@ import ( ) type ItemToItemOptions struct { - TagsIDF []float32 - UsersIDF []float32 + TagsIDF []float32 + UsersIDF []float32 + OpenAIConfig config.OpenAIConfig } type ItemToItem interface { Items() []*data.Item Push(item *data.Item, feedback []dataset.ID) - PopAll(callback func(itemId string, score []cache.Score)) + PopAll(i int) []cache.Score } func NewItemToItem(cfg config.ItemToItemConfig, n int, timestamp time.Time, opts *ItemToItemOptions) (ItemToItem, error) { @@ -64,6 +71,11 @@ func NewItemToItem(cfg config.ItemToItemConfig, n int, timestamp time.Time, opts return nil, errors.New("tags and users IDF are required for auto item-to-item") } return newAutoItemToItem(cfg, n, timestamp, opts.TagsIDF, opts.UsersIDF) + case "chat": + if opts == nil || opts.OpenAIConfig.BaseURL == "" || opts.OpenAIConfig.AuthToken == "" { + return nil, errors.New("OpenAI config is required for chat item-to-item") + } + return newChatItemToItem(cfg, n, timestamp, opts.OpenAIConfig) default: return nil, errors.New("invalid item-to-item type") } @@ -82,22 +94,20 @@ func (b *baseItemToItem[T]) Items() []*data.Item { return b.items } -func (b *baseItemToItem[T]) PopAll(callback func(itemId string, score []cache.Score)) { - for index, item := range b.items { - scores, err := b.index.SearchIndex(index, b.n+1, true) - if err != nil { - log.Logger().Error("failed to search index", zap.Error(err)) - return - } - callback(item.ItemId, lo.Map(scores, func(v lo.Tuple2[int, float32], _ int) cache.Score { - return cache.Score{ - Id: b.items[v.A].ItemId, - Categories: b.items[v.A].Categories, - Score: -float64(v.B), - Timestamp: b.timestamp, - } - })) +func (b *baseItemToItem[T]) PopAll(i int) []cache.Score { + scores, err := b.index.SearchIndex(i, b.n+1, true) + if err != nil { + log.Logger().Error("failed to search index", zap.Error(err)) + return nil } + return lo.Map(scores, func(v lo.Tuple2[int, float32], _ int) cache.Score { + return cache.Score{ + Id: b.items[v.A].ItemId, + Categories: b.items[v.A].Categories, + Score: -float64(v.B), + Timestamp: b.timestamp, + } + }) } type embeddingItemToItem struct { @@ -340,3 +350,91 @@ func flatten(o any, tSet mapset.Set[dataset.ID]) { } } } + +type chatItemToItem struct { + *embeddingItemToItem + template *exec.Template + client *openai.Client + chatModel string + embeddingModel string +} + +func newChatItemToItem(cfg config.ItemToItemConfig, n int, timestamp time.Time, openaiConfig config.OpenAIConfig) (*chatItemToItem, error) { + // create embedding item-to-item recommender + embedding, err := newEmbeddingItemToItem(cfg, n, timestamp) + if err != nil { + return nil, err + } + // parse template + template, err := gonja.FromString(cfg.Prompt) + if err != nil { + return nil, err + } + // create openai client + clientConfig := openai.DefaultConfig(openaiConfig.AuthToken) + clientConfig.BaseURL = openaiConfig.BaseURL + return &chatItemToItem{ + embeddingItemToItem: embedding, + template: template, + client: openai.NewClientWithConfig(clientConfig), + chatModel: openaiConfig.ChatCompletionModel, + embeddingModel: openaiConfig.EmbeddingsModel, + }, nil +} + +func (g *chatItemToItem) PopAll(i int) []cache.Score { + // render template + var buf strings.Builder + ctx := exec.NewContext(map[string]any{ + "item": g.items[i], + }) + if err := g.template.Execute(&buf, ctx); err != nil { + log.Logger().Error("failed to execute template", zap.Error(err)) + return nil + } + fmt.Println(buf.String()) + // chat completion + resp, err := g.client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + Model: g.chatModel, + Messages: []openai.ChatCompletionMessage{{ + Role: openai.ChatMessageRoleUser, + Content: buf.String(), + }}, + }) + if err != nil { + log.Logger().Error("failed to chat completion", zap.Error(err)) + return nil + } + message := stripThink(resp.Choices[0].Message.Content) + // message embedding + resp2, err := g.client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{ + Input: message, + Model: openai.EmbeddingModel(g.embeddingModel), + }) + if err != nil { + log.Logger().Error("failed to create embeddings", zap.Error(err)) + return nil + } + embedding := resp2.Data[0].Embedding + // search index + scores := g.index.SearchVector(embedding, g.n+1, true) + return lo.Map(scores, func(v lo.Tuple2[int, float32], _ int) cache.Score { + return cache.Score{ + Id: g.items[v.A].ItemId, + Categories: g.items[v.A].Categories, + Score: -float64(v.B), + Timestamp: g.timestamp, + } + }) +} + +func stripThink(s string) string { + if len(s) < 7 || s[:7] != "" { + return s + } + end := strings.Index(s, "") + if end == -1 { + return s + } + return s[end+8:] +} diff --git a/logics/item_to_item_test.go b/logics/item_to_item_test.go index 70a6864ce..8a8ba351a 100644 --- a/logics/item_to_item_test.go +++ b/logics/item_to_item_test.go @@ -20,9 +20,10 @@ import ( "time" "github.com/stretchr/testify/suite" + "github.com/zhenghaoz/gorse/base/floats" + "github.com/zhenghaoz/gorse/common/mock" "github.com/zhenghaoz/gorse/config" "github.com/zhenghaoz/gorse/dataset" - "github.com/zhenghaoz/gorse/storage/cache" "github.com/zhenghaoz/gorse/storage/data" ) @@ -97,12 +98,7 @@ func (suite *ItemToItemTestSuite) TestEmbedding() { }, nil) } - var scores []cache.Score - item2item.PopAll(func(itemId string, score []cache.Score) { - if itemId == "0" { - scores = score - } - }) + scores := item2item.PopAll(0) suite.Len(scores, 10) for i := 1; i <= 10; i++ { suite.Equal(strconv.Itoa(i), scores[i-1].Id) @@ -131,12 +127,7 @@ func (suite *ItemToItemTestSuite) TestTags() { }, nil) } - var scores []cache.Score - item2item.PopAll(func(itemId string, score []cache.Score) { - if itemId == "0" { - scores = score - } - }) + scores := item2item.PopAll(0) suite.Len(scores, 10) for i := 1; i <= 10; i++ { suite.Equal(strconv.Itoa(i), scores[i-1].Id) @@ -160,12 +151,7 @@ func (suite *ItemToItemTestSuite) TestUsers() { item2item.Push(&data.Item{ItemId: strconv.Itoa(i)}, feedback) } - var scores []cache.Score - item2item.PopAll(func(itemId string, score []cache.Score) { - if itemId == "0" { - scores = score - } - }) + scores := item2item.PopAll(0) suite.Len(scores, 10) for i := 1; i <= 10; i++ { suite.Equal(strconv.Itoa(i), scores[i-1].Id) @@ -198,24 +184,57 @@ func (suite *ItemToItemTestSuite) TestAuto() { item2item.Push(item, feedback) } - var scores0, scores1 []cache.Score - item2item.PopAll(func(itemId string, score []cache.Score) { - if itemId == "0" { - scores0 = score - } else if itemId == "1" { - scores1 = score - } - }) + scores0 := item2item.PopAll(0) suite.Len(scores0, 10) for i := 1; i <= 10; i++ { suite.Equal(strconv.Itoa(i*2), scores0[i-1].Id) } + scores1 := item2item.PopAll(1) suite.Len(scores1, 10) for i := 1; i <= 10; i++ { suite.Equal(strconv.Itoa(i*2+1), scores1[i-1].Id) } } +func (suite *ItemToItemTestSuite) TestChat() { + mockAI := mock.NewOpenAIServer() + go func() { + _ = mockAI.Start() + }() + mockAI.Ready() + defer mockAI.Close() + + timestamp := time.Now() + item2item, err := newChatItemToItem(config.ItemToItemConfig{ + Column: "item.Labels.embeddings", + Prompt: "Please generate similar items for {{ item.Labels.title }}.", + }, 10, timestamp, config.OpenAIConfig{ + BaseURL: mockAI.BaseURL(), + AuthToken: mockAI.AuthToken(), + ChatCompletionModel: "deepseek-r1", + EmbeddingsModel: "text-similarity-ada-001", + }) + suite.NoError(err) + + for i := 0; i < 100; i++ { + embedding := mock.Hash("Please generate similar items for item_0.") + floats.AddConst(embedding, float32(i)) + item2item.Push(&data.Item{ + ItemId: strconv.Itoa(i), + Labels: map[string]any{ + "title": "item_" + strconv.Itoa(i), + "embeddings": embedding, + }, + }, nil) + } + + scores := item2item.PopAll(0) + suite.Len(scores, 10) + for i := 1; i <= 10; i++ { + suite.Equal(strconv.Itoa(i), scores[i-1].Id) + } +} + func TestItemToItem(t *testing.T) { suite.Run(t, new(ItemToItemTestSuite)) } diff --git a/logics/user_to_user.go b/logics/user_to_user.go index 423b1037e..a23fcf5bc 100644 --- a/logics/user_to_user.go +++ b/logics/user_to_user.go @@ -43,7 +43,7 @@ type UserToUserOptions struct { type UserToUser interface { Users() []*data.User Push(user *data.User, feedback []dataset.ID) - PopAll(callback func(userId string, score []cache.Score)) + PopAll(i int) []cache.Score } func NewUserToUser(cfg UserToUserConfig, n int, timestamp time.Time, opts *UserToUserOptions) (UserToUser, error) { @@ -82,21 +82,19 @@ func (b *baseUserToUser[T]) Users() []*data.User { return b.users } -func (b *baseUserToUser[T]) PopAll(callback func(userId string, score []cache.Score)) { - for index, user := range b.users { - scores, err := b.index.SearchIndex(index, b.n+1, true) - if err != nil { - log.Logger().Error("failed to search index", zap.Error(err)) - return - } - callback(user.UserId, lo.Map(scores, func(v lo.Tuple2[int, float32], _ int) cache.Score { - return cache.Score{ - Id: b.users[v.A].UserId, - Score: -float64(v.B), - Timestamp: b.timestamp, - } - })) +func (b *baseUserToUser[T]) PopAll(i int) []cache.Score { + scores, err := b.index.SearchIndex(i, b.n+1, true) + if err != nil { + log.Logger().Error("failed to search index", zap.Error(err)) + return nil } + return lo.Map(scores, func(v lo.Tuple2[int, float32], _ int) cache.Score { + return cache.Score{ + Id: b.users[v.A].UserId, + Score: -float64(v.B), + Timestamp: b.timestamp, + } + }) } type embeddingUserToUser struct { diff --git a/logics/user_to_user_test.go b/logics/user_to_user_test.go index 16101b283..fc61ed933 100644 --- a/logics/user_to_user_test.go +++ b/logics/user_to_user_test.go @@ -21,7 +21,6 @@ import ( "github.com/stretchr/testify/suite" "github.com/zhenghaoz/gorse/dataset" - "github.com/zhenghaoz/gorse/storage/cache" "github.com/zhenghaoz/gorse/storage/data" ) @@ -45,12 +44,7 @@ func (suite *UserToUserTestSuite) TestEmbedding() { }, nil) } - var scores []cache.Score - user2user.PopAll(func(userId string, score []cache.Score) { - if userId == "0" { - scores = score - } - }) + scores := user2user.PopAll(0) suite.Len(scores, 10) for i := 1; i <= 10; i++ { suite.Equal(strconv.Itoa(i), scores[i-1].Id) @@ -79,12 +73,7 @@ func (suite *UserToUserTestSuite) TestTags() { }, nil) } - var scores []cache.Score - user2user.PopAll(func(userId string, score []cache.Score) { - if userId == "0" { - scores = score - } - }) + scores := user2user.PopAll(0) suite.Len(scores, 10) for i := 1; i <= 10; i++ { suite.Equal(strconv.Itoa(i), scores[i-1].Id) @@ -108,12 +97,7 @@ func (suite *UserToUserTestSuite) TestItems() { user2user.Push(&data.User{UserId: strconv.Itoa(i)}, feedback) } - var scores []cache.Score - user2user.PopAll(func(userId string, score []cache.Score) { - if userId == "0" { - scores = score - } - }) + scores := user2user.PopAll(0) suite.Len(scores, 10) for i := 1; i <= 10; i++ { suite.Equal(strconv.Itoa(i), scores[i-1].Id) @@ -146,15 +130,9 @@ func (suite *UserToUserTestSuite) TestAuto() { user2user.Push(user, feedback) } - var scores0, scores1 []cache.Score - user2user.PopAll(func(userId string, score []cache.Score) { - if userId == "0" { - scores0 = score - } else if userId == "1" { - scores1 = score - } - }) + scores0 := user2user.PopAll(0) suite.Len(scores0, 10) + scores1 := user2user.PopAll(1) suite.Len(scores1, 10) } diff --git a/master/rest.go b/master/rest.go index 7bd1299fd..ff23af014 100644 --- a/master/rest.go +++ b/master/rest.go @@ -430,6 +430,11 @@ func (m *Master) handleUserInfo(request *restful.Request, response *restful.Resp server.Ok(response, UserInfo{ Name: m.Config.Master.DashboardUserName, }) + } else { + response.Header().Set("Content-Type", "application/json") + if _, err := response.Write([]byte("null")); err != nil { + log.ResponseLogger(response).Error("failed to write response", zap.Error(err)) + } } } diff --git a/master/tasks.go b/master/tasks.go index fb0083870..b73ce668f 100644 --- a/master/tasks.go +++ b/master/tasks.go @@ -1012,8 +1012,9 @@ func (m *Master) updateItemToItem(dataset *dataset.Dataset) error { itemToItemRecommenders := make([]logics.ItemToItem, 0, len(itemToItemConfigs)) for _, cfg := range itemToItemConfigs { recommender, err := logics.NewItemToItem(cfg, m.Config.Recommend.CacheSize, dataset.GetTimestamp(), &logics.ItemToItemOptions{ - TagsIDF: dataset.GetItemColumnValuesIDF(), - UsersIDF: dataset.GetUserIDF(), + TagsIDF: dataset.GetItemColumnValuesIDF(), + UsersIDF: dataset.GetUserIDF(), + OpenAIConfig: m.Config.OpenAI, }) if err != nil { return errors.Trace(err) @@ -1033,31 +1034,35 @@ func (m *Master) updateItemToItem(dataset *dataset.Dataset) error { // Save item-to-item recommendations to cache for i, recommender := range itemToItemRecommenders { - recommender.PopAll(func(itemId string, score []cache.Score) { + for j, item := range recommender.Items() { itemToItemConfig := itemToItemConfigs[i] - if m.needUpdateItemToItem(itemId, itemToItemConfigs[i]) { + if m.needUpdateItemToItem(item.ItemId, itemToItemConfig) { + score := recommender.PopAll(j) + if score == nil { + continue + } log.Logger().Debug("update item-to-item recommendation", - zap.String("item_id", itemId), + zap.String("item_id", item.ItemId), zap.String("name", itemToItemConfig.Name), zap.Int("n_recommendations", len(score))) // Save item-to-item recommendation to cache - if err := m.CacheClient.AddScores(ctx, cache.ItemToItem, cache.Key(itemToItemConfig.Name, itemId), score); err != nil { + if err := m.CacheClient.AddScores(ctx, cache.ItemToItem, cache.Key(itemToItemConfig.Name, item.ItemId), score); err != nil { log.Logger().Error("failed to save item-to-item recommendation to cache", - zap.String("item_id", itemId), zap.Error(err)) - return + zap.String("item_id", item.ItemId), zap.Error(err)) + continue } // Save item-to-item digest and last update time to cache if err := m.CacheClient.Set(ctx, - cache.String(cache.Key(cache.ItemToItemDigest, itemToItemConfig.Name, itemId), itemToItemConfig.Hash()), - cache.Time(cache.Key(cache.ItemToItemUpdateTime, itemToItemConfig.Name, itemId), time.Now()), + cache.String(cache.Key(cache.ItemToItemDigest, itemToItemConfig.Name, item.ItemId), itemToItemConfig.Hash()), + cache.Time(cache.Key(cache.ItemToItemUpdateTime, itemToItemConfig.Name, item.ItemId), time.Now()), ); err != nil { log.Logger().Error("failed to save item-to-item digest to cache", - zap.String("item_id", itemId), zap.Error(err)) - return + zap.String("item_id", item.ItemId), zap.Error(err)) + continue } } span.Add(1) - }) + } } return nil } @@ -1132,26 +1137,31 @@ func (m *Master) updateUserToUser(dataset *dataset.Dataset) error { } // Save user-to-user recommendations to cache - userToUserRecommender.PopAll(func(userId string, score []cache.Score) { - if m.needUpdateUserToUser(userId) { + for j, user := range userToUserRecommender.Users() { + if m.needUpdateUserToUser(user.UserId) { + score := userToUserRecommender.PopAll(j) + if score == nil { + continue + } log.Logger().Debug("update user neighbors", - zap.String("user_id", userId), + zap.String("user_id", user.UserId), zap.Int("n_recommendations", len(score))) // Save user-to-user recommendations to cache - if err := m.CacheClient.AddScores(ctx, cache.UserToUser, cache.Key(cache.Neighbors, userId), score); err != nil { - log.Logger().Error("failed to save user neighbors to cache", zap.String("user_id", userId), zap.Error(err)) - return + if err := m.CacheClient.AddScores(ctx, cache.UserToUser, cache.Key(cache.Neighbors, user.UserId), score); err != nil { + log.Logger().Error("failed to save user neighbors to cache", zap.String("user_id", user.UserId), zap.Error(err)) + continue } // Save user-to-user digest and last update time to cache if err := m.CacheClient.Set(ctx, - cache.String(cache.Key(cache.UserToUserDigest, cache.Key(cache.Neighbors, userId)), m.Config.UserNeighborDigest()), - cache.Time(cache.Key(cache.UserToUserUpdateTime, cache.Key(cache.Neighbors, userId)), time.Now()), + cache.String(cache.Key(cache.UserToUserDigest, cache.Key(cache.Neighbors, user.UserId)), m.Config.UserNeighborDigest()), + cache.Time(cache.Key(cache.UserToUserUpdateTime, cache.Key(cache.Neighbors, user.UserId)), time.Now()), ); err != nil { - log.Logger().Error("failed to save user neighbors digest to cache", zap.String("user_id", userId), zap.Error(err)) - return + log.Logger().Error("failed to save user neighbors digest to cache", zap.String("user_id", user.UserId), zap.Error(err)) + continue } } - }) + span.Add(1) + } return nil }