From 9c42c43f6f22bb048682bff45688449a0a5bd9ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E5=B8=AE=E6=94=BF?= Date: Fri, 25 Oct 2024 15:11:36 +0800 Subject: [PATCH 1/2] =?UTF-8?q?points=5Fin=5Fboxes=5Fall=E5=92=8Cpoints=5F?= =?UTF-8?q?in=5Fboxes=5Fpart=E7=9A=84mmcv=E5=85=BC=E5=AE=B9npu=E5=88=A4?= =?UTF-8?q?=E6=96=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mmcv/ops/points_in_boxes.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/mmcv/ops/points_in_boxes.py b/mmcv/ops/points_in_boxes.py index 4915e6b573..6d10e93b69 100644 --- a/mmcv/ops/points_in_boxes.py +++ b/mmcv/ops/points_in_boxes.py @@ -47,8 +47,11 @@ def points_in_boxes_part(points: Tensor, boxes: Tensor) -> Tensor: points_device = points.get_device() assert points_device == boxes.get_device(), \ 'Points and boxes should be put on the same device' - if torch.cuda.current_device() != points_device: - torch.cuda.set_device(points_device) + if points.device.type != 'npu': + if torch.cuda.current_device() != points_device: + torch.cuda.set_device(points_device) + else: + boxes[:, :, 2] += boxes[:, :, 5] / 2.0 ext_module.points_in_boxes_part_forward(boxes.contiguous(), points.contiguous(), @@ -127,8 +130,11 @@ def points_in_boxes_all(points: Tensor, boxes: Tensor) -> Tensor: points_device = points.get_device() assert points_device == boxes.get_device(), \ 'Points and boxes should be put on the same device' - if torch.cuda.current_device() != points_device: - torch.cuda.set_device(points_device) + if points.device.type != 'npu': + if torch.cuda.current_device() != points_device: + torch.cuda.set_device(points_device) + else: + boxes[:, :, 2] += boxes[:, :, 5] / 2.0 ext_module.points_in_boxes_all_forward(boxes.contiguous(), points.contiguous(), From a136f2f73b9e55648961474a904a951fe4c982a1 Mon Sep 17 00:00:00 2001 From: Zac <33156501+ZrBac@users.noreply.github.com> Date: Tue, 5 Nov 2024 10:27:58 +0800 Subject: [PATCH 2/2] Update points_in_boxes.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 去除判断 --- mmcv/ops/points_in_boxes.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/mmcv/ops/points_in_boxes.py b/mmcv/ops/points_in_boxes.py index 6d10e93b69..e58a6e2a12 100644 --- a/mmcv/ops/points_in_boxes.py +++ b/mmcv/ops/points_in_boxes.py @@ -133,8 +133,6 @@ def points_in_boxes_all(points: Tensor, boxes: Tensor) -> Tensor: if points.device.type != 'npu': if torch.cuda.current_device() != points_device: torch.cuda.set_device(points_device) - else: - boxes[:, :, 2] += boxes[:, :, 5] / 2.0 ext_module.points_in_boxes_all_forward(boxes.contiguous(), points.contiguous(),