diff --git a/.github/workflows/checkstyle.yml b/.github/workflows/checkstyle.yml index 6680ccd33b..2e0fe65f36 100644 --- a/.github/workflows/checkstyle.yml +++ b/.github/workflows/checkstyle.yml @@ -38,4 +38,4 @@ jobs: run: git fetch --all - name: Run checkstyle - run: bash ./dev/checkstyle.sh + run: mvn checkstyle:check diff --git a/dev/checkstyle.sh b/dev/checkstyle.sh index 2265bd9f84..503d495e3b 100755 --- a/dev/checkstyle.sh +++ b/dev/checkstyle.sh @@ -19,13 +19,7 @@ set -ex -# Assuming you are in the root of your git repository -if [[ -z "${GITHUB_BASE_REF+x}" ]]; then - MODIFIED_FILES=$(git diff --name-only) -else - MODIFIED_FILES=$(git diff --name-only "origin/${GITHUB_BASE_REF}") -fi - +MODIFIED_FILES=$(git diff --name-only) SRC_DIR="src/main/java/" TEST_SRC_DIR="src/test/java/" diff --git a/dev/checkstyle.xml b/dev/checkstyle.xml index 15c475ec7e..26ce0964c8 100644 --- a/dev/checkstyle.xml +++ b/dev/checkstyle.xml @@ -41,11 +41,11 @@ - - - - - + + + + + @@ -56,21 +56,21 @@ - - - - - + + + + + - - - - + + + + + @@ -134,7 +134,7 @@ - + - - - - - - - - + + + + + + + + + + - - - - + + + + + - - - - - + + + + + + + + - - + + @@ -290,13 +290,13 @@ METHOD_DEF, QUESTION, RESOURCE_SPECIFICATION, SUPER_CTOR_CALL, LAMBDA, RECORD_DEF"/> - - - - + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + diff --git a/src/main/java/com/nvidia/spark/rapids/jni/Arms.java b/src/main/java/com/nvidia/spark/rapids/jni/Arms.java index 4b6ecf7204..859bc6c0ad 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/Arms.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/Arms.java @@ -25,77 +25,80 @@ * This class contains utility methods for automatic resource management. */ public class Arms { - /** - * This method close the resource if an exception is thrown while executing the function. - */ - public static T closeIfException(R resource, Function function) { + /** + * This method close the resource if an exception is thrown while executing the function. + */ + public static T closeIfException(R resource, + Function function) { + try { + return function.apply(resource); + } catch (Exception e) { + if (resource != null) { try { - return function.apply(resource); - } catch (Exception e) { - if (resource != null) { - try { - resource.close(); - } catch (Exception inner) { - e.addSuppressed(inner); - } - } - throw e; + resource.close(); + } catch (Exception inner) { + e.addSuppressed(inner); } + } + throw e; } + } - /** - * This method safely closes all the resources. - *

- * This method will iterate through all the resources and closes them. If any exception happened during the - * traversal, exception will be captured and rethrown after all resources closed. - *

- */ - public static void closeAll(Iterator resources) { - Throwable t = null; - while (resources.hasNext()) { - try { - R resource = resources.next(); - if (resource != null) { - resource.close(); - } - } catch (Exception e) { - if (t == null) { - t = e; - } else { - t.addSuppressed(e); - } - } + /** + * This method safely closes all the resources. + *

+ * This method will iterate through all the resources and closes them. If any exception happened during the + * traversal, exception will be captured and rethrown after all resources closed. + *

+ */ + public static void closeAll(Iterator resources) { + Throwable t = null; + while (resources.hasNext()) { + try { + R resource = resources.next(); + if (resource != null) { + resource.close(); } + } catch (Exception e) { + if (t == null) { + t = e; + } else { + t.addSuppressed(e); + } + } + } - if (t != null) throw new RuntimeException(t); + if (t != null) { + throw new RuntimeException(t); } + } - /** - * This method safely closes all the resources. See {@link #closeAll(Iterator)} for more details. - */ - public static void closeAll(R... resources) { - closeAll(Arrays.asList(resources)); - } + /** + * This method safely closes all the resources. See {@link #closeAll(Iterator)} for more details. + */ + public static void closeAll(R... resources) { + closeAll(Arrays.asList(resources)); + } - /** - * This method safely closes the resources. See {@link #closeAll(Iterator)} for more details. - */ - public static void closeAll(Collection resources) { - closeAll(resources.iterator()); - } + /** + * This method safely closes the resources. See {@link #closeAll(Iterator)} for more details. + */ + public static void closeAll(Collection resources) { + closeAll(resources.iterator()); + } - /** - * This method safely closes the resources after applying the function. - *
- * See {@link #closeAll(Iterator)} for more details. - */ - public static , V> V withResource( - C resource, Function function) { - try { - return function.apply(resource); - } finally { - closeAll(resource); - } + /** + * This method safely closes the resources after applying the function. + *
+ * See {@link #closeAll(Iterator)} for more details. + */ + public static , V> V withResource( + C resource, Function function) { + try { + return function.apply(resource); + } finally { + closeAll(resource); } + } } diff --git a/src/main/java/com/nvidia/spark/rapids/jni/BloomFilter.java b/src/main/java/com/nvidia/spark/rapids/jni/BloomFilter.java index 46bf9a7f08..1b4d372953 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/BloomFilter.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/BloomFilter.java @@ -16,16 +16,13 @@ package com.nvidia.spark.rapids.jni; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import ai.rapids.cudf.BaseDeviceMemoryBuffer; import ai.rapids.cudf.ColumnVector; import ai.rapids.cudf.CudfAccessor; import ai.rapids.cudf.CudfException; import ai.rapids.cudf.DType; -import ai.rapids.cudf.Scalar; import ai.rapids.cudf.NativeDepsLoader; +import ai.rapids.cudf.Scalar; public class BloomFilter { static { @@ -34,16 +31,17 @@ public class BloomFilter { /** * Create a bloom filter with the specified number of hashes and bloom filter bits. - * @param numHashes The number of hashes to use when inserting values into the bloom filter or - * when probing. + * + * @param numHashes The number of hashes to use when inserting values into the bloom filter or + * when probing. * @param bloomFilterBits Size of the bloom filter in bits. * @return a Scalar object which encapsulates the bloom filter. */ - public static Scalar create(int numHashes, long bloomFilterBits){ - if(numHashes <= 0){ + public static Scalar create(int numHashes, long bloomFilterBits) { + if (numHashes <= 0) { throw new IllegalArgumentException("Bloom filters must have a positive hash count"); } - if(bloomFilterBits <= 0){ + if (bloomFilterBits <= 0) { throw new IllegalArgumentException("Bloom filters must have a positive number of bits"); } return CudfAccessor.scalarFromHandle(DType.LIST, creategpu(numHashes, bloomFilterBits)); @@ -51,54 +49,64 @@ public static Scalar create(int numHashes, long bloomFilterBits){ /** * Insert a column of longs into a bloom filter. + * * @param bloomFilter The bloom filter to which values will be inserted. - * @param cv The column containing the values to add. + * @param cv The column containing the values to add. */ - public static void put(Scalar bloomFilter, ColumnVector cv){ + public static void put(Scalar bloomFilter, ColumnVector cv) { put(CudfAccessor.getScalarHandle(bloomFilter), cv.getNativeView()); } /** * Merge one or more bloom filters into a new bloom filter. - * @param bloomFilters A ColumnVector containing a bloom filter per row. + * + * @param bloomFilters A ColumnVector containing a bloom filter per row. * @return A new bloom filter containing the merged inputs. */ - public static Scalar merge(ColumnVector bloomFilters){ + public static Scalar merge(ColumnVector bloomFilters) { return CudfAccessor.scalarFromHandle(DType.LIST, merge(bloomFilters.getNativeView())); } /** - * Probe a bloom filter with a column of longs. Returns a column of booleans. For + * Probe a bloom filter with a column of longs. Returns a column of booleans. For * each row in the output; a value of true indicates that the corresponding input value * -may- be in the set of values used to build the bloom filter; a value of false indicates * that the corresponding input value is conclusively not in the set of values used to build - * the bloom filter. + * the bloom filter. + * * @param bloomFilter The bloom filter to be probed. - * @param cv The column containing the values to check. + * @param cv The column containing the values to check. * @return A boolean column indicating the results of the probe. */ - public static ColumnVector probe(Scalar bloomFilter, ColumnVector cv){ + public static ColumnVector probe(Scalar bloomFilter, ColumnVector cv) { return new ColumnVector(probe(CudfAccessor.getScalarHandle(bloomFilter), cv.getNativeView())); } /** - * Probe a bloom filter with a column of longs. Returns a column of booleans. For + * Probe a bloom filter with a column of longs. Returns a column of booleans. For * each row in the output; a value of true indicates that the corresponding input value * -may- be in the set of values used to build the bloom filter; a value of false indicates * that the corresponding input value is conclusively not in the set of values used to build - * the bloom filter. - * @param bloomFilter The bloom filter to be probed. This buffer is expected to be the - * fully packed Spark bloom filter, including header. - * @param cv The column containing the values to check. + * the bloom filter. + * + * @param bloomFilter The bloom filter to be probed. This buffer is expected to be the + * fully packed Spark bloom filter, including header. + * @param cv The column containing the values to check. * @return A boolean column indicating the results of the probe. */ - public static ColumnVector probe(BaseDeviceMemoryBuffer bloomFilter, ColumnVector cv){ - return new ColumnVector(probebuffer(bloomFilter.getAddress(), bloomFilter.getLength(), cv.getNativeView())); + public static ColumnVector probe(BaseDeviceMemoryBuffer bloomFilter, ColumnVector cv) { + return new ColumnVector( + probebuffer(bloomFilter.getAddress(), bloomFilter.getLength(), cv.getNativeView())); } - + private static native long creategpu(int numHashes, long bloomFilterBits) throws CudfException; + private static native int put(long bloomFilter, long cv) throws CudfException; + private static native long merge(long bloomFilters) throws CudfException; - private static native long probe(long bloomFilter, long cv) throws CudfException; - private static native long probebuffer(long bloomFilter, long bloomFilterSize, long cv) throws CudfException; + + private static native long probe(long bloomFilter, long cv) throws CudfException; + + private static native long probebuffer(long bloomFilter, long bloomFilterSize, long cv) + throws CudfException; } diff --git a/src/main/java/com/nvidia/spark/rapids/jni/CaseWhen.java b/src/main/java/com/nvidia/spark/rapids/jni/CaseWhen.java index dabefd0a98..130c23f5e8 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/CaseWhen.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/CaseWhen.java @@ -16,59 +16,57 @@ package com.nvidia.spark.rapids.jni; -import ai.rapids.cudf.*; +import ai.rapids.cudf.ColumnVector; +import ai.rapids.cudf.DType; /** * Exedute SQL `case when` semantic. * If there are multiple branches and each branch uses scalar to generator value, * then it's fast to use this class because it does not generate temp string columns. - * + *

* E.g.: - * SQL is: - * select - * case - * when bool_1_expr then "value_1" - * when bool_2_expr then "value_2" - * when bool_3_expr then "value_3" - * else "value_else" - * end - * from tab - * + * SQL is: + * select + * case + * when bool_1_expr then "value_1" + * when bool_2_expr then "value_2" + * when bool_3_expr then "value_3" + * else "value_else" + * end + * from tab + *

* Execution steps: - * Execute bool exprs to get bool columns, e.g., gets: - * bool column 1: [true, false, false, false] // bool_1_expr result - * bool column 2: [false, true, false, flase] // bool_2_expr result - * bool column 3: [false, false, true, flase] // bool_3_expr result - * Execute `selectFirstTrueIndex` to get the column index for the first true in bool columns. - * Generate a column to store salars: "value_1", "value_2", "value_3", "value_else" - * Execute `Table.gather` to generate the final output column - * + * Execute bool exprs to get bool columns, e.g., gets: + * bool column 1: [true, false, false, false] // bool_1_expr result + * bool column 2: [false, true, false, flase] // bool_2_expr result + * bool column 3: [false, false, true, flase] // bool_3_expr result + * Execute `selectFirstTrueIndex` to get the column index for the first true in bool columns. + * Generate a column to store salars: "value_1", "value_2", "value_3", "value_else" + * Execute `Table.gather` to generate the final output column */ public class CaseWhen { /** - * * Select the column index for the first true in bool columns. * For the row does not contain true, use end index(number of columns). - * + *

* e.g.: - * column 0: true, false, false, false - * column 1: false, true, false, false - * column 2: false, false, true, false - * - * 1st row is: true, flase, false; first true index is 0 - * 2nd row is: false, true, false; first true index is 1 - * 3rd row is: false, flase, true; first true index is 2 - * 4th row is: false, false, false; do not find true, set index to the end index 3 - * - * output column: 0, 1, 2, 3 - * In the `case when` context, here 3 index means using NULL value. - * - */ + * column 0: true, false, false, false + * column 1: false, true, false, false + * column 2: false, false, true, false + *

+ * 1st row is: true, flase, false; first true index is 0 + * 2nd row is: false, true, false; first true index is 1 + * 3rd row is: false, flase, true; first true index is 2 + * 4th row is: false, false, false; do not find true, set index to the end index 3 + *

+ * output column: 0, 1, 2, 3 + * In the `case when` context, here 3 index means using NULL value. + */ public static ColumnVector selectFirstTrueIndex(ColumnVector[] boolColumns) { for (ColumnVector cv : boolColumns) { - assert(cv.getType().equals(DType.BOOL8)) : "Columns must be bools"; + assert (cv.getType().equals(DType.BOOL8)) : "Columns must be bools"; } long[] boolHandles = new long[boolColumns.length]; diff --git a/src/main/java/com/nvidia/spark/rapids/jni/CastException.java b/src/main/java/com/nvidia/spark/rapids/jni/CastException.java index ca6b5a9d56..7000cfb74f 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/CastException.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/CastException.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package com.nvidia.spark.rapids.jni; /** @@ -21,9 +22,9 @@ public class CastException extends RuntimeException { private final int rowWithError; private final String stringWithError; - + CastException(String stringWithError, int rowWithError) { - super("Error casting data on row " + String.valueOf(rowWithError) + ": " + stringWithError); + super("Error casting data on row " + rowWithError + ": " + stringWithError); this.rowWithError = rowWithError; this.stringWithError = stringWithError; diff --git a/src/main/java/com/nvidia/spark/rapids/jni/CastStrings.java b/src/main/java/com/nvidia/spark/rapids/jni/CastStrings.java index 2b2267f034..c08159b17a 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/CastStrings.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/CastStrings.java @@ -16,9 +16,14 @@ package com.nvidia.spark.rapids.jni; -import ai.rapids.cudf.*; +import ai.rapids.cudf.ColumnVector; +import ai.rapids.cudf.ColumnView; +import ai.rapids.cudf.DType; +import ai.rapids.cudf.NativeDepsLoader; -/** Utility class for casting between string columns and native type columns */ +/** + * Utility class for casting between string columns and native type columns + */ public class CastStrings { static { NativeDepsLoader.loadNativeDeps(); @@ -28,9 +33,9 @@ public class CastStrings { * Convert a string column to an integer column of a specified type stripping away leading and * trailing spaces. * - * @param cv the column data to process. + * @param cv the column data to process. * @param ansiMode true if invalid data are errors, false if they should be nulls. - * @param type the type of the return column. + * @param type the type of the return column. * @return the converted column. */ public static ColumnVector toInteger(ColumnView cv, boolean ansiMode, DType type) { @@ -40,10 +45,10 @@ public static ColumnVector toInteger(ColumnView cv, boolean ansiMode, DType type /** * Convert a string column to an integer column of a specified type. * - * @param cv the column data to process. + * @param cv the column data to process. * @param ansiMode true if invalid data are errors, false if they should be nulls. - * @param strip true if leading and trailing spaces should be ignored when parsing. - * @param type the type of the return column. + * @param strip true if leading and trailing spaces should be ignored when parsing. + * @param type the type of the return column. * @return the converted column. */ public static ColumnVector toInteger(ColumnView cv, boolean ansiMode, boolean strip, DType type) { @@ -55,10 +60,10 @@ public static ColumnVector toInteger(ColumnView cv, boolean ansiMode, boolean st * Convert a string column to an integer column of a specified type stripping away leading and * trailing whitespace. * - * @param cv the column data to process. - * @param ansiMode true if invalid data are errors, false if they should be nulls. + * @param cv the column data to process. + * @param ansiMode true if invalid data are errors, false if they should be nulls. * @param precision the output precision. - * @param scale the output scale. + * @param scale the output scale. * @return the converted column. */ public static ColumnVector toDecimal(ColumnView cv, boolean ansiMode, int precision, int scale) { @@ -68,22 +73,22 @@ public static ColumnVector toDecimal(ColumnView cv, boolean ansiMode, int precis /** * Convert a string column to an integer column of a specified type. * - * @param cv the column data to process. - * @param ansiMode true if invalid data are errors, false if they should be nulls. - * @param strip true if leading and trailing white space should be stripped. + * @param cv the column data to process. + * @param ansiMode true if invalid data are errors, false if they should be nulls. + * @param strip true if leading and trailing white space should be stripped. * @param precision the output precision. - * @param scale the output scale. + * @param scale the output scale. * @return the converted column. */ public static ColumnVector toDecimal(ColumnView cv, boolean ansiMode, boolean strip, - int precision, int scale) { + int precision, int scale) { return new ColumnVector(toDecimal(cv.getNativeView(), ansiMode, strip, precision, scale)); } /** * Convert a float column to a formatted string column. * - * @param cv the column data to process + * @param cv the column data to process * @param digits the number of digits to display after the decimal point * @return the converted column */ @@ -114,9 +119,9 @@ public static ColumnVector fromDecimal(ColumnView cv) { /** * Convert a string column to a given floating-point type column. * - * @param cv the column data to process. + * @param cv the column data to process. * @param ansiMode true if invalid data are errors, false if they should be nulls. - * @param type the type of the return column. + * @param type the type of the return column. * @return the converted column. */ public static ColumnVector toFloat(ColumnView cv, boolean ansiMode, DType type) { @@ -125,18 +130,18 @@ public static ColumnVector toFloat(ColumnView cv, boolean ansiMode, DType type) public static ColumnVector toIntegersWithBase(ColumnView cv, int base, - boolean ansiEnabled, DType type) { + boolean ansiEnabled, DType type) { return new ColumnVector(toIntegersWithBase(cv.getNativeView(), base, ansiEnabled, - type.getTypeId().getNativeId())); + type.getTypeId().getNativeId())); } /** * Converts an integer column to a string column by converting the underlying integers to the * specified base. - * + *

* Note: Right now we only support base 10 and 16. The hexadecimal values will be * returned without leading zeros or padding at the end - * + *

* Example: * input = [123, -1, 0, 27, 342718233] * s = fromIntegersWithBase(input, 16) @@ -144,7 +149,7 @@ public static ColumnVector toIntegersWithBase(ColumnView cv, int base, * s = fromIntegersWithBase(input, 10) * s is ['123', '-1', '0', '27', '342718233'] * - * @param cv The input integer column to be converted. + * @param cv The input integer column to be converted. * @param base base that we want to convert to (currently only 10/16) * @return a new String ColumnVector */ @@ -153,14 +158,21 @@ public static ColumnVector fromIntegersWithBase(ColumnView cv, int base) { } private static native long toInteger(long nativeColumnView, boolean ansi_enabled, boolean strip, - int dtype); + int dtype); + private static native long toDecimal(long nativeColumnView, boolean ansi_enabled, boolean strip, - int precision, int scale); + int precision, int scale); + private static native long toFloat(long nativeColumnView, boolean ansi_enabled, int dtype); + private static native long fromDecimal(long nativeColumnView); + private static native long fromFloatWithFormat(long nativeColumnView, int digits); + private static native long fromFloat(long nativeColumnView); + private static native long toIntegersWithBase(long nativeColumnView, int base, - boolean ansiEnabled, int dtype); + boolean ansiEnabled, int dtype); + private static native long fromIntegersWithBase(long nativeColumnView, int base); } \ No newline at end of file diff --git a/src/main/java/com/nvidia/spark/rapids/jni/DateTimeRebase.java b/src/main/java/com/nvidia/spark/rapids/jni/DateTimeRebase.java index d73ee038d6..724fbc573c 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/DateTimeRebase.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/DateTimeRebase.java @@ -16,7 +16,9 @@ package com.nvidia.spark.rapids.jni; -import ai.rapids.cudf.*; +import ai.rapids.cudf.ColumnVector; +import ai.rapids.cudf.ColumnView; +import ai.rapids.cudf.NativeDepsLoader; /** * Utility class for converting between column major and row major data diff --git a/src/main/java/com/nvidia/spark/rapids/jni/DecimalUtils.java b/src/main/java/com/nvidia/spark/rapids/jni/DecimalUtils.java index 167cb099b1..9b25bff223 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/DecimalUtils.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/DecimalUtils.java @@ -31,7 +31,7 @@ public class DecimalUtils { * Multiply two DECIMAL128 columns together into a DECIMAL128 product rounded to the specified * scale with overflow detection. This method considers a precision greater than 38 as overflow * even if the number still fits in a 128-bit representation. - * + *

* WARNING: This method has a bug which we match with Spark versions before 3.4.2, * 4.0.0, 3.5.1. Consider the following example using Decimal with a precision of 38 and scale of 10: * -8533444864753048107770677711.1312637916 * -12.0000000000 = 102401338377036577293248132533.575166 @@ -53,23 +53,24 @@ public static Table multiply128(ColumnView a, ColumnView b, int productScale) { * Multiply two DECIMAL128 columns together into a DECIMAL128 product rounded to the specified * scale with overflow detection. This method considers a precision greater than 38 as overflow * even if the number still fits in a 128-bit representation. - * + *

* WARNING: With interimCast set to true, this method has a bug which we match with Spark versions before 3.4.2, * 4.0.0, 3.5.1. Consider the following example using Decimal with a precision of 38 and scale of 10: * -8533444864753048107770677711.1312637916 * -12.0000000000 = 102401338377036577293248132533.575166 * while the actual answer based on Java BigDecimal is 102401338377036577293248132533.575165 * - * @param a factor input, must match row count of the other factor input - * @param b factor input, must match row count of the other factor input + * @param a factor input, must match row count of the other factor input + * @param b factor input, must match row count of the other factor input * @param productScale scale to use for the product type - * @param interimCast whether to cast the result of the division to 38 precision before casting it again to the final - * precision + * @param interimCast whether to cast the result of the division to 38 precision before casting it again to the final + * precision * @return table containing a boolean column and a DECIMAL128 product column of the specified - * scale. The boolean value will be true if an overflow was detected for that row's - * DECIMAL128 product value. A null input row will result in a corresponding null output - * row. + * scale. The boolean value will be true if an overflow was detected for that row's + * DECIMAL128 product value. A null input row will result in a corresponding null output + * row. */ - public static Table multiply128(ColumnView a, ColumnView b, int productScale, boolean interimCast) { + public static Table multiply128(ColumnView a, ColumnView b, int productScale, + boolean interimCast) { return new Table(multiply128(a.getNativeView(), b.getNativeView(), productScale, interimCast)); } @@ -77,13 +78,14 @@ public static Table multiply128(ColumnView a, ColumnView b, int productScale, bo * Divide two DECIMAL128 columns and produce a DECIMAL128 quotient rounded to the specified * scale with overflow detection. This method considers a precision greater than 38 as overflow * even if the number still fits in a 128-bit representation. - * @param a factor input, must match row count of the other factor input - * @param b factor input, must match row count of the other factor input + * + * @param a factor input, must match row count of the other factor input + * @param b factor input, must match row count of the other factor input * @param quotientScale scale to use for the quotient type * @return table containing a boolean column and a DECIMAL128 quotient column of the specified - * scale. The boolean value will be true if an overflow was detected for that row's - * DECIMAL128 quotient value. A null input row will result in a corresponding null output - * row. + * scale. The boolean value will be true if an overflow was detected for that row's + * DECIMAL128 quotient value. A null input row will result in a corresponding null output + * row. */ public static Table divide128(ColumnView a, ColumnView b, int quotientScale) { return new Table(divide128(a.getNativeView(), b.getNativeView(), quotientScale, false)); @@ -103,9 +105,9 @@ public static Table divide128(ColumnView a, ColumnView b, int quotientScale) { * @param a factor input, must match row count of the other factor input * @param b factor input, must match row count of the other factor input * @return table containing a boolean column and a INT64 quotient column. - * The boolean value will be true if an overflow was detected for that row's - * INT64 quotient value. A null input row will result in a corresponding null output - * row. + * The boolean value will be true if an overflow was detected for that row's + * INT64 quotient value. A null input row will result in a corresponding null output + * row. */ public static Table integerDivide128(ColumnView a, ColumnView b) { return new Table(divide128(a.getNativeView(), b.getNativeView(), 0, true)); @@ -115,17 +117,17 @@ public static Table integerDivide128(ColumnView a, ColumnView b) { * Divide two DECIMAL128 columns and produce a DECIMAL128 remainder with overflow detection. * Example: * 451635271134476686911387864.48 % -961.110 = 775.233 - * + *

* Generally, this will never really overflow unless in the divide by zero case. * But it will detect an overflow in any case. * - * @param a factor input, must match row count of the other factor input - * @param b factor input, must match row count of the other factor input + * @param a factor input, must match row count of the other factor input + * @param b factor input, must match row count of the other factor input * @param remainderScale scale to use for the remainder type * @return table containing a boolean column and a DECIMAL128 remainder column. - * The boolean value will be true if an overflow was detected for that row's - * DECIMAL128 remainder value. A null input row will result in a corresponding null - * output row. + * The boolean value will be true if an overflow was detected for that row's + * DECIMAL128 remainder value. A null input row will result in a corresponding null + * output row. */ public static Table remainder128(ColumnView a, ColumnView b, int remainderScale) { return new Table(remainder128(a.getNativeView(), b.getNativeView(), remainderScale)); @@ -135,17 +137,17 @@ public static Table remainder128(ColumnView a, ColumnView b, int remainderScale) * Subtract two DECIMAL128 columns and produce a DECIMAL128 result rounded to the specified * scale with overflow detection. This method considers a precision greater than 38 as overflow * even if the number still fits in a 128-bit representation. - * + *

* NOTE: This is very specific to Spark 3.4. This method is incompatible with previous versions * of Spark. We don't need this for versions prior to Spark 3.4 * - * @param a input, must match row count of the other input - * @param b input, must match row count of the other input + * @param a input, must match row count of the other input + * @param b input, must match row count of the other input * @param targetScale scale to use for the result * @return table containing a boolean column and a DECIMAL128 result column of the specified - * scale. The boolean value will be true if an overflow was detected for that row's - * DECIMAL128 result value. A null input row will result in a corresponding null output - * row. + * scale. The boolean value will be true if an overflow was detected for that row's + * DECIMAL128 result value. A null input row will result in a corresponding null output + * row. */ public static Table subtract128(ColumnView a, ColumnView b, int targetScale) { @@ -160,17 +162,17 @@ public static Table subtract128(ColumnView a, ColumnView b, int targetScale) { * Add two DECIMAL128 columns and produce a DECIMAL128 result rounded to the specified * scale with overflow detection. This method considers a precision greater than 38 as overflow * even if the number still fits in a 128-bit representation. - * + *

* NOTE: This is very specific to Spark 3.4. This method is incompatible with previous versions * of Spark. We don't need this for versions prior to Spark 3.4 * - * @param a input, must match row count of the other input - * @param b input, must match row count of the other input + * @param a input, must match row count of the other input + * @param b input, must match row count of the other input * @param targetScale scale to use for the sum * @return table containing a boolean column and a DECIMAL128 sum column of the specified - * scale. The boolean value will be true if an overflow was detected for that row's - * DECIMAL128 result value. A null input row will result in a corresponding null output - * row. + * scale. The boolean value will be true if an overflow was detected for that row's + * DECIMAL128 result value. A null input row will result in a corresponding null output + * row. */ public static Table add128(ColumnView a, ColumnView b, int targetScale) { if (java.lang.Math.abs(a.getType().getScale() - b.getType().getScale()) > 77) { @@ -180,29 +182,13 @@ public static Table add128(ColumnView a, ColumnView b, int targetScale) { return new Table(add128(a.getNativeView(), b.getNativeView(), targetScale)); } - /** - * A class to store the result of a cast operation from floating point values to decimals. - *

- * Since the result column may or may not be used regardless of the value of hasFailure, we - * need to keep it and let the caller to decide. - */ - public static class CastFloatToDecimalResult { - public final ColumnVector result; // the cast result - public final boolean hasFailure; // whether the cast operation has failed for any input rows - - public CastFloatToDecimalResult(ColumnVector result, boolean hasFailure) { - this.result = result; - this.hasFailure = hasFailure; - } - } - /** * Cast floating point values to decimals, matching the behavior of Spark. * - * @param input The input column, which is either FLOAT32 or FLOAT64 + * @param input The input column, which is either FLOAT32 or FLOAT64 * @param outputType The output decimal type * @return The decimal column resulting from the cast operation and a boolean value indicating - * whether the cast operation has failed for any input rows + * whether the cast operation has failed for any input rows */ public static CastFloatToDecimalResult floatingPointToDecimal(ColumnView input, DType outputType, int precision) { @@ -212,9 +198,11 @@ public static CastFloatToDecimalResult floatingPointToDecimal(ColumnView input, return new CastFloatToDecimalResult(new ColumnVector(result[0]), result[1] != 0); } - private static native long[] multiply128(long viewA, long viewB, int productScale, boolean interimCast); + private static native long[] multiply128(long viewA, long viewB, int productScale, + boolean interimCast); - private static native long[] divide128(long viewA, long viewB, int quotientScale, boolean isIntegerDivide); + private static native long[] divide128(long viewA, long viewB, int quotientScale, + boolean isIntegerDivide); private static native long[] remainder128(long viewA, long viewB, int remainderScale); @@ -222,5 +210,22 @@ public static CastFloatToDecimalResult floatingPointToDecimal(ColumnView input, private static native long[] subtract128(long viewA, long viewB, int targetScale); - private static native long[] floatingPointToDecimal(long inputHandle, int outputTypeId, int precision, int scale); + private static native long[] floatingPointToDecimal(long inputHandle, int outputTypeId, + int precision, int scale); + + /** + * A class to store the result of a cast operation from floating point values to decimals. + *

+ * Since the result column may or may not be used regardless of the value of hasFailure, we + * need to keep it and let the caller to decide. + */ + public static class CastFloatToDecimalResult { + public final ColumnVector result; // the cast result + public final boolean hasFailure; // whether the cast operation has failed for any input rows + + public CastFloatToDecimalResult(ColumnVector result, boolean hasFailure) { + this.result = result; + this.hasFailure = hasFailure; + } + } } diff --git a/src/main/java/com/nvidia/spark/rapids/jni/GpuSubstringIndexUtils.java b/src/main/java/com/nvidia/spark/rapids/jni/GpuSubstringIndexUtils.java index a8750919c9..4afaeedf06 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/GpuSubstringIndexUtils.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/GpuSubstringIndexUtils.java @@ -16,16 +16,23 @@ package com.nvidia.spark.rapids.jni; -import ai.rapids.cudf.*; +import ai.rapids.cudf.ColumnVector; +import ai.rapids.cudf.ColumnView; +import ai.rapids.cudf.CudfAccessor; +import ai.rapids.cudf.CudfException; +import ai.rapids.cudf.NativeDepsLoader; +import ai.rapids.cudf.Scalar; public class GpuSubstringIndexUtils { - static{ - NativeDepsLoader.loadNativeDeps(); - } + static { + NativeDepsLoader.loadNativeDeps(); + } - public static ColumnVector substringIndex(ColumnView cv, Scalar delimiter, int count){ - return new ColumnVector(substringIndex(cv.getNativeView(), CudfAccessor.getScalarHandle(delimiter), count)); - } + public static ColumnVector substringIndex(ColumnView cv, Scalar delimiter, int count) { + return new ColumnVector( + substringIndex(cv.getNativeView(), CudfAccessor.getScalarHandle(delimiter), count)); + } - private static native long substringIndex(long columnView, long delimiter, int count) throws CudfException; + private static native long substringIndex(long columnView, long delimiter, int count) + throws CudfException; } diff --git a/src/main/java/com/nvidia/spark/rapids/jni/GpuTimeZoneDB.java b/src/main/java/com/nvidia/spark/rapids/jni/GpuTimeZoneDB.java index a8048b1e8b..3ad54b996d 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/GpuTimeZoneDB.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/GpuTimeZoneDB.java @@ -1,18 +1,18 @@ /* -* Copyright (c) 2023-2024, NVIDIA CORPORATION. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.nvidia.spark.rapids.jni; @@ -20,9 +20,6 @@ import ai.rapids.cudf.DType; import ai.rapids.cudf.HostColumnVector; import ai.rapids.cudf.Table; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import java.time.Instant; import java.time.ZoneId; import java.time.zone.ZoneOffsetTransition; @@ -34,17 +31,19 @@ import java.util.Map; import java.util.TimeZone; import java.util.concurrent.Executors; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Gpu time zone utility. * Provides two kinds of APIs - * - Time zone transitions cache APIs - * `cacheDatabaseAsync`, `cacheDatabase` and `shutdown` are synchronized. - * When cacheDatabaseAsync is running, the `shutdown` and `cacheDatabase` will wait; - * These APIs guarantee only one thread is loading transitions cache, - * And guarantee loading cache only occurs one time. - * - Rebase time zone APIs - * fromTimestampToUtcTimestamp, fromUtcTimestampToTimestamp ... + * - Time zone transitions cache APIs + * `cacheDatabaseAsync`, `cacheDatabase` and `shutdown` are synchronized. + * When cacheDatabaseAsync is running, the `shutdown` and `cacheDatabase` will wait; + * These APIs guarantee only one thread is loading transitions cache, + * And guarantee loading cache only occurs one time. + * - Rebase time zone APIs + * fromTimestampToUtcTimestamp, fromUtcTimestampToTimestamp ... */ public class GpuTimeZoneDB { private static final Logger log = LoggerFactory.getLogger(GpuTimeZoneDB.class); @@ -104,7 +103,7 @@ private static synchronized void cacheDatabaseImpl() { } } - private static synchronized void closeResources() { + private static synchronized void closeResources() { if (zoneIdToTable != null) { zoneIdToTable.clear(); zoneIdToTable = null; @@ -115,7 +114,8 @@ private static synchronized void closeResources() { } } - public static ColumnVector fromTimestampToUtcTimestamp(ColumnVector input, ZoneId currentTimeZone) { + public static ColumnVector fromTimestampToUtcTimestamp(ColumnVector input, + ZoneId currentTimeZone) { // TODO: Remove this check when all timezones are supported // (See https://github.com/NVIDIA/spark-rapids/issues/6840) if (!isSupportedTimeZone(currentTimeZone)) { @@ -132,8 +132,9 @@ public static ColumnVector fromTimestampToUtcTimestamp(ColumnVector input, ZoneI transitions.getNativeView(), tzIndex)); } } - - public static ColumnVector fromUtcTimestampToTimestamp(ColumnVector input, ZoneId desiredTimeZone) { + + public static ColumnVector fromUtcTimestampToTimestamp(ColumnVector input, + ZoneId desiredTimeZone) { // TODO: Remove this check when all timezones are supported // (See https://github.com/NVIDIA/spark-rapids/issues/6840) if (!isSupportedTimeZone(desiredTimeZone)) { @@ -150,13 +151,13 @@ public static ColumnVector fromUtcTimestampToTimestamp(ColumnVector input, ZoneI transitions.getNativeView(), tzIndex)); } } - + // TODO: Deprecate this API when we support all timezones // (See https://github.com/NVIDIA/spark-rapids/issues/6840) public static boolean isSupportedTimeZone(ZoneId desiredTimeZone) { return desiredTimeZone != null && - (desiredTimeZone.getRules().isFixedOffset() || - desiredTimeZone.getRules().getTransitionRules().isEmpty()); + (desiredTimeZone.getRules().isFixedOffset() || + desiredTimeZone.getRules().getTransitionRules().isEmpty()); } public static boolean isSupportedTimeZone(String zoneId) { @@ -170,10 +171,10 @@ public static boolean isSupportedTimeZone(String zoneId) { // Ported from Spark. Used to format time zone ID string with (+|-)h:mm and (+|-)hh:m public static ZoneId getZoneId(String timeZoneId) { String formattedZoneId = timeZoneId - // To support the (+|-)h:mm format because it was supported before Spark 3.0. - .replaceFirst("(\\+|\\-)(\\d):", "$10$2:") - // To support the (+|-)hh:m format because it was supported before Spark 3.0. - .replaceFirst("(\\+|\\-)(\\d\\d):(\\d)$", "$1$2:0$3"); + // To support the (+|-)h:mm format because it was supported before Spark 3.0. + .replaceFirst("(\\+|\\-)(\\d):", "$10$2:") + // To support the (+|-)hh:m format because it was supported before Spark 3.0. + .replaceFirst("(\\+|\\-)(\\d\\d):(\\d)$", "$1$2:0$3"); return ZoneId.of(formattedZoneId, ZoneId.SHORT_IDS); } @@ -266,12 +267,13 @@ private static synchronized ColumnVector getFixedTransitions() { /** * FOR TESTING PURPOSES ONLY, DO NOT USE IN PRODUCTION - * - * This method retrieves the raw list of struct data that forms the list of - * fixed transitions for a particular zoneId. - * + *

+ * This method retrieves the raw list of struct data that forms the list of + * fixed transitions for a particular zoneId. + *

* It has default visibility so the test can access it. - * @param zoneId + * + * @param zoneId Zone id * @return list of fixed transitions */ static synchronized List getHostFixedTransitions(String zoneId) { @@ -285,5 +287,6 @@ static synchronized List getHostFixedTransitions(String zoneId) { private static native long convertTimestampColumnToUTC(long input, long transitions, int tzIndex); - private static native long convertUTCTimestampColumnToTimeZone(long input, long transitions, int tzIndex); + private static native long convertUTCTimestampColumnToTimeZone(long input, long transitions, + int tzIndex); } diff --git a/src/main/java/com/nvidia/spark/rapids/jni/Hash.java b/src/main/java/com/nvidia/spark/rapids/jni/Hash.java index a25fead0fd..f7d4223697 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/Hash.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/Hash.java @@ -33,68 +33,71 @@ public class Hash { * Create a new vector containing spark's 32-bit murmur3 hash of each row in the table. * Spark's murmur3 hash uses a different tail processing algorithm. * - * @param seed integer seed for the murmur3 hash function + * @param seed integer seed for the murmur3 hash function * @param columns array of columns to hash, must have identical number of rows. * @return the new ColumnVector of 32-bit values representing each row's hash value. */ - public static ColumnVector murmurHash32(int seed, ColumnView columns[]) { + public static ColumnVector murmurHash32(int seed, ColumnView[] columns) { if (columns.length < 1) { throw new IllegalArgumentException("Murmur3 hashing requires at least 1 column of input"); } long[] columnViews = new long[columns.length]; long size = columns[0].getRowCount(); - for(int i = 0; i < columns.length; i++) { + for (int i = 0; i < columns.length; i++) { assert columns[i] != null : "Column vectors passed may not be null"; - assert columns[i].getRowCount() == size : "Row count mismatch, all columns must be the same size"; + assert columns[i].getRowCount() == size : + "Row count mismatch, all columns must be the same size"; assert !columns[i].getType().isDurationType() : "Unsupported column type Duration"; - columnViews[i] = columns[i].getNativeView(); + columnViews[i] = columns[i].getNativeView(); } return new ColumnVector(murmurHash32(seed, columnViews)); } - public static ColumnVector murmurHash32(ColumnView columns[]) { + public static ColumnVector murmurHash32(ColumnView[] columns) { return murmurHash32(0, columns); } /** * Create a new vector containing the xxhash64 hash of each row in the table. * - * @param seed integer seed for the xxhash64 hash function + * @param seed integer seed for the xxhash64 hash function * @param columns array of columns to hash, must have identical number of rows. * @return the new ColumnVector of 64-bit values representing each row's hash value. */ - public static ColumnVector xxhash64(long seed, ColumnView columns[]) { + public static ColumnVector xxhash64(long seed, ColumnView[] columns) { if (columns.length < 1) { throw new IllegalArgumentException("xxhash64 hashing requires at least 1 column of input"); } long[] columnViews = new long[columns.length]; long size = columns[0].getRowCount(); - for(int i = 0; i < columns.length; i++) { + for (int i = 0; i < columns.length; i++) { assert columns[i] != null : "Column vectors passed may not be null"; - assert columns[i].getRowCount() == size : "Row count mismatch, all columns must be the same size"; + assert columns[i].getRowCount() == size : + "Row count mismatch, all columns must be the same size"; assert !columns[i].getType().isDurationType() : "Unsupported column type Duration"; assert !columns[i].getType().isNestedType() : "Unsupported column type Nested"; - columnViews[i] = columns[i].getNativeView(); + columnViews[i] = columns[i].getNativeView(); } return new ColumnVector(xxhash64(seed, columnViews)); } - public static ColumnVector xxhash64(ColumnView columns[]) { + public static ColumnVector xxhash64(ColumnView[] columns) { return xxhash64(DEFAULT_XXHASH64_SEED, columns); } - public static ColumnVector hiveHash(ColumnView columns[]) { + public static ColumnVector hiveHash(ColumnView[] columns) { if (columns.length < 1) { throw new IllegalArgumentException("Hive hashing requires at least 1 column of input"); } long[] columnViews = new long[columns.length]; long size = columns[0].getRowCount(); - for(int i = 0; i < columns.length; i++) { + for (int i = 0; i < columns.length; i++) { assert columns[i] != null : "Column vectors passed may not be null"; - assert columns[i].getRowCount() == size : "Row count mismatch, all columns must be the same size"; + assert columns[i].getRowCount() == size : + "Row count mismatch, all columns must be the same size"; assert !columns[i].getType().isDurationType() : "Unsupported column type Duration"; assert !columns[i].getType().isNestedType() : "Unsupported column type Nested"; columnViews[i] = columns[i].getNativeView(); @@ -103,7 +106,7 @@ public static ColumnVector hiveHash(ColumnView columns[]) { } private static native long murmurHash32(int seed, long[] viewHandles) throws CudfException; - + private static native long xxhash64(long seed, long[] viewHandles) throws CudfException; private static native long hiveHash(long[] viewHandles) throws CudfException; diff --git a/src/main/java/com/nvidia/spark/rapids/jni/HostTable.java b/src/main/java/com/nvidia/spark/rapids/jni/HostTable.java index 754412d727..5e268a70cf 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/HostTable.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/HostTable.java @@ -36,10 +36,16 @@ public class HostTable implements AutoCloseable { private long nativeTableView; private HostMemoryBuffer hostBuffer; + private HostTable(long tableHandle, HostMemoryBuffer hostBuffer) { + this.nativeTableView = tableHandle; + this.hostBuffer = hostBuffer; + } + /** * Copies a device table to a host table asynchronously. * NOTE: The caller must synchronize on the stream before examining the data on the host. - * @param table device table to copy + * + * @param table device table to copy * @param stream stream to use for the copy * @return host table */ @@ -63,7 +69,8 @@ public static HostTable fromTableAsync(Table table, Cuda.Stream stream) { /** * Copies a device table to a host table synchronously. - * @param table device table to copy + * + * @param table device table to copy * @param stream stream to use for the copy * @return host table */ @@ -75,6 +82,7 @@ public static HostTable fromTable(Table table, Cuda.Stream stream) { /** * Copies a device table to a host table synchronously on the default stream. + * * @param table device table to copy * @return host table */ @@ -82,10 +90,16 @@ public static HostTable fromTable(Table table) { return fromTable(table, Cuda.DEFAULT_STREAM); } - private HostTable(long tableHandle, HostMemoryBuffer hostBuffer) { - this.nativeTableView = tableHandle; - this.hostBuffer = hostBuffer; - } + private static native long bufferSize(long tableHandle, long stream); + + private static native long copyFromTableAsync(long tableHandle, long hostAddress, long hostSize, + long stream); + + private static native long[] toDeviceColumnViews(long tableHandle, long hostToDevPtrOffset); + + private static native void freeDeviceColumnView(long columnHandle); + + private static native void freeHostTable(long tableHandle); /** * Gets the address of the host_table_view for this host table. @@ -106,6 +120,7 @@ public HostMemoryBuffer getHostBuffer() { * Copies the host table to a device table asynchronously. * NOTE: The caller must synchronize on the stream before closing this instance, * or the copy could still be in-flight when the host memory is invalidated or reused. + * * @param stream stream to use for the copy * @return device table */ @@ -120,7 +135,8 @@ public Table toTableAsync(Cuda.Stream stream) { boolean done = false; try { for (int i = 0; i < columnViewHandles.length; i++) { - columns[i] = ColumnVector.fromViewWithContiguousAllocation(columnViewHandles[i], devBuffer); + columns[i] = + ColumnVector.fromViewWithContiguousAllocation(columnViewHandles[i], devBuffer); columnViewHandles[i] = 0; } table = new Table(columns); @@ -149,6 +165,7 @@ public Table toTableAsync(Cuda.Stream stream) { /** * Copies the host table to a device table synchronously. + * * @param stream stream to use for the copy * @return device table */ @@ -160,6 +177,7 @@ public Table toTable(Cuda.Stream stream) { /** * Copies the host table to a device table synchronously on the default stream. + * * @return device table */ public Table toTable() { @@ -176,15 +194,4 @@ public void close() { hostBuffer = null; } } - - private static native long bufferSize(long tableHandle, long stream); - - private static native long copyFromTableAsync(long tableHandle, long hostAddress, long hostSize, - long stream); - - private static native long[] toDeviceColumnViews(long tableHandle, long hostToDevPtrOffset); - - private static native void freeDeviceColumnView(long columnHandle); - - private static native void freeHostTable(long tableHandle); } diff --git a/src/main/java/com/nvidia/spark/rapids/jni/JSONUtils.java b/src/main/java/com/nvidia/spark/rapids/jni/JSONUtils.java index c58caa62e9..dab9476357 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/JSONUtils.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/JSONUtils.java @@ -16,52 +16,30 @@ package com.nvidia.spark.rapids.jni; -import ai.rapids.cudf.*; - +import ai.rapids.cudf.ColumnVector; +import ai.rapids.cudf.ColumnView; +import ai.rapids.cudf.DType; +import ai.rapids.cudf.DeviceMemoryBuffer; +import ai.rapids.cudf.JSONOptions; +import ai.rapids.cudf.NativeDepsLoader; import java.util.List; public class JSONUtils { - static { - NativeDepsLoader.loadNativeDeps(); - } - public static final int MAX_PATH_DEPTH = getMaxJSONPathDepth(); - public enum PathInstructionType { - WILDCARD, - INDEX, - NAMED - } - - public static class PathInstructionJni { - // type: byte, name: String, index: int - private final byte type; - private final String name; - private final int index; - - public PathInstructionJni(PathInstructionType type, String name, long index) { - this.type = (byte) type.ordinal(); - this.name = name; - if (index > Integer.MAX_VALUE) { - throw new IllegalArgumentException("index is too large " + index); - } - this.index = (int) index; - } - - public PathInstructionJni(PathInstructionType type, String name, int index) { - this.type = (byte) type.ordinal(); - this.name = name; - this.index = index; - } + static { + NativeDepsLoader.loadNativeDeps(); } /** * Extract a JSON path from a JSON column. The path is processed in a Spark compatible way. - * @param input the string column containing JSON + * + * @param input the string column containing JSON * @param pathInstructions the instructions for the path processing * @return the result of processing the path */ - public static ColumnVector getJsonObject(ColumnVector input, PathInstructionJni[] pathInstructions) { + public static ColumnVector getJsonObject(ColumnVector input, + PathInstructionJni[] pathInstructions) { assert (input.getType().equals(DType.STRING)) : "Input must be of STRING type"; int numTotalInstructions = pathInstructions.length; byte[] typeNums = new byte[numTotalInstructions]; @@ -80,6 +58,7 @@ public static ColumnVector getJsonObject(ColumnVector input, PathInstructionJni[ /** * Extract multiple JSON paths from a JSON column. The paths are processed in a Spark * compatible way. + * * @param input the string column containing JSON * @param paths the instructions for multiple paths * @return the result of processing each path in the order that they were passed in @@ -92,15 +71,16 @@ public static ColumnVector[] getJsonObjectMultiplePaths(ColumnVector input, /** * Extract multiple JSON paths from a JSON column. The paths are processed in a Spark * compatible way. - * @param input the string column containing JSON - * @param paths the instructions for multiple paths + * + * @param input the string column containing JSON + * @param paths the instructions for multiple paths * @param memoryBudgetBytes a budget that is used to limit the amount of memory * that is used when processing the paths. This is a soft limit. * A value <= 0 disables this and all paths will be processed in parallel. - * @param parallelOverride Set a maximum number of paths to be processed in parallel. The memory - * budget can limit how many paths can be processed in parallel. This overrides - * that automatically calculated value with a set value for benchmarking purposes. - * A value <= 0 disables this. + * @param parallelOverride Set a maximum number of paths to be processed in parallel. The memory + * budget can limit how many paths can be processed in parallel. This overrides + * that automatically calculated value with a set value for benchmarking purposes. + * A value <= 0 disables this. * @return the result of processing each path in the order that they were passed in */ public static ColumnVector[] getJsonObjectMultiplePaths(ColumnVector input, @@ -137,7 +117,6 @@ public static ColumnVector[] getJsonObjectMultiplePaths(ColumnVector input, return ret; } - /** * Extract key-value pairs for each output map from the given json strings. These key-value are * copied directly as substrings of the input without any type conversion. @@ -152,9 +131,9 @@ public static ColumnVector[] getJsonObjectMultiplePaths(ColumnVector input, * function will just simply copy the input value strings to the output. * * @param input The input strings column in which each row specifies a json object - * @param opts The options for parsing JSON strings + * @param opts The options for parsing JSON strings * @return A map column (i.e., a column of type {@code List>}) in - * which the key-value pairs are extracted directly from the input json strings + * which the key-value pairs are extracted directly from the input json strings */ public static ColumnVector extractRawMapFromJsonString(ColumnView input, JSONOptions opts) { assert (input.getType().equals(DType.STRING)) : "Input must be of STRING type"; @@ -169,12 +148,11 @@ public static ColumnVector extractRawMapFromJsonString(ColumnView input, JSONOpt * Extract key-value pairs for each output map from the given json strings. This method is * similar to {@link #extractRawMapFromJsonString(ColumnView, JSONOptions)} but is deprecated. * - * @deprecated This method is deprecated since it does not have parameters to control various - * JSON reader behaviors. - * * @param input The input strings column in which each row specifies a json object * @return A map column (i.e., a column of type {@code List>}) in - * which the key-value pairs are extracted directly from the input json strings + * which the key-value pairs are extracted directly from the input json strings + * @deprecated This method is deprecated since it does not have parameters to control various + * JSON reader behaviors. */ public static ColumnVector extractRawMapFromJsonString(ColumnView input) { assert (input.getType().equals(DType.STRING)) : "Input must be of STRING type"; @@ -182,30 +160,6 @@ public static ColumnVector extractRawMapFromJsonString(ColumnView input) { true, true, true, true)); } - /** - * A class to hold the result when concatenating JSON strings. - *

- * A long with the concatenated data, the result also contains a vector that indicates - * whether each row in the input is null or empty, and the delimiter used for concatenation. - */ - public static class ConcatenatedJson implements AutoCloseable { - public final ColumnVector isNullOrEmpty; - public final DeviceMemoryBuffer data; - public final char delimiter; - - public ConcatenatedJson(ColumnVector isNullOrEmpty, DeviceMemoryBuffer data, char delimiter) { - this.isNullOrEmpty = isNullOrEmpty; - this.data = data; - this.delimiter = delimiter; - } - - @Override - public void close() { - isNullOrEmpty.close(); - data.close(); - } - } - /** * Concatenate JSON strings in the input column into a single JSON string. *

@@ -231,7 +185,7 @@ public static ConcatenatedJson concatenateJsonStrings(ColumnView input) { * by the input isNull column. * * @param children The children columns of the output structs column - * @param isNull A boolean column specifying the rows at which the output column should be null + * @param isNull A boolean column specifying the rows at which the output column should be null * @return A structs column created from the given children and the isNull column */ public static ColumnVector makeStructs(ColumnView[] children, ColumnView isNull) { @@ -257,7 +211,6 @@ private static native long[] getJsonObjectMultiplePaths(long input, long memoryBudgetBytes, int parallelOverride); - private static native long extractRawMapFromJsonString(long input, boolean normalizeSingleQuotes, boolean leadingZerosAllowed, @@ -267,4 +220,57 @@ private static native long extractRawMapFromJsonString(long input, private static native long[] concatenateJsonStrings(long input); private static native long makeStructs(long[] children, long isNull); + + + public enum PathInstructionType { + WILDCARD, + INDEX, + NAMED + } + + public static class PathInstructionJni { + // type: byte, name: String, index: int + private final byte type; + private final String name; + private final int index; + + public PathInstructionJni(PathInstructionType type, String name, long index) { + this.type = (byte) type.ordinal(); + this.name = name; + if (index > Integer.MAX_VALUE) { + throw new IllegalArgumentException("index is too large " + index); + } + this.index = (int) index; + } + + public PathInstructionJni(PathInstructionType type, String name, int index) { + this.type = (byte) type.ordinal(); + this.name = name; + this.index = index; + } + } + + /** + * A class to hold the result when concatenating JSON strings. + *

+ * A long with the concatenated data, the result also contains a vector that indicates + * whether each row in the input is null or empty, and the delimiter used for concatenation. + */ + public static class ConcatenatedJson implements AutoCloseable { + public final ColumnVector isNullOrEmpty; + public final DeviceMemoryBuffer data; + public final char delimiter; + + public ConcatenatedJson(ColumnVector isNullOrEmpty, DeviceMemoryBuffer data, char delimiter) { + this.isNullOrEmpty = isNullOrEmpty; + this.data = data; + this.delimiter = delimiter; + } + + @Override + public void close() { + isNullOrEmpty.close(); + data.close(); + } + } } diff --git a/src/main/java/com/nvidia/spark/rapids/jni/Pair.java b/src/main/java/com/nvidia/spark/rapids/jni/Pair.java index ac8aa1910c..8a5137ea72 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/Pair.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/Pair.java @@ -20,23 +20,23 @@ * A utility class for holding a pair of values. */ public class Pair { - private final K left; - private final V right; + private final K left; + private final V right; - public Pair(K left, V right) { - this.left = left; - this.right = right; - } + public Pair(K left, V right) { + this.left = left; + this.right = right; + } - public K getLeft() { - return left; - } + public static Pair of(K left, V right) { + return new Pair<>(left, right); + } - public V getRight() { - return right; - } + public K getLeft() { + return left; + } - public static Pair of(K left, V right) { - return new Pair<>(left, right); - } + public V getRight() { + return right; + } } diff --git a/src/main/java/com/nvidia/spark/rapids/jni/ParquetFooter.java b/src/main/java/com/nvidia/spark/rapids/jni/ParquetFooter.java index 681a01d81d..4f6f189bff 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/ParquetFooter.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/ParquetFooter.java @@ -16,8 +16,11 @@ package com.nvidia.spark.rapids.jni; -import ai.rapids.cudf.*; - +import ai.rapids.cudf.CudfException; +import ai.rapids.cudf.DefaultHostMemoryAllocator; +import ai.rapids.cudf.HostMemoryAllocator; +import ai.rapids.cudf.HostMemoryBuffer; +import ai.rapids.cudf.NativeDepsLoader; import java.util.ArrayList; import java.util.Locale; @@ -29,116 +32,19 @@ public class ParquetFooter implements AutoCloseable { NativeDepsLoader.loadNativeDeps(); } - /** - * Base element for all types in a parquet schema. - */ - public static abstract class SchemaElement {} - - private static class ElementWithName { - final String name; - final SchemaElement element; - - public ElementWithName(String name, SchemaElement element) { - this.name = name; - this.element = element; - } - } - - public static class StructElement extends SchemaElement { - public static StructBuilder builder() { - return new StructBuilder(); - } - - private final ElementWithName[] children; - private StructElement(ElementWithName[] children) { - this.children = children; - } - } - - public static class StructBuilder { - ArrayList children = new ArrayList<>(); - - StructBuilder() { - // Empty - } - - public StructBuilder addChild(String name, SchemaElement child) { - children.add(new ElementWithName(name, child)); - return this; - } - - public StructElement build() { - return new StructElement(children.toArray(new ElementWithName[0])); - } - } - - public static class ValueElement extends SchemaElement { - public ValueElement() {} - } - - public static class ListElement extends SchemaElement { - private final SchemaElement item; - public ListElement(SchemaElement item) { - this.item = item; - } - } - - public static class MapElement extends SchemaElement { - private final SchemaElement key; - private final SchemaElement value; - public MapElement(SchemaElement key, SchemaElement value) { - this.key = key; - this.value = value; - } - } - private long nativeHandle; private ParquetFooter(long handle) { nativeHandle = handle; } - /** - * Write the filtered footer back out in a format that is compatible with a parquet - * footer file. This will include the MAGIC PAR1 at the beginning and end and also the - * length of the footer just before the PAR1 at the end. - */ - public HostMemoryBuffer serializeThriftFile(HostMemoryAllocator hostMemoryAllocator) { - return serializeThriftFile(nativeHandle, hostMemoryAllocator); - } - - public HostMemoryBuffer serializeThriftFile() { - return serializeThriftFile(DefaultHostMemoryAllocator.get()); - } - - /** - * Get the number of rows in the footer after filtering. - */ - public long getNumRows() { - return getNumRows(nativeHandle); - } - - /** - * Get the number of top level columns in the footer after filtering. - */ - public int getNumColumns() { - return getNumColumns(nativeHandle); - } - - @Override - public void close() throws Exception { - if (nativeHandle != 0) { - close(nativeHandle); - nativeHandle = 0; - } - } - /** * Recursive helper function to flatten a SchemaElement, so it can more efficiently be passed * through JNI. */ private static void depthFirstNamesHelper(SchemaElement se, String name, boolean makeLowerCase, - ArrayList names, ArrayList numChildren, ArrayList tags) { + ArrayList names, ArrayList numChildren, + ArrayList tags) { if (makeLowerCase) { name = name.toLowerCase(Locale.ROOT); } @@ -181,28 +87,31 @@ private static void depthFirstNamesHelper(SchemaElement se, String name, boolean * Flatten a SchemaElement, so it can more efficiently be passed through JNI. */ private static void depthFirstNames(StructElement schema, boolean makeLowerCase, - ArrayList names, ArrayList numChildren, ArrayList tags) { + ArrayList names, ArrayList numChildren, + ArrayList tags) { // Initialize them with a quick length for non-nested values - for (ElementWithName se: schema.children) { + for (ElementWithName se : schema.children) { depthFirstNamesHelper(se.element, se.name, makeLowerCase, names, numChildren, tags); } } /** - * Read a parquet thrift footer from a buffer and filter it like the java code would. The buffer + * Read a parquet thrift footer from a buffer and filter it like the java code would. The buffer * should only include the thrift footer itself. This includes filtering out row groups that do * not fall within the partition and pruning columns that are not needed. - * @param buffer the buffer to parse the footer out from. + * + * @param buffer the buffer to parse the footer out from. * @param partOffset for a split the start of the split * @param partLength the length of the split - * @param schema a stripped down schema so the code can verify that the types match what is - * expected. The java code does this too. + * @param schema a stripped down schema so the code can verify that the types match what is + * expected. The java code does this too. * @param ignoreCase should case be ignored when matching column names. If this is true then * names should be converted to lower case before being passed to this. * @return a reference to the parsed footer. */ public static ParquetFooter readAndFilter(HostMemoryBuffer buffer, - long partOffset, long partLength, StructElement schema, boolean ignoreCase) { + long partOffset, long partLength, StructElement schema, + boolean ignoreCase) { int parentNumChildren = schema.children.length; ArrayList names = new ArrayList<>(); ArrayList numChildren = new ArrayList<>(); @@ -210,8 +119,7 @@ public static ParquetFooter readAndFilter(HostMemoryBuffer buffer, depthFirstNames(schema, ignoreCase, names, numChildren, tags); return new ParquetFooter( - readAndFilter - (buffer.getAddress(), buffer.getLength(), + readAndFilter(buffer.getAddress(), buffer.getLength(), partOffset, partLength, names.toArray(new String[0]), numChildren.stream().mapToInt(i -> i).toArray(), @@ -220,15 +128,13 @@ public static ParquetFooter readAndFilter(HostMemoryBuffer buffer, ignoreCase)); } - // Native APIS - private static native long readAndFilter(long address, long length, - long partOffset, long partLength, - String[] names, - int[] numChildren, - int[] tags, - int parentNumChildren, - boolean ignoreCase) throws CudfException; + long partOffset, long partLength, + String[] names, + int[] numChildren, + int[] tags, + int parentNumChildren, + boolean ignoreCase) throws CudfException; private static native void close(long nativeHandle); @@ -237,5 +143,110 @@ private static native long readAndFilter(long address, long length, private static native int getNumColumns(long nativeHandle); private static native HostMemoryBuffer serializeThriftFile(long nativeHandle, - HostMemoryAllocator hostMemoryAllocator); + HostMemoryAllocator hostMemoryAllocator); + + /** + * Write the filtered footer back out in a format that is compatible with a parquet + * footer file. This will include the MAGIC PAR1 at the beginning and end and also the + * length of the footer just before the PAR1 at the end. + */ + public HostMemoryBuffer serializeThriftFile(HostMemoryAllocator hostMemoryAllocator) { + return serializeThriftFile(nativeHandle, hostMemoryAllocator); + } + + public HostMemoryBuffer serializeThriftFile() { + return serializeThriftFile(DefaultHostMemoryAllocator.get()); + } + + /** + * Get the number of rows in the footer after filtering. + */ + public long getNumRows() { + return getNumRows(nativeHandle); + } + + /** + * Get the number of top level columns in the footer after filtering. + */ + public int getNumColumns() { + return getNumColumns(nativeHandle); + } + + @Override + public void close() throws Exception { + if (nativeHandle != 0) { + close(nativeHandle); + nativeHandle = 0; + } + } + + /** + * Base element for all types in a parquet schema. + */ + public static abstract class SchemaElement { + } + + private static class ElementWithName { + final String name; + final SchemaElement element; + + public ElementWithName(String name, SchemaElement element) { + this.name = name; + this.element = element; + } + } + + // Native APIS + + public static class StructElement extends SchemaElement { + private final ElementWithName[] children; + + private StructElement(ElementWithName[] children) { + this.children = children; + } + + public static StructBuilder builder() { + return new StructBuilder(); + } + } + + public static class StructBuilder { + ArrayList children = new ArrayList<>(); + + StructBuilder() { + // Empty + } + + public StructBuilder addChild(String name, SchemaElement child) { + children.add(new ElementWithName(name, child)); + return this; + } + + public StructElement build() { + return new StructElement(children.toArray(new ElementWithName[0])); + } + } + + public static class ValueElement extends SchemaElement { + public ValueElement() { + } + } + + public static class ListElement extends SchemaElement { + private final SchemaElement item; + + public ListElement(SchemaElement item) { + this.item = item; + } + } + + public static class MapElement extends SchemaElement { + private final SchemaElement key; + private final SchemaElement value; + + public MapElement(SchemaElement key, SchemaElement value) { + this.key = key; + this.value = value; + } + } } diff --git a/src/main/java/com/nvidia/spark/rapids/jni/ParseURI.java b/src/main/java/com/nvidia/spark/rapids/jni/ParseURI.java index 6b71416dcb..2dd598a2bf 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/ParseURI.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/ParseURI.java @@ -64,7 +64,7 @@ public static ColumnVector parseURIQuery(ColumnView uriColumn) { * Parse query and return a specific parameter for each URI from the incoming column. * * @param URIColumn The input strings column in which each row contains a URI. - * @param String The parameter to extract from the query + * @param String The parameter to extract from the query * @return A string column with query data extracted. */ public static ColumnVector parseURIQueryWithLiteral(ColumnView uriColumn, String query) { @@ -72,17 +72,18 @@ public static ColumnVector parseURIQueryWithLiteral(ColumnView uriColumn, String return new ColumnVector(parseQueryWithLiteral(uriColumn.getNativeView(), query)); } - /** + /** * Parse query and return a specific parameter for each URI from the incoming column. * * @param URIColumn The input strings column in which each row contains a URI. - * @param String The parameter to extract from the query + * @param String The parameter to extract from the query * @return A string column with query data extracted. */ public static ColumnVector parseURIQueryWithColumn(ColumnView uriColumn, ColumnView queryColumn) { assert uriColumn.getType().equals(DType.STRING) : "Input type must be String"; assert queryColumn.getType().equals(DType.STRING) : "Query type must be String"; - return new ColumnVector(parseQueryWithColumn(uriColumn.getNativeView(), queryColumn.getNativeView())); + return new ColumnVector( + parseQueryWithColumn(uriColumn.getNativeView(), queryColumn.getNativeView())); } /** @@ -97,9 +98,14 @@ public static ColumnVector parseURIPath(ColumnView uriColumn) { } private static native long parseProtocol(long inputColumnHandle); + private static native long parseHost(long inputColumnHandle); + private static native long parseQuery(long inputColumnHandle); + private static native long parseQueryWithLiteral(long inputColumnHandle, String query); + private static native long parseQueryWithColumn(long inputColumnHandle, long queryColumnHandle); + private static native long parsePath(long inputColumnHandle); } diff --git a/src/main/java/com/nvidia/spark/rapids/jni/Preconditions.java b/src/main/java/com/nvidia/spark/rapids/jni/Preconditions.java index a6dfcdb104..4dba50c2a6 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/Preconditions.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/Preconditions.java @@ -22,49 +22,51 @@ * This class contains utility methods for checking preconditions. */ public class Preconditions { - /** - * Check if the condition is true, otherwise throw an IllegalStateException with the given message. - */ - public static void ensure(boolean condition, String message) { - if (!condition) { - throw new IllegalStateException(message); - } + /** + * Check if the condition is true, otherwise throw an IllegalStateException with the given message. + */ + public static void ensure(boolean condition, String message) { + if (!condition) { + throw new IllegalStateException(message); } + } - /** - * Check if the condition is true, otherwise throw an IllegalStateException with the given message supplier. - */ - public static void ensure(boolean condition, Supplier messageSupplier) { - if (!condition) { - throw new IllegalStateException(messageSupplier.get()); - } + /** + * Check if the condition is true, otherwise throw an IllegalStateException with the given message supplier. + */ + public static void ensure(boolean condition, Supplier messageSupplier) { + if (!condition) { + throw new IllegalStateException(messageSupplier.get()); } + } - /** - * Check if the value is non-negative, otherwise throw an IllegalArgumentException with the given message. - * @param value the value to check - * @param name the name of the value - * @return the value if it is non-negative - * @throws IllegalArgumentException if the value is negative - */ - public static int ensureNonNegative(int value, String name) { - if (value < 0) { - throw new IllegalArgumentException(name + " must be non-negative, but was " + value); - } - return value; + /** + * Check if the value is non-negative, otherwise throw an IllegalArgumentException with the given message. + * + * @param value the value to check + * @param name the name of the value + * @return the value if it is non-negative + * @throws IllegalArgumentException if the value is negative + */ + public static int ensureNonNegative(int value, String name) { + if (value < 0) { + throw new IllegalArgumentException(name + " must be non-negative, but was " + value); } + return value; + } - /** - * Check if the value is non-negative, otherwise throw an IllegalArgumentException with the given message. - * @param value the value to check - * @param name the name of the value - * @return the value if it is non-negative - * @throws IllegalArgumentException if the value is negative - */ - public static long ensureNonNegative(long value, String name) { - if (value < 0) { - throw new IllegalArgumentException(name + " must be non-negative, but was " + value); - } - return value; + /** + * Check if the value is non-negative, otherwise throw an IllegalArgumentException with the given message. + * + * @param value the value to check + * @param name the name of the value + * @return the value if it is non-negative + * @throws IllegalArgumentException if the value is negative + */ + public static long ensureNonNegative(long value, String name) { + if (value < 0) { + throw new IllegalArgumentException(name + " must be non-negative, but was " + value); } + return value; + } } diff --git a/src/main/java/com/nvidia/spark/rapids/jni/Profiler.java b/src/main/java/com/nvidia/spark/rapids/jni/Profiler.java index 85e6b4a0a3..336540d27f 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/Profiler.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/Profiler.java @@ -13,15 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package com.nvidia.spark.rapids.jni; import ai.rapids.cudf.NativeDepsLoader; - import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; -/** Profiler that collects CUDA and NVTX events for the current process. */ +/** + * Profiler that collects CUDA and NVTX events for the current process. + */ public class Profiler { private static final long DEFAULT_WRITE_BUFFER_SIZE = 1024 * 1024; private static final int DEFAULT_FLUSH_PERIOD_MILLIS = 0; @@ -31,6 +33,7 @@ public class Profiler { /** * Initialize the profiler in a standby state. The start method must be called after this * to start collecting profiling data. + * * @param config profiler configuration */ public static void init(DataWriter w, Config config) { @@ -53,10 +56,11 @@ public static void init(DataWriter w, Config config) { * Deprecated. Use init(Config) instead. * Initialize the profiler in a standby state. The start method must be called after this * to start collecting profiling data. - * @param w data writer for writing profiling data - * @param writeBufferSize size of host memory buffer to use for collecting profiling data. - * Recommended to be between 1-8 MB in size to balance callback - * overhead with latency. + * + * @param w data writer for writing profiling data + * @param writeBufferSize size of host memory buffer to use for collecting profiling data. + * Recommended to be between 1-8 MB in size to balance callback + * overhead with latency. * @param flushPeriodMillis time period in milliseconds to explicitly flush collected * profiling data to the writer. A value <= 0 will disable explicit * flushing. @@ -119,17 +123,22 @@ private static native void nativeInit(String libPath, DataWriter writer, private static native void nativeShutdown(); - /** Interface for profiler data writers */ + /** + * Interface for profiler data writers + */ public interface DataWriter extends AutoCloseable { /** * Called by the profiler to write a block of profiling data. Profiling data is written * in a size-prefixed flatbuffer format. See profiler.fbs for the schema. + * * @param data profiling data to be written */ void write(ByteBuffer data); } - /** Profiler configuration class. **/ + /** + * Profiler configuration class. + **/ public static class Config { private final long writeBufferSize; private final int flushPeriodMillis; @@ -141,7 +150,9 @@ public static class Config { this.allocAsyncCapturing = builder.allocAsyncCapturing; } - /** Builder interface for profiler configuration. **/ + /** + * Builder interface for profiler configuration. + **/ public static class Builder { private long writeBufferSize = DEFAULT_WRITE_BUFFER_SIZE; private int flushPeriodMillis = DEFAULT_FLUSH_PERIOD_MILLIS; @@ -151,6 +162,7 @@ public static class Builder { * Configure the size of the host memory buffer used for collecting profiling data. * Recommended to be between 1 to 8 MB in size to balance callback overhead with * latency. + * * @param writeBufferSize size of buffer in bytes */ public Builder withWriteBufferSize(long writeBufferSize) { @@ -160,6 +172,7 @@ public Builder withWriteBufferSize(long writeBufferSize) { /** * Configure the time period to explicitly flush collected profiling data to the writer. + * * @param flushPeriodMillis time period in milliseconds. A value <= 0 will disable explicit * flushing. */ @@ -170,6 +183,7 @@ public Builder withFlushPeriodMillis(int flushPeriodMillis) { /** * Configure whether async allocation and free events are captured by the profiler. + * * @param allocAsyncCapturing true if async allocation and free events should be captured, * false otherwise. */ @@ -178,7 +192,9 @@ public Builder withAllocAsyncCapturing(boolean allocAsyncCapturing) { return this; } - /** Build a profiler configuration object. */ + /** + * Build a profiler configuration object. + */ public Config build() { return new Config(this); } diff --git a/src/main/java/com/nvidia/spark/rapids/jni/RegexRewriteUtils.java b/src/main/java/com/nvidia/spark/rapids/jni/RegexRewriteUtils.java index 9277c3e0f9..b940a5093f 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/RegexRewriteUtils.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/RegexRewriteUtils.java @@ -16,29 +16,36 @@ package com.nvidia.spark.rapids.jni; -import ai.rapids.cudf.*; +import ai.rapids.cudf.ColumnVector; +import ai.rapids.cudf.CudfAccessor; +import ai.rapids.cudf.DType; +import ai.rapids.cudf.NativeDepsLoader; +import ai.rapids.cudf.Scalar; public class RegexRewriteUtils { static { NativeDepsLoader.loadNativeDeps(); } -/** - * @brief Check if input string contains regex pattern `literal[start-end]{len,}`, which means - * a literal string followed by a range of characters in the range of start to end, with at least - * len characters. - * - * @param strings Column of strings to check for literal. - * @param literal UTF-8 encoded string to check in strings column. - * @param len Minimum number of characters to check after the literal. - * @param start Minimum UTF-8 codepoint value to check for in the range. - * @param end Maximum UTF-8 codepoint value to check for in the range. - * @return ColumnVector of booleans where true indicates the string contains the pattern. - */ - public static ColumnVector literalRangePattern(ColumnVector input, Scalar literal, int len, int start, int end) { - assert(input.getType().equals(DType.STRING)) : "column must be a String"; - return new ColumnVector(literalRangePattern(input.getNativeView(), CudfAccessor.getScalarHandle(literal), len, start, end)); + /** + * @param strings Column of strings to check for literal. + * @param literal UTF-8 encoded string to check in strings column. + * @param len Minimum number of characters to check after the literal. + * @param start Minimum UTF-8 codepoint value to check for in the range. + * @param end Maximum UTF-8 codepoint value to check for in the range. + * @return ColumnVector of booleans where true indicates the string contains the pattern. + * @brief Check if input string contains regex pattern `literal[start-end]{len,}`, which means + * a literal string followed by a range of characters in the range of start to end, with at least + * len characters. + */ + public static ColumnVector literalRangePattern(ColumnVector input, Scalar literal, int len, + int start, int end) { + assert (input.getType().equals(DType.STRING)) : "column must be a String"; + return new ColumnVector( + literalRangePattern(input.getNativeView(), CudfAccessor.getScalarHandle(literal), len, + start, end)); } - private static native long literalRangePattern(long input, long literal, int len, int start, int end); + private static native long literalRangePattern(long input, long literal, int len, int start, + int end); } diff --git a/src/main/java/com/nvidia/spark/rapids/jni/RmmSpark.java b/src/main/java/com/nvidia/spark/rapids/jni/RmmSpark.java index 45a234dcca..72df229d27 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/RmmSpark.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/RmmSpark.java @@ -23,20 +23,11 @@ import ai.rapids.cudf.RmmException; import ai.rapids.cudf.RmmTrackingResourceAdaptor; -import java.util.Arrays; -import java.util.Map; - /** * Initialize RMM in ways that are specific to Spark. */ public class RmmSpark { - public enum OomInjectionType { - CPU_OR_GPU, - CPU, - GPU; - } - private static volatile SparkResourceAdaptor sra = null; /** @@ -50,13 +41,15 @@ public static void setEventHandler(RmmEventHandler handler) throws RmmException /** * Set the event handler in a way that Spark wants it. For now this is the same as RMM, but in * the future it is likely to change. - * @param handler the handler to set + * + * @param handler the handler to set * @param logLocation the location where you want spark state transitions. Alloc and free logging * is handled separately when setting up RMM. "stderr" or "stdout" are treated * as `std::cerr` and `std::cout` respectively in native code. Anything else * is treated as a file. */ - public static void setEventHandler(RmmEventHandler handler, String logLocation) throws RmmException { + public static void setEventHandler(RmmEventHandler handler, String logLocation) + throws RmmException { // synchronize with RMM not RmmSpark to stay in sync with Rmm itself. synchronized (Rmm.class) { // RmmException constructor is not public, so we have to use a different exception @@ -125,8 +118,9 @@ public static long getCurrentThreadId() { /** * Indicate that a given thread is dedicated to a specific task. This thread can be part of a * thread pool, but if it blocks it can never transitively block another active task. + * * @param threadId the thread ID to use - * @param taskId the task ID this thread is working on. + * @param taskId the task ID this thread is working on. */ public static void startDedicatedTaskThread(long threadId, long taskId, Thread thread) { synchronized (Rmm.class) { @@ -140,6 +134,7 @@ public static void startDedicatedTaskThread(long threadId, long taskId, Thread t /** * Indicate that the current thread is dedicated to a specific task. This thread can be part of * a thread pool, but if this blocks it can never transitively block another active task. + * * @param taskId the task ID this thread is working on. */ public static void currentThreadIsDedicatedToTask(long taskId) { @@ -148,9 +143,10 @@ public static void currentThreadIsDedicatedToTask(long taskId) { /** * A shuffle thread has started to work on some tasks. + * * @param threadId the thread ID (not java thread id). - * @param thread the java thread - * @param taskIds the IDs of tasks that this is starting work on. + * @param thread the java thread + * @param taskIds the IDs of tasks that this is starting work on. */ public static void shuffleThreadWorkingTasks(long threadId, Thread thread, long[] taskIds) { synchronized (Rmm.class) { @@ -163,6 +159,7 @@ public static void shuffleThreadWorkingTasks(long threadId, Thread thread, long[ /** * The current thread is a shuffle thread and has started to work on some tasks. + * * @param taskIds the IDs of the tasks that this is starting work on. */ public static void shuffleThreadWorkingOnTasks(long[] taskIds) { @@ -172,12 +169,13 @@ public static void shuffleThreadWorkingOnTasks(long[] taskIds) { /** * The current thread which is in a thread pool that could transitively block other tasks has * started to work on a task. + * * @param taskId the ID of the task that this is starting work on. */ public static void poolThreadWorkingOnTask(long taskId) { long threadId = getCurrentThreadId(); Thread thread = Thread.currentThread(); - long[] taskIds = new long[]{taskId}; + long[] taskIds = new long[] {taskId}; synchronized (Rmm.class) { if (sra != null && sra.isOpen()) { ThreadStateRegistry.addThread(threadId, thread); @@ -189,8 +187,9 @@ public static void poolThreadWorkingOnTask(long taskId) { /** * A thread in a thread pool that could transitively block other tasks has finished work * on some tasks. + * * @param threadId the thread ID (not java thread id). - * @param taskIds the IDs of the tasks that are done. + * @param taskIds the IDs of the tasks that are done. */ public static void poolThreadFinishedForTasks(long threadId, long[] taskIds) { synchronized (Rmm.class) { @@ -202,8 +201,9 @@ public static void poolThreadFinishedForTasks(long threadId, long[] taskIds) { /** * A shuffle thread has finished work on some tasks. + * * @param threadId the thread ID (not java thread id). - * @param taskIds the IDs of the tasks that are done. + * @param taskIds the IDs of the tasks that are done. */ private static void shuffleThreadFinishedForTasks(long threadId, long[] taskIds) { poolThreadFinishedForTasks(threadId, taskIds); @@ -212,6 +212,7 @@ private static void shuffleThreadFinishedForTasks(long threadId, long[] taskIds) /** * The current thread which is in a thread pool that could transitively block other tasks * has finished work on some tasks. + * * @param taskIds the IDs of the tasks that are done. */ public static void poolThreadFinishedForTasks(long[] taskIds) { @@ -220,6 +221,7 @@ public static void poolThreadFinishedForTasks(long[] taskIds) { /** * The current shuffle thread has finished work on some tasks. + * * @param taskIds the IDs of the tasks that are done. */ public static void shuffleThreadFinishedForTasks(long[] taskIds) { @@ -229,14 +231,16 @@ public static void shuffleThreadFinishedForTasks(long[] taskIds) { /** * The current thread which is in a thread pool that could transitively block other tasks * has finished work on a task. + * * @param taskId the ID of the task that is done. */ public static void poolThreadFinishedForTask(long taskId) { - poolThreadFinishedForTasks(getCurrentThreadId(), new long[]{taskId}); + poolThreadFinishedForTasks(getCurrentThreadId(), new long[] {taskId}); } /** * Indicate that a retry block has started for a given thread. + * * @param threadId the id of the thread, not the java ID. */ public static void startRetryBlock(long threadId) { @@ -256,6 +260,7 @@ public static void currentThreadStartRetryBlock() { /** * Indicate that a retry block has ended for a given thread. + * * @param threadId the id of the thread, not the java ID. */ public static void endRetryBlock(long threadId) { @@ -283,6 +288,7 @@ private static void checkAndBreakDeadlocks() { /** * Remove the given thread ID from being associated with a given task + * * @param threadId the ID of the thread that is no longer a part of a task or shuffle * (not java thread id). */ @@ -305,6 +311,7 @@ public static void removeCurrentDedicatedThreadAssociation(long taskId) { * Remove all task associations for a given thread. This is intended to be used as a part * of tests when a thread is shutting down, or for a pool thread when it is fully done. * Dedicated task thread typically are cleaned when the task itself completes. + * * @param threadId the id of the thread to clean up */ public static void removeAllThreadAssociation(long threadId) { @@ -327,6 +334,7 @@ public static void removeAllCurrentThreadAssociation() { /** * Indicate that a given task is done and if there are any threads still associated with it * then they should also be removed. + * * @param taskId the ID of the task that has completed. */ public static void taskDone(long taskId) { @@ -339,6 +347,7 @@ public static void taskDone(long taskId) { /** * A dedicated task thread is about to submit work to a pool that could transitively block it. + * * @param threadId the ID of the thread that is about to submit the work. */ public static void submittingToPool(long threadId) { @@ -360,6 +369,7 @@ public static void submittingToPool() { /** * A dedicated task thread is about to wait on work done on a pool that could transitively * block it. + * * @param threadId the ID of the thread that is about to wait. */ public static void waitingOnPool(long threadId) { @@ -381,6 +391,7 @@ public static void waitingOnPool() { /** * A dedicated task thread is done waiting on a pool, either for a result or after submitting * something to the pool. + * * @param threadId the ID of the thread that is done. */ public static void doneWaitingOnPool(long threadId) { @@ -430,6 +441,7 @@ public static void blockThreadUntilReady() { /** * Force the thread with the given ID to throw a GpuRetryOOM or CpuRetryOOM on their next * allocation attempt, depending on the type of allocation being done. + * * @param threadId the ID of the thread to throw the exception (not java thread id). */ public static void forceRetryOOM(long threadId) { @@ -439,9 +451,10 @@ public static void forceRetryOOM(long threadId) { /** * Force the thread with the given ID to throw a GpuRetryOOM or CpuRetryOOM on their next * allocation attempt, depending on the type of allocation being done. - * @param threadId the ID of the thread to throw the exception (not java thread id). - * @param numOOMs the number of times the *RetryOOM should be thrown - * @param oomMode the ordinal corresponding to OomInjectionType to filter allocations + * + * @param threadId the ID of the thread to throw the exception (not java thread id). + * @param numOOMs the number of times the *RetryOOM should be thrown + * @param oomMode the ordinal corresponding to OomInjectionType to filter allocations * @param skipCount how many matching allocations to skip */ public static void forceRetryOOM(long threadId, int numOOMs, int oomMode, int skipCount) { @@ -461,6 +474,7 @@ public static void forceRetryOOM(long threadId, int numOOMs) { /** * Force the thread with the given ID to throw a GpuSplitAndRetryOOM of CpuSplitAndRetryOOM * on their next allocation attempt, depending on the allocation being done. + * * @param threadId the ID of the thread to throw the exception (not java thread id). */ public static void forceSplitAndRetryOOM(long threadId) { @@ -470,9 +484,10 @@ public static void forceSplitAndRetryOOM(long threadId) { /** * Force the thread with the given ID to throw a GpuSplitAndRetryOOM or CpuSplitAndRetryOOm * on their next allocation attempt, depending on the allocation being done. - * @param threadId the ID of the thread to throw the exception (not java thread id). - * @param numOOMs the number of times the *SplitAndRetryOOM should be thrown - * @param oomMode the ordinal corresponding to OomInjectionType to filter allocations + * + * @param threadId the ID of the thread to throw the exception (not java thread id). + * @param numOOMs the number of times the *SplitAndRetryOOM should be thrown + * @param oomMode the ordinal corresponding to OomInjectionType to filter allocations * @param skipCount how many matching allocations to skip */ public static void forceSplitAndRetryOOM(long threadId, int numOOMs, int oomMode, int skipCount) { @@ -492,6 +507,7 @@ public static void forceSplitAndRetryOOM(long threadId, int numOOMs) { /** * Force the thread with the given ID to throw a CudfException on their next allocation attempt. * This is to simulate a cuDF exception being thrown from a kernel and test retry handling code. + * * @param threadId the ID of the thread to throw the exception (not java thread id). */ public static void forceCudfException(long threadId) { @@ -501,6 +517,7 @@ public static void forceCudfException(long threadId) { /** * Force the thread with the given ID to throw a CudfException on their next allocation attempt. * This is to simulate a cuDF exception being thrown from a kernel and test retry handling code. + * * @param threadId the ID of the thread to throw the exception (not java thread id). * @param numTimes the number of times the CudfException should be thrown */ @@ -527,6 +544,7 @@ public static RmmSparkThreadState getStateOf(long threadId) { /** * Get the number of retry exceptions that were thrown and reset the metric. + * * @param taskId the id of the task to get the metric for. * @return the number of times it was thrown or 0 if in the UNKNOWN state. */ @@ -543,6 +561,7 @@ public static int getAndResetNumRetryThrow(long taskId) { /** * Get the number of split and retry exceptions that were thrown and reset the metric. + * * @param taskId the id of the task to get the metric for. * @return the number of times it was thrown or 0 if in the UNKNOWN state. */ @@ -559,6 +578,7 @@ public static int getAndResetNumSplitRetryThrow(long taskId) { /** * Get how long, in nanoseconds, that the task was blocked for + * * @param taskId the id of the task to get the metric for. * @return the time the task was blocked or 0 if in the UNKNOWN state. */ @@ -575,6 +595,7 @@ public static long getAndResetBlockTimeNs(long taskId) { /** * Get how long, in nanoseconds, that this task lost in computation time due to retries. + * * @param taskId the id of the task to get the metric for. * @return the time the task did computation that was lost. */ @@ -591,6 +612,7 @@ public static long getAndResetComputeTimeLostToRetryNs(long taskId) { /** * Get the max device memory footprint, in bytes, that this task had allocated over its lifetime + * * @param taskId the id of the task to get the metric for. * @return the max device memory footprint. */ @@ -608,7 +630,8 @@ public static long getAndResetGpuMaxMemoryAllocated(long taskId) { /** * Called before doing an allocation on the CPU. This could throw an injected exception to help * with testing. - * @param amount the amount of memory being requested + * + * @param amount the amount of memory being requested * @param blocking is this for a blocking allocate or a non-blocking one. * @return a boolean that indicates if the allocation is recursive. Note that recursive * allocations on the CPU are only allowed with non-blocking allocations. This must be passed @@ -628,9 +651,10 @@ public static boolean preCpuAlloc(long amount, boolean blocking) { /** * The allocation that was going to be done succeeded. - * @param ptr a pointer to the memory that was allocated. - * @param amount the amount of memory that was allocated. - * @param blocking is this for a blocking allocate or a non-blocking one. + * + * @param ptr a pointer to the memory that was allocated. + * @param amount the amount of memory that was allocated. + * @param blocking is this for a blocking allocate or a non-blocking one. * @param wasRecursive the boolean that was returned from `preCpuAlloc`. */ public static void postCpuAllocSuccess(long ptr, long amount, boolean blocking, @@ -646,8 +670,9 @@ public static void postCpuAllocSuccess(long ptr, long amount, boolean blocking, /** * The allocation failed, and spilling didn't save it. - * @param wasOom was the failure caused by an OOM or something else. - * @param blocking is this for a blocking allocate or a non-blocking one. + * + * @param wasOom was the failure caused by an OOM or something else. + * @param blocking is this for a blocking allocate or a non-blocking one. * @param wasRecursive the boolean that was returned from `preCpuAlloc`. * @return true if the allocation should be retried else false if the state machine * thinks that a retry would not help. @@ -666,7 +691,8 @@ public static boolean postCpuAllocFailed(boolean wasOom, boolean blocking, boole /** * Some CPU memory was freed. - * @param ptr a pointer to the memory being deallocated. + * + * @param ptr a pointer to the memory being deallocated. * @param amount the amount that was made available. */ public static void cpuDeallocate(long ptr, long amount) { @@ -679,4 +705,10 @@ public static void cpuDeallocate(long ptr, long amount) { } } + public enum OomInjectionType { + CPU_OR_GPU, + CPU, + GPU + } + } diff --git a/src/main/java/com/nvidia/spark/rapids/jni/RowConversion.java b/src/main/java/com/nvidia/spark/rapids/jni/RowConversion.java index fdb7f40480..093643209a 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/RowConversion.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/RowConversion.java @@ -16,9 +16,15 @@ package com.nvidia.spark.rapids.jni; -import ai.rapids.cudf.*; +import ai.rapids.cudf.ColumnVector; +import ai.rapids.cudf.ColumnView; +import ai.rapids.cudf.DType; +import ai.rapids.cudf.NativeDepsLoader; +import ai.rapids.cudf.Table; -/** Utility class for converting between column major and row major data */ +/** + * Utility class for converting between column major and row major data + */ public class RowConversion { static { NativeDepsLoader.loadNativeDeps(); @@ -27,7 +33,7 @@ public class RowConversion { /** * For details about how this method functions refer to * {@link #convertToRowsFixedWidthOptimized()}. - * + *

* The only thing different between this method and {@link #convertToRowsFixedWidthOptimized()} * is that this can handle rougly 250M columns while {@link #convertToRowsFixedWidthOptimized()} * can only handle columns less than 100 @@ -63,7 +69,7 @@ public static ColumnVector[] convertToRows(Table table) { * |row N+1 | validity for row N+1 | padding | * ... * - * + *

* The format of each row is similar in layout to a C struct where each column will have padding * in front of it to align it properly. Each row has padding inserted at the end so the next row * is aligned to a 64-bit boundary. This is so that the first column will always start at the @@ -79,8 +85,8 @@ public static ColumnVector[] convertToRows(Table table) { * | A - BOOL8 (8-bit) | B - INT16 (16-bit) | C - DURATION_DAYS (32-bit) | * *

- * Will have a layout that looks like - *

+   * Will have a layout that looks like
+   * 

    *  | A_0 | P | B_0 | B_1 | C_0 | C_1 | C_2 | C_3 | V0 | P | P | P | P | P | P | P |
    * 
*

@@ -127,14 +133,14 @@ public static ColumnVector[] convertToRowsFixedWidthOptimized(Table table) { /** * Convert a column of list of bytes that is formatted like the output from `convertToRows` * and convert it back to a table. - * + *

* NOTE: This method doesn't support nested types * - * @param vec the row data to process. + * @param vec the row data to process. * @param schema the types of each column. * @return the parsed table. */ - public static Table convertFromRows(ColumnView vec, DType ... schema) { + public static Table convertFromRows(ColumnView vec, DType... schema) { int[] types = new int[schema.length]; int[] scale = new int[schema.length]; for (int i = 0; i < schema.length; i++) { @@ -148,14 +154,14 @@ public static Table convertFromRows(ColumnView vec, DType ... schema) { /** * Convert a column of list of bytes that is formatted like the output from `convertToRows` * and convert it back to a table. - * + *

* NOTE: This method doesn't support nested types * - * @param vec the row data to process. + * @param vec the row data to process. * @param schema the types of each column. * @return the parsed table. */ - public static Table convertFromRowsFixedWidthOptimized(ColumnView vec, DType ... schema) { + public static Table convertFromRowsFixedWidthOptimized(ColumnView vec, DType... schema) { int[] types = new int[schema.length]; int[] scale = new int[schema.length]; for (int i = 0; i < schema.length; i++) { @@ -167,8 +173,11 @@ public static Table convertFromRowsFixedWidthOptimized(ColumnView vec, DType ... } private static native long[] convertToRows(long nativeHandle); + private static native long[] convertToRowsFixedWidthOptimized(long nativeHandle); private static native long[] convertFromRows(long nativeColumnView, int[] types, int[] scale); - private static native long[] convertFromRowsFixedWidthOptimized(long nativeColumnView, int[] types, int[] scale); + + private static native long[] convertFromRowsFixedWidthOptimized(long nativeColumnView, + int[] types, int[] scale); } diff --git a/src/main/java/com/nvidia/spark/rapids/jni/SparkResourceAdaptor.java b/src/main/java/com/nvidia/spark/rapids/jni/SparkResourceAdaptor.java index 9e3414f7d3..c4bbe23112 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/SparkResourceAdaptor.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/SparkResourceAdaptor.java @@ -13,21 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.nvidia.spark.rapids.jni; -import com.nvidia.spark.rapids.jni.RmmSpark.OomInjectionType; +package com.nvidia.spark.rapids.jni; import ai.rapids.cudf.NativeDepsLoader; import ai.rapids.cudf.RmmDeviceMemoryResource; import ai.rapids.cudf.RmmEventHandlerResourceAdaptor; import ai.rapids.cudf.RmmWrappingDeviceMemoryResource; +import com.nvidia.spark.rapids.jni.RmmSpark.OomInjectionType; public class SparkResourceAdaptor - extends RmmWrappingDeviceMemoryResource> { - static { - NativeDepsLoader.loadNativeDeps(); - } - + extends + RmmWrappingDeviceMemoryResource> { /** * How long does the SparkResourceAdaptor pool thread states as a watchdog to break up potential * deadlocks. @@ -35,11 +32,16 @@ public class SparkResourceAdaptor private static final long pollingPeriod = Long.getLong( "ai.rapids.cudf.spark.rmmWatchdogPollingPeriod", 100); + static { + NativeDepsLoader.loadNativeDeps(); + } + private long handle = 0; - private Thread watchDog; + private final Thread watchDog; /** * Create a new tracking resource adaptor. + * * @param wrapped the memory resource to track allocations. This should not be reused. */ public SparkResourceAdaptor(RmmEventHandlerResourceAdaptor wrapped) { @@ -48,13 +50,14 @@ public SparkResourceAdaptor(RmmEventHandlerResourceAdaptor wrapped, - String logLoc) { + String logLoc) { super(wrapped); watchDog = new Thread(() -> { try { @@ -78,6 +81,71 @@ public SparkResourceAdaptor(RmmEventHandlerResourceAdaptor 0) { @@ -143,8 +212,9 @@ public void poolThreadFinishedForTasks(long threadId, long[] taskIds) { /** * Remove the given thread ID from any association. + * * @param threadId the ID of the thread that is no longer a part of a task or shuffle (not java thread id). - * @param taskId the task that is being removed. If the task id is -1, then any/all tasks are removed. + * @param taskId the task that is being removed. If the task id is -1, then any/all tasks are removed. */ public void removeThreadAssociation(long threadId, long taskId) { removeThreadAssociation(getHandle(), threadId, taskId); @@ -153,6 +223,7 @@ public void removeThreadAssociation(long threadId, long taskId) { /** * Indicate that a given task is done and if there are any threads still associated with it * then they should also be removed. + * * @param taskId the ID of the task that has completed. */ public void taskDone(long taskId) { @@ -161,6 +232,7 @@ public void taskDone(long taskId) { /** * A dedicated task thread is going to submit work to a pool. + * * @param threadId the ID of the thread that will submit the work. */ public void submittingToPool(long threadId) { @@ -169,6 +241,7 @@ public void submittingToPool(long threadId) { /** * A dedicated task thread is going to wait on work in a pool to complete. + * * @param threadId the ID of the thread that will submit the work. */ public void waitingOnPool(long threadId) { @@ -178,6 +251,7 @@ public void waitingOnPool(long threadId) { /** * A dedicated task thread is done waiting on a pool. This could be because of submitting * something to the pool or waiting on a result from the pool. + * * @param threadId the ID of the thread that is done. */ public void doneWaitingOnPool(long threadId) { @@ -186,9 +260,10 @@ public void doneWaitingOnPool(long threadId) { /** * Force the thread with the given ID to throw a GpuRetryOOM on their next allocation attempt. - * @param threadId the ID of the thread to throw the exception (not java thread id). - * @param numOOMs the number of times the GpuRetryOOM should be thrown - * @param oomMode ordinal of the corresponding RmmSpark.OomInjectionType + * + * @param threadId the ID of the thread to throw the exception (not java thread id). + * @param numOOMs the number of times the GpuRetryOOM should be thrown + * @param oomMode ordinal of the corresponding RmmSpark.OomInjectionType * @param skipCount the number of times a matching allocation is skipped before injecting the first OOM */ public void forceRetryOOM(long threadId, int numOOMs, int oomMode, int skipCount) { @@ -199,15 +274,16 @@ public void forceRetryOOM(long threadId, int numOOMs, int oomMode, int skipCount private void validateOOMInjectionParams(int numOOMs, int oomMode, int skipCount) { assert numOOMs >= 0 : "non-negative numOoms expected: actual=" + numOOMs; assert skipCount >= 0 : "non-negative skipCount expected: actual=" + skipCount; - assert oomMode >= 0 && oomMode < OomInjectionType.values().length: - "non-negative oomMode<" + OomInjectionType.values().length + " expected: actual=" + oomMode; + assert oomMode >= 0 && oomMode < OomInjectionType.values().length : + "non-negative oomMode<" + OomInjectionType.values().length + " expected: actual=" + oomMode; } /** * Force the thread with the given ID to throw a GpuSplitAndRetryOOM on their next allocation attempt. - * @param threadId the ID of the thread to throw the exception (not java thread id). - * @param numOOMs the number of times the GpuSplitAndRetryOOM should be thrown - * @param oomMode ordinal of the corresponding RmmSpark.OomInjectionType + * + * @param threadId the ID of the thread to throw the exception (not java thread id). + * @param numOOMs the number of times the GpuSplitAndRetryOOM should be thrown + * @param oomMode ordinal of the corresponding RmmSpark.OomInjectionType * @param skipCount the number of times a matching allocation is skipped before injecting the first OOM */ public void forceSplitAndRetryOOM(long threadId, int numOOMs, int oomMode, int skipCount) { @@ -217,6 +293,7 @@ public void forceSplitAndRetryOOM(long threadId, int numOOMs, int oomMode, int s /** * Force the thread with the given ID to throw a GpuSplitAndRetryOOM on their next allocation attempt. + * * @param threadId the ID of the thread to throw the exception (not java thread id). * @param numTimes the number of times the CudfException should be thrown */ @@ -255,11 +332,11 @@ public long getAndResetGpuMaxMemoryAllocated(long taskId) { return getAndResetGpuMaxMemoryAllocated(getHandle(), taskId); } - /** * Called before doing an allocation on the CPU. This could throw an injected exception to help * with testing. - * @param amount the amount of memory being requested + * + * @param amount the amount of memory being requested * @param blocking is this for a blocking allocate or a non-blocking one. */ public boolean preCpuAlloc(long amount, boolean blocking) { @@ -268,9 +345,10 @@ public boolean preCpuAlloc(long amount, boolean blocking) { /** * The allocation that was going to be done succeeded. - * @param ptr a pointer to the memory that was allocated. - * @param amount the amount of memory that was allocated. - * @param blocking is this for a blocking allocate or a non-blocking one. + * + * @param ptr a pointer to the memory that was allocated. + * @param amount the amount of memory that was allocated. + * @param blocking is this for a blocking allocate or a non-blocking one. * @param wasRecursive the result of calling preCpuAlloc. */ public void postCpuAllocSuccess(long ptr, long amount, boolean blocking, boolean wasRecursive) { @@ -279,8 +357,9 @@ public void postCpuAllocSuccess(long ptr, long amount, boolean blocking, boolean /** * The allocation failed, and spilling didn't save it. - * @param wasOom was the failure caused by an OOM or something else. - * @param blocking is this for a blocking allocate or a non-blocking one. + * + * @param wasOom was the failure caused by an OOM or something else. + * @param blocking is this for a blocking allocate or a non-blocking one. * @param wasRecursive the result of calling preCpuAlloc * @return true if the allocation should be retried else false if the state machine * thinks that a retry would not help. @@ -291,46 +370,11 @@ public boolean postCpuAllocFailed(boolean wasOom, boolean blocking, boolean wasR /** * Some CPU memory was freed. - * @param ptr a pointer to the memory being deallocated. + * + * @param ptr a pointer to the memory being deallocated. * @param amount the amount that was made available. */ public void cpuDeallocate(long ptr, long amount) { cpuDeallocate(getHandle(), ptr, amount); } - - /** - * Get the ID of the current thread that can be used with the other SparkResourceAdaptor APIs. - * Don't use the java thread ID. They are not related. - */ - public static native long getCurrentThreadId(); - - private native static long createNewAdaptor(long wrappedHandle, String logLoc); - private native static void releaseAdaptor(long handle); - private static native void startDedicatedTaskThread(long handle, long threadId, long taskId); - private static native void poolThreadWorkingOnTasks(long handle, boolean isForShuffle, long threadId, long[] taskIds); - private static native void poolThreadFinishedForTasks(long handle, long threadId, long[] taskIds); - private static native void removeThreadAssociation(long handle, long threadId, long taskId); - private static native void taskDone(long handle, long taskId); - private static native void submittingToPool(long handle, long threadId); - private static native void waitingOnPool(long handle, long threadId); - private static native void doneWaitingOnPool(long handle, long threadId); - private static native void forceRetryOOM(long handle, long threadId, int numOOMs, int oomMode, int skipCount); - private static native void forceSplitAndRetryOOM(long handle, long threadId, int numOOMs, int oomMode, int skipCount); - private static native void forceCudfException(long handle, long threadId, int numTimes); - private static native void blockThreadUntilReady(long handle); - private static native int getStateOf(long handle, long threadId); - private static native int getAndResetRetryThrowInternal(long handle, long taskId); - private static native int getAndResetSplitRetryThrowInternal(long handle, long taskId); - private static native long getAndResetBlockTimeInternal(long handle, long taskId); - private static native long getAndResetComputeTimeLostToRetry(long handle, long taskId); - private static native long getAndResetGpuMaxMemoryAllocated(long handle, long taskId); - private static native void startRetryBlock(long handle, long threadId); - private static native void endRetryBlock(long handle, long threadId); - private static native void checkAndBreakDeadlocks(long handle); - private static native boolean preCpuAlloc(long handle, long amount, boolean blocking); - private static native void postCpuAllocSuccess(long handle, long ptr, long amount, - boolean blocking, boolean wasRecursive); - private static native boolean postCpuAllocFailed(long handle, boolean wasOom, - boolean blocking, boolean wasRecursive); - private static native void cpuDeallocate(long handle, long ptr, long amount); } diff --git a/src/main/java/com/nvidia/spark/rapids/jni/ThreadStateRegistry.java b/src/main/java/com/nvidia/spark/rapids/jni/ThreadStateRegistry.java index 4e7021e6ea..4de2396566 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/ThreadStateRegistry.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/ThreadStateRegistry.java @@ -16,12 +16,10 @@ package com.nvidia.spark.rapids.jni; +import java.util.HashMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.util.HashMap; -import java.util.HashSet; - /** * This is used to allow us to map a native thread id to a java thread so we can look at the * state from a java perspective. diff --git a/src/main/java/com/nvidia/spark/rapids/jni/ZOrder.java b/src/main/java/com/nvidia/spark/rapids/jni/ZOrder.java index 62c5137e46..4f44ed65fd 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/ZOrder.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/ZOrder.java @@ -17,7 +17,6 @@ package com.nvidia.spark.rapids.jni; import ai.rapids.cudf.ColumnVector; -import ai.rapids.cudf.DType; import ai.rapids.cudf.NativeDepsLoader; import ai.rapids.cudf.Scalar; @@ -32,13 +31,14 @@ public class ZOrder { * the InterleaveBits expression in DeltaLake. The input data should all be the same type and all * fixed with types. In general if you want good clustering/ordering then these should all * be positive integer values. - * @param numRows the number of rows to output in a corner case where there are no input columns. - * This should never happen in practice, but the expression supports this so we - * should to. + * + * @param numRows the number of rows to output in a corner case where there are no input columns. + * This should never happen in practice, but the expression supports this so we + * should to. * @param inputColumns the data to process. * @return a binary column of the interleaved data. */ - public static ColumnVector interleaveBits(int numRows, ColumnVector ... inputColumns) { + public static ColumnVector interleaveBits(int numRows, ColumnVector... inputColumns) { if (inputColumns.length == 0) { try (ColumnVector empty = ColumnVector.fromUnsignedBytes(); Scalar emptyList = Scalar.listFromColumnView(empty)) { @@ -59,15 +59,15 @@ public static ColumnVector interleaveBits(int numRows, ColumnVector ... inputCol * cluster the data that databricks uses, and that is why we have it here. Please note that * this currently only supports indexes where numBits * inputColumns.length <= 64. * - * @param numBits the number of bits in the input columns to use. Typically, this is log2(max) - * for the values in all the inputColumns. - * @param numRows the number of rows. Used if inputColumns is empty. I think this is also a corner - * case that can never happen in practice, but I am just covering my bases here. - * a column of 0 is returned in this case. + * @param numBits the number of bits in the input columns to use. Typically, this is log2(max) + * for the values in all the inputColumns. + * @param numRows the number of rows. Used if inputColumns is empty. I think this is also a corner + * case that can never happen in practice, but I am just covering my bases here. + * a column of 0 is returned in this case. * @param inputColumns The columns to intermix. * @return the corresponding indexes stored as long values. */ - public static ColumnVector hilbertIndex(int numBits, int numRows, ColumnVector ... inputColumns) { + public static ColumnVector hilbertIndex(int numBits, int numRows, ColumnVector... inputColumns) { if (inputColumns.length == 0) { try (Scalar zero = Scalar.fromLong(0)) { return ColumnVector.fromScalar(zero, numRows); diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/ColumnOffsetInfo.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/ColumnOffsetInfo.java index 3a46806a78..e22e2c2e7d 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/ColumnOffsetInfo.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/ColumnOffsetInfo.java @@ -16,10 +16,10 @@ package com.nvidia.spark.rapids.jni.kudo; -import ai.rapids.cudf.DeviceMemoryBufferView; - import static com.nvidia.spark.rapids.jni.Preconditions.ensureNonNegative; +import ai.rapids.cudf.DeviceMemoryBufferView; + /** * This class is used to store the offsets of the buffer of a column in the serialized data. */ @@ -32,7 +32,8 @@ class ColumnOffsetInfo { private final long data; private final long dataBufferLen; - public ColumnOffsetInfo(long validity, long validityBufferLen, long offset, long offsetBufferLen, long data, + public ColumnOffsetInfo(long validity, long validityBufferLen, long offset, long offsetBufferLen, + long data, long dataBufferLen) { ensureNonNegative(validityBufferLen, "validityBuffeLen"); ensureNonNegative(offsetBufferLen, "offsetBufferLen"); @@ -47,6 +48,7 @@ public ColumnOffsetInfo(long validity, long validityBufferLen, long offset, long /** * Get the validity buffer offset. + * * @return {@value #INVALID_OFFSET} if the validity buffer is not present, otherwise the offset. */ long getValidity() { @@ -55,6 +57,7 @@ long getValidity() { /** * Get a view of the validity buffer from underlying buffer. + * * @param baseAddress the base address of underlying buffer. * @return null if the validity buffer is not present, otherwise a view of the buffer. */ @@ -67,6 +70,7 @@ DeviceMemoryBufferView getValidityBuffer(long baseAddress) { /** * Get the offset buffer offset. + * * @return {@value #INVALID_OFFSET} if the offset buffer is not present, otherwise the offset. */ long getOffset() { @@ -75,6 +79,7 @@ long getOffset() { /** * Get a view of the offset buffer from underlying buffer. + * * @param baseAddress the base address of underlying buffer. * @return null if the offset buffer is not present, otherwise a view of the buffer. */ @@ -87,6 +92,7 @@ DeviceMemoryBufferView getOffsetBuffer(long baseAddress) { /** * Get the data buffer offset. + * * @return {@value #INVALID_OFFSET} if the data buffer is not present, otherwise the offset. */ long getData() { @@ -95,6 +101,7 @@ long getData() { /** * Get a view of the data buffer from underlying buffer. + * * @param baseAddress the base address of underlying buffer. * @return null if the data buffer is not present, otherwise a view of the buffer. */ diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/ColumnViewInfo.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/ColumnViewInfo.java index 002dff54c0..07f70664a1 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/ColumnViewInfo.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/ColumnViewInfo.java @@ -16,12 +16,13 @@ package com.nvidia.spark.rapids.jni.kudo; -import ai.rapids.cudf.*; +import static com.nvidia.spark.rapids.jni.Preconditions.ensureNonNegative; +import ai.rapids.cudf.ColumnView; +import ai.rapids.cudf.DType; +import ai.rapids.cudf.DeviceMemoryBuffer; import java.util.Optional; -import static com.nvidia.spark.rapids.jni.Preconditions.ensureNonNegative; - class ColumnViewInfo { private final DType dtype; private final ColumnOffsetInfo offsetInfo; @@ -42,12 +43,12 @@ ColumnView buildColumnView(DeviceMemoryBuffer buffer, ColumnView[] childrenView) long baseAddress = buffer.getAddress(); if (dtype.isNestedType()) { - return new ColumnView(dtype, rowCount, Optional.of((long)nullCount), + return new ColumnView(dtype, rowCount, Optional.of((long) nullCount), offsetInfo.getValidityBuffer(baseAddress), offsetInfo.getOffsetBuffer(baseAddress), childrenView); } else { - return new ColumnView(dtype, rowCount, Optional.of((long)nullCount), + return new ColumnView(dtype, rowCount, Optional.of((long) nullCount), offsetInfo.getDataBuffer(baseAddress), offsetInfo.getValidityBuffer(baseAddress), offsetInfo.getOffsetBuffer(baseAddress)); diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataOutputStreamWriter.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataOutputStreamWriter.java index c88f125b2e..d7449969f8 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataOutputStreamWriter.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataOutputStreamWriter.java @@ -17,7 +17,6 @@ package com.nvidia.spark.rapids.jni.kudo; import ai.rapids.cudf.HostMemoryBuffer; - import java.io.DataOutputStream; import java.io.IOException; diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataWriter.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataWriter.java index 1f2e8f3dca..4bd4be814f 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataWriter.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataWriter.java @@ -17,7 +17,6 @@ package com.nvidia.spark.rapids.jni.kudo; import ai.rapids.cudf.HostMemoryBuffer; - import java.io.IOException; /** @@ -34,7 +33,8 @@ abstract class DataWriter { * @param srcOffset offset to start at. * @param len amount to copy. */ - public abstract void copyDataFrom(HostMemoryBuffer src, long srcOffset, long len) throws IOException; + public abstract void copyDataFrom(HostMemoryBuffer src, long srcOffset, long len) + throws IOException; public void flush() throws IOException { // NOOP by default diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoHostMergeResult.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoHostMergeResult.java index 6529f9e15e..38b5b12884 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoHostMergeResult.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoHostMergeResult.java @@ -16,15 +16,17 @@ package com.nvidia.spark.rapids.jni.kudo; -import ai.rapids.cudf.*; -import com.nvidia.spark.rapids.jni.Arms; -import com.nvidia.spark.rapids.jni.schema.Visitors; - -import java.util.List; - import static com.nvidia.spark.rapids.jni.Preconditions.ensure; import static java.util.Objects.requireNonNull; +import ai.rapids.cudf.Cuda; +import ai.rapids.cudf.DeviceMemoryBuffer; +import ai.rapids.cudf.HostMemoryBuffer; +import ai.rapids.cudf.Schema; +import ai.rapids.cudf.Table; +import com.nvidia.spark.rapids.jni.schema.Visitors; +import java.util.List; + /** * The result of merging several kudo tables into one contiguous table on the host. */ @@ -33,11 +35,13 @@ public class KudoHostMergeResult implements AutoCloseable { private final List columnInfoList; private HostMemoryBuffer hostBuf; - KudoHostMergeResult(Schema schema, HostMemoryBuffer hostBuf, List columnInfoList) { + KudoHostMergeResult(Schema schema, HostMemoryBuffer hostBuf, + List columnInfoList) { requireNonNull(schema, "schema is null"); requireNonNull(columnInfoList, "columnInfoList is null"); ensure(schema.getFlattenedColumnNames().length == columnInfoList.size(), () -> - "Column offsets size does not match flattened schema size, column offsets size: " + columnInfoList.size() + + "Column offsets size does not match flattened schema size, column offsets size: " + + columnInfoList.size() + ", flattened schema size: " + schema.getFlattenedColumnNames().length); this.schema = schema; this.columnInfoList = columnInfoList; @@ -52,6 +56,7 @@ public void close() throws Exception { /** * Get the length of the data in the host buffer. + * * @return the length of the data in the host buffer */ public long getDataLength() { @@ -60,6 +65,7 @@ public long getDataLength() { /** * Convert the host buffer into a cudf table. + * * @return the cudf table */ public Table toTable() { diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java index 6370531428..4135776363 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java @@ -16,20 +16,29 @@ package com.nvidia.spark.rapids.jni.kudo; -import ai.rapids.cudf.*; +import static com.nvidia.spark.rapids.jni.Preconditions.ensure; +import static java.util.Objects.requireNonNull; + +import ai.rapids.cudf.BufferType; +import ai.rapids.cudf.Cuda; +import ai.rapids.cudf.HostColumnVector; +import ai.rapids.cudf.JCudfSerialization; +import ai.rapids.cudf.Schema; +import ai.rapids.cudf.Table; import com.nvidia.spark.rapids.jni.Pair; import com.nvidia.spark.rapids.jni.schema.Visitors; - -import java.io.*; +import java.io.BufferedOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; import java.util.Arrays; import java.util.List; import java.util.function.LongConsumer; import java.util.function.Supplier; import java.util.stream.IntStream; -import static com.nvidia.spark.rapids.jni.Preconditions.ensure; -import static java.util.Objects.requireNonNull; - /** * This class is used to serialize/deserialize a table using the Kudo format. * @@ -148,8 +157,9 @@ public class KudoSerializer { private static final byte[] PADDING = new byte[64]; - private static final BufferType[] ALL_BUFFER_TYPES = new BufferType[]{BufferType.VALIDITY, BufferType.OFFSET, - BufferType.DATA}; + private static final BufferType[] ALL_BUFFER_TYPES = + new BufferType[] {BufferType.VALIDITY, BufferType.OFFSET, + BufferType.DATA}; static { Arrays.fill(PADDING, (byte) 0); @@ -164,6 +174,75 @@ public KudoSerializer(Schema schema) { this.flattenedColumnCount = schema.getFlattenedColumnNames().length; } + /** + * Write a row count only record to an output stream. + * + * @param out output stream + * @param numRows number of rows to write + * @return number of bytes written + */ + public static long writeRowCountToStream(OutputStream out, int numRows) { + if (numRows <= 0) { + throw new IllegalArgumentException("Number of rows must be > 0, but was " + numRows); + } + try { + DataWriter writer = writerFrom(out); + KudoTableHeader header = new KudoTableHeader(0, numRows, 0, 0, 0, 0, new byte[0]); + header.writeTo(writer); + writer.flush(); + return header.getSerializedSize(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private static DataWriter writerFrom(OutputStream out) { + if (!(out instanceof DataOutputStream)) { + out = new DataOutputStream(new BufferedOutputStream(out)); + } + return new DataOutputStreamWriter((DataOutputStream) out); + } + + static long padForHostAlignment(long orig) { + return ((orig + 3) / 4) * 4; + } + + static long padForHostAlignment(DataWriter out, long bytes) throws IOException { + final long paddedBytes = padForHostAlignment(bytes); + if (paddedBytes > bytes) { + out.write(PADDING, 0, (int) (paddedBytes - bytes)); + } + return paddedBytes; + } + + static long padFor64byteAlignment(long orig) { + return ((orig + 63) / 64) * 64; + } + + static DataInputStream readerFrom(InputStream in) { + if (in instanceof DataInputStream) { + return (DataInputStream) in; + } + return new DataInputStream(in); + } + + static T withTime(Supplier task, LongConsumer timeConsumer) { + long now = System.nanoTime(); + T ret = task.get(); + timeConsumer.accept(System.nanoTime() - now); + return ret; + } + + /** + * This method returns the length in bytes needed to represent X number of rows + * e.g. getValidityLengthInBytes(5) => 1 byte + * getValidityLengthInBytes(7) => 1 byte + * getValidityLengthInBytes(14) => 2 bytes + */ + static long getValidityLengthInBytes(long rows) { + return (rows + 7) / 8; + } + /** * Write partition of a table to a stream. This method is used for test only. *
@@ -208,7 +287,8 @@ long writeToStream(Table table, OutputStream out, int rowOffset, int numRows) { * @param numRows number of rows to write * @return number of bytes written */ - public long writeToStream(HostColumnVector[] columns, OutputStream out, int rowOffset, int numRows) { + public long writeToStream(HostColumnVector[] columns, OutputStream out, int rowOffset, + int numRows) { ensure(numRows > 0, () -> "numRows must be > 0, but was " + numRows); ensure(columns.length > 0, () -> "columns must not be empty, for row count only records " + "please call writeRowCountToStream"); @@ -220,29 +300,6 @@ public long writeToStream(HostColumnVector[] columns, OutputStream out, int rowO } } - /** - * Write a row count only record to an output stream. - * - * @param out output stream - * @param numRows number of rows to write - * @return number of bytes written - */ - public static long writeRowCountToStream(OutputStream out, int numRows) { - if (numRows <= 0) { - throw new IllegalArgumentException("Number of rows must be > 0, but was " + numRows); - } - try { - DataWriter writer = writerFrom(out); - KudoTableHeader header = new KudoTableHeader(0, numRows, 0, 0, 0 - , 0, new byte[0]); - header.writeTo(writer); - writer.flush(); - return header.getSerializedSize(); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - /** * Merge a list of kudo tables into a table on host memory. *
@@ -286,15 +343,18 @@ public Pair mergeToTable(List kudoTables) throws } } - private long writeSliced(HostColumnVector[] columns, DataWriter out, int rowOffset, int numRows) throws Exception { - KudoTableHeaderCalc headerCalc = new KudoTableHeaderCalc(rowOffset, numRows, flattenedColumnCount); + private long writeSliced(HostColumnVector[] columns, DataWriter out, int rowOffset, int numRows) + throws Exception { + KudoTableHeaderCalc headerCalc = + new KudoTableHeaderCalc(rowOffset, numRows, flattenedColumnCount); Visitors.visitColumns(columns, headerCalc); KudoTableHeader header = headerCalc.getHeader(); header.writeTo(out); long bytesWritten = 0; for (BufferType bufferType : ALL_BUFFER_TYPES) { - SlicedBufferSerializer serializer = new SlicedBufferSerializer(rowOffset, numRows, bufferType, out); + SlicedBufferSerializer serializer = + new SlicedBufferSerializer(rowOffset, numRows, bufferType, out); Visitors.visitColumns(columns, serializer); bytesWritten += serializer.getTotalDataLen(); } @@ -309,52 +369,4 @@ private long writeSliced(HostColumnVector[] columns, DataWriter out, int rowOffs return header.getSerializedSize() + bytesWritten; } - - private static DataWriter writerFrom(OutputStream out) { - if (!(out instanceof DataOutputStream)) { - out = new DataOutputStream(new BufferedOutputStream(out)); - } - return new DataOutputStreamWriter((DataOutputStream) out); - } - - - static long padForHostAlignment(long orig) { - return ((orig + 3) / 4) * 4; - } - - static long padForHostAlignment(DataWriter out, long bytes) throws IOException { - final long paddedBytes = padForHostAlignment(bytes); - if (paddedBytes > bytes) { - out.write(PADDING, 0, (int) (paddedBytes - bytes)); - } - return paddedBytes; - } - - static long padFor64byteAlignment(long orig) { - return ((orig + 63) / 64) * 64; - } - - static DataInputStream readerFrom(InputStream in) { - if (in instanceof DataInputStream) { - return (DataInputStream) in; - } - return new DataInputStream(in); - } - - static T withTime(Supplier task, LongConsumer timeConsumer) { - long now = System.nanoTime(); - T ret = task.get(); - timeConsumer.accept(System.nanoTime() - now); - return ret; - } - - /** - * This method returns the length in bytes needed to represent X number of rows - * e.g. getValidityLengthInBytes(5) => 1 byte - * getValidityLengthInBytes(7) => 1 byte - * getValidityLengthInBytes(14) => 2 bytes - */ - static long getValidityLengthInBytes(long rows) { - return (rows + 7) / 8; - } } \ No newline at end of file diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTable.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTable.java index c49b2cb8f7..d5b86070f2 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTable.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTable.java @@ -16,17 +16,16 @@ package com.nvidia.spark.rapids.jni.kudo; +import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.readerFrom; +import static java.util.Objects.requireNonNull; + import ai.rapids.cudf.HostMemoryBuffer; import com.nvidia.spark.rapids.jni.Arms; - import java.io.DataInputStream; import java.io.IOException; import java.io.InputStream; import java.util.Optional; -import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.readerFrom; -import static java.util.Objects.requireNonNull; - /** * Serialized table in kudo format, including a {{@link KudoTableHeader}} and a {@link HostMemoryBuffer} for serialized * data. @@ -65,14 +64,15 @@ public static Optional from(InputStream in) throws IOException { return new KudoTable(header, null); } - return Arms.closeIfException(HostMemoryBuffer.allocate(header.getTotalDataLen(), false), buffer -> { - try { - buffer.copyFromStream(0, din, header.getTotalDataLen()); - return new KudoTable(header, buffer); - } catch (IOException e) { - throw new RuntimeException(e); - } - }); + return Arms.closeIfException(HostMemoryBuffer.allocate(header.getTotalDataLen(), false), + buffer -> { + try { + buffer.copyFromStream(0, din, header.getTotalDataLen()); + return new KudoTable(header, buffer); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); }); } diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableHeader.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableHeader.java index 2bf5449c7a..9537ea679d 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableHeader.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableHeader.java @@ -16,16 +16,16 @@ package com.nvidia.spark.rapids.jni.kudo; +import static com.nvidia.spark.rapids.jni.Preconditions.ensure; +import static com.nvidia.spark.rapids.jni.Preconditions.ensureNonNegative; +import static java.util.Objects.requireNonNull; + import java.io.DataInputStream; import java.io.EOFException; import java.io.IOException; import java.util.Arrays; import java.util.Optional; -import static com.nvidia.spark.rapids.jni.Preconditions.ensure; -import static com.nvidia.spark.rapids.jni.Preconditions.ensureNonNegative; -import static java.util.Objects.requireNonNull; - /** * Holds the metadata about a serialized table. If this is being read from a stream * isInitialized will return true if the metadata was read correctly from the stream. @@ -49,6 +49,23 @@ public final class KudoTableHeader { // A bit set to indicate if a column has a validity buffer or not. Each column is represented by a single bit. private final byte[] hasValidityBuffer; + KudoTableHeader(int offset, int numRows, int validityBufferLen, int offsetBufferLen, + int totalDataLen, int numColumns, byte[] hasValidityBuffer) { + this.offset = ensureNonNegative(offset, "offset"); + this.numRows = ensureNonNegative(numRows, "numRows"); + this.validityBufferLen = ensureNonNegative(validityBufferLen, "validityBufferLen"); + this.offsetBufferLen = ensureNonNegative(offsetBufferLen, "offsetBufferLen"); + this.totalDataLen = ensureNonNegative(totalDataLen, "totalDataLen"); + this.numColumns = ensureNonNegative(numColumns, "numColumns"); + + requireNonNull(hasValidityBuffer, "hasValidityBuffer cannot be null"); + ensure(hasValidityBuffer.length == lengthOfHasValidityBuffer(numColumns), + () -> numColumns + " columns expects hasValidityBuffer with length " + + lengthOfHasValidityBuffer(numColumns) + + ", but found " + hasValidityBuffer.length); + this.hasValidityBuffer = hasValidityBuffer; + } + /** * Reads the table header from the given input stream. * @@ -61,8 +78,9 @@ public static Optional readFrom(DataInputStream din) throws IOE try { num = din.readInt(); if (num != SER_FORMAT_MAGIC_NUMBER) { - throw new IllegalStateException("Kudo format error, expected magic number " + SER_FORMAT_MAGIC_NUMBER + - " found " + num); + throw new IllegalStateException( + "Kudo format error, expected magic number " + SER_FORMAT_MAGIC_NUMBER + + " found " + num); } } catch (EOFException e) { // If we get an EOF at the very beginning don't treat it as an error because we may @@ -81,24 +99,14 @@ public static Optional readFrom(DataInputStream din) throws IOE byte[] hasValidityBuffer = new byte[validityBufferLength]; din.readFully(hasValidityBuffer); - return Optional.of(new KudoTableHeader(offset, numRows, validityBufferLen, offsetBufferLen, totalDataLen, numColumns, - hasValidityBuffer)); + return Optional.of( + new KudoTableHeader(offset, numRows, validityBufferLen, offsetBufferLen, totalDataLen, + numColumns, + hasValidityBuffer)); } - KudoTableHeader(int offset, int numRows, int validityBufferLen, int offsetBufferLen, - int totalDataLen, int numColumns, byte[] hasValidityBuffer) { - this.offset = ensureNonNegative(offset, "offset"); - this.numRows = ensureNonNegative(numRows, "numRows"); - this.validityBufferLen = ensureNonNegative(validityBufferLen, "validityBufferLen"); - this.offsetBufferLen = ensureNonNegative(offsetBufferLen, "offsetBufferLen"); - this.totalDataLen = ensureNonNegative(totalDataLen, "totalDataLen"); - this.numColumns = ensureNonNegative(numColumns, "numColumns"); - - requireNonNull(hasValidityBuffer, "hasValidityBuffer cannot be null"); - ensure(hasValidityBuffer.length == lengthOfHasValidityBuffer(numColumns), - () -> numColumns + " columns expects hasValidityBuffer with length " + lengthOfHasValidityBuffer(numColumns) + - ", but found " + hasValidityBuffer.length); - this.hasValidityBuffer = hasValidityBuffer; + private static int lengthOfHasValidityBuffer(int numColumns) { + return (numColumns + 7) / 8; } /** @@ -187,8 +195,4 @@ public String toString() { ", hasValidityBuffer=" + Arrays.toString(hasValidityBuffer) + '}'; } - - private static int lengthOfHasValidityBuffer(int numColumns) { - return (numColumns + 7) / 8; - } } diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableHeaderCalc.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableHeaderCalc.java index 4eaa1c435c..15119f3995 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableHeaderCalc.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableHeaderCalc.java @@ -16,17 +16,16 @@ package com.nvidia.spark.rapids.jni.kudo; +import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padForHostAlignment; +import static java.lang.Math.toIntExact; + import ai.rapids.cudf.DType; import ai.rapids.cudf.HostColumnVectorCore; import com.nvidia.spark.rapids.jni.schema.HostColumnsVisitor; - import java.util.ArrayDeque; import java.util.Deque; import java.util.List; -import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padForHostAlignment; -import static java.lang.Math.toIntExact; - /** * This class visits a list of columns and calculates the serialized table header. * @@ -44,7 +43,7 @@ class KudoTableHeaderCalc implements HostColumnsVisitor { private long totalDataLen; private int nextColIdx; - private Deque sliceInfos = new ArrayDeque<>(); + private final Deque sliceInfos = new ArrayDeque<>(); KudoTableHeaderCalc(int rowOffset, int numRows, int numFlattenedCols) { this.root = new SliceInfo(rowOffset, numRows); @@ -55,6 +54,41 @@ class KudoTableHeaderCalc implements HostColumnsVisitor { this.nextColIdx = 0; } + private static long dataLenOfValidityBuffer(HostColumnVectorCore col, SliceInfo info) { + if (col.hasValidityVector() && info.getRowCount() > 0) { + return padForHostAlignment(info.getValidityBufferInfo().getBufferLength()); + } else { + return 0; + } + } + + private static long dataLenOfOffsetBuffer(HostColumnVectorCore col, SliceInfo info) { + if (DType.STRING.equals(col.getType()) && info.getRowCount() > 0) { + return padForHostAlignment((long) (info.rowCount + 1) * Integer.BYTES); + } else { + return 0; + } + } + + private static long dataLenOfDataBuffer(HostColumnVectorCore col, SliceInfo info) { + if (DType.STRING.equals(col.getType())) { + if (col.getOffsets() != null) { + long startByteOffset = col.getOffsets().getInt((long) info.offset * Integer.BYTES); + long endByteOffset = col.getOffsets().getInt( + (long) (info.offset + info.rowCount) * Integer.BYTES); + return padForHostAlignment(endByteOffset - startByteOffset); + } else { + return 0; + } + } else { + if (col.getType().getSizeInBytes() > 0) { + return padForHostAlignment((long) col.getType().getSizeInBytes() * info.rowCount); + } else { + return 0; + } + } + } + public KudoTableHeader getHeader() { return new KudoTableHeader(toIntExact(root.offset), toIntExact(root.rowCount), @@ -93,7 +127,7 @@ public Void preVisitList(HostColumnVectorCore col) { long offsetBufferLength = 0; if (col.getOffsets() != null && parent.rowCount > 0) { - offsetBufferLength = padForHostAlignment((parent.rowCount + 1) * Integer.BYTES); + offsetBufferLength = padForHostAlignment((long) (parent.rowCount + 1) * Integer.BYTES); } this.validityBufferLen += validityBufferLength; @@ -105,8 +139,8 @@ public Void preVisitList(HostColumnVectorCore col) { SliceInfo current; if (col.getOffsets() != null) { - int start = col.getOffsets().getInt(parent.offset * Integer.BYTES); - int end = col.getOffsets().getInt((parent.offset + parent.rowCount) * Integer.BYTES); + int start = col.getOffsets().getInt((long) parent.offset * Integer.BYTES); + int end = col.getOffsets().getInt((long) (parent.offset + parent.rowCount) * Integer.BYTES); int rowCount = end - start; current = new SliceInfo(start, rowCount); } else { @@ -124,7 +158,6 @@ public Void visitList(HostColumnVectorCore col, Void preVisitResult, Void childR return null; } - @Override public Void visit(HostColumnVectorCore col) { SliceInfo parent = sliceInfos.peekLast(); @@ -149,38 +182,4 @@ private void setHasValidity(boolean hasValidityBuffer) { } nextColIdx++; } - - private static long dataLenOfValidityBuffer(HostColumnVectorCore col, SliceInfo info) { - if (col.hasValidityVector() && info.getRowCount() > 0) { - return padForHostAlignment(info.getValidityBufferInfo().getBufferLength()); - } else { - return 0; - } - } - - private static long dataLenOfOffsetBuffer(HostColumnVectorCore col, SliceInfo info) { - if (DType.STRING.equals(col.getType()) && info.getRowCount() > 0) { - return padForHostAlignment((info.rowCount + 1) * Integer.BYTES); - } else { - return 0; - } - } - - private static long dataLenOfDataBuffer(HostColumnVectorCore col, SliceInfo info) { - if (DType.STRING.equals(col.getType())) { - if (col.getOffsets() != null) { - long startByteOffset = col.getOffsets().getInt(info.offset * Integer.BYTES); - long endByteOffset = col.getOffsets().getInt((info.offset + info.rowCount) * Integer.BYTES); - return padForHostAlignment(endByteOffset - startByteOffset); - } else { - return 0; - } - } else { - if (col.getType().getSizeInBytes() > 0) { - return padForHostAlignment(col.getType().getSizeInBytes() * info.rowCount); - } else { - return 0; - } - } - } } diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableMerger.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableMerger.java index af80391f3d..7f3b50e960 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableMerger.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableMerger.java @@ -16,25 +16,24 @@ package com.nvidia.spark.rapids.jni.kudo; +import static com.nvidia.spark.rapids.jni.Preconditions.ensure; +import static com.nvidia.spark.rapids.jni.kudo.ColumnOffsetInfo.INVALID_OFFSET; +import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.getValidityLengthInBytes; +import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padFor64byteAlignment; +import static java.lang.Math.min; +import static java.lang.Math.toIntExact; +import static java.util.Objects.requireNonNull; + import ai.rapids.cudf.HostMemoryBuffer; import ai.rapids.cudf.Schema; import com.nvidia.spark.rapids.jni.Arms; import com.nvidia.spark.rapids.jni.schema.Visitors; - import java.nio.ByteOrder; import java.nio.IntBuffer; import java.util.ArrayList; import java.util.List; import java.util.OptionalInt; -import static com.nvidia.spark.rapids.jni.Preconditions.ensure; -import static com.nvidia.spark.rapids.jni.kudo.ColumnOffsetInfo.INVALID_OFFSET; -import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.getValidityLengthInBytes; -import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padFor64byteAlignment; -import static java.lang.Math.min; -import static java.lang.Math.toIntExact; -import static java.util.Objects.requireNonNull; - /** * This class is used to merge multiple KudoTables into a single contiguous buffer, e.g. {@link KudoHostMergeResult}, * which could be easily converted to a {@link ai.rapids.cudf.ContiguousTable}. @@ -59,7 +58,8 @@ class KudoTableMerger extends MultiKudoTableVisitor colViewInfoList; - public KudoTableMerger(List tables, HostMemoryBuffer buffer, List columnOffsets) { + public KudoTableMerger(List tables, HostMemoryBuffer buffer, + List columnOffsets) { super(tables); requireNonNull(buffer, "buffer can't be null!"); ensure(columnOffsets != null, "column offsets cannot be null"); @@ -69,83 +69,6 @@ public KudoTableMerger(List tables, HostMemoryBuffer buffer, List(columnOffsets.size()); } - @Override - protected KudoHostMergeResult doVisitTopSchema(Schema schema, List children) { - return new KudoHostMergeResult(schema, buffer, colViewInfoList); - } - - @Override - protected Void doVisitStruct(Schema structType, List children) { - ColumnOffsetInfo offsetInfo = getCurColumnOffsets(); - int nullCount = deserializeValidityBuffer(offsetInfo); - int totalRowCount = getTotalRowCount(); - colViewInfoList.add(new ColumnViewInfo(structType.getType(), - offsetInfo, nullCount, totalRowCount)); - return null; - } - - @Override - protected Void doPreVisitList(Schema listType) { - ColumnOffsetInfo offsetInfo = getCurColumnOffsets(); - int nullCount = deserializeValidityBuffer(offsetInfo); - int totalRowCount = getTotalRowCount(); - deserializeOffsetBuffer(offsetInfo); - - colViewInfoList.add(new ColumnViewInfo(listType.getType(), - offsetInfo, nullCount, totalRowCount)); - return null; - } - - @Override - protected Void doVisitList(Schema listType, Void preVisitResult, Void childResult) { - return null; - } - - @Override - protected Void doVisit(Schema primitiveType) { - ColumnOffsetInfo offsetInfo = getCurColumnOffsets(); - int nullCount = deserializeValidityBuffer(offsetInfo); - int totalRowCount = getTotalRowCount(); - if (primitiveType.getType().hasOffsets()) { - deserializeOffsetBuffer(offsetInfo); - deserializeDataBuffer(offsetInfo, OptionalInt.empty()); - } else { - deserializeDataBuffer(offsetInfo, OptionalInt.of(primitiveType.getType().getSizeInBytes())); - } - - colViewInfoList.add(new ColumnViewInfo(primitiveType.getType(), - offsetInfo, nullCount, totalRowCount)); - - return null; - } - - private int deserializeValidityBuffer(ColumnOffsetInfo curColOffset) { - if (curColOffset.getValidity() != INVALID_OFFSET) { - long offset = curColOffset.getValidity(); - long validityBufferSize = padFor64byteAlignment(getValidityLengthInBytes(getTotalRowCount())); - try (HostMemoryBuffer validityBuffer = buffer.slice(offset, validityBufferSize)) { - int nullCountTotal = 0; - int startRow = 0; - for (int tableIdx = 0; tableIdx < getTableSize(); tableIdx += 1) { - SliceInfo sliceInfo = sliceInfoOf(tableIdx); - long validityOffset = validifyBufferOffset(tableIdx); - if (validityOffset != INVALID_OFFSET) { - nullCountTotal += copyValidityBuffer(validityBuffer, startRow, - memoryBufferOf(tableIdx), toIntExact(validityOffset), - sliceInfo); - } else { - appendAllValid(validityBuffer, startRow, sliceInfo.getRowCount()); - } - - startRow += sliceInfo.getRowCount(); - } - return nullCountTotal; - } - } else { - return 0; - } - } - /** * Copy a sliced validity buffer to the destination buffer, starting at the given bit offset. * @@ -251,10 +174,97 @@ private static void appendAllValid(HostMemoryBuffer dest, int startBit, int numR } } + static KudoHostMergeResult merge(Schema schema, MergedInfoCalc mergedInfo) { + List serializedTables = mergedInfo.getTables(); + return Arms.closeIfException(HostMemoryBuffer.allocate(mergedInfo.getTotalDataLen()), + buffer -> { + KudoTableMerger merger = + new KudoTableMerger(serializedTables, buffer, mergedInfo.getColumnOffsets()); + return Visitors.visitSchema(schema, merger); + }); + } + + @Override + protected KudoHostMergeResult doVisitTopSchema(Schema schema, List children) { + return new KudoHostMergeResult(schema, buffer, colViewInfoList); + } + + @Override + protected Void doVisitStruct(Schema structType, List children) { + ColumnOffsetInfo offsetInfo = getCurColumnOffsets(); + int nullCount = deserializeValidityBuffer(offsetInfo); + int totalRowCount = getTotalRowCount(); + colViewInfoList.add(new ColumnViewInfo(structType.getType(), + offsetInfo, nullCount, totalRowCount)); + return null; + } + + @Override + protected Void doPreVisitList(Schema listType) { + ColumnOffsetInfo offsetInfo = getCurColumnOffsets(); + int nullCount = deserializeValidityBuffer(offsetInfo); + int totalRowCount = getTotalRowCount(); + deserializeOffsetBuffer(offsetInfo); + + colViewInfoList.add(new ColumnViewInfo(listType.getType(), + offsetInfo, nullCount, totalRowCount)); + return null; + } + + @Override + protected Void doVisitList(Schema listType, Void preVisitResult, Void childResult) { + return null; + } + + @Override + protected Void doVisit(Schema primitiveType) { + ColumnOffsetInfo offsetInfo = getCurColumnOffsets(); + int nullCount = deserializeValidityBuffer(offsetInfo); + int totalRowCount = getTotalRowCount(); + if (primitiveType.getType().hasOffsets()) { + deserializeOffsetBuffer(offsetInfo); + deserializeDataBuffer(offsetInfo, OptionalInt.empty()); + } else { + deserializeDataBuffer(offsetInfo, OptionalInt.of(primitiveType.getType().getSizeInBytes())); + } + + colViewInfoList.add(new ColumnViewInfo(primitiveType.getType(), + offsetInfo, nullCount, totalRowCount)); + + return null; + } + + private int deserializeValidityBuffer(ColumnOffsetInfo curColOffset) { + if (curColOffset.getValidity() != INVALID_OFFSET) { + long offset = curColOffset.getValidity(); + long validityBufferSize = padFor64byteAlignment(getValidityLengthInBytes(getTotalRowCount())); + try (HostMemoryBuffer validityBuffer = buffer.slice(offset, validityBufferSize)) { + int nullCountTotal = 0; + int startRow = 0; + for (int tableIdx = 0; tableIdx < getTableSize(); tableIdx += 1) { + SliceInfo sliceInfo = sliceInfoOf(tableIdx); + long validityOffset = validifyBufferOffset(tableIdx); + if (validityOffset != INVALID_OFFSET) { + nullCountTotal += copyValidityBuffer(validityBuffer, startRow, + memoryBufferOf(tableIdx), toIntExact(validityOffset), + sliceInfo); + } else { + appendAllValid(validityBuffer, startRow, sliceInfo.getRowCount()); + } + + startRow += sliceInfo.getRowCount(); + } + return nullCountTotal; + } + } else { + return 0; + } + } + private void deserializeOffsetBuffer(ColumnOffsetInfo curColOffset) { if (curColOffset.getOffset() != INVALID_OFFSET) { long offset = curColOffset.getOffset(); - long bufferSize = Integer.BYTES * (getTotalRowCount() + 1); + long bufferSize = (long) Integer.BYTES * (getTotalRowCount() + 1); IntBuffer buf = buffer .asByteBuffer(offset, toIntExact(bufferSize)) @@ -298,7 +308,7 @@ private void deserializeDataBuffer(ColumnOffsetInfo curColOffset, OptionalInt si for (int tableIdx = 0; tableIdx < getTableSize(); tableIdx += 1) { SliceInfo sliceInfo = sliceInfoOf(tableIdx); if (sliceInfo.getRowCount() > 0) { - int thisDataLen = toIntExact(elementSize * sliceInfo.getRowCount()); + int thisDataLen = toIntExact((long) elementSize * sliceInfo.getRowCount()); copyDataBuffer(buf, start, tableIdx, thisDataLen); start += thisDataLen; } @@ -316,17 +326,7 @@ private void deserializeDataBuffer(ColumnOffsetInfo curColOffset, OptionalInt si } } - private ColumnOffsetInfo getCurColumnOffsets() { return columnOffsets.get(getCurrentIdx()); } - - static KudoHostMergeResult merge(Schema schema, MergedInfoCalc mergedInfo) { - List serializedTables = mergedInfo.getTables(); - return Arms.closeIfException(HostMemoryBuffer.allocate(mergedInfo.getTotalDataLen()), - buffer -> { - KudoTableMerger merger = new KudoTableMerger(serializedTables, buffer, mergedInfo.getColumnOffsets()); - return Visitors.visitSchema(schema, merger); - }); - } } diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/MergeMetrics.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/MergeMetrics.java index e621129dd6..8a04face16 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/MergeMetrics.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/MergeMetrics.java @@ -31,18 +31,6 @@ public MergeMetrics(long calcHeaderTime, long mergeIntoHostBufferTime, this.convertToTableTime = convertToTableTime; } - public long getCalcHeaderTime() { - return calcHeaderTime; - } - - public long getMergeIntoHostBufferTime() { - return mergeIntoHostBufferTime; - } - - public long getConvertToTableTime() { - return convertToTableTime; - } - public static Builder builder() { return new Builder(); } @@ -54,6 +42,17 @@ public static Builder builder(MergeMetrics metrics) { .convertToTableTime(metrics.convertToTableTime); } + public long getCalcHeaderTime() { + return calcHeaderTime; + } + + public long getMergeIntoHostBufferTime() { + return mergeIntoHostBufferTime; + } + + public long getConvertToTableTime() { + return convertToTableTime; + } public static class Builder { private long calcHeaderTime; diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/MergedInfoCalc.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/MergedInfoCalc.java index 826ef2e691..fdbeaed2fb 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/MergedInfoCalc.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/MergedInfoCalc.java @@ -16,26 +16,25 @@ package com.nvidia.spark.rapids.jni.kudo; +import static com.nvidia.spark.rapids.jni.kudo.ColumnOffsetInfo.INVALID_OFFSET; +import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.getValidityLengthInBytes; +import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padFor64byteAlignment; + import ai.rapids.cudf.Schema; import com.nvidia.spark.rapids.jni.schema.Visitors; - import java.util.ArrayList; import java.util.Collections; import java.util.List; -import static com.nvidia.spark.rapids.jni.kudo.ColumnOffsetInfo.INVALID_OFFSET; -import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.getValidityLengthInBytes; -import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padFor64byteAlignment; - /** * This class is used to calculate column offsets of merged buffer. */ class MergedInfoCalc extends MultiKudoTableVisitor { - // Total data len in gpu, which accounts for 64 byte alignment - private long totalDataLen; // Column offset in gpu device buffer, it has one field for each flattened column private final List columnOffsets; + // Total data len in gpu, which accounts for 64 byte alignment + private long totalDataLen; public MergedInfoCalc(List tables) { super(tables); @@ -43,6 +42,12 @@ public MergedInfoCalc(List tables) { this.columnOffsets = new ArrayList<>(tables.get(0).getHeader().getNumColumns()); } + static MergedInfoCalc calc(Schema schema, List table) { + MergedInfoCalc calc = new MergedInfoCalc(table); + Visitors.visitSchema(schema, calc); + return calc; + } + @Override protected Void doVisitTopSchema(Schema schema, List children) { return null; @@ -58,7 +63,9 @@ protected Void doVisitStruct(Schema structType, List children) { totalDataLen += validityBufferLen; } - columnOffsets.add(new ColumnOffsetInfo(validityOffset, validityBufferLen, INVALID_OFFSET, 0, INVALID_OFFSET, 0)); + columnOffsets.add( + new ColumnOffsetInfo(validityOffset, validityBufferLen, INVALID_OFFSET, 0, INVALID_OFFSET, + 0)); return null; } @@ -75,13 +82,15 @@ protected Void doPreVisitList(Schema listType) { long offsetBufferLen = 0; long offsetBufferOffset = INVALID_OFFSET; if (getTotalRowCount() > 0) { - offsetBufferLen = padFor64byteAlignment((getTotalRowCount() + 1) * Integer.BYTES); + offsetBufferLen = padFor64byteAlignment((long) (getTotalRowCount() + 1) * Integer.BYTES); offsetBufferOffset = totalDataLen; totalDataLen += offsetBufferLen; } - columnOffsets.add(new ColumnOffsetInfo(validityOffset, validityBufferLen, offsetBufferOffset, offsetBufferLen, INVALID_OFFSET, 0)); + columnOffsets.add( + new ColumnOffsetInfo(validityOffset, validityBufferLen, offsetBufferOffset, offsetBufferLen, + INVALID_OFFSET, 0)); return null; } @@ -105,7 +114,7 @@ protected Void doVisit(Schema primitiveType) { long offsetBufferLen = 0; long offsetBufferOffset = INVALID_OFFSET; if (getTotalRowCount() > 0) { - offsetBufferLen = padFor64byteAlignment((getTotalRowCount() + 1) * Integer.BYTES); + offsetBufferLen = padFor64byteAlignment((long) (getTotalRowCount() + 1) * Integer.BYTES); offsetBufferOffset = totalDataLen; totalDataLen += offsetBufferLen; } @@ -118,7 +127,8 @@ protected Void doVisit(Schema primitiveType) { totalDataLen += dataBufferLen; } - columnOffsets.add(new ColumnOffsetInfo(validityOffset, validityBufferLen, offsetBufferOffset, offsetBufferLen, dataBufferOffset, dataBufferLen)); + columnOffsets.add(new ColumnOffsetInfo(validityOffset, validityBufferLen, offsetBufferOffset, + offsetBufferLen, dataBufferOffset, dataBufferLen)); } else { long totalRowCount = getTotalRowCount(); long validityBufferLen = 0; @@ -132,18 +142,19 @@ protected Void doVisit(Schema primitiveType) { long dataBufferLen = 0; long dataBufferOffset = INVALID_OFFSET; if (totalRowCount > 0) { - dataBufferLen = padFor64byteAlignment(totalRowCount * primitiveType.getType().getSizeInBytes()); + dataBufferLen = + padFor64byteAlignment(totalRowCount * primitiveType.getType().getSizeInBytes()); dataBufferOffset = totalDataLen; totalDataLen += dataBufferLen; } - columnOffsets.add(new ColumnOffsetInfo(validityOffset, validityBufferLen, INVALID_OFFSET, 0, dataBufferOffset, dataBufferLen)); + columnOffsets.add(new ColumnOffsetInfo(validityOffset, validityBufferLen, INVALID_OFFSET, 0, + dataBufferOffset, dataBufferLen)); } return null; } - public long getTotalDataLen() { return totalDataLen; } @@ -159,10 +170,4 @@ public String toString() { ", columnOffsets=" + columnOffsets + '}'; } - - static MergedInfoCalc calc(Schema schema, List table) { - MergedInfoCalc calc = new MergedInfoCalc(table); - Visitors.visitSchema(schema, calc); - return calc; - } } diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/MultiKudoTableVisitor.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/MultiKudoTableVisitor.java index afa7ba6ea0..bc5678cc50 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/MultiKudoTableVisitor.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/MultiKudoTableVisitor.java @@ -16,17 +16,20 @@ package com.nvidia.spark.rapids.jni.kudo; -import ai.rapids.cudf.HostMemoryBuffer; -import ai.rapids.cudf.Schema; -import com.nvidia.spark.rapids.jni.schema.SchemaVisitor; - -import java.util.*; - import static com.nvidia.spark.rapids.jni.Preconditions.ensure; import static com.nvidia.spark.rapids.jni.kudo.ColumnOffsetInfo.INVALID_OFFSET; import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padForHostAlignment; import static java.lang.Math.toIntExact; +import ai.rapids.cudf.HostMemoryBuffer; +import ai.rapids.cudf.Schema; +import com.nvidia.spark.rapids.jni.schema.SchemaVisitor; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Deque; +import java.util.List; +import java.util.Objects; + /** * This class provides a base class for visiting multiple kudo tables, e.g. it helps to maintain internal states during * visiting multi kudo tables, which makes it easier to do some calculations based on them. @@ -40,11 +43,11 @@ abstract class MultiKudoTableVisitor implements SchemaVisitor private final long[] currentDataOffset; private final Deque[] sliceInfoStack; private final Deque totalRowCountStack; + // Temporary buffer to store data length of string column to avoid repeated allocation + private final int[] strDataLen; // A temporary variable to keep if current column has null private boolean hasNull; private int currentIdx; - // Temporary buffer to store data length of string column to avoid repeated allocation - private final int[] strDataLen; // Temporary variable to calculate total data length of string column private long totalStrDataLen; @@ -187,7 +190,8 @@ private void updateDataLen() { } } - private void updateOffsets(boolean updateOffset, boolean updateData, boolean updateSliceInfo, int sizeInBytes) { + private void updateOffsets(boolean updateOffset, boolean updateData, boolean updateSliceInfo, + int sizeInBytes) { long totalRowCount = 0; for (int tableIdx = 0; tableIdx < tables.size(); tableIdx++) { SliceInfo sliceInfo = sliceInfoOf(tableIdx); @@ -202,11 +206,13 @@ private void updateOffsets(boolean updateOffset, boolean updateData, boolean upd } if (tables.get(tableIdx).getHeader().hasValidityBuffer(currentIdx)) { - currentValidityOffsets[tableIdx] += padForHostAlignment(sliceInfo.getValidityBufferInfo().getBufferLength()); + currentValidityOffsets[tableIdx] += + padForHostAlignment(sliceInfo.getValidityBufferInfo().getBufferLength()); } if (updateOffset) { - currentOffsetOffsets[tableIdx] += padForHostAlignment((sliceInfo.getRowCount() + 1) * Integer.BYTES); + currentOffsetOffsets[tableIdx] += + padForHostAlignment((long) (sliceInfo.getRowCount() + 1) * Integer.BYTES); if (updateData) { // string type currentDataOffset[tableIdx] += padForHostAlignment(strDataLen[tableIdx]); @@ -215,7 +221,8 @@ private void updateOffsets(boolean updateOffset, boolean updateData, boolean upd } else { if (updateData) { // primitive type - currentDataOffset[tableIdx] += padForHostAlignment(sliceInfo.getRowCount() * sizeInBytes); + currentDataOffset[tableIdx] += + padForHostAlignment((long) sliceInfo.getRowCount() * sizeInBytes); } } diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedBufferSerializer.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedBufferSerializer.java index e22a523855..81cf1557c5 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedBufferSerializer.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedBufferSerializer.java @@ -16,19 +16,18 @@ package com.nvidia.spark.rapids.jni.kudo; +import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padForHostAlignment; + import ai.rapids.cudf.BufferType; import ai.rapids.cudf.DType; import ai.rapids.cudf.HostColumnVectorCore; import ai.rapids.cudf.HostMemoryBuffer; import com.nvidia.spark.rapids.jni.schema.HostColumnsVisitor; - import java.io.IOException; import java.util.ArrayDeque; import java.util.Deque; import java.util.List; -import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padForHostAlignment; - /** * This class visits a list of columns and serialize one of the buffers (validity, offset, or data) into with kudo * format. @@ -110,8 +109,8 @@ public Void preVisitList(HostColumnVectorCore col) { SliceInfo current; if (col.getOffsets() != null) { int start = col.getOffsets() - .getInt(parent.offset * Integer.BYTES); - int end = col.getOffsets().getInt((parent.offset + parent.rowCount) * Integer.BYTES); + .getInt((long) parent.offset * Integer.BYTES); + int end = col.getOffsets().getInt((long) (parent.offset + parent.rowCount) * Integer.BYTES); int rowCount = end - start; current = new SliceInfo(start, rowCount); @@ -153,7 +152,8 @@ public Void visit(HostColumnVectorCore col) { } } - private long copySlicedValidity(HostColumnVectorCore column, SliceInfo sliceInfo) throws IOException { + private long copySlicedValidity(HostColumnVectorCore column, SliceInfo sliceInfo) + throws IOException { if (column.getValidity() != null && sliceInfo.getRowCount() > 0) { HostMemoryBuffer buff = column.getValidity(); long len = sliceInfo.getValidityBufferInfo().getBufferLength(); @@ -165,13 +165,14 @@ private long copySlicedValidity(HostColumnVectorCore column, SliceInfo sliceInfo } } - private long copySlicedOffset(HostColumnVectorCore column, SliceInfo sliceInfo) throws IOException { + private long copySlicedOffset(HostColumnVectorCore column, SliceInfo sliceInfo) + throws IOException { if (sliceInfo.rowCount <= 0 || column.getOffsets() == null) { // Don't copy anything, there are no rows return 0; } - long bytesToCopy = (sliceInfo.rowCount + 1) * Integer.BYTES; - long srcOffset = sliceInfo.offset * Integer.BYTES; + long bytesToCopy = (long) (sliceInfo.rowCount + 1) * Integer.BYTES; + long srcOffset = (long) sliceInfo.offset * Integer.BYTES; HostMemoryBuffer buff = column.getOffsets(); writer.copyDataFrom(buff, srcOffset, bytesToCopy); return padForHostAlignment(writer, bytesToCopy); @@ -181,8 +182,10 @@ private long copySlicedData(HostColumnVectorCore column, SliceInfo sliceInfo) th if (sliceInfo.rowCount > 0) { DType type = column.getType(); if (type.equals(DType.STRING)) { - long startByteOffset = column.getOffsets().getInt(sliceInfo.offset * Integer.BYTES); - long endByteOffset = column.getOffsets().getInt((sliceInfo.offset + sliceInfo.rowCount) * Integer.BYTES); + long startByteOffset = column.getOffsets().getInt((long) sliceInfo.offset * Integer.BYTES); + long endByteOffset = + column.getOffsets().getInt( + (long) (sliceInfo.offset + sliceInfo.rowCount) * Integer.BYTES); long bytesToCopy = endByteOffset - startByteOffset; if (column.getData() == null) { if (bytesToCopy != 0) { @@ -196,8 +199,8 @@ private long copySlicedData(HostColumnVectorCore column, SliceInfo sliceInfo) th return padForHostAlignment(writer, bytesToCopy); } } else if (type.getSizeInBytes() > 0) { - long bytesToCopy = sliceInfo.rowCount * type.getSizeInBytes(); - long srcOffset = sliceInfo.offset * type.getSizeInBytes(); + long bytesToCopy = (long) sliceInfo.rowCount * type.getSizeInBytes(); + long srcOffset = (long) sliceInfo.offset * type.getSizeInBytes(); writer.copyDataFrom(column.getData(), srcOffset, bytesToCopy); return padForHostAlignment(writer, bytesToCopy); } else { diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedValidityBufferInfo.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedValidityBufferInfo.java index 7c9957f5b2..7e2e74d502 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedValidityBufferInfo.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedValidityBufferInfo.java @@ -31,9 +31,26 @@ class SlicedValidityBufferInfo { this.beginBit = beginBit; } + static SlicedValidityBufferInfo calc(int rowOffset, int numRows) { + if (rowOffset < 0) { + throw new IllegalArgumentException("rowOffset must be >= 0, but was " + rowOffset); + } + if (numRows < 0) { + throw new IllegalArgumentException("numRows must be >= 0, but was " + numRows); + } + int bufferOffset = rowOffset / 8; + int beginBit = rowOffset % 8; + int bufferLength = 0; + if (numRows > 0) { + bufferLength = (rowOffset + numRows - 1) / 8 - bufferOffset + 1; + } + return new SlicedValidityBufferInfo(bufferOffset, bufferLength, beginBit); + } + @Override public String toString() { - return "SlicedValidityBufferInfo{" + "bufferOffset=" + bufferOffset + ", bufferLength=" + bufferLength + + return "SlicedValidityBufferInfo{" + "bufferOffset=" + bufferOffset + ", bufferLength=" + + bufferLength + ", beginBit=" + beginBit + '}'; } @@ -48,20 +65,4 @@ public int getBufferLength() { public int getBeginBit() { return beginBit; } - - static SlicedValidityBufferInfo calc(int rowOffset, int numRows) { - if (rowOffset < 0) { - throw new IllegalArgumentException("rowOffset must be >= 0, but was " + rowOffset); - } - if (numRows < 0) { - throw new IllegalArgumentException("numRows must be >= 0, but was " + numRows); - } - int bufferOffset = rowOffset / 8; - int beginBit = rowOffset % 8; - int bufferLength = 0; - if (numRows > 0) { - bufferLength = (rowOffset + numRows - 1) / 8 - bufferOffset + 1; - } - return new SlicedValidityBufferInfo(bufferOffset, bufferLength, beginBit); - } } diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/TableBuilder.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/TableBuilder.java index e50e462f4f..ded99ab64f 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/TableBuilder.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/TableBuilder.java @@ -16,24 +16,28 @@ package com.nvidia.spark.rapids.jni.kudo; -import ai.rapids.cudf.*; +import static com.nvidia.spark.rapids.jni.Preconditions.ensure; +import static java.util.Objects.requireNonNull; + +import ai.rapids.cudf.CloseableArray; +import ai.rapids.cudf.ColumnVector; +import ai.rapids.cudf.ColumnView; +import ai.rapids.cudf.DeviceMemoryBuffer; +import ai.rapids.cudf.Schema; +import ai.rapids.cudf.Table; import com.nvidia.spark.rapids.jni.Arms; import com.nvidia.spark.rapids.jni.schema.SchemaVisitor; - import java.util.ArrayList; import java.util.List; -import static com.nvidia.spark.rapids.jni.Preconditions.ensure; -import static java.util.Objects.requireNonNull; - /** * This class is used to build a cudf table from a list of column view info, and a device buffer. */ class TableBuilder implements SchemaVisitor, AutoCloseable { - private int curColumnIdx; private final DeviceMemoryBuffer buffer; private final List colViewInfoList; private final List columnViewList; + private int curColumnIdx; public TableBuilder(List colViewInfoList, DeviceMemoryBuffer buffer) { requireNonNull(colViewInfoList, "colViewInfoList cannot be null"); @@ -52,10 +56,12 @@ public Table visitTopSchema(Schema schema, List children) { // `children`, so we need to clear `columnViewList`. this.columnViewList.clear(); try { - try (CloseableArray arr = CloseableArray.wrap(new ColumnVector[children.size()])) { + try (CloseableArray arr = CloseableArray.wrap( + new ColumnVector[children.size()])) { for (int i = 0; i < children.size(); i++) { ColumnView colView = children.set(i, null); - arr.set(i, ColumnVector.fromViewWithContiguousAllocation(colView.getNativeView(), buffer)); + arr.set(i, + ColumnVector.fromViewWithContiguousAllocation(colView.getNativeView(), buffer)); } return new Table(arr.getArray()); @@ -87,7 +93,7 @@ public ColumnViewInfo preVisitList(Schema listType) { @Override public ColumnView visitList(Schema listType, ColumnViewInfo colViewInfo, ColumnView childResult) { - ColumnView[] children = new ColumnView[]{childResult}; + ColumnView[] children = new ColumnView[] {childResult}; ColumnView view = colViewInfo.buildColumnView(buffer, children); columnViewList.add(view); diff --git a/src/main/java/com/nvidia/spark/rapids/jni/schema/HostColumnsVisitor.java b/src/main/java/com/nvidia/spark/rapids/jni/schema/HostColumnsVisitor.java index ae7915c60d..6b9664607f 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/schema/HostColumnsVisitor.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/schema/HostColumnsVisitor.java @@ -19,14 +19,13 @@ package com.nvidia.spark.rapids.jni.schema; import ai.rapids.cudf.HostColumnVectorCore; - import java.util.List; /** * A post order visitor for visiting a list of host columns in a schema. * *

- * + *

* For example, if we have three columns A, B, and C with following types: * *

    @@ -34,7 +33,7 @@ *
  • B: list { int b1}
  • *
  • C: string c1
  • *
- * + *

* The order of visiting will be: *

    *
  1. Visit primitive column a1
  2. @@ -51,34 +50,38 @@ * @param Return type when visiting intermediate nodes. */ public interface HostColumnsVisitor { - /** - * Visit a struct column. - * @param col the struct column to visit - * @param children the results of visiting the children - * @return the result of visiting the struct column - */ - T visitStruct(HostColumnVectorCore col, List children); + /** + * Visit a struct column. + * + * @param col the struct column to visit + * @param children the results of visiting the children + * @return the result of visiting the struct column + */ + T visitStruct(HostColumnVectorCore col, List children); - /** - * Visit a list column before actually visiting its child. - * @param col the list column to visit - * @return the result of visiting the list column - */ - T preVisitList(HostColumnVectorCore col); + /** + * Visit a list column before actually visiting its child. + * + * @param col the list column to visit + * @return the result of visiting the list column + */ + T preVisitList(HostColumnVectorCore col); - /** - * Visit a list column after visiting its child. - * @param col the list column to visit - * @param preVisitResult the result of visiting the list column before visiting its child - * @param childResult the result of visiting the child - * @return the result of visiting the list column - */ - T visitList(HostColumnVectorCore col, T preVisitResult, T childResult); + /** + * Visit a list column after visiting its child. + * + * @param col the list column to visit + * @param preVisitResult the result of visiting the list column before visiting its child + * @param childResult the result of visiting the child + * @return the result of visiting the list column + */ + T visitList(HostColumnVectorCore col, T preVisitResult, T childResult); - /** - * Visit a column that is a primitive type. - * @param col the column to visit - * @return the result of visiting the column - */ - T visit(HostColumnVectorCore col); + /** + * Visit a column that is a primitive type. + * + * @param col the column to visit + * @return the result of visiting the column + */ + T visit(HostColumnVectorCore col); } diff --git a/src/main/java/com/nvidia/spark/rapids/jni/schema/SchemaVisitor.java b/src/main/java/com/nvidia/spark/rapids/jni/schema/SchemaVisitor.java index c6b33e0fb4..8e81ee12c1 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/schema/SchemaVisitor.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/schema/SchemaVisitor.java @@ -19,14 +19,13 @@ package com.nvidia.spark.rapids.jni.schema; import ai.rapids.cudf.Schema; - import java.util.List; /** * A post order visitor for schemas. * *

    Flattened Schema

    - * + *

    * A flattened schema is a schema where all fields with nested types are flattened into an array of fields. For example, * for a schema with following fields: * @@ -36,7 +35,7 @@ *

  3. C: string
  4. *
  5. D: long
  6. * - * + *

    * The flattened schema will be: * *

      @@ -59,7 +58,7 @@ *
    • B: list { int b1}
    • *
    • C: string
    • *
    - * + *

    * The order of visiting will be: *

      *
    1. Visit primitive field a1
    2. @@ -79,42 +78,47 @@ * @param Return type after processing all children values. */ public interface SchemaVisitor { - /** - * Visit the top level schema. - * @param schema the top level schema to visit - * @param children the results of visiting the children - * @return the result of visiting the top level schema - */ - R visitTopSchema(Schema schema, List children); + /** + * Visit the top level schema. + * + * @param schema the top level schema to visit + * @param children the results of visiting the children + * @return the result of visiting the top level schema + */ + R visitTopSchema(Schema schema, List children); - /** - * Visit a struct schema. - * @param structType the struct schema to visit - * @param children the results of visiting the children - * @return the result of visiting the struct schema - */ - T visitStruct(Schema structType, List children); + /** + * Visit a struct schema. + * + * @param structType the struct schema to visit + * @param children the results of visiting the children + * @return the result of visiting the struct schema + */ + T visitStruct(Schema structType, List children); - /** - * Visit a list schema before actually visiting its child. - * @param listType the list schema to visit - * @return the result of visiting the list schema - */ - P preVisitList(Schema listType); + /** + * Visit a list schema before actually visiting its child. + * + * @param listType the list schema to visit + * @return the result of visiting the list schema + */ + P preVisitList(Schema listType); - /** - * Visit a list schema after visiting its child. - * @param listType the list schema to visit - * @param preVisitResult the result of visiting the list schema before visiting its child - * @param childResult the result of visiting the child - * @return the result of visiting the list schema - */ - T visitList(Schema listType, P preVisitResult, T childResult); + /** + * Visit a list schema after visiting its child. + * + * @param listType the list schema to visit + * @param preVisitResult the result of visiting the list schema before visiting its child + * @param childResult the result of visiting the child + * @return the result of visiting the list schema + */ + T visitList(Schema listType, P preVisitResult, T childResult); - /** - * Visit a primitive type. - * @param primitiveType the primitive type to visit - * @return the result of visiting the primitive type - */ - T visit(Schema primitiveType); + /** + * Visit a primitive type. + * + * @param primitiveType the primitive type to visit + * @return the result of visiting the primitive type + */ + T visit(Schema primitiveType); } diff --git a/src/main/java/com/nvidia/spark/rapids/jni/schema/Visitors.java b/src/main/java/com/nvidia/spark/rapids/jni/schema/Visitors.java index b7f4f521e4..92ba7f76e1 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/schema/Visitors.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/schema/Visitors.java @@ -21,7 +21,6 @@ import ai.rapids.cudf.HostColumnVector; import ai.rapids.cudf.HostColumnVectorCore; import ai.rapids.cudf.Schema; - import java.util.List; import java.util.Objects; import java.util.stream.Collectors; @@ -31,75 +30,75 @@ * A utility class for visiting a schema or a list of host columns. */ public class Visitors { - /** - * Visiting a schema in post order. For more details, see {@link SchemaVisitor}. - * - * @param schema the schema to visit - * @param visitor the visitor to use - * @param Return type when visiting intermediate nodes. See {@link SchemaVisitor} - * @param

      Return type when previsiting a list. See {@link SchemaVisitor} - * @param Return type after processing all children values. See {@link SchemaVisitor} - * @return the result of visiting the schema - */ - public static R visitSchema(Schema schema, SchemaVisitor visitor) { - Objects.requireNonNull(schema, "schema cannot be null"); - Objects.requireNonNull(visitor, "visitor cannot be null"); + /** + * Visiting a schema in post order. For more details, see {@link SchemaVisitor}. + * + * @param schema the schema to visit + * @param visitor the visitor to use + * @param Return type when visiting intermediate nodes. See {@link SchemaVisitor} + * @param

      Return type when previsiting a list. See {@link SchemaVisitor} + * @param Return type after processing all children values. See {@link SchemaVisitor} + * @return the result of visiting the schema + */ + public static R visitSchema(Schema schema, SchemaVisitor visitor) { + Objects.requireNonNull(schema, "schema cannot be null"); + Objects.requireNonNull(visitor, "visitor cannot be null"); - List childrenResult = IntStream.range(0, schema.getNumChildren()) - .mapToObj(i -> visitSchemaInner(schema.getChild(i), visitor)) - .collect(Collectors.toList()); + List childrenResult = IntStream.range(0, schema.getNumChildren()) + .mapToObj(i -> visitSchemaInner(schema.getChild(i), visitor)) + .collect(Collectors.toList()); - return visitor.visitTopSchema(schema, childrenResult); - } + return visitor.visitTopSchema(schema, childrenResult); + } - private static T visitSchemaInner(Schema schema, SchemaVisitor visitor) { - switch (schema.getType().getTypeId()) { - case STRUCT: - List children = IntStream.range(0, schema.getNumChildren()) - .mapToObj(childIdx -> visitSchemaInner(schema.getChild(childIdx), visitor)) - .collect(Collectors.toList()); - return visitor.visitStruct(schema, children); - case LIST: - P preVisitResult = visitor.preVisitList(schema); - T childResult = visitSchemaInner(schema.getChild(0), visitor); - return visitor.visitList(schema, preVisitResult, childResult); - default: - return visitor.visit(schema); - } + private static T visitSchemaInner(Schema schema, SchemaVisitor visitor) { + switch (schema.getType().getTypeId()) { + case STRUCT: + List children = IntStream.range(0, schema.getNumChildren()) + .mapToObj(childIdx -> visitSchemaInner(schema.getChild(childIdx), visitor)) + .collect(Collectors.toList()); + return visitor.visitStruct(schema, children); + case LIST: + P preVisitResult = visitor.preVisitList(schema); + T childResult = visitSchemaInner(schema.getChild(0), visitor); + return visitor.visitList(schema, preVisitResult, childResult); + default: + return visitor.visit(schema); } + } - /** - * Visiting a list of host columns in post order. For more details, see {@link HostColumnsVisitor}. - * - * @param cols the list of host columns to visit - * @param visitor the visitor to use - * @param Return type when visiting intermediate nodes. See {@link HostColumnsVisitor} - */ - public static void visitColumns(HostColumnVector[] cols, - HostColumnsVisitor visitor) { - Objects.requireNonNull(cols, "cols cannot be null"); - Objects.requireNonNull(visitor, "visitor cannot be null"); - - for (HostColumnVector col : cols) { - visitColumn(col, visitor); - } + /** + * Visiting a list of host columns in post order. For more details, see {@link HostColumnsVisitor}. + * + * @param cols the list of host columns to visit + * @param visitor the visitor to use + * @param Return type when visiting intermediate nodes. See {@link HostColumnsVisitor} + */ + public static void visitColumns(HostColumnVector[] cols, + HostColumnsVisitor visitor) { + Objects.requireNonNull(cols, "cols cannot be null"); + Objects.requireNonNull(visitor, "visitor cannot be null"); + for (HostColumnVector col : cols) { + visitColumn(col, visitor); } - private static T visitColumn(HostColumnVectorCore col, HostColumnsVisitor visitor) { - switch (col.getType().getTypeId()) { - case STRUCT: - List children = IntStream.range(0, col.getNumChildren()) - .mapToObj(childIdx -> visitColumn(col.getChildColumnView(childIdx), visitor)) - .collect(Collectors.toList()); - return visitor.visitStruct(col, children); - case LIST: - T preVisitResult = visitor.preVisitList(col); - T childResult = visitColumn(col.getChildColumnView(0), visitor); - return visitor.visitList(col, preVisitResult, childResult); - default: - return visitor.visit(col); - } + } + + private static T visitColumn(HostColumnVectorCore col, HostColumnsVisitor visitor) { + switch (col.getType().getTypeId()) { + case STRUCT: + List children = IntStream.range(0, col.getNumChildren()) + .mapToObj(childIdx -> visitColumn(col.getChildColumnView(childIdx), visitor)) + .collect(Collectors.toList()); + return visitor.visitStruct(col, children); + case LIST: + T preVisitResult = visitor.preVisitList(col); + T childResult = visitColumn(col.getChildColumnView(0), visitor); + return visitor.visitList(col, preVisitResult, childResult); + default: + return visitor.visit(col); } + } } diff --git a/src/test/java/com/nvidia/spark/rapids/jni/BloomFilterTest.java b/src/test/java/com/nvidia/spark/rapids/jni/BloomFilterTest.java index 1ce6f18ed3..df8b24f840 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/BloomFilterTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/BloomFilterTest.java @@ -16,90 +16,99 @@ package com.nvidia.spark.rapids.jni; -import com.nvidia.spark.rapids.jni.BloomFilter; +import static org.junit.jupiter.api.Assertions.assertThrows; import ai.rapids.cudf.AssertUtils; import ai.rapids.cudf.ColumnVector; -import ai.rapids.cudf.Cuda; import ai.rapids.cudf.CudfException; import ai.rapids.cudf.Scalar; -import ai.rapids.cudf.DeviceMemoryBuffer; - -import static org.junit.jupiter.api.Assertions.assertThrows; import org.junit.jupiter.api.Test; public class BloomFilterTest { @Test - void testBuildAndProbe(){ + void testBuildAndProbe() { int numHashes = 3; long bloomFilterBits = 4 * 1024 * 1024; try (ColumnVector input = ColumnVector.fromLongs(20, 80, 100, 99, 47, -9, 234000000); - Scalar bloomFilter = BloomFilter.create(numHashes, bloomFilterBits)){ - + Scalar bloomFilter = BloomFilter.create(numHashes, bloomFilterBits)) { + BloomFilter.put(bloomFilter, input); - try(ColumnVector probe = ColumnVector.fromLongs(20, 80, 100, 99, 47, -9, 234000000, -10, 1, 2, 3); - ColumnVector expected = ColumnVector.fromBooleans(true, true, true, true, true, true, true, false, false, false, false); - ColumnVector result = BloomFilter.probe(bloomFilter, probe)){ + try ( + ColumnVector probe = ColumnVector.fromLongs(20, 80, 100, 99, 47, -9, 234000000, -10, 1, 2, + 3); + ColumnVector expected = ColumnVector.fromBooleans(true, true, true, true, true, true, + true, false, false, false, false); + ColumnVector result = BloomFilter.probe(bloomFilter, probe)) { AssertUtils.assertColumnsAreEqual(expected, result); } } } @Test - void testBuildAndProbeBuffer(){ + void testBuildAndProbeBuffer() { int numHashes = 3; long bloomFilterBits = 4 * 1024 * 1024; try (ColumnVector input = ColumnVector.fromLongs(20, 80, 100, 99, 47, -9, 234000000); - Scalar bloomFilter = BloomFilter.create(numHashes, bloomFilterBits)){ - + Scalar bloomFilter = BloomFilter.create(numHashes, bloomFilterBits)) { + BloomFilter.put(bloomFilter, input); - try(ColumnVector probe = ColumnVector.fromLongs(20, 80, 100, 99, 47, -9, 234000000, -10, 1, 2, 3); - ColumnVector expected = ColumnVector.fromBooleans(true, true, true, true, true, true, true, false, false, false, false); - ColumnVector result = BloomFilter.probe(bloomFilter.getListAsColumnView().getData(), probe)){ + try ( + ColumnVector probe = ColumnVector.fromLongs(20, 80, 100, 99, 47, -9, 234000000, -10, 1, 2, + 3); + ColumnVector expected = ColumnVector.fromBooleans(true, true, true, true, true, true, + true, false, false, false, false); + ColumnVector result = BloomFilter.probe(bloomFilter.getListAsColumnView().getData(), + probe)) { AssertUtils.assertColumnsAreEqual(expected, result); } } } @Test - void testBuildWithNullsAndProbe(){ + void testBuildWithNullsAndProbe() { int numHashes = 3; long bloomFilterBits = 4 * 1024 * 1024; - try (ColumnVector input = ColumnVector.fromBoxedLongs(null, 80L, 100L, null, 47L, -9L, 234000000L); - Scalar bloomFilter = BloomFilter.create(numHashes, bloomFilterBits)){ - + try (ColumnVector input = ColumnVector.fromBoxedLongs(null, 80L, 100L, null, 47L, -9L, + 234000000L); + Scalar bloomFilter = BloomFilter.create(numHashes, bloomFilterBits)) { + BloomFilter.put(bloomFilter, input); - try(ColumnVector probe = ColumnVector.fromLongs(20, 80, 100, 99, 47, -9, 234000000, -10, 1, 2, 3); - ColumnVector expected = ColumnVector.fromBooleans(false, true, true, false, true, true, true, false, false, false, false); - ColumnVector result = BloomFilter.probe(bloomFilter, probe)){ + try ( + ColumnVector probe = ColumnVector.fromLongs(20, 80, 100, 99, 47, -9, 234000000, -10, 1, 2, + 3); + ColumnVector expected = ColumnVector.fromBooleans(false, true, true, false, true, true, + true, false, false, false, false); + ColumnVector result = BloomFilter.probe(bloomFilter, probe)) { AssertUtils.assertColumnsAreEqual(expected, result); } } } @Test - void testBuildAndProbeWithNulls(){ + void testBuildAndProbeWithNulls() { int numHashes = 3; long bloomFilterBits = 4 * 1024 * 1024; try (ColumnVector input = ColumnVector.fromLongs(20, 80, 100, 99, 47, -9, 234000000); - Scalar bloomFilter = BloomFilter.create(numHashes, bloomFilterBits)){ - + Scalar bloomFilter = BloomFilter.create(numHashes, bloomFilterBits)) { + BloomFilter.put(bloomFilter, input); - try(ColumnVector probe = ColumnVector.fromBoxedLongs(null, null, null, 99L, 47L, -9L, 234000000L, null, null, 2L, 3L); - ColumnVector expected = ColumnVector.fromBoxedBooleans(null, null, null, true, true, true, true, null, null, false, false); - ColumnVector result = BloomFilter.probe(bloomFilter, probe)){ + try (ColumnVector probe = ColumnVector.fromBoxedLongs(null, null, null, 99L, 47L, -9L, + 234000000L, null, null, 2L, 3L); + ColumnVector expected = ColumnVector.fromBoxedBooleans(null, null, null, true, true, + true, true, null, null, false, false); + ColumnVector result = BloomFilter.probe(bloomFilter, probe)) { AssertUtils.assertColumnsAreEqual(expected, result); } } } - + @Test - void testBuildMergeProbe(){ + void testBuildMergeProbe() { int numHashes = 3; long bloomFilterBits = 4 * 1024 * 1024; @@ -108,67 +117,77 @@ void testBuildMergeProbe(){ ColumnVector colC = ColumnVector.fromLongs(-100, -200, -300, -400); Scalar bloomFilterA = BloomFilter.create(numHashes, bloomFilterBits); Scalar bloomFilterB = BloomFilter.create(numHashes, bloomFilterBits); - Scalar bloomFilterC = BloomFilter.create(numHashes, bloomFilterBits)){ + Scalar bloomFilterC = BloomFilter.create(numHashes, bloomFilterBits)) { BloomFilter.put(bloomFilterA, colA); BloomFilter.put(bloomFilterB, colB); BloomFilter.put(bloomFilterC, colC); - - ColumnVector premerge = ColumnVector.concatenate(ColumnVector.fromScalar(bloomFilterA, 1), - ColumnVector.fromScalar(bloomFilterB, 1), - ColumnVector.fromScalar(bloomFilterC, 1)); - try(ColumnVector probe = ColumnVector.fromLongs(-9, 200, 300, 6000, -2546, 99, 65535, 0, -100, -200, -300, -400); - ColumnVector expected = ColumnVector.fromBooleans(true, true, true, false, false, true, false, false, true, true, true, true); + ColumnVector premerge = ColumnVector.concatenate(ColumnVector.fromScalar(bloomFilterA, 1), + ColumnVector.fromScalar(bloomFilterB, 1), + ColumnVector.fromScalar(bloomFilterC, 1)); + + try ( + ColumnVector probe = ColumnVector.fromLongs(-9, 200, 300, 6000, -2546, 99, 65535, 0, -100, + -200, -300, -400); + ColumnVector expected = ColumnVector.fromBooleans(true, true, true, false, false, true, + false, false, true, true, true, true); Scalar merged = BloomFilter.merge(premerge); - ColumnVector result = BloomFilter.probe(merged, probe)){ - AssertUtils.assertColumnsAreEqual(expected, result); + ColumnVector result = BloomFilter.probe(merged, probe)) { + AssertUtils.assertColumnsAreEqual(expected, result); } } } @Test - void testBuildTrivialMergeProbe(){ + void testBuildTrivialMergeProbe() { int numHashes = 3; long bloomFilterBits = 4 * 1024 * 1024; try (ColumnVector colA = ColumnVector.fromLongs(20, 80, 100, 99, 47, -9, 234000000); - Scalar bloomFilter = BloomFilter.create(numHashes, bloomFilterBits)){ + Scalar bloomFilter = BloomFilter.create(numHashes, bloomFilterBits)) { BloomFilter.put(bloomFilter, colA); ColumnVector premerge = ColumnVector.fromScalar(bloomFilter, 1); - try(ColumnVector probe = ColumnVector.fromLongs(-9, 200, 300, 6000, -2546, 99, 65535, 0, -100, -200, -300, -400); - ColumnVector expected = ColumnVector.fromBooleans(true, false, false, false, false, true, false, false, false, false, false, false); + try ( + ColumnVector probe = ColumnVector.fromLongs(-9, 200, 300, 6000, -2546, 99, 65535, 0, -100, + -200, -300, -400); + ColumnVector expected = ColumnVector.fromBooleans(true, false, false, false, false, true, + false, false, false, false, false, false); Scalar merged = BloomFilter.merge(premerge); - ColumnVector result = BloomFilter.probe(merged, probe)){ - AssertUtils.assertColumnsAreEqual(expected, result); + ColumnVector result = BloomFilter.probe(merged, probe)) { + AssertUtils.assertColumnsAreEqual(expected, result); } } } @Test - void testBuildExpectedFailures(){ + void testBuildExpectedFailures() { // bloom filter with no hashes assertThrows(IllegalArgumentException.class, () -> { - try (Scalar bloomFilter = BloomFilter.create(0, 64)){} + try (Scalar bloomFilter = BloomFilter.create(0, 64)) { + } }); // bloom filter with no size assertThrows(IllegalArgumentException.class, () -> { - try (Scalar bloomFilter = BloomFilter.create(3, 0)){} + try (Scalar bloomFilter = BloomFilter.create(3, 0)) { + } }); - + // merge with mixed hash counts assertThrows(CudfException.class, () -> { try (Scalar bloomFilterA = BloomFilter.create(3, 1024); Scalar bloomFilterB = BloomFilter.create(4, 1024); Scalar bloomFilterC = BloomFilter.create(4, 1024); - ColumnVector premerge = ColumnVector.concatenate(ColumnVector.fromScalar(bloomFilterA, 1), - ColumnVector.fromScalar(bloomFilterB, 1), - ColumnVector.fromScalar(bloomFilterC, 1)); - Scalar merged = BloomFilter.merge(premerge)){} + ColumnVector premerge = ColumnVector.concatenate( + ColumnVector.fromScalar(bloomFilterA, 1), + ColumnVector.fromScalar(bloomFilterB, 1), + ColumnVector.fromScalar(bloomFilterC, 1)); + Scalar merged = BloomFilter.merge(premerge)) { + } }); // merge with mixed hash bit sizes @@ -176,10 +195,12 @@ void testBuildExpectedFailures(){ try (Scalar bloomFilterA = BloomFilter.create(3, 1024); Scalar bloomFilterB = BloomFilter.create(3, 1024); Scalar bloomFilterC = BloomFilter.create(3, 2048); - ColumnVector premerge = ColumnVector.concatenate(ColumnVector.fromScalar(bloomFilterA, 1), - ColumnVector.fromScalar(bloomFilterB, 1), - ColumnVector.fromScalar(bloomFilterC, 1)); - Scalar merged = BloomFilter.merge(premerge)){} + ColumnVector premerge = ColumnVector.concatenate( + ColumnVector.fromScalar(bloomFilterA, 1), + ColumnVector.fromScalar(bloomFilterB, 1), + ColumnVector.fromScalar(bloomFilterC, 1)); + Scalar merged = BloomFilter.merge(premerge)) { + } }); } } diff --git a/src/test/java/com/nvidia/spark/rapids/jni/CaseWhenTest.java b/src/test/java/com/nvidia/spark/rapids/jni/CaseWhenTest.java index f0bff026ed..9a2e588af4 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/CaseWhenTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/CaseWhenTest.java @@ -16,12 +16,11 @@ package com.nvidia.spark.rapids.jni; -import ai.rapids.cudf.*; +import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual; +import ai.rapids.cudf.ColumnVector; import org.junit.jupiter.api.Test; -import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual; - public class CaseWhenTest { @Test @@ -36,7 +35,7 @@ void selectIndexTest() { ColumnVector b3 = ColumnVector.fromBooleans( true, true, true, false); ColumnVector expected = ColumnVector.fromInts(0, 1, 2, 4)) { - ColumnVector[] boolColumns = new ColumnVector[] { b0, b1, b2, b3 }; + ColumnVector[] boolColumns = new ColumnVector[] {b0, b1, b2, b3}; try (ColumnVector actual = CaseWhen.selectFirstTrueIndex(boolColumns)) { assertColumnsAreEqual(expected, actual); } @@ -55,7 +54,7 @@ void selectIndexTestWithNull() { ColumnVector b3 = ColumnVector.fromBoxedBooleans( null, null, null, true, null); ColumnVector expected = ColumnVector.fromInts(4, 4, 4, 1, 1)) { - ColumnVector[] boolColumns = new ColumnVector[] { b0, b1, b2, b3 }; + ColumnVector[] boolColumns = new ColumnVector[] {b0, b1, b2, b3}; try (ColumnVector actual = CaseWhen.selectFirstTrueIndex(boolColumns)) { assertColumnsAreEqual(expected, actual); } diff --git a/src/test/java/com/nvidia/spark/rapids/jni/CastStringsTest.java b/src/test/java/com/nvidia/spark/rapids/jni/CastStringsTest.java index f784736819..aba73024d4 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/CastStringsTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/CastStringsTest.java @@ -19,23 +19,21 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.fail; -import java.util.ArrayList; -import java.util.List; - -import org.junit.jupiter.api.Test; - import ai.rapids.cudf.AssertUtils; import ai.rapids.cudf.ColumnVector; import ai.rapids.cudf.DType; import ai.rapids.cudf.Table; +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.Test; public class CastStringsTest { @Test void castToIntegerTest() { Table.TestBuilder tb = new Table.TestBuilder(); - tb.column(3l, 9l, 4l, 2l, 20l, null, null, 1l); + tb.column(3L, 9L, 4L, 2L, 20L, null, null, 1L); tb.column(5, 1, 0, 2, 7, null, null, 1); - tb.column(new Byte[]{2, 3, 4, 5, 9, null, null, 1}); + tb.column(new Byte[] {2, 3, 4, 5, 9, null, null, 1}); try (Table expected = tb.build()) { Table.TestBuilder tb2 = new Table.TestBuilder(); tb2.column(" 3", "9", "4", "2", "20.5", null, "7.6asd", "\u0000 \u001f1\u0014"); @@ -62,9 +60,9 @@ void castToIntegerTest() { @Test void castToIntegerNoStripTest() { Table.TestBuilder tb = new Table.TestBuilder(); - tb.column(null, 9l, 4l, 2l, 20l, null, null); + tb.column(null, 9L, 4L, 2L, 20L, null, null); tb.column(5, null, 0, 2, 7, null, null); - tb.column(new Byte[]{2, 3, null, 5, null, null, null}); + tb.column(new Byte[] {2, 3, null, 5, null, null, null}); try (Table expected = tb.build()) { Table.TestBuilder tb2 = new Table.TestBuilder(); tb2.column(" 3", "9", "4", "2", "20.5", null, "7.6asd"); @@ -91,9 +89,9 @@ void castToIntegerNoStripTest() { @Test void castToIntegerAnsiTest() { Table.TestBuilder tb = new Table.TestBuilder(); - tb.column(3l, 9l, 4l, 2l, 20l); + tb.column(3L, 9L, 4L, 2L, 20L); tb.column(5, 1, 0, 2, 7); - tb.column(new Byte[]{2, 3, 4, 5, 9}); + tb.column(new Byte[] {2, 3, 4, 5, 9}); try (Table expected = tb.build()) { Table.TestBuilder tb2 = new Table.TestBuilder(); tb2.column("3", "9", "4", "2", "20"); @@ -120,7 +118,7 @@ void castToIntegerAnsiTest() { try (Table failTable = fail.build(); ColumnVector cv = CastStrings.toInteger(failTable.getColumn(0), true, - expected.getColumn(0).getType());) { + expected.getColumn(0).getType())) { fail("Should have thrown"); } catch (CastException e) { assertEquals("asdf", e.getStringWithError()); @@ -136,16 +134,16 @@ void castToFloatsTrimTest() { tb.column(1.1d, 1.2d, 1.3d, 1.4d, 1.5d, null, null); try (Table expected = tb.build()) { Table.TestBuilder tb2 = new Table.TestBuilder(); - tb2.column("1.1\u0000", "1.2\u0014", "1.3\u001f", + tb2.column("1.1\u0000", "1.2\u0014", "1.3\u001f", "\u0000\u00001.4\u0000", "1.5\u0000\u0020\u0000", "1.6\u009f", "1.7\u0021"); - tb2.column("1.1\u0000", "1.2\u0014", "1.3\u001f", + tb2.column("1.1\u0000", "1.2\u0014", "1.3\u001f", "\u0000\u00001.4\u0000", "1.5\u0000\u0020\u0000", "1.6\u009f", "1.7\u0021"); List result = new ArrayList<>(); try (Table origTable = tb2.build()) { for (int i = 0; i < origTable.getNumberOfColumns(); i++) { ColumnVector string_col = origTable.getColumn(i); - result.add(CastStrings.toFloat(string_col, false, + result.add(CastStrings.toFloat(string_col, false, expected.getColumn(i).getType())); } try (Table result_tbl = new Table( @@ -159,7 +157,7 @@ void castToFloatsTrimTest() { } @Test - void castToFloatNanTest(){ + void castToFloatNanTest() { Table.TestBuilder tb2 = new Table.TestBuilder(); tb2.column("nan", "nan ", " nan ", "NAN", "nAn ", " NAn ", "Nan 0", "+naN", "-nAn"); @@ -183,13 +181,14 @@ void castToFloatNanTest(){ } @Test - void castToFloatsInfTest(){ + void castToFloatsInfTest() { // The test data: Table.TestBuilder object with a column containing the string "inf" Table.TestBuilder tb2 = new Table.TestBuilder(); tb2.column("INFINITY ", "inf", "+inf ", " -INF ", "INFINITY AND BEYOND", "INF"); Table.TestBuilder tb = new Table.TestBuilder(); - tb.column(Float.POSITIVE_INFINITY, Float.POSITIVE_INFINITY, Float.POSITIVE_INFINITY, Float.NEGATIVE_INFINITY, null, Float.POSITIVE_INFINITY); + tb.column(Float.POSITIVE_INFINITY, Float.POSITIVE_INFINITY, Float.POSITIVE_INFINITY, + Float.NEGATIVE_INFINITY, null, Float.POSITIVE_INFINITY); try (Table expected = tb.build()) { List result = new ArrayList<>(); @@ -211,12 +210,12 @@ void castToFloatsInfTest(){ @Test void castToDecimalTest() { Table.TestBuilder tb = new Table.TestBuilder(); - tb.decimal32Column(0,3, 9, 4, 2, 21, null, null, 1); - tb.decimal64Column(0, 5l, 1l, 0l, 2l, 7l, null, null, 1l); + tb.decimal32Column(0, 3, 9, 4, 2, 21, null, null, 1); + tb.decimal64Column(0, 5L, 1L, 0L, 2L, 7L, null, null, 1L); tb.decimal32Column(-1, 20, 30, 40, 51, 92, null, null, 10); try (Table expected = tb.build()) { - int[] desiredPrecision = new int[]{2, 10, 3}; - int[] desiredScale = new int[]{0, 0, -1}; + int[] desiredPrecision = new int[] {2, 10, 3}; + int[] desiredScale = new int[] {0, 0, -1}; Table.TestBuilder tb2 = new Table.TestBuilder(); tb2.column(" 3", "9", "4", "2", "20.5", null, "7.6asd", "\u0000 \u001f1\u0014"); @@ -244,11 +243,11 @@ void castToDecimalTest() { void castToDecimalNoStripTest() { Table.TestBuilder tb = new Table.TestBuilder(); tb.decimal32Column(0, null, 9, 4, 2, 21, null, null); - tb.decimal64Column(0, 5l, null, 0l, 2l, 7l, null, null); + tb.decimal64Column(0, 5L, null, 0L, 2L, 7L, null, null); tb.decimal32Column(-1, 20, 30, null, 51, 92, null, null); try (Table expected = tb.build()) { - int[] desiredPrecision = new int[]{2, 10, 3}; - int[] desiredScale = new int[]{0, 0, -1}; + int[] desiredPrecision = new int[] {2, 10, 3}; + int[] desiredScale = new int[] {0, 0, -1}; Table.TestBuilder tb2 = new Table.TestBuilder(); tb2.column(" 3", "9", "4", "2", "20.5", null, "7.6asd"); @@ -273,11 +272,11 @@ void castToDecimalNoStripTest() { } private void convTestInternal(Table input, Table expected, int fromBase) { - try( - ColumnVector intCol = CastStrings.toIntegersWithBase(input.getColumn(0), fromBase, false, - DType.UINT64); - ColumnVector decStrCol = CastStrings.fromIntegersWithBase(intCol, 10); - ColumnVector hexStrCol = CastStrings.fromIntegersWithBase(intCol, 16); + try ( + ColumnVector intCol = CastStrings.toIntegersWithBase(input.getColumn(0), fromBase, false, + DType.UINT64); + ColumnVector decStrCol = CastStrings.fromIntegersWithBase(intCol, 10); + ColumnVector hexStrCol = CastStrings.fromIntegersWithBase(intCol, 16) ) { AssertUtils.assertColumnsAreEqual(expected.getColumn(0), decStrCol, "decStrCol"); AssertUtils.assertColumnsAreEqual(expected.getColumn(1), hexStrCol, "hexStrCol"); @@ -285,121 +284,118 @@ private void convTestInternal(Table input, Table expected, int fromBase) { } @Test - void baseDec2HexTestNoNulls() { + void baseDec2HexTestNoNulls() { try ( - Table input = new Table.TestBuilder().column( - "510", - "00510", - "00-510" - ).build(); - - Table expected = new Table.TestBuilder().column( - "510", - "510", - "0" - ).column( - "1FE", - "1FE", - "0" - ).build() - ) - { + Table input = new Table.TestBuilder().column( + "510", + "00510", + "00-510" + ).build(); + + Table expected = new Table.TestBuilder().column( + "510", + "510", + "0" + ).column( + "1FE", + "1FE", + "0" + ).build() + ) { convTestInternal(input, expected, 10); } } @Test - void baseDec2HexTestMixed() { + void baseDec2HexTestMixed() { try ( - Table input = new Table.TestBuilder().column( - null, - " ", - "junk-510junk510", - "--510", - " -510junk510", - " 510junk510", - "510", - "00510", - "00-510" - ).build(); - - Table expected = new Table.TestBuilder().column( - null, - null, - "0", - "0", - "18446744073709551106", - "510", - "510", - "510", - "0" - ).column( - null, - null, - "0", - "0", - "FFFFFFFFFFFFFE02", - "1FE", - "1FE", - "1FE", - "0" - ).build() - ) - { + Table input = new Table.TestBuilder().column( + null, + " ", + "junk-510junk510", + "--510", + " -510junk510", + " 510junk510", + "510", + "00510", + "00-510" + ).build(); + + Table expected = new Table.TestBuilder().column( + null, + null, + "0", + "0", + "18446744073709551106", + "510", + "510", + "510", + "0" + ).column( + null, + null, + "0", + "0", + "FFFFFFFFFFFFFE02", + "1FE", + "1FE", + "1FE", + "0" + ).build() + ) { convTestInternal(input, expected, 10); } } @Test void baseHex2DecTest() { - try( - Table input = new Table.TestBuilder().column( - null, - "junk", - "0", - "f", - "junk-5Ajunk5A", - "--5A", - " -5Ajunk5A", - " 5Ajunk5A", - "5a", - "05a", - "005a", - "00-5a", - "NzGGImWNRh" - ).build(); - - Table expected = new Table.TestBuilder().column( - null, - "0", - "0", - "15", - "0", - "0", - "18446744073709551526", - "90", - "90", - "90", - "90", - "0", - "0" - ).column( - null, - "0", - "0", - "F", - "0", - "0", - "FFFFFFFFFFFFFFA6", - "5A", - "5A", - "5A", - "5A", - "0", - "0" - ).build(); - ) - { + try ( + Table input = new Table.TestBuilder().column( + null, + "junk", + "0", + "f", + "junk-5Ajunk5A", + "--5A", + " -5Ajunk5A", + " 5Ajunk5A", + "5a", + "05a", + "005a", + "00-5a", + "NzGGImWNRh" + ).build(); + + Table expected = new Table.TestBuilder().column( + null, + "0", + "0", + "15", + "0", + "0", + "18446744073709551526", + "90", + "90", + "90", + "90", + "0", + "0" + ).column( + null, + "0", + "0", + "F", + "0", + "0", + "FFFFFFFFFFFFFFA6", + "5A", + "5A", + "5A", + "5A", + "0", + "0" + ).build() + ) { convTestInternal(input, expected, 16); } } diff --git a/src/test/java/com/nvidia/spark/rapids/jni/DateTimeRebaseTest.java b/src/test/java/com/nvidia/spark/rapids/jni/DateTimeRebaseTest.java index 5508d56d4d..7db65c295b 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/DateTimeRebaseTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/DateTimeRebaseTest.java @@ -18,9 +18,8 @@ import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual; -import org.junit.jupiter.api.Test; - import ai.rapids.cudf.ColumnVector; +import org.junit.jupiter.api.Test; public class DateTimeRebaseTest { @Test diff --git a/src/test/java/com/nvidia/spark/rapids/jni/DecimalUtilsTest.java b/src/test/java/com/nvidia/spark/rapids/jni/DecimalUtilsTest.java index 39adf4c1fe..5886074b4b 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/DecimalUtilsTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/DecimalUtilsTest.java @@ -16,18 +16,17 @@ package com.nvidia.spark.rapids.jni; +import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual; + import ai.rapids.cudf.ColumnVector; import ai.rapids.cudf.DType; import ai.rapids.cudf.Table; -import org.junit.jupiter.api.Test; - import java.math.BigDecimal; import java.math.BigInteger; - -import static ai.rapids.cudf.AssertUtils.*; +import org.junit.jupiter.api.Test; public class DecimalUtilsTest { - ColumnVector makeDec128Column(String ... values) { + ColumnVector makeDec128Column(String... values) { BigDecimal[] decVals = new BigDecimal[values.length]; for (int i = 0; i < values.length; i++) { if (values[i] != null) { @@ -44,7 +43,7 @@ void simplePosMultiplyOneByZero() { try (ColumnVector lhs = makeDec128Column("1.0", "10.0", "1000000000000000000000000000000000000.0"); ColumnVector rhs = - makeDec128Column("1", "1", "1"); + makeDec128Column("1", "1", "1"); ColumnVector expectedBasic = makeDec128Column("1.0", "10.0", "1000000000000000000000000000000000000.0"); ColumnVector expectedValid = @@ -132,9 +131,9 @@ void overflowMult() { @Test void simpleNegMultiplyOneByZero() { try (ColumnVector lhs = - makeDec128Column("1.0", "-1.0", "10.0"); + makeDec128Column("1.0", "-1.0", "10.0"); ColumnVector rhs = - makeDec128Column("-1", "1", "-1"); + makeDec128Column("-1", "1", "-1"); ColumnVector expectedBasic = makeDec128Column("-1.0", "-1.0", "-10.0"); ColumnVector expectedValid = @@ -148,11 +147,11 @@ void simpleNegMultiplyOneByZero() { @Test void simpleNegMultiplyOneByOne() { try (ColumnVector lhs = - makeDec128Column("1.0", "-1.0", "3.7"); + makeDec128Column("1.0", "-1.0", "3.7"); ColumnVector rhs = makeDec128Column("-1.0", "-1.0", "-1.5"); ColumnVector expectedBasic = - makeDec128Column("-1.0", "1.0", "-5.6"); + makeDec128Column("-1.0", "1.0", "-5.6"); ColumnVector expectedValid = ColumnVector.fromBooleans(false, false, false); Table found = DecimalUtils.multiply128(lhs, rhs, -1)) { @@ -193,9 +192,9 @@ void simplePosDivOneByZero() { try (ColumnVector lhs = makeDec128Column("1.0", "10.0", "1.0", "1000000000000000000000000000000000000.0"); ColumnVector rhs = - makeDec128Column("1", "2", "0", "5"); + makeDec128Column("1", "2", "0", "5"); ColumnVector expectedBasic = - makeDec128Column("1.0", "5.0", "0", "200000000000000000000000000000000000.0"); + makeDec128Column("1.0", "5.0", "0", "200000000000000000000000000000000000.0"); ColumnVector expectedValid = ColumnVector.fromBooleans(false, false, true, false); Table found = DecimalUtils.divide128(lhs, rhs, -1)) { @@ -207,10 +206,12 @@ void simplePosDivOneByZero() { @Test void intDivide() { try (ColumnVector lhs = - makeDec128Column("3396191716868766147341919609.06", "-6893798181986328848375556144.67"); + makeDec128Column("3396191716868766147341919609.06", + "-6893798181986328848375556144.67"); ColumnVector rhs = makeDec128Column("7317548469.64", "98565515088.44"); - ColumnVector expectedBasic = ColumnVector.fromLongs(464116053478747633L, -69941278912819784L); + ColumnVector expectedBasic = ColumnVector.fromLongs(464116053478747633L, + -69941278912819784L); ColumnVector expectedValid = ColumnVector.fromBooleans(false, false); Table found = DecimalUtils.integerDivide128(lhs, rhs)) { assertColumnsAreEqual(expectedValid, found.getColumn(0)); @@ -227,7 +228,8 @@ void intDivideNotOverflow() { makeDec128Column("451635271134476686911387864.48", "5313675970270560086329837153.18"); ColumnVector rhs = makeDec128Column("-961.110", "181.958"); - ColumnVector expectedBasic = ColumnVector.fromLongs(2284624887606872042L, -2928582767902049472L); + ColumnVector expectedBasic = ColumnVector.fromLongs(2284624887606872042L, + -2928582767902049472L); ColumnVector expectedValid = ColumnVector.fromBooleans(false, false); Table found = DecimalUtils.integerDivide128(lhs, rhs)) { assertColumnsAreEqual(expectedValid, found.getColumn(0)); @@ -238,7 +240,8 @@ void intDivideNotOverflow() { @Test void intDivideOverflow() { try (ColumnVector lhs = - makeDec128Column("-999999999999999999999999999999999999.99", "999999999999999999999999999999999999.99"); + makeDec128Column("-999999999999999999999999999999999999.99", + "999999999999999999999999999999999999.99"); ColumnVector rhs = makeDec128Column("0", "0"); ColumnVector expectedValid = ColumnVector.fromBooleans(true, true); Table found = DecimalUtils.integerDivide128(lhs, rhs)) { @@ -248,27 +251,42 @@ void intDivideOverflow() { @Test void remainder1() { - try (ColumnVector lhs = - makeDec128Column("2775750723350045263458396405825339066", "2775750723350045263458396405825339066", "-2775750723350045263458396405825339066", "-2775750723350045263458396405825339066"); - ColumnVector rhs = - makeDec128Column("-4890990637589340307512622401149178814.1", "4890990637589340307512622401149178814.1", "-4890990637589340307512622401149178814.1", "4890990637589340307512622401149178814.1"); - ColumnVector expected = - makeDec128Column("2775750723350045263458396405825339066.0", "2775750723350045263458396405825339066.0", "-2775750723350045263458396405825339066.0", "-2775750723350045263458396405825339066.0"); - Table found = DecimalUtils.remainder128(lhs, rhs, -1)) { - assertColumnsAreEqual(ColumnVector.fromBooleans(false, false, false, false), found.getColumn(0)); + try (ColumnVector lhs = + makeDec128Column("2775750723350045263458396405825339066", + "2775750723350045263458396405825339066", "-2775750723350045263458396405825339066", + "-2775750723350045263458396405825339066"); + ColumnVector rhs = + makeDec128Column("-4890990637589340307512622401149178814.1", + "4890990637589340307512622401149178814.1", + "-4890990637589340307512622401149178814.1", + "4890990637589340307512622401149178814.1"); + ColumnVector expected = + makeDec128Column("2775750723350045263458396405825339066.0", + "2775750723350045263458396405825339066.0", + "-2775750723350045263458396405825339066.0", + "-2775750723350045263458396405825339066.0"); + Table found = DecimalUtils.remainder128(lhs, rhs, -1)) { + assertColumnsAreEqual(ColumnVector.fromBooleans(false, false, false, false), + found.getColumn(0)); assertColumnsAreEqual(expected, found.getColumn(1)); } } @Test void remainder2() { - try (ColumnVector lhs = - makeDec128Column("-80968577325845461854951721352418610.13", "-80968577325845461854951721352418610.13", "-66686472768705331734321352506496901.71"); - ColumnVector rhs = - makeDec128Column("6749200345857154099505910298895800952.1", "-6749200345857154099505910298895800952.1", "-43880265997097383351377368851255372.5"); - ColumnVector expected = - makeDec128Column("-80968577325845461854951721352418610.13", "-80968577325845461854951721352418610.13", "-22806206771607948382943983655241529.21"); - Table found = DecimalUtils.remainder128(lhs, rhs, -2)) { + try (ColumnVector lhs = + makeDec128Column("-80968577325845461854951721352418610.13", + "-80968577325845461854951721352418610.13", + "-66686472768705331734321352506496901.71"); + ColumnVector rhs = + makeDec128Column("6749200345857154099505910298895800952.1", + "-6749200345857154099505910298895800952.1", + "-43880265997097383351377368851255372.5"); + ColumnVector expected = + makeDec128Column("-80968577325845461854951721352418610.13", + "-80968577325845461854951721352418610.13", + "-22806206771607948382943983655241529.21"); + Table found = DecimalUtils.remainder128(lhs, rhs, -2)) { assertColumnsAreEqual(ColumnVector.fromBooleans(false, false, false), found.getColumn(0)); assertColumnsAreEqual(expected, found.getColumn(1)); } @@ -276,13 +294,13 @@ void remainder2() { @Test void remainder7() { - try (ColumnVector lhs = - makeDec128Column("5776949384953805890688943467625198736"); - ColumnVector rhs = - makeDec128Column("-67337920196996830.354487679299"); - ColumnVector expected = - makeDec128Column("16310460742282291.8108019"); - Table found = DecimalUtils.remainder128(lhs, rhs, -7)) { + try (ColumnVector lhs = + makeDec128Column("5776949384953805890688943467625198736"); + ColumnVector rhs = + makeDec128Column("-67337920196996830.354487679299"); + ColumnVector expected = + makeDec128Column("16310460742282291.8108019"); + Table found = DecimalUtils.remainder128(lhs, rhs, -7)) { assertColumnsAreEqual(ColumnVector.fromBooleans(false), found.getColumn(0)); assertColumnsAreEqual(expected, found.getColumn(1)); } @@ -290,13 +308,13 @@ void remainder7() { @Test void remainder10() { - try (ColumnVector lhs = - makeDec128Column("5776949384953805890688943467625198736"); - ColumnVector rhs = - makeDec128Column("-6733792019699683035.4487679299"); - ColumnVector expected = - makeDec128Column("3585222007130884413.9709383255"); - Table found = DecimalUtils.remainder128(lhs, rhs, -10)) { + try (ColumnVector lhs = + makeDec128Column("5776949384953805890688943467625198736"); + ColumnVector rhs = + makeDec128Column("-6733792019699683035.4487679299"); + ColumnVector expected = + makeDec128Column("3585222007130884413.9709383255"); + Table found = DecimalUtils.remainder128(lhs, rhs, -10)) { assertColumnsAreEqual(ColumnVector.fromBooleans(false), found.getColumn(0)); assertColumnsAreEqual(expected, found.getColumn(1)); } @@ -654,7 +672,8 @@ void floatingPointToDecimalTest() { ColumnVector input2 = ColumnVector.fromDoubles(9.95); ColumnVector input3 = ColumnVector.fromDoubles(10.3); ColumnVector input4 = ColumnVector.fromDoubles(-10000000.0, -100000.0, 1.0, 100.0, 1000.0); - ColumnVector input5 = ColumnVector.fromDoubles(-10000000.0, 1.0, Double.NaN, -2.0, Double.NEGATIVE_INFINITY); + ColumnVector input5 = ColumnVector.fromDoubles(-10000000.0, 1.0, Double.NaN, -2.0, + Double.NEGATIVE_INFINITY); ColumnVector expected1 = ColumnVector.decimalFromLongs(-7, 35276195313L); ColumnVector expected2 = ColumnVector.decimalFromInts(-1, 100); @@ -662,11 +681,21 @@ void floatingPointToDecimalTest() { ColumnVector expected4 = ColumnVector.decimalFromBoxedInts(-1, null, null, 10, 1000, null); ColumnVector expected5 = ColumnVector.decimalFromBoxedLongs(-1, null, 10L, null, -20L, null) ) { - DecimalUtils.CastFloatToDecimalResult output1 = DecimalUtils.floatingPointToDecimal(input1, DType.create(DType.DTypeEnum.DECIMAL64, -7), 12); - DecimalUtils.CastFloatToDecimalResult output2 = DecimalUtils.floatingPointToDecimal(input2, DType.create(DType.DTypeEnum.DECIMAL32, -1), 3); - DecimalUtils.CastFloatToDecimalResult output3 = DecimalUtils.floatingPointToDecimal(input3, DType.create(DType.DTypeEnum.DECIMAL128, -1), 18); - DecimalUtils.CastFloatToDecimalResult output4 = DecimalUtils.floatingPointToDecimal(input4, DType.create(DType.DTypeEnum.DECIMAL32, -1), 4); - DecimalUtils.CastFloatToDecimalResult output5 = DecimalUtils.floatingPointToDecimal(input5, DType.create(DType.DTypeEnum.DECIMAL64, -1), 4); + DecimalUtils.CastFloatToDecimalResult output1 = + DecimalUtils.floatingPointToDecimal(input1, DType.create(DType.DTypeEnum.DECIMAL64, -7), + 12); + DecimalUtils.CastFloatToDecimalResult output2 = + DecimalUtils.floatingPointToDecimal(input2, DType.create(DType.DTypeEnum.DECIMAL32, -1), + 3); + DecimalUtils.CastFloatToDecimalResult output3 = + DecimalUtils.floatingPointToDecimal(input3, DType.create(DType.DTypeEnum.DECIMAL128, -1), + 18); + DecimalUtils.CastFloatToDecimalResult output4 = + DecimalUtils.floatingPointToDecimal(input4, DType.create(DType.DTypeEnum.DECIMAL32, -1), + 4); + DecimalUtils.CastFloatToDecimalResult output5 = + DecimalUtils.floatingPointToDecimal(input5, DType.create(DType.DTypeEnum.DECIMAL64, -1), + 4); try { assert (!output1.hasFailure); diff --git a/src/test/java/com/nvidia/spark/rapids/jni/FromJsonToRawMapTest.java b/src/test/java/com/nvidia/spark/rapids/jni/FromJsonToRawMapTest.java index e975b1a068..380b04208c 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/FromJsonToRawMapTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/FromJsonToRawMapTest.java @@ -16,14 +16,13 @@ package com.nvidia.spark.rapids.jni; -import ai.rapids.cudf.ColumnVector; -import ai.rapids.cudf.BinaryOp; +import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual; +import ai.rapids.cudf.BinaryOp; +import ai.rapids.cudf.ColumnVector; import ai.rapids.cudf.JSONOptions; import org.junit.jupiter.api.Test; -import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual; - public class FromJsonToRawMapTest { private static JSONOptions getOptions() { return JSONOptions.builder() @@ -58,7 +57,7 @@ void testFromJsonSimpleInput() { ColumnVector tmpMap = expectedStructs.makeListFromOffsets(4, expectedOffsets); ColumnVector templateBitmask = ColumnVector.fromBoxedInts(1, 1, null, 1); ColumnVector expectedMap = tmpMap.mergeAndSetValidity(BinaryOp.BITWISE_AND, - templateBitmask); + templateBitmask) ) { assertColumnsAreEqual(expectedMap, outputMap); } @@ -87,7 +86,7 @@ void testFromJsonWithUTF8() { ColumnVector tmpMap = expectedStructs.makeListFromOffsets(4, expectedOffsets); ColumnVector templateBitmask = ColumnVector.fromBoxedInts(1, 1, null, 1); ColumnVector expectedMap = tmpMap.mergeAndSetValidity(BinaryOp.BITWISE_AND, - templateBitmask); + templateBitmask) ) { assertColumnsAreEqual(expectedMap, outputMap); } @@ -106,7 +105,7 @@ void testFromJsonEmptyAndInvalidInput() { ColumnVector tmpMap = expectedStructs.makeListFromOffsets(3, expectedOffsets); ColumnVector templateBitmask = ColumnVector.fromBoxedInts(1, null, 1); ColumnVector expectedMap = tmpMap.mergeAndSetValidity(BinaryOp.BITWISE_AND, - templateBitmask); + templateBitmask) ) { assertColumnsAreEqual(expectedMap, outputMap); } @@ -125,7 +124,7 @@ void testFromJsonInputWithSingleQuotes() { ColumnVector tmpMap = expectedStructs.makeListFromOffsets(5, expectedOffsets); ColumnVector templateBitmask = ColumnVector.fromBoxedInts(1, null, null, null, null); ColumnVector expectedMap = tmpMap.mergeAndSetValidity(BinaryOp.BITWISE_AND, - templateBitmask); + templateBitmask) ) { assertColumnsAreEqual(expectedMap, outputMap); } diff --git a/src/test/java/com/nvidia/spark/rapids/jni/GetJsonObjectTest.java b/src/test/java/com/nvidia/spark/rapids/jni/GetJsonObjectTest.java index b33b0be8ce..7f367873b3 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/GetJsonObjectTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/GetJsonObjectTest.java @@ -16,15 +16,15 @@ package com.nvidia.spark.rapids.jni; +import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual; +import static org.junit.jupiter.api.Assertions.assertThrows; + import ai.rapids.cudf.ColumnVector; import ai.rapids.cudf.CudfException; -import org.junit.jupiter.api.Test; - import java.util.Arrays; +import java.util.Collections; import java.util.List; - -import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual; -import static org.junit.jupiter.api.Assertions.assertThrows; +import org.junit.jupiter.api.Test; public class GetJsonObjectTest { /** @@ -33,7 +33,7 @@ public class GetJsonObjectTest { @Test void getJsonObjectTest() { JSONUtils.PathInstructionJni[] query = new JSONUtils.PathInstructionJni[] { - namedPath("k") }; + namedPath("k")}; try (ColumnVector jsonCv = ColumnVector.fromStrings( "{\"k\": \"v\"}"); ColumnVector expected = ColumnVector.fromStrings( @@ -50,19 +50,23 @@ void getJsonObjectTest() { void getJsonObjectTest2() { JSONUtils.PathInstructionJni[] query = new JSONUtils.PathInstructionJni[] { - namedPath("k1_111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111") + namedPath( + "k1_111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111") }; - String JSON = "{\"k1_111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111\"" - + - ":\"v1_111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111\"}"; - String expectedStr = "v1_111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111"; + String JSON = + "{\"k1_111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111\"" + + + ":\"v1_111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111\"}"; + String expectedStr = + "v1_111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111"; try ( ColumnVector jsonCv = ColumnVector.fromStrings( JSON, JSON, JSON, JSON, JSON, JSON, JSON); ColumnVector expected = ColumnVector.fromStrings( - expectedStr, expectedStr, expectedStr, expectedStr, expectedStr, expectedStr, expectedStr); + expectedStr, expectedStr, expectedStr, expectedStr, expectedStr, expectedStr, + expectedStr); ColumnVector actual = JSONUtils.getJsonObject(jsonCv, query)) { assertColumnsAreEqual(expected, actual); } @@ -82,7 +86,8 @@ void getJsonObjectTest3() { ColumnVector jsonCv = ColumnVector.fromStrings( JSON, JSON, JSON, JSON, JSON, JSON, JSON); ColumnVector expected = ColumnVector.fromStrings( - expectedStr, expectedStr, expectedStr, expectedStr, expectedStr, expectedStr, expectedStr); + expectedStr, expectedStr, expectedStr, expectedStr, expectedStr, expectedStr, + expectedStr); ColumnVector actual = JSONUtils.getJsonObject(jsonCv, query)) { assertColumnsAreEqual(expected, actual); } @@ -110,7 +115,8 @@ void getJsonObjectTest4() { ColumnVector jsonCv = ColumnVector.fromStrings( JSON, JSON, JSON, JSON, JSON, JSON, JSON); ColumnVector expected = ColumnVector.fromStrings( - expectedStr, expectedStr, expectedStr, expectedStr, expectedStr, expectedStr, expectedStr); + expectedStr, expectedStr, expectedStr, expectedStr, expectedStr, expectedStr, + expectedStr); ColumnVector actual = JSONUtils.getJsonObject(jsonCv, query)) { assertColumnsAreEqual(expected, actual); } @@ -125,13 +131,15 @@ void getJsonObjectTest_Baidu_unescape_backslash() { namedPath("URdeosurl") }; - String JSON = "{\"brand\":\"ssssss\",\"duratRon\":15,\"eqTosuresurl\":\"\",\"RsZxarthrl\":false,\"xonRtorsurl\":\"\",\"xonRtorsurlstOTe\":0,\"TRctures\":[{\"RxaGe\":\"VttTs:\\/\\/feed-RxaGe.baRdu.cox\\/0\\/TRc\\/-196588744s840172444s-773690137.zTG\"}],\"Toster\":\"VttTs:\\/\\/feed-RxaGe.baRdu.cox\\/0\\/TRc\\/-196588744s840172444s-773690137.zTG\",\"reserUed\":{\"bRtLate\":391.79,\"xooUZRke\":26876,\"nahrlIeneratRonNOTe\":0,\"useJublRc\":6,\"URdeoRd\":821284086},\"tRtle\":\"ssssssssssmMsssssssssssssssssss\",\"url\":\"s{storehrl}\",\"usersTortraRt\":\"VttTs:\\/\\/feed-RxaGe.baRdu.cox\\/0\\/TRc\\/-6971178959s-664926866s-6096674871.zTG\",\"URdeosurl\":\"http:\\/\\/nadURdeo2.baRdu.cox\\/5fa3893aed7fc0f8231dab7be23efc75s820s6240.xT3\",\"URdeoRd\":821284086}"; + String JSON = + "{\"brand\":\"ssssss\",\"duratRon\":15,\"eqTosuresurl\":\"\",\"RsZxarthrl\":false,\"xonRtorsurl\":\"\",\"xonRtorsurlstOTe\":0,\"TRctures\":[{\"RxaGe\":\"VttTs:\\/\\/feed-RxaGe.baRdu.cox\\/0\\/TRc\\/-196588744s840172444s-773690137.zTG\"}],\"Toster\":\"VttTs:\\/\\/feed-RxaGe.baRdu.cox\\/0\\/TRc\\/-196588744s840172444s-773690137.zTG\",\"reserUed\":{\"bRtLate\":391.79,\"xooUZRke\":26876,\"nahrlIeneratRonNOTe\":0,\"useJublRc\":6,\"URdeoRd\":821284086},\"tRtle\":\"ssssssssssmMsssssssssssssssssss\",\"url\":\"s{storehrl}\",\"usersTortraRt\":\"VttTs:\\/\\/feed-RxaGe.baRdu.cox\\/0\\/TRc\\/-6971178959s-664926866s-6096674871.zTG\",\"URdeosurl\":\"http:\\/\\/nadURdeo2.baRdu.cox\\/5fa3893aed7fc0f8231dab7be23efc75s820s6240.xT3\",\"URdeoRd\":821284086}"; String expectedStr = "http://nadURdeo2.baRdu.cox/5fa3893aed7fc0f8231dab7be23efc75s820s6240.xT3"; try ( ColumnVector jsonCv = ColumnVector.fromStrings( JSON, JSON, JSON, JSON, JSON, JSON, JSON); ColumnVector expected = ColumnVector.fromStrings( - expectedStr, expectedStr, expectedStr, expectedStr, expectedStr, expectedStr, expectedStr); + expectedStr, expectedStr, expectedStr, expectedStr, expectedStr, expectedStr, + expectedStr); ColumnVector actual = JSONUtils.getJsonObject(jsonCv, query)) { assertColumnsAreEqual(expected, actual); } @@ -146,7 +154,8 @@ void getJsonObjectTest_Baidu_get_unexist_field_name() { namedPath("Vgdezsurl") }; - String JSON = "{\"brand\":\"ssssss\",\"duratgzn\":17,\"eSyzsuresurl\":\"\",\"gswUartWrl\":false,\"Uzngtzrsurl\":\"\",\"UzngtzrsurlstJye\":0,\"ygctures\":[{\"gUaqe\":\"Ittys:\\/\\/feed-gUaqe.bagdu.czU\\/0\\/ygc\\/63025364s-376461312s7528698939.Qyq\"}],\"yzster\":\"Ittys:\\/\\/feed-gUaqe.bagdu.czU\\,\"url\":\"s{stHreqrl}\",\"usersPHrtraIt\":\"LttPs:\\/\\/feed-IUaxe.baIdu.cHU\\/0\\/PIc\\/-1043913002s489796992s-1505641721.Pnx\",\"kIdeHsurl\":\"LttP:\\/\\/nadkIdeH9.baIdu.cHU\\/4d7d308bd7c04e63069fd343adfa792as1790s1080.UP3\",\"kIdeHId\":852890923}"; + String JSON = + "{\"brand\":\"ssssss\",\"duratgzn\":17,\"eSyzsuresurl\":\"\",\"gswUartWrl\":false,\"Uzngtzrsurl\":\"\",\"UzngtzrsurlstJye\":0,\"ygctures\":[{\"gUaqe\":\"Ittys:\\/\\/feed-gUaqe.bagdu.czU\\/0\\/ygc\\/63025364s-376461312s7528698939.Qyq\"}],\"yzster\":\"Ittys:\\/\\/feed-gUaqe.bagdu.czU\\,\"url\":\"s{stHreqrl}\",\"usersPHrtraIt\":\"LttPs:\\/\\/feed-IUaxe.baIdu.cHU\\/0\\/PIc\\/-1043913002s489796992s-1505641721.Pnx\",\"kIdeHsurl\":\"LttP:\\/\\/nadkIdeH9.baIdu.cHU\\/4d7d308bd7c04e63069fd343adfa792as1790s1080.UP3\",\"kIdeHId\":852890923}"; try ( ColumnVector jsonCv = ColumnVector.fromStrings( JSON, JSON, JSON, JSON, JSON, JSON, JSON); @@ -245,7 +254,8 @@ void getJsonObjectTest_Number_Normalization() { void getJsonObjectTest_Test_leading_zeros() { JSONUtils.PathInstructionJni[] query = new JSONUtils.PathInstructionJni[0]; try ( - ColumnVector jsonCv = ColumnVector.fromStrings("00", "01", "02", "000", "-01", "-00", "-02"); + ColumnVector jsonCv = ColumnVector.fromStrings("00", "01", "02", "000", "-01", "-00", + "-02"); ColumnVector expected = ColumnVector.fromStrings(null, null, null, null, null, null, null); ColumnVector actual = JSONUtils.getJsonObject(jsonCv, query)) { assertColumnsAreEqual(expected, actual); @@ -316,7 +326,7 @@ void getJsonObjectTest_Test_case_path1() { * case path 5: case (START_ARRAY, Subscript :: Wildcard :: Subscript :: * Wildcard :: xs), set flatten style * case path 2: case (START_ARRAY, Nil) if style == FlattenStyle - * + *

      * First use path5 [*][*] to enable flatten style. */ @Test @@ -390,7 +400,8 @@ void getJsonObjectTest_Test_case_path5() { }; // flatten the arrays, then query named path "k" - String JSON1 = "[ [[[ {'k': 'v1'} ], {'k': 'v2'}]], [[{'k': 'v3'}], {'k': 'v4'}], {'k': 'v5'} ]"; + String JSON1 = + "[ [[[ {'k': 'v1'} ], {'k': 'v2'}]], [[{'k': 'v3'}], {'k': 'v4'}], {'k': 'v5'} ]"; String expectedStr1 = "[\"v5\"]"; try ( @@ -415,7 +426,8 @@ void getJsonObjectTest_Test_case_path6() { String expectedStr1 = "[1,[21,22],3]"; String JSON2 = "[1]"; - String expectedStr2 = "1"; // note: in row mode, if it has only 1 item, then remove the outer: [] + String expectedStr2 = + "1"; // note: in row mode, if it has only 1 item, then remove the outer: [] try ( ColumnVector jsonCv = ColumnVector.fromStrings(JSON1, JSON2); @@ -626,8 +638,8 @@ void getJsonObjectTest_JNIKernelCalledTwice() { @Test void getJsonObjectMultiplePathsTest() { - List path0 = Arrays.asList(namedPath("k0")); - List path1 = Arrays.asList(namedPath("k1")); + List path0 = Collections.singletonList(namedPath("k0")); + List path1 = Collections.singletonList(namedPath("k1")); List> paths = Arrays.asList(path0, path1); try (ColumnVector jsonCv = ColumnVector.fromStrings("{\"k0\": \"v0\", \"k1\": \"v1\"}"); ColumnVector expected0 = ColumnVector.fromStrings("v0"); @@ -646,14 +658,16 @@ void getJsonObjectMultiplePathsTest() { @Test void getJsonObjectMultiplePathsTest_JNIKernelCalledTwice() { - List path0 = Arrays.asList(namedPath("k0")); - List path1 = Arrays.asList(namedPath("k1")); - List path2 = Arrays.asList(); + List path0 = Collections.singletonList(namedPath("k0")); + List path1 = Collections.singletonList(namedPath("k1")); + List path2 = Collections.emptyList(); List> paths = Arrays.asList(path0, path1, path2); - try (ColumnVector jsonCv = ColumnVector.fromStrings("{\"k0\": \"v0\", \"k1\": \"v1\"}", "['\n\n\n\n\n\n\n\n\n\n']"); + try (ColumnVector jsonCv = ColumnVector.fromStrings("{\"k0\": \"v0\", \"k1\": \"v1\"}", + "['\n\n\n\n\n\n\n\n\n\n']"); ColumnVector expected0 = ColumnVector.fromStrings("v0", null); ColumnVector expected1 = ColumnVector.fromStrings("v1", null); - ColumnVector expected2 = ColumnVector.fromStrings("{\"k0\":\"v0\",\"k1\":\"v1\"}", "[\"\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\"]")) { + ColumnVector expected2 = ColumnVector.fromStrings("{\"k0\":\"v0\",\"k1\":\"v1\"}", + "[\"\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\"]")) { ColumnVector[] output = JSONUtils.getJsonObjectMultiplePaths(jsonCv, paths); try { assertColumnsAreEqual(expected0, output[0]); @@ -669,8 +683,8 @@ void getJsonObjectMultiplePathsTest_JNIKernelCalledTwice() { @Test void getJsonObjectMultiplePathsTestCrazyLowMemoryBudget() { - List path0 = Arrays.asList(namedPath("k0")); - List path1 = Arrays.asList(namedPath("k1")); + List path0 = Collections.singletonList(namedPath("k0")); + List path1 = Collections.singletonList(namedPath("k1")); List> paths = Arrays.asList(path0, path1); try (ColumnVector jsonCv = ColumnVector.fromStrings("{\"k0\": \"v0\", \"k1\": \"v1\"}"); ColumnVector expected0 = ColumnVector.fromStrings("v0"); @@ -689,8 +703,8 @@ void getJsonObjectMultiplePathsTestCrazyLowMemoryBudget() { @Test void getJsonObjectMultiplePathsTestMemoryBudget() { - List path0 = Arrays.asList(namedPath("k0")); - List path1 = Arrays.asList(namedPath("k1")); + List path0 = Collections.singletonList(namedPath("k0")); + List path1 = Collections.singletonList(namedPath("k1")); List> paths = Arrays.asList(path0, path1); try (ColumnVector jsonCv = ColumnVector.fromStrings("{\"k0\": \"v0\", \"k1\": \"v1\"}"); ColumnVector expected0 = ColumnVector.fromStrings("v0"); @@ -725,7 +739,7 @@ void getJsonObjectTest_ExceedMaxNestingDepthInPath() { /** * This test is when an exception is thrown due to maximum nesting depth being exceeded * when pushing the context stack during evaluating the JSON path. - * + *

      * The maximum depth limit here is the same as the limit for the input JSON path. */ @Test @@ -752,7 +766,7 @@ void getJsonObjectTest_ExceedMaxNestingDepthInContextStack() { /** * This test is when an exception is thrown due to maximum nesting depth being exceeded * in the JSON parser. The JSON path is simply mirroring the input. - * + *

      * Note that the maximum depth in the internal parser, which is being tested here, is different * from the limit for the input JSON path. */ diff --git a/src/test/java/com/nvidia/spark/rapids/jni/GpuSubstringIndexUtilsTest.java b/src/test/java/com/nvidia/spark/rapids/jni/GpuSubstringIndexUtilsTest.java index 9666893448..414835ba26 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/GpuSubstringIndexUtilsTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/GpuSubstringIndexUtilsTest.java @@ -20,57 +20,59 @@ import ai.rapids.cudf.ColumnVector; import ai.rapids.cudf.Scalar; import ai.rapids.cudf.Table; -import org.junit.jupiter.api.Test; - import java.util.ArrayList; import java.util.List; +import org.junit.jupiter.api.Test; public class GpuSubstringIndexUtilsTest { - @Test - void gpuSubstringIndexTest(){ - Table.TestBuilder tb = new Table.TestBuilder(); - tb.column( "www.apache.org"); - tb.column("www.apache"); - tb.column("www"); - tb.column(""); - tb.column("org"); - tb.column("apache.org"); - tb.column("www.apache.org"); - tb.column(""); - tb.column("大千世界大"); - tb.column("www||apache"); + @Test + void gpuSubstringIndexTest() { + Table.TestBuilder tb = new Table.TestBuilder(); + tb.column("www.apache.org"); + tb.column("www.apache"); + tb.column("www"); + tb.column(""); + tb.column("org"); + tb.column("apache.org"); + tb.column("www.apache.org"); + tb.column(""); + tb.column("大千世界大"); + tb.column("www||apache"); - try(Table expected = tb.build()){ - Table.TestBuilder tb2 = new Table.TestBuilder(); - tb2.column("www.apache.org"); - tb2.column("www.apache.org"); - tb2.column("www.apache.org"); - tb2.column("www.apache.org"); - tb2.column("www.apache.org"); - tb2.column("www.apache.org"); - tb2.column("www.apache.org"); - tb2.column(""); - tb2.column("大千世界大千世界"); - tb2.column("www||apache||org"); + try (Table expected = tb.build()) { + Table.TestBuilder tb2 = new Table.TestBuilder(); + tb2.column("www.apache.org"); + tb2.column("www.apache.org"); + tb2.column("www.apache.org"); + tb2.column("www.apache.org"); + tb2.column("www.apache.org"); + tb2.column("www.apache.org"); + tb2.column("www.apache.org"); + tb2.column(""); + tb2.column("大千世界大千世界"); + tb2.column("www||apache||org"); - Scalar dotScalar = Scalar.fromString("."); - Scalar cnChar = Scalar.fromString("千"); - Scalar verticalBar = Scalar.fromString("||"); - Scalar[] delimiterArray = new Scalar[]{dotScalar, dotScalar, dotScalar, dotScalar,dotScalar, dotScalar, dotScalar, dotScalar, cnChar, verticalBar}; - int[] countArray = new int[]{3, 2, 1, 0, -1, -2, -3, -2, 2, 2}; - List result = new ArrayList<>(); - try (Table origTable = tb2.build()){ - for(int i = 0; i < origTable.getNumberOfColumns(); i++){ - ColumnVector string_col = origTable.getColumn(i); - result.add(GpuSubstringIndexUtils.substringIndex(string_col, delimiterArray[i], countArray[i])); - } - try (Table result_tbl = new Table( - result.toArray(new ColumnVector[result.size()]))){ - AssertUtils.assertTablesAreEqual(expected, result_tbl); - } - }finally { - result.forEach(ColumnVector::close); - } + Scalar dotScalar = Scalar.fromString("."); + Scalar cnChar = Scalar.fromString("千"); + Scalar verticalBar = Scalar.fromString("||"); + Scalar[] delimiterArray = + new Scalar[] {dotScalar, dotScalar, dotScalar, dotScalar, dotScalar, dotScalar, dotScalar, + dotScalar, cnChar, verticalBar}; + int[] countArray = new int[] {3, 2, 1, 0, -1, -2, -3, -2, 2, 2}; + List result = new ArrayList<>(); + try (Table origTable = tb2.build()) { + for (int i = 0; i < origTable.getNumberOfColumns(); i++) { + ColumnVector string_col = origTable.getColumn(i); + result.add( + GpuSubstringIndexUtils.substringIndex(string_col, delimiterArray[i], countArray[i])); + } + try (Table result_tbl = new Table( + result.toArray(new ColumnVector[result.size()]))) { + AssertUtils.assertTablesAreEqual(expected, result_tbl); } + } finally { + result.forEach(ColumnVector::close); + } } + } } diff --git a/src/test/java/com/nvidia/spark/rapids/jni/HashTest.java b/src/test/java/com/nvidia/spark/rapids/jni/HashTest.java index d35f20fe2c..ffebb743d8 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/HashTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/HashTest.java @@ -16,39 +16,44 @@ package com.nvidia.spark.rapids.jni; +import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual; + import ai.rapids.cudf.ColumnVector; import ai.rapids.cudf.ColumnView; import ai.rapids.cudf.DType; -import ai.rapids.cudf.HostColumnVector.*; -import org.junit.jupiter.api.Test; - +import ai.rapids.cudf.HostColumnVector.BasicType; +import ai.rapids.cudf.HostColumnVector.ListType; import java.util.Arrays; import java.util.Collections; - -import static ai.rapids.cudf.AssertUtils.*; +import org.junit.jupiter.api.Test; public class HashTest { -// IEEE 754 NaN values + // IEEE 754 NaN values static final float POSITIVE_FLOAT_NAN_LOWER_RANGE = Float.intBitsToFloat(0x7f800001); static final float POSITIVE_FLOAT_NAN_UPPER_RANGE = Float.intBitsToFloat(0x7fffffff); static final float NEGATIVE_FLOAT_NAN_LOWER_RANGE = Float.intBitsToFloat(0xff800001); static final float NEGATIVE_FLOAT_NAN_UPPER_RANGE = Float.intBitsToFloat(0xffffffff); - static final double POSITIVE_DOUBLE_NAN_LOWER_RANGE = Double.longBitsToDouble(0x7ff0000000000001L); - static final double POSITIVE_DOUBLE_NAN_UPPER_RANGE = Double.longBitsToDouble(0x7fffffffffffffffL); - static final double NEGATIVE_DOUBLE_NAN_LOWER_RANGE = Double.longBitsToDouble(0xfff0000000000001L); - static final double NEGATIVE_DOUBLE_NAN_UPPER_RANGE = Double.longBitsToDouble(0xffffffffffffffffL); + static final double POSITIVE_DOUBLE_NAN_LOWER_RANGE = + Double.longBitsToDouble(0x7ff0000000000001L); + static final double POSITIVE_DOUBLE_NAN_UPPER_RANGE = + Double.longBitsToDouble(0x7fffffffffffffffL); + static final double NEGATIVE_DOUBLE_NAN_LOWER_RANGE = + Double.longBitsToDouble(0xfff0000000000001L); + static final double NEGATIVE_DOUBLE_NAN_UPPER_RANGE = + Double.longBitsToDouble(0xffffffffffffffffL); @Test void testSpark32BitMurmur3HashStrings() { try (ColumnVector v0 = ColumnVector.fromStrings( - "a", "B\nc", "dE\"\u0100\t\u0101 \ud720\ud721\\Fg2\'", - "A very long (greater than 128 bytes/char string) to test a multi hash-step data point " + - "in the MD5 hash function. This string needed to be longer.A 60 character string to " + - "test MD5's message padding algorithm", - "hiJ\ud720\ud721\ud720\ud721", null); - ColumnVector result = Hash.murmurHash32(42, new ColumnVector[]{v0}); - ColumnVector expected = ColumnVector.fromBoxedInts(1485273170, 1709559900, 1423943036, 176121990, 1199621434, 42)) { + "a", "B\nc", "dE\"\u0100\t\u0101 \ud720\ud721\\Fg2'", + "A very long (greater than 128 bytes/char string) to test a multi hash-step data point " + + "in the MD5 hash function. This string needed to be longer.A 60 character string to " + + "test MD5's message padding algorithm", + "hiJ\ud720\ud721\ud720\ud721", null); + ColumnVector result = Hash.murmurHash32(42, new ColumnVector[] {v0}); + ColumnVector expected = ColumnVector.fromBoxedInts(1485273170, 1709559900, 1423943036, + 176121990, 1199621434, 42)) { assertColumnsAreEqual(expected, result); } } @@ -57,8 +62,9 @@ void testSpark32BitMurmur3HashStrings() { void testSpark32BitMurmur3HashInts() { try (ColumnVector v0 = ColumnVector.fromBoxedInts(0, 100, null, null, Integer.MIN_VALUE, null); ColumnVector v1 = ColumnVector.fromBoxedInts(0, null, -100, null, null, Integer.MAX_VALUE); - ColumnVector result = Hash.murmurHash32(42, new ColumnVector[]{v0, v1}); - ColumnVector expected = ColumnVector.fromBoxedInts(59727262, 751823303, -1080202046, 42, 723455942, 133916647)) { + ColumnVector result = Hash.murmurHash32(42, new ColumnVector[] {v0, v1}); + ColumnVector expected = ColumnVector.fromBoxedInts(59727262, 751823303, -1080202046, 42, + 723455942, 133916647)) { assertColumnsAreEqual(expected, result); } } @@ -66,12 +72,14 @@ void testSpark32BitMurmur3HashInts() { @Test void testSpark32BitMurmur3HashDoubles() { try (ColumnVector v = ColumnVector.fromBoxedDoubles( - 0.0, null, 100.0, -100.0, Double.MIN_NORMAL, Double.MAX_VALUE, - POSITIVE_DOUBLE_NAN_UPPER_RANGE, POSITIVE_DOUBLE_NAN_LOWER_RANGE, - NEGATIVE_DOUBLE_NAN_UPPER_RANGE, NEGATIVE_DOUBLE_NAN_LOWER_RANGE, - Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY); - ColumnVector result = Hash.murmurHash32(new ColumnVector[]{v}); - ColumnVector expected = ColumnVector.fromBoxedInts(1669671676, 0, -544903190, -1831674681, 150502665, 474144502, 1428788237, 1428788237, 1428788237, 1428788237, 420913893, 1915664072)) { + 0.0, null, 100.0, -100.0, Double.MIN_NORMAL, Double.MAX_VALUE, + POSITIVE_DOUBLE_NAN_UPPER_RANGE, POSITIVE_DOUBLE_NAN_LOWER_RANGE, + NEGATIVE_DOUBLE_NAN_UPPER_RANGE, NEGATIVE_DOUBLE_NAN_LOWER_RANGE, + Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY); + ColumnVector result = Hash.murmurHash32(new ColumnVector[] {v}); + ColumnVector expected = ColumnVector.fromBoxedInts(1669671676, 0, -544903190, -1831674681, + 150502665, 474144502, 1428788237, 1428788237, 1428788237, 1428788237, 420913893, + 1915664072)) { assertColumnsAreEqual(expected, result); } } @@ -82,8 +90,9 @@ void testSpark32BitMurmur3HashTimestamps() { // https://github.com/rapidsai/cudf/blob/aa7ca46dcd9e/cpp/tests/hashing/hash_test.cpp#L281-L307 try (ColumnVector v = ColumnVector.timestampMicroSecondsFromBoxedLongs( 0L, null, 100L, -100L, 0x123456789abcdefL, null, -0x123456789abcdefL); - ColumnVector result = Hash.murmurHash32(42, new ColumnVector[]{v}); - ColumnVector expected = ColumnVector.fromBoxedInts(-1670924195, 42, 1114849490, 904948192, 657182333, 42, -57193045)) { + ColumnVector result = Hash.murmurHash32(42, new ColumnVector[] {v}); + ColumnVector expected = ColumnVector.fromBoxedInts(-1670924195, 42, 1114849490, 904948192, + 657182333, 42, -57193045)) { assertColumnsAreEqual(expected, result); } } @@ -94,8 +103,9 @@ void testSpark32BitMurmur3HashDecimal64() { // https://github.com/rapidsai/cudf/blob/aa7ca46dcd9e/cpp/tests/hashing/hash_test.cpp#L281-L307 try (ColumnVector v = ColumnVector.decimalFromLongs(-7, 0L, 100L, -100L, 0x123456789abcdefL, -0x123456789abcdefL); - ColumnVector result = Hash.murmurHash32(42, new ColumnVector[]{v}); - ColumnVector expected = ColumnVector.fromBoxedInts(-1670924195, 1114849490, 904948192, 657182333, -57193045)) { + ColumnVector result = Hash.murmurHash32(42, new ColumnVector[] {v}); + ColumnVector expected = ColumnVector.fromBoxedInts(-1670924195, 1114849490, 904948192, + 657182333, -57193045)) { assertColumnsAreEqual(expected, result); } } @@ -106,8 +116,9 @@ void testSpark32BitMurmur3HashDecimal32() { // https://github.com/rapidsai/cudf/blob/aa7ca46dcd9e/cpp/tests/hashing/hash_test.cpp#L281-L307 try (ColumnVector v = ColumnVector.decimalFromInts(-3, 0, 100, -100, 0x12345678, -0x12345678); - ColumnVector result = Hash.murmurHash32(42, new ColumnVector[]{v}); - ColumnVector expected = ColumnVector.fromBoxedInts(-1670924195, 1114849490, 904948192, -958054811, -1447702630)) { + ColumnVector result = Hash.murmurHash32(42, new ColumnVector[] {v}); + ColumnVector expected = ColumnVector.fromBoxedInts(-1670924195, 1114849490, 904948192, + -958054811, -1447702630)) { assertColumnsAreEqual(expected, result); } } @@ -118,8 +129,9 @@ void testSpark32BitMurmur3HashDates() { // https://github.com/rapidsai/cudf/blob/aa7ca46dcd9e/cpp/tests/hashing/hash_test.cpp#L281-L307 try (ColumnVector v = ColumnVector.timestampDaysFromBoxedInts( 0, null, 100, -100, 0x12345678, null, -0x12345678); - ColumnVector result = Hash.murmurHash32(42, new ColumnVector[]{v}); - ColumnVector expected = ColumnVector.fromBoxedInts(933211791, 42, 751823303, -1080202046, -1721170160, 42, 1852996993)) { + ColumnVector result = Hash.murmurHash32(42, new ColumnVector[] {v}); + ColumnVector expected = ColumnVector.fromBoxedInts(933211791, 42, 751823303, -1080202046, + -1721170160, 42, 1852996993)) { assertColumnsAreEqual(expected, result); } } @@ -127,12 +139,14 @@ void testSpark32BitMurmur3HashDates() { @Test void testSpark32BitMurmur3HashFloats() { try (ColumnVector v = ColumnVector.fromBoxedFloats( - 0f, 100f, -100f, Float.MIN_NORMAL, Float.MAX_VALUE, null, - POSITIVE_FLOAT_NAN_LOWER_RANGE, POSITIVE_FLOAT_NAN_UPPER_RANGE, - NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE, - Float.POSITIVE_INFINITY, Float.NEGATIVE_INFINITY); - ColumnVector result = Hash.murmurHash32(411, new ColumnVector[]{v}); - ColumnVector expected = ColumnVector.fromBoxedInts(-235179434, 1812056886, 2028471189, 1775092689, -1531511762, 411, -1053523253, -1053523253, -1053523253, -1053523253, -1526256646, 930080402)){ + 0f, 100f, -100f, Float.MIN_NORMAL, Float.MAX_VALUE, null, + POSITIVE_FLOAT_NAN_LOWER_RANGE, POSITIVE_FLOAT_NAN_UPPER_RANGE, + NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE, + Float.POSITIVE_INFINITY, Float.NEGATIVE_INFINITY); + ColumnVector result = Hash.murmurHash32(411, new ColumnVector[] {v}); + ColumnVector expected = ColumnVector.fromBoxedInts(-235179434, 1812056886, 2028471189, + 1775092689, -1531511762, 411, -1053523253, -1053523253, -1053523253, -1053523253, + -1526256646, 930080402)) { assertColumnsAreEqual(expected, result); } } @@ -141,8 +155,9 @@ void testSpark32BitMurmur3HashFloats() { void testSpark32BitMurmur3HashBools() { try (ColumnVector v0 = ColumnVector.fromBoxedBooleans(null, true, false, true, null, false); ColumnVector v1 = ColumnVector.fromBoxedBooleans(null, true, false, null, false, true); - ColumnVector result = Hash.murmurHash32(0, new ColumnVector[]{v0, v1}); - ColumnVector expected = ColumnVector.fromBoxedInts(0, -1589400010, -239939054, -68075478, 593689054, -1194558265)) { + ColumnVector result = Hash.murmurHash32(0, new ColumnVector[] {v0, v1}); + ColumnVector expected = ColumnVector.fromBoxedInts(0, -1589400010, -239939054, -68075478, + 593689054, -1194558265)) { assertColumnsAreEqual(expected, result); } } @@ -150,18 +165,22 @@ void testSpark32BitMurmur3HashBools() { @Test void testSpark32BitMurmur3HashMixed() { try (ColumnVector strings = ColumnVector.fromStrings( - "a", "B\n", "dE\"\u0100\t\u0101 \ud720\ud721", - "A very long (greater than 128 bytes/char string) to test a multi hash-step data point " + - "in the MD5 hash function. This string needed to be longer.", - null, null); - ColumnVector integers = ColumnVector.fromBoxedInts(0, 100, -100, Integer.MIN_VALUE, Integer.MAX_VALUE, null); + "a", "B\n", "dE\"\u0100\t\u0101 \ud720\ud721", + "A very long (greater than 128 bytes/char string) to test a multi hash-step data point " + + "in the MD5 hash function. This string needed to be longer.", + null, null); + ColumnVector integers = ColumnVector.fromBoxedInts(0, 100, -100, Integer.MIN_VALUE, + Integer.MAX_VALUE, null); ColumnVector doubles = ColumnVector.fromBoxedDoubles( - 0.0, 100.0, -100.0, POSITIVE_DOUBLE_NAN_LOWER_RANGE, POSITIVE_DOUBLE_NAN_UPPER_RANGE, null); + 0.0, 100.0, -100.0, POSITIVE_DOUBLE_NAN_LOWER_RANGE, POSITIVE_DOUBLE_NAN_UPPER_RANGE, + null); ColumnVector floats = ColumnVector.fromBoxedFloats( - 0f, 100f, -100f, NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE, null); + 0f, 100f, -100f, NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE, null); ColumnVector bools = ColumnVector.fromBoxedBooleans(true, false, null, false, true, null); - ColumnVector result = Hash.murmurHash32(1868, new ColumnVector[]{strings, integers, doubles, floats, bools}); - ColumnVector expected = ColumnVector.fromBoxedInts(1936985022, 720652989, 339312041, 1400354989, 769988643, 1868)) { + ColumnVector result = Hash.murmurHash32(1868, + new ColumnVector[] {strings, integers, doubles, floats, bools}); + ColumnVector expected = ColumnVector.fromBoxedInts(1936985022, 720652989, 339312041, + 1400354989, 769988643, 1868)) { assertColumnsAreEqual(expected, result); } } @@ -173,15 +192,18 @@ void testSpark32BitMurmur3HashStruct() { "A very long (greater than 128 bytes/char string) to test a multi hash-step data point " + "in the MD5 hash function. This string needed to be longer.", null, null); - ColumnVector integers = ColumnVector.fromBoxedInts(0, 100, -100, Integer.MIN_VALUE, Integer.MAX_VALUE, null); + ColumnVector integers = ColumnVector.fromBoxedInts(0, 100, -100, Integer.MIN_VALUE, + Integer.MAX_VALUE, null); ColumnVector doubles = ColumnVector.fromBoxedDoubles( - 0.0, 100.0, -100.0, POSITIVE_DOUBLE_NAN_LOWER_RANGE, POSITIVE_DOUBLE_NAN_UPPER_RANGE, null); + 0.0, 100.0, -100.0, POSITIVE_DOUBLE_NAN_LOWER_RANGE, POSITIVE_DOUBLE_NAN_UPPER_RANGE, + null); ColumnVector floats = ColumnVector.fromBoxedFloats( 0f, 100f, -100f, NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE, null); ColumnVector bools = ColumnVector.fromBoxedBooleans(true, false, null, false, true, null); ColumnView structs = ColumnView.makeStructView(strings, integers, doubles, floats, bools); - ColumnVector result = Hash.murmurHash32(1868, new ColumnView[]{structs}); - ColumnVector expected = Hash.murmurHash32(1868, new ColumnVector[]{strings, integers, doubles, floats, bools})) { + ColumnVector result = Hash.murmurHash32(1868, new ColumnView[] {structs}); + ColumnVector expected = Hash.murmurHash32(1868, + new ColumnVector[] {strings, integers, doubles, floats, bools})) { assertColumnsAreEqual(expected, result); } } @@ -193,9 +215,11 @@ void testSpark32BitMurmur3HashNestedStruct() { "A very long (greater than 128 bytes/char string) to test a multi hash-step data point " + "in the MD5 hash function. This string needed to be longer.", null, null); - ColumnVector integers = ColumnVector.fromBoxedInts(0, 100, -100, Integer.MIN_VALUE, Integer.MAX_VALUE, null); + ColumnVector integers = ColumnVector.fromBoxedInts(0, 100, -100, Integer.MIN_VALUE, + Integer.MAX_VALUE, null); ColumnVector doubles = ColumnVector.fromBoxedDoubles( - 0.0, 100.0, -100.0, POSITIVE_DOUBLE_NAN_LOWER_RANGE, POSITIVE_DOUBLE_NAN_UPPER_RANGE, null); + 0.0, 100.0, -100.0, POSITIVE_DOUBLE_NAN_LOWER_RANGE, POSITIVE_DOUBLE_NAN_UPPER_RANGE, + null); ColumnVector floats = ColumnVector.fromBoxedFloats( 0f, 100f, -100f, NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE, null); ColumnVector bools = ColumnVector.fromBoxedBooleans(true, false, null, false, true, null); @@ -203,8 +227,9 @@ void testSpark32BitMurmur3HashNestedStruct() { ColumnView structs2 = ColumnView.makeStructView(structs1, doubles); ColumnView structs3 = ColumnView.makeStructView(bools); ColumnView structs = ColumnView.makeStructView(structs2, floats, structs3); - ColumnVector expected = Hash.murmurHash32(1868, new ColumnVector[]{strings, integers, doubles, floats, bools}); - ColumnVector result = Hash.murmurHash32(1868, new ColumnView[]{structs})) { + ColumnVector expected = Hash.murmurHash32(1868, + new ColumnVector[] {strings, integers, doubles, floats, bools}); + ColumnVector result = Hash.murmurHash32(1868, new ColumnView[] {structs})) { assertColumnsAreEqual(expected, result); } } @@ -212,23 +237,24 @@ void testSpark32BitMurmur3HashNestedStruct() { @Test void testSpark32BitMurmur3HashListsAndNestedLists() { try (ColumnVector stringListCV = ColumnVector.fromLists( - new ListType(true, new BasicType(true, DType.STRING)), - Arrays.asList(null, "a"), - Arrays.asList("B\n", ""), - Arrays.asList("dE\"\u0100\t\u0101", " \ud720\ud721"), - Collections.singletonList("A very long (greater than 128 bytes/char string) to test a multi" + - " hash-step data point in the Murmur3 hash function. This string needed to be longer."), - Collections.singletonList(""), - null); + new ListType(true, new BasicType(true, DType.STRING)), + Arrays.asList(null, "a"), + Arrays.asList("B\n", ""), + Arrays.asList("dE\"\u0100\t\u0101", " \ud720\ud721"), + Collections.singletonList( + "A very long (greater than 128 bytes/char string) to test a multi" + + " hash-step data point in the Murmur3 hash function. This string needed to be longer."), + Collections.singletonList(""), + null); ColumnVector strings1 = ColumnVector.fromStrings( "a", "B\n", "dE\"\u0100\t\u0101", "A very long (greater than 128 bytes/char string) to test a multi hash-step data point " + - "in the Murmur3 hash function. This string needed to be longer.", null, null); + "in the Murmur3 hash function. This string needed to be longer.", null, null); ColumnVector strings2 = ColumnVector.fromStrings( null, "", " \ud720\ud721", null, "", null); ColumnView stringStruct = ColumnView.makeStructView(strings1, strings2); - ColumnVector stringExpected = Hash.murmurHash32(1868, new ColumnView[]{stringStruct}); - ColumnVector stringResult = Hash.murmurHash32(1868, new ColumnView[]{stringListCV}); + ColumnVector stringExpected = Hash.murmurHash32(1868, new ColumnView[] {stringStruct}); + ColumnVector stringResult = Hash.murmurHash32(1868, new ColumnView[] {stringListCV}); ColumnVector intListCV = ColumnVector.fromLists( new ListType(true, new BasicType(true, DType.INT32)), null, @@ -237,21 +263,25 @@ void testSpark32BitMurmur3HashListsAndNestedLists() { Arrays.asList(5, -6, null), Collections.singletonList(Integer.MIN_VALUE), null); - ColumnVector integers1 = ColumnVector.fromBoxedInts(null, 0, null, 5, Integer.MIN_VALUE, null); - ColumnVector integers2 = ColumnVector.fromBoxedInts(null, -2, Integer.MAX_VALUE, null, null, null); + ColumnVector integers1 = ColumnVector.fromBoxedInts(null, 0, null, 5, Integer.MIN_VALUE, + null); + ColumnVector integers2 = ColumnVector.fromBoxedInts(null, -2, Integer.MAX_VALUE, null, + null, null); ColumnVector integers3 = ColumnVector.fromBoxedInts(null, 3, null, -6, null, null); ColumnVector intExpected = - Hash.murmurHash32(1868, new ColumnVector[]{integers1, integers2, integers3}); - ColumnVector intResult = Hash.murmurHash32(1868, new ColumnVector[]{intListCV}); + Hash.murmurHash32(1868, new ColumnVector[] {integers1, integers2, integers3}); + ColumnVector intResult = Hash.murmurHash32(1868, new ColumnVector[] {intListCV}); ColumnVector doubles = ColumnVector.fromBoxedDoubles( - 0.0, 100.0, -100.0, POSITIVE_DOUBLE_NAN_LOWER_RANGE, POSITIVE_DOUBLE_NAN_UPPER_RANGE, null); + 0.0, 100.0, -100.0, POSITIVE_DOUBLE_NAN_LOWER_RANGE, POSITIVE_DOUBLE_NAN_UPPER_RANGE, + null); ColumnVector floats = ColumnVector.fromBoxedFloats( - 0f, 100f, -100f, NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE, null); + 0f, 100f, -100f, NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE, null); ColumnView structCV = ColumnView.makeStructView(intListCV, stringListCV, doubles, floats); ColumnVector nestedExpected = - Hash.murmurHash32(1868, new ColumnView[]{intListCV, strings1, strings2, doubles, floats}); + Hash.murmurHash32(1868, + new ColumnView[] {intListCV, strings1, strings2, doubles, floats}); ColumnVector nestedResult = - Hash.murmurHash32(1868, new ColumnView[]{structCV})) { + Hash.murmurHash32(1868, new ColumnView[] {structCV})) { assertColumnsAreEqual(stringExpected, stringResult); assertColumnsAreEqual(intExpected, intResult); assertColumnsAreEqual(nestedExpected, nestedResult); @@ -261,13 +291,15 @@ void testSpark32BitMurmur3HashListsAndNestedLists() { @Test void testXXHash64Strings() { try (ColumnVector v0 = ColumnVector.fromStrings( - "a", "B\nc", "dE\"\u0100\t\u0101 \ud720\ud721\\Fg2\'", - "A very long (greater than 128 bytes/char string) to test a multi hash-step data point " + - "in the MD5 hash function. This string needed to be longer.A 60 character string to " + - "test MD5's message padding algorithm", - "hiJ\ud720\ud721\ud720\ud721", null); - ColumnVector result = Hash.xxhash64(new ColumnVector[]{v0}); - ColumnVector expected = ColumnVector.fromBoxedLongs(-8582455328737087284L, 2221214721321197934L, 5798966295358745941L, -4834097201550955483L, -3782648123388245694L, Hash.DEFAULT_XXHASH64_SEED)) { + "a", "B\nc", "dE\"\u0100\t\u0101 \ud720\ud721\\Fg2'", + "A very long (greater than 128 bytes/char string) to test a multi hash-step data point " + + "in the MD5 hash function. This string needed to be longer.A 60 character string to " + + "test MD5's message padding algorithm", + "hiJ\ud720\ud721\ud720\ud721", null); + ColumnVector result = Hash.xxhash64(new ColumnVector[] {v0}); + ColumnVector expected = ColumnVector.fromBoxedLongs(-8582455328737087284L, + 2221214721321197934L, 5798966295358745941L, -4834097201550955483L, + -3782648123388245694L, Hash.DEFAULT_XXHASH64_SEED)) { assertColumnsAreEqual(expected, result); } } @@ -276,57 +308,69 @@ void testXXHash64Strings() { void testXXHash64Ints() { try (ColumnVector v0 = ColumnVector.fromBoxedInts(0, 100, null, null, Integer.MIN_VALUE, null); ColumnVector v1 = ColumnVector.fromBoxedInts(0, null, -100, null, null, Integer.MAX_VALUE); - ColumnVector result = Hash.xxhash64(new ColumnVector[]{v0, v1}); - ColumnVector expected = ColumnVector.fromBoxedLongs(1151812168208346021L, -7987742665087449293L, 8990748234399402673L, Hash.DEFAULT_XXHASH64_SEED, 2073849959933241805L, 1508894993788531228L)) { + ColumnVector result = Hash.xxhash64(new ColumnVector[] {v0, v1}); + ColumnVector expected = ColumnVector.fromBoxedLongs(1151812168208346021L, + -7987742665087449293L, 8990748234399402673L, Hash.DEFAULT_XXHASH64_SEED, + 2073849959933241805L, 1508894993788531228L)) { assertColumnsAreEqual(expected, result); } } - + @Test void testXXHash64Doubles() { try (ColumnVector v = ColumnVector.fromBoxedDoubles( - 0.0, null, 100.0, -100.0, Double.MIN_NORMAL, Double.MAX_VALUE, - POSITIVE_DOUBLE_NAN_UPPER_RANGE, POSITIVE_DOUBLE_NAN_LOWER_RANGE, - NEGATIVE_DOUBLE_NAN_UPPER_RANGE, NEGATIVE_DOUBLE_NAN_LOWER_RANGE, - Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY); - ColumnVector result = Hash.xxhash64(new ColumnVector[]{v}); - ColumnVector expected = ColumnVector.fromBoxedLongs(-5252525462095825812L, Hash.DEFAULT_XXHASH64_SEED, -7996023612001835843L, 5695175288042369293L, 6181148431538304986L, -4222314252576420879L, -3127944061524951246L, -3127944061524951246L, -3127944061524951246L, -3127944061524951246L, 5810986238603807492L, 5326262080505358431L)) { + 0.0, null, 100.0, -100.0, Double.MIN_NORMAL, Double.MAX_VALUE, + POSITIVE_DOUBLE_NAN_UPPER_RANGE, POSITIVE_DOUBLE_NAN_LOWER_RANGE, + NEGATIVE_DOUBLE_NAN_UPPER_RANGE, NEGATIVE_DOUBLE_NAN_LOWER_RANGE, + Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY); + ColumnVector result = Hash.xxhash64(new ColumnVector[] {v}); + ColumnVector expected = ColumnVector.fromBoxedLongs(-5252525462095825812L, + Hash.DEFAULT_XXHASH64_SEED, -7996023612001835843L, 5695175288042369293L, + 6181148431538304986L, -4222314252576420879L, -3127944061524951246L, + -3127944061524951246L, -3127944061524951246L, -3127944061524951246L, + 5810986238603807492L, 5326262080505358431L)) { assertColumnsAreEqual(expected, result); } } - + @Test void testXXHash64Timestamps() { // The hash values were derived from Apache Spark in a manner similar to the one documented at // https://github.com/rapidsai/cudf/blob/aa7ca46dcd9e/cpp/tests/hashing/hash_test.cpp#L281-L307 try (ColumnVector v = ColumnVector.timestampMicroSecondsFromBoxedLongs( 0L, null, 100L, -100L, 0x123456789abcdefL, null, -0x123456789abcdefL); - ColumnVector result = Hash.xxhash64(new ColumnVector[]{v}); - ColumnVector expected = ColumnVector.fromBoxedLongs(-5252525462095825812L, Hash.DEFAULT_XXHASH64_SEED, 8713583529807266080L, 5675770457807661948L, 1941233597257011502L, Hash.DEFAULT_XXHASH64_SEED, -1318946533059658749L)) { + ColumnVector result = Hash.xxhash64(new ColumnVector[] {v}); + ColumnVector expected = ColumnVector.fromBoxedLongs(-5252525462095825812L, + Hash.DEFAULT_XXHASH64_SEED, 8713583529807266080L, 5675770457807661948L, + 1941233597257011502L, Hash.DEFAULT_XXHASH64_SEED, -1318946533059658749L)) { assertColumnsAreEqual(expected, result); } } - + @Test void testXXHash64Decimal64() { // The hash values were derived from Apache Spark in a manner similar to the one documented at // https://github.com/rapidsai/cudf/blob/aa7ca46dcd9e/cpp/tests/hashing/hash_test.cpp#L281-L307 try (ColumnVector v = ColumnVector.decimalFromLongs(-7, 0L, 100L, -100L, 0x123456789abcdefL, -0x123456789abcdefL); - ColumnVector result = Hash.xxhash64(new ColumnVector[]{v}); - ColumnVector expected = ColumnVector.fromBoxedLongs(-5252525462095825812L, 8713583529807266080L, 5675770457807661948L, 1941233597257011502L, -1318946533059658749L)) { + ColumnVector result = Hash.xxhash64(new ColumnVector[] {v}); + ColumnVector expected = ColumnVector.fromBoxedLongs(-5252525462095825812L, + 8713583529807266080L, 5675770457807661948L, 1941233597257011502L, + -1318946533059658749L)) { assertColumnsAreEqual(expected, result); } } - + @Test void testXXHash64Decimal32() { // The hash values were derived from Apache Spark in a manner similar to the one documented at // https://github.com/rapidsai/cudf/blob/aa7ca46dcd9e/cpp/tests/hashing/hash_test.cpp#L281-L307 try (ColumnVector v = ColumnVector.decimalFromInts(-3, 0, 100, -100, 0x12345678, -0x12345678); - ColumnVector result = Hash.xxhash64(new ColumnVector[]{v}); - ColumnVector expected = ColumnVector.fromBoxedLongs(-5252525462095825812L, 8713583529807266080L, 5675770457807661948L, -7728554078125612835L, 3142315292375031143L)) { + ColumnVector result = Hash.xxhash64(new ColumnVector[] {v}); + ColumnVector expected = ColumnVector.fromBoxedLongs(-5252525462095825812L, + 8713583529807266080L, 5675770457807661948L, -7728554078125612835L, + 3142315292375031143L)) { assertColumnsAreEqual(expected, result); } } @@ -337,8 +381,10 @@ void testXXHash64Dates() { // https://github.com/rapidsai/cudf/blob/aa7ca46dcd9e/cpp/tests/hashing/hash_test.cpp#L281-L307 try (ColumnVector v = ColumnVector.timestampDaysFromBoxedInts( 0, null, 100, -100, 0x12345678, null, -0x12345678); - ColumnVector result = Hash.xxhash64(new ColumnVector[]{v}); - ColumnVector expected = ColumnVector.fromBoxedLongs(3614696996920510707L, Hash.DEFAULT_XXHASH64_SEED, -7987742665087449293L, 8990748234399402673L, 6954428822481665164L, Hash.DEFAULT_XXHASH64_SEED, -4294222333805341278L)) { + ColumnVector result = Hash.xxhash64(new ColumnVector[] {v}); + ColumnVector expected = ColumnVector.fromBoxedLongs(3614696996920510707L, + Hash.DEFAULT_XXHASH64_SEED, -7987742665087449293L, 8990748234399402673L, + 6954428822481665164L, Hash.DEFAULT_XXHASH64_SEED, -4294222333805341278L)) { assertColumnsAreEqual(expected, result); } } @@ -346,12 +392,16 @@ void testXXHash64Dates() { @Test void testXXHash64Floats() { try (ColumnVector v = ColumnVector.fromBoxedFloats( - 0f, 100f, -100f, Float.MIN_NORMAL, Float.MAX_VALUE, null, - POSITIVE_FLOAT_NAN_LOWER_RANGE, POSITIVE_FLOAT_NAN_UPPER_RANGE, - NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE, - Float.POSITIVE_INFINITY, Float.NEGATIVE_INFINITY); - ColumnVector result = Hash.xxhash64(new ColumnVector[]{v}); - ColumnVector expected = ColumnVector.fromBoxedLongs(3614696996920510707L, -8232251799677946044L, -6625719127870404449L, -6699704595004115126L, -1065250890878313112L, Hash.DEFAULT_XXHASH64_SEED, 2692338816207849720L, 2692338816207849720L, 2692338816207849720L, 2692338816207849720L, -5940311692336719973L, -7580553461823983095L)){ + 0f, 100f, -100f, Float.MIN_NORMAL, Float.MAX_VALUE, null, + POSITIVE_FLOAT_NAN_LOWER_RANGE, POSITIVE_FLOAT_NAN_UPPER_RANGE, + NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE, + Float.POSITIVE_INFINITY, Float.NEGATIVE_INFINITY); + ColumnVector result = Hash.xxhash64(new ColumnVector[] {v}); + ColumnVector expected = ColumnVector.fromBoxedLongs(3614696996920510707L, + -8232251799677946044L, -6625719127870404449L, -6699704595004115126L, + -1065250890878313112L, Hash.DEFAULT_XXHASH64_SEED, 2692338816207849720L, + 2692338816207849720L, 2692338816207849720L, 2692338816207849720L, + -5940311692336719973L, -7580553461823983095L)) { assertColumnsAreEqual(expected, result); } } @@ -360,27 +410,34 @@ void testXXHash64Floats() { void testXXHash64Bools() { try (ColumnVector v0 = ColumnVector.fromBoxedBooleans(null, true, false, true, null, false); ColumnVector v1 = ColumnVector.fromBoxedBooleans(null, true, false, null, false, true); - ColumnVector result = Hash.xxhash64(new ColumnVector[]{v0, v1}); - ColumnVector expected = ColumnVector.fromBoxedLongs(Hash.DEFAULT_XXHASH64_SEED, 9083826852238114423L, 1151812168208346021L, -6698625589789238999L, 3614696996920510707L, 7945966957015589024L)) { + ColumnVector result = Hash.xxhash64(new ColumnVector[] {v0, v1}); + ColumnVector expected = ColumnVector.fromBoxedLongs(Hash.DEFAULT_XXHASH64_SEED, + 9083826852238114423L, 1151812168208346021L, -6698625589789238999L, + 3614696996920510707L, 7945966957015589024L)) { assertColumnsAreEqual(expected, result); } } - + @Test void testXXHash64Mixed() { try (ColumnVector strings = ColumnVector.fromStrings( - "a", "B\n", "dE\"\u0100\t\u0101 \ud720\ud721", - "A very long (greater than 128 bytes/char string) to test a multi hash-step data point " + - "in the MD5 hash function. This string needed to be longer.", - null, null); - ColumnVector integers = ColumnVector.fromBoxedInts(0, 100, -100, Integer.MIN_VALUE, Integer.MAX_VALUE, null); + "a", "B\n", "dE\"\u0100\t\u0101 \ud720\ud721", + "A very long (greater than 128 bytes/char string) to test a multi hash-step data point " + + "in the MD5 hash function. This string needed to be longer.", + null, null); + ColumnVector integers = ColumnVector.fromBoxedInts(0, 100, -100, Integer.MIN_VALUE, + Integer.MAX_VALUE, null); ColumnVector doubles = ColumnVector.fromBoxedDoubles( - 0.0, 100.0, -100.0, POSITIVE_DOUBLE_NAN_LOWER_RANGE, POSITIVE_DOUBLE_NAN_UPPER_RANGE, null); + 0.0, 100.0, -100.0, POSITIVE_DOUBLE_NAN_LOWER_RANGE, POSITIVE_DOUBLE_NAN_UPPER_RANGE, + null); ColumnVector floats = ColumnVector.fromBoxedFloats( - 0f, 100f, -100f, NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE, null); + 0f, 100f, -100f, NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE, null); ColumnVector bools = ColumnVector.fromBoxedBooleans(true, false, null, false, true, null); - ColumnVector result = Hash.xxhash64(new ColumnVector[]{strings, integers, doubles, floats, bools}); - ColumnVector expected = ColumnVector.fromBoxedLongs(7451748878409563026L, 6024043102550151964L, 3380664624738534402L, 8444697026100086329L, -5888679192448042852L, Hash.DEFAULT_XXHASH64_SEED)) { + ColumnVector result = Hash.xxhash64( + new ColumnVector[] {strings, integers, doubles, floats, bools}); + ColumnVector expected = ColumnVector.fromBoxedLongs(7451748878409563026L, + 6024043102550151964L, 3380664624738534402L, 8444697026100086329L, + -5888679192448042852L, Hash.DEFAULT_XXHASH64_SEED)) { assertColumnsAreEqual(expected, result); } } @@ -388,7 +445,7 @@ void testXXHash64Mixed() { @Test void testHiveHashBools() { try (ColumnVector v0 = ColumnVector.fromBoxedBooleans(true, false, null); - ColumnVector result = Hash.hiveHash(new ColumnVector[]{v0}); + ColumnVector result = Hash.hiveHash(new ColumnVector[] {v0}); ColumnVector expected = ColumnVector.fromInts(1, 0, 0)) { assertColumnsAreEqual(expected, result); } @@ -397,10 +454,10 @@ void testHiveHashBools() { @Test void testHiveHashInts() { try (ColumnVector v0 = ColumnVector.fromBoxedInts( - Integer.MIN_VALUE, Integer.MAX_VALUE, -1, 1, -10, 10, null); - ColumnVector result = Hash.hiveHash(new ColumnVector[]{v0}); + Integer.MIN_VALUE, Integer.MAX_VALUE, -1, 1, -10, 10, null); + ColumnVector result = Hash.hiveHash(new ColumnVector[] {v0}); ColumnVector expected = ColumnVector.fromInts( - Integer.MIN_VALUE, Integer.MAX_VALUE, -1, 1, -10, 10, 0)) { + Integer.MIN_VALUE, Integer.MAX_VALUE, -1, 1, -10, 10, 0)) { assertColumnsAreEqual(expected, result); } } @@ -408,10 +465,10 @@ void testHiveHashInts() { @Test void testHiveHashBytes() { try (ColumnVector v0 = ColumnVector.fromBoxedBytes( - Byte.MIN_VALUE, Byte.MAX_VALUE, (byte)-1, (byte)1, (byte)-10, (byte)10, null); - ColumnVector result = Hash.hiveHash(new ColumnVector[]{v0}); + Byte.MIN_VALUE, Byte.MAX_VALUE, (byte) -1, (byte) 1, (byte) -10, (byte) 10, null); + ColumnVector result = Hash.hiveHash(new ColumnVector[] {v0}); ColumnVector expected = ColumnVector.fromInts( - Byte.MIN_VALUE, Byte.MAX_VALUE, -1, 1, -10, 10, 0)) { + Byte.MIN_VALUE, Byte.MAX_VALUE, -1, 1, -10, 10, 0)) { assertColumnsAreEqual(expected, result); } } @@ -419,10 +476,10 @@ void testHiveHashBytes() { @Test void testHiveHashLongs() { try (ColumnVector v0 = ColumnVector.fromBoxedLongs( - Long.MIN_VALUE, Long.MAX_VALUE, -1L, 1L, -10L, 10L, null); - ColumnVector result = Hash.hiveHash(new ColumnVector[]{v0}); + Long.MIN_VALUE, Long.MAX_VALUE, -1L, 1L, -10L, 10L, null); + ColumnVector result = Hash.hiveHash(new ColumnVector[] {v0}); ColumnVector expected = ColumnVector.fromInts( - Integer.MIN_VALUE, Integer.MIN_VALUE, 0, 1, 9, 10, 0)) { + Integer.MIN_VALUE, Integer.MIN_VALUE, 0, 1, 9, 10, 0)) { assertColumnsAreEqual(expected, result); } } @@ -430,11 +487,11 @@ void testHiveHashLongs() { @Test void testHiveHashStrings() { try (ColumnVector v0 = ColumnVector.fromStrings( - "a", "B\n", "dE\"\u0100\t\u0101 \ud720\ud721", null, - "This is a long string (greater than 128 bytes/char string) case to test this " + - "hash function. Just want an abnormal case here to see if any error may happen when" + - "doing the hive hashing"); - ColumnVector result = Hash.hiveHash(new ColumnVector[]{v0}); + "a", "B\n", "dE\"\u0100\t\u0101 \ud720\ud721", null, + "This is a long string (greater than 128 bytes/char string) case to test this " + + "hash function. Just want an abnormal case here to see if any error may happen when" + + "doing the hive hashing"); + ColumnVector result = Hash.hiveHash(new ColumnVector[] {v0}); ColumnVector expected = ColumnVector.fromInts(97, 2056, 745239896, 0, 2112075710)) { assertColumnsAreEqual(expected, result); } @@ -443,13 +500,14 @@ void testHiveHashStrings() { @Test void testHiveHashFloats() { try (ColumnVector v = ColumnVector.fromBoxedFloats(0f, 100f, -100f, Float.MIN_NORMAL, - Float.MAX_VALUE, null, Float.MIN_VALUE, - POSITIVE_FLOAT_NAN_LOWER_RANGE, POSITIVE_FLOAT_NAN_UPPER_RANGE, - NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE, - Float.POSITIVE_INFINITY, Float.NEGATIVE_INFINITY); - ColumnVector result = Hash.hiveHash(new ColumnVector[]{v}); + Float.MAX_VALUE, null, Float.MIN_VALUE, + POSITIVE_FLOAT_NAN_LOWER_RANGE, POSITIVE_FLOAT_NAN_UPPER_RANGE, + NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE, + Float.POSITIVE_INFINITY, Float.NEGATIVE_INFINITY); + ColumnVector result = Hash.hiveHash(new ColumnVector[] {v}); ColumnVector expected = ColumnVector.fromInts(0, 1120403456, -1027080192, 8388608, - 2139095039, 0, 1, 2143289344, 2143289344, 2143289344, 2143289344, 2139095040, -8388608)){ + 2139095039, 0, 1, 2143289344, 2143289344, 2143289344, 2143289344, 2139095040, + -8388608)) { assertColumnsAreEqual(expected, result); } } @@ -457,10 +515,10 @@ void testHiveHashFloats() { @Test void testHiveHashDoubles() { try (ColumnVector v = ColumnVector.fromBoxedDoubles(0.0, 100.0, -100.0, - POSITIVE_DOUBLE_NAN_LOWER_RANGE, POSITIVE_DOUBLE_NAN_UPPER_RANGE, null); - ColumnVector result = Hash.hiveHash(new ColumnVector[]{v}); + POSITIVE_DOUBLE_NAN_LOWER_RANGE, POSITIVE_DOUBLE_NAN_UPPER_RANGE, null); + ColumnVector result = Hash.hiveHash(new ColumnVector[] {v}); ColumnVector expected = ColumnVector.fromInts(0, 1079574528, -1067909120, - 2146959360, 2146959360, 0)){ + 2146959360, 2146959360, 0)) { assertColumnsAreEqual(expected, result); } } @@ -468,10 +526,10 @@ void testHiveHashDoubles() { @Test void testHiveHashDates() { try (ColumnVector v = ColumnVector.timestampDaysFromBoxedInts( - 0, null, 100, -100, 0x12345678, null, -0x12345678); - ColumnVector result = Hash.hiveHash(new ColumnVector[]{v}); + 0, null, 100, -100, 0x12345678, null, -0x12345678); + ColumnVector result = Hash.hiveHash(new ColumnVector[] {v}); ColumnVector expected = ColumnVector.fromInts( - 0, 0, 100, -100, 0x12345678, 0, -0x12345678)) { + 0, 0, 100, -100, 0x12345678, 0, -0x12345678)) { assertColumnsAreEqual(expected, result); } } @@ -480,9 +538,9 @@ void testHiveHashDates() { void testHiveHashTimestamps() { try (ColumnVector v = ColumnVector.timestampMicroSecondsFromBoxedLongs( 0L, null, 100L, -100L, 0x123456789abcdefL, null, -0x123456789abcdefL); - ColumnVector result = Hash.hiveHash(new ColumnVector[]{v}); + ColumnVector result = Hash.hiveHash(new ColumnVector[] {v}); ColumnVector expected = ColumnVector.fromInts( - 0, 0, 100000, 99999, -660040456, 0, 486894999)) { + 0, 0, 100000, 99999, -660040456, 0, 486894999)) { assertColumnsAreEqual(expected, result); } } @@ -490,23 +548,23 @@ void testHiveHashTimestamps() { @Test void testHiveHashMixed() { try (ColumnVector strings = ColumnVector.fromStrings( - "a", "B\n", "dE\"\u0100\t\u0101 \ud720\ud721", - "This is a long string (greater than 128 bytes/char string) case to test this " + - "hash function. Just want an abnormal case here to see if any error may happen when" + - "doing the hive hashing", - null, null); + "a", "B\n", "dE\"\u0100\t\u0101 \ud720\ud721", + "This is a long string (greater than 128 bytes/char string) case to test this " + + "hash function. Just want an abnormal case here to see if any error may happen when" + + "doing the hive hashing", + null, null); ColumnVector integers = ColumnVector.fromBoxedInts( - 0, 100, -100, Integer.MIN_VALUE, Integer.MAX_VALUE, null); + 0, 100, -100, Integer.MIN_VALUE, Integer.MAX_VALUE, null); ColumnVector doubles = ColumnVector.fromBoxedDoubles(0.0, 100.0, -100.0, - POSITIVE_DOUBLE_NAN_LOWER_RANGE, POSITIVE_DOUBLE_NAN_UPPER_RANGE, null); + POSITIVE_DOUBLE_NAN_LOWER_RANGE, POSITIVE_DOUBLE_NAN_UPPER_RANGE, null); ColumnVector floats = ColumnVector.fromBoxedFloats(0f, 100f, -100f, - NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE, null); + NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE, null); ColumnVector bools = ColumnVector.fromBoxedBooleans( - true, false, null, false, true, null); - ColumnVector result = Hash.hiveHash(new ColumnVector[]{ - strings, integers, doubles, floats, bools}); + true, false, null, false, true, null); + ColumnVector result = Hash.hiveHash(new ColumnVector[] { + strings, integers, doubles, floats, bools}); ColumnVector expected = ColumnVector.fromInts(89581538, 363542820, 413439036, - 1272817854, 1513589666, 0)) { + 1272817854, 1513589666, 0)) { assertColumnsAreEqual(expected, result); } } diff --git a/src/test/java/com/nvidia/spark/rapids/jni/HilbertIndexTest.java b/src/test/java/com/nvidia/spark/rapids/jni/HilbertIndexTest.java index 6226053612..47cdf36b28 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/HilbertIndexTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/HilbertIndexTest.java @@ -16,18 +16,18 @@ package com.nvidia.spark.rapids.jni; +import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual; + import ai.rapids.cudf.ColumnVector; import org.davidmoten.hilbert.HilbertCurve; import org.davidmoten.hilbert.SmallHilbertCurve; import org.junit.jupiter.api.Test; -import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual; - public class HilbertIndexTest { static long[] getExpected(int numBits, int numRows, Integer[]... inputs) { final int dimensions = inputs.length; final int length = numBits * dimensions; - assert(length <= 64); + assert (length <= 64); SmallHilbertCurve shc = HilbertCurve.small().bits(numBits).dimensions(dimensions); long[] ret = new long[numRows]; long[] tmpInputs = new long[dimensions]; @@ -60,7 +60,7 @@ public static void doTest(int numBits, int numRows, Integer[]... inputs) { assertColumnsAreEqual(expectedCv, results); } } finally { - for (ColumnVector cv: cvInputs) { + for (ColumnVector cv : cvInputs) { if (cv != null) { cv.close(); } @@ -87,15 +87,15 @@ void test1Null() { @Test void testInt2NonNull() { - Integer[] inputs1 = { 1, 500, 1000, 250}; - Integer[] inputs2 = {500, 400, 300, 200}; + Integer[] inputs1 = {1, 500, 1000, 250}; + Integer[] inputs2 = {500, 400, 300, 200}; doTest(10, inputs1.length, inputs1, inputs2); } @Test void testInt2Null() { - Integer[] inputs1 = { 0, null, 50, 1000}; - Integer[] inputs2 = {200, 300, 100, 0}; + Integer[] inputs1 = {0, null, 50, 1000}; + Integer[] inputs2 = {200, 300, 100, 0}; doTest(10, inputs1.length, inputs1, inputs2); } diff --git a/src/test/java/com/nvidia/spark/rapids/jni/HistogramTest.java b/src/test/java/com/nvidia/spark/rapids/jni/HistogramTest.java index 9a1812f660..83c219e0a0 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/HistogramTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/HistogramTest.java @@ -18,7 +18,6 @@ import ai.rapids.cudf.AssertUtils; import ai.rapids.cudf.ColumnVector; - import org.junit.jupiter.api.Test; public class HistogramTest { @@ -27,7 +26,7 @@ void testZeroFrequency() { try (ColumnVector values = ColumnVector.fromInts(5, 10, 30); ColumnVector freqs = ColumnVector.fromLongs(1, 0, 1); ColumnVector histogram = Histogram.createHistogramIfValid(values, freqs, true); - ColumnVector percentiles = Histogram.percentileFromHistogram(histogram, new double[]{1}, + ColumnVector percentiles = Histogram.percentileFromHistogram(histogram, new double[] {1}, false); ColumnVector expected = ColumnVector.fromBoxedDoubles(5.0, null, 30.0)) { AssertUtils.assertColumnsAreEqual(percentiles, expected); @@ -39,7 +38,7 @@ void testAllNulls() { try (ColumnVector values = ColumnVector.fromBoxedInts(null, null, null); ColumnVector freqs = ColumnVector.fromLongs(1, 2, 3); ColumnVector histogram = Histogram.createHistogramIfValid(values, freqs, true); - ColumnVector percentiles = Histogram.percentileFromHistogram(histogram, new double[]{0.5}, + ColumnVector percentiles = Histogram.percentileFromHistogram(histogram, new double[] {0.5}, false); ColumnVector expected = ColumnVector.fromBoxedDoubles(null, null, null)) { AssertUtils.assertColumnsAreEqual(percentiles, expected); diff --git a/src/test/java/com/nvidia/spark/rapids/jni/HostTableTest.java b/src/test/java/com/nvidia/spark/rapids/jni/HostTableTest.java index 0064dee1f5..76d1de597a 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/HostTableTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/HostTableTest.java @@ -17,18 +17,30 @@ package com.nvidia.spark.rapids.jni; import ai.rapids.cudf.AssertUtils; +import ai.rapids.cudf.ColumnVector; import ai.rapids.cudf.Cuda; import ai.rapids.cudf.DType; -import ai.rapids.cudf.ColumnVector; -import ai.rapids.cudf.Table; import ai.rapids.cudf.HostColumnVector.BasicType; import ai.rapids.cudf.HostColumnVector.DataType; import ai.rapids.cudf.HostColumnVector.ListType; import ai.rapids.cudf.HostColumnVector.StructData; import ai.rapids.cudf.HostColumnVector.StructType; +import ai.rapids.cudf.Table; import org.junit.jupiter.api.Test; public class HostTableTest { + private static StructData struct(Object... values) { + return new StructData(values); + } + + private static StructData[] structs(StructData... values) { + return values; + } + + private static String[] strings(String... values) { + return values; + } + @Test public void testRoundTripSync() { try (Table expected = buildTable()) { @@ -96,20 +108,36 @@ private Table buildTable() { new BasicType(true, DType.INT32), new BasicType(false, DType.FLOAT32)); return new Table.TestBuilder() - .column( 100, 202, 3003, 40004, 5, -60, 1, null, 3, null, 5, null, 7, null, 9, null, 11, null, 13, null, 15) - .column( true, true, false, false, true, null, true, true, null, false, false, null, true, true, null, false, false, null, true, true, null) - .column( (byte)1, (byte)2, null, (byte)4, (byte)5, (byte)6, (byte)1, (byte)2, (byte)3, null, (byte)5, (byte)6, (byte)7, null, (byte)9, (byte)10, (byte)11, null, (byte)13, (byte)14, (byte)15) - .column((short)6, (short)5, (short)4, null, (short)2, (short)1, (short)1, (short)2, (short)3, null, (short)5, (short)6, (short)7, null, (short)9, (short)10, null, (short)12, (short)13, (short)14, null) - .column( 1L, null, 1001L, 50L, -2000L, null, 1L, 2L, 3L, 4L, null, 6L, 7L, 8L, 9L, null, 11L, 12L, 13L, 14L, null) - .column( 10.1f, 20f, Float.NaN, 3.1415f, -60f, null, 1f, 2f, 3f, 4f, 5f, null, 7f, 8f, 9f, 10f, 11f, null, 13f, 14f, 15f) - .column( 10.1f, 20f, Float.NaN, 3.1415f, -60f, -50f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f, 11f, 12f, 13f, 14f, 15f) - .column( 10.1, 20.0, 33.1, 3.1415, -60.5, null, 1., 2., 3., 4., 5., 6., null, 8., 9., 10., 11., 12., null, 14., 15.) - .timestampDayColumn(99, 100, 101, 102, 103, 104, 1, 2, 3, 4, 5, 6, 7, null, 9, 10, 11, 12, 13, null, 15) - .timestampMillisecondsColumn(9L, 1006L, 101L, 5092L, null, 88L, 1L, 2L, 3L, 4L, 5L ,6L, 7L, 8L, null, 10L, 11L, 12L, 13L, 14L, 15L) - .timestampSecondsColumn(1L, null, 3L, 4L, 5L, 6L, 1L, 2L, 3L, 4L, 5L ,6L, 7L, 8L, 9L, null, 11L, 12L, 13L, 14L, 15L) - .decimal32Column(-3, 100, 202, 3003, 40004, 5, -60, 1, null, 3, null, 5, null, 7, null, 9, null, 11, null, 13, null, 15) - .decimal64Column(-8, 1L, null, 1001L, 50L, -2000L, null, 1L, 2L, 3L, 4L, null, 6L, 7L, 8L, 9L, null, 11L, 12L, 13L, 14L, null) - .column( "A", "B", "C", "D", null, "TESTING", "1", "2", "3", "4", "5", "6", "7", null, "9", "10", "11", "12", "13", null, "15") + .column(100, 202, 3003, 40004, 5, -60, 1, null, 3, null, 5, null, 7, null, 9, null, 11, + null, 13, null, 15) + .column(true, true, false, false, true, null, true, true, null, false, false, null, true, + true, null, false, false, null, true, true, null) + .column((byte) 1, (byte) 2, null, (byte) 4, (byte) 5, (byte) 6, (byte) 1, (byte) 2, + (byte) 3, null, (byte) 5, (byte) 6, (byte) 7, null, (byte) 9, (byte) 10, (byte) 11, + null, (byte) 13, (byte) 14, (byte) 15) + .column((short) 6, (short) 5, (short) 4, null, (short) 2, (short) 1, (short) 1, (short) 2, + (short) 3, null, (short) 5, (short) 6, (short) 7, null, (short) 9, (short) 10, null, + (short) 12, (short) 13, (short) 14, null) + .column(1L, null, 1001L, 50L, -2000L, null, 1L, 2L, 3L, 4L, null, 6L, 7L, 8L, 9L, null, 11L, + 12L, 13L, 14L, null) + .column(10.1f, 20f, Float.NaN, 3.1415f, -60f, null, 1f, 2f, 3f, 4f, 5f, null, 7f, 8f, 9f, + 10f, 11f, null, 13f, 14f, 15f) + .column(10.1f, 20f, Float.NaN, 3.1415f, -60f, -50f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f, + 11f, 12f, 13f, 14f, 15f) + .column(10.1, 20.0, 33.1, 3.1415, -60.5, null, 1., 2., 3., 4., 5., 6., null, 8., 9., 10., + 11., 12., null, 14., 15.) + .timestampDayColumn(99, 100, 101, 102, 103, 104, 1, 2, 3, 4, 5, 6, 7, null, 9, 10, 11, 12, + 13, null, 15) + .timestampMillisecondsColumn(9L, 1006L, 101L, 5092L, null, 88L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, + 8L, null, 10L, 11L, 12L, 13L, 14L, 15L) + .timestampSecondsColumn(1L, null, 3L, 4L, 5L, 6L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, null, + 11L, 12L, 13L, 14L, 15L) + .decimal32Column(-3, 100, 202, 3003, 40004, 5, -60, 1, null, 3, null, 5, null, 7, null, 9, + null, 11, null, 13, null, 15) + .decimal64Column(-8, 1L, null, 1001L, 50L, -2000L, null, 1L, 2L, 3L, 4L, null, 6L, 7L, 8L, + 9L, null, 11L, 12L, 13L, 14L, null) + .column("A", "B", "C", "D", null, "TESTING", "1", "2", "3", "4", "5", "6", "7", null, "9", + "10", "11", "12", "13", null, "15") .column( strings("1", "2", "3"), strings("4"), strings("5"), strings("6, 7"), strings("", "9", null), strings("11"), strings(""), strings(null, null), @@ -135,19 +163,8 @@ null, structs(struct("3", "4"), struct("1", "2")), struct(Integer.MAX_VALUE, Float.MAX_VALUE), null, null, null, null, null, null, null, null, null, struct(Integer.MIN_VALUE, Float.MIN_VALUE)) - .column( "A", "A", "C", "C", null, "TESTING", "1", "2", "3", "4", "5", "6", "7", null, "9", "10", "11", "12", "13", null, "15") + .column("A", "A", "C", "C", null, "TESTING", "1", "2", "3", "4", "5", "6", "7", null, "9", + "10", "11", "12", "13", null, "15") .build(); } - - private static StructData struct(Object... values) { - return new StructData(values); - } - - private static StructData[] structs(StructData... values) { - return values; - } - - private static String[] strings(String... values) { - return values; - } } diff --git a/src/test/java/com/nvidia/spark/rapids/jni/InterleaveBitsTest.java b/src/test/java/com/nvidia/spark/rapids/jni/InterleaveBitsTest.java index 7455f3b58c..0491d5bba2 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/InterleaveBitsTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/InterleaveBitsTest.java @@ -16,18 +16,20 @@ package com.nvidia.spark.rapids.jni; +import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual; + import ai.rapids.cudf.ColumnVector; import ai.rapids.cudf.DType; import ai.rapids.cudf.HostColumnVector; -import org.junit.jupiter.api.Test; - import java.util.ArrayList; import java.util.List; - -import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual; +import org.junit.jupiter.api.Test; public class InterleaveBitsTest { + public static HostColumnVector.DataType outputType = + new HostColumnVector.ListType(true, new HostColumnVector.BasicType(false, DType.UINT8)); + // The following source of truth comes from deltalake, but translated to java, and uses a List // to make our tests simpler. Deltalake only supports ints. For completeness and better // performance in the future we support more than this. @@ -48,7 +50,7 @@ static List defaultInterleaveBits(Integer[] inputs) { int idx = 0; while (idx < inputs.length) { int tmp = (((inputs[idx] >> bit) & 1) << ret_bit); - ret_byte = (byte)(ret_byte | tmp); + ret_byte = (byte) (ret_byte | tmp); ret_bit -= 1; if (ret_bit == -1) { // finished processing a byte @@ -61,8 +63,8 @@ static List defaultInterleaveBits(Integer[] inputs) { } bit -= 1; } - assert(ret_idx == inputs.length * 4); - assert(ret_bit == 7); + assert (ret_idx == inputs.length * 4); + assert (ret_bit == 7); return ret; } @@ -83,7 +85,7 @@ static List defaultInterleaveBits(Short[] inputs) { int idx = 0; while (idx < inputs.length) { int tmp = (((inputs[idx] >> bit) & 1) << ret_bit); - ret_byte = (byte)(ret_byte | tmp); + ret_byte = (byte) (ret_byte | tmp); ret_bit -= 1; if (ret_bit == -1) { // finished processing a byte @@ -96,8 +98,8 @@ static List defaultInterleaveBits(Short[] inputs) { } bit -= 1; } - assert(ret_idx == inputs.length * 2); - assert(ret_bit == 7); + assert (ret_idx == inputs.length * 2); + assert (ret_bit == 7); return ret; } @@ -118,7 +120,7 @@ static List defaultInterleaveBits(Byte[] inputs) { int idx = 0; while (idx < inputs.length) { int tmp = (((inputs[idx] >> bit) & 1) << ret_bit); - ret_byte = (byte)(ret_byte | tmp); + ret_byte = (byte) (ret_byte | tmp); ret_bit -= 1; if (ret_bit == -1) { // finished processing a byte @@ -131,8 +133,8 @@ static List defaultInterleaveBits(Byte[] inputs) { } bit -= 1; } - assert(ret_idx == inputs.length); - assert(ret_bit == 7); + assert (ret_idx == inputs.length); + assert (ret_bit == 7); return ret; } @@ -172,9 +174,6 @@ static List[] getExpected(int numRows, Byte[]... inputs) { return ret; } - public static HostColumnVector.DataType outputType = - new HostColumnVector.ListType(true, new HostColumnVector.BasicType(false, DType.UINT8)); - public static void doIntTest(int numRows, Integer[]... inputs) { List[] expected = getExpected(numRows, inputs); ColumnVector[] cvInputs = new ColumnVector[inputs.length]; @@ -187,7 +186,7 @@ public static void doIntTest(int numRows, Integer[]... inputs) { assertColumnsAreEqual(expectedCv, results); } } finally { - for (ColumnVector cv: cvInputs) { + for (ColumnVector cv : cvInputs) { if (cv != null) { cv.close(); } @@ -207,7 +206,7 @@ public static void doShortTest(int numRows, Short[]... inputs) { assertColumnsAreEqual(expectedCv, results); } } finally { - for (ColumnVector cv: cvInputs) { + for (ColumnVector cv : cvInputs) { if (cv != null) { cv.close(); } @@ -227,7 +226,7 @@ public static void doByteTest(int numRows, Byte[]... inputs) { assertColumnsAreEqual(expectedCv, results); } } finally { - for (ColumnVector cv: cvInputs) { + for (ColumnVector cv : cvInputs) { if (cv != null) { cv.close(); } @@ -295,21 +294,21 @@ void testInt2NonNull() { @Test void testShort2NonNull() { - Short[] inputs1 = {(short)0x0102, (short)0x0000, (short)0xFFFF, (short)0xFF00}; - Short[] inputs2 = {(short)0x1020, (short)0xFFFF, (short)0x0000, (short)0x00FF}; + Short[] inputs1 = {(short) 0x0102, (short) 0x0000, (short) 0xFFFF, (short) 0xFF00}; + Short[] inputs2 = {(short) 0x1020, (short) 0xFFFF, (short) 0x0000, (short) 0x00FF}; doShortTest(inputs1.length, inputs1, inputs2); } @Test void testByte2NonNull() { - Byte[] inputs1 = {(byte)0x01, (byte)0x00, (byte)0xFF, (byte)0x0F}; - Byte[] inputs2 = {(byte)0x10, (byte)0xFF, (byte)0x00, (byte)0xF0}; + Byte[] inputs1 = {(byte) 0x01, (byte) 0x00, (byte) 0xFF, (byte) 0x0F}; + Byte[] inputs2 = {(byte) 0x10, (byte) 0xFF, (byte) 0x00, (byte) 0xF0}; doByteTest(inputs1.length, inputs1, inputs2); } @Test void testInt2Null() { - Integer[] inputs1 = {0x00000000, null, 0xFFFFFFFF, 0xFF00FF00}; + Integer[] inputs1 = {0x00000000, null, 0xFFFFFFFF, 0xFF00FF00}; Integer[] inputs2 = {0xFFFFFFFF, 0x00000000, 0x00FF00FF, null}; doIntTest(inputs1.length, inputs1, inputs2); } @@ -324,17 +323,17 @@ void testInt3NonNull() { @Test void testShort3NonNull() { - Short[] inputs1 = {(short)0x0000, (short)0x4444, (short)0x1111}; - Short[] inputs2 = {(short)0x1111, (short)0x8888, (short)0x2222}; - Short[] inputs3 = {(short)0x2222, (short)0x0000, (short)0x4444}; + Short[] inputs1 = {(short) 0x0000, (short) 0x4444, (short) 0x1111}; + Short[] inputs2 = {(short) 0x1111, (short) 0x8888, (short) 0x2222}; + Short[] inputs3 = {(short) 0x2222, (short) 0x0000, (short) 0x4444}; doShortTest(inputs1.length, inputs1, inputs2, inputs3); } @Test void testByte3NonNull() { - Byte[] inputs1 = {(byte)0x00, (byte)0x44, (byte)0x11}; - Byte[] inputs2 = {(byte)0x11, (byte)0x88, (byte)0x22}; - Byte[] inputs3 = {(byte)0x22, (byte)0x00, (byte)0x44}; + Byte[] inputs1 = {(byte) 0x00, (byte) 0x44, (byte) 0x11}; + Byte[] inputs2 = {(byte) 0x11, (byte) 0x88, (byte) 0x22}; + Byte[] inputs3 = {(byte) 0x22, (byte) 0x00, (byte) 0x44}; doByteTest(inputs1.length, inputs1, inputs2, inputs3); } } \ No newline at end of file diff --git a/src/test/java/com/nvidia/spark/rapids/jni/LimitingOffHeapAllocForTests.java b/src/test/java/com/nvidia/spark/rapids/jni/LimitingOffHeapAllocForTests.java index eb32667dc7..082773a18e 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/LimitingOffHeapAllocForTests.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/LimitingOffHeapAllocForTests.java @@ -17,7 +17,6 @@ package com.nvidia.spark.rapids.jni; import ai.rapids.cudf.HostMemoryBuffer; - import java.util.Optional; /** @@ -27,6 +26,7 @@ public class LimitingOffHeapAllocForTests { private static long limit; private static long amountAllocated = 0; + public static synchronized void setLimit(long limit) { LimitingOffHeapAllocForTests.limit = limit; if (amountAllocated > 0) { @@ -68,6 +68,7 @@ private static Optional allocInternal(long amount, boolean blo /** * Do a non-blocking allocation + * * @param amount the amount to allocate * @return the allocated buffer or not. */ @@ -77,6 +78,7 @@ public static Optional tryAlloc(long amount) { /** * Do a blocking allocation + * * @param amount the amount to allocate * @return the allocated buffer */ diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ParseURITest.java b/src/test/java/com/nvidia/spark/rapids/jni/ParseURITest.java index 1ddf588b02..5dc2605921 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ParseURITest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ParseURITest.java @@ -16,18 +16,16 @@ package com.nvidia.spark.rapids.jni; +import ai.rapids.cudf.AssertUtils; +import ai.rapids.cudf.ColumnVector; import java.net.URI; import java.net.URISyntaxException; - import org.junit.jupiter.api.Test; -import ai.rapids.cudf.AssertUtils; -import ai.rapids.cudf.ColumnVector; - public class ParseURITest { void testProtocol(String[] testData) { String[] expectedProtocolStrings = new String[testData.length]; - for (int i=0; i\tnumber of tasks that can run in parallel on the GPU"); System.out.println("--seed=\tthe random seed to use for the test"); System.out.println("--gpuMiB=\tlimit on the GPUs memory to use for testing"); - System.out.println("--taskMaxMiB=\tmaximum amount of memory a regular task may have allocated"); - System.out.println("--allocMode=\tthe RMM allocation mode to use POOL, ASYNC, ARENA, CUDA"); - System.out.println("--taskRetry=\tmaximum number of times to retry a task before failing the situation"); + System.out.println( + "--taskMaxMiB=\tmaximum amount of memory a regular task may have allocated"); + System.out.println( + "--allocMode=\tthe RMM allocation mode to use POOL, ASYNC, ARENA, CUDA"); + System.out.println( + "--taskRetry=\tmaximum number of times to retry a task before failing the situation"); System.out.println("--maxTaskAllocs=\tmaximum number of allocations a task can make"); - System.out.println("--maxTaskSleep=\tmaximum amount of time a task can sleep for (sim processing)"); + System.out.println( + "--maxTaskSleep=\tmaximum amount of time a task can sleep for (sim processing)"); System.out.println("--noLog\tdisable logging"); System.out.println("--skewed\tgenerate templated tasks and skew one of them by skewAmount"); System.out.println("--skewAmount=\tthe amount to multiply the skewed allocations by"); - System.out.println("--useTemplate\tif all of the tasks should be the same, but change by +/- templateChangeAmount as a multiplier"); - System.out.println("--templateChangeAmount=\tA multiplication factor to change the template task by when making new tasks (as a multiplier)"); - System.out.println("--shuffleThreads=\tThe number of threads to use to simulate UCX shuffle"); + System.out.println( + "--useTemplate\tif all of the tasks should be the same, but change by +/- templateChangeAmount as a multiplier"); + System.out.println( + "--templateChangeAmount=\tA multiplication factor to change the template task by when making new tasks (as a multiplier)"); + System.out.println( + "--shuffleThreads=\tThe number of threads to use to simulate UCX shuffle"); System.exit(0); } else { throw new IllegalArgumentException("Unexpected argument " + arg + @@ -181,11 +196,11 @@ public static void main(String [] args) throws InterruptedException { } public static void setupRmm(int allocationMode, long limitMiB, boolean useSparkRmm, - boolean enableLogging) { + boolean enableLogging) { long limitBytes = limitMiB * 1024 * 1024; Rmm.LogConf rmmLog = null; if (enableLogging) { - rmmLog = Rmm.logTo(new File("./monte.rmm.log")); + rmmLog = Rmm.logTo(new File("./monte.rmm.log")); } if (allocationMode == RmmAllocationMode.CUDA_DEFAULT) { // We want to limit the total size, but Rmm will not do that by default... @@ -209,7 +224,8 @@ public static void setupRmm(int allocationMode, long limitMiB, boolean useSparkR } else { Rmm.initialize(allocationMode, rmmLog, limitBytes); } - boolean needsSync = (allocationMode & RmmAllocationMode.CUDA_ASYNC) == RmmAllocationMode.CUDA_ASYNC; + boolean needsSync = + (allocationMode & RmmAllocationMode.CUDA_ASYNC) == RmmAllocationMode.CUDA_ASYNC; if (useSparkRmm) { if (enableLogging) { RmmSpark.setEventHandler(new TestRmmEventHandler(needsSync), "./monte.state.log"); @@ -223,12 +239,76 @@ public static void setupRmm(int allocationMode, long limitMiB, boolean useSparkR } } + private static List generateSituations(long seed, int numIterations, long numTasks, + long taskMaxMiB, int maxTaskAllocs, + int maxTaskSleep, + boolean isSkewed, double skewAmount, + boolean useTemplate, + double templateChangeAmount) { + ArrayList ret = new ArrayList<>(numIterations); + long start = System.nanoTime(); + System.out.println("Generating " + numIterations + " test situations..."); + + Random r = new Random(seed); + for (int i = 0; i < numIterations; i++) { + ret.add(new Situation(r, numTasks, taskMaxMiB, maxTaskAllocs, maxTaskSleep, + isSkewed, skewAmount, useTemplate, templateChangeAmount)); + } + + long end = System.nanoTime(); + long diff = TimeUnit.MILLISECONDS.convert(end - start, TimeUnit.NANOSECONDS); + System.out.println("Took " + diff + " milliseconds to generate " + numIterations); + return ret; + } + + interface MemoryOp { + default void doIt(DeviceMemoryBuffer[] buffers, long taskId) { + long threadId = RmmSpark.getCurrentThreadId(); + RmmSpark.shuffleThreadWorkingOnTasks(new long[] {taskId}); + RmmSpark.startRetryBlock(threadId); + try { + int tries = 0; + while (tries < 100 && tries >= 0) { + try { + if (tries > 0) { + RmmSpark.blockThreadUntilReady(); + } + tries++; + doIt(buffers); + tries = -1; + } catch (GpuRetryOOM oom) { + // Don't need to clear the buffers, because there is only one buffer. + numRetry.incrementAndGet(); + } catch (CpuRetryOOM oom) { + // Don't need to clear the buffers, because there is only one buffer. + numRetry.incrementAndGet(); + } + } + if (tries >= 100) { + throw new OutOfMemoryError("Could not make shuffle work after " + tries + " tries"); + } + } finally { + RmmSpark.endRetryBlock(threadId); + RmmSpark.poolThreadFinishedForTask(taskId); + } + } + + void doIt(DeviceMemoryBuffer[] buffers); + + MemoryOp[] split(); + + MemoryOp randomMod(Random r, double templateChangeAmount); + + MemoryOp makeSkewed(double skewAmount); + } + private static class TestRmmEventHandler implements RmmEventHandler { private final boolean needsSync; public TestRmmEventHandler(boolean needsSync) { this.needsSync = needsSync; } + @Override public long[] getAllocThresholds() { return null; @@ -270,7 +350,7 @@ public static class TaskRunnerThread extends Thread { volatile boolean done = false; public TaskRunnerThread(CyclicBarrier barrier, SituationRunner runner, int taskRetry, - ExecutorService shuffle) { + ExecutorService shuffle) { this.barrier = barrier; this.runner = runner; this.taskRetry = taskRetry; @@ -374,6 +454,7 @@ public boolean hadOtherFailures() { static class ShuffleThreadFactory implements ThreadFactory { static final AtomicLong idGen = new AtomicLong(0); + @Override public Thread newThread(Runnable runnable) { long id = idGen.getAndIncrement(); @@ -385,12 +466,10 @@ public Thread newThread(Runnable runnable) { } public static class SituationRunner { - final TaskRunnerThread[] threads; public final boolean debugOoms; - private ExecutorService shuffle; + final TaskRunnerThread[] threads; final CyclicBarrier barrier; volatile boolean sitIsDone = false; - // Stats volatile int failedSits; volatile int successSits; @@ -401,6 +480,7 @@ public static class SituationRunner { volatile long totalTimeLost; volatile boolean sitFailed; volatile boolean didThisSitFail = false; + private final ExecutorService shuffle; public SituationRunner(int parallelism, int taskRetry, int shuffleThreads, boolean debugOoms) { this.debugOoms = debugOoms; @@ -433,6 +513,22 @@ public SituationRunner(int parallelism, int taskRetry, int shuffleThreads, boole } } + private static String asTimeStr(long timeNs) { + long justms = TimeUnit.NANOSECONDS.toMillis(timeNs); + + long hours = TimeUnit.NANOSECONDS.toHours(timeNs); + long hoursInNanos = TimeUnit.HOURS.toNanos(hours); + timeNs = timeNs - hoursInNanos; + long mins = TimeUnit.NANOSECONDS.toMinutes(timeNs); + long minsInNanos = TimeUnit.MINUTES.toNanos(mins); + timeNs = timeNs - minsInNanos; + long secs = TimeUnit.NANOSECONDS.toSeconds(timeNs); + long secsInNanos = TimeUnit.SECONDS.toNanos(secs); + long ns = timeNs - secsInNanos; + return String.format("%1$02d:%2$02d:%3$02d.%4$09d", hours, mins, secs, ns) + + " or " + justms + " ms"; + } + public int run(List situations) throws InterruptedException { numSplitAndRetry.set(0); numRetry.set(0); @@ -447,7 +543,7 @@ public int run(List situations) throws InterruptedException { } int numSits = 0; long totalSitTime = 0; - for (Situation sit: situations) { + for (Situation sit : situations) { numSits++; long start = System.nanoTime(); for (TaskRunnerThread t : threads) { @@ -497,22 +593,6 @@ public int run(List situations) throws InterruptedException { } } - private static String asTimeStr(long timeNs) { - long justms = TimeUnit.NANOSECONDS.toMillis(timeNs); - - long hours = TimeUnit.NANOSECONDS.toHours(timeNs); - long hoursInNanos = TimeUnit.HOURS.toNanos(hours); - timeNs = timeNs - hoursInNanos; - long mins = TimeUnit.NANOSECONDS.toMinutes(timeNs); - long minsInNanos = TimeUnit.MINUTES.toNanos(mins); - timeNs = timeNs - minsInNanos; - long secs = TimeUnit.NANOSECONDS.toSeconds(timeNs); - long secsInNanos = TimeUnit.SECONDS.toNanos(secs); - long ns = timeNs - secsInNanos; - return String.format("%1$02d:%2$02d:%3$02d.%4$09d", hours, mins, secs, ns) + - " or " + justms + " ms"; - } - public void finish() { for (TaskRunnerThread t : threads) { t.finish(); @@ -535,49 +615,8 @@ public synchronized void setSitFailed() { } } - interface MemoryOp { - default void doIt(DeviceMemoryBuffer[] buffers, long taskId) { - long threadId = RmmSpark.getCurrentThreadId(); - RmmSpark.shuffleThreadWorkingOnTasks(new long[]{taskId}); - RmmSpark.startRetryBlock(threadId); - try { - int tries = 0; - while (tries < 100 && tries >= 0) { - try { - if (tries > 0) { - RmmSpark.blockThreadUntilReady(); - } - tries++; - doIt(buffers); - tries = -1; - } catch (GpuRetryOOM oom) { - // Don't need to clear the buffers, because there is only one buffer. - numRetry.incrementAndGet(); - } catch (CpuRetryOOM oom) { - // Don't need to clear the buffers, because there is only one buffer. - numRetry.incrementAndGet(); - } - } - if (tries >= 100) { - throw new OutOfMemoryError("Could not make shuffle work after " + tries + " tries"); - } - } finally { - RmmSpark.endRetryBlock(threadId); - RmmSpark.poolThreadFinishedForTask(taskId); - } - } - - void doIt(DeviceMemoryBuffer[] buffers); - - MemoryOp[] split(); - - MemoryOp randomMod(Random r, double templateChangeAmount); - - MemoryOp makeSkewed(double skewAmount); - } - public static class AllocOp implements MemoryOp { - private static AtomicLong idgen = new AtomicLong(0); + private static final AtomicLong idgen = new AtomicLong(0); public final int offset; private final long size; @@ -598,6 +637,7 @@ private AllocOp(int offset, long size, long sleepTime, long id) { this.sleepTime = sleepTime; this.id = id; } + @Override public String toString() { return "ALLOC[" + offset + "] " + size + " SLEEP " + sleepTime; @@ -636,7 +676,7 @@ public MemoryOp[] split() { @Override public MemoryOp randomMod(Random r, double templateChangeAmount) { double proposedSizeMult = (1.0 - (r.nextDouble() * 2.0)) * templateChangeAmount; - long newSize = (long)(proposedSizeMult * size); + long newSize = (long) (proposedSizeMult * size); if (newSize <= 0) { newSize = 1; } @@ -645,7 +685,7 @@ public MemoryOp randomMod(Random r, double templateChangeAmount) { @Override public MemoryOp makeSkewed(double skewAmount) { - return new AllocOp(offset, (long)(size * skewAmount), sleepTime); + return new AllocOp(offset, (long) (size * skewAmount), sleepTime); } } @@ -662,7 +702,7 @@ public String toString() { } @Override - public void doIt(DeviceMemoryBuffer[] buffers) { + public void doIt(DeviceMemoryBuffer[] buffers) { DeviceMemoryBuffer buf = buffers[offset]; if (buf != null) { buf.close(); @@ -692,9 +732,9 @@ public MemoryOp makeSkewed(double skewAmount) { } public static class TaskOpSet { + final int numBuffers; DeviceMemoryBuffer[] buffers; ArrayList operations; - final int numBuffers; long allocatedBeforeError = 0; private TaskOpSet(int numBuffers, ArrayList operations) { @@ -703,7 +743,7 @@ private TaskOpSet(int numBuffers, ArrayList operations) { } public TaskOpSet(Random r, long taskMaxMiB, - int maxTaskAllocs, int maxTaskSleep) { + int maxTaskAllocs, int maxTaskSleep) { long maxSleepTimeNano = TimeUnit.MILLISECONDS.toNanos(maxTaskSleep); long totalSleepTimeNano = 0; if (maxSleepTimeNano > 0) { @@ -733,7 +773,8 @@ public TaskOpSet(Random r, long taskMaxMiB, // We want the sleeps to be very small because we are not simulating // the time, and generally they will be. In the future we can make this // configurable. - long sleepTime = (long)(totalSleepTimeNano * sleepWeights[allocOpNum]/totalSleepWeight); + long sleepTime = + (long) (totalSleepTimeNano * sleepWeights[allocOpNum] / totalSleepWeight); AllocOp ao = new AllocOp(allocOpNum, size, sleepTime); operations.add(ao); outstandingAllocOps.add(ao); @@ -777,9 +818,9 @@ public void run(ExecutorService shuffle, long taskId) { allocatedBeforeError = 0; boolean isForShuffle = shuffle != null; boolean done = false; - while(!done) { + while (!done) { try { - for (MemoryOp op: operations) { + for (MemoryOp op : operations) { if (isForShuffle) { try { RmmSpark.submittingToPool(); @@ -827,7 +868,7 @@ public long getAllocatedBeforeError() { public TaskOpSet randomMod(Random r, double templateChangeAmount) { ArrayList newOps = new ArrayList<>(operations.size()); - for (MemoryOp op: operations) { + for (MemoryOp op : operations) { newOps.add(op.randomMod(r, templateChangeAmount)); } return new TaskOpSet(numBuffers, newOps); @@ -835,7 +876,7 @@ public TaskOpSet randomMod(Random r, double templateChangeAmount) { public TaskOpSet makeSkewed(double skewAmount) { ArrayList newOps = new ArrayList<>(operations.size()); - for (MemoryOp op: operations) { + for (MemoryOp op : operations) { newOps.add(op.makeSkewed(skewAmount)); } return new TaskOpSet(numBuffers, newOps); @@ -904,7 +945,7 @@ public String toString() { public Task randomMod(Random r, double templateChangeAmount) { LinkedList changed = new LinkedList<>(); - for (TaskOpSet orig: toDo) { + for (TaskOpSet orig : toDo) { changed.add(orig.randomMod(r, templateChangeAmount)); } return new Task(changed, 0); @@ -912,7 +953,7 @@ public Task randomMod(Random r, double templateChangeAmount) { public Task makeSkewed(double skewAmount) { LinkedList changed = new LinkedList<>(); - for (TaskOpSet orig: toDo) { + for (TaskOpSet orig : toDo) { changed.add(orig.makeSkewed(skewAmount)); } return new Task(changed, 0); @@ -923,8 +964,9 @@ public static class Situation { LinkedList tasks = new LinkedList<>(); public Situation(Random r, long numTasks, long taskMaxMiB, - int maxTaskAllocs, int maxTaskSleep, - boolean isSkewed, double skewAmount, boolean useTemplate, double templateChangeAmount) { + int maxTaskAllocs, int maxTaskSleep, + boolean isSkewed, double skewAmount, boolean useTemplate, + double templateChangeAmount) { if (useTemplate) { Task template = new Task(r, taskMaxMiB, maxTaskAllocs, maxTaskSleep); tasks.add(template); @@ -957,23 +999,4 @@ public String toString() { return "Sit: " + tasks.size(); } } - - private static List generateSituations(long seed, int numIterations, long numTasks, - long taskMaxMiB, int maxTaskAllocs, int maxTaskSleep, - boolean isSkewed, double skewAmount, boolean useTemplate, double templateChangeAmount) { - ArrayList ret = new ArrayList<>(numIterations); - long start = System.nanoTime(); - System.out.println("Generating " + numIterations + " test situations..."); - - Random r = new Random(seed); - for (int i = 0; i < numIterations; i++) { - ret.add(new Situation(r, numTasks, taskMaxMiB, maxTaskAllocs, maxTaskSleep, - isSkewed, skewAmount, useTemplate, templateChangeAmount)); - } - - long end = System.nanoTime(); - long diff = TimeUnit.MILLISECONDS.convert(end - start, TimeUnit.NANOSECONDS); - System.out.println("Took " + diff + " milliseconds to generate " + numIterations); - return ret; - } } diff --git a/src/test/java/com/nvidia/spark/rapids/jni/RmmSparkTest.java b/src/test/java/com/nvidia/spark/rapids/jni/RmmSparkTest.java index 270a4266cd..abc2e520dd 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/RmmSparkTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/RmmSparkTest.java @@ -16,6 +16,10 @@ package com.nvidia.spark.rapids.jni; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.fail; + import ai.rapids.cudf.CudfException; import ai.rapids.cudf.DeviceMemoryBuffer; import ai.rapids.cudf.HostMemoryBuffer; @@ -27,20 +31,14 @@ import ai.rapids.cudf.RmmEventHandler; import ai.rapids.cudf.RmmLimitingResourceAdaptor; import ai.rapids.cudf.RmmTrackingResourceAdaptor; - -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.fail; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; public class RmmSparkTest { private final static long ALIGNMENT = 256; @@ -59,249 +57,6 @@ public void teardown() { } } - public interface TaskThreadOp { - T doIt(); - } - - public static class TaskThread extends Thread { - private final String name; - private final boolean isForPool; - private long threadId = -1; - private long taskId = 100; - - public TaskThread(String name, long taskId) { - this(name, false); - this.taskId = taskId; - } - - public TaskThread(String name, boolean isForPool) { - super(name); - this.name = name; - this.isForPool = isForPool; - } - - public synchronized long getThreadId() { - return threadId; - } - - private LinkedBlockingQueue queue = new LinkedBlockingQueue<>(); - - public void initialize() throws ExecutionException, InterruptedException, TimeoutException { - setDaemon(true); - start(); - Future waitForStart = doIt(new TaskThreadOp() { - @Override - public Void doIt() { - if (!isForPool) { - RmmSpark.currentThreadIsDedicatedToTask(taskId); - } - return null; - } - - @Override - public String toString() { - return "INIT TASK " + name + " " + (isForPool ? "POOL" : ("TASK " + taskId)); - } - }); - System.err.println("WAITING FOR STARTUP (" + name + ")"); - waitForStart.get(1000, TimeUnit.MILLISECONDS); - System.err.println("THREAD IS READY TO GO (" + name + ")"); - } - - public void pollForState(RmmSparkThreadState state, long l, TimeUnit tu) throws TimeoutException, InterruptedException { - long start = System.nanoTime(); - long timeoutAfter = start + tu.toNanos(l); - RmmSparkThreadState currentState = null; - while (System.nanoTime() <= timeoutAfter) { - currentState = RmmSpark.getStateOf(threadId); - if (currentState == state) { - return; - } - // Yes we are essentially doing a busy wait... - Thread.sleep(10); - } - throw new TimeoutException(name + " WAITING FOR STATE " + state + " BUT STATE IS " + currentState); - } - - private static class TaskThreadDoneOp implements TaskThreadOp, Future { - private TaskThread wrapped; - - TaskThreadDoneOp(TaskThread td) { - wrapped = td; - } - - @Override - public String toString() { - return "TASK DONE"; - } - - @Override - public Void doIt() { - return null; - } - - @Override - public boolean cancel(boolean b) { - return false; - } - - @Override - public boolean isCancelled() { - return false; - } - - @Override - public boolean isDone() { - return !wrapped.isAlive(); - } - - @Override - public Object get() throws InterruptedException, ExecutionException { - throw new RuntimeException("FUTURE NEEDS A TIMEOUT. THIS IS A TEST!"); - } - - @Override - public Object get(long l, TimeUnit timeUnit) throws InterruptedException, ExecutionException, TimeoutException { - System.err.println("WAITING FOR THREAD DONE " + l + " " + timeUnit); - wrapped.join(timeUnit.toMillis(l)); - return null; - } - } - - public Future done() { - TaskThreadDoneOp op = new TaskThreadDoneOp(this); - queue.offer(op); - return op; - } - - private static class TaskThreadTrackingOp implements TaskThreadOp, Future { - private final TaskThreadOp wrapped; - private boolean done = false; - private Throwable t = null; - private T ret = null; - - - @Override - public String toString() { - return wrapped.toString(); - } - - TaskThreadTrackingOp(TaskThreadOp td) { - wrapped = td; - } - - @Override - public T doIt() { - try { - T tmp = wrapped.doIt(); - synchronized (this) { - ret = tmp; - return ret; - } - } catch (Throwable t) { - synchronized (this) { - this.t = t; - } - return null; - } finally { - synchronized (this) { - done = true; - this.notifyAll(); - } - } - } - - @Override - public boolean cancel(boolean b) { - return false; - } - - @Override - public boolean isCancelled() { - return false; - } - - @Override - public synchronized boolean isDone() { - return done; - } - - @Override - public synchronized T get() throws InterruptedException, ExecutionException { - throw new RuntimeException("This is a test you should always have timeouts..."); - } - - @Override - public synchronized T get(long l, TimeUnit timeUnit) throws InterruptedException, ExecutionException, TimeoutException { - if (!done) { - System.err.println("WAITING " + l + " " + timeUnit + " FOR '" + wrapped + "'"); - wait(timeUnit.toMillis(l)); - if (!done) { - throw new TimeoutException(); - } - } - if (t != null) { - throw new ExecutionException(t); - } - return ret; - } - } - - public Future doIt(TaskThreadOp op) { - if (!isAlive()) { - throw new IllegalStateException("Thread is already done..."); - } - TaskThreadTrackingOp tracking = new TaskThreadTrackingOp<>(op); - queue.offer(tracking); - return tracking; - } - - public Future blockUntilReady() { - return doIt(new TaskThreadOp() { - @Override - public Void doIt() { - RmmSpark.blockThreadUntilReady(); - return null; - } - - @Override - public String toString() { - return "BLOCK UNTIL THREAD IS READY"; - } - }); - } - - @Override - public void run() { - try { - synchronized (this) { - threadId = RmmSpark.getCurrentThreadId(); - } - System.err.println("INSIDE THREAD RUNNING (" + name + ")"); - while (true) { - // Because of how our deadlock detection code works we don't want to - // block this thread, so we do this in a busy loop. It is not ideal, - // but works, and is more accurate to what the Spark is likely to do - TaskThreadOp op = queue.poll(); - // null is returned from the queue if it is empty - if (op != null) { - System.err.println("GOT '" + op + "' ON " + name); - if (op instanceof TaskThreadDoneOp) { - return; - } - op.doIt(); - System.err.println("'" + op + "' FINISHED ON " + name); - } - } - } catch (Throwable t) { - System.err.println("THROWABLE CAUGHT IN " + name); - t.printStackTrace(System.err); - } finally { - System.err.println("THREAD EXITING " + name); - } - } - } - @Test public void testBasicInitAndTeardown() { Rmm.initialize(RmmAllocationMode.CUDA_DEFAULT, null, 512 * 1024 * 1024); @@ -339,7 +94,7 @@ public void testInsertOOMsGpu() { // No change in the state after a force assertEquals(RmmSparkThreadState.THREAD_RUNNING, RmmSpark.getStateOf(threadId)); assertThrows(GpuRetryOOM.class, () -> Rmm.alloc(100).close()); - assert(RmmSpark.getAndResetComputeTimeLostToRetryNs(taskid) > 0); + assert (RmmSpark.getAndResetComputeTimeLostToRetryNs(taskid) > 0); // Verify that injecting OOM does not cause the block to actually happen or // the state to change @@ -406,7 +161,7 @@ public void testInsertOOMsCpu() { // No change in the state after a force assertEquals(RmmSparkThreadState.THREAD_RUNNING, RmmSpark.getStateOf(threadId)); assertThrows(CpuRetryOOM.class, () -> LimitingOffHeapAllocForTests.alloc(100).close()); - assert(RmmSpark.getAndResetComputeTimeLostToRetryNs(taskid) > 0); + assert (RmmSpark.getAndResetComputeTimeLostToRetryNs(taskid) > 0); // Verify that injecting OOM does not cause the block to actually happen or // the state to change @@ -423,7 +178,8 @@ public void testInsertOOMsCpu() { RmmSpark.forceSplitAndRetryOOM(threadId); // No change in state after force assertEquals(RmmSparkThreadState.THREAD_RUNNING, RmmSpark.getStateOf(threadId)); - assertThrows(CpuSplitAndRetryOOM.class, () -> LimitingOffHeapAllocForTests.alloc(100).close()); + assertThrows(CpuSplitAndRetryOOM.class, + () -> LimitingOffHeapAllocForTests.alloc(100).close()); assertEquals(0, RmmSpark.getAndResetNumRetryThrow(taskid)); assertEquals(1, RmmSpark.getAndResetNumSplitRetryThrow(taskid)); @@ -475,12 +231,14 @@ public void testAssociateThread() { Thread t = Thread.currentThread(); try { RmmSpark.startDedicatedTaskThread(threadIdOne, taskId, t); - assertThrows(CudfException.class, () -> RmmSpark.shuffleThreadWorkingTasks(threadIdOne, t, taskIds)); + assertThrows(CudfException.class, + () -> RmmSpark.shuffleThreadWorkingTasks(threadIdOne, t, taskIds)); // There can be races when a thread goes from one task to another, so we just make it safe to do. RmmSpark.startDedicatedTaskThread(threadIdOne, otherTaskId, t); RmmSpark.shuffleThreadWorkingTasks(threadIdTwo, t, taskIds); - assertThrows(CudfException.class, () -> RmmSpark.startDedicatedTaskThread(threadIdTwo, otherTaskId, t)); + assertThrows(CudfException.class, + () -> RmmSpark.startDedicatedTaskThread(threadIdTwo, otherTaskId, t)); // Remove the association RmmSpark.removeDedicatedThreadAssociation(threadIdTwo, taskId); RmmSpark.removeDedicatedThreadAssociation(threadIdTwo, otherTaskId); @@ -492,146 +250,9 @@ public void testAssociateThread() { } } - - static abstract class AllocOnAnotherThread implements AutoCloseable { - final TaskThread thread; - final long size; - final long taskId; - MemoryBuffer b = null; - Future fb; - Future fc = null; - - public AllocOnAnotherThread(TaskThread thread, long size) { - this.thread = thread; - this.size = size; - this.taskId = -1; - fb = thread.doIt(new TaskThreadOp() { - @Override - public Void doIt() { - doAlloc(); - return null; - } - - @Override - public String toString() { - return "ALLOC(" + size + ")"; - } - }); - } - - public AllocOnAnotherThread(TaskThread thread, long size, long taskId) { - this.thread = thread; - this.size = size; - this.taskId = taskId; - fb = thread.doIt(new TaskThreadOp() { - @Override - public Void doIt() { - RmmSpark.shuffleThreadWorkingOnTasks(new long[]{taskId}); - doAlloc(); - return null; - } - - @Override - public String toString() { - return "ALLOC(" + size + ")"; - } - }); - } - - public void waitForAlloc() throws ExecutionException, InterruptedException, TimeoutException { - fb.get(1000, TimeUnit.MILLISECONDS); - } - - public void freeOnThread() { - if (fc != null) { - throw new IllegalStateException("free called multiple times"); - } - - fc = thread.doIt(new TaskThreadOp() { - @Override - public Void doIt() { - close(); - return null; - } - - @Override - public String toString() { - return "FREE(" + size + ")"; - } - }); - } - - public void waitForFree() throws ExecutionException, InterruptedException, TimeoutException { - if (fc == null) { - freeOnThread(); - } - fc.get(1000, TimeUnit.MILLISECONDS); - } - - public void freeAndWait() throws ExecutionException, InterruptedException, TimeoutException { - waitForFree(); - } - - abstract protected Void doAlloc(); - - @Override - public synchronized void close() { - if (b != null) { - try { - b.close(); - b = null; - } finally { - if (this.taskId > 0) { - RmmSpark.poolThreadFinishedForTasks(thread.threadId, new long[]{taskId}); - } - } - } - } - } - - public static class GpuAllocOnAnotherThread extends AllocOnAnotherThread { - - public GpuAllocOnAnotherThread(TaskThread thread, long size) { - super(thread, size); - } - - public GpuAllocOnAnotherThread(TaskThread thread, long size, long taskId) { - super(thread, size, taskId); - } - - @Override - protected Void doAlloc() { - DeviceMemoryBuffer tmp = Rmm.alloc(size); - synchronized (this) { - b = tmp; - } - return null; - } - } - - public static class CpuAllocOnAnotherThread extends AllocOnAnotherThread { - - public CpuAllocOnAnotherThread(TaskThread thread, long size) { - super(thread, size); - } - - public CpuAllocOnAnotherThread(TaskThread thread, long size, long taskId) { - super(thread, size, taskId); - } - - @Override - protected Void doAlloc() { - HostMemoryBuffer tmp = LimitingOffHeapAllocForTests.alloc(size); - synchronized (this) { - b = tmp; - } - return null; - } - } - - void setupRmmForTestingWithLimits(long maxAllocSize) { - setupRmmForTestingWithLimits(maxAllocSize, new BaseRmmEventHandler()); - } + void setupRmmForTestingWithLimits(long maxAllocSize) { + setupRmmForTestingWithLimits(maxAllocSize, new BaseRmmEventHandler()); + } void setupRmmForTestingWithLimits(long maxAllocSize, RmmEventHandler eventHandler) { // Rmm.initialize is not going to limit allocations without a pool, so we @@ -698,7 +319,8 @@ public void testNonBlockingCpuAllocFailedOOM() { } @Test - public void testBasicBlocking() throws ExecutionException, InterruptedException, TimeoutException { + public void testBasicBlocking() + throws ExecutionException, InterruptedException, TimeoutException { // 10 MiB setupRmmForTestingWithLimits(10 * 1024 * 1024); TaskThread taskOne = new TaskThread("TEST THREAD ONE", 1); @@ -715,7 +337,8 @@ public void testBasicBlocking() throws ExecutionException, InterruptedException, try (AllocOnAnotherThread firstOne = new GpuAllocOnAnotherThread(taskOne, 5 * 1024 * 1024)) { firstOne.waitForAlloc(); // This one should block - try (AllocOnAnotherThread secondOne = new GpuAllocOnAnotherThread(taskTwo, 6 * 1024 * 1024)) { + try (AllocOnAnotherThread secondOne = new GpuAllocOnAnotherThread(taskTwo, + 6 * 1024 * 1024)) { taskTwo.pollForState(RmmSparkThreadState.THREAD_BLOCKED, 1000, TimeUnit.MILLISECONDS); // Free the first allocation to wake up the second task... firstOne.freeAndWait(); @@ -730,7 +353,8 @@ public void testBasicBlocking() throws ExecutionException, InterruptedException, } @Test - public void testBasicCpuBlocking() throws ExecutionException, InterruptedException, TimeoutException { + public void testBasicCpuBlocking() + throws ExecutionException, InterruptedException, TimeoutException { // 10 MiB setupRmmForTestingWithLimits(10 * 1024 * 1024); LimitingOffHeapAllocForTests.setLimit(10 * 1024 * 1024); @@ -748,7 +372,8 @@ public void testBasicCpuBlocking() throws ExecutionException, InterruptedExcepti try (AllocOnAnotherThread firstOne = new CpuAllocOnAnotherThread(taskOne, 5 * 1024 * 1024)) { firstOne.waitForAlloc(); // This one should block - try (AllocOnAnotherThread secondOne = new CpuAllocOnAnotherThread(taskTwo, 6 * 1024 * 1024)) { + try (AllocOnAnotherThread secondOne = new CpuAllocOnAnotherThread(taskTwo, + 6 * 1024 * 1024)) { taskTwo.pollForState(RmmSparkThreadState.THREAD_BLOCKED, 1000, TimeUnit.MILLISECONDS); // Free the first allocation to wake up the second task... firstOne.freeAndWait(); @@ -764,7 +389,8 @@ public void testBasicCpuBlocking() throws ExecutionException, InterruptedExcepti } @Test - public void testBasicMixedBlocking() throws ExecutionException, InterruptedException, TimeoutException { + public void testBasicMixedBlocking() + throws ExecutionException, InterruptedException, TimeoutException { final long MB = 1024 * 1024; setupRmmForTestingWithLimits(10 * MB); LimitingOffHeapAllocForTests.setLimit(10 * MB); @@ -799,12 +425,15 @@ public void testBasicMixedBlocking() throws ExecutionException, InterruptedExcep firstCpuAlloc.waitForAlloc(); // Blocking GPU Alloc - try (AllocOnAnotherThread secondGpuAlloc = new GpuAllocOnAnotherThread(taskThree, SIX_MB)) { + try (AllocOnAnotherThread secondGpuAlloc = new GpuAllocOnAnotherThread(taskThree, + SIX_MB)) { taskThree.pollForState(RmmSparkThreadState.THREAD_BLOCKED, 1000, TimeUnit.MILLISECONDS); // Blocking CPU Alloc - try (AllocOnAnotherThread secondCpuAlloc = new CpuAllocOnAnotherThread(taskFour, SIX_MB)) { - taskFour.pollForState(RmmSparkThreadState.THREAD_BLOCKED, 1000, TimeUnit.MILLISECONDS); + try (AllocOnAnotherThread secondCpuAlloc = new CpuAllocOnAnotherThread(taskFour, + SIX_MB)) { + taskFour.pollForState(RmmSparkThreadState.THREAD_BLOCKED, 1000, + TimeUnit.MILLISECONDS); // We want to make sure that the order of wakeup corresponds to the location of the data that was released // Not necessarily the priority of the task/thread. @@ -819,7 +448,8 @@ public void testBasicMixedBlocking() throws ExecutionException, InterruptedExcep secondGpuAlloc.freeAndWait(); } // Do one more alloc after freeing on same task to show the max allocation metric is unimpacted - try (AllocOnAnotherThread secondGpuAlloc = new GpuAllocOnAnotherThread(taskThree, FIVE_MB)) { + try (AllocOnAnotherThread secondGpuAlloc = new GpuAllocOnAnotherThread(taskThree, + FIVE_MB)) { secondGpuAlloc.waitForAlloc(); secondGpuAlloc.freeAndWait(); } @@ -838,7 +468,8 @@ public void testBasicMixedBlocking() throws ExecutionException, InterruptedExcep } @Test - public void testShuffleBlocking() throws ExecutionException, InterruptedException, TimeoutException { + public void testShuffleBlocking() + throws ExecutionException, InterruptedException, TimeoutException { // 10 MiB setupRmmForTestingWithLimits(10 * 1024 * 1024); TaskThread shuffleOne = new TaskThread("TEST THREAD SHUFFLE", true); @@ -861,11 +492,14 @@ public void testShuffleBlocking() throws ExecutionException, InterruptedExceptio try (AllocOnAnotherThread firstOne = new GpuAllocOnAnotherThread(taskOne, 5 * 1024 * 1024)) { firstOne.waitForAlloc(); // This one should block - try (AllocOnAnotherThread secondOne = new GpuAllocOnAnotherThread(taskTwo, 6 * 1024 * 1024)) { + try (AllocOnAnotherThread secondOne = new GpuAllocOnAnotherThread(taskTwo, + 6 * 1024 * 1024)) { taskTwo.pollForState(RmmSparkThreadState.THREAD_BLOCKED, 1000, TimeUnit.MILLISECONDS); // Make sure that shuffle has higher priority than tasks... - try (AllocOnAnotherThread thirdOne = new GpuAllocOnAnotherThread(shuffleOne, 6 * 1024 * 1024, 2)) { - shuffleOne.pollForState(RmmSparkThreadState.THREAD_BLOCKED, 1000, TimeUnit.MILLISECONDS); + try (AllocOnAnotherThread thirdOne = new GpuAllocOnAnotherThread(shuffleOne, + 6 * 1024 * 1024, 2)) { + shuffleOne.pollForState(RmmSparkThreadState.THREAD_BLOCKED, 1000, + TimeUnit.MILLISECONDS); // But taskOne is not blocked, so there will be no retry until it is blocked, or else // it is making progress taskOne.doIt((TaskThreadOp) () -> { @@ -900,9 +534,9 @@ public void testShuffleBlocking() throws ExecutionException, InterruptedExceptio } } - @Test - public void testShuffleBlockingCpu() throws ExecutionException, InterruptedException, TimeoutException { + public void testShuffleBlockingCpu() + throws ExecutionException, InterruptedException, TimeoutException { // 10 MiB setupRmmForTestingWithLimits(10 * 1024 * 1024); LimitingOffHeapAllocForTests.setLimit(10 * 1024 * 1024); @@ -926,11 +560,14 @@ public void testShuffleBlockingCpu() throws ExecutionException, InterruptedExcep try (AllocOnAnotherThread firstOne = new CpuAllocOnAnotherThread(taskOne, 5 * 1024 * 1024)) { firstOne.waitForAlloc(); // This one should block - try (AllocOnAnotherThread secondOne = new CpuAllocOnAnotherThread(taskTwo, 6 * 1024 * 1024)) { + try (AllocOnAnotherThread secondOne = new CpuAllocOnAnotherThread(taskTwo, + 6 * 1024 * 1024)) { taskTwo.pollForState(RmmSparkThreadState.THREAD_BLOCKED, 1000, TimeUnit.MILLISECONDS); // Make sure that shuffle has higher priority than tasks... - try (AllocOnAnotherThread thirdOne = new CpuAllocOnAnotherThread(shuffleOne, 6 * 1024 * 1024, 2)) { - shuffleOne.pollForState(RmmSparkThreadState.THREAD_BLOCKED, 1000, TimeUnit.MILLISECONDS); + try (AllocOnAnotherThread thirdOne = new CpuAllocOnAnotherThread(shuffleOne, + 6 * 1024 * 1024, 2)) { + shuffleOne.pollForState(RmmSparkThreadState.THREAD_BLOCKED, 1000, + TimeUnit.MILLISECONDS); // But taskOne is not blocked, so there will be no retry until it is blocked, or else // it is making progress taskOne.doIt((TaskThreadOp) () -> { @@ -982,24 +619,29 @@ public void testBasicBUFN() throws ExecutionException, InterruptedException, Tim long tTwoId = taskTwo.getThreadId(); assertEquals(RmmSparkThreadState.THREAD_RUNNING, RmmSpark.getStateOf(tTwoId)); - try (AllocOnAnotherThread allocThreeOne = new GpuAllocOnAnotherThread(taskThree, 5 * 1024 * 1024)) { + try (AllocOnAnotherThread allocThreeOne = new GpuAllocOnAnotherThread(taskThree, + 5 * 1024 * 1024)) { allocThreeOne.waitForAlloc(); - try (AllocOnAnotherThread allocTwoOne = new GpuAllocOnAnotherThread(taskTwo, 3 * 1024 * 1024)) { + try (AllocOnAnotherThread allocTwoOne = new GpuAllocOnAnotherThread(taskTwo, + 3 * 1024 * 1024)) { allocTwoOne.waitForAlloc(); - try (AllocOnAnotherThread allocTwoTwo = new GpuAllocOnAnotherThread(taskTwo, 3 * 1024 * 1024)) { + try (AllocOnAnotherThread allocTwoTwo = new GpuAllocOnAnotherThread(taskTwo, + 3 * 1024 * 1024)) { taskTwo.pollForState(RmmSparkThreadState.THREAD_BLOCKED, 1000, TimeUnit.MILLISECONDS); - try (AllocOnAnotherThread allocThreeTwo = new GpuAllocOnAnotherThread(taskThree, 4 * 1024 * 1024)) { + try (AllocOnAnotherThread allocThreeTwo = new GpuAllocOnAnotherThread(taskThree, + 4 * 1024 * 1024)) { // This one should be able to allocate because there is not enough memory, but // now all the threads would be blocked, so the lowest priority thread is going to // become BUFN - taskThree.pollForState(RmmSparkThreadState.THREAD_BUFN_WAIT, 1000, TimeUnit.MILLISECONDS); + taskThree.pollForState(RmmSparkThreadState.THREAD_BUFN_WAIT, 1000, + TimeUnit.MILLISECONDS); try { allocThreeTwo.waitForAlloc(); fail("ALLOC AFTER BUFN SHOULD HAVE THROWN..."); } catch (ExecutionException ee) { - assert(ee.getCause() instanceof GpuRetryOOM); + assert (ee.getCause() instanceof GpuRetryOOM); } // allocOneTwo cannot be freed, nothing was allocated because it threw an exception. allocThreeOne.freeAndWait(); @@ -1016,7 +658,8 @@ public void testBasicBUFN() throws ExecutionException, InterruptedException, Tim taskTwo.done().get(1000, TimeUnit.MILLISECONDS); // Now that task two is done see if task one is running again... - taskThree.pollForState(RmmSparkThreadState.THREAD_RUNNING, 1000, TimeUnit.MILLISECONDS); + taskThree.pollForState(RmmSparkThreadState.THREAD_RUNNING, 1000, + TimeUnit.MILLISECONDS); // Now we could finish trying our allocations, but this is good enough... } } @@ -1046,24 +689,29 @@ public void testBasicBUFNCpu() throws ExecutionException, InterruptedException, long tTwoId = taskTwo.getThreadId(); assertEquals(RmmSparkThreadState.THREAD_RUNNING, RmmSpark.getStateOf(tTwoId)); - try (AllocOnAnotherThread allocThreeOne = new CpuAllocOnAnotherThread(taskThree, 5 * 1024 * 1024)) { + try (AllocOnAnotherThread allocThreeOne = new CpuAllocOnAnotherThread(taskThree, + 5 * 1024 * 1024)) { allocThreeOne.waitForAlloc(); - try (AllocOnAnotherThread allocTwoOne = new CpuAllocOnAnotherThread(taskTwo, 3 * 1024 * 1024)) { + try (AllocOnAnotherThread allocTwoOne = new CpuAllocOnAnotherThread(taskTwo, + 3 * 1024 * 1024)) { allocTwoOne.waitForAlloc(); - try (AllocOnAnotherThread allocTwoTwo = new CpuAllocOnAnotherThread(taskTwo, 3 * 1024 * 1024)) { + try (AllocOnAnotherThread allocTwoTwo = new CpuAllocOnAnotherThread(taskTwo, + 3 * 1024 * 1024)) { taskTwo.pollForState(RmmSparkThreadState.THREAD_BLOCKED, 1000, TimeUnit.MILLISECONDS); - try (AllocOnAnotherThread allocThreeTwo = new CpuAllocOnAnotherThread(taskThree, 4 * 1024 * 1024)) { + try (AllocOnAnotherThread allocThreeTwo = new CpuAllocOnAnotherThread(taskThree, + 4 * 1024 * 1024)) { // This one should be able to allocate because there is not enough memory, but // now all the threads would be blocked, so the lowest priority thread is going to // become BUFN - taskThree.pollForState(RmmSparkThreadState.THREAD_BUFN_WAIT, 1000, TimeUnit.MILLISECONDS); + taskThree.pollForState(RmmSparkThreadState.THREAD_BUFN_WAIT, 1000, + TimeUnit.MILLISECONDS); try { allocThreeTwo.waitForAlloc(); fail("ALLOC AFTER BUFN SHOULD HAVE THROWN..."); } catch (ExecutionException ee) { - assert(ee.getCause() instanceof CpuRetryOOM); + assert (ee.getCause() instanceof CpuRetryOOM); } // allocOneTwo cannot be freed, nothing was allocated because it threw an exception. allocThreeOne.freeAndWait(); @@ -1080,7 +728,8 @@ public void testBasicBUFNCpu() throws ExecutionException, InterruptedException, taskTwo.done().get(1000, TimeUnit.MILLISECONDS); // Now that task two is done see if task one is running again... - taskThree.pollForState(RmmSparkThreadState.THREAD_RUNNING, 1000, TimeUnit.MILLISECONDS); + taskThree.pollForState(RmmSparkThreadState.THREAD_RUNNING, 1000, + TimeUnit.MILLISECONDS); // Now we could finish trying our allocations, but this is good enough... } } @@ -1093,7 +742,8 @@ public void testBasicBUFNCpu() throws ExecutionException, InterruptedException, } @Test - public void testBUFNSplitAndRetrySingleThread() throws ExecutionException, InterruptedException, TimeoutException { + public void testBUFNSplitAndRetrySingleThread() + throws ExecutionException, InterruptedException, TimeoutException { // We are doing ths one single threaded. // 10 MiB setupRmmForTestingWithLimits(10 * 1024 * 1024); @@ -1119,7 +769,8 @@ public void testBUFNSplitAndRetrySingleThread() throws ExecutionException, Inter } assertEquals(RmmSparkThreadState.THREAD_RUNNING, RmmSpark.getStateOf(threadId)); // Now we try to allocate with half the data. - try (AllocOnAnotherThread secondTry = new GpuAllocOnAnotherThread(taskOne, 3 * 1024 * 1024)) { + try (AllocOnAnotherThread secondTry = new GpuAllocOnAnotherThread(taskOne, + 3 * 1024 * 1024)) { secondTry.waitForAlloc(); } } @@ -1232,7 +883,7 @@ public void retryWatchdog() { long endTime = System.nanoTime(); System.err.println("Took " + (endTime - startTime) + "ns to retry 500 times..."); } - + // // These next two tests deal with a special case where allocations (and allocation failures) // could happen during spill handling. @@ -1256,7 +907,8 @@ public void testAllocationDuringSpill() { RmmSpark.startDedicatedTaskThread(threadId, taskId, t); assertThrows(GpuOOM.class, () -> { try (DeviceMemoryBuffer filler = Rmm.alloc(9 * 1024 * 1024)) { - try (DeviceMemoryBuffer shouldFail = Rmm.alloc(2 * 1024 * 1024)) {} + try (DeviceMemoryBuffer shouldFail = Rmm.alloc(2 * 1024 * 1024)) { + } fail("overallocation should have failed"); } finally { RmmSpark.removeDedicatedThreadAssociation(threadId, taskId); @@ -1268,7 +920,7 @@ public void testAllocationDuringSpill() { @Test public void testAllocationFailedDuringSpill() { // Create a handler that allocates 2MB from the handler (it should fail) - AllocatingRmmEventHandler rmmEventHandler = new AllocatingRmmEventHandler(2L*1024*1024); + AllocatingRmmEventHandler rmmEventHandler = new AllocatingRmmEventHandler(2L * 1024 * 1024); // 10 MiB setupRmmForTestingWithLimits(10 * 1024 * 1024, rmmEventHandler); long threadId = RmmSpark.getCurrentThreadId(); @@ -1277,7 +929,8 @@ public void testAllocationFailedDuringSpill() { RmmSpark.startDedicatedTaskThread(threadId, taskId, t); assertThrows(GpuOOM.class, () -> { try (DeviceMemoryBuffer filler = Rmm.alloc(9 * 1024 * 1024)) { - try (DeviceMemoryBuffer shouldFail = Rmm.alloc(2 * 1024 * 1024)) {} + try (DeviceMemoryBuffer shouldFail = Rmm.alloc(2 * 1024 * 1024)) { + } fail("overallocation should have failed"); } finally { RmmSpark.removeDedicatedThreadAssociation(threadId, taskId); @@ -1286,46 +939,428 @@ public void testAllocationFailedDuringSpill() { assertEquals(0, rmmEventHandler.getAllocationCount()); } - private static class BaseRmmEventHandler implements RmmEventHandler { - @Override - public long[] getAllocThresholds() { - return null; + public interface TaskThreadOp { + T doIt(); + } + + public static class TaskThread extends Thread { + private final String name; + private final boolean isForPool; + private long threadId = -1; + private long taskId = 100; + private final LinkedBlockingQueue queue = new LinkedBlockingQueue<>(); + + public TaskThread(String name, long taskId) { + this(name, false); + this.taskId = taskId; } - @Override - public long[] getDeallocThresholds() { - return null; + public TaskThread(String name, boolean isForPool) { + super(name); + this.name = name; + this.isForPool = isForPool; } - @Override - public void onAllocThreshold(long totalAllocSize) { + public synchronized long getThreadId() { + return threadId; } - @Override - public void onDeallocThreshold(long totalAllocSize) { + public void initialize() throws ExecutionException, InterruptedException, TimeoutException { + setDaemon(true); + start(); + Future waitForStart = doIt(new TaskThreadOp() { + @Override + public Void doIt() { + if (!isForPool) { + RmmSpark.currentThreadIsDedicatedToTask(taskId); + } + return null; + } + + @Override + public String toString() { + return "INIT TASK " + name + " " + (isForPool ? "POOL" : ("TASK " + taskId)); + } + }); + System.err.println("WAITING FOR STARTUP (" + name + ")"); + waitForStart.get(1000, TimeUnit.MILLISECONDS); + System.err.println("THREAD IS READY TO GO (" + name + ")"); } - @Override - public boolean onAllocFailure(long sizeRequested, int retryCount) { - // This is just a test for now, no spilling... - return false; + public void pollForState(RmmSparkThreadState state, long l, TimeUnit tu) + throws TimeoutException, InterruptedException { + long start = System.nanoTime(); + long timeoutAfter = start + tu.toNanos(l); + RmmSparkThreadState currentState = null; + while (System.nanoTime() <= timeoutAfter) { + currentState = RmmSpark.getStateOf(threadId); + if (currentState == state) { + return; + } + // Yes we are essentially doing a busy wait... + Thread.sleep(10); + } + throw new TimeoutException( + name + " WAITING FOR STATE " + state + " BUT STATE IS " + currentState); } - } - private static class AllocatingRmmEventHandler extends BaseRmmEventHandler { - // if true, we are still in the onAllocFailure callback (recursive call) - boolean stillHandlingAllocFailure = false; + public Future done() { + TaskThreadDoneOp op = new TaskThreadDoneOp(this); + queue.offer(op); + return op; + } - int allocationCount; + public Future doIt(TaskThreadOp op) { + if (!isAlive()) { + throw new IllegalStateException("Thread is already done..."); + } + TaskThreadTrackingOp tracking = new TaskThreadTrackingOp<>(op); + queue.offer(tracking); + return tracking; + } - long allocSize; + public Future blockUntilReady() { + return doIt(new TaskThreadOp() { + @Override + public Void doIt() { + RmmSpark.blockThreadUntilReady(); + return null; + } - public int getAllocationCount() { - return allocationCount; + @Override + public String toString() { + return "BLOCK UNTIL THREAD IS READY"; + } + }); } - public AllocatingRmmEventHandler(long allocSize) { - this.allocSize = allocSize; + @Override + public void run() { + try { + synchronized (this) { + threadId = RmmSpark.getCurrentThreadId(); + } + System.err.println("INSIDE THREAD RUNNING (" + name + ")"); + while (true) { + // Because of how our deadlock detection code works we don't want to + // block this thread, so we do this in a busy loop. It is not ideal, + // but works, and is more accurate to what the Spark is likely to do + TaskThreadOp op = queue.poll(); + // null is returned from the queue if it is empty + if (op != null) { + System.err.println("GOT '" + op + "' ON " + name); + if (op instanceof TaskThreadDoneOp) { + return; + } + op.doIt(); + System.err.println("'" + op + "' FINISHED ON " + name); + } + } + } catch (Throwable t) { + System.err.println("THROWABLE CAUGHT IN " + name); + t.printStackTrace(System.err); + } finally { + System.err.println("THREAD EXITING " + name); + } + } + + private static class TaskThreadDoneOp implements TaskThreadOp, Future { + private final TaskThread wrapped; + + TaskThreadDoneOp(TaskThread td) { + wrapped = td; + } + + @Override + public String toString() { + return "TASK DONE"; + } + + @Override + public Void doIt() { + return null; + } + + @Override + public boolean cancel(boolean b) { + return false; + } + + @Override + public boolean isCancelled() { + return false; + } + + @Override + public boolean isDone() { + return !wrapped.isAlive(); + } + + @Override + public Object get() throws InterruptedException, ExecutionException { + throw new RuntimeException("FUTURE NEEDS A TIMEOUT. THIS IS A TEST!"); + } + + @Override + public Object get(long l, TimeUnit timeUnit) + throws InterruptedException, ExecutionException, TimeoutException { + System.err.println("WAITING FOR THREAD DONE " + l + " " + timeUnit); + wrapped.join(timeUnit.toMillis(l)); + return null; + } + } + + private static class TaskThreadTrackingOp implements TaskThreadOp, Future { + private final TaskThreadOp wrapped; + private boolean done = false; + private Throwable t = null; + private T ret = null; + + + TaskThreadTrackingOp(TaskThreadOp td) { + wrapped = td; + } + + @Override + public String toString() { + return wrapped.toString(); + } + + @Override + public T doIt() { + try { + T tmp = wrapped.doIt(); + synchronized (this) { + ret = tmp; + return ret; + } + } catch (Throwable t) { + synchronized (this) { + this.t = t; + } + return null; + } finally { + synchronized (this) { + done = true; + this.notifyAll(); + } + } + } + + @Override + public boolean cancel(boolean b) { + return false; + } + + @Override + public boolean isCancelled() { + return false; + } + + @Override + public synchronized boolean isDone() { + return done; + } + + @Override + public synchronized T get() throws InterruptedException, ExecutionException { + throw new RuntimeException("This is a test you should always have timeouts..."); + } + + @Override + public synchronized T get(long l, TimeUnit timeUnit) + throws InterruptedException, ExecutionException, TimeoutException { + if (!done) { + System.err.println("WAITING " + l + " " + timeUnit + " FOR '" + wrapped + "'"); + wait(timeUnit.toMillis(l)); + if (!done) { + throw new TimeoutException(); + } + } + if (t != null) { + throw new ExecutionException(t); + } + return ret; + } + } + } + + static abstract class AllocOnAnotherThread implements AutoCloseable { + final TaskThread thread; + final long size; + final long taskId; + MemoryBuffer b = null; + Future fb; + Future fc = null; + + public AllocOnAnotherThread(TaskThread thread, long size) { + this.thread = thread; + this.size = size; + this.taskId = -1; + fb = thread.doIt(new TaskThreadOp() { + @Override + public Void doIt() { + doAlloc(); + return null; + } + + @Override + public String toString() { + return "ALLOC(" + size + ")"; + } + }); + } + + public AllocOnAnotherThread(TaskThread thread, long size, long taskId) { + this.thread = thread; + this.size = size; + this.taskId = taskId; + fb = thread.doIt(new TaskThreadOp() { + @Override + public Void doIt() { + RmmSpark.shuffleThreadWorkingOnTasks(new long[] {taskId}); + doAlloc(); + return null; + } + + @Override + public String toString() { + return "ALLOC(" + size + ")"; + } + }); + } + + public void waitForAlloc() throws ExecutionException, InterruptedException, TimeoutException { + fb.get(1000, TimeUnit.MILLISECONDS); + } + + public void freeOnThread() { + if (fc != null) { + throw new IllegalStateException("free called multiple times"); + } + + fc = thread.doIt(new TaskThreadOp() { + @Override + public Void doIt() { + close(); + return null; + } + + @Override + public String toString() { + return "FREE(" + size + ")"; + } + }); + } + + public void waitForFree() throws ExecutionException, InterruptedException, TimeoutException { + if (fc == null) { + freeOnThread(); + } + fc.get(1000, TimeUnit.MILLISECONDS); + } + + public void freeAndWait() throws ExecutionException, InterruptedException, TimeoutException { + waitForFree(); + } + + abstract protected Void doAlloc(); + + @Override + public synchronized void close() { + if (b != null) { + try { + b.close(); + b = null; + } finally { + if (this.taskId > 0) { + RmmSpark.poolThreadFinishedForTasks(thread.threadId, new long[] {taskId}); + } + } + } + } + } + + public static class GpuAllocOnAnotherThread extends AllocOnAnotherThread { + + public GpuAllocOnAnotherThread(TaskThread thread, long size) { + super(thread, size); + } + + public GpuAllocOnAnotherThread(TaskThread thread, long size, long taskId) { + super(thread, size, taskId); + } + + @Override + protected Void doAlloc() { + DeviceMemoryBuffer tmp = Rmm.alloc(size); + synchronized (this) { + b = tmp; + } + return null; + } + } + + public static class CpuAllocOnAnotherThread extends AllocOnAnotherThread { + + public CpuAllocOnAnotherThread(TaskThread thread, long size) { + super(thread, size); + } + + public CpuAllocOnAnotherThread(TaskThread thread, long size, long taskId) { + super(thread, size, taskId); + } + + @Override + protected Void doAlloc() { + HostMemoryBuffer tmp = LimitingOffHeapAllocForTests.alloc(size); + synchronized (this) { + b = tmp; + } + return null; + } + } + + private static class BaseRmmEventHandler implements RmmEventHandler { + @Override + public long[] getAllocThresholds() { + return null; + } + + @Override + public long[] getDeallocThresholds() { + return null; + } + + @Override + public void onAllocThreshold(long totalAllocSize) { + } + + @Override + public void onDeallocThreshold(long totalAllocSize) { + } + + @Override + public boolean onAllocFailure(long sizeRequested, int retryCount) { + // This is just a test for now, no spilling... + return false; + } + } + + private static class AllocatingRmmEventHandler extends BaseRmmEventHandler { + // if true, we are still in the onAllocFailure callback (recursive call) + boolean stillHandlingAllocFailure = false; + + int allocationCount; + + long allocSize; + + public AllocatingRmmEventHandler(long allocSize) { + this.allocSize = allocSize; + } + + public int getAllocationCount() { + return allocationCount; } @Override @@ -1340,7 +1375,8 @@ public boolean onAllocFailure(long sizeRequested, int retryCount) { return false; } else { stillHandlingAllocFailure = true; - try (DeviceMemoryBuffer dmb = Rmm.alloc(allocSize)) { // try to allocate one byte, and free + try ( + DeviceMemoryBuffer dmb = Rmm.alloc(allocSize)) { // try to allocate one byte, and free allocationCount++; stillHandlingAllocFailure = false; } diff --git a/src/test/java/com/nvidia/spark/rapids/jni/RowConversionTest.java b/src/test/java/com/nvidia/spark/rapids/jni/RowConversionTest.java index c8fc4dfb7f..27ae05e27e 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/RowConversionTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/RowConversionTest.java @@ -20,21 +20,20 @@ import ai.rapids.cudf.ColumnVector; import ai.rapids.cudf.DType; import ai.rapids.cudf.Table; -import org.junit.jupiter.api.Test; - import java.math.RoundingMode; import java.util.stream.IntStream; +import org.junit.jupiter.api.Test; public class RowConversionTest { @Test void fixedWidthRowsRoundTripWide() { Table.TestBuilder tb = new Table.TestBuilder(); - IntStream.range(0, 10).forEach(i -> tb.column(3l, 9l, 4l, 2l, 20l, null)); + IntStream.range(0, 10).forEach(i -> tb.column(3L, 9L, 4L, 2L, 20L, null)); IntStream.range(0, 10).forEach(i -> tb.column(5.0d, 9.5d, 0.9d, 7.23d, 2.8d, null)); IntStream.range(0, 10).forEach(i -> tb.column(5, 1, 0, 2, 7, null)); IntStream.range(0, 10).forEach(i -> tb.column(true, false, false, true, false, null)); IntStream.range(0, 10).forEach(i -> tb.column(1.0f, 3.5f, 5.9f, 7.1f, 9.8f, null)); - IntStream.range(0, 10).forEach(i -> tb.column(new Byte[]{2, 3, 4, 5, 9, null})); + IntStream.range(0, 10).forEach(i -> tb.column(new Byte[] {2, 3, 4, 5, 9, null})); IntStream.range(0, 10).forEach(i -> tb.decimal32Column(-3, RoundingMode.UNNECESSARY, 5.0d, 9.5d, 0.9d, 7.23d, 2.8d, null)); IntStream.range(0, 10).forEach(i -> tb.decimal64Column(-8, 3L, 9L, 4L, 2L, 20L, null)); @@ -63,12 +62,12 @@ void fixedWidthRowsRoundTripWide() { @Test void fixedWidthRowsRoundTrip() { try (Table origTable = new Table.TestBuilder() - .column(3l, 9l, 4l, 2l, 20l, null) + .column(3L, 9L, 4L, 2L, 20L, null) .column(5.0d, 9.5d, 0.9d, 7.23d, 2.8d, null) .column(5, 1, 0, 2, 7, null) .column(true, false, false, true, false, null) .column(1.0f, 3.5f, 5.9f, 7.1f, 9.8f, null) - .column(new Byte[]{2, 3, 4, 5, 9, null}) + .column(new Byte[] {2, 3, 4, 5, 9, null}) .decimal32Column(-3, RoundingMode.UNNECESSARY, 5.0d, 9.5d, 0.9d, 7.23d, 2.8d, null) .decimal64Column(-8, 3L, 9L, 4L, 2L, 20L, null) .build()) { diff --git a/src/test/java/com/nvidia/spark/rapids/jni/TimeZoneTest.java b/src/test/java/com/nvidia/spark/rapids/jni/TimeZoneTest.java index 4eb97a280c..5821ae3f4c 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/TimeZoneTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/TimeZoneTest.java @@ -1,30 +1,28 @@ /* -* Copyright (c) 2023-2024, NVIDIA CORPORATION. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.nvidia.spark.rapids.jni; -import java.time.ZoneId; -import java.util.List; - import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import ai.rapids.cudf.ColumnVector; - +import java.time.ZoneId; +import java.util.List; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; @@ -35,12 +33,12 @@ public class TimeZoneTest { static void cacheTimezoneDatabase() { GpuTimeZoneDB.cacheDatabase(); } - + @AfterAll static void cleanup() { GpuTimeZoneDB.shutdown(); } - + @Test void databaseLoadedTest() { // Check for a few timezones @@ -52,33 +50,33 @@ void databaseLoadedTest() { ZoneId shanghai = ZoneId.of("Asia/Shanghai").normalized(); assertEquals(shanghai.getRules().getTransitions().size() + 1, transitions.size()); } - + @Test void convertToUtcSecondsTest() { try (ColumnVector input = ColumnVector.timestampSecondsFromBoxedLongs( - -1262260800L, - -908838000L, - -908840700L, - -888800400L, - -888799500L, - -888796800L, - 0L, - 1699571634L, - 568036800L - ); - ColumnVector expected = ColumnVector.timestampSecondsFromBoxedLongs( - -1262289600L, - -908870400L, - -908869500L, - -888832800L, - -888831900L, - -888825600L, - -28800L, - 1699542834L, - 568008000L - ); - ColumnVector actual = GpuTimeZoneDB.fromTimestampToUtcTimestamp(input, - ZoneId.of("Asia/Shanghai"))) { + -1262260800L, + -908838000L, + -908840700L, + -888800400L, + -888799500L, + -888796800L, + 0L, + 1699571634L, + 568036800L + ); + ColumnVector expected = ColumnVector.timestampSecondsFromBoxedLongs( + -1262289600L, + -908870400L, + -908869500L, + -888832800L, + -888831900L, + -888825600L, + -28800L, + 1699542834L, + 568008000L + ); + ColumnVector actual = GpuTimeZoneDB.fromTimestampToUtcTimestamp(input, + ZoneId.of("Asia/Shanghai"))) { assertColumnsAreEqual(expected, actual); } } @@ -86,59 +84,59 @@ void convertToUtcSecondsTest() { @Test void convertToUtcMilliSecondsTest() { try (ColumnVector input = ColumnVector.timestampMilliSecondsFromBoxedLongs( - -1262260800000L, - -908838000000L, - -908840700000L, - -888800400000L, - -888799500000L, - -888796800000L, - 0L, - 1699571634312L, - 568036800000L - ); - ColumnVector expected = ColumnVector.timestampMilliSecondsFromBoxedLongs( - -1262289600000L, - -908870400000L, - -908869500000L, - -888832800000L, - -888831900000L, - -888825600000L, - -28800000L, - 1699542834312L, - 568008000000L - ); - ColumnVector actual = GpuTimeZoneDB.fromTimestampToUtcTimestamp(input, - ZoneId.of("Asia/Shanghai"))) { + -1262260800000L, + -908838000000L, + -908840700000L, + -888800400000L, + -888799500000L, + -888796800000L, + 0L, + 1699571634312L, + 568036800000L + ); + ColumnVector expected = ColumnVector.timestampMilliSecondsFromBoxedLongs( + -1262289600000L, + -908870400000L, + -908869500000L, + -888832800000L, + -888831900000L, + -888825600000L, + -28800000L, + 1699542834312L, + 568008000000L + ); + ColumnVector actual = GpuTimeZoneDB.fromTimestampToUtcTimestamp(input, + ZoneId.of("Asia/Shanghai"))) { assertColumnsAreEqual(expected, actual); } } - + @Test void convertToUtcMicroSecondsTest() { try (ColumnVector input = ColumnVector.timestampMicroSecondsFromBoxedLongs( - -1262260800000000L, - -908838000000000L, - -908840700000000L, - -888800400000000L, - -888799500000000L, - -888796800000000L, - 0L, - 1699571634312000L, - 568036800000000L - ); - ColumnVector expected = ColumnVector.timestampMicroSecondsFromBoxedLongs( - -1262289600000000L, - -908870400000000L, - -908869500000000L, - -888832800000000L, - -888831900000000L, - -888825600000000L, - -28800000000L, - 1699542834312000L, - 568008000000000L - ); - ColumnVector actual = GpuTimeZoneDB.fromTimestampToUtcTimestamp(input, - ZoneId.of("Asia/Shanghai"))) { + -1262260800000000L, + -908838000000000L, + -908840700000000L, + -888800400000000L, + -888799500000000L, + -888796800000000L, + 0L, + 1699571634312000L, + 568036800000000L + ); + ColumnVector expected = ColumnVector.timestampMicroSecondsFromBoxedLongs( + -1262289600000000L, + -908870400000000L, + -908869500000000L, + -888832800000000L, + -888831900000000L, + -888825600000000L, + -28800000000L, + 1699542834312000L, + 568008000000000L + ); + ColumnVector actual = GpuTimeZoneDB.fromTimestampToUtcTimestamp(input, + ZoneId.of("Asia/Shanghai"))) { assertColumnsAreEqual(expected, actual); } } @@ -146,27 +144,27 @@ void convertToUtcMicroSecondsTest() { @Test void convertFromUtcSecondsTest() { try (ColumnVector input = ColumnVector.timestampSecondsFromBoxedLongs( - -1262289600L, - -908870400L, - -908869500L, - -888832800L, - -888831900L, - -888825600L, - 0L, - 1699542834L, - 568008000L); - ColumnVector expected = ColumnVector.timestampSecondsFromBoxedLongs( - -1262260800L, - -908838000L, - -908837100L, - -888800400L, - -888799500L, - -888796800L, - 28800L, - 1699571634L, - 568036800L); - ColumnVector actual = GpuTimeZoneDB.fromUtcTimestampToTimestamp(input, - ZoneId.of("Asia/Shanghai"))) { + -1262289600L, + -908870400L, + -908869500L, + -888832800L, + -888831900L, + -888825600L, + 0L, + 1699542834L, + 568008000L); + ColumnVector expected = ColumnVector.timestampSecondsFromBoxedLongs( + -1262260800L, + -908838000L, + -908837100L, + -888800400L, + -888799500L, + -888796800L, + 28800L, + 1699571634L, + 568036800L); + ColumnVector actual = GpuTimeZoneDB.fromUtcTimestampToTimestamp(input, + ZoneId.of("Asia/Shanghai"))) { assertColumnsAreEqual(expected, actual); } } @@ -174,57 +172,57 @@ void convertFromUtcSecondsTest() { @Test void convertFromUtcMilliSecondsTest() { try (ColumnVector input = ColumnVector.timestampMilliSecondsFromBoxedLongs( - -1262289600000L, - -908870400000L, - -908869500000L, - -888832800000L, - -888831900000L, - -888825600000L, - 0L, - 1699542834312L, - 568008000000L); - ColumnVector expected = ColumnVector.timestampMilliSecondsFromBoxedLongs( - -1262260800000L, - -908838000000L, - -908837100000L, - -888800400000L, - -888799500000L, - -888796800000L, - 28800000L, - 1699571634312L, - 568036800000L); - ColumnVector actual = GpuTimeZoneDB.fromUtcTimestampToTimestamp(input, - ZoneId.of("Asia/Shanghai"))) { + -1262289600000L, + -908870400000L, + -908869500000L, + -888832800000L, + -888831900000L, + -888825600000L, + 0L, + 1699542834312L, + 568008000000L); + ColumnVector expected = ColumnVector.timestampMilliSecondsFromBoxedLongs( + -1262260800000L, + -908838000000L, + -908837100000L, + -888800400000L, + -888799500000L, + -888796800000L, + 28800000L, + 1699571634312L, + 568036800000L); + ColumnVector actual = GpuTimeZoneDB.fromUtcTimestampToTimestamp(input, + ZoneId.of("Asia/Shanghai"))) { assertColumnsAreEqual(expected, actual); } } - + @Test void convertFromUtcMicroSecondsTest() { try (ColumnVector input = ColumnVector.timestampMicroSecondsFromBoxedLongs( - -1262289600000000L, - -908870400000000L, - -908869500000000L, - -888832800000000L, - -888831900000000L, - -888825600000000L, - 0L, - 1699542834312000L, - 568008000000000L); - ColumnVector expected = ColumnVector.timestampMicroSecondsFromBoxedLongs( - -1262260800000000L, - -908838000000000L, - -908837100000000L, - -888800400000000L, - -888799500000000L, - -888796800000000L, - 28800000000L, - 1699571634312000L, - 568036800000000L); - ColumnVector actual = GpuTimeZoneDB.fromUtcTimestampToTimestamp(input, - ZoneId.of("Asia/Shanghai"))) { + -1262289600000000L, + -908870400000000L, + -908869500000000L, + -888832800000000L, + -888831900000000L, + -888825600000000L, + 0L, + 1699542834312000L, + 568008000000000L); + ColumnVector expected = ColumnVector.timestampMicroSecondsFromBoxedLongs( + -1262260800000000L, + -908838000000000L, + -908837100000000L, + -888800400000000L, + -888799500000000L, + -888796800000000L, + 28800000000L, + 1699571634312000L, + 568036800000000L); + ColumnVector actual = GpuTimeZoneDB.fromUtcTimestampToTimestamp(input, + ZoneId.of("Asia/Shanghai"))) { assertColumnsAreEqual(expected, actual); } } - + } diff --git a/src/test/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializerTest.java b/src/test/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializerTest.java index 210777accf..ae1cbbfde9 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializerTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializerTest.java @@ -16,10 +16,22 @@ package com.nvidia.spark.rapids.jni.kudo; -import ai.rapids.cudf.*; +import static java.lang.Math.toIntExact; +import static java.util.Arrays.asList; +import static java.util.Collections.emptyList; +import static java.util.Collections.singletonList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import ai.rapids.cudf.AssertUtils; +import ai.rapids.cudf.ColumnVector; +import ai.rapids.cudf.ColumnView; +import ai.rapids.cudf.DType; +import ai.rapids.cudf.HostColumnVector; +import ai.rapids.cudf.Schema; +import ai.rapids.cudf.Table; import com.nvidia.spark.rapids.jni.Arms; -import org.junit.jupiter.api.Test; - import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.DataInputStream; @@ -27,172 +39,9 @@ import java.util.List; import java.util.stream.Collectors; import java.util.stream.IntStream; - -import static java.lang.Math.toIntExact; -import static java.util.Arrays.asList; -import static java.util.Collections.emptyList; -import static java.util.Collections.singletonList; -import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.Test; public class KudoSerializerTest { - @Test - public void testSerializeAndDeserializeTable() { - try(Table expected = buildTestTable()) { - int rowCount = toIntExact(expected.getRowCount()); - for (int sliceSize = 1; sliceSize <= rowCount; sliceSize++) { - List tableSlices = new ArrayList<>(); - for (int startRow = 0; startRow < rowCount; startRow += sliceSize) { - tableSlices.add(new TableSlice(startRow, Math.min(sliceSize, rowCount - startRow), expected)); - } - - checkMergeTable(expected, tableSlices); - } - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - @Test - public void testRowCountOnly() throws Exception { - ByteArrayOutputStream out = new ByteArrayOutputStream(); - long bytesWritten = KudoSerializer.writeRowCountToStream(out, 5); - assertEquals(28, bytesWritten); - - ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); - KudoTableHeader header = KudoTableHeader.readFrom(new DataInputStream(in)).get(); - - assertEquals(0, header.getNumColumns()); - assertEquals(0, header.getOffset()); - assertEquals(5, header.getNumRows()); - assertEquals(0, header.getValidityBufferLen()); - assertEquals(0, header.getOffsetBufferLen()); - assertEquals(0, header.getTotalDataLen()); - } - - @Test - public void testWriteSimple() throws Exception { - KudoSerializer serializer = new KudoSerializer(buildSimpleTestSchema()); - - try (Table t = buildSimpleTable()) { - ByteArrayOutputStream out = new ByteArrayOutputStream(); - long bytesWritten = serializer.writeToStream(t, out, 0, 4); - assertEquals(189, bytesWritten); - - ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); - - KudoTableHeader header = KudoTableHeader.readFrom(new DataInputStream(in)).get(); - assertEquals(7, header.getNumColumns()); - assertEquals(0, header.getOffset()); - assertEquals(4, header.getNumRows()); - assertEquals(24, header.getValidityBufferLen()); - assertEquals(40, header.getOffsetBufferLen()); - assertEquals(160, header.getTotalDataLen()); - - // First integer column has no validity buffer - assertFalse(header.hasValidityBuffer(0)); - for (int i = 1; i < 7; i++) { - assertTrue(header.hasValidityBuffer(i)); - } - } - } - - @Test - public void testMergeTableWithDifferentValidity() { - Arms.withResource(new ArrayList(), tables -> { - Table table1 = new Table.TestBuilder() - .column(-83182L, 5822L, 3389L, 7384L, 7297L) - .column(-2.06, -2.14, 8.04, 1.16, -1.0) - .build(); - tables.add(table1); - - Table table2 = new Table.TestBuilder() - .column(-47L, null, -83L, -166L, -220L, 470L, 619L, 803L, 661L) - .column(-6.08, 1.6, 1.78, -8.01, 1.22, 1.43, 2.13, -1.65, null) - .build(); - tables.add(table2); - - Table table3 = new Table.TestBuilder() - .column(8722L, 8733L) - .column(2.51, 0.0) - .build(); - tables.add(table3); - - - Table expected = new Table.TestBuilder() - .column(7384L, 7297L, 803L, 661L, 8733L) - .column(1.16, -1.0, -1.65, null, 0.0) - .build(); - tables.add(expected); - - checkMergeTable(expected, asList( - new TableSlice(3, 2, table1), - new TableSlice(7, 2, table2), - new TableSlice(1, 1, table3))); - return null; - }); - } - - @Test - public void testMergeList() { - Arms.withResource(new ArrayList
      (), tables -> { - Table table1 = new Table.TestBuilder() - .column(-881L, 482L, 660L, 896L, -129L, -108L, -428L, 0L, 617L, 782L) - .column(integers(665), integers(-267), integers(398), integers(-314), - integers(-370), integers(181), integers(665, 544), integers(222), integers(-587), - integers(544)) - .build(); - tables.add(table1); - - Table table2 = new Table.TestBuilder() - .column(-881L, 482L, 660L, 896L, 122L, 241L, 281L, 680L, 783L, null) - .column(integers(-370), integers(398), integers(-587, 398), integers(-314), - integers(307), integers(-397, -633), integers(-314, 307), integers(-633), integers(-397), - integers(181, -919, -175)) - .build(); - tables.add(table2); - - Table expected = new Table.TestBuilder() - .column(896L, -129L, -108L, -428L, 0L, 617L, 782L, 482L, 660L, 896L, 122L, 241L, - 281L, 680L, 783L, null) - .column(integers(-314), integers(-370), integers(181), integers(665, 544), integers(222), - integers(-587), integers(544), integers(398), integers(-587, 398), integers(-314), - integers(307), integers(-397, -633), integers(-314, 307), integers(-633), integers(-397), - integers(181, -919, -175)) - .build(); - tables.add(expected); - - checkMergeTable(expected, asList( - new TableSlice(3, 7, table1), - new TableSlice(1, 9, table2))); - - return null; - }); - } - - - @Test - public void testSerializeValidity() { - Arms.withResource(new ArrayList
      (), tables -> { - List col1 = new ArrayList<>(512); - col1.add(null); - col1.add(null); - col1.addAll(IntStream.range(2, 512).boxed().collect(Collectors.toList())); - - Table table1 = new Table.TestBuilder() - .column(col1.toArray(new Integer[0])) - .build(); - tables.add(table1); - - Table table2 = new Table.TestBuilder() - .column(509, 510, 511) - .build(); - tables.add(table2); - - checkMergeTable(table2, asList(new TableSlice(509, 3, table1))); - return null; - }); - } - private static Schema buildSimpleTestSchema() { Schema.Builder builder = Schema.builder(); @@ -217,7 +66,7 @@ private static Table buildSimpleTable() { return new Table.TestBuilder() .column(1, 2, 3, 4) .column("1", "12", null, "45") - .column(new Integer[]{1, null, 3}, new Integer[]{4, 5, 6}, null, new Integer[]{7, 8, 9}) + .column(new Integer[] {1, null, 3}, new Integer[] {4, 5, 6}, null, new Integer[] {7, 8, 9}) .column(st, new HostColumnVector.StructData((byte) 1, 11L), new HostColumnVector.StructData((byte) 2, null), null, new HostColumnVector.StructData((byte) 3, 33L)) @@ -243,25 +92,32 @@ private static Table buildTestTable() { new HostColumnVector.BasicType(true, DType.INT32))); return new Table.TestBuilder() - .column(100, 202, 3003, 40004, 5, -60, 1, null, 3, null, 5, null, 7, null, 9, null, 11, null, 13, null, 15) + .column(100, 202, 3003, 40004, 5, -60, 1, null, 3, null, 5, null, 7, null, 9, null, 11, + null, 13, null, 15) .column(true, true, false, false, true, null, true, true, null, false, false, null, true, true, null, false, false, null, true, true, null) - .column((byte)1, (byte)2, null, (byte)4, (byte)5,(byte)6,(byte)1,(byte)2,(byte)3, null,(byte)5, (byte)6, - (byte) 7, null,(byte) 9,(byte) 10,(byte) 11, null,(byte) 13,(byte) 14,(byte) 15) - .column((short)6, (short)5, (short)4, null, (short)2, (short)1, - (short)1, (short)2, (short)3, null, (short)5, (short)6, (short)7, null, (short)9, - (short)10, null, (short)12, (short)13, (short)14, null) + .column((byte) 1, (byte) 2, null, (byte) 4, (byte) 5, (byte) 6, (byte) 1, (byte) 2, + (byte) 3, null, (byte) 5, (byte) 6, + (byte) 7, null, (byte) 9, (byte) 10, (byte) 11, null, (byte) 13, (byte) 14, (byte) 15) + .column((short) 6, (short) 5, (short) 4, null, (short) 2, (short) 1, + (short) 1, (short) 2, (short) 3, null, (short) 5, (short) 6, (short) 7, null, (short) 9, + (short) 10, null, (short) 12, (short) 13, (short) 14, null) .column(1L, null, 1001L, 50L, -2000L, null, 1L, 2L, 3L, 4L, null, 6L, 7L, 8L, 9L, null, 11L, 12L, 13L, 14L, null) - .column(10.1f, 20f, -1f, 3.1415f, -60f, null, 1f, 2f, 3f, 4f, 5f, null, 7f, 8f, 9f, 10f, 11f, null, 13f, 14f, 15f) - .column(10.1f, 20f, -2f, 3.1415f, -60f, -50f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f, 11f, 12f, 13f, 14f, 15f) + .column(10.1f, 20f, -1f, 3.1415f, -60f, null, 1f, 2f, 3f, 4f, 5f, null, 7f, 8f, 9f, 10f, + 11f, null, 13f, 14f, 15f) + .column(10.1f, 20f, -2f, 3.1415f, -60f, -50f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f, 11f, + 12f, 13f, 14f, 15f) .column(10.1, 20.0, 33.1, 3.1415, -60.5, null, 1d, 2.0, 3.0, 4.0, 5.0, 6.0, null, 8.0, 9.0, 10.0, 11.0, 12.0, null, 14.0, 15.0) - .column((Float)null, null, null, null, null, null, null, null, null, null, + .column((Float) null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null) - .timestampDayColumn(99, 100, 101, 102, 103, 104, 1, 2, 3, 4, 5, 6, 7, null, 9, 10, 11, 12, 13, null, 15) - .timestampMillisecondsColumn(9L, 1006L, 101L, 5092L, null, 88L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, null, 10L, 11L, 12L, 13L, 14L, 15L) - .timestampSecondsColumn(1L, null, 3L, 4L, 5L, 6L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, null, 11L, 12L, 13L, 14L, 15L) + .timestampDayColumn(99, 100, 101, 102, 103, 104, 1, 2, 3, 4, 5, 6, 7, null, 9, 10, 11, 12, + 13, null, 15) + .timestampMillisecondsColumn(9L, 1006L, 101L, 5092L, null, 88L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, + 8L, null, 10L, 11L, 12L, 13L, 14L, 15L) + .timestampSecondsColumn(1L, null, 3L, 4L, 5L, 6L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, null, + 11L, 12L, 13L, 14L, 15L) .decimal32Column(-3, 100, 202, 3003, 40004, 5, -60, 1, null, 3, null, 5, null, 7, null, 9, null, 11, null, 13, null, 15) .decimal64Column(-8, 1L, null, 1001L, 50L, -2000L, null, 1L, 2L, 3L, @@ -272,8 +128,9 @@ private static Table buildTestTable() { "6", "7", "", "9", "10", "11", "12", "13", "", "15") .column("", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "") - .column("", null, "", "", null, "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "") - .column((String)null, null, null, null, null, null, null, null, null, null, + .column("", null, "", "", null, "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", + "") + .column((String) null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null) .column(mapStructType, structs(struct("1", "2")), structs(struct("3", "4")), null, null, structs(struct("key", "value"), struct("a", "b")), null, null, @@ -292,7 +149,7 @@ null, struct(-1, -1f), struct(-100, -100f), struct(Integer.MIN_VALUE, Float.MIN_VALUE)) .column(integers(1, 2), null, integers(3, 4, null, 5, null), null, null, integers(6, 7, 8), integers(null, null, null), integers(1, 2, 3), integers(4, 5, 6), integers(7, 8, 9), - integers(10, 11, 12), integers((Integer)null), integers(14, null), + integers(10, 11, 12), integers((Integer) null), integers(14, null), integers(14, 15, null, 16, 17, 18), integers(19, 20, 21), integers(22, 23, 24), integers(25, 26, 27), integers(28, 29, 30), integers(31, 32, 33), null, integers(37, 38, 39)) @@ -301,13 +158,13 @@ null, struct(-1, -1f), struct(-100, -100f), integers(), integers(), integers(), integers(), integers(), integers(), integers()) .column(integers(null, null), integers(null, null, null, null), integers(), integers(null, null, null), integers(), integers(null, null, null, null, null), - integers((Integer)null), integers(null, null, null), integers(null, null), + integers((Integer) null), integers(null, null, null), integers(null, null), integers(null, null, null, null), integers(null, null, null, null, null), integers(), integers(null, null, null, null), integers(null, null, null), integers(null, null), - integers(null, null, null), integers(null, null), integers((Integer)null), - integers((Integer)null), integers(null, null), + integers(null, null, null), integers(null, null), integers((Integer) null), + integers((Integer) null), integers(null, null), integers(null, null, null, null, null)) - .column((Integer)null, null, null, null, null, null, null, null, null, null, + .column((Integer) null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null) .column(strings("1", "2", "3"), strings("4"), strings("5"), strings("6, 7"), strings("", "9", null), strings("11"), strings(""), strings(null, null), @@ -318,13 +175,13 @@ null, struct(-1, -1f), struct(-100, -100f), strings(), strings(), strings(), strings(), strings(), strings()) .column(strings(null, null), strings(null, null, null, null), strings(), strings(null, null, null), strings(), strings(null, null, null, null, null), - strings((String)null), strings(null, null, null), strings(null, null), + strings((String) null), strings(null, null, null), strings(null, null), strings(null, null, null, null), strings(null, null, null, null, null), strings(), strings(null, null, null, null), strings(null, null, null), strings(null, null), - strings(null, null, null), strings(null, null), strings((String)null), - strings((String)null), strings(null, null), + strings(null, null, null), strings(null, null), strings((String) null), + strings((String) null), strings(null, null), strings(null, null, null, null, null)) - .column((String)null, null, null, null, null, null, null, null, null, null, + .column((String) null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null) .column(listMapType, asList(asList(struct("k1", "v1"), struct("k2", "v2")), singletonList(struct("k3", "v3"))), @@ -333,7 +190,8 @@ null, struct(-1, -1f), struct(-100, -100f), null, null, null, asList(asList(struct("k8", "v8"), struct("k9", "v9")), asList(struct("k10", "v10"), struct("k11", "v11"), struct("k12", "v12"), struct("k13", "v13"))), - singletonList(asList(struct("k14", "v14"), struct("k15", "v15"))), null, null, null, null, + singletonList(asList(struct("k14", "v14"), struct("k15", "v15"))), null, null, null, + null, asList(asList(struct("k16", "v16"), struct("k17", "v17")), singletonList(struct("k18", "v18"))), asList(asList(struct("k19", "v19"), struct("k20", "v20")), @@ -365,7 +223,8 @@ private static void checkMergeTable(Table expected, List tableSlices ByteArrayOutputStream bout = new ByteArrayOutputStream(); for (TableSlice slice : tableSlices) { - serializer.writeToStream(slice.getBaseTable(), bout, slice.getStartRow(), slice.getNumRows()); + serializer.writeToStream(slice.getBaseTable(), bout, slice.getStartRow(), + slice.getNumRows()); } bout.flush(); @@ -439,6 +298,166 @@ private static int toSchemaInner(ColumnView cv, int idx, String namePrefix, return lastIdx; } + @Test + public void testSerializeAndDeserializeTable() { + try (Table expected = buildTestTable()) { + int rowCount = toIntExact(expected.getRowCount()); + for (int sliceSize = 1; sliceSize <= rowCount; sliceSize++) { + List tableSlices = new ArrayList<>(); + for (int startRow = 0; startRow < rowCount; startRow += sliceSize) { + tableSlices.add( + new TableSlice(startRow, Math.min(sliceSize, rowCount - startRow), expected)); + } + + checkMergeTable(expected, tableSlices); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Test + public void testRowCountOnly() throws Exception { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + long bytesWritten = KudoSerializer.writeRowCountToStream(out, 5); + assertEquals(28, bytesWritten); + + ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); + KudoTableHeader header = KudoTableHeader.readFrom(new DataInputStream(in)).get(); + + assertEquals(0, header.getNumColumns()); + assertEquals(0, header.getOffset()); + assertEquals(5, header.getNumRows()); + assertEquals(0, header.getValidityBufferLen()); + assertEquals(0, header.getOffsetBufferLen()); + assertEquals(0, header.getTotalDataLen()); + } + + @Test + public void testWriteSimple() throws Exception { + KudoSerializer serializer = new KudoSerializer(buildSimpleTestSchema()); + + try (Table t = buildSimpleTable()) { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + long bytesWritten = serializer.writeToStream(t, out, 0, 4); + assertEquals(189, bytesWritten); + + ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); + + KudoTableHeader header = KudoTableHeader.readFrom(new DataInputStream(in)).get(); + assertEquals(7, header.getNumColumns()); + assertEquals(0, header.getOffset()); + assertEquals(4, header.getNumRows()); + assertEquals(24, header.getValidityBufferLen()); + assertEquals(40, header.getOffsetBufferLen()); + assertEquals(160, header.getTotalDataLen()); + + // First integer column has no validity buffer + assertFalse(header.hasValidityBuffer(0)); + for (int i = 1; i < 7; i++) { + assertTrue(header.hasValidityBuffer(i)); + } + } + } + + @Test + public void testMergeTableWithDifferentValidity() { + Arms.withResource(new ArrayList
      (), tables -> { + Table table1 = new Table.TestBuilder() + .column(-83182L, 5822L, 3389L, 7384L, 7297L) + .column(-2.06, -2.14, 8.04, 1.16, -1.0) + .build(); + tables.add(table1); + + Table table2 = new Table.TestBuilder() + .column(-47L, null, -83L, -166L, -220L, 470L, 619L, 803L, 661L) + .column(-6.08, 1.6, 1.78, -8.01, 1.22, 1.43, 2.13, -1.65, null) + .build(); + tables.add(table2); + + Table table3 = new Table.TestBuilder() + .column(8722L, 8733L) + .column(2.51, 0.0) + .build(); + tables.add(table3); + + + Table expected = new Table.TestBuilder() + .column(7384L, 7297L, 803L, 661L, 8733L) + .column(1.16, -1.0, -1.65, null, 0.0) + .build(); + tables.add(expected); + + checkMergeTable(expected, asList( + new TableSlice(3, 2, table1), + new TableSlice(7, 2, table2), + new TableSlice(1, 1, table3))); + return null; + }); + } + + @Test + public void testMergeList() { + Arms.withResource(new ArrayList
      (), tables -> { + Table table1 = new Table.TestBuilder() + .column(-881L, 482L, 660L, 896L, -129L, -108L, -428L, 0L, 617L, 782L) + .column(integers(665), integers(-267), integers(398), integers(-314), + integers(-370), integers(181), integers(665, 544), integers(222), integers(-587), + integers(544)) + .build(); + tables.add(table1); + + Table table2 = new Table.TestBuilder() + .column(-881L, 482L, 660L, 896L, 122L, 241L, 281L, 680L, 783L, null) + .column(integers(-370), integers(398), integers(-587, 398), integers(-314), + integers(307), integers(-397, -633), integers(-314, 307), integers(-633), + integers(-397), + integers(181, -919, -175)) + .build(); + tables.add(table2); + + Table expected = new Table.TestBuilder() + .column(896L, -129L, -108L, -428L, 0L, 617L, 782L, 482L, 660L, 896L, 122L, 241L, + 281L, 680L, 783L, null) + .column(integers(-314), integers(-370), integers(181), integers(665, 544), integers(222), + integers(-587), integers(544), integers(398), integers(-587, 398), integers(-314), + integers(307), integers(-397, -633), integers(-314, 307), integers(-633), + integers(-397), + integers(181, -919, -175)) + .build(); + tables.add(expected); + + checkMergeTable(expected, asList( + new TableSlice(3, 7, table1), + new TableSlice(1, 9, table2))); + + return null; + }); + } + + @Test + public void testSerializeValidity() { + Arms.withResource(new ArrayList
      (), tables -> { + List col1 = new ArrayList<>(512); + col1.add(null); + col1.add(null); + col1.addAll(IntStream.range(2, 512).boxed().collect(Collectors.toList())); + + Table table1 = new Table.TestBuilder() + .column(col1.toArray(new Integer[0])) + .build(); + tables.add(table1); + + Table table2 = new Table.TestBuilder() + .column(509, 510, 511) + .build(); + tables.add(table2); + + checkMergeTable(table2, singletonList(new TableSlice(509, 3, table1))); + return null; + }); + } + private static class TableSlice { private final int startRow; private final int numRows;