From 0dc2ee82e1714ed1b7e1f8b2f5767cdcb03c12aa Mon Sep 17 00:00:00 2001 From: Michael Barry Date: Thu, 11 Jan 2024 08:42:16 -0500 Subject: [PATCH] Deterministic merging (#785) --- .../collection/BenchmarkKWayMerge.java | 2 +- .../com/onthegomap/planetiler/VectorTile.java | 4 +- .../collection/ArrayLongMinHeap.java | 146 ++++++++------ .../collection/ExternalMergeSort.java | 2 +- .../planetiler/collection/FeatureSort.java | 2 +- .../planetiler/collection/HasLongSortKey.java | 2 +- .../planetiler/collection/LongMerger.java | 56 ++++-- .../planetiler/collection/LongMinHeap.java | 6 +- .../collection/SortableFeature.java | 10 +- .../planetiler/util/CompareArchives.java | 4 +- .../planetiler/collection/LongMergerTest.java | 190 ++++++++++-------- .../collection/LongMinHeapTest.java | 29 ++- 12 files changed, 285 insertions(+), 168 deletions(-) diff --git a/planetiler-benchmarks/src/main/java/com/onthegomap/planetiler/collection/BenchmarkKWayMerge.java b/planetiler-benchmarks/src/main/java/com/onthegomap/planetiler/collection/BenchmarkKWayMerge.java index 6ad42302fb..a11ff03326 100644 --- a/planetiler-benchmarks/src/main/java/com/onthegomap/planetiler/collection/BenchmarkKWayMerge.java +++ b/planetiler-benchmarks/src/main/java/com/onthegomap/planetiler/collection/BenchmarkKWayMerge.java @@ -15,7 +15,7 @@ public class BenchmarkKWayMerge { public static void main(String[] args) { for (int i = 0; i < 4; i++) { System.err.println(); - testMinHeap("quaternary", LongMinHeap::newArrayHeap); + testMinHeap("quaternary", n -> LongMinHeap.newArrayHeap(n, Integer::compare)); System.err.println(String.join("\t", "priorityqueue", Long.toString(testPriorityQueue(10).toMillis()), diff --git a/planetiler-core/src/main/java/com/onthegomap/planetiler/VectorTile.java b/planetiler-core/src/main/java/com/onthegomap/planetiler/VectorTile.java index 1e46cf5e42..ec71780ca3 100644 --- a/planetiler-core/src/main/java/com/onthegomap/planetiler/VectorTile.java +++ b/planetiler-core/src/main/java/com/onthegomap/planetiler/VectorTile.java @@ -35,6 +35,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.TreeMap; import java.util.function.Consumer; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -80,7 +81,8 @@ public class VectorTile { // TODO make these configurable private static final int EXTENT = 4096; private static final double SIZE = 256d; - private final Map layers = new LinkedHashMap<>(); + // use a treemap to ensure that layers are encoded in a consistent order + private final Map layers = new TreeMap<>(); private LayerAttrStats.Updater.ForZoom layerStatsTracker = LayerAttrStats.Updater.ForZoom.NOOP; private static int[] getCommands(Geometry input, int scale) { diff --git a/planetiler-core/src/main/java/com/onthegomap/planetiler/collection/ArrayLongMinHeap.java b/planetiler-core/src/main/java/com/onthegomap/planetiler/collection/ArrayLongMinHeap.java index 176ea023e1..f2464ae29c 100644 --- a/planetiler-core/src/main/java/com/onthegomap/planetiler/collection/ArrayLongMinHeap.java +++ b/planetiler-core/src/main/java/com/onthegomap/planetiler/collection/ArrayLongMinHeap.java @@ -18,6 +18,7 @@ package com.onthegomap.planetiler.collection; import java.util.Arrays; +import java.util.function.IntBinaryOperator; /** * A min-heap stored in an array where each element has 4 children. @@ -38,24 +39,26 @@ */ class ArrayLongMinHeap implements LongMinHeap { protected static final int NOT_PRESENT = -1; - protected final int[] tree; - protected final int[] positions; - protected final long[] vals; + protected final int[] posToId; + protected final int[] idToPos; + protected final long[] posToValue; protected final int max; protected int size; + private final IntBinaryOperator tieBreaker; /** * @param elements the number of elements that can be stored in this heap. Currently the heap cannot be resized or * shrunk/trimmed after initial creation. elements-1 is the maximum id that can be stored in this heap */ - ArrayLongMinHeap(int elements) { + ArrayLongMinHeap(int elements, IntBinaryOperator tieBreaker) { // we use an offset of one to make the arithmetic a bit simpler/more efficient, the 0th elements are not used! - tree = new int[elements + 1]; - positions = new int[elements + 1]; - Arrays.fill(positions, NOT_PRESENT); - vals = new long[elements + 1]; - vals[0] = Long.MIN_VALUE; + posToId = new int[elements + 1]; + idToPos = new int[elements + 1]; + Arrays.fill(idToPos, NOT_PRESENT); + posToValue = new long[elements + 1]; + posToValue[0] = Long.MIN_VALUE; this.max = elements; + this.tieBreaker = tieBreaker; } private static int firstChild(int index) { @@ -87,58 +90,59 @@ public void push(int id, long value) { " was pushed already, you need to use the update method if you want to change its value"); } size++; - tree[size] = id; - positions[id] = size; - vals[size] = value; + posToId[size] = id; + idToPos[id] = size; + posToValue[size] = value; percolateUp(size); } @Override public boolean contains(int id) { checkIdInRange(id); - return positions[id] != NOT_PRESENT; + return idToPos[id] != NOT_PRESENT; } @Override public void update(int id, long value) { checkIdInRange(id); - int index = positions[id]; - if (index < 0) { + int pos = idToPos[id]; + if (pos < 0) { throw new IllegalStateException( "The heap does not contain: " + id + ". Use the contains method to check this before calling update"); } - long prev = vals[index]; - vals[index] = value; - if (value > prev) { - percolateDown(index); - } else if (value < prev) { - percolateUp(index); + long prev = posToValue[pos]; + posToValue[pos] = value; + int cmp = compareIdPos(value, prev, id, pos); + if (cmp > 0) { + percolateDown(pos); + } else if (cmp < 0) { + percolateUp(pos); } } @Override public void updateHead(long value) { - vals[1] = value; + posToValue[1] = value; percolateDown(1); } @Override public int peekId() { - return tree[1]; + return posToId[1]; } @Override public long peekValue() { - return vals[1]; + return posToValue[1]; } @Override public int poll() { int id = peekId(); - tree[1] = tree[size]; - vals[1] = vals[size]; - positions[tree[1]] = 1; - positions[id] = NOT_PRESENT; + posToId[1] = posToId[size]; + posToValue[1] = posToValue[size]; + idToPos[posToId[1]] = 1; + idToPos[id] = NOT_PRESENT; size--; percolateDown(1); return id; @@ -147,29 +151,29 @@ public int poll() { @Override public void clear() { for (int i = 1; i <= size; i++) { - positions[tree[i]] = NOT_PRESENT; + idToPos[posToId[i]] = NOT_PRESENT; } size = 0; } - private void percolateUp(int index) { - assert index != 0; - if (index == 1) { + private void percolateUp(int pos) { + assert pos != 0; + if (pos == 1) { return; } - final int el = tree[index]; - final long val = vals[index]; + final int id = posToId[pos]; + final long val = posToValue[pos]; // the finish condition (index==0) is covered here automatically because we set vals[0]=-inf int parent; long parentValue; - while (val < (parentValue = vals[parent = parent(index)])) { - vals[index] = parentValue; - positions[tree[index] = tree[parent]] = index; - index = parent; + while (compareIdPos(val, parentValue = posToValue[parent = parent(pos)], id, parent) < 0) { + posToValue[pos] = parentValue; + idToPos[posToId[pos] = posToId[parent]] = pos; + pos = parent; } - tree[index] = el; - vals[index] = val; - positions[tree[index]] = index; + posToId[pos] = id; + posToValue[pos] = val; + idToPos[posToId[pos]] = pos; } private void checkIdInRange(int id) { @@ -178,45 +182,65 @@ private void checkIdInRange(int id) { } } - private void percolateDown(int index) { + private void percolateDown(int pos) { if (size == 0) { return; } - assert index > 0; - assert index <= size; - final int el = tree[index]; - final long val = vals[index]; + assert pos > 0; + assert pos <= size; + final int id = posToId[pos]; + final long value = posToValue[pos]; int child; - while ((child = firstChild(index)) <= size) { + while ((child = firstChild(pos)) <= size) { // optimization: this is a very hot code path for performance of k-way merging, // so manually-unroll the loop over the 4 child elements to find the minimum value int minChild = child; - long minValue = vals[child], value; + long minValue = posToValue[child], childValue; if (++child <= size) { - if ((value = vals[child]) < minValue) { + if (comparePosPos(childValue = posToValue[child], minValue, child, minChild) < 0) { minChild = child; - minValue = value; + minValue = childValue; } if (++child <= size) { - if ((value = vals[child]) < minValue) { + if (comparePosPos(childValue = posToValue[child], minValue, child, minChild) < 0) { minChild = child; - minValue = value; + minValue = childValue; } - if (++child <= size && (value = vals[child]) < minValue) { + if (++child <= size && + comparePosPos(childValue = posToValue[child], minValue, child, minChild) < 0) { minChild = child; - minValue = value; + minValue = childValue; } } } - if (minValue >= val) { + if (comparePosPos(value, minValue, pos, minChild) <= 0) { break; } - vals[index] = minValue; - positions[tree[index] = tree[minChild]] = index; - index = minChild; + posToValue[pos] = minValue; + idToPos[posToId[pos] = posToId[minChild]] = pos; + pos = minChild; } - tree[index] = el; - vals[index] = val; - positions[el] = index; + posToId[pos] = id; + posToValue[pos] = value; + idToPos[id] = pos; } + + private int comparePosPos(long val1, long val2, int pos1, int pos2) { + if (val1 < val2) { + return -1; + } else if (val1 == val2 && val1 != Long.MIN_VALUE) { + return tieBreaker.applyAsInt(posToId[pos1], posToId[pos2]); + } + return 1; + } + + private int compareIdPos(long val1, long val2, int id1, int pos2) { + if (val1 < val2) { + return -1; + } else if (val1 == val2 && val1 != Long.MIN_VALUE) { + return tieBreaker.applyAsInt(id1, posToId[pos2]); + } + return 1; + } + } diff --git a/planetiler-core/src/main/java/com/onthegomap/planetiler/collection/ExternalMergeSort.java b/planetiler-core/src/main/java/com/onthegomap/planetiler/collection/ExternalMergeSort.java index 11a63af5be..15cf934527 100644 --- a/planetiler-core/src/main/java/com/onthegomap/planetiler/collection/ExternalMergeSort.java +++ b/planetiler-core/src/main/java/com/onthegomap/planetiler/collection/ExternalMergeSort.java @@ -253,7 +253,7 @@ public Iterator iterator(int shard, int shards) { } } - return LongMerger.mergeIterators(iterators); + return LongMerger.mergeIterators(iterators, SortableFeature.COMPARE_BYTES); } public int chunks() { diff --git a/planetiler-core/src/main/java/com/onthegomap/planetiler/collection/FeatureSort.java b/planetiler-core/src/main/java/com/onthegomap/planetiler/collection/FeatureSort.java index c209c89ff1..4257f31f21 100644 --- a/planetiler-core/src/main/java/com/onthegomap/planetiler/collection/FeatureSort.java +++ b/planetiler-core/src/main/java/com/onthegomap/planetiler/collection/FeatureSort.java @@ -131,7 +131,7 @@ default ParallelIterator parallelIterator(Stats stats, int threads) { } } }); - return new ParallelIterator(reader, LongMerger.mergeSuppliers(queues)); + return new ParallelIterator(reader, LongMerger.mergeSuppliers(queues, SortableFeature.COMPARE_BYTES)); } record ParallelIterator(Worker reader, @Override Iterator iterator) diff --git a/planetiler-core/src/main/java/com/onthegomap/planetiler/collection/HasLongSortKey.java b/planetiler-core/src/main/java/com/onthegomap/planetiler/collection/HasLongSortKey.java index 8482d43556..e4a40cf812 100644 --- a/planetiler-core/src/main/java/com/onthegomap/planetiler/collection/HasLongSortKey.java +++ b/planetiler-core/src/main/java/com/onthegomap/planetiler/collection/HasLongSortKey.java @@ -2,7 +2,7 @@ /** * An item with a {@code long key} that can be used for sorting/grouping. - * + *

* These items can be sorted or grouped by {@link FeatureSort}/{@link FeatureGroup} implementations. Sorted lists can * also be merged using {@link LongMerger}. */ diff --git a/planetiler-core/src/main/java/com/onthegomap/planetiler/collection/LongMerger.java b/planetiler-core/src/main/java/com/onthegomap/planetiler/collection/LongMerger.java index 16bcfa9dcf..1303924c11 100644 --- a/planetiler-core/src/main/java/com/onthegomap/planetiler/collection/LongMerger.java +++ b/planetiler-core/src/main/java/com/onthegomap/planetiler/collection/LongMerger.java @@ -1,6 +1,7 @@ package com.onthegomap.planetiler.collection; import java.util.Collections; +import java.util.Comparator; import java.util.Iterator; import java.util.List; import java.util.NoSuchElementException; @@ -16,29 +17,34 @@ public class LongMerger { private LongMerger() {} /** Merges sorted items from {@link Supplier Suppliers} that return {@code null} when there are no items left. */ - public static Iterator mergeSuppliers(List> suppliers) { - return mergeIterators(suppliers.stream().map(SupplierIterator::new).toList()); + public static Iterator mergeSuppliers(List> suppliers, + Comparator tieBreaker) { + return mergeIterators(suppliers.stream().map(SupplierIterator::new).toList(), tieBreaker); } /** Merges sorted iterators into a combined iterator over all the items. */ - public static Iterator mergeIterators(List> iterators) { + public static Iterator mergeIterators(List> iterators, + Comparator tieBreaker) { return switch (iterators.size()) { case 0 -> Collections.emptyIterator(); case 1 -> iterators.get(0); - case 2 -> new TwoWayMerge<>(iterators.get(0), iterators.get(1)); - case 3 -> new ThreeWayMerge<>(iterators.get(0), iterators.get(1), iterators.get(2)); - default -> new KWayMerge<>(iterators); + case 2 -> new TwoWayMerge<>(iterators.get(0), iterators.get(1), tieBreaker); + case 3 -> new ThreeWayMerge<>(iterators.get(0), iterators.get(1), iterators.get(2), tieBreaker); + default -> new KWayMerge<>(iterators, tieBreaker); }; } private static class TwoWayMerge implements Iterator { + + private final Comparator tieBreaker; T a, b; long ak = Long.MAX_VALUE, bk = Long.MAX_VALUE; final Iterator inputA, inputB; - TwoWayMerge(Iterator inputA, Iterator inputB) { + TwoWayMerge(Iterator inputA, Iterator inputB, Comparator tieBreaker) { this.inputA = inputA; this.inputB = inputB; + this.tieBreaker = tieBreaker; if (inputA.hasNext()) { a = inputA.next(); ak = a.key(); @@ -57,7 +63,7 @@ public boolean hasNext() { @Override public T next() { T result; - if (ak < bk) { + if (lessThan(ak, bk, a, b)) { result = a; if (inputA.hasNext()) { a = inputA.next(); @@ -80,14 +86,21 @@ public T next() { } return result; } + + private boolean lessThan(long ak, long bk, T a, T b) { + return ak < bk || (ak == bk && lessThanCmp(a, b, tieBreaker)); + } } private static class ThreeWayMerge implements Iterator { + + private final Comparator tieBreaker; T a, b, c; long ak = Long.MAX_VALUE, bk = Long.MAX_VALUE, ck = Long.MAX_VALUE; final Iterator inputA, inputB, inputC; - ThreeWayMerge(Iterator inputA, Iterator inputB, Iterator inputC) { + ThreeWayMerge(Iterator inputA, Iterator inputB, Iterator inputC, Comparator tieBreaker) { + this.tieBreaker = tieBreaker; this.inputA = inputA; this.inputB = inputB; this.inputC = inputC; @@ -114,8 +127,8 @@ public boolean hasNext() { public T next() { T result; // use at most 2 comparisons to get the next item - if (ak < bk) { - if (ak < ck) { + if (lessThan(ak, bk, a, b)) { + if (lessThan(ak, ck, a, c)) { // ACB / ABC result = a; if (inputA.hasNext()) { @@ -136,7 +149,7 @@ public T next() { ck = Long.MAX_VALUE; } } - } else if (ck < bk) { + } else if (lessThan(ck, bk, c, b)) { // CAB result = c; if (inputC.hasNext()) { @@ -161,6 +174,21 @@ public T next() { } return result; } + + private boolean lessThan(long ak, long bk, T a, T b) { + return ak < bk || (ak == bk && lessThanCmp(a, b, tieBreaker)); + } + } + + private static boolean lessThanCmp(T a, T b, Comparator tieBreaker) { + // nulls go at the end + if (a == null) { + return false; + } else if (b == null) { + return true; + } else { + return tieBreaker.compare(a, b) < 0; + } } private static class KWayMerge implements Iterator { @@ -169,10 +197,10 @@ private static class KWayMerge implements Iterator private final LongMinHeap heap; @SuppressWarnings("unchecked") - KWayMerge(List> inputIterators) { + KWayMerge(List> inputIterators, Comparator tieBreaker) { this.iterators = new Iterator[inputIterators.size()]; this.items = (T[]) new HasLongSortKey[inputIterators.size()]; - this.heap = LongMinHeap.newArrayHeap(inputIterators.size()); + this.heap = LongMinHeap.newArrayHeap(inputIterators.size(), (a, b) -> tieBreaker.compare(items[a], items[b])); int outIdx = 0; for (Iterator iter : inputIterators) { if (iter.hasNext()) { diff --git a/planetiler-core/src/main/java/com/onthegomap/planetiler/collection/LongMinHeap.java b/planetiler-core/src/main/java/com/onthegomap/planetiler/collection/LongMinHeap.java index f29985e995..2b47b75b18 100644 --- a/planetiler-core/src/main/java/com/onthegomap/planetiler/collection/LongMinHeap.java +++ b/planetiler-core/src/main/java/com/onthegomap/planetiler/collection/LongMinHeap.java @@ -17,6 +17,8 @@ */ package com.onthegomap.planetiler.collection; +import java.util.function.IntBinaryOperator; + /** * API for min-heaps that keeps track of {@code int} keys in a range from {@code [0, size)} ordered by {@code long} * values. @@ -31,8 +33,8 @@ public interface LongMinHeap { *

* This is slightly faster than a traditional binary min heap due to a shallower, more cache-friendly memory layout. */ - static LongMinHeap newArrayHeap(int elements) { - return new ArrayLongMinHeap(elements); + static LongMinHeap newArrayHeap(int elements, IntBinaryOperator tieBreaker) { + return new ArrayLongMinHeap(elements, tieBreaker); } int size(); diff --git a/planetiler-core/src/main/java/com/onthegomap/planetiler/collection/SortableFeature.java b/planetiler-core/src/main/java/com/onthegomap/planetiler/collection/SortableFeature.java index e5df6c484f..7884799701 100644 --- a/planetiler-core/src/main/java/com/onthegomap/planetiler/collection/SortableFeature.java +++ b/planetiler-core/src/main/java/com/onthegomap/planetiler/collection/SortableFeature.java @@ -1,12 +1,20 @@ package com.onthegomap.planetiler.collection; import java.util.Arrays; +import java.util.Comparator; public record SortableFeature(@Override long key, byte[] value) implements Comparable, HasLongSortKey { + public static final Comparator COMPARE_BYTES = (a, b) -> Arrays.compareUnsigned(a.value, b.value); @Override public int compareTo(SortableFeature o) { - return Long.compare(key, o.key); + if (key < o.key) { + return -1; + } else if (key == o.key) { + return Arrays.compareUnsigned(value, o.value); + } else { + return 1; + } } @Override diff --git a/planetiler-core/src/main/java/com/onthegomap/planetiler/util/CompareArchives.java b/planetiler-core/src/main/java/com/onthegomap/planetiler/util/CompareArchives.java index 0a73c7e521..e84384d7fc 100644 --- a/planetiler-core/src/main/java/com/onthegomap/planetiler/util/CompareArchives.java +++ b/planetiler-core/src/main/java/com/onthegomap/planetiler/util/CompareArchives.java @@ -205,8 +205,8 @@ private void compareLayer(VectorTileProto.Tile.Layer layer1, VectorTileProto.Til compareList(name, "keys list", layer1.getKeysList(), layer2.getKeysList()); compareList(name, "values list", layer1.getValuesList(), layer2.getValuesList()); if (compareValues(name, "features count", layer1.getFeaturesCount(), layer2.getFeaturesCount())) { - var ids1 = layer1.getFeaturesList().stream().map(f -> f.getId()); - var ids2 = layer1.getFeaturesList().stream().map(f -> f.getId()); + var ids1 = layer1.getFeaturesList().stream().map(f -> f.getId()).toList(); + var ids2 = layer2.getFeaturesList().stream().map(f -> f.getId()).toList(); if (compareValues(name, "feature ids", Set.of(ids1), Set.of(ids2)) && compareValues(name, "feature order", ids1, ids2)) { for (int i = 0; i < layer1.getFeaturesCount() && i < layer2.getFeaturesCount(); i++) { diff --git a/planetiler-core/src/test/java/com/onthegomap/planetiler/collection/LongMergerTest.java b/planetiler-core/src/test/java/com/onthegomap/planetiler/collection/LongMergerTest.java index e0a8f89665..96e5d7f9e0 100644 --- a/planetiler-core/src/test/java/com/onthegomap/planetiler/collection/LongMergerTest.java +++ b/planetiler-core/src/test/java/com/onthegomap/planetiler/collection/LongMergerTest.java @@ -4,6 +4,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import java.util.ArrayList; +import java.util.Comparator; import java.util.List; import java.util.NoSuchElementException; import java.util.function.Supplier; @@ -12,21 +13,36 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; +import org.junit.jupiter.params.provider.ValueSource; class LongMergerTest { - record Item(long key) implements HasLongSortKey {} + record Item(long key, int secondary) implements HasLongSortKey, Comparable { + @Override + public int compareTo(Item o) { + int cmp = Long.compare(key, o.key); + if (cmp == 0) { + cmp = Integer.compare(secondary, o.secondary); + } + return cmp; + } + + long value() { + return key + secondary; + } + } record ItemList(List items) {} - private static ItemList list(long... items) { - return new ItemList(LongStream.of(items).mapToObj(Item::new).toList()); + private static ItemList list(boolean primaryKey, long... items) { + return new ItemList( + LongStream.of(items).mapToObj(i -> primaryKey ? new Item(i, 0) : new Item(0, (int) i)).toList()); } private static List merge(ItemList... lists) { List list = new ArrayList<>(); var iter = LongMerger.mergeIterators(Stream.of(lists) .map(d -> d.items.iterator()) - .toList()); - iter.forEachRemaining(item -> list.add(item.key)); + .toList(), Comparator.naturalOrder()); + iter.forEachRemaining(item -> list.add(item.value())); assertThrows(NoSuchElementException.class, iter::next); return list; } @@ -36,10 +52,11 @@ void testMergeEmpty() { assertEquals(List.of(), merge()); } - @Test - void testMergeSupplier() { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testMergeSupplier(boolean primaryKey) { List list = new ArrayList<>(); - var iter = LongMerger.mergeSuppliers(Stream.of(new ItemList[]{list(1, 2)}) + var iter = LongMerger.mergeSuppliers(Stream.of(new ItemList[]{list(primaryKey, 1, 2)}) .map(d -> d.items.iterator()) .>map(d -> () -> { try { @@ -48,17 +65,18 @@ void testMergeSupplier() { return null; } }) - .toList()); - iter.forEachRemaining(item -> list.add(item.key)); + .toList(), Comparator.naturalOrder()); + iter.forEachRemaining(item -> list.add(item.value())); assertThrows(NoSuchElementException.class, iter::next); assertEquals(List.of(1L, 2L), list); } - @Test - void testMerge1() { - assertEquals(List.of(), merge(list())); - assertEquals(List.of(1L), merge(list(1))); - assertEquals(List.of(1L, 2L), merge(list(1, 2))); + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testMerge1(boolean primaryKey) { + assertEquals(List.of(), merge(list(primaryKey))); + assertEquals(List.of(1L), merge(list(primaryKey, 1))); + assertEquals(List.of(1L, 2L), merge(list(primaryKey, 1, 2))); } @ParameterizedTest @@ -73,16 +91,20 @@ void testMerge1() { "1 3,2,1 2 3", }, nullValues = {"null"}) void testMerge2(String a, String b, String output) { - var listA = list(parse(a)); - var listB = list(parse(b)); - assertEquals( - LongStream.of(parse(output)).boxed().toList(), - merge(listA, listB) - ); - assertEquals( - LongStream.of(parse(output)).boxed().toList(), - merge(listB, listA) - ); + for (boolean primaryKey : List.of(false, true)) { + var listA = list(primaryKey, parse(a)); + var listB = list(primaryKey, parse(b)); + assertEquals( + LongStream.of(parse(output)).boxed().toList(), + merge(listA, listB), + "primary=" + primaryKey + ); + assertEquals( + LongStream.of(parse(output)).boxed().toList(), + merge(listB, listA), + "primary=" + primaryKey + ); + } } @ParameterizedTest @@ -98,39 +120,41 @@ void testMerge2(String a, String b, String output) { "1 3,2,4,1 2 3 4", }, nullValues = {""}) void testMerge3(String a, String b, String c, String output) { - var listA = list(parse(a)); - var listB = list(parse(b)); - var listC = list(parse(c)); - assertEquals( - LongStream.of(parse(output)).boxed().toList(), - merge(listA, listB, listC), - "ABC" - ); - assertEquals( - LongStream.of(parse(output)).boxed().toList(), - merge(listA, listC, listB), - "ACB" - ); - assertEquals( - LongStream.of(parse(output)).boxed().toList(), - merge(listB, listA, listC), - "BAC" - ); - assertEquals( - LongStream.of(parse(output)).boxed().toList(), - merge(listB, listC, listA), - "BCA" - ); - assertEquals( - LongStream.of(parse(output)).boxed().toList(), - merge(listC, listA, listB), - "CAB" - ); - assertEquals( - LongStream.of(parse(output)).boxed().toList(), - merge(listC, listB, listA), - "CBA" - ); + for (boolean primaryKey : List.of(false, true)) { + var listA = list(primaryKey, parse(a)); + var listB = list(primaryKey, parse(b)); + var listC = list(primaryKey, parse(c)); + assertEquals( + LongStream.of(parse(output)).boxed().toList(), + merge(listA, listB, listC), + "ABC primary=" + primaryKey + ); + assertEquals( + LongStream.of(parse(output)).boxed().toList(), + merge(listA, listC, listB), + "ACB primary=" + primaryKey + ); + assertEquals( + LongStream.of(parse(output)).boxed().toList(), + merge(listB, listA, listC), + "BAC primary=" + primaryKey + ); + assertEquals( + LongStream.of(parse(output)).boxed().toList(), + merge(listB, listC, listA), + "BCA primary=" + primaryKey + ); + assertEquals( + LongStream.of(parse(output)).boxed().toList(), + merge(listC, listA, listB), + "CAB primary=" + primaryKey + ); + assertEquals( + LongStream.of(parse(output)).boxed().toList(), + merge(listC, listB, listA), + "CBA primary=" + primaryKey + ); + } } @ParameterizedTest @@ -146,31 +170,33 @@ void testMerge3(String a, String b, String c, String output) { "1 2,2 3,,,1 2 2 3", }, nullValues = {""}) void testMerge4(String a, String b, String c, String d, String output) { - var listA = list(parse(a)); - var listB = list(parse(b)); - var listC = list(parse(c)); - var listD = list(parse(d)); + for (boolean primaryKey : List.of(false, true)) { + var listA = list(primaryKey, parse(a)); + var listB = list(primaryKey, parse(b)); + var listC = list(primaryKey, parse(c)); + var listD = list(primaryKey, parse(d)); - assertEquals( - LongStream.of(parse(output)).boxed().toList(), - merge(listA, listB, listC, listD), - "ABCD" - ); - assertEquals( - LongStream.of(parse(output)).boxed().toList(), - merge(listB, listA, listC, listD), - "BACD" - ); - assertEquals( - LongStream.of(parse(output)).boxed().toList(), - merge(listB, listC, listA, listD), - "BCAD" - ); - assertEquals( - LongStream.of(parse(output)).boxed().toList(), - merge(listB, listC, listD, listA), - "BCDA" - ); + assertEquals( + LongStream.of(parse(output)).boxed().toList(), + merge(listA, listB, listC, listD), + "ABCD primary=" + primaryKey + ); + assertEquals( + LongStream.of(parse(output)).boxed().toList(), + merge(listB, listA, listC, listD), + "BACD primary=" + primaryKey + ); + assertEquals( + LongStream.of(parse(output)).boxed().toList(), + merge(listB, listC, listA, listD), + "BCAD primary=" + primaryKey + ); + assertEquals( + LongStream.of(parse(output)).boxed().toList(), + merge(listB, listC, listD, listA), + "BCDA primary=" + primaryKey + ); + } } private static long[] parse(String in) { diff --git a/planetiler-core/src/test/java/com/onthegomap/planetiler/collection/LongMinHeapTest.java b/planetiler-core/src/test/java/com/onthegomap/planetiler/collection/LongMinHeapTest.java index 7c3d787f8e..65ece71268 100644 --- a/planetiler-core/src/test/java/com/onthegomap/planetiler/collection/LongMinHeapTest.java +++ b/planetiler-core/src/test/java/com/onthegomap/planetiler/collection/LongMinHeapTest.java @@ -29,6 +29,8 @@ import java.util.PriorityQueue; import java.util.Random; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; /** @@ -42,7 +44,7 @@ class LongMinHeapTest { protected LongMinHeap heap; void create(int capacity) { - heap = LongMinHeap.newArrayHeap(capacity); + heap = LongMinHeap.newArrayHeap(capacity, Integer::compare); } @Test @@ -77,6 +79,31 @@ void duplicateElements() { assertThrows(IllegalStateException.class, () -> heap.push(2, 4L)); } + @ParameterizedTest + @CsvSource({ + "0, 1, 2, 3, 4, 5", + "5, 4, 3, 2, 1, 0", + "0, 1, 2, 5, 4, 3", + "0, 1, 5, 2, 4, 3", + "0, 5, 1, 2, 4, 3", + "5, 0, 1, 2, 4, 3", + }) + void tieBreaker(int a, int b, int c, int d, int e, int f) { + heap = LongMinHeap.newArrayHeap(6, (id1, id2) -> -Integer.compare(id1, id2)); + heap.push(a, 0L); + heap.push(b, 0L); + heap.push(c, 0L); + heap.push(d, 0L); + heap.push(e, 0L); + heap.push(f, 0L); + assertEquals(5, heap.poll()); + assertEquals(4, heap.poll()); + assertEquals(3, heap.poll()); + assertEquals(2, heap.poll()); + assertEquals(1, heap.poll()); + assertEquals(0, heap.poll()); + } + @Test void testContains() { create(4);