Skip to content

Commit

Permalink
Replace a few cuda() occurences
Browse files Browse the repository at this point in the history
  • Loading branch information
MedericFourmy committed Dec 6, 2023
1 parent 758905d commit 671010d
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from happypose.toolbox.inference.types import DetectionsType, ObservationTensor
from happypose.toolbox.inference.utils import add_instance_id, filter_detections

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Detector(DetectorModule):
def __init__(self, model, ds_name):
Expand Down Expand Up @@ -78,21 +79,21 @@ def get_detections(

if len(bboxes) > 0:
if torch.cuda.is_available():
bboxes = torch.stack(bboxes).cuda().float()
masks = torch.stack(masks).cuda()
bboxes = torch.stack(bboxes).to(device).float()
masks = torch.stack(masks).to(device)
else:
bboxes = torch.stack(bboxes).float()
masks = torch.stack(masks)
else:
infos = {"score": [], "label": [], "batch_im_id": []}
if torch.cuda.is_available():
bboxes = torch.empty(0, 4).cuda().float()
bboxes = torch.empty(0, 4).to(device).float()
masks = torch.empty(
0,
images.shape[1],
images.shape[2],
dtype=torch.bool,
).cuda()
).to(device)
else:
bboxes = torch.empty(0, 4).float()
masks = torch.empty(
Expand Down
9 changes: 5 additions & 4 deletions happypose/pose_estimators/megapose/inference/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from happypose.toolbox.inference.detector import DetectorModule
from happypose.toolbox.inference.types import DetectionsType, ObservationTensor

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Detector(DetectorModule):
def __init__(self, model: torch.nn.Module) -> None:
Expand Down Expand Up @@ -112,17 +113,17 @@ def get_detections(
infos.append(info)

if len(bboxes) > 0:
bboxes = torch.stack(bboxes).cuda().float()
masks = torch.stack(masks).cuda()
bboxes = torch.stack(bboxes).to(device).float()
masks = torch.stack(masks).to(device)
else:
infos = {"score": [], "label": [], "batch_im_id": []}
bboxes = torch.empty(0, 4).cuda().float()
bboxes = torch.empty(0, 4).to(device).float()
masks = torch.empty(
0,
images.shape[1],
images.shape[2],
dtype=torch.bool,
).cuda()
).to(device)

outputs = tc.PandasTensorCollection(
infos=pd.DataFrame(infos),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def run_inference_pipeline(
if detections is None and run_detector:
start_time = time.time()
detections = self.forward_detection_model(observation)
detections = detections.cuda()
detections = detections.to(device)
print("# detections =", len(detections.bboxes))
elapsed = time.time() - start_time
timing_str += f"detection={elapsed:.2f}, "
Expand Down

0 comments on commit 671010d

Please sign in to comment.