diff --git a/planetiler-benchmarks/src/main/java/com/onthegomap/planetiler/collection/BenchmarkKWayMerge.java b/planetiler-benchmarks/src/main/java/com/onthegomap/planetiler/collection/BenchmarkKWayMerge.java index 20c9548a28..a11ff03326 100644 --- a/planetiler-benchmarks/src/main/java/com/onthegomap/planetiler/collection/BenchmarkKWayMerge.java +++ b/planetiler-benchmarks/src/main/java/com/onthegomap/planetiler/collection/BenchmarkKWayMerge.java @@ -15,7 +15,7 @@ public class BenchmarkKWayMerge { public static void main(String[] args) { for (int i = 0; i < 4; i++) { System.err.println(); - testMinHeap("quaternary", n -> LongMinHeap.newArrayHeap(n, (a, b) -> 0)); + testMinHeap("quaternary", n -> LongMinHeap.newArrayHeap(n, Integer::compare)); System.err.println(String.join("\t", "priorityqueue", Long.toString(testPriorityQueue(10).toMillis()), diff --git a/planetiler-core/src/main/java/com/onthegomap/planetiler/collection/ArrayLongMinHeap.java b/planetiler-core/src/main/java/com/onthegomap/planetiler/collection/ArrayLongMinHeap.java index 213d781287..f2464ae29c 100644 --- a/planetiler-core/src/main/java/com/onthegomap/planetiler/collection/ArrayLongMinHeap.java +++ b/planetiler-core/src/main/java/com/onthegomap/planetiler/collection/ArrayLongMinHeap.java @@ -105,21 +105,18 @@ public boolean contains(int id) { @Override public void update(int id, long value) { checkIdInRange(id); - int index = idToPos[id]; - if (index < 0) { + int pos = idToPos[id]; + if (pos < 0) { throw new IllegalStateException( "The heap does not contain: " + id + ". Use the contains method to check this before calling update"); } - long prev = posToValue[index]; - posToValue[index] = value; - int cmp = Long.compare(value, prev); - if (cmp == 0 && value != Long.MIN_VALUE) { - cmp = tieBreaker.applyAsInt(id, posToId[index]); - } + long prev = posToValue[pos]; + posToValue[pos] = value; + int cmp = compareIdPos(value, prev, id, pos); if (cmp > 0) { - percolateDown(index); + percolateDown(pos); } else if (cmp < 0) { - percolateUp(index); + percolateUp(pos); } } @@ -169,8 +166,7 @@ private void percolateUp(int pos) { // the finish condition (index==0) is covered here automatically because we set vals[0]=-inf int parent; long parentValue; - while (val < (parentValue = posToValue[parent = parent(pos)]) || - (val == parentValue && val != Long.MIN_VALUE && tieBreaker.applyAsInt(id, posToId[parent]) < 0)) { + while (compareIdPos(val, parentValue = posToValue[parent = parent(pos)], id, parent) < 0) { posToValue[pos] = parentValue; idToPos[posToId[pos] = posToId[parent]] = pos; pos = parent; @@ -201,29 +197,23 @@ private void percolateDown(int pos) { int minChild = child; long minValue = posToValue[child], childValue; if (++child <= size) { - if ((childValue = posToValue[child]) < minValue || - (childValue == minValue && childValue != Long.MIN_VALUE && - tieBreaker.applyAsInt(posToId[child], posToId[minChild]) < 0)) { + if (comparePosPos(childValue = posToValue[child], minValue, child, minChild) < 0) { minChild = child; minValue = childValue; } if (++child <= size) { - if ((childValue = posToValue[child]) < minValue || - (childValue == minValue && childValue != Long.MIN_VALUE && - tieBreaker.applyAsInt(posToId[child], posToId[minChild]) < 0)) { + if (comparePosPos(childValue = posToValue[child], minValue, child, minChild) < 0) { minChild = child; minValue = childValue; } - if (++child <= size && ((childValue = posToValue[child]) < minValue || - (childValue == minValue && childValue != Long.MIN_VALUE && - tieBreaker.applyAsInt(posToId[child], posToId[minChild]) < 0))) { + if (++child <= size && + comparePosPos(childValue = posToValue[child], minValue, child, minChild) < 0) { minChild = child; minValue = childValue; } } } - if (minValue > value || - (minValue == value && minValue != Long.MIN_VALUE && tieBreaker.applyAsInt(posToId[minChild], id) >= 0)) { + if (comparePosPos(value, minValue, pos, minChild) <= 0) { break; } posToValue[pos] = minValue; @@ -234,4 +224,23 @@ private void percolateDown(int pos) { posToValue[pos] = value; idToPos[id] = pos; } + + private int comparePosPos(long val1, long val2, int pos1, int pos2) { + if (val1 < val2) { + return -1; + } else if (val1 == val2 && val1 != Long.MIN_VALUE) { + return tieBreaker.applyAsInt(posToId[pos1], posToId[pos2]); + } + return 1; + } + + private int compareIdPos(long val1, long val2, int id1, int pos2) { + if (val1 < val2) { + return -1; + } else if (val1 == val2 && val1 != Long.MIN_VALUE) { + return tieBreaker.applyAsInt(id1, posToId[pos2]); + } + return 1; + } + }