Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[release/2.3] Enable bf16 with fp32 weights for MIOpen batchnorm #1666

Draft
wants to merge 7 commits into
base: release/2.3
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion aten/src/ATen/native/Normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
#include <c10/core/SymIntArrayRef.h>
#include <utility>
#include <vector>
#include <iostream>

static const int MIOPEN_DIM_MAX = 5;

Expand Down Expand Up @@ -479,6 +480,8 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu_template(
return std::make_tuple(grad_input, grad_weight, grad_bias);
}

bool PYTORCH_MIOPEN_EXTRA_LOGGING = c10::utils::check_env("PYTORCH_MIOPEN_EXTRA_LOGGING").value_or(false);

// _batch_norm_impl_index(_backward) are used in the JIT be able to keep the run-time selection
// of backends, while enabling it to keep the information about the used backend, so that it can
// use its corresponding backward implementation.
Expand All @@ -487,6 +490,20 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
const Tensor& input, const c10::optional<Tensor>& weight_opt /* optional */, const c10::optional<Tensor>& bias_opt /* optional */, const c10::optional<Tensor>& running_mean_opt /* optional */, const c10::optional<Tensor>& running_var_opt /* optional */,
bool training, double momentum, double eps, bool cudnn_enabled) {
// See [Note: hacky wrapper removal for optional tensor]
if (PYTORCH_MIOPEN_EXTRA_LOGGING)
std :: cout
<< "PYTORCH_MIOPEN_EXTRA_LOGGING: ********************* _batch_norm_impl_index"
<< " input=" << input.scalar_type()
<< " weight=" << (weight_opt.has_value() ? weight_opt.value().scalar_type() : at::ScalarType::Undefined)
<< " bias=" << (bias_opt.has_value() ? bias_opt.value().scalar_type() : at::ScalarType::Undefined)
<< " running_mean=" << (running_mean_opt.has_value() ? running_mean_opt.value().scalar_type() : at::ScalarType::Undefined)
<< " running_var=" << (running_var_opt.has_value() ? running_var_opt.value().scalar_type() : at::ScalarType::Undefined)
<< " training=" << training
// << " momentum=" << momentum
// << " eps=" << eps
<< " cudnn_enabled=" << cudnn_enabled
<< std::endl;

c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
Expand Down Expand Up @@ -564,16 +581,33 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
bool use_miopen = (input.is_cuda()
&& input.dim() <= MIOPEN_DIM_MAX
&& input.scalar_type() != at::kDouble
&& input.scalar_type() != at::kBFloat16
&& (weight.scalar_type() != at::kHalf)
&& (weight.scalar_type() != at::kBFloat16)
&& weight.defined() && bias.defined()
&& ((running_mean.defined() && running_var.defined())
|| (!running_mean.defined() && !running_var.defined() && training))
&& detail::getCUDAHooks().compiledWithMIOpen()
&& cudnn_enabled
);

if (PYTORCH_MIOPEN_EXTRA_LOGGING)
std::cout
<< "PYTORCH_MIOPEN_EXTRA_LOGGING: ********************* _batch_norm_impl_index (use_miopen)"
<< " use_miopen=" << use_miopen
<< " cudnn_enabled=" << cudnn_enabled
<< " dim=" << input.dim()
<< " memory_format=" << input.suggest_memory_format()
<< " input.dtype=" << input.scalar_type()
<< " weight.dtype=" << (weight.defined()?"+":"-") << weight.scalar_type()
<< " bias.dtype=" << (bias.defined()?"+":"-") << bias.scalar_type()
<< " running_mean.dtype=" << (running_mean.defined()?"+":"-") << running_mean.scalar_type()
<< " running_var.dtype=" << (running_mean.defined()?"+":"-") << running_mean.scalar_type()
<< " training=" << training
<< std::endl;

if (use_miopen && input.suggest_memory_format() != MemoryFormat::ChannelsLast && input.suggest_memory_format() != MemoryFormat::ChannelsLast3d) {
if (PYTORCH_MIOPEN_EXTRA_LOGGING)
std::cout << "PYTORCH_MIOPEN_EXTRA_LOGGING: ********************* _batch_norm_impl_index (calling miopen_batch_norm)" << std::endl;
return std::tuple_cat(
at::miopen_batch_norm(
input.contiguous(), weight.contiguous(), bias.contiguous(),
Expand All @@ -596,6 +630,8 @@ std::tuple<Tensor, Tensor, Tensor> _batch_norm_impl_index_backward(
const Tensor& input, const Tensor& grad_output, const c10::optional<Tensor>& weight_opt /* optional */, const c10::optional<Tensor>& running_mean_opt /* optional */, const c10::optional<Tensor>& running_var_opt /* optional */, const c10::optional<Tensor>& save_mean_opt /* optional */, const c10::optional<Tensor>& save_var_transform_opt /* optional */,
bool train, double epsilon, std::array<bool, 3> output_mask, const Tensor &reservedSpace) {
// See [Note: hacky wrapper removal for optional tensor]
if (PYTORCH_MIOPEN_EXTRA_LOGGING)
std :: cout << "PYTORCH_MIOPEN_EXTRA_LOGGING: ********************* _batch_norm_impl_index_backward" << std::endl;
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
Expand Down Expand Up @@ -626,12 +662,16 @@ std::tuple<Tensor, Tensor, Tensor> _batch_norm_impl_index_backward(

// backward in inference mode is not supported in cudnn, fallback to native
if (impl_index == 0 || (!train)) {
if (PYTORCH_MIOPEN_EXTRA_LOGGING)
std :: cout << "PYTORCH_MIOPEN_EXTRA_LOGGING: ********************* _batch_norm_impl_index_backward (calling native_batch_norm_backward)" << std::endl;
return at::native_batch_norm_backward(grad_output, input, weight, running_mean, running_var, save_mean, save_var_transform, train, epsilon, output_mask);
} else if (impl_index == 1) {
// TODO: _batch_norm_impl_index_backward is only used in JIT. cudnn NHWC
// format conversion is done inside cudnn_batch_norm_backward instead
return at::cudnn_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var_transform, epsilon, reservedSpace);
} else if (impl_index == 2) {
if (PYTORCH_MIOPEN_EXTRA_LOGGING)
std :: cout << "PYTORCH_MIOPEN_EXTRA_LOGGING: ********************* _batch_norm_impl_index_backward (calling miopen_batch_norm_backward)" << std::endl;
return at::miopen_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var_transform, epsilon);
}
TORCH_INTERNAL_ASSERT(false, "Unsupported impl_index in _batch_norm_impl_index_backward: ", impl_index);
Expand All @@ -641,6 +681,20 @@ Tensor batch_norm(
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
bool training, double momentum, double eps, bool cudnn_enabled) {
if (PYTORCH_MIOPEN_EXTRA_LOGGING)
std :: cout
<< "PYTORCH_MIOPEN_EXTRA_LOGGING: ********************* batch_norm"
<< " input=" << input.scalar_type()
<< " weight=" << (weight_opt.has_value() ? weight_opt.value().scalar_type() : at::ScalarType::Undefined)
<< " bias=" << (bias_opt.has_value() ? bias_opt.value().scalar_type() : at::ScalarType::Undefined)
<< " running_mean=" << (running_mean_opt.has_value() ? running_mean_opt.value().scalar_type() : at::ScalarType::Undefined)
<< " running_var=" << (running_var_opt.has_value() ? running_var_opt.value().scalar_type() : at::ScalarType::Undefined)
<< " training=" << training
// << " momentum=" << momentum
// << " eps=" << eps
<< " cudnn_enabled=" << cudnn_enabled
<< std::endl;

const Tensor& weight = c10::value_or_else(weight_opt, [] {return Tensor();});
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/miopen/BatchNorm_miopen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
checkAllDefined(c, {running_mean, running_var});
}
checkAllSameGPU(c, {input, weight, bias, running_mean, running_var});
if (input->scalar_type() != ScalarType::Half) {
if (input->scalar_type() != ScalarType::Half && input->scalar_type() != ScalarType::BFloat16) {
checkAllSameType(c, {input, weight});
}
checkAllSameType(c, {weight, bias, running_mean, running_var});
Expand Down Expand Up @@ -179,7 +179,7 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(

checkAllDefined(c, {input, grad_output, weight, save_mean, save_var});
checkAllSameGPU(c, {input, grad_output, weight, save_mean, save_var});
if (input->scalar_type() == ScalarType::Half) {
if (input->scalar_type() == ScalarType::Half || input->scalar_type() == ScalarType::BFloat16) {
checkScalarType(c, weight, ScalarType::Float);
} else {
checkAllSameType(c, {input, weight});
Expand Down
155 changes: 154 additions & 1 deletion test/test_nn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
# Owner(s): ["module: nn"]

import contextlib
Expand All @@ -9,6 +10,7 @@
import warnings
import pickle
import re
import os
from copy import deepcopy
from itertools import product
from functools import partial
Expand Down Expand Up @@ -4877,6 +4879,54 @@ def run_test(input, grad_output):
grad = grad.permute(0, 2, 1, 3)
run_test(input, grad)

@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@unittest.skipIf(not TEST_CUDNN, "needs cudnn")
def test_batchnorm_nhwc_miopen(self):
def run_test(input, grad_output):
c = input.size(1)
mod = nn.BatchNorm2d(c).cuda().float()
mod.weight.data.uniform_()
mod.bias.data.uniform_()
ref_input = input.detach().clone(memory_format=torch.preserve_format).requires_grad_(True)
ref_grad = grad.detach().clone(memory_format=torch.preserve_format)
ref_mod = nn.BatchNorm2d(c).cuda().float()
ref_mod.load_state_dict(mod.state_dict())
out = mod(input)
out.backward(grad_output)
with torch.backends.cudnn.flags(enabled=False): # force to use native nhwc batchnorm
ref_out = ref_mod(ref_input)
ref_out.backward(ref_grad)
self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
self.assertTrue(ref_out.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(out, ref_out)
self.assertEqual(mod.weight.grad, ref_mod.weight.grad)
self.assertEqual(mod.bias.grad, ref_mod.bias.grad)
self.assertEqual(input.grad, ref_input.grad)

# TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
PYTORCH_MIOPEN_SUGGEST_NHWC = "PYTORCH_MIOPEN_SUGGEST_NHWC"
prev_val = os.getenv(PYTORCH_MIOPEN_SUGGEST_NHWC)
try:
os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] = "1"
input = torch.randint(1, 10, (4, 8, 2, 2), dtype=torch.float32, device="cuda")
input = input.contiguous(memory_format=torch.channels_last).detach().requires_grad_()

grad = torch.randint(1, 10, (4, 8, 2, 2), dtype=torch.float32, device="cuda")
grad = grad.contiguous(memory_format=torch.channels_last)
run_test(input, grad)
# see #42588, grad is channels_last contiguous, but grad.suggest_memory_format (rightly) return "contiguous"
# not channels_last
input = torch.randint(1, 10, (2, 8, 8, 1), dtype=torch.float32, device="cuda")
input = input.contiguous(memory_format=torch.channels_last).detach().requires_grad_()
grad = torch.randint(1, 10, (2, 8, 8, 1), dtype=torch.float32, device="cuda")
grad = grad.permute(0, 2, 1, 3)
run_test(input, grad)
finally:
if prev_val is None:
del os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC]
else:
os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] = prev_val

@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_batchnorm_cudnn_half(self):
# THNN
Expand Down Expand Up @@ -4988,7 +5038,8 @@ def test_batchnorm_buffer_update_when_stats_are_not_tracked(self):

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_batchnorm_nhwc_cuda(self):
for dtype in (torch.half, torch.float):
# for dtype in (torch.half, torch.float):
for dtype in (torch.bfloat16,):
(N, C, H, W) = 2, 64, 50, 50
model = torch.nn.BatchNorm2d(C, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
model = model.eval().cuda().to(dtype)
Expand Down Expand Up @@ -8114,6 +8165,108 @@ def test_affine_3d_rotateRandom(self, device):

self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary))

def batchnorm2d_miopen(self, dtype, memory_format):
def run_test(input, grad_output, enable_native = True, enable_cpu = False):
print(f"XXXXXXXXXXXXXX {torch.__file__}")
c = input.size(1)
mod = nn.BatchNorm2d(c, device='cuda', dtype=input.dtype)
mod.weight.data.uniform_()
mod.bias.data.uniform_()
if enable_native:
ref_input = input.detach().clone(memory_format=torch.preserve_format).requires_grad_(True)
ref_grad = grad.detach().clone(memory_format=torch.preserve_format)
ref_mod = nn.BatchNorm2d(c).cuda().to(dtype=input.dtype)
ref_mod.load_state_dict(mod.state_dict())

if enable_cpu:
cpu_input = input.detach().clone(memory_format=torch.preserve_format).cpu().requires_grad_(True)
cpu_grad = grad.detach().cpu().clone(memory_format=torch.preserve_format)
cpu_mod = nn.BatchNorm2d(c).cpu().to(dtype=input.dtype)
cpu_mod.load_state_dict(mod.state_dict())

print("---------------- forward ----------------")
time.sleep(1)
out = mod(input)
# return
if enable_cpu:
print("---------------- cpu_forward ----------------")
time.sleep(1)
cpu_out = cpu_mod(cpu_input)
if enable_native:
print("---------------- ref_forward ----------------")
with torch.backends.cudnn.flags(enabled=False): # force to use native nhwc batchnorm
time.sleep(1)
ref_out = ref_mod(ref_input)

print("---------------- backward ----------------")
time.sleep(1)
# if input.dtype == torch.bfloat16 and memory_format==torch.channels_last:
# grad_output = grad_output.to(torch.float) # .contiguous(memory_format=torch.channels_last)
out.backward(grad_output)
if enable_cpu:
print("---------------- cpu_backward ----------------")
time.sleep(1)
cpu_out.backward(cpu_grad)
if enable_native:
print("---------------- ref_backward ----------------")
time.sleep(1)
ref_out.backward(ref_grad)
print("---------------- check ----------------")
time.sleep(1)
self.assertTrue(out.is_contiguous(memory_format=memory_format))
if enable_cpu:
self.assertTrue(cpu_out.is_contiguous(memory_format=memory_format))
self.assertEqual(out, cpu_out)
self.assertEqual(mod.weight.grad, cpu_mod.weight.grad)
self.assertEqual(mod.bias.grad, cpu_mod.bias.grad)
self.assertEqual(mod.running_mean, cpu_mod.running_mean)
self.assertEqual(mod.running_var, cpu_mod.running_var)
self.assertEqual(input.grad, cpu_input.grad)
if enable_native:
self.assertTrue(ref_out.is_contiguous(memory_format=memory_format))
self.assertEqual(out, ref_out)
self.assertEqual(mod.weight.grad, ref_mod.weight.grad)
self.assertEqual(mod.bias.grad, ref_mod.bias.grad)
self.assertEqual(mod.running_mean, ref_mod.running_mean, atol=1e-2, rtol=3e-2, exact_dtype=False)
self.assertEqual(mod.running_var, ref_mod.running_var, atol=1e-2, rtol=3e-2, exact_dtype=False)
self.assertEqual(input.grad, ref_input.grad)
print("---------------- end ----------------")

# size = (4, 8, 2, 2)
size = (8, 32, 470, 725)
input = torch.randint(1, 10, size=size, dtype=dtype, device="cuda")
input = input.contiguous(memory_format=memory_format).detach().requires_grad_()
grad = torch.randint(1, 10, size=size, dtype=dtype, device="cuda")
grad = grad.contiguous(memory_format=memory_format)
run_test(input, grad)
# see #42588, grad is channels_last contiguous, but grad.suggest_memory_format (rightly) return "contiguous"
# not channels_last
# input = torch.randint(1, 10, (2, 8, 8, 1), dtype=dtype, device="cuda")
# input = input.contiguous(memory_format=memory_format).detach().requires_grad_()
# grad = torch.randint(1, 10, (2, 8, 8, 1), dtype=dtype, device="cuda")
# grad = grad.permute(0, 2, 1, 3)
# run_test(input, grad)


@onlyCUDA
@dtypes(torch.float, torch.float16, torch.bfloat16)
def test_batchnorm_nhwc_miopen(self, dtype):
# TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
PYTORCH_MIOPEN_SUGGEST_NHWC = "PYTORCH_MIOPEN_SUGGEST_NHWC"
prev_val = os.getenv(PYTORCH_MIOPEN_SUGGEST_NHWC)
try:
os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] = "1"
self.batchnorm2d_miopen(dtype, torch.channels_last)
finally:
if prev_val is None:
del os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC]
else:
os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] = prev_val

@onlyCUDA
@dtypes(torch.float, torch.float16, torch.bfloat16)
def test_batchnorm_nchw_miopen(self, dtype):
self.batchnorm2d_miopen(dtype, torch.contiguous_format)

@onlyCUDA
@dtypes(torch.float, torch.half)
Expand Down