diff --git a/.github/workflows/license-header-check.yml b/.github/workflows/license-header-check.yml new file mode 100644 index 0000000000..db783e66af --- /dev/null +++ b/.github/workflows/license-header-check.yml @@ -0,0 +1,55 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# A workflow to check copyright/license header +name: license header check + +on: + pull_request: + types: [opened, synchronize, reopened] + +jobs: + license-header-check: + runs-on: ubuntu-latest + if: "!contains(github.event.pull_request.title, '[bot]')" + steps: + - name: Get checkout depth + run: | + echo "PR_FETCH_DEPTH=$(( ${{ github.event.pull_request.commits }} + 10 ))" >> $GITHUB_ENV + + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: ${{ env.PR_FETCH_DEPTH }} + + - name: license-header-check + uses: NVIDIA/spark-rapids-common/license-header-check@main + with: + included_file_patterns: | + *.cpp, + *.hpp, + *.cu, + *.cuh, + *.java, + *.sh, + *Dockerfile*, + *Jenkinsfile*, + *.yml, + *.yaml, + *.txt, + *.xml, + *.fbs, + build/* + excluded_file_patterns: | + thirdparty/* diff --git a/.gitmodules b/.gitmodules index e2001c2c84..93d7670bbf 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,4 +1,4 @@ [submodule "thirdparty/cudf"] path = thirdparty/cudf url = https://github.com/rapidsai/cudf.git - branch = branch-24.10 + branch = branch-24.12 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 271e62feb1..bae978da31 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -165,7 +165,7 @@ $ ./build/build-in-docker install ... ``` Now cd to ~/repos/NVIDIA/spark-rapids and build with one of the options from -[spark-rapids instructions](https://github.com/NVIDIA/spark-rapids/blob/branch-24.10/CONTRIBUTING.md#building-from-source). +[spark-rapids instructions](https://github.com/NVIDIA/spark-rapids/blob/branch-24.12/CONTRIBUTING.md#building-from-source). ```bash $ ./build/buildall diff --git a/build/apply-patches b/build/apply-patches index 991613e6dc..31c2adcfdd 100755 --- a/build/apply-patches +++ b/build/apply-patches @@ -16,8 +16,6 @@ # limitations under the License. # -# Run a command in a Docker container with devtoolset - set -e BASE_DIR=$( git rev-parse --show-toplevel ) @@ -26,14 +24,51 @@ PATCH_DIR=${PATCH_DIR:-$(realpath "$BASE_DIR/patches/")} CUDF_DIR=${CUDF_DIR:-$(realpath "$BASE_DIR/thirdparty/cudf/")} +# Apply pattches to CUDF is problematic in a number of ways. But ultimately it comes down to +# making sure that a user can do development work in spark-rapids-jni without the patches +# getting in the way +# The operations I really want to support no matter what state CUDF is in are +# 1) Build the repo from scratch +# 2) Rebuild the repo without having to clean and start over +# 3) upmerge to a new version of the plugin including updating the cudf submodule +# +# Building from scratch is simple. We want clean to unapply any patches and +# build to apply them. But if we want to rebuild without a clean we need to know what +# state the CUDF repo is in. Did we apply patches to it or not. The fastest way to do this +# is to save some state files about what happened. But a user could mess with CUDF directly +# so we want to have ways to double check that they are indeed correct. + +FULLY_PATCHED_FILE="$CUDF_DIR/spark-rapids-jni.patch" + pushd "$CUDF_DIR" -if [ -n "$(git status --porcelain --untracked-files=no)" ] ; then - echo "Error: CUDF repository has uncommitted changes. No patches will be applied..." - exit 1 + +PATCH_FILES=$(find "$PATCH_DIR" -type f -not -empty) + +if [ -z "$PATCH_FILES" ] ; then + echo "No patches to apply" + exit 0 +fi + +CHANGED_FILES=$(git status --porcelain --untracked-files=no) + +if [ \( -s "$FULLY_PATCHED_FILE" \) -a \( -n "$CHANGED_FILES" \) ] ; then + if git apply -R --check "$FULLY_PATCHED_FILE" ; then + echo "Patches appear to have been applied already" + exit 0 + fi +fi + +if [ -n "$CHANGED_FILES" ] ; then + echo "Error: CUDF repository has uncommitted changes. No patches will be applied. Please clean the repository so we can try and add the needed patches" + echo "$CHANGED_FILE" + exit 1 fi find "$PATCH_DIR" -maxdepth 1 -type f -print0 | sort -zV | while IFS= read -r -d '' file; do - echo "patching with: $file" - patch --no-backup-if-mismatch -f -t --reject-file=- -p1 -i "$file" + echo "patching with: $file" + git apply -v "$file" done + +git diff > "$FULLY_PATCHED_FILE" + popd diff --git a/build/unapply-patches b/build/unapply-patches index 186a781ade..a31708e25f 100755 --- a/build/unapply-patches +++ b/build/unapply-patches @@ -16,29 +16,67 @@ # limitations under the License. # -# Run a command in a Docker container with devtoolset - set -e -SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +BASE_DIR=$( git rev-parse --show-toplevel ) + +PATCH_DIR=${PATCH_DIR:-$(realpath "$BASE_DIR/patches/")} -PATCH_DIR=${PATCH_DIR:-$(realpath "$SCRIPT_DIR/../patches/")} +CUDF_DIR=${CUDF_DIR:-$(realpath "$BASE_DIR/thirdparty/cudf/")} -CUDF_DIR=${CUDF_DIR:-$(realpath "$SCRIPT_DIR/../thirdparty/cudf/")} +# Apply pattches to CUDF is problematic in a number of ways. But ultimately it comes down to +# making sure that a user can do development work in spark-rapids-jni without the patches +# getting in the way +# The operations I really want to support no matter what state CUDF is in are +# 1) Build the repo from scratch +# 2) Rebuild the repo without having to clean and start over +# 3) upmerge to a new version of the plugin including updating the cudf submodule +# +# Building from scratch is simple. We want clean to unapply any patches and +# build to apply them. But if we want to rebuild without a clean we need to know what +# state the CUDF repo is in. Did we apply patches to it or not. The fastest way to do this +# is to save some state files about what happened. But a user could mess with CUDF directly +# so we want to have ways to double check that they are indeed correct. +FULLY_PATCHED_FILE="$CUDF_DIR/spark-rapids-jni.patch" pushd "$CUDF_DIR" -if [ -n "$(git status --porcelain --untracked-files=no)" ] ; then - #only try to remove patches if it looks like something was changed - find "$PATCH_DIR" -maxdepth 1 -type f -print0 | sort -zV -r | while IFS= read -r -d '' file; do - echo "patching with: $file" - patch -R --no-backup-if-mismatch --reject-file=- -f -t -p1 -i "$file" - done + +PATCH_FILES=$(find "$PATCH_DIR" -type f -not -empty) + +if [ -z "$PATCH_FILES" ] ; then + echo "No patches to remove" + exit 0 fi -# Check for modifications -if [ -n "$(git status --porcelain --untracked-files=no)" ] ; then - echo "Error: CUDF repository has uncommitted changes. You might want to clean in manually if you know that is expected" +CHANGED_FILES=$(git status --porcelain --untracked-files=no) + +if [ \( -s "$FULLY_PATCHED_FILE" \) -a \( -n "$CHANGED_FILES" \) ] ; then + if git apply --check -R "$FULLY_PATCHED_FILE"; then + echo "Patches appear to have been applied, so going to remove them" + git apply -R -v "$FULLY_PATCHED_FILE" + rm -f "$FULLY_PATCHED_FILE" + + # Check for modifications, again + if [ -n "$(git status --porcelain --untracked-files=no)" ] ; then + echo "Error: CUDF repository has uncommitted changes. You might want to clean in manually if you know that is expected" + git status --porcelain --untracked-files=no + exit 1 + fi + + exit 0 + else + echo "Files are changed, but in a way where the full path file does not apply to remove them $FULL_PATCHED_FILE" exit 1 + fi fi + +if [ -n "$CHANGED_FILES" ] ; then + echo "Error: CUDF repository has uncommitted changes, but does not appear to have been patched. Please clean it and try again." + echo "$CHANGED_FILE" + exit 1 +else + echo "No changes in CUDF repository to remove" +fi + popd diff --git a/ci/Dockerfile b/ci/Dockerfile index b3f4239dc6..4b2ff969b2 100644 --- a/ci/Dockerfile +++ b/ci/Dockerfile @@ -36,7 +36,7 @@ RUN dnf --enablerepo=powertools install -y scl-utils gcc-toolset-${TOOLSET_VERSI RUN mkdir -m 777 /usr/local/rapids /rapids # 3.22.3: CUDA architecture 'native' support + flexible CMAKE__*_LAUNCHER for ccache -ARG CMAKE_VERSION=3.26.4 +ARG CMAKE_VERSION=3.28.6 # default x86_64 from x86 build, aarch64 cmake for arm build ARG CMAKE_ARCH=x86_64 RUN cd /usr/local && wget --quiet https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}-linux-${CMAKE_ARCH}.tar.gz && \ diff --git a/ci/submodule-sync.sh b/ci/submodule-sync.sh index 29b0cf5dad..8d4089a6c7 100755 --- a/ci/submodule-sync.sh +++ b/ci/submodule-sync.sh @@ -57,26 +57,29 @@ if [ -n "$CUDF_TAG" ]; then else git submodule update --remote --merge fi + +cudf_pins_only=false cudf_sha=$(git -C thirdparty/cudf rev-parse HEAD) if [[ "${cudf_sha}" == "${cudf_prev_sha}" ]]; then - echo "Submodule is up to date." - exit 0 + echo "cuDF submodule is up to date. Try update cudf-pins..." + cudf_pins_only=true +else + echo "Try update cudf submodule to ${cudf_sha}..." + git add . + git commit -s -m "Update submodule cudf to ${cudf_sha}" fi -echo "Try update cudf submodule to ${cudf_sha}..." -git add . - -echo "Test against ${cudf_sha}..." - +echo "Build libcudf only to update pinned versions..." MVN="mvn -Dmaven.wagon.http.retryHandler.count=3 -B" set +e # Don't do a full build. Just try to update/build CUDF with no patches on top of it. -${MVN} validate ${MVN_MIRROR} \ +# calling the antrun directly skips applying patches and also only builds +# libcudf +${MVN} antrun:run@build-libcudf ${MVN_MIRROR} \ -DCPP_PARALLEL_LEVEL=${PARALLEL_LEVEL} \ -Dlibcudf.build.configure=true \ -Dlibcudf.dependency.mode=latest \ - -Dsubmodule.patch.skip \ - -DUSE_GDS=ON -Dtest=*,!CuFileTest,!CudaFatalTest,!ColumnViewNonEmptyNullsTest \ + -DUSE_GDS=ON \ -DBUILD_TESTS=ON \ -DUSE_SANITIZER=ON validate_status=$? @@ -88,21 +91,25 @@ rapids_cmake_sha=$(git -C ${LIBCUDF_BUILD_PATH}/_deps/rapids-cmake-src/ rev-pars echo "Update rapids-cmake pinned SHA1 to ${rapids_cmake_sha}" echo "${rapids_cmake_sha}" > thirdparty/cudf-pins/rapids-cmake.sha -# Bash the wrong nvcomp version to the correct version until -# nvcomp version mismatch is fixed. https://github.com/rapidsai/cudf/issues/16772. -echo "Revert nvcomp to 3.0.6" -sed -i -e 's/4\.0\.1\.0/3.0.6/' \ - -e 's|https://developer.download.nvidia.com/compute/nvcomp/${version}/local_installers/nvcomp-linux-sbsa-${version}-cuda${cuda-toolkit-version-mapping}.tar.gz|https://developer.download.nvidia.com/compute/nvcomp/${version}/local_installers/nvcomp_${version}_SBSA_${cuda-toolkit-version-mapping}.tgz|' \ - -e 's|https://developer.download.nvidia.com/compute/nvcomp/${version}/local_installers/nvcomp-linux-x86_64-${version}-cuda${cuda-toolkit-version-mapping}.tar.gz|https://developer.download.nvidia.com/compute/nvcomp/${version}/local_installers/nvcomp_${version}_x86_64_${cuda-toolkit-version-mapping}.tgz|' \ - thirdparty/cudf-pins/versions.json +echo "Workaround for https://github.com/NVIDIA/spark-rapids-jni/issues/2582" +cudf_patch_path="cudf/cpp/cmake/thirdparty/patches" +sed -i "s|\${current_json_dir}|\${current_json_dir}/../${cudf_patch_path}|g" thirdparty/cudf-pins/versions.json # Do the git add after the build so that we get # the updated versions.json generated by the build echo "Update cudf submodule to ${cudf_sha} with updated pinned versions" git add . -git diff-index --quiet HEAD || git commit -s -m "Update submodule cudf to ${cudf_sha}" +if ! git diff-index --quiet HEAD; then + # We perform a squash merge for submodule-sync commits + git commit -s -m "Update pinned versions for cudf ${cudf_sha}" +elif ${cudf_pins_only}; then + echo "No changes to commit. Exit early..." + exit 0 +fi + sha=$(git rev-parse HEAD) +echo "Test against ${cudf_sha}..." set +e # now build and test everything with the patches in place ${MVN} clean verify ${MVN_MIRROR} \ diff --git a/patches/noop.patch b/patches/noop.patch new file mode 100644 index 0000000000..e69de29bb2 diff --git a/patches/revert_nvcomp4.patch b/patches/revert_nvcomp4.patch deleted file mode 100644 index 88b58b14dc..0000000000 --- a/patches/revert_nvcomp4.patch +++ /dev/null @@ -1,907 +0,0 @@ -diff --git a/ci/build_wheel_cudf.sh b/ci/build_wheel_cudf.sh -index fb93b06dbe..e5565c4b53 100755 ---- a/ci/build_wheel_cudf.sh -+++ b/ci/build_wheel_cudf.sh -@@ -23,6 +23,8 @@ export PIP_CONSTRAINT="/tmp/constraints.txt" - python -m auditwheel repair \ - --exclude libcudf.so \ - --exclude libnvcomp.so \ -+ --exclude libnvcomp_bitcomp.so \ -+ --exclude libnvcomp_gdeflate.so \ - -w ${package_dir}/final_dist \ - ${package_dir}/dist/* - -diff --git a/ci/build_wheel_pylibcudf.sh b/ci/build_wheel_pylibcudf.sh -index 5e9f7f8a0c..0e4745bda2 100755 ---- a/ci/build_wheel_pylibcudf.sh -+++ b/ci/build_wheel_pylibcudf.sh -@@ -21,6 +21,8 @@ export PIP_CONSTRAINT="/tmp/constraints.txt" - python -m auditwheel repair \ - --exclude libcudf.so \ - --exclude libnvcomp.so \ -+ --exclude libnvcomp_bitcomp.so \ -+ --exclude libnvcomp_gdeflate.so \ - -w ${package_dir}/final_dist \ - ${package_dir}/dist/* - -diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml -index 5a05dfd053..e7363645d6 100644 ---- a/conda/environments/all_cuda-118_arch-x86_64.yaml -+++ b/conda/environments/all_cuda-118_arch-x86_64.yaml -@@ -58,7 +58,7 @@ dependencies: - - numpy>=1.23,<3.0a0 - - numpydoc - - nvcc_linux-64=11.8 --- nvcomp==4.0.1 -+- nvcomp==3.0.6 - - nvtx>=0.2.1 - - openpyxl - - packaging -diff --git a/conda/environments/all_cuda-125_arch-x86_64.yaml b/conda/environments/all_cuda-125_arch-x86_64.yaml -index 8490296233..3559a1a341 100644 ---- a/conda/environments/all_cuda-125_arch-x86_64.yaml -+++ b/conda/environments/all_cuda-125_arch-x86_64.yaml -@@ -56,7 +56,7 @@ dependencies: - - numba>=0.57 - - numpy>=1.23,<3.0a0 - - numpydoc --- nvcomp==4.0.1 -+- nvcomp==3.0.6 - - nvtx>=0.2.1 - - openpyxl - - packaging -diff --git a/conda/recipes/libcudf/conda_build_config.yaml b/conda/recipes/libcudf/conda_build_config.yaml -index dc75eb4b25..67d501d746 100644 ---- a/conda/recipes/libcudf/conda_build_config.yaml -+++ b/conda/recipes/libcudf/conda_build_config.yaml -@@ -35,7 +35,7 @@ spdlog_version: - - ">=1.14.1,<1.15" - - nvcomp_version: -- - "=4.0.1" -+ - "=3.0.6" - - zlib_version: - - ">=1.2.13" -diff --git a/cpp/include/cudf/io/nvcomp_adapter.hpp b/cpp/include/cudf/io/nvcomp_adapter.hpp -index 0d74a4158a..f3260d0cb5 100644 ---- a/cpp/include/cudf/io/nvcomp_adapter.hpp -+++ b/cpp/include/cudf/io/nvcomp_adapter.hpp -@@ -24,7 +24,7 @@ - namespace CUDF_EXPORT cudf { - namespace io::nvcomp { - --enum class compression_type { SNAPPY, ZSTD, DEFLATE, LZ4, GZIP }; -+enum class compression_type { SNAPPY, ZSTD, DEFLATE, LZ4 }; - - /** - * @brief Set of parameters that impact whether nvCOMP features are enabled. -@@ -36,20 +36,33 @@ struct feature_status_parameters { - int lib_patch_version; ///< patch version - bool are_all_integrations_enabled; ///< all integrations - bool are_stable_integrations_enabled; ///< stable integrations -+ int compute_capability_major; ///< cuda compute major version - - /** -- * @brief Default constructor using the current version of nvcomp and current environment -- * variables -+ * @brief Default Constructor - */ - feature_status_parameters(); - - /** -- * @brief Constructor using the current version of nvcomp -+ * @brief feature_status_parameters Constructor - * -+ * @param major positive integer representing major value of nvcomp -+ * @param minor positive integer representing minor value of nvcomp -+ * @param patch positive integer representing patch value of nvcomp - * @param all_enabled if all integrations are enabled - * @param stable_enabled if stable integrations are enabled -+ * @param cc_major CUDA compute capability - */ -- feature_status_parameters(bool all_enabled, bool stable_enabled); -+ feature_status_parameters( -+ int major, int minor, int patch, bool all_enabled, bool stable_enabled, int cc_major) -+ : lib_major_version{major}, -+ lib_minor_version{minor}, -+ lib_patch_version{patch}, -+ are_all_integrations_enabled{all_enabled}, -+ are_stable_integrations_enabled{stable_enabled}, -+ compute_capability_major{cc_major} -+ { -+ } - }; - - /** -@@ -61,7 +74,8 @@ inline bool operator==(feature_status_parameters const& lhs, feature_status_para - lhs.lib_minor_version == rhs.lib_minor_version and - lhs.lib_patch_version == rhs.lib_patch_version and - lhs.are_all_integrations_enabled == rhs.are_all_integrations_enabled and -- lhs.are_stable_integrations_enabled == rhs.are_stable_integrations_enabled; -+ lhs.are_stable_integrations_enabled == rhs.are_stable_integrations_enabled and -+ lhs.compute_capability_major == rhs.compute_capability_major; - } - - /** -diff --git a/cpp/src/io/comp/nvcomp_adapter.cpp b/cpp/src/io/comp/nvcomp_adapter.cpp -index c3187f73a9..3191e8f015 100644 ---- a/cpp/src/io/comp/nvcomp_adapter.cpp -+++ b/cpp/src/io/comp/nvcomp_adapter.cpp -@@ -22,46 +22,94 @@ - #include - #include - --#include --#include - #include - #include --#include - - #include - -+#define NVCOMP_DEFLATE_HEADER -+#if __has_include(NVCOMP_DEFLATE_HEADER) -+#include NVCOMP_DEFLATE_HEADER -+#endif -+ -+#define NVCOMP_ZSTD_HEADER -+#if __has_include(NVCOMP_ZSTD_HEADER) -+#include NVCOMP_ZSTD_HEADER -+#endif -+ -+// When building with nvcomp 4.0 or newer, map the new version macros to the old ones -+#ifndef NVCOMP_MAJOR_VERSION -+#define NVCOMP_MAJOR_VERSION NVCOMP_VER_MAJOR -+#define NVCOMP_MINOR_VERSION NVCOMP_VER_MINOR -+#define NVCOMP_PATCH_VERSION NVCOMP_VER_PATCH -+#endif -+ -+#define NVCOMP_HAS_ZSTD_DECOMP(MAJOR, MINOR, PATCH) (MAJOR > 2 or (MAJOR == 2 and MINOR >= 3)) -+ -+#define NVCOMP_HAS_ZSTD_COMP(MAJOR, MINOR, PATCH) (MAJOR > 2 or (MAJOR == 2 and MINOR >= 4)) -+ -+#define NVCOMP_HAS_DEFLATE(MAJOR, MINOR, PATCH) (MAJOR > 2 or (MAJOR == 2 and MINOR >= 5)) -+ -+#define NVCOMP_HAS_DECOMP_TEMPSIZE_EX(MAJOR, MINOR, PATCH) \ -+ (MAJOR > 2 or (MAJOR == 2 and MINOR > 3) or (MAJOR == 2 and MINOR == 3 and PATCH >= 1)) -+ -+#define NVCOMP_HAS_COMP_TEMPSIZE_EX(MAJOR, MINOR, PATCH) (MAJOR > 2 or (MAJOR == 2 and MINOR >= 6)) -+ -+// ZSTD is stable for nvcomp 2.3.2 or newer -+#define NVCOMP_ZSTD_DECOMP_IS_STABLE(MAJOR, MINOR, PATCH) \ -+ (MAJOR > 2 or (MAJOR == 2 and MINOR > 3) or (MAJOR == 2 and MINOR == 3 and PATCH >= 2)) -+ - namespace cudf::io::nvcomp { - - // Dispatcher for nvcompBatchedDecompressGetTempSizeEx - template --auto batched_decompress_get_temp_size_ex(compression_type compression, Args&&... args) -+std::optional batched_decompress_get_temp_size_ex(compression_type compression, -+ Args&&... args) - { -+#if NVCOMP_HAS_DECOMP_TEMPSIZE_EX(NVCOMP_MAJOR_VERSION, NVCOMP_MINOR_VERSION, NVCOMP_PATCH_VERSION) - switch (compression) { - case compression_type::SNAPPY: - return nvcompBatchedSnappyDecompressGetTempSizeEx(std::forward(args)...); - case compression_type::ZSTD: -+#if NVCOMP_HAS_ZSTD_DECOMP(NVCOMP_MAJOR_VERSION, NVCOMP_MINOR_VERSION, NVCOMP_PATCH_VERSION) - return nvcompBatchedZstdDecompressGetTempSizeEx(std::forward(args)...); -+#else -+ return std::nullopt; -+#endif - case compression_type::LZ4: - return nvcompBatchedLZ4DecompressGetTempSizeEx(std::forward(args)...); -- case compression_type::DEFLATE: -- return nvcompBatchedDeflateDecompressGetTempSizeEx(std::forward(args)...); -- case compression_type::GZIP: -- return nvcompBatchedGzipDecompressGetTempSizeEx(std::forward(args)...); -- default: CUDF_FAIL("Unsupported compression type"); -+ case compression_type::DEFLATE: [[fallthrough]]; -+ default: return std::nullopt; - } -+#endif -+ return std::nullopt; - } --size_t batched_decompress_temp_size(compression_type compression, -- size_t num_chunks, -- size_t max_uncomp_chunk_size, -- size_t max_total_uncomp_size) --{ -- size_t temp_size = 0; -- nvcompStatus_t nvcomp_status = batched_decompress_get_temp_size_ex( -- compression, num_chunks, max_uncomp_chunk_size, &temp_size, max_total_uncomp_size); - -- CUDF_EXPECTS(nvcomp_status == nvcompStatus_t::nvcompSuccess, -- "Unable to get scratch size for decompression"); -- return temp_size; -+// Dispatcher for nvcompBatchedDecompressGetTempSize -+template -+auto batched_decompress_get_temp_size(compression_type compression, Args&&... args) -+{ -+ switch (compression) { -+ case compression_type::SNAPPY: -+ return nvcompBatchedSnappyDecompressGetTempSize(std::forward(args)...); -+ case compression_type::ZSTD: -+#if NVCOMP_HAS_ZSTD_DECOMP(NVCOMP_MAJOR_VERSION, NVCOMP_MINOR_VERSION, NVCOMP_PATCH_VERSION) -+ return nvcompBatchedZstdDecompressGetTempSize(std::forward(args)...); -+#else -+ CUDF_FAIL("Decompression error: " + -+ nvcomp::is_decompression_disabled(nvcomp::compression_type::ZSTD).value()); -+#endif -+ case compression_type::DEFLATE: -+#if NVCOMP_HAS_DEFLATE(NVCOMP_MAJOR_VERSION, NVCOMP_MINOR_VERSION, NVCOMP_PATCH_VERSION) -+ return nvcompBatchedDeflateDecompressGetTempSize(std::forward(args)...); -+#else -+ CUDF_FAIL("Decompression error: " + -+ nvcomp::is_decompression_disabled(nvcomp::compression_type::DEFLATE).value()); -+#endif -+ case compression_type::LZ4: -+ return nvcompBatchedLZ4DecompressGetTempSize(std::forward(args)...); -+ default: CUDF_FAIL("Unsupported compression type"); -+ } - } - - // Dispatcher for nvcompBatchedDecompressAsync -@@ -72,12 +120,20 @@ auto batched_decompress_async(compression_type compression, Args&&... args) - case compression_type::SNAPPY: - return nvcompBatchedSnappyDecompressAsync(std::forward(args)...); - case compression_type::ZSTD: -+#if NVCOMP_HAS_ZSTD_DECOMP(NVCOMP_MAJOR_VERSION, NVCOMP_MINOR_VERSION, NVCOMP_PATCH_VERSION) - return nvcompBatchedZstdDecompressAsync(std::forward(args)...); -+#else -+ CUDF_FAIL("Decompression error: " + -+ nvcomp::is_decompression_disabled(nvcomp::compression_type::ZSTD).value()); -+#endif - case compression_type::DEFLATE: -+#if NVCOMP_HAS_DEFLATE(NVCOMP_MAJOR_VERSION, NVCOMP_MINOR_VERSION, NVCOMP_PATCH_VERSION) - return nvcompBatchedDeflateDecompressAsync(std::forward(args)...); -+#else -+ CUDF_FAIL("Decompression error: " + -+ nvcomp::is_decompression_disabled(nvcomp::compression_type::DEFLATE).value()); -+#endif - case compression_type::LZ4: return nvcompBatchedLZ4DecompressAsync(std::forward(args)...); -- case compression_type::GZIP: -- return nvcompBatchedGzipDecompressAsync(std::forward(args)...); - default: CUDF_FAIL("Unsupported compression type"); - } - } -@@ -89,11 +145,31 @@ std::string compression_type_name(compression_type compression) - case compression_type::ZSTD: return "Zstandard"; - case compression_type::DEFLATE: return "Deflate"; - case compression_type::LZ4: return "LZ4"; -- case compression_type::GZIP: return "GZIP"; - } - return "compression_type(" + std::to_string(static_cast(compression)) + ")"; - } - -+size_t batched_decompress_temp_size(compression_type compression, -+ size_t num_chunks, -+ size_t max_uncomp_chunk_size, -+ size_t max_total_uncomp_size) -+{ -+ size_t temp_size = 0; -+ auto nvcomp_status = batched_decompress_get_temp_size_ex( -+ compression, num_chunks, max_uncomp_chunk_size, &temp_size, max_total_uncomp_size); -+ -+ if (nvcomp_status.value_or(nvcompStatus_t::nvcompErrorInternal) != -+ nvcompStatus_t::nvcompSuccess) { -+ nvcomp_status = -+ batched_decompress_get_temp_size(compression, num_chunks, max_uncomp_chunk_size, &temp_size); -+ } -+ -+ CUDF_EXPECTS(nvcomp_status == nvcompStatus_t::nvcompSuccess, -+ "Unable to get scratch size for decompression"); -+ -+ return temp_size; -+} -+ - void batched_decompress(compression_type compression, - device_span const> inputs, - device_span const> outputs, -@@ -128,10 +204,54 @@ void batched_decompress(compression_type compression, - update_compression_results(nvcomp_statuses, actual_uncompressed_data_sizes, results, stream); - } - --size_t batched_compress_temp_size(compression_type compression, -- size_t batch_size, -- size_t max_uncompressed_chunk_bytes, -- size_t max_total_uncompressed_bytes) -+// Wrapper for nvcompBatchedCompressGetTempSize -+auto batched_compress_get_temp_size(compression_type compression, -+ size_t batch_size, -+ size_t max_uncompressed_chunk_bytes) -+{ -+ size_t temp_size = 0; -+ nvcompStatus_t nvcomp_status = nvcompStatus_t::nvcompSuccess; -+ switch (compression) { -+ case compression_type::SNAPPY: -+ nvcomp_status = nvcompBatchedSnappyCompressGetTempSize( -+ batch_size, max_uncompressed_chunk_bytes, nvcompBatchedSnappyDefaultOpts, &temp_size); -+ break; -+ case compression_type::DEFLATE: -+#if NVCOMP_HAS_DEFLATE(NVCOMP_MAJOR_VERSION, NVCOMP_MINOR_VERSION, NVCOMP_PATCH_VERSION) -+ nvcomp_status = nvcompBatchedDeflateCompressGetTempSize( -+ batch_size, max_uncompressed_chunk_bytes, nvcompBatchedDeflateDefaultOpts, &temp_size); -+ break; -+#else -+ CUDF_FAIL("Compression error: " + -+ nvcomp::is_compression_disabled(nvcomp::compression_type::DEFLATE).value()); -+#endif -+ case compression_type::ZSTD: -+#if NVCOMP_HAS_ZSTD_COMP(NVCOMP_MAJOR_VERSION, NVCOMP_MINOR_VERSION, NVCOMP_PATCH_VERSION) -+ nvcomp_status = nvcompBatchedZstdCompressGetTempSize( -+ batch_size, max_uncompressed_chunk_bytes, nvcompBatchedZstdDefaultOpts, &temp_size); -+ break; -+#else -+ CUDF_FAIL("Compression error: " + -+ nvcomp::is_compression_disabled(nvcomp::compression_type::ZSTD).value()); -+#endif -+ case compression_type::LZ4: -+ nvcomp_status = nvcompBatchedLZ4CompressGetTempSize( -+ batch_size, max_uncompressed_chunk_bytes, nvcompBatchedLZ4DefaultOpts, &temp_size); -+ break; -+ default: CUDF_FAIL("Unsupported compression type"); -+ } -+ -+ CUDF_EXPECTS(nvcomp_status == nvcompStatus_t::nvcompSuccess, -+ "Unable to get scratch size for compression"); -+ return temp_size; -+} -+ -+#if NVCOMP_HAS_COMP_TEMPSIZE_EX(NVCOMP_MAJOR_VERSION, NVCOMP_MINOR_VERSION, NVCOMP_PATCH_VERSION) -+// Wrapper for nvcompBatchedCompressGetTempSizeEx -+auto batched_compress_get_temp_size_ex(compression_type compression, -+ size_t batch_size, -+ size_t max_uncompressed_chunk_bytes, -+ size_t max_total_uncompressed_bytes) - { - size_t temp_size = 0; - nvcompStatus_t nvcomp_status = nvcompStatus_t::nvcompSuccess; -@@ -171,8 +291,28 @@ size_t batched_compress_temp_size(compression_type compression, - "Unable to get scratch size for compression"); - return temp_size; - } -+#endif -+ -+size_t batched_compress_temp_size(compression_type compression, -+ size_t num_chunks, -+ size_t max_uncomp_chunk_size, -+ size_t max_total_uncomp_size) -+{ -+#if NVCOMP_HAS_COMP_TEMPSIZE_EX(NVCOMP_MAJOR_VERSION, NVCOMP_MINOR_VERSION, NVCOMP_PATCH_VERSION) -+ try { -+ return batched_compress_get_temp_size_ex( -+ compression, num_chunks, max_uncomp_chunk_size, max_total_uncomp_size); -+ } catch (...) { -+ // Ignore errors in the expanded version; fall back to the old API in case of failure -+ CUDF_LOG_WARN( -+ "CompressGetTempSizeEx call failed, falling back to CompressGetTempSize; this may increase " -+ "the memory usage"); -+ } -+#endif -+ -+ return batched_compress_get_temp_size(compression, num_chunks, max_uncomp_chunk_size); -+} - --// Wrapper for nvcompBatchedCompressGetMaxOutputChunkSize - size_t compress_max_output_chunk_size(compression_type compression, - uint32_t max_uncompressed_chunk_bytes) - { -@@ -188,13 +328,23 @@ size_t compress_max_output_chunk_size(compression_type compression, - capped_uncomp_bytes, nvcompBatchedSnappyDefaultOpts, &max_comp_chunk_size); - break; - case compression_type::DEFLATE: -+#if NVCOMP_HAS_DEFLATE(NVCOMP_MAJOR_VERSION, NVCOMP_MINOR_VERSION, NVCOMP_PATCH_VERSION) - status = nvcompBatchedDeflateCompressGetMaxOutputChunkSize( - capped_uncomp_bytes, nvcompBatchedDeflateDefaultOpts, &max_comp_chunk_size); - break; -+#else -+ CUDF_FAIL("Compression error: " + -+ nvcomp::is_compression_disabled(nvcomp::compression_type::DEFLATE).value()); -+#endif - case compression_type::ZSTD: -+#if NVCOMP_HAS_ZSTD_COMP(NVCOMP_MAJOR_VERSION, NVCOMP_MINOR_VERSION, NVCOMP_PATCH_VERSION) - status = nvcompBatchedZstdCompressGetMaxOutputChunkSize( - capped_uncomp_bytes, nvcompBatchedZstdDefaultOpts, &max_comp_chunk_size); - break; -+#else -+ CUDF_FAIL("Compression error: " + -+ nvcomp::is_compression_disabled(nvcomp::compression_type::ZSTD).value()); -+#endif - case compression_type::LZ4: - status = nvcompBatchedLZ4CompressGetMaxOutputChunkSize( - capped_uncomp_bytes, nvcompBatchedLZ4DefaultOpts, &max_comp_chunk_size); -@@ -234,6 +384,7 @@ static void batched_compress_async(compression_type compression, - stream.value()); - break; - case compression_type::DEFLATE: -+#if NVCOMP_HAS_DEFLATE(NVCOMP_MAJOR_VERSION, NVCOMP_MINOR_VERSION, NVCOMP_PATCH_VERSION) - nvcomp_status = nvcompBatchedDeflateCompressAsync(device_uncompressed_ptrs, - device_uncompressed_bytes, - max_uncompressed_chunk_bytes, -@@ -245,7 +396,12 @@ static void batched_compress_async(compression_type compression, - nvcompBatchedDeflateDefaultOpts, - stream.value()); - break; -+#else -+ CUDF_FAIL("Compression error: " + -+ nvcomp::is_compression_disabled(nvcomp::compression_type::DEFLATE).value()); -+#endif - case compression_type::ZSTD: -+#if NVCOMP_HAS_ZSTD_COMP(NVCOMP_MAJOR_VERSION, NVCOMP_MINOR_VERSION, NVCOMP_PATCH_VERSION) - nvcomp_status = nvcompBatchedZstdCompressAsync(device_uncompressed_ptrs, - device_uncompressed_bytes, - max_uncompressed_chunk_bytes, -@@ -257,6 +413,10 @@ static void batched_compress_async(compression_type compression, - nvcompBatchedZstdDefaultOpts, - stream.value()); - break; -+#else -+ CUDF_FAIL("Compression error: " + -+ nvcomp::is_compression_disabled(nvcomp::compression_type::ZSTD).value()); -+#endif - case compression_type::LZ4: - nvcomp_status = nvcompBatchedLZ4CompressAsync(device_uncompressed_ptrs, - device_uncompressed_bytes, -@@ -318,18 +478,16 @@ void batched_compress(compression_type compression, - } - - feature_status_parameters::feature_status_parameters() -- : feature_status_parameters(nvcomp_integration::is_all_enabled(), -- nvcomp_integration::is_stable_enabled()) --{ --} -- --feature_status_parameters::feature_status_parameters(bool all_enabled, bool stable_enabled) -- : lib_major_version{NVCOMP_VER_MAJOR}, -- lib_minor_version{NVCOMP_VER_MINOR}, -- lib_patch_version{NVCOMP_VER_PATCH}, -- are_all_integrations_enabled{all_enabled}, -- are_stable_integrations_enabled{stable_enabled} -+ : lib_major_version{NVCOMP_MAJOR_VERSION}, -+ lib_minor_version{NVCOMP_MINOR_VERSION}, -+ lib_patch_version{NVCOMP_PATCH_VERSION}, -+ are_all_integrations_enabled{nvcomp_integration::is_all_enabled()}, -+ are_stable_integrations_enabled{nvcomp_integration::is_stable_enabled()} - { -+ int device; -+ CUDF_CUDA_TRY(cudaGetDevice(&device)); -+ CUDF_CUDA_TRY( -+ cudaDeviceGetAttribute(&compute_capability_major, cudaDevAttrComputeCapabilityMajor, device)); - } - - // Represents all parameters required to determine status of a compression/decompression feature -@@ -352,21 +510,42 @@ std::optional is_compression_disabled_impl(compression_type compres - { - switch (compression) { - case compression_type::DEFLATE: { -+ if (not NVCOMP_HAS_DEFLATE( -+ params.lib_major_version, params.lib_minor_version, params.lib_patch_version)) { -+ return "nvCOMP 2.5 or newer is required for Deflate compression"; -+ } - if (not params.are_all_integrations_enabled) { - return "DEFLATE compression is experimental, you can enable it through " - "`LIBCUDF_NVCOMP_POLICY` environment variable."; - } - return std::nullopt; - } -+ case compression_type::SNAPPY: { -+ if (not params.are_stable_integrations_enabled) { -+ return "Snappy compression has been disabled through the `LIBCUDF_NVCOMP_POLICY` " -+ "environment variable."; -+ } -+ return std::nullopt; -+ } -+ case compression_type::ZSTD: { -+ if (not NVCOMP_HAS_ZSTD_COMP( -+ params.lib_major_version, params.lib_minor_version, params.lib_patch_version)) { -+ return "nvCOMP 2.4 or newer is required for Zstandard compression"; -+ } -+ if (not params.are_stable_integrations_enabled) { -+ return "Zstandard compression is experimental, you can enable it through " -+ "`LIBCUDF_NVCOMP_POLICY` environment variable."; -+ } -+ return std::nullopt; -+ } - case compression_type::LZ4: -- case compression_type::SNAPPY: -- case compression_type::ZSTD: - if (not params.are_stable_integrations_enabled) { -- return "nvCOMP use is disabled through the `LIBCUDF_NVCOMP_POLICY` environment variable."; -+ return "LZ4 compression has been disabled through the `LIBCUDF_NVCOMP_POLICY` " -+ "environment variable."; - } - return std::nullopt; -- default: return "Unsupported compression type"; - } -+ return "Unsupported compression type"; - } - - std::optional is_compression_disabled(compression_type compression, -@@ -398,26 +577,58 @@ std::optional is_compression_disabled(compression_type compression, - return reason; - } - -+std::optional is_zstd_decomp_disabled(feature_status_parameters const& params) -+{ -+ if (not NVCOMP_HAS_ZSTD_DECOMP( -+ params.lib_major_version, params.lib_minor_version, params.lib_patch_version)) { -+ return "nvCOMP 2.3 or newer is required for Zstandard decompression"; -+ } -+ -+ if (NVCOMP_ZSTD_DECOMP_IS_STABLE( -+ params.lib_major_version, params.lib_minor_version, params.lib_patch_version)) { -+ if (not params.are_stable_integrations_enabled) { -+ return "Zstandard decompression has been disabled through the `LIBCUDF_NVCOMP_POLICY` " -+ "environment variable."; -+ } -+ } else if (not params.are_all_integrations_enabled) { -+ return "Zstandard decompression is experimental, you can enable it through " -+ "`LIBCUDF_NVCOMP_POLICY` environment variable."; -+ } -+ -+ return std::nullopt; -+} -+ - std::optional is_decompression_disabled_impl(compression_type compression, - feature_status_parameters params) - { - switch (compression) { -- case compression_type::DEFLATE: -- case compression_type::GZIP: { -+ case compression_type::DEFLATE: { -+ if (not NVCOMP_HAS_DEFLATE( -+ params.lib_major_version, params.lib_minor_version, params.lib_patch_version)) { -+ return "nvCOMP 2.5 or newer is required for Deflate decompression"; -+ } - if (not params.are_all_integrations_enabled) { - return "DEFLATE decompression is experimental, you can enable it through " - "`LIBCUDF_NVCOMP_POLICY` environment variable."; - } - return std::nullopt; - } -- case compression_type::LZ4: -- case compression_type::SNAPPY: -- case compression_type::ZSTD: { -+ case compression_type::SNAPPY: { - if (not params.are_stable_integrations_enabled) { -- return "nvCOMP use is disabled through the `LIBCUDF_NVCOMP_POLICY` environment variable."; -+ return "Snappy decompression has been disabled through the `LIBCUDF_NVCOMP_POLICY` " -+ "environment variable."; - } - return std::nullopt; - } -+ case compression_type::ZSTD: return is_zstd_decomp_disabled(params); -+ case compression_type::LZ4: { -+ if (not params.are_stable_integrations_enabled) { -+ return "LZ4 decompression has been disabled through the `LIBCUDF_NVCOMP_POLICY` " -+ "environment variable."; -+ } -+ return std::nullopt; -+ } -+ default: return "Unsupported compression type"; - } - return "Unsupported compression type"; - } -@@ -451,27 +662,43 @@ std::optional is_decompression_disabled(compression_type compressio - return reason; - } - --size_t required_alignment(compression_type compression) -+size_t compress_input_alignment_bits(compression_type compression) - { - switch (compression) { -- case compression_type::GZIP: -- case compression_type::DEFLATE: return nvcompDeflateRequiredAlignment; -- case compression_type::SNAPPY: return nvcompSnappyRequiredAlignment; -- case compression_type::ZSTD: return nvcompZstdRequiredAlignment; -- case compression_type::LZ4: return nvcompLZ4RequiredAlignment; -+ case compression_type::DEFLATE: return 0; -+ case compression_type::SNAPPY: return 0; -+ case compression_type::ZSTD: return 2; -+ case compression_type::LZ4: return 2; - default: CUDF_FAIL("Unsupported compression type"); - } - } - --std::optional compress_max_allowed_chunk_size(compression_type compression) -+size_t compress_output_alignment_bits(compression_type compression) - { - switch (compression) { -- case compression_type::DEFLATE: return nvcompDeflateCompressionMaxAllowedChunkSize; -- case compression_type::SNAPPY: return nvcompSnappyCompressionMaxAllowedChunkSize; -- case compression_type::ZSTD: return nvcompZstdCompressionMaxAllowedChunkSize; -- case compression_type::LZ4: return nvcompLZ4CompressionMaxAllowedChunkSize; -+ case compression_type::DEFLATE: return 3; -+ case compression_type::SNAPPY: return 0; -+ case compression_type::ZSTD: return 0; -+ case compression_type::LZ4: return 2; - default: CUDF_FAIL("Unsupported compression type"); - } - } - -+std::optional compress_max_allowed_chunk_size(compression_type compression) -+{ -+ switch (compression) { -+ case compression_type::DEFLATE: return 64 * 1024; -+ case compression_type::SNAPPY: return std::nullopt; -+ case compression_type::ZSTD: -+#if NVCOMP_HAS_ZSTD_COMP(NVCOMP_MAJOR_VERSION, NVCOMP_MINOR_VERSION, NVCOMP_PATCH_VERSION) -+ return nvcompZstdCompressionMaxAllowedChunkSize; -+#else -+ CUDF_FAIL("Compression error: " + -+ nvcomp::is_compression_disabled(nvcomp::compression_type::ZSTD).value()); -+#endif -+ case compression_type::LZ4: return 16 * 1024 * 1024; -+ default: return std::nullopt; -+ } -+} -+ - } // namespace cudf::io::nvcomp -diff --git a/cpp/src/io/comp/nvcomp_adapter.hpp b/cpp/src/io/comp/nvcomp_adapter.hpp -index 583bd6a352..43c79e3237 100644 ---- a/cpp/src/io/comp/nvcomp_adapter.hpp -+++ b/cpp/src/io/comp/nvcomp_adapter.hpp -@@ -75,12 +75,20 @@ size_t batched_decompress_temp_size(compression_type compression, - uint32_t max_uncomp_chunk_size); - - /** -- * @brief Gets input and output alignment requirements for the given compression type. -+ * @brief Gets input alignment requirements for the given compression type. - * - * @param compression Compression type -- * @returns required alignment -+ * @returns required alignment, in bits - */ --[[nodiscard]] size_t required_alignment(compression_type compression); -+[[nodiscard]] size_t compress_input_alignment_bits(compression_type compression); -+ -+/** -+ * @brief Gets output alignment requirements for the given compression type. -+ * -+ * @param compression Compression type -+ * @returns required alignment, in bits -+ */ -+[[nodiscard]] size_t compress_output_alignment_bits(compression_type compression); - - /** - * @brief Maximum size of uncompressed chunks that can be compressed with nvCOMP. -diff --git a/cpp/src/io/orc/writer_impl.cu b/cpp/src/io/orc/writer_impl.cu -index 60a64fb0ee..40cfbe763b 100644 ---- a/cpp/src/io/orc/writer_impl.cu -+++ b/cpp/src/io/orc/writer_impl.cu -@@ -533,20 +533,20 @@ auto uncomp_block_alignment(CompressionKind compression_kind) - { - if (compression_kind == NONE or - nvcomp::is_compression_disabled(to_nvcomp_compression_type(compression_kind))) { -- return 1ul; -+ return 1u; - } - -- return nvcomp::required_alignment(to_nvcomp_compression_type(compression_kind)); -+ return 1u << nvcomp::compress_input_alignment_bits(to_nvcomp_compression_type(compression_kind)); - } - - auto comp_block_alignment(CompressionKind compression_kind) - { - if (compression_kind == NONE or - nvcomp::is_compression_disabled(to_nvcomp_compression_type(compression_kind))) { -- return 1ul; -+ return 1u; - } - -- return nvcomp::required_alignment(to_nvcomp_compression_type(compression_kind)); -+ return 1u << nvcomp::compress_output_alignment_bits(to_nvcomp_compression_type(compression_kind)); - } - - /** -diff --git a/cpp/src/io/parquet/reader_impl_chunking.cu b/cpp/src/io/parquet/reader_impl_chunking.cu -index c588fedb85..bab70c126b 100644 ---- a/cpp/src/io/parquet/reader_impl_chunking.cu -+++ b/cpp/src/io/parquet/reader_impl_chunking.cu -@@ -865,18 +865,8 @@ std::vector compute_page_splits_by_row(device_span=2.5.0,<2.6.0a0 - # Align nvcomp version with rapids-cmake -- - nvcomp==4.0.1 -+ - nvcomp==3.0.6 - - spdlog>=1.14.1,<1.15 - rapids_build_skbuild: - common: -diff --git a/docs/cudf/source/user_guide/io/io.md b/docs/cudf/source/user_guide/io/io.md -index 97b961b455..adcdaa51e7 100644 ---- a/docs/cudf/source/user_guide/io/io.md -+++ b/docs/cudf/source/user_guide/io/io.md -@@ -75,6 +75,7 @@ IO format. - - - -+ - **Notes:** - - - \[¹\] - Not all orientations are GPU-accelerated. -@@ -176,9 +177,4 @@ If no value is set, behavior will be the same as the "STABLE" option. - +-----------------------+--------+--------+--------------+--------------+---------+--------+--------------+--------------+--------+ - | DEFLATE | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | Experimental | Experimental | ❌ | - +-----------------------+--------+--------+--------------+--------------+---------+--------+--------------+--------------+--------+ -- | LZ4 | ❌ | ❌ | Stable | Stable | ❌ | ❌ | Stable | Stable | ❌ | -- +-----------------------+--------+--------+--------------+--------------+---------+--------+--------------+--------------+--------+ -- | GZIP | ❌ | ❌ | Experimental | Experimental | ❌ | ❌ | ❌ | ❌ | ❌ | -- +-----------------------+--------+--------+--------------+--------------+---------+--------+--------------+--------------+--------+ -- - ``` -diff --git a/java/pom.xml b/java/pom.xml -index e4f1cdf64e..9694e741f1 100644 ---- a/java/pom.xml -+++ b/java/pom.xml -@@ -1,6 +1,6 @@ - - false 3.0.0 0.2.2 @@ -452,15 +456,13 @@ + - - - + - + diff --git a/src/main/cpp/CMakeLists.txt b/src/main/cpp/CMakeLists.txt index 3ee308550f..20a918266e 100644 --- a/src/main/cpp/CMakeLists.txt +++ b/src/main/cpp/CMakeLists.txt @@ -12,7 +12,7 @@ # the License. # ============================================================================= -cmake_minimum_required(VERSION 3.20.1 FATAL_ERROR) +cmake_minimum_required(VERSION 3.28.6 FATAL_ERROR) set(CUDF_DIR "$ENV{CUDF_ROOT}" @@ -51,7 +51,7 @@ rapids_cuda_init_architectures(SPARK_RAPIDS_JNI) project( SPARK_RAPIDS_JNI - VERSION 24.10.00 + VERSION 24.12.00 LANGUAGES C CXX CUDA ) @@ -98,6 +98,7 @@ set_property( message(VERBOSE "SPARK_RAPIDS_JNI: RMM_LOGGING_LEVEL = '${RMM_LOGGING_LEVEL}'.") set(CMAKE_PREFIX_PATH ${CUDF_INSTALL_DIR}) +set(CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/Modules/") # Set a default build type if none was specified rapids_cmake_build_type("Release") @@ -194,6 +195,7 @@ add_library( src/GpuTimeZoneDBJni.cpp src/HashJni.cpp src/HistogramJni.cpp + src/HostTableJni.cpp src/JSONUtilsJni.cpp src/NativeParquetJni.cpp src/ParseURIJni.cpp @@ -205,15 +207,17 @@ add_library( src/bloom_filter.cu src/case_when.cu src/cast_decimal_to_string.cu - src/format_float.cu src/cast_float_to_string.cu src/cast_string.cu src/cast_string_to_float.cu src/datetime_rebase.cu src/decimal_utils.cu + src/format_float.cu src/from_json_to_raw_map.cu + src/from_json_to_structs.cu src/get_json_object.cu src/histogram.cu + src/json_utils.cu src/murmur_hash.cu src/parse_uri.cu src/regex_rewrite_utils.cu @@ -295,7 +299,6 @@ add_dependencies(cudfjnistub spark_rapids_jni) # * cuFile JNI ------------------------------------------------------------------------------------- if(USE_GDS) - include(${CUDF_DIR}/cpp/cmake/Modules/FindcuFile.cmake) find_library(CUFILEJNI_LIB "libcufilejni.a" REQUIRED NO_DEFAULT_PATH HINTS "${CUDFJNI_BUILD_DIR}" ) @@ -307,6 +310,12 @@ if(USE_GDS) # set target compile options CXX_STANDARD 17 CXX_STANDARD_REQUIRED ON + CXX_EXTENSIONS ON + CUDA_STANDARD 17 + CUDA_STANDARD_REQUIRED ON + POSITION_INDEPENDENT_CODE ON + INTERFACE_POSITION_INDEPENDENT_CODE ON + CUDA_RUNTIME_LIBRARY Static ) target_link_libraries( cufilejni PRIVATE @@ -314,7 +323,7 @@ if(USE_GDS) ${CUFILEJNI_LIB} -Wl,--no-whole-archive spark_rapids_jni - ${cuFile_LIBRARIES} + CUDA::cuFile_static ) rapids_cuda_set_runtime(cufilejni USE_STATIC ON) endif() diff --git a/src/main/cpp/src/HostTableJni.cpp b/src/main/cpp/src/HostTableJni.cpp new file mode 100644 index 0000000000..0959f35e57 --- /dev/null +++ b/src/main/cpp/src/HostTableJni.cpp @@ -0,0 +1,249 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cudf_jni_apis.hpp" +#include "host_table_view.hpp" + +#include +#include +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include + +namespace { + +using spark_rapids_jni::host_column_view; +using spark_rapids_jni::host_table_view; + +// Padding sizes to 64-byte for compatibility with Arrow +std::size_t pad_size(std::size_t size) { return rmm::align_up(size, 64); } + +// Determine the size of buffer needed to hold just the data portion of this column. +// This does not include validity, offsets, or any child columns. +std::size_t get_data_size(cudf::column_view const& c, cudaStream_t stream) +{ + auto dtype = c.type(); + if (cudf::is_fixed_width(dtype)) { + return cudf::size_of(dtype) * c.size(); + } else if (dtype.id() == cudf::type_id::STRING) { + auto scv = cudf::strings_column_view(c); + return scv.chars_size(stream); + } else { + throw std::runtime_error(std::string("unexpected data type: ") + + std::to_string(static_cast(dtype.id()))); + } +} + +// Determine the size of buffer needed to hold all of the data for a column. +// This includes validity, data, offsets, and child columns. +std::size_t column_size(cudf::column_view const& c, cudaStream_t stream) +{ + std::size_t size = 0; + if (c.data() != nullptr) { size += pad_size(get_data_size(c, stream)); } + if (c.has_nulls()) { size += cudf::bitmask_allocation_size_bytes(c.size()); } + return std::accumulate(c.child_begin(), + c.child_end(), + size, + [stream](std::size_t sum, cudf::column_view const& child) { + return sum + column_size(child, stream); + }); +} + +// Determine the size of buffer needed to hold all of the data for a table. +std::size_t host_buffer_size(cudf::table_view const& t, cudaStream_t stream) +{ + std::size_t s = 0; + return std::accumulate( + t.begin(), t.end(), s, [stream](std::size_t sum, cudf::column_view const& c) { + return sum + column_size(c, stream); + }); +} + +uint8_t* copy_to_host_async( + void const* src, uint8_t* dest, std::size_t size, uint8_t const* dest_end, cudaStream_t stream) +{ + if (dest + size > dest_end) { throw std::runtime_error("buffer overflow"); } + CUDF_CUDA_TRY(cudaMemcpyAsync(dest, src, size, cudaMemcpyDeviceToHost, stream)); + return dest + size; +} + +uint8_t* build_host_column_view_async(cudf::column_view const& dev_col, + uint8_t* bp, + uint8_t const* bp_end, + cudaStream_t stream, + std::vector& host_cols) +{ + void const* host_data = nullptr; + void const* dev_data = dev_col.data(); + if (dev_data != nullptr) { + host_data = bp; + auto data_size = get_data_size(dev_col, stream); + auto padded_bp_end = bp + pad_size(data_size); + bp = copy_to_host_async(dev_data, bp, data_size, bp_end, stream); + while (bp != padded_bp_end) { + *bp++ = 0; + } + } + cudf::bitmask_type const* host_null_mask = nullptr; + if (dev_col.has_nulls()) { + host_null_mask = reinterpret_cast(bp); + auto mask_size = cudf::bitmask_allocation_size_bytes(dev_col.size()); + bp = copy_to_host_async(dev_col.null_mask(), bp, mask_size, bp_end, stream); + } + std::vector children; + children.reserve(dev_col.num_children()); + std::for_each(dev_col.child_begin(), dev_col.child_end(), [&](cudf::column_view const& child) { + bp = build_host_column_view_async(child, bp, bp_end, stream, children); + }); + host_cols.push_back(host_column_view( + dev_col.type(), dev_col.size(), host_data, host_null_mask, dev_col.null_count(), children)); + return bp; +} + +std::unique_ptr to_host_table_async(cudf::table_view const& dev_table, + uint8_t* buffer, + std::size_t buffer_size, + cudaStream_t stream) +{ + uint8_t* bp = buffer; + uint8_t const* buffer_end = buffer + buffer_size; + std::vector cols; + cols.reserve(dev_table.num_columns()); + std::for_each(dev_table.begin(), dev_table.end(), [&](cudf::column_view const& dev_col) { + bp = build_host_column_view_async(dev_col, bp, buffer_end, stream, cols); + }); + return std::make_unique(cols); +} + +cudf::column_view to_device_column(host_column_view const& host_col, jlong host_to_dev_offset) +{ + auto data = host_col.data(); + if (data != nullptr) { data += host_to_dev_offset; } + auto mask = host_col.null_mask(); + if (mask != nullptr) { mask += host_to_dev_offset / sizeof(*mask); } + std::vector children; + std::transform(host_col.child_begin(), + host_col.child_end(), + std::back_inserter(children), + [host_to_dev_offset](host_column_view const& c) { + return to_device_column(c, host_to_dev_offset); + }); + return cudf::column_view( + host_col.type(), host_col.size(), data, mask, host_col.null_count(), 0, children); +} + +std::vector> to_device_column_views( + host_table_view const& host_table, jlong host_to_dev_offset) +{ + std::vector> cv_ptrs; + cv_ptrs.reserve(host_table.num_columns()); + std::transform( + host_table.begin(), + host_table.end(), + std::back_inserter(cv_ptrs), + [host_to_dev_offset](host_column_view const& host_col) { + return std::make_unique(to_device_column(host_col, host_to_dev_offset)); + }); + return cv_ptrs; +} + +} // anonymous namespace + +extern "C" { + +JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_HostTable_bufferSize(JNIEnv* env, + jclass, + jlong table_handle, + jlong jstream) +{ + JNI_NULL_CHECK(env, table_handle, "table is null", 0); + try { + cudf::jni::auto_set_device(env); + auto t = reinterpret_cast(table_handle); + auto stream = reinterpret_cast(jstream); + return static_cast(host_buffer_size(*t, stream)); + } + CATCH_STD(env, 0); +} + +JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_HostTable_copyFromTableAsync( + JNIEnv* env, jclass, jlong table_handle, jlong host_address, jlong host_size, jlong jstream) +{ + JNI_NULL_CHECK(env, table_handle, "table is null", 0); + try { + cudf::jni::auto_set_device(env); + auto table = reinterpret_cast(table_handle); + auto buffer = reinterpret_cast(host_address); + auto buffer_size = static_cast(host_size); + auto stream = reinterpret_cast(jstream); + auto host_table_view = to_host_table_async(*table, buffer, buffer_size, stream); + return reinterpret_cast(host_table_view.release()); + } + CATCH_STD(env, 0); +} + +JNIEXPORT jlongArray JNICALL Java_com_nvidia_spark_rapids_jni_HostTable_toDeviceColumnViews( + JNIEnv* env, jclass, jlong table_handle, jlong host_to_dev_offset) +{ + JNI_NULL_CHECK(env, table_handle, "table is null", nullptr); + JNI_ARG_CHECK( + env, host_to_dev_offset % sizeof(cudf::bitmask_type) == 0, "invalid offset", nullptr); + try { + cudf::jni::auto_set_device(env); + auto host_table = reinterpret_cast(table_handle); + auto column_view_ptrs = to_device_column_views(*host_table, host_to_dev_offset); + cudf::jni::native_jlongArray handles(env, static_cast(column_view_ptrs.size())); + std::transform( + column_view_ptrs.begin(), + column_view_ptrs.end(), + handles.begin(), + [](std::unique_ptr& p) { return cudf::jni::release_as_jlong(p); }); + return handles.get_jArray(); + } + CATCH_STD(env, 0); +} + +JNIEXPORT void JNICALL Java_com_nvidia_spark_rapids_jni_HostTable_freeDeviceColumnView( + JNIEnv* env, jclass, jlong dev_column_view_handle) +{ + JNI_NULL_CHECK(env, dev_column_view_handle, "view is null", ); + try { + delete reinterpret_cast(dev_column_view_handle); + } + CATCH_STD(env, ); +} + +JNIEXPORT void JNICALL Java_com_nvidia_spark_rapids_jni_HostTable_freeHostTable(JNIEnv* env, + jclass, + jlong table_handle) +{ + JNI_NULL_CHECK(env, table_handle, "table is null", ); + try { + delete reinterpret_cast(table_handle); + } + CATCH_STD(env, ); +} + +} // extern "C" diff --git a/src/main/cpp/src/JSONUtilsJni.cpp b/src/main/cpp/src/JSONUtilsJni.cpp index 5a0c5dd341..2c4b8e1aaa 100644 --- a/src/main/cpp/src/JSONUtilsJni.cpp +++ b/src/main/cpp/src/JSONUtilsJni.cpp @@ -15,8 +15,8 @@ */ #include "cudf_jni_apis.hpp" -#include "from_json.hpp" #include "get_json_object.hpp" +#include "json_utils.hpp" #include @@ -142,7 +142,13 @@ Java_com_nvidia_spark_rapids_jni_JSONUtils_getJsonObjectMultiplePaths(JNIEnv* en } JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_JSONUtils_extractRawMapFromJsonString( - JNIEnv* env, jclass, jlong j_input) + JNIEnv* env, + jclass, + jlong j_input, + jboolean normalize_single_quotes, + jboolean allow_leading_zeros, + jboolean allow_nonnumeric_numbers, + jboolean allow_unquoted_control) { JNI_NULL_CHECK(env, j_input, "j_input is null", 0); @@ -150,8 +156,128 @@ JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_JSONUtils_extractRawMap cudf::jni::auto_set_device(env); auto const input_cv = reinterpret_cast(j_input); return cudf::jni::ptr_as_jlong( - spark_rapids_jni::from_json_to_raw_map(cudf::strings_column_view{*input_cv}).release()); + spark_rapids_jni::from_json_to_raw_map(cudf::strings_column_view{*input_cv}, + normalize_single_quotes, + allow_leading_zeros, + allow_nonnumeric_numbers, + allow_unquoted_control) + .release()); } CATCH_STD(env, 0); } + +JNIEXPORT jlong JNICALL +Java_com_nvidia_spark_rapids_jni_JSONUtils_fromJSONToStructs(JNIEnv* env, + jclass, + jlong j_input, + jobjectArray j_col_names, + jintArray j_num_children, + jintArray j_types, + jintArray j_scales, + jintArray j_precisions, + jboolean normalize_single_quotes, + jboolean allow_leading_zeros, + jboolean allow_nonnumeric_numbers, + jboolean allow_unquoted_control, + jboolean is_us_locale) +{ + JNI_NULL_CHECK(env, j_input, "j_input is null", 0); + JNI_NULL_CHECK(env, j_col_names, "j_col_names is null", 0); + JNI_NULL_CHECK(env, j_num_children, "j_num_children is null", 0); + JNI_NULL_CHECK(env, j_types, "j_types is null", 0); + JNI_NULL_CHECK(env, j_scales, "j_scales is null", 0); + JNI_NULL_CHECK(env, j_precisions, "j_precisions is null", 0); + + try { + cudf::jni::auto_set_device(env); + + auto const input_cv = reinterpret_cast(j_input); + auto const col_names = cudf::jni::native_jstringArray(env, j_col_names).as_cpp_vector(); + auto const num_children = cudf::jni::native_jintArray(env, j_num_children).to_vector(); + auto const types = cudf::jni::native_jintArray(env, j_types).to_vector(); + auto const scales = cudf::jni::native_jintArray(env, j_scales).to_vector(); + auto const precisions = cudf::jni::native_jintArray(env, j_precisions).to_vector(); + + CUDF_EXPECTS(col_names.size() > 0, "Invalid schema data: col_names."); + CUDF_EXPECTS(col_names.size() == num_children.size(), "Invalid schema data: num_children."); + CUDF_EXPECTS(col_names.size() == types.size(), "Invalid schema data: types."); + CUDF_EXPECTS(col_names.size() == scales.size(), "Invalid schema data: scales."); + CUDF_EXPECTS(col_names.size() == precisions.size(), "Invalid schema data: precisions."); + + return cudf::jni::ptr_as_jlong( + spark_rapids_jni::from_json_to_structs(cudf::strings_column_view{*input_cv}, + col_names, + num_children, + types, + scales, + precisions, + normalize_single_quotes, + allow_leading_zeros, + allow_nonnumeric_numbers, + allow_unquoted_control, + is_us_locale) + .release()); + } + CATCH_STD(env, 0); } + +JNIEXPORT jlong JNICALL +Java_com_nvidia_spark_rapids_jni_JSONUtils_convertFromStrings(JNIEnv* env, + jclass, + jlong j_input, + jintArray j_num_children, + jintArray j_types, + jintArray j_scales, + jintArray j_precisions, + jboolean allow_nonnumeric_numbers, + jboolean is_us_locale) +{ + JNI_NULL_CHECK(env, j_input, "j_input is null", 0); + JNI_NULL_CHECK(env, j_num_children, "j_num_children is null", 0); + JNI_NULL_CHECK(env, j_types, "j_types is null", 0); + JNI_NULL_CHECK(env, j_scales, "j_scales is null", 0); + JNI_NULL_CHECK(env, j_precisions, "j_precisions is null", 0); + + try { + cudf::jni::auto_set_device(env); + + auto const input_cv = reinterpret_cast(j_input); + auto const num_children = cudf::jni::native_jintArray(env, j_num_children).to_vector(); + auto const types = cudf::jni::native_jintArray(env, j_types).to_vector(); + auto const scales = cudf::jni::native_jintArray(env, j_scales).to_vector(); + auto const precisions = cudf::jni::native_jintArray(env, j_precisions).to_vector(); + + CUDF_EXPECTS(num_children.size() > 0, "Invalid schema data: num_children."); + CUDF_EXPECTS(num_children.size() == types.size(), "Invalid schema data: types."); + CUDF_EXPECTS(num_children.size() == scales.size(), "Invalid schema data: scales."); + CUDF_EXPECTS(num_children.size() == precisions.size(), "Invalid schema data: precisions."); + + return cudf::jni::ptr_as_jlong( + spark_rapids_jni::convert_from_strings(cudf::strings_column_view{*input_cv}, + num_children, + types, + scales, + precisions, + allow_nonnumeric_numbers, + is_us_locale) + .release()); + } + CATCH_STD(env, 0); +} + +JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_JSONUtils_removeQuotes( + JNIEnv* env, jclass, jlong j_input, jboolean nullify_if_not_quoted) +{ + JNI_NULL_CHECK(env, j_input, "j_input is null", 0); + + try { + cudf::jni::auto_set_device(env); + auto const input_cv = reinterpret_cast(j_input); + return cudf::jni::ptr_as_jlong( + spark_rapids_jni::remove_quotes(cudf::strings_column_view{*input_cv}, nullify_if_not_quoted) + .release()); + } + CATCH_STD(env, 0); +} + +} // extern "C" diff --git a/src/main/cpp/src/SparkResourceAdaptorJni.cpp b/src/main/cpp/src/SparkResourceAdaptorJni.cpp index 8eeb047ddc..31a603411f 100644 --- a/src/main/cpp/src/SparkResourceAdaptorJni.cpp +++ b/src/main/cpp/src/SparkResourceAdaptorJni.cpp @@ -300,7 +300,6 @@ class full_thread_state { // time) long time_retry_running_nanos = 0; std::chrono::time_point block_start; - long gpu_memory_allocated_bytes = 0; // metrics for the current thread task_metrics metrics; @@ -891,6 +890,8 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { std::condition_variable task_has_woken_condition; std::map threads; std::map> task_to_threads; + long gpu_memory_allocated_bytes = 0; + // Metrics are a little complicated. Spark reports metrics at a task level // but we track and collect them at a thread level. The life time of a thread // and a task are not tied to each other, and a thread can work on things for @@ -1376,10 +1377,9 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { // num_bytes is likely not padded, which could cause slight inaccuracies // but for now it shouldn't matter for watermark purposes if (!is_for_cpu) { - thread->second.gpu_memory_allocated_bytes += num_bytes; + gpu_memory_allocated_bytes += num_bytes; thread->second.metrics.gpu_max_memory_allocated = - std::max(thread->second.metrics.gpu_max_memory_allocated, - thread->second.gpu_memory_allocated_bytes); + std::max(thread->second.metrics.gpu_max_memory_allocated, gpu_memory_allocated_bytes); } break; default: break; @@ -1780,6 +1780,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { auto const thread = threads.find(tid); if (thread != threads.end()) { log_status("DEALLOC", tid, thread->second.task_id, thread->second.state); + if (!is_for_cpu) { gpu_memory_allocated_bytes -= num_bytes; } } else { log_status("DEALLOC", tid, -2, thread_state::UNKNOWN); } @@ -1802,7 +1803,6 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { if (is_for_cpu == t_state.is_cpu_alloc) { transition(t_state, thread_state::THREAD_ALLOC_FREE); } - if (!is_for_cpu) { t_state.gpu_memory_allocated_bytes -= num_bytes; } break; default: break; } diff --git a/src/main/cpp/src/cast_string.cu b/src/main/cpp/src/cast_string.cu index 88f1f0323f..c491668f92 100644 --- a/src/main/cpp/src/cast_string.cu +++ b/src/main/cpp/src/cast_string.cu @@ -408,12 +408,7 @@ CUDF_KERNEL void string_to_decimal_kernel(T* out, auto const row_start = offsets[row]; auto const len = offsets[row + 1] - row_start; bool const valid_entry = incoming_null_mask == nullptr || bit_is_set(incoming_null_mask, row); - - auto ret = validate_and_exponent(&chars[row_start], len, strip); - bool valid = ret.has_value(); - bool positive; - int decimal_location; - int first_digit; + bool valid = valid_entry && len > 0; // first_digit is distance into the string array for the first digit to process. This skips +, -, // whitespace, etc. decimal_location is the index into the string where the decimal point should @@ -439,8 +434,15 @@ CUDF_KERNEL void string_to_decimal_kernel(T* out, return count; }; + auto const validated = + valid ? validate_and_exponent(&chars[row_start], len, strip) : cuda::std::nullopt; + valid = validated.has_value(); + if (valid) { - thrust::tie(positive, decimal_location, first_digit) = *ret; + bool positive; + int decimal_location; + int first_digit; + thrust::tie(positive, decimal_location, first_digit) = *validated; auto const max_digits_before_decimal = precision + scale; auto const significant_digits_before_decimal_in_string = count_significant_digits( diff --git a/src/main/cpp/src/decimal_utils.cu b/src/main/cpp/src/decimal_utils.cu index 147818d9aa..a28d210e64 100644 --- a/src/main/cpp/src/decimal_utils.cu +++ b/src/main/cpp/src/decimal_utils.cu @@ -21,7 +21,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/src/main/cpp/src/from_json_to_raw_map.cu b/src/main/cpp/src/from_json_to_raw_map.cu index 73c2c4b559..0c657552c8 100644 --- a/src/main/cpp/src/from_json_to_raw_map.cu +++ b/src/main/cpp/src/from_json_to_raw_map.cu @@ -15,19 +15,22 @@ */ #include "from_json_to_raw_map_debug.cuh" +#include "json_utils.hpp" -#include #include -#include #include -#include +#include +#include +#include #include -#include #include -#include #include +#include +#include #include +#include +#include #include #include @@ -35,7 +38,6 @@ #include #include #include -#include #include #include #include @@ -45,7 +47,6 @@ #include #include #include -#include #include @@ -55,104 +56,56 @@ using namespace cudf::io::json; namespace { -// Unify the input json strings by: -// 1. Append one comma character (',') to the end of each input string, except the last one. -// 2. Concatenate all input strings into one string. -// 3. Add a pair of bracket characters ('[' and ']') to the beginning and the end of the output. -rmm::device_uvector unify_json_strings(cudf::strings_column_view const& input, - rmm::cuda_stream_view stream) +std::unique_ptr make_empty_map(rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) { - if (input.is_empty()) { - return cudf::detail::make_device_uvector_async( - std::vector{'[', ']'}, stream, rmm::mr::get_current_device_resource()); - } - - auto const d_strings = cudf::column_device_view::create(input.parent(), stream); - auto const chars_size = input.chars_size(stream); - auto const output_size = - 2l + // two extra bracket characters '[' and ']' - static_cast(chars_size) + - static_cast(input.size() - 1) + // append `,` character between input rows - static_cast(input.null_count()) * 2l; // replace null with "{}" - // TODO: This assertion eventually needs to be removed. - // See https://github.com/NVIDIA/spark-rapids-jni/issues/1707 - CUDF_EXPECTS(output_size <= static_cast(std::numeric_limits::max()), - "The input json column is too large and causes overflow."); - - auto const joined_input = cudf::strings::detail::join_strings( - input, - cudf::string_scalar(","), // append `,` character between the input rows - cudf::string_scalar("{}"), // replacement for null rows - stream, - rmm::mr::get_current_device_resource()); - auto const joined_input_scv = cudf::strings_column_view{*joined_input}; - auto const joined_input_size_bytes = joined_input_scv.chars_size(stream); - // TODO: This assertion requires a stream synchronization, may want to remove at some point. - // See https://github.com/NVIDIA/spark-rapids-jni/issues/1707 - CUDF_EXPECTS(joined_input_size_bytes + 2 == output_size, "Incorrect output size computation."); - - // We want to concatenate 3 strings: "[" + joined_input + "]". - // For efficiency, let's use memcpy instead of `cudf::strings::detail::concatenate`. - auto output = rmm::device_uvector(joined_input_size_bytes + 2, stream); - CUDF_CUDA_TRY(cudaMemsetAsync(output.data(), static_cast('['), 1, stream.value())); - CUDF_CUDA_TRY(cudaMemcpyAsync(output.data() + 1, - joined_input_scv.chars_begin(stream), - joined_input_size_bytes, - cudaMemcpyDefault, - stream.value())); - CUDF_CUDA_TRY(cudaMemsetAsync( - output.data() + joined_input_size_bytes + 1, static_cast(']'), 1, stream.value())); - -#ifdef DEBUG_FROM_JSON - print_debug(output, "Processed json string", "", stream); -#endif - return output; + auto keys = cudf::make_empty_column(cudf::data_type{cudf::type_id::STRING}); + auto values = cudf::make_empty_column(cudf::data_type{cudf::type_id::STRING}); + std::vector> out_keys_vals; + out_keys_vals.emplace_back(std::move(keys)); + out_keys_vals.emplace_back(std::move(values)); + auto child = + cudf::make_structs_column(0, std::move(out_keys_vals), 0, rmm::device_buffer{}, stream, mr); + auto offsets = cudf::make_empty_column(cudf::data_type(cudf::type_id::INT32)); + return cudf::make_lists_column( + 0, std::move(offsets), std::move(child), 0, rmm::device_buffer{}, stream, mr); } -// Check and throw exception if there is any parsing error. -void throw_if_error(rmm::device_uvector const& input_json, - rmm::device_uvector const& tokens, - rmm::device_uvector const& token_indices, - rmm::cuda_stream_view stream) +// Concatenating all input strings into one string, for which each input string is appended by a +// delimiter character that does not exist in the input column. +std::tuple> unify_json_strings( + cudf::strings_column_view const& input, rmm::cuda_stream_view stream) { - auto const error_count = - thrust::count(rmm::exec_policy(stream), tokens.begin(), tokens.end(), token_t::ErrorBegin); - - if (error_count > 0) { - auto const error_location = - thrust::find(rmm::exec_policy(stream), tokens.begin(), tokens.end(), token_t::ErrorBegin); - SymbolOffsetT error_index; - CUDF_CUDA_TRY( - cudaMemcpyAsync(&error_index, - token_indices.data() + thrust::distance(tokens.begin(), error_location), - sizeof(SymbolOffsetT), - cudaMemcpyDeviceToHost, - stream.value())); - stream.synchronize(); - - constexpr SymbolOffsetT extension = 10u; - - // Warning: SymbolOffsetT is unsigned type thus we need to be careful with subtractions. - auto const begin_print_idx = - error_index > extension ? error_index - extension : SymbolOffsetT{0}; - auto const end_print_idx = - std::min(error_index + extension, static_cast(input_json.size())); - auto const print_size = end_print_idx - begin_print_idx; - auto const h_input_json = cudf::detail::make_host_vector_sync( - cudf::device_span{input_json.data() + begin_print_idx, print_size}, stream); - - std::cerr << "Substring in the range [" + std::to_string(begin_print_idx) + ", " + - std::to_string(end_print_idx) + "]" + " of the input (invalid) json:\n"; - std::cerr << std::string(h_input_json.data(), h_input_json.size()) << std::endl; - - CUDF_FAIL("JSON Parser encountered an invalid format at location " + - std::to_string(error_index)); + auto const default_mr = cudf::get_current_device_resource(); + auto [concatenated_buff, delimiter, should_be_nullified] = + concat_json(input, /*nullify_invalid_rows*/ true, stream, default_mr); + + if (concatenated_buff->size() == 0) { + return {std::move(*concatenated_buff), delimiter, std::move(should_be_nullified)}; } + + // Append the delimiter to the end of the concatenated buffer. + // This is to fix a bug when the last string is invalid + // (https://github.com/rapidsai/cudf/issues/16999). + // The bug was fixed in libcudf's JSON reader by the same way like this. + auto unified_buff = rmm::device_buffer(concatenated_buff->size() + 1, stream, default_mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(unified_buff.data(), + concatenated_buff->data(), + concatenated_buff->size(), + cudaMemcpyDefault, + stream)); + cudf::detail::cuda_memcpy_async( + cudf::device_span(static_cast(unified_buff.data()) + concatenated_buff->size(), + 1u), + cudf::host_span(&delimiter, 1, false), + stream); + + return {std::move(unified_buff), delimiter, std::move(should_be_nullified)}; } // Check if a token is a json node. struct is_node { - __host__ __device__ bool operator()(PdaTokenT const token) const + __host__ __device__ bool operator()(PdaTokenT token) const { switch (token) { case token_t::StructBegin: @@ -170,8 +123,8 @@ struct is_node { // The top json node (top json object level) has level 0. // Each row in the input column should have levels starting from 1. // This is copied from cudf's `json_tree.cu`. -rmm::device_uvector compute_node_levels(int64_t num_nodes, - rmm::device_uvector const& tokens, +rmm::device_uvector compute_node_levels(std::size_t num_nodes, + cudf::device_span tokens, rmm::cuda_stream_view stream) { auto token_levels = rmm::device_uvector(tokens.size(), stream); @@ -204,8 +157,10 @@ rmm::device_uvector compute_node_levels(int64_t num_nodes, [does_push, does_pop] __device__(PdaTokenT const token) -> cudf::size_type { return does_push(token) - does_pop(token); })); - thrust::exclusive_scan( - rmm::exec_policy(stream), push_pop_it, push_pop_it + tokens.size(), token_levels.begin()); + thrust::exclusive_scan(rmm::exec_policy_nosync(stream), + push_pop_it, + push_pop_it + tokens.size(), + token_levels.begin()); auto node_levels = rmm::device_uvector(num_nodes, stream); auto const copy_end = cudf::detail::copy_if_safe(token_levels.begin(), @@ -215,7 +170,7 @@ rmm::device_uvector compute_node_levels(int64_t num_nodes, is_node{}, stream); CUDF_EXPECTS(thrust::distance(node_levels.begin(), copy_end) == num_nodes, - "Node level count mismatch"); + "Node level count mismatch."); #ifdef DEBUG_FROM_JSON print_debug(node_levels, "Node levels", ", ", stream); @@ -225,7 +180,7 @@ rmm::device_uvector compute_node_levels(int64_t num_nodes, // Compute the map from nodes to their indices in the list of all tokens. rmm::device_uvector compute_node_to_token_index_map( - int64_t num_nodes, rmm::device_uvector const& tokens, rmm::cuda_stream_view stream) + std::size_t num_nodes, cudf::device_span tokens, rmm::cuda_stream_view stream) { auto node_token_ids = rmm::device_uvector(num_nodes, stream); auto const node_id_it = thrust::counting_iterator(0); @@ -236,7 +191,7 @@ rmm::device_uvector compute_node_to_token_index_map( is_node{}, stream); CUDF_EXPECTS(thrust::distance(node_token_ids.begin(), copy_end) == num_nodes, - "Invalid computation for node-to-token-index map"); + "Invalid computation for node-to-token-index map."); #ifdef DEBUG_FROM_JSON print_map_debug(node_token_ids, "Node-to-token-index map", stream); @@ -247,7 +202,7 @@ rmm::device_uvector compute_node_to_token_index_map( // This is copied from cudf's `json_tree.cu`. template std::pair, rmm::device_uvector> stable_sorted_key_order( - rmm::device_uvector const& keys, rmm::cuda_stream_view stream) + cudf::device_span keys, rmm::cuda_stream_view stream) { // Buffers used for storing intermediate results during sorting. rmm::device_uvector keys_buffer1(keys.size(), stream); @@ -257,8 +212,8 @@ std::pair, rmm::device_uvector> stable_s cub::DoubleBuffer keys_buffer(keys_buffer1.data(), keys_buffer2.data()); cub::DoubleBuffer order_buffer(order_buffer1.data(), order_buffer2.data()); - thrust::copy(rmm::exec_policy(stream), keys.begin(), keys.end(), keys_buffer1.begin()); - thrust::sequence(rmm::exec_policy(stream), order_buffer1.begin(), order_buffer1.end()); + thrust::copy(rmm::exec_policy_nosync(stream), keys.begin(), keys.end(), keys_buffer1.begin()); + thrust::sequence(rmm::exec_policy_nosync(stream), order_buffer1.begin(), order_buffer1.end()); size_t temp_storage_bytes = 0; cub::DeviceRadixSort::SortPairs( @@ -280,15 +235,15 @@ std::pair, rmm::device_uvector> stable_s } // This is copied from cudf's `json_tree.cu`. -void propagate_parent_to_siblings(rmm::device_uvector const& node_levels, - rmm::device_uvector& parent_node_ids, +void propagate_parent_to_siblings(cudf::device_span node_levels, + cudf::device_span parent_node_ids, rmm::cuda_stream_view stream) { auto const [sorted_node_levels, sorted_order] = stable_sorted_key_order(node_levels, stream); // Instead of gather, using permutation_iterator, which is ~17% faster. thrust::inclusive_scan_by_key( - rmm::exec_policy(stream), + rmm::exec_policy_nosync(stream), sorted_node_levels.begin(), sorted_node_levels.end(), thrust::make_permutation_iterator(parent_node_ids.begin(), sorted_order.begin()), @@ -299,13 +254,12 @@ void propagate_parent_to_siblings(rmm::device_uvector const& node_le // This is copied from cudf's `json_tree.cu`. rmm::device_uvector compute_parent_node_ids( - int64_t num_nodes, - rmm::device_uvector const& tokens, - rmm::device_uvector const& node_token_ids, + cudf::device_span tokens, + cudf::device_span node_token_ids, rmm::cuda_stream_view stream) { - auto const first_childs_parent_token_id = cuda::proclaim_return_type( - [tokens = tokens.begin()] __device__(auto i) -> NodeIndexT { + auto const first_childs_parent_token_id = + cuda::proclaim_return_type([tokens] __device__(auto i) -> NodeIndexT { if (i <= 0) { return -1; } if (tokens[i - 1] == token_t::StructBegin || tokens[i - 1] == token_t::ListBegin) { return i - 1; @@ -319,9 +273,10 @@ rmm::device_uvector compute_parent_node_ids( } }); + auto const num_nodes = node_token_ids.size(); auto parent_node_ids = rmm::device_uvector(num_nodes, stream); thrust::transform( - rmm::exec_policy(stream), + rmm::exec_policy_nosync(stream), node_token_ids.begin(), node_token_ids.end(), parent_node_ids.begin(), @@ -345,17 +300,18 @@ rmm::device_uvector compute_parent_node_ids( return parent_node_ids; } +// Special values to denote if a node is a key or value to extract for the output. constexpr int8_t key_sentinel{1}; constexpr int8_t value_sentinel{2}; // Check for each node if it is a key or a value field. rmm::device_uvector check_key_or_value_nodes( - rmm::device_uvector const& parent_node_ids, rmm::cuda_stream_view stream) + cudf::device_span parent_node_ids, rmm::cuda_stream_view stream) { auto key_or_value = rmm::device_uvector(parent_node_ids.size(), stream); auto const transform_it = thrust::counting_iterator(0); thrust::transform( - rmm::exec_policy(stream), + rmm::exec_policy_nosync(stream), transform_it, transform_it + parent_node_ids.size(), key_or_value.begin(), @@ -363,11 +319,11 @@ rmm::device_uvector check_key_or_value_nodes( [key_sentinel = key_sentinel, value_sentinel = value_sentinel, parent_ids = parent_node_ids.begin()] __device__(auto const node_id) -> int8_t { - if (parent_ids[node_id] > 0) { + if (parent_ids[node_id] >= 0) { auto const grand_parent = parent_ids[parent_ids[node_id]]; - if (grand_parent == 0) { + if (grand_parent < 0) { return key_sentinel; - } else if (parent_ids[grand_parent] == 0) { + } else if (parent_ids[grand_parent] < 0) { return value_sentinel; } } @@ -381,10 +337,10 @@ rmm::device_uvector check_key_or_value_nodes( return key_or_value; } -// Convert token indices to node ranges for each valid node. +// Convert token positions to node ranges for each valid node. struct node_ranges_fn { cudf::device_span tokens; - cudf::device_span token_indices; + cudf::device_span token_positions; cudf::device_span node_token_ids; cudf::device_span parent_node_ids; cudf::device_span key_or_value; @@ -431,7 +387,7 @@ struct node_ranges_fn { }; }); - auto const get_token_index = cuda::proclaim_return_type( + auto const get_token_position = cuda::proclaim_return_type( [include_quote_char = include_quote_char] __device__(PdaTokenT const token, SymbolOffsetT const token_index) { constexpr SymbolOffsetT quote_char_size = 1; @@ -456,18 +412,18 @@ struct node_ranges_fn { cudf_assert(is_begin_of_section(token) && "Invalid node category."); // The section from the original JSON input that this token demarcates. - auto const range_begin = get_token_index(token, token_indices[token_idx]); + auto const range_begin = get_token_position(token, token_positions[token_idx]); auto range_end = range_begin + 1; // non-leaf, non-field nodes ignore this value. if ((token_idx + 1) < tokens.size() && end_of_partner(token) == tokens[token_idx + 1]) { // Update the range_end for this pair of tokens - range_end = get_token_index(tokens[token_idx + 1], token_indices[token_idx + 1]); + range_end = get_token_position(tokens[token_idx + 1], token_positions[token_idx + 1]); } else { auto nested_range_value = nested_node_to_value(token); // iterate until this is zero auto end_idx = token_idx + 1; while (end_idx < tokens.size()) { nested_range_value += nested_node_to_value(tokens[end_idx]); if (nested_range_value == 0) { - range_end = get_token_index(tokens[end_idx], token_indices[end_idx]) + 1; + range_end = get_token_position(tokens[end_idx], token_positions[end_idx]) + 1; break; } ++end_idx; @@ -479,26 +435,26 @@ struct node_ranges_fn { } }; -// Compute index range for each node. +// Compute position range for each node. // These ranges identify positions to extract nodes from the unified json string. rmm::device_uvector> compute_node_ranges( - int64_t num_nodes, - rmm::device_uvector const& tokens, - rmm::device_uvector const& token_indices, - rmm::device_uvector const& node_token_ids, - rmm::device_uvector const& parent_node_ids, - rmm::device_uvector const& key_or_value, + cudf::device_span tokens, + cudf::device_span token_positions, + cudf::device_span node_token_ids, + cudf::device_span parent_node_ids, + cudf::device_span key_or_value, rmm::cuda_stream_view stream) { + auto const num_nodes = node_token_ids.size(); auto node_ranges = rmm::device_uvector>(num_nodes, stream); auto const transform_it = thrust::counting_iterator(0); thrust::transform( - rmm::exec_policy(stream), + rmm::exec_policy_nosync(stream), transform_it, transform_it + num_nodes, node_ranges.begin(), - node_ranges_fn{tokens, token_indices, node_token_ids, parent_node_ids, key_or_value}); + node_ranges_fn{tokens, token_positions, node_token_ids, parent_node_ids, key_or_value}); #ifdef DEBUG_FROM_JSON print_pair_debug(node_ranges, "Node ranges", stream); @@ -510,14 +466,14 @@ rmm::device_uvector> compute_node_ran // This both calculates the output size and executes the substring. // No bound check is performed, assuming that the substring bounds are all valid. struct substring_fn { - cudf::device_span const d_string; - cudf::device_span const> const d_ranges; + cudf::device_span d_string; + cudf::device_span const> d_ranges; cudf::size_type* d_sizes; char* d_chars; cudf::detail::input_offsetalator d_offsets; - __device__ void operator()(cudf::size_type const idx) + __device__ void operator()(cudf::size_type idx) { auto const range = d_ranges[idx]; auto const size = range.second - range.first; @@ -531,52 +487,40 @@ struct substring_fn { // Extract key-value string pairs from the input json string. std::unique_ptr extract_keys_or_values( - bool extract_key, - int64_t num_nodes, - rmm::device_uvector> const& node_ranges, - rmm::device_uvector const& key_or_value, - rmm::device_uvector const& unified_json_buff, + int8_t key_value_sentinel, + cudf::device_span const> node_ranges, + cudf::device_span key_or_value, + cudf::device_span input_json, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { - auto const is_key = cuda::proclaim_return_type( - [key_or_value = key_or_value.begin()] __device__(auto const node_id) { - return key_or_value[node_id] == key_sentinel; - }); - - auto const is_value = cuda::proclaim_return_type( - [key_or_value = key_or_value.begin()] __device__(auto const node_id) { - return key_or_value[node_id] == value_sentinel; + auto const is_key_or_value = cuda::proclaim_return_type( + [key_or_value, key_value_sentinel] __device__(auto const node_id) { + return key_or_value[node_id] == key_value_sentinel; }); - auto extract_ranges = - rmm::device_uvector>(num_nodes, stream, mr); - auto const stencil_it = thrust::make_counting_iterator(0); - auto const range_end = extract_key ? cudf::detail::copy_if_safe(node_ranges.begin(), - node_ranges.end(), - stencil_it, - extract_ranges.begin(), - is_key, - stream) - : cudf::detail::copy_if_safe(node_ranges.begin(), - node_ranges.end(), - stencil_it, - extract_ranges.begin(), - is_value, - stream); - auto const num_extract = thrust::distance(extract_ranges.begin(), range_end); + auto extracted_ranges = + rmm::device_uvector>(node_ranges.size(), stream, mr); + auto const range_end = cudf::detail::copy_if_safe(node_ranges.begin(), + node_ranges.end(), + thrust::make_counting_iterator(0), + extracted_ranges.begin(), + is_key_or_value, + stream); + auto const num_extract = thrust::distance(extracted_ranges.begin(), range_end); + if (num_extract == 0) { return cudf::make_empty_column(cudf::data_type{cudf::type_id::STRING}); } auto [offsets, chars] = cudf::strings::detail::make_strings_children( - substring_fn{unified_json_buff, extract_ranges}, num_extract, stream, mr); + substring_fn{input_json, extracted_ranges}, num_extract, stream, mr); return cudf::make_strings_column( num_extract, std::move(offsets), chars.release(), 0, rmm::device_buffer{}); } // Compute the offsets for the final lists of Struct. -rmm::device_uvector compute_list_offsets( +std::unique_ptr compute_list_offsets( cudf::size_type n_lists, - rmm::device_uvector const& parent_node_ids, - rmm::device_uvector const& key_or_value, + cudf::device_span parent_node_ids, + cudf::device_span key_or_value, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { @@ -584,15 +528,15 @@ rmm::device_uvector compute_list_offsets( // These object nodes are given as one row of the input json strings column. auto node_child_counts = rmm::device_uvector(parent_node_ids.size(), stream); - // For the nodes having parent_id == 0 (they are json object given by one input row), set their - // child counts to zero. Otherwise, set child counts to `-1` (a sentinel number). + // For the nodes having parent_id < 0 (they are json object given by one input row), set their + // child counts to zero. Otherwise, set child counts to a negative sentinel number. thrust::transform( - rmm::exec_policy(stream), + rmm::exec_policy_nosync(stream), parent_node_ids.begin(), parent_node_ids.end(), node_child_counts.begin(), cuda::proclaim_return_type([] __device__(auto const parent_id) -> NodeIndexT { - return parent_id == 0 ? 0 : std::numeric_limits::lowest(); + return parent_id < 0 ? 0 : std::numeric_limits::lowest(); })); auto const is_key = cuda::proclaim_return_type( @@ -602,7 +546,7 @@ rmm::device_uvector compute_list_offsets( // Count the number of keys for each json object using `atomicAdd`. auto const transform_it = thrust::counting_iterator(0); - thrust::for_each(rmm::exec_policy(stream), + thrust::for_each(rmm::exec_policy_nosync(stream), transform_it, transform_it + parent_node_ids.size(), [is_key, @@ -630,73 +574,206 @@ rmm::device_uvector compute_list_offsets( print_debug(list_offsets, "Output list sizes (except the last one)", ", ", stream); #endif - thrust::exclusive_scan( - rmm::exec_policy(stream), list_offsets.begin(), list_offsets.end(), list_offsets.begin()); + thrust::exclusive_scan(rmm::exec_policy_nosync(stream), + list_offsets.begin(), + list_offsets.end(), + list_offsets.begin()); #ifdef DEBUG_FROM_JSON print_debug(list_offsets, "Output list offsets", ", ", stream); #endif - return list_offsets; + return std::make_unique(std::move(list_offsets), rmm::device_buffer{}, 0); +} + +// If a JSON line is invalid, the tokens corresponding to that line are output as +// [StructBegin, StructEnd] but their locations in the unified JSON string are all set to 0. +struct is_invalid_struct_begin { + cudf::device_span tokens; + cudf::device_span node_token_ids; + cudf::device_span token_positions; + + __device__ bool operator()(int node_idx) const + { + auto const node_token_id = node_token_ids[node_idx]; + auto const node_token = tokens[node_token_id]; + if (node_token != token_t::StructBegin) { return false; } + + // The next token in the token stream after node_token. + // Since the token stream has been post process, there should always be the more token. + auto const next_token = tokens[node_token_id + 1]; + if (next_token != token_t::StructEnd) { return false; } + + return token_positions[node_token_id] == 0 && token_positions[node_token_id + 1] == 0; + } +}; + +// A line begin with a StructBegin token which does not have parent. +struct is_line_begin { + cudf::device_span tokens; + cudf::device_span node_token_ids; + cudf::device_span parent_node_ids; + + __device__ bool operator()(int node_idx) const + { + return tokens[node_token_ids[node_idx]] == token_t::StructBegin && + parent_node_ids[node_idx] < 0; + } +}; + +std::pair create_null_mask( + cudf::size_type num_rows, + std::unique_ptr const& should_be_nullified, + cudf::device_span tokens, + cudf::device_span token_positions, + cudf::device_span node_token_ids, + cudf::device_span parent_node_ids, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto const num_nodes = node_token_ids.size(); + + // To store indices of the StructBegin nodes that are detected as of invalid JSON objects. + rmm::device_uvector invalid_indices(num_nodes, stream); + + auto const node_id_it = thrust::counting_iterator(0); + auto const invalid_copy_end = + cudf::detail::copy_if_safe(node_id_it, + node_id_it + node_token_ids.size(), + invalid_indices.begin(), + is_invalid_struct_begin{tokens, node_token_ids, token_positions}, + stream); + auto const num_invalid = thrust::distance(invalid_indices.begin(), invalid_copy_end); +#ifdef DEBUG_FROM_JSON + print_debug(invalid_indices, + "Invalid StructBegin nodes' indices (size = " + std::to_string(num_invalid) + ")", + ", ", + stream); +#endif + + // In addition to `should_be_nullified` which identified the null and empty rows, + // we also need to identify the rows containing invalid JSON objects. + if (num_invalid > 0) { + // Build a list of StructBegin tokens that start a line. + // We must have such list having size equal to the number of original input JSON strings. + rmm::device_uvector line_begin_indices(num_nodes, stream); + auto const line_begin_copy_end = + cudf::detail::copy_if_safe(node_id_it, + node_id_it + node_token_ids.size(), + line_begin_indices.begin(), + is_line_begin{tokens, node_token_ids, parent_node_ids}, + stream); + auto const num_line_begin = thrust::distance(line_begin_indices.begin(), line_begin_copy_end); + CUDF_EXPECTS(num_line_begin == num_rows, "Incorrect count of JSON objects."); +#ifdef DEBUG_FROM_JSON + print_debug(line_begin_indices, + "Line begin StructBegin indices (size = " + std::to_string(num_line_begin) + ")", + ", ", + stream); +#endif + + // Scatter the indices of the invalid StructBegin nodes into `should_be_nullified`. + thrust::for_each(rmm::exec_policy_nosync(stream), + invalid_indices.begin(), + invalid_indices.begin() + num_invalid, + [should_be_nullified = should_be_nullified->mutable_view().begin(), + line_begin_indices = line_begin_indices.begin(), + num_rows] __device__(auto node_idx) { + auto const row_idx = thrust::lower_bound(thrust::seq, + line_begin_indices, + line_begin_indices + num_rows, + node_idx) - + line_begin_indices; + should_be_nullified[row_idx] = true; + }); + } + + auto const valid_it = should_be_nullified->view().begin(); + auto [null_mask, null_count] = cudf::detail::valid_if( + valid_it, valid_it + should_be_nullified->size(), thrust::logical_not{}, stream, mr); + return {null_count > 0 ? std::move(null_mask) : rmm::device_buffer{0, stream, mr}, null_count}; } } // namespace std::unique_ptr from_json_to_raw_map(cudf::strings_column_view const& input, + bool normalize_single_quotes, + bool allow_leading_zeros, + bool allow_nonnumeric_numbers, + bool allow_unquoted_control, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { - // Firstly, concatenate all the input json strings into one giant input json string. + if (input.is_empty()) { return make_empty_map(stream, mr); } + + // Firstly, concatenate all the input json strings into one buffer. // When testing/debugging, the output can be validated using // https://jsonformatter.curiousconcept.com. - auto const unified_json_buff = unify_json_strings(input, stream); + auto [concat_json_buff, delimiter, should_be_nullified] = unify_json_strings(input, stream); + auto concat_buff_wrapper = + cudf::io::datasource::owning_buffer(std::move(concat_json_buff)); + if (normalize_single_quotes) { + cudf::io::json::detail::normalize_single_quotes( + concat_buff_wrapper, delimiter, stream, cudf::get_current_device_resource_ref()); + } + auto const preprocessed_input = cudf::device_span( + reinterpret_cast(concat_buff_wrapper.data()), concat_buff_wrapper.size()); // Tokenize the input json strings. static_assert(sizeof(SymbolT) == sizeof(char), "Invalid internal data for nested json tokenizer."); - auto const [tokens, token_indices] = cudf::io::json::detail::get_token_stream( - cudf::device_span{unified_json_buff.data(), unified_json_buff.size()}, - cudf::io::json_reader_options{}, + auto const [tokens, token_positions] = cudf::io::json::detail::get_token_stream( + preprocessed_input, + cudf::io::json_reader_options_builder{} + .lines(true) + .normalize_whitespace(false) // don't need it + .experimental(true) + .mixed_types_as_string(true) + .recovery_mode(cudf::io::json_recovery_mode_t::RECOVER_WITH_NULL) + .strict_validation(true) + // specifying parameters + .delimiter(delimiter) + .numeric_leading_zeros(allow_leading_zeros) + .nonnumeric_numbers(allow_nonnumeric_numbers) + .unquoted_control_chars(allow_unquoted_control) + .build(), stream, - rmm::mr::get_current_device_resource()); + cudf::get_current_device_resource()); #ifdef DEBUG_FROM_JSON print_debug(tokens, "Tokens", ", ", stream); - print_debug(token_indices, "Token indices", ", ", stream); + print_debug(token_positions, "Token positions", ", ", stream); + std::cerr << "normalize_single_quotes: " << normalize_single_quotes << std::endl; + std::cerr << "allow_leading_zeros: " << allow_leading_zeros << std::endl; + std::cerr << "allow_nonnumeric_numbers: " << allow_nonnumeric_numbers << std::endl; + std::cerr << "allow_unquoted_control: " << allow_unquoted_control << std::endl; #endif - // Make sure there is no error during parsing. - throw_if_error(unified_json_buff, tokens, token_indices, stream); - auto const num_nodes = - thrust::count_if(rmm::exec_policy(stream), tokens.begin(), tokens.end(), is_node{}); + thrust::count_if(rmm::exec_policy_nosync(stream), tokens.begin(), tokens.end(), is_node{}); // Compute the map from nodes to their indices in the list of all tokens. auto const node_token_ids = compute_node_to_token_index_map(num_nodes, tokens, stream); // A map from each node to the index of its parent node. - auto const parent_node_ids = compute_parent_node_ids(num_nodes, tokens, node_token_ids, stream); + auto const parent_node_ids = compute_parent_node_ids(tokens, node_token_ids, stream); // Check for each node if it is a map key or a map value to extract. - auto const key_or_value_node = check_key_or_value_nodes(parent_node_ids, stream); + auto const is_key_or_value_node = check_key_or_value_nodes(parent_node_ids, stream); // Compute index range for each node. // These ranges identify positions to extract nodes from the unified json string. auto const node_ranges = compute_node_ranges( - num_nodes, tokens, token_indices, node_token_ids, parent_node_ids, key_or_value_node, stream); - - // - // From below are variables for returning output. - // + tokens, token_positions, node_token_ids, parent_node_ids, is_key_or_value_node, stream); auto extracted_keys = extract_keys_or_values( - true /*key*/, num_nodes, node_ranges, key_or_value_node, unified_json_buff, stream, mr); + key_sentinel, node_ranges, is_key_or_value_node, preprocessed_input, stream, mr); auto extracted_values = extract_keys_or_values( - false /*value*/, num_nodes, node_ranges, key_or_value_node, unified_json_buff, stream, mr); + value_sentinel, node_ranges, is_key_or_value_node, preprocessed_input, stream, mr); CUDF_EXPECTS(extracted_keys->size() == extracted_values->size(), "Invalid key-value pair extraction."); // Compute the offsets of the final output lists column. auto list_offsets = - compute_list_offsets(input.size(), parent_node_ids, key_or_value_node, stream, mr); + compute_list_offsets(input.size(), parent_node_ids, is_key_or_value_node, stream, mr); #ifdef DEBUG_FROM_JSON print_output_spark_map(list_offsets, extracted_keys, extracted_values, stream); @@ -709,15 +786,27 @@ std::unique_ptr from_json_to_raw_map(cudf::strings_column_view con auto structs_col = cudf::make_structs_column( num_pairs, std::move(out_keys_vals), 0, rmm::device_buffer{}, stream, mr); - auto offsets = std::make_unique(std::move(list_offsets), rmm::device_buffer{}, 0); - - return cudf::make_lists_column(input.size(), - std::move(offsets), - std::move(structs_col), - input.null_count(), - cudf::detail::copy_bitmask(input.parent(), stream, mr), - stream, - mr); + // Do not use `cudf::make_lists_column` since we do not need to call `purge_nonempty_nulls` + // on the children columns as they do not have non-empty nulls. + std::vector> list_children; + list_children.emplace_back(std::move(list_offsets)); + list_children.emplace_back(std::move(structs_col)); + + auto [null_mask, null_count] = create_null_mask(input.size(), + should_be_nullified, + tokens, + token_positions, + node_token_ids, + parent_node_ids, + stream, + mr); + + return std::make_unique(cudf::data_type{cudf::type_id::LIST}, + input.size(), + rmm::device_buffer{}, + std::move(null_mask), + null_count, + std::move(list_children)); } } // namespace spark_rapids_jni diff --git a/src/main/cpp/src/from_json_to_raw_map_debug.cuh b/src/main/cpp/src/from_json_to_raw_map_debug.cuh index 43584d9fd5..481002c369 100644 --- a/src/main/cpp/src/from_json_to_raw_map_debug.cuh +++ b/src/main/cpp/src/from_json_to_raw_map_debug.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -108,20 +108,25 @@ void print_pair_debug(rmm::device_uvector const& input, } // Print the final output map data (Spark's MapType, i.e., List>). -void print_output_spark_map(rmm::device_uvector const& list_offsets, +void print_output_spark_map(std::unique_ptr const& list_offsets, std::unique_ptr const& extracted_keys, std::unique_ptr const& extracted_values, rmm::cuda_stream_view stream) { - auto const keys_child = extracted_keys->child(cudf::strings_column_view::chars_column_index); + if (extracted_keys->size() == 0) { + std::cerr << "Extract keys-values are all empty.\n" << std::endl; + return; + } + + auto const keys_sv = cudf::strings_column_view{extracted_keys->view()}; + auto const values_sv = cudf::strings_column_view{extracted_values->view()}; auto const keys_offsets = extracted_keys->child(cudf::strings_column_view::offsets_column_index); - auto const values_child = extracted_values->child(cudf::strings_column_view::chars_column_index); auto const values_offsets = extracted_values->child(cudf::strings_column_view::offsets_column_index); auto const h_extracted_keys_child = cudf::detail::make_host_vector_sync( - cudf::device_span{keys_child.view().data(), - static_cast(keys_child.size())}, + cudf::device_span{keys_sv.chars_begin(stream), + static_cast(keys_sv.chars_size(stream))}, stream); auto const h_extracted_keys_offsets = cudf::detail::make_host_vector_sync( cudf::device_span{keys_offsets.view().data(), @@ -129,8 +134,8 @@ void print_output_spark_map(rmm::device_uvector const& list_off stream); auto const h_extracted_values_child = cudf::detail::make_host_vector_sync( - cudf::device_span{values_child.view().data(), - static_cast(values_child.size())}, + cudf::device_span{values_sv.chars_begin(stream), + static_cast(values_sv.chars_size(stream))}, stream); auto const h_extracted_values_offsets = cudf::detail::make_host_vector_sync( cudf::device_span{values_offsets.view().data(), @@ -138,7 +143,9 @@ void print_output_spark_map(rmm::device_uvector const& list_off stream); auto const h_list_offsets = cudf::detail::make_host_vector_sync( - cudf::device_span{list_offsets.data(), list_offsets.size()}, stream); + cudf::device_span{list_offsets->view().begin(), + static_cast(list_offsets->size())}, + stream); CUDF_EXPECTS(h_list_offsets.back() == extracted_keys->size(), "Invalid list offsets computation."); diff --git a/src/main/cpp/src/from_json_to_structs.cu b/src/main/cpp/src/from_json_to_structs.cu new file mode 100644 index 0000000000..ddfdcc4c4f --- /dev/null +++ b/src/main/cpp/src/from_json_to_structs.cu @@ -0,0 +1,1083 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cast_string.hpp" +#include "json_utils.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace spark_rapids_jni { + +namespace detail { + +namespace { + +/** + * @brief The struct similar to `cudf::io::schema_element` with adding decimal precision and + * preserving column order. + */ +struct schema_element_with_precision { + cudf::data_type type; + int precision; + std::vector> child_types; +}; + +std::pair parse_schema_element( + std::size_t& index, + std::vector const& col_names, + std::vector const& num_children, + std::vector const& types, + std::vector const& scales, + std::vector const& precisions) +{ + // Get data for the current column. + auto const d_type = cudf::data_type{static_cast(types[index]), scales[index]}; + auto const precision = precisions[index]; + auto const col_num_children = num_children[index]; + index++; + + std::map children; + std::vector> children_with_precisions; + std::vector child_names(col_num_children); + + if (d_type.id() == cudf::type_id::STRUCT || d_type.id() == cudf::type_id::LIST) { + for (int i = 0; i < col_num_children; ++i) { + auto const& name = col_names[index]; + auto [child, child_with_precision] = + parse_schema_element(index, col_names, num_children, types, scales, precisions); + children.emplace(name, std::move(child)); + children_with_precisions.emplace_back(name, std::move(child_with_precision)); + child_names[i] = name; + } + } else { + CUDF_EXPECTS(col_num_children == 0, + "Found children for a non-nested type that should have none.", + std::invalid_argument); + } + + // Note that if the first schema element does not has type STRUCT/LIST then it always has type + // STRING, since we intentionally parse JSON into strings column for later post-processing. + auto const schema_dtype = + d_type.id() == cudf::type_id::STRUCT || d_type.id() == cudf::type_id::LIST + ? d_type + : cudf::data_type{cudf::type_id::STRING}; + return {cudf::io::schema_element{schema_dtype, std::move(children), {std::move(child_names)}}, + schema_element_with_precision{d_type, precision, std::move(children_with_precisions)}}; +} + +// Generate struct type schemas by traveling the schema data by depth-first search order. +// Two separate schemas is generated: +// - The first one is used as input to `cudf::read_json`, in which the data types of all columns +// are specified as STRING type. As such, the table returned by `cudf::read_json` will contain +// only strings columns or nested (LIST/STRUCT) columns. +// - The second schema contains decimal precision (if available) and preserves schema column types +// as well as the column order, used for converting from STRING type to the desired types for the +// final output. +std::pair generate_struct_schema( + std::vector const& col_names, + std::vector const& num_children, + std::vector const& types, + std::vector const& scales, + std::vector const& precisions) +{ + std::map schema_cols; + std::vector> schema_cols_with_precisions; + std::vector name_order; + + std::size_t index = 0; + while (index < types.size()) { + auto const& name = col_names[index]; + auto [child, child_with_precision] = + parse_schema_element(index, col_names, num_children, types, scales, precisions); + schema_cols.emplace(name, std::move(child)); + schema_cols_with_precisions.emplace_back(name, std::move(child_with_precision)); + name_order.push_back(name); + } + return { + cudf::io::schema_element{ + cudf::data_type{cudf::type_id::STRUCT}, std::move(schema_cols), {std::move(name_order)}}, + schema_element_with_precision{ + cudf::data_type{cudf::type_id::STRUCT}, -1, std::move(schema_cols_with_precisions)}}; +} + +using string_index_pair = thrust::pair; + +std::unique_ptr cast_strings_to_booleans(cudf::column_view const& input, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + CUDF_FUNC_RANGE(); + + auto const string_count = input.size(); + if (string_count == 0) { return cudf::make_empty_column(cudf::data_type{cudf::type_id::BOOL8}); } + + auto output = cudf::make_fixed_width_column( + cudf::data_type{cudf::type_id::BOOL8}, string_count, cudf::mask_state::UNALLOCATED, stream, mr); + auto validity = rmm::device_uvector(string_count, stream); + + auto const input_sv = cudf::strings_column_view{input}; + auto const offsets_it = + cudf::detail::offsetalator_factory::make_input_iterator(input_sv.offsets()); + auto const d_input_ptr = cudf::column_device_view::create(input, stream); + auto const is_valid_it = cudf::detail::make_validity_iterator(*d_input_ptr); + auto const output_it = thrust::make_zip_iterator( + thrust::make_tuple(output->mutable_view().begin(), validity.begin())); + thrust::tabulate( + rmm::exec_policy_nosync(stream), + output_it, + output_it + string_count, + [chars = input_sv.chars_begin(stream), offsets = offsets_it, is_valid = is_valid_it] __device__( + auto idx) -> thrust::tuple { + if (is_valid[idx]) { + auto const start_offset = offsets[idx]; + auto const end_offset = offsets[idx + 1]; + auto const size = end_offset - start_offset; + auto const str = chars + start_offset; + + if (size == 4 && str[0] == 't' && str[1] == 'r' && str[2] == 'u' && str[3] == 'e') { + return {true, true}; + } + if (size == 5 && str[0] == 'f' && str[1] == 'a' && str[2] == 'l' && str[3] == 's' && + str[4] == 'e') { + return {false, true}; + } + } + + // Either null input, or the input string is neither `true` nor `false`. + return {false, false}; + }); + + auto [null_mask, null_count] = + cudf::detail::valid_if(validity.begin(), validity.end(), thrust::identity{}, stream, mr); + output->set_null_mask(null_count > 0 ? std::move(null_mask) : rmm::device_buffer{0, stream, mr}, + null_count); + + return output; +} + +std::unique_ptr cast_strings_to_integers(cudf::column_view const& input, + cudf::data_type output_type, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + CUDF_FUNC_RANGE(); + + auto const string_count = input.size(); + if (string_count == 0) { return cudf::make_empty_column(output_type); } + + auto const input_sv = cudf::strings_column_view{input}; + auto const input_offsets_it = + cudf::detail::offsetalator_factory::make_input_iterator(input_sv.offsets()); + auto const d_input_ptr = cudf::column_device_view::create(input, stream); + auto const valid_input_it = cudf::detail::make_validity_iterator(*d_input_ptr); + + // We need to nullify the invalid string rows. + // Technically, we should just mask out these rows as nulls through the nullmask. + // These masked out non-empty nulls will be handled in the conversion API. + auto valids = rmm::device_uvector(string_count, stream); + + // Since the strings store integer numbers, they should be very short. + // As such, using one thread per string should be fine. + thrust::tabulate(rmm::exec_policy_nosync(stream), + valids.begin(), + valids.end(), + [chars = input_sv.chars_begin(stream), + offsets = input_offsets_it, + valid_input = valid_input_it] __device__(cudf::size_type idx) -> bool { + if (!valid_input[idx]) { return false; } + + auto in_ptr = chars + offsets[idx]; + auto const in_end = chars + offsets[idx + 1]; + while (in_ptr != in_end) { + if (*in_ptr == '.' || *in_ptr == 'e' || *in_ptr == 'E') { return false; } + ++in_ptr; + } + + return true; + }); + + auto const [null_mask, null_count] = + cudf::detail::valid_if(valids.begin(), + valids.end(), + thrust::identity{}, + stream, + cudf::get_current_device_resource_ref()); + // If the null count doesn't change, just use the input column for conversion. + auto const input_applied_null = + null_count == input.null_count() + ? cudf::column_view{} + : cudf::column_view{cudf::data_type{cudf::type_id::STRING}, + input_sv.size(), + input_sv.chars_begin(stream), + reinterpret_cast(null_mask.data()), + null_count, + input_sv.offset(), + std::vector{input_sv.offsets()}}; + + return spark_rapids_jni::string_to_integer( + output_type, + null_count == input.null_count() ? input_sv : cudf::strings_column_view{input_applied_null}, + /*ansi_mode*/ false, + /*strip*/ false, + stream, + mr); +} + +std::pair, bool> try_remove_quotes_for_floats( + cudf::column_view const& input, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) +{ + CUDF_FUNC_RANGE(); + + auto const string_count = input.size(); + if (string_count == 0) { return {nullptr, false}; } + + auto const input_sv = cudf::strings_column_view{input}; + auto const input_offsets_it = + cudf::detail::offsetalator_factory::make_input_iterator(input_sv.offsets()); + auto const d_input_ptr = cudf::column_device_view::create(input, stream); + auto const is_valid_it = cudf::detail::make_validity_iterator(*d_input_ptr); + + auto string_pairs = rmm::device_uvector(string_count, stream); + thrust::tabulate(rmm::exec_policy_nosync(stream), + string_pairs.begin(), + string_pairs.end(), + [chars = input_sv.chars_begin(stream), + offsets = input_offsets_it, + is_valid = is_valid_it] __device__(cudf::size_type idx) -> string_index_pair { + if (!is_valid[idx]) { return {nullptr, 0}; } + + auto const start_offset = offsets[idx]; + auto const end_offset = offsets[idx + 1]; + auto const size = end_offset - start_offset; + auto const str = chars + start_offset; + + // Need to check for size, since the input string may contain just a single + // character `"`. Such input should not be considered as quoted. + auto const is_quoted = size > 1 && str[0] == '"' && str[size - 1] == '"'; + + // We check and remove quotes only for the special cases (non-numeric numbers + // wrapped in double quotes) that are accepted in `from_json`. + // They are "NaN", "+INF", "-INF", "+Infinity", "Infinity", "-Infinity". + if (is_quoted) { + // "NaN" + auto accepted = size == 5 && str[1] == 'N' && str[2] == 'a' && str[3] == 'N'; + + // "+INF" and "-INF" + accepted = accepted || (size == 6 && (str[1] == '+' || str[1] == '-') && + str[2] == 'I' && str[3] == 'N' && str[4] == 'F'); + + // "Infinity" + accepted = accepted || (size == 10 && str[1] == 'I' && str[2] == 'n' && + str[3] == 'f' && str[4] == 'i' && str[5] == 'n' && + str[6] == 'i' && str[7] == 't' && str[8] == 'y'); + + // "+Infinity" and "-Infinity" + accepted = accepted || (size == 11 && (str[1] == '+' || str[1] == '-') && + str[2] == 'I' && str[3] == 'n' && str[4] == 'f' && + str[5] == 'i' && str[6] == 'n' && str[7] == 'i' && + str[8] == 't' && str[9] == 'y'); + + if (accepted) { return {str + 1, size - 2}; } + } + + return {str, size}; + }); + + auto const size_it = cudf::detail::make_counting_transform_iterator( + 0, + cuda::proclaim_return_type( + [string_pairs = string_pairs.begin()] __device__(cudf::size_type idx) -> cudf::size_type { + return string_pairs[idx].second; + })); + auto [offsets_column, bytes] = + cudf::strings::detail::make_offsets_child_column(size_it, size_it + string_count, stream, mr); + + // If the output has the same total bytes, the output should be the same as the input. + if (bytes == input_sv.chars_size(stream)) { return {nullptr, false}; } + + auto chars_data = cudf::strings::detail::make_chars_buffer( + offsets_column->view(), bytes, string_pairs.begin(), string_count, stream, mr); + + return {cudf::make_strings_column(string_count, + std::move(offsets_column), + chars_data.release(), + input.null_count(), + cudf::detail::copy_bitmask(input, stream, mr)), + true}; +} + +std::unique_ptr cast_strings_to_floats(cudf::column_view const& input, + cudf::data_type output_type, + bool allow_nonnumeric_numbers, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + CUDF_FUNC_RANGE(); + + auto const string_count = input.size(); + if (string_count == 0) { return cudf::make_empty_column(output_type); } + + if (allow_nonnumeric_numbers) { + // Non-numeric numbers are always quoted. + auto const [removed_quotes, success] = try_remove_quotes_for_floats(input, stream, mr); + return spark_rapids_jni::string_to_float( + output_type, + cudf::strings_column_view{success ? removed_quotes->view() : input}, + /*ansi_mode*/ false, + stream, + mr); + } + return spark_rapids_jni::string_to_float( + output_type, cudf::strings_column_view{input}, /*ansi_mode*/ false, stream, mr); +} + +// TODO there is a bug here around 0 https://github.com/NVIDIA/spark-rapids/issues/10898 +std::unique_ptr cast_strings_to_decimals(cudf::column_view const& input, + cudf::data_type output_type, + int precision, + bool is_us_locale, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + CUDF_FUNC_RANGE(); + + auto const string_count = input.size(); + if (string_count == 0) { return cudf::make_empty_column(output_type); } + + CUDF_EXPECTS(is_us_locale, "String to decimal conversion is only supported in US locale."); + + auto const input_sv = cudf::strings_column_view{input}; + auto const in_offsets = + cudf::detail::offsetalator_factory::make_input_iterator(input_sv.offsets()); + + // Count the number of characters `"`. + rmm::device_uvector quote_counts(string_count, stream); + // Count the number of characters `"` and `,` in each string. + rmm::device_uvector remove_counts(string_count, stream); + + { + using count_type = thrust::tuple; + auto const check_it = cudf::detail::make_counting_transform_iterator( + 0, + cuda::proclaim_return_type( + [chars = input_sv.chars_begin(stream)] __device__(auto idx) { + auto const c = chars[idx]; + auto const is_quote = c == '"'; + auto const should_remove = is_quote || c == ','; + return count_type{static_cast(is_quote), static_cast(should_remove)}; + })); + auto const plus_op = + cuda::proclaim_return_type([] __device__(count_type lhs, count_type rhs) { + return count_type{thrust::get<0>(lhs) + thrust::get<0>(rhs), + thrust::get<1>(lhs) + thrust::get<1>(rhs)}; + }); + + auto const out_count_it = + thrust::make_zip_iterator(quote_counts.begin(), remove_counts.begin()); + + std::size_t temp_storage_bytes = 0; + cub::DeviceSegmentedReduce::Reduce(nullptr, + temp_storage_bytes, + check_it, + out_count_it, + string_count, + in_offsets, + in_offsets + 1, + plus_op, + count_type{0, 0}, + stream.value()); + auto d_temp_storage = rmm::device_buffer{temp_storage_bytes, stream}; + cub::DeviceSegmentedReduce::Reduce(d_temp_storage.data(), + temp_storage_bytes, + check_it, + out_count_it, + string_count, + in_offsets, + in_offsets + 1, + plus_op, + count_type{0, 0}, + stream.value()); + } + + auto const out_size_it = cudf::detail::make_counting_transform_iterator( + 0, + cuda::proclaim_return_type( + [offsets = in_offsets, + quote_counts = quote_counts.begin(), + remove_counts = remove_counts.begin()] __device__(auto idx) { + auto const input_size = offsets[idx + 1] - offsets[idx]; + // If the current row is non-quoted, just return the original string. + // As such, non-quoted string containing `,` character will not be preprocessed. + if (quote_counts[idx] == 0) { return static_cast(input_size); } + + // For quoted strings, we will modify them, removing characters '"' and ','. + return static_cast(input_size - remove_counts[idx]); + })); + auto [offsets_column, bytes] = cudf::strings::detail::make_offsets_child_column( + out_size_it, out_size_it + string_count, stream, mr); + + // If the output strings column does not change in its total bytes, we can use the input directly. + if (bytes == input_sv.chars_size(stream)) { + return spark_rapids_jni::string_to_decimal(precision, + output_type.scale(), + input_sv, + /*ansi_mode*/ false, + /*strip*/ false, + stream, + mr); + } + + auto const out_offsets = + cudf::detail::offsetalator_factory::make_input_iterator(offsets_column->view()); + auto chars_data = rmm::device_uvector(bytes, stream, mr); + + // Since the strings store decimal numbers, they should not be very long. + // As such, using one thread per string should be fine. + thrust::for_each(rmm::exec_policy_nosync(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(string_count), + [in_offsets, + out_offsets, + input = input_sv.chars_begin(stream), + output = chars_data.begin()] __device__(auto idx) { + auto const in_size = in_offsets[idx + 1] - in_offsets[idx]; + auto const out_size = out_offsets[idx + 1] - out_offsets[idx]; + if (in_size == 0) { return; } + + // If the output size is not changed, we are returning the original unquoted + // string. Such string may still contain other alphabet characters, but that + // should be handled in the conversion function later on. + if (in_size == out_size) { + memcpy(output + out_offsets[idx], input + in_offsets[idx], in_size); + } else { // copy byte by byte, ignoring '"' and ',' characters. + auto in_ptr = input + in_offsets[idx]; + auto in_end = input + in_offsets[idx + 1]; + auto out_ptr = output + out_offsets[idx]; + while (in_ptr != in_end) { + if (*in_ptr != '"' && *in_ptr != ',') { + *out_ptr = *in_ptr; + ++out_ptr; + } + ++in_ptr; + } + } + }); + + // Don't care about the null mask, as nulls imply empty strings, which will also result in nulls. + auto const unquoted_strings = + cudf::make_strings_column(string_count, std::move(offsets_column), chars_data.release(), 0, {}); + + return spark_rapids_jni::string_to_decimal(precision, + output_type.scale(), + cudf::strings_column_view{unquoted_strings->view()}, + /*ansi_mode*/ false, + /*strip*/ false, + stream, + mr); +} + +std::pair, bool> try_remove_quotes( + cudf::strings_column_view const& input, + bool nullify_if_not_quoted, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + CUDF_FUNC_RANGE(); + + auto const string_count = input.size(); + if (string_count == 0) { return {nullptr, false}; } + + auto const input_offsets_it = + cudf::detail::offsetalator_factory::make_input_iterator(input.offsets()); + auto const d_input_ptr = cudf::column_device_view::create(input.parent(), stream); + auto const is_valid_it = cudf::detail::make_validity_iterator(*d_input_ptr); + + auto string_pairs = rmm::device_uvector(string_count, stream); + thrust::tabulate(rmm::exec_policy_nosync(stream), + string_pairs.begin(), + string_pairs.end(), + [nullify_if_not_quoted, + chars = input.chars_begin(stream), + offsets = input_offsets_it, + is_valid = is_valid_it] __device__(cudf::size_type idx) -> string_index_pair { + if (!is_valid[idx]) { return {nullptr, 0}; } + + auto const start_offset = offsets[idx]; + auto const end_offset = offsets[idx + 1]; + auto const size = end_offset - start_offset; + auto const str = chars + start_offset; + + // Need to check for size, since the input string may contain just a single + // character `"`. Such input should not be considered as quoted. + auto const is_quoted = size > 1 && str[0] == '"' && str[size - 1] == '"'; + if (nullify_if_not_quoted && !is_quoted) { return {nullptr, 0}; } + + if (is_quoted) { return {chars + start_offset + 1, size - 2}; } + return {chars + start_offset, size}; + }); + + auto const size_it = cudf::detail::make_counting_transform_iterator( + 0, + cuda::proclaim_return_type( + [string_pairs = string_pairs.begin()] __device__(cudf::size_type idx) -> cudf::size_type { + return string_pairs[idx].second; + })); + auto [offsets_column, bytes] = + cudf::strings::detail::make_offsets_child_column(size_it, size_it + string_count, stream, mr); + + // If the output has the same total bytes, the output should be the same as the input. + if (bytes == input.chars_size(stream)) { return {nullptr, false}; } + + auto chars_data = cudf::strings::detail::make_chars_buffer( + offsets_column->view(), bytes, string_pairs.begin(), string_count, stream, mr); + + if (nullify_if_not_quoted) { + auto output = cudf::make_strings_column(string_count, + std::move(offsets_column), + chars_data.release(), + 0, + rmm::device_buffer{0, stream, mr}); + + auto [null_mask, null_count] = cudf::detail::valid_if( + string_pairs.begin(), + string_pairs.end(), + [] __device__(string_index_pair const& pair) { return pair.first != nullptr; }, + stream, + mr); + if (null_count > 0) { output->set_null_mask(std::move(null_mask), null_count); } + + return {std::move(output), true}; + } + + return {cudf::make_strings_column(string_count, + std::move(offsets_column), + chars_data.release(), + input.null_count(), + cudf::detail::copy_bitmask(input.parent(), stream, mr)), + true}; +} + +// Copied and modified from `cudf/cpp/src/io/json/parser_features.cpp`. +struct empty_column_functor { + rmm::cuda_stream_view stream; + rmm::device_async_resource_ref mr; + + template ())> + std::unique_ptr operator()(schema_element_with_precision const& schema) const + { + return cudf::make_empty_column(schema.type); + } + + template )> + std::unique_ptr operator()(schema_element_with_precision const& schema) const + { + CUDF_EXPECTS(schema.child_types.size() == 1, "Lists column should have only one child"); + auto offsets = cudf::make_empty_column(cudf::data_type(cudf::type_to_id())); + auto child = cudf::type_dispatcher( + schema.child_types.front().second.type, *this, schema.child_types.front().second); + return cudf::make_lists_column(0, std::move(offsets), std::move(child), 0, {}, stream, mr); + } + + template )> + std::unique_ptr operator()(schema_element_with_precision const& schema) const + { + std::vector> children; + for (auto const& [child_name, child_schema] : schema.child_types) { + children.emplace_back(cudf::type_dispatcher(child_schema.type, *this, child_schema)); + } + return cudf::make_structs_column(0, std::move(children), 0, {}, stream, mr); + } +}; + +// Copied and modified from `cudf/cpp/src/io/json/parser_features.cpp`. +struct allnull_column_functor { + rmm::cuda_stream_view stream; + rmm::device_async_resource_ref mr; + + private: + auto make_zeroed_offsets(cudf::size_type size) const + { + auto offsets_buff = + cudf::detail::make_zeroed_device_uvector_async(size + 1, stream, mr); + return std::make_unique(std::move(offsets_buff), rmm::device_buffer{}, 0); + } + + public: + template () && !std::is_same_v && + !std::is_same_v && + !std::is_same_v)> + std::unique_ptr operator()(Args...) const + { + CUDF_FAIL("Invalid type."); + } + + template ())> + std::unique_ptr operator()(schema_element_with_precision const& schema, + cudf::size_type size) const + { + return cudf::make_fixed_width_column(schema.type, size, cudf::mask_state::ALL_NULL, stream, mr); + } + + template )> + std::unique_ptr operator()(schema_element_with_precision const&, + cudf::size_type size) const + { + auto offsets = make_zeroed_offsets(size); + auto null_mask = cudf::detail::create_null_mask(size, cudf::mask_state::ALL_NULL, stream, mr); + return cudf::make_strings_column( + size, std::move(offsets), rmm::device_buffer{}, size, std::move(null_mask)); + } + + template )> + std::unique_ptr operator()(schema_element_with_precision const& schema, + cudf::size_type size) const + { + CUDF_EXPECTS(schema.child_types.size() == 1, "Lists column should have only one child"); + std::vector> children; + children.emplace_back(make_zeroed_offsets(size)); + children.emplace_back(cudf::type_dispatcher(schema.child_types.front().second.type, + empty_column_functor{stream, mr}, + schema.child_types.front().second)); + auto null_mask = cudf::detail::create_null_mask(size, cudf::mask_state::ALL_NULL, stream, mr); + // Do not use `cudf::make_lists_column` since we do not need to call `purge_nonempty_nulls` + // on the child column as it does not have non-empty nulls. + return std::make_unique(cudf::data_type{cudf::type_id::LIST}, + size, + rmm::device_buffer{}, + std::move(null_mask), + size, + std::move(children)); + } + + template )> + std::unique_ptr operator()(schema_element_with_precision const& schema, + cudf::size_type size) const + { + std::vector> children; + children.reserve(schema.child_types.size()); + for (auto const& [child_name, child_schema] : schema.child_types) { + children.emplace_back(cudf::type_dispatcher(child_schema.type, *this, child_schema, size)); + } + auto null_mask = cudf::detail::create_null_mask(size, cudf::mask_state::ALL_NULL, stream, mr); + // Do not use `cudf::make_structs_column` since we do not need to call `superimpose_nulls` + // on the children columns. + return std::make_unique(cudf::data_type{cudf::type_id::STRUCT}, + size, + rmm::device_buffer{}, + std::move(null_mask), + size, + std::move(children)); + } +}; + +// This is a workaround for https://github.com/rapidsai/cudf/issues/17167. +// When the issue is fixed, we should remove this utility and adopt it. +std::unique_ptr make_all_nulls_column(schema_element_with_precision const& schema, + cudf::size_type num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + return cudf::type_dispatcher(schema.type, allnull_column_functor{stream, mr}, schema, num_rows); +} + +template +std::unique_ptr convert_data_type(InputType&& input, + schema_element_with_precision const& schema, + bool allow_nonnumeric_numbers, + bool is_us_locale, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + CUDF_FUNC_RANGE(); + + using DecayInputT = std::decay_t; + auto constexpr input_is_const_cv = std::is_same_v; + auto constexpr input_is_column_ptr = std::is_same_v>; + static_assert(input_is_const_cv ^ input_is_column_ptr, + "Input to `convert_data_type` must either be `cudf::column_view const&` or " + "`std::unique_ptr`"); + + auto const [d_type, num_rows] = [&]() -> std::pair { + if constexpr (input_is_column_ptr) { + return {input->type().id(), input->size()}; + } else { + return {input.type().id(), input.size()}; + } + }(); + + if (d_type == cudf::type_id::STRING) { + if (cudf::is_chrono(schema.type)) { + // Date/time is not processed here - it should be handled separately in spark-rapids. + if constexpr (input_is_column_ptr) { + return std::move(input); + } else { + CUDF_FAIL("Cannot convert data type to a chrono (date/time) type."); + return nullptr; + } + } + + if (schema.type.id() == cudf::type_id::BOOL8) { + if constexpr (input_is_column_ptr) { + return cast_strings_to_booleans(input->view(), stream, mr); + } else { + return cast_strings_to_booleans(input, stream, mr); + } + } + + if (cudf::is_integral(schema.type)) { + if constexpr (input_is_column_ptr) { + return cast_strings_to_integers(input->view(), schema.type, stream, mr); + } else { + return cast_strings_to_integers(input, schema.type, stream, mr); + } + } + + if (cudf::is_floating_point(schema.type)) { + if constexpr (input_is_column_ptr) { + return cast_strings_to_floats( + input->view(), schema.type, allow_nonnumeric_numbers, stream, mr); + } else { + return cast_strings_to_floats(input, schema.type, allow_nonnumeric_numbers, stream, mr); + } + } + + if (cudf::is_fixed_point(schema.type)) { + if constexpr (input_is_column_ptr) { + return cast_strings_to_decimals( + input->view(), schema.type, schema.precision, is_us_locale, stream, mr); + } else { + return cast_strings_to_decimals( + input, schema.type, schema.precision, is_us_locale, stream, mr); + } + } + + if (schema.type.id() == cudf::type_id::STRING) { + if constexpr (input_is_column_ptr) { + auto [removed_quotes, success] = + try_remove_quotes(input->view(), /*nullify_if_not_quoted*/ false, stream, mr); + return std::move(success ? removed_quotes : input); + } else { + auto [removed_quotes, success] = + try_remove_quotes(input, /*nullify_if_not_quoted*/ false, stream, mr); + return success ? std::move(removed_quotes) + : std::make_unique(input, stream, mr); + } + } + + CUDF_FAIL("Unexpected column type for conversion."); + return nullptr; + } // d_type == cudf::type_id::STRING + + // From here, the input column should have type either LIST or STRUCT. + + // Handle mismatched schema. + if (schema.type.id() != d_type) { return make_all_nulls_column(schema, num_rows, stream, mr); } + + if constexpr (input_is_column_ptr) { + auto const null_count = input->null_count(); + auto const num_children = input->num_children(); + auto input_content = input->release(); + + if (schema.type.id() == cudf::type_id::LIST) { + auto const& child_schema = schema.child_types.front().second; + auto& child = input_content.children[cudf::lists_column_view::child_column_index]; + + // Handle mismatched child schema. + if (cudf::is_nested(child_schema.type) && (child_schema.type.id() != child->type().id())) { + return make_all_nulls_column(schema, num_rows, stream, mr); + } + + std::vector> new_children; + new_children.emplace_back( + std::move(input_content.children[cudf::lists_column_view::offsets_column_index])); + new_children.emplace_back(convert_data_type( + std::move(child), child_schema, allow_nonnumeric_numbers, is_us_locale, stream, mr)); + + // Do not use `cudf::make_lists_column` since we do not need to call `purge_nonempty_nulls` + // on the child column as it does not have non-empty nulls. + return std::make_unique(cudf::data_type{cudf::type_id::LIST}, + num_rows, + rmm::device_buffer{}, + std::move(*input_content.null_mask), + null_count, + std::move(new_children)); + } + + if (schema.type.id() == cudf::type_id::STRUCT) { + std::vector> new_children; + new_children.reserve(num_children); + for (cudf::size_type i = 0; i < num_children; ++i) { + new_children.emplace_back(convert_data_type(std::move(input_content.children[i]), + schema.child_types[i].second, + allow_nonnumeric_numbers, + is_us_locale, + stream, + mr)); + } + + // Do not use `cudf::make_structs_column` since we do not need to call `superimpose_nulls` + // on the children columns. + return std::make_unique(cudf::data_type{cudf::type_id::STRUCT}, + num_rows, + rmm::device_buffer{}, + std::move(*input_content.null_mask), + null_count, + std::move(new_children)); + } + } else { // input_is_const_cv + auto const null_count = input.null_count(); + auto const num_children = input.num_children(); + + if (schema.type.id() == cudf::type_id::LIST) { + auto const& child_schema = schema.child_types.front().second; + auto const child = input.child(cudf::lists_column_view::child_column_index); + + // Handle mismatched child schema. + if (cudf::is_nested(child_schema.type) && (child_schema.type.id() != child.type().id())) { + return make_all_nulls_column(schema, num_rows, stream, mr); + } + + std::vector> new_children; + new_children.emplace_back( + std::make_unique(input.child(cudf::lists_column_view::offsets_column_index))); + new_children.emplace_back( + convert_data_type(child, child_schema, allow_nonnumeric_numbers, is_us_locale, stream, mr)); + + // Do not use `cudf::make_lists_column` since we do not need to call `purge_nonempty_nulls` + // on the child column as it does not have non-empty nulls. + return std::make_unique(cudf::data_type{cudf::type_id::LIST}, + num_rows, + rmm::device_buffer{}, + cudf::detail::copy_bitmask(input, stream, mr), + null_count, + std::move(new_children)); + } + + if (schema.type.id() == cudf::type_id::STRUCT) { + std::vector> new_children; + new_children.reserve(num_children); + for (cudf::size_type i = 0; i < num_children; ++i) { + new_children.emplace_back(convert_data_type(input.child(i), + schema.child_types[i].second, + allow_nonnumeric_numbers, + is_us_locale, + stream, + mr)); + } + + // Do not use `cudf::make_structs_column` since we do not need to call `superimpose_nulls` + // on the children columns. + return std::make_unique(cudf::data_type{cudf::type_id::STRUCT}, + num_rows, + rmm::device_buffer{}, + cudf::detail::copy_bitmask(input, stream, mr), + null_count, + std::move(new_children)); + } + } + + CUDF_FAIL("Unexpected column type for conversion."); + return nullptr; +} + +std::unique_ptr from_json_to_structs(cudf::strings_column_view const& input, + std::vector const& col_names, + std::vector const& num_children, + std::vector const& types, + std::vector const& scales, + std::vector const& precisions, + bool normalize_single_quotes, + bool allow_leading_zeros, + bool allow_nonnumeric_numbers, + bool allow_unquoted_control, + bool is_us_locale, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto const [concat_input, delimiter, should_be_nullified] = + concat_json(input, false, stream, cudf::get_current_device_resource()); + auto const [schema, schema_with_precision] = + generate_struct_schema(col_names, num_children, types, scales, precisions); + + auto opts_builder = + cudf::io::json_reader_options::builder( + cudf::io::source_info{cudf::device_span{ + static_cast(concat_input->data()), concat_input->size()}}) + // fixed options + .lines(true) + .recovery_mode(cudf::io::json_recovery_mode_t::RECOVER_WITH_NULL) + .normalize_whitespace(true) + .mixed_types_as_string(true) + .keep_quotes(true) + .experimental(true) + .strict_validation(true) + // specifying parameters + .normalize_single_quotes(normalize_single_quotes) + .delimiter(delimiter) + .numeric_leading_zeros(allow_leading_zeros) + .nonnumeric_numbers(allow_nonnumeric_numbers) + .unquoted_control_chars(allow_unquoted_control) + .dtypes(schema) + .prune_columns(schema.child_types.size() != 0); + + auto const parsed_table_with_meta = cudf::io::read_json(opts_builder.build()); + auto const& parsed_meta = parsed_table_with_meta.metadata; + auto parsed_columns = parsed_table_with_meta.tbl->release(); + + CUDF_EXPECTS(parsed_columns.size() == schema.child_types.size(), + "Numbers of output columns is different from schema size."); + + std::vector> converted_cols; + converted_cols.reserve(parsed_columns.size()); + for (std::size_t i = 0; i < parsed_columns.size(); ++i) { + auto const d_type = parsed_columns[i]->type().id(); + CUDF_EXPECTS(d_type == cudf::type_id::LIST || d_type == cudf::type_id::STRUCT || + d_type == cudf::type_id::STRING, + "Parsed JSON columns should be STRING or nested."); + + auto const& [col_name, col_schema] = schema_with_precision.child_types[i]; + CUDF_EXPECTS(parsed_meta.schema_info[i].name == col_name, "Mismatched column name."); + converted_cols.emplace_back(convert_data_type(std::move(parsed_columns[i]), + col_schema, + allow_nonnumeric_numbers, + is_us_locale, + stream, + mr)); + } + + auto const valid_it = should_be_nullified->view().begin(); + auto [null_mask, null_count] = cudf::detail::valid_if( + valid_it, valid_it + should_be_nullified->size(), thrust::logical_not{}, stream, mr); + + // Do not use `cudf::make_structs_column` since we do not need to call `superimpose_nulls` + // on the children columns. + return std::make_unique( + cudf::data_type{cudf::type_id::STRUCT}, + input.size(), + rmm::device_buffer{}, + null_count > 0 ? std::move(null_mask) : rmm::device_buffer{0, stream, mr}, + null_count, + std::move(converted_cols)); +} + +} // namespace + +} // namespace detail + +std::unique_ptr from_json_to_structs(cudf::strings_column_view const& input, + std::vector const& col_names, + std::vector const& num_children, + std::vector const& types, + std::vector const& scales, + std::vector const& precisions, + bool normalize_single_quotes, + bool allow_leading_zeros, + bool allow_nonnumeric_numbers, + bool allow_unquoted_control, + bool is_us_locale, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + CUDF_FUNC_RANGE(); + + return detail::from_json_to_structs(input, + col_names, + num_children, + types, + scales, + precisions, + normalize_single_quotes, + allow_leading_zeros, + allow_nonnumeric_numbers, + allow_unquoted_control, + is_us_locale, + stream, + mr); +} + +std::unique_ptr convert_from_strings(cudf::strings_column_view const& input, + std::vector const& num_children, + std::vector const& types, + std::vector const& scales, + std::vector const& precisions, + bool allow_nonnumeric_numbers, + bool is_us_locale, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + CUDF_FUNC_RANGE(); + + [[maybe_unused]] auto const [schema, schema_with_precision] = detail::generate_struct_schema( + /*dummy col_names*/ std::vector(num_children.size(), std::string{}), + num_children, + types, + scales, + precisions); + CUDF_EXPECTS(schema_with_precision.child_types.size() == 1, + "The input schema to convert must have exactly one column."); + + auto const input_cv = input.parent(); + return detail::convert_data_type(input_cv, + schema_with_precision.child_types.front().second, + allow_nonnumeric_numbers, + is_us_locale, + stream, + mr); +} + +std::unique_ptr remove_quotes(cudf::strings_column_view const& input, + bool nullify_if_not_quoted, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + CUDF_FUNC_RANGE(); + + auto const input_cv = input.parent(); + auto [removed_quotes, success] = + detail::try_remove_quotes(input_cv, nullify_if_not_quoted, stream, mr); + return success ? std::move(removed_quotes) : std::make_unique(input_cv, stream, mr); +} + +} // namespace spark_rapids_jni diff --git a/src/main/cpp/src/get_json_object.cu b/src/main/cpp/src/get_json_object.cu index f836186192..8fce46bca4 100644 --- a/src/main/cpp/src/get_json_object.cu +++ b/src/main/cpp/src/get_json_object.cu @@ -304,17 +304,6 @@ __device__ inline thrust::tuple path_match_index( } } -__device__ inline thrust::tuple path_match_named( - cudf::device_span path) -{ - auto match = path_match_element(path, path_instruction_type::NAMED); - if (match) { - return thrust::make_tuple(true, path.data()[0].name); - } else { - return thrust::make_tuple(false, cudf::string_view()); - } -} - __device__ inline thrust::tuple path_match_index_wildcard( cudf::device_span path) { @@ -464,7 +453,7 @@ __device__ thrust::pair evaluate_path( // case (START_OBJECT, Named :: xs) // case path 4 else if (json_token::START_OBJECT == ctx.token && - thrust::get<0>(path_match_named(ctx.path))) { + ctx.path.front().type == path_instruction_type::NAMED) { if (!ctx.is_first_enter) { // 2st enter // skip the following children after the expect @@ -492,15 +481,16 @@ __device__ thrust::pair evaluate_path( ctx.is_first_enter = false; // match first mached children with expected name bool found_expected_child = false; - while (json_token::END_OBJECT != p.next_token()) { + auto const to_match_name = ctx.path.front().name; + while (true) { + auto const is_name_matched = p.parse_next_token_with_matching(to_match_name); + if (json_token::END_OBJECT == p.get_current_token()) { break; } + // JSON validation check if (json_token::ERROR == p.get_current_token()) { return {false, 0}; } - // need to try more children - auto match_named = path_match_named(ctx.path); - auto named = thrust::get<1>(match_named); // current token is FIELD_NAME - if (p.match_current_field_name(named)) { + if (is_name_matched) { // skip FIELD_NAME token p.next_token(); // JSON validation check @@ -1029,7 +1019,6 @@ std::vector> get_json_object_batch( construct_path_commands(json_paths, stream); auto const num_outputs = json_paths.size(); - std::vector> output; // The error check array contains markers denoting if there is any out-of-bound write occurs // (first `num_outputs` elements), or if the nesting depth exceeded its limits (the last element). @@ -1062,19 +1051,23 @@ std::vector> get_json_object_batch( auto d_path_data = cudf::detail::make_device_uvector_async( h_path_data, stream, rmm::mr::get_current_device_resource()); thrust::uninitialized_fill( - rmm::exec_policy(stream), d_error_check.begin(), d_error_check.end(), 0); + rmm::exec_policy_nosync(stream), d_error_check.begin(), d_error_check.end(), 0); kernel_launcher::exec(input, d_path_data, d_max_path_depth_exceeded, stream); auto h_error_check = cudf::detail::make_host_vector_sync(d_error_check, stream); auto has_no_oob = check_error(h_error_check); + std::vector const>> + batch_stringviews; + batch_stringviews.reserve(out_stringviews.size()); + // If we didn't see any out-of-bound write, everything is good so far. // Just gather the output strings and return. if (has_no_oob) { for (auto const& out_sview : out_stringviews) { - output.emplace_back(cudf::make_strings_column(out_sview, stream, mr)); + batch_stringviews.emplace_back(out_sview); } - return output; + return cudf::make_strings_column_batch(batch_stringviews, stream, mr); } // From here, we had out-of-bound write. Although this is very rare, it may still happen. @@ -1082,6 +1075,7 @@ std::vector> get_json_object_batch( std::vector, int64_t>> out_offsets_and_sizes; std::vector> out_char_buffers; std::vector oob_indices; + std::vector no_oob_indices; // Check validity from the stored char pointers. auto const validator = [] __device__(thrust::pair const item) { @@ -1095,7 +1089,6 @@ std::vector> get_json_object_batch( if (h_error_check[idx]) { oob_indices.emplace_back(idx); - output.emplace_back(nullptr); // just placeholder. out_null_masks_and_null_counts.emplace_back( cudf::detail::valid_if(out_sview.begin(), out_sview.end(), validator, stream, mr)); @@ -1121,9 +1114,18 @@ std::vector> get_json_object_batch( out_char_buffers.back().data(), d_error_check.data() + idx}); } else { - output.emplace_back(cudf::make_strings_column(out_sview, stream, mr)); + no_oob_indices.emplace_back(idx); + batch_stringviews.emplace_back(out_sview); } } + + std::vector> output(num_outputs); + auto no_oob_output = cudf::make_strings_column_batch(batch_stringviews, stream, mr); + for (std::size_t idx = 0; idx < no_oob_indices.size(); ++idx) { + auto const out_idx = no_oob_indices[idx]; + output[out_idx] = std::move(no_oob_output[idx]); + } + // These buffers are no longer needed. scratch_buffers.clear(); out_stringviews.clear(); @@ -1132,7 +1134,7 @@ std::vector> get_json_object_batch( d_path_data = cudf::detail::make_device_uvector_async( h_path_data, stream, rmm::mr::get_current_device_resource()); thrust::uninitialized_fill( - rmm::exec_policy(stream), d_error_check.begin(), d_error_check.end(), 0); + rmm::exec_policy_nosync(stream), d_error_check.begin(), d_error_check.end(), 0); kernel_launcher::exec(input, d_path_data, d_max_path_depth_exceeded, stream); h_error_check = cudf::detail::make_host_vector_sync(d_error_check, stream); has_no_oob = check_error(h_error_check); diff --git a/src/main/cpp/src/hive_hash.cu b/src/main/cpp/src/hive_hash.cu index 85598565a9..c10c424b8b 100644 --- a/src/main/cpp/src/hive_hash.cu +++ b/src/main/cpp/src/hive_hash.cu @@ -37,6 +37,8 @@ using hive_hash_value_t = int32_t; constexpr hive_hash_value_t HIVE_HASH_FACTOR = 31; constexpr hive_hash_value_t HIVE_INIT_HASH = 0; +constexpr int MAX_NESTED_DEPTH = 8; + hive_hash_value_t __device__ inline compute_int(int32_t key) { return key; } hive_hash_value_t __device__ inline compute_long(int64_t key) @@ -182,8 +184,8 @@ class hive_device_row_hasher { HIVE_INIT_HASH, cuda::proclaim_return_type( [row_index, nulls = this->_check_nulls] __device__(auto hash, auto const& column) { - auto cur_hash = cudf::type_dispatcher( - column.type(), element_hasher_adapter{nulls}, column, row_index); + auto cur_hash = + cudf::type_dispatcher(column.type(), element_hasher_adapter{nulls}, column, row_index); return HIVE_HASH_FACTOR * hash + cur_hash; })); } @@ -191,8 +193,6 @@ class hive_device_row_hasher { private: /** * @brief Computes the hash value of an element in the given column. - * - * Only supported non nested types now */ class element_hasher_adapter { public: @@ -210,11 +210,227 @@ class hive_device_row_hasher { return this->hash_functor.template operator()(col, row_index); } + /** + * @brief A structure to keep track of the computation for nested types. + */ + struct col_stack_frame { + private: + cudf::column_device_view _column; // the column has only one row + hive_hash_value_t _cur_hash; // current hash value of the column + int _idx_to_process; // the index of child or element to process next + + public: + __device__ col_stack_frame() = + delete; // Because the default constructor of `cudf::column_device_view` is deleted + + __device__ col_stack_frame(cudf::column_device_view col) + : _column(std::move(col)), _idx_to_process(0), _cur_hash(HIVE_INIT_HASH) + { + } + + __device__ void update_cur_hash(hive_hash_value_t hash) + { + _cur_hash = _cur_hash * HIVE_HASH_FACTOR + hash; + } + + __device__ hive_hash_value_t get_hash() { return _cur_hash; } + + __device__ int get_and_inc_idx_to_process() { return _idx_to_process++; } + + __device__ int get_idx_to_process() { return _idx_to_process; } + + __device__ cudf::column_device_view get_column() { return _column; } + }; + + /** + * @brief Functor to compute the hive hash value for a nested column. + * + * This functor produces the same result as "HiveHash" in Spark for structs and lists. + * The pseudocode of Spark's HiveHash function for structs is as follows: + * + * hive_hash_value_t hive_hash(NestedType element) { + * hive_hash_value_t hash = HIVE_INIT_HASH; + * for (int i = 0; i < element.num_child(); i++) { + * hash = hash * HIVE_HASH_FACTOR + hive_hash(element.get_child(i)); + * } + * return hash; + * } + * + * In the cases of lists, the hash value is computed by a similar way but we iterate through the + * list elements instead of through the child columns' elements. + * + * This functor uses a stack to simulate the recursive process of the above pseudocode. + * When an element is popped from the stack, it means that the hash value of it has been + * computed. Therefore, we should update the parent's `cur_hash` upon popping the element. + * + * The algorithm is as follows: + * + * 1. Initialize the stack and push the root column into the stack. + * 2. While the stack is not empty: + * a. Get the top element of the stack. Don't pop it until it is processed. + * b. If the column is a structs column: + * i. If all child columns are processed, pop the element and update `cur_hash` of its + * parent column. + * ii. Otherwise, process the next child column. + * c. If the column is a lists column, process it by a similar way as structs column but + * iterating through the list elements instead of child columns' elements. + * 3. Return the hash value of the root column. + * + * For example, consider the following nested column: `Struct, decimal>` + * + * S1 + * / \ + * S2 d + * / \ + * i f + * + * - First, S1 is pushed into the stack. Then, S2 is pushed into the stack. + * - S2's hash value can be computed directly because its children are of primitive types. + * When S2 is popped, S1's `cur_hash` is updated to S2's hash value. + * - Now the top of the stack is S1. The next child to process is d. S1's `cur_hash` is updated + * to `hash(S2) * HIVE_HASH_FACTOR + hash(d)`, which is the hash value of S1. + * - When S1 is popped, the hash value of the root column is returned. + * + * As lists columns have a different interface from structs columns, we need to handle them + * separately. + * + * For example, consider that the input column is of type `List>`. + * Assume that the element at `row_index` is: [[1, 0], null, [2, null]] + * Since the stack_frame should contain a column that consists of only one row, the input column + * should be sliced. The sliced column is noted as L1 here. + * + * L1 + * | + * L2 + * | + * i + * + * List level L1: + * |Index| List> | + * |-----|-------------------------| + * |0 |[[1, 0], null, [2, null]]| + * length: 1 + * Offsets: 0, 3 + * + * List level L2: + * |Index|List| + * |-----|---------| + * |0 |[1, 0] | + * |1 |null | + * |2 |[2, null]| + * length: 3 + * Offsets: 0, 2, 2, 4 + * null_mask: 101 + * + * Int level i: + * |Index|int | + * |-----|----| + * |0 |1 | + * |1 |0 | + * |2 |2 | + * |3 |null| + * length: 4 + * null_mask: 0111 + * + * Since the underlying data loses the null information of the top-level list column, computing + * hash values using the underlying data merely can yield different results compared to Spark. + * For example, [[1, 0], [2, null]] has the same underlying data as [[1, 0], null, [2, null]]. + * However, they have different hive hash values, as null values affect the hash value. + * + * And sublist structure factors in the hash value. For example, [[1, 0], [2]] and [[1], [0, 2]] + * have different hive hash values. + * + * The computation process for lists columns in this solution is as follows: + * L1 List> + * | + * L2 List + * / | \ + * L2[0] L2[1] L2[2] int + * | | + * i1 i2 int + * + * Note: L2、i1、i2 are all temporary columns, which would not be pushed into the stack. + * If the child column is of primitive type, the hash value of the list column can be directly + * computed. + * + * @tparam T The type of the column. + * @param col The column device view. + * @param row_index The index of the row to compute the hash for. + * @return The computed hive hash value. + * + * @note This function is only enabled for nested column types. + */ template ())> __device__ hive_hash_value_t operator()(cudf::column_device_view const& col, cudf::size_type row_index) const noexcept { - CUDF_UNREACHABLE("Nested type is not supported"); + cudf::column_device_view curr_col = col.slice(row_index, 1); + // The default constructor of `col_stack_frame` is deleted, so it can not allocate an array + // of `col_stack_frame` directly. + // Instead leverage the byte array to create the col_stack_frame array. + alignas(col_stack_frame) char stack_wrapper[sizeof(col_stack_frame) * MAX_NESTED_DEPTH]; + auto col_stack = reinterpret_cast(stack_wrapper); + int stack_size = 0; + + col_stack[stack_size++] = col_stack_frame(curr_col); + + while (stack_size > 0) { + col_stack_frame& top = col_stack[stack_size - 1]; + curr_col = top.get_column(); + // Do not pop it until it is processed. The definition of `processed` is: + // - For structs, it is when all child columns are processed. + // - For lists, it is when all elements in the list are processed. + if (curr_col.type().id() == cudf::type_id::STRUCT) { + if (top.get_idx_to_process() == curr_col.num_child_columns()) { + if (--stack_size > 0) { col_stack[stack_size - 1].update_cur_hash(top.get_hash()); } + } else { + auto const structcv = cudf::detail::structs_column_device_view(curr_col); + while (top.get_idx_to_process() < curr_col.num_child_columns()) { + auto idx = top.get_and_inc_idx_to_process(); + auto const child_col = structcv.get_sliced_child(idx); + // If the child is of primitive type, accumulate child hash into struct hash + if (child_col.type().id() != cudf::type_id::LIST && + child_col.type().id() != cudf::type_id::STRUCT) { + auto child_hash = + cudf::type_dispatcher( + child_col.type(), this->hash_functor, child_col, 0); + top.update_cur_hash(child_hash); + } else { + col_stack[stack_size++] = col_stack_frame(child_col); + break; + } + } + } + } else if (curr_col.type().id() == cudf::type_id::LIST) { + // Get the child column of the list column + cudf::column_device_view child_col = + cudf::detail::lists_column_device_view(curr_col).get_sliced_child(); + // If the child column is of primitive type, directly compute the hash value of the list + if (child_col.type().id() != cudf::type_id::LIST && + child_col.type().id() != cudf::type_id::STRUCT) { + auto single_level_list_hash = cudf::detail::accumulate( + thrust::counting_iterator(0), + thrust::counting_iterator(child_col.size()), + HIVE_INIT_HASH, + [child_col, hasher = this->hash_functor] __device__(auto hash, auto element_index) { + auto cur_hash = cudf::type_dispatcher( + child_col.type(), hasher, child_col, element_index); + return HIVE_HASH_FACTOR * hash + cur_hash; + }); + top.update_cur_hash(single_level_list_hash); + if (--stack_size > 0) { col_stack[stack_size - 1].update_cur_hash(top.get_hash()); } + } else { + if (top.get_idx_to_process() == child_col.size()) { + if (--stack_size > 0) { col_stack[stack_size - 1].update_cur_hash(top.get_hash()); } + } else { + // Push the next element into the stack + col_stack[stack_size++] = + col_stack_frame(child_col.slice(top.get_and_inc_idx_to_process(), 1)); + } + } + } + } + return col_stack[0].get_hash(); } private: @@ -224,6 +440,36 @@ class hive_device_row_hasher { Nullate const _check_nulls; cudf::table_device_view const _table; }; + +void check_nested_depth(cudf::table_view const& input) +{ + using column_checker_fn_t = std::function; + + column_checker_fn_t get_nested_depth = [&](cudf::column_view const& col) { + if (col.type().id() == cudf::type_id::LIST) { + auto const child_col = cudf::lists_column_view(col).child(); + return 1 + get_nested_depth(child_col); + } else if (col.type().id() == cudf::type_id::STRUCT) { + int max_child_depth = 0; + for (auto child = col.child_begin(); child != col.child_end(); ++child) { + max_child_depth = std::max(max_child_depth, get_nested_depth(*child)); + } + return 1 + max_child_depth; + } else { // Primitive type + return 0; + } + }; + + for (auto i = 0; i < input.num_columns(); i++) { + cudf::column_view const& col = input.column(i); + CUDF_EXPECTS(get_nested_depth(col) <= MAX_NESTED_DEPTH, + "The " + std::to_string(i) + + "-th column exceeds the maximum allowed nested depth. " + + "Current depth: " + std::to_string(get_nested_depth(col)) + ", " + + "Maximum allowed depth: " + std::to_string(MAX_NESTED_DEPTH)); + } +} + } // namespace std::unique_ptr hive_hash(cudf::table_view const& input, @@ -239,6 +485,8 @@ std::unique_ptr hive_hash(cudf::table_view const& input, // Return early if there's nothing to hash if (input.num_columns() == 0 || input.num_rows() == 0) { return output; } + check_nested_depth(input); + bool const nullable = has_nested_nulls(input); auto const input_view = cudf::table_device_view::create(input, stream); auto output_view = output->mutable_view(); diff --git a/src/main/cpp/src/host_table_view.hpp b/src/main/cpp/src/host_table_view.hpp new file mode 100644 index 0000000000..f707b78976 --- /dev/null +++ b/src/main/cpp/src/host_table_view.hpp @@ -0,0 +1,320 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include +#include + +namespace spark_rapids_jni { + +/** + * @brief A non-owning, immutable view of host data as a column of elements, + * some of which may be null as indicated by a bitmask. + * + * Unless otherwise noted, the memory layout of the `host_column_view`'s data and + * bitmask is expected to adhere to the Arrow Physical Memory Layout + * Specification: https://arrow.apache.org/docs/memory_layout.html + * + * Because `host_column_view` is non-owning, no host memory is allocated nor freed + * when `host_column_view` objects are created or destroyed. + */ +class host_column_view { + private: + cudf::data_type _type{cudf::type_id::EMPTY}; + cudf::size_type _size{}; + void const* _data{}; + cudf::bitmask_type const* _null_mask{}; + cudf::size_type _null_count{}; + std::vector _children{}; + + public: + host_column_view() = default; + ~host_column_view() = default; + host_column_view(host_column_view const&) = default; + host_column_view(host_column_view&&) = default; + host_column_view& operator=(host_column_view const&) = default; + host_column_view& operator=(host_column_view&&) = default; + + /** + * @brief Construct a `host_column_view` from pointers to host memory for the + * elements and bitmask of the column. + */ + host_column_view(cudf::data_type type, + cudf::size_type size, + void const* data, + cudf::bitmask_type const* null_mask, + cudf::size_type null_count, + std::vector const& children = {}) + : _type{type}, + _size{size}, + _data{data}, + _null_mask{null_mask}, + _null_count{null_count}, + _children{children} + { + CUDF_EXPECTS(size >= 0, "Column size cannot be negative."); + if (type.id() == cudf::type_id::EMPTY) { + _null_count = size; + CUDF_EXPECTS(nullptr == data, "EMPTY column should have no data."); + CUDF_EXPECTS(nullptr == null_mask, "EMPTY column should have no null mask."); + } else if (cudf::is_compound(type)) { + if (type.id() != cudf::type_id::STRING) { + CUDF_EXPECTS(nullptr == data, "Compound (parent) columns cannot have data"); + } + } else if (size > 0) { + CUDF_EXPECTS(nullptr != data, "Null data pointer."); + } + if ((null_count > 0) and (type.id() != cudf::type_id::EMPTY)) { + CUDF_EXPECTS(nullptr != null_mask, "Invalid null mask for non-zero null count."); + } + if (type.id() == cudf::type_id::EMPTY) { + CUDF_EXPECTS(num_children() == 0, "EMPTY column cannot have children."); + } + } + + /** + * @brief Returns the number of elements in the column + * + * @return The number of elements in the column + */ + [[nodiscard]] cudf::size_type size() const noexcept { return _size; } + + /** + * @brief Returns the element `data_type` + * + * @return The `data_type` of the elements in the column + */ + [[nodiscard]] cudf::data_type type() const noexcept { return _type; } + + /** + * @brief Indicates if the column can contain null elements, i.e., if it has + * an allocated bitmask. + * + * @note If `null_count() > 0`, this function must always return `true`. + * + * @return true The bitmask is allocated + * @return false The bitmask is not allocated + */ + [[nodiscard]] bool nullable() const noexcept { return nullptr != _null_mask; } + + /** + * @brief Returns the count of null elements + * + * @return The count of null elements + */ + [[nodiscard]] cudf::size_type null_count() const noexcept { return _null_count; } + + /** + * @brief Indicates if the column contains null elements, + * i.e., `null_count() > 0` + * + * @return true One or more elements are null + * @return false All elements are valid + */ + [[nodiscard]] bool has_nulls() const { return null_count() > 0; } + + /** + * @brief Returns raw pointer to the underlying bitmask allocation. + * + * @note If `null_count() == 0`, this may return `nullptr`. + * @return Raw pointer to the bitmask + */ + [[nodiscard]] cudf::bitmask_type const* null_mask() const noexcept { return _null_mask; } + + template + T const* data() const noexcept + { + return static_cast(_data); + } + + /** + * @brief Returns the specified child + * + * @param child_index The index of the desired child + * @return The requested child `column_view` + */ + [[nodiscard]] host_column_view const& child(cudf::size_type child_index) const + { + return _children.at(child_index); + } + + /** + * @brief Returns the number of child columns. + * + * @return The number of child columns + */ + [[nodiscard]] cudf::size_type num_children() const noexcept { return _children.size(); } + + /** + * @brief Returns iterator to the beginning of the ordered sequence of child column-views. + * + * @return An iterator to a `host_column_view` referencing the first child column + */ + auto child_begin() const noexcept { return _children.cbegin(); } + + /** + * @brief Returns iterator to the end of the ordered sequence of child column-views. + * + * @return An iterator to a `host_column_view` one past the end of the child columns + */ + auto child_end() const noexcept { return _children.cend(); } + + /** + * @brief Returns the child column corresponding to the offsets of a strings column + * + * @note This must only be called on a strings column. + */ + [[nodiscard]] host_column_view const& strings_offsets() const + { + return _children.at(cudf::strings_column_view::offsets_column_index); + } + + /** + * @brief Returns the child column corresponding to the offsets of a lists column + * + * @note This must only be called on a lists column. + */ + [[nodiscard]] host_column_view const& lists_offsets() const + { + return _children.at(cudf::lists_column_view::offsets_column_index); + } + + /** + * @brief Returns the child column containing the data of a lists column + * + * @note This must only be called on a lists column. + */ + [[nodiscard]] host_column_view const& lists_child() const + { + return _children.at(cudf::lists_column_view::child_column_index); + } +}; + +/** + * @brief A set of host_column_view's of the same size. + */ +class host_table_view { + private: + std::vector _columns{}; + cudf::size_type _num_rows{}; + + public: + using iterator = decltype(std::begin(_columns)); ///< Iterator type for the table + using const_iterator = decltype(std::cbegin(_columns)); ///< const iterator type for the table + + host_table_view() = default; + ~host_table_view() = default; + host_table_view(host_table_view const&) = default; + host_table_view(host_table_view&&) = default; + host_table_view& operator=(host_table_view const&) = default; + host_table_view& operator=(host_table_view&&) = default; + + /** + * @brief Construct from a vector of column views + * + * @note Because a `std::vector` is constructible from a + * `std::initializer_list`, this constructor also supports the following + * usage: + * ``` + * host_column_view c0, c1, c2; + * ... + * host_table_view t{{c0,c1,c2}}; // Creates a `host_table_view` from c0, c1, c2 + * ``` + * + * @throws cudf::logic_error If all views do not have the same size + * + * @param cols The vector of column views to construct the table from + */ + explicit host_table_view(std::vector const& cols) : _columns{cols} + { + if (num_columns() > 0) { + std::for_each(_columns.begin(), _columns.end(), [this](host_column_view const& col) { + CUDF_EXPECTS(col.size() == _columns.front().size(), "Column size mismatch."); + }); + _num_rows = _columns.front().size(); + } else { + _num_rows = 0; + } + } + + /** + * @brief Returns an iterator to the first view in the table. + * + * @return An iterator to the first host_column_view + */ + iterator begin() noexcept { return std::begin(_columns); } + + /** + * @brief Returns an iterator to the first view in the table. + * + * @return An iterator to the first host_column_view + */ + [[nodiscard]] const_iterator begin() const noexcept { return std::begin(_columns); } + + /** + * @brief Returns an iterator one past the last column view in the table. + * + * `end()` acts as a place holder. Attempting to dereference it results in + * undefined behavior. + * + * @return An iterator to one past the last column view in the table + */ + iterator end() noexcept { return std::end(_columns); } + + /** + * @brief Returns an iterator one past the last column view in the table. + * + * `end()` acts as a place holder. Attempting to dereference it results in + * undefined behavior. + * + * @return An iterator to one past the last column view in the table + */ + [[nodiscard]] const_iterator end() const noexcept { return std::end(_columns); } + + /** + * @brief Returns a reference to the view of the specified column + * + * @throws std::out_of_range + * If `column_index` is out of the range [0, num_columns) + * + * @param column_index The index of the desired column + * @return A reference to the desired column + */ + [[nodiscard]] host_column_view const& column(cudf::size_type column_index) const + { + return _columns.at(column_index); + } + + /** + * @brief Returns the number of columns + * + * @return The number of columns + */ + [[nodiscard]] cudf::size_type num_columns() const noexcept { return _columns.size(); } + + /** + * @brief Returns the number of rows + * + * @return The number of rows + */ + [[nodiscard]] cudf::size_type num_rows() const noexcept { return _num_rows; } +}; + +} // namespace spark_rapids_jni diff --git a/src/main/cpp/src/json_parser.cuh b/src/main/cpp/src/json_parser.cuh index 4e712937ed..446caf6357 100644 --- a/src/main/cpp/src/json_parser.cuh +++ b/src/main/cpp/src/json_parser.cuh @@ -134,13 +134,8 @@ class char_range { __device__ inline cudf::size_type size() const { return _len; } __device__ inline char const* data() const { return _data; } - __device__ inline char const* start() const { return _data; } - __device__ inline char const* end() const { return _data + _len; } - - __device__ inline bool eof(cudf::size_type pos) const { return pos >= _len; } __device__ inline bool is_null() const { return _data == nullptr; } - __device__ inline bool is_empty() const { return _len == 0; } - + __device__ inline bool is_empty() const { return _len <= 0; } __device__ inline char operator[](cudf::size_type pos) const { return _data[pos]; } __device__ inline cudf::string_view slice_sv(cudf::size_type pos, cudf::size_type len) const @@ -153,35 +148,29 @@ class char_range { return char_range(_data + pos, len); } - private: + protected: char const* _data; cudf::size_type _len; }; /** - * A char_range that keeps track of where in the data it currently is. + * A char range that moves the begin pointer of the current range forward while reading. + * + * This support continuous reading of characters without the need of an additional variable + * to keep track of the current reading position. */ -class char_range_reader { +class char_range_reader : public char_range { public: - __device__ inline explicit char_range_reader(char_range range) : _range(range), _pos(0) {} - - __device__ inline char_range_reader(char_range range, cudf::size_type start) - : _range(range), _pos(start) + __device__ inline explicit char_range_reader(char_range range) : char_range(std::move(range)) {} + __device__ inline void next() { + _data++; + _len--; } - __device__ inline bool eof() const { return _range.eof(_pos); } - __device__ inline bool is_null() const { return _range.is_null(); } - - __device__ inline void next() { _pos++; } - - __device__ inline char current_char() const { return _range[_pos]; } - - __device__ inline cudf::size_type pos() const { return _pos; } - - private: - char_range _range; - cudf::size_type _pos; + // Warning: this does not check for out-of-bound access. + // The caller must be responsible to check for empty range before calling this. + __device__ inline char current_char() const { return _data[0]; } }; /** @@ -298,7 +287,7 @@ class json_parser { */ static __device__ inline bool try_skip(char_range_reader& reader, char expected) { - if (!reader.eof() && reader.current_char() == expected) { + if (!reader.is_empty() && reader.current_char() == expected) { reader.next(); return true; } @@ -412,12 +401,10 @@ class json_parser { */ __device__ inline void parse_string_and_set_current() { - // TODO eventually chars should be a reader so we can just pass it in... - char_range_reader reader(chars, curr_pos); - auto [success, end_char_pos] = try_parse_string(reader); + [[maybe_unused]] auto const [success, matched, end] = + try_parse_string(char_range_reader{chars.slice(curr_pos, chars.size() - curr_pos)}); if (success) { - // TODO remove end_char_pos, and just get it from the reader... - curr_pos = end_char_pos; + curr_pos = static_cast(thrust::distance(chars.data(), end)); current_token = json_token::VALUE_STRING; } else { set_current_error(); @@ -499,7 +486,7 @@ class json_parser { char* copy_destination, escape_style w_style) { - if (str.eof()) { return 0; } + if (str.is_empty()) { return 0; } char const quote_char = str.current_char(); int output_size_bytes = 0; @@ -514,7 +501,7 @@ class json_parser { str.next(); // scan string content - while (!str.eof()) { + while (!str.is_empty()) { char const c = str.current_char(); int const v = static_cast(c); if (c == quote_char) { @@ -546,8 +533,10 @@ class json_parser { } else if ('\\' == c) { // path 3: escape path str.next(); - char_range_reader to_match(char_range::null()); - if (!try_skip_escape_part(str, to_match, copy_destination, w_style, output_size_bytes)) { + char_range_reader to_match(char_range::null()); // unused + bool matched_field_name{false}; // unused + if (!try_skip_escape_part( + str, to_match, copy_destination, w_style, output_size_bytes, matched_field_name)) { return output_size_bytes; } } else { @@ -610,85 +599,58 @@ class json_parser { * * @param str string to parse * @param to_match expected match str - * @param w_style the escape style for writing. - * @return a pair of success and length, where success is true if the string - * is valid and length is the number of bytes needed to encode the string - * in the given style. + * @return a tuple of values indicating if the parse process was successful, if field name was + * matched, and a pointer to the past-end position of the parsed data */ - static __device__ inline std::pair try_parse_string( - char_range_reader& str, - char_range_reader to_match = char_range_reader(char_range::null()), - escape_style w_style = escape_style::UNESCAPED) + static __device__ inline thrust::tuple try_parse_string( + char_range_reader str, char_range_reader to_match = char_range_reader(char_range::null())) { - if (str.eof()) { return std::make_pair(false, 0); } - char const quote_char = str.current_char(); - int output_size_bytes = 0; - - // write the first " if write style is escaped - if (escape_style::ESCAPED == w_style) { output_size_bytes++; } + if (str.is_empty()) { return thrust::make_tuple(false, false, nullptr); } + char const quote_char = str.current_char(); + bool matched_field_name = !to_match.is_null(); // skip left quote char // We don't need to actually verify what it is, because we just read it. str.next(); // scan string content - while (!str.eof()) { + while (!str.is_empty()) { char c = str.current_char(); int v = static_cast(c); - if (c == quote_char) { - // path 1: match closing quote char + if (c == quote_char) { // path 1: match closing quote char str.next(); - - // match check, the last char in match_str is quote_char - if (!to_match.is_null() && !to_match.eof()) { return std::make_pair(false, 0); } - - // write the end " if write style is escaped - if (escape_style::ESCAPED == w_style) { output_size_bytes++; } - - return std::make_pair(true, str.pos()); - } else if (v >= 0 && v < 32) { - // path 2: unescaped control char - - // copy if enabled, escape mode, write more chars - if (escape_style::ESCAPED == w_style) { - int escape_chars = escape_char(str.current_char(), nullptr); - output_size_bytes += (escape_chars - 1); - } - - // check match if enabled - if (!try_match_char(to_match, str.current_char())) { return std::make_pair(false, 0); } - + matched_field_name = matched_field_name && (to_match.is_null() || to_match.is_empty()); + return thrust::make_tuple(true, matched_field_name, str.data()); + } else if (v >= 0 && v < 32) { // path 2: unescaped control char + matched_field_name = matched_field_name && try_match_char(to_match, c); str.next(); - output_size_bytes++; continue; - } else if ('\\' == c) { - // path 3: escape path + } else if ('\\' == c) { // path 3: escape path str.next(); - char* copy_dest_nullptr = nullptr; - if (!try_skip_escape_part(str, to_match, copy_dest_nullptr, w_style, output_size_bytes)) { - return std::make_pair(false, 0); - } - } else { - // path 4: safe code point - - // handle single unescaped " char; happens when string is quoted by char ' - // e.g.: 'A"' string, escape to "A\\"" (5 chars: " A \ " ") - if ('\"' == c && escape_style::ESCAPED == w_style) { output_size_bytes++; } - if (!try_skip_safe_code_point(str, c)) { return std::make_pair(false, 0); } - // check match if enabled - if (!try_match_char(to_match, c)) { return std::make_pair(false, 0); } - output_size_bytes++; + char* copy_dest_nullptr = nullptr; // unused + int output_size_bytes = 0; // unused + if (!try_skip_escape_part(str, + to_match, + copy_dest_nullptr, + escape_style::UNESCAPED, + output_size_bytes, + matched_field_name)) { + return thrust::make_tuple(false, false, nullptr); + } + } else { // path 4: safe code point + if (!try_skip_safe_code_point(str, c)) { return thrust::make_tuple(false, false, nullptr); } + matched_field_name = matched_field_name && try_match_char(to_match, c); } } - return std::make_pair(false, 0); + return thrust::make_tuple(false, false, nullptr); } static __device__ inline bool try_match_char(char_range_reader& reader, char c) { if (!reader.is_null()) { - if (!reader.eof() && reader.current_char() == c) { + if (!reader.is_empty() && reader.current_char() == c) { reader.next(); return true; } else { @@ -708,11 +670,12 @@ class json_parser { char_range_reader& to_match, char*& copy_dest, escape_style w_style, - int& output_size_bytes) + int& output_size_bytes, + bool& matched_field_name) { // already skipped the first '\' // try skip second part - if (!str.eof()) { + if (!str.is_empty()) { char const c = str.current_char(); switch (c) { // path 1: \", \', \\, \/, \b, \f, \n, \r, \t @@ -725,17 +688,17 @@ class json_parser { } output_size_bytes++; } - if (!try_match_char(to_match, c)) { return false; } output_size_bytes++; str.next(); + matched_field_name = matched_field_name && try_match_char(to_match, c); return true; case '\'': // for both unescaped/escaped writes a single char ' if (nullptr != copy_dest) { *copy_dest++ = c; } - if (!try_match_char(to_match, c)) { return false; } output_size_bytes++; str.next(); + matched_field_name = matched_field_name && try_match_char(to_match, c); return true; case '\\': if (nullptr != copy_dest && escape_style::UNESCAPED == w_style) { *copy_dest++ = c; } @@ -746,16 +709,16 @@ class json_parser { } output_size_bytes++; } - if (!try_match_char(to_match, c)) { return false; } output_size_bytes++; str.next(); + matched_field_name = matched_field_name && try_match_char(to_match, c); return true; case '/': // for both unescaped/escaped writes a single char / if (nullptr != copy_dest) { *copy_dest++ = c; } - if (!try_match_char(to_match, c)) { return false; } output_size_bytes++; str.next(); + matched_field_name = matched_field_name && try_match_char(to_match, c); return true; case 'b': if (nullptr != copy_dest && escape_style::UNESCAPED == w_style) { *copy_dest++ = '\b'; } @@ -766,9 +729,9 @@ class json_parser { } output_size_bytes++; } - if (!try_match_char(to_match, '\b')) { return false; } output_size_bytes++; str.next(); + matched_field_name = matched_field_name && try_match_char(to_match, '\b'); return true; case 'f': if (nullptr != copy_dest && escape_style::UNESCAPED == w_style) { *copy_dest++ = '\f'; } @@ -779,9 +742,9 @@ class json_parser { } output_size_bytes++; } - if (!try_match_char(to_match, '\f')) { return false; } output_size_bytes++; str.next(); + matched_field_name = matched_field_name && try_match_char(to_match, '\f'); return true; case 'n': if (nullptr != copy_dest && escape_style::UNESCAPED == w_style) { *copy_dest++ = '\n'; } @@ -792,9 +755,9 @@ class json_parser { } output_size_bytes++; } - if (!try_match_char(to_match, '\n')) { return false; } output_size_bytes++; str.next(); + matched_field_name = matched_field_name && try_match_char(to_match, '\n'); return true; case 'r': if (nullptr != copy_dest && escape_style::UNESCAPED == w_style) { *copy_dest++ = '\r'; } @@ -805,9 +768,9 @@ class json_parser { } output_size_bytes++; } - if (!try_match_char(to_match, '\r')) { return false; } output_size_bytes++; str.next(); + matched_field_name = matched_field_name && try_match_char(to_match, '\r'); return true; case 't': if (nullptr != copy_dest && escape_style::UNESCAPED == w_style) { *copy_dest++ = '\t'; } @@ -818,9 +781,9 @@ class json_parser { } output_size_bytes++; } - if (!try_match_char(to_match, '\t')) { return false; } output_size_bytes++; str.next(); + matched_field_name = matched_field_name && try_match_char(to_match, '\t'); return true; // path 1 done: \", \', \\, \/, \b, \f, \n, \r, \t case 'u': @@ -829,7 +792,7 @@ class json_parser { // for both unescaped/escaped writes corresponding utf8 bytes, no need // to pass in write style - return try_skip_unicode(str, to_match, copy_dest, output_size_bytes); + return try_skip_unicode(str, to_match, copy_dest, output_size_bytes, matched_field_name); default: // path 3: invalid return false; @@ -948,13 +911,14 @@ class json_parser { static __device__ bool try_skip_unicode(char_range_reader& str, char_range_reader& to_match, char*& copy_dest, - int& output_size_bytes) + int& output_size_bytes, + bool& matched_field_name) { // already parsed \u // now we expect 4 hex chars. cudf::char_utf8 code_point = 0; for (size_t i = 0; i < 4; i++) { - if (str.eof()) { return false; } + if (str.is_empty()) { return false; } char const c = str.current_char(); str.next(); if (!is_hex_digit(c)) { return false; } @@ -976,9 +940,12 @@ class json_parser { } } - if (!to_match.is_null()) { + if (matched_field_name && !to_match.is_null()) { for (cudf::size_type i = 0; i < bytes; i++) { - if (to_match.eof() || to_match.current_char() != buff[i]) { return false; } + if (to_match.is_empty() || to_match.current_char() != buff[i]) { + matched_field_name = false; + break; + } to_match.next(); } } @@ -1210,16 +1177,17 @@ class json_parser { /** * parse the key string in key:value pair */ - __device__ inline void parse_field_name_and_set_current() + __device__ inline void parse_field_name_and_set_current( + bool& matched_field_name, char_range to_match_field_name = char_range::null()) { - // TODO eventually chars should be a reader so we can just pass it in... - char_range_reader reader(chars, curr_pos); - current_token_start_pos = curr_pos; - auto [success, end_char_pos] = try_parse_string(reader); + current_token_start_pos = curr_pos; + auto const [success, matched, end] = + try_parse_string(char_range_reader{chars.slice(curr_pos, chars.size() - curr_pos)}, + char_range_reader{std::move(to_match_field_name)}); if (success) { - // TODO remove end_char_pos, and just get it from the reader... - curr_pos = end_char_pos; - current_token = json_token::FIELD_NAME; + matched_field_name = matched; + curr_pos = static_cast(thrust::distance(chars.data(), end)); + current_token = json_token::FIELD_NAME; } else { set_current_error(); } @@ -1228,11 +1196,12 @@ class json_parser { /** * continute parsing the next token and update current token * Note: only parse one token at a time - * @param[out] has_comma_before_token has comma before next token - * @param[out] has_colon_before_token has colon before next token */ - __device__ inline void parse_next_token_and_set_current(bool& has_comma_before_token, - bool& has_colon_before_token) + __device__ inline void parse_next_token_and_set_current( + bool& has_comma_before_token, + bool& has_colon_before_token, + bool& matched_field_name, + char_range to_match_field_name = char_range::null()) { skip_whitespaces(); if (!eof()) { @@ -1264,7 +1233,7 @@ class json_parser { current_token = json_token::END_OBJECT; } else { // parse key in key:value pair - parse_field_name_and_set_current(); + parse_field_name_and_set_current(matched_field_name, to_match_field_name); } } else if (current_token == json_token::FIELD_NAME) { if (c == ':') { @@ -1289,7 +1258,7 @@ class json_parser { // parse next key:value pair curr_pos++; skip_whitespaces(); - parse_field_name_and_set_current(); + parse_field_name_and_set_current(matched_field_name, to_match_field_name); } else { set_current_error(); } @@ -1351,10 +1320,29 @@ class json_parser { // parse next token bool has_comma_before_token; // no-initialization because of do not care here bool has_colon_before_token; // no-initialization because of do not care here - parse_next_token_and_set_current(has_comma_before_token, has_colon_before_token); + bool matched_field_name; // no-initialization because of do not care here + parse_next_token_and_set_current( + has_comma_before_token, has_colon_before_token, matched_field_name); return current_token; } + /** + * Continute parsing the next token. If the token is a field name then check if it is + * matched with the given name. + */ + __device__ bool parse_next_token_with_matching(cudf::string_view to_match_field_name) + { + // parse next token + bool has_comma_before_token; // no-initialization because of do not care here + bool has_colon_before_token; // no-initialization because of do not care here + bool matched_field_name; + parse_next_token_and_set_current(has_comma_before_token, + has_colon_before_token, + matched_field_name, + char_range{to_match_field_name}); + return matched_field_name; + } + /** * get current token */ @@ -1573,31 +1561,6 @@ class json_parser { return 0; } - /** - * match field name string when current token is FIELD_NAME, - * return true if current token is FIELD_NAME and match successfully. - * return false otherwise, - */ - __device__ bool match_current_field_name(cudf::string_view name) const - { - return match_current_field_name(char_range(name)); - } - - /** - * match current field name - */ - __device__ bool match_current_field_name(char_range name) const - { - if (json_token::FIELD_NAME == current_token) { - char_range_reader reader(current_range()); - char_range_reader to_match(name); - auto [b, end_pos] = try_parse_string(reader, to_match, escape_style::UNESCAPED); - return b; - } else { - return false; - } - } - /** * copy current structure to destination. * return false if meets JSON format error, @@ -1648,7 +1611,9 @@ class json_parser { bool has_colon_before_token = false; // parse and get has_comma_before_token, has_colon_before_token - parse_next_token_and_set_current(has_comma_before_token, has_colon_before_token); + bool matched_field_name; // unused + parse_next_token_and_set_current( + has_comma_before_token, has_colon_before_token, matched_field_name); // check the JSON format if (current_token == json_token::ERROR) { return thrust::make_pair(false, 0); } diff --git a/src/main/cpp/src/json_utils.cu b/src/main/cpp/src/json_utils.cu new file mode 100644 index 0000000000..ef3d0db0f8 --- /dev/null +++ b/src/main/cpp/src/json_utils.cu @@ -0,0 +1,229 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace spark_rapids_jni { + +namespace detail { + +namespace { + +constexpr bool not_whitespace(cudf::char_utf8 ch) +{ + return ch != ' ' && ch != '\r' && ch != '\n' && ch != '\t'; +} + +constexpr bool can_be_delimiter(char c) +{ + // The character list below is from `json_reader_options.set_delimiter`. + switch (c) { + case '{': + case '[': + case '}': + case ']': + case ',': + case ':': + case '"': + case '\'': + case '\\': + case ' ': + case '\t': + case '\r': return false; + default: return true; + } +} + +} // namespace + +std::tuple, char, std::unique_ptr> concat_json( + cudf::strings_column_view const& input, + bool nullify_invalid_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + if (input.is_empty()) { + return {std::make_unique(0, stream, mr), + '\n', + std::make_unique( + rmm::device_uvector{0, stream, mr}, rmm::device_buffer{}, 0)}; + } + + auto const d_input_ptr = cudf::column_device_view::create(input.parent(), stream); + auto const default_mr = rmm::mr::get_current_device_resource(); + + // Check if the input rows are null, empty (containing only whitespaces), and invalid JSON. + // This will be used for masking out the null/empty/invalid input rows when doing string + // concatenation. + rmm::device_uvector is_valid_input(input.size(), stream, default_mr); + + // Check if the input rows are null, empty (containing only whitespaces), and may also check + // for invalid JSON strings. + // This will be returned to the caller to create null mask for the final output. + rmm::device_uvector should_be_nullified(input.size(), stream, mr); + + thrust::for_each( + rmm::exec_policy_nosync(stream), + thrust::make_counting_iterator(0L), + thrust::make_counting_iterator(input.size() * static_cast(cudf::detail::warp_size)), + [nullify_invalid_rows, + input = *d_input_ptr, + output = thrust::make_zip_iterator(thrust::make_tuple( + is_valid_input.begin(), should_be_nullified.begin()))] __device__(int64_t tidx) { + // Execute one warp per row to minimize thread divergence. + if ((tidx % cudf::detail::warp_size) != 0) { return; } + auto const idx = tidx / cudf::detail::warp_size; + + if (input.is_null(idx)) { + output[idx] = thrust::make_tuple(false, true); + return; + } + + auto const d_str = input.element(idx); + auto const size = d_str.size_bytes(); + int i = 0; + char ch; + + // Skip the very first whitespace characters. + for (; i < size; ++i) { + ch = d_str[i]; + if (not_whitespace(ch)) { break; } + } + + auto const not_eol = i < size; + + // If the current row is not null or empty, it should start with `{`. Otherwise, we need to + // replace it by a null. This is necessary for libcudf's JSON reader to work. + // Note that if we want to support ARRAY schema, we need to check for `[` instead. + auto constexpr start_character = '{'; + if (not_eol && ch != start_character) { + output[idx] = thrust::make_tuple(false, nullify_invalid_rows); + return; + } + + output[idx] = thrust::make_tuple(not_eol, !not_eol); + }); + + auto constexpr num_levels = 256; + auto constexpr lower_level = std::numeric_limits::min(); + auto constexpr upper_level = std::numeric_limits::max(); + auto const num_chars = input.chars_size(stream); + + rmm::device_uvector histogram(num_levels, stream, default_mr); + thrust::uninitialized_fill( + rmm::exec_policy_nosync(stream), histogram.begin(), histogram.end(), 0); + + size_t temp_storage_bytes = 0; + cub::DeviceHistogram::HistogramEven(nullptr, + temp_storage_bytes, + input.chars_begin(stream), + histogram.begin(), + num_levels, + lower_level, + upper_level, + num_chars, + stream.value()); + rmm::device_buffer d_temp(temp_storage_bytes, stream); + cub::DeviceHistogram::HistogramEven(d_temp.data(), + temp_storage_bytes, + input.chars_begin(stream), + histogram.begin(), + num_levels, + lower_level, + upper_level, + num_chars, + stream.value()); + + auto const it = thrust::make_counting_iterator(0); + auto const zero_level_idx = -lower_level; // the bin storing count for character `\0` + auto const zero_level_it = it + zero_level_idx; + auto const end = it + num_levels; + + auto const first_zero_count_pos = + thrust::find_if(rmm::exec_policy_nosync(stream), + zero_level_it, // ignore the negative characters + end, + [zero_level_idx, counts = histogram.begin()] __device__(auto idx) -> bool { + auto const count = counts[idx]; + if (count > 0) { return false; } + auto const first_non_existing_char = static_cast(idx - zero_level_idx); + return can_be_delimiter(first_non_existing_char); + }); + + // This should never happen since the input should never cover the entire char range. + if (first_zero_count_pos == end) { + throw std::logic_error( + "Cannot find any character suitable as delimiter during joining json strings."); + } + auto const delimiter = static_cast(thrust::distance(zero_level_it, first_zero_count_pos)); + + auto [null_mask, null_count] = cudf::detail::valid_if( + is_valid_input.begin(), is_valid_input.end(), thrust::identity{}, stream, default_mr); + // If the null count doesn't change, just use the input column for concatenation. + auto const input_applied_null = + null_count == input.null_count() + ? cudf::column_view{} + : cudf::column_view{cudf::data_type{cudf::type_id::STRING}, + input.size(), + input.chars_begin(stream), + reinterpret_cast(null_mask.data()), + null_count, + input.offset(), + std::vector{input.offsets()}}; + + auto concat_strings = cudf::strings::detail::join_strings( + null_count == input.null_count() ? input : cudf::strings_column_view{input_applied_null}, + cudf::string_scalar(std::string(1, delimiter), true, stream, default_mr), + cudf::string_scalar("{}", true, stream, default_mr), + stream, + mr); + + return {std::move(concat_strings->release().data), + delimiter, + std::make_unique(std::move(should_be_nullified), rmm::device_buffer{}, 0)}; +} + +} // namespace detail + +std::tuple, char, std::unique_ptr> concat_json( + cudf::strings_column_view const& input, + bool nullify_invalid_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + CUDF_FUNC_RANGE(); + return detail::concat_json(input, nullify_invalid_rows, stream, mr); +} + +} // namespace spark_rapids_jni diff --git a/src/main/cpp/src/json_utils.hpp b/src/main/cpp/src/json_utils.hpp new file mode 100644 index 0000000000..319c81103d --- /dev/null +++ b/src/main/cpp/src/json_utils.hpp @@ -0,0 +1,116 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include +#include + +#include + +namespace spark_rapids_jni { + +/** + * @brief Extract a map column from the JSON strings given by an input strings column. + */ +std::unique_ptr from_json_to_raw_map( + cudf::strings_column_view const& input, + bool normalize_single_quotes, + bool allow_leading_zeros, + bool allow_nonnumeric_numbers, + bool allow_unquoted_control, + rmm::cuda_stream_view stream = cudf::get_default_stream(), + rmm::device_async_resource_ref mr = cudf::get_current_device_resource()); + +/** + * @brief Parse JSON strings into a struct column followed by a given data schema. + * + * The data schema is specified as data arrays flattened by depth-first-search order. + */ +std::unique_ptr from_json_to_structs( + cudf::strings_column_view const& input, + std::vector const& col_names, + std::vector const& num_children, + std::vector const& types, + std::vector const& scales, + std::vector const& precisions, + bool normalize_single_quotes, + bool allow_leading_zeros, + bool allow_nonnumeric_numbers, + bool allow_unquoted_control, + bool is_us_locale, + rmm::cuda_stream_view stream = cudf::get_default_stream(), + rmm::device_async_resource_ref mr = cudf::get_current_device_resource()); + +/** + * @brief Convert from a strings column to a column with the desired type given by a data schema. + * + * The given column schema is specified as data arrays flattened by depth-first-search order. + */ +std::unique_ptr convert_from_strings( + cudf::strings_column_view const& input, + std::vector const& num_children, + std::vector const& types, + std::vector const& scales, + std::vector const& precisions, + bool allow_nonnumeric_numbers, + bool is_us_locale, + rmm::cuda_stream_view stream = cudf::get_default_stream(), + rmm::device_async_resource_ref mr = cudf::get_current_device_resource()); + +/** + * @brief Remove quotes from each string in the given strings column. + * + * If `nullify_if_not_quoted` is true, an input string that is not quoted will result in a null. + * Otherwise, the output will be the same as the unquoted input. + */ +std::unique_ptr remove_quotes( + cudf::strings_column_view const& input, + bool nullify_if_not_quoted, + rmm::cuda_stream_view stream = cudf::get_default_stream(), + rmm::device_async_resource_ref mr = cudf::get_current_device_resource()); + +/** + * @brief Concatenate the JSON objects given by a strings column into one single character buffer, + * in which each JSON objects is delimited by a special character that does not exist in the input. + * + * Beyond returning the concatenated buffer with delimiter, the function also returns a BOOL8 + * column indicating which rows should be nullified after parsing the concatenated buffer. Each + * row of this column is a `true` value if the corresponding input row is either empty, containing + * only whitespaces, or invalid JSON object depending on the `nullify_invalid_rows` parameter. + * + * Note that an invalid JSON object in this context is a string that does not start with the `{` + * character after whitespaces. + * + * @param input The strings column containing input JSON objects + * @param nullify_invalid_rows Whether to nullify rows containing invalid JSON objects + * @param stream The CUDA stream used for device memory operations and kernel launches + * @param mr Device memory resource used to allocate device memory of the table in the returned + * @return A tuple containing the concatenated JSON objects as a single buffer, the delimiter + * character, and a BOOL8 column indicating which rows should be nullified after parsing + * the concatenated buffer + */ +std::tuple, char, std::unique_ptr> concat_json( + cudf::strings_column_view const& input, + bool nullify_invalid_rows = false, + rmm::cuda_stream_view stream = cudf::get_default_stream(), + rmm::device_async_resource_ref mr = cudf::get_current_device_resource()); + +} // namespace spark_rapids_jni diff --git a/src/main/cpp/tests/CMakeLists.txt b/src/main/cpp/tests/CMakeLists.txt index 244d18c903..c774d30618 100644 --- a/src/main/cpp/tests/CMakeLists.txt +++ b/src/main/cpp/tests/CMakeLists.txt @@ -31,7 +31,8 @@ function(ConfigureTest CMAKE_TEST_NAME) INSTALL_RPATH "\$ORIGIN/../../../lib" ) target_link_libraries(${CMAKE_TEST_NAME} GTest::gtest_main GTest::gmock_main cudf::cudf - cudf::cudftestutil spark_rapids_jni) + cudf::cudftestutil cudf::cudftestutil_impl + spark_rapids_jni) add_test(NAME ${CMAKE_TEST_NAME} COMMAND ${CMAKE_TEST_NAME}) install( TARGETS ${CMAKE_TEST_NAME} diff --git a/src/main/cpp/tests/cast_string.cpp b/src/main/cpp/tests/cast_string.cpp index efad547869..862469b0af 100644 --- a/src/main/cpp/tests/cast_string.cpp +++ b/src/main/cpp/tests/cast_string.cpp @@ -43,13 +43,13 @@ TYPED_TEST_SUITE(StringToFloatTests, cudf::test::FloatingPointTypes); TYPED_TEST(StringToIntegerTests, Simple) { - auto const strings = test::strings_column_wrapper{"1", "0", "42"}; + auto const strings = test::strings_column_wrapper{"1", "0", "42", "null"}; strings_column_view scv{strings}; auto const result = spark_rapids_jni::string_to_integer( data_type{type_to_id()}, scv, false, true, cudf::get_default_stream()); - test::fixed_width_column_wrapper expected({1, 0, 42}, {1, 1, 1}); + test::fixed_width_column_wrapper expected({1, 0, 42, 0}, {1, 1, 1, 0}); CUDF_TEST_EXPECT_COLUMNS_EQUAL(result->view(), expected); } @@ -252,6 +252,24 @@ TYPED_TEST(StringToIntegerTests, Empty) EXPECT_EQ(result->type().id(), type_to_id()); } +TYPED_TEST(StringToIntegerTests, NonEmptyNulls) +{ + auto const strings = test::strings_column_wrapper{"123", "123", "123", "123"}.release(); + auto const valids = std::vector{true, false, true, false}; + auto [null_mask, null_count] = cudf::test::detail::make_null_mask(valids.begin(), valids.end()); + strings->set_null_mask(null_mask, null_count); + + auto const scv = strings_column_view{*strings}; + EXPECT_EQ(scv.chars_size(cudf::get_default_stream()), 12); // make sure it has non-empty null. + + auto const result = spark_rapids_jni::string_to_integer( + data_type{type_to_id()}, scv, false, true, cudf::get_default_stream()); + + test::fixed_width_column_wrapper expected({123, 0, 123, 0}, {1, 0, 1, 0}); + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(result->view(), expected); +} + TEST_F(StringToDecimalTests, Simple) { auto const strings = test::strings_column_wrapper({"1", "0", "-1"}); @@ -552,6 +570,26 @@ TEST_F(StringToDecimalTests, Empty) EXPECT_EQ(result->type().scale(), 2); } +TEST_F(StringToDecimalTests, NonEmptyNulls) +{ + auto const strings = + test::strings_column_wrapper{"1.23456", "1.23456", "1.23456", "1.23456"}.release(); + auto const valids = std::vector{true, false, true, false}; + auto [null_mask, null_count] = cudf::test::detail::make_null_mask(valids.begin(), valids.end()); + strings->set_null_mask(null_mask, null_count); + + auto const scv = strings_column_view{*strings}; + EXPECT_EQ(scv.chars_size(cudf::get_default_stream()), 28); // make sure it has non-empty null. + + auto const result = + spark_rapids_jni::string_to_decimal(6, -5, scv, false, true, cudf::get_default_stream()); + + test::fixed_point_column_wrapper expected( + {123456, 0, 123456, 0}, {1, 0, 1, 0}, numeric::scale_type{-5}); + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(result->view(), expected); +} + TYPED_TEST(StringToFloatTests, Simple) { cudf::test::strings_column_wrapper in{"-1.8946e-10", @@ -707,3 +745,21 @@ TYPED_TEST(StringToFloatTests, Empty) EXPECT_EQ(result->size(), 0); } + +TYPED_TEST(StringToFloatTests, NonEmptyNulls) +{ + auto const strings = test::strings_column_wrapper{"1.23", "1.23", "1.23", "1.23"}.release(); + auto const valids = std::vector{true, false, true, false}; + auto [null_mask, null_count] = cudf::test::detail::make_null_mask(valids.begin(), valids.end()); + strings->set_null_mask(null_mask, null_count); + + auto const scv = strings_column_view{*strings}; + EXPECT_EQ(scv.chars_size(cudf::get_default_stream()), 16); // make sure it has non-empty null. + + auto const result = spark_rapids_jni::string_to_float( + data_type{type_to_id()}, scv, false, cudf::get_default_stream()); + + test::fixed_width_column_wrapper expected({1.23, 0.0, 1.23, 0.0}, {1, 0, 1, 0}); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(result->view(), expected); +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/Arms.java b/src/main/java/com/nvidia/spark/rapids/jni/Arms.java new file mode 100644 index 0000000000..4b6ecf7204 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/Arms.java @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; +import java.util.function.Function; + +/** + * This class contains utility methods for automatic resource management. + */ +public class Arms { + /** + * This method close the resource if an exception is thrown while executing the function. + */ + public static T closeIfException(R resource, Function function) { + try { + return function.apply(resource); + } catch (Exception e) { + if (resource != null) { + try { + resource.close(); + } catch (Exception inner) { + e.addSuppressed(inner); + } + } + throw e; + } + } + + /** + * This method safely closes all the resources. + *

+ * This method will iterate through all the resources and closes them. If any exception happened during the + * traversal, exception will be captured and rethrown after all resources closed. + *

+ */ + public static void closeAll(Iterator resources) { + Throwable t = null; + while (resources.hasNext()) { + try { + R resource = resources.next(); + if (resource != null) { + resource.close(); + } + } catch (Exception e) { + if (t == null) { + t = e; + } else { + t.addSuppressed(e); + } + } + } + + if (t != null) throw new RuntimeException(t); + } + + + /** + * This method safely closes all the resources. See {@link #closeAll(Iterator)} for more details. + */ + public static void closeAll(R... resources) { + closeAll(Arrays.asList(resources)); + } + + /** + * This method safely closes the resources. See {@link #closeAll(Iterator)} for more details. + */ + public static void closeAll(Collection resources) { + closeAll(resources.iterator()); + } + + /** + * This method safely closes the resources after applying the function. + *
+ * See {@link #closeAll(Iterator)} for more details. + */ + public static , V> V withResource( + C resource, Function function) { + try { + return function.apply(resource); + } finally { + closeAll(resource); + } + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/GpuTimeZoneDB.java b/src/main/java/com/nvidia/spark/rapids/jni/GpuTimeZoneDB.java index 643db278df..a8048b1e8b 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/GpuTimeZoneDB.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/GpuTimeZoneDB.java @@ -52,33 +52,10 @@ public class GpuTimeZoneDB { // For the timezone database, we store the transitions in a ColumnVector that is a list of // structs. The type of this column vector is: // LIST> - private Map zoneIdToTable; + private static Map zoneIdToTable; // use this reference to indicate if time zone cache is initialized. - private HostColumnVector fixedTransitions; - - // Guarantee singleton instance - private GpuTimeZoneDB() { - } - - // singleton instance - private static final GpuTimeZoneDB instance = new GpuTimeZoneDB(); - - // This method is default visibility for testing purposes only. - // The instance will be never be exposed publicly for this class. - static GpuTimeZoneDB getInstance() { - return instance; - } - - static class LoadingLock { - Boolean isLoading = false; - - // record whether a shutdown is called ever. - // if `isCloseCalledEver` is true, then the following loading should be skipped. - Boolean isShutdownCalledEver = false; - } - - private static final LoadingLock lock = new LoadingLock(); + private static HostColumnVector fixedTransitions; /** * This should be called on startup of an executor. @@ -86,33 +63,12 @@ static class LoadingLock { * If `shutdown` was called ever, then will not load the cache */ public static void cacheDatabaseAsync() { - synchronized (lock) { - if (lock.isShutdownCalledEver) { - // shutdown was called ever, will never load cache again. - return; - } - - if (lock.isLoading) { - // another thread is loading(), return - return; - } else { - lock.isLoading = true; - } - } - // start a new thread to load Runnable runnable = () -> { try { - instance.cacheDatabaseImpl(); + cacheDatabaseImpl(); } catch (Exception e) { log.error("cache time zone transitions cache failed", e); - } finally { - synchronized (lock) { - // now loading is done - lock.isLoading = false; - // `cacheDatabase` and `shutdown` may wait loading is done. - lock.notify(); - } } }; Thread thread = Executors.defaultThreadFactory().newThread(runnable); @@ -127,55 +83,17 @@ public static void cacheDatabaseAsync() { * If cache is exits, do not load cache again. */ public static void cacheDatabase() { - synchronized (lock) { - if (lock.isLoading) { - // another thread is loading(), wait loading is done - while (lock.isLoading) { - try { - lock.wait(); - } catch (InterruptedException e) { - throw new IllegalStateException("cache time zone transitions cache failed", e); - } - } - return; - } else { - lock.isLoading = true; - } - } - - try { - instance.cacheDatabaseImpl(); - } finally { - // loading is done. - synchronized (lock) { - lock.isLoading = false; - // `cacheDatabase` and/or `shutdown` may wait loading is done. - lock.notify(); - } - } + cacheDatabaseImpl(); } /** * close the cache, used when Plugin is closing */ - public static void shutdown() { - synchronized (lock) { - lock.isShutdownCalledEver = true; - while (lock.isLoading) { - // wait until loading is done - try { - lock.wait(); - } catch (InterruptedException e) { - throw new IllegalStateException("shutdown time zone transitions cache failed", e); - } - } - instance.shutdownImpl(); - // `cacheDatabase` and/or `shutdown` may wait loading is done. - lock.notify(); - } + public static synchronized void shutdown() { + closeResources(); } - private void cacheDatabaseImpl() { + private static synchronized void cacheDatabaseImpl() { if (fixedTransitions == null) { try { loadData(); @@ -186,11 +104,7 @@ private void cacheDatabaseImpl() { } } - private void shutdownImpl() { - closeResources(); - } - - private void closeResources() { + private static synchronized void closeResources() { if (zoneIdToTable != null) { zoneIdToTable.clear(); zoneIdToTable = null; @@ -208,9 +122,12 @@ public static ColumnVector fromTimestampToUtcTimestamp(ColumnVector input, ZoneI throw new IllegalArgumentException(String.format("Unsupported timezone: %s", currentTimeZone.toString())); } + // there is technically a race condition on shutdown. Shutdown could be called after + // the database is cached. This would result in a null pointer exception at some point + // in the processing. This should be rare enough that it is not a big deal. cacheDatabase(); - Integer tzIndex = instance.getZoneIDMap().get(currentTimeZone.normalized().toString()); - try (Table transitions = instance.getTransitions()) { + Integer tzIndex = zoneIdToTable.get(currentTimeZone.normalized().toString()); + try (Table transitions = getTransitions()) { return new ColumnVector(convertTimestampColumnToUTC(input.getNativeView(), transitions.getNativeView(), tzIndex)); } @@ -223,9 +140,12 @@ public static ColumnVector fromUtcTimestampToTimestamp(ColumnVector input, ZoneI throw new IllegalArgumentException(String.format("Unsupported timezone: %s", desiredTimeZone.toString())); } + // there is technically a race condition on shutdown. Shutdown could be called after + // the database is cached. This would result in a null pointer exception at some point + // in the processing. This should be rare enough that it is not a big deal. cacheDatabase(); - Integer tzIndex = instance.getZoneIDMap().get(desiredTimeZone.normalized().toString()); - try (Table transitions = instance.getTransitions()) { + Integer tzIndex = zoneIdToTable.get(desiredTimeZone.normalized().toString()); + try (Table transitions = getTransitions()) { return new ColumnVector(convertUTCTimestampColumnToTimeZone(input.getNativeView(), transitions.getNativeView(), tzIndex)); } @@ -258,7 +178,7 @@ public static ZoneId getZoneId(String timeZoneId) { } @SuppressWarnings("unchecked") - private void loadData() { + private static synchronized void loadData() { try { List> masterTransitions = new ArrayList<>(); zoneIdToTable = new HashMap<>(); @@ -334,17 +254,13 @@ private void loadData() { } } - private Map getZoneIDMap() { - return zoneIdToTable; - } - - private Table getTransitions() { + private static synchronized Table getTransitions() { try (ColumnVector fixedTransitions = getFixedTransitions()) { return new Table(fixedTransitions); } } - private ColumnVector getFixedTransitions() { + private static synchronized ColumnVector getFixedTransitions() { return fixedTransitions.copyToDevice(); } @@ -358,16 +274,15 @@ private ColumnVector getFixedTransitions() { * @param zoneId * @return list of fixed transitions */ - List getHostFixedTransitions(String zoneId) { + static synchronized List getHostFixedTransitions(String zoneId) { zoneId = ZoneId.of(zoneId).normalized().toString(); // we use the normalized form to dedupe - Integer idx = getZoneIDMap().get(zoneId); + Integer idx = zoneIdToTable.get(zoneId); if (idx == null) { return null; } return fixedTransitions.getList(idx); } - private static native long convertTimestampColumnToUTC(long input, long transitions, int tzIndex); private static native long convertUTCTimestampColumnToTimeZone(long input, long transitions, int tzIndex); diff --git a/src/main/java/com/nvidia/spark/rapids/jni/Hash.java b/src/main/java/com/nvidia/spark/rapids/jni/Hash.java index a25fead0fd..16971c5bdb 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/Hash.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/Hash.java @@ -96,7 +96,6 @@ public static ColumnVector hiveHash(ColumnView columns[]) { assert columns[i] != null : "Column vectors passed may not be null"; assert columns[i].getRowCount() == size : "Row count mismatch, all columns must be the same size"; assert !columns[i].getType().isDurationType() : "Unsupported column type Duration"; - assert !columns[i].getType().isNestedType() : "Unsupported column type Nested"; columnViews[i] = columns[i].getNativeView(); } return new ColumnVector(hiveHash(columnViews)); diff --git a/src/main/java/com/nvidia/spark/rapids/jni/HostTable.java b/src/main/java/com/nvidia/spark/rapids/jni/HostTable.java new file mode 100644 index 0000000000..754412d727 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/HostTable.java @@ -0,0 +1,190 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni; + +import ai.rapids.cudf.ColumnVector; +import ai.rapids.cudf.Cuda; +import ai.rapids.cudf.DeviceMemoryBuffer; +import ai.rapids.cudf.HostMemoryBuffer; +import ai.rapids.cudf.NativeDepsLoader; +import ai.rapids.cudf.Table; + +/** + * Represents a cudf Table but in host memory instead of device memory. + * Table is tracked in native code as a host_table_view. + */ +public class HostTable implements AutoCloseable { + + static { + NativeDepsLoader.loadNativeDeps(); + } + + private long nativeTableView; + private HostMemoryBuffer hostBuffer; + + /** + * Copies a device table to a host table asynchronously. + * NOTE: The caller must synchronize on the stream before examining the data on the host. + * @param table device table to copy + * @param stream stream to use for the copy + * @return host table + */ + public static HostTable fromTableAsync(Table table, Cuda.Stream stream) { + long size = bufferSize(table.getNativeView(), stream.getStream()); + long tableHandle = 0; + HostMemoryBuffer hostBuffer = HostMemoryBuffer.allocate(size); + try { + tableHandle = copyFromTableAsync(table.getNativeView(), + hostBuffer.getAddress(), hostBuffer.getLength(), stream.getStream()); + } catch (Throwable t) { + try { + hostBuffer.close(); + } catch (Throwable t2) { + t.addSuppressed(t2); + } + throw t; + } + return new HostTable(tableHandle, hostBuffer); + } + + /** + * Copies a device table to a host table synchronously. + * @param table device table to copy + * @param stream stream to use for the copy + * @return host table + */ + public static HostTable fromTable(Table table, Cuda.Stream stream) { + HostTable hostTable = fromTableAsync(table, stream); + stream.sync(); + return hostTable; + } + + /** + * Copies a device table to a host table synchronously on the default stream. + * @param table device table to copy + * @return host table + */ + public static HostTable fromTable(Table table) { + return fromTable(table, Cuda.DEFAULT_STREAM); + } + + private HostTable(long tableHandle, HostMemoryBuffer hostBuffer) { + this.nativeTableView = tableHandle; + this.hostBuffer = hostBuffer; + } + + /** + * Gets the address of the host_table_view for this host table. + * NOTE: This is only valid as long as the HostTable instance is valid. + */ + public long getNativeTableView() { + return nativeTableView; + } + + /** + * Gets the host memory buffer containing the data for this host table. + */ + public HostMemoryBuffer getHostBuffer() { + return hostBuffer; + } + + /** + * Copies the host table to a device table asynchronously. + * NOTE: The caller must synchronize on the stream before closing this instance, + * or the copy could still be in-flight when the host memory is invalidated or reused. + * @param stream stream to use for the copy + * @return device table + */ + public Table toTableAsync(Cuda.Stream stream) { + long size = hostBuffer.getLength(); + Table table = null; + try (DeviceMemoryBuffer devBuffer = DeviceMemoryBuffer.allocate(size, stream)) { + devBuffer.copyFromHostBufferAsync(hostBuffer, stream); + long hostToDevPtrOffset = devBuffer.getAddress() - hostBuffer.getAddress(); + long[] columnViewHandles = toDeviceColumnViews(nativeTableView, hostToDevPtrOffset); + ColumnVector[] columns = new ColumnVector[columnViewHandles.length]; + boolean done = false; + try { + for (int i = 0; i < columnViewHandles.length; i++) { + columns[i] = ColumnVector.fromViewWithContiguousAllocation(columnViewHandles[i], devBuffer); + columnViewHandles[i] = 0; + } + table = new Table(columns); + // Need to synchronize before returning to ensure host copy completed, otherwise caller may + // free and reuse the host buffer before device copy completes. + stream.sync(); + done = true; + } finally { + // always close columns because Table incremented refcounts + for (ColumnVector c : columns) { + if (c != null) { + c.close(); + } + } + if (!done) { + for (long viewHandle : columnViewHandles) { + if (viewHandle != 0) { + freeDeviceColumnView(viewHandle); + } + } + } + } + } + return table; + } + + /** + * Copies the host table to a device table synchronously. + * @param stream stream to use for the copy + * @return device table + */ + public Table toTable(Cuda.Stream stream) { + Table table = toTableAsync(stream); + stream.sync(); + return table; + } + + /** + * Copies the host table to a device table synchronously on the default stream. + * @return device table + */ + public Table toTable() { + return toTable(Cuda.DEFAULT_STREAM); + } + + @Override + public void close() { + try { + freeHostTable(nativeTableView); + } finally { + nativeTableView = 0; + hostBuffer.close(); + hostBuffer = null; + } + } + + private static native long bufferSize(long tableHandle, long stream); + + private static native long copyFromTableAsync(long tableHandle, long hostAddress, long hostSize, + long stream); + + private static native long[] toDeviceColumnViews(long tableHandle, long hostToDevPtrOffset); + + private static native void freeDeviceColumnView(long columnHandle); + + private static native void freeHostTable(long tableHandle); +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/JSONUtils.java b/src/main/java/com/nvidia/spark/rapids/jni/JSONUtils.java index 3a7c4a6a53..9cf00acff2 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/JSONUtils.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/JSONUtils.java @@ -141,25 +141,104 @@ public static ColumnVector[] getJsonObjectMultiplePaths(ColumnVector input, /** * Extract key-value pairs for each output map from the given json strings. These key-value are * copied directly as substrings of the input without any type conversion. - *

+ *

* Since there is not any validity check, the output of this function may be different from * what generated by Spark's `from_json` function. Situations that can lead to * different/incorrect outputs may include:
* - The value in the input json string is invalid, such as 'abc' value for an integer key.
* - The value string can be non-clean format for floating-point type, such as '1.00000'. - *

+ *

* The output of these situations should all be NULL or a value '1.0', respectively. However, this * function will just simply copy the input value strings to the output. * * @param input The input strings column in which each row specifies a json object + * @param opts The options for parsing JSON strings * @return A map column (i.e., a column of type {@code List>}) in - * which the key-value pairs are extracted directly from the input json strings + * which the key-value pairs are extracted directly from the input json strings */ - public static ColumnVector extractRawMapFromJsonString(ColumnView input) { + public static ColumnVector extractRawMapFromJsonString(ColumnView input, JSONOptions opts) { assert (input.getType().equals(DType.STRING)) : "Input must be of STRING type"; - return new ColumnVector(extractRawMapFromJsonString(input.getNativeView())); + return new ColumnVector(extractRawMapFromJsonString(input.getNativeView(), + opts.isNormalizeSingleQuotes(), + opts.leadingZerosAllowed(), + opts.nonNumericNumbersAllowed(), + opts.unquotedControlChars())); } + /** + * Parse a JSON string into a struct column following by the given data schema. + *

+ * Many JSON options in the given {@code opts} parameter are ignored from passing down to the + * native code. That is because these options are hard-coded with the same values in both the + * plugin code and native code. Specifically:
+ * - {@code RecoverWithNull: true}
+ * - {@code MixedTypesAsStrings: true}
+ * - {@code NormalizeWhitespace: true}
+ * - {@code KeepQuotes: true}
+ * - {@code StrictValidation: true}
+ * - {@code Experimental: true} + * + * @param input The input strings column in which each row specifies a json object + * @param schema The schema of the output struct column + * @param opts The options for parsing JSON strings + * @param isUSLocale Whether the current local is US locale, used when converting strings to + * decimal types + * @return A struct column in which each row is parsed from the corresponding json string + */ + public static ColumnVector fromJSONToStructs(ColumnView input, Schema schema, JSONOptions opts, + boolean isUSLocale) { + assert (input.getType().equals(DType.STRING)) : "Input must be of STRING type"; + return new ColumnVector(fromJSONToStructs(input.getNativeView(), + schema.getFlattenedColumnNames(), + schema.getFlattenedNumChildren(), + schema.getFlattenedTypeIds(), + schema.getFlattenedTypeScales(), + schema.getFlattenedDecimalPrecisions(), + opts.isNormalizeSingleQuotes(), + opts.leadingZerosAllowed(), + opts.nonNumericNumbersAllowed(), + opts.unquotedControlChars(), + isUSLocale)); + } + + /** + * Convert from a strings column to a column with the desired type given by a data schema. + * + * @param input The input strings column + * @param schema The schema of the output column + * @param allowedNonNumericNumbers Whether non-numeric numbers are allowed, used when converting + * strings to float types + * @param isUSLocale Whether the current local is US locale, used when converting strings to + * decimal types + * @return A column with the desired data type + */ + public static ColumnVector convertFromStrings(ColumnView input, Schema schema, + boolean allowedNonNumericNumbers, + boolean isUSLocale) { + assert (input.getType().equals(DType.STRING)) : "Input must be of STRING type"; + return new ColumnVector(convertFromStrings(input.getNativeView(), + schema.getFlattenedNumChildren(), + schema.getFlattenedTypeIds(), + schema.getFlattenedTypeScales(), + schema.getFlattenedDecimalPrecisions(), + allowedNonNumericNumbers, + isUSLocale)); + } + + /** + * Remove quotes from each string in the given strings column. + *

+ * If `nullifyIfNotQuoted` is true, an input string that is not quoted will result in a null. + * Otherwise, the output will be the same as the unquoted input. + * + * @param input The input strings column + * @param nullifyIfNotQuoted Whether to output a null row if the input string is not quoted + * @return A strings column in which quotes are removed from all strings + */ + public static ColumnVector removeQuotes(ColumnView input, boolean nullifyIfNotQuoted) { + assert (input.getType().equals(DType.STRING)) : "Input must be of STRING type"; + return new ColumnVector(removeQuotes(input.getNativeView(), nullifyIfNotQuoted)); + } private static native int getMaxJSONPathDepth(); @@ -176,6 +255,31 @@ private static native long[] getJsonObjectMultiplePaths(long input, long memoryBudgetBytes, int parallelOverride); + private static native long extractRawMapFromJsonString(long input, + boolean normalizeSingleQuotes, + boolean leadingZerosAllowed, + boolean nonNumericNumbersAllowed, + boolean unquotedControlChars); + + private static native long fromJSONToStructs(long input, + String[] names, + int[] numChildren, + int[] typeIds, + int[] typeScales, + int[] typePrecision, + boolean normalizeSingleQuotes, + boolean leadingZerosAllowed, + boolean nonNumericNumbersAllowed, + boolean unquotedControlChars, + boolean isUSLocale); + + private static native long convertFromStrings(long input, + int[] numChildren, + int[] typeIds, + int[] typeScales, + int[] typePrecision, + boolean nonNumericNumbersAllowed, + boolean isUSLocale); - private static native long extractRawMapFromJsonString(long input); + private static native long removeQuotes(long input, boolean nullifyIfNotQuoted); } diff --git a/src/main/cpp/src/from_json.hpp b/src/main/java/com/nvidia/spark/rapids/jni/Pair.java similarity index 51% rename from src/main/cpp/src/from_json.hpp rename to src/main/java/com/nvidia/spark/rapids/jni/Pair.java index 75fc3bc103..ac8aa1910c 100644 --- a/src/main/cpp/src/from_json.hpp +++ b/src/main/java/com/nvidia/spark/rapids/jni/Pair.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,21 +14,29 @@ * limitations under the License. */ -#pragma once +package com.nvidia.spark.rapids.jni; -#include -#include - -#include -#include +/** + * A utility class for holding a pair of values. + */ +public class Pair { + private final K left; + private final V right; -#include + public Pair(K left, V right) { + this.left = left; + this.right = right; + } -namespace spark_rapids_jni { + public K getLeft() { + return left; + } -std::unique_ptr from_json_to_raw_map( - cudf::strings_column_view const& input, - rmm::cuda_stream_view stream = cudf::get_default_stream(), - rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource()); + public V getRight() { + return right; + } -} // namespace spark_rapids_jni + public static Pair of(K left, V right) { + return new Pair<>(left, right); + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/Preconditions.java b/src/main/java/com/nvidia/spark/rapids/jni/Preconditions.java new file mode 100644 index 0000000000..a6dfcdb104 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/Preconditions.java @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni; + +import java.util.function.Supplier; + +/** + * This class contains utility methods for checking preconditions. + */ +public class Preconditions { + /** + * Check if the condition is true, otherwise throw an IllegalStateException with the given message. + */ + public static void ensure(boolean condition, String message) { + if (!condition) { + throw new IllegalStateException(message); + } + } + + /** + * Check if the condition is true, otherwise throw an IllegalStateException with the given message supplier. + */ + public static void ensure(boolean condition, Supplier messageSupplier) { + if (!condition) { + throw new IllegalStateException(messageSupplier.get()); + } + } + + /** + * Check if the value is non-negative, otherwise throw an IllegalArgumentException with the given message. + * @param value the value to check + * @param name the name of the value + * @return the value if it is non-negative + * @throws IllegalArgumentException if the value is negative + */ + public static int ensureNonNegative(int value, String name) { + if (value < 0) { + throw new IllegalArgumentException(name + " must be non-negative, but was " + value); + } + return value; + } + + /** + * Check if the value is non-negative, otherwise throw an IllegalArgumentException with the given message. + * @param value the value to check + * @param name the name of the value + * @return the value if it is non-negative + * @throws IllegalArgumentException if the value is negative + */ + public static long ensureNonNegative(long value, String name) { + if (value < 0) { + throw new IllegalArgumentException(name + " must be non-negative, but was " + value); + } + return value; + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/ColumnOffsetInfo.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/ColumnOffsetInfo.java new file mode 100644 index 0000000000..3a46806a78 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/ColumnOffsetInfo.java @@ -0,0 +1,121 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni.kudo; + +import ai.rapids.cudf.DeviceMemoryBufferView; + +import static com.nvidia.spark.rapids.jni.Preconditions.ensureNonNegative; + +/** + * This class is used to store the offsets of the buffer of a column in the serialized data. + */ +class ColumnOffsetInfo { + static final long INVALID_OFFSET = -1L; + private final long validity; + private final long validityBufferLen; + private final long offset; + private final long offsetBufferLen; + private final long data; + private final long dataBufferLen; + + public ColumnOffsetInfo(long validity, long validityBufferLen, long offset, long offsetBufferLen, long data, + long dataBufferLen) { + ensureNonNegative(validityBufferLen, "validityBuffeLen"); + ensureNonNegative(offsetBufferLen, "offsetBufferLen"); + ensureNonNegative(dataBufferLen, "dataBufferLen"); + this.validity = validity; + this.validityBufferLen = validityBufferLen; + this.offset = offset; + this.offsetBufferLen = offsetBufferLen; + this.data = data; + this.dataBufferLen = dataBufferLen; + } + + /** + * Get the validity buffer offset. + * @return {@value #INVALID_OFFSET} if the validity buffer is not present, otherwise the offset. + */ + long getValidity() { + return validity; + } + + /** + * Get a view of the validity buffer from underlying buffer. + * @param baseAddress the base address of underlying buffer. + * @return null if the validity buffer is not present, otherwise a view of the buffer. + */ + DeviceMemoryBufferView getValidityBuffer(long baseAddress) { + if (validity == INVALID_OFFSET) { + return null; + } + return new DeviceMemoryBufferView(validity + baseAddress, validityBufferLen); + } + + /** + * Get the offset buffer offset. + * @return {@value #INVALID_OFFSET} if the offset buffer is not present, otherwise the offset. + */ + long getOffset() { + return offset; + } + + /** + * Get a view of the offset buffer from underlying buffer. + * @param baseAddress the base address of underlying buffer. + * @return null if the offset buffer is not present, otherwise a view of the buffer. + */ + DeviceMemoryBufferView getOffsetBuffer(long baseAddress) { + if (offset == INVALID_OFFSET) { + return null; + } + return new DeviceMemoryBufferView(offset + baseAddress, offsetBufferLen); + } + + /** + * Get the data buffer offset. + * @return {@value #INVALID_OFFSET} if the data buffer is not present, otherwise the offset. + */ + long getData() { + return data; + } + + /** + * Get a view of the data buffer from underlying buffer. + * @param baseAddress the base address of underlying buffer. + * @return null if the data buffer is not present, otherwise a view of the buffer. + */ + DeviceMemoryBufferView getDataBuffer(long baseAddress) { + if (data == INVALID_OFFSET) { + return null; + } + return new DeviceMemoryBufferView(data + baseAddress, dataBufferLen); + } + + long getDataBufferLen() { + return dataBufferLen; + } + + @Override + public String toString() { + return "ColumnOffsets{" + + "validity=" + validity + + ", offset=" + offset + + ", data=" + data + + ", dataLen=" + dataBufferLen + + '}'; + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/ColumnViewInfo.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/ColumnViewInfo.java new file mode 100644 index 0000000000..002dff54c0 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/ColumnViewInfo.java @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni.kudo; + +import ai.rapids.cudf.*; + +import java.util.Optional; + +import static com.nvidia.spark.rapids.jni.Preconditions.ensureNonNegative; + +class ColumnViewInfo { + private final DType dtype; + private final ColumnOffsetInfo offsetInfo; + private final int nullCount; + private final int rowCount; + + public ColumnViewInfo(DType dtype, ColumnOffsetInfo offsetInfo, + int nullCount, int rowCount) { + ensureNonNegative(nullCount, "nullCount"); + ensureNonNegative(rowCount, "rowCount"); + this.dtype = dtype; + this.offsetInfo = offsetInfo; + this.nullCount = nullCount; + this.rowCount = rowCount; + } + + ColumnView buildColumnView(DeviceMemoryBuffer buffer, ColumnView[] childrenView) { + long baseAddress = buffer.getAddress(); + + if (dtype.isNestedType()) { + return new ColumnView(dtype, rowCount, Optional.of((long)nullCount), + offsetInfo.getValidityBuffer(baseAddress), + offsetInfo.getOffsetBuffer(baseAddress), + childrenView); + } else { + return new ColumnView(dtype, rowCount, Optional.of((long)nullCount), + offsetInfo.getDataBuffer(baseAddress), + offsetInfo.getValidityBuffer(baseAddress), + offsetInfo.getOffsetBuffer(baseAddress)); + } + } + + @Override + public String toString() { + return "ColumnViewInfo{" + + "dtype=" + dtype + + ", offsetInfo=" + offsetInfo + + ", nullCount=" + nullCount + + ", rowCount=" + rowCount + + '}'; + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataOutputStreamWriter.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataOutputStreamWriter.java new file mode 100644 index 0000000000..c88f125b2e --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataOutputStreamWriter.java @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni.kudo; + +import ai.rapids.cudf.HostMemoryBuffer; + +import java.io.DataOutputStream; +import java.io.IOException; + +/** + * Visible for testing + */ +class DataOutputStreamWriter extends DataWriter { + private final byte[] arrayBuffer = new byte[1024]; + private final DataOutputStream dout; + + public DataOutputStreamWriter(DataOutputStream dout) { + this.dout = dout; + } + + @Override + public void writeInt(int i) throws IOException { + dout.writeInt(i); + } + + @Override + public void copyDataFrom(HostMemoryBuffer src, long srcOffset, long len) throws IOException { + long dataLeft = len; + while (dataLeft > 0) { + int amountToCopy = (int) Math.min(arrayBuffer.length, dataLeft); + src.getBytes(arrayBuffer, 0, srcOffset, amountToCopy); + dout.write(arrayBuffer, 0, amountToCopy); + srcOffset += amountToCopy; + dataLeft -= amountToCopy; + } + } + + @Override + public void flush() throws IOException { + dout.flush(); + } + + @Override + public void write(byte[] arr, int offset, int length) throws IOException { + dout.write(arr, offset, length); + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataWriter.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataWriter.java new file mode 100644 index 0000000000..1f2e8f3dca --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataWriter.java @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni.kudo; + +import ai.rapids.cudf.HostMemoryBuffer; + +import java.io.IOException; + +/** + * Visible for testing + */ +abstract class DataWriter { + + public abstract void writeInt(int i) throws IOException; + + /** + * Copy data from src starting at srcOffset and going for len bytes. + * + * @param src where to copy from. + * @param srcOffset offset to start at. + * @param len amount to copy. + */ + public abstract void copyDataFrom(HostMemoryBuffer src, long srcOffset, long len) throws IOException; + + public void flush() throws IOException { + // NOOP by default + } + + public abstract void write(byte[] arr, int offset, int length) throws IOException; +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoHostMergeResult.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoHostMergeResult.java new file mode 100644 index 0000000000..6529f9e15e --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoHostMergeResult.java @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni.kudo; + +import ai.rapids.cudf.*; +import com.nvidia.spark.rapids.jni.Arms; +import com.nvidia.spark.rapids.jni.schema.Visitors; + +import java.util.List; + +import static com.nvidia.spark.rapids.jni.Preconditions.ensure; +import static java.util.Objects.requireNonNull; + +/** + * The result of merging several kudo tables into one contiguous table on the host. + */ +public class KudoHostMergeResult implements AutoCloseable { + private final Schema schema; + private final List columnInfoList; + private HostMemoryBuffer hostBuf; + + KudoHostMergeResult(Schema schema, HostMemoryBuffer hostBuf, List columnInfoList) { + requireNonNull(schema, "schema is null"); + requireNonNull(columnInfoList, "columnInfoList is null"); + ensure(schema.getFlattenedColumnNames().length == columnInfoList.size(), () -> + "Column offsets size does not match flattened schema size, column offsets size: " + columnInfoList.size() + + ", flattened schema size: " + schema.getFlattenedColumnNames().length); + this.schema = schema; + this.columnInfoList = columnInfoList; + this.hostBuf = requireNonNull(hostBuf, "hostBuf is null"); + } + + @Override + public void close() throws Exception { + hostBuf.close(); + hostBuf = null; + } + + /** + * Get the length of the data in the host buffer. + * @return the length of the data in the host buffer + */ + public long getDataLength() { + return hostBuf.getLength(); + } + + /** + * Convert the host buffer into a cudf table. + * @return the cudf table + */ + public Table toTable() { + try (DeviceMemoryBuffer deviceMemBuf = DeviceMemoryBuffer.allocate(hostBuf.getLength())) { + if (hostBuf.getLength() > 0) { + deviceMemBuf.copyFromHostBufferAsync(hostBuf, Cuda.DEFAULT_STREAM); + } + + try (TableBuilder builder = new TableBuilder(columnInfoList, deviceMemBuf)) { + Table t = Visitors.visitSchema(schema, builder); + + Cuda.DEFAULT_STREAM.sync(); + return t; + } catch (Exception e) { + throw new RuntimeException(e); + } + } + } + + @Override + public String toString() { + return "HostMergeResult{" + + "columnOffsets=" + columnInfoList + + ", hostBuf length =" + hostBuf.getLength() + + '}'; + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java new file mode 100644 index 0000000000..6370531428 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java @@ -0,0 +1,360 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni.kudo; + +import ai.rapids.cudf.*; +import com.nvidia.spark.rapids.jni.Pair; +import com.nvidia.spark.rapids.jni.schema.Visitors; + +import java.io.*; +import java.util.Arrays; +import java.util.List; +import java.util.function.LongConsumer; +import java.util.function.Supplier; +import java.util.stream.IntStream; + +import static com.nvidia.spark.rapids.jni.Preconditions.ensure; +import static java.util.Objects.requireNonNull; + +/** + * This class is used to serialize/deserialize a table using the Kudo format. + * + *

Background

+ *

+ * The Kudo format is a binary format that is optimized for serializing/deserializing a table partition during Spark + * shuffle. The optimizations are based on two key observations: + * + *

    + *
  1. The binary format doesn't need to be self descriptive, since shuffle runtime could provide information such + * as schema, which helped us to reduce header size a lot. + *
  2. + *
  3. In most cases we need to concat several small tables into a larger table during shuffle read time, since + * gpu's vectorized execution engine typically requires larger batch size, which makes write time concatenation + * meaningless. This relaxed the requirement of calculating exact validity buffer and offset buffer at write time, + * which makes write almost a memory copy process, without sacrificing read performance much. + *
  4. + *
+ * + *

Format

+ *

+ * Similar to {@link JCudfSerialization}, it still consists of two parts: header and body. + * + *

Header

+ *

+ * Header consists of following fields: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
Field NameSizeComments
Magic Number4ASCII codes for "KUD0"
Offset4Row offset in original table, in big endian format
Number of rows4Number of rows, in big endian format
Length of validity buffer4Length of validity buffer, in big endian format
Length of offset buffer4Length of offset buffer, in big endian format
Length of total body4Length of total body, in big endian format
Number of columns4Number of columns in flattened schema, in big endian format. For details of flattened schema, + * see {@link com.nvidia.spark.rapids.jni.schema.SchemaVisitor} + *
hasValidityBuffer(number of columns + 7) / 8A bit set to indicate whether a column has validity buffer. To test if column + * coli has validity buffer, use the following code: + *
+ * + * int pos = coli / 8;
+ * int bit = coli % 8;
+ * return (hasValidityBuffer[pos] & (1 << bit)) != 0; + *
+ *
+ * + *

Body

+ *

+ * The body consists of three part: + *

    + *
  1. Validity buffers for every column with validity in depth-first ordering of schema columns. Each buffer of + * each column is 4 bytes padded. + *
  2. + *
  3. Offset buffers for every column with offsets in depth-first ordering of schema columns. Each buffer of each + * column is 4 bytes padded.
  4. + *
  5. Data buffers for every column with data in depth-first ordering of schema columns. Each buffer of each + * column is 4 bytes padded.
  6. + *
+ * + *

Serialization

+ *

+ * The serialization process writes the header first, then writes the body. There are two optimizations when writing + * validity buffer and offset buffer: + * + *

    + *
  1. For validity buffer, it only copies buffers without calculating an exact validity buffer. For example, when + * we want to serialize rows [3, 9) of the original table, instead of calculating the exact validity buffer, it + * just copies first two bytes of the validity buffer. At read time, the deserializer will know that the true + * validity buffer starts from the fourth bit, since we have recorded the row offset in the header. + *
  2. + *
  3. For offset buffer, it only copies buffers without calculating an exact offset buffer. For example, when we want + * * to serialize rows [3, 9) of the original table, instead of calculating the exact offset values by subtracting + * * first value, it just copies the offset buffer values of rows [3, 9). + * *
  4. + *
+ */ +public class KudoSerializer { + + private static final byte[] PADDING = new byte[64]; + private static final BufferType[] ALL_BUFFER_TYPES = new BufferType[]{BufferType.VALIDITY, BufferType.OFFSET, + BufferType.DATA}; + + static { + Arrays.fill(PADDING, (byte) 0); + } + + private final Schema schema; + private final int flattenedColumnCount; + + public KudoSerializer(Schema schema) { + requireNonNull(schema, "schema is null"); + this.schema = schema; + this.flattenedColumnCount = schema.getFlattenedColumnNames().length; + } + + /** + * Write partition of a table to a stream. This method is used for test only. + *
+ * The caller should ensure that table's schema matches the schema used to create this serializer, otherwise behavior + * is undefined. + * + * @param table table to write + * @param out output stream + * @param rowOffset row offset in original table + * @param numRows number of rows to write + * @return number of bytes written + */ + long writeToStream(Table table, OutputStream out, int rowOffset, int numRows) { + HostColumnVector[] columns = null; + try { + columns = IntStream.range(0, table.getNumberOfColumns()) + .mapToObj(table::getColumn) + .map(c -> c.copyToHostAsync(Cuda.DEFAULT_STREAM)) + .toArray(HostColumnVector[]::new); + + Cuda.DEFAULT_STREAM.sync(); + return writeToStream(columns, out, rowOffset, numRows); + } finally { + if (columns != null) { + for (HostColumnVector column : columns) { + column.close(); + } + } + } + } + + /** + * Write partition of an array of {@link HostColumnVector} to an output stream. + *
+ *

+ * The caller should ensure that table's schema matches the schema used to create this serializer, otherwise behavior + * is undefined. + * + * @param columns columns to write + * @param out output stream + * @param rowOffset row offset in original column vector. + * @param numRows number of rows to write + * @return number of bytes written + */ + public long writeToStream(HostColumnVector[] columns, OutputStream out, int rowOffset, int numRows) { + ensure(numRows > 0, () -> "numRows must be > 0, but was " + numRows); + ensure(columns.length > 0, () -> "columns must not be empty, for row count only records " + + "please call writeRowCountToStream"); + + try { + return writeSliced(columns, writerFrom(out), rowOffset, numRows); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + /** + * Write a row count only record to an output stream. + * + * @param out output stream + * @param numRows number of rows to write + * @return number of bytes written + */ + public static long writeRowCountToStream(OutputStream out, int numRows) { + if (numRows <= 0) { + throw new IllegalArgumentException("Number of rows must be > 0, but was " + numRows); + } + try { + DataWriter writer = writerFrom(out); + KudoTableHeader header = new KudoTableHeader(0, numRows, 0, 0, 0 + , 0, new byte[0]); + header.writeTo(writer); + writer.flush(); + return header.getSerializedSize(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + /** + * Merge a list of kudo tables into a table on host memory. + *
+ * The caller should ensure that the {@link KudoSerializer} used to generate kudo tables have same schema as current + * {@link KudoSerializer}, otherwise behavior is undefined. + * + * @param kudoTables list of kudo tables. This method doesn't take ownership of the input tables, and caller should + * take care of closing them after calling this method. + * @return the merged table, and metrics during merge. + */ + public Pair mergeOnHost(List kudoTables) { + MergeMetrics.Builder metricsBuilder = MergeMetrics.builder(); + + MergedInfoCalc mergedInfoCalc = withTime(() -> MergedInfoCalc.calc(schema, kudoTables), + metricsBuilder::calcHeaderTime); + KudoHostMergeResult result = withTime(() -> KudoTableMerger.merge(schema, mergedInfoCalc), + metricsBuilder::mergeIntoHostBufferTime); + return Pair.of(result, metricsBuilder.build()); + + } + + /** + * Merge a list of kudo tables into a contiguous table. + *
+ * The caller should ensure that the {@link KudoSerializer} used to generate kudo tables have same schema as current + * {@link KudoSerializer}, otherwise behavior is undefined. + * + * @param kudoTables list of kudo tables. This method doesn't take ownership of the input tables, and caller should + * take care of closing them after calling this method. + * @return the merged table, and metrics during merge. + * @throws Exception if any error occurs during merge. + */ + public Pair mergeToTable(List kudoTables) throws Exception { + Pair result = mergeOnHost(kudoTables); + MergeMetrics.Builder builder = MergeMetrics.builder(result.getRight()); + try (KudoHostMergeResult children = result.getLeft()) { + Table table = withTime(children::toTable, + builder::convertToTableTime); + + return Pair.of(table, builder.build()); + } + } + + private long writeSliced(HostColumnVector[] columns, DataWriter out, int rowOffset, int numRows) throws Exception { + KudoTableHeaderCalc headerCalc = new KudoTableHeaderCalc(rowOffset, numRows, flattenedColumnCount); + Visitors.visitColumns(columns, headerCalc); + KudoTableHeader header = headerCalc.getHeader(); + header.writeTo(out); + + long bytesWritten = 0; + for (BufferType bufferType : ALL_BUFFER_TYPES) { + SlicedBufferSerializer serializer = new SlicedBufferSerializer(rowOffset, numRows, bufferType, out); + Visitors.visitColumns(columns, serializer); + bytesWritten += serializer.getTotalDataLen(); + } + + if (bytesWritten != header.getTotalDataLen()) { + throw new IllegalStateException("Header total data length: " + header.getTotalDataLen() + + " does not match actual written data length: " + bytesWritten + + ", rowOffset: " + rowOffset + " numRows: " + numRows); + } + + out.flush(); + + return header.getSerializedSize() + bytesWritten; + } + + private static DataWriter writerFrom(OutputStream out) { + if (!(out instanceof DataOutputStream)) { + out = new DataOutputStream(new BufferedOutputStream(out)); + } + return new DataOutputStreamWriter((DataOutputStream) out); + } + + + static long padForHostAlignment(long orig) { + return ((orig + 3) / 4) * 4; + } + + static long padForHostAlignment(DataWriter out, long bytes) throws IOException { + final long paddedBytes = padForHostAlignment(bytes); + if (paddedBytes > bytes) { + out.write(PADDING, 0, (int) (paddedBytes - bytes)); + } + return paddedBytes; + } + + static long padFor64byteAlignment(long orig) { + return ((orig + 63) / 64) * 64; + } + + static DataInputStream readerFrom(InputStream in) { + if (in instanceof DataInputStream) { + return (DataInputStream) in; + } + return new DataInputStream(in); + } + + static T withTime(Supplier task, LongConsumer timeConsumer) { + long now = System.nanoTime(); + T ret = task.get(); + timeConsumer.accept(System.nanoTime() - now); + return ret; + } + + /** + * This method returns the length in bytes needed to represent X number of rows + * e.g. getValidityLengthInBytes(5) => 1 byte + * getValidityLengthInBytes(7) => 1 byte + * getValidityLengthInBytes(14) => 2 bytes + */ + static long getValidityLengthInBytes(long rows) { + return (rows + 7) / 8; + } +} \ No newline at end of file diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTable.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTable.java new file mode 100644 index 0000000000..c49b2cb8f7 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTable.java @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni.kudo; + +import ai.rapids.cudf.HostMemoryBuffer; +import com.nvidia.spark.rapids.jni.Arms; + +import java.io.DataInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.Optional; + +import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.readerFrom; +import static java.util.Objects.requireNonNull; + +/** + * Serialized table in kudo format, including a {{@link KudoTableHeader}} and a {@link HostMemoryBuffer} for serialized + * data. + */ +public class KudoTable implements AutoCloseable { + private final KudoTableHeader header; + private final HostMemoryBuffer buffer; + + /** + * Create a kudo table. + * + * @param header kudo table header + * @param buffer host memory buffer for the table data. KudoTable will take ownership of this buffer, so don't close + * it after passing it to this constructor. + */ + public KudoTable(KudoTableHeader header, HostMemoryBuffer buffer) { + requireNonNull(header, "Header must not be null"); + this.header = header; + this.buffer = buffer; + } + + /** + * Read a kudo table from an input stream. + * + * @param in input stream + * @return the kudo table, or empty if the input stream is empty. + * @throws IOException if an I/O error occurs + */ + public static Optional from(InputStream in) throws IOException { + requireNonNull(in, "Input stream must not be null"); + + DataInputStream din = readerFrom(in); + return KudoTableHeader.readFrom(din).map(header -> { + // Header only + if (header.getNumColumns() == 0) { + return new KudoTable(header, null); + } + + return Arms.closeIfException(HostMemoryBuffer.allocate(header.getTotalDataLen(), false), buffer -> { + try { + buffer.copyFromStream(0, din, header.getTotalDataLen()); + return new KudoTable(header, buffer); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + }); + } + + public KudoTableHeader getHeader() { + return header; + } + + public HostMemoryBuffer getBuffer() { + return buffer; + } + + @Override + public String toString() { + return "SerializedTable{" + + "header=" + header + + ", buffer=" + buffer + + '}'; + } + + @Override + public void close() throws Exception { + if (buffer != null) { + buffer.close(); + } + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableHeader.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableHeader.java new file mode 100644 index 0000000000..2bf5449c7a --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableHeader.java @@ -0,0 +1,194 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni.kudo; + +import java.io.DataInputStream; +import java.io.EOFException; +import java.io.IOException; +import java.util.Arrays; +import java.util.Optional; + +import static com.nvidia.spark.rapids.jni.Preconditions.ensure; +import static com.nvidia.spark.rapids.jni.Preconditions.ensureNonNegative; +import static java.util.Objects.requireNonNull; + +/** + * Holds the metadata about a serialized table. If this is being read from a stream + * isInitialized will return true if the metadata was read correctly from the stream. + * It will return false if an EOF was encountered at the beginning indicating that + * there was no data to be read. + */ +public final class KudoTableHeader { + /** + * Magic number "KUD0" in ASCII. + */ + private static final int SER_FORMAT_MAGIC_NUMBER = 0x4B554430; + + // The offset in the original table where row starts. For example, if we want to serialize rows [3, 9) of the + // original table, offset would be 3, and numRows would be 6. + private final int offset; + private final int numRows; + private final int validityBufferLen; + private final int offsetBufferLen; + private final int totalDataLen; + private final int numColumns; + // A bit set to indicate if a column has a validity buffer or not. Each column is represented by a single bit. + private final byte[] hasValidityBuffer; + + /** + * Reads the table header from the given input stream. + * + * @param din input stream + * @return the table header. If an EOFException is encountered at the beginning, returns empty result. + * @throws IOException if an I/O error occurs + */ + public static Optional readFrom(DataInputStream din) throws IOException { + int num; + try { + num = din.readInt(); + if (num != SER_FORMAT_MAGIC_NUMBER) { + throw new IllegalStateException("Kudo format error, expected magic number " + SER_FORMAT_MAGIC_NUMBER + + " found " + num); + } + } catch (EOFException e) { + // If we get an EOF at the very beginning don't treat it as an error because we may + // have finished reading everything... + return Optional.empty(); + } + + int offset = din.readInt(); + int numRows = din.readInt(); + + int validityBufferLen = din.readInt(); + int offsetBufferLen = din.readInt(); + int totalDataLen = din.readInt(); + int numColumns = din.readInt(); + int validityBufferLength = lengthOfHasValidityBuffer(numColumns); + byte[] hasValidityBuffer = new byte[validityBufferLength]; + din.readFully(hasValidityBuffer); + + return Optional.of(new KudoTableHeader(offset, numRows, validityBufferLen, offsetBufferLen, totalDataLen, numColumns, + hasValidityBuffer)); + } + + KudoTableHeader(int offset, int numRows, int validityBufferLen, int offsetBufferLen, + int totalDataLen, int numColumns, byte[] hasValidityBuffer) { + this.offset = ensureNonNegative(offset, "offset"); + this.numRows = ensureNonNegative(numRows, "numRows"); + this.validityBufferLen = ensureNonNegative(validityBufferLen, "validityBufferLen"); + this.offsetBufferLen = ensureNonNegative(offsetBufferLen, "offsetBufferLen"); + this.totalDataLen = ensureNonNegative(totalDataLen, "totalDataLen"); + this.numColumns = ensureNonNegative(numColumns, "numColumns"); + + requireNonNull(hasValidityBuffer, "hasValidityBuffer cannot be null"); + ensure(hasValidityBuffer.length == lengthOfHasValidityBuffer(numColumns), + () -> numColumns + " columns expects hasValidityBuffer with length " + lengthOfHasValidityBuffer(numColumns) + + ", but found " + hasValidityBuffer.length); + this.hasValidityBuffer = hasValidityBuffer; + } + + /** + * Returns the size of a buffer needed to read data into the stream. + */ + public int getTotalDataLen() { + return totalDataLen; + } + + /** + * Returns the number of rows stored in this table. + */ + public int getNumRows() { + return numRows; + } + + public int getOffset() { + return offset; + } + + public boolean hasValidityBuffer(int columnIndex) { + int pos = columnIndex / 8; + int bit = columnIndex % 8; + return (hasValidityBuffer[pos] & (1 << bit)) != 0; + } + + /** + * Get the size of the serialized header. + * + *

+ * It consists of the following fields: + *

    + *
  1. Magic Number
  2. + *
  3. Row Offset
  4. + *
  5. Number of rows
  6. + *
  7. Validity buffer length
  8. + *
  9. Offset buffer length
  10. + *
  11. Total data length
  12. + *
  13. Number of columns
  14. + *
  15. hasValidityBuffer
  16. + *
+ *

+ * For more details of each field, please refer to {@link KudoSerializer}. + *

+ * + * @return the size of the serialized header. + */ + public int getSerializedSize() { + return 7 * Integer.BYTES + hasValidityBuffer.length; + } + + public int getNumColumns() { + return numColumns; + } + + public int getValidityBufferLen() { + return validityBufferLen; + } + + public int getOffsetBufferLen() { + return offsetBufferLen; + } + + public void writeTo(DataWriter dout) throws IOException { + // Now write out the data + dout.writeInt(SER_FORMAT_MAGIC_NUMBER); + + dout.writeInt(offset); + dout.writeInt(numRows); + dout.writeInt(validityBufferLen); + dout.writeInt(offsetBufferLen); + dout.writeInt(totalDataLen); + dout.writeInt(numColumns); + dout.write(hasValidityBuffer, 0, hasValidityBuffer.length); + } + + @Override + public String toString() { + return "SerializedTableHeader{" + + "offset=" + offset + + ", numRows=" + numRows + + ", validityBufferLen=" + validityBufferLen + + ", offsetBufferLen=" + offsetBufferLen + + ", totalDataLen=" + totalDataLen + + ", numColumns=" + numColumns + + ", hasValidityBuffer=" + Arrays.toString(hasValidityBuffer) + + '}'; + } + + private static int lengthOfHasValidityBuffer(int numColumns) { + return (numColumns + 7) / 8; + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableHeaderCalc.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableHeaderCalc.java new file mode 100644 index 0000000000..4eaa1c435c --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableHeaderCalc.java @@ -0,0 +1,186 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni.kudo; + +import ai.rapids.cudf.DType; +import ai.rapids.cudf.HostColumnVectorCore; +import com.nvidia.spark.rapids.jni.schema.HostColumnsVisitor; + +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.List; + +import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padForHostAlignment; +import static java.lang.Math.toIntExact; + +/** + * This class visits a list of columns and calculates the serialized table header. + * + *

+ * The columns are visited in post order, and for more details about the visiting process, please refer to + * {@link HostColumnsVisitor}. + *

+ */ +class KudoTableHeaderCalc implements HostColumnsVisitor { + private final SliceInfo root; + private final int numFlattenedCols; + private final byte[] bitset; + private long validityBufferLen; + private long offsetBufferLen; + private long totalDataLen; + private int nextColIdx; + + private Deque sliceInfos = new ArrayDeque<>(); + + KudoTableHeaderCalc(int rowOffset, int numRows, int numFlattenedCols) { + this.root = new SliceInfo(rowOffset, numRows); + this.totalDataLen = 0; + sliceInfos.addLast(this.root); + this.bitset = new byte[(numFlattenedCols + 7) / 8]; + this.numFlattenedCols = numFlattenedCols; + this.nextColIdx = 0; + } + + public KudoTableHeader getHeader() { + return new KudoTableHeader(toIntExact(root.offset), + toIntExact(root.rowCount), + toIntExact(validityBufferLen), + toIntExact(offsetBufferLen), + toIntExact(totalDataLen), + numFlattenedCols, + bitset); + } + + @Override + public Void visitStruct(HostColumnVectorCore col, List children) { + SliceInfo parent = sliceInfos.getLast(); + + long validityBufferLength = 0; + if (col.hasValidityVector()) { + validityBufferLength = padForHostAlignment(parent.getValidityBufferInfo().getBufferLength()); + } + + this.validityBufferLen += validityBufferLength; + + totalDataLen += validityBufferLength; + this.setHasValidity(col.hasValidityVector()); + return null; + } + + @Override + public Void preVisitList(HostColumnVectorCore col) { + SliceInfo parent = sliceInfos.getLast(); + + + long validityBufferLength = 0; + if (col.hasValidityVector() && parent.rowCount > 0) { + validityBufferLength = padForHostAlignment(parent.getValidityBufferInfo().getBufferLength()); + } + + long offsetBufferLength = 0; + if (col.getOffsets() != null && parent.rowCount > 0) { + offsetBufferLength = padForHostAlignment((parent.rowCount + 1) * Integer.BYTES); + } + + this.validityBufferLen += validityBufferLength; + this.offsetBufferLen += offsetBufferLength; + this.totalDataLen += validityBufferLength + offsetBufferLength; + + this.setHasValidity(col.hasValidityVector()); + + SliceInfo current; + + if (col.getOffsets() != null) { + int start = col.getOffsets().getInt(parent.offset * Integer.BYTES); + int end = col.getOffsets().getInt((parent.offset + parent.rowCount) * Integer.BYTES); + int rowCount = end - start; + current = new SliceInfo(start, rowCount); + } else { + current = new SliceInfo(0, 0); + } + + sliceInfos.addLast(current); + return null; + } + + @Override + public Void visitList(HostColumnVectorCore col, Void preVisitResult, Void childResult) { + sliceInfos.removeLast(); + + return null; + } + + + @Override + public Void visit(HostColumnVectorCore col) { + SliceInfo parent = sliceInfos.peekLast(); + long validityBufferLen = dataLenOfValidityBuffer(col, parent); + long offsetBufferLen = dataLenOfOffsetBuffer(col, parent); + long dataBufferLen = dataLenOfDataBuffer(col, parent); + + this.validityBufferLen += validityBufferLen; + this.offsetBufferLen += offsetBufferLen; + this.totalDataLen += validityBufferLen + offsetBufferLen + dataBufferLen; + + this.setHasValidity(col.hasValidityVector()); + + return null; + } + + private void setHasValidity(boolean hasValidityBuffer) { + if (hasValidityBuffer) { + int bytePos = nextColIdx / 8; + int bitPos = nextColIdx % 8; + bitset[bytePos] = (byte) (bitset[bytePos] | (1 << bitPos)); + } + nextColIdx++; + } + + private static long dataLenOfValidityBuffer(HostColumnVectorCore col, SliceInfo info) { + if (col.hasValidityVector() && info.getRowCount() > 0) { + return padForHostAlignment(info.getValidityBufferInfo().getBufferLength()); + } else { + return 0; + } + } + + private static long dataLenOfOffsetBuffer(HostColumnVectorCore col, SliceInfo info) { + if (DType.STRING.equals(col.getType()) && info.getRowCount() > 0) { + return padForHostAlignment((info.rowCount + 1) * Integer.BYTES); + } else { + return 0; + } + } + + private static long dataLenOfDataBuffer(HostColumnVectorCore col, SliceInfo info) { + if (DType.STRING.equals(col.getType())) { + if (col.getOffsets() != null) { + long startByteOffset = col.getOffsets().getInt(info.offset * Integer.BYTES); + long endByteOffset = col.getOffsets().getInt((info.offset + info.rowCount) * Integer.BYTES); + return padForHostAlignment(endByteOffset - startByteOffset); + } else { + return 0; + } + } else { + if (col.getType().getSizeInBytes() > 0) { + return padForHostAlignment(col.getType().getSizeInBytes() * info.rowCount); + } else { + return 0; + } + } + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableMerger.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableMerger.java new file mode 100644 index 0000000000..af80391f3d --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableMerger.java @@ -0,0 +1,332 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni.kudo; + +import ai.rapids.cudf.HostMemoryBuffer; +import ai.rapids.cudf.Schema; +import com.nvidia.spark.rapids.jni.Arms; +import com.nvidia.spark.rapids.jni.schema.Visitors; + +import java.nio.ByteOrder; +import java.nio.IntBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.OptionalInt; + +import static com.nvidia.spark.rapids.jni.Preconditions.ensure; +import static com.nvidia.spark.rapids.jni.kudo.ColumnOffsetInfo.INVALID_OFFSET; +import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.getValidityLengthInBytes; +import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padFor64byteAlignment; +import static java.lang.Math.min; +import static java.lang.Math.toIntExact; +import static java.util.Objects.requireNonNull; + +/** + * This class is used to merge multiple KudoTables into a single contiguous buffer, e.g. {@link KudoHostMergeResult}, + * which could be easily converted to a {@link ai.rapids.cudf.ContiguousTable}. + */ +class KudoTableMerger extends MultiKudoTableVisitor { + // Number of 1s in a byte + private static final int[] NUMBER_OF_ONES = new int[256]; + + static { + for (int i = 0; i < NUMBER_OF_ONES.length; i += 1) { + int count = 0; + for (int j = 0; j < 8; j += 1) { + if ((i & (1 << j)) != 0) { + count += 1; + } + } + NUMBER_OF_ONES[i] = count; + } + } + + private final List columnOffsets; + private final HostMemoryBuffer buffer; + private final List colViewInfoList; + + public KudoTableMerger(List tables, HostMemoryBuffer buffer, List columnOffsets) { + super(tables); + requireNonNull(buffer, "buffer can't be null!"); + ensure(columnOffsets != null, "column offsets cannot be null"); + ensure(!columnOffsets.isEmpty(), "column offsets cannot be empty"); + this.columnOffsets = columnOffsets; + this.buffer = buffer; + this.colViewInfoList = new ArrayList<>(columnOffsets.size()); + } + + @Override + protected KudoHostMergeResult doVisitTopSchema(Schema schema, List children) { + return new KudoHostMergeResult(schema, buffer, colViewInfoList); + } + + @Override + protected Void doVisitStruct(Schema structType, List children) { + ColumnOffsetInfo offsetInfo = getCurColumnOffsets(); + int nullCount = deserializeValidityBuffer(offsetInfo); + int totalRowCount = getTotalRowCount(); + colViewInfoList.add(new ColumnViewInfo(structType.getType(), + offsetInfo, nullCount, totalRowCount)); + return null; + } + + @Override + protected Void doPreVisitList(Schema listType) { + ColumnOffsetInfo offsetInfo = getCurColumnOffsets(); + int nullCount = deserializeValidityBuffer(offsetInfo); + int totalRowCount = getTotalRowCount(); + deserializeOffsetBuffer(offsetInfo); + + colViewInfoList.add(new ColumnViewInfo(listType.getType(), + offsetInfo, nullCount, totalRowCount)); + return null; + } + + @Override + protected Void doVisitList(Schema listType, Void preVisitResult, Void childResult) { + return null; + } + + @Override + protected Void doVisit(Schema primitiveType) { + ColumnOffsetInfo offsetInfo = getCurColumnOffsets(); + int nullCount = deserializeValidityBuffer(offsetInfo); + int totalRowCount = getTotalRowCount(); + if (primitiveType.getType().hasOffsets()) { + deserializeOffsetBuffer(offsetInfo); + deserializeDataBuffer(offsetInfo, OptionalInt.empty()); + } else { + deserializeDataBuffer(offsetInfo, OptionalInt.of(primitiveType.getType().getSizeInBytes())); + } + + colViewInfoList.add(new ColumnViewInfo(primitiveType.getType(), + offsetInfo, nullCount, totalRowCount)); + + return null; + } + + private int deserializeValidityBuffer(ColumnOffsetInfo curColOffset) { + if (curColOffset.getValidity() != INVALID_OFFSET) { + long offset = curColOffset.getValidity(); + long validityBufferSize = padFor64byteAlignment(getValidityLengthInBytes(getTotalRowCount())); + try (HostMemoryBuffer validityBuffer = buffer.slice(offset, validityBufferSize)) { + int nullCountTotal = 0; + int startRow = 0; + for (int tableIdx = 0; tableIdx < getTableSize(); tableIdx += 1) { + SliceInfo sliceInfo = sliceInfoOf(tableIdx); + long validityOffset = validifyBufferOffset(tableIdx); + if (validityOffset != INVALID_OFFSET) { + nullCountTotal += copyValidityBuffer(validityBuffer, startRow, + memoryBufferOf(tableIdx), toIntExact(validityOffset), + sliceInfo); + } else { + appendAllValid(validityBuffer, startRow, sliceInfo.getRowCount()); + } + + startRow += sliceInfo.getRowCount(); + } + return nullCountTotal; + } + } else { + return 0; + } + } + + /** + * Copy a sliced validity buffer to the destination buffer, starting at the given bit offset. + * + * @return Number of nulls in the validity buffer. + */ + private static int copyValidityBuffer(HostMemoryBuffer dest, int startBit, + HostMemoryBuffer src, int srcOffset, + SliceInfo sliceInfo) { + int nullCount = 0; + int totalRowCount = sliceInfo.getRowCount(); + int curIdx = 0; + int curSrcByteIdx = srcOffset; + int curSrcBitIdx = sliceInfo.getValidityBufferInfo().getBeginBit(); + int curDestByteIdx = startBit / 8; + int curDestBitIdx = startBit % 8; + + while (curIdx < totalRowCount) { + int leftRowCount = totalRowCount - curIdx; + int appendCount; + if (curDestBitIdx == 0) { + appendCount = min(8, leftRowCount); + } else { + appendCount = min(8 - curDestBitIdx, leftRowCount); + } + + int leftBitsInCurSrcByte = 8 - curSrcBitIdx; + byte srcByte = src.getByte(curSrcByteIdx); + if (leftBitsInCurSrcByte >= appendCount) { + // Extract appendCount bits from srcByte, starting from curSrcBitIdx + byte mask = (byte) (((1 << appendCount) - 1) & 0xFF); + srcByte = (byte) ((srcByte >>> curSrcBitIdx) & mask); + + nullCount += (appendCount - NUMBER_OF_ONES[srcByte & 0xFF]); + + // Sets the bits in destination buffer starting from curDestBitIdx to 0 + byte destByte = dest.getByte(curDestByteIdx); + destByte = (byte) (destByte & ((1 << curDestBitIdx) - 1) & 0xFF); + + // Update destination byte with the bits from source byte + destByte = (byte) ((destByte | (srcByte << curDestBitIdx)) & 0xFF); + dest.setByte(curDestByteIdx, destByte); + + curSrcBitIdx += appendCount; + if (curSrcBitIdx == 8) { + curSrcBitIdx = 0; + curSrcByteIdx += 1; + } + } else { + // Extract appendCount bits from srcByte, starting from curSrcBitIdx + byte mask = (byte) (((1 << leftBitsInCurSrcByte) - 1) & 0xFF); + srcByte = (byte) ((srcByte >>> curSrcBitIdx) & mask); + + byte nextSrcByte = src.getByte(curSrcByteIdx + 1); + byte nextSrcByteMask = (byte) ((1 << (appendCount - leftBitsInCurSrcByte)) - 1); + nextSrcByte = (byte) (nextSrcByte & nextSrcByteMask); + nextSrcByte = (byte) (nextSrcByte << leftBitsInCurSrcByte); + srcByte = (byte) (srcByte | nextSrcByte); + + nullCount += (appendCount - NUMBER_OF_ONES[srcByte & 0xFF]); + + // Sets the bits in destination buffer starting from curDestBitIdx to 0 + byte destByte = dest.getByte(curDestByteIdx); + destByte = (byte) (destByte & ((1 << curDestBitIdx) - 1)); + + // Update destination byte with the bits from source byte + destByte = (byte) (destByte | (srcByte << curDestBitIdx)); + dest.setByte(curDestByteIdx, destByte); + + // Update the source byte index and bit index + curSrcByteIdx += 1; + curSrcBitIdx = appendCount - leftBitsInCurSrcByte; + } + + curIdx += appendCount; + + // Update the destination byte index and bit index + curDestBitIdx += appendCount; + if (curDestBitIdx == 8) { + curDestBitIdx = 0; + curDestByteIdx += 1; + } + } + + return nullCount; + } + + private static void appendAllValid(HostMemoryBuffer dest, int startBit, int numRowsLong) { + int numRows = toIntExact(numRowsLong); + int curDestByteIdx = startBit / 8; + int curDestBitIdx = startBit % 8; + + if (curDestBitIdx > 0) { + int numBits = 8 - curDestBitIdx; + int mask = ((1 << numBits) - 1) << curDestBitIdx; + dest.setByte(curDestByteIdx, (byte) (dest.getByte(curDestByteIdx) | mask)); + curDestByteIdx += 1; + numRows -= numBits; + } + + if (numRows > 0) { + int numBytes = (numRows + 7) / 8; + dest.setMemory(curDestByteIdx, numBytes, (byte) 0xFF); + } + } + + private void deserializeOffsetBuffer(ColumnOffsetInfo curColOffset) { + if (curColOffset.getOffset() != INVALID_OFFSET) { + long offset = curColOffset.getOffset(); + long bufferSize = Integer.BYTES * (getTotalRowCount() + 1); + + IntBuffer buf = buffer + .asByteBuffer(offset, toIntExact(bufferSize)) + .order(ByteOrder.LITTLE_ENDIAN) + .asIntBuffer(); + + int accumulatedDataLen = 0; + + for (int tableIdx = 0; tableIdx < getTableSize(); tableIdx += 1) { + SliceInfo sliceInfo = sliceInfoOf(tableIdx); + + if (sliceInfo.getRowCount() > 0) { + int rowCnt = sliceInfo.getRowCount(); + + int firstOffset = offsetOf(tableIdx, 0); + int lastOffset = offsetOf(tableIdx, rowCnt); + + for (int i = 0; i < rowCnt; i += 1) { + buf.put(offsetOf(tableIdx, i) - firstOffset + accumulatedDataLen); + } + + accumulatedDataLen += (lastOffset - firstOffset); + } + } + + buf.put(accumulatedDataLen); + } + } + + private void deserializeDataBuffer(ColumnOffsetInfo curColOffset, OptionalInt sizeInBytes) { + if (curColOffset.getData() != INVALID_OFFSET && curColOffset.getDataBufferLen() > 0) { + long offset = curColOffset.getData(); + long dataLen = curColOffset.getDataBufferLen(); + + try (HostMemoryBuffer buf = buffer.slice(offset, dataLen)) { + if (sizeInBytes.isPresent()) { + // Fixed size type + int elementSize = sizeInBytes.getAsInt(); + + long start = 0; + for (int tableIdx = 0; tableIdx < getTableSize(); tableIdx += 1) { + SliceInfo sliceInfo = sliceInfoOf(tableIdx); + if (sliceInfo.getRowCount() > 0) { + int thisDataLen = toIntExact(elementSize * sliceInfo.getRowCount()); + copyDataBuffer(buf, start, tableIdx, thisDataLen); + start += thisDataLen; + } + } + } else { + // String type + long start = 0; + for (int tableIdx = 0; tableIdx < getTableSize(); tableIdx += 1) { + int thisDataLen = getStrDataLenOf(tableIdx); + copyDataBuffer(buf, start, tableIdx, thisDataLen); + start += thisDataLen; + } + } + } + } + } + + + private ColumnOffsetInfo getCurColumnOffsets() { + return columnOffsets.get(getCurrentIdx()); + } + + static KudoHostMergeResult merge(Schema schema, MergedInfoCalc mergedInfo) { + List serializedTables = mergedInfo.getTables(); + return Arms.closeIfException(HostMemoryBuffer.allocate(mergedInfo.getTotalDataLen()), + buffer -> { + KudoTableMerger merger = new KudoTableMerger(serializedTables, buffer, mergedInfo.getColumnOffsets()); + return Visitors.visitSchema(schema, merger); + }); + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/MergeMetrics.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/MergeMetrics.java new file mode 100644 index 0000000000..e621129dd6 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/MergeMetrics.java @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni.kudo; + +public class MergeMetrics { + // The time it took to calculate combined header in nanoseconds + private final long calcHeaderTime; + // The time it took to merge the buffers into the host buffer in nanoseconds + private final long mergeIntoHostBufferTime; + // The time it took to convert the host buffer into a contiguous table in nanoseconds + private final long convertToTableTime; + + public MergeMetrics(long calcHeaderTime, long mergeIntoHostBufferTime, + long convertToTableTime) { + this.calcHeaderTime = calcHeaderTime; + this.mergeIntoHostBufferTime = mergeIntoHostBufferTime; + this.convertToTableTime = convertToTableTime; + } + + public long getCalcHeaderTime() { + return calcHeaderTime; + } + + public long getMergeIntoHostBufferTime() { + return mergeIntoHostBufferTime; + } + + public long getConvertToTableTime() { + return convertToTableTime; + } + + public static Builder builder() { + return new Builder(); + } + + public static Builder builder(MergeMetrics metrics) { + return new Builder() + .calcHeaderTime(metrics.calcHeaderTime) + .mergeIntoHostBufferTime(metrics.mergeIntoHostBufferTime) + .convertToTableTime(metrics.convertToTableTime); + } + + + public static class Builder { + private long calcHeaderTime; + private long mergeIntoHostBufferTime; + private long convertToTableTime; + + public Builder calcHeaderTime(long calcHeaderTime) { + this.calcHeaderTime = calcHeaderTime; + return this; + } + + public Builder mergeIntoHostBufferTime(long mergeIntoHostBufferTime) { + this.mergeIntoHostBufferTime = mergeIntoHostBufferTime; + return this; + } + + public Builder convertToTableTime(long convertToTableTime) { + this.convertToTableTime = convertToTableTime; + return this; + } + + public MergeMetrics build() { + return new MergeMetrics(calcHeaderTime, mergeIntoHostBufferTime, convertToTableTime); + } + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/MergedInfoCalc.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/MergedInfoCalc.java new file mode 100644 index 0000000000..826ef2e691 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/MergedInfoCalc.java @@ -0,0 +1,168 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni.kudo; + +import ai.rapids.cudf.Schema; +import com.nvidia.spark.rapids.jni.schema.Visitors; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static com.nvidia.spark.rapids.jni.kudo.ColumnOffsetInfo.INVALID_OFFSET; +import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.getValidityLengthInBytes; +import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padFor64byteAlignment; + + +/** + * This class is used to calculate column offsets of merged buffer. + */ +class MergedInfoCalc extends MultiKudoTableVisitor { + // Total data len in gpu, which accounts for 64 byte alignment + private long totalDataLen; + // Column offset in gpu device buffer, it has one field for each flattened column + private final List columnOffsets; + + public MergedInfoCalc(List tables) { + super(tables); + this.totalDataLen = 0; + this.columnOffsets = new ArrayList<>(tables.get(0).getHeader().getNumColumns()); + } + + @Override + protected Void doVisitTopSchema(Schema schema, List children) { + return null; + } + + @Override + protected Void doVisitStruct(Schema structType, List children) { + long validityBufferLen = 0; + long validityOffset = INVALID_OFFSET; + if (hasNull()) { + validityBufferLen = padFor64byteAlignment(getValidityLengthInBytes(getTotalRowCount())); + validityOffset = totalDataLen; + totalDataLen += validityBufferLen; + } + + columnOffsets.add(new ColumnOffsetInfo(validityOffset, validityBufferLen, INVALID_OFFSET, 0, INVALID_OFFSET, 0)); + return null; + } + + @Override + protected Void doPreVisitList(Schema listType) { + long validityBufferLen = 0; + long validityOffset = INVALID_OFFSET; + if (hasNull()) { + validityBufferLen = padFor64byteAlignment(getValidityLengthInBytes(getTotalRowCount())); + validityOffset = totalDataLen; + totalDataLen += validityBufferLen; + } + + long offsetBufferLen = 0; + long offsetBufferOffset = INVALID_OFFSET; + if (getTotalRowCount() > 0) { + offsetBufferLen = padFor64byteAlignment((getTotalRowCount() + 1) * Integer.BYTES); + offsetBufferOffset = totalDataLen; + totalDataLen += offsetBufferLen; + } + + + columnOffsets.add(new ColumnOffsetInfo(validityOffset, validityBufferLen, offsetBufferOffset, offsetBufferLen, INVALID_OFFSET, 0)); + return null; + } + + @Override + protected Void doVisitList(Schema listType, Void preVisitResult, Void childResult) { + return null; + } + + @Override + protected Void doVisit(Schema primitiveType) { + // String type + if (primitiveType.getType().hasOffsets()) { + long validityBufferLen = 0; + long validityOffset = INVALID_OFFSET; + if (hasNull()) { + validityBufferLen = padFor64byteAlignment(getValidityLengthInBytes(getTotalRowCount())); + validityOffset = totalDataLen; + totalDataLen += validityBufferLen; + } + + long offsetBufferLen = 0; + long offsetBufferOffset = INVALID_OFFSET; + if (getTotalRowCount() > 0) { + offsetBufferLen = padFor64byteAlignment((getTotalRowCount() + 1) * Integer.BYTES); + offsetBufferOffset = totalDataLen; + totalDataLen += offsetBufferLen; + } + + long dataBufferLen = 0; + long dataBufferOffset = INVALID_OFFSET; + if (getTotalStrDataLen() > 0) { + dataBufferLen = padFor64byteAlignment(getTotalStrDataLen()); + dataBufferOffset = totalDataLen; + totalDataLen += dataBufferLen; + } + + columnOffsets.add(new ColumnOffsetInfo(validityOffset, validityBufferLen, offsetBufferOffset, offsetBufferLen, dataBufferOffset, dataBufferLen)); + } else { + long totalRowCount = getTotalRowCount(); + long validityBufferLen = 0; + long validityOffset = INVALID_OFFSET; + if (hasNull()) { + validityBufferLen = padFor64byteAlignment(getValidityLengthInBytes(totalRowCount)); + validityOffset = totalDataLen; + totalDataLen += validityBufferLen; + } + + long dataBufferLen = 0; + long dataBufferOffset = INVALID_OFFSET; + if (totalRowCount > 0) { + dataBufferLen = padFor64byteAlignment(totalRowCount * primitiveType.getType().getSizeInBytes()); + dataBufferOffset = totalDataLen; + totalDataLen += dataBufferLen; + } + + columnOffsets.add(new ColumnOffsetInfo(validityOffset, validityBufferLen, INVALID_OFFSET, 0, dataBufferOffset, dataBufferLen)); + } + + return null; + } + + + public long getTotalDataLen() { + return totalDataLen; + } + + List getColumnOffsets() { + return Collections.unmodifiableList(columnOffsets); + } + + @Override + public String toString() { + return "MergedInfoCalc{" + + "totalDataLen=" + totalDataLen + + ", columnOffsets=" + columnOffsets + + '}'; + } + + static MergedInfoCalc calc(Schema schema, List table) { + MergedInfoCalc calc = new MergedInfoCalc(table); + Visitors.visitSchema(schema, calc); + return calc; + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/MultiKudoTableVisitor.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/MultiKudoTableVisitor.java new file mode 100644 index 0000000000..afa7ba6ea0 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/MultiKudoTableVisitor.java @@ -0,0 +1,286 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni.kudo; + +import ai.rapids.cudf.HostMemoryBuffer; +import ai.rapids.cudf.Schema; +import com.nvidia.spark.rapids.jni.schema.SchemaVisitor; + +import java.util.*; + +import static com.nvidia.spark.rapids.jni.Preconditions.ensure; +import static com.nvidia.spark.rapids.jni.kudo.ColumnOffsetInfo.INVALID_OFFSET; +import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padForHostAlignment; +import static java.lang.Math.toIntExact; + +/** + * This class provides a base class for visiting multiple kudo tables, e.g. it helps to maintain internal states during + * visiting multi kudo tables, which makes it easier to do some calculations based on them. + *
+ * The schema used when visiting these kudo tables must be same as the schema used when creating these kudo tables. + */ +abstract class MultiKudoTableVisitor implements SchemaVisitor { + private final List tables; + private final long[] currentValidityOffsets; + private final long[] currentOffsetOffsets; + private final long[] currentDataOffset; + private final Deque[] sliceInfoStack; + private final Deque totalRowCountStack; + // A temporary variable to keep if current column has null + private boolean hasNull; + private int currentIdx; + // Temporary buffer to store data length of string column to avoid repeated allocation + private final int[] strDataLen; + // Temporary variable to calculate total data length of string column + private long totalStrDataLen; + + protected MultiKudoTableVisitor(List inputTables) { + Objects.requireNonNull(inputTables, "tables cannot be null"); + ensure(!inputTables.isEmpty(), "tables cannot be empty"); + this.tables = inputTables instanceof ArrayList ? inputTables : new ArrayList<>(inputTables); + this.currentValidityOffsets = new long[tables.size()]; + this.currentOffsetOffsets = new long[tables.size()]; + this.currentDataOffset = new long[tables.size()]; + this.sliceInfoStack = new Deque[tables.size()]; + long totalRowCount = 0L; + for (int i = 0; i < tables.size(); i++) { + this.currentValidityOffsets[i] = 0; + KudoTableHeader header = tables.get(i).getHeader(); + this.currentOffsetOffsets[i] = header.getValidityBufferLen(); + this.currentDataOffset[i] = header.getValidityBufferLen() + header.getOffsetBufferLen(); + this.sliceInfoStack[i] = new ArrayDeque<>(16); + this.sliceInfoStack[i].add(new SliceInfo(header.getOffset(), header.getNumRows())); + totalRowCount += header.getNumRows(); + } + this.totalRowCountStack = new ArrayDeque<>(16); + totalRowCountStack.addLast(toIntExact(totalRowCount)); + this.hasNull = true; + this.currentIdx = 0; + this.strDataLen = new int[tables.size()]; + this.totalStrDataLen = 0; + } + + List getTables() { + return tables; + } + + @Override + public R visitTopSchema(Schema schema, List children) { + return doVisitTopSchema(schema, children); + } + + protected abstract R doVisitTopSchema(Schema schema, List children); + + @Override + public T visitStruct(Schema structType, List children) { + updateHasNull(); + T t = doVisitStruct(structType, children); + updateOffsets( + false, // Update offset buffer offset + false, // Update data buffer offset + false, // Update slice info + -1 // element size in bytes, not used for struct + ); + currentIdx += 1; + return t; + } + + protected abstract T doVisitStruct(Schema structType, List children); + + @Override + public P preVisitList(Schema listType) { + updateHasNull(); + P t = doPreVisitList(listType); + updateOffsets( + true, // update offset buffer offset + false, // update data buffer offset + true, // update slice info + Integer.BYTES // element size in bytes + ); + currentIdx += 1; + return t; + } + + protected abstract P doPreVisitList(Schema listType); + + @Override + public T visitList(Schema listType, P preVisitResult, T childResult) { + T t = doVisitList(listType, preVisitResult, childResult); + for (int tableIdx = 0; tableIdx < tables.size(); tableIdx++) { + sliceInfoStack[tableIdx].removeLast(); + } + totalRowCountStack.removeLast(); + return t; + } + + protected abstract T doVisitList(Schema listType, P preVisitResult, T childResult); + + @Override + public T visit(Schema primitiveType) { + updateHasNull(); + if (primitiveType.getType().hasOffsets()) { + // string type + updateDataLen(); + } + + T t = doVisit(primitiveType); + if (primitiveType.getType().hasOffsets()) { + updateOffsets( + true, // update offset buffer offset + true, // update data buffer offset + false, // update slice info + -1 // element size in bytes, not used for string + ); + } else { + updateOffsets( + false, //update offset buffer offset + true, // update data buffer offset + false, // update slice info + primitiveType.getType().getSizeInBytes() // element size in bytes + ); + } + currentIdx += 1; + return t; + } + + protected abstract T doVisit(Schema primitiveType); + + private void updateHasNull() { + hasNull = false; + for (KudoTable table : tables) { + if (table.getHeader().hasValidityBuffer(currentIdx)) { + hasNull = true; + return; + } + } + } + + // For string column only + private void updateDataLen() { + totalStrDataLen = 0; + // String's data len needs to be calculated from offset buffer + for (int tableIdx = 0; tableIdx < getTableSize(); tableIdx += 1) { + SliceInfo sliceInfo = sliceInfoOf(tableIdx); + if (sliceInfo.getRowCount() > 0) { + int offset = offsetOf(tableIdx, 0); + int endOffset = offsetOf(tableIdx, sliceInfo.getRowCount()); + + strDataLen[tableIdx] = endOffset - offset; + totalStrDataLen += strDataLen[tableIdx]; + } else { + strDataLen[tableIdx] = 0; + } + } + } + + private void updateOffsets(boolean updateOffset, boolean updateData, boolean updateSliceInfo, int sizeInBytes) { + long totalRowCount = 0; + for (int tableIdx = 0; tableIdx < tables.size(); tableIdx++) { + SliceInfo sliceInfo = sliceInfoOf(tableIdx); + if (sliceInfo.getRowCount() > 0) { + if (updateSliceInfo) { + int startOffset = offsetOf(tableIdx, 0); + int endOffset = offsetOf(tableIdx, sliceInfo.getRowCount()); + int rowCount = endOffset - startOffset; + totalRowCount += rowCount; + + sliceInfoStack[tableIdx].addLast(new SliceInfo(startOffset, rowCount)); + } + + if (tables.get(tableIdx).getHeader().hasValidityBuffer(currentIdx)) { + currentValidityOffsets[tableIdx] += padForHostAlignment(sliceInfo.getValidityBufferInfo().getBufferLength()); + } + + if (updateOffset) { + currentOffsetOffsets[tableIdx] += padForHostAlignment((sliceInfo.getRowCount() + 1) * Integer.BYTES); + if (updateData) { + // string type + currentDataOffset[tableIdx] += padForHostAlignment(strDataLen[tableIdx]); + } + // otherwise list type + } else { + if (updateData) { + // primitive type + currentDataOffset[tableIdx] += padForHostAlignment(sliceInfo.getRowCount() * sizeInBytes); + } + } + + } else { + if (updateSliceInfo) { + sliceInfoStack[tableIdx].addLast(new SliceInfo(0, 0)); + } + } + } + + if (updateSliceInfo) { + totalRowCountStack.addLast(toIntExact(totalRowCount)); + } + } + + // Below parts are information about current column + + protected int getTotalRowCount() { + return totalRowCountStack.getLast(); + } + + + protected boolean hasNull() { + return hasNull; + } + + protected SliceInfo sliceInfoOf(int tableIdx) { + return sliceInfoStack[tableIdx].getLast(); + } + + protected HostMemoryBuffer memoryBufferOf(int tableIdx) { + return tables.get(tableIdx).getBuffer(); + } + + protected int offsetOf(int tableIdx, long rowIdx) { + long startOffset = currentOffsetOffsets[tableIdx]; + return tables.get(tableIdx).getBuffer().getInt(startOffset + rowIdx * Integer.BYTES); + } + + protected long validifyBufferOffset(int tableIdx) { + if (tables.get(tableIdx).getHeader().hasValidityBuffer(currentIdx)) { + return currentValidityOffsets[tableIdx]; + } else { + return INVALID_OFFSET; + } + } + + protected void copyDataBuffer(HostMemoryBuffer dst, long dstOffset, int tableIdx, int dataLen) { + long startOffset = currentDataOffset[tableIdx]; + dst.copyFromHostBuffer(dstOffset, tables.get(tableIdx).getBuffer(), startOffset, dataLen); + } + + protected long getTotalStrDataLen() { + return totalStrDataLen; + } + + protected int getStrDataLenOf(int tableIdx) { + return strDataLen[tableIdx]; + } + + protected int getCurrentIdx() { + return currentIdx; + } + + public int getTableSize() { + return this.tables.size(); + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/SliceInfo.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/SliceInfo.java new file mode 100644 index 0000000000..bd82d3c634 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/SliceInfo.java @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni.kudo; + +public class SliceInfo { + final int offset; + final int rowCount; + private final SlicedValidityBufferInfo validityBufferInfo; + + SliceInfo(int offset, int rowCount) { + this.offset = offset; + this.rowCount = rowCount; + this.validityBufferInfo = SlicedValidityBufferInfo.calc(offset, rowCount); + } + + SlicedValidityBufferInfo getValidityBufferInfo() { + return validityBufferInfo; + } + + public int getOffset() { + return offset; + } + + public int getRowCount() { + return rowCount; + } + + @Override + public String toString() { + return "SliceInfo{" + + "offset=" + offset + + ", rowCount=" + rowCount + + ", validityBufferInfo=" + validityBufferInfo + + '}'; + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedBufferSerializer.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedBufferSerializer.java new file mode 100644 index 0000000000..e22a523855 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedBufferSerializer.java @@ -0,0 +1,210 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni.kudo; + +import ai.rapids.cudf.BufferType; +import ai.rapids.cudf.DType; +import ai.rapids.cudf.HostColumnVectorCore; +import ai.rapids.cudf.HostMemoryBuffer; +import com.nvidia.spark.rapids.jni.schema.HostColumnsVisitor; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.List; + +import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padForHostAlignment; + +/** + * This class visits a list of columns and serialize one of the buffers (validity, offset, or data) into with kudo + * format. + * + *

+ * The host columns are visited in post order, for more details about the visiting process, please refer to + * {@link HostColumnsVisitor}. + *

+ * + *

+ * For more details about the kudo format, please refer to {@link KudoSerializer}. + *

+ */ +class SlicedBufferSerializer implements HostColumnsVisitor { + private final SliceInfo root; + private final BufferType bufferType; + private final DataWriter writer; + + private final Deque sliceInfos = new ArrayDeque<>(); + private long totalDataLen; + + SlicedBufferSerializer(int rowOffset, int numRows, BufferType bufferType, DataWriter writer) { + this.root = new SliceInfo(rowOffset, numRows); + this.bufferType = bufferType; + this.writer = writer; + this.sliceInfos.addLast(root); + this.totalDataLen = 0; + } + + public long getTotalDataLen() { + return totalDataLen; + } + + @Override + public Void visitStruct(HostColumnVectorCore col, List children) { + SliceInfo parent = sliceInfos.peekLast(); + + try { + switch (bufferType) { + case VALIDITY: + totalDataLen += this.copySlicedValidity(col, parent); + return null; + case OFFSET: + case DATA: + return null; + default: + throw new IllegalArgumentException("Unexpected buffer type: " + bufferType); + } + + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public Void preVisitList(HostColumnVectorCore col) { + SliceInfo parent = sliceInfos.getLast(); + + + long bytesCopied = 0; + try { + switch (bufferType) { + case VALIDITY: + bytesCopied = this.copySlicedValidity(col, parent); + break; + case OFFSET: + bytesCopied = this.copySlicedOffset(col, parent); + break; + case DATA: + break; + default: + throw new IllegalArgumentException("Unexpected buffer type: " + bufferType); + } + + } catch (IOException e) { + throw new RuntimeException(e); + } + + SliceInfo current; + if (col.getOffsets() != null) { + int start = col.getOffsets() + .getInt(parent.offset * Integer.BYTES); + int end = col.getOffsets().getInt((parent.offset + parent.rowCount) * Integer.BYTES); + int rowCount = end - start; + + current = new SliceInfo(start, rowCount); + } else { + current = new SliceInfo(0, 0); + } + + sliceInfos.addLast(current); + + totalDataLen += bytesCopied; + return null; + } + + @Override + public Void visitList(HostColumnVectorCore col, Void preVisitResult, Void childResult) { + sliceInfos.removeLast(); + return null; + } + + @Override + public Void visit(HostColumnVectorCore col) { + SliceInfo parent = sliceInfos.getLast(); + try { + switch (bufferType) { + case VALIDITY: + totalDataLen += this.copySlicedValidity(col, parent); + return null; + case OFFSET: + totalDataLen += this.copySlicedOffset(col, parent); + return null; + case DATA: + totalDataLen += this.copySlicedData(col, parent); + return null; + default: + throw new IllegalArgumentException("Unexpected buffer type: " + bufferType); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private long copySlicedValidity(HostColumnVectorCore column, SliceInfo sliceInfo) throws IOException { + if (column.getValidity() != null && sliceInfo.getRowCount() > 0) { + HostMemoryBuffer buff = column.getValidity(); + long len = sliceInfo.getValidityBufferInfo().getBufferLength(); + writer.copyDataFrom(buff, sliceInfo.getValidityBufferInfo().getBufferOffset(), + len); + return padForHostAlignment(writer, len); + } else { + return 0; + } + } + + private long copySlicedOffset(HostColumnVectorCore column, SliceInfo sliceInfo) throws IOException { + if (sliceInfo.rowCount <= 0 || column.getOffsets() == null) { + // Don't copy anything, there are no rows + return 0; + } + long bytesToCopy = (sliceInfo.rowCount + 1) * Integer.BYTES; + long srcOffset = sliceInfo.offset * Integer.BYTES; + HostMemoryBuffer buff = column.getOffsets(); + writer.copyDataFrom(buff, srcOffset, bytesToCopy); + return padForHostAlignment(writer, bytesToCopy); + } + + private long copySlicedData(HostColumnVectorCore column, SliceInfo sliceInfo) throws IOException { + if (sliceInfo.rowCount > 0) { + DType type = column.getType(); + if (type.equals(DType.STRING)) { + long startByteOffset = column.getOffsets().getInt(sliceInfo.offset * Integer.BYTES); + long endByteOffset = column.getOffsets().getInt((sliceInfo.offset + sliceInfo.rowCount) * Integer.BYTES); + long bytesToCopy = endByteOffset - startByteOffset; + if (column.getData() == null) { + if (bytesToCopy != 0) { + throw new IllegalStateException("String column has no data buffer, " + + "but bytes to copy is not zero: " + bytesToCopy); + } + + return 0; + } else { + writer.copyDataFrom(column.getData(), startByteOffset, bytesToCopy); + return padForHostAlignment(writer, bytesToCopy); + } + } else if (type.getSizeInBytes() > 0) { + long bytesToCopy = sliceInfo.rowCount * type.getSizeInBytes(); + long srcOffset = sliceInfo.offset * type.getSizeInBytes(); + writer.copyDataFrom(column.getData(), srcOffset, bytesToCopy); + return padForHostAlignment(writer, bytesToCopy); + } else { + return 0; + } + } else { + return 0; + } + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedValidityBufferInfo.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedValidityBufferInfo.java new file mode 100644 index 0000000000..7c9957f5b2 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedValidityBufferInfo.java @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni.kudo; + +/** + * A simple utility class to hold information about serializing/deserializing sliced validity buffer. + */ +class SlicedValidityBufferInfo { + private final int bufferOffset; + private final int bufferLength; + /// The bit offset within the buffer where the slice starts + private final int beginBit; + + SlicedValidityBufferInfo(int bufferOffset, int bufferLength, int beginBit) { + this.bufferOffset = bufferOffset; + this.bufferLength = bufferLength; + this.beginBit = beginBit; + } + + @Override + public String toString() { + return "SlicedValidityBufferInfo{" + "bufferOffset=" + bufferOffset + ", bufferLength=" + bufferLength + + ", beginBit=" + beginBit + '}'; + } + + public int getBufferOffset() { + return bufferOffset; + } + + public int getBufferLength() { + return bufferLength; + } + + public int getBeginBit() { + return beginBit; + } + + static SlicedValidityBufferInfo calc(int rowOffset, int numRows) { + if (rowOffset < 0) { + throw new IllegalArgumentException("rowOffset must be >= 0, but was " + rowOffset); + } + if (numRows < 0) { + throw new IllegalArgumentException("numRows must be >= 0, but was " + numRows); + } + int bufferOffset = rowOffset / 8; + int beginBit = rowOffset % 8; + int bufferLength = 0; + if (numRows > 0) { + bufferLength = (rowOffset + numRows - 1) / 8 - bufferOffset + 1; + } + return new SlicedValidityBufferInfo(bufferOffset, bufferLength, beginBit); + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/TableBuilder.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/TableBuilder.java new file mode 100644 index 0000000000..e50e462f4f --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/TableBuilder.java @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni.kudo; + +import ai.rapids.cudf.*; +import com.nvidia.spark.rapids.jni.Arms; +import com.nvidia.spark.rapids.jni.schema.SchemaVisitor; + +import java.util.ArrayList; +import java.util.List; + +import static com.nvidia.spark.rapids.jni.Preconditions.ensure; +import static java.util.Objects.requireNonNull; + +/** + * This class is used to build a cudf table from a list of column view info, and a device buffer. + */ +class TableBuilder implements SchemaVisitor, AutoCloseable { + private int curColumnIdx; + private final DeviceMemoryBuffer buffer; + private final List colViewInfoList; + private final List columnViewList; + + public TableBuilder(List colViewInfoList, DeviceMemoryBuffer buffer) { + requireNonNull(colViewInfoList, "colViewInfoList cannot be null"); + ensure(!colViewInfoList.isEmpty(), "colViewInfoList cannot be empty"); + requireNonNull(buffer, "Device buffer can't be null!"); + + this.curColumnIdx = 0; + this.buffer = buffer; + this.colViewInfoList = colViewInfoList; + this.columnViewList = new ArrayList<>(colViewInfoList.size()); + } + + @Override + public Table visitTopSchema(Schema schema, List children) { + // When this method is called, the ownership of the column views in `columnViewList` has been transferred to + // `children`, so we need to clear `columnViewList`. + this.columnViewList.clear(); + try { + try (CloseableArray arr = CloseableArray.wrap(new ColumnVector[children.size()])) { + for (int i = 0; i < children.size(); i++) { + ColumnView colView = children.set(i, null); + arr.set(i, ColumnVector.fromViewWithContiguousAllocation(colView.getNativeView(), buffer)); + } + + return new Table(arr.getArray()); + } + } finally { + Arms.closeAll(columnViewList); + } + } + + @Override + public ColumnView visitStruct(Schema structType, List children) { + ColumnViewInfo colViewInfo = getCurrentColumnViewInfo(); + + ColumnView[] childrenView = children.toArray(new ColumnView[0]); + ColumnView columnView = colViewInfo.buildColumnView(buffer, childrenView); + curColumnIdx += 1; + columnViewList.add(columnView); + return columnView; + } + + @Override + public ColumnViewInfo preVisitList(Schema listType) { + ColumnViewInfo colViewInfo = getCurrentColumnViewInfo(); + + curColumnIdx += 1; + return colViewInfo; + } + + @Override + public ColumnView visitList(Schema listType, ColumnViewInfo colViewInfo, ColumnView childResult) { + + ColumnView[] children = new ColumnView[]{childResult}; + + ColumnView view = colViewInfo.buildColumnView(buffer, children); + columnViewList.add(view); + return view; + } + + @Override + public ColumnView visit(Schema primitiveType) { + ColumnViewInfo colViewInfo = getCurrentColumnViewInfo(); + + ColumnView columnView = colViewInfo.buildColumnView(buffer, null); + curColumnIdx += 1; + columnViewList.add(columnView); + return columnView; + } + + private ColumnViewInfo getCurrentColumnViewInfo() { + return colViewInfoList.get(curColumnIdx); + } + + @Override + public void close() throws Exception { + Arms.closeAll(columnViewList); + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/schema/HostColumnsVisitor.java b/src/main/java/com/nvidia/spark/rapids/jni/schema/HostColumnsVisitor.java new file mode 100644 index 0000000000..ae7915c60d --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/schema/HostColumnsVisitor.java @@ -0,0 +1,84 @@ +/* + * + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package com.nvidia.spark.rapids.jni.schema; + +import ai.rapids.cudf.HostColumnVectorCore; + +import java.util.List; + +/** + * A post order visitor for visiting a list of host columns in a schema. + * + *

+ * + * For example, if we have three columns A, B, and C with following types: + * + *

    + *
  • A: struct { int a1; long a2}
  • + *
  • B: list { int b1}
  • + *
  • C: string c1
  • + *
+ * + * The order of visiting will be: + *
    + *
  1. Visit primitive column a1
  2. + *
  3. Visit primitive column a2
  4. + *
  5. Visit struct column A, with results from a1 and a2
  6. + *
  7. Previsit list column B
  8. + *
  9. Visit primitive column b1
  10. + *
  11. Visit list column B
  12. + *
  13. Visit primitive column c1
  14. + *
+ * + *

+ * + * @param Return type when visiting intermediate nodes. + */ +public interface HostColumnsVisitor { + /** + * Visit a struct column. + * @param col the struct column to visit + * @param children the results of visiting the children + * @return the result of visiting the struct column + */ + T visitStruct(HostColumnVectorCore col, List children); + + /** + * Visit a list column before actually visiting its child. + * @param col the list column to visit + * @return the result of visiting the list column + */ + T preVisitList(HostColumnVectorCore col); + + /** + * Visit a list column after visiting its child. + * @param col the list column to visit + * @param preVisitResult the result of visiting the list column before visiting its child + * @param childResult the result of visiting the child + * @return the result of visiting the list column + */ + T visitList(HostColumnVectorCore col, T preVisitResult, T childResult); + + /** + * Visit a column that is a primitive type. + * @param col the column to visit + * @return the result of visiting the column + */ + T visit(HostColumnVectorCore col); +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/schema/SchemaVisitor.java b/src/main/java/com/nvidia/spark/rapids/jni/schema/SchemaVisitor.java new file mode 100644 index 0000000000..c6b33e0fb4 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/schema/SchemaVisitor.java @@ -0,0 +1,120 @@ +/* + * + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package com.nvidia.spark.rapids.jni.schema; + +import ai.rapids.cudf.Schema; + +import java.util.List; + +/** + * A post order visitor for schemas. + * + *

Flattened Schema

+ * + * A flattened schema is a schema where all fields with nested types are flattened into an array of fields. For example, + * for a schema with following fields: + * + *
    + *
  • A: struct { int a1; long a2}
  • + *
  • B: list { int b1}
  • + *
  • C: string
  • + *
  • D: long
  • + *
+ * + * The flattened schema will be: + * + *
    + *
  • A: struct
  • + *
  • A.a1: int
  • + *
  • A.a2: long
  • + *
  • B: list
  • + *
  • B.b1: int
  • + *
  • C: string
  • + *
  • D: long
  • + *
+ * + *

Example

+ * + *

+ * This visitor visits each filed in the flattened schema in post order. For example, if our schema consists of three + * fields A, B, and C with following fields: + *

    + *
  • A: struct { int a1; long a2}
  • + *
  • B: list { int b1}
  • + *
  • C: string
  • + *
+ * + * The order of visiting will be: + *
    + *
  1. Visit primitive field a1
  2. + *
  3. Visit primitive field a2
  4. + *
  5. Visit struct field A, with results from a1 and a2
  6. + *
  7. Previsit list field B
  8. + *
  9. Visit primitive field b1
  10. + *
  11. Visit list field B with results from b1 and previsit result.
  12. + *
  13. Visit primitive field C
  14. + *
  15. Visit top schema with results from fields A, B, and C
  16. + *
+ * + *

+ * + * @param Return type when visiting intermediate nodes. + * @param

Return type after visiting a list schema before visiting its child. + * @param Return type after processing all children values. + */ +public interface SchemaVisitor { + /** + * Visit the top level schema. + * @param schema the top level schema to visit + * @param children the results of visiting the children + * @return the result of visiting the top level schema + */ + R visitTopSchema(Schema schema, List children); + + /** + * Visit a struct schema. + * @param structType the struct schema to visit + * @param children the results of visiting the children + * @return the result of visiting the struct schema + */ + T visitStruct(Schema structType, List children); + + /** + * Visit a list schema before actually visiting its child. + * @param listType the list schema to visit + * @return the result of visiting the list schema + */ + P preVisitList(Schema listType); + + /** + * Visit a list schema after visiting its child. + * @param listType the list schema to visit + * @param preVisitResult the result of visiting the list schema before visiting its child + * @param childResult the result of visiting the child + * @return the result of visiting the list schema + */ + T visitList(Schema listType, P preVisitResult, T childResult); + + /** + * Visit a primitive type. + * @param primitiveType the primitive type to visit + * @return the result of visiting the primitive type + */ + T visit(Schema primitiveType); +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/schema/Visitors.java b/src/main/java/com/nvidia/spark/rapids/jni/schema/Visitors.java new file mode 100644 index 0000000000..b7f4f521e4 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/schema/Visitors.java @@ -0,0 +1,105 @@ +/* + * + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package com.nvidia.spark.rapids.jni.schema; + +import ai.rapids.cudf.HostColumnVector; +import ai.rapids.cudf.HostColumnVectorCore; +import ai.rapids.cudf.Schema; + +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +/** + * A utility class for visiting a schema or a list of host columns. + */ +public class Visitors { + /** + * Visiting a schema in post order. For more details, see {@link SchemaVisitor}. + * + * @param schema the schema to visit + * @param visitor the visitor to use + * @param Return type when visiting intermediate nodes. See {@link SchemaVisitor} + * @param

Return type when previsiting a list. See {@link SchemaVisitor} + * @param Return type after processing all children values. See {@link SchemaVisitor} + * @return the result of visiting the schema + */ + public static R visitSchema(Schema schema, SchemaVisitor visitor) { + Objects.requireNonNull(schema, "schema cannot be null"); + Objects.requireNonNull(visitor, "visitor cannot be null"); + + List childrenResult = IntStream.range(0, schema.getNumChildren()) + .mapToObj(i -> visitSchemaInner(schema.getChild(i), visitor)) + .collect(Collectors.toList()); + + return visitor.visitTopSchema(schema, childrenResult); + } + + private static T visitSchemaInner(Schema schema, SchemaVisitor visitor) { + switch (schema.getType().getTypeId()) { + case STRUCT: + List children = IntStream.range(0, schema.getNumChildren()) + .mapToObj(childIdx -> visitSchemaInner(schema.getChild(childIdx), visitor)) + .collect(Collectors.toList()); + return visitor.visitStruct(schema, children); + case LIST: + P preVisitResult = visitor.preVisitList(schema); + T childResult = visitSchemaInner(schema.getChild(0), visitor); + return visitor.visitList(schema, preVisitResult, childResult); + default: + return visitor.visit(schema); + } + } + + + /** + * Visiting a list of host columns in post order. For more details, see {@link HostColumnsVisitor}. + * + * @param cols the list of host columns to visit + * @param visitor the visitor to use + * @param Return type when visiting intermediate nodes. See {@link HostColumnsVisitor} + */ + public static void visitColumns(HostColumnVector[] cols, + HostColumnsVisitor visitor) { + Objects.requireNonNull(cols, "cols cannot be null"); + Objects.requireNonNull(visitor, "visitor cannot be null"); + + for (HostColumnVector col : cols) { + visitColumn(col, visitor); + } + + } + + private static T visitColumn(HostColumnVectorCore col, HostColumnsVisitor visitor) { + switch (col.getType().getTypeId()) { + case STRUCT: + List children = IntStream.range(0, col.getNumChildren()) + .mapToObj(childIdx -> visitColumn(col.getChildColumnView(childIdx), visitor)) + .collect(Collectors.toList()); + return visitor.visitStruct(col, children); + case LIST: + T preVisitResult = visitor.preVisitList(col); + T childResult = visitColumn(col.getChildColumnView(0), visitor); + return visitor.visitList(col, preVisitResult, childResult); + default: + return visitor.visit(col); + } + } +} diff --git a/src/test/java/com/nvidia/spark/rapids/jni/FromJsonToRawMapTest.java b/src/test/java/com/nvidia/spark/rapids/jni/FromJsonToRawMapTest.java index 8edff2f4c8..e975b1a068 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/FromJsonToRawMapTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/FromJsonToRawMapTest.java @@ -19,11 +19,21 @@ import ai.rapids.cudf.ColumnVector; import ai.rapids.cudf.BinaryOp; +import ai.rapids.cudf.JSONOptions; import org.junit.jupiter.api.Test; import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual; public class FromJsonToRawMapTest { + private static JSONOptions getOptions() { + return JSONOptions.builder() + .withNormalizeSingleQuotes(true) + .withLeadingZeros(true) + .withNonNumericNumbers(true) + .withUnquotedControlChars(true) + .build(); + } + @Test void testFromJsonSimpleInput() { @@ -36,7 +46,7 @@ void testFromJsonSimpleInput() { try (ColumnVector input = ColumnVector.fromStrings(jsonString1, jsonString2, null, jsonString3); - ColumnVector outputMap = JSONUtils.extractRawMapFromJsonString(input); + ColumnVector outputMap = JSONUtils.extractRawMapFromJsonString(input, getOptions()); ColumnVector expectedKeys = ColumnVector.fromStrings("Zipcode", "ZipCodeType", "City", "State", "category", "index", "author", "title", "price"); @@ -65,7 +75,7 @@ void testFromJsonWithUTF8() { try (ColumnVector input = ColumnVector.fromStrings(jsonString1, jsonString2, null, jsonString3); - ColumnVector outputMap = JSONUtils.extractRawMapFromJsonString(input); + ColumnVector outputMap = JSONUtils.extractRawMapFromJsonString(input, getOptions()); ColumnVector expectedKeys = ColumnVector.fromStrings("Zipc\u00f3de", "Z\u00edpCodeTyp" + "\u00e9", "City", "St\u00e2te", "Zipc\u00f3de", "Z\u00edpCodeTyp\u00e9", @@ -82,4 +92,42 @@ void testFromJsonWithUTF8() { assertColumnsAreEqual(expectedMap, outputMap); } } + + @Test + void testFromJsonEmptyAndInvalidInput() { + try (ColumnVector input = + ColumnVector.fromStrings("{}", "BAD", "{\"A\": 100}"); + ColumnVector outputMap = JSONUtils.extractRawMapFromJsonString(input, getOptions()); + + ColumnVector expectedKeys = ColumnVector.fromStrings("A"); + ColumnVector expectedValues = ColumnVector.fromStrings("100"); + ColumnVector expectedStructs = ColumnVector.makeStruct(expectedKeys, expectedValues); + ColumnVector expectedOffsets = ColumnVector.fromInts(0, 0, 0, 1); + ColumnVector tmpMap = expectedStructs.makeListFromOffsets(3, expectedOffsets); + ColumnVector templateBitmask = ColumnVector.fromBoxedInts(1, null, 1); + ColumnVector expectedMap = tmpMap.mergeAndSetValidity(BinaryOp.BITWISE_AND, + templateBitmask); + ) { + assertColumnsAreEqual(expectedMap, outputMap); + } + } + + @Test + void testFromJsonInputWithSingleQuotes() { + try (ColumnVector input = + ColumnVector.fromStrings("{'teacher': 'ABC', 'student': 'XYZ'}", + "invalid", "null", "", " "); + ColumnVector outputMap = JSONUtils.extractRawMapFromJsonString(input, getOptions()); + ColumnVector expectedKeys = ColumnVector.fromStrings("teacher", "student"); + ColumnVector expectedValues = ColumnVector.fromStrings("ABC", "XYZ"); + ColumnVector expectedStructs = ColumnVector.makeStruct(expectedKeys, expectedValues); + ColumnVector expectedOffsets = ColumnVector.fromInts(0, 2, 2, 2, 2, 2); + ColumnVector tmpMap = expectedStructs.makeListFromOffsets(5, expectedOffsets); + ColumnVector templateBitmask = ColumnVector.fromBoxedInts(1, null, null, null, null); + ColumnVector expectedMap = tmpMap.mergeAndSetValidity(BinaryOp.BITWISE_AND, + templateBitmask); + ) { + assertColumnsAreEqual(expectedMap, outputMap); + } + } } diff --git a/src/test/java/com/nvidia/spark/rapids/jni/HashTest.java b/src/test/java/com/nvidia/spark/rapids/jni/HashTest.java index d35f20fe2c..19172a8d33 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/HashTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/HashTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ import ai.rapids.cudf.ColumnVector; import ai.rapids.cudf.ColumnView; +import ai.rapids.cudf.CudfException; import ai.rapids.cudf.DType; import ai.rapids.cudf.HostColumnVector.*; import org.junit.jupiter.api.Test; @@ -26,6 +27,7 @@ import java.util.Collections; import static ai.rapids.cudf.AssertUtils.*; +import static org.junit.jupiter.api.Assertions.assertThrows; public class HashTest { // IEEE 754 NaN values @@ -510,4 +512,190 @@ void testHiveHashMixed() { assertColumnsAreEqual(expected, result); } } + + @Test + void testHiveHashStruct() { + try (ColumnVector strings = ColumnVector.fromStrings( + "a", "B\n", "dE\"\u0100\t\u0101 \ud720\ud721", + "This is a long string (greater than 128 bytes/char string) case to test this " + + "hash function. Just want an abnormal case here to see if any error may happen when" + + "doing the hive hashing", + null, null); + ColumnVector integers = ColumnVector.fromBoxedInts( + 0, 100, -100, Integer.MIN_VALUE, Integer.MAX_VALUE, null); + ColumnVector doubles = ColumnVector.fromBoxedDoubles(0.0, 100.0, -100.0, + POSITIVE_DOUBLE_NAN_LOWER_RANGE, POSITIVE_DOUBLE_NAN_UPPER_RANGE, null); + ColumnVector floats = ColumnVector.fromBoxedFloats(0f, 100f, -100f, + NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE, null); + ColumnVector bools = ColumnVector.fromBoxedBooleans( + true, false, null, false, true, null); + ColumnView structs = ColumnView.makeStructView(strings, integers, doubles, floats, bools); + ColumnVector result = Hash.hiveHash(new ColumnView[]{structs}); + ColumnVector expected = ColumnVector.fromInts(89581538, 363542820, 413439036, + 1272817854, 1513589666, 0)) { + assertColumnsAreEqual(expected, result); + } + } + + @Test + void testHiveHashNestedStruct() { + try (ColumnVector strings = ColumnVector.fromStrings( + "a", "B\n", "dE\"\u0100\t\u0101 \ud720\ud721", + "This is a long string (greater than 128 bytes/char string) case to test this " + + "hash function. Just want an abnormal case here to see if any error may happen when" + + "doing the hive hashing", + null, null); + ColumnVector integers = ColumnVector.fromBoxedInts( + 0, 100, -100, Integer.MIN_VALUE, Integer.MAX_VALUE, null); + ColumnVector doubles = ColumnVector.fromBoxedDoubles(0.0, 100.0, -100.0, + POSITIVE_DOUBLE_NAN_LOWER_RANGE, POSITIVE_DOUBLE_NAN_UPPER_RANGE, null); + ColumnVector floats = ColumnVector.fromBoxedFloats(0f, 100f, -100f, + NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE, null); + ColumnVector bools = ColumnVector.fromBoxedBooleans( + true, false, null, false, true, null); + ColumnView structs1 = ColumnView.makeStructView(strings, integers); + ColumnView structs2 = ColumnView.makeStructView(structs1, doubles); + ColumnView structs3 = ColumnView.makeStructView(bools); + ColumnView structs = ColumnView.makeStructView(structs2, floats, structs3); + ColumnVector result = Hash.hiveHash(new ColumnView[]{structs}); + ColumnVector expected = ColumnVector.fromInts(89581538, 363542820, 413439036, + 1272817854, 1513589666, 0)) { + assertColumnsAreEqual(expected, result); + } + } + + @Test + void testHiveHashLists() { + try (ColumnVector stringListCV = ColumnVector.fromLists( + new ListType(true, new BasicType(true, DType.STRING)), + Arrays.asList(null, "a"), + Arrays.asList("B\n", ""), + Arrays.asList("dE\"\u0100\t\u0101", " \ud720\ud721"), + Collections.singletonList("This is a long string (greater than 128 bytes/char string) case to test this " + + "hash function. Just want an abnormal case here to see if any error may happen when" + + "doing the hive hashing"), + Collections.singletonList(""), + null); + ColumnVector stringResult = Hash.hiveHash(new ColumnView[]{stringListCV}); + ColumnVector stringExpected = ColumnVector.fromInts(97, 63736, -96263528, 2112075710, 0, 0); + ColumnVector intListCV = ColumnVector.fromLists( + new ListType(true, new BasicType(true, DType.INT32)), + Collections.emptyList(), + Arrays.asList(0, -2, 3), + Collections.singletonList(Integer.MAX_VALUE), + Arrays.asList(5, -6, null), + Collections.singletonList(Integer.MIN_VALUE), + null); + ColumnVector intExpected = ColumnVector.fromInts(0, -59, 2147483647, 4619, -2147483648, 0); + ColumnVector intResult = Hash.hiveHash(new ColumnVector[]{intListCV});) { + assertColumnsAreEqual(stringExpected, stringResult); + assertColumnsAreEqual(intExpected, intResult); + } + } + + @Test + void testHiveHashNestedLists() { + try (ColumnVector nestedStringListCV = ColumnVector.fromLists( + new ListType(true, new ListType(true, new BasicType(true, DType.STRING))), + Arrays.asList(null, Arrays.asList("a", null)), + Arrays.asList(Arrays.asList("B\n", "")), + Arrays.asList(Collections.singletonList("dE\"\u0100\t\u0101"), Collections.singletonList(" \ud720\ud721")), + Arrays.asList(Collections.singletonList("This is a long string (greater than 128 bytes/char string) case to test this " + + "hash function. Just want an abnormal case here to see if any error may happen when" + + "doing the hive hashing")), + Arrays.asList(Collections.singletonList(""), null), + null); + ColumnVector stringExpected = ColumnVector.fromInts(3007, 63736, -96263528, 2112075710, 0, 0); + ColumnVector stringResult = Hash.hiveHash(new ColumnView[]{nestedStringListCV}); + ColumnVector nestedIntListCV = ColumnVector.fromLists( + new ListType(true, new ListType(true, new BasicType(true, DType.INT32))), + Arrays.asList(Arrays.asList(null, null), null), + Arrays.asList(Collections.singletonList(0), Collections.singletonList(-2), Collections.singletonList(3)), + Arrays.asList(null, Collections.singletonList(Integer.MAX_VALUE)), + Arrays.asList(Collections.singletonList(5), Arrays.asList(-6, null)), + Arrays.asList(Collections.singletonList(Integer.MIN_VALUE), null), + null); + ColumnVector intExpected = ColumnVector.fromInts(0, -59, 2147483647, -31, -2147483648, 0); + ColumnVector intResult = Hash.hiveHash(new ColumnVector[]{nestedIntListCV});) { + assertColumnsAreEqual(stringExpected, stringResult); + assertColumnsAreEqual(intExpected, intResult); + } + } + + @Test + void testHiveHashStructOfList() { + try (ColumnVector stringListCV = ColumnVector.fromLists( + new ListType(true, new BasicType(true, DType.STRING)), + Arrays.asList(null, "a"), + Arrays.asList("B\n", ""), + Arrays.asList("dE\"\u0100\t\u0101", " \ud720\ud721"), + Collections.singletonList("This is a long string (greater than 128 bytes/char string) case to test this " + + "hash function. Just want an abnormal case here to see if any error may happen when" + + "doing the hive hashing"), + Collections.singletonList(""), + null); + ColumnVector intListCV = ColumnVector.fromLists( + new ListType(true, new BasicType(true, DType.INT32)), + Collections.singletonList(null), + Arrays.asList(0, -2, 3), + Collections.singletonList(Integer.MAX_VALUE), + Arrays.asList(5, -6, null), + Collections.singletonList(Integer.MIN_VALUE), + null); + ColumnVector doubles = ColumnVector.fromBoxedDoubles( + 0.0, 100.0, -100.0, POSITIVE_DOUBLE_NAN_LOWER_RANGE, POSITIVE_DOUBLE_NAN_UPPER_RANGE, null); + ColumnVector floats = ColumnVector.fromBoxedFloats( + 0f, 100f, -100f, NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE, null); + ColumnView structCV = ColumnView.makeStructView(intListCV, stringListCV, doubles, floats); + ColumnVector nestedExpected = ColumnVector.fromInts(93217, 286968083, 59992121, -1697616301, 2127036416, 0); + ColumnVector nestedResult = Hash.hiveHash(new ColumnView[]{structCV})) { + assertColumnsAreEqual(nestedExpected, nestedResult); + } + } + + @Test + void testHiveHashListOfStruct() { + try (ColumnVector structListCV = ColumnVector.fromLists(new ListType(true, new StructType(true, + new BasicType(true, DType.STRING), new BasicType(true, DType.INT32), new BasicType(true, DType.FLOAT64), new BasicType(true, DType.FLOAT32), new BasicType(true, DType.BOOL8))), + Collections.emptyList(), + Collections.singletonList(new StructData("a", 0, 0.0, 0f, true)), + Arrays.asList(new StructData("B\n", 100, 100.0, 100f, false), new StructData("dE\"\u0100\t\u0101 \ud720\ud721", -100, -100.0, -100f, null)), + Collections.singletonList(new StructData("This is a long string (greater than 128 bytes/char string) case to test this " + + "hash function. Just want an abnormal case here to see if any error may happen when" + "doing the hive hashing", Integer.MIN_VALUE, POSITIVE_DOUBLE_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_LOWER_RANGE, false)), + Arrays.asList(new StructData(null, Integer.MAX_VALUE, POSITIVE_DOUBLE_NAN_UPPER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE, true), new StructData(null, null, null, null, null)), + null); + ColumnVector result = Hash.hiveHash(new ColumnView[]{structListCV}); + ColumnVector expected = ColumnVector.fromInts(0, 89581538, -1201635432, 1272817854, -323360610, 0)) { + assertColumnsAreEqual(expected, result); + } + } + + @Test + void testHiveHashNestedDepthExceedsLimit() { + try (ColumnVector nestedIntListCV = ColumnVector.fromLists( + new ListType(true, new ListType(true, new BasicType(true, DType.INT32))), + Arrays.asList(Arrays.asList(null, null), null), + Arrays.asList(Collections.singletonList(0), Collections.singletonList(-2), Collections.singletonList(3)), + Arrays.asList(null, Collections.singletonList(Integer.MAX_VALUE)), + Arrays.asList(Collections.singletonList(5), Arrays.asList(-6, null)), + Arrays.asList(Collections.singletonList(Integer.MIN_VALUE), null), + null); + ColumnVector integers = ColumnVector.fromBoxedInts( + 0, 100, -100, Integer.MIN_VALUE, Integer.MAX_VALUE, null); + ColumnVector doubles = ColumnVector.fromBoxedDoubles(0.0, 100.0, -100.0, + POSITIVE_DOUBLE_NAN_LOWER_RANGE, POSITIVE_DOUBLE_NAN_UPPER_RANGE, null); + ColumnVector floats = ColumnVector.fromBoxedFloats(0f, 100f, -100f, + NEGATIVE_FLOAT_NAN_LOWER_RANGE, NEGATIVE_FLOAT_NAN_UPPER_RANGE, null); + ColumnVector bools = ColumnVector.fromBoxedBooleans( + true, false, null, false, true, null); + ColumnView structs1 = ColumnView.makeStructView(nestedIntListCV, integers); + ColumnView structs2 = ColumnView.makeStructView(structs1, doubles); + ColumnView structs3 = ColumnView.makeStructView(structs2, bools); + ColumnView structs4 = ColumnView.makeStructView(structs3); + ColumnView structs5 = ColumnView.makeStructView(structs4, floats); + ColumnView structs6 = ColumnView.makeStructView(structs5); + ColumnView nestedResult = ColumnView.makeStructView(structs6);) { + assertThrows(CudfException.class, () -> Hash.hiveHash(new ColumnView[]{nestedResult})); + } + } } diff --git a/src/test/java/com/nvidia/spark/rapids/jni/HostTableTest.java b/src/test/java/com/nvidia/spark/rapids/jni/HostTableTest.java new file mode 100644 index 0000000000..0064dee1f5 --- /dev/null +++ b/src/test/java/com/nvidia/spark/rapids/jni/HostTableTest.java @@ -0,0 +1,153 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni; + +import ai.rapids.cudf.AssertUtils; +import ai.rapids.cudf.Cuda; +import ai.rapids.cudf.DType; +import ai.rapids.cudf.ColumnVector; +import ai.rapids.cudf.Table; +import ai.rapids.cudf.HostColumnVector.BasicType; +import ai.rapids.cudf.HostColumnVector.DataType; +import ai.rapids.cudf.HostColumnVector.ListType; +import ai.rapids.cudf.HostColumnVector.StructData; +import ai.rapids.cudf.HostColumnVector.StructType; +import org.junit.jupiter.api.Test; + +public class HostTableTest { + @Test + public void testRoundTripSync() { + try (Table expected = buildTable()) { + try (HostTable ht = HostTable.fromTable(expected, Cuda.DEFAULT_STREAM)) { + Table actual = ht.toTable(Cuda.DEFAULT_STREAM); + AssertUtils.assertTablesAreEqual(expected, actual); + } + } + } + + @Test + public void testRoundTripSyncDefault() { + try (Table expected = buildTable()) { + try (HostTable ht = HostTable.fromTable(expected)) { + Table actual = ht.toTable(); + AssertUtils.assertTablesAreEqual(expected, actual); + } + } + } + + @Test + public void testRoundTripAsync() { + testRoundTripAsync(buildTable()); + } + + @Test + public void testRoundTripAsyncEmpty() { + testRoundTripAsync(buildEmptyTable()); + } + + private void testRoundTripAsync(Table expected) { + try (Table t = expected) { + try (HostTable ht = HostTable.fromTableAsync(t, Cuda.DEFAULT_STREAM)) { + Table actual = ht.toTableAsync(Cuda.DEFAULT_STREAM); + AssertUtils.assertTablesAreEqual(expected, actual); + } + } + } + + private Table buildEmptyTable() { + DataType listStringsType = new ListType(true, new BasicType(true, DType.STRING)); + DataType mapType = new ListType(true, + new StructType(true, + new BasicType(false, DType.STRING), + new BasicType(false, DType.STRING))); + DataType structType = new StructType(true, + new BasicType(true, DType.INT8), + new BasicType(false, DType.FLOAT32)); + try (ColumnVector emptyInt = ColumnVector.fromInts(); + ColumnVector emptyDouble = ColumnVector.fromDoubles(); + ColumnVector emptyString = ColumnVector.fromStrings(); + ColumnVector emptyListString = ColumnVector.fromLists(listStringsType); + ColumnVector emptyMap = ColumnVector.fromLists(mapType); + ColumnVector emptyStruct = ColumnVector.fromStructs(structType)) { + return new Table(emptyInt, emptyInt, emptyDouble, emptyString, + emptyListString, emptyMap, emptyStruct); + } + } + + private Table buildTable() { + StructType mapStructType = new StructType(true, + new BasicType(false, DType.STRING), + new BasicType(false, DType.STRING)); + StructType structType = new StructType(true, + new BasicType(true, DType.INT32), + new BasicType(false, DType.FLOAT32)); + return new Table.TestBuilder() + .column( 100, 202, 3003, 40004, 5, -60, 1, null, 3, null, 5, null, 7, null, 9, null, 11, null, 13, null, 15) + .column( true, true, false, false, true, null, true, true, null, false, false, null, true, true, null, false, false, null, true, true, null) + .column( (byte)1, (byte)2, null, (byte)4, (byte)5, (byte)6, (byte)1, (byte)2, (byte)3, null, (byte)5, (byte)6, (byte)7, null, (byte)9, (byte)10, (byte)11, null, (byte)13, (byte)14, (byte)15) + .column((short)6, (short)5, (short)4, null, (short)2, (short)1, (short)1, (short)2, (short)3, null, (short)5, (short)6, (short)7, null, (short)9, (short)10, null, (short)12, (short)13, (short)14, null) + .column( 1L, null, 1001L, 50L, -2000L, null, 1L, 2L, 3L, 4L, null, 6L, 7L, 8L, 9L, null, 11L, 12L, 13L, 14L, null) + .column( 10.1f, 20f, Float.NaN, 3.1415f, -60f, null, 1f, 2f, 3f, 4f, 5f, null, 7f, 8f, 9f, 10f, 11f, null, 13f, 14f, 15f) + .column( 10.1f, 20f, Float.NaN, 3.1415f, -60f, -50f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f, 11f, 12f, 13f, 14f, 15f) + .column( 10.1, 20.0, 33.1, 3.1415, -60.5, null, 1., 2., 3., 4., 5., 6., null, 8., 9., 10., 11., 12., null, 14., 15.) + .timestampDayColumn(99, 100, 101, 102, 103, 104, 1, 2, 3, 4, 5, 6, 7, null, 9, 10, 11, 12, 13, null, 15) + .timestampMillisecondsColumn(9L, 1006L, 101L, 5092L, null, 88L, 1L, 2L, 3L, 4L, 5L ,6L, 7L, 8L, null, 10L, 11L, 12L, 13L, 14L, 15L) + .timestampSecondsColumn(1L, null, 3L, 4L, 5L, 6L, 1L, 2L, 3L, 4L, 5L ,6L, 7L, 8L, 9L, null, 11L, 12L, 13L, 14L, 15L) + .decimal32Column(-3, 100, 202, 3003, 40004, 5, -60, 1, null, 3, null, 5, null, 7, null, 9, null, 11, null, 13, null, 15) + .decimal64Column(-8, 1L, null, 1001L, 50L, -2000L, null, 1L, 2L, 3L, 4L, null, 6L, 7L, 8L, 9L, null, 11L, 12L, 13L, 14L, null) + .column( "A", "B", "C", "D", null, "TESTING", "1", "2", "3", "4", "5", "6", "7", null, "9", "10", "11", "12", "13", null, "15") + .column( + strings("1", "2", "3"), strings("4"), strings("5"), strings("6, 7"), + strings("", "9", null), strings("11"), strings(""), strings(null, null), + strings("15", null), null, null, strings("18", "19", "20"), + null, strings("22"), strings("23", ""), null, + null, null, null, strings(), + strings("the end")) + .column(mapStructType, + structs(struct("1", "2")), structs(struct("3", "4")), + null, null, + structs(struct("key", "value"), struct("a", "b")), null, + null, structs(struct("3", "4"), struct("1", "2")), + structs(), structs(null, struct("foo", "bar")), + structs(null, null, null), null, + null, null, + null, null, + null, null, + null, null, + structs(struct("the", "end"))) + .column(structType, + struct(1, 1f), null, struct(2, 3f), null, struct(8, 7f), + struct(0, 0f), null, null, struct(-1, -1f), struct(-100, -100f), + struct(Integer.MAX_VALUE, Float.MAX_VALUE), null, null, null, null, + null, null, null, null, null, + struct(Integer.MIN_VALUE, Float.MIN_VALUE)) + .column( "A", "A", "C", "C", null, "TESTING", "1", "2", "3", "4", "5", "6", "7", null, "9", "10", "11", "12", "13", null, "15") + .build(); + } + + private static StructData struct(Object... values) { + return new StructData(values); + } + + private static StructData[] structs(StructData... values) { + return values; + } + + private static String[] strings(String... values) { + return values; + } +} diff --git a/src/test/java/com/nvidia/spark/rapids/jni/RmmSparkTest.java b/src/test/java/com/nvidia/spark/rapids/jni/RmmSparkTest.java index 987dd58534..f618f945b0 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/RmmSparkTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/RmmSparkTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -360,7 +360,7 @@ public void testInsertOOMsGpu() { assertThrows(GpuSplitAndRetryOOM.class, () -> Rmm.alloc(100).close()); assertEquals(0, RmmSpark.getAndResetNumRetryThrow(taskid)); assertEquals(1, RmmSpark.getAndResetNumSplitRetryThrow(taskid)); - assertEquals(ALIGNMENT * 2, RmmSpark.getAndResetGpuMaxMemoryAllocated(taskid)); + assertEquals(ALIGNMENT, RmmSpark.getAndResetGpuMaxMemoryAllocated(taskid)); // Verify that injecting OOM does not cause the block to actually happen assertEquals(RmmSparkThreadState.THREAD_RUNNING, RmmSpark.getStateOf(threadId)); @@ -818,6 +818,11 @@ public void testBasicMixedBlocking() throws ExecutionException, InterruptedExcep secondGpuAlloc.waitForAlloc(); secondGpuAlloc.freeAndWait(); } + // Do one more alloc after freeing on same task to show the max allocation metric is unimpacted + try (AllocOnAnotherThread secondGpuAlloc = new GpuAllocOnAnotherThread(taskThree, FIVE_MB)) { + secondGpuAlloc.waitForAlloc(); + secondGpuAlloc.freeAndWait(); + } } } } finally { diff --git a/src/test/java/com/nvidia/spark/rapids/jni/TimeZoneTest.java b/src/test/java/com/nvidia/spark/rapids/jni/TimeZoneTest.java index 7aaec496de..4eb97a280c 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/TimeZoneTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/TimeZoneTest.java @@ -1,5 +1,5 @@ /* -* Copyright (c) 2023, NVIDIA CORPORATION. +* Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -44,11 +44,10 @@ static void cleanup() { @Test void databaseLoadedTest() { // Check for a few timezones - GpuTimeZoneDB instance = GpuTimeZoneDB.getInstance(); - List transitions = instance.getHostFixedTransitions("UTC+8"); + List transitions = GpuTimeZoneDB.getHostFixedTransitions("UTC+8"); assertNotNull(transitions); assertEquals(1, transitions.size()); - transitions = instance.getHostFixedTransitions("Asia/Shanghai"); + transitions = GpuTimeZoneDB.getHostFixedTransitions("Asia/Shanghai"); assertNotNull(transitions); ZoneId shanghai = ZoneId.of("Asia/Shanghai").normalized(); assertEquals(shanghai.getRules().getTransitions().size() + 1, transitions.size()); diff --git a/src/test/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializerTest.java b/src/test/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializerTest.java new file mode 100644 index 0000000000..210777accf --- /dev/null +++ b/src/test/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializerTest.java @@ -0,0 +1,465 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni.kudo; + +import ai.rapids.cudf.*; +import com.nvidia.spark.rapids.jni.Arms; +import org.junit.jupiter.api.Test; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static java.lang.Math.toIntExact; +import static java.util.Arrays.asList; +import static java.util.Collections.emptyList; +import static java.util.Collections.singletonList; +import static org.junit.jupiter.api.Assertions.*; + +public class KudoSerializerTest { + @Test + public void testSerializeAndDeserializeTable() { + try(Table expected = buildTestTable()) { + int rowCount = toIntExact(expected.getRowCount()); + for (int sliceSize = 1; sliceSize <= rowCount; sliceSize++) { + List tableSlices = new ArrayList<>(); + for (int startRow = 0; startRow < rowCount; startRow += sliceSize) { + tableSlices.add(new TableSlice(startRow, Math.min(sliceSize, rowCount - startRow), expected)); + } + + checkMergeTable(expected, tableSlices); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Test + public void testRowCountOnly() throws Exception { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + long bytesWritten = KudoSerializer.writeRowCountToStream(out, 5); + assertEquals(28, bytesWritten); + + ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); + KudoTableHeader header = KudoTableHeader.readFrom(new DataInputStream(in)).get(); + + assertEquals(0, header.getNumColumns()); + assertEquals(0, header.getOffset()); + assertEquals(5, header.getNumRows()); + assertEquals(0, header.getValidityBufferLen()); + assertEquals(0, header.getOffsetBufferLen()); + assertEquals(0, header.getTotalDataLen()); + } + + @Test + public void testWriteSimple() throws Exception { + KudoSerializer serializer = new KudoSerializer(buildSimpleTestSchema()); + + try (Table t = buildSimpleTable()) { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + long bytesWritten = serializer.writeToStream(t, out, 0, 4); + assertEquals(189, bytesWritten); + + ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); + + KudoTableHeader header = KudoTableHeader.readFrom(new DataInputStream(in)).get(); + assertEquals(7, header.getNumColumns()); + assertEquals(0, header.getOffset()); + assertEquals(4, header.getNumRows()); + assertEquals(24, header.getValidityBufferLen()); + assertEquals(40, header.getOffsetBufferLen()); + assertEquals(160, header.getTotalDataLen()); + + // First integer column has no validity buffer + assertFalse(header.hasValidityBuffer(0)); + for (int i = 1; i < 7; i++) { + assertTrue(header.hasValidityBuffer(i)); + } + } + } + + @Test + public void testMergeTableWithDifferentValidity() { + Arms.withResource(new ArrayList(), tables -> { + Table table1 = new Table.TestBuilder() + .column(-83182L, 5822L, 3389L, 7384L, 7297L) + .column(-2.06, -2.14, 8.04, 1.16, -1.0) + .build(); + tables.add(table1); + + Table table2 = new Table.TestBuilder() + .column(-47L, null, -83L, -166L, -220L, 470L, 619L, 803L, 661L) + .column(-6.08, 1.6, 1.78, -8.01, 1.22, 1.43, 2.13, -1.65, null) + .build(); + tables.add(table2); + + Table table3 = new Table.TestBuilder() + .column(8722L, 8733L) + .column(2.51, 0.0) + .build(); + tables.add(table3); + + + Table expected = new Table.TestBuilder() + .column(7384L, 7297L, 803L, 661L, 8733L) + .column(1.16, -1.0, -1.65, null, 0.0) + .build(); + tables.add(expected); + + checkMergeTable(expected, asList( + new TableSlice(3, 2, table1), + new TableSlice(7, 2, table2), + new TableSlice(1, 1, table3))); + return null; + }); + } + + @Test + public void testMergeList() { + Arms.withResource(new ArrayList
(), tables -> { + Table table1 = new Table.TestBuilder() + .column(-881L, 482L, 660L, 896L, -129L, -108L, -428L, 0L, 617L, 782L) + .column(integers(665), integers(-267), integers(398), integers(-314), + integers(-370), integers(181), integers(665, 544), integers(222), integers(-587), + integers(544)) + .build(); + tables.add(table1); + + Table table2 = new Table.TestBuilder() + .column(-881L, 482L, 660L, 896L, 122L, 241L, 281L, 680L, 783L, null) + .column(integers(-370), integers(398), integers(-587, 398), integers(-314), + integers(307), integers(-397, -633), integers(-314, 307), integers(-633), integers(-397), + integers(181, -919, -175)) + .build(); + tables.add(table2); + + Table expected = new Table.TestBuilder() + .column(896L, -129L, -108L, -428L, 0L, 617L, 782L, 482L, 660L, 896L, 122L, 241L, + 281L, 680L, 783L, null) + .column(integers(-314), integers(-370), integers(181), integers(665, 544), integers(222), + integers(-587), integers(544), integers(398), integers(-587, 398), integers(-314), + integers(307), integers(-397, -633), integers(-314, 307), integers(-633), integers(-397), + integers(181, -919, -175)) + .build(); + tables.add(expected); + + checkMergeTable(expected, asList( + new TableSlice(3, 7, table1), + new TableSlice(1, 9, table2))); + + return null; + }); + } + + + @Test + public void testSerializeValidity() { + Arms.withResource(new ArrayList
(), tables -> { + List col1 = new ArrayList<>(512); + col1.add(null); + col1.add(null); + col1.addAll(IntStream.range(2, 512).boxed().collect(Collectors.toList())); + + Table table1 = new Table.TestBuilder() + .column(col1.toArray(new Integer[0])) + .build(); + tables.add(table1); + + Table table2 = new Table.TestBuilder() + .column(509, 510, 511) + .build(); + tables.add(table2); + + checkMergeTable(table2, asList(new TableSlice(509, 3, table1))); + return null; + }); + } + + private static Schema buildSimpleTestSchema() { + Schema.Builder builder = Schema.builder(); + + builder.addColumn(DType.INT32, "a"); + builder.addColumn(DType.STRING, "b"); + Schema.Builder listBuilder = builder.addColumn(DType.LIST, "c"); + listBuilder.addColumn(DType.INT32, "c1"); + + Schema.Builder structBuilder = builder.addColumn(DType.STRUCT, "d"); + structBuilder.addColumn(DType.INT8, "d1"); + structBuilder.addColumn(DType.INT64, "d2"); + + return builder.build(); + } + + private static Table buildSimpleTable() { + HostColumnVector.StructType st = new HostColumnVector.StructType( + true, + new HostColumnVector.BasicType(true, DType.INT8), + new HostColumnVector.BasicType(true, DType.INT64) + ); + return new Table.TestBuilder() + .column(1, 2, 3, 4) + .column("1", "12", null, "45") + .column(new Integer[]{1, null, 3}, new Integer[]{4, 5, 6}, null, new Integer[]{7, 8, 9}) + .column(st, new HostColumnVector.StructData((byte) 1, 11L), + new HostColumnVector.StructData((byte) 2, null), null, + new HostColumnVector.StructData((byte) 3, 33L)) + .build(); + } + + private static Table buildTestTable() { + HostColumnVector.ListType listMapType = new HostColumnVector.ListType(true, + new HostColumnVector.ListType(true, + new HostColumnVector.StructType(true, + new HostColumnVector.BasicType(false, DType.STRING), + new HostColumnVector.BasicType(true, DType.STRING)))); + HostColumnVector.ListType mapStructType = new HostColumnVector.ListType(true, + new HostColumnVector.StructType(true, + new HostColumnVector.BasicType(false, DType.STRING), + new HostColumnVector.BasicType(false, DType.STRING))); + HostColumnVector.StructType structType = new HostColumnVector.StructType(true, + new HostColumnVector.BasicType(true, DType.INT32), + new HostColumnVector.BasicType(false, DType.FLOAT32)); + HostColumnVector.ListType listDateType = new HostColumnVector.ListType(true, + new HostColumnVector.StructType(false, + new HostColumnVector.BasicType(false, DType.INT32), + new HostColumnVector.BasicType(true, DType.INT32))); + + return new Table.TestBuilder() + .column(100, 202, 3003, 40004, 5, -60, 1, null, 3, null, 5, null, 7, null, 9, null, 11, null, 13, null, 15) + .column(true, true, false, false, true, null, true, true, null, false, false, null, true, + true, null, false, false, null, true, true, null) + .column((byte)1, (byte)2, null, (byte)4, (byte)5,(byte)6,(byte)1,(byte)2,(byte)3, null,(byte)5, (byte)6, + (byte) 7, null,(byte) 9,(byte) 10,(byte) 11, null,(byte) 13,(byte) 14,(byte) 15) + .column((short)6, (short)5, (short)4, null, (short)2, (short)1, + (short)1, (short)2, (short)3, null, (short)5, (short)6, (short)7, null, (short)9, + (short)10, null, (short)12, (short)13, (short)14, null) + .column(1L, null, 1001L, 50L, -2000L, null, 1L, 2L, 3L, 4L, null, 6L, + 7L, 8L, 9L, null, 11L, 12L, 13L, 14L, null) + .column(10.1f, 20f, -1f, 3.1415f, -60f, null, 1f, 2f, 3f, 4f, 5f, null, 7f, 8f, 9f, 10f, 11f, null, 13f, 14f, 15f) + .column(10.1f, 20f, -2f, 3.1415f, -60f, -50f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f, 11f, 12f, 13f, 14f, 15f) + .column(10.1, 20.0, 33.1, 3.1415, -60.5, null, 1d, 2.0, 3.0, 4.0, 5.0, + 6.0, null, 8.0, 9.0, 10.0, 11.0, 12.0, null, 14.0, 15.0) + .column((Float)null, null, null, null, null, null, null, null, null, null, + null, null, null, null, null, null, null, null, null, null, null) + .timestampDayColumn(99, 100, 101, 102, 103, 104, 1, 2, 3, 4, 5, 6, 7, null, 9, 10, 11, 12, 13, null, 15) + .timestampMillisecondsColumn(9L, 1006L, 101L, 5092L, null, 88L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, null, 10L, 11L, 12L, 13L, 14L, 15L) + .timestampSecondsColumn(1L, null, 3L, 4L, 5L, 6L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, null, 11L, 12L, 13L, 14L, 15L) + .decimal32Column(-3, 100, 202, 3003, 40004, 5, -60, 1, null, 3, + null, 5, null, 7, null, 9, null, 11, null, 13, null, 15) + .decimal64Column(-8, 1L, null, 1001L, 50L, -2000L, null, 1L, 2L, 3L, + 4L, null, 6L, 7L, 8L, 9L, null, 11L, 12L, 13L, 14L, null) + .column("A", "B", "C", "D", null, "TESTING", "1", "2", "3", "4", + "5", "6", "7", null, "9", "10", "11", "12", "13", null, "15") + .column("A", "A", "C", "C", "E", "TESTING", "1", "2", "3", "4", "5", + "6", "7", "", "9", "10", "11", "12", "13", "", "15") + .column("", "", "", "", "", "", "", "", "", "", "", "", "", "", "", + "", "", "", "", "", "") + .column("", null, "", "", null, "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "") + .column((String)null, null, null, null, null, null, null, null, null, null, + null, null, null, null, null, null, null, null, null, null, null) + .column(mapStructType, structs(struct("1", "2")), structs(struct("3", "4")), null, null, + structs(struct("key", "value"), struct("a", "b")), null, null, + structs(struct("3", "4"), struct("1", "2")), structs(), + structs(null, struct("foo", "bar")), + structs(null, null, null), null, null, null, null, null, null, null, null, null, + structs(struct("the", "end"))) + .column(structType, struct(1, 1f), null, struct(2, 3f), + null, struct(8, 7f), struct(0, 0f), null, + null, struct(-1, -1f), struct(-100, -100f), + struct(Integer.MAX_VALUE, Float.MAX_VALUE), null, + null, null, + null, null, + null, null, + null, null, + struct(Integer.MIN_VALUE, Float.MIN_VALUE)) + .column(integers(1, 2), null, integers(3, 4, null, 5, null), null, null, integers(6, 7, 8), + integers(null, null, null), integers(1, 2, 3), integers(4, 5, 6), integers(7, 8, 9), + integers(10, 11, 12), integers((Integer)null), integers(14, null), + integers(14, 15, null, 16, 17, 18), integers(19, 20, 21), integers(22, 23, 24), + integers(25, 26, 27), integers(28, 29, 30), integers(31, 32, 33), null, + integers(37, 38, 39)) + .column(integers(), integers(), integers(), integers(), integers(), integers(), integers(), + integers(), integers(), integers(), integers(), integers(), integers(), integers(), + integers(), integers(), integers(), integers(), integers(), integers(), integers()) + .column(integers(null, null), integers(null, null, null, null), integers(), + integers(null, null, null), integers(), integers(null, null, null, null, null), + integers((Integer)null), integers(null, null, null), integers(null, null), + integers(null, null, null, null), integers(null, null, null, null, null), integers(), + integers(null, null, null, null), integers(null, null, null), integers(null, null), + integers(null, null, null), integers(null, null), integers((Integer)null), + integers((Integer)null), integers(null, null), + integers(null, null, null, null, null)) + .column((Integer)null, null, null, null, null, null, null, null, null, null, + null, null, null, null, null, null, null, null, null, null, null) + .column(strings("1", "2", "3"), strings("4"), strings("5"), strings("6, 7"), + strings("", "9", null), strings("11"), strings(""), strings(null, null), + strings("15", null), null, null, strings("18", "19", "20"), null, strings("22"), + strings("23", ""), null, null, null, null, strings(), strings("the end")) + .column(strings(), strings(), strings(), strings(), strings(), strings(), strings(), + strings(), strings(), strings(), strings(), strings(), strings(), strings(), strings(), + strings(), strings(), strings(), strings(), strings(), strings()) + .column(strings(null, null), strings(null, null, null, null), strings(), + strings(null, null, null), strings(), strings(null, null, null, null, null), + strings((String)null), strings(null, null, null), strings(null, null), + strings(null, null, null, null), strings(null, null, null, null, null), strings(), + strings(null, null, null, null), strings(null, null, null), strings(null, null), + strings(null, null, null), strings(null, null), strings((String)null), + strings((String)null), strings(null, null), + strings(null, null, null, null, null)) + .column((String)null, null, null, null, null, null, null, null, null, null, + null, null, null, null, null, null, null, null, null, null, null) + .column(listMapType, asList(asList(struct("k1", "v1"), struct("k2", "v2")), + singletonList(struct("k3", "v3"))), + asList(asList(struct("k4", "v4"), struct("k5", "v5"), + struct("k6", "v6")), singletonList(struct("k7", "v7"))), + null, null, null, asList(asList(struct("k8", "v8"), struct("k9", "v9")), + asList(struct("k10", "v10"), struct("k11", "v11"), struct("k12", "v12"), + struct("k13", "v13"))), + singletonList(asList(struct("k14", "v14"), struct("k15", "v15"))), null, null, null, null, + asList(asList(struct("k16", "v16"), struct("k17", "v17")), + singletonList(struct("k18", "v18"))), + asList(asList(struct("k19", "v19"), struct("k20", "v20")), + singletonList(struct("k21", "v21"))), + asList(singletonList(struct("k22", "v22")), singletonList(struct("k23", "v23"))), + asList(null, null, null), + asList(singletonList(struct("k22", null)), singletonList(struct("k23", null))), + null, null, null, null, null) + .column(listDateType, asList(struct(-210, 293), struct(-719, 205), struct(-509, 183), + struct(174, 122), struct(647, 683)), + asList(struct(311, 992), struct(-169, 482), struct(166, 525)), + asList(struct(156, 197), struct(926, 134), struct(747, 312), struct(293, 801)), + asList(struct(647, null), struct(293, 387)), emptyList(), + null, emptyList(), null, + asList(struct(-210, 293), struct(-719, 205), struct(-509, 183), struct(174, 122), + struct(647, 683)), + asList(struct(311, 992), struct(-169, 482), struct(166, 525)), + asList(struct(156, 197), struct(926, 134), struct(747, 312), struct(293, 801)), + asList(struct(647, null), struct(293, 387)), emptyList(), null, + emptyList(), null, + singletonList(struct(778, 765)), asList(struct(7, 87), struct(8, 96)), + asList(struct(9, 56), struct(10, 532), struct(11, 456)), null, emptyList()) + .build(); + } + + private static void checkMergeTable(Table expected, List tableSlices) { + try { + KudoSerializer serializer = new KudoSerializer(schemaOf(expected)); + + ByteArrayOutputStream bout = new ByteArrayOutputStream(); + for (TableSlice slice : tableSlices) { + serializer.writeToStream(slice.getBaseTable(), bout, slice.getStartRow(), slice.getNumRows()); + } + bout.flush(); + + ByteArrayInputStream bin = new ByteArrayInputStream(bout.toByteArray()); + Arms.withResource(new ArrayList(tableSlices.size()), kudoTables -> { + try { + for (int i = 0; i < tableSlices.size(); i++) { + kudoTables.add(KudoTable.from(bin).get()); + } + + long rows = kudoTables.stream().mapToLong(t -> t.getHeader().getNumRows()).sum(); + assertEquals(expected.getRowCount(), toIntExact(rows)); + + try (Table merged = serializer.mergeToTable(kudoTables).getLeft()) { + assertEquals(expected.getRowCount(), merged.getRowCount()); + AssertUtils.assertTablesAreEqual(expected, merged); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + + return null; + }); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private static Integer[] integers(Integer... values) { + return values; + } + + private static HostColumnVector.StructData struct(Object... values) { + return new HostColumnVector.StructData(values); + } + + private static List structs(HostColumnVector.StructData... values) { + return asList(values); + } + + private static String[] strings(String... values) { + return values; + } + + private static Schema schemaOf(Table t) { + Schema.Builder builder = Schema.builder(); + + for (int i = 0; i < t.getNumberOfColumns(); i++) { + ColumnVector cv = t.getColumn(i); + addToSchema(cv, "col_" + i + "_", builder); + } + + return builder.build(); + } + + private static void addToSchema(ColumnView cv, String namePrefix, Schema.Builder builder) { + toSchemaInner(cv, 0, namePrefix, builder); + } + + private static int toSchemaInner(ColumnView cv, int idx, String namePrefix, + Schema.Builder builder) { + String name = namePrefix + idx; + + Schema.Builder thisBuilder = builder.addColumn(cv.getType(), name); + int lastIdx = idx; + for (int i = 0; i < cv.getNumChildren(); i++) { + lastIdx = toSchemaInner(cv.getChildColumnView(i), lastIdx + 1, namePrefix, + thisBuilder); + } + + return lastIdx; + } + + private static class TableSlice { + private final int startRow; + private final int numRows; + private final Table baseTable; + + private TableSlice(int startRow, int numRows, Table baseTable) { + this.startRow = startRow; + this.numRows = numRows; + this.baseTable = baseTable; + } + + public int getStartRow() { + return startRow; + } + + public int getNumRows() { + return numRows; + } + + public Table getBaseTable() { + return baseTable; + } + } +} diff --git a/thirdparty/cudf b/thirdparty/cudf index 7b0adfa253..ff41ecf473 160000 --- a/thirdparty/cudf +++ b/thirdparty/cudf @@ -1 +1 @@ -Subproject commit 7b0adfa2533e4792464230ee67916a04ce06caf6 +Subproject commit ff41ecf473b7e7be01885c66cd90571ccc48e41d diff --git a/thirdparty/cudf-pins/rapids-cmake.sha b/thirdparty/cudf-pins/rapids-cmake.sha index 37820d8ad4..7d09034936 100644 --- a/thirdparty/cudf-pins/rapids-cmake.sha +++ b/thirdparty/cudf-pins/rapids-cmake.sha @@ -1 +1 @@ -312909127cf0fe96e178f0ffa754908f58d489a3 +cb1a30c8f12923fe1f56b1948cb8b26d5676f6b9 diff --git a/thirdparty/cudf-pins/versions.json b/thirdparty/cudf-pins/versions.json index ed40c777a4..dafc978739 100644 --- a/thirdparty/cudf-pins/versions.json +++ b/thirdparty/cudf-pins/versions.json @@ -10,22 +10,22 @@ "patches" : [ { - "file" : "${current_json_dir}/cccl_symbol_visibility.diff", + "file" : "${current_json_dir}/../cudf/cpp/cmake/thirdparty/patches/cccl_symbol_visibility.diff", "fixed_in" : "2.6", "issue" : "Correct symbol visibility issues in libcudacxx [https://github.com/NVIDIA/cccl/pull/1832/]" }, { - "file" : "${current_json_dir}/thrust_disable_64bit_dispatching.diff", + "file" : "${current_json_dir}/../cudf/cpp/cmake/thirdparty/patches/thrust_disable_64bit_dispatching.diff", "fixed_in" : "", "issue" : "Remove 64bit dispatching as not needed by libcudf and results in compiling twice as many kernels [https://github.com/rapidsai/cudf/pull/11437]" }, { - "file" : "${current_json_dir}/thrust_faster_sort_compile_times.diff", + "file" : "${current_json_dir}/../cudf/cpp/cmake/thirdparty/patches/thrust_faster_sort_compile_times.diff", "fixed_in" : "", "issue" : "Improve Thrust sort compile times by not unrolling loops for inlined comparators [https://github.com/rapidsai/cudf/pull/10577]" }, { - "file" : "${current_json_dir}/thrust_faster_scan_compile_times.diff", + "file" : "${current_json_dir}/../cudf/cpp/cmake/thirdparty/patches/thrust_faster_scan_compile_times.diff", "fixed_in" : "", "issue" : "Improve Thrust scan compile times by reducing the number of kernels generated [https://github.com/rapidsai/cudf/pull/8183]" } @@ -40,14 +40,6 @@ "git_url" : "https://github.com/google/googletest.git", "version" : "1.13.0" }, - "KvikIO" : - { - "always_download" : true, - "git_shallow" : false, - "git_tag" : "1b85263eba89c0f077fbb3da90a770b84161d20f", - "git_url" : "https://github.com/rapidsai/kvikio.git", - "version" : "24.10" - }, "bs_thread_pool" : { "always_download" : true, @@ -60,7 +52,7 @@ { "always_download" : true, "git_shallow" : false, - "git_tag" : "d3477661d771e0d6fd22259bf6dd6f8c64a7401c", + "git_tag" : "dc0f9fc20c2a544e53099e640a681b347532391a", "git_url" : "https://github.com/NVIDIA/cuCollections.git", "version" : "0.0.1" }, @@ -103,12 +95,28 @@ "git_url" : "https://github.com/rapidsai/jitify.git", "version" : "2.0.0" }, + "kvikio" : + { + "always_download" : true, + "git_shallow" : false, + "git_tag" : "e2ff91954a2ebe3ee65e0a1568e9d77cb195164d", + "git_url" : "https://github.com/rapidsai/kvikio.git", + "version" : "24.12" + }, "nanoarrow" : { "always_download" : true, "git_shallow" : false, "git_tag" : "1e2664a70ec14907409cadcceb14d79b9670bcdb", "git_url" : "https://github.com/apache/arrow-nanoarrow.git", + "patches" : + [ + { + "file" : "${current_json_dir}/../cudf/cpp/cmake/thirdparty/patches/nanoarrow_clang_tidy_compliance.diff", + "fixed_in" : "", + "issue" : "https://github.com/apache/arrow-nanoarrow/issues/537" + } + ], "version" : "0.6.0.dev" }, "nvcomp" : @@ -119,15 +127,15 @@ "git_url" : "https://github.com/NVIDIA/nvcomp.git", "proprietary_binary" : { - "aarch64-linux" : "https://developer.download.nvidia.com/compute/nvcomp/${version}/local_installers/nvcomp_${version}_SBSA_${cuda-toolkit-version-mapping}.tgz", - "x86_64-linux" : "https://developer.download.nvidia.com/compute/nvcomp/${version}/local_installers/nvcomp_${version}_x86_64_${cuda-toolkit-version-mapping}.tgz" + "aarch64-linux" : "https://developer.download.nvidia.com/compute/nvcomp/redist/nvcomp/linux-sbsa/nvcomp-linux-sbsa-${version}_cuda${cuda-toolkit-version-mapping}-archive.tar.xz", + "x86_64-linux" : "https://developer.download.nvidia.com/compute/nvcomp/redist/nvcomp/linux-x86_64/nvcomp-linux-x86_64-${version}_cuda${cuda-toolkit-version-mapping}-archive.tar.xz" }, "proprietary_binary_cuda_version_mapping" : { - "11" : "11.x", - "12" : "12.x" + "11" : "11", + "12" : "12" }, - "version" : "3.0.6" + "version" : "4.1.0.6" }, "nvtx3" : { @@ -141,9 +149,9 @@ { "always_download" : true, "git_shallow" : false, - "git_tag" : "afe0a3336397b17a96bb703e82f3b6365ee7c41e", + "git_tag" : "b391ce867a62bbfc3283779d70231be5b00e2e84", "git_url" : "https://github.com/rapidsai/rmm.git", - "version" : "24.10" + "version" : "24.12" }, "spdlog" : {