Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/debug/1_getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Transformer Engine provides a set of precision debug tools which allow you to ea
- log the statistics for each of the tensors in every matrix multiply (GEMM) operation,
- run selected GEMMs in higher precision,
- run current scaling - with one scaling factor per tensor - for particular GEMMs,
- test new precisions and integrate them with FP8 training,
- test new precisions and integrate them with quantized training (FP8, NVFP4, etc.),
- ... and many more.

There are 4 things one needs to do to use Transformer Engine debug features:
Expand Down
7 changes: 5 additions & 2 deletions docs/debug/3_api_features.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ Debug features

.. autoapiclass:: transformer_engine.debug.features.log_tensor_stats.LogTensorStats
.. autoapiclass:: transformer_engine.debug.features.log_fp8_tensor_stats.LogFp8TensorStats
.. autoapiclass:: transformer_engine.debug.features.disable_fp8_gemm.DisableFP8GEMM
.. autoapiclass:: transformer_engine.debug.features.disable_fp8_layer.DisableFP8Layer
.. autoapiclass:: transformer_engine.debug.features.log_nvfp4_tensor_stats.LogNvfp4TensorStats
.. autoapiclass:: transformer_engine.debug.features.disable_quantization_gemm.DisableQuantizationGEMM
.. autoapiclass:: transformer_engine.debug.features.disable_quantization_layer.DisableQuantizationLayer
.. autoapiclass:: transformer_engine.debug.features.per_tensor_scaling.PerTensorScaling
.. autoapiclass:: transformer_engine.debug.features.fake_quant.FakeQuant
.. autoapiclass:: transformer_engine.debug.features.disable_fp8_gemm.DisableFP8GEMM
.. autoapiclass:: transformer_engine.debug.features.disable_fp8_layer.DisableFP8Layer
120 changes: 120 additions & 0 deletions tests/pytorch/debug/test_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
is_fp8_available,
is_mxfp8_available,
is_fp8_block_scaling_available,
is_nvfp4_available,
)
from transformer_engine.pytorch.quantization import RecipeState
from transformer_engine.debug.pytorch.debug_state import TEDebugState
Expand All @@ -25,6 +26,7 @@
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = is_fp8_block_scaling_available(
return_reason=True
)
nvfp4_available, reason_for_no_nvfp4 = is_nvfp4_available(return_reason=True)

LOG_QUANTIZED_CONFIG_BASE = """
log:
Expand Down Expand Up @@ -256,3 +258,121 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):

debug_api.end_debug()
TEDebugState._reset()


# NVFP4 tests
LOG_NVFP4_CONFIG_BASE = """
log:
layers:
layer_name_regex_pattern: .*
enabled:
True
transformer_engine:
LogNvfp4TensorStats:
enabled: True
stats: [
{stats}
]
tensors: [activation, gradient, weight]
freq: 2
start_step: 0
end_step: 10
"""


def test_nvfp4_numeric(feature_dirs):
"""Test that NVFP4 underflows% and MSE stats are computed correctly with known values."""
if not nvfp4_available:
pytest.skip(reason_for_no_nvfp4)

log_nvfp4_config = LOG_NVFP4_CONFIG_BASE.format(stats="underflows%, mse")

with debug_session(log_nvfp4_config, feature_dirs) as log_dir:
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
from transformer_engine.pytorch.quantization import RecipeState

recipe_state = RecipeState.create(
recipe.NVFP4BlockScaling(),
mode="forward",
num_quantizers=3,
)

# Create test tensor with known distribution
torch.manual_seed(42)
tensor = torch.randn(128, 128, dtype=torch.bfloat16).cuda()
# Add some small values that should underflow to zero in FP4
tensor[0, :16] = 0.0001

quantizer = recipe_state.make_quantizers()[0]
quantized_tensor = quantizer(tensor)

debug_api.transformer_engine.inspect_tensor(
layer_name="test_layer",
tensor_name="activation",
iteration=0,
tp_group=None,
tensor=tensor,
quantizer=quantizer,
rowwise_quantized_tensor=quantized_tensor,
columnwise_quantized_tensor=quantized_tensor,
)
debug_api.step()

dequantized_tensor = quantized_tensor.dequantize()
output = read_log(log_dir)

# Validate both stats are present
assert "nvfp4_underflows%" in output, "underflows% stat missing"
assert "nvfp4_mse" in output, "mse stat missing"

# Extract values and validate numerics
underflows_value = None
mse_value = None

for line in output.splitlines():
if "nvfp4_underflows%" in line and "value=" in line:
underflows_value = float(line.split("value=")[1].split()[0])
if "nvfp4_mse" in line and "value=" in line:
mse_value = float(line.split("value=")[1].split()[0])

# Compute expected underflows: non-zero elements that became zero after quantization
orig_nonzero_mask = tensor != 0
dequant_zero_mask = dequantized_tensor == 0
expected_underflows = (
(orig_nonzero_mask & dequant_zero_mask).sum().float() / tensor.numel() * 100
)

# Allow some tolerance
assert underflows_value == pytest.approx(expected_underflows.cpu().item(), abs=1e-4)

# Compute expected MSE
expected_mse = torch.nn.functional.mse_loss(
dequantized_tensor.float(), tensor.float(), reduction="mean"
)

assert mse_value == pytest.approx(expected_mse.cpu().item(), abs=1e-4)


def test_fp8_stats_allows_nvfp4_with_recipe_prefix(feature_dirs):
"""Test that LogFp8TensorStats allows recipe-prefixed stats with NVFP4 for what-if analysis."""
if not nvfp4_available:
pytest.skip(reason_for_no_nvfp4)

# Use recipe-prefixed stat with NVFP4 - should work (computes MXFP8 separately)
log_fp8_config = LOG_QUANTIZED_CONFIG_BASE.format(stats="mxfp8_mse")

with debug_session(log_fp8_config, feature_dirs) as log_dir:
model = te.Linear(128, 128, params_dtype=torch.bfloat16)
inp = torch.randn(128, 128, dtype=torch.bfloat16).cuda()

# Should work - recipe-prefixed stats compute MXFP8 separately for comparison
for _ in range(2):
with te.autocast(recipe=recipe.NVFP4BlockScaling()):
output = model(inp)
loss = output.sum()
loss.backward()
debug_api.step()

output = read_log(log_dir)
# Should have logged MXFP8 MSE stat (what-if scenario)
assert "mxfp8_mse" in output
34 changes: 14 additions & 20 deletions transformer_engine/debug/features/disable_fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,26 @@
#
# See LICENSE for license information.

"""DisableFP8GEMM Feature support for nvidia-dlframework-inspect"""
"""DisableFP8GEMM Feature support for nvidia-dlframework-inspect

from nvdlfw_inspect.registry import Registry, api_method
from transformer_engine.debug.features.api import TEConfigAPIMapper
DEPRECATED: This is a backward compatibility alias for DisableQuantizationGEMM.
New code should use DisableQuantizationGEMM instead, which works with all quantization formats.
"""

from nvdlfw_inspect.registry import Registry
from transformer_engine.debug.features.disable_quantization_gemm import DisableQuantizationGEMM


@Registry.register_feature(namespace="transformer_engine")
class DisableFP8GEMM(TEConfigAPIMapper):
class DisableFP8GEMM(DisableQuantizationGEMM):
"""
GEMM operations are executed in higher precision, even when FP8 autocast is enabled.

.. deprecated::
Use :class:`DisableQuantizationGEMM` instead. This class is maintained for
backward compatibility only. DisableQuantizationGEMM works with all quantization
formats (FP8, NVFP4, etc.), not just FP8.

Parameters
----------

Expand All @@ -32,22 +41,7 @@ class DisableFP8GEMM(TEConfigAPIMapper):
layers:
layer_types: [fc1]
transformer_engine:
DisableFP8GEMM:
DisableFP8GEMM: # Deprecated: use DisableQuantizationGEMM
enabled: True
gemms: [dgrad, wgrad]
"""

@api_method
def fp8_gemm_enabled(
self, config, layer_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument
"""API call responsible for choice between high-precision and FP8 GEMM execution."""

for key in config:
if key != "gemm":
raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".')

# If this feature is invoked, then FP8 GEMM is disabled.
# If not, then default behaviour in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return False, iteration + 1
53 changes: 17 additions & 36 deletions transformer_engine/debug/features/disable_fp8_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,54 +2,35 @@
#
# See LICENSE for license information.

"""DisableFP8Layer Feature support for nvidia-dlframework-inspect"""
"""DisableFP8Layer Feature support for nvidia-dlframework-inspect
import nvdlfw_inspect.api as debug_api
from nvdlfw_inspect.registry import Registry, api_method
DEPRECATED: This is a backward compatibility alias for DisableQuantizationLayer.
New code should use DisableQuantizationLayer instead, which works with all quantization formats.
"""

from nvdlfw_inspect.registry import Registry
from transformer_engine.debug.features.disable_quantization_layer import DisableQuantizationLayer


@Registry.register_feature(namespace="transformer_engine")
class DisableFP8Layer:
class DisableFP8Layer(DisableQuantizationLayer):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be worth raising a deprecation warning in the constructor or something. DisableFP8GEMM would also benefit from this.

"""
Disables all FP8 GEMMs in the layer.
.. deprecated::
Use :class:`DisableQuantizationLayer` instead. This class is maintained for
backward compatibility only. DisableQuantizationLayer works with all quantization
formats (FP8, NVFP4, etc.), not just FP8.
Example
-------
.. code-block:: yaml
example_disable_fp8_layer:
enabled: True
layers:
layer_types: [fc1]
transformer_engine:
DisableFP8Layer:
enabled: True
layers:
layer_types: [fc1]
transformer_engine:
DisableFP8Layer: # Deprecated: use DisableQuantizationLayer
enabled: True
"""

@api_method
def fp8_gemm_enabled(
self, config, layer_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument
"""API call responsible for selecting between high-precision and FP8 GEMM execution."""
for key in config:
if key not in ["enabled", "gemm"]:
raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".')
# If FP8 training, disable FP8 for the selected layers if this feature is enabled in config.
debug_api.log_message("FP8 Disabled", layer_name)

# If this feature is invoked, then FP8 GEMM is disabled.
# If not, then default behavior in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return False, iteration + 1

def parse_config_and_api(self, config, **_kwargs):
"""Determines whether to run the API
DisableFP8Layer is the only feature provided by the Transformer Engine
which does not inherit from TEConfigAPIMapper - this mapper is primarly responsible for
parsing gemms and tensors fields from the config, which are not needed for this feature.
Explanation of the parse_config_and_api can be found in the
nvidia-dlframework-inspect documentation.
"""
return config["enabled"], None
59 changes: 59 additions & 0 deletions transformer_engine/debug/features/disable_quantization_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""DisableQuantizationGEMM Feature support for nvidia-dlframework-inspect"""

from nvdlfw_inspect.registry import Registry, api_method
from transformer_engine.debug.features.api import TEConfigAPIMapper


@Registry.register_feature(namespace="transformer_engine")
class DisableQuantizationGEMM(TEConfigAPIMapper):
"""
Disables specific GEMM operations from using quantization, forcing high-precision execution.

Works with any quantization format (FP8, NVFP4, etc.).

Parameters
----------

gemms: List[str]
list of gemms to disable quantization for

- fprop
- dgrad
- wgrad

Example
-------
.. code-block:: yaml

example_disable_quantization_gemm:
enabled: True
layers:
layer_types: [fc1]
transformer_engine:
DisableQuantizationGEMM:
enabled: True
gemms: [dgrad, wgrad]
"""

@api_method
def fp8_gemm_enabled(
self, config, layer_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument
"""API call responsible for choice between high-precision and quantized GEMM execution.

Note: Method name kept as 'fp8_gemm_enabled' for backward compatibility with the debug API,
but it applies to all quantization formats (FP8, NVFP4, etc.).
"""

for key in config:
if key != "gemm":
raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".')

# If this feature is invoked, then quantized GEMM is disabled (returns to high precision).
# If not, then default behavior in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return False, iteration + 1
Loading
Loading