-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add example for superstream demo (#14)
- Loading branch information
Brian Amadio
authored
Mar 3, 2021
1 parent
1b25682
commit c6ed170
Showing
7 changed files
with
288 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
FROM golang:1.16 AS build | ||
|
||
ARG service_name | ||
|
||
WORKDIR /go/src/app | ||
COPY . . | ||
|
||
RUN go mod download | ||
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -a -installsuffix cgo -o /app ./${service_name} | ||
|
||
FROM scratch | ||
COPY --from=build /app /app | ||
EXPOSE ${PORT} | ||
|
||
ENTRYPOINT ["/app"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# Contextual multi-armed bandit | ||
With microservice reward source | ||
|
||
## Starting the services: | ||
|
||
`docker compose build && docker compose up` | ||
|
||
## Services | ||
|
||
### Bandit | ||
|
||
The bandit service is a stateless app that uses Thompson sampling to select an arm for a contextual multi-armed bandit. | ||
It depends on the reward service, which provides reward estimates for each arm depending on context. | ||
|
||
The bandit service does not use the context directly, but just passes it to the reward service. | ||
The reward service is responsible for validating the context. | ||
|
||
For example: | ||
|
||
`curl -XPOST localhost:1338/randomize -d '{"unit": "visitor_id:12345", "context": {"campaign_id": 1}}'` | ||
|
||
The bandit service will use the context value as the body of a POST request to the reward service. | ||
|
||
### Reward | ||
|
||
The reward service is a stateful service that provides reward estimates given a context. | ||
In this basic example the rewards are hard-coded, but a real reward service would be connected to a DB. | ||
|
||
You can query the reward service directly with: | ||
|
||
`curl -i -XPOST localhost:1337/rewards -d '{"campaign_id": 1}'` | ||
|
||
The reward service returns an error if the context is invalid, otherwise it returns the reward estimate for each arm. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
package main | ||
|
||
import ( | ||
"encoding/json" | ||
"errors" | ||
"io" | ||
"log" | ||
"net/http" | ||
"os" | ||
"time" | ||
|
||
"github.com/gorilla/handlers" | ||
"github.com/gorilla/mux" | ||
"github.com/stitchfix/mab" | ||
"github.com/stitchfix/mab/numint" | ||
) | ||
|
||
type randomizeRequest struct { | ||
Unit string `json:"unit"` | ||
Context json.RawMessage `json:"context"` | ||
} | ||
|
||
var bandit mab.Bandit | ||
|
||
func handler(w http.ResponseWriter, r *http.Request) { | ||
var req randomizeRequest | ||
|
||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { | ||
http.Error(w, err.Error(), http.StatusBadRequest) | ||
return | ||
} | ||
|
||
result, err := bandit.SelectArm(r.Context(), req.Unit, req.Context) | ||
if err != nil { | ||
var non200 *mab.ErrRewardNon2XX | ||
if errors.As(err, &non200) { | ||
http.Error(w, err.Error(), err.(*mab.ErrRewardNon2XX).StatusCode) | ||
return | ||
} | ||
http.Error(w, err.Error(), http.StatusInternalServerError) | ||
return | ||
} | ||
|
||
_ = json.NewEncoder(w).Encode(result) | ||
} | ||
|
||
func main() { | ||
cli := &http.Client{Timeout: time.Second} | ||
url := "http://reward-service/rewards" | ||
parser := mab.ParseFunc(mab.BetaFromJSON) | ||
marshaler := mab.MarshalFunc(json.Marshal) | ||
|
||
bandit = mab.Bandit{ | ||
RewardSource: mab.NewHTTPSource(cli, url, parser, mab.WithContextMarshaler(marshaler)), | ||
Strategy: mab.NewThompson(numint.NewQuadrature()), | ||
Sampler: mab.NewSha1Sampler(), | ||
} | ||
|
||
r := mux.NewRouter() | ||
r.HandleFunc("/randomize", handler).Methods("POST") | ||
r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { | ||
w.WriteHeader(http.StatusOK) | ||
io.WriteString(w, `{"alive": true}`) | ||
}) | ||
|
||
server := &http.Server{ | ||
Handler: handlers.LoggingHandler(os.Stdout, r), | ||
Addr: "0.0.0.0:80", | ||
} | ||
|
||
log.Fatal(server.ListenAndServe()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
version: "3" | ||
|
||
services: | ||
reward-service: | ||
build: | ||
args: | ||
service_name: reward | ||
environment: | ||
- PORT=1337 | ||
ports: | ||
- "1337:80" | ||
bandit-service: | ||
build: | ||
args: | ||
service_name: bandit | ||
environment: | ||
- PORT=1338 | ||
ports: | ||
- "1338:80" | ||
depends_on: | ||
- reward-service |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
module github.com/stitchfix/mab/examples/superstream_demo | ||
|
||
go 1.14 | ||
|
||
require ( | ||
github.com/gorilla/handlers v1.5.1 | ||
github.com/gorilla/mux v1.8.0 | ||
github.com/stitchfix/mab v0.1.1 | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= | ||
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= | ||
github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw= | ||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= | ||
github.com/felixge/httpsnoop v1.0.1 h1:lvB5Jl89CsZtGIWuTcDM1E/vkVs49/Ml7JJe07l8SPQ= | ||
github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= | ||
github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= | ||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= | ||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= | ||
github.com/gorilla/handlers v1.5.1 h1:9lRY6j8DEeeBT10CvO9hGW0gmky0BprnvDI5vfhUHH4= | ||
github.com/gorilla/handlers v1.5.1/go.mod h1:t8XrUpc4KVXb7HGyJ4/cEnwQiaxrX/hz1Zv/4g96P1Q= | ||
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= | ||
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= | ||
github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= | ||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= | ||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= | ||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= | ||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= | ||
github.com/stitchfix/mab v0.1.1 h1:UMijkS857AyLd8VDqvwm+OBlGm3Pni70CTzq51Hs/Vw= | ||
github.com/stitchfix/mab v0.1.1/go.mod h1:8XNtsDrZu9hDAixMR4Up33NUNalw3UZ5urBC4LYhYuY= | ||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= | ||
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= | ||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= | ||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= | ||
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= | ||
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= | ||
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= | ||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= | ||
golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6 h1:QE6XYQK6naiK1EPAe1g/ILLxN5RBoH5xkJk3CqlMI/Y= | ||
golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= | ||
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs= | ||
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= | ||
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= | ||
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= | ||
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= | ||
golang.org/x/mod v0.1.1-0.20191107180719-034126e5016b/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= | ||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= | ||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= | ||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | ||
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||
golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= | ||
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= | ||
golang.org/x/tools v0.0.0-20190206041539-40960b6deb8e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= | ||
golang.org/x/tools v0.0.0-20200207183749-b753a1ba74fa/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= | ||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= | ||
gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo= | ||
gonum.org/v1/gonum v0.8.2 h1:CCXrcPKiGGotvnN6jfUsKk4rRqm7q09/YbKb5xCEvtM= | ||
gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0= | ||
gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0 h1:OE9mWmgKkjJyEmDAAtGMPjXu+YNeGvK9VTSHY6+Qihc= | ||
gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw= | ||
gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc= | ||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= | ||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= | ||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= | ||
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
package main | ||
|
||
import ( | ||
"encoding/json" | ||
"fmt" | ||
"io" | ||
"log" | ||
"net/http" | ||
"os" | ||
|
||
"github.com/gorilla/handlers" | ||
"github.com/gorilla/mux" | ||
) | ||
|
||
// In a real system, rewards are stored in a DB, but for purposes of the demo we'll just hard-code some example values | ||
var campaignRewards map[int][]struct{ Alpha, Beta float64 } | ||
|
||
func init() { | ||
campaignRewards = make(map[int][]struct{ Alpha, Beta float64 }) | ||
|
||
campaignRewards[1] = []struct{ Alpha, Beta float64 }{ | ||
{10, 125}, | ||
{4, 130}, | ||
{16, 80}, | ||
{25, 99}, | ||
} | ||
|
||
campaignRewards[2] = []struct{ Alpha, Beta float64 }{ | ||
{25, 125}, | ||
{5, 50}, | ||
{7, 90}, | ||
{13, 200}, | ||
} | ||
} | ||
|
||
// This function handles incoming post requests to the /rewards endpoint | ||
func handler(w http.ResponseWriter, r *http.Request) { | ||
|
||
// The request body must contain a JSON object with at least a "campaign_id" key and and integer value | ||
var req struct { | ||
CampaignID *int `json:"campaign_id"` | ||
} | ||
|
||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { | ||
http.Error(w, err.Error(), http.StatusBadRequest) | ||
return | ||
} | ||
|
||
if req.CampaignID == nil { | ||
http.Error(w, "missing required key \"campaign_id\"", http.StatusBadRequest) | ||
return | ||
} | ||
|
||
// get the context-dependent reward estimates | ||
rewards, ok := campaignRewards[*req.CampaignID] | ||
if !ok { | ||
http.Error(w, fmt.Sprintf("no rewards for campaign ID %d", req.CampaignID), http.StatusBadRequest) | ||
return | ||
} | ||
|
||
// send a JSON-encoded response | ||
w.Header().Set("Content-Type", "application/json") | ||
_ = json.NewEncoder(w).Encode(rewards) | ||
} | ||
|
||
func main() { | ||
r := mux.NewRouter() | ||
r.HandleFunc("/rewards", handler).Methods("POST") | ||
r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { | ||
w.WriteHeader(http.StatusOK) | ||
io.WriteString(w, `{"alive": true}`) | ||
}) | ||
|
||
server := &http.Server{ | ||
Handler: handlers.LoggingHandler(os.Stdout, r), | ||
Addr: "0.0.0.0:80", | ||
} | ||
|
||
log.Fatal(server.ListenAndServe()) | ||
} |