Skip to content

Commit

Permalink
callbacks with custom interval
Browse files Browse the repository at this point in the history
- fixes #50
  • Loading branch information
casperdcl committed Jul 12, 2024
1 parent 78d6904 commit 96b66c7
Showing 1 changed file with 35 additions and 20 deletions.
55 changes: 35 additions & 20 deletions petric.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

import sirf.STIR as STIR
from cil.optimisation.algorithms import Algorithm
from cil.optimisation.utilities import callbacks as cbks
from cil.optimisation.utilities import callbacks as cil_callbacks
from img_quality_cil_stir import ImageQualityCallback

log = logging.getLogger('petric')
Expand All @@ -38,17 +38,31 @@
SRCDIR = Path("./data")


class SaveIters(cbks.Callback):
class Callback(cil_callbacks.Callback):
"""
CIL Callback but with `self.skip_iteration` checking `min(self.interval, algo.update_objective_interval)`.
TODO: backport this class to CIL.
"""
def __init__(self, interval: int = 1 << 31, **kwargs):
super().__init__(**kwargs)
self.interval = interval

def skip_iteration(self, algo: Algorithm) -> bool:
return algo.iteration % min(self.interval,
algo.update_objective_interval) != 0 and algo.iteration != algo.max_iteration


class SaveIters(Callback):
"""Saves `algo.x` as "iter_{algo.iteration:04d}.hv" and `algo.loss` in `csv_file`"""
def __init__(self, verbose=1, outdir=OUTDIR, csv_file='objectives.csv'):
super().__init__(verbose)
def __init__(self, outdir=OUTDIR, csv_file='objectives.csv', **kwargs):
super().__init__(**kwargs)
self.outdir = Path(outdir)
self.outdir.mkdir(parents=True, exist_ok=True)
self.csv = csv.writer((self.outdir / csv_file).open("w", buffering=1))
self.csv.writerow(("iter", "objective"))

def __call__(self, algo: Algorithm):
if algo.iteration % algo.update_objective_interval == 0 or algo.iteration == algo.max_iteration:
if not self.skip_iteration(algo):
log.debug("saving iter %d...", algo.iteration)
algo.x.write(str(self.outdir / f'iter_{algo.iteration:04d}.hv'))
self.csv.writerow((algo.iteration, algo.get_last_loss()))
Expand All @@ -57,18 +71,18 @@ def __call__(self, algo: Algorithm):
algo.x.write(str(self.outdir / 'iter_final.hv'))


class StatsLog(cbks.Callback):
class StatsLog(Callback):
"""Log image slices & objective value"""
def __init__(self, verbose=1, transverse_slice=None, coronal_slice=None, vmax=None, logdir=OUTDIR):
super().__init__(verbose)
def __init__(self, transverse_slice=None, coronal_slice=None, vmax=None, logdir=OUTDIR, **kwargs):
super().__init__(**kwargs)
self.transverse_slice = transverse_slice
self.coronal_slice = coronal_slice
self.vmax = vmax
self.x_prev = None
self.tb = logdir if isinstance(logdir, SummaryWriter) else SummaryWriter(logdir=str(logdir))

def __call__(self, algo: Algorithm):
if algo.iteration % algo.update_objective_interval != 0 and algo.iteration != algo.max_iteration:
if self.skip_iteration(algo):
return
log.debug("logging iter %d...", algo.iteration)
# initialise `None` values
Expand All @@ -89,21 +103,22 @@ def __call__(self, algo: Algorithm):
log.debug("...logged")


class QualityMetrics(ImageQualityCallback):
class QualityMetrics(ImageQualityCallback, Callback):
"""From https://github.com/SyneRBI/PETRIC/wiki#metrics-and-thresholds"""
def __init__(self, reference_image, whole_object_mask, background_mask, **kwargs):
super().__init__(reference_image, **kwargs)
def __init__(self, reference_image, whole_object_mask, background_mask, interval: int = 1 << 31, **kwargs):
# TODO: drop multiple inheritance once `interval` included in CIL
Callback.__init__(self, interval=interval)
ImageQualityCallback.__init__(self, reference_image, **kwargs)
self.whole_object_indices = np.where(whole_object_mask.as_array())
self.background_indices = np.where(background_mask.as_array())
self.ref_im_arr = reference_image.as_array()
self.norm = self.ref_im_arr[self.background_indices].mean()

def __call__(self, algo: Algorithm):
iteration = algo.iteration
if iteration % algo.update_objective_interval != 0 and iteration != algo.max_iteration:
if self.skip_iteration(algo):
return
for tag, value in self.evaluate(algo.x).items():
self.tb_summary_writer.add_scalar(tag, value, iteration)
self.tb_summary_writer.add_scalar(tag, value, algo.iteration)

def evaluate(self, test_im: STIR.ImageData) -> dict[str, float]:
assert not any(self.filter.values()), "Filtering not implemented"
Expand All @@ -120,16 +135,16 @@ def evaluate(self, test_im: STIR.ImageData) -> dict[str, float]:
return {**whole, **local}


class MetricsWithTimeout(cbks.Callback):
class MetricsWithTimeout(cil_callbacks.Callback):
"""Stops the algorithm after `seconds`"""
def __init__(self, seconds=300, outdir=OUTDIR, transverse_slice=None, coronal_slice=None, verbose=1):
super().__init__(verbose)
def __init__(self, seconds=300, outdir=OUTDIR, transverse_slice=None, coronal_slice=None, **kwargs):
super().__init__(**kwargs)
self._seconds = seconds
self.callbacks = [
cbks.ProgressCallback(),
cil_callbacks.ProgressCallback(),
SaveIters(outdir=outdir),
(tb_cbk := StatsLog(logdir=outdir, transverse_slice=transverse_slice, coronal_slice=coronal_slice))]
self.tb = tb_cbk.tb
self.tb = tb_cbk.tb # convenient access to the underlying SummaryWriter
self.reset()

def reset(self, seconds=None):
Expand Down

0 comments on commit 96b66c7

Please sign in to comment.