Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions backends/aoti/slim/c10/core/Device.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <cstdint>
#include <functional>
#include <string>

#include <executorch/backends/aoti/slim/c10/core/DeviceType.h>
#include <executorch/runtime/platform/assert.h>

namespace executorch::backends::aoti::slim::c10 {

/// An index representing a specific device; e.g., the 1 in GPU 1.
/// A DeviceIndex is not independently meaningful without knowing
/// the DeviceType it is associated; try to use Device rather than
/// DeviceIndex directly.
using DeviceIndex = int8_t;

/// Represents a compute device on which a tensor is located.
/// A device is uniquely identified by a type (e.g., CPU) and a device index.
struct Device final {
using Type = DeviceType;

/// Constructs a new Device from a DeviceType and an optional device index.
/// @param type The type of device.
/// @param index The device index. For CPU, this should be -1 or 0.
/* implicit */
Device(DeviceType type, DeviceIndex index = -1) : type_(type), index_(index) {
validate();
}

/// Constructs a Device from a string description.
/// The string must be "cpu" or "cpu:0".
/* implicit */ Device(const std::string& device_string) : Device(Type::CPU) {
ET_CHECK_MSG(!device_string.empty(), "Device string must not be empty");

if (device_string == "cpu" || device_string == "CPU") {
type_ = DeviceType::CPU;
index_ = -1;
} else if (
device_string == "cpu:0" || device_string == "CPU:0" ||
device_string == "cpu:1" || device_string == "CPU:1") {
type_ = DeviceType::CPU;
index_ = static_cast<DeviceIndex>(device_string.back() - '0');
} else {
ET_CHECK_MSG(
false,
"Invalid device string: %s. Currently only 'cpu' is supported.",
device_string.c_str());
}
validate();
}

/// Returns true if the type and index of this Device matches that of other.
bool operator==(const Device& other) const noexcept {
return this->type_ == other.type_ && this->index_ == other.index_;
}

/// Returns true if the type or index of this Device differs from that of
/// other.
bool operator!=(const Device& other) const noexcept {
return !(*this == other);
}

/// Sets the device index.
void set_index(DeviceIndex index) {
index_ = index;
}

/// Returns the type of device this is.
DeviceType type() const noexcept {
return type_;
}

/// Returns the device index.
DeviceIndex index() const noexcept {
return index_;
}

/// Returns true if the device has a non-default index.
bool has_index() const noexcept {
return index_ != -1;
}

/// Returns true if the device is of CPU type.
bool is_cpu() const noexcept {
return type_ == DeviceType::CPU;
}

/// Returns a string representation of the device (e.g., "cpu" or "cpu:0").
std::string str() const {
std::string str = DeviceTypeName(type(), /* lower_case */ true);
if (has_index()) {
str.push_back(':');
str.append(std::to_string(index()));
}
return str;
}

private:
DeviceType type_;
DeviceIndex index_ = -1;

void validate() {
ET_DCHECK_MSG(
index_ >= -1,
"Device index must be -1 or non-negative, got %d",
static_cast<int>(index_));
ET_DCHECK_MSG(
!is_cpu() || index_ <= 0,
"CPU device index must be -1 or zero, got %d",
static_cast<int>(index_));
}
};

inline std::ostream& operator<<(std::ostream& stream, const Device& device) {
stream << device.str();
return stream;
}

} // namespace executorch::backends::aoti::slim::c10

namespace std {
template <>
struct hash<executorch::backends::aoti::slim::c10::Device> {
size_t operator()(
executorch::backends::aoti::slim::c10::Device d) const noexcept {
static_assert(
sizeof(executorch::backends::aoti::slim::c10::DeviceType) == 1,
"DeviceType is not 8-bit");
static_assert(
sizeof(executorch::backends::aoti::slim::c10::DeviceIndex) == 1,
"DeviceIndex is not 8-bit");
uint32_t bits = static_cast<uint32_t>(static_cast<uint8_t>(d.type()))
<< 16 |
static_cast<uint32_t>(static_cast<uint8_t>(d.index()));
return std::hash<uint32_t>{}(bits);
}
};
} // namespace std
66 changes: 66 additions & 0 deletions backends/aoti/slim/c10/core/DeviceType.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <cstdint>
#include <ostream>
#include <string>

#include <executorch/runtime/platform/assert.h>

namespace executorch::backends::aoti::slim::c10 {

/// Enum representing the type of device.
enum class DeviceType : int8_t {
CPU = 0,
COMPILE_TIME_MAX_DEVICE_TYPES = 1,
};

constexpr DeviceType kCPU = DeviceType::CPU;

/// Maximum number of device types at compile time.
constexpr int COMPILE_TIME_MAX_DEVICE_TYPES =
static_cast<int>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES);

/// Returns the name of the device type as a string.
/// @param d The device type.
/// @param lower_case If true, returns the name in lower case.
/// @return The name of the device type.
inline std::string DeviceTypeName(DeviceType d, bool lower_case = false) {
switch (d) {
case DeviceType::CPU:
return lower_case ? "cpu" : "CPU";
default:
ET_CHECK_MSG(false, "Unknown device type: %d", static_cast<int>(d));
}
}

/// Checks if the device type is valid.
/// @param d The device type to check.
/// @return true if the device type is valid, false otherwise.
inline bool isValidDeviceType(DeviceType d) {
return d == DeviceType::CPU;
}

inline std::ostream& operator<<(std::ostream& stream, DeviceType type) {
stream << DeviceTypeName(type, /* lower_case */ true);
return stream;
}

} // namespace executorch::backends::aoti::slim::c10

namespace std {
template <>
struct hash<executorch::backends::aoti::slim::c10::DeviceType> {
std::size_t operator()(
executorch::backends::aoti::slim::c10::DeviceType k) const {
return std::hash<int>()(static_cast<int>(k));
}
};
} // namespace std
83 changes: 83 additions & 0 deletions backends/aoti/slim/c10/core/ScalarType.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <cstddef>
#include <cstdint>
#include <ostream>

#include <executorch/runtime/platform/assert.h>

namespace executorch::backends::aoti::slim::c10 {

/// Enum representing the scalar type (dtype) of tensor elements.
/// Note: Enum values must match PyTorch's c10::ScalarType for compatibility.
enum class ScalarType : int8_t {
// Byte = 0,
// Char = 1,
// Short = 2,
// Int = 3,
// Long = 4,
Float = 6,
// Bool = 11,
// BFloat16 = 15,
Undefined = -1,
NumOptions = 7,
};

/// Constant for Float scalar type.
constexpr ScalarType kFloat = ScalarType::Float;

/// Returns the size in bytes of a single element of the given scalar type.
/// @param t The scalar type.
/// @return The size in bytes of a single element.
inline size_t elementSize(ScalarType t) {
switch (t) {
case ScalarType::Float:
return sizeof(float);
default:
ET_CHECK_MSG(false, "Unknown ScalarType: %d", static_cast<int>(t));
}
}

/// Returns the name of the scalar type as a string.
/// @param t The scalar type.
/// @return The name of the scalar type.
inline const char* toString(ScalarType t) {
switch (t) {
case ScalarType::Float:
return "Float";
case ScalarType::Undefined:
return "Undefined";
default:
return "UNKNOWN_SCALAR";
}
}

/// Checks if the scalar type is a floating point type.
/// @param t The scalar type to check.
/// @return true if the scalar type is floating point, false otherwise.
inline bool isFloatingType(ScalarType t) {
return t == ScalarType::Float;
}

/// Checks if the scalar type is an integral type (including bool).
/// @param t The scalar type to check.
/// @param includeBool Whether to consider Bool as integral.
/// @return true if the scalar type is integral, false otherwise.
inline bool isIntegralType(ScalarType t, bool /*includeBool*/) {
(void)t;
return false;
}

inline std::ostream& operator<<(std::ostream& stream, ScalarType scalar_type) {
return stream << toString(scalar_type);
}

} // namespace executorch::backends::aoti::slim::c10
3 changes: 3 additions & 0 deletions backends/aoti/slim/c10/core/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
load("targets.bzl", "define_common_targets")

define_common_targets()
52 changes: 52 additions & 0 deletions backends/aoti/slim/c10/core/targets.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

def define_common_targets():
"""Define targets for SlimTensor c10 core module."""

# Header-only library for DeviceType
runtime.cxx_library(
name = "device_type",
headers = [
"DeviceType.h",
],
visibility = ["@EXECUTORCH_CLIENTS"],
exported_deps = [
"//executorch/runtime/platform:platform",
],
)

# Header-only library for Device
runtime.cxx_library(
name = "device",
headers = [
"Device.h",
],
visibility = ["@EXECUTORCH_CLIENTS"],
exported_deps = [
":device_type",
"//executorch/runtime/platform:platform",
],
)

# Header-only library for ScalarType
runtime.cxx_library(
name = "scalar_type",
headers = [
"ScalarType.h",
],
visibility = ["@EXECUTORCH_CLIENTS"],
exported_deps = [
"//executorch/runtime/platform:platform",
],
)

# Combined c10 core library
runtime.cxx_library(
name = "core",
visibility = ["@EXECUTORCH_CLIENTS"],
exported_deps = [
":device",
":device_type",
":scalar_type",
],
)
3 changes: 3 additions & 0 deletions backends/aoti/slim/c10/core/test/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
load("targets.bzl", "define_common_targets")

define_common_targets()
25 changes: 25 additions & 0 deletions backends/aoti/slim/c10/core/test/targets.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

def define_common_targets():
"""Define test targets for SlimTensor c10 core module."""

runtime.cxx_test(
name = "test_device",
srcs = [
"test_device.cpp",
],
deps = [
"//executorch/backends/aoti/slim/c10/core:device",
"//executorch/backends/aoti/slim/c10/core:device_type",
],
)

runtime.cxx_test(
name = "test_scalar_type",
srcs = [
"test_scalar_type.cpp",
],
deps = [
"//executorch/backends/aoti/slim/c10/core:scalar_type",
],
)
Loading
Loading