From 671010dd485f5a171c0e294fccd17c91f2abdb15 Mon Sep 17 00:00:00 2001 From: mfourmy Date: Wed, 6 Dec 2023 17:49:59 +0100 Subject: [PATCH] Replace a few cuda() occurences --- .../cosypose/cosypose/integrated/detector.py | 9 +++++---- happypose/pose_estimators/megapose/inference/detector.py | 9 +++++---- .../pose_estimators/megapose/inference/pose_estimator.py | 2 +- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/happypose/pose_estimators/cosypose/cosypose/integrated/detector.py b/happypose/pose_estimators/cosypose/cosypose/integrated/detector.py index 4b9a36ac..3a2a6d96 100644 --- a/happypose/pose_estimators/cosypose/cosypose/integrated/detector.py +++ b/happypose/pose_estimators/cosypose/cosypose/integrated/detector.py @@ -10,6 +10,7 @@ from happypose.toolbox.inference.types import DetectionsType, ObservationTensor from happypose.toolbox.inference.utils import add_instance_id, filter_detections +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class Detector(DetectorModule): def __init__(self, model, ds_name): @@ -78,21 +79,21 @@ def get_detections( if len(bboxes) > 0: if torch.cuda.is_available(): - bboxes = torch.stack(bboxes).cuda().float() - masks = torch.stack(masks).cuda() + bboxes = torch.stack(bboxes).to(device).float() + masks = torch.stack(masks).to(device) else: bboxes = torch.stack(bboxes).float() masks = torch.stack(masks) else: infos = {"score": [], "label": [], "batch_im_id": []} if torch.cuda.is_available(): - bboxes = torch.empty(0, 4).cuda().float() + bboxes = torch.empty(0, 4).to(device).float() masks = torch.empty( 0, images.shape[1], images.shape[2], dtype=torch.bool, - ).cuda() + ).to(device) else: bboxes = torch.empty(0, 4).float() masks = torch.empty( diff --git a/happypose/pose_estimators/megapose/inference/detector.py b/happypose/pose_estimators/megapose/inference/detector.py index 97732139..e15aee69 100644 --- a/happypose/pose_estimators/megapose/inference/detector.py +++ b/happypose/pose_estimators/megapose/inference/detector.py @@ -29,6 +29,7 @@ from happypose.toolbox.inference.detector import DetectorModule from happypose.toolbox.inference.types import DetectionsType, ObservationTensor +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class Detector(DetectorModule): def __init__(self, model: torch.nn.Module) -> None: @@ -112,17 +113,17 @@ def get_detections( infos.append(info) if len(bboxes) > 0: - bboxes = torch.stack(bboxes).cuda().float() - masks = torch.stack(masks).cuda() + bboxes = torch.stack(bboxes).to(device).float() + masks = torch.stack(masks).to(device) else: infos = {"score": [], "label": [], "batch_im_id": []} - bboxes = torch.empty(0, 4).cuda().float() + bboxes = torch.empty(0, 4).to(device).float() masks = torch.empty( 0, images.shape[1], images.shape[2], dtype=torch.bool, - ).cuda() + ).to(device) outputs = tc.PandasTensorCollection( infos=pd.DataFrame(infos), diff --git a/happypose/pose_estimators/megapose/inference/pose_estimator.py b/happypose/pose_estimators/megapose/inference/pose_estimator.py index e327f4bd..df0e3dcd 100644 --- a/happypose/pose_estimators/megapose/inference/pose_estimator.py +++ b/happypose/pose_estimators/megapose/inference/pose_estimator.py @@ -562,7 +562,7 @@ def run_inference_pipeline( if detections is None and run_detector: start_time = time.time() detections = self.forward_detection_model(observation) - detections = detections.cuda() + detections = detections.to(device) print("# detections =", len(detections.bboxes)) elapsed = time.time() - start_time timing_str += f"detection={elapsed:.2f}, "