Skip to content

Commit

Permalink
Add param configs to allow hold more input and output params on GPU.
Browse files Browse the repository at this point in the history
  • Loading branch information
Xreki committed Jan 26, 2022
1 parent 37ece0e commit 5fc0bc5
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 54 deletions.
12 changes: 9 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ option(WITH_TORCH "compile warp-ctc with Torch." ${Torch_FOUND})
option(WITH_OMP "compile warp-ctc with OpenMP." ON)
option(BUILD_TESTS "build warp-ctc unit tests." ON)
option(BUILD_SHARED "build warp-ctc shared library." ON)
option(WITH_ROCM "Compile PaddlePaddle with ROCM platform" OFF)
option(WITH_ROCM "Compile PaddlePaddle with ROCM platform" OFF)

if(WITH_ROCM)
add_definitions(-DWARPCTC_WITH_HIP)
Expand Down Expand Up @@ -65,32 +65,38 @@ endif()

# need to be at least 30 or __shfl_down in reduce wont compile
IF (CUDA_VERSION VERSION_LESS "11.0")
# CUDA_VERSION < 11.0
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_30,code=sm_30")
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_35,code=sm_35")
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_50,code=sm_50")
ENDIF()
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_35,code=sm_35")

set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_50,code=sm_50")
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_52,code=sm_52")

IF (CUDA_VERSION VERSION_GREATER "7.6")
# CUDA_VERSION > 7.6
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_60,code=sm_60")
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_61,code=sm_61")
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_62,code=sm_62")
ENDIF()

IF ((CUDA_VERSION VERSION_GREATER "9.0") OR (CUDA_VERSION VERSION_EQUAL "9.0"))
# CUDA_VERSION >= 9.0
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_70,code=sm_70")
ENDIF()

IF ((CUDA_VERSION VERSION_GREATER "10.0") OR (CUDA_VERSION VERSION_EQUAL "10.0"))
# CUDA_VERSION >= 10.0
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_75,code=sm_75")
ENDIF()

IF ((CUDA_VERSION VERSION_GREATER "11.0") OR (CUDA_VERSION VERSION_EQUAL "11.0"))
# CUDA_VERSION >= 11.0
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_80,code=sm_80")
ENDIF()

IF ((CUDA_VERSION VERSION_GREATER "11.2") OR (CUDA_VERSION VERSION_EQUAL "11.2"))
# CUDA_VERSION >= 11.2
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_86,code=sm_86")
ENDIF()

Expand Down
11 changes: 11 additions & 0 deletions include/ctc.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ typedef enum {
CTC_GPU = 1
} ctcComputeLocation;

struct ctcParamConfigs {
/// input params of compute_ctc_loss
ctcComputeLocation flat_labels_loc{CTC_CPU};
ctcComputeLocation label_lengths_loc{CTC_CPU};
ctcComputeLocation input_lengths_loc{CTC_CPU};
/// output param of compute_ctc_loss
ctcComputeLocation costs_loc{CTC_CPU};
};

/** Structure used for options to the CTC compution. Applications
* should zero out the array using memset and sizeof(struct
* ctcOptions) in C or default initialization (e.g. 'ctcOptions
Expand All @@ -70,6 +79,8 @@ struct ctcOptions {

/// the label value/index that the CTC calculation should use as the blank label
int blank_label;
/// where to hold the input and output params
ctcParamConfigs params;
};

/** Compute the connectionist temporal classification loss between
Expand Down
100 changes: 56 additions & 44 deletions include/detail/gpu_ctc.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "ctc_helper.h"
#include "gpu_ctc_kernels.h"
#include "reduce.h"
#include "gpu_helper.h"

template <typename ProbT>
class GpuCTC {
Expand All @@ -11,10 +12,19 @@ class GpuCTC {
int minibatch,
void *workspace,
GPUstream stream,
int blank_label) :
int blank_label,
bool has_cpu_flat_labels,
bool has_cpu_label_lengths,
bool has_cpu_input_lengths,
bool copy_costs_to_cpu) :
out_dim_(alphabet_size), minibatch_(minibatch),
gpu_workspace_(workspace), stream_(stream),
blank_label_(blank_label) {};
blank_label_(blank_label) {
param_configs_.has_cpu_flat_labels = has_cpu_flat_labels;
param_configs_.has_cpu_label_lengths = has_cpu_label_lengths;
param_configs_.has_cpu_input_lengths = has_cpu_input_lengths;
param_configs_.copy_costs_to_cpu = copy_costs_to_cpu;
}

// Noncopyable
GpuCTC(const GpuCTC&) = delete;
Expand All @@ -36,7 +46,6 @@ class GpuCTC {
const int* const input_lengths);

private:

template<int NT, int VT>
ctcStatus_t launch_alpha_beta_kernels(const ProbT* const probs,
ProbT *grads,
Expand All @@ -53,12 +62,14 @@ class GpuCTC {
ctcStatus_t
setup_gpu_metadata(const int* const flat_labels,
const int* const label_lengths,
const int* const input_lengths);
const int* const input_lengths,
ProbT* costs);

ctcStatus_t
create_metadata_and_choose_config(const int* const label_lengths,
const int* const flat_labels,
const int* const input_lengths,
ProbT* costs,
size_t& best_config);

ctcStatus_t
Expand All @@ -74,6 +85,12 @@ class GpuCTC {
bool compute_alpha,
bool compute_betas_and_grad);

struct ParamConfigs {
bool has_cpu_flat_labels{true};
bool has_cpu_label_lengths{true};
bool has_cpu_input_lengths{true};
bool copy_costs_to_cpu{true};
};

int out_dim_; // Number of characters plus blank
int minibatch_;
Expand All @@ -86,34 +103,40 @@ class GpuCTC {
GPUstream stream_;

int blank_label_;

void *gpu_workspace_; // Buffer for all temporary GPU memory
int *utt_length_; // T
int *label_sizes_; // L
int *repeats_; // repeats_
int *label_offsets_;
int *labels_without_blanks_;
int *labels_with_blanks_;
ProbT *alphas_;
ProbT *nll_forward_;
ProbT *nll_backward_;
ProbT *denoms_; // Temporary storage for denoms for softmax
ProbT *probs_; // Temporary storage for probabilities (softmax output)
ParamConfigs param_configs_;

void *gpu_workspace_{nullptr}; // Buffer for all temporary GPU memory
int *utt_length_{nullptr}; // T
int *label_sizes_{nullptr}; // L
int *repeats_{nullptr}; // repeats_
int *label_offsets_{nullptr};
int *labels_without_blanks_{nullptr};
int *labels_with_blanks_{nullptr};
ProbT *alphas_{nullptr};
ProbT *nll_forward_{nullptr};
ProbT *nll_backward_{nullptr};
ProbT *denoms_{nullptr}; // Temporary storage for denoms for softmax
ProbT *probs_{nullptr}; // Temporary storage for probabilities (softmax output)
};

template<typename ProbT>
ctcStatus_t
GpuCTC<ProbT>::setup_gpu_metadata(const int* const flat_labels,
const int* const label_lengths,
const int* const input_lengths)
const int* const input_lengths,
ProbT* costs)
{
size_t gpu_bytes_used = 0;

nll_forward_ =
reinterpret_cast<ProbT *>(static_cast<char*>(gpu_workspace_) +
gpu_bytes_used);
gpu_bytes_used += minibatch_ * sizeof(ProbT);

if (param_configs_.copy_costs_to_cpu) {
nll_forward_ =
reinterpret_cast<ProbT *>(static_cast<char*>(gpu_workspace_) +
gpu_bytes_used);
gpu_bytes_used += minibatch_ * sizeof(ProbT);
} else {
// input costs should be a device ptr.
nll_forward_ = costs;
}

nll_backward_ =
reinterpret_cast<ProbT *>(static_cast<char*>(gpu_workspace_) +
Expand Down Expand Up @@ -345,10 +368,11 @@ ctcStatus_t
GpuCTC<ProbT>::create_metadata_and_choose_config(const int* const flat_labels,
const int* const label_lengths,
const int* const input_lengths,
ProbT* costs,
size_t& best_config) {

// Setup the metadata for GPU
ctcStatus_t status = setup_gpu_metadata(flat_labels, label_lengths, input_lengths);
ctcStatus_t status = setup_gpu_metadata(flat_labels, label_lengths, input_lengths, costs);
if (status != CTC_STATUS_SUCCESS)
return status;

Expand Down Expand Up @@ -472,6 +496,7 @@ GpuCTC<ProbT>::compute_cost_and_score(const ProbT* const activations,
ctcStatus_t status = create_metadata_and_choose_config(flat_labels,
label_lengths,
input_lengths,
costs,
best_config);
if (status != CTC_STATUS_SUCCESS)
return status;
Expand All @@ -483,26 +508,13 @@ GpuCTC<ProbT>::compute_cost_and_score(const ProbT* const activations,
launch_gpu_kernels(probs_, grads, best_config,
compute_alpha, compute_betas_and_grad);

gpuError_t cuda_status_mem, cuda_status_sync;

#ifdef __HIPCC__
cuda_status_mem = hipMemcpyAsync(costs, nll_forward_,
sizeof(ProbT) * minibatch_,
hipMemcpyDeviceToHost, stream_);
#else
cuda_status_mem = cudaMemcpyAsync(costs, nll_forward_,
sizeof(ProbT) * minibatch_,
cudaMemcpyDeviceToHost, stream_);
#endif

#ifdef __HIPCC__
cuda_status_sync = hipStreamSynchronize(stream_);
#else
cuda_status_sync = cudaStreamSynchronize(stream_);
#endif

if (cuda_status_mem != gpuSuccess || cuda_status_sync != gpuSuccess)
return CTC_STATUS_MEMOPS_FAILED;
if (param_configs_.copy_costs_to_cpu) {
// Copy the costs back to CPU.
gpuError_t cuda_status_mem = warpctc::memcpy_d2h_sync(
costs, nll_forward_, sizeof(ProbT) * minibatch_, stream_);
if (cuda_status_mem != gpuSuccess)
return CTC_STATUS_MEMOPS_FAILED;
}

return CTC_STATUS_SUCCESS;
}
Expand Down
35 changes: 35 additions & 0 deletions include/detail/gpu_helper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#pragma once

#include "type_defs.h"

namespace warpctc {

static gpuError_t memcpy_d2h_async(void *dst, const void *src, size_t bytes, GPUstream stream) {
gpuError_t status;
#ifdef __HIPCC__
status = hipMemcpyAsync(dst, src, bytes, hipMemcpyDeviceToHost, stream);
#else
status = cudaMemcpyAsync(dst, src, bytes, cudaMemcpyDeviceToHost, stream);
#endif
return status;
}

static gpuError_t synchronize(GPUstream stream) {
gpuError_t status;
#ifdef __HIPCC__
status = hipStreamSynchronize(stream);
#else
status = cudaStreamSynchronize(stream);
#endif
return status;
}

static gpuError_t memcpy_d2h_sync(void *dst, const void *src, size_t bytes, GPUstream stream) {
gpuError_t status = memcpy_d2h_async(dst, src, bytes, stream);
if (status != gpuSuccess) {
return status;
}
return synchronize(stream);
}

} // namespace warpctc
30 changes: 23 additions & 7 deletions src/ctc_entrypoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ ctcStatus_t compute_ctc_loss(const float* const activations,
} else if (options.loc == CTC_GPU) {
#if (defined(__HIPCC__) || defined(__CUDACC__))
GpuCTC<float> ctc(alphabet_size, minibatch, workspace, options.stream,
options.blank_label);
options.blank_label, options.params.flat_labels_loc == CTC_CPU,
options.params.label_lengths_loc == CTC_CPU,
options.params.input_lengths_loc == CTC_CPU,
options.params.costs_loc == CTC_CPU);

if (gradients != NULL)
return ctc.cost_and_grad(activations, gradients, costs,
Expand Down Expand Up @@ -112,7 +115,7 @@ ctcStatus_t compute_ctc_loss_double(const double* const activations,

if (options.loc == CTC_CPU) {
CpuCTC<double> ctc(alphabet_size, minibatch, workspace, options.num_threads,
options.blank_label);
options.blank_label);

if (gradients != NULL)
return ctc.cost_and_grad(activations, gradients,
Expand All @@ -125,7 +128,10 @@ ctcStatus_t compute_ctc_loss_double(const double* const activations,
} else if (options.loc == CTC_GPU) {
#if (defined(__HIPCC__) || defined(__CUDACC__))
GpuCTC<double> ctc(alphabet_size, minibatch, workspace, options.stream,
options.blank_label);
options.blank_label, options.params.flat_labels_loc == CTC_CPU,
options.params.label_lengths_loc == CTC_CPU,
options.params.input_lengths_loc == CTC_CPU,
options.params.costs_loc == CTC_CPU);

if (gradients != NULL)
return ctc.cost_and_grad(activations, gradients, costs,
Expand Down Expand Up @@ -167,8 +173,13 @@ ctcStatus_t get_workspace_size(const int* const label_lengths,

if (options.loc == CTC_GPU) {
// GPU storage
//nll_forward, nll_backward
*size_bytes += 2 * sizeof(float) * minibatch;
if (options.params.costs_loc == CTC_CPU) {
// nll_forward, nll_backward
*size_bytes += 2 * sizeof(float) * minibatch;
} else {
// nll_backward
*size_bytes += 1 * sizeof(float) * minibatch;
}

//repeats
*size_bytes += sizeof(int) * minibatch;
Expand Down Expand Up @@ -249,8 +260,13 @@ ctcStatus_t get_workspace_size_double(const int* const label_lengths,

if (options.loc == CTC_GPU) {
// GPU storage
//nll_forward, nll_backward
*size_bytes += 2 * sizeof(double) * minibatch;
if (options.params.costs_loc == CTC_CPU) {
// nll_forward, nll_backward
*size_bytes += 2 * sizeof(double) * minibatch;
} else {
// nll_backward
*size_bytes += 1 * sizeof(double) * minibatch;
}

//repeats
*size_bytes += sizeof(int) * minibatch;
Expand Down

0 comments on commit 5fc0bc5

Please sign in to comment.