Skip to content

Commit

Permalink
add metrics as per wiki
Browse files Browse the repository at this point in the history
  • Loading branch information
casperdcl committed Jul 9, 2024
1 parent 85393b1 commit 5e0f33c
Showing 1 changed file with 65 additions and 16 deletions.
81 changes: 65 additions & 16 deletions petric.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from traceback import print_exc

import numpy as np
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio
from tensorboardX import SummaryWriter

import sirf.STIR as STIR
Expand Down Expand Up @@ -88,26 +87,58 @@ def __call__(self, algo: Algorithm):
log.debug("...logged")


class QualityMetrics(ImageQualityCallback):
"""From https://github.com/SyneRBI/PETRIC/wiki#metrics-and-thresholds"""
def __init__(self, reference_image, backround_mask, foreground_mask, **kwargs):
super().__init__(reference_image, **kwargs)
self.background = np.where(backround_mask == 1)
self.foreground = np.where(foreground_mask == 1)

def __call__(self, algorithm):
from skimage import metrics as sm

iteration = algorithm.iteration
if iteration % algorithm.update_objective_interval != 0 and iteration != algorithm.max_iteration:
return
test_image = algorithm.x # CIL or SIRF ImageData

# # (0) objective value
# objective = algorithm.get_last_objective(return_all=False)
# self.tb_summary_writer.add_scalar('objective', objective, iteration)

test_im_arr, ref_im_arr = test_image.as_array(), self.reference_image.as_array()

for filter_name, filter_func in self.filter.items():
if filter_func is not None:
test_im, ref_im = map(filter_func, (test_im_arr, ref_im_arr))

# (1) global metrics & statistics
norm = ref_im[self.background].mean()
self.tb_summary_writer.add_scalar(
f"RMSE_foreground{filter_name}",
np.sqrt(sm.mean_squared_error(test_im[self.foreground], ref_im[self.foreground])) / norm, iteration)
self.tb_summary_writer.add_scalar(
f"RMSE_background{filter_name}",
np.sqrt(sm.mean_squared_error(test_im[self.background], ref_im[self.background])) / norm, iteration)

# (2) local metrics & statistics
for roi_name, roi_inds in self.roi_indices.items():
# AEM not to be confused with MAE
self.tb_summary_writer.add_scalar(f"AEM_VOI_{roi_name}{filter_name}",
np.abs(test_im[roi_inds].mean() - ref_im[roi_inds].mean()) / norm,
iteration)


class MetricsWithTimeout(cbks.Callback):
"""Stops the algorithm after `seconds`"""
def __init__(self, seconds=300, outdir=OUTDIR, transverse_slice=None, coronal_slice=None, reference_image=None,
verbose=1):
def __init__(self, seconds=300, outdir=OUTDIR, transverse_slice=None, coronal_slice=None, verbose=1):
super().__init__(verbose)
self._seconds = seconds
self.callbacks = [
cbks.ProgressCallback(),
SaveIters(outdir=outdir),
(tb_cbk := TensorBoard(logdir=outdir, transverse_slice=transverse_slice, coronal_slice=coronal_slice))]

if reference_image:
roi_image_dict = {f'S{i}': STIR.ImageData(f'S{i}.hv') for i in range(1, 8)}
# NB: these metrics are for testing only.
# The final evaluation will use metrics described in https://github.com/SyneRBI/PETRIC/wiki
self.callbacks.append(
ImageQualityCallback(
reference_image, tb_cbk.tb, roi_mask_dict=roi_image_dict, metrics_dict={
'MSE': mean_squared_error, 'MAE': self.mean_absolute_error, 'PSNR': peak_signal_noise_ratio},
statistics_dict={'MEAN': np.mean, 'STDDEV': np.std, 'MAX': np.max}))
self.tb = tb_cbk.tb
self.reset()

def reset(self, seconds=None):
Expand Down Expand Up @@ -144,7 +175,9 @@ def construct_RDP(penalty_strength, initial_image, kappa, max_scaling=1e-3):
return prior


Dataset = namedtuple('Dataset', ['acquired_data', 'additive_term', 'mult_factors', 'OSEM_image', 'prior', 'kappa'])
Dataset = namedtuple('Dataset', [
'acquired_data', 'additive_term', 'mult_factors', 'OSEM_image', 'prior', 'kappa', 'reference_image',
'background_mask', 'foreground_mask', 'voi_masks'])


def get_data(srcdir=".", outdir=OUTDIR, sirf_verbosity=0):
Expand All @@ -165,7 +198,18 @@ def get_data(srcdir=".", outdir=OUTDIR, sirf_verbosity=0):
penalty_strength = 1 / 700 # default choice
prior = construct_RDP(penalty_strength, OSEM_image, kappa)

return Dataset(acquired_data, additive_term, mult_factors, OSEM_image, prior, kappa)
reference_image = STIR.ImageData(str(srcdir / 'reference_image.hv')) if (srcdir /
'reference_image.hv').is_file() else None
background_mask = STIR.ImageData(str(srcdir / 'VOI_background.hv')) if (srcdir /
'VOI_background.hv').is_file() else None
foreground_mask = STIR.ImageData(str(srcdir / 'VOI_foreground.hv')) if (srcdir /
'VOI_foreground.hv').is_file() else None
voi_masks = {
voi.stem: STIR.ImageData(str(voi))
for voi in srcdir.glob("VOI_*.hv") if voi.stem[4:] not in ('background', 'foreground')}

return Dataset(acquired_data, additive_term, mult_factors, OSEM_image, prior, kappa, reference_image,
background_mask, foreground_mask, voi_masks)


if SRCDIR.is_dir():
Expand Down Expand Up @@ -194,7 +238,12 @@ def get_data(srcdir=".", outdir=OUTDIR, sirf_verbosity=0):
assert issubclass(Submission, Algorithm)
for srcdir, outdir, metrics in data_dirs_metrics:
data = get_data(srcdir=srcdir, outdir=outdir)
metrics[0].reset() # timeout from now
metrics_cbk = metrics[0]
if data.reference_image is not None:
metrics_cbk.callbacks.append(
QualityMetrics(data.reference_image, data.background_mask, data.foreground_mask,
tb_summary_writer=metrics_cbk.tb, roi_mask_dict=data.voi_masks))
metrics_cbk.reset() # timeout from now
algo = Submission(data)
try:
algo.run(np.inf, callbacks=metrics + submission_callbacks)
Expand Down

0 comments on commit 5e0f33c

Please sign in to comment.