diff --git a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h index efebf3ac02..6f3f8c246c 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h +++ b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h @@ -1090,6 +1090,8 @@ void group_index_select_or_add_cuda( int get_group_index_select_cols_per_warp(); +int get_group_index_select_unroll_factor(); + std::vector jagged_index_select_2d( const at::Tensor& values, const at::Tensor& lengths, diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu index 96c57cde68..12ed1045d4 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu @@ -33,6 +33,10 @@ int get_group_index_select_cols_per_warp() { return GROUP_INDEX_SELECT_COLS_PER_WARP; } +int get_group_index_select_unroll_factor() { + return GROUP_INDEX_SELECT_UNROLL_FACTOR; +} + template < typename index_t, typename scalar_t, @@ -82,28 +86,80 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( // All columns are the same member_id = warp_id / (warps_per_row * num_work_rows); member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows); +#ifdef USE_ROCM + if (num_cols < COLS_PER_WARP && num_cols >= UNROLL_FACTOR) { + // Need to ensure that [member_id] and [member_warp_id] are calculated correctly + // for the small embedding dimension path below + int rows_per_warp = COLS_PER_WARP / num_cols; + auto warps_per_member = (num_work_rows + rows_per_warp - 1) / rows_per_warp; + member_id = warp_id / warps_per_member; + member_warp_id = warp_id % warps_per_member; + } +#endif // USE_ROCM } - const auto row = member_warp_id / warps_per_row; - const auto col_offset = - ((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) + - (threadIdx.x * UNROLL_FACTOR); - scalar_t* input = - reinterpret_cast(input_ptrs[member_id]) + col_offset; - scalar_t* output = - reinterpret_cast(output_ptrs[member_id]) + col_offset; - - index_t* indices = reinterpret_cast(indices_ptrs[member_id]); - const index_t idx = indices[row]; + +#ifdef USE_ROCM + if (num_cols < COLS_PER_WARP && num_cols >= UNROLL_FACTOR) { + // Optimized path for small embedding dimensions + // Each warp processes 'rows_per_warp' rows + int rows_per_warp = COLS_PER_WARP / num_cols; + int64_t start_row = member_warp_id * rows_per_warp; + + // Since we are processing multiple rows within the warp, we need to + // map each lane to a specific row, in addition to the column + int local_row = (threadIdx.x * UNROLL_FACTOR) / num_cols; // the row ID within the set of rows handled by this warp + int col_offset = (threadIdx.x * UNROLL_FACTOR) % num_cols; + int64_t current_row = start_row + local_row; // the actual row within the table processed by this lane + + // local_row may be out of bounds for the last few lanes in the warp if [COLS_PER_WARP % num_cols != 0] + // and we also need to confirm that we are within num_work_rows + if (local_row < rows_per_warp && current_row < num_work_rows) { + scalar_t* input = + reinterpret_cast(input_ptrs[member_id]) + col_offset; + scalar_t* output = + reinterpret_cast(output_ptrs[member_id]) + col_offset; + + index_t* indices = reinterpret_cast(indices_ptrs[member_id]); + const index_t idx = indices[current_row]; #pragma unroll - for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { - // Compile time conditional - if constexpr (USE_INDEX_SELECT) { - output[row * num_cols + i] = LDG(&input[idx * num_cols + i]); - } else { - gpuAtomicAddNoReturn( - &output[idx * num_cols + i], input[row * num_cols + i]); + for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { + // Compile time conditional + if constexpr (USE_INDEX_SELECT) { + output[current_row * num_cols + i] = LDG(&input[idx * num_cols + i]); + } else { + gpuAtomicAddNoReturn( + &output[idx * num_cols + i], input[current_row * num_cols + i]); + } + } } + } else { + // Large embedding dimensions use >= 1 warp per row + // which is the default codepath for non-ROCm as well +#endif // USE_ROCM + const auto row = member_warp_id / warps_per_row; + const auto col_offset = + ((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) + + (threadIdx.x * UNROLL_FACTOR); + scalar_t* input = + reinterpret_cast(input_ptrs[member_id]) + col_offset; + scalar_t* output = + reinterpret_cast(output_ptrs[member_id]) + col_offset; + + index_t* indices = reinterpret_cast(indices_ptrs[member_id]); + const index_t idx = indices[row]; +#pragma unroll + for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { + // Compile time conditional + if constexpr (USE_INDEX_SELECT) { + output[row * num_cols + i] = LDG(&input[idx * num_cols + i]); + } else { + gpuAtomicAddNoReturn( + &output[idx * num_cols + i], input[row * num_cols + i]); + } + } +#ifdef USE_ROCM } +#endif // USE_ROCM } } diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index 9e8587b8d1..df3f49af5c 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -265,6 +265,7 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( Tensor input_reshaped = first_input.reshape({num_input_rows, -1}); const int num_cols = input_reshaped.size(1); const int cols_per_warp = get_group_index_select_cols_per_warp(); + const int unroll_factor = get_group_index_select_unroll_factor(); int64_t warp_offset = 0; bool use_var_cols = false; @@ -303,7 +304,22 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( // Number of columns can be different auto num_cols_ = input_reshaped_.size(1); + +#ifdef USE_ROCM + int64_t warps_needed; + if (num_cols_ < cols_per_warp && num_cols_ >= unroll_factor) { + // Optimization: Pack multiple rows into one warp + int rows_per_warp = cols_per_warp / num_cols_; + warps_needed = (num_output_rows_ + rows_per_warp - 1) / rows_per_warp; + } else { + // Standard: One or more warps per row + int warps_per_row = (num_cols_ + cols_per_warp - 1) / cols_per_warp; + warps_needed = warps_per_row * num_output_rows_; + } +#else + // Standard: One or more warps per row auto warps_per_row = (num_cols_ + cols_per_warp - 1) / cols_per_warp; +#endif // USE_ROCM if (num_cols != num_cols_) { use_var_cols = true; @@ -329,7 +345,11 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( warp_offsets_group[i] = warp_offset; num_cols_group[i] = num_cols_; +#ifdef USE_ROCM + warp_offset += warps_needed; +#else warp_offset += warps_per_row * num_output_rows; +#endif // USE_ROCM } // Store the last offset