From af6810bfa4d8c2fc3ceb84e33ebff31bc59cbd7d Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Mon, 2 Mar 2026 17:46:59 -0800 Subject: [PATCH 1/9] up --- .github/workflows/mlx.yml | 99 ++ .gitignore | 2 + .gitmodules | 4 + CMakeLists.txt | 15 + CMakePresets.json | 85 +- backends/mlx/CMakeLists.txt | 330 ++++ backends/mlx/README.md | 499 ++++++ backends/mlx/__init__.py | 17 + backends/mlx/_logging.py | 40 + backends/mlx/builder/__init__.py | 16 + backends/mlx/builder/op_helpers.py | 275 ++++ backends/mlx/builder/op_registry.py | 151 ++ backends/mlx/builder/pattern_matcher.py | 64 + backends/mlx/builder/program_builder.py | 1018 ++++++++++++ backends/mlx/builder/slot_manager.py | 187 +++ backends/mlx/custom_ops.py | 15 + backends/mlx/ops.py | 294 ++++ backends/mlx/partitioner.py | 298 ++++ backends/mlx/passes.py | 20 + backends/mlx/patches/mlx_json.patch | 29 + backends/mlx/pattern_utils.py | 360 +++++ backends/mlx/patterns.py | 14 + backends/mlx/preprocess.py | 168 ++ backends/mlx/pte_inspector.py | 897 ++++++++++ backends/mlx/runtime/MLXBackend.cpp | 419 +++++ backends/mlx/runtime/MLXExecutor.h | 878 ++++++++++ backends/mlx/runtime/MLXInterpreter.h | 169 ++ backends/mlx/serialization/MLXLoader.cpp.tmpl | 324 ++++ backends/mlx/serialization/MLXLoader.h.tmpl | 343 ++++ backends/mlx/serialization/README.md | 130 ++ backends/mlx/serialization/__init__.py | 32 + backends/mlx/serialization/generate.py | 1437 +++++++++++++++++ .../mlx/serialization/mlx_graph_serialize.py | 416 +++++ backends/mlx/serialization/schema.fbs | 192 +++ backends/mlx/test/CMakeLists.txt | 51 + backends/mlx/test/README.md | 164 ++ backends/mlx/test/__init__.py | 5 + backends/mlx/test/op_test_runner.cpp | 395 +++++ backends/mlx/test/run_all_tests.py | 496 ++++++ backends/mlx/test/strict_compile_test.cpp | 45 + backends/mlx/test/test_ops.py | 176 ++ backends/mlx/test/test_partitioner.py | 45 + backends/mlx/test/test_passes.py | 6 + backends/mlx/test/test_pattern_utils.py | 592 +++++++ backends/mlx/test/test_utils.py | 1122 +++++++++++++ backends/mlx/test/tester.py | 78 + backends/mlx/third-party/mlx | 1 + backends/test/suite/flow.py | 11 +- backends/test/suite/flows/mlx.py | 14 + exir/_serialize/_program.py | 67 + setup.py | 33 + tools/cmake/Utils.cmake | 33 + tools/cmake/executorch-config.cmake | 45 + tools/cmake/preset/default.cmake | 1 + tools/cmake/preset/pybind.cmake | 18 + 55 files changed, 12633 insertions(+), 2 deletions(-) create mode 100644 .github/workflows/mlx.yml create mode 100644 backends/mlx/CMakeLists.txt create mode 100644 backends/mlx/README.md create mode 100644 backends/mlx/__init__.py create mode 100644 backends/mlx/_logging.py create mode 100644 backends/mlx/builder/__init__.py create mode 100644 backends/mlx/builder/op_helpers.py create mode 100644 backends/mlx/builder/op_registry.py create mode 100644 backends/mlx/builder/pattern_matcher.py create mode 100644 backends/mlx/builder/program_builder.py create mode 100644 backends/mlx/builder/slot_manager.py create mode 100644 backends/mlx/custom_ops.py create mode 100644 backends/mlx/ops.py create mode 100644 backends/mlx/partitioner.py create mode 100644 backends/mlx/passes.py create mode 100644 backends/mlx/patches/mlx_json.patch create mode 100644 backends/mlx/pattern_utils.py create mode 100644 backends/mlx/patterns.py create mode 100644 backends/mlx/preprocess.py create mode 100644 backends/mlx/pte_inspector.py create mode 100644 backends/mlx/runtime/MLXBackend.cpp create mode 100644 backends/mlx/runtime/MLXExecutor.h create mode 100644 backends/mlx/runtime/MLXInterpreter.h create mode 100644 backends/mlx/serialization/MLXLoader.cpp.tmpl create mode 100644 backends/mlx/serialization/MLXLoader.h.tmpl create mode 100644 backends/mlx/serialization/README.md create mode 100644 backends/mlx/serialization/__init__.py create mode 100755 backends/mlx/serialization/generate.py create mode 100644 backends/mlx/serialization/mlx_graph_serialize.py create mode 100644 backends/mlx/serialization/schema.fbs create mode 100644 backends/mlx/test/CMakeLists.txt create mode 100644 backends/mlx/test/README.md create mode 100644 backends/mlx/test/__init__.py create mode 100644 backends/mlx/test/op_test_runner.cpp create mode 100644 backends/mlx/test/run_all_tests.py create mode 100644 backends/mlx/test/strict_compile_test.cpp create mode 100644 backends/mlx/test/test_ops.py create mode 100644 backends/mlx/test/test_partitioner.py create mode 100644 backends/mlx/test/test_passes.py create mode 100644 backends/mlx/test/test_pattern_utils.py create mode 100644 backends/mlx/test/test_utils.py create mode 100644 backends/mlx/test/tester.py create mode 160000 backends/mlx/third-party/mlx create mode 100644 backends/test/suite/flows/mlx.py diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml new file mode 100644 index 00000000000..2e8ca7aa3b7 --- /dev/null +++ b/.github/workflows/mlx.yml @@ -0,0 +1,99 @@ +name: MLX + +on: + push: + branches: + - main + - release/* + pull_request: + paths: + - .github/workflows/mlx.yml + - backends/mlx/** + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + cancel-in-progress: true + +jobs: + test-mlx: + uses: pytorch/test-infra/.github/workflows/macos_job.yml@main + with: + job-name: test-mlx + runner: macos-14-xlarge + python-version: "3.12" + submodules: recursive + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 90 + script: | + set -eux + + echo "::group::Install ExecuTorch and configure build" + ${CONDA_RUN} python install_executorch.py > /dev/null + # The sanitizers fail on github VM runner, but pass on real device + # TODO: figure out why + ${CONDA_RUN} cmake --preset mlx-release -DEXECUTORCH_BUILD_TESTS=ON -DEXECUTORCH_MLX_ENABLE_SANITIZERS=OFF + echo "::endgroup::" + + ${CONDA_RUN} pip list + + echo "::group::Build test runners" + ${CONDA_RUN} cmake --build cmake-out --target op_test_runner -j$(( $(sysctl -n hw.ncpu) - 1 )) + echo "::endgroup::" + + echo "::group::Run op unit tests" + ${CONDA_RUN} python -m executorch.backends.mlx.test.run_all_tests -j4 --max-tasks-per-worker 10 --clean-after + echo "::endgroup::" + + echo "::group::Run Python unit tests" + ${CONDA_RUN} python -m pytest \ + backends/mlx/test/test_passes.py \ + backends/mlx/test/test_pattern_utils.py \ + backends/mlx/test/test_partitioner.py \ + -v + echo "::endgroup::" + + backend-tester: + strategy: + fail-fast: false + matrix: + suite: [models, operators] + uses: pytorch/test-infra/.github/workflows/macos_job.yml@main + with: + job-name: test-mlx-backend-${{ matrix.suite }} + runner: macos-14-xlarge + python-version: "3.12" + submodules: recursive + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 120 + script: | + set -eux + + echo "::group::Install ExecuTorch" + ${CONDA_RUN} python install_executorch.py > /dev/null + echo "::endgroup::" + + ${CONDA_RUN} pip list + + echo "::group::Run backend test suite (${{ matrix.suite }})" + ${CONDA_RUN} pytest -c /dev/null backends/test/suite/${{ matrix.suite }}/ -m flow_mlx -n auto 2>&1 | tee pytest_output.txt || true + echo "::endgroup::" + + # Parse pytest summary and check failure threshold + if grep -E "^=+ .* =+$" pytest_output.txt | tail -1 | grep -q "failed"; then + FAILED=$(grep -E "^=+ .* =+$" pytest_output.txt | tail -1 | grep -oE "[0-9]+ failed" | grep -oE "[0-9]+") + else + FAILED=0 + fi + + if [ "${{ matrix.suite }}" = "operators" ]; then + MAX_FAILURES=0 + else + MAX_FAILURES=3 + fi + + echo "Failed tests: $FAILED (max allowed: $MAX_FAILURES)" + if [ "$FAILED" -gt "$MAX_FAILURES" ]; then + echo "::error::Too many test failures: $FAILED > $MAX_FAILURES" + exit 1 + fi diff --git a/.gitignore b/.gitignore index 4ddbb7c49ad..3453b7e9676 100644 --- a/.gitignore +++ b/.gitignore @@ -74,5 +74,7 @@ xcuserdata/ *.dll *.pyd + # Agents .claude/*.local.* +extension/pybindings/mlx.metallib diff --git a/.gitmodules b/.gitmodules index 1f202d4fdec..917e755da27 100644 --- a/.gitmodules +++ b/.gitmodules @@ -67,3 +67,7 @@ [submodule "third-party/json"] path = third-party/json url = https://github.com/nlohmann/json.git +[submodule "backends/mlx/third-party/mlx"] + path = backends/mlx/third-party/mlx + url = https://github.com/ml-explore/mlx.git + shallow = true diff --git a/CMakeLists.txt b/CMakeLists.txt index 995a75c342b..2297a8142f3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -659,6 +659,11 @@ if(EXECUTORCH_BUILD_MPS) list(APPEND _executorch_backends mpsdelegate) endif() +if(EXECUTORCH_BUILD_MLX) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/mlx) + list(APPEND _executorch_backends mlxdelegate) +endif() + if(EXECUTORCH_BUILD_NEURON) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/mediatek) list(APPEND _executorch_backends neuron_backend) @@ -956,6 +961,10 @@ if(EXECUTORCH_BUILD_PYBIND) list(APPEND _dep_libs mpsdelegate) endif() + if(EXECUTORCH_BUILD_MLX) + list(APPEND _dep_libs mlxdelegate) + endif() + if(EXECUTORCH_BUILD_OPENVINO) list(APPEND _dep_libs openvino_backend) endif() @@ -1056,6 +1065,12 @@ if(EXECUTORCH_BUILD_PYBIND) install(TARGETS data_loader LIBRARY DESTINATION executorch/extension/pybindings ) + + # Copy MLX metallib next to _portable_lib.so for editable installs. MLX uses + # dladdr() to find the directory containing the library with MLX code, then + # looks for mlx.metallib in that directory. When MLX is statically linked into + # _portable_lib.so, we need the metallib colocated with it. + executorch_target_copy_mlx_metallib(portable_lib) endif() if(EXECUTORCH_BUILD_WASM) diff --git a/CMakePresets.json b/CMakePresets.json index ca4da226ba1..fa1d77623d9 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -110,7 +110,7 @@ "inherits": ["common"], "cacheVariables": { "EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/pybind.cmake", - "CMAKE_OSX_DEPLOYMENT_TARGET": "12.0" + "CMAKE_OSX_DEPLOYMENT_TARGET": "14.0" }, "condition": { "type": "inList", @@ -294,6 +294,43 @@ "EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/arm_ethosu_linux.cmake", "CMAKE_TOOLCHAIN_FILE": "${sourceDir}/examples/arm/ethos-u-setup/aarch64-linux-musl-toolchain.cmake" } + }, + { + "name": "mlx", + "displayName": "Build MLX delegate", + "inherits": ["common"], + "cacheVariables": { + "EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/mlx.cmake", + "EXECUTORCH_ENABLE_LOGGING": "ON", + "CMAKE_OSX_DEPLOYMENT_TARGET": "14.0" + }, + "condition": { + "lhs": "${hostSystemName}", + "type": "equals", + "rhs": "Darwin" + } + }, + { + "name": "mlx-release", + "displayName": "MLX delegate release build", + "inherits": ["mlx"], + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_INSTALL_PREFIX": "${sourceDir}/cmake-out", + "ET_MLX_ENABLE_OP_LOGGING": "OFF", + "ET_MIN_LOG_LEVEL": "Error" + } + }, + { + "name": "mlx-debug", + "displayName": "MLX delegate debug build with op logging", + "inherits": ["mlx"], + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_INSTALL_PREFIX": "${sourceDir}/cmake-out", + "ET_MLX_ENABLE_OP_LOGGING": "ON", + "ET_MIN_LOG_LEVEL": "Debug" + } } ], "buildPresets": [ @@ -362,6 +399,24 @@ "install" ], "jobs": 0 + }, + { + "name": "mlx-release-install", + "displayName": "Build and install MLX delegate release artifacts", + "configurePreset": "mlx-release", + "targets": [ + "install" + ], + "jobs": 0 + }, + { + "name": "mlx-debug-install", + "displayName": "Build and install MLX delegate debug artifacts", + "configurePreset": "mlx-debug", + "targets": [ + "install" + ], + "jobs": 0 } ], "workflowPresets": [ @@ -462,6 +517,34 @@ "name": "llm-metal-stats-install" } ] + }, + { + "name": "mlx-release", + "displayName": "Configure, build and install ExecuTorch MLX delegate", + "steps": [ + { + "type": "configure", + "name": "mlx-release" + }, + { + "type": "build", + "name": "mlx-release-install" + } + ] + }, + { + "name": "mlx-debug", + "displayName": "Configure, build and install ExecuTorch MLX delegate with op logging (Debug)", + "steps": [ + { + "type": "configure", + "name": "mlx-debug" + }, + { + "type": "build", + "name": "mlx-debug-install" + } + ] } ] } diff --git a/backends/mlx/CMakeLists.txt b/backends/mlx/CMakeLists.txt new file mode 100644 index 00000000000..00e7c497b1c --- /dev/null +++ b/backends/mlx/CMakeLists.txt @@ -0,0 +1,330 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +cmake_minimum_required(VERSION 3.19) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +if(NOT CMAKE_CXX_STANDARD) + set(CMAKE_CXX_STANDARD 17) +endif() + +# Source root directory for executorch. +if(NOT EXECUTORCH_ROOT) + set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..) +endif() + +include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) + +set(_common_compile_options -Wall -Werror -Wno-deprecated-declarations) + +# Sanitizer flags (asan + ubsan) for security hardening — CI only. Enable via: +# cmake --preset mlx-release -DEXECUTORCH_MLX_ENABLE_SANITIZERS=ON +option(EXECUTORCH_MLX_ENABLE_SANITIZERS + "Enable ASan + UBSan for MLX delegate and tests" OFF +) +if(EXECUTORCH_MLX_ENABLE_SANITIZERS) + list(APPEND _common_compile_options -fsanitize=address,undefined + -fno-omit-frame-pointer + ) + set(_mlx_sanitizer_link_options -fsanitize=address,undefined) +endif() + +# ----------------------------------------------------------------------------- +# Code generation from schema.fbs +# ----------------------------------------------------------------------------- +# +# The generate.py script generates all code from schema.fbs: Python: +# mlx_graph_schema.py, _generated_serializers.py, _generated/ C++: MLXLoader.h, +# MLXLoader.cpp, schema_generated.h +# +# We run generate.py at build time so these files don't need to be checked in. +# ----------------------------------------------------------------------------- + +set(_mlx_generate_script + "${CMAKE_CURRENT_SOURCE_DIR}/serialization/generate.py" +) +set(_mlx_schema_fbs "${CMAKE_CURRENT_SOURCE_DIR}/serialization/schema.fbs") + +# Generated C++ files that we need for compilation +set(_mlx_generated_cpp_files + "${CMAKE_CURRENT_SOURCE_DIR}/runtime/schema_generated.h" + "${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXLoader.h" + "${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXLoader.cpp" +) + +# Generated Python files (tracked for dependency purposes) +set(_mlx_generated_python_files + "${CMAKE_CURRENT_SOURCE_DIR}/serialization/mlx_graph_schema.py" + "${CMAKE_CURRENT_SOURCE_DIR}/serialization/_generated_serializers.py" +) + +# Run generate.py to create all generated files from schema.fbs Find Python - +# prefer Python3_EXECUTABLE if set (from FindPython3), otherwise use +# PYTHON_EXECUTABLE +if(Python3_EXECUTABLE) + set(_python_executable ${Python3_EXECUTABLE}) +elseif(PYTHON_EXECUTABLE) + set(_python_executable ${PYTHON_EXECUTABLE}) +else() + find_package( + Python3 + COMPONENTS Interpreter + REQUIRED + ) + set(_python_executable ${Python3_EXECUTABLE}) +endif() + +add_custom_command( + OUTPUT ${_mlx_generated_cpp_files} ${_mlx_generated_python_files} + COMMAND ${_python_executable} ${_mlx_generate_script} --flatc + $ + WORKING_DIRECTORY ${EXECUTORCH_ROOT} + DEPENDS ${_mlx_schema_fbs} ${_mlx_generate_script} flatc + COMMENT "Generating MLX delegate code from schema.fbs" + VERBATIM +) + +# Custom target to trigger generation +add_custom_target( + mlx_generate_code DEPENDS ${_mlx_generated_cpp_files} + ${_mlx_generated_python_files} +) + +# Interface library for schema includes +add_library(mlx_schema INTERFACE) +add_dependencies(mlx_schema mlx_generate_code) +target_include_directories( + mlx_schema + INTERFACE + $ + $ +) + +# ----------------------------------------------------------------------------- +# MLX dependency (from submodule) +# ----------------------------------------------------------------------------- + +set(MLX_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third-party/mlx) + +# Check that submodule is initialized +if(NOT EXISTS "${MLX_SOURCE_DIR}/CMakeLists.txt") + message( + FATAL_ERROR "MLX submodule not initialized.\n" + "Run: git submodule update --init backends/mlx/third-party/mlx" + ) +endif() + +# Validate deployment target - MLX requires macOS 14.0+ / iOS 17.0+ +# +# The macOS preset uses ios.toolchain.cmake (with PLATFORM=MAC_ARM64), so +# DEPLOYMENT_TARGET is set for both macOS and iOS builds. We check PLATFORM to +# distinguish them rather than relying on which variable is set. +set(_mlx_deployment_target_ok TRUE) +if(PLATFORM AND PLATFORM MATCHES "^MAC") + # macOS build via ios.toolchain.cmake (e.g., MAC_ARM64, MAC_UNIVERSAL) + if(DEPLOYMENT_TARGET) + set(_mlx_dt_value ${DEPLOYMENT_TARGET}) + elseif(CMAKE_OSX_DEPLOYMENT_TARGET) + set(_mlx_dt_value ${CMAKE_OSX_DEPLOYMENT_TARGET}) + endif() + if(_mlx_dt_value AND _mlx_dt_value VERSION_LESS "14.0") + set(_mlx_deployment_target_ok FALSE) + set(_mlx_deployment_target_value ${_mlx_dt_value}) + set(_mlx_deployment_target_min "14.0") + endif() +elseif(DEPLOYMENT_TARGET) + # iOS/tvOS/watchOS/visionOS builds via ios.toolchain.cmake + if(DEPLOYMENT_TARGET VERSION_LESS "17.0") + set(_mlx_deployment_target_ok FALSE) + set(_mlx_deployment_target_value ${DEPLOYMENT_TARGET}) + set(_mlx_deployment_target_min "17.0") + endif() +elseif(CMAKE_OSX_DEPLOYMENT_TARGET) + # Plain macOS build (no ios.toolchain.cmake) + if(CMAKE_OSX_DEPLOYMENT_TARGET VERSION_LESS "14.0") + set(_mlx_deployment_target_ok FALSE) + set(_mlx_deployment_target_value ${CMAKE_OSX_DEPLOYMENT_TARGET}) + set(_mlx_deployment_target_min "14.0") + endif() +endif() + +if(NOT _mlx_deployment_target_ok) + message( + FATAL_ERROR + "MLX requires deployment target >= ${_mlx_deployment_target_min}, got ${_mlx_deployment_target_value}.\n" + "Either increase the deployment target or disable MLX with -DEXECUTORCH_BUILD_MLX=OFF" + ) +endif() + +# MLX build options - we only need the C++ library with Metal +set(MLX_BUILD_PYTHON_BINDINGS + OFF + CACHE BOOL "" FORCE +) +set(MLX_BUILD_TESTS + OFF + CACHE BOOL "" FORCE +) +set(MLX_BUILD_EXAMPLES + OFF + CACHE BOOL "" FORCE +) +set(MLX_BUILD_BENCHMARKS + OFF + CACHE BOOL "" FORCE +) +set(MLX_BUILD_PYTHON_STUBS + OFF + CACHE BOOL "" FORCE +) +set(MLX_BUILD_CUDA + OFF + CACHE BOOL "" FORCE +) +set(MLX_BUILD_CPU + OFF + CACHE BOOL "" FORCE +) +set(MLX_BUILD_METAL + ON + CACHE BOOL "" FORCE +) +set(MLX_BUILD_SHARED_LIBS + OFF + CACHE BOOL "" FORCE +) +set(MLX_BUILD_GGUF + OFF + CACHE BOOL "" FORCE +) +set(MLX_BUILD_SAFETENSORS + OFF + CACHE BOOL "" FORCE +) +set(MLX_METAL_JIT + ON + CACHE BOOL "Use JIT compiled Metal kernels" +) + +# Auto-apply patches to MLX submodule. Each patch is applied idempotently: `git +# apply --check` tests whether the patch is still applicable (i.e. not yet +# applied), and only then applies it. +set(_mlx_patches "${CMAKE_CURRENT_SOURCE_DIR}/patches/mlx_json.patch") +foreach(_patch IN LISTS _mlx_patches) + if(EXISTS "${_patch}" AND EXISTS "${MLX_SOURCE_DIR}") + get_filename_component(_patch_name "${_patch}" NAME) + execute_process( + COMMAND git apply --check "${_patch}" + WORKING_DIRECTORY ${MLX_SOURCE_DIR} + RESULT_VARIABLE _patch_check + OUTPUT_QUIET ERROR_QUIET + ) + if(_patch_check EQUAL 0) + execute_process( + COMMAND git apply "${_patch}" WORKING_DIRECTORY ${MLX_SOURCE_DIR} + ) + message(STATUS "Applied ${_patch_name} to MLX submodule") + else() + message(STATUS "${_patch_name} already applied or not applicable") + endif() + endif() +endforeach() + +# Add MLX subdirectory +message(STATUS "Adding MLX from submodule: ${MLX_SOURCE_DIR}") +add_subdirectory(${MLX_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR}/mlx) + +# ----------------------------------------------------------------------------- +# MLX Backend library +# ----------------------------------------------------------------------------- + +# Op logging option (for debugging) - OFF by default for performance +option(ET_MLX_ENABLE_OP_LOGGING "Enable per-op logging in MLX delegate" OFF) + +set(_mlx_backend__srcs ${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXLoader.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXBackend.cpp +) + +add_library(mlxdelegate ${_mlx_backend__srcs}) + +# Ensure schema is generated before compiling +add_dependencies(mlxdelegate mlx_schema) + +# Add logging flag if enabled +if(ET_MLX_ENABLE_OP_LOGGING) + target_compile_definitions(mlxdelegate PRIVATE ET_MLX_ENABLE_OP_LOGGING=1) + message(STATUS "MLX delegate op logging ENABLED") +endif() + +target_include_directories( + mlxdelegate PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/runtime +) + +# Link against MLX and executorch mlx is only available at BUILD_INTERFACE - +# consumers must link to mlx separately +target_link_libraries( + mlxdelegate PRIVATE mlx_schema executorch_core $ +) + +executorch_target_link_options_shared_lib(mlxdelegate) +target_compile_options(mlxdelegate PRIVATE ${_common_compile_options}) +if(EXECUTORCH_MLX_ENABLE_SANITIZERS) + target_link_options(mlxdelegate PRIVATE ${_mlx_sanitizer_link_options}) +endif() + +install( + TARGETS mlxdelegate mlx_schema + EXPORT ExecuTorchTargets + DESTINATION ${CMAKE_INSTALL_LIBDIR} +) + +# Install mlx library for downstream consumers +install(TARGETS mlx DESTINATION ${CMAKE_INSTALL_LIBDIR}) + +# Install mlx headers for downstream consumers that may need mlx types +install( + DIRECTORY ${MLX_SOURCE_DIR}/mlx/ + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/mlx + FILES_MATCHING + PATTERN "*.h" +) + +# Install mlx.metallib (Metal GPU kernels) for runtime execution +# +# MLX searches for metallib in this order (see mlx/backend/metal/device.cpp): 1. +# {binary_dir}/mlx.metallib - colocated with the .so/.dylib 2. +# {binary_dir}/Resources/mlx/ - Resources subdirectory 3. SwiftPM bundle - +# not applicable for us 4. {binary_dir}/Resources/default/ - Resources +# subdirectory 5. METAL_PATH (compile-time) - hardcoded build path (won't +# exist) +# +# where {binary_dir} is determined at runtime via dladdr() on the library +# containing MLX code. When MLX is statically linked into _portable_lib.so, this +# is the directory containing _portable_lib.so. +# +# For the installed library, we put metallib in lib/ alongside libmlx.a +install( + FILES ${CMAKE_CURRENT_BINARY_DIR}/mlx/mlx/backend/metal/kernels/mlx.metallib + DESTINATION ${CMAKE_INSTALL_LIBDIR} +) + +# Cache the metallib path for pybindings to copy it next to _portable_lib.so +# This enables editable installs to work correctly +set(MLX_METALLIB_PATH + "${CMAKE_CURRENT_BINARY_DIR}/mlx/mlx/backend/metal/kernels/mlx.metallib" + CACHE INTERNAL "Path to mlx.metallib for pybindings" +) + +# ----------------------------------------------------------------------------- +# Tests (off by default; CI passes -DEXECUTORCH_BUILD_TESTS=ON) +# ----------------------------------------------------------------------------- + +if(EXECUTORCH_BUILD_TESTS) + add_subdirectory(test) +endif() diff --git a/backends/mlx/README.md b/backends/mlx/README.md new file mode 100644 index 00000000000..ebab893385a --- /dev/null +++ b/backends/mlx/README.md @@ -0,0 +1,499 @@ +# MLX Delegate for ExecuTorch + +The MLX delegate compiles PyTorch models to run on Apple Silicon GPUs via the +[MLX](https://github.com/ml-explore/mlx) framework. It consists of: + +- A **Python compilation pipeline** that converts ExportedPrograms (Edge IR) into + a custom FlatBuffer bytecode format. +- A **C++ runtime** that loads the bytecode and executes it using MLX GPU + primitives. + +> **Adding a new op?** Jump to [How to Add a New Op](#how-to-add-a-new-op). + +## Getting Started + +The MLX delegate requires **Apple Silicon** (M1 or later) and the **Metal +compiler**, which ships with Xcode (not the standalone Command Line Tools). + +**Check if Metal is available:** + +```bash +xcrun -sdk macosx --find metal +``` + +If this prints a path (e.g. `/Applications/Xcode.app/.../metal`), you're set. +If it errors, you either need to install Xcode from the +[App Store](https://apps.apple.com/us/app/xcode/id497799835) or +, or — if Xcode is already installed but the +command line developer directory points at Command Line Tools — switch it: + +```bash +sudo xcode-select -s /Applications/Xcode.app/Contents/Developer +``` + +### Python (pybindings) + +The simplest way to get started is to install ExecuTorch with Python bindings. +From the repo root: + +```bash +python install_executorch.py +``` + +This builds and installs the `executorch` pip package with pybindings. On Apple +Silicon, when the Metal compiler is available, the MLX backend is automatically +included. You can then export models in Python using the MLX partitioner and run +them via the ExecuTorch Python API. + +### C++ (CMake preset) + +To build the C++ runtime with the MLX delegate, use the `mlx-release` CMake +workflow preset from the repo root: + +```bash +cmake --workflow --preset mlx-release +``` + +This configures and builds a Release build of the ExecuTorch runtime with the +MLX delegate and installs artifacts into `cmake-out/`. The preset enables the +MLX delegate along with commonly needed extensions (module, data loader, flat +tensor, LLM runner, etc.). + +Downstream C++ apps can then `find_package(executorch)` and link against +`mlxdelegate` and `mlx`. See +[`examples/models/llama/CMakeLists.txt`](../../examples/models/llama/CMakeLists.txt) +for a working example. + +There is also an `mlx-debug` preset that enables debug symbols and compiles in +per-op logging support, which is useful during development: + +```bash +cmake --workflow --preset mlx-debug +``` + +The debug build compiles in the logging code, but to actually see per-op output +you must also set the environment variable when running the binary: + +```bash +ET_MLX_ENABLE_OP_LOGGING=1 ./cmake-out/my_app +``` + +### Debugging + +Set `ET_MLX_DEBUG=1` during AOT (export/compilation) to see detailed debug +logging from the partitioner and preprocessor — including ops-to-not-decompose +lists, graph dumps, per-node support decisions, and serialization details: + +```bash +ET_MLX_DEBUG=1 python -m executorch.backends.mlx.examples.llm.export_llm_hf ... +``` + +--- + +## Directory Layout + +``` +backends/mlx/ +├── serialization/ # Schema + code generation +│ ├── schema.fbs # ← Source of truth (FlatBuffer schema) +│ ├── generate.py # Code generator (schema.fbs → everything else) +│ ├── mlx_graph_schema.py # [GENERATED] Python dataclasses for IR nodes +│ ├── mlx_graph_serialize.py # Serialization to FlatBuffer binary +│ ├── _generated_serializers.py # [GENERATED] Per-op FlatBuffer builders +│ └── _generated/ # [GENERATED] FlatBuffer Python bindings (flatc) +├── runtime/ # C++ runtime (loaded at inference time) +│ ├── MLXBackend.cpp # BackendInterface (init / execute / destroy) +│ ├── MLXLoader.h/.cpp # [GENERATED] FlatBuffer → C++ structs +│ ├── MLXExecutor.h # ExecutionState, constant loading, helpers +│ ├── MLXInterpreter.h # Op dispatch loop + per-op exec_* functions +│ └── schema_generated.h # [GENERATED] FlatBuffer C++ bindings (flatc) +├── llm/ # LLM infrastructure (KV cache, attention, etc.) +│ ├── cache.py # KV cache implementations (ET + HF static cache) +│ ├── et_attention.py # ExecuTorch custom SDPA attention +│ ├── hf_attention.py # HuggingFace custom SDPA attention +│ ├── quantization.py # TorchAO quantization helpers +│ └── source_transformation.py # Source transforms for MLX export +├── _generated_inspector.py # [GENERATED] Inspector utilities for .pte debugging +├── _logging.py # Debug logging utilities (ET_MLX_DEBUG) +├── builder/ # Core build infrastructure +│ ├── op_registry.py # REGISTRY (op handler registration) +│ ├── op_helpers.py # Helper utilities for op handlers +│ ├── pattern_matcher.py # Pattern matching for multi-node fusions +│ ├── program_builder.py # MLXProgramBuilder +│ └── slot_manager.py # Tensor/value slot allocation +├── ops.py # Op handlers (ATen target → MLX IR node) +├── patterns.py # Pattern handlers (multi-node fusions) +├── passes.py # Graph passes (RMSNorm fusion, CSE, etc.) +├── pattern_utils.py # Pattern matching utilities for passes +├── partitioner.py # Decides which ops to delegate to MLX +├── preprocess.py # BackendDetails.preprocess() entry point +├── custom_ops.py # Custom torch ops (kv_cache_update, custom_sdpa, rope) +├── pte_inspector.py # .pte file inspection/debugging tool +├── test/ +│ ├── test_ops.py # Op test definitions (models + configs) +│ ├── test_utils.py # OpTestCase base class + helpers +│ ├── op_test_runner.cpp # C++ test runner (loads .pte, runs, compares) +│ └── run_all_tests.py # End-to-end: export → C++ run → compare +└── examples/ + ├── llm/ # LLM export + run via HuggingFace + └── whisper/ # Whisper export + run +``` + +Files marked **[GENERATED]** are NOT CHECKED IN CODE and are produced by running: + +```bash +python backends/mlx/serialization/generate.py +``` + +--- + +## Compilation Pipeline + +The compilation pipeline converts a PyTorch model into a `.pte` file containing +the MLX delegate payload. The high-level flow: + +``` +torch.export() → ExportedProgram (ATen IR) +to_edge_transform_and_lower() → Edge IR + partitioning + lowering +``` + +Within that flow, the MLX-specific steps are: + +1. **Partitioning** (`partitioner.py`) — `MLXPartitioner` walks the Edge IR + graph and tags nodes that MLX can handle. It uses `MLXProgramBuilder` in a + dry-run mode to determine support — so partitioning and compilation use the + exact same logic. Unsupported ops fall back to ExecuTorch's portable + runtime. + +2. **Preprocessing** (`preprocess.py`) — For each partitioned subgraph, + `MLXBackend.preprocess()` is called. It builds an `MLXGraph` via + `MLXProgramBuilder`, serializes it to FlatBuffer, and returns a + `PreprocessResult` with the binary payload and constant data. + +3. **Op handling** (`ops.py`, `patterns.py`) — During the build, + `MLXProgramBuilder` walks the FX graph node-by-node and dispatches to + registered handlers. Single-op handlers live in `ops.py`; multi-node fused + patterns (e.g., quantized linear, SDPA, KV cache update) live in + `patterns.py`. + +4. **Serialization** (`serialization/`) — The `MLXGraph` dataclass tree is + serialized to a FlatBuffer binary. See [Serialization](#serialization) below. + +The complete preprocessing flow: + +``` +ExportedProgram (subgraph) + → MLXProgramBuilder.build() # walks FX graph, calls op handlers + → MLXGraph # Python IR (dataclasses from mlx_graph_schema.py) + → MLXGraphSerializer.serialize() # FlatBuffer binary + → PreprocessResult # returned to ExecuTorch +``` + +--- + +## How to Add a New Op + +This section walks through adding a new op end-to-end, using **`aten.linear`** +as an example. + +### Step 1: Add the Node to `schema.fbs` + +Add a new table in the "Op nodes" section and add it to the `OpNode` union: + +```fbs +table LinearNode { + x: Tid (required); + weight: Tid (required); + out: Tid (required); + bias: Tid; // optional +} +``` + +Then add `LinearNode` to the `union OpNode { ... }` list. + +### Step 2: Run the Code Generator + +```bash +python backends/mlx/serialization/generate.py +``` + +This regenerates: + +- `mlx_graph_schema.py` — adds `LinearNode` Python dataclass +- `_generated_serializers.py` — adds `_build_LinearNode` serializer +- `runtime/MLXLoader.h` — adds `LinearNode` C++ struct, `OpCode::LINEAR`, loader +- `runtime/MLXLoader.cpp` — adds FlatBuffer → `LinearNode` deserialization +- `runtime/schema_generated.h` — FlatBuffer C++ bindings + +### Step 3: Add the Python Op Handler (`ops.py`) + +Register a handler that converts the ATen op to your new node. Make sure to +import `LinearNode` from `mlx_graph_schema`: + +```python +from executorch.backends.mlx.serialization.mlx_graph_schema import LinearNode + +@REGISTRY.register(target=[torch.ops.aten.linear.default]) +def _linear_handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 2, 3, "aten.linear") + require_kwargs(P.kwargs(n), set(), "aten.linear") + x, w = args[0], args[1] + b = args[2] if len(args) > 2 else None + out = P.make_or_get_slot(n) + P.emit( + LinearNode( + x=P.slot_to_tid(x), + weight=P.slot_to_tid(w), + out=P.slot_to_tid(out), + bias=P.slot_to_tid(b) if b else None, + ) + ) + return out +``` + +Key APIs: +- **`P.args(n)`** — resolves FX node args to `Slot` objects (tensor/value references) +- **`P.make_or_get_slot(n)`** — allocates the output tensor slot +- **`P.slot_to_tid(slot)`** — converts a `Slot` to a `Tid` for the IR node +- **`P.emit(node)`** — appends the instruction to the graph + +### Step 4: Add the C++ Op Handler (`MLXInterpreter.h`) + +Add an `exec_*` function in the `ops` namespace: + +```cpp +inline void exec_linear(const LinearNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& X = st.const_tensor_ref(n.x); + auto W = transpose(st.const_tensor_ref(n.weight), {1, 0}, s); + array Y = n.bias + ? addmm(st.const_tensor_ref(*n.bias), X, W, 1.0f, 1.0f, s) + : matmul(X, W, s); + st.set_tensor(n.out, std::move(Y)); +} +``` + +Then add the dispatch case in `Interpreter::execute_instruction()`: + +```cpp +case OpCode::LINEAR: + ops::exec_linear(std::get(instr.node), st, s); + break; +``` + +### Step 5: Write a Test (`test/test_ops.py`) + +Each test follows a standard pattern: + +1. **Define a `nn.Module`** that uses the op. +2. **Define an `OpTestCase` subclass** that specifies test configurations. +3. **Decorate with `@register_test`** to register it with the test runner. + +```python +class LinearModel(nn.Module): + def __init__(self, in_features=64, out_features=128, bias=True): + super().__init__() + self.linear = nn.Linear(in_features, out_features, bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + +@register_test +class LinearTest(OpTestCase): + name = "linear" + rtol = 1e-4 + atol = 1e-4 + + def __init__(self, in_features=64, out_features=128, bias=True): + self.in_features = in_features + self.out_features = out_features + self.bias = bias + + @classmethod + def get_test_configs(cls): + return [cls(), cls(bias=False)] + + def create_model(self): + return LinearModel(self.in_features, self.out_features, bias=self.bias) + + def create_inputs(self): + return (torch.randn(2, 16, self.in_features),) +``` + +### Step 6: Run Tests + +Tests are end-to-end: export `.pte` → run via C++ `op_test_runner` → compare +outputs against PyTorch reference. Since adding a new op always involves C++ +changes, use `--rebuild` to recompile the runtime: + +```bash +python -m executorch.backends.mlx.test.run_all_tests --rebuild linear +``` + +Run all tests in parallel: + +```bash +python -m executorch.backends.mlx.test.run_all_tests --rebuild -j4 --clean-after +``` + +Other useful flags: + +| Flag | Purpose | +|---|---| +| `--rebuild` | Rebuild the C++ `op_test_runner` before running | +| `-j N` / `--parallel N` | Run N tests in parallel | +| `--clean-after` | Remove generated test artifacts after running | +| `--list` | List all available test names and exit | +| `-v` / `--verbose` | Verbose output | + +Test artifacts are saved to `test/op_tests//` (`.pte`, input/output +`.bin` files). See [`test/README.md`](test/README.md) for full details on test +architecture, prerequisites, and the `OpTestCase` API. + +### Checklist + +- [ ] Add `*Node` table to `schema.fbs` + add to `OpNode` union +- [ ] Run `python backends/mlx/serialization/generate.py` +- [ ] Add `@REGISTRY.register` handler in `ops.py` (and import the new node class) +- [ ] Add `exec_*` function in `runtime/MLXInterpreter.h` +- [ ] Add `case OpCode::*` in `Interpreter::execute_instruction()` +- [ ] Add test model + `OpTestCase` in `test/test_ops.py` +- [ ] Run `python -m executorch.backends.mlx.test.run_all_tests --rebuild ` + +--- + +## Serialization + +### Overview + +The serialization system converts a Python `MLXGraph` dataclass tree into a +FlatBuffer binary that the C++ runtime can load. The source of truth is +**`schema.fbs`** — a single FlatBuffer schema file from which all code on both +sides is generated. + +### Schema (`schema.fbs`) + +The schema defines: + +| Concept | FlatBuffer type | Purpose | +|---|---|---| +| **`Tid`** | struct | Tensor slot index (indexes into the runtime tensor array) | +| **`Vid`** | struct | Value slot index (for scalar `int32`/`float`/`bool` values) | +| **`IntOrVid`** | table | A field that is either a literal `int64` or a runtime `Vid` reference (for dynamic shapes) | +| **`FloatOrVid`** | table | Same idea for floats | +| **`TidOrVid`** | table | Either a tensor or a scalar value | +| **Op node tables** | table | One per op (e.g. `AddNode`, `SiluNode`, `ReshapeNode`). Each declares its inputs/outputs as `Tid`/`Vid` references and any scalar parameters. | +| **`OpNode`** | union | Union of all op node tables | +| **`Instruction`** | table | Wraps an `OpNode` union | +| **`MLXGraph`** | table (root) | The complete program: slot counts, instruction list, I/O maps, named slots, tensor metadata | + +Key design points: + +- **No embedded weights.** Constants are stored in ExecuTorch's `named_data_map` + and loaded by name at runtime. This enables zero-copy on unified memory. +- **Tensor IDs (`Tid`) are globally ordered:** Constants → Inputs → Outputs → + Mutable Buffers → Temps. The runtime uses this ordering for O(1) type lookup. +- **Dynamic shapes** are supported via `IntOrVid` — a shape dimension can be + either a literal integer or a reference to a runtime value produced by + `sym_size` / `item()` ops. + +### Code Generation (`generate.py`) + +`generate.py` parses `schema.fbs` and generates **all** boilerplate on both the +Python and C++ sides: + +| Generated file | What it contains | +|---|---| +| `mlx_graph_schema.py` | Python `@dataclass` for every op node, `Tid`, `Vid`, `IntOrVid`, etc. | +| `_generated_serializers.py` | `GeneratedOpBuilders` mixin class with `_build_*Node` methods for every op | +| `_generated_inspector.py` | Inspector utilities for debugging `.pte` files | +| `runtime/MLXLoader.h` | C++ structs for every op node, `OpCode` enum, `NodeVariant`, `Instruction`, `MLXProgram` | +| `runtime/MLXLoader.cpp` | `load_instruction()` and `load_program()` — FlatBuffer → C++ struct conversion | +| `runtime/schema_generated.h` | Standard FlatBuffer C++ bindings (via `flatc`) | +| `_generated/` directory | Standard FlatBuffer Python bindings (via `flatc`) | + +Running the generator: + +```bash +python backends/mlx/serialization/generate.py +``` + +Use `--skip-flatc` if you only changed op node definitions (not core types) and +want to skip the `flatc` invocation. + +### Serialization Format + +The binary payload embedded in the `.pte` file has this layout: + +``` +[Header: 24 bytes] + 4 bytes padding (zeros) + 4 bytes magic ("MLX0") + 8 bytes data_segment_offset (uint64 LE) + 8 bytes data_segment_size (uint64 LE) +[FlatBuffer payload] +[Padding to 16-byte alignment] +[Data segment (currently unused — constants go via named_data_map)] +``` + +The `MLXGraphSerializer` class (in `mlx_graph_serialize.py`) drives +serialization. It inherits `GeneratedOpBuilders` for the per-op builders and +adds the root-table construction, I/O maps, tensor metadata, and header. + +--- + +## Runtime + +### Initialization (`init`) + +When ExecuTorch loads a `.pte` with an MLX delegate blob, `MLXBackend::init()` +is called: + +1. **Parse FlatBuffer** — `loader::load_program()` deserializes the binary into + an `MLXProgram` struct (C++ mirrors of the schema). +2. **Load constants** — Iterates `named_slots`, calls + `named_data_map->get_data(name)` for each constant tensor, wraps the buffer + as an `mlx::core::array` (zero-copy when possible on unified memory). +3. **Initialize mutable buffers** — Creates zero-filled MLX arrays for + persistent state (e.g., KV cache). These live across `execute()` calls. +4. **Bind execution state** — `ExecutionState::bind()` pre-computes tensor ID + ranges for O(1) routing. + +### Execution (`execute`) + +Each `execute()` call: + +1. **Reset** per-execution state (inputs/outputs/temps cleared; mutable buffers + and constants are retained). +2. **Bind inputs** — Walk `input_map`, convert each ExecuTorch tensor to an + `mlx::core::array` (zero-copy pointer wrap). +3. **Run instructions** — `Interpreter::run()` dispatches each `Instruction` + through a `switch` on `OpCode`, calling the corresponding `exec_*` function. +4. **Evaluate** — Call `mlx::core::eval()` on output tensors to trigger + lazy GPU computation. +5. **Copy outputs** — Convert MLX arrays back to ExecuTorch tensors via + `memcpy`. + +### Tensor ID Layout + +Tensor slot IDs are assigned in a fixed order during compilation: + +``` + ┌──────────┬──────────┬──────────┬────────────────┬──────────┐ + │ Constants│ Inputs │ Outputs │ Mutable Buffers│ Temps │ + │ 0..C-1 │ C..I-1 │ I..O-1 │ O..M-1 │ M..T-1 │ + └──────────┴──────────┴──────────┴────────────────┴──────────┘ +``` + +The runtime stores constants and mutable buffers in separate containers +(`ConstantData`, `MutableBufferData`). Inputs, outputs, and temps share a flat +`vector>` in `ExecutionState`. + +### Key Runtime Files + +| File | Role | +|---|---| +| `MLXBackend.cpp` | `init()` / `execute()` / `destroy()` — the ExecuTorch `BackendInterface` | +| `MLXLoader.h/.cpp` | [GENERATED] Deserializes FlatBuffer into `MLXProgram` (C++ structs) | +| `MLXExecutor.h` | `ExecutionState`, `ConstantData`, `MutableBufferData`, constant loading, dtype conversion | +| `MLXInterpreter.h` | The op dispatch switch + all `exec_*` implementations | diff --git a/backends/mlx/__init__.py b/backends/mlx/__init__.py new file mode 100644 index 00000000000..48f4c28f5ca --- /dev/null +++ b/backends/mlx/__init__.py @@ -0,0 +1,17 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +"""MLX backend for ExecuTorch - executes models on Apple Silicon using MLX.""" + +# Import custom_ops module to register custom ATen ops (rope, etc.) +from executorch.backends.mlx import custom_ops as _custom_ops # noqa: F401 +from executorch.backends.mlx.partitioner import MLXPartitioner + +from executorch.backends.mlx.preprocess import MLXBackend + +__all__ = ["MLXBackend", "MLXPartitioner"] diff --git a/backends/mlx/_logging.py b/backends/mlx/_logging.py new file mode 100644 index 00000000000..eff472550f9 --- /dev/null +++ b/backends/mlx/_logging.py @@ -0,0 +1,40 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +Centralized logging for the MLX backend. + +Usage: + from executorch.backends.mlx._logging import logger + + logger.info("Always visible (e.g., unsupported ops summary)") + logger.debug("Only visible when ET_MLX_DEBUG=1") + logger.warning("Always visible") + +The logger is set to INFO by default, so logger.info() always prints. +Set ET_MLX_DEBUG=1 to lower the threshold to DEBUG for verbose output +(graph dumps, per-node traces, ops_to_not_decompose lists, etc.). +""" + +import logging +import os + +_MLX_DEBUG = os.environ.get("ET_MLX_DEBUG", "") not in ("", "0") + +logger = logging.getLogger("executorch.backends.mlx") +logger.setLevel(logging.DEBUG if _MLX_DEBUG else logging.INFO) +logger.propagate = False + +if not logger.handlers: + _handler = logging.StreamHandler() + _handler.setFormatter( + logging.Formatter( + "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" + ) + ) + logger.addHandler(_handler) diff --git a/backends/mlx/builder/__init__.py b/backends/mlx/builder/__init__.py new file mode 100644 index 00000000000..ce793ed9a15 --- /dev/null +++ b/backends/mlx/builder/__init__.py @@ -0,0 +1,16 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +# Trigger op/pattern handler registration. +# ops.py and patterns.py use @REGISTRY.register() decorators at import time. +# This must happen after REGISTRY is defined (in op_registry.py). +from executorch.backends.mlx import ops, patterns # noqa: F401 +from executorch.backends.mlx.builder.op_registry import REGISTRY # noqa: F401 +from executorch.backends.mlx.builder.program_builder import ( # noqa: F401 + MLXProgramBuilder, +) diff --git a/backends/mlx/builder/op_helpers.py b/backends/mlx/builder/op_helpers.py new file mode 100644 index 00000000000..5e082cdf386 --- /dev/null +++ b/backends/mlx/builder/op_helpers.py @@ -0,0 +1,275 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +from __future__ import annotations + +from typing import Dict, Optional, Tuple, TYPE_CHECKING, Union + +import torch +from executorch.backends.mlx.builder.slot_manager import Slot +from executorch.exir.scalar_type import ScalarType +from torch.fx.node import Node + +if TYPE_CHECKING: + from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder + + +def get_aten_target(target): + """ + Unwrap EdgeOpOverload to get the underlying ATen op. + + In Edge IR, ops are wrapped in EdgeOpOverload. This extracts the + underlying ATen op for consistent comparison. + """ + if hasattr(target, "_op") and "EdgeOpOverload" in type(target).__name__: + return target._op + return target + + +# Mapping from _copy variants to their non-copy equivalents. +# Edge IR uses _copy variants for certain ops, but for pattern matching +# we want to compare against the semantic operation. +_COPY_TO_NON_COPY = { + torch.ops.aten.slice_copy.Tensor: torch.ops.aten.slice.Tensor, + torch.ops.aten.transpose_copy.int: torch.ops.aten.transpose.int, + torch.ops.aten.view_copy.default: torch.ops.aten.view.default, + torch.ops.aten.permute_copy.default: torch.ops.aten.permute.default, + torch.ops.aten.unsqueeze_copy.default: torch.ops.aten.unsqueeze.default, + torch.ops.aten.squeeze_copy.dim: torch.ops.aten.squeeze.dim, + torch.ops.aten.squeeze_copy.dims: torch.ops.aten.squeeze.dims, + torch.ops.aten.squeeze_copy.default: torch.ops.aten.squeeze.default, + torch.ops.aten.expand_copy.default: torch.ops.aten.expand.default, + torch.ops.aten.alias_copy.default: torch.ops.aten.alias.default, +} + + +def get_aten_target_normalized(target): + """ + Get ATen target, mapping _copy variants to their non-copy equivalents. + + Use this for pattern matching where Edge IR uses _copy variants but + we want to match the semantic operation. + + E.g., aten.transpose_copy.int -> aten.transpose.int + """ + target = get_aten_target(target) + return _COPY_TO_NON_COPY.get(target, target) + + +def emit_stop_position( + P: "MLXProgramBuilder", + start: "Union[int, Slot]", + length_tensor: "Slot", + length_dim: int, + length_meta: "Optional[torch.Tensor]" = None, +) -> "Union[int, Slot]": + """ + Emit nodes to compute stop = start + length for slice operations. + + May emit SymSizeNode and/or AddIntNode depending on whether + start and length are static or dynamic. + + Args: + P: The program builder + start: Start position (int or Slot) + length_tensor: The tensor slot whose dimension gives the length + length_dim: Which dimension of length_tensor contains the length + length_meta: Optional tensor metadata for static length extraction + + Returns: + stop position as int (if fully static) or Slot (if any dynamic) + """ + from executorch.backends.mlx.serialization.mlx_graph_schema import ( + AddIntNode, + IntOrVid, + SymSizeNode, + ) + + # Check if seq_len is symbolic (dynamic) + seq_len_is_symbolic = False + seq_len_concrete = None + + if length_meta is not None: + seq_len_dim = length_meta.shape[length_dim] + if hasattr(seq_len_dim, "node"): + seq_len_is_symbolic = True + else: + seq_len_concrete = int(seq_len_dim) + + if seq_len_is_symbolic or length_meta is None: + # Dynamic seq_len: emit SymSizeNode to get length at runtime + _, seq_len_slot = P.slot_manager.make_tmp_value_slot() + P.emit( + SymSizeNode( + a=P.slot_to_tid(length_tensor), + dim=length_dim, + out=P.slot_to_vid(seq_len_slot), + ) + ) + _, stop_slot = P.slot_manager.make_tmp_value_slot() + if isinstance(start, Slot): + start_iov = P.to_int_or_vid(start) + else: + start_iov = IntOrVid.from_literal(int(start)) + P.emit( + AddIntNode( + a=start_iov, + b=IntOrVid.from_vid(P.slot_to_vid(seq_len_slot)), + out=P.slot_to_vid(stop_slot), + ) + ) + return stop_slot + else: + # Static seq_len + if isinstance(start, Slot): + # Dynamic start + static length + _, stop_slot = P.slot_manager.make_tmp_value_slot() + P.emit( + AddIntNode( + a=P.to_int_or_vid(start), + b=IntOrVid.from_literal(seq_len_concrete), + out=P.slot_to_vid(stop_slot), + ) + ) + return stop_slot + else: + # Both static - just return the sum + return start + seq_len_concrete + + +def emit_lifted_constant(P: "MLXProgramBuilder", value, dtype: torch.dtype) -> Slot: + """Lift a scalar to a 0-D tensor. + + Concrete scalars (int/float/bool) become deduplicated constants. + Dynamic values (SymInt Slots) emit a FullNode at runtime. + """ + + if isinstance(value, (int, float, bool)): + return P.make_or_get_constant( + f"_scalar_{value}", torch.tensor(value, dtype=dtype) # 0-D + ) + + from executorch.backends.mlx.serialization.mlx_graph_schema import FullNode + + _, slot = P.make_tmp_slot() + P.emit( + FullNode( + shape=[], + v=P.to_float_or_vid(value), + scalar_type=torch_dtype_to_scalar_type(dtype), + out=P.slot_to_tid(slot), + ) + ) + return slot + + +def to_mlx_qparams( + qdata: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + bits: int, + compute_biases: bool = True, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Convert TorchAO quantization params to MLX format. + + TorchAO uses: s * (q - z), with q signed + MLX uses: S * Q + B, with Q unsigned + + s * (q - z) + = s ((q + offset) - (z + offset)) + = s Q + B, + where Q = q + offset, B = -s * (z + offset) + + Args: + compute_biases: If False, skip bias computation (for scale_only mode). + Returns (Q, None) in this case. This is valid when + zero_point is all zeros, as the C++ runtime will compute + biases = -scales * 2^(bits-1). + """ + assert qdata.dtype == torch.int8 + offset = 2 ** (bits - 1) + Q = qdata.to(torch.int32) + offset + + # Pack data tightly into uint32 + assert 32 % bits == 0 + vals_per_uint32 = 32 // bits + assert qdata.shape[1] % vals_per_uint32 == 0 + + Q = Q.reshape(-1, vals_per_uint32) + shifts = torch.arange(0, 32, bits, dtype=torch.int64) + + # Convert to int64 for shift/packing + Q = Q.to(torch.int64) + Q = (Q << shifts).sum(dim=-1) + Q = Q.to(torch.uint32) + Q = Q.reshape(qdata.shape[0], -1) + + if compute_biases: + B = -scale * (zero_point.to(scale.dtype) + offset) + return Q, B + else: + return Q, None + + +def parse_dequant_node( + node: Node, +) -> Optional[Tuple[Node, Node, Node, int, int, Optional[torch.dtype], int]]: + """Parse a torchao.dequantize_affine node. + + Accepts N-dimensional block_size with a single non-1 element identifying + the quantized dimension and group_size. For example: + - Linear weights (2D): block_size=[1, 32] → quantized_dim=1 + - Conv2d weights (4D): block_size=[1, 32, 1, 1] → quantized_dim=1 + + Returns (qdata, scale, zero_point, group_size, bits, out_dtype, quantized_dim) + or None if unsupported. + """ + qdata, block_size, scale, zero_point, dtype, qmin, qmax = node.args[0:7] + out_dtype = ( + node.args[7] if len(node.args) > 7 else node.kwargs.get("output_dtype", None) + ) + if dtype != torch.int8: + return None + if len(block_size) < 2: + return None + non_one = [(i, d) for i, d in enumerate(block_size) if d != 1] + if len(non_one) != 1: + return None + quantized_dim, group_size = non_one[0] + if group_size not in [32, 64, 128]: + return None + if qmin == -8 and qmax == 7: + bits = 4 + elif qmin == -128 and qmax == 127: + bits = 8 + else: + return None + return qdata, scale, zero_point, group_size, bits, out_dtype, quantized_dim + + +# Mapping from torch dtype to ET ScalarType int value +# See executorch/exir/scalar_type.py for ScalarType enum +_TORCH_DTYPE_TO_SCALAR_TYPE: Dict[torch.dtype, int] = { + torch.float16: ScalarType.HALF, + torch.float32: ScalarType.FLOAT, + torch.bfloat16: ScalarType.BFLOAT16, + torch.int32: ScalarType.INT, + torch.int64: ScalarType.LONG, + torch.uint32: ScalarType.UINT32, + torch.uint8: ScalarType.BYTE, + torch.bool: ScalarType.BOOL, + torch.int8: ScalarType.CHAR, +} + + +def torch_dtype_to_scalar_type(dtype: torch.dtype) -> int: + """Convert torch dtype to ET ScalarType int value.""" + if dtype not in _TORCH_DTYPE_TO_SCALAR_TYPE: + raise ValueError(f"Unsupported dtype: {dtype}") + return int(_TORCH_DTYPE_TO_SCALAR_TYPE[dtype]) diff --git a/backends/mlx/builder/op_registry.py b/backends/mlx/builder/op_registry.py new file mode 100644 index 00000000000..19668ca2c1b --- /dev/null +++ b/backends/mlx/builder/op_registry.py @@ -0,0 +1,151 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +from __future__ import annotations + +from typing import Callable, Dict, List, Optional, Tuple, Type, TYPE_CHECKING, Union + +from executorch.backends.mlx._logging import logger +from torch.fx.node import Node + +if TYPE_CHECKING: + from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder + from executorch.backends.mlx.builder.slot_manager import Slot + from torch.export import ExportedProgram + +# Handler type: takes (builder, node) and returns optional slot(s) +Handler = Callable[ + ["MLXProgramBuilder", Node], Optional[Union["Slot", Tuple["Slot", ...]]] +] + + +class PatternHandler: + def __init__(self, head: Node, body: List[Node]) -> None: + self.head: Node = head + self.body: List[Node] = body + + @classmethod + def deferred_handler(cls, P: MLXProgramBuilder, n: Node) -> None: + pass + + @classmethod + def maybe_create(cls, ep: ExportedProgram, head: Node) -> Optional[PatternHandler]: + raise NotImplementedError + + def __call__(self, P: MLXProgramBuilder, n: Node) -> None: + raise NotImplementedError + + def set_handlers(self, P: MLXProgramBuilder): + if P.node_info[self.head].handler is not None: + raise AssertionError( + f"Head node {self.head.name} already has handler {P.node_info[self.head].handler}, " + f"cannot set pattern {self.__class__.__name__}" + ) + for n in self.body: + if P.node_info[n].handler is not None: + raise AssertionError( + f"Body node {n.name} already has handler {P.node_info[n].handler}, " + f"cannot set pattern {self.__class__.__name__}" + ) + + logger.debug( + f"Pattern {self.__class__.__name__}: " + f"HEAD={self.head.name}, BODY={[n.name for n in self.body]}" + ) + P.node_info[self.head].handler = self + for n in self.body: + P.node_info[n].handler = PatternHandler.deferred_handler + + +class MLXOpRegistry: + """Registry for op handlers and pattern handlers.""" + + def __init__(self): + self._handlers: Dict[Union[str, Callable], Handler] = {} + self._patterns: Dict[str, Type[PatternHandler]] = {} + + def reset(self) -> None: + """Reset the registry to empty state. Useful for testing.""" + self._handlers.clear() + self._patterns.clear() + + def register(self, target: Union[str, Callable, list, tuple]): + """Decorator to register a handler for one or more op targets.""" + + def deco(fn: Handler): + targets = target if isinstance(target, (list, tuple)) else [target] + for t in targets: + if t in self._handlers: + raise ValueError(f"Target {t} already registered") + self._handlers[t] = fn + return fn + + return deco + + def get_handler(self, node: Node) -> Optional[Handler]: + """Get the handler for a node, or None if not registered.""" + t = node.target + if t in self._handlers: + return self._handlers[t] + # Handle EdgeOpOverload by extracting the underlying ATen op + if hasattr(t, "_op") and t._op in self._handlers: + return self._handlers[t._op] + # Check for string-based targets (e.g., higher_order ops) + target_str = str(t) + if target_str in self._handlers: + return self._handlers[target_str] + return None + + def registered_ops(self) -> set: + """Return all registered op targets.""" + return set(self._handlers.keys()) + + def unregister(self, target: Union[str, Callable, list, tuple]) -> None: + """Remove a handler for one or more op targets. + + This is useful for debugging - allows temporarily disabling specific + handlers to test if they are causing issues. + + Args: + target: Single target or list of targets to unregister + """ + targets = target if isinstance(target, (list, tuple)) else [target] + for t in targets: + if t in self._handlers: + del self._handlers[t] + + def register_pattern(self, name: str): + """Decorator to register a pattern handler class.""" + + def deco(cls: Type[PatternHandler]): + if not issubclass(cls, PatternHandler): + raise TypeError( + "register_pattern must decorate a PatternHandler subclass" + ) + if name in self._patterns: + raise ValueError(f"Pattern '{name}' already registered") + self._patterns[name] = cls + return cls + + return deco + + def get_pattern_cls(self, name: str) -> Optional[Type[PatternHandler]]: + """Get a pattern handler class by name.""" + return self._patterns.get(name) + + def get_noop_handler(self) -> Optional[Handler]: + """Get the NOOP handler, if registered.""" + return self._handlers.get("NOOP") + + def patterns(self): + """Return all registered pattern names.""" + return self._patterns.keys() + + +# Global registry +REGISTRY = MLXOpRegistry() diff --git a/backends/mlx/builder/pattern_matcher.py b/backends/mlx/builder/pattern_matcher.py new file mode 100644 index 00000000000..2db422e3f68 --- /dev/null +++ b/backends/mlx/builder/pattern_matcher.py @@ -0,0 +1,64 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +from __future__ import annotations + +from typing import List, TYPE_CHECKING + +from executorch.backends.mlx._logging import logger +from executorch.backends.mlx.builder.op_registry import PatternHandler + +if TYPE_CHECKING: + from executorch.backends.mlx.builder.op_registry import MLXOpRegistry + from torch.export import ExportedProgram + + +class PatternMatcher: + """ + Discovers and applies pattern handlers to an FX graph. + + Pattern handlers match multi-node subgraphs and lower them to optimized + MLX operations. This class orchestrates the pattern discovery process: + + 1. Iterates through all registered pattern types + 2. For each pattern, tries to match it against every node in the graph + 3. When a match is found, assigns handlers to the head and body nodes + + The ordering matters: patterns are matched before dead code elimination + because some pattern body nodes (e.g., update_cache) have no users + since they mutate in-place, but they're not dead. + """ + + def __init__(self, ep: ExportedProgram, registry: "MLXOpRegistry"): + self.ep = ep + self.registry = registry + self._matches: List[PatternHandler] = [] + + def find_patterns(self) -> List[PatternHandler]: + """ + Find all pattern matches in the graph. + + Returns a list of PatternHandler instances, one for each match found. + Patterns are tried in registration order. + """ + self._matches = [] + for name in self.registry.patterns(): + self._find_pattern(name) + return self._matches + + def _find_pattern(self, name: str) -> None: + """Try to match a single pattern type against all nodes.""" + pattern_cls = self.registry.get_pattern_cls(name) + if pattern_cls is None: + return + + for n in self.ep.graph.nodes: + handler = pattern_cls.maybe_create(self.ep, n) + if handler is not None: + logger.debug(f"Pattern {name} matched at node {n.name}") + self._matches.append(handler) diff --git a/backends/mlx/builder/program_builder.py b/backends/mlx/builder/program_builder.py new file mode 100644 index 00000000000..60d5ebbdbfe --- /dev/null +++ b/backends/mlx/builder/program_builder.py @@ -0,0 +1,1018 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +MLX Program Builder - converts an ExportedProgram to an MLXGraph. + +This module is responsible for: +1. Walking the FX graph from an ExportedProgram +2. Converting each node to the corresponding MLX op +3. Managing tensor and value slots +4. Building the final MLXGraph dataclass for serialization + +Op handlers are registered in ops.py. +Pattern handlers are registered in patterns.py. +""" + +from __future__ import annotations + +import traceback +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple, Union + +import torch + +from executorch.backends.mlx._logging import logger +from executorch.backends.mlx.builder.op_helpers import torch_dtype_to_scalar_type +from executorch.backends.mlx.builder.op_registry import ( + Handler, + PatternHandler, + REGISTRY, +) +from executorch.backends.mlx.builder.pattern_matcher import PatternMatcher +from executorch.backends.mlx.builder.slot_manager import ( + IdSpace, + IdType, + Slot, + SlotManager, +) +from executorch.backends.mlx.serialization.mlx_graph_schema import ( + FloatOrVid, + IdCopyNode, + Instruction, + InstructionChain, + IntOrVid, + IntOrVidOrTid, + MLXGraph, + NamedSlot, + OpNodeUnion, + ShapeDim, + SlotType, + SlotVariant, + TensorMeta, + Tid, + Vid, +) +from executorch.exir._serialize._named_data_store import NamedDataStore +from torch.export.exported_program import ExportedProgram +from torch.fx.node import Node +from torch.utils import _pytree as pytree + + +def _check_dtype(node: Node) -> Optional[str]: + """ + Check if a node has a supported dtype. + + Args: + node: The FX node to check + + Returns: + None if the node's dtype is supported, otherwise an error message string + """ + fake_val = node.meta.get("val", None) + if fake_val is not None and hasattr(fake_val, "dtype"): + try: + torch_dtype_to_scalar_type(fake_val.dtype) + except ValueError: + return f"has unsupported dtype: {fake_val.dtype}" + return None + + +def _check_input_dtypes(node: Node) -> Optional[str]: + """ + Check if all input tensors to a node have supported dtypes. + + Args: + node: The FX node to check + + Returns: + None if all input dtypes are supported, otherwise an error message string + describing which input (arg position or kwarg name) has an unsupported dtype + """ + # Check positional args + for i, arg in enumerate(node.args): + if isinstance(arg, Node): + dtype_error = _check_dtype(arg) + if dtype_error is not None: + return f"arg[{i}] ({arg.name}) {dtype_error}" + + # Check kwargs + for kwarg_name, kwarg_val in node.kwargs.items(): + if isinstance(kwarg_val, Node): + dtype_error = _check_dtype(kwarg_val) + if dtype_error is not None: + return f"kwarg '{kwarg_name}' ({kwarg_val.name}) {dtype_error}" + + return None + + +@dataclass +class NodeInfo: + handled: bool = False + handler: Optional[Union[Handler, PatternHandler]] = None + supported: bool = False + unsupported_reason: Optional[str] = None + name: Optional[str] = None + remaining_reads: int = 0 + + +class MLXProgramBuilder: + """ + Builds an MLXGraph from an ExportedProgram. + + Args: + ep: The ExportedProgram to build from + """ + + def __init__(self, ep: ExportedProgram, named_data_key_prefix: str = ""): + self.ep: ExportedProgram = ep + self._instrs: List[Instruction] = [] + self.extra_constants: Dict[str, torch.Tensor] = {} + self.slot_manager = SlotManager() + self.node_info: DefaultDict[Node, NodeInfo] = defaultdict(NodeInfo) + self._mlx_graph: Optional[MLXGraph] = None + # Map from SymInt symbol names (e.g., "s77") to the FX Node that produces them. + # This is used to resolve symbolic tensor dimensions to Vid references. + self._symint_to_node: Dict[str, Node] = {} + # Maps for remapping local slot indices to global Tid/Vid indices during build + self._tid_slot_map: List[Tuple[Tid, Slot]] = [] + self._vid_slot_map: List[Tuple[Vid, Slot]] = [] + # Prefix for named_data_store keys and named_slots to avoid collisions + # in multi-method programs where different methods may have lifted tensor + # constants with the same auto-generated name. + self._named_data_key_prefix: str = named_data_key_prefix + # Unprefixed canonical-name → Slot for constants, populated by _build_io_maps(). + # Used by get_named_data_store() to look up tensors without prefix interference. + self._constant_name_to_slot: Dict[str, Slot] = {} + + def _prefix_key(self, name: str) -> str: + """Apply the named-data key prefix for the .pte namespace. + + This is the single point where canonical (unprefixed) names are + transformed into the external keys used in the .pte's ``named_data`` + section and the FlatBuffer ``named_slots`` field. + """ + if self._named_data_key_prefix: + return f"{self._named_data_key_prefix}/{name}" + return name + + def emit(self, op: OpNodeUnion) -> None: + self._instrs.append(Instruction(op=op)) + + def args(self, node: Node) -> Tuple[Any, ...]: + return self.slot_map(node.args) + + def kwargs(self, node: Node) -> Dict[str, Any]: + return self.slot_map(node.kwargs) + + def slot_map(self, tree): + leaves, spec = pytree.tree_flatten(tree) + new_leaves = [] + for a in leaves: + if isinstance(a, Node): + # Use make_or_get_slots which handles both single and multi-output nodes. + # For single-output nodes, returns a 1-tuple; for multi-output, returns n-tuple. + # We unwrap single-element tuples for convenience. + slots = self.make_or_get_slots(a) + if len(slots) == 1: + new_leaves.append(slots[0]) + else: + new_leaves.append(slots) + else: + new_leaves.append(a) + + for a in new_leaves: + if isinstance(a, Slot): + assert self.slot_manager.is_alive( + a + ), f"Slot {a} is not alive; it was either already freed or never created" + + return pytree.tree_unflatten(new_leaves, spec) + + def make_or_get_slots( + self, node: Node, id_space: IdSpace = IdSpace.Temp + ) -> Tuple[Slot, ...]: + """Get or create slots for a multi-output node. Always returns a tuple.""" + return self.slot_manager.make_or_get_slots(node, id_space) + + def make_or_get_slot(self, node: Node, id_space: IdSpace = IdSpace.Temp) -> Slot: + """Get or create a slot for a single-output node. Returns a single Slot.""" + return self.slot_manager.make_or_get_slot(node, id_space) + + def set_slot(self, node: Node, slot: Slot): + self.slot_manager.set_slot(node, slot) + + def make_tmp_slot(self) -> Tuple[str, Slot]: + """Create a temporary tensor slot.""" + return self.slot_manager.make_tmp_slot() + + def make_tmp_value_slot(self) -> Tuple[str, Slot]: + """Create a temporary value (SymInt) slot.""" + return self.slot_manager.make_tmp_value_slot() + + def make_or_get_constant(self, name: str, tensor: torch.Tensor) -> Slot: + """ + Creates an extra constant outside of the ExportedProgram state_dict. + Ops can use this to add constants during build that do not exist in the + ExportedProgram state_dict, e.g., doing naive packing of quantized ops. + """ + assert name not in self.ep.state_dict + assert name not in self.ep.constants + + if name in self.extra_constants: + # During fake tensor tracing, we can't use torch.equal + # Just assume tensors with same name are the same + slot = self.slot_manager.get_slot(name) + assert slot is not None + return slot + + slot = self.slot_manager.make_constant_slot(name) + self.extra_constants[name] = tensor + return slot + + def get_placeholder_target_and_tensor(self, node: Node) -> Tuple[str, torch.Tensor]: + assert node.op == "placeholder" + placeholder_name = node.name + + sig = self.ep.graph_signature + sd = self.ep.state_dict + consts = self.ep.constants + + for ispec in sig.input_specs: + if ispec.arg.name != placeholder_name: + continue + target = ispec.target + if target is None: + continue + if target in sd: + return (target, sd[target]) + if target in consts: + return (target, consts[target]) + + raise KeyError(f"Unable to resolve placeholder {placeholder_name}") + + def slot_to_tid(self, slot: Slot) -> Tid: + """Convert a tensor Slot to a Tid, recording it for later remapping.""" + assert slot.id_type == IdType.Tensor + # Use local slot.idx as placeholder - will be remapped to global idx in build() + tid = Tid(idx=slot.idx) + self._tid_slot_map.append((tid, slot)) + return tid + + def slot_to_vid(self, slot: Slot) -> Vid: + """Convert a value Slot to a Vid, recording it for later remapping.""" + assert slot.id_type != IdType.Tensor + vid = Vid(idx=slot.idx) + self._vid_slot_map.append((vid, slot)) + return vid + + def to_int_or_vid(self, v: Union[int, Slot]) -> IntOrVid: + if isinstance(v, Slot): + return IntOrVid.from_vid(self.slot_to_vid(v)) + return IntOrVid.from_literal(int(v)) + + def to_float_or_vid(self, v: Union[float, int, Slot]) -> FloatOrVid: + if isinstance(v, Slot): + return FloatOrVid.from_vid(self.slot_to_vid(v)) + return FloatOrVid.from_literal(float(v)) + + def to_int_or_vid_or_tid(self, v: Union[int, Slot]) -> IntOrVidOrTid: + if isinstance(v, Slot): + if v.id_type == IdType.Tensor: + return IntOrVidOrTid.from_tid(self.slot_to_tid(v)) + return IntOrVidOrTid.from_vid(self.slot_to_vid(v)) + return IntOrVidOrTid.from_literal(int(v)) + + def _mark_read(self, node: Node): + assert self.node_info[node].handled, f"Node {node} is not handled" + assert ( + self.node_info[node].remaining_reads > 0 + ), f"Reading node {node}, but it has no remaining reads" + self.node_info[node].remaining_reads -= 1 + + if self.node_info[node].remaining_reads == 0: + slot = self.slot_manager.get_slot(node) + if slot is None: + return + if not isinstance(slot, tuple): + slot = (slot,) + for s in slot: + if s.id_space != IdSpace.Temp: + continue + if s.id_type == IdType.Tensor: + self.slot_manager.tid_managers[IdSpace.Temp].return_id(s.idx) + else: + self.slot_manager.vid_managers[IdSpace.Temp].return_id(s.idx) + + def _mark_node_handled(self, node: Node, *, handler: Optional[Handler] = None): + if self.node_info[node].handled: + return + self.node_info[node].handled = True + self.node_info[node].remaining_reads = len(node.users) + self.node_info[node].handler = handler + + if handler == PatternHandler.deferred_handler: + return + + def mark_read(n: Node): + flat_args, spec = pytree.tree_flatten((n.args, n.kwargs)) + seen = set() + for a in flat_args: + if isinstance(a, Node): + if a not in seen: + self._mark_read(a) + seen.add(a) + + if isinstance(handler, PatternHandler): + for n in handler.body: + mark_read(n) + mark_read(node) + + def _mark_node_supported(self, node: Node, *, handler: Optional[Handler] = None): + self.node_info[node].supported = True + self._mark_node_handled(node, handler=handler) + + def _mark_node_unsupported(self, node: Node, reason: str): + self.node_info[node].supported = False + self.node_info[node].unsupported_reason = reason + self._mark_node_handled(node) + + def _is_handled(self, node: Node) -> bool: + return self.node_info[node].handled + + def _mark_supported( + self, nodes: Union[List[Node], Node], *, handler: Optional[Handler] = None + ) -> None: + if isinstance(nodes, Node): + nodes = [nodes] + for node in nodes: + self._mark_node_supported(node, handler=handler) + + def _mark_unsupported(self, nodes: Union[List[Node], Node], reason: str) -> None: + if isinstance(nodes, Node): + nodes = [nodes] + for node in nodes: + self._mark_node_unsupported(node, reason) + + def _make_io_slots(self): # noqa: C901 + from torch.export.graph_signature import ( + InputKind, + OutputKind, + SymIntArgument, + TensorArgument, + ) + + output_kind_targets = defaultdict(set) + constant_tensors = [] + user_inputs = [] + user_outputs = [] + mutable_buffers = [] + + for ospec in self.ep.graph_signature.output_specs: + kind = ospec.kind + arg = ospec.arg + name = arg.name + target = ospec.target + if target is not None: + output_kind_targets[kind].add(target) + if kind in (OutputKind.USER_OUTPUT, OutputKind.USER_INPUT_MUTATION): + user_outputs.append(name) + + for ispec in self.ep.graph_signature.input_specs: + kind = ispec.kind + arg = ispec.arg + name = arg.name + target = ispec.target + + if isinstance(arg, TensorArgument): + if kind == InputKind.PARAMETER: + # Parameters are treated as constants (not mutated) + constant_tensors.append(name) + elif kind == InputKind.BUFFER: + if target in output_kind_targets[OutputKind.BUFFER_MUTATION]: + mutable_buffers.append(name) + else: + # Non-mutated buffers (like lifted tensor constants) are constants + constant_tensors.append(name) + elif kind == InputKind.USER_INPUT: + user_inputs.append(name) + elif kind == InputKind.CONSTANT_TENSOR: + constant_tensors.append(name) + else: + raise NotImplementedError( + f"Support for input {arg} is not implemented" + ) + elif isinstance(arg, SymIntArgument): + if kind == InputKind.USER_INPUT: + user_inputs.append(name) + else: + raise NotImplementedError( + f"Support for input {arg} is not implemented" + ) + else: + raise NotImplementedError(f"Support for input {arg} is not implemented") + + for node in self.ep.graph.nodes: + if node.op == "placeholder": + if node.users == {}: + continue + if node.name in constant_tensors: + self.make_or_get_slot(node, id_space=IdSpace.Constant) + elif node.name in user_inputs: + val = node.meta.get("val", None) + if isinstance(val, torch.Tensor) and not val.is_contiguous(): + raise ValueError( + f"MLX backend requires contiguous input tensors, " + f"but input '{node.name}' has non-contiguous strides. " + f"shape={list(val.shape)}, stride={list(val.stride())}. " + f"Ensure example inputs passed to torch.export.export() " + f"are contiguous (call .contiguous() on them)." + ) + self.make_or_get_slot(node, id_space=IdSpace.Input) + elif node.name in mutable_buffers: + self.make_or_get_slot(node, id_space=IdSpace.MutableBuffer) + else: + raise NotImplementedError( + f"Support for placeholder {node.name} is not implemented" + ) + elif node.op == "output": + outs, _ = pytree.tree_flatten(node.args) + for o in outs: + if isinstance(o, Node) and o.name in user_outputs: + self.make_or_get_slot(o, id_space=IdSpace.Output) + + def _mark_noop(self): + """Mark noops and dead nodes.""" + dead = set() + noop_handler = REGISTRY.get_noop_handler() + if noop_handler is None: + return + + for n in reversed(self.ep.graph.nodes): + handler = REGISTRY.get_handler(n) + if handler == noop_handler: + dead.add(n) + + if n.op != "output" and all(user in dead for user in n.users): + self.node_info[n].handler = noop_handler + dead.add(n) + + def _apply_patterns(self) -> None: + """ + Find and apply pattern handlers to the graph. + + Uses PatternMatcher to discover multi-node patterns and assigns + handlers to matched nodes. This must run BEFORE _mark_noop so + pattern body nodes don't get incorrectly marked as dead. + """ + matcher = PatternMatcher(self.ep, REGISTRY) + for handler in matcher.find_patterns(): + handler.set_handlers(self) + + def _process_nodes(self) -> None: # noqa C901 + """ + Common logic for processing all nodes: create slots, match patterns, run handlers. + + This method: + 1. Creates I/O slots for placeholders and outputs + 2. Matches patterns FIRST (so body nodes get handlers and aren't marked dead) + 3. Marks dead/noop nodes + 4. Runs handlers for remaining nodes, marking them supported/unsupported + + The ordering is important: patterns must be matched before noops because + some pattern body nodes (e.g., update_cache) have no users since they + mutate in-place, but they're not dead - they're handled by the pattern. + """ + self._make_io_slots() + + # Apply patterns BEFORE _mark_noop so pattern body nodes don't get + # incorrectly marked as dead (e.g., update_cache nodes have no users + # since they mutate in-place, but they're not dead) + self._apply_patterns() + self._mark_noop() + + for n in self.ep.graph.nodes: + if self._is_handled(n): + continue + + if self.node_info[n].handler is not None: + handler = self.node_info[n].handler + handler(self, n) + self._mark_supported(n, handler=handler) + continue + + # Check input dtypes before processing node + unsupported_dtype_msg = _check_input_dtypes(n) + if unsupported_dtype_msg is not None: + if n.meta.get("val", None) is not None: + self.slot_manager.make_or_get_slots(n) + self._mark_unsupported(n, unsupported_dtype_msg) + continue + + if n.op in ("placeholder", "output"): + dtype_error = _check_dtype(n) + if dtype_error is not None: + self._mark_unsupported(n, f"{n.op} {dtype_error}") + continue + self._mark_supported(n) + continue + + handler = REGISTRY.get_handler(n) + if handler is None: + msg = f"no handler for target={n.target}" + if n.meta.get("val", None) is not None: + self.slot_manager.make_or_get_slots(n) + self._mark_unsupported(n, msg) + continue + + try: + handler(self, n) + self._mark_supported(n, handler=handler) + except Exception as e: + trace_str = traceback.format_exc() + msg = f"{handler} failed for {n.target}: {e}.\n{trace_str}" + if n.meta.get("val", None) is not None: + self.slot_manager.make_or_get_slots(n) + self._mark_unsupported(n, msg) + + def check_support_only(self) -> None: + """ + Check which nodes are supported without building the full MLXGraph. + + This method populates node_info with supported/unsupported status for each + node, but avoids calling _build_mlx_graph() which can corrupt the shape_env + by evaluating symbolic shapes. + + Use this method for ops_to_not_decompose() and similar queries where you + only need to know support status, not the full compiled graph. + """ + self._process_nodes() + # NOTE: We intentionally skip _verify_build() and _build_mlx_graph() here + # because _build_mlx_graph() calls int() on tensor shapes which evaluates + # SymInts and corrupts the shape_env. This method is used for + # ops_to_not_decompose() where we only need support status. + + def _emit_buffer_mutation_writebacks(self): + """Emit copy-back instructions for BUFFER_MUTATION outputs. + + When a model mutates a buffer (e.g., via .copy_() or .mul_()), + torch.export functionalizes it: the new value is a computation result, + and the output spec marks it as BUFFER_MUTATION with a target buffer. + + This method emits an IdCopyNode for each BUFFER_MUTATION output, + copying the computation result back to the mutable buffer slot so + the updated value persists across execution calls. + """ + from torch.export.graph_signature import InputKind, OutputKind + + # Map buffer target name -> input placeholder name + target_to_placeholder = {} + for ispec in self.ep.graph_signature.input_specs: + if ispec.kind == InputKind.BUFFER and ispec.target is not None: + target_to_placeholder[ispec.target] = ispec.arg.name + + for ospec in self.ep.graph_signature.output_specs: + if ospec.kind != OutputKind.BUFFER_MUTATION: + continue + + result_slot = self.slot_manager.get_slot(ospec.arg.name) + placeholder_name = target_to_placeholder.get(ospec.target) + if result_slot is None or placeholder_name is None: + continue + + buffer_slot = self.slot_manager.get_slot(placeholder_name) + if buffer_slot is None or buffer_slot.id_space != IdSpace.MutableBuffer: + continue + + self.emit( + IdCopyNode( + x=self.slot_to_tid(result_slot), + out=self.slot_to_tid(buffer_slot), + ) + ) + + def build(self) -> MLXGraph: + if self._mlx_graph is not None: + return self._mlx_graph + + self._process_nodes() + self._emit_buffer_mutation_writebacks() + self._verify_build() + self._mlx_graph = self._build_mlx_graph() + return self._mlx_graph + + def _verify_build(self): + noop_handler = REGISTRY.get_noop_handler() + + for n, info in self.node_info.items(): + assert info.handled + assert ( + info.remaining_reads == 0 + ), f"Expected {n} to have no remaining reads, but it has {info.remaining_reads}" + if n.op == "output": + assert self.slot_manager.get_slot(n) is None + continue + if ( + info.handler in (noop_handler, PatternHandler.deferred_handler) + or n.users == {} + ): + assert ( + self.slot_manager.get_slot(n) is None + ), f"Did not expect node {n} handled by {info.handler} to have a slot" + else: + assert ( + self.slot_manager.get_slot(n) is not None + ), f"Expected slot for node {n}" + + def _collect_used_slots( + self, + ) -> Tuple[Set[Slot], Dict[IdSpace, int], Dict[IdSpace, int]]: + """ + Collect all used slots and count tensors/values per IdSpace. + + For constants and temps, only includes those actually referenced by + instructions. This ensures unused slots are not serialized or counted. + + Returns: + (used_slots, num_tensors, num_values) + """ + # Get slots actually referenced by instructions + instruction_referenced: Set[Slot] = {slot for _, slot in self._tid_slot_map} + instruction_referenced.update({slot for _, slot in self._vid_slot_map}) + + used_slots: Set[Slot] = set() + for _n, slot in self.slot_manager.name_to_slot.items(): + if not isinstance(slot, tuple): + slot = (slot,) + for s in slot: + # For constants and temps, only include if referenced by instructions + if s.id_space in (IdSpace.Constant, IdSpace.Temp): + if s in instruction_referenced: + used_slots.add(s) + else: + # Inputs, outputs, mutable buffers - always include + used_slots.add(s) + + num_tensors: Dict[IdSpace, int] = defaultdict(int) + num_values: Dict[IdSpace, int] = defaultdict(int) + seen: Set[Slot] = set() + for s in used_slots: + if s in seen: + continue + seen.add(s) + if s.id_type == IdType.Tensor: + num_tensors[s.id_space] += 1 + else: + num_values[s.id_space] += 1 + + return used_slots, num_tensors, num_values + + def _create_slot_mappings( + self, used_slots: Set[Slot] + ) -> Tuple[Dict[Slot, int], Dict[Slot, int]]: + """ + Create slot→Tid and slot→Vid mappings, and remap existing references. + + Returns: + (slot_to_tid, slot_to_vid) + """ + id_space_order = { + IdSpace.Constant: 0, + IdSpace.Input: 1, + IdSpace.Output: 2, + IdSpace.MutableBuffer: 3, + IdSpace.Temp: 4, + } + + # Create Tid mapping + slot_to_tid = sorted( + [s for s in used_slots if s.id_type == IdType.Tensor], + key=lambda s: (id_space_order[s.id_space], s.idx), + ) + slot_to_tid = {s: idx for idx, s in enumerate(slot_to_tid)} + + # Create Vid mapping + slot_to_vid = sorted( + [s for s in used_slots if s.id_type != IdType.Tensor], + key=lambda s: (id_space_order[s.id_space], s.idx), + ) + slot_to_vid = {s: idx for idx, s in enumerate(slot_to_vid)} + + # Remap all Tid/Vid values in instructions to use global indices + if hasattr(self, "_tid_slot_map"): + for tid, slot in self._tid_slot_map: + if slot in slot_to_tid: + tid.idx = slot_to_tid[slot] + else: + logger.warning(f"Slot {slot} not found in slot_to_tid mapping") + + if hasattr(self, "_vid_slot_map"): + for vid, slot in self._vid_slot_map: + if slot in slot_to_vid: + vid.idx = slot_to_vid[slot] + else: + logger.warning(f"Slot {slot} not found in slot_to_vid mapping") + + return slot_to_tid, slot_to_vid + + def _to_slot_variant( + self, + slot: Slot, + slot_to_tid: Dict[Slot, int], + slot_to_vid: Dict[Slot, int], + ) -> SlotVariant: + """Convert a Slot to a SlotVariant using the provided mappings.""" + if slot.id_type == IdType.Tensor: + idx = slot_to_tid[slot] + slot_type = SlotType.TensorSlot + elif slot.id_type == IdType.SymInt: + idx = slot_to_vid[slot] + slot_type = SlotType.IntValueSlot + elif slot.id_type == IdType.SymBool: + idx = slot_to_vid[slot] + slot_type = SlotType.BoolValueSlot + else: + raise NotImplementedError(f"Unsupported slot type {slot.id_type}") + return SlotVariant(idx=idx, slot_type=slot_type) + + def _build_io_maps( + self, + used_slots: Set[Slot], + slot_to_tid: Dict[Slot, int], + slot_to_vid: Dict[Slot, int], + ) -> Tuple[ + List[SlotVariant], List[SlotVariant], List[SlotVariant], List[NamedSlot] + ]: + """ + Build input/output/mutable_buffer maps and named slots. + + Returns: + (input_map, output_map, mutable_buffer_map, named_slots) + """ + input_map: List[SlotVariant] = [] + output_map: List[SlotVariant] = [] + mutable_buffer_map: List[SlotVariant] = [] + # Canonical (unprefixed) name → Slot. The prefix is applied only at + # the exit boundaries: NamedSlot construction and NamedDataStore keys. + name_to_slot: Dict[str, Slot] = {} + + for ispec in self.ep.graph_signature.input_specs: + slot = self.slot_manager.get_slot(ispec.arg.name) + if slot is None: + continue + assert isinstance(slot, Slot) + name = ispec.target if ispec.target is not None else ispec.arg.name + if slot.id_space == IdSpace.Input: + input_map.append(self._to_slot_variant(slot, slot_to_tid, slot_to_vid)) + name_to_slot[name] = slot + elif slot.id_space == IdSpace.MutableBuffer: + mutable_buffer_map.append( + self._to_slot_variant(slot, slot_to_tid, slot_to_vid) + ) + name_to_slot[name] = slot + else: + if slot in used_slots: + name_to_slot[name] = slot + + for ospec in self.ep.graph_signature.output_specs: + name = ospec.arg.name + slot = self.slot_manager.get_slot(name) + if slot is None: + continue + assert isinstance(slot, Slot) + if slot.id_space == IdSpace.Output: + output_map.append(self._to_slot_variant(slot, slot_to_tid, slot_to_vid)) + name = ospec.target if ospec.target is not None else ospec.arg.name + name_to_slot[name] = slot + elif slot.id_space == IdSpace.MutableBuffer: + name = ospec.target if ospec.target is not None else ospec.arg.name + name_to_slot[name] = slot + + for name in self.extra_constants: + slot = self.slot_manager.get_slot(name) + assert slot is not None and isinstance(slot, Slot) + if slot in used_slots: + name_to_slot[name] = slot + + # Store unprefixed constant mapping for get_named_data_store() + self._constant_name_to_slot = { + n: s for n, s in name_to_slot.items() if s.id_space == IdSpace.Constant + } + + # Apply prefix at the exit boundary — the FlatBuffer named_slots + named_slots = [ + NamedSlot( + name=self._prefix_key(n), + slot=self._to_slot_variant(s, slot_to_tid, slot_to_vid), + ) + for n, s in name_to_slot.items() + ] + + return input_map, output_map, mutable_buffer_map, named_slots + + def _build_tensor_meta( # noqa: C901 + self, + used_slots: Set[Slot], + slot_to_tid: Dict[Slot, int], + slot_to_vid: Dict[Slot, int], + num_tensors: Dict[IdSpace, int], + ) -> List[TensorMeta]: + """ + Build tensor metadata list with shape/dtype information. + + Static dimensions are stored as ShapeDim(value=N). + Dynamic dimensions (SymInt) are stored as ShapeDim(value=-1) + with min/max bounds from the shape_env. + + Note: tensor_meta shapes are only consumed by the runtime for + constant and mutable buffer allocation (which are always static). + Dynamic dim metadata is informational — the runtime resolves + dynamic shapes via SymSizeNode at execution time. + """ + + def _get_dim_bounds(dim: torch.SymInt) -> tuple: + """Get (min, max) bounds for a symbolic dimension.""" + try: + node = dim.node + shape_env = node.shape_env + if shape_env is not None: + expr = node.expr + lower = int(shape_env.bound_sympy(expr).lower) + upper = int(shape_env.bound_sympy(expr).upper) + if upper > 2**30: + return (lower, -1) # treat as unbounded + return (lower, upper) + except Exception: + pass + return (0, -1) # unbounded fallback + + def to_tensor_meta(t: torch.Tensor) -> TensorMeta: + shape: List[ShapeDim] = [] + for dim in t.shape: + if isinstance(dim, torch.SymInt): + lo, hi = _get_dim_bounds(dim) + shape.append(ShapeDim(value=-1, min_value=lo, max_value=hi)) + else: + shape.append(ShapeDim(value=int(dim))) + + dim_order = list(range(len(t.shape))) if len(t.shape) > 0 else None + + return TensorMeta( + shape=shape, + scalar_type=torch_dtype_to_scalar_type(t.dtype), + dim_order=dim_order, + ) + + tensor_meta: Dict[int, TensorMeta] = {} + for n in self.node_info: + slot = self.slot_manager.get_slot(n) + if not isinstance(slot, tuple): + slot = (slot,) + fake_val = n.meta.get("val", None) + if not isinstance(fake_val, tuple): + fake_val = (fake_val,) + for s, fv in zip(slot, fake_val): + if s not in used_slots: + continue + if s.id_type != IdType.Tensor: + continue + if s.id_space == IdSpace.Temp: + continue + idx = slot_to_tid[s] + if fv is not None and hasattr(fv, "shape"): + tensor_meta[idx] = to_tensor_meta(fv) + + for name, t in self.extra_constants.items(): + slot = self.slot_manager.get_slot(name) + assert slot is not None and isinstance(slot, Slot) + if slot in used_slots: + idx = slot_to_tid[slot] + tensor_meta[idx] = to_tensor_meta(t) + + num_non_temp_tensors = sum(num_tensors.values()) - num_tensors[IdSpace.Temp] + return [tensor_meta.get(i) for i in range(num_non_temp_tensors)] + + def _build_mlx_graph(self) -> MLXGraph: + # Check support + for node, info in self.node_info.items(): + if not info.supported: + raise ValueError( + f"Found unsupported node: {node}\nReason: {info.unsupported_reason}" + ) + + # Collect slots and create mappings + used_slots, num_tensors, num_values = self._collect_used_slots() + slot_to_tid, slot_to_vid = self._create_slot_mappings(used_slots) + + # Store for use in get_constant_data() - needed to serialize in Tid order + self._slot_to_final_tid = slot_to_tid + + # Build I/O maps and metadata + input_map, output_map, mutable_buffer_map, named_slots = self._build_io_maps( + used_slots, slot_to_tid, slot_to_vid + ) + tensor_meta_list = self._build_tensor_meta( + used_slots, slot_to_tid, slot_to_vid, num_tensors + ) + + # Compute final counts + num_constant_tensors = num_tensors[IdSpace.Constant] + num_temp_tensors = num_tensors[IdSpace.Temp] + num_values_count = sum(num_values.values()) + + return MLXGraph( + version="1", + num_constant_tensors=num_constant_tensors, + num_input_tensors=num_tensors[IdSpace.Input], + num_output_tensors=num_tensors[IdSpace.Output], + num_mutable_buffer_tensors=num_tensors[IdSpace.MutableBuffer], + num_temp_tensors=num_temp_tensors, + num_values=num_values_count, + instruction_chains=[InstructionChain(instructions=self._instrs)], + main_chain_idx=0, + init_chain_idx=-1, + input_map=input_map, + output_map=output_map, + mutable_buffer_map=mutable_buffer_map, + named_slots=named_slots, + tensor_meta=tensor_meta_list, + ) + + def get_named_data_store(self) -> NamedDataStore: + """ + Get a NamedDataStore containing all constant tensors. + + Uses the unprefixed canonical-name → Slot mapping built by + ``_build_io_maps()`` so that tensor lookups hit ``ep.state_dict`` / + ``ep.constants`` / ``extra_constants`` (which all use unprefixed + keys). The prefix is applied at the exit boundary — the + ``NamedDataStore`` key — so it matches the FlatBuffer ``named_slots``. + """ + named_data_store = NamedDataStore() + + # Sort by final TID for deterministic ordering + entries = sorted( + self._constant_name_to_slot.items(), + key=lambda x: self._slot_to_final_tid.get(x[1], 0), + ) + + logger.debug(f"Adding {len(entries)} constants to NamedDataStore...") + for canonical_name, _slot in entries: + tensor = self._find_constant_tensor(canonical_name) + if tensor is None: + continue + + t = tensor.detach().cpu().contiguous() + named_data_store.add_named_data( + key=self._prefix_key(canonical_name), + data=t, + alignment=16, + ) + logger.debug("Done adding constants to NamedDataStore") + + return named_data_store + + def get_mutable_buffer_names(self) -> List[str]: + """ + Get the names of all mutable buffers in Tid order. + + Returns: + List of mutable buffer names in the order they appear in mutable_buffer_map. + """ + assert self._mlx_graph is not None, "Must call build() first" + + names = [] + for name, slot in self.slot_manager.name_to_slot.items(): + if isinstance(slot, tuple): + continue + if slot.id_space != IdSpace.MutableBuffer: + continue + if slot in self._slot_to_final_tid: + names.append((name, self._slot_to_final_tid[slot])) + + # Sort by Tid and return just the names + names.sort(key=lambda x: x[1]) + return [n for n, _ in names] + + def _find_constant_tensor(self, name: str) -> Optional[torch.Tensor]: + """Find a constant tensor by name from various sources.""" + if name in self.ep.state_dict: + return self.ep.state_dict[name] + if name in self.ep.constants: + return self.ep.constants[name] + if name in self.extra_constants: + return self.extra_constants[name] + # Look up by target + for ispec in self.ep.graph_signature.input_specs: + if ispec.arg.name == name and ispec.target is not None: + if ispec.target in self.ep.state_dict: + return self.ep.state_dict[ispec.target] + if ispec.target in self.ep.constants: + return self.ep.constants[ispec.target] + return None diff --git a/backends/mlx/builder/slot_manager.py b/backends/mlx/builder/slot_manager.py new file mode 100644 index 00000000000..b1884a76a68 --- /dev/null +++ b/backends/mlx/builder/slot_manager.py @@ -0,0 +1,187 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +import uuid +from collections import defaultdict +from dataclasses import dataclass +from enum import auto, Enum +from typing import Dict, Optional, Tuple, Union + +import torch +from torch.fx.node import Node + + +class IdType(Enum): + Tensor = auto() + SymInt = auto() + SymBool = auto() + + +class IdSpace(Enum): + Constant = auto() + Input = auto() + Output = auto() + MutableBuffer = auto() + Temp = auto() + + +@dataclass(frozen=True) +class Slot: + id_type: IdType + id_space: IdSpace + idx: Optional[int] = None + + +class IdManager: + def __init__(self): + self.free: set[int] = set() + self.next_new_id = 0 + + def get_id(self): + return self.free.pop() if self.free else self._bump() + + def _bump(self): + idx = self.next_new_id + self.next_new_id += 1 + return idx + + def return_id(self, idx): + self.free.add(idx) + + +class SlotManager: + def __init__(self): + self.tid_managers: Dict[IdSpace, IdManager] = defaultdict(IdManager) + self.vid_managers: Dict[IdSpace, IdManager] = defaultdict(IdManager) + self.name_to_slot: Dict[str, Slot] = {} + + def set_slot(self, node_or_name: Union[Node, str], slot: Slot): + if isinstance(node_or_name, Node): + node_or_name = node_or_name.name + # Allow setting a slot to the same value (e.g., for in-place ops like SLICE_UPDATE) + existing = self.name_to_slot.get(node_or_name) + if existing is not None: + # If already set to the same slot, it's fine + if existing == slot: + return + raise AssertionError( + f"Slot for {node_or_name} already set to {existing}, trying to set to {slot}" + ) + self.name_to_slot[node_or_name] = slot + + def get_slot( + self, node_or_name: Union[Node, str] + ) -> Optional[Union[Tuple[Slot], Slot]]: + if isinstance(node_or_name, Node): + node_or_name = node_or_name.name + return self.name_to_slot.get(node_or_name, None) + + def _val_to_idtype(self, v) -> IdType: + from torch._subclasses.fake_tensor import FakeTensor + + if isinstance(v, FakeTensor): + return IdType.Tensor + elif isinstance(v, torch.SymInt): + return IdType.SymInt + elif isinstance(v, torch.SymBool): + return IdType.SymBool + else: + raise NotImplementedError(f"val_to_idtype: {v}") + + def is_alive(self, slot: Slot) -> bool: + if slot.id_type == IdType.Tensor: + manager = self.tid_managers[slot.id_space] + else: + manager = self.vid_managers[slot.id_space] + idx = slot.idx + if idx >= manager.next_new_id: + return False + if idx in manager.free: + return False + return True + + def make_constant_slot(self, name: str) -> Slot: + assert name not in self.name_to_slot + id_space = IdSpace.Constant + manager = self.tid_managers[id_space] + idx = manager.get_id() + slot = Slot(id_type=IdType.Tensor, id_space=id_space, idx=idx) + self.name_to_slot[name] = slot + return slot + + def make_tmp_slot(self) -> Tuple[str, Slot]: + name = f"tmp_{uuid.uuid4().hex}" + id_space = IdSpace.Temp + manager = self.tid_managers[id_space] + idx = manager.get_id() + slot = Slot(id_type=IdType.Tensor, id_space=id_space, idx=idx) + self.name_to_slot[name] = slot + return name, slot + + def make_tmp_value_slot(self) -> Tuple[str, Slot]: + """Create a temporary SymInt slot and register it.""" + name = f"tmp_val_{uuid.uuid4().hex}" + id_space = IdSpace.Temp + manager = self.vid_managers[id_space] + idx = manager.get_id() + slot = Slot(id_type=IdType.SymInt, id_space=id_space, idx=idx) + self.name_to_slot[name] = slot + return name, slot + + def make_or_get_slots( + self, node: Node, id_space: IdSpace = IdSpace.Temp + ) -> Tuple[Slot, ...]: + """ + Get or create slots for a node. Always returns a tuple of slots. + + Use this for multi-output ops (e.g., topk returns (values, indices)). + For single-output ops, prefer make_or_get_slot() which returns a single Slot. + """ + if node.name in self.name_to_slot: + slot = self.name_to_slot[node.name] + # Normalize to tuple for consistent return type + if not isinstance(slot, tuple): + return (slot,) + return slot + + val = node.meta.get("val", None) + assert val is not None, f"Node {node} has no val" + if not isinstance(val, (list, tuple)): + val = (val,) + + slots = [] + for v in val: + id_type = self._val_to_idtype(v) + if id_type == IdType.Tensor: + manager = self.tid_managers[id_space] + else: + manager = self.vid_managers[id_space] + idx = manager.get_id() + slots.append(Slot(id_type=id_type, id_space=id_space, idx=idx)) + slots = tuple(slots) + + # Store in the format that matches the node's output structure + if len(slots) == 1: + self.set_slot(node, slots[0]) + else: + self.set_slot(node, slots) + return slots + + def make_or_get_slot(self, node: Node, id_space: IdSpace = IdSpace.Temp) -> Slot: + """ + Get or create a slot for a single-output node. Returns a single Slot. + + Use this for single-output ops (the common case). + For multi-output ops, use make_or_get_slots() instead. + """ + slots = self.make_or_get_slots(node, id_space) + assert len(slots) == 1, ( + f"Expected single output for node {node.name}, got {len(slots)}. " + f"Use make_or_get_slots() for multi-output ops." + ) + return slots[0] diff --git a/backends/mlx/custom_ops.py b/backends/mlx/custom_ops.py new file mode 100644 index 00000000000..81853adbd6d --- /dev/null +++ b/backends/mlx/custom_ops.py @@ -0,0 +1,15 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +Custom MLX operator definitions. + +This module defines custom operators that are supported by the MLX backend. +These ops are used during model export to represent operations that MLX +can execute efficiently but may not have direct PyTorch equivalents. +""" diff --git a/backends/mlx/ops.py b/backends/mlx/ops.py new file mode 100644 index 00000000000..6e8516e86b1 --- /dev/null +++ b/backends/mlx/ops.py @@ -0,0 +1,294 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +MLX Op Handlers - registered handlers for converting ATen/custom ops to MLX. + +This module contains all the op handler functions registered with the MLXOpRegistry. +Each handler converts a specific PyTorch operation to the corresponding MLX graph node. +""" + +from __future__ import annotations + +import operator +from typing import Any, Dict, List, Optional, Set, Tuple, Union + +import torch +from executorch.backends.mlx.builder.op_registry import REGISTRY +from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder +from executorch.backends.mlx.builder.slot_manager import Slot +from executorch.backends.mlx.serialization.mlx_graph_schema import AddmmNode +from torch.fx.node import Node + + +def require_static_int(value: Any, param_name: str, op_name: str) -> None: + """ + Validate that a parameter is a static integer (not a Slot/SymInt). + + Raises NotImplementedError if the value is dynamic. + + Args: + value: The parameter value to check + param_name: Name of the parameter (for error message) + op_name: Name of the operation (for error message) + """ + if isinstance(value, Slot) or not isinstance(value, int): + raise NotImplementedError( + f"{op_name} with dynamic {param_name} is not supported. " + f"{param_name} requires a static int32 value, but got {value} (type={type(value).__name__})." + ) + + +def require_static_float(value: Any, param_name: str, op_name: str) -> None: + """ + Validate that a parameter is a static float (not a Slot/SymFloat). + + Raises NotImplementedError if the value is dynamic. + + Args: + value: The parameter value to check + param_name: Name of the parameter (for error message) + op_name: Name of the operation (for error message) + """ + if isinstance(value, Slot) or not isinstance(value, (int, float)): + raise NotImplementedError( + f"{op_name} with dynamic {param_name} is not supported. " + f"{param_name} requires a static float value, but got {value} (type={type(value).__name__})." + ) + + +def require_static_ints( + values: Union[List[Any], Any], param_name: str, op_name: str +) -> None: + """ + Validate that all values in a list are static integers (not Slots/SymInts). + + Raises NotImplementedError if any value is dynamic. + + Args: + values: List of values to check, or a single value + param_name: Name of the parameter (for error message) + op_name: Name of the operation (for error message) + """ + if not isinstance(values, list): + values = [values] + + for v in values: + require_static_int(v, param_name, op_name) + + +def require_args( + args: List[Any], + min_count: int, + max_count: int, + op_name: str, +) -> None: + """ + Validate that args count is within expected range. + + Raises ValueError if the count is outside the expected range. + + Args: + args: The handler args list + min_count: Minimum number of args expected + max_count: Maximum number of args expected + op_name: Name of the operation (for error message) + """ + if not (min_count <= len(args) <= max_count): + if min_count == max_count: + raise ValueError(f"{op_name}: expected {min_count} args, got {len(args)}") + raise ValueError( + f"{op_name}: expected {min_count}-{max_count} args, got {len(args)}" + ) + + +def require_kwargs( + kwargs: Dict[str, Any], + allowed: Set[str], + op_name: str, +) -> None: + """ + Validate that only allowed kwargs are present. + + Raises ValueError if unexpected kwargs are found. + + Args: + kwargs: The handler kwargs dict + allowed: Set of allowed kwarg names + op_name: Name of the operation (for error message) + """ + unexpected = set(kwargs.keys()) - allowed + if unexpected: + raise ValueError(f"{op_name}: unexpected kwargs: {unexpected}") + + +def require_contiguous_format( + *, + layout=None, + memory_format=None, + dim_order=None, + op_name: str, +) -> None: + """ + Validate that layout/memory_format/dim_order specify contiguous format. + + MLX only supports contiguous (strided) tensors. Raises ValueError if + sparse layouts or non-contiguous memory formats are requested. + + Args: + layout: The torch layout (e.g., torch.strided, torch.sparse_coo) + memory_format: The torch memory format (e.g., torch.contiguous_format, + torch.channels_last) + dim_order: The dimension order (list of ints, identity = contiguous) + op_name: Name of the operation (for error message) + """ + if layout is not None and layout != torch.strided: + raise ValueError(f"{op_name}: only strided layout supported, got {layout}") + + if memory_format is not None and memory_format not in ( + torch.contiguous_format, + torch.preserve_format, + ): + raise ValueError( + f"{op_name}: only contiguous memory format supported, got {memory_format}" + ) + + if dim_order is not None: + if list(dim_order) != list(range(len(dim_order))): + raise ValueError( + f"{op_name}: only contiguous dim_order supported, got {dim_order}" + ) + + +def is_static_value(value: Any) -> bool: + """ + Check if a value is static (not a Slot/SymInt). + + Returns: + True if the value is a static scalar (int, float, bool), False otherwise + """ + return not isinstance(value, Slot) + + +def used_getitem_indices(n: Node) -> Set[int]: + """Return the set of getitem indices actually consumed downstream. + + Only includes indices where the getitem node has at least one user. + """ + return { + user.args[1] + for user in n.users + if user.target == operator.getitem and len(user.users) > 0 + } + + +def normalize_reduction_dim( + args: List[Any], start_idx: int = 1 +) -> Tuple[Optional[List[int]], bool]: + """ + Normalize dim argument for reduction operations. + + Extracts and normalizes the dim argument from handler args, returning a list of axes + and the keepdim flag. Handles both list-based dims (e.g., sum.dim_IntList) and + single int dims (e.g., prod.dim_int). + + Args: + args: The handler args list + start_idx: Index where the dim argument starts (default 1, after self) + + Returns: + Tuple of (axes, keepdim) where: + - axes: List of dimension indices, or empty list for reduce-all + - keepdim: Boolean keepdim flag (default False) + """ + if len(args) > start_idx and isinstance(args[start_idx], (list, tuple)): + dim = list(args[start_idx]) + keepdim = args[start_idx + 1] if len(args) > start_idx + 1 else False + elif len(args) > start_idx and isinstance(args[start_idx], int): + dim = [args[start_idx]] + keepdim = args[start_idx + 1] if len(args) > start_idx + 1 else False + else: + dim = [] + keepdim = False + + return dim, keepdim + + +@REGISTRY.register(target=[torch.ops.aten.addmm.default]) +def _addmm_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle addmm: self + (mat1 @ mat2). + + addmm(self, mat1, mat2, *, beta=1, alpha=1) computes: + beta * self + alpha * (mat1 @ mat2) + + This is typically the result of decomposing linear(x, w, b) in Edge IR: + permute(w) -> addmm(b, x, permuted_w) + + For the common case where beta=1 and alpha=1, this is equivalent to: + mat1 @ mat2 + self + + We use AddmmNode which calls matmul directly (no transposition needed). + """ + args = P.args(n) + kwargs = P.kwargs(n) + require_args(args, 3, 3, "aten.addmm") + require_kwargs(kwargs, {"beta", "alpha"}, "aten.addmm") + bias, mat1, mat2 = args[0], args[1], args[2] + + beta = kwargs.get("beta", 1) + alpha = kwargs.get("alpha", 1) + + out = P.make_or_get_slot(n) + + # Emit AddmmNode with alpha and beta parameters + P.emit( + AddmmNode( + mat1=P.slot_to_tid(mat1), + mat2=P.slot_to_tid(mat2), + out=P.slot_to_tid(out), + bias=P.slot_to_tid(bias), + alpha=float(alpha), + beta=float(beta), + ) + ) + return out + + +@REGISTRY.register( + target=[ + torch.ops.aten.mm.default, + torch.ops.aten.bmm.default, + torch.ops.aten.matmul.default, + ] +) +def _mm_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle mm/bmm/matmul: matrix multiplication without bias. + + All three ops compute matrix products with different dimension expectations: + - mm: 2D x 2D + - bmm: 3D x 3D (batched) + - matmul: arbitrary dimensions (NumPy semantics) + + MLX's matmul handles all cases, so we emit AddmmNode with bias=None. + """ + args = P.args(n) + require_args(args, 2, 2, "aten.mm/bmm/matmul") + require_kwargs(P.kwargs(n), set(), "aten.mm/bmm/matmul") + mat1, mat2 = args[0], args[1] + + out = P.make_or_get_slot(n) + + P.emit( + AddmmNode( + mat1=P.slot_to_tid(mat1), + mat2=P.slot_to_tid(mat2), + out=P.slot_to_tid(out), + bias=None, + ) + ) + return out diff --git a/backends/mlx/partitioner.py b/backends/mlx/partitioner.py new file mode 100644 index 00000000000..0896cafc301 --- /dev/null +++ b/backends/mlx/partitioner.py @@ -0,0 +1,298 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +MLX Partitioner - decides which ops should run on the MLX delegate. + +This module provides a Partitioner implementation that analyzes an EdgeIR +graph and marks supported operations for delegation to MLX. +""" + +from __future__ import annotations + +import inspect +from typing import Any, Callable, Dict, List, Tuple, Union + +import torch +from executorch.backends.mlx._logging import logger +from executorch.backends.mlx.preprocess import MLXBackend +from executorch.exir.backend.backend_details import CompileSpec +from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import ( + generate_partitions_from_list_of_nodes, +) +from executorch.exir.backend.partitioner import ( + DelegationSpec, + Partitioner, + PartitionResult, +) +from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer +from torch.export.exported_program import ExportedProgram +from torch.fx.passes.infra.partitioner import Partition +from torch.fx.passes.operator_support import OperatorSupportBase + + +class MLXOperatorSupport(OperatorSupportBase): + """ + Determines which operators are supported by the MLX delegate. + + Uses MLXProgramBuilder to determine support - this ensures the partitioner + uses the exact same logic as the actual compilation. A node is supported + if the builder can handle it (either via direct handler or pattern match). + """ + + def __init__( + self, + edge_program: torch.export.ExportedProgram, + compile_specs: List[CompileSpec], + ): + self.edge_program = edge_program + self.compile_specs = compile_specs + + # Run the builder to determine which nodes are supported + # The builder populates node_info with supported/unsupported status + from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder + + self._builder = MLXProgramBuilder(edge_program) + self._builder.check_support_only() + + def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: + if node.op != "call_function": + return False + + # Check if builder determined this node is supported + info = self._builder.node_info.get(node) + if info is not None and info.supported: + logger.debug(f"[SUPPORTED] Node {node.target}") + return True + + logger.debug(f"[UNSUPPORTED] Node {node.target}") + return False + + +class MLXPartitioner(Partitioner): + """ + Partitioner for the MLX delegate. + + Analyzes an EdgeIR graph and partitions supported operations + for delegation to MLX. + """ + + def __init__(self, compile_specs: List[CompileSpec] | None = None) -> None: + self.compile_specs = compile_specs or [] + self.delegation_spec = DelegationSpec(MLXBackend.__name__, self.compile_specs) + self.partition_tags: Dict[str, DelegationSpec] = {} + + def ops_to_not_decompose( + self, ep: ExportedProgram + ) -> tuple[list[torch._ops.OpOverload], Callable[[torch.fx.Node], bool] | None]: + """ + Return ops that should NOT be decomposed during edge lowering. + + This runs the MLXProgramBuilder to trace through the graph and determine + which nodes are supported (either via direct handlers or patterns). + Only ops for nodes that are actually supported should be preserved. + + This is called by to_edge_transform_and_lower to determine which + ops to preserve before partitioning. + + NOTE: We use check_support_only() instead of build() to avoid corrupting + the shape_env. build() calls _build_mlx_graph() which evaluates SymInts + to concrete values when converting tensor shapes, which corrupts the + shape_env and causes dynamic shapes to be lost during decomposition. + """ + from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder + + # Check if the graph already contains lowered modules (post-partitioning pass) + # In this case, we should return empty since partitioning is already done + for node in ep.graph.nodes: + if node.op == "get_attr" and "lowered_module" in node.name: + logger.debug( + "MLX ops_to_not_decompose: Graph already partitioned, returning empty" + ) + return ([], None) + + # Run the builder to determine which nodes are supported + builder = MLXProgramBuilder(ep) + builder.check_support_only() + + # Collect ops for nodes that are actually supported + do_not_decompose: list[torch._ops.OpOverload] = [] + + for node in ep.graph.nodes: + if node.op == "call_function" and isinstance( + node.target, torch._ops.OpOverload + ): + info = builder.node_info.get(node) + if info is not None and info.supported: + if node.target not in do_not_decompose: + do_not_decompose.append(node.target) + + logger.debug( + f"MLX ops_to_not_decompose: {[str(op) for op in do_not_decompose]}" + ) + return (do_not_decompose, None) + + def generate_partitions(self, edge_program: ExportedProgram) -> List[Any]: + """Generate partitions of supported nodes.""" + self.supported_ops = MLXOperatorSupport( + edge_program=edge_program, + compile_specs=self.delegation_spec.compile_specs, + ) + + # Collect unsupported ops, aggregated by target + unsupported_by_target: Dict[str, Tuple[int, str]] = ( + {} + ) # target -> (count, reason) + for node in edge_program.graph.nodes: + is_supported = self.supported_ops.is_node_supported({}, node) + if not is_supported and node.op == "call_function": + target_str = str(node.target) + info = self.supported_ops._builder.node_info.get(node) + reason = info.unsupported_reason if info else "No handler registered" + if target_str in unsupported_by_target: + count, _ = unsupported_by_target[target_str] + unsupported_by_target[target_str] = (count + 1, reason) + else: + unsupported_by_target[target_str] = (1, reason) + + logger.info("=" * 80) + logger.info("MLX Partitioner: UNSUPPORTED OPS SUMMARY") + logger.info("=" * 80) + if unsupported_by_target: + for target, (count, reason) in unsupported_by_target.items(): + logger.info(f" [UNSUPPORTED x{count}] {target}") + logger.info(f" Reason: {reason}") + else: + logger.info(" (All call_function nodes are supported!)") + logger.info("=" * 80) + + partitions = generate_partitions_from_list_of_nodes( + edge_program.graph_module, + op_support=self.supported_ops, + ) + + # WORKAROUND: Include sym_size nodes in partitions when any of their + # users are in the partition. Without this, sym_size nodes stay outside + # the partition and their results cross the partition boundary as concrete + # inputs, losing dynamic shape information during delegate lowering. + # By pulling them inside, the MLX runtime can execute SYM_SIZE at runtime, + # keeping shapes dynamic. + partitions = self._include_sym_size_nodes_in_partitions( + edge_program.graph_module, partitions + ) + + return partitions + + def _include_sym_size_nodes_in_partitions( + self, gm: torch.fx.GraphModule, partitions: List[Partition] + ) -> List[Partition]: + """ + Include sym_size nodes in partitions when any of their users are in the partition. + + This is a workaround for the dynamic shapes bug where symbolic shapes are lost + during delegate lowering if the sym_size node is not included in the partition. + """ + from executorch.exir.dialects.edge._ops import EdgeOpOverload + + for partition in partitions: + partition_nodes = set(partition.nodes) + nodes_to_add = [] + + for node in gm.graph.nodes: + if node.op != "call_function": + continue + + # Check if this is a sym_size node + target = node.target + if isinstance(target, EdgeOpOverload): + target = target._op + + if target != torch.ops.aten.sym_size.int: + continue + + # Check if any user of this sym_size node is in the partition + for user in node.users: + if user in partition_nodes: + # Add sym_size to partition if not already there + if node not in partition_nodes: + nodes_to_add.append(node) + logger.debug( + f"Adding sym_size node {node.name} to partition " + f"(used by {user.name})" + ) + break + + # Add the sym_size nodes to the partition + for node in nodes_to_add: + partition.add_node(node) + + return partitions + + def tag_nodes(self, partitions: List[Partition]) -> None: + """Tag nodes in each partition for delegation.""" + for partition in partitions: + delegation_tag = f"mlx_{partition.id}" + for node in partition.nodes: + node.meta["delegation_tag"] = delegation_tag + self.partition_tags[delegation_tag] = self.delegation_spec + + @staticmethod + def check_partitions(partitions: Union[dict, list]) -> bool: + """Check if any partitions were found.""" + pl = len(partitions) + if pl == 0: + logger.warning("MLX: Nothing can be partitioned!") + else: + logger.info(f"MLX: Found {pl} subgraphs to be partitioned.") + return pl != 0 + + @staticmethod + def _is_to_edge_transform_and_lower() -> bool: + """Check whether we are being called from to_edge_transform_and_lower.""" + for frame_info in inspect.stack(): + if frame_info.function == "to_edge_transform_and_lower": + return True + return False + + def partition(self, edge_program: ExportedProgram) -> PartitionResult: + """ + Partition the edge program for MLX delegation. + + Args: + edge_program: The ExportedProgram to partition. + + Returns: + PartitionResult with tagged nodes and partition specs. + + Raises: + RuntimeError: If called from the deprecated ``to_edge`` workflow. + """ + if not self._is_to_edge_transform_and_lower(): + raise RuntimeError( + "MLXPartitioner must be used with to_edge_transform_and_lower(). " + "The to_edge() + to_backend() workflow is not supported because " + "it decomposes ops that MLX has optimized implementations for. " + "Please use:\n" + " exir.to_edge_transform_and_lower(\n" + ' {"forward": exported_program},\n' + " partitioner=[MLXPartitioner()],\n" + " )" + ) + partitions = self.generate_partitions(edge_program=edge_program) + if self.check_partitions(partitions): + self.tag_nodes(partitions) + # Tag constant data that are used by the supported ops + tag_constant_data(edge_program) + # Tag mutated buffers so they are included in the partition + # This ensures the partitioned subgraph has proper mutation tracking + tag_mutated_buffer(edge_program) + + return PartitionResult( + tagged_exported_program=edge_program, + partition_tags=self.partition_tags, + ) diff --git a/backends/mlx/passes.py b/backends/mlx/passes.py new file mode 100644 index 00000000000..c7efdf561de --- /dev/null +++ b/backends/mlx/passes.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Graph transformation passes for the MLX backend. +""" + +from typing import List + +from executorch.exir.pass_base import ExportPass + + +def get_default_passes() -> List[ExportPass]: + """ + Returns a list of passes that are enabled by default for the MLX backend. + """ + return [] diff --git a/backends/mlx/patches/mlx_json.patch b/backends/mlx/patches/mlx_json.patch new file mode 100644 index 00000000000..4760403c8e6 --- /dev/null +++ b/backends/mlx/patches/mlx_json.patch @@ -0,0 +1,29 @@ +diff --git a/CMakeLists.txt b/CMakeLists.txt +--- a/CMakeLists.txt ++++ b/CMakeLists.txt +@@ -304,12 +304,18 @@ else() + set(MLX_BUILD_ACCELERATE OFF) + endif() + +-message(STATUS "Downloading json") +-FetchContent_Declare( +- json +- URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz) +-FetchContent_MakeAvailable(json) +-target_include_directories( +- mlx PRIVATE $) ++# Only fetch json if nlohmann_json target doesn't already exist ++# (ExecuTorch provides its own copy) ++if(NOT TARGET nlohmann_json) ++ message(STATUS "Downloading json") ++ FetchContent_Declare( ++ json ++ URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz) ++ FetchContent_MakeAvailable(json) ++ target_include_directories( ++ mlx PRIVATE $) ++else() ++ message(STATUS "Using existing nlohmann_json target") ++endif() + + add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx) diff --git a/backends/mlx/pattern_utils.py b/backends/mlx/pattern_utils.py new file mode 100644 index 00000000000..0d3d86430eb --- /dev/null +++ b/backends/mlx/pattern_utils.py @@ -0,0 +1,360 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Shared pattern matching utilities for MLX backend. + +This module provides common utilities used by both: +- passes.py: Graph transformation passes (ExportPass) +- patterns.py: MLX lowering pattern handlers (PatternHandler) + +The core abstraction is the `PatternMatch` base class which provides: +- `maybe_create(head)` - Class method to match a pattern from a head node +- Captured values as typed fields +- `body` list of intermediate nodes to remove + +Usage in passes.py: + class FuseRMSNormPass(ExportPass): + def call(self, graph_module): + for node in graph.nodes: + if match := RMSNormMatch.maybe_create(node): + replacement = self._emit_fused_op(graph, match) + node.replace_all_uses_with(replacement) + match.remove_body_nodes(graph) + +Usage in patterns.py: + class RMSNormHandler(PatternHandler): + @classmethod + def maybe_create(cls, ep, head): + if match := RMSNormMatch.maybe_create(head): + return cls(head, match.body, match.input, match.weight, match.eps) + return None +""" + +from dataclasses import dataclass, field +from typing import Any, Callable, List, Optional, Set, Tuple, Union + +from executorch.backends.mlx.builder.op_helpers import get_aten_target_normalized +from torch.fx import Graph +from torch.fx.node import Node + + +# Type alias for walk_back result entries +# Each entry corresponds to an OpStep: +# - Node: matched node (for regular steps) +# - None: optional step that didn't match +# - List[Node]: repeat step (0 or more matches) +WalkBackEntry = Union[Node, None, List[Node]] + + +def match_target(node: Node, op: Any) -> bool: + """ + Check if a node's normalized aten target matches the given op. + + Uses get_aten_target_normalized to handle edge dialect ops. + This means slice_copy matches slice, etc. + + Args: + node: The node to check + op: The op to match (e.g., torch.ops.aten.mul.Tensor) + """ + return node.op == "call_function" and get_aten_target_normalized(node.target) == op + + +def has_single_user(node: Node) -> bool: + return len(node.users) == 1 + + +def has_no_users(node: Node) -> bool: + return len(node.users) == 0 + + +def extract_lifted_tensor_constant(node: Node) -> Optional[float]: + """ + Extract scalar value from a lifted tensor constant node. + + Lifted constants are created during torch.export and contain small + constant tensors (like epsilon values). The actual value is stored + in node.meta["val"]. + + Args: + node: A node that may be a lifted tensor constant + + Returns: + The scalar float value, or None if not a lifted constant or not scalar + """ + if not isinstance(node, Node): + return None + if "lifted_tensor_constant" not in node.name: + return None + val = node.meta.get("val") + if val is None: + return None + if not hasattr(val, "item"): + return None + try: + return float(val.item()) + except (RuntimeError, ValueError): + return None + + +@dataclass +class OpStep: + """ + One step in a backward walk through the graph. + + Used with walk_back() to define pattern chains. Supports both exact op + matching and predicate-based matching. + + Attributes: + op: Specific op to match (e.g., torch.ops.aten.rsqrt.default) + predicate: Alternative to op - a function that returns True for matching nodes + optional: If True, skip this step if it doesn't match + repeat: If True, match this step 0 or more times (like regex *) + require_single_user: If True (default), only match nodes with exactly one user + nargs: Number of args required. Can be: + - int: minimum number of args (default 1, since we advance via args[0]) + - tuple (min, max): range of args required (inclusive) + kwargs: Set of kwargs we handle (node's kwargs must be subset of this) + arg_index: Which arg to follow when advancing (default 0) + + Examples: + # Match specific op + OpStep(op=torch.ops.aten.rsqrt.default) + + # Match with predicate (for matching families of ops) + OpStep(predicate=lambda n: match_target(n, torch.ops.aten.select.int)) + + # Match chain of same op type (0 or more) + OpStep(op=torch.ops.aten.select.int, repeat=True) + + # Optional dtype conversion + OpStep(op=torch.ops.aten._to_copy.default, optional=True) + + # Require between 2 and 4 args + OpStep(op=torch.ops.aten.some_op.default, nargs=(2, 4)) + + # Declare that we handle 'dtype' kwarg + OpStep(op=torch.ops.aten._to_copy.default, kwargs={"dtype"}) + + # Follow second arg (e.g., mul(x, rsqrt(y)) -> follow rsqrt in args[1]) + OpStep(op=torch.ops.aten.mul.Tensor, arg_index=1) + """ + + op: Any = None + predicate: Optional[Callable[[Node], bool]] = None + optional: bool = False + repeat: bool = False + require_single_user: bool = True + nargs: Union[int, Tuple[int, int]] = 1 + kwargs: Set[str] = field(default_factory=set) # Empty = no kwargs allowed + arg_index: int = 0 + + def matches(self, node: Node) -> bool: + """Check if this step fully matches the given node.""" + # Check op or predicate + if self.op is not None: + if not match_target(node, self.op): + return False + elif self.predicate is not None: + if not self.predicate(node): + return False + else: + return False + + # Check single user requirement + if self.require_single_user and not has_single_user(node): + return False + + # Check nargs and kwargs + if not self._check_nargs(node): + return False + if not self._check_kwargs(node): + return False + + return True + + def _check_nargs(self, node: Node) -> bool: + """Check if node has the required number of args.""" + n = len(node.args) + if isinstance(self.nargs, tuple): + min_args, max_args = self.nargs + # Must be in range AND enough to access arg_index + return min_args <= n <= max_args and n > self.arg_index + else: + # Must have at least nargs, AND enough to access arg_index + return n >= self.nargs and n > self.arg_index + + def _check_kwargs(self, node: Node) -> bool: + """Check that node's kwargs are all declared in self.kwargs (no unhandled kwargs).""" + return set(node.kwargs.keys()).issubset(self.kwargs) + + +def walk_back( # noqa: C901 + node: Node, + steps: List[OpStep], + debug: bool = False, +) -> Optional[Tuple[Node, List[WalkBackEntry]]]: + """ + Walk backwards through a chain of ops, matching against a pattern. + + Starting from *node*, try to match each step against the current node. + At every matched step the walk advances to ``cur.args[step.arg_index]``. + Optional steps are silently skipped when they don't match. Repeat steps + match 0 or more times. + + Args: + node: Starting node + steps: List of OpStep to match in order + + Returns: + ``(base_node, entries)`` if the full chain matches, else ``None``. + *base_node* is the input to the first (deepest) op in the chain. + *entries* is a list with one entry per OpStep: + - Node: matched node (for regular steps) + - None: optional step that didn't match + - List[Node]: repeat step (0 or more matches) + + Examples: + # Match: rsqrt(add(mean(pow(x, 2)), eps)) + result = walk_back(rsqrt_node, [ + OpStep(op=torch.ops.aten.rsqrt.default), + OpStep(op=torch.ops.aten.add.Tensor), + OpStep(op=torch.ops.aten.mean.dim), + OpStep(op=torch.ops.aten.pow.Tensor_Scalar), + ]) + if result: + base, entries = result + rsqrt, add, mean, pow = entries # Each is a Node + + # Match chain of select ops (like tensor[0][0]) + result = walk_back(node, [ + OpStep(op=torch.ops.aten.select.int, repeat=True), + ]) + if result: + base, entries = result + select_nodes = entries[0] # List[Node], may be empty + + # Skip optional _to_copy, then match rsqrt + result = walk_back(node, [ + OpStep(op=torch.ops.aten._to_copy.default, optional=True), + OpStep(op=torch.ops.aten.rsqrt.default), + ]) + if result: + base, entries = result + to_copy, rsqrt = entries # to_copy may be None + """ + entries: List[WalkBackEntry] = [] + cur = node + + for i, step in enumerate(steps): + if not isinstance(cur, Node): + if debug: + print( + f" [walk_back] step {i}: cur is not a Node ({type(cur).__name__})" + ) + return None + + if step.repeat: + # Match 0 or more times, return as list + matched_nodes: List[Node] = [] + while isinstance(cur, Node) and step.matches(cur): + matched_nodes.append(cur) + cur = cur.args[step.arg_index] + entries.append(matched_nodes) + if debug: + print( + f" [walk_back] step {i} (repeat): matched {len(matched_nodes)} nodes" + ) + # repeat always succeeds (matches 0 or more) + continue + + if step.matches(cur): + entries.append(cur) + if debug: + print(f" [walk_back] step {i}: matched {cur.name}") + cur = cur.args[step.arg_index] + elif step.optional: + entries.append(None) + if debug: + print(f" [walk_back] step {i} (optional): skipped, cur={cur.name}") + continue + else: + if debug: + print( + f" [walk_back] step {i}: FAILED at cur={cur.name}, target={cur.target}, step.op={step.op}" + ) + return None + + if not isinstance(cur, Node): + return None + + return cur, entries + + +@dataclass +class PatternMatch: + """ + Base class for pattern match results. + + Subclasses should: + 1. Add fields for captured values (input nodes, constants, etc.) + 2. Implement maybe_create() classmethod for pattern matching + 3. Optionally implement emit_* methods for specific backends + + Example: + @dataclass + class RMSNormMatch(PatternMatch): + input_node: Node + weight_node: Node + eps: float + + @classmethod + def maybe_create(cls, head: Node) -> Optional["RMSNormMatch"]: + # Pattern matching logic... + if not matched: + return None + return cls( + head=head, + body=body_nodes, + input_node=input_node, + weight_node=weight_node, + eps=eps_value, + ) + """ + + head: Node # The output node of the matched pattern + body: List[Node] = field(default_factory=list) # Intermediate nodes + + @classmethod + def maybe_create(cls, head: Node, **context) -> Optional["PatternMatch"]: + """ + Try to match the pattern starting from head node. + + Override in subclasses to implement pattern-specific matching. + + Args: + head: Candidate head node to match from + **context: Additional context (e.g., ExportedProgram for patterns.py) + + Returns: + PatternMatch instance with captured values, or None if no match + """ + return None + + def remove_body_nodes(self, graph: Graph) -> None: + """ + Remove body nodes from the graph (in reverse order for safety). + + Call after replacing head with fused op. + """ + for node in reversed(self.body): + if has_no_users(node): + graph.erase_node(node) + + def all_nodes(self) -> List[Node]: + """Return all nodes in the pattern (head + body).""" + return [self.head] + self.body diff --git a/backends/mlx/patterns.py b/backends/mlx/patterns.py new file mode 100644 index 00000000000..c8bef1f91ca --- /dev/null +++ b/backends/mlx/patterns.py @@ -0,0 +1,14 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +MLX Pattern Handlers - pattern-based op lowering for fused operations. + +This module contains pattern handlers that match multi-node subgraphs and lower +them to optimized MLX operations. +""" diff --git a/backends/mlx/preprocess.py b/backends/mlx/preprocess.py new file mode 100644 index 00000000000..315835f1689 --- /dev/null +++ b/backends/mlx/preprocess.py @@ -0,0 +1,168 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +MLX Backend preprocessing - converts EdgeIR to MLX delegate payload. + +This module implements the BackendDetails.preprocess() method which: +1. Takes an ExportedProgram (edge dialect) +2. Builds an MLXGraph using MLXProgramBuilder +3. Serializes to FlatBuffer (no embedded constants - those come via named_data_map) +4. Returns PreprocessResult with the binary and data_store_output for constants +""" + +from __future__ import annotations + +import hashlib +from typing import ClassVar, final, List + +from executorch.backends.mlx._logging import logger +from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder +from executorch.backends.mlx.serialization.mlx_graph_serialize import ( + HEADER_LENGTH, + MAGIC, + serialize_mlx_graph, +) +from executorch.exir.backend.backend_details import ( + BackendDetails, + CompileSpec, + PreprocessResult, +) +from torch.export.exported_program import ExportedProgram + + +@final +class MLXBackend(BackendDetails): + """ + ExecuTorch backend for MLX (Apple Silicon GPU compute framework). + + This backend compiles EdgeIR programs to a custom bytecode format + that can be executed by the MLX C++ runtime. + + Constants (weights) are stored in ExecuTorch's named_data_map rather than + embedded in the delegate payload. This allows ExecuTorch to own the constant + data and provide it to the backend at runtime. + """ + + MAGIC_IX: ClassVar[slice] = slice(4, 8) + DATA_SEGMENT_OFFSET_IX: ClassVar[slice] = slice(8, 16) + DATA_SEGMENT_SIZE_IX: ClassVar[slice] = slice(16, 24) + + EXPECTED_MAGIC: ClassVar[bytes] = MAGIC + EXPECTED_LENGTH: ClassVar[int] = HEADER_LENGTH + + @staticmethod + def preprocess( + edge_program: ExportedProgram, + compile_specs: List[CompileSpec], + ) -> PreprocessResult: + """ + Convert an ExportedProgram to MLX delegate payload. + + Args: + edge_program: The ExportedProgram in edge dialect to compile. + compile_specs: List of compilation options. + + Returns: + PreprocessResult containing the serialized MLX program and + data_store_output with constant tensor data. + """ + logger.debug("MLXBackend.preprocess() called") + logger.debug(f"Edge program:\n{edge_program}") + + # Build MLXGraph from ExportedProgram + # Use a deterministic 4-hex prefix derived from the edge program to + # namespace named_data keys, avoiding collisions in multi-method + # programs where different methods may have lifted tensor constants + # with the same auto-generated name. + prefix = hashlib.sha256(str(edge_program).encode()).hexdigest()[:4] + builder = MLXProgramBuilder(edge_program, named_data_key_prefix=prefix) + mlx_graph = builder.build() + + # Get constant data as NamedDataStore (ET will own this data) + named_data_store = builder.get_named_data_store() + + logger.debug(f" named_data_store entries: {len(named_data_store.pte_data)}") + _log_mlx_graph(mlx_graph) + + # Serialize to bytes (no constant data embedded) + serialized = serialize_mlx_graph(mlx_graph) + + logger.debug(f"MLXBackend.preprocess() complete: {len(serialized)} bytes") + + return PreprocessResult( + processed_bytes=serialized, + data_store_output=named_data_store.get_named_data_store_output(), + ) + + +def _format_tensor_meta(meta) -> str: + """Format a TensorMeta for display.""" + shape_parts = [] + for dim in meta.shape: + if dim.value == -1: + # Dynamic dim + if dim.max_value == -1: + shape_parts.append(f"dyn(min={dim.min_value})") + else: + shape_parts.append(f"dyn({dim.min_value}..{dim.max_value})") + else: + shape_parts.append(str(dim.value)) + shape_str = f"[{', '.join(shape_parts)}]" + dtype_str = f"dtype={meta.scalar_type}" if meta.scalar_type is not None else "" + dim_order_str = f"dim_order={meta.dim_order}" if meta.dim_order is not None else "" + parts = [shape_str] + if dtype_str: + parts.append(dtype_str) + if dim_order_str: + parts.append(dim_order_str) + return ", ".join(parts) + + +def _log_mlx_graph(mlx_graph) -> None: # noqa: C901 + """Log MLXGraph contents at DEBUG level for debugging.""" + logger.debug("MLXGraph:") + logger.debug(f" version: {mlx_graph.version}") + logger.debug(f" num_constant_tensors: {mlx_graph.num_constant_tensors}") + logger.debug(f" num_input_tensors: {mlx_graph.num_input_tensors}") + logger.debug(f" num_output_tensors: {mlx_graph.num_output_tensors}") + logger.debug( + f" num_mutable_buffer_tensors: {mlx_graph.num_mutable_buffer_tensors}" + ) + logger.debug(f" num_temp_tensors: {mlx_graph.num_temp_tensors}") + logger.debug(f" num_values: {mlx_graph.num_values}") + logger.debug(f" instruction_chains ({len(mlx_graph.instruction_chains)}):") + for c, chain in enumerate(mlx_graph.instruction_chains): + label = "" + if c == mlx_graph.main_chain_idx: + label = " (main)" + elif c == mlx_graph.init_chain_idx: + label = " (init)" + logger.debug(f" chain {c}{label} ({len(chain.instructions)} instructions):") + for i, instr in enumerate(chain.instructions): + logger.debug(f" [{i}]: {type(instr.op).__name__}") + if mlx_graph.input_map: + logger.debug(f" input_map ({len(mlx_graph.input_map)}):") + for i, slot in enumerate(mlx_graph.input_map): + logger.debug(f" [{i}]: {slot}") + if mlx_graph.output_map: + logger.debug(f" output_map ({len(mlx_graph.output_map)}):") + for i, slot in enumerate(mlx_graph.output_map): + logger.debug(f" [{i}]: {slot}") + if mlx_graph.mutable_buffer_map: + logger.debug(f" mutable_buffer_map ({len(mlx_graph.mutable_buffer_map)}):") + for i, slot in enumerate(mlx_graph.mutable_buffer_map): + logger.debug(f" [{i}]: {slot}") + if mlx_graph.named_slots: + logger.debug(f" named_slots ({len(mlx_graph.named_slots)}):") + for ns in mlx_graph.named_slots: + logger.debug(f" {ns.name}: {ns.slot}") + if mlx_graph.tensor_meta: + logger.debug(f" tensor_meta ({len(mlx_graph.tensor_meta)}):") + for i, meta in enumerate(mlx_graph.tensor_meta): + logger.debug(f" t{i}: {_format_tensor_meta(meta)}") diff --git a/backends/mlx/pte_inspector.py b/backends/mlx/pte_inspector.py new file mode 100644 index 00000000000..d9e533b0b1e --- /dev/null +++ b/backends/mlx/pte_inspector.py @@ -0,0 +1,897 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +PTE Inspector - Extract and dump data from ExecuTorch .pte files. + +This utility can: +1. Parse the PTE file structure (header, flatbuffer, segments) +2. Extract delegate payloads (e.g., MLX backend data) +3. Convert FlatBuffer data to JSON for inspection + +Usage: + python pte_inspector.py mlx_mlp.pte + python pte_inspector.py mlx_mlp.pte --output output.json + python pte_inspector.py mlx_mlp.pte --extract-delegate mlx --output mlx_payload.bin +""" + +from __future__ import annotations + +import argparse +import json +import sys +import traceback +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple + +from executorch.backends.mlx._generated_inspector import OP_NODE_FIELDS +from executorch.backends.mlx.serialization._generated_serializers import ( + MLX_OP_TYPE_NAMES, +) +from executorch.exir._serialize._program import ( + _ExtendedHeader, + _extract_delegate_payload as extract_delegate_payload, +) + +MLX_MAGIC = b"MLX0" +MLX_HEADER_LENGTH = 24 + +_SLOT_TYPE_NAMES = {0: "Tensor", 1: "Int", 2: "Float", 3: "Bool"} + + +@dataclass +class MLXHeader: + + magic: bytes + data_segment_offset: int + data_segment_size: int + + @classmethod + def from_bytes(cls, data: bytes) -> "MLXHeader": + if len(data) < MLX_HEADER_LENGTH: + raise ValueError( + f"Not enough data for MLX header: {len(data)} < {MLX_HEADER_LENGTH}" + ) + + # Layout: [4 bytes padding][4 bytes magic][8 bytes offset][8 bytes size] + magic = data[4:8] + data_segment_offset = int.from_bytes(data[8:16], byteorder="little") + data_segment_size = int.from_bytes(data[16:24], byteorder="little") + + return cls( + magic=magic, + data_segment_offset=data_segment_offset, + data_segment_size=data_segment_size, + ) + + def is_valid(self) -> bool: + return self.magic == MLX_MAGIC + + def to_dict(self) -> Dict[str, Any]: + return { + "magic": self.magic.decode("utf-8", errors="replace"), + "data_segment_offset": self.data_segment_offset, + "data_segment_size": self.data_segment_size, + } + + +@dataclass +class MLXPayload: + """Parsed MLX delegate payload: header + flatbuffer bytes.""" + + header: MLXHeader + fb_data: bytes + raw: bytes + + +def _load_mlx_payload(pte_data: bytes, delegate_index: int = 0) -> MLXPayload: + """Extract MLX delegate payload from PTE data and parse its header. + + Raises ``ValueError`` if the delegate cannot be found or the MLX header is + invalid. + """ + payload = extract_delegate_payload(pte_data, "mlx", delegate_index=delegate_index) + if payload is None: + raise ValueError(f"Could not extract MLX delegate {delegate_index}") + + header = MLXHeader.from_bytes(payload) + if not header.is_valid(): + raise ValueError(f"Invalid MLX magic: {header.magic!r}") + + fb_data = payload[MLX_HEADER_LENGTH : header.data_segment_offset] + return MLXPayload(header=header, fb_data=fb_data, raw=payload) + + +def _find_mlx_delegates(pte_data: bytes) -> List[Tuple[int, Dict]]: + """Return list of ``(plan_index, delegate_dict)`` for every MLX delegate.""" + from executorch.exir._serialize._flatbuffer import _program_flatbuffer_to_json + + program_data = json.loads(_program_flatbuffer_to_json(pte_data)) + delegates: List[Tuple[int, Dict]] = [] + for plan in program_data.get("execution_plan", []): + for i, delegate in enumerate(plan.get("delegates", [])): + if "mlx" in delegate.get("id", "").lower(): + delegates.append((i, delegate)) + return delegates + + +def _get_fb_graph(fb_data: bytes): + """Return the FlatBuffer MLXGraph root object.""" + from executorch.backends.mlx.serialization._generated.mlx_delegate import ( + MLXGraph as FBMLXGraph, + ) + + return FBMLXGraph.MLXGraph.GetRootAs(fb_data, 0) + + +def _parse_graph_info(graph) -> Dict[str, Any]: + """Extract top-level graph scalars (tensor counts, chain counts, etc.).""" + return { + "version": graph.Version().decode("utf-8") if graph.Version() else None, + "num_constant_tensors": graph.NumConstantTensors(), + "num_input_tensors": graph.NumInputTensors(), + "num_output_tensors": graph.NumOutputTensors(), + "num_mutable_buffer_tensors": graph.NumMutableBufferTensors(), + "num_temp_tensors": graph.NumTempTensors(), + "num_values": graph.NumValues(), + "num_instruction_chains": graph.InstructionChainsLength(), + "main_chain_idx": graph.MainChainIdx(), + "init_chain_idx": graph.InitChainIdx(), + "input_map_length": graph.InputMapLength(), + "output_map_length": graph.OutputMapLength(), + "mutable_buffer_map_length": graph.MutableBufferMapLength(), + "named_slots_length": graph.NamedSlotsLength(), + "tensor_meta_length": graph.TensorMetaLength(), + } + + +def _parse_instructions(graph) -> List[Dict[str, Any]]: + """Parse all instruction chains and their op nodes.""" + chains: List[Dict[str, Any]] = [] + for c in range(graph.InstructionChainsLength()): + chain = graph.InstructionChains(c) + chain_info: Dict[str, Any] = {"chain_index": c, "instructions": []} + if chain: + for i in range(chain.InstructionsLength()): + try: + instr = chain.Instructions(i) + if instr: + op_type = instr.OpType() + op_name = MLX_OP_TYPE_NAMES.get(op_type, f"Unknown({op_type})") + instr_info: Dict[str, Any] = { + "instr_idx": i, + "op_type": op_type, + "op_name": op_name, + } + op_data = _parse_op_node(instr, op_name) + if op_data: + instr_info.update(op_data) + chain_info["instructions"].append(instr_info) + except Exception as e: + chain_info["instructions"].append( + {"instr_idx": i, "error": f"parse_failed: {e}"} + ) + chains.append(chain_info) + return chains + + +def _parse_named_slots(graph) -> List[Dict[str, Any]]: + slots: List[Dict[str, Any]] = [] + for i in range(graph.NamedSlotsLength()): + try: + ns = graph.NamedSlots(i) + if ns: + info: Dict[str, Any] = { + "name": ns.Name().decode("utf-8") if ns.Name() else None, + } + slot = ns.Slot() + if slot: + info["slot_idx"] = slot.Idx() + info["slot_type"] = slot.SlotType() + slots.append(info) + except Exception as e: + slots.append({"instr_idx": i, "error": f"parse_failed: {e}"}) + return slots + + +def _parse_tensor_meta(graph) -> List[Dict[str, Any]]: + metas: List[Dict[str, Any]] = [] + for i in range(graph.TensorMetaLength()): + try: + tm = graph.TensorMeta(i) + if tm: + shape: List[Any] = [] + for j in range(tm.ShapeLength()): + sd = tm.Shape(j) + if sd.Value() == -1: + lo = sd.MinValue() + hi = sd.MaxValue() + if hi == -1: + shape.append(f"dyn(min={lo})") + else: + shape.append(f"dyn({lo}..{hi})") + else: + shape.append(sd.Value()) + meta: Dict[str, Any] = { + "index": i, + "dtype": tm.Dtype(), + "shape": shape, + } + if tm.StridesLength() > 0: + meta["strides"] = [tm.Strides(j) for j in range(tm.StridesLength())] + metas.append(meta) + except Exception as e: + metas.append({"instr_idx": i, "error": f"parse_failed: {e}"}) + return metas + + +def _parse_io_maps( + graph, +) -> Tuple[List[Dict], List[Dict], List[Dict]]: + """Return (input_map, output_map, mutable_buffer_map) as slot-variant dicts.""" + + def _extract( + length_fn: Callable[[], int], getter_fn: Callable[[int], Any] + ) -> List[Dict]: + result = [] + for i in range(length_fn()): + try: + sv = getter_fn(i) + if sv: + result.append({"idx": sv.Idx(), "slot_type": sv.SlotType()}) + except Exception as e: + result.append({"instr_idx": i, "error": f"parse_failed: {e}"}) + return result + + return ( + _extract(graph.InputMapLength, graph.InputMap), + _extract(graph.OutputMapLength, graph.OutputMap), + _extract(graph.MutableBufferMapLength, graph.MutableBufferMap), + ) + + +def parse_mlx_flatbuffer(fb_data: bytes) -> Dict[str, Any]: + """Parse MLX FlatBuffer data into a dict using the generated FlatBuffer bindings.""" + result: Dict[str, Any] = {} + try: + graph = _get_fb_graph(fb_data) + + result = _parse_graph_info(graph) + result["instruction_chains"] = _parse_instructions(graph) + result["named_slots"] = _parse_named_slots(graph) + result["tensor_meta"] = _parse_tensor_meta(graph) + + input_map, output_map, mutable_buffer_map = _parse_io_maps(graph) + result["input_map"] = input_map + result["output_map"] = output_map + result["mutable_buffer_map"] = mutable_buffer_map + + try: + cs = graph.ConstantSegment() + if cs: + result["constant_segment"] = { + "offset": cs.Offset(), + "size": cs.Size(), + } + except Exception as e: + result["constant_segment_error"] = f"parse_failed: {e}" + + except ImportError as e: + result["error"] = f"FlatBuffer bindings not available: {e}" + result["_fallback"] = "Using basic header parsing only" + except Exception as e: + result["error"] = f"FlatBuffer parse error: {e}" + result["traceback"] = traceback.format_exc() + + return result + + +def _parse_op_node(instr, op_name: str) -> Optional[Dict[str, Any]]: + """Parse the specific op node fields from an instruction. + + Uses the generated field mappings in ``OP_NODE_FIELDS`` to extract + op-specific fields without manually maintaining per-op logic. + """ + try: + op = instr.Op() + if op is None: + return None + + if op_name not in OP_NODE_FIELDS: + return {"error": f"Unknown op type: {op_name}"} + + module = __import__( + f"executorch.backends.mlx.serialization._generated.mlx_delegate.{op_name}", + fromlist=[op_name], + ) + node_class = getattr(module, op_name) + node = node_class() + node.Init(op.Bytes, op.Pos) + + result: Dict[str, Any] = {} + for field_name, accessor_name, kind in OP_NODE_FIELDS[op_name]: + try: + result[field_name] = _extract_field(node, accessor_name, kind) + except Exception as e: + result[field_name] = {"error": str(e)} + + result = {k: v for k, v in result.items() if v is not None} + return result if result else None + + except Exception as e: + return {"parse_error": str(e), "traceback": traceback.format_exc()} + + +def _extract_vid_or_tid(obj) -> Optional[Dict[str, Any]]: + """Extract a VidOrTid FlatBuffer object into a dict. + + VidOrTid has: .IsVid() -> bool, .Vid() -> Vid|None, .Tid() -> Tid|None. + Same pattern as IntOrVid but references value/tensor slots instead of + holding a literal. + """ + if obj is None: + return None + if obj.IsVid(): + v = obj.Vid() + return {"vid": v.Idx()} if v else None + t = obj.Tid() + return {"tid": t.Idx()} if t else None + + +def _extract_field(node, accessor_name: str, kind: str) -> Any: # noqa: C901 + """Extract a single field from a FlatBuffer op node based on its *kind*.""" + if kind == "tid": + t = getattr(node, accessor_name)() + return {"tid": t.Idx()} if t else None + + if kind == "vid": + v = getattr(node, accessor_name)() + return {"vid": v.Idx()} if v else None + + if kind == "vid_or_tid": + return _extract_vid_or_tid(getattr(node, accessor_name)()) + + if kind == "int_or_vid_or_tid": + ivt = getattr(node, accessor_name)() + if ivt is None: + return None + k = ivt.Kind() + if k == 0: # literal int + return {"literal": ivt.Literal()} + elif k == 1: # Vid + v = ivt.Vid() + return {"vid": v.Idx()} if v else None + elif k == 2: # Tid + t = ivt.Tid() + return {"tid": t.Idx()} if t else None + return {"kind": k} + + if kind == "int_or_vid": + iov = getattr(node, accessor_name)() + if iov is None: + return None + if iov.IsVid(): + v = iov.Vid() + return {"vid": v.Idx()} if v else None + return {"literal": iov.Literal()} + + if kind == "float_or_vid": + fov = getattr(node, accessor_name)() + if fov is None: + return None + if fov.IsVid(): + v = fov.Vid() + return {"vid": v.Idx()} if v else None + return {"literal": fov.Literal()} + + if kind == "int_list": + length = getattr(node, f"{accessor_name}Length")() + getter = getattr(node, accessor_name) + return [getter(i) for i in range(length)] + + if kind == "tid_list": + length = getattr(node, f"{accessor_name}Length")() + getter = getattr(node, accessor_name) + items = [] + for i in range(length): + s = getter(i) + items.append(f"tid {s.Idx()}" if s else None) + return items + + if kind == "int_or_vid_list": + length = getattr(node, f"{accessor_name}Length")() + getter = getattr(node, accessor_name) + items = [] + for i in range(length): + iov = getter(i) + if iov is None: + items.append(None) + elif iov.IsVid(): + v = iov.Vid() + items.append({"vid": v.Idx()} if v else None) + else: + items.append({"literal": iov.Literal()}) + return items + + if kind == "string": + val = getattr(node, accessor_name)() + return val.decode("utf-8") if val else None + + # scalar (default) + return getattr(node, accessor_name)() + + +def parse_mlx_payload(payload: bytes) -> Dict[str, Any]: + """Parse raw MLX delegate payload bytes into a dict. + + This is the public entry point for callers that already have the raw + delegate payload (e.g. from ``extract_delegate_payload``). + """ + header = MLXHeader.from_bytes(payload) + + if not header.is_valid(): + return { + "error": f"Invalid MLX magic: {header.magic!r}", + "header": header.to_dict(), + } + + fb_data = payload[MLX_HEADER_LENGTH : header.data_segment_offset] + result: Dict[str, Any] = { + "header": header.to_dict(), + "flatbuffer_size": len(fb_data), + "graph": parse_mlx_flatbuffer(fb_data), + } + + if header.data_segment_size > 0: + result["constant_data_size"] = header.data_segment_size + + return result + + +def parse_executorch_program(pte_data: bytes) -> Dict[str, Any]: # noqa: C901 + result: Dict[str, Any] = {} + + if len(pte_data) < 8: + raise ValueError("File too small to be a valid PTE file") + + fb_magic = pte_data[4:8] + result["flatbuffer_magic"] = fb_magic.decode("utf-8", errors="replace") + + extended_header_offset = 8 + if len(pte_data) > extended_header_offset + 32: + try: + header = _ExtendedHeader.from_bytes(pte_data[extended_header_offset:]) + if header.is_valid(): + result["extended_header"] = { + "magic": header.magic.decode("utf-8", errors="replace"), + "length": header.length, + "program_size": header.program_size, + "segment_base_offset": header.segment_base_offset, + "segment_data_size": header.segment_data_size, + } + fb_start = extended_header_offset + header.length + result["flatbuffer_offset"] = fb_start + result["flatbuffer_size"] = header.program_size + result["segment_offset"] = header.segment_base_offset + result["segment_size"] = header.segment_data_size + except Exception as e: + result["header_parse_error"] = str(e) + + try: + from executorch.exir._serialize._flatbuffer import _program_flatbuffer_to_json + + program_data = json.loads(_program_flatbuffer_to_json(pte_data)) + result["program"] = program_data + + if "execution_plan" in program_data: + delegates = [] + for plan in program_data["execution_plan"]: + if "delegates" in plan: + for delegate in plan["delegates"]: + delegate_info: Dict[str, Any] = { + "id": delegate.get("id"), + "processed_type": delegate.get("processed", {}).get( + "location" + ), + } + processed = delegate.get("processed", {}) + if "data" in processed: + delegate_info["inline_data_size"] = len(processed["data"]) + if "location" in processed: + delegate_info["location"] = processed["location"] + delegates.append(delegate_info) + result["delegates"] = delegates + + except ImportError: + result["program_parse_error"] = "ExecuTorch FlatBuffer parsing not available" + except Exception as e: + result["program_parse_error"] = str(e) + + return result + + +def _slot_type_display(slot_type: int, style: str = "full") -> str: + """Return display string for a slot type. + + *style* controls the format: + - ``"full"``: "Tensor", "Int", etc. (for summary tables) + - ``"short"``: "tid", "vid" (for instruction I/O lists) + """ + if style == "short": + return "tid" if slot_type == 0 else "vid" + return _SLOT_TYPE_NAMES.get(slot_type, "Unknown") + + +def _print_slot_map(label: str, slots: List[Dict]) -> None: + """Print a list of slot-variant dicts with their type names.""" + if not slots: + return + print(f"\n {label}:") + for i, slot in enumerate(slots): + type_name = _slot_type_display(slot.get("slot_type", 0)) + print(f" [{i}]: idx={slot.get('idx')}, type={type_name}") + + +def show_mlx_summary(pte_data: bytes) -> None: # noqa: C901 + try: + mlx_delegates = _find_mlx_delegates(pte_data) + if not mlx_delegates: + print("No MLX delegates found in this PTE file.") + return + + print(f"\n{'='*70}") + print("MLX DELEGATE SUMMARY") + print(f"{'='*70}") + print(f"File contains {len(mlx_delegates)} MLX delegate(s)\n") + + for idx, (delegate_idx, delegate) in enumerate(mlx_delegates): + print(f"\n--- Delegate {idx} (plan index {delegate_idx}) ---") + print(f"ID: {delegate.get('id', 'unknown')}") + + try: + mlx = _load_mlx_payload(pte_data, delegate_index=idx) + except ValueError as e: + print(f" {e}") + continue + + graph_info = parse_mlx_flatbuffer(mlx.fb_data) + + print("\nMLX Graph Info:") + for key in ( + "num_constant_tensors", + "num_input_tensors", + "num_output_tensors", + "num_mutable_buffer_tensors", + "num_temp_tensors", + "num_values", + "num_instruction_chains", + ): + label = f" {key + ':':<29}" + print(f"{label}{graph_info.get(key, '?')}") + + main_idx = graph_info.get("main_chain_idx", 0) + chains = graph_info.get("instruction_chains", []) + main_num = "?" + if main_idx < len(chains): + main_num = len(chains[main_idx].get("instructions", [])) + print(f" {'main_chain_idx:':<29}{main_idx} ({main_num} instructions)") + print(f" {'init_chain_idx:':<29}{graph_info.get('init_chain_idx', '?')}") + + print("\nI/O Maps:") + print( + f" {'input_map length:':<29}{graph_info.get('input_map_length', '?')}" + ) + print( + f" {'output_map length:':<29}{graph_info.get('output_map_length', '?')}" + ) + print( + f" {'mutable_buffer_map length:':<29}{graph_info.get('mutable_buffer_map_length', '?')}" + ) + + input_len = graph_info.get("input_map_length", 0) + mutable_len = graph_info.get("mutable_buffer_map_length", 0) + if input_len and mutable_len is not None: + print( + f" => regular inputs expected: {input_len - mutable_len} (input_map - mutable_buffer_map)" + ) + + _print_slot_map("Input Map Details", graph_info.get("input_map", [])) + if graph_info.get("mutable_buffer_map"): + _print_slot_map( + "Mutable Buffer Map Details", + graph_info["mutable_buffer_map"], + ) + _print_slot_map("Output Map Details", graph_info.get("output_map", [])) + + if mlx.header.data_segment_size > 0: + print(f"\n Constant data size: {mlx.header.data_segment_size:,} bytes") + + print(f"\n{'='*70}\n") + + except Exception as e: + print(f"Error showing MLX summary: {e}", file=sys.stderr) + traceback.print_exc() + + +def show_mlx_instructions(pte_data: bytes) -> None: # noqa: C901 + try: + mlx_delegates = _find_mlx_delegates(pte_data) + if not mlx_delegates: + print("No MLX delegates found in this PTE file.", file=sys.stderr) + sys.exit(1) + + if len(mlx_delegates) > 1: + print( + f"Found {len(mlx_delegates)} MLX delegate(s) in PTE file\n", + file=sys.stderr, + ) + + for idx, (delegate_idx, _delegate) in enumerate(mlx_delegates): + try: + mlx = _load_mlx_payload(pte_data, delegate_index=idx) + except ValueError as e: + print(f"\nError: {e}", file=sys.stderr) + continue + + graph = parse_mlx_flatbuffer(mlx.fb_data) + if "error" in graph: + print( + f"\nError parsing delegate {idx}: {graph['error']}", + file=sys.stderr, + ) + continue + + # Print delegate header + if len(mlx_delegates) > 1: + print("\n" + "=" * 70) + print(f"MLX DELEGATE {idx} (plan index {delegate_idx})") + print("=" * 70) + else: + print("\n" + "=" * 70) + print("MLX Graph Summary") + print("=" * 70) + + # Basic info + print(f"Version: {graph.get('version', 'unknown')}") + print(f"Constant tensors: {graph.get('num_constant_tensors', 0)}") + print(f"Input tensors: {graph.get('num_input_tensors', 0)}") + print(f"Output tensors: {graph.get('num_output_tensors', 0)}") + print( + f"Mutable buffer tensors: {graph.get('num_mutable_buffer_tensors', 0)}" + ) + print(f"Temp tensors: {graph.get('num_temp_tensors', 0)}") + print(f"Values: {graph.get('num_values', 0)}") + num_chains = graph.get("num_instruction_chains", 0) + main_idx = graph.get("main_chain_idx", 0) + init_idx = graph.get("init_chain_idx", -1) + print(f"Instruction chains: {num_chains}") + print(f"Main chain idx: {main_idx}") + if init_idx >= 0: + print(f"Init chain idx: {init_idx}") + + constant_seg = graph.get("constant_segment", {}) + if constant_seg: + print(f"Constant data: {constant_seg.get('size', 0):,} bytes") + + # Instruction chains + for chain_info in graph.get("instruction_chains", []): + chain_idx = chain_info.get("chain_index", "?") + label = "" + if chain_idx == main_idx: + label = " (main)" + elif chain_idx == init_idx: + label = " (init)" + instructions = chain_info.get("instructions", []) + print(f"\nChain {chain_idx}{label} ({len(instructions)} instructions):") + for instr in instructions: + op_name = instr.get("op_name", f"op_{instr.get('op_type', '?')}") + print(f" [{instr.get('instr_idx', '?')}] {op_name}") + + for key, value in instr.items(): + if key in ("instr_idx", "op_type", "op_name"): + continue + if isinstance(value, dict): + if "tid" in value: + print(f" {key}: tid {value['tid']}") + elif "vid" in value: + print(f" {key}: vid {value['vid']}") + else: + print(f" {key}: {value}") + elif value is not None: + print(f" {key}: {value}") + + # Named slots + named_slots = graph.get("named_slots", []) + if named_slots: + print("\nNamed Slots:") + for slot in named_slots: + slot_type = _slot_type_display( + slot.get("slot_type", 0), style="short" + ) + print( + f" [{slot.get('slot_idx', '?')}] {slot.get('name', '?')} ({slot_type})" + ) + + # Input/Output maps + input_map = graph.get("input_map", []) + output_map = graph.get("output_map", []) + + if input_map: + print("\nInputs:") + for inp in input_map: + slot_type = _slot_type_display( + inp.get("slot_type", 0), style="short" + ) + print(f" {slot_type} {inp.get('idx', '?')}") + + if output_map: + print("\nOutputs:") + for out in output_map: + slot_type = _slot_type_display( + out.get("slot_type", 0), style="short" + ) + print(f" {slot_type} {out.get('idx', '?')}") + + print("=" * 70 + "\n") + + except Exception as e: + print(f"Error showing MLX instructions: {e}", file=sys.stderr) + traceback.print_exc() + sys.exit(1) + + +def main(): # noqa: C901 + parser = argparse.ArgumentParser( + description="Inspect ExecuTorch .pte files and extract data", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +MLX-Specific Options: + --mlx-summary Show high-level summary (tensor counts, I/O maps) + --mlx-instructions Show detailed instruction list with operation parameters + (use this to verify quantization, inspect ops, etc.) + +Examples: + # Basic PTE file inspection + python -m executorch.backends.mlx.pte_inspector model.pte + + # Show high-level MLX delegate summary + python -m executorch.backends.mlx.pte_inspector model.pte --mlx-summary + + # Show detailed MLX instructions (verify quantization, inspect operations) + python -m executorch.backends.mlx.pte_inspector model.pte --mlx-instructions + + # Extract raw delegate payload to binary file + python -m executorch.backends.mlx.pte_inspector model.pte \\ + --extract-delegate MLXBackend -o delegate.bin + """, + ) + parser.add_argument("pte_file", type=Path, help="Path to the .pte file") + parser.add_argument( + "--output", "-o", type=Path, help="Output file (default: stdout)" + ) + parser.add_argument( + "--extract-delegate", + type=str, + metavar="ID", + help="Extract delegate payload by ID (e.g., 'mlx')", + ) + parser.add_argument( + "--delegate-index", + type=int, + default=None, + metavar="N", + help="Index of delegate to extract (0-based). If not specified, extracts first matching delegate.", + ) + parser.add_argument( + "--parse-mlx", + action="store_true", + help="Parse extracted MLX payload (use with --extract-delegate mlx)", + ) + parser.add_argument( + "--mlx-summary", + action="store_true", + help="Show summary of all MLX delegates (input/output/mutable buffer counts)", + ) + parser.add_argument( + "--mlx-instructions", + action="store_true", + help="Show detailed MLX instruction list with operands and quantization details", + ) + parser.add_argument( + "--format", + choices=["json", "summary"], + default="json", + help="Output format (default: json)", + ) + parser.add_argument( + "--indent", + type=int, + default=2, + help="JSON indentation (default: 2)", + ) + + args = parser.parse_args() + + if not args.pte_file.exists(): + print(f"Error: File not found: {args.pte_file}", file=sys.stderr) + sys.exit(1) + + pte_data = args.pte_file.read_bytes() + print(f"Loaded {len(pte_data)} bytes from {args.pte_file}", file=sys.stderr) + + if args.mlx_instructions: + show_mlx_instructions(pte_data) + return + + if args.mlx_summary: + show_mlx_summary(pte_data) + return + + if args.extract_delegate: + payload = extract_delegate_payload( + pte_data, args.extract_delegate, delegate_index=args.delegate_index + ) + if payload is None: + print( + f"Error: Delegate '{args.extract_delegate}' not found", file=sys.stderr + ) + sys.exit(1) + + if args.parse_mlx and args.extract_delegate.lower() == "mlx": + result = parse_mlx_payload(payload) + + output = json.dumps(result, indent=args.indent, default=str) + + if args.output: + args.output.write_text(output) + print(f"Wrote parsed MLX data to {args.output}", file=sys.stderr) + else: + print(output) + else: + if args.output: + args.output.write_bytes(payload) + print(f"Wrote {len(payload)} bytes to {args.output}", file=sys.stderr) + else: + print(f"Delegate payload: {len(payload)} bytes", file=sys.stderr) + if len(payload) >= MLX_HEADER_LENGTH: + header = MLXHeader.from_bytes(payload) + print(f" Magic: {header.magic!r}", file=sys.stderr) + print( + f" Data offset: {header.data_segment_offset}", file=sys.stderr + ) + print(f" Data size: {header.data_segment_size}", file=sys.stderr) + return + + result = parse_executorch_program(pte_data) + result["file_size"] = len(pte_data) + result["file_path"] = str(args.pte_file) + + if args.format == "summary": + print(f"PTE File: {args.pte_file}") + print(f" Size: {len(pte_data):,} bytes") + if "extended_header" in result: + h = result["extended_header"] + print(f" Program size: {h['program_size']:,} bytes") + print(f" Segment offset: {h['segment_base_offset']:,}") + print(f" Segment size: {h['segment_data_size']:,} bytes") + if "delegates" in result: + print(f" Delegates: {len(result['delegates'])}") + for d in result["delegates"]: + print(f" - {d.get('id', 'unknown')}") + else: + output = json.dumps(result, indent=args.indent, default=str) + + if args.output: + args.output.write_text(output) + print(f"Wrote JSON to {args.output}", file=sys.stderr) + else: + print(output) + + +if __name__ == "__main__": + main() diff --git a/backends/mlx/runtime/MLXBackend.cpp b/backends/mlx/runtime/MLXBackend.cpp new file mode 100644 index 00000000000..38dff189935 --- /dev/null +++ b/backends/mlx/runtime/MLXBackend.cpp @@ -0,0 +1,419 @@ +// +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. +// + +#include "MLXExecutor.h" +#include "MLXInterpreter.h" +#include "MLXLoader.h" + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace mlx { + +// Note: We use fully qualified executorch::aten::Tensor because MLXExecutor.h +// defines Tensor as mlx::core::array in the executorch::backends::mlx +// namespace. +using ETTensor = ::executorch::aten::Tensor; +using ::executorch::runtime::ArrayRef; +using ::executorch::runtime::Backend; +using ::executorch::runtime::BackendExecutionContext; +using ::executorch::runtime::BackendInitContext; +using ::executorch::runtime::CompileSpec; +using ::executorch::runtime::DelegateHandle; +using ::executorch::runtime::Error; +using ::executorch::runtime::EValue; +using ::executorch::runtime::FreeableBuffer; +using ::executorch::runtime::Result; +using ::executorch::runtime::Span; + +using ::mlx::core::array; +using ::mlx::core::Dtype; +using ::mlx::core::eval; + +namespace { + +array tensor_to_mlx( + const ETTensor& t, + const std::optional& expected_meta = std::nullopt) { + if (!executorch::runtime::tensor_is_contiguous(t)) { + throw std::runtime_error("tensor_to_mlx: input tensor is not contiguous"); + } + + Dtype dtype = + resolve_dtype(static_cast(t.scalar_type())); + + if (expected_meta.has_value()) { + Dtype expected_dtype = resolve_dtype(expected_meta->scalar_type); + if (dtype != expected_dtype) { + throw std::runtime_error( + std::string("tensor_to_mlx: dtype mismatch - input tensor has ") + + ExecutionState::dtype_str(dtype) + " but model expects " + + ExecutionState::dtype_str(expected_dtype)); + } + } + + ::mlx::core::Shape shape; + for (int i = 0; i < t.dim(); ++i) { + auto dim_size = t.size(i); + if (dim_size > std::numeric_limits::max() || + dim_size < std::numeric_limits::min()) { + throw std::runtime_error( + "tensor_to_mlx: dimension " + std::to_string(i) + " size " + + std::to_string(dim_size) + " exceeds int range"); + } + shape.push_back(static_cast(dim_size)); + } + + // SAFETY: MLX reads this data during async_eval() Metal command encoding, + // which completes before the lock is released. The ET tensor must remain + // valid until async_eval returns. + const void* cptr = t.const_data_ptr(); + if (!cptr) { + throw std::runtime_error("tensor_to_mlx: tensor has null data pointer"); + } + void* data_ptr = const_cast(cptr); + auto deleter = [](void*) {}; + return array(data_ptr, shape, dtype, deleter); +} + +// Build the contiguous + dtype conversion pipeline for an output array. +// Returns a lazy array (not yet evaluated) ready for async_eval. +array prepare_output( + const array& arr, + Dtype expected_dtype, + const ::mlx::core::Stream& stream) { + array result = + ::mlx::core::contiguous(arr, /*allow_col_major=*/false, stream); + if (result.dtype() != expected_dtype) { + result = ::mlx::core::astype(result, expected_dtype, stream); + } + return result; +} + +// Wait for a prepared output array and copy its data to an ET tensor. +// The array must have been submitted via async_eval before calling this. +void write_output(array& arr, ETTensor& out) { + arr.wait(); + + // Resize output tensor if shape doesn't match (dynamic shapes) + const auto& mlx_shape = arr.shape(); + auto out_sizes = out.sizes(); + + bool shape_matches = (mlx_shape.size() == static_cast(out.dim())); + if (shape_matches) { + for (size_t i = 0; i < mlx_shape.size(); ++i) { + if (static_cast(mlx_shape[i]) != + static_cast(out_sizes[i])) { + shape_matches = false; + break; + } + } + } + + if (!shape_matches) { + std::vector new_sizes; + new_sizes.reserve(mlx_shape.size()); + for (auto d : mlx_shape) { + new_sizes.push_back(static_cast(d)); + } + auto err = resize_tensor( + out, + ArrayRef( + new_sizes.data(), new_sizes.size())); + if (err != Error::Ok) { + throw std::runtime_error("write_output: failed to resize output tensor"); + } + } + + size_t mlx_nbytes = arr.nbytes(); + size_t out_nbytes = out.nbytes(); + if (mlx_nbytes != out_nbytes) { + throw std::runtime_error( + "write_output: size mismatch - MLX has " + std::to_string(mlx_nbytes) + + " bytes, output has " + std::to_string(out_nbytes) + " bytes"); + } + + const void* src = arr.data(); + if (!src) { + throw std::runtime_error( + "write_output: arr.data() is null after wait()"); + } + std::memcpy(out.mutable_data_ptr(), src, out_nbytes); +} + +} // namespace + +struct MLXHandle { + MLXProgram program; + ConstantData constants; + MutableBufferData mutable_buffers; + ExecutionState state; // Reusable execution state + Interpreter interpreter; + ::mlx::core::Stream stream; // Dedicated GPU stream for this handle + + // Keep the constant buffers alive for zero-copy constants + // Each FreeableBuffer must outlive the MLX arrays that reference it + std::vector constant_buffers; + + MLXHandle() : stream(::mlx::core::new_stream(::mlx::core::Device::gpu)) {} + ~MLXHandle() = default; + + MLXHandle(const MLXHandle&) = delete; + MLXHandle& operator=(const MLXHandle&) = delete; +}; + +// MLX is not thread-safe: its computation graph is global shared state. +// A global mutex serializes graph construction and command submission +// across all handles. GPU execution and output copies can proceed +// without the lock (see execute() for the async pipeline design). +static std::mutex& mlx_global_mutex() { + static std::mutex m; + return m; +} + +class MLXBackend final : public ::executorch::runtime::BackendInterface { + public: + ~MLXBackend() override = default; + + bool is_available() const override { + return ::mlx::core::metal::is_available(); + } + + Result init( + BackendInitContext& context, + FreeableBuffer* processed, + ArrayRef compile_specs) const override { + std::lock_guard lock(mlx_global_mutex()); + auto* handle = + context.get_runtime_allocator()->allocateInstance(); + if (handle == nullptr) { + return Error::MemoryAllocationFailed; + } + + try { + new (handle) MLXHandle(); + + if (!processed || !processed->data() || processed->size() == 0) { + throw std::runtime_error("init: null or empty delegate payload"); + } + + handle->program = loader::load_program( + static_cast(processed->data()), processed->size()); + + // Validate schema version + if (handle->program.version != "1") { + throw std::runtime_error( + "Unsupported MLX schema version '" + handle->program.version + + "' (expected '1'). Rebuild the .pte with a matching SDK version."); + } + + // Load constants from named_data_map + // Constants are stored by name in the .pte file and provided by ET at + // runtime + const runtime::NamedDataMap* named_data_map = + context.get_named_data_map(); + load_constants( + handle->program, + named_data_map, + handle->constants, + handle->constant_buffers); + + // Delegate payload no longer needed after constants are loaded + processed->Free(); + processed = nullptr; + + // Load mutable buffers (e.g., KV cache) + load_mutable_buffers(handle->program, handle->mutable_buffers); + + // Bind execution state (reused across execute() calls) + handle->state.bind( + handle->program, handle->constants, handle->mutable_buffers); + + // Run init chain if present. + // SAFETY: The >= 0 check ensures init_chain_idx is non-negative, so the + // static_cast cannot produce UINT32_MAX from a -1 sentinel. + if (handle->program.init_chain_idx >= 0) { + handle->interpreter.run_chain( + handle->program, + static_cast(handle->program.init_chain_idx), + handle->state, + handle->stream); + } + + } catch (const std::exception& e) { + ET_LOG(Error, "Failed to load MLX program: %s", e.what()); + handle->~MLXHandle(); + if (processed != nullptr) { + processed->Free(); + } + return Error::InvalidProgram; + } + + return handle; + } + + Error execute( + ET_UNUSED BackendExecutionContext& context, + DelegateHandle* handle, + Span args) const override { + try { + std::vector prepared_outputs; + struct OutputInfo { + size_t arg_idx; + size_t prepared_idx; + }; + + std::vector tensor_output_info; + size_t arg_idx = 0; + + auto* h = static_cast(handle); + const auto& program = h->program; + + // Graph construction + async GPU dispatch (locked) + { + std::lock_guard lock(mlx_global_mutex()); + + h->state.reset(); + + const size_t n_inputs = program.input_map.size(); + const size_t n_outputs = program.output_map.size(); + if (n_inputs > SIZE_MAX - n_outputs) { + throw std::runtime_error("execute: input + output count overflow"); + } + const size_t expected_args = n_inputs + n_outputs; + if (args.size() != expected_args) { + ET_LOG( + Error, "Expected %zu args, got %zu", expected_args, args.size()); + return Error::InvalidArgument; + } + + // Bind inputs + for (const auto& slot : program.input_map) { + if (arg_idx >= args.size()) { + throw std::runtime_error( + "execute: arg_idx " + std::to_string(arg_idx) + + " out of bounds (args.size()=" + std::to_string(args.size()) + + ")"); + } + if (slot.slot_type == SlotType::TensorSlot) { + const ETTensor& tensor = args[arg_idx++]->toTensor(); + Tid tid{slot.idx}; + std::optional expected_meta = std::nullopt; + if (tid.idx < program.tensor_meta.size()) { + expected_meta = program.tensor_meta[tid.idx]; + } + h->state.set_tensor(tid, tensor_to_mlx(tensor, expected_meta)); + } else if (slot.slot_type == SlotType::IntValueSlot) { + int64_t val = args[arg_idx]->toInt(); + arg_idx++; + if (val > std::numeric_limits::max() || + val < std::numeric_limits::min()) { + ET_LOG( + Error, + "Int input value %lld exceeds int32 range", + static_cast(val)); + return Error::InvalidArgument; + } + h->state.set_value(Vid{slot.idx}, static_cast(val)); + } else { + throw std::runtime_error( + "Unhandled input slot type: " + + std::to_string(static_cast(slot.slot_type))); + } + } + + // Run the MLX program (builds lazy computation graph) + h->interpreter.run(program, h->state, h->stream); + + // Prepare output pipeline and collect int outputs + // Build contiguous + dtype conversion lazily for each tensor output, + // and extract int outputs (which don't need GPU) while still locked. + prepared_outputs.reserve(program.num_output_tensors); + + for (const auto& slot : program.output_map) { + if (slot.slot_type == SlotType::TensorSlot) { + ETTensor& out_tensor = args[arg_idx]->toTensor(); + Dtype expected_dtype = + resolve_dtype(static_cast( + out_tensor.scalar_type())); + array out_arr = prepare_output( + h->state.const_tensor_ref(Tid{slot.idx}), + expected_dtype, + h->stream); + tensor_output_info.push_back({arg_idx, prepared_outputs.size()}); + prepared_outputs.push_back(std::move(out_arr)); + arg_idx++; + } else if (slot.slot_type == SlotType::IntValueSlot) { + Vid vid{slot.idx}; + int64_t int_val = + static_cast(h->state.const_value_ref(vid)); + *args[arg_idx] = EValue(int_val); + arg_idx++; + } else { + throw std::runtime_error( + "Unhandled output slot type: " + + std::to_string(static_cast(slot.slot_type))); + } + } + + // Submit all output work to GPU asynchronously + // async_eval encodes Metal commands and returns immediately. + // The GPU will signal events on completion. + if (!prepared_outputs.empty()) { + ::mlx::core::async_eval(prepared_outputs); + } + + } // Lock released — GPU is still executing + + for (auto& info : tensor_output_info) { + ETTensor& out_tensor = args[info.arg_idx]->toTensor(); + + // write_output waits on arr to be ready + write_output(prepared_outputs[info.prepared_idx], out_tensor); + } + + h->state.reset(); // Release temp GPU buffers back to MLX cache + + return Error::Ok; + } catch (const std::exception& e) { + ET_LOG(Error, "MLX execute failed: %s", e.what()); + return Error::Internal; + } + } + + void destroy(DelegateHandle* handle) const override { + std::lock_guard lock(mlx_global_mutex()); + if (handle != nullptr) { + auto* mlx_handle = static_cast(handle); + mlx_handle->~MLXHandle(); + } + } +}; + +namespace { +auto cls = MLXBackend(); +Backend backend{"MLXBackend", &cls}; +static auto success_with_compiler = register_backend(backend); +} // namespace + +} // namespace mlx +} // namespace backends +} // namespace executorch diff --git a/backends/mlx/runtime/MLXExecutor.h b/backends/mlx/runtime/MLXExecutor.h new file mode 100644 index 00000000000..32d623790ab --- /dev/null +++ b/backends/mlx/runtime/MLXExecutor.h @@ -0,0 +1,878 @@ +// +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. +// + +#pragma once + +#include "MLXLoader.h" + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +// ============================================================================= +// Op Logging - compile-time gate + runtime env var check +// +// Compile flag (CMake: -DET_MLX_ENABLE_OP_LOGGING=1) controls whether logging +// code is compiled in at all. When off, all logging is stripped (zero +// overhead). When on, the env var ET_MLX_ENABLE_OP_LOGGING=1 must also be set +// at runtime to actually produce output. +// ============================================================================= +#ifndef ET_MLX_ENABLE_OP_LOGGING +#define ET_MLX_ENABLE_OP_LOGGING 0 +#endif + +// ============================================================================= +// Constant Zero-Copy - Enable via CMake: -DET_MLX_ENABLE_CONSTANT_ZERO_COPY=1 +// When enabled, attempts to load model constants (weights) using zero-copy +// on Apple Silicon's unified memory. Falls back to copying if zero-copy fails. +// Disable if you want predictable memory usage (always copies). +// ============================================================================= +#ifndef ET_MLX_ENABLE_CONSTANT_ZERO_COPY +#define ET_MLX_ENABLE_CONSTANT_ZERO_COPY 1 // Enabled by default +#endif + +namespace executorch { +namespace backends { +namespace mlx { + +/// Multiply two unsigned values, throw on overflow. +template +inline T safe_mul(T a, T b, const char* context) { + static_assert(std::is_unsigned::value, "safe_mul requires unsigned type"); + T result; + if (__builtin_mul_overflow(a, b, &result)) { + throw std::runtime_error(std::string(context) + ": unsigned mul overflow"); + } + return result; +} + +// Runtime check for op logging (only callable when compiled in) +#if ET_MLX_ENABLE_OP_LOGGING +inline bool isOpLoggingEnabled() { + static const bool enabled = []() { + const char* val = std::getenv("ET_MLX_ENABLE_OP_LOGGING"); + return val != nullptr && std::string(val) == "1"; + }(); + return enabled; +} +#else +constexpr bool isOpLoggingEnabled() { + return false; +} +#endif + +// Compile-time constant zero-copy flag +constexpr bool kEnableConstantZeroCopy = ET_MLX_ENABLE_CONSTANT_ZERO_COPY; + +using Tensor = ::mlx::core::array; +using Value = std::variant; +using StreamOrDevice = ::mlx::core::StreamOrDevice; + +struct ConstantData { + std::vector tensors; + + inline const Tensor& get(Tid id) const { + if (id.idx >= tensors.size()) { + throw std::out_of_range("ConstantData::get: id out of range"); + } + return tensors[id.idx]; + } + + inline void add(Tensor t) { + tensors.push_back(std::move(t)); + } +}; + +struct MutableBufferData { + // Maps tensor slot idx to MLX array + // Using vector of optional since mlx::array has no default constructor + std::vector> tensors; + + inline void resize(size_t n) { + tensors.resize(n, std::nullopt); + } + + inline bool has(Tid id) const { + return id.idx < tensors.size() && tensors[id.idx].has_value(); + } + + inline Tensor& get(Tid id) { + if (id.idx >= tensors.size() || !tensors[id.idx].has_value()) { + throw std::out_of_range("MutableBufferData::get: id not found or unset"); + } + return *tensors[id.idx]; + } + + inline const Tensor& get(Tid id) const { + if (id.idx >= tensors.size() || !tensors[id.idx].has_value()) { + throw std::out_of_range("MutableBufferData::get: id not found or unset"); + } + return *tensors[id.idx]; + } + + inline void set(Tid id, Tensor t) { + if (id.idx >= tensors.size()) { + throw std::out_of_range("MutableBufferData::set: id out of range"); + } + tensors[id.idx] = std::move(t); + } + + inline void clear() { + tensors.clear(); + } +}; + +struct ExecutionState { + const MLXProgram* program{nullptr}; + const ConstantData* constants{nullptr}; // Shared, read-only + MutableBufferData* mutable_buffers{nullptr}; // Per-handle, persistent + + // Per-execution tensors: inputs, outputs, temps (NOT constants or mutable + // buffers) + std::vector> tensors; + + // Non-constant values (SymInt, etc.) + std::vector> values; + + // Logging context + size_t current_op_idx{0}; + const char* current_op_name{nullptr}; + + // Tensor ID range boundaries for O(1) type lookup (computed at bind time) + uint32_t num_constants{0}; + uint32_t input_end{0}; + uint32_t output_end{0}; + uint32_t mutable_buffer_end{0}; + + void bind( + const MLXProgram& prog, + const ConstantData& const_data, + MutableBufferData& mut_bufs) { + program = &prog; + constants = &const_data; + mutable_buffers = &mut_bufs; + + // Allocate space for inputs, outputs, and temps only (not constants or + // mutable buffers) + uint64_t num_per_execution_tensors = + static_cast(prog.num_input_tensors) + + prog.num_output_tensors + prog.num_temp_tensors; + if (num_per_execution_tensors > 1'000'000) { + throw std::runtime_error( + "bind: num_per_execution_tensors " + + std::to_string(num_per_execution_tensors) + " exceeds limit"); + } + tensors.assign( + static_cast(num_per_execution_tensors), std::nullopt); + if (prog.num_values > 1'000'000) { + throw std::runtime_error( + "bind: num_values " + std::to_string(prog.num_values) + + " exceeds limit"); + } + values.assign(prog.num_values, std::nullopt); + + // Compute tensor ID range boundaries for fast type lookup + // ID assignment order: Constant -> Input -> Output -> MutableBuffer -> Temp + num_constants = prog.num_constant_tensors; + uint64_t ie = static_cast(num_constants) + prog.num_input_tensors; + uint64_t oe = ie + prog.num_output_tensors; + uint64_t me = oe + prog.num_mutable_buffer_tensors; + if (me > std::numeric_limits::max()) { + throw std::runtime_error("bind: tensor ID range overflow"); + } + input_end = static_cast(ie); + output_end = static_cast(oe); + mutable_buffer_end = static_cast(me); + } + + // Check if a tensor ID is a mutable buffer + inline bool is_mutable_buffer(Tid id) const { + return id.idx >= output_end && id.idx < mutable_buffer_end; + } + + // Convert tensor ID to index in the tensors vector + // Accounts for constants and mutable buffers not being in the vector + inline uint32_t tensor_index(Tid id) const { + if (id.idx < num_constants) { + throw std::runtime_error( + "tensor_index: called with constant tensor id " + + std::to_string(id.idx)); + } + if (is_mutable_buffer(id)) { + throw std::runtime_error( + "tensor_index: called with mutable buffer tensor id " + + std::to_string(id.idx)); + } + uint32_t idx = id.idx - num_constants; + // If this ID is after mutable buffer range, subtract mutable buffer count + if (id.idx >= mutable_buffer_end) { + if (idx < program->num_mutable_buffer_tensors) { + throw std::runtime_error( + "tensor_index: underflow for tensor id " + std::to_string(id.idx)); + } + idx -= program->num_mutable_buffer_tensors; + } + if (idx >= tensors.size()) { + throw std::out_of_range( + "tensor_index: computed index " + std::to_string(idx) + + " out of range (size=" + std::to_string(tensors.size()) + + ") for tensor id " + std::to_string(id.idx)); + } + return idx; + } + + void reset() { + // Clear per-execution tensors (inputs, outputs, temps) + // Constants and mutable buffers are not in this vector + for (auto& t : tensors) { + t = std::nullopt; + } + for (auto& v : values) { + v = std::nullopt; + } + } + + static inline const char* dtype_str(::mlx::core::Dtype dtype) { + using namespace ::mlx::core; + switch (dtype.val()) { + case float32.val(): + return "f32"; + case float16.val(): + return "f16"; + case bfloat16.val(): + return "bf16"; + case int32.val(): + return "i32"; + case int64.val(): + return "i64"; + case int16.val(): + return "i16"; + case int8.val(): + return "i8"; + case uint32.val(): + return "u32"; + case uint8.val(): + return "u8"; + case bool_.val(): + return "bool"; + default: + return "?"; + } + } + + static inline std::string format_tensor_info(const Tensor& t) { + std::ostringstream ss; + ss << dtype_str(t.dtype()); + ss << "("; + const auto& shape = t.shape(); + for (size_t i = 0; i < shape.size(); ++i) { + if (i > 0) + ss << ","; + ss << shape[i]; + } + ss << ")"; + return ss.str(); + } + + // Compute tensor stats: min, max, mean, nan_count + // Uses MLX ops for GPU-accelerated computation + static inline std::string format_tensor_stats(const Tensor& t) { + using namespace ::mlx::core; + + try { + std::ostringstream ss; + + size_t numel = t.size(); + if (numel == 0) { + ss << "[empty]"; + return ss.str(); + } + + // Cast to float32 for stats computation (handles bf16/fp16/int/bool) + Tensor t_float = astype(t, float32); + + // Use MLX ops for efficient GPU-based stats + Tensor nan_mask = isnan(t_float); + Tensor inf_mask = isinf(t_float); + Tensor nan_count_arr = sum(astype(nan_mask, int32)); + Tensor inf_count_arr = sum(astype(inf_mask, int32)); + + // For min/max/mean, we need to handle NaN/Inf - replace with 0 + Tensor valid_mask = logical_not(logical_or(nan_mask, inf_mask)); + Tensor t_valid = where(valid_mask, t_float, zeros_like(t_float)); + + Tensor min_arr = min(t_valid); + Tensor max_arr = max(t_valid); + Tensor mean_arr = mean(t_valid); + + // Evaluate all at once + eval({nan_count_arr, inf_count_arr, min_arr, max_arr, mean_arr}); + + int nan_count = nan_count_arr.item(); + int inf_count = inf_count_arr.item(); + float min_val = min_arr.item(); + float max_val = max_arr.item(); + float mean_val = mean_arr.item(); + + ss << std::fixed << std::setprecision(4); + ss << "[min=" << min_val << " max=" << max_val << " mean=" << mean_val; + if (nan_count > 0) { + ss << " NaN=" << nan_count; + } + if (inf_count > 0) { + ss << " Inf=" << inf_count; + } + ss << "]"; + return ss.str(); + } catch (const std::exception& e) { + return std::string("[stats error: ") + e.what() + "]"; + } catch (...) { + return "[stats error: unknown]"; + } + } + + // Get tensor type prefix for logging: "c", "i", "o", "b", "t" + inline const char* tensor_type_prefix(Tid id) const { + if (!program) + return "?"; + + uint32_t tid = id.idx; + + // Check each range in order (mutually exclusive ranges) + if (tid < program->num_constant_tensors) + return "c"; // Constant + if (tid < input_end) + return "i"; // User Input + if (tid < output_end) + return "o"; // User Output + if (tid < mutable_buffer_end) + return "b"; // Mutable Buffer + return "t"; // Temp + } + + inline void begin_op(size_t idx, const char* name) { + current_op_idx = idx; + current_op_name = name; + if (isOpLoggingEnabled()) { + std::cout << "[" << idx << "] " << name << std::endl; + } + } + + inline void end_op() { + if (isOpLoggingEnabled()) { + std::cout << "----\n"; + } + } + + inline Tensor& tensor_ref(Tid id) { + if (isOpLoggingEnabled()) { + std::cout << " ref " << tensor_type_prefix(id) << id.idx << std::flush; + } + if (!program) { + throw std::runtime_error("tensor_ref: Program not bound"); + } + if (id.idx >= program->num_tensors()) { + throw std::out_of_range("tensor_ref: id out of range"); + } + if (program->is_constant_tensor(id)) { + throw std::runtime_error("tensor_ref: cannot mutate constant tensor"); + } + // Route to mutable buffers or per-execution tensors + Tensor* t = nullptr; + if (is_mutable_buffer(id)) { + if (!mutable_buffers) { + throw std::runtime_error("tensor_ref: mutable_buffers not bound"); + } + t = &mutable_buffers->get(id); + } else { + uint32_t idx = tensor_index(id); + if (idx >= tensors.size()) { + throw std::out_of_range("tensor_ref: tensor idx out of range"); + } + auto& opt = tensors[idx]; + if (!opt) { + throw std::runtime_error( + "tensor_ref: uninitialized tensor idx=" + std::to_string(id.idx)); + } + t = &*opt; + } + if (isOpLoggingEnabled()) { + std::cout << " " << format_tensor_info(*t) << "\n"; + } + return *t; + } + + inline const Tensor& const_tensor_ref(Tid id) const { + if (isOpLoggingEnabled()) { + std::cout << " in " << tensor_type_prefix(id) << id.idx << std::flush; + } + if (!program) { + throw std::runtime_error("const_tensor_ref: Program not bound"); + } + if (id.idx >= program->num_tensors()) { + throw std::out_of_range("const_tensor_ref: id out of range"); + } + + const Tensor* t = nullptr; + if (program->is_constant_tensor(id)) { + // Route to constants + if (!constants) { + throw std::runtime_error("const_tensor_ref: constants not bound"); + } + t = &constants->get(id); + } else if (is_mutable_buffer(id)) { + // Route to mutable buffers + if (!mutable_buffers) { + throw std::runtime_error("const_tensor_ref: mutable_buffers not bound"); + } + t = &mutable_buffers->get(id); + } else { + // Route to per-execution tensors + uint32_t idx = tensor_index(id); + if (idx >= tensors.size()) { + throw std::out_of_range("const_tensor_ref: tensor idx out of range"); + } + const auto& opt = tensors[idx]; + if (!opt) { + throw std::runtime_error( + "const_tensor_ref: uninitialized tensor idx=" + + std::to_string(id.idx)); + } + t = &*opt; + } + + if (isOpLoggingEnabled()) { + std::cout << " " << format_tensor_info(*t) << " " + << format_tensor_stats(*t) << "\n"; + } + return *t; + } + + // Set a tensor output + inline void set_tensor(Tid id, Tensor arr) { + if (isOpLoggingEnabled()) { + std::cout << " out " << tensor_type_prefix(id) << id.idx << " " + << format_tensor_info(arr) << " " << format_tensor_stats(arr) + << "\n"; + } + if (!program) { + throw std::runtime_error("set_tensor: Program not bound"); + } + if (id.idx < program->num_constant_tensors) { + throw std::runtime_error("set_tensor: cannot write to constant tensor"); + } + // Route to mutable buffers or per-execution tensors + if (is_mutable_buffer(id)) { + if (!mutable_buffers) { + throw std::runtime_error("set_tensor: mutable_buffers not bound"); + } + mutable_buffers->set(id, std::move(arr)); + } else { + uint32_t idx = tensor_index(id); + if (idx >= tensors.size()) { + throw std::out_of_range("set_tensor: tensor idx out of range"); + } + tensors[idx] = std::move(arr); + } + } + + template + inline T& value_ref(Vid id) { + if (isOpLoggingEnabled()) { + std::cout << " ref v" << id.idx << std::flush; + } + if (id.idx >= values.size()) { + throw std::out_of_range("value_ref: id out of range"); + } + auto& opt = values[id.idx]; + if (!opt) { + throw std::runtime_error( + "value_ref: uninitialized value idx=" + std::to_string(id.idx)); + } + if (isOpLoggingEnabled()) { + std::cout << " " << std::get(*opt) << "\n"; + } + return std::get(*opt); + } + + template + inline const T& const_value_ref(Vid id) const { + if (isOpLoggingEnabled()) { + std::cout << " in v" << id.idx << std::flush; + } + if (id.idx >= values.size()) { + throw std::out_of_range("const_value_ref: id out of range"); + } + const auto& opt = values[id.idx]; + if (!opt) { + throw std::runtime_error( + "const_value_ref: uninitialized value idx=" + std::to_string(id.idx)); + } + if (isOpLoggingEnabled()) { + std::cout << " " << std::get(*opt) << "\n"; + } + return std::get(*opt); + } + + inline const Value& const_value(Vid id) const { + if (isOpLoggingEnabled()) { + std::cout << " in v" << id.idx << std::flush; + } + if (id.idx >= values.size()) { + throw std::out_of_range("const_value: id out of range"); + } + const auto& opt = values[id.idx]; + if (!opt) { + throw std::runtime_error( + "const_value: uninitialized value idx=" + std::to_string(id.idx)); + } + if (isOpLoggingEnabled()) { + std::visit([](auto&& arg) { std::cout << " " << arg << "\n"; }, *opt); + } + return *opt; + } + + template + inline void set_value(Vid id, T val) { + if (isOpLoggingEnabled()) { + std::cout << " out v" << id.idx << " " << val << "\n"; + } + if (id.idx >= values.size()) { + throw std::out_of_range("set_value: id out of range"); + } + values[id.idx] = val; + } +}; + +inline ::mlx::core::Dtype resolve_dtype(ScalarType d) { + using namespace ::mlx::core; + switch (d) { + case ScalarType::Half: + return float16; + case ScalarType::Float: + return float32; + case ScalarType::BFloat16: + return bfloat16; + case ScalarType::Int: + return int32; + case ScalarType::Short: + return int16; + case ScalarType::Long: + return int64; + case ScalarType::UInt32: + return uint32; + case ScalarType::Byte: + return uint8; + case ScalarType::Bool: + return bool_; + case ScalarType::Char: + return int8; + default: + throw std::runtime_error( + "Unsupported ScalarType: " + std::to_string(static_cast(d))); + } +} + +inline ::mlx::core::Dtype resolve_dtype(int8_t d) { + return resolve_dtype(static_cast(d)); +} + +// Maximum allocation size for any single tensor created from untrusted data. +// This bounds GPU memory allocation from malformed payloads. +constexpr size_t kMaxAllocationBytes = + static_cast(4) * 1024 * 1024 * 1024; // 4 GB + +/// Validate that a tensor with the given shape and dtype does not exceed +/// kMaxAllocationBytes. Throws std::runtime_error on invalid dimensions +/// or if the total size exceeds the limit. +inline void check_allocation_bounded( + const ::mlx::core::Shape& shape, + ::mlx::core::Dtype dtype, + const char* context) { + size_t elem_size = ::mlx::core::size_of(dtype); + size_t numel = 1; + for (auto d : shape) { + if (d <= 0) { + throw std::runtime_error( + std::string(context) + ": invalid dimension " + std::to_string(d)); + } + numel = safe_mul(numel, static_cast(d), context); + } + size_t total_bytes = safe_mul(numel, elem_size, context); + if (total_bytes > kMaxAllocationBytes) { + throw std::runtime_error( + std::string(context) + ": allocation exceeds 4GB limit"); + } +} + +inline int32_t clamp_to_int32(int64_t val64) { + // INT64_MAX is commonly used as a sentinel for "slice to end". + // Non-sentinel large values are silently clamped, which may change + // slice semantics — but this matches PyTorch behavior. + if (val64 >= static_cast(std::numeric_limits::max())) { + return std::numeric_limits::max(); + } else if ( + val64 <= static_cast(std::numeric_limits::min())) { + return std::numeric_limits::min(); + } + return static_cast(val64); +} + +inline int32_t resolve_int( + const std::variant& v, + const ExecutionState& st) { + if (std::holds_alternative(v)) { + return clamp_to_int32(std::get(v)); + } + return st.const_value_ref(std::get(v)); +} + +inline std::vector resolve_ints( + const std::vector>& v, + const ExecutionState& st) { + std::vector out; + out.reserve(v.size()); + for (const auto& elem : v) { + out.push_back(resolve_int(elem, st)); + } + return out; +} + +inline float resolve_float( + const std::variant& v, + const ExecutionState& st) { + if (std::holds_alternative(v)) { + return static_cast(std::get(v)); + } + // The value may be stored as int32_t (from SymInt computations) or float. + const auto& val = st.const_value(std::get(v)); + return std::visit( + [](auto&& arg) -> float { return static_cast(arg); }, val); +} + +inline ::mlx::core::Shape to_shape( + const std::vector>& dims, + const ExecutionState& st) { + auto resolved = resolve_ints(dims, st); + return ::mlx::core::Shape(resolved.begin(), resolved.end()); +} + +inline ::mlx::core::Shape to_shape(const std::vector& dims) { + return ::mlx::core::Shape(dims.begin(), dims.end()); +} + +// Overload for static shapes (used when loading constants where all dims must +// be literals) +// Convert ShapeDim vector to MLX Shape (for constants and mutable buffers). +// Only static dimensions are allowed — dynamic dims (value == -1) are rejected. +inline ::mlx::core::Shape to_shape(const std::vector& dims) { + ::mlx::core::Shape out; + out.reserve(dims.size()); + for (const auto& d : dims) { + if (d.is_dynamic()) { + throw std::runtime_error( + "to_shape: expected static shape but found dynamic dimension"); + } + out.push_back(d.value); + } + return out; +} + +// Load constants from ExecuTorch's NamedDataMap. +// Constants are stored by name in the .pte file and loaded via the +// named_data_map interface. This allows ExecuTorch to own the constant data and +// enables zero-copy on Apple Silicon unified memory. +// +// Parameters: +// program: The loaded MLXProgram containing tensor metadata and named_slots +// named_data_map: ExecuTorch's interface for accessing named data +// store: Output storage for loaded constant tensors +// constant_buffers: Vector to store FreeableBuffers (must outlive store for +// zero-copy) +inline void load_constants( + const MLXProgram& program, + const runtime::NamedDataMap* named_data_map, + ConstantData& store, + std::vector& constant_buffers) { + using namespace ::mlx::core; + + store.tensors.clear(); + constant_buffers.clear(); + + if (program.num_constant_tensors == 0) { + return; + } + + store.tensors.reserve(program.num_constant_tensors); + constant_buffers.reserve(program.num_constant_tensors); + + // Build tid -> name map for O(1) lookup + std::unordered_map tid_to_name; + tid_to_name.reserve(program.named_slots.size()); + for (const auto& ns : program.named_slots) { + if (ns.slot.slot_type == SlotType::TensorSlot) { + tid_to_name[ns.slot.idx] = &ns.name; + } + } + + // Load each constant tensor by name + for (uint32_t tid = 0; tid < program.num_constant_tensors; ++tid) { + // Get tensor metadata + if (tid >= program.tensor_meta.size() || !program.tensor_meta[tid]) { + throw std::runtime_error( + "load_constants: missing metadata for constant " + + std::to_string(tid)); + } + + // Find the name for this tensor ID + auto it = tid_to_name.find(tid); + const std::string* name = (it != tid_to_name.end()) ? it->second : nullptr; + if (!name) { + throw std::runtime_error( + "load_constants: no name found for constant tensor " + + std::to_string(tid)); + } + + // Get data from named_data_map + if (named_data_map == nullptr) { + throw std::runtime_error( + "load_constants: named_data_map is null but program has constants"); + } + + auto data_result = named_data_map->get_data(name->c_str()); + if (!data_result.ok()) { + throw std::runtime_error( + "load_constants: failed to get data for constant '" + *name + + "': error " + std::to_string(static_cast(data_result.error()))); + } + + // Move the buffer into our storage (keeps it alive for zero-copy) + constant_buffers.push_back(std::move(data_result.get())); + runtime::FreeableBuffer& buffer = constant_buffers.back(); + + const auto& meta = *program.tensor_meta[tid]; + Shape shape = to_shape(meta.shape); + Dtype dtype = resolve_dtype(meta.scalar_type); + + // Create MLX array with zero-copy when enabled. + // SAFETY: Constants are read-only; the program builder ensures no in-place + // ops target constant tensors. The const_cast is required by MLX's array + // constructor but the data will not be mutated + void* data_ptr = const_cast(buffer.data()); + + if constexpr (kEnableConstantZeroCopy) { + // Zero-copy: wrap pointer directly with no-op deleter + // The FreeableBuffer in constant_buffers keeps the data alive + auto deleter = [](void*) { + // Data lifetime managed by FreeableBuffer in + // MLXHandle::constant_buffers + }; + array arr = array(data_ptr, shape, dtype, deleter); + store.add(std::move(arr)); + } else { + // No deleter = MLX copies the data into its own memory + store.add(array(static_cast(data_ptr), shape, dtype)); + } + } + + // Evaluate all constants immediately to prepare Metal buffers + // This trades init time for faster first inference + eval(store.tensors); +} + +inline void load_mutable_buffers( + const MLXProgram& program, + MutableBufferData& store) { + using namespace ::mlx::core; + + store.clear(); + + if (program.mutable_buffer_map.empty()) { + return; + } + + // Pre-size the storage to fit all tensor IDs + // Mutable buffer IDs are in the global tensor ID space + uint32_t max_tid = 0; + for (const auto& slot : program.mutable_buffer_map) { + if (slot.idx > max_tid) { + max_tid = slot.idx; + } + } + if (max_tid >= 1'000'000) { + throw std::runtime_error( + "load_mutable_buffers: max_tid " + std::to_string(max_tid) + + " exceeds limit"); + } + store.resize(max_tid + 1); + + for (const auto& slot : program.mutable_buffer_map) { + if (slot.slot_type != SlotType::TensorSlot) { + throw std::runtime_error( + "load_mutable_buffers: unexpected slot type " + + std::to_string(static_cast(slot.slot_type))); + } + + Tid tid{slot.idx}; + + // Get metadata for this tensor + if (tid.idx >= program.tensor_meta.size()) { + ET_LOG( + Error, + "load_mutable_buffers: tid %u >= tensor_meta.size() %zu", + tid.idx, + program.tensor_meta.size()); + throw std::runtime_error( + "load_mutable_buffers: tensor index out of range for tensor " + + std::to_string(tid.idx)); + } + + if (!program.tensor_meta[tid.idx]) { + ET_LOG( + Error, + "load_mutable_buffers: missing metadata for tensor %u", + tid.idx); + throw std::runtime_error( + "load_mutable_buffers: missing metadata for tensor " + + std::to_string(tid.idx)); + } + + const auto& meta = *program.tensor_meta[tid.idx]; + auto shape = to_shape(meta.shape); + auto dtype = resolve_dtype(meta.scalar_type); + + check_allocation_bounded(shape, dtype, "load_mutable_buffers"); + + // Initialize mutable buffer to zeros + // This matches the typical initialization of KV cache buffers + auto arr = zeros(shape, dtype); + + // Evaluate immediately to allocate in GPU memory + eval(arr); + + store.set(tid, std::move(arr)); + } +} + +} // namespace mlx +} // namespace backends +} // namespace executorch diff --git a/backends/mlx/runtime/MLXInterpreter.h b/backends/mlx/runtime/MLXInterpreter.h new file mode 100644 index 00000000000..f3b6e9b720f --- /dev/null +++ b/backends/mlx/runtime/MLXInterpreter.h @@ -0,0 +1,169 @@ +// +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. +// + +#pragma once + +#include "MLXExecutor.h" + +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace mlx { + +namespace ops { + +using namespace ::mlx::core; + +/** + * Normalize axis to be in range [0, rank) and validate. + * @param axis The axis value (can be negative) + * @param rank The tensor rank + * @param op_name Name of the operation for error messages + * @return Normalized axis in range [0, rank) + * @throws std::out_of_range if axis is out of range + */ +inline int normalize_axis(int axis, int rank, const char* op_name) { + if (axis < -rank || axis >= rank) { + throw std::out_of_range(std::string(op_name) + ": axis out of range"); + } + if (axis < 0) + axis += rank; + return axis; +} + +/** + * Infers dimensions with -1 in a reshape-like operation. + * + * PyTorch allows -1 in shapes to mean "infer this dimension from total size". + * MLX requires concrete positive integers, so we must resolve -1 values. + * + * @param shape The shape to resolve (may contain -1) + * @param input_size Total number of elements in the input tensor + * @return Resolved shape with all positive integers + * @throws std::runtime_error if shape has multiple -1 or incompatible sizes + */ +inline std::vector infer_shape_with_minus_one( + const std::vector& shape, + size_t input_size) { + std::vector resolved_shape = shape; + int neg_one_idx = -1; + int64_t known_size = 1; // Use int64_t to avoid overflow + + // Find -1 dimension and compute product of known dimensions + for (size_t i = 0; i < resolved_shape.size(); i++) { + if (resolved_shape[i] == -1) { + if (neg_one_idx != -1) { + throw std::runtime_error("infer_shape: only one dimension can be -1"); + } + neg_one_idx = static_cast(i); + } else { + known_size *= static_cast(resolved_shape[i]); + } + } + + // Infer the -1 dimension if present + if (neg_one_idx != -1) { + if (known_size == 0) { + throw std::runtime_error( + "infer_shape: cannot infer -1 dimension when known product is 0"); + } + int64_t input_size_i64 = static_cast(input_size); + if (input_size_i64 % known_size != 0) { + throw std::runtime_error( + "infer_shape: cannot infer dimension - size mismatch"); + } + int64_t inferred_dim = input_size_i64 / known_size; + + // Check that inferred dimension fits in int + if (inferred_dim > std::numeric_limits::max()) { + throw std::runtime_error( + "infer_shape: inferred dimension exceeds int max"); + } + + resolved_shape[static_cast(neg_one_idx)] = + static_cast(inferred_dim); + } + + return resolved_shape; +} + +inline void exec_noop(const NoopNode&, ExecutionState&, StreamOrDevice) {} + +inline void +exec_addmm(const AddmmNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& mat1 = st.const_tensor_ref(n.mat1); + const auto& mat2 = st.const_tensor_ref(n.mat2); + + array Y = n.bias ? addmm( + st.const_tensor_ref(*n.bias), + mat1, + mat2, + /*alpha=*/n.alpha, + /*beta=*/n.beta, + s) + : matmul(mat1, mat2, s); + + st.set_tensor(n.out, std::move(Y)); +} + +} // namespace ops + +class Interpreter { + public: + void run( + const MLXProgram& prog, + ExecutionState& st, + StreamOrDevice stream = {}) const { + run_chain(prog, prog.main_chain_idx, st, stream); + } + + void run_chain( + const MLXProgram& prog, + uint32_t chain_idx, + ExecutionState& st, + StreamOrDevice stream = {}) const { + if (chain_idx >= prog.instruction_chains.size()) { + throw std::runtime_error( + "run_chain: chain_idx " + std::to_string(chain_idx) + + " out of range (num_chains=" + + std::to_string(prog.instruction_chains.size()) + ")"); + } + const auto& chain = prog.instruction_chains[chain_idx]; + size_t idx = 0; + for (const auto& instr : chain) { + st.begin_op(idx, op_name(instr.op)); + dispatch(instr, st, stream); + st.end_op(); + ++idx; + } + } + + private: + void dispatch(const Instruction& instr, ExecutionState& st, StreamOrDevice s) + const { + switch (instr.op) { + case OpCode::NOOP: + ops::exec_noop(std::get(instr.node), st, s); + break; + case OpCode::ADDMM: + ops::exec_addmm(std::get(instr.node), st, s); + break; + default: + throw std::runtime_error( + "Unknown opcode: " + std::to_string(static_cast(instr.op))); + } + } +}; + +} // namespace mlx +} // namespace backends +} // namespace executorch diff --git a/backends/mlx/serialization/MLXLoader.cpp.tmpl b/backends/mlx/serialization/MLXLoader.cpp.tmpl new file mode 100644 index 00000000000..aa4716d7a4a --- /dev/null +++ b/backends/mlx/serialization/MLXLoader.cpp.tmpl @@ -0,0 +1,324 @@ +// -*- c++ -*- + +#include "MLXLoader.h" + +#include +#include + +namespace executorch { +namespace backends { +namespace mlx { +namespace loader { + +namespace { + +// Header structure for MLX payload +constexpr size_t kHeaderSize = 24; +constexpr uint32_t kMagic = 0x30584C4D; // "MLX0" in little-endian + +struct MLXHeader { + uint32_t padding; + uint32_t magic; + uint64_t data_offset; + uint64_t data_size; +}; +static_assert(sizeof(MLXHeader) == kHeaderSize, "MLXHeader size mismatch"); + +bool parse_header(const void* data, size_t size, MLXHeader& header) { + if (size < kHeaderSize) { + return false; + } + std::memcpy(&header, data, sizeof(MLXHeader)); + if (header.magic != kMagic) { + return false; + } + // Validate data_offset: must be strictly greater than kHeaderSize (so the + // FlatBuffer region is non-empty) and must not exceed the total buffer size. + if (header.data_offset <= kHeaderSize || header.data_offset > size) { + return false; + } + return true; +} + +// Helper to convert FlatBuffer vectors to std::vector. +// Caps size to prevent unbounded allocations from malformed payloads. +template +std::vector to_vector(const flatbuffers::Vector* fb_vec) { + if (!fb_vec) { + return {}; + } + constexpr size_t kMaxVectorSize = 1'000'000; + if (fb_vec->size() > kMaxVectorSize) { + throw std::runtime_error( + "FlatBuffer vector size " + std::to_string(fb_vec->size()) + + " exceeds maximum of " + std::to_string(kMaxVectorSize)); + } + return std::vector(fb_vec->begin(), fb_vec->end()); +} + +} // namespace + +// ============================================================================= +// load_instruction - AUTO-GENERATED switch statement +// ============================================================================= + +Instruction load_instruction(const mlx_delegate::Instruction* fb_instr) { + Instruction instr; + + if (!fb_instr || !fb_instr->op()) { + instr.op = OpCode::NOOP; + instr.node = NoopNode{}; + return instr; + } + + auto op_type = fb_instr->op_type(); + + switch (op_type) { +{{LOAD_INSTRUCTION_CASES}} + default: + throw std::runtime_error( + "Unknown op_type in load_instruction: " + + std::to_string(static_cast(op_type)) + + ". The .pte was built with a newer schema than this binary. " + "Rebuild with the latest runtime."); + } + + return instr; +} + +// ============================================================================= +// load_program +// ============================================================================= + +MLXProgram load_program(const void* data, size_t size) { + MLXHeader header; + if (!parse_header(data, size, header)) { + throw std::runtime_error("Invalid MLX header"); + } + + // Defense-in-depth: parse_header already validates this, but guard the + // unsigned subtraction against underflow in case the call site ever changes. + if (header.data_offset <= kHeaderSize || header.data_offset > size) { + throw std::runtime_error("data_offset out of range"); + } + const uint8_t* fb_data = static_cast(data) + kHeaderSize; + size_t fb_size = header.data_offset - kHeaderSize; + + flatbuffers::Verifier verifier(fb_data, fb_size); + if (!mlx_delegate::VerifyMLXGraphBuffer(verifier)) { + throw std::runtime_error("Invalid FlatBuffer data"); + } + + const auto* fb_graph = mlx_delegate::GetMLXGraph(fb_data); + if (!fb_graph) { + throw std::runtime_error("Failed to parse MLXGraph"); + } + + MLXProgram program; + + if (fb_graph->version()) { + program.version = fb_graph->version()->str(); + } + + program.num_constant_tensors = fb_graph->num_constant_tensors(); + program.num_input_tensors = fb_graph->num_input_tensors(); + program.num_output_tensors = fb_graph->num_output_tensors(); + program.num_mutable_buffer_tensors = fb_graph->num_mutable_buffer_tensors(); + program.num_temp_tensors = fb_graph->num_temp_tensors(); + program.num_values = fb_graph->num_values(); + + // Cap all counts/collection sizes to prevent unbounded allocations from + // malformed FlatBuffer payloads + constexpr size_t kMaxCollectionSize = 1'000'000; + auto check_collection_size = [](size_t sz, const char* name) { + if (sz > kMaxCollectionSize) { + throw std::runtime_error( + std::string("Malformed program: ") + name + " size " + + std::to_string(sz) + " exceeds maximum of " + + std::to_string(kMaxCollectionSize)); + } + }; + + check_collection_size(program.num_tensors(), "num_tensors()"); + check_collection_size(program.num_values, "num_values"); + + if (fb_graph->instruction_chains()) { + check_collection_size(fb_graph->instruction_chains()->size(), "instruction_chains"); + program.instruction_chains.reserve(fb_graph->instruction_chains()->size()); + for (size_t c = 0; c < fb_graph->instruction_chains()->size(); ++c) { + const auto* fb_chain = fb_graph->instruction_chains()->Get(static_cast(c)); + std::vector chain; + if (fb_chain && fb_chain->instructions()) { + check_collection_size(fb_chain->instructions()->size(), "instructions in chain"); + chain.reserve(fb_chain->instructions()->size()); + for (size_t i = 0; i < fb_chain->instructions()->size(); ++i) { + chain.push_back(load_instruction(fb_chain->instructions()->Get(static_cast(i)))); + } + } + program.instruction_chains.push_back(std::move(chain)); + } + } + + program.main_chain_idx = fb_graph->main_chain_idx(); + program.init_chain_idx = fb_graph->init_chain_idx(); + + // Validate chain indices against actual instruction_chains size. + if (program.main_chain_idx >= program.instruction_chains.size()) { + throw std::runtime_error( + "Invalid main_chain_idx " + + std::to_string(program.main_chain_idx) + + " (only " + std::to_string(program.instruction_chains.size()) + + " chains loaded)"); + } + if (program.init_chain_idx >= 0 && + static_cast(program.init_chain_idx) >= + program.instruction_chains.size()) { + throw std::runtime_error( + "Invalid init_chain_idx " + + std::to_string(program.init_chain_idx) + + " (only " + std::to_string(program.instruction_chains.size()) + + " chains loaded)"); + } + + if (fb_graph->input_map()) { + check_collection_size(fb_graph->input_map()->size(), "input_map"); + for (size_t i = 0; i < fb_graph->input_map()->size(); ++i) { + const auto* slot = fb_graph->input_map()->Get(static_cast(i)); + auto sv = convert_slot_variant(slot); + if (sv.slot_type == SlotType::TensorSlot && + sv.idx >= program.num_tensors()) { + throw std::runtime_error( + "input_map: slot index " + std::to_string(sv.idx) + + " exceeds num_tensors " + + std::to_string(program.num_tensors())); + } + program.input_map.push_back(sv); + } + } + + if (fb_graph->output_map()) { + check_collection_size(fb_graph->output_map()->size(), "output_map"); + for (size_t i = 0; i < fb_graph->output_map()->size(); ++i) { + const auto* slot = fb_graph->output_map()->Get(static_cast(i)); + auto sv = convert_slot_variant(slot); + if (sv.slot_type == SlotType::TensorSlot && + sv.idx >= program.num_tensors()) { + throw std::runtime_error( + "output_map: slot index " + std::to_string(sv.idx) + + " exceeds num_tensors " + + std::to_string(program.num_tensors())); + } + program.output_map.push_back(sv); + } + } + + if (fb_graph->mutable_buffer_map()) { + check_collection_size(fb_graph->mutable_buffer_map()->size(), "mutable_buffer_map"); + for (size_t i = 0; i < fb_graph->mutable_buffer_map()->size(); ++i) { + const auto* slot = fb_graph->mutable_buffer_map()->Get(static_cast(i)); + auto sv = convert_slot_variant(slot); + if (sv.slot_type == SlotType::TensorSlot && + sv.idx >= program.num_tensors()) { + throw std::runtime_error( + "mutable_buffer_map: slot index " + std::to_string(sv.idx) + + " exceeds num_tensors " + + std::to_string(program.num_tensors())); + } + program.mutable_buffer_map.push_back(sv); + } + } + + if (fb_graph->named_slots()) { + check_collection_size(fb_graph->named_slots()->size(), "named_slots"); + for (size_t i = 0; i < fb_graph->named_slots()->size(); ++i) { + const auto* fb_slot = fb_graph->named_slots()->Get(static_cast(i)); + if (!fb_slot || !fb_slot->name()) { + throw std::runtime_error( + "Malformed program: named_slot at index " + std::to_string(i) + + " is null or has null name"); + } + NamedSlot slot; + slot.name = fb_slot->name()->str(); + slot.slot = convert_slot_variant(fb_slot->slot()); + program.named_slots.push_back(std::move(slot)); + } + } + + if (fb_graph->tensor_meta()) { + check_collection_size(fb_graph->tensor_meta()->size(), "tensor_meta"); + for (size_t i = 0; i < fb_graph->tensor_meta()->size(); ++i) { + const auto* fb_meta = fb_graph->tensor_meta()->Get(static_cast(i)); + if (fb_meta) { + TensorMeta meta; + if (fb_meta->shape()) { + // Validate tensor rank against kTensorDimensionLimit to prevent + // stack overflows from unchecked rank + constexpr size_t kTensorDimensionLimit = 16; + if (fb_meta->shape()->size() > kTensorDimensionLimit) { + throw std::runtime_error( + "Tensor at index " + std::to_string(i) + + " has rank " + std::to_string(fb_meta->shape()->size()) + + " exceeding kTensorDimensionLimit (" + + std::to_string(kTensorDimensionLimit) + ")"); + } + for (size_t j = 0; j < fb_meta->shape()->size(); ++j) { + const auto* fb_dim = fb_meta->shape()->Get(static_cast(j)); + if (!fb_dim) { + throw std::runtime_error( + "Null ShapeDim at index " + std::to_string(j) + + " in tensor_meta " + std::to_string(i)); + } + ShapeDim dim; + dim.value = fb_dim->value(); + dim.min_value = fb_dim->min_value(); + dim.max_value = fb_dim->max_value(); + if (dim.value < -1) { + throw std::runtime_error( + "Invalid ShapeDim value " + std::to_string(dim.value) + + " at index " + std::to_string(j) + + " in tensor_meta " + std::to_string(i)); + } + if (dim.is_dynamic()) { + if (dim.min_value < 0) { + throw std::runtime_error( + "Invalid ShapeDim min_value " + std::to_string(dim.min_value) + + " at index " + std::to_string(j) + + " in tensor_meta " + std::to_string(i)); + } + if (dim.max_value != -1 && dim.max_value < dim.min_value) { + throw std::runtime_error( + "ShapeDim max_value " + std::to_string(dim.max_value) + + " < min_value " + std::to_string(dim.min_value) + + " at index " + std::to_string(j) + + " in tensor_meta " + std::to_string(i)); + } + } + meta.shape.push_back(dim); + } + } + auto raw_scalar_type = fb_meta->scalar_type(); + if (raw_scalar_type < 0 || + raw_scalar_type >= + static_cast(ScalarType::NumOptions)) { + throw std::runtime_error( + "Invalid scalar_type " + std::to_string(raw_scalar_type) + + " in tensor_meta at index " + std::to_string(i)); + } + meta.scalar_type = static_cast(raw_scalar_type); + if (fb_meta->dim_order()) { + meta.dim_order = to_vector(fb_meta->dim_order()); + } + program.tensor_meta.push_back(std::move(meta)); + } else { + program.tensor_meta.push_back(std::nullopt); + } + } + } + + return program; +} + +} // namespace loader +} // namespace mlx +} // namespace backends +} // namespace executorch diff --git a/backends/mlx/serialization/MLXLoader.h.tmpl b/backends/mlx/serialization/MLXLoader.h.tmpl new file mode 100644 index 00000000000..0930d5e00e1 --- /dev/null +++ b/backends/mlx/serialization/MLXLoader.h.tmpl @@ -0,0 +1,343 @@ +// -*- c++ -*- + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "schema_generated.h" + +// ExecuTorch scalar type for dtype representation +#include + +namespace executorch { +namespace backends { +namespace mlx { + +// ============================================================================= +// Core types matching the Python side +// ============================================================================= + +struct Tid { + uint32_t idx{}; +}; + +struct Vid { + uint32_t idx{}; +}; + +// ============================================================================= +// Tensor metadata +// ============================================================================= + +// Import ScalarType from ExecuTorch +using ScalarType = ::executorch::runtime::etensor::ScalarType; + +struct ShapeDim { + int32_t value{-1}; // Static dim (>= 0), or -1 for dynamic + int32_t min_value{0}; // Lower bound (when value == -1) + int32_t max_value{-1}; // Upper bound (-1 = unbounded, when value == -1) + + bool is_dynamic() const { return value < 0; } +}; + +struct TensorMeta { + std::vector shape; + ScalarType scalar_type{ScalarType::Float}; // ET ScalarType + std::vector dim_order; +}; + +// VidOrTid: either a scalar value (Vid) or a tensor (Tid) +struct VidOrTid { + Vid vid{}; + Tid tid{}; + bool is_vid{false}; // false = use tid, true = use vid +}; + +// IntOrVidOrTid: a literal int, a runtime Vid, or a tensor (Tid) +struct IntOrVidOrTid { + int64_t literal{0}; + Vid vid{}; + Tid tid{}; + uint8_t kind{0}; // 0 = literal int, 1 = vid, 2 = tid +}; + +// ============================================================================= +// Op node types (AUTO-GENERATED from schema.fbs) +// ============================================================================= + +{{OP_NODE_STRUCTS}} + +// ============================================================================= +// OpCode enum (AUTO-GENERATED from schema.fbs) +// ============================================================================= + +enum class OpCode : uint8_t { +{{OPCODE_ENUM_VALUES}} +}; + +// OpCode to string conversion (for logging) +inline const char* op_name(OpCode op) { + switch (op) { +{{OP_NAME_CASES}} + } + return "UNKNOWN"; +} + +// ============================================================================= +// NodeVariant for type-erased op storage (AUTO-GENERATED) +// ============================================================================= + +using NodeVariant = std::variant< +{{NODE_VARIANT_TYPES}} +>; + +// ============================================================================= +// Instruction +// ============================================================================= + +struct Instruction { + OpCode op{OpCode::NOOP}; + NodeVariant node; + + template + T& get() { + return std::get(node); + } + + template + const T& get() const { + return std::get(node); + } +}; + +// ============================================================================= +// SlotVariant for I/O mapping +// ============================================================================= + +enum class SlotType : uint8_t { + TensorSlot = 0, + IntValueSlot = 1, + FloatValueSlot = 2, + BoolValueSlot = 3, +}; + +struct SlotVariant { + uint32_t idx; + SlotType slot_type; +}; + +// ============================================================================= +// Named slot (name -> slot mapping) +// ============================================================================= + +struct NamedSlot { + std::string name; + SlotVariant slot; +}; + +// ============================================================================= +// MLXProgram - the loaded program ready for execution +// ============================================================================= + +struct MLXProgram { + std::string version; + + // Tensor/value slot counts (in Tid assignment order) + uint32_t num_constant_tensors{0}; + uint32_t num_input_tensors{0}; + uint32_t num_output_tensors{0}; + uint32_t num_mutable_buffer_tensors{0}; + uint32_t num_temp_tensors{0}; + uint32_t num_values{0}; + + // Instruction chains + std::vector> instruction_chains; + uint32_t main_chain_idx{0}; + int32_t init_chain_idx{-1}; // -1 = no init chain + + // I/O mappings + std::vector input_map; + std::vector output_map; + std::vector mutable_buffer_map; + + // Name to slot lookup + std::vector named_slots; + + // Tensor metadata + std::vector> tensor_meta; + + // Helper methods + inline uint64_t num_tensors() const { + return static_cast(num_constant_tensors) + + num_input_tensors + num_output_tensors + + num_mutable_buffer_tensors + num_temp_tensors; + } + + inline bool is_constant_tensor(Tid id) const { + return id.idx < num_constant_tensors; + } + + inline size_t num_inputs() const { + return input_map.size(); + } + + inline size_t num_outputs() const { + return output_map.size(); + } +}; + +// ============================================================================= +// FlatBuffer loading functions +// ============================================================================= + +namespace loader { + +// Convert FlatBuffer SlotType to our SlotType +inline SlotType convert_slot_type(mlx_delegate::SlotType fb_type) { + switch (fb_type) { + case mlx_delegate::SlotType_TensorSlot: + return SlotType::TensorSlot; + case mlx_delegate::SlotType_IntValueSlot: + return SlotType::IntValueSlot; + case mlx_delegate::SlotType_FloatValueSlot: + return SlotType::FloatValueSlot; + case mlx_delegate::SlotType_BoolValueSlot: + return SlotType::BoolValueSlot; + default: + throw std::runtime_error("Unknown SlotType: " + + std::to_string(static_cast(fb_type))); + } +} + +// Convert FlatBuffer Tid +inline Tid convert_tid(const mlx_delegate::Tid* fb_tid) { + if (!fb_tid) { + throw std::runtime_error("Null Tid in FlatBuffer"); + } + return Tid{fb_tid->idx()}; +} + +// Convert FlatBuffer Vid +inline Vid convert_vid(const mlx_delegate::Vid* fb_vid) { + if (!fb_vid) { + throw std::runtime_error("Null Vid in FlatBuffer"); + } + return Vid{fb_vid->idx()}; +} + +// Convert FlatBuffer IntOrVid +inline std::variant convert_int_or_vid( + const mlx_delegate::IntOrVid* fb) { + if (!fb) { + throw std::runtime_error("Null IntOrVid in FlatBuffer"); + } + if (!fb->is_vid()) { + return fb->literal(); + } + const auto* vid_ptr = fb->vid(); + if (!vid_ptr) { + throw std::runtime_error("IntOrVid has is_vid=true but vid pointer is null"); + } + return Vid{vid_ptr->idx()}; +} + +// Convert FlatBuffer FloatOrVid +inline std::variant convert_float_or_vid( + const mlx_delegate::FloatOrVid* fb) { + if (!fb) { + throw std::runtime_error("Null FloatOrVid in FlatBuffer"); + } + if (!fb->is_vid()) { + return fb->literal(); + } + const auto* vid_ptr = fb->vid(); + if (!vid_ptr) { + throw std::runtime_error("FloatOrVid has is_vid=true but vid pointer is null"); + } + return Vid{vid_ptr->idx()}; +} + +// Convert FlatBuffer VidOrTid (scalar value or tensor) +inline VidOrTid convert_vid_or_tid( + const mlx_delegate::VidOrTid* fb) { + if (!fb) { + throw std::runtime_error("Null VidOrTid in FlatBuffer"); + } + VidOrTid result; + result.is_vid = fb->is_vid(); + if (result.is_vid) { + if (!fb->vid()) { + throw std::runtime_error("VidOrTid has is_vid=true but vid pointer is null"); + } + result.vid = Vid{fb->vid()->idx()}; + } else { + if (!fb->tid()) { + throw std::runtime_error("VidOrTid has is_vid=false but tid pointer is null"); + } + result.tid = Tid{fb->tid()->idx()}; + } + return result; +} + +// Convert FlatBuffer IntOrVidOrTid (literal int, Vid, or Tid) +inline IntOrVidOrTid convert_int_or_vid_or_tid( + const mlx_delegate::IntOrVidOrTid* fb) { + if (!fb) { + throw std::runtime_error("Null IntOrVidOrTid in FlatBuffer"); + } + IntOrVidOrTid result; + result.kind = fb->kind(); + switch (result.kind) { + case 0: // literal int + result.literal = fb->literal(); + break; + case 1: { // Vid + const auto* vid_ptr = fb->vid(); + if (!vid_ptr) { + throw std::runtime_error( + "IntOrVidOrTid has kind=1 (Vid) but vid pointer is null"); + } + result.vid = Vid{vid_ptr->idx()}; + break; + } + case 2: { // Tid + const auto* tid_ptr = fb->tid(); + if (!tid_ptr) { + throw std::runtime_error( + "IntOrVidOrTid has kind=2 (Tid) but tid pointer is null"); + } + result.tid = Tid{tid_ptr->idx()}; + break; + } + default: + throw std::runtime_error( + "IntOrVidOrTid has invalid kind: " + std::to_string(result.kind)); + } + return result; +} + +// Convert FlatBuffer SlotVariant +inline SlotVariant convert_slot_variant(const mlx_delegate::SlotVariant* fb) { + if (!fb) { + throw std::runtime_error("Null SlotVariant in FlatBuffer"); + } + return SlotVariant{fb->idx(), convert_slot_type(fb->slot_type())}; +} + +// Load an instruction from FlatBuffer +Instruction load_instruction(const mlx_delegate::Instruction* fb_instr); + +// Load the full MLXProgram from FlatBuffer data +MLXProgram load_program(const void* data, size_t size); + +} // namespace loader + +} // namespace mlx +} // namespace backends +} // namespace executorch diff --git a/backends/mlx/serialization/README.md b/backends/mlx/serialization/README.md new file mode 100644 index 00000000000..f2c022d0c80 --- /dev/null +++ b/backends/mlx/serialization/README.md @@ -0,0 +1,130 @@ +# MLX Delegate Serialization + +This directory contains the serialization code for the MLX delegate, which converts +Python graph representations to FlatBuffer format for execution on Apple Silicon. + +## Single Source of Truth: `schema.fbs` + +The FlatBuffer schema file `schema.fbs` is the **single source of truth** for all +serialization-related code. When you need to add a new op or modify existing types, +edit `schema.fbs` and regenerate all derived files. + +## Code Generator + +The `generate.py` script parses `schema.fbs` and generates: + +| Generated File | Description | +|----------------|-------------| +| `mlx_graph_schema.py` | Python dataclasses for all schema types | +| `_generated_serializers.py` | Python FlatBuffer serialization methods | +| `_generated/` | Python FlatBuffer reader classes (via `flatc`) | +| `../runtime/MLXLoader.h` | C++ structs, OpCode enum, NodeVariant | +| `../runtime/MLXLoader.cpp` | C++ `load_instruction()` switch statement | +| `../runtime/schema_generated.h` | C++ FlatBuffer reader classes (via `flatc`) | + +## Usage + +### Regenerate all files + +From the executorch root directory: + +```bash +python backends/mlx/serialization/generate.py +``` + +Or with explicit flatc path: + +```bash +python backends/mlx/serialization/generate.py --flatc /path/to/flatc +``` + +### Options + +``` +--flatc PATH Path to flatc compiler (default: "flatc") +--skip-flatc Skip running flatc (use existing FlatBuffer bindings) +--dry-run Print what would be generated without writing files +``` + +## File Structure + +``` +serialization/ +├── README.md # This file +├── schema.fbs # SOURCE OF TRUTH - FlatBuffer schema +├── generate.py # Code generator script +├── mlx_graph_schema.py # [GENERATED] Python dataclasses +├── mlx_graph_serialize.py # Main serializer (uses generated code) +├── _generated_serializers.py # [GENERATED] Op serialization methods +└── _generated/ # [GENERATED] FlatBuffer Python bindings + └── mlx_delegate/ + ├── *.py # One file per table/enum + +runtime/ +├── MLXLoader.h # [GENERATED] C++ types and loader decls +├── MLXLoader.cpp # [GENERATED] C++ loader implementation +├── schema_generated.h # [GENERATED] FlatBuffer C++ bindings +├── MLXInterpreter.h # C++ executor (manual) +├── MLXExecutor.h # C++ executor interface (manual) +└── MLXBackend.cpp # ExecuTorch backend integration (manual) +``` + +## Schema Design Notes + +### Field Types + +- `Tid` - Tensor slot identifier (indexes into tensor array) +- `Vid` - Value slot identifier (indexes into values array for scalars) +- `IntOrVid` - Either a literal int64 or a Vid (for dynamic shapes) +- `FloatOrVid` - Either a literal double or a Vid +- `DTypeId` - Data type enum (f16, f32, bf16, i32, etc.) + +### Optional Fields + +FlatBuffer fields without `(required)` are optional. In the generated Python +dataclasses, these become `Optional[T]` with default `None`. + +For optional scalar fields that need a sentinel (to distinguish None from 0), +use the `= null` default: + +```flatbuffers +table MyNode { + value: float = null; // None by default, distinguishes None from 0.0 +} +``` + +This requires FlatBuffers 2.0+ (ExecuTorch uses 24.3.25). The generated Python +dataclass will have `value: Optional[float] = None`. + +## Troubleshooting + +### flatc not found + +Install FlatBuffers or specify the path: + +```bash +# macOS +brew install flatbuffers + +# Or specify path +python generate.py --flatc /usr/local/bin/flatc +``` + +### Import errors after regeneration + +Make sure you're running from the correct environment: + +```bash +conda run -n et-mlx python backends/mlx/serialization/generate.py +``` + +### Generated code doesn't match schema + +Delete all generated files and regenerate: + +```bash +rm -rf backends/mlx/serialization/_generated +rm backends/mlx/serialization/mlx_graph_schema.py +rm backends/mlx/serialization/_generated_serializers.py +python backends/mlx/serialization/generate.py +``` diff --git a/backends/mlx/serialization/__init__.py b/backends/mlx/serialization/__init__.py new file mode 100644 index 00000000000..35a4f0cef8a --- /dev/null +++ b/backends/mlx/serialization/__init__.py @@ -0,0 +1,32 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +"""Serialization utilities for MLX delegate.""" + +from pathlib import Path + + +_schema_py = Path(__file__).parent / "mlx_graph_schema.py" +if not _schema_py.exists(): + raise ImportError( + "MLX delegate generated files not found. " + "Run 'python install_executorch.py' first." + ) + +# Export serialization functions for convenience +from executorch.backends.mlx.serialization.mlx_graph_serialize import ( # noqa: F401, E501 + deserialize_to_json, + parse_header, + serialize_mlx_graph, +) + +__all__ = [ + "deserialize_to_json", + "parse_header", + "serialize_mlx_graph", +] diff --git a/backends/mlx/serialization/generate.py b/backends/mlx/serialization/generate.py new file mode 100755 index 00000000000..d12743906db --- /dev/null +++ b/backends/mlx/serialization/generate.py @@ -0,0 +1,1437 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +""" +Code generator for MLX delegate. + +This is the SINGLE SOURCE OF TRUTH generator. Edit schema.fbs, then run: + python generate.py + +Generates: +1. FlatBuffer bindings (via flatc): + - _generated/ (Python) + - ../runtime/schema_generated.h (C++) +2. mlx_graph_schema.py (Python dataclasses) +3. _generated_serializers.py (Python serialization code) +4. ../runtime/MLXLoader.h (C++ structs, enums) - PARTIAL +5. ../runtime/MLXLoader.cpp (C++ loader switch) - PARTIAL + +Usage: + python generate.py [--flatc PATH_TO_FLATC] [--skip-flatc] +""" + +from __future__ import annotations + +import argparse +import re +import subprocess +from dataclasses import dataclass +from pathlib import Path +from typing import List, Optional, Tuple + + +SCRIPT_DIR = Path(__file__).parent +SCHEMA_FBS = SCRIPT_DIR / "schema.fbs" +GENERATED_DIR = SCRIPT_DIR / "_generated" +GENERATED_SERIALIZERS = SCRIPT_DIR / "_generated_serializers.py" +GENERATED_SCHEMA_PY = SCRIPT_DIR / "mlx_graph_schema.py" +GENERATED_INSPECTOR = SCRIPT_DIR.parent / "_generated_inspector.py" +RUNTIME_DIR = SCRIPT_DIR.parent / "runtime" +LOADER_H_TMPL = SCRIPT_DIR / "MLXLoader.h.tmpl" +LOADER_CPP_TMPL = SCRIPT_DIR / "MLXLoader.cpp.tmpl" +LOADER_H = RUNTIME_DIR / "MLXLoader.h" +LOADER_CPP = RUNTIME_DIR / "MLXLoader.cpp" + + +@dataclass +class FBSEnum: + name: str + base_type: str # e.g., "byte" + values: List[Tuple[str, Optional[int]]] # (name, explicit_value or None) + + +@dataclass +class FBSField: + name: str + type_str: str + required: bool + default: Optional[str] + + +# FBS integer types (signed and unsigned) +FBS_INTEGER_TYPES = frozenset( + { + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + } +) + +# FBS float types +FBS_FLOAT_TYPES = frozenset({"float", "double"}) + +# All FBS primitive scalar types (numbers + bool) +FBS_SCALAR_TYPES = FBS_INTEGER_TYPES | FBS_FLOAT_TYPES | frozenset({"bool"}) + +# Compound "or" types that wrap a literal + Vid +FBS_COMPOUND_TYPES = frozenset({"IntOrVid", "FloatOrVid", "VidOrTid", "IntOrVidOrTid"}) + +# Python type mapping for FBS primitives +FBS_TO_PYTHON = { + "int8": "int", + "int16": "int", + "int32": "int", + "int64": "int", + "uint8": "int", + "uint16": "int", + "uint32": "int", + "uint64": "int", + "float": "float", + "double": "float", + "bool": "bool", + "string": "str", + "byte": "int", +} + +# C++ type mapping for FBS primitives +FBS_TO_CPP = { + "int8": "int8_t", + "int16": "int16_t", + "int32": "int32_t", + "int64": "int64_t", + "uint8": "uint8_t", + "uint16": "uint16_t", + "uint32": "uint32_t", + "uint64": "uint64_t", + "float": "float", + "double": "double", + "bool": "bool", + "string": "std::string", + "byte": "uint8_t", + "Tid": "Tid", + "Vid": "Vid", + "IntOrVid": "std::variant", + "FloatOrVid": "std::variant", +} + + +def _section_header(comment: str, title: str) -> List[str]: + """Generate a section-header banner for generated output.""" + sep = f"{comment} {'=' * 76}" + return [sep, f"{comment} {title}", sep, ""] + + +def _file_header(comment: str, description: str = "") -> List[str]: + """Generate a standard auto-generated file header. + + Args: + comment: Comment prefix, e.g. '#' for Python or '//' for C++. + description: Optional description appended after the banner. + """ + sep = f"{comment} {'=' * 76}" + lines = [ + f"{comment}", + f"{comment} Copyright (c) Meta Platforms, Inc. and affiliates.", + f"{comment} All rights reserved.", + f"{comment}", + f"{comment} This source code is licensed under the BSD-style license found in the", + f"{comment} LICENSE file in the root directory of this source tree.", + f"{comment}", + sep, + f"{comment} AUTO-GENERATED FILE - DO NOT EDIT MANUALLY", + sep, + f"{comment}", + f"{comment} This file was generated from schema.fbs by the MLX delegate code generator.", + f"{comment}", + f"{comment} Source: backends/mlx/serialization/schema.fbs", + f"{comment} Generator: backends/mlx/serialization/generate.py", + f"{comment}", + f"{comment} To regenerate, run from the executorch root:", + f"{comment} python backends/mlx/serialization/generate.py", + f"{comment}", + sep, + ] + if description: + lines.append(f"{comment}") + lines.append(f"{comment} {description}") + return lines + + +@dataclass +class FBSStruct: + name: str + fields: List[FBSField] + + +@dataclass +class FBSTable: + name: str + fields: List[FBSField] + + +@dataclass +class FBSUnion: + name: str + types: List[str] + + +@dataclass +class FBSSchema: + namespace: str + enums: List[FBSEnum] + structs: List[FBSStruct] + tables: List[FBSTable] + unions: List[FBSUnion] + + def get_op_nodes(self) -> List[FBSTable]: + """Get all tables that are part of the OpNode union.""" + op_union = next((u for u in self.unions if u.name == "OpNode"), None) + if not op_union: + return [] + op_names = set(op_union.types) + return [t for t in self.tables if t.name in op_names] + + +def parse_fbs(fbs_path: Path) -> FBSSchema: + """Parse a FlatBuffer schema file.""" + with open(fbs_path) as f: + content = f.read() + + # Remove comments + content = re.sub(r"//.*$", "", content, flags=re.MULTILINE) + + namespace = "" + enums: List[FBSEnum] = [] + structs: List[FBSStruct] = [] + tables: List[FBSTable] = [] + unions: List[FBSUnion] = [] + + # Parse namespace + ns_match = re.search(r"namespace\s+(\w+)\s*;", content) + if ns_match: + namespace = ns_match.group(1) + + # Parse enums + for match in re.finditer(r"enum\s+(\w+)\s*:\s*(\w+)\s*\{([^}]+)\}", content): + enum_name = match.group(1) + base_type = match.group(2) + body = match.group(3) + values = [] + for val_match in re.finditer(r"(\w+)\s*(?:=\s*(\d+))?", body): + name = val_match.group(1) + explicit_val = int(val_match.group(2)) if val_match.group(2) else None + values.append((name, explicit_val)) + enums.append(FBSEnum(enum_name, base_type, values)) + + # Parse structs + for match in re.finditer(r"struct\s+(\w+)\s*\{([^}]+)\}", content): + struct_name = match.group(1) + body = match.group(2) + fields = _parse_fields(body) + structs.append(FBSStruct(struct_name, fields)) + + # Parse tables + for match in re.finditer(r"table\s+(\w+)\s*\{([^}]*)\}", content): + table_name = match.group(1) + body = match.group(2) + fields = _parse_fields(body) + tables.append(FBSTable(table_name, fields)) + + # Parse unions + for match in re.finditer(r"union\s+(\w+)\s*\{([^}]+)\}", content): + union_name = match.group(1) + body = match.group(2) + types = [t.strip() for t in body.split(",") if t.strip()] + unions.append(FBSUnion(union_name, types)) + + return FBSSchema(namespace, enums, structs, tables, unions) + + +def _parse_fields(body: str) -> List[FBSField]: + """Parse fields from a struct/table body.""" + fields = [] + for line in body.split(";"): + line = line.strip() + if not line: + continue + + # Parse: name: type (attributes) = default + match = re.match( + r"(\w+)\s*:\s*(\[?\w+\]?)\s*(?:\(([^)]*)\))?\s*(?:=\s*([^;]+))?", line + ) + if match: + name = match.group(1) + type_str = match.group(2) + attrs = match.group(3) or "" + default = match.group(4).strip() if match.group(4) else None + required = "required" in attrs + fields.append(FBSField(name, type_str, required, default)) + + return fields + + +# Config for compound type factory methods. +# Maps compound type name -> (primary_field_name, primary_python_type, description) +_COMPOUND_TYPE_CONFIG = { + "IntOrVid": ("literal", "int", "a literal integer"), + "FloatOrVid": ("literal", "float", "a literal float"), + "VidOrTid": ("tid", "Tid", "a tensor reference"), + "IntOrVidOrTid": ("literal", "int", "a literal integer"), +} + + +def _generate_compound_type(table: FBSTable) -> List[str]: # noqa: C901 + """Generate a Python dataclass for a compound type (IntOrVid, etc.) from schema.""" + name = table.name + config = _COMPOUND_TYPE_CONFIG.get(name) + if not config: + raise ValueError(f"No compound type config for '{name}'") + + primary_field, primary_py_type, primary_desc = config + + # Build the docstring from the schema structure + lines = [ + "@dataclass", + f"class {name}:", + ] + + # Docstring: describe the two alternatives + lines.append( + f' """Represents either {primary_desc} or a runtime Vid reference."""' + ) + + # Dataclass fields from the parsed schema + for fld in table.fields: + if fld.default == "false": + default = "False" + elif fld.default == "true": + default = "True" + elif fld.type_str in ("Tid", "Vid"): + default = "None" + elif fld.default is not None: + default = fld.default + elif fld.type_str in FBS_INTEGER_TYPES: + default = "0" + elif fld.type_str in FBS_FLOAT_TYPES: + default = "0.0" + else: + default = "None" + truly_required = default != "None" + py_type = _fbs_type_to_python(fld.type_str, truly_required) + lines.append(f" {fld.name}: {py_type} = {default}") + + # Check if this is a 3-way discriminator (IntOrVidOrTid uses 'kind') + has_kind = any(fld.name == "kind" for fld in table.fields) + has_tid = any(fld.name == "tid" for fld in table.fields) + + # Factory: from_primary (e.g. from_literal, from_tid) + lines.append("") + lines.append(" @classmethod") + lines.append( + f' def from_{primary_field}(cls, value: {primary_py_type}) -> "{name}":' + ) + lines.append(f' """Create a {name} from {primary_desc}."""') + if has_kind: + lines.append(f" return cls({primary_field}=value, kind=0)") + else: + lines.append(f" return cls({primary_field}=value, is_vid=False)") + + # Factory: from_vid + lines.append("") + lines.append(" @classmethod") + lines.append(f' def from_vid(cls, vid: Vid) -> "{name}":') + lines.append(f' """Create a {name} from a Vid reference."""') + if has_kind: + lines.append(" return cls(vid=vid, kind=1)") + else: + lines.append(" return cls(vid=vid, is_vid=True)") + + # Factory: from_tid (only for types with a tid field) + if has_tid: + lines.append("") + lines.append(" @classmethod") + lines.append(f' def from_tid(cls, tid: Tid) -> "{name}":') + lines.append(f' """Create a {name} from a Tid tensor reference."""') + if has_kind: + lines.append(" return cls(tid=tid, kind=2)") + else: + lines.append(" return cls(tid=tid, is_vid=False)") + + lines.append("") + return lines + + +def _generate_dataclass(table: FBSTable) -> List[str]: + """Generate a Python @dataclass from a parsed FBS table. + + Handles field ordering (required/defaulted before optional), skips + _is_set sentinel fields, and emits proper type annotations with defaults. + """ + lines = ["@dataclass", f"class {table.name}:"] + fields = [f for f in table.fields if not f.name.endswith("_is_set")] + if not fields: + lines.append(" pass") + else: + required_fields = [f for f in fields if f.required or f.default is not None] + optional_fields = [f for f in fields if not f.required and f.default is None] + + for fld in required_fields: + py_type = _fbs_type_to_python(fld.type_str, True) + default = _fbs_default_to_python(fld.default, fld.type_str) + if default is not None: + lines.append(f" {fld.name}: {py_type} = {default}") + else: + lines.append(f" {fld.name}: {py_type}") + + for fld in optional_fields: + py_type = _fbs_type_to_python(fld.type_str, fld.required) + lines.append(f" {fld.name}: {py_type} = None") + + lines.extend(["", ""]) + return lines + + +def generate_python_schema(schema: FBSSchema) -> str: # noqa: C901 + """Generate mlx_graph_schema.py from parsed FBS.""" + lines = _file_header("#") + lines.extend( + [ + "", + "from __future__ import annotations", + "", + "from dataclasses import dataclass, field", + "from enum import IntEnum", + "from typing import List, Optional, Union", + "", + "", + *_section_header("#", "Enums"), + ] + ) + + # Generate enums + for enum in schema.enums: + lines.append(f"class {enum.name}(IntEnum):") + val = 0 + for name, explicit_val in enum.values: + if explicit_val is not None: + val = explicit_val + lines.append(f" {name} = {val}") + val += 1 + lines.append("") + lines.append("") + + lines.extend(_section_header("#", "Core types")) + + # Generate structs (Tid, Vid) + for struct in schema.structs: + lines.append("@dataclass") + lines.append(f"class {struct.name}:") + for fld in struct.fields: + py_type = _fbs_type_to_python(fld.type_str, fld.required) + default = _fbs_default_to_python(fld.default, fld.type_str) + if default: + lines.append(f" {fld.name}: {py_type} = {default}") + else: + lines.append(f" {fld.name}: {py_type}") + lines.append("") + lines.append("") + + # Generate compound types (IntOrVid, FloatOrVid, TidOrVid) from schema + for type_name in sorted(FBS_COMPOUND_TYPES): + table = next((t for t in schema.tables if t.name == type_name), None) + if table: + lines.extend(_generate_compound_type(table)) + lines.append("") + + # Generate ShapeDim, SlotVariant, NamedSlot, TensorMeta (but not Instruction/MLXGraph yet - they reference OpNode) + other_tables = ["ShapeDim", "SlotVariant", "NamedSlot", "TensorMeta"] + for table_name in other_tables: + table = next((t for t in schema.tables if t.name == table_name), None) + if table: + lines.extend(_generate_dataclass(table)) + + lines.extend(_section_header("#", "Op nodes")) + + # Generate op node dataclasses + op_nodes = schema.get_op_nodes() + for table in op_nodes: + lines.extend(_generate_dataclass(table)) + + # Generate OpNodeUnion type alias + op_names = [t.name for t in op_nodes] + lines.append("# Union of all op types") + lines.append("OpNodeUnion = Union[") + for name in op_names: + lines.append(f" {name},") + lines.append("]") + lines.append("") + + # Generate Instruction and MLXGraph (these reference OpNode so must come after) + lines.extend( + [ + *_section_header("#", "Container types (reference OpNodeUnion)"), + "@dataclass", + "class Instruction:", + " op: OpNodeUnion", + "", + "", + "@dataclass", + "class InstructionChain:", + " instructions: List[Instruction]", + "", + "", + "@dataclass", + "class MLXGraph:", + " instruction_chains: List[InstructionChain]", + " version: Optional[str] = None", + " num_constant_tensors: int = 0", + " num_input_tensors: int = 0", + " num_output_tensors: int = 0", + " num_mutable_buffer_tensors: int = 0", + " num_temp_tensors: int = 0", + " num_values: int = 0", + " main_chain_idx: int = 0", + " init_chain_idx: int = -1", + " input_map: Optional[List[SlotVariant]] = None", + " output_map: Optional[List[SlotVariant]] = None", + " mutable_buffer_map: Optional[List[SlotVariant]] = None", + " named_slots: Optional[List[NamedSlot]] = None", + " tensor_meta: Optional[List[TensorMeta]] = None", + "", + ] + ) + + return "\n".join(lines) + + +def _fbs_type_to_python(fbs_type: str, required: bool) -> str: + """Convert FBS type to Python type annotation. + + When required=False, the result is wrapped in Optional[…] for all types + (scalars, lists, and reference types alike). + """ + # Handle arrays + if fbs_type.startswith("[") and fbs_type.endswith("]"): + inner = fbs_type[1:-1] + inner_py = _fbs_type_to_python(inner, True) + base = f"List[{inner_py}]" + return base if required else f"Optional[{base}]" + + py_type = FBS_TO_PYTHON.get(fbs_type, fbs_type) + + if not required: + return f"Optional[{py_type}]" + + return py_type + + +def _fbs_default_to_python(default: Optional[str], fbs_type: str) -> Optional[str]: + """Convert FBS default value to Python.""" + if default is None: + return None + + if default == "false": + return "False" + if default == "true": + return "True" + if default == "null": + return "None" + + # Handle enum defaults like 'TensorSlot' + if fbs_type == "SlotType": + return f"SlotType.{default}" + + # Numeric defaults + return default + + +def generate_python_serializers(schema: FBSSchema) -> str: + """Generate _generated_serializers.py from parsed FBS.""" + op_nodes = schema.get_op_nodes() + op_union = next((u for u in schema.unions if u.name == "OpNode"), None) + + header = _file_header( + "#", + "This file contains auto-generated serializer methods for all op types.", + ) + + # Imports and module-level code + op_imports = ",\n".join(f" {t.name}" for t in op_nodes) + lines = [ + *header, + "", + "from __future__ import annotations", + "", + "from typing import List, Tuple, Dict", + "", + "import flatbuffers", + "", + ] + + # Generate op type names dict from union order + lines.append( + "# FlatBuffer union indices: 0 = NONE, then 1-indexed from union order" + ) + lines.append("MLX_OP_TYPE_NAMES = {") + lines.append(' 0: "NONE",') + if op_union: + for i, type_name in enumerate(op_union.types, start=1): + lines.append(f' {i}: "{type_name}",') + lines.append("}") + lines.append("") + + lines.extend( + [ + "from executorch.backends.mlx.serialization.mlx_graph_schema import (", + f"{op_imports},", + " IntOrVid,", + " FloatOrVid,", + " VidOrTid,", + " IntOrVidOrTid,", + " Tid,", + " Vid,", + ")", + "", + "", + "def _build_int_vector(builder: flatbuffers.Builder, vec: List[int]) -> int:", + ' """Build a vector of int32."""', + " builder.StartVector(4, len(vec), 4)", + " for v in reversed(vec):", + " builder.PrependInt32(v)", + " return builder.EndVector()", + "", + "", + "class GeneratedOpBuilders:", + ' """Mixin class with auto-generated op builder methods."""', + "", + " def _build_int_or_vid(self, builder: flatbuffers.Builder, iov: IntOrVid) -> int:", + ' """Build an IntOrVid table."""', + " from executorch.backends.mlx.serialization._generated.mlx_delegate import IntOrVid as FBIntOrVidModule", + " from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid", + "", + " FBIntOrVidModule.Start(builder)", + " FBIntOrVidModule.AddLiteral(builder, iov.literal)", + " FBIntOrVidModule.AddIsVid(builder, iov.is_vid)", + " if iov.vid is not None:", + " # Vid is an inline struct - must be added last for proper FlatBuffer layout", + " FBIntOrVidModule.AddVid(builder, CreateVid(builder, iov.vid.idx))", + " return FBIntOrVidModule.End(builder)", + "", + " def _build_float_or_vid(self, builder: flatbuffers.Builder, fov: FloatOrVid) -> int:", + ' """Build a FloatOrVid table."""', + " from executorch.backends.mlx.serialization._generated.mlx_delegate import FloatOrVid as FBFloatOrVidModule", + " from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid", + "", + " FBFloatOrVidModule.Start(builder)", + " FBFloatOrVidModule.AddLiteral(builder, fov.literal)", + " FBFloatOrVidModule.AddIsVid(builder, fov.is_vid)", + " if fov.vid is not None:", + " FBFloatOrVidModule.AddVid(builder, CreateVid(builder, fov.vid.idx))", + " return FBFloatOrVidModule.End(builder)", + "", + " def _build_vid_or_tid(self, builder: flatbuffers.Builder, vot: VidOrTid) -> int:", + ' """Build a TidOrVid table."""', + " from executorch.backends.mlx.serialization._generated.mlx_delegate import VidOrTid as FBVidOrTidModule", + " from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid", + " from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid", + "", + " FBVidOrTidModule.Start(builder)", + " FBVidOrTidModule.AddIsVid(builder, vot.is_vid)", + " if vot.tid is not None:", + " FBVidOrTidModule.AddTid(builder, CreateTid(builder, vot.tid.idx))", + " if vot.vid is not None:", + " FBVidOrTidModule.AddVid(builder, CreateVid(builder, vot.vid.idx))", + " return FBVidOrTidModule.End(builder)", + "", + " def _build_int_or_vid_or_tid(self, builder: flatbuffers.Builder, ivt: IntOrVidOrTid) -> int:", + ' """Build an IntOrVidOrTid table."""', + " from executorch.backends.mlx.serialization._generated.mlx_delegate import IntOrVidOrTid as FBIntOrVidOrTidModule", + " from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid", + " from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid", + "", + " FBIntOrVidOrTidModule.Start(builder)", + " FBIntOrVidOrTidModule.AddLiteral(builder, ivt.literal)", + " FBIntOrVidOrTidModule.AddKind(builder, ivt.kind)", + " if ivt.tid is not None:", + " FBIntOrVidOrTidModule.AddTid(builder, CreateTid(builder, ivt.tid.idx))", + " if ivt.vid is not None:", + " FBIntOrVidOrTidModule.AddVid(builder, CreateVid(builder, ivt.vid.idx))", + " return FBIntOrVidOrTidModule.End(builder)", + "", + " def _build_int_or_vid_vector(", + " self, builder: flatbuffers.Builder, vec: List[IntOrVid]", + " ) -> int:", + ' """Build a vector of IntOrVid tables."""', + " offsets = []", + " for iov in vec:", + " offsets.append(self._build_int_or_vid(builder, iov))", + " builder.StartVector(4, len(offsets), 4)", + " for off in reversed(offsets):", + " builder.PrependUOffsetTRelative(off)", + " return builder.EndVector()", + "", + " def _build_tid_vector(", + " self, builder: flatbuffers.Builder, vec: List[Tid]", + " ) -> int:", + ' """Build a vector of Tid structs."""', + " from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid", + "", + " # For vectors of structs, we need to build the vector differently", + " # Each Tid struct is 4 bytes (uint32), so we manually write them", + " builder.StartVector(4, len(vec), 4)", + " for tid in reversed(vec):", + " builder.Prep(4, 0) # Align for struct", + " builder.PrependUint32(tid.idx)", + " return builder.EndVector()", + "", + ] + ) + + # Generate builder methods for each op + for table in op_nodes: + lines.append(_generate_op_builder_method(table)) + + return "\n".join(lines) + + +def _generate_op_builder_method(table: FBSTable) -> str: + """Generate a _build_XxxNode method for the serializer class.""" + class_name = table.name + fb_module_name = f"FB{class_name}Module" + + lines = [ + f" def _build_{class_name}(", + f" self, builder: flatbuffers.Builder, op: {class_name}", + " ) -> Tuple[int, int]:", + f' """Auto-generated builder for {class_name}."""', + " # Import the MODULE (not class) to access builder functions like Start(), Add*(), End()", + f" from executorch.backends.mlx.serialization._generated.mlx_delegate import {class_name} as {fb_module_name}", + " from executorch.backends.mlx.serialization._generated.mlx_delegate import OpNode as FBOpNodeModule", + " from executorch.backends.mlx.serialization._generated.mlx_delegate.Tid import CreateTid", + " from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import CreateVid", + "", + ] + + # Pre-build any strings or vectors (must be done before Start) + prebuild_lines = [] + for fld in table.fields: + if fld.name.endswith("_is_set"): + continue + kind = _get_field_kind(fld, table) + pb = _emit_py_prebuild(kind, fld) + if pb: + prebuild_lines.extend(pb) + + if prebuild_lines: + lines.extend(prebuild_lines) + lines.append("") + + # Start the FlatBuffer table + lines.append(f" {fb_module_name}.Start(builder)") + + # Add each field + for fld in table.fields: + if fld.name.endswith("_is_set"): + continue + fb_field_name = _to_pascal_case(fld.name) + kind = _get_field_kind(fld, table) + add_lines = _emit_py_add(kind, fld, fb_module_name, fb_field_name) + if add_lines is None: + raise ValueError( + f"Unhandled field kind '{kind}' for field '{fld.name}' in table '{table.name}'. " + f"Add a handler in _emit_py_add()." + ) + lines.extend(add_lines) + + # End the FlatBuffer table and return offset + union type + lines.append(f" offset = {fb_module_name}.End(builder)") + lines.append(f" return offset, FBOpNodeModule.OpNode.{class_name}") + lines.append("") + + return "\n".join(lines) + + +# Prebuild emitters: return list of lines or None if no prebuild needed. +# These build offsets/vectors that must be created before FlatBuffer Start(). + +_PY_PREBUILD_VECTOR = { + "list_int": "_build_int_vector(builder, op.{name})", + "list_int_or_vid": "self._build_int_or_vid_vector(builder, op.{name})", + "list_tid": "self._build_tid_vector(builder, op.{name})", +} + +_PY_PREBUILD_OFFSET = { + "str": "builder.CreateString(op.{name})", + "int_or_vid": "self._build_int_or_vid(builder, op.{name})", + "float_or_vid": "self._build_float_or_vid(builder, op.{name})", + "vid_or_tid": "self._build_vid_or_tid(builder, op.{name})", + "int_or_vid_or_tid": "self._build_int_or_vid_or_tid(builder, op.{name})", + "optional_str": "builder.CreateString(op.{name}) if op.{name} is not None else None", +} + + +def _emit_py_prebuild(kind: str, fld: FBSField) -> List[str]: + """Emit prebuild lines for a field kind, or empty list if none needed.""" + n = fld.name + if kind in _PY_PREBUILD_VECTOR: + expr = _PY_PREBUILD_VECTOR[kind].format(name=n) + if fld.required: + return [f" {n}_vec = {expr}"] + else: + return [f" {n}_vec = {expr} if op.{n} is not None else None"] + if kind in _PY_PREBUILD_OFFSET: + suffix = "_off" + expr = _PY_PREBUILD_OFFSET[kind].format(name=n) + return [f" {n}{suffix} = {expr}"] + return [] + + +# Maps struct kinds to their Python Create function name +_PY_STRUCT_CREATOR = {"tid": "CreateTid", "vid": "CreateVid"} + + +def _emit_py_add( + kind: str, fld: FBSField, mod: str, fb_name: str +) -> "List[str] | None": + """Emit Add lines for a field kind, or None if kind is unrecognized.""" + n = fld.name + add = f"{mod}.Add{fb_name}" + + # Required struct via inline Create call + if kind in _PY_STRUCT_CREATOR: + creator = _PY_STRUCT_CREATOR[kind] + return [f" {add}(builder, {creator}(builder, op.{n}.idx))"] + # Scalars (direct value) + if kind in ("int", "float", "bool"): + return [f" {add}(builder, op.{n})"] + # Pre-built offsets (string, compound types) + if kind in ("str", "int_or_vid", "float_or_vid", "vid_or_tid", "int_or_vid_or_tid"): + return [f" {add}(builder, {n}_off)"] + # Pre-built vectors (required vs optional) + if kind in ("list_int", "list_int_or_vid", "list_tid"): + if fld.required: + return [f" {add}(builder, {n}_vec)"] + return [ + f" if {n}_vec is not None:", + f" {add}(builder, {n}_vec)", + ] + # Optional struct via inline Create call + if kind in ("optional_tid", "optional_vid"): + creator = _PY_STRUCT_CREATOR[kind.removeprefix("optional_")] + return [ + f" if op.{n} is not None:", + f" {add}(builder, {creator}(builder, op.{n}.idx))", + ] + # Optional scalars + if kind in ("optional_float", "optional_int"): + return [ + f" if op.{n} is not None:", + f" {add}(builder, op.{n})", + ] + # Optional string offset + if kind == "optional_str": + return [ + f" if {n}_off is not None:", + f" {add}(builder, {n}_off)", + ] + return None + + +def _get_field_kind(fld: FBSField, table: FBSTable) -> str: # noqa: C901 + """Classify a field into a canonical kind string. + + This is the single source of truth for field classification, used by all + generators (Python builder, C++ loader, and inspector via _INSPECTOR_KIND_MAP). + """ + t = fld.type_str + + # Handle arrays + if t.startswith("[") and t.endswith("]"): + inner = t[1:-1] + if inner in FBS_INTEGER_TYPES: + return "list_int" + if inner == "IntOrVid": + return "list_int_or_vid" + if inner == "Tid": + return "list_tid" + raise ValueError( + f"Unrecognized array element type '{inner}' for field '{fld.name}' in table '{table.name}'. " + f"Add a handler in _get_field_kind()." + ) + + # Handle basic types + if t == "Tid": + return "optional_tid" if not fld.required else "tid" + if t == "Vid": + return "optional_vid" if not fld.required else "vid" + if t == "IntOrVid": + return "int_or_vid" + if t == "FloatOrVid": + return "float_or_vid" + if t == "VidOrTid": + return "vid_or_tid" + if t == "IntOrVidOrTid": + return "int_or_vid_or_tid" + if t in FBS_INTEGER_TYPES: + if fld.default == "null": + return "optional_int" + return "int" + if t in FBS_FLOAT_TYPES: + # Check if this is optional (has = null default) + if fld.default == "null": + return "optional_float" + return "float" + if t == "bool": + return "bool" + if t == "string": + return "optional_str" if not fld.required else "str" + + raise ValueError( + f"Unrecognized field type '{t}' for field '{fld.name}' in table '{table.name}'. " + f"Add a handler in _get_field_kind()." + ) + + +def _to_pascal_case(name: str) -> str: + """Convert snake_case to PascalCase.""" + # Handle special cases + if name == "table_": + return "Table_" + parts = name.split("_") + return "".join(p.capitalize() for p in parts) + + +def generate_cpp_loader_h(schema: FBSSchema) -> str: + """Generate MLXLoader.h from parsed FBS using template.""" + op_nodes = schema.get_op_nodes() + + struct_lines = [] + for table in op_nodes: + struct_lines.append(f"struct {table.name} {{") + if not table.fields: + struct_lines.append("};") + else: + for fld in table.fields: + if fld.name.endswith("_is_set"): + continue + cpp_type = _fbs_type_to_cpp(fld.type_str, fld.required, table, fld) + struct_lines.append(f" {cpp_type} {fld.name};") + struct_lines.append("};") + struct_lines.append("") + + enum_lines = [] + for table in op_nodes: + enum_lines.append(f" {_table_name_to_opcode(table.name)},") + + name_lines = [] + for table in op_nodes: + op_code = _table_name_to_opcode(table.name) + name_lines.append(f" case OpCode::{op_code}:") + name_lines.append(f' return "{op_code}";') + + variant_lines = [] + for i, table in enumerate(op_nodes): + comma = "," if i < len(op_nodes) - 1 else "" + variant_lines.append(f" {table.name}{comma}") + + # Read template and fill placeholders + header = "\n".join(_file_header("//")) + "\n//\n" + tmpl = LOADER_H_TMPL.read_text() + result = tmpl.replace("{{OP_NODE_STRUCTS}}", "\n".join(struct_lines)) + result = result.replace("{{OPCODE_ENUM_VALUES}}", "\n".join(enum_lines)) + result = result.replace("{{OP_NAME_CASES}}", "\n".join(name_lines)) + result = result.replace("{{NODE_VARIANT_TYPES}}", "\n".join(variant_lines)) + return header + result + + +def _fbs_type_to_cpp( + fbs_type: str, + required: bool, + table: Optional["FBSTable"] = None, + fld: Optional["FBSField"] = None, +) -> str: + """Convert FBS type to C++ type. + + Args: + fbs_type: The FlatBuffer type string + required: Whether the field is required + table: Optional table context for type inference + fld: Optional field context for the current field + + Note: Most scalar types (float, int, etc.) are never optional in C++. + The Python serialization layer is responsible for ensuring scalar fields + have values (using defaults if user doesn't provide them). + Reference types (Tid, Vid) and DTypeId with '= null' default can be optional. + """ + # Handle arrays + if fbs_type.startswith("[") and fbs_type.endswith("]"): + inner = fbs_type[1:-1] + inner_cpp = _fbs_type_to_cpp(inner, True) + return f"std::vector<{inner_cpp}>" + + cpp_type = FBS_TO_CPP.get(fbs_type, fbs_type) + + # Handle optional types + if not required: + if fbs_type == "Tid": + return "std::optional" + if fbs_type == "Vid": + return "std::optional" + if fld is not None and fld.default == "null" and fbs_type in FBS_TO_CPP: + return f"std::optional<{cpp_type}>" + + return cpp_type + + +_OPCODE_OVERRIDES = { + "ARange": "ARANGE", + "AsType": "ASTYPE", + "Conv1D": "CONV1D", + "Conv2D": "CONV2D", + "Conv3D": "CONV3D", + "ConvTranspose1D": "CONV_TRANSPOSE1D", + "ConvTranspose2D": "CONV_TRANSPOSE2D", + "ConvTranspose3D": "CONV_TRANSPOSE3D", +} + + +def _table_name_to_opcode(name: str) -> str: + """Convert table name like 'LinearNode' to opcode like 'LINEAR'. + + Uses regex-based camelCase → UPPER_SNAKE_CASE conversion with a small + override dict for names whose conventional opcode doesn't follow the + normal camelCase splitting rules (e.g. Conv1D → CONV1D, not CONV1_D). + """ + name = name.removesuffix("Node") + if name in _OPCODE_OVERRIDES: + return _OPCODE_OVERRIDES[name] + s = re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", name) + s = re.sub(r"([A-Z]+)([A-Z][a-z])", r"\1_\2", s) + return s.upper() + + +def generate_cpp_loader_cpp(schema: FBSSchema) -> str: + """Generate MLXLoader.cpp from parsed FBS using template.""" + op_nodes = schema.get_op_nodes() + + case_lines = [] + for table in op_nodes: + case_lines.extend(_generate_loader_case(table)) + + # Read template and fill placeholders + header = "\n".join(_file_header("//")) + "\n" + tmpl = LOADER_CPP_TMPL.read_text() + result = tmpl.replace("{{LOAD_INSTRUCTION_CASES}}", "\n".join(case_lines)) + return header + result + + +def _generate_loader_case(table: FBSTable) -> List[str]: + """Generate a switch case for loading an op node.""" + class_name = table.name + op_code = _table_name_to_opcode(class_name) + + lines = [ + f" case mlx_delegate::OpNode_{class_name}: {{", + ] + + if not table.fields: + # NoopNode case + lines.extend( + [ + f" instr.op = OpCode::{op_code};", + f" instr.node = {class_name}{{}};", + " break;", + " }", + "", + ] + ) + return lines + + lines.append(f" auto fb = fb_instr->op_as_{class_name}();") + lines.append(" if (!fb) {{") + lines.append( + ' throw std::runtime_error("FlatBuffer op_type/payload mismatch for {class_name}");' + ) + lines.append(" }}") + lines.append(f" {class_name} node;") + + for fld in table.fields: + if fld.name.endswith("_is_set"): + continue + + fb_field_name = fld.name + kind = _get_field_kind(fld, table) + load_lines = _emit_cpp_load(kind, fld.name, fb_field_name) + if load_lines is None: + raise ValueError( + f"Unhandled field kind '{kind}' for field '{fld.name}' in table '{table.name}'. " + f"Add a handler in _emit_cpp_load()." + ) + lines.extend(load_lines) + + lines.extend( + [ + f" instr.op = OpCode::{op_code};", + " instr.node = std::move(node);", + " break;", + " }", + "", + ] + ) + + return lines + + +# Maps kinds to their C++ converter function name +_CPP_CONVERTER = { + "tid": "convert_tid", + "vid": "convert_vid", + "int_or_vid": "convert_int_or_vid", + "float_or_vid": "convert_float_or_vid", + "vid_or_tid": "convert_vid_or_tid", + "int_or_vid_or_tid": "convert_int_or_vid_or_tid", +} + + +def _emit_cpp_load(kind: str, name: str, fb_name: str) -> "List[str] | None": + """Emit C++ load lines for a field kind, or None if kind is unrecognized.""" + # Required struct / compound via converter + if kind in _CPP_CONVERTER: + conv = _CPP_CONVERTER[kind] + return [f" node.{name} = {conv}(fb->{fb_name}());"] + # Scalars (direct value) + if kind in ("int", "float", "bool"): + return [f" node.{name} = fb->{fb_name}();"] + # Required string + if kind == "str": + return [f' node.{name} = fb->{fb_name}() ? fb->{fb_name}()->str() : "";'] + # Optional struct / compound via guarded converter + base_kind = kind.removeprefix("optional_") + if kind.startswith("optional_") and base_kind in _CPP_CONVERTER: + conv = _CPP_CONVERTER[base_kind] + return [ + f" if (fb->{fb_name}()) {{", + f" node.{name} = {conv}(fb->{fb_name}());", + " }", + ] + # Optional scalar (FlatBuffers returns flatbuffers::Optional) + if kind in ("optional_float", "optional_int"): + return [ + f" auto {fb_name}_opt = fb->{fb_name}();", + f" if ({fb_name}_opt.has_value()) {{", + f" node.{name} = {fb_name}_opt.value();", + " }", + ] + # Optional string + if kind == "optional_str": + return [ + f" if (fb->{fb_name}()) {{", + f" node.{name} = fb->{fb_name}()->str();", + " }", + ] + # Integer/bool vector via to_vector + if kind == "list_int": + return [f" node.{name} = to_vector(fb->{fb_name}());"] + # Int-or-vid vector (indexed access) + if kind == "list_int_or_vid": + return [ + f" if (fb->{fb_name}()) {{", + f" for (size_t i = 0; i < fb->{fb_name}()->size(); ++i) {{", + f" node.{name}.push_back(convert_int_or_vid(fb->{fb_name}()->Get(static_cast(i))));", + " }", + " }", + ] + # Tid vector (range-based iteration) + if kind == "list_tid": + return [ + f" if (fb->{fb_name}()) {{", + f" for (auto fb_tid : *fb->{fb_name}()) {{", + f" node.{name}.push_back(convert_tid(fb_tid));", + " }", + " }", + ] + return None + + +def run_flatc(flatc_path: str = "flatc") -> bool: + """Run flatc to generate Python and C++ bindings.""" + print(f"Running flatc on {SCHEMA_FBS}...") + + # Create output directories + GENERATED_DIR.mkdir(parents=True, exist_ok=True) + + success = True + + # Generate Python bindings + cmd_py = [ + flatc_path, + "--python", + "-o", + str(GENERATED_DIR), + str(SCHEMA_FBS), + ] + try: + result = subprocess.run(cmd_py, capture_output=True, text=True) + if result.returncode != 0: + print(f"flatc (Python) failed: {result.stderr}") + success = False + else: + print(f"Generated FlatBuffer Python bindings in {GENERATED_DIR}") + except FileNotFoundError: + print(f"flatc not found at '{flatc_path}'. Skipping FlatBuffer generation.") + success = False + + # Generate C++ bindings + cmd_cpp = [ + flatc_path, + "--cpp", + "-o", + str(RUNTIME_DIR), + str(SCHEMA_FBS), + ] + try: + result = subprocess.run(cmd_cpp, capture_output=True, text=True) + if result.returncode != 0: + print(f"flatc (C++) failed: {result.stderr}") + success = False + else: + print(f"Generated FlatBuffer C++ bindings in {RUNTIME_DIR}") + except FileNotFoundError: + success = False + + return success + + +_FLATC_IMPORT_PREFIX = "executorch.backends.mlx.serialization._generated." + + +def _fixup_flatc_imports() -> None: + """Rewrite bare ``from mlx_delegate.X`` imports in generated FlatBuffer code. + + ``flatc --python`` emits lazy imports like ``from mlx_delegate.Tid import Tid`` + inside accessor methods. These only resolve if the ``_generated/`` directory is + on ``sys.path``. We rewrite them to fully-qualified imports so no ``sys.path`` + manipulation is needed at runtime. + """ + fb_dir = GENERATED_DIR / "mlx_delegate" + if not fb_dir.exists(): + return + + count = 0 + for py_file in fb_dir.glob("*.py"): + content = py_file.read_text() + if "from mlx_delegate." not in content: + continue + new_content = content.replace( + "from mlx_delegate.", f"from {_FLATC_IMPORT_PREFIX}mlx_delegate." + ) + if new_content != content: + py_file.write_text(new_content) + count += 1 + + if count: + print(f"Fixed bare imports in {count} generated FlatBuffer file(s)") + + +# Mapping from fine-grained field kinds (from _get_field_kind) to inspector +# display kinds. The inspector uses coarser categories: optional/required +# distinctions collapse, and int/float/bool all map to "scalar". +_INSPECTOR_KIND_MAP = { + "tid": "tid", + "optional_tid": "tid", + "vid": "vid", + "optional_vid": "vid", + "int_or_vid": "int_or_vid", + "float_or_vid": "float_or_vid", + "vid_or_tid": "vid_or_tid", + "int_or_vid_or_tid": "int_or_vid_or_tid", + "list_int": "int_list", + "list_int_or_vid": "int_or_vid_list", + "list_tid": "tid_list", + "int": "scalar", + "optional_int": "scalar", + "float": "scalar", + "optional_float": "scalar", + "bool": "scalar", + "str": "string", + "optional_str": "string", +} + + +def generate_inspector(schema: "Schema") -> str: # noqa: F821 + """Generate the inspector field mappings file.""" + lines = _file_header("#") + lines.extend( + [ + "", + '"""', + "Auto-generated inspector field mappings for MLX delegate.", + "", + "This module provides field metadata for each op node type, enabling", + "the pte_inspector to parse FlatBuffer op nodes without manually", + "maintaining field mappings.", + '"""', + "", + "from __future__ import annotations", + "", + "from typing import Dict, List, Tuple", + "", + "", + "# Field kinds and their extractors", + "# Each field is a tuple of (display_name, accessor_name, kind)", + "# where kind is one of: 'tid', 'vid', 'int_or_vid', 'float_or_vid',", + "# 'int_list', 'int_or_vid_list', 'tid_list', 'scalar', 'string'", + "", + "FieldSpec = Tuple[str, str, str] # (display_name, accessor_name, kind)", + "", + "", + "# Mapping from op node name to list of field specs", + "OP_NODE_FIELDS: Dict[str, List[FieldSpec]] = {", + ] + ) + + op_nodes = schema.get_op_nodes() + + for table in op_nodes: + lines.append(f' "{table.name}": [') + for fld in table.fields: + # Skip fields ending in _is_set (legacy pattern) + if fld.name.endswith("_is_set"): + continue + + kind = _get_field_kind(fld, table) + inspector_kind = _INSPECTOR_KIND_MAP.get(kind) + if inspector_kind is None: + raise ValueError( + f"No inspector mapping for field kind '{kind}' " + f"(field '{fld.name}' in table '{table.name}'). " + f"Add a mapping in _INSPECTOR_KIND_MAP." + ) + accessor = _to_pascal_case(fld.name) + lines.append(f' ("{fld.name}", "{accessor}", "{inspector_kind}"),') + lines.append(" ],") + + lines.append("}") + lines.append("") + lines.append("") + + # Add the list of op node names for import generation + lines.append("# List of all op node names (for dynamic imports)") + lines.append("OP_NODE_NAMES: List[str] = [") + for table in op_nodes: + lines.append(f' "{table.name}",') + lines.append("]") + lines.append("") + + return "\n".join(lines) + + +def main(): # noqa: C901 + parser = argparse.ArgumentParser( + description="Generate MLX delegate code from schema.fbs" + ) + parser.add_argument( + "--flatc", + default="flatc", + help="Path to flatc compiler", + ) + parser.add_argument( + "--skip-flatc", + action="store_true", + help="Skip running flatc (use existing generated files)", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Print what would be generated without writing files", + ) + args = parser.parse_args() + + print(f"Parsing {SCHEMA_FBS}...") + schema = parse_fbs(SCHEMA_FBS) + print( + f" Found {len(schema.enums)} enums, {len(schema.structs)} structs, " + f"{len(schema.tables)} tables, {len(schema.unions)} unions" + ) + print(f" Op nodes: {len(schema.get_op_nodes())}") + + # Run flatc + if not args.skip_flatc: + run_flatc(args.flatc) + _fixup_flatc_imports() + + # Generate all code files + generators = [ + (generate_python_schema, GENERATED_SCHEMA_PY, "mlx_graph_schema.py"), + ( + generate_python_serializers, + GENERATED_SERIALIZERS, + "_generated_serializers.py", + ), + (generate_cpp_loader_h, LOADER_H, "MLXLoader.h"), + (generate_cpp_loader_cpp, LOADER_CPP, "MLXLoader.cpp"), + (generate_inspector, GENERATED_INSPECTOR, "_generated_inspector.py"), + ] + for gen_fn, output_path, label in generators: + print(f"Generating {output_path}...") + content = gen_fn(schema) + if args.dry_run: + print(f"--- {label} (first 50 lines) ---") + print("\n".join(content.split("\n")[:50])) + else: + with open(output_path, "w") as f: + f.write(content) + + # Create __init__.py for _generated package that re-exports from mlx_delegate + init_file = GENERATED_DIR / "__init__.py" + if not args.dry_run: + init_file.parent.mkdir(parents=True, exist_ok=True) + + # Get all the exports from mlx_delegate (tables, enums, structs, and unions) + exports = [] + for table in schema.tables: + exports.append(table.name) + for enum in schema.enums: + exports.append(enum.name) + for struct in schema.structs: + exports.append(struct.name) + for union in schema.unions: + exports.append(union.name) + + # Create __init__.py with re-exports + init_content = """# Auto-generated FlatBuffer bindings +# Re-exports from mlx_delegate namespace for convenient imports + +""" + # Add imports from mlx_delegate + for export in sorted(exports): + init_content += f"from executorch.backends.mlx.serialization._generated.mlx_delegate.{export} import {export}\n" + + init_content += f"\n__all__ = {sorted(exports)!r}\n" + init_file.write_text(init_content) + + print("Done!") + print("") + print("Generated files:") + print(f" - {GENERATED_SCHEMA_PY}") + print(f" - {GENERATED_SERIALIZERS}") + print(f" - {GENERATED_INSPECTOR}") + print(f" - {LOADER_H}") + print(f" - {LOADER_CPP}") + if not args.skip_flatc: + print(f" - {GENERATED_DIR}/ (FlatBuffer Python bindings)") + print(f" - {RUNTIME_DIR}/schema_generated.h (FlatBuffer C++ bindings)") + + +if __name__ == "__main__": + main() diff --git a/backends/mlx/serialization/mlx_graph_serialize.py b/backends/mlx/serialization/mlx_graph_serialize.py new file mode 100644 index 00000000000..db5acc9048f --- /dev/null +++ b/backends/mlx/serialization/mlx_graph_serialize.py @@ -0,0 +1,416 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +Serialization utilities for MLX delegate. + +Converts MLXGraph dataclasses to FlatBuffer binary format. + +Constants are NOT embedded in the delegate payload - they are provided by +ExecuTorch via named_data_map at runtime. + +Layout: + [Header: 24 bytes] + - Padding: 4 bytes (zeros) + - Magic: 4 bytes ("MLX0") + - Reserved: 16 bytes (zeros, for future use) + [FlatBuffer payload] +""" + +from __future__ import annotations + +import struct +from typing import Any, List, Tuple + +import flatbuffers + +# Import auto-generated serializers +from executorch.backends.mlx.serialization._generated_serializers import ( + GeneratedOpBuilders, +) +from executorch.backends.mlx.serialization.mlx_graph_schema import ( # noqa: F401 + FloatOrVid, + Instruction, + IntOrVid, + MLXGraph, + NamedSlot, + OpNodeUnion, + SlotType, + SlotVariant, + TensorMeta, + Tid, + Vid, +) +from executorch.exir._serialize._program import Cord + +HEADER_LENGTH = 24 +MAGIC = b"MLX0" +ALIGNMENT = 16 + + +def _padding_required(offset: int, alignment: int) -> int: + remainder = offset % alignment + return (alignment - remainder) % alignment + + +def _build_tid(builder: flatbuffers.Builder, tid: Tid) -> int: + return tid.idx + + +def _build_vid(builder: flatbuffers.Builder, vid: Vid) -> int: + return vid.idx + + +def _build_int_or_vid(builder: flatbuffers.Builder, iov: IntOrVid) -> int: + # Import the MODULE (not class) to access builder functions + from executorch.backends.mlx.serialization._generated.mlx_delegate import ( + IntOrVid as FBIntOrVidModule, + ) + from executorch.backends.mlx.serialization._generated.mlx_delegate.Vid import ( + CreateVid, + ) + + FBIntOrVidModule.Start(builder) + FBIntOrVidModule.AddLiteral(builder, iov.literal) + FBIntOrVidModule.AddIsVid(builder, iov.is_vid) + if iov.vid is not None: + # Vid is an inline struct - must be added last for proper FlatBuffer layout + FBIntOrVidModule.AddVid(builder, CreateVid(builder, iov.vid.idx)) + return FBIntOrVidModule.End(builder) + + +def _build_string(builder: flatbuffers.Builder, s: str) -> int: + return builder.CreateString(s) + + +def _build_int_vector(builder: flatbuffers.Builder, vec: List[int]) -> int: + # FlatBuffers vectors must be created before the table that contains them + builder.StartVector(4, len(vec), 4) # elem_size=4, num_elems, alignment + for v in reversed(vec): + builder.PrependInt32(v) + return builder.EndVector() + + +class MLXGraphSerializer(GeneratedOpBuilders): + """ + Serializes MLXGraph to bytes with separate constant data segment. + + Inherits auto-generated op builders from GeneratedOpBuilders mixin. + """ + + def __init__(self, graph: MLXGraph, constant_data: bytes = b""): + self.graph = graph + self.constant_data = constant_data + + def serialize(self) -> bytes: + """ + Serialize the graph to bytes. + + Returns: + Complete serialized payload with header, flatbuffer, and data segment. + """ + # Build FlatBuffer + fb_bytes = self._build_flatbuffer() + + # Calculate offsets + data_segment_offset = HEADER_LENGTH + len(fb_bytes) + padding_len = _padding_required(data_segment_offset, ALIGNMENT) + data_segment_offset += padding_len + data_segment_size = len(self.constant_data) + + # Build header + header = ( + b"\x00\x00\x00\x00" # 4 bytes padding + + MAGIC # 4 bytes magic + + struct.pack(" 0: + result.append(b"\x00" * padding_len) + result.append(self.constant_data) + + return bytes(result) + + def _build_flatbuffer(self) -> bytes: + builder = flatbuffers.Builder(4096) + + # Build all components bottom-up (FlatBuffers requirement) + + # 1. Build instruction chains + chain_offsets = [] + for chain in self.graph.instruction_chains: + instr_offsets = [] + for instr in chain.instructions: + instr_offsets.append(self._build_instruction(builder, instr)) + instr_vec = self._build_offset_vector(builder, instr_offsets) + + from executorch.backends.mlx.serialization._generated.mlx_delegate import ( + InstructionChain as FBInstructionChainModule, + ) + + FBInstructionChainModule.Start(builder) + FBInstructionChainModule.AddInstructions(builder, instr_vec) + chain_offsets.append(FBInstructionChainModule.End(builder)) + + chains_vec = self._build_offset_vector(builder, chain_offsets) + + # 2. Build I/O maps + input_map_vec = self._build_slot_variant_vector(builder, self.graph.input_map) + output_map_vec = self._build_slot_variant_vector(builder, self.graph.output_map) + mutable_buffer_map_vec = self._build_slot_variant_vector( + builder, self.graph.mutable_buffer_map + ) + + # 3. Build named slots + named_slots_offsets = [] + for ns in self.graph.named_slots: + named_slots_offsets.append(self._build_named_slot(builder, ns)) + named_slots_vec = self._build_offset_vector(builder, named_slots_offsets) + + # 4. Build tensor metadata + tensor_meta_offsets = [] + for tm in self.graph.tensor_meta: + if tm is not None: + tensor_meta_offsets.append(self._build_tensor_meta(builder, tm)) + else: + tensor_meta_offsets.append(0) # null + tensor_meta_vec = self._build_offset_vector(builder, tensor_meta_offsets) + + # 5. Build version string (must be created before the table that uses it) + version_off = builder.CreateString(self.graph.version) + + # 6. Build the root MLXGraph table + from executorch.backends.mlx.serialization._generated.mlx_delegate import ( + MLXGraph as FBMLXGraphModule, + ) + + FBMLXGraphModule.Start(builder) + FBMLXGraphModule.AddVersion(builder, version_off) + FBMLXGraphModule.AddNumConstantTensors(builder, self.graph.num_constant_tensors) + FBMLXGraphModule.AddNumInputTensors(builder, self.graph.num_input_tensors) + FBMLXGraphModule.AddNumOutputTensors(builder, self.graph.num_output_tensors) + FBMLXGraphModule.AddNumMutableBufferTensors( + builder, self.graph.num_mutable_buffer_tensors + ) + FBMLXGraphModule.AddNumTempTensors(builder, self.graph.num_temp_tensors) + FBMLXGraphModule.AddNumValues(builder, self.graph.num_values) + FBMLXGraphModule.AddInstructionChains(builder, chains_vec) + FBMLXGraphModule.AddMainChainIdx(builder, self.graph.main_chain_idx) + FBMLXGraphModule.AddInitChainIdx(builder, self.graph.init_chain_idx) + FBMLXGraphModule.AddInputMap(builder, input_map_vec) + FBMLXGraphModule.AddOutputMap(builder, output_map_vec) + FBMLXGraphModule.AddMutableBufferMap(builder, mutable_buffer_map_vec) + FBMLXGraphModule.AddNamedSlots(builder, named_slots_vec) + FBMLXGraphModule.AddTensorMeta(builder, tensor_meta_vec) + root = FBMLXGraphModule.End(builder) + + builder.Finish(root) + return bytes(builder.Output()) + + def _build_instruction( + self, builder: flatbuffers.Builder, instr: Instruction + ) -> int: + op_offset, op_type = self._build_op_node(builder, instr.op) + + from executorch.backends.mlx.serialization._generated.mlx_delegate import ( + Instruction as FBInstructionModule, + ) + + FBInstructionModule.Start(builder) + FBInstructionModule.AddOpType(builder, op_type) + FBInstructionModule.AddOp(builder, op_offset) + return FBInstructionModule.End(builder) + + def _build_op_node( + self, builder: flatbuffers.Builder, op: OpNodeUnion + ) -> Tuple[int, int]: + """ + Build an op node and return (offset, union_type). + + This is the main dispatch for all op types. + """ + # Map Python class to FlatBuffer union type and builder + # This would ideally be auto-generated + + op_type = type(op).__name__ + builder_method = getattr(self, f"_build_{op_type}", None) + + if builder_method is None: + raise NotImplementedError(f"No builder for op type: {op_type}") + + return builder_method(builder, op) + + def _build_offset_vector( + self, builder: flatbuffers.Builder, offsets: List[int] + ) -> int: + builder.StartVector(4, len(offsets), 4) + for off in reversed(offsets): + builder.PrependUOffsetTRelative(off) + return builder.EndVector() + + def _build_slot_variant_vector( + self, builder: flatbuffers.Builder, slots: List[SlotVariant] + ) -> int: + offsets = [] + for slot in slots: + offsets.append(self._build_slot_variant(builder, slot)) + return self._build_offset_vector(builder, offsets) + + def _build_slot_variant( + self, builder: flatbuffers.Builder, slot: SlotVariant + ) -> int: + from executorch.backends.mlx.serialization._generated.mlx_delegate import ( + SlotVariant as FBSlotVariantModule, + ) + + FBSlotVariantModule.Start(builder) + FBSlotVariantModule.AddIdx(builder, slot.idx) + FBSlotVariantModule.AddSlotType(builder, slot.slot_type) + return FBSlotVariantModule.End(builder) + + def _build_named_slot(self, builder: flatbuffers.Builder, ns: NamedSlot) -> int: + name_off = builder.CreateString(ns.name) + slot_off = self._build_slot_variant(builder, ns.slot) + + from executorch.backends.mlx.serialization._generated.mlx_delegate import ( + NamedSlot as FBNamedSlotModule, + ) + + FBNamedSlotModule.Start(builder) + FBNamedSlotModule.AddName(builder, name_off) + FBNamedSlotModule.AddSlot(builder, slot_off) + return FBNamedSlotModule.End(builder) + + def _build_tensor_meta(self, builder: flatbuffers.Builder, tm: TensorMeta) -> int: + # Shape is a vector of ShapeDim tables + shape_offsets = [] + for dim in tm.shape: + shape_offsets.append(self._build_shape_dim(builder, dim)) + shape_vec = self._build_offset_vector(builder, shape_offsets) + + # Build dim_order vector (uint8) + dim_order_vec = 0 + if tm.dim_order: + builder.StartVector(1, len(tm.dim_order), 1) # elem_size=1 for uint8 + for d in reversed(tm.dim_order): + builder.PrependUint8(d) + dim_order_vec = builder.EndVector() + + from executorch.backends.mlx.serialization._generated.mlx_delegate import ( + TensorMeta as FBTensorMetaModule, + ) + + FBTensorMetaModule.Start(builder) + FBTensorMetaModule.AddShape(builder, shape_vec) + if tm.scalar_type is not None: + FBTensorMetaModule.AddScalarType(builder, tm.scalar_type) + if dim_order_vec: + FBTensorMetaModule.AddDimOrder(builder, dim_order_vec) + return FBTensorMetaModule.End(builder) + + def _build_shape_dim(self, builder: flatbuffers.Builder, dim) -> int: + from executorch.backends.mlx.serialization._generated.mlx_delegate import ( + ShapeDim as FBShapeDimModule, + ) + + FBShapeDimModule.Start(builder) + FBShapeDimModule.AddValue(builder, dim.value) + FBShapeDimModule.AddMinValue(builder, dim.min_value) + FBShapeDimModule.AddMaxValue(builder, dim.max_value) + return FBShapeDimModule.End(builder) + + +def serialize_mlx_graph(graph: MLXGraph, constant_data: bytes = b"") -> bytes: + """ + Serialize an MLXGraph to bytes. + + Args: + graph: The MLXGraph to serialize. + constant_data: Raw bytes for constant tensors. + + Returns: + Serialized bytes with header, flatbuffer, and data segment. + """ + serializer = MLXGraphSerializer(graph, constant_data) + return serializer.serialize() + + +def parse_header(data: bytes) -> Tuple[int, int, int, int]: + """ + Parse the MLX delegate header. + + Returns: + (flatbuffer_offset, flatbuffer_size, data_segment_offset, data_segment_size) + """ + if len(data) < HEADER_LENGTH: + raise ValueError(f"Data too short: {len(data)} < {HEADER_LENGTH}") + + magic = data[4:8] + if magic != MAGIC: + raise ValueError(f"Invalid magic: {magic!r} (expected {MAGIC!r})") + + data_segment_offset = struct.unpack(" dict: + """ + Deserialize MLX delegate payload to a JSON-compatible dict. + + Useful for debugging - extracts the FlatBuffer and dumps it as JSON. + """ + fb_off, fb_size, ds_off, ds_size = parse_header(data) + + # Extract FlatBuffer portion + fb_data = data[fb_off : fb_off + fb_size] + + # Parse using generated FlatBuffer code + from executorch.backends.mlx.serialization._generated.mlx_delegate.MLXGraph import ( + MLXGraph as FBMLXGraphClass, + ) + + graph = FBMLXGraphClass.GetRootAs(fb_data, 0) + + # Convert to dict (recursive) + result = _fb_to_dict(graph) + result["_constant_segment_size"] = ds_size + + return result + + +def _fb_to_dict(obj: Any) -> Any: + if obj is None: + return None + if isinstance(obj, (int, float, str, bool, bytes)): + return obj + if isinstance(obj, (list, tuple)): + return [_fb_to_dict(item) for item in obj] + + # FlatBuffer object - extract fields + result = {} + for attr in dir(obj): + if attr.startswith("_") or attr[0].islower(): + continue + try: + value = getattr(obj, attr)() + result[attr] = _fb_to_dict(value) + except (TypeError, AttributeError): + pass + + return result diff --git a/backends/mlx/serialization/schema.fbs b/backends/mlx/serialization/schema.fbs new file mode 100644 index 00000000000..945186ebef8 --- /dev/null +++ b/backends/mlx/serialization/schema.fbs @@ -0,0 +1,192 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// FlatBuffer schema for MLX delegate - THIS IS THE SOURCE OF TRUTH +// Defines the IR that gets serialized into the .pte file and executed by MLX runtime +// +// After editing this file, regenerate dependent files with: +// python backends/mlx/serialization/generate.py +// +// BACKWARD COMPATIBILITY RULES: +// - New fields in tables: APPEND ONLY (add at the end, with a default value) +// - New union members: APPEND ONLY (add at the end of the union) +// - New tables: Safe to add freely +// - New enum values: APPEND ONLY +// - NEVER remove, reorder, or change the type of existing fields/members + +namespace mlx_delegate; + +// ============================================================================= +// Core types +// ============================================================================= + +// We use ET's ScalarType (int8) directly. +// See runtime/core/portable_type/scalar_type.h for ScalarType values. + +// Tensor slot identifier - indexes into tensors array +struct Tid { + idx: uint32; +} + +// Value slot identifier - indexes into values array +// Values are stored as variant at runtime +struct Vid { + idx: uint32; +} + +// NOTE: These compound types use tables with manual discriminators rather than +// FlatBuffers unions because IntOrVid is used in vectors ([IntOrVid]), and +// FlatBuffers does not support vectors of unions. + +// For fields that can be either a literal int or a runtime Vid +table IntOrVid { + literal: int64; // widened to int64 for future-proofing + vid: Vid; + is_vid: bool = false; +} + +// For fields that can be either a literal float or a runtime Vid +table FloatOrVid { + literal: double; // widened to double for future-proofing + vid: Vid; + is_vid: bool = false; +} + +// For fields that can be either a tensor (Tid) or a scalar value (Vid) +table VidOrTid { + vid: Vid; + tid: Tid; + is_vid: bool = false; // false = use tid, true = use vid +} + +// For fields that can be a literal int, a runtime Vid, or a tensor (Tid) +table IntOrVidOrTid { + literal: int64; + vid: Vid; + tid: Tid; + kind: uint8 = 0; // 0 = literal int, 1 = vid, 2 = tid +} + +// ============================================================================= +// Op nodes - mirrors ops_schema.py dataclasses +// ============================================================================= + +table NoopNode {} + +table AddmmNode { + mat1: Tid (required); // First matrix + mat2: Tid (required); // Second matrix + out: Tid (required); + bias: Tid; // optional - added to result + alpha: float = 1.0; // Scalar multiplier for mat1 @ mat2 + beta: float = 1.0; // Scalar multiplier for bias +} + +// ============================================================================= +// Union of all op types +// ============================================================================= + +// BC: APPEND ONLY — new op nodes must be added at the end of this union. +// Reordering or removing members changes numeric type IDs and breaks existing .pte files. +union OpNode { + NoopNode, + AddmmNode + // BC: Add new op nodes here (append only) +} + +// ============================================================================= +// Instruction wrapper +// ============================================================================= + +table Instruction { + op: OpNode (required); +} + +// ============================================================================= +// Instruction chain (basic block of sequential instructions) +// ============================================================================= + +table InstructionChain { + instructions: [Instruction] (required); + // BC: New fields must be appended here with a default value +} + +// ============================================================================= +// Tensor metadata +// ============================================================================= + +// Shape dimension: static value, or dynamic with optional bounds +table ShapeDim { + value: int32 = -1; // Static dim (>= 0), or -1 for dynamic + min_value: int32 = 0; // Lower bound (only when value == -1) + max_value: int32 = -1; // Upper bound (-1 = unbounded, only when value == -1) +} + +table TensorMeta { + shape: [ShapeDim] (required); // Dimension info with static/dynamic distinction + scalar_type: int8; // ET ScalarType value (see runtime/core/portable_type/scalar_type.h) + dim_order: [uint8]; // Memory layout order (matches TensorLayout.dim_order, DimOrderType = uint8_t) +} + +// ============================================================================= +// Slot variant for I/O mapping +// ============================================================================= + +enum SlotType : byte { + TensorSlot = 0, + IntValueSlot = 1, + FloatValueSlot = 2, + BoolValueSlot = 3 +} + +table SlotVariant { + idx: uint32; + slot_type: SlotType = TensorSlot; +} + +// ============================================================================= +// Name to slot mapping entry +// ============================================================================= + +table NamedSlot { + name: string (required); + slot: SlotVariant (required); +} + +// ============================================================================= +// Root type: MLX Graph +// ============================================================================= + +// BC: New fields must be appended at the end of this table with a default value. +table MLXGraph { + // Version for compatibility + version: string; + + // Tensor slot counts + + num_constant_tensors: uint32; + num_input_tensors: uint32; + num_output_tensors: uint32; + num_mutable_buffer_tensors: uint32; + num_temp_tensors: uint32; + num_values: uint32; + + // Instruction chains (basic blocks of sequential instructions) + instruction_chains: [InstructionChain] (required); + main_chain_idx: uint32 = 0; // Chain to run every execute() call + init_chain_idx: int32 = -1; // Chain to run once at init(), -1 = none + + // I/O mappings + input_map: [SlotVariant]; + output_map: [SlotVariant]; + mutable_buffer_map: [SlotVariant]; + + // Name to slot lookup (used for constant/mutable buffer keys in named_data_map) + named_slots: [NamedSlot]; + + // Tensor metadata (for non-temp tensors), indexed by Tid + tensor_meta: [TensorMeta]; + + // BC: New fields must be appended here with a default value +} + +root_type MLXGraph; diff --git a/backends/mlx/test/CMakeLists.txt b/backends/mlx/test/CMakeLists.txt new file mode 100644 index 00000000000..2a709a63412 --- /dev/null +++ b/backends/mlx/test/CMakeLists.txt @@ -0,0 +1,51 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# MLX backend tests + +# Strict compiler flags for the test runner — mlxdelegate uses PRIVATE so these +# don't propagate to downstream consumers +set(_mlx_test_compile_options -Wall -Werror -Wconversion -Wsign-conversion + -Wshorten-64-to-32 +) + +# Sanitizers are inherited from parent via EXECUTORCH_MLX_ENABLE_SANITIZERS +if(EXECUTORCH_MLX_ENABLE_SANITIZERS) + list(APPEND _mlx_test_compile_options -fsanitize=address,undefined + -fno-omit-frame-pointer + ) +endif() + +# Op test runner - generic test binary for testing individual ops +add_executable(op_test_runner op_test_runner.cpp) + +target_compile_options(op_test_runner PRIVATE ${_mlx_test_compile_options}) +if(EXECUTORCH_MLX_ENABLE_SANITIZERS) + target_link_options(op_test_runner PRIVATE ${_mlx_sanitizer_link_options}) +endif() + +target_link_libraries( + op_test_runner PRIVATE extension_module extension_tensor executorch + mlxdelegate +) + +# -------------------------------------------------------------------------- +# Compile-only strict warnings test for delegate headers +# +# Verifies MLXExecutor.h, MLXInterpreter.h, MLXLoader.h compile cleanly under +# -Wconversion -Wsign-conversion -Wshorten-64-to-32 -Werror. ExecuTorch and MLX +# headers are suppressed via pragma in the source file. This target is never +# linked or run — a successful compile is the test. +# -------------------------------------------------------------------------- +add_library(strict_compile_test OBJECT strict_compile_test.cpp) +target_compile_options(strict_compile_test PRIVATE ${_mlx_test_compile_options}) +target_include_directories( + strict_compile_test PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../runtime +) +target_link_libraries( + strict_compile_test PRIVATE mlx_schema executorch_core mlx +) +add_dependencies(op_test_runner strict_compile_test) diff --git a/backends/mlx/test/README.md b/backends/mlx/test/README.md new file mode 100644 index 00000000000..6d90d513fec --- /dev/null +++ b/backends/mlx/test/README.md @@ -0,0 +1,164 @@ +# MLX Backend Tests + +This directory contains end-to-end tests for the MLX backend. Each test verifies that a specific op or pattern is correctly lowered to MLX and produces matching outputs between PyTorch and the MLX runtime. + +## Setup + +### 1. Install ExecuTorch Python package (if not already installed) + +```bash +python install_executorch.py --editable +``` + +### 2. Configure CMake with MLX preset + +From the ExecuTorch root directory: + +```bash +cmake --preset mlx-release -DEXECUTORCH_BUILD_TESTS=ON +``` + +This configures the build with MLX delegate support and test targets. Build files are generated in `cmake-out/`. + +### 3. Build the test runner + +```bash +cmake --build cmake-out --target op_test_runner +``` + +This builds the `op_test_runner` binary that executes `.pte` models using the MLX runtime. + + + +## Prerequisites + +1. **Python environment**: Tests must be run in an environment where the `executorch` Python package is installed +2. **Built C++ runtime**: The `op_test_runner` binary must be built (see Setup above) + +## Running Tests + +### Run All Tests + +To run all registered tests: + +```bash +python -m executorch.backends.mlx.test.run_all_tests -j4 --clean-after +``` + +### Options + +| Flag | Description | +|------|-------------| +| `-j N` / `--parallel N` | Run tests in parallel with N workers | +| `--clean-after` | Clean up generated test files after running | +| `--clean` | Clean up generated test files and exit | +| `--rebuild` | Rebuild the C++ test runner before running | +| `--list` | List available tests and exit | +| `-v` / `--verbose` | Verbose output | +| `--timeout SECS` | Timeout per test in seconds (default: 300) | + +### Memory Management Options + +Running many tests can accumulate memory (torch/MLX/Metal allocations). These flags help manage memory: + +| Flag | Description | +|------|-------------| +| `--isolate` | Run each test in a separate subprocess (sequential mode only). Provides full memory isolation but is slower due to Python/torch import overhead per test. | +| `--max-tasks-per-worker N` | Recycle parallel workers after N tests (parallel mode only). Workers are terminated and replaced after completing N tests, releasing accumulated memory. | + +**Comparison:** + +| Mode | Memory Isolation | Speed | +|------|------------------|-------| +| `-j 4` | None (workers reused) | Fastest | +| `-j 4 --max-tasks-per-worker 10` | Bounded (recycled every 10 tests) | Fast | +| `-j 4 --max-tasks-per-worker 1` | Full (new process per test) | Slower | +| `--isolate` | Full (subprocess per test) | Slowest (sequential) | + +**Recommended for CI with memory constraints:** + +```bash +python -m executorch.backends.mlx.test.run_all_tests -j4 --max-tasks-per-worker 10 --clean-after +``` + +### Run a Specific Test + +To run a specific test by name (e.g., `linear`): + +```bash +python -m executorch.backends.mlx.test.run_all_tests linear +``` + +With verbose output: + +```bash +python -m executorch.backends.mlx.test.run_all_tests -v linear +``` + +### List Available Tests + +```bash +python -m executorch.backends.mlx.test.run_all_tests --list +``` + +## Test Architecture + +All tests are defined in `test_ops.py`. Each test follows a common pattern: + +1. **Define a model** - A simple `nn.Module` that uses the op being tested +2. **Create test inputs** - Generate random input tensors +3. **Export and lower** - Export the model and lower it to the MLX backend +4. **Run C++ binary** - Execute the lowered model using `op_test_runner` +5. **Compare outputs** - Verify PyTorch and MLX outputs match within tolerance + +### Test Class Structure + +Tests inherit from `OpTestCase` and implement: + +```python +@register_test +class MyTest(OpTestCase): + name = "my_test" # Test name (used for output directory) + rtol = 1e-5 # Relative tolerance for comparison + atol = 1e-5 # Absolute tolerance for comparison + + def create_model(self) -> nn.Module: + """Return the model to test.""" + ... + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + """Return input tensors for export.""" + ... + + def get_dynamic_shapes(self) -> Optional[Dict]: + """Return dynamic shape specs, or None for static shapes.""" + ... + + @classmethod + def get_test_configs(cls) -> List["MyTest"]: + """Return list of test configurations to run.""" + ... +``` + +## Test Output + +Test artifacts are saved to `op_tests//`: +- `model.pte` - Exported ExecuTorch model +- `input.bin` - Serialized input tensors +- `expected_output.bin` - PyTorch reference output +- `actual_output.bin` - MLX runtime output + +## Adding a New Test + +1. Add a new model class and `OpTestCase` subclass to `test_ops.py` +2. Use the `@register_test` decorator on the test class +3. Implement `create_model()`, `create_inputs()`, and `get_test_configs()` +4. Run the test to verify it works E2E + +## Test harness + +MLX also plugs into the ExecuTorch test harness for even more coverage. To run, use the following command from the ExecuTorch root directory: + +```bash +pytest -c /dev/null backends/test/suite/operators/ -m flow_mlx +``` diff --git a/backends/mlx/test/__init__.py b/backends/mlx/test/__init__.py new file mode 100644 index 00000000000..2e41cd717f6 --- /dev/null +++ b/backends/mlx/test/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/mlx/test/op_test_runner.cpp b/backends/mlx/test/op_test_runner.cpp new file mode 100644 index 00000000000..6bed13d7a56 --- /dev/null +++ b/backends/mlx/test/op_test_runner.cpp @@ -0,0 +1,395 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * Generic op test runner for MLX delegate. + * + * Loads a .pte file, reads inputs from .bin files, runs the model, + * and writes outputs to .bin files. + * + * Build: + * cd cmake-out-mlx && cmake --build . --target op_test_runner + * + * Usage: + * ./cmake-out-mlx/backends/mlx/test/op_test_runner \ + * --pte \ + * --input \ + * --output + * + * Binary file format: + * - 4 bytes: number of tensors (uint32_t) + * For each tensor: + * - 4 bytes: dtype (0=float32, 1=float16, 2=int32, 3=int64, 4=bfloat16, + * 5=bool) + * - 4 bytes: number of dimensions (uint32_t) + * - 4 bytes * ndim: shape (int32_t each) + * - N bytes: data (size = product of shape * sizeof(dtype)) + */ + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wconversion" +#pragma clang diagnostic ignored "-Wsign-conversion" +#pragma clang diagnostic ignored "-Wshorten-64-to-32" +#pragma clang diagnostic ignored "-Wimplicit-float-conversion" +#include +#include +#pragma clang diagnostic pop + +#include +#include +#include +#include +#include +#include +#include + +using namespace ::executorch::extension; +using namespace ::executorch::runtime; + +enum class DType : uint32_t { + Float32 = 0, + Float16 = 1, + Int32 = 2, + Int64 = 3, + BFloat16 = 4, + Bool = 5, +}; + +size_t dtype_size(DType dtype) { + switch (dtype) { + case DType::Float32: + return 4; + case DType::Float16: + return 2; + case DType::Int32: + return 4; + case DType::Int64: + return 8; + case DType::BFloat16: + return 2; + case DType::Bool: + return 1; + default: + return 4; + } +} + +exec_aten::ScalarType dtype_to_scalar_type(DType dtype) { + switch (dtype) { + case DType::Float32: + return exec_aten::ScalarType::Float; + case DType::Float16: + return exec_aten::ScalarType::Half; + case DType::Int32: + return exec_aten::ScalarType::Int; + case DType::Int64: + return exec_aten::ScalarType::Long; + case DType::BFloat16: + return exec_aten::ScalarType::BFloat16; + case DType::Bool: + return exec_aten::ScalarType::Bool; + default: + return exec_aten::ScalarType::Float; + } +} + +DType scalar_type_to_dtype(exec_aten::ScalarType stype) { + switch (stype) { + case exec_aten::ScalarType::Float: + return DType::Float32; + case exec_aten::ScalarType::Half: + return DType::Float16; + case exec_aten::ScalarType::Int: + return DType::Int32; + case exec_aten::ScalarType::Long: + return DType::Int64; + case exec_aten::ScalarType::BFloat16: + return DType::BFloat16; + case exec_aten::ScalarType::Bool: + return DType::Bool; + default: + return DType::Float32; + } +} + +struct TensorData { + DType dtype; + std::vector shape; + std::vector data; +}; + +std::vector read_tensors_from_bin(const std::string& path) { + std::ifstream file(path, std::ios::binary); + if (!file) { + throw std::runtime_error("Failed to open input file: " + path); + } + + uint32_t num_tensors; + file.read(reinterpret_cast(&num_tensors), sizeof(num_tensors)); + + std::vector tensors; + tensors.reserve(num_tensors); + + for (uint32_t i = 0; i < num_tensors; ++i) { + TensorData t; + + uint32_t dtype_val; + file.read(reinterpret_cast(&dtype_val), sizeof(dtype_val)); + t.dtype = static_cast(dtype_val); + + uint32_t ndim; + file.read(reinterpret_cast(&ndim), sizeof(ndim)); + + t.shape.resize(ndim); + file.read(reinterpret_cast(t.shape.data()), ndim * sizeof(int32_t)); + + size_t numel = 1; + for (int32_t s : t.shape) { + numel *= static_cast(s); + } + size_t data_size = numel * dtype_size(t.dtype); + + t.data.resize(data_size); + file.read( + reinterpret_cast(t.data.data()), + static_cast(data_size)); + + tensors.push_back(std::move(t)); + } + + return tensors; +} + +void write_tensors_to_bin( + const std::string& path, + const std::vector& tensors) { + std::ofstream file(path, std::ios::binary); + if (!file) { + throw std::runtime_error("Failed to open output file: " + path); + } + + uint32_t num_tensors = static_cast(tensors.size()); + file.write(reinterpret_cast(&num_tensors), sizeof(num_tensors)); + + for (const auto& t : tensors) { + uint32_t dtype_val = static_cast(t.dtype); + file.write(reinterpret_cast(&dtype_val), sizeof(dtype_val)); + + uint32_t ndim = static_cast(t.shape.size()); + file.write(reinterpret_cast(&ndim), sizeof(ndim)); + + file.write( + reinterpret_cast(t.shape.data()), ndim * sizeof(int32_t)); + + file.write( + reinterpret_cast(t.data.data()), + static_cast(t.data.size())); + } +} + +void print_usage(const char* prog_name) { + std::cerr << "Usage: " << prog_name << " [options]\n" + << "Options:\n" + << " --pte Path to .pte model file (required)\n" + << " --input Path to input .bin file (required)\n" + << " --output Path to output .bin file (required)\n" + << " --verbose Print verbose output\n" + << std::endl; +} + +int main(int argc, char* argv[]) { + std::string pte_path; + std::string input_path; + std::string output_path; + bool verbose = false; + + for (int i = 1; i < argc; ++i) { + std::string arg = argv[i]; + if (arg == "--pte" && i + 1 < argc) { + pte_path = argv[++i]; + } else if (arg == "--input" && i + 1 < argc) { + input_path = argv[++i]; + } else if (arg == "--output" && i + 1 < argc) { + output_path = argv[++i]; + } else if (arg == "--verbose") { + verbose = true; + } else if (arg == "--help" || arg == "-h") { + print_usage(argv[0]); + return 0; + } else { + std::cerr << "Unknown argument: " << arg << std::endl; + print_usage(argv[0]); + return 1; + } + } + + if (pte_path.empty() || input_path.empty() || output_path.empty()) { + std::cerr << "Error: --pte, --input, and --output are required\n"; + print_usage(argv[0]); + return 1; + } + + try { + if (verbose) { + std::cout << "Loading model from: " << pte_path << std::endl; + } + + Module module(pte_path); + auto load_error = module.load(); + if (load_error != Error::Ok) { + std::cerr << "Failed to load model: " << static_cast(load_error) + << std::endl; + return 1; + } + + if (verbose) { + std::cout << "Model loaded successfully" << std::endl; + } + + auto load_method_error = module.load_method("forward"); + if (load_method_error != Error::Ok) { + std::cerr << "Failed to load forward method: " + << static_cast(load_method_error) << std::endl; + return 1; + } + + if (verbose) { + std::cout << "Reading inputs from: " << input_path << std::endl; + } + + auto input_tensors = read_tensors_from_bin(input_path); + + if (verbose) { + std::cout << "Read " << input_tensors.size() << " input tensors" + << std::endl; + for (size_t i = 0; i < input_tensors.size(); ++i) { + std::cout << " Input " << i + << ": dtype=" << static_cast(input_tensors[i].dtype) + << ", shape=["; + for (size_t j = 0; j < input_tensors[i].shape.size(); ++j) { + std::cout << input_tensors[i].shape[j]; + if (j < input_tensors[i].shape.size() - 1) + std::cout << ", "; + } + std::cout << "]" << std::endl; + } + } + + std::vector tensor_ptrs; + std::vector inputs; + tensor_ptrs.reserve(input_tensors.size()); + inputs.reserve(input_tensors.size()); + + for (const auto& t : input_tensors) { + std::vector sizes(t.shape.begin(), t.shape.end()); + + TensorPtr tensor_ptr; + if (t.dtype == DType::Float32) { + std::vector data(t.data.size() / sizeof(float)); + std::memcpy(data.data(), t.data.data(), t.data.size()); + tensor_ptr = make_tensor_ptr(sizes, std::move(data)); + } else if (t.dtype == DType::Float16) { + std::vector data( + t.data.size() / sizeof(exec_aten::Half)); + std::memcpy(data.data(), t.data.data(), t.data.size()); + tensor_ptr = make_tensor_ptr(sizes, std::move(data)); + } else if (t.dtype == DType::BFloat16) { + std::vector data( + t.data.size() / sizeof(exec_aten::BFloat16)); + std::memcpy(data.data(), t.data.data(), t.data.size()); + tensor_ptr = make_tensor_ptr(sizes, std::move(data)); + } else if (t.dtype == DType::Int32) { + std::vector data(t.data.size() / sizeof(int32_t)); + std::memcpy(data.data(), t.data.data(), t.data.size()); + tensor_ptr = make_tensor_ptr(sizes, std::move(data)); + } else if (t.dtype == DType::Int64) { + std::vector data(t.data.size() / sizeof(int64_t)); + std::memcpy(data.data(), t.data.data(), t.data.size()); + tensor_ptr = make_tensor_ptr(sizes, std::move(data)); + } else if (t.dtype == DType::Bool) { + std::vector data(t.data.size()); + std::memcpy(data.data(), t.data.data(), t.data.size()); + tensor_ptr = make_tensor_ptr( + sizes, std::move(data), {}, {}, exec_aten::ScalarType::Bool); + } else { + std::cerr << "Unsupported dtype: " << static_cast(t.dtype) + << std::endl; + return 1; + } + + tensor_ptrs.push_back(tensor_ptr); + inputs.push_back(tensor_ptr); + } + + if (verbose) { + std::cout << "Executing forward..." << std::endl; + } + + auto result = module.forward(inputs); + if (result.error() != Error::Ok) { + std::cerr << "Execution failed: " << static_cast(result.error()) + << std::endl; + return 1; + } + + if (verbose) { + std::cout << "Execution succeeded, " << result->size() << " outputs" + << std::endl; + } + + std::vector output_tensors; + output_tensors.reserve(result->size()); + + for (size_t i = 0; i < result->size(); ++i) { + const auto& evalue = result->at(i); + if (!evalue.isTensor()) { + std::cerr << "Output " << i << " is not a tensor" << std::endl; + return 1; + } + + const auto& tensor = evalue.toTensor(); + TensorData t; + t.dtype = scalar_type_to_dtype(tensor.scalar_type()); + + t.shape.resize(static_cast(tensor.dim())); + for (size_t d = 0; d < static_cast(tensor.dim()); ++d) { + t.shape[d] = static_cast(tensor.size(static_cast(d))); + } + + size_t data_size = tensor.nbytes(); + t.data.resize(data_size); + std::memcpy(t.data.data(), tensor.const_data_ptr(), data_size); + + if (verbose) { + std::cout << " Output " << i << ": dtype=" << static_cast(t.dtype) + << ", shape=["; + for (size_t j = 0; j < t.shape.size(); ++j) { + std::cout << t.shape[j]; + if (j < t.shape.size() - 1) + std::cout << ", "; + } + std::cout << "]" << std::endl; + } + + output_tensors.push_back(std::move(t)); + } + + if (verbose) { + std::cout << "Writing outputs to: " << output_path << std::endl; + } + + write_tensors_to_bin(output_path, output_tensors); + + std::cout << "OK" << std::endl; + return 0; + + } catch (const std::exception& e) { + std::cerr << "Error: " << e.what() << std::endl; + return 1; + } +} diff --git a/backends/mlx/test/run_all_tests.py b/backends/mlx/test/run_all_tests.py new file mode 100644 index 00000000000..3cda35da275 --- /dev/null +++ b/backends/mlx/test/run_all_tests.py @@ -0,0 +1,496 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Run all MLX delegate op tests. + +Usage: + # Run all tests (all configurations): + python -m executorch.backends.mlx.test.run_all_tests + + # Run specific test (all its configurations): + python -m executorch.backends.mlx.test.run_all_tests add + + # Run specific test configuration: + python -m executorch.backends.mlx.test.run_all_tests add_scalar + + # List available tests: + python -m executorch.backends.mlx.test.run_all_tests --list + + # Rebuild C++ runner before running: + python -m executorch.backends.mlx.test.run_all_tests --rebuild + + # Run tests in parallel: + python -m executorch.backends.mlx.test.run_all_tests -j 4 + + # Run with custom timeout: + python -m executorch.backends.mlx.test.run_all_tests --timeout 60 +""" + +import argparse +import importlib +import multiprocessing +import subprocess +import sys +from multiprocessing import Pool +from typing import List, Optional, Tuple + +from .test_utils import ( + clean_test_outputs, + DEFAULT_TEST_TIMEOUT, + get_all_test_configs, + get_registered_tests, + get_test_output_size, + rebuild_op_test_runner, +) + + +def discover_and_import_tests(): + """ + Import test_ops.py module which contains all test definitions. + This triggers registration of all tests. + """ + importlib.import_module(".test_ops", package=__package__) + + +def _run_single_test( + test_class_name: str, + config_name: str, + config_kwargs: dict, + verbose: bool, + timeout: int, +) -> Tuple[str, bool, Optional[str]]: + """ + Run a single test configuration in a subprocess. + + Called via multiprocessing.Pool.starmap for parallel execution. + Recreates the test instance from the class name and kwargs. + + Args: + test_class_name: Name of the test class module.path + config_name: Name of this configuration + config_kwargs: Kwargs to recreate the test instance + verbose: Whether to print verbose output + timeout: Timeout in seconds + + Returns: + (config_name, passed, error_message) + """ + try: + # Re-discover and import tests in this subprocess + discover_and_import_tests() + + # Find the test config by name + all_configs = get_all_test_configs() + test_instance = None + for name, instance in all_configs: + if name == config_name: + test_instance = instance + break + + if test_instance is None: + return (config_name, False, f"Could not find test config: {config_name}") + + # Run the test + passed = test_instance.run_test(verbose=verbose, timeout=timeout) + return (config_name, passed, None) + + except Exception as e: + import traceback + + return (config_name, False, f"Exception: {e}\n{traceback.format_exc()}") + + +def run_tests_sequential( + configs_to_run: List[Tuple[str, object]], + verbose: bool = False, + timeout: int = DEFAULT_TEST_TIMEOUT, + clean_after_each: bool = False, + isolate: bool = False, +) -> Tuple[int, int, List[str]]: + """ + Run tests sequentially. + + Args: + configs_to_run: List of (config_name, test_instance) tuples. + verbose: Whether to print verbose output. + timeout: Timeout in seconds per test. + clean_after_each: Whether to clean up test outputs after each test. + isolate: Whether to run each test in a subprocess to prevent memory + accumulation across tests (torch/MLX/Metal allocations). + + Returns: + (passed_count, failed_count, failed_test_names) + """ + passed = 0 + failed = 0 + failed_tests = [] + + for config_name, test in configs_to_run: + if isolate: + test_passed = _run_test_in_subprocess( + config_name, verbose=verbose, timeout=timeout + ) + else: + try: + test_passed = test.run_test(verbose=verbose, timeout=timeout) + except Exception as e: + print(f"✗ FAILED: {config_name} - Exception: {e}") + import traceback + + traceback.print_exc() + test_passed = False + + if test_passed: + passed += 1 + else: + failed += 1 + failed_tests.append(config_name) + + if clean_after_each: + clean_test_outputs([config_name], verbose=False) + + return passed, failed, failed_tests + + +def _run_test_in_subprocess( + config_name: str, + verbose: bool = False, + timeout: int = DEFAULT_TEST_TIMEOUT, +) -> bool: + """ + Run a single test in an isolated subprocess. + + Each test gets its own Python interpreter so torch/MLX/Metal memory is + fully released between tests, preventing OOM on CI runners. + + Args: + config_name: Name of the test configuration to run. + verbose: Whether to print verbose output. + timeout: Timeout in seconds. + + Returns: + True if test passed, False otherwise. + """ + cmd = [ + sys.executable, + "-m", + "executorch.backends.mlx.test.test_utils", + config_name, + "run", + ] + if verbose: + cmd.append("--verbose") + + try: + result = subprocess.run( + cmd, + timeout=timeout, + capture_output=False, + ) + return result.returncode == 0 + except subprocess.TimeoutExpired: + print(f"✗ FAILED: {config_name} - Timeout after {timeout}s") + return False + except Exception as e: + print(f"✗ FAILED: {config_name} - Subprocess error: {e}") + return False + + +def run_tests_parallel( + configs_to_run: List[Tuple[str, object]], + num_workers: int, + verbose: bool = False, + timeout: int = DEFAULT_TEST_TIMEOUT, + max_tasks_per_worker: Optional[int] = None, +) -> Tuple[int, int, List[str]]: + """ + Run tests in parallel using multiprocessing.Pool. + + Args: + configs_to_run: List of (config_name, test_instance) tuples. + num_workers: Number of parallel workers. + verbose: Whether to print verbose output. + timeout: Timeout in seconds per test. + max_tasks_per_worker: Maximum tasks per worker before recycling. + When set, worker processes are terminated and replaced after + completing this many tests, which releases accumulated memory + (torch/MLX/Metal allocations). None means workers are never recycled. + + Returns: + (passed_count, failed_count, failed_test_names) + """ + passed = 0 + failed = 0 + failed_tests = [] + + # Prepare test args for parallel execution + # We pass config names and let subprocesses recreate the test instances + test_args = [("", name, {}, verbose, timeout) for name, _ in configs_to_run] + + recycle_msg = "" + if max_tasks_per_worker is not None: + recycle_msg = f", recycling workers every {max_tasks_per_worker} tests" + print( + f"\nRunning {len(test_args)} tests with {num_workers} workers{recycle_msg}...\n" + ) + + with Pool(processes=num_workers, maxtasksperchild=max_tasks_per_worker) as pool: + results = pool.starmap(_run_single_test, test_args) + + for result_name, result_passed, error_msg in results: + if result_passed: + print(f"✓ PASSED: {result_name}") + passed += 1 + else: + if error_msg: + print(f"✗ FAILED: {result_name} - {error_msg}") + else: + print(f"✗ FAILED: {result_name}") + failed += 1 + failed_tests.append(result_name) + + return passed, failed, failed_tests + + +def run_tests( + test_filter: List[str], + verbose: bool = False, + parallel: int = 1, + timeout: int = DEFAULT_TEST_TIMEOUT, + clean_after_each: bool = False, + isolate: bool = False, + max_tasks_per_worker: Optional[int] = None, +) -> Tuple[int, int, List[str]]: + """ + Run tests matching the filter. + + Args: + test_filter: List of test names/patterns to run. If empty, runs all tests. + Can match either base test name (e.g., "add") or config name (e.g., "add_scalar"). + verbose: Whether to print verbose output. + parallel: Number of parallel workers (1 = sequential). + timeout: Timeout in seconds per test. + clean_after_each: Whether to clean up test outputs after each test (sequential only). + isolate: Whether to run each test in a subprocess (sequential only). + max_tasks_per_worker: Maximum tasks per worker before recycling (parallel only). + + Returns: + (passed_count, failed_count, failed_test_names) + """ + all_configs = get_all_test_configs() + registry = get_registered_tests() + + # Determine which configs to run + configs_to_run = [] + if not test_filter: + # Run all + configs_to_run = all_configs + else: + for pattern in test_filter: + matched = False + + # Check if pattern matches a base test name + if pattern in registry: + configs_to_run.extend(registry[pattern]) + matched = True + else: + # Check if pattern matches a config name + for config_name, config in all_configs: + if config_name == pattern: + configs_to_run.append((config_name, config)) + matched = True + + if not matched: + print(f"Warning: No test matching '{pattern}', skipping") + + if not configs_to_run: + print("No tests to run.") + return 0, 0, [] + + # Run tests + if parallel > 1: + return run_tests_parallel( + configs_to_run, parallel, verbose, timeout, max_tasks_per_worker + ) + else: + return run_tests_sequential( + configs_to_run, verbose, timeout, clean_after_each, isolate + ) + + +def main(): # noqa: C901 + # Get CPU count for default parallel workers + cpu_count = multiprocessing.cpu_count() + + parser = argparse.ArgumentParser(description="Run all MLX delegate op tests") + parser.add_argument( + "tests", + nargs="*", + help="Specific tests to run (default: all). Can be base name (e.g., 'add') or config name (e.g., 'add_scalar')", + ) + parser.add_argument( + "--list", + action="store_true", + help="List available tests and exit", + ) + parser.add_argument( + "--verbose", + "-v", + action="store_true", + help="Verbose output", + ) + parser.add_argument( + "--rebuild", + action="store_true", + help="Rebuild the C++ test runner before running", + ) + parser.add_argument( + "--clean", + action="store_true", + help="Clean up generated test files and exit", + ) + parser.add_argument( + "--clean-after", + action="store_true", + help="Clean up generated test files after running tests", + ) + parser.add_argument( + "--isolate", + action="store_true", + help="Run each test in a separate subprocess to prevent memory accumulation", + ) + parser.add_argument( + "-j", + "--parallel", + type=int, + default=1, + metavar="N", + help=f"Run tests in parallel with N workers (default: 1, max: {cpu_count})", + ) + parser.add_argument( + "--timeout", + type=int, + default=DEFAULT_TEST_TIMEOUT, + metavar="SECS", + help=f"Timeout per test in seconds (default: {DEFAULT_TEST_TIMEOUT})", + ) + parser.add_argument( + "--max-tasks-per-worker", + type=int, + default=None, + metavar="N", + help="Recycle parallel workers after N tests to release memory (default: no recycling)", + ) + args = parser.parse_args() + + # Validate parallel workers + if args.parallel < 1: + args.parallel = 1 + elif args.parallel > cpu_count: + print( + f"Warning: --parallel {args.parallel} exceeds CPU count ({cpu_count}), using {cpu_count}" + ) + args.parallel = cpu_count + + # Auto-discover and import all test modules + discover_and_import_tests() + + # Handle --clean flag + if args.clean: + # Determine which tests to clean + test_names = None + if args.tests: + # Get config names for the specified tests + registry = get_registered_tests() + test_names = [] + for pattern in args.tests: + if pattern in registry: + test_names.extend(cfg_name for cfg_name, _ in registry[pattern]) + else: + test_names.append(pattern) + + # Show current size + current_size = get_test_output_size(test_names) + if current_size > 0: + print(f"Current test output size: {current_size / 1024 / 1024:.2f} MB") + + # Clean + files_removed = clean_test_outputs(test_names, verbose=args.verbose) + if files_removed > 0: + print(f"Removed {files_removed} files") + else: + print("No files to clean") + sys.exit(0) + + # List tests + if args.list: + registry = get_registered_tests() + print("Available tests:") + for base_name in sorted(registry.keys()): + configs = registry[base_name] + if len(configs) == 1 and configs[0][0] == base_name: + # Single config with same name as base + print(f" {base_name}") + else: + # Multiple configs or different name + print(f" {base_name}:") + for config_name, _ in configs: + print(f" - {config_name}") + sys.exit(0) + + # Rebuild if requested + if args.rebuild: + if not rebuild_op_test_runner(verbose=args.verbose): + sys.exit(1) + + # Run tests + passed, failed, failed_tests = run_tests( + args.tests, + verbose=args.verbose, + parallel=args.parallel, + timeout=args.timeout, + clean_after_each=args.clean_after, + isolate=args.isolate, + max_tasks_per_worker=args.max_tasks_per_worker, + ) + + # Print summary + print("\n" + "=" * 60) + print("TEST SUMMARY") + print("=" * 60) + print(f"Passed: {passed}") + print(f"Failed: {failed}") + if failed_tests: + print(f"Failed tests: {', '.join(failed_tests)}") + print("=" * 60) + + # Clean up after tests if requested + if args.clean_after: + # Determine which tests to clean (same logic as --clean) + test_names = None + if args.tests: + registry = get_registered_tests() + test_names = [] + for pattern in args.tests: + if pattern in registry: + test_names.extend(cfg_name for cfg_name, _ in registry[pattern]) + else: + test_names.append(pattern) + + current_size = get_test_output_size(test_names) + files_removed = clean_test_outputs(test_names, verbose=args.verbose) + if files_removed > 0: + print( + f"\nCleaned up {files_removed} files ({current_size / 1024 / 1024:.2f} MB)" + ) + + sys.exit(0 if failed == 0 else 1) + + +if __name__ == "__main__": + main() diff --git a/backends/mlx/test/strict_compile_test.cpp b/backends/mlx/test/strict_compile_test.cpp new file mode 100644 index 00000000000..28df78a7d5a --- /dev/null +++ b/backends/mlx/test/strict_compile_test.cpp @@ -0,0 +1,45 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * Compile-only test to verify MLX delegate headers are clean under strict + * warnings (-Wconversion, -Wsign-conversion, -Wshorten-64-to-32, -Werror). + * + * This file includes the delegate headers and instantiates key types to ensure + * template code is also checked. It is never linked or executed — a successful + * compilation is the test. + */ + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wconversion" +#pragma clang diagnostic ignored "-Wsign-conversion" +#pragma clang diagnostic ignored "-Wshorten-64-to-32" +#include +#include +#include +#include +#pragma clang diagnostic pop + +// These are the headers we want to verify under strict warnings +#include "MLXExecutor.h" +#include "MLXInterpreter.h" +#include "MLXLoader.h" + +// Instantiate key types to ensure template code is checked +namespace { +[[maybe_unused]] void force_instantiation() { + using namespace executorch::backends::mlx; + + // Force safe_mul template instantiation + (void)safe_mul(0, 0, "test"); + + // Force check_allocation_bounded instantiation + ::mlx::core::Shape shape = {1, 2, 3}; + check_allocation_bounded(shape, ::mlx::core::float32, "test"); +} +} // namespace diff --git a/backends/mlx/test/test_ops.py b/backends/mlx/test/test_ops.py new file mode 100644 index 00000000000..01286f75f16 --- /dev/null +++ b/backends/mlx/test/test_ops.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Consolidated op tests for the MLX delegate. + +This file contains all op tests organized by category. Each test class inherits +from OpTestCase and can be run via the run_all_tests.py script. + +Usage: + # Run all tests (with 4 parallel workers, cleanup after) + python -m executorch.backends.mlx.test.run_all_tests -j4 --clean-after + + # Run specific test + python -m executorch.backends.mlx.test.run_all_tests add + + # List available tests + python -m executorch.backends.mlx.test.run_all_tests --list + +See README.md in this directory for full documentation. +""" + +from typing import List, Tuple + +import torch +import torch.nn as nn + +# Import custom ops for RoPE and KV cache tests +from executorch.backends.mlx import ( # noqa: F401 - registers mlx ops # noqa: F401 - registers mlx.rope + custom_ops, + ops, +) + +from .test_utils import OpTestCase, register_test + + +class BmmModel(nn.Module): + """Model that performs batch matrix multiplication.""" + + def __init__(self, batch_size: int, n: int, m: int, p: int): + super().__init__() + self.weight = nn.Parameter(torch.randn(batch_size, m, p)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.bmm(x, self.weight) + + +@register_test +class BmmTest(OpTestCase): + """Test case for bmm (batch matrix multiplication).""" + + name = "bmm" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + batch_size: int = 4, + n: int = 8, + m: int = 16, + p: int = 32, + ): + self.batch_size = batch_size + self.n = n + self.m = m + self.p = p + self.name = f"bmm_{batch_size}x{n}x{m}x{p}" + + @classmethod + def get_test_configs(cls) -> List["BmmTest"]: + return [ + cls(batch_size=4, n=8, m=16, p=32), + cls(batch_size=2, n=64, m=64, p=32), + ] + + def create_model(self) -> nn.Module: + return BmmModel(self.batch_size, self.n, self.m, self.p) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.batch_size, self.n, self.m) + return (x,) + + +class AddmmModel(nn.Module): + """Model that performs addmm: bias + (mat1 @ mat2).""" + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + alpha: float = 1.0, + beta: float = 1.0, + ): + super().__init__() + self.weight = nn.Parameter(torch.randn(out_features, in_features)) + if bias: + self.bias = nn.Parameter(torch.randn(out_features)) + else: + self.bias = None + self.alpha = alpha + self.beta = beta + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.bias is not None: + return torch.addmm( + self.bias, x, self.weight.t(), beta=self.beta, alpha=self.alpha + ) + else: + return torch.mm(x, self.weight.t()) + + +@register_test +class AddmmTest(OpTestCase): + """Test case for addmm.""" + + name = "addmm" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + batch_size: int = 2, + in_features: int = 64, + out_features: int = 32, + bias: bool = True, + alpha: float = 1.0, + beta: float = 1.0, + ): + self.batch_size = batch_size + self.in_features = in_features + self.out_features = out_features + self.bias = bias + self.alpha = alpha + self.beta = beta + + # Build unique test name + if not bias: + name = f"addmm_{in_features}x{out_features}_no_bias" + elif alpha != 1.0 or beta != 1.0: + name = f"addmm_{in_features}x{out_features}_a{alpha}_b{beta}" + else: + name = f"addmm_{in_features}x{out_features}" + self.name = name + + @classmethod + def get_test_configs(cls) -> List["AddmmTest"]: + return [ + cls( + batch_size=2, in_features=64, out_features=32 + ), # with bias, default alpha/beta + cls( + batch_size=2, in_features=64, out_features=32, bias=False + ), # without bias + cls(batch_size=4, in_features=128, out_features=64), # larger size + cls( + batch_size=2, in_features=64, out_features=32, alpha=2.0, beta=0.5 + ), # custom alpha/beta + ] + + def create_model(self) -> nn.Module: + return AddmmModel( + self.in_features, + self.out_features, + bias=self.bias, + alpha=self.alpha, + beta=self.beta, + ) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.batch_size, self.in_features) + return (x,) diff --git a/backends/mlx/test/test_partitioner.py b/backends/mlx/test/test_partitioner.py new file mode 100644 index 00000000000..4a5833aa656 --- /dev/null +++ b/backends/mlx/test/test_partitioner.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for the MLX partitioner. +""" + +import unittest + +import torch +import torch.nn as nn +from executorch.backends.mlx.partitioner import MLXPartitioner +from executorch.exir import EdgeCompileConfig, to_edge +from torch.export import export + + +class TestMLXPartitionerRejectsToEdge(unittest.TestCase): + """MLXPartitioner must only be used via to_edge_transform_and_lower.""" + + def test_to_edge_then_to_backend_raises(self): + class M(nn.Module): + def forward(self, x): + return x + 1 + + ep = export(M(), (torch.randn(4),), strict=False) + edge = to_edge( + ep, + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ), + ) + + with self.assertRaises(RuntimeError) as ctx: + edge.to_backend(MLXPartitioner()) + + self.assertIn("to_edge_transform_and_lower", str(ctx.exception)) + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/mlx/test/test_passes.py b/backends/mlx/test/test_passes.py new file mode 100644 index 00000000000..a9fdb3b996b --- /dev/null +++ b/backends/mlx/test/test_passes.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/mlx/test/test_pattern_utils.py b/backends/mlx/test/test_pattern_utils.py new file mode 100644 index 00000000000..48495a469d7 --- /dev/null +++ b/backends/mlx/test/test_pattern_utils.py @@ -0,0 +1,592 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for pattern_utils.py - shared pattern matching utilities. +""" + +import unittest + +import torch +from torch.export import export + + +def get_exported_graph(module, example_inputs): + """Export a module and return the graph with ATen ops.""" + ep = export(module, example_inputs) + return ep.graph_module.graph + + +def find_node_by_target(graph, target_name): + """Find first call_function node whose target contains target_name.""" + for node in graph.nodes: + if node.op == "call_function" and target_name in str(node.target): + return node + return None + + +def find_all_nodes_by_target(graph, target_name): + """Find all call_function nodes whose target contains target_name.""" + return [ + node + for node in graph.nodes + if node.op == "call_function" and target_name in str(node.target) + ] + + +class TestMatchTarget(unittest.TestCase): + """Tests for match_target function.""" + + def test_match_target_basic(self): + """Test basic op matching.""" + from executorch.backends.mlx.pattern_utils import match_target + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + self.assertIsNotNone(rsqrt_node) + self.assertTrue(match_target(rsqrt_node, torch.ops.aten.rsqrt.default)) + self.assertFalse(match_target(rsqrt_node, torch.ops.aten.add.Tensor)) + + def test_match_target_non_call_function(self): + """Test that non-call_function nodes don't match.""" + from executorch.backends.mlx.pattern_utils import match_target + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + + # Find a placeholder node + placeholder_node = None + for node in graph.nodes: + if node.op == "placeholder": + placeholder_node = node + break + + self.assertIsNotNone(placeholder_node) + self.assertFalse(match_target(placeholder_node, torch.ops.aten.rsqrt.default)) + + +class TestHasSingleUser(unittest.TestCase): + """Tests for has_single_user function.""" + + def test_single_user(self): + """Test node with single user.""" + from executorch.backends.mlx.pattern_utils import has_single_user + + class SingleUserModule(torch.nn.Module): + def forward(self, x): + y = torch.neg(x) # Single use + return y + 1 + + graph = get_exported_graph(SingleUserModule(), (torch.randn(4, 4),)) + neg_node = find_node_by_target(graph, "neg") + + self.assertIsNotNone(neg_node) + self.assertTrue(has_single_user(neg_node)) + + def test_multiple_users(self): + """Test node with multiple users.""" + from executorch.backends.mlx.pattern_utils import has_single_user + + class MultiUserModule(torch.nn.Module): + def forward(self, x): + y = torch.neg(x) # Used by both add and mul + a = y + 1 + b = y * 2 + return a + b + + graph = get_exported_graph(MultiUserModule(), (torch.randn(4, 4),)) + neg_node = find_node_by_target(graph, "neg") + + self.assertIsNotNone(neg_node) + self.assertFalse(has_single_user(neg_node)) + + +class TestHasNoUsers(unittest.TestCase): + """Tests for has_no_users function.""" + + def test_has_users(self): + """Test node that has users.""" + from executorch.backends.mlx.pattern_utils import has_no_users + + class SimpleModule(torch.nn.Module): + def forward(self, x): + y = torch.neg(x) + return y + 1 + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + neg_node = find_node_by_target(graph, "neg") + + self.assertIsNotNone(neg_node) + self.assertFalse(has_no_users(neg_node)) + + def test_no_users_after_removal(self): + """Test has_no_users returns True for orphaned nodes.""" + from executorch.backends.mlx.pattern_utils import has_no_users + + class SimpleModule(torch.nn.Module): + def forward(self, x): + y = torch.neg(x) + z = torch.rsqrt(y) + return z + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + neg_node = find_node_by_target(graph, "neg") + rsqrt_node = find_node_by_target(graph, "rsqrt") + + # Initially neg has a user (rsqrt) + self.assertFalse(has_no_users(neg_node)) + + # Replace rsqrt's input with placeholder to orphan neg + placeholder = None + for node in graph.nodes: + if node.op == "placeholder": + placeholder = node + break + rsqrt_node.replace_input_with(neg_node, placeholder) + + # Now neg has no users + self.assertTrue(has_no_users(neg_node)) + + +class TestOpStep(unittest.TestCase): + """Tests for OpStep dataclass.""" + + def test_matches_with_op(self): + """Test OpStep.matches with op field.""" + from executorch.backends.mlx.pattern_utils import OpStep + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + step = OpStep(op=torch.ops.aten.rsqrt.default) + self.assertTrue(step.matches(rsqrt_node)) + + step_wrong = OpStep(op=torch.ops.aten.neg.default) + self.assertFalse(step_wrong.matches(rsqrt_node)) + + def test_matches_with_predicate(self): + """Test OpStep.matches with predicate field.""" + from executorch.backends.mlx.pattern_utils import OpStep + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + # Predicate that always returns True + step_true = OpStep(predicate=lambda n: True) + self.assertTrue(step_true.matches(rsqrt_node)) + + # Predicate that always returns False + step_false = OpStep(predicate=lambda n: False) + self.assertFalse(step_false.matches(rsqrt_node)) + + def test_matches_no_op_no_predicate(self): + """Test OpStep.matches returns False when neither op nor predicate set.""" + from executorch.backends.mlx.pattern_utils import OpStep + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + step_empty = OpStep() + self.assertFalse(step_empty.matches(rsqrt_node)) + + def test_matches_require_single_user_true(self): + """Test OpStep.matches with require_single_user=True (default).""" + from executorch.backends.mlx.pattern_utils import OpStep + + class MultiUserModule(torch.nn.Module): + def forward(self, x): + y = torch.neg(x) # Used by both add and mul + a = y + 1 + b = y * 2 + return a + b + + graph = get_exported_graph(MultiUserModule(), (torch.randn(4, 4),)) + neg_node = find_node_by_target(graph, "neg") + + # Default require_single_user=True, neg has multiple users + step = OpStep(op=torch.ops.aten.neg.default) + self.assertFalse(step.matches(neg_node)) + + def test_matches_require_single_user_false(self): + """Test OpStep.matches with require_single_user=False.""" + from executorch.backends.mlx.pattern_utils import OpStep + + class MultiUserModule(torch.nn.Module): + def forward(self, x): + y = torch.neg(x) # Used by both add and mul + a = y + 1 + b = y * 2 + return a + b + + graph = get_exported_graph(MultiUserModule(), (torch.randn(4, 4),)) + neg_node = find_node_by_target(graph, "neg") + + # With require_single_user=False, should match despite multiple users + step = OpStep(op=torch.ops.aten.neg.default, require_single_user=False) + self.assertTrue(step.matches(neg_node)) + + def test_matches_nargs_int(self): + """Test OpStep.matches with nargs as int (minimum).""" + from executorch.backends.mlx.pattern_utils import OpStep + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) # rsqrt has 1 arg + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + # nargs=1 should match (rsqrt has 1 arg) + step = OpStep(op=torch.ops.aten.rsqrt.default, nargs=1) + self.assertTrue(step.matches(rsqrt_node)) + + # nargs=2 should fail (rsqrt only has 1 arg) + step_too_many = OpStep(op=torch.ops.aten.rsqrt.default, nargs=2) + self.assertFalse(step_too_many.matches(rsqrt_node)) + + def test_matches_nargs_tuple(self): + """Test OpStep.matches with nargs as tuple (range).""" + from executorch.backends.mlx.pattern_utils import OpStep + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) # rsqrt has 1 arg + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + # nargs=(1, 3) should match (rsqrt has 1 arg, in range) + step = OpStep(op=torch.ops.aten.rsqrt.default, nargs=(1, 3)) + self.assertTrue(step.matches(rsqrt_node)) + + # nargs=(2, 4) should fail (rsqrt has 1 arg, not in range) + step_out_of_range = OpStep(op=torch.ops.aten.rsqrt.default, nargs=(2, 4)) + self.assertFalse(step_out_of_range.matches(rsqrt_node)) + + def test_matches_kwargs_empty(self): + """Test OpStep.matches with empty kwargs (node must have no kwargs).""" + from executorch.backends.mlx.pattern_utils import OpStep + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) # No kwargs + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + # Empty kwargs set() means node must have no kwargs (default) + step = OpStep(op=torch.ops.aten.rsqrt.default, kwargs=set()) + self.assertTrue(step.matches(rsqrt_node)) + + # Default is also empty set (strict checking) + step_default = OpStep(op=torch.ops.aten.rsqrt.default) + self.assertTrue(step_default.matches(rsqrt_node)) + + def test_matches_kwargs_declared(self): + """Test OpStep.matches with declared kwargs.""" + from executorch.backends.mlx.pattern_utils import OpStep + + class CastModule(torch.nn.Module): + def forward(self, x): + return x.to(torch.float16) + + graph = get_exported_graph(CastModule(), (torch.randn(4, 4),)) + to_copy_node = find_node_by_target(graph, "_to_copy") + + if to_copy_node is not None: + # Check what kwargs exist + node_kwargs = set(to_copy_node.kwargs.keys()) + + # If we declare all kwargs, should match + step_all = OpStep( + op=torch.ops.aten._to_copy.default, + kwargs=node_kwargs, + ) + self.assertTrue(step_all.matches(to_copy_node)) + + # If we don't declare some kwargs, should fail + if node_kwargs: + step_missing = OpStep( + op=torch.ops.aten._to_copy.default, + kwargs=set(), # Empty, but node has kwargs + ) + self.assertFalse(step_missing.matches(to_copy_node)) + + def test_matches_arg_index(self): + """Test OpStep.matches validates arg_index is accessible.""" + from executorch.backends.mlx.pattern_utils import OpStep + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) # rsqrt has 1 arg + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + # arg_index=0 should work (rsqrt has 1 arg) + step = OpStep(op=torch.ops.aten.rsqrt.default, arg_index=0) + self.assertTrue(step.matches(rsqrt_node)) + + # arg_index=1 should fail (rsqrt only has 1 arg, can't access args[1]) + step_bad_index = OpStep(op=torch.ops.aten.rsqrt.default, arg_index=1) + self.assertFalse(step_bad_index.matches(rsqrt_node)) + + +class TestWalkBack(unittest.TestCase): + """Tests for walk_back function.""" + + def test_walk_back_single_step(self): + """Test walk_back with a single step.""" + from executorch.backends.mlx.pattern_utils import OpStep, walk_back + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + result = walk_back(rsqrt_node, [OpStep(op=torch.ops.aten.rsqrt.default)]) + + self.assertIsNotNone(result) + base_node, entries = result + self.assertEqual(len(entries), 1) + self.assertEqual(entries[0], rsqrt_node) + # base_node should be the input to rsqrt + self.assertEqual(base_node.op, "placeholder") + + def test_walk_back_chain(self): + """Test walk_back with multiple steps in a chain.""" + from executorch.backends.mlx.pattern_utils import OpStep, walk_back + + class ChainModule(torch.nn.Module): + def forward(self, x): + y = torch.neg(x) + z = torch.rsqrt(y) + return z + + graph = get_exported_graph(ChainModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + # Match rsqrt -> neg chain + result = walk_back( + rsqrt_node, + [ + OpStep(op=torch.ops.aten.rsqrt.default), + OpStep(op=torch.ops.aten.neg.default), + ], + ) + + self.assertIsNotNone(result) + base_node, entries = result + self.assertEqual(len(entries), 2) + self.assertEqual(base_node.op, "placeholder") + + def test_walk_back_no_match(self): + """Test walk_back returns None when pattern doesn't match.""" + from executorch.backends.mlx.pattern_utils import OpStep, walk_back + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + # Try to match neg which isn't there + result = walk_back(rsqrt_node, [OpStep(op=torch.ops.aten.neg.default)]) + + self.assertIsNone(result) + + def test_walk_back_optional_step(self): + """Test walk_back with optional step that doesn't match.""" + from executorch.backends.mlx.pattern_utils import OpStep, walk_back + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + # Match rsqrt, skip optional neg (not present) + result = walk_back( + rsqrt_node, + [ + OpStep(op=torch.ops.aten.rsqrt.default), + OpStep(op=torch.ops.aten.neg.default, optional=True), + ], + ) + + self.assertIsNotNone(result) + base_node, entries = result + self.assertEqual(len(entries), 2) # One for each step + self.assertIsNotNone(entries[0]) # rsqrt matched + self.assertIsNone(entries[1]) # neg is None (optional, not matched) + + def test_walk_back_repeat_step(self): + """Test walk_back with repeat step.""" + from executorch.backends.mlx.pattern_utils import OpStep, walk_back + + class RepeatModule(torch.nn.Module): + def forward(self, x): + y = torch.neg(x) + z = torch.neg(y) + w = torch.neg(z) + return w + + graph = get_exported_graph(RepeatModule(), (torch.randn(4, 4),)) + + # Find the last neg node (output of the chain) + neg_nodes = find_all_nodes_by_target(graph, "neg") + self.assertEqual(len(neg_nodes), 3) + last_neg = neg_nodes[-1] + + # Match chain of neg ops + result = walk_back( + last_neg, + [OpStep(op=torch.ops.aten.neg.default, repeat=True)], + ) + + self.assertIsNotNone(result) + base_node, entries = result + self.assertEqual(len(entries), 1) # One entry for the repeat step + self.assertIsInstance(entries[0], list) # Repeat returns list + self.assertEqual(len(entries[0]), 3) # Three neg nodes matched + self.assertEqual(base_node.op, "placeholder") + + def test_walk_back_repeat_zero_matches(self): + """Test walk_back with repeat step matching zero times then another step.""" + from executorch.backends.mlx.pattern_utils import OpStep, walk_back + + class SimpleModule(torch.nn.Module): + def forward(self, x): + return torch.rsqrt(x) + + graph = get_exported_graph(SimpleModule(), (torch.randn(4, 4),)) + rsqrt_node = find_node_by_target(graph, "rsqrt") + + # Try to match neg (repeat, 0 matches) then rsqrt + # neg doesn't exist at rsqrt, so 0 matches, then we match rsqrt + result = walk_back( + rsqrt_node, + [ + OpStep(op=torch.ops.aten.neg.default, repeat=True), + OpStep(op=torch.ops.aten.rsqrt.default), + ], + ) + + # This should match: neg repeat matches 0 times, rsqrt matches + self.assertIsNotNone(result) + base_node, entries = result + self.assertEqual(len(entries), 2) # One for each step + self.assertIsInstance(entries[0], list) # Repeat returns list + self.assertEqual(len(entries[0]), 0) # Zero neg nodes matched + self.assertIsNotNone(entries[1]) # rsqrt matched + + def test_walk_back_arg_index(self): + """Test walk_back with arg_index to follow non-first argument.""" + from executorch.backends.mlx.pattern_utils import OpStep, walk_back + + class BinaryModule(torch.nn.Module): + def forward(self, x): + y = torch.rsqrt(x) + return x * y # mul(x, rsqrt(x)) + + graph = get_exported_graph(BinaryModule(), (torch.randn(4, 4),)) + mul_node = find_node_by_target(graph, "mul") + rsqrt_node = find_node_by_target(graph, "rsqrt") + + self.assertIsNotNone(mul_node) + self.assertIsNotNone(rsqrt_node) + + # Follow args[1] (rsqrt) instead of args[0] (placeholder) + result = walk_back( + mul_node, + [ + OpStep(op=torch.ops.aten.mul.Tensor, nargs=2, arg_index=1), + OpStep(op=torch.ops.aten.rsqrt.default), + ], + ) + + self.assertIsNotNone(result) + base_node, entries = result + self.assertEqual(len(entries), 2) # mul and rsqrt + self.assertEqual(entries[0], mul_node) + self.assertEqual(entries[1], rsqrt_node) + # base_node should be the input to rsqrt (placeholder) + self.assertEqual(base_node.op, "placeholder") + + +class TestPatternMatch(unittest.TestCase): + """Tests for PatternMatch base class.""" + + def test_all_nodes(self): + """Test all_nodes returns head + body.""" + from executorch.backends.mlx.pattern_utils import PatternMatch + + class ChainModule(torch.nn.Module): + def forward(self, x): + y = torch.neg(x) + z = torch.rsqrt(y) + return z + + graph = get_exported_graph(ChainModule(), (torch.randn(4, 4),)) + neg_node = find_node_by_target(graph, "neg") + rsqrt_node = find_node_by_target(graph, "rsqrt") + + match = PatternMatch(head=rsqrt_node, body=[neg_node]) + self.assertEqual(match.all_nodes(), [rsqrt_node, neg_node]) + + def test_remove_body_nodes(self): + """Test remove_body_nodes removes unused nodes.""" + from executorch.backends.mlx.pattern_utils import PatternMatch + + class ChainModule(torch.nn.Module): + def forward(self, x): + y = torch.neg(x) + z = torch.rsqrt(y) + return z + + graph = get_exported_graph(ChainModule(), (torch.randn(4, 4),)) + neg_node = find_node_by_target(graph, "neg") + rsqrt_node = find_node_by_target(graph, "rsqrt") + + # To test remove_body_nodes, we'd need to first replace rsqrt's uses + # and then call remove_body_nodes. For this test, just verify the + # method exists and doesn't crash when nodes have users. + match = PatternMatch(head=rsqrt_node, body=[neg_node]) + + # This won't remove neg because it still has a user (rsqrt) + match.remove_body_nodes(graph) + + # neg should still exist because it has a user + self.assertIn(neg_node, graph.nodes) + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/mlx/test/test_utils.py b/backends/mlx/test/test_utils.py new file mode 100644 index 00000000000..090bceabf08 --- /dev/null +++ b/backends/mlx/test/test_utils.py @@ -0,0 +1,1122 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Utilities for MLX delegate op testing. + +This module provides functions to: +1. Save/load tensors to/from binary files (compatible with C++ op_test_runner) +2. Export simple models to .pte files +3. Compare expected vs actual outputs +4. Run the C++ op_test_runner binary +""" + +import json +import os +import struct +import subprocess +import tempfile +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch + + +DEFAULT_TEST_TIMEOUT = 300 # 5 minutes default timeout + + +class TestTimeoutError(Exception): + """Raised when a test exceeds its timeout.""" + + pass + + +# DType enum values matching C++ op_test_runner +DTYPE_FLOAT32 = 0 +DTYPE_FLOAT16 = 1 +DTYPE_INT32 = 2 +DTYPE_INT64 = 3 +DTYPE_BFLOAT16 = 4 +DTYPE_BOOL = 5 + + +# Default tolerance presets for different data types. +# These are based on the precision characteristics of each dtype: +# - FP32: ~7 decimal digits of precision +# - FP16: ~3-4 decimal digits of precision +# - BF16: ~2-3 decimal digits of precision (same exponent range as FP32) +TOLERANCE_PRESETS = { + torch.float32: {"rtol": 1e-5, "atol": 1e-5}, + torch.float16: {"rtol": 1e-3, "atol": 1e-3}, + torch.bfloat16: {"rtol": 1e-2, "atol": 1e-2}, + # Integer types should match exactly + torch.int32: {"rtol": 0, "atol": 0}, + torch.int64: {"rtol": 0, "atol": 0}, +} + + +def get_tolerance_for_dtype(dtype: torch.dtype) -> Tuple[float, float]: + """ + Get appropriate (rtol, atol) tolerances for a given dtype. + + Args: + dtype: The torch dtype to get tolerances for. + + Returns: + (rtol, atol) tuple with appropriate tolerances for the dtype. + """ + if dtype in TOLERANCE_PRESETS: + preset = TOLERANCE_PRESETS[dtype] + return preset["rtol"], preset["atol"] + # Default to FP32 tolerances for unknown types + return 1e-5, 1e-5 + + +def get_tolerance_for_dtypes(dtypes: List[torch.dtype]) -> Tuple[float, float]: + """ + Get tolerances that work for a list of dtypes (uses the loosest tolerances). + + Args: + dtypes: List of torch dtypes. + + Returns: + (rtol, atol) tuple with tolerances that accommodate all dtypes. + """ + if not dtypes: + return 1e-5, 1e-5 + + max_rtol = 0.0 + max_atol = 0.0 + for dtype in dtypes: + rtol, atol = get_tolerance_for_dtype(dtype) + max_rtol = max(max_rtol, rtol) + max_atol = max(max_atol, atol) + + return max_rtol, max_atol + + +def torch_dtype_to_bin_dtype(dtype: torch.dtype) -> int: + """Convert torch dtype to binary file dtype enum value.""" + mapping = { + torch.float32: DTYPE_FLOAT32, + torch.float16: DTYPE_FLOAT16, + torch.int32: DTYPE_INT32, + torch.int64: DTYPE_INT64, + torch.bfloat16: DTYPE_BFLOAT16, + torch.bool: DTYPE_BOOL, + } + if dtype not in mapping: + raise ValueError(f"Unsupported dtype: {dtype}") + return mapping[dtype] + + +def bin_dtype_to_torch_dtype(dtype_val: int) -> torch.dtype: + """Convert binary file dtype enum value to torch dtype.""" + mapping = { + DTYPE_FLOAT32: torch.float32, + DTYPE_FLOAT16: torch.float16, + DTYPE_INT32: torch.int32, + DTYPE_INT64: torch.int64, + DTYPE_BFLOAT16: torch.bfloat16, + DTYPE_BOOL: torch.bool, + } + if dtype_val not in mapping: + raise ValueError(f"Unknown dtype value: {dtype_val}") + return mapping[dtype_val] + + +def _atomic_write_binary(path: Path, data: bytes) -> None: + """ + Atomically write binary data to a file. + + Writes to a temporary file in the same directory, then atomically replaces + the target path. This prevents race conditions when multiple parallel + workers write to the same ``op_tests/`` tree. + """ + path.parent.mkdir(parents=True, exist_ok=True) + fd, tmp = tempfile.mkstemp(dir=path.parent, suffix=".tmp") + closed = False + try: + os.write(fd, data) + os.close(fd) + closed = True + os.replace(tmp, path) + except BaseException: + if not closed: + os.close(fd) + if os.path.exists(tmp): + os.unlink(tmp) + raise + + +def save_tensors_to_bin(tensors: List[torch.Tensor], path: Union[str, Path]) -> None: + """ + Save a list of tensors to a binary file. + + Binary format: + - 4 bytes: number of tensors (uint32) + For each tensor: + - 4 bytes: dtype enum (uint32) + - 4 bytes: number of dimensions (uint32) + - 4 bytes * ndim: shape (int32 each) + - N bytes: tensor data + """ + path = Path(path) + + buf = bytearray() + # Write number of tensors + buf += struct.pack("I", len(tensors)) + + for tensor in tensors: + # Ensure contiguous + tensor = tensor.contiguous() + + # Write dtype + dtype_val = torch_dtype_to_bin_dtype(tensor.dtype) + buf += struct.pack("I", dtype_val) + + # Write ndim + buf += struct.pack("I", tensor.dim()) + + # Write shape + for s in tensor.shape: + buf += struct.pack("i", s) + + # Write data - bf16 needs special handling since numpy doesn't support it + if tensor.dtype == torch.bfloat16: + # View bf16 as uint16 to preserve raw bytes + buf += tensor.view(torch.uint16).numpy().tobytes() + else: + buf += tensor.numpy().tobytes() + + _atomic_write_binary(path, bytes(buf)) + + +def load_tensors_from_bin(path: Union[str, Path]) -> List[torch.Tensor]: + path = Path(path) + + # Mapping from torch dtype to numpy dtype + np_dtype_map = { + torch.float32: np.float32, + torch.float16: np.float16, + torch.int32: np.int32, + torch.int64: np.int64, + torch.bool: np.bool_, + # bfloat16 needs special handling - read as uint16 + } + + # Element size for each dtype + elem_size_map = { + torch.float32: 4, + torch.float16: 2, + torch.int32: 4, + torch.int64: 8, + torch.bfloat16: 2, + torch.bool: 1, + } + + tensors = [] + with open(path, "rb") as f: + # Read number of tensors + num_tensors = struct.unpack("I", f.read(4))[0] + + for _ in range(num_tensors): + # Read dtype + dtype_val = struct.unpack("I", f.read(4))[0] + dtype = bin_dtype_to_torch_dtype(dtype_val) + + # Read ndim + ndim = struct.unpack("I", f.read(4))[0] + + # Read shape + shape = [] + for _ in range(ndim): + shape.append(struct.unpack("i", f.read(4))[0]) + + # Read data + numel = 1 + for s in shape: + numel *= s + + elem_size = elem_size_map[dtype] + data_bytes = f.read(numel * elem_size) + + # Convert to tensor + if dtype == torch.bfloat16: + # Read as uint16 and view as bfloat16 + arr = np.frombuffer(data_bytes, dtype=np.uint16).reshape(shape) + tensor = torch.tensor(arr).view(torch.bfloat16) + else: + arr = np.frombuffer(data_bytes, dtype=np_dtype_map[dtype]).reshape( + shape + ) + tensor = torch.from_numpy(arr.copy()) + + tensors.append(tensor) + + return tensors + + +def export_model_to_pte( + model: torch.nn.Module, + example_inputs: Tuple[torch.Tensor, ...], + output_path: Union[str, Path], + dynamic_shapes: Optional[Dict] = None, + verbose: bool = False, +) -> None: + """ + Export a PyTorch model to a .pte file using the MLX delegate. + + Args: + model: The PyTorch model to export. + example_inputs: Example inputs for tracing. + output_path: Path to save the .pte file. + dynamic_shapes: + dynamic_shapes: Optional dynamic shapes specification for torch.export. + Example: {0: {0: Dim("batch", min=1, max=32)}} for dynamic batch on first input. + verbose: Whether to print the exported program for debugging. + """ + import executorch.exir as exir + from executorch.backends.mlx import MLXPartitioner + from executorch.exir.capture._config import ExecutorchBackendConfig + from torch.export import export + + model = model.eval() + + # Export with torch.export + exported_program = export( + model, example_inputs, dynamic_shapes=dynamic_shapes, strict=True + ) + + # Print exported program if verbose + if verbose: + print("\n" + "=" * 60) + print("EXPORTED PROGRAM (torch.export)") + print("=" * 60) + print(exported_program) + + # Lower to edge and delegate to MLX + edge_program = exir.to_edge_transform_and_lower( + exported_program, + partitioner=[MLXPartitioner()], + ) + + # Print edge program if verbose + if verbose: + print("\n" + "=" * 60) + print("EDGE PROGRAM (after decomposition)") + print("=" * 60) + print(edge_program.exported_program()) + + # Export to ExecuTorch + executorch_program = edge_program.to_executorch( + config=ExecutorchBackendConfig(extract_delegate_segments=True) + ) + + # Save to file + output_path = Path(output_path) + _atomic_write_binary(output_path, executorch_program.buffer) + + +def inspect_pte_file(pte_path: Union[str, Path]) -> Dict: + """ + Inspect a PTE file and return the MLX graph information. + + Returns: + Dictionary with MLX graph details + """ + from executorch.backends.mlx.pte_inspector import ( + extract_delegate_payload, + parse_mlx_payload, + ) + + pte_path = Path(pte_path) + pte_data = pte_path.read_bytes() + + # Extract MLX delegate payload + payload = extract_delegate_payload(pte_data, "MLXBackend") + if payload is None: + return {"error": "Could not extract MLX delegate payload"} + + # Parse the MLX payload + mlx_data = parse_mlx_payload(payload) + return mlx_data + + +def print_mlx_graph_summary(pte_path: Union[str, Path]) -> None: + """ + Print a human-readable summary of the MLX graph in a PTE file. + + This function uses the pte_inspector module to display the MLX graph. + """ + from executorch.backends.mlx.pte_inspector import show_mlx_instructions + + pte_path = Path(pte_path) + pte_data = pte_path.read_bytes() + show_mlx_instructions(pte_data) + + +def count_mlx_delegate_segments(pte_path: Union[str, Path]) -> int: + """ + Count the number of MLX delegate segments in a PTE file. + + Args: + pte_path: Path to the .pte file + + Returns: + Number of MLX delegate segments found + """ + from executorch.exir._serialize._flatbuffer import _program_flatbuffer_to_json + + pte_path = Path(pte_path) + pte_data = pte_path.read_bytes() + + try: + program_json = _program_flatbuffer_to_json(pte_data) + program_data = json.loads(program_json) + + # Count all MLX delegates across all execution plans + count = 0 + for plan in program_data.get("execution_plan", []): + for delegate in plan.get("delegates", []): + delegate_name = delegate.get("id", "") + # Match MLXBackend (case-insensitive) + if "mlx" in delegate_name.lower(): + count += 1 + + return count + except Exception as e: + print(f"Error counting MLX segments: {e}") + return 0 + + +def get_mlx_node_counts(pte_path: Union[str, Path]) -> Dict[str, int]: + """ + Get a count of each MLX op node type in a serialized .pte file. + + Args: + pte_path: Path to the .pte file + + Returns: + Dictionary mapping op name (e.g. "SdpaNode", "SliceUpdateNode") to count. + """ + data = inspect_pte_file(pte_path) + graph = data.get("graph", {}) + counts: Dict[str, int] = {} + for chain_info in graph.get("instruction_chains", []): + for instr in chain_info.get("instructions", []): + op_name = instr.get("op_name") + if op_name: + counts[op_name] = counts.get(op_name, 0) + 1 + return counts + + +def compare_outputs( + expected: List[torch.Tensor], + actual: List[torch.Tensor], + rtol: float = 1e-5, + atol: float = 1e-5, +) -> Tuple[bool, str]: + """ + Compare expected and actual outputs using torch.allclose. + + Returns: + (passed, message) tuple + """ + if len(expected) != len(actual): + return ( + False, + f"Output count mismatch: expected {len(expected)}, got {len(actual)}", + ) + + for i, (exp, act) in enumerate(zip(expected, actual)): + if exp.shape != act.shape: + return ( + False, + f"Output {i} shape mismatch: expected {exp.shape}, got {act.shape}", + ) + + if exp.dtype != act.dtype: + # Convert both to float32 for comparison + exp = exp.float() + act = act.float() + + # For bool tensors, use exact comparison + if exp.dtype == torch.bool: + if not torch.equal(exp, act): + mismatches = (exp != act).sum().item() + total = exp.numel() + return False, ( + f"Output {i} values do not match:\n" + f" {mismatches}/{total} elements differ\n" + f" expected[:5]={exp.flatten()[:5].tolist()}\n" + f" actual[:5]={act.flatten()[:5].tolist()}" + ) + elif not torch.allclose(exp, act, rtol=rtol, atol=atol): + diff = (exp - act).abs() + max_diff = diff.max().item() + mean_diff = diff.float().mean().item() + return False, ( + f"Output {i} values do not match:\n" + f" max_diff={max_diff:.6e}, mean_diff={mean_diff:.6e}\n" + f" rtol={rtol}, atol={atol}\n" + f" expected[:5]={exp.flatten()[:5].tolist()}\n" + f" actual[:5]={act.flatten()[:5].tolist()}" + ) + + return True, "All outputs match" + + +def find_executorch_root() -> Path: # noqa: C901 + """Find the executorch root directory.""" + test_dir = Path(__file__).parent + + # Walk up to find the executorch root (has CMakeLists.txt and backends dir at root) + executorch_root = test_dir + for _ in range(10): # Max 10 levels up + if (executorch_root / "CMakeLists.txt").exists() and ( + executorch_root / "backends" + ).exists(): + # Check if we're in src/executorch (editable install) + if ( + executorch_root.name == "executorch" + and executorch_root.parent.name == "src" + ): + executorch_root = executorch_root.parent.parent + break + executorch_root = executorch_root.parent + + # If we didn't find a valid root (e.g. running from a pip-installed + # site-packages), fall back to cwd which is typically the repo root. + if not (executorch_root / "CMakeLists.txt").exists(): + cwd = Path.cwd() + if (cwd / "CMakeLists.txt").exists() and (cwd / "backends").exists(): + executorch_root = cwd + + return executorch_root + + +def find_build_dir(): + """Find the cmake build directory containing op_test_runner.""" + executorch_root = find_executorch_root() + + # Check common build locations + candidates = [ + executorch_root / "cmake-out-mlx", + executorch_root / "cmake-out", + executorch_root / "build", + ] + + for candidate in candidates: + runner_path = candidate / "backends" / "mlx" / "test" / "op_test_runner" + if runner_path.exists(): + return candidate + + # Return first candidate that exists as a directory (for rebuild) + for candidate in candidates: + if candidate.is_dir(): + return candidate + + return None + + +def find_op_test_runner() -> Path: + """Find the op_test_runner binary.""" + executorch_root = find_executorch_root() + + # Check common build locations + candidates = [ + executorch_root + / "cmake-out-mlx" + / "backends" + / "mlx" + / "test" + / "op_test_runner", + executorch_root / "cmake-out" / "backends" / "mlx" / "test" / "op_test_runner", + executorch_root / "build" / "backends" / "mlx" / "test" / "op_test_runner", + ] + + for candidate in candidates: + if candidate.exists(): + return candidate + + raise FileNotFoundError( + "Could not find op_test_runner binary. Tried:\n" + + "\n".join(f" - {c}" for c in candidates) + + "\n\nBuild with:\n" + + " cmake --preset mlx-release -DEXECUTORCH_BUILD_TESTS=ON\n" + + " cmake --build cmake-out --target op_test_runner" + ) + + +def rebuild_op_test_runner(verbose: bool = False) -> bool: + """ + Rebuild the op_test_runner binary using cmake. + + Args: + verbose: Whether to print build output. + + Returns: + True if build succeeded, False otherwise. + """ + build_dir = find_build_dir() + if build_dir is None: + print("Error: Could not find cmake build directory.") + print("Make sure you have run cmake configuration first.") + return False + + print(f"Rebuilding op_test_runner in {build_dir}...") + + cmd = ["cmake", "--build", str(build_dir), "--target", "op_test_runner", "-j8"] + + if verbose: + print(f"Running: {' '.join(cmd)}") + + result = subprocess.run( + cmd, + capture_output=not verbose, + text=True, + ) + + if result.returncode != 0: + print(f"Build failed with exit code {result.returncode}") + if not verbose and result.stderr: + print(f"stderr: {result.stderr}") + if not verbose and result.stdout: + print(f"stdout: {result.stdout}") + return False + + print("Build succeeded.") + return True + + +def run_cpp_test_runner( + pte_path: Path, + input_path: Path, + output_path: Path, + verbose: bool = False, + timeout: Optional[int] = None, +) -> bool: + """ + Run the C++ op_test_runner binary. + + Args: + pte_path: Path to the .pte model file. + input_path: Path to input .bin file. + output_path: Path to write output .bin file. + verbose: Whether to print verbose output. + timeout: Timeout in seconds. None means use DEFAULT_TEST_TIMEOUT. + + Returns: + True if execution succeeded, False otherwise. + """ + if timeout is None: + timeout = DEFAULT_TEST_TIMEOUT + + runner = find_op_test_runner() + + cmd = [ + str(runner), + "--pte", + str(pte_path), + "--input", + str(input_path), + "--output", + str(output_path), + ] + if verbose: + cmd.append("--verbose") + + print(f"Running: {' '.join(cmd)}") + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=timeout, + ) + except subprocess.TimeoutExpired: + print(f"TIMEOUT: C++ runner exceeded {timeout}s timeout") + return False + + if result.returncode != 0: + print(f"FAILED: {result.stderr}") + print(f"stdout: {result.stdout}") + return False + + print(f"C++ binary output: {result.stdout.strip()}") + return True + + +# Files that are generated during tests and can be safely cleaned up +GENERATED_TEST_FILES = [ + "model.pte", + "input.bin", + "expected_output.bin", + "actual_output.bin", +] + + +def clean_test_outputs( + test_names: Optional[List[str]] = None, verbose: bool = False +) -> int: + """ + Clean up generated test output files. + + Args: + test_names: Optional list of test names to clean. If None, cleans all tests. + verbose: Whether to print verbose output. + + Returns: + Number of files removed. + """ + test_dir = Path(__file__).parent / "op_tests" + if not test_dir.exists(): + if verbose: + print(f"Test directory does not exist: {test_dir}") + return 0 + + files_removed = 0 + + # Get directories to clean + if test_names: + dirs_to_clean = [ + test_dir / name for name in test_names if (test_dir / name).exists() + ] + else: + dirs_to_clean = [d for d in test_dir.iterdir() if d.is_dir()] + + for subdir in dirs_to_clean: + for filename in GENERATED_TEST_FILES: + filepath = subdir / filename + if filepath.exists(): + if verbose: + print(f"Removing: {filepath}") + filepath.unlink() + files_removed += 1 + + # Remove empty directories + if subdir.exists() and not any(subdir.iterdir()): + if verbose: + print(f"Removing empty directory: {subdir}") + subdir.rmdir() + + return files_removed + + +def get_test_output_size(test_names: Optional[List[str]] = None) -> int: + """ + Get total size of generated test output files in bytes. + + Args: + test_names: Optional list of test names to check. If None, checks all tests. + + Returns: + Total size in bytes. + """ + test_dir = Path(__file__).parent / "op_tests" + if not test_dir.exists(): + return 0 + + total_size = 0 + + # Get directories to check + if test_names: + dirs_to_check = [ + test_dir / name for name in test_names if (test_dir / name).exists() + ] + else: + dirs_to_check = [d for d in test_dir.iterdir() if d.is_dir()] + + for subdir in dirs_to_check: + for filename in GENERATED_TEST_FILES: + filepath = subdir / filename + if filepath.exists(): + total_size += filepath.stat().st_size + + return total_size + + +# Global registry: maps base_name -> (test_class, get_test_configs method) +# Tests are instantiated lazily when actually run, not at import time +_TEST_REGISTRY: Dict[str, type] = {} + + +def register_test(test_class: type) -> type: + """ + Class decorator to register a test class. + + The test class must have: + - A class attribute `name` (str) - the base test name + - A class method `get_test_configs()` that returns a list of OpTestCase instances + + Test instances are created LAZILY when tests are actually run, not at import time. + This avoids creating random tensors at import time and keeps memory usage low. + + Example: + @register_test + class AddTest(OpTestCase): + name = "add" + + @classmethod + def get_test_configs(cls) -> List["OpTestCase"]: + return [ + cls(), # default config + cls(scalar=2.5), # scalar variant + ] + """ + if not hasattr(test_class, "name"): + raise ValueError( + f"Test class {test_class.__name__} must have a 'name' attribute" + ) + + base_name = test_class.name + _TEST_REGISTRY[base_name] = test_class + + return test_class + + +def get_registered_tests() -> Dict[str, List[Tuple[str, "OpTestCase"]]]: + """ + Get all registered tests with their configurations. + + Returns dict mapping base_name -> list of (config_name, test_instance). + Test instances are created fresh each time this is called. + """ + result = {} + for base_name, test_class in _TEST_REGISTRY.items(): + if hasattr(test_class, "get_test_configs"): + configs = test_class.get_test_configs() + else: + configs = [test_class()] + result[base_name] = [(cfg.name, cfg) for cfg in configs] + return result + + +def get_test_names() -> List[str]: + """Get list of registered base test names.""" + return list(_TEST_REGISTRY.keys()) + + +def get_all_test_configs() -> List[Tuple[str, "OpTestCase"]]: + """ + Get flat list of all (config_name, test_instance) tuples. + + Test instances are created fresh each time this is called. + """ + result = [] + for _base_name, test_class in _TEST_REGISTRY.items(): + if hasattr(test_class, "get_test_configs"): + configs = test_class.get_test_configs() + else: + configs = [test_class()] + result.extend((cfg.name, cfg) for cfg in configs) + return result + + +class OpTestCase: + """ + Base class for op test cases. + + Subclasses should implement: + - name: str - test name + - create_model() -> nn.Module + - create_inputs() -> Tuple[torch.Tensor, ...] + + Optionally override: + - get_dynamic_shapes() -> Optional[Dict] - for dynamic shape testing + - create_test_inputs() -> Tuple[torch.Tensor, ...] - inputs for testing (may differ from export inputs) + - expected_mlx_segments: int - expected number of MLX delegate segments (default: 1) + """ + + name: str = "base_test" + rtol: float = 1e-5 + atol: float = 1e-5 + seed: int = 42 # Default seed for reproducibility + timeout: int = DEFAULT_TEST_TIMEOUT # Timeout in seconds + skip_comparison: bool = False # Skip output comparison (for pattern-only tests) + skip_comparison_reason: str = "" # Reason for skipping comparison + expected_mlx_segments: int = 1 # Expected number of MLX delegate segments + expected_node_counts: Optional[Dict[str, int]] = ( + None # Expected serialized node counts + ) + + def _set_seed(self) -> None: + """Set random seed for reproducibility.""" + torch.manual_seed(self.seed) + + def create_model(self) -> torch.nn.Module: + raise NotImplementedError + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + """Create inputs for export (tracing).""" + raise NotImplementedError + + def create_test_inputs(self) -> Tuple[torch.Tensor, ...]: + """Create inputs for testing. Override for dynamic shape tests.""" + return self.create_inputs() + + def get_dynamic_shapes(self) -> Optional[Dict]: + """Return dynamic shapes specification for torch.export, or None for static shapes.""" + return None + + def get_test_dir(self) -> Path: + """Get the directory for this test's files.""" + test_dir = Path(__file__).parent / "op_tests" / self.name + test_dir.mkdir(parents=True, exist_ok=True) + return test_dir + + def generate_test_files(self, verbose: bool = False) -> Tuple[Path, Path, Path]: + """ + Generate .pte, input.bin, and expected_output.bin files. + + Args: + verbose: Whether to print the exported program for debugging. + + Returns: + (pte_path, input_path, expected_output_path) + """ + test_dir = self.get_test_dir() + + pte_path = test_dir / "model.pte" + input_path = test_dir / "input.bin" + expected_path = test_dir / "expected_output.bin" + + # Set seed for reproducibility + self._set_seed() + + # Create model and inputs + model = self.create_model() + export_inputs = self.create_inputs() + + # Set seed again before creating test inputs (in case they differ) + self._set_seed() + test_inputs = self.create_test_inputs() + + # Get expected outputs using test inputs + model.eval() + with torch.no_grad(): + if isinstance(test_inputs, torch.Tensor): + test_inputs = (test_inputs,) + expected_outputs = model(*test_inputs) + if isinstance(expected_outputs, torch.Tensor): + expected_outputs = [expected_outputs] + else: + expected_outputs = list(expected_outputs) + + # Export model with export inputs (and potentially dynamic shapes) + print(f"Exporting model to {pte_path}") + if isinstance(export_inputs, torch.Tensor): + export_inputs = (export_inputs,) + + dynamic_shapes = self.get_dynamic_shapes() + if dynamic_shapes: + print(f" Using dynamic shapes: {dynamic_shapes}") + + export_model_to_pte( + model, + export_inputs, + pte_path, + dynamic_shapes=dynamic_shapes, + verbose=verbose, + ) + + # Save test inputs + print(f"Saving inputs to {input_path}") + if isinstance(test_inputs, torch.Tensor): + test_inputs = [test_inputs] + else: + test_inputs = list(test_inputs) + save_tensors_to_bin(test_inputs, input_path) + + # Save expected outputs + print(f"Saving expected outputs to {expected_path}") + save_tensors_to_bin(expected_outputs, expected_path) + + return pte_path, input_path, expected_path + + def compare_with_actual( + self, actual_output_path: Union[str, Path], use_dtype_tolerances: bool = False + ) -> Tuple[bool, str]: + """ + Compare actual outputs with expected outputs. + + Args: + actual_output_path: Path to the actual output file. + use_dtype_tolerances: If True, uses tolerance presets based on output dtypes + instead of the test's rtol/atol values. + """ + test_dir = self.get_test_dir() + expected_path = test_dir / "expected_output.bin" + + expected = load_tensors_from_bin(expected_path) + actual = load_tensors_from_bin(actual_output_path) + + # Determine tolerances + if use_dtype_tolerances: + # Use dtype-based tolerances (loosest tolerance across all output dtypes) + output_dtypes = [t.dtype for t in expected] + rtol, atol = get_tolerance_for_dtypes(output_dtypes) + else: + rtol, atol = self.rtol, self.atol + + return compare_outputs(expected, actual, rtol=rtol, atol=atol) + + def run_test(self, verbose: bool = False, timeout: Optional[int] = None) -> bool: + """ + Run the full test: generate files, run C++, compare outputs. + + Args: + verbose: Whether to print verbose output. + timeout: Timeout in seconds. None means use self.timeout. + + Returns: + True if test passed, False otherwise. + """ + if timeout is None: + timeout = self.timeout + + print(f"\n{'='*60}") + print(f"Running test: {self.name}") + print(f"{'='*60}\n") + + # Generate test files + print("Step 1: Generating test files...") + pte_path, input_path, expected_path = self.generate_test_files(verbose=verbose) + + # Print MLX graph summary + print_mlx_graph_summary(pte_path) + + # Verify expected number of MLX delegate segments + print("\nStep 2: Verifying MLX delegation...") + actual_segments = count_mlx_delegate_segments(pte_path) + print(f" Expected MLX segments: {self.expected_mlx_segments}") + print(f" Actual MLX segments: {actual_segments}") + + if actual_segments != self.expected_mlx_segments: + print("✗ FAILED: MLX delegation mismatch!") + print( + f" Expected {self.expected_mlx_segments} segment(s), but found {actual_segments}" + ) + return False + print("✓ MLX delegation verified") + + # Verify expected node counts if specified + if self.expected_node_counts is not None: + print("\n Verifying serialized node counts...") + actual_counts = get_mlx_node_counts(pte_path) + for node_name, expected_count in self.expected_node_counts.items(): + actual_count = actual_counts.get(node_name, 0) + if actual_count != expected_count: + print(f"✗ FAILED: Node count mismatch for {node_name}!") + print(f" Expected {expected_count}, got {actual_count}") + print(f" All node counts: {actual_counts}") + return False + print(f" ✓ {node_name}: {actual_count}") + print(" ✓ All node counts verified") + + # Run C++ binary + print("\nStep 3: Running C++ binary...") + actual_path = self.get_test_dir() / "actual_output.bin" + if not run_cpp_test_runner( + pte_path, input_path, actual_path, verbose=verbose, timeout=timeout + ): + return False + + # Compare outputs (or skip if configured) + print("\nStep 4: Comparing outputs...") + if self.skip_comparison: + reason = self.skip_comparison_reason or "skip_comparison=True" + print(f"NOTE: Output comparison skipped ({reason})") + print("✓ PASSED (runtime execution succeeded)") + return True + + passed, message = self.compare_with_actual(actual_path) + + if passed: + print(f"✓ PASSED: {message}") + else: + print(f"✗ FAILED: {message}") + + return passed + + +def run_op_test_main( + test_factory, + description: str, + add_args_fn=None, +): + """ + Common main() function for op tests. + + This handles the common argparse setup, rebuild logic, and generate/compare/run + action handling that is shared across all op tests. + + Args: + test_factory: A callable that takes parsed args (argparse.Namespace) and + returns an OpTestCase instance. + description: Description for the argparse help message. + add_args_fn: Optional callable that takes a parser and adds test-specific + arguments. Signature: add_args_fn(parser) -> None + """ + import argparse + import sys + + parser = argparse.ArgumentParser(description=description) + parser.add_argument( + "action", + choices=["generate", "compare", "run"], + help="Action to perform: generate (create test files), compare (compare outputs), run (full test)", + ) + parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output") + parser.add_argument( + "--rebuild", + action="store_true", + help="Rebuild the C++ test runner before running", + ) + + # Add test-specific arguments + if add_args_fn is not None: + add_args_fn(parser) + + args = parser.parse_args() + + # Rebuild if requested + if args.rebuild: + if not rebuild_op_test_runner(verbose=args.verbose): + sys.exit(1) + + # Create test case from factory + test = test_factory(args) + + if args.action == "generate": + pte_path, input_path, expected_path = test.generate_test_files( + verbose=args.verbose + ) + print("\nGenerated files:") + print(f" PTE: {pte_path}") + print(f" Input: {input_path}") + print(f" Expected: {expected_path}") + print_mlx_graph_summary(pte_path) + + elif args.action == "compare": + actual_path = test.get_test_dir() / "actual_output.bin" + if not actual_path.exists(): + print(f"Error: {actual_path} not found. Run the C++ binary first.") + sys.exit(1) + + passed, message = test.compare_with_actual(actual_path) + if passed: + print(f"✓ PASSED: {message}") + else: + print(f"✗ FAILED: {message}") + sys.exit(0 if passed else 1) + + elif args.action == "run": + passed = test.run_test(verbose=args.verbose) + sys.exit(0 if passed else 1) diff --git a/backends/mlx/test/tester.py b/backends/mlx/test/tester.py new file mode 100644 index 00000000000..7a929ea7c3b --- /dev/null +++ b/backends/mlx/test/tester.py @@ -0,0 +1,78 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import functools +from typing import Any, List, Optional, Tuple + +import executorch +import executorch.backends.test.harness.stages as BaseStages +import torch + +from executorch.backends.mlx.partitioner import MLXPartitioner +from executorch.backends.test.harness import Tester as TesterBase +from executorch.backends.test.harness.stages import StageType +from executorch.exir import EdgeCompileConfig +from executorch.exir.backend.backend_details import CompileSpec +from executorch.exir.backend.partitioner import Partitioner + + +def _create_default_partitioner( + compile_specs: List[CompileSpec] | None = None, +) -> MLXPartitioner: + return MLXPartitioner(compile_specs=compile_specs) + + +class Partition(BaseStages.Partition): + def __init__( + self, + partitioner: Optional[Partitioner] = None, + compile_specs: Optional[List[CompileSpec]] = None, + ): + super().__init__( + partitioner=partitioner or _create_default_partitioner(compile_specs), + ) + + +class ToEdgeTransformAndLower(BaseStages.ToEdgeTransformAndLower): + def __init__( + self, + partitioners: Optional[List[Partitioner]] = None, + edge_compile_config: Optional[EdgeCompileConfig] = None, + compile_specs: Optional[List[CompileSpec]] = None, + ): + super().__init__( + default_partitioner_cls=lambda: _create_default_partitioner(compile_specs), + partitioners=partitioners, + edge_compile_config=edge_compile_config, + ) + + +class MLXTester(TesterBase): + def __init__( + self, + module: torch.nn.Module, + example_inputs: Tuple[torch.Tensor], + dynamic_shapes: Optional[Tuple[Any]] = None, + compile_specs: Optional[List[CompileSpec]] = None, + ): + stage_classes = ( + executorch.backends.test.harness.Tester.default_stage_classes() + | { + StageType.PARTITION: functools.partial( + Partition, compile_specs=compile_specs + ), + StageType.TO_EDGE_TRANSFORM_AND_LOWER: functools.partial( + ToEdgeTransformAndLower, compile_specs=compile_specs + ), + } + ) + + super().__init__( + module=module, + stage_classes=stage_classes, + example_inputs=example_inputs, + dynamic_shapes=dynamic_shapes, + ) diff --git a/backends/mlx/third-party/mlx b/backends/mlx/third-party/mlx new file mode 160000 index 00000000000..72e94c81e16 --- /dev/null +++ b/backends/mlx/third-party/mlx @@ -0,0 +1 @@ +Subproject commit 72e94c81e1685c90679ef03532c4b8897010abf9 diff --git a/backends/test/suite/flow.py b/backends/test/suite/flow.py index f3c9ee75083..c9142581f33 100644 --- a/backends/test/suite/flow.py +++ b/backends/test/suite/flow.py @@ -53,7 +53,7 @@ def __str__(self): return self.name -def all_flows() -> dict[str, TestFlow]: +def all_flows() -> dict[str, TestFlow]: # noqa: C901 flows = [] from executorch.backends.test.suite.flows.portable import PORTABLE_TEST_FLOW @@ -147,4 +147,13 @@ def all_flows() -> dict[str, TestFlow]: except Exception as e: logger.info(f"Skipping ARM flow registration: {e}") + try: + from executorch.backends.test.suite.flows.mlx import MLX_TEST_FLOW + + flows += [ + MLX_TEST_FLOW, + ] + except Exception as e: + logger.info(f"Skipping MLX flow registration: {e}") + return {f.name: f for f in flows if f is not None} diff --git a/backends/test/suite/flows/mlx.py b/backends/test/suite/flows/mlx.py new file mode 100644 index 00000000000..d70db46b73c --- /dev/null +++ b/backends/test/suite/flows/mlx.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.backends.mlx.test.tester import MLXTester +from executorch.backends.test.suite.flow import TestFlow + +MLX_TEST_FLOW = TestFlow( + name="mlx", + backend="mlx", + tester_factory=MLXTester, +) diff --git a/exir/_serialize/_program.py b/exir/_serialize/_program.py index be7bf0bd56f..e2c31f0c5fc 100644 --- a/exir/_serialize/_program.py +++ b/exir/_serialize/_program.py @@ -765,3 +765,70 @@ def deserialize_pte_binary(program_data: bytes) -> PTEFile: ) return PTEFile(program=program, mutable_data=None, named_data=None) + + +def _extract_delegate_payload( + pte_data: bytes, backend_id: str, delegate_index: int = 0 +) -> Optional[bytes]: + """Extract a delegate payload from a serialized PTE file. + + Parses the PTE file structure, finds the delegate matching the given + backend ID, and returns its raw payload bytes. Handles both inline + delegate data and segment-based storage. + + Args: + pte_data: Raw bytes of the PTE file. + backend_id: ID substring to match (case-insensitive). + For example, 'mlx' matches 'MLXBackend'. + delegate_index: Which matching delegate to extract (0-based). + Defaults to 0 (first match). + + Returns: + Delegate payload bytes, or None if not found. + """ + # Parse the extended header + extended_header = _get_extended_header(pte_data) + + # Determine program size from header or use full data + if extended_header is not None: + program_size = extended_header.program_size + else: + program_size = len(pte_data) + + # Parse the program flatbuffer + program: Program = _json_to_program( + _program_flatbuffer_to_json(pte_data[:program_size]) + ) + + # Search for the matching delegate + match_count = 0 + for plan in program.execution_plan: + for delegate in plan.delegates: + if backend_id.lower() not in delegate.id.lower(): + continue + if match_count != delegate_index: + match_count += 1 + continue + + processed = delegate.processed + + # Inline data + if processed.location == DataLocation.INLINE: + inline_data = program.backend_delegate_data[processed.index] + if inline_data.data: + return bytes(inline_data.data) + return None + + # Segment data + if processed.location == DataLocation.SEGMENT: + if extended_header is None: + return None + + segment = program.segments[processed.index] + offset = extended_header.segment_base_offset + segment.offset + size = segment.size + return pte_data[offset : offset + size] + + return None + + return None diff --git a/setup.py b/setup.py index f05951012e3..d07736128c8 100644 --- a/setup.py +++ b/setup.py @@ -624,6 +624,26 @@ def run(self): # the input file is read-only. self.copy_file(src, dst, preserve_mode=False) + # Copy CMake-generated Python directories that setuptools missed. + # Setuptools discovers packages at configuration time, before CMake + # runs. Directories created by CMake during the build (e.g. by + # generate.py) are not in the package list and must be copied manually. + generated_dirs = [ + "backends/mlx/serialization/_generated", + ] + for rel_dir in generated_dirs: + src_dir = os.path.join("src/executorch", rel_dir) + if not os.path.isdir(src_dir): + continue + dst_dir = os.path.join(dst_root, rel_dir) + for dirpath, _dirnames, filenames in os.walk(src_dir): + for filename in filenames: + src_file = os.path.join(dirpath, filename) + rel_path = os.path.relpath(src_file, src_dir) + dst_file = os.path.join(dst_dir, rel_path) + self.mkpath(os.path.dirname(dst_file)) + self.copy_file(src_file, dst_file, preserve_mode=False) + class Buck2EnvironmentFixer(contextlib.AbstractContextManager): """Removes HOME from the environment when running as root. @@ -786,6 +806,9 @@ def run(self): # noqa C901 if cmake_cache.is_enabled("EXECUTORCH_BUILD_COREML"): cmake_build_args += ["--target", "executorchcoreml"] + if cmake_cache.is_enabled("EXECUTORCH_BUILD_MLX"): + cmake_build_args += ["--target", "mlxdelegate"] + if cmake_cache.is_enabled("EXECUTORCH_BUILD_KERNELS_LLM_AOT"): cmake_build_args += ["--target", "custom_ops_aot_lib"] cmake_build_args += ["--target", "quantized_ops_aot_lib"] @@ -846,6 +869,16 @@ def run(self): # noqa C901 modpath="executorch.extension.pybindings.data_loader", dependent_cmake_flags=["EXECUTORCH_BUILD_PYBIND"], ), + # MLX metallib (Metal GPU kernels) must be colocated with _portable_lib.so + # because MLX uses dladdr() to find the directory containing the library, + # then looks for mlx.metallib in that directory at runtime. + # After submodule migration, the path is backends/mlx/mlx/... + BuiltFile( + src_dir="%CMAKE_CACHE_DIR%/backends/mlx/mlx/mlx/backend/metal/kernels/", + src_name="mlx.metallib", + dst="executorch/extension/pybindings/", + dependent_cmake_flags=["EXECUTORCH_BUILD_MLX"], + ), BuiltExtension( src="extension/training/_training_lib.*", # @lint-ignore https://github.com/pytorch/executorch/blob/cb3eba0d7f630bc8cec0a9cc1df8ae2f17af3f7a/scripts/lint_xrefs.sh modpath="executorch.extension.training.pybindings._training_lib", diff --git a/tools/cmake/Utils.cmake b/tools/cmake/Utils.cmake index 74f2be78804..3295036663c 100644 --- a/tools/cmake/Utils.cmake +++ b/tools/cmake/Utils.cmake @@ -178,3 +178,36 @@ function(executorch_add_prefix_to_public_headers targetName prefix) TARGET "${targetName}" PROPERTY PUBLIC_HEADER ${FIXED_PUBLIC_HEADERS} ) endfunction() + +# ----------------------------------------------------------------------------- +# MLX metallib distribution helper +# ----------------------------------------------------------------------------- +# Copies mlx.metallib next to the target executable so MLX can find it at +# runtime. +# +# MLX uses dladdr() to find the directory containing the binary with MLX code, +# then looks for mlx.metallib in that directory. When MLX is statically linked +# into an executable or shared library, this function ensures the metallib is +# colocated with that binary. +# +# Usage: executorch_target_copy_mlx_metallib(my_executable) +# +function(executorch_target_copy_mlx_metallib target) + if(EXECUTORCH_BUILD_MLX) + if(DEFINED MLX_METALLIB_PATH AND EXISTS "${MLX_METALLIB_PATH}") + add_custom_command( + TARGET ${target} + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different "${MLX_METALLIB_PATH}" + "$/mlx.metallib" + COMMENT "Copying mlx.metallib for ${target}" + ) + elseif(DEFINED MLX_METALLIB_PATH) + message( + WARNING + "MLX_METALLIB_PATH is set to ${MLX_METALLIB_PATH} but file does not exist. " + "metallib will not be copied for ${target}." + ) + endif() + endif() +endfunction() diff --git a/tools/cmake/executorch-config.cmake b/tools/cmake/executorch-config.cmake index dc4d34d8701..524e1be36ec 100644 --- a/tools/cmake/executorch-config.cmake +++ b/tools/cmake/executorch-config.cmake @@ -63,6 +63,8 @@ set(optional_lib_list coreml_inmemoryfs coremldelegate mpsdelegate + mlxdelegate + mlx metal_backend neuron_backend qnn_executorch_backend @@ -118,3 +120,46 @@ set_property( TARGET executorch_core PROPERTY INTERFACE_LINK_LIBRARIES ${FIXED_EXECUTORCH_CORE_LINK_LIBRARIES} ) + +# Expose MLX library and metallib path for downstream consumers +if(TARGET mlxdelegate) + # Create imported target for mlx library if not already defined (mlx is built + # by MLX's CMake but we need to expose it for linking) + if(NOT TARGET mlx) + find_library( + _mlx_library mlx + HINTS "${_root}/lib" + CMAKE_FIND_ROOT_PATH_BOTH + ) + if(_mlx_library) + add_library(mlx STATIC IMPORTED) + set_target_properties(mlx PROPERTIES IMPORTED_LOCATION "${_mlx_library}") + # MLX requires Metal and Foundation frameworks on Apple platforms + if(APPLE) + find_library(METAL_FRAMEWORK Metal) + find_library(FOUNDATION_FRAMEWORK Foundation) + if(METAL_FRAMEWORK AND FOUNDATION_FRAMEWORK) + set_target_properties( + mlx PROPERTIES INTERFACE_LINK_LIBRARIES + "${METAL_FRAMEWORK};${FOUNDATION_FRAMEWORK}" + ) + endif() + endif() + message(STATUS "Found mlx library at: ${_mlx_library}") + endif() + endif() + + # Find metallib for runtime distribution + find_file( + _mlx_metallib mlx.metallib + HINTS "${_root}/lib" + CMAKE_FIND_ROOT_PATH_BOTH + ) + if(_mlx_metallib) + set(MLX_METALLIB_PATH + "${_mlx_metallib}" + CACHE FILEPATH "Path to mlx.metallib for runtime distribution" + ) + message(STATUS "Found mlx.metallib at: ${MLX_METALLIB_PATH}") + endif() +endif() diff --git a/tools/cmake/preset/default.cmake b/tools/cmake/preset/default.cmake index 1caf8ea9602..9280d0db915 100644 --- a/tools/cmake/preset/default.cmake +++ b/tools/cmake/preset/default.cmake @@ -121,6 +121,7 @@ define_overridable_option( EXECUTORCH_BUILD_EXTENSION_APPLE "Build the Apple extension" BOOL OFF ) define_overridable_option(EXECUTORCH_BUILD_MPS "Build the MPS backend" BOOL OFF) +define_overridable_option(EXECUTORCH_BUILD_MLX "Build the MLX backend" BOOL OFF) define_overridable_option( EXECUTORCH_BUILD_NEURON "Build the backends/mediatek directory" BOOL OFF ) diff --git a/tools/cmake/preset/pybind.cmake b/tools/cmake/preset/pybind.cmake index 699a7c50358..dc60dc7d820 100644 --- a/tools/cmake/preset/pybind.cmake +++ b/tools/cmake/preset/pybind.cmake @@ -31,6 +31,24 @@ if(CMAKE_SYSTEM_NAME STREQUAL "Darwin") set_overridable_option(EXECUTORCH_BUILD_EXTENSION_TRAINING ON) set_overridable_option(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER ON) set_overridable_option(EXECUTORCH_BUILD_EXTENSION_LLM ON) + # MLX requires Apple Silicon (ARM64) and the Metal compiler (xcrun -sdk macosx + # metal) which is only available with Xcode, not Command Line Tools + if(CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") + execute_process( + COMMAND xcrun -sdk macosx --find metal + RESULT_VARIABLE _metal_compiler_result + OUTPUT_QUIET ERROR_QUIET + ) + if(_metal_compiler_result EQUAL 0) + set_overridable_option(EXECUTORCH_BUILD_MLX ON) + set_overridable_option(ET_MLX_ENABLE_OP_LOGGING ON) + else() + message( + STATUS + "Metal compiler not found, disabling MLX backend. Install Xcode to enable MLX." + ) + endif() + endif() elseif(CMAKE_SYSTEM_NAME STREQUAL "Linux") set_overridable_option(EXECUTORCH_BUILD_COREML ON) set_overridable_option(EXECUTORCH_BUILD_EXTENSION_TRAINING ON) From 0f03f2b8032289b997125356c44727c6d587a052 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Mon, 2 Mar 2026 18:06:21 -0800 Subject: [PATCH 2/9] up --- tools/cmake/preset/mlx.cmake | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 tools/cmake/preset/mlx.cmake diff --git a/tools/cmake/preset/mlx.cmake b/tools/cmake/preset/mlx.cmake new file mode 100644 index 00000000000..d8ea7fe237f --- /dev/null +++ b/tools/cmake/preset/mlx.cmake @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# MLX delegate preset - builds ExecuTorch with MLX backend for Apple Silicon + +# Core ExecuTorch options +set_overridable_option(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER ON) +set_overridable_option(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR ON) +set_overridable_option(EXECUTORCH_BUILD_EXTENSION_MODULE ON) +set_overridable_option(EXECUTORCH_BUILD_EXTENSION_TENSOR ON) +set_overridable_option(EXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP ON) +set_overridable_option(EXECUTORCH_BUILD_EXTENSION_LLM ON) +set_overridable_option(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER ON) +set_overridable_option(EXECUTORCH_BUILD_EXTENSION_ASR_RUNNER ON) +set_overridable_option(EXECUTORCH_BUILD_KERNELS_OPTIMIZED ON) + +# Build the MLX delegate +set_overridable_option(EXECUTORCH_BUILD_MLX ON) From bf3673208f279bee45fa8f261e6c5fd9da3b7397 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Mon, 2 Mar 2026 18:21:54 -0800 Subject: [PATCH 3/9] up --- backends/mlx/runtime/MLXInterpreter.h | 8 ++++++++ backends/mlx/serialization/schema.fbs | 6 ++++++ 2 files changed, 14 insertions(+) diff --git a/backends/mlx/runtime/MLXInterpreter.h b/backends/mlx/runtime/MLXInterpreter.h index f3b6e9b720f..bfd593c162b 100644 --- a/backends/mlx/runtime/MLXInterpreter.h +++ b/backends/mlx/runtime/MLXInterpreter.h @@ -98,6 +98,11 @@ inline std::vector infer_shape_with_minus_one( inline void exec_noop(const NoopNode&, ExecutionState&, StreamOrDevice) {} +inline void +exec_id_copy(const IdCopyNode& n, ExecutionState& st, StreamOrDevice) { + st.set_tensor(n.out, st.const_tensor_ref(n.x)); +} + inline void exec_addmm(const AddmmNode& n, ExecutionState& st, StreamOrDevice s) { const auto& mat1 = st.const_tensor_ref(n.mat1); @@ -154,6 +159,9 @@ class Interpreter { case OpCode::NOOP: ops::exec_noop(std::get(instr.node), st, s); break; + case OpCode::ID_COPY: + ops::exec_id_copy(std::get(instr.node), st, s); + break; case OpCode::ADDMM: ops::exec_addmm(std::get(instr.node), st, s); break; diff --git a/backends/mlx/serialization/schema.fbs b/backends/mlx/serialization/schema.fbs index 945186ebef8..8b159314760 100644 --- a/backends/mlx/serialization/schema.fbs +++ b/backends/mlx/serialization/schema.fbs @@ -72,6 +72,11 @@ table IntOrVidOrTid { table NoopNode {} +table IdCopyNode { + x: Tid (required); + out: Tid (required); +} + table AddmmNode { mat1: Tid (required); // First matrix mat2: Tid (required); // Second matrix @@ -89,6 +94,7 @@ table AddmmNode { // Reordering or removing members changes numeric type IDs and breaks existing .pte files. union OpNode { NoopNode, + IdCopyNode, AddmmNode // BC: Add new op nodes here (append only) } From 6a2d4556701c5db003df99c0d6de95f5a50db3ea Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Mon, 2 Mar 2026 18:50:55 -0800 Subject: [PATCH 4/9] up --- backends/mlx/ops.py | 6 +++ backends/mlx/test/test_ops.py | 91 ----------------------------------- 2 files changed, 6 insertions(+), 91 deletions(-) diff --git a/backends/mlx/ops.py b/backends/mlx/ops.py index 6e8516e86b1..4c9e0d6f796 100644 --- a/backends/mlx/ops.py +++ b/backends/mlx/ops.py @@ -219,6 +219,12 @@ def normalize_reduction_dim( return dim, keepdim +@REGISTRY.register(target=["NOOP", torch.ops.aten._assert_scalar.default]) +def _noop_handler(P: MLXProgramBuilder, n: Node) -> None: + """No-op handler for nodes that don't emit any MLX instructions.""" + return None + + @REGISTRY.register(target=[torch.ops.aten.addmm.default]) def _addmm_handler(P: MLXProgramBuilder, n: Node) -> Slot: """Handle addmm: self + (mat1 @ mat2). diff --git a/backends/mlx/test/test_ops.py b/backends/mlx/test/test_ops.py index 01286f75f16..0ba98b532ad 100644 --- a/backends/mlx/test/test_ops.py +++ b/backends/mlx/test/test_ops.py @@ -83,94 +83,3 @@ def create_model(self) -> nn.Module: def create_inputs(self) -> Tuple[torch.Tensor, ...]: x = torch.randn(self.batch_size, self.n, self.m) return (x,) - - -class AddmmModel(nn.Module): - """Model that performs addmm: bias + (mat1 @ mat2).""" - - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - alpha: float = 1.0, - beta: float = 1.0, - ): - super().__init__() - self.weight = nn.Parameter(torch.randn(out_features, in_features)) - if bias: - self.bias = nn.Parameter(torch.randn(out_features)) - else: - self.bias = None - self.alpha = alpha - self.beta = beta - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.bias is not None: - return torch.addmm( - self.bias, x, self.weight.t(), beta=self.beta, alpha=self.alpha - ) - else: - return torch.mm(x, self.weight.t()) - - -@register_test -class AddmmTest(OpTestCase): - """Test case for addmm.""" - - name = "addmm" - rtol = 1e-4 - atol = 1e-4 - - def __init__( - self, - batch_size: int = 2, - in_features: int = 64, - out_features: int = 32, - bias: bool = True, - alpha: float = 1.0, - beta: float = 1.0, - ): - self.batch_size = batch_size - self.in_features = in_features - self.out_features = out_features - self.bias = bias - self.alpha = alpha - self.beta = beta - - # Build unique test name - if not bias: - name = f"addmm_{in_features}x{out_features}_no_bias" - elif alpha != 1.0 or beta != 1.0: - name = f"addmm_{in_features}x{out_features}_a{alpha}_b{beta}" - else: - name = f"addmm_{in_features}x{out_features}" - self.name = name - - @classmethod - def get_test_configs(cls) -> List["AddmmTest"]: - return [ - cls( - batch_size=2, in_features=64, out_features=32 - ), # with bias, default alpha/beta - cls( - batch_size=2, in_features=64, out_features=32, bias=False - ), # without bias - cls(batch_size=4, in_features=128, out_features=64), # larger size - cls( - batch_size=2, in_features=64, out_features=32, alpha=2.0, beta=0.5 - ), # custom alpha/beta - ] - - def create_model(self) -> nn.Module: - return AddmmModel( - self.in_features, - self.out_features, - bias=self.bias, - alpha=self.alpha, - beta=self.beta, - ) - - def create_inputs(self) -> Tuple[torch.Tensor, ...]: - x = torch.randn(self.batch_size, self.in_features) - return (x,) From 493d9ea5e76525ccbdc66b39dca624ad29c67a91 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Wed, 4 Mar 2026 21:32:00 -0800 Subject: [PATCH 5/9] up --- backends/mlx/builder/program_builder.py | 19 +++++++++++++----- backends/mlx/runtime/MLXBackend.cpp | 26 ++++++++++++++++++++++--- backends/mlx/runtime/MLXExecutor.h | 20 ++++++++++++++++++- backends/mlx/test/test_utils.py | 10 +++++++++- 4 files changed, 65 insertions(+), 10 deletions(-) diff --git a/backends/mlx/builder/program_builder.py b/backends/mlx/builder/program_builder.py index 60d5ebbdbfe..2add4f1b7a3 100644 --- a/backends/mlx/builder/program_builder.py +++ b/backends/mlx/builder/program_builder.py @@ -27,7 +27,6 @@ from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple, Union import torch - from executorch.backends.mlx._logging import logger from executorch.backends.mlx.builder.op_helpers import torch_dtype_to_scalar_type from executorch.backends.mlx.builder.op_registry import ( @@ -132,7 +131,9 @@ class MLXProgramBuilder: def __init__(self, ep: ExportedProgram, named_data_key_prefix: str = ""): self.ep: ExportedProgram = ep - self._instrs: List[Instruction] = [] + self._chains: List[List[Instruction]] = [[]] # chain 0 = main + self._current_chain: int = 0 + self.init_chain_idx: int = -1 self.extra_constants: Dict[str, torch.Tensor] = {} self.slot_manager = SlotManager() self.node_info: DefaultDict[Node, NodeInfo] = defaultdict(NodeInfo) @@ -163,7 +164,13 @@ def _prefix_key(self, name: str) -> str: return name def emit(self, op: OpNodeUnion) -> None: - self._instrs.append(Instruction(op=op)) + self._chains[self._current_chain].append(Instruction(op=op)) + + def emit_init(self, op: OpNodeUnion) -> None: + if self.init_chain_idx == -1: + self.init_chain_idx = len(self._chains) + self._chains.append([]) + self._chains[self.init_chain_idx].append(Instruction(op=op)) def args(self, node: Node) -> Tuple[Any, ...]: return self.slot_map(node.args) @@ -934,9 +941,11 @@ def _build_mlx_graph(self) -> MLXGraph: num_mutable_buffer_tensors=num_tensors[IdSpace.MutableBuffer], num_temp_tensors=num_temp_tensors, num_values=num_values_count, - instruction_chains=[InstructionChain(instructions=self._instrs)], + instruction_chains=[ + InstructionChain(instructions=chain) for chain in self._chains + ], main_chain_idx=0, - init_chain_idx=-1, + init_chain_idx=self.init_chain_idx, input_map=input_map, output_map=output_map, mutable_buffer_map=mutable_buffer_map, diff --git a/backends/mlx/runtime/MLXBackend.cpp b/backends/mlx/runtime/MLXBackend.cpp index 38dff189935..99e20114ea7 100644 --- a/backends/mlx/runtime/MLXBackend.cpp +++ b/backends/mlx/runtime/MLXBackend.cpp @@ -219,10 +219,24 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface { static_cast(processed->data()), processed->size()); // Validate schema version - if (handle->program.version != "1") { + int schema_version = 1; + if (!handle->program.version.empty()) { + try { + schema_version = std::stoi(handle->program.version); + } catch (...) { + throw std::runtime_error( + "Invalid MLX schema version '" + handle->program.version + + "' (expected integer)"); + } + } + constexpr int kMaxSupportedVersion = 1; + if (schema_version > kMaxSupportedVersion) { throw std::runtime_error( - "Unsupported MLX schema version '" + handle->program.version + - "' (expected '1'). Rebuild the .pte with a matching SDK version."); + "This .pte requires ExecuTorch MLX runtime version " + + std::to_string(schema_version) + + " but this runtime only supports up to version " + + std::to_string(kMaxSupportedVersion) + + ". Upgrade ExecuTorch to a newer version."); } // Load constants from named_data_map @@ -251,11 +265,17 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface { // SAFETY: The >= 0 check ensures init_chain_idx is non-negative, so the // static_cast cannot produce UINT32_MAX from a -1 sentinel. if (handle->program.init_chain_idx >= 0) { + handle->state.is_init_chain = true; handle->interpreter.run_chain( handle->program, static_cast(handle->program.init_chain_idx), handle->state, handle->stream); + handle->state.is_init_chain = false; + + // Evaluate any constants written by the init chain so the first + // execute() doesn't pay the cost of materializing them. + eval(handle->constants.tensors); } } catch (const std::exception& e) { diff --git a/backends/mlx/runtime/MLXExecutor.h b/backends/mlx/runtime/MLXExecutor.h index 32d623790ab..978eaadabba 100644 --- a/backends/mlx/runtime/MLXExecutor.h +++ b/backends/mlx/runtime/MLXExecutor.h @@ -97,6 +97,13 @@ struct ConstantData { return tensors[id.idx]; } + inline void set(Tid id, Tensor t) { + if (id.idx >= tensors.size()) { + throw std::out_of_range("ConstantData::set: id out of range"); + } + tensors[id.idx] = std::move(t); + } + inline void add(Tensor t) { tensors.push_back(std::move(t)); } @@ -153,6 +160,9 @@ struct ExecutionState { // Non-constant values (SymInt, etc.) std::vector> values; + // Init chain flag: when true, set_tensor allows writing to constants + bool is_init_chain{false}; + // Logging context size_t current_op_idx{0}; const char* current_op_name{nullptr}; @@ -478,7 +488,15 @@ struct ExecutionState { throw std::runtime_error("set_tensor: Program not bound"); } if (id.idx < program->num_constant_tensors) { - throw std::runtime_error("set_tensor: cannot write to constant tensor"); + if (!is_init_chain) { + throw std::runtime_error("set_tensor: cannot write to constant tensor"); + } + // Init chain can write over constants + if (!constants) { + throw std::runtime_error("set_tensor: constants not bound"); + } + const_cast(constants)->set(id, std::move(arr)); + return; } // Route to mutable buffers or per-execution tensors if (is_mutable_buffer(id)) { diff --git a/backends/mlx/test/test_utils.py b/backends/mlx/test/test_utils.py index 090bceabf08..660968195b7 100644 --- a/backends/mlx/test/test_utils.py +++ b/backends/mlx/test/test_utils.py @@ -23,6 +23,7 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple, Union +import executorch.exir as exir import numpy as np import torch @@ -268,6 +269,7 @@ def export_model_to_pte( output_path: Union[str, Path], dynamic_shapes: Optional[Dict] = None, verbose: bool = False, + edge_compile_config: Optional[exir.EdgeCompileConfig] = None, ) -> None: """ Export a PyTorch model to a .pte file using the MLX delegate. @@ -281,7 +283,6 @@ def export_model_to_pte( Example: {0: {0: Dim("batch", min=1, max=32)}} for dynamic batch on first input. verbose: Whether to print the exported program for debugging. """ - import executorch.exir as exir from executorch.backends.mlx import MLXPartitioner from executorch.exir.capture._config import ExecutorchBackendConfig from torch.export import export @@ -301,9 +302,11 @@ def export_model_to_pte( print(exported_program) # Lower to edge and delegate to MLX + compile_config = edge_compile_config or exir.EdgeCompileConfig() edge_program = exir.to_edge_transform_and_lower( exported_program, partitioner=[MLXPartitioner()], + compile_config=compile_config, ) # Print edge program if verbose @@ -865,6 +868,10 @@ def get_dynamic_shapes(self) -> Optional[Dict]: """Return dynamic shapes specification for torch.export, or None for static shapes.""" return None + def get_edge_compile_config(self) -> Optional[exir.EdgeCompileConfig]: + """Return EdgeCompileConfig for export, or None for default.""" + return None + def get_test_dir(self) -> Path: """Get the directory for this test's files.""" test_dir = Path(__file__).parent / "op_tests" / self.name @@ -924,6 +931,7 @@ def generate_test_files(self, verbose: bool = False) -> Tuple[Path, Path, Path]: pte_path, dynamic_shapes=dynamic_shapes, verbose=verbose, + edge_compile_config=self.get_edge_compile_config(), ) # Save test inputs From 5ee8ac41c1be9f4a496c3e7f890fc75caacab7c1 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Wed, 4 Mar 2026 21:33:28 -0800 Subject: [PATCH 6/9] up --- backends/mlx/third-party/mlx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/mlx/third-party/mlx b/backends/mlx/third-party/mlx index 72e94c81e16..365d6f29b47 160000 --- a/backends/mlx/third-party/mlx +++ b/backends/mlx/third-party/mlx @@ -1 +1 @@ -Subproject commit 72e94c81e1685c90679ef03532c4b8897010abf9 +Subproject commit 365d6f29b47686a9f5401f6a9ec5825fee162d69 From 0df21d99285cb1e97c8819bb04ae789e0d1c8c4e Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Wed, 4 Mar 2026 21:44:36 -0800 Subject: [PATCH 7/9] up --- .github/workflows/mlx.yml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index 2e8ca7aa3b7..ea0bce96e1a 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -9,11 +9,13 @@ on: paths: - .github/workflows/mlx.yml - backends/mlx/** + - extension/llm/export/** + - extension/audio/** + - examples/models/parakeet/** + - examples/models/voxtral_realtime/** workflow_dispatch: -concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} - cancel-in-progress: true +permissions: {} jobs: test-mlx: From 93afd3e1ac787d5deecb1179df9cb5bac266e89e Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Wed, 4 Mar 2026 21:50:26 -0800 Subject: [PATCH 8/9] up --- .github/workflows/mlx.yml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index ea0bce96e1a..cc83c90e23e 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -9,10 +9,6 @@ on: paths: - .github/workflows/mlx.yml - backends/mlx/** - - extension/llm/export/** - - extension/audio/** - - examples/models/parakeet/** - - examples/models/voxtral_realtime/** workflow_dispatch: permissions: {} From 0adbe8c752f49911d541038bada8b73259b95752 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Thu, 5 Mar 2026 11:23:22 -0800 Subject: [PATCH 9/9] up --- backends/mlx/README.md | 123 ++++++++++++++++--------- backends/mlx/serialization/generate.py | 2 +- 2 files changed, 82 insertions(+), 43 deletions(-) diff --git a/backends/mlx/README.md b/backends/mlx/README.md index ebab893385a..eea60fe2d00 100644 --- a/backends/mlx/README.md +++ b/backends/mlx/README.md @@ -193,7 +193,7 @@ ExportedProgram (subgraph) ## How to Add a New Op -This section walks through adding a new op end-to-end, using **`aten.linear`** +This section walks through adding a new op end-to-end, using **`aten.addmm`** as an example. ### Step 1: Add the Node to `schema.fbs` @@ -201,15 +201,15 @@ as an example. Add a new table in the "Op nodes" section and add it to the `OpNode` union: ```fbs -table LinearNode { - x: Tid (required); - weight: Tid (required); +table AddmmNode { + mat1: Tid (required); + mat2: Tid (required); out: Tid (required); bias: Tid; // optional } ``` -Then add `LinearNode` to the `union OpNode { ... }` list. +Then add `AddmmNode` to the `union OpNode { ... }` list. ### Step 2: Run the Code Generator @@ -219,34 +219,40 @@ python backends/mlx/serialization/generate.py This regenerates: -- `mlx_graph_schema.py` — adds `LinearNode` Python dataclass -- `_generated_serializers.py` — adds `_build_LinearNode` serializer -- `runtime/MLXLoader.h` — adds `LinearNode` C++ struct, `OpCode::LINEAR`, loader -- `runtime/MLXLoader.cpp` — adds FlatBuffer → `LinearNode` deserialization +- `mlx_graph_schema.py` — adds `AddmmNode` Python dataclass +- `_generated_serializers.py` — adds `_build_AddmmNode` serializer +- `runtime/MLXLoader.h` — adds `AddmmNode` C++ struct, `OpCode::ADDMM`, loader +- `runtime/MLXLoader.cpp` — adds FlatBuffer → `AddmmNode` deserialization - `runtime/schema_generated.h` — FlatBuffer C++ bindings ### Step 3: Add the Python Op Handler (`ops.py`) Register a handler that converts the ATen op to your new node. Make sure to -import `LinearNode` from `mlx_graph_schema`: +import `AddmmNode` from `mlx_graph_schema`: ```python -from executorch.backends.mlx.serialization.mlx_graph_schema import LinearNode +from executorch.backends.mlx.serialization.mlx_graph_schema import AddmmNode -@REGISTRY.register(target=[torch.ops.aten.linear.default]) -def _linear_handler(P: MLXProgramBuilder, n: Node) -> Slot: +@REGISTRY.register(target=[torch.ops.aten.addmm.default]) +def _addmm_handler(P: MLXProgramBuilder, n: Node) -> Slot: args = P.args(n) - require_args(args, 2, 3, "aten.linear") - require_kwargs(P.kwargs(n), set(), "aten.linear") - x, w = args[0], args[1] - b = args[2] if len(args) > 2 else None + kwargs = P.kwargs(n) + require_args(args, 3, 3, "aten.addmm") + require_kwargs(kwargs, {"beta", "alpha"}, "aten.addmm") + bias, mat1, mat2 = args[0], args[1], args[2] + + beta = kwargs.get("beta", 1) + alpha = kwargs.get("alpha", 1) + out = P.make_or_get_slot(n) P.emit( - LinearNode( - x=P.slot_to_tid(x), - weight=P.slot_to_tid(w), + AddmmNode( + mat1=P.slot_to_tid(mat1), + mat2=P.slot_to_tid(mat2), out=P.slot_to_tid(out), - bias=P.slot_to_tid(b) if b else None, + bias=P.slot_to_tid(bias), + alpha=float(alpha), + beta=float(beta), ) ) return out @@ -263,21 +269,28 @@ Key APIs: Add an `exec_*` function in the `ops` namespace: ```cpp -inline void exec_linear(const LinearNode& n, ExecutionState& st, StreamOrDevice s) { - const auto& X = st.const_tensor_ref(n.x); - auto W = transpose(st.const_tensor_ref(n.weight), {1, 0}, s); - array Y = n.bias - ? addmm(st.const_tensor_ref(*n.bias), X, W, 1.0f, 1.0f, s) - : matmul(X, W, s); +inline void exec_addmm(const AddmmNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& mat1 = st.const_tensor_ref(n.mat1); + const auto& mat2 = st.const_tensor_ref(n.mat2); + + array Y = n.bias ? addmm( + st.const_tensor_ref(*n.bias), + mat1, + mat2, + /*alpha=*/n.alpha, + /*beta=*/n.beta, + s) + : matmul(mat1, mat2, s); + st.set_tensor(n.out, std::move(Y)); } ``` -Then add the dispatch case in `Interpreter::execute_instruction()`: +Then add the dispatch case in `Interpreter::dispatch()`: ```cpp -case OpCode::LINEAR: - ops::exec_linear(std::get(instr.node), st, s); +case OpCode::ADDMM: + ops::exec_addmm(std::get(instr.node), st, s); break; ``` @@ -290,34 +303,60 @@ Each test follows a standard pattern: 3. **Decorate with `@register_test`** to register it with the test runner. ```python -class LinearModel(nn.Module): - def __init__(self, in_features=64, out_features=128, bias=True): +class AddmmModel(nn.Module): + """Model that performs addmm: bias + (mat1 @ mat2).""" + + def __init__(self, in_features, out_features, bias=True, alpha=1.0, beta=1.0): super().__init__() - self.linear = nn.Linear(in_features, out_features, bias=bias) + self.weight = nn.Parameter(torch.randn(out_features, in_features)) + if bias: + self.bias = nn.Parameter(torch.randn(out_features)) + else: + self.bias = None + self.alpha = alpha + self.beta = beta def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.linear(x) + if self.bias is not None: + return torch.addmm( + self.bias, x, self.weight.t(), beta=self.beta, alpha=self.alpha + ) + else: + return torch.mm(x, self.weight.t()) @register_test -class LinearTest(OpTestCase): - name = "linear" +class AddmmTest(OpTestCase): + name = "addmm" rtol = 1e-4 atol = 1e-4 - def __init__(self, in_features=64, out_features=128, bias=True): + def __init__(self, batch_size=2, in_features=64, out_features=32, + bias=True, alpha=1.0, beta=1.0): + self.batch_size = batch_size self.in_features = in_features self.out_features = out_features self.bias = bias + self.alpha = alpha + self.beta = beta + self.name = f"addmm_{in_features}x{out_features}" @classmethod def get_test_configs(cls): - return [cls(), cls(bias=False)] + return [ + cls(batch_size=2, in_features=64, out_features=32), + cls(batch_size=2, in_features=64, out_features=32, bias=False), + cls(batch_size=4, in_features=128, out_features=64), + cls(batch_size=2, in_features=64, out_features=32, alpha=2.0, beta=0.5), + ] def create_model(self): - return LinearModel(self.in_features, self.out_features, bias=self.bias) + return AddmmModel( + self.in_features, self.out_features, + bias=self.bias, alpha=self.alpha, beta=self.beta, + ) def create_inputs(self): - return (torch.randn(2, 16, self.in_features),) + return (torch.randn(self.batch_size, self.in_features),) ``` ### Step 6: Run Tests @@ -327,7 +366,7 @@ outputs against PyTorch reference. Since adding a new op always involves C++ changes, use `--rebuild` to recompile the runtime: ```bash -python -m executorch.backends.mlx.test.run_all_tests --rebuild linear +python -m executorch.backends.mlx.test.run_all_tests --rebuild addmm ``` Run all tests in parallel: @@ -356,7 +395,7 @@ architecture, prerequisites, and the `OpTestCase` API. - [ ] Run `python backends/mlx/serialization/generate.py` - [ ] Add `@REGISTRY.register` handler in `ops.py` (and import the new node class) - [ ] Add `exec_*` function in `runtime/MLXInterpreter.h` -- [ ] Add `case OpCode::*` in `Interpreter::execute_instruction()` +- [ ] Add `case OpCode::*` in `Interpreter::dispatch()` - [ ] Add test model + `OpTestCase` in `test/test_ops.py` - [ ] Run `python -m executorch.backends.mlx.test.run_all_tests --rebuild ` diff --git a/backends/mlx/serialization/generate.py b/backends/mlx/serialization/generate.py index d12743906db..6f6ee11fe41 100755 --- a/backends/mlx/serialization/generate.py +++ b/backends/mlx/serialization/generate.py @@ -1006,7 +1006,7 @@ def _fbs_type_to_cpp( def _table_name_to_opcode(name: str) -> str: - """Convert table name like 'LinearNode' to opcode like 'LINEAR'. + """Convert table name like 'AddNode' to opcode like 'ADD'. Uses regex-based camelCase → UPPER_SNAKE_CASE conversion with a small override dict for names whose conventional opcode doesn't follow the