diff --git a/monai/apps/nnunet/nnunetv2_runner.py b/monai/apps/nnunet/nnunetv2_runner.py index e62809403e..44b3c24256 100644 --- a/monai/apps/nnunet/nnunetv2_runner.py +++ b/monai/apps/nnunet/nnunetv2_runner.py @@ -37,6 +37,7 @@ class nnUNetV2Runner: # noqa: N801 """ ``nnUNetV2Runner`` provides an interface in MONAI to use `nnU-Net` V2 library to analyze, train, and evaluate neural networks for medical image segmentation tasks. + A version of nnunetv2 higher than 2.2 is needed for this class. ``nnUNetV2Runner`` can be used in two ways: @@ -770,7 +771,7 @@ def find_best_configuration( def predict( self, list_of_lists_or_source_folder: str | list[list[str]], - output_folder: str, + output_folder: str | None | list[str], model_training_output_dir: str, use_folds: tuple[int, ...] | str | None = None, tile_step_size: float = 0.5, @@ -824,7 +825,7 @@ def predict( """ os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}" - from nnunetv2.inference.predict_from_raw_data import predict_from_raw_data + from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor n_processes_preprocessing = ( self.default_num_processes if num_processes_preprocessing < 0 else num_processes_preprocessing @@ -832,20 +833,21 @@ def predict( n_processes_segmentation_export = ( self.default_num_processes if num_processes_segmentation_export < 0 else num_processes_segmentation_export ) - - predict_from_raw_data( - list_of_lists_or_source_folder=list_of_lists_or_source_folder, - output_folder=output_folder, - model_training_output_dir=model_training_output_dir, - use_folds=use_folds, + predictor = nnUNetPredictor( tile_step_size=tile_step_size, use_gaussian=use_gaussian, use_mirroring=use_mirroring, - perform_everything_on_gpu=perform_everything_on_gpu, + perform_everything_on_device=perform_everything_on_gpu, verbose=verbose, + ) + predictor.initialize_from_trained_model_folder( + model_training_output_dir=model_training_output_dir, use_folds=use_folds, checkpoint_name=checkpoint_name + ) + predictor.predict_from_files( + list_of_lists_or_source_folder=list_of_lists_or_source_folder, + output_folder_or_list_of_truncated_output_files=output_folder, save_probabilities=save_probabilities, overwrite=overwrite, - checkpoint_name=checkpoint_name, num_processes_preprocessing=n_processes_preprocessing, num_processes_segmentation_export=n_processes_segmentation_export, folder_with_segs_from_prev_stage=folder_with_segs_from_prev_stage,