Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions extension/llm/sampler/cuda_sampler.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/extension/llm/sampler/cuda_sampler.h>
#include <executorch/runtime/platform/log.h>

namespace executorch {
namespace extension {
namespace llm {

// CudaSampler implementation
//
// IMPORTANT: Stream synchronization considerations
// ------------------------------------------------
// CudaSampler uses the default CUDA stream (nullptr) rather than creating its
// own stream. This is a deliberate design choice for the following reasons:
//
// 1. The CUDA backend (cuda_backend.cpp) creates its own stream internally for
// running the model (encoder/decoder). This stream is encapsulated inside
// CudaDelegateHandle and is not exposed through any public API.
//
// 2. When the decoder produces logits on the backend's stream, we need those
// logits to be fully written before we run argmax on them. Using different
// streams without explicit synchronization could cause race conditions.
//
// 3. The legacy default stream (stream 0 / nullptr) has special synchronization
// semantics: operations on the default stream implicitly synchronize with
// operations on other streams in the same CUDA context. This means:
// - The argmax kernel will wait for the decoder to finish writing logits
// - No explicit cudaDeviceSynchronize() or cross-stream synchronization needed
//
// 4. Trade-off: Using the default stream prevents concurrent execution between
Copy link
Contributor

@Gasoonjia Gasoonjia Dec 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im wondering if we really need to make sampler and cuda backend using same cuda stream, since the sampling and decoding should be able to work in parallel: the argmax process of logits_{i} should be able to work with the decoder generating logits_{i+1} since they do not have any dependency, and such parallelism may not happen if argmax and decoder share the same cudastream.

// the sampler and other CUDA operations. However, for single-token argmax
// on a vocabulary-sized tensor, this overhead is negligible compared to the
// complexity of managing cross-stream synchronization.
//
// If in the future the CUDA backend exposes its stream, we could pass it here
// for tighter integration and potential pipelining opportunities.
//
CudaSampler::CudaSampler() : out_token_gpu_(nullptr) {
// Allocate GPU memory for output token
cudaError_t err = cudaMalloc(&out_token_gpu_, sizeof(int));
if (err != cudaSuccess) {
ET_LOG(
Error,
"Failed to allocate GPU memory for CudaSampler: %s",
cudaGetErrorString(err));
out_token_gpu_ = nullptr;
return;
}
// Note: We intentionally do NOT create a CUDA stream here.
// We use the default stream (nullptr) for synchronization with the backend.
// See the detailed comment above for rationale.
}

CudaSampler::~CudaSampler() {
// Note: No stream to destroy since we use the default stream (nullptr)
if (out_token_gpu_ != nullptr) {
cudaFree(out_token_gpu_);
}
}

CudaSampler::CudaSampler(CudaSampler&& other) noexcept
: out_token_gpu_(other.out_token_gpu_) {
other.out_token_gpu_ = nullptr;
}

CudaSampler& CudaSampler::operator=(CudaSampler&& other) noexcept {
if (this != &other) {
// Clean up existing resources
if (out_token_gpu_ != nullptr) {
cudaFree(out_token_gpu_);
}
// Take ownership of other's resources
out_token_gpu_ = other.out_token_gpu_;
other.out_token_gpu_ = nullptr;
}
return *this;
}

int32_t CudaSampler::sample_argmax(
const void* logits_ptr,
int vocab_size,
::executorch::aten::ScalarType scalar_type) {
if (out_token_gpu_ == nullptr) {
ET_LOG(Error, "CudaSampler not properly initialized");
return -1;
}

// Use default stream (nullptr) for implicit synchronization with backend
return cuda::argmax_cuda(
logits_ptr, vocab_size, scalar_type, nullptr, out_token_gpu_);
}

} // namespace llm
} // namespace extension
} // namespace executorch

90 changes: 90 additions & 0 deletions extension/llm/sampler/cuda_sampler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#ifdef CUDA_AVAILABLE

#include <cuda_runtime.h>
#include <cstdint>

#include <executorch/runtime/core/exec_aten/exec_aten.h>

namespace executorch {
namespace extension {
namespace llm {

/**
* CUDA-based sampler for performing argmax on GPU.
* This class avoids memory allocation in the hot path by pre-allocating
* scratch space on initialization.
*
* NOTE: This sampler uses the default CUDA stream (nullptr) rather than
* creating its own stream. This provides implicit synchronization with
* the CUDA backend's stream, ensuring that logits are fully written before
* argmax reads them. See argmax.cu for detailed rationale.
*/
class CudaSampler {
public:
CudaSampler();
~CudaSampler();

// Non-copyable
CudaSampler(const CudaSampler&) = delete;
CudaSampler& operator=(const CudaSampler&) = delete;

// Movable
CudaSampler(CudaSampler&& other) noexcept;
CudaSampler& operator=(CudaSampler&& other) noexcept;

/**
* Perform argmax sampling on GPU logits.
*
* @param logits_ptr Pointer to GPU memory containing logits
* @param vocab_size Vocabulary size (number of logits)
* @param scalar_type Data type of the logits tensor
* @return The token index with the highest logit value, or -1 on error
*/
int32_t sample_argmax(
const void* logits_ptr,
int vocab_size,
::executorch::aten::ScalarType scalar_type);

private:
// Pre-allocated GPU memory for output token
int* out_token_gpu_;
};

namespace cuda {

/**
* Perform argmax on GPU logits tensor.
* This is a lower-level function that requires pre-allocated GPU memory.
*
* @param logits_ptr Pointer to GPU memory containing logits
* @param vocab_size Vocabulary size
* @param scalar_type Data type of the logits
* @param cuda_stream CUDA stream for async execution
* @param out_token_gpu Pre-allocated GPU memory for output token
* @return The token index with highest logit, or -1 on error
*/
int32_t argmax_cuda(
const void* logits_ptr,
int vocab_size,
::executorch::aten::ScalarType scalar_type,
cudaStream_t cuda_stream,
int* out_token_gpu);

} // namespace cuda

} // namespace llm
} // namespace extension
} // namespace executorch

#endif // CUDA_AVAILABLE

Loading