From 1637ec1b1439508403ddff06d1348edb1f286989 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Wed, 24 Dec 2025 01:46:58 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- backends/cuda/runtime/cuda_backend.cpp | 31 ++++++++++- extension/asr/runner/CMakeLists.txt | 9 ++++ extension/asr/runner/runner.cpp | 75 +++++++++++++++++++++----- extension/llm/runner/CMakeLists.txt | 11 ++++ 4 files changed, 111 insertions(+), 15 deletions(-) diff --git a/backends/cuda/runtime/cuda_backend.cpp b/backends/cuda/runtime/cuda_backend.cpp index cd1c6b96f02..3577335ed22 100644 --- a/backends/cuda/runtime/cuda_backend.cpp +++ b/backends/cuda/runtime/cuda_backend.cpp @@ -83,7 +83,36 @@ class ET_EXPERIMENTAL CudaBackend final return false; } std::lock_guard guard(skip_copy_method_mutex_); - return method_name == skip_copy_method_; + // Support comma-separated list of method names + if (skip_copy_method_.empty()) { + return false; + } + // Check if method_name matches any entry in the comma-separated list + size_t start = 0; + size_t end = skip_copy_method_.find(','); + while (end != std::string::npos) { + std::string entry = skip_copy_method_.substr(start, end - start); + // Trim whitespace + size_t entry_start = entry.find_first_not_of(" \t"); + size_t entry_end = entry.find_last_not_of(" \t"); + if (entry_start != std::string::npos) { + entry = entry.substr(entry_start, entry_end - entry_start + 1); + if (entry == method_name) { + return true; + } + } + start = end + 1; + end = skip_copy_method_.find(',', start); + } + // Check last (or only) entry + std::string entry = skip_copy_method_.substr(start); + size_t entry_start = entry.find_first_not_of(" \t"); + size_t entry_end = entry.find_last_not_of(" \t"); + if (entry_start != std::string::npos) { + entry = entry.substr(entry_start, entry_end - entry_start + 1); + return entry == method_name; + } + return false; } Error load_function_pointers_into_handle( diff --git a/extension/asr/runner/CMakeLists.txt b/extension/asr/runner/CMakeLists.txt index c3d77712017..e3101e96401 100644 --- a/extension/asr/runner/CMakeLists.txt +++ b/extension/asr/runner/CMakeLists.txt @@ -42,6 +42,15 @@ if(EXECUTORCH_BUILD_CUDA) find_package(CUDAToolkit QUIET) if(CUDAToolkit_FOUND) target_compile_definitions(extension_asr_runner PUBLIC CUDA_AVAILABLE) + target_include_directories( + extension_asr_runner PUBLIC ${CUDAToolkit_INCLUDE_DIRS} + ) + # Link against the CUDA sampler library from extension/llm/sampler + if(TARGET extension_llm_sampler_cuda) + target_link_libraries(extension_asr_runner PUBLIC extension_llm_sampler_cuda) + else() + target_link_libraries(extension_asr_runner PUBLIC CUDA::cudart) + endif() message(STATUS "CUDAToolkit found; defining CUDA_AVAILABLE for ASR runner") else() message( diff --git a/extension/asr/runner/runner.cpp b/extension/asr/runner/runner.cpp index 6d5c61696d9..6267f18e602 100644 --- a/extension/asr/runner/runner.cpp +++ b/extension/asr/runner/runner.cpp @@ -22,6 +22,10 @@ #include #include +#ifdef CUDA_AVAILABLE +#include +#endif + namespace executorch::extension::asr { namespace { @@ -110,19 +114,25 @@ Error AsrRunner::load() { ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kDecoderMethodName)); decoder_method_loaded_ = true; #ifdef CUDA_AVAILABLE - executorch::runtime::BackendOptions<1> backend_options; - // For decoder still copy output from GPU to CPU for sampling. - // TODO: change sampler to use a CUDA kernel to sample and then skip copying - // decoder output as well - ET_CHECK_OK_OR_RETURN_ERROR(backend_options.set_option( - "skip_copy_output_to_cpu_for_method", kEncoderMethodName)); - const auto opt_err = - executorch::runtime::set_option("CudaBackend", backend_options.view()); - if (opt_err != ::executorch::runtime::Error::Ok) { - ET_LOG( - Error, - "Failed to set CUDA backend options: %d", - static_cast(opt_err)); + { + // Skip copying outputs to CPU for both encoder and decoder methods. + // Encoder output stays on GPU for the decoder to consume directly. + // Decoder logits stay on GPU for CUDA-based sampling (temperature=0). + // For temperature != 0, we fall back to CPU sampling which will require + // a copy, but that path is less common for ASR applications. + std::string skip_methods = + std::string(kEncoderMethodName) + "," + kDecoderMethodName; + executorch::runtime::BackendOptions<1> backend_options; + ET_CHECK_OK_OR_RETURN_ERROR(backend_options.set_option( + "skip_copy_output_to_cpu_for_method", skip_methods.c_str())); + const auto opt_err = + executorch::runtime::set_option("CudaBackend", backend_options.view()); + if (opt_err != ::executorch::runtime::Error::Ok) { + ET_LOG( + Error, + "Failed to set CUDA backend options: %d", + static_cast(opt_err)); + } } #endif ET_CHECK_OK_OR_RETURN_ERROR(load_tokenizer()); @@ -266,6 +276,18 @@ Result> AsrRunner::transcribe( decoder_inputs.emplace_back(decoder_input_ptr); decoder_inputs.emplace_back(encoder_output_ptr); decoder_inputs.emplace_back(cache_position_ptr); + +#ifdef CUDA_AVAILABLE + // Create CUDA sampler outside the loop to avoid memory allocation overhead. + // Only used when temperature == 0 (argmax sampling). + const bool use_cuda_sampler = (config.temperature == 0.0f); + std::optional<::executorch::extension::llm::CudaSampler> cuda_sampler; + if (use_cuda_sampler) { + cuda_sampler.emplace(); + ET_LOG(Info, "Using CUDA sampler for argmax sampling"); + } +#endif + // Add some green coloring for the first generated token // token_callback("\033[1;32m"); while (generated_tokens < config.max_new_tokens) { @@ -286,9 +308,34 @@ Result> AsrRunner::transcribe( ET_CHECK_OR_RETURN_ERROR( vocab_size > 0, Internal, "Decoder logits tensor is empty."); - const int64_t next_token = + int64_t next_token; +#ifdef CUDA_AVAILABLE + if (use_cuda_sampler && cuda_sampler.has_value()) { + // Use CUDA-based argmax sampling - logits are already on GPU + next_token = static_cast(cuda_sampler->sample_argmax( + logits_tensor.const_data_ptr(), + static_cast(vocab_size), + logits_tensor.scalar_type())); + ET_CHECK_OR_RETURN_ERROR( + next_token >= 0, + Internal, + "CUDA sampler failed to sample token"); + } else { + // Fall back to CPU sampling for temperature != 0 + // Note: This requires the logits to be copied to CPU, which happens + // automatically when skip_copy_output_to_cpu_for_method doesn't include + // the decoder method. Since we include decoder in the skip list, we need + // to handle this case differently in the future if we want to support + // temperature != 0 with CUDA. + next_token = + static_cast(::executorch::extension::llm::logits_to_token( + logits_tensor, config.temperature)); + } +#else + next_token = static_cast(::executorch::extension::llm::logits_to_token( logits_tensor, config.temperature)); +#endif if (!first_token_generated) { stats_.first_token_ms = ::executorch::extension::llm::time_in_ms(); diff --git a/extension/llm/runner/CMakeLists.txt b/extension/llm/runner/CMakeLists.txt index 6a2c1989922..9bf3ad997f3 100644 --- a/extension/llm/runner/CMakeLists.txt +++ b/extension/llm/runner/CMakeLists.txt @@ -66,6 +66,17 @@ if(EXECUTORCH_BUILD_CUDA) target_compile_definitions(extension_llm_runner PUBLIC CUDA_AVAILABLE) target_link_libraries(extension_llm_runner PUBLIC CUDA::cudart) message(STATUS "CUDAToolkit found; defining CUDA_AVAILABLE") + + # Build the CUDA sampler library + if(NOT TARGET extension_llm_sampler_cuda) + add_subdirectory( + ${EXECUTORCH_ROOT}/extension/llm/sampler + ${CMAKE_CURRENT_BINARY_DIR}/sampler + ) + endif() + if(TARGET extension_llm_sampler_cuda) + target_link_libraries(extension_llm_runner PUBLIC extension_llm_sampler_cuda) + endif() else() message( STATUS