Skip to content

Commit

Permalink
[ENH] Soft-DTW (#2160)
Browse files Browse the repository at this point in the history
* started working on soft-dtw

* added tests for soft-dtw

* uncommented distances

* fix example

* fixed example

* fixed indent

* indent

* indent
  • Loading branch information
chrisholder authored Oct 12, 2024
1 parent ba1bc0c commit 866c73b
Show file tree
Hide file tree
Showing 6 changed files with 514 additions and 0 deletions.
10 changes: 10 additions & 0 deletions aeon/distances/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@
"shift_scale_invariant_distance",
"shift_scale_invariant_pairwise_distance",
"shift_scale_invariant_best_shift",
"soft_dtw_distance",
"soft_dtw_pairwise_distance",
"soft_dtw_alignment_path",
"soft_dtw_cost_matrix",
]


Expand Down Expand Up @@ -151,6 +155,12 @@
shift_scale_invariant_distance,
shift_scale_invariant_pairwise_distance,
)
from aeon.distances._soft_dtw import (
soft_dtw_alignment_path,
soft_dtw_cost_matrix,
soft_dtw_distance,
soft_dtw_pairwise_distance,
)
from aeon.distances._squared import squared_distance, squared_pairwise_distance
from aeon.distances._twe import (
twe_alignment_path,
Expand Down
50 changes: 50 additions & 0 deletions aeon/distances/_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@
shift_scale_invariant_distance,
shift_scale_invariant_pairwise_distance,
)
from aeon.distances._soft_dtw import (
soft_dtw_alignment_path,
soft_dtw_cost_matrix,
soft_dtw_distance,
soft_dtw_pairwise_distance,
)
from aeon.distances._squared import squared_distance, squared_pairwise_distance
from aeon.distances._twe import (
twe_alignment_path,
Expand Down Expand Up @@ -102,6 +108,7 @@ class DistanceKwargs(TypedDict, total=False):
standardize: bool
m: int
max_shift: Optional[int]
gamma: float


DistanceFunction = Callable[[np.ndarray, np.ndarray, Any], float]
Expand Down Expand Up @@ -256,6 +263,14 @@ def distance(
return sbd_distance(x, y, kwargs.get("standardize", True))
elif metric == "shift_scale":
return shift_scale_invariant_distance(x, y, kwargs.get("max_shift", None))
elif metric == "soft_dtw":
return soft_dtw_distance(
x,
y,
gamma=kwargs.get("gamma", 1.0),
itakura_max_slope=kwargs.get("itakura_max_slope"),
window=kwargs.get("window"),
)
else:
if isinstance(metric, Callable):
return metric(x, y, **kwargs)
Expand Down Expand Up @@ -438,6 +453,14 @@ def pairwise_distance(
return shift_scale_invariant_pairwise_distance(
x, y, kwargs.get("max_shift", None)
)
elif metric == "soft_dtw":
return soft_dtw_pairwise_distance(
x,
y,
gamma=kwargs.get("gamma", 1.0),
itakura_max_slope=kwargs.get("itakura_max_slope"),
window=kwargs.get("window"),
)
else:
if isinstance(metric, Callable):
if y is None and not symmetric:
Expand Down Expand Up @@ -630,6 +653,14 @@ def alignment_path(
kwargs.get("itakura_max_slope"),
kwargs.get("warp_penalty", 1.0),
)
elif metric == "soft_dtw":
return soft_dtw_alignment_path(
x,
y,
gamma=kwargs.get("gamma", 1.0),
itakura_max_slope=kwargs.get("itakura_max_slope"),
window=kwargs.get("window"),
)
else:
raise ValueError("Metric must be one of the supported strings")

Expand Down Expand Up @@ -773,6 +804,14 @@ def cost_matrix(
kwargs.get("itakura_max_slope"),
kwargs.get("warp_penalty", 1.0),
)
elif metric == "soft_dtw":
return soft_dtw_cost_matrix(
x,
y,
gamma=kwargs.get("gamma", 1.0),
itakura_max_slope=kwargs.get("itakura_max_slope"),
window=kwargs.get("window"),
)
else:
raise ValueError("Metric must be one of the supported strings")

Expand Down Expand Up @@ -826,6 +865,7 @@ def get_distance_function(metric: Union[str, DistanceFunction]) -> DistanceFunct
'minkowski' distances.minkowski_distance
'sbd' distances.sbd_distance
'shift_scale' distances.shift_scale_invariant_distance
'soft_dtw' distances.soft_dtw_distance
=============== ========================================
Parameters
Expand Down Expand Up @@ -884,6 +924,7 @@ def get_pairwise_distance_function(
'minkowski' distances.minkowski_pairwise_distance
'sbd' distances.sbd_pairwise_distance
'shift_scale' distances.shift_scale_invariant_pairwise_distance
'soft_dtw' distances.soft_dtw_pairwise_distance
=============== ========================================
Parameters
Expand Down Expand Up @@ -937,6 +978,7 @@ def get_alignment_path_function(metric: str) -> AlignmentPathFunction:
'msm' distances.msm_alignment_path
'twe' distances.twe_alignment_path
'lcss' distances.lcss_alignment_path
'soft_dtw' distances.soft_dtw_alignment_path
=============== ========================================
Parameters
Expand Down Expand Up @@ -985,6 +1027,7 @@ def get_cost_matrix_function(metric: str) -> CostMatrixFunction:
'msm' distances.msm_cost_matrix
'twe' distances.twe_cost_matrix
'lcss' distances.lcss_cost_matrix
'soft_dtw' distances.soft_dtw_cost_matrix
=============== ========================================
Parameters
Expand Down Expand Up @@ -1142,6 +1185,13 @@ def _resolve_key_from_distance(metric: Union[str, Callable], key: str) -> Any:
"distance": shift_scale_invariant_distance,
"pairwise_distance": shift_scale_invariant_pairwise_distance,
},
{
"name": "soft_dtw",
"distance": soft_dtw_distance,
"pairwise_distance": soft_dtw_pairwise_distance,
"cost_matrix": soft_dtw_cost_matrix,
"alignment_path": soft_dtw_alignment_path,
},
]

DISTANCES_DICT = {d["name"]: d for d in DISTANCES}
Loading

0 comments on commit 866c73b

Please sign in to comment.