From 9ee13de00f64b52d3cf6518de950884343be568c Mon Sep 17 00:00:00 2001 From: Brian Amadio Date: Tue, 16 Mar 2021 14:23:18 -0700 Subject: [PATCH] Update demo 2 (#17) * update README * move adapter code into separate file --- examples/superstream_demo/README.md | 19 +++--- examples/superstream_demo/bandit/adapter.go | 72 +++++++++++++++++++++ examples/superstream_demo/bandit/main.go | 54 +--------------- examples/superstream_demo/reward/main.go | 35 ++++++---- 4 files changed, 105 insertions(+), 75 deletions(-) create mode 100644 examples/superstream_demo/bandit/adapter.go diff --git a/examples/superstream_demo/README.md b/examples/superstream_demo/README.md index a2c25aa..92c3968 100644 --- a/examples/superstream_demo/README.md +++ b/examples/superstream_demo/README.md @@ -1,4 +1,5 @@ # Contextual multi-armed bandit + With microservice reward source ## Starting the services: @@ -12,22 +13,24 @@ With microservice reward source The bandit service is a stateless app that uses Thompson sampling to select an arm for a contextual multi-armed bandit. It depends on the reward service, which provides reward estimates for each arm depending on context. -The bandit service does not use the context directly, but just passes it to the reward service. -The reward service is responsible for validating the context. +The bandit service does not use the context directly, but just passes it to the reward service. The reward service is +responsible for validating the context. For example: -`curl -XPOST localhost:1338/randomize -d '{"unit": "visitor_id:12345", "context": {"campaign_id": 1}}'` +`curl -XPOST localhost:1338/select_arm -d '{"unit": "visitor_id:12345", "context": {"source_id": 1}}'` -The bandit service will use the context value as the body of a POST request to the reward service. +The bandit service will pass the value under the "context" key as a top-level JSON object in the request to the reward +service. ### Reward -The reward service is a stateful service that provides reward estimates given a context. -In this basic example the rewards are hard-coded, but a real reward service would be connected to a DB. +The reward service is a stateful service that provides reward estimates given a context. In this basic example the +rewards are hard-coded, but a real reward service would be connected to a DB. You can query the reward service directly with: -`curl -i -XPOST localhost:1337/rewards -d '{"campaign_id": 1}'` +`curl -i -XPOST localhost:1337/rewards -d '{"source_id": 1}'` -The reward service returns an error if the context is invalid, otherwise it returns the reward estimate for each arm. \ No newline at end of file +The reward service returns an error if the context is invalid or there are no reward estimates for the given context, +otherwise it returns the reward estimate for each arm. \ No newline at end of file diff --git a/examples/superstream_demo/bandit/adapter.go b/examples/superstream_demo/bandit/adapter.go new file mode 100644 index 0000000..88a5091 --- /dev/null +++ b/examples/superstream_demo/bandit/adapter.go @@ -0,0 +1,72 @@ +package main + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + + "github.com/stitchfix/mab" +) + +type adapter struct { + bandit mab.Bandit +} + +type selectArmRequest struct { + Unit string `json:"unit"` + Context json.RawMessage `json:"context"` +} + +type selectArmResponse struct { + Rewards []mab.Dist `json:"rewards"` + Probs []float64 `json:"probs"` + Arm int `json:"arm"` +} + +func (a adapter) selectArm(w http.ResponseWriter, r *http.Request) { + + defer r.Body.Close() + + var req selectArmRequest + + if err := a.decodeRequestBody(r.Body, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + result, err := a.bandit.SelectArm(r.Context(), req.Unit, req.Context) + + if err != nil { + a.writeError(w, err) + return + } + + resp := selectArmResponse{ + Rewards: result.Rewards, + Probs: result.Probs, + Arm: result.Arm, + } + + _ = json.NewEncoder(w).Encode(resp) +} + +func (a adapter) decodeRequestBody(b io.Reader, req *selectArmRequest) error { + if err := json.NewDecoder(b).Decode(req); err == io.EOF { + return fmt.Errorf("request body empty") + } else if err != nil { + return err + } + return nil +} + +func (a adapter) writeError(w http.ResponseWriter, err error) { + var non200 *mab.ErrRewardNon2XX + if errors.As(err, &non200) { + http.Error(w, err.Error(), err.(*mab.ErrRewardNon2XX).StatusCode) + return + } + http.Error(w, err.Error(), http.StatusInternalServerError) + return +} diff --git a/examples/superstream_demo/bandit/main.go b/examples/superstream_demo/bandit/main.go index e726b61..f5a084f 100644 --- a/examples/superstream_demo/bandit/main.go +++ b/examples/superstream_demo/bandit/main.go @@ -2,9 +2,6 @@ package main import ( "encoding/json" - "errors" - "fmt" - "io" "log" "net/http" "os" @@ -29,7 +26,7 @@ func main() { } r := mux.NewRouter() - r.HandleFunc("/select_arm", handler{bandit}.selectArm).Methods("POST") + r.HandleFunc("/select_arm", adapter{bandit}.selectArm).Methods("POST") server := &http.Server{ Handler: handlers.LoggingHandler(os.Stdout, r), @@ -38,52 +35,3 @@ func main() { log.Fatal(server.ListenAndServe()) } - -type handler struct { - bandit mab.Bandit -} - -type selectArmRequest struct { - Unit string `json:"unit"` - Context json.RawMessage `json:"context"` -} - -func (h handler) selectArm(w http.ResponseWriter, r *http.Request) { - - defer r.Body.Close() - - var req selectArmRequest - - if err := h.decodeRequestBody(r.Body, &req); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - - result, err := h.bandit.SelectArm(r.Context(), req.Unit, req.Context) - - if err != nil { - h.writeError(w, err) - return - } - - _ = json.NewEncoder(w).Encode(result) -} - -func (h handler) decodeRequestBody(b io.Reader, req *selectArmRequest) error { - if err := json.NewDecoder(b).Decode(req); err == io.EOF { - return fmt.Errorf("request body empty") - } else if err != nil { - return err - } - return nil -} - -func (h handler) writeError(w http.ResponseWriter, err error) { - var non200 *mab.ErrRewardNon2XX - if errors.As(err, &non200) { - http.Error(w, err.Error(), err.(*mab.ErrRewardNon2XX).StatusCode) - return - } - http.Error(w, err.Error(), http.StatusInternalServerError) - return -} diff --git a/examples/superstream_demo/reward/main.go b/examples/superstream_demo/reward/main.go index 686e376..752e401 100644 --- a/examples/superstream_demo/reward/main.go +++ b/examples/superstream_demo/reward/main.go @@ -13,32 +13,39 @@ import ( ) // In a real system, rewards are stored in a DB, but for purposes of the demo we'll just hard-code some example values -var campaignRewards map[int][]struct{ Alpha, Beta float64 } +var rewards map[int][]struct{ Alpha, Beta float64 } func init() { - campaignRewards = make(map[int][]struct{ Alpha, Beta float64 }) + rewards = make(map[int][]struct{ Alpha, Beta float64 }) - campaignRewards[1] = []struct{ Alpha, Beta float64 }{ + rewards[0] = []struct{ Alpha, Beta float64 }{ {10, 125}, {34, 130}, {26, 95}, {25, 99}, } - campaignRewards[2] = []struct{ Alpha, Beta float64 }{ - {25, 125}, - {10, 50}, - {7, 35}, - {57, 200}, + rewards[1] = []struct{ Alpha, Beta float64 }{ + {10, 125}, + {34, 130}, + {26, 95}, + {25, 99}, + } + + rewards[2] = []struct{ Alpha, Beta float64 }{ + {50, 250}, + {20, 105}, + {20, 75}, + {110, 399}, } } // This function handles incoming post requests to the /rewards endpoint func handler(w http.ResponseWriter, r *http.Request) { - // The request body must contain a JSON object with at least a "campaign_id" key and and integer value + // The request body must contain a JSON object with at least a "source_id" key and and integer value var req struct { - CampaignID *int `json:"campaign_id"` + SourceID *int `json:"source_id"` } if err := json.NewDecoder(r.Body).Decode(&req); err == io.EOF { @@ -49,15 +56,15 @@ func handler(w http.ResponseWriter, r *http.Request) { return } - if req.CampaignID == nil { - http.Error(w, "missing required key \"campaign_id\"", http.StatusBadRequest) + if req.SourceID == nil { + http.Error(w, "missing required key \"source_id\"", http.StatusBadRequest) return } // get the context-dependent reward estimates - rewards, ok := campaignRewards[*req.CampaignID] + rewards, ok := rewards[*req.SourceID] if !ok { - http.Error(w, fmt.Sprintf("no rewards for campaign ID %d", *req.CampaignID), http.StatusBadRequest) + http.Error(w, fmt.Sprintf("no rewards for source ID %d", *req.SourceID), http.StatusBadRequest) return }