From e4e1a495dd92af57a50a65e6537d77350c5668bd Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 23 Dec 2025 17:07:38 -0800 Subject: [PATCH] [slimtensor] Introduce Device and ScalarType headers for SlimTensor minimal support This diff introduces the foundational c10 core headers for SlimTensor, a lightweight tensor implementation used by torchnative, to cuda backend runtime and further it will be used by all aoti-driven backends like MPS. We add: - DeviceType.h - Device type enum (CPU only for now) - Device.h - Device class representing compute device location - ScalarType.h - Scalar type enum with elementSize() helper (Float only for now) These headers are modeled after PyTorch's c10 but simplified for our needs. The enum values are kept compatible with PyTorch for serialization compatibility. This is the first step in migrating SlimTensor to replace ETensor as the internal tensor representation in CUDA backend. Future diffs will add Storage, SlimTensor class, and additional dtypes/devices incrementally. Differential Revision: [D89747061](https://our.internmc.facebook.com/intern/diff/D89747061/) [ghstack-poisoned] --- backends/aoti/slim/c10/core/Device.h | 147 ++++++++++++++++++ backends/aoti/slim/c10/core/DeviceType.h | 66 ++++++++ backends/aoti/slim/c10/core/ScalarType.h | 83 ++++++++++ backends/aoti/slim/c10/core/TARGETS | 3 + backends/aoti/slim/c10/core/targets.bzl | 52 +++++++ backends/aoti/slim/c10/core/test/TARGETS | 3 + backends/aoti/slim/c10/core/test/targets.bzl | 25 +++ .../aoti/slim/c10/core/test/test_device.cpp | 111 +++++++++++++ .../slim/c10/core/test/test_scalar_type.cpp | 61 ++++++++ 9 files changed, 551 insertions(+) create mode 100644 backends/aoti/slim/c10/core/Device.h create mode 100644 backends/aoti/slim/c10/core/DeviceType.h create mode 100644 backends/aoti/slim/c10/core/ScalarType.h create mode 100644 backends/aoti/slim/c10/core/TARGETS create mode 100644 backends/aoti/slim/c10/core/targets.bzl create mode 100644 backends/aoti/slim/c10/core/test/TARGETS create mode 100644 backends/aoti/slim/c10/core/test/targets.bzl create mode 100644 backends/aoti/slim/c10/core/test/test_device.cpp create mode 100644 backends/aoti/slim/c10/core/test/test_scalar_type.cpp diff --git a/backends/aoti/slim/c10/core/Device.h b/backends/aoti/slim/c10/core/Device.h new file mode 100644 index 00000000000..08217443931 --- /dev/null +++ b/backends/aoti/slim/c10/core/Device.h @@ -0,0 +1,147 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +#include +#include + +namespace executorch::backends::aoti::slim::c10 { + +/// An index representing a specific device; e.g., the 1 in GPU 1. +/// A DeviceIndex is not independently meaningful without knowing +/// the DeviceType it is associated; try to use Device rather than +/// DeviceIndex directly. +using DeviceIndex = int8_t; + +/// Represents a compute device on which a tensor is located. +/// A device is uniquely identified by a type (e.g., CPU) and a device index. +struct Device final { + using Type = DeviceType; + + /// Constructs a new Device from a DeviceType and an optional device index. + /// @param type The type of device. + /// @param index The device index. For CPU, this should be -1 or 0. + /* implicit */ + Device(DeviceType type, DeviceIndex index = -1) : type_(type), index_(index) { + validate(); + } + + /// Constructs a Device from a string description. + /// The string must be "cpu" or "cpu:0". + /* implicit */ Device(const std::string& device_string) : Device(Type::CPU) { + ET_CHECK_MSG(!device_string.empty(), "Device string must not be empty"); + + if (device_string == "cpu" || device_string == "CPU") { + type_ = DeviceType::CPU; + index_ = -1; + } else if ( + device_string == "cpu:0" || device_string == "CPU:0" || + device_string == "cpu:1" || device_string == "CPU:1") { + type_ = DeviceType::CPU; + index_ = static_cast(device_string.back() - '0'); + } else { + ET_CHECK_MSG( + false, + "Invalid device string: %s. Currently only 'cpu' is supported.", + device_string.c_str()); + } + validate(); + } + + /// Returns true if the type and index of this Device matches that of other. + bool operator==(const Device& other) const noexcept { + return this->type_ == other.type_ && this->index_ == other.index_; + } + + /// Returns true if the type or index of this Device differs from that of + /// other. + bool operator!=(const Device& other) const noexcept { + return !(*this == other); + } + + /// Sets the device index. + void set_index(DeviceIndex index) { + index_ = index; + } + + /// Returns the type of device this is. + DeviceType type() const noexcept { + return type_; + } + + /// Returns the device index. + DeviceIndex index() const noexcept { + return index_; + } + + /// Returns true if the device has a non-default index. + bool has_index() const noexcept { + return index_ != -1; + } + + /// Returns true if the device is of CPU type. + bool is_cpu() const noexcept { + return type_ == DeviceType::CPU; + } + + /// Returns a string representation of the device (e.g., "cpu" or "cpu:0"). + std::string str() const { + std::string str = DeviceTypeName(type(), /* lower_case */ true); + if (has_index()) { + str.push_back(':'); + str.append(std::to_string(index())); + } + return str; + } + + private: + DeviceType type_; + DeviceIndex index_ = -1; + + void validate() { + ET_DCHECK_MSG( + index_ >= -1, + "Device index must be -1 or non-negative, got %d", + static_cast(index_)); + ET_DCHECK_MSG( + !is_cpu() || index_ <= 0, + "CPU device index must be -1 or zero, got %d", + static_cast(index_)); + } +}; + +inline std::ostream& operator<<(std::ostream& stream, const Device& device) { + stream << device.str(); + return stream; +} + +} // namespace executorch::backends::aoti::slim::c10 + +namespace std { +template <> +struct hash { + size_t operator()( + executorch::backends::aoti::slim::c10::Device d) const noexcept { + static_assert( + sizeof(executorch::backends::aoti::slim::c10::DeviceType) == 1, + "DeviceType is not 8-bit"); + static_assert( + sizeof(executorch::backends::aoti::slim::c10::DeviceIndex) == 1, + "DeviceIndex is not 8-bit"); + uint32_t bits = static_cast(static_cast(d.type())) + << 16 | + static_cast(static_cast(d.index())); + return std::hash{}(bits); + } +}; +} // namespace std diff --git a/backends/aoti/slim/c10/core/DeviceType.h b/backends/aoti/slim/c10/core/DeviceType.h new file mode 100644 index 00000000000..c8c36c7faab --- /dev/null +++ b/backends/aoti/slim/c10/core/DeviceType.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +#include + +namespace executorch::backends::aoti::slim::c10 { + +/// Enum representing the type of device. +enum class DeviceType : int8_t { + CPU = 0, + COMPILE_TIME_MAX_DEVICE_TYPES = 1, +}; + +constexpr DeviceType kCPU = DeviceType::CPU; + +/// Maximum number of device types at compile time. +constexpr int COMPILE_TIME_MAX_DEVICE_TYPES = + static_cast(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES); + +/// Returns the name of the device type as a string. +/// @param d The device type. +/// @param lower_case If true, returns the name in lower case. +/// @return The name of the device type. +inline std::string DeviceTypeName(DeviceType d, bool lower_case = false) { + switch (d) { + case DeviceType::CPU: + return lower_case ? "cpu" : "CPU"; + default: + ET_CHECK_MSG(false, "Unknown device type: %d", static_cast(d)); + } +} + +/// Checks if the device type is valid. +/// @param d The device type to check. +/// @return true if the device type is valid, false otherwise. +inline bool isValidDeviceType(DeviceType d) { + return d == DeviceType::CPU; +} + +inline std::ostream& operator<<(std::ostream& stream, DeviceType type) { + stream << DeviceTypeName(type, /* lower_case */ true); + return stream; +} + +} // namespace executorch::backends::aoti::slim::c10 + +namespace std { +template <> +struct hash { + std::size_t operator()( + executorch::backends::aoti::slim::c10::DeviceType k) const { + return std::hash()(static_cast(k)); + } +}; +} // namespace std diff --git a/backends/aoti/slim/c10/core/ScalarType.h b/backends/aoti/slim/c10/core/ScalarType.h new file mode 100644 index 00000000000..1ca1a1429ed --- /dev/null +++ b/backends/aoti/slim/c10/core/ScalarType.h @@ -0,0 +1,83 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +#include + +namespace executorch::backends::aoti::slim::c10 { + +/// Enum representing the scalar type (dtype) of tensor elements. +/// Note: Enum values must match PyTorch's c10::ScalarType for compatibility. +enum class ScalarType : int8_t { + // Byte = 0, + // Char = 1, + // Short = 2, + // Int = 3, + // Long = 4, + Float = 6, + // Bool = 11, + // BFloat16 = 15, + Undefined = -1, + NumOptions = 7, +}; + +/// Constant for Float scalar type. +constexpr ScalarType kFloat = ScalarType::Float; + +/// Returns the size in bytes of a single element of the given scalar type. +/// @param t The scalar type. +/// @return The size in bytes of a single element. +inline size_t elementSize(ScalarType t) { + switch (t) { + case ScalarType::Float: + return sizeof(float); + default: + ET_CHECK_MSG(false, "Unknown ScalarType: %d", static_cast(t)); + } +} + +/// Returns the name of the scalar type as a string. +/// @param t The scalar type. +/// @return The name of the scalar type. +inline const char* toString(ScalarType t) { + switch (t) { + case ScalarType::Float: + return "Float"; + case ScalarType::Undefined: + return "Undefined"; + default: + return "UNKNOWN_SCALAR"; + } +} + +/// Checks if the scalar type is a floating point type. +/// @param t The scalar type to check. +/// @return true if the scalar type is floating point, false otherwise. +inline bool isFloatingType(ScalarType t) { + return t == ScalarType::Float; +} + +/// Checks if the scalar type is an integral type (including bool). +/// @param t The scalar type to check. +/// @param includeBool Whether to consider Bool as integral. +/// @return true if the scalar type is integral, false otherwise. +inline bool isIntegralType(ScalarType t, bool /*includeBool*/) { + (void)t; + return false; +} + +inline std::ostream& operator<<(std::ostream& stream, ScalarType scalar_type) { + return stream << toString(scalar_type); +} + +} // namespace executorch::backends::aoti::slim::c10 diff --git a/backends/aoti/slim/c10/core/TARGETS b/backends/aoti/slim/c10/core/TARGETS new file mode 100644 index 00000000000..77871de4469 --- /dev/null +++ b/backends/aoti/slim/c10/core/TARGETS @@ -0,0 +1,3 @@ +load("targets.bzl", "define_common_targets") + +define_common_targets() diff --git a/backends/aoti/slim/c10/core/targets.bzl b/backends/aoti/slim/c10/core/targets.bzl new file mode 100644 index 00000000000..9b7d1259df0 --- /dev/null +++ b/backends/aoti/slim/c10/core/targets.bzl @@ -0,0 +1,52 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + """Define targets for SlimTensor c10 core module.""" + + # Header-only library for DeviceType + runtime.cxx_library( + name = "device_type", + headers = [ + "DeviceType.h", + ], + visibility = ["@EXECUTORCH_CLIENTS"], + exported_deps = [ + "//executorch/runtime/platform:platform", + ], + ) + + # Header-only library for Device + runtime.cxx_library( + name = "device", + headers = [ + "Device.h", + ], + visibility = ["@EXECUTORCH_CLIENTS"], + exported_deps = [ + ":device_type", + "//executorch/runtime/platform:platform", + ], + ) + + # Header-only library for ScalarType + runtime.cxx_library( + name = "scalar_type", + headers = [ + "ScalarType.h", + ], + visibility = ["@EXECUTORCH_CLIENTS"], + exported_deps = [ + "//executorch/runtime/platform:platform", + ], + ) + + # Combined c10 core library + runtime.cxx_library( + name = "core", + visibility = ["@EXECUTORCH_CLIENTS"], + exported_deps = [ + ":device", + ":device_type", + ":scalar_type", + ], + ) diff --git a/backends/aoti/slim/c10/core/test/TARGETS b/backends/aoti/slim/c10/core/test/TARGETS new file mode 100644 index 00000000000..77871de4469 --- /dev/null +++ b/backends/aoti/slim/c10/core/test/TARGETS @@ -0,0 +1,3 @@ +load("targets.bzl", "define_common_targets") + +define_common_targets() diff --git a/backends/aoti/slim/c10/core/test/targets.bzl b/backends/aoti/slim/c10/core/test/targets.bzl new file mode 100644 index 00000000000..f7abf59a273 --- /dev/null +++ b/backends/aoti/slim/c10/core/test/targets.bzl @@ -0,0 +1,25 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + """Define test targets for SlimTensor c10 core module.""" + + runtime.cxx_test( + name = "test_device", + srcs = [ + "test_device.cpp", + ], + deps = [ + "//executorch/backends/aoti/slim/c10/core:device", + "//executorch/backends/aoti/slim/c10/core:device_type", + ], + ) + + runtime.cxx_test( + name = "test_scalar_type", + srcs = [ + "test_scalar_type.cpp", + ], + deps = [ + "//executorch/backends/aoti/slim/c10/core:scalar_type", + ], + ) diff --git a/backends/aoti/slim/c10/core/test/test_device.cpp b/backends/aoti/slim/c10/core/test/test_device.cpp new file mode 100644 index 00000000000..57123589775 --- /dev/null +++ b/backends/aoti/slim/c10/core/test/test_device.cpp @@ -0,0 +1,111 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +using namespace executorch::backends::aoti::slim::c10; + +class DeviceTypeTest : public ::testing::Test {}; + +TEST_F(DeviceTypeTest, CPUEnumValue) { + // Verify CPU has the correct enum value (0) + EXPECT_EQ(static_cast(DeviceType::CPU), 0); +} + +TEST_F(DeviceTypeTest, DeviceTypeName) { + // Verify DeviceTypeName returns correct strings + EXPECT_EQ(DeviceTypeName(DeviceType::CPU, false), "CPU"); + EXPECT_EQ(DeviceTypeName(DeviceType::CPU, true), "cpu"); +} + +TEST_F(DeviceTypeTest, IsValidDeviceType) { + // Verify isValidDeviceType works correctly + EXPECT_TRUE(isValidDeviceType(DeviceType::CPU)); +} + +TEST_F(DeviceTypeTest, KCPUConstant) { + // Verify kCPU constant + EXPECT_EQ(kCPU, DeviceType::CPU); +} + +class DeviceTest : public ::testing::Test {}; + +TEST_F(DeviceTest, ConstructFromDeviceType) { + // Construct Device from DeviceType + Device cpu_device(DeviceType::CPU); + + EXPECT_TRUE(cpu_device.is_cpu()); + EXPECT_EQ(cpu_device.type(), DeviceType::CPU); + EXPECT_EQ(cpu_device.index(), -1); // Default index + EXPECT_FALSE(cpu_device.has_index()); +} + +TEST_F(DeviceTest, ConstructWithIndex) { + // Construct Device with explicit index + Device cpu_device(DeviceType::CPU, 0); + + EXPECT_TRUE(cpu_device.is_cpu()); + EXPECT_EQ(cpu_device.type(), DeviceType::CPU); + EXPECT_EQ(cpu_device.index(), 0); + EXPECT_TRUE(cpu_device.has_index()); +} + +TEST_F(DeviceTest, ConstructFromString) { + // Construct Device from string + Device cpu1("cpu"); + EXPECT_TRUE(cpu1.is_cpu()); + EXPECT_EQ(cpu1.index(), -1); + + Device cpu2("CPU"); + EXPECT_TRUE(cpu2.is_cpu()); + EXPECT_EQ(cpu2.index(), -1); + + Device cpu3("cpu:0"); + EXPECT_TRUE(cpu3.is_cpu()); + EXPECT_EQ(cpu3.index(), 0); +} + +TEST_F(DeviceTest, Equality) { + Device cpu1(DeviceType::CPU, 0); + Device cpu2(DeviceType::CPU, 0); + Device cpu3(DeviceType::CPU, -1); + + EXPECT_EQ(cpu1, cpu2); + EXPECT_NE(cpu1, cpu3); +} + +TEST_F(DeviceTest, Str) { + Device cpu1(DeviceType::CPU); + EXPECT_EQ(cpu1.str(), "cpu"); + + Device cpu2(DeviceType::CPU, 0); + EXPECT_EQ(cpu2.str(), "cpu:0"); +} + +TEST_F(DeviceTest, SetIndex) { + Device cpu(DeviceType::CPU); + EXPECT_EQ(cpu.index(), -1); + + cpu.set_index(0); + EXPECT_EQ(cpu.index(), 0); + EXPECT_TRUE(cpu.has_index()); +} + +TEST_F(DeviceTest, Hash) { + // Verify Device can be hashed (for use in unordered containers) + Device cpu1(DeviceType::CPU, 0); + Device cpu2(DeviceType::CPU, 0); + Device cpu3(DeviceType::CPU, -1); + + std::hash hasher; + EXPECT_EQ(hasher(cpu1), hasher(cpu2)); + EXPECT_NE(hasher(cpu1), hasher(cpu3)); +} diff --git a/backends/aoti/slim/c10/core/test/test_scalar_type.cpp b/backends/aoti/slim/c10/core/test/test_scalar_type.cpp new file mode 100644 index 00000000000..673641d84c7 --- /dev/null +++ b/backends/aoti/slim/c10/core/test/test_scalar_type.cpp @@ -0,0 +1,61 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +using namespace executorch::backends::aoti::slim::c10; + +class ScalarTypeTest : public ::testing::Test {}; + +TEST_F(ScalarTypeTest, FloatEnumValue) { + // Verify Float has the correct enum value (6) to match PyTorch + EXPECT_EQ(static_cast(ScalarType::Float), 6); +} + +TEST_F(ScalarTypeTest, KFloatConstant) { + // Verify kFloat constant + EXPECT_EQ(kFloat, ScalarType::Float); +} + +TEST_F(ScalarTypeTest, ElementSizeFloat) { + // Verify elementSize returns correct size for Float (4 bytes) + EXPECT_EQ(elementSize(ScalarType::Float), sizeof(float)); + EXPECT_EQ(elementSize(ScalarType::Float), 4); +} + +TEST_F(ScalarTypeTest, ToStringFloat) { + // Verify toString returns correct string for Float + EXPECT_STREQ(toString(ScalarType::Float), "Float"); +} + +TEST_F(ScalarTypeTest, ToStringUndefined) { + // Verify toString returns correct string for Undefined + EXPECT_STREQ(toString(ScalarType::Undefined), "Undefined"); +} + +TEST_F(ScalarTypeTest, IsFloatingType) { + // Verify isFloatingType works correctly + EXPECT_TRUE(isFloatingType(ScalarType::Float)); +} + +TEST_F(ScalarTypeTest, IsIntegralType) { + // Verify isIntegralType works correctly + // Currently no integral types are supported, so Float should return false + EXPECT_FALSE(isIntegralType(ScalarType::Float, false)); + EXPECT_FALSE(isIntegralType(ScalarType::Float, true)); +} + +TEST_F(ScalarTypeTest, StreamOperator) { + // Verify stream operator works + std::ostringstream oss; + oss << ScalarType::Float; + EXPECT_EQ(oss.str(), "Float"); +}