diff --git a/substrafl/algorithms/pytorch/torch_base_algo.py b/substrafl/algorithms/pytorch/torch_base_algo.py index a40fb55b..51060753 100644 --- a/substrafl/algorithms/pytorch/torch_base_algo.py +++ b/substrafl/algorithms/pytorch/torch_base_algo.py @@ -158,11 +158,9 @@ def _local_predict(self, predict_dataset: torch.utils.data.Dataset, predictions_ self._model.eval() - predictions = torch.Tensor([]).to(self._device) with torch.inference_mode(): - for x in predict_loader: - x = x.to(self._device) - predictions = torch.cat((predictions, self._model(x)), 0) + predictions = [self._model(x.to(self._device)) for x in predict_loader] + predictions = torch.cat(predictions, dim=0) predictions = predictions.cpu().detach() self._save_predictions(predictions, predictions_path)