Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions backends/cuda/runtime/shims/memory_slim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,35 @@ aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking) {
return Error::Ok;
}

AOTITorchError aoti_torch_item_bool(Tensor* tensor, bool* ret_value) {
ET_CHECK_OR_RETURN_ERROR(
tensor != nullptr,
InvalidArgument,
"aoti_torch_item_bool: tensor is null");

ET_CHECK_OR_RETURN_ERROR(
ret_value != nullptr,
InvalidArgument,
"aoti_torch_item_bool: ret_value is null");

ET_CHECK_OR_RETURN_ERROR(
tensor->numel() == 1,
InvalidArgument,
"aoti_torch_item_bool: tensor must have exactly 1 element, got %zu",
tensor->numel());

ET_CHECK_OR_RETURN_ERROR(
tensor->dtype() == ScalarType::Bool,
InvalidArgument,
"aoti_torch_item_bool: tensor dtype must be Bool");

// SlimTensor::item<T>() handles both CPU and CUDA tensors.
// For CUDA tensors, it copies the value to CPU automatically.
*ret_value = tensor->item<bool>();

return Error::Ok;
}

AOTITorchError aoti_torch_assign_tensors_out(Tensor* src, Tensor** ret_dst) {
ET_CHECK_OR_RETURN_ERROR(
src != nullptr,
Expand Down
13 changes: 13 additions & 0 deletions backends/cuda/runtime/shims/memory_slim.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,19 @@ AOTI_SHIM_EXPORT AOTITorchError aoti_torch__reinterpret_tensor(
AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking);

/**
* Extracts a boolean scalar value from a single-element tensor.
*
* The tensor must contain exactly one element and have Bool dtype.
* For CUDA tensors, this will synchronize to copy the value to CPU.
*
* @param tensor Single-element boolean tensor (must not be null)
* @param ret_value Output parameter for the extracted boolean value
* @return AOTITorchError error code (Error::Ok on success)
*/
AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_item_bool(Tensor* tensor, bool* ret_value);

/**
* Moves a tensor into a new handle and assigns it to the output parameter.
*
Expand Down
1 change: 1 addition & 0 deletions backends/cuda/runtime/shims/tests/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,5 @@ def define_common_targets():
cuda_shim_slim_cpp_unittest("aoti_torch_new_tensor_handle")
cuda_shim_slim_cpp_unittest("aoti_torch__reinterpret_tensor")
cuda_shim_slim_cpp_unittest("aoti_torch_copy_")
cuda_shim_slim_cpp_unittest("aoti_torch_item_bool")
cuda_shim_slim_cpp_unittest("aoti_torch_assign_tensors_out")
291 changes: 291 additions & 0 deletions backends/cuda/runtime/shims/tests/test_aoti_torch_item_bool_slim.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <vector>

#include <executorch/backends/aoti/slim/c10/core/Device.h>
#include <executorch/backends/aoti/slim/c10/core/ScalarType.h>
#include <executorch/backends/cuda/runtime/shims/memory_slim.h>
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/platform/platform.h>

using namespace executorch::backends::cuda;
using executorch::runtime::Error;

namespace slim_c10 = executorch::backends::aoti::slim::c10;

namespace {

bool isCudaAvailable() {
int device_count = 0;
cudaError_t err = cudaGetDeviceCount(&device_count);
return (err == cudaSuccess && device_count > 0);
}

} // namespace

class AOTITorchItemBoolSlimTest : public ::testing::Test {
protected:
void SetUp() override {
et_pal_init();
}

Tensor* createScalarBoolTensor(
bool value,
int32_t device_type = static_cast<int32_t>(slim_c10::DeviceType::CPU),
int32_t device_index = 0) {
Tensor* tensor = nullptr;

std::vector<int64_t> sizes = {1};
std::vector<int64_t> strides = {1};

AOTITorchError error = aoti_torch_empty_strided(
sizes.size(),
sizes.data(),
strides.data(),
static_cast<int32_t>(slim_c10::ScalarType::Bool),
device_type,
device_index,
&tensor);

if (error != Error::Ok || tensor == nullptr) {
return nullptr;
}

if (device_type == static_cast<int32_t>(slim_c10::DeviceType::CPU)) {
bool* data = static_cast<bool*>(tensor->data_ptr());
*data = value;
} else {
cudaMemcpy(
tensor->data_ptr(), &value, sizeof(bool), cudaMemcpyHostToDevice);
}

return tensor;
}

Tensor* createTestTensor(
const std::vector<int64_t>& sizes,
int32_t dtype = static_cast<int32_t>(slim_c10::ScalarType::Float),
int32_t device_type = static_cast<int32_t>(slim_c10::DeviceType::CPU),
int32_t device_index = 0) {
Tensor* tensor = nullptr;

std::vector<int64_t> strides(sizes.size());
if (!sizes.empty()) {
strides[sizes.size() - 1] = 1;
for (int64_t i = static_cast<int64_t>(sizes.size()) - 2; i >= 0; i--) {
strides[i] = strides[i + 1] * sizes[i + 1];
}
}

AOTITorchError error = aoti_torch_empty_strided(
sizes.size(),
sizes.data(),
strides.data(),
dtype,
device_type,
device_index,
&tensor);

return (error == Error::Ok) ? tensor : nullptr;
}
};

// ============================================================================
// Basic Functionality Tests
// ============================================================================

TEST_F(AOTITorchItemBoolSlimTest, TrueValue_CPU) {
Tensor* tensor = createScalarBoolTensor(
true, static_cast<int32_t>(slim_c10::DeviceType::CPU), 0);
ASSERT_NE(tensor, nullptr);

bool result = false;
AOTITorchError error = aoti_torch_item_bool(tensor, &result);

EXPECT_EQ(error, Error::Ok);
EXPECT_EQ(result, true);

EXPECT_EQ(aoti_torch_delete_tensor_object(tensor), Error::Ok);
}

TEST_F(AOTITorchItemBoolSlimTest, FalseValue_CPU) {
Tensor* tensor = createScalarBoolTensor(
false, static_cast<int32_t>(slim_c10::DeviceType::CPU), 0);
ASSERT_NE(tensor, nullptr);

bool result = true;
AOTITorchError error = aoti_torch_item_bool(tensor, &result);

EXPECT_EQ(error, Error::Ok);
EXPECT_EQ(result, false);

EXPECT_EQ(aoti_torch_delete_tensor_object(tensor), Error::Ok);
}

// ============================================================================
// Error Handling Tests
// ============================================================================

TEST_F(AOTITorchItemBoolSlimTest, NullTensor) {
bool result = false;
AOTITorchError error = aoti_torch_item_bool(nullptr, &result);

EXPECT_EQ(error, Error::InvalidArgument);
}

TEST_F(AOTITorchItemBoolSlimTest, NullReturnValue) {
Tensor* tensor = createScalarBoolTensor(
true, static_cast<int32_t>(slim_c10::DeviceType::CPU), 0);
ASSERT_NE(tensor, nullptr);

AOTITorchError error = aoti_torch_item_bool(tensor, nullptr);

EXPECT_EQ(error, Error::InvalidArgument);

EXPECT_EQ(aoti_torch_delete_tensor_object(tensor), Error::Ok);
}

TEST_F(AOTITorchItemBoolSlimTest, MultiElementTensor) {
std::vector<int64_t> sizes = {2, 3};
Tensor* tensor = createTestTensor(
sizes,
static_cast<int32_t>(slim_c10::ScalarType::Bool),
static_cast<int32_t>(slim_c10::DeviceType::CPU),
0);
ASSERT_NE(tensor, nullptr);
EXPECT_GT(tensor->numel(), 1);

bool result = false;
AOTITorchError error = aoti_torch_item_bool(tensor, &result);

EXPECT_EQ(error, Error::InvalidArgument);

EXPECT_EQ(aoti_torch_delete_tensor_object(tensor), Error::Ok);
}

TEST_F(AOTITorchItemBoolSlimTest, WrongDtype_Float) {
std::vector<int64_t> sizes = {1};
Tensor* tensor = createTestTensor(
sizes,
static_cast<int32_t>(slim_c10::ScalarType::Float),
static_cast<int32_t>(slim_c10::DeviceType::CPU),
0);
ASSERT_NE(tensor, nullptr);

bool result = false;
AOTITorchError error = aoti_torch_item_bool(tensor, &result);

EXPECT_EQ(error, Error::InvalidArgument);

EXPECT_EQ(aoti_torch_delete_tensor_object(tensor), Error::Ok);
}

TEST_F(AOTITorchItemBoolSlimTest, WrongDtype_Long) {
std::vector<int64_t> sizes = {1};
Tensor* tensor = createTestTensor(
sizes,
static_cast<int32_t>(slim_c10::ScalarType::Long),
static_cast<int32_t>(slim_c10::DeviceType::CPU),
0);
ASSERT_NE(tensor, nullptr);

bool result = false;
AOTITorchError error = aoti_torch_item_bool(tensor, &result);

EXPECT_EQ(error, Error::InvalidArgument);

EXPECT_EQ(aoti_torch_delete_tensor_object(tensor), Error::Ok);
}

// ============================================================================
// CUDA Tests
// ============================================================================

TEST_F(AOTITorchItemBoolSlimTest, TrueValue_CUDA) {
if (!isCudaAvailable()) {
GTEST_SKIP() << "CUDA not available";
}

Tensor* tensor = createScalarBoolTensor(
true, static_cast<int32_t>(slim_c10::DeviceType::CUDA), 0);
ASSERT_NE(tensor, nullptr);
EXPECT_TRUE(tensor->is_cuda());

bool result = false;
AOTITorchError error = aoti_torch_item_bool(tensor, &result);

EXPECT_EQ(error, Error::Ok);
EXPECT_EQ(result, true);

EXPECT_EQ(aoti_torch_delete_tensor_object(tensor), Error::Ok);
}

TEST_F(AOTITorchItemBoolSlimTest, FalseValue_CUDA) {
if (!isCudaAvailable()) {
GTEST_SKIP() << "CUDA not available";
}

Tensor* tensor = createScalarBoolTensor(
false, static_cast<int32_t>(slim_c10::DeviceType::CUDA), 0);
ASSERT_NE(tensor, nullptr);
EXPECT_TRUE(tensor->is_cuda());

bool result = true;
AOTITorchError error = aoti_torch_item_bool(tensor, &result);

EXPECT_EQ(error, Error::Ok);
EXPECT_EQ(result, false);

EXPECT_EQ(aoti_torch_delete_tensor_object(tensor), Error::Ok);
}

TEST_F(AOTITorchItemBoolSlimTest, MultiElementTensor_CUDA) {
if (!isCudaAvailable()) {
GTEST_SKIP() << "CUDA not available";
}

std::vector<int64_t> sizes = {2, 3};
Tensor* tensor = createTestTensor(
sizes,
static_cast<int32_t>(slim_c10::ScalarType::Bool),
static_cast<int32_t>(slim_c10::DeviceType::CUDA),
0);
ASSERT_NE(tensor, nullptr);
EXPECT_TRUE(tensor->is_cuda());

bool result = false;
AOTITorchError error = aoti_torch_item_bool(tensor, &result);

EXPECT_EQ(error, Error::InvalidArgument);

EXPECT_EQ(aoti_torch_delete_tensor_object(tensor), Error::Ok);
}

TEST_F(AOTITorchItemBoolSlimTest, WrongDtype_Float_CUDA) {
if (!isCudaAvailable()) {
GTEST_SKIP() << "CUDA not available";
}

std::vector<int64_t> sizes = {1};
Tensor* tensor = createTestTensor(
sizes,
static_cast<int32_t>(slim_c10::ScalarType::Float),
static_cast<int32_t>(slim_c10::DeviceType::CUDA),
0);
ASSERT_NE(tensor, nullptr);

bool result = false;
AOTITorchError error = aoti_torch_item_bool(tensor, &result);

EXPECT_EQ(error, Error::InvalidArgument);

EXPECT_EQ(aoti_torch_delete_tensor_object(tensor), Error::Ok);
}
Loading