Skip to content

Commit 89b6a14

Browse files
committed
feat: refactor _cuda_recurrence to handle complex tensors and improve error handling for pararnn import
1 parent a12d686 commit 89b6a14

1 file changed

Lines changed: 18 additions & 10 deletions

File tree

torchlpc/recurrence.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,24 @@ def _cuda_recurrence(
4646
impulse: torch.Tensor, decay: torch.Tensor, initial_state: torch.Tensor
4747
) -> torch.Tensor:
4848
n_dims, n_steps = decay.shape
49-
if impulse.is_floating_point():
50-
try:
51-
import pararnn.parallel_reduction.parallel_reduction
52-
53-
return torch.ops.parallel_reduce_cuda.parallel_reduce_diag_cuda(
54-
F.pad(-decay, (1, 0)),
55-
torch.cat([initial_state.unsqueeze(1), impulse], dim=1),
56-
)[:, 1:]
57-
except ImportError:
58-
pass
49+
try:
50+
import pararnn.parallel_reduction.parallel_reduction
51+
except ImportError:
52+
pass
53+
else:
54+
jac = F.pad(-decay, (1, 0))
55+
rhs = torch.cat([initial_state.unsqueeze(1), impulse], dim=1)
56+
if decay.is_complex():
57+
jac = torch.stack(
58+
[jac.real, -jac.imag, jac.imag, jac.real], dim=-1
59+
).unflatten(-1, (2, 2))
60+
rhs = torch.view_as_real(rhs)
61+
return torch.view_as_complex(
62+
torch.ops.parallel_reduce_cuda.parallel_reduce_block_diag_2x2_cuda(
63+
jac, rhs
64+
)[:, 1:]
65+
)
66+
return torch.ops.parallel_reduce_cuda.parallel_reduce_diag_cuda(jac, rhs)[:, 1:]
5967

6068
if n_dims * WARPSIZE < n_steps:
6169
runner = scan_cuda_runner

0 commit comments

Comments
 (0)