From 0723a1ed4de7e278da5c2ef956b76f510a728b66 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Fri, 26 Dec 2025 12:41:43 -0800 Subject: [PATCH] [slimtensor] Add all required dtype support (Int8/16/32/64, Bool, BFloat16) This diff adds support for all required scalar types in SlimTensor to support ExecuTorch aoti-driven backend usage: Int8 (Char), Int16 (Short), Int32 (Int), Int64 (Long), Bool, and BFloat16. **Key changes:** 1. **`c10/core/ScalarType.h`** - Extended with all required types: - Added enum values matching PyTorch's c10::ScalarType for compatibility - Added type alias constants (kChar, kShort, kInt, kLong, kBool, kBFloat16) - Extended `elementSize()` to return correct sizes for all types - Extended `toString()` for all types - Fixed `isFloatingType()` to include BFloat16 - Fixed `isIntegralType()` to properly handle all integral types and Bool - Added `isBoolType()` helper function - Imported BFloat16 from ExecuTorch's portable_type Differential Revision: [D89821402](https://our.internmc.facebook.com/intern/diff/D89821402/) [ghstack-poisoned] --- backends/aoti/slim/c10/core/ScalarType.h | 87 +++++- backends/aoti/slim/c10/core/targets.bzl | 1 + .../slim/c10/core/test/test_scalar_type.cpp | 193 ++++++++++-- backends/aoti/slim/core/test/targets.bzl | 10 + .../slim/core/test/test_slimtensor_dtypes.cpp | 290 ++++++++++++++++++ 5 files changed, 538 insertions(+), 43 deletions(-) create mode 100644 backends/aoti/slim/core/test/test_slimtensor_dtypes.cpp diff --git a/backends/aoti/slim/c10/core/ScalarType.h b/backends/aoti/slim/c10/core/ScalarType.h index 1ca1a1429ed..28391f012d5 100644 --- a/backends/aoti/slim/c10/core/ScalarType.h +++ b/backends/aoti/slim/c10/core/ScalarType.h @@ -12,35 +12,64 @@ #include #include +#include #include namespace executorch::backends::aoti::slim::c10 { +// Import BFloat16 from ExecuTorch's portable_type +using BFloat16 = ::executorch::runtime::etensor::BFloat16; + /// Enum representing the scalar type (dtype) of tensor elements. /// Note: Enum values must match PyTorch's c10::ScalarType for compatibility. enum class ScalarType : int8_t { - // Byte = 0, - // Char = 1, - // Short = 2, - // Int = 3, - // Long = 4, - Float = 6, - // Bool = 11, - // BFloat16 = 15, + // Byte = 0, // uint8_t - not currently needed + Char = 1, // int8_t + Short = 2, // int16_t + Int = 3, // int32_t + Long = 4, // int64_t + // Half = 5, // float16 - not currently needed + Float = 6, // float + // Double = 7, // double - not currently needed + // ComplexHalf = 8, + // ComplexFloat = 9, + // ComplexDouble = 10, + Bool = 11, // bool + // QInt8 = 12, + // QUInt8 = 13, + // QInt32 = 14, + BFloat16 = 15, // bfloat16 Undefined = -1, - NumOptions = 7, }; -/// Constant for Float scalar type. +// Type alias constants for convenience +constexpr ScalarType kChar = ScalarType::Char; +constexpr ScalarType kShort = ScalarType::Short; +constexpr ScalarType kInt = ScalarType::Int; +constexpr ScalarType kLong = ScalarType::Long; constexpr ScalarType kFloat = ScalarType::Float; +constexpr ScalarType kBool = ScalarType::Bool; +constexpr ScalarType kBFloat16 = ScalarType::BFloat16; /// Returns the size in bytes of a single element of the given scalar type. /// @param t The scalar type. /// @return The size in bytes of a single element. inline size_t elementSize(ScalarType t) { switch (t) { + case ScalarType::Char: + return sizeof(int8_t); + case ScalarType::Short: + return sizeof(int16_t); + case ScalarType::Int: + return sizeof(int32_t); + case ScalarType::Long: + return sizeof(int64_t); case ScalarType::Float: return sizeof(float); + case ScalarType::Bool: + return sizeof(bool); + case ScalarType::BFloat16: + return sizeof(BFloat16); default: ET_CHECK_MSG(false, "Unknown ScalarType: %d", static_cast(t)); } @@ -51,8 +80,20 @@ inline size_t elementSize(ScalarType t) { /// @return The name of the scalar type. inline const char* toString(ScalarType t) { switch (t) { + case ScalarType::Char: + return "Char"; + case ScalarType::Short: + return "Short"; + case ScalarType::Int: + return "Int"; + case ScalarType::Long: + return "Long"; case ScalarType::Float: return "Float"; + case ScalarType::Bool: + return "Bool"; + case ScalarType::BFloat16: + return "BFloat16"; case ScalarType::Undefined: return "Undefined"; default: @@ -64,16 +105,32 @@ inline const char* toString(ScalarType t) { /// @param t The scalar type to check. /// @return true if the scalar type is floating point, false otherwise. inline bool isFloatingType(ScalarType t) { - return t == ScalarType::Float; + return t == ScalarType::Float || t == ScalarType::BFloat16; } -/// Checks if the scalar type is an integral type (including bool). +/// Checks if the scalar type is an integral type (including bool optionally). /// @param t The scalar type to check. /// @param includeBool Whether to consider Bool as integral. /// @return true if the scalar type is integral, false otherwise. -inline bool isIntegralType(ScalarType t, bool /*includeBool*/) { - (void)t; - return false; +inline bool isIntegralType(ScalarType t, bool includeBool) { + switch (t) { + case ScalarType::Char: + case ScalarType::Short: + case ScalarType::Int: + case ScalarType::Long: + return true; + case ScalarType::Bool: + return includeBool; + default: + return false; + } +} + +/// Checks if the scalar type is a boolean type. +/// @param t The scalar type to check. +/// @return true if the scalar type is Bool, false otherwise. +inline bool isBoolType(ScalarType t) { + return t == ScalarType::Bool; } inline std::ostream& operator<<(std::ostream& stream, ScalarType scalar_type) { diff --git a/backends/aoti/slim/c10/core/targets.bzl b/backends/aoti/slim/c10/core/targets.bzl index 500620aecd1..5a9b9558938 100644 --- a/backends/aoti/slim/c10/core/targets.bzl +++ b/backends/aoti/slim/c10/core/targets.bzl @@ -36,6 +36,7 @@ def define_common_targets(): ], visibility = ["@EXECUTORCH_CLIENTS"], exported_deps = [ + "//executorch/runtime/core/portable_type:portable_type", "//executorch/runtime/platform:platform", ], ) diff --git a/backends/aoti/slim/c10/core/test/test_scalar_type.cpp b/backends/aoti/slim/c10/core/test/test_scalar_type.cpp index 673641d84c7..332f5d7d264 100644 --- a/backends/aoti/slim/c10/core/test/test_scalar_type.cpp +++ b/backends/aoti/slim/c10/core/test/test_scalar_type.cpp @@ -13,49 +13,186 @@ using namespace executorch::backends::aoti::slim::c10; -class ScalarTypeTest : public ::testing::Test {}; +// ============================================================================= +// Test Data Structures for Parameterized Tests +// ============================================================================= -TEST_F(ScalarTypeTest, FloatEnumValue) { - // Verify Float has the correct enum value (6) to match PyTorch - EXPECT_EQ(static_cast(ScalarType::Float), 6); +struct ScalarTypeTestData { + ScalarType dtype; + int expected_enum_value; + size_t expected_element_size; + const char* expected_name; + bool is_floating; + bool is_integral; + bool is_integral_with_bool; + bool is_bool; +}; + +// All supported scalar types with their expected properties +const std::vector kAllScalarTypes = { + // dtype, enum_value, element_size, name, is_float, is_int, is_int_w_bool, + // is_bool + {ScalarType::Char, 1, 1, "Char", false, true, true, false}, + {ScalarType::Short, 2, 2, "Short", false, true, true, false}, + {ScalarType::Int, 3, 4, "Int", false, true, true, false}, + {ScalarType::Long, 4, 8, "Long", false, true, true, false}, + {ScalarType::Float, 6, 4, "Float", true, false, false, false}, + {ScalarType::Bool, 11, 1, "Bool", false, false, true, true}, + {ScalarType::BFloat16, 15, 2, "BFloat16", true, false, false, false}, +}; + +// ============================================================================= +// Parameterized Test Fixture +// ============================================================================= + +class ScalarTypeParamTest + : public ::testing::TestWithParam {}; + +TEST_P(ScalarTypeParamTest, EnumValue) { + const auto& data = GetParam(); + EXPECT_EQ(static_cast(data.dtype), data.expected_enum_value) + << "Failed for dtype: " << toString(data.dtype); +} + +TEST_P(ScalarTypeParamTest, ElementSize) { + const auto& data = GetParam(); + EXPECT_EQ(elementSize(data.dtype), data.expected_element_size) + << "Failed for dtype: " << toString(data.dtype); +} + +TEST_P(ScalarTypeParamTest, ToString) { + const auto& data = GetParam(); + EXPECT_STREQ(toString(data.dtype), data.expected_name) + << "Failed for dtype: " << toString(data.dtype); +} + +TEST_P(ScalarTypeParamTest, IsFloatingType) { + const auto& data = GetParam(); + EXPECT_EQ(isFloatingType(data.dtype), data.is_floating) + << "Failed for dtype: " << toString(data.dtype); +} + +TEST_P(ScalarTypeParamTest, IsIntegralTypeWithoutBool) { + const auto& data = GetParam(); + EXPECT_EQ(isIntegralType(data.dtype, false), data.is_integral) + << "Failed for dtype: " << toString(data.dtype); +} + +TEST_P(ScalarTypeParamTest, IsIntegralTypeWithBool) { + const auto& data = GetParam(); + EXPECT_EQ(isIntegralType(data.dtype, true), data.is_integral_with_bool) + << "Failed for dtype: " << toString(data.dtype); +} + +TEST_P(ScalarTypeParamTest, IsBoolType) { + const auto& data = GetParam(); + EXPECT_EQ(isBoolType(data.dtype), data.is_bool) + << "Failed for dtype: " << toString(data.dtype); +} + +TEST_P(ScalarTypeParamTest, StreamOperator) { + const auto& data = GetParam(); + std::ostringstream oss; + oss << data.dtype; + EXPECT_EQ(oss.str(), data.expected_name) + << "Failed for dtype: " << toString(data.dtype); +} + +INSTANTIATE_TEST_SUITE_P( + AllTypes, + ScalarTypeParamTest, + ::testing::ValuesIn(kAllScalarTypes), + [](const ::testing::TestParamInfo& info) { + return std::string(info.param.expected_name); + }); + +// ============================================================================= +// Type Constant Tests +// ============================================================================= + +class ScalarTypeConstantsTest : public ::testing::Test {}; + +TEST_F(ScalarTypeConstantsTest, KCharConstant) { + EXPECT_EQ(kChar, ScalarType::Char); +} + +TEST_F(ScalarTypeConstantsTest, KShortConstant) { + EXPECT_EQ(kShort, ScalarType::Short); } -TEST_F(ScalarTypeTest, KFloatConstant) { - // Verify kFloat constant +TEST_F(ScalarTypeConstantsTest, KIntConstant) { + EXPECT_EQ(kInt, ScalarType::Int); +} + +TEST_F(ScalarTypeConstantsTest, KLongConstant) { + EXPECT_EQ(kLong, ScalarType::Long); +} + +TEST_F(ScalarTypeConstantsTest, KFloatConstant) { EXPECT_EQ(kFloat, ScalarType::Float); } -TEST_F(ScalarTypeTest, ElementSizeFloat) { - // Verify elementSize returns correct size for Float (4 bytes) - EXPECT_EQ(elementSize(ScalarType::Float), sizeof(float)); - EXPECT_EQ(elementSize(ScalarType::Float), 4); +TEST_F(ScalarTypeConstantsTest, KBoolConstant) { + EXPECT_EQ(kBool, ScalarType::Bool); } -TEST_F(ScalarTypeTest, ToStringFloat) { - // Verify toString returns correct string for Float - EXPECT_STREQ(toString(ScalarType::Float), "Float"); +TEST_F(ScalarTypeConstantsTest, KBFloat16Constant) { + EXPECT_EQ(kBFloat16, ScalarType::BFloat16); } -TEST_F(ScalarTypeTest, ToStringUndefined) { - // Verify toString returns correct string for Undefined +// ============================================================================= +// Edge Cases and Special Values +// ============================================================================= + +class ScalarTypeEdgeCasesTest : public ::testing::Test {}; + +TEST_F(ScalarTypeEdgeCasesTest, UndefinedToString) { EXPECT_STREQ(toString(ScalarType::Undefined), "Undefined"); } -TEST_F(ScalarTypeTest, IsFloatingType) { - // Verify isFloatingType works correctly - EXPECT_TRUE(isFloatingType(ScalarType::Float)); +TEST_F(ScalarTypeEdgeCasesTest, UndefinedIsNotFloating) { + EXPECT_FALSE(isFloatingType(ScalarType::Undefined)); } -TEST_F(ScalarTypeTest, IsIntegralType) { - // Verify isIntegralType works correctly - // Currently no integral types are supported, so Float should return false - EXPECT_FALSE(isIntegralType(ScalarType::Float, false)); - EXPECT_FALSE(isIntegralType(ScalarType::Float, true)); +TEST_F(ScalarTypeEdgeCasesTest, UndefinedIsNotIntegral) { + EXPECT_FALSE(isIntegralType(ScalarType::Undefined, false)); + EXPECT_FALSE(isIntegralType(ScalarType::Undefined, true)); } -TEST_F(ScalarTypeTest, StreamOperator) { - // Verify stream operator works - std::ostringstream oss; - oss << ScalarType::Float; - EXPECT_EQ(oss.str(), "Float"); +TEST_F(ScalarTypeEdgeCasesTest, UndefinedIsNotBool) { + EXPECT_FALSE(isBoolType(ScalarType::Undefined)); +} + +// ============================================================================= +// Element Size Consistency Tests +// ============================================================================= + +class ElementSizeConsistencyTest : public ::testing::Test {}; + +TEST_F(ElementSizeConsistencyTest, CharMatchesSizeofInt8) { + EXPECT_EQ(elementSize(ScalarType::Char), sizeof(int8_t)); +} + +TEST_F(ElementSizeConsistencyTest, ShortMatchesSizeofInt16) { + EXPECT_EQ(elementSize(ScalarType::Short), sizeof(int16_t)); +} + +TEST_F(ElementSizeConsistencyTest, IntMatchesSizeofInt32) { + EXPECT_EQ(elementSize(ScalarType::Int), sizeof(int32_t)); +} + +TEST_F(ElementSizeConsistencyTest, LongMatchesSizeofInt64) { + EXPECT_EQ(elementSize(ScalarType::Long), sizeof(int64_t)); +} + +TEST_F(ElementSizeConsistencyTest, FloatMatchesSizeofFloat) { + EXPECT_EQ(elementSize(ScalarType::Float), sizeof(float)); +} + +TEST_F(ElementSizeConsistencyTest, BoolMatchesSizeofBool) { + EXPECT_EQ(elementSize(ScalarType::Bool), sizeof(bool)); +} + +TEST_F(ElementSizeConsistencyTest, BFloat16MatchesSizeofBFloat16) { + EXPECT_EQ(elementSize(ScalarType::BFloat16), sizeof(BFloat16)); } diff --git a/backends/aoti/slim/core/test/targets.bzl b/backends/aoti/slim/core/test/targets.bzl index 4d7ec4b0fbf..1bfc816fa3a 100644 --- a/backends/aoti/slim/core/test/targets.bzl +++ b/backends/aoti/slim/core/test/targets.bzl @@ -34,3 +34,13 @@ def define_common_targets(): "//executorch/backends/aoti/slim/core:storage", ], ) + + runtime.cxx_test( + name = "test_slimtensor_dtypes", + srcs = [ + "test_slimtensor_dtypes.cpp", + ], + deps = [ + "//executorch/backends/aoti/slim/factory:empty", + ], + ) diff --git a/backends/aoti/slim/core/test/test_slimtensor_dtypes.cpp b/backends/aoti/slim/core/test/test_slimtensor_dtypes.cpp new file mode 100644 index 00000000000..8ecb8d977b7 --- /dev/null +++ b/backends/aoti/slim/core/test/test_slimtensor_dtypes.cpp @@ -0,0 +1,290 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +#include + +namespace executorch::backends::aoti::slim { + +// ============================================================================= +// Test Data Structures for Parameterized Tests +// ============================================================================= + +template +struct DTypeTraits; + +template <> +struct DTypeTraits { + static constexpr c10::ScalarType dtype = c10::ScalarType::Char; + static constexpr const char* name = "Char"; + static int8_t test_value(size_t i) { + return static_cast(i % 127); + } +}; + +template <> +struct DTypeTraits { + static constexpr c10::ScalarType dtype = c10::ScalarType::Short; + static constexpr const char* name = "Short"; + static int16_t test_value(size_t i) { + return static_cast(i * 10); + } +}; + +template <> +struct DTypeTraits { + static constexpr c10::ScalarType dtype = c10::ScalarType::Int; + static constexpr const char* name = "Int"; + static int32_t test_value(size_t i) { + return static_cast(i * 100); + } +}; + +template <> +struct DTypeTraits { + static constexpr c10::ScalarType dtype = c10::ScalarType::Long; + static constexpr const char* name = "Long"; + static int64_t test_value(size_t i) { + return static_cast(i * 1000); + } +}; + +template <> +struct DTypeTraits { + static constexpr c10::ScalarType dtype = c10::ScalarType::Float; + static constexpr const char* name = "Float"; + static float test_value(size_t i) { + return static_cast(i) * 1.5f; + } +}; + +template <> +struct DTypeTraits { + static constexpr c10::ScalarType dtype = c10::ScalarType::Bool; + static constexpr const char* name = "Bool"; + static bool test_value(size_t i) { + return (i % 2) == 0; + } +}; + +template <> +struct DTypeTraits { + static constexpr c10::ScalarType dtype = c10::ScalarType::BFloat16; + static constexpr const char* name = "BFloat16"; + static c10::BFloat16 test_value(size_t i) { + return c10::BFloat16(static_cast(i) * 0.5f); + } +}; + +// ============================================================================= +// Typed Test Fixture +// ============================================================================= + +template +class SlimTensorDTypeTest : public ::testing::Test { + protected: + static constexpr c10::ScalarType kDType = DTypeTraits::dtype; + static constexpr size_t kNumel = 24; + static constexpr std::array kSizes = {2, 3, 4}; + + SlimTensor create_tensor() { + return empty({2, 3, 4}, kDType); + } + + void fill_tensor(SlimTensor& tensor) { + T* data = static_cast(tensor.data_ptr()); + for (size_t i = 0; i < tensor.numel(); ++i) { + data[i] = DTypeTraits::test_value(i); + } + } + + void verify_tensor_values(const SlimTensor& tensor) { + const T* data = static_cast(tensor.data_ptr()); + for (size_t i = 0; i < tensor.numel(); ++i) { + T expected = DTypeTraits::test_value(i); + if constexpr (std::is_same_v) { + EXPECT_FLOAT_EQ(data[i], expected) << "Mismatch at index " << i; + } else if constexpr (std::is_same_v) { + EXPECT_FLOAT_EQ( + static_cast(data[i]), static_cast(expected)) + << "Mismatch at index " << i; + } else { + EXPECT_EQ(data[i], expected) << "Mismatch at index " << i; + } + } + } +}; + +// Define the types to test +using DTypeTestTypes = ::testing:: + Types; + +TYPED_TEST_SUITE(SlimTensorDTypeTest, DTypeTestTypes); + +// ============================================================================= +// Core Tensor Creation Tests +// ============================================================================= + +TYPED_TEST(SlimTensorDTypeTest, CreateEmptyTensor) { + SlimTensor tensor = this->create_tensor(); + + EXPECT_TRUE(tensor.defined()); + EXPECT_EQ(tensor.dtype(), this->kDType); + EXPECT_EQ(tensor.dim(), 3u); + EXPECT_EQ(tensor.numel(), this->kNumel); + EXPECT_TRUE(tensor.is_cpu()); + EXPECT_TRUE(tensor.is_contiguous()); +} + +TYPED_TEST(SlimTensorDTypeTest, CorrectElementSize) { + SlimTensor tensor = this->create_tensor(); + EXPECT_EQ(tensor.itemsize(), sizeof(TypeParam)); +} + +TYPED_TEST(SlimTensorDTypeTest, CorrectNbytes) { + SlimTensor tensor = this->create_tensor(); + EXPECT_EQ(tensor.nbytes(), this->kNumel * sizeof(TypeParam)); +} + +TYPED_TEST(SlimTensorDTypeTest, DataPtrIsValid) { + SlimTensor tensor = this->create_tensor(); + EXPECT_NE(tensor.data_ptr(), nullptr); +} + +// ============================================================================= +// Data Read/Write Tests +// ============================================================================= + +TYPED_TEST(SlimTensorDTypeTest, WriteAndReadData) { + SlimTensor tensor = this->create_tensor(); + this->fill_tensor(tensor); + this->verify_tensor_values(tensor); +} + +TYPED_TEST(SlimTensorDTypeTest, ZeroInitialize) { + SlimTensor tensor = this->create_tensor(); + std::memset(tensor.data_ptr(), 0, tensor.nbytes()); + + const TypeParam* data = static_cast(tensor.data_ptr()); + for (size_t i = 0; i < tensor.numel(); ++i) { + if constexpr (std::is_same_v) { + EXPECT_FALSE(data[i]) << "Non-zero at index " << i; + } else if constexpr (std::is_same_v) { + EXPECT_FLOAT_EQ(data[i], 0.0f) << "Non-zero at index " << i; + } else if constexpr (std::is_same_v) { + EXPECT_FLOAT_EQ(static_cast(data[i]), 0.0f) + << "Non-zero at index " << i; + } else { + EXPECT_EQ(data[i], static_cast(0)) + << "Non-zero at index " << i; + } + } +} + +// ============================================================================= +// Copy Tests +// ============================================================================= + +TYPED_TEST(SlimTensorDTypeTest, CopyContiguousTensor) { + SlimTensor src = this->create_tensor(); + this->fill_tensor(src); + + SlimTensor dst = this->create_tensor(); + dst.copy_(src); + + this->verify_tensor_values(dst); +} + +TYPED_TEST(SlimTensorDTypeTest, CopyPreservesSourceData) { + SlimTensor src = this->create_tensor(); + this->fill_tensor(src); + + SlimTensor dst = this->create_tensor(); + dst.copy_(src); + + // Modify dst and verify src is unchanged + std::memset(dst.data_ptr(), 0, dst.nbytes()); + + // src should still have original values + this->verify_tensor_values(src); +} + +// ============================================================================= +// Empty Strided Tests +// ============================================================================= + +TYPED_TEST(SlimTensorDTypeTest, EmptyStridedCreation) { + std::vector sizes = {2, 3, 4}; + std::vector strides = {12, 4, 1}; + + SlimTensor tensor = + empty_strided(makeArrayRef(sizes), makeArrayRef(strides), this->kDType); + + EXPECT_EQ(tensor.dtype(), this->kDType); + EXPECT_TRUE(tensor.is_contiguous()); +} + +TYPED_TEST(SlimTensorDTypeTest, NonContiguousStrides) { + std::vector sizes = {3, 2}; + std::vector strides = {1, 3}; + + SlimTensor tensor = + empty_strided(makeArrayRef(sizes), makeArrayRef(strides), this->kDType); + + EXPECT_EQ(tensor.dtype(), this->kDType); + EXPECT_FALSE(tensor.is_contiguous()); +} + +// ============================================================================= +// Empty Like Tests +// ============================================================================= + +TYPED_TEST(SlimTensorDTypeTest, EmptyLikePreservesDType) { + SlimTensor original = this->create_tensor(); + SlimTensor copy = empty_like(original); + + EXPECT_EQ(copy.dtype(), original.dtype()); + EXPECT_EQ(copy.numel(), original.numel()); + EXPECT_EQ(copy.dim(), original.dim()); + EXPECT_NE(copy.data_ptr(), original.data_ptr()); +} + +// ============================================================================= +// Dimension and Shape Tests +// ============================================================================= + +TYPED_TEST(SlimTensorDTypeTest, OneDimensionalTensor) { + SlimTensor tensor = empty({10}, this->kDType); + + EXPECT_EQ(tensor.dim(), 1u); + EXPECT_EQ(tensor.numel(), 10u); + EXPECT_EQ(tensor.size(0), 10); + EXPECT_EQ(tensor.stride(0), 1); +} + +TYPED_TEST(SlimTensorDTypeTest, FourDimensionalTensor) { + SlimTensor tensor = empty({2, 3, 4, 5}, this->kDType); + + EXPECT_EQ(tensor.dim(), 4u); + EXPECT_EQ(tensor.numel(), 120u); + EXPECT_TRUE(tensor.is_contiguous()); +} + +TYPED_TEST(SlimTensorDTypeTest, ZeroSizedTensor) { + SlimTensor tensor = empty({0, 5}, this->kDType); + + EXPECT_TRUE(tensor.is_empty()); + EXPECT_EQ(tensor.numel(), 0u); + EXPECT_EQ(tensor.dtype(), this->kDType); +} + +} // namespace executorch::backends::aoti::slim