From 8dd0036739c51fb3537ebb34c8abd7ab350da630 Mon Sep 17 00:00:00 2001 From: Cedric Hansen Date: Mon, 20 Dec 2021 18:34:44 -0500 Subject: [PATCH] Adding AVLTreeDigest option with a specific Random obj The random element in TDigest can cause some unpredictability in certain use cases. This commit adds a second constructor to `AVLTreeDigest`, which allows a specific random obj to be used. If this constructor is used, then the Random object will be persisted, such that the random number generation is consistent. Tests have been added to verify that this option does not change the behaviour of the standard `AVLTreeDigest` constructor --- .../src/main/java/com/tdunning/Benchmark.java | 8 +- .../main/java/com/tdunning/TDigestBench.java | 13 ++- .../tdunning/math/stats/AVLTreeDigest.java | 103 +++++++++++++++--- .../java/com/tdunning/math/stats/TDigest.java | 37 +++++++ .../math/stats/AlternativeMergeTest.java | 9 ++ .../math/stats/TDigestSerializationTest.java | 18 ++- .../com/tdunning/math/stats/TDigestTest.java | 5 + .../com/tdunning/tdigest/quality/Util.java | 9 ++ 8 files changed, 180 insertions(+), 22 deletions(-) diff --git a/benchmark/src/main/java/com/tdunning/Benchmark.java b/benchmark/src/main/java/com/tdunning/Benchmark.java index ed16de6e..82637336 100755 --- a/benchmark/src/main/java/com/tdunning/Benchmark.java +++ b/benchmark/src/main/java/com/tdunning/Benchmark.java @@ -43,7 +43,7 @@ public class Benchmark { private Random gen = new Random(); private double[] data; - @Param({"merge", "tree"}) + @Param({"merge", "tree", "seededTree"}) public String method; @Param({"20", "50", "100", "200", "500"}) @@ -59,8 +59,12 @@ public void setup() { } if (method.equals("tree")) { td = new AVLTreeDigest(compression); - } else { + } else if (method.equals("merge")){ td = new MergingDigest(500); + } else if (method.equals("seededTree")) { + td = new AVLTreeDigest(compression, gen); + } else { + throw new IllegalArgumentException("Method " + method + " is not supported"); } // First values are very cheap to add, we are more interested in the steady state, diff --git a/benchmark/src/main/java/com/tdunning/TDigestBench.java b/benchmark/src/main/java/com/tdunning/TDigestBench.java index 68886e2b..1e8eadde 100644 --- a/benchmark/src/main/java/com/tdunning/TDigestBench.java +++ b/benchmark/src/main/java/com/tdunning/TDigestBench.java @@ -45,6 +45,17 @@ TDigest create(double compression) { return new AVLTreeDigest(compression); } + @Override + TDigest create() { + return create(20); + } + }, + SEEDED_AVL_TREE { + @Override + TDigest create(double compression) { + return new AVLTreeDigest(compression, new Random()); + } + @Override TDigest create() { return create(20); @@ -106,7 +117,7 @@ AbstractDistribution create(Random random) { @Param({"100", "300"}) double compression; - @Param({"MERGE", "AVL_TREE"}) + @Param({"MERGE", "AVL_TREE", "SEEDED_AVL_TREE"}) TDigestFactory tdigestFactory; @Param({"NORMAL", "GAMMA"}) diff --git a/core/src/main/java/com/tdunning/math/stats/AVLTreeDigest.java b/core/src/main/java/com/tdunning/math/stats/AVLTreeDigest.java index c874d8b2..0d30528a 100644 --- a/core/src/main/java/com/tdunning/math/stats/AVLTreeDigest.java +++ b/core/src/main/java/com/tdunning/math/stats/AVLTreeDigest.java @@ -17,6 +17,11 @@ package com.tdunning.math.stats; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; import java.nio.ByteBuffer; import java.util.Collection; import java.util.Collections; @@ -27,12 +32,19 @@ import static com.tdunning.math.stats.IntAVLTree.NIL; public class AVLTreeDigest extends AbstractTDigest { - final Random gen = new Random(); + private final Random rng; private final double compression; private AVLGroupTree summary; private long count = 0; // package private for testing + /** + * If {@link rng} should be persisted + */ + private boolean persistRandomObject; + + private final static int NUM_BYTES_FOR_RANDOM_OBJECT = 104; + /** * A histogram structure that will record a sketch of a distribution. * @@ -45,6 +57,16 @@ public class AVLTreeDigest extends AbstractTDigest { public AVLTreeDigest(double compression) { this.compression = compression; summary = new AVLGroupTree(false); + rng = new Random(); + persistRandomObject = false; + } + + @SuppressWarnings("WeakerAccess") + public AVLTreeDigest(double compression, Random random) { + this.compression = compression; + summary = new AVLGroupTree(false); + rng = random; + persistRandomObject = true; } @Override @@ -128,7 +150,7 @@ public void add(double x, int w, List data) { // what it does is sample uniformly from all clusters that have room if (summary.count(neighbor) + w <= k) { n++; - if (gen.nextDouble() < 1 / n) { + if (rng.nextDouble() < 1 / n) { closest = neighbor; } } @@ -500,7 +522,7 @@ public double compression() { @Override public int byteSize() { compress(); - return 32 + summary.size() * 12; + return 36 + NUM_BYTES_FOR_RANDOM_OBJECT + summary.size() * 12; } /** @@ -527,7 +549,10 @@ public void asBytes(ByteBuffer buf) { buf.putDouble(min); buf.putDouble(max); buf.putDouble((float) compression()); + buf.putInt(persistRandomObject ? 1 : 0); + buf.put(serializeRandomObj(rng)); buf.putInt(summary.size()); + for (Centroid centroid : summary) { buf.putDouble(centroid.mean()); } @@ -543,6 +568,8 @@ public void asSmallBytes(ByteBuffer buf) { buf.putDouble(min); buf.putDouble(max); buf.putDouble(compression()); + buf.putInt(persistRandomObject ? 1 : 0); + buf.put(serializeRandomObj(rng)); buf.putInt(summary.size()); double x = 0; @@ -567,14 +594,21 @@ public void asSmallBytes(ByteBuffer buf) { @SuppressWarnings("WeakerAccess") public static AVLTreeDigest fromBytes(ByteBuffer buf) { int encoding = buf.getInt(); + double min = buf.getDouble(); + double max = buf.getDouble(); + double compression = buf.getDouble(); + boolean persistRandomObj = buf.getInt() == 0 ? false : true; + byte [] randomObjBytes = new byte[NUM_BYTES_FOR_RANDOM_OBJECT]; + buf.get(randomObjBytes); + Random rand = deserializeRandomObj(randomObjBytes); + AVLTreeDigest r = persistRandomObj ? + new AVLTreeDigest(compression, rand) : + new AVLTreeDigest(compression); + r.setMinMax(min, max); + int n = buf.getInt(); + double[] means = new double[n]; + if (encoding == VERBOSE_ENCODING) { - double min = buf.getDouble(); - double max = buf.getDouble(); - double compression = buf.getDouble(); - AVLTreeDigest r = new AVLTreeDigest(compression); - r.setMinMax(min, max); - int n = buf.getInt(); - double[] means = new double[n]; for (int i = 0; i < n; i++) { means[i] = buf.getDouble(); } @@ -583,13 +617,6 @@ public static AVLTreeDigest fromBytes(ByteBuffer buf) { } return r; } else if (encoding == SMALL_ENCODING) { - double min = buf.getDouble(); - double max = buf.getDouble(); - double compression = buf.getDouble(); - AVLTreeDigest r = new AVLTreeDigest(compression); - r.setMinMax(min, max); - int n = buf.getInt(); - double[] means = new double[n]; double x = 0; for (int i = 0; i < n; i++) { double delta = buf.getFloat(); @@ -607,4 +634,46 @@ public static AVLTreeDigest fromBytes(ByteBuffer buf) { } } + + private byte[] serializeRandomObj(Random r) { + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + try { + ObjectOutputStream oos = new ObjectOutputStream(bos); + oos.writeObject(r); + oos.flush(); + byte [] data = bos.toByteArray(); + bos.close(); + oos.close(); + return data; + } catch (IOException e) { + e.printStackTrace(); + throw new RuntimeException("Cannot serialize random object"); + } + } + + private static Random deserializeRandomObj(byte [] bytes) { + ByteArrayInputStream bais = new ByteArrayInputStream(bytes); + try { + ObjectInputStream ois = new ObjectInputStream(bais); + Random r = (Random)ois.readObject(); + return r; + } catch (IOException e) { + e.printStackTrace(); + throw new RuntimeException("Cannot deserialize random object"); + } catch (ClassNotFoundException e) { + e.printStackTrace(); + throw new RuntimeException("Unable to find Random class"); + } + } + + @Override + public boolean persistRandomValue() { + return persistRandomObject; + } + + @Override + public Random getRandomNumberGenerator() { + return rng; + } + } diff --git a/core/src/main/java/com/tdunning/math/stats/TDigest.java b/core/src/main/java/com/tdunning/math/stats/TDigest.java index 67bd5aef..0ee2293d 100644 --- a/core/src/main/java/com/tdunning/math/stats/TDigest.java +++ b/core/src/main/java/com/tdunning/math/stats/TDigest.java @@ -21,6 +21,7 @@ import java.nio.ByteBuffer; import java.util.Collection; import java.util.List; +import java.util.Random; /** * Adaptive histogram based on something like streaming k-means crossed with Q-digest. @@ -70,6 +71,22 @@ public static TDigest createAvlTreeDigest(double compression) { return new AVLTreeDigest(compression); } + /** + * Creates an AVLTreeDigest with a specific random seed. + * + * This behaves very similarly to the standard AVLTreeDigest, but with the added ability to start with a specific seed. + * This has uses with allowing historic tree values to remain unchanged + * + * @param compression The compression parameter. 100 is a common value for normal uses. 1000 is extremely large. + * The number of centroids retained will be a smallish (usually less than 10) multiple of this number. + * @param random The random object to user for this TDigest + * @return the AvlTreeDigest + */ + @SuppressWarnings("WeakerAccess") + public static TDigest createAvlTreeDigestWithSeed(double compression, Random random) { + return new AVLTreeDigest(compression, random); + } + /** * Creates a TDigest of whichever type is the currently recommended type. MergingDigest is generally the best * known implementation right now. @@ -237,4 +254,24 @@ void setMinMax(double min, double max) { this.min = min; this.max = max; } + + /** + * In certain TDigest implementations, there are cases where a random object might be + * serialized. This flag indicates if the TDigest is persisting the random object + * + * @return true if the TDigest has a random object that will be serialized + */ + public boolean persistRandomValue() { + return false; + } + + /** + * In certain TDigest implementations, there are cases where a random object might play a significant + * role. This method returns the Random instance being used to generate these numbers + * + * @return the random instance in the TDigest if one exists, null otherwise. + */ + public Random getRandomNumberGenerator() { + return null; + } } diff --git a/core/src/test/java/com/tdunning/math/stats/AlternativeMergeTest.java b/core/src/test/java/com/tdunning/math/stats/AlternativeMergeTest.java index f4ef7aee..8077a077 100644 --- a/core/src/test/java/com/tdunning/math/stats/AlternativeMergeTest.java +++ b/core/src/test/java/com/tdunning/math/stats/AlternativeMergeTest.java @@ -46,6 +46,7 @@ public void testMerges() throws FileNotFoundException { for (double compression : new double[]{50, 100, 200, 400}) { MergingDigest digest1 = new MergingDigest(compression); AVLTreeDigest digest2 = new AVLTreeDigest(compression); + AVLTreeDigest digest3 = new AVLTreeDigest(compression, new Random()); List data = new ArrayList<>(); Random gen = new Random(); for (int i = 0; i < n; i++) { @@ -53,6 +54,7 @@ public void testMerges() throws FileNotFoundException { data.add(x); digest1.add(x); digest2.add(x); + digest3.add(x); } Collections.sort(data); List counts = new ArrayList<>(); @@ -73,6 +75,7 @@ public void testMerges() throws FileNotFoundException { } sizes.printf("%s, %d, %d, %.0f, %d\n", "merge", counts.size(), digest1.centroids().size(), compression, n); sizes.printf("%s, %d, %d, %.0f, %d\n", "tree", counts.size(), digest2.centroids().size(), compression, n); + sizes.printf("%s, %d, %d, %.0f, %d\n", "tree with seed", counts.size(), digest3.centroids().size(), compression, n); sizes.printf("%s, %d, %d, %.0f, %d\n", "ideal", counts.size(), counts.size(), compression, n); soFar = 0; for (Double count : counts) { @@ -92,6 +95,12 @@ public void testMerges() throws FileNotFoundException { soFar += c.count(); } assertEquals(n, soFar, 0); + soFar = 0; + for (Centroid c : digest3.centroids()) { + out.printf("%s, %.0f, %d, %.3f, %d\n", "tree", compression, n, (soFar + c.count() / 2) / n, c.count()); + soFar += c.count(); + } + assertEquals(n, soFar, 0); } } } diff --git a/core/src/test/java/com/tdunning/math/stats/TDigestSerializationTest.java b/core/src/test/java/com/tdunning/math/stats/TDigestSerializationTest.java index dfddbc76..ff0e9823 100644 --- a/core/src/test/java/com/tdunning/math/stats/TDigestSerializationTest.java +++ b/core/src/test/java/com/tdunning/math/stats/TDigestSerializationTest.java @@ -38,14 +38,21 @@ * Serializability is important, for example, if we want to use t-digests with Spark. */ public class TDigestSerializationTest { + private final static double COMPRESSION = 100.0; + @Test public void testMergingDigest() throws IOException { - assertSerializesAndDeserializes(new MergingDigest(100)); + assertSerializesAndDeserializes(new MergingDigest(COMPRESSION)); } @Test public void testAVLTreeDigest() throws IOException { - assertSerializesAndDeserializes(new AVLTreeDigest(100)); + assertSerializesAndDeserializes(new AVLTreeDigest(COMPRESSION)); + } + + @Test + public void testAVLTreeDigestWithSeed() throws IOException { + assertSerializesAndDeserializes(new AVLTreeDigest(COMPRESSION, new Random())); } private void assertSerializesAndDeserializes(T tdigest) throws IOException { @@ -86,6 +93,13 @@ private void assertTDigestEquals(TDigest t1, TDigest t2) { assertEquals(c1.count(), c2.count()); assertEquals(c1.mean(), c2.mean(), 1e-10); } + assertEquals(t1.persistRandomValue(), t2.persistRandomValue()); + if (t1.persistRandomValue()) { + //cheeky way to check if the random objects were properly persisted + assertEquals(t1.getRandomNumberGenerator().nextDouble(), + t2.getRandomNumberGenerator().nextDouble(), 0.0); + } + assertEquals(t1.compression(), t2.compression(), 0.0); assertFalse(cx.hasNext()); assertNotNull(t2); } diff --git a/core/src/test/java/com/tdunning/math/stats/TDigestTest.java b/core/src/test/java/com/tdunning/math/stats/TDigestTest.java index ff13dd41..65ef111b 100644 --- a/core/src/test/java/com/tdunning/math/stats/TDigestTest.java +++ b/core/src/test/java/com/tdunning/math/stats/TDigestTest.java @@ -747,6 +747,11 @@ public void testSerialization() { assertEquals(dist.centroids().size(), dist2.centroids().size()); assertEquals(dist.compression(), dist2.compression(), 1e-4); assertEquals(dist.size(), dist2.size()); + assertEquals(dist.persistRandomValue(), dist2.persistRandomValue()); + if (dist.persistRandomValue()) { + assertEquals(dist.getRandomNumberGenerator().nextDouble(), + dist2.getRandomNumberGenerator().nextDouble(), 0.0); + } for (double q = 0; q < 1; q += 0.01) { assertEquals(dist.quantile(q), dist2.quantile(q), 1e-5); diff --git a/quality/src/test/java/com/tdunning/tdigest/quality/Util.java b/quality/src/test/java/com/tdunning/tdigest/quality/Util.java index 914b93f6..80f412aa 100644 --- a/quality/src/test/java/com/tdunning/tdigest/quality/Util.java +++ b/quality/src/test/java/com/tdunning/tdigest/quality/Util.java @@ -60,6 +60,15 @@ TDigest create(double compression) { TDigest create() { return create(20); } + }, + + SEEDED_AVL_TREE { + TDigest create(double compression) { + return new AVLTreeDigest(compression, new Random()); + } + TDigest create() { + return create(20); + } }; abstract TDigest create(double compression);