Skip to content

Commit 87dfcdb

Browse files
committed
add pyfftw sdp and check for fftw/pyfftw
1 parent 3b08396 commit 87dfcdb

File tree

3 files changed

+255
-9
lines changed

3 files changed

+255
-9
lines changed

stumpy/sdp.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@
66

77
from . import config
88

9+
try: # pragma: no cover
10+
import pyfftw
11+
12+
FFTW_IS_AVAILABLE = True
13+
except ImportError: # pragma: no cover
14+
FFTW_IS_AVAILABLE = False
15+
916

1017
@njit(fastmath=config.STUMPY_FASTMATH_TRUE)
1118
def _njit_sliding_dot_product(Q, T):
@@ -109,6 +116,188 @@ def _pocketfft_sliding_dot_product(Q, T):
109116
return c2r(False, np.multiply(fft_2d[0], fft_2d[1]), n=next_fast_n)[m - 1 : n]
110117

111118

119+
class _PYFFTW_SLIDING_DOT_PRODUCT:
120+
"""
121+
A class to compute the sliding dot product using FFTW via pyfftw.
122+
123+
This class uses FFTW (via pyfftw) to efficiently compute the sliding dot product
124+
between a query sequence Q and a time series T. It preallocates arrays and caches
125+
FFTW objects to optimize repeated computations with similar-sized inputs.
126+
127+
Parameters
128+
----------
129+
max_n : int, default=2**20
130+
Maximum length to preallocate arrays for. This will be the size of the
131+
real-valued array. A complex-valued array of size `1 + (max_n // 2)`
132+
will also be preallocated. If inputs exceed this size, arrays will be
133+
reallocated to accommodate larger sizes.
134+
135+
Attributes
136+
----------
137+
real_arr : pyfftw.empty_aligned
138+
Preallocated real-valued array for FFTW computations.
139+
140+
complex_arr : pyfftw.empty_aligned
141+
Preallocated complex-valued array for FFTW computations.
142+
143+
rfft_objects : dict
144+
Cache of FFTW forward transform objects, keyed by
145+
(next_fast_n, n_threads, planning_flag).
146+
147+
irfft_objects : dict
148+
Cache of FFTW inverse transform objects, keyed by
149+
(next_fast_n, n_threads, planning_flag).
150+
151+
Notes
152+
-----
153+
The class maintains internal caches of FFTW objects to avoid redundant planning
154+
operations when called multiple times with similar-sized inputs and parameters.
155+
156+
Examples
157+
--------
158+
>>> sdp_obj = _PYFFTW_SLIDING_DOT_PRODUCT(max_n=1000)
159+
>>> Q = np.array([1, 2, 3])
160+
>>> T = np.array([4, 5, 6, 7, 8])
161+
>>> result = sdp_obj(Q, T)
162+
163+
References
164+
----------
165+
`FFTW documentation <http://www.fftw.org/>`__
166+
167+
`pyfftw documentation <https://pyfftw.readthedocs.io/>`__
168+
"""
169+
170+
def __init__(self, max_n=2**20):
171+
"""
172+
Initialize the `_PYFFTW_SLIDING_DOT_PRODUCT` object, which can be called
173+
to compute the sliding dot product using FFTW via pyfftw.
174+
175+
Parameters
176+
----------
177+
max_n : int, default=2**20
178+
Maximum length to preallocate arrays for. This will be the size of the
179+
real-valued array. A complex-valued array of size `1 + (max_n // 2)`
180+
will also be preallocated.
181+
182+
Returns
183+
-------
184+
None
185+
"""
186+
# Preallocate arrays
187+
self.real_arr = pyfftw.empty_aligned(max_n, dtype="float64")
188+
self.complex_arr = pyfftw.empty_aligned(1 + (max_n // 2), dtype="complex128")
189+
190+
# Store FFTW objects, keyed by (next_fast_n, n_threads, planning_flag)
191+
self.rfft_objects = {}
192+
self.irfft_objects = {}
193+
194+
def __call__(self, Q, T, n_threads=1, planning_flag="FFTW_ESTIMATE"):
195+
"""
196+
Compute the sliding dot product between `Q` and `T` using FFTW via pyfftw,
197+
and cache FFTW objects if not already cached.
198+
199+
Parameters
200+
----------
201+
Q : numpy.ndarray
202+
Query array or subsequence.
203+
204+
T : numpy.ndarray
205+
Time series or sequence.
206+
207+
n_threads : int, default=1
208+
Number of threads to use for FFTW computations.
209+
210+
planning_flag : str, default="FFTW_ESTIMATE"
211+
The planning flag that will be used in FFTW for planning.
212+
See pyfftw documentation for details. Current options, ordered
213+
ascendingly by the level of aggressiveness in planning, are:
214+
"FFTW_ESTIMATE", "FFTW_MEASURE", "FFTW_PATIENT", and "FFTW_EXHAUSTIVE".
215+
The more aggressive the planning, the longer the planning time, but
216+
the faster the execution time.
217+
218+
Returns
219+
-------
220+
out : numpy.ndarray
221+
Sliding dot product between `Q` and `T`.
222+
223+
Notes
224+
-----
225+
The planning_flag is defaulted to "FFTW_ESTIMATE" to be aligned with
226+
MATLAB's FFTW usage (as of version R2025b)
227+
See: https://www.mathworks.com/help/matlab/ref/fftw.html
228+
229+
This implementation is inspired by the answer on StackOverflow:
230+
https://stackoverflow.com/a/30615425/2955541
231+
"""
232+
m = Q.shape[0]
233+
n = T.shape[0]
234+
next_fast_n = pyfftw.next_fast_len(n)
235+
236+
# Update preallocated arrays if needed
237+
if next_fast_n > len(self.real_arr):
238+
self.real_arr = pyfftw.empty_aligned(next_fast_n, dtype="float64")
239+
self.complex_arr = pyfftw.empty_aligned(
240+
1 + (next_fast_n // 2), dtype="complex128"
241+
)
242+
243+
real_arr = self.real_arr[:next_fast_n]
244+
complex_arr = self.complex_arr[: 1 + (next_fast_n // 2)]
245+
246+
# Get or create FFTW objects
247+
key = (next_fast_n, n_threads, planning_flag)
248+
249+
rfft_obj = self.rfft_objects.get(key, None)
250+
if rfft_obj is None:
251+
rfft_obj = pyfftw.FFTW(
252+
input_array=real_arr,
253+
output_array=complex_arr,
254+
direction="FFTW_FORWARD",
255+
flags=(planning_flag,),
256+
threads=n_threads,
257+
)
258+
self.rfft_objects[key] = rfft_obj
259+
else:
260+
rfft_obj.update_arrays(real_arr, complex_arr)
261+
262+
irfft_obj = self.irfft_objects.get(key, None)
263+
if irfft_obj is None:
264+
irfft_obj = pyfftw.FFTW(
265+
input_array=complex_arr,
266+
output_array=real_arr,
267+
direction="FFTW_BACKWARD",
268+
flags=(planning_flag, "FFTW_DESTROY_INPUT"),
269+
threads=n_threads,
270+
)
271+
self.irfft_objects[key] = irfft_obj
272+
else:
273+
irfft_obj.update_arrays(complex_arr, real_arr)
274+
275+
# RFFT(T)
276+
real_arr[:n] = T
277+
real_arr[n:] = 0.0
278+
rfft_obj.execute() # output is in complex_arr
279+
complex_arr_T = complex_arr.copy()
280+
281+
# RFFT(Q)
282+
# Scale by 1/next_fast_n to account for
283+
# FFTW's unnormalized inverse FFT via execute()
284+
np.multiply(Q[::-1], 1.0 / next_fast_n, out=real_arr[:m])
285+
real_arr[m:] = 0.0
286+
rfft_obj.execute() # output is in complex_arr
287+
288+
# RFFT(T) * RFFT(Q)
289+
np.multiply(complex_arr, complex_arr_T, out=complex_arr)
290+
291+
# IRFFT (input is in complex_arr)
292+
irfft_obj.execute() # output is in real_arr
293+
294+
return real_arr[m - 1 : n]
295+
296+
297+
if FFTW_IS_AVAILABLE:
298+
_pyfftw_sliding_dot_product = _PYFFTW_SLIDING_DOT_PRODUCT(max_n=2**20)
299+
300+
112301
def _sliding_dot_product(Q, T):
113302
"""
114303
Compute the sliding dot product between `Q` and `T`

test.sh

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -154,27 +154,58 @@ gen_ray_coveragerc()
154154
# Generate a .coveragerc_override file that excludes Ray functions and tests
155155
gen_coveragerc_boilerplate
156156
echo " def .*_ray_*" >> .coveragerc_override
157-
echo " def ,*_ray\(*" >> .coveragerc_override
157+
echo " def .*_ray\(*" >> .coveragerc_override
158158
echo " def ray_.*" >> .coveragerc_override
159159
echo " def test_.*_ray*" >> .coveragerc_override
160160
}
161161

162-
set_ray_coveragerc()
162+
check_fftw_pyfftw()
163163
{
164-
# If `ray` command is not found then generate a .coveragerc_override file
165-
if ! command -v ray &> /dev/null
164+
if ! command -v fftw-wisdom &> /dev/null \
165+
|| ! python -c "import pyfftw" &> /dev/null;
166166
then
167-
echo "Ray Not Installed"
167+
echo "FFTW and/or pyFFTW Not Installed"
168+
else
169+
echo "FFTW and pyFFTW Installed"
170+
fi
171+
}
172+
173+
gen_pyfftw_coveragerc()
174+
{
175+
gen_coveragerc_boilerplate
176+
echo " class .*PYFFTW*" >> .coveragerc_override
177+
echo " def test_.*pyfftw*" >> .coveragerc_override
178+
}
179+
180+
set_coveragerc()
181+
{
182+
fcoveragerc=""
183+
184+
if ! command -v ray &> /dev/null;
185+
then
186+
echo "Ray not installed"
168187
gen_ray_coveragerc
169-
fcoveragerc="--rcfile=.coveragerc_override"
170188
else
171-
echo "Ray Installed"
189+
echo "Ray installed"
190+
fi
191+
192+
if ! command -v fftw-wisdom &> /dev/null \
193+
|| ! python -c "import pyfftw" &> /dev/null;
194+
then
195+
echo "FFTW and/or pyFFTW not Installed"
196+
gen_pyfftw_coveragerc
197+
else
198+
echo "FFTW and pyFFTW Installed"
199+
fi
200+
201+
if [ -f .coveragerc_override ]; then
202+
fcoveragerc="--rcfile=.coveragerc_override"
172203
fi
173204
}
174205

175206
show_coverage_report()
176207
{
177-
set_ray_coveragerc
208+
set_coveragerc
178209
coverage report --show-missing --fail-under=100 --skip-covered --omit=fastmath.py,docstring.py,versions.py $fcoveragerc
179210
check_errs $?
180211
}
@@ -361,6 +392,7 @@ check_print
361392
check_pkg_imports
362393
check_naive
363394
check_ray
395+
check_fftw_pyfftw
364396

365397

366398
if [[ -z $NUMBA_DISABLE_JIT || $NUMBA_DISABLE_JIT -eq 0 ]]; then
@@ -405,4 +437,4 @@ else
405437
test_coverage
406438
fi
407439

408-
clean_up
440+
clean_up

tests/test_sdp.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ def get_sdp_function_names():
7474
if func_name.endswith("sliding_dot_product"):
7575
out.append(func_name)
7676

77+
if sdp.FFTW_IS_AVAILABLE:
78+
out.append("_pyfftw_sliding_dot_product")
79+
7780
return out
7881

7982

@@ -152,3 +155,25 @@ def test_sdp_power2():
152155
raise e
153156

154157
return
158+
159+
160+
def test_pyfftw_sdp_max_n():
161+
if not sdp.FFTW_IS_AVAILABLE: # pragma: no cover
162+
pytest.skip("Skipping Test pyFFTW Not Installed")
163+
164+
# When `len(T)` larger than `max_n` in pyfftw_sdp,
165+
# the internal preallocated arrays should be resized.
166+
# This test checks that functionality.
167+
max_n = 2**10
168+
sdp_func = sdp._PYFFTW_SLIDING_DOT_PRODUCT(max_n)
169+
170+
# len(T) > max_n to trigger array resizing
171+
T = np.random.rand(max_n + 1)
172+
Q = np.random.rand(2**8)
173+
174+
comp = sdp_func(Q, T)
175+
ref = naive.rolling_window_dot_product(Q, T)
176+
177+
np.testing.assert_allclose(comp, ref)
178+
179+
return

0 commit comments

Comments
 (0)