Skip to content

Commit

Permalink
update after review
Browse files Browse the repository at this point in the history
  • Loading branch information
casperdcl committed Jul 9, 2024
1 parent b043a13 commit e15033a
Showing 1 changed file with 28 additions and 29 deletions.
57 changes: 28 additions & 29 deletions petric.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,38 +90,37 @@ def __call__(self, algo: Algorithm):

class QualityMetrics(ImageQualityCallback):
"""From https://github.com/SyneRBI/PETRIC/wiki#metrics-and-thresholds"""
def __init__(self, reference_image, backround_mask, foreground_mask, **kwargs):
def __init__(self, reference_image, whole_object_mask, background_mask, **kwargs):
super().__init__(reference_image, **kwargs)
self.background = np.where(backround_mask == 1)
self.foreground = np.where(foreground_mask == 1)
self.whole_object_indices = np.where(whole_object_mask == 1)
self.background_indices = np.where(background_mask == 1)
self.ref_im_arr = reference_image.as_array()

def __call__(self, algorithm):
iteration = algorithm.iteration
if iteration % algorithm.update_objective_interval != 0 and iteration != algorithm.max_iteration:
return
test_image = algorithm.x # CIL or SIRF ImageData

test_im_arr, ref_im_arr = test_image.as_array(), self.reference_image.as_array()

for filter_name, filter_func in self.filter.items():
if filter_func is not None:
test_im, ref_im = map(filter_func, (test_im_arr, ref_im_arr))
if filter_func is None:
filter_func = lambda x: x
test_im, ref_im = (filter_func(img_data).as_array() for img_data in (algorithm.x, self.reference_image))

# (1) global metrics & statistics
norm = ref_im[self.background].mean()
self.tb_summary_writer.add_scalar(f"RMSE_foreground{filter_name}",
np.sqrt(mse(test_im[self.foreground], ref_im[self.foreground])) / norm,
iteration)
self.tb_summary_writer.add_scalar(f"RMSE_background{filter_name}",
np.sqrt(mse(test_im[self.background], ref_im[self.background])) / norm,
iteration)
norm = ref_im[self.background_indices].mean()
self.tb_summary_writer.add_scalar(
f"RMSE_whole_object{filter_name}",
np.sqrt(mse(ref_im[self.whole_object_indices], test_im[self.whole_object_indices])) / norm, iteration)
self.tb_summary_writer.add_scalar(
f"RMSE_background{filter_name}",
np.sqrt(mse(ref_im[self.background_indices], test_im[self.background_indices])) / norm, iteration)

# (2) local metrics & statistics
for roi_name, roi_inds in self.roi_indices.items():
for voi_name, voi_indices in self.voi_indices.items():
# AEM not to be confused with MAE
self.tb_summary_writer.add_scalar(f"AEM_VOI_{roi_name}{filter_name}",
np.abs(test_im[roi_inds].mean() - ref_im[roi_inds].mean()) / norm,
iteration)
self.tb_summary_writer.add_scalar(
f"AEM_VOI_{voi_name}{filter_name}",
np.abs(test_im[voi_indices].mean() - ref_im[voi_indices].mean()) / norm, iteration)


class MetricsWithTimeout(cbks.Callback):
Expand Down Expand Up @@ -172,7 +171,7 @@ def construct_RDP(penalty_strength, initial_image, kappa, max_scaling=1e-3):

Dataset = namedtuple('Dataset', [
'acquired_data', 'additive_term', 'mult_factors', 'OSEM_image', 'prior', 'kappa', 'reference_image',
'background_mask', 'foreground_mask', 'voi_masks'])
'background_mask', 'whole_object_mask', 'voi_masks'])


def get_data(srcdir=".", outdir=OUTDIR, sirf_verbosity=0):
Expand All @@ -197,14 +196,14 @@ def get_data(srcdir=".", outdir=OUTDIR, sirf_verbosity=0):
'reference_image.hv').is_file() else None
background_mask = STIR.ImageData(str(srcdir / 'VOI_background.hv')) if (srcdir /
'VOI_background.hv').is_file() else None
foreground_mask = STIR.ImageData(str(srcdir / 'VOI_foreground.hv')) if (srcdir /
'VOI_foreground.hv').is_file() else None
whole_object_mask = STIR.ImageData(str(srcdir / 'VOI_whole_object.hv')) if (srcdir /
'VOI_whole_object.hv').is_file() else None
voi_masks = {
voi.stem: STIR.ImageData(str(voi))
for voi in srcdir.glob("VOI_*.hv") if voi.stem[4:] not in ('background', 'foreground')}
for voi in srcdir.glob("VOI_*.hv") if voi.stem[4:] not in ('background', 'whole_object')}

return Dataset(acquired_data, additive_term, mult_factors, OSEM_image, prior, kappa, reference_image,
background_mask, foreground_mask, voi_masks)
background_mask, whole_object_mask, voi_masks)


if SRCDIR.is_dir():
Expand Down Expand Up @@ -233,12 +232,12 @@ def get_data(srcdir=".", outdir=OUTDIR, sirf_verbosity=0):
assert issubclass(Submission, Algorithm)
for srcdir, outdir, metrics in data_dirs_metrics:
data = get_data(srcdir=srcdir, outdir=outdir)
metrics_cbk = metrics[0]
metrics_with_timeout = metrics[0]
if data.reference_image is not None:
metrics_cbk.callbacks.append(
QualityMetrics(data.reference_image, data.background_mask, data.foreground_mask,
tb_summary_writer=metrics_cbk.tb, roi_mask_dict=data.voi_masks))
metrics_cbk.reset() # timeout from now
metrics_with_timeout.callbacks.append(
QualityMetrics(data.reference_image, data.whole_object_mask, data.background_mask,
tb_summary_writer=metrics_with_timeout.tb, roi_mask_dict=data.voi_masks))
metrics_with_timeout.reset() # timeout from now
algo = Submission(data)
try:
algo.run(np.inf, callbacks=metrics + submission_callbacks)
Expand Down

0 comments on commit e15033a

Please sign in to comment.