Skip to content

Commit

Permalink
Optimize sampling from IndexedSeq
Browse files Browse the repository at this point in the history
Indexed sequences allow us to skip over items without examining each
one.
  • Loading branch information
marcusb committed Dec 31, 2024
1 parent 431d5b7 commit 7d1cfdf
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ class ReservoirSamplingBenchmark {
def timeAlgorithmL(state: BenchmarkState, bh: Blackhole): Unit =
bh.consume(new ReservoirSamplingToListAggregator[Int](state.samples).apply(0 until state.collectionSize))

@Benchmark
def timeAlgorithmLSeq(state: BenchmarkState, bh: Blackhole): Unit =
bh.consume(new ReservoirSamplingToListAggregator[Int](state.samples).apply((0 until state.collectionSize).asInstanceOf[Seq[Int]]))

@Benchmark
def timePriorityQeueue(state: BenchmarkState, bh: Blackhole): Unit =
bh.consume(prioQueueSampler(state.samples).apply(0 until state.collectionSize))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import scala.util.Random
* the element type
*/
sealed class Reservoir[T](val capacity: Int) {
var reservoir: mutable.Buffer[T] = mutable.Buffer()
var reservoir: mutable.ArrayBuffer[T] = new mutable.ArrayBuffer

// When the reservoir is full, w is the threshold for accepting an element into the reservoir, and
// the following invariant holds: The maximum score of the elements in the reservoir is w,
Expand Down Expand Up @@ -52,6 +52,13 @@ sealed class Reservoir[T](val capacity: Int) {
}
}

// The number of items to skip before accepting the next item is geometrically distributed
// with probability of success w / prior. The prior will be 1 when adding to a single reservoir,
// but when merging reservoirs it will be the threshold of the reservoir being pulled from,
// and in this case we require that w < prior.
private def nextAcceptTime(rng: Random, prior: Double = 1.0): Int =
(-rng.self.nextExponential / Math.log1p(-w / prior)).toInt

/**
* Add multiple elements to the reservoir.
* @param xs
Expand All @@ -64,26 +71,55 @@ sealed class Reservoir[T](val capacity: Int) {
* @return
* this reservoir
*/
def append(xs: TraversableOnce[T], rng: Random, prior: Double = 1): Reservoir[T] = {
// The number of items to skip before accepting the next item is geometrically distributed
// with probability of success w / prior. The prior will be 1 when adding to a single reservoir,
// but when merging reservoirs it will be the threshold of the reservoir being pulled from,
// and in this case we require that w < prior.
def nextAcceptTime = (-rng.self.nextExponential / Math.log1p(-w / prior)).toInt

var skip = if (isFull) nextAcceptTime else 0
def append(xs: TraversableOnce[T], rng: Random): Reservoir[T] = {
var skip = if (isFull) nextAcceptTime(rng) else 0
for (x <- xs) {
if (!isFull) {
// keep adding while reservoir is not full
accept(x, rng)
if (isFull) {
skip = nextAcceptTime
skip = nextAcceptTime(rng)
}
} else if (skip > 0) {
skip -= 1
} else {
accept(x, rng)
skip = nextAcceptTime
skip = nextAcceptTime(rng)
}
}
this
}

/**
* Add multiple elements to the reservoir. This overload is optimized for indexed sequences, where we can
* skip over multiple indexes without accessing the elements.
*
* @param xs
* the elements to add
* @param rng
* the random source
* @param prior
* the threshold of the elements being added, such that the added element's value is distributed as
* <pre>U[0, prior]</pre>
* @return
* this reservoir
*/
def append(xs: IndexedSeq[T], rng: Random, prior: Double): Reservoir[T] = {
var i = xs.size.min(capacity - size)
for (j <- 0 until i) {
accept(xs(j), rng)
}
assert(isFull)

val end = xs.size
i -= 1
while (i >= 0 && i < end) {
i += 1 + nextAcceptTime(rng, prior)
// the addition can overflow, in which case i < 0
if (i >= 0 && i < end) {
// element enters the reservoir
reservoir(rng.nextInt(capacity)) = xs(i)
w *= Math.pow(rng.nextDouble, kInv)
}
}
this
Expand Down Expand Up @@ -147,7 +183,7 @@ class ReservoirMonoid[T](implicit val randomSupplier: () => Random) extends Mono
s2.reservoir(i) = s2.reservoir.head
s1.append(s2.reservoir.drop(1), rng, s2.w)
} else {
s1.append(s2.reservoir, rng)
s1.append(s2.reservoir, rng, 1.0)
}
}
}
Expand All @@ -157,6 +193,10 @@ class ReservoirMonoid[T](implicit val randomSupplier: () => Random) extends Mono
* reservoir is mutable, it is a good idea to copy the result to an immutable view before using it, as is done
* by [[ReservoirSamplingToListAggregator]].
*
* The aggregator defines operations for [[IndexedSeq]]s that allow for more efficient aggregation, however
* care must be taken with methods such as [[composePrepare()]] which return a regular [[MonoidAggregator]]
* that loses this optimized behavior.
*
* @param k
* the number of elements to sample
* @param randomSupplier
Expand All @@ -172,6 +212,7 @@ abstract class ReservoirSamplingAggregator[T, +C](k: Int)(implicit val randomSup
override def prepare(x: T): Reservoir[T] = monoid.build(k, x)

override def apply(xs: TraversableOnce[T]): C = present(agg(xs))
def apply(xs: IndexedSeq[T]): C = present(agg(xs))

override def applyOption(inputs: TraversableOnce[T]): Option[C] =
if (inputs.isEmpty) None else Some(apply(inputs))
Expand All @@ -180,11 +221,16 @@ abstract class ReservoirSamplingAggregator[T, +C](k: Int)(implicit val randomSup

override def appendAll(r: Reservoir[T], xs: TraversableOnce[T]): Reservoir[T] =
r.append(xs, randomSupplier())
def appendAll(r: Reservoir[T], xs: IndexedSeq[T]): Reservoir[T] =
r.append(xs, randomSupplier(), 1.0)

override def appendAll(xs: TraversableOnce[T]): Reservoir[T] = agg(xs)
def appendAll(xs: IndexedSeq[T]): Reservoir[T] = agg(xs)

private def agg(xs: TraversableOnce[T]): Reservoir[T] =
appendAll(monoid.zero(k), xs)
private def agg(xs: IndexedSeq[T]): Reservoir[T] =
appendAll(monoid.zero(k), xs)
}

class ReservoirSamplingToListAggregator[T](k: Int)(implicit randomSupplier: () => Random)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package com.twitter.algebird.mutable

import com.twitter.algebird.{Aggregator, CheckProperties, Preparer}
import com.twitter.algebird.RandomSamplingLaws._
import com.twitter.algebird.scalacheck.Distribution.{forAllSampled, uniform}
import com.twitter.algebird.{Aggregator, CheckProperties, Preparer}
import org.scalacheck.Gen

import scala.util.Random

Expand All @@ -23,4 +25,11 @@ class ReservoirSamplingTest extends CheckProperties {
property("reservoir sampling with priority queue works") {
randomSamplingDistributions(prioQueueSampler)
}

property("sampling from non-indexed Seq") {
val n = 100
"sampleList" |: forAllSampled(10000, Gen.choose(1, 20))(_ => uniform(n)) { k =>
new ReservoirSamplingToListAggregator[Int](k).apply((0 until n).asInstanceOf[Seq[Int]]).head
}
}
}

0 comments on commit 7d1cfdf

Please sign in to comment.