From 02984db7d0a7c6c2cdf1ef90b270ec3143834d1b Mon Sep 17 00:00:00 2001 From: cyschneck <22159116+cyschneck@users.noreply.github.com> Date: Wed, 21 Feb 2024 00:59:17 -0700 Subject: [PATCH 1/6] valueError when scale is 0 --- pywt/_cwt.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pywt/_cwt.py b/pywt/_cwt.py index cad9b04c..3ca98564 100644 --- a/pywt/_cwt.py +++ b/pywt/_cwt.py @@ -147,6 +147,9 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1): # reshape to (n_batch, data.shape[-1]) data_shape_pre = data.shape data = data.reshape((-1, data.shape[-1])) + + if 0 in scales: + raise ValueError("scales range cannot include zero") for i, scale in enumerate(scales): step = x[1] - x[0] From 2853b9bbb312d5498429c152d0e81a00cd836c02 Mon Sep 17 00:00:00 2001 From: cyschneck <22159116+cyschneck@users.noreply.github.com> Date: Fri, 23 Feb 2024 12:20:44 -0700 Subject: [PATCH 2/6] update scales tests, tabs->spaces --- pywt/_cwt.py | 6 +++--- pywt/tests/test_cwt_wavelets.py | 7 +++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/pywt/_cwt.py b/pywt/_cwt.py index 3ca98564..c71fbf75 100644 --- a/pywt/_cwt.py +++ b/pywt/_cwt.py @@ -147,9 +147,9 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1): # reshape to (n_batch, data.shape[-1]) data_shape_pre = data.shape data = data.reshape((-1, data.shape[-1])) - - if 0 in scales: - raise ValueError("scales range cannot include zero") + + if 0 in scales: + raise ValueError("scales range cannot include zero") for i, scale in enumerate(scales): step = x[1] - x[0] diff --git a/pywt/tests/test_cwt_wavelets.py b/pywt/tests/test_cwt_wavelets.py index f7f64bca..d6d13a5a 100644 --- a/pywt/tests/test_cwt_wavelets.py +++ b/pywt/tests/test_cwt_wavelets.py @@ -441,6 +441,13 @@ def test_cwt_small_scales(): # extremely short scale factors raise a ValueError assert_raises(ValueError, pywt.cwt, data, scales=0.01, wavelet='mexh') +def test_cwt_zero_scale(): + data = np.zeros(32) + scales = np.arange(0, 4) + + # scale that includes 0 throws ValueError to prevent IndexError + assert_raises(ValueError, pywt.cwt, data, scales=scales, wavelet='morl') + def test_cwt_method_fft(): rstate = np.random.RandomState(1) From d87280a9d843081ff48a122bfdacc4db7828cc71 Mon Sep 17 00:00:00 2001 From: cyschneck <22159116+cyschneck@users.noreply.github.com> Date: Tue, 27 Feb 2024 14:49:20 -0700 Subject: [PATCH 3/6] scales cannot be negative values --- pywt/_cwt.py | 8 ++++---- pywt/tests/test_cwt_wavelets.py | 7 +++++++ 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/pywt/_cwt.py b/pywt/_cwt.py index c71fbf75..9de581a3 100644 --- a/pywt/_cwt.py +++ b/pywt/_cwt.py @@ -118,8 +118,8 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1): dt_cplx = np.result_type(dt, np.complex64) if not isinstance(wavelet, (ContinuousWavelet, Wavelet)): wavelet = DiscreteContinuousWavelet(wavelet) - if np.isscalar(scales): - scales = np.array([scales]) + + scales = np.array([scales]) if not np.isscalar(axis): raise np.AxisError("axis must be a scalar.") @@ -148,8 +148,8 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1): data_shape_pre = data.shape data = data.reshape((-1, data.shape[-1])) - if 0 in scales: - raise ValueError("scales range cannot include zero") + if np.any(scales <= 0): + raise ValueError("scales range cannot include values less than 1") for i, scale in enumerate(scales): step = x[1] - x[0] diff --git a/pywt/tests/test_cwt_wavelets.py b/pywt/tests/test_cwt_wavelets.py index d6d13a5a..f142404c 100644 --- a/pywt/tests/test_cwt_wavelets.py +++ b/pywt/tests/test_cwt_wavelets.py @@ -448,6 +448,13 @@ def test_cwt_zero_scale(): # scale that includes 0 throws ValueError to prevent IndexError assert_raises(ValueError, pywt.cwt, data, scales=scales, wavelet='morl') +def test_cwt_negative_scale(): + data = np.zeros(32) + scales = np.asarray([-1, -2, -3]) + + # scale that includes negative values throws ValueError to prevent IndexError + assert_raises(ValueError, pywt.cwt, data, scales=scales, wavelet='morl') + def test_cwt_method_fft(): rstate = np.random.RandomState(1) From d71de3678659cc91faf112bb6e0c908a9878c135 Mon Sep 17 00:00:00 2001 From: cyschneck <22159116+cyschneck@users.noreply.github.com> Date: Thu, 7 Mar 2024 13:39:36 -0700 Subject: [PATCH 4/6] unconditional conversion of scales --- pywt/_cwt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pywt/_cwt.py b/pywt/_cwt.py index 9de581a3..10ee17d7 100644 --- a/pywt/_cwt.py +++ b/pywt/_cwt.py @@ -119,7 +119,7 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1): if not isinstance(wavelet, (ContinuousWavelet, Wavelet)): wavelet = DiscreteContinuousWavelet(wavelet) - scales = np.array([scales]) + scales = np.asarray(scales) if not np.isscalar(axis): raise np.AxisError("axis must be a scalar.") From 5bd591e742ebf9246d3235def9638c511340e1c8 Mon Sep 17 00:00:00 2001 From: Ralf Gommers Date: Fri, 8 Mar 2024 17:58:31 +0100 Subject: [PATCH 5/6] Use `atleast_1d` to convert the `scales` input to `cwt` --- pywt/_cwt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pywt/_cwt.py b/pywt/_cwt.py index 10ee17d7..26e5c907 100644 --- a/pywt/_cwt.py +++ b/pywt/_cwt.py @@ -119,7 +119,7 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1): if not isinstance(wavelet, (ContinuousWavelet, Wavelet)): wavelet = DiscreteContinuousWavelet(wavelet) - scales = np.asarray(scales) + scales = np.atleast_1d(scales) if not np.isscalar(axis): raise np.AxisError("axis must be a scalar.") From d9b483bc33d0756406cb5f0739ad47829f325a46 Mon Sep 17 00:00:00 2001 From: Ralf Gommers Date: Fri, 8 Mar 2024 18:09:11 +0100 Subject: [PATCH 6/6] Move validation of `scales` up and correct error message --- pywt/_cwt.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pywt/_cwt.py b/pywt/_cwt.py index 26e5c907..c76042cc 100644 --- a/pywt/_cwt.py +++ b/pywt/_cwt.py @@ -120,6 +120,9 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1): wavelet = DiscreteContinuousWavelet(wavelet) scales = np.atleast_1d(scales) + if np.any(scales <= 0): + raise ValueError("`scales` must only include positive values") + if not np.isscalar(axis): raise np.AxisError("axis must be a scalar.") @@ -148,9 +151,6 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1): data_shape_pre = data.shape data = data.reshape((-1, data.shape[-1])) - if np.any(scales <= 0): - raise ValueError("scales range cannot include values less than 1") - for i, scale in enumerate(scales): step = x[1] - x[0] j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * step)