diff --git a/advanced_source/cpp_custom_ops.rst b/advanced_source/cpp_custom_ops.rst index 512c39b2a68..5bdc01964bd 100644 --- a/advanced_source/cpp_custom_ops.rst +++ b/advanced_source/cpp_custom_ops.rst @@ -16,7 +16,7 @@ Custom C++ and CUDA Operators .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites :class-card: card-prerequisites - * PyTorch 2.4 or later + * PyTorch 2.10 or later (or PyTorch 2.4 or later if using the non-stable API) * Basic understanding of C++ and CUDA programming .. note:: @@ -37,7 +37,16 @@ the operation are as follows: return a * b + c You can find the end-to-end working example for this tutorial -`here `_ . +in the `extension-cpp `_ repository, +which contains two parallel implementations: + +- `extension_cpp_stable/ `_: + Uses APIs supported by the LibTorch Stable ABI (recommended for PyTorch 2.10+). The main body of this + tutorial uses code snippets from this implementation. +- `extension_cpp/ `_: + Uses the standard ATen/LibTorch API. Use this if you need APIs not yet available in the + stable ABI. Code snippets from this implementation are shown in the + :ref:`reverting-to-non-stable-api` section. Setting up the Build System --------------------------- @@ -62,12 +71,19 @@ Using ``cpp_extension`` is as simple as writing the following ``setup.py``: setup(name="extension_cpp", ext_modules=[ - cpp_extension.CppExtension( + cpp_extension.CppExtension( "extension_cpp", ["muladd.cpp"], - # define Py_LIMITED_API with min version 3.9 to expose only the stable - # limited API subset from Python.h - extra_compile_args={"cxx": ["-DPy_LIMITED_API=0x03090000"]}, + extra_compile_args={ + "cxx": [ + # define Py_LIMITED_API with min version 3.9 to expose only the stable + # limited API subset from Python.h + "-DPy_LIMITED_API=0x03090000", + # define TORCH_TARGET_VERSION with min version 2.10 to expose only the + # stable API subset from torch + "-DTORCH_TARGET_VERSION=0x020a000000000000", + ] + }, py_limited_api=True)], # Build 1 wheel across multiple Python versions cmdclass={'build_ext': cpp_extension.BuildExtension}, options={"bdist_wheel": {"py_limited_api": "cp39"}} # 3.9 is minimum supported Python version @@ -78,6 +94,9 @@ If you need to compile CUDA code (for example, ``.cu`` files), then instead use Please see `extension-cpp `_ for an example for how this is set up. +CPython Agnosticism +^^^^^^^^^^^^^^^^^^^ + The above example represents what we refer to as a CPython agnostic wheel, meaning we are building a single wheel that can be run across multiple CPython versions (similar to pure Python packages). CPython agnosticism is desirable in minimizing the number of wheels your @@ -148,25 +167,62 @@ like so: cmdclass={'build_ext': cpp_extension.BuildExtension}, ) +LibTorch Stable ABI (PyTorch Agnosticism) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In addition to CPython agnosticism, there is a second axis of wheel compatibility: +LibTorch agnosticism. While CPython agnosticism allows building a single wheel +that works across multiple Python versions (3.9, 3.10, 3.11, etc.), LibTorch agnosticism +allows building a single wheel that works across multiple PyTorch versions (2.10, 2.11, 2.12, etc.). +These two concepts are orthogonal and can be combined. + +To achieve LibTorch agnosticism, you must use the LibTorch Stable ABI, which provides +a stable C API for interacting with PyTorch tensors and operators. For example, instead of +using ``at::Tensor``, you must use ``torch::stable::Tensor``. For comprehensive +documentation on the stable ABI, including migration guides, supported types, and +stack-based API conventions, see the +`LibTorch Stable ABI documentation `_. + +The setup.py above already includes ``TORCH_TARGET_VERSION=0x020a000000000000``, which indicates that +the extension targets the LibTorch Stable ABI with a minimum supported PyTorch version of 2.10. The version format is: +``[MAJ 1 byte][MIN 1 byte][PATCH 1 byte][ABI TAG 5 bytes]``, so 2.10.0 = ``0x020a000000000000``. + +The sections below contain examples of code using the LibTorch Stable ABI. +If the stable API/ABI does not contain what you need, see the :ref:`reverting-to-non-stable-api` section +or the `extension_cpp/ subdirectory `_ +in the extension-cpp repository for the equivalent examples using the non-stable API. + Defining the custom op and adding backend implementations --------------------------------------------------------- -First, let's write a C++ function that computes ``mymuladd``: +First, let's write a C++ function that computes ``mymuladd`` using the LibTorch Stable ABI: .. code-block:: cpp - at::Tensor mymuladd_cpu(at::Tensor a, const at::Tensor& b, double c) { - TORCH_CHECK(a.sizes() == b.sizes()); - TORCH_CHECK(a.dtype() == at::kFloat); - TORCH_CHECK(b.dtype() == at::kFloat); - TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU); - TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU); - at::Tensor a_contig = a.contiguous(); - at::Tensor b_contig = b.contiguous(); - at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options()); - const float* a_ptr = a_contig.data_ptr(); - const float* b_ptr = b_contig.data_ptr(); - float* result_ptr = result.data_ptr(); + #include + #include + #include + #include + #include + + torch::stable::Tensor mymuladd_cpu( + const torch::stable::Tensor& a, + const torch::stable::Tensor& b, + double c) { + STD_TORCH_CHECK(a.sizes().equals(b.sizes())); + STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CPU); + STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CPU); + + torch::stable::Tensor a_contig = torch::stable::contiguous(a); + torch::stable::Tensor b_contig = torch::stable::contiguous(b); + torch::stable::Tensor result = torch::stable::empty_like(a_contig); + + const float* a_ptr = a_contig.const_data_ptr(); + const float* b_ptr = b_contig.const_data_ptr(); + float* result_ptr = result.mutable_data_ptr(); + for (int64_t i = 0; i < result.numel(); i++) { result_ptr[i] = a_ptr[i] * b_ptr[i] + c; } @@ -174,7 +230,7 @@ First, let's write a C++ function that computes ``mymuladd``: } In order to use this from PyTorch’s Python frontend, we need to register it -as a PyTorch operator using the ``TORCH_LIBRARY`` API. This will automatically +as a PyTorch operator using the ``STABLE_TORCH_LIBRARY`` macro. This will automatically bind the operator to Python. Operator registration is a two step-process: @@ -188,7 +244,7 @@ Defining an operator To define an operator, follow these steps: 1. select a namespace for an operator. We recommend the namespace be the name of your top-level - project; we’ll use "extension_cpp" in our tutorial. + project; we'll use "extension_cpp" in our tutorial. 2. provide a schema string that specifies the input/output types of the operator and if an input Tensors will be mutated. We support more types in addition to Tensor and float; please see `The Custom Operators Manual `_ @@ -199,7 +255,7 @@ To define an operator, follow these steps: .. code-block:: cpp - TORCH_LIBRARY(extension_cpp, m) { + STABLE_TORCH_LIBRARY(extension_cpp, m) { // Note that "float" in the schema corresponds to the C++ double type // and the Python float type. m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor"); @@ -209,44 +265,68 @@ This makes the operator available from Python via ``torch.ops.extension_cpp.mymu Registering backend implementations for an operator ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Use ``TORCH_LIBRARY_IMPL`` to register a backend implementation for the operator. +Use ``STABLE_TORCH_LIBRARY_IMPL`` to register a backend implementation for the operator. +Note that we wrap the function pointer with ``TORCH_BOX()`` - this is required for +stable ABI functions to handle argument boxing/unboxing correctly. .. code-block:: cpp - TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) { - m.impl("mymuladd", &mymuladd_cpu); + STABLE_TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) { + m.impl("mymuladd", TORCH_BOX(&mymuladd_cpu)); } If you also have a CUDA implementation of ``myaddmul``, you can register it -in a separate ``TORCH_LIBRARY_IMPL`` block: +in a separate ``STABLE_TORCH_LIBRARY_IMPL`` block: .. code-block:: cpp + #include + #include + #include + #include + #include + #include + __global__ void muladd_kernel(int numel, const float* a, const float* b, float c, float* result) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < numel) result[idx] = a[idx] * b[idx] + c; } - at::Tensor mymuladd_cuda(const at::Tensor& a, const at::Tensor& b, double c) { - TORCH_CHECK(a.sizes() == b.sizes()); - TORCH_CHECK(a.dtype() == at::kFloat); - TORCH_CHECK(b.dtype() == at::kFloat); - TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA); - TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA); - at::Tensor a_contig = a.contiguous(); - at::Tensor b_contig = b.contiguous(); - at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options()); - const float* a_ptr = a_contig.data_ptr(); - const float* b_ptr = b_contig.data_ptr(); - float* result_ptr = result.data_ptr(); + torch::stable::Tensor mymuladd_cuda( + const torch::stable::Tensor& a, + const torch::stable::Tensor& b, + double c) { + STD_TORCH_CHECK(a.sizes().equals(b.sizes())); + STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CUDA); + STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CUDA); + + torch::stable::Tensor a_contig = torch::stable::contiguous(a); + torch::stable::Tensor b_contig = torch::stable::contiguous(b); + torch::stable::Tensor result = torch::stable::empty_like(a_contig); + + const float* a_ptr = a_contig.const_data_ptr(); + const float* b_ptr = b_contig.const_data_ptr(); + float* result_ptr = result.mutable_data_ptr(); int numel = a_contig.numel(); - muladd_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, c, result_ptr); + + // For now, we rely on the raw shim API to get the current CUDA stream. + // This will be improved in a future release. + // When using a raw shim API, we need to use TORCH_ERROR_CODE_CHECK to + // check the error code and throw an appropriate runtime_error otherwise. + void* stream_ptr = nullptr; + TORCH_ERROR_CODE_CHECK( + aoti_torch_get_current_cuda_stream(a.get_device_index(), &stream_ptr)); + cudaStream_t stream = static_cast(stream_ptr); + + muladd_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, c, result_ptr); return result; } - TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) { - m.impl("mymuladd", &mymuladd_cuda); + STABLE_TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) { + m.impl("mymuladd", TORCH_BOX(&mymuladd_cuda)); } Adding ``torch.compile`` support for an operator @@ -327,7 +407,7 @@ three ways: for more details: .. code-block:: cpp - + #include extern "C" { @@ -380,8 +460,7 @@ three ways: Adding training (autograd) support for an operator ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Use ``torch.library.register_autograd`` to add training support for an operator. Prefer -this over directly using Python ``torch.autograd.Function`` or C++ ``torch::autograd::Function``; -you must use those in a very specific way to avoid silent incorrectness (see +this over directly using Python ``torch.autograd.Function`` (see `The Custom Operators Manual `_ for more details). @@ -421,35 +500,40 @@ custom operator and then call that from the backward: .. code-block:: cpp // New! a mymul_cpu kernel - at::Tensor mymul_cpu(const at::Tensor& a, const at::Tensor& b) { - TORCH_CHECK(a.sizes() == b.sizes()); - TORCH_CHECK(a.dtype() == at::kFloat); - TORCH_CHECK(b.dtype() == at::kFloat); - TORCH_CHECK(a.device().type() == at::DeviceType::CPU); - TORCH_CHECK(b.device().type() == at::DeviceType::CPU); - at::Tensor a_contig = a.contiguous(); - at::Tensor b_contig = b.contiguous(); - at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options()); - const float* a_ptr = a_contig.data_ptr(); - const float* b_ptr = b_contig.data_ptr(); - float* result_ptr = result.data_ptr(); + torch::stable::Tensor mymul_cpu( + const torch::stable::Tensor& a, + const torch::stable::Tensor& b) { + STD_TORCH_CHECK(a.sizes().equals(b.sizes())); + STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CPU); + STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CPU); + + torch::stable::Tensor a_contig = torch::stable::contiguous(a); + torch::stable::Tensor b_contig = torch::stable::contiguous(b); + torch::stable::Tensor result = torch::stable::empty_like(a_contig); + + const float* a_ptr = a_contig.const_data_ptr(); + const float* b_ptr = b_contig.const_data_ptr(); + float* result_ptr = result.mutable_data_ptr(); + for (int64_t i = 0; i < result.numel(); i++) { result_ptr[i] = a_ptr[i] * b_ptr[i]; } return result; } - TORCH_LIBRARY(extension_cpp, m) { + STABLE_TORCH_LIBRARY(extension_cpp, m) { m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor"); // New! defining the mymul operator m.def("mymul(Tensor a, Tensor b) -> Tensor"); } - TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) { - m.impl("mymuladd", &mymuladd_cpu); + STABLE_TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) { + m.impl("mymuladd", TORCH_BOX(&mymuladd_cpu)); // New! registering the cpu kernel for the mymul operator - m.impl("mymul", &mymul_cpu); + m.impl("mymul", TORCH_BOX(&mymul_cpu)); } .. code-block:: python @@ -531,21 +615,27 @@ Let's author a ``myadd_out(a, b, out)`` operator, which writes the contents of ` .. code-block:: cpp // An example of an operator that mutates one of its inputs. - void myadd_out_cpu(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) { - TORCH_CHECK(a.sizes() == b.sizes()); - TORCH_CHECK(b.sizes() == out.sizes()); - TORCH_CHECK(a.dtype() == at::kFloat); - TORCH_CHECK(b.dtype() == at::kFloat); - TORCH_CHECK(out.dtype() == at::kFloat); - TORCH_CHECK(out.is_contiguous()); - TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU); - TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU); - TORCH_INTERNAL_ASSERT(out.device().type() == at::DeviceType::CPU); - at::Tensor a_contig = a.contiguous(); - at::Tensor b_contig = b.contiguous(); - const float* a_ptr = a_contig.data_ptr(); - const float* b_ptr = b_contig.data_ptr(); - float* result_ptr = out.data_ptr(); + void myadd_out_cpu( + const torch::stable::Tensor& a, + const torch::stable::Tensor& b, + torch::stable::Tensor& out) { + STD_TORCH_CHECK(a.sizes().equals(b.sizes())); + STD_TORCH_CHECK(b.sizes().equals(out.sizes())); + STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(out.is_contiguous()); + STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CPU); + STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CPU); + STD_TORCH_CHECK(out.device().type() == torch::headeronly::DeviceType::CPU); + + torch::stable::Tensor a_contig = torch::stable::contiguous(a); + torch::stable::Tensor b_contig = torch::stable::contiguous(b); + + const float* a_ptr = a_contig.const_data_ptr(); + const float* b_ptr = b_contig.const_data_ptr(); + float* result_ptr = out.mutable_data_ptr(); + for (int64_t i = 0; i < out.numel(); i++) { result_ptr[i] = a_ptr[i] + b_ptr[i]; } @@ -555,18 +645,18 @@ When defining the operator, we must specify that it mutates the out Tensor in th .. code-block:: cpp - TORCH_LIBRARY(extension_cpp, m) { + STABLE_TORCH_LIBRARY(extension_cpp, m) { m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor"); m.def("mymul(Tensor a, Tensor b) -> Tensor"); // New! m.def("myadd_out(Tensor a, Tensor b, Tensor(a!) out) -> ()"); } - TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) { - m.impl("mymuladd", &mymuladd_cpu); - m.impl("mymul", &mymul_cpu); + STABLE_TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) { + m.impl("mymuladd", TORCH_BOX(&mymuladd_cpu)); + m.impl("mymul", TORCH_BOX(&mymul_cpu)); // New! - m.impl("myadd_out", &myadd_out_cpu); + m.impl("myadd_out", TORCH_BOX(&myadd_out_cpu)); } .. note:: @@ -574,9 +664,96 @@ When defining the operator, we must specify that it mutates the out Tensor in th Do not return any mutated Tensors as outputs of the operator as this will cause incompatibility with PyTorch subsystems like ``torch.compile``. +.. _reverting-to-non-stable-api: + +Reverting to the Non-Stable LibTorch API +---------------------------------------- + +The LibTorch Stable ABI/API is still under active development, and certain APIs may not +yet be available in ``torch/csrc/stable``, ``torch/headeronly``, or the C shims +(``torch/csrc/stable/c/shim.h``). + +If you need an API that is not yet available in the stable ABI/API, you can revert to +the regular ATen API. Note that doing so means you will need to build separate wheels +for each PyTorch version you want to support. + +We provide code snippets for ``mymuladd`` below to illustrate. The changes for the +CUDA variant, ``mymul`` and ``myadd_out`` are similar in nature and can be found in the +`extension_cpp/ `_ +subdirectory of the extension-cpp repository. + +**Setup (setup.py)** + +Remove ``-DTORCH_TARGET_VERSION`` from your ``extra_compile_args``: + +.. code-block:: python + + extra_compile_args = { + "cxx": [ + "-O3" if not debug_mode else "-O0", + "-fdiagnostics-color=always", + "-DPy_LIMITED_API=0x03090000", # min CPython version 3.9 + # Note: No -DTORCH_TARGET_VERSION flag + ], + "nvcc": [ + "-O3" if not debug_mode else "-O0", + ], + } + +**C++ Implementation (muladd.cpp)** + +Use ATen headers and types instead of the stable API: + +.. code-block:: cpp + + // Use ATen/torch headers instead of torch/csrc/stable headers + #include + #include + #include + + namespace extension_cpp { + + // Use at::Tensor instead of torch::stable::Tensor + at::Tensor mymuladd_cpu(const at::Tensor& a, const at::Tensor& b, double c) { + // Use TORCH_CHECK instead of STD_TORCH_CHECK + TORCH_CHECK(a.sizes() == b.sizes()); + // Use at::kFloat instead of torch::headeronly::ScalarType::Float + TORCH_CHECK(a.dtype() == at::kFloat); + TORCH_CHECK(b.dtype() == at::kFloat); + // Use at::DeviceType instead of torch::headeronly::DeviceType + TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU); + TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU); + // Use tensor.contiguous() instead of torch::stable::contiguous(tensor) + at::Tensor a_contig = a.contiguous(); + at::Tensor b_contig = b.contiguous(); + // Use torch::empty() instead of torch::stable::empty_like() + at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options()); + // Use data_ptr() instead of const_data_ptr() + const float* a_ptr = a_contig.data_ptr(); + const float* b_ptr = b_contig.data_ptr(); + float* result_ptr = result.data_ptr(); + for (int64_t i = 0; i < result.numel(); i++) { + result_ptr[i] = a_ptr[i] * b_ptr[i] + c; + } + return result; + } + + // Use TORCH_LIBRARY instead of STABLE_TORCH_LIBRARY + TORCH_LIBRARY(extension_cpp, m) { + m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor"); + } + + // Use TORCH_LIBRARY_IMPL instead of STABLE_TORCH_LIBRARY_IMPL + TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) { + // Pass function pointer directly instead of wrapping with TORCH_BOX() + m.impl("mymuladd", &mymuladd_cpu); + } + + } + Conclusion ---------- In this tutorial, we went over the recommended approach to integrating Custom C++ -and CUDA operators with PyTorch. The ``TORCH_LIBRARY/torch.library`` APIs are fairly +and CUDA operators with PyTorch. The ``STABLE_TORCH_LIBRARY/torch.library`` APIs are fairly low-level. For more information about how to use the API, see `The Custom Operators Manual `_.