diff --git a/common/mock/openai.go b/common/mock/openai.go
index d347f5b33..74bd0f6b1 100644
--- a/common/mock/openai.go
+++ b/common/mock/openai.go
@@ -16,12 +16,15 @@ package mock
import (
"bytes"
+ "crypto/md5"
"encoding/json"
"fmt"
- "github.com/emicklei/go-restful/v3"
- "github.com/sashabaranov/go-openai"
"net"
"net/http"
+
+ "github.com/emicklei/go-restful/v3"
+ "github.com/samber/lo"
+ "github.com/sashabaranov/go-openai"
)
type OpenAIServer struct {
@@ -29,8 +32,6 @@ type OpenAIServer struct {
httpServer *http.Server
authToken string
ready chan struct{}
-
- mockEmbeddings []float32
}
func NewOpenAIServer() *OpenAIServer {
@@ -81,10 +82,6 @@ func (s *OpenAIServer) Close() error {
return s.httpServer.Close()
}
-func (s *OpenAIServer) Embeddings(embeddings []float32) {
- s.mockEmbeddings = embeddings
-}
-
func (s *OpenAIServer) chatCompletion(req *restful.Request, resp *restful.Response) {
var r openai.ChatCompletionRequest
err := req.ReadEntity(&r)
@@ -92,8 +89,11 @@ func (s *OpenAIServer) chatCompletion(req *restful.Request, resp *restful.Respon
_ = resp.WriteError(http.StatusBadRequest, err)
return
}
+ content := r.Messages[0].Content
+ if r.Model == "deepseek-r1" {
+ content = "To be or not to be, that is the question." + content
+ }
if r.Stream {
- content := r.Messages[0].Content
for i := 0; i < len(content); i += 8 {
buf := bytes.NewBuffer(nil)
buf.WriteString("data: ")
@@ -112,7 +112,7 @@ func (s *OpenAIServer) chatCompletion(req *restful.Request, resp *restful.Respon
_ = resp.WriteEntity(openai.ChatCompletionResponse{
Choices: []openai.ChatCompletionChoice{{
Message: openai.ChatCompletionMessage{
- Content: r.Messages[0].Content,
+ Content: content,
},
}},
})
@@ -120,15 +120,35 @@ func (s *OpenAIServer) chatCompletion(req *restful.Request, resp *restful.Respon
}
func (s *OpenAIServer) embeddings(req *restful.Request, resp *restful.Response) {
+ // parse request
var r openai.EmbeddingRequest
err := req.ReadEntity(&r)
if err != nil {
_ = resp.WriteError(http.StatusBadRequest, err)
return
}
+ input, ok := r.Input.(string)
+ if !ok {
+ _ = resp.WriteError(http.StatusBadRequest, fmt.Errorf("invalid input type"))
+ return
+ }
+
+ // write response
_ = resp.WriteEntity(openai.EmbeddingResponse{
Data: []openai.Embedding{{
- Embedding: s.mockEmbeddings,
+ Embedding: Hash(input),
}},
})
}
+
+func Hash(input string) []float32 {
+ hasher := md5.New()
+ _, err := hasher.Write([]byte(input))
+ if err != nil {
+ panic(err)
+ }
+ h := hasher.Sum(nil)
+ return lo.Map(h, func(b byte, _ int) float32 {
+ return float32(b)
+ })
+}
diff --git a/common/mock/openai_test.go b/common/mock/openai_test.go
index 450d2c6f6..592d1ac7a 100644
--- a/common/mock/openai_test.go
+++ b/common/mock/openai_test.go
@@ -16,12 +16,13 @@ package mock
import (
"context"
- "github.com/juju/errors"
- "github.com/sashabaranov/go-openai"
- "github.com/stretchr/testify/suite"
"io"
"strings"
"testing"
+
+ "github.com/juju/errors"
+ "github.com/sashabaranov/go-openai"
+ "github.com/stretchr/testify/suite"
)
type OpenAITestSuite struct {
@@ -97,7 +98,6 @@ func (suite *OpenAITestSuite) TestChatCompletionStream() {
}
func (suite *OpenAITestSuite) TestEmbeddings() {
- suite.server.Embeddings([]float32{1, 2, 3})
resp, err := suite.client.CreateEmbeddings(
context.Background(),
openai.EmbeddingRequest{
@@ -106,7 +106,7 @@ func (suite *OpenAITestSuite) TestEmbeddings() {
},
)
suite.NoError(err)
- suite.Equal([]float32{1, 2, 3}, resp.Data[0].Embedding)
+ suite.Equal([]float32{139, 26, 153, 83, 196, 97, 18, 150, 168, 39, 171, 248, 196, 120, 4, 215}, resp.Data[0].Embedding)
}
func TestOpenAITestSuite(t *testing.T) {
diff --git a/config/config.go b/config/config.go
index 76f718c20..2f524f2cd 100644
--- a/config/config.go
+++ b/config/config.go
@@ -150,7 +150,7 @@ type NeighborsConfig struct {
type ItemToItemConfig struct {
Name string `mapstructure:"name" json:"name"`
- Type string `mapstructure:"type" json:"type" validate:"oneof=embedding tags users llm"`
+ Type string `mapstructure:"type" json:"type" validate:"oneof=embedding tags users chat"`
Column string `mapstructure:"column" json:"column" validate:"item_expr"`
Prompt string `mapstructure:"prompt" json:"prompt"`
}
diff --git a/go.mod b/go.mod
index 5d232fe0c..3cb59abb0 100644
--- a/go.mod
+++ b/go.mod
@@ -22,7 +22,7 @@ require (
github.com/go-viper/mapstructure/v2 v2.2.1
github.com/google/uuid v1.6.0
github.com/gorilla/securecookie v1.1.1
- github.com/gorse-io/dashboard v0.0.0-20250209091713-a70341e78d48
+ github.com/gorse-io/dashboard v0.0.0-20250214134211-90d95a512041
github.com/gorse-io/gorse-go v0.5.0-alpha.1
github.com/haxii/go-swagger-ui v0.0.0-20210203093335-a63a6bbde946
github.com/jaswdr/faker v1.16.0
@@ -110,7 +110,6 @@ require (
github.com/golang/protobuf v1.5.3 // indirect
github.com/golang/snappy v0.0.4 // indirect
github.com/google/flatbuffers v2.0.6+incompatible // indirect
- github.com/google/pprof v0.0.0-20240827171923-fa2c70bbbfe5 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0 // indirect
github.com/hashicorp/go-version v1.6.0 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
@@ -152,6 +151,7 @@ require (
github.com/sagikazarmark/locafero v0.4.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/shopspring/decimal v1.3.1 // indirect
+ github.com/sirupsen/logrus v1.9.3 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
github.com/spf13/afero v1.11.0 // indirect
github.com/spf13/cast v1.6.0 // indirect
diff --git a/go.sum b/go.sum
index 0c136e0ed..594cd32ab 100644
--- a/go.sum
+++ b/go.sum
@@ -35,6 +35,8 @@ gioui.org v0.0.0-20210308172011-57750fc8a0a6/go.mod h1:RSH6KIUZ0p2xy5zHDxgAM4zum
git.sr.ht/~sbinet/gg v0.3.1/go.mod h1:KGYtlADtqsqANL9ueOFkWymvzUvLMQllU5Ixo+8v3pc=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
+github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ=
+github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE=
github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030IGemrRc=
github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs=
github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0=
@@ -211,6 +213,8 @@ github.com/go-resty/resty/v2 v2.16.3/go.mod h1:hkJtXbA2iKHzJheXYvQ8snQES5ZLGKMwQ
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
+github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI=
+github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8=
github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIxtHqx8aGss=
github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw=
@@ -304,8 +308,8 @@ github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyC
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
github.com/gorse-io/clickhouse v0.3.3-0.20220715124633-688011a495bb h1:z/oOWE+Vy0PLcwIulZmIug4FtmvE3dJ1YOGprLeHwwY=
github.com/gorse-io/clickhouse v0.3.3-0.20220715124633-688011a495bb/go.mod h1:iILWzbul8U+gsf4kqbheF2QzBmdvVp63mloGGK8emDI=
-github.com/gorse-io/dashboard v0.0.0-20250209091713-a70341e78d48 h1:kfCK07ae/+NvxlcPqh0SpaXxkDlceqSmamsX7t/E4+w=
-github.com/gorse-io/dashboard v0.0.0-20250209091713-a70341e78d48/go.mod h1:lv2bu311bjIJeRfY+6hiIaw20M6fLxT4ma9Ye+bpwGY=
+github.com/gorse-io/dashboard v0.0.0-20250214134211-90d95a512041 h1:Uzv+3PZKKatcS8O7G6v3gUXvu9dfHWqOz8EmkUbGlaw=
+github.com/gorse-io/dashboard v0.0.0-20250214134211-90d95a512041/go.mod h1:q98umjWRGQ3Qi1BPksKsy0nvxwKb8V+W4d4XRegxKYg=
github.com/gorse-io/gorgonia v0.0.0-20230817132253-6dd1dbf95849 h1:Hwywr6NxzYeZYn35KwOsw7j8ZiMT60TBzpbn1MbEido=
github.com/gorse-io/gorgonia v0.0.0-20230817132253-6dd1dbf95849/go.mod h1:TtVGAt7ENNmgBnC0JA68CAjIDCEtcqaRHvnkAWJ/Fu0=
github.com/gorse-io/gorse-go v0.5.0-alpha.1 h1:QBWKGAbSKNAWnieXVIdQiE0lLGvKXfFFAFPOQEkPW/E=
@@ -488,7 +492,13 @@ github.com/montanaflynn/stats v0.7.1/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
+github.com/nikolalohinski/gonja/v2 v2.3.3 h1:5cTcmz0i/DwJl67US8Rvnb4OkBXB5V5OWd5IIAPPkXw=
github.com/nikolalohinski/gonja/v2 v2.3.3/go.mod h1:8KC3RlefxnOaY5P4rH5erdwV0/owS83U615cSnDLYFs=
+github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
+github.com/onsi/ginkgo/v2 v2.20.1 h1:YlVIbqct+ZmnEph770q9Q7NVAz4wwIiVNahee6JyUzo=
+github.com/onsi/ginkgo/v2 v2.20.1/go.mod h1:lG9ey2Z29hR41WMVthyJBGUBcBhGOtoPF2VFMvBXFCI=
+github.com/onsi/gomega v1.35.1 h1:Cwbd75ZBPxFSuZ6T+rN/WCb/gOc6YgFBXLlZLhC7Ds4=
+github.com/onsi/gomega v1.35.1/go.mod h1:PvZbdDc8J6XJEpDK4HCuRBm8a6Fzp9/DmhC9C7yFlog=
github.com/openzipkin/zipkin-go v0.4.1 h1:kNd/ST2yLLWhaWrkgchya40TJabe8Hioj9udfPcEO5A=
github.com/openzipkin/zipkin-go v0.4.1/go.mod h1:qY0VqDSN1pOBN94dBc6w2GJlWLiovAyg7Qt6/I9HecM=
github.com/orcaman/concurrent-map v1.0.0 h1:I/2A2XPCb4IuQWcQhBhSwGfiuybl/J0ev9HDbW65HOY=
@@ -585,6 +595,8 @@ github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPx
github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
+github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
+github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
@@ -904,6 +916,7 @@ golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220330033206-e17cdc41300f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
diff --git a/logics/item_to_item.go b/logics/item_to_item.go
index 0801a25ae..857952e95 100644
--- a/logics/item_to_item.go
+++ b/logics/item_to_item.go
@@ -15,15 +15,21 @@
package logics
import (
+ "context"
"errors"
+ "fmt"
"sort"
+ "strings"
"time"
"github.com/chewxy/math32"
mapset "github.com/deckarep/golang-set/v2"
"github.com/expr-lang/expr"
"github.com/expr-lang/expr/vm"
+ "github.com/nikolalohinski/gonja/v2"
+ "github.com/nikolalohinski/gonja/v2/exec"
"github.com/samber/lo"
+ "github.com/sashabaranov/go-openai"
"github.com/zhenghaoz/gorse/base/floats"
"github.com/zhenghaoz/gorse/base/log"
"github.com/zhenghaoz/gorse/common/ann"
@@ -35,14 +41,15 @@ import (
)
type ItemToItemOptions struct {
- TagsIDF []float32
- UsersIDF []float32
+ TagsIDF []float32
+ UsersIDF []float32
+ OpenAIConfig config.OpenAIConfig
}
type ItemToItem interface {
Items() []*data.Item
Push(item *data.Item, feedback []dataset.ID)
- PopAll(callback func(itemId string, score []cache.Score))
+ PopAll(i int) []cache.Score
}
func NewItemToItem(cfg config.ItemToItemConfig, n int, timestamp time.Time, opts *ItemToItemOptions) (ItemToItem, error) {
@@ -64,6 +71,11 @@ func NewItemToItem(cfg config.ItemToItemConfig, n int, timestamp time.Time, opts
return nil, errors.New("tags and users IDF are required for auto item-to-item")
}
return newAutoItemToItem(cfg, n, timestamp, opts.TagsIDF, opts.UsersIDF)
+ case "chat":
+ if opts == nil || opts.OpenAIConfig.BaseURL == "" || opts.OpenAIConfig.AuthToken == "" {
+ return nil, errors.New("OpenAI config is required for chat item-to-item")
+ }
+ return newChatItemToItem(cfg, n, timestamp, opts.OpenAIConfig)
default:
return nil, errors.New("invalid item-to-item type")
}
@@ -82,22 +94,20 @@ func (b *baseItemToItem[T]) Items() []*data.Item {
return b.items
}
-func (b *baseItemToItem[T]) PopAll(callback func(itemId string, score []cache.Score)) {
- for index, item := range b.items {
- scores, err := b.index.SearchIndex(index, b.n+1, true)
- if err != nil {
- log.Logger().Error("failed to search index", zap.Error(err))
- return
- }
- callback(item.ItemId, lo.Map(scores, func(v lo.Tuple2[int, float32], _ int) cache.Score {
- return cache.Score{
- Id: b.items[v.A].ItemId,
- Categories: b.items[v.A].Categories,
- Score: -float64(v.B),
- Timestamp: b.timestamp,
- }
- }))
+func (b *baseItemToItem[T]) PopAll(i int) []cache.Score {
+ scores, err := b.index.SearchIndex(i, b.n+1, true)
+ if err != nil {
+ log.Logger().Error("failed to search index", zap.Error(err))
+ return nil
}
+ return lo.Map(scores, func(v lo.Tuple2[int, float32], _ int) cache.Score {
+ return cache.Score{
+ Id: b.items[v.A].ItemId,
+ Categories: b.items[v.A].Categories,
+ Score: -float64(v.B),
+ Timestamp: b.timestamp,
+ }
+ })
}
type embeddingItemToItem struct {
@@ -340,3 +350,91 @@ func flatten(o any, tSet mapset.Set[dataset.ID]) {
}
}
}
+
+type chatItemToItem struct {
+ *embeddingItemToItem
+ template *exec.Template
+ client *openai.Client
+ chatModel string
+ embeddingModel string
+}
+
+func newChatItemToItem(cfg config.ItemToItemConfig, n int, timestamp time.Time, openaiConfig config.OpenAIConfig) (*chatItemToItem, error) {
+ // create embedding item-to-item recommender
+ embedding, err := newEmbeddingItemToItem(cfg, n, timestamp)
+ if err != nil {
+ return nil, err
+ }
+ // parse template
+ template, err := gonja.FromString(cfg.Prompt)
+ if err != nil {
+ return nil, err
+ }
+ // create openai client
+ clientConfig := openai.DefaultConfig(openaiConfig.AuthToken)
+ clientConfig.BaseURL = openaiConfig.BaseURL
+ return &chatItemToItem{
+ embeddingItemToItem: embedding,
+ template: template,
+ client: openai.NewClientWithConfig(clientConfig),
+ chatModel: openaiConfig.ChatCompletionModel,
+ embeddingModel: openaiConfig.EmbeddingsModel,
+ }, nil
+}
+
+func (g *chatItemToItem) PopAll(i int) []cache.Score {
+ // render template
+ var buf strings.Builder
+ ctx := exec.NewContext(map[string]any{
+ "item": g.items[i],
+ })
+ if err := g.template.Execute(&buf, ctx); err != nil {
+ log.Logger().Error("failed to execute template", zap.Error(err))
+ return nil
+ }
+ fmt.Println(buf.String())
+ // chat completion
+ resp, err := g.client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
+ Model: g.chatModel,
+ Messages: []openai.ChatCompletionMessage{{
+ Role: openai.ChatMessageRoleUser,
+ Content: buf.String(),
+ }},
+ })
+ if err != nil {
+ log.Logger().Error("failed to chat completion", zap.Error(err))
+ return nil
+ }
+ message := stripThink(resp.Choices[0].Message.Content)
+ // message embedding
+ resp2, err := g.client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{
+ Input: message,
+ Model: openai.EmbeddingModel(g.embeddingModel),
+ })
+ if err != nil {
+ log.Logger().Error("failed to create embeddings", zap.Error(err))
+ return nil
+ }
+ embedding := resp2.Data[0].Embedding
+ // search index
+ scores := g.index.SearchVector(embedding, g.n+1, true)
+ return lo.Map(scores, func(v lo.Tuple2[int, float32], _ int) cache.Score {
+ return cache.Score{
+ Id: g.items[v.A].ItemId,
+ Categories: g.items[v.A].Categories,
+ Score: -float64(v.B),
+ Timestamp: g.timestamp,
+ }
+ })
+}
+
+func stripThink(s string) string {
+ if len(s) < 7 || s[:7] != "" {
+ return s
+ }
+ end := strings.Index(s, "")
+ if end == -1 {
+ return s
+ }
+ return s[end+8:]
+}
diff --git a/logics/item_to_item_test.go b/logics/item_to_item_test.go
index 70a6864ce..8a8ba351a 100644
--- a/logics/item_to_item_test.go
+++ b/logics/item_to_item_test.go
@@ -20,9 +20,10 @@ import (
"time"
"github.com/stretchr/testify/suite"
+ "github.com/zhenghaoz/gorse/base/floats"
+ "github.com/zhenghaoz/gorse/common/mock"
"github.com/zhenghaoz/gorse/config"
"github.com/zhenghaoz/gorse/dataset"
- "github.com/zhenghaoz/gorse/storage/cache"
"github.com/zhenghaoz/gorse/storage/data"
)
@@ -97,12 +98,7 @@ func (suite *ItemToItemTestSuite) TestEmbedding() {
}, nil)
}
- var scores []cache.Score
- item2item.PopAll(func(itemId string, score []cache.Score) {
- if itemId == "0" {
- scores = score
- }
- })
+ scores := item2item.PopAll(0)
suite.Len(scores, 10)
for i := 1; i <= 10; i++ {
suite.Equal(strconv.Itoa(i), scores[i-1].Id)
@@ -131,12 +127,7 @@ func (suite *ItemToItemTestSuite) TestTags() {
}, nil)
}
- var scores []cache.Score
- item2item.PopAll(func(itemId string, score []cache.Score) {
- if itemId == "0" {
- scores = score
- }
- })
+ scores := item2item.PopAll(0)
suite.Len(scores, 10)
for i := 1; i <= 10; i++ {
suite.Equal(strconv.Itoa(i), scores[i-1].Id)
@@ -160,12 +151,7 @@ func (suite *ItemToItemTestSuite) TestUsers() {
item2item.Push(&data.Item{ItemId: strconv.Itoa(i)}, feedback)
}
- var scores []cache.Score
- item2item.PopAll(func(itemId string, score []cache.Score) {
- if itemId == "0" {
- scores = score
- }
- })
+ scores := item2item.PopAll(0)
suite.Len(scores, 10)
for i := 1; i <= 10; i++ {
suite.Equal(strconv.Itoa(i), scores[i-1].Id)
@@ -198,24 +184,57 @@ func (suite *ItemToItemTestSuite) TestAuto() {
item2item.Push(item, feedback)
}
- var scores0, scores1 []cache.Score
- item2item.PopAll(func(itemId string, score []cache.Score) {
- if itemId == "0" {
- scores0 = score
- } else if itemId == "1" {
- scores1 = score
- }
- })
+ scores0 := item2item.PopAll(0)
suite.Len(scores0, 10)
for i := 1; i <= 10; i++ {
suite.Equal(strconv.Itoa(i*2), scores0[i-1].Id)
}
+ scores1 := item2item.PopAll(1)
suite.Len(scores1, 10)
for i := 1; i <= 10; i++ {
suite.Equal(strconv.Itoa(i*2+1), scores1[i-1].Id)
}
}
+func (suite *ItemToItemTestSuite) TestChat() {
+ mockAI := mock.NewOpenAIServer()
+ go func() {
+ _ = mockAI.Start()
+ }()
+ mockAI.Ready()
+ defer mockAI.Close()
+
+ timestamp := time.Now()
+ item2item, err := newChatItemToItem(config.ItemToItemConfig{
+ Column: "item.Labels.embeddings",
+ Prompt: "Please generate similar items for {{ item.Labels.title }}.",
+ }, 10, timestamp, config.OpenAIConfig{
+ BaseURL: mockAI.BaseURL(),
+ AuthToken: mockAI.AuthToken(),
+ ChatCompletionModel: "deepseek-r1",
+ EmbeddingsModel: "text-similarity-ada-001",
+ })
+ suite.NoError(err)
+
+ for i := 0; i < 100; i++ {
+ embedding := mock.Hash("Please generate similar items for item_0.")
+ floats.AddConst(embedding, float32(i))
+ item2item.Push(&data.Item{
+ ItemId: strconv.Itoa(i),
+ Labels: map[string]any{
+ "title": "item_" + strconv.Itoa(i),
+ "embeddings": embedding,
+ },
+ }, nil)
+ }
+
+ scores := item2item.PopAll(0)
+ suite.Len(scores, 10)
+ for i := 1; i <= 10; i++ {
+ suite.Equal(strconv.Itoa(i), scores[i-1].Id)
+ }
+}
+
func TestItemToItem(t *testing.T) {
suite.Run(t, new(ItemToItemTestSuite))
}
diff --git a/logics/user_to_user.go b/logics/user_to_user.go
index 423b1037e..a23fcf5bc 100644
--- a/logics/user_to_user.go
+++ b/logics/user_to_user.go
@@ -43,7 +43,7 @@ type UserToUserOptions struct {
type UserToUser interface {
Users() []*data.User
Push(user *data.User, feedback []dataset.ID)
- PopAll(callback func(userId string, score []cache.Score))
+ PopAll(i int) []cache.Score
}
func NewUserToUser(cfg UserToUserConfig, n int, timestamp time.Time, opts *UserToUserOptions) (UserToUser, error) {
@@ -82,21 +82,19 @@ func (b *baseUserToUser[T]) Users() []*data.User {
return b.users
}
-func (b *baseUserToUser[T]) PopAll(callback func(userId string, score []cache.Score)) {
- for index, user := range b.users {
- scores, err := b.index.SearchIndex(index, b.n+1, true)
- if err != nil {
- log.Logger().Error("failed to search index", zap.Error(err))
- return
- }
- callback(user.UserId, lo.Map(scores, func(v lo.Tuple2[int, float32], _ int) cache.Score {
- return cache.Score{
- Id: b.users[v.A].UserId,
- Score: -float64(v.B),
- Timestamp: b.timestamp,
- }
- }))
+func (b *baseUserToUser[T]) PopAll(i int) []cache.Score {
+ scores, err := b.index.SearchIndex(i, b.n+1, true)
+ if err != nil {
+ log.Logger().Error("failed to search index", zap.Error(err))
+ return nil
}
+ return lo.Map(scores, func(v lo.Tuple2[int, float32], _ int) cache.Score {
+ return cache.Score{
+ Id: b.users[v.A].UserId,
+ Score: -float64(v.B),
+ Timestamp: b.timestamp,
+ }
+ })
}
type embeddingUserToUser struct {
diff --git a/logics/user_to_user_test.go b/logics/user_to_user_test.go
index 16101b283..fc61ed933 100644
--- a/logics/user_to_user_test.go
+++ b/logics/user_to_user_test.go
@@ -21,7 +21,6 @@ import (
"github.com/stretchr/testify/suite"
"github.com/zhenghaoz/gorse/dataset"
- "github.com/zhenghaoz/gorse/storage/cache"
"github.com/zhenghaoz/gorse/storage/data"
)
@@ -45,12 +44,7 @@ func (suite *UserToUserTestSuite) TestEmbedding() {
}, nil)
}
- var scores []cache.Score
- user2user.PopAll(func(userId string, score []cache.Score) {
- if userId == "0" {
- scores = score
- }
- })
+ scores := user2user.PopAll(0)
suite.Len(scores, 10)
for i := 1; i <= 10; i++ {
suite.Equal(strconv.Itoa(i), scores[i-1].Id)
@@ -79,12 +73,7 @@ func (suite *UserToUserTestSuite) TestTags() {
}, nil)
}
- var scores []cache.Score
- user2user.PopAll(func(userId string, score []cache.Score) {
- if userId == "0" {
- scores = score
- }
- })
+ scores := user2user.PopAll(0)
suite.Len(scores, 10)
for i := 1; i <= 10; i++ {
suite.Equal(strconv.Itoa(i), scores[i-1].Id)
@@ -108,12 +97,7 @@ func (suite *UserToUserTestSuite) TestItems() {
user2user.Push(&data.User{UserId: strconv.Itoa(i)}, feedback)
}
- var scores []cache.Score
- user2user.PopAll(func(userId string, score []cache.Score) {
- if userId == "0" {
- scores = score
- }
- })
+ scores := user2user.PopAll(0)
suite.Len(scores, 10)
for i := 1; i <= 10; i++ {
suite.Equal(strconv.Itoa(i), scores[i-1].Id)
@@ -146,15 +130,9 @@ func (suite *UserToUserTestSuite) TestAuto() {
user2user.Push(user, feedback)
}
- var scores0, scores1 []cache.Score
- user2user.PopAll(func(userId string, score []cache.Score) {
- if userId == "0" {
- scores0 = score
- } else if userId == "1" {
- scores1 = score
- }
- })
+ scores0 := user2user.PopAll(0)
suite.Len(scores0, 10)
+ scores1 := user2user.PopAll(1)
suite.Len(scores1, 10)
}
diff --git a/master/rest.go b/master/rest.go
index 7bd1299fd..ff23af014 100644
--- a/master/rest.go
+++ b/master/rest.go
@@ -430,6 +430,11 @@ func (m *Master) handleUserInfo(request *restful.Request, response *restful.Resp
server.Ok(response, UserInfo{
Name: m.Config.Master.DashboardUserName,
})
+ } else {
+ response.Header().Set("Content-Type", "application/json")
+ if _, err := response.Write([]byte("null")); err != nil {
+ log.ResponseLogger(response).Error("failed to write response", zap.Error(err))
+ }
}
}
diff --git a/master/tasks.go b/master/tasks.go
index fb0083870..b73ce668f 100644
--- a/master/tasks.go
+++ b/master/tasks.go
@@ -1012,8 +1012,9 @@ func (m *Master) updateItemToItem(dataset *dataset.Dataset) error {
itemToItemRecommenders := make([]logics.ItemToItem, 0, len(itemToItemConfigs))
for _, cfg := range itemToItemConfigs {
recommender, err := logics.NewItemToItem(cfg, m.Config.Recommend.CacheSize, dataset.GetTimestamp(), &logics.ItemToItemOptions{
- TagsIDF: dataset.GetItemColumnValuesIDF(),
- UsersIDF: dataset.GetUserIDF(),
+ TagsIDF: dataset.GetItemColumnValuesIDF(),
+ UsersIDF: dataset.GetUserIDF(),
+ OpenAIConfig: m.Config.OpenAI,
})
if err != nil {
return errors.Trace(err)
@@ -1033,31 +1034,35 @@ func (m *Master) updateItemToItem(dataset *dataset.Dataset) error {
// Save item-to-item recommendations to cache
for i, recommender := range itemToItemRecommenders {
- recommender.PopAll(func(itemId string, score []cache.Score) {
+ for j, item := range recommender.Items() {
itemToItemConfig := itemToItemConfigs[i]
- if m.needUpdateItemToItem(itemId, itemToItemConfigs[i]) {
+ if m.needUpdateItemToItem(item.ItemId, itemToItemConfig) {
+ score := recommender.PopAll(j)
+ if score == nil {
+ continue
+ }
log.Logger().Debug("update item-to-item recommendation",
- zap.String("item_id", itemId),
+ zap.String("item_id", item.ItemId),
zap.String("name", itemToItemConfig.Name),
zap.Int("n_recommendations", len(score)))
// Save item-to-item recommendation to cache
- if err := m.CacheClient.AddScores(ctx, cache.ItemToItem, cache.Key(itemToItemConfig.Name, itemId), score); err != nil {
+ if err := m.CacheClient.AddScores(ctx, cache.ItemToItem, cache.Key(itemToItemConfig.Name, item.ItemId), score); err != nil {
log.Logger().Error("failed to save item-to-item recommendation to cache",
- zap.String("item_id", itemId), zap.Error(err))
- return
+ zap.String("item_id", item.ItemId), zap.Error(err))
+ continue
}
// Save item-to-item digest and last update time to cache
if err := m.CacheClient.Set(ctx,
- cache.String(cache.Key(cache.ItemToItemDigest, itemToItemConfig.Name, itemId), itemToItemConfig.Hash()),
- cache.Time(cache.Key(cache.ItemToItemUpdateTime, itemToItemConfig.Name, itemId), time.Now()),
+ cache.String(cache.Key(cache.ItemToItemDigest, itemToItemConfig.Name, item.ItemId), itemToItemConfig.Hash()),
+ cache.Time(cache.Key(cache.ItemToItemUpdateTime, itemToItemConfig.Name, item.ItemId), time.Now()),
); err != nil {
log.Logger().Error("failed to save item-to-item digest to cache",
- zap.String("item_id", itemId), zap.Error(err))
- return
+ zap.String("item_id", item.ItemId), zap.Error(err))
+ continue
}
}
span.Add(1)
- })
+ }
}
return nil
}
@@ -1132,26 +1137,31 @@ func (m *Master) updateUserToUser(dataset *dataset.Dataset) error {
}
// Save user-to-user recommendations to cache
- userToUserRecommender.PopAll(func(userId string, score []cache.Score) {
- if m.needUpdateUserToUser(userId) {
+ for j, user := range userToUserRecommender.Users() {
+ if m.needUpdateUserToUser(user.UserId) {
+ score := userToUserRecommender.PopAll(j)
+ if score == nil {
+ continue
+ }
log.Logger().Debug("update user neighbors",
- zap.String("user_id", userId),
+ zap.String("user_id", user.UserId),
zap.Int("n_recommendations", len(score)))
// Save user-to-user recommendations to cache
- if err := m.CacheClient.AddScores(ctx, cache.UserToUser, cache.Key(cache.Neighbors, userId), score); err != nil {
- log.Logger().Error("failed to save user neighbors to cache", zap.String("user_id", userId), zap.Error(err))
- return
+ if err := m.CacheClient.AddScores(ctx, cache.UserToUser, cache.Key(cache.Neighbors, user.UserId), score); err != nil {
+ log.Logger().Error("failed to save user neighbors to cache", zap.String("user_id", user.UserId), zap.Error(err))
+ continue
}
// Save user-to-user digest and last update time to cache
if err := m.CacheClient.Set(ctx,
- cache.String(cache.Key(cache.UserToUserDigest, cache.Key(cache.Neighbors, userId)), m.Config.UserNeighborDigest()),
- cache.Time(cache.Key(cache.UserToUserUpdateTime, cache.Key(cache.Neighbors, userId)), time.Now()),
+ cache.String(cache.Key(cache.UserToUserDigest, cache.Key(cache.Neighbors, user.UserId)), m.Config.UserNeighborDigest()),
+ cache.Time(cache.Key(cache.UserToUserUpdateTime, cache.Key(cache.Neighbors, user.UserId)), time.Now()),
); err != nil {
- log.Logger().Error("failed to save user neighbors digest to cache", zap.String("user_id", userId), zap.Error(err))
- return
+ log.Logger().Error("failed to save user neighbors digest to cache", zap.String("user_id", user.UserId), zap.Error(err))
+ continue
}
}
- })
+ span.Add(1)
+ }
return nil
}