diff --git a/backends/aoti/slim/c10/core/Device.h b/backends/aoti/slim/c10/core/Device.h new file mode 100644 index 00000000000..08217443931 --- /dev/null +++ b/backends/aoti/slim/c10/core/Device.h @@ -0,0 +1,147 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +#include +#include + +namespace executorch::backends::aoti::slim::c10 { + +/// An index representing a specific device; e.g., the 1 in GPU 1. +/// A DeviceIndex is not independently meaningful without knowing +/// the DeviceType it is associated; try to use Device rather than +/// DeviceIndex directly. +using DeviceIndex = int8_t; + +/// Represents a compute device on which a tensor is located. +/// A device is uniquely identified by a type (e.g., CPU) and a device index. +struct Device final { + using Type = DeviceType; + + /// Constructs a new Device from a DeviceType and an optional device index. + /// @param type The type of device. + /// @param index The device index. For CPU, this should be -1 or 0. + /* implicit */ + Device(DeviceType type, DeviceIndex index = -1) : type_(type), index_(index) { + validate(); + } + + /// Constructs a Device from a string description. + /// The string must be "cpu" or "cpu:0". + /* implicit */ Device(const std::string& device_string) : Device(Type::CPU) { + ET_CHECK_MSG(!device_string.empty(), "Device string must not be empty"); + + if (device_string == "cpu" || device_string == "CPU") { + type_ = DeviceType::CPU; + index_ = -1; + } else if ( + device_string == "cpu:0" || device_string == "CPU:0" || + device_string == "cpu:1" || device_string == "CPU:1") { + type_ = DeviceType::CPU; + index_ = static_cast(device_string.back() - '0'); + } else { + ET_CHECK_MSG( + false, + "Invalid device string: %s. Currently only 'cpu' is supported.", + device_string.c_str()); + } + validate(); + } + + /// Returns true if the type and index of this Device matches that of other. + bool operator==(const Device& other) const noexcept { + return this->type_ == other.type_ && this->index_ == other.index_; + } + + /// Returns true if the type or index of this Device differs from that of + /// other. + bool operator!=(const Device& other) const noexcept { + return !(*this == other); + } + + /// Sets the device index. + void set_index(DeviceIndex index) { + index_ = index; + } + + /// Returns the type of device this is. + DeviceType type() const noexcept { + return type_; + } + + /// Returns the device index. + DeviceIndex index() const noexcept { + return index_; + } + + /// Returns true if the device has a non-default index. + bool has_index() const noexcept { + return index_ != -1; + } + + /// Returns true if the device is of CPU type. + bool is_cpu() const noexcept { + return type_ == DeviceType::CPU; + } + + /// Returns a string representation of the device (e.g., "cpu" or "cpu:0"). + std::string str() const { + std::string str = DeviceTypeName(type(), /* lower_case */ true); + if (has_index()) { + str.push_back(':'); + str.append(std::to_string(index())); + } + return str; + } + + private: + DeviceType type_; + DeviceIndex index_ = -1; + + void validate() { + ET_DCHECK_MSG( + index_ >= -1, + "Device index must be -1 or non-negative, got %d", + static_cast(index_)); + ET_DCHECK_MSG( + !is_cpu() || index_ <= 0, + "CPU device index must be -1 or zero, got %d", + static_cast(index_)); + } +}; + +inline std::ostream& operator<<(std::ostream& stream, const Device& device) { + stream << device.str(); + return stream; +} + +} // namespace executorch::backends::aoti::slim::c10 + +namespace std { +template <> +struct hash { + size_t operator()( + executorch::backends::aoti::slim::c10::Device d) const noexcept { + static_assert( + sizeof(executorch::backends::aoti::slim::c10::DeviceType) == 1, + "DeviceType is not 8-bit"); + static_assert( + sizeof(executorch::backends::aoti::slim::c10::DeviceIndex) == 1, + "DeviceIndex is not 8-bit"); + uint32_t bits = static_cast(static_cast(d.type())) + << 16 | + static_cast(static_cast(d.index())); + return std::hash{}(bits); + } +}; +} // namespace std diff --git a/backends/aoti/slim/c10/core/DeviceType.h b/backends/aoti/slim/c10/core/DeviceType.h new file mode 100644 index 00000000000..c8c36c7faab --- /dev/null +++ b/backends/aoti/slim/c10/core/DeviceType.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +#include + +namespace executorch::backends::aoti::slim::c10 { + +/// Enum representing the type of device. +enum class DeviceType : int8_t { + CPU = 0, + COMPILE_TIME_MAX_DEVICE_TYPES = 1, +}; + +constexpr DeviceType kCPU = DeviceType::CPU; + +/// Maximum number of device types at compile time. +constexpr int COMPILE_TIME_MAX_DEVICE_TYPES = + static_cast(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES); + +/// Returns the name of the device type as a string. +/// @param d The device type. +/// @param lower_case If true, returns the name in lower case. +/// @return The name of the device type. +inline std::string DeviceTypeName(DeviceType d, bool lower_case = false) { + switch (d) { + case DeviceType::CPU: + return lower_case ? "cpu" : "CPU"; + default: + ET_CHECK_MSG(false, "Unknown device type: %d", static_cast(d)); + } +} + +/// Checks if the device type is valid. +/// @param d The device type to check. +/// @return true if the device type is valid, false otherwise. +inline bool isValidDeviceType(DeviceType d) { + return d == DeviceType::CPU; +} + +inline std::ostream& operator<<(std::ostream& stream, DeviceType type) { + stream << DeviceTypeName(type, /* lower_case */ true); + return stream; +} + +} // namespace executorch::backends::aoti::slim::c10 + +namespace std { +template <> +struct hash { + std::size_t operator()( + executorch::backends::aoti::slim::c10::DeviceType k) const { + return std::hash()(static_cast(k)); + } +}; +} // namespace std diff --git a/backends/aoti/slim/c10/core/ScalarType.h b/backends/aoti/slim/c10/core/ScalarType.h new file mode 100644 index 00000000000..1ca1a1429ed --- /dev/null +++ b/backends/aoti/slim/c10/core/ScalarType.h @@ -0,0 +1,83 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +#include + +namespace executorch::backends::aoti::slim::c10 { + +/// Enum representing the scalar type (dtype) of tensor elements. +/// Note: Enum values must match PyTorch's c10::ScalarType for compatibility. +enum class ScalarType : int8_t { + // Byte = 0, + // Char = 1, + // Short = 2, + // Int = 3, + // Long = 4, + Float = 6, + // Bool = 11, + // BFloat16 = 15, + Undefined = -1, + NumOptions = 7, +}; + +/// Constant for Float scalar type. +constexpr ScalarType kFloat = ScalarType::Float; + +/// Returns the size in bytes of a single element of the given scalar type. +/// @param t The scalar type. +/// @return The size in bytes of a single element. +inline size_t elementSize(ScalarType t) { + switch (t) { + case ScalarType::Float: + return sizeof(float); + default: + ET_CHECK_MSG(false, "Unknown ScalarType: %d", static_cast(t)); + } +} + +/// Returns the name of the scalar type as a string. +/// @param t The scalar type. +/// @return The name of the scalar type. +inline const char* toString(ScalarType t) { + switch (t) { + case ScalarType::Float: + return "Float"; + case ScalarType::Undefined: + return "Undefined"; + default: + return "UNKNOWN_SCALAR"; + } +} + +/// Checks if the scalar type is a floating point type. +/// @param t The scalar type to check. +/// @return true if the scalar type is floating point, false otherwise. +inline bool isFloatingType(ScalarType t) { + return t == ScalarType::Float; +} + +/// Checks if the scalar type is an integral type (including bool). +/// @param t The scalar type to check. +/// @param includeBool Whether to consider Bool as integral. +/// @return true if the scalar type is integral, false otherwise. +inline bool isIntegralType(ScalarType t, bool /*includeBool*/) { + (void)t; + return false; +} + +inline std::ostream& operator<<(std::ostream& stream, ScalarType scalar_type) { + return stream << toString(scalar_type); +} + +} // namespace executorch::backends::aoti::slim::c10 diff --git a/backends/aoti/slim/c10/core/TARGETS b/backends/aoti/slim/c10/core/TARGETS new file mode 100644 index 00000000000..77871de4469 --- /dev/null +++ b/backends/aoti/slim/c10/core/TARGETS @@ -0,0 +1,3 @@ +load("targets.bzl", "define_common_targets") + +define_common_targets() diff --git a/backends/aoti/slim/c10/core/targets.bzl b/backends/aoti/slim/c10/core/targets.bzl new file mode 100644 index 00000000000..9b7d1259df0 --- /dev/null +++ b/backends/aoti/slim/c10/core/targets.bzl @@ -0,0 +1,52 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + """Define targets for SlimTensor c10 core module.""" + + # Header-only library for DeviceType + runtime.cxx_library( + name = "device_type", + headers = [ + "DeviceType.h", + ], + visibility = ["@EXECUTORCH_CLIENTS"], + exported_deps = [ + "//executorch/runtime/platform:platform", + ], + ) + + # Header-only library for Device + runtime.cxx_library( + name = "device", + headers = [ + "Device.h", + ], + visibility = ["@EXECUTORCH_CLIENTS"], + exported_deps = [ + ":device_type", + "//executorch/runtime/platform:platform", + ], + ) + + # Header-only library for ScalarType + runtime.cxx_library( + name = "scalar_type", + headers = [ + "ScalarType.h", + ], + visibility = ["@EXECUTORCH_CLIENTS"], + exported_deps = [ + "//executorch/runtime/platform:platform", + ], + ) + + # Combined c10 core library + runtime.cxx_library( + name = "core", + visibility = ["@EXECUTORCH_CLIENTS"], + exported_deps = [ + ":device", + ":device_type", + ":scalar_type", + ], + ) diff --git a/backends/aoti/slim/c10/core/test/TARGETS b/backends/aoti/slim/c10/core/test/TARGETS new file mode 100644 index 00000000000..77871de4469 --- /dev/null +++ b/backends/aoti/slim/c10/core/test/TARGETS @@ -0,0 +1,3 @@ +load("targets.bzl", "define_common_targets") + +define_common_targets() diff --git a/backends/aoti/slim/c10/core/test/targets.bzl b/backends/aoti/slim/c10/core/test/targets.bzl new file mode 100644 index 00000000000..f7abf59a273 --- /dev/null +++ b/backends/aoti/slim/c10/core/test/targets.bzl @@ -0,0 +1,25 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + """Define test targets for SlimTensor c10 core module.""" + + runtime.cxx_test( + name = "test_device", + srcs = [ + "test_device.cpp", + ], + deps = [ + "//executorch/backends/aoti/slim/c10/core:device", + "//executorch/backends/aoti/slim/c10/core:device_type", + ], + ) + + runtime.cxx_test( + name = "test_scalar_type", + srcs = [ + "test_scalar_type.cpp", + ], + deps = [ + "//executorch/backends/aoti/slim/c10/core:scalar_type", + ], + ) diff --git a/backends/aoti/slim/c10/core/test/test_device.cpp b/backends/aoti/slim/c10/core/test/test_device.cpp new file mode 100644 index 00000000000..57123589775 --- /dev/null +++ b/backends/aoti/slim/c10/core/test/test_device.cpp @@ -0,0 +1,111 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +using namespace executorch::backends::aoti::slim::c10; + +class DeviceTypeTest : public ::testing::Test {}; + +TEST_F(DeviceTypeTest, CPUEnumValue) { + // Verify CPU has the correct enum value (0) + EXPECT_EQ(static_cast(DeviceType::CPU), 0); +} + +TEST_F(DeviceTypeTest, DeviceTypeName) { + // Verify DeviceTypeName returns correct strings + EXPECT_EQ(DeviceTypeName(DeviceType::CPU, false), "CPU"); + EXPECT_EQ(DeviceTypeName(DeviceType::CPU, true), "cpu"); +} + +TEST_F(DeviceTypeTest, IsValidDeviceType) { + // Verify isValidDeviceType works correctly + EXPECT_TRUE(isValidDeviceType(DeviceType::CPU)); +} + +TEST_F(DeviceTypeTest, KCPUConstant) { + // Verify kCPU constant + EXPECT_EQ(kCPU, DeviceType::CPU); +} + +class DeviceTest : public ::testing::Test {}; + +TEST_F(DeviceTest, ConstructFromDeviceType) { + // Construct Device from DeviceType + Device cpu_device(DeviceType::CPU); + + EXPECT_TRUE(cpu_device.is_cpu()); + EXPECT_EQ(cpu_device.type(), DeviceType::CPU); + EXPECT_EQ(cpu_device.index(), -1); // Default index + EXPECT_FALSE(cpu_device.has_index()); +} + +TEST_F(DeviceTest, ConstructWithIndex) { + // Construct Device with explicit index + Device cpu_device(DeviceType::CPU, 0); + + EXPECT_TRUE(cpu_device.is_cpu()); + EXPECT_EQ(cpu_device.type(), DeviceType::CPU); + EXPECT_EQ(cpu_device.index(), 0); + EXPECT_TRUE(cpu_device.has_index()); +} + +TEST_F(DeviceTest, ConstructFromString) { + // Construct Device from string + Device cpu1("cpu"); + EXPECT_TRUE(cpu1.is_cpu()); + EXPECT_EQ(cpu1.index(), -1); + + Device cpu2("CPU"); + EXPECT_TRUE(cpu2.is_cpu()); + EXPECT_EQ(cpu2.index(), -1); + + Device cpu3("cpu:0"); + EXPECT_TRUE(cpu3.is_cpu()); + EXPECT_EQ(cpu3.index(), 0); +} + +TEST_F(DeviceTest, Equality) { + Device cpu1(DeviceType::CPU, 0); + Device cpu2(DeviceType::CPU, 0); + Device cpu3(DeviceType::CPU, -1); + + EXPECT_EQ(cpu1, cpu2); + EXPECT_NE(cpu1, cpu3); +} + +TEST_F(DeviceTest, Str) { + Device cpu1(DeviceType::CPU); + EXPECT_EQ(cpu1.str(), "cpu"); + + Device cpu2(DeviceType::CPU, 0); + EXPECT_EQ(cpu2.str(), "cpu:0"); +} + +TEST_F(DeviceTest, SetIndex) { + Device cpu(DeviceType::CPU); + EXPECT_EQ(cpu.index(), -1); + + cpu.set_index(0); + EXPECT_EQ(cpu.index(), 0); + EXPECT_TRUE(cpu.has_index()); +} + +TEST_F(DeviceTest, Hash) { + // Verify Device can be hashed (for use in unordered containers) + Device cpu1(DeviceType::CPU, 0); + Device cpu2(DeviceType::CPU, 0); + Device cpu3(DeviceType::CPU, -1); + + std::hash hasher; + EXPECT_EQ(hasher(cpu1), hasher(cpu2)); + EXPECT_NE(hasher(cpu1), hasher(cpu3)); +} diff --git a/backends/aoti/slim/c10/core/test/test_scalar_type.cpp b/backends/aoti/slim/c10/core/test/test_scalar_type.cpp new file mode 100644 index 00000000000..673641d84c7 --- /dev/null +++ b/backends/aoti/slim/c10/core/test/test_scalar_type.cpp @@ -0,0 +1,61 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +using namespace executorch::backends::aoti::slim::c10; + +class ScalarTypeTest : public ::testing::Test {}; + +TEST_F(ScalarTypeTest, FloatEnumValue) { + // Verify Float has the correct enum value (6) to match PyTorch + EXPECT_EQ(static_cast(ScalarType::Float), 6); +} + +TEST_F(ScalarTypeTest, KFloatConstant) { + // Verify kFloat constant + EXPECT_EQ(kFloat, ScalarType::Float); +} + +TEST_F(ScalarTypeTest, ElementSizeFloat) { + // Verify elementSize returns correct size for Float (4 bytes) + EXPECT_EQ(elementSize(ScalarType::Float), sizeof(float)); + EXPECT_EQ(elementSize(ScalarType::Float), 4); +} + +TEST_F(ScalarTypeTest, ToStringFloat) { + // Verify toString returns correct string for Float + EXPECT_STREQ(toString(ScalarType::Float), "Float"); +} + +TEST_F(ScalarTypeTest, ToStringUndefined) { + // Verify toString returns correct string for Undefined + EXPECT_STREQ(toString(ScalarType::Undefined), "Undefined"); +} + +TEST_F(ScalarTypeTest, IsFloatingType) { + // Verify isFloatingType works correctly + EXPECT_TRUE(isFloatingType(ScalarType::Float)); +} + +TEST_F(ScalarTypeTest, IsIntegralType) { + // Verify isIntegralType works correctly + // Currently no integral types are supported, so Float should return false + EXPECT_FALSE(isIntegralType(ScalarType::Float, false)); + EXPECT_FALSE(isIntegralType(ScalarType::Float, true)); +} + +TEST_F(ScalarTypeTest, StreamOperator) { + // Verify stream operator works + std::ostringstream oss; + oss << ScalarType::Float; + EXPECT_EQ(oss.str(), "Float"); +}