diff --git a/src/ttsim3d/device_handler.py b/src/ttsim3d/device_handler.py index f80c1fe..11437bf 100644 --- a/src/ttsim3d/device_handler.py +++ b/src/ttsim3d/device_handler.py @@ -1,6 +1,6 @@ """Handles cpu/gpu device selection.""" -from typing import Optional +from typing import Optional, Union import psutil import torch @@ -60,6 +60,53 @@ def calculate_batches( return num_batches, atoms_per_batch +def get_device(gpu_ids: Optional[Union[int, list[int]]] = None) -> torch.device: + """Get the appropriate torch device based on availability and user preference. + + Parameters + ---------- + gpu_ids : Optional[Union[int, list[int]]] + Device selection preference: + - None: Use CPU + - -1: Use first available GPU (CUDA or MPS) + - >=0: Use specific CUDA device + - list[int]: Use specific CUDA devices (for multi-GPU) + + Returns + ------- + torch.device + The selected compute device + """ + # Default to CPU + if gpu_ids is None: + return torch.device("cpu") + + # Check for CUDA availability + if torch.cuda.is_available(): + if isinstance(gpu_ids, list): + # Multi-GPU not yet implemented + return torch.device(f"cuda:{gpu_ids[0]}") + elif gpu_ids >= 0: + return torch.device(f"cuda:{gpu_ids}") + else: # gpu_ids == -1 + return torch.device("cuda:0") + + # Check for MPS (Apple Silicon) availability + elif torch.backends.mps.is_available(): + if gpu_ids is not None: # User requested GPU + return torch.device("mps") + + # Fallback to CPU + return torch.device("cpu") + + +def move_tensor_to_device(tensor: torch.Tensor, device: torch.device) -> torch.Tensor: + """Move a tensor to the specified device if it's not already there.""" + if tensor.device != device: + return tensor.to(device) + return tensor + + def select_gpu( gpu_id: Optional[int] = None, ) -> torch.device: diff --git a/src/ttsim3d/grid_coords.py b/src/ttsim3d/grid_coords.py index 8761d9d..68b37c9 100644 --- a/src/ttsim3d/grid_coords.py +++ b/src/ttsim3d/grid_coords.py @@ -58,15 +58,20 @@ def get_atom_voxel_indices( tuple[torch.Tensor,torch.Tensor] The voxel indices and the offset from the edge of the voxel. """ + # Move to device + device = atom_zyx.device + shape_tensor = torch.tensor(upsampled_shape, device=device) + offset_tensor = torch.tensor(offset, device=device) + pixel_size_tensor = torch.tensor(upsampled_pixel_size, device=device) + origin_idx = ( - upsampled_shape[0] / 2, - upsampled_shape[1] / 2, - upsampled_shape[2] / 2, + shape_tensor[0] / 2, + shape_tensor[1] / 2, + shape_tensor[2] / 2, ) + origin_idx_tensor = torch.tensor(origin_idx, device=device) this_coords = ( - (atom_zyx / upsampled_pixel_size) - + torch.tensor(origin_idx).unsqueeze(0) - + offset + (atom_zyx / pixel_size_tensor) + origin_idx_tensor.unsqueeze(0) + offset_tensor ) atom_indices = torch.floor(this_coords) # these are the voxel indices atom_dds = ( @@ -120,11 +125,14 @@ def get_voxel_neighborhood_offsets( The offsets of the voxel neighborhood. """ + device = mean_b_factor.device if isinstance(mean_b_factor, torch.Tensor) else "cpu" # Get the size of the voxel neighbourhood to calculate the potential of each atom size_neighborhood = get_size_neighborhood_cistem( mean_b_factor, upsampled_pixel_size ) - neighborhood_range = torch.arange(-size_neighborhood, size_neighborhood + 1) + neighborhood_range = torch.arange( + -size_neighborhood, size_neighborhood + 1, device=device + ) # Create coordinate grids for the neighborhood sz, sy, sx = torch.meshgrid( neighborhood_range, neighborhood_range, neighborhood_range, indexing="ij" diff --git a/src/ttsim3d/models.py b/src/ttsim3d/models.py index 6aa6077..08704a4 100644 --- a/src/ttsim3d/models.py +++ b/src/ttsim3d/models.py @@ -2,7 +2,7 @@ import os import pathlib -from typing import Annotated, Any, Optional +from typing import Annotated, Any, Optional, Union import torch from pydantic import ( @@ -246,18 +246,19 @@ def get_scale_atom_b_factors(self) -> torch.Tensor: def run( self, - gpu_ids: Optional[int | list[int]] = None, + gpu_ids: Optional[Union[int, list[int]]] = None, atom_indices: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Runs the simulation and returns the simulated volume. Parameters ---------- - gpu_ids: int | list[int] - A list of GPU IDs to use for the simulation. The default is 'None' - which will use the CPU. A value of '-1' will use all available - GPUs, otherwise a list of integers greater than or equal to 0 are - expected. + gpu_ids : Optional[Union[int, list[int]]] + Device selection: + - None: Use CPU + - -1: Use first available GPU + - >=0: Use specific CUDA device + - list[int]: Use specific CUDA devices (future multi-GPU support) atom_indices: torch.Tensor The indices of the atoms to simulate. The default is 'None' which will simulate all atoms in the structure. @@ -298,6 +299,7 @@ def run( apply_dqe=self.simulator_config.apply_dqe, mtf_frequencies=mtf_frequencies, mtf_amplitudes=mtf_amplitudes, + gpu_ids=gpu_ids, ) if self.simulator_config.store_volume: diff --git a/src/ttsim3d/scattering_potential.py b/src/ttsim3d/scattering_potential.py index c49a71b..1fa81b0 100644 --- a/src/ttsim3d/scattering_potential.py +++ b/src/ttsim3d/scattering_potential.py @@ -96,7 +96,7 @@ def get_total_b_param( bPlusB: torch.Tensor Total B parameter for each atom in the neighborhood. """ - b_params = get_b_param(atom_ids) + b_params = get_b_param(atom_ids).to(atom_b_factors.device) bPlusB = 2 * torch.pi / torch.sqrt(atom_b_factors.unsqueeze(1) + b_params) return bPlusB @@ -149,8 +149,8 @@ def get_scattering_potential_of_voxel_batch( device = zyx_coords1.device # Get scattering parameters for atoms - params_a = get_a_param(atom_ids) - params_bPlusB = get_total_b_param(atom_ids, atom_b_factors) + params_a = get_a_param(atom_ids).to(device) + params_bPlusB = get_total_b_param(atom_ids, atom_b_factors).to(device) # Compare signs element-wise for batched coordinates t_all = (zyx_coords1 * zyx_coords2) >= 0 # Shape: (atomN, voxelN, 3) diff --git a/src/ttsim3d/simulate3d.py b/src/ttsim3d/simulate3d.py index 1cd0435..68b8acb 100644 --- a/src/ttsim3d/simulate3d.py +++ b/src/ttsim3d/simulate3d.py @@ -1,6 +1,6 @@ """Simulation of 3D volume and associated helper functions.""" -from typing import Literal +from typing import Literal, Optional, Union import einops import numpy as np @@ -9,7 +9,7 @@ from torch_fourier_filter.dose_weight import cumulative_dose_filter_3d from torch_fourier_filter.mtf import make_mtf_grid -from ttsim3d.device_handler import calculate_batches +from ttsim3d.device_handler import calculate_batches, get_device, move_tensor_to_device from ttsim3d.grid_coords import ( fourier_rescale_3d_force_size, get_atom_voxel_indices, @@ -349,6 +349,7 @@ def calculate_simulation_dose_filter_3d( modify_signal: Literal["None", "sqrt", "rel_diff"], rfft: bool = True, fftshift: bool = False, + device: torch.device = None, ) -> torch.Tensor: """Helper function to calculate a cumulative dose filter for a simulation. @@ -373,6 +374,8 @@ def calculate_simulation_dose_filter_3d( If True, the filter is returned in rfft format. Default is True. fftshift : bool If True, the filter is fftshifted. Default is False. + device : torch.device + The device to use for the calculation. Default is None. Returns ------- @@ -385,6 +388,7 @@ def calculate_simulation_dose_filter_3d( crit_exposure_bfactor=critical_bfactor, rfft=rfft, fftshift=fftshift, + device=device, ) if modify_signal == "None": @@ -458,6 +462,15 @@ def apply_simulation_filters( The final simulated volume. """ + device = upsampled_volume.device + device_mps = device + # pytorch does not support mps FFT, so need to move to cpu for mps + if device.type == "mps": + device = torch.device("cpu") + upsampled_volume = move_tensor_to_device(upsampled_volume, device) + mtf_amplitudes = move_tensor_to_device(mtf_amplitudes, device) + mtf_frequencies = move_tensor_to_device(mtf_frequencies, device) + upsampled_volume_rfft = torch.fft.rfftn(upsampled_volume, dim=(-3, -2, -1)) upsampled_shape = upsampled_volume.shape @@ -471,7 +484,9 @@ def apply_simulation_filters( modify_signal=dose_filter_modify_signal, rfft=True, fftshift=False, + device=device, ) + upsampled_volume_rfft *= dose_filter # Fourier crop back to desired size @@ -492,6 +507,7 @@ def apply_simulation_filters( mtf_amplitudes=mtf_amplitudes, rfft=True, fftshift=False, + device=device, ) upsampled_volume_rfft *= mtf @@ -504,6 +520,10 @@ def apply_simulation_filters( # NOTE: ifftshift not needed since volume here was never fftshifted # cropped_volume = torch.fft.ifftshift(cropped_volume, dim=(-3, -2, -1)) + # Move back to upsampled volume device if mps + if device_mps.type == "mps": + cropped_volume = cropped_volume.to(device_mps) + return cropped_volume @@ -523,7 +543,7 @@ def simulate3d( apply_dqe: bool = False, mtf_frequencies: torch.Tensor = None, mtf_amplitudes: torch.Tensor = None, - # gpu_ids: int | list[int] = -999, # TODO: implement gpu selection + gpu_ids: Optional[Union[int, list[int]]] = None, ) -> torch.Tensor: """Simulate 3D electron scattering volume with requested parameters. @@ -575,22 +595,35 @@ def simulate3d( The amplitudes for the modulation transfer function (MTF) filter at the corresponding frequencies. Must be the same length as 'mtf_frequencies'. Required if 'apply_dqe' is True. - # gpu_ids : int | list[int] + gpu_ids : Optional[Union[int, list[int]]] + Device selection: + - None: Use CPU + - -1: Use first available GPU + - >=0: Use specific CUDA device + - list[int]: Use specific CUDA devices (future multi-GPU support) Returns ------- torch.Tensor The simulated 3D volume in real space. """ + # Get compute device + device = get_device(gpu_ids) + + # Move input tensors to device + atom_positions_zyx = move_tensor_to_device(atom_positions_zyx, device) + atom_b_factors = move_tensor_to_device(atom_b_factors, device) + if mtf_frequencies is not None: + mtf_frequencies = move_tensor_to_device(mtf_frequencies, device) + if mtf_amplitudes is not None: + mtf_amplitudes = move_tensor_to_device(mtf_amplitudes, device) + # Validate portions of the input before continuing _validate_dose_filter_inputs( dose_filter_modify_signal, dose_filter_critical_bfactor ) _validate_dqe_filter_inputs(apply_dqe, mtf_frequencies, mtf_amplitudes) - # Select devices for calculation - # TODO: Implement GPU selection - # Calculate the atom-wise scattering potentials lead_term = _calculate_lead_term(beam_energy_kev, sim_pixel_spacing) @@ -607,7 +640,7 @@ def simulate3d( actual_upsampling = setup_results["actual_upsampling"] # Nowsplit to batches - upsampled_volume = torch.zeros(upsampled_shape, dtype=torch.float32) + upsampled_volume = torch.zeros(upsampled_shape, dtype=torch.float32, device=device) num_batches, atoms_per_batch = calculate_batches(setup_results, upsampled_volume)