Skip to content

Commit

Permalink
Merge pull request #104 from SyneRBI/sagittal
Browse files Browse the repository at this point in the history
TensorBoard: add sagittal slice
  • Loading branch information
casperdcl authored Sep 11, 2024
2 parents 4bda727 + b1e31da commit 5c1fa8a
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions petric.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,12 @@ def __call__(self, algo: Algorithm):

class StatsLog(Callback):
"""Log image slices & objective value"""
def __init__(self, transverse_slice=None, coronal_slice=None, vmax=None, logdir=OUTDIR, **kwargs):
def __init__(self, transverse_slice=None, coronal_slice=None, sagittal_slice=None, vmax=None, logdir=OUTDIR,
**kwargs):
super().__init__(**kwargs)
self.transverse_slice = transverse_slice
self.coronal_slice = coronal_slice
self.sagittal_slice = sagittal_slice
self.vmax = vmax
self.x_prev = None
self.tb = logdir if isinstance(logdir, SummaryWriter) else SummaryWriter(logdir=str(logdir))
Expand All @@ -89,6 +91,7 @@ def __call__(self, algo: Algorithm):
# initialise `None` values
self.transverse_slice = algo.x.dimensions()[0] // 2 if self.transverse_slice is None else self.transverse_slice
self.coronal_slice = algo.x.dimensions()[1] // 2 if self.coronal_slice is None else self.coronal_slice
self.sagittal_slice = algo.x.dimensions()[2] // 2 if self.sagittal_slice is None else self.sagittal_slice
self.vmax = algo.x.max() if self.vmax is None else self.vmax

self.tb.add_scalar("objective", algo.get_last_loss(), algo.iteration, t)
Expand All @@ -97,9 +100,11 @@ def __call__(self, algo: Algorithm):
self.tb.add_scalar("normalised_change", normalised_change, algo.iteration, t)
self.x_prev = algo.x.clone()
x_arr = algo.x.as_array()
self.tb.add_image("transverse", np.clip(x_arr[self.transverse_slice:self.transverse_slice + 1] / self.vmax, 0,
1), algo.iteration, t)
self.tb.add_image("transverse", np.clip(x_arr[None, self.transverse_slice] / self.vmax, 0, 1), algo.iteration,
t)
self.tb.add_image("coronal", np.clip(x_arr[None, :, self.coronal_slice] / self.vmax, 0, 1), algo.iteration, t)
self.tb.add_image("sagittal", np.clip(x_arr[None, :, :, self.sagittal_slice] / self.vmax, 0, 1), algo.iteration,
t)
log.debug("...logged")


Expand Down Expand Up @@ -148,7 +153,8 @@ def __init__(self, seconds=600, outdir=OUTDIR, transverse_slice=None, coronal_sl
self.callbacks = [
cil_callbacks.ProgressCallback(),
SaveIters(outdir=outdir),
(tb_cbk := StatsLog(logdir=outdir, transverse_slice=transverse_slice, coronal_slice=coronal_slice))]
(tb_cbk := StatsLog(logdir=outdir, transverse_slice=transverse_slice, coronal_slice=coronal_slice,
sagittal_slice=sagittal_slice))]
self.tb = tb_cbk.tb # convenient access to the underlying SummaryWriter
self.reset()

Expand Down

0 comments on commit 5c1fa8a

Please sign in to comment.