99try : # pragma: no cover
1010 import pyfftw
1111
12- FFTW_IS_AVAILABLE = True
12+ PYFFTW_IS_AVAILABLE = True
1313except ImportError : # pragma: no cover
14- FFTW_IS_AVAILABLE = False
14+ PYFFTW_IS_AVAILABLE = False
1515
1616
1717@njit (fastmath = config .STUMPY_FASTMATH_TRUE )
@@ -121,7 +121,7 @@ class _PYFFTW_SLIDING_DOT_PRODUCT:
121121 A class to compute the sliding dot product using FFTW via pyfftw.
122122
123123 This class uses FFTW (via pyfftw) to efficiently compute the sliding dot product
124- between a query sequence Q and a time series T. It preallocates arrays and caches
124+ between a query sequence, Q, and a time series, T. It preallocates arrays and caches
125125 FFTW objects to optimize repeated computations with similar-sized inputs.
126126
127127 Parameters
@@ -141,12 +141,12 @@ class _PYFFTW_SLIDING_DOT_PRODUCT:
141141 Preallocated complex-valued array for FFTW computations.
142142
143143 rfft_objects : dict
144- Cache of FFTW forward transform objects, keyed by
145- (next_fast_n, n_threads, planning_flag).
144+ Cache of FFTW forward transform objects with
145+ (next_fast_n, n_threads, planning_flag) as lookup keys .
146146
147147 irfft_objects : dict
148- Cache of FFTW inverse transform objects, keyed by
149- (next_fast_n, n_threads, planning_flag).
148+ Cache of FFTW inverse transform objects with
149+ (next_fast_n, n_threads, planning_flag) as lookup keys .
150150
151151 Notes
152152 -----
@@ -247,54 +247,52 @@ def __call__(self, Q, T, n_threads=1, planning_flag="FFTW_ESTIMATE"):
247247 key = (next_fast_n , n_threads , planning_flag )
248248
249249 rfft_obj = self .rfft_objects .get (key , None )
250- if rfft_obj is None :
250+ irfft_obj = self .irfft_objects .get (key , None )
251+
252+ if rfft_obj is None or irfft_obj is None :
251253 rfft_obj = pyfftw .FFTW (
252254 input_array = real_arr ,
253255 output_array = complex_arr ,
254256 direction = "FFTW_FORWARD" ,
255257 flags = (planning_flag ,),
256258 threads = n_threads ,
257259 )
258- self .rfft_objects [key ] = rfft_obj
259- else :
260- rfft_obj .update_arrays (real_arr , complex_arr )
261-
262- irfft_obj = self .irfft_objects .get (key , None )
263- if irfft_obj is None :
264260 irfft_obj = pyfftw .FFTW (
265261 input_array = complex_arr ,
266262 output_array = real_arr ,
267263 direction = "FFTW_BACKWARD" ,
268264 flags = (planning_flag , "FFTW_DESTROY_INPUT" ),
269265 threads = n_threads ,
270266 )
267+ self .rfft_objects [key ] = rfft_obj
271268 self .irfft_objects [key ] = irfft_obj
272269 else :
270+ rfft_obj .update_arrays (real_arr , complex_arr )
273271 irfft_obj .update_arrays (complex_arr , real_arr )
274272
275- # RFFT(T)
273+ # Compute RFFT of T
276274 real_arr [:n ] = T
277275 real_arr [n :] = 0.0
278- rfft_obj .execute () # output is in complex_arr
279- complex_arr_T = complex_arr .copy ()
276+ rfft_obj .execute () # output is stored in complex_arr
280277
281- # RFFT(Q)
282- # Scale by 1/next_fast_n to account for
283- # FFTW's unnormalized inverse FFT via execute()
278+ # need to make a copy since the array will be
279+ # overwritten later during the RFFT(Q) step
280+ rfft_of_T = complex_arr .copy ()
281+
282+ # Compute RFFT of Q (reversed and scaled by 1/next_fast_n)
284283 np .multiply (Q [::- 1 ], 1.0 / next_fast_n , out = real_arr [:m ])
285284 real_arr [m :] = 0.0
286- rfft_obj .execute () # output is in complex_arr
287-
288- # RFFT(T) * RFFT(Q)
289- np .multiply (complex_arr , complex_arr_T , out = complex_arr )
285+ rfft_obj .execute () # output is stored in complex_arr
286+ rfft_of_Q = complex_arr
290287
291- # IRFFT (input is in complex_arr)
292- irfft_obj .execute () # output is in real_arr
288+ # Compute IRFFT of the element-wise product of the RFFTs
289+ np .multiply (rfft_of_Q , rfft_of_T , out = complex_arr )
290+ irfft_obj .execute () # output is stored in real_arr
293291
294292 return real_arr [m - 1 : n ]
295293
296294
297- if FFTW_IS_AVAILABLE : # pragma: no cover
295+ if PYFFTW_IS_AVAILABLE : # pragma: no cover
298296 _pyfftw_sliding_dot_product = _PYFFTW_SLIDING_DOT_PRODUCT (max_n = 2 ** 20 )
299297
300298
0 commit comments