diff --git a/.github/workflows/checkstyle.yml b/.github/workflows/checkstyle.yml
index 6680ccd33b..2e0fe65f36 100644
--- a/.github/workflows/checkstyle.yml
+++ b/.github/workflows/checkstyle.yml
@@ -38,4 +38,4 @@ jobs:
run: git fetch --all
- name: Run checkstyle
- run: bash ./dev/checkstyle.sh
+ run: mvn checkstyle:check
diff --git a/dev/checkstyle.sh b/dev/checkstyle.sh
index 2265bd9f84..503d495e3b 100755
--- a/dev/checkstyle.sh
+++ b/dev/checkstyle.sh
@@ -19,13 +19,7 @@
set -ex
-# Assuming you are in the root of your git repository
-if [[ -z "${GITHUB_BASE_REF+x}" ]]; then
- MODIFIED_FILES=$(git diff --name-only)
-else
- MODIFIED_FILES=$(git diff --name-only "origin/${GITHUB_BASE_REF}")
-fi
-
+MODIFIED_FILES=$(git diff --name-only)
SRC_DIR="src/main/java/"
TEST_SRC_DIR="src/test/java/"
diff --git a/dev/checkstyle.xml b/dev/checkstyle.xml
index 15c475ec7e..26ce0964c8 100644
--- a/dev/checkstyle.xml
+++ b/dev/checkstyle.xml
@@ -41,11 +41,11 @@
-
-
-
-
-
+
+
+
+
+
@@ -56,21 +56,21 @@
-
-
-
-
-
+
+
+
+
+
-
-
-
-
+
+
+
+
+
@@ -134,7 +134,7 @@
-
+
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
-
-
-
-
+
+
+
+
+
-
-
-
-
-
+
+
+
+
+
+
+
+
-
-
+
+
@@ -290,13 +290,13 @@
METHOD_DEF, QUESTION, RESOURCE_SPECIFICATION, SUPER_CTOR_CALL, LAMBDA,
RECORD_DEF"/>
-
-
-
-
+
+
+
+
+
+
+
-
+
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
-
+
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/Arms.java b/src/main/java/com/nvidia/spark/rapids/jni/Arms.java
index 4b6ecf7204..859bc6c0ad 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/Arms.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/Arms.java
@@ -25,77 +25,80 @@
* This class contains utility methods for automatic resource management.
*/
public class Arms {
- /**
- * This method close the resource if an exception is thrown while executing the function.
- */
- public static T closeIfException(R resource, Function function) {
+ /**
+ * This method close the resource if an exception is thrown while executing the function.
+ */
+ public static T closeIfException(R resource,
+ Function function) {
+ try {
+ return function.apply(resource);
+ } catch (Exception e) {
+ if (resource != null) {
try {
- return function.apply(resource);
- } catch (Exception e) {
- if (resource != null) {
- try {
- resource.close();
- } catch (Exception inner) {
- e.addSuppressed(inner);
- }
- }
- throw e;
+ resource.close();
+ } catch (Exception inner) {
+ e.addSuppressed(inner);
}
+ }
+ throw e;
}
+ }
- /**
- * This method safely closes all the resources.
- *
- * This method will iterate through all the resources and closes them. If any exception happened during the
- * traversal, exception will be captured and rethrown after all resources closed.
- *
- */
- public static void closeAll(Iterator resources) {
- Throwable t = null;
- while (resources.hasNext()) {
- try {
- R resource = resources.next();
- if (resource != null) {
- resource.close();
- }
- } catch (Exception e) {
- if (t == null) {
- t = e;
- } else {
- t.addSuppressed(e);
- }
- }
+ /**
+ * This method safely closes all the resources.
+ *
+ * This method will iterate through all the resources and closes them. If any exception happened during the
+ * traversal, exception will be captured and rethrown after all resources closed.
+ *
+ */
+ public static void closeAll(Iterator resources) {
+ Throwable t = null;
+ while (resources.hasNext()) {
+ try {
+ R resource = resources.next();
+ if (resource != null) {
+ resource.close();
}
+ } catch (Exception e) {
+ if (t == null) {
+ t = e;
+ } else {
+ t.addSuppressed(e);
+ }
+ }
+ }
- if (t != null) throw new RuntimeException(t);
+ if (t != null) {
+ throw new RuntimeException(t);
}
+ }
- /**
- * This method safely closes all the resources. See {@link #closeAll(Iterator)} for more details.
- */
- public static void closeAll(R... resources) {
- closeAll(Arrays.asList(resources));
- }
+ /**
+ * This method safely closes all the resources. See {@link #closeAll(Iterator)} for more details.
+ */
+ public static void closeAll(R... resources) {
+ closeAll(Arrays.asList(resources));
+ }
- /**
- * This method safely closes the resources. See {@link #closeAll(Iterator)} for more details.
- */
- public static void closeAll(Collection resources) {
- closeAll(resources.iterator());
- }
+ /**
+ * This method safely closes the resources. See {@link #closeAll(Iterator)} for more details.
+ */
+ public static void closeAll(Collection resources) {
+ closeAll(resources.iterator());
+ }
- /**
- * This method safely closes the resources after applying the function.
- *
- * See {@link #closeAll(Iterator)} for more details.
- */
- public static , V> V withResource(
- C resource, Function function) {
- try {
- return function.apply(resource);
- } finally {
- closeAll(resource);
- }
+ /**
+ * This method safely closes the resources after applying the function.
+ *
+ * See {@link #closeAll(Iterator)} for more details.
+ */
+ public static , V> V withResource(
+ C resource, Function function) {
+ try {
+ return function.apply(resource);
+ } finally {
+ closeAll(resource);
}
+ }
}
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/BloomFilter.java b/src/main/java/com/nvidia/spark/rapids/jni/BloomFilter.java
index 46bf9a7f08..1b4d372953 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/BloomFilter.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/BloomFilter.java
@@ -16,16 +16,13 @@
package com.nvidia.spark.rapids.jni;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
import ai.rapids.cudf.BaseDeviceMemoryBuffer;
import ai.rapids.cudf.ColumnVector;
import ai.rapids.cudf.CudfAccessor;
import ai.rapids.cudf.CudfException;
import ai.rapids.cudf.DType;
-import ai.rapids.cudf.Scalar;
import ai.rapids.cudf.NativeDepsLoader;
+import ai.rapids.cudf.Scalar;
public class BloomFilter {
static {
@@ -34,16 +31,17 @@ public class BloomFilter {
/**
* Create a bloom filter with the specified number of hashes and bloom filter bits.
- * @param numHashes The number of hashes to use when inserting values into the bloom filter or
- * when probing.
+ *
+ * @param numHashes The number of hashes to use when inserting values into the bloom filter or
+ * when probing.
* @param bloomFilterBits Size of the bloom filter in bits.
* @return a Scalar object which encapsulates the bloom filter.
*/
- public static Scalar create(int numHashes, long bloomFilterBits){
- if(numHashes <= 0){
+ public static Scalar create(int numHashes, long bloomFilterBits) {
+ if (numHashes <= 0) {
throw new IllegalArgumentException("Bloom filters must have a positive hash count");
}
- if(bloomFilterBits <= 0){
+ if (bloomFilterBits <= 0) {
throw new IllegalArgumentException("Bloom filters must have a positive number of bits");
}
return CudfAccessor.scalarFromHandle(DType.LIST, creategpu(numHashes, bloomFilterBits));
@@ -51,54 +49,64 @@ public static Scalar create(int numHashes, long bloomFilterBits){
/**
* Insert a column of longs into a bloom filter.
+ *
* @param bloomFilter The bloom filter to which values will be inserted.
- * @param cv The column containing the values to add.
+ * @param cv The column containing the values to add.
*/
- public static void put(Scalar bloomFilter, ColumnVector cv){
+ public static void put(Scalar bloomFilter, ColumnVector cv) {
put(CudfAccessor.getScalarHandle(bloomFilter), cv.getNativeView());
}
/**
* Merge one or more bloom filters into a new bloom filter.
- * @param bloomFilters A ColumnVector containing a bloom filter per row.
+ *
+ * @param bloomFilters A ColumnVector containing a bloom filter per row.
* @return A new bloom filter containing the merged inputs.
*/
- public static Scalar merge(ColumnVector bloomFilters){
+ public static Scalar merge(ColumnVector bloomFilters) {
return CudfAccessor.scalarFromHandle(DType.LIST, merge(bloomFilters.getNativeView()));
}
/**
- * Probe a bloom filter with a column of longs. Returns a column of booleans. For
+ * Probe a bloom filter with a column of longs. Returns a column of booleans. For
* each row in the output; a value of true indicates that the corresponding input value
* -may- be in the set of values used to build the bloom filter; a value of false indicates
* that the corresponding input value is conclusively not in the set of values used to build
- * the bloom filter.
+ * the bloom filter.
+ *
* @param bloomFilter The bloom filter to be probed.
- * @param cv The column containing the values to check.
+ * @param cv The column containing the values to check.
* @return A boolean column indicating the results of the probe.
*/
- public static ColumnVector probe(Scalar bloomFilter, ColumnVector cv){
+ public static ColumnVector probe(Scalar bloomFilter, ColumnVector cv) {
return new ColumnVector(probe(CudfAccessor.getScalarHandle(bloomFilter), cv.getNativeView()));
}
/**
- * Probe a bloom filter with a column of longs. Returns a column of booleans. For
+ * Probe a bloom filter with a column of longs. Returns a column of booleans. For
* each row in the output; a value of true indicates that the corresponding input value
* -may- be in the set of values used to build the bloom filter; a value of false indicates
* that the corresponding input value is conclusively not in the set of values used to build
- * the bloom filter.
- * @param bloomFilter The bloom filter to be probed. This buffer is expected to be the
- * fully packed Spark bloom filter, including header.
- * @param cv The column containing the values to check.
+ * the bloom filter.
+ *
+ * @param bloomFilter The bloom filter to be probed. This buffer is expected to be the
+ * fully packed Spark bloom filter, including header.
+ * @param cv The column containing the values to check.
* @return A boolean column indicating the results of the probe.
*/
- public static ColumnVector probe(BaseDeviceMemoryBuffer bloomFilter, ColumnVector cv){
- return new ColumnVector(probebuffer(bloomFilter.getAddress(), bloomFilter.getLength(), cv.getNativeView()));
+ public static ColumnVector probe(BaseDeviceMemoryBuffer bloomFilter, ColumnVector cv) {
+ return new ColumnVector(
+ probebuffer(bloomFilter.getAddress(), bloomFilter.getLength(), cv.getNativeView()));
}
-
+
private static native long creategpu(int numHashes, long bloomFilterBits) throws CudfException;
+
private static native int put(long bloomFilter, long cv) throws CudfException;
+
private static native long merge(long bloomFilters) throws CudfException;
- private static native long probe(long bloomFilter, long cv) throws CudfException;
- private static native long probebuffer(long bloomFilter, long bloomFilterSize, long cv) throws CudfException;
+
+ private static native long probe(long bloomFilter, long cv) throws CudfException;
+
+ private static native long probebuffer(long bloomFilter, long bloomFilterSize, long cv)
+ throws CudfException;
}
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/CaseWhen.java b/src/main/java/com/nvidia/spark/rapids/jni/CaseWhen.java
index dabefd0a98..130c23f5e8 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/CaseWhen.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/CaseWhen.java
@@ -16,59 +16,57 @@
package com.nvidia.spark.rapids.jni;
-import ai.rapids.cudf.*;
+import ai.rapids.cudf.ColumnVector;
+import ai.rapids.cudf.DType;
/**
* Exedute SQL `case when` semantic.
* If there are multiple branches and each branch uses scalar to generator value,
* then it's fast to use this class because it does not generate temp string columns.
- *
+ *
* E.g.:
- * SQL is:
- * select
- * case
- * when bool_1_expr then "value_1"
- * when bool_2_expr then "value_2"
- * when bool_3_expr then "value_3"
- * else "value_else"
- * end
- * from tab
- *
+ * SQL is:
+ * select
+ * case
+ * when bool_1_expr then "value_1"
+ * when bool_2_expr then "value_2"
+ * when bool_3_expr then "value_3"
+ * else "value_else"
+ * end
+ * from tab
+ *
* Execution steps:
- * Execute bool exprs to get bool columns, e.g., gets:
- * bool column 1: [true, false, false, false] // bool_1_expr result
- * bool column 2: [false, true, false, flase] // bool_2_expr result
- * bool column 3: [false, false, true, flase] // bool_3_expr result
- * Execute `selectFirstTrueIndex` to get the column index for the first true in bool columns.
- * Generate a column to store salars: "value_1", "value_2", "value_3", "value_else"
- * Execute `Table.gather` to generate the final output column
- *
+ * Execute bool exprs to get bool columns, e.g., gets:
+ * bool column 1: [true, false, false, false] // bool_1_expr result
+ * bool column 2: [false, true, false, flase] // bool_2_expr result
+ * bool column 3: [false, false, true, flase] // bool_3_expr result
+ * Execute `selectFirstTrueIndex` to get the column index for the first true in bool columns.
+ * Generate a column to store salars: "value_1", "value_2", "value_3", "value_else"
+ * Execute `Table.gather` to generate the final output column
*/
public class CaseWhen {
/**
- *
* Select the column index for the first true in bool columns.
* For the row does not contain true, use end index(number of columns).
- *
+ *
* e.g.:
- * column 0: true, false, false, false
- * column 1: false, true, false, false
- * column 2: false, false, true, false
- *
- * 1st row is: true, flase, false; first true index is 0
- * 2nd row is: false, true, false; first true index is 1
- * 3rd row is: false, flase, true; first true index is 2
- * 4th row is: false, false, false; do not find true, set index to the end index 3
- *
- * output column: 0, 1, 2, 3
- * In the `case when` context, here 3 index means using NULL value.
- *
- */
+ * column 0: true, false, false, false
+ * column 1: false, true, false, false
+ * column 2: false, false, true, false
+ *
+ * 1st row is: true, flase, false; first true index is 0
+ * 2nd row is: false, true, false; first true index is 1
+ * 3rd row is: false, flase, true; first true index is 2
+ * 4th row is: false, false, false; do not find true, set index to the end index 3
+ *
+ * output column: 0, 1, 2, 3
+ * In the `case when` context, here 3 index means using NULL value.
+ */
public static ColumnVector selectFirstTrueIndex(ColumnVector[] boolColumns) {
for (ColumnVector cv : boolColumns) {
- assert(cv.getType().equals(DType.BOOL8)) : "Columns must be bools";
+ assert (cv.getType().equals(DType.BOOL8)) : "Columns must be bools";
}
long[] boolHandles = new long[boolColumns.length];
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/CastException.java b/src/main/java/com/nvidia/spark/rapids/jni/CastException.java
index ca6b5a9d56..7000cfb74f 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/CastException.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/CastException.java
@@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
package com.nvidia.spark.rapids.jni;
/**
@@ -21,9 +22,9 @@
public class CastException extends RuntimeException {
private final int rowWithError;
private final String stringWithError;
-
+
CastException(String stringWithError, int rowWithError) {
- super("Error casting data on row " + String.valueOf(rowWithError) + ": " + stringWithError);
+ super("Error casting data on row " + rowWithError + ": " + stringWithError);
this.rowWithError = rowWithError;
this.stringWithError = stringWithError;
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/CastStrings.java b/src/main/java/com/nvidia/spark/rapids/jni/CastStrings.java
index 2b2267f034..c08159b17a 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/CastStrings.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/CastStrings.java
@@ -16,9 +16,14 @@
package com.nvidia.spark.rapids.jni;
-import ai.rapids.cudf.*;
+import ai.rapids.cudf.ColumnVector;
+import ai.rapids.cudf.ColumnView;
+import ai.rapids.cudf.DType;
+import ai.rapids.cudf.NativeDepsLoader;
-/** Utility class for casting between string columns and native type columns */
+/**
+ * Utility class for casting between string columns and native type columns
+ */
public class CastStrings {
static {
NativeDepsLoader.loadNativeDeps();
@@ -28,9 +33,9 @@ public class CastStrings {
* Convert a string column to an integer column of a specified type stripping away leading and
* trailing spaces.
*
- * @param cv the column data to process.
+ * @param cv the column data to process.
* @param ansiMode true if invalid data are errors, false if they should be nulls.
- * @param type the type of the return column.
+ * @param type the type of the return column.
* @return the converted column.
*/
public static ColumnVector toInteger(ColumnView cv, boolean ansiMode, DType type) {
@@ -40,10 +45,10 @@ public static ColumnVector toInteger(ColumnView cv, boolean ansiMode, DType type
/**
* Convert a string column to an integer column of a specified type.
*
- * @param cv the column data to process.
+ * @param cv the column data to process.
* @param ansiMode true if invalid data are errors, false if they should be nulls.
- * @param strip true if leading and trailing spaces should be ignored when parsing.
- * @param type the type of the return column.
+ * @param strip true if leading and trailing spaces should be ignored when parsing.
+ * @param type the type of the return column.
* @return the converted column.
*/
public static ColumnVector toInteger(ColumnView cv, boolean ansiMode, boolean strip, DType type) {
@@ -55,10 +60,10 @@ public static ColumnVector toInteger(ColumnView cv, boolean ansiMode, boolean st
* Convert a string column to an integer column of a specified type stripping away leading and
* trailing whitespace.
*
- * @param cv the column data to process.
- * @param ansiMode true if invalid data are errors, false if they should be nulls.
+ * @param cv the column data to process.
+ * @param ansiMode true if invalid data are errors, false if they should be nulls.
* @param precision the output precision.
- * @param scale the output scale.
+ * @param scale the output scale.
* @return the converted column.
*/
public static ColumnVector toDecimal(ColumnView cv, boolean ansiMode, int precision, int scale) {
@@ -68,22 +73,22 @@ public static ColumnVector toDecimal(ColumnView cv, boolean ansiMode, int precis
/**
* Convert a string column to an integer column of a specified type.
*
- * @param cv the column data to process.
- * @param ansiMode true if invalid data are errors, false if they should be nulls.
- * @param strip true if leading and trailing white space should be stripped.
+ * @param cv the column data to process.
+ * @param ansiMode true if invalid data are errors, false if they should be nulls.
+ * @param strip true if leading and trailing white space should be stripped.
* @param precision the output precision.
- * @param scale the output scale.
+ * @param scale the output scale.
* @return the converted column.
*/
public static ColumnVector toDecimal(ColumnView cv, boolean ansiMode, boolean strip,
- int precision, int scale) {
+ int precision, int scale) {
return new ColumnVector(toDecimal(cv.getNativeView(), ansiMode, strip, precision, scale));
}
/**
* Convert a float column to a formatted string column.
*
- * @param cv the column data to process
+ * @param cv the column data to process
* @param digits the number of digits to display after the decimal point
* @return the converted column
*/
@@ -114,9 +119,9 @@ public static ColumnVector fromDecimal(ColumnView cv) {
/**
* Convert a string column to a given floating-point type column.
*
- * @param cv the column data to process.
+ * @param cv the column data to process.
* @param ansiMode true if invalid data are errors, false if they should be nulls.
- * @param type the type of the return column.
+ * @param type the type of the return column.
* @return the converted column.
*/
public static ColumnVector toFloat(ColumnView cv, boolean ansiMode, DType type) {
@@ -125,18 +130,18 @@ public static ColumnVector toFloat(ColumnView cv, boolean ansiMode, DType type)
public static ColumnVector toIntegersWithBase(ColumnView cv, int base,
- boolean ansiEnabled, DType type) {
+ boolean ansiEnabled, DType type) {
return new ColumnVector(toIntegersWithBase(cv.getNativeView(), base, ansiEnabled,
- type.getTypeId().getNativeId()));
+ type.getTypeId().getNativeId()));
}
/**
* Converts an integer column to a string column by converting the underlying integers to the
* specified base.
- *
+ *
* Note: Right now we only support base 10 and 16. The hexadecimal values will be
* returned without leading zeros or padding at the end
- *
+ *
* Example:
* input = [123, -1, 0, 27, 342718233]
* s = fromIntegersWithBase(input, 16)
@@ -144,7 +149,7 @@ public static ColumnVector toIntegersWithBase(ColumnView cv, int base,
* s = fromIntegersWithBase(input, 10)
* s is ['123', '-1', '0', '27', '342718233']
*
- * @param cv The input integer column to be converted.
+ * @param cv The input integer column to be converted.
* @param base base that we want to convert to (currently only 10/16)
* @return a new String ColumnVector
*/
@@ -153,14 +158,21 @@ public static ColumnVector fromIntegersWithBase(ColumnView cv, int base) {
}
private static native long toInteger(long nativeColumnView, boolean ansi_enabled, boolean strip,
- int dtype);
+ int dtype);
+
private static native long toDecimal(long nativeColumnView, boolean ansi_enabled, boolean strip,
- int precision, int scale);
+ int precision, int scale);
+
private static native long toFloat(long nativeColumnView, boolean ansi_enabled, int dtype);
+
private static native long fromDecimal(long nativeColumnView);
+
private static native long fromFloatWithFormat(long nativeColumnView, int digits);
+
private static native long fromFloat(long nativeColumnView);
+
private static native long toIntegersWithBase(long nativeColumnView, int base,
- boolean ansiEnabled, int dtype);
+ boolean ansiEnabled, int dtype);
+
private static native long fromIntegersWithBase(long nativeColumnView, int base);
}
\ No newline at end of file
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/DateTimeRebase.java b/src/main/java/com/nvidia/spark/rapids/jni/DateTimeRebase.java
index d73ee038d6..724fbc573c 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/DateTimeRebase.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/DateTimeRebase.java
@@ -16,7 +16,9 @@
package com.nvidia.spark.rapids.jni;
-import ai.rapids.cudf.*;
+import ai.rapids.cudf.ColumnVector;
+import ai.rapids.cudf.ColumnView;
+import ai.rapids.cudf.NativeDepsLoader;
/**
* Utility class for converting between column major and row major data
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/DecimalUtils.java b/src/main/java/com/nvidia/spark/rapids/jni/DecimalUtils.java
index 167cb099b1..9b25bff223 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/DecimalUtils.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/DecimalUtils.java
@@ -31,7 +31,7 @@ public class DecimalUtils {
* Multiply two DECIMAL128 columns together into a DECIMAL128 product rounded to the specified
* scale with overflow detection. This method considers a precision greater than 38 as overflow
* even if the number still fits in a 128-bit representation.
- *
+ *
* WARNING: This method has a bug which we match with Spark versions before 3.4.2,
* 4.0.0, 3.5.1. Consider the following example using Decimal with a precision of 38 and scale of 10:
* -8533444864753048107770677711.1312637916 * -12.0000000000 = 102401338377036577293248132533.575166
@@ -53,23 +53,24 @@ public static Table multiply128(ColumnView a, ColumnView b, int productScale) {
* Multiply two DECIMAL128 columns together into a DECIMAL128 product rounded to the specified
* scale with overflow detection. This method considers a precision greater than 38 as overflow
* even if the number still fits in a 128-bit representation.
- *
+ *
* WARNING: With interimCast set to true, this method has a bug which we match with Spark versions before 3.4.2,
* 4.0.0, 3.5.1. Consider the following example using Decimal with a precision of 38 and scale of 10:
* -8533444864753048107770677711.1312637916 * -12.0000000000 = 102401338377036577293248132533.575166
* while the actual answer based on Java BigDecimal is 102401338377036577293248132533.575165
*
- * @param a factor input, must match row count of the other factor input
- * @param b factor input, must match row count of the other factor input
+ * @param a factor input, must match row count of the other factor input
+ * @param b factor input, must match row count of the other factor input
* @param productScale scale to use for the product type
- * @param interimCast whether to cast the result of the division to 38 precision before casting it again to the final
- * precision
+ * @param interimCast whether to cast the result of the division to 38 precision before casting it again to the final
+ * precision
* @return table containing a boolean column and a DECIMAL128 product column of the specified
- * scale. The boolean value will be true if an overflow was detected for that row's
- * DECIMAL128 product value. A null input row will result in a corresponding null output
- * row.
+ * scale. The boolean value will be true if an overflow was detected for that row's
+ * DECIMAL128 product value. A null input row will result in a corresponding null output
+ * row.
*/
- public static Table multiply128(ColumnView a, ColumnView b, int productScale, boolean interimCast) {
+ public static Table multiply128(ColumnView a, ColumnView b, int productScale,
+ boolean interimCast) {
return new Table(multiply128(a.getNativeView(), b.getNativeView(), productScale, interimCast));
}
@@ -77,13 +78,14 @@ public static Table multiply128(ColumnView a, ColumnView b, int productScale, bo
* Divide two DECIMAL128 columns and produce a DECIMAL128 quotient rounded to the specified
* scale with overflow detection. This method considers a precision greater than 38 as overflow
* even if the number still fits in a 128-bit representation.
- * @param a factor input, must match row count of the other factor input
- * @param b factor input, must match row count of the other factor input
+ *
+ * @param a factor input, must match row count of the other factor input
+ * @param b factor input, must match row count of the other factor input
* @param quotientScale scale to use for the quotient type
* @return table containing a boolean column and a DECIMAL128 quotient column of the specified
- * scale. The boolean value will be true if an overflow was detected for that row's
- * DECIMAL128 quotient value. A null input row will result in a corresponding null output
- * row.
+ * scale. The boolean value will be true if an overflow was detected for that row's
+ * DECIMAL128 quotient value. A null input row will result in a corresponding null output
+ * row.
*/
public static Table divide128(ColumnView a, ColumnView b, int quotientScale) {
return new Table(divide128(a.getNativeView(), b.getNativeView(), quotientScale, false));
@@ -103,9 +105,9 @@ public static Table divide128(ColumnView a, ColumnView b, int quotientScale) {
* @param a factor input, must match row count of the other factor input
* @param b factor input, must match row count of the other factor input
* @return table containing a boolean column and a INT64 quotient column.
- * The boolean value will be true if an overflow was detected for that row's
- * INT64 quotient value. A null input row will result in a corresponding null output
- * row.
+ * The boolean value will be true if an overflow was detected for that row's
+ * INT64 quotient value. A null input row will result in a corresponding null output
+ * row.
*/
public static Table integerDivide128(ColumnView a, ColumnView b) {
return new Table(divide128(a.getNativeView(), b.getNativeView(), 0, true));
@@ -115,17 +117,17 @@ public static Table integerDivide128(ColumnView a, ColumnView b) {
* Divide two DECIMAL128 columns and produce a DECIMAL128 remainder with overflow detection.
* Example:
* 451635271134476686911387864.48 % -961.110 = 775.233
- *
+ *
* Generally, this will never really overflow unless in the divide by zero case.
* But it will detect an overflow in any case.
*
- * @param a factor input, must match row count of the other factor input
- * @param b factor input, must match row count of the other factor input
+ * @param a factor input, must match row count of the other factor input
+ * @param b factor input, must match row count of the other factor input
* @param remainderScale scale to use for the remainder type
* @return table containing a boolean column and a DECIMAL128 remainder column.
- * The boolean value will be true if an overflow was detected for that row's
- * DECIMAL128 remainder value. A null input row will result in a corresponding null
- * output row.
+ * The boolean value will be true if an overflow was detected for that row's
+ * DECIMAL128 remainder value. A null input row will result in a corresponding null
+ * output row.
*/
public static Table remainder128(ColumnView a, ColumnView b, int remainderScale) {
return new Table(remainder128(a.getNativeView(), b.getNativeView(), remainderScale));
@@ -135,17 +137,17 @@ public static Table remainder128(ColumnView a, ColumnView b, int remainderScale)
* Subtract two DECIMAL128 columns and produce a DECIMAL128 result rounded to the specified
* scale with overflow detection. This method considers a precision greater than 38 as overflow
* even if the number still fits in a 128-bit representation.
- *
+ *
* NOTE: This is very specific to Spark 3.4. This method is incompatible with previous versions
* of Spark. We don't need this for versions prior to Spark 3.4
*
- * @param a input, must match row count of the other input
- * @param b input, must match row count of the other input
+ * @param a input, must match row count of the other input
+ * @param b input, must match row count of the other input
* @param targetScale scale to use for the result
* @return table containing a boolean column and a DECIMAL128 result column of the specified
- * scale. The boolean value will be true if an overflow was detected for that row's
- * DECIMAL128 result value. A null input row will result in a corresponding null output
- * row.
+ * scale. The boolean value will be true if an overflow was detected for that row's
+ * DECIMAL128 result value. A null input row will result in a corresponding null output
+ * row.
*/
public static Table subtract128(ColumnView a, ColumnView b, int targetScale) {
@@ -160,17 +162,17 @@ public static Table subtract128(ColumnView a, ColumnView b, int targetScale) {
* Add two DECIMAL128 columns and produce a DECIMAL128 result rounded to the specified
* scale with overflow detection. This method considers a precision greater than 38 as overflow
* even if the number still fits in a 128-bit representation.
- *
+ *
* NOTE: This is very specific to Spark 3.4. This method is incompatible with previous versions
* of Spark. We don't need this for versions prior to Spark 3.4
*
- * @param a input, must match row count of the other input
- * @param b input, must match row count of the other input
+ * @param a input, must match row count of the other input
+ * @param b input, must match row count of the other input
* @param targetScale scale to use for the sum
* @return table containing a boolean column and a DECIMAL128 sum column of the specified
- * scale. The boolean value will be true if an overflow was detected for that row's
- * DECIMAL128 result value. A null input row will result in a corresponding null output
- * row.
+ * scale. The boolean value will be true if an overflow was detected for that row's
+ * DECIMAL128 result value. A null input row will result in a corresponding null output
+ * row.
*/
public static Table add128(ColumnView a, ColumnView b, int targetScale) {
if (java.lang.Math.abs(a.getType().getScale() - b.getType().getScale()) > 77) {
@@ -180,29 +182,13 @@ public static Table add128(ColumnView a, ColumnView b, int targetScale) {
return new Table(add128(a.getNativeView(), b.getNativeView(), targetScale));
}
- /**
- * A class to store the result of a cast operation from floating point values to decimals.
- *
- * Since the result column may or may not be used regardless of the value of hasFailure, we
- * need to keep it and let the caller to decide.
- */
- public static class CastFloatToDecimalResult {
- public final ColumnVector result; // the cast result
- public final boolean hasFailure; // whether the cast operation has failed for any input rows
-
- public CastFloatToDecimalResult(ColumnVector result, boolean hasFailure) {
- this.result = result;
- this.hasFailure = hasFailure;
- }
- }
-
/**
* Cast floating point values to decimals, matching the behavior of Spark.
*
- * @param input The input column, which is either FLOAT32 or FLOAT64
+ * @param input The input column, which is either FLOAT32 or FLOAT64
* @param outputType The output decimal type
* @return The decimal column resulting from the cast operation and a boolean value indicating
- * whether the cast operation has failed for any input rows
+ * whether the cast operation has failed for any input rows
*/
public static CastFloatToDecimalResult floatingPointToDecimal(ColumnView input, DType outputType,
int precision) {
@@ -212,9 +198,11 @@ public static CastFloatToDecimalResult floatingPointToDecimal(ColumnView input,
return new CastFloatToDecimalResult(new ColumnVector(result[0]), result[1] != 0);
}
- private static native long[] multiply128(long viewA, long viewB, int productScale, boolean interimCast);
+ private static native long[] multiply128(long viewA, long viewB, int productScale,
+ boolean interimCast);
- private static native long[] divide128(long viewA, long viewB, int quotientScale, boolean isIntegerDivide);
+ private static native long[] divide128(long viewA, long viewB, int quotientScale,
+ boolean isIntegerDivide);
private static native long[] remainder128(long viewA, long viewB, int remainderScale);
@@ -222,5 +210,22 @@ public static CastFloatToDecimalResult floatingPointToDecimal(ColumnView input,
private static native long[] subtract128(long viewA, long viewB, int targetScale);
- private static native long[] floatingPointToDecimal(long inputHandle, int outputTypeId, int precision, int scale);
+ private static native long[] floatingPointToDecimal(long inputHandle, int outputTypeId,
+ int precision, int scale);
+
+ /**
+ * A class to store the result of a cast operation from floating point values to decimals.
+ *
+ * Since the result column may or may not be used regardless of the value of hasFailure, we
+ * need to keep it and let the caller to decide.
+ */
+ public static class CastFloatToDecimalResult {
+ public final ColumnVector result; // the cast result
+ public final boolean hasFailure; // whether the cast operation has failed for any input rows
+
+ public CastFloatToDecimalResult(ColumnVector result, boolean hasFailure) {
+ this.result = result;
+ this.hasFailure = hasFailure;
+ }
+ }
}
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/GpuSubstringIndexUtils.java b/src/main/java/com/nvidia/spark/rapids/jni/GpuSubstringIndexUtils.java
index a8750919c9..4afaeedf06 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/GpuSubstringIndexUtils.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/GpuSubstringIndexUtils.java
@@ -16,16 +16,23 @@
package com.nvidia.spark.rapids.jni;
-import ai.rapids.cudf.*;
+import ai.rapids.cudf.ColumnVector;
+import ai.rapids.cudf.ColumnView;
+import ai.rapids.cudf.CudfAccessor;
+import ai.rapids.cudf.CudfException;
+import ai.rapids.cudf.NativeDepsLoader;
+import ai.rapids.cudf.Scalar;
public class GpuSubstringIndexUtils {
- static{
- NativeDepsLoader.loadNativeDeps();
- }
+ static {
+ NativeDepsLoader.loadNativeDeps();
+ }
- public static ColumnVector substringIndex(ColumnView cv, Scalar delimiter, int count){
- return new ColumnVector(substringIndex(cv.getNativeView(), CudfAccessor.getScalarHandle(delimiter), count));
- }
+ public static ColumnVector substringIndex(ColumnView cv, Scalar delimiter, int count) {
+ return new ColumnVector(
+ substringIndex(cv.getNativeView(), CudfAccessor.getScalarHandle(delimiter), count));
+ }
- private static native long substringIndex(long columnView, long delimiter, int count) throws CudfException;
+ private static native long substringIndex(long columnView, long delimiter, int count)
+ throws CudfException;
}
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/GpuTimeZoneDB.java b/src/main/java/com/nvidia/spark/rapids/jni/GpuTimeZoneDB.java
index a8048b1e8b..3ad54b996d 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/GpuTimeZoneDB.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/GpuTimeZoneDB.java
@@ -1,18 +1,18 @@
/*
-* Copyright (c) 2023-2024, NVIDIA CORPORATION.
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*/
+ * Copyright (c) 2023-2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
package com.nvidia.spark.rapids.jni;
@@ -20,9 +20,6 @@
import ai.rapids.cudf.DType;
import ai.rapids.cudf.HostColumnVector;
import ai.rapids.cudf.Table;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
import java.time.Instant;
import java.time.ZoneId;
import java.time.zone.ZoneOffsetTransition;
@@ -34,17 +31,19 @@
import java.util.Map;
import java.util.TimeZone;
import java.util.concurrent.Executors;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
/**
* Gpu time zone utility.
* Provides two kinds of APIs
- * - Time zone transitions cache APIs
- * `cacheDatabaseAsync`, `cacheDatabase` and `shutdown` are synchronized.
- * When cacheDatabaseAsync is running, the `shutdown` and `cacheDatabase` will wait;
- * These APIs guarantee only one thread is loading transitions cache,
- * And guarantee loading cache only occurs one time.
- * - Rebase time zone APIs
- * fromTimestampToUtcTimestamp, fromUtcTimestampToTimestamp ...
+ * - Time zone transitions cache APIs
+ * `cacheDatabaseAsync`, `cacheDatabase` and `shutdown` are synchronized.
+ * When cacheDatabaseAsync is running, the `shutdown` and `cacheDatabase` will wait;
+ * These APIs guarantee only one thread is loading transitions cache,
+ * And guarantee loading cache only occurs one time.
+ * - Rebase time zone APIs
+ * fromTimestampToUtcTimestamp, fromUtcTimestampToTimestamp ...
*/
public class GpuTimeZoneDB {
private static final Logger log = LoggerFactory.getLogger(GpuTimeZoneDB.class);
@@ -104,7 +103,7 @@ private static synchronized void cacheDatabaseImpl() {
}
}
- private static synchronized void closeResources() {
+ private static synchronized void closeResources() {
if (zoneIdToTable != null) {
zoneIdToTable.clear();
zoneIdToTable = null;
@@ -115,7 +114,8 @@ private static synchronized void closeResources() {
}
}
- public static ColumnVector fromTimestampToUtcTimestamp(ColumnVector input, ZoneId currentTimeZone) {
+ public static ColumnVector fromTimestampToUtcTimestamp(ColumnVector input,
+ ZoneId currentTimeZone) {
// TODO: Remove this check when all timezones are supported
// (See https://github.com/NVIDIA/spark-rapids/issues/6840)
if (!isSupportedTimeZone(currentTimeZone)) {
@@ -132,8 +132,9 @@ public static ColumnVector fromTimestampToUtcTimestamp(ColumnVector input, ZoneI
transitions.getNativeView(), tzIndex));
}
}
-
- public static ColumnVector fromUtcTimestampToTimestamp(ColumnVector input, ZoneId desiredTimeZone) {
+
+ public static ColumnVector fromUtcTimestampToTimestamp(ColumnVector input,
+ ZoneId desiredTimeZone) {
// TODO: Remove this check when all timezones are supported
// (See https://github.com/NVIDIA/spark-rapids/issues/6840)
if (!isSupportedTimeZone(desiredTimeZone)) {
@@ -150,13 +151,13 @@ public static ColumnVector fromUtcTimestampToTimestamp(ColumnVector input, ZoneI
transitions.getNativeView(), tzIndex));
}
}
-
+
// TODO: Deprecate this API when we support all timezones
// (See https://github.com/NVIDIA/spark-rapids/issues/6840)
public static boolean isSupportedTimeZone(ZoneId desiredTimeZone) {
return desiredTimeZone != null &&
- (desiredTimeZone.getRules().isFixedOffset() ||
- desiredTimeZone.getRules().getTransitionRules().isEmpty());
+ (desiredTimeZone.getRules().isFixedOffset() ||
+ desiredTimeZone.getRules().getTransitionRules().isEmpty());
}
public static boolean isSupportedTimeZone(String zoneId) {
@@ -170,10 +171,10 @@ public static boolean isSupportedTimeZone(String zoneId) {
// Ported from Spark. Used to format time zone ID string with (+|-)h:mm and (+|-)hh:m
public static ZoneId getZoneId(String timeZoneId) {
String formattedZoneId = timeZoneId
- // To support the (+|-)h:mm format because it was supported before Spark 3.0.
- .replaceFirst("(\\+|\\-)(\\d):", "$10$2:")
- // To support the (+|-)hh:m format because it was supported before Spark 3.0.
- .replaceFirst("(\\+|\\-)(\\d\\d):(\\d)$", "$1$2:0$3");
+ // To support the (+|-)h:mm format because it was supported before Spark 3.0.
+ .replaceFirst("(\\+|\\-)(\\d):", "$10$2:")
+ // To support the (+|-)hh:m format because it was supported before Spark 3.0.
+ .replaceFirst("(\\+|\\-)(\\d\\d):(\\d)$", "$1$2:0$3");
return ZoneId.of(formattedZoneId, ZoneId.SHORT_IDS);
}
@@ -266,12 +267,13 @@ private static synchronized ColumnVector getFixedTransitions() {
/**
* FOR TESTING PURPOSES ONLY, DO NOT USE IN PRODUCTION
- *
- * This method retrieves the raw list of struct data that forms the list of
- * fixed transitions for a particular zoneId.
- *
+ *
+ * This method retrieves the raw list of struct data that forms the list of
+ * fixed transitions for a particular zoneId.
+ *
* It has default visibility so the test can access it.
- * @param zoneId
+ *
+ * @param zoneId Zone id
* @return list of fixed transitions
*/
static synchronized List getHostFixedTransitions(String zoneId) {
@@ -285,5 +287,6 @@ static synchronized List getHostFixedTransitions(String zoneId) {
private static native long convertTimestampColumnToUTC(long input, long transitions, int tzIndex);
- private static native long convertUTCTimestampColumnToTimeZone(long input, long transitions, int tzIndex);
+ private static native long convertUTCTimestampColumnToTimeZone(long input, long transitions,
+ int tzIndex);
}
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/Hash.java b/src/main/java/com/nvidia/spark/rapids/jni/Hash.java
index a25fead0fd..f7d4223697 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/Hash.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/Hash.java
@@ -33,68 +33,71 @@ public class Hash {
* Create a new vector containing spark's 32-bit murmur3 hash of each row in the table.
* Spark's murmur3 hash uses a different tail processing algorithm.
*
- * @param seed integer seed for the murmur3 hash function
+ * @param seed integer seed for the murmur3 hash function
* @param columns array of columns to hash, must have identical number of rows.
* @return the new ColumnVector of 32-bit values representing each row's hash value.
*/
- public static ColumnVector murmurHash32(int seed, ColumnView columns[]) {
+ public static ColumnVector murmurHash32(int seed, ColumnView[] columns) {
if (columns.length < 1) {
throw new IllegalArgumentException("Murmur3 hashing requires at least 1 column of input");
}
long[] columnViews = new long[columns.length];
long size = columns[0].getRowCount();
- for(int i = 0; i < columns.length; i++) {
+ for (int i = 0; i < columns.length; i++) {
assert columns[i] != null : "Column vectors passed may not be null";
- assert columns[i].getRowCount() == size : "Row count mismatch, all columns must be the same size";
+ assert columns[i].getRowCount() == size :
+ "Row count mismatch, all columns must be the same size";
assert !columns[i].getType().isDurationType() : "Unsupported column type Duration";
- columnViews[i] = columns[i].getNativeView();
+ columnViews[i] = columns[i].getNativeView();
}
return new ColumnVector(murmurHash32(seed, columnViews));
}
- public static ColumnVector murmurHash32(ColumnView columns[]) {
+ public static ColumnVector murmurHash32(ColumnView[] columns) {
return murmurHash32(0, columns);
}
/**
* Create a new vector containing the xxhash64 hash of each row in the table.
*
- * @param seed integer seed for the xxhash64 hash function
+ * @param seed integer seed for the xxhash64 hash function
* @param columns array of columns to hash, must have identical number of rows.
* @return the new ColumnVector of 64-bit values representing each row's hash value.
*/
- public static ColumnVector xxhash64(long seed, ColumnView columns[]) {
+ public static ColumnVector xxhash64(long seed, ColumnView[] columns) {
if (columns.length < 1) {
throw new IllegalArgumentException("xxhash64 hashing requires at least 1 column of input");
}
long[] columnViews = new long[columns.length];
long size = columns[0].getRowCount();
- for(int i = 0; i < columns.length; i++) {
+ for (int i = 0; i < columns.length; i++) {
assert columns[i] != null : "Column vectors passed may not be null";
- assert columns[i].getRowCount() == size : "Row count mismatch, all columns must be the same size";
+ assert columns[i].getRowCount() == size :
+ "Row count mismatch, all columns must be the same size";
assert !columns[i].getType().isDurationType() : "Unsupported column type Duration";
assert !columns[i].getType().isNestedType() : "Unsupported column type Nested";
- columnViews[i] = columns[i].getNativeView();
+ columnViews[i] = columns[i].getNativeView();
}
return new ColumnVector(xxhash64(seed, columnViews));
}
- public static ColumnVector xxhash64(ColumnView columns[]) {
+ public static ColumnVector xxhash64(ColumnView[] columns) {
return xxhash64(DEFAULT_XXHASH64_SEED, columns);
}
- public static ColumnVector hiveHash(ColumnView columns[]) {
+ public static ColumnVector hiveHash(ColumnView[] columns) {
if (columns.length < 1) {
throw new IllegalArgumentException("Hive hashing requires at least 1 column of input");
}
long[] columnViews = new long[columns.length];
long size = columns[0].getRowCount();
- for(int i = 0; i < columns.length; i++) {
+ for (int i = 0; i < columns.length; i++) {
assert columns[i] != null : "Column vectors passed may not be null";
- assert columns[i].getRowCount() == size : "Row count mismatch, all columns must be the same size";
+ assert columns[i].getRowCount() == size :
+ "Row count mismatch, all columns must be the same size";
assert !columns[i].getType().isDurationType() : "Unsupported column type Duration";
assert !columns[i].getType().isNestedType() : "Unsupported column type Nested";
columnViews[i] = columns[i].getNativeView();
@@ -103,7 +106,7 @@ public static ColumnVector hiveHash(ColumnView columns[]) {
}
private static native long murmurHash32(int seed, long[] viewHandles) throws CudfException;
-
+
private static native long xxhash64(long seed, long[] viewHandles) throws CudfException;
private static native long hiveHash(long[] viewHandles) throws CudfException;
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/HostTable.java b/src/main/java/com/nvidia/spark/rapids/jni/HostTable.java
index 754412d727..5e268a70cf 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/HostTable.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/HostTable.java
@@ -36,10 +36,16 @@ public class HostTable implements AutoCloseable {
private long nativeTableView;
private HostMemoryBuffer hostBuffer;
+ private HostTable(long tableHandle, HostMemoryBuffer hostBuffer) {
+ this.nativeTableView = tableHandle;
+ this.hostBuffer = hostBuffer;
+ }
+
/**
* Copies a device table to a host table asynchronously.
* NOTE: The caller must synchronize on the stream before examining the data on the host.
- * @param table device table to copy
+ *
+ * @param table device table to copy
* @param stream stream to use for the copy
* @return host table
*/
@@ -63,7 +69,8 @@ public static HostTable fromTableAsync(Table table, Cuda.Stream stream) {
/**
* Copies a device table to a host table synchronously.
- * @param table device table to copy
+ *
+ * @param table device table to copy
* @param stream stream to use for the copy
* @return host table
*/
@@ -75,6 +82,7 @@ public static HostTable fromTable(Table table, Cuda.Stream stream) {
/**
* Copies a device table to a host table synchronously on the default stream.
+ *
* @param table device table to copy
* @return host table
*/
@@ -82,10 +90,16 @@ public static HostTable fromTable(Table table) {
return fromTable(table, Cuda.DEFAULT_STREAM);
}
- private HostTable(long tableHandle, HostMemoryBuffer hostBuffer) {
- this.nativeTableView = tableHandle;
- this.hostBuffer = hostBuffer;
- }
+ private static native long bufferSize(long tableHandle, long stream);
+
+ private static native long copyFromTableAsync(long tableHandle, long hostAddress, long hostSize,
+ long stream);
+
+ private static native long[] toDeviceColumnViews(long tableHandle, long hostToDevPtrOffset);
+
+ private static native void freeDeviceColumnView(long columnHandle);
+
+ private static native void freeHostTable(long tableHandle);
/**
* Gets the address of the host_table_view for this host table.
@@ -106,6 +120,7 @@ public HostMemoryBuffer getHostBuffer() {
* Copies the host table to a device table asynchronously.
* NOTE: The caller must synchronize on the stream before closing this instance,
* or the copy could still be in-flight when the host memory is invalidated or reused.
+ *
* @param stream stream to use for the copy
* @return device table
*/
@@ -120,7 +135,8 @@ public Table toTableAsync(Cuda.Stream stream) {
boolean done = false;
try {
for (int i = 0; i < columnViewHandles.length; i++) {
- columns[i] = ColumnVector.fromViewWithContiguousAllocation(columnViewHandles[i], devBuffer);
+ columns[i] =
+ ColumnVector.fromViewWithContiguousAllocation(columnViewHandles[i], devBuffer);
columnViewHandles[i] = 0;
}
table = new Table(columns);
@@ -149,6 +165,7 @@ public Table toTableAsync(Cuda.Stream stream) {
/**
* Copies the host table to a device table synchronously.
+ *
* @param stream stream to use for the copy
* @return device table
*/
@@ -160,6 +177,7 @@ public Table toTable(Cuda.Stream stream) {
/**
* Copies the host table to a device table synchronously on the default stream.
+ *
* @return device table
*/
public Table toTable() {
@@ -176,15 +194,4 @@ public void close() {
hostBuffer = null;
}
}
-
- private static native long bufferSize(long tableHandle, long stream);
-
- private static native long copyFromTableAsync(long tableHandle, long hostAddress, long hostSize,
- long stream);
-
- private static native long[] toDeviceColumnViews(long tableHandle, long hostToDevPtrOffset);
-
- private static native void freeDeviceColumnView(long columnHandle);
-
- private static native void freeHostTable(long tableHandle);
}
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/JSONUtils.java b/src/main/java/com/nvidia/spark/rapids/jni/JSONUtils.java
index c58caa62e9..dab9476357 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/JSONUtils.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/JSONUtils.java
@@ -16,52 +16,30 @@
package com.nvidia.spark.rapids.jni;
-import ai.rapids.cudf.*;
-
+import ai.rapids.cudf.ColumnVector;
+import ai.rapids.cudf.ColumnView;
+import ai.rapids.cudf.DType;
+import ai.rapids.cudf.DeviceMemoryBuffer;
+import ai.rapids.cudf.JSONOptions;
+import ai.rapids.cudf.NativeDepsLoader;
import java.util.List;
public class JSONUtils {
- static {
- NativeDepsLoader.loadNativeDeps();
- }
-
public static final int MAX_PATH_DEPTH = getMaxJSONPathDepth();
- public enum PathInstructionType {
- WILDCARD,
- INDEX,
- NAMED
- }
-
- public static class PathInstructionJni {
- // type: byte, name: String, index: int
- private final byte type;
- private final String name;
- private final int index;
-
- public PathInstructionJni(PathInstructionType type, String name, long index) {
- this.type = (byte) type.ordinal();
- this.name = name;
- if (index > Integer.MAX_VALUE) {
- throw new IllegalArgumentException("index is too large " + index);
- }
- this.index = (int) index;
- }
-
- public PathInstructionJni(PathInstructionType type, String name, int index) {
- this.type = (byte) type.ordinal();
- this.name = name;
- this.index = index;
- }
+ static {
+ NativeDepsLoader.loadNativeDeps();
}
/**
* Extract a JSON path from a JSON column. The path is processed in a Spark compatible way.
- * @param input the string column containing JSON
+ *
+ * @param input the string column containing JSON
* @param pathInstructions the instructions for the path processing
* @return the result of processing the path
*/
- public static ColumnVector getJsonObject(ColumnVector input, PathInstructionJni[] pathInstructions) {
+ public static ColumnVector getJsonObject(ColumnVector input,
+ PathInstructionJni[] pathInstructions) {
assert (input.getType().equals(DType.STRING)) : "Input must be of STRING type";
int numTotalInstructions = pathInstructions.length;
byte[] typeNums = new byte[numTotalInstructions];
@@ -80,6 +58,7 @@ public static ColumnVector getJsonObject(ColumnVector input, PathInstructionJni[
/**
* Extract multiple JSON paths from a JSON column. The paths are processed in a Spark
* compatible way.
+ *
* @param input the string column containing JSON
* @param paths the instructions for multiple paths
* @return the result of processing each path in the order that they were passed in
@@ -92,15 +71,16 @@ public static ColumnVector[] getJsonObjectMultiplePaths(ColumnVector input,
/**
* Extract multiple JSON paths from a JSON column. The paths are processed in a Spark
* compatible way.
- * @param input the string column containing JSON
- * @param paths the instructions for multiple paths
+ *
+ * @param input the string column containing JSON
+ * @param paths the instructions for multiple paths
* @param memoryBudgetBytes a budget that is used to limit the amount of memory
* that is used when processing the paths. This is a soft limit.
* A value <= 0 disables this and all paths will be processed in parallel.
- * @param parallelOverride Set a maximum number of paths to be processed in parallel. The memory
- * budget can limit how many paths can be processed in parallel. This overrides
- * that automatically calculated value with a set value for benchmarking purposes.
- * A value <= 0 disables this.
+ * @param parallelOverride Set a maximum number of paths to be processed in parallel. The memory
+ * budget can limit how many paths can be processed in parallel. This overrides
+ * that automatically calculated value with a set value for benchmarking purposes.
+ * A value <= 0 disables this.
* @return the result of processing each path in the order that they were passed in
*/
public static ColumnVector[] getJsonObjectMultiplePaths(ColumnVector input,
@@ -137,7 +117,6 @@ public static ColumnVector[] getJsonObjectMultiplePaths(ColumnVector input,
return ret;
}
-
/**
* Extract key-value pairs for each output map from the given json strings. These key-value are
* copied directly as substrings of the input without any type conversion.
@@ -152,9 +131,9 @@ public static ColumnVector[] getJsonObjectMultiplePaths(ColumnVector input,
* function will just simply copy the input value strings to the output.
*
* @param input The input strings column in which each row specifies a json object
- * @param opts The options for parsing JSON strings
+ * @param opts The options for parsing JSON strings
* @return A map column (i.e., a column of type {@code List>}) in
- * which the key-value pairs are extracted directly from the input json strings
+ * which the key-value pairs are extracted directly from the input json strings
*/
public static ColumnVector extractRawMapFromJsonString(ColumnView input, JSONOptions opts) {
assert (input.getType().equals(DType.STRING)) : "Input must be of STRING type";
@@ -169,12 +148,11 @@ public static ColumnVector extractRawMapFromJsonString(ColumnView input, JSONOpt
* Extract key-value pairs for each output map from the given json strings. This method is
* similar to {@link #extractRawMapFromJsonString(ColumnView, JSONOptions)} but is deprecated.
*
- * @deprecated This method is deprecated since it does not have parameters to control various
- * JSON reader behaviors.
- *
* @param input The input strings column in which each row specifies a json object
* @return A map column (i.e., a column of type {@code List>}) in
- * which the key-value pairs are extracted directly from the input json strings
+ * which the key-value pairs are extracted directly from the input json strings
+ * @deprecated This method is deprecated since it does not have parameters to control various
+ * JSON reader behaviors.
*/
public static ColumnVector extractRawMapFromJsonString(ColumnView input) {
assert (input.getType().equals(DType.STRING)) : "Input must be of STRING type";
@@ -182,30 +160,6 @@ public static ColumnVector extractRawMapFromJsonString(ColumnView input) {
true, true, true, true));
}
- /**
- * A class to hold the result when concatenating JSON strings.
- *
- * A long with the concatenated data, the result also contains a vector that indicates
- * whether each row in the input is null or empty, and the delimiter used for concatenation.
- */
- public static class ConcatenatedJson implements AutoCloseable {
- public final ColumnVector isNullOrEmpty;
- public final DeviceMemoryBuffer data;
- public final char delimiter;
-
- public ConcatenatedJson(ColumnVector isNullOrEmpty, DeviceMemoryBuffer data, char delimiter) {
- this.isNullOrEmpty = isNullOrEmpty;
- this.data = data;
- this.delimiter = delimiter;
- }
-
- @Override
- public void close() {
- isNullOrEmpty.close();
- data.close();
- }
- }
-
/**
* Concatenate JSON strings in the input column into a single JSON string.
*
@@ -231,7 +185,7 @@ public static ConcatenatedJson concatenateJsonStrings(ColumnView input) {
* by the input isNull column.
*
* @param children The children columns of the output structs column
- * @param isNull A boolean column specifying the rows at which the output column should be null
+ * @param isNull A boolean column specifying the rows at which the output column should be null
* @return A structs column created from the given children and the isNull column
*/
public static ColumnVector makeStructs(ColumnView[] children, ColumnView isNull) {
@@ -257,7 +211,6 @@ private static native long[] getJsonObjectMultiplePaths(long input,
long memoryBudgetBytes,
int parallelOverride);
-
private static native long extractRawMapFromJsonString(long input,
boolean normalizeSingleQuotes,
boolean leadingZerosAllowed,
@@ -267,4 +220,57 @@ private static native long extractRawMapFromJsonString(long input,
private static native long[] concatenateJsonStrings(long input);
private static native long makeStructs(long[] children, long isNull);
+
+
+ public enum PathInstructionType {
+ WILDCARD,
+ INDEX,
+ NAMED
+ }
+
+ public static class PathInstructionJni {
+ // type: byte, name: String, index: int
+ private final byte type;
+ private final String name;
+ private final int index;
+
+ public PathInstructionJni(PathInstructionType type, String name, long index) {
+ this.type = (byte) type.ordinal();
+ this.name = name;
+ if (index > Integer.MAX_VALUE) {
+ throw new IllegalArgumentException("index is too large " + index);
+ }
+ this.index = (int) index;
+ }
+
+ public PathInstructionJni(PathInstructionType type, String name, int index) {
+ this.type = (byte) type.ordinal();
+ this.name = name;
+ this.index = index;
+ }
+ }
+
+ /**
+ * A class to hold the result when concatenating JSON strings.
+ *
+ * A long with the concatenated data, the result also contains a vector that indicates
+ * whether each row in the input is null or empty, and the delimiter used for concatenation.
+ */
+ public static class ConcatenatedJson implements AutoCloseable {
+ public final ColumnVector isNullOrEmpty;
+ public final DeviceMemoryBuffer data;
+ public final char delimiter;
+
+ public ConcatenatedJson(ColumnVector isNullOrEmpty, DeviceMemoryBuffer data, char delimiter) {
+ this.isNullOrEmpty = isNullOrEmpty;
+ this.data = data;
+ this.delimiter = delimiter;
+ }
+
+ @Override
+ public void close() {
+ isNullOrEmpty.close();
+ data.close();
+ }
+ }
}
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/Pair.java b/src/main/java/com/nvidia/spark/rapids/jni/Pair.java
index ac8aa1910c..8a5137ea72 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/Pair.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/Pair.java
@@ -20,23 +20,23 @@
* A utility class for holding a pair of values.
*/
public class Pair {
- private final K left;
- private final V right;
+ private final K left;
+ private final V right;
- public Pair(K left, V right) {
- this.left = left;
- this.right = right;
- }
+ public Pair(K left, V right) {
+ this.left = left;
+ this.right = right;
+ }
- public K getLeft() {
- return left;
- }
+ public static Pair of(K left, V right) {
+ return new Pair<>(left, right);
+ }
- public V getRight() {
- return right;
- }
+ public K getLeft() {
+ return left;
+ }
- public static Pair of(K left, V right) {
- return new Pair<>(left, right);
- }
+ public V getRight() {
+ return right;
+ }
}
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/ParquetFooter.java b/src/main/java/com/nvidia/spark/rapids/jni/ParquetFooter.java
index 681a01d81d..4f6f189bff 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/ParquetFooter.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/ParquetFooter.java
@@ -16,8 +16,11 @@
package com.nvidia.spark.rapids.jni;
-import ai.rapids.cudf.*;
-
+import ai.rapids.cudf.CudfException;
+import ai.rapids.cudf.DefaultHostMemoryAllocator;
+import ai.rapids.cudf.HostMemoryAllocator;
+import ai.rapids.cudf.HostMemoryBuffer;
+import ai.rapids.cudf.NativeDepsLoader;
import java.util.ArrayList;
import java.util.Locale;
@@ -29,116 +32,19 @@ public class ParquetFooter implements AutoCloseable {
NativeDepsLoader.loadNativeDeps();
}
- /**
- * Base element for all types in a parquet schema.
- */
- public static abstract class SchemaElement {}
-
- private static class ElementWithName {
- final String name;
- final SchemaElement element;
-
- public ElementWithName(String name, SchemaElement element) {
- this.name = name;
- this.element = element;
- }
- }
-
- public static class StructElement extends SchemaElement {
- public static StructBuilder builder() {
- return new StructBuilder();
- }
-
- private final ElementWithName[] children;
- private StructElement(ElementWithName[] children) {
- this.children = children;
- }
- }
-
- public static class StructBuilder {
- ArrayList children = new ArrayList<>();
-
- StructBuilder() {
- // Empty
- }
-
- public StructBuilder addChild(String name, SchemaElement child) {
- children.add(new ElementWithName(name, child));
- return this;
- }
-
- public StructElement build() {
- return new StructElement(children.toArray(new ElementWithName[0]));
- }
- }
-
- public static class ValueElement extends SchemaElement {
- public ValueElement() {}
- }
-
- public static class ListElement extends SchemaElement {
- private final SchemaElement item;
- public ListElement(SchemaElement item) {
- this.item = item;
- }
- }
-
- public static class MapElement extends SchemaElement {
- private final SchemaElement key;
- private final SchemaElement value;
- public MapElement(SchemaElement key, SchemaElement value) {
- this.key = key;
- this.value = value;
- }
- }
-
private long nativeHandle;
private ParquetFooter(long handle) {
nativeHandle = handle;
}
- /**
- * Write the filtered footer back out in a format that is compatible with a parquet
- * footer file. This will include the MAGIC PAR1 at the beginning and end and also the
- * length of the footer just before the PAR1 at the end.
- */
- public HostMemoryBuffer serializeThriftFile(HostMemoryAllocator hostMemoryAllocator) {
- return serializeThriftFile(nativeHandle, hostMemoryAllocator);
- }
-
- public HostMemoryBuffer serializeThriftFile() {
- return serializeThriftFile(DefaultHostMemoryAllocator.get());
- }
-
- /**
- * Get the number of rows in the footer after filtering.
- */
- public long getNumRows() {
- return getNumRows(nativeHandle);
- }
-
- /**
- * Get the number of top level columns in the footer after filtering.
- */
- public int getNumColumns() {
- return getNumColumns(nativeHandle);
- }
-
- @Override
- public void close() throws Exception {
- if (nativeHandle != 0) {
- close(nativeHandle);
- nativeHandle = 0;
- }
- }
-
/**
* Recursive helper function to flatten a SchemaElement, so it can more efficiently be passed
* through JNI.
*/
private static void depthFirstNamesHelper(SchemaElement se, String name, boolean makeLowerCase,
- ArrayList names, ArrayList numChildren, ArrayList tags) {
+ ArrayList names, ArrayList numChildren,
+ ArrayList tags) {
if (makeLowerCase) {
name = name.toLowerCase(Locale.ROOT);
}
@@ -181,28 +87,31 @@ private static void depthFirstNamesHelper(SchemaElement se, String name, boolean
* Flatten a SchemaElement, so it can more efficiently be passed through JNI.
*/
private static void depthFirstNames(StructElement schema, boolean makeLowerCase,
- ArrayList names, ArrayList numChildren, ArrayList tags) {
+ ArrayList names, ArrayList numChildren,
+ ArrayList tags) {
// Initialize them with a quick length for non-nested values
- for (ElementWithName se: schema.children) {
+ for (ElementWithName se : schema.children) {
depthFirstNamesHelper(se.element, se.name, makeLowerCase, names, numChildren, tags);
}
}
/**
- * Read a parquet thrift footer from a buffer and filter it like the java code would. The buffer
+ * Read a parquet thrift footer from a buffer and filter it like the java code would. The buffer
* should only include the thrift footer itself. This includes filtering out row groups that do
* not fall within the partition and pruning columns that are not needed.
- * @param buffer the buffer to parse the footer out from.
+ *
+ * @param buffer the buffer to parse the footer out from.
* @param partOffset for a split the start of the split
* @param partLength the length of the split
- * @param schema a stripped down schema so the code can verify that the types match what is
- * expected. The java code does this too.
+ * @param schema a stripped down schema so the code can verify that the types match what is
+ * expected. The java code does this too.
* @param ignoreCase should case be ignored when matching column names. If this is true then
* names should be converted to lower case before being passed to this.
* @return a reference to the parsed footer.
*/
public static ParquetFooter readAndFilter(HostMemoryBuffer buffer,
- long partOffset, long partLength, StructElement schema, boolean ignoreCase) {
+ long partOffset, long partLength, StructElement schema,
+ boolean ignoreCase) {
int parentNumChildren = schema.children.length;
ArrayList names = new ArrayList<>();
ArrayList numChildren = new ArrayList<>();
@@ -210,8 +119,7 @@ public static ParquetFooter readAndFilter(HostMemoryBuffer buffer,
depthFirstNames(schema, ignoreCase, names, numChildren, tags);
return new ParquetFooter(
- readAndFilter
- (buffer.getAddress(), buffer.getLength(),
+ readAndFilter(buffer.getAddress(), buffer.getLength(),
partOffset, partLength,
names.toArray(new String[0]),
numChildren.stream().mapToInt(i -> i).toArray(),
@@ -220,15 +128,13 @@ public static ParquetFooter readAndFilter(HostMemoryBuffer buffer,
ignoreCase));
}
- // Native APIS
-
private static native long readAndFilter(long address, long length,
- long partOffset, long partLength,
- String[] names,
- int[] numChildren,
- int[] tags,
- int parentNumChildren,
- boolean ignoreCase) throws CudfException;
+ long partOffset, long partLength,
+ String[] names,
+ int[] numChildren,
+ int[] tags,
+ int parentNumChildren,
+ boolean ignoreCase) throws CudfException;
private static native void close(long nativeHandle);
@@ -237,5 +143,110 @@ private static native long readAndFilter(long address, long length,
private static native int getNumColumns(long nativeHandle);
private static native HostMemoryBuffer serializeThriftFile(long nativeHandle,
- HostMemoryAllocator hostMemoryAllocator);
+ HostMemoryAllocator hostMemoryAllocator);
+
+ /**
+ * Write the filtered footer back out in a format that is compatible with a parquet
+ * footer file. This will include the MAGIC PAR1 at the beginning and end and also the
+ * length of the footer just before the PAR1 at the end.
+ */
+ public HostMemoryBuffer serializeThriftFile(HostMemoryAllocator hostMemoryAllocator) {
+ return serializeThriftFile(nativeHandle, hostMemoryAllocator);
+ }
+
+ public HostMemoryBuffer serializeThriftFile() {
+ return serializeThriftFile(DefaultHostMemoryAllocator.get());
+ }
+
+ /**
+ * Get the number of rows in the footer after filtering.
+ */
+ public long getNumRows() {
+ return getNumRows(nativeHandle);
+ }
+
+ /**
+ * Get the number of top level columns in the footer after filtering.
+ */
+ public int getNumColumns() {
+ return getNumColumns(nativeHandle);
+ }
+
+ @Override
+ public void close() throws Exception {
+ if (nativeHandle != 0) {
+ close(nativeHandle);
+ nativeHandle = 0;
+ }
+ }
+
+ /**
+ * Base element for all types in a parquet schema.
+ */
+ public static abstract class SchemaElement {
+ }
+
+ private static class ElementWithName {
+ final String name;
+ final SchemaElement element;
+
+ public ElementWithName(String name, SchemaElement element) {
+ this.name = name;
+ this.element = element;
+ }
+ }
+
+ // Native APIS
+
+ public static class StructElement extends SchemaElement {
+ private final ElementWithName[] children;
+
+ private StructElement(ElementWithName[] children) {
+ this.children = children;
+ }
+
+ public static StructBuilder builder() {
+ return new StructBuilder();
+ }
+ }
+
+ public static class StructBuilder {
+ ArrayList children = new ArrayList<>();
+
+ StructBuilder() {
+ // Empty
+ }
+
+ public StructBuilder addChild(String name, SchemaElement child) {
+ children.add(new ElementWithName(name, child));
+ return this;
+ }
+
+ public StructElement build() {
+ return new StructElement(children.toArray(new ElementWithName[0]));
+ }
+ }
+
+ public static class ValueElement extends SchemaElement {
+ public ValueElement() {
+ }
+ }
+
+ public static class ListElement extends SchemaElement {
+ private final SchemaElement item;
+
+ public ListElement(SchemaElement item) {
+ this.item = item;
+ }
+ }
+
+ public static class MapElement extends SchemaElement {
+ private final SchemaElement key;
+ private final SchemaElement value;
+
+ public MapElement(SchemaElement key, SchemaElement value) {
+ this.key = key;
+ this.value = value;
+ }
+ }
}
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/ParseURI.java b/src/main/java/com/nvidia/spark/rapids/jni/ParseURI.java
index 6b71416dcb..2dd598a2bf 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/ParseURI.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/ParseURI.java
@@ -64,7 +64,7 @@ public static ColumnVector parseURIQuery(ColumnView uriColumn) {
* Parse query and return a specific parameter for each URI from the incoming column.
*
* @param URIColumn The input strings column in which each row contains a URI.
- * @param String The parameter to extract from the query
+ * @param String The parameter to extract from the query
* @return A string column with query data extracted.
*/
public static ColumnVector parseURIQueryWithLiteral(ColumnView uriColumn, String query) {
@@ -72,17 +72,18 @@ public static ColumnVector parseURIQueryWithLiteral(ColumnView uriColumn, String
return new ColumnVector(parseQueryWithLiteral(uriColumn.getNativeView(), query));
}
- /**
+ /**
* Parse query and return a specific parameter for each URI from the incoming column.
*
* @param URIColumn The input strings column in which each row contains a URI.
- * @param String The parameter to extract from the query
+ * @param String The parameter to extract from the query
* @return A string column with query data extracted.
*/
public static ColumnVector parseURIQueryWithColumn(ColumnView uriColumn, ColumnView queryColumn) {
assert uriColumn.getType().equals(DType.STRING) : "Input type must be String";
assert queryColumn.getType().equals(DType.STRING) : "Query type must be String";
- return new ColumnVector(parseQueryWithColumn(uriColumn.getNativeView(), queryColumn.getNativeView()));
+ return new ColumnVector(
+ parseQueryWithColumn(uriColumn.getNativeView(), queryColumn.getNativeView()));
}
/**
@@ -97,9 +98,14 @@ public static ColumnVector parseURIPath(ColumnView uriColumn) {
}
private static native long parseProtocol(long inputColumnHandle);
+
private static native long parseHost(long inputColumnHandle);
+
private static native long parseQuery(long inputColumnHandle);
+
private static native long parseQueryWithLiteral(long inputColumnHandle, String query);
+
private static native long parseQueryWithColumn(long inputColumnHandle, long queryColumnHandle);
+
private static native long parsePath(long inputColumnHandle);
}
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/Preconditions.java b/src/main/java/com/nvidia/spark/rapids/jni/Preconditions.java
index a6dfcdb104..4dba50c2a6 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/Preconditions.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/Preconditions.java
@@ -22,49 +22,51 @@
* This class contains utility methods for checking preconditions.
*/
public class Preconditions {
- /**
- * Check if the condition is true, otherwise throw an IllegalStateException with the given message.
- */
- public static void ensure(boolean condition, String message) {
- if (!condition) {
- throw new IllegalStateException(message);
- }
+ /**
+ * Check if the condition is true, otherwise throw an IllegalStateException with the given message.
+ */
+ public static void ensure(boolean condition, String message) {
+ if (!condition) {
+ throw new IllegalStateException(message);
}
+ }
- /**
- * Check if the condition is true, otherwise throw an IllegalStateException with the given message supplier.
- */
- public static void ensure(boolean condition, Supplier messageSupplier) {
- if (!condition) {
- throw new IllegalStateException(messageSupplier.get());
- }
+ /**
+ * Check if the condition is true, otherwise throw an IllegalStateException with the given message supplier.
+ */
+ public static void ensure(boolean condition, Supplier messageSupplier) {
+ if (!condition) {
+ throw new IllegalStateException(messageSupplier.get());
}
+ }
- /**
- * Check if the value is non-negative, otherwise throw an IllegalArgumentException with the given message.
- * @param value the value to check
- * @param name the name of the value
- * @return the value if it is non-negative
- * @throws IllegalArgumentException if the value is negative
- */
- public static int ensureNonNegative(int value, String name) {
- if (value < 0) {
- throw new IllegalArgumentException(name + " must be non-negative, but was " + value);
- }
- return value;
+ /**
+ * Check if the value is non-negative, otherwise throw an IllegalArgumentException with the given message.
+ *
+ * @param value the value to check
+ * @param name the name of the value
+ * @return the value if it is non-negative
+ * @throws IllegalArgumentException if the value is negative
+ */
+ public static int ensureNonNegative(int value, String name) {
+ if (value < 0) {
+ throw new IllegalArgumentException(name + " must be non-negative, but was " + value);
}
+ return value;
+ }
- /**
- * Check if the value is non-negative, otherwise throw an IllegalArgumentException with the given message.
- * @param value the value to check
- * @param name the name of the value
- * @return the value if it is non-negative
- * @throws IllegalArgumentException if the value is negative
- */
- public static long ensureNonNegative(long value, String name) {
- if (value < 0) {
- throw new IllegalArgumentException(name + " must be non-negative, but was " + value);
- }
- return value;
+ /**
+ * Check if the value is non-negative, otherwise throw an IllegalArgumentException with the given message.
+ *
+ * @param value the value to check
+ * @param name the name of the value
+ * @return the value if it is non-negative
+ * @throws IllegalArgumentException if the value is negative
+ */
+ public static long ensureNonNegative(long value, String name) {
+ if (value < 0) {
+ throw new IllegalArgumentException(name + " must be non-negative, but was " + value);
}
+ return value;
+ }
}
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/Profiler.java b/src/main/java/com/nvidia/spark/rapids/jni/Profiler.java
index 85e6b4a0a3..336540d27f 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/Profiler.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/Profiler.java
@@ -13,15 +13,17 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
package com.nvidia.spark.rapids.jni;
import ai.rapids.cudf.NativeDepsLoader;
-
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
-/** Profiler that collects CUDA and NVTX events for the current process. */
+/**
+ * Profiler that collects CUDA and NVTX events for the current process.
+ */
public class Profiler {
private static final long DEFAULT_WRITE_BUFFER_SIZE = 1024 * 1024;
private static final int DEFAULT_FLUSH_PERIOD_MILLIS = 0;
@@ -31,6 +33,7 @@ public class Profiler {
/**
* Initialize the profiler in a standby state. The start method must be called after this
* to start collecting profiling data.
+ *
* @param config profiler configuration
*/
public static void init(DataWriter w, Config config) {
@@ -53,10 +56,11 @@ public static void init(DataWriter w, Config config) {
* Deprecated. Use init(Config) instead.
* Initialize the profiler in a standby state. The start method must be called after this
* to start collecting profiling data.
- * @param w data writer for writing profiling data
- * @param writeBufferSize size of host memory buffer to use for collecting profiling data.
- * Recommended to be between 1-8 MB in size to balance callback
- * overhead with latency.
+ *
+ * @param w data writer for writing profiling data
+ * @param writeBufferSize size of host memory buffer to use for collecting profiling data.
+ * Recommended to be between 1-8 MB in size to balance callback
+ * overhead with latency.
* @param flushPeriodMillis time period in milliseconds to explicitly flush collected
* profiling data to the writer. A value <= 0 will disable explicit
* flushing.
@@ -119,17 +123,22 @@ private static native void nativeInit(String libPath, DataWriter writer,
private static native void nativeShutdown();
- /** Interface for profiler data writers */
+ /**
+ * Interface for profiler data writers
+ */
public interface DataWriter extends AutoCloseable {
/**
* Called by the profiler to write a block of profiling data. Profiling data is written
* in a size-prefixed flatbuffer format. See profiler.fbs for the schema.
+ *
* @param data profiling data to be written
*/
void write(ByteBuffer data);
}
- /** Profiler configuration class. **/
+ /**
+ * Profiler configuration class.
+ **/
public static class Config {
private final long writeBufferSize;
private final int flushPeriodMillis;
@@ -141,7 +150,9 @@ public static class Config {
this.allocAsyncCapturing = builder.allocAsyncCapturing;
}
- /** Builder interface for profiler configuration. **/
+ /**
+ * Builder interface for profiler configuration.
+ **/
public static class Builder {
private long writeBufferSize = DEFAULT_WRITE_BUFFER_SIZE;
private int flushPeriodMillis = DEFAULT_FLUSH_PERIOD_MILLIS;
@@ -151,6 +162,7 @@ public static class Builder {
* Configure the size of the host memory buffer used for collecting profiling data.
* Recommended to be between 1 to 8 MB in size to balance callback overhead with
* latency.
+ *
* @param writeBufferSize size of buffer in bytes
*/
public Builder withWriteBufferSize(long writeBufferSize) {
@@ -160,6 +172,7 @@ public Builder withWriteBufferSize(long writeBufferSize) {
/**
* Configure the time period to explicitly flush collected profiling data to the writer.
+ *
* @param flushPeriodMillis time period in milliseconds. A value <= 0 will disable explicit
* flushing.
*/
@@ -170,6 +183,7 @@ public Builder withFlushPeriodMillis(int flushPeriodMillis) {
/**
* Configure whether async allocation and free events are captured by the profiler.
+ *
* @param allocAsyncCapturing true if async allocation and free events should be captured,
* false otherwise.
*/
@@ -178,7 +192,9 @@ public Builder withAllocAsyncCapturing(boolean allocAsyncCapturing) {
return this;
}
- /** Build a profiler configuration object. */
+ /**
+ * Build a profiler configuration object.
+ */
public Config build() {
return new Config(this);
}
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/RegexRewriteUtils.java b/src/main/java/com/nvidia/spark/rapids/jni/RegexRewriteUtils.java
index 9277c3e0f9..b940a5093f 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/RegexRewriteUtils.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/RegexRewriteUtils.java
@@ -16,29 +16,36 @@
package com.nvidia.spark.rapids.jni;
-import ai.rapids.cudf.*;
+import ai.rapids.cudf.ColumnVector;
+import ai.rapids.cudf.CudfAccessor;
+import ai.rapids.cudf.DType;
+import ai.rapids.cudf.NativeDepsLoader;
+import ai.rapids.cudf.Scalar;
public class RegexRewriteUtils {
static {
NativeDepsLoader.loadNativeDeps();
}
-/**
- * @brief Check if input string contains regex pattern `literal[start-end]{len,}`, which means
- * a literal string followed by a range of characters in the range of start to end, with at least
- * len characters.
- *
- * @param strings Column of strings to check for literal.
- * @param literal UTF-8 encoded string to check in strings column.
- * @param len Minimum number of characters to check after the literal.
- * @param start Minimum UTF-8 codepoint value to check for in the range.
- * @param end Maximum UTF-8 codepoint value to check for in the range.
- * @return ColumnVector of booleans where true indicates the string contains the pattern.
- */
- public static ColumnVector literalRangePattern(ColumnVector input, Scalar literal, int len, int start, int end) {
- assert(input.getType().equals(DType.STRING)) : "column must be a String";
- return new ColumnVector(literalRangePattern(input.getNativeView(), CudfAccessor.getScalarHandle(literal), len, start, end));
+ /**
+ * @param strings Column of strings to check for literal.
+ * @param literal UTF-8 encoded string to check in strings column.
+ * @param len Minimum number of characters to check after the literal.
+ * @param start Minimum UTF-8 codepoint value to check for in the range.
+ * @param end Maximum UTF-8 codepoint value to check for in the range.
+ * @return ColumnVector of booleans where true indicates the string contains the pattern.
+ * @brief Check if input string contains regex pattern `literal[start-end]{len,}`, which means
+ * a literal string followed by a range of characters in the range of start to end, with at least
+ * len characters.
+ */
+ public static ColumnVector literalRangePattern(ColumnVector input, Scalar literal, int len,
+ int start, int end) {
+ assert (input.getType().equals(DType.STRING)) : "column must be a String";
+ return new ColumnVector(
+ literalRangePattern(input.getNativeView(), CudfAccessor.getScalarHandle(literal), len,
+ start, end));
}
- private static native long literalRangePattern(long input, long literal, int len, int start, int end);
+ private static native long literalRangePattern(long input, long literal, int len, int start,
+ int end);
}
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/RmmSpark.java b/src/main/java/com/nvidia/spark/rapids/jni/RmmSpark.java
index 45a234dcca..72df229d27 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/RmmSpark.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/RmmSpark.java
@@ -23,20 +23,11 @@
import ai.rapids.cudf.RmmException;
import ai.rapids.cudf.RmmTrackingResourceAdaptor;
-import java.util.Arrays;
-import java.util.Map;
-
/**
* Initialize RMM in ways that are specific to Spark.
*/
public class RmmSpark {
- public enum OomInjectionType {
- CPU_OR_GPU,
- CPU,
- GPU;
- }
-
private static volatile SparkResourceAdaptor sra = null;
/**
@@ -50,13 +41,15 @@ public static void setEventHandler(RmmEventHandler handler) throws RmmException
/**
* Set the event handler in a way that Spark wants it. For now this is the same as RMM, but in
* the future it is likely to change.
- * @param handler the handler to set
+ *
+ * @param handler the handler to set
* @param logLocation the location where you want spark state transitions. Alloc and free logging
* is handled separately when setting up RMM. "stderr" or "stdout" are treated
* as `std::cerr` and `std::cout` respectively in native code. Anything else
* is treated as a file.
*/
- public static void setEventHandler(RmmEventHandler handler, String logLocation) throws RmmException {
+ public static void setEventHandler(RmmEventHandler handler, String logLocation)
+ throws RmmException {
// synchronize with RMM not RmmSpark to stay in sync with Rmm itself.
synchronized (Rmm.class) {
// RmmException constructor is not public, so we have to use a different exception
@@ -125,8 +118,9 @@ public static long getCurrentThreadId() {
/**
* Indicate that a given thread is dedicated to a specific task. This thread can be part of a
* thread pool, but if it blocks it can never transitively block another active task.
+ *
* @param threadId the thread ID to use
- * @param taskId the task ID this thread is working on.
+ * @param taskId the task ID this thread is working on.
*/
public static void startDedicatedTaskThread(long threadId, long taskId, Thread thread) {
synchronized (Rmm.class) {
@@ -140,6 +134,7 @@ public static void startDedicatedTaskThread(long threadId, long taskId, Thread t
/**
* Indicate that the current thread is dedicated to a specific task. This thread can be part of
* a thread pool, but if this blocks it can never transitively block another active task.
+ *
* @param taskId the task ID this thread is working on.
*/
public static void currentThreadIsDedicatedToTask(long taskId) {
@@ -148,9 +143,10 @@ public static void currentThreadIsDedicatedToTask(long taskId) {
/**
* A shuffle thread has started to work on some tasks.
+ *
* @param threadId the thread ID (not java thread id).
- * @param thread the java thread
- * @param taskIds the IDs of tasks that this is starting work on.
+ * @param thread the java thread
+ * @param taskIds the IDs of tasks that this is starting work on.
*/
public static void shuffleThreadWorkingTasks(long threadId, Thread thread, long[] taskIds) {
synchronized (Rmm.class) {
@@ -163,6 +159,7 @@ public static void shuffleThreadWorkingTasks(long threadId, Thread thread, long[
/**
* The current thread is a shuffle thread and has started to work on some tasks.
+ *
* @param taskIds the IDs of the tasks that this is starting work on.
*/
public static void shuffleThreadWorkingOnTasks(long[] taskIds) {
@@ -172,12 +169,13 @@ public static void shuffleThreadWorkingOnTasks(long[] taskIds) {
/**
* The current thread which is in a thread pool that could transitively block other tasks has
* started to work on a task.
+ *
* @param taskId the ID of the task that this is starting work on.
*/
public static void poolThreadWorkingOnTask(long taskId) {
long threadId = getCurrentThreadId();
Thread thread = Thread.currentThread();
- long[] taskIds = new long[]{taskId};
+ long[] taskIds = new long[] {taskId};
synchronized (Rmm.class) {
if (sra != null && sra.isOpen()) {
ThreadStateRegistry.addThread(threadId, thread);
@@ -189,8 +187,9 @@ public static void poolThreadWorkingOnTask(long taskId) {
/**
* A thread in a thread pool that could transitively block other tasks has finished work
* on some tasks.
+ *
* @param threadId the thread ID (not java thread id).
- * @param taskIds the IDs of the tasks that are done.
+ * @param taskIds the IDs of the tasks that are done.
*/
public static void poolThreadFinishedForTasks(long threadId, long[] taskIds) {
synchronized (Rmm.class) {
@@ -202,8 +201,9 @@ public static void poolThreadFinishedForTasks(long threadId, long[] taskIds) {
/**
* A shuffle thread has finished work on some tasks.
+ *
* @param threadId the thread ID (not java thread id).
- * @param taskIds the IDs of the tasks that are done.
+ * @param taskIds the IDs of the tasks that are done.
*/
private static void shuffleThreadFinishedForTasks(long threadId, long[] taskIds) {
poolThreadFinishedForTasks(threadId, taskIds);
@@ -212,6 +212,7 @@ private static void shuffleThreadFinishedForTasks(long threadId, long[] taskIds)
/**
* The current thread which is in a thread pool that could transitively block other tasks
* has finished work on some tasks.
+ *
* @param taskIds the IDs of the tasks that are done.
*/
public static void poolThreadFinishedForTasks(long[] taskIds) {
@@ -220,6 +221,7 @@ public static void poolThreadFinishedForTasks(long[] taskIds) {
/**
* The current shuffle thread has finished work on some tasks.
+ *
* @param taskIds the IDs of the tasks that are done.
*/
public static void shuffleThreadFinishedForTasks(long[] taskIds) {
@@ -229,14 +231,16 @@ public static void shuffleThreadFinishedForTasks(long[] taskIds) {
/**
* The current thread which is in a thread pool that could transitively block other tasks
* has finished work on a task.
+ *
* @param taskId the ID of the task that is done.
*/
public static void poolThreadFinishedForTask(long taskId) {
- poolThreadFinishedForTasks(getCurrentThreadId(), new long[]{taskId});
+ poolThreadFinishedForTasks(getCurrentThreadId(), new long[] {taskId});
}
/**
* Indicate that a retry block has started for a given thread.
+ *
* @param threadId the id of the thread, not the java ID.
*/
public static void startRetryBlock(long threadId) {
@@ -256,6 +260,7 @@ public static void currentThreadStartRetryBlock() {
/**
* Indicate that a retry block has ended for a given thread.
+ *
* @param threadId the id of the thread, not the java ID.
*/
public static void endRetryBlock(long threadId) {
@@ -283,6 +288,7 @@ private static void checkAndBreakDeadlocks() {
/**
* Remove the given thread ID from being associated with a given task
+ *
* @param threadId the ID of the thread that is no longer a part of a task or shuffle
* (not java thread id).
*/
@@ -305,6 +311,7 @@ public static void removeCurrentDedicatedThreadAssociation(long taskId) {
* Remove all task associations for a given thread. This is intended to be used as a part
* of tests when a thread is shutting down, or for a pool thread when it is fully done.
* Dedicated task thread typically are cleaned when the task itself completes.
+ *
* @param threadId the id of the thread to clean up
*/
public static void removeAllThreadAssociation(long threadId) {
@@ -327,6 +334,7 @@ public static void removeAllCurrentThreadAssociation() {
/**
* Indicate that a given task is done and if there are any threads still associated with it
* then they should also be removed.
+ *
* @param taskId the ID of the task that has completed.
*/
public static void taskDone(long taskId) {
@@ -339,6 +347,7 @@ public static void taskDone(long taskId) {
/**
* A dedicated task thread is about to submit work to a pool that could transitively block it.
+ *
* @param threadId the ID of the thread that is about to submit the work.
*/
public static void submittingToPool(long threadId) {
@@ -360,6 +369,7 @@ public static void submittingToPool() {
/**
* A dedicated task thread is about to wait on work done on a pool that could transitively
* block it.
+ *
* @param threadId the ID of the thread that is about to wait.
*/
public static void waitingOnPool(long threadId) {
@@ -381,6 +391,7 @@ public static void waitingOnPool() {
/**
* A dedicated task thread is done waiting on a pool, either for a result or after submitting
* something to the pool.
+ *
* @param threadId the ID of the thread that is done.
*/
public static void doneWaitingOnPool(long threadId) {
@@ -430,6 +441,7 @@ public static void blockThreadUntilReady() {
/**
* Force the thread with the given ID to throw a GpuRetryOOM or CpuRetryOOM on their next
* allocation attempt, depending on the type of allocation being done.
+ *
* @param threadId the ID of the thread to throw the exception (not java thread id).
*/
public static void forceRetryOOM(long threadId) {
@@ -439,9 +451,10 @@ public static void forceRetryOOM(long threadId) {
/**
* Force the thread with the given ID to throw a GpuRetryOOM or CpuRetryOOM on their next
* allocation attempt, depending on the type of allocation being done.
- * @param threadId the ID of the thread to throw the exception (not java thread id).
- * @param numOOMs the number of times the *RetryOOM should be thrown
- * @param oomMode the ordinal corresponding to OomInjectionType to filter allocations
+ *
+ * @param threadId the ID of the thread to throw the exception (not java thread id).
+ * @param numOOMs the number of times the *RetryOOM should be thrown
+ * @param oomMode the ordinal corresponding to OomInjectionType to filter allocations
* @param skipCount how many matching allocations to skip
*/
public static void forceRetryOOM(long threadId, int numOOMs, int oomMode, int skipCount) {
@@ -461,6 +474,7 @@ public static void forceRetryOOM(long threadId, int numOOMs) {
/**
* Force the thread with the given ID to throw a GpuSplitAndRetryOOM of CpuSplitAndRetryOOM
* on their next allocation attempt, depending on the allocation being done.
+ *
* @param threadId the ID of the thread to throw the exception (not java thread id).
*/
public static void forceSplitAndRetryOOM(long threadId) {
@@ -470,9 +484,10 @@ public static void forceSplitAndRetryOOM(long threadId) {
/**
* Force the thread with the given ID to throw a GpuSplitAndRetryOOM or CpuSplitAndRetryOOm
* on their next allocation attempt, depending on the allocation being done.
- * @param threadId the ID of the thread to throw the exception (not java thread id).
- * @param numOOMs the number of times the *SplitAndRetryOOM should be thrown
- * @param oomMode the ordinal corresponding to OomInjectionType to filter allocations
+ *
+ * @param threadId the ID of the thread to throw the exception (not java thread id).
+ * @param numOOMs the number of times the *SplitAndRetryOOM should be thrown
+ * @param oomMode the ordinal corresponding to OomInjectionType to filter allocations
* @param skipCount how many matching allocations to skip
*/
public static void forceSplitAndRetryOOM(long threadId, int numOOMs, int oomMode, int skipCount) {
@@ -492,6 +507,7 @@ public static void forceSplitAndRetryOOM(long threadId, int numOOMs) {
/**
* Force the thread with the given ID to throw a CudfException on their next allocation attempt.
* This is to simulate a cuDF exception being thrown from a kernel and test retry handling code.
+ *
* @param threadId the ID of the thread to throw the exception (not java thread id).
*/
public static void forceCudfException(long threadId) {
@@ -501,6 +517,7 @@ public static void forceCudfException(long threadId) {
/**
* Force the thread with the given ID to throw a CudfException on their next allocation attempt.
* This is to simulate a cuDF exception being thrown from a kernel and test retry handling code.
+ *
* @param threadId the ID of the thread to throw the exception (not java thread id).
* @param numTimes the number of times the CudfException should be thrown
*/
@@ -527,6 +544,7 @@ public static RmmSparkThreadState getStateOf(long threadId) {
/**
* Get the number of retry exceptions that were thrown and reset the metric.
+ *
* @param taskId the id of the task to get the metric for.
* @return the number of times it was thrown or 0 if in the UNKNOWN state.
*/
@@ -543,6 +561,7 @@ public static int getAndResetNumRetryThrow(long taskId) {
/**
* Get the number of split and retry exceptions that were thrown and reset the metric.
+ *
* @param taskId the id of the task to get the metric for.
* @return the number of times it was thrown or 0 if in the UNKNOWN state.
*/
@@ -559,6 +578,7 @@ public static int getAndResetNumSplitRetryThrow(long taskId) {
/**
* Get how long, in nanoseconds, that the task was blocked for
+ *
* @param taskId the id of the task to get the metric for.
* @return the time the task was blocked or 0 if in the UNKNOWN state.
*/
@@ -575,6 +595,7 @@ public static long getAndResetBlockTimeNs(long taskId) {
/**
* Get how long, in nanoseconds, that this task lost in computation time due to retries.
+ *
* @param taskId the id of the task to get the metric for.
* @return the time the task did computation that was lost.
*/
@@ -591,6 +612,7 @@ public static long getAndResetComputeTimeLostToRetryNs(long taskId) {
/**
* Get the max device memory footprint, in bytes, that this task had allocated over its lifetime
+ *
* @param taskId the id of the task to get the metric for.
* @return the max device memory footprint.
*/
@@ -608,7 +630,8 @@ public static long getAndResetGpuMaxMemoryAllocated(long taskId) {
/**
* Called before doing an allocation on the CPU. This could throw an injected exception to help
* with testing.
- * @param amount the amount of memory being requested
+ *
+ * @param amount the amount of memory being requested
* @param blocking is this for a blocking allocate or a non-blocking one.
* @return a boolean that indicates if the allocation is recursive. Note that recursive
* allocations on the CPU are only allowed with non-blocking allocations. This must be passed
@@ -628,9 +651,10 @@ public static boolean preCpuAlloc(long amount, boolean blocking) {
/**
* The allocation that was going to be done succeeded.
- * @param ptr a pointer to the memory that was allocated.
- * @param amount the amount of memory that was allocated.
- * @param blocking is this for a blocking allocate or a non-blocking one.
+ *
+ * @param ptr a pointer to the memory that was allocated.
+ * @param amount the amount of memory that was allocated.
+ * @param blocking is this for a blocking allocate or a non-blocking one.
* @param wasRecursive the boolean that was returned from `preCpuAlloc`.
*/
public static void postCpuAllocSuccess(long ptr, long amount, boolean blocking,
@@ -646,8 +670,9 @@ public static void postCpuAllocSuccess(long ptr, long amount, boolean blocking,
/**
* The allocation failed, and spilling didn't save it.
- * @param wasOom was the failure caused by an OOM or something else.
- * @param blocking is this for a blocking allocate or a non-blocking one.
+ *
+ * @param wasOom was the failure caused by an OOM or something else.
+ * @param blocking is this for a blocking allocate or a non-blocking one.
* @param wasRecursive the boolean that was returned from `preCpuAlloc`.
* @return true if the allocation should be retried else false if the state machine
* thinks that a retry would not help.
@@ -666,7 +691,8 @@ public static boolean postCpuAllocFailed(boolean wasOom, boolean blocking, boole
/**
* Some CPU memory was freed.
- * @param ptr a pointer to the memory being deallocated.
+ *
+ * @param ptr a pointer to the memory being deallocated.
* @param amount the amount that was made available.
*/
public static void cpuDeallocate(long ptr, long amount) {
@@ -679,4 +705,10 @@ public static void cpuDeallocate(long ptr, long amount) {
}
}
+ public enum OomInjectionType {
+ CPU_OR_GPU,
+ CPU,
+ GPU
+ }
+
}
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/RowConversion.java b/src/main/java/com/nvidia/spark/rapids/jni/RowConversion.java
index fdb7f40480..093643209a 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/RowConversion.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/RowConversion.java
@@ -16,9 +16,15 @@
package com.nvidia.spark.rapids.jni;
-import ai.rapids.cudf.*;
+import ai.rapids.cudf.ColumnVector;
+import ai.rapids.cudf.ColumnView;
+import ai.rapids.cudf.DType;
+import ai.rapids.cudf.NativeDepsLoader;
+import ai.rapids.cudf.Table;
-/** Utility class for converting between column major and row major data */
+/**
+ * Utility class for converting between column major and row major data
+ */
public class RowConversion {
static {
NativeDepsLoader.loadNativeDeps();
@@ -27,7 +33,7 @@ public class RowConversion {
/**
* For details about how this method functions refer to
* {@link #convertToRowsFixedWidthOptimized()}.
- *
+ *
* The only thing different between this method and {@link #convertToRowsFixedWidthOptimized()}
* is that this can handle rougly 250M columns while {@link #convertToRowsFixedWidthOptimized()}
* can only handle columns less than 100
@@ -63,7 +69,7 @@ public static ColumnVector[] convertToRows(Table table) {
* |row N+1 | validity for row N+1 | padding |
* ...
*
- *
+ *
* The format of each row is similar in layout to a C struct where each column will have padding
* in front of it to align it properly. Each row has padding inserted at the end so the next row
* is aligned to a 64-bit boundary. This is so that the first column will always start at the
@@ -79,8 +85,8 @@ public static ColumnVector[] convertToRows(Table table) {
* | A - BOOL8 (8-bit) | B - INT16 (16-bit) | C - DURATION_DAYS (32-bit) |
*
*
- * Will have a layout that looks like
- *
+ * Will have a layout that looks like
+ *
* | A_0 | P | B_0 | B_1 | C_0 | C_1 | C_2 | C_3 | V0 | P | P | P | P | P | P | P |
*
*
@@ -127,14 +133,14 @@ public static ColumnVector[] convertToRowsFixedWidthOptimized(Table table) {
/**
* Convert a column of list of bytes that is formatted like the output from `convertToRows`
* and convert it back to a table.
- *
+ *
* NOTE: This method doesn't support nested types
*
- * @param vec the row data to process.
+ * @param vec the row data to process.
* @param schema the types of each column.
* @return the parsed table.
*/
- public static Table convertFromRows(ColumnView vec, DType ... schema) {
+ public static Table convertFromRows(ColumnView vec, DType... schema) {
int[] types = new int[schema.length];
int[] scale = new int[schema.length];
for (int i = 0; i < schema.length; i++) {
@@ -148,14 +154,14 @@ public static Table convertFromRows(ColumnView vec, DType ... schema) {
/**
* Convert a column of list of bytes that is formatted like the output from `convertToRows`
* and convert it back to a table.
- *
+ *
* NOTE: This method doesn't support nested types
*
- * @param vec the row data to process.
+ * @param vec the row data to process.
* @param schema the types of each column.
* @return the parsed table.
*/
- public static Table convertFromRowsFixedWidthOptimized(ColumnView vec, DType ... schema) {
+ public static Table convertFromRowsFixedWidthOptimized(ColumnView vec, DType... schema) {
int[] types = new int[schema.length];
int[] scale = new int[schema.length];
for (int i = 0; i < schema.length; i++) {
@@ -167,8 +173,11 @@ public static Table convertFromRowsFixedWidthOptimized(ColumnView vec, DType ...
}
private static native long[] convertToRows(long nativeHandle);
+
private static native long[] convertToRowsFixedWidthOptimized(long nativeHandle);
private static native long[] convertFromRows(long nativeColumnView, int[] types, int[] scale);
- private static native long[] convertFromRowsFixedWidthOptimized(long nativeColumnView, int[] types, int[] scale);
+
+ private static native long[] convertFromRowsFixedWidthOptimized(long nativeColumnView,
+ int[] types, int[] scale);
}
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/SparkResourceAdaptor.java b/src/main/java/com/nvidia/spark/rapids/jni/SparkResourceAdaptor.java
index 9e3414f7d3..c4bbe23112 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/SparkResourceAdaptor.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/SparkResourceAdaptor.java
@@ -13,21 +13,18 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package com.nvidia.spark.rapids.jni;
-import com.nvidia.spark.rapids.jni.RmmSpark.OomInjectionType;
+package com.nvidia.spark.rapids.jni;
import ai.rapids.cudf.NativeDepsLoader;
import ai.rapids.cudf.RmmDeviceMemoryResource;
import ai.rapids.cudf.RmmEventHandlerResourceAdaptor;
import ai.rapids.cudf.RmmWrappingDeviceMemoryResource;
+import com.nvidia.spark.rapids.jni.RmmSpark.OomInjectionType;
public class SparkResourceAdaptor
- extends RmmWrappingDeviceMemoryResource> {
- static {
- NativeDepsLoader.loadNativeDeps();
- }
-
+ extends
+ RmmWrappingDeviceMemoryResource> {
/**
* How long does the SparkResourceAdaptor pool thread states as a watchdog to break up potential
* deadlocks.
@@ -35,11 +32,16 @@ public class SparkResourceAdaptor
private static final long pollingPeriod = Long.getLong(
"ai.rapids.cudf.spark.rmmWatchdogPollingPeriod", 100);
+ static {
+ NativeDepsLoader.loadNativeDeps();
+ }
+
private long handle = 0;
- private Thread watchDog;
+ private final Thread watchDog;
/**
* Create a new tracking resource adaptor.
+ *
* @param wrapped the memory resource to track allocations. This should not be reused.
*/
public SparkResourceAdaptor(RmmEventHandlerResourceAdaptor wrapped) {
@@ -48,13 +50,14 @@ public SparkResourceAdaptor(RmmEventHandlerResourceAdaptor wrapped,
- String logLoc) {
+ String logLoc) {
super(wrapped);
watchDog = new Thread(() -> {
try {
@@ -78,6 +81,71 @@ public SparkResourceAdaptor(RmmEventHandlerResourceAdaptor 0) {
@@ -143,8 +212,9 @@ public void poolThreadFinishedForTasks(long threadId, long[] taskIds) {
/**
* Remove the given thread ID from any association.
+ *
* @param threadId the ID of the thread that is no longer a part of a task or shuffle (not java thread id).
- * @param taskId the task that is being removed. If the task id is -1, then any/all tasks are removed.
+ * @param taskId the task that is being removed. If the task id is -1, then any/all tasks are removed.
*/
public void removeThreadAssociation(long threadId, long taskId) {
removeThreadAssociation(getHandle(), threadId, taskId);
@@ -153,6 +223,7 @@ public void removeThreadAssociation(long threadId, long taskId) {
/**
* Indicate that a given task is done and if there are any threads still associated with it
* then they should also be removed.
+ *
* @param taskId the ID of the task that has completed.
*/
public void taskDone(long taskId) {
@@ -161,6 +232,7 @@ public void taskDone(long taskId) {
/**
* A dedicated task thread is going to submit work to a pool.
+ *
* @param threadId the ID of the thread that will submit the work.
*/
public void submittingToPool(long threadId) {
@@ -169,6 +241,7 @@ public void submittingToPool(long threadId) {
/**
* A dedicated task thread is going to wait on work in a pool to complete.
+ *
* @param threadId the ID of the thread that will submit the work.
*/
public void waitingOnPool(long threadId) {
@@ -178,6 +251,7 @@ public void waitingOnPool(long threadId) {
/**
* A dedicated task thread is done waiting on a pool. This could be because of submitting
* something to the pool or waiting on a result from the pool.
+ *
* @param threadId the ID of the thread that is done.
*/
public void doneWaitingOnPool(long threadId) {
@@ -186,9 +260,10 @@ public void doneWaitingOnPool(long threadId) {
/**
* Force the thread with the given ID to throw a GpuRetryOOM on their next allocation attempt.
- * @param threadId the ID of the thread to throw the exception (not java thread id).
- * @param numOOMs the number of times the GpuRetryOOM should be thrown
- * @param oomMode ordinal of the corresponding RmmSpark.OomInjectionType
+ *
+ * @param threadId the ID of the thread to throw the exception (not java thread id).
+ * @param numOOMs the number of times the GpuRetryOOM should be thrown
+ * @param oomMode ordinal of the corresponding RmmSpark.OomInjectionType
* @param skipCount the number of times a matching allocation is skipped before injecting the first OOM
*/
public void forceRetryOOM(long threadId, int numOOMs, int oomMode, int skipCount) {
@@ -199,15 +274,16 @@ public void forceRetryOOM(long threadId, int numOOMs, int oomMode, int skipCount
private void validateOOMInjectionParams(int numOOMs, int oomMode, int skipCount) {
assert numOOMs >= 0 : "non-negative numOoms expected: actual=" + numOOMs;
assert skipCount >= 0 : "non-negative skipCount expected: actual=" + skipCount;
- assert oomMode >= 0 && oomMode < OomInjectionType.values().length:
- "non-negative oomMode<" + OomInjectionType.values().length + " expected: actual=" + oomMode;
+ assert oomMode >= 0 && oomMode < OomInjectionType.values().length :
+ "non-negative oomMode<" + OomInjectionType.values().length + " expected: actual=" + oomMode;
}
/**
* Force the thread with the given ID to throw a GpuSplitAndRetryOOM on their next allocation attempt.
- * @param threadId the ID of the thread to throw the exception (not java thread id).
- * @param numOOMs the number of times the GpuSplitAndRetryOOM should be thrown
- * @param oomMode ordinal of the corresponding RmmSpark.OomInjectionType
+ *
+ * @param threadId the ID of the thread to throw the exception (not java thread id).
+ * @param numOOMs the number of times the GpuSplitAndRetryOOM should be thrown
+ * @param oomMode ordinal of the corresponding RmmSpark.OomInjectionType
* @param skipCount the number of times a matching allocation is skipped before injecting the first OOM
*/
public void forceSplitAndRetryOOM(long threadId, int numOOMs, int oomMode, int skipCount) {
@@ -217,6 +293,7 @@ public void forceSplitAndRetryOOM(long threadId, int numOOMs, int oomMode, int s
/**
* Force the thread with the given ID to throw a GpuSplitAndRetryOOM on their next allocation attempt.
+ *
* @param threadId the ID of the thread to throw the exception (not java thread id).
* @param numTimes the number of times the CudfException should be thrown
*/
@@ -255,11 +332,11 @@ public long getAndResetGpuMaxMemoryAllocated(long taskId) {
return getAndResetGpuMaxMemoryAllocated(getHandle(), taskId);
}
-
/**
* Called before doing an allocation on the CPU. This could throw an injected exception to help
* with testing.
- * @param amount the amount of memory being requested
+ *
+ * @param amount the amount of memory being requested
* @param blocking is this for a blocking allocate or a non-blocking one.
*/
public boolean preCpuAlloc(long amount, boolean blocking) {
@@ -268,9 +345,10 @@ public boolean preCpuAlloc(long amount, boolean blocking) {
/**
* The allocation that was going to be done succeeded.
- * @param ptr a pointer to the memory that was allocated.
- * @param amount the amount of memory that was allocated.
- * @param blocking is this for a blocking allocate or a non-blocking one.
+ *
+ * @param ptr a pointer to the memory that was allocated.
+ * @param amount the amount of memory that was allocated.
+ * @param blocking is this for a blocking allocate or a non-blocking one.
* @param wasRecursive the result of calling preCpuAlloc.
*/
public void postCpuAllocSuccess(long ptr, long amount, boolean blocking, boolean wasRecursive) {
@@ -279,8 +357,9 @@ public void postCpuAllocSuccess(long ptr, long amount, boolean blocking, boolean
/**
* The allocation failed, and spilling didn't save it.
- * @param wasOom was the failure caused by an OOM or something else.
- * @param blocking is this for a blocking allocate or a non-blocking one.
+ *
+ * @param wasOom was the failure caused by an OOM or something else.
+ * @param blocking is this for a blocking allocate or a non-blocking one.
* @param wasRecursive the result of calling preCpuAlloc
* @return true if the allocation should be retried else false if the state machine
* thinks that a retry would not help.
@@ -291,46 +370,11 @@ public boolean postCpuAllocFailed(boolean wasOom, boolean blocking, boolean wasR
/**
* Some CPU memory was freed.
- * @param ptr a pointer to the memory being deallocated.
+ *
+ * @param ptr a pointer to the memory being deallocated.
* @param amount the amount that was made available.
*/
public void cpuDeallocate(long ptr, long amount) {
cpuDeallocate(getHandle(), ptr, amount);
}
-
- /**
- * Get the ID of the current thread that can be used with the other SparkResourceAdaptor APIs.
- * Don't use the java thread ID. They are not related.
- */
- public static native long getCurrentThreadId();
-
- private native static long createNewAdaptor(long wrappedHandle, String logLoc);
- private native static void releaseAdaptor(long handle);
- private static native void startDedicatedTaskThread(long handle, long threadId, long taskId);
- private static native void poolThreadWorkingOnTasks(long handle, boolean isForShuffle, long threadId, long[] taskIds);
- private static native void poolThreadFinishedForTasks(long handle, long threadId, long[] taskIds);
- private static native void removeThreadAssociation(long handle, long threadId, long taskId);
- private static native void taskDone(long handle, long taskId);
- private static native void submittingToPool(long handle, long threadId);
- private static native void waitingOnPool(long handle, long threadId);
- private static native void doneWaitingOnPool(long handle, long threadId);
- private static native void forceRetryOOM(long handle, long threadId, int numOOMs, int oomMode, int skipCount);
- private static native void forceSplitAndRetryOOM(long handle, long threadId, int numOOMs, int oomMode, int skipCount);
- private static native void forceCudfException(long handle, long threadId, int numTimes);
- private static native void blockThreadUntilReady(long handle);
- private static native int getStateOf(long handle, long threadId);
- private static native int getAndResetRetryThrowInternal(long handle, long taskId);
- private static native int getAndResetSplitRetryThrowInternal(long handle, long taskId);
- private static native long getAndResetBlockTimeInternal(long handle, long taskId);
- private static native long getAndResetComputeTimeLostToRetry(long handle, long taskId);
- private static native long getAndResetGpuMaxMemoryAllocated(long handle, long taskId);
- private static native void startRetryBlock(long handle, long threadId);
- private static native void endRetryBlock(long handle, long threadId);
- private static native void checkAndBreakDeadlocks(long handle);
- private static native boolean preCpuAlloc(long handle, long amount, boolean blocking);
- private static native void postCpuAllocSuccess(long handle, long ptr, long amount,
- boolean blocking, boolean wasRecursive);
- private static native boolean postCpuAllocFailed(long handle, boolean wasOom,
- boolean blocking, boolean wasRecursive);
- private static native void cpuDeallocate(long handle, long ptr, long amount);
}
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/ThreadStateRegistry.java b/src/main/java/com/nvidia/spark/rapids/jni/ThreadStateRegistry.java
index 4e7021e6ea..4de2396566 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/ThreadStateRegistry.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/ThreadStateRegistry.java
@@ -16,12 +16,10 @@
package com.nvidia.spark.rapids.jni;
+import java.util.HashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import java.util.HashMap;
-import java.util.HashSet;
-
/**
* This is used to allow us to map a native thread id to a java thread so we can look at the
* state from a java perspective.
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/ZOrder.java b/src/main/java/com/nvidia/spark/rapids/jni/ZOrder.java
index 62c5137e46..4f44ed65fd 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/ZOrder.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/ZOrder.java
@@ -17,7 +17,6 @@
package com.nvidia.spark.rapids.jni;
import ai.rapids.cudf.ColumnVector;
-import ai.rapids.cudf.DType;
import ai.rapids.cudf.NativeDepsLoader;
import ai.rapids.cudf.Scalar;
@@ -32,13 +31,14 @@ public class ZOrder {
* the InterleaveBits expression in DeltaLake. The input data should all be the same type and all
* fixed with types. In general if you want good clustering/ordering then these should all
* be positive integer values.
- * @param numRows the number of rows to output in a corner case where there are no input columns.
- * This should never happen in practice, but the expression supports this so we
- * should to.
+ *
+ * @param numRows the number of rows to output in a corner case where there are no input columns.
+ * This should never happen in practice, but the expression supports this so we
+ * should to.
* @param inputColumns the data to process.
* @return a binary column of the interleaved data.
*/
- public static ColumnVector interleaveBits(int numRows, ColumnVector ... inputColumns) {
+ public static ColumnVector interleaveBits(int numRows, ColumnVector... inputColumns) {
if (inputColumns.length == 0) {
try (ColumnVector empty = ColumnVector.fromUnsignedBytes();
Scalar emptyList = Scalar.listFromColumnView(empty)) {
@@ -59,15 +59,15 @@ public static ColumnVector interleaveBits(int numRows, ColumnVector ... inputCol
* cluster the data that databricks uses, and that is why we have it here. Please note that
* this currently only supports indexes where numBits * inputColumns.length <= 64.
*
- * @param numBits the number of bits in the input columns to use. Typically, this is log2(max)
- * for the values in all the inputColumns.
- * @param numRows the number of rows. Used if inputColumns is empty. I think this is also a corner
- * case that can never happen in practice, but I am just covering my bases here.
- * a column of 0 is returned in this case.
+ * @param numBits the number of bits in the input columns to use. Typically, this is log2(max)
+ * for the values in all the inputColumns.
+ * @param numRows the number of rows. Used if inputColumns is empty. I think this is also a corner
+ * case that can never happen in practice, but I am just covering my bases here.
+ * a column of 0 is returned in this case.
* @param inputColumns The columns to intermix.
* @return the corresponding indexes stored as long values.
*/
- public static ColumnVector hilbertIndex(int numBits, int numRows, ColumnVector ... inputColumns) {
+ public static ColumnVector hilbertIndex(int numBits, int numRows, ColumnVector... inputColumns) {
if (inputColumns.length == 0) {
try (Scalar zero = Scalar.fromLong(0)) {
return ColumnVector.fromScalar(zero, numRows);
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/ColumnOffsetInfo.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/ColumnOffsetInfo.java
index 3a46806a78..e22e2c2e7d 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/ColumnOffsetInfo.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/ColumnOffsetInfo.java
@@ -16,10 +16,10 @@
package com.nvidia.spark.rapids.jni.kudo;
-import ai.rapids.cudf.DeviceMemoryBufferView;
-
import static com.nvidia.spark.rapids.jni.Preconditions.ensureNonNegative;
+import ai.rapids.cudf.DeviceMemoryBufferView;
+
/**
* This class is used to store the offsets of the buffer of a column in the serialized data.
*/
@@ -32,7 +32,8 @@ class ColumnOffsetInfo {
private final long data;
private final long dataBufferLen;
- public ColumnOffsetInfo(long validity, long validityBufferLen, long offset, long offsetBufferLen, long data,
+ public ColumnOffsetInfo(long validity, long validityBufferLen, long offset, long offsetBufferLen,
+ long data,
long dataBufferLen) {
ensureNonNegative(validityBufferLen, "validityBuffeLen");
ensureNonNegative(offsetBufferLen, "offsetBufferLen");
@@ -47,6 +48,7 @@ public ColumnOffsetInfo(long validity, long validityBufferLen, long offset, long
/**
* Get the validity buffer offset.
+ *
* @return {@value #INVALID_OFFSET} if the validity buffer is not present, otherwise the offset.
*/
long getValidity() {
@@ -55,6 +57,7 @@ long getValidity() {
/**
* Get a view of the validity buffer from underlying buffer.
+ *
* @param baseAddress the base address of underlying buffer.
* @return null if the validity buffer is not present, otherwise a view of the buffer.
*/
@@ -67,6 +70,7 @@ DeviceMemoryBufferView getValidityBuffer(long baseAddress) {
/**
* Get the offset buffer offset.
+ *
* @return {@value #INVALID_OFFSET} if the offset buffer is not present, otherwise the offset.
*/
long getOffset() {
@@ -75,6 +79,7 @@ long getOffset() {
/**
* Get a view of the offset buffer from underlying buffer.
+ *
* @param baseAddress the base address of underlying buffer.
* @return null if the offset buffer is not present, otherwise a view of the buffer.
*/
@@ -87,6 +92,7 @@ DeviceMemoryBufferView getOffsetBuffer(long baseAddress) {
/**
* Get the data buffer offset.
+ *
* @return {@value #INVALID_OFFSET} if the data buffer is not present, otherwise the offset.
*/
long getData() {
@@ -95,6 +101,7 @@ long getData() {
/**
* Get a view of the data buffer from underlying buffer.
+ *
* @param baseAddress the base address of underlying buffer.
* @return null if the data buffer is not present, otherwise a view of the buffer.
*/
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/ColumnViewInfo.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/ColumnViewInfo.java
index 002dff54c0..07f70664a1 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/ColumnViewInfo.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/ColumnViewInfo.java
@@ -16,12 +16,13 @@
package com.nvidia.spark.rapids.jni.kudo;
-import ai.rapids.cudf.*;
+import static com.nvidia.spark.rapids.jni.Preconditions.ensureNonNegative;
+import ai.rapids.cudf.ColumnView;
+import ai.rapids.cudf.DType;
+import ai.rapids.cudf.DeviceMemoryBuffer;
import java.util.Optional;
-import static com.nvidia.spark.rapids.jni.Preconditions.ensureNonNegative;
-
class ColumnViewInfo {
private final DType dtype;
private final ColumnOffsetInfo offsetInfo;
@@ -42,12 +43,12 @@ ColumnView buildColumnView(DeviceMemoryBuffer buffer, ColumnView[] childrenView)
long baseAddress = buffer.getAddress();
if (dtype.isNestedType()) {
- return new ColumnView(dtype, rowCount, Optional.of((long)nullCount),
+ return new ColumnView(dtype, rowCount, Optional.of((long) nullCount),
offsetInfo.getValidityBuffer(baseAddress),
offsetInfo.getOffsetBuffer(baseAddress),
childrenView);
} else {
- return new ColumnView(dtype, rowCount, Optional.of((long)nullCount),
+ return new ColumnView(dtype, rowCount, Optional.of((long) nullCount),
offsetInfo.getDataBuffer(baseAddress),
offsetInfo.getValidityBuffer(baseAddress),
offsetInfo.getOffsetBuffer(baseAddress));
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataOutputStreamWriter.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataOutputStreamWriter.java
index c88f125b2e..d7449969f8 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataOutputStreamWriter.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataOutputStreamWriter.java
@@ -17,7 +17,6 @@
package com.nvidia.spark.rapids.jni.kudo;
import ai.rapids.cudf.HostMemoryBuffer;
-
import java.io.DataOutputStream;
import java.io.IOException;
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataWriter.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataWriter.java
index 1f2e8f3dca..4bd4be814f 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataWriter.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataWriter.java
@@ -17,7 +17,6 @@
package com.nvidia.spark.rapids.jni.kudo;
import ai.rapids.cudf.HostMemoryBuffer;
-
import java.io.IOException;
/**
@@ -34,7 +33,8 @@ abstract class DataWriter {
* @param srcOffset offset to start at.
* @param len amount to copy.
*/
- public abstract void copyDataFrom(HostMemoryBuffer src, long srcOffset, long len) throws IOException;
+ public abstract void copyDataFrom(HostMemoryBuffer src, long srcOffset, long len)
+ throws IOException;
public void flush() throws IOException {
// NOOP by default
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoHostMergeResult.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoHostMergeResult.java
index 6529f9e15e..38b5b12884 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoHostMergeResult.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoHostMergeResult.java
@@ -16,15 +16,17 @@
package com.nvidia.spark.rapids.jni.kudo;
-import ai.rapids.cudf.*;
-import com.nvidia.spark.rapids.jni.Arms;
-import com.nvidia.spark.rapids.jni.schema.Visitors;
-
-import java.util.List;
-
import static com.nvidia.spark.rapids.jni.Preconditions.ensure;
import static java.util.Objects.requireNonNull;
+import ai.rapids.cudf.Cuda;
+import ai.rapids.cudf.DeviceMemoryBuffer;
+import ai.rapids.cudf.HostMemoryBuffer;
+import ai.rapids.cudf.Schema;
+import ai.rapids.cudf.Table;
+import com.nvidia.spark.rapids.jni.schema.Visitors;
+import java.util.List;
+
/**
* The result of merging several kudo tables into one contiguous table on the host.
*/
@@ -33,11 +35,13 @@ public class KudoHostMergeResult implements AutoCloseable {
private final List columnInfoList;
private HostMemoryBuffer hostBuf;
- KudoHostMergeResult(Schema schema, HostMemoryBuffer hostBuf, List columnInfoList) {
+ KudoHostMergeResult(Schema schema, HostMemoryBuffer hostBuf,
+ List columnInfoList) {
requireNonNull(schema, "schema is null");
requireNonNull(columnInfoList, "columnInfoList is null");
ensure(schema.getFlattenedColumnNames().length == columnInfoList.size(), () ->
- "Column offsets size does not match flattened schema size, column offsets size: " + columnInfoList.size() +
+ "Column offsets size does not match flattened schema size, column offsets size: " +
+ columnInfoList.size() +
", flattened schema size: " + schema.getFlattenedColumnNames().length);
this.schema = schema;
this.columnInfoList = columnInfoList;
@@ -52,6 +56,7 @@ public void close() throws Exception {
/**
* Get the length of the data in the host buffer.
+ *
* @return the length of the data in the host buffer
*/
public long getDataLength() {
@@ -60,6 +65,7 @@ public long getDataLength() {
/**
* Convert the host buffer into a cudf table.
+ *
* @return the cudf table
*/
public Table toTable() {
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java
index 6370531428..4135776363 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java
@@ -16,20 +16,29 @@
package com.nvidia.spark.rapids.jni.kudo;
-import ai.rapids.cudf.*;
+import static com.nvidia.spark.rapids.jni.Preconditions.ensure;
+import static java.util.Objects.requireNonNull;
+
+import ai.rapids.cudf.BufferType;
+import ai.rapids.cudf.Cuda;
+import ai.rapids.cudf.HostColumnVector;
+import ai.rapids.cudf.JCudfSerialization;
+import ai.rapids.cudf.Schema;
+import ai.rapids.cudf.Table;
import com.nvidia.spark.rapids.jni.Pair;
import com.nvidia.spark.rapids.jni.schema.Visitors;
-
-import java.io.*;
+import java.io.BufferedOutputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
import java.util.Arrays;
import java.util.List;
import java.util.function.LongConsumer;
import java.util.function.Supplier;
import java.util.stream.IntStream;
-import static com.nvidia.spark.rapids.jni.Preconditions.ensure;
-import static java.util.Objects.requireNonNull;
-
/**
* This class is used to serialize/deserialize a table using the Kudo format.
*
@@ -148,8 +157,9 @@
public class KudoSerializer {
private static final byte[] PADDING = new byte[64];
- private static final BufferType[] ALL_BUFFER_TYPES = new BufferType[]{BufferType.VALIDITY, BufferType.OFFSET,
- BufferType.DATA};
+ private static final BufferType[] ALL_BUFFER_TYPES =
+ new BufferType[] {BufferType.VALIDITY, BufferType.OFFSET,
+ BufferType.DATA};
static {
Arrays.fill(PADDING, (byte) 0);
@@ -164,6 +174,75 @@ public KudoSerializer(Schema schema) {
this.flattenedColumnCount = schema.getFlattenedColumnNames().length;
}
+ /**
+ * Write a row count only record to an output stream.
+ *
+ * @param out output stream
+ * @param numRows number of rows to write
+ * @return number of bytes written
+ */
+ public static long writeRowCountToStream(OutputStream out, int numRows) {
+ if (numRows <= 0) {
+ throw new IllegalArgumentException("Number of rows must be > 0, but was " + numRows);
+ }
+ try {
+ DataWriter writer = writerFrom(out);
+ KudoTableHeader header = new KudoTableHeader(0, numRows, 0, 0, 0, 0, new byte[0]);
+ header.writeTo(writer);
+ writer.flush();
+ return header.getSerializedSize();
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private static DataWriter writerFrom(OutputStream out) {
+ if (!(out instanceof DataOutputStream)) {
+ out = new DataOutputStream(new BufferedOutputStream(out));
+ }
+ return new DataOutputStreamWriter((DataOutputStream) out);
+ }
+
+ static long padForHostAlignment(long orig) {
+ return ((orig + 3) / 4) * 4;
+ }
+
+ static long padForHostAlignment(DataWriter out, long bytes) throws IOException {
+ final long paddedBytes = padForHostAlignment(bytes);
+ if (paddedBytes > bytes) {
+ out.write(PADDING, 0, (int) (paddedBytes - bytes));
+ }
+ return paddedBytes;
+ }
+
+ static long padFor64byteAlignment(long orig) {
+ return ((orig + 63) / 64) * 64;
+ }
+
+ static DataInputStream readerFrom(InputStream in) {
+ if (in instanceof DataInputStream) {
+ return (DataInputStream) in;
+ }
+ return new DataInputStream(in);
+ }
+
+ static T withTime(Supplier task, LongConsumer timeConsumer) {
+ long now = System.nanoTime();
+ T ret = task.get();
+ timeConsumer.accept(System.nanoTime() - now);
+ return ret;
+ }
+
+ /**
+ * This method returns the length in bytes needed to represent X number of rows
+ * e.g. getValidityLengthInBytes(5) => 1 byte
+ * getValidityLengthInBytes(7) => 1 byte
+ * getValidityLengthInBytes(14) => 2 bytes
+ */
+ static long getValidityLengthInBytes(long rows) {
+ return (rows + 7) / 8;
+ }
+
/**
* Write partition of a table to a stream. This method is used for test only.
*
@@ -208,7 +287,8 @@ long writeToStream(Table table, OutputStream out, int rowOffset, int numRows) {
* @param numRows number of rows to write
* @return number of bytes written
*/
- public long writeToStream(HostColumnVector[] columns, OutputStream out, int rowOffset, int numRows) {
+ public long writeToStream(HostColumnVector[] columns, OutputStream out, int rowOffset,
+ int numRows) {
ensure(numRows > 0, () -> "numRows must be > 0, but was " + numRows);
ensure(columns.length > 0, () -> "columns must not be empty, for row count only records " +
"please call writeRowCountToStream");
@@ -220,29 +300,6 @@ public long writeToStream(HostColumnVector[] columns, OutputStream out, int rowO
}
}
- /**
- * Write a row count only record to an output stream.
- *
- * @param out output stream
- * @param numRows number of rows to write
- * @return number of bytes written
- */
- public static long writeRowCountToStream(OutputStream out, int numRows) {
- if (numRows <= 0) {
- throw new IllegalArgumentException("Number of rows must be > 0, but was " + numRows);
- }
- try {
- DataWriter writer = writerFrom(out);
- KudoTableHeader header = new KudoTableHeader(0, numRows, 0, 0, 0
- , 0, new byte[0]);
- header.writeTo(writer);
- writer.flush();
- return header.getSerializedSize();
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
- }
-
/**
* Merge a list of kudo tables into a table on host memory.
*
@@ -286,15 +343,18 @@ public Pair
mergeToTable(List kudoTables) throws
}
}
- private long writeSliced(HostColumnVector[] columns, DataWriter out, int rowOffset, int numRows) throws Exception {
- KudoTableHeaderCalc headerCalc = new KudoTableHeaderCalc(rowOffset, numRows, flattenedColumnCount);
+ private long writeSliced(HostColumnVector[] columns, DataWriter out, int rowOffset, int numRows)
+ throws Exception {
+ KudoTableHeaderCalc headerCalc =
+ new KudoTableHeaderCalc(rowOffset, numRows, flattenedColumnCount);
Visitors.visitColumns(columns, headerCalc);
KudoTableHeader header = headerCalc.getHeader();
header.writeTo(out);
long bytesWritten = 0;
for (BufferType bufferType : ALL_BUFFER_TYPES) {
- SlicedBufferSerializer serializer = new SlicedBufferSerializer(rowOffset, numRows, bufferType, out);
+ SlicedBufferSerializer serializer =
+ new SlicedBufferSerializer(rowOffset, numRows, bufferType, out);
Visitors.visitColumns(columns, serializer);
bytesWritten += serializer.getTotalDataLen();
}
@@ -309,52 +369,4 @@ private long writeSliced(HostColumnVector[] columns, DataWriter out, int rowOffs
return header.getSerializedSize() + bytesWritten;
}
-
- private static DataWriter writerFrom(OutputStream out) {
- if (!(out instanceof DataOutputStream)) {
- out = new DataOutputStream(new BufferedOutputStream(out));
- }
- return new DataOutputStreamWriter((DataOutputStream) out);
- }
-
-
- static long padForHostAlignment(long orig) {
- return ((orig + 3) / 4) * 4;
- }
-
- static long padForHostAlignment(DataWriter out, long bytes) throws IOException {
- final long paddedBytes = padForHostAlignment(bytes);
- if (paddedBytes > bytes) {
- out.write(PADDING, 0, (int) (paddedBytes - bytes));
- }
- return paddedBytes;
- }
-
- static long padFor64byteAlignment(long orig) {
- return ((orig + 63) / 64) * 64;
- }
-
- static DataInputStream readerFrom(InputStream in) {
- if (in instanceof DataInputStream) {
- return (DataInputStream) in;
- }
- return new DataInputStream(in);
- }
-
- static T withTime(Supplier task, LongConsumer timeConsumer) {
- long now = System.nanoTime();
- T ret = task.get();
- timeConsumer.accept(System.nanoTime() - now);
- return ret;
- }
-
- /**
- * This method returns the length in bytes needed to represent X number of rows
- * e.g. getValidityLengthInBytes(5) => 1 byte
- * getValidityLengthInBytes(7) => 1 byte
- * getValidityLengthInBytes(14) => 2 bytes
- */
- static long getValidityLengthInBytes(long rows) {
- return (rows + 7) / 8;
- }
}
\ No newline at end of file
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTable.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTable.java
index c49b2cb8f7..d5b86070f2 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTable.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTable.java
@@ -16,17 +16,16 @@
package com.nvidia.spark.rapids.jni.kudo;
+import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.readerFrom;
+import static java.util.Objects.requireNonNull;
+
import ai.rapids.cudf.HostMemoryBuffer;
import com.nvidia.spark.rapids.jni.Arms;
-
import java.io.DataInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Optional;
-import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.readerFrom;
-import static java.util.Objects.requireNonNull;
-
/**
* Serialized table in kudo format, including a {{@link KudoTableHeader}} and a {@link HostMemoryBuffer} for serialized
* data.
@@ -65,14 +64,15 @@ public static Optional from(InputStream in) throws IOException {
return new KudoTable(header, null);
}
- return Arms.closeIfException(HostMemoryBuffer.allocate(header.getTotalDataLen(), false), buffer -> {
- try {
- buffer.copyFromStream(0, din, header.getTotalDataLen());
- return new KudoTable(header, buffer);
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- });
+ return Arms.closeIfException(HostMemoryBuffer.allocate(header.getTotalDataLen(), false),
+ buffer -> {
+ try {
+ buffer.copyFromStream(0, din, header.getTotalDataLen());
+ return new KudoTable(header, buffer);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ });
});
}
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableHeader.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableHeader.java
index 2bf5449c7a..9537ea679d 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableHeader.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableHeader.java
@@ -16,16 +16,16 @@
package com.nvidia.spark.rapids.jni.kudo;
+import static com.nvidia.spark.rapids.jni.Preconditions.ensure;
+import static com.nvidia.spark.rapids.jni.Preconditions.ensureNonNegative;
+import static java.util.Objects.requireNonNull;
+
import java.io.DataInputStream;
import java.io.EOFException;
import java.io.IOException;
import java.util.Arrays;
import java.util.Optional;
-import static com.nvidia.spark.rapids.jni.Preconditions.ensure;
-import static com.nvidia.spark.rapids.jni.Preconditions.ensureNonNegative;
-import static java.util.Objects.requireNonNull;
-
/**
* Holds the metadata about a serialized table. If this is being read from a stream
* isInitialized will return true if the metadata was read correctly from the stream.
@@ -49,6 +49,23 @@ public final class KudoTableHeader {
// A bit set to indicate if a column has a validity buffer or not. Each column is represented by a single bit.
private final byte[] hasValidityBuffer;
+ KudoTableHeader(int offset, int numRows, int validityBufferLen, int offsetBufferLen,
+ int totalDataLen, int numColumns, byte[] hasValidityBuffer) {
+ this.offset = ensureNonNegative(offset, "offset");
+ this.numRows = ensureNonNegative(numRows, "numRows");
+ this.validityBufferLen = ensureNonNegative(validityBufferLen, "validityBufferLen");
+ this.offsetBufferLen = ensureNonNegative(offsetBufferLen, "offsetBufferLen");
+ this.totalDataLen = ensureNonNegative(totalDataLen, "totalDataLen");
+ this.numColumns = ensureNonNegative(numColumns, "numColumns");
+
+ requireNonNull(hasValidityBuffer, "hasValidityBuffer cannot be null");
+ ensure(hasValidityBuffer.length == lengthOfHasValidityBuffer(numColumns),
+ () -> numColumns + " columns expects hasValidityBuffer with length " +
+ lengthOfHasValidityBuffer(numColumns) +
+ ", but found " + hasValidityBuffer.length);
+ this.hasValidityBuffer = hasValidityBuffer;
+ }
+
/**
* Reads the table header from the given input stream.
*
@@ -61,8 +78,9 @@ public static Optional readFrom(DataInputStream din) throws IOE
try {
num = din.readInt();
if (num != SER_FORMAT_MAGIC_NUMBER) {
- throw new IllegalStateException("Kudo format error, expected magic number " + SER_FORMAT_MAGIC_NUMBER +
- " found " + num);
+ throw new IllegalStateException(
+ "Kudo format error, expected magic number " + SER_FORMAT_MAGIC_NUMBER +
+ " found " + num);
}
} catch (EOFException e) {
// If we get an EOF at the very beginning don't treat it as an error because we may
@@ -81,24 +99,14 @@ public static Optional readFrom(DataInputStream din) throws IOE
byte[] hasValidityBuffer = new byte[validityBufferLength];
din.readFully(hasValidityBuffer);
- return Optional.of(new KudoTableHeader(offset, numRows, validityBufferLen, offsetBufferLen, totalDataLen, numColumns,
- hasValidityBuffer));
+ return Optional.of(
+ new KudoTableHeader(offset, numRows, validityBufferLen, offsetBufferLen, totalDataLen,
+ numColumns,
+ hasValidityBuffer));
}
- KudoTableHeader(int offset, int numRows, int validityBufferLen, int offsetBufferLen,
- int totalDataLen, int numColumns, byte[] hasValidityBuffer) {
- this.offset = ensureNonNegative(offset, "offset");
- this.numRows = ensureNonNegative(numRows, "numRows");
- this.validityBufferLen = ensureNonNegative(validityBufferLen, "validityBufferLen");
- this.offsetBufferLen = ensureNonNegative(offsetBufferLen, "offsetBufferLen");
- this.totalDataLen = ensureNonNegative(totalDataLen, "totalDataLen");
- this.numColumns = ensureNonNegative(numColumns, "numColumns");
-
- requireNonNull(hasValidityBuffer, "hasValidityBuffer cannot be null");
- ensure(hasValidityBuffer.length == lengthOfHasValidityBuffer(numColumns),
- () -> numColumns + " columns expects hasValidityBuffer with length " + lengthOfHasValidityBuffer(numColumns) +
- ", but found " + hasValidityBuffer.length);
- this.hasValidityBuffer = hasValidityBuffer;
+ private static int lengthOfHasValidityBuffer(int numColumns) {
+ return (numColumns + 7) / 8;
}
/**
@@ -187,8 +195,4 @@ public String toString() {
", hasValidityBuffer=" + Arrays.toString(hasValidityBuffer) +
'}';
}
-
- private static int lengthOfHasValidityBuffer(int numColumns) {
- return (numColumns + 7) / 8;
- }
}
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableHeaderCalc.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableHeaderCalc.java
index 4eaa1c435c..15119f3995 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableHeaderCalc.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableHeaderCalc.java
@@ -16,17 +16,16 @@
package com.nvidia.spark.rapids.jni.kudo;
+import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padForHostAlignment;
+import static java.lang.Math.toIntExact;
+
import ai.rapids.cudf.DType;
import ai.rapids.cudf.HostColumnVectorCore;
import com.nvidia.spark.rapids.jni.schema.HostColumnsVisitor;
-
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.List;
-import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padForHostAlignment;
-import static java.lang.Math.toIntExact;
-
/**
* This class visits a list of columns and calculates the serialized table header.
*
@@ -44,7 +43,7 @@ class KudoTableHeaderCalc implements HostColumnsVisitor {
private long totalDataLen;
private int nextColIdx;
- private Deque sliceInfos = new ArrayDeque<>();
+ private final Deque sliceInfos = new ArrayDeque<>();
KudoTableHeaderCalc(int rowOffset, int numRows, int numFlattenedCols) {
this.root = new SliceInfo(rowOffset, numRows);
@@ -55,6 +54,41 @@ class KudoTableHeaderCalc implements HostColumnsVisitor {
this.nextColIdx = 0;
}
+ private static long dataLenOfValidityBuffer(HostColumnVectorCore col, SliceInfo info) {
+ if (col.hasValidityVector() && info.getRowCount() > 0) {
+ return padForHostAlignment(info.getValidityBufferInfo().getBufferLength());
+ } else {
+ return 0;
+ }
+ }
+
+ private static long dataLenOfOffsetBuffer(HostColumnVectorCore col, SliceInfo info) {
+ if (DType.STRING.equals(col.getType()) && info.getRowCount() > 0) {
+ return padForHostAlignment((long) (info.rowCount + 1) * Integer.BYTES);
+ } else {
+ return 0;
+ }
+ }
+
+ private static long dataLenOfDataBuffer(HostColumnVectorCore col, SliceInfo info) {
+ if (DType.STRING.equals(col.getType())) {
+ if (col.getOffsets() != null) {
+ long startByteOffset = col.getOffsets().getInt((long) info.offset * Integer.BYTES);
+ long endByteOffset = col.getOffsets().getInt(
+ (long) (info.offset + info.rowCount) * Integer.BYTES);
+ return padForHostAlignment(endByteOffset - startByteOffset);
+ } else {
+ return 0;
+ }
+ } else {
+ if (col.getType().getSizeInBytes() > 0) {
+ return padForHostAlignment((long) col.getType().getSizeInBytes() * info.rowCount);
+ } else {
+ return 0;
+ }
+ }
+ }
+
public KudoTableHeader getHeader() {
return new KudoTableHeader(toIntExact(root.offset),
toIntExact(root.rowCount),
@@ -93,7 +127,7 @@ public Void preVisitList(HostColumnVectorCore col) {
long offsetBufferLength = 0;
if (col.getOffsets() != null && parent.rowCount > 0) {
- offsetBufferLength = padForHostAlignment((parent.rowCount + 1) * Integer.BYTES);
+ offsetBufferLength = padForHostAlignment((long) (parent.rowCount + 1) * Integer.BYTES);
}
this.validityBufferLen += validityBufferLength;
@@ -105,8 +139,8 @@ public Void preVisitList(HostColumnVectorCore col) {
SliceInfo current;
if (col.getOffsets() != null) {
- int start = col.getOffsets().getInt(parent.offset * Integer.BYTES);
- int end = col.getOffsets().getInt((parent.offset + parent.rowCount) * Integer.BYTES);
+ int start = col.getOffsets().getInt((long) parent.offset * Integer.BYTES);
+ int end = col.getOffsets().getInt((long) (parent.offset + parent.rowCount) * Integer.BYTES);
int rowCount = end - start;
current = new SliceInfo(start, rowCount);
} else {
@@ -124,7 +158,6 @@ public Void visitList(HostColumnVectorCore col, Void preVisitResult, Void childR
return null;
}
-
@Override
public Void visit(HostColumnVectorCore col) {
SliceInfo parent = sliceInfos.peekLast();
@@ -149,38 +182,4 @@ private void setHasValidity(boolean hasValidityBuffer) {
}
nextColIdx++;
}
-
- private static long dataLenOfValidityBuffer(HostColumnVectorCore col, SliceInfo info) {
- if (col.hasValidityVector() && info.getRowCount() > 0) {
- return padForHostAlignment(info.getValidityBufferInfo().getBufferLength());
- } else {
- return 0;
- }
- }
-
- private static long dataLenOfOffsetBuffer(HostColumnVectorCore col, SliceInfo info) {
- if (DType.STRING.equals(col.getType()) && info.getRowCount() > 0) {
- return padForHostAlignment((info.rowCount + 1) * Integer.BYTES);
- } else {
- return 0;
- }
- }
-
- private static long dataLenOfDataBuffer(HostColumnVectorCore col, SliceInfo info) {
- if (DType.STRING.equals(col.getType())) {
- if (col.getOffsets() != null) {
- long startByteOffset = col.getOffsets().getInt(info.offset * Integer.BYTES);
- long endByteOffset = col.getOffsets().getInt((info.offset + info.rowCount) * Integer.BYTES);
- return padForHostAlignment(endByteOffset - startByteOffset);
- } else {
- return 0;
- }
- } else {
- if (col.getType().getSizeInBytes() > 0) {
- return padForHostAlignment(col.getType().getSizeInBytes() * info.rowCount);
- } else {
- return 0;
- }
- }
- }
}
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableMerger.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableMerger.java
index af80391f3d..7f3b50e960 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableMerger.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableMerger.java
@@ -16,25 +16,24 @@
package com.nvidia.spark.rapids.jni.kudo;
+import static com.nvidia.spark.rapids.jni.Preconditions.ensure;
+import static com.nvidia.spark.rapids.jni.kudo.ColumnOffsetInfo.INVALID_OFFSET;
+import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.getValidityLengthInBytes;
+import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padFor64byteAlignment;
+import static java.lang.Math.min;
+import static java.lang.Math.toIntExact;
+import static java.util.Objects.requireNonNull;
+
import ai.rapids.cudf.HostMemoryBuffer;
import ai.rapids.cudf.Schema;
import com.nvidia.spark.rapids.jni.Arms;
import com.nvidia.spark.rapids.jni.schema.Visitors;
-
import java.nio.ByteOrder;
import java.nio.IntBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.OptionalInt;
-import static com.nvidia.spark.rapids.jni.Preconditions.ensure;
-import static com.nvidia.spark.rapids.jni.kudo.ColumnOffsetInfo.INVALID_OFFSET;
-import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.getValidityLengthInBytes;
-import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padFor64byteAlignment;
-import static java.lang.Math.min;
-import static java.lang.Math.toIntExact;
-import static java.util.Objects.requireNonNull;
-
/**
* This class is used to merge multiple KudoTables into a single contiguous buffer, e.g. {@link KudoHostMergeResult},
* which could be easily converted to a {@link ai.rapids.cudf.ContiguousTable}.
@@ -59,7 +58,8 @@ class KudoTableMerger extends MultiKudoTableVisitor colViewInfoList;
- public KudoTableMerger(List tables, HostMemoryBuffer buffer, List columnOffsets) {
+ public KudoTableMerger(List tables, HostMemoryBuffer buffer,
+ List columnOffsets) {
super(tables);
requireNonNull(buffer, "buffer can't be null!");
ensure(columnOffsets != null, "column offsets cannot be null");
@@ -69,83 +69,6 @@ public KudoTableMerger(List tables, HostMemoryBuffer buffer, List
(columnOffsets.size());
}
- @Override
- protected KudoHostMergeResult doVisitTopSchema(Schema schema, List children) {
- return new KudoHostMergeResult(schema, buffer, colViewInfoList);
- }
-
- @Override
- protected Void doVisitStruct(Schema structType, List children) {
- ColumnOffsetInfo offsetInfo = getCurColumnOffsets();
- int nullCount = deserializeValidityBuffer(offsetInfo);
- int totalRowCount = getTotalRowCount();
- colViewInfoList.add(new ColumnViewInfo(structType.getType(),
- offsetInfo, nullCount, totalRowCount));
- return null;
- }
-
- @Override
- protected Void doPreVisitList(Schema listType) {
- ColumnOffsetInfo offsetInfo = getCurColumnOffsets();
- int nullCount = deserializeValidityBuffer(offsetInfo);
- int totalRowCount = getTotalRowCount();
- deserializeOffsetBuffer(offsetInfo);
-
- colViewInfoList.add(new ColumnViewInfo(listType.getType(),
- offsetInfo, nullCount, totalRowCount));
- return null;
- }
-
- @Override
- protected Void doVisitList(Schema listType, Void preVisitResult, Void childResult) {
- return null;
- }
-
- @Override
- protected Void doVisit(Schema primitiveType) {
- ColumnOffsetInfo offsetInfo = getCurColumnOffsets();
- int nullCount = deserializeValidityBuffer(offsetInfo);
- int totalRowCount = getTotalRowCount();
- if (primitiveType.getType().hasOffsets()) {
- deserializeOffsetBuffer(offsetInfo);
- deserializeDataBuffer(offsetInfo, OptionalInt.empty());
- } else {
- deserializeDataBuffer(offsetInfo, OptionalInt.of(primitiveType.getType().getSizeInBytes()));
- }
-
- colViewInfoList.add(new ColumnViewInfo(primitiveType.getType(),
- offsetInfo, nullCount, totalRowCount));
-
- return null;
- }
-
- private int deserializeValidityBuffer(ColumnOffsetInfo curColOffset) {
- if (curColOffset.getValidity() != INVALID_OFFSET) {
- long offset = curColOffset.getValidity();
- long validityBufferSize = padFor64byteAlignment(getValidityLengthInBytes(getTotalRowCount()));
- try (HostMemoryBuffer validityBuffer = buffer.slice(offset, validityBufferSize)) {
- int nullCountTotal = 0;
- int startRow = 0;
- for (int tableIdx = 0; tableIdx < getTableSize(); tableIdx += 1) {
- SliceInfo sliceInfo = sliceInfoOf(tableIdx);
- long validityOffset = validifyBufferOffset(tableIdx);
- if (validityOffset != INVALID_OFFSET) {
- nullCountTotal += copyValidityBuffer(validityBuffer, startRow,
- memoryBufferOf(tableIdx), toIntExact(validityOffset),
- sliceInfo);
- } else {
- appendAllValid(validityBuffer, startRow, sliceInfo.getRowCount());
- }
-
- startRow += sliceInfo.getRowCount();
- }
- return nullCountTotal;
- }
- } else {
- return 0;
- }
- }
-
/**
* Copy a sliced validity buffer to the destination buffer, starting at the given bit offset.
*
@@ -251,10 +174,97 @@ private static void appendAllValid(HostMemoryBuffer dest, int startBit, int numR
}
}
+ static KudoHostMergeResult merge(Schema schema, MergedInfoCalc mergedInfo) {
+ List serializedTables = mergedInfo.getTables();
+ return Arms.closeIfException(HostMemoryBuffer.allocate(mergedInfo.getTotalDataLen()),
+ buffer -> {
+ KudoTableMerger merger =
+ new KudoTableMerger(serializedTables, buffer, mergedInfo.getColumnOffsets());
+ return Visitors.visitSchema(schema, merger);
+ });
+ }
+
+ @Override
+ protected KudoHostMergeResult doVisitTopSchema(Schema schema, List children) {
+ return new KudoHostMergeResult(schema, buffer, colViewInfoList);
+ }
+
+ @Override
+ protected Void doVisitStruct(Schema structType, List children) {
+ ColumnOffsetInfo offsetInfo = getCurColumnOffsets();
+ int nullCount = deserializeValidityBuffer(offsetInfo);
+ int totalRowCount = getTotalRowCount();
+ colViewInfoList.add(new ColumnViewInfo(structType.getType(),
+ offsetInfo, nullCount, totalRowCount));
+ return null;
+ }
+
+ @Override
+ protected Void doPreVisitList(Schema listType) {
+ ColumnOffsetInfo offsetInfo = getCurColumnOffsets();
+ int nullCount = deserializeValidityBuffer(offsetInfo);
+ int totalRowCount = getTotalRowCount();
+ deserializeOffsetBuffer(offsetInfo);
+
+ colViewInfoList.add(new ColumnViewInfo(listType.getType(),
+ offsetInfo, nullCount, totalRowCount));
+ return null;
+ }
+
+ @Override
+ protected Void doVisitList(Schema listType, Void preVisitResult, Void childResult) {
+ return null;
+ }
+
+ @Override
+ protected Void doVisit(Schema primitiveType) {
+ ColumnOffsetInfo offsetInfo = getCurColumnOffsets();
+ int nullCount = deserializeValidityBuffer(offsetInfo);
+ int totalRowCount = getTotalRowCount();
+ if (primitiveType.getType().hasOffsets()) {
+ deserializeOffsetBuffer(offsetInfo);
+ deserializeDataBuffer(offsetInfo, OptionalInt.empty());
+ } else {
+ deserializeDataBuffer(offsetInfo, OptionalInt.of(primitiveType.getType().getSizeInBytes()));
+ }
+
+ colViewInfoList.add(new ColumnViewInfo(primitiveType.getType(),
+ offsetInfo, nullCount, totalRowCount));
+
+ return null;
+ }
+
+ private int deserializeValidityBuffer(ColumnOffsetInfo curColOffset) {
+ if (curColOffset.getValidity() != INVALID_OFFSET) {
+ long offset = curColOffset.getValidity();
+ long validityBufferSize = padFor64byteAlignment(getValidityLengthInBytes(getTotalRowCount()));
+ try (HostMemoryBuffer validityBuffer = buffer.slice(offset, validityBufferSize)) {
+ int nullCountTotal = 0;
+ int startRow = 0;
+ for (int tableIdx = 0; tableIdx < getTableSize(); tableIdx += 1) {
+ SliceInfo sliceInfo = sliceInfoOf(tableIdx);
+ long validityOffset = validifyBufferOffset(tableIdx);
+ if (validityOffset != INVALID_OFFSET) {
+ nullCountTotal += copyValidityBuffer(validityBuffer, startRow,
+ memoryBufferOf(tableIdx), toIntExact(validityOffset),
+ sliceInfo);
+ } else {
+ appendAllValid(validityBuffer, startRow, sliceInfo.getRowCount());
+ }
+
+ startRow += sliceInfo.getRowCount();
+ }
+ return nullCountTotal;
+ }
+ } else {
+ return 0;
+ }
+ }
+
private void deserializeOffsetBuffer(ColumnOffsetInfo curColOffset) {
if (curColOffset.getOffset() != INVALID_OFFSET) {
long offset = curColOffset.getOffset();
- long bufferSize = Integer.BYTES * (getTotalRowCount() + 1);
+ long bufferSize = (long) Integer.BYTES * (getTotalRowCount() + 1);
IntBuffer buf = buffer
.asByteBuffer(offset, toIntExact(bufferSize))
@@ -298,7 +308,7 @@ private void deserializeDataBuffer(ColumnOffsetInfo curColOffset, OptionalInt si
for (int tableIdx = 0; tableIdx < getTableSize(); tableIdx += 1) {
SliceInfo sliceInfo = sliceInfoOf(tableIdx);
if (sliceInfo.getRowCount() > 0) {
- int thisDataLen = toIntExact(elementSize * sliceInfo.getRowCount());
+ int thisDataLen = toIntExact((long) elementSize * sliceInfo.getRowCount());
copyDataBuffer(buf, start, tableIdx, thisDataLen);
start += thisDataLen;
}
@@ -316,17 +326,7 @@ private void deserializeDataBuffer(ColumnOffsetInfo curColOffset, OptionalInt si
}
}
-
private ColumnOffsetInfo getCurColumnOffsets() {
return columnOffsets.get(getCurrentIdx());
}
-
- static KudoHostMergeResult merge(Schema schema, MergedInfoCalc mergedInfo) {
- List serializedTables = mergedInfo.getTables();
- return Arms.closeIfException(HostMemoryBuffer.allocate(mergedInfo.getTotalDataLen()),
- buffer -> {
- KudoTableMerger merger = new KudoTableMerger(serializedTables, buffer, mergedInfo.getColumnOffsets());
- return Visitors.visitSchema(schema, merger);
- });
- }
}
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/MergeMetrics.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/MergeMetrics.java
index e621129dd6..8a04face16 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/MergeMetrics.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/MergeMetrics.java
@@ -31,18 +31,6 @@ public MergeMetrics(long calcHeaderTime, long mergeIntoHostBufferTime,
this.convertToTableTime = convertToTableTime;
}
- public long getCalcHeaderTime() {
- return calcHeaderTime;
- }
-
- public long getMergeIntoHostBufferTime() {
- return mergeIntoHostBufferTime;
- }
-
- public long getConvertToTableTime() {
- return convertToTableTime;
- }
-
public static Builder builder() {
return new Builder();
}
@@ -54,6 +42,17 @@ public static Builder builder(MergeMetrics metrics) {
.convertToTableTime(metrics.convertToTableTime);
}
+ public long getCalcHeaderTime() {
+ return calcHeaderTime;
+ }
+
+ public long getMergeIntoHostBufferTime() {
+ return mergeIntoHostBufferTime;
+ }
+
+ public long getConvertToTableTime() {
+ return convertToTableTime;
+ }
public static class Builder {
private long calcHeaderTime;
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/MergedInfoCalc.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/MergedInfoCalc.java
index 826ef2e691..fdbeaed2fb 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/MergedInfoCalc.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/MergedInfoCalc.java
@@ -16,26 +16,25 @@
package com.nvidia.spark.rapids.jni.kudo;
+import static com.nvidia.spark.rapids.jni.kudo.ColumnOffsetInfo.INVALID_OFFSET;
+import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.getValidityLengthInBytes;
+import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padFor64byteAlignment;
+
import ai.rapids.cudf.Schema;
import com.nvidia.spark.rapids.jni.schema.Visitors;
-
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
-import static com.nvidia.spark.rapids.jni.kudo.ColumnOffsetInfo.INVALID_OFFSET;
-import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.getValidityLengthInBytes;
-import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padFor64byteAlignment;
-
/**
* This class is used to calculate column offsets of merged buffer.
*/
class MergedInfoCalc extends MultiKudoTableVisitor {
- // Total data len in gpu, which accounts for 64 byte alignment
- private long totalDataLen;
// Column offset in gpu device buffer, it has one field for each flattened column
private final List columnOffsets;
+ // Total data len in gpu, which accounts for 64 byte alignment
+ private long totalDataLen;
public MergedInfoCalc(List tables) {
super(tables);
@@ -43,6 +42,12 @@ public MergedInfoCalc(List tables) {
this.columnOffsets = new ArrayList<>(tables.get(0).getHeader().getNumColumns());
}
+ static MergedInfoCalc calc(Schema schema, List table) {
+ MergedInfoCalc calc = new MergedInfoCalc(table);
+ Visitors.visitSchema(schema, calc);
+ return calc;
+ }
+
@Override
protected Void doVisitTopSchema(Schema schema, List children) {
return null;
@@ -58,7 +63,9 @@ protected Void doVisitStruct(Schema structType, List children) {
totalDataLen += validityBufferLen;
}
- columnOffsets.add(new ColumnOffsetInfo(validityOffset, validityBufferLen, INVALID_OFFSET, 0, INVALID_OFFSET, 0));
+ columnOffsets.add(
+ new ColumnOffsetInfo(validityOffset, validityBufferLen, INVALID_OFFSET, 0, INVALID_OFFSET,
+ 0));
return null;
}
@@ -75,13 +82,15 @@ protected Void doPreVisitList(Schema listType) {
long offsetBufferLen = 0;
long offsetBufferOffset = INVALID_OFFSET;
if (getTotalRowCount() > 0) {
- offsetBufferLen = padFor64byteAlignment((getTotalRowCount() + 1) * Integer.BYTES);
+ offsetBufferLen = padFor64byteAlignment((long) (getTotalRowCount() + 1) * Integer.BYTES);
offsetBufferOffset = totalDataLen;
totalDataLen += offsetBufferLen;
}
- columnOffsets.add(new ColumnOffsetInfo(validityOffset, validityBufferLen, offsetBufferOffset, offsetBufferLen, INVALID_OFFSET, 0));
+ columnOffsets.add(
+ new ColumnOffsetInfo(validityOffset, validityBufferLen, offsetBufferOffset, offsetBufferLen,
+ INVALID_OFFSET, 0));
return null;
}
@@ -105,7 +114,7 @@ protected Void doVisit(Schema primitiveType) {
long offsetBufferLen = 0;
long offsetBufferOffset = INVALID_OFFSET;
if (getTotalRowCount() > 0) {
- offsetBufferLen = padFor64byteAlignment((getTotalRowCount() + 1) * Integer.BYTES);
+ offsetBufferLen = padFor64byteAlignment((long) (getTotalRowCount() + 1) * Integer.BYTES);
offsetBufferOffset = totalDataLen;
totalDataLen += offsetBufferLen;
}
@@ -118,7 +127,8 @@ protected Void doVisit(Schema primitiveType) {
totalDataLen += dataBufferLen;
}
- columnOffsets.add(new ColumnOffsetInfo(validityOffset, validityBufferLen, offsetBufferOffset, offsetBufferLen, dataBufferOffset, dataBufferLen));
+ columnOffsets.add(new ColumnOffsetInfo(validityOffset, validityBufferLen, offsetBufferOffset,
+ offsetBufferLen, dataBufferOffset, dataBufferLen));
} else {
long totalRowCount = getTotalRowCount();
long validityBufferLen = 0;
@@ -132,18 +142,19 @@ protected Void doVisit(Schema primitiveType) {
long dataBufferLen = 0;
long dataBufferOffset = INVALID_OFFSET;
if (totalRowCount > 0) {
- dataBufferLen = padFor64byteAlignment(totalRowCount * primitiveType.getType().getSizeInBytes());
+ dataBufferLen =
+ padFor64byteAlignment(totalRowCount * primitiveType.getType().getSizeInBytes());
dataBufferOffset = totalDataLen;
totalDataLen += dataBufferLen;
}
- columnOffsets.add(new ColumnOffsetInfo(validityOffset, validityBufferLen, INVALID_OFFSET, 0, dataBufferOffset, dataBufferLen));
+ columnOffsets.add(new ColumnOffsetInfo(validityOffset, validityBufferLen, INVALID_OFFSET, 0,
+ dataBufferOffset, dataBufferLen));
}
return null;
}
-
public long getTotalDataLen() {
return totalDataLen;
}
@@ -159,10 +170,4 @@ public String toString() {
", columnOffsets=" + columnOffsets +
'}';
}
-
- static MergedInfoCalc calc(Schema schema, List table) {
- MergedInfoCalc calc = new MergedInfoCalc(table);
- Visitors.visitSchema(schema, calc);
- return calc;
- }
}
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/MultiKudoTableVisitor.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/MultiKudoTableVisitor.java
index afa7ba6ea0..bc5678cc50 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/MultiKudoTableVisitor.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/MultiKudoTableVisitor.java
@@ -16,17 +16,20 @@
package com.nvidia.spark.rapids.jni.kudo;
-import ai.rapids.cudf.HostMemoryBuffer;
-import ai.rapids.cudf.Schema;
-import com.nvidia.spark.rapids.jni.schema.SchemaVisitor;
-
-import java.util.*;
-
import static com.nvidia.spark.rapids.jni.Preconditions.ensure;
import static com.nvidia.spark.rapids.jni.kudo.ColumnOffsetInfo.INVALID_OFFSET;
import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padForHostAlignment;
import static java.lang.Math.toIntExact;
+import ai.rapids.cudf.HostMemoryBuffer;
+import ai.rapids.cudf.Schema;
+import com.nvidia.spark.rapids.jni.schema.SchemaVisitor;
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Deque;
+import java.util.List;
+import java.util.Objects;
+
/**
* This class provides a base class for visiting multiple kudo tables, e.g. it helps to maintain internal states during
* visiting multi kudo tables, which makes it easier to do some calculations based on them.
@@ -40,11 +43,11 @@ abstract class MultiKudoTableVisitor implements SchemaVisitor
private final long[] currentDataOffset;
private final Deque[] sliceInfoStack;
private final Deque totalRowCountStack;
+ // Temporary buffer to store data length of string column to avoid repeated allocation
+ private final int[] strDataLen;
// A temporary variable to keep if current column has null
private boolean hasNull;
private int currentIdx;
- // Temporary buffer to store data length of string column to avoid repeated allocation
- private final int[] strDataLen;
// Temporary variable to calculate total data length of string column
private long totalStrDataLen;
@@ -187,7 +190,8 @@ private void updateDataLen() {
}
}
- private void updateOffsets(boolean updateOffset, boolean updateData, boolean updateSliceInfo, int sizeInBytes) {
+ private void updateOffsets(boolean updateOffset, boolean updateData, boolean updateSliceInfo,
+ int sizeInBytes) {
long totalRowCount = 0;
for (int tableIdx = 0; tableIdx < tables.size(); tableIdx++) {
SliceInfo sliceInfo = sliceInfoOf(tableIdx);
@@ -202,11 +206,13 @@ private void updateOffsets(boolean updateOffset, boolean updateData, boolean upd
}
if (tables.get(tableIdx).getHeader().hasValidityBuffer(currentIdx)) {
- currentValidityOffsets[tableIdx] += padForHostAlignment(sliceInfo.getValidityBufferInfo().getBufferLength());
+ currentValidityOffsets[tableIdx] +=
+ padForHostAlignment(sliceInfo.getValidityBufferInfo().getBufferLength());
}
if (updateOffset) {
- currentOffsetOffsets[tableIdx] += padForHostAlignment((sliceInfo.getRowCount() + 1) * Integer.BYTES);
+ currentOffsetOffsets[tableIdx] +=
+ padForHostAlignment((long) (sliceInfo.getRowCount() + 1) * Integer.BYTES);
if (updateData) {
// string type
currentDataOffset[tableIdx] += padForHostAlignment(strDataLen[tableIdx]);
@@ -215,7 +221,8 @@ private void updateOffsets(boolean updateOffset, boolean updateData, boolean upd
} else {
if (updateData) {
// primitive type
- currentDataOffset[tableIdx] += padForHostAlignment(sliceInfo.getRowCount() * sizeInBytes);
+ currentDataOffset[tableIdx] +=
+ padForHostAlignment((long) sliceInfo.getRowCount() * sizeInBytes);
}
}
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedBufferSerializer.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedBufferSerializer.java
index e22a523855..81cf1557c5 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedBufferSerializer.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedBufferSerializer.java
@@ -16,19 +16,18 @@
package com.nvidia.spark.rapids.jni.kudo;
+import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padForHostAlignment;
+
import ai.rapids.cudf.BufferType;
import ai.rapids.cudf.DType;
import ai.rapids.cudf.HostColumnVectorCore;
import ai.rapids.cudf.HostMemoryBuffer;
import com.nvidia.spark.rapids.jni.schema.HostColumnsVisitor;
-
import java.io.IOException;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.List;
-import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padForHostAlignment;
-
/**
* This class visits a list of columns and serialize one of the buffers (validity, offset, or data) into with kudo
* format.
@@ -110,8 +109,8 @@ public Void preVisitList(HostColumnVectorCore col) {
SliceInfo current;
if (col.getOffsets() != null) {
int start = col.getOffsets()
- .getInt(parent.offset * Integer.BYTES);
- int end = col.getOffsets().getInt((parent.offset + parent.rowCount) * Integer.BYTES);
+ .getInt((long) parent.offset * Integer.BYTES);
+ int end = col.getOffsets().getInt((long) (parent.offset + parent.rowCount) * Integer.BYTES);
int rowCount = end - start;
current = new SliceInfo(start, rowCount);
@@ -153,7 +152,8 @@ public Void visit(HostColumnVectorCore col) {
}
}
- private long copySlicedValidity(HostColumnVectorCore column, SliceInfo sliceInfo) throws IOException {
+ private long copySlicedValidity(HostColumnVectorCore column, SliceInfo sliceInfo)
+ throws IOException {
if (column.getValidity() != null && sliceInfo.getRowCount() > 0) {
HostMemoryBuffer buff = column.getValidity();
long len = sliceInfo.getValidityBufferInfo().getBufferLength();
@@ -165,13 +165,14 @@ private long copySlicedValidity(HostColumnVectorCore column, SliceInfo sliceInfo
}
}
- private long copySlicedOffset(HostColumnVectorCore column, SliceInfo sliceInfo) throws IOException {
+ private long copySlicedOffset(HostColumnVectorCore column, SliceInfo sliceInfo)
+ throws IOException {
if (sliceInfo.rowCount <= 0 || column.getOffsets() == null) {
// Don't copy anything, there are no rows
return 0;
}
- long bytesToCopy = (sliceInfo.rowCount + 1) * Integer.BYTES;
- long srcOffset = sliceInfo.offset * Integer.BYTES;
+ long bytesToCopy = (long) (sliceInfo.rowCount + 1) * Integer.BYTES;
+ long srcOffset = (long) sliceInfo.offset * Integer.BYTES;
HostMemoryBuffer buff = column.getOffsets();
writer.copyDataFrom(buff, srcOffset, bytesToCopy);
return padForHostAlignment(writer, bytesToCopy);
@@ -181,8 +182,10 @@ private long copySlicedData(HostColumnVectorCore column, SliceInfo sliceInfo) th
if (sliceInfo.rowCount > 0) {
DType type = column.getType();
if (type.equals(DType.STRING)) {
- long startByteOffset = column.getOffsets().getInt(sliceInfo.offset * Integer.BYTES);
- long endByteOffset = column.getOffsets().getInt((sliceInfo.offset + sliceInfo.rowCount) * Integer.BYTES);
+ long startByteOffset = column.getOffsets().getInt((long) sliceInfo.offset * Integer.BYTES);
+ long endByteOffset =
+ column.getOffsets().getInt(
+ (long) (sliceInfo.offset + sliceInfo.rowCount) * Integer.BYTES);
long bytesToCopy = endByteOffset - startByteOffset;
if (column.getData() == null) {
if (bytesToCopy != 0) {
@@ -196,8 +199,8 @@ private long copySlicedData(HostColumnVectorCore column, SliceInfo sliceInfo) th
return padForHostAlignment(writer, bytesToCopy);
}
} else if (type.getSizeInBytes() > 0) {
- long bytesToCopy = sliceInfo.rowCount * type.getSizeInBytes();
- long srcOffset = sliceInfo.offset * type.getSizeInBytes();
+ long bytesToCopy = (long) sliceInfo.rowCount * type.getSizeInBytes();
+ long srcOffset = (long) sliceInfo.offset * type.getSizeInBytes();
writer.copyDataFrom(column.getData(), srcOffset, bytesToCopy);
return padForHostAlignment(writer, bytesToCopy);
} else {
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedValidityBufferInfo.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedValidityBufferInfo.java
index 7c9957f5b2..7e2e74d502 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedValidityBufferInfo.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedValidityBufferInfo.java
@@ -31,9 +31,26 @@ class SlicedValidityBufferInfo {
this.beginBit = beginBit;
}
+ static SlicedValidityBufferInfo calc(int rowOffset, int numRows) {
+ if (rowOffset < 0) {
+ throw new IllegalArgumentException("rowOffset must be >= 0, but was " + rowOffset);
+ }
+ if (numRows < 0) {
+ throw new IllegalArgumentException("numRows must be >= 0, but was " + numRows);
+ }
+ int bufferOffset = rowOffset / 8;
+ int beginBit = rowOffset % 8;
+ int bufferLength = 0;
+ if (numRows > 0) {
+ bufferLength = (rowOffset + numRows - 1) / 8 - bufferOffset + 1;
+ }
+ return new SlicedValidityBufferInfo(bufferOffset, bufferLength, beginBit);
+ }
+
@Override
public String toString() {
- return "SlicedValidityBufferInfo{" + "bufferOffset=" + bufferOffset + ", bufferLength=" + bufferLength +
+ return "SlicedValidityBufferInfo{" + "bufferOffset=" + bufferOffset + ", bufferLength=" +
+ bufferLength +
", beginBit=" + beginBit + '}';
}
@@ -48,20 +65,4 @@ public int getBufferLength() {
public int getBeginBit() {
return beginBit;
}
-
- static SlicedValidityBufferInfo calc(int rowOffset, int numRows) {
- if (rowOffset < 0) {
- throw new IllegalArgumentException("rowOffset must be >= 0, but was " + rowOffset);
- }
- if (numRows < 0) {
- throw new IllegalArgumentException("numRows must be >= 0, but was " + numRows);
- }
- int bufferOffset = rowOffset / 8;
- int beginBit = rowOffset % 8;
- int bufferLength = 0;
- if (numRows > 0) {
- bufferLength = (rowOffset + numRows - 1) / 8 - bufferOffset + 1;
- }
- return new SlicedValidityBufferInfo(bufferOffset, bufferLength, beginBit);
- }
}
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/TableBuilder.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/TableBuilder.java
index e50e462f4f..ded99ab64f 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/TableBuilder.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/TableBuilder.java
@@ -16,24 +16,28 @@
package com.nvidia.spark.rapids.jni.kudo;
-import ai.rapids.cudf.*;
+import static com.nvidia.spark.rapids.jni.Preconditions.ensure;
+import static java.util.Objects.requireNonNull;
+
+import ai.rapids.cudf.CloseableArray;
+import ai.rapids.cudf.ColumnVector;
+import ai.rapids.cudf.ColumnView;
+import ai.rapids.cudf.DeviceMemoryBuffer;
+import ai.rapids.cudf.Schema;
+import ai.rapids.cudf.Table;
import com.nvidia.spark.rapids.jni.Arms;
import com.nvidia.spark.rapids.jni.schema.SchemaVisitor;
-
import java.util.ArrayList;
import java.util.List;
-import static com.nvidia.spark.rapids.jni.Preconditions.ensure;
-import static java.util.Objects.requireNonNull;
-
/**
* This class is used to build a cudf table from a list of column view info, and a device buffer.
*/
class TableBuilder implements SchemaVisitor, AutoCloseable {
- private int curColumnIdx;
private final DeviceMemoryBuffer buffer;
private final List colViewInfoList;
private final List columnViewList;
+ private int curColumnIdx;
public TableBuilder(List colViewInfoList, DeviceMemoryBuffer buffer) {
requireNonNull(colViewInfoList, "colViewInfoList cannot be null");
@@ -52,10 +56,12 @@ public Table visitTopSchema(Schema schema, List children) {
// `children`, so we need to clear `columnViewList`.
this.columnViewList.clear();
try {
- try (CloseableArray arr = CloseableArray.wrap(new ColumnVector[children.size()])) {
+ try (CloseableArray arr = CloseableArray.wrap(
+ new ColumnVector[children.size()])) {
for (int i = 0; i < children.size(); i++) {
ColumnView colView = children.set(i, null);
- arr.set(i, ColumnVector.fromViewWithContiguousAllocation(colView.getNativeView(), buffer));
+ arr.set(i,
+ ColumnVector.fromViewWithContiguousAllocation(colView.getNativeView(), buffer));
}
return new Table(arr.getArray());
@@ -87,7 +93,7 @@ public ColumnViewInfo preVisitList(Schema listType) {
@Override
public ColumnView visitList(Schema listType, ColumnViewInfo colViewInfo, ColumnView childResult) {
- ColumnView[] children = new ColumnView[]{childResult};
+ ColumnView[] children = new ColumnView[] {childResult};
ColumnView view = colViewInfo.buildColumnView(buffer, children);
columnViewList.add(view);
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/schema/HostColumnsVisitor.java b/src/main/java/com/nvidia/spark/rapids/jni/schema/HostColumnsVisitor.java
index ae7915c60d..6b9664607f 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/schema/HostColumnsVisitor.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/schema/HostColumnsVisitor.java
@@ -19,14 +19,13 @@
package com.nvidia.spark.rapids.jni.schema;
import ai.rapids.cudf.HostColumnVectorCore;
-
import java.util.List;
/**
* A post order visitor for visiting a list of host columns in a schema.
*
*
- *
+ *
* For example, if we have three columns A, B, and C with following types:
*
*
@@ -34,7 +33,7 @@
*
B: list { int b1}
*
C: string c1
*
- *
+ *
* The order of visiting will be:
*
*
Visit primitive column a1
@@ -51,34 +50,38 @@
* @param Return type when visiting intermediate nodes.
*/
public interface HostColumnsVisitor {
- /**
- * Visit a struct column.
- * @param col the struct column to visit
- * @param children the results of visiting the children
- * @return the result of visiting the struct column
- */
- T visitStruct(HostColumnVectorCore col, List children);
+ /**
+ * Visit a struct column.
+ *
+ * @param col the struct column to visit
+ * @param children the results of visiting the children
+ * @return the result of visiting the struct column
+ */
+ T visitStruct(HostColumnVectorCore col, List children);
- /**
- * Visit a list column before actually visiting its child.
- * @param col the list column to visit
- * @return the result of visiting the list column
- */
- T preVisitList(HostColumnVectorCore col);
+ /**
+ * Visit a list column before actually visiting its child.
+ *
+ * @param col the list column to visit
+ * @return the result of visiting the list column
+ */
+ T preVisitList(HostColumnVectorCore col);
- /**
- * Visit a list column after visiting its child.
- * @param col the list column to visit
- * @param preVisitResult the result of visiting the list column before visiting its child
- * @param childResult the result of visiting the child
- * @return the result of visiting the list column
- */
- T visitList(HostColumnVectorCore col, T preVisitResult, T childResult);
+ /**
+ * Visit a list column after visiting its child.
+ *
+ * @param col the list column to visit
+ * @param preVisitResult the result of visiting the list column before visiting its child
+ * @param childResult the result of visiting the child
+ * @return the result of visiting the list column
+ */
+ T visitList(HostColumnVectorCore col, T preVisitResult, T childResult);
- /**
- * Visit a column that is a primitive type.
- * @param col the column to visit
- * @return the result of visiting the column
- */
- T visit(HostColumnVectorCore col);
+ /**
+ * Visit a column that is a primitive type.
+ *
+ * @param col the column to visit
+ * @return the result of visiting the column
+ */
+ T visit(HostColumnVectorCore col);
}
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/schema/SchemaVisitor.java b/src/main/java/com/nvidia/spark/rapids/jni/schema/SchemaVisitor.java
index c6b33e0fb4..8e81ee12c1 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/schema/SchemaVisitor.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/schema/SchemaVisitor.java
@@ -19,14 +19,13 @@
package com.nvidia.spark.rapids.jni.schema;
import ai.rapids.cudf.Schema;
-
import java.util.List;
/**
* A post order visitor for schemas.
*
*
Flattened Schema
- *
+ *
* A flattened schema is a schema where all fields with nested types are flattened into an array of fields. For example,
* for a schema with following fields:
*
@@ -36,7 +35,7 @@
*
C: string
*
D: long
*
- *
+ *
* The flattened schema will be:
*
*
@@ -59,7 +58,7 @@
*
B: list { int b1}
*
C: string
*
- *
+ *
* The order of visiting will be:
*
*
Visit primitive field a1
@@ -79,42 +78,47 @@
* @param Return type after processing all children values.
*/
public interface SchemaVisitor {
- /**
- * Visit the top level schema.
- * @param schema the top level schema to visit
- * @param children the results of visiting the children
- * @return the result of visiting the top level schema
- */
- R visitTopSchema(Schema schema, List children);
+ /**
+ * Visit the top level schema.
+ *
+ * @param schema the top level schema to visit
+ * @param children the results of visiting the children
+ * @return the result of visiting the top level schema
+ */
+ R visitTopSchema(Schema schema, List children);
- /**
- * Visit a struct schema.
- * @param structType the struct schema to visit
- * @param children the results of visiting the children
- * @return the result of visiting the struct schema
- */
- T visitStruct(Schema structType, List children);
+ /**
+ * Visit a struct schema.
+ *
+ * @param structType the struct schema to visit
+ * @param children the results of visiting the children
+ * @return the result of visiting the struct schema
+ */
+ T visitStruct(Schema structType, List children);
- /**
- * Visit a list schema before actually visiting its child.
- * @param listType the list schema to visit
- * @return the result of visiting the list schema
- */
- P preVisitList(Schema listType);
+ /**
+ * Visit a list schema before actually visiting its child.
+ *
+ * @param listType the list schema to visit
+ * @return the result of visiting the list schema
+ */
+ P preVisitList(Schema listType);
- /**
- * Visit a list schema after visiting its child.
- * @param listType the list schema to visit
- * @param preVisitResult the result of visiting the list schema before visiting its child
- * @param childResult the result of visiting the child
- * @return the result of visiting the list schema
- */
- T visitList(Schema listType, P preVisitResult, T childResult);
+ /**
+ * Visit a list schema after visiting its child.
+ *
+ * @param listType the list schema to visit
+ * @param preVisitResult the result of visiting the list schema before visiting its child
+ * @param childResult the result of visiting the child
+ * @return the result of visiting the list schema
+ */
+ T visitList(Schema listType, P preVisitResult, T childResult);
- /**
- * Visit a primitive type.
- * @param primitiveType the primitive type to visit
- * @return the result of visiting the primitive type
- */
- T visit(Schema primitiveType);
+ /**
+ * Visit a primitive type.
+ *
+ * @param primitiveType the primitive type to visit
+ * @return the result of visiting the primitive type
+ */
+ T visit(Schema primitiveType);
}
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/schema/Visitors.java b/src/main/java/com/nvidia/spark/rapids/jni/schema/Visitors.java
index b7f4f521e4..92ba7f76e1 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/schema/Visitors.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/schema/Visitors.java
@@ -21,7 +21,6 @@
import ai.rapids.cudf.HostColumnVector;
import ai.rapids.cudf.HostColumnVectorCore;
import ai.rapids.cudf.Schema;
-
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
@@ -31,75 +30,75 @@
* A utility class for visiting a schema or a list of host columns.
*/
public class Visitors {
- /**
- * Visiting a schema in post order. For more details, see {@link SchemaVisitor}.
- *
- * @param schema the schema to visit
- * @param visitor the visitor to use
- * @param Return type when visiting intermediate nodes. See {@link SchemaVisitor}
- * @param
Return type when previsiting a list. See {@link SchemaVisitor}
- * @param Return type after processing all children values. See {@link SchemaVisitor}
- * @return the result of visiting the schema
- */
- public static R visitSchema(Schema schema, SchemaVisitor visitor) {
- Objects.requireNonNull(schema, "schema cannot be null");
- Objects.requireNonNull(visitor, "visitor cannot be null");
+ /**
+ * Visiting a schema in post order. For more details, see {@link SchemaVisitor}.
+ *
+ * @param schema the schema to visit
+ * @param visitor the visitor to use
+ * @param Return type when visiting intermediate nodes. See {@link SchemaVisitor}
+ * @param
* First use path5 [*][*] to enable flatten style.
*/
@Test
@@ -390,7 +400,8 @@ void getJsonObjectTest_Test_case_path5() {
};
// flatten the arrays, then query named path "k"
- String JSON1 = "[ [[[ {'k': 'v1'} ], {'k': 'v2'}]], [[{'k': 'v3'}], {'k': 'v4'}], {'k': 'v5'} ]";
+ String JSON1 =
+ "[ [[[ {'k': 'v1'} ], {'k': 'v2'}]], [[{'k': 'v3'}], {'k': 'v4'}], {'k': 'v5'} ]";
String expectedStr1 = "[\"v5\"]";
try (
@@ -415,7 +426,8 @@ void getJsonObjectTest_Test_case_path6() {
String expectedStr1 = "[1,[21,22],3]";
String JSON2 = "[1]";
- String expectedStr2 = "1"; // note: in row mode, if it has only 1 item, then remove the outer: []
+ String expectedStr2 =
+ "1"; // note: in row mode, if it has only 1 item, then remove the outer: []
try (
ColumnVector jsonCv = ColumnVector.fromStrings(JSON1, JSON2);
@@ -626,8 +638,8 @@ void getJsonObjectTest_JNIKernelCalledTwice() {
@Test
void getJsonObjectMultiplePathsTest() {
- List path0 = Arrays.asList(namedPath("k0"));
- List path1 = Arrays.asList(namedPath("k1"));
+ List path0 = Collections.singletonList(namedPath("k0"));
+ List path1 = Collections.singletonList(namedPath("k1"));
List> paths = Arrays.asList(path0, path1);
try (ColumnVector jsonCv = ColumnVector.fromStrings("{\"k0\": \"v0\", \"k1\": \"v1\"}");
ColumnVector expected0 = ColumnVector.fromStrings("v0");
@@ -646,14 +658,16 @@ void getJsonObjectMultiplePathsTest() {
@Test
void getJsonObjectMultiplePathsTest_JNIKernelCalledTwice() {
- List path0 = Arrays.asList(namedPath("k0"));
- List path1 = Arrays.asList(namedPath("k1"));
- List path2 = Arrays.asList();
+ List path0 = Collections.singletonList(namedPath("k0"));
+ List path1 = Collections.singletonList(namedPath("k1"));
+ List path2 = Collections.emptyList();
List> paths = Arrays.asList(path0, path1, path2);
- try (ColumnVector jsonCv = ColumnVector.fromStrings("{\"k0\": \"v0\", \"k1\": \"v1\"}", "['\n\n\n\n\n\n\n\n\n\n']");
+ try (ColumnVector jsonCv = ColumnVector.fromStrings("{\"k0\": \"v0\", \"k1\": \"v1\"}",
+ "['\n\n\n\n\n\n\n\n\n\n']");
ColumnVector expected0 = ColumnVector.fromStrings("v0", null);
ColumnVector expected1 = ColumnVector.fromStrings("v1", null);
- ColumnVector expected2 = ColumnVector.fromStrings("{\"k0\":\"v0\",\"k1\":\"v1\"}", "[\"\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\"]")) {
+ ColumnVector expected2 = ColumnVector.fromStrings("{\"k0\":\"v0\",\"k1\":\"v1\"}",
+ "[\"\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\"]")) {
ColumnVector[] output = JSONUtils.getJsonObjectMultiplePaths(jsonCv, paths);
try {
assertColumnsAreEqual(expected0, output[0]);
@@ -669,8 +683,8 @@ void getJsonObjectMultiplePathsTest_JNIKernelCalledTwice() {
@Test
void getJsonObjectMultiplePathsTestCrazyLowMemoryBudget() {
- List path0 = Arrays.asList(namedPath("k0"));
- List path1 = Arrays.asList(namedPath("k1"));
+ List path0 = Collections.singletonList(namedPath("k0"));
+ List path1 = Collections.singletonList(namedPath("k1"));
List> paths = Arrays.asList(path0, path1);
try (ColumnVector jsonCv = ColumnVector.fromStrings("{\"k0\": \"v0\", \"k1\": \"v1\"}");
ColumnVector expected0 = ColumnVector.fromStrings("v0");
@@ -689,8 +703,8 @@ void getJsonObjectMultiplePathsTestCrazyLowMemoryBudget() {
@Test
void getJsonObjectMultiplePathsTestMemoryBudget() {
- List path0 = Arrays.asList(namedPath("k0"));
- List path1 = Arrays.asList(namedPath("k1"));
+ List path0 = Collections.singletonList(namedPath("k0"));
+ List path1 = Collections.singletonList(namedPath("k1"));
List> paths = Arrays.asList(path0, path1);
try (ColumnVector jsonCv = ColumnVector.fromStrings("{\"k0\": \"v0\", \"k1\": \"v1\"}");
ColumnVector expected0 = ColumnVector.fromStrings("v0");
@@ -725,7 +739,7 @@ void getJsonObjectTest_ExceedMaxNestingDepthInPath() {
/**
* This test is when an exception is thrown due to maximum nesting depth being exceeded
* when pushing the context stack during evaluating the JSON path.
- *
+ *
* The maximum depth limit here is the same as the limit for the input JSON path.
*/
@Test
@@ -752,7 +766,7 @@ void getJsonObjectTest_ExceedMaxNestingDepthInContextStack() {
/**
* This test is when an exception is thrown due to maximum nesting depth being exceeded
* in the JSON parser. The JSON path is simply mirroring the input.
- *
+ *
* Note that the maximum depth in the internal parser, which is being tested here, is different
* from the limit for the input JSON path.
*/
diff --git a/src/test/java/com/nvidia/spark/rapids/jni/GpuSubstringIndexUtilsTest.java b/src/test/java/com/nvidia/spark/rapids/jni/GpuSubstringIndexUtilsTest.java
index 9666893448..414835ba26 100644
--- a/src/test/java/com/nvidia/spark/rapids/jni/GpuSubstringIndexUtilsTest.java
+++ b/src/test/java/com/nvidia/spark/rapids/jni/GpuSubstringIndexUtilsTest.java
@@ -20,57 +20,59 @@
import ai.rapids.cudf.ColumnVector;
import ai.rapids.cudf.Scalar;
import ai.rapids.cudf.Table;
-import org.junit.jupiter.api.Test;
-
import java.util.ArrayList;
import java.util.List;
+import org.junit.jupiter.api.Test;
public class GpuSubstringIndexUtilsTest {
- @Test
- void gpuSubstringIndexTest(){
- Table.TestBuilder tb = new Table.TestBuilder();
- tb.column( "www.apache.org");
- tb.column("www.apache");
- tb.column("www");
- tb.column("");
- tb.column("org");
- tb.column("apache.org");
- tb.column("www.apache.org");
- tb.column("");
- tb.column("大千世界大");
- tb.column("www||apache");
+ @Test
+ void gpuSubstringIndexTest() {
+ Table.TestBuilder tb = new Table.TestBuilder();
+ tb.column("www.apache.org");
+ tb.column("www.apache");
+ tb.column("www");
+ tb.column("");
+ tb.column("org");
+ tb.column("apache.org");
+ tb.column("www.apache.org");
+ tb.column("");
+ tb.column("大千世界大");
+ tb.column("www||apache");
- try(Table expected = tb.build()){
- Table.TestBuilder tb2 = new Table.TestBuilder();
- tb2.column("www.apache.org");
- tb2.column("www.apache.org");
- tb2.column("www.apache.org");
- tb2.column("www.apache.org");
- tb2.column("www.apache.org");
- tb2.column("www.apache.org");
- tb2.column("www.apache.org");
- tb2.column("");
- tb2.column("大千世界大千世界");
- tb2.column("www||apache||org");
+ try (Table expected = tb.build()) {
+ Table.TestBuilder tb2 = new Table.TestBuilder();
+ tb2.column("www.apache.org");
+ tb2.column("www.apache.org");
+ tb2.column("www.apache.org");
+ tb2.column("www.apache.org");
+ tb2.column("www.apache.org");
+ tb2.column("www.apache.org");
+ tb2.column("www.apache.org");
+ tb2.column("");
+ tb2.column("大千世界大千世界");
+ tb2.column("www||apache||org");
- Scalar dotScalar = Scalar.fromString(".");
- Scalar cnChar = Scalar.fromString("千");
- Scalar verticalBar = Scalar.fromString("||");
- Scalar[] delimiterArray = new Scalar[]{dotScalar, dotScalar, dotScalar, dotScalar,dotScalar, dotScalar, dotScalar, dotScalar, cnChar, verticalBar};
- int[] countArray = new int[]{3, 2, 1, 0, -1, -2, -3, -2, 2, 2};
- List result = new ArrayList<>();
- try (Table origTable = tb2.build()){
- for(int i = 0; i < origTable.getNumberOfColumns(); i++){
- ColumnVector string_col = origTable.getColumn(i);
- result.add(GpuSubstringIndexUtils.substringIndex(string_col, delimiterArray[i], countArray[i]));
- }
- try (Table result_tbl = new Table(
- result.toArray(new ColumnVector[result.size()]))){
- AssertUtils.assertTablesAreEqual(expected, result_tbl);
- }
- }finally {
- result.forEach(ColumnVector::close);
- }
+ Scalar dotScalar = Scalar.fromString(".");
+ Scalar cnChar = Scalar.fromString("千");
+ Scalar verticalBar = Scalar.fromString("||");
+ Scalar[] delimiterArray =
+ new Scalar[] {dotScalar, dotScalar, dotScalar, dotScalar, dotScalar, dotScalar, dotScalar,
+ dotScalar, cnChar, verticalBar};
+ int[] countArray = new int[] {3, 2, 1, 0, -1, -2, -3, -2, 2, 2};
+ List result = new ArrayList<>();
+ try (Table origTable = tb2.build()) {
+ for (int i = 0; i < origTable.getNumberOfColumns(); i++) {
+ ColumnVector string_col = origTable.getColumn(i);
+ result.add(
+ GpuSubstringIndexUtils.substringIndex(string_col, delimiterArray[i], countArray[i]));
+ }
+ try (Table result_tbl = new Table(
+ result.toArray(new ColumnVector[result.size()]))) {
+ AssertUtils.assertTablesAreEqual(expected, result_tbl);
}
+ } finally {
+ result.forEach(ColumnVector::close);
+ }
}
+ }
}
diff --git a/src/test/java/com/nvidia/spark/rapids/jni/HashTest.java b/src/test/java/com/nvidia/spark/rapids/jni/HashTest.java
index d35f20fe2c..ffebb743d8 100644
--- a/src/test/java/com/nvidia/spark/rapids/jni/HashTest.java
+++ b/src/test/java/com/nvidia/spark/rapids/jni/HashTest.java
@@ -16,39 +16,44 @@
package com.nvidia.spark.rapids.jni;
+import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual;
+
import ai.rapids.cudf.ColumnVector;
import ai.rapids.cudf.ColumnView;
import ai.rapids.cudf.DType;
-import ai.rapids.cudf.HostColumnVector.*;
-import org.junit.jupiter.api.Test;
-
+import ai.rapids.cudf.HostColumnVector.BasicType;
+import ai.rapids.cudf.HostColumnVector.ListType;
import java.util.Arrays;
import java.util.Collections;
-
-import static ai.rapids.cudf.AssertUtils.*;
+import org.junit.jupiter.api.Test;
public class HashTest {
-// IEEE 754 NaN values
+ // IEEE 754 NaN values
static final float POSITIVE_FLOAT_NAN_LOWER_RANGE = Float.intBitsToFloat(0x7f800001);
static final float POSITIVE_FLOAT_NAN_UPPER_RANGE = Float.intBitsToFloat(0x7fffffff);
static final float NEGATIVE_FLOAT_NAN_LOWER_RANGE = Float.intBitsToFloat(0xff800001);
static final float NEGATIVE_FLOAT_NAN_UPPER_RANGE = Float.intBitsToFloat(0xffffffff);
- static final double POSITIVE_DOUBLE_NAN_LOWER_RANGE = Double.longBitsToDouble(0x7ff0000000000001L);
- static final double POSITIVE_DOUBLE_NAN_UPPER_RANGE = Double.longBitsToDouble(0x7fffffffffffffffL);
- static final double NEGATIVE_DOUBLE_NAN_LOWER_RANGE = Double.longBitsToDouble(0xfff0000000000001L);
- static final double NEGATIVE_DOUBLE_NAN_UPPER_RANGE = Double.longBitsToDouble(0xffffffffffffffffL);
+ static final double POSITIVE_DOUBLE_NAN_LOWER_RANGE =
+ Double.longBitsToDouble(0x7ff0000000000001L);
+ static final double POSITIVE_DOUBLE_NAN_UPPER_RANGE =
+ Double.longBitsToDouble(0x7fffffffffffffffL);
+ static final double NEGATIVE_DOUBLE_NAN_LOWER_RANGE =
+ Double.longBitsToDouble(0xfff0000000000001L);
+ static final double NEGATIVE_DOUBLE_NAN_UPPER_RANGE =
+ Double.longBitsToDouble(0xffffffffffffffffL);
@Test
void testSpark32BitMurmur3HashStrings() {
try (ColumnVector v0 = ColumnVector.fromStrings(
- "a", "B\nc", "dE\"\u0100\t\u0101 \ud720\ud721\\Fg2\'",
- "A very long (greater than 128 bytes/char string) to test a multi hash-step data point " +
- "in the MD5 hash function. This string needed to be longer.A 60 character string to " +
- "test MD5's message padding algorithm",
- "hiJ\ud720\ud721\ud720\ud721", null);
- ColumnVector result = Hash.murmurHash32(42, new ColumnVector[]{v0});
- ColumnVector expected = ColumnVector.fromBoxedInts(1485273170, 1709559900, 1423943036, 176121990, 1199621434, 42)) {
+ "a", "B\nc", "dE\"\u0100\t\u0101 \ud720\ud721\\Fg2'",
+ "A very long (greater than 128 bytes/char string) to test a multi hash-step data point " +
+ "in the MD5 hash function. This string needed to be longer.A 60 character string to " +
+ "test MD5's message padding algorithm",
+ "hiJ\ud720\ud721\ud720\ud721", null);
+ ColumnVector result = Hash.murmurHash32(42, new ColumnVector[] {v0});
+ ColumnVector expected = ColumnVector.fromBoxedInts(1485273170, 1709559900, 1423943036,
+ 176121990, 1199621434, 42)) {
assertColumnsAreEqual(expected, result);
}
}
@@ -57,8 +62,9 @@ void testSpark32BitMurmur3HashStrings() {
void testSpark32BitMurmur3HashInts() {
try (ColumnVector v0 = ColumnVector.fromBoxedInts(0, 100, null, null, Integer.MIN_VALUE, null);
ColumnVector v1 = ColumnVector.fromBoxedInts(0, null, -100, null, null, Integer.MAX_VALUE);
- ColumnVector result = Hash.murmurHash32(42, new ColumnVector[]{v0, v1});
- ColumnVector expected = ColumnVector.fromBoxedInts(59727262, 751823303, -1080202046, 42, 723455942, 133916647)) {
+ ColumnVector result = Hash.murmurHash32(42, new ColumnVector[] {v0, v1});
+ ColumnVector expected = ColumnVector.fromBoxedInts(59727262, 751823303, -1080202046, 42,
+ 723455942, 133916647)) {
assertColumnsAreEqual(expected, result);
}
}
@@ -66,12 +72,14 @@ void testSpark32BitMurmur3HashInts() {
@Test
void testSpark32BitMurmur3HashDoubles() {
try (ColumnVector v = ColumnVector.fromBoxedDoubles(
- 0.0, null, 100.0, -100.0, Double.MIN_NORMAL, Double.MAX_VALUE,
- POSITIVE_DOUBLE_NAN_UPPER_RANGE, POSITIVE_DOUBLE_NAN_LOWER_RANGE,
- NEGATIVE_DOUBLE_NAN_UPPER_RANGE, NEGATIVE_DOUBLE_NAN_LOWER_RANGE,
- Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY);
- ColumnVector result = Hash.murmurHash32(new ColumnVector[]{v});
- ColumnVector expected = ColumnVector.fromBoxedInts(1669671676, 0, -544903190, -1831674681, 150502665, 474144502, 1428788237, 1428788237, 1428788237, 1428788237, 420913893, 1915664072)) {
+ 0.0, null, 100.0, -100.0, Double.MIN_NORMAL, Double.MAX_VALUE,
+ POSITIVE_DOUBLE_NAN_UPPER_RANGE, POSITIVE_DOUBLE_NAN_LOWER_RANGE,
+ NEGATIVE_DOUBLE_NAN_UPPER_RANGE, NEGATIVE_DOUBLE_NAN_LOWER_RANGE,
+ Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY);
+ ColumnVector result = Hash.murmurHash32(new ColumnVector[] {v});
+ ColumnVector expected = ColumnVector.fromBoxedInts(1669671676, 0, -544903190, -1831674681,
+ 150502665, 474144502, 1428788237, 1428788237, 1428788237, 1428788237, 420913893,
+ 1915664072)) {
assertColumnsAreEqual(expected, result);
}
}
@@ -82,8 +90,9 @@ void testSpark32BitMurmur3HashTimestamps() {
// https://github.com/rapidsai/cudf/blob/aa7ca46dcd9e/cpp/tests/hashing/hash_test.cpp#L281-L307
try (ColumnVector v = ColumnVector.timestampMicroSecondsFromBoxedLongs(
0L, null, 100L, -100L, 0x123456789abcdefL, null, -0x123456789abcdefL);
- ColumnVector result = Hash.murmurHash32(42, new ColumnVector[]{v});
- ColumnVector expected = ColumnVector.fromBoxedInts(-1670924195, 42, 1114849490, 904948192, 657182333, 42, -57193045)) {
+ ColumnVector result = Hash.murmurHash32(42, new ColumnVector[] {v});
+ ColumnVector expected = ColumnVector.fromBoxedInts(-1670924195, 42, 1114849490, 904948192,
+ 657182333, 42, -57193045)) {
assertColumnsAreEqual(expected, result);
}
}
@@ -94,8 +103,9 @@ void testSpark32BitMurmur3HashDecimal64() {
// https://github.com/rapidsai/cudf/blob/aa7ca46dcd9e/cpp/tests/hashing/hash_test.cpp#L281-L307
try (ColumnVector v = ColumnVector.decimalFromLongs(-7,
0L, 100L, -100L, 0x123456789abcdefL, -0x123456789abcdefL);
- ColumnVector result = Hash.murmurHash32(42, new ColumnVector[]{v});
- ColumnVector expected = ColumnVector.fromBoxedInts(-1670924195, 1114849490, 904948192, 657182333, -57193045)) {
+ ColumnVector result = Hash.murmurHash32(42, new ColumnVector[] {v});
+ ColumnVector expected = ColumnVector.fromBoxedInts(-1670924195, 1114849490, 904948192,
+ 657182333, -57193045)) {
assertColumnsAreEqual(expected, result);
}
}
@@ -106,8 +116,9 @@ void testSpark32BitMurmur3HashDecimal32() {
// https://github.com/rapidsai/cudf/blob/aa7ca46dcd9e/cpp/tests/hashing/hash_test.cpp#L281-L307
try (ColumnVector v = ColumnVector.decimalFromInts(-3,
0, 100, -100, 0x12345678, -0x12345678);
- ColumnVector result = Hash.murmurHash32(42, new ColumnVector[]{v});
- ColumnVector expected = ColumnVector.fromBoxedInts(-1670924195, 1114849490, 904948192, -958054811, -1447702630)) {
+ ColumnVector result = Hash.murmurHash32(42, new ColumnVector[] {v});
+ ColumnVector expected = ColumnVector.fromBoxedInts(-1670924195, 1114849490, 904948192,
+ -958054811, -1447702630)) {
assertColumnsAreEqual(expected, result);
}
}
@@ -118,8 +129,9 @@ void testSpark32BitMurmur3HashDates() {
// https://github.com/rapidsai/cudf/blob/aa7ca46dcd9e/cpp/tests/hashing/hash_test.cpp#L281-L307
try (ColumnVector v = ColumnVector.timestampDaysFromBoxedInts(
0, null, 100, -100, 0x12345678, null, -0x12345678);
- ColumnVector result = Hash.murmurHash32(42, new ColumnVector[]{v});
- ColumnVector expected = ColumnVector.fromBoxedInts(933211791, 42, 751823303, -1080202046, -1721170160, 42, 1852996993)) {
+ ColumnVector result = Hash.murmurHash32(42, new ColumnVector[] {v});
+ ColumnVector expected = ColumnVector.fromBoxedInts(933211791, 42, 751823303, -1080202046,
+ -1721170160, 42, 1852996993)) {
assertColumnsAreEqual(expected, result);
}
}
@@ -127,12 +139,14 @@ void testSpark32BitMurmur3HashDates() {
@Test
void testSpark32BitMurmur3HashFloats() {
try (ColumnVector v = ColumnVector.fromBoxedFloats(
- 0f, 100f, -100f, Float.MIN_NORMAL, Float.MAX_VALUE, null,
- POSITIVE_FLOAT_NAN_LOWER_RANGE, POSITIVE_FLOAT_NAN_UPPER_RANGE,
- NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE,
- Float.POSITIVE_INFINITY, Float.NEGATIVE_INFINITY);
- ColumnVector result = Hash.murmurHash32(411, new ColumnVector[]{v});
- ColumnVector expected = ColumnVector.fromBoxedInts(-235179434, 1812056886, 2028471189, 1775092689, -1531511762, 411, -1053523253, -1053523253, -1053523253, -1053523253, -1526256646, 930080402)){
+ 0f, 100f, -100f, Float.MIN_NORMAL, Float.MAX_VALUE, null,
+ POSITIVE_FLOAT_NAN_LOWER_RANGE, POSITIVE_FLOAT_NAN_UPPER_RANGE,
+ NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE,
+ Float.POSITIVE_INFINITY, Float.NEGATIVE_INFINITY);
+ ColumnVector result = Hash.murmurHash32(411, new ColumnVector[] {v});
+ ColumnVector expected = ColumnVector.fromBoxedInts(-235179434, 1812056886, 2028471189,
+ 1775092689, -1531511762, 411, -1053523253, -1053523253, -1053523253, -1053523253,
+ -1526256646, 930080402)) {
assertColumnsAreEqual(expected, result);
}
}
@@ -141,8 +155,9 @@ void testSpark32BitMurmur3HashFloats() {
void testSpark32BitMurmur3HashBools() {
try (ColumnVector v0 = ColumnVector.fromBoxedBooleans(null, true, false, true, null, false);
ColumnVector v1 = ColumnVector.fromBoxedBooleans(null, true, false, null, false, true);
- ColumnVector result = Hash.murmurHash32(0, new ColumnVector[]{v0, v1});
- ColumnVector expected = ColumnVector.fromBoxedInts(0, -1589400010, -239939054, -68075478, 593689054, -1194558265)) {
+ ColumnVector result = Hash.murmurHash32(0, new ColumnVector[] {v0, v1});
+ ColumnVector expected = ColumnVector.fromBoxedInts(0, -1589400010, -239939054, -68075478,
+ 593689054, -1194558265)) {
assertColumnsAreEqual(expected, result);
}
}
@@ -150,18 +165,22 @@ void testSpark32BitMurmur3HashBools() {
@Test
void testSpark32BitMurmur3HashMixed() {
try (ColumnVector strings = ColumnVector.fromStrings(
- "a", "B\n", "dE\"\u0100\t\u0101 \ud720\ud721",
- "A very long (greater than 128 bytes/char string) to test a multi hash-step data point " +
- "in the MD5 hash function. This string needed to be longer.",
- null, null);
- ColumnVector integers = ColumnVector.fromBoxedInts(0, 100, -100, Integer.MIN_VALUE, Integer.MAX_VALUE, null);
+ "a", "B\n", "dE\"\u0100\t\u0101 \ud720\ud721",
+ "A very long (greater than 128 bytes/char string) to test a multi hash-step data point " +
+ "in the MD5 hash function. This string needed to be longer.",
+ null, null);
+ ColumnVector integers = ColumnVector.fromBoxedInts(0, 100, -100, Integer.MIN_VALUE,
+ Integer.MAX_VALUE, null);
ColumnVector doubles = ColumnVector.fromBoxedDoubles(
- 0.0, 100.0, -100.0, POSITIVE_DOUBLE_NAN_LOWER_RANGE, POSITIVE_DOUBLE_NAN_UPPER_RANGE, null);
+ 0.0, 100.0, -100.0, POSITIVE_DOUBLE_NAN_LOWER_RANGE, POSITIVE_DOUBLE_NAN_UPPER_RANGE,
+ null);
ColumnVector floats = ColumnVector.fromBoxedFloats(
- 0f, 100f, -100f, NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE, null);
+ 0f, 100f, -100f, NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE, null);
ColumnVector bools = ColumnVector.fromBoxedBooleans(true, false, null, false, true, null);
- ColumnVector result = Hash.murmurHash32(1868, new ColumnVector[]{strings, integers, doubles, floats, bools});
- ColumnVector expected = ColumnVector.fromBoxedInts(1936985022, 720652989, 339312041, 1400354989, 769988643, 1868)) {
+ ColumnVector result = Hash.murmurHash32(1868,
+ new ColumnVector[] {strings, integers, doubles, floats, bools});
+ ColumnVector expected = ColumnVector.fromBoxedInts(1936985022, 720652989, 339312041,
+ 1400354989, 769988643, 1868)) {
assertColumnsAreEqual(expected, result);
}
}
@@ -173,15 +192,18 @@ void testSpark32BitMurmur3HashStruct() {
"A very long (greater than 128 bytes/char string) to test a multi hash-step data point " +
"in the MD5 hash function. This string needed to be longer.",
null, null);
- ColumnVector integers = ColumnVector.fromBoxedInts(0, 100, -100, Integer.MIN_VALUE, Integer.MAX_VALUE, null);
+ ColumnVector integers = ColumnVector.fromBoxedInts(0, 100, -100, Integer.MIN_VALUE,
+ Integer.MAX_VALUE, null);
ColumnVector doubles = ColumnVector.fromBoxedDoubles(
- 0.0, 100.0, -100.0, POSITIVE_DOUBLE_NAN_LOWER_RANGE, POSITIVE_DOUBLE_NAN_UPPER_RANGE, null);
+ 0.0, 100.0, -100.0, POSITIVE_DOUBLE_NAN_LOWER_RANGE, POSITIVE_DOUBLE_NAN_UPPER_RANGE,
+ null);
ColumnVector floats = ColumnVector.fromBoxedFloats(
0f, 100f, -100f, NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE, null);
ColumnVector bools = ColumnVector.fromBoxedBooleans(true, false, null, false, true, null);
ColumnView structs = ColumnView.makeStructView(strings, integers, doubles, floats, bools);
- ColumnVector result = Hash.murmurHash32(1868, new ColumnView[]{structs});
- ColumnVector expected = Hash.murmurHash32(1868, new ColumnVector[]{strings, integers, doubles, floats, bools})) {
+ ColumnVector result = Hash.murmurHash32(1868, new ColumnView[] {structs});
+ ColumnVector expected = Hash.murmurHash32(1868,
+ new ColumnVector[] {strings, integers, doubles, floats, bools})) {
assertColumnsAreEqual(expected, result);
}
}
@@ -193,9 +215,11 @@ void testSpark32BitMurmur3HashNestedStruct() {
"A very long (greater than 128 bytes/char string) to test a multi hash-step data point " +
"in the MD5 hash function. This string needed to be longer.",
null, null);
- ColumnVector integers = ColumnVector.fromBoxedInts(0, 100, -100, Integer.MIN_VALUE, Integer.MAX_VALUE, null);
+ ColumnVector integers = ColumnVector.fromBoxedInts(0, 100, -100, Integer.MIN_VALUE,
+ Integer.MAX_VALUE, null);
ColumnVector doubles = ColumnVector.fromBoxedDoubles(
- 0.0, 100.0, -100.0, POSITIVE_DOUBLE_NAN_LOWER_RANGE, POSITIVE_DOUBLE_NAN_UPPER_RANGE, null);
+ 0.0, 100.0, -100.0, POSITIVE_DOUBLE_NAN_LOWER_RANGE, POSITIVE_DOUBLE_NAN_UPPER_RANGE,
+ null);
ColumnVector floats = ColumnVector.fromBoxedFloats(
0f, 100f, -100f, NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE, null);
ColumnVector bools = ColumnVector.fromBoxedBooleans(true, false, null, false, true, null);
@@ -203,8 +227,9 @@ void testSpark32BitMurmur3HashNestedStruct() {
ColumnView structs2 = ColumnView.makeStructView(structs1, doubles);
ColumnView structs3 = ColumnView.makeStructView(bools);
ColumnView structs = ColumnView.makeStructView(structs2, floats, structs3);
- ColumnVector expected = Hash.murmurHash32(1868, new ColumnVector[]{strings, integers, doubles, floats, bools});
- ColumnVector result = Hash.murmurHash32(1868, new ColumnView[]{structs})) {
+ ColumnVector expected = Hash.murmurHash32(1868,
+ new ColumnVector[] {strings, integers, doubles, floats, bools});
+ ColumnVector result = Hash.murmurHash32(1868, new ColumnView[] {structs})) {
assertColumnsAreEqual(expected, result);
}
}
@@ -212,23 +237,24 @@ void testSpark32BitMurmur3HashNestedStruct() {
@Test
void testSpark32BitMurmur3HashListsAndNestedLists() {
try (ColumnVector stringListCV = ColumnVector.fromLists(
- new ListType(true, new BasicType(true, DType.STRING)),
- Arrays.asList(null, "a"),
- Arrays.asList("B\n", ""),
- Arrays.asList("dE\"\u0100\t\u0101", " \ud720\ud721"),
- Collections.singletonList("A very long (greater than 128 bytes/char string) to test a multi" +
- " hash-step data point in the Murmur3 hash function. This string needed to be longer."),
- Collections.singletonList(""),
- null);
+ new ListType(true, new BasicType(true, DType.STRING)),
+ Arrays.asList(null, "a"),
+ Arrays.asList("B\n", ""),
+ Arrays.asList("dE\"\u0100\t\u0101", " \ud720\ud721"),
+ Collections.singletonList(
+ "A very long (greater than 128 bytes/char string) to test a multi" +
+ " hash-step data point in the Murmur3 hash function. This string needed to be longer."),
+ Collections.singletonList(""),
+ null);
ColumnVector strings1 = ColumnVector.fromStrings(
"a", "B\n", "dE\"\u0100\t\u0101",
"A very long (greater than 128 bytes/char string) to test a multi hash-step data point " +
- "in the Murmur3 hash function. This string needed to be longer.", null, null);
+ "in the Murmur3 hash function. This string needed to be longer.", null, null);
ColumnVector strings2 = ColumnVector.fromStrings(
null, "", " \ud720\ud721", null, "", null);
ColumnView stringStruct = ColumnView.makeStructView(strings1, strings2);
- ColumnVector stringExpected = Hash.murmurHash32(1868, new ColumnView[]{stringStruct});
- ColumnVector stringResult = Hash.murmurHash32(1868, new ColumnView[]{stringListCV});
+ ColumnVector stringExpected = Hash.murmurHash32(1868, new ColumnView[] {stringStruct});
+ ColumnVector stringResult = Hash.murmurHash32(1868, new ColumnView[] {stringListCV});
ColumnVector intListCV = ColumnVector.fromLists(
new ListType(true, new BasicType(true, DType.INT32)),
null,
@@ -237,21 +263,25 @@ void testSpark32BitMurmur3HashListsAndNestedLists() {
Arrays.asList(5, -6, null),
Collections.singletonList(Integer.MIN_VALUE),
null);
- ColumnVector integers1 = ColumnVector.fromBoxedInts(null, 0, null, 5, Integer.MIN_VALUE, null);
- ColumnVector integers2 = ColumnVector.fromBoxedInts(null, -2, Integer.MAX_VALUE, null, null, null);
+ ColumnVector integers1 = ColumnVector.fromBoxedInts(null, 0, null, 5, Integer.MIN_VALUE,
+ null);
+ ColumnVector integers2 = ColumnVector.fromBoxedInts(null, -2, Integer.MAX_VALUE, null,
+ null, null);
ColumnVector integers3 = ColumnVector.fromBoxedInts(null, 3, null, -6, null, null);
ColumnVector intExpected =
- Hash.murmurHash32(1868, new ColumnVector[]{integers1, integers2, integers3});
- ColumnVector intResult = Hash.murmurHash32(1868, new ColumnVector[]{intListCV});
+ Hash.murmurHash32(1868, new ColumnVector[] {integers1, integers2, integers3});
+ ColumnVector intResult = Hash.murmurHash32(1868, new ColumnVector[] {intListCV});
ColumnVector doubles = ColumnVector.fromBoxedDoubles(
- 0.0, 100.0, -100.0, POSITIVE_DOUBLE_NAN_LOWER_RANGE, POSITIVE_DOUBLE_NAN_UPPER_RANGE, null);
+ 0.0, 100.0, -100.0, POSITIVE_DOUBLE_NAN_LOWER_RANGE, POSITIVE_DOUBLE_NAN_UPPER_RANGE,
+ null);
ColumnVector floats = ColumnVector.fromBoxedFloats(
- 0f, 100f, -100f, NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE, null);
+ 0f, 100f, -100f, NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE, null);
ColumnView structCV = ColumnView.makeStructView(intListCV, stringListCV, doubles, floats);
ColumnVector nestedExpected =
- Hash.murmurHash32(1868, new ColumnView[]{intListCV, strings1, strings2, doubles, floats});
+ Hash.murmurHash32(1868,
+ new ColumnView[] {intListCV, strings1, strings2, doubles, floats});
ColumnVector nestedResult =
- Hash.murmurHash32(1868, new ColumnView[]{structCV})) {
+ Hash.murmurHash32(1868, new ColumnView[] {structCV})) {
assertColumnsAreEqual(stringExpected, stringResult);
assertColumnsAreEqual(intExpected, intResult);
assertColumnsAreEqual(nestedExpected, nestedResult);
@@ -261,13 +291,15 @@ void testSpark32BitMurmur3HashListsAndNestedLists() {
@Test
void testXXHash64Strings() {
try (ColumnVector v0 = ColumnVector.fromStrings(
- "a", "B\nc", "dE\"\u0100\t\u0101 \ud720\ud721\\Fg2\'",
- "A very long (greater than 128 bytes/char string) to test a multi hash-step data point " +
- "in the MD5 hash function. This string needed to be longer.A 60 character string to " +
- "test MD5's message padding algorithm",
- "hiJ\ud720\ud721\ud720\ud721", null);
- ColumnVector result = Hash.xxhash64(new ColumnVector[]{v0});
- ColumnVector expected = ColumnVector.fromBoxedLongs(-8582455328737087284L, 2221214721321197934L, 5798966295358745941L, -4834097201550955483L, -3782648123388245694L, Hash.DEFAULT_XXHASH64_SEED)) {
+ "a", "B\nc", "dE\"\u0100\t\u0101 \ud720\ud721\\Fg2'",
+ "A very long (greater than 128 bytes/char string) to test a multi hash-step data point " +
+ "in the MD5 hash function. This string needed to be longer.A 60 character string to " +
+ "test MD5's message padding algorithm",
+ "hiJ\ud720\ud721\ud720\ud721", null);
+ ColumnVector result = Hash.xxhash64(new ColumnVector[] {v0});
+ ColumnVector expected = ColumnVector.fromBoxedLongs(-8582455328737087284L,
+ 2221214721321197934L, 5798966295358745941L, -4834097201550955483L,
+ -3782648123388245694L, Hash.DEFAULT_XXHASH64_SEED)) {
assertColumnsAreEqual(expected, result);
}
}
@@ -276,57 +308,69 @@ void testXXHash64Strings() {
void testXXHash64Ints() {
try (ColumnVector v0 = ColumnVector.fromBoxedInts(0, 100, null, null, Integer.MIN_VALUE, null);
ColumnVector v1 = ColumnVector.fromBoxedInts(0, null, -100, null, null, Integer.MAX_VALUE);
- ColumnVector result = Hash.xxhash64(new ColumnVector[]{v0, v1});
- ColumnVector expected = ColumnVector.fromBoxedLongs(1151812168208346021L, -7987742665087449293L, 8990748234399402673L, Hash.DEFAULT_XXHASH64_SEED, 2073849959933241805L, 1508894993788531228L)) {
+ ColumnVector result = Hash.xxhash64(new ColumnVector[] {v0, v1});
+ ColumnVector expected = ColumnVector.fromBoxedLongs(1151812168208346021L,
+ -7987742665087449293L, 8990748234399402673L, Hash.DEFAULT_XXHASH64_SEED,
+ 2073849959933241805L, 1508894993788531228L)) {
assertColumnsAreEqual(expected, result);
}
}
-
+
@Test
void testXXHash64Doubles() {
try (ColumnVector v = ColumnVector.fromBoxedDoubles(
- 0.0, null, 100.0, -100.0, Double.MIN_NORMAL, Double.MAX_VALUE,
- POSITIVE_DOUBLE_NAN_UPPER_RANGE, POSITIVE_DOUBLE_NAN_LOWER_RANGE,
- NEGATIVE_DOUBLE_NAN_UPPER_RANGE, NEGATIVE_DOUBLE_NAN_LOWER_RANGE,
- Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY);
- ColumnVector result = Hash.xxhash64(new ColumnVector[]{v});
- ColumnVector expected = ColumnVector.fromBoxedLongs(-5252525462095825812L, Hash.DEFAULT_XXHASH64_SEED, -7996023612001835843L, 5695175288042369293L, 6181148431538304986L, -4222314252576420879L, -3127944061524951246L, -3127944061524951246L, -3127944061524951246L, -3127944061524951246L, 5810986238603807492L, 5326262080505358431L)) {
+ 0.0, null, 100.0, -100.0, Double.MIN_NORMAL, Double.MAX_VALUE,
+ POSITIVE_DOUBLE_NAN_UPPER_RANGE, POSITIVE_DOUBLE_NAN_LOWER_RANGE,
+ NEGATIVE_DOUBLE_NAN_UPPER_RANGE, NEGATIVE_DOUBLE_NAN_LOWER_RANGE,
+ Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY);
+ ColumnVector result = Hash.xxhash64(new ColumnVector[] {v});
+ ColumnVector expected = ColumnVector.fromBoxedLongs(-5252525462095825812L,
+ Hash.DEFAULT_XXHASH64_SEED, -7996023612001835843L, 5695175288042369293L,
+ 6181148431538304986L, -4222314252576420879L, -3127944061524951246L,
+ -3127944061524951246L, -3127944061524951246L, -3127944061524951246L,
+ 5810986238603807492L, 5326262080505358431L)) {
assertColumnsAreEqual(expected, result);
}
}
-
+
@Test
void testXXHash64Timestamps() {
// The hash values were derived from Apache Spark in a manner similar to the one documented at
// https://github.com/rapidsai/cudf/blob/aa7ca46dcd9e/cpp/tests/hashing/hash_test.cpp#L281-L307
try (ColumnVector v = ColumnVector.timestampMicroSecondsFromBoxedLongs(
0L, null, 100L, -100L, 0x123456789abcdefL, null, -0x123456789abcdefL);
- ColumnVector result = Hash.xxhash64(new ColumnVector[]{v});
- ColumnVector expected = ColumnVector.fromBoxedLongs(-5252525462095825812L, Hash.DEFAULT_XXHASH64_SEED, 8713583529807266080L, 5675770457807661948L, 1941233597257011502L, Hash.DEFAULT_XXHASH64_SEED, -1318946533059658749L)) {
+ ColumnVector result = Hash.xxhash64(new ColumnVector[] {v});
+ ColumnVector expected = ColumnVector.fromBoxedLongs(-5252525462095825812L,
+ Hash.DEFAULT_XXHASH64_SEED, 8713583529807266080L, 5675770457807661948L,
+ 1941233597257011502L, Hash.DEFAULT_XXHASH64_SEED, -1318946533059658749L)) {
assertColumnsAreEqual(expected, result);
}
}
-
+
@Test
void testXXHash64Decimal64() {
// The hash values were derived from Apache Spark in a manner similar to the one documented at
// https://github.com/rapidsai/cudf/blob/aa7ca46dcd9e/cpp/tests/hashing/hash_test.cpp#L281-L307
try (ColumnVector v = ColumnVector.decimalFromLongs(-7,
0L, 100L, -100L, 0x123456789abcdefL, -0x123456789abcdefL);
- ColumnVector result = Hash.xxhash64(new ColumnVector[]{v});
- ColumnVector expected = ColumnVector.fromBoxedLongs(-5252525462095825812L, 8713583529807266080L, 5675770457807661948L, 1941233597257011502L, -1318946533059658749L)) {
+ ColumnVector result = Hash.xxhash64(new ColumnVector[] {v});
+ ColumnVector expected = ColumnVector.fromBoxedLongs(-5252525462095825812L,
+ 8713583529807266080L, 5675770457807661948L, 1941233597257011502L,
+ -1318946533059658749L)) {
assertColumnsAreEqual(expected, result);
}
}
-
+
@Test
void testXXHash64Decimal32() {
// The hash values were derived from Apache Spark in a manner similar to the one documented at
// https://github.com/rapidsai/cudf/blob/aa7ca46dcd9e/cpp/tests/hashing/hash_test.cpp#L281-L307
try (ColumnVector v = ColumnVector.decimalFromInts(-3,
0, 100, -100, 0x12345678, -0x12345678);
- ColumnVector result = Hash.xxhash64(new ColumnVector[]{v});
- ColumnVector expected = ColumnVector.fromBoxedLongs(-5252525462095825812L, 8713583529807266080L, 5675770457807661948L, -7728554078125612835L, 3142315292375031143L)) {
+ ColumnVector result = Hash.xxhash64(new ColumnVector[] {v});
+ ColumnVector expected = ColumnVector.fromBoxedLongs(-5252525462095825812L,
+ 8713583529807266080L, 5675770457807661948L, -7728554078125612835L,
+ 3142315292375031143L)) {
assertColumnsAreEqual(expected, result);
}
}
@@ -337,8 +381,10 @@ void testXXHash64Dates() {
// https://github.com/rapidsai/cudf/blob/aa7ca46dcd9e/cpp/tests/hashing/hash_test.cpp#L281-L307
try (ColumnVector v = ColumnVector.timestampDaysFromBoxedInts(
0, null, 100, -100, 0x12345678, null, -0x12345678);
- ColumnVector result = Hash.xxhash64(new ColumnVector[]{v});
- ColumnVector expected = ColumnVector.fromBoxedLongs(3614696996920510707L, Hash.DEFAULT_XXHASH64_SEED, -7987742665087449293L, 8990748234399402673L, 6954428822481665164L, Hash.DEFAULT_XXHASH64_SEED, -4294222333805341278L)) {
+ ColumnVector result = Hash.xxhash64(new ColumnVector[] {v});
+ ColumnVector expected = ColumnVector.fromBoxedLongs(3614696996920510707L,
+ Hash.DEFAULT_XXHASH64_SEED, -7987742665087449293L, 8990748234399402673L,
+ 6954428822481665164L, Hash.DEFAULT_XXHASH64_SEED, -4294222333805341278L)) {
assertColumnsAreEqual(expected, result);
}
}
@@ -346,12 +392,16 @@ void testXXHash64Dates() {
@Test
void testXXHash64Floats() {
try (ColumnVector v = ColumnVector.fromBoxedFloats(
- 0f, 100f, -100f, Float.MIN_NORMAL, Float.MAX_VALUE, null,
- POSITIVE_FLOAT_NAN_LOWER_RANGE, POSITIVE_FLOAT_NAN_UPPER_RANGE,
- NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE,
- Float.POSITIVE_INFINITY, Float.NEGATIVE_INFINITY);
- ColumnVector result = Hash.xxhash64(new ColumnVector[]{v});
- ColumnVector expected = ColumnVector.fromBoxedLongs(3614696996920510707L, -8232251799677946044L, -6625719127870404449L, -6699704595004115126L, -1065250890878313112L, Hash.DEFAULT_XXHASH64_SEED, 2692338816207849720L, 2692338816207849720L, 2692338816207849720L, 2692338816207849720L, -5940311692336719973L, -7580553461823983095L)){
+ 0f, 100f, -100f, Float.MIN_NORMAL, Float.MAX_VALUE, null,
+ POSITIVE_FLOAT_NAN_LOWER_RANGE, POSITIVE_FLOAT_NAN_UPPER_RANGE,
+ NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE,
+ Float.POSITIVE_INFINITY, Float.NEGATIVE_INFINITY);
+ ColumnVector result = Hash.xxhash64(new ColumnVector[] {v});
+ ColumnVector expected = ColumnVector.fromBoxedLongs(3614696996920510707L,
+ -8232251799677946044L, -6625719127870404449L, -6699704595004115126L,
+ -1065250890878313112L, Hash.DEFAULT_XXHASH64_SEED, 2692338816207849720L,
+ 2692338816207849720L, 2692338816207849720L, 2692338816207849720L,
+ -5940311692336719973L, -7580553461823983095L)) {
assertColumnsAreEqual(expected, result);
}
}
@@ -360,27 +410,34 @@ void testXXHash64Floats() {
void testXXHash64Bools() {
try (ColumnVector v0 = ColumnVector.fromBoxedBooleans(null, true, false, true, null, false);
ColumnVector v1 = ColumnVector.fromBoxedBooleans(null, true, false, null, false, true);
- ColumnVector result = Hash.xxhash64(new ColumnVector[]{v0, v1});
- ColumnVector expected = ColumnVector.fromBoxedLongs(Hash.DEFAULT_XXHASH64_SEED, 9083826852238114423L, 1151812168208346021L, -6698625589789238999L, 3614696996920510707L, 7945966957015589024L)) {
+ ColumnVector result = Hash.xxhash64(new ColumnVector[] {v0, v1});
+ ColumnVector expected = ColumnVector.fromBoxedLongs(Hash.DEFAULT_XXHASH64_SEED,
+ 9083826852238114423L, 1151812168208346021L, -6698625589789238999L,
+ 3614696996920510707L, 7945966957015589024L)) {
assertColumnsAreEqual(expected, result);
}
}
-
+
@Test
void testXXHash64Mixed() {
try (ColumnVector strings = ColumnVector.fromStrings(
- "a", "B\n", "dE\"\u0100\t\u0101 \ud720\ud721",
- "A very long (greater than 128 bytes/char string) to test a multi hash-step data point " +
- "in the MD5 hash function. This string needed to be longer.",
- null, null);
- ColumnVector integers = ColumnVector.fromBoxedInts(0, 100, -100, Integer.MIN_VALUE, Integer.MAX_VALUE, null);
+ "a", "B\n", "dE\"\u0100\t\u0101 \ud720\ud721",
+ "A very long (greater than 128 bytes/char string) to test a multi hash-step data point " +
+ "in the MD5 hash function. This string needed to be longer.",
+ null, null);
+ ColumnVector integers = ColumnVector.fromBoxedInts(0, 100, -100, Integer.MIN_VALUE,
+ Integer.MAX_VALUE, null);
ColumnVector doubles = ColumnVector.fromBoxedDoubles(
- 0.0, 100.0, -100.0, POSITIVE_DOUBLE_NAN_LOWER_RANGE, POSITIVE_DOUBLE_NAN_UPPER_RANGE, null);
+ 0.0, 100.0, -100.0, POSITIVE_DOUBLE_NAN_LOWER_RANGE, POSITIVE_DOUBLE_NAN_UPPER_RANGE,
+ null);
ColumnVector floats = ColumnVector.fromBoxedFloats(
- 0f, 100f, -100f, NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE, null);
+ 0f, 100f, -100f, NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE, null);
ColumnVector bools = ColumnVector.fromBoxedBooleans(true, false, null, false, true, null);
- ColumnVector result = Hash.xxhash64(new ColumnVector[]{strings, integers, doubles, floats, bools});
- ColumnVector expected = ColumnVector.fromBoxedLongs(7451748878409563026L, 6024043102550151964L, 3380664624738534402L, 8444697026100086329L, -5888679192448042852L, Hash.DEFAULT_XXHASH64_SEED)) {
+ ColumnVector result = Hash.xxhash64(
+ new ColumnVector[] {strings, integers, doubles, floats, bools});
+ ColumnVector expected = ColumnVector.fromBoxedLongs(7451748878409563026L,
+ 6024043102550151964L, 3380664624738534402L, 8444697026100086329L,
+ -5888679192448042852L, Hash.DEFAULT_XXHASH64_SEED)) {
assertColumnsAreEqual(expected, result);
}
}
@@ -388,7 +445,7 @@ void testXXHash64Mixed() {
@Test
void testHiveHashBools() {
try (ColumnVector v0 = ColumnVector.fromBoxedBooleans(true, false, null);
- ColumnVector result = Hash.hiveHash(new ColumnVector[]{v0});
+ ColumnVector result = Hash.hiveHash(new ColumnVector[] {v0});
ColumnVector expected = ColumnVector.fromInts(1, 0, 0)) {
assertColumnsAreEqual(expected, result);
}
@@ -397,10 +454,10 @@ void testHiveHashBools() {
@Test
void testHiveHashInts() {
try (ColumnVector v0 = ColumnVector.fromBoxedInts(
- Integer.MIN_VALUE, Integer.MAX_VALUE, -1, 1, -10, 10, null);
- ColumnVector result = Hash.hiveHash(new ColumnVector[]{v0});
+ Integer.MIN_VALUE, Integer.MAX_VALUE, -1, 1, -10, 10, null);
+ ColumnVector result = Hash.hiveHash(new ColumnVector[] {v0});
ColumnVector expected = ColumnVector.fromInts(
- Integer.MIN_VALUE, Integer.MAX_VALUE, -1, 1, -10, 10, 0)) {
+ Integer.MIN_VALUE, Integer.MAX_VALUE, -1, 1, -10, 10, 0)) {
assertColumnsAreEqual(expected, result);
}
}
@@ -408,10 +465,10 @@ void testHiveHashInts() {
@Test
void testHiveHashBytes() {
try (ColumnVector v0 = ColumnVector.fromBoxedBytes(
- Byte.MIN_VALUE, Byte.MAX_VALUE, (byte)-1, (byte)1, (byte)-10, (byte)10, null);
- ColumnVector result = Hash.hiveHash(new ColumnVector[]{v0});
+ Byte.MIN_VALUE, Byte.MAX_VALUE, (byte) -1, (byte) 1, (byte) -10, (byte) 10, null);
+ ColumnVector result = Hash.hiveHash(new ColumnVector[] {v0});
ColumnVector expected = ColumnVector.fromInts(
- Byte.MIN_VALUE, Byte.MAX_VALUE, -1, 1, -10, 10, 0)) {
+ Byte.MIN_VALUE, Byte.MAX_VALUE, -1, 1, -10, 10, 0)) {
assertColumnsAreEqual(expected, result);
}
}
@@ -419,10 +476,10 @@ void testHiveHashBytes() {
@Test
void testHiveHashLongs() {
try (ColumnVector v0 = ColumnVector.fromBoxedLongs(
- Long.MIN_VALUE, Long.MAX_VALUE, -1L, 1L, -10L, 10L, null);
- ColumnVector result = Hash.hiveHash(new ColumnVector[]{v0});
+ Long.MIN_VALUE, Long.MAX_VALUE, -1L, 1L, -10L, 10L, null);
+ ColumnVector result = Hash.hiveHash(new ColumnVector[] {v0});
ColumnVector expected = ColumnVector.fromInts(
- Integer.MIN_VALUE, Integer.MIN_VALUE, 0, 1, 9, 10, 0)) {
+ Integer.MIN_VALUE, Integer.MIN_VALUE, 0, 1, 9, 10, 0)) {
assertColumnsAreEqual(expected, result);
}
}
@@ -430,11 +487,11 @@ void testHiveHashLongs() {
@Test
void testHiveHashStrings() {
try (ColumnVector v0 = ColumnVector.fromStrings(
- "a", "B\n", "dE\"\u0100\t\u0101 \ud720\ud721", null,
- "This is a long string (greater than 128 bytes/char string) case to test this " +
- "hash function. Just want an abnormal case here to see if any error may happen when" +
- "doing the hive hashing");
- ColumnVector result = Hash.hiveHash(new ColumnVector[]{v0});
+ "a", "B\n", "dE\"\u0100\t\u0101 \ud720\ud721", null,
+ "This is a long string (greater than 128 bytes/char string) case to test this " +
+ "hash function. Just want an abnormal case here to see if any error may happen when" +
+ "doing the hive hashing");
+ ColumnVector result = Hash.hiveHash(new ColumnVector[] {v0});
ColumnVector expected = ColumnVector.fromInts(97, 2056, 745239896, 0, 2112075710)) {
assertColumnsAreEqual(expected, result);
}
@@ -443,13 +500,14 @@ void testHiveHashStrings() {
@Test
void testHiveHashFloats() {
try (ColumnVector v = ColumnVector.fromBoxedFloats(0f, 100f, -100f, Float.MIN_NORMAL,
- Float.MAX_VALUE, null, Float.MIN_VALUE,
- POSITIVE_FLOAT_NAN_LOWER_RANGE, POSITIVE_FLOAT_NAN_UPPER_RANGE,
- NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE,
- Float.POSITIVE_INFINITY, Float.NEGATIVE_INFINITY);
- ColumnVector result = Hash.hiveHash(new ColumnVector[]{v});
+ Float.MAX_VALUE, null, Float.MIN_VALUE,
+ POSITIVE_FLOAT_NAN_LOWER_RANGE, POSITIVE_FLOAT_NAN_UPPER_RANGE,
+ NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE,
+ Float.POSITIVE_INFINITY, Float.NEGATIVE_INFINITY);
+ ColumnVector result = Hash.hiveHash(new ColumnVector[] {v});
ColumnVector expected = ColumnVector.fromInts(0, 1120403456, -1027080192, 8388608,
- 2139095039, 0, 1, 2143289344, 2143289344, 2143289344, 2143289344, 2139095040, -8388608)){
+ 2139095039, 0, 1, 2143289344, 2143289344, 2143289344, 2143289344, 2139095040,
+ -8388608)) {
assertColumnsAreEqual(expected, result);
}
}
@@ -457,10 +515,10 @@ void testHiveHashFloats() {
@Test
void testHiveHashDoubles() {
try (ColumnVector v = ColumnVector.fromBoxedDoubles(0.0, 100.0, -100.0,
- POSITIVE_DOUBLE_NAN_LOWER_RANGE, POSITIVE_DOUBLE_NAN_UPPER_RANGE, null);
- ColumnVector result = Hash.hiveHash(new ColumnVector[]{v});
+ POSITIVE_DOUBLE_NAN_LOWER_RANGE, POSITIVE_DOUBLE_NAN_UPPER_RANGE, null);
+ ColumnVector result = Hash.hiveHash(new ColumnVector[] {v});
ColumnVector expected = ColumnVector.fromInts(0, 1079574528, -1067909120,
- 2146959360, 2146959360, 0)){
+ 2146959360, 2146959360, 0)) {
assertColumnsAreEqual(expected, result);
}
}
@@ -468,10 +526,10 @@ void testHiveHashDoubles() {
@Test
void testHiveHashDates() {
try (ColumnVector v = ColumnVector.timestampDaysFromBoxedInts(
- 0, null, 100, -100, 0x12345678, null, -0x12345678);
- ColumnVector result = Hash.hiveHash(new ColumnVector[]{v});
+ 0, null, 100, -100, 0x12345678, null, -0x12345678);
+ ColumnVector result = Hash.hiveHash(new ColumnVector[] {v});
ColumnVector expected = ColumnVector.fromInts(
- 0, 0, 100, -100, 0x12345678, 0, -0x12345678)) {
+ 0, 0, 100, -100, 0x12345678, 0, -0x12345678)) {
assertColumnsAreEqual(expected, result);
}
}
@@ -480,9 +538,9 @@ void testHiveHashDates() {
void testHiveHashTimestamps() {
try (ColumnVector v = ColumnVector.timestampMicroSecondsFromBoxedLongs(
0L, null, 100L, -100L, 0x123456789abcdefL, null, -0x123456789abcdefL);
- ColumnVector result = Hash.hiveHash(new ColumnVector[]{v});
+ ColumnVector result = Hash.hiveHash(new ColumnVector[] {v});
ColumnVector expected = ColumnVector.fromInts(
- 0, 0, 100000, 99999, -660040456, 0, 486894999)) {
+ 0, 0, 100000, 99999, -660040456, 0, 486894999)) {
assertColumnsAreEqual(expected, result);
}
}
@@ -490,23 +548,23 @@ void testHiveHashTimestamps() {
@Test
void testHiveHashMixed() {
try (ColumnVector strings = ColumnVector.fromStrings(
- "a", "B\n", "dE\"\u0100\t\u0101 \ud720\ud721",
- "This is a long string (greater than 128 bytes/char string) case to test this " +
- "hash function. Just want an abnormal case here to see if any error may happen when" +
- "doing the hive hashing",
- null, null);
+ "a", "B\n", "dE\"\u0100\t\u0101 \ud720\ud721",
+ "This is a long string (greater than 128 bytes/char string) case to test this " +
+ "hash function. Just want an abnormal case here to see if any error may happen when" +
+ "doing the hive hashing",
+ null, null);
ColumnVector integers = ColumnVector.fromBoxedInts(
- 0, 100, -100, Integer.MIN_VALUE, Integer.MAX_VALUE, null);
+ 0, 100, -100, Integer.MIN_VALUE, Integer.MAX_VALUE, null);
ColumnVector doubles = ColumnVector.fromBoxedDoubles(0.0, 100.0, -100.0,
- POSITIVE_DOUBLE_NAN_LOWER_RANGE, POSITIVE_DOUBLE_NAN_UPPER_RANGE, null);
+ POSITIVE_DOUBLE_NAN_LOWER_RANGE, POSITIVE_DOUBLE_NAN_UPPER_RANGE, null);
ColumnVector floats = ColumnVector.fromBoxedFloats(0f, 100f, -100f,
- NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE, null);
+ NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE, null);
ColumnVector bools = ColumnVector.fromBoxedBooleans(
- true, false, null, false, true, null);
- ColumnVector result = Hash.hiveHash(new ColumnVector[]{
- strings, integers, doubles, floats, bools});
+ true, false, null, false, true, null);
+ ColumnVector result = Hash.hiveHash(new ColumnVector[] {
+ strings, integers, doubles, floats, bools});
ColumnVector expected = ColumnVector.fromInts(89581538, 363542820, 413439036,
- 1272817854, 1513589666, 0)) {
+ 1272817854, 1513589666, 0)) {
assertColumnsAreEqual(expected, result);
}
}
diff --git a/src/test/java/com/nvidia/spark/rapids/jni/HilbertIndexTest.java b/src/test/java/com/nvidia/spark/rapids/jni/HilbertIndexTest.java
index 6226053612..47cdf36b28 100644
--- a/src/test/java/com/nvidia/spark/rapids/jni/HilbertIndexTest.java
+++ b/src/test/java/com/nvidia/spark/rapids/jni/HilbertIndexTest.java
@@ -16,18 +16,18 @@
package com.nvidia.spark.rapids.jni;
+import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual;
+
import ai.rapids.cudf.ColumnVector;
import org.davidmoten.hilbert.HilbertCurve;
import org.davidmoten.hilbert.SmallHilbertCurve;
import org.junit.jupiter.api.Test;
-import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual;
-
public class HilbertIndexTest {
static long[] getExpected(int numBits, int numRows, Integer[]... inputs) {
final int dimensions = inputs.length;
final int length = numBits * dimensions;
- assert(length <= 64);
+ assert (length <= 64);
SmallHilbertCurve shc = HilbertCurve.small().bits(numBits).dimensions(dimensions);
long[] ret = new long[numRows];
long[] tmpInputs = new long[dimensions];
@@ -60,7 +60,7 @@ public static void doTest(int numBits, int numRows, Integer[]... inputs) {
assertColumnsAreEqual(expectedCv, results);
}
} finally {
- for (ColumnVector cv: cvInputs) {
+ for (ColumnVector cv : cvInputs) {
if (cv != null) {
cv.close();
}
@@ -87,15 +87,15 @@ void test1Null() {
@Test
void testInt2NonNull() {
- Integer[] inputs1 = { 1, 500, 1000, 250};
- Integer[] inputs2 = {500, 400, 300, 200};
+ Integer[] inputs1 = {1, 500, 1000, 250};
+ Integer[] inputs2 = {500, 400, 300, 200};
doTest(10, inputs1.length, inputs1, inputs2);
}
@Test
void testInt2Null() {
- Integer[] inputs1 = { 0, null, 50, 1000};
- Integer[] inputs2 = {200, 300, 100, 0};
+ Integer[] inputs1 = {0, null, 50, 1000};
+ Integer[] inputs2 = {200, 300, 100, 0};
doTest(10, inputs1.length, inputs1, inputs2);
}
diff --git a/src/test/java/com/nvidia/spark/rapids/jni/HistogramTest.java b/src/test/java/com/nvidia/spark/rapids/jni/HistogramTest.java
index 9a1812f660..83c219e0a0 100644
--- a/src/test/java/com/nvidia/spark/rapids/jni/HistogramTest.java
+++ b/src/test/java/com/nvidia/spark/rapids/jni/HistogramTest.java
@@ -18,7 +18,6 @@
import ai.rapids.cudf.AssertUtils;
import ai.rapids.cudf.ColumnVector;
-
import org.junit.jupiter.api.Test;
public class HistogramTest {
@@ -27,7 +26,7 @@ void testZeroFrequency() {
try (ColumnVector values = ColumnVector.fromInts(5, 10, 30);
ColumnVector freqs = ColumnVector.fromLongs(1, 0, 1);
ColumnVector histogram = Histogram.createHistogramIfValid(values, freqs, true);
- ColumnVector percentiles = Histogram.percentileFromHistogram(histogram, new double[]{1},
+ ColumnVector percentiles = Histogram.percentileFromHistogram(histogram, new double[] {1},
false);
ColumnVector expected = ColumnVector.fromBoxedDoubles(5.0, null, 30.0)) {
AssertUtils.assertColumnsAreEqual(percentiles, expected);
@@ -39,7 +38,7 @@ void testAllNulls() {
try (ColumnVector values = ColumnVector.fromBoxedInts(null, null, null);
ColumnVector freqs = ColumnVector.fromLongs(1, 2, 3);
ColumnVector histogram = Histogram.createHistogramIfValid(values, freqs, true);
- ColumnVector percentiles = Histogram.percentileFromHistogram(histogram, new double[]{0.5},
+ ColumnVector percentiles = Histogram.percentileFromHistogram(histogram, new double[] {0.5},
false);
ColumnVector expected = ColumnVector.fromBoxedDoubles(null, null, null)) {
AssertUtils.assertColumnsAreEqual(percentiles, expected);
diff --git a/src/test/java/com/nvidia/spark/rapids/jni/HostTableTest.java b/src/test/java/com/nvidia/spark/rapids/jni/HostTableTest.java
index 0064dee1f5..76d1de597a 100644
--- a/src/test/java/com/nvidia/spark/rapids/jni/HostTableTest.java
+++ b/src/test/java/com/nvidia/spark/rapids/jni/HostTableTest.java
@@ -17,18 +17,30 @@
package com.nvidia.spark.rapids.jni;
import ai.rapids.cudf.AssertUtils;
+import ai.rapids.cudf.ColumnVector;
import ai.rapids.cudf.Cuda;
import ai.rapids.cudf.DType;
-import ai.rapids.cudf.ColumnVector;
-import ai.rapids.cudf.Table;
import ai.rapids.cudf.HostColumnVector.BasicType;
import ai.rapids.cudf.HostColumnVector.DataType;
import ai.rapids.cudf.HostColumnVector.ListType;
import ai.rapids.cudf.HostColumnVector.StructData;
import ai.rapids.cudf.HostColumnVector.StructType;
+import ai.rapids.cudf.Table;
import org.junit.jupiter.api.Test;
public class HostTableTest {
+ private static StructData struct(Object... values) {
+ return new StructData(values);
+ }
+
+ private static StructData[] structs(StructData... values) {
+ return values;
+ }
+
+ private static String[] strings(String... values) {
+ return values;
+ }
+
@Test
public void testRoundTripSync() {
try (Table expected = buildTable()) {
@@ -96,20 +108,36 @@ private Table buildTable() {
new BasicType(true, DType.INT32),
new BasicType(false, DType.FLOAT32));
return new Table.TestBuilder()
- .column( 100, 202, 3003, 40004, 5, -60, 1, null, 3, null, 5, null, 7, null, 9, null, 11, null, 13, null, 15)
- .column( true, true, false, false, true, null, true, true, null, false, false, null, true, true, null, false, false, null, true, true, null)
- .column( (byte)1, (byte)2, null, (byte)4, (byte)5, (byte)6, (byte)1, (byte)2, (byte)3, null, (byte)5, (byte)6, (byte)7, null, (byte)9, (byte)10, (byte)11, null, (byte)13, (byte)14, (byte)15)
- .column((short)6, (short)5, (short)4, null, (short)2, (short)1, (short)1, (short)2, (short)3, null, (short)5, (short)6, (short)7, null, (short)9, (short)10, null, (short)12, (short)13, (short)14, null)
- .column( 1L, null, 1001L, 50L, -2000L, null, 1L, 2L, 3L, 4L, null, 6L, 7L, 8L, 9L, null, 11L, 12L, 13L, 14L, null)
- .column( 10.1f, 20f, Float.NaN, 3.1415f, -60f, null, 1f, 2f, 3f, 4f, 5f, null, 7f, 8f, 9f, 10f, 11f, null, 13f, 14f, 15f)
- .column( 10.1f, 20f, Float.NaN, 3.1415f, -60f, -50f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f, 11f, 12f, 13f, 14f, 15f)
- .column( 10.1, 20.0, 33.1, 3.1415, -60.5, null, 1., 2., 3., 4., 5., 6., null, 8., 9., 10., 11., 12., null, 14., 15.)
- .timestampDayColumn(99, 100, 101, 102, 103, 104, 1, 2, 3, 4, 5, 6, 7, null, 9, 10, 11, 12, 13, null, 15)
- .timestampMillisecondsColumn(9L, 1006L, 101L, 5092L, null, 88L, 1L, 2L, 3L, 4L, 5L ,6L, 7L, 8L, null, 10L, 11L, 12L, 13L, 14L, 15L)
- .timestampSecondsColumn(1L, null, 3L, 4L, 5L, 6L, 1L, 2L, 3L, 4L, 5L ,6L, 7L, 8L, 9L, null, 11L, 12L, 13L, 14L, 15L)
- .decimal32Column(-3, 100, 202, 3003, 40004, 5, -60, 1, null, 3, null, 5, null, 7, null, 9, null, 11, null, 13, null, 15)
- .decimal64Column(-8, 1L, null, 1001L, 50L, -2000L, null, 1L, 2L, 3L, 4L, null, 6L, 7L, 8L, 9L, null, 11L, 12L, 13L, 14L, null)
- .column( "A", "B", "C", "D", null, "TESTING", "1", "2", "3", "4", "5", "6", "7", null, "9", "10", "11", "12", "13", null, "15")
+ .column(100, 202, 3003, 40004, 5, -60, 1, null, 3, null, 5, null, 7, null, 9, null, 11,
+ null, 13, null, 15)
+ .column(true, true, false, false, true, null, true, true, null, false, false, null, true,
+ true, null, false, false, null, true, true, null)
+ .column((byte) 1, (byte) 2, null, (byte) 4, (byte) 5, (byte) 6, (byte) 1, (byte) 2,
+ (byte) 3, null, (byte) 5, (byte) 6, (byte) 7, null, (byte) 9, (byte) 10, (byte) 11,
+ null, (byte) 13, (byte) 14, (byte) 15)
+ .column((short) 6, (short) 5, (short) 4, null, (short) 2, (short) 1, (short) 1, (short) 2,
+ (short) 3, null, (short) 5, (short) 6, (short) 7, null, (short) 9, (short) 10, null,
+ (short) 12, (short) 13, (short) 14, null)
+ .column(1L, null, 1001L, 50L, -2000L, null, 1L, 2L, 3L, 4L, null, 6L, 7L, 8L, 9L, null, 11L,
+ 12L, 13L, 14L, null)
+ .column(10.1f, 20f, Float.NaN, 3.1415f, -60f, null, 1f, 2f, 3f, 4f, 5f, null, 7f, 8f, 9f,
+ 10f, 11f, null, 13f, 14f, 15f)
+ .column(10.1f, 20f, Float.NaN, 3.1415f, -60f, -50f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f,
+ 11f, 12f, 13f, 14f, 15f)
+ .column(10.1, 20.0, 33.1, 3.1415, -60.5, null, 1., 2., 3., 4., 5., 6., null, 8., 9., 10.,
+ 11., 12., null, 14., 15.)
+ .timestampDayColumn(99, 100, 101, 102, 103, 104, 1, 2, 3, 4, 5, 6, 7, null, 9, 10, 11, 12,
+ 13, null, 15)
+ .timestampMillisecondsColumn(9L, 1006L, 101L, 5092L, null, 88L, 1L, 2L, 3L, 4L, 5L, 6L, 7L,
+ 8L, null, 10L, 11L, 12L, 13L, 14L, 15L)
+ .timestampSecondsColumn(1L, null, 3L, 4L, 5L, 6L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, null,
+ 11L, 12L, 13L, 14L, 15L)
+ .decimal32Column(-3, 100, 202, 3003, 40004, 5, -60, 1, null, 3, null, 5, null, 7, null, 9,
+ null, 11, null, 13, null, 15)
+ .decimal64Column(-8, 1L, null, 1001L, 50L, -2000L, null, 1L, 2L, 3L, 4L, null, 6L, 7L, 8L,
+ 9L, null, 11L, 12L, 13L, 14L, null)
+ .column("A", "B", "C", "D", null, "TESTING", "1", "2", "3", "4", "5", "6", "7", null, "9",
+ "10", "11", "12", "13", null, "15")
.column(
strings("1", "2", "3"), strings("4"), strings("5"), strings("6, 7"),
strings("", "9", null), strings("11"), strings(""), strings(null, null),
@@ -135,19 +163,8 @@ null, structs(struct("3", "4"), struct("1", "2")),
struct(Integer.MAX_VALUE, Float.MAX_VALUE), null, null, null, null,
null, null, null, null, null,
struct(Integer.MIN_VALUE, Float.MIN_VALUE))
- .column( "A", "A", "C", "C", null, "TESTING", "1", "2", "3", "4", "5", "6", "7", null, "9", "10", "11", "12", "13", null, "15")
+ .column("A", "A", "C", "C", null, "TESTING", "1", "2", "3", "4", "5", "6", "7", null, "9",
+ "10", "11", "12", "13", null, "15")
.build();
}
-
- private static StructData struct(Object... values) {
- return new StructData(values);
- }
-
- private static StructData[] structs(StructData... values) {
- return values;
- }
-
- private static String[] strings(String... values) {
- return values;
- }
}
diff --git a/src/test/java/com/nvidia/spark/rapids/jni/InterleaveBitsTest.java b/src/test/java/com/nvidia/spark/rapids/jni/InterleaveBitsTest.java
index 7455f3b58c..0491d5bba2 100644
--- a/src/test/java/com/nvidia/spark/rapids/jni/InterleaveBitsTest.java
+++ b/src/test/java/com/nvidia/spark/rapids/jni/InterleaveBitsTest.java
@@ -16,18 +16,20 @@
package com.nvidia.spark.rapids.jni;
+import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual;
+
import ai.rapids.cudf.ColumnVector;
import ai.rapids.cudf.DType;
import ai.rapids.cudf.HostColumnVector;
-import org.junit.jupiter.api.Test;
-
import java.util.ArrayList;
import java.util.List;
-
-import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual;
+import org.junit.jupiter.api.Test;
public class InterleaveBitsTest {
+ public static HostColumnVector.DataType outputType =
+ new HostColumnVector.ListType(true, new HostColumnVector.BasicType(false, DType.UINT8));
+
// The following source of truth comes from deltalake, but translated to java, and uses a List
// to make our tests simpler. Deltalake only supports ints. For completeness and better
// performance in the future we support more than this.
@@ -48,7 +50,7 @@ static List defaultInterleaveBits(Integer[] inputs) {
int idx = 0;
while (idx < inputs.length) {
int tmp = (((inputs[idx] >> bit) & 1) << ret_bit);
- ret_byte = (byte)(ret_byte | tmp);
+ ret_byte = (byte) (ret_byte | tmp);
ret_bit -= 1;
if (ret_bit == -1) {
// finished processing a byte
@@ -61,8 +63,8 @@ static List defaultInterleaveBits(Integer[] inputs) {
}
bit -= 1;
}
- assert(ret_idx == inputs.length * 4);
- assert(ret_bit == 7);
+ assert (ret_idx == inputs.length * 4);
+ assert (ret_bit == 7);
return ret;
}
@@ -83,7 +85,7 @@ static List defaultInterleaveBits(Short[] inputs) {
int idx = 0;
while (idx < inputs.length) {
int tmp = (((inputs[idx] >> bit) & 1) << ret_bit);
- ret_byte = (byte)(ret_byte | tmp);
+ ret_byte = (byte) (ret_byte | tmp);
ret_bit -= 1;
if (ret_bit == -1) {
// finished processing a byte
@@ -96,8 +98,8 @@ static List defaultInterleaveBits(Short[] inputs) {
}
bit -= 1;
}
- assert(ret_idx == inputs.length * 2);
- assert(ret_bit == 7);
+ assert (ret_idx == inputs.length * 2);
+ assert (ret_bit == 7);
return ret;
}
@@ -118,7 +120,7 @@ static List defaultInterleaveBits(Byte[] inputs) {
int idx = 0;
while (idx < inputs.length) {
int tmp = (((inputs[idx] >> bit) & 1) << ret_bit);
- ret_byte = (byte)(ret_byte | tmp);
+ ret_byte = (byte) (ret_byte | tmp);
ret_bit -= 1;
if (ret_bit == -1) {
// finished processing a byte
@@ -131,8 +133,8 @@ static List defaultInterleaveBits(Byte[] inputs) {
}
bit -= 1;
}
- assert(ret_idx == inputs.length);
- assert(ret_bit == 7);
+ assert (ret_idx == inputs.length);
+ assert (ret_bit == 7);
return ret;
}
@@ -172,9 +174,6 @@ static List[] getExpected(int numRows, Byte[]... inputs) {
return ret;
}
- public static HostColumnVector.DataType outputType =
- new HostColumnVector.ListType(true, new HostColumnVector.BasicType(false, DType.UINT8));
-
public static void doIntTest(int numRows, Integer[]... inputs) {
List[] expected = getExpected(numRows, inputs);
ColumnVector[] cvInputs = new ColumnVector[inputs.length];
@@ -187,7 +186,7 @@ public static void doIntTest(int numRows, Integer[]... inputs) {
assertColumnsAreEqual(expectedCv, results);
}
} finally {
- for (ColumnVector cv: cvInputs) {
+ for (ColumnVector cv : cvInputs) {
if (cv != null) {
cv.close();
}
@@ -207,7 +206,7 @@ public static void doShortTest(int numRows, Short[]... inputs) {
assertColumnsAreEqual(expectedCv, results);
}
} finally {
- for (ColumnVector cv: cvInputs) {
+ for (ColumnVector cv : cvInputs) {
if (cv != null) {
cv.close();
}
@@ -227,7 +226,7 @@ public static void doByteTest(int numRows, Byte[]... inputs) {
assertColumnsAreEqual(expectedCv, results);
}
} finally {
- for (ColumnVector cv: cvInputs) {
+ for (ColumnVector cv : cvInputs) {
if (cv != null) {
cv.close();
}
@@ -295,21 +294,21 @@ void testInt2NonNull() {
@Test
void testShort2NonNull() {
- Short[] inputs1 = {(short)0x0102, (short)0x0000, (short)0xFFFF, (short)0xFF00};
- Short[] inputs2 = {(short)0x1020, (short)0xFFFF, (short)0x0000, (short)0x00FF};
+ Short[] inputs1 = {(short) 0x0102, (short) 0x0000, (short) 0xFFFF, (short) 0xFF00};
+ Short[] inputs2 = {(short) 0x1020, (short) 0xFFFF, (short) 0x0000, (short) 0x00FF};
doShortTest(inputs1.length, inputs1, inputs2);
}
@Test
void testByte2NonNull() {
- Byte[] inputs1 = {(byte)0x01, (byte)0x00, (byte)0xFF, (byte)0x0F};
- Byte[] inputs2 = {(byte)0x10, (byte)0xFF, (byte)0x00, (byte)0xF0};
+ Byte[] inputs1 = {(byte) 0x01, (byte) 0x00, (byte) 0xFF, (byte) 0x0F};
+ Byte[] inputs2 = {(byte) 0x10, (byte) 0xFF, (byte) 0x00, (byte) 0xF0};
doByteTest(inputs1.length, inputs1, inputs2);
}
@Test
void testInt2Null() {
- Integer[] inputs1 = {0x00000000, null, 0xFFFFFFFF, 0xFF00FF00};
+ Integer[] inputs1 = {0x00000000, null, 0xFFFFFFFF, 0xFF00FF00};
Integer[] inputs2 = {0xFFFFFFFF, 0x00000000, 0x00FF00FF, null};
doIntTest(inputs1.length, inputs1, inputs2);
}
@@ -324,17 +323,17 @@ void testInt3NonNull() {
@Test
void testShort3NonNull() {
- Short[] inputs1 = {(short)0x0000, (short)0x4444, (short)0x1111};
- Short[] inputs2 = {(short)0x1111, (short)0x8888, (short)0x2222};
- Short[] inputs3 = {(short)0x2222, (short)0x0000, (short)0x4444};
+ Short[] inputs1 = {(short) 0x0000, (short) 0x4444, (short) 0x1111};
+ Short[] inputs2 = {(short) 0x1111, (short) 0x8888, (short) 0x2222};
+ Short[] inputs3 = {(short) 0x2222, (short) 0x0000, (short) 0x4444};
doShortTest(inputs1.length, inputs1, inputs2, inputs3);
}
@Test
void testByte3NonNull() {
- Byte[] inputs1 = {(byte)0x00, (byte)0x44, (byte)0x11};
- Byte[] inputs2 = {(byte)0x11, (byte)0x88, (byte)0x22};
- Byte[] inputs3 = {(byte)0x22, (byte)0x00, (byte)0x44};
+ Byte[] inputs1 = {(byte) 0x00, (byte) 0x44, (byte) 0x11};
+ Byte[] inputs2 = {(byte) 0x11, (byte) 0x88, (byte) 0x22};
+ Byte[] inputs3 = {(byte) 0x22, (byte) 0x00, (byte) 0x44};
doByteTest(inputs1.length, inputs1, inputs2, inputs3);
}
}
\ No newline at end of file
diff --git a/src/test/java/com/nvidia/spark/rapids/jni/LimitingOffHeapAllocForTests.java b/src/test/java/com/nvidia/spark/rapids/jni/LimitingOffHeapAllocForTests.java
index eb32667dc7..082773a18e 100644
--- a/src/test/java/com/nvidia/spark/rapids/jni/LimitingOffHeapAllocForTests.java
+++ b/src/test/java/com/nvidia/spark/rapids/jni/LimitingOffHeapAllocForTests.java
@@ -17,7 +17,6 @@
package com.nvidia.spark.rapids.jni;
import ai.rapids.cudf.HostMemoryBuffer;
-
import java.util.Optional;
/**
@@ -27,6 +26,7 @@
public class LimitingOffHeapAllocForTests {
private static long limit;
private static long amountAllocated = 0;
+
public static synchronized void setLimit(long limit) {
LimitingOffHeapAllocForTests.limit = limit;
if (amountAllocated > 0) {
@@ -68,6 +68,7 @@ private static Optional allocInternal(long amount, boolean blo
/**
* Do a non-blocking allocation
+ *
* @param amount the amount to allocate
* @return the allocated buffer or not.
*/
@@ -77,6 +78,7 @@ public static Optional tryAlloc(long amount) {
/**
* Do a blocking allocation
+ *
* @param amount the amount to allocate
* @return the allocated buffer
*/
diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ParseURITest.java b/src/test/java/com/nvidia/spark/rapids/jni/ParseURITest.java
index 1ddf588b02..5dc2605921 100644
--- a/src/test/java/com/nvidia/spark/rapids/jni/ParseURITest.java
+++ b/src/test/java/com/nvidia/spark/rapids/jni/ParseURITest.java
@@ -16,18 +16,16 @@
package com.nvidia.spark.rapids.jni;
+import ai.rapids.cudf.AssertUtils;
+import ai.rapids.cudf.ColumnVector;
import java.net.URI;
import java.net.URISyntaxException;
-
import org.junit.jupiter.api.Test;
-import ai.rapids.cudf.AssertUtils;
-import ai.rapids.cudf.ColumnVector;
-
public class ParseURITest {
void testProtocol(String[] testData) {
String[] expectedProtocolStrings = new String[testData.length];
- for (int i=0; i\tnumber of tasks that can run in parallel on the GPU");
System.out.println("--seed=\tthe random seed to use for the test");
System.out.println("--gpuMiB=\tlimit on the GPUs memory to use for testing");
- System.out.println("--taskMaxMiB=\tmaximum amount of memory a regular task may have allocated");
- System.out.println("--allocMode=\tthe RMM allocation mode to use POOL, ASYNC, ARENA, CUDA");
- System.out.println("--taskRetry=\tmaximum number of times to retry a task before failing the situation");
+ System.out.println(
+ "--taskMaxMiB=\tmaximum amount of memory a regular task may have allocated");
+ System.out.println(
+ "--allocMode=\tthe RMM allocation mode to use POOL, ASYNC, ARENA, CUDA");
+ System.out.println(
+ "--taskRetry=\tmaximum number of times to retry a task before failing the situation");
System.out.println("--maxTaskAllocs=\tmaximum number of allocations a task can make");
- System.out.println("--maxTaskSleep=\tmaximum amount of time a task can sleep for (sim processing)");
+ System.out.println(
+ "--maxTaskSleep=\tmaximum amount of time a task can sleep for (sim processing)");
System.out.println("--noLog\tdisable logging");
System.out.println("--skewed\tgenerate templated tasks and skew one of them by skewAmount");
System.out.println("--skewAmount=\tthe amount to multiply the skewed allocations by");
- System.out.println("--useTemplate\tif all of the tasks should be the same, but change by +/- templateChangeAmount as a multiplier");
- System.out.println("--templateChangeAmount=\tA multiplication factor to change the template task by when making new tasks (as a multiplier)");
- System.out.println("--shuffleThreads=\tThe number of threads to use to simulate UCX shuffle");
+ System.out.println(
+ "--useTemplate\tif all of the tasks should be the same, but change by +/- templateChangeAmount as a multiplier");
+ System.out.println(
+ "--templateChangeAmount=\tA multiplication factor to change the template task by when making new tasks (as a multiplier)");
+ System.out.println(
+ "--shuffleThreads=\tThe number of threads to use to simulate UCX shuffle");
System.exit(0);
} else {
throw new IllegalArgumentException("Unexpected argument " + arg +
@@ -181,11 +196,11 @@ public static void main(String [] args) throws InterruptedException {
}
public static void setupRmm(int allocationMode, long limitMiB, boolean useSparkRmm,
- boolean enableLogging) {
+ boolean enableLogging) {
long limitBytes = limitMiB * 1024 * 1024;
Rmm.LogConf rmmLog = null;
if (enableLogging) {
- rmmLog = Rmm.logTo(new File("./monte.rmm.log"));
+ rmmLog = Rmm.logTo(new File("./monte.rmm.log"));
}
if (allocationMode == RmmAllocationMode.CUDA_DEFAULT) {
// We want to limit the total size, but Rmm will not do that by default...
@@ -209,7 +224,8 @@ public static void setupRmm(int allocationMode, long limitMiB, boolean useSparkR
} else {
Rmm.initialize(allocationMode, rmmLog, limitBytes);
}
- boolean needsSync = (allocationMode & RmmAllocationMode.CUDA_ASYNC) == RmmAllocationMode.CUDA_ASYNC;
+ boolean needsSync =
+ (allocationMode & RmmAllocationMode.CUDA_ASYNC) == RmmAllocationMode.CUDA_ASYNC;
if (useSparkRmm) {
if (enableLogging) {
RmmSpark.setEventHandler(new TestRmmEventHandler(needsSync), "./monte.state.log");
@@ -223,12 +239,76 @@ public static void setupRmm(int allocationMode, long limitMiB, boolean useSparkR
}
}
+ private static List generateSituations(long seed, int numIterations, long numTasks,
+ long taskMaxMiB, int maxTaskAllocs,
+ int maxTaskSleep,
+ boolean isSkewed, double skewAmount,
+ boolean useTemplate,
+ double templateChangeAmount) {
+ ArrayList ret = new ArrayList<>(numIterations);
+ long start = System.nanoTime();
+ System.out.println("Generating " + numIterations + " test situations...");
+
+ Random r = new Random(seed);
+ for (int i = 0; i < numIterations; i++) {
+ ret.add(new Situation(r, numTasks, taskMaxMiB, maxTaskAllocs, maxTaskSleep,
+ isSkewed, skewAmount, useTemplate, templateChangeAmount));
+ }
+
+ long end = System.nanoTime();
+ long diff = TimeUnit.MILLISECONDS.convert(end - start, TimeUnit.NANOSECONDS);
+ System.out.println("Took " + diff + " milliseconds to generate " + numIterations);
+ return ret;
+ }
+
+ interface MemoryOp {
+ default void doIt(DeviceMemoryBuffer[] buffers, long taskId) {
+ long threadId = RmmSpark.getCurrentThreadId();
+ RmmSpark.shuffleThreadWorkingOnTasks(new long[] {taskId});
+ RmmSpark.startRetryBlock(threadId);
+ try {
+ int tries = 0;
+ while (tries < 100 && tries >= 0) {
+ try {
+ if (tries > 0) {
+ RmmSpark.blockThreadUntilReady();
+ }
+ tries++;
+ doIt(buffers);
+ tries = -1;
+ } catch (GpuRetryOOM oom) {
+ // Don't need to clear the buffers, because there is only one buffer.
+ numRetry.incrementAndGet();
+ } catch (CpuRetryOOM oom) {
+ // Don't need to clear the buffers, because there is only one buffer.
+ numRetry.incrementAndGet();
+ }
+ }
+ if (tries >= 100) {
+ throw new OutOfMemoryError("Could not make shuffle work after " + tries + " tries");
+ }
+ } finally {
+ RmmSpark.endRetryBlock(threadId);
+ RmmSpark.poolThreadFinishedForTask(taskId);
+ }
+ }
+
+ void doIt(DeviceMemoryBuffer[] buffers);
+
+ MemoryOp[] split();
+
+ MemoryOp randomMod(Random r, double templateChangeAmount);
+
+ MemoryOp makeSkewed(double skewAmount);
+ }
+
private static class TestRmmEventHandler implements RmmEventHandler {
private final boolean needsSync;
public TestRmmEventHandler(boolean needsSync) {
this.needsSync = needsSync;
}
+
@Override
public long[] getAllocThresholds() {
return null;
@@ -270,7 +350,7 @@ public static class TaskRunnerThread extends Thread {
volatile boolean done = false;
public TaskRunnerThread(CyclicBarrier barrier, SituationRunner runner, int taskRetry,
- ExecutorService shuffle) {
+ ExecutorService shuffle) {
this.barrier = barrier;
this.runner = runner;
this.taskRetry = taskRetry;
@@ -374,6 +454,7 @@ public boolean hadOtherFailures() {
static class ShuffleThreadFactory implements ThreadFactory {
static final AtomicLong idGen = new AtomicLong(0);
+
@Override
public Thread newThread(Runnable runnable) {
long id = idGen.getAndIncrement();
@@ -385,12 +466,10 @@ public Thread newThread(Runnable runnable) {
}
public static class SituationRunner {
- final TaskRunnerThread[] threads;
public final boolean debugOoms;
- private ExecutorService shuffle;
+ final TaskRunnerThread[] threads;
final CyclicBarrier barrier;
volatile boolean sitIsDone = false;
-
// Stats
volatile int failedSits;
volatile int successSits;
@@ -401,6 +480,7 @@ public static class SituationRunner {
volatile long totalTimeLost;
volatile boolean sitFailed;
volatile boolean didThisSitFail = false;
+ private final ExecutorService shuffle;
public SituationRunner(int parallelism, int taskRetry, int shuffleThreads, boolean debugOoms) {
this.debugOoms = debugOoms;
@@ -433,6 +513,22 @@ public SituationRunner(int parallelism, int taskRetry, int shuffleThreads, boole
}
}
+ private static String asTimeStr(long timeNs) {
+ long justms = TimeUnit.NANOSECONDS.toMillis(timeNs);
+
+ long hours = TimeUnit.NANOSECONDS.toHours(timeNs);
+ long hoursInNanos = TimeUnit.HOURS.toNanos(hours);
+ timeNs = timeNs - hoursInNanos;
+ long mins = TimeUnit.NANOSECONDS.toMinutes(timeNs);
+ long minsInNanos = TimeUnit.MINUTES.toNanos(mins);
+ timeNs = timeNs - minsInNanos;
+ long secs = TimeUnit.NANOSECONDS.toSeconds(timeNs);
+ long secsInNanos = TimeUnit.SECONDS.toNanos(secs);
+ long ns = timeNs - secsInNanos;
+ return String.format("%1$02d:%2$02d:%3$02d.%4$09d", hours, mins, secs, ns) +
+ " or " + justms + " ms";
+ }
+
public int run(List situations) throws InterruptedException {
numSplitAndRetry.set(0);
numRetry.set(0);
@@ -447,7 +543,7 @@ public int run(List situations) throws InterruptedException {
}
int numSits = 0;
long totalSitTime = 0;
- for (Situation sit: situations) {
+ for (Situation sit : situations) {
numSits++;
long start = System.nanoTime();
for (TaskRunnerThread t : threads) {
@@ -497,22 +593,6 @@ public int run(List situations) throws InterruptedException {
}
}
- private static String asTimeStr(long timeNs) {
- long justms = TimeUnit.NANOSECONDS.toMillis(timeNs);
-
- long hours = TimeUnit.NANOSECONDS.toHours(timeNs);
- long hoursInNanos = TimeUnit.HOURS.toNanos(hours);
- timeNs = timeNs - hoursInNanos;
- long mins = TimeUnit.NANOSECONDS.toMinutes(timeNs);
- long minsInNanos = TimeUnit.MINUTES.toNanos(mins);
- timeNs = timeNs - minsInNanos;
- long secs = TimeUnit.NANOSECONDS.toSeconds(timeNs);
- long secsInNanos = TimeUnit.SECONDS.toNanos(secs);
- long ns = timeNs - secsInNanos;
- return String.format("%1$02d:%2$02d:%3$02d.%4$09d", hours, mins, secs, ns) +
- " or " + justms + " ms";
- }
-
public void finish() {
for (TaskRunnerThread t : threads) {
t.finish();
@@ -535,49 +615,8 @@ public synchronized void setSitFailed() {
}
}
- interface MemoryOp {
- default void doIt(DeviceMemoryBuffer[] buffers, long taskId) {
- long threadId = RmmSpark.getCurrentThreadId();
- RmmSpark.shuffleThreadWorkingOnTasks(new long[]{taskId});
- RmmSpark.startRetryBlock(threadId);
- try {
- int tries = 0;
- while (tries < 100 && tries >= 0) {
- try {
- if (tries > 0) {
- RmmSpark.blockThreadUntilReady();
- }
- tries++;
- doIt(buffers);
- tries = -1;
- } catch (GpuRetryOOM oom) {
- // Don't need to clear the buffers, because there is only one buffer.
- numRetry.incrementAndGet();
- } catch (CpuRetryOOM oom) {
- // Don't need to clear the buffers, because there is only one buffer.
- numRetry.incrementAndGet();
- }
- }
- if (tries >= 100) {
- throw new OutOfMemoryError("Could not make shuffle work after " + tries + " tries");
- }
- } finally {
- RmmSpark.endRetryBlock(threadId);
- RmmSpark.poolThreadFinishedForTask(taskId);
- }
- }
-
- void doIt(DeviceMemoryBuffer[] buffers);
-
- MemoryOp[] split();
-
- MemoryOp randomMod(Random r, double templateChangeAmount);
-
- MemoryOp makeSkewed(double skewAmount);
- }
-
public static class AllocOp implements MemoryOp {
- private static AtomicLong idgen = new AtomicLong(0);
+ private static final AtomicLong idgen = new AtomicLong(0);
public final int offset;
private final long size;
@@ -598,6 +637,7 @@ private AllocOp(int offset, long size, long sleepTime, long id) {
this.sleepTime = sleepTime;
this.id = id;
}
+
@Override
public String toString() {
return "ALLOC[" + offset + "] " + size + " SLEEP " + sleepTime;
@@ -636,7 +676,7 @@ public MemoryOp[] split() {
@Override
public MemoryOp randomMod(Random r, double templateChangeAmount) {
double proposedSizeMult = (1.0 - (r.nextDouble() * 2.0)) * templateChangeAmount;
- long newSize = (long)(proposedSizeMult * size);
+ long newSize = (long) (proposedSizeMult * size);
if (newSize <= 0) {
newSize = 1;
}
@@ -645,7 +685,7 @@ public MemoryOp randomMod(Random r, double templateChangeAmount) {
@Override
public MemoryOp makeSkewed(double skewAmount) {
- return new AllocOp(offset, (long)(size * skewAmount), sleepTime);
+ return new AllocOp(offset, (long) (size * skewAmount), sleepTime);
}
}
@@ -662,7 +702,7 @@ public String toString() {
}
@Override
- public void doIt(DeviceMemoryBuffer[] buffers) {
+ public void doIt(DeviceMemoryBuffer[] buffers) {
DeviceMemoryBuffer buf = buffers[offset];
if (buf != null) {
buf.close();
@@ -692,9 +732,9 @@ public MemoryOp makeSkewed(double skewAmount) {
}
public static class TaskOpSet {
+ final int numBuffers;
DeviceMemoryBuffer[] buffers;
ArrayList operations;
- final int numBuffers;
long allocatedBeforeError = 0;
private TaskOpSet(int numBuffers, ArrayList operations) {
@@ -703,7 +743,7 @@ private TaskOpSet(int numBuffers, ArrayList operations) {
}
public TaskOpSet(Random r, long taskMaxMiB,
- int maxTaskAllocs, int maxTaskSleep) {
+ int maxTaskAllocs, int maxTaskSleep) {
long maxSleepTimeNano = TimeUnit.MILLISECONDS.toNanos(maxTaskSleep);
long totalSleepTimeNano = 0;
if (maxSleepTimeNano > 0) {
@@ -733,7 +773,8 @@ public TaskOpSet(Random r, long taskMaxMiB,
// We want the sleeps to be very small because we are not simulating
// the time, and generally they will be. In the future we can make this
// configurable.
- long sleepTime = (long)(totalSleepTimeNano * sleepWeights[allocOpNum]/totalSleepWeight);
+ long sleepTime =
+ (long) (totalSleepTimeNano * sleepWeights[allocOpNum] / totalSleepWeight);
AllocOp ao = new AllocOp(allocOpNum, size, sleepTime);
operations.add(ao);
outstandingAllocOps.add(ao);
@@ -777,9 +818,9 @@ public void run(ExecutorService shuffle, long taskId) {
allocatedBeforeError = 0;
boolean isForShuffle = shuffle != null;
boolean done = false;
- while(!done) {
+ while (!done) {
try {
- for (MemoryOp op: operations) {
+ for (MemoryOp op : operations) {
if (isForShuffle) {
try {
RmmSpark.submittingToPool();
@@ -827,7 +868,7 @@ public long getAllocatedBeforeError() {
public TaskOpSet randomMod(Random r, double templateChangeAmount) {
ArrayList newOps = new ArrayList<>(operations.size());
- for (MemoryOp op: operations) {
+ for (MemoryOp op : operations) {
newOps.add(op.randomMod(r, templateChangeAmount));
}
return new TaskOpSet(numBuffers, newOps);
@@ -835,7 +876,7 @@ public TaskOpSet randomMod(Random r, double templateChangeAmount) {
public TaskOpSet makeSkewed(double skewAmount) {
ArrayList newOps = new ArrayList<>(operations.size());
- for (MemoryOp op: operations) {
+ for (MemoryOp op : operations) {
newOps.add(op.makeSkewed(skewAmount));
}
return new TaskOpSet(numBuffers, newOps);
@@ -904,7 +945,7 @@ public String toString() {
public Task randomMod(Random r, double templateChangeAmount) {
LinkedList changed = new LinkedList<>();
- for (TaskOpSet orig: toDo) {
+ for (TaskOpSet orig : toDo) {
changed.add(orig.randomMod(r, templateChangeAmount));
}
return new Task(changed, 0);
@@ -912,7 +953,7 @@ public Task randomMod(Random r, double templateChangeAmount) {
public Task makeSkewed(double skewAmount) {
LinkedList changed = new LinkedList<>();
- for (TaskOpSet orig: toDo) {
+ for (TaskOpSet orig : toDo) {
changed.add(orig.makeSkewed(skewAmount));
}
return new Task(changed, 0);
@@ -923,8 +964,9 @@ public static class Situation {
LinkedList tasks = new LinkedList<>();
public Situation(Random r, long numTasks, long taskMaxMiB,
- int maxTaskAllocs, int maxTaskSleep,
- boolean isSkewed, double skewAmount, boolean useTemplate, double templateChangeAmount) {
+ int maxTaskAllocs, int maxTaskSleep,
+ boolean isSkewed, double skewAmount, boolean useTemplate,
+ double templateChangeAmount) {
if (useTemplate) {
Task template = new Task(r, taskMaxMiB, maxTaskAllocs, maxTaskSleep);
tasks.add(template);
@@ -957,23 +999,4 @@ public String toString() {
return "Sit: " + tasks.size();
}
}
-
- private static List generateSituations(long seed, int numIterations, long numTasks,
- long taskMaxMiB, int maxTaskAllocs, int maxTaskSleep,
- boolean isSkewed, double skewAmount, boolean useTemplate, double templateChangeAmount) {
- ArrayList ret = new ArrayList<>(numIterations);
- long start = System.nanoTime();
- System.out.println("Generating " + numIterations + " test situations...");
-
- Random r = new Random(seed);
- for (int i = 0; i < numIterations; i++) {
- ret.add(new Situation(r, numTasks, taskMaxMiB, maxTaskAllocs, maxTaskSleep,
- isSkewed, skewAmount, useTemplate, templateChangeAmount));
- }
-
- long end = System.nanoTime();
- long diff = TimeUnit.MILLISECONDS.convert(end - start, TimeUnit.NANOSECONDS);
- System.out.println("Took " + diff + " milliseconds to generate " + numIterations);
- return ret;
- }
}
diff --git a/src/test/java/com/nvidia/spark/rapids/jni/RmmSparkTest.java b/src/test/java/com/nvidia/spark/rapids/jni/RmmSparkTest.java
index 270a4266cd..abc2e520dd 100644
--- a/src/test/java/com/nvidia/spark/rapids/jni/RmmSparkTest.java
+++ b/src/test/java/com/nvidia/spark/rapids/jni/RmmSparkTest.java
@@ -16,6 +16,10 @@
package com.nvidia.spark.rapids.jni;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.fail;
+
import ai.rapids.cudf.CudfException;
import ai.rapids.cudf.DeviceMemoryBuffer;
import ai.rapids.cudf.HostMemoryBuffer;
@@ -27,20 +31,14 @@
import ai.rapids.cudf.RmmEventHandler;
import ai.rapids.cudf.RmmLimitingResourceAdaptor;
import ai.rapids.cudf.RmmTrackingResourceAdaptor;
-
-import org.junit.jupiter.api.AfterEach;
-import org.junit.jupiter.api.BeforeEach;
-import org.junit.jupiter.api.Test;
-
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
-
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertThrows;
-import static org.junit.jupiter.api.Assertions.fail;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
public class RmmSparkTest {
private final static long ALIGNMENT = 256;
@@ -59,249 +57,6 @@ public void teardown() {
}
}
- public interface TaskThreadOp {
- T doIt();
- }
-
- public static class TaskThread extends Thread {
- private final String name;
- private final boolean isForPool;
- private long threadId = -1;
- private long taskId = 100;
-
- public TaskThread(String name, long taskId) {
- this(name, false);
- this.taskId = taskId;
- }
-
- public TaskThread(String name, boolean isForPool) {
- super(name);
- this.name = name;
- this.isForPool = isForPool;
- }
-
- public synchronized long getThreadId() {
- return threadId;
- }
-
- private LinkedBlockingQueue queue = new LinkedBlockingQueue<>();
-
- public void initialize() throws ExecutionException, InterruptedException, TimeoutException {
- setDaemon(true);
- start();
- Future waitForStart = doIt(new TaskThreadOp() {
- @Override
- public Void doIt() {
- if (!isForPool) {
- RmmSpark.currentThreadIsDedicatedToTask(taskId);
- }
- return null;
- }
-
- @Override
- public String toString() {
- return "INIT TASK " + name + " " + (isForPool ? "POOL" : ("TASK " + taskId));
- }
- });
- System.err.println("WAITING FOR STARTUP (" + name + ")");
- waitForStart.get(1000, TimeUnit.MILLISECONDS);
- System.err.println("THREAD IS READY TO GO (" + name + ")");
- }
-
- public void pollForState(RmmSparkThreadState state, long l, TimeUnit tu) throws TimeoutException, InterruptedException {
- long start = System.nanoTime();
- long timeoutAfter = start + tu.toNanos(l);
- RmmSparkThreadState currentState = null;
- while (System.nanoTime() <= timeoutAfter) {
- currentState = RmmSpark.getStateOf(threadId);
- if (currentState == state) {
- return;
- }
- // Yes we are essentially doing a busy wait...
- Thread.sleep(10);
- }
- throw new TimeoutException(name + " WAITING FOR STATE " + state + " BUT STATE IS " + currentState);
- }
-
- private static class TaskThreadDoneOp implements TaskThreadOp, Future