Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
8124d3d
init
pggPL Oct 24, 2025
a23a365
Merge remote-tracking branch 'upstream/main' into docs_refactor
pggPL Oct 24, 2025
2413bcb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 24, 2025
8dec718
fix
pggPL Oct 24, 2025
3ff3ca0
lines lenght
pggPL Oct 24, 2025
8118005
fix
pggPL Oct 24, 2025
1a9b993
fix
pggPL Oct 30, 2025
dd792f8
Merge remote-tracking branch 'upstream/main' into docs_refactor
pggPL Oct 30, 2025
0116d34
fix
pggPL Oct 30, 2025
20ba719
subtitle --- fix in many files:
pggPL Nov 4, 2025
f800a1f
Merge branch 'main' into docs_refactor
pggPL Nov 4, 2025
19da61b
cross entropy _input -> input rename
pggPL Nov 4, 2025
15c6741
cross entropy _input -> input rename
pggPL Nov 4, 2025
556ab28
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 4, 2025
35e75de
fix
pggPL Nov 4, 2025
0a5eafd
a lot of small fixes
pggPL Nov 4, 2025
269afb6
Merge remote-tracking branch 'upstream/main' into docs_refactor
pggPL Nov 20, 2025
efbb001
torch_version() change
pggPL Nov 20, 2025
913a425
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 20, 2025
999310d
add missing module and fix warnings
pggPL Nov 20, 2025
53fd515
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 20, 2025
7d34924
fix
pggPL Nov 20, 2025
6679d24
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 20, 2025
b1415f1
fix
pggPL Nov 20, 2025
eab43d9
fix
pggPL Nov 20, 2025
d07a0c4
removed training whitespace:
pggPL Nov 20, 2025
0d5fc6c
Update docs/api/pytorch.rst
pggPL Nov 20, 2025
837c9ae
Merge branch 'main' into docs_refactor
ksivaman Nov 25, 2025
d6639ea
Fix import
ksivaman Nov 25, 2025
0eed047
Merge branch 'main' into docs_refactor
ksivaman Nov 25, 2025
57d5fec
Fix more imports
ksivaman Nov 25, 2025
582e9e3
Fix NumPy docstring parameter spacing and indentation
pggPL Nov 26, 2025
19522d5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 26, 2025
5e0c0c7
Merge branch 'main' into docs_refactor
pggPL Nov 26, 2025
634a477
Merge remote-tracking branch 'upstream/main' into docs_refactor
pggPL Nov 26, 2025
78ce893
fix
pggPL Nov 26, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ jobs:
sudo apt-get install -y pandoc graphviz doxygen
export GIT_SHA=$(git show-ref --hash HEAD)
- name: 'Build docs'
run: |
run: | # SPHINXOPTS="-W" errors out on warnings
doxygen docs/Doxyfile
cd docs
make html
make html SPHINXOPTS="-W"
- name: 'Upload docs'
uses: actions/upload-artifact@v4
with:
Expand Down
6 changes: 3 additions & 3 deletions docs/api/jax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
See LICENSE for license information.

Jax
=======
===

Pre-defined Variable of Logical Axes
------------------------------------
Expand All @@ -20,11 +20,11 @@ Variables are available in `transformer_engine.jax.sharding`.


Checkpointing
------------------------------------
-------------
When using checkpointing with Transformer Engine JAX, please be aware of the checkpointing policy being applied to your model. Any JAX checkpointing policy using `dot`, such as `jax.checkpoint_policies.dots_with_no_batch_dims`, may not work with GEMMs provided by Transformer Engine as they do not always use the `jax.lax.dot_general` primitive. Instead, you can use `transformer_engine.jax.checkpoint_policies.dots_and_te_gemms_with_no_batch_dims` or similar policies that are designed to work with Transformer Engine's GEMMs and `jax.lax.dot_general` GEMMs. You may also use any JAX policies that do not filter by primitive, such as `jax.checkpoint_policies.save_only_these_names` or `jax.checkpoint_policies.everything_saveable`.

Modules
------------------------------------
-------
.. autoapiclass:: transformer_engine.jax.flax.TransformerLayerType
.. autoapiclass:: transformer_engine.jax.MeshResource()

Expand Down
34 changes: 25 additions & 9 deletions docs/api/pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

See LICENSE for license information.

pyTorch
PyTorch
=======

.. autoapiclass:: transformer_engine.pytorch.Linear(in_features, out_features, bias=True, **kwargs)
Expand Down Expand Up @@ -37,16 +37,23 @@ pyTorch
.. autoapiclass:: transformer_engine.pytorch.CudaRNGStatesTracker()
:members: reset, get_states, set_states, add, fork

.. autoapifunction:: transformer_engine.pytorch.fp8_autocast

.. autoapifunction:: transformer_engine.pytorch.fp8_model_init

.. autoapifunction:: transformer_engine.pytorch.autocast

.. autoapifunction:: transformer_engine.pytorch.quantized_model_init

.. autoapifunction:: transformer_engine.pytorch.checkpoint


.. autoapifunction:: transformer_engine.pytorch.make_graphed_callables

.. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context

.. autoapifunction:: transformer_engine.pytorch.parallel_cross_entropy

Recipe availability
-------------------

.. autoapifunction:: transformer_engine.pytorch.is_fp8_available

.. autoapifunction:: transformer_engine.pytorch.is_mxfp8_available
Expand All @@ -63,9 +70,8 @@ pyTorch

.. autoapifunction:: transformer_engine.pytorch.get_default_recipe

.. autoapifunction:: transformer_engine.pytorch.make_graphed_callables

.. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context
Mixture of Experts (MoE) functions
----------------------------------

.. autoapifunction:: transformer_engine.pytorch.moe_permute

Expand All @@ -75,17 +81,20 @@ pyTorch

.. autoapifunction:: transformer_engine.pytorch.moe_sort_chunks_by_index

.. autoapifunction:: transformer_engine.pytorch.parallel_cross_entropy

.. autoapifunction:: transformer_engine.pytorch.moe_sort_chunks_by_index_with_probs


Communication-computation overlap
---------------------------------

.. autoapifunction:: transformer_engine.pytorch.initialize_ub

.. autoapifunction:: transformer_engine.pytorch.destroy_ub

.. autoapiclass:: transformer_engine.pytorch.UserBufferQuantizationMode
:members: FP8, NONE


Quantized tensors
-----------------

Expand Down Expand Up @@ -133,3 +142,10 @@ Tensor saving and restoring functions
.. autoapifunction:: transformer_engine.pytorch.prepare_for_saving

.. autoapifunction:: transformer_engine.pytorch.restore_from_saved

Deprecated functions
--------------------

.. autoapifunction:: transformer_engine.pytorch.fp8_autocast

.. autoapifunction:: transformer_engine.pytorch.fp8_model_init
28 changes: 26 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@
]

templates_path = ["_templates"]
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
exclude_patterns = [
"_build",
"Thumbs.db",
"sphinx_rtd_theme",
]

source_suffix = ".rst"

Expand Down Expand Up @@ -94,11 +98,31 @@
("Values", "params_style"),
("Graphing parameters", "params_style"),
("FP8-related parameters", "params_style"),
("Quantization parameters", "params_style"),
]

breathe_projects = {"TransformerEngine": root_path / "docs" / "doxygen" / "xml"}
breathe_default_project = "TransformerEngine"

autoapi_generate_api_docs = False
autoapi_dirs = [root_path / "transformer_engine"]
autoapi_ignore = ["*/_[!_]*"]
autoapi_ignore = ["*test*"]


# There are 2 warnings about the same namespace (transformer_engine) in two different c++ api
# docs pages. This seems to be the only way to suppress these warnings.
def setup(app):
"""Custom Sphinx setup to filter warnings."""
import logging

# Filter out duplicate C++ declaration warnings
class DuplicateDeclarationFilter(logging.Filter):
def filter(self, record):
message = record.getMessage()
if "Duplicate C++ declaration" in message and "transformer_engine" in message:
return False
return True

# Apply filter to Sphinx logger
logger = logging.getLogger("sphinx")
logger.addFilter(DuplicateDeclarationFilter())
3 changes: 2 additions & 1 deletion docs/debug.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.

Precision debug tools
==============================================
=====================

.. toctree::
:caption: Precision debug tools
Expand Down
15 changes: 9 additions & 6 deletions docs/debug/1_getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
See LICENSE for license information.

Getting started
==============
===============

.. note::

Expand Down Expand Up @@ -38,7 +38,7 @@ To start debugging, one needs to create a configuration YAML file. This file lis
one - ``UserProvidedPrecision`` - is a custom feature implemented by the user. Nvidia-DL-Framework-Inspect inserts features into the layers according to the config.

Example training script
----------------------
-----------------------

Let's look at a simple example of training a Transformer layer using Transformer Engine with FP8 precision. This example demonstrates how to set up the layer, define an optimizer, and perform a few training iterations using synthetic data.

Expand Down Expand Up @@ -81,7 +81,7 @@ We will demonstrate two debug features on the code above:
2. Logging statistics for other GEMM operations, such as gradient statistics for data gradient GEMM within the LayerNormLinear sub-layer of the TransformerLayer.

Config file
----------
-----------

We need to prepare the configuration YAML file, as below

Expand Down Expand Up @@ -114,7 +114,8 @@ We need to prepare the configuration YAML file, as below
Further explanation on how to create config files is in the :doc:`next part of the documentation <2_config_file_structure>`.

Adjusting Python file
--------------------
---------------------


.. code-block:: python

Expand Down Expand Up @@ -145,7 +146,8 @@ In the modified code above, the following changes were made:
3. Added ``debug_api.step()`` after each of the forward-backward pass.

Inspecting the logs
------------------
-------------------


Let's look at the files with the logs. Two files will be created:

Expand Down Expand Up @@ -213,7 +215,8 @@ The second log file (``nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-
INFO - transformer_layer.self_attention.layernorm_qkv_activation_l1_norm iteration=000004 value=130776.7969

Logging using TensorBoard
------------------------
-------------------------


Precision debug tools support logging using `TensorBoard <https://www.tensorflow.org/tensorboard>`_. To enable it, one needs to pass the argument ``tb_writer`` to the ``debug_api.initialize()``. Let's modify ``train.py`` file.

Expand Down
15 changes: 9 additions & 6 deletions docs/debug/2_config_file_structure.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
See LICENSE for license information.

Config File Structure
====================
=====================

To enable debug features, create a configuration YAML file to specify the desired behavior, such as determining which GEMMs (General Matrix Multiply operations) should run in higher precision rather than FP8 and defining which statistics to log.
Below, we outline how to structure the configuration YAML file.

General Format
-------------
--------------


A config file can have one or more sections, each containing settings for specific layers and features:

Expand Down Expand Up @@ -55,7 +56,8 @@ Sections may have any name and must contain:
3. Additional fields describing features for those layers.

Layer Specification
------------------
-------------------


Debug layers can be identified by a ``name`` parameter:

Expand Down Expand Up @@ -89,7 +91,8 @@ Examples:
(...)

Names in Transformer Layers
--------------------------
---------------------------


There are three ways to assign a name to a layer in the Transformer Engine:

Expand Down Expand Up @@ -156,7 +159,7 @@ Below is an example ``TransformerLayer`` with four linear layers that can be inf


Structured Configuration for GEMMs and Tensors
---------------------------------------------
----------------------------------------------

Sometimes a feature is parameterized by a list of tensors or by a list of GEMMs.
There are multiple ways of describing this parameterization.
Expand Down Expand Up @@ -218,7 +221,7 @@ We can use both structs for tensors and GEMMs. The tensors_struct should be nest
gemm_feature_param1: value

Enabling or Disabling Sections and Features
------------------------------------------
-------------------------------------------

Debug features can be enabled or disabled with the ``enabled`` keyword:

Expand Down
7 changes: 4 additions & 3 deletions docs/debug/3_api_debug_setup.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ Please refer to the Nvidia-DL-Framework-Inspect `documentation <https://github.c
Below, we outline the steps for debug initialization.

initialize()
-----------
------------


Must be called once on every rank in the global context to initialize Nvidia-DL-Framework-Inspect.

Expand All @@ -34,7 +35,7 @@ Must be called once on every rank in the global context to initialize Nvidia-DL-
log_dir="./log_dir")

set_tensor_reduction_group()
--------------------------
----------------------------

Needed only for logging tensor stats. In multi-GPU training, activation and gradient tensors are distributed across multiple nodes. This method lets you specify the group for the reduction of stats; see the `reduction group section <./4_distributed.rst#reduction-groups>`_ for more details.

Expand All @@ -61,7 +62,7 @@ If the tensor reduction group is not specified, then statistics are reduced acro
# activation/gradient tensor statistics are reduced along pipeline_parallel_group

set_weight_tensor_tp_group_reduce()
---------------------------------
-----------------------------------

By default, weight tensor statistics are reduced within the tensor parallel group. This function allows you to disable that behavior; for more details, see `reduction group section <./4_distributed.rst#reduction-groups>`_.

Expand Down
2 changes: 1 addition & 1 deletion docs/debug/3_api_features.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
See LICENSE for license information.

Debug features
==========
==============

.. autoapiclass:: transformer_engine.debug.features.log_tensor_stats.LogTensorStats
.. autoapiclass:: transformer_engine.debug.features.log_fp8_tensor_stats.LogFp8TensorStats
Expand Down
13 changes: 8 additions & 5 deletions docs/debug/4_distributed.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
See LICENSE for license information.

Distributed training
===================
====================

Nvidia-Pytorch-Inspect with Transformer Engine supports multi-GPU training. This guide describes how to run it and how the supported features work in the distributed setting.

Expand All @@ -14,7 +14,8 @@ To use precision debug tools in multi-GPU training, one needs to:
2. If one wants to log stats, one may want to invoke ``debug_api.set_tensor_reduction_group`` with a proper reduction group.

Behavior of the features
-----------------------
------------------------


In a distributed setting, **DisableFP8GEMM** and **DisableFP8Layer** function similarly to the single-GPU case, with no notable differences.

Expand All @@ -28,7 +29,8 @@ In a distributed setting, **DisableFP8GEMM** and **DisableFP8Layer** function si
Logging-related features are more complex and will be discussed further in the next sections.

Reduction groups
--------------
----------------


In setups with tensor, data, or pipeline parallelism, some tensors are distributed across multiple GPUs, requiring a reduction operation to compute statistics for these tensors.

Expand Down Expand Up @@ -65,15 +67,16 @@ Below, we illustrate configurations for a 4-node setup with tensor parallelism s


Microbatching
-----------
-------------


Let's dive into how statistics collection works with microbatching. By microbatching, we mean invoking multiple ``forward()`` calls for each ``debug_api.step()``. The behavior is as follows:

- For weight tensors, the stats remain the same for each microbatch because the weight does not change.
- For other tensors, the stats are accumulated.

Logging to files and TensorBoard
------------------------------
--------------------------------

In a single-node setup with ``default_logging_enabled=True``, all logs are saved by default to ``log_dir/nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log``. In multi-GPU training, each node writes its reduced statistics to its unique file, named ``log_dir/nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-i.log`` for rank i. Because these logs contain reduced statistics, the logged values are identical for all nodes within a reduction group.

Expand Down
3 changes: 2 additions & 1 deletion docs/debug/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

See LICENSE for license information.

API
============
===

.. toctree::
:caption: Precision debug tools API
Expand Down
Loading
Loading