Skip to content

Commit 986d403

Browse files
authored
Merge pull request #50 from QuState/rayon
multithreading via Rayon
2 parents e3d75e5 + 2d4130e commit 986d403

File tree

7 files changed

+100
-56
lines changed

7 files changed

+100
-56
lines changed

.github/workflows/rust.yml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,13 @@ jobs:
4646
args: --all -- --check
4747

4848
combo:
49-
name: Test
49+
name: Test ${{ matrix.args }}
5050
runs-on: ubuntu-latest
51+
strategy:
52+
fail-fast: false # ensures if one fails, the other keeps running
53+
matrix:
54+
# This creates two jobs: one with the flag, one with an empty string
55+
args: ["--all-features", "--features=complex-nums"]
5156
steps:
5257
- name: Checkout sources
5358
uses: actions/checkout@v2
@@ -63,7 +68,7 @@ jobs:
6368
uses: actions-rs/cargo@v1
6469
with:
6570
command: test
66-
args: --all-features
71+
args: ${{ matrix.args }}
6772

6873
coverage:
6974
runs-on: ubuntu-latest

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@ multiversion = "0.8.0"
1818
num-complex = { version = "0.4.6", features = ["bytemuck"], optional = true }
1919
bytemuck = { version = "1.23.2", optional = true }
2020
wide = "0.8.1"
21+
rayon = { version = "1.11.0", optional = true }
2122

2223
[features]
2324
default = []
2425
complex-nums = ["dep:num-complex", "dep:bytemuck"]
26+
parallel = ["dep:rayon"]
2527

2628
[dev-dependencies]
2729
criterion = "0.8.0"

src/algorithms/dif.rs

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use crate::algorithms::cobra::cobra_apply;
1616
use crate::kernels::common::{fft_chunk_2, fft_chunk_4};
1717
use crate::kernels::dif::{fft_32_chunk_n_simd, fft_64_chunk_n_simd, fft_chunk_n};
1818
use crate::options::Options;
19+
use crate::parallel::run_maybe_in_parallel;
1920
use crate::planner::{Direction, Planner32, Planner64};
2021
use crate::twiddles::filter_twiddles;
2122

@@ -118,15 +119,11 @@ pub fn fft_64_with_opts_and_plan(
118119

119120
// Optional bit reversal (controlled by options)
120121
if opts.dif_perform_bit_reversal {
121-
if opts.multithreaded_bit_reversal {
122-
std::thread::scope(|s| {
123-
s.spawn(|| cobra_apply(reals, n));
124-
s.spawn(|| cobra_apply(imags, n));
125-
});
126-
} else {
127-
cobra_apply(reals, n);
128-
cobra_apply(imags, n);
129-
}
122+
run_maybe_in_parallel(
123+
opts.multithreaded_bit_reversal,
124+
|| cobra_apply(reals, n),
125+
|| cobra_apply(imags, n),
126+
);
130127
}
131128

132129
// Scaling for inverse transform
@@ -225,15 +222,11 @@ pub fn fft_32_with_opts_and_plan(
225222
}
226223

227224
if opts.dif_perform_bit_reversal {
228-
if opts.multithreaded_bit_reversal {
229-
std::thread::scope(|s| {
230-
s.spawn(|| cobra_apply(reals, n));
231-
s.spawn(|| cobra_apply(imags, n));
232-
});
233-
} else {
234-
cobra_apply(reals, n);
235-
cobra_apply(imags, n);
236-
}
225+
run_maybe_in_parallel(
226+
opts.multithreaded_bit_reversal,
227+
|| cobra_apply(reals, n),
228+
|| cobra_apply(imags, n),
229+
);
237230
}
238231

239232
// Scaling for inverse transform

src/algorithms/dit.rs

Lines changed: 41 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use crate::kernels::dit::{
2323
fft_dit_chunk_8_simd_f64,
2424
};
2525
use crate::options::Options;
26+
use crate::parallel::run_maybe_in_parallel;
2627
use crate::planner::{Direction, PlannerDit32, PlannerDit64};
2728

2829
/// L1 cache block size in complex elements (8KB for f32, 16KB for f64)
@@ -35,18 +36,18 @@ const L1_BLOCK_SIZE: usize = 1024;
3536
fn recursive_dit_fft_f64(
3637
reals: &mut [f64],
3738
imags: &mut [f64],
38-
offset: usize,
3939
size: usize,
4040
planner: &PlannerDit64,
41+
opts: &Options,
4142
mut stage_twiddle_idx: usize,
4243
) -> usize {
4344
let log_size = size.ilog2() as usize;
4445

4546
if size <= L1_BLOCK_SIZE {
4647
for stage in 0..log_size {
4748
stage_twiddle_idx = execute_dit_stage_f64(
48-
&mut reals[offset..offset + size],
49-
&mut imags[offset..offset + size],
49+
&mut reals[..size],
50+
&mut imags[..size],
5051
stage,
5152
planner,
5253
stage_twiddle_idx,
@@ -57,9 +58,14 @@ fn recursive_dit_fft_f64(
5758
let half = size / 2;
5859
let log_half = half.ilog2() as usize;
5960

61+
let (re_first_half, re_second_half) = reals.split_at_mut(half);
62+
let (im_first_half, im_second_half) = imags.split_at_mut(half);
6063
// Recursively process both halves
61-
recursive_dit_fft_f64(reals, imags, offset, half, planner, 0);
62-
recursive_dit_fft_f64(reals, imags, offset + half, half, planner, 0);
64+
run_maybe_in_parallel(
65+
size > opts.smallest_parallel_chunk_size,
66+
|| recursive_dit_fft_f64(re_first_half, im_first_half, half, planner, opts, 0),
67+
|| recursive_dit_fft_f64(re_second_half, im_second_half, half, planner, opts, 0),
68+
);
6369

6470
// Both halves completed stages 0..log_half-1
6571
// Stages 0-5 use hardcoded twiddles, 6+ use planner
@@ -68,8 +74,8 @@ fn recursive_dit_fft_f64(
6874
// Process remaining stages that span both halves
6975
for stage in log_half..log_size {
7076
stage_twiddle_idx = execute_dit_stage_f64(
71-
&mut reals[offset..offset + size],
72-
&mut imags[offset..offset + size],
77+
&mut reals[..size],
78+
&mut imags[..size],
7379
stage,
7480
planner,
7581
stage_twiddle_idx,
@@ -84,18 +90,18 @@ fn recursive_dit_fft_f64(
8490
fn recursive_dit_fft_f32(
8591
reals: &mut [f32],
8692
imags: &mut [f32],
87-
offset: usize,
8893
size: usize,
8994
planner: &PlannerDit32,
95+
opts: &Options,
9096
mut stage_twiddle_idx: usize,
9197
) -> usize {
9298
let log_size = size.ilog2() as usize;
9399

94100
if size <= L1_BLOCK_SIZE {
95101
for stage in 0..log_size {
96102
stage_twiddle_idx = execute_dit_stage_f32(
97-
&mut reals[offset..offset + size],
98-
&mut imags[offset..offset + size],
103+
&mut reals[..size],
104+
&mut imags[..size],
99105
stage,
100106
planner,
101107
stage_twiddle_idx,
@@ -106,15 +112,24 @@ fn recursive_dit_fft_f32(
106112
let half = size / 2;
107113
let log_half = half.ilog2() as usize;
108114

109-
recursive_dit_fft_f32(reals, imags, offset, half, planner, 0);
110-
recursive_dit_fft_f32(reals, imags, offset + half, half, planner, 0);
115+
let (re_first_half, re_second_half) = reals.split_at_mut(half);
116+
let (im_first_half, im_second_half) = imags.split_at_mut(half);
117+
// Recursively process both halves
118+
run_maybe_in_parallel(
119+
size > opts.smallest_parallel_chunk_size,
120+
|| recursive_dit_fft_f32(re_first_half, im_first_half, half, planner, opts, 0),
121+
|| recursive_dit_fft_f32(re_second_half, im_second_half, half, planner, opts, 0),
122+
);
111123

124+
// Both halves completed stages 0..log_half-1
125+
// Stages 0-5 use hardcoded twiddles, 6+ use planner
112126
stage_twiddle_idx = log_half.saturating_sub(6);
113127

128+
// Process remaining stages that span both halves
114129
for stage in log_half..log_size {
115130
stage_twiddle_idx = execute_dit_stage_f32(
116-
&mut reals[offset..offset + size],
117-
&mut imags[offset..offset + size],
131+
&mut reals[..size],
132+
&mut imags[..size],
118133
stage,
119134
planner,
120135
stage_twiddle_idx,
@@ -235,15 +250,11 @@ pub fn fft_64_dit_with_planner_and_opts(
235250
assert_eq!(log_n, planner.log_n);
236251

237252
// DIT requires bit-reversed input
238-
if opts.multithreaded_bit_reversal {
239-
std::thread::scope(|s| {
240-
s.spawn(|| cobra_apply(reals, log_n));
241-
s.spawn(|| cobra_apply(imags, log_n));
242-
});
243-
} else {
244-
cobra_apply(reals, log_n);
245-
cobra_apply(imags, log_n);
246-
}
253+
run_maybe_in_parallel(
254+
opts.multithreaded_bit_reversal,
255+
|| cobra_apply(reals, log_n),
256+
|| cobra_apply(imags, log_n),
257+
);
247258

248259
// Handle inverse FFT
249260
if let Direction::Reverse = planner.direction {
@@ -252,7 +263,7 @@ pub fn fft_64_dit_with_planner_and_opts(
252263
}
253264
}
254265

255-
recursive_dit_fft_f64(reals, imags, 0, n, planner, 0);
266+
recursive_dit_fft_f64(reals, imags, n, planner, opts, 0);
256267

257268
// Scaling for inverse transform
258269
if let Direction::Reverse = planner.direction {
@@ -282,15 +293,11 @@ pub fn fft_32_dit_with_planner_and_opts(
282293
assert_eq!(log_n, planner.log_n);
283294

284295
// DIT requires bit-reversed input
285-
if opts.multithreaded_bit_reversal {
286-
std::thread::scope(|s| {
287-
s.spawn(|| cobra_apply(reals, log_n));
288-
s.spawn(|| cobra_apply(imags, log_n));
289-
});
290-
} else {
291-
cobra_apply(reals, log_n);
292-
cobra_apply(imags, log_n);
293-
}
296+
run_maybe_in_parallel(
297+
opts.multithreaded_bit_reversal,
298+
|| cobra_apply(reals, log_n),
299+
|| cobra_apply(imags, log_n),
300+
);
294301

295302
// Handle inverse FFT
296303
if let Direction::Reverse = planner.direction {
@@ -299,7 +306,7 @@ pub fn fft_32_dit_with_planner_and_opts(
299306
}
300307
}
301308

302-
recursive_dit_fft_f32(reals, imags, 0, n, planner, 0);
309+
recursive_dit_fft_f32(reals, imags, n, planner, opts, 0);
303310

304311
// Scaling for inverse transform
305312
if let Direction::Reverse = planner.direction {

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use crate::utils::{combine_re_im, deinterleave_complex32, deinterleave_complex64
2020
mod algorithms;
2121
mod kernels;
2222
pub mod options;
23+
mod parallel;
2324
pub mod planner;
2425
mod twiddles;
2526
mod utils;

src/options.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
#[derive(Debug, Clone)]
1010
pub struct Options {
1111
/// Whether to run the bit reversal step in 2 threads instead of one.
12-
/// This is beneficial only at large input sizes (i.e. gigabytes of data).
12+
/// This is beneficial only at medium to large sizes (i.e. megabytes of data).
1313
/// The exact threshold where it starts being beneficial varies depending on the hardware.
14+
///
15+
/// This option is ignored if the `parallel` feature is disabled.
1416
pub multithreaded_bit_reversal: bool,
1517

1618
/// Controls bit reversal behavior for DIF FFT algorithms.
@@ -33,13 +35,21 @@ pub struct Options {
3335
/// fft_64_with_opts_and_plan(&mut reals, &mut imags, &opts, &planner);
3436
/// ```
3537
pub dif_perform_bit_reversal: bool,
38+
39+
/// Do not split the input any further to run in parallel below this size
40+
///
41+
/// Set to `usize::MAX` to disable parallelism in the recursive FFT step.
42+
///
43+
/// This option is ignored if the `parallel` feature is disabled.
44+
pub smallest_parallel_chunk_size: usize,
3645
}
3746

3847
impl Default for Options {
3948
fn default() -> Self {
4049
Self {
4150
multithreaded_bit_reversal: false,
4251
dif_perform_bit_reversal: true, // Default to standard FFT behavior
52+
smallest_parallel_chunk_size: usize::MAX,
4353
}
4454
}
4555
}
@@ -49,7 +59,8 @@ impl Options {
4959
pub fn guess_options(input_size: usize) -> Options {
5060
let mut options = Options::default();
5161
let n: usize = input_size.ilog2() as usize;
52-
options.multithreaded_bit_reversal = n >= 22;
62+
options.multithreaded_bit_reversal = n >= 16;
63+
options.smallest_parallel_chunk_size = 16384;
5364
options
5465
}
5566
}

src/parallel.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
//! Utilities for parallelism
2+
3+
/// Runs the two specified closures in parallel,
4+
/// if and only if `parallel` is set to `true` and the `parallel` feature is enabled
5+
#[allow(unused_variables)] // when `parallel` feature is disabled, the variable is ignored
6+
pub fn run_maybe_in_parallel<A, B, RA, RB>(parallel: bool, oper_a: A, oper_b: B) -> (RA, RB)
7+
where
8+
A: FnOnce() -> RA + Send,
9+
B: FnOnce() -> RB + Send,
10+
RA: Send,
11+
RB: Send,
12+
{
13+
#[cfg(feature = "parallel")]
14+
{
15+
if parallel {
16+
rayon::join(oper_a, oper_b)
17+
} else {
18+
(oper_a(), oper_b())
19+
}
20+
}
21+
#[cfg(not(feature = "parallel"))]
22+
{
23+
(oper_a(), oper_b())
24+
}
25+
}

0 commit comments

Comments
 (0)