Skip to content

Commit e3c37a3

Browse files
committed
add oaconvolve and general function
1 parent e446369 commit e3c37a3

File tree

1 file changed

+87
-1
lines changed

1 file changed

+87
-1
lines changed

stumpy/sdp.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import warnings
2+
13
import numpy as np
24
from numba import njit
35
from scipy.fft import next_fast_len
46
from scipy.fft._pocketfft.basic import c2r, r2c
5-
from scipy.signal import convolve
7+
from scipy.signal import convolve, oaconvolve
68

79
from . import config
810

@@ -76,6 +78,26 @@ def _convolve_sliding_dot_product(Q, T):
7678
return convolve(np.flipud(Q), T, mode="valid")
7779

7880

81+
def _oaconvolve_sliding_dot_product(Q, T):
82+
"""
83+
Use scipy's oaconvolve to calculate the sliding dot product.
84+
85+
Parameters
86+
----------
87+
Q : numpy.ndarray
88+
Query array or subsequence
89+
90+
T : numpy.ndarray
91+
Time series or sequence
92+
93+
Returns
94+
-------
95+
output : numpy.ndarray
96+
Sliding dot product between `Q` and `T`.
97+
"""
98+
return oaconvolve(np.ascontiguousarray(Q[::-1]), T, mode="valid")
99+
100+
79101
def _pocketfft_sliding_dot_product(Q, T):
80102
"""
81103
Use scipy.fft._pocketfft to compute
@@ -289,4 +311,68 @@ def __call__(self, Q, T, n_threads=1, planning_flag="FFTW_ESTIMATE"):
289311
if FFTW_IS_AVAILABLE: # pragma: no cover
290312
_pyfftw_sliding_dot_product = _PYFFTW_SLIDING_DOT_PRODUCT()
291313
else: # pragma: no cover
314+
msg = (
315+
"Couldn't import pyFFTW. Set _pyfftw_sliding_dot_product " + "function to None"
316+
)
317+
warnings.warn(msg)
292318
_pyfftw_sliding_dot_product = None
319+
320+
321+
def _sliding_dot_product(
322+
Q,
323+
T,
324+
boundaries=[
325+
[(-np.inf, 2**7 + 1), (-np.inf, np.inf), _njit_sliding_dot_product],
326+
],
327+
default_sdp=_oaconvolve_sliding_dot_product,
328+
):
329+
"""
330+
Compute the sliding dot product between the query Q
331+
and the time series T by using different algorithms
332+
for different `len(Q), len(T)` according to the
333+
boundaries.
334+
335+
Parameters
336+
----------
337+
Q : numpy.ndarray
338+
Query array or subsequence
339+
340+
T : numpy.ndarray
341+
Time series or sequence
342+
343+
boundaries : list
344+
A nested list, where each item is a list
345+
like [(LB_Q, UB_Q), (LB_T, UB_T), sdp_func]
346+
The `sdp_func` is used if LB_Q<=len(Q)<UB_Q
347+
and LB_T<=len(T)<UB_T
348+
349+
default_sdp : function
350+
A function to compute sliding_dot_product when
351+
the provided `sdp_func` in boundaries is None
352+
or `(len(Q), len(T))` does not fit into the
353+
provided boundaries.
354+
355+
Returns
356+
-------
357+
output : numpy.ndarray
358+
Sliding dot product between `Q` and `T`.
359+
360+
Notes
361+
-----
362+
The function `_pyfftw_sliding_dot_product` will be set to None
363+
if pyFFTW cannot be imported
364+
"""
365+
m = len(Q)
366+
n = len(T)
367+
368+
for Q_boundaries, T_boundaries, sdp_func in boundaries:
369+
if (
370+
Q_boundaries[0] <= m < Q_boundaries[1]
371+
and T_boundaries[0] <= n < T_boundaries[1]
372+
and sdp_func is not None
373+
):
374+
return sdp_func(Q, T)
375+
376+
# when the union of regions is not comprehensive
377+
# or sdp_func is None
378+
return default_sdp(Q, T)

0 commit comments

Comments
 (0)