From 16d6c85aa36f3c0311b9d0ec2889f18610368c76 Mon Sep 17 00:00:00 2001 From: mandreux-owkin <62643750+mandreux-owkin@users.noreply.github.com> Date: Thu, 21 Sep 2023 08:28:22 +0200 Subject: [PATCH] chore: refactor _local_predict for better performances (#171) Signed-off-by: Mathieu Andreux --- substrafl/algorithms/pytorch/torch_base_algo.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/substrafl/algorithms/pytorch/torch_base_algo.py b/substrafl/algorithms/pytorch/torch_base_algo.py index a40fb55b..51060753 100644 --- a/substrafl/algorithms/pytorch/torch_base_algo.py +++ b/substrafl/algorithms/pytorch/torch_base_algo.py @@ -158,11 +158,9 @@ def _local_predict(self, predict_dataset: torch.utils.data.Dataset, predictions_ self._model.eval() - predictions = torch.Tensor([]).to(self._device) with torch.inference_mode(): - for x in predict_loader: - x = x.to(self._device) - predictions = torch.cat((predictions, self._model(x)), 0) + predictions = [self._model(x.to(self._device)) for x in predict_loader] + predictions = torch.cat(predictions, dim=0) predictions = predictions.cpu().detach() self._save_predictions(predictions, predictions_path)