-
Notifications
You must be signed in to change notification settings - Fork 779
Add CudaSampler class for GPU-based token sampling #16387
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
larryliu0820
wants to merge
1
commit into
gh/larryliu0820/87/head
Choose a base branch
from
gh/larryliu0820/88/head
base: gh/larryliu0820/87/head
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+193
−0
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,103 @@ | ||
| /* | ||
| * Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| * All rights reserved. | ||
| * | ||
| * This source code is licensed under the BSD-style license found in the | ||
| * LICENSE file in the root directory of this source tree. | ||
| */ | ||
|
|
||
| #include <executorch/extension/llm/sampler/cuda_sampler.h> | ||
| #include <executorch/runtime/platform/log.h> | ||
|
|
||
| namespace executorch { | ||
| namespace extension { | ||
| namespace llm { | ||
|
|
||
| // CudaSampler implementation | ||
| // | ||
| // IMPORTANT: Stream synchronization considerations | ||
| // ------------------------------------------------ | ||
| // CudaSampler uses the default CUDA stream (nullptr) rather than creating its | ||
| // own stream. This is a deliberate design choice for the following reasons: | ||
| // | ||
| // 1. The CUDA backend (cuda_backend.cpp) creates its own stream internally for | ||
| // running the model (encoder/decoder). This stream is encapsulated inside | ||
| // CudaDelegateHandle and is not exposed through any public API. | ||
| // | ||
| // 2. When the decoder produces logits on the backend's stream, we need those | ||
| // logits to be fully written before we run argmax on them. Using different | ||
| // streams without explicit synchronization could cause race conditions. | ||
| // | ||
| // 3. The legacy default stream (stream 0 / nullptr) has special synchronization | ||
| // semantics: operations on the default stream implicitly synchronize with | ||
| // operations on other streams in the same CUDA context. This means: | ||
| // - The argmax kernel will wait for the decoder to finish writing logits | ||
| // - No explicit cudaDeviceSynchronize() or cross-stream synchronization needed | ||
| // | ||
| // 4. Trade-off: Using the default stream prevents concurrent execution between | ||
| // the sampler and other CUDA operations. However, for single-token argmax | ||
| // on a vocabulary-sized tensor, this overhead is negligible compared to the | ||
| // complexity of managing cross-stream synchronization. | ||
| // | ||
| // If in the future the CUDA backend exposes its stream, we could pass it here | ||
| // for tighter integration and potential pipelining opportunities. | ||
| // | ||
| CudaSampler::CudaSampler() : out_token_gpu_(nullptr) { | ||
| // Allocate GPU memory for output token | ||
| cudaError_t err = cudaMalloc(&out_token_gpu_, sizeof(int)); | ||
| if (err != cudaSuccess) { | ||
| ET_LOG( | ||
| Error, | ||
| "Failed to allocate GPU memory for CudaSampler: %s", | ||
| cudaGetErrorString(err)); | ||
| out_token_gpu_ = nullptr; | ||
| return; | ||
| } | ||
| // Note: We intentionally do NOT create a CUDA stream here. | ||
| // We use the default stream (nullptr) for synchronization with the backend. | ||
| // See the detailed comment above for rationale. | ||
| } | ||
|
|
||
| CudaSampler::~CudaSampler() { | ||
| // Note: No stream to destroy since we use the default stream (nullptr) | ||
| if (out_token_gpu_ != nullptr) { | ||
| cudaFree(out_token_gpu_); | ||
| } | ||
| } | ||
|
|
||
| CudaSampler::CudaSampler(CudaSampler&& other) noexcept | ||
| : out_token_gpu_(other.out_token_gpu_) { | ||
| other.out_token_gpu_ = nullptr; | ||
| } | ||
|
|
||
| CudaSampler& CudaSampler::operator=(CudaSampler&& other) noexcept { | ||
| if (this != &other) { | ||
| // Clean up existing resources | ||
| if (out_token_gpu_ != nullptr) { | ||
| cudaFree(out_token_gpu_); | ||
| } | ||
| // Take ownership of other's resources | ||
| out_token_gpu_ = other.out_token_gpu_; | ||
| other.out_token_gpu_ = nullptr; | ||
| } | ||
| return *this; | ||
| } | ||
|
|
||
| int32_t CudaSampler::sample_argmax( | ||
| const void* logits_ptr, | ||
| int vocab_size, | ||
| ::executorch::aten::ScalarType scalar_type) { | ||
| if (out_token_gpu_ == nullptr) { | ||
| ET_LOG(Error, "CudaSampler not properly initialized"); | ||
| return -1; | ||
| } | ||
|
|
||
| // Use default stream (nullptr) for implicit synchronization with backend | ||
| return cuda::argmax_cuda( | ||
| logits_ptr, vocab_size, scalar_type, nullptr, out_token_gpu_); | ||
| } | ||
|
|
||
| } // namespace llm | ||
| } // namespace extension | ||
| } // namespace executorch | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,90 @@ | ||
| /* | ||
| * Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| * All rights reserved. | ||
| * | ||
| * This source code is licensed under the BSD-style license found in the | ||
| * LICENSE file in the root directory of this source tree. | ||
| */ | ||
|
|
||
| #pragma once | ||
|
|
||
| #ifdef CUDA_AVAILABLE | ||
|
|
||
| #include <cuda_runtime.h> | ||
| #include <cstdint> | ||
|
|
||
| #include <executorch/runtime/core/exec_aten/exec_aten.h> | ||
|
|
||
| namespace executorch { | ||
| namespace extension { | ||
| namespace llm { | ||
|
|
||
| /** | ||
| * CUDA-based sampler for performing argmax on GPU. | ||
| * This class avoids memory allocation in the hot path by pre-allocating | ||
| * scratch space on initialization. | ||
| * | ||
| * NOTE: This sampler uses the default CUDA stream (nullptr) rather than | ||
| * creating its own stream. This provides implicit synchronization with | ||
| * the CUDA backend's stream, ensuring that logits are fully written before | ||
| * argmax reads them. See argmax.cu for detailed rationale. | ||
| */ | ||
| class CudaSampler { | ||
| public: | ||
| CudaSampler(); | ||
| ~CudaSampler(); | ||
|
|
||
| // Non-copyable | ||
| CudaSampler(const CudaSampler&) = delete; | ||
| CudaSampler& operator=(const CudaSampler&) = delete; | ||
|
|
||
| // Movable | ||
| CudaSampler(CudaSampler&& other) noexcept; | ||
| CudaSampler& operator=(CudaSampler&& other) noexcept; | ||
|
|
||
| /** | ||
| * Perform argmax sampling on GPU logits. | ||
| * | ||
| * @param logits_ptr Pointer to GPU memory containing logits | ||
| * @param vocab_size Vocabulary size (number of logits) | ||
| * @param scalar_type Data type of the logits tensor | ||
| * @return The token index with the highest logit value, or -1 on error | ||
| */ | ||
| int32_t sample_argmax( | ||
| const void* logits_ptr, | ||
| int vocab_size, | ||
| ::executorch::aten::ScalarType scalar_type); | ||
|
|
||
| private: | ||
| // Pre-allocated GPU memory for output token | ||
| int* out_token_gpu_; | ||
| }; | ||
|
|
||
| namespace cuda { | ||
|
|
||
| /** | ||
| * Perform argmax on GPU logits tensor. | ||
| * This is a lower-level function that requires pre-allocated GPU memory. | ||
| * | ||
| * @param logits_ptr Pointer to GPU memory containing logits | ||
| * @param vocab_size Vocabulary size | ||
| * @param scalar_type Data type of the logits | ||
| * @param cuda_stream CUDA stream for async execution | ||
| * @param out_token_gpu Pre-allocated GPU memory for output token | ||
| * @return The token index with highest logit, or -1 on error | ||
| */ | ||
| int32_t argmax_cuda( | ||
| const void* logits_ptr, | ||
| int vocab_size, | ||
| ::executorch::aten::ScalarType scalar_type, | ||
| cudaStream_t cuda_stream, | ||
| int* out_token_gpu); | ||
|
|
||
| } // namespace cuda | ||
|
|
||
| } // namespace llm | ||
| } // namespace extension | ||
| } // namespace executorch | ||
|
|
||
| #endif // CUDA_AVAILABLE | ||
|
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Im wondering if we really need to make sampler and cuda backend using same cuda stream, since the sampling and decoding should be able to work in parallel: the argmax process of logits_{i} should be able to work with the decoder generating logits_{i+1} since they do not have any dependency, and such parallelism may not happen if argmax and decoder share the same cudastream.