Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 98 additions & 7 deletions onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul.wgsl.template
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

#param block_size
#param n_bits
#param has_bias
#param has_zero_points
#param is_qualcomm
#param has_weight_idx

#use .getByOffset .setByOffset

Expand Down Expand Up @@ -75,15 +77,20 @@ fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32)
{
return;
}

let b_value = b.getByOffset(b_global*uniforms.K16+kidx_v+col);
#if has_weight_idx
let b_weight_offset = uniforms.weight_idx * uniforms.N * uniforms.K16;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Since b_weight_offset and b_scale_offset are constant value, will it be better to calculate them in cpu and write into uniform? In shader, we can always read them from uniform.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could do that but I intentionally did not because there are 32 experts so we'd need to compile 32 shaders.
But I was thinking for weight_idx == 0 I could do a #if in the template and 'const b_weight_offset = 0' so everything not QMoE would benefit of const and for QMoE we'd need to compile 2 shaders.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made a change to use const for weight_idx related offset if weight_idx == 0. So only QMoE takes a tiny hit for the weight_idx.

let b_value = b.getByOffset(b_weight_offset + b_global * uniforms.K16 + kidx_v + col);
#else
let b_value = b.getByOffset(b_global * uniforms.K16+kidx_v + col);
#endif
let block_idx = kidx_v/(block_size/16);
let zero = mm_read_zero(b_global, block_idx, uniforms.N, uniforms.zero_blocks_per_col);
tile_B[col][row] = DequantizedFrom4BitsTo8Bits(b_value, zero);
if (col == 0)
{
// kidx_v - each kidx_v covers 16 values of k
scale_B[row] = scales_b.getByOffset(b_global*(uniforms.K/block_size) + block_idx);
let b_scale_offset = uniforms.weight_idx * uniforms.N * (uniforms.K/block_size);
scale_B[row] = scales_b.getByOffset(b_scale_offset + b_global*(uniforms.K/block_size) + block_idx);
}
}
#endif
Expand All @@ -97,13 +104,20 @@ fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32)
return;
}

let b_value = b.getByOffset(b_global*uniforms.K16+kidx_v+col);
#if has_weight_idx
let b_weight_offset = uniforms.weight_idx * uniforms.N * uniforms.K16;
let b_value = b.getByOffset(b_weight_offset + b_global * uniforms.K16 + kidx_v + col);
#else
const b_weight_offset : u32 = 0;
let b_value = b.getByOffset(b_global * uniforms.K16 + kidx_v + col);
#endif
tile_B[col][row] = AlignWithZeroPoint(b_value);
if (col == 0)
{
// kidx_v - each kidx_v covers 16 values of k
let block_idx = kidx_v/(block_size/16);
scale_B[row] = scales_b.getByOffset(b_global*(uniforms.K/block_size) + block_idx);
let b_scale_offset = uniforms.weight_idx * uniforms.N * (uniforms.K/block_size);
scale_B[row] = scales_b.getByOffset(b_scale_offset + b_global*(uniforms.K/block_size) + block_idx);
#if has_zero_points
zeroes[row] = mm_read_zero(b_global, block_idx, uniforms.N, uniforms.zero_blocks_per_col);
#endif
Expand All @@ -119,10 +133,17 @@ fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32)
{
return;
}
let b_value = b.getByOffset(b_global*uniforms.K16+kidx_v+col);
#if has_weight_idx
let b_weight_offset = uniforms.weight_idx * uniforms.N * uniforms.K16;
let b_value = b.getByOffset(b_weight_offset + b_global * uniforms.K16 + kidx_v + col);
#else
const b_weight_offset : u32 = 0;
let b_value = b.getByOffset(b_global * uniforms.K16 + kidx_v + col);
#endif
tile_B[col][row] = DequantizedFrom2BitsTo8Bits(b_value);
let block_idx = kidx_v/(block_size/16);
scale_B[row] = scales_b.getByOffset(b_global*(uniforms.K/block_size) + block_idx);
let b_scale_offset = uniforms.weight_idx * uniforms.N * (uniforms.K/block_size);
scale_B[row] = scales_b.getByOffset(b_scale_offset + b_global*(uniforms.K/block_size) + block_idx);
}
#endif

Expand Down Expand Up @@ -360,19 +381,89 @@ $MAIN {
let a_global = a_global_base + base_A + a_idx;
let b_global = b_global_base + base_B;
let output_idx = ((a_global) * uniforms.N + b_global)/4;
#if has_bias
#if has_weight_idx
let b_bias_offset = uniforms.weight_idx * uniforms.N;
#else
const b_bias_offset : u32 = 0;
#endif
#endif
// This creates a shader requirement that uniforms.N % 16 == 0
if (a_global < uniforms.M && b_global < uniforms.N)
{
#if is_qualcomm
#if has_bias
let bias_vec1 = vec4<output_element_t>(
bias[b_global + 0 + b_bias_offset],
bias[b_global + 1 + b_bias_offset],
bias[b_global + 2 + b_bias_offset],
bias[b_global + 3 + b_bias_offset]
);
let bias_vec2 = vec4<output_element_t>(
bias[b_global + 4 + b_bias_offset],
bias[b_global + 5 + b_bias_offset],
bias[b_global + 6 + b_bias_offset],
bias[b_global + 7 + b_bias_offset]
);
let bias_vec3 = vec4<output_element_t>(
bias[b_global + 8 + b_bias_offset],
bias[b_global + 9 + b_bias_offset],
bias[b_global + 10 + b_bias_offset],
bias[b_global + 11 + b_bias_offset]
);
let bias_vec4 = vec4<output_element_t>(
bias[b_global + 12 + b_bias_offset],
bias[b_global + 13 + b_bias_offset],
bias[b_global + 14 + b_bias_offset],
bias[b_global + 15 + b_bias_offset]
);
output.setByOffset(output_idx, vec4<output_element_t>(lane_outputs[0], lane_outputs[1], lane_outputs[2], lane_outputs[3]) + bias_vec1);
output.setByOffset(output_idx+1, vec4<output_element_t>(lane_outputs[4], lane_outputs[5], lane_outputs[6], lane_outputs[7]) + bias_vec2);
output.setByOffset(output_idx+2, vec4<output_element_t>(lane_outputs[8], lane_outputs[9], lane_outputs[10], lane_outputs[11]) + bias_vec3);
output.setByOffset(output_idx+3, vec4<output_element_t>(lane_outputs[12], lane_outputs[13], lane_outputs[14], lane_outputs[15]) + bias_vec4);
#else
output.setByOffset(output_idx, vec4<output_element_t>(lane_outputs[0], lane_outputs[1], lane_outputs[2], lane_outputs[3]));
output.setByOffset(output_idx+1, vec4<output_element_t>(lane_outputs[4], lane_outputs[5], lane_outputs[6], lane_outputs[7]));
output.setByOffset(output_idx+2, vec4<output_element_t>(lane_outputs[8], lane_outputs[9], lane_outputs[10], lane_outputs[11]));
output.setByOffset(output_idx+3, vec4<output_element_t>(lane_outputs[12], lane_outputs[13], lane_outputs[14], lane_outputs[15]));
#endif
#else
#if has_bias
// TODO: wanted to use vec4 for bias but for some reason that fails ut. Later.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using vec4 for bias, you need to make sure N % 4 == 0 or it will be very complicated to re-arrange to get the correct vec4 data.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can sit it out

let bias_vec1 = vec4<output_element_t>(
bias[b_global + 0 + b_bias_offset],
bias[b_global + 1 + b_bias_offset],
bias[b_global + 2 + b_bias_offset],
bias[b_global + 3 + b_bias_offset]
);
let bias_vec2 = vec4<output_element_t>(
bias[b_global + 4 + b_bias_offset],
bias[b_global + 5 + b_bias_offset],
bias[b_global + 6 + b_bias_offset],
bias[b_global + 7 + b_bias_offset]
);
let bias_vec3 = vec4<output_element_t>(
bias[b_global + 8 + b_bias_offset],
bias[b_global + 9 + b_bias_offset],
bias[b_global + 10 + b_bias_offset],
bias[b_global + 11 + b_bias_offset]
);
let bias_vec4 = vec4<output_element_t>(
bias[b_global + 12 + b_bias_offset],
bias[b_global + 13 + b_bias_offset],
bias[b_global + 14 + b_bias_offset],
bias[b_global + 15 + b_bias_offset]
);
output.setByOffset(output_idx, lane_output1 + bias_vec1);
output.setByOffset(output_idx+1, lane_output2 + bias_vec2);
output.setByOffset(output_idx+2, lane_output3 + bias_vec3);
output.setByOffset(output_idx+3, lane_output4 + bias_vec4);
#else
output.setByOffset(output_idx, lane_output1);
output.setByOffset(output_idx+1, lane_output2);
output.setByOffset(output_idx+2, lane_output3);
output.setByOffset(output_idx+3, lane_output4);
#endif
#endif
}
} // MAIN
37 changes: 29 additions & 8 deletions onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,14 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
if (has_zero_points_) {
shader.AddInput("zero_points", ShaderUsage::UseUniform);
}
if (has_bias_) {
shader.AddInput("bias", ShaderUsage::UseUniform);
}
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias);
return WGSL_TEMPLATE_APPLY(shader, "quantization/dp4a_matmul.wgsl.template",
WGSL_TEMPLATE_PARAMETER(block_size, block_size_),
WGSL_TEMPLATE_PARAMETER(has_bias, has_bias_),
WGSL_TEMPLATE_PARAMETER(has_weight_idx, has_weight_idx_),
WGSL_TEMPLATE_PARAMETER(has_zero_points, has_zero_points_),
WGSL_TEMPLATE_PARAMETER(is_qualcomm, is_qualcomm_),
WGSL_TEMPLATE_PARAMETER(n_bits, nbits_),
Expand All @@ -50,13 +55,18 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co
if (has_zero_points_) {
shader.AddInput("zero_points", ShaderUsage::UseUniform);
}
if (has_bias_) {
shader.AddInput("bias", ShaderUsage::UseUniform);
}
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias);

ORT_ENFORCE(WorkgroupSizeX() % tile_size_k_vec_ == 0 && tile_size_k_vec_ % 4 == 0, "tile_size_k_vec_ must evenly divide workgroup size X and be divisible by 4");
const uint32_t sub_tile_count = WorkgroupSizeX() / tile_size_k_vec_;
ORT_ENFORCE(tile_size_ % sub_tile_count == 0, "tile_size_ must be divisible by sub_tile_count");

return WGSL_TEMPLATE_APPLY(shader, "quantization/dp4a_matmul_small_m.wgsl.template",
WGSL_TEMPLATE_PARAMETER(has_bias, has_bias_),
WGSL_TEMPLATE_PARAMETER(has_weight_idx, has_weight_idx_),
WGSL_TEMPLATE_PARAMETER(has_zero_points, has_zero_points_),
WGSL_TEMPLATE_PARAMETER(n_bits, nbits_),
WGSL_TEMPLATE_PARAMETER(output_type_i32, true),
Expand All @@ -72,7 +82,7 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co
}

Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales,
const Tensor* zero_points,
const Tensor* zero_points, const Tensor* bias,
uint32_t M,
uint32_t N,
uint32_t K,
Expand All @@ -81,7 +91,8 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
uint32_t min_M_for_tile_optimization,
uint32_t nbits,
onnxruntime::webgpu::ComputeContext& context,
Tensor* y) {
Tensor* y,
const uint32_t weight_index) {
constexpr uint32_t kVec4Components = 4;
constexpr uint32_t kVec2Components = 2;
constexpr uint32_t kU32Components = 4;
Expand All @@ -100,7 +111,10 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
{&a_scale, ProgramTensorMetadataDependency::Rank, 1}})
.AddUniformVariable({M * K / kU32Components});
ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program));

const bool has_zero_points = zero_points != nullptr;
const bool has_bias = bias != nullptr;
const bool has_weight_idx = weight_index != 0;
const bool single_scale_weights = (block_size == K * N);
if (M < min_M_for_tile_optimization) {
uint32_t tile_size_k_vec = 16;
Expand All @@ -111,20 +125,23 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
tile_size_n = 4;
}
const uint32_t b_components = (nbits == 2 ? kVec2Components : kVec4Components);
DP4AMatMulNBitsSmallMProgram mul_program{tile_size_k_vec, tile_size_n, nbits, has_zero_points, single_scale_weights};
DP4AMatMulNBitsSmallMProgram mul_program{tile_size_k_vec, tile_size_n, nbits, has_zero_points, has_bias, has_weight_idx, single_scale_weights};
uint32_t num_N_tile = (N + tile_size_n - 1) / tile_size_n;
mul_program.SetWorkgroupSize(128);
mul_program.SetDispatchGroupSize(M * num_N_tile);
mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(kVec4Components)},
{&a_scale, ProgramTensorMetadataDependency::TypeAndRank, 1},
{b, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(b_components * kU32Components)},
{scales, ProgramTensorMetadataDependency::TypeAndRank, 1}})
.AddUniformVariables({M, N, K, K / 16, K / 32, block_size, num_N_tile, zero_blocks_per_col})
.AddUniformVariables({M, N, K, K / 16, K / 32, block_size, num_N_tile, zero_blocks_per_col, weight_index})
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, 1})
.CacheHint(nbits, tile_size_k_vec, tile_size_n, has_zero_points, single_scale_weights);
.CacheHint(nbits, tile_size_k_vec, tile_size_n, has_zero_points, single_scale_weights, has_bias, has_weight_idx);
if (has_zero_points) {
mul_program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4});
}
if (has_bias) {
mul_program.AddInput({bias, ProgramTensorMetadataDependency::None});
}
return context.RunProgram(mul_program);
}

Expand All @@ -133,7 +150,7 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
uint32_t num_M_tile = (M + kTileSize - 1) / kTileSize;
uint32_t num_N_tile = (N + kTileSize - 1) / kTileSize;
bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"};
DP4AMatMulNBitsProgram mul_program{block_size, nbits, has_zero_points, is_qualcomm};
DP4AMatMulNBitsProgram mul_program{block_size, nbits, has_zero_points, has_bias, has_weight_idx, is_qualcomm};
mul_program.SetWorkgroupSize(256);
mul_program.SetDispatchGroupSize(num_M_tile * num_N_tile);
mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(kVec4Components)},
Expand All @@ -146,12 +163,16 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
{static_cast<uint32_t>(K / 8)},
{static_cast<uint32_t>(K / 16)},
{num_N_tile},
{zero_blocks_per_col}})
{zero_blocks_per_col},
{weight_index}})
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, static_cast<int>(kVec4Components)})
.CacheHint("Block" + std::to_string(block_size), nbits, has_zero_points, is_qualcomm);
.CacheHint("Block" + std::to_string(block_size), nbits, has_zero_points, is_qualcomm, has_bias, has_weight_idx);
if (has_zero_points) {
mul_program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4});
}
if (has_bias) {
mul_program.AddInput({bias, ProgramTensorMetadataDependency::None});
}
return context.RunProgram(mul_program);
}

Expand Down
Loading
Loading