Skip to content

Commit 5aae9fb

Browse files
authored
Fixed #582 Allow 1D Mean/Stddev Match Params (#583)
1 parent d288451 commit 5aae9fb

File tree

4 files changed

+62
-1
lines changed

4 files changed

+62
-1
lines changed

stumpy/aamp_motifs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,8 @@ def max_distance(D):
380380

381381
if T_subseq_isfinite is None:
382382
T, T_subseq_isfinite = core.preprocess_non_normalized(T, m)
383+
if len(T_subseq_isfinite.shape) == 1:
384+
T_subseq_isfinite = T_subseq_isfinite[np.newaxis, :]
383385

384386
D = np.empty((d, n - m + 1))
385387
for i in range(d):

stumpy/motifs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,10 @@ def max_distance(D):
445445

446446
if M_T is None or Σ_T is None: # pragma: no cover
447447
T, M_T, Σ_T = core.preprocess(T, m)
448+
if len(M_T.shape) == 1:
449+
M_T = M_T[np.newaxis, :]
450+
if len(Σ_T.shape) == 1:
451+
Σ_T = Σ_T[np.newaxis, :]
448452

449453
D = np.empty((d, n - m + 1))
450454
for i in range(d):

tests/test_aamp_motifs.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy.testing as npt
33
import pytest
44

5-
from stumpy import aamp_motifs, aamp_match
5+
from stumpy import core, aamp_motifs, aamp_match
66

77
import naive
88

@@ -211,3 +211,31 @@ def test_aamp_match(Q, T):
211211
)
212212

213213
npt.assert_almost_equal(left, right)
214+
215+
216+
@pytest.mark.parametrize("Q, T", test_data)
217+
def test_aamp_match_T_subseq_isfinite(Q, T):
218+
m = Q.shape[0]
219+
excl_zone = int(np.ceil(m / 4))
220+
max_distance = 0.3
221+
T, T_subseq_isfinite = core.preprocess_non_normalized(T, len(Q))
222+
223+
for p in [1.0, 2.0, 3.0]:
224+
left = naive_aamp_match(
225+
Q,
226+
T,
227+
p=p,
228+
excl_zone=excl_zone,
229+
max_distance=max_distance,
230+
)
231+
232+
right = aamp_match(
233+
Q,
234+
T,
235+
T_subseq_isfinite,
236+
p=p,
237+
max_matches=None,
238+
max_distance=max_distance,
239+
)
240+
241+
npt.assert_almost_equal(left, right)

tests/test_motifs.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,3 +235,30 @@ def test_match(Q, T):
235235
)
236236

237237
npt.assert_almost_equal(left, right)
238+
239+
240+
@pytest.mark.parametrize("Q, T", test_data)
241+
def test_match_mean_stddev(Q, T):
242+
m = Q.shape[0]
243+
excl_zone = int(np.ceil(m / 4))
244+
max_distance = 0.3
245+
246+
left = naive_match(
247+
Q,
248+
T,
249+
excl_zone,
250+
max_distance=max_distance,
251+
)
252+
253+
M_T, Σ_T = core.compute_mean_std(T, len(Q))
254+
255+
right = match(
256+
Q,
257+
T,
258+
M_T,
259+
Σ_T,
260+
max_matches=None,
261+
max_distance=lambda D: max_distance, # also test lambda functionality
262+
)
263+
264+
npt.assert_almost_equal(left, right)

0 commit comments

Comments
 (0)