From 8573c9a7f0db46f7e8326ac33a4a8fbc87e116e9 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Wed, 24 Dec 2025 01:46:56 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- extension/llm/sampler/cuda_sampler.cu | 103 ++++++++++++++++++++++++++ extension/llm/sampler/cuda_sampler.h | 90 ++++++++++++++++++++++ 2 files changed, 193 insertions(+) create mode 100644 extension/llm/sampler/cuda_sampler.cu create mode 100644 extension/llm/sampler/cuda_sampler.h diff --git a/extension/llm/sampler/cuda_sampler.cu b/extension/llm/sampler/cuda_sampler.cu new file mode 100644 index 00000000000..651b5ea94bd --- /dev/null +++ b/extension/llm/sampler/cuda_sampler.cu @@ -0,0 +1,103 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace executorch { +namespace extension { +namespace llm { + +// CudaSampler implementation +// +// IMPORTANT: Stream synchronization considerations +// ------------------------------------------------ +// CudaSampler uses the default CUDA stream (nullptr) rather than creating its +// own stream. This is a deliberate design choice for the following reasons: +// +// 1. The CUDA backend (cuda_backend.cpp) creates its own stream internally for +// running the model (encoder/decoder). This stream is encapsulated inside +// CudaDelegateHandle and is not exposed through any public API. +// +// 2. When the decoder produces logits on the backend's stream, we need those +// logits to be fully written before we run argmax on them. Using different +// streams without explicit synchronization could cause race conditions. +// +// 3. The legacy default stream (stream 0 / nullptr) has special synchronization +// semantics: operations on the default stream implicitly synchronize with +// operations on other streams in the same CUDA context. This means: +// - The argmax kernel will wait for the decoder to finish writing logits +// - No explicit cudaDeviceSynchronize() or cross-stream synchronization needed +// +// 4. Trade-off: Using the default stream prevents concurrent execution between +// the sampler and other CUDA operations. However, for single-token argmax +// on a vocabulary-sized tensor, this overhead is negligible compared to the +// complexity of managing cross-stream synchronization. +// +// If in the future the CUDA backend exposes its stream, we could pass it here +// for tighter integration and potential pipelining opportunities. +// +CudaSampler::CudaSampler() : out_token_gpu_(nullptr) { + // Allocate GPU memory for output token + cudaError_t err = cudaMalloc(&out_token_gpu_, sizeof(int)); + if (err != cudaSuccess) { + ET_LOG( + Error, + "Failed to allocate GPU memory for CudaSampler: %s", + cudaGetErrorString(err)); + out_token_gpu_ = nullptr; + return; + } + // Note: We intentionally do NOT create a CUDA stream here. + // We use the default stream (nullptr) for synchronization with the backend. + // See the detailed comment above for rationale. +} + +CudaSampler::~CudaSampler() { + // Note: No stream to destroy since we use the default stream (nullptr) + if (out_token_gpu_ != nullptr) { + cudaFree(out_token_gpu_); + } +} + +CudaSampler::CudaSampler(CudaSampler&& other) noexcept + : out_token_gpu_(other.out_token_gpu_) { + other.out_token_gpu_ = nullptr; +} + +CudaSampler& CudaSampler::operator=(CudaSampler&& other) noexcept { + if (this != &other) { + // Clean up existing resources + if (out_token_gpu_ != nullptr) { + cudaFree(out_token_gpu_); + } + // Take ownership of other's resources + out_token_gpu_ = other.out_token_gpu_; + other.out_token_gpu_ = nullptr; + } + return *this; +} + +int32_t CudaSampler::sample_argmax( + const void* logits_ptr, + int vocab_size, + ::executorch::aten::ScalarType scalar_type) { + if (out_token_gpu_ == nullptr) { + ET_LOG(Error, "CudaSampler not properly initialized"); + return -1; + } + + // Use default stream (nullptr) for implicit synchronization with backend + return cuda::argmax_cuda( + logits_ptr, vocab_size, scalar_type, nullptr, out_token_gpu_); +} + +} // namespace llm +} // namespace extension +} // namespace executorch + diff --git a/extension/llm/sampler/cuda_sampler.h b/extension/llm/sampler/cuda_sampler.h new file mode 100644 index 00000000000..4438548b1da --- /dev/null +++ b/extension/llm/sampler/cuda_sampler.h @@ -0,0 +1,90 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#ifdef CUDA_AVAILABLE + +#include +#include + +#include + +namespace executorch { +namespace extension { +namespace llm { + +/** + * CUDA-based sampler for performing argmax on GPU. + * This class avoids memory allocation in the hot path by pre-allocating + * scratch space on initialization. + * + * NOTE: This sampler uses the default CUDA stream (nullptr) rather than + * creating its own stream. This provides implicit synchronization with + * the CUDA backend's stream, ensuring that logits are fully written before + * argmax reads them. See argmax.cu for detailed rationale. + */ +class CudaSampler { + public: + CudaSampler(); + ~CudaSampler(); + + // Non-copyable + CudaSampler(const CudaSampler&) = delete; + CudaSampler& operator=(const CudaSampler&) = delete; + + // Movable + CudaSampler(CudaSampler&& other) noexcept; + CudaSampler& operator=(CudaSampler&& other) noexcept; + + /** + * Perform argmax sampling on GPU logits. + * + * @param logits_ptr Pointer to GPU memory containing logits + * @param vocab_size Vocabulary size (number of logits) + * @param scalar_type Data type of the logits tensor + * @return The token index with the highest logit value, or -1 on error + */ + int32_t sample_argmax( + const void* logits_ptr, + int vocab_size, + ::executorch::aten::ScalarType scalar_type); + + private: + // Pre-allocated GPU memory for output token + int* out_token_gpu_; +}; + +namespace cuda { + +/** + * Perform argmax on GPU logits tensor. + * This is a lower-level function that requires pre-allocated GPU memory. + * + * @param logits_ptr Pointer to GPU memory containing logits + * @param vocab_size Vocabulary size + * @param scalar_type Data type of the logits + * @param cuda_stream CUDA stream for async execution + * @param out_token_gpu Pre-allocated GPU memory for output token + * @return The token index with highest logit, or -1 on error + */ +int32_t argmax_cuda( + const void* logits_ptr, + int vocab_size, + ::executorch::aten::ScalarType scalar_type, + cudaStream_t cuda_stream, + int* out_token_gpu); + +} // namespace cuda + +} // namespace llm +} // namespace extension +} // namespace executorch + +#endif // CUDA_AVAILABLE +