Skip to content

Commit

Permalink
Update demo 2 (#17)
Browse files Browse the repository at this point in the history
* update README

* move adapter code into separate file
  • Loading branch information
Brian Amadio authored Mar 16, 2021
1 parent 323274c commit 9ee13de
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 75 deletions.
19 changes: 11 additions & 8 deletions examples/superstream_demo/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Contextual multi-armed bandit

With microservice reward source

## Starting the services:
Expand All @@ -12,22 +13,24 @@ With microservice reward source
The bandit service is a stateless app that uses Thompson sampling to select an arm for a contextual multi-armed bandit.
It depends on the reward service, which provides reward estimates for each arm depending on context.

The bandit service does not use the context directly, but just passes it to the reward service.
The reward service is responsible for validating the context.
The bandit service does not use the context directly, but just passes it to the reward service. The reward service is
responsible for validating the context.

For example:

`curl -XPOST localhost:1338/randomize -d '{"unit": "visitor_id:12345", "context": {"campaign_id": 1}}'`
`curl -XPOST localhost:1338/select_arm -d '{"unit": "visitor_id:12345", "context": {"source_id": 1}}'`

The bandit service will use the context value as the body of a POST request to the reward service.
The bandit service will pass the value under the "context" key as a top-level JSON object in the request to the reward
service.

### Reward

The reward service is a stateful service that provides reward estimates given a context.
In this basic example the rewards are hard-coded, but a real reward service would be connected to a DB.
The reward service is a stateful service that provides reward estimates given a context. In this basic example the
rewards are hard-coded, but a real reward service would be connected to a DB.

You can query the reward service directly with:

`curl -i -XPOST localhost:1337/rewards -d '{"campaign_id": 1}'`
`curl -i -XPOST localhost:1337/rewards -d '{"source_id": 1}'`

The reward service returns an error if the context is invalid, otherwise it returns the reward estimate for each arm.
The reward service returns an error if the context is invalid or there are no reward estimates for the given context,
otherwise it returns the reward estimate for each arm.
72 changes: 72 additions & 0 deletions examples/superstream_demo/bandit/adapter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package main

import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"

"github.com/stitchfix/mab"
)

type adapter struct {
bandit mab.Bandit
}

type selectArmRequest struct {
Unit string `json:"unit"`
Context json.RawMessage `json:"context"`
}

type selectArmResponse struct {
Rewards []mab.Dist `json:"rewards"`
Probs []float64 `json:"probs"`
Arm int `json:"arm"`
}

func (a adapter) selectArm(w http.ResponseWriter, r *http.Request) {

defer r.Body.Close()

var req selectArmRequest

if err := a.decodeRequestBody(r.Body, &req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

result, err := a.bandit.SelectArm(r.Context(), req.Unit, req.Context)

if err != nil {
a.writeError(w, err)
return
}

resp := selectArmResponse{
Rewards: result.Rewards,
Probs: result.Probs,
Arm: result.Arm,
}

_ = json.NewEncoder(w).Encode(resp)
}

func (a adapter) decodeRequestBody(b io.Reader, req *selectArmRequest) error {
if err := json.NewDecoder(b).Decode(req); err == io.EOF {
return fmt.Errorf("request body empty")
} else if err != nil {
return err
}
return nil
}

func (a adapter) writeError(w http.ResponseWriter, err error) {
var non200 *mab.ErrRewardNon2XX
if errors.As(err, &non200) {
http.Error(w, err.Error(), err.(*mab.ErrRewardNon2XX).StatusCode)
return
}
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
54 changes: 1 addition & 53 deletions examples/superstream_demo/bandit/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@ package main

import (
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"os"
Expand All @@ -29,7 +26,7 @@ func main() {
}

r := mux.NewRouter()
r.HandleFunc("/select_arm", handler{bandit}.selectArm).Methods("POST")
r.HandleFunc("/select_arm", adapter{bandit}.selectArm).Methods("POST")

server := &http.Server{
Handler: handlers.LoggingHandler(os.Stdout, r),
Expand All @@ -38,52 +35,3 @@ func main() {

log.Fatal(server.ListenAndServe())
}

type handler struct {
bandit mab.Bandit
}

type selectArmRequest struct {
Unit string `json:"unit"`
Context json.RawMessage `json:"context"`
}

func (h handler) selectArm(w http.ResponseWriter, r *http.Request) {

defer r.Body.Close()

var req selectArmRequest

if err := h.decodeRequestBody(r.Body, &req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

result, err := h.bandit.SelectArm(r.Context(), req.Unit, req.Context)

if err != nil {
h.writeError(w, err)
return
}

_ = json.NewEncoder(w).Encode(result)
}

func (h handler) decodeRequestBody(b io.Reader, req *selectArmRequest) error {
if err := json.NewDecoder(b).Decode(req); err == io.EOF {
return fmt.Errorf("request body empty")
} else if err != nil {
return err
}
return nil
}

func (h handler) writeError(w http.ResponseWriter, err error) {
var non200 *mab.ErrRewardNon2XX
if errors.As(err, &non200) {
http.Error(w, err.Error(), err.(*mab.ErrRewardNon2XX).StatusCode)
return
}
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
35 changes: 21 additions & 14 deletions examples/superstream_demo/reward/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,39 @@ import (
)

// In a real system, rewards are stored in a DB, but for purposes of the demo we'll just hard-code some example values
var campaignRewards map[int][]struct{ Alpha, Beta float64 }
var rewards map[int][]struct{ Alpha, Beta float64 }

func init() {
campaignRewards = make(map[int][]struct{ Alpha, Beta float64 })
rewards = make(map[int][]struct{ Alpha, Beta float64 })

campaignRewards[1] = []struct{ Alpha, Beta float64 }{
rewards[0] = []struct{ Alpha, Beta float64 }{
{10, 125},
{34, 130},
{26, 95},
{25, 99},
}

campaignRewards[2] = []struct{ Alpha, Beta float64 }{
{25, 125},
{10, 50},
{7, 35},
{57, 200},
rewards[1] = []struct{ Alpha, Beta float64 }{
{10, 125},
{34, 130},
{26, 95},
{25, 99},
}

rewards[2] = []struct{ Alpha, Beta float64 }{
{50, 250},
{20, 105},
{20, 75},
{110, 399},
}
}

// This function handles incoming post requests to the /rewards endpoint
func handler(w http.ResponseWriter, r *http.Request) {

// The request body must contain a JSON object with at least a "campaign_id" key and and integer value
// The request body must contain a JSON object with at least a "source_id" key and and integer value
var req struct {
CampaignID *int `json:"campaign_id"`
SourceID *int `json:"source_id"`
}

if err := json.NewDecoder(r.Body).Decode(&req); err == io.EOF {
Expand All @@ -49,15 +56,15 @@ func handler(w http.ResponseWriter, r *http.Request) {
return
}

if req.CampaignID == nil {
http.Error(w, "missing required key \"campaign_id\"", http.StatusBadRequest)
if req.SourceID == nil {
http.Error(w, "missing required key \"source_id\"", http.StatusBadRequest)
return
}

// get the context-dependent reward estimates
rewards, ok := campaignRewards[*req.CampaignID]
rewards, ok := rewards[*req.SourceID]
if !ok {
http.Error(w, fmt.Sprintf("no rewards for campaign ID %d", *req.CampaignID), http.StatusBadRequest)
http.Error(w, fmt.Sprintf("no rewards for source ID %d", *req.SourceID), http.StatusBadRequest)
return
}

Expand Down

0 comments on commit 9ee13de

Please sign in to comment.