Skip to content

Commit

Permalink
Merge pull request #73 from JYYCaN/rc4main
Browse files Browse the repository at this point in the history
add new npu op roiaware_pool3d && fix npu op scatter_points bug
  • Loading branch information
hust17yixuan authored Nov 6, 2024
2 parents 32997dd + 3424ec1 commit 0238175
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 5 deletions.
86 changes: 86 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/roiaware_pool3d.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#include "pytorch_npu_helper.hpp"
using namespace NPU_NAME_SPACE;
using namespace std;

void roiaware_pool3d_forward_npu(int boxes_num, int pts_num, int channels,
int max_pts_each_voxel, int out_x, int out_y,
int out_z, const Tensor rois, const Tensor pts,
const Tensor pts_feature, Tensor argmax,
Tensor pts_idx_of_voxels,
Tensor pooled_features, int pool_method) {
at::Tensor rois_cast = rois;
at::Tensor pts_cast = pts;
at::Tensor pts_feature_cast = pts_feature;
at::Tensor pooled_features_cast = pooled_features;

auto dtype = rois.dtype();
if (dtype == at::kHalf) {
rois_cast = rois_cast.to(at::kFloat);
pts_cast = pts_cast.to(at::kFloat);
pts_feature_cast = pts_feature_cast.to(at::kFloat);
pooled_features_cast = pooled_features_cast.to(at::kFloat);
}

EXEC_NPU_CMD(aclnnRoiawarePool3d, rois_cast, pts_cast, pts_feature_cast,
pool_method, max_pts_each_voxel, out_x, out_y, out_z, argmax,
pts_idx_of_voxels, pooled_features_cast);

if (dtype == at::kHalf) {
pooled_features_cast = pooled_features_cast.to(at::kHalf);
}

pooled_features.copy_(pooled_features_cast);
}

void roiaware_pool3d_backward_npu(int boxes_num, int out_x, int out_y,
int out_z, int channels,
int max_pts_each_voxel,
const Tensor pts_idx_of_voxels,
const Tensor argmax, const Tensor grad_out,
Tensor grad_in, int pool_method)
{
int32_t npoints = grad_in.size(0);

auto dtype = grad_out.dtype();
at::Tensor grad_out_cast = grad_out;
at::Tensor grad_in_cast = grad_in;

if (dtype == at::kHalf) {
grad_out_cast = grad_out.to(at::kFloat);
grad_in_cast = grad_in_cast.to(at::kFloat);
}

if (pool_method == 0) {
// maxpool3d
EXEC_NPU_CMD(aclnnRoiawareMaxpool3dGrad, argmax, grad_out_cast, boxes_num,
out_x, out_y, out_z, channels, npoints, grad_in_cast);
} else if (pool_method == 1) {
// avgpool3d
EXEC_NPU_CMD(aclnnRoiawareAvgpool3dGrad, pts_idx_of_voxels, grad_out_cast,
boxes_num, out_x, out_y, out_z, channels, npoints,
max_pts_each_voxel, grad_in_cast);
}

if (dtype == at::kHalf) {
grad_in_cast = grad_in_cast.to(at::kHalf);
}

grad_in.copy_(grad_in_cast);
}

void roiaware_pool3d_forward_impl(int boxes_num, int pts_num, int channels,
int max_pts_each_voxel, int out_x, int out_y,
int out_z, const Tensor rois,
const Tensor pts, const Tensor pts_feature,
Tensor argmax, Tensor pts_idx_of_voxels,
Tensor pooled_features, int pool_method);

void roiaware_pool3d_backward_impl(int boxes_num, int out_x, int out_y,
int out_z, int channels,
int max_pts_each_voxel,
const Tensor pts_idx_of_voxels,
const Tensor argmax, const Tensor grad_out,
Tensor grad_in, int pool_method);

REGISTER_NPU_IMPL(roiaware_pool3d_forward_impl, roiaware_pool3d_forward_npu);
REGISTER_NPU_IMPL(roiaware_pool3d_backward_impl, roiaware_pool3d_backward_npu);
10 changes: 5 additions & 5 deletions mmcv/ops/scatter_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ def forward(ctx: Any,
"""
ctx.device = feats.device.type
if ctx.device == 'npu':
import ads_c
voxel_idx = ads_c.point_to_voxel(coors, [], [], 'XYZ')
unique_res = ads_c.unique_voxel(voxel_idx)
import mx_driving._C
voxel_idx = mx_driving._C.point_to_voxel(coors, [], [], 'XYZ')
unique_res = mx_driving._C.unique_voxel(voxel_idx)
num_voxels, uniqued_voxel_idx, prefix_sum, \
argsort_coor, _ = unique_res
voxel_coors = \
ads_c.voxel_to_point(uniqued_voxel_idx, [], [], 'XYZ')
mx_driving._C.voxel_to_point(uniqued_voxel_idx, [], [], 'XYZ')
voxel_feats, \
compare_mask = ads_c.npu_dynamic_scatter(feats, coors,
compare_mask = mx_driving._C.npu_dynamic_scatter(feats, coors,
prefix_sum,
argsort_coor,
num_voxels,
Expand Down

0 comments on commit 0238175

Please sign in to comment.