Skip to content

Commit

Permalink
Numint docs (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
btamadio authored Feb 26, 2021
1 parent a8d718b commit df61949
Show file tree
Hide file tree
Showing 8 changed files with 181 additions and 10 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ go get -u github.com/stitchfix/mab

### Bandit

A `Bandit` consists of three components: a `RewardSource`, a `Strategy` and a `Sampler`. Users can provide their own
implementations of each component, or use the Mab implementations.
A `Bandit` consists of three components: a `RewardSource`, a `Strategy` and a `Sampler`.
Mab provides implementations of each of these, but you are encouraged to implement your own as well!
Each component is defined by single-method interface, making it relatively simple to fully customize a Mab bandit.

Example:

Expand Down
43 changes: 43 additions & 0 deletions numint/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# numint
One-dimensional numerical quadrature

## Description

Numint is a package for one-dimensional numerical quadrature.
It provides several Newton-Cotes and Gauss Legendre quadrature rules and an algorithm for successive approximations.

Numint was developed for use with the [mab](http://github.com/stitchfix/mab) Thompson sampling multi-armed bandit strategy,
but works as a standalone library for numerical integration.

Numint can be extended by implementing the `Rule` and/or `Subdivider` interfaces.
## Installation

```go
go get -u github.com/stitchfix/mab/numint
```

## Usage

```go
package main
import (
"fmt"

"github.com/stitchfix/mab/numint"
)

func main() {
q := numint.NewQuadrature(numint.WithAbsTol(1E-6))
res, _ := q.Integrate(math.Cos, 0, 1)
fmt.Println(res)
}
```

## Documentation

More detailed refence docs can be found on [pkg.go.dev](https://pkg.go.dev/github.com/stitchfix/mab/numint)

## License

Mab and Numint are licensed under the Apache 2.0 license. See the LICENSE file for terms and conditions for use, reproduction, and
distribution.
7 changes: 7 additions & 0 deletions numint/gauss_legendre.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package numint

// GaussLegendre returns a GaussLegendre rule of specified degree.
// The rules are based off of hard-coded constants, and are implemented up to n=12.
func GaussLegendre(degree int) *GaussLegendreRule {

if degree > len(weightValues) || degree < 1 {
Expand Down Expand Up @@ -27,11 +29,14 @@ func GaussLegendre(degree int) *GaussLegendreRule {
}
}

// GaussLegendreRule provides Weights and Points functions for Gauss Legendre quadrature rules.
type GaussLegendreRule struct {
abscissae, weightCoeffs []float64
weights, points []float64
}

// Weights returns the quadrature weights to use for the interval [a, b].
// The number of points returned depends on the degree of the rule.
func (g *GaussLegendreRule) Weights(a float64, b float64) []float64 {

for i := range g.weightCoeffs {
Expand All @@ -41,6 +46,8 @@ func (g *GaussLegendreRule) Weights(a float64, b float64) []float64 {
return g.weights
}

// Points returns the quadrature sampling points to use for the interval [a, b].
// The number of points returned depends on the degree of the rule.
func (g GaussLegendreRule) Points(a float64, b float64) []float64 {

for i := range g.abscissae {
Expand Down
11 changes: 11 additions & 0 deletions numint/newton_cotes.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
package numint

// NewtonCotesOpen returns an open NewtonCotesRule of the specified degree.
// Open Newton-Cotes rules do not include the endpoints of the interval in the sampling points.
// NewtonCotesOpen is implemented for rules up to degree 7.
func NewtonCotesOpen(degree int) NewtonCotesRule {
var coeffs []float64
switch degree {
Expand All @@ -20,6 +23,9 @@ func NewtonCotesOpen(degree int) NewtonCotesRule {
return NewtonCotesRule{coeffs: coeffs, open: true}
}

// NewtonCotesClosed returns a closed NewtonCotesRule of the specified degree.
// Closed Newton-Cotes rules include the endpoints of the interval in the sampling points.
// NewtonCotesClosed is implemented for rules up to degree 5.
func NewtonCotesClosed(degree int) NewtonCotesRule {
var coeffs []float64
switch degree {
Expand All @@ -38,11 +44,14 @@ func NewtonCotesClosed(degree int) NewtonCotesRule {
return NewtonCotesRule{coeffs: coeffs, open: false}
}

// NewtonCotesRule provides Weights and Points functions for Newton-Codes quadrature rules.
type NewtonCotesRule struct {
coeffs []float64
open bool
}

// Weights returns the quadrature weights to use for the interval [a, b].
// The number of points returned depends on the degree and openness of the rule.
func (n *NewtonCotesRule) Weights(a float64, b float64) []float64 {
weights := make([]float64, len(n.coeffs))
for i := range n.coeffs {
Expand All @@ -51,6 +60,8 @@ func (n *NewtonCotesRule) Weights(a float64, b float64) []float64 {
return weights
}

// Points returns the quadrature sampling points to use for the interval [a, b].
// The number of points returned depends on the degree and openness of the rule.
func (n *NewtonCotesRule) Points(a float64, b float64) []float64 {
if n.degree() <= 0 {
return []float64{}
Expand Down
70 changes: 70 additions & 0 deletions numint/numint_test/quadrature_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package numint_test

import (
"math"
"testing"

"github.com/stitchfix/mab"
"github.com/stitchfix/mab/numint"
)

func TestQuadrature_Integrate(t *testing.T) {
tests := []struct {
f func(float64) float64
a, b, expected float64
}{
{
func(x float64) float64 { return x },
0, 0,
0,
},
{
func(x float64) float64 { return x },
0, 1,
0.5,
},
{
func(x float64) float64 { return 1.0 / (1 + x*x) },
0, 1,
0.785398,
},
{
mab.Beta(10, 20).Prob,
0, 1,
1,
},
{
mab.Normal(10, 20).Prob,
-700, 900,
1,
},
{
math.Asinh,
-.5, 1,
0.344588,
},
{
func(x float64) float64 { return x * math.Cos(x*x) },
1, 5,
-0.486911,
},
}
tol := 1E-6
q := numint.NewQuadrature(numint.WithAbsTol(tol))

for _, test := range tests {
t.Run("", func(t *testing.T) {
actual, err := q.Integrate(test.f, test.a, test.b)
if err != nil {
t.Fatal(err)
}
if !closeEnough(test.expected, actual, tol) {
t.Errorf("actual not %f, got=%f", test.expected, actual)
}
})
}
}

func closeEnough(a, b, tol float64) bool {
return math.Abs(a-b) < tol
}
8 changes: 8 additions & 0 deletions numint/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@ package numint

import "math"

// Option is a function that can be passed to NewQuadrature to override the default settings.
type Option func(*Quadrature)

// WithMaxIter sets the max iterations to m
func WithMaxIter(m int) Option {
return func(q *Quadrature) {
q.maxIter = m
}
}

// WithAbsTol sets the absolute tolerance convergence criteria to absTol and sets the relative tolerance to be ignored.
func WithAbsTol(absTol float64) Option {
return func(q *Quadrature) {
q.tol = tolerance{
Expand All @@ -19,6 +22,7 @@ func WithAbsTol(absTol float64) Option {
}
}

// WithRelTol sets the relative tolerance convergence criteria to relTol and sets the absolute tolerance to be ignored.
func WithRelTol(relTol float64) Option {
return func(q *Quadrature) {
q.tol = tolerance{
Expand All @@ -28,6 +32,8 @@ func WithRelTol(relTol float64) Option {
}
}

// WithAbsAndRelTol sets both the absolute and relative tolerances so that the absolute difference and relative differences
// between successive iterations must both meet a threshold for convergence.
func WithAbsAndRelTol(absTol float64, relTol float64) Option {
return func(q *Quadrature) {
q.tol = tolerance{
Expand All @@ -37,12 +43,14 @@ func WithAbsAndRelTol(absTol float64, relTol float64) Option {
}
}

// WithRule sets the rule that should be used for each iteration of numerical quadrature.
func WithRule(rule Rule) Option {
return func(q *Quadrature) {
q.rule = rule
}
}

// WithSubDivider sets the subdivider that should be used to compute the set of sub-intervals for each iteration.
func WithSubDivider(s SubDivider) Option {
return func(q *Quadrature) {
q.subDivider = s
Expand Down
45 changes: 37 additions & 8 deletions numint/quadrature.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Package numint provides rules and methods for one-dimensional numerical quadrature
package numint

import (
"errors"
"fmt"
"math"
)
Expand All @@ -18,6 +18,12 @@ var defaultRule = GaussLegendre(defaultDegree)
var defaultSubdivider = EquallySpaced(defaultSubIntervals)
var defaultTolerance = tolerance{defaultRelTol, defaultAbsTol}

// NewQuadrature returns a pointer to new Quadrature with any Option arguments applied.
// For example:
// q := NewQuadrature()
// Returns a Quadrature with all default settings.
// The default settings can be overridden with Option functions:
// q := NewQuadrature(WithRule(GaussLegendre(4), WithMaxIter(10), WithRelTol(0.01))
func NewQuadrature(opts ...Option) *Quadrature {
quad := Quadrature{
rule: defaultRule,
Expand All @@ -31,40 +37,63 @@ func NewQuadrature(opts ...Option) *Quadrature {
return &quad
}

// Rule is an interface that provides Weights and sampling Points to be used during numerical quadrature.
type Rule interface {
Weights(a float64, b float64) []float64
Points(a float64, b float64) []float64
}

// SubDivider determines how to sub-divide intervals for each iteration of numerical quadrature.
// It takes a slice of intervals and returns the a flat slice containing all the sub-divided intervals.
// For example, using the EquallySpaced subdivider to divide the intervals [0, 1], [1, 2] into two equally-spaced intervals each:
// sub := EquallySpaced(2)
// result := sub.SubDivide([]Interval{{0, 1}, {1, 2}})
// Results in the sub-intervals:
// []Interval{{0, 0.5}, {0.5, 1}, {1, 1.5}, {1.5, 2}}
type SubDivider interface {
SubDivide([]Interval) []Interval
}

type Integrand func(float64) float64
type integrand func(float64) float64

// Interval represents a finite interval between A and B, where B > A.
type Interval struct {
A float64
B float64
}

// Quadrature contains the rule, subdivider, tolerance, and max iterations for numerical quadrature.
// These fields can all be specified using NewQuadrature with the corresponding option functions.
type Quadrature struct {
rule Rule
tol tolerance
subDivider SubDivider
maxIter int
}

// Integrate computes an estimate of the integral of f from a to b.
// It works by first getting the Points and Weights from the Rule for the interval [a, b]
// then computing the sum of w_i * f(x_i) where w_i are the weights and p_i are the points.
// The next step is to subdivide the original interval [a, b] using the SubDivider,
// then compute the same estimate summed over the sub-intervals.
// This process is repeated until the absolute difference between successive iterations is less than the specified tolerance,
// or until maxIter is reached.
// If absolute tolerance is set using WithAbsTol, only absolute tolerance is checked.
// If relative tolerance is set using WithRelTol, only relative tolerance is checked.
// If both absolute and relative tolerances are set using WithAbsAndRelTol, then the absolute difference must be less than *both* tolerances for the algorithm to converge.
// If the max iteration threshold is reached without reaching the specified tolerance, Integrate returns the final result and an error.
// The max iteration threshold can be specified using WithMaxIter as an argument to NewQuadrature.
func (q Quadrature) Integrate(f func(float64) float64, a float64, b float64) (float64, error) {
if a == b {
return 0, nil
}
if !q.canConverge() {
return math.NaN(), errors.New("integral cannot converge. check tolerance")
return math.NaN(), fmt.Errorf("integral cannot converge. check tolerance")
}
return q.iterativeComposite(f, Interval{a, b})
}

func (q Quadrature) iterativeComposite(f Integrand, interval Interval) (float64, error) {
func (q Quadrature) iterativeComposite(f integrand, interval Interval) (float64, error) {

intervals := []Interval{interval}

Expand All @@ -78,13 +107,13 @@ func (q Quadrature) iterativeComposite(f Integrand, interval Interval) (float64,
prevResult := result
result, err = q.compositeEstimate(f, intervals)
if err != nil {
return math.NaN(), fmt.Errorf(err.Error())
return result, err
}
if q.hasConverged(result, prevResult) {
return result, nil
}
}
return math.NaN(), errors.New("failed to converge")
return result, fmt.Errorf("failed to converge")
}

func (q Quadrature) canConverge() bool {
Expand All @@ -98,7 +127,7 @@ func (q Quadrature) hasConverged(result, prevResult float64) bool {
return relErr <= q.tol.relative && absErr <= q.tol.absolute
}

func (q Quadrature) compositeEstimate(f Integrand, intervals []Interval) (float64, error) {
func (q Quadrature) compositeEstimate(f integrand, intervals []Interval) (float64, error) {
total := 0.0
for i := range intervals {
result, err := q.singleEstimate(f, intervals[i])
Expand All @@ -110,7 +139,7 @@ func (q Quadrature) compositeEstimate(f Integrand, intervals []Interval) (float6
return total, nil
}

func (q Quadrature) singleEstimate(f Integrand, interval Interval) (float64, error) {
func (q Quadrature) singleEstimate(f integrand, interval Interval) (float64, error) {

x := q.rule.Points(interval.A, interval.B)
w := q.rule.Weights(interval.A, interval.B)
Expand Down
2 changes: 2 additions & 0 deletions numint/subdividers.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package numint

// EquallySpaced returns an SubDivider that divides intervals into a set number of equally-spaced sub-intervals.
func EquallySpaced(nSubIntervals int) SubDivider {
return equallySpaced{
nSubIntervals: nSubIntervals,
Expand All @@ -10,6 +11,7 @@ type equallySpaced struct {
nSubIntervals int
}

// SubDivide divides each interval into an equal number of sub-intervals and returns the slice of sub-intervals as a flat slice.
func (e equallySpaced) SubDivide(intervals []Interval) []Interval {
result := make([]Interval, 0)

Expand Down

0 comments on commit df61949

Please sign in to comment.