Skip to content

Commit

Permalink
Add interface supporting double type (#163)
Browse files Browse the repository at this point in the history
* add float64 input for warpctc

* add comment to function of double
  • Loading branch information
Li Fuchen authored Sep 21, 2020
1 parent fc7f226 commit 95a461e
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 17 deletions.
80 changes: 75 additions & 5 deletions include/ctc.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ struct ctcOptions {
int blank_label;
};

/** Compute the connectionist temporal classification loss between a sequence
* of probabilities and a ground truth labeling. Optionally compute the
* gradient with respect to the inputs.
/** Compute the connectionist temporal classification loss between
* a probability sequence with dtype float and a ground truth labeling.
* Optionally compute the gradient with respect to the inputs.
* \param [in] activations pointer to the activations in either CPU or GPU
* addressable memory, depending on info. We assume a fixed
* memory layout for this 3 dimensional tensor, which has dimension
Expand Down Expand Up @@ -112,10 +112,57 @@ API_REFERENCE ctcStatus_t compute_ctc_loss(const float* const activations,
void *workspace,
ctcOptions options);

/** Compute the connectionist temporal classification loss between
* a probability sequence of dtype double and a ground truth labeling.
* Optionally compute the gradient with respect to the inputs.
* \param [in] activations pointer to the activations in either CPU or GPU
* addressable memory, depending on info. We assume a fixed
* memory layout for this 3 dimensional tensor, which has dimension
* (t, n, p), where t is the time index, n is the minibatch index,
* and p indexes over probabilities of each symbol in the alphabet.
* The memory layout is (t, n, p) in C order (slowest to fastest changing
* index, aka row-major), or (p, n, t) in Fortran order (fastest to slowest
* changing index, aka column-major). We also assume strides are equal to
* dimensions - there is no padding between dimensions.
* More precisely, element (t, n, p), for a problem with mini_batch examples
* in the mini batch, and alphabet_size symbols in the alphabet, is located at:
* activations[(t * mini_batch + n) * alphabet_size + p]
* \param [out] gradients if not NULL, then gradients are computed. Should be
* allocated in the same memory space as probs and memory
* ordering is identical.
* \param [in] flat_labels Always in CPU memory. A concatenation
* of all the labels for the minibatch.
* \param [in] label_lengths Always in CPU memory. The length of each label
* for each example in the minibatch.
* \param [in] input_lengths Always in CPU memory. The number of time steps
* for each sequence in the minibatch.
* \param [in] alphabet_size The number of possible output symbols. There
* should be this many probabilities for each time step.
* \param [in] mini_batch How many examples in a minibatch.
* \param [out] costs Always in CPU memory. The cost of each example in the
* minibatch.
* \param [in,out] workspace In same memory space as probs. Should be of
* size requested by get_workspace_size.
* \param [in] options see struct ctcOptions
*
* \return Status information
*
* */
API_REFERENCE ctcStatus_t compute_ctc_loss_double(const double* const activations,
double* gradients,
const int* const flat_labels,
const int* const label_lengths,
const int* const input_lengths,
int alphabet_size,
int minibatch,
double *costs,
void *workspace,
ctcOptions options);


/** For a given set of labels and minibatch size return the required workspace
* size. This will need to be allocated in the same memory space as your
* probabilities.
* size when the dtype of your probabilities is float. This will need to be allocated
* in the same memory space as your probabilities.
* \param [in] label_lengths Always in CPU memory. The length of each label
* for each example in the minibatch.
* \param [in] input_lengths Always in CPU memory. The number of time steps
Expand All @@ -136,6 +183,29 @@ API_REFERENCE ctcStatus_t get_workspace_size(const int* const label_lengths,
ctcOptions info,
size_t* size_bytes);

/** For a given set of labels and minibatch size return the required workspace
* size when the dtype of your probabilities is double. This will need to be allocated
* in the same memory space as your probabilities.
* \param [in] label_lengths Always in CPU memory. The length of each label
* for each example in the minibatch.
* \param [in] input_lengths Always in CPU memory. The number of time steps
* for each sequence in the minibatch.
* \param [in] alphabet_size How many symbols in the alphabet or, equivalently,
* the number of probabilities at each time step
* \param [in] mini_batch How many examples in a minibatch.
* \param [in] info see struct ctcOptions
* \param [out] size_bytes is pointer to a scalar where the memory
* requirement in bytes will be placed. This memory should be allocated
* at the same place, CPU or GPU, that the probs are in
*
* \return Status information
**/
API_REFERENCE ctcStatus_t get_workspace_size_double(const int* const label_lengths,
const int* const input_lengths,
int alphabet_size, int minibatch,
ctcOptions info,
size_t* size_bytes);

#ifdef __cplusplus
}
#endif
4 changes: 2 additions & 2 deletions include/detail/gpu_ctc.h
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ GpuCTC<ProbT>::compute_probs(const ProbT* const activations) {

// Numerically stable SM
ctcStatus_t ctc_status =
reduce_max(probs_, denoms_, out_dim_,
reduce_max<ProbT>(probs_, denoms_, out_dim_,
activation_cols_, 1, stream_);
if (ctc_status != CTC_STATUS_SUCCESS)
return ctc_status;
Expand All @@ -385,7 +385,7 @@ GpuCTC<ProbT>::compute_probs(const ProbT* const activations) {

// Reduce along columns to calculate denominator
ctc_status =
reduce_exp(probs_, denoms_, out_dim_,
reduce_exp<ProbT>(probs_, denoms_, out_dim_,
activation_cols_, 1, stream_);
if (ctc_status != CTC_STATUS_SUCCESS)
return ctc_status;
Expand Down
9 changes: 6 additions & 3 deletions include/detail/reduce.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#pragma once

ctcStatus_t reduce_negate(const float* input, float* output, int rows, int cols, bool axis, cudaStream_t stream);
ctcStatus_t reduce_exp(const float* input, float* output, int rows, int cols, bool axis, cudaStream_t stream);
ctcStatus_t reduce_max(const float* input, float* output, int rows, int cols, bool axis, cudaStream_t stream);
template <typename T>
ctcStatus_t reduce_negate(const T* input, T* output, int rows, int cols, bool axis, cudaStream_t stream);
template <typename T>
ctcStatus_t reduce_exp(const T* input, T* output, int rows, int cols, bool axis, cudaStream_t stream);
template <typename T>
ctcStatus_t reduce_max(const T* input, T* output, int rows, int cols, bool axis, cudaStream_t stream);
136 changes: 136 additions & 0 deletions src/ctc_entrypoint.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <cstddef>
#include <iostream>
#include <algorithm>
#include <cstdio>

#include <ctc.h>

Expand Down Expand Up @@ -89,6 +90,59 @@ ctcStatus_t compute_ctc_loss(const float* const activations,
}
}

ctcStatus_t compute_ctc_loss_double(const double* const activations,
double* gradients,
const int* const flat_labels,
const int* const label_lengths,
const int* const input_lengths,
int alphabet_size,
int minibatch,
double *costs,
void *workspace,
ctcOptions options) {
if (activations == nullptr ||
flat_labels == nullptr ||
label_lengths == nullptr ||
input_lengths == nullptr ||
costs == nullptr ||
workspace == nullptr ||
alphabet_size <= 0 ||
minibatch <= 0)
return CTC_STATUS_INVALID_VALUE;

if (options.loc == CTC_CPU) {
CpuCTC<double> ctc(alphabet_size, minibatch, workspace, options.num_threads,
options.blank_label);

if (gradients != NULL)
return ctc.cost_and_grad(activations, gradients,
costs,
flat_labels, label_lengths,
input_lengths);
else
return ctc.score_forward(activations, costs, flat_labels,
label_lengths, input_lengths);
} else if (options.loc == CTC_GPU) {
#ifdef __CUDACC__
GpuCTC<double> ctc(alphabet_size, minibatch, workspace, options.stream,
options.blank_label);

if (gradients != NULL)
return ctc.cost_and_grad(activations, gradients, costs,
flat_labels, label_lengths,
input_lengths);
else
return ctc.score_forward(activations, costs, flat_labels,
label_lengths, input_lengths);
#else
std::cerr << "GPU execution requested, but not compiled with GPU support" << std::endl;
return CTC_STATUS_EXECUTION_FAILED;
#endif
} else {
return CTC_STATUS_INVALID_VALUE;
}
}


ctcStatus_t get_workspace_size(const int* const label_lengths,
const int* const input_lengths,
Expand Down Expand Up @@ -172,4 +226,86 @@ ctcStatus_t get_workspace_size(const int* const label_lengths,
return CTC_STATUS_SUCCESS;
}

ctcStatus_t get_workspace_size_double(const int* const label_lengths,
const int* const input_lengths,
int alphabet_size, int minibatch,
ctcOptions options,
size_t* size_bytes)
{
if (label_lengths == nullptr ||
input_lengths == nullptr ||
size_bytes == nullptr ||
alphabet_size <= 0 ||
minibatch <= 0)
return CTC_STATUS_INVALID_VALUE;

// This is the max of all S and T for all examples in the minibatch.
int maxL = *std::max_element(label_lengths, label_lengths + minibatch);
int maxT = *std::max_element(input_lengths, input_lengths + minibatch);

const int S = 2 * maxL + 1;

*size_bytes = 0;

if (options.loc == CTC_GPU) {
// GPU storage
//nll_forward, nll_backward
*size_bytes += 2 * sizeof(double) * minibatch;

//repeats
*size_bytes += sizeof(int) * minibatch;

//label offsets
*size_bytes += sizeof(int) * minibatch;

//utt_length
*size_bytes += sizeof(int) * minibatch;

//label lengths
*size_bytes += sizeof(int) * minibatch;

//labels without blanks - overallocate for now
*size_bytes += sizeof(int) * maxL * minibatch;

//labels with blanks
*size_bytes += sizeof(int) * S * minibatch;

//alphas
*size_bytes += sizeof(double) * S * maxT * minibatch;

//denoms
*size_bytes += sizeof(double) * maxT * minibatch;

//probs (since we will pass in activations)
*size_bytes += sizeof(double) * alphabet_size * maxT * minibatch;

} else {
//cpu can eventually replace all minibatch with
//max number of concurrent threads if memory is
//really tight

//per minibatch memory
size_t per_minibatch_bytes = 0;

//output
per_minibatch_bytes += sizeof(double) * alphabet_size ;

//alphas
per_minibatch_bytes += sizeof(double) * S * maxT;

//betas
per_minibatch_bytes += sizeof(double) * S;

//labels w/blanks, e_inc, s_inc
per_minibatch_bytes += 3 * sizeof(int) * S;

*size_bytes = per_minibatch_bytes * minibatch;

//probs
*size_bytes += sizeof(double) * alphabet_size * maxT * minibatch;
}

return CTC_STATUS_SUCCESS;
}

}
22 changes: 15 additions & 7 deletions src/reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -148,15 +148,23 @@ ctcStatus_t reduce(Iof f, Rof g, const T* input, T* output, int rows, int cols,

return CTC_STATUS_SUCCESS;
}

ctcStatus_t reduce_negate(const float *input, float *output, int rows, int cols, bool axis, cudaStream_t stream) {
return reduce(ctc_helper::negate<float>(), ctc_helper::add<float>(), input, output, rows, cols, axis, stream);
template<typename T>
ctcStatus_t reduce_negate(const T *input, T *output, int rows, int cols, bool axis, cudaStream_t stream) {
return reduce(ctc_helper::negate<T>(), ctc_helper::add<T>(), input, output, rows, cols, axis, stream);
}
template ctcStatus_t reduce_negate<float>(const float *input, float *output, int rows, int cols, bool axis, cudaStream_t stream);
template ctcStatus_t reduce_negate<double>(const double *input, double *output, int rows, int cols, bool axis, cudaStream_t stream);

ctcStatus_t reduce_exp(const float *input, float *output, int rows, int cols, bool axis, cudaStream_t stream) {
return reduce(ctc_helper::exponential<float>(), ctc_helper::add<float>(), input, output, rows, cols, axis, stream);
template<typename T>
ctcStatus_t reduce_exp(const T *input, T *output, int rows, int cols, bool axis, cudaStream_t stream) {
return reduce(ctc_helper::exponential<T>(), ctc_helper::add<T>(), input, output, rows, cols, axis, stream);
}
template ctcStatus_t reduce_exp<float>(const float *input, float *output, int rows, int cols, bool axis, cudaStream_t stream);
template ctcStatus_t reduce_exp<double>(const double *input, double *output, int rows, int cols, bool axis, cudaStream_t stream);

ctcStatus_t reduce_max(const float *input, float *output, int rows, int cols, bool axis, cudaStream_t stream) {
return reduce(ctc_helper::identity<float>(), ctc_helper::maximum<float>(),input, output, rows, cols, axis, stream);
template<typename T>
ctcStatus_t reduce_max(const T *input, T *output, int rows, int cols, bool axis, cudaStream_t stream) {
return reduce(ctc_helper::identity<T>(), ctc_helper::maximum<T>(),input, output, rows, cols, axis, stream);
}
template ctcStatus_t reduce_max<float>(const float *input, float *output, int rows, int cols, bool axis, cudaStream_t stream);
template ctcStatus_t reduce_max<double>(const double *input, double *output, int rows, int cols, bool axis, cudaStream_t stream);

0 comments on commit 95a461e

Please sign in to comment.