diff --git a/benchmark/src/main/java/com/tdunning/Benchmark.java b/benchmark/src/main/java/com/tdunning/Benchmark.java index ed16de6..8263733 100755 --- a/benchmark/src/main/java/com/tdunning/Benchmark.java +++ b/benchmark/src/main/java/com/tdunning/Benchmark.java @@ -43,7 +43,7 @@ public class Benchmark { private Random gen = new Random(); private double[] data; - @Param({"merge", "tree"}) + @Param({"merge", "tree", "seededTree"}) public String method; @Param({"20", "50", "100", "200", "500"}) @@ -59,8 +59,12 @@ public void setup() { } if (method.equals("tree")) { td = new AVLTreeDigest(compression); - } else { + } else if (method.equals("merge")){ td = new MergingDigest(500); + } else if (method.equals("seededTree")) { + td = new AVLTreeDigest(compression, gen); + } else { + throw new IllegalArgumentException("Method " + method + " is not supported"); } // First values are very cheap to add, we are more interested in the steady state, diff --git a/benchmark/src/main/java/com/tdunning/TDigestBench.java b/benchmark/src/main/java/com/tdunning/TDigestBench.java index 68886e2..1e8eadd 100644 --- a/benchmark/src/main/java/com/tdunning/TDigestBench.java +++ b/benchmark/src/main/java/com/tdunning/TDigestBench.java @@ -45,6 +45,17 @@ TDigest create(double compression) { return new AVLTreeDigest(compression); } + @Override + TDigest create() { + return create(20); + } + }, + SEEDED_AVL_TREE { + @Override + TDigest create(double compression) { + return new AVLTreeDigest(compression, new Random()); + } + @Override TDigest create() { return create(20); @@ -106,7 +117,7 @@ AbstractDistribution create(Random random) { @Param({"100", "300"}) double compression; - @Param({"MERGE", "AVL_TREE"}) + @Param({"MERGE", "AVL_TREE", "SEEDED_AVL_TREE"}) TDigestFactory tdigestFactory; @Param({"NORMAL", "GAMMA"}) diff --git a/core/src/main/java/com/tdunning/math/stats/AVLTreeDigest.java b/core/src/main/java/com/tdunning/math/stats/AVLTreeDigest.java index c874d8b..0d30528 100644 --- a/core/src/main/java/com/tdunning/math/stats/AVLTreeDigest.java +++ b/core/src/main/java/com/tdunning/math/stats/AVLTreeDigest.java @@ -17,6 +17,11 @@ package com.tdunning.math.stats; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; import java.nio.ByteBuffer; import java.util.Collection; import java.util.Collections; @@ -27,12 +32,19 @@ import static com.tdunning.math.stats.IntAVLTree.NIL; public class AVLTreeDigest extends AbstractTDigest { - final Random gen = new Random(); + private final Random rng; private final double compression; private AVLGroupTree summary; private long count = 0; // package private for testing + /** + * If {@link rng} should be persisted + */ + private boolean persistRandomObject; + + private final static int NUM_BYTES_FOR_RANDOM_OBJECT = 104; + /** * A histogram structure that will record a sketch of a distribution. * @@ -45,6 +57,16 @@ public class AVLTreeDigest extends AbstractTDigest { public AVLTreeDigest(double compression) { this.compression = compression; summary = new AVLGroupTree(false); + rng = new Random(); + persistRandomObject = false; + } + + @SuppressWarnings("WeakerAccess") + public AVLTreeDigest(double compression, Random random) { + this.compression = compression; + summary = new AVLGroupTree(false); + rng = random; + persistRandomObject = true; } @Override @@ -128,7 +150,7 @@ public void add(double x, int w, List data) { // what it does is sample uniformly from all clusters that have room if (summary.count(neighbor) + w <= k) { n++; - if (gen.nextDouble() < 1 / n) { + if (rng.nextDouble() < 1 / n) { closest = neighbor; } } @@ -500,7 +522,7 @@ public double compression() { @Override public int byteSize() { compress(); - return 32 + summary.size() * 12; + return 36 + NUM_BYTES_FOR_RANDOM_OBJECT + summary.size() * 12; } /** @@ -527,7 +549,10 @@ public void asBytes(ByteBuffer buf) { buf.putDouble(min); buf.putDouble(max); buf.putDouble((float) compression()); + buf.putInt(persistRandomObject ? 1 : 0); + buf.put(serializeRandomObj(rng)); buf.putInt(summary.size()); + for (Centroid centroid : summary) { buf.putDouble(centroid.mean()); } @@ -543,6 +568,8 @@ public void asSmallBytes(ByteBuffer buf) { buf.putDouble(min); buf.putDouble(max); buf.putDouble(compression()); + buf.putInt(persistRandomObject ? 1 : 0); + buf.put(serializeRandomObj(rng)); buf.putInt(summary.size()); double x = 0; @@ -567,14 +594,21 @@ public void asSmallBytes(ByteBuffer buf) { @SuppressWarnings("WeakerAccess") public static AVLTreeDigest fromBytes(ByteBuffer buf) { int encoding = buf.getInt(); + double min = buf.getDouble(); + double max = buf.getDouble(); + double compression = buf.getDouble(); + boolean persistRandomObj = buf.getInt() == 0 ? false : true; + byte [] randomObjBytes = new byte[NUM_BYTES_FOR_RANDOM_OBJECT]; + buf.get(randomObjBytes); + Random rand = deserializeRandomObj(randomObjBytes); + AVLTreeDigest r = persistRandomObj ? + new AVLTreeDigest(compression, rand) : + new AVLTreeDigest(compression); + r.setMinMax(min, max); + int n = buf.getInt(); + double[] means = new double[n]; + if (encoding == VERBOSE_ENCODING) { - double min = buf.getDouble(); - double max = buf.getDouble(); - double compression = buf.getDouble(); - AVLTreeDigest r = new AVLTreeDigest(compression); - r.setMinMax(min, max); - int n = buf.getInt(); - double[] means = new double[n]; for (int i = 0; i < n; i++) { means[i] = buf.getDouble(); } @@ -583,13 +617,6 @@ public static AVLTreeDigest fromBytes(ByteBuffer buf) { } return r; } else if (encoding == SMALL_ENCODING) { - double min = buf.getDouble(); - double max = buf.getDouble(); - double compression = buf.getDouble(); - AVLTreeDigest r = new AVLTreeDigest(compression); - r.setMinMax(min, max); - int n = buf.getInt(); - double[] means = new double[n]; double x = 0; for (int i = 0; i < n; i++) { double delta = buf.getFloat(); @@ -607,4 +634,46 @@ public static AVLTreeDigest fromBytes(ByteBuffer buf) { } } + + private byte[] serializeRandomObj(Random r) { + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + try { + ObjectOutputStream oos = new ObjectOutputStream(bos); + oos.writeObject(r); + oos.flush(); + byte [] data = bos.toByteArray(); + bos.close(); + oos.close(); + return data; + } catch (IOException e) { + e.printStackTrace(); + throw new RuntimeException("Cannot serialize random object"); + } + } + + private static Random deserializeRandomObj(byte [] bytes) { + ByteArrayInputStream bais = new ByteArrayInputStream(bytes); + try { + ObjectInputStream ois = new ObjectInputStream(bais); + Random r = (Random)ois.readObject(); + return r; + } catch (IOException e) { + e.printStackTrace(); + throw new RuntimeException("Cannot deserialize random object"); + } catch (ClassNotFoundException e) { + e.printStackTrace(); + throw new RuntimeException("Unable to find Random class"); + } + } + + @Override + public boolean persistRandomValue() { + return persistRandomObject; + } + + @Override + public Random getRandomNumberGenerator() { + return rng; + } + } diff --git a/core/src/main/java/com/tdunning/math/stats/TDigest.java b/core/src/main/java/com/tdunning/math/stats/TDigest.java index 67bd5ae..0ee2293 100644 --- a/core/src/main/java/com/tdunning/math/stats/TDigest.java +++ b/core/src/main/java/com/tdunning/math/stats/TDigest.java @@ -21,6 +21,7 @@ import java.nio.ByteBuffer; import java.util.Collection; import java.util.List; +import java.util.Random; /** * Adaptive histogram based on something like streaming k-means crossed with Q-digest. @@ -70,6 +71,22 @@ public static TDigest createAvlTreeDigest(double compression) { return new AVLTreeDigest(compression); } + /** + * Creates an AVLTreeDigest with a specific random seed. + * + * This behaves very similarly to the standard AVLTreeDigest, but with the added ability to start with a specific seed. + * This has uses with allowing historic tree values to remain unchanged + * + * @param compression The compression parameter. 100 is a common value for normal uses. 1000 is extremely large. + * The number of centroids retained will be a smallish (usually less than 10) multiple of this number. + * @param random The random object to user for this TDigest + * @return the AvlTreeDigest + */ + @SuppressWarnings("WeakerAccess") + public static TDigest createAvlTreeDigestWithSeed(double compression, Random random) { + return new AVLTreeDigest(compression, random); + } + /** * Creates a TDigest of whichever type is the currently recommended type. MergingDigest is generally the best * known implementation right now. @@ -237,4 +254,24 @@ void setMinMax(double min, double max) { this.min = min; this.max = max; } + + /** + * In certain TDigest implementations, there are cases where a random object might be + * serialized. This flag indicates if the TDigest is persisting the random object + * + * @return true if the TDigest has a random object that will be serialized + */ + public boolean persistRandomValue() { + return false; + } + + /** + * In certain TDigest implementations, there are cases where a random object might play a significant + * role. This method returns the Random instance being used to generate these numbers + * + * @return the random instance in the TDigest if one exists, null otherwise. + */ + public Random getRandomNumberGenerator() { + return null; + } } diff --git a/core/src/test/java/com/tdunning/math/stats/AlternativeMergeTest.java b/core/src/test/java/com/tdunning/math/stats/AlternativeMergeTest.java index f4ef7ae..8077a07 100644 --- a/core/src/test/java/com/tdunning/math/stats/AlternativeMergeTest.java +++ b/core/src/test/java/com/tdunning/math/stats/AlternativeMergeTest.java @@ -46,6 +46,7 @@ public void testMerges() throws FileNotFoundException { for (double compression : new double[]{50, 100, 200, 400}) { MergingDigest digest1 = new MergingDigest(compression); AVLTreeDigest digest2 = new AVLTreeDigest(compression); + AVLTreeDigest digest3 = new AVLTreeDigest(compression, new Random()); List data = new ArrayList<>(); Random gen = new Random(); for (int i = 0; i < n; i++) { @@ -53,6 +54,7 @@ public void testMerges() throws FileNotFoundException { data.add(x); digest1.add(x); digest2.add(x); + digest3.add(x); } Collections.sort(data); List counts = new ArrayList<>(); @@ -73,6 +75,7 @@ public void testMerges() throws FileNotFoundException { } sizes.printf("%s, %d, %d, %.0f, %d\n", "merge", counts.size(), digest1.centroids().size(), compression, n); sizes.printf("%s, %d, %d, %.0f, %d\n", "tree", counts.size(), digest2.centroids().size(), compression, n); + sizes.printf("%s, %d, %d, %.0f, %d\n", "tree with seed", counts.size(), digest3.centroids().size(), compression, n); sizes.printf("%s, %d, %d, %.0f, %d\n", "ideal", counts.size(), counts.size(), compression, n); soFar = 0; for (Double count : counts) { @@ -92,6 +95,12 @@ public void testMerges() throws FileNotFoundException { soFar += c.count(); } assertEquals(n, soFar, 0); + soFar = 0; + for (Centroid c : digest3.centroids()) { + out.printf("%s, %.0f, %d, %.3f, %d\n", "tree", compression, n, (soFar + c.count() / 2) / n, c.count()); + soFar += c.count(); + } + assertEquals(n, soFar, 0); } } } diff --git a/core/src/test/java/com/tdunning/math/stats/TDigestSerializationTest.java b/core/src/test/java/com/tdunning/math/stats/TDigestSerializationTest.java index dfddbc7..ff0e982 100644 --- a/core/src/test/java/com/tdunning/math/stats/TDigestSerializationTest.java +++ b/core/src/test/java/com/tdunning/math/stats/TDigestSerializationTest.java @@ -38,14 +38,21 @@ * Serializability is important, for example, if we want to use t-digests with Spark. */ public class TDigestSerializationTest { + private final static double COMPRESSION = 100.0; + @Test public void testMergingDigest() throws IOException { - assertSerializesAndDeserializes(new MergingDigest(100)); + assertSerializesAndDeserializes(new MergingDigest(COMPRESSION)); } @Test public void testAVLTreeDigest() throws IOException { - assertSerializesAndDeserializes(new AVLTreeDigest(100)); + assertSerializesAndDeserializes(new AVLTreeDigest(COMPRESSION)); + } + + @Test + public void testAVLTreeDigestWithSeed() throws IOException { + assertSerializesAndDeserializes(new AVLTreeDigest(COMPRESSION, new Random())); } private void assertSerializesAndDeserializes(T tdigest) throws IOException { @@ -86,6 +93,13 @@ private void assertTDigestEquals(TDigest t1, TDigest t2) { assertEquals(c1.count(), c2.count()); assertEquals(c1.mean(), c2.mean(), 1e-10); } + assertEquals(t1.persistRandomValue(), t2.persistRandomValue()); + if (t1.persistRandomValue()) { + //cheeky way to check if the random objects were properly persisted + assertEquals(t1.getRandomNumberGenerator().nextDouble(), + t2.getRandomNumberGenerator().nextDouble(), 0.0); + } + assertEquals(t1.compression(), t2.compression(), 0.0); assertFalse(cx.hasNext()); assertNotNull(t2); } diff --git a/core/src/test/java/com/tdunning/math/stats/TDigestTest.java b/core/src/test/java/com/tdunning/math/stats/TDigestTest.java index ff13dd4..65ef111 100644 --- a/core/src/test/java/com/tdunning/math/stats/TDigestTest.java +++ b/core/src/test/java/com/tdunning/math/stats/TDigestTest.java @@ -747,6 +747,11 @@ public void testSerialization() { assertEquals(dist.centroids().size(), dist2.centroids().size()); assertEquals(dist.compression(), dist2.compression(), 1e-4); assertEquals(dist.size(), dist2.size()); + assertEquals(dist.persistRandomValue(), dist2.persistRandomValue()); + if (dist.persistRandomValue()) { + assertEquals(dist.getRandomNumberGenerator().nextDouble(), + dist2.getRandomNumberGenerator().nextDouble(), 0.0); + } for (double q = 0; q < 1; q += 0.01) { assertEquals(dist.quantile(q), dist2.quantile(q), 1e-5); diff --git a/quality/src/test/java/com/tdunning/tdigest/quality/Util.java b/quality/src/test/java/com/tdunning/tdigest/quality/Util.java index 914b93f..80f412a 100644 --- a/quality/src/test/java/com/tdunning/tdigest/quality/Util.java +++ b/quality/src/test/java/com/tdunning/tdigest/quality/Util.java @@ -60,6 +60,15 @@ TDigest create(double compression) { TDigest create() { return create(20); } + }, + + SEEDED_AVL_TREE { + TDigest create(double compression) { + return new AVLTreeDigest(compression, new Random()); + } + TDigest create() { + return create(20); + } }; abstract TDigest create(double compression);