Skip to content

Commit

Permalink
ParseReward takes []byte instead of io.ReadCloser (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
btamadio authored Mar 1, 2021
1 parent 456abc3 commit 1b25682
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 39 deletions.
45 changes: 14 additions & 31 deletions http_reward_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,22 @@ func (h *HTTPSource) GetRewards(ctx context.Context, banditContext interface{})
if err != nil {
return nil, err
}
defer resp.Body.Close()

data, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}

if resp.StatusCode < 200 || resp.StatusCode > 299 {
defer resp.Body.Close()
errMsg, _ := ioutil.ReadAll(resp.Body)
return nil, &ErrRewardNon2XX{
Url: h.url,
StatusCode: resp.StatusCode,
RespBody: string(errMsg),
RespBody: string(data),
}
}

return h.parser.Parse(resp.Body)
return h.parser.Parse(data)
}

type ErrRewardNon2XX struct {
Expand All @@ -95,7 +99,7 @@ type HttpDoer interface {

// RewardParser will be called to convert the response from the reward service to a slice of distributions.
type RewardParser interface {
Parse(io.ReadCloser) ([]Dist, error)
Parse([]byte) ([]Dist, error)
}

// ContextMarshaler is called on the banditContext and the result will become the body of the request to the bandit service.
Expand All @@ -114,9 +118,9 @@ func WithContextMarshaler(m ContextMarshaler) HTTPSourceOption {
}

// ParseFunc is an adapter to allow a normal function to be used as a RewardParser
type ParseFunc func(io.ReadCloser) ([]Dist, error)
type ParseFunc func([]byte) ([]Dist, error)

func (p ParseFunc) Parse(rc io.ReadCloser) ([]Dist, error) { return p(rc) }
func (p ParseFunc) Parse(b []byte) ([]Dist, error) { return p(b) }

// MarshalFunc is an adapter to allow a normal function to be used as a ContextMarshaler
type MarshalFunc func(banditContext interface{}) ([]byte, error)
Expand All @@ -128,14 +132,7 @@ func (m MarshalFunc) Marshal(banditContext interface{}) ([]byte, error) { return
// `[{"alpha": 123, "beta": 456}, {"alpha": 3.1415, "beta": 9.999}]`
// Returns an error if alpha or beta value are missing or less than 1 for any arm.
// Any additional keys are ignored.
func BetaFromJSON(rc io.ReadCloser) ([]Dist, error) {
defer rc.Close()

data, err := ioutil.ReadAll(rc)
if err != nil {
return nil, fmt.Errorf("failed to read data: %w", err)
}

func BetaFromJSON(data []byte) ([]Dist, error) {
var resp []struct {
Alpha *float64 `json:"alpha"`
Beta *float64 `json:"beta"`
Expand Down Expand Up @@ -170,14 +167,7 @@ func BetaFromJSON(rc io.ReadCloser) ([]Dist, error) {
// `[{"mu": 123, "sigma": 456}, {"mu": 3.1415, "sigma": 9.999}]`
// Returns an error if mu or sigma value are missing or sigma is less than 0 for any arm.
// Any additional keys are ignored.
func NormalFromJSON(rc io.ReadCloser) ([]Dist, error) {
defer rc.Close()

data, err := ioutil.ReadAll(rc)
if err != nil {
return nil, fmt.Errorf("failed to read data: %w", err)
}

func NormalFromJSON(data []byte) ([]Dist, error) {
var resp []struct {
Mu *float64 `json:"mu"`
Sigma *float64 `json:"sigma"`
Expand Down Expand Up @@ -208,14 +198,7 @@ func NormalFromJSON(rc io.ReadCloser) ([]Dist, error) {
// Expects the JSON data to be in the form:
// `[{"mu": 123}, {"mu": 3.1415}]`
// Returns an error if mu value is missing for any arm. Any additional keys are ignored.
func PointFromJSON(rc io.ReadCloser) ([]Dist, error) {
defer rc.Close()

data, err := ioutil.ReadAll(rc)
if err != nil {
return nil, fmt.Errorf("failed to read data: %w", err)
}

func PointFromJSON(data []byte) ([]Dist, error) {
var resp []struct {
Mu *float64
}
Expand Down
14 changes: 6 additions & 8 deletions mab_test/http_reward_source_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package mab

import (
"bytes"
"io/ioutil"
"testing"

"github.com/stitchfix/mab"
Expand Down Expand Up @@ -49,7 +47,7 @@ func TestBetaFromJSON(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
actual, err := mab.BetaFromJSON(ioutil.NopCloser(bytes.NewReader(test.data)))
actual, err := mab.BetaFromJSON(test.data)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -96,7 +94,7 @@ func TestBetaFromJSONError(t *testing.T) {
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
_, err := mab.BetaFromJSON(ioutil.NopCloser(bytes.NewReader(test.data)))
_, err := mab.BetaFromJSON(test.data)
if err == nil {
t.Error("expected error but didn't get one")
}
Expand Down Expand Up @@ -149,7 +147,7 @@ func TestNormalFromJSON(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
actual, err := mab.NormalFromJSON(ioutil.NopCloser(bytes.NewReader(test.data)))
actual, err := mab.NormalFromJSON(test.data)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -192,7 +190,7 @@ func TestNormalFromJSONError(t *testing.T) {
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
_, err := mab.NormalFromJSON(ioutil.NopCloser(bytes.NewReader(test.data)))
_, err := mab.NormalFromJSON(test.data)
if err == nil {
t.Error("expected error but didn't get one")
}
Expand Down Expand Up @@ -245,7 +243,7 @@ func TestPointFromJSON(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
actual, err := mab.PointFromJSON(ioutil.NopCloser(bytes.NewReader(test.data)))
actual, err := mab.PointFromJSON(test.data)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -280,7 +278,7 @@ func TestPointFromJSONError(t *testing.T) {
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
_, err := mab.PointFromJSON(ioutil.NopCloser(bytes.NewReader(test.data)))
_, err := mab.PointFromJSON(test.data)
if err == nil {
t.Error("expected error but didn't get one")
}
Expand Down

0 comments on commit 1b25682

Please sign in to comment.