Skip to content

Commit

Permalink
Merge pull request #13 from teamtomo/gpu_support
Browse files Browse the repository at this point in the history
fix: sim3d GPU code working for cuda and mps
  • Loading branch information
jdickerson95 authored Feb 27, 2025
2 parents e2bcf6a + 2bc9536 commit d1695d9
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 26 deletions.
49 changes: 48 additions & 1 deletion src/ttsim3d/device_handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Handles cpu/gpu device selection."""

from typing import Optional
from typing import Optional, Union

import psutil
import torch
Expand Down Expand Up @@ -60,6 +60,53 @@ def calculate_batches(
return num_batches, atoms_per_batch


def get_device(gpu_ids: Optional[Union[int, list[int]]] = None) -> torch.device:
"""Get the appropriate torch device based on availability and user preference.
Parameters
----------
gpu_ids : Optional[Union[int, list[int]]]
Device selection preference:
- None: Use CPU
- -1: Use first available GPU (CUDA or MPS)
- >=0: Use specific CUDA device
- list[int]: Use specific CUDA devices (for multi-GPU)
Returns
-------
torch.device
The selected compute device
"""
# Default to CPU
if gpu_ids is None:
return torch.device("cpu")

# Check for CUDA availability
if torch.cuda.is_available():
if isinstance(gpu_ids, list):
# Multi-GPU not yet implemented
return torch.device(f"cuda:{gpu_ids[0]}")
elif gpu_ids >= 0:
return torch.device(f"cuda:{gpu_ids}")
else: # gpu_ids == -1
return torch.device("cuda:0")

# Check for MPS (Apple Silicon) availability
elif torch.backends.mps.is_available():
if gpu_ids is not None: # User requested GPU
return torch.device("mps")

# Fallback to CPU
return torch.device("cpu")


def move_tensor_to_device(tensor: torch.Tensor, device: torch.device) -> torch.Tensor:
"""Move a tensor to the specified device if it's not already there."""
if tensor.device != device:
return tensor.to(device)
return tensor


def select_gpu(
gpu_id: Optional[int] = None,
) -> torch.device:
Expand Down
22 changes: 15 additions & 7 deletions src/ttsim3d/grid_coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,20 @@ def get_atom_voxel_indices(
tuple[torch.Tensor,torch.Tensor]
The voxel indices and the offset from the edge of the voxel.
"""
# Move to device
device = atom_zyx.device
shape_tensor = torch.tensor(upsampled_shape, device=device)
offset_tensor = torch.tensor(offset, device=device)
pixel_size_tensor = torch.tensor(upsampled_pixel_size, device=device)

origin_idx = (
upsampled_shape[0] / 2,
upsampled_shape[1] / 2,
upsampled_shape[2] / 2,
shape_tensor[0] / 2,
shape_tensor[1] / 2,
shape_tensor[2] / 2,
)
origin_idx_tensor = torch.tensor(origin_idx, device=device)
this_coords = (
(atom_zyx / upsampled_pixel_size)
+ torch.tensor(origin_idx).unsqueeze(0)
+ offset
(atom_zyx / pixel_size_tensor) + origin_idx_tensor.unsqueeze(0) + offset_tensor
)
atom_indices = torch.floor(this_coords) # these are the voxel indices
atom_dds = (
Expand Down Expand Up @@ -120,11 +125,14 @@ def get_voxel_neighborhood_offsets(
The offsets of the voxel neighborhood.
"""
device = mean_b_factor.device if isinstance(mean_b_factor, torch.Tensor) else "cpu"
# Get the size of the voxel neighbourhood to calculate the potential of each atom
size_neighborhood = get_size_neighborhood_cistem(
mean_b_factor, upsampled_pixel_size
)
neighborhood_range = torch.arange(-size_neighborhood, size_neighborhood + 1)
neighborhood_range = torch.arange(
-size_neighborhood, size_neighborhood + 1, device=device
)
# Create coordinate grids for the neighborhood
sz, sy, sx = torch.meshgrid(
neighborhood_range, neighborhood_range, neighborhood_range, indexing="ij"
Expand Down
16 changes: 9 additions & 7 deletions src/ttsim3d/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import os
import pathlib
from typing import Annotated, Any, Optional
from typing import Annotated, Any, Optional, Union

import torch
from pydantic import (
Expand Down Expand Up @@ -246,18 +246,19 @@ def get_scale_atom_b_factors(self) -> torch.Tensor:

def run(
self,
gpu_ids: Optional[int | list[int]] = None,
gpu_ids: Optional[Union[int, list[int]]] = None,
atom_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Runs the simulation and returns the simulated volume.
Parameters
----------
gpu_ids: int | list[int]
A list of GPU IDs to use for the simulation. The default is 'None'
which will use the CPU. A value of '-1' will use all available
GPUs, otherwise a list of integers greater than or equal to 0 are
expected.
gpu_ids : Optional[Union[int, list[int]]]
Device selection:
- None: Use CPU
- -1: Use first available GPU
- >=0: Use specific CUDA device
- list[int]: Use specific CUDA devices (future multi-GPU support)
atom_indices: torch.Tensor
The indices of the atoms to simulate. The default is 'None' which
will simulate all atoms in the structure.
Expand Down Expand Up @@ -298,6 +299,7 @@ def run(
apply_dqe=self.simulator_config.apply_dqe,
mtf_frequencies=mtf_frequencies,
mtf_amplitudes=mtf_amplitudes,
gpu_ids=gpu_ids,
)

if self.simulator_config.store_volume:
Expand Down
6 changes: 3 additions & 3 deletions src/ttsim3d/scattering_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def get_total_b_param(
bPlusB: torch.Tensor
Total B parameter for each atom in the neighborhood.
"""
b_params = get_b_param(atom_ids)
b_params = get_b_param(atom_ids).to(atom_b_factors.device)
bPlusB = 2 * torch.pi / torch.sqrt(atom_b_factors.unsqueeze(1) + b_params)

return bPlusB
Expand Down Expand Up @@ -149,8 +149,8 @@ def get_scattering_potential_of_voxel_batch(
device = zyx_coords1.device

# Get scattering parameters for atoms
params_a = get_a_param(atom_ids)
params_bPlusB = get_total_b_param(atom_ids, atom_b_factors)
params_a = get_a_param(atom_ids).to(device)
params_bPlusB = get_total_b_param(atom_ids, atom_b_factors).to(device)

# Compare signs element-wise for batched coordinates
t_all = (zyx_coords1 * zyx_coords2) >= 0 # Shape: (atomN, voxelN, 3)
Expand Down
49 changes: 41 additions & 8 deletions src/ttsim3d/simulate3d.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Simulation of 3D volume and associated helper functions."""

from typing import Literal
from typing import Literal, Optional, Union

import einops
import numpy as np
Expand All @@ -9,7 +9,7 @@
from torch_fourier_filter.dose_weight import cumulative_dose_filter_3d
from torch_fourier_filter.mtf import make_mtf_grid

from ttsim3d.device_handler import calculate_batches
from ttsim3d.device_handler import calculate_batches, get_device, move_tensor_to_device
from ttsim3d.grid_coords import (
fourier_rescale_3d_force_size,
get_atom_voxel_indices,
Expand Down Expand Up @@ -349,6 +349,7 @@ def calculate_simulation_dose_filter_3d(
modify_signal: Literal["None", "sqrt", "rel_diff"],
rfft: bool = True,
fftshift: bool = False,
device: torch.device = None,
) -> torch.Tensor:
"""Helper function to calculate a cumulative dose filter for a simulation.
Expand All @@ -373,6 +374,8 @@ def calculate_simulation_dose_filter_3d(
If True, the filter is returned in rfft format. Default is True.
fftshift : bool
If True, the filter is fftshifted. Default is False.
device : torch.device
The device to use for the calculation. Default is None.
Returns
-------
Expand All @@ -385,6 +388,7 @@ def calculate_simulation_dose_filter_3d(
crit_exposure_bfactor=critical_bfactor,
rfft=rfft,
fftshift=fftshift,
device=device,
)

if modify_signal == "None":
Expand Down Expand Up @@ -458,6 +462,15 @@ def apply_simulation_filters(
The final simulated volume.
"""
device = upsampled_volume.device
device_mps = device
# pytorch does not support mps FFT, so need to move to cpu for mps
if device.type == "mps":
device = torch.device("cpu")
upsampled_volume = move_tensor_to_device(upsampled_volume, device)
mtf_amplitudes = move_tensor_to_device(mtf_amplitudes, device)
mtf_frequencies = move_tensor_to_device(mtf_frequencies, device)

upsampled_volume_rfft = torch.fft.rfftn(upsampled_volume, dim=(-3, -2, -1))
upsampled_shape = upsampled_volume.shape

Expand All @@ -471,7 +484,9 @@ def apply_simulation_filters(
modify_signal=dose_filter_modify_signal,
rfft=True,
fftshift=False,
device=device,
)

upsampled_volume_rfft *= dose_filter

# Fourier crop back to desired size
Expand All @@ -492,6 +507,7 @@ def apply_simulation_filters(
mtf_amplitudes=mtf_amplitudes,
rfft=True,
fftshift=False,
device=device,
)
upsampled_volume_rfft *= mtf

Expand All @@ -504,6 +520,10 @@ def apply_simulation_filters(
# NOTE: ifftshift not needed since volume here was never fftshifted
# cropped_volume = torch.fft.ifftshift(cropped_volume, dim=(-3, -2, -1))

# Move back to upsampled volume device if mps
if device_mps.type == "mps":
cropped_volume = cropped_volume.to(device_mps)

return cropped_volume


Expand All @@ -523,7 +543,7 @@ def simulate3d(
apply_dqe: bool = False,
mtf_frequencies: torch.Tensor = None,
mtf_amplitudes: torch.Tensor = None,
# gpu_ids: int | list[int] = -999, # TODO: implement gpu selection
gpu_ids: Optional[Union[int, list[int]]] = None,
) -> torch.Tensor:
"""Simulate 3D electron scattering volume with requested parameters.
Expand Down Expand Up @@ -575,22 +595,35 @@ def simulate3d(
The amplitudes for the modulation transfer function (MTF) filter at the
corresponding frequencies. Must be the same length as
'mtf_frequencies'. Required if 'apply_dqe' is True.
# gpu_ids : int | list[int]
gpu_ids : Optional[Union[int, list[int]]]
Device selection:
- None: Use CPU
- -1: Use first available GPU
- >=0: Use specific CUDA device
- list[int]: Use specific CUDA devices (future multi-GPU support)
Returns
-------
torch.Tensor
The simulated 3D volume in real space.
"""
# Get compute device
device = get_device(gpu_ids)

# Move input tensors to device
atom_positions_zyx = move_tensor_to_device(atom_positions_zyx, device)
atom_b_factors = move_tensor_to_device(atom_b_factors, device)
if mtf_frequencies is not None:
mtf_frequencies = move_tensor_to_device(mtf_frequencies, device)
if mtf_amplitudes is not None:
mtf_amplitudes = move_tensor_to_device(mtf_amplitudes, device)

# Validate portions of the input before continuing
_validate_dose_filter_inputs(
dose_filter_modify_signal, dose_filter_critical_bfactor
)
_validate_dqe_filter_inputs(apply_dqe, mtf_frequencies, mtf_amplitudes)

# Select devices for calculation
# TODO: Implement GPU selection

# Calculate the atom-wise scattering potentials
lead_term = _calculate_lead_term(beam_energy_kev, sim_pixel_spacing)

Expand All @@ -607,7 +640,7 @@ def simulate3d(
actual_upsampling = setup_results["actual_upsampling"]

# Nowsplit to batches
upsampled_volume = torch.zeros(upsampled_shape, dtype=torch.float32)
upsampled_volume = torch.zeros(upsampled_shape, dtype=torch.float32, device=device)

num_batches, atoms_per_batch = calculate_batches(setup_results, upsampled_volume)

Expand Down

0 comments on commit d1695d9

Please sign in to comment.