Skip to content
10 changes: 10 additions & 0 deletions aeon/distances/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@
"soft_dtw_pairwise_distance",
"soft_dtw_alignment_path",
"soft_dtw_cost_matrix",
"kdtw_distance",
"kdtw_alignment_path",
"kdtw_cost_matrix",
"kdtw_pairwise_distance",
]

from aeon.distances._distance import (
Expand Down Expand Up @@ -157,6 +161,12 @@
wdtw_distance,
wdtw_pairwise_distance,
)
from aeon.distances.kernel import (
kdtw_alignment_path,
kdtw_cost_matrix,
kdtw_distance,
kdtw_pairwise_distance,
)
from aeon.distances.mindist._dft_sfa import mindist_dft_sfa_distance
from aeon.distances.mindist._paa_sax import mindist_paa_sax_distance
from aeon.distances.mindist._sax import mindist_sax_distance
Expand Down
25 changes: 25 additions & 0 deletions aeon/distances/_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@
wdtw_distance,
wdtw_pairwise_distance,
)
from aeon.distances.kernel import (
kdtw_alignment_path,
kdtw_cost_matrix,
kdtw_distance,
kdtw_pairwise_distance,
)
from aeon.distances.mindist import (
mindist_dft_sfa_distance,
mindist_dft_sfa_pairwise_distance,
Expand Down Expand Up @@ -109,6 +115,9 @@ class DistanceKwargs(TypedDict, total=False):
m: int
max_shift: Optional[int]
gamma: float
sigma: float
normalize_input: bool
normalize_dist: bool


DistanceFunction = Callable[[np.ndarray, np.ndarray, Any], float]
Expand Down Expand Up @@ -469,6 +478,7 @@ def get_distance_function(method: Union[str, DistanceFunction]) -> DistanceFunct
'sbd' distances.sbd_distance
'shift_scale' distances.shift_scale_invariant_distance
'soft_dtw' distances.soft_dtw_distance
'kdtw' distances.kdtw_distance
=============== ========================================

Parameters
Expand Down Expand Up @@ -528,6 +538,7 @@ def get_pairwise_distance_function(
'sbd' distances.sbd_pairwise_distance
'shift_scale' distances.shift_scale_invariant_pairwise_distance
'soft_dtw' distances.soft_dtw_pairwise_distance
'kdtw' distances.kdtw_pairwise_distance
=============== ========================================

Parameters
Expand Down Expand Up @@ -582,6 +593,7 @@ def get_alignment_path_function(method: str) -> AlignmentPathFunction:
'twe' distances.twe_alignment_path
'lcss' distances.lcss_alignment_path
'soft_dtw' distances.soft_dtw_alignment_path
'kdtw' distances.kdtw_alignment_path
=============== ========================================

Parameters
Expand Down Expand Up @@ -631,6 +643,7 @@ def get_cost_matrix_function(method: str) -> CostMatrixFunction:
'twe' distances.twe_cost_matrix
'lcss' distances.lcss_cost_matrix
'soft_dtw' distances.soft_dtw_cost_matrix
'kdtw' distances.kdtw_cost_matrix
=============== ========================================

Parameters
Expand Down Expand Up @@ -685,6 +698,7 @@ class DistanceType(Enum):

POINTWISE = "pointwise"
ELASTIC = "elastic"
KERNEL = "kernel"
CROSS_CORRELATION = "cross-correlation"
MIN_DISTANCE = "min-dist"
MATRIX_PROFILE = "matrix-profile"
Expand Down Expand Up @@ -909,6 +923,16 @@ class DistanceType(Enum):
"symmetric": True,
"unequal_support": True,
},
{
"name": "kdtw",
"distance": kdtw_distance,
"pairwise_distance": kdtw_pairwise_distance,
"cost_matrix": kdtw_cost_matrix,
"alignment_path": kdtw_alignment_path,
"type": DistanceType.KERNEL,
"symmetric": True,
"unequal_support": True,
},
]

DISTANCES_DICT = {d["name"]: d for d in DISTANCES}
Expand All @@ -922,6 +946,7 @@ class DistanceType(Enum):
]

ELASTIC_DISTANCES = [d["name"] for d in DISTANCES if d["type"] == DistanceType.ELASTIC]
KERNEL_DISTANCES = [d["name"] for d in DISTANCES if d["type"] == DistanceType.KERNEL]
POINTWISE_DISTANCES = [
d["name"] for d in DISTANCES if d["type"] == DistanceType.POINTWISE
]
Expand Down
24 changes: 15 additions & 9 deletions aeon/distances/elastic/_alignment_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@


@njit(cache=True, fastmath=True)
def compute_min_return_path(cost_matrix: np.ndarray) -> list[tuple]:
def compute_min_return_path(
cost_matrix: np.ndarray, larger_is_better: bool = False
) -> list[tuple]:
"""Compute the minimum return path through a cost matrix.

Parameters
----------
cost_matrix : np.ndarray, of shape (n_timepoints_x, n_timepoints_y)
Cost matrix.
larger_is_better : bool, default=False
If True, the path will be computed for the maximum cost instead of the minimum.

Returns
-------
Expand All @@ -32,15 +36,17 @@ def compute_min_return_path(cost_matrix: np.ndarray) -> list[tuple]:
elif j == 0:
i -= 1
else:
min_index = np.argmin(
np.array(
[
cost_matrix[i - 1, j - 1],
cost_matrix[i - 1, j],
cost_matrix[i, j - 1],
]
)
costs = np.array(
[
cost_matrix[i - 1, j - 1],
cost_matrix[i - 1, j],
cost_matrix[i, j - 1],
]
)
if larger_is_better:
min_index = np.argmax(costs)
else:
min_index = np.argmin(costs)

if min_index == 0:
i, j = i - 1, j - 1
Expand Down
20 changes: 20 additions & 0 deletions aeon/distances/elastic/tests/test_cost_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,26 @@ def _validate_cost_matrix_result(
assert_almost_equal(curr_distance, distance_result)
elif name == "soft_dtw":
assert_almost_equal(abs(cost_matrix_result[-1, -1]), distance_result)
elif name == "kdtw":
# distance is normalized by default, so we need to do this here as well:
from aeon.distances.kernel._kdtw import (
_kdtw_cost_to_distance,
_normalize_time_series,
)

_x = x
_y = y
if x.ndim == 1:
_x = x.reshape((1, x.shape[0]))
if y.ndim == 1:
_y = y.reshape((1, y.shape[0]))

_x = _normalize_time_series(_x)
_y = _normalize_time_series(_y)
d = _kdtw_cost_to_distance(
cost_matrix_result, _x, _y, gamma=0.125, epsilon=1e-20, normalize_dist=True
)
assert_almost_equal(d, distance_result)
else:
assert_almost_equal(cost_matrix_result[-1, -1], distance_result)

Expand Down
15 changes: 15 additions & 0 deletions aeon/distances/kernel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Kernel distances."""

__all__ = [
"kdtw_distance",
"kdtw_alignment_path",
"kdtw_cost_matrix",
"kdtw_pairwise_distance",
]

from aeon.distances.kernel._kdtw import (
kdtw_alignment_path,
kdtw_cost_matrix,
kdtw_distance,
kdtw_pairwise_distance,
)
Loading
Loading