-
Couldn't load subscription status.
- Fork 3.5k
webgpu / nbitmm support for bias and weight_index #26392
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f5e8c86
36c3a23
6e03276
cab3663
36d12fb
897c7b7
273a200
91801e7
34d736f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,8 +3,10 @@ | |
|
|
||
| #param block_size | ||
| #param n_bits | ||
| #param has_bias | ||
| #param has_zero_points | ||
| #param is_qualcomm | ||
| #param has_weight_idx | ||
|
|
||
| #use .getByOffset .setByOffset | ||
|
|
||
|
|
@@ -75,15 +77,20 @@ fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32) | |
| { | ||
| return; | ||
| } | ||
|
|
||
| let b_value = b.getByOffset(b_global*uniforms.K16+kidx_v+col); | ||
| #if has_weight_idx | ||
| let b_weight_offset = uniforms.weight_idx * uniforms.N * uniforms.K16; | ||
| let b_value = b.getByOffset(b_weight_offset + b_global * uniforms.K16 + kidx_v + col); | ||
| #else | ||
| let b_value = b.getByOffset(b_global * uniforms.K16+kidx_v + col); | ||
| #endif | ||
| let block_idx = kidx_v/(block_size/16); | ||
| let zero = mm_read_zero(b_global, block_idx, uniforms.N, uniforms.zero_blocks_per_col); | ||
| tile_B[col][row] = DequantizedFrom4BitsTo8Bits(b_value, zero); | ||
| if (col == 0) | ||
| { | ||
| // kidx_v - each kidx_v covers 16 values of k | ||
| scale_B[row] = scales_b.getByOffset(b_global*(uniforms.K/block_size) + block_idx); | ||
| let b_scale_offset = uniforms.weight_idx * uniforms.N * (uniforms.K/block_size); | ||
| scale_B[row] = scales_b.getByOffset(b_scale_offset + b_global*(uniforms.K/block_size) + block_idx); | ||
| } | ||
| } | ||
| #endif | ||
|
|
@@ -97,13 +104,20 @@ fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32) | |
| return; | ||
| } | ||
|
|
||
| let b_value = b.getByOffset(b_global*uniforms.K16+kidx_v+col); | ||
| #if has_weight_idx | ||
| let b_weight_offset = uniforms.weight_idx * uniforms.N * uniforms.K16; | ||
| let b_value = b.getByOffset(b_weight_offset + b_global * uniforms.K16 + kidx_v + col); | ||
| #else | ||
| const b_weight_offset : u32 = 0; | ||
| let b_value = b.getByOffset(b_global * uniforms.K16 + kidx_v + col); | ||
| #endif | ||
| tile_B[col][row] = AlignWithZeroPoint(b_value); | ||
| if (col == 0) | ||
| { | ||
| // kidx_v - each kidx_v covers 16 values of k | ||
| let block_idx = kidx_v/(block_size/16); | ||
| scale_B[row] = scales_b.getByOffset(b_global*(uniforms.K/block_size) + block_idx); | ||
| let b_scale_offset = uniforms.weight_idx * uniforms.N * (uniforms.K/block_size); | ||
| scale_B[row] = scales_b.getByOffset(b_scale_offset + b_global*(uniforms.K/block_size) + block_idx); | ||
| #if has_zero_points | ||
| zeroes[row] = mm_read_zero(b_global, block_idx, uniforms.N, uniforms.zero_blocks_per_col); | ||
| #endif | ||
|
|
@@ -119,10 +133,17 @@ fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32) | |
| { | ||
| return; | ||
| } | ||
| let b_value = b.getByOffset(b_global*uniforms.K16+kidx_v+col); | ||
| #if has_weight_idx | ||
| let b_weight_offset = uniforms.weight_idx * uniforms.N * uniforms.K16; | ||
| let b_value = b.getByOffset(b_weight_offset + b_global * uniforms.K16 + kidx_v + col); | ||
| #else | ||
| const b_weight_offset : u32 = 0; | ||
| let b_value = b.getByOffset(b_global * uniforms.K16 + kidx_v + col); | ||
| #endif | ||
| tile_B[col][row] = DequantizedFrom2BitsTo8Bits(b_value); | ||
| let block_idx = kidx_v/(block_size/16); | ||
| scale_B[row] = scales_b.getByOffset(b_global*(uniforms.K/block_size) + block_idx); | ||
| let b_scale_offset = uniforms.weight_idx * uniforms.N * (uniforms.K/block_size); | ||
| scale_B[row] = scales_b.getByOffset(b_scale_offset + b_global*(uniforms.K/block_size) + block_idx); | ||
| } | ||
| #endif | ||
|
|
||
|
|
@@ -360,19 +381,89 @@ $MAIN { | |
| let a_global = a_global_base + base_A + a_idx; | ||
| let b_global = b_global_base + base_B; | ||
| let output_idx = ((a_global) * uniforms.N + b_global)/4; | ||
| #if has_bias | ||
| #if has_weight_idx | ||
| let b_bias_offset = uniforms.weight_idx * uniforms.N; | ||
| #else | ||
| const b_bias_offset : u32 = 0; | ||
| #endif | ||
| #endif | ||
| // This creates a shader requirement that uniforms.N % 16 == 0 | ||
| if (a_global < uniforms.M && b_global < uniforms.N) | ||
| { | ||
| #if is_qualcomm | ||
| #if has_bias | ||
| let bias_vec1 = vec4<output_element_t>( | ||
| bias[b_global + 0 + b_bias_offset], | ||
| bias[b_global + 1 + b_bias_offset], | ||
| bias[b_global + 2 + b_bias_offset], | ||
| bias[b_global + 3 + b_bias_offset] | ||
| ); | ||
| let bias_vec2 = vec4<output_element_t>( | ||
| bias[b_global + 4 + b_bias_offset], | ||
| bias[b_global + 5 + b_bias_offset], | ||
| bias[b_global + 6 + b_bias_offset], | ||
| bias[b_global + 7 + b_bias_offset] | ||
| ); | ||
| let bias_vec3 = vec4<output_element_t>( | ||
| bias[b_global + 8 + b_bias_offset], | ||
| bias[b_global + 9 + b_bias_offset], | ||
| bias[b_global + 10 + b_bias_offset], | ||
| bias[b_global + 11 + b_bias_offset] | ||
| ); | ||
| let bias_vec4 = vec4<output_element_t>( | ||
| bias[b_global + 12 + b_bias_offset], | ||
| bias[b_global + 13 + b_bias_offset], | ||
| bias[b_global + 14 + b_bias_offset], | ||
| bias[b_global + 15 + b_bias_offset] | ||
| ); | ||
| output.setByOffset(output_idx, vec4<output_element_t>(lane_outputs[0], lane_outputs[1], lane_outputs[2], lane_outputs[3]) + bias_vec1); | ||
| output.setByOffset(output_idx+1, vec4<output_element_t>(lane_outputs[4], lane_outputs[5], lane_outputs[6], lane_outputs[7]) + bias_vec2); | ||
| output.setByOffset(output_idx+2, vec4<output_element_t>(lane_outputs[8], lane_outputs[9], lane_outputs[10], lane_outputs[11]) + bias_vec3); | ||
| output.setByOffset(output_idx+3, vec4<output_element_t>(lane_outputs[12], lane_outputs[13], lane_outputs[14], lane_outputs[15]) + bias_vec4); | ||
| #else | ||
| output.setByOffset(output_idx, vec4<output_element_t>(lane_outputs[0], lane_outputs[1], lane_outputs[2], lane_outputs[3])); | ||
| output.setByOffset(output_idx+1, vec4<output_element_t>(lane_outputs[4], lane_outputs[5], lane_outputs[6], lane_outputs[7])); | ||
| output.setByOffset(output_idx+2, vec4<output_element_t>(lane_outputs[8], lane_outputs[9], lane_outputs[10], lane_outputs[11])); | ||
| output.setByOffset(output_idx+3, vec4<output_element_t>(lane_outputs[12], lane_outputs[13], lane_outputs[14], lane_outputs[15])); | ||
| #endif | ||
| #else | ||
| #if has_bias | ||
| // TODO: wanted to use vec4 for bias but for some reason that fails ut. Later. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using vec4 for bias, you need to make sure N % 4 == 0 or it will be very complicated to re-arrange to get the correct vec4 data. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can sit it out |
||
| let bias_vec1 = vec4<output_element_t>( | ||
| bias[b_global + 0 + b_bias_offset], | ||
| bias[b_global + 1 + b_bias_offset], | ||
| bias[b_global + 2 + b_bias_offset], | ||
| bias[b_global + 3 + b_bias_offset] | ||
| ); | ||
| let bias_vec2 = vec4<output_element_t>( | ||
| bias[b_global + 4 + b_bias_offset], | ||
| bias[b_global + 5 + b_bias_offset], | ||
| bias[b_global + 6 + b_bias_offset], | ||
| bias[b_global + 7 + b_bias_offset] | ||
| ); | ||
| let bias_vec3 = vec4<output_element_t>( | ||
| bias[b_global + 8 + b_bias_offset], | ||
| bias[b_global + 9 + b_bias_offset], | ||
| bias[b_global + 10 + b_bias_offset], | ||
| bias[b_global + 11 + b_bias_offset] | ||
| ); | ||
| let bias_vec4 = vec4<output_element_t>( | ||
| bias[b_global + 12 + b_bias_offset], | ||
| bias[b_global + 13 + b_bias_offset], | ||
| bias[b_global + 14 + b_bias_offset], | ||
| bias[b_global + 15 + b_bias_offset] | ||
| ); | ||
| output.setByOffset(output_idx, lane_output1 + bias_vec1); | ||
| output.setByOffset(output_idx+1, lane_output2 + bias_vec2); | ||
| output.setByOffset(output_idx+2, lane_output3 + bias_vec3); | ||
| output.setByOffset(output_idx+3, lane_output4 + bias_vec4); | ||
| #else | ||
| output.setByOffset(output_idx, lane_output1); | ||
| output.setByOffset(output_idx+1, lane_output2); | ||
| output.setByOffset(output_idx+2, lane_output3); | ||
| output.setByOffset(output_idx+3, lane_output4); | ||
| #endif | ||
| #endif | ||
| } | ||
| } // MAIN | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Since
b_weight_offsetandb_scale_offsetare constant value, will it be better to calculate them in cpu and write into uniform? In shader, we can always read them from uniform.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could do that but I intentionally did not because there are 32 experts so we'd need to compile 32 shaders.
But I was thinking for weight_idx == 0 I could do a #if in the template and 'const b_weight_offset = 0' so everything not QMoE would benefit of const and for QMoE we'd need to compile 2 shaders.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made a change to use const for weight_idx related offset if weight_idx == 0. So only QMoE takes a tiny hit for the weight_idx.