From a1e2aa2a444c8104d0f7e4334a0c6d57b03c5dc6 Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 2 Mar 2026 13:03:31 -0800 Subject: [PATCH] [ET-VK][ez] Use tree reduction in q8ta_linear_gemv shader Replace the serial O(WGS) reduction loop with a tree reduction pattern (O(log2(WGS))). Previously, only thread 0 summed all 64 partial accumulators sequentially. Now all threads participate in a classic halving reduction, matching the pattern already used in linear_q4gsw_coop.glsl. Authored by Claude. Differential Revision: [D94949137](https://our.internmc.facebook.com/intern/diff/D94949137/) [ghstack-poisoned] --- .../graph/ops/glsl/q8ta_linear_gemv.glsl | 29 ++++++++++++++----- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.glsl index 241fc1845bf..becff2ab9ab 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear_gemv.glsl @@ -66,7 +66,12 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; #include "linear_int_weight_sums_load.glslh" #include "linear_fp_bias_load.glslh" -shared Int32Accum partial_accums[WGS]; +// Array-of-arrays shared memory layout: partial_accums[lid][tile_n4]. +// Each element is exactly 16 bytes (ivec4). This avoids the Samsung S25 +// (Adreno 830) driver bug triggered by the original Int32Accum struct layout, +// where barrier() only invalidated the first 16-byte component of each +// 32-byte struct slot, leaving subsequent components stale. +shared ivec4 partial_accums[WGS][TILE_N4]; void main() { const int lid = int(gl_LocalInvocationID.z); @@ -104,19 +109,29 @@ void main() { } } - partial_accums[lid] = out_accum; + [[unroll]] for (int tile_n4 = 0; tile_n4 < TILE_N4; ++tile_n4) { + partial_accums[lid][tile_n4] = out_accum.data[0][tile_n4]; + } memoryBarrierShared(); barrier(); - // Only the first thread writes the result - if (lid == 0) { - for (int i = 1; i < WGS; ++i) { + // Tree reduction: O(log2(WGS)). + for (int i = WGS / 2; i > 0; i /= 2) { + if (lid < i) { [[unroll]] for (int tile_n4 = 0; tile_n4 < TILE_N4; ++tile_n4) { - out_accum.data[0][tile_n4] += - partial_accums[i].data[0][tile_n4]; + partial_accums[lid][tile_n4] += partial_accums[lid + i][tile_n4]; } } + memoryBarrierShared(); + barrier(); + } + + // Only the first thread writes the result + if (lid == 0) { + [[unroll]] for (int tile_n4 = 0; tile_n4 < TILE_N4; ++tile_n4) { + out_accum.data[0][tile_n4] = partial_accums[0][tile_n4]; + } FPPerOutChannelParams weight_scales_tile; load_weight_scales_tile(weight_scales_tile, n4);