@@ -46,16 +46,24 @@ def _cuda_recurrence(
4646 impulse : torch .Tensor , decay : torch .Tensor , initial_state : torch .Tensor
4747) -> torch .Tensor :
4848 n_dims , n_steps = decay .shape
49- if impulse .is_floating_point ():
50- try :
51- import pararnn .parallel_reduction .parallel_reduction
52-
53- return torch .ops .parallel_reduce_cuda .parallel_reduce_diag_cuda (
54- F .pad (- decay , (1 , 0 )),
55- torch .cat ([initial_state .unsqueeze (1 ), impulse ], dim = 1 ),
56- )[:, 1 :]
57- except ImportError :
58- pass
49+ try :
50+ import pararnn .parallel_reduction .parallel_reduction
51+ except ImportError :
52+ pass
53+ else :
54+ jac = F .pad (- decay , (1 , 0 ))
55+ rhs = torch .cat ([initial_state .unsqueeze (1 ), impulse ], dim = 1 )
56+ if decay .is_complex ():
57+ jac = torch .stack (
58+ [jac .real , - jac .imag , jac .imag , jac .real ], dim = - 1
59+ ).unflatten (- 1 , (2 , 2 ))
60+ rhs = torch .view_as_real (rhs )
61+ return torch .view_as_complex (
62+ torch .ops .parallel_reduce_cuda .parallel_reduce_block_diag_2x2_cuda (
63+ jac , rhs
64+ )[:, 1 :]
65+ )
66+ return torch .ops .parallel_reduce_cuda .parallel_reduce_diag_cuda (jac , rhs )[:, 1 :]
5967
6068 if n_dims * WARPSIZE < n_steps :
6169 runner = scan_cuda_runner
0 commit comments