From d7b824c09f0010535d3e9dc362b370bcc7f743e4 Mon Sep 17 00:00:00 2001 From: Luca Kubin Date: Mon, 17 May 2021 22:18:53 +0200 Subject: [PATCH] Fixed bug in dtw_path when one input is all nans or with empty length (#355) Co-authored-by: Romain Tavenard Co-authored-by: Luca Kubin --- docs/requirements_rtd.txt | 2 +- pyproject.toml | 2 +- tslearn/__init__.py | 2 +- tslearn/metrics/dtw_variants.py | 3 +++ tslearn/tests/test_metrics.py | 15 +++++++++++++++ 5 files changed, 21 insertions(+), 3 deletions(-) diff --git a/docs/requirements_rtd.txt b/docs/requirements_rtd.txt index 80c570648..7665e5040 100644 --- a/docs/requirements_rtd.txt +++ b/docs/requirements_rtd.txt @@ -12,5 +12,5 @@ tensorflow>=2 Pygments numba sphinx_bootstrap_theme -git+git://github.com/numpy/numpydoc@master +git+git://github.com/numpy/numpydoc@main matplotlib diff --git a/pyproject.toml b/pyproject.toml index 1956a13a5..fdba3974b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,2 +1,2 @@ [build-system] -requires = ["setuptools", "wheel", "numpy", "Cython"] +requires = ["setuptools", "wheel", "numpy<=1.19", "Cython"] diff --git a/tslearn/__init__.py b/tslearn/__init__.py index a1001a129..e407086de 100644 --- a/tslearn/__init__.py +++ b/tslearn/__init__.py @@ -1,7 +1,7 @@ import os __author__ = 'Romain Tavenard romain.tavenard[at]univ-rennes2.fr' -__version__ = "0.5.0.4" +__version__ = "0.5.0.5" __bibtex__ = r"""@article{JMLR:v21:20-091, author = {Romain Tavenard and Johann Faouzi and Gilles Vandewiele and Felix Divo and Guillaume Androz and Chester Holtz and diff --git a/tslearn/metrics/dtw_variants.py b/tslearn/metrics/dtw_variants.py index 14f635164..4fb1ae3bb 100644 --- a/tslearn/metrics/dtw_variants.py +++ b/tslearn/metrics/dtw_variants.py @@ -189,6 +189,9 @@ def dtw_path(s1, s2, global_constraint=None, sakoe_chiba_radius=None, s1 = to_time_series(s1, remove_nans=True) s2 = to_time_series(s2, remove_nans=True) + if len(s1) == 0 or len(s2) == 0: + raise ValueError("One of the input time series contains only nans or has zero length.") + mask = compute_mask( s1, s2, GLOBAL_CONSTRAINT_CODE[global_constraint], sakoe_chiba_radius, itakura_max_slope diff --git a/tslearn/tests/test_metrics.py b/tslearn/tests/test_metrics.py index ae3f80a7f..57de9fbfa 100644 --- a/tslearn/tests/test_metrics.py +++ b/tslearn/tests/test_metrics.py @@ -1,8 +1,10 @@ +import pytest import numpy as np from scipy.spatial.distance import cdist import tslearn.metrics import tslearn.clustering from tslearn.utils import to_time_series +from tslearn.metrics.dtw_variants import dtw_path __author__ = 'Romain Tavenard romain.tavenard[at]univ-rennes2.fr' @@ -426,3 +428,16 @@ def test_softdtw(): np.testing.assert_equal(dist, dist_ref ** 2) np.testing.assert_allclose(matrix_path, mat_path_ref) + + +def test_dtw_path_with_empty_or_nan_inputs(): + s1 = np.zeros((3, 10)) + s2_empty = np.zeros((0, 10)) + with pytest.raises(ValueError) as excinfo: + dtw_path(s1, s2_empty) + assert str(excinfo.value) == "One of the input time series contains only nans or has zero length." + + s2_nan = np.full((3, 10), np.nan) + with pytest.raises(ValueError) as excinfo: + dtw_path(s1, s2_nan) + assert str(excinfo.value) == "One of the input time series contains only nans or has zero length."