Skip to content

Commit

Permalink
Update demo (#16)
Browse files Browse the repository at this point in the history
* dockerfile cache dependencies

* reward service handle JSON errors correctly

* refactor bandit error handling, update rewards
  • Loading branch information
Brian Amadio authored Mar 16, 2021
1 parent c6ed170 commit 323274c
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 44 deletions.
5 changes: 3 additions & 2 deletions examples/superstream_demo/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ FROM golang:1.16 AS build
ARG service_name

WORKDIR /go/src/app
COPY . .

COPY go.* .
RUN go mod download

COPY . .
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -a -installsuffix cgo -o /app ./${service_name}

FROM scratch
Expand Down
85 changes: 51 additions & 34 deletions examples/superstream_demo/bandit/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
Expand All @@ -15,58 +16,74 @@ import (
"github.com/stitchfix/mab/numint"
)

type randomizeRequest struct {
func main() {
cli := &http.Client{Timeout: time.Second}
url := "http://reward-service/rewards"
parser := mab.ParseFunc(mab.BetaFromJSON)
marshaler := mab.MarshalFunc(json.Marshal)

bandit := mab.Bandit{
RewardSource: mab.NewHTTPSource(cli, url, parser, mab.WithContextMarshaler(marshaler)),
Strategy: mab.NewThompson(numint.NewQuadrature()),
Sampler: mab.NewSha1Sampler(),
}

r := mux.NewRouter()
r.HandleFunc("/select_arm", handler{bandit}.selectArm).Methods("POST")

server := &http.Server{
Handler: handlers.LoggingHandler(os.Stdout, r),
Addr: "0.0.0.0:80",
}

log.Fatal(server.ListenAndServe())
}

type handler struct {
bandit mab.Bandit
}

type selectArmRequest struct {
Unit string `json:"unit"`
Context json.RawMessage `json:"context"`
}

var bandit mab.Bandit
func (h handler) selectArm(w http.ResponseWriter, r *http.Request) {

defer r.Body.Close()

func handler(w http.ResponseWriter, r *http.Request) {
var req randomizeRequest
var req selectArmRequest

if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
if err := h.decodeRequestBody(r.Body, &req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

result, err := bandit.SelectArm(r.Context(), req.Unit, req.Context)
result, err := h.bandit.SelectArm(r.Context(), req.Unit, req.Context)

if err != nil {
var non200 *mab.ErrRewardNon2XX
if errors.As(err, &non200) {
http.Error(w, err.Error(), err.(*mab.ErrRewardNon2XX).StatusCode)
return
}
http.Error(w, err.Error(), http.StatusInternalServerError)
h.writeError(w, err)
return
}

_ = json.NewEncoder(w).Encode(result)
}

func main() {
cli := &http.Client{Timeout: time.Second}
url := "http://reward-service/rewards"
parser := mab.ParseFunc(mab.BetaFromJSON)
marshaler := mab.MarshalFunc(json.Marshal)

bandit = mab.Bandit{
RewardSource: mab.NewHTTPSource(cli, url, parser, mab.WithContextMarshaler(marshaler)),
Strategy: mab.NewThompson(numint.NewQuadrature()),
Sampler: mab.NewSha1Sampler(),
func (h handler) decodeRequestBody(b io.Reader, req *selectArmRequest) error {
if err := json.NewDecoder(b).Decode(req); err == io.EOF {
return fmt.Errorf("request body empty")
} else if err != nil {
return err
}
return nil
}

r := mux.NewRouter()
r.HandleFunc("/randomize", handler).Methods("POST")
r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
io.WriteString(w, `{"alive": true}`)
})

server := &http.Server{
Handler: handlers.LoggingHandler(os.Stdout, r),
Addr: "0.0.0.0:80",
func (h handler) writeError(w http.ResponseWriter, err error) {
var non200 *mab.ErrRewardNon2XX
if errors.As(err, &non200) {
http.Error(w, err.Error(), err.(*mab.ErrRewardNon2XX).StatusCode)
return
}

log.Fatal(server.ListenAndServe())
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
19 changes: 11 additions & 8 deletions examples/superstream_demo/reward/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,31 @@ func init() {

campaignRewards[1] = []struct{ Alpha, Beta float64 }{
{10, 125},
{4, 130},
{16, 80},
{34, 130},
{26, 95},
{25, 99},
}

campaignRewards[2] = []struct{ Alpha, Beta float64 }{
{25, 125},
{5, 50},
{7, 90},
{13, 200},
{10, 50},
{7, 35},
{57, 200},
}
}

// This function handles incoming post requests to the /rewards endpoint
func handler(w http.ResponseWriter, r *http.Request) {

// The request body must contain a JSON object with at least a "campaign_id" key and and integer value
var req struct {
CampaignID *int `json:"campaign_id"`
}

if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
if err := json.NewDecoder(r.Body).Decode(&req); err == io.EOF {
http.Error(w, "request body empty", http.StatusBadRequest)
return
} else if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
Expand All @@ -54,7 +57,7 @@ func handler(w http.ResponseWriter, r *http.Request) {
// get the context-dependent reward estimates
rewards, ok := campaignRewards[*req.CampaignID]
if !ok {
http.Error(w, fmt.Sprintf("no rewards for campaign ID %d", req.CampaignID), http.StatusBadRequest)
http.Error(w, fmt.Sprintf("no rewards for campaign ID %d", *req.CampaignID), http.StatusBadRequest)
return
}

Expand Down

0 comments on commit 323274c

Please sign in to comment.