Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 72 additions & 15 deletions backends/aoti/slim/c10/core/ScalarType.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,35 +12,64 @@
#include <cstdint>
#include <ostream>

#include <executorch/runtime/core/portable_type/bfloat16.h>
#include <executorch/runtime/platform/assert.h>

namespace executorch::backends::aoti::slim::c10 {

// Import BFloat16 from ExecuTorch's portable_type
using BFloat16 = ::executorch::runtime::etensor::BFloat16;

/// Enum representing the scalar type (dtype) of tensor elements.
/// Note: Enum values must match PyTorch's c10::ScalarType for compatibility.
enum class ScalarType : int8_t {
// Byte = 0,
// Char = 1,
// Short = 2,
// Int = 3,
// Long = 4,
Float = 6,
// Bool = 11,
// BFloat16 = 15,
// Byte = 0, // uint8_t - not currently needed
Char = 1, // int8_t
Short = 2, // int16_t
Int = 3, // int32_t
Long = 4, // int64_t
// Half = 5, // float16 - not currently needed
Float = 6, // float
// Double = 7, // double - not currently needed
// ComplexHalf = 8,
// ComplexFloat = 9,
// ComplexDouble = 10,
Bool = 11, // bool
// QInt8 = 12,
// QUInt8 = 13,
// QInt32 = 14,
BFloat16 = 15, // bfloat16
Undefined = -1,
NumOptions = 7,
};

/// Constant for Float scalar type.
// Type alias constants for convenience
constexpr ScalarType kChar = ScalarType::Char;
constexpr ScalarType kShort = ScalarType::Short;
constexpr ScalarType kInt = ScalarType::Int;
constexpr ScalarType kLong = ScalarType::Long;
constexpr ScalarType kFloat = ScalarType::Float;
constexpr ScalarType kBool = ScalarType::Bool;
constexpr ScalarType kBFloat16 = ScalarType::BFloat16;

/// Returns the size in bytes of a single element of the given scalar type.
/// @param t The scalar type.
/// @return The size in bytes of a single element.
inline size_t elementSize(ScalarType t) {
switch (t) {
case ScalarType::Char:
return sizeof(int8_t);
case ScalarType::Short:
return sizeof(int16_t);
case ScalarType::Int:
return sizeof(int32_t);
case ScalarType::Long:
return sizeof(int64_t);
case ScalarType::Float:
return sizeof(float);
case ScalarType::Bool:
return sizeof(bool);
case ScalarType::BFloat16:
return sizeof(BFloat16);
default:
ET_CHECK_MSG(false, "Unknown ScalarType: %d", static_cast<int>(t));
}
Expand All @@ -51,8 +80,20 @@ inline size_t elementSize(ScalarType t) {
/// @return The name of the scalar type.
inline const char* toString(ScalarType t) {
switch (t) {
case ScalarType::Char:
return "Char";
case ScalarType::Short:
return "Short";
case ScalarType::Int:
return "Int";
case ScalarType::Long:
return "Long";
case ScalarType::Float:
return "Float";
case ScalarType::Bool:
return "Bool";
case ScalarType::BFloat16:
return "BFloat16";
case ScalarType::Undefined:
return "Undefined";
default:
Expand All @@ -64,16 +105,32 @@ inline const char* toString(ScalarType t) {
/// @param t The scalar type to check.
/// @return true if the scalar type is floating point, false otherwise.
inline bool isFloatingType(ScalarType t) {
return t == ScalarType::Float;
return t == ScalarType::Float || t == ScalarType::BFloat16;
}

/// Checks if the scalar type is an integral type (including bool).
/// Checks if the scalar type is an integral type (including bool optionally).
/// @param t The scalar type to check.
/// @param includeBool Whether to consider Bool as integral.
/// @return true if the scalar type is integral, false otherwise.
inline bool isIntegralType(ScalarType t, bool /*includeBool*/) {
(void)t;
return false;
inline bool isIntegralType(ScalarType t, bool includeBool) {
switch (t) {
case ScalarType::Char:
case ScalarType::Short:
case ScalarType::Int:
case ScalarType::Long:
return true;
case ScalarType::Bool:
return includeBool;
default:
return false;
}
}

/// Checks if the scalar type is a boolean type.
/// @param t The scalar type to check.
/// @return true if the scalar type is Bool, false otherwise.
inline bool isBoolType(ScalarType t) {
return t == ScalarType::Bool;
}

inline std::ostream& operator<<(std::ostream& stream, ScalarType scalar_type) {
Expand Down
1 change: 1 addition & 0 deletions backends/aoti/slim/c10/core/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def define_common_targets():
],
visibility = ["@EXECUTORCH_CLIENTS"],
exported_deps = [
"//executorch/runtime/core/portable_type:portable_type",
"//executorch/runtime/platform:platform",
],
)
Expand Down
193 changes: 165 additions & 28 deletions backends/aoti/slim/c10/core/test/test_scalar_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,49 +13,186 @@

using namespace executorch::backends::aoti::slim::c10;

class ScalarTypeTest : public ::testing::Test {};
// =============================================================================
// Test Data Structures for Parameterized Tests
// =============================================================================

TEST_F(ScalarTypeTest, FloatEnumValue) {
// Verify Float has the correct enum value (6) to match PyTorch
EXPECT_EQ(static_cast<int>(ScalarType::Float), 6);
struct ScalarTypeTestData {
ScalarType dtype;
int expected_enum_value;
size_t expected_element_size;
const char* expected_name;
bool is_floating;
bool is_integral;
bool is_integral_with_bool;
bool is_bool;
};

// All supported scalar types with their expected properties
const std::vector<ScalarTypeTestData> kAllScalarTypes = {
// dtype, enum_value, element_size, name, is_float, is_int, is_int_w_bool,
// is_bool
{ScalarType::Char, 1, 1, "Char", false, true, true, false},
{ScalarType::Short, 2, 2, "Short", false, true, true, false},
{ScalarType::Int, 3, 4, "Int", false, true, true, false},
{ScalarType::Long, 4, 8, "Long", false, true, true, false},
{ScalarType::Float, 6, 4, "Float", true, false, false, false},
{ScalarType::Bool, 11, 1, "Bool", false, false, true, true},
{ScalarType::BFloat16, 15, 2, "BFloat16", true, false, false, false},
};

// =============================================================================
// Parameterized Test Fixture
// =============================================================================

class ScalarTypeParamTest
: public ::testing::TestWithParam<ScalarTypeTestData> {};

TEST_P(ScalarTypeParamTest, EnumValue) {
const auto& data = GetParam();
EXPECT_EQ(static_cast<int>(data.dtype), data.expected_enum_value)
<< "Failed for dtype: " << toString(data.dtype);
}

TEST_P(ScalarTypeParamTest, ElementSize) {
const auto& data = GetParam();
EXPECT_EQ(elementSize(data.dtype), data.expected_element_size)
<< "Failed for dtype: " << toString(data.dtype);
}

TEST_P(ScalarTypeParamTest, ToString) {
const auto& data = GetParam();
EXPECT_STREQ(toString(data.dtype), data.expected_name)
<< "Failed for dtype: " << toString(data.dtype);
}

TEST_P(ScalarTypeParamTest, IsFloatingType) {
const auto& data = GetParam();
EXPECT_EQ(isFloatingType(data.dtype), data.is_floating)
<< "Failed for dtype: " << toString(data.dtype);
}

TEST_P(ScalarTypeParamTest, IsIntegralTypeWithoutBool) {
const auto& data = GetParam();
EXPECT_EQ(isIntegralType(data.dtype, false), data.is_integral)
<< "Failed for dtype: " << toString(data.dtype);
}

TEST_P(ScalarTypeParamTest, IsIntegralTypeWithBool) {
const auto& data = GetParam();
EXPECT_EQ(isIntegralType(data.dtype, true), data.is_integral_with_bool)
<< "Failed for dtype: " << toString(data.dtype);
}

TEST_P(ScalarTypeParamTest, IsBoolType) {
const auto& data = GetParam();
EXPECT_EQ(isBoolType(data.dtype), data.is_bool)
<< "Failed for dtype: " << toString(data.dtype);
}

TEST_P(ScalarTypeParamTest, StreamOperator) {
const auto& data = GetParam();
std::ostringstream oss;
oss << data.dtype;
EXPECT_EQ(oss.str(), data.expected_name)
<< "Failed for dtype: " << toString(data.dtype);
}

INSTANTIATE_TEST_SUITE_P(
AllTypes,
ScalarTypeParamTest,
::testing::ValuesIn(kAllScalarTypes),
[](const ::testing::TestParamInfo<ScalarTypeTestData>& info) {
return std::string(info.param.expected_name);
});

// =============================================================================
// Type Constant Tests
// =============================================================================

class ScalarTypeConstantsTest : public ::testing::Test {};

TEST_F(ScalarTypeConstantsTest, KCharConstant) {
EXPECT_EQ(kChar, ScalarType::Char);
}

TEST_F(ScalarTypeConstantsTest, KShortConstant) {
EXPECT_EQ(kShort, ScalarType::Short);
}

TEST_F(ScalarTypeTest, KFloatConstant) {
// Verify kFloat constant
TEST_F(ScalarTypeConstantsTest, KIntConstant) {
EXPECT_EQ(kInt, ScalarType::Int);
}

TEST_F(ScalarTypeConstantsTest, KLongConstant) {
EXPECT_EQ(kLong, ScalarType::Long);
}

TEST_F(ScalarTypeConstantsTest, KFloatConstant) {
EXPECT_EQ(kFloat, ScalarType::Float);
}

TEST_F(ScalarTypeTest, ElementSizeFloat) {
// Verify elementSize returns correct size for Float (4 bytes)
EXPECT_EQ(elementSize(ScalarType::Float), sizeof(float));
EXPECT_EQ(elementSize(ScalarType::Float), 4);
TEST_F(ScalarTypeConstantsTest, KBoolConstant) {
EXPECT_EQ(kBool, ScalarType::Bool);
}

TEST_F(ScalarTypeTest, ToStringFloat) {
// Verify toString returns correct string for Float
EXPECT_STREQ(toString(ScalarType::Float), "Float");
TEST_F(ScalarTypeConstantsTest, KBFloat16Constant) {
EXPECT_EQ(kBFloat16, ScalarType::BFloat16);
}

TEST_F(ScalarTypeTest, ToStringUndefined) {
// Verify toString returns correct string for Undefined
// =============================================================================
// Edge Cases and Special Values
// =============================================================================

class ScalarTypeEdgeCasesTest : public ::testing::Test {};

TEST_F(ScalarTypeEdgeCasesTest, UndefinedToString) {
EXPECT_STREQ(toString(ScalarType::Undefined), "Undefined");
}

TEST_F(ScalarTypeTest, IsFloatingType) {
// Verify isFloatingType works correctly
EXPECT_TRUE(isFloatingType(ScalarType::Float));
TEST_F(ScalarTypeEdgeCasesTest, UndefinedIsNotFloating) {
EXPECT_FALSE(isFloatingType(ScalarType::Undefined));
}

TEST_F(ScalarTypeTest, IsIntegralType) {
// Verify isIntegralType works correctly
// Currently no integral types are supported, so Float should return false
EXPECT_FALSE(isIntegralType(ScalarType::Float, false));
EXPECT_FALSE(isIntegralType(ScalarType::Float, true));
TEST_F(ScalarTypeEdgeCasesTest, UndefinedIsNotIntegral) {
EXPECT_FALSE(isIntegralType(ScalarType::Undefined, false));
EXPECT_FALSE(isIntegralType(ScalarType::Undefined, true));
}

TEST_F(ScalarTypeTest, StreamOperator) {
// Verify stream operator works
std::ostringstream oss;
oss << ScalarType::Float;
EXPECT_EQ(oss.str(), "Float");
TEST_F(ScalarTypeEdgeCasesTest, UndefinedIsNotBool) {
EXPECT_FALSE(isBoolType(ScalarType::Undefined));
}

// =============================================================================
// Element Size Consistency Tests
// =============================================================================

class ElementSizeConsistencyTest : public ::testing::Test {};

TEST_F(ElementSizeConsistencyTest, CharMatchesSizeofInt8) {
EXPECT_EQ(elementSize(ScalarType::Char), sizeof(int8_t));
}

TEST_F(ElementSizeConsistencyTest, ShortMatchesSizeofInt16) {
EXPECT_EQ(elementSize(ScalarType::Short), sizeof(int16_t));
}

TEST_F(ElementSizeConsistencyTest, IntMatchesSizeofInt32) {
EXPECT_EQ(elementSize(ScalarType::Int), sizeof(int32_t));
}

TEST_F(ElementSizeConsistencyTest, LongMatchesSizeofInt64) {
EXPECT_EQ(elementSize(ScalarType::Long), sizeof(int64_t));
}

TEST_F(ElementSizeConsistencyTest, FloatMatchesSizeofFloat) {
EXPECT_EQ(elementSize(ScalarType::Float), sizeof(float));
}

TEST_F(ElementSizeConsistencyTest, BoolMatchesSizeofBool) {
EXPECT_EQ(elementSize(ScalarType::Bool), sizeof(bool));
}

TEST_F(ElementSizeConsistencyTest, BFloat16MatchesSizeofBFloat16) {
EXPECT_EQ(elementSize(ScalarType::BFloat16), sizeof(BFloat16));
}
10 changes: 10 additions & 0 deletions backends/aoti/slim/core/test/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,13 @@ def define_common_targets():
"//executorch/backends/aoti/slim/core:storage",
],
)

runtime.cxx_test(
name = "test_slimtensor_dtypes",
srcs = [
"test_slimtensor_dtypes.cpp",
],
deps = [
"//executorch/backends/aoti/slim/factory:empty",
],
)
Loading
Loading