diff --git a/go.sum b/go.sum index 6ba5ff5..52fbb18 100644 --- a/go.sum +++ b/go.sum @@ -34,6 +34,7 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB golang.org/x/mod v0.1.1-0.20191107180719-034126e5016b/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sync v0.0.0-20190423024810-112230192c58 h1:8gQV6CLnAEikrhgkHFbMAEhagSSnXWGV915qUMm9mrU= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/numint/gauss_legendre.go b/numint/gauss_legendre.go index 9d430b8..fe8c534 100644 --- a/numint/gauss_legendre.go +++ b/numint/gauss_legendre.go @@ -24,37 +24,38 @@ func GaussLegendre(degree int) *GaussLegendreRule { return &GaussLegendreRule{ abscissae: abscissae, weightCoeffs: weightCoeffs, - weights: make([]float64, len(weightCoeffs)), - points: make([]float64, len(abscissae)), } } // GaussLegendreRule provides Weights and Points functions for Gauss Legendre quadrature rules. type GaussLegendreRule struct { abscissae, weightCoeffs []float64 - weights, points []float64 } // Weights returns the quadrature weights to use for the interval [a, b]. // The number of points returned depends on the degree of the rule. func (g *GaussLegendreRule) Weights(a float64, b float64) []float64 { + weights := make([]float64, len(g.weightCoeffs)) + for i := range g.weightCoeffs { - g.weights[i] = g.weightCoeffs[i] * (b - a) / 2 + weights[i] = g.weightCoeffs[i] * (b - a) / 2 } - return g.weights + return weights } // Points returns the quadrature sampling points to use for the interval [a, b]. // The number of points returned depends on the degree of the rule. func (g GaussLegendreRule) Points(a float64, b float64) []float64 { + points := make([]float64, len(g.abscissae)) + for i := range g.abscissae { - g.points[i] = g.abscissae[i]*(b-a)/2 + (b+a)/2 + points[i] = g.abscissae[i]*(b-a)/2 + (b+a)/2 } - return g.points + return points } // source: http://www.holoborodko.com/pavel/numerical-methods/numerical-integration/ diff --git a/thompson.go b/thompson.go index 17b2101..39b5567 100644 --- a/thompson.go +++ b/thompson.go @@ -1,5 +1,9 @@ package mab +import ( + "sync" +) + func NewThompson(integrator Integrator) *Thompson { return &Thompson{ integrator: integrator, @@ -8,8 +12,6 @@ func NewThompson(integrator Integrator) *Thompson { type Thompson struct { integrator Integrator - rewards []Dist - probs []float64 } type Integrator interface { @@ -21,39 +23,63 @@ func (t *Thompson) ComputeProbs(rewards []Dist) ([]float64, error) { return []float64{}, nil } - t.rewards = rewards - return t.computeProbs() + integrals := t.integrals(rewards) + return t.integrateParallel(integrals) } -func (t *Thompson) computeProbs() ([]float64, error) { - t.probs = make([]float64, len(t.rewards)) - for arm := range t.rewards { - prob, err := t.computeProb(arm) - if err != nil { - return nil, err - } - t.probs[arm] = prob - } - return t.probs, nil +type integral struct { + integrand integrand + interval interval } -func (t *Thompson) computeProb(arm int) (float64, error) { - integrand := t.integrand(arm) - xMin, xMax := t.rewards[arm].Support() +type integrand func(float64) float64 +type interval struct{ a, b float64 } - return t.integrator.Integrate(integrand, xMin, xMax) +func (t *Thompson) integrals(rewards []Dist) []integral { + result := make([]integral, len(rewards)) + for i := range rewards { + result[i].integrand = t.integrand(rewards, i) + result[i].interval.a, result[i].interval.b = rewards[i].Support() + } + return result } -func (t *Thompson) integrand(arm int) func(float64) float64 { +func (t *Thompson) integrand(rewards []Dist, arm int) integrand { return func(x float64) float64 { - total := t.rewards[arm].Prob(x) - for j := range t.rewards { + total := rewards[arm].Prob(x) + for j := range rewards { if arm == j { continue } - total *= t.rewards[j].CDF(x) + total *= rewards[j].CDF(x) } return total } } + +func (t *Thompson) integrateParallel(integrals []integral) ([]float64, error) { + n := len(integrals) + + results := make([]float64, n) + errs := make([]error, n) + + var wg sync.WaitGroup + for i := 0; i < n; i++ { + wg.Add(1) + go func(i int, xi integral) { + results[i], errs[i] = t.integrator.Integrate(xi.integrand, xi.interval.a, xi.interval.b) + wg.Done() + }(i, integrals[i]) + } + + wg.Wait() + + for _, err := range errs { + if err != nil { + return nil, err + } + } + + return results, nil +}