diff --git a/backends/cuda/runtime/shims/memory_slim.cpp b/backends/cuda/runtime/shims/memory_slim.cpp index 7996d330db5..93fd884958c 100644 --- a/backends/cuda/runtime/shims/memory_slim.cpp +++ b/backends/cuda/runtime/shims/memory_slim.cpp @@ -206,6 +206,35 @@ aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking) { return Error::Ok; } +AOTITorchError aoti_torch_item_bool(Tensor* tensor, bool* ret_value) { + ET_CHECK_OR_RETURN_ERROR( + tensor != nullptr, + InvalidArgument, + "aoti_torch_item_bool: tensor is null"); + + ET_CHECK_OR_RETURN_ERROR( + ret_value != nullptr, + InvalidArgument, + "aoti_torch_item_bool: ret_value is null"); + + ET_CHECK_OR_RETURN_ERROR( + tensor->numel() == 1, + InvalidArgument, + "aoti_torch_item_bool: tensor must have exactly 1 element, got %zu", + tensor->numel()); + + ET_CHECK_OR_RETURN_ERROR( + tensor->dtype() == ScalarType::Bool, + InvalidArgument, + "aoti_torch_item_bool: tensor dtype must be Bool"); + + // SlimTensor::item() handles both CPU and CUDA tensors. + // For CUDA tensors, it copies the value to CPU automatically. + *ret_value = tensor->item(); + + return Error::Ok; +} + AOTITorchError aoti_torch_assign_tensors_out(Tensor* src, Tensor** ret_dst) { ET_CHECK_OR_RETURN_ERROR( src != nullptr, diff --git a/backends/cuda/runtime/shims/memory_slim.h b/backends/cuda/runtime/shims/memory_slim.h index 0f293d1e995..ec8b8db14f8 100644 --- a/backends/cuda/runtime/shims/memory_slim.h +++ b/backends/cuda/runtime/shims/memory_slim.h @@ -143,6 +143,19 @@ AOTI_SHIM_EXPORT AOTITorchError aoti_torch__reinterpret_tensor( AOTI_SHIM_EXPORT AOTITorchError aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking); +/** + * Extracts a boolean scalar value from a single-element tensor. + * + * The tensor must contain exactly one element and have Bool dtype. + * For CUDA tensors, this will synchronize to copy the value to CPU. + * + * @param tensor Single-element boolean tensor (must not be null) + * @param ret_value Output parameter for the extracted boolean value + * @return AOTITorchError error code (Error::Ok on success) + */ +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_item_bool(Tensor* tensor, bool* ret_value); + /** * Moves a tensor into a new handle and assigns it to the output parameter. * diff --git a/backends/cuda/runtime/shims/tests/targets.bzl b/backends/cuda/runtime/shims/tests/targets.bzl index 852c30465af..a6b18eba4c8 100644 --- a/backends/cuda/runtime/shims/tests/targets.bzl +++ b/backends/cuda/runtime/shims/tests/targets.bzl @@ -77,4 +77,5 @@ def define_common_targets(): cuda_shim_slim_cpp_unittest("aoti_torch_new_tensor_handle") cuda_shim_slim_cpp_unittest("aoti_torch__reinterpret_tensor") cuda_shim_slim_cpp_unittest("aoti_torch_copy_") + cuda_shim_slim_cpp_unittest("aoti_torch_item_bool") cuda_shim_slim_cpp_unittest("aoti_torch_assign_tensors_out") diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_item_bool_slim.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_item_bool_slim.cpp new file mode 100644 index 00000000000..dee95cbafe2 --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_item_bool_slim.cpp @@ -0,0 +1,291 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include +#include +#include +#include +#include + +using namespace executorch::backends::cuda; +using executorch::runtime::Error; + +namespace slim_c10 = executorch::backends::aoti::slim::c10; + +namespace { + +bool isCudaAvailable() { + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + return (err == cudaSuccess && device_count > 0); +} + +} // namespace + +class AOTITorchItemBoolSlimTest : public ::testing::Test { + protected: + void SetUp() override { + et_pal_init(); + } + + Tensor* createScalarBoolTensor( + bool value, + int32_t device_type = static_cast(slim_c10::DeviceType::CPU), + int32_t device_index = 0) { + Tensor* tensor = nullptr; + + std::vector sizes = {1}; + std::vector strides = {1}; + + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides.data(), + static_cast(slim_c10::ScalarType::Bool), + device_type, + device_index, + &tensor); + + if (error != Error::Ok || tensor == nullptr) { + return nullptr; + } + + if (device_type == static_cast(slim_c10::DeviceType::CPU)) { + bool* data = static_cast(tensor->data_ptr()); + *data = value; + } else { + cudaMemcpy( + tensor->data_ptr(), &value, sizeof(bool), cudaMemcpyHostToDevice); + } + + return tensor; + } + + Tensor* createTestTensor( + const std::vector& sizes, + int32_t dtype = static_cast(slim_c10::ScalarType::Float), + int32_t device_type = static_cast(slim_c10::DeviceType::CPU), + int32_t device_index = 0) { + Tensor* tensor = nullptr; + + std::vector strides(sizes.size()); + if (!sizes.empty()) { + strides[sizes.size() - 1] = 1; + for (int64_t i = static_cast(sizes.size()) - 2; i >= 0; i--) { + strides[i] = strides[i + 1] * sizes[i + 1]; + } + } + + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides.data(), + dtype, + device_type, + device_index, + &tensor); + + return (error == Error::Ok) ? tensor : nullptr; + } +}; + +// ============================================================================ +// Basic Functionality Tests +// ============================================================================ + +TEST_F(AOTITorchItemBoolSlimTest, TrueValue_CPU) { + Tensor* tensor = createScalarBoolTensor( + true, static_cast(slim_c10::DeviceType::CPU), 0); + ASSERT_NE(tensor, nullptr); + + bool result = false; + AOTITorchError error = aoti_torch_item_bool(tensor, &result); + + EXPECT_EQ(error, Error::Ok); + EXPECT_EQ(result, true); + + EXPECT_EQ(aoti_torch_delete_tensor_object(tensor), Error::Ok); +} + +TEST_F(AOTITorchItemBoolSlimTest, FalseValue_CPU) { + Tensor* tensor = createScalarBoolTensor( + false, static_cast(slim_c10::DeviceType::CPU), 0); + ASSERT_NE(tensor, nullptr); + + bool result = true; + AOTITorchError error = aoti_torch_item_bool(tensor, &result); + + EXPECT_EQ(error, Error::Ok); + EXPECT_EQ(result, false); + + EXPECT_EQ(aoti_torch_delete_tensor_object(tensor), Error::Ok); +} + +// ============================================================================ +// Error Handling Tests +// ============================================================================ + +TEST_F(AOTITorchItemBoolSlimTest, NullTensor) { + bool result = false; + AOTITorchError error = aoti_torch_item_bool(nullptr, &result); + + EXPECT_EQ(error, Error::InvalidArgument); +} + +TEST_F(AOTITorchItemBoolSlimTest, NullReturnValue) { + Tensor* tensor = createScalarBoolTensor( + true, static_cast(slim_c10::DeviceType::CPU), 0); + ASSERT_NE(tensor, nullptr); + + AOTITorchError error = aoti_torch_item_bool(tensor, nullptr); + + EXPECT_EQ(error, Error::InvalidArgument); + + EXPECT_EQ(aoti_torch_delete_tensor_object(tensor), Error::Ok); +} + +TEST_F(AOTITorchItemBoolSlimTest, MultiElementTensor) { + std::vector sizes = {2, 3}; + Tensor* tensor = createTestTensor( + sizes, + static_cast(slim_c10::ScalarType::Bool), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(tensor, nullptr); + EXPECT_GT(tensor->numel(), 1); + + bool result = false; + AOTITorchError error = aoti_torch_item_bool(tensor, &result); + + EXPECT_EQ(error, Error::InvalidArgument); + + EXPECT_EQ(aoti_torch_delete_tensor_object(tensor), Error::Ok); +} + +TEST_F(AOTITorchItemBoolSlimTest, WrongDtype_Float) { + std::vector sizes = {1}; + Tensor* tensor = createTestTensor( + sizes, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(tensor, nullptr); + + bool result = false; + AOTITorchError error = aoti_torch_item_bool(tensor, &result); + + EXPECT_EQ(error, Error::InvalidArgument); + + EXPECT_EQ(aoti_torch_delete_tensor_object(tensor), Error::Ok); +} + +TEST_F(AOTITorchItemBoolSlimTest, WrongDtype_Long) { + std::vector sizes = {1}; + Tensor* tensor = createTestTensor( + sizes, + static_cast(slim_c10::ScalarType::Long), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(tensor, nullptr); + + bool result = false; + AOTITorchError error = aoti_torch_item_bool(tensor, &result); + + EXPECT_EQ(error, Error::InvalidArgument); + + EXPECT_EQ(aoti_torch_delete_tensor_object(tensor), Error::Ok); +} + +// ============================================================================ +// CUDA Tests +// ============================================================================ + +TEST_F(AOTITorchItemBoolSlimTest, TrueValue_CUDA) { + if (!isCudaAvailable()) { + GTEST_SKIP() << "CUDA not available"; + } + + Tensor* tensor = createScalarBoolTensor( + true, static_cast(slim_c10::DeviceType::CUDA), 0); + ASSERT_NE(tensor, nullptr); + EXPECT_TRUE(tensor->is_cuda()); + + bool result = false; + AOTITorchError error = aoti_torch_item_bool(tensor, &result); + + EXPECT_EQ(error, Error::Ok); + EXPECT_EQ(result, true); + + EXPECT_EQ(aoti_torch_delete_tensor_object(tensor), Error::Ok); +} + +TEST_F(AOTITorchItemBoolSlimTest, FalseValue_CUDA) { + if (!isCudaAvailable()) { + GTEST_SKIP() << "CUDA not available"; + } + + Tensor* tensor = createScalarBoolTensor( + false, static_cast(slim_c10::DeviceType::CUDA), 0); + ASSERT_NE(tensor, nullptr); + EXPECT_TRUE(tensor->is_cuda()); + + bool result = true; + AOTITorchError error = aoti_torch_item_bool(tensor, &result); + + EXPECT_EQ(error, Error::Ok); + EXPECT_EQ(result, false); + + EXPECT_EQ(aoti_torch_delete_tensor_object(tensor), Error::Ok); +} + +TEST_F(AOTITorchItemBoolSlimTest, MultiElementTensor_CUDA) { + if (!isCudaAvailable()) { + GTEST_SKIP() << "CUDA not available"; + } + + std::vector sizes = {2, 3}; + Tensor* tensor = createTestTensor( + sizes, + static_cast(slim_c10::ScalarType::Bool), + static_cast(slim_c10::DeviceType::CUDA), + 0); + ASSERT_NE(tensor, nullptr); + EXPECT_TRUE(tensor->is_cuda()); + + bool result = false; + AOTITorchError error = aoti_torch_item_bool(tensor, &result); + + EXPECT_EQ(error, Error::InvalidArgument); + + EXPECT_EQ(aoti_torch_delete_tensor_object(tensor), Error::Ok); +} + +TEST_F(AOTITorchItemBoolSlimTest, WrongDtype_Float_CUDA) { + if (!isCudaAvailable()) { + GTEST_SKIP() << "CUDA not available"; + } + + std::vector sizes = {1}; + Tensor* tensor = createTestTensor( + sizes, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CUDA), + 0); + ASSERT_NE(tensor, nullptr); + + bool result = false; + AOTITorchError error = aoti_torch_item_bool(tensor, &result); + + EXPECT_EQ(error, Error::InvalidArgument); + + EXPECT_EQ(aoti_torch_delete_tensor_object(tensor), Error::Ok); +}