diff --git a/README.md b/README.md index c73dbf9..25d55b6 100644 --- a/README.md +++ b/README.md @@ -43,8 +43,9 @@ go get -u github.com/stitchfix/mab ### Bandit -A `Bandit` consists of three components: a `RewardSource`, a `Strategy` and a `Sampler`. Users can provide their own -implementations of each component, or use the Mab implementations. +A `Bandit` consists of three components: a `RewardSource`, a `Strategy` and a `Sampler`. +Mab provides implementations of each of these, but you are encouraged to implement your own as well! +Each component is defined by single-method interface, making it relatively simple to fully customize a Mab bandit. Example: diff --git a/numint/README.md b/numint/README.md new file mode 100644 index 0000000..8b3ab1f --- /dev/null +++ b/numint/README.md @@ -0,0 +1,43 @@ +# numint +One-dimensional numerical quadrature + +## Description + +Numint is a package for one-dimensional numerical quadrature. +It provides several Newton-Cotes and Gauss Legendre quadrature rules and an algorithm for successive approximations. + +Numint was developed for use with the [mab](http://github.com/stitchfix/mab) Thompson sampling multi-armed bandit strategy, +but works as a standalone library for numerical integration. + +Numint can be extended by implementing the `Rule` and/or `Subdivider` interfaces. +## Installation + +```go +go get -u github.com/stitchfix/mab/numint +``` + +## Usage + +```go +package main +import ( + "fmt" + + "github.com/stitchfix/mab/numint" +) + +func main() { + q := numint.NewQuadrature(numint.WithAbsTol(1E-6)) + res, _ := q.Integrate(math.Cos, 0, 1) + fmt.Println(res) +} +``` + +## Documentation + +More detailed refence docs can be found on [pkg.go.dev](https://pkg.go.dev/github.com/stitchfix/mab/numint) + +## License + +Mab and Numint are licensed under the Apache 2.0 license. See the LICENSE file for terms and conditions for use, reproduction, and +distribution. \ No newline at end of file diff --git a/numint/gauss_legendre.go b/numint/gauss_legendre.go index 0af3858..9d430b8 100644 --- a/numint/gauss_legendre.go +++ b/numint/gauss_legendre.go @@ -1,5 +1,7 @@ package numint +// GaussLegendre returns a GaussLegendre rule of specified degree. +// The rules are based off of hard-coded constants, and are implemented up to n=12. func GaussLegendre(degree int) *GaussLegendreRule { if degree > len(weightValues) || degree < 1 { @@ -27,11 +29,14 @@ func GaussLegendre(degree int) *GaussLegendreRule { } } +// GaussLegendreRule provides Weights and Points functions for Gauss Legendre quadrature rules. type GaussLegendreRule struct { abscissae, weightCoeffs []float64 weights, points []float64 } +// Weights returns the quadrature weights to use for the interval [a, b]. +// The number of points returned depends on the degree of the rule. func (g *GaussLegendreRule) Weights(a float64, b float64) []float64 { for i := range g.weightCoeffs { @@ -41,6 +46,8 @@ func (g *GaussLegendreRule) Weights(a float64, b float64) []float64 { return g.weights } +// Points returns the quadrature sampling points to use for the interval [a, b]. +// The number of points returned depends on the degree of the rule. func (g GaussLegendreRule) Points(a float64, b float64) []float64 { for i := range g.abscissae { diff --git a/numint/newton_cotes.go b/numint/newton_cotes.go index 4b21599..6d81d23 100644 --- a/numint/newton_cotes.go +++ b/numint/newton_cotes.go @@ -1,5 +1,8 @@ package numint +// NewtonCotesOpen returns an open NewtonCotesRule of the specified degree. +// Open Newton-Cotes rules do not include the endpoints of the interval in the sampling points. +// NewtonCotesOpen is implemented for rules up to degree 7. func NewtonCotesOpen(degree int) NewtonCotesRule { var coeffs []float64 switch degree { @@ -20,6 +23,9 @@ func NewtonCotesOpen(degree int) NewtonCotesRule { return NewtonCotesRule{coeffs: coeffs, open: true} } +// NewtonCotesClosed returns a closed NewtonCotesRule of the specified degree. +// Closed Newton-Cotes rules include the endpoints of the interval in the sampling points. +// NewtonCotesClosed is implemented for rules up to degree 5. func NewtonCotesClosed(degree int) NewtonCotesRule { var coeffs []float64 switch degree { @@ -38,11 +44,14 @@ func NewtonCotesClosed(degree int) NewtonCotesRule { return NewtonCotesRule{coeffs: coeffs, open: false} } +// NewtonCotesRule provides Weights and Points functions for Newton-Codes quadrature rules. type NewtonCotesRule struct { coeffs []float64 open bool } +// Weights returns the quadrature weights to use for the interval [a, b]. +// The number of points returned depends on the degree and openness of the rule. func (n *NewtonCotesRule) Weights(a float64, b float64) []float64 { weights := make([]float64, len(n.coeffs)) for i := range n.coeffs { @@ -51,6 +60,8 @@ func (n *NewtonCotesRule) Weights(a float64, b float64) []float64 { return weights } +// Points returns the quadrature sampling points to use for the interval [a, b]. +// The number of points returned depends on the degree and openness of the rule. func (n *NewtonCotesRule) Points(a float64, b float64) []float64 { if n.degree() <= 0 { return []float64{} diff --git a/numint/numint_test/quadrature_test.go b/numint/numint_test/quadrature_test.go new file mode 100644 index 0000000..77e6219 --- /dev/null +++ b/numint/numint_test/quadrature_test.go @@ -0,0 +1,70 @@ +package numint_test + +import ( + "math" + "testing" + + "github.com/stitchfix/mab" + "github.com/stitchfix/mab/numint" +) + +func TestQuadrature_Integrate(t *testing.T) { + tests := []struct { + f func(float64) float64 + a, b, expected float64 + }{ + { + func(x float64) float64 { return x }, + 0, 0, + 0, + }, + { + func(x float64) float64 { return x }, + 0, 1, + 0.5, + }, + { + func(x float64) float64 { return 1.0 / (1 + x*x) }, + 0, 1, + 0.785398, + }, + { + mab.Beta(10, 20).Prob, + 0, 1, + 1, + }, + { + mab.Normal(10, 20).Prob, + -700, 900, + 1, + }, + { + math.Asinh, + -.5, 1, + 0.344588, + }, + { + func(x float64) float64 { return x * math.Cos(x*x) }, + 1, 5, + -0.486911, + }, + } + tol := 1E-6 + q := numint.NewQuadrature(numint.WithAbsTol(tol)) + + for _, test := range tests { + t.Run("", func(t *testing.T) { + actual, err := q.Integrate(test.f, test.a, test.b) + if err != nil { + t.Fatal(err) + } + if !closeEnough(test.expected, actual, tol) { + t.Errorf("actual not %f, got=%f", test.expected, actual) + } + }) + } +} + +func closeEnough(a, b, tol float64) bool { + return math.Abs(a-b) < tol +} diff --git a/numint/options.go b/numint/options.go index 045ac2d..370a7c4 100644 --- a/numint/options.go +++ b/numint/options.go @@ -2,14 +2,17 @@ package numint import "math" +// Option is a function that can be passed to NewQuadrature to override the default settings. type Option func(*Quadrature) +// WithMaxIter sets the max iterations to m func WithMaxIter(m int) Option { return func(q *Quadrature) { q.maxIter = m } } +// WithAbsTol sets the absolute tolerance convergence criteria to absTol and sets the relative tolerance to be ignored. func WithAbsTol(absTol float64) Option { return func(q *Quadrature) { q.tol = tolerance{ @@ -19,6 +22,7 @@ func WithAbsTol(absTol float64) Option { } } +// WithRelTol sets the relative tolerance convergence criteria to relTol and sets the absolute tolerance to be ignored. func WithRelTol(relTol float64) Option { return func(q *Quadrature) { q.tol = tolerance{ @@ -28,6 +32,8 @@ func WithRelTol(relTol float64) Option { } } +// WithAbsAndRelTol sets both the absolute and relative tolerances so that the absolute difference and relative differences +// between successive iterations must both meet a threshold for convergence. func WithAbsAndRelTol(absTol float64, relTol float64) Option { return func(q *Quadrature) { q.tol = tolerance{ @@ -37,12 +43,14 @@ func WithAbsAndRelTol(absTol float64, relTol float64) Option { } } +// WithRule sets the rule that should be used for each iteration of numerical quadrature. func WithRule(rule Rule) Option { return func(q *Quadrature) { q.rule = rule } } +// WithSubDivider sets the subdivider that should be used to compute the set of sub-intervals for each iteration. func WithSubDivider(s SubDivider) Option { return func(q *Quadrature) { q.subDivider = s diff --git a/numint/quadrature.go b/numint/quadrature.go index 35f9a21..3e98cfa 100644 --- a/numint/quadrature.go +++ b/numint/quadrature.go @@ -1,7 +1,7 @@ +// Package numint provides rules and methods for one-dimensional numerical quadrature package numint import ( - "errors" "fmt" "math" ) @@ -18,6 +18,12 @@ var defaultRule = GaussLegendre(defaultDegree) var defaultSubdivider = EquallySpaced(defaultSubIntervals) var defaultTolerance = tolerance{defaultRelTol, defaultAbsTol} +// NewQuadrature returns a pointer to new Quadrature with any Option arguments applied. +// For example: +// q := NewQuadrature() +// Returns a Quadrature with all default settings. +// The default settings can be overridden with Option functions: +// q := NewQuadrature(WithRule(GaussLegendre(4), WithMaxIter(10), WithRelTol(0.01)) func NewQuadrature(opts ...Option) *Quadrature { quad := Quadrature{ rule: defaultRule, @@ -31,22 +37,33 @@ func NewQuadrature(opts ...Option) *Quadrature { return &quad } +// Rule is an interface that provides Weights and sampling Points to be used during numerical quadrature. type Rule interface { Weights(a float64, b float64) []float64 Points(a float64, b float64) []float64 } +// SubDivider determines how to sub-divide intervals for each iteration of numerical quadrature. +// It takes a slice of intervals and returns the a flat slice containing all the sub-divided intervals. +// For example, using the EquallySpaced subdivider to divide the intervals [0, 1], [1, 2] into two equally-spaced intervals each: +// sub := EquallySpaced(2) +// result := sub.SubDivide([]Interval{{0, 1}, {1, 2}}) +// Results in the sub-intervals: +// []Interval{{0, 0.5}, {0.5, 1}, {1, 1.5}, {1.5, 2}} type SubDivider interface { SubDivide([]Interval) []Interval } -type Integrand func(float64) float64 +type integrand func(float64) float64 +// Interval represents a finite interval between A and B, where B > A. type Interval struct { A float64 B float64 } +// Quadrature contains the rule, subdivider, tolerance, and max iterations for numerical quadrature. +// These fields can all be specified using NewQuadrature with the corresponding option functions. type Quadrature struct { rule Rule tol tolerance @@ -54,17 +71,29 @@ type Quadrature struct { maxIter int } +// Integrate computes an estimate of the integral of f from a to b. +// It works by first getting the Points and Weights from the Rule for the interval [a, b] +// then computing the sum of w_i * f(x_i) where w_i are the weights and p_i are the points. +// The next step is to subdivide the original interval [a, b] using the SubDivider, +// then compute the same estimate summed over the sub-intervals. +// This process is repeated until the absolute difference between successive iterations is less than the specified tolerance, +// or until maxIter is reached. +// If absolute tolerance is set using WithAbsTol, only absolute tolerance is checked. +// If relative tolerance is set using WithRelTol, only relative tolerance is checked. +// If both absolute and relative tolerances are set using WithAbsAndRelTol, then the absolute difference must be less than *both* tolerances for the algorithm to converge. +// If the max iteration threshold is reached without reaching the specified tolerance, Integrate returns the final result and an error. +// The max iteration threshold can be specified using WithMaxIter as an argument to NewQuadrature. func (q Quadrature) Integrate(f func(float64) float64, a float64, b float64) (float64, error) { if a == b { return 0, nil } if !q.canConverge() { - return math.NaN(), errors.New("integral cannot converge. check tolerance") + return math.NaN(), fmt.Errorf("integral cannot converge. check tolerance") } return q.iterativeComposite(f, Interval{a, b}) } -func (q Quadrature) iterativeComposite(f Integrand, interval Interval) (float64, error) { +func (q Quadrature) iterativeComposite(f integrand, interval Interval) (float64, error) { intervals := []Interval{interval} @@ -78,13 +107,13 @@ func (q Quadrature) iterativeComposite(f Integrand, interval Interval) (float64, prevResult := result result, err = q.compositeEstimate(f, intervals) if err != nil { - return math.NaN(), fmt.Errorf(err.Error()) + return result, err } if q.hasConverged(result, prevResult) { return result, nil } } - return math.NaN(), errors.New("failed to converge") + return result, fmt.Errorf("failed to converge") } func (q Quadrature) canConverge() bool { @@ -98,7 +127,7 @@ func (q Quadrature) hasConverged(result, prevResult float64) bool { return relErr <= q.tol.relative && absErr <= q.tol.absolute } -func (q Quadrature) compositeEstimate(f Integrand, intervals []Interval) (float64, error) { +func (q Quadrature) compositeEstimate(f integrand, intervals []Interval) (float64, error) { total := 0.0 for i := range intervals { result, err := q.singleEstimate(f, intervals[i]) @@ -110,7 +139,7 @@ func (q Quadrature) compositeEstimate(f Integrand, intervals []Interval) (float6 return total, nil } -func (q Quadrature) singleEstimate(f Integrand, interval Interval) (float64, error) { +func (q Quadrature) singleEstimate(f integrand, interval Interval) (float64, error) { x := q.rule.Points(interval.A, interval.B) w := q.rule.Weights(interval.A, interval.B) diff --git a/numint/subdividers.go b/numint/subdividers.go index d7d4991..2c3f20d 100644 --- a/numint/subdividers.go +++ b/numint/subdividers.go @@ -1,5 +1,6 @@ package numint +// EquallySpaced returns an SubDivider that divides intervals into a set number of equally-spaced sub-intervals. func EquallySpaced(nSubIntervals int) SubDivider { return equallySpaced{ nSubIntervals: nSubIntervals, @@ -10,6 +11,7 @@ type equallySpaced struct { nSubIntervals int } +// SubDivide divides each interval into an equal number of sub-intervals and returns the slice of sub-intervals as a flat slice. func (e equallySpaced) SubDivide(intervals []Interval) []Interval { result := make([]Interval, 0)