diff --git a/weightedrand.go b/weightedrand.go index 03648a7..b8fd40d 100644 --- a/weightedrand.go +++ b/weightedrand.go @@ -37,6 +37,8 @@ type Chooser[T any, W integer] struct { data []Choice[T, W] totals []uint64 max uint64 + + customRand *rand.Rand } // NewChooser initializes a new Chooser for picking from the provided choices. @@ -64,7 +66,13 @@ func NewChooser[T any, W integer](choices ...Choice[T, W]) (*Chooser[T, W], erro return nil, errNoValidChoices } - return &Chooser[T, W]{data: choices, totals: totals, max: runningTotal}, nil + return &Chooser[T, W]{data: choices, totals: totals, max: runningTotal, customRand: nil}, nil +} + +// SetRand applies an optional custom randomness source r for the Chooser. If +// set to nil nil, global rand will be used. +func (c *Chooser[T, W]) SetRand(r *rand.Rand) { + c.customRand = r } // Possible errors returned by NewChooser, preventing the creation of a Chooser @@ -82,9 +90,17 @@ var ( // Pick returns a single weighted random Choice.Item from the Chooser. // -// Utilizes global rand as the source of randomness. Safe for concurrent usage. +// Utilizes global rand as the source of randomness by default, which is safe +// for concurrent usage. If a custom rand source was set with SetRand, that +// source will be used instead. func (c Chooser[T, W]) Pick() T { - r := rand.Uint64N(c.max) + 1 + var r uint64 + if c.customRand == nil { + r = rand.Uint64N(c.max) + 1 + } else { + r = c.customRand.Uint64N(c.max) + 1 + } + i, _ := slices.BinarySearch(c.totals, r) return c.data[i].Item }