Skip to content

Commit

Permalink
implement LLM-based recommenders (#941)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghaoz authored Feb 14, 2025
1 parent d5c570d commit b349ba0
Show file tree
Hide file tree
Showing 11 changed files with 273 additions and 132 deletions.
42 changes: 31 additions & 11 deletions common/mock/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,22 @@ package mock

import (
"bytes"
"crypto/md5"
"encoding/json"
"fmt"
"github.com/emicklei/go-restful/v3"
"github.com/sashabaranov/go-openai"
"net"
"net/http"

"github.com/emicklei/go-restful/v3"
"github.com/samber/lo"
"github.com/sashabaranov/go-openai"
)

type OpenAIServer struct {
listener net.Listener
httpServer *http.Server
authToken string
ready chan struct{}

mockEmbeddings []float32
}

func NewOpenAIServer() *OpenAIServer {
Expand Down Expand Up @@ -81,19 +82,18 @@ func (s *OpenAIServer) Close() error {
return s.httpServer.Close()
}

func (s *OpenAIServer) Embeddings(embeddings []float32) {
s.mockEmbeddings = embeddings
}

func (s *OpenAIServer) chatCompletion(req *restful.Request, resp *restful.Response) {
var r openai.ChatCompletionRequest
err := req.ReadEntity(&r)
if err != nil {
_ = resp.WriteError(http.StatusBadRequest, err)
return
}
content := r.Messages[0].Content
if r.Model == "deepseek-r1" {
content = "<think>To be or not to be, that is the question.</think>" + content
}
if r.Stream {
content := r.Messages[0].Content
for i := 0; i < len(content); i += 8 {
buf := bytes.NewBuffer(nil)
buf.WriteString("data: ")
Expand All @@ -112,23 +112,43 @@ func (s *OpenAIServer) chatCompletion(req *restful.Request, resp *restful.Respon
_ = resp.WriteEntity(openai.ChatCompletionResponse{
Choices: []openai.ChatCompletionChoice{{
Message: openai.ChatCompletionMessage{
Content: r.Messages[0].Content,
Content: content,
},
}},
})
}
}

func (s *OpenAIServer) embeddings(req *restful.Request, resp *restful.Response) {
// parse request
var r openai.EmbeddingRequest
err := req.ReadEntity(&r)
if err != nil {
_ = resp.WriteError(http.StatusBadRequest, err)
return
}
input, ok := r.Input.(string)
if !ok {
_ = resp.WriteError(http.StatusBadRequest, fmt.Errorf("invalid input type"))
return
}

// write response
_ = resp.WriteEntity(openai.EmbeddingResponse{
Data: []openai.Embedding{{
Embedding: s.mockEmbeddings,
Embedding: Hash(input),
}},
})
}

func Hash(input string) []float32 {
hasher := md5.New()
_, err := hasher.Write([]byte(input))
if err != nil {
panic(err)
}
h := hasher.Sum(nil)
return lo.Map(h, func(b byte, _ int) float32 {
return float32(b)
})
}
10 changes: 5 additions & 5 deletions common/mock/openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ package mock

import (
"context"
"github.com/juju/errors"
"github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/suite"
"io"
"strings"
"testing"

"github.com/juju/errors"
"github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/suite"
)

type OpenAITestSuite struct {
Expand Down Expand Up @@ -97,7 +98,6 @@ func (suite *OpenAITestSuite) TestChatCompletionStream() {
}

func (suite *OpenAITestSuite) TestEmbeddings() {
suite.server.Embeddings([]float32{1, 2, 3})
resp, err := suite.client.CreateEmbeddings(
context.Background(),
openai.EmbeddingRequest{
Expand All @@ -106,7 +106,7 @@ func (suite *OpenAITestSuite) TestEmbeddings() {
},
)
suite.NoError(err)
suite.Equal([]float32{1, 2, 3}, resp.Data[0].Embedding)
suite.Equal([]float32{139, 26, 153, 83, 196, 97, 18, 150, 168, 39, 171, 248, 196, 120, 4, 215}, resp.Data[0].Embedding)
}

func TestOpenAITestSuite(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ type NeighborsConfig struct {

type ItemToItemConfig struct {
Name string `mapstructure:"name" json:"name"`
Type string `mapstructure:"type" json:"type" validate:"oneof=embedding tags users llm"`
Type string `mapstructure:"type" json:"type" validate:"oneof=embedding tags users chat"`
Column string `mapstructure:"column" json:"column" validate:"item_expr"`
Prompt string `mapstructure:"prompt" json:"prompt"`
}
Expand Down
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ require (
github.com/go-viper/mapstructure/v2 v2.2.1
github.com/google/uuid v1.6.0
github.com/gorilla/securecookie v1.1.1
github.com/gorse-io/dashboard v0.0.0-20250209091713-a70341e78d48
github.com/gorse-io/dashboard v0.0.0-20250214134211-90d95a512041
github.com/gorse-io/gorse-go v0.5.0-alpha.1
github.com/haxii/go-swagger-ui v0.0.0-20210203093335-a63a6bbde946
github.com/jaswdr/faker v1.16.0
Expand Down Expand Up @@ -110,7 +110,6 @@ require (
github.com/golang/protobuf v1.5.3 // indirect
github.com/golang/snappy v0.0.4 // indirect
github.com/google/flatbuffers v2.0.6+incompatible // indirect
github.com/google/pprof v0.0.0-20240827171923-fa2c70bbbfe5 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0 // indirect
github.com/hashicorp/go-version v1.6.0 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
Expand Down Expand Up @@ -152,6 +151,7 @@ require (
github.com/sagikazarmark/locafero v0.4.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/shopspring/decimal v1.3.1 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
github.com/spf13/afero v1.11.0 // indirect
github.com/spf13/cast v1.6.0 // indirect
Expand Down
17 changes: 15 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ gioui.org v0.0.0-20210308172011-57750fc8a0a6/go.mod h1:RSH6KIUZ0p2xy5zHDxgAM4zum
git.sr.ht/~sbinet/gg v0.3.1/go.mod h1:KGYtlADtqsqANL9ueOFkWymvzUvLMQllU5Ixo+8v3pc=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ=
github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE=
github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030IGemrRc=
github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs=
github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0=
Expand Down Expand Up @@ -211,6 +213,8 @@ github.com/go-resty/resty/v2 v2.16.3/go.mod h1:hkJtXbA2iKHzJheXYvQ8snQES5ZLGKMwQ
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI=
github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8=
github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIxtHqx8aGss=
github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw=
Expand Down Expand Up @@ -304,8 +308,8 @@ github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyC
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
github.com/gorse-io/clickhouse v0.3.3-0.20220715124633-688011a495bb h1:z/oOWE+Vy0PLcwIulZmIug4FtmvE3dJ1YOGprLeHwwY=
github.com/gorse-io/clickhouse v0.3.3-0.20220715124633-688011a495bb/go.mod h1:iILWzbul8U+gsf4kqbheF2QzBmdvVp63mloGGK8emDI=
github.com/gorse-io/dashboard v0.0.0-20250209091713-a70341e78d48 h1:kfCK07ae/+NvxlcPqh0SpaXxkDlceqSmamsX7t/E4+w=
github.com/gorse-io/dashboard v0.0.0-20250209091713-a70341e78d48/go.mod h1:lv2bu311bjIJeRfY+6hiIaw20M6fLxT4ma9Ye+bpwGY=
github.com/gorse-io/dashboard v0.0.0-20250214134211-90d95a512041 h1:Uzv+3PZKKatcS8O7G6v3gUXvu9dfHWqOz8EmkUbGlaw=
github.com/gorse-io/dashboard v0.0.0-20250214134211-90d95a512041/go.mod h1:q98umjWRGQ3Qi1BPksKsy0nvxwKb8V+W4d4XRegxKYg=
github.com/gorse-io/gorgonia v0.0.0-20230817132253-6dd1dbf95849 h1:Hwywr6NxzYeZYn35KwOsw7j8ZiMT60TBzpbn1MbEido=
github.com/gorse-io/gorgonia v0.0.0-20230817132253-6dd1dbf95849/go.mod h1:TtVGAt7ENNmgBnC0JA68CAjIDCEtcqaRHvnkAWJ/Fu0=
github.com/gorse-io/gorse-go v0.5.0-alpha.1 h1:QBWKGAbSKNAWnieXVIdQiE0lLGvKXfFFAFPOQEkPW/E=
Expand Down Expand Up @@ -488,7 +492,13 @@ github.com/montanaflynn/stats v0.7.1/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
github.com/nikolalohinski/gonja/v2 v2.3.3 h1:5cTcmz0i/DwJl67US8Rvnb4OkBXB5V5OWd5IIAPPkXw=
github.com/nikolalohinski/gonja/v2 v2.3.3/go.mod h1:8KC3RlefxnOaY5P4rH5erdwV0/owS83U615cSnDLYFs=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/ginkgo/v2 v2.20.1 h1:YlVIbqct+ZmnEph770q9Q7NVAz4wwIiVNahee6JyUzo=
github.com/onsi/ginkgo/v2 v2.20.1/go.mod h1:lG9ey2Z29hR41WMVthyJBGUBcBhGOtoPF2VFMvBXFCI=
github.com/onsi/gomega v1.35.1 h1:Cwbd75ZBPxFSuZ6T+rN/WCb/gOc6YgFBXLlZLhC7Ds4=
github.com/onsi/gomega v1.35.1/go.mod h1:PvZbdDc8J6XJEpDK4HCuRBm8a6Fzp9/DmhC9C7yFlog=
github.com/openzipkin/zipkin-go v0.4.1 h1:kNd/ST2yLLWhaWrkgchya40TJabe8Hioj9udfPcEO5A=
github.com/openzipkin/zipkin-go v0.4.1/go.mod h1:qY0VqDSN1pOBN94dBc6w2GJlWLiovAyg7Qt6/I9HecM=
github.com/orcaman/concurrent-map v1.0.0 h1:I/2A2XPCb4IuQWcQhBhSwGfiuybl/J0ev9HDbW65HOY=
Expand Down Expand Up @@ -585,6 +595,8 @@ github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPx
github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
Expand Down Expand Up @@ -904,6 +916,7 @@ golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220330033206-e17cdc41300f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
Expand Down
Loading

0 comments on commit b349ba0

Please sign in to comment.