Skip to content

Commit

Permalink
Refactored OptimisationLoop dumping
Browse files Browse the repository at this point in the history
  • Loading branch information
ccuetom committed Nov 15, 2024
1 parent c58a7ba commit 3bb49b1
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 69 deletions.
28 changes: 14 additions & 14 deletions mosaic/file_manipulation/h5.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,10 @@ def append(name, obj, group):
append(sub_group_name, obj[index], sub_group)

else:
if name not in group:
_write_dataset(name, obj, group)
_write_dataset(name, obj, group)


def _write_dataset(name, obj, group):
if name in group:
return

is_bytes = False
if isinstance(obj, bytes):
is_bytes = True
Expand All @@ -110,17 +106,21 @@ def _write_dataset(name, obj, group):
is_none = True
obj = 'None'

dataset = group.create_dataset(name, data=obj)
dataset.attrs['is_ndarray'] = isinstance(obj, np.ndarray)
dataset.attrs['is_list'] = isinstance(obj, list)
dataset.attrs['is_tuple'] = isinstance(obj, tuple)
dataset.attrs['is_str'] = isinstance(obj, str)
dataset.attrs['is_bytes'] = is_bytes
dataset.attrs['is_none'] = is_none
if name not in group:
dataset = group.create_dataset(name, data=obj)
dataset.attrs['is_ndarray'] = isinstance(obj, np.ndarray)
dataset.attrs['is_list'] = isinstance(obj, list)
dataset.attrs['is_tuple'] = isinstance(obj, tuple)
dataset.attrs['is_str'] = isinstance(obj, str)
dataset.attrs['is_bytes'] = is_bytes
dataset.attrs['is_none'] = is_none

if isinstance(obj, list) and len(obj):
flat_obj = np.asarray(obj).flatten().tolist()
dataset.attrs['is_str'] = isinstance(flat_obj[0], str)
if name not in group:
dataset.attrs['is_str'] = isinstance(flat_obj[0], str)

else:
group[name][...] = obj


def read(obj, lazy=True, filter=None, only=None):
Expand Down
3 changes: 0 additions & 3 deletions stride/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,9 +412,6 @@ def _dealloc(*args):

await asyncio.gather(*summ_returns)

# loop = mosaic.get_event_loop()
# loop.run(asyncio.gather, *parallel_returns)

self.clear_graph()

return self
Expand Down
9 changes: 7 additions & 2 deletions stride/optimisation/loss/l2_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,16 @@ def forward(self, modelled, observed, **kwargs):
else kwargs.pop('shot_id', 0)

residual_data = modelled.data-observed.data
residual = observed.alike(name='residual', data=residual_data[:, ::self.d_sample])
residual_data_sampled = residual_data[:, ::self.d_sample]

residual = observed.alike(name='residual', data=residual_data)
residual_sampled = observed.alike(name='residual', data=residual_data_sampled,
shape=residual_data_sampled.shape,
extended_shape=residual_data_sampled.shape)
self.residual = residual

fun_data = 0.5 * np.sum(residual.data ** 2)
fun = FunctionalValue(fun_data, shot_id, residual, **kwargs)
fun = FunctionalValue(fun_data, shot_id, residual_sampled, **kwargs)

return fun

Expand Down
Loading

0 comments on commit 3bb49b1

Please sign in to comment.