Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions extension/llm/sampler/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# LLM sampler library with optional CUDA support
#
# ### Editing this file ###
#
# This file should be formatted with
# ~~~
# cmake-format -i CMakeLists.txt
# ~~~
# It should also be cmake-lint clean.
#

if(NOT EXECUTORCH_ROOT)
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..)
endif()

include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)

# If the project is configured to build with CUDA support, build the CUDA
# sampler library.
if(EXECUTORCH_BUILD_CUDA)
find_package(CUDAToolkit QUIET)
if(CUDAToolkit_FOUND)
enable_language(CUDA)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)

# Define CUDA sampler library
add_library(extension_llm_sampler_cuda STATIC argmax.cu cuda_sampler.cu)
target_include_directories(
extension_llm_sampler_cuda
PUBLIC ${EXECUTORCH_ROOT}
${EXECUTORCH_ROOT}/..
${CUDAToolkit_INCLUDE_DIRS}
)
target_compile_definitions(extension_llm_sampler_cuda PUBLIC CUDA_AVAILABLE)
target_link_libraries(extension_llm_sampler_cuda PUBLIC executorch_core
CUDA::cudart
)
set_target_properties(
extension_llm_sampler_cuda
PROPERTIES POSITION_INDEPENDENT_CODE ON CUDA_SEPARABLE_COMPILATION ON
)

message(
STATUS "CUDAToolkit found; building extension_llm_sampler_cuda library"
)

install(
TARGETS extension_llm_sampler_cuda
EXPORT ExecuTorchTargets
DESTINATION ${CMAKE_INSTALL_LIBDIR}
)
else()
message(
STATUS
"CUDA requested (EXECUTORCH_BUILD_CUDA=ON) but no CUDA runtime found"
)
endif()
endif()

# Install header files
install(
DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/executorch/extension/llm/sampler
FILES_MATCHING
PATTERN "*.h"
PATTERN "*.cuh"
PATTERN "test" EXCLUDE
)

72 changes: 72 additions & 0 deletions extension/llm/sampler/argmax.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/extension/llm/sampler/argmax.cuh>
#include <executorch/extension/llm/sampler/cuda_sampler.h>
#include <executorch/runtime/platform/log.h>

namespace executorch {
namespace extension {
namespace llm {
namespace cuda {

// Wrapper function that performs argmax on GPU logits tensor
// Returns the token index with the highest logit value
// logits_ptr: pointer to GPU memory containing logits
// vocab_size: vocabulary size
// scalar_type: data type of the logits tensor
// cuda_stream: CUDA stream for async execution (nullptr for default stream)
// out_token_gpu: pre-allocated GPU memory for output token (int*)
int32_t argmax_cuda(
const void* logits_ptr,
int vocab_size,
::executorch::aten::ScalarType scalar_type,
cudaStream_t cuda_stream,
int* out_token_gpu) {
// Launch kernel for single row (batch size 1)
launch_argmax_vocab_rows(
logits_ptr,
scalar_type,
1, // rows = 1
vocab_size,
out_token_gpu,
nullptr, // don't need max logit value
cuda_stream,
256 // threads per block
);

// Copy result back to host
int32_t token;
cudaError_t err = cudaMemcpyAsync(
&token, out_token_gpu, sizeof(int), cudaMemcpyDeviceToHost, cuda_stream);
if (err != cudaSuccess) {
ET_LOG(
Error,
"Failed to copy argmax result from GPU: %s",
cudaGetErrorString(err));
return -1;
}

// Synchronize to ensure result is ready
err = cudaStreamSynchronize(cuda_stream);
if (err != cudaSuccess) {
ET_LOG(
Error,
"Failed to synchronize CUDA stream: %s",
cudaGetErrorString(err));
return -1;
}

return token;
}

} // namespace cuda
} // namespace llm
} // namespace extension
} // namespace executorch

171 changes: 171 additions & 0 deletions extension/llm/sampler/argmax.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <float.h>
#include <stdint.h>

#include <executorch/runtime/core/exec_aten/exec_aten.h>

namespace executorch {
namespace extension {
namespace llm {
namespace cuda {

struct ArgMaxPair {
float v;
int i;
};

// tie-break: smaller index wins on equal values
__device__ __forceinline__ ArgMaxPair better(ArgMaxPair a, ArgMaxPair b) {
if (b.v > a.v)
return b;
if (b.v < a.v)
return a;
return (b.i < a.i) ? b : a;
}

__device__ __forceinline__ ArgMaxPair
warp_argmax_xor(ArgMaxPair x, unsigned mask = 0xffffffffu) {
for (int d = 16; d > 0; d >>= 1) {
ArgMaxPair y;
y.v = __shfl_xor_sync(mask, x.v, d);
y.i = __shfl_xor_sync(mask, x.i, d);
x = better(x, y);
}
return x;
}

// ---- dtype -> float load helpers ----
template <typename T>
__device__ __forceinline__ float load_as_float(const T* p);

template <>
__device__ __forceinline__ float load_as_float<float>(const float* p) {
return *p;
}

template <>
__device__ __forceinline__ float load_as_float<half>(const half* p) {
return __half2float(*p);
}

template <>
__device__ __forceinline__ float
load_as_float<nv_bfloat16>(const nv_bfloat16* p) {
return __bfloat162float(*p);
}

// logits: [rows, vocab] row-major contiguous
// out_token: [rows]
// out_maxlogit: [rows] (optional; pass nullptr if not needed)
template <typename T>
__global__ void argmax_vocab_rows_kernel(
const T* __restrict__ logits,
int rows,
int vocab,
int* __restrict__ out_token,
float* __restrict__ out_maxlogit) {
int row = blockIdx.x;
if (row >= rows)
return;

int tid = threadIdx.x;
int lane = tid & 31;
int warp = tid >> 5;
int warps_per_block = (blockDim.x + 31) >> 5;

const T* row_ptr = logits + (size_t)row * (size_t)vocab;

// local scan
ArgMaxPair best;
best.v = -FLT_MAX;
best.i = -1;

for (int j = tid; j < vocab; j += blockDim.x) {
float v = load_as_float<T>(row_ptr + j);
best = better(best, ArgMaxPair{v, j});
}

// warp reduce
best = warp_argmax_xor(best);

// shared collect warp winners (supports up to 1024 threads = 32 warps)
__shared__ float s_val[32];
__shared__ int s_idx[32];

if (lane == 0) {
s_val[warp] = best.v;
s_idx[warp] = best.i;
}
__syncthreads();

// first warp reduces warp winners
if (warp == 0) {
ArgMaxPair wbest;
if (lane < warps_per_block) {
wbest.v = s_val[lane];
wbest.i = s_idx[lane];
} else {
wbest.v = -FLT_MAX;
wbest.i = -1;
}

wbest = warp_argmax_xor(wbest);

if (lane == 0) {
out_token[row] = wbest.i;
if (out_maxlogit)
out_maxlogit[row] = wbest.v;
}
}
}

inline void launch_argmax_vocab_rows(
const void* logits,
::executorch::aten::ScalarType scalar_type,
int rows,
int vocab,
int* out_token,
float* out_maxlogit,
cudaStream_t stream,
int threads = 256) {
dim3 block(threads);
dim3 grid(rows);

switch (scalar_type) {
case ::executorch::aten::ScalarType::Float:
argmax_vocab_rows_kernel<float><<<grid, block, 0, stream>>>(
(const float*)logits, rows, vocab, out_token, out_maxlogit);
break;
case ::executorch::aten::ScalarType::Half:
argmax_vocab_rows_kernel<half><<<grid, block, 0, stream>>>(
(const half*)logits, rows, vocab, out_token, out_maxlogit);
break;
case ::executorch::aten::ScalarType::BFloat16:
argmax_vocab_rows_kernel<nv_bfloat16><<<grid, block, 0, stream>>>(
(const nv_bfloat16*)logits, rows, vocab, out_token, out_maxlogit);
break;
default:
// Unsupported type, fall back to float
argmax_vocab_rows_kernel<float><<<grid, block, 0, stream>>>(
(const float*)logits, rows, vocab, out_token, out_maxlogit);
break;
}
}

} // namespace cuda
} // namespace llm
} // namespace extension
} // namespace executorch

63 changes: 63 additions & 0 deletions extension/llm/sampler/test/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

cmake_minimum_required(VERSION 3.19)
project(llm_sampler_cuda_tests LANGUAGES CXX CUDA)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)

# Find required packages
find_package(CUDAToolkit REQUIRED)

# Fetch GoogleTest
include(FetchContent)
FetchContent_Declare(
googletest
GIT_REPOSITORY https://github.com/google/googletest.git
GIT_TAG v1.14.0
)
# For Windows: Prevent overriding the parent project's compiler/linker settings
set(gtest_force_shared_crt
ON
CACHE BOOL "" FORCE
)
FetchContent_MakeAvailable(googletest)

# Get EXECUTORCH_ROOT
if(NOT EXECUTORCH_ROOT)
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../..)
endif()

# Find installed ExecuTorch
find_package(executorch CONFIG REQUIRED HINTS ${CMAKE_INSTALL_PREFIX})

# List of CUDA test files
set(LLM_SAMPLER_CUDA_TESTS test_argmax)

enable_testing()

foreach(test_name ${LLM_SAMPLER_CUDA_TESTS})
add_executable(${test_name} ${test_name}.cu)

target_include_directories(
${test_name} PRIVATE ${EXECUTORCH_ROOT}/.. ${EXECUTORCH_ROOT}
${CUDAToolkit_INCLUDE_DIRS}
)

target_link_libraries(
${test_name}
PRIVATE GTest::gtest
GTest::gtest_main
executorch_core
CUDA::cudart
)

add_test(NAME ${test_name} COMMAND ${test_name})
endforeach()

Loading
Loading