Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions fbgemm_gpu/include/fbgemm_gpu/utils/inclusive_sum_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,52 @@ __inline__ __device__ void inclusive_sum_scan_kernel(
const int block_id,
const bool is_multi_block,
const int signal) {
// ROCm path
#ifdef USE_ROCM
// Perform scan within a block
cub::BlockScan<scalar_t, NUM_THREADS_PER_BLOCK>(temp_storage)
.InclusiveSum(arr, arr);

// Perform stream scan across blocks
// Perform scan across blocks
if (is_multi_block) {
const bool is_last_thread =
threadIdx.x == (num_entries_per_block - 1) / ITEMS_PER_THREAD;
// The thread that holds the last entry in the block does synchronization
if (is_last_thread) {
scalar_t block_prev_local = 0;
if (block_id != 0) {
// Spin wait for the previous block to write the sum value
while (atomicAdd(&block_flags[block_id - 1], 0) < signal)
;

// Get sum from the previous block
*block_prev = block_prev_local = block_sums[block_id - 1];
}

// Write sum to global memory for the next block to consume
const int scope = (num_entries_per_block - 1) % ITEMS_PER_THREAD;
block_sums[block_id] = block_prev_local + arr[scope];
__threadfence();
// Set a flag to notify the next block
atomicExch(&block_flags[block_id], signal);
}

__syncthreads();

if (block_id != 0) {
scalar_t block_prev_local = *block_prev;
for (int i = 0; i < ITEMS_PER_THREAD; i++) {
arr[i] += block_prev_local;
}
}
}
#else
// CUDA path
// Perform scan across blocks
cub::BlockScan<scalar_t, NUM_THREADS_PER_BLOCK>(temp_storage)
.InclusiveSum(arr, arr);

// Perform scan across blocks
if (is_multi_block) {
// The thread that holds the last entry in the block does synchronization
if (threadIdx.x == (num_entries_per_block - 1) / ITEMS_PER_THREAD) {
Expand Down Expand Up @@ -104,6 +145,6 @@ __inline__ __device__ void inclusive_sum_scan_kernel(
}
}
}
#endif
}

} // namespace fbgemm_gpu
214 changes: 171 additions & 43 deletions fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
* LICENSE file in the root directory of this source tree.
*/

#include <type_traits>
#include "common.cuh"

using Tensor = at::Tensor;
Expand All @@ -17,7 +18,8 @@ template <
typename index_t,
typename acc_t,
int NUM_THREADS_PER_BLOCK,
int MAX_ENTRIES_PER_BLOCK>
int MAX_ENTRIES_PER_BLOCK,
int ENTRIES_PER_THREAD>
__global__ void index_select_scalar_cumsum_kernel(
pta::PackedTensorAccessor32<scalar_t, 1, at::RestrictPtrTraits> output,
pta::PackedTensorAccessor32<acc_t, 1, at::RestrictPtrTraits> output_cumsum,
Expand All @@ -31,6 +33,81 @@ __global__ void index_select_scalar_cumsum_kernel(
acc_t* block_sums) {
typedef cub::BlockScan<acc_t, NUM_THREADS_PER_BLOCK> BlockScan;
__shared__ typename BlockScan::TempStorage bs_temp_storage;
__shared__ acc_t block_prefix;

// ROCm path
#ifdef USE_ROCM
const int output_batch_size = indices.size(0);
const int num_entries = num_batches * output_batch_size;
const bool multi_block = gridDim.x > 1;
const int block_entries = blockIdx.x == gridDim.x - 1
? last_block_num_entries
: MAX_ENTRIES_PER_BLOCK;
const int block_entry_start = blockIdx.x * MAX_ENTRIES_PER_BLOCK;
const int remaining_entries = num_entries - block_entry_start;
const int num_entries_per_block = remaining_entries > 0
? (remaining_entries < block_entries ? remaining_entries : block_entries)
: 0;

const int base_entry = block_entry_start + threadIdx.x * ENTRIES_PER_THREAD;
acc_t local_data[ENTRIES_PER_THREAD];

#pragma unroll
for (int i = 0; i < ENTRIES_PER_THREAD; ++i) {
const int entry = base_entry + i;
if (entry < num_entries) {
const int bid = entry / output_batch_size;
const int idx_in_batch = entry - bid * output_batch_size;
const int bid_base = bid * input_batch_size;
const index_t sel_idx = indices[idx_in_batch];
local_data[i] = __builtin_nontemporal_load(&input[bid_base + sel_idx]);
output[entry] = local_data[i];
} else {
local_data[i] = 0;
}
}

// Faster path for single block
if (!multi_block) {
if (num_entries_per_block > 0) {
BlockScan(bs_temp_storage).InclusiveSum(local_data, local_data);
}
if (base_entry < num_entries) {
#pragma unroll
for (int i = 0; i < ENTRIES_PER_THREAD; ++i) {
const int entry = base_entry + i;
if (entry < num_entries) {
output_cumsum[entry] = local_data[i];
}
}
}
return;
}

if (num_entries_per_block > 0) {
inclusive_sum_scan_kernel<acc_t, ENTRIES_PER_THREAD, NUM_THREADS_PER_BLOCK>(
local_data,
bs_temp_storage,
block_flags,
block_sums,
&block_prefix,
num_entries_per_block,
blockIdx.x,
multi_block,
1);
}

if (base_entry < num_entries) {
#pragma unroll
for (int i = 0; i < ENTRIES_PER_THREAD; ++i) {
const int entry = base_entry + i;
if (entry < num_entries) {
output_cumsum[entry] = local_data[i];
}
}
}
#else
// CUDA path
__shared__ acc_t smem[MAX_ENTRIES_PER_BLOCK];
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const int output_batch_size = indices.size(0);
Expand Down Expand Up @@ -65,6 +142,7 @@ __global__ void index_select_scalar_cumsum_kernel(
if (tid < num_batches * output_batch_size) {
output_cumsum[tid] = *local_data;
}
#endif
}

template <
Expand Down Expand Up @@ -183,58 +261,108 @@ class KeyedJaggedIndexSelectDim1GPUOp
const int num_batches = lengths.numel() / batch_size;
const int num_output_lengths = num_batches * indices.numel();
const int MAX_CUMSUM_ENTRIES_PER_BLOCK = 256;
#ifdef USE_ROCM
const int num_entries_per_thread[] = {4, 2, 1};
int entries_per_thread = 1;
for (int i : num_entries_per_thread) {
if (indices.numel() % i == 0) {
entries_per_thread = i;
break;
}
}
#else
const int entries_per_thread = 1;
const int ENTRIES_PER_BLOCK = MAX_CUMSUM_ENTRIES_PER_BLOCK;
auto grid_size = cuda_calc_xblock_count(
num_output_lengths, MAX_CUMSUM_ENTRIES_PER_BLOCK);
#endif

Tensor output_offsets =
at::empty({num_batches * indices.numel()}, offsets.options());
Tensor output_lengths =
at::empty({num_batches * indices.numel()}, lengths.options());

Tensor block_flags, block_sums;
if (grid_size > 1) {
block_flags = at::zeros({grid_size}, lengths.options().dtype(at::kInt));
block_sums = at::empty({grid_size}, output_offsets.options());
}

// Do index select and cumsum
AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "index_select_scalar_cumsum_wrapper_1", [&] {
using length_t = index_t;
AT_DISPATCH_INDEX_TYPES(
offsets.scalar_type(),
"index_select_scalar_cumsum_wrapper_2",
[&] {
using offset_t = index_t;
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(),
"index_select_scalar_cumsum_wrapper_3",
[&] {
FBGEMM_LAUNCH_KERNEL(
(index_select_scalar_cumsum_kernel<
length_t,
index_t,
offset_t,
MAX_CUMSUM_ENTRIES_PER_BLOCK,
MAX_CUMSUM_ENTRIES_PER_BLOCK>),
grid_size,
MAX_CUMSUM_ENTRIES_PER_BLOCK,
0,
at::cuda::getCurrentCUDAStream(),
PTA_B(output_lengths, length_t, 1, 32),
PTA_B(output_offsets, offset_t, 1, 32),
PTA_B(lengths, length_t, 1, 32),
PTA_B(indices, index_t, 1, 32),
num_batches,
batch_size,
num_output_lengths -
MAX_CUMSUM_ENTRIES_PER_BLOCK * (grid_size - 1),
grid_size > 1 ? block_flags.data_ptr<int>() : nullptr,
grid_size > 1 ? block_sums.data_ptr<offset_t>()
: nullptr);
});
});
auto dispatch_cumsum = [&](auto vec_tag, auto grid_calc) {
constexpr int ENTRIES_PER_THREAD = decltype(vec_tag)::value;
constexpr int ENTRIES_PER_BLOCK =
MAX_CUMSUM_ENTRIES_PER_BLOCK * ENTRIES_PER_THREAD;
const auto grid_size = grid_calc(ENTRIES_PER_BLOCK);

Tensor block_flags, block_sums;
if (grid_size > 1) {
block_flags =
at::zeros({grid_size}, lengths.options().dtype(at::kInt));
block_sums = at::empty({grid_size}, output_offsets.options());
}

AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "index_select_scalar_cumsum_wrapper_1", [&] {
using length_t = index_t;
AT_DISPATCH_INDEX_TYPES(
offsets.scalar_type(),
"index_select_scalar_cumsum_wrapper_2",
[&] {
using offset_t = index_t;
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(),
"index_select_scalar_cumsum_wrapper_3",
[&] {
FBGEMM_LAUNCH_KERNEL(
(index_select_scalar_cumsum_kernel<
length_t,
index_t,
offset_t,
MAX_CUMSUM_ENTRIES_PER_BLOCK,
ENTRIES_PER_BLOCK,
ENTRIES_PER_THREAD>),
grid_size,
MAX_CUMSUM_ENTRIES_PER_BLOCK,
0,
at::cuda::getCurrentCUDAStream(),
PTA_B(output_lengths, length_t, 1, 32),
PTA_B(output_offsets, offset_t, 1, 32),
PTA_B(lengths, length_t, 1, 32),
PTA_B(indices, index_t, 1, 32),
num_batches,
batch_size,
grid_size == 0
? 0
: num_output_lengths -
ENTRIES_PER_BLOCK * (grid_size - 1),
grid_size > 1
? block_flags.data_ptr<int>()
: nullptr,
grid_size > 1
? block_sums.data_ptr<offset_t>()
: nullptr);
});
});
});
};

#ifdef USE_ROCM
auto rocm_grid = [&](int entries_per_block) {
return (num_output_lengths + entries_per_block - 1) / entries_per_block;
};
switch (entries_per_thread) {
case 4:
dispatch_cumsum(std::integral_constant<int, 4>{}, rocm_grid);
break;
case 2:
dispatch_cumsum(std::integral_constant<int, 2>{}, rocm_grid);
break;
default:
dispatch_cumsum(std::integral_constant<int, 1>{}, rocm_grid);
break;
}
#else
dispatch_cumsum(
std::integral_constant<int, 1>{},
[&](int entries_per_block) {
return cuda_calc_xblock_count(num_output_lengths, entries_per_block);
});
#endif

const int64_t num_outputs = (selected_lengths_sum.has_value())
? selected_lengths_sum.value().guard_int(__FILE__, __LINE__)
Expand Down
Loading