From cd42f24b22afa40fdcaa3511ea06ec45f87e87f8 Mon Sep 17 00:00:00 2001 From: Matthew Rothenberg Date: Wed, 17 Apr 2024 15:46:02 -0400 Subject: [PATCH] feat: method to set randomness source The internal mechanics of this are a bit inelegant, since unfortunately the global randomness source is not exported, necessitating these nil check methods instead. The API here needs some user feedback. I believe the majority case will want to set this once and not on a per-call basis (cf. the deprecated PickSource method in the previous version), but that needs be validated. --- weightedrand.go | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/weightedrand.go b/weightedrand.go index 03648a7..b8fd40d 100644 --- a/weightedrand.go +++ b/weightedrand.go @@ -37,6 +37,8 @@ type Chooser[T any, W integer] struct { data []Choice[T, W] totals []uint64 max uint64 + + customRand *rand.Rand } // NewChooser initializes a new Chooser for picking from the provided choices. @@ -64,7 +66,13 @@ func NewChooser[T any, W integer](choices ...Choice[T, W]) (*Chooser[T, W], erro return nil, errNoValidChoices } - return &Chooser[T, W]{data: choices, totals: totals, max: runningTotal}, nil + return &Chooser[T, W]{data: choices, totals: totals, max: runningTotal, customRand: nil}, nil +} + +// SetRand applies an optional custom randomness source r for the Chooser. If +// set to nil nil, global rand will be used. +func (c *Chooser[T, W]) SetRand(r *rand.Rand) { + c.customRand = r } // Possible errors returned by NewChooser, preventing the creation of a Chooser @@ -82,9 +90,17 @@ var ( // Pick returns a single weighted random Choice.Item from the Chooser. // -// Utilizes global rand as the source of randomness. Safe for concurrent usage. +// Utilizes global rand as the source of randomness by default, which is safe +// for concurrent usage. If a custom rand source was set with SetRand, that +// source will be used instead. func (c Chooser[T, W]) Pick() T { - r := rand.Uint64N(c.max) + 1 + var r uint64 + if c.customRand == nil { + r = rand.Uint64N(c.max) + 1 + } else { + r = c.customRand.Uint64N(c.max) + 1 + } + i, _ := slices.BinarySearch(c.totals, r) return c.data[i].Item }