Skip to content

Commit 6051ab4

Browse files
committed
addressed comments
1 parent ae90507 commit 6051ab4

File tree

2 files changed

+27
-29
lines changed

2 files changed

+27
-29
lines changed

stumpy/sdp.py

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
try: # pragma: no cover
1010
import pyfftw
1111

12-
FFTW_IS_AVAILABLE = True
12+
PYFFTW_IS_AVAILABLE = True
1313
except ImportError: # pragma: no cover
14-
FFTW_IS_AVAILABLE = False
14+
PYFFTW_IS_AVAILABLE = False
1515

1616

1717
@njit(fastmath=config.STUMPY_FASTMATH_TRUE)
@@ -121,7 +121,7 @@ class _PYFFTW_SLIDING_DOT_PRODUCT:
121121
A class to compute the sliding dot product using FFTW via pyfftw.
122122
123123
This class uses FFTW (via pyfftw) to efficiently compute the sliding dot product
124-
between a query sequence Q and a time series T. It preallocates arrays and caches
124+
between a query sequence, Q, and a time series, T. It preallocates arrays and caches
125125
FFTW objects to optimize repeated computations with similar-sized inputs.
126126
127127
Parameters
@@ -141,12 +141,12 @@ class _PYFFTW_SLIDING_DOT_PRODUCT:
141141
Preallocated complex-valued array for FFTW computations.
142142
143143
rfft_objects : dict
144-
Cache of FFTW forward transform objects, keyed by
145-
(next_fast_n, n_threads, planning_flag).
144+
Cache of FFTW forward transform objects with
145+
(next_fast_n, n_threads, planning_flag) as lookup keys.
146146
147147
irfft_objects : dict
148-
Cache of FFTW inverse transform objects, keyed by
149-
(next_fast_n, n_threads, planning_flag).
148+
Cache of FFTW inverse transform objects with
149+
(next_fast_n, n_threads, planning_flag) as lookup keys.
150150
151151
Notes
152152
-----
@@ -247,54 +247,52 @@ def __call__(self, Q, T, n_threads=1, planning_flag="FFTW_ESTIMATE"):
247247
key = (next_fast_n, n_threads, planning_flag)
248248

249249
rfft_obj = self.rfft_objects.get(key, None)
250-
if rfft_obj is None:
250+
irfft_obj = self.irfft_objects.get(key, None)
251+
252+
if rfft_obj is None or irfft_obj is None:
251253
rfft_obj = pyfftw.FFTW(
252254
input_array=real_arr,
253255
output_array=complex_arr,
254256
direction="FFTW_FORWARD",
255257
flags=(planning_flag,),
256258
threads=n_threads,
257259
)
258-
self.rfft_objects[key] = rfft_obj
259-
else:
260-
rfft_obj.update_arrays(real_arr, complex_arr)
261-
262-
irfft_obj = self.irfft_objects.get(key, None)
263-
if irfft_obj is None:
264260
irfft_obj = pyfftw.FFTW(
265261
input_array=complex_arr,
266262
output_array=real_arr,
267263
direction="FFTW_BACKWARD",
268264
flags=(planning_flag, "FFTW_DESTROY_INPUT"),
269265
threads=n_threads,
270266
)
267+
self.rfft_objects[key] = rfft_obj
271268
self.irfft_objects[key] = irfft_obj
272269
else:
270+
rfft_obj.update_arrays(real_arr, complex_arr)
273271
irfft_obj.update_arrays(complex_arr, real_arr)
274272

275-
# RFFT(T)
273+
# Compute RFFT of T
276274
real_arr[:n] = T
277275
real_arr[n:] = 0.0
278-
rfft_obj.execute() # output is in complex_arr
279-
complex_arr_T = complex_arr.copy()
276+
rfft_obj.execute() # output is stored in complex_arr
280277

281-
# RFFT(Q)
282-
# Scale by 1/next_fast_n to account for
283-
# FFTW's unnormalized inverse FFT via execute()
278+
# need to make a copy since the array will be
279+
# overwritten later during the RFFT(Q) step
280+
rfft_of_T = complex_arr.copy()
281+
282+
# Compute RFFT of Q (reversed and scaled by 1/next_fast_n)
284283
np.multiply(Q[::-1], 1.0 / next_fast_n, out=real_arr[:m])
285284
real_arr[m:] = 0.0
286-
rfft_obj.execute() # output is in complex_arr
287-
288-
# RFFT(T) * RFFT(Q)
289-
np.multiply(complex_arr, complex_arr_T, out=complex_arr)
285+
rfft_obj.execute() # output is stored in complex_arr
286+
rfft_of_Q = complex_arr
290287

291-
# IRFFT (input is in complex_arr)
292-
irfft_obj.execute() # output is in real_arr
288+
# Compute IRFFT of the element-wise product of the RFFTs
289+
np.multiply(rfft_of_Q, rfft_of_T, out=complex_arr)
290+
irfft_obj.execute() # output is stored in real_arr
293291

294292
return real_arr[m - 1 : n]
295293

296294

297-
if FFTW_IS_AVAILABLE: # pragma: no cover
295+
if PYFFTW_IS_AVAILABLE: # pragma: no cover
298296
_pyfftw_sliding_dot_product = _PYFFTW_SLIDING_DOT_PRODUCT(max_n=2**20)
299297

300298

tests/test_sdp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def get_sdp_function_names():
7474
if func_name.endswith("sliding_dot_product"):
7575
out.append(func_name)
7676

77-
if sdp.FFTW_IS_AVAILABLE: # pragma: no cover
77+
if sdp.PYFFTW_IS_AVAILABLE: # pragma: no cover
7878
out.append("_pyfftw_sliding_dot_product")
7979

8080
return out
@@ -158,7 +158,7 @@ def test_sdp_power2():
158158

159159

160160
def test_pyfftw_sdp_max_n():
161-
if not sdp.FFTW_IS_AVAILABLE: # pragma: no cover
161+
if not sdp.PYFFTW_IS_AVAILABLE: # pragma: no cover
162162
pytest.skip("Skipping Test pyFFTW Not Installed")
163163

164164
# When `len(T)` larger than `max_n` in pyfftw_sdp,

0 commit comments

Comments
 (0)