diff --git a/CMakeLists.txt b/CMakeLists.txt index cd8ccaa..e9ca779 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -21,7 +21,7 @@ option(WITH_TORCH "compile warp-ctc with Torch." ${Torch_FOUND}) option(WITH_OMP "compile warp-ctc with OpenMP." ON) option(BUILD_TESTS "build warp-ctc unit tests." ON) option(BUILD_SHARED "build warp-ctc shared library." ON) -option(WITH_ROCM "Compile PaddlePaddle with ROCM platform" OFF) +option(WITH_ROCM "Compile PaddlePaddle with ROCM platform" OFF) if(WITH_ROCM) add_definitions(-DWARPCTC_WITH_HIP) @@ -65,32 +65,38 @@ endif() # need to be at least 30 or __shfl_down in reduce wont compile IF (CUDA_VERSION VERSION_LESS "11.0") + # CUDA_VERSION < 11.0 set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_30,code=sm_30") + set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_35,code=sm_35") + set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_50,code=sm_50") ENDIF() -set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_35,code=sm_35") -set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_50,code=sm_50") set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_52,code=sm_52") IF (CUDA_VERSION VERSION_GREATER "7.6") + # CUDA_VERSION > 7.6 set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_60,code=sm_60") set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_61,code=sm_61") set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_62,code=sm_62") ENDIF() IF ((CUDA_VERSION VERSION_GREATER "9.0") OR (CUDA_VERSION VERSION_EQUAL "9.0")) + # CUDA_VERSION >= 9.0 set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_70,code=sm_70") ENDIF() IF ((CUDA_VERSION VERSION_GREATER "10.0") OR (CUDA_VERSION VERSION_EQUAL "10.0")) + # CUDA_VERSION >= 10.0 set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_75,code=sm_75") ENDIF() IF ((CUDA_VERSION VERSION_GREATER "11.0") OR (CUDA_VERSION VERSION_EQUAL "11.0")) + # CUDA_VERSION >= 11.0 set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_80,code=sm_80") ENDIF() IF ((CUDA_VERSION VERSION_GREATER "11.2") OR (CUDA_VERSION VERSION_EQUAL "11.2")) + # CUDA_VERSION >= 11.2 set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_86,code=sm_86") ENDIF() diff --git a/include/ctc.h b/include/ctc.h index 11d9de6..18455e8 100644 --- a/include/ctc.h +++ b/include/ctc.h @@ -52,6 +52,15 @@ typedef enum { CTC_GPU = 1 } ctcComputeLocation; +struct ctcParamConfigs { + /// input params of compute_ctc_loss + ctcComputeLocation flat_labels_loc{CTC_CPU}; + ctcComputeLocation label_lengths_loc{CTC_CPU}; + ctcComputeLocation input_lengths_loc{CTC_CPU}; + /// output param of compute_ctc_loss + ctcComputeLocation costs_loc{CTC_CPU}; +}; + /** Structure used for options to the CTC compution. Applications * should zero out the array using memset and sizeof(struct * ctcOptions) in C or default initialization (e.g. 'ctcOptions @@ -70,6 +79,8 @@ struct ctcOptions { /// the label value/index that the CTC calculation should use as the blank label int blank_label; + /// where to hold the input and output params + ctcParamConfigs params; }; /** Compute the connectionist temporal classification loss between diff --git a/include/detail/gpu_ctc.h b/include/detail/gpu_ctc.h index cafe6ae..c665e06 100644 --- a/include/detail/gpu_ctc.h +++ b/include/detail/gpu_ctc.h @@ -3,6 +3,7 @@ #include "ctc_helper.h" #include "gpu_ctc_kernels.h" #include "reduce.h" +#include "gpu_helper.h" template class GpuCTC { @@ -11,10 +12,19 @@ class GpuCTC { int minibatch, void *workspace, GPUstream stream, - int blank_label) : + int blank_label, + bool has_cpu_flat_labels, + bool has_cpu_label_lengths, + bool has_cpu_input_lengths, + bool copy_costs_to_cpu) : out_dim_(alphabet_size), minibatch_(minibatch), gpu_workspace_(workspace), stream_(stream), - blank_label_(blank_label) {}; + blank_label_(blank_label) { + param_configs_.has_cpu_flat_labels = has_cpu_flat_labels; + param_configs_.has_cpu_label_lengths = has_cpu_label_lengths; + param_configs_.has_cpu_input_lengths = has_cpu_input_lengths; + param_configs_.copy_costs_to_cpu = copy_costs_to_cpu; + } // Noncopyable GpuCTC(const GpuCTC&) = delete; @@ -36,7 +46,6 @@ class GpuCTC { const int* const input_lengths); private: - template ctcStatus_t launch_alpha_beta_kernels(const ProbT* const probs, ProbT *grads, @@ -53,12 +62,14 @@ class GpuCTC { ctcStatus_t setup_gpu_metadata(const int* const flat_labels, const int* const label_lengths, - const int* const input_lengths); + const int* const input_lengths, + ProbT* costs); ctcStatus_t create_metadata_and_choose_config(const int* const label_lengths, const int* const flat_labels, const int* const input_lengths, + ProbT* costs, size_t& best_config); ctcStatus_t @@ -74,6 +85,12 @@ class GpuCTC { bool compute_alpha, bool compute_betas_and_grad); + struct ParamConfigs { + bool has_cpu_flat_labels{true}; + bool has_cpu_label_lengths{true}; + bool has_cpu_input_lengths{true}; + bool copy_costs_to_cpu{true}; + }; int out_dim_; // Number of characters plus blank int minibatch_; @@ -86,34 +103,40 @@ class GpuCTC { GPUstream stream_; int blank_label_; - - void *gpu_workspace_; // Buffer for all temporary GPU memory - int *utt_length_; // T - int *label_sizes_; // L - int *repeats_; // repeats_ - int *label_offsets_; - int *labels_without_blanks_; - int *labels_with_blanks_; - ProbT *alphas_; - ProbT *nll_forward_; - ProbT *nll_backward_; - ProbT *denoms_; // Temporary storage for denoms for softmax - ProbT *probs_; // Temporary storage for probabilities (softmax output) + ParamConfigs param_configs_; + + void *gpu_workspace_{nullptr}; // Buffer for all temporary GPU memory + int *utt_length_{nullptr}; // T + int *label_sizes_{nullptr}; // L + int *repeats_{nullptr}; // repeats_ + int *label_offsets_{nullptr}; + int *labels_without_blanks_{nullptr}; + int *labels_with_blanks_{nullptr}; + ProbT *alphas_{nullptr}; + ProbT *nll_forward_{nullptr}; + ProbT *nll_backward_{nullptr}; + ProbT *denoms_{nullptr}; // Temporary storage for denoms for softmax + ProbT *probs_{nullptr}; // Temporary storage for probabilities (softmax output) }; template ctcStatus_t GpuCTC::setup_gpu_metadata(const int* const flat_labels, const int* const label_lengths, - const int* const input_lengths) + const int* const input_lengths, + ProbT* costs) { size_t gpu_bytes_used = 0; - nll_forward_ = - reinterpret_cast(static_cast(gpu_workspace_) + - gpu_bytes_used); - gpu_bytes_used += minibatch_ * sizeof(ProbT); - + if (param_configs_.copy_costs_to_cpu) { + nll_forward_ = + reinterpret_cast(static_cast(gpu_workspace_) + + gpu_bytes_used); + gpu_bytes_used += minibatch_ * sizeof(ProbT); + } else { + // input costs should be a device ptr. + nll_forward_ = costs; + } nll_backward_ = reinterpret_cast(static_cast(gpu_workspace_) + @@ -345,10 +368,11 @@ ctcStatus_t GpuCTC::create_metadata_and_choose_config(const int* const flat_labels, const int* const label_lengths, const int* const input_lengths, + ProbT* costs, size_t& best_config) { // Setup the metadata for GPU - ctcStatus_t status = setup_gpu_metadata(flat_labels, label_lengths, input_lengths); + ctcStatus_t status = setup_gpu_metadata(flat_labels, label_lengths, input_lengths, costs); if (status != CTC_STATUS_SUCCESS) return status; @@ -472,6 +496,7 @@ GpuCTC::compute_cost_and_score(const ProbT* const activations, ctcStatus_t status = create_metadata_and_choose_config(flat_labels, label_lengths, input_lengths, + costs, best_config); if (status != CTC_STATUS_SUCCESS) return status; @@ -483,26 +508,13 @@ GpuCTC::compute_cost_and_score(const ProbT* const activations, launch_gpu_kernels(probs_, grads, best_config, compute_alpha, compute_betas_and_grad); - gpuError_t cuda_status_mem, cuda_status_sync; - -#ifdef __HIPCC__ - cuda_status_mem = hipMemcpyAsync(costs, nll_forward_, - sizeof(ProbT) * minibatch_, - hipMemcpyDeviceToHost, stream_); -#else - cuda_status_mem = cudaMemcpyAsync(costs, nll_forward_, - sizeof(ProbT) * minibatch_, - cudaMemcpyDeviceToHost, stream_); -#endif - -#ifdef __HIPCC__ - cuda_status_sync = hipStreamSynchronize(stream_); -#else - cuda_status_sync = cudaStreamSynchronize(stream_); -#endif - - if (cuda_status_mem != gpuSuccess || cuda_status_sync != gpuSuccess) - return CTC_STATUS_MEMOPS_FAILED; + if (param_configs_.copy_costs_to_cpu) { + // Copy the costs back to CPU. + gpuError_t cuda_status_mem = warpctc::memcpy_d2h_sync( + costs, nll_forward_, sizeof(ProbT) * minibatch_, stream_); + if (cuda_status_mem != gpuSuccess) + return CTC_STATUS_MEMOPS_FAILED; + } return CTC_STATUS_SUCCESS; } diff --git a/include/detail/gpu_helper.h b/include/detail/gpu_helper.h new file mode 100644 index 0000000..d323f9c --- /dev/null +++ b/include/detail/gpu_helper.h @@ -0,0 +1,35 @@ +#pragma once + +#include "type_defs.h" + +namespace warpctc { + +static gpuError_t memcpy_d2h_async(void *dst, const void *src, size_t bytes, GPUstream stream) { + gpuError_t status; +#ifdef __HIPCC__ + status = hipMemcpyAsync(dst, src, bytes, hipMemcpyDeviceToHost, stream); +#else + status = cudaMemcpyAsync(dst, src, bytes, cudaMemcpyDeviceToHost, stream); +#endif + return status; +} + +static gpuError_t synchronize(GPUstream stream) { + gpuError_t status; +#ifdef __HIPCC__ + status = hipStreamSynchronize(stream); +#else + status = cudaStreamSynchronize(stream); +#endif + return status; +} + +static gpuError_t memcpy_d2h_sync(void *dst, const void *src, size_t bytes, GPUstream stream) { + gpuError_t status = memcpy_d2h_async(dst, src, bytes, stream); + if (status != gpuSuccess) { + return status; + } + return synchronize(stream); +} + +} // namespace warpctc diff --git a/src/ctc_entrypoint.cpp b/src/ctc_entrypoint.cpp index f56f6a8..01aefb2 100644 --- a/src/ctc_entrypoint.cpp +++ b/src/ctc_entrypoint.cpp @@ -72,7 +72,10 @@ ctcStatus_t compute_ctc_loss(const float* const activations, } else if (options.loc == CTC_GPU) { #if (defined(__HIPCC__) || defined(__CUDACC__)) GpuCTC ctc(alphabet_size, minibatch, workspace, options.stream, - options.blank_label); + options.blank_label, options.params.flat_labels_loc == CTC_CPU, + options.params.label_lengths_loc == CTC_CPU, + options.params.input_lengths_loc == CTC_CPU, + options.params.costs_loc == CTC_CPU); if (gradients != NULL) return ctc.cost_and_grad(activations, gradients, costs, @@ -112,7 +115,7 @@ ctcStatus_t compute_ctc_loss_double(const double* const activations, if (options.loc == CTC_CPU) { CpuCTC ctc(alphabet_size, minibatch, workspace, options.num_threads, - options.blank_label); + options.blank_label); if (gradients != NULL) return ctc.cost_and_grad(activations, gradients, @@ -125,7 +128,10 @@ ctcStatus_t compute_ctc_loss_double(const double* const activations, } else if (options.loc == CTC_GPU) { #if (defined(__HIPCC__) || defined(__CUDACC__)) GpuCTC ctc(alphabet_size, minibatch, workspace, options.stream, - options.blank_label); + options.blank_label, options.params.flat_labels_loc == CTC_CPU, + options.params.label_lengths_loc == CTC_CPU, + options.params.input_lengths_loc == CTC_CPU, + options.params.costs_loc == CTC_CPU); if (gradients != NULL) return ctc.cost_and_grad(activations, gradients, costs, @@ -167,8 +173,13 @@ ctcStatus_t get_workspace_size(const int* const label_lengths, if (options.loc == CTC_GPU) { // GPU storage - //nll_forward, nll_backward - *size_bytes += 2 * sizeof(float) * minibatch; + if (options.params.costs_loc == CTC_CPU) { + // nll_forward, nll_backward + *size_bytes += 2 * sizeof(float) * minibatch; + } else { + // nll_backward + *size_bytes += 1 * sizeof(float) * minibatch; + } //repeats *size_bytes += sizeof(int) * minibatch; @@ -249,8 +260,13 @@ ctcStatus_t get_workspace_size_double(const int* const label_lengths, if (options.loc == CTC_GPU) { // GPU storage - //nll_forward, nll_backward - *size_bytes += 2 * sizeof(double) * minibatch; + if (options.params.costs_loc == CTC_CPU) { + // nll_forward, nll_backward + *size_bytes += 2 * sizeof(double) * minibatch; + } else { + // nll_backward + *size_bytes += 1 * sizeof(double) * minibatch; + } //repeats *size_bytes += sizeof(int) * minibatch;