Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -1090,6 +1090,8 @@ void group_index_select_or_add_cuda(

int get_group_index_select_cols_per_warp();

int get_group_index_select_unroll_factor();

std::vector<at::Tensor> jagged_index_select_2d(
const at::Tensor& values,
const at::Tensor& lengths,
Expand Down
92 changes: 74 additions & 18 deletions fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ int get_group_index_select_cols_per_warp() {
return GROUP_INDEX_SELECT_COLS_PER_WARP;
}

int get_group_index_select_unroll_factor() {
return GROUP_INDEX_SELECT_UNROLL_FACTOR;
}

template <
typename index_t,
typename scalar_t,
Expand Down Expand Up @@ -82,28 +86,80 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
// All columns are the same
member_id = warp_id / (warps_per_row * num_work_rows);
member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows);
#ifdef USE_ROCM
if (num_cols < COLS_PER_WARP && num_cols >= UNROLL_FACTOR) {
// Need to ensure that [member_id] and [member_warp_id] are calculated correctly
// for the small embedding dimension path below
int rows_per_warp = COLS_PER_WARP / num_cols;
auto warps_per_member = (num_work_rows + rows_per_warp - 1) / rows_per_warp;
member_id = warp_id / warps_per_member;
member_warp_id = warp_id % warps_per_member;
}
#endif // USE_ROCM
}
const auto row = member_warp_id / warps_per_row;
const auto col_offset =
((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) +
(threadIdx.x * UNROLL_FACTOR);
scalar_t* input =
reinterpret_cast<scalar_t*>(input_ptrs[member_id]) + col_offset;
scalar_t* output =
reinterpret_cast<scalar_t*>(output_ptrs[member_id]) + col_offset;

index_t* indices = reinterpret_cast<index_t*>(indices_ptrs[member_id]);
const index_t idx = indices[row];

#ifdef USE_ROCM
if (num_cols < COLS_PER_WARP && num_cols >= UNROLL_FACTOR) {
// Optimized path for small embedding dimensions
// Each warp processes 'rows_per_warp' rows
int rows_per_warp = COLS_PER_WARP / num_cols;
int64_t start_row = member_warp_id * rows_per_warp;

// Since we are processing multiple rows within the warp, we need to
// map each lane to a specific row, in addition to the column
int local_row = (threadIdx.x * UNROLL_FACTOR) / num_cols; // the row ID within the set of rows handled by this warp
int col_offset = (threadIdx.x * UNROLL_FACTOR) % num_cols;
int64_t current_row = start_row + local_row; // the actual row within the table processed by this lane

// local_row may be out of bounds for the last few lanes in the warp if [COLS_PER_WARP % num_cols != 0]
// and we also need to confirm that we are within num_work_rows
if (local_row < rows_per_warp && current_row < num_work_rows) {
scalar_t* input =
reinterpret_cast<scalar_t*>(input_ptrs[member_id]) + col_offset;
scalar_t* output =
reinterpret_cast<scalar_t*>(output_ptrs[member_id]) + col_offset;

index_t* indices = reinterpret_cast<index_t*>(indices_ptrs[member_id]);
const index_t idx = indices[current_row];
#pragma unroll
for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) {
// Compile time conditional
if constexpr (USE_INDEX_SELECT) {
output[row * num_cols + i] = LDG(&input[idx * num_cols + i]);
} else {
gpuAtomicAddNoReturn(
&output[idx * num_cols + i], input[row * num_cols + i]);
for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) {
// Compile time conditional
if constexpr (USE_INDEX_SELECT) {
output[current_row * num_cols + i] = LDG(&input[idx * num_cols + i]);
} else {
gpuAtomicAddNoReturn(
&output[idx * num_cols + i], input[current_row * num_cols + i]);
}
}
}
} else {
// Large embedding dimensions use >= 1 warp per row
// which is the default codepath for non-ROCm as well
#endif // USE_ROCM
const auto row = member_warp_id / warps_per_row;
const auto col_offset =
((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) +
(threadIdx.x * UNROLL_FACTOR);
scalar_t* input =
reinterpret_cast<scalar_t*>(input_ptrs[member_id]) + col_offset;
scalar_t* output =
reinterpret_cast<scalar_t*>(output_ptrs[member_id]) + col_offset;

index_t* indices = reinterpret_cast<index_t*>(indices_ptrs[member_id]);
const index_t idx = indices[row];
#pragma unroll
for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) {
// Compile time conditional
if constexpr (USE_INDEX_SELECT) {
output[row * num_cols + i] = LDG(&input[idx * num_cols + i]);
} else {
gpuAtomicAddNoReturn(
&output[idx * num_cols + i], input[row * num_cols + i]);
}
}
#ifdef USE_ROCM
}
#endif // USE_ROCM
}
}

Expand Down
20 changes: 20 additions & 0 deletions fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu(
Tensor input_reshaped = first_input.reshape({num_input_rows, -1});
const int num_cols = input_reshaped.size(1);
const int cols_per_warp = get_group_index_select_cols_per_warp();
const int unroll_factor = get_group_index_select_unroll_factor();
int64_t warp_offset = 0;
bool use_var_cols = false;

Expand Down Expand Up @@ -303,7 +304,22 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu(

// Number of columns can be different
auto num_cols_ = input_reshaped_.size(1);

#ifdef USE_ROCM
int64_t warps_needed;
if (num_cols_ < cols_per_warp && num_cols_ >= unroll_factor) {
// Optimization: Pack multiple rows into one warp
int rows_per_warp = cols_per_warp / num_cols_;
warps_needed = (num_output_rows_ + rows_per_warp - 1) / rows_per_warp;
} else {
// Standard: One or more warps per row
int warps_per_row = (num_cols_ + cols_per_warp - 1) / cols_per_warp;
warps_needed = warps_per_row * num_output_rows_;
}
#else
// Standard: One or more warps per row
auto warps_per_row = (num_cols_ + cols_per_warp - 1) / cols_per_warp;
#endif // USE_ROCM

if (num_cols != num_cols_) {
use_var_cols = true;
Expand All @@ -329,7 +345,11 @@ static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu(
warp_offsets_group[i] = warp_offset;
num_cols_group[i] = num_cols_;

#ifdef USE_ROCM
warp_offset += warps_needed;
#else
warp_offset += warps_per_row * num_output_rows;
#endif // USE_ROCM
}

// Store the last offset
Expand Down
Loading