Skip to content

Commit

Permalink
Add multi_scale_deform_attn_grad op adapter for NPU (#3046)
Browse files Browse the repository at this point in the history
  • Loading branch information
RRaoyzee authored Mar 13, 2024
1 parent e5562f8 commit 265531f
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 1 deletion.
59 changes: 58 additions & 1 deletion mmcv/ops/csrc/pytorch/npu/ms_deform_attn_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ Tensor ms_deform_attn_forward_npu(const Tensor &value,

c10::SmallVector<int64_t, 3> output_size = {
value.size(0), sampling_locations.size(1), value.size(2) * value.size(3)};
at::Tensor output = at::empty(output_size, value_fp32.options());
at::Tensor output = at::zeros(output_size, value_fp32.options());

OpCommand cmd;
cmd.Name("MultiScaleDeformableAttnFunction")
Expand All @@ -75,3 +75,60 @@ Tensor ms_deform_attn_forward_npu(const Tensor &value,
}

REGISTER_NPU_IMPL(ms_deform_attn_impl_forward, ms_deform_attn_forward_npu);

void ms_deform_attn_impl_backward(
const Tensor &value, const Tensor &spatial_shapes,
const Tensor &level_start_index, const Tensor &sampling_loc,
const Tensor &attn_weight, const Tensor &grad_output, Tensor &grad_value,
Tensor &grad_sampling_loc, Tensor &grad_attn_weight,
const int im2col_step);

void ms_deform_attn_backward_npu(const Tensor &value, const Tensor &spatial_shapes,
const Tensor &level_start_index,
const Tensor &sampling_loc,
const Tensor &attn_weight,
const Tensor &grad_output, Tensor &grad_value,
Tensor &grad_sampling_loc,
Tensor &grad_attn_weight, const int im2col_step) {
check_support(value, attn_weight);
at::Tensor value_fp32 = value;
at::Tensor spatial_shapes_int32 = spatial_shapes;
at::Tensor level_start_index_int32 = level_start_index;
at::Tensor sampling_loc_fp32 = sampling_loc.transpose(4, 5).contiguous();
at::Tensor attn_weight_fp32 = attn_weight;
at::Tensor grad_output_fp32 = grad_output;
if (value.scalar_type() != at::kFloat) {
value_fp32 = value.to(at::kFloat);
}
if (spatial_shapes.scalar_type() != at::kInt) {
spatial_shapes_int32 = spatial_shapes.to(at::kInt);
}
if (level_start_index.scalar_type() != at::kInt) {
level_start_index_int32 = level_start_index.to(at::kInt);
}
if (sampling_loc.scalar_type() != at::kFloat) {
sampling_loc_fp32 = sampling_loc_fp32.to(at::kFloat);
}
if (attn_weight.scalar_type() != at::kFloat) {
attn_weight_fp32 = attn_weight.to(at::kFloat);
}
if (grad_output.scalar_type() != at::kFloat) {
grad_output_fp32 = grad_output.to(at::kFloat);
}

OpCommand cmd;
cmd.Name("MultiScaleDeformableAttentionGrad")
.Input(value_fp32)
.Input(spatial_shapes_int32)
.Input(level_start_index_int32)
.Input(sampling_loc_fp32)
.Input(attn_weight_fp32)
.Input(grad_output_fp32)
.Output(grad_value)
.Output(grad_sampling_loc)
.Output(grad_attn_weight)
.Run();
grad_sampling_loc = grad_sampling_loc.transpose(4, 5).contiguous();
}

REGISTER_NPU_IMPL(ms_deform_attn_impl_backward, ms_deform_attn_backward_npu);
64 changes: 64 additions & 0 deletions tests/test_ops/test_ms_deformable_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,3 +317,67 @@ def test_gradient_numerical(channels,
im2col_step),
eps=eps,
atol=1e-2)


@pytest.mark.skipif(not IS_NPU_AVAILABLE, reason='requires NPU support')
def test_backward_equal_with_pytorch_npu():
N, M, D = 6, 4, 8
Lq, L, P = 10000, 4, 8
shapes = torch.as_tensor([(60, 40), (30, 20), (16, 24), (53, 32)],
dtype=torch.int32)
level_start_index = torch.cat((shapes.new_zeros(
(1, )), shapes.prod(1).cumsum(0)[:-1]))
S = sum((H * W).item() for H, W in shapes)

torch.manual_seed(3)
value = torch.rand(N, S, M, D) * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2)
attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5
attention_weights /= attention_weights.sum(
-1, keepdim=True).sum(
-2, keepdim=True)
im2col_step = 2
value.requires_grad = True
sampling_locations.requires_grad = True
attention_weights.requires_grad = True
output_pytorch = multi_scale_deformable_attn_pytorch(
value.float(), shapes, sampling_locations.float(),
attention_weights.float())
grad_output_pytorch = torch.ones_like(output_pytorch)
output_pytorch.backward(grad_output_pytorch)
grad_value = value.grad.detach().cpu()
grad_location = sampling_locations.grad.detach().cpu()
grad_attn_weight = attention_weights.grad.detach().cpu()

value_npu = value.npu()
shapes_npu = shapes.npu()
level_start_index_npu = level_start_index.npu()
sampling_locations_npu = sampling_locations.npu()
attention_weights_npu = attention_weights.npu()
output_npu = MultiScaleDeformableAttnFunction.apply(
value_npu.float(), shapes_npu, level_start_index_npu,
sampling_locations_npu.float(), attention_weights_npu.float(),
im2col_step)
grad_output_npu = torch.ones_like(output_npu)
output_npu.backward(grad_output_npu)
grad_value_npu = value_npu.grad.detach().cpu()
grad_location_npu = sampling_locations_npu.grad.detach().cpu()
grad_attn_weight_npu = attention_weights_npu.grad.detach().cpu()
assert torch.allclose(grad_value_npu, grad_value)
max_abs_err_1 = (grad_value_npu - grad_value).abs().max()
max_rel_err_1 = ((grad_value_npu - grad_value).abs() /
grad_value.abs()).max()
assert max_abs_err_1 < 1e-5
assert max_rel_err_1 < 1e-4
assert torch.allclose(grad_location_npu, grad_location)
max_abs_err_2 = (grad_location_npu - grad_location).abs().max()
max_rel_err_2 = ((grad_location_npu - grad_location).abs() /
grad_location.abs()).max()
assert max_abs_err_2 < 1e-5
assert max_rel_err_2 < 1e-4
assert torch.allclose(grad_attn_weight_npu, grad_attn_weight)
max_abs_err_3 = (grad_attn_weight_npu - grad_attn_weight).abs().max()
max_rel_err_3 = ((grad_attn_weight_npu - grad_attn_weight).abs() /
grad_attn_weight.abs()).max()
assert max_abs_err_3 < 1e-5
assert max_rel_err_3 < 1e-4

0 comments on commit 265531f

Please sign in to comment.