From 72161123d7e5ee2cbcf7626cda603dd58953aef1 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 14 Nov 2024 22:37:30 -0800 Subject: [PATCH] small tweaks to mixed-state virtual-detector ptycho --- .../process/phase/ptychographic_methods.py | 96 ++++++++++++------- 1 file changed, 64 insertions(+), 32 deletions(-) diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index 1a808a3d5..e88c97141 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -1479,10 +1479,10 @@ def _initialize_probe( for i_probe in range(1, num_probes): shift_x = xp.exp( -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sx) - ) + ).astype(xp.complex64) shift_y = xp.exp( -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sy) - ) + ).astype(xp.complex64) _probes[i_probe] = ( _probes[i_probe - 1] * shift_x[:, None] * shift_y[None] ) @@ -1750,31 +1750,44 @@ def _gradient_descent_fourier_projection( xp=xp, ) + if fourier_mask is not None: + fourier_overlap *= fourier_mask + if virtual_detector_masks is not None: + mask_sums = virtual_detector_masks.sum((-1, -2)) inverse_mask = (1 - virtual_detector_masks.sum(0)).astype(xp.bool_) - old_fourier_overlap_sum = xp.sum(xp.abs(fourier_overlap) ** 2) - - # serial loop to allow large number of detector masks - for mask in virtual_detector_masks: - val = xp.sum(fourier_overlap * mask, axis=(-1, -2)) / xp.sum(mask) - fourier_overlap[..., mask] = val[:, None] - + abs_fourier_overlap = self._return_farfield_amplitudes(fourier_overlap) + old_fourier_overlap_sum = xp.sum(abs_fourier_overlap**2) fourier_overlap[..., inverse_mask] = 0.0 - # normalize to avoid losing electrons + fourier_overlap_binned = xp.full_like( + fourier_overlap, fill_value=1e-16, dtype=xp.float32 + ) + for mask, mask_sum in zip(virtual_detector_masks, mask_sums): + val = xp.sqrt( + xp.sum(abs_fourier_overlap**2 * mask, axis=(-1, -2)) / mask_sum + ) + fourier_overlap_binned[..., mask] = val[:, None] + new_fourier_overlap_sum = xp.sum(xp.abs(fourier_overlap) ** 2) - fourier_overlap *= xp.sqrt( + fourier_overlap_binned *= xp.sqrt( old_fourier_overlap_sum / new_fourier_overlap_sum ) + fourier_modified_overlap = ( + fourier_overlap * amplitudes / fourier_overlap_binned + ) + farfield_amplitudes = self._return_farfield_amplitudes( + fourier_overlap_binned + ) + else: + fourier_modified_overlap = amplitudes * xp.exp( + 1j * xp.angle(fourier_overlap) + ) + farfield_amplitudes = self._return_farfield_amplitudes(fourier_overlap) - if fourier_mask is not None: - fourier_overlap *= fourier_mask - - farfield_amplitudes = self._return_farfield_amplitudes(fourier_overlap) error = xp.sum(xp.abs(amplitudes - farfield_amplitudes) ** 2) - fourier_modified_overlap = amplitudes * xp.exp(1j * xp.angle(fourier_overlap)) - fourier_modified_overlap = fourier_modified_overlap - fourier_overlap + if fourier_mask is not None: fourier_modified_overlap *= fourier_mask @@ -3010,27 +3023,46 @@ def _gradient_descent_fourier_projection( xp=xp, ) - if virtual_detector_masks is not None: - masked_values = xp.sum( - fourier_overlap[:, :, None, :, :] - * virtual_detector_masks[None, None, :, :, :], - axis=(-1, -2), - ).transpose(2, 0, 1) - fourier_overlap = xp.zeros_like(fourier_overlap) - for mask, value in zip(virtual_detector_masks, masked_values): - fourier_overlap[..., mask] = value[:, :, None] / xp.sum(mask) - if fourier_mask is not None: fourier_overlap *= fourier_mask - farfield_amplitudes = self._return_farfield_amplitudes(fourier_overlap) - error = xp.sum(xp.abs(amplitudes - farfield_amplitudes) ** 2) + if virtual_detector_masks is not None: + mask_sums = virtual_detector_masks.sum((-1, -2)) + inverse_mask = (1 - virtual_detector_masks.sum(0)).astype(xp.bool_) + abs_fourier_overlap = self._return_farfield_amplitudes(fourier_overlap) + old_fourier_overlap_sum = xp.sum(abs_fourier_overlap**2) + fourier_overlap[..., inverse_mask] = 0.0 - farfield_amplitudes[farfield_amplitudes == 0.0] = np.inf - amplitude_modification = amplitudes / farfield_amplitudes + fourier_overlap_binned = xp.full_like( + fourier_overlap, fill_value=1e-16, dtype=xp.float32 + ) + for mask, mask_sum in zip(virtual_detector_masks, mask_sums): + val = xp.sqrt( + xp.sum(abs_fourier_overlap**2 * mask, axis=(-1, -2)) + / mask_sum + / self._num_probes + ) + fourier_overlap_binned[..., mask] = val[:, None, None] - fourier_modified_overlap = amplitude_modification[:, None] * fourier_overlap + abs_fourier_overlap = self._return_farfield_amplitudes(fourier_overlap) + new_fourier_overlap_sum = xp.sum(abs_fourier_overlap**2) + fourier_overlap_binned *= xp.sqrt( + old_fourier_overlap_sum / new_fourier_overlap_sum + ) + fourier_modified_overlap = ( + fourier_overlap * amplitudes[:, None] / fourier_overlap_binned + ) + farfield_amplitudes = self._return_farfield_amplitudes( + fourier_overlap_binned + ) + else: + farfield_amplitudes = self._return_farfield_amplitudes(fourier_overlap) + farfield_amplitudes[farfield_amplitudes == 0.0] = np.inf + amplitude_modification = amplitudes / farfield_amplitudes + fourier_modified_overlap = amplitude_modification[:, None] * fourier_overlap + + error = xp.sum(xp.abs(amplitudes - farfield_amplitudes) ** 2) fourier_modified_overlap = fourier_modified_overlap - fourier_overlap if fourier_mask is not None: