Skip to content

Commit

Permalink
segmentation tutorial fix (sony#1228)
Browse files Browse the repository at this point in the history
resolve issue sony#1222 
fix gpu execusion on tutorial
  • Loading branch information
samuel-wj-chapman authored Sep 24, 2024
1 parent 3eed10b commit c90c5f2
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
5 changes: 3 additions & 2 deletions tutorials/mct_model_garden/models_pytorch/yolov8/yolov8.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,11 +536,12 @@ def seg_model_predict(model: Any,
List: List containing tensors of predictions.
"""
input_tensor = torch.from_numpy(inputs).unsqueeze(0) # Add batch dimension

device = get_working_device()
input_tensor = input_tensor.to(device)
# Run the model
with torch.no_grad():
outputs = model(input_tensor)

outputs = [output.cpu() for output in outputs]
return outputs

def yolov8_pytorch(model_yaml: str) -> (nn.Module, Dict):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,8 @@
"source": [
"from tutorials.mct_model_garden.models_pytorch.yolov8.yolov8 import seg_model_predict\n",
"from tutorials.mct_model_garden.evaluation_metrics.coco_evaluation import evaluate_yolov8_segmentation\n",
"device = get_working_device()\n",
"model = model.to(device)\n",
"evaluate_yolov8_segmentation(model, seg_model_predict, data_dir='coco', data_type='val2017', img_ids_limit=100, output_file='results.json', iou_thresh=0.7, conf=0.001, max_dets=300,mask_thresh=0.55)"
]
},
Expand Down

0 comments on commit c90c5f2

Please sign in to comment.