From 5e2b9a7b837d903bca00daf929ca5a461a8c7f50 Mon Sep 17 00:00:00 2001 From: Wang Yixuan <88923622+hust17yixuan@users.noreply.github.com> Date: Mon, 9 Dec 2024 16:31:31 +0800 Subject: [PATCH] fix deform conv (#3212) --- mmcv/ops/deform_conv.py | 3 ++- mmcv/ops/modulated_deform_conv.py | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/mmcv/ops/deform_conv.py b/mmcv/ops/deform_conv.py index 73472dc9b1..d0e9d21604 100644 --- a/mmcv/ops/deform_conv.py +++ b/mmcv/ops/deform_conv.py @@ -51,8 +51,9 @@ def symbolic(g, def _npu_backward(ctx, grad_output): input_tensor, weight, offset_out, offset_all, sort_index_for_npu_bp = \ ctx.saved_tensors + import torch_npu grad_input, grad_weight, grad_offset_all, grad_bias = \ - torch.npu_deformable_conv2dbk( + torch_npu.npu_deformable_conv2dbk( input_tensor, grad_output, offset_out, weight, offset_all, kernel_size=[weight.shape[3], weight.shape[2]], stride=[1, 1, ctx.stride[0], ctx.stride[1]], diff --git a/mmcv/ops/modulated_deform_conv.py b/mmcv/ops/modulated_deform_conv.py index 4c735e2a09..8a348e8351 100644 --- a/mmcv/ops/modulated_deform_conv.py +++ b/mmcv/ops/modulated_deform_conv.py @@ -56,7 +56,8 @@ def _npu_forward(ctx, input_tensor, offset, mask, weight, bias): kernel_w, kernel_h, ctx.deform_groups) select_offset = offset.index_select(1, sort_index_fp) offset_all = torch.cat([select_offset, mask], dim=1) - output, offset_out = torch.npu_deformable_conv2d( + import torch_npu + output, offset_out = torch_npu.npu_deformable_conv2d( input_tensor, weight, offset_all, @@ -80,8 +81,9 @@ def _npu_forward(ctx, input_tensor, offset, mask, weight, bias): def _npu_backward(ctx, grad_output): input_tensor, weight, offset_out, offset_all, sort_index_bp = \ ctx.saved_tensors + import torch_npu grad_input, grad_weight, grad_offset_all, grad_bias = \ - torch.npu_deformable_conv2dbk( + torch_npu.npu_deformable_conv2dbk( input_tensor, grad_output, offset_out, weight, offset_all, kernel_size=[weight.shape[3], weight.shape[2]], stride=[1, 1, ctx.stride[0], ctx.stride[1]],