From c90c5f250350370288c7a07b27fe5fd71a44e9fe Mon Sep 17 00:00:00 2001 From: Samuel Chapman <48865231+samuel-wj-chapman@users.noreply.github.com> Date: Tue, 24 Sep 2024 11:23:50 +0300 Subject: [PATCH] segmentation tutorial fix (#1228) resolve issue #1222 fix gpu execusion on tutorial --- tutorials/mct_model_garden/models_pytorch/yolov8/yolov8.py | 5 +++-- .../pytorch/pytorch_yolov8n_seg_for_imx500.ipynb | 2 ++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tutorials/mct_model_garden/models_pytorch/yolov8/yolov8.py b/tutorials/mct_model_garden/models_pytorch/yolov8/yolov8.py index 037eda499..973872c82 100644 --- a/tutorials/mct_model_garden/models_pytorch/yolov8/yolov8.py +++ b/tutorials/mct_model_garden/models_pytorch/yolov8/yolov8.py @@ -536,11 +536,12 @@ def seg_model_predict(model: Any, List: List containing tensors of predictions. """ input_tensor = torch.from_numpy(inputs).unsqueeze(0) # Add batch dimension - + device = get_working_device() + input_tensor = input_tensor.to(device) # Run the model with torch.no_grad(): outputs = model(input_tensor) - + outputs = [output.cpu() for output in outputs] return outputs def yolov8_pytorch(model_yaml: str) -> (nn.Module, Dict): diff --git a/tutorials/notebooks/imx500_notebooks/pytorch/pytorch_yolov8n_seg_for_imx500.ipynb b/tutorials/notebooks/imx500_notebooks/pytorch/pytorch_yolov8n_seg_for_imx500.ipynb index 55060b425..f2e059434 100644 --- a/tutorials/notebooks/imx500_notebooks/pytorch/pytorch_yolov8n_seg_for_imx500.ipynb +++ b/tutorials/notebooks/imx500_notebooks/pytorch/pytorch_yolov8n_seg_for_imx500.ipynb @@ -428,6 +428,8 @@ "source": [ "from tutorials.mct_model_garden.models_pytorch.yolov8.yolov8 import seg_model_predict\n", "from tutorials.mct_model_garden.evaluation_metrics.coco_evaluation import evaluate_yolov8_segmentation\n", + "device = get_working_device()\n", + "model = model.to(device)\n", "evaluate_yolov8_segmentation(model, seg_model_predict, data_dir='coco', data_type='val2017', img_ids_limit=100, output_file='results.json', iou_thresh=0.7, conf=0.001, max_dets=300,mask_thresh=0.55)" ] },