Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proof of concept on how to fix issues #531, #535 and #570 #574

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 26 additions & 19 deletions pywt/_cwt.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from math import floor, ceil
from scipy import interpolate

from ._extensions._pywt import (DiscreteContinuousWavelet, ContinuousWavelet,
Wavelet, _check_dtype)
from ._functions import integrate_wavelet, scale2frequency
from ._functions import evaluate_wavelet, scale2frequency


__all__ = ["cwt"]
Expand Down Expand Up @@ -123,13 +124,16 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1):
dt_out = dt_cplx if wavelet.complex_cwt else dt
out = np.empty((np.size(scales),) + data.shape, dtype=dt_out)
precision = 10
int_psi, x = integrate_wavelet(wavelet, precision=precision)
int_psi = np.conj(int_psi) if wavelet.complex_cwt else int_psi
psi, x = evaluate_wavelet(wavelet, precision=precision)
psi = np.conj(psi) if wavelet.complex_cwt else psi

# convert int_psi, x to the same precision as the data
dt_psi = dt_cplx if int_psi.dtype.kind == 'c' else dt
int_psi = np.asarray(int_psi, dtype=dt_psi)
# convert psi, x to the same precision as the data
dt_psi = dt_cplx if psi.dtype.kind == 'c' else dt
psi = np.asarray(psi, dtype=dt_psi)
x = np.asarray(x, dtype=data.real.dtype)
# FIXME: The original wavelet function could be used here, but
# interpolation is computationally more efficient.
wavefun = interpolate.interp1d(x, psi, kind='cubic', assume_sorted=True)

if method == 'fft':
size_scale0 = -1
Expand All @@ -146,41 +150,44 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1):
data = data.reshape((-1, data.shape[-1]))

for i, scale in enumerate(scales):
step = x[1] - x[0]
j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * step)
j = j.astype(int) # floor
if j[-1] >= int_psi.size:
j = np.extract(j < int_psi.size, j)
int_psi_scale = int_psi[j][::-1]
# FIXME: Boundary points might be discarded erroneously
if np.sign(x[0])*np.sign(x[-1])<0:
# Wavelet is sampled at 0.0 if the range includes it
xsl = np.arange(0.0, x[0], -1.0/scale)
xsr = np.arange(0.0, x[-1], 1.0/scale)
xs = np.concatenate((xsl[:0:-1], xsr))
else:
xs = np.arange(x[0], x[-1], 1.0/scale)
psi_scale = wavefun(xs)[::-1]

if method == 'conv':
if data.ndim == 1:
conv = np.convolve(data, int_psi_scale)
conv = np.convolve(data, psi_scale)
else:
# batch convolution via loop
conv_shape = list(data.shape)
conv_shape[-1] += int_psi_scale.size - 1
conv_shape[-1] += psi_scale.size - 1
conv_shape = tuple(conv_shape)
conv = np.empty(conv_shape, dtype=dt_out)
for n in range(data.shape[0]):
conv[n, :] = np.convolve(data[n], int_psi_scale)
conv[n, :] = np.convolve(data[n], psi_scale)
else:
# The padding is selected for:
# - optimal FFT complexity
# - to be larger than the two signals length to avoid circular
# convolution
size_scale = next_fast_len(
data.shape[-1] + int_psi_scale.size - 1
data.shape[-1] + psi_scale.size - 1
)
if size_scale != size_scale0:
# Must recompute fft_data when the padding size changes.
fft_data = fftmodule.fft(data, size_scale, axis=-1)
size_scale0 = size_scale
fft_wav = fftmodule.fft(int_psi_scale, size_scale, axis=-1)
fft_wav = fftmodule.fft(psi_scale, size_scale, axis=-1)
conv = fftmodule.ifft(fft_wav * fft_data, axis=-1)
conv = conv[..., :data.shape[-1] + int_psi_scale.size - 1]
conv = conv[..., :data.shape[-1] + psi_scale.size - 1]

coef = - np.sqrt(scale) * np.diff(conv, axis=-1)
coef = conv / np.sqrt(scale)
if out.dtype.kind != 'c':
coef = coef.real
# transform axis is always -1 due to the data reshape above
Expand Down
55 changes: 53 additions & 2 deletions pywt/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from ._extensions._pywt import DiscreteContinuousWavelet, Wavelet, ContinuousWavelet


__all__ = ["integrate_wavelet", "central_frequency", "scale2frequency", "qmf",
"orthogonal_filter_bank",
__all__ = ["integrate_wavelet", "evaluate_wavelet", "central_frequency",
"scale2frequency", "qmf", "orthogonal_filter_bank",
"intwave", "centrfrq", "scal2frq", "orthfilt"]


Expand Down Expand Up @@ -119,6 +119,57 @@ def integrate_wavelet(wavelet, precision=8):
return _integrate(psi_d, step), _integrate(psi_r, step), x


def evaluate_wavelet(wavelet, precision=8):
"""
Evaluate `psi` wavelet function between lower and upper bound.

Parameters
----------
wavelet : Wavelet instance or str
Wavelet to evaluate. If a string, should be the name of a wavelet.
precision : int, optional
Number of wavelet function points computed with Wavelet's
wavefun(level=precision) method (default: 8).

Returns
-------
[psi, x] :
for orthogonal wavelets
[psi_d, psi_r, x] :
for other wavelets


Examples
--------
>>> from pywt import Wavelet, evaluate_wavelet
>>> wavelet1 = Wavelet('db2')
>>> [psi, x] = evaluate_wavelet(wavelet1, precision=5)
>>> wavelet2 = Wavelet('bior1.3')
>>> [psi_d, psi_r, x] = evaluate_wavelet(wavelet2, precision=5)

"""

if type(wavelet) in (tuple, list):
psi, x = np.asarray(wavelet[0]), np.asarray(wavelet[1])
return psi, x
elif not isinstance(wavelet, (Wavelet, ContinuousWavelet)):
wavelet = DiscreteContinuousWavelet(wavelet)

functions_approximations = wavelet.wavefun(precision)

if len(functions_approximations) == 2: # continuous wavelet
psi, x = functions_approximations
return psi, x

elif len(functions_approximations) == 3: # orthogonal wavelet
phi, psi, x = functions_approximations
return psi, x

else: # biorthogonal wavelet
phi_d, psi_d, phi_r, psi_r, x = functions_approximations
return psi_d, psi_r, x


def central_frequency(wavelet, precision=8):
"""
Computes the central frequency of the `psi` wavelet function.
Expand Down