Skip to content

Commit

Permalink
add example for superstream demo (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
Brian Amadio authored Mar 3, 2021
1 parent 1b25682 commit c6ed170
Show file tree
Hide file tree
Showing 7 changed files with 288 additions and 0 deletions.
15 changes: 15 additions & 0 deletions examples/superstream_demo/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
FROM golang:1.16 AS build

ARG service_name

WORKDIR /go/src/app
COPY . .

RUN go mod download
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -a -installsuffix cgo -o /app ./${service_name}

FROM scratch
COPY --from=build /app /app
EXPOSE ${PORT}

ENTRYPOINT ["/app"]
33 changes: 33 additions & 0 deletions examples/superstream_demo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Contextual multi-armed bandit
With microservice reward source

## Starting the services:

`docker compose build && docker compose up`

## Services

### Bandit

The bandit service is a stateless app that uses Thompson sampling to select an arm for a contextual multi-armed bandit.
It depends on the reward service, which provides reward estimates for each arm depending on context.

The bandit service does not use the context directly, but just passes it to the reward service.
The reward service is responsible for validating the context.

For example:

`curl -XPOST localhost:1338/randomize -d '{"unit": "visitor_id:12345", "context": {"campaign_id": 1}}'`

The bandit service will use the context value as the body of a POST request to the reward service.

### Reward

The reward service is a stateful service that provides reward estimates given a context.
In this basic example the rewards are hard-coded, but a real reward service would be connected to a DB.

You can query the reward service directly with:

`curl -i -XPOST localhost:1337/rewards -d '{"campaign_id": 1}'`

The reward service returns an error if the context is invalid, otherwise it returns the reward estimate for each arm.
72 changes: 72 additions & 0 deletions examples/superstream_demo/bandit/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package main

import (
"encoding/json"
"errors"
"io"
"log"
"net/http"
"os"
"time"

"github.com/gorilla/handlers"
"github.com/gorilla/mux"
"github.com/stitchfix/mab"
"github.com/stitchfix/mab/numint"
)

type randomizeRequest struct {
Unit string `json:"unit"`
Context json.RawMessage `json:"context"`
}

var bandit mab.Bandit

func handler(w http.ResponseWriter, r *http.Request) {
var req randomizeRequest

if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

result, err := bandit.SelectArm(r.Context(), req.Unit, req.Context)
if err != nil {
var non200 *mab.ErrRewardNon2XX
if errors.As(err, &non200) {
http.Error(w, err.Error(), err.(*mab.ErrRewardNon2XX).StatusCode)
return
}
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

_ = json.NewEncoder(w).Encode(result)
}

func main() {
cli := &http.Client{Timeout: time.Second}
url := "http://reward-service/rewards"
parser := mab.ParseFunc(mab.BetaFromJSON)
marshaler := mab.MarshalFunc(json.Marshal)

bandit = mab.Bandit{
RewardSource: mab.NewHTTPSource(cli, url, parser, mab.WithContextMarshaler(marshaler)),
Strategy: mab.NewThompson(numint.NewQuadrature()),
Sampler: mab.NewSha1Sampler(),
}

r := mux.NewRouter()
r.HandleFunc("/randomize", handler).Methods("POST")
r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
io.WriteString(w, `{"alive": true}`)
})

server := &http.Server{
Handler: handlers.LoggingHandler(os.Stdout, r),
Addr: "0.0.0.0:80",
}

log.Fatal(server.ListenAndServe())
}
21 changes: 21 additions & 0 deletions examples/superstream_demo/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
version: "3"

services:
reward-service:
build:
args:
service_name: reward
environment:
- PORT=1337
ports:
- "1337:80"
bandit-service:
build:
args:
service_name: bandit
environment:
- PORT=1338
ports:
- "1338:80"
depends_on:
- reward-service
9 changes: 9 additions & 0 deletions examples/superstream_demo/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
module github.com/stitchfix/mab/examples/superstream_demo

go 1.14

require (
github.com/gorilla/handlers v1.5.1
github.com/gorilla/mux v1.8.0
github.com/stitchfix/mab v0.1.1
)
58 changes: 58 additions & 0 deletions examples/superstream_demo/go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/felixge/httpsnoop v1.0.1 h1:lvB5Jl89CsZtGIWuTcDM1E/vkVs49/Ml7JJe07l8SPQ=
github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
github.com/gorilla/handlers v1.5.1 h1:9lRY6j8DEeeBT10CvO9hGW0gmky0BprnvDI5vfhUHH4=
github.com/gorilla/handlers v1.5.1/go.mod h1:t8XrUpc4KVXb7HGyJ4/cEnwQiaxrX/hz1Zv/4g96P1Q=
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stitchfix/mab v0.1.1 h1:UMijkS857AyLd8VDqvwm+OBlGm3Pni70CTzq51Hs/Vw=
github.com/stitchfix/mab v0.1.1/go.mod h1:8XNtsDrZu9hDAixMR4Up33NUNalw3UZ5urBC4LYhYuY=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6 h1:QE6XYQK6naiK1EPAe1g/ILLxN5RBoH5xkJk3CqlMI/Y=
golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.1.1-0.20191107180719-034126e5016b/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190206041539-40960b6deb8e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20200207183749-b753a1ba74fa/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo=
gonum.org/v1/gonum v0.8.2 h1:CCXrcPKiGGotvnN6jfUsKk4rRqm7q09/YbKb5xCEvtM=
gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0=
gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0 h1:OE9mWmgKkjJyEmDAAtGMPjXu+YNeGvK9VTSHY6+Qihc=
gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw=
gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
80 changes: 80 additions & 0 deletions examples/superstream_demo/reward/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package main

import (
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"

"github.com/gorilla/handlers"
"github.com/gorilla/mux"
)

// In a real system, rewards are stored in a DB, but for purposes of the demo we'll just hard-code some example values
var campaignRewards map[int][]struct{ Alpha, Beta float64 }

func init() {
campaignRewards = make(map[int][]struct{ Alpha, Beta float64 })

campaignRewards[1] = []struct{ Alpha, Beta float64 }{
{10, 125},
{4, 130},
{16, 80},
{25, 99},
}

campaignRewards[2] = []struct{ Alpha, Beta float64 }{
{25, 125},
{5, 50},
{7, 90},
{13, 200},
}
}

// This function handles incoming post requests to the /rewards endpoint
func handler(w http.ResponseWriter, r *http.Request) {

// The request body must contain a JSON object with at least a "campaign_id" key and and integer value
var req struct {
CampaignID *int `json:"campaign_id"`
}

if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

if req.CampaignID == nil {
http.Error(w, "missing required key \"campaign_id\"", http.StatusBadRequest)
return
}

// get the context-dependent reward estimates
rewards, ok := campaignRewards[*req.CampaignID]
if !ok {
http.Error(w, fmt.Sprintf("no rewards for campaign ID %d", req.CampaignID), http.StatusBadRequest)
return
}

// send a JSON-encoded response
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(rewards)
}

func main() {
r := mux.NewRouter()
r.HandleFunc("/rewards", handler).Methods("POST")
r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
io.WriteString(w, `{"alive": true}`)
})

server := &http.Server{
Handler: handlers.LoggingHandler(os.Stdout, r),
Addr: "0.0.0.0:80",
}

log.Fatal(server.ListenAndServe())
}

0 comments on commit c6ed170

Please sign in to comment.