Skip to content

Commit

Permalink
[4/N] Avoid copy in std::get (pytorch#142285)
Browse files Browse the repository at this point in the history
Fixes #ISSUE_NUMBER

Pull Request resolved: pytorch#142285
Approved by: https://github.com/Skylion007

Co-authored-by: Aaron Gokaslan <[email protected]>
  • Loading branch information
2 people authored and pytorchmergebot committed Dec 9, 2024
1 parent 2cc01cc commit a108b28
Show file tree
Hide file tree
Showing 17 changed files with 43 additions and 60 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/core/boxing/KernelFunction_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ void expectOutOfPlaceMultiUnboxedCallingWorks(const KernelFunction& func) {
auto t1 = at::zeros({1});
auto t2 = at::zeros({1});

auto [t1_out, t2_out] = func.call<
const auto [t1_out, t2_out] = func.call<
std::tuple<at::Tensor&, at::Tensor&>, at::Scalar, at::Scalar, at::Tensor&, at::Tensor&
>(dummy, CPU_TEST_SET, s1, s2, t1, t2);

Expand Down
29 changes: 14 additions & 15 deletions aten/src/ATen/functorch/BatchRulesConvolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ convolution_backward_input_batch_rule(
const auto result = at::convolution_backward_symint(
grad_output_, dummy_input, weight_, std::nullopt, stride, padding,
dilation, transposed, output_padding, groups * batch_size, mask);
const auto grad_input = reshape_dim_outof(1, batch_size, std::get<0>(result));
return std::make_tuple(grad_input, 1);
auto grad_input = reshape_dim_outof(1, batch_size, std::get<0>(result));
return std::make_tuple(std::move(grad_input), 1);
} else if (grad_output_bdim && !weight_bdim) {
// BNO, OI -> (BN)O, OI -> (BN)I
// transposed is the same.
Expand All @@ -192,8 +192,8 @@ convolution_backward_input_batch_rule(
const auto result = at::convolution_backward_symint(
grad_output_, dummy_input, weight, std::nullopt, stride, padding,
dilation, transposed, output_padding, groups, mask);
const auto grad_input = reshape_dim_outof(0, batch_size, std::get<0>(result));
return std::make_tuple(grad_input, 0);
auto grad_input = reshape_dim_outof(0, batch_size, std::get<0>(result));
return std::make_tuple(std::move(grad_input), 0);
} else if (!grad_output_bdim && weight_bdim) {
const auto batch_size = weight.size(*weight_bdim);
if (groups == 1) {
Expand Down Expand Up @@ -359,7 +359,6 @@ static std::tuple<Tensor,Tensor,Tensor> convolution_backward_plumbing(
const Tensor& grad_output_, const Tensor& input_, const Tensor& weight_,
const c10::OptionalArrayRef<SymInt> bias_sizes_opt,
c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
c10::SymIntArrayRef output_padding, c10::SymInt groups, std::array<bool, 3> output_mask) {
const auto maybe_layer = maybeCurrentDynamicLayer();
vmap_check_escaped(maybe_layer, "convolution_backward_plumbing");
Expand All @@ -369,14 +368,14 @@ static std::tuple<Tensor,Tensor,Tensor> convolution_backward_plumbing(
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
return at::convolution_backward_symint(
grad_output_, input_, weight_, bias_sizes_opt, stride, padding,
dilation, transposed, output_padding, groups, output_mask);
dilation, transposed, output_padding, std::move(groups), output_mask);
}

auto [grad_output, grad_output_bdim] = unwrapTensorAtLevel(grad_output_, cur_level);
auto [input, input_bdim] = unwrapTensorAtLevel(input_, cur_level);
auto [weight, weight_bdim] = unwrapTensorAtLevel(weight_, cur_level);

const auto grad_bias = compute_grad_bias(grad_output_, output_mask);
auto grad_bias = compute_grad_bias(grad_output_, output_mask);
output_mask[2] = false;

// TODO: A little bird says that unfold + matmul is actually faster than
Expand Down Expand Up @@ -408,14 +407,14 @@ static std::tuple<Tensor,Tensor,Tensor> convolution_backward_plumbing(
grad_output, input, weight, std::nullopt, stride, padding, dilation,
transposed, output_padding, batch_size * groups, output_mask);
// N(BI), (BO)I -> NBI, BOI
const auto grad_input = output_mask[0] ?
auto grad_input = output_mask[0] ?
reshape_dim_outof(1, batch_size, std::get<0>(result)) : Tensor();
const auto grad_weight = output_mask[1] ?
auto grad_weight = output_mask[1] ?
reshape_dim_outof(0, batch_size, std::get<1>(result)) : Tensor();
return std::make_tuple(
output_mask[0] ? makeBatched(grad_input, 1, cur_level) : grad_input,
output_mask[1] ? makeBatched(grad_weight, 0, cur_level) : grad_weight,
grad_bias);
output_mask[0] ? makeBatched(std::move(grad_input), 1, cur_level) : std::move(grad_input),
output_mask[1] ? makeBatched(std::move(grad_weight), 0, cur_level) : std::move(grad_weight),
std::move(grad_bias));
}

Tensor grad_input;
Expand All @@ -426,7 +425,7 @@ static std::tuple<Tensor,Tensor,Tensor> convolution_backward_plumbing(
input, input_bdim,
weight, weight_bdim,
stride, padding, dilation, transposed, output_padding, groups);
grad_input = makeBatched(tensor, bdim, cur_level);
grad_input = makeBatched(std::move(tensor), bdim, cur_level);
}

Tensor grad_weight;
Expand All @@ -437,9 +436,9 @@ static std::tuple<Tensor,Tensor,Tensor> convolution_backward_plumbing(
input, input_bdim,
weight, weight_bdim,
stride, padding, dilation, transposed, output_padding, groups);
grad_weight = makeBatched(tensor, bdim, cur_level);
grad_weight = makeBatched(std::move(tensor), bdim, cur_level);
}
return std::make_tuple(grad_input, grad_weight, grad_bias);
return std::make_tuple(std::move(grad_input), std::move(grad_weight), std::move(grad_bias));

// Someone's definitely going to find a problem with this batching rule so
// I'm leaving the following fallback if we need it back.
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/functorch/BatchRulesNorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ batch_norm_batch_rule(
result0 = result0 + bias_;
}
result0 = result0.transpose(1, 2); // [B0, B, C, *], because some arg must have been batched, the output must be batched
return std::make_tuple(result0, 0, mean, stats_bdim, rstd, stats_bdim);
return std::make_tuple(std::move(result0), 0, std::move(mean), stats_bdim, std::move(rstd), stats_bdim);
}

template<typename F, F Func>
Expand Down
3 changes: 1 addition & 2 deletions aten/src/ATen/native/RNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -727,8 +727,7 @@ struct LSTMCell : Cell<std::tuple<Tensor, Tensor>, cell_params> {
const hidden_type& hidden,
const cell_params& params,
bool pre_compute_input = false) const override {
const auto& hx = std::get<0>(hidden);
const auto& cx = std::get<1>(hidden);
const auto& [hx, cx] = hidden;

if (input.is_cuda() || input.is_xpu() || input.is_privateuseone()) {
TORCH_CHECK(!pre_compute_input);
Expand Down
5 changes: 1 addition & 4 deletions aten/src/ATen/native/cpu/GridSamplerKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1052,10 +1052,7 @@ static inline void grid_sample_2d_grid_slice_iterator(
std::min(step, len * 2));
auto vec2 = Vec::loadu(grid_ptr + grid_offset + step,
std::max(static_cast<int64_t>(0), len * 2 - step));
auto vec_xy_pair = deinterleave2(vec1, vec2);

auto x = std::get<0>(vec_xy_pair);
auto y = std::get<1>(vec_xy_pair);
auto [x, y] = deinterleave2(vec1, vec2);

// make sure that x and y are valid grid sample locations
if (len < step) {
Expand Down
7 changes: 2 additions & 5 deletions aten/src/ATen/native/cuda/DistributionTemplates.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,10 @@ __global__ void distribution_elementwise_grid_stride_kernel(int64_t numel,
PhiloxCudaState philox_args,
const dist_t dist_func,
const transform_t transform_func) {
auto seeds = at::cuda::philox::unpack(philox_args);
auto [seed, offset] = at::cuda::philox::unpack(philox_args);
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(std::get<0>(seeds),
idx,
std::get<1>(seeds),
&state);
curand_init(seed, idx, offset, &state);

int64_t rounded_size = ((numel - 1)/(blockDim.x * gridDim.x * unroll_factor)+1) *
blockDim.x * gridDim.x * unroll_factor;
Expand Down
14 changes: 4 additions & 10 deletions aten/src/ATen/native/cuda/Dropout.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,10 @@ fused_dropout_kernel_vec(at::cuda::detail::TensorInfo<const scalar_t, IndexType>
using LoadT = memory::aligned_vector<scalar_t, VEC>;
using MaskLoadT = memory::aligned_vector<mask_t, VEC>;

auto seeds = at::cuda::philox::unpack(philox_args);
auto [seed, offset] = at::cuda::philox::unpack(philox_args);
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(std::get<0>(seeds),
idx,
std::get<1>(seeds),
&state);
curand_init(seed, idx, offset, &state);

// Helps align the total number of times curand_uniform4 is called by each thread for the same totalElements
// in the vec=2 and vec=4 cases.
Expand Down Expand Up @@ -138,13 +135,10 @@ fused_dropout_kernel(cuda::detail::TensorInfo<const scalar_t, IndexType> a,
cuda::detail::TensorInfo<mask_t, IndexType> c,
IndexType totalElements, accscalar_t p,
PhiloxCudaState philox_args) {
auto seeds = at::cuda::philox::unpack(philox_args);
auto [seed, offset] = at::cuda::philox::unpack(philox_args);
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(std::get<0>(seeds),
idx,
std::get<1>(seeds),
&state);
curand_init(seed, idx, offset, &state);
accscalar_t scale = 1.0 / p;

IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) *
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/native/cuda/IndexKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ static Tensor & masked_select_out_cuda_impl(Tensor & result, const Tensor & self
// Cannot reassign to mask_temp and self_temp here! if they are
// owning and expand_outplace returns a borrow, the returned borrow
// would dangle.
auto mask_self_expanded = expand_outplace(*mask_temp, *self_temp);
auto [mask_expanded, self_expanded] = expand_outplace(*mask_temp, *self_temp);
at::cuda::index_out(
result, *std::get<1>(mask_self_expanded),
c10::List<std::optional<at::Tensor>>({*std::move(std::get<0>(mask_self_expanded))}));
result, *self_expanded,
c10::List<std::optional<at::Tensor>>({*std::move(mask_expanded)}));

return result;
}
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/Randperm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ __global__ void randperm_handle_duplicate_keys_kernel(T *keys, scalar_t *data, T

// do random permutation inside each island.
data += tid;
auto [seed, offset] = at::cuda::philox::unpack(philox_args);
const auto [seed, offset] = at::cuda::philox::unpack(philox_args);
curandStatePhilox4_32_10_t state;
curand_init(seed, tid, offset, &state);
for (int i = island_size - 1; i > 0; i--) {
Expand Down
6 changes: 1 addition & 5 deletions aten/src/ATen/native/cuda/RreluWithNoise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,7 @@ inline void _rrelu_with_noise_cuda_train(

int64_t numel = input.numel();
const int unroll_factor = std::is_same_v<scalar_t, double> ? 2 : 4;
auto execution_policy = calc_execution_policy(numel, unroll_factor);

auto counter_offset = std::get<0>(execution_policy);
auto grid = std::get<1>(execution_policy);
auto block = std::get<2>(execution_policy);
auto [counter_offset, grid, block] = calc_execution_policy(numel, unroll_factor);

auto gen = get_generator_or_default<CUDAGeneratorImpl>(
generator, cuda::detail::getDefaultCUDAGenerator());
Expand Down
10 changes: 6 additions & 4 deletions aten/src/ATen/native/cudnn/RNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2560,9 +2560,10 @@ std::pair<Tensor, hidden_type> _cudnn_impl(
dropout_state.buffer);

return {
std::get<0>(cudnn_output),
std::move(std::get<0>(cudnn_output)),
pack_hidden<hidden_type>(
std::get<1>(cudnn_output), std::get<2>(cudnn_output))};
std::move(std::get<1>(cudnn_output)),
std::move(std::get<2>(cudnn_output)))};
}

template <typename hidden_type>
Expand Down Expand Up @@ -2621,9 +2622,10 @@ std::pair<Tensor, hidden_type> _cudnn_impl(
dropout_state.buffer);

return {
std::get<0>(cudnn_output),
std::move(std::get<0>(cudnn_output)),
pack_hidden<hidden_type>(
std::get<1>(cudnn_output), std::get<2>(cudnn_output))};
std::move(std::get<1>(cudnn_output)),
std::move(std::get<2>(cudnn_output)))};
}

#define ONE_HIDDEN_RNN(NAME, MODE) \
Expand Down
1 change: 0 additions & 1 deletion aten/src/ATen/native/quantized/TensorAdvancedIndexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ Tensor& _index_put_impl_quantized_cpu_(Tensor & self, const torch::List<std::opt
value_ = value.to(self.device());
}
at::assert_no_overlap(self, value);
// NOLINTNEXTLINE(performance-implicit-conversion-in-loop)
for (const std::optional<Tensor>& index: indices) {
if (index.has_value()) {
at::assert_no_overlap(self, *index);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ register_conv_params() {
},
// __setstate__ takes c10::IValue because we support parsing historical
// serialization versions.
[](c10::IValue v)
[](const c10::IValue& v)
-> c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> { // __setstate__
ConvParamsSerializationTypeV3 state = parse_conv_serialized_state<kSpatialDim>(v);
return deserialize_conv<kSpatialDim>(state);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,16 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kNWarps = Kernel_traits::kNWarps;

auto seed_offset = at::cuda::philox::unpack(params.philox_args);
pytorch_flash::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t,
auto [seed, offset] = at::cuda::philox::unpack(params.philox_args);
pytorch_flash::Dropout dropout(seed, offset, params.p_dropout_in_uint8_t,
bidb, bidh, tidx, params.h);

// Save seed and offset for backward. If we don't have this here, the 0-th thread block might
// exit early and no one saves the rng state.
if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) {
if (params.philox_args.captured_) {
*params.seed = std::get<0>(seed_offset);
*params.extragraph_offset = std::get<1>(seed_offset);
*params.seed = seed;
*params.extragraph_offset = offset;
}
}

Expand Down
2 changes: 1 addition & 1 deletion c10/cuda/CUDACachingAllocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class C10_CUDA_API FreeMemoryCallback {

C10_DECLARE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback);
#define REGISTER_FREE_MEMORY_CALLBACK(name, ...) \
C10_REGISTER_CLASS(FreeCudaMemoryCallbacksRegistry, name, __VA_ARGS__);
C10_REGISTER_CLASS(FreeCudaMemoryCallbacksRegistry, name, __VA_ARGS__)
} // namespace c10
//
// TODO: Turn this into an honest to goodness class. I briefly attempted to do
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/CudaIPCTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,6 @@ bool CudaIPCCollect() {

namespace c10 {
namespace {
REGISTER_FREE_MEMORY_CALLBACK("cuda_ipc_collect", CudaIPCCollectCallback);
REGISTER_FREE_MEMORY_CALLBACK("cuda_ipc_collect", CudaIPCCollectCallback)
}
} // namespace c10
2 changes: 1 addition & 1 deletion torch/csrc/jit/cuda/cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,6 @@ TORCH_LIBRARY(cuda, m) {
.def("record", &CUDAEvent::record)
.def("synchronize", &CUDAEvent::synchronize)
.def("wait", &CUDAEvent::wait);
};
}

} // namespace torch::jit

0 comments on commit a108b28

Please sign in to comment.