diff --git a/python/stempy/io/__init__.py b/python/stempy/io/__init__.py index d9b9efd3..22ffc9a3 100644 --- a/python/stempy/io/__init__.py +++ b/python/stempy/io/__init__.py @@ -259,16 +259,17 @@ def save_electron_counts(path, array): """ array.write_to_hdf5(path) -def load_electron_counts(path): +def load_electron_counts(path, keep_flyback=True): """Load electron counted data from an HDF5 file. :param path: path to the HDF5 file. :type path: str - + :param keep_flyback: option to crop the flyback column during loading + :type keep_flyback: bool :return: a SparseArray containing the electron counted data :rtype: SparseArray """ - return SparseArray.from_hdf5(path) + return SparseArray.from_hdf5(path, keep_flyback=keep_flyback) def save_stem_images(outputFile, images, names): """Save STEM images to an HDF5 file. diff --git a/python/stempy/io/sparse_array.py b/python/stempy/io/sparse_array.py index 70e83e18..cec4a04b 100644 --- a/python/stempy/io/sparse_array.py +++ b/python/stempy/io/sparse_array.py @@ -170,11 +170,13 @@ def _validate(self): raise Exception(msg) @classmethod - def from_hdf5(cls, filepath, **init_kwargs): + def from_hdf5(cls, filepath, keep_flyback=True, **init_kwargs): """Create a SparseArray from a stempy HDF5 file :param filepath: the path to the HDF5 file :type filepath: str + :param keep_flyback: option to crop the flyback column during loading + :type keep_flyback: bool :param init_kwargs: any kwargs to forward to SparseArray.__init__() :type init_kwargs: dict @@ -188,19 +190,31 @@ def from_hdf5(cls, filepath, **init_kwargs): frames = f['electron_events/frames'] scan_positions_group = f['electron_events/scan_positions'] - - data = frames[()] scan_shape = [scan_positions_group.attrs[x] for x in ['Nx', 'Ny']] frame_shape = [frames.attrs[x] for x in ['Nx', 'Ny']] - - scan_positions = scan_positions_group[()] + + if keep_flyback: + data = frames[()] # load the full data set + scan_positions = scan_positions_group[()] + else: + # Generate the original scan indices from the scan_shape + orig_indices = np.ravel_multi_index([ii.ravel() for ii in np.indices(scan_shape)],scan_shape) + # Remove the indices of the last column + crop_indices = np.delete(orig_indices, orig_indices[scan_shape[0]-1::scan_shape[0]]) + # Load only the data needed + data = frames[crop_indices] + # Reduce the column shape by 1 + scan_shape[0] = scan_shape[0] - 1 + # Create the proper scan_positions without the flyback column + scan_positions = np.ravel_multi_index([ii.ravel() for ii in np.indices(scan_shape)],scan_shape) + # Load any metadata metadata = {} if 'metadata' in f: load_h5_to_dict(f['metadata'], metadata) scan_shape = scan_shape[::-1] - + if version >= 3: # Convert to int to avoid integer division that results in # a float diff --git a/tests/test_sparse_array.py b/tests/test_sparse_array.py index 98cdf5e9..e088884c 100644 --- a/tests/test_sparse_array.py +++ b/tests/test_sparse_array.py @@ -710,6 +710,12 @@ def compare_with_sparse(full, sparse): assert np.array_equal(m_array[[False, True], 0][0], position_one) +def test_keep_flyback(electron_data_small): + flyback = SparseArray.from_hdf5(electron_data_small, keep_flyback=True) + assert flyback.scan_shape[1] == 50 + no_flyback = SparseArray.from_hdf5(electron_data_small, keep_flyback=False) + assert no_flyback.scan_shape[1] == 49 + # Test binning until this number TEST_BINNING_UNTIL = 33