Skip to content

Commit

Permalink
Adding AVLTreeDigest option with a specific Random obj
Browse files Browse the repository at this point in the history
The random element in TDigest can cause some unpredictability in certain use cases.
This commit adds a second constructor to `AVLTreeDigest`, which allows a specific random obj to be used.
If this constructor is used, then the Random object will be persisted, such that the random number generation
is consistent.

Tests have been added to verify that this option does not change the behaviour of the standard `AVLTreeDigest` constructor
  • Loading branch information
cedric-hansen committed Dec 22, 2021
1 parent 15a2de9 commit 8dd0036
Show file tree
Hide file tree
Showing 8 changed files with 180 additions and 22 deletions.
8 changes: 6 additions & 2 deletions benchmark/src/main/java/com/tdunning/Benchmark.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public class Benchmark {
private Random gen = new Random();
private double[] data;

@Param({"merge", "tree"})
@Param({"merge", "tree", "seededTree"})
public String method;

@Param({"20", "50", "100", "200", "500"})
Expand All @@ -59,8 +59,12 @@ public void setup() {
}
if (method.equals("tree")) {
td = new AVLTreeDigest(compression);
} else {
} else if (method.equals("merge")){
td = new MergingDigest(500);
} else if (method.equals("seededTree")) {
td = new AVLTreeDigest(compression, gen);
} else {
throw new IllegalArgumentException("Method " + method + " is not supported");
}

// First values are very cheap to add, we are more interested in the steady state,
Expand Down
13 changes: 12 additions & 1 deletion benchmark/src/main/java/com/tdunning/TDigestBench.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,17 @@ TDigest create(double compression) {
return new AVLTreeDigest(compression);
}

@Override
TDigest create() {
return create(20);
}
},
SEEDED_AVL_TREE {
@Override
TDigest create(double compression) {
return new AVLTreeDigest(compression, new Random());
}

@Override
TDigest create() {
return create(20);
Expand Down Expand Up @@ -106,7 +117,7 @@ AbstractDistribution create(Random random) {
@Param({"100", "300"})
double compression;

@Param({"MERGE", "AVL_TREE"})
@Param({"MERGE", "AVL_TREE", "SEEDED_AVL_TREE"})
TDigestFactory tdigestFactory;

@Param({"NORMAL", "GAMMA"})
Expand Down
103 changes: 86 additions & 17 deletions core/src/main/java/com/tdunning/math/stats/AVLTreeDigest.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@

package com.tdunning.math.stats;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.nio.ByteBuffer;
import java.util.Collection;
import java.util.Collections;
Expand All @@ -27,12 +32,19 @@
import static com.tdunning.math.stats.IntAVLTree.NIL;

public class AVLTreeDigest extends AbstractTDigest {
final Random gen = new Random();
private final Random rng;
private final double compression;
private AVLGroupTree summary;

private long count = 0; // package private for testing

/**
* If {@link rng} should be persisted
*/
private boolean persistRandomObject;

private final static int NUM_BYTES_FOR_RANDOM_OBJECT = 104;

/**
* A histogram structure that will record a sketch of a distribution.
*
Expand All @@ -45,6 +57,16 @@ public class AVLTreeDigest extends AbstractTDigest {
public AVLTreeDigest(double compression) {
this.compression = compression;
summary = new AVLGroupTree(false);
rng = new Random();
persistRandomObject = false;
}

@SuppressWarnings("WeakerAccess")
public AVLTreeDigest(double compression, Random random) {
this.compression = compression;
summary = new AVLGroupTree(false);
rng = random;
persistRandomObject = true;
}

@Override
Expand Down Expand Up @@ -128,7 +150,7 @@ public void add(double x, int w, List<Double> data) {
// what it does is sample uniformly from all clusters that have room
if (summary.count(neighbor) + w <= k) {
n++;
if (gen.nextDouble() < 1 / n) {
if (rng.nextDouble() < 1 / n) {
closest = neighbor;
}
}
Expand Down Expand Up @@ -500,7 +522,7 @@ public double compression() {
@Override
public int byteSize() {
compress();
return 32 + summary.size() * 12;
return 36 + NUM_BYTES_FOR_RANDOM_OBJECT + summary.size() * 12;
}

/**
Expand All @@ -527,7 +549,10 @@ public void asBytes(ByteBuffer buf) {
buf.putDouble(min);
buf.putDouble(max);
buf.putDouble((float) compression());
buf.putInt(persistRandomObject ? 1 : 0);
buf.put(serializeRandomObj(rng));
buf.putInt(summary.size());

for (Centroid centroid : summary) {
buf.putDouble(centroid.mean());
}
Expand All @@ -543,6 +568,8 @@ public void asSmallBytes(ByteBuffer buf) {
buf.putDouble(min);
buf.putDouble(max);
buf.putDouble(compression());
buf.putInt(persistRandomObject ? 1 : 0);
buf.put(serializeRandomObj(rng));
buf.putInt(summary.size());

double x = 0;
Expand All @@ -567,14 +594,21 @@ public void asSmallBytes(ByteBuffer buf) {
@SuppressWarnings("WeakerAccess")
public static AVLTreeDigest fromBytes(ByteBuffer buf) {
int encoding = buf.getInt();
double min = buf.getDouble();
double max = buf.getDouble();
double compression = buf.getDouble();
boolean persistRandomObj = buf.getInt() == 0 ? false : true;
byte [] randomObjBytes = new byte[NUM_BYTES_FOR_RANDOM_OBJECT];
buf.get(randomObjBytes);
Random rand = deserializeRandomObj(randomObjBytes);
AVLTreeDigest r = persistRandomObj ?
new AVLTreeDigest(compression, rand) :
new AVLTreeDigest(compression);
r.setMinMax(min, max);
int n = buf.getInt();
double[] means = new double[n];

if (encoding == VERBOSE_ENCODING) {
double min = buf.getDouble();
double max = buf.getDouble();
double compression = buf.getDouble();
AVLTreeDigest r = new AVLTreeDigest(compression);
r.setMinMax(min, max);
int n = buf.getInt();
double[] means = new double[n];
for (int i = 0; i < n; i++) {
means[i] = buf.getDouble();
}
Expand All @@ -583,13 +617,6 @@ public static AVLTreeDigest fromBytes(ByteBuffer buf) {
}
return r;
} else if (encoding == SMALL_ENCODING) {
double min = buf.getDouble();
double max = buf.getDouble();
double compression = buf.getDouble();
AVLTreeDigest r = new AVLTreeDigest(compression);
r.setMinMax(min, max);
int n = buf.getInt();
double[] means = new double[n];
double x = 0;
for (int i = 0; i < n; i++) {
double delta = buf.getFloat();
Expand All @@ -607,4 +634,46 @@ public static AVLTreeDigest fromBytes(ByteBuffer buf) {
}
}


private byte[] serializeRandomObj(Random r) {
ByteArrayOutputStream bos = new ByteArrayOutputStream();
try {
ObjectOutputStream oos = new ObjectOutputStream(bos);
oos.writeObject(r);
oos.flush();
byte [] data = bos.toByteArray();
bos.close();
oos.close();
return data;
} catch (IOException e) {
e.printStackTrace();
throw new RuntimeException("Cannot serialize random object");
}
}

private static Random deserializeRandomObj(byte [] bytes) {
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
try {
ObjectInputStream ois = new ObjectInputStream(bais);
Random r = (Random)ois.readObject();
return r;
} catch (IOException e) {
e.printStackTrace();
throw new RuntimeException("Cannot deserialize random object");
} catch (ClassNotFoundException e) {
e.printStackTrace();
throw new RuntimeException("Unable to find Random class");
}
}

@Override
public boolean persistRandomValue() {
return persistRandomObject;
}

@Override
public Random getRandomNumberGenerator() {
return rng;
}

}
37 changes: 37 additions & 0 deletions core/src/main/java/com/tdunning/math/stats/TDigest.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.nio.ByteBuffer;
import java.util.Collection;
import java.util.List;
import java.util.Random;

/**
* Adaptive histogram based on something like streaming k-means crossed with Q-digest.
Expand Down Expand Up @@ -70,6 +71,22 @@ public static TDigest createAvlTreeDigest(double compression) {
return new AVLTreeDigest(compression);
}

/**
* Creates an AVLTreeDigest with a specific random seed.
*
* This behaves very similarly to the standard AVLTreeDigest, but with the added ability to start with a specific seed.
* This has uses with allowing historic tree values to remain unchanged
*
* @param compression The compression parameter. 100 is a common value for normal uses. 1000 is extremely large.
* The number of centroids retained will be a smallish (usually less than 10) multiple of this number.
* @param random The random object to user for this TDigest
* @return the AvlTreeDigest
*/
@SuppressWarnings("WeakerAccess")
public static TDigest createAvlTreeDigestWithSeed(double compression, Random random) {
return new AVLTreeDigest(compression, random);
}

/**
* Creates a TDigest of whichever type is the currently recommended type. MergingDigest is generally the best
* known implementation right now.
Expand Down Expand Up @@ -237,4 +254,24 @@ void setMinMax(double min, double max) {
this.min = min;
this.max = max;
}

/**
* In certain TDigest implementations, there are cases where a random object might be
* serialized. This flag indicates if the TDigest is persisting the random object
*
* @return true if the TDigest has a random object that will be serialized
*/
public boolean persistRandomValue() {
return false;
}

/**
* In certain TDigest implementations, there are cases where a random object might play a significant
* role. This method returns the Random instance being used to generate these numbers
*
* @return the random instance in the TDigest if one exists, null otherwise.
*/
public Random getRandomNumberGenerator() {
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,15 @@ public void testMerges() throws FileNotFoundException {
for (double compression : new double[]{50, 100, 200, 400}) {
MergingDigest digest1 = new MergingDigest(compression);
AVLTreeDigest digest2 = new AVLTreeDigest(compression);
AVLTreeDigest digest3 = new AVLTreeDigest(compression, new Random());
List<Double> data = new ArrayList<>();
Random gen = new Random();
for (int i = 0; i < n; i++) {
double x = gen.nextDouble();
data.add(x);
digest1.add(x);
digest2.add(x);
digest3.add(x);
}
Collections.sort(data);
List<Double> counts = new ArrayList<>();
Expand All @@ -73,6 +75,7 @@ public void testMerges() throws FileNotFoundException {
}
sizes.printf("%s, %d, %d, %.0f, %d\n", "merge", counts.size(), digest1.centroids().size(), compression, n);
sizes.printf("%s, %d, %d, %.0f, %d\n", "tree", counts.size(), digest2.centroids().size(), compression, n);
sizes.printf("%s, %d, %d, %.0f, %d\n", "tree with seed", counts.size(), digest3.centroids().size(), compression, n);
sizes.printf("%s, %d, %d, %.0f, %d\n", "ideal", counts.size(), counts.size(), compression, n);
soFar = 0;
for (Double count : counts) {
Expand All @@ -92,6 +95,12 @@ public void testMerges() throws FileNotFoundException {
soFar += c.count();
}
assertEquals(n, soFar, 0);
soFar = 0;
for (Centroid c : digest3.centroids()) {
out.printf("%s, %.0f, %d, %.3f, %d\n", "tree", compression, n, (soFar + c.count() / 2) / n, c.count());
soFar += c.count();
}
assertEquals(n, soFar, 0);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,21 @@
* Serializability is important, for example, if we want to use t-digests with Spark.
*/
public class TDigestSerializationTest {
private final static double COMPRESSION = 100.0;

@Test
public void testMergingDigest() throws IOException {
assertSerializesAndDeserializes(new MergingDigest(100));
assertSerializesAndDeserializes(new MergingDigest(COMPRESSION));
}

@Test
public void testAVLTreeDigest() throws IOException {
assertSerializesAndDeserializes(new AVLTreeDigest(100));
assertSerializesAndDeserializes(new AVLTreeDigest(COMPRESSION));
}

@Test
public void testAVLTreeDigestWithSeed() throws IOException {
assertSerializesAndDeserializes(new AVLTreeDigest(COMPRESSION, new Random()));
}

private <T extends TDigest> void assertSerializesAndDeserializes(T tdigest) throws IOException {
Expand Down Expand Up @@ -86,6 +93,13 @@ private void assertTDigestEquals(TDigest t1, TDigest t2) {
assertEquals(c1.count(), c2.count());
assertEquals(c1.mean(), c2.mean(), 1e-10);
}
assertEquals(t1.persistRandomValue(), t2.persistRandomValue());
if (t1.persistRandomValue()) {
//cheeky way to check if the random objects were properly persisted
assertEquals(t1.getRandomNumberGenerator().nextDouble(),
t2.getRandomNumberGenerator().nextDouble(), 0.0);
}
assertEquals(t1.compression(), t2.compression(), 0.0);
assertFalse(cx.hasNext());
assertNotNull(t2);
}
Expand Down
5 changes: 5 additions & 0 deletions core/src/test/java/com/tdunning/math/stats/TDigestTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,11 @@ public void testSerialization() {
assertEquals(dist.centroids().size(), dist2.centroids().size());
assertEquals(dist.compression(), dist2.compression(), 1e-4);
assertEquals(dist.size(), dist2.size());
assertEquals(dist.persistRandomValue(), dist2.persistRandomValue());
if (dist.persistRandomValue()) {
assertEquals(dist.getRandomNumberGenerator().nextDouble(),
dist2.getRandomNumberGenerator().nextDouble(), 0.0);
}

for (double q = 0; q < 1; q += 0.01) {
assertEquals(dist.quantile(q), dist2.quantile(q), 1e-5);
Expand Down
9 changes: 9 additions & 0 deletions quality/src/test/java/com/tdunning/tdigest/quality/Util.java
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,15 @@ TDigest create(double compression) {
TDigest create() {
return create(20);
}
},

SEEDED_AVL_TREE {
TDigest create(double compression) {
return new AVLTreeDigest(compression, new Random());
}
TDigest create() {
return create(20);
}
};

abstract TDigest create(double compression);
Expand Down

0 comments on commit 8dd0036

Please sign in to comment.