diff --git a/examples/superstream_demo/Dockerfile b/examples/superstream_demo/Dockerfile new file mode 100644 index 0000000..13aee4c --- /dev/null +++ b/examples/superstream_demo/Dockerfile @@ -0,0 +1,15 @@ +FROM golang:1.16 AS build + +ARG service_name + +WORKDIR /go/src/app +COPY . . + +RUN go mod download +RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -a -installsuffix cgo -o /app ./${service_name} + +FROM scratch +COPY --from=build /app /app +EXPOSE ${PORT} + +ENTRYPOINT ["/app"] diff --git a/examples/superstream_demo/README.md b/examples/superstream_demo/README.md new file mode 100644 index 0000000..a2c25aa --- /dev/null +++ b/examples/superstream_demo/README.md @@ -0,0 +1,33 @@ +# Contextual multi-armed bandit +With microservice reward source + +## Starting the services: + +`docker compose build && docker compose up` + +## Services + +### Bandit + +The bandit service is a stateless app that uses Thompson sampling to select an arm for a contextual multi-armed bandit. +It depends on the reward service, which provides reward estimates for each arm depending on context. + +The bandit service does not use the context directly, but just passes it to the reward service. +The reward service is responsible for validating the context. + +For example: + +`curl -XPOST localhost:1338/randomize -d '{"unit": "visitor_id:12345", "context": {"campaign_id": 1}}'` + +The bandit service will use the context value as the body of a POST request to the reward service. + +### Reward + +The reward service is a stateful service that provides reward estimates given a context. +In this basic example the rewards are hard-coded, but a real reward service would be connected to a DB. + +You can query the reward service directly with: + +`curl -i -XPOST localhost:1337/rewards -d '{"campaign_id": 1}'` + +The reward service returns an error if the context is invalid, otherwise it returns the reward estimate for each arm. \ No newline at end of file diff --git a/examples/superstream_demo/bandit/main.go b/examples/superstream_demo/bandit/main.go new file mode 100644 index 0000000..a142f51 --- /dev/null +++ b/examples/superstream_demo/bandit/main.go @@ -0,0 +1,72 @@ +package main + +import ( + "encoding/json" + "errors" + "io" + "log" + "net/http" + "os" + "time" + + "github.com/gorilla/handlers" + "github.com/gorilla/mux" + "github.com/stitchfix/mab" + "github.com/stitchfix/mab/numint" +) + +type randomizeRequest struct { + Unit string `json:"unit"` + Context json.RawMessage `json:"context"` +} + +var bandit mab.Bandit + +func handler(w http.ResponseWriter, r *http.Request) { + var req randomizeRequest + + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + result, err := bandit.SelectArm(r.Context(), req.Unit, req.Context) + if err != nil { + var non200 *mab.ErrRewardNon2XX + if errors.As(err, &non200) { + http.Error(w, err.Error(), err.(*mab.ErrRewardNon2XX).StatusCode) + return + } + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + _ = json.NewEncoder(w).Encode(result) +} + +func main() { + cli := &http.Client{Timeout: time.Second} + url := "http://reward-service/rewards" + parser := mab.ParseFunc(mab.BetaFromJSON) + marshaler := mab.MarshalFunc(json.Marshal) + + bandit = mab.Bandit{ + RewardSource: mab.NewHTTPSource(cli, url, parser, mab.WithContextMarshaler(marshaler)), + Strategy: mab.NewThompson(numint.NewQuadrature()), + Sampler: mab.NewSha1Sampler(), + } + + r := mux.NewRouter() + r.HandleFunc("/randomize", handler).Methods("POST") + r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + io.WriteString(w, `{"alive": true}`) + }) + + server := &http.Server{ + Handler: handlers.LoggingHandler(os.Stdout, r), + Addr: "0.0.0.0:80", + } + + log.Fatal(server.ListenAndServe()) +} diff --git a/examples/superstream_demo/docker-compose.yml b/examples/superstream_demo/docker-compose.yml new file mode 100644 index 0000000..3465dc4 --- /dev/null +++ b/examples/superstream_demo/docker-compose.yml @@ -0,0 +1,21 @@ +version: "3" + +services: + reward-service: + build: + args: + service_name: reward + environment: + - PORT=1337 + ports: + - "1337:80" + bandit-service: + build: + args: + service_name: bandit + environment: + - PORT=1338 + ports: + - "1338:80" + depends_on: + - reward-service diff --git a/examples/superstream_demo/go.mod b/examples/superstream_demo/go.mod new file mode 100644 index 0000000..9df5b60 --- /dev/null +++ b/examples/superstream_demo/go.mod @@ -0,0 +1,9 @@ +module github.com/stitchfix/mab/examples/superstream_demo + +go 1.14 + +require ( + github.com/gorilla/handlers v1.5.1 + github.com/gorilla/mux v1.8.0 + github.com/stitchfix/mab v0.1.1 +) diff --git a/examples/superstream_demo/go.sum b/examples/superstream_demo/go.sum new file mode 100644 index 0000000..6b97c07 --- /dev/null +++ b/examples/superstream_demo/go.sum @@ -0,0 +1,58 @@ +dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= +github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/felixge/httpsnoop v1.0.1 h1:lvB5Jl89CsZtGIWuTcDM1E/vkVs49/Ml7JJe07l8SPQ= +github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= +github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= +github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= +github.com/gorilla/handlers v1.5.1 h1:9lRY6j8DEeeBT10CvO9hGW0gmky0BprnvDI5vfhUHH4= +github.com/gorilla/handlers v1.5.1/go.mod h1:t8XrUpc4KVXb7HGyJ4/cEnwQiaxrX/hz1Zv/4g96P1Q= +github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= +github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= +github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stitchfix/mab v0.1.1 h1:UMijkS857AyLd8VDqvwm+OBlGm3Pni70CTzq51Hs/Vw= +github.com/stitchfix/mab v0.1.1/go.mod h1:8XNtsDrZu9hDAixMR4Up33NUNalw3UZ5urBC4LYhYuY= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6 h1:QE6XYQK6naiK1EPAe1g/ILLxN5RBoH5xkJk3CqlMI/Y= +golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= +golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs= +golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= +golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= +golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/mod v0.1.1-0.20191107180719-034126e5016b/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190206041539-40960b6deb8e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20200207183749-b753a1ba74fa/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo= +gonum.org/v1/gonum v0.8.2 h1:CCXrcPKiGGotvnN6jfUsKk4rRqm7q09/YbKb5xCEvtM= +gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0= +gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0 h1:OE9mWmgKkjJyEmDAAtGMPjXu+YNeGvK9VTSHY6+Qihc= +gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw= +gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/examples/superstream_demo/reward/main.go b/examples/superstream_demo/reward/main.go new file mode 100644 index 0000000..1d0994f --- /dev/null +++ b/examples/superstream_demo/reward/main.go @@ -0,0 +1,80 @@ +package main + +import ( + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "os" + + "github.com/gorilla/handlers" + "github.com/gorilla/mux" +) + +// In a real system, rewards are stored in a DB, but for purposes of the demo we'll just hard-code some example values +var campaignRewards map[int][]struct{ Alpha, Beta float64 } + +func init() { + campaignRewards = make(map[int][]struct{ Alpha, Beta float64 }) + + campaignRewards[1] = []struct{ Alpha, Beta float64 }{ + {10, 125}, + {4, 130}, + {16, 80}, + {25, 99}, + } + + campaignRewards[2] = []struct{ Alpha, Beta float64 }{ + {25, 125}, + {5, 50}, + {7, 90}, + {13, 200}, + } +} + +// This function handles incoming post requests to the /rewards endpoint +func handler(w http.ResponseWriter, r *http.Request) { + + // The request body must contain a JSON object with at least a "campaign_id" key and and integer value + var req struct { + CampaignID *int `json:"campaign_id"` + } + + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + if req.CampaignID == nil { + http.Error(w, "missing required key \"campaign_id\"", http.StatusBadRequest) + return + } + + // get the context-dependent reward estimates + rewards, ok := campaignRewards[*req.CampaignID] + if !ok { + http.Error(w, fmt.Sprintf("no rewards for campaign ID %d", req.CampaignID), http.StatusBadRequest) + return + } + + // send a JSON-encoded response + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(rewards) +} + +func main() { + r := mux.NewRouter() + r.HandleFunc("/rewards", handler).Methods("POST") + r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + io.WriteString(w, `{"alive": true}`) + }) + + server := &http.Server{ + Handler: handlers.LoggingHandler(os.Stdout, r), + Addr: "0.0.0.0:80", + } + + log.Fatal(server.ListenAndServe()) +}