From 85caa29d934af46192aeb590fed2d5371754e3ce Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Fri, 12 Dec 2025 15:09:14 +0000 Subject: [PATCH 1/6] adds optimized path for small dimension sizes to group_index_select_or_add_2d_kernel --- .../src/sparse_ops/sparse_group_index.cu | 76 ++++++++++++++----- fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp | 14 +++- 2 files changed, 70 insertions(+), 20 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu index 96c57cde68..a0584b23de 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu @@ -83,25 +83,65 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( member_id = warp_id / (warps_per_row * num_work_rows); member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows); } - const auto row = member_warp_id / warps_per_row; - const auto col_offset = - ((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) + - (threadIdx.x * UNROLL_FACTOR); - scalar_t* input = - reinterpret_cast(input_ptrs[member_id]) + col_offset; - scalar_t* output = - reinterpret_cast(output_ptrs[member_id]) + col_offset; - - index_t* indices = reinterpret_cast(indices_ptrs[member_id]); - const index_t idx = indices[row]; + + if (num_cols < COLS_PER_WARP) { + // Optimized path for small embedding dimensions + // Each warp processes 'rows_per_warp' rows + int rows_per_warp = COLS_PER_WARP / num_cols; + int64_t start_row = member_warp_id * rows_per_warp; + + // Since we are processing multiple rows within the warp, we need to + // map each lane to a specific row, in addition to the column + int local_row = (threadIdx.x * UNROLL_FACTOR) / num_cols; // the row ID within the set of rows handled by this warp + int col = (threadIdx.x * UNROLL_FACTOR) % num_cols; + int64_t current_row = start_row + local_row; // the actual row within the table processed by this lane + + // local_row may be out of bounds for the last few lanes in the warp + // if [COLS_PER_WARP % num_cols != 0] + // TODO: check if current_row < num_work_rows is necessary + if (local_row < rows_per_warp && current_row < num_work_rows) { + index_t* indices = reinterpret_cast(indices_ptrs[member_id]); + index_t idx = indices[current_row]; + + scalar_t* input_base = reinterpret_cast(input_ptrs[member_id]); + scalar_t* output_base = reinterpret_cast(output_ptrs[member_id]); + +#pragma unroll + for (int i = 0; i < UNROLL_FACTOR && col + i < num_cols; i++) { + if constexpr (USE_INDEX_SELECT) { + output_base[current_row * num_cols + col] = + LDG(&input_base[idx * num_cols + col]); + } else { + gpuAtomicAddNoReturn( + &output_base[idx * num_cols + col], + input_base[current_row * num_cols + col]); + } + } + } + } else { + // Large embedding dimensions use >= 1 warp per row + + const auto row = member_warp_id / warps_per_row; + const auto col_offset = + ((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) + + (threadIdx.x * UNROLL_FACTOR); + + scalar_t* input = + reinterpret_cast(input_ptrs[member_id]) + col_offset; + scalar_t* output = + reinterpret_cast(output_ptrs[member_id]) + col_offset; + + index_t* indices = reinterpret_cast(indices_ptrs[member_id]); + const index_t idx = indices[row]; + #pragma unroll - for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { - // Compile time conditional - if constexpr (USE_INDEX_SELECT) { - output[row * num_cols + i] = LDG(&input[idx * num_cols + i]); - } else { - gpuAtomicAddNoReturn( - &output[idx * num_cols + i], input[row * num_cols + i]); + for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { + if constexpr (USE_INDEX_SELECT) { + output[row * num_cols + i] = LDG(&input[idx * num_cols + i]); + } else { + gpuAtomicAddNoReturn( + &output[idx * num_cols + i], input[row * num_cols + i]); + } } } } diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index 9e8587b8d1..bdd0f13652 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -303,7 +303,17 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( // Number of columns can be different auto num_cols_ = input_reshaped_.size(1); - auto warps_per_row = (num_cols_ + cols_per_warp - 1) / cols_per_warp; + + int64_t warps_needed; + if (num_cols_ < cols_per_warp) { + // Optimization: Pack multiple rows into one warp + int rows_per_warp = cols_per_warp / num_cols_; + warps_needed = (num_output_rows_ + rows_per_warp - 1) / rows_per_warp; + } else { + // Standard: One or more warps per row + int warps_per_row = (num_cols_ + cols_per_warp - 1) / cols_per_warp; + warps_needed = warps_per_row * num_output_rows_; + } if (num_cols != num_cols_) { use_var_cols = true; @@ -329,7 +339,7 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( warp_offsets_group[i] = warp_offset; num_cols_group[i] = num_cols_; - warp_offset += warps_per_row * num_output_rows; + warp_offset += warps_needed; } // Store the last offset From ff1b9b6c70f9483cf2ed23e3afbd76c3a599cf9e Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Tue, 16 Dec 2025 17:27:58 +0000 Subject: [PATCH 2/6] sparse_group_index.cu: edits some comments --- fbgemm_gpu/src/sparse_ops/sparse_group_index.cu | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu index a0584b23de..a8eb081610 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu @@ -96,9 +96,8 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( int col = (threadIdx.x * UNROLL_FACTOR) % num_cols; int64_t current_row = start_row + local_row; // the actual row within the table processed by this lane - // local_row may be out of bounds for the last few lanes in the warp - // if [COLS_PER_WARP % num_cols != 0] - // TODO: check if current_row < num_work_rows is necessary + // local_row may be out of bounds for the last few lanes in the warp if [COLS_PER_WARP % num_cols != 0] + // and we also need to confirm that we are within num_work_rows if (local_row < rows_per_warp && current_row < num_work_rows) { index_t* indices = reinterpret_cast(indices_ptrs[member_id]); index_t idx = indices[current_row]; From 439a51a567ea6274dbffa4d2614274bc2f82d9e5 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Tue, 16 Dec 2025 18:28:08 +0000 Subject: [PATCH 3/6] adds USE_ROCM guards to subwarp optimizations for group_index_select_or_add_2d_kernel --- fbgemm_gpu/src/sparse_ops/sparse_group_index.cu | 10 +++++++--- fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp | 9 +++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu index a8eb081610..2c54c7bab1 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu @@ -84,6 +84,7 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows); } +#ifdef USE_ROCM if (num_cols < COLS_PER_WARP) { // Optimized path for small embedding dimensions // Each warp processes 'rows_per_warp' rows @@ -107,6 +108,7 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( #pragma unroll for (int i = 0; i < UNROLL_FACTOR && col + i < num_cols; i++) { + // Compile time conditional if constexpr (USE_INDEX_SELECT) { output_base[current_row * num_cols + col] = LDG(&input_base[idx * num_cols + col]); @@ -119,12 +121,12 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( } } else { // Large embedding dimensions use >= 1 warp per row - + // which is the default codepath for non-ROCm as well +#endif // USE_ROCM const auto row = member_warp_id / warps_per_row; const auto col_offset = ((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) + (threadIdx.x * UNROLL_FACTOR); - scalar_t* input = reinterpret_cast(input_ptrs[member_id]) + col_offset; scalar_t* output = @@ -132,9 +134,9 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( index_t* indices = reinterpret_cast(indices_ptrs[member_id]); const index_t idx = indices[row]; - #pragma unroll for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { + // Compile time conditional if constexpr (USE_INDEX_SELECT) { output[row * num_cols + i] = LDG(&input[idx * num_cols + i]); } else { @@ -142,7 +144,9 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( &output[idx * num_cols + i], input[row * num_cols + i]); } } +#ifdef USE_ROCM } +#endif // USE_ROCM } } diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index bdd0f13652..d592cce6a9 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -304,6 +304,7 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( // Number of columns can be different auto num_cols_ = input_reshaped_.size(1); +#ifdef USE_ROCM int64_t warps_needed; if (num_cols_ < cols_per_warp) { // Optimization: Pack multiple rows into one warp @@ -314,6 +315,10 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( int warps_per_row = (num_cols_ + cols_per_warp - 1) / cols_per_warp; warps_needed = warps_per_row * num_output_rows_; } +#else + // Standard: One or more warps per row + auto warps_per_row = (num_cols_ + cols_per_warp - 1) / cols_per_warp; +#endif // USE_ROCM if (num_cols != num_cols_) { use_var_cols = true; @@ -339,7 +344,11 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( warp_offsets_group[i] = warp_offset; num_cols_group[i] = num_cols_; +#ifdef USE_ROCM warp_offset += warps_needed; +#else + warp_offset += warps_per_row * num_output_rows; +#endif // USE_ROCM } // Store the last offset From 2a85d73f669959ff86859c216edae4b41db4f53b Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Thu, 18 Dec 2025 10:11:29 +0000 Subject: [PATCH 4/6] sparse_group_index: handle UNROLL_FACTOR for small dimensions in group_index_select_or_add_2d_kernel --- .../src/sparse_ops/sparse_group_index.cu | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu index 2c54c7bab1..5e87d96961 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu @@ -94,28 +94,27 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( // Since we are processing multiple rows within the warp, we need to // map each lane to a specific row, in addition to the column int local_row = (threadIdx.x * UNROLL_FACTOR) / num_cols; // the row ID within the set of rows handled by this warp - int col = (threadIdx.x * UNROLL_FACTOR) % num_cols; + int col_offset = (threadIdx.x * UNROLL_FACTOR) % num_cols; int64_t current_row = start_row + local_row; // the actual row within the table processed by this lane // local_row may be out of bounds for the last few lanes in the warp if [COLS_PER_WARP % num_cols != 0] // and we also need to confirm that we are within num_work_rows if (local_row < rows_per_warp && current_row < num_work_rows) { - index_t* indices = reinterpret_cast(indices_ptrs[member_id]); - index_t idx = indices[current_row]; - - scalar_t* input_base = reinterpret_cast(input_ptrs[member_id]); - scalar_t* output_base = reinterpret_cast(output_ptrs[member_id]); + scalar_t* input = + reinterpret_cast(input_ptrs[member_id]) + col_offset; + scalar_t* output = + reinterpret_cast(output_ptrs[member_id]) + col_offset; + index_t* indices = reinterpret_cast(indices_ptrs[member_id]); + const index_t idx = indices[current_row]; #pragma unroll - for (int i = 0; i < UNROLL_FACTOR && col + i < num_cols; i++) { + for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { // Compile time conditional if constexpr (USE_INDEX_SELECT) { - output_base[current_row * num_cols + col] = - LDG(&input_base[idx * num_cols + col]); + output[current_row * num_cols + i] = LDG(&input[idx * num_cols + i]); } else { gpuAtomicAddNoReturn( - &output_base[idx * num_cols + col], - input_base[current_row * num_cols + col]); + &output[idx * num_cols + i], input[current_row * num_cols + i]); } } } From 2f541407b26c53140e857fee4e3cbc217beb5864 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Thu, 18 Dec 2025 13:26:01 +0000 Subject: [PATCH 5/6] sparse_group_index: handle fixed-column-size case correctly in optimized small embedding dims path --- fbgemm_gpu/src/sparse_ops/sparse_group_index.cu | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu index 5e87d96961..3f191c6100 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu @@ -82,6 +82,16 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( // All columns are the same member_id = warp_id / (warps_per_row * num_work_rows); member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows); +#ifdef USE_ROCM + if (num_cols < COLS_PER_WARP) { + // Need to ensure that [member_id] and [member_warp_id] are calculated correctly + // for the small embedding dimension path below + int rows_per_warp = COLS_PER_WARP / num_cols; + auto warps_per_member = (num_work_rows + rows_per_warp - 1) / rows_per_warp; + member_id = warp_id / warps_per_member; + member_warp_id = warp_id % warps_per_member; + } +#endif // USE_ROCM } #ifdef USE_ROCM From e0edc4095377302c84f9450ec400800c5a98dcba Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Thu, 18 Dec 2025 14:27:06 +0000 Subject: [PATCH 6/6] group_index_select_or_add_2d_kernel: when num_cols < UNROLL_FACTOR, disable optimized smallEmbD path --- fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h | 2 ++ fbgemm_gpu/src/sparse_ops/sparse_group_index.cu | 8 ++++++-- fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp | 3 ++- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h index efebf3ac02..6f3f8c246c 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h +++ b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h @@ -1090,6 +1090,8 @@ void group_index_select_or_add_cuda( int get_group_index_select_cols_per_warp(); +int get_group_index_select_unroll_factor(); + std::vector jagged_index_select_2d( const at::Tensor& values, const at::Tensor& lengths, diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu index 3f191c6100..12ed1045d4 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu @@ -33,6 +33,10 @@ int get_group_index_select_cols_per_warp() { return GROUP_INDEX_SELECT_COLS_PER_WARP; } +int get_group_index_select_unroll_factor() { + return GROUP_INDEX_SELECT_UNROLL_FACTOR; +} + template < typename index_t, typename scalar_t, @@ -83,7 +87,7 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( member_id = warp_id / (warps_per_row * num_work_rows); member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows); #ifdef USE_ROCM - if (num_cols < COLS_PER_WARP) { + if (num_cols < COLS_PER_WARP && num_cols >= UNROLL_FACTOR) { // Need to ensure that [member_id] and [member_warp_id] are calculated correctly // for the small embedding dimension path below int rows_per_warp = COLS_PER_WARP / num_cols; @@ -95,7 +99,7 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( } #ifdef USE_ROCM - if (num_cols < COLS_PER_WARP) { + if (num_cols < COLS_PER_WARP && num_cols >= UNROLL_FACTOR) { // Optimized path for small embedding dimensions // Each warp processes 'rows_per_warp' rows int rows_per_warp = COLS_PER_WARP / num_cols; diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index d592cce6a9..df3f49af5c 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -265,6 +265,7 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( Tensor input_reshaped = first_input.reshape({num_input_rows, -1}); const int num_cols = input_reshaped.size(1); const int cols_per_warp = get_group_index_select_cols_per_warp(); + const int unroll_factor = get_group_index_select_unroll_factor(); int64_t warp_offset = 0; bool use_var_cols = false; @@ -306,7 +307,7 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( #ifdef USE_ROCM int64_t warps_needed; - if (num_cols_ < cols_per_warp) { + if (num_cols_ < cols_per_warp && num_cols_ >= unroll_factor) { // Optimization: Pack multiple rows into one warp int rows_per_warp = cols_per_warp / num_cols_; warps_needed = (num_output_rows_ + rows_per_warp - 1) / rows_per_warp;