diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul.wgsl.template index eebe329c104e7..279a5f97eb3ba 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul.wgsl.template @@ -3,8 +3,10 @@ #param block_size #param n_bits +#param has_bias #param has_zero_points #param is_qualcomm +#param has_weight_idx #use .getByOffset .setByOffset @@ -75,15 +77,20 @@ fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32) { return; } - - let b_value = b.getByOffset(b_global*uniforms.K16+kidx_v+col); +#if has_weight_idx + let b_weight_offset = uniforms.weight_idx * uniforms.N * uniforms.K16; + let b_value = b.getByOffset(b_weight_offset + b_global * uniforms.K16 + kidx_v + col); +#else + let b_value = b.getByOffset(b_global * uniforms.K16+kidx_v + col); +#endif let block_idx = kidx_v/(block_size/16); let zero = mm_read_zero(b_global, block_idx, uniforms.N, uniforms.zero_blocks_per_col); tile_B[col][row] = DequantizedFrom4BitsTo8Bits(b_value, zero); if (col == 0) { // kidx_v - each kidx_v covers 16 values of k - scale_B[row] = scales_b.getByOffset(b_global*(uniforms.K/block_size) + block_idx); + let b_scale_offset = uniforms.weight_idx * uniforms.N * (uniforms.K/block_size); + scale_B[row] = scales_b.getByOffset(b_scale_offset + b_global*(uniforms.K/block_size) + block_idx); } } #endif @@ -97,13 +104,20 @@ fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32) return; } - let b_value = b.getByOffset(b_global*uniforms.K16+kidx_v+col); +#if has_weight_idx + let b_weight_offset = uniforms.weight_idx * uniforms.N * uniforms.K16; + let b_value = b.getByOffset(b_weight_offset + b_global * uniforms.K16 + kidx_v + col); +#else + const b_weight_offset : u32 = 0; + let b_value = b.getByOffset(b_global * uniforms.K16 + kidx_v + col); +#endif tile_B[col][row] = AlignWithZeroPoint(b_value); if (col == 0) { // kidx_v - each kidx_v covers 16 values of k let block_idx = kidx_v/(block_size/16); - scale_B[row] = scales_b.getByOffset(b_global*(uniforms.K/block_size) + block_idx); + let b_scale_offset = uniforms.weight_idx * uniforms.N * (uniforms.K/block_size); + scale_B[row] = scales_b.getByOffset(b_scale_offset + b_global*(uniforms.K/block_size) + block_idx); #if has_zero_points zeroes[row] = mm_read_zero(b_global, block_idx, uniforms.N, uniforms.zero_blocks_per_col); #endif @@ -119,10 +133,17 @@ fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32) { return; } - let b_value = b.getByOffset(b_global*uniforms.K16+kidx_v+col); +#if has_weight_idx + let b_weight_offset = uniforms.weight_idx * uniforms.N * uniforms.K16; + let b_value = b.getByOffset(b_weight_offset + b_global * uniforms.K16 + kidx_v + col); +#else + const b_weight_offset : u32 = 0; + let b_value = b.getByOffset(b_global * uniforms.K16 + kidx_v + col); +#endif tile_B[col][row] = DequantizedFrom2BitsTo8Bits(b_value); let block_idx = kidx_v/(block_size/16); - scale_B[row] = scales_b.getByOffset(b_global*(uniforms.K/block_size) + block_idx); + let b_scale_offset = uniforms.weight_idx * uniforms.N * (uniforms.K/block_size); + scale_B[row] = scales_b.getByOffset(b_scale_offset + b_global*(uniforms.K/block_size) + block_idx); } #endif @@ -360,19 +381,89 @@ $MAIN { let a_global = a_global_base + base_A + a_idx; let b_global = b_global_base + base_B; let output_idx = ((a_global) * uniforms.N + b_global)/4; +#if has_bias +#if has_weight_idx + let b_bias_offset = uniforms.weight_idx * uniforms.N; +#else + const b_bias_offset : u32 = 0; +#endif +#endif // This creates a shader requirement that uniforms.N % 16 == 0 if (a_global < uniforms.M && b_global < uniforms.N) { #if is_qualcomm + #if has_bias + let bias_vec1 = vec4( + bias[b_global + 0 + b_bias_offset], + bias[b_global + 1 + b_bias_offset], + bias[b_global + 2 + b_bias_offset], + bias[b_global + 3 + b_bias_offset] + ); + let bias_vec2 = vec4( + bias[b_global + 4 + b_bias_offset], + bias[b_global + 5 + b_bias_offset], + bias[b_global + 6 + b_bias_offset], + bias[b_global + 7 + b_bias_offset] + ); + let bias_vec3 = vec4( + bias[b_global + 8 + b_bias_offset], + bias[b_global + 9 + b_bias_offset], + bias[b_global + 10 + b_bias_offset], + bias[b_global + 11 + b_bias_offset] + ); + let bias_vec4 = vec4( + bias[b_global + 12 + b_bias_offset], + bias[b_global + 13 + b_bias_offset], + bias[b_global + 14 + b_bias_offset], + bias[b_global + 15 + b_bias_offset] + ); + output.setByOffset(output_idx, vec4(lane_outputs[0], lane_outputs[1], lane_outputs[2], lane_outputs[3]) + bias_vec1); + output.setByOffset(output_idx+1, vec4(lane_outputs[4], lane_outputs[5], lane_outputs[6], lane_outputs[7]) + bias_vec2); + output.setByOffset(output_idx+2, vec4(lane_outputs[8], lane_outputs[9], lane_outputs[10], lane_outputs[11]) + bias_vec3); + output.setByOffset(output_idx+3, vec4(lane_outputs[12], lane_outputs[13], lane_outputs[14], lane_outputs[15]) + bias_vec4); + #else output.setByOffset(output_idx, vec4(lane_outputs[0], lane_outputs[1], lane_outputs[2], lane_outputs[3])); output.setByOffset(output_idx+1, vec4(lane_outputs[4], lane_outputs[5], lane_outputs[6], lane_outputs[7])); output.setByOffset(output_idx+2, vec4(lane_outputs[8], lane_outputs[9], lane_outputs[10], lane_outputs[11])); output.setByOffset(output_idx+3, vec4(lane_outputs[12], lane_outputs[13], lane_outputs[14], lane_outputs[15])); + #endif #else + #if has_bias + // TODO: wanted to use vec4 for bias but for some reason that fails ut. Later. + let bias_vec1 = vec4( + bias[b_global + 0 + b_bias_offset], + bias[b_global + 1 + b_bias_offset], + bias[b_global + 2 + b_bias_offset], + bias[b_global + 3 + b_bias_offset] + ); + let bias_vec2 = vec4( + bias[b_global + 4 + b_bias_offset], + bias[b_global + 5 + b_bias_offset], + bias[b_global + 6 + b_bias_offset], + bias[b_global + 7 + b_bias_offset] + ); + let bias_vec3 = vec4( + bias[b_global + 8 + b_bias_offset], + bias[b_global + 9 + b_bias_offset], + bias[b_global + 10 + b_bias_offset], + bias[b_global + 11 + b_bias_offset] + ); + let bias_vec4 = vec4( + bias[b_global + 12 + b_bias_offset], + bias[b_global + 13 + b_bias_offset], + bias[b_global + 14 + b_bias_offset], + bias[b_global + 15 + b_bias_offset] + ); + output.setByOffset(output_idx, lane_output1 + bias_vec1); + output.setByOffset(output_idx+1, lane_output2 + bias_vec2); + output.setByOffset(output_idx+2, lane_output3 + bias_vec3); + output.setByOffset(output_idx+3, lane_output4 + bias_vec4); + #else output.setByOffset(output_idx, lane_output1); output.setByOffset(output_idx+1, lane_output2); output.setByOffset(output_idx+2, lane_output3); output.setByOffset(output_idx+3, lane_output4); + #endif #endif } } // MAIN diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc index d6e15e56f193f..79d54840c1d66 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc @@ -27,9 +27,14 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { if (has_zero_points_) { shader.AddInput("zero_points", ShaderUsage::UseUniform); } + if (has_bias_) { + shader.AddInput("bias", ShaderUsage::UseUniform); + } const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); return WGSL_TEMPLATE_APPLY(shader, "quantization/dp4a_matmul.wgsl.template", WGSL_TEMPLATE_PARAMETER(block_size, block_size_), + WGSL_TEMPLATE_PARAMETER(has_bias, has_bias_), + WGSL_TEMPLATE_PARAMETER(has_weight_idx, has_weight_idx_), WGSL_TEMPLATE_PARAMETER(has_zero_points, has_zero_points_), WGSL_TEMPLATE_PARAMETER(is_qualcomm, is_qualcomm_), WGSL_TEMPLATE_PARAMETER(n_bits, nbits_), @@ -50,6 +55,9 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co if (has_zero_points_) { shader.AddInput("zero_points", ShaderUsage::UseUniform); } + if (has_bias_) { + shader.AddInput("bias", ShaderUsage::UseUniform); + } const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); ORT_ENFORCE(WorkgroupSizeX() % tile_size_k_vec_ == 0 && tile_size_k_vec_ % 4 == 0, "tile_size_k_vec_ must evenly divide workgroup size X and be divisible by 4"); @@ -57,6 +65,8 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co ORT_ENFORCE(tile_size_ % sub_tile_count == 0, "tile_size_ must be divisible by sub_tile_count"); return WGSL_TEMPLATE_APPLY(shader, "quantization/dp4a_matmul_small_m.wgsl.template", + WGSL_TEMPLATE_PARAMETER(has_bias, has_bias_), + WGSL_TEMPLATE_PARAMETER(has_weight_idx, has_weight_idx_), WGSL_TEMPLATE_PARAMETER(has_zero_points, has_zero_points_), WGSL_TEMPLATE_PARAMETER(n_bits, nbits_), WGSL_TEMPLATE_PARAMETER(output_type_i32, true), @@ -72,7 +82,7 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co } Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, - const Tensor* zero_points, + const Tensor* zero_points, const Tensor* bias, uint32_t M, uint32_t N, uint32_t K, @@ -81,7 +91,8 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor uint32_t min_M_for_tile_optimization, uint32_t nbits, onnxruntime::webgpu::ComputeContext& context, - Tensor* y) { + Tensor* y, + const uint32_t weight_index) { constexpr uint32_t kVec4Components = 4; constexpr uint32_t kVec2Components = 2; constexpr uint32_t kU32Components = 4; @@ -100,7 +111,10 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor {&a_scale, ProgramTensorMetadataDependency::Rank, 1}}) .AddUniformVariable({M * K / kU32Components}); ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program)); + const bool has_zero_points = zero_points != nullptr; + const bool has_bias = bias != nullptr; + const bool has_weight_idx = weight_index != 0; const bool single_scale_weights = (block_size == K * N); if (M < min_M_for_tile_optimization) { uint32_t tile_size_k_vec = 16; @@ -111,7 +125,7 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor tile_size_n = 4; } const uint32_t b_components = (nbits == 2 ? kVec2Components : kVec4Components); - DP4AMatMulNBitsSmallMProgram mul_program{tile_size_k_vec, tile_size_n, nbits, has_zero_points, single_scale_weights}; + DP4AMatMulNBitsSmallMProgram mul_program{tile_size_k_vec, tile_size_n, nbits, has_zero_points, has_bias, has_weight_idx, single_scale_weights}; uint32_t num_N_tile = (N + tile_size_n - 1) / tile_size_n; mul_program.SetWorkgroupSize(128); mul_program.SetDispatchGroupSize(M * num_N_tile); @@ -119,12 +133,15 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor {&a_scale, ProgramTensorMetadataDependency::TypeAndRank, 1}, {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(b_components * kU32Components)}, {scales, ProgramTensorMetadataDependency::TypeAndRank, 1}}) - .AddUniformVariables({M, N, K, K / 16, K / 32, block_size, num_N_tile, zero_blocks_per_col}) + .AddUniformVariables({M, N, K, K / 16, K / 32, block_size, num_N_tile, zero_blocks_per_col, weight_index}) .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, 1}) - .CacheHint(nbits, tile_size_k_vec, tile_size_n, has_zero_points, single_scale_weights); + .CacheHint(nbits, tile_size_k_vec, tile_size_n, has_zero_points, single_scale_weights, has_bias, has_weight_idx); if (has_zero_points) { mul_program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); } + if (has_bias) { + mul_program.AddInput({bias, ProgramTensorMetadataDependency::None}); + } return context.RunProgram(mul_program); } @@ -133,7 +150,7 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor uint32_t num_M_tile = (M + kTileSize - 1) / kTileSize; uint32_t num_N_tile = (N + kTileSize - 1) / kTileSize; bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"}; - DP4AMatMulNBitsProgram mul_program{block_size, nbits, has_zero_points, is_qualcomm}; + DP4AMatMulNBitsProgram mul_program{block_size, nbits, has_zero_points, has_bias, has_weight_idx, is_qualcomm}; mul_program.SetWorkgroupSize(256); mul_program.SetDispatchGroupSize(num_M_tile * num_N_tile); mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec4Components)}, @@ -146,12 +163,16 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor {static_cast(K / 8)}, {static_cast(K / 16)}, {num_N_tile}, - {zero_blocks_per_col}}) + {zero_blocks_per_col}, + {weight_index}}) .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, static_cast(kVec4Components)}) - .CacheHint("Block" + std::to_string(block_size), nbits, has_zero_points, is_qualcomm); + .CacheHint("Block" + std::to_string(block_size), nbits, has_zero_points, is_qualcomm, has_bias, has_weight_idx); if (has_zero_points) { mul_program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); } + if (has_bias) { + mul_program.AddInput({bias, ProgramTensorMetadataDependency::None}); + } return context.RunProgram(mul_program); } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h index b00392cbb291e..cfb6139bb7c1b 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h @@ -21,11 +21,15 @@ class DP4AMatMulQuantizeProgram final : public Program { public: - DP4AMatMulNBitsProgram(uint32_t block_size, uint32_t nbits, bool has_zero_points, bool is_qualcomm) : Program{"DP4AMatMulNBits"}, - block_size_(block_size), - nbits_(nbits), - has_zero_points_(has_zero_points), - is_qualcomm_(is_qualcomm) {} + DP4AMatMulNBitsProgram(uint32_t block_size, uint32_t nbits, + bool has_zero_points, bool has_bias, + bool has_weight_idx, bool is_qualcomm) : Program{"DP4AMatMulNBits"}, + block_size_(block_size), + nbits_(nbits), + has_bias_(has_bias), + has_zero_points_(has_zero_points), + has_weight_idx_(has_weight_idx), + is_qualcomm_(is_qualcomm) {} Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( {"M", ProgramUniformVariableDataType::Uint32}, @@ -34,23 +38,30 @@ class DP4AMatMulNBitsProgram final : public Program { {"K8", ProgramUniformVariableDataType::Uint32}, {"K16", ProgramUniformVariableDataType::Uint32}, {"num_N_tile", ProgramUniformVariableDataType::Uint32}, - {"zero_blocks_per_col", ProgramUniformVariableDataType::Uint32}); + {"zero_blocks_per_col", ProgramUniformVariableDataType::Uint32}, + {"weight_idx", ProgramUniformVariableDataType::Uint32}); private: uint32_t block_size_; uint32_t nbits_; + bool has_bias_; bool has_zero_points_; + bool has_weight_idx_; bool is_qualcomm_; }; class DP4AMatMulNBitsSmallMProgram final : public Program { public: - DP4AMatMulNBitsSmallMProgram(uint32_t tile_size_k_vec, uint32_t tile_size, uint32_t nbits, bool has_zero_points, bool single_scale_weights) : Program{"DP4AMatMulNBitsSmallMProgram"}, - tile_size_k_vec_(tile_size_k_vec), - tile_size_(tile_size), - nbits_(nbits), - has_zero_points_(has_zero_points), - single_scale_weights_(single_scale_weights) {} + DP4AMatMulNBitsSmallMProgram(uint32_t tile_size_k_vec, uint32_t tile_size, uint32_t nbits, + bool has_zero_points, bool has_bias, + bool has_weight_idx, bool single_scale_weights) : Program{"DP4AMatMulNBitsSmallMProgram"}, + tile_size_k_vec_(tile_size_k_vec), + tile_size_(tile_size), + nbits_(nbits), + has_bias_(has_bias), + has_zero_points_(has_zero_points), + has_weight_idx_(has_weight_idx), + single_scale_weights_(single_scale_weights) {} Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( {"M", ProgramUniformVariableDataType::Uint32}, @@ -60,18 +71,21 @@ class DP4AMatMulNBitsSmallMProgram final : public Program(bits_); @@ -124,23 +133,81 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context MatMulComputeHelper helper; TensorShape b_shape({N_, K_}); ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); - auto* y = context.Output(0, helper.OutputShape()); + auto output_shape = helper.OutputShape(); + Tensor* y = context.Output(0, output_shape); const uint32_t data_size = onnxruntime::narrow(y->Shape().Size()); if (data_size == 0) { return Status::OK(); } + return ApplyMatMulNBits(a, b, scales, zero_points, bias, K_, N_, block_size_, accuracy_level_, bits_, context, y, 0); +} + +/** + * @brief Applies a quantized matrix multiplication using N-bit precision. + * + * This function computes the matrix multiplication of the quantized tensor inputs with multiple + * optional optimizations tailored to the GPU backend. Depending on the provided parameters and GPU + * capabilities, it selects one of several optimized kernels (such as subgroup matrix multiplication, + * DP4A, wide tile programs, or the default matmul program) to perform the computation. + * It can be called by the MatMulNBits operator or directly for custom scenarios like QMoe. + * + * @param a Pointer to the left-hand side (activation) tensor. + * @param b Pointer to the quantized weight tensor. + * b has the shape (N, k_blocks, blob_size) or (weight_batch, N, k_blocks, blob_size) + * @param scales Pointer to the tensor containing scaling factors for quantization. + * scales has the shape (N) or (weight_batch, N) + * @param zero_points Pointer to the zero-point tensor for quantization; must be of type uint8 if provided. + * weight_index > 0 is only supported when zero_points is nullptr. + * @param bias Pointer to the bias tensor; optional. + * @param K_op The K dimension of the operation (number of columns in 'a' and rows in 'b' before quantization). + * @param N_op The N dimension of the operation (number of columns in 'b'). + * @param block_size_op The block size used for quantization partitioning. + * @param accuracy_level Accuracy level influencing the choice of optimized kernel. + * @param nbits Number of bits used for quantization. + * @param weight_index Index of the weight matrix in case of stacked weights; defaults to 0. + * @param context Compute context for WebGPU, providing device-specific information and execution facilities. + * @param y Pointer to the output tensor that will hold the result. + * + * @return Status indicating whether the operation was successful or if an error occurred. + * + * @note Special optimizations are considered: + * - Subgroup matrix multiplication for eligible Apple/Intel GPUs. + * - DP4A-based multiplication on FP32-only GPUs for specific dimensions and conditions. + * - A wide tile program is used when block size, component count, and other criteria are met. + * - Otherwise, a default matmul program is used. + */ +Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, const Tensor* zero_points, const Tensor* bias, + int64_t K_op, + int64_t N_op, + int64_t block_size_op, + int64_t accuracy_level, + int64_t nbits, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y, + const uint32_t weight_index) { + TensorShape b_shape({N_op, K_op}); + MatMulComputeHelper helper; + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); + + const bool has_bias = bias != nullptr; + const bool has_weight_idx = weight_index > 0; + const bool has_zero_points = zero_points != nullptr; + if (has_zero_points) { + ORT_ENFORCE(zero_points->DataType() == DataTypeImpl::GetType(), "Currently, only uint8 is supported for zero points, but got ", zero_points->DataType()); + } + const uint32_t batch_count = onnxruntime::narrow(helper.OutputOffsets().size()); const uint32_t M = onnxruntime::narrow(helper.M()); const uint32_t N = onnxruntime::narrow(helper.N()); const uint32_t K = onnxruntime::narrow(helper.K()); - const uint32_t block_size = onnxruntime::narrow(block_size_); + const uint32_t block_size = onnxruntime::narrow(block_size_op); // Special case matrix used by bitnets where there is a single scale for the entire const bool single_scale_weights = (block_size == K * N); const uint32_t block_size_per_col = single_scale_weights ? K : block_size; const uint32_t n_blocks_per_col = (K + block_size_per_col - 1) / block_size_per_col; - const uint32_t blob_size = (block_size_per_col / 8) * nbits; + const uint32_t blob_size = (block_size_per_col / 8) * static_cast(nbits); const uint32_t blob_size_in_words = blob_size / 4; const uint32_t components_a = GetMaxComponents(K); const uint32_t components_b = GetMaxComponents(blob_size_in_words); @@ -153,16 +220,15 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context int32_t subgroup_matrix_config_index = -1; // apple|intel - Experimental dawn support for subgroup matrix matmul. if (M >= kMinMForTileOptimization && (context.AdapterInfo().vendor == std::string_view{"apple"} || context.AdapterInfo().vendor == std::string_view{"intel"}) && - CanApplySubgroupMatrixMatMulNBits(context, accuracy_level_, block_size, batch_count, N, K, subgroup_matrix_config_index)) { - return ApplySubgroupMatrixMatMulNBits(a, b, scales, zero_points, M, N, K, nbits, zero_blocks_per_col, subgroup_matrix_config_index, context, y); + CanApplySubgroupMatrixMatMulNBits(context, accuracy_level, block_size, batch_count, N, K, subgroup_matrix_config_index)) { + return ApplySubgroupMatrixMatMulNBits(a, b, scales, zero_points, bias, M, N, K, static_cast(nbits), zero_blocks_per_col, subgroup_matrix_config_index, context, y, weight_index); } #endif // On FP32 only GPUs, integer math is faster than FP32 therefore always use DP4A independent of length of M. - if ((M >= kMinMForTileOptimization || y->DataType() == DataTypeImpl::GetType() || - context.AdapterInfo().vendor == std::string_view{"qualcomm"}) && - CanApplyDP4AMatrixMatMulNBits(context, accuracy_level_, block_size, batch_count, N, K, components_a)) { - return ApplyDP4AMatrixMatMulNBits(a, b, scales, zero_points, M, N, K, block_size, zero_blocks_per_col, kMinMForTileOptimization, nbits, context, y); + if ((M >= kMinMForTileOptimization || y->DataType() == DataTypeImpl::GetType() || context.AdapterInfo().vendor == std::string_view{"qualcomm"}) && + CanApplyDP4AMatrixMatMulNBits(context, accuracy_level, block_size, batch_count, N, K, components_a)) { + return ApplyDP4AMatrixMatMulNBits(a, b, scales, zero_points, bias, M, N, K, block_size, zero_blocks_per_col, kMinMForTileOptimization, static_cast(nbits), context, y, weight_index); } // WideTileProgram @@ -172,6 +238,7 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context components_b == 4 && nbits != 2 && M >= kMinMForTileOptimization; + if (use_wide_tile_program) { // Enforce output components to 1. components = 1; @@ -182,7 +249,7 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context const uint32_t num_N_tile = ceil_div(N, tile_n); const uint32_t num_M_tile = ceil_div(M, tile_m); - MatMulNBitsWideTileProgram program{has_zero_points, tile_m, tile_n, nbits}; + MatMulNBitsWideTileProgram program{has_zero_points, has_bias, has_weight_idx, tile_m, tile_n, static_cast(nbits)}; program.SetWorkgroupSize(workgroup_size); program.SetDispatchGroupSize(num_N_tile, num_M_tile, batch_count); @@ -204,6 +271,9 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context {ceil_div(zero_points->Shape().Size(), static_cast(4))}, 4}); } + if (has_bias) { + program.AddInput({bias, ProgramTensorMetadataDependency::None}); + } program.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, onnxruntime::narrow(components)}); @@ -215,8 +285,9 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context {n_blocks_per_col}, {zero_blocks_per_col}, {num_N_tile}, - {num_M_tile}}); - program.CacheHint(nbits, has_zero_points); + {num_M_tile}, + {weight_index}}); + program.CacheHint(nbits, has_zero_points, has_bias, has_weight_idx); return context.RunProgram(program); } @@ -227,7 +298,7 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context uint32_t components_b_with_u32 = components_b * kU32Components; uint32_t num_N_tile = (N + tile_size - 1) / tile_size; uint32_t K_of_b = (n_blocks_per_col * blob_size) / components_b_with_u32; - MatMulNBitsProgram program{tile_size, nbits, has_zero_points, single_scale_weights}; + MatMulNBitsProgram program{tile_size, static_cast(nbits), has_zero_points, has_bias, has_weight_idx, single_scale_weights}; program.SetWorkgroupSize(workgroup_size); program.SetDispatchGroupSize((N + tile_size - 1) / tile_size, M, batch_count); program @@ -235,11 +306,24 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_b_with_u32)}, {scales, ProgramTensorMetadataDependency::TypeAndRank}}) .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank}) - .AddUniformVariables({{M}, {N}, {K}, {K / components_a}, {K_of_b}, {block_size}, {n_blocks_per_col}, {zero_blocks_per_col}, {num_N_tile}, {batch_count}}) - .CacheHint(nbits, has_zero_points, single_scale_weights); + .AddUniformVariables({{M}, + {N}, + {K}, + {K / components_a}, + {K_of_b}, + {block_size}, + {n_blocks_per_col}, + {zero_blocks_per_col}, + {num_N_tile}, + {batch_count}, + {weight_index}}) + .CacheHint(nbits, has_zero_points, single_scale_weights, has_bias, has_weight_idx); if (has_zero_points) { program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); } + if (has_bias) { + program.AddInput({bias, ProgramTensorMetadataDependency::None}); + } return context.RunProgram(program); } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h index 1f7bd16d9cb6f..ccd1ef6f1355c 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h @@ -14,8 +14,8 @@ using namespace onnxruntime::webgpu; class MatMulNBitsWideTileProgram final : public Program { public: - MatMulNBitsWideTileProgram(bool has_zero_points, uint32_t tile_m, uint32_t tile_n, uint32_t nbits) - : Program{"MatMulNBitsWideTile"}, has_zero_points_{has_zero_points}, tile_m_(tile_m), tile_n_(tile_n), nbits_(nbits) {} + MatMulNBitsWideTileProgram(bool has_zero_points, bool has_bias, bool has_weight_idx, uint32_t tile_m, uint32_t tile_n, uint32_t nbits) + : Program{"MatMulNBitsWideTile"}, has_zero_points_{has_zero_points}, has_bias_{has_bias}, has_weight_idx_{has_weight_idx}, tile_m_(tile_m), tile_n_(tile_n), nbits_(nbits) {} Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"Batch", ProgramUniformVariableDataType::Uint32}, @@ -26,10 +26,13 @@ class MatMulNBitsWideTileProgram final : public Program { public: - MatMulNBitsProgram(uint32_t tile_size, uint32_t nbits, bool has_zero_points, bool single_scale_weights) : Program{"MatMulNBits"}, tile_size_(tile_size), nbits_(nbits), has_zero_points_(has_zero_points), single_scale_weights_(single_scale_weights) {} + MatMulNBitsProgram(uint32_t tile_size, uint32_t nbits, bool has_zero_points, bool has_bias, bool has_weight_idx, bool single_scale_weights) + : Program{"MatMulNBits"}, tile_size_(tile_size), nbits_(nbits), has_zero_points_(has_zero_points), has_bias_(has_bias), has_weight_idx_{has_weight_idx}, single_scale_weights_(single_scale_weights) {} Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( {"M", ProgramUniformVariableDataType::Uint32}, @@ -49,12 +53,15 @@ class MatMulNBitsProgram final : public Program { {"blocks_per_col", ProgramUniformVariableDataType::Uint32}, {"zero_blocks_per_col", ProgramUniformVariableDataType::Uint32}, {"num_N_tile", ProgramUniformVariableDataType::Uint32}, - {"batch_count", ProgramUniformVariableDataType::Uint32}); + {"batch_count", ProgramUniformVariableDataType::Uint32}, + {"weight_idx", ProgramUniformVariableDataType::Uint32}); private: uint32_t tile_size_; uint32_t nbits_; bool has_zero_points_; + bool has_bias_; + bool has_weight_idx_; bool single_scale_weights_; }; @@ -80,6 +87,10 @@ class MatMulNBits final : public WebGpuKernel { int64_t bits_; }; +Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, const Tensor* zero_points, const Tensor* bias, + int64_t K_op, int64_t N_op, int64_t block_size_op, int64_t accuracy_level, int64_t bits_op, + onnxruntime::webgpu::ComputeContext& context, Tensor* y, const uint32_t weight_index = 0); + } // namespace webgpu } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.wgsl.template index 0fe3ec92ef3de..1b74862515c69 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.wgsl.template @@ -11,6 +11,8 @@ #param tile_size_k_vec #param tile_size_k #param tile_size +#param has_bias +#param has_weight_idx #use .getByOffset .setByOffset @@ -32,6 +34,20 @@ fn loadSHMA(batch: u32, a_global: u32, kidx: u32, col: u32) $MAIN { let batch = workgroup_idx / (uniforms.M * uniforms.num_N_tile); +#if has_weight_idx + let b_base_offset = uniforms.weight_idx * uniforms.K_of_b * uniforms.N; +#if single_scale_weights + let b_scale_offset = uniforms.weight_idx; +#else + let b_scale_offset = uniforms.weight_idx * uniforms.N * uniforms.blocks_per_col; +#endif +#else + const b_base_offset : u32 = 0; + const b_scale_offset : u32 = 0; +#endif +#if has_bias + let b_bias_offset = uniforms.weight_idx * uniforms.N; +#endif let a_global = (workgroup_idx / uniforms.num_N_tile) % uniforms.M; let b_global_base = (workgroup_idx % uniforms.num_N_tile) * tile_size; @@ -40,7 +56,7 @@ $MAIN { #if single_scale_weights let block_idx = 0; - let scale_b = scales_b.getByOffset(0); + let scale_b = scales_b.getByOffset(0 + b_scale_offset); let zero = mm_read_zero(0, 0, uniforms.N, uniforms.zero_blocks_per_col); #endif @@ -60,10 +76,10 @@ $MAIN { { #if !single_scale_weights let block_idx = (kidx + idx * elements_in_value_b) / uniforms.block_size; - let scale_b = scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx); + let scale_b = scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx + b_scale_offset); let zero = mm_read_zero(b_global, block_idx, uniforms.N, uniforms.zero_blocks_per_col); #endif - var b_value = b.getByOffset(b_global * uniforms.K_of_b + k_offset); + var b_value = b.getByOffset(b_global * uniforms.K_of_b + k_offset + b_base_offset); #if n_bits == 4 var sum = output_element_t(0); @@ -154,6 +170,9 @@ $MAIN { let b_global = b_global_base + local_idx; let output_idx = batch * uniforms.M * uniforms.N + a_global * uniforms.N + b_global; if (b_global < uniforms.N) { +#if has_bias + output_value += bias[b_global + b_bias_offset]; +#endif output.setByOffset(output_idx, output_value); } } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_wide_tile.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_wide_tile.wgsl.template index 7c2fca615a99b..b95d4bd49c6d8 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_wide_tile.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_wide_tile.wgsl.template @@ -2,6 +2,8 @@ // Licensed under the MIT License. #param has_zero_points +#param has_bias +#param has_weight_idx #param nbits #param tile_m #param tile_n @@ -68,7 +70,12 @@ fn load_a(batch : u32, row : u32, col : u32) -> input_a_value_t { fn load_scale(row : u32, block_idx : u32) -> output_element_t { if (row < uniforms.N && block_idx < uniforms.n_blocks_per_col) { let offset = row * uniforms.n_blocks_per_col + block_idx; +#if has_weight_idx + let b_scale_offset = uniforms.weight_idx * uniforms.N * uniforms.n_blocks_per_col; + return scales.getByOffset(offset + b_scale_offset); +#else return scales.getByOffset(offset); +#endif } return output_element_t(); } @@ -83,7 +90,12 @@ fn write_output(batch : u32, row : u32, col : u32, value : output_element_t) { #if nbits == 4 fn load_b(row : u32, block_idx : u32) -> vec4 { if (row < uniforms.N && block_idx < uniforms.K_of_b) { +#if has_weight_idx + let b_offset = uniforms.weight_idx * uniforms.K_of_b * uniforms.N; + let offset = row * uniforms.K_of_b + block_idx + b_offset; +#else let offset = row * uniforms.K_of_b + block_idx; +#endif return b.getByOffset(offset); } return vec4(); @@ -112,7 +124,12 @@ fn dequantize(packed_data : u32, #else // nbits == 8 fn load_b(row : u32, block_idx : u32) -> array, 4> { if (row < uniforms.N) { +#if has_weight_idx + let b_offset = uniforms.weight_idx * uniforms.K_of_b * uniforms.N; + let offset = 2 * block_idx + b_offset; +#else let offset = 2 * block_idx; +#endif let b_data_0 = select(input_b_value_t(), b.getByOffset(row * uniforms.K_of_b + offset), offset < uniforms.K_of_b); @@ -190,7 +207,19 @@ $MAIN { } // Write the results. + #if has_bias + #if has_weight_idx + let b_bias_offset = uniforms.weight_idx * uniforms.N; + let bias_value = bias[b_bias_offset + col + local_idx]; + #else + let bias_value = bias[col + local_idx]; + #endif + #endif for (var m_idx = 0u; m_idx < kTileM; m_idx++) { + #if has_bias + write_output(batch, row + m_idx, col + local_idx, output_element_t(results[m_idx]) + bias_value); + #else write_output(batch, row + m_idx, col + local_idx, output_element_t(results[m_idx])); + #endif } } // MAIN diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_zero_pt.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_zero_pt.wgsl.template index 9135708adf153..0ad2f89f5263c 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_zero_pt.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_zero_pt.wgsl.template @@ -3,6 +3,7 @@ #param n_bits #param has_zero_points +#param has_bias #param output_type_i32 #if output_type_i32 diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc index db1a6319b3247..50aa4de4749bb 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc @@ -275,224 +275,22 @@ Status GenerateShaderCodeOnIntel(ShaderHelper& shader, const ShaderVariableHelpe subgroupMatrixStore(&output, matrix_c_offset + subtile_id * m_dim * uniforms.N + 2 * n_dim, matC02, false, uniforms.N); subgroupMatrixStore(&output, matrix_c_offset + subtile_id * m_dim * uniforms.N + 3 * n_dim, matC03, false, uniforms.N); )MAIN_FN"; - return Status::OK(); } Status GenerateShaderCodeOnApple(ShaderHelper& shader, const ShaderVariableHelper& a, const ShaderVariableHelper& b, const ShaderVariableHelper& scales_b, - const ShaderVariableHelper& output, uint32_t nbits, bool has_zero_points) { - // tile/subtile sizes and work distribution are inspired from metal shaders in llama.cpp (kernel_mul_mm) - // https://github.com/ggml-org/llama.cpp/blob/d04e7163c85a847bc61d58c22f2c503596db7aa8/ggml/src/ggml-metal/ggml-metal.metal#L6066 - shader.AdditionalImplementation() << R"ADDNL_FN_PART( - const tile_cols = 64; - const tile_rows = 32; - const tile_k = 32; - const subtile_cols = 32; - const subtile_rows = 16; - const quantization_block_size = 32; - alias compute_precision = output_element_t; - - var tile_A: array; // 32 x 32 - RxC - var tile_B: array; // 64 x 32 - RxC - var scratch: array, 4>, 4>; // 64 * 4 * 4 - - fn loadSHMA(tile_base: u32, k_idx: u32, row: u32, c_idx:u32) { - let a_global = tile_base + row; - if (a_global >= uniforms.M) { - return; - } - // Each call loads 8 columns, starting at col. - var col = c_idx * 8; - // 128 threads need to load 32 x 32. 4 threads per row or 8 col per thread. - for (var col_offset:u32 = 0; col_offset < 8; col_offset++) - { - )ADDNL_FN_PART"; - shader.AdditionalImplementation() - << " tile_A[row * tile_k + col + col_offset] = compute_precision(" - << a.GetByOffset("a_global * uniforms.K + k_idx + col + col_offset") - << ");"; - shader.AdditionalImplementation() << R"ADDNL_FN_PART( - } - })ADDNL_FN_PART"; - shader.AdditionalImplementation() << GenerateZeroPointReadingCode(nbits, has_zero_points, "compute_precision"); - if (nbits == 4) { - shader.AdditionalImplementation() << R"ADDNL_FN_PART( - fn loadSHMB(tile_base: u32, k_idx: u32, row: u32, c_idx: u32) { - let b_global = tile_base + row; - if (b_global >= uniforms.N) { - return; - } - // Each call loads 16 columns, starting at col. - var col = c_idx * 16; - // 128 threads need to load 64 x 32. 2 threads per row or 16 col per thread. - // Stored in column major fashion. - let b_idx = u32((b_global*uniforms.K + k_idx + col)/8); - )ADDNL_FN_PART"; - shader.AdditionalImplementation() << "let scale = compute_precision(" - << scales_b.GetByOffset("(b_global * uniforms.K + k_idx + col) / quantization_block_size") - << ");"; - shader.AdditionalImplementation() << R"ADDNL_FN_PART( - let zero = mm_read_zero(b_global, (k_idx + col) / quantization_block_size, uniforms.N, uniforms.zero_blocks_per_col); - for (var step:u32 = 0; step < 2; step++) - { - )ADDNL_FN_PART"; - shader.AdditionalImplementation() << "var b_value = " - << b.GetByOffset("b_idx+step") - << ';'; - shader.AdditionalImplementation() << R"ADDNL_FN_PART(var b_value_lower = (vec4(unpack4xU8(b_value & 0x0F0F0F0Fu)) - vec4(zero)) * scale; - var b_value_upper = (vec4(unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu)) - vec4(zero)) * scale; - let tile_b_base = row * tile_k + col + step * 8; - tile_B[tile_b_base] = b_value_lower[0]; - tile_B[tile_b_base + 1] = b_value_upper[0]; - tile_B[tile_b_base + 2] = b_value_lower[1]; - tile_B[tile_b_base + 3] = b_value_upper[1]; - tile_B[tile_b_base + 4] = b_value_lower[2]; - tile_B[tile_b_base + 5] = b_value_upper[2]; - tile_B[tile_b_base + 6] = b_value_lower[3]; - tile_B[tile_b_base + 7] = b_value_upper[3]; - } -} - )ADDNL_FN_PART"; - } else { - ORT_ENFORCE(nbits == 8, "Only 4/8 bits are supported for webgpu matmulnbits"); - shader.AdditionalImplementation() << R"ADDNL_FN_PART( - fn loadSHMB(tile_base: u32, k_idx: u32, row: u32, c_idx: u32) { - let b_global = tile_base + row; - if (b_global >= uniforms.N) { - return; - } - // Each call loads 16 columns, starting at col. - var col = c_idx * 16; - // 128 threads need to load 64 x 32. 2 threads per row or 16 col per thread. - // Stored in column major fashion. - let b_idx = u32((b_global*uniforms.K + k_idx + col)/8); - )ADDNL_FN_PART"; - shader.AdditionalImplementation() << "let scale = compute_precision(" - << scales_b.GetByOffset("(b_global * uniforms.K + k_idx + col) / quantization_block_size") - << ");"; - shader.AdditionalImplementation() << R"ADDNL_FN_PART( - let zero = mm_read_zero(b_global, (k_idx + col) / quantization_block_size, uniforms.N, uniforms.zero_blocks_per_col); - for (var step : u32 = 0; step < 2; step++) { - )ADDNL_FN_PART"; - shader.AdditionalImplementation() << "var b_value = " - << b.GetByOffset("b_idx+step") - << ';'; - - shader.AdditionalImplementation() << R"ADDNL_FN_PART( - var b_value0 = (vec4(unpack4xU8(b_value[0])) - vec4(zero)) * scale; - var b_value1 = (vec4(unpack4xU8(b_value[1])) - vec4(zero)) * scale; - let tile_b_base = row * tile_k + col + step * 8; - tile_B[tile_b_base] = b_value0[0]; - tile_B[tile_b_base + 1] = b_value0[1]; - tile_B[tile_b_base + 2] = b_value0[2]; - tile_B[tile_b_base + 3] = b_value0[3]; - tile_B[tile_b_base + 4] = b_value1[0]; - tile_B[tile_b_base + 5] = b_value1[1]; - tile_B[tile_b_base + 6] = b_value1[2]; - tile_B[tile_b_base + 7] = b_value1[3]; - } -} - )ADDNL_FN_PART"; - } - shader.AdditionalImplementation() - << " fn storeOutput(offset:u32, row: u32, col:u32, src_slot:u32, row_limit:i32) {\n" - << " if (row_limit > 0 && row < u32(row_limit))\n" - << " {\n" - << " " << output.SetByOffset("offset + row * uniforms.N + col", "output_element_t(scratch[src_slot][0][row * 8 + col])") << ";\n" - << " " << output.SetByOffset("offset + row * uniforms.N + col + 8", "output_element_t(scratch[src_slot][1][row * 8 + col])") << ";\n" - << " " << output.SetByOffset("offset + row * uniforms.N + col + 16", "output_element_t(scratch[src_slot][2][row * 8 + col])") << ";\n" - << " " << output.SetByOffset("offset + row * uniforms.N + col + 24", "output_element_t(scratch[src_slot][3][row * 8 + col])") << ";\n" - << " let col2 = col + 1;\n" - << " " << output.SetByOffset("offset + row * uniforms.N + col2", "output_element_t(scratch[src_slot][0][row * 8 + col2])") << ";\n" - << " " << output.SetByOffset("offset + row * uniforms.N + col2 + 8", "output_element_t(scratch[src_slot][1][row * 8 + col2])") << ";\n" - << " " << output.SetByOffset("offset + row * uniforms.N + col2 + 16", "output_element_t(scratch[src_slot][2][row * 8 + col2])") << ";\n" - << " " << output.SetByOffset("offset + row * uniforms.N + col2 + 24", "output_element_t(scratch[src_slot][3][row * 8 + col2])") << ";\n" - << " }\n" - << " }\n"; - - shader.MainFunctionBody() << R"MAIN_FN( - let a_global_base = workgroup_id.y * tile_rows; - let b_global_base = workgroup_id.x * tile_cols; - - let subtile_id = u32(local_idx / sg_size); - let subtile_idx = u32(subtile_id / 2); - let subtile_idy = subtile_id % 2; - let base_A = subtile_idy * subtile_rows; - let base_B = subtile_idx * subtile_cols; - - var matC00: subgroup_matrix_result; - var matC01: subgroup_matrix_result; - var matC02: subgroup_matrix_result; - var matC03: subgroup_matrix_result; - var matC10: subgroup_matrix_result; - var matC11: subgroup_matrix_result; - var matC12: subgroup_matrix_result; - var matC13: subgroup_matrix_result; - for (var kidx: u32 = 0; kidx < uniforms.K; kidx += tile_k) { - // Load Phase - loadSHMA(a_global_base, kidx, local_idx/4, local_idx%4); - loadSHMB(b_global_base, kidx, local_idx/2, local_idx%2); - workgroupBarrier(); - - for (var step: u32 = 0; step < tile_k; step+=8) - { - // Load to local memory phase - let matrix_a_offset = subtile_idy * subtile_rows * tile_k + step; - // Syntax: subgroupMatrixLoad src_ptr,src_offset,is_col_major,src_stride - var matA0: subgroup_matrix_left = subgroupMatrixLoad>(&tile_A, matrix_a_offset, false, tile_k); - var matA1: subgroup_matrix_left = subgroupMatrixLoad>(&tile_A, matrix_a_offset + 8 * tile_k, false, tile_k); - - // tile_B is stored as column major. - // [col0-0:32][col1-0:32][col2-0:32]..[col63-0:32] - var matrix_b_offset = subtile_idx * subtile_cols * tile_k + step; - var matB0: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset, true, tile_k); - var matB1: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 8 * tile_k, true, tile_k); - var matB2: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 16 * tile_k, true, tile_k); - var matB3: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 24 * tile_k, true, tile_k); - - // Compute Phase - // Syntax: subgroupMatrixMultiplyAccumulate left, right, accumulate -> accumulate - matC00 = subgroupMatrixMultiplyAccumulate(matA0, matB0, matC00); - matC01 = subgroupMatrixMultiplyAccumulate(matA0, matB1, matC01); - matC02 = subgroupMatrixMultiplyAccumulate(matA0, matB2, matC02); - matC03 = subgroupMatrixMultiplyAccumulate(matA0, matB3, matC03); - - matC10 = subgroupMatrixMultiplyAccumulate(matA1, matB0, matC10); - matC11 = subgroupMatrixMultiplyAccumulate(matA1, matB1, matC11); - matC12 = subgroupMatrixMultiplyAccumulate(matA1, matB2, matC12); - matC13 = subgroupMatrixMultiplyAccumulate(matA1, matB3, matC13); - } - - workgroupBarrier(); - } - - // Write out - // Write out top block - subgroupMatrixStore(&scratch[subtile_id][0], 0, matC00, false, 8); - subgroupMatrixStore(&scratch[subtile_id][1], 0, matC01, false, 8); - subgroupMatrixStore(&scratch[subtile_id][2], 0, matC02, false, 8); - subgroupMatrixStore(&scratch[subtile_id][3], 0, matC03, false, 8); - workgroupBarrier(); - let row = u32(sg_id / 4); - var col = u32(sg_id % 4) * 2; - var matrix_c_offset = (a_global_base+base_A) * uniforms.N + b_global_base + base_B; - var row_limit:i32 = i32(uniforms.M) - i32(a_global_base + base_A); - storeOutput(matrix_c_offset, row, col, subtile_id, row_limit); - workgroupBarrier(); - - // Write out bottom block - subgroupMatrixStore(&scratch[subtile_id][0], 0, matC10, false, 8); - subgroupMatrixStore(&scratch[subtile_id][1], 0, matC11, false, 8); - subgroupMatrixStore(&scratch[subtile_id][2], 0, matC12, false, 8); - subgroupMatrixStore(&scratch[subtile_id][3], 0, matC13, false, 8); - workgroupBarrier(); - matrix_c_offset = matrix_c_offset + 8 * uniforms.N; - row_limit = i32(uniforms.M) - i32(a_global_base + base_A + 8); - storeOutput(matrix_c_offset, row, col, subtile_id, row_limit); - )MAIN_FN"; - - return Status::OK(); + const ShaderVariableHelper& output, uint32_t nbits, bool has_zero_points, bool has_bias, bool has_weight_idx) { + return WGSL_TEMPLATE_APPLY(shader, "quantization/subgroup_matrix_matmul_nbits_apple.wgsl.template", + WGSL_TEMPLATE_PARAMETER(has_bias, has_bias), + WGSL_TEMPLATE_PARAMETER(has_weight_idx, has_weight_idx), + WGSL_TEMPLATE_PARAMETER(has_zero_points, has_zero_points), + WGSL_TEMPLATE_PARAMETER(n_bits, nbits), + WGSL_TEMPLATE_PARAMETER(output_type_i32, false), + WGSL_TEMPLATE_VARIABLE(a, a), + WGSL_TEMPLATE_VARIABLE(b, b), + WGSL_TEMPLATE_VARIABLE(output, output), + WGSL_TEMPLATE_VARIABLE(scales_b, scales_b)); } Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { @@ -502,10 +300,14 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader if (has_zero_points_) { shader.AddInput("zero_points", ShaderUsage::UseUniform); } + if (has_bias_) { + shader.AddInput("bias", ShaderUsage::UseUniform); + } const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); - if (!vendor_.compare("apple")) { - return GenerateShaderCodeOnApple(shader, a, b, scales_b, output, nbits_, has_zero_points_); + // TODO: add support for bias to the shader for Intel. In the meantime, use the shader for Metal + if (!vendor_.compare("apple") || has_bias_) { + return GenerateShaderCodeOnApple(shader, a, b, scales_b, output, nbits_, has_zero_points_, has_bias_, has_weight_idx_); } else if (!vendor_.compare("intel")) { return GenerateShaderCodeOnIntel(shader, b, scales_b, nbits_, config_index_, has_zero_points_); } else { @@ -515,7 +317,7 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader } Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, - const Tensor* zero_points, + const Tensor* zero_points, const Tensor* bias, uint32_t M, uint32_t N, uint32_t K, @@ -523,7 +325,8 @@ Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Te uint32_t zero_blocks_per_col, int32_t config_index, onnxruntime::webgpu::ComputeContext& context, - Tensor* y) { + Tensor* y, + const uint32_t weight_index) { // If applicable, layout optimization of input matrix A(MxK) can be used for SubgroupMatrixLoad. Tensor a_prepack; if (context.AdapterInfo().vendor == std::string_view{"intel"}) { @@ -552,13 +355,16 @@ Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Te ORT_RETURN_IF_ERROR(context.RunProgram(prepack_program)); a = &a_prepack; } + uint32_t tile_size_a = 32; uint32_t work_group_size = 128; constexpr uint32_t kTileSizeB = 64; constexpr uint32_t kU32Components = 4; TensorShape y_shape{1, M, N}; const bool has_zero_points = zero_points != nullptr; - SubgroupMatrixMatMulNBitsProgram mul_program{nbits, config_index, context.AdapterInfo().vendor, has_zero_points}; + const bool has_bias = bias != nullptr; + const bool has_weight_idx = weight_index > 0; + SubgroupMatrixMatMulNBitsProgram mul_program{nbits, config_index, context.AdapterInfo().vendor, has_zero_points, has_bias, has_weight_idx}; if (context.AdapterInfo().vendor == std::string_view{"intel"}) { tile_size_a = 64; work_group_size = 256; @@ -570,12 +376,15 @@ Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Te mul_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, 1}, {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(nbits == 4 ? kU32Components : 2 * kU32Components)}, {scales, ProgramTensorMetadataDependency::TypeAndRank, 1}}) - .AddUniformVariables({{M}, {N}, {K}, {zero_blocks_per_col}}) + .AddUniformVariables({{M}, {N}, {K}, {zero_blocks_per_col}, {weight_index}}) .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, y_shape, 1}) - .CacheHint(nbits, has_zero_points); + .CacheHint(nbits, has_zero_points, has_bias, has_weight_idx); if (has_zero_points) { mul_program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); } + if (bias) { + mul_program.AddInput({bias, ProgramTensorMetadataDependency::None}); + } return context.RunProgram(mul_program); } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h index cf1fb9a6f7f15..cb9bd8a599f54 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h @@ -21,27 +21,33 @@ using namespace onnxruntime::webgpu; class SubgroupMatrixMatMulNBitsProgram final : public Program { public: - SubgroupMatrixMatMulNBitsProgram(uint32_t nbits, int32_t config_index, const wgpu::StringView& vendor, bool has_zero_points) : Program{"SubgroupMatrixMatMulNBits"}, - nbits_(nbits), - config_index_(config_index), - vendor_(vendor), - has_zero_points_(has_zero_points) {} + SubgroupMatrixMatMulNBitsProgram(uint32_t nbits, int32_t config_index, const wgpu::StringView& vendor, bool has_zero_points, bool has_bias, bool has_weight_idx) + : Program{"SubgroupMatrixMatMulNBits"}, + nbits_(nbits), + config_index_(config_index), + vendor_(vendor), + has_zero_points_(has_zero_points), + has_bias_(has_bias), + has_weight_idx_{has_weight_idx} {}; Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( {"M", ProgramUniformVariableDataType::Uint32}, {"N", ProgramUniformVariableDataType::Uint32}, {"K", ProgramUniformVariableDataType::Uint32}, - {"zero_blocks_per_col", ProgramUniformVariableDataType::Uint32}); + {"zero_blocks_per_col", ProgramUniformVariableDataType::Uint32}, + {"weight_idx", ProgramUniformVariableDataType::Uint32}); private: uint32_t nbits_; int32_t config_index_; std::string vendor_; bool has_zero_points_; + bool has_bias_; + bool has_weight_idx_; }; Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, - const Tensor* zero_points, + const Tensor* zero_points, const Tensor* bias, uint32_t M, uint32_t N, uint32_t K, @@ -49,7 +55,8 @@ Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Te uint32_t zero_blocks_per_col, int32_t config_index, onnxruntime::webgpu::ComputeContext& context, - Tensor* y); + Tensor* y, + const uint32_t weight_index); bool CanApplySubgroupMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& context, uint64_t accuracy_level, diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_apple.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_apple.wgsl.template new file mode 100644 index 0000000000000..b3c14ff6e1eae --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_apple.wgsl.template @@ -0,0 +1,220 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// tile/subtile sizes and work distribution are inspired from metal shaders in llama.cpp (kernel_mul_mm) +// https://github.com/ggml-org/llama.cpp/blob/d04e7163c85a847bc61d58c22f2c503596db7aa8/ggml/src/ggml-metal/ggml-metal.metal#L6066 + +#param n_bits +#param has_zero_points +#param has_bias +#param has_weight_idx + +#use .getByOffset .setByOffset + +#include "quantization/matmul_nbits_zero_pt.wgsl.template" + +const tile_cols = 64; +const tile_rows = 32; +const tile_k = 32; +const subtile_cols = 32; +const subtile_rows = 16; +const quantization_block_size = 32; +alias compute_precision = output_element_t; + +var tile_A: array; // 32 x 32 - RxC +var tile_B: array; // 64 x 32 - RxC +var scratch: array, 4>, 4>; // 64 * 4 * 4 + +fn loadSHMA(tile_base: u32, k_idx: u32, row: u32, c_idx:u32) { + let a_global = tile_base + row; + if (a_global >= uniforms.M) { + return; + } + // Each call loads 8 columns, starting at col. + var col = c_idx * 8; + // 128 threads need to load 32 x 32. 4 threads per row or 8 col per thread. + for (var col_offset:u32 = 0; col_offset < 8; col_offset++) { + tile_A[row * tile_k + col + col_offset] = compute_precision(a.getByOffset(a_global * uniforms.K + k_idx + col + col_offset)); + } +} + +#if n_bits == 4 +fn loadSHMB(tile_base: u32, k_idx: u32, row: u32, c_idx: u32) { + let b_global = tile_base + row; + if (b_global >= uniforms.N) { + return; + } + // Each call loads 16 columns, starting at col. + var col = c_idx * 16; + // 128 threads need to load 64 x 32. 2 threads per row or 16 col per thread. + // Stored in column major fashion. +#if has_weight_idx + let b_base_offset = uniforms.weight_idx * uniforms.K * uniforms.N; + let scale_offset = uniforms.weight_idx * uniforms.N * (uniforms.K / quantization_block_size); +#else + const b_base_offset : u32 = 0; + const scale_offset : u32 = 0; +#endif + let b_idx = u32((b_global*uniforms.K + k_idx + col)/8) + b_base_offset / 8; + let scale = compute_precision(scales_b.getByOffset((b_global * uniforms.K + k_idx + col) / quantization_block_size + scale_offset)); + let zero = mm_read_zero(b_global, (k_idx + col) / quantization_block_size, uniforms.N, uniforms.zero_blocks_per_col); + for (var step:u32 = 0; step < 2; step++) { + var b_value = b.getByOffset(b_idx+step); + var b_value_lower = (vec4(unpack4xU8(b_value & 0x0F0F0F0Fu)) - vec4(zero)) * scale; + var b_value_upper = (vec4(unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu)) - vec4(zero)) * scale; + let tile_b_base = row * tile_k + col + step * 8; + tile_B[tile_b_base] = b_value_lower[0]; + tile_B[tile_b_base + 1] = b_value_upper[0]; + tile_B[tile_b_base + 2] = b_value_lower[1]; + tile_B[tile_b_base + 3] = b_value_upper[1]; + tile_B[tile_b_base + 4] = b_value_lower[2]; + tile_B[tile_b_base + 5] = b_value_upper[2]; + tile_B[tile_b_base + 6] = b_value_lower[3]; + tile_B[tile_b_base + 7] = b_value_upper[3]; + } +} +#endif + +#if n_bits == 8 +fn loadSHMB(tile_base: u32, k_idx: u32, row: u32, c_idx: u32) { + let b_global = tile_base + row; + if (b_global >= uniforms.N) { + return; + } + // Each call loads 16 columns, starting at col. + var col = c_idx * 16; + // 128 threads need to load 64 x 32. 2 threads per row or 16 col per thread. + // Stored in column major fashion. +#if has_weight_idx + let b_base_offset = uniforms.weight_idx * uniforms.K * uniforms.N; + let scale_offset = uniforms.weight_idx * uniforms.N * (uniforms.K / quantization_block_size); +#else + const b_base_offset : u32 = 0; + const scale_offset : u32 = 0; +#endif + let b_idx = u32((b_global*uniforms.K + k_idx + col)/8) + b_base_offset / 8; + let scale = compute_precision(scales_b.getByOffset((b_global * uniforms.K + k_idx + col) / quantization_block_size + scale_offset)); + let zero = mm_read_zero(b_global, (k_idx + col) / quantization_block_size, uniforms.N, uniforms.zero_blocks_per_col); + for (var step : u32 = 0; step < 2; step++) { + var b_value = b.getByOffset(b_idx+step); + var b_value0 = (vec4(unpack4xU8(b_value[0])) - vec4(zero)) * scale; + var b_value1 = (vec4(unpack4xU8(b_value[1])) - vec4(zero)) * scale; + let tile_b_base = row * tile_k + col + step * 8; + tile_B[tile_b_base] = b_value0[0]; + tile_B[tile_b_base + 1] = b_value0[1]; + tile_B[tile_b_base + 2] = b_value0[2]; + tile_B[tile_b_base + 3] = b_value0[3]; + tile_B[tile_b_base + 4] = b_value1[0]; + tile_B[tile_b_base + 5] = b_value1[1]; + tile_B[tile_b_base + 6] = b_value1[2]; + tile_B[tile_b_base + 7] = b_value1[3]; + } +} +#endif + +fn storeOutput(offset:u32, row: u32, col:u32, src_slot:u32, row_limit:i32) { + if (row_limit > 0 && row < u32(row_limit)) { + let col2 = col + 1; +#if has_bias + let col_base = offset % uniforms.N + uniforms.weight_idx * uniforms.N; + + output.setByOffset(offset + row * uniforms.N + col, output_element_t(scratch[src_slot][0][row * 8 + col]) + bias[col_base + col]); + output.setByOffset(offset + row * uniforms.N + col + 8, output_element_t(scratch[src_slot][1][row * 8 + col]) + bias[col_base + col + 8]); + output.setByOffset(offset + row * uniforms.N + col + 16, output_element_t(scratch[src_slot][2][row * 8 + col]) + bias[col_base + col + 16]); + output.setByOffset(offset + row * uniforms.N + col + 24, output_element_t(scratch[src_slot][3][row * 8 + col]) + bias[col_base + col + 24]); + + output.setByOffset(offset + row * uniforms.N + col2, output_element_t(scratch[src_slot][0][row * 8 + col2]) + bias[col_base + col2]); + output.setByOffset(offset + row * uniforms.N + col2 + 8, output_element_t(scratch[src_slot][1][row * 8 + col2]) + bias[col_base + col2 + 8]); + output.setByOffset(offset + row * uniforms.N + col2 + 16, output_element_t(scratch[src_slot][2][row * 8 + col2]) + bias[col_base + col2 + 16]); + output.setByOffset(offset + row * uniforms.N + col2 + 24, output_element_t(scratch[src_slot][3][row * 8 + col2]) + bias[col_base + col2 + 24]); +#else + output.setByOffset(offset + row * uniforms.N + col, output_element_t(scratch[src_slot][0][row * 8 + col])); + output.setByOffset(offset + row * uniforms.N + col + 8, output_element_t(scratch[src_slot][1][row * 8 + col])); + output.setByOffset(offset + row * uniforms.N + col + 16, output_element_t(scratch[src_slot][2][row * 8 + col])); + output.setByOffset(offset + row * uniforms.N + col + 24, output_element_t(scratch[src_slot][3][row * 8 + col])); + output.setByOffset(offset + row * uniforms.N + col2, output_element_t(scratch[src_slot][0][row * 8 + col2])); + output.setByOffset(offset + row * uniforms.N + col2 + 8, output_element_t(scratch[src_slot][1][row * 8 + col2])); + output.setByOffset(offset + row * uniforms.N + col2 + 16, output_element_t(scratch[src_slot][2][row * 8 + col2])); + output.setByOffset(offset + row * uniforms.N + col2 + 24, output_element_t(scratch[src_slot][3][row * 8 + col2])); +#endif + } +} + +$MAIN { + let a_global_base = workgroup_id.y * tile_rows; + let b_global_base = workgroup_id.x * tile_cols; + + let subtile_id = u32(local_idx / sg_size); + let subtile_idx = u32(subtile_id / 2); + let subtile_idy = subtile_id % 2; + let base_A = subtile_idy * subtile_rows; + let base_B = subtile_idx * subtile_cols; + + var matC00: subgroup_matrix_result; + var matC01: subgroup_matrix_result; + var matC02: subgroup_matrix_result; + var matC03: subgroup_matrix_result; + var matC10: subgroup_matrix_result; + var matC11: subgroup_matrix_result; + var matC12: subgroup_matrix_result; + var matC13: subgroup_matrix_result; + for (var kidx: u32 = 0; kidx < uniforms.K; kidx += tile_k) { + // Load Phase + loadSHMA(a_global_base, kidx, local_idx/4, local_idx%4); + loadSHMB(b_global_base, kidx, local_idx/2, local_idx%2); + workgroupBarrier(); + + for (var step: u32 = 0; step < tile_k; step+=8){ + // Load to local memory phase + let matrix_a_offset = subtile_idy * subtile_rows * tile_k + step; + // Syntax: subgroupMatrixLoad src_ptr,src_offset,is_col_major,src_stride + var matA0: subgroup_matrix_left = subgroupMatrixLoad>(&tile_A, matrix_a_offset, false, tile_k); + var matA1: subgroup_matrix_left = subgroupMatrixLoad>(&tile_A, matrix_a_offset + 8 * tile_k, false, tile_k); + + // tile_B is stored as column major. + // [col0-0:32][col1-0:32][col2-0:32]..[col63-0:32] + var matrix_b_offset = subtile_idx * subtile_cols * tile_k + step; + var matB0: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset, true, tile_k); + var matB1: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 8 * tile_k, true, tile_k); + var matB2: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 16 * tile_k, true, tile_k); + var matB3: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 24 * tile_k, true, tile_k); + + // Compute Phase + // Syntax: subgroupMatrixMultiplyAccumulate left, right, accumulate -> accumulate + matC00 = subgroupMatrixMultiplyAccumulate(matA0, matB0, matC00); + matC01 = subgroupMatrixMultiplyAccumulate(matA0, matB1, matC01); + matC02 = subgroupMatrixMultiplyAccumulate(matA0, matB2, matC02); + matC03 = subgroupMatrixMultiplyAccumulate(matA0, matB3, matC03); + + matC10 = subgroupMatrixMultiplyAccumulate(matA1, matB0, matC10); + matC11 = subgroupMatrixMultiplyAccumulate(matA1, matB1, matC11); + matC12 = subgroupMatrixMultiplyAccumulate(matA1, matB2, matC12); + matC13 = subgroupMatrixMultiplyAccumulate(matA1, matB3, matC13); + } + workgroupBarrier(); + } + + // Write out + // Write out top block + subgroupMatrixStore(&scratch[subtile_id][0], 0, matC00, false, 8); + subgroupMatrixStore(&scratch[subtile_id][1], 0, matC01, false, 8); + subgroupMatrixStore(&scratch[subtile_id][2], 0, matC02, false, 8); + subgroupMatrixStore(&scratch[subtile_id][3], 0, matC03, false, 8); + workgroupBarrier(); + let row = u32(sg_id / 4); + var col = u32(sg_id % 4) * 2; + var matrix_c_offset = (a_global_base+base_A) * uniforms.N + b_global_base + base_B; + var row_limit:i32 = i32(uniforms.M) - i32(a_global_base + base_A); + storeOutput(matrix_c_offset, row, col, subtile_id, row_limit); + workgroupBarrier(); + + // Write out bottom block + subgroupMatrixStore(&scratch[subtile_id][0], 0, matC10, false, 8); + subgroupMatrixStore(&scratch[subtile_id][1], 0, matC11, false, 8); + subgroupMatrixStore(&scratch[subtile_id][2], 0, matC12, false, 8); + subgroupMatrixStore(&scratch[subtile_id][3], 0, matC13, false, 8); + workgroupBarrier(); + matrix_c_offset = matrix_c_offset + 8 * uniforms.N; + row_limit = i32(uniforms.M) - i32(a_global_base + base_A + 8); + storeOutput(matrix_c_offset, row, col, subtile_id, row_limit); +} // MAIN diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index cc0e3207e6795..f5e741b1c6f8d 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -342,6 +342,14 @@ void TestMatMulNBitsTyped(std::optional abs_error = std::nullopt, RunTest(opts); } #endif // !defined(USE_DML) && !defined(USE_WEBGPU) +#if defined(USE_WEBGPU) + { + // WebGPU does support bias but no g_idx + TestOptions opts = base_opts; + opts.has_bias = true; + RunTest(opts); + } +#endif } #if !defined(USE_OPENVINO) diff --git a/onnxruntime/test/contrib_ops/matmul_8bits_test.cc b/onnxruntime/test/contrib_ops/matmul_8bits_test.cc index b336debecef94..652d3246ce462 100644 --- a/onnxruntime/test/contrib_ops/matmul_8bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_8bits_test.cc @@ -249,7 +249,7 @@ void TestMatMul8BitsTyped(float abs_error = 0.1f, float rel_error = 0.02f) { } // CUDA/WEBGPU does not support bias for MatMulNBits -#if !defined(USE_CUDA) && !defined(USE_WEBGPU) +#if !defined(USE_CUDA) { TestOptions8Bits opts = base_opts; opts.has_zero_point = false;