1- // bitnet_kernel_optimal .wgsl
2- // Optimized BitNet B1.58 Ternary Kernel for WGPU (Optimal, but fails on DX12)
1+ // bitnet_kernel_pre_naga .wgsl
2+ // Optimized BitNet B1.58 Ternary Kernel for WGPU (Original, before any DX12/Naga workaround )
33// Supports {-1, 0, +1} ternary weights with efficient packing and vectorization
44
55struct BitnetMetadata {
@@ -17,11 +17,11 @@ struct BitnetMetadata {
1717@group (0 ) @binding (5 ) var <storage , read_write > output : array <f32 >;
1818
1919// Optimized tiling parameters for modern GPUs
20- const TILE_DIM_M : u32 = 64u ; // Reduced for better occupancy
20+ const TILE_DIM_M : u32 = 64u ;
2121const TILE_DIM_N : u32 = 64u ;
22- const TILE_DIM_K : u32 = 32u ; // Increased K tile for better data reuse
22+ const TILE_DIM_K : u32 = 32u ;
2323
24- const THREAD_TILE_M : u32 = 4u ; // Smaller thread tiles for better vectorization
24+ const THREAD_TILE_M : u32 = 4u ;
2525const THREAD_TILE_N : u32 = 4u ;
2626
2727const WORKGROUP_SIZE_X : u32 = 16u ; // TILE_DIM_N / THREAD_TILE_N
@@ -31,12 +31,10 @@ const WORKGROUP_SIZE_Y: u32 = 16u; // TILE_DIM_M / THREAD_TILE_M
3131const TILE_A_SIZE : u32 = (TILE_DIM_M * TILE_DIM_K ) / 4u ; // for vec4<i32>
3232const TILE_B_SIZE : u32 = TILE_DIM_K * TILE_DIM_N ; // for i32
3333
34- // Shared memory with better alignment
3534var <workgroup > tile_a : array <vec4 <i32 >, TILE_A_SIZE >;
3635var <workgroup > tile_b : array <i32 , TILE_B_SIZE >;
3736
38- // Use direct decode function for ternary weights, matching the logic
39- // from the previously passing tests.
37+ // Use direct decode function for ternary weights
4038fn decode_2bit (val : u32 ) -> i32 {
4139 switch (val ) {
4240 case 1u : { return 1 ; } // 01
@@ -45,8 +43,6 @@ fn decode_2bit(val: u32) -> i32 {
4543 }
4644}
4745
48- // Unroll the decoding loop for maximum compatibility, while preserving
49- // the LSB-to-MSB decoding order from the previously passing version.
5046fn decode_16x2bit_ternary (packed_val : u32 ) -> array <i32 , 16 > {
5147 var decoded : array <i32 , 16 >;
5248 decoded [0 ] = decode_2bit ((packed_val >> 0u ) & 0x3u );
@@ -68,7 +64,6 @@ fn decode_16x2bit_ternary(packed_val: u32) -> array<i32, 16> {
6864 return decoded ;
6965}
7066
71- // Vectorized dot product for better throughput
7267fn dot_product_4x4 (a : vec4 <i32 >, b : vec4 <i32 >) -> i32 {
7368 return dot (a , b );
7469}
@@ -84,21 +79,17 @@ fn main(
8479 let tile_start_m = workgroup_id . y * TILE_DIM_M ;
8580 let tile_start_n = workgroup_id . x * TILE_DIM_N ;
8681
87- // FIX 1: Use a flattened array of i32 for the accumulator to avoid the
88- // `array<vec4<i32>>` indexing bug on the Dx12 backend.
89- var accumulators : array <i32 , 16 >;
90- for (var i = 0u ; i < 16u ; i = i + 1u ) {
91- accumulators [i ] = 0 ;
82+ // Original accumulator: vectorized style
83+ var accumulators : array <vec4 <i32 >, THREAD_TILE_M >;
84+ for (var i = 0u ; i < THREAD_TILE_M ; i = i + 1u ) {
85+ accumulators [i ] = vec4 <i32 >(0 );
9286 }
9387
94- // Main tiling loop with optimizations
9588 let num_k_tiles = (metadata . K + TILE_DIM_K - 1u ) / TILE_DIM_K ;
9689
9790 var k_tile_idx = 0u ;
9891 while (k_tile_idx < num_k_tiles ) {
9992 let k_tile_start = k_tile_idx * TILE_DIM_K ;
100- // === Cooperative Loading with Coalescing ===
101- // Load activations with vectorization
10293 let total_a_elements = TILE_DIM_M * TILE_DIM_K / 4u ;
10394 let loads_per_thread_a = (total_a_elements + 255u ) / 256u ; // Ceiling division
10495 for (var i = 0u ; i < loads_per_thread_a ; i = i + 1u ) {
@@ -111,7 +102,6 @@ fn main(
111102 let global_m = tile_start_m + m ;
112103 let global_k = k_tile_start + k ;
113104 if (global_m < metadata . M && global_k + 3u < metadata . K ) {
114- // Load 4 activations at once
115105 let base_addr = global_m * metadata . K + global_k ;
116106 tile_a [vec_idx ] = vec4 <i32 >(
117107 activations [base_addr ],
@@ -124,7 +114,6 @@ fn main(
124114 }
125115 }
126116 }
127- // Load and decode weights
128117 let total_b_elements = TILE_DIM_N * TILE_DIM_K ;
129118 let loads_per_thread_b = (total_b_elements + 255u ) / 256u ;
130119 for (var i = 0u ; i < loads_per_thread_b ; i = i + 1u ) {
@@ -137,26 +126,7 @@ fn main(
137126 if (global_n < metadata . N && global_k_packed_idx < metadata . K_packed ) {
138127 let weight_idx = global_n * metadata . K_packed + global_k_packed_idx ;
139128 let packed_w = packed_weights [weight_idx ];
140- // FIX 2: Inline the decoding logic. The Dx12 backend fails pipeline
141- // creation when a function returns an array, as seen in the V4.2.2 test.
142- var decoded : array <i32 , 16 >;
143- decoded [0 ] = decode_2bit ((packed_w >> 0u ) & 0x3u );
144- decoded [1 ] = decode_2bit ((packed_w >> 2u ) & 0x3u );
145- decoded [2 ] = decode_2bit ((packed_w >> 4u ) & 0x3u );
146- decoded [3 ] = decode_2bit ((packed_w >> 6u ) & 0x3u );
147- decoded [4 ] = decode_2bit ((packed_w >> 8u ) & 0x3u );
148- decoded [5 ] = decode_2bit ((packed_w >> 10u ) & 0x3u );
149- decoded [6 ] = decode_2bit ((packed_w >> 12u ) & 0x3u );
150- decoded [7 ] = decode_2bit ((packed_w >> 14u ) & 0x3u );
151- decoded [8 ] = decode_2bit ((packed_w >> 16u ) & 0x3u );
152- decoded [9 ] = decode_2bit ((packed_w >> 18u ) & 0x3u );
153- decoded [10 ] = decode_2bit ((packed_w >> 20u ) & 0x3u );
154- decoded [11 ] = decode_2bit ((packed_w >> 22u ) & 0x3u );
155- decoded [12 ] = decode_2bit ((packed_w >> 24u ) & 0x3u );
156- decoded [13 ] = decode_2bit ((packed_w >> 26u ) & 0x3u );
157- decoded [14 ] = decode_2bit ((packed_w >> 28u ) & 0x3u );
158- decoded [15 ] = decode_2bit ((packed_w >> 30u ) & 0x3u );
159- // Store decoded weights (unrolled for WGSL compliance)
129+ let decoded = decode_16x2bit_ternary (packed_w );
160130 tile_b [n * TILE_DIM_K + k + 0u ] = decoded [0u ];
161131 tile_b [n * TILE_DIM_K + k + 1u ] = decoded [1u ];
162132 tile_b [n * TILE_DIM_K + k + 2u ] = decoded [2u ];
@@ -174,25 +144,21 @@ fn main(
174144 tile_b [n * TILE_DIM_K + k + 14u ] = decoded [14u ];
175145 tile_b [n * TILE_DIM_K + k + 15u ] = decoded [15u ];
176146 } else {
177- // Pad with zeros
178147 for (var j = 0u ; j < 16u ; j = j + 1u ) {
179148 tile_b [n * TILE_DIM_K + k + j ] = 0 ;
180149 }
181150 }
182151 }
183152 }
184153 workgroupBarrier ();
185- // === Vectorized Computation ===
186154 for (var k_inner = 0u ; k_inner < TILE_DIM_K ; k_inner = k_inner + 4u ) {
187- // Load vectorized activations
188155 var a_vecs : array <vec4 <i32 >, THREAD_TILE_M >;
189156 for (var m = 0u ; m < THREAD_TILE_M ; m = m + 1u ) {
190157 let base_m = thread_idx_m * THREAD_TILE_M + m ;
191158 let vec_idx = (base_m * TILE_DIM_K + k_inner ) / 4u ;
192159 let a_i32 = tile_a [vec_idx ];
193160 a_vecs [m ] = a_i32 ;
194161 }
195- // Load vectorized weights and compute
196162 for (var n = 0u ; n < THREAD_TILE_N ; n = n + 1u ) {
197163 let base_n = thread_idx_n * THREAD_TILE_N + n ;
198164 let b_vec = vec4 <i32 >(
@@ -201,32 +167,25 @@ fn main(
201167 tile_b [base_n * TILE_DIM_K + k_inner + 2u ],
202168 tile_b [base_n * TILE_DIM_K + k_inner + 3u ]
203169 );
204- // Vectorized multiply-accumulate
205170 for (var m = 0u ; m < THREAD_TILE_M ; m = m + 1u ) {
206171 let dot_result = dot_product_4x4 (a_vecs [m ], b_vec );
207- // Manually calculate the 1D index for the flattened accumulator.
208- let acc_idx = m * THREAD_TILE_N + n ;
209- accumulators [acc_idx ] += dot_result ;
172+ accumulators [m ][n ] += dot_result ;
210173 }
211174 }
212175 }
213176 workgroupBarrier ();
214177 k_tile_idx = k_tile_idx + 1u ;
215178 }
216- // === Write Results with Proper Scaling ===
217179 for (var m = 0u ; m < THREAD_TILE_M ; m = m + 1u ) {
218180 for (var n = 0u ; n < THREAD_TILE_N ; n = n + 1u ) {
219181 let global_m = tile_start_m + thread_idx_m * THREAD_TILE_M + m ;
220182 let global_n = tile_start_n + thread_idx_n * THREAD_TILE_N + n ;
221183 if (global_m < metadata . M && global_n < metadata . N ) {
222- // BitNet B1.58 scaling: result = activation_scale * weight_scale * dot_product
223184 let activation_scale = activation_scales [global_m ];
224185 let weight_scale = weight_scales [global_n ];
225- // Use the manually calculated 1D index again.
226- let acc_idx = m * THREAD_TILE_N + n ;
227- let final_result = f32 (accumulators [acc_idx ]) * activation_scale * weight_scale ;
186+ let final_result = f32 (accumulators [m ][n ]) * activation_scale * weight_scale ;
228187 output [global_m * metadata . N + global_n ] = final_result ;
229188 }
230189 }
231190 }
232- }
191+ }
0 commit comments