Skip to content

Commit 9950b8a

Browse files
committed
Fix kernel test warnings, improve error handling, and update README with test results summary
1 parent ced7a52 commit 9950b8a

17 files changed

Lines changed: 3465 additions & 479 deletions

crates/bitnet-core/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ bytemuck = { workspace = true, version = "1.12.3", features = ["derive"] }
3737
pollster = { workspace = true }
3838
derive-new = "0.7.0"
3939
bitnet-tools = { path = "../bitnet-tools" }
40-
wgpu = { version = "0.20.0", features = ["naga"], optional = true }
40+
wgpu = { version = "25.0.2", optional = true }
4141
rand = "0.9.1"
4242
tokio = { version = "1.36.0", features = ["full"] }
4343
futures-intrusive = "0.5"

crates/bitnet-core/README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,15 @@ use bitnet_core::model::Transformer;
8282
- Streaming and per-block model loading tests
8383
- **Optional Stress Test**: A long-running stress test (`stress_test_maximum_dimension_support`) is available but ignored by default. To run it, set the `RUN_STRESS_TESTS` environment variable:
8484
- **PowerShell**:
85+
8586
```powershell
87+
8688
$env:RUN_STRESS_TESTS="1"; cargo test --package bitnet-core --test kernel_tests -- --nocapture
89+
8790
```
91+
8892
- **Linux/macOS**:
93+
8994
```bash
9095
RUN_STRESS_TESTS=1 cargo test --package bitnet-core --test kernel_tests -- --nocapture
9196
```

crates/bitnet-core/src/bitnet_linear.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,8 +295,9 @@ impl BitLinear {
295295
label: Some("Bitnet Pipeline"),
296296
layout: Some(&pipeline_layout),
297297
module: &shader,
298-
entry_point: "main",
298+
entry_point: Some("main"),
299299
compilation_options: Default::default(),
300+
cache: None,
300301
});
301302

302303
// Step 4: Create and submit command buffer

crates/bitnet-core/src/kernels.rs

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,11 @@ pub struct BitnetMetadata {
9696
/// - +1 => 10 (binary 2)
9797
/// 2. Per-output-channel scaling factors
9898
///
99+
/// # Packing Order (LSB-first)
100+
///
101+
/// The first weight goes into the lowest bits (bits 0-1), the next into bits 2-3, ..., the last (16th) into bits 30-31.
102+
/// This matches the GPU shader and converter logic.
103+
///
99104
/// # Arguments
100105
///
101106
/// * `weights` - 2D array of ternary weights [out_features][in_features]
@@ -138,7 +143,7 @@ pub fn pack_ternary_weights(weights: &[Vec<i8>]) -> Result<(Vec<u32>, Vec<f32>),
138143
// Pack weights
139144
for (in_idx, &w) in row.iter().enumerate() {
140145
let pack_idx = in_idx / 16;
141-
let bit_idx = 30 - ((in_idx % 16) * 2); // Start from MSB and work down
146+
let bit_idx = (in_idx % 16) * 2; // LSB-first: first weight in bits 0-1
142147

143148
// Map -1, 0, +1 to 2-bit values
144149
let bits = match w {
@@ -215,10 +220,11 @@ mod tests {
215220
assert_eq!(scales.len(), 2);
216221
assert!(scales.iter().all(|&s| s > 0.0));
217222

218-
// Verify first row packing
223+
// Verify first row packing (LSB-first)
219224
// Input: -1, 0, 1, 0, -1, 1, 0, 0, 1, -1, 0, 1, 0, -1, 1, 0
220225
// Bits: 00, 01, 10, 01, 00, 10, 01, 01, 10, 00, 01, 10, 01, 00, 10, 01
221-
let expected = 0b00_01_10_01_00_10_01_01_10_00_01_10_01_00_10_01u32;
226+
// LSB-first: first weight in bits 0-1, last in 30-31
227+
let expected = 0b01_10_00_01_10_01_01_10_00_10_01_00_01_10_01_00u32;
222228
assert_eq!(packed[0], expected, "Packed weights don't match expected pattern.\nExpected: {:032b}\nGot: {:032b}", expected, packed[0]);
223229

224230
// Verify second row packing
@@ -264,4 +270,26 @@ mod tests {
264270
);
265271
println!("[TEST] test_invalid_weight_value (took {:.2?})", t0.elapsed());
266272
}
273+
274+
#[test]
275+
fn test_packing_unpacking_symmetry() {
276+
// This test ensures that packing and then unpacking using the shader/scalar logic is symmetric.
277+
let original: Vec<i8> = vec![-1, 0, 1, 0, -1, 1, 0, 0, 1, -1, 0, 1, 0, -1, 1, 0];
278+
let weights = vec![original.clone()];
279+
let (packed, _) = pack_ternary_weights(&weights).unwrap();
280+
let packed_val = packed[0];
281+
// Unpack using the same logic as the shader/scalar code (LSB-first)
282+
let mut unpacked = Vec::with_capacity(16);
283+
for i in 0..16 {
284+
let bits = (packed_val >> (i * 2)) & 0b11;
285+
let w = match bits {
286+
0 => -1i8,
287+
1 => 0i8,
288+
2 => 1i8,
289+
_ => 0i8, // Should never happen
290+
};
291+
unpacked.push(w);
292+
}
293+
assert_eq!(original, unpacked, "Packing and unpacking are not symmetric!\nOriginal: {:?}\nUnpacked: {:?}", original, unpacked);
294+
}
267295
}

crates/bitnet-core/src/kernels/bitnet_kernel_optimal.wgsl

Lines changed: 14 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// bitnet_kernel_optimal.wgsl
2-
// Optimized BitNet B1.58 Ternary Kernel for WGPU (Optimal, but fails on DX12)
1+
// bitnet_kernel_pre_naga.wgsl
2+
// Optimized BitNet B1.58 Ternary Kernel for WGPU (Original, before any DX12/Naga workaround)
33
// Supports {-1, 0, +1} ternary weights with efficient packing and vectorization
44

55
struct BitnetMetadata {
@@ -17,11 +17,11 @@ struct BitnetMetadata {
1717
@group(0) @binding(5) var<storage, read_write> output: array<f32>;
1818

1919
// Optimized tiling parameters for modern GPUs
20-
const TILE_DIM_M: u32 = 64u; // Reduced for better occupancy
20+
const TILE_DIM_M: u32 = 64u;
2121
const TILE_DIM_N: u32 = 64u;
22-
const TILE_DIM_K: u32 = 32u; // Increased K tile for better data reuse
22+
const TILE_DIM_K: u32 = 32u;
2323

24-
const THREAD_TILE_M: u32 = 4u; // Smaller thread tiles for better vectorization
24+
const THREAD_TILE_M: u32 = 4u;
2525
const THREAD_TILE_N: u32 = 4u;
2626

2727
const WORKGROUP_SIZE_X: u32 = 16u; // TILE_DIM_N / THREAD_TILE_N
@@ -31,12 +31,10 @@ const WORKGROUP_SIZE_Y: u32 = 16u; // TILE_DIM_M / THREAD_TILE_M
3131
const TILE_A_SIZE: u32 = (TILE_DIM_M * TILE_DIM_K) / 4u; // for vec4<i32>
3232
const TILE_B_SIZE: u32 = TILE_DIM_K * TILE_DIM_N; // for i32
3333

34-
// Shared memory with better alignment
3534
var<workgroup> tile_a: array<vec4<i32>, TILE_A_SIZE>;
3635
var<workgroup> tile_b: array<i32, TILE_B_SIZE>;
3736

38-
// Use direct decode function for ternary weights, matching the logic
39-
// from the previously passing tests.
37+
// Use direct decode function for ternary weights
4038
fn decode_2bit(val: u32) -> i32 {
4139
switch(val) {
4240
case 1u: { return 1; } // 01
@@ -45,8 +43,6 @@ fn decode_2bit(val: u32) -> i32 {
4543
}
4644
}
4745

48-
// Unroll the decoding loop for maximum compatibility, while preserving
49-
// the LSB-to-MSB decoding order from the previously passing version.
5046
fn decode_16x2bit_ternary(packed_val: u32) -> array<i32, 16> {
5147
var decoded: array<i32, 16>;
5248
decoded[0] = decode_2bit((packed_val >> 0u) & 0x3u);
@@ -68,7 +64,6 @@ fn decode_16x2bit_ternary(packed_val: u32) -> array<i32, 16> {
6864
return decoded;
6965
}
7066

71-
// Vectorized dot product for better throughput
7267
fn dot_product_4x4(a: vec4<i32>, b: vec4<i32>) -> i32 {
7368
return dot(a, b);
7469
}
@@ -84,21 +79,17 @@ fn main(
8479
let tile_start_m = workgroup_id.y * TILE_DIM_M;
8580
let tile_start_n = workgroup_id.x * TILE_DIM_N;
8681

87-
// FIX 1: Use a flattened array of i32 for the accumulator to avoid the
88-
// `array<vec4<i32>>` indexing bug on the Dx12 backend.
89-
var accumulators: array<i32, 16>;
90-
for (var i = 0u; i < 16u; i = i + 1u) {
91-
accumulators[i] = 0;
82+
// Original accumulator: vectorized style
83+
var accumulators: array<vec4<i32>, THREAD_TILE_M>;
84+
for (var i = 0u; i < THREAD_TILE_M; i = i + 1u) {
85+
accumulators[i] = vec4<i32>(0);
9286
}
9387

94-
// Main tiling loop with optimizations
9588
let num_k_tiles = (metadata.K + TILE_DIM_K - 1u) / TILE_DIM_K;
9689

9790
var k_tile_idx = 0u;
9891
while (k_tile_idx < num_k_tiles) {
9992
let k_tile_start = k_tile_idx * TILE_DIM_K;
100-
// === Cooperative Loading with Coalescing ===
101-
// Load activations with vectorization
10293
let total_a_elements = TILE_DIM_M * TILE_DIM_K / 4u;
10394
let loads_per_thread_a = (total_a_elements + 255u) / 256u; // Ceiling division
10495
for (var i = 0u; i < loads_per_thread_a; i = i + 1u) {
@@ -111,7 +102,6 @@ fn main(
111102
let global_m = tile_start_m + m;
112103
let global_k = k_tile_start + k;
113104
if (global_m < metadata.M && global_k + 3u < metadata.K) {
114-
// Load 4 activations at once
115105
let base_addr = global_m * metadata.K + global_k;
116106
tile_a[vec_idx] = vec4<i32>(
117107
activations[base_addr],
@@ -124,7 +114,6 @@ fn main(
124114
}
125115
}
126116
}
127-
// Load and decode weights
128117
let total_b_elements = TILE_DIM_N * TILE_DIM_K;
129118
let loads_per_thread_b = (total_b_elements + 255u) / 256u;
130119
for (var i = 0u; i < loads_per_thread_b; i = i + 1u) {
@@ -137,26 +126,7 @@ fn main(
137126
if (global_n < metadata.N && global_k_packed_idx < metadata.K_packed) {
138127
let weight_idx = global_n * metadata.K_packed + global_k_packed_idx;
139128
let packed_w = packed_weights[weight_idx];
140-
// FIX 2: Inline the decoding logic. The Dx12 backend fails pipeline
141-
// creation when a function returns an array, as seen in the V4.2.2 test.
142-
var decoded: array<i32, 16>;
143-
decoded[0] = decode_2bit((packed_w >> 0u) & 0x3u);
144-
decoded[1] = decode_2bit((packed_w >> 2u) & 0x3u);
145-
decoded[2] = decode_2bit((packed_w >> 4u) & 0x3u);
146-
decoded[3] = decode_2bit((packed_w >> 6u) & 0x3u);
147-
decoded[4] = decode_2bit((packed_w >> 8u) & 0x3u);
148-
decoded[5] = decode_2bit((packed_w >> 10u) & 0x3u);
149-
decoded[6] = decode_2bit((packed_w >> 12u) & 0x3u);
150-
decoded[7] = decode_2bit((packed_w >> 14u) & 0x3u);
151-
decoded[8] = decode_2bit((packed_w >> 16u) & 0x3u);
152-
decoded[9] = decode_2bit((packed_w >> 18u) & 0x3u);
153-
decoded[10] = decode_2bit((packed_w >> 20u) & 0x3u);
154-
decoded[11] = decode_2bit((packed_w >> 22u) & 0x3u);
155-
decoded[12] = decode_2bit((packed_w >> 24u) & 0x3u);
156-
decoded[13] = decode_2bit((packed_w >> 26u) & 0x3u);
157-
decoded[14] = decode_2bit((packed_w >> 28u) & 0x3u);
158-
decoded[15] = decode_2bit((packed_w >> 30u) & 0x3u);
159-
// Store decoded weights (unrolled for WGSL compliance)
129+
let decoded = decode_16x2bit_ternary(packed_w);
160130
tile_b[n * TILE_DIM_K + k + 0u] = decoded[0u];
161131
tile_b[n * TILE_DIM_K + k + 1u] = decoded[1u];
162132
tile_b[n * TILE_DIM_K + k + 2u] = decoded[2u];
@@ -174,25 +144,21 @@ fn main(
174144
tile_b[n * TILE_DIM_K + k + 14u] = decoded[14u];
175145
tile_b[n * TILE_DIM_K + k + 15u] = decoded[15u];
176146
} else {
177-
// Pad with zeros
178147
for (var j = 0u; j < 16u; j = j + 1u) {
179148
tile_b[n * TILE_DIM_K + k + j] = 0;
180149
}
181150
}
182151
}
183152
}
184153
workgroupBarrier();
185-
// === Vectorized Computation ===
186154
for (var k_inner = 0u; k_inner < TILE_DIM_K; k_inner = k_inner + 4u) {
187-
// Load vectorized activations
188155
var a_vecs: array<vec4<i32>, THREAD_TILE_M>;
189156
for (var m = 0u; m < THREAD_TILE_M; m = m + 1u) {
190157
let base_m = thread_idx_m * THREAD_TILE_M + m;
191158
let vec_idx = (base_m * TILE_DIM_K + k_inner) / 4u;
192159
let a_i32 = tile_a[vec_idx];
193160
a_vecs[m] = a_i32;
194161
}
195-
// Load vectorized weights and compute
196162
for (var n = 0u; n < THREAD_TILE_N; n = n + 1u) {
197163
let base_n = thread_idx_n * THREAD_TILE_N + n;
198164
let b_vec = vec4<i32>(
@@ -201,32 +167,25 @@ fn main(
201167
tile_b[base_n * TILE_DIM_K + k_inner + 2u],
202168
tile_b[base_n * TILE_DIM_K + k_inner + 3u]
203169
);
204-
// Vectorized multiply-accumulate
205170
for (var m = 0u; m < THREAD_TILE_M; m = m + 1u) {
206171
let dot_result = dot_product_4x4(a_vecs[m], b_vec);
207-
// Manually calculate the 1D index for the flattened accumulator.
208-
let acc_idx = m * THREAD_TILE_N + n;
209-
accumulators[acc_idx] += dot_result;
172+
accumulators[m][n] += dot_result;
210173
}
211174
}
212175
}
213176
workgroupBarrier();
214177
k_tile_idx = k_tile_idx + 1u;
215178
}
216-
// === Write Results with Proper Scaling ===
217179
for (var m = 0u; m < THREAD_TILE_M; m = m + 1u) {
218180
for (var n = 0u; n < THREAD_TILE_N; n = n + 1u) {
219181
let global_m = tile_start_m + thread_idx_m * THREAD_TILE_M + m;
220182
let global_n = tile_start_n + thread_idx_n * THREAD_TILE_N + n;
221183
if (global_m < metadata.M && global_n < metadata.N) {
222-
// BitNet B1.58 scaling: result = activation_scale * weight_scale * dot_product
223184
let activation_scale = activation_scales[global_m];
224185
let weight_scale = weight_scales[global_n];
225-
// Use the manually calculated 1D index again.
226-
let acc_idx = m * THREAD_TILE_N + n;
227-
let final_result = f32(accumulators[acc_idx]) * activation_scale * weight_scale;
186+
let final_result = f32(accumulators[m][n]) * activation_scale * weight_scale;
228187
output[global_m * metadata.N + global_n] = final_result;
229188
}
230189
}
231190
}
232-
}
191+
}

crates/bitnet-core/src/wgpu_context.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ impl WgpuContext {
6161
compatible_surface: None,
6262
})
6363
.await
64-
.ok_or(BitNetError::NoSuitableAdapter)?;
64+
.map_err(|_| BitNetError::NoSuitableAdapter)?;
6565

6666
let mut required_features = wgpu::Features::empty();
6767
if adapter.features().contains(wgpu::Features::TIMESTAMP_QUERY) {
@@ -74,8 +74,9 @@ impl WgpuContext {
7474
label: Some("Bitnet Device"),
7575
required_features,
7676
required_limits: Default::default(),
77-
},
78-
None,
77+
memory_hints: wgpu::MemoryHints::default(),
78+
trace: wgpu::Trace::default(),
79+
}
7980
)
8081
.await
8182
.map_err(BitNetError::RequestDeviceError)?;
@@ -103,7 +104,7 @@ impl WgpuContext {
103104
compatible_surface: None,
104105
})
105106
.await
106-
.ok_or(BitNetError::NoSuitableAdapter)?;
107+
.map_err(|_| BitNetError::NoSuitableAdapter)?;
107108

108109
let mut required_features = wgpu::Features::empty();
109110
if adapter.features().contains(wgpu::Features::TIMESTAMP_QUERY) {
@@ -116,8 +117,9 @@ impl WgpuContext {
116117
label: Some("Custom Device"),
117118
required_features,
118119
required_limits: limits,
119-
},
120-
None,
120+
memory_hints: wgpu::MemoryHints::default(),
121+
trace: wgpu::Trace::default(),
122+
}
121123
)
122124
.await
123125
.map_err(BitNetError::RequestDeviceError)?;

crates/bitnet-core/tests/DX12_test.rs

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1916,18 +1916,18 @@ struct WarmGpuContext {
19161916

19171917
impl WarmGpuContext {
19181918
async fn new() -> Option<Self> {
1919-
let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
1919+
let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
19201920
backends: wgpu::Backends::DX12,
19211921
..Default::default()
19221922
});
19231923

19241924
let adapter = instance
19251925
.request_adapter(&wgpu::RequestAdapterOptions::default())
1926-
.await?;
1926+
.await.ok()?;
19271927

19281928
let info = adapter.get_info();
19291929
let (device, queue) = adapter // Capture queue here
1930-
.request_device(&wgpu::DeviceDescriptor::default(), None)
1930+
.request_device(&wgpu::DeviceDescriptor::default())
19311931
.await
19321932
.ok()?;
19331933

@@ -1958,7 +1958,7 @@ where
19581958
}));
19591959

19601960
let result = f();
1961-
device.poll(wgpu::Maintain::Wait);
1961+
device.poll(wgpu::MaintainBase::Wait);
19621962
device.on_uncaptured_error(Box::new(|_| {})); // Reset handler
19631963

19641964
let e = error.lock().unwrap().take();
@@ -2008,8 +2008,9 @@ async fn test_shader_compilation(name: &str, shader_source: &str, context: &Warm
20082008
label: Some(&format!("{} compute pipeline", name)),
20092009
layout: Some(&pipeline_layout),
20102010
module: &shader_module,
2011-
entry_point: "main",
2011+
entry_point: Some("main"),
20122012
compilation_options: Default::default(),
2013+
cache: None,
20132014
})
20142015
}).await;
20152016
if let Some(e) = pipeline_error {
@@ -2150,8 +2151,9 @@ async fn test_correctness_on_dx12(context: &WarmGpuContext) {
21502151
label: Some("Correctness Test Pipeline"),
21512152
layout: Some(&pipeline_layout),
21522153
module: &shader_module,
2153-
entry_point: "main",
2154+
entry_point: Some("main"),
21542155
compilation_options: Default::default(),
2156+
cache: None,
21552157
})
21562158
}).await;
21572159

@@ -2220,7 +2222,7 @@ async fn test_correctness_on_dx12(context: &WarmGpuContext) {
22202222
let buffer_slice = staging_buffer.slice(..);
22212223
let (tx, rx) = futures_intrusive::channel::shared::oneshot_channel();
22222224
buffer_slice.map_async(wgpu::MapMode::Read, move |v| tx.send(v).unwrap());
2223-
context.device.poll(wgpu::Maintain::Wait);
2225+
context.device.poll(wgpu::MaintainBase::Wait);
22242226
rx.receive().await.unwrap().unwrap();
22252227
let gpu_output: Vec<f32> = bytemuck::cast_slice(&buffer_slice.get_mapped_range()).to_vec();
22262228

0 commit comments

Comments
 (0)