From 51919ea3096d3d9789db34fba97b8dca09db7e59 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Thu, 26 Oct 2023 10:32:58 +0300 Subject: [PATCH] Fixed issue with torch 1.12 issue with arange not supporting fp16 for CPU device. --- .../training/models/detection_models/yolo_base.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/super_gradients/training/models/detection_models/yolo_base.py b/src/super_gradients/training/models/detection_models/yolo_base.py index 1f058f43f8..7e8bde9194 100755 --- a/src/super_gradients/training/models/detection_models/yolo_base.py +++ b/src/super_gradients/training/models/detection_models/yolo_base.py @@ -267,7 +267,7 @@ def forward(self, inputs): if not self.training: outputs_logits.append(output.clone()) if self.grid[i].shape[2:4] != output.shape[2:4]: - self.grid[i] = self._make_grid(nx, ny, dtype=reg_output.dtype).to(output.device) + self.grid[i] = self._make_grid(nx, ny, dtype=reg_output.dtype, device=output.device) xy = (output[..., :2] + self.grid[i].to(output.device)) * self.stride[i] wh = torch.exp(output[..., 2:4]) * self.stride[i] @@ -279,12 +279,14 @@ def forward(self, inputs): return outputs if self.training else (torch.cat(outputs, 1), outputs_logits) @staticmethod - def _make_grid(nx: int, ny: int, dtype: torch.dtype): + def _make_grid(nx: int, ny: int, dtype: torch.dtype, device: torch.device): + y, x = torch.arange(ny, dtype=torch.float32, device=device), torch.arange(nx, dtype=torch.float32, device=device) + if torch_version_is_greater_or_equal(1, 10): # https://github.com/pytorch/pytorch/issues/50276 - yv, xv = torch.meshgrid([torch.arange(ny, dtype=dtype), torch.arange(nx, dtype=dtype)], indexing="ij") + yv, xv = torch.meshgrid([y, x], indexing="ij") else: - yv, xv = torch.meshgrid([torch.arange(ny, dtype=dtype), torch.arange(nx, dtype=dtype)]) + yv, xv = torch.meshgrid([y, x]) return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).to(dtype)