Skip to content

Commit

Permalink
black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
bsavitzky committed Nov 18, 2024
1 parent 13357b1 commit d3e227b
Show file tree
Hide file tree
Showing 13 changed files with 97 additions and 84 deletions.
1 change: 0 additions & 1 deletion py4DSTEM/io/filereaders/read_mib.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,6 @@ def scan_size(path, scan):
header_path = path[:-3] + "hdr"
result = {}
if os.path.exists(header_path):

with open(header_path, encoding="UTF-8") as f:
for line in f:
k, v = line.split("\t", 1)
Expand Down
4 changes: 3 additions & 1 deletion py4DSTEM/process/diffraction/WK_scattering_factors.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,9 @@ def RI1(BI, BJ, G):
ri1[sub] = np.pi * (BI * np.log((BI + BJ) / BI) + BJ * np.log((BI + BJ) / BJ))

sub = np.logical_and(eps <= 0.1, G > 0.0)
temp = 0.5 * BI**2 * np.log(BI / (BI + BJ)) + 0.5 * BJ**2 * np.log(BJ / (BI + BJ))
temp = 0.5 * BI**2 * np.log(BI / (BI + BJ)) + 0.5 * BJ**2 * np.log(
BJ / (BI + BJ)
)
temp += 0.75 * (BI**2 + BJ**2) - 0.25 * (BI + BJ) ** 2
temp -= 0.5 * (BI - BJ) ** 2
ri1[sub] += np.pi * G[sub] ** 2 * temp
Expand Down
3 changes: 2 additions & 1 deletion py4DSTEM/process/diffraction/crystal_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,8 @@ def plot_scattering_intensity(
int_sf_plot = calc_1D_profile(
k,
self.g_vec_leng,
(self.struct_factors_int**int_power_scale) * (self.g_vec_leng**k_power_scale),
(self.struct_factors_int**int_power_scale)
* (self.g_vec_leng**k_power_scale),
remove_origin=True,
k_broadening=k_broadening,
int_scale=int_scale,
Expand Down
4 changes: 2 additions & 2 deletions py4DSTEM/process/diffraction/digital_dark_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def DDFimage(points_array, aperture_positions, Rshape=None, tol=1):

def radial_filtered_array(points_array_w_rphi, radius, tol=1):
"""
Calculates a Filtered points array from a list of detected diffraction peak positions in a points_array
Calculates a Filtered points array from a list of detected diffraction peak positions in a points_array
matching a specific qr radius, within a defined matching tolerance
Parameters
Expand Down Expand Up @@ -579,7 +579,7 @@ def DDF_radial_image(points_array_w_rphi, radius, Rshape, tol=1):

def DDFradialazimuthimage(points_array_w_rphi, radius, phi0, phi1, Rshape, tol=1):
"""
Calculates a Digital Dark Field image from a list of detected diffraction peak positions in a points_array
Calculates a Digital Dark Field image from a list of detected diffraction peak positions in a points_array
matching a specific qr radius, within a defined matching tolerance, and only within a defined azimuthal range
Parameters
Expand Down
3 changes: 2 additions & 1 deletion py4DSTEM/process/fit/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ def polar_gaussian_2D(
# t2 = np.min(np.vstack([t,1-t]))
t2 = np.square(t - mu_t)
return (
I0 * np.exp(-(t2 / (2 * sigma_t**2) + (q - mu_q) ** 2 / (2 * sigma_q**2))) + C
I0 * np.exp(-(t2 / (2 * sigma_t**2) + (q - mu_q) ** 2 / (2 * sigma_q**2)))
+ C
)


Expand Down
28 changes: 14 additions & 14 deletions py4DSTEM/process/phase/magnetic_ptychographic_tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,20 +1202,20 @@ def reconstruct(

# position correction
if not fix_positions and a0 > 0:
self._positions_px_all[batch_indices] = (
self._position_correction(
object_sliced,
vectorized_patch_indices_row,
vectorized_patch_indices_col,
shifted_probes,
overlap,
amplitudes_device,
positions_px,
positions_px_initial,
positions_step_size,
max_position_update_distance,
max_position_total_distance,
)
self._positions_px_all[
batch_indices
] = self._position_correction(
object_sliced,
vectorized_patch_indices_row,
vectorized_patch_indices_col,
shifted_probes,
overlap,
amplitudes_device,
positions_px,
positions_px_initial,
positions_step_size,
max_position_update_distance,
max_position_total_distance,
)

measurement_error += batch_error
Expand Down
28 changes: 14 additions & 14 deletions py4DSTEM/process/phase/magnetic_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -1510,20 +1510,20 @@ def reconstruct(

# position correction
if not fix_positions and a0 > 0:
self._positions_px_all[batch_indices] = (
self._position_correction(
self._object,
vectorized_patch_indices_row,
vectorized_patch_indices_col,
shifted_probes,
overlap,
amplitudes_device,
positions_px,
positions_px_initial,
positions_step_size,
max_position_update_distance,
max_position_total_distance,
)
self._positions_px_all[
batch_indices
] = self._position_correction(
self._object,
vectorized_patch_indices_row,
vectorized_patch_indices_col,
shifted_probes,
overlap,
amplitudes_device,
positions_px,
positions_px_initial,
positions_step_size,
max_position_update_distance,
max_position_total_distance,
)

measurement_error += batch_error
Expand Down
22 changes: 11 additions & 11 deletions py4DSTEM/process/phase/parallax.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,15 +884,17 @@ def guess_common_aberrations(
sampling = 1 / (
np.array(self._reciprocal_sampling) * self._region_of_interest_shape
)
aberrations_basis, aberrations_basis_du, aberrations_basis_dv = (
calculate_aberration_gradient_basis(
aberrations_mn,
sampling,
self._region_of_interest_shape,
self._wavelength,
rotation_angle=np.deg2rad(rotation_angle_deg),
xp=xp,
)
(
aberrations_basis,
aberrations_basis_du,
aberrations_basis_dv,
) = calculate_aberration_gradient_basis(
aberrations_mn,
sampling,
self._region_of_interest_shape,
self._wavelength,
rotation_angle=np.deg2rad(rotation_angle_deg),
xp=xp,
)

# shifts
Expand Down Expand Up @@ -2432,7 +2434,6 @@ def score_CTF(coefs):

# Plot the measured/fitted shifts comparison
if plot_BF_shifts_comparison:

fitted_shifts = (
xp.tensordot(gradients, xp.array(self._aberrations_coefs), axes=1)
.reshape((2, -1))
Expand Down Expand Up @@ -3055,7 +3056,6 @@ def show_shifts(
shifts = shifts_px * scale_arrows * xp.array(self._reciprocal_sampling)

if plot_rotated_shifts and hasattr(self, "rotation_Q_to_R_rads"):

if figax is None:
figsize = kwargs.pop("figsize", (8, 4))
fig, ax = plt.subplots(1, 2, figsize=figsize)
Expand Down
4 changes: 3 additions & 1 deletion py4DSTEM/process/phase/ptychographic_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,9 @@ def _precompute_propagator_arrays(
propagators[i] = xp.exp(
1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz)
)
propagators[i] *= xp.exp(1.0j * (-(ky**2)[None] * np.pi * wavelength * dz))
propagators[i] *= xp.exp(
1.0j * (-(ky**2)[None] * np.pi * wavelength * dz)
)

if theta_x is not None:
propagators[i] *= xp.exp(
Expand Down
28 changes: 14 additions & 14 deletions py4DSTEM/process/phase/ptychographic_tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,20 +1093,20 @@ def reconstruct(

# position correction
if not fix_positions:
self._positions_px_all[batch_indices] = (
self._position_correction(
object_sliced,
vectorized_patch_indices_row,
vectorized_patch_indices_col,
shifted_probes,
overlap,
amplitudes_device,
positions_px,
positions_px_initial,
positions_step_size,
max_position_update_distance,
max_position_total_distance,
)
self._positions_px_all[
batch_indices
] = self._position_correction(
object_sliced,
vectorized_patch_indices_row,
vectorized_patch_indices_col,
shifted_probes,
overlap,
amplitudes_device,
positions_px,
positions_px_initial,
positions_step_size,
max_position_update_distance,
max_position_total_distance,
)

measurement_error += batch_error
Expand Down
4 changes: 3 additions & 1 deletion py4DSTEM/process/phase/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,9 @@ def evaluate_gaussian_envelope(
self, alpha: Union[float, np.ndarray]
) -> Union[float, np.ndarray]:
xp = self._xp
return xp.exp(-0.5 * self._gaussian_spread**2 * alpha**2 / self._wavelength**2)
return xp.exp(
-0.5 * self._gaussian_spread**2 * alpha**2 / self._wavelength**2
)

def evaluate_spatial_envelope(
self, alpha: Union[float, np.ndarray], phi: Union[float, np.ndarray]
Expand Down
34 changes: 14 additions & 20 deletions py4DSTEM/process/phase/xray_magnetic_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,7 +892,6 @@ def _gradient_descent_adjoint(

match (self._recon_mode, self._active_measurement_index):
case (0, 0) | (1, 0): # reverse

magnetic_conj = xp.exp(1.0j * xp.conj(object_patches[1]))

probe_magnetic_abs = xp.abs(shifted_probes * magnetic_conj)
Expand Down Expand Up @@ -930,7 +929,6 @@ def _gradient_descent_adjoint(
)

if not fix_probe:

electrostatic_magnetic_abs = xp.abs(
electrostatic_conj * magnetic_conj
)
Expand Down Expand Up @@ -962,7 +960,6 @@ def _gradient_descent_adjoint(
)

case (0, 1) | (1, 2) | (2, 1): # forward

magnetic_conj = xp.exp(-1.0j * xp.conj(object_patches[1]))

probe_magnetic_abs = xp.abs(shifted_probes * magnetic_conj)
Expand Down Expand Up @@ -992,7 +989,6 @@ def _gradient_descent_adjoint(
)

if not fix_probe:

electrostatic_magnetic_abs = xp.abs(
electrostatic_conj * magnetic_conj
)
Expand Down Expand Up @@ -1024,7 +1020,6 @@ def _gradient_descent_adjoint(
)

case (1, 1) | (2, 0): # neutral

probe_abs = xp.abs(shifted_probes)
probe_normalization = self._sum_overlapping_patches_bincounts(
probe_abs**2,
Expand All @@ -1047,7 +1042,6 @@ def _gradient_descent_adjoint(
)

if not fix_probe:

electrostatic_abs = xp.abs(electrostatic_conj)
electrostatic_normalization = xp.sum(
electrostatic_abs**2,
Expand Down Expand Up @@ -1482,20 +1476,20 @@ def reconstruct(

# position correction
if not fix_positions and a0 > 0:
self._positions_px_all[batch_indices] = (
self._position_correction(
self._object,
vectorized_patch_indices_row,
vectorized_patch_indices_col,
shifted_probes,
overlap,
amplitudes_device,
positions_px,
positions_px_initial,
positions_step_size,
max_position_update_distance,
max_position_total_distance,
)
self._positions_px_all[
batch_indices
] = self._position_correction(
self._object,
vectorized_patch_indices_row,
vectorized_patch_indices_col,
shifted_probes,
overlap,
amplitudes_device,
positions_px,
positions_px_initial,
positions_step_size,
max_position_update_distance,
max_position_total_distance,
)

measurement_error += batch_error
Expand Down
18 changes: 15 additions & 3 deletions py4DSTEM/process/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,12 @@ def electron_wavelength_angstrom(E_eV):
c = 299792458
h = 6.62607 * 10**-34

lam = h / ma.sqrt(2 * m * e * E_eV) / ma.sqrt(1 + e * E_eV / 2 / m / c**2) * 10**10
lam = (
h
/ ma.sqrt(2 * m * e * E_eV)
/ ma.sqrt(1 + e * E_eV / 2 / m / c**2)
* 10**10
)
return lam


Expand All @@ -102,8 +107,15 @@ def electron_interaction_parameter(E_eV):
e = 1.602177 * 10**-19
c = 299792458
h = 6.62607 * 10**-34
lam = h / ma.sqrt(2 * m * e * E_eV) / ma.sqrt(1 + e * E_eV / 2 / m / c**2) * 10**10
sigma = (2 * np.pi / lam / E_eV) * (m * c**2 + e * E_eV) / (2 * m * c**2 + e * E_eV)
lam = (
h
/ ma.sqrt(2 * m * e * E_eV)
/ ma.sqrt(1 + e * E_eV / 2 / m / c**2)
* 10**10
)
sigma = (
(2 * np.pi / lam / E_eV) * (m * c**2 + e * E_eV) / (2 * m * c**2 + e * E_eV)
)
return sigma


Expand Down

0 comments on commit d3e227b

Please sign in to comment.