diff --git a/petric.py b/petric.py index c282e80..ec62dc7 100755 --- a/petric.py +++ b/petric.py @@ -22,7 +22,7 @@ from traceback import print_exc import numpy as np -from skimage.metrics import mean_squared_error, peak_signal_noise_ratio +from skimage.metrics import mean_squared_error as mse from tensorboardX import SummaryWriter import sirf.STIR as STIR @@ -88,26 +88,51 @@ def __call__(self, algo: Algorithm): log.debug("...logged") +class QualityMetrics(ImageQualityCallback): + """From https://github.com/SyneRBI/PETRIC/wiki#metrics-and-thresholds""" + def __init__(self, reference_image, whole_object_mask, background_mask, **kwargs): + super().__init__(reference_image, **kwargs) + self.whole_object_indices = np.where(whole_object_mask == 1) + self.background_indices = np.where(background_mask == 1) + self.ref_im_arr = reference_image.as_array() + self.norm = self.ref_im_arr[self.background_indices].mean() + + def __call__(self, algorithm): + iteration = algorithm.iteration + if iteration % algorithm.update_objective_interval != 0 and iteration != algorithm.max_iteration: + return + + assert not any(self.filter.values()), "Filtering not implemented" + test_im_arr = algorithm.x.as_array() + + # (1) global metrics & statistics + self.tb_summary_writer.add_scalar( + "RMSE_whole_object", + np.sqrt(mse(self.ref_im_arr[self.whole_object_indices], test_im_arr[self.whole_object_indices])) / + self.norm, iteration) + self.tb_summary_writer.add_scalar( + "RMSE_background", + np.sqrt(mse(self.ref_im_arr[self.background_indices], test_im_arr[self.background_indices])) / self.norm, + iteration) + + # (2) local metrics & statistics + for voi_name, voi_indices in sorted(self.voi_indices.items()): + # AEM not to be confused with MAE + self.tb_summary_writer.add_scalar( + f"AEM_VOI_{voi_name}", + np.abs(test_im_arr[voi_indices].mean() - self.ref_im_arr[voi_indices].mean()) / self.norm, iteration) + + class MetricsWithTimeout(cbks.Callback): """Stops the algorithm after `seconds`""" - def __init__(self, seconds=300, outdir=OUTDIR, transverse_slice=None, coronal_slice=None, reference_image=None, - verbose=1): + def __init__(self, seconds=300, outdir=OUTDIR, transverse_slice=None, coronal_slice=None, verbose=1): super().__init__(verbose) self._seconds = seconds self.callbacks = [ cbks.ProgressCallback(), SaveIters(outdir=outdir), (tb_cbk := TensorBoard(logdir=outdir, transverse_slice=transverse_slice, coronal_slice=coronal_slice))] - - if reference_image: - roi_image_dict = {f'S{i}': STIR.ImageData(f'S{i}.hv') for i in range(1, 8)} - # NB: these metrics are for testing only. - # The final evaluation will use metrics described in https://github.com/SyneRBI/PETRIC/wiki - self.callbacks.append( - ImageQualityCallback( - reference_image, tb_cbk.tb, roi_mask_dict=roi_image_dict, metrics_dict={ - 'MSE': mean_squared_error, 'MAE': self.mean_absolute_error, 'PSNR': peak_signal_noise_ratio}, - statistics_dict={'MEAN': np.mean, 'STDDEV': np.std, 'MAX': np.max})) + self.tb = tb_cbk.tb self.reset() def reset(self, seconds=None): @@ -144,7 +169,9 @@ def construct_RDP(penalty_strength, initial_image, kappa, max_scaling=1e-3): return prior -Dataset = namedtuple('Dataset', ['acquired_data', 'additive_term', 'mult_factors', 'OSEM_image', 'prior', 'kappa']) +Dataset = namedtuple('Dataset', [ + 'acquired_data', 'additive_term', 'mult_factors', 'OSEM_image', 'prior', 'kappa', 'reference_image', + 'whole_object_mask', 'background_mask', 'voi_masks']) def get_data(srcdir=".", outdir=OUTDIR, sirf_verbosity=0): @@ -165,7 +192,20 @@ def get_data(srcdir=".", outdir=OUTDIR, sirf_verbosity=0): penalty_strength = 1 / 700 # default choice prior = construct_RDP(penalty_strength, OSEM_image, kappa) - return Dataset(acquired_data, additive_term, mult_factors, OSEM_image, prior, kappa) + def get_image(fname): + if (source := srcdir / 'PETRIC' / fname).is_file(): + return STIR.ImageData(str(source)) + return None # explicit to suppress linter warnings + + reference_image = get_image('reference_image.hv') + whole_object_mask = get_image('VOI_whole_object.hv') + background_mask = get_image('VOI_background.hv') + voi_masks = { + voi.stem: STIR.ImageData(str(voi)) + for voi in (srcdir / 'PETRIC').glob("VOI_*.hv") if voi.stem[4:] not in ('background', 'whole_object')} + + return Dataset(acquired_data, additive_term, mult_factors, OSEM_image, prior, kappa, reference_image, + whole_object_mask, background_mask, voi_masks) if SRCDIR.is_dir(): @@ -194,7 +234,12 @@ def get_data(srcdir=".", outdir=OUTDIR, sirf_verbosity=0): assert issubclass(Submission, Algorithm) for srcdir, outdir, metrics in data_dirs_metrics: data = get_data(srcdir=srcdir, outdir=outdir) - metrics[0].reset() # timeout from now + metrics_with_timeout = metrics[0] + if data.reference_image is not None: + metrics_with_timeout.callbacks.append( + QualityMetrics(data.reference_image, data.whole_object_mask, data.background_mask, + tb_summary_writer=metrics_with_timeout.tb, roi_mask_dict=data.voi_masks)) + metrics_with_timeout.reset() # timeout from now algo = Submission(data) try: algo.run(np.inf, callbacks=metrics + submission_callbacks)