From a65aa0f1783161ffc83fef547c5f2d8db814d069 Mon Sep 17 00:00:00 2001 From: Tsekai Lee <44702332+Binary2355@users.noreply.github.com> Date: Tue, 11 Jun 2024 11:49:49 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Feature=E3=80=91knn/tnn=20npu=20added?= =?UTF-8?q?=20(#3124)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: lizekai --- mmcv/ops/csrc/pytorch/npu/knn_npu.cpp | 21 +++++++++++++++ mmcv/ops/csrc/pytorch/npu/three_nn_npu.cpp | 30 ++++++++++++++++++++++ mmcv/ops/knn.py | 5 ++-- 3 files changed, 54 insertions(+), 2 deletions(-) create mode 100644 mmcv/ops/csrc/pytorch/npu/knn_npu.cpp create mode 100644 mmcv/ops/csrc/pytorch/npu/three_nn_npu.cpp diff --git a/mmcv/ops/csrc/pytorch/npu/knn_npu.cpp b/mmcv/ops/csrc/pytorch/npu/knn_npu.cpp new file mode 100644 index 0000000000..f25f9cf623 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/knn_npu.cpp @@ -0,0 +1,21 @@ +#include "pytorch_npu_helper.hpp" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" + +using namespace NPU_NAME_SPACE; +using namespace std; + +void knn_forward_npu(int b, int n, int m, int nsample, const Tensor xyz, + const Tensor new_xyz, Tensor idx, Tensor dist2) { + // transpose known from [B, N, 3] to [B, 3, N] + at::Tensor source = xyz.transpose(1, 2).contiguous(); + at::Tensor target = new_xyz.contiguous(); + + bool is_from_knn = true; + EXEC_NPU_CMD(aclnnKnn, source, target, nsample, is_from_knn, idx, dist2); +} + +void knn_forward_impl(int b, int n, int m, int nsample, const Tensor xyz, + const Tensor new_xyz, Tensor idx, Tensor dist2); + +REGISTER_NPU_IMPL(knn_forward_impl, knn_forward_npu); diff --git a/mmcv/ops/csrc/pytorch/npu/three_nn_npu.cpp b/mmcv/ops/csrc/pytorch/npu/three_nn_npu.cpp new file mode 100644 index 0000000000..9766816f6c --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/three_nn_npu.cpp @@ -0,0 +1,30 @@ +#include "pytorch_npu_helper.hpp" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" + +using namespace NPU_NAME_SPACE; +using namespace std; + +void three_nn_forward_npu(int b, int n, int m, const Tensor unknown, + const Tensor known, Tensor dist2, Tensor idx) { + // transpose known [B, N, 3] -> [B, 3, N] + at::Tensor source = known.transpose(1, 2).contiguous(); + at::Tensor target = unknown.contiguous(); + auto originDtype = source.scalar_type(); + if (originDtype == at::kHalf) { + source = source.to(at::kFloat); + target = target.to(at::kFloat); + } + + bool is_from_knn = false; + uint32_t nsample = 3; + EXEC_NPU_CMD(aclnnKnn, source, target, nsample, is_from_knn, idx, dist2); + if (originDtype == at::kHalf) { + dist2 = dist2.to(at::kHalf); + } +} + +void three_nn_forward_impl(int b, int n, int m, const Tensor unknown, + const Tensor known, Tensor dist2, Tensor idx); + +REGISTER_NPU_IMPL(three_nn_forward_impl, three_nn_forward_npu); diff --git a/mmcv/ops/knn.py b/mmcv/ops/knn.py index 48ce92f925..47ced04c6a 100644 --- a/mmcv/ops/knn.py +++ b/mmcv/ops/knn.py @@ -55,8 +55,9 @@ def forward(ctx, center_xyz_device = center_xyz.get_device() assert center_xyz_device == xyz.get_device(), \ 'center_xyz and xyz should be put on the same device' - if torch.cuda.current_device() != center_xyz_device: - torch.cuda.set_device(center_xyz_device) + if xyz.device.type != 'npu': + if torch.cuda.current_device() != center_xyz_device: + torch.cuda.set_device(center_xyz_device) B, npoint, _ = center_xyz.shape N = xyz.shape[1]