Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an option copy_costs_to_cpu to allow not copy costs back to CPU. #175

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ option(WITH_TORCH "compile warp-ctc with Torch." ${Torch_FOUND})
option(WITH_OMP "compile warp-ctc with OpenMP." ON)
option(BUILD_TESTS "build warp-ctc unit tests." ON)
option(BUILD_SHARED "build warp-ctc shared library." ON)
option(WITH_ROCM "Compile PaddlePaddle with ROCM platform" OFF)
option(WITH_ROCM "Compile PaddlePaddle with ROCM platform" OFF)

if(WITH_ROCM)
add_definitions(-DWARPCTC_WITH_HIP)
Expand Down Expand Up @@ -65,32 +65,38 @@ endif()

# need to be at least 30 or __shfl_down in reduce wont compile
IF (CUDA_VERSION VERSION_LESS "11.0")
# CUDA_VERSION < 11.0
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_30,code=sm_30")
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_35,code=sm_35")
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_50,code=sm_50")
ENDIF()
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_35,code=sm_35")

set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_50,code=sm_50")
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_52,code=sm_52")

IF (CUDA_VERSION VERSION_GREATER "7.6")
# CUDA_VERSION > 7.6
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_60,code=sm_60")
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_61,code=sm_61")
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_62,code=sm_62")
ENDIF()

IF ((CUDA_VERSION VERSION_GREATER "9.0") OR (CUDA_VERSION VERSION_EQUAL "9.0"))
# CUDA_VERSION >= 9.0
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_70,code=sm_70")
ENDIF()

IF ((CUDA_VERSION VERSION_GREATER "10.0") OR (CUDA_VERSION VERSION_EQUAL "10.0"))
# CUDA_VERSION >= 10.0
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_75,code=sm_75")
ENDIF()

IF ((CUDA_VERSION VERSION_GREATER "11.0") OR (CUDA_VERSION VERSION_EQUAL "11.0"))
# CUDA_VERSION >= 11.0
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_80,code=sm_80")
ENDIF()

IF ((CUDA_VERSION VERSION_GREATER "11.2") OR (CUDA_VERSION VERSION_EQUAL "11.2"))
# CUDA_VERSION >= 11.2
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_86,code=sm_86")
ENDIF()

Expand Down
11 changes: 11 additions & 0 deletions include/ctc.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ typedef enum {
CTC_GPU = 1
} ctcComputeLocation;

struct ctcParamConfigs {
/// input params of compute_ctc_loss
ctcComputeLocation flat_labels_loc{CTC_CPU};
ctcComputeLocation label_lengths_loc{CTC_CPU};
ctcComputeLocation input_lengths_loc{CTC_CPU};
/// output param of compute_ctc_loss
ctcComputeLocation costs_loc{CTC_CPU};
};

/** Structure used for options to the CTC compution. Applications
* should zero out the array using memset and sizeof(struct
* ctcOptions) in C or default initialization (e.g. 'ctcOptions
Expand All @@ -70,6 +79,8 @@ struct ctcOptions {

/// the label value/index that the CTC calculation should use as the blank label
int blank_label;
/// where to hold the input and output params
ctcParamConfigs params;
};

/** Compute the connectionist temporal classification loss between
Expand Down
171 changes: 68 additions & 103 deletions include/detail/gpu_ctc.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "ctc_helper.h"
#include "gpu_ctc_kernels.h"
#include "reduce.h"
#include "gpu_helper.h"

template <typename ProbT>
class GpuCTC {
Expand All @@ -11,10 +12,19 @@ class GpuCTC {
int minibatch,
void *workspace,
GPUstream stream,
int blank_label) :
int blank_label,
bool has_cpu_flat_labels,
bool has_cpu_label_lengths,
bool has_cpu_input_lengths,
bool copy_costs_to_cpu) :
out_dim_(alphabet_size), minibatch_(minibatch),
gpu_workspace_(workspace), stream_(stream),
blank_label_(blank_label) {};
blank_label_(blank_label) {
param_configs_.has_cpu_flat_labels = has_cpu_flat_labels;
param_configs_.has_cpu_label_lengths = has_cpu_label_lengths;
param_configs_.has_cpu_input_lengths = has_cpu_input_lengths;
param_configs_.copy_costs_to_cpu = copy_costs_to_cpu;
}

// Noncopyable
GpuCTC(const GpuCTC&) = delete;
Expand All @@ -36,7 +46,6 @@ class GpuCTC {
const int* const input_lengths);

private:

template<int NT, int VT>
ctcStatus_t launch_alpha_beta_kernels(const ProbT* const probs,
ProbT *grads,
Expand All @@ -53,12 +62,14 @@ class GpuCTC {
ctcStatus_t
setup_gpu_metadata(const int* const flat_labels,
const int* const label_lengths,
const int* const input_lengths);
const int* const input_lengths,
ProbT* costs);

ctcStatus_t
create_metadata_and_choose_config(const int* const label_lengths,
const int* const flat_labels,
const int* const input_lengths,
ProbT* costs,
size_t& best_config);

ctcStatus_t
Expand All @@ -74,6 +85,12 @@ class GpuCTC {
bool compute_alpha,
bool compute_betas_and_grad);

struct ParamConfigs {
bool has_cpu_flat_labels{true};
bool has_cpu_label_lengths{true};
bool has_cpu_input_lengths{true};
bool copy_costs_to_cpu{true};
};

int out_dim_; // Number of characters plus blank
int minibatch_;
Expand All @@ -86,34 +103,40 @@ class GpuCTC {
GPUstream stream_;

int blank_label_;

void *gpu_workspace_; // Buffer for all temporary GPU memory
int *utt_length_; // T
int *label_sizes_; // L
int *repeats_; // repeats_
int *label_offsets_;
int *labels_without_blanks_;
int *labels_with_blanks_;
ProbT *alphas_;
ProbT *nll_forward_;
ProbT *nll_backward_;
ProbT *denoms_; // Temporary storage for denoms for softmax
ProbT *probs_; // Temporary storage for probabilities (softmax output)
ParamConfigs param_configs_;

void *gpu_workspace_{nullptr}; // Buffer for all temporary GPU memory
int *utt_length_{nullptr}; // T
int *label_sizes_{nullptr}; // L
int *repeats_{nullptr}; // repeats_
int *label_offsets_{nullptr};
int *labels_without_blanks_{nullptr};
int *labels_with_blanks_{nullptr};
ProbT *alphas_{nullptr};
ProbT *nll_forward_{nullptr};
ProbT *nll_backward_{nullptr};
ProbT *denoms_{nullptr}; // Temporary storage for denoms for softmax
ProbT *probs_{nullptr}; // Temporary storage for probabilities (softmax output)
};

template<typename ProbT>
ctcStatus_t
GpuCTC<ProbT>::setup_gpu_metadata(const int* const flat_labels,
const int* const label_lengths,
const int* const input_lengths)
const int* const input_lengths,
ProbT* costs)
{
size_t gpu_bytes_used = 0;

nll_forward_ =
reinterpret_cast<ProbT *>(static_cast<char*>(gpu_workspace_) +
gpu_bytes_used);
gpu_bytes_used += minibatch_ * sizeof(ProbT);

if (param_configs_.copy_costs_to_cpu) {
nll_forward_ =
reinterpret_cast<ProbT *>(static_cast<char*>(gpu_workspace_) +
gpu_bytes_used);
gpu_bytes_used += minibatch_ * sizeof(ProbT);
} else {
// input costs should be a device ptr.
nll_forward_ = costs;
}

nll_backward_ =
reinterpret_cast<ProbT *>(static_cast<char*>(gpu_workspace_) +
Expand Down Expand Up @@ -181,30 +204,14 @@ GpuCTC<ProbT>::setup_gpu_metadata(const int* const flat_labels,
Lmax = std::max(Lmax, L);
}

#ifdef __HIPCC__
cuda_status = hipMemcpyAsync(&(repeats_[start_idx]), repeats,
(end_idx - start_idx) * sizeof(int),
hipMemcpyHostToDevice, stream_);
#else
cuda_status = cudaMemcpyAsync(&(repeats_[start_idx]), repeats,
(end_idx - start_idx) * sizeof(int),
cudaMemcpyHostToDevice, stream_);
#endif

cuda_status = warpctc::memcpy_h2d_async(
&(repeats_[start_idx]), repeats, (end_idx - start_idx) * sizeof(int), stream_);
if (cuda_status != gpuSuccess)
return CTC_STATUS_MEMOPS_FAILED;


#ifdef __HIPCC__
cuda_status = hipMemcpyAsync(&(label_offsets_[start_idx]), label_offsets,
(end_idx - start_idx) * sizeof(int),
hipMemcpyHostToDevice, stream_);
#else
cuda_status = cudaMemcpyAsync(&(label_offsets_[start_idx]), label_offsets,
(end_idx - start_idx) * sizeof(int),
cudaMemcpyHostToDevice, stream_);
#endif

cuda_status = warpctc::memcpy_h2d_async(
&(label_offsets_[start_idx]), label_offsets, (end_idx - start_idx) * sizeof(int), stream_);
if (cuda_status != gpuSuccess)
return CTC_STATUS_MEMOPS_FAILED;
}
Expand All @@ -220,16 +227,8 @@ GpuCTC<ProbT>::setup_gpu_metadata(const int* const flat_labels,
gpu_bytes_used);
gpu_bytes_used += minibatch_ * sizeof(int);

#ifdef __HIPCC__
cuda_status = hipMemcpyAsync(utt_length_, input_lengths,
minibatch_ * sizeof(int),
hipMemcpyHostToDevice, stream_);
#else
cuda_status = cudaMemcpyAsync(utt_length_, input_lengths,
minibatch_ * sizeof(int),
cudaMemcpyHostToDevice, stream_);
#endif

cuda_status = warpctc::memcpy_h2d_async(
utt_length_, input_lengths, minibatch_ * sizeof(int), stream_);
if (cuda_status != gpuSuccess)
return CTC_STATUS_MEMOPS_FAILED;

Expand All @@ -238,16 +237,8 @@ GpuCTC<ProbT>::setup_gpu_metadata(const int* const flat_labels,
gpu_bytes_used);
gpu_bytes_used += minibatch_ * sizeof(int);

#ifdef __HIPCC__
cuda_status = hipMemcpyAsync(label_sizes_, label_lengths,
minibatch_ * sizeof(int),
hipMemcpyHostToDevice, stream_);
#else
cuda_status = cudaMemcpyAsync(label_sizes_, label_lengths,
minibatch_ * sizeof(int),
cudaMemcpyHostToDevice, stream_);
#endif

cuda_status = warpctc::memcpy_h2d_async(
label_sizes_, label_lengths, minibatch_ * sizeof(int), stream_);
if (cuda_status != gpuSuccess)
return CTC_STATUS_MEMOPS_FAILED;

Expand All @@ -256,16 +247,8 @@ GpuCTC<ProbT>::setup_gpu_metadata(const int* const flat_labels,
gpu_bytes_used);
gpu_bytes_used += Lmax * minibatch_ * sizeof(int);

#ifdef __HIPCC__
cuda_status = hipMemcpyAsync(labels_without_blanks_, flat_labels,
total_label_length * sizeof(int),
hipMemcpyHostToDevice, stream_);
#else
cuda_status = cudaMemcpyAsync(labels_without_blanks_, flat_labels,
total_label_length * sizeof(int),
cudaMemcpyHostToDevice, stream_);
#endif

cuda_status = warpctc::memcpy_h2d_async(
labels_without_blanks_, flat_labels, total_label_length * sizeof(int), stream_);
if (cuda_status != gpuSuccess)
return CTC_STATUS_MEMOPS_FAILED;

Expand All @@ -279,7 +262,6 @@ GpuCTC<ProbT>::setup_gpu_metadata(const int* const flat_labels,
gpu_bytes_used);
gpu_bytes_used += (S_ * T_) * minibatch_ * sizeof(ProbT);


denoms_ =
reinterpret_cast<ProbT *>(static_cast<char*>(gpu_workspace_) +
gpu_bytes_used);
Expand Down Expand Up @@ -307,25 +289,19 @@ ctcStatus_t GpuCTC<ProbT>::launch_alpha_beta_kernels(const ProbT* const probs,
// away
const int stride = minibatch_;

if (compute_alpha)
if (compute_alpha) {
compute_alpha_kernel<ProbT, NT, VT><<<grid_size, NT, 0, stream_>>>
(probs, label_sizes_, utt_length_,
repeats_, labels_without_blanks_, label_offsets_,
labels_with_blanks_, alphas_, nll_forward_,
stride, out_dim_, S_, T_, blank_label_);

}

if (compute_beta) {
compute_betas_and_grad_kernel<ProbT, NT, VT><<<grid_size, NT, 0, stream_>>>
(probs, label_sizes_, utt_length_, repeats_,
labels_with_blanks_, alphas_, nll_forward_, nll_backward_,
grads, stride, out_dim_, S_, T_, blank_label_);

#ifdef __HIPCC__
hipStreamSynchronize(stream_);
#else
cudaStreamSynchronize(stream_);
#endif
}

#ifdef __HIPCC__
Expand All @@ -345,10 +321,11 @@ ctcStatus_t
GpuCTC<ProbT>::create_metadata_and_choose_config(const int* const flat_labels,
const int* const label_lengths,
const int* const input_lengths,
ProbT* costs,
size_t& best_config) {

// Setup the metadata for GPU
ctcStatus_t status = setup_gpu_metadata(flat_labels, label_lengths, input_lengths);
ctcStatus_t status = setup_gpu_metadata(flat_labels, label_lengths, input_lengths, costs);
if (status != CTC_STATUS_SUCCESS)
return status;

Expand Down Expand Up @@ -472,6 +449,7 @@ GpuCTC<ProbT>::compute_cost_and_score(const ProbT* const activations,
ctcStatus_t status = create_metadata_and_choose_config(flat_labels,
label_lengths,
input_lengths,
costs,
best_config);
if (status != CTC_STATUS_SUCCESS)
return status;
Expand All @@ -483,26 +461,13 @@ GpuCTC<ProbT>::compute_cost_and_score(const ProbT* const activations,
launch_gpu_kernels(probs_, grads, best_config,
compute_alpha, compute_betas_and_grad);

gpuError_t cuda_status_mem, cuda_status_sync;

#ifdef __HIPCC__
cuda_status_mem = hipMemcpyAsync(costs, nll_forward_,
sizeof(ProbT) * minibatch_,
hipMemcpyDeviceToHost, stream_);
#else
cuda_status_mem = cudaMemcpyAsync(costs, nll_forward_,
sizeof(ProbT) * minibatch_,
cudaMemcpyDeviceToHost, stream_);
#endif

#ifdef __HIPCC__
cuda_status_sync = hipStreamSynchronize(stream_);
#else
cuda_status_sync = cudaStreamSynchronize(stream_);
#endif

if (cuda_status_mem != gpuSuccess || cuda_status_sync != gpuSuccess)
return CTC_STATUS_MEMOPS_FAILED;
if (param_configs_.copy_costs_to_cpu) {
// Copy the costs back to CPU.
gpuError_t cuda_status_mem = warpctc::memcpy_d2h_sync(
costs, nll_forward_, sizeof(ProbT) * minibatch_, stream_);
if (cuda_status_mem != gpuSuccess)
return CTC_STATUS_MEMOPS_FAILED;
}

return CTC_STATUS_SUCCESS;
}
Expand Down
Loading