Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exception during histogram add operation #220

Open
sumanshil opened this issue Sep 21, 2024 · 8 comments
Open

Exception during histogram add operation #220

sumanshil opened this issue Sep 21, 2024 · 8 comments

Comments

@sumanshil
Copy link

We are adding histograms in a flink job which is aggregating histograms from a kafka queue. During the aggregation we are seeing an exception which is getting thrown from this function. We see flink job instability when it encounters too many such exceptions. I am thinking to check the centroid.mean() value before calling the merge operation.

  public boolean isValid() {
    return digest.centroids().stream().noneMatch((c) -> Double.isNaN(c.mean()));
  }

I wanted to check if this would be the right logic or will it result in inaccurate histogram calculation.

@tdunning
Copy link
Owner

The line that you point to in your comment refer to this code:

    public double cdf(double x) {
        if (Double.isNaN(x) || Double.isInfinite(x)) {
            throw new IllegalArgumentException(String.format("Invalid value: %f", x));
        }

As you can see, the test is checking to see if the argument you are passing in is a valid number. Thus, it appears that the number you are giving to the code is invalid and that means that the exception is a correct response.

But I also think that you are worried about the validity of the of the t-digest itself. That makes me wonder if the line that you pointed to is not the line that where you are getting an exception. Or maybe it means that I am misunderstanding entirely.

Can you provide a stack trace of the exception you are seeing? Do you have some code to replicate the error that you can post here?

@sumanshil
Copy link
Author

sumanshil commented Sep 21, 2024

Sorry. It was my mistake. Actually the exception is getting thrown from this line.

    private void add(double x, int w, List<Double> history) {
        if (Double.isNaN(x)) {
            throw new IllegalArgumentException("Cannot add NaN to t-digest");
        }

The above function is getting called from the below function in AbstractTDigest.java.

    @Override
    public void add(TDigest other) {
        for (Centroid centroid : other.centroids()) {
            add(centroid.mean(), centroid.count(), centroid);
        }
    }

I am wondering if values are getting added from an invalid tdigest. Exception is getting thrown when a Double.NaN value is encountered. Here is the full stacktrace

java.lang.IllegalArgumentException: Cannot add NaN to t-digest
	at com.tdunning.math.stats.MergingDigest.add(MergingDigest.java:199) ~[flink-job.jar:?]
	at com.tdunning.math.stats.MergingDigest.add(MergingDigest.java:189) ~[flink-job.jar:?]
	at com.tdunning.math.stats.AbstractTDigest.add(AbstractTDigest.java:143) ~[flink-job.jar:?]
	at visibility.mabs.src.main.java.com.pinterest.utils.MabsBaseMetric.plus(MabsBaseMetric.java:350) ~[flink-job.jar:?]
	at visibility.mabs.src.main.java.com.pinterest.mabs.MabsFlinkJob$1.reduce(MabsFlinkJob.java:92) ~[flink-job.jar:?]
	at visibility.mabs.src.main.java.com.pinterest.mabs.MabsFlinkJob$1.reduce(MabsFlinkJob.java:88) ~[flink-job.jar:?]
	at 

We are currently not sure how a tdigest instance end up in this state when a centroid.mean() value is Double.NaN. We are trying to avoid this exception in a Flink job.

@tdunning
Copy link
Owner

tdunning commented Sep 21, 2024 via email

@sumanshil
Copy link
Author

We are using version 3.2

@sumanshil
Copy link
Author

We are looking for a solution to handle invalid tdigest without throwing an exception as we process a large number of tdigest in a flink job. It causes high cpu usage in the job and results in instability. We can add a check like I mentioned before adding tdigest. But I am wondering if we can avoid that too as it is expensive to check for all tdigest instance. Thanks for your help on this issue.

@tdunning
Copy link
Owner

tdunning commented Sep 22, 2024 via email

@sumanshil
Copy link
Author

sumanshil commented Sep 22, 2024

We are running metric collector in hosts. Metric collector which is part of application is aggregating tdigests in hosts for 1 minute interval. The aggregated tdigests are enqueued in Kafka. The flink job consumes from kafka queue and aggregates the tdigests further based on cluster to reduce cardinality. At the collector level, we have added a check to prevent Double.NaN to get added in tdigest

  public void add(double value) {
    if (Double.isNaN(value)) {
      return;
    }
    this.digest.add(value);
   }
  }

But this check is still not able to prevent the exception in the Flink job. It looks like when we aggregate at the collector in hosts, the aggregation results in Double.NaN value in centroid.mean() value. We can try upgrading to 3.3. But I am wondering if you encountered this issue before and was it fixed in later versions?

@sumanshil
Copy link
Author

BTW We are using ResettableMergingDigest.java. This is a copy of MergingDigest.java and has a reset() function. This class was created to reuse tdigest instances.


import java.nio.ByteBuffer;
import java.util.AbstractCollection;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;

public class ResettableMergingDigest extends AbstractTDigest {
  private double compression;

  // points to the first unused centroid
  private int lastUsedCell;

  // sum_i weight[i]  See also unmergedWeight
  private double totalWeight = 0;

  // number of points that have been added to each merged centroid
  private double[] weight;
  // mean of points added to each merged centroid
  private double[] mean;

  // history of all data added to centroids (for testing purposes)
  private List<List<Double>> data = null;

  // sum_i tempWeight[i]
  private double unmergedWeight = 0;

  // this is the index of the next temporary centroid
  // this is a more Java-like convention than lastUsedCell uses
  private int tempUsed = 0;
  private double[] tempWeight;
  private double[] tempMean;
  private List<List<Double>> tempData = null;

  // array used for sorting the temp centroids.  This is a field
  // to avoid allocations during operation
  private int[] order;
  private static boolean usePieceWiseApproximation = true;
  private static boolean useWeightLimit = true;

  private double compressionInput = 0;
  private int bufferSizeInput = 0;
  private int sizeInput = 0;

  /**
   * Allocates a buffer merging t-digest. This is the normally used constructor that allocates
   * default sized internal arrays. Other versions are available, but should only be used for
   * special cases.
   *
   * @param compression The compression factor
   */
  @SuppressWarnings("WeakerAccess")
  public ResettableMergingDigest(double compression) {
    this(compression, -1);
  }

  /**
   * If you know the size of the temporary buffer for incoming points, you can use this entry point.
   *
   * @param compression Compression factor for t-digest. Same as 1/\delta in the paper.
   * @param bufferSize How many samples to retain before merging.
   */
  @SuppressWarnings("WeakerAccess")
  public ResettableMergingDigest(double compression, int bufferSize) {
    // we can guarantee that we only need 2 * ceiling(compression).
    this(compression, bufferSize, -1);
  }

  /**
   * Fully specified constructor. Normally only used for deserializing a buffer t-digest.
   *
   * @param compression Compression factor
   * @param bufferSize Number of temporary centroids
   * @param size Size of main buffer
   */
  @SuppressWarnings("WeakerAccess")
  public ResettableMergingDigest(double compression, int bufferSize, int size) {
    this.compressionInput = compression;
    this.bufferSizeInput = bufferSize;
    this.sizeInput = size;
    initialize();
  }

  public void initialize() {
    int size = this.sizeInput;
    this.compression = this.compressionInput;
    if (size == -1) {
      size = (int) (2 * Math.ceil(this.compression));
      if (useWeightLimit) {
        // the weight limit approach generates smaller centroids than necessary
        // that can result in using a bit more memory than expected
        size += 10;
      }
    }
    int bufferSize = this.bufferSizeInput;
    if (bufferSize == -1) {
      // having a big buffer is good for speed
      // experiments show bufferSize = 1 gives half the performance of bufferSize=10
      // bufferSize = 2 gives 40% worse performance than 10
      // but bufferSize = 5 only costs about 5-10%
      //
      //   compression factor     time(us)
      //    50          1         0.275799
      //    50          2         0.151368
      //    50          5         0.108856
      //    50         10         0.102530
      //   100          1         0.215121
      //   100          2         0.142743
      //   100          5         0.112278
      //   100         10         0.107753
      //   200          1         0.210972
      //   200          2         0.148613
      //   200          5         0.118220
      //   200         10         0.112970
      //   500          1         0.219469
      //   500          2         0.158364
      //   500          5         0.127552
      //   500         10         0.121505
      bufferSize = (int) (5 * Math.ceil(this.compression));
    }

    weight = new double[size];
    mean = new double[size];

    tempWeight = new double[bufferSize];
    tempMean = new double[bufferSize];
    order = new int[bufferSize];

    lastUsedCell = 0;
  }

  public void reset() {
    min = Double.POSITIVE_INFINITY;
    max = Double.NEGATIVE_INFINITY;
    totalWeight = 0;
    data = null;
    unmergedWeight = 0;
    tempUsed = 0;
    tempData = null;
    useWeightLimit = true;
    usePieceWiseApproximation = true;
    initialize();
  }

  /** Turns on internal data recording. */
  @Override
  public TDigest recordAllData() {
    super.recordAllData();
    data = new ArrayList<>();
    tempData = new ArrayList<>();
    return this;
  }

  @Override
  void add(double x, int w, Centroid base) {
    add(x, w, base.data());
  }

  @Override
  public void add(double x, int w) {
    add(x, w, (List<Double>) null);
  }

  private void add(double x, int w, List<Double> history) {
    if (Double.isNaN(x)) {
      throw new IllegalArgumentException("Cannot add NaN to t-digest");
    }
    if (tempUsed >= tempWeight.length - lastUsedCell - 1) {
      mergeNewValues();
    }
    int where = tempUsed++;
    tempWeight[where] = w;
    tempMean[where] = x;
    unmergedWeight += w;

    if (data != null) {
      if (tempData == null) {
        tempData = new ArrayList<>();
      }
      while (tempData.size() <= where) {
        tempData.add(new ArrayList<>());
      }
      if (history == null) {
        history = Collections.singletonList(x);
      }
      tempData.get(where).addAll(history);
    }
  }

  private void add(double[] m, double[] w, int count, List<List<Double>> data) {
    if (m.length != w.length) {
      throw new IllegalArgumentException("Arrays not same length");
    }
    if (m.length < count + lastUsedCell) {
      // make room to add existing centroids
      double[] m1 = new double[count + lastUsedCell];
      System.arraycopy(m, 0, m1, 0, count);
      m = m1;
      double[] w1 = new double[count + lastUsedCell];
      System.arraycopy(w, 0, w1, 0, count);
      w = w1;
    }
    double total = 0;
    for (int i = 0; i < count; i++) {
      total += w[i];
    }
    merge(m, w, count, data, null, total);
  }

  @Override
  public void add(List<? extends TDigest> others) {
    if (others.size() == 0) {
      return;
    }
    int size = lastUsedCell;
    for (TDigest other : others) {
      other.compress();
      size += other.centroidCount();
    }

    double[] m = new double[size];
    double[] w = new double[size];
    List<List<Double>> data;
    if (recordAllData) {
      data = new ArrayList<>();
    } else {
      data = null;
    }
    int offset = 0;
    for (TDigest other : others) {
      if (other instanceof ResettableMergingDigest) {
        ResettableMergingDigest md = (ResettableMergingDigest) other;
        System.arraycopy(md.mean, 0, m, offset, md.lastUsedCell);
        System.arraycopy(md.weight, 0, w, offset, md.lastUsedCell);
        if (data != null) {
          for (Centroid centroid : other.centroids()) {
            data.add(centroid.data());
          }
        }
        offset += md.lastUsedCell;
      } else {
        for (Centroid centroid : other.centroids()) {
          m[offset] = centroid.mean();
          w[offset] = centroid.count();
          if (recordAllData) {
            assert data != null;
            data.add(centroid.data());
          }
          offset++;
        }
      }
    }
    add(m, w, size, data);
  }

  private void mergeNewValues() {
    if (unmergedWeight > 0) {
      merge(tempMean, tempWeight, tempUsed, tempData, order, unmergedWeight);
      tempUsed = 0;
      unmergedWeight = 0;
      if (data != null) {
        tempData = new ArrayList<>();
      }
    }
  }

  private void merge(
      double[] incomingMean,
      double[] incomingWeight,
      int incomingCount,
      List<List<Double>> incomingData,
      int[] incomingOrder,
      double unmergedWeight) {
    System.arraycopy(mean, 0, incomingMean, incomingCount, lastUsedCell);
    System.arraycopy(weight, 0, incomingWeight, incomingCount, lastUsedCell);
    incomingCount += lastUsedCell;

    if (incomingData != null) {
      for (int i = 0; i < lastUsedCell; i++) {
        assert data != null;
        incomingData.add(data.get(i));
      }
      data = new ArrayList<>();
    }
    if (incomingOrder == null) {
      incomingOrder = new int[incomingCount];
    }
    Sort.sort(incomingOrder, incomingMean, incomingCount);

    totalWeight += unmergedWeight;
    double normalizer = compression / (Math.PI * totalWeight);

    assert incomingCount > 0;
    lastUsedCell = 0;
    mean[lastUsedCell] = incomingMean[incomingOrder[0]];
    weight[lastUsedCell] = incomingWeight[incomingOrder[0]];
    double wSoFar = 0;
    if (data != null) {
      assert incomingData != null;
      data.add(incomingData.get(incomingOrder[0]));
    }

    double k1 = 0;

    // weight will contain all zeros
    double wLimit;
    wLimit = totalWeight * integratedQ(k1 + 1);
    for (int i = 1; i < incomingCount; i++) {
      int ix = incomingOrder[i];
      double proposedWeight = weight[lastUsedCell] + incomingWeight[ix];
      double projectedW = wSoFar + proposedWeight;
      boolean addThis;
      if (useWeightLimit) {
        double z = proposedWeight * normalizer;
        double q0 = wSoFar / totalWeight;
        double q2 = (wSoFar + proposedWeight) / totalWeight;
        addThis = z * z <= q0 * (1 - q0) && z * z <= q2 * (1 - q2);
      } else {
        addThis = projectedW <= wLimit;
      }

      if (addThis) {
        // next point will fit
        // so merge into existing centroid
        weight[lastUsedCell] += incomingWeight[ix];
        mean[lastUsedCell] =
            mean[lastUsedCell]
                + (incomingMean[ix] - mean[lastUsedCell])
                    * incomingWeight[ix]
                    / weight[lastUsedCell];
        incomingWeight[ix] = 0;

        if (data != null) {
          while (data.size() <= lastUsedCell) {
            data.add(new ArrayList());
          }
          assert incomingData != null;
          assert data.get(lastUsedCell) != incomingData.get(ix);
          data.get(lastUsedCell).addAll(incomingData.get(ix));
        }
      } else {
        // didn't fit ... move to next output, copy out first centroid
        wSoFar += weight[lastUsedCell];
        if (!useWeightLimit) {
          k1 = integratedLocation(wSoFar / totalWeight);
          wLimit = totalWeight * integratedQ(k1 + 1);
        }

        lastUsedCell++;
        mean[lastUsedCell] = incomingMean[ix];
        weight[lastUsedCell] = incomingWeight[ix];
        incomingWeight[ix] = 0;

        if (data != null) {
          assert incomingData != null;
          assert data.size() == lastUsedCell;
          data.add(incomingData.get(ix));
        }
      }
    }
    // points to next empty cell
    lastUsedCell++;

    // sanity check
    double sum = 0;
    for (int i = 0; i < lastUsedCell; i++) {
      sum += weight[i];
    }
    assert sum == totalWeight;

    if (totalWeight > 0) {
      min = Math.min(min, mean[0]);
      max = Math.max(max, mean[lastUsedCell - 1]);
    }
  }

  /** Exposed for testing. */
  int checkWeights() {
    return checkWeights(weight, totalWeight, lastUsedCell);
  }

  private int checkWeights(double[] w, double total, int last) {
    int badCount = 0;

    int n = last;
    if (w[n] > 0) {
      n++;
    }

    double k1 = 0;
    double q = 0;
    double left = 0;
    String header = "\n";
    for (int i = 0; i < n; i++) {
      double dq = w[i] / total;
      double k2 = integratedLocation(q + dq);
      q += dq / 2;
      if (k2 - k1 > 1 && w[i] != 1) {
        System.out.printf(
            "%sOversize centroid at %d, k0=%.2f, k1=%.2f, dk=%.2f, w=%.2f, q=%.4f, dq=%.4f,"
                + " left=%.1f, current=%.2f maxw=%.2f\n",
            header,
            i,
            k1,
            k2,
            k2 - k1,
            w[i],
            q,
            dq,
            left,
            w[i],
            Math.PI * total / compression * Math.sqrt(q * (1 - q)));
        header = "";
        badCount++;
      }
      if (k2 - k1 > 4 && w[i] != 1) {
        throw new IllegalStateException(
            String.format(
                "Egregiously oversized centroid at %d, k0=%.2f, k1=%.2f, dk=%.2f, w=%.2f, q=%.4f,"
                    + " dq=%.4f, left=%.1f, current=%.2f, maxw=%.2f\n",
                i,
                k1,
                k2,
                k2 - k1,
                w[i],
                q,
                dq,
                left,
                w[i],
                Math.PI * total / compression * Math.sqrt(q * (1 - q))));
      }
      q += dq / 2;
      left += w[i];
      k1 = k2;
    }

    return badCount;
  }

  /**
   * Converts a quantile into a centroid scale value. The centroid scale is nominally the number k
   * of the centroid that a quantile point q should belong to. Due to round-offs, however, we can't
   * align things perfectly without splitting points and centroids. We don't want to do that, so we
   * have to allow for offsets. In the end, the criterion is that any quantile range that spans a
   * centroid scale range more than one should be split across more than one centroid if possible.
   * This won't be possible if the quantile range refers to a single point or an already existing
   * centroid.
   *
   * <p>This mapping is steep near q=0 or q=1 so each centroid there will correspond to less q
   * range. Near q=0.5, the mapping is flatter so that centroids there will represent a larger chunk
   * of quantiles.
   *
   * @param q The quantile scale value to be mapped.
   * @return The centroid scale value corresponding to q.
   */
  private double integratedLocation(double q) {
    return compression * (asinApproximation(2 * q - 1) + Math.PI / 2) / Math.PI;
  }

  private double integratedQ(double k) {
    return (Math.sin(Math.min(k, compression) * Math.PI / compression - Math.PI / 2) + 1) / 2;
  }

  static double asinApproximation(double x) {
    if (usePieceWiseApproximation) {
      if (x < 0) {
        return -asinApproximation(-x);
      } else {
        // this approximation works by breaking that range from 0 to 1 into 5 regions
        // for all but the region nearest 1, rational polynomial models get us a very
        // good approximation of asin and by interpolating as we move from region to
        // region, we can guarantee continuity and we happen to get monotonicity as well.
        // for the values near 1, we just use Math.asin as our region "approximation".

        // cutoffs for models. Note that the ranges overlap. In the overlap we do
        // linear interpolation to guarantee the overall result is "nice"
        double c0High = 0.1;
        double c1High = 0.55;
        double c2Low = 0.5;
        double c2High = 0.8;
        double c3Low = 0.75;
        double c3High = 0.9;
        double c4Low = 0.87;
        if (x > c3High) {
          return Math.asin(x);
        } else {
          // the models
          double[] m0 = {
            0.2955302411, 1.2221903614, 0.1488583743, 0.2422015816, -0.3688700895, 0.0733398445
          };
          double[] m1 = {
            -0.0430991920, 0.9594035750, -0.0362312299, 0.1204623351, 0.0457029620, -0.0026025285
          };
          double[] m2 = {
            -0.034873933724,
            1.054796752703,
            -0.194127063385,
            0.283963735636,
            0.023800124916,
            -0.000872727381
          };
          double[] m3 = {
            -0.37588391875,
            2.61991859025,
            -2.48835406886,
            1.48605387425,
            0.00857627492,
            -0.00015802871
          };

          // the parameters for all of the models
          double[] vars = {1, x, x * x, x * x * x, 1 / (1 - x), 1 / (1 - x) / (1 - x)};

          // raw grist for interpolation coefficients
          double x0 = bound((c0High - x) / c0High);
          double x1 = bound((c1High - x) / (c1High - c2Low));
          double x2 = bound((c2High - x) / (c2High - c3Low));
          double x3 = bound((c3High - x) / (c3High - c4Low));

          // interpolation coefficients
          //noinspection UnnecessaryLocalVariable
          double mix0 = x0;
          double mix1 = (1 - x0) * x1;
          double mix2 = (1 - x1) * x2;
          double mix3 = (1 - x2) * x3;
          double mix4 = 1 - x3;

          // now mix all the results together, avoiding extra evaluations
          double r = 0;
          if (mix0 > 0) {
            r += mix0 * eval(m0, vars);
          }
          if (mix1 > 0) {
            r += mix1 * eval(m1, vars);
          }
          if (mix2 > 0) {
            r += mix2 * eval(m2, vars);
          }
          if (mix3 > 0) {
            r += mix3 * eval(m3, vars);
          }
          if (mix4 > 0) {
            // model 4 is just the real deal
            r += mix4 * Math.asin(x);
          }
          return r;
        }
      }
    } else {
      return Math.asin(x);
    }
  }

  private static double eval(double[] model, double[] vars) {
    double r = 0;
    for (int i = 0; i < model.length; i++) {
      r += model[i] * vars[i];
    }
    return r;
  }

  private static double bound(double v) {
    if (v <= 0) {
      return 0;
    } else if (v >= 1) {
      return 1;
    } else {
      return v;
    }
  }

  @Override
  public void compress() {
    mergeNewValues();
  }

  @Override
  public long size() {
    return (long) (totalWeight + unmergedWeight);
  }

  @Override
  public double cdf(double x) {
    mergeNewValues();

    if (lastUsedCell == 0) {
      // no data to examine
      return Double.NaN;
    } else if (lastUsedCell == 1) {
      // exactly one centroid, should have max==min
      double width = max - min;
      if (x < min) {
        return 0;
      } else if (x > max) {
        return 1;
      } else if (x - min <= width) {
        // min and max are too close together to do any viable interpolation
        return 0.5;
      } else {
        // interpolate if somehow we have weight > 0 and max != min
        return (x - min) / (max - min);
      }
    } else {
      int n = lastUsedCell;
      if (x <= min) {
        return 0;
      }

      if (x >= max) {
        return 1;
      }

      // check for the left tail
      if (x <= mean[0]) {
        // note that this is different than mean[0] > min
        // ... this guarantees we divide by non-zero number and interpolation works
        if (mean[0] - min > 0) {
          return (x - min) / (mean[0] - min) * weight[0] / totalWeight / 2;
        } else {
          return 0;
        }
      }
      assert x > mean[0];

      // and the right tail
      if (x >= mean[n - 1]) {
        if (max - mean[n - 1] > 0) {
          return 1 - (max - x) / (max - mean[n - 1]) * weight[n - 1] / totalWeight / 2;
        } else {
          return 1;
        }
      }
      assert x < mean[n - 1];

      // we know that there are at least two centroids and x > mean[0] && x < mean[n-1]
      // that means that there are either a bunch of consecutive centroids all equal at x
      // or there are consecutive centroids, c0 <= x and c1 > x
      double weightSoFar = weight[0] / 2;
      for (int it = 0; it < n; it++) {
        if (mean[it] == x) {
          double w0 = weightSoFar;
          while (it < n && mean[it + 1] == x) {
            weightSoFar += (weight[it] + weight[it + 1]);
            it++;
          }
          return (w0 + weightSoFar) / 2 / totalWeight;
        }
        if (mean[it] <= x && mean[it + 1] > x) {
          if (mean[it + 1] - mean[it] > 0) {
            double dw = (weight[it] + weight[it + 1]) / 2;
            return (weightSoFar + dw * (x - mean[it]) / (mean[it + 1] - mean[it])) / totalWeight;
          } else {
            // this is simply caution against floating point madness
            // it is conceivable that the centroids will be different
            // but too near to allow safe interpolation
            double dw = (weight[it] + weight[it + 1]) / 2;
            return weightSoFar + dw / totalWeight;
          }
        }
        weightSoFar += (weight[it] + weight[it + 1]) / 2;
      }
      // it should not be possible for the loop fall through
      throw new IllegalStateException("Can't happen ... loop fell through");
    }
  }

  @Override
  public double quantile(double q) {
    if (q < 0 || q > 1) {
      throw new IllegalArgumentException("q should be in [0,1], got " + q);
    }
    mergeNewValues();

    if (lastUsedCell == 0 && weight[lastUsedCell] == 0) {
      // no centroids means no data, no way to get a quantile
      return Double.NaN;
    } else if (lastUsedCell == 0) {
      // with one data point, all quantiles lead to Rome
      return mean[0];
    }

    // we know that there are at least two centroids now
    int n = lastUsedCell;

    // if values were stored in a sorted array, index would be the offset we are interested in
    final double index = q * totalWeight;

    // at the boundaries, we return min or max
    if (index < weight[0] / 2) {
      assert weight[0] > 0;
      return min + 2 * index / weight[0] * (mean[0] - min);
    }

    // in between we interpolate between centroids
    double weightSoFar = weight[0] / 2;
    for (int i = 0; i < n - 1; i++) {
      double dw = (weight[i] + weight[i + 1]) / 2;
      if (weightSoFar + dw > index) {
        // centroids i and i+1 bracket our current point
        double z1 = index - weightSoFar;
        double z2 = weightSoFar + dw - index;
        return weightedAverage(mean[i], z2, mean[i + 1], z1);
      }
      weightSoFar += dw;
    }
    assert index <= totalWeight;
    assert index >= totalWeight - weight[n - 1] / 2;

    // weightSoFar = totalWeight - weight[n-1]/2 (very nearly)
    // so we interpolate out to max value ever seen
    double z1 = index - totalWeight - weight[n - 1] / 2.0;
    double z2 = weight[n - 1] / 2 - z1;
    return weightedAverage(mean[n - 1], z1, max, z2);
  }

  @Override
  public int centroidCount() {
    return lastUsedCell;
  }

  @Override
  public Collection centroids() {
    // we don't actually keep centroid structures around so we have to fake it
    compress();
    return new AbstractCollection() {
      @Override
      public Iterator iterator() {
        return new Iterator() {
          int i = 0;

          @Override
          public boolean hasNext() {
            return i < lastUsedCell;
          }

          @Override
          public Centroid next() {
            Centroid rc = new Centroid(mean[i], (int) weight[i], data != null ? data.get(i) : null);
            i++;
            return rc;
          }

          @Override
          public void remove() {
            throw new UnsupportedOperationException("Default operation");
          }
        };
      }

      @Override
      public int size() {
        return lastUsedCell;
      }
    };
  }

  @Override
  public double compression() {
    return compression;
  }

  @Override
  public int byteSize() {
    compress();
    // format code, compression(float), buffer-size(int), temp-size(int), #centroids-1(int),
    // then two doubles per centroid
    return lastUsedCell * 16 + 32;
  }

  @Override
  public int smallByteSize() {
    compress();
    // format code(int), compression(float), buffer-size(short), temp-size(short),
    // #centroids-1(short),
    // then two floats per centroid
    return lastUsedCell * 8 + 30;
  }

  public enum Encoding {
    VERBOSE_ENCODING(1),
    SMALL_ENCODING(2);

    private final int code;

    Encoding(int code) {
      this.code = code;
    }
  }

  @Override
  public void asBytes(ByteBuffer buf) {
    compress();
    buf.putInt(Encoding.VERBOSE_ENCODING.code);
    buf.putDouble(min);
    buf.putDouble(max);
    buf.putDouble(compression);
    buf.putInt(lastUsedCell);
    for (int i = 0; i < lastUsedCell; i++) {
      buf.putDouble(weight[i]);
      buf.putDouble(mean[i]);
    }
  }

  @Override
  public void asSmallBytes(ByteBuffer buf) {
    compress();
    buf.putInt(Encoding.SMALL_ENCODING.code); // 4
    buf.putDouble(min); // + 8
    buf.putDouble(max); // + 8
    buf.putFloat((float) compression); // + 4
    buf.putShort((short) mean.length); // + 2
    buf.putShort((short) tempMean.length); // + 2
    buf.putShort((short) lastUsedCell); // + 2 = 30
    for (int i = 0; i < lastUsedCell; i++) {
      buf.putFloat((float) weight[i]);
      buf.putFloat((float) mean[i]);
    }
  }

  @SuppressWarnings("WeakerAccess")
  public static ResettableMergingDigest fromBytes(ByteBuffer buf) {
    int encoding = buf.getInt();
    if (encoding == Encoding.VERBOSE_ENCODING.code) {
      double min = buf.getDouble();
      double max = buf.getDouble();
      double compression = buf.getDouble();
      int n = buf.getInt();
      ResettableMergingDigest r = new ResettableMergingDigest(compression);
      r.setMinMax(min, max);
      r.lastUsedCell = n;
      for (int i = 0; i < n; i++) {
        r.weight[i] = buf.getDouble();
        r.mean[i] = buf.getDouble();

        r.totalWeight += r.weight[i];
      }
      return r;
    } else if (encoding == Encoding.SMALL_ENCODING.code) {
      double min = buf.getDouble();
      double max = buf.getDouble();
      double compression = buf.getFloat();
      int n = buf.getShort();
      int bufferSize = buf.getShort();
      ResettableMergingDigest r = new ResettableMergingDigest(compression, bufferSize, n);
      r.setMinMax(min, max);
      r.lastUsedCell = buf.getShort();
      for (int i = 0; i < r.lastUsedCell; i++) {
        r.weight[i] = buf.getFloat();
        r.mean[i] = buf.getFloat();

        r.totalWeight += r.weight[i];
      }
      return r;
    } else {
      throw new IllegalStateException("Invalid format for serialized histogram");
    }
  }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants