diff --git a/extension/llm/sampler/CMakeLists.txt b/extension/llm/sampler/CMakeLists.txt new file mode 100644 index 00000000000..57736a2d1a0 --- /dev/null +++ b/extension/llm/sampler/CMakeLists.txt @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# LLM sampler library with optional CUDA support +# +# ### Editing this file ### +# +# This file should be formatted with +# ~~~ +# cmake-format -i CMakeLists.txt +# ~~~ +# It should also be cmake-lint clean. +# + +if(NOT EXECUTORCH_ROOT) + set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..) +endif() + +include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) + +# If the project is configured to build with CUDA support, build the CUDA +# sampler library. +if(EXECUTORCH_BUILD_CUDA) + find_package(CUDAToolkit QUIET) + if(CUDAToolkit_FOUND) + enable_language(CUDA) + set(CMAKE_CUDA_STANDARD 17) + set(CMAKE_CUDA_STANDARD_REQUIRED ON) + + # Define CUDA sampler library + add_library(extension_llm_sampler_cuda STATIC argmax.cu cuda_sampler.cu) + target_include_directories( + extension_llm_sampler_cuda + PUBLIC ${EXECUTORCH_ROOT} + ${EXECUTORCH_ROOT}/.. + ${CUDAToolkit_INCLUDE_DIRS} + ) + target_compile_definitions(extension_llm_sampler_cuda PUBLIC CUDA_AVAILABLE) + target_link_libraries(extension_llm_sampler_cuda PUBLIC executorch_core + CUDA::cudart + ) + set_target_properties( + extension_llm_sampler_cuda + PROPERTIES POSITION_INDEPENDENT_CODE ON CUDA_SEPARABLE_COMPILATION ON + ) + + message( + STATUS "CUDAToolkit found; building extension_llm_sampler_cuda library" + ) + + install( + TARGETS extension_llm_sampler_cuda + EXPORT ExecuTorchTargets + DESTINATION ${CMAKE_INSTALL_LIBDIR} + ) + else() + message( + STATUS + "CUDA requested (EXECUTORCH_BUILD_CUDA=ON) but no CUDA runtime found" + ) + endif() +endif() + +# Install header files +install( + DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/ + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/executorch/extension/llm/sampler + FILES_MATCHING + PATTERN "*.h" + PATTERN "*.cuh" + PATTERN "test" EXCLUDE +) + diff --git a/extension/llm/sampler/argmax.cu b/extension/llm/sampler/argmax.cu new file mode 100644 index 00000000000..46b0bf0ca47 --- /dev/null +++ b/extension/llm/sampler/argmax.cu @@ -0,0 +1,72 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace executorch { +namespace extension { +namespace llm { +namespace cuda { + +// Wrapper function that performs argmax on GPU logits tensor +// Returns the token index with the highest logit value +// logits_ptr: pointer to GPU memory containing logits +// vocab_size: vocabulary size +// scalar_type: data type of the logits tensor +// cuda_stream: CUDA stream for async execution (nullptr for default stream) +// out_token_gpu: pre-allocated GPU memory for output token (int*) +int32_t argmax_cuda( + const void* logits_ptr, + int vocab_size, + ::executorch::aten::ScalarType scalar_type, + cudaStream_t cuda_stream, + int* out_token_gpu) { + // Launch kernel for single row (batch size 1) + launch_argmax_vocab_rows( + logits_ptr, + scalar_type, + 1, // rows = 1 + vocab_size, + out_token_gpu, + nullptr, // don't need max logit value + cuda_stream, + 256 // threads per block + ); + + // Copy result back to host + int32_t token; + cudaError_t err = cudaMemcpyAsync( + &token, out_token_gpu, sizeof(int), cudaMemcpyDeviceToHost, cuda_stream); + if (err != cudaSuccess) { + ET_LOG( + Error, + "Failed to copy argmax result from GPU: %s", + cudaGetErrorString(err)); + return -1; + } + + // Synchronize to ensure result is ready + err = cudaStreamSynchronize(cuda_stream); + if (err != cudaSuccess) { + ET_LOG( + Error, + "Failed to synchronize CUDA stream: %s", + cudaGetErrorString(err)); + return -1; + } + + return token; +} + +} // namespace cuda +} // namespace llm +} // namespace extension +} // namespace executorch + diff --git a/extension/llm/sampler/argmax.cuh b/extension/llm/sampler/argmax.cuh new file mode 100644 index 00000000000..0cbd0a99d79 --- /dev/null +++ b/extension/llm/sampler/argmax.cuh @@ -0,0 +1,171 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace executorch { +namespace extension { +namespace llm { +namespace cuda { + +struct ArgMaxPair { + float v; + int i; +}; + +// tie-break: smaller index wins on equal values +__device__ __forceinline__ ArgMaxPair better(ArgMaxPair a, ArgMaxPair b) { + if (b.v > a.v) + return b; + if (b.v < a.v) + return a; + return (b.i < a.i) ? b : a; +} + +__device__ __forceinline__ ArgMaxPair +warp_argmax_xor(ArgMaxPair x, unsigned mask = 0xffffffffu) { + for (int d = 16; d > 0; d >>= 1) { + ArgMaxPair y; + y.v = __shfl_xor_sync(mask, x.v, d); + y.i = __shfl_xor_sync(mask, x.i, d); + x = better(x, y); + } + return x; +} + +// ---- dtype -> float load helpers ---- +template +__device__ __forceinline__ float load_as_float(const T* p); + +template <> +__device__ __forceinline__ float load_as_float(const float* p) { + return *p; +} + +template <> +__device__ __forceinline__ float load_as_float(const half* p) { + return __half2float(*p); +} + +template <> +__device__ __forceinline__ float +load_as_float(const nv_bfloat16* p) { + return __bfloat162float(*p); +} + +// logits: [rows, vocab] row-major contiguous +// out_token: [rows] +// out_maxlogit: [rows] (optional; pass nullptr if not needed) +template +__global__ void argmax_vocab_rows_kernel( + const T* __restrict__ logits, + int rows, + int vocab, + int* __restrict__ out_token, + float* __restrict__ out_maxlogit) { + int row = blockIdx.x; + if (row >= rows) + return; + + int tid = threadIdx.x; + int lane = tid & 31; + int warp = tid >> 5; + int warps_per_block = (blockDim.x + 31) >> 5; + + const T* row_ptr = logits + (size_t)row * (size_t)vocab; + + // local scan + ArgMaxPair best; + best.v = -FLT_MAX; + best.i = -1; + + for (int j = tid; j < vocab; j += blockDim.x) { + float v = load_as_float(row_ptr + j); + best = better(best, ArgMaxPair{v, j}); + } + + // warp reduce + best = warp_argmax_xor(best); + + // shared collect warp winners (supports up to 1024 threads = 32 warps) + __shared__ float s_val[32]; + __shared__ int s_idx[32]; + + if (lane == 0) { + s_val[warp] = best.v; + s_idx[warp] = best.i; + } + __syncthreads(); + + // first warp reduces warp winners + if (warp == 0) { + ArgMaxPair wbest; + if (lane < warps_per_block) { + wbest.v = s_val[lane]; + wbest.i = s_idx[lane]; + } else { + wbest.v = -FLT_MAX; + wbest.i = -1; + } + + wbest = warp_argmax_xor(wbest); + + if (lane == 0) { + out_token[row] = wbest.i; + if (out_maxlogit) + out_maxlogit[row] = wbest.v; + } + } +} + +inline void launch_argmax_vocab_rows( + const void* logits, + ::executorch::aten::ScalarType scalar_type, + int rows, + int vocab, + int* out_token, + float* out_maxlogit, + cudaStream_t stream, + int threads = 256) { + dim3 block(threads); + dim3 grid(rows); + + switch (scalar_type) { + case ::executorch::aten::ScalarType::Float: + argmax_vocab_rows_kernel<<>>( + (const float*)logits, rows, vocab, out_token, out_maxlogit); + break; + case ::executorch::aten::ScalarType::Half: + argmax_vocab_rows_kernel<<>>( + (const half*)logits, rows, vocab, out_token, out_maxlogit); + break; + case ::executorch::aten::ScalarType::BFloat16: + argmax_vocab_rows_kernel<<>>( + (const nv_bfloat16*)logits, rows, vocab, out_token, out_maxlogit); + break; + default: + // Unsupported type, fall back to float + argmax_vocab_rows_kernel<<>>( + (const float*)logits, rows, vocab, out_token, out_maxlogit); + break; + } +} + +} // namespace cuda +} // namespace llm +} // namespace extension +} // namespace executorch + diff --git a/extension/llm/sampler/test/CMakeLists.txt b/extension/llm/sampler/test/CMakeLists.txt new file mode 100644 index 00000000000..82e3fc6577a --- /dev/null +++ b/extension/llm/sampler/test/CMakeLists.txt @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.19) +project(llm_sampler_cuda_tests LANGUAGES CXX CUDA) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CUDA_STANDARD 17) +set(CMAKE_CUDA_STANDARD_REQUIRED ON) + +# Find required packages +find_package(CUDAToolkit REQUIRED) + +# Fetch GoogleTest +include(FetchContent) +FetchContent_Declare( + googletest + GIT_REPOSITORY https://github.com/google/googletest.git + GIT_TAG v1.14.0 +) +# For Windows: Prevent overriding the parent project's compiler/linker settings +set(gtest_force_shared_crt + ON + CACHE BOOL "" FORCE +) +FetchContent_MakeAvailable(googletest) + +# Get EXECUTORCH_ROOT +if(NOT EXECUTORCH_ROOT) + set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../..) +endif() + +# Find installed ExecuTorch +find_package(executorch CONFIG REQUIRED HINTS ${CMAKE_INSTALL_PREFIX}) + +# List of CUDA test files +set(LLM_SAMPLER_CUDA_TESTS test_argmax) + +enable_testing() + +foreach(test_name ${LLM_SAMPLER_CUDA_TESTS}) + add_executable(${test_name} ${test_name}.cu) + + target_include_directories( + ${test_name} PRIVATE ${EXECUTORCH_ROOT}/.. ${EXECUTORCH_ROOT} + ${CUDAToolkit_INCLUDE_DIRS} + ) + + target_link_libraries( + ${test_name} + PRIVATE GTest::gtest + GTest::gtest_main + executorch_core + CUDA::cudart + ) + + add_test(NAME ${test_name} COMMAND ${test_name}) +endforeach() + diff --git a/extension/llm/sampler/test/CMakePresets.json b/extension/llm/sampler/test/CMakePresets.json new file mode 100644 index 00000000000..427939ce05f --- /dev/null +++ b/extension/llm/sampler/test/CMakePresets.json @@ -0,0 +1,96 @@ +{ + "version": 6, + "configurePresets": [ + { + "name": "default", + "displayName": "LLM Sampler CUDA Tests", + "binaryDir": "${sourceDir}/../../../../cmake-out/extension/llm/sampler/test", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_PREFIX_PATH": "${sourceDir}/../../../../cmake-out" + }, + "condition": { + "type": "inList", + "string": "${hostSystemName}", + "list": ["Linux", "Windows"] + } + }, + { + "name": "debug", + "displayName": "LLM Sampler CUDA Tests (Debug)", + "inherits": ["default"], + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug" + } + } + ], + "buildPresets": [ + { + "name": "default", + "displayName": "Build LLM Sampler CUDA Tests", + "configurePreset": "default" + }, + { + "name": "debug", + "displayName": "Build LLM Sampler CUDA Tests (Debug)", + "configurePreset": "debug" + } + ], + "workflowPresets": [ + { + "name": "default", + "displayName": "Configure, build, and test LLM Sampler CUDA Tests", + "steps": [ + { + "type": "configure", + "name": "default" + }, + { + "type": "build", + "name": "default" + }, + { + "type": "test", + "name": "default" + } + ] + }, + { + "name": "debug", + "displayName": "Configure, build, and test LLM Sampler CUDA Tests (Debug)", + "steps": [ + { + "type": "configure", + "name": "debug" + }, + { + "type": "build", + "name": "debug" + }, + { + "type": "test", + "name": "debug" + } + ] + } + ], + "testPresets": [ + { + "name": "default", + "displayName": "Run all LLM Sampler CUDA Tests", + "configurePreset": "default", + "output": { + "outputOnFailure": true + } + }, + { + "name": "debug", + "displayName": "Run all LLM Sampler CUDA Tests (Debug)", + "configurePreset": "debug", + "output": { + "outputOnFailure": true + } + } + ] +} + diff --git a/extension/llm/sampler/test/test_argmax.cu b/extension/llm/sampler/test/test_argmax.cu new file mode 100644 index 00000000000..8f12b5231c5 --- /dev/null +++ b/extension/llm/sampler/test/test_argmax.cu @@ -0,0 +1,488 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include + +using namespace executorch::extension::llm::cuda; + +// Test fixture for argmax tests +class ArgmaxTest : public ::testing::Test { + protected: + void SetUp() override { + // Initialize ExecuTorch Platform Abstraction Layer + et_pal_init(); + + // Check if CUDA is available + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available, skipping CUDA tests"; + } + + // Create CUDA stream + ASSERT_EQ(cudaStreamCreate(&stream_), cudaSuccess); + + // Allocate output buffers on GPU + ASSERT_EQ(cudaMalloc(&out_token_gpu_, sizeof(int) * max_rows_), cudaSuccess); + ASSERT_EQ( + cudaMalloc(&out_maxlogit_gpu_, sizeof(float) * max_rows_), cudaSuccess); + } + + void TearDown() override { + if (out_token_gpu_) { + cudaFree(out_token_gpu_); + } + if (out_maxlogit_gpu_) { + cudaFree(out_maxlogit_gpu_); + } + if (stream_) { + cudaStreamDestroy(stream_); + } + } + + // Helper to create and upload float tensor to GPU + float* create_float_tensor(const std::vector& data) { + float* gpu_ptr; + EXPECT_EQ(cudaMalloc(&gpu_ptr, data.size() * sizeof(float)), cudaSuccess); + EXPECT_EQ( + cudaMemcpy( + gpu_ptr, + data.data(), + data.size() * sizeof(float), + cudaMemcpyHostToDevice), + cudaSuccess); + return gpu_ptr; + } + + // Helper to create and upload half tensor to GPU + half* create_half_tensor(const std::vector& data) { + std::vector half_data(data.size()); + for (size_t i = 0; i < data.size(); ++i) { + half_data[i] = __float2half(data[i]); + } + half* gpu_ptr; + EXPECT_EQ(cudaMalloc(&gpu_ptr, data.size() * sizeof(half)), cudaSuccess); + EXPECT_EQ( + cudaMemcpy( + gpu_ptr, + half_data.data(), + data.size() * sizeof(half), + cudaMemcpyHostToDevice), + cudaSuccess); + return gpu_ptr; + } + + // Helper to create and upload bfloat16 tensor to GPU + nv_bfloat16* create_bfloat16_tensor(const std::vector& data) { + std::vector bf16_data(data.size()); + for (size_t i = 0; i < data.size(); ++i) { + bf16_data[i] = __float2bfloat16(data[i]); + } + nv_bfloat16* gpu_ptr; + EXPECT_EQ( + cudaMalloc(&gpu_ptr, data.size() * sizeof(nv_bfloat16)), cudaSuccess); + EXPECT_EQ( + cudaMemcpy( + gpu_ptr, + bf16_data.data(), + data.size() * sizeof(nv_bfloat16), + cudaMemcpyHostToDevice), + cudaSuccess); + return gpu_ptr; + } + + // Helper to get CPU argmax for verification + int cpu_argmax(const std::vector& data) { + return static_cast( + std::max_element(data.begin(), data.end()) - data.begin()); + } + + cudaStream_t stream_ = nullptr; + int* out_token_gpu_ = nullptr; + float* out_maxlogit_gpu_ = nullptr; + static constexpr int max_rows_ = 16; +}; + +// Test basic argmax with float32 +TEST_F(ArgmaxTest, BasicFloat32) { + std::vector logits = {0.1f, 0.5f, 0.8f, 0.3f, 0.2f}; + int vocab_size = static_cast(logits.size()); + int expected_idx = cpu_argmax(logits); + + float* gpu_logits = create_float_tensor(logits); + + launch_argmax_vocab_rows( + gpu_logits, + ::executorch::aten::ScalarType::Float, + 1, // rows + vocab_size, + out_token_gpu_, + out_maxlogit_gpu_, + stream_, + 256); + + ASSERT_EQ(cudaStreamSynchronize(stream_), cudaSuccess); + + int out_token; + float out_maxlogit; + ASSERT_EQ( + cudaMemcpy(&out_token, out_token_gpu_, sizeof(int), cudaMemcpyDeviceToHost), + cudaSuccess); + ASSERT_EQ( + cudaMemcpy( + &out_maxlogit, out_maxlogit_gpu_, sizeof(float), cudaMemcpyDeviceToHost), + cudaSuccess); + + EXPECT_EQ(out_token, expected_idx); + EXPECT_FLOAT_EQ(out_maxlogit, logits[expected_idx]); + + cudaFree(gpu_logits); +} + +// Test argmax with half precision +TEST_F(ArgmaxTest, BasicHalf) { + std::vector logits = {0.1f, 0.2f, 0.9f, 0.4f, 0.5f, 0.6f}; + int vocab_size = static_cast(logits.size()); + int expected_idx = cpu_argmax(logits); + + half* gpu_logits = create_half_tensor(logits); + + launch_argmax_vocab_rows( + gpu_logits, + ::executorch::aten::ScalarType::Half, + 1, // rows + vocab_size, + out_token_gpu_, + out_maxlogit_gpu_, + stream_, + 256); + + ASSERT_EQ(cudaStreamSynchronize(stream_), cudaSuccess); + + int out_token; + float out_maxlogit; + ASSERT_EQ( + cudaMemcpy(&out_token, out_token_gpu_, sizeof(int), cudaMemcpyDeviceToHost), + cudaSuccess); + ASSERT_EQ( + cudaMemcpy( + &out_maxlogit, out_maxlogit_gpu_, sizeof(float), cudaMemcpyDeviceToHost), + cudaSuccess); + + EXPECT_EQ(out_token, expected_idx); + // Half precision has some tolerance + EXPECT_NEAR(out_maxlogit, logits[expected_idx], 0.01f); + + cudaFree(gpu_logits); +} + +// Test argmax with bfloat16 +TEST_F(ArgmaxTest, BasicBFloat16) { + std::vector logits = {-1.0f, 2.5f, 1.0f, 0.5f}; + int vocab_size = static_cast(logits.size()); + int expected_idx = cpu_argmax(logits); + + nv_bfloat16* gpu_logits = create_bfloat16_tensor(logits); + + launch_argmax_vocab_rows( + gpu_logits, + ::executorch::aten::ScalarType::BFloat16, + 1, // rows + vocab_size, + out_token_gpu_, + out_maxlogit_gpu_, + stream_, + 256); + + ASSERT_EQ(cudaStreamSynchronize(stream_), cudaSuccess); + + int out_token; + float out_maxlogit; + ASSERT_EQ( + cudaMemcpy(&out_token, out_token_gpu_, sizeof(int), cudaMemcpyDeviceToHost), + cudaSuccess); + ASSERT_EQ( + cudaMemcpy( + &out_maxlogit, out_maxlogit_gpu_, sizeof(float), cudaMemcpyDeviceToHost), + cudaSuccess); + + EXPECT_EQ(out_token, expected_idx); + // BFloat16 has some tolerance + EXPECT_NEAR(out_maxlogit, logits[expected_idx], 0.1f); + + cudaFree(gpu_logits); +} + +// Test with large vocabulary (typical for LLMs) +TEST_F(ArgmaxTest, LargeVocab) { + int vocab_size = 32000; // Typical LLM vocab size + std::vector logits(vocab_size); + + // Fill with random values + std::mt19937 gen(42); + std::uniform_real_distribution dist(-5.0f, 5.0f); + for (int i = 0; i < vocab_size; ++i) { + logits[i] = dist(gen); + } + + // Set a known maximum + int expected_idx = 12345; + logits[expected_idx] = 100.0f; + + float* gpu_logits = create_float_tensor(logits); + + launch_argmax_vocab_rows( + gpu_logits, + ::executorch::aten::ScalarType::Float, + 1, // rows + vocab_size, + out_token_gpu_, + out_maxlogit_gpu_, + stream_, + 256); + + ASSERT_EQ(cudaStreamSynchronize(stream_), cudaSuccess); + + int out_token; + float out_maxlogit; + ASSERT_EQ( + cudaMemcpy(&out_token, out_token_gpu_, sizeof(int), cudaMemcpyDeviceToHost), + cudaSuccess); + ASSERT_EQ( + cudaMemcpy( + &out_maxlogit, out_maxlogit_gpu_, sizeof(float), cudaMemcpyDeviceToHost), + cudaSuccess); + + EXPECT_EQ(out_token, expected_idx); + EXPECT_FLOAT_EQ(out_maxlogit, 100.0f); + + cudaFree(gpu_logits); +} + +// Test multiple rows (batch) +TEST_F(ArgmaxTest, MultipleRows) { + int rows = 4; + int vocab_size = 10; + std::vector logits = { + // Row 0: max at index 2 + 0.1f, 0.2f, 0.9f, 0.3f, 0.4f, 0.5f, 0.1f, 0.2f, 0.3f, 0.4f, + // Row 1: max at index 5 + 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.8f, 0.1f, 0.2f, 0.3f, 0.4f, + // Row 2: max at index 9 + 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.1f, 0.2f, 0.3f, 0.4f, 0.95f, + // Row 3: max at index 0 + 0.99f, 0.2f, 0.3f, 0.4f, 0.5f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, + }; + + std::vector expected_indices = {2, 5, 9, 0}; + + float* gpu_logits = create_float_tensor(logits); + + launch_argmax_vocab_rows( + gpu_logits, + ::executorch::aten::ScalarType::Float, + rows, + vocab_size, + out_token_gpu_, + out_maxlogit_gpu_, + stream_, + 256); + + ASSERT_EQ(cudaStreamSynchronize(stream_), cudaSuccess); + + std::vector out_tokens(rows); + ASSERT_EQ( + cudaMemcpy( + out_tokens.data(), + out_token_gpu_, + rows * sizeof(int), + cudaMemcpyDeviceToHost), + cudaSuccess); + + for (int i = 0; i < rows; ++i) { + EXPECT_EQ(out_tokens[i], expected_indices[i]) << "Row " << i << " failed"; + } + + cudaFree(gpu_logits); +} + +// Test tie-breaking (smaller index wins) +TEST_F(ArgmaxTest, TieBreaking) { + std::vector logits = {0.5f, 0.5f, 0.5f, 0.5f, 0.5f}; + int vocab_size = static_cast(logits.size()); + + float* gpu_logits = create_float_tensor(logits); + + launch_argmax_vocab_rows( + gpu_logits, + ::executorch::aten::ScalarType::Float, + 1, // rows + vocab_size, + out_token_gpu_, + nullptr, // don't need max logit + stream_, + 256); + + ASSERT_EQ(cudaStreamSynchronize(stream_), cudaSuccess); + + int out_token; + ASSERT_EQ( + cudaMemcpy(&out_token, out_token_gpu_, sizeof(int), cudaMemcpyDeviceToHost), + cudaSuccess); + + // Smallest index should win on tie + EXPECT_EQ(out_token, 0); + + cudaFree(gpu_logits); +} + +// Test with negative values +TEST_F(ArgmaxTest, NegativeValues) { + std::vector logits = {-5.0f, -2.0f, -1.0f, -3.0f, -4.0f}; + int vocab_size = static_cast(logits.size()); + int expected_idx = cpu_argmax(logits); // Should be index 2 (-1.0f) + + float* gpu_logits = create_float_tensor(logits); + + launch_argmax_vocab_rows( + gpu_logits, + ::executorch::aten::ScalarType::Float, + 1, // rows + vocab_size, + out_token_gpu_, + out_maxlogit_gpu_, + stream_, + 256); + + ASSERT_EQ(cudaStreamSynchronize(stream_), cudaSuccess); + + int out_token; + float out_maxlogit; + ASSERT_EQ( + cudaMemcpy(&out_token, out_token_gpu_, sizeof(int), cudaMemcpyDeviceToHost), + cudaSuccess); + ASSERT_EQ( + cudaMemcpy( + &out_maxlogit, out_maxlogit_gpu_, sizeof(float), cudaMemcpyDeviceToHost), + cudaSuccess); + + EXPECT_EQ(out_token, expected_idx); + EXPECT_FLOAT_EQ(out_maxlogit, -1.0f); + + cudaFree(gpu_logits); +} + +// Test max at first position +TEST_F(ArgmaxTest, MaxAtFirst) { + std::vector logits = {10.0f, 1.0f, 2.0f, 3.0f, 4.0f}; + int vocab_size = static_cast(logits.size()); + + float* gpu_logits = create_float_tensor(logits); + + launch_argmax_vocab_rows( + gpu_logits, + ::executorch::aten::ScalarType::Float, + 1, + vocab_size, + out_token_gpu_, + nullptr, + stream_, + 256); + + ASSERT_EQ(cudaStreamSynchronize(stream_), cudaSuccess); + + int out_token; + ASSERT_EQ( + cudaMemcpy(&out_token, out_token_gpu_, sizeof(int), cudaMemcpyDeviceToHost), + cudaSuccess); + + EXPECT_EQ(out_token, 0); + + cudaFree(gpu_logits); +} + +// Test max at last position +TEST_F(ArgmaxTest, MaxAtLast) { + std::vector logits = {1.0f, 2.0f, 3.0f, 4.0f, 10.0f}; + int vocab_size = static_cast(logits.size()); + + float* gpu_logits = create_float_tensor(logits); + + launch_argmax_vocab_rows( + gpu_logits, + ::executorch::aten::ScalarType::Float, + 1, + vocab_size, + out_token_gpu_, + nullptr, + stream_, + 256); + + ASSERT_EQ(cudaStreamSynchronize(stream_), cudaSuccess); + + int out_token; + ASSERT_EQ( + cudaMemcpy(&out_token, out_token_gpu_, sizeof(int), cudaMemcpyDeviceToHost), + cudaSuccess); + + EXPECT_EQ(out_token, 4); + + cudaFree(gpu_logits); +} + +// Test with different thread counts +TEST_F(ArgmaxTest, DifferentThreadCounts) { + std::vector logits(1024); + std::mt19937 gen(123); + std::uniform_real_distribution dist(-10.0f, 10.0f); + for (size_t i = 0; i < logits.size(); ++i) { + logits[i] = dist(gen); + } + + int expected_idx = cpu_argmax(logits); + float* gpu_logits = create_float_tensor(logits); + + std::vector thread_counts = {32, 64, 128, 256, 512}; + + for (int threads : thread_counts) { + launch_argmax_vocab_rows( + gpu_logits, + ::executorch::aten::ScalarType::Float, + 1, + static_cast(logits.size()), + out_token_gpu_, + nullptr, + stream_, + threads); + + ASSERT_EQ(cudaStreamSynchronize(stream_), cudaSuccess); + + int out_token; + ASSERT_EQ( + cudaMemcpy( + &out_token, out_token_gpu_, sizeof(int), cudaMemcpyDeviceToHost), + cudaSuccess); + + EXPECT_EQ(out_token, expected_idx) << "Failed with " << threads << " threads"; + } + + cudaFree(gpu_logits); +} + +