Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TensorBoard: add sagittal slice #104

Merged
merged 1 commit into from
Sep 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions petric.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,12 @@ def __call__(self, algo: Algorithm):

class StatsLog(Callback):
"""Log image slices & objective value"""
def __init__(self, transverse_slice=None, coronal_slice=None, vmax=None, logdir=OUTDIR, **kwargs):
def __init__(self, transverse_slice=None, coronal_slice=None, sagittal_slice=None, vmax=None, logdir=OUTDIR,
**kwargs):
super().__init__(**kwargs)
self.transverse_slice = transverse_slice
self.coronal_slice = coronal_slice
self.sagittal_slice = sagittal_slice
self.vmax = vmax
self.x_prev = None
self.tb = logdir if isinstance(logdir, SummaryWriter) else SummaryWriter(logdir=str(logdir))
Expand All @@ -89,6 +91,7 @@ def __call__(self, algo: Algorithm):
# initialise `None` values
self.transverse_slice = algo.x.dimensions()[0] // 2 if self.transverse_slice is None else self.transverse_slice
self.coronal_slice = algo.x.dimensions()[1] // 2 if self.coronal_slice is None else self.coronal_slice
self.sagittal_slice = algo.x.dimensions()[2] // 2 if self.sagittal_slice is None else self.sagittal_slice
self.vmax = algo.x.max() if self.vmax is None else self.vmax

self.tb.add_scalar("objective", algo.get_last_loss(), algo.iteration, t)
Expand All @@ -97,9 +100,11 @@ def __call__(self, algo: Algorithm):
self.tb.add_scalar("normalised_change", normalised_change, algo.iteration, t)
self.x_prev = algo.x.clone()
x_arr = algo.x.as_array()
self.tb.add_image("transverse", np.clip(x_arr[self.transverse_slice:self.transverse_slice + 1] / self.vmax, 0,
1), algo.iteration, t)
self.tb.add_image("transverse", np.clip(x_arr[None, self.transverse_slice] / self.vmax, 0, 1), algo.iteration,
t)
self.tb.add_image("coronal", np.clip(x_arr[None, :, self.coronal_slice] / self.vmax, 0, 1), algo.iteration, t)
self.tb.add_image("sagittal", np.clip(x_arr[None, :, :, self.sagittal_slice] / self.vmax, 0, 1), algo.iteration,
t)
log.debug("...logged")


Expand Down Expand Up @@ -148,7 +153,8 @@ def __init__(self, seconds=600, outdir=OUTDIR, transverse_slice=None, coronal_sl
self.callbacks = [
cil_callbacks.ProgressCallback(),
SaveIters(outdir=outdir),
(tb_cbk := StatsLog(logdir=outdir, transverse_slice=transverse_slice, coronal_slice=coronal_slice))]
(tb_cbk := StatsLog(logdir=outdir, transverse_slice=transverse_slice, coronal_slice=coronal_slice,
sagittal_slice=sagittal_slice))]
self.tb = tb_cbk.tb # convenient access to the underlying SummaryWriter
self.reset()

Expand Down