diff --git a/aten/src/ATen/native/Col2Im.cpp b/aten/src/ATen/native/Col2Im.cpp new file mode 100644 index 0000000000000..8b4cec3c8c4cb --- /dev/null +++ b/aten/src/ATen/native/Col2Im.cpp @@ -0,0 +1,249 @@ +#include +#include +#include +#include +#include + +#include +#include + +// Note [im2col/col2im output padding] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// Our implementations of im2col and col2im take both the input height/width as +// well as a seemingly redundant output height/width. In principle, you could +// compute the output height/width by using the convolution shape formulas. So, +// what's up with that? +// +// The trouble arises when one runs the backward of a transposed convolution +// with output_padding >= stride. (BTW, output_padding is known as adj inside +// THNN.) Let's consider a simple case where we have kernel=2, dilation=2, +// stride=1, output_padding=1 for a 4x4 input: +// +// Input: X +// +// Output: X.X. +// .... +// X.X. +// .... +// +// If we compute backwards of output with a standard convolution on the output +// with the same parameters, we would end up with a 2x2 grad_input (because you +// can slide the stencil over to the right once and down once). But that is all +// out-of-bounds if you're computing backwards for a 1x1 input. +// +// "Now Edward," you might say, "the real problem is that you set output_padding +// >= stride, surely an error should have been raised in this case." To +// understand why it is useful to handle this case, we have to understand how we +// compute the weight gradient of a convolution. Suppose we have a convolution +// with kernel=2, stride=2 on a 5x5 input. Let us see all the contributions of +// weight[0][0] (which we have labeled w) in the output: +// +// Input: a.b.. Weight: w. +// ..... .. +// c.d.. +// ..... +// ..... +// +// Output: [ aw+... bw+... ] +// [ cw+... dw+... ] +// +// From this diagram, it easy to see that we can compute the weight gradient +// by performing a *dilated* convolution between the input and the +// output gradients with kernel=2, dilation=2, stride=1. But there's a rub: if +// we do a dilated convolution directly, we'll end up with a 3x3 weight +// gradient, when we clearly wanted a 2x2. So how do we avoid going out +// of bounds? We could add a notion of 'output_padding' for non-transposed +// convolution, but another simple and effective fix is to just accept +// the desired output size directly, and compute only within those bounds. +// +// +// ALSO do vol2col + +namespace at { +namespace native { +namespace { + +static void col2im_out_cpu_template( + Tensor& output, + const Tensor& input_, + IntArrayRef output_size, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride) { + TORCH_CHECK( + output_size.size() == 2, + "It is expected output_size equals to 2, but got size ", + output_size.size()); + + TORCH_CHECK( + kernel_size.size() == 2, + "It is expected kernel_size equals to 2, but got size ", + kernel_size.size()); + + TORCH_CHECK( + dilation.size() == 2, + "It is expected dilation equals to 2, but got size ", + dilation.size()); + + TORCH_CHECK( + padding.size() == 2, + "It is expected padding equals to 2, but got size ", + padding.size()); + + TORCH_CHECK( + stride.size() == 2, + "It is expected stride equals to 2, but got size ", + stride.size()); + + int64_t output_height = output_size[0]; + int64_t output_width = output_size[1]; + int64_t kernel_height = kernel_size[0]; + int64_t kernel_width = kernel_size[1]; + int64_t dilation_height = dilation[0]; + int64_t dilation_width = dilation[1]; + int64_t pad_height = padding[0]; + int64_t pad_width = padding[1]; + int64_t stride_height = stride[0]; + int64_t stride_width = stride[1]; + + col2im_shape_check( + input_, + Tensor(), + output_height, + output_width, + kernel_height, + kernel_width, + dilation_height, + dilation_width, + pad_height, + pad_width, + stride_height, + stride_width); + + Tensor input = input_.contiguous(); + + bool batched_input = true; + if (input.dim() == 2) { + // Force batch + batched_input = false; + input.resize_({1, input.size(0), input.size(1)}); + } + + int64_t batch_size = input.size(0); + int64_t n_input_plane = input.size(1); + int64_t n_output_plane = n_input_plane / (kernel_width * kernel_height); + + output.resize_({batch_size, n_output_plane, output_height, output_width}); + output.zero_(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "col2im_out_cpu", [&] { + Tensor input_n = Tensor(); + Tensor output_n = Tensor(); + + int64_t height_col = (output_height + 2 * pad_height - + (dilation_height * (kernel_height - 1) + 1)) / + stride_height + + 1; + int64_t width_col = (output_width + 2 * pad_width - + (dilation_width * (kernel_width - 1) + 1)) / + stride_width + + 1; + + for (int64_t elt = 0; elt < batch_size; elt++) { + input_n = input.select(0, elt); + output_n = output.select(0, elt); + + col2im( + input_n.data(), + n_output_plane, + output_height, + output_width, + height_col, + width_col, + kernel_height, + kernel_width, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + output_n.data()); + } + + if (!batched_input) { + output.resize_({n_output_plane, output_height, output_width}); + } + }); +} + +void col2im_backward_out_cpu_template( + Tensor& grad_input, + const Tensor& grad_output, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride) { + // im2col_out_cpu checks size of kernel_size, dilation, padding and stride + im2col_out_cpu( + grad_input, grad_output, kernel_size, dilation, padding, stride); +} + +} // namespace + +Tensor& col2im_out_cpu( + Tensor& output, + const Tensor& input, + IntArrayRef output_size, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride) { + col2im_out_cpu_template( + output, input, output_size, kernel_size, dilation, padding, stride); + return output; +} + +Tensor col2im_cpu( + const Tensor& input, + IntArrayRef output_size, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride) { + Tensor output = at::empty_like(input); + + col2im_out_cpu_template( + output, input, output_size, kernel_size, dilation, padding, stride); + return output; +} + +Tensor& col2im_backward_out_cpu( + Tensor& grad_input, + const Tensor& grad_output, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride) { + col2im_backward_out_cpu_template( + grad_input, grad_output, kernel_size, dilation, padding, stride); + return grad_input; +} + +Tensor col2im_backward_cpu( + const Tensor& grad_output, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride) { + Tensor grad_input = at::empty_like(grad_output); + + col2im_backward_out_cpu_template( + grad_input, grad_output, kernel_size, dilation, padding, stride); + return grad_input; +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/Im2Col.cpp b/aten/src/ATen/native/Im2Col.cpp new file mode 100644 index 0000000000000..c660c48708d78 --- /dev/null +++ b/aten/src/ATen/native/Im2Col.cpp @@ -0,0 +1,213 @@ +#include +#include +#include +#include +#include + +#include +#include + +namespace at { +namespace native { +namespace { + +static void im2col_out_cpu_template( + Tensor& output, + const Tensor& input_, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride) { + TORCH_CHECK( + kernel_size.size() == 2, + "It is expected kernel_size equals to 2, but got size ", + kernel_size.size()); + + TORCH_CHECK( + dilation.size() == 2, + "It is expected dilation equals to 2, but got size ", + dilation.size()); + + TORCH_CHECK( + padding.size() == 2, + "It is expected padding equals to 2, but got size ", + padding.size()); + + TORCH_CHECK( + stride.size() == 2, + "It is expected stride equals to 2, but got size ", + stride.size()); + + int64_t kernel_height = kernel_size[0]; + int64_t kernel_width = kernel_size[1]; + int64_t dilation_height = dilation[0]; + int64_t dilation_width = dilation[1]; + int64_t pad_height = padding[0]; + int64_t pad_width = padding[1]; + int64_t stride_height = stride[0]; + int64_t stride_width = stride[1]; + + im2col_shape_check( + input_, + Tensor(), + kernel_height, + kernel_width, + dilation_height, + dilation_width, + pad_height, + pad_width, + stride_height, + stride_width); + + Tensor input = input_.contiguous(); + + bool batched_input = true; + + if (input.dim() == 3) { + batched_input = false; + input.resize_({1, input.size(0), input.size(1), input.size(2)}); + } + + int64_t batch_size = input.size(0); + int64_t n_input_plane = input.size(1); + int64_t input_height = input.size(2); + int64_t input_width = input.size(3); + + int64_t output_height = (input_height + 2 * pad_height - + (dilation_height * (kernel_height - 1) + 1)) / + stride_height + + 1; + int64_t output_width = (input_width + 2 * pad_width - + (dilation_width * (kernel_width - 1) + 1)) / + stride_width + + 1; + int64_t n_output_plane = n_input_plane * kernel_width * kernel_height; + int64_t output_length = output_height * output_width; + + output.resize_({batch_size, n_output_plane, output_length}); + output.zero_(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "im2col_out_cpu", [&] { + Tensor input_n; + Tensor output_n; + + for (int64_t elt = 0; elt < batch_size; elt++) { + input_n = input.select(0, elt); + output_n = output.select(0, elt); + + im2col( + input_n.data(), + n_input_plane, + input_height, + input_width, + output_height, + output_width, + kernel_height, + kernel_width, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + output_n.data()); + } + + if (!batched_input) { + output.resize_({n_output_plane, output_length}); + } + }); +} + +static void im2col_backward_out_cpu_template( + Tensor& grad_input, + const Tensor& grad_output, + IntArrayRef input_size, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride) { + TORCH_CHECK( + input_size.size() == 2, + "It is expected input_size equals to 2, but got size ", + input_size.size()); + // col2im_out_cpu checks size of kernel_size, dilation, padding and stride + col2im_out_cpu( + grad_input, + grad_output, + input_size, + kernel_size, + dilation, + padding, + stride); +} + +} // namespace + +Tensor& im2col_out_cpu( + Tensor& output, + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride) { + im2col_out_cpu_template( + output, input, kernel_size, dilation, padding, stride); + return output; +} + +Tensor im2col_cpu( + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride) { + Tensor output = at::empty_like(input); + + im2col_out_cpu_template( + output, input, kernel_size, dilation, padding, stride); + return output; +} + +Tensor& im2col_backward_out_cpu( + Tensor& grad_input, + const Tensor& grad_output, + IntArrayRef input_size, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride) { + im2col_backward_out_cpu_template( + grad_input, + grad_output, + input_size, + kernel_size, + dilation, + padding, + stride); + return grad_input; +} + +Tensor im2col_backward_cpu( + const Tensor& grad_output, + IntArrayRef input_size, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride) { + Tensor grad_input = at::empty_like(grad_output); + + im2col_backward_out_cpu_template( + grad_input, + grad_output, + input_size, + kernel_size, + dilation, + padding, + stride); + return grad_input; +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/Col2Im.cu b/aten/src/ATen/native/cuda/Col2Im.cu new file mode 100644 index 0000000000000..92985e375c8f6 --- /dev/null +++ b/aten/src/ATen/native/cuda/Col2Im.cu @@ -0,0 +1,207 @@ +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + +namespace at { +namespace native { +namespace { + +void col2im_out_cuda_template( + Tensor& output, + const Tensor& input_, + IntArrayRef output_size, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride) { + TensorArg input_arg{input_, "input", 1}; + TensorArg output_arg{output, "output", 2}; + checkAllSameGPU("col2im_out_cuda", {input_arg, output_arg}); + + TORCH_CHECK( + output_size.size() == 2, + "It is expected output_size equals to 2, but got size ", + output_size.size()); + + TORCH_CHECK( + kernel_size.size() == 2, + "It is expected kernel_size equals to 2, but got size ", + kernel_size.size()); + + TORCH_CHECK( + dilation.size() == 2, + "It is expected dilation equals to 2, but got size ", + dilation.size()); + + TORCH_CHECK( + padding.size() == 2, + "It is expected padding equals to 2, but got size ", + padding.size()); + + TORCH_CHECK( + stride.size() == 2, + "It is expected stride equals to 2, but got size ", + stride.size()); + + int64_t output_height = output_size[0]; + int64_t output_width = output_size[1]; + int64_t kernel_height = kernel_size[0]; + int64_t kernel_width = kernel_size[1]; + int64_t dilation_height = dilation[0]; + int64_t dilation_width = dilation[1]; + int64_t pad_height = padding[0]; + int64_t pad_width = padding[1]; + int64_t stride_height = stride[0]; + int64_t stride_width = stride[1]; + + col2im_shape_check( + input_, + Tensor(), + output_height, + output_width, + kernel_height, + kernel_width, + dilation_height, + dilation_width, + pad_height, + pad_width, + stride_height, + stride_width); + + Tensor input = input_.contiguous(); + + bool batched_input = true; + if (input.dim() == 2) { + // Force batch + batched_input = false; + input.resize_({1, input.size(0), input.size(1)}); + } + + int64_t batch_size = input.size(0); + int64_t n_input_plane = input.size(1); + int64_t n_output_plane = n_input_plane / (kernel_width * kernel_height); + + output.resize_({batch_size, n_output_plane, output_height, output_width}); + output.zero_(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "col2im_out_cuda", [&] { + using accscalar_t = at::acc_type; + + Tensor input_n; + Tensor output_n; + + int64_t height_col = (output_height + 2 * pad_height - + (dilation_height * (kernel_height - 1) + 1)) / + stride_height + + 1; + int64_t width_col = (output_width + 2 * pad_width - + (dilation_width * (kernel_width - 1) + 1)) / + stride_width + + 1; + + for (int64_t elt = 0; elt < batch_size; elt++) { + input_n = input.select(0, elt); + output_n = output.select(0, elt); + + col2im( + at::cuda::getCurrentCUDAStream(), + input_n.data(), + n_output_plane, + output_height, + output_width, + height_col, + width_col, + kernel_height, + kernel_width, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + output_n.data()); + } + + if (!batched_input) { + output.resize_({n_output_plane, output_height, output_width}); + } + }); +} + +void col2im_backward_out_cuda_template( + Tensor& grad_input, + const Tensor& grad_output, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride) { + // im2col_out_cuda checks size of kernel_size, dilation, padding and stride + im2col_out_cuda( + grad_input, grad_output, kernel_size, dilation, padding, stride); +} + +} // namespace + +Tensor& col2im_out_cuda( + Tensor& output, + const Tensor& input, + IntArrayRef output_size, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride) { + col2im_out_cuda_template( + output, input, output_size, kernel_size, dilation, padding, stride); + return output; +} + +Tensor col2im_cuda( + const Tensor& input, + IntArrayRef output_size, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride) { + Tensor output = at::empty_like(input); + + col2im_out_cuda_template( + output, input, output_size, kernel_size, dilation, padding, stride); + return output; +} + +Tensor& col2im_backward_out_cuda( + Tensor& grad_input, + const Tensor& grad_output, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride) { + col2im_backward_out_cuda_template( + grad_input, grad_output, kernel_size, dilation, padding, stride); + return grad_input; +} + +Tensor col2im_backward_cuda( + const Tensor& grad_output, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride) { + Tensor grad_input = at::empty_like(grad_output); + + col2im_backward_out_cuda_template( + grad_input, grad_output, kernel_size, dilation, padding, stride); + return grad_input; +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/Im2Col.cu b/aten/src/ATen/native/cuda/Im2Col.cu new file mode 100644 index 0000000000000..176e82333f8d9 --- /dev/null +++ b/aten/src/ATen/native/cuda/Im2Col.cu @@ -0,0 +1,220 @@ +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + +namespace at { +namespace native { +namespace { + +static void im2col_out_cuda_template( + Tensor& output, + const Tensor& input_, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride) { + TORCH_CHECK( + kernel_size.size() == 2, + "It is expected kernel_size equals to 2, but got size ", + kernel_size.size()); + + TORCH_CHECK( + dilation.size() == 2, + "It is expected dilation equals to 2, but got size ", + dilation.size()); + + TORCH_CHECK( + padding.size() == 2, + "It is expected padding equals to 2, but got size ", + padding.size()); + + TORCH_CHECK( + stride.size() == 2, + "It is expected stride equals to 2, but got size ", + stride.size()); + + int64_t kernel_height = kernel_size[0]; + int64_t kernel_width = kernel_size[1]; + int64_t dilation_height = dilation[0]; + int64_t dilation_width = dilation[1]; + int64_t pad_height = padding[0]; + int64_t pad_width = padding[1]; + int64_t stride_height = stride[0]; + int64_t stride_width = stride[1]; + + TensorArg input_arg{input_, "input", 1}; + TensorArg output_arg{output, "output", 2}; + checkAllSameGPU("im2col_cuda", {input_arg, output_arg}); + + im2col_shape_check( + input_, + Tensor(), + kernel_height, + kernel_width, + dilation_height, + dilation_width, + pad_height, + pad_width, + stride_height, + stride_width); + + Tensor input = input_.contiguous(); + + bool batched_input = true; + + if (input.dim() == 3) { + batched_input = false; + input.resize_({1, input.size(0), input.size(1), input.size(2)}); + } + + int64_t batch_size = input.size(0); + int64_t n_input_plane = input.size(1); + int64_t input_height = input.size(2); + int64_t input_width = input.size(3); + + int64_t output_height = (input_height + 2 * pad_height - + (dilation_height * (kernel_height - 1) + 1)) / + stride_height + + 1; + int64_t output_width = (input_width + 2 * pad_width - + (dilation_width * (kernel_width - 1) + 1)) / + stride_width + + 1; + int64_t n_output_plane = n_input_plane * kernel_width * kernel_height; + int64_t output_length = output_height * output_width; + + output.resize_({batch_size, n_output_plane, output_length}); + output.zero_(); + + // Launch kernel + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "im2col_out_cuda", [&] { + Tensor input_n; + Tensor output_n; + + for (int64_t elt = 0; elt < batch_size; elt++) { + input_n = input.select(0, elt); + output_n = output.select(0, elt); + + im2col( + at::cuda::getCurrentCUDAStream(), + input_n.data(), + n_input_plane, + input_height, + input_width, + output_height, + output_width, + kernel_height, + kernel_width, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + output_n.data()); + } + + if (!batched_input) { + output.resize_({n_output_plane, output_length}); + } + }); +} + +static void im2col_backward_out_cuda_template( + Tensor& grad_input, + const Tensor& grad_output, + IntArrayRef input_size, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride) { + TORCH_CHECK( + input_size.size() == 2, + "It is expected input_size equals to 2, but got size ", + input_size.size()); + // col2im_out_cuda checks size of kernel_size, dilation, padding and stride + col2im_out_cuda( + grad_input, + grad_output, + input_size, + kernel_size, + dilation, + padding, + stride); +} + +} // namespace + +Tensor& im2col_out_cuda( + Tensor& output, + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride) { + im2col_out_cuda_template( + output, input, kernel_size, dilation, padding, stride); + return output; +} + +Tensor im2col_cuda( + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride) { + Tensor output = at::empty_like(input); + im2col_out_cuda_template( + output, input, kernel_size, dilation, padding, stride); + return output; +} + +Tensor& im2col_backward_out_cuda( + Tensor& grad_input, + const Tensor& grad_output, + IntArrayRef input_size, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride) { + im2col_backward_out_cuda_template( + grad_input, + grad_output, + input_size, + kernel_size, + dilation, + padding, + stride); + return grad_input; +} + +Tensor im2col_backward_cuda( + const Tensor& grad_output, + IntArrayRef input_size, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride) { + Tensor grad_input = at::empty_like(grad_output); + im2col_backward_out_cuda_template( + grad_input, + grad_output, + input_size, + kernel_size, + dilation, + padding, + stride); + return grad_input; +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/im2col.cuh b/aten/src/ATen/native/cuda/im2col.cuh new file mode 100644 index 0000000000000..13c3de90bf6a1 --- /dev/null +++ b/aten/src/ATen/native/cuda/im2col.cuh @@ -0,0 +1,215 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include + +namespace at { +namespace native { + +using namespace at::cuda::detail; + +// Kernel for fast unfold+copy +// (borrowed from Caffe: +// https://github.com/BVLC/caffe/blob/master/src/caffe/layers/conv_layer.cu) +// CUDA_NUM_THREADS = 1024 + +template +C10_LAUNCH_BOUNDS_1(1024) +__global__ void im2col_kernel( + const int64_t n, + const dt* data_im, + const int64_t height, + const int64_t width, + const int64_t kernel_height, + const int64_t kernel_width, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + const int64_t height_col, + const int64_t width_col, + dt* data_col) { + CUDA_KERNEL_LOOP(index, n) { + int64_t w_out = index % width_col; + + index /= width_col; + + int64_t h_out = index % height_col; + int64_t channel_in = index / height_col; + int64_t channel_out = channel_in * kernel_height * kernel_width; + int64_t h_in = h_out * stride_height - pad_height; + int64_t w_in = w_out * stride_width - pad_width; + + data_col += (channel_out * height_col + h_out) * width_col + w_out; + data_im += (channel_in * height + h_in) * width + w_in; + + for (int64_t i = 0; i < kernel_height; ++i) { + for (int64_t j = 0; j < kernel_width; ++j) { + int64_t h = h_in + i * dilation_height; + int64_t w = w_in + j * dilation_width; + *data_col = (h >= 0 && w >= 0 && h < height && w < width) + ? data_im[i * dilation_height * width + j * dilation_width] + : ScalarConvert::to(0); + data_col += height_col * width_col; + } + } + } +} + +template +void im2col( + cudaStream_t stream, + const dt* data_im, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t height_col, + const int64_t width_col, + const int64_t kernel_height, + const int64_t kernel_width, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + dt* data_col) { + // We are going to launch channels * height_col * width_col kernels, each + // kernel responsible for copying a single-channel grid. + int64_t num_kernels = channels * height_col * width_col; + // Launch CUDA_NUM_THREADS = 1024 + im2col_kernel<<>>( + num_kernels, + data_im, + height, + width, + kernel_height, + kernel_width, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + height_col, + width_col, + data_col); + AT_CUDA_CHECK(cudaGetLastError()); +} + +template +C10_LAUNCH_BOUNDS_1(1024) +__global__ void col2im_kernel( + const int64_t n, + const dt* data_col, + const int64_t height, + const int64_t width, + const int64_t channels, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + const int64_t height_col, + const int64_t width_col, + dt* data_im) { + CUDA_KERNEL_LOOP(index, n) { + accT val = static_cast(0); + const int64_t w_im = index % width + pad_width; + const int64_t h_im = (index / width) % height + pad_height; + const int64_t c_im = index / (width * height); + int64_t kernel_extent_w = (kernel_w - 1) * dilation_width + 1; + int64_t kernel_extent_h = (kernel_h - 1) * dilation_height + 1; + // compute the start and end of the output + const int64_t w_col_start = (w_im < kernel_extent_w) + ? 0 + : (w_im - kernel_extent_w) / stride_width + 1; + const int64_t w_col_end = ::min(w_im / stride_width + 1, width_col); + const int64_t h_col_start = (h_im < kernel_extent_h) + ? 0 + : (h_im - kernel_extent_h) / stride_height + 1; + const int64_t h_col_end = ::min(h_im / stride_height + 1, height_col); + + // TODO: use LCM of stride and dilation to avoid unnecessary loops + for (int64_t h_col = h_col_start; h_col < h_col_end; h_col += 1) { + for (int64_t w_col = w_col_start; w_col < w_col_end; w_col += 1) { + int64_t h_k = (h_im - h_col * stride_height); + int64_t w_k = (w_im - w_col * stride_width); + if (h_k % dilation_height == 0 && w_k % dilation_width == 0) { + h_k /= dilation_height; + w_k /= dilation_width; + int64_t data_col_index = + (((c_im * kernel_h + h_k) * kernel_w + w_k) * height_col + + h_col) * + width_col + + w_col; + val += data_col[data_col_index]; + } + } + } + data_im[index] = static_cast
(val); + } +} + +template +void col2im( + cudaStream_t stream, + const dt* data_col, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t output_height, + const int64_t output_width, + const int64_t patch_height, + const int64_t patch_width, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + dt* data_im) { + int64_t num_kernels = channels * height * width; + // To avoid involving atomic operations, we will launch one kernel per + // bottom dimension, and then in the kernel add up the top dimensions. + // CUDA_NUM_THREADS = 1024 + col2im_kernel + <<>>( + num_kernels, + data_col, + height, + width, + channels, + patch_height, + patch_width, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + output_height, + output_width, + data_im); + AT_CUDA_CHECK(cudaGetLastError()); +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/vol2col.cuh b/aten/src/ATen/native/cuda/vol2col.cuh new file mode 100644 index 0000000000000..19d9cf0ed5051 --- /dev/null +++ b/aten/src/ATen/native/cuda/vol2col.cuh @@ -0,0 +1,262 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include + +namespace at { +namespace native { + +using namespace at::cuda::detail; + +// Kernel for fast unfold+copy on volumes +template +__global__ void vol2col_kernel( + const int n, + const T* data_vol, + const int depth, + const int height, + const int width, + const int ksize_t, + const int ksize_h, + const int ksize_w, + const int pad_t, + const int pad_h, + const int pad_w, + const int stride_t, + const int stride_h, + const int stride_w, + const int dilation_t, + const int dilation_h, + const int dilation_w, + const int depth_col, + const int height_col, + const int width_col, + T* data_col) { + CUDA_KERNEL_LOOP(index, n) { + int w_out = index % width_col; + index /= width_col; + int h_out = index % height_col; + index /= height_col; + int t_out = index % depth_col; + int channel_in = index / depth_col; + int channel_out = channel_in * ksize_t * ksize_h * ksize_w; + int t_in = t_out * stride_t - pad_t; + int h_in = h_out * stride_h - pad_h; + int w_in = w_out * stride_w - pad_w; + data_col += + ((channel_out * depth_col + t_out) * height_col + h_out) * width_col + + w_out; + data_vol += ((channel_in * depth + t_in) * height + h_in) * width + w_in; + for (int i = 0; i < ksize_t; ++i) { + for (int j = 0; j < ksize_h; ++j) { + for (int k = 0; k < ksize_w; ++k) { + int t = t_in + i * dilation_t; + int h = h_in + j * dilation_h; + int w = w_in + k * dilation_w; + *data_col = (t >= 0 && h >= 0 && w >= 0 && t < depth && h < height && + w < width) + ? data_vol + [i * dilation_t * height * width + j * dilation_h * width + + k * dilation_w] + : static_cast(0); + data_col += depth_col * height_col * width_col; + } + } + } + } +} + +template +void vol2col( + cudaStream_t stream, + const T* data_vol, + const int channels, + const int depth, + const int height, + const int width, + const int depth_col, + const int height_col, + const int width_col, + const int ksize_t, + const int ksize_h, + const int ksize_w, + const int pad_t, + const int pad_h, + const int pad_w, + const int stride_t, + const int stride_h, + const int stride_w, + const int dilation_t, + const int dilation_h, + const int dilation_w, + T* data_col) { + // We are going to launch channels * depth_col * height_col * width_col + // kernels, each kernel responsible for copying a single-channel grid. + int num_kernels = channels * depth_col * height_col * width_col; + // Launch + vol2col_kernel<<>>( + num_kernels, + data_vol, + depth, + height, + width, + ksize_t, + ksize_h, + ksize_w, + pad_t, + pad_h, + pad_w, + stride_t, + stride_h, + stride_w, + dilation_t, + dilation_h, + dilation_w, + depth_col, + height_col, + width_col, + data_col); + AT_CUDA_CHECK(cudaGetLastError()); +} + +template +__global__ void vol2im_kernel( + const int n, + const T* data_col, + const int depth, + const int height, + const int width, + const int channels, + const int kernel_t, + const int kernel_h, + const int kernel_w, + const int pad_t, + const int pad_h, + const int pad_w, + const int stride_t, + const int stride_h, + const int stride_w, + const int dilation_t, + const int dilation_h, + const int dilation_w, + const int depth_col, + const int height_col, + const int width_col, + T* data_vol) { + CUDA_KERNEL_LOOP(index, n) { + accT val = static_cast(0); + const int w_im = index % width + pad_w; + const int h_im = (index / width) % height + pad_h; + const int t_im = (index / width / height) % depth + pad_t; + const int c_im = index / (width * height * depth); + int kernel_extent_w = (kernel_w - 1) * dilation_w + 1; + int kernel_extent_h = (kernel_h - 1) * dilation_h + 1; + int kernel_extent_t = (kernel_t - 1) * dilation_t + 1; + // compute the start and end of the output + const int w_col_start = + (w_im < kernel_extent_w) ? 0 : (w_im - kernel_extent_w) / stride_w + 1; + const int w_col_end = ::min(w_im / stride_w + 1, width_col); + const int h_col_start = + (h_im < kernel_extent_h) ? 0 : (h_im - kernel_extent_h) / stride_h + 1; + const int h_col_end = ::min(h_im / stride_h + 1, height_col); + const int t_col_start = + (t_im < kernel_extent_t) ? 0 : (t_im - kernel_extent_t) / stride_t + 1; + const int t_col_end = ::min(t_im / stride_t + 1, depth_col); + // TODO: use LCM of stride and dilation to avoid unnecessary loops + for (int t_col = t_col_start; t_col < t_col_end; t_col += 1) { + for (int h_col = h_col_start; h_col < h_col_end; h_col += 1) { + for (int w_col = w_col_start; w_col < w_col_end; w_col += 1) { + int t_k = (t_im - t_col * stride_t); + int h_k = (h_im - h_col * stride_h); + int w_k = (w_im - w_col * stride_w); + if (t_k % dilation_t == 0 && h_k % dilation_h == 0 && + w_k % dilation_w == 0) { + t_k /= dilation_t; + h_k /= dilation_h; + w_k /= dilation_w; + int data_col_index = + (((((c_im * kernel_t + t_k) * kernel_h + h_k) * kernel_w + + w_k) * + depth_col + + t_col) * + height_col + + h_col) * + width_col + + w_col; + val += data_col[data_col_index]; + } + } + } + } + data_vol[index] = static_cast(val); + } +} + +template +void col2vol( + cudaStream_t stream, + const T* data_col, + const int channels, + const int depth, + const int height, + const int width, + const int output_depth, + const int output_height, + const int output_width, + const int patch_t, + const int patch_h, + const int patch_w, + const int pad_t, + const int pad_h, + const int pad_w, + const int stride_t, + const int stride_h, + const int stride_w, + const int dilation_t, + const int dilation_h, + const int dilation_w, + T* data_vol) { + int num_kernels = channels * depth * height * width; + // To avoid involving atomic operations, we will launch one kernel per + // bottom dimension, and then in the kernel add up the top dimensions. + vol2im_kernel + <<>>( + num_kernels, + data_col, + depth, + height, + width, + channels, + patch_t, + patch_h, + patch_w, + pad_t, + pad_h, + pad_w, + stride_t, + stride_h, + stride_w, + dilation_t, + dilation_h, + dilation_w, + output_depth, + output_height, + output_width, + data_vol); + AT_CUDA_CHECK(cudaGetLastError()); +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/im2col.h b/aten/src/ATen/native/im2col.h new file mode 100644 index 0000000000000..df7eeffbc78b1 --- /dev/null +++ b/aten/src/ATen/native/im2col.h @@ -0,0 +1,94 @@ +#pragma once + +#include +#include +#include +#include + +namespace at { +namespace native { + +template +static void im2col( + const T* data_im, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t output_height, + const int64_t output_width, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_h, + const int64_t pad_w, + const int64_t stride_h, + const int64_t stride_w, + const int64_t dilation_h, + const int64_t dilation_w, + T* data_col) { + const int64_t height_col = output_height; + const int64_t width_col = output_width; + const int64_t channels_col = channels * kernel_h * kernel_w; + + for (int64_t c_col = 0; c_col < channels_col; ++c_col) { + int64_t w_offset = c_col % kernel_w; + int64_t h_offset = (c_col / kernel_w) % kernel_h; + int64_t c_im = c_col / kernel_h / kernel_w; + + for (int64_t h_col = 0; h_col < height_col; ++h_col) { + int64_t h_im = h_col * stride_h - pad_h + h_offset * dilation_h; + + for (int64_t w_col = 0; w_col < width_col; ++w_col) { + int64_t w_im = w_col * stride_w - pad_w + w_offset * dilation_w; + data_col[(c_col * height_col + h_col) * width_col + w_col] = + (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) + ? data_im[(c_im * height + h_im) * width + w_im] + : static_cast(0); + } + } + } +} + +template +static void col2im( + const T* data_col, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t output_height, + const int64_t output_width, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_h, + const int64_t pad_w, + const int64_t stride_h, + const int64_t stride_w, + const int64_t dilation_h, + const int64_t dilation_w, + T* data_im) { + memset(data_im, 0, sizeof(T) * height * width * channels); + + const int64_t height_col = output_height; + const int64_t width_col = output_width; + const int64_t channels_col = channels * kernel_h * kernel_w; + + for (int64_t c_col = 0; c_col < channels_col; ++c_col) { + int64_t w_offset = c_col % kernel_w; + int64_t h_offset = (c_col / kernel_w) % kernel_h; + int64_t c_im = c_col / kernel_h / kernel_w; + + for (int64_t h_col = 0; h_col < height_col; ++h_col) { + int64_t h_im = h_col * stride_h - pad_h + h_offset * dilation_h; + + for (int64_t w_col = 0; w_col < width_col; ++w_col) { + int64_t w_im = w_col * stride_w - pad_w + w_offset * dilation_w; + + if (h_im >= 0 && h_im < height && w_im >= 0 && w_im < width) + data_im[(c_im * height + h_im) * width + w_im] += + data_col[(c_col * height_col + h_col) * width_col + w_col]; + } + } + } +} + +} // native +} // at diff --git a/aten/src/ATen/native/im2col_shape_check.h b/aten/src/ATen/native/im2col_shape_check.h new file mode 100644 index 0000000000000..ff20959e04856 --- /dev/null +++ b/aten/src/ATen/native/im2col_shape_check.h @@ -0,0 +1,210 @@ +#include +#include + +namespace at { +namespace native { + +static inline void col2im_shape_check( + const Tensor& input, + const Tensor& grad_output, + int64_t output_height, + int64_t output_width, + int64_t kernel_height, + int64_t kernel_width, + int64_t dilation_height, + int64_t dilation_width, + int64_t pad_height, + int64_t pad_width, + int64_t stride_height, + int64_t stride_width) { + TORCH_CHECK( + kernel_width > 0 && kernel_height > 0, + "kernel size should be greater than zero, but got kernel_height: ", + kernel_height, + " kernel_width: ", + kernel_width); + TORCH_CHECK( + stride_width > 0 && stride_height > 0, + "stride should be greater than zero, but got stride_height: ", + stride_height, + " stride_width: ", + stride_width); + TORCH_CHECK( + dilation_width > 0 && dilation_height > 0, + "dilation should be greater than zero, but got dilation_height: ", + dilation_height, + " dilation_width: ", + dilation_width); + + int64_t ndim = input.ndimension(); + TORCH_CHECK( + input.numel() != 0 && (ndim == 2 || ndim == 3), + "Expected non-empty 2D or 3D input tensor, but got input of sizes", + input.sizes()); + + int64_t batch_dim = (ndim == 3) ? 0 : -1; + int64_t n_input_plane = input.size(batch_dim + 1); + + if (n_input_plane % (kernel_width * kernel_height) != 0) { + AT_ERROR( + "Expected size of input's dimension 1 to be divisible by the " + "product of kernel_size, but got input.size(1)=", + n_input_plane, + " and kernel_size=(", + kernel_height, + ", ", + kernel_width, + ")."); + } + + int64_t input_length = input.size(batch_dim + 2); + int64_t n_blocks_height = + div_rtn( + output_height + 2 * pad_height - + dilation_height * (kernel_height - 1) - 1, + stride_height) + + 1; + int64_t n_blocks_width = div_rtn( + output_width + 2 * pad_width - + dilation_width * (kernel_width - 1) - 1, + stride_width) + + 1; + + if (input_length != (n_blocks_height * n_blocks_width)) { + AT_ERROR( + "Given output_size=(", + output_height, + ", ", + output_width, + "), kernel_size=(", + kernel_height, + ", ", + kernel_width, + "), dilation=(", + dilation_height, + ", ", + dilation_width, + "), padding=(", + pad_height, + ", ", + pad_width, + "), stride=(", + stride_height, + ", ", + stride_width, + "), expected size of input's dimension 2 to match the calculated number of ", + "sliding blocks ", + n_blocks_height, + " * ", + n_blocks_width, + " = ", + (n_blocks_height * n_blocks_width), + ", but got input.size(2)=", + input_length, + "."); + } + + if (output_width < 1 || output_height < 1) { + AT_ERROR( + "Expected output spatial size to be positive, but got: output_size=(", + output_height, + ", ", + output_width, + ")."); + } +} + +static inline void im2col_shape_check( + const Tensor& input, + const Tensor& grad_output, + int64_t kernel_height, + int64_t kernel_width, + int64_t dilation_height, + int64_t dilation_width, + int64_t pad_height, + int64_t pad_width, + int64_t stride_height, + int64_t stride_width) { + TORCH_CHECK( + kernel_width > 0 && kernel_height > 0, + "kernel size should be greater than zero, but got kernel_height: ", + kernel_height, + " kernel_width: ", + kernel_width); + + TORCH_CHECK( + dilation_width > 0 && dilation_height > 0, + "dilation should be greater than zero, but got dilation_height: ", + dilation_height, + " dilation_width: ", + dilation_width); + + TORCH_CHECK( + pad_width >= 0 && pad_height >= 0, + "padding should be non-negative, but got pad_height: ", + pad_height, + " pad_width: ", + pad_width); + + TORCH_CHECK( + stride_width > 0 && stride_height > 0, + "stride should be greater than zero, but got stride_height: ", + stride_height, + " stride_width: ", + stride_width); + + int64_t ndim = input.ndimension(); + + TORCH_CHECK( + input.numel() != 0 && (ndim == 3 || ndim == 4), + "Expected non-empty 3D or 4D input tensor, but got input of size ", + input.sizes()); + + int64_t dim_batch = 0; + + if (ndim == 3) { + dim_batch = -1; + } + + int64_t n_input_plane = input.size(dim_batch + 1); + int64_t input_height = input.size(dim_batch + 2); + int64_t input_width = input.size(dim_batch + 3); + int64_t output_height = div_rtn( + input_height + 2 * pad_height - + (dilation_height * (kernel_height - 1) + 1), + stride_height) + + 1; + int64_t output_width = div_rtn( + input_width + 2 * pad_width - + (dilation_width * (kernel_width - 1) + 1), + stride_width) + + 1; + + if (output_height < 1 || output_width < 1) { + AT_ERROR( + "Given input with spatial size (", + input_height, + ", ", + input_height, + "), kernel_size=(", + kernel_height, + ", ", + kernel_width, + "), dilation=(", + dilation_height, + ", ", + dilation_width, + "), padding=(", + pad_height, + ", ", + pad_width, + "), calculated shape of the array of sliding blocks as (", + output_height, + ", ", + output_width, + "), which is too small (non-positive)."); + } +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 753444a7d0ce3..3c96a2e6359c8 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -5421,26 +5421,50 @@ CPU: legacy::cpu::_thnn_conv_dilated3d_backward CUDA: legacy::cuda::_thnn_conv_dilated3d_backward -- func: thnn_col2im(Tensor self, int[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor +- func: col2im(Tensor self, int[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) out) -> Tensor(a!) python_module: nn dispatch: - CPU: legacy::cpu::_thnn_col2im_forward - CUDA: legacy::cuda::_thnn_col2im_forward + CPU: col2im_out_cpu + CUDA: col2im_out_cuda -- func: thnn_col2im_backward(Tensor grad_output, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor +- func: col2im(Tensor self, int[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor python_module: nn dispatch: - CPU: legacy::cpu::_thnn_col2im_backward - CUDA: legacy::cuda::_thnn_col2im_backward + CPU: col2im_cpu + CUDA: col2im_cuda -- func: thnn_im2col(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor +- func: col2im_backward(Tensor grad_output, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) grad_input) -> Tensor(a!) python_module: nn dispatch: - CPU: legacy::cpu::_thnn_im2col_forward - CUDA: legacy::cuda::_thnn_im2col_forward + CPU: col2im_backward_out_cpu + CUDA: col2im_backward_out_cuda -- func: thnn_im2col_backward(Tensor grad_output, int[2] input_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor +- func: col2im_backward(Tensor grad_output, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor python_module: nn dispatch: - CPU: legacy::cpu::_thnn_im2col_backward - CUDA: legacy::cuda::_thnn_im2col_backward + CPU: col2im_backward_cpu + CUDA: col2im_backward_cuda + +- func: im2col(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + dispatch: + CPU: im2col_out_cpu + CUDA: im2col_out_cuda + +- func: im2col(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor + python_module: nn + dispatch: + CPU: im2col_cpu + CUDA: im2col_cuda + +- func: im2col_backward(Tensor grad_output, int[2] input_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + dispatch: + CPU: im2col_backward_out_cpu + CUDA: im2col_backward_out_cuda + +- func: im2col_backward(Tensor grad_output, int[2] input_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor + python_module: nn + dispatch: + CPU: im2col_backward_cpu + CUDA: im2col_backward_cuda diff --git a/aten/src/ATen/native/vol2col.h b/aten/src/ATen/native/vol2col.h new file mode 100644 index 0000000000000..d4659ad6e1db0 --- /dev/null +++ b/aten/src/ATen/native/vol2col.h @@ -0,0 +1,115 @@ +#pragma once + +#include +#include +#include +#include + +namespace at { +namespace native { + +template +static void vol2col( + const T* data_vol, + const int64_t channels, + const int64_t depth, + const int64_t height, + const int64_t width, + const int64_t depth_col, + const int64_t height_col, + const int64_t width_col, + const int64_t kT, + const int64_t kernel_height, + const int64_t kernel_width, + const int64_t pT, + const int64_t pH, + const int64_t pW, + const int64_t dT, + const int64_t dH, + const int64_t dW, + const int64_t dilationT, + const int64_t dilationH, + const int64_t dilationW, + T* data_col) { + int64_t c, t, h, w; + int64_t channels_col = channels * kT * kernel_height * kernel_width; + for (c = 0; c < channels_col; ++c) { + int64_t w_offset = c % kernel_width; + int64_t h_offset = (c / kernel_width) % kernel_height; + int64_t t_offset = (c / kernel_width / kernel_height) % kT; + int64_t c_vol = c / kT / kernel_height / kernel_width; + for (t = 0; t < depth_col; ++t) { + int64_t t_pad = t * dT - pT + t_offset * dilationT; + for (h = 0; h < height_col; ++h) { + int64_t h_pad = h * dH - pH + h_offset * dilationH; + for (w = 0; w < width_col; ++w) { + int64_t w_pad = w * dW - pW + w_offset * dilationW; + if (t_pad >= 0 && t_pad < depth && h_pad >= 0 && h_pad < height && + w_pad >= 0 && w_pad < width) + data_col[((c * depth_col + t) * height_col + h) * width_col + w] = + data_vol + [((c_vol * depth + t_pad) * height + h_pad) * width + + w_pad]; + else + data_col[((c * depth_col + t) * height_col + h) * width_col + w] = + 0; + } + } + } + } +} + +template +static void col2vol( + const T* data_col, + const int64_t channels, + const int64_t depth, + const int64_t height, + const int64_t width, + const int64_t out_depth, + const int64_t out_height, + const int64_t out_width, + const int64_t kT, + const int64_t kernel_height, + const int64_t kernel_width, + const int64_t pT, + const int64_t pH, + const int64_t pW, + const int64_t dT, + const int64_t dH, + const int64_t dW, + const int64_t dilationT, + const int64_t dilationH, + const int64_t dilationW, + T* data_vol) { + int64_t c, t, h, w; + memset(data_vol, 0, sizeof(T) * depth * height * width * channels); + int64_t depth_col = out_depth; + int64_t height_col = out_height; + int64_t width_col = out_width; + int64_t channels_col = channels * kT * kernel_height * kernel_width; + for (c = 0; c < channels_col; ++c) { + int64_t w_offset = c % kernel_width; + int64_t h_offset = (c / kernel_width) % kernel_height; + int64_t t_offset = (c / kernel_width / kernel_height) % kT; + int64_t c_vol = c / kT / kernel_height / kernel_width; + for (t = 0; t < depth_col; ++t) { + int64_t t_pad = t * dT - pT + t_offset * dilationT; + for (h = 0; h < height_col; ++h) { + int64_t h_pad = h * dH - pH + h_offset * dilationH; + for (w = 0; w < width_col; ++w) { + int64_t w_pad = w * dW - pW + w_offset * dilationW; + if (t_pad >= 0 && t_pad < depth && h_pad >= 0 && h_pad < height && + w_pad >= 0 && w_pad < width) + data_vol + [((c_vol * depth + t_pad) * height + h_pad) * width + w_pad] += + data_col + [((c * depth_col + t) * height_col + h) * width_col + w]; + } + } + } + } +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/nn.yaml b/aten/src/ATen/nn.yaml index 03bedf9c2e5b1..f99ca2b5d0d20 100644 --- a/aten/src/ATen/nn.yaml +++ b/aten/src/ATen/nn.yaml @@ -170,17 +170,3 @@ - name: _thnn_conv_dilated3d(Tensor self, Tensor weight, IntArrayRef[3] kernel_size, Tensor? bias={}, IntArrayRef[3] stride=1, IntArrayRef[3] padding=0, IntArrayRef[3] dilation=1) cname: VolumetricDilatedConvolution buffers: [columns, ones] - -# Fold and Unfold - -- name: _thnn_col2im(Tensor self, IntArrayRef[2] output_size, IntArrayRef[2] kernel_size, IntArrayRef[2] dilation, IntArrayRef[2] padding, IntArrayRef[2] stride) - cname: Col2Im - scalar_check: - output: 'false' - grad_input: 'false' - -- name: _thnn_im2col(Tensor self, IntArrayRef[2] kernel_size, IntArrayRef[2] dilation, IntArrayRef[2] padding, IntArrayRef[2] stride) - cname: Im2Col - scalar_check: - output: 'false' - grad_input: 'false' diff --git a/aten/src/THCUNN/CMakeLists.txt b/aten/src/THCUNN/CMakeLists.txt index a55a7db25b4a4..ecb5ee48fc86d 100644 --- a/aten/src/THCUNN/CMakeLists.txt +++ b/aten/src/THCUNN/CMakeLists.txt @@ -3,13 +3,11 @@ ${CMAKE_CURRENT_SOURCE_DIR}/AbsCriterion.cu ${CMAKE_CURRENT_SOURCE_DIR}/Abs.cu ${CMAKE_CURRENT_SOURCE_DIR}/BCECriterion.cu ${CMAKE_CURRENT_SOURCE_DIR}/ClassNLLCriterion.cu -${CMAKE_CURRENT_SOURCE_DIR}/Col2Im.cu ${CMAKE_CURRENT_SOURCE_DIR}/DistKLDivCriterion.cu ${CMAKE_CURRENT_SOURCE_DIR}/ELU.cu ${CMAKE_CURRENT_SOURCE_DIR}/FeatureLPPooling.cu ${CMAKE_CURRENT_SOURCE_DIR}/GatedLinearUnit.cu ${CMAKE_CURRENT_SOURCE_DIR}/HardTanh.cu -${CMAKE_CURRENT_SOURCE_DIR}/Im2Col.cu ${CMAKE_CURRENT_SOURCE_DIR}/IndexLinear.cu ${CMAKE_CURRENT_SOURCE_DIR}/L1Cost.cu ${CMAKE_CURRENT_SOURCE_DIR}/LeakyReLU.cu diff --git a/aten/src/THCUNN/Col2Im.cu b/aten/src/THCUNN/Col2Im.cu deleted file mode 100644 index 8e6a5de2da857..0000000000000 --- a/aten/src/THCUNN/Col2Im.cu +++ /dev/null @@ -1,11 +0,0 @@ -#include -#include -#include -#include -#include - -#include -#include - -#include -#include diff --git a/aten/src/THCUNN/Im2Col.cu b/aten/src/THCUNN/Im2Col.cu deleted file mode 100644 index 785106e02c003..0000000000000 --- a/aten/src/THCUNN/Im2Col.cu +++ /dev/null @@ -1,11 +0,0 @@ -#include -#include -#include - -#include -#include -#include -#include - -#include -#include diff --git a/aten/src/THCUNN/SpatialConvolutionLocal.cu b/aten/src/THCUNN/SpatialConvolutionLocal.cu index 31f3be1037e84..d92103e9c8fac 100644 --- a/aten/src/THCUNN/SpatialConvolutionLocal.cu +++ b/aten/src/THCUNN/SpatialConvolutionLocal.cu @@ -1,6 +1,7 @@ #include +#include #include -#include +#include #include #include diff --git a/aten/src/THCUNN/SpatialConvolutionMM.cu b/aten/src/THCUNN/SpatialConvolutionMM.cu index 86eaebb3ef29a..6b151ae14a55b 100644 --- a/aten/src/THCUNN/SpatialConvolutionMM.cu +++ b/aten/src/THCUNN/SpatialConvolutionMM.cu @@ -1,7 +1,7 @@ #include #include #include -#include +#include #include #include diff --git a/aten/src/THCUNN/SpatialDilatedConvolution.cu b/aten/src/THCUNN/SpatialDilatedConvolution.cu index e5d44d6f127ae..7a85e1f1264f3 100644 --- a/aten/src/THCUNN/SpatialDilatedConvolution.cu +++ b/aten/src/THCUNN/SpatialDilatedConvolution.cu @@ -1,6 +1,6 @@ #include #include -#include +#include #include #include diff --git a/aten/src/THCUNN/SpatialFullConvolution.cu b/aten/src/THCUNN/SpatialFullConvolution.cu index c46dc1223ae3f..1980d59abbb33 100644 --- a/aten/src/THCUNN/SpatialFullConvolution.cu +++ b/aten/src/THCUNN/SpatialFullConvolution.cu @@ -1,5 +1,7 @@ #include -#include +#include +#include +#include #include #include diff --git a/aten/src/THCUNN/SpatialFullDilatedConvolution.cu b/aten/src/THCUNN/SpatialFullDilatedConvolution.cu index 961d1f0ed4de1..449ca31a082e7 100644 --- a/aten/src/THCUNN/SpatialFullDilatedConvolution.cu +++ b/aten/src/THCUNN/SpatialFullDilatedConvolution.cu @@ -1,5 +1,8 @@ #include -#include +#include +#include +#include +#include #include #include diff --git a/aten/src/THCUNN/VolumetricDilatedConvolution.cu b/aten/src/THCUNN/VolumetricDilatedConvolution.cu index 41f2eab7c97ef..28efe00c1d956 100644 --- a/aten/src/THCUNN/VolumetricDilatedConvolution.cu +++ b/aten/src/THCUNN/VolumetricDilatedConvolution.cu @@ -1,7 +1,7 @@ #include #include #include -#include +#include #include #include diff --git a/aten/src/THCUNN/VolumetricFullDilatedConvolution.cu b/aten/src/THCUNN/VolumetricFullDilatedConvolution.cu index d5f9d87317c0f..9da51c0e31fc8 100644 --- a/aten/src/THCUNN/VolumetricFullDilatedConvolution.cu +++ b/aten/src/THCUNN/VolumetricFullDilatedConvolution.cu @@ -1,7 +1,7 @@ #include #include #include -#include +#include #include #include diff --git a/aten/src/THCUNN/generic/Col2Im.cu b/aten/src/THCUNN/generic/Col2Im.cu deleted file mode 100644 index e1960387f1e56..0000000000000 --- a/aten/src/THCUNN/generic/Col2Im.cu +++ /dev/null @@ -1,131 +0,0 @@ -#ifndef THC_GENERIC_FILE -#define THC_GENERIC_FILE "THCUNN/generic/Col2Im.cu" -#else - -#include - -static inline void THNN_(Col2Im_shapeCheck)( - THCState *state, - THCTensor *input, - THCTensor *gradOutput, - int64_t outputHeight, int64_t outputWidth, - int64_t kH, int64_t kW, int64_t dH, int64_t dW, - int64_t padH, int64_t padW, int64_t sH, int64_t sW) { - - THArgCheck(kW > 0 && kH > 0, 6, - "kernel size should be greater than zero, but got kH: %d kW: %d", kH, kW); - THArgCheck(sW > 0 && sH > 0, 12, - "stride should be greater than zero, but got sH: %d sW: %d", sH, sW); - THArgCheck(dW > 0 && dH > 0, 8, - "dilation should be greater than zero, but got dH: %d dW: %d", dH, dW); - - int64_t ndim = THCTensor_(nDimensionLegacyNoScalars)(state, input); - THCUNN_argCheck(state, !input->is_empty() && (ndim == 2 || ndim == 3), 2, input, - "Expected non-empty 2D or 3D input tensor, but got input of shape %s"); - - int batch_dim = (ndim == 3) ? 0 : -1; - int64_t nInputPlane = input->size(batch_dim + 1); - - if (nInputPlane % (kW * kH) != 0) { - THError("Expected size of input's dimension 1 to be divisible by the " - "product of kernel_size, but got input.size(1)=%lld and " - "kernel_size=(%d, %d).", (long long) nInputPlane, kH, kW); - } - - int64_t inputLength = input->size(batch_dim + 2); - int64_t nBlocksH = div_rtn(outputHeight + 2 * padH - dH * (kH - 1) - 1, sH) + 1; - int64_t nBlocksW = div_rtn(outputWidth + 2 * padW - dW * (kW - 1) - 1, sW) + 1; - - if (inputLength != (nBlocksH * nBlocksW)) { - THError("Given output_size=(%d, %d), kernel_size=(%d, %d), " - "dilation=(%d, %d), padding=(%d, %d), stride=(%d, %d), expected " - "size of input's dimension 2 to match the calculated number of " - "sliding blocks %lld * %lld = %lld, but got input.size(2)=%lld.", - outputHeight, outputWidth, kH, kW, dH, dW, padH, padW, sH, sW, - (long long) nBlocksH, (long long) nBlocksW, - (long long) (nBlocksH * nBlocksW), (long long) inputLength); - } - - if (outputWidth < 1 || outputHeight < 1) { - THError("Expected output spatial size to be positive, but got: output_size=(%d, %d).", - outputHeight, outputWidth); - } -} - -void THNN_(Col2Im_updateOutput)( - THCState *state, - THCTensor *input, - THCTensor *output, - int64_t outputHeight, int64_t outputWidth, - int64_t kH, int64_t kW, - int64_t dH, int64_t dW, - int64_t padH, int64_t padW, - int64_t sH, int64_t sW) { - - THCUNN_assertSameGPU(state, 2, input, output); - - THNN_(Col2Im_shapeCheck)(state, input, NULL, outputHeight, outputWidth, - kH, kW, dH, dW, padH, padW, sH, sW); - - bool batched_input = true; - if (input->dim() == 2) { - // Force batch - batched_input = false; - THCTensor_(resize3d)(state, input, 1, input->size(0), input->size(1)); - } - - int64_t batchSize = input->size(0); - int64_t nInputPlane = input->size(1); - int64_t nOutputPlane = nInputPlane / (kW * kH); - - input = THCTensor_(newContiguous)(state, input); - - THCTensor_(resize4d)(state, output, batchSize, nOutputPlane, outputHeight, outputWidth); - THCTensor_(zero)(state, output); - - THCTensor *input_n = THCTensor_(new)(state); - THCTensor *output_n = THCTensor_(new)(state); - - int64_t height_col = (outputHeight + 2 * padH - (dH * (kH - 1) + 1)) / sH + 1; - int64_t width_col = (outputWidth + 2 * padW - (dW * (kW - 1) + 1)) / sW + 1; - - for (int64_t elt = 0; elt < batchSize; elt++) { - THCTensor_(select)(state, input_n, input, 0, elt); - THCTensor_(select)(state, output_n, output, 0, elt); - - col2im( - THCState_getCurrentStream(state), - THCTensor_(data)(state, input_n), - nOutputPlane, - outputHeight, outputWidth, - height_col, width_col, - kH, kW, - padH, padW, - sH, sW, - dH, dW, THCTensor_(data)(state, output_n)); - } - - THCTensor_(free)(state, input_n); - THCTensor_(free)(state, output_n); - - if (!batched_input) { - THCTensor_(resize3d)(state, output, nOutputPlane, outputHeight, outputWidth); - } - THCTensor_(free)(state, input); -} - -void THNN_(Col2Im_updateGradInput)( - THCState *state, - THCTensor *gradOutput, - THCTensor *gradInput, - int64_t kH, int64_t kW, - int64_t dH, int64_t dW, - int64_t padH, int64_t padW, - int64_t sH, int64_t sW) { - - THNN_(Im2Col_updateOutput)(state, gradOutput, gradInput, - kH, kW, dH, dW, padH, padW, sH, sW); - -} - -#endif diff --git a/aten/src/THCUNN/generic/Im2Col.cu b/aten/src/THCUNN/generic/Im2Col.cu deleted file mode 100644 index cd7140ff620ba..0000000000000 --- a/aten/src/THCUNN/generic/Im2Col.cu +++ /dev/null @@ -1,121 +0,0 @@ -#ifndef THC_GENERIC_FILE -#define THC_GENERIC_FILE "THCUNN/generic/Im2Col.cu" -#else - -#include - -static inline void THNN_(Im2Col_shapeCheck)( - THCState *state, - THCTensor *input, - THCTensor *gradOutput, - int64_t kH, int64_t kW, int64_t dH, int64_t dW, - int64_t padH, int64_t padW, int64_t sH, int64_t sW) { - - THArgCheck(kW > 0 && kH > 0, 4, - "kernel size should be greater than zero, but got kH: %d kW: %d", kH, kW); - THArgCheck(dW > 0 && dH > 0, 6, - "dilation should be greater than zero, but got dH: %d dW: %d", dH, dW); - THArgCheck(padW >= 0 && padH >= 0, 8, - "padding should be non-negative, but got padH: %d padW: %d", padH, padW); - THArgCheck(sW > 0 && sH > 0, 10, - "stride should be greater than zero, but got sH: %d sW: %d", sH, sW); - - int64_t ndim = THCTensor_(nDimensionLegacyNoScalars)(state, input); - THCUNN_argCheck(state, !input->is_empty() && (ndim == 3 || ndim == 4), 2, input, - "Expected non-empty 3D or 4D input tensor, but got input of shape %s"); - - int dim_batch = 0; - if (ndim == 3) { - dim_batch = -1; - } - int64_t nInputPlane = THCTensor_(size)(state, input, dim_batch + 1); - int64_t inputHeight = THCTensor_(size)(state, input, dim_batch + 2); - int64_t inputWidth = THCTensor_(size)(state, input, dim_batch + 3); - int64_t outputHeight = div_rtn(inputHeight + 2 * padH - (dH * (kH - 1) + 1), sH) + 1; - int64_t outputWidth = div_rtn(inputWidth + 2 * padW - (dW * (kW - 1) + 1), sW) + 1; - - if (outputHeight < 1 || outputWidth < 1) { - THError("Given input with spatial size (%d, %d), kernel_size=(%d, %d), " - "dilation=(%d, %d), padding=(%d, %d), calculated " - "shape of the array of sliding blocks as (%d, %d), which is " - "too small (non-positive).", - inputHeight, inputHeight, kH, kW, dH, dW, padH, padW, - outputHeight, outputWidth); - } -} - -void THNN_(Im2Col_updateOutput)( - THCState *state, - THCTensor *input, - THCTensor *output, - int64_t kH, int64_t kW, - int64_t dH, int64_t dW, - int64_t padH, int64_t padW, - int64_t sH, int64_t sW) { - - THCUNN_assertSameGPU(state, 2, input, output); - - THNN_(Im2Col_shapeCheck)(state, input, NULL, kH, kW, dH, dW, padH, padW, sH, sW); - - input = THCTensor_(newContiguous)(state, input); - bool batched_input = true; - if (input->dim() == 3) { - batched_input = false; - THCTensor_(resize4d)(state, input, 1, input->size(0), input->size(1), input->size(2)); - } - - int64_t batchSize = THCTensor_(size)(state, input, 0); - int64_t nInputPlane = THCTensor_(size)(state, input, 1); - int64_t inputHeight = THCTensor_(size)(state, input, 2); - int64_t inputWidth = THCTensor_(size)(state, input, 3); - - int64_t outputHeight = (inputHeight + 2 * padH - (dH * (kH - 1) + 1)) / sH + 1; - int64_t outputWidth = (inputWidth + 2 * padW - (dW * (kW - 1) + 1)) / sW + 1; - int64_t nOutputPlane = nInputPlane * kW * kH; - int64_t outputLength = outputHeight * outputWidth; - - THCTensor_(resize3d)(state, output, batchSize, nOutputPlane, outputLength); - THCTensor_(zero)(state, output); - - THCTensor *input_n = THCTensor_(new)(state); - THCTensor *output_n = THCTensor_(new)(state); - - for (int64_t elt = 0; elt < batchSize; elt++) { - THCTensor_(select)(state, input_n, input, 0, elt); - THCTensor_(select)(state, output_n, output, 0, elt); - - im2col( - THCState_getCurrentStream(state), - THCTensor_(data)(state, input_n), - nInputPlane, inputHeight, inputWidth, - outputHeight, outputWidth, - kH, kW, padH, padW, sH, sW, - dH, dW, THCTensor_(data)(state, output_n)); - } - - THCTensor_(free)(state, input_n); - THCTensor_(free)(state, output_n); - - if (!batched_input) { - THCTensor_(resize2d)(state, output, nOutputPlane, outputLength); - } - THCTensor_(free)(state, input); -} - -void THNN_(Im2Col_updateGradInput)( - THCState *state, - THCTensor *gradOutput, - THCTensor *gradInput, - int64_t inputHeight, int64_t inputWidth, - int64_t kH, int64_t kW, - int64_t dH, int64_t dW, - int64_t padH, int64_t padW, - int64_t sH, int64_t sW) { - - THNN_(Col2Im_updateOutput)(state, gradOutput, gradInput, - inputHeight, inputWidth, - kH, kW, dH, dW, - padH, padW, sH, sW); -} - -#endif diff --git a/aten/src/THCUNN/generic/SpatialConvolutionLocal.cu b/aten/src/THCUNN/generic/SpatialConvolutionLocal.cu index 9e98254140969..e081dad7a228b 100644 --- a/aten/src/THCUNN/generic/SpatialConvolutionLocal.cu +++ b/aten/src/THCUNN/generic/SpatialConvolutionLocal.cu @@ -131,7 +131,7 @@ void THNN_(SpatialConvolutionLocal_updateOutput)( THCTensor_(select)(state, output_n, output, 0, elt); // Extract columns: - im2col( + at::native::im2col( THCState_getCurrentStream(state), THCTensor_(data)(state, input_n), nInputPlane, inputHeight, inputWidth, @@ -265,7 +265,7 @@ void THNN_(SpatialConvolutionLocal_updateGradInput)( // fgradInput3d: oH*oW x nInputPlane*kH*kW x 1 // Unpack columns back into input: - col2im( + at::native::col2im( THCState_getCurrentStream(state), THCTensor_(data)(state, fgradInput_n), nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, kH, kW, padH, padW, dH, dW, @@ -368,7 +368,7 @@ void THNN_(SpatialConvolutionLocal_accGradParameters)( kW*kH*nInputPlane, outputHeight*outputWidth); // Extract columns: - im2col( + at::native::im2col( THCState_getCurrentStream(state), THCTensor_(data)(state, input_n), nInputPlane, inputHeight, inputWidth, diff --git a/aten/src/THCUNN/generic/SpatialConvolutionMM.cu b/aten/src/THCUNN/generic/SpatialConvolutionMM.cu index 0ca36e6236b16..4b95fadc0ddca 100644 --- a/aten/src/THCUNN/generic/SpatialConvolutionMM.cu +++ b/aten/src/THCUNN/generic/SpatialConvolutionMM.cu @@ -200,7 +200,7 @@ void THNN_(SpatialConvolutionMM_updateOutput)( } // Extract columns: - im2col( + at::native::im2col( THCState_getCurrentStream(state), THCTensor_(data)(state, input_n), nInputPlane, inputHeight, inputWidth, @@ -331,7 +331,7 @@ void THNN_(SpatialConvolutionMM_updateGradInput)( ); // Unpack columns back into input: - col2im( + at::native::col2im( THCState_getCurrentStream(state), THCTensor_(data)(state, gradColumns), nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, kH, kW, padH, padW, dH, dW, @@ -430,7 +430,7 @@ void THNN_(SpatialConvolutionMM_accGradParameters)( THCTensor_(select)(state, input_n, input, 0, elt); // Extract columns: - im2col( + at::native::im2col( THCState_getCurrentStream(state), THCTensor_(data)(state, input_n), nInputPlane, inputHeight, inputWidth, diff --git a/aten/src/THCUNN/generic/SpatialDilatedConvolution.cu b/aten/src/THCUNN/generic/SpatialDilatedConvolution.cu index 2aeb4e49cb5e6..e3e09a37894f7 100644 --- a/aten/src/THCUNN/generic/SpatialDilatedConvolution.cu +++ b/aten/src/THCUNN/generic/SpatialDilatedConvolution.cu @@ -173,7 +173,7 @@ void THNN_(SpatialDilatedConvolution_updateOutput)( } // Extract columns: - im2col( + at::native::im2col( THCState_getCurrentStream(state), THCTensor_(data)(state, input_n), nInputPlane, inputHeight, inputWidth, @@ -306,7 +306,7 @@ void THNN_(SpatialDilatedConvolution_updateGradInput)( ); // Unpack columns back into input: - col2im( + at::native::col2im( THCState_getCurrentStream(state), THCTensor_(data)(state, gradColumns), nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, kH, kW, padH, padW, dH, dW, @@ -408,7 +408,7 @@ void THNN_(SpatialDilatedConvolution_accGradParameters)( THCTensor_(select)(state, input_n, input, 0, elt); // Extract columns: - im2col( + at::native::im2col( THCState_getCurrentStream(state), THCTensor_(data)(state, input_n), nInputPlane, inputHeight, inputWidth, diff --git a/aten/src/THCUNN/generic/SpatialFullDilatedConvolution.cu b/aten/src/THCUNN/generic/SpatialFullDilatedConvolution.cu index e601a78f25f70..c152535a257a0 100644 --- a/aten/src/THCUNN/generic/SpatialFullDilatedConvolution.cu +++ b/aten/src/THCUNN/generic/SpatialFullDilatedConvolution.cu @@ -166,7 +166,7 @@ void THNN_(SpatialFullDilatedConvolution_updateOutput)( ); // Unpack columns back into input: - col2im( + at::native::col2im( THCState_getCurrentStream(state), THCTensor_(data)(state, columns), nOutputPlane, outputHeight, outputWidth, inputHeight, inputWidth, kH, kW, padH, padW, dH, dW, @@ -273,7 +273,7 @@ void THNN_(SpatialFullDilatedConvolution_updateGradInput)( THCTensor_(select)(state, gradOutput_n, gradOutput, 0, elt); // Extract columns: - im2col( + at::native::im2col( THCState_getCurrentStream(state), THCTensor_(data)(state, gradOutput_n), nOutputPlane, outputHeight, outputWidth, @@ -408,7 +408,7 @@ void THNN_(SpatialFullDilatedConvolution_accGradParameters)( THCTensor_(select)(state, input_n, input, 0, elt); // Extract columns: - im2col( + at::native::im2col( THCState_getCurrentStream(state), THCTensor_(data)(state, gradOutput_n), nOutputPlane, outputHeight, outputWidth, diff --git a/aten/src/THCUNN/generic/THCUNN.h b/aten/src/THCUNN/generic/THCUNN.h index 965a7c41a6402..7eed4f96d8e13 100644 --- a/aten/src/THCUNN/generic/THCUNN.h +++ b/aten/src/THCUNN/generic/THCUNN.h @@ -151,44 +151,6 @@ THC_API void THNN_(GatedLinear_updateGradInput)( THCTensor *gradInput, int dim); -THC_API void THNN_(Im2Col_updateOutput)( - THCState *state, - THCTensor *input, - THCTensor *output, - int64_t kH, int64_t kW, - int64_t dH, int64_t dW, - int64_t padH, int64_t padW, - int64_t sH, int64_t sW); - -THC_API void THNN_(Im2Col_updateGradInput)( - THCState *state, - THCTensor *gradOutput, - THCTensor *gradInput, - int64_t inputHeight, int64_t inputWidth, - int64_t kH, int64_t kW, - int64_t dH, int64_t dW, - int64_t padH, int64_t padW, - int64_t sH, int64_t sW); - -THC_API void THNN_(Col2Im_updateOutput)( - THCState *state, - THCTensor *input, - THCTensor *output, - int64_t outputHeight, int64_t outputWidth, - int64_t kH, int64_t kW, - int64_t dH, int64_t dW, - int64_t padH, int64_t padW, - int64_t sH, int64_t sW); - - THC_API void THNN_(Col2Im_updateGradInput)( - THCState *state, - THCTensor *gradOutput, - THCTensor *gradInput, - int64_t kH, int64_t kW, - int64_t dH, int64_t dW, - int64_t padH, int64_t padW, - int64_t sH, int64_t sW); - THC_API void THNN_(LeakyReLU_updateOutput)( THCState *state, THCTensor *input, diff --git a/aten/src/THCUNN/generic/VolumetricDilatedConvolution.cu b/aten/src/THCUNN/generic/VolumetricDilatedConvolution.cu index cdc229bbfcf48..aa7c0e5754562 100644 --- a/aten/src/THCUNN/generic/VolumetricDilatedConvolution.cu +++ b/aten/src/THCUNN/generic/VolumetricDilatedConvolution.cu @@ -186,7 +186,7 @@ void THNN_(VolumetricDilatedConvolution_updateOutput)( } // Extract columns: - vol2col( + at::native::vol2col( THCState_getCurrentStream(state), THCTensor_(data)(state, input_n), nInputPlane, inputDepth, inputHeight, inputWidth, @@ -322,7 +322,7 @@ void THNN_(VolumetricDilatedConvolution_updateGradInput)( ); // Unpack columns back into input: - col2vol( + at::native::col2vol( THCState_getCurrentStream(state), THCTensor_(data)(state, gradColumns), nInputPlane, inputDepth, inputHeight, inputWidth, @@ -417,7 +417,7 @@ void THNN_(VolumetricDilatedConvolution_accGradParameters)( THCTensor_(select)(state, input_n, input, 0, elt); // Extract columns: - vol2col( + at::native::vol2col( THCState_getCurrentStream(state), THCTensor_(data)(state, input_n), nInputPlane, inputDepth, inputHeight, inputWidth, diff --git a/aten/src/THCUNN/generic/VolumetricFullDilatedConvolution.cu b/aten/src/THCUNN/generic/VolumetricFullDilatedConvolution.cu index ac3f2a6e21993..aadb792f38d3d 100644 --- a/aten/src/THCUNN/generic/VolumetricFullDilatedConvolution.cu +++ b/aten/src/THCUNN/generic/VolumetricFullDilatedConvolution.cu @@ -185,7 +185,7 @@ void THNN_(VolumetricFullDilatedConvolution_updateOutput)( ); // Unpack columns back into input: - col2vol( + at::native::col2vol( THCState_getCurrentStream(state), THCTensor_(data)(state, columns), nOutputPlane, outputDepth, outputHeight, outputWidth, @@ -303,7 +303,7 @@ void THNN_(VolumetricFullDilatedConvolution_updateGradInput)( THCTensor_(select)(state, gradOutput_n, gradOutput, 0, elt); // Extract columns: - vol2col( + at::native::vol2col( THCState_getCurrentStream(state), THCTensor_(data)(state, gradOutput_n), nOutputPlane, outputDepth, outputHeight, outputWidth, @@ -446,7 +446,7 @@ void THNN_(VolumetricFullDilatedConvolution_accGradParameters)( THCTensor_(select)(state, input_n, input, 0, elt); // Extract columns: - vol2col( + at::native::vol2col( THCState_getCurrentStream(state), THCTensor_(data)(state, gradOutput_n), nOutputPlane, outputDepth, outputHeight, outputWidth, diff --git a/aten/src/THCUNN/im2col.h b/aten/src/THCUNN/im2col.h deleted file mode 100644 index ce763f4adadb7..0000000000000 --- a/aten/src/THCUNN/im2col.h +++ /dev/null @@ -1,124 +0,0 @@ -#ifndef THCUNN_IM2COL_H -#define THCUNN_IM2COL_H - -#include -#include -#include - -// Kernel for fast unfold+copy -// (borrowed from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/conv_layer.cu) -template -C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS) -__global__ void im2col_kernel(const int64_t n, const Dtype* data_im, - const int64_t height, const int64_t width, - const int64_t ksize_h, const int64_t ksize_w, - const int64_t pad_h, const int64_t pad_w, - const int64_t stride_h, const int64_t stride_w, - const int64_t dilation_h, const int64_t dilation_w, - const int64_t height_col, const int64_t width_col, - Dtype* data_col) { - CUDA_KERNEL_LOOP(index, n) { - int64_t w_out = index % width_col; - index /= width_col; - int64_t h_out = index % height_col; - int64_t channel_in = index / height_col; - int64_t channel_out = channel_in * ksize_h * ksize_w; - int64_t h_in = h_out * stride_h - pad_h; - int64_t w_in = w_out * stride_w - pad_w; - data_col += (channel_out * height_col + h_out) * width_col + w_out; - data_im += (channel_in * height + h_in) * width + w_in; - for (int64_t i = 0; i < ksize_h; ++i) { - for (int64_t j = 0; j < ksize_w; ++j) { - int64_t h = h_in + i * dilation_h; - int64_t w = w_in + j * dilation_w; - *data_col = (h >= 0 && w >= 0 && h < height && w < width) ? - data_im[i * dilation_h * width + j * dilation_w] : ScalarConvert::to(0); - data_col += height_col * width_col; - } - } - } -} - -template -void im2col(cudaStream_t stream, const Dtype* data_im, const int64_t channels, - const int64_t height, const int64_t width, - const int64_t height_col, const int64_t width_col, - const int64_t ksize_h, const int64_t ksize_w, const int64_t pad_h, - const int64_t pad_w, const int64_t stride_h, const int64_t stride_w, - const int64_t dilation_h, const int64_t dilation_w, Dtype* data_col) { - // We are going to launch channels * height_col * width_col kernels, each - // kernel responsible for copying a single-channel grid. - int64_t num_kernels = channels * height_col * width_col; - // Launch - im2col_kernel <<>> ( - num_kernels, data_im, height, width, ksize_h, ksize_w, - pad_h, pad_w, stride_h, stride_w, - dilation_h, dilation_w, - height_col, width_col, data_col - ); - THCudaCheck(cudaGetLastError()); -} - -template -C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS) -__global__ void col2im_kernel(const int64_t n, const Dtype* data_col, - const int64_t height, const int64_t width, const int64_t channels, - const int64_t kernel_h, const int64_t kernel_w, - const int64_t pad_h, const int64_t pad_w, - const int64_t stride_h, const int64_t stride_w, - const int64_t dilation_h, const int64_t dilation_w, - const int64_t height_col, const int64_t width_col, - Dtype* data_im) { - CUDA_KERNEL_LOOP(index, n) { - Acctype val = Acctype(0); - const int64_t w_im = index % width + pad_w; - const int64_t h_im = (index / width) % height + pad_h; - const int64_t c_im = index / (width * height); - int64_t kernel_extent_w = (kernel_w - 1) * dilation_w + 1; - int64_t kernel_extent_h = (kernel_h - 1) * dilation_h + 1; - // compute the start and end of the output - const int64_t w_col_start = - (w_im < kernel_extent_w) ? 0 : (w_im - kernel_extent_w) / stride_w + 1; - const int64_t w_col_end = min(w_im / stride_w + 1, width_col); - const int64_t h_col_start = - (h_im < kernel_extent_h) ? 0 : (h_im - kernel_extent_h) / stride_h + 1; - const int64_t h_col_end = min(h_im / stride_h + 1, height_col); - // TODO: use LCM of stride and dilation to avoid unnecessary loops - for (int64_t h_col = h_col_start; h_col < h_col_end; h_col += 1) { - for (int64_t w_col = w_col_start; w_col < w_col_end; w_col += 1) { - int64_t h_k = (h_im - h_col * stride_h); - int64_t w_k = (w_im - w_col * stride_w); - if (h_k % dilation_h == 0 && w_k % dilation_w == 0) { - h_k /= dilation_h; - w_k /= dilation_w; - int64_t data_col_index = (((c_im * kernel_h + h_k) * kernel_w + w_k) * - height_col + h_col) * width_col + w_col; - val += data_col[data_col_index]; - } - } - } - data_im[index] = ScalarConvert::to(val); - } -} - - -template -void col2im(cudaStream_t stream, const Dtype* data_col, const int64_t channels, - const int64_t height, const int64_t width, - const int64_t output_height, const int64_t output_width, - const int64_t patch_h, const int64_t patch_w, const int64_t pad_h, - const int64_t pad_w, const int64_t stride_h, const int64_t stride_w, - const int64_t dilation_h, const int64_t dilation_w, Dtype* data_im) { - int64_t num_kernels = channels * height * width; - // To avoid involving atomic operations, we will launch one kernel per - // bottom dimension, and then in the kernel add up the top dimensions. - col2im_kernel <<>> ( - num_kernels, data_col, height, width, channels, - patch_h, patch_w, pad_h, pad_w, stride_h, stride_w, - dilation_h, dilation_w, - output_height, output_width, data_im - ); - THCudaCheck(cudaGetLastError()); -} - -#endif diff --git a/aten/src/THCUNN/vol2col.h b/aten/src/THCUNN/vol2col.h deleted file mode 100644 index 20aa44ae360d5..0000000000000 --- a/aten/src/THCUNN/vol2col.h +++ /dev/null @@ -1,139 +0,0 @@ -#ifndef THCUNN_VOL2COL_H -#define THCUNN_VOL2COL_H - -#include -#include - -// Kernel for fast unfold+copy on volumes -template -__global__ void vol2col_kernel(const int n, const Dtype* data_vol, - const int depth, const int height, const int width, - const int ksize_t, const int ksize_h, const int ksize_w, - const int pad_t, const int pad_h, const int pad_w, - const int stride_t, const int stride_h, const int stride_w, - const int dilation_t, const int dilation_h, const int dilation_w, - const int depth_col, const int height_col, const int width_col, - Dtype* data_col) { -CUDA_KERNEL_LOOP(index, n) { - int w_out = index % width_col; - index /= width_col; - int h_out = index % height_col; - index /= height_col; - int t_out = index % depth_col; - int channel_in = index / depth_col; - int channel_out = channel_in * ksize_t * ksize_h * ksize_w; - int t_in = t_out * stride_t - pad_t; - int h_in = h_out * stride_h - pad_h; - int w_in = w_out * stride_w - pad_w; - data_col += ((channel_out * depth_col + t_out) * height_col + h_out) * width_col + w_out; - data_vol += ((channel_in * depth + t_in) * height + h_in) * width + w_in; - for (int i = 0; i < ksize_t; ++i) { - for (int j = 0; j < ksize_h; ++j) { - for (int k = 0; k < ksize_w; ++k) { - int t = t_in + i * dilation_t; - int h = h_in + j * dilation_h; - int w = w_in + k * dilation_w; - *data_col = (t >= 0 && h >= 0 && w >= 0 && t < depth && h < height && w < width) ? - data_vol[i * dilation_t * height * width + j * dilation_h * width + k * dilation_w] : ScalarConvert::to(0); - data_col += depth_col * height_col * width_col; - } - } - } - } -} - -template -void vol2col(cudaStream_t stream, const Dtype* data_vol, const int channels, - const int depth, const int height, const int width, - const int depth_col, const int height_col, const int width_col, - const int ksize_t, const int ksize_h, const int ksize_w, - const int pad_t, const int pad_h, const int pad_w, - const int stride_t, const int stride_h, const int stride_w, - const int dilation_t, const int dilation_h, const int dilation_w, - Dtype* data_col) { - // We are going to launch channels * depth_col * height_col * width_col kernels, each - // kernel responsible for copying a single-channel grid. - int num_kernels = channels * depth_col * height_col * width_col; - // Launch - vol2col_kernel <<>> ( - num_kernels, data_vol, depth, height, width, ksize_t, ksize_h, ksize_w, - pad_t, pad_h, pad_w, stride_t, stride_h, stride_w, - dilation_t, dilation_h, dilation_w, - depth_col, height_col, width_col, data_col - ); - THCudaCheck(cudaGetLastError()); -} - -template -__global__ void vol2im_kernel(const int n, const Dtype* data_col, - const int depth, const int height, const int width, const int channels, - const int kernel_t, const int kernel_h, const int kernel_w, - const int pad_t, const int pad_h, const int pad_w, - const int stride_t, const int stride_h, const int stride_w, - const int dilation_t, const int dilation_h, const int dilation_w, - const int depth_col, const int height_col, const int width_col, - Dtype* data_vol) { - CUDA_KERNEL_LOOP(index, n) { - Acctype val = Acctype(0); - const int w_im = index % width + pad_w; - const int h_im = (index / width) % height + pad_h; - const int t_im = (index / width / height) % depth + pad_t; - const int c_im = index / (width * height * depth); - int kernel_extent_w = (kernel_w - 1) * dilation_w + 1; - int kernel_extent_h = (kernel_h - 1) * dilation_h + 1; - int kernel_extent_t = (kernel_t - 1) * dilation_t + 1; - // compute the start and end of the output - const int w_col_start = - (w_im < kernel_extent_w) ? 0 : (w_im - kernel_extent_w) / stride_w + 1; - const int w_col_end = min(w_im / stride_w + 1, width_col); - const int h_col_start = - (h_im < kernel_extent_h) ? 0 : (h_im - kernel_extent_h) / stride_h + 1; - const int h_col_end = min(h_im / stride_h + 1, height_col); - const int t_col_start = - (t_im < kernel_extent_t) ? 0 : (t_im - kernel_extent_t) / stride_t + 1; - const int t_col_end = min(t_im / stride_t + 1, depth_col); - // TODO: use LCM of stride and dilation to avoid unnecessary loops - for (int t_col = t_col_start; t_col < t_col_end; t_col += 1) { - for (int h_col = h_col_start; h_col < h_col_end; h_col += 1) { - for (int w_col = w_col_start; w_col < w_col_end; w_col += 1) { - int t_k = (t_im - t_col * stride_t); - int h_k = (h_im - h_col * stride_h); - int w_k = (w_im - w_col * stride_w); - if (t_k % dilation_t == 0 && h_k % dilation_h == 0 && w_k % dilation_w == 0) { - t_k /= dilation_t; - h_k /= dilation_h; - w_k /= dilation_w; - int data_col_index = - (((((c_im * kernel_t + t_k) * kernel_h + h_k) * kernel_w + w_k) - * depth_col + t_col) * height_col + h_col) * width_col + w_col; - val += data_col[data_col_index]; - } - } - } - } - data_vol[index] = ScalarConvert::to(val); - } -} - -template -void col2vol(cudaStream_t stream, const Dtype* data_col, const int channels, - const int depth, const int height, const int width, - const int output_depth, const int output_height, const int output_width, - const int patch_t, const int patch_h, const int patch_w, - const int pad_t, const int pad_h, const int pad_w, - const int stride_t, const int stride_h, const int stride_w, - const int dilation_t, const int dilation_h, const int dilation_w, - Dtype* data_vol) { - int num_kernels = channels * depth * height * width; - // To avoid involving atomic operations, we will launch one kernel per - // bottom dimension, and then in the kernel add up the top dimensions. - vol2im_kernel <<>> ( - num_kernels, data_col, depth, height, width, channels, - patch_t, patch_h, patch_w, pad_t, pad_h, pad_w, stride_t, stride_h, stride_w, - dilation_t, dilation_h, dilation_w, - output_depth, output_height, output_width, data_vol - ); - THCudaCheck(cudaGetLastError()); -} - -#endif diff --git a/aten/src/THNN/generic/Col2Im.c b/aten/src/THNN/generic/Col2Im.c deleted file mode 100644 index 1d86abbf77aa6..0000000000000 --- a/aten/src/THNN/generic/Col2Im.c +++ /dev/null @@ -1,234 +0,0 @@ -#ifndef TH_GENERIC_FILE -#define TH_GENERIC_FILE "THNN/generic/Col2Im.c" -#else - -#include - -// Note [im2col/col2im output padding] -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// Our implementations of im2col and col2im take both the input height/width as -// well as a seemingly redundant output height/width. In principle, you could -// compute the output height/width by using the convolution shape formulas. So, -// what's up with that? -// -// The trouble arises when one runs the backward of a transposed convolution -// with output_padding >= stride. (BTW, output_padding is known as adj inside -// THNN.) Let's consider a simple case where we have kernel=2, dilation=2, -// stride=1, output_padding=1 for a 4x4 input: -// -// Input: X -// -// Output: X.X. -// .... -// X.X. -// .... -// -// If we compute backwards of output with a standard convolution on the output -// with the same parameters, we would end up with a 2x2 grad_input (because you -// can slide the stencil over to the right once and down once). But that is all -// out-of-bounds if you're computing backwards for a 1x1 input. -// -// "Now Edward," you might say, "the real problem is that you set output_padding -// >= stride, surely an error should have been raised in this case." To -// understand why it is useful to handle this case, we have to understand how we -// compute the weight gradient of a convolution. Suppose we have a convolution -// with kernel=2, stride=2 on a 5x5 input. Let us see all the contributions of -// weight[0][0] (which we have labeled w) in the output: -// -// Input: a.b.. Weight: w. -// ..... .. -// c.d.. -// ..... -// ..... -// -// Output: [ aw+... bw+... ] -// [ cw+... dw+... ] -// -// From this diagram, it easy to see that we can compute the weight gradient -// by performing a *dilated* convolution between the input and the -// output gradients with kernel=2, dilation=2, stride=1. But there's a rub: if -// we do a dilated convolution directly, we'll end up with a 3x3 weight -// gradient, when we clearly wanted a 2x2. So how do we avoid going out -// of bounds? We could add a notion of 'output_padding' for non-transposed -// convolution, but another simple and effective fix is to just accept -// the desired output size directly, and compute only within those bounds. -// -// -// ALSO do vol2col - -static void THNN_(im2col)(const scalar_t* data_im, const int64_t channels, - const int64_t height, const int64_t width, - const int64_t output_height, const int64_t output_width, - const int64_t kernel_h, const int64_t kernel_w, - const int64_t pad_h, const int64_t pad_w, - const int64_t stride_h, const int64_t stride_w, - const int64_t dilation_h, const int64_t dilation_w, - scalar_t* data_col) { - const int64_t height_col = output_height; - const int64_t width_col = output_width; - const int64_t channels_col = channels * kernel_h * kernel_w; - for (int64_t c_col = 0; c_col < channels_col; ++c_col) { - int64_t w_offset = c_col % kernel_w; - int64_t h_offset = (c_col / kernel_w) % kernel_h; - int64_t c_im = c_col / kernel_h / kernel_w; - for (int64_t h_col = 0; h_col < height_col; ++h_col) { - int64_t h_im = h_col * stride_h - pad_h + h_offset * dilation_h; - for (int64_t w_col = 0; w_col < width_col; ++w_col) { - int64_t w_im = w_col * stride_w - pad_w + w_offset * dilation_w; - data_col[(c_col * height_col + h_col) * width_col + w_col] = - (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) ? - data_im[(c_im * height + h_im) * width + w_im] : 0; - } - } - } -} - -static void THNN_(col2im)(const scalar_t* data_col, const int64_t channels, - const int64_t height, const int64_t width, - const int64_t output_height, const int64_t output_width, - const int64_t kernel_h, const int64_t kernel_w, - const int64_t pad_h, const int64_t pad_w, - const int64_t stride_h, const int64_t stride_w, - const int64_t dilation_h, const int64_t dilation_w, - scalar_t* data_im) { - memset(data_im, 0, sizeof(scalar_t) * height * width * channels); - const int64_t height_col = output_height; - const int64_t width_col = output_width; - const int64_t channels_col = channels * kernel_h * kernel_w; - for (int64_t c_col = 0; c_col < channels_col; ++c_col) { - int64_t w_offset = c_col % kernel_w; - int64_t h_offset = (c_col / kernel_w) % kernel_h; - int64_t c_im = c_col / kernel_h / kernel_w; - for (int64_t h_col = 0; h_col < height_col; ++h_col) { - int64_t h_im = h_col * stride_h - pad_h + h_offset * dilation_h; - for (int64_t w_col = 0; w_col < width_col; ++w_col) { - int64_t w_im = w_col * stride_w - pad_w + w_offset * dilation_w; - if (h_im >= 0 && h_im < height && w_im >= 0 && w_im < width) - data_im[(c_im * height + h_im) * width + w_im] += - data_col[(c_col * height_col + h_col) * width_col + w_col]; - } - } - } -} - -static inline void THNN_(Col2Im_shapeCheck)( - THNNState *state, - THTensor *input, - THTensor *gradOutput, - int64_t outputHeight, int64_t outputWidth, - int64_t kH, int64_t kW, int64_t dilationH, int64_t dilationW, - int64_t padH, int64_t padW, int64_t dH, int64_t dW) { - - THArgCheck(kW > 0 && kH > 0, 6, - "kernel size should be greater than zero, but got kH: %d kW: %d", kH, kW); - THArgCheck(dW > 0 && dH > 0, 12, - "stride should be greater than zero, but got dH: %d dW: %d", dH, dW); - THArgCheck(dilationW > 0 && dilationH > 0, 8, - "dilation should be greater than zero, but got dilationH: %d dilationW: %d", dilationH, dilationW); - - int64_t ndim = THTensor_(nDimensionLegacyNoScalars)(input); - THNN_ARGCHECK(!input->is_empty() && (ndim == 2 || ndim == 3), 2, input, - "Expected non-empty 2D or 3D input tensor, but got input of shape %s"); - - int64_t batch_dim = (ndim == 3) ? 0 : -1; - int64_t nInputPlane = input->size(batch_dim + 1); - - if (nInputPlane % (kW * kH) != 0) { - THError("Expected size of input's dimension 1 to be divisible by the " - "product of kernel_size, but got input.size(1)=%lld and " - "kernel_size=(%d, %d).", (long long) nInputPlane, kH, kW); - } - - int64_t inputLength = input->size(batch_dim + 2); - int64_t nBlocksH = div_rtn(outputHeight + 2 * padH - dilationH * (kH - 1) - 1, dH) + 1; - int64_t nBlocksW = div_rtn(outputWidth + 2 * padW - dilationW * (kW - 1) - 1, dW) + 1; - - if (inputLength != (nBlocksH * nBlocksW)) { - THError("Given output_size=(%d, %d), kernel_size=(%d, %d), " - "dilation=(%d, %d), padding=(%d, %d), stride=(%d, %d), expected " - "size of input's dimension 2 to match the calculated number of " - "sliding blocks %lld * %lld = %lld, but got input.size(2)=%lld.", - outputHeight, outputWidth, kH, kW, dilationH, dilationW, padH, padW, dH, dW, - (long long) nBlocksH, (long long) nBlocksW, - (long long) (nBlocksH * nBlocksW), (long long) inputLength); - } - - if (outputWidth < 1 || outputHeight < 1) { - THError("Expected output spatial size to be positive, but got: output_size=(%d, %d).", - outputHeight, outputWidth); - } -} - -void THNN_(Col2Im_updateOutput)( - THNNState *state, - THTensor *input, - THTensor *output, - int64_t outputHeight, int64_t outputWidth, - int64_t kH, int64_t kW, - int64_t dilationH, int64_t dilationW, - int64_t padH, int64_t padW, - int64_t dH, int64_t dW) { - - THNN_(Col2Im_shapeCheck)(state, input, NULL, outputHeight, outputWidth, - kH, kW, dilationH, dilationW, padH, padW, dH, dW); - - bool batched_input = true; - if (input->dim() == 2) { - // Force batch - batched_input = false; - THTensor_(resize3d)(input, 1, input->size(0), input->size(1)); - } - - long batchSize = input->size(0); - long nInputPlane = input->size(1); - long nOutputPlane = nInputPlane / (kW * kH); - - input = THTensor_(newContiguous)(input); - - THTensor_(resize4d)(output, batchSize, nOutputPlane, outputHeight, outputWidth); - THTensor_(zero)(output); - - THTensor *input_n = THTensor_(new)(); - THTensor *output_n = THTensor_(new)(); - - int64_t height_col = (outputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; - int64_t width_col = (outputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; - - for (int64_t elt = 0; elt < batchSize; elt++) { - THTensor_(select)(input_n, input, 0, elt); - THTensor_(select)(output_n, output, 0, elt); - - THNN_(col2im)( - input_n->data(), - nOutputPlane, - outputHeight, outputWidth, - height_col, width_col, - kH, kW, - padH, padW, - dH, dW, - dilationH, dilationW, output_n->data()); - } - - c10::raw::intrusive_ptr::decref(input_n); - c10::raw::intrusive_ptr::decref(output_n); - - if (!batched_input) { - THTensor_(resize3d)(output, nOutputPlane, outputHeight, outputWidth); - } - c10::raw::intrusive_ptr::decref(input); -} - -void THNN_(Col2Im_updateGradInput)( - THNNState *state, - THTensor *gradOutput, - THTensor *gradInput, - int64_t kH, int64_t kW, - int64_t dilationH, int64_t dilationW, - int64_t padH, int64_t padW, - int64_t dH, int64_t dW) { - - THNN_(Im2Col_updateOutput)(state, gradOutput, gradInput, - kH, kW, dilationH, dilationW, padH, padW, dH, dW); -} - -#endif diff --git a/aten/src/THNN/generic/Im2Col.c b/aten/src/THNN/generic/Im2Col.c deleted file mode 100644 index 6d665134b4d4b..0000000000000 --- a/aten/src/THNN/generic/Im2Col.c +++ /dev/null @@ -1,121 +0,0 @@ -#ifndef TH_GENERIC_FILE -#define TH_GENERIC_FILE "THNN/generic/Im2Col.c" -#else - -#include - -static inline void THNN_(Im2Col_shapeCheck)( - THNNState *state, - THTensor *input, - THTensor *gradOutput, - int64_t kH, int64_t kW, int64_t dilationH, int64_t dilationW, - int64_t padH, int64_t padW, int64_t dH, int64_t dW) { - - THArgCheck(kW > 0 && kH > 0, 4, - "kernel size should be greater than zero, but got kH: %d kW: %d", kH, kW); - THArgCheck(dilationW > 0 && dilationH > 0, 6, - "dilation should be greater than zero, but got dilationH: %d dilationW: %d", dilationH, dilationW); - THArgCheck(dW > 0 && dH > 0, 10, - "stride should be greater than zero, but got dH: %d dW: %d", dH, dW); - - int64_t ndim = THTensor_(nDimensionLegacyNoScalars)(input); - THNN_ARGCHECK(!input->is_empty() && (ndim == 3 || ndim == 4), 2, input, - "Expected non-empty 3D or 4D input tensor, but got input of shape %s"); - - int64_t dim_batch = 0; - if (ndim == 3) { - dim_batch = -1; - } - int64_t nInputPlane = THTensor_(size)(input, dim_batch + 1); - int64_t inputHeight = THTensor_(size)(input, dim_batch + 2); - int64_t inputWidth = THTensor_(size)(input, dim_batch + 3); - int64_t outputHeight = div_rtn(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1), dH) + 1; - int64_t outputWidth = div_rtn(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1), dW) + 1; - int64_t nOutputPlane = nInputPlane * kW * kH; - int64_t outputLength = outputHeight * outputWidth; - - if (outputHeight < 1 || outputWidth < 1) { - THError("Given input with spatial size (%d, %d), kernel_size=(%d, %d), " - "dilation=(%d, %d), padding=(%d, %d), calculated " - "shape of the array of sliding blocks as (%d, %d), which is " - "too small (non-positive).", - inputHeight, inputHeight, kH, kW, dilationH, dilationW, padH, padW, - outputHeight, outputWidth); - } -} - -void THNN_(Im2Col_updateOutput)( - THNNState *state, - THTensor *input, - THTensor *output, - int64_t kH, int64_t kW, - int64_t dilationH, int64_t dilationW, - int64_t padH, int64_t padW, - int64_t dH, int64_t dW) { - - THNN_(Im2Col_shapeCheck)(state, input, NULL, kH, kW, dilationH, dilationW, padH, padW, dH, dW); - - input = THTensor_(newContiguous)(input); - bool batched_input = true; - if (input->dim() == 3) { - batched_input = false; - THTensor_(resize4d)(input, 1, input->size(0), input->size(1), input->size(2)); - } - - int64_t batchSize = THTensor_(size)(input, 0); - int64_t nInputPlane = THTensor_(size)(input, 1); - int64_t inputHeight = THTensor_(size)(input, 2); - int64_t inputWidth = THTensor_(size)(input, 3); - - int64_t outputHeight = (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; - int64_t outputWidth = (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; - int64_t nOutputPlane = nInputPlane * kW * kH; - int64_t outputLength = outputHeight * outputWidth; - - THTensor_(resize3d)(output, batchSize, nOutputPlane, outputLength); - THTensor_(zero)(output); - - THTensor *input_n = THTensor_(new)(); - THTensor *output_n = THTensor_(new)(); - - for (int64_t elt = 0; elt < batchSize; elt++) { - THTensor_(select)(input_n, input, 0, elt); - THTensor_(select)(output_n, output, 0, elt); - - THNN_(im2col)( - input_n->data(), - nInputPlane, - inputHeight, inputWidth, - outputHeight, outputWidth, - kH, kW, padH, padW, dH, dW, - dilationH, dilationW, output_n->data()); - } - - c10::raw::intrusive_ptr::decref(input_n); - c10::raw::intrusive_ptr::decref(output_n); - - if (!batched_input) { - THTensor_(resize2d)(output, nOutputPlane, outputLength); - } - c10::raw::intrusive_ptr::decref(input); -} - -void THNN_(Im2Col_updateGradInput)( - THNNState *state, - THTensor *gradOutput, - THTensor *gradInput, - int64_t inputHeight, int64_t inputWidth, - int64_t kH, int64_t kW, - int64_t dilationH, int64_t dilationW, - int64_t padH, int64_t padW, - int64_t dH, int64_t dW) { - - - THNN_(Col2Im_updateOutput)(state, gradOutput, gradInput, - inputHeight, inputWidth, - kH, kW, dilationH, dilationW, - padH, padW, dH, dW); -} - - -#endif diff --git a/aten/src/THNN/generic/SpatialDilatedConvolution.c b/aten/src/THNN/generic/SpatialDilatedConvolution.c index 4ad1ffec16554..82d958b02f0b4 100644 --- a/aten/src/THNN/generic/SpatialDilatedConvolution.c +++ b/aten/src/THNN/generic/SpatialDilatedConvolution.c @@ -164,7 +164,7 @@ void THNN_(SpatialDilatedConvolution_updateOutput)( } // Extract columns: - THNN_(im2col)( + at::native::im2col( input_n->data(), nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, @@ -280,7 +280,7 @@ void THNN_(SpatialDilatedConvolution_updateGradInput)( ); // Unpack columns back into input: - THNN_(col2im)( + at::native::col2im( gradColumns->data(), nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, kH, kW, padH, padW, dH, dW, @@ -373,7 +373,7 @@ void THNN_(SpatialDilatedConvolution_accGradParameters)( THTensor_(select)(input_n, input, 0, elt); // Extract columns: - THNN_(im2col)( + at::native::im2col( input_n->data(), nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, diff --git a/aten/src/THNN/generic/SpatialFullDilatedConvolution.c b/aten/src/THNN/generic/SpatialFullDilatedConvolution.c index 89021befa3ef1..0ecff5d71f879 100644 --- a/aten/src/THNN/generic/SpatialFullDilatedConvolution.c +++ b/aten/src/THNN/generic/SpatialFullDilatedConvolution.c @@ -1,3 +1,5 @@ +#include + #ifndef TH_GENERIC_FILE #define TH_GENERIC_FILE "THNN/generic/SpatialFullDilatedConvolution.c" #else @@ -161,7 +163,7 @@ void THNN_(SpatialFullDilatedConvolution_updateOutput)( ); // Unpack columns back into input: - THNN_(col2im)( + at::native::col2im( columns->data(), nOutputPlane, outputHeight, outputWidth, inputHeight, inputWidth, kH, kW, padH, padW, dH, dW, dilationH, dilationW, @@ -264,7 +266,7 @@ void THNN_(SpatialFullDilatedConvolution_updateGradInput)( THTensor_(select)(gradOutput_n, gradOutput, 0, elt); // Extract columns: - THNN_(im2col)( + at::native::im2col( gradOutput_n->data(), nOutputPlane, outputHeight, outputWidth, inputHeight, inputWidth, @@ -390,7 +392,7 @@ void THNN_(SpatialFullDilatedConvolution_accGradParameters)( THTensor_(select)(input_n, input, 0, elt); // Extract columns: - THNN_(im2col)( + at::native::im2col( gradOutput_n->data(), nOutputPlane, outputHeight, outputWidth, inputHeight, inputWidth, diff --git a/aten/src/THNN/generic/THNN.h b/aten/src/THNN/generic/THNN.h index 1700b0eac80a4..6cde3bc828bb7 100644 --- a/aten/src/THNN/generic/THNN.h +++ b/aten/src/THNN/generic/THNN.h @@ -102,44 +102,6 @@ TH_API void THNN_(HardTanh_updateGradInput)( accreal max_val, // upper threshold bool inplace); -TH_API void THNN_(Im2Col_updateOutput)( - THNNState *state, - THTensor *input, - THTensor *output, - int64_t kH, int64_t kW, - int64_t dilationH, int64_t dilationW, - int64_t padH, int64_t padW, - int64_t dH, int64_t dW); - -TH_API void THNN_(Im2Col_updateGradInput)( - THNNState *state, - THTensor *gradOutput, - THTensor *gradInput, - int64_t isizeH, int64_t isizeW, - int64_t kH, int64_t kW, - int64_t dilationH, int64_t dilationW, - int64_t padH, int64_t padW, - int64_t dH, int64_t dW); - -TH_API void THNN_(Col2Im_updateOutput)( - THNNState *state, - THTensor *input, - THTensor *output, - int64_t outputHeight, int64_t outputWidth, - int64_t kH, int64_t kW, - int64_t dilationH, int64_t dilationW, - int64_t padH, int64_t padW, - int64_t dH, int64_t dW); - -TH_API void THNN_(Col2Im_updateGradInput)( - THNNState *state, - THTensor *gradOutput, - THTensor *gradInput, - int64_t kH, int64_t kW, - int64_t dilationH, int64_t dilationW, - int64_t padH, int64_t padW, - int64_t dH, int64_t dW); - TH_API void THNN_(LeakyReLU_updateOutput)( THNNState *state, // library's state THTensor *input, // [MODIFIED] input tensor diff --git a/aten/src/THNN/generic/VolumetricDilatedConvolution.c b/aten/src/THNN/generic/VolumetricDilatedConvolution.c index ded13f9c190fc..ae79b663d8381 100644 --- a/aten/src/THNN/generic/VolumetricDilatedConvolution.c +++ b/aten/src/THNN/generic/VolumetricDilatedConvolution.c @@ -173,7 +173,7 @@ void THNN_(VolumetricDilatedConvolution_updateOutput)( } // Extract columns: - THNN_(vol2col)( + at::native::vol2col( input_n->data(), nInputPlane, inputDepth, inputHeight, inputWidth, outputDepth, outputHeight, outputWidth, @@ -292,7 +292,7 @@ void THNN_(VolumetricDilatedConvolution_updateGradInput)( ); // Unpack columns back into input: - THNN_(col2vol)( + at::native::col2vol( gradColumns->data(), nInputPlane, inputDepth, inputHeight, inputWidth, outputDepth, outputHeight, outputWidth, @@ -395,7 +395,7 @@ void THNN_(VolumetricDilatedConvolution_accGradParameters)( THTensor_(select)(input_n, input, 0, elt); // Extract columns: - THNN_(vol2col)( + at::native::vol2col( input_n->data(), nInputPlane, inputDepth, inputHeight, inputWidth, outputDepth, outputHeight, outputWidth, diff --git a/aten/src/THNN/generic/VolumetricFullDilatedConvolution.c b/aten/src/THNN/generic/VolumetricFullDilatedConvolution.c index 058dbf2d614b1..8d2f97551925d 100644 --- a/aten/src/THNN/generic/VolumetricFullDilatedConvolution.c +++ b/aten/src/THNN/generic/VolumetricFullDilatedConvolution.c @@ -1,89 +1,9 @@ +#include + #ifndef TH_GENERIC_FILE #define TH_GENERIC_FILE "THNN/generic/VolumetricFullDilatedConvolution.c" #else -static void THNN_(vol2col)( - const scalar_t *data_vol, const int64_t channels, - const int64_t depth, const int64_t height, const int64_t width, - const int64_t depth_col, const int64_t height_col, const int64_t width_col, - const int64_t kT, const int64_t kH, const int64_t kW, - const int64_t pT, const int64_t pH, const int64_t pW, - const int64_t dT, const int64_t dH, const int64_t dW, - const int64_t dilationT, const int64_t dilationH, const int64_t dilationW, - scalar_t *data_col) -{ - int64_t c, t, h, w; - int64_t channels_col = channels * kT * kH * kW; - for (c = 0; c < channels_col; ++c) - { - int64_t w_offset = c % kW; - int64_t h_offset = (c / kW) % kH; - int64_t t_offset = (c / kW / kH) % kT; - int64_t c_vol = c / kT / kH / kW; - for (t = 0; t < depth_col; ++t) - { - int64_t t_pad = t * dT - pT + t_offset * dilationT; - for (h = 0; h < height_col; ++h) - { - int64_t h_pad = h * dH - pH + h_offset * dilationH; - for (w = 0; w < width_col; ++w) - { - int64_t w_pad = w * dW - pW + w_offset * dilationW; - if (t_pad >= 0 && t_pad < depth && - h_pad >= 0 && h_pad < height && - w_pad >= 0 && w_pad < width) - data_col[((c * depth_col + t) * height_col + h) * width_col + w] = - data_vol[((c_vol * depth + t_pad) * height + h_pad) * width + w_pad]; - else - data_col[((c * depth_col + t) * height_col + h) * width_col + w] = 0; - } - } - } - } -} - -static void THNN_(col2vol)( - const scalar_t* data_col, const int64_t channels, - const int64_t depth, const int64_t height, const int64_t width, - const int64_t out_depth, const int64_t out_height, const int64_t out_width, - const int64_t kT, const int64_t kH, const int64_t kW, - const int64_t pT, const int64_t pH, const int64_t pW, - const int64_t dT, const int64_t dH, const int64_t dW, - const int64_t dilationT, const int64_t dilationH, const int64_t dilationW, - scalar_t* data_vol) -{ - int64_t c, t, h, w; - memset(data_vol, 0, sizeof(scalar_t) * depth * height * width * channels); - int64_t depth_col = out_depth; - int64_t height_col = out_height; - int64_t width_col = out_width; - int64_t channels_col = channels * kT * kH * kW; - for (c = 0; c < channels_col; ++c) - { - int64_t w_offset = c % kW; - int64_t h_offset = (c / kW) % kH; - int64_t t_offset = (c / kW / kH) % kT; - int64_t c_vol = c / kT / kH / kW; - for (t = 0; t < depth_col; ++t) - { - int64_t t_pad = t * dT - pT + t_offset * dilationT; - for (h = 0; h < height_col; ++h) - { - int64_t h_pad = h * dH - pH + h_offset * dilationH; - for (w = 0; w < width_col; ++w) - { - int64_t w_pad = w * dW - pW + w_offset * dilationW; - if (t_pad >= 0 && t_pad < depth && - h_pad >= 0 && h_pad < height && - w_pad >= 0 && w_pad < width) - data_vol[((c_vol * depth + t_pad) * height + h_pad) * width + w_pad] += - data_col[((c * depth_col + t) * height_col + h) * width_col + w]; - } - } - } - } -} - static inline void THNN_(VolumetricFullDilatedConvolution_shapeCheck)( THTensor *input, THTensor *gradOutput, THTensor *weight, THTensor *bias, @@ -255,7 +175,7 @@ void THNN_(VolumetricFullDilatedConvolution_updateOutput)( ); // Unpack columns back into input: - THNN_(col2vol)( + at::native::col2vol( columns->data(), nOutputPlane, outputDepth, outputHeight, outputWidth, inputDepth, inputHeight, inputWidth, @@ -370,7 +290,7 @@ void THNN_(VolumetricFullDilatedConvolution_updateGradInput)( THTensor_(select)(gradOutput_n, gradOutput, 0, elt); // Extract columns: - THNN_(vol2col)( + at::native::vol2col( gradOutput_n->data(), nOutputPlane, outputDepth, outputHeight, outputWidth, inputDepth, inputHeight, inputWidth, @@ -506,7 +426,7 @@ void THNN_(VolumetricFullDilatedConvolution_accGradParameters)( THTensor_(select)(input_n, input, 0, elt); // Extract columns: - THNN_(vol2col)( + at::native::vol2col( gradOutput_n->data(), nOutputPlane, outputDepth, outputHeight, outputWidth, inputDepth, inputHeight, inputWidth, diff --git a/aten/src/THNN/init.cpp b/aten/src/THNN/init.cpp index 5f7cbe9e3e2f3..679a4ec91ae85 100644 --- a/aten/src/THNN/init.cpp +++ b/aten/src/THNN/init.cpp @@ -70,18 +70,12 @@ #include #include -#include -#include - #include #include #include #include -#include -#include - #include #include diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index fc826259c3126..4d52c5e9784b6 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1221,11 +1221,11 @@ - name: thnn_conv_dilated3d_backward(Tensor grad_output, Tensor self, Tensor weight, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, Tensor columns, Tensor ones, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, {{0, 0, 0}}, 1, false, false, false, grad_input_mask) -- name: thnn_col2im(Tensor self, int[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor - self: thnn_col2im_backward(grad, kernel_size, dilation, padding, stride) +- name: col2im(Tensor self, int[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor + self: col2im_backward(grad, kernel_size, dilation, padding, stride) -- name: thnn_im2col(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor - self: thnn_im2col_backward(grad, {self.size(2), self.size(3)}, kernel_size, dilation, padding, stride) +- name: im2col(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor + self: im2col_backward(grad, {self.size(2), self.size(3)}, kernel_size, dilation, padding, stride) # NN double backwards support diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 67b1f37af8225..9d971f72fc520 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -3028,8 +3028,8 @@ def unfold(input, kernel_size, dilation=1, padding=0, stride=1): assert_int_or_pair(padding, 'padding', msg) assert_int_or_pair(stride, 'stride', msg) - ret = torch._C._nn.thnn_im2col(input, _pair(kernel_size), - _pair(dilation), _pair(padding), _pair(stride)) + ret = torch._C._nn.im2col(input, _pair(kernel_size), + _pair(dilation), _pair(padding), _pair(stride)) else: raise NotImplementedError("Input Error: Only 4D input Tensors are supported (got {}D)".format(input.dim())) ret = input # TODO: remove when jit supports exception control flow @@ -3056,8 +3056,8 @@ def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1): assert_int_or_pair(padding, 'padding', msg) assert_int_or_pair(stride, 'stride', msg) - ret = torch._C._nn.thnn_col2im(input, _pair(output_size), _pair(kernel_size), - _pair(dilation), _pair(padding), _pair(stride)) + ret = torch._C._nn.col2im(input, _pair(output_size), _pair(kernel_size), + _pair(dilation), _pair(padding), _pair(stride)) else: raise NotImplementedError("Input Error: Only 3D input Tensors are supported (got {}D)".format(input.dim())) ret = input # TODO: remove when jit supports exception control flow