Skip to content

Commit

Permalink
optional pyfftw import for simulation.fractal.py (#1191)
Browse files Browse the repository at this point in the history
* Fallback to `scipy.fft` if `pyfftw` is not available

* Multi-threading fft2 with scipy
  • Loading branch information
avalentino authored May 11, 2024
1 parent feae339 commit 9fb0c4c
Showing 1 changed file with 22 additions and 11 deletions.
33 changes: 22 additions & 11 deletions src/mintpy/simulation/fractal.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,27 @@
import matplotlib.pyplot as plt
import numpy as np

NUM_THREADS = min(os.cpu_count(), 4)

try:
import pyfftw
from pyfftw.interfaces.numpy_fft import fft2, fftshift, ifft2

# speedup pyfftw
print(f'using {NUM_THREADS} threads for fft computation with pyfftw.')
pyfftw.config.NUM_THREADS = NUM_THREADS
except ImportError:
raise ImportError('Cannot import pyfftw!')
import functools
import warnings

import scipy
from scipy.fft import fftshift

# speedup pyfftw
NUM_THREADS = min(os.cpu_count(), 4)
print(f'using {NUM_THREADS} threads for pyfftw computation.')
pyfftw.config.NUM_THREADS = NUM_THREADS
warnings.warn('Cannot import pyfftw, fallback to scipy.fft.')

fft2 = functools.partial(scipy.fft.fft2, workers=NUM_THREADS)
ifft2 = functools.partial(scipy.ifft2, workers=NUM_THREADS)
print(f'using {NUM_THREADS} threads for fft computation with scipy.')


def fractal_surface_atmos(shape=(128, 128), resolution=60., p0=1., freq0=1e-3,
Expand Down Expand Up @@ -67,8 +78,8 @@ def fractal_surface_atmos(shape=(128, 128), resolution=60., p0=1., freq0=1e-3,

# simulate a uniform random signal
h = np.random.rand(length, width)
H = pyfftw.interfaces.numpy_fft.fft2(h)
H = pyfftw.interfaces.numpy_fft.fftshift(H)
H = fft2(h)
H = fftshift(H)

# scale the spectrum with the power law
yy, xx = np.mgrid[0:length-1:length*1j,
Expand Down Expand Up @@ -120,7 +131,7 @@ def fractal_surface_atmos(shape=(128, 128), resolution=60., p0=1., freq0=1e-3,

# get the fractal spectrum and transform to spatial domain
Hfrac = np.divide(H, fraction)
fsurf = pyfftw.interfaces.numpy_fft.ifft2(Hfrac)
fsurf = ifft2(Hfrac)
fsurf = np.abs(fsurf).astype(np.float32)
fsurf -= np.mean(fsurf)

Expand All @@ -129,7 +140,7 @@ def fractal_surface_atmos(shape=(128, 128), resolution=60., p0=1., freq0=1e-3,

# scale the spectrum to match the input power spectral density.
Hfrac *= np.sqrt(p0/p1)
fsurf = pyfftw.interfaces.numpy_fft.ifft2(Hfrac)
fsurf = ifft2(Hfrac)
fsurf = np.abs(fsurf).astype(np.float32)
fsurf -= np.mean(fsurf)
return fsurf
Expand Down Expand Up @@ -162,8 +173,8 @@ def get_power_spectral_density(data, resolution=60., freq0=1e-3, display=False,
N = data.shape[0]

# calculate the normalized power spectrum (spectral density)
fdata2d = pyfftw.interfaces.numpy_fft.fft2(data)
fdata2d = pyfftw.interfaces.numpy_fft.fftshift(fdata2d)
fdata2d = fft2(data)
fdata2d = fftshift(fdata2d)
psd2d = np.abs(np.multiply(fdata2d, np.conj(fdata2d))) / (N**2)

# The frequency coordinate in cycle / m
Expand Down

0 comments on commit 9fb0c4c

Please sign in to comment.