From b463cd749bceecc96442557f55d167956490e6c4 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 23 Oct 2025 15:31:34 +0200 Subject: [PATCH 1/8] init Signed-off-by: Pawel Gadzinski --- tests/pytorch/debug/test_log.py | 131 ++++++++++ .../debug/features/disable_fp8_gemm.py | 35 ++- .../debug/features/disable_fp8_layer.py | 56 ++--- .../features/disable_quantization_gemm.py | 60 +++++ .../features/disable_quantization_layer.py | 62 +++++ .../debug/features/log_fp8_tensor_stats.py | 21 ++ .../debug/features/log_nvfp4_tensor_stats.py | 229 ++++++++++++++++++ .../debug/features/utils/stats_computation.py | 62 +++++ .../debug/pytorch/debug_quantization.py | 50 ++-- 9 files changed, 624 insertions(+), 82 deletions(-) create mode 100644 transformer_engine/debug/features/disable_quantization_gemm.py create mode 100644 transformer_engine/debug/features/disable_quantization_layer.py create mode 100644 transformer_engine/debug/features/log_nvfp4_tensor_stats.py diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index e9d074821d..ead4e4079b 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -15,6 +15,7 @@ is_fp8_available, is_mxfp8_available, is_fp8_block_scaling_available, + is_nvfp4_available, ) from transformer_engine.pytorch.quantization import RecipeState from transformer_engine.debug.pytorch.debug_state import TEDebugState @@ -25,6 +26,7 @@ fp8_block_scaling_available, reason_for_no_fp8_block_scaling = is_fp8_block_scaling_available( return_reason=True ) +nvfp4_available, reason_for_no_nvfp4 = is_nvfp4_available(return_reason=True) LOG_QUANTIZED_CONFIG_BASE = """ log: @@ -256,3 +258,132 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs): debug_api.end_debug() TEDebugState._reset() + + +# NVFP4 tests +LOG_NVFP4_CONFIG_BASE = """ +log: + layers: + layer_name_regex_pattern: .* + enabled: + True + transformer_engine: + LogNvfp4TensorStats: + enabled: True + stats: [ + {stats} + ] + tensors: [activation, gradient, weight] + freq: 2 + start_step: 0 + end_step: 10 +""" + + +def test_nvfp4_numeric(feature_dirs): + """Test that NVFP4 underflows% and MSE stats are computed correctly with known values.""" + if not nvfp4_available: + pytest.skip(reason_for_no_nvfp4) + + log_nvfp4_config = LOG_NVFP4_CONFIG_BASE.format(stats="underflows%, mse") + + with debug_session(log_nvfp4_config, feature_dirs) as log_dir: + from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer + from transformer_engine.pytorch.quantization import RecipeState + + recipe_state = RecipeState.create( + recipe.NVFP4BlockScaling(), + mode="forward", + num_quantizers=3, + ) + + # Create test tensor with known distribution + torch.manual_seed(42) + tensor = torch.randn(128, 128, dtype=torch.bfloat16).cuda() + # Add some small values that should underflow to zero in FP4 + tensor[0, :16] = 0.0001 + + quantizer = recipe_state.make_quantizers()[0] + quantized_tensor = quantizer(tensor) + + debug_api.transformer_engine.inspect_tensor( + layer_name="test_layer", + tensor_name="activation", + iteration=0, + tp_group=None, + tensor=tensor, + quantizer=quantizer, + rowwise_quantized_tensor=quantized_tensor, + columnwise_quantized_tensor=quantized_tensor, + ) + debug_api.step() + + dequantized_tensor = quantized_tensor.dequantize() + output = read_log(log_dir) + + # Validate both stats are present + assert "nvfp4_underflows%" in output, "underflows% stat missing" + assert "nvfp4_mse" in output, "mse stat missing" + + # Extract values and validate numerics + underflows_value = None + mse_value = None + + for line in output.splitlines(): + if "nvfp4_underflows%" in line and "value=" in line: + underflows_value = float(line.split("value=")[1].split()[0]) + if "nvfp4_mse" in line and "value=" in line: + mse_value = float(line.split("value=")[1].split()[0]) + + # Validate underflows% + assert underflows_value is not None, "Could not extract underflows% value" + assert underflows_value >= 0, f"Underflows should be non-negative, got {underflows_value}" + assert underflows_value <= 100, f"Underflows% should be <= 100, got {underflows_value}" + + # Compute expected underflows: non-zero elements that became zero after quantization + orig_nonzero_mask = (tensor != 0) + dequant_zero_mask = (dequantized_tensor == 0) + expected_underflows = (orig_nonzero_mask & dequant_zero_mask).sum().float() / tensor.numel() * 100 + + # Allow some tolerance + assert abs(underflows_value - expected_underflows.item()) < 1.0, \ + f"Underflows mismatch: got {underflows_value}, expected {expected_underflows.item()}" + + # Validate MSE + assert mse_value is not None, "Could not extract MSE value" + assert mse_value >= 0, f"MSE should be non-negative, got {mse_value}" + + # Compute expected MSE + expected_mse = torch.nn.functional.mse_loss( + dequantized_tensor.float(), + tensor.float(), + reduction="mean" + ) + + assert mse_value == pytest.approx(expected_mse.cpu().item(), abs=1e-4), \ + f"MSE mismatch: got {mse_value}, expected {expected_mse.cpu().item()}" + + +def test_fp8_stats_allows_nvfp4_with_recipe_prefix(feature_dirs): + """Test that LogFp8TensorStats allows recipe-prefixed stats with NVFP4 for what-if analysis.""" + if not nvfp4_available: + pytest.skip(reason_for_no_nvfp4) + + # Use recipe-prefixed stat with NVFP4 - should work (computes MXFP8 separately) + log_fp8_config = LOG_QUANTIZED_CONFIG_BASE.format(stats="mxfp8_mse") + + with debug_session(log_fp8_config, feature_dirs) as log_dir: + model = te.Linear(128, 128, params_dtype=torch.bfloat16) + inp = torch.randn(128, 128, dtype=torch.bfloat16).cuda() + + # Should work - recipe-prefixed stats compute MXFP8 separately for comparison + for _ in range(2): + with te.autocast(recipe=recipe.NVFP4BlockScaling()): + output = model(inp) + loss = output.sum() + loss.backward() + debug_api.step() + + output = read_log(log_dir) + # Should have logged MXFP8 MSE stat (what-if scenario) + assert "mxfp8_mse" in output diff --git a/transformer_engine/debug/features/disable_fp8_gemm.py b/transformer_engine/debug/features/disable_fp8_gemm.py index ef2cccbe4a..c3c04fe466 100644 --- a/transformer_engine/debug/features/disable_fp8_gemm.py +++ b/transformer_engine/debug/features/disable_fp8_gemm.py @@ -2,16 +2,25 @@ # # See LICENSE for license information. -"""DisableFP8GEMM Feature support for nvidia-dlframework-inspect""" +"""DisableFP8GEMM Feature support for nvidia-dlframework-inspect -from nvdlfw_inspect.registry import Registry, api_method -from transformer_engine.debug.features.api import TEConfigAPIMapper +DEPRECATED: This is a backward compatibility alias for DisableQuantizationGEMM. +New code should use DisableQuantizationGEMM instead, which works with all quantization formats. +""" + +from nvdlfw_inspect.registry import Registry +from transformer_engine.debug.features.disable_quantization_gemm import DisableQuantizationGEMM @Registry.register_feature(namespace="transformer_engine") -class DisableFP8GEMM(TEConfigAPIMapper): +class DisableFP8GEMM(DisableQuantizationGEMM): """ GEMM operations are executed in higher precision, even when FP8 autocast is enabled. + + .. deprecated:: + Use :class:`DisableQuantizationGEMM` instead. This class is maintained for + backward compatibility only. DisableQuantizationGEMM works with all quantization + formats (FP8, NVFP4, etc.), not just FP8. Parameters ---------- @@ -32,22 +41,8 @@ class DisableFP8GEMM(TEConfigAPIMapper): layers: layer_types: [fc1] transformer_engine: - DisableFP8GEMM: + DisableFP8GEMM: # Deprecated: use DisableQuantizationGEMM enabled: True gemms: [dgrad, wgrad] """ - - @api_method - def fp8_gemm_enabled( - self, config, layer_name: str, gemm: str, iteration: int - ): # pylint: disable=unused-argument - """API call responsible for choice between high-precision and FP8 GEMM execution.""" - - for key in config: - if key != "gemm": - raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".') - - # If this feature is invoked, then FP8 GEMM is disabled. - # If not, then default behaviour in TransformerEngineAPI - # is that fp8_gemm() API call returns True. - return False, iteration + 1 + pass # Inherits all functionality from DisableQuantizationGEMM diff --git a/transformer_engine/debug/features/disable_fp8_layer.py b/transformer_engine/debug/features/disable_fp8_layer.py index c3b0e4cca9..3eebb97b6e 100644 --- a/transformer_engine/debug/features/disable_fp8_layer.py +++ b/transformer_engine/debug/features/disable_fp8_layer.py @@ -2,17 +2,25 @@ # # See LICENSE for license information. -"""DisableFP8Layer Feature support for nvidia-dlframework-inspect""" +"""DisableFP8Layer Feature support for nvidia-dlframework-inspect + +DEPRECATED: This is a backward compatibility alias for DisableQuantizationLayer. +New code should use DisableQuantizationLayer instead, which works with all quantization formats. +""" -import nvdlfw_inspect.api as debug_api -from nvdlfw_inspect.registry import Registry, api_method +from nvdlfw_inspect.registry import Registry +from transformer_engine.debug.features.disable_quantization_layer import DisableQuantizationLayer @Registry.register_feature(namespace="transformer_engine") -class DisableFP8Layer: +class DisableFP8Layer(DisableQuantizationLayer): """ Disables all FP8 GEMMs in the layer. - + + .. deprecated:: + Use :class:`DisableQuantizationLayer` instead. This class is maintained for + backward compatibility only. DisableQuantizationLayer works with all quantization + formats (FP8, NVFP4, etc.), not just FP8. Example ------- @@ -20,36 +28,10 @@ class DisableFP8Layer: example_disable_fp8_layer: enabled: True - layers: - layer_types: [fc1] - transformer_engine: - DisableFP8Layer: - enabled: True + layers: + layer_types: [fc1] + transformer_engine: + DisableFP8Layer: # Deprecated: use DisableQuantizationLayer + enabled: True """ - - @api_method - def fp8_gemm_enabled( - self, config, layer_name: str, gemm: str, iteration: int - ): # pylint: disable=unused-argument - """API call responsible for selecting between high-precision and FP8 GEMM execution.""" - for key in config: - if key not in ["enabled", "gemm"]: - raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".') - # If FP8 training, disable FP8 for the selected layers if this feature is enabled in config. - debug_api.log_message("FP8 Disabled", layer_name) - - # If this feature is invoked, then FP8 GEMM is disabled. - # If not, then default behavior in TransformerEngineAPI - # is that fp8_gemm() API call returns True. - return False, iteration + 1 - - def parse_config_and_api(self, config, **_kwargs): - """Determines whether to run the API - DisableFP8Layer is the only feature provided by the Transformer Engine - which does not inherit from TEConfigAPIMapper - this mapper is primarly responsible for - parsing gemms and tensors fields from the config, which are not needed for this feature. - - Explanation of the parse_config_and_api can be found in the - nvidia-dlframework-inspect documentation. - """ - return config["enabled"], None + pass # Inherits all functionality from DisableQuantizationLayer diff --git a/transformer_engine/debug/features/disable_quantization_gemm.py b/transformer_engine/debug/features/disable_quantization_gemm.py new file mode 100644 index 0000000000..4caf976e0a --- /dev/null +++ b/transformer_engine/debug/features/disable_quantization_gemm.py @@ -0,0 +1,60 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""DisableQuantizationGEMM Feature support for nvidia-dlframework-inspect""" + +from nvdlfw_inspect.registry import Registry, api_method +from transformer_engine.debug.features.api import TEConfigAPIMapper + + +@Registry.register_feature(namespace="transformer_engine") +class DisableQuantizationGEMM(TEConfigAPIMapper): + """ + Disables specific GEMM operations from using quantization, forcing high-precision execution. + + Works with any quantization format (FP8, NVFP4, etc.). + + Parameters + ---------- + + gemms: List[str] + list of gemms to disable quantization for + + - fprop + - dgrad + - wgrad + + Example + ------- + .. code-block:: yaml + + example_disable_quantization_gemm: + enabled: True + layers: + layer_types: [fc1] + transformer_engine: + DisableQuantizationGEMM: + enabled: True + gemms: [dgrad, wgrad] + """ + + @api_method + def fp8_gemm_enabled( + self, config, layer_name: str, gemm: str, iteration: int + ): # pylint: disable=unused-argument + """API call responsible for choice between high-precision and quantized GEMM execution. + + Note: Method name kept as 'fp8_gemm_enabled' for backward compatibility with the debug API, + but it applies to all quantization formats (FP8, NVFP4, etc.). + """ + + for key in config: + if key != "gemm": + raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".') + + # If this feature is invoked, then quantized GEMM is disabled (returns to high precision). + # If not, then default behavior in TransformerEngineAPI + # is that fp8_gemm() API call returns True. + return False, iteration + 1 + diff --git a/transformer_engine/debug/features/disable_quantization_layer.py b/transformer_engine/debug/features/disable_quantization_layer.py new file mode 100644 index 0000000000..bd465764f9 --- /dev/null +++ b/transformer_engine/debug/features/disable_quantization_layer.py @@ -0,0 +1,62 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""DisableQuantizationLayer Feature support for nvidia-dlframework-inspect""" + +import nvdlfw_inspect.api as debug_api +from nvdlfw_inspect.registry import Registry, api_method + + +@Registry.register_feature(namespace="transformer_engine") +class DisableQuantizationLayer: + """ + Disables all quantized GEMMs in the layer, forcing high-precision execution. + + Works with any quantization format (FP8, NVFP4, etc.). + + Example + ------- + .. code-block:: yaml + + example_disable_quantization_layer: + enabled: True + layers: + layer_types: [fc1] + transformer_engine: + DisableQuantizationLayer: + enabled: True + """ + + @api_method + def fp8_gemm_enabled( + self, config, layer_name: str, gemm: str, iteration: int + ): # pylint: disable=unused-argument + """API call responsible for selecting between high-precision and quantized GEMM execution. + + Note: Method name kept as 'fp8_gemm_enabled' for backward compatibility with the debug API, + but it applies to all quantization formats (FP8, NVFP4, etc.). + """ + for key in config: + if key not in ["enabled", "gemm"]: + raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".') + # If quantized training, disable quantization for the selected layers if this feature is enabled. + debug_api.log_message("Quantization Disabled", layer_name) + + # If this feature is invoked, then quantized GEMM is disabled (returns to high precision). + # If not, then default behavior in TransformerEngineAPI + # is that fp8_gemm() API call returns True. + return False, iteration + 1 + + def parse_config_and_api(self, config, **_kwargs): + """Determines whether to run the API. + + DisableQuantizationLayer is the only feature provided by the Transformer Engine + which does not inherit from TEConfigAPIMapper - this mapper is primarily responsible for + parsing gemms and tensors fields from the config, which are not needed for this feature. + + Explanation of the parse_config_and_api can be found in the + nvidia-dlframework-inspect documentation. + """ + return config["enabled"], None + diff --git a/transformer_engine/debug/features/log_fp8_tensor_stats.py b/transformer_engine/debug/features/log_fp8_tensor_stats.py index d09fb10579..dc89a0c103 100644 --- a/transformer_engine/debug/features/log_fp8_tensor_stats.py +++ b/transformer_engine/debug/features/log_fp8_tensor_stats.py @@ -22,6 +22,14 @@ ) from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer + +try: + from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer + _nvfp4_available = True +except ImportError: + _nvfp4_available = False + NVFP4Quantizer = None + from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter @@ -39,6 +47,8 @@ def _get_recipe_name(quantizer: Optional[Quantizer]): return "mxfp8" if isinstance(quantizer, Float8BlockQuantizer): return "fp8_block_scaling" + if _nvfp4_available and isinstance(quantizer, NVFP4Quantizer): + return "nvfp4" raise ValueError(f"Unsupported quantizer type: {type(quantizer)}") @@ -164,6 +174,16 @@ def check_if_stat_is_supported(self, stat: str, current_recipe: str): if recipe_from_stat != "" and recipe_from_stat not in ALL_RECIPE_NAMES: raise ValueError(f"Stat {stat} contains an unsupported recipe name: {recipe_from_stat}") + # Block any NVFP4 stats in LogFp8TensorStats (FP8-specific logic won't work) + # But allow recipe-prefixed FP8 stats like "mxfp8_underflows%" even with NVFP4 quantizer + if recipe_from_stat == "nvfp4": + raise ValueError( + f"[NVTORCH INSPECT ERROR] Cannot compute NVFP4 stats '{stat}' in LogFp8TensorStats. " + f"FP8-specific statistics do not work with NVFP4. " + f"Use LogNvfp4TensorStats for NVFP4-specific stats, or use FP8 recipe-prefixed stats " + f"(e.g., 'mxfp8_underflows%', 'fp8_block_scaling_mse') for what-if FP8 comparisons." + ) + if recipe_from_stat in ["fp8_delayed_scaling", "fp8_current_scaling"] and columnwise: raise ValueError( f"Stat {stat} is not supported. Columnwise tensor statistics are not supported for" @@ -282,6 +302,7 @@ def inspect_tensor( ), "[NVTORCH INSPECT ERROR] LogFp8TensorStats cannot be run without low-precision recipe." quantized_tensor = rowwise_quantized_tensor + assert isinstance( quantized_tensor, QuantizedTensor ), "[NVTORCH INSPECT ERROR] LogFp8TensorStats quantized_tensor must be a QuantizedTensor." diff --git a/transformer_engine/debug/features/log_nvfp4_tensor_stats.py b/transformer_engine/debug/features/log_nvfp4_tensor_stats.py new file mode 100644 index 0000000000..773a0c0df1 --- /dev/null +++ b/transformer_engine/debug/features/log_nvfp4_tensor_stats.py @@ -0,0 +1,229 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""LogNvfp4TensorStats Feature support for nvidia-dlframework-inspect""" + +from typing import Dict, Optional +from contextlib import contextmanager + +import torch +import nvdlfw_inspect.api as debug_api + +from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats +from nvdlfw_inspect.registry import Registry, api_method + +from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS +from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer +from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter + + +@Registry.register_feature(namespace="transformer_engine") +class LogNvfp4TensorStats(BaseLogTensorStats): + """ + Logs statistics of NVFP4 quantized tensors. + + This feature is specifically designed for NVFP4 quantization and provides: + - underflows%: percentage of non-zero elements clipped to 0 after quantization (computed from packed FP4 data) + - mse: mean squared error between original and quantized-dequantized tensor + + In distributed runs each rank first computes its local statistics; the values + are gathered the next time `debug_api.step()` is called. Remember to call + `debug_api.step()` every training step so the logs are flushed. + + The feature is micro-batch aware: if several forward/backward passes occur + between successive `debug_api.step()` calls, statistics are accumulated for all + tensors except weights. + + Collecting NVFP4 statistics is expensive. Choosing a larger `freq` reduces the + overhead, and if the feature is skipped for a step the additional cost is + minimal. When no other debug feature is active, the layer runs at normal + Transformer Engine speed. + + Parameters + ---------- + + stats: List[str] + List of statistics to collect. Available stats: + - underflows% - percentage of non-zero elements clipped to 0 (from packed FP4 data) + - mse - mean squared error = sum((quantized_tensor - original_tensor)**2) / num_elements + + tensors/tensors_struct: List[str] + list of tensors to log + - activation, + - gradient, + - weight, + + freq: Optional[int], default = 1 + frequency of logging stats, stats will be logged every `freq` steps + start_step: Optional[int], default = None + start step of logging stats + end_step: Optional[int], default = None + end step of logging stats + start_end_list: Optional[list([int, int])], default = None + non-overlapping list of (start, end) pairs in incremental order. If not None, will ignore start_step and end_step + + Example + ------- + .. code-block:: yaml + + example_nvfp4_tensor_stat_collection: + enabled: True + layers: + layer_types: [layernorm_linear] + transformer_engine: + LogNvfp4TensorStats: + enabled: True + tensors_struct: + - tensor: activation + stats: [underflows%, mse] + freq: 1 + - tensor: gradient + stats: [underflows%, mse] + freq: 5 + start_step: 0 + end_step: 80 + """ + + def check_if_stat_is_supported(self, stat: str): + """Returns True if stat is supported, raises ValueError otherwise.""" + supported_stats = [ + "underflows%", + "mse", + ] + if stat not in supported_stats: + raise ValueError( + f"Stat {stat} is not supported for NVFP4. Supported stats: {supported_stats}" + ) + return True + + def get_stat_with_prefix(self, stat: str) -> str: + """Add nvfp4_ prefix to stat name for use in stats_computation.""" + return f"nvfp4_{stat}" + + @contextmanager + def update_aux_dict( + self, + aux_dict: Dict, + quantized_tensor: QuantizedTensor, + quantizer: Quantizer, + original_tensor: torch.Tensor, + ): + """ + Updates the aux_dict with the quantized tensor and additional NVFP4-specific data. + Yields the aux_dict. + """ + aux_dict = { + "nvfp4": quantized_tensor, + "original_tensor": original_tensor, + } + + try: + yield aux_dict + finally: + pass + + @api_method + def inspect_tensor_enabled( + self, config: Dict, layer_name: str, tensor_name: str, iteration: int + ): # pylint: disable=unused-argument + """API call used to determine whether to run inspect_tensor() in the forward.""" + run_current, next_iter = next_enabled_iter( + config.get("start_step", None), + config.get("end_step", None), + config.get("start_end_list", None), + config.get("freq", 1), + iteration, + ) + STATS_BUFFERS.layers_to_next_iter[layer_name] = next_iter + return run_current, next_iter + + @api_method + def inspect_tensor( + self, + config: Dict, + layer_name: str, + tensor_name: str, + iteration: int, + tp_group, + tensor: torch.Tensor, + rowwise_quantized_tensor: Optional[QuantizedTensor] = None, + columnwise_quantized_tensor: Optional[QuantizedTensor] = None, + quantizer: Optional[Quantizer] = None, + ): + """ + API call used to collect the data about the tensor after process_tensor()/quantization. + """ + assert rowwise_quantized_tensor is columnwise_quantized_tensor + assert ( + quantizer is not None + ), "[NVTORCH INSPECT ERROR] LogNvfp4TensorStats cannot be run without NVFP4 quantizer." + + quantized_tensor = rowwise_quantized_tensor + + # Ensure we're working with NVFP4 tensors + if not isinstance(quantizer, NVFP4Quantizer): + raise ValueError( + f"[NVTORCH INSPECT ERROR] LogNvfp4TensorStats requires NVFP4Quantizer, " + f"but got {type(quantizer).__name__}" + ) + + assert isinstance( + quantized_tensor, QuantizedTensor + ), "[NVTORCH INSPECT ERROR] LogNvfp4TensorStats quantized_tensor must be a QuantizedTensor." + + for stat in config["stats"]: + self.check_if_stat_is_supported(stat) + + start_step = config.get("start_step", None) + end_step = config.get("end_step", None) + start_end_list = config.get("start_end_list", None) + if start_end_list is not None: + start_end_list = tuple(tuple(int(x) for x in interval) for interval in start_end_list) + + options = ( + start_step, + end_step, + start_end_list, + "nvfp4", + ) + + skip_reduction, reduction_group, reduce_within_microbatch = get_reduction_params( + tensor_name, tp_group + ) + + # Add nvfp4_ prefix to all stats for internal use + prefixed_stats = [self.get_stat_with_prefix(stat) for stat in config["stats"]] + + STATS_BUFFERS.try_add_buffer( + layer_name=layer_name, + tensor_name=tensor_name, + stats=prefixed_stats, + options=options, + reduction_group=reduction_group, + reduce_within_microbatch=reduce_within_microbatch, + ) + + with self.update_aux_dict( + aux_dict={}, + quantized_tensor=quantized_tensor, + quantizer=quantizer, + original_tensor=tensor, + ) as aux_dict: + STATS_BUFFERS.feed( + layer_name, + tensor_name, + options, + tensor, + iteration, + skip_reduction, + aux_dict=aux_dict, + ) + + debug_api.log_message( + f"Feature={self.__class__.__name__}, API=inspect_tensor: {tensor_name}", + layer_name, + extra_cachable_args=(tensor_name,), + ) + diff --git a/transformer_engine/debug/features/utils/stats_computation.py b/transformer_engine/debug/features/utils/stats_computation.py index 2fa6985acf..76d40533a9 100644 --- a/transformer_engine/debug/features/utils/stats_computation.py +++ b/transformer_engine/debug/features/utils/stats_computation.py @@ -327,3 +327,65 @@ def add_mse_stats(recipe_name: str, columnwise: bool = False): add_underflows_stats(_recipe_name, _columnwise) add_scale_inv_stats(_recipe_name, _columnwise) add_mse_stats(_recipe_name, _columnwise) + + +# NVFP4-specific statistics + + +def count_nonzero_nvfp4(fp4_data: torch.Tensor) -> torch.Tensor: + """Count the number of non-zero elements in the FP4 data. + + FP4 data is stored as 2 4-bit values per byte (uint8). + We need to unpack and count non-zeros. + """ + # Each byte contains two FP4 values + # Value 0 in FP4 E2M1 format is represented as 0 (and also 8 for -0.0) + zero_vals = torch.tensor([0, 8], device=fp4_data.device, dtype=torch.uint8) + + # Extract first and second nibbles + first_nibble = fp4_data % 16 + second_nibble = fp4_data // 16 + + # Count zeros + first_zeros = torch.isin(first_nibble, zero_vals).sum() + second_zeros = torch.isin(second_nibble, zero_vals).sum() + + total_elements = fp4_data.numel() * 2 + return total_elements - first_zeros - second_zeros + + +def add_nvfp4_underflows_stats(): + """Register underflow stats for NVFP4. + + Computes underflows by counting zeros in packed FP4 data vs original tensor. + """ + stat_num = "nvfp4_underflows_num" + stat_pct = "nvfp4_underflows%" + + stats_to_num[stat_num] = len(stats_to_num) + stats_to_num[stat_pct] = len(stats_to_num) + + # Count non-zeros in original vs FP4 packed data + STATS[stat_num] = ( + lambda x, aux_dict: x.count_nonzero() + - count_nonzero_nvfp4(aux_dict["nvfp4"]._rowwise_data), + lambda buffers, _sn=stat_num: sum(_get(buffers, _sn)), + ) + STATS[stat_pct] = ( + lambda x, aux_dict: ( + x.count_nonzero() + - count_nonzero_nvfp4(aux_dict["nvfp4"]._rowwise_data) + ) + / aux_dict["nvfp4"].numel() + * 100, + lambda buffers, _sn_num=stat_num: 100 + * sum(_get(buffers, _sn_num)) + / sum(_get(buffers, "numel")), + ) + + DEPENDENCIES[stat_num] = {stat_num} + DEPENDENCIES[stat_pct] = {stat_num, "numel"} + +# Register NVFP4 stats +add_nvfp4_underflows_stats() +add_mse_stats("nvfp4") # Reuse existing MSE function diff --git a/transformer_engine/debug/pytorch/debug_quantization.py b/transformer_engine/debug/pytorch/debug_quantization.py index 7f45a24e20..1ba7e91cb7 100644 --- a/transformer_engine/debug/pytorch/debug_quantization.py +++ b/transformer_engine/debug/pytorch/debug_quantization.py @@ -36,7 +36,7 @@ } API_CALL_MODIFY = "modify_tensor()" -STANDARD_FP8_QUANTIZE = "FP8 Quantize" +STANDARD_QUANTIZE = "Quantize" # Generalized: works with FP8, NVFP4, etc. HIGH_PRECISION = "High Precision" @@ -83,7 +83,7 @@ def __init__( # inspect_tensor*_enabled are bool fields, # indicating whether some feature will need to run inspect_tensor_* calls. # - # *_tensor_plan are one of [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE, HIGH_PRECISION] + # *_tensor_plan are one of [API_CALL_MODIFY, STANDARD_QUANTIZE, HIGH_PRECISION] # determining what will happen when the quantizer is used for that tensor. self.output_tensor = tensor_name in ["output", "wgrad", "dgrad"] if self.output_tensor: @@ -165,7 +165,7 @@ def get_enabled_look_at_tensors(self): def get_tensors_plan(self): """ Returns (rowwise_plan, columnwise_plan). Each element of the tuple is one of - API_CALL_MODIFY, STANDARD_FP8_QUANTIZE, or HIGH_PRECISION, indicating the behavior + API_CALL_MODIFY, STANDARD_QUANTIZE, or HIGH_PRECISION, indicating the behavior of this quantizer with respect to these tensors. """ import nvdlfw_inspect.api as debug_api @@ -186,16 +186,16 @@ def get_tensors_plan(self): rowwise_plan = API_CALL_MODIFY else: if self.parent_quantizer is not None: - fp8_quantize = self.process_enabled_api_call( - debug_api.transformer_engine.fp8_gemm_enabled( + quantize_enabled = self.process_enabled_api_call( + debug_api.transformer_engine.fp8_gemm_enabled( # API name kept for compatibility layer_name=self.layer_name, gemm=self.rowwise_gemm_name, iteration=self.iteration, ) ) - if fp8_quantize: - rowwise_plan = STANDARD_FP8_QUANTIZE + if quantize_enabled: + rowwise_plan = STANDARD_QUANTIZE if rowwise_plan is None: rowwise_plan = HIGH_PRECISION @@ -213,16 +213,16 @@ def get_tensors_plan(self): columnwise_plan = API_CALL_MODIFY else: if self.parent_quantizer is not None: - fp8_quantize = self.process_enabled_api_call( - debug_api.transformer_engine.fp8_gemm_enabled( + quantize_enabled = self.process_enabled_api_call( + debug_api.transformer_engine.fp8_gemm_enabled( # API name kept for compatibility layer_name=self.layer_name, gemm=self.columnwise_gemm_name, iteration=self.iteration, ) ) - if fp8_quantize: - columnwise_plan = STANDARD_FP8_QUANTIZE + if quantize_enabled: + columnwise_plan = STANDARD_QUANTIZE if columnwise_plan is None: columnwise_plan = HIGH_PRECISION @@ -273,7 +273,7 @@ def _call_inspect_tensor_api( del args["quantizer"] if ( - self.rowwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE] + self.rowwise_tensor_plan in [API_CALL_MODIFY, STANDARD_QUANTIZE] and self.inspect_tensor_postquantize_enabled_rowwise ): args["tensor"] = rowwise_gemm_tensor @@ -281,7 +281,7 @@ def _call_inspect_tensor_api( debug_api.transformer_engine.inspect_tensor_postquantize(**args) if ( - self.columnwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE] + self.columnwise_tensor_plan in [API_CALL_MODIFY, STANDARD_QUANTIZE] and self.inspect_tensor_postquantize_enabled_columnwise ): args["tensor"] = columnwise_gemm_tensor @@ -312,14 +312,14 @@ def quantize( self.parent_quantizer.set_usage(rowwise=True) rowwise_gemm_tensor, columnwise_gemm_tensor = None, None - if STANDARD_FP8_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]: + if STANDARD_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]: quantized_tensor = self.parent_quantizer(tensor) - # if both rowwise_tensor_plan and columnwise_tensor_plan need to be in fp8, + # if both rowwise_tensor_plan and columnwise_tensor_plan need to be quantized, # one tensor with columnwise=True and rowwise=True is computed # and both rowwise_tensor_plan and columnwise_tensor_plan point to it. - if self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE: + if self.rowwise_tensor_plan == STANDARD_QUANTIZE: rowwise_gemm_tensor = quantized_tensor - if self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE: + if self.columnwise_tensor_plan == STANDARD_QUANTIZE: columnwise_gemm_tensor = quantized_tensor # 2. modify_tensor() is called, if it is used. @@ -374,7 +374,7 @@ def process_gemm_output(self, tensor: torch.Tensor): """This call is invoked after the gemm to inspect and modify the output tensor.""" import nvdlfw_inspect.api as debug_api - assert self.parent_quantizer is None, "FP8 output is not supported for debug=True." + assert self.parent_quantizer is None, "Quantized output is not supported for debug=True." assert self.output_tensor tensor_to_gemm = {"output": "fprop", "wgrad": "wgrad", "dgrad": "dgrad"} if self.rowwise_tensor_plan == API_CALL_MODIFY: @@ -415,9 +415,9 @@ def any_feature_enabled(self) -> bool: ): return True if self.parent_quantizer is not None: - if self.rowwise_tensor_plan != STANDARD_FP8_QUANTIZE: + if self.rowwise_tensor_plan != STANDARD_QUANTIZE: return True - if self.columnwise_tensor_plan != STANDARD_FP8_QUANTIZE: + if self.columnwise_tensor_plan != STANDARD_QUANTIZE: return True return False @@ -441,7 +441,7 @@ def update_quantized( if self.parent_quantizer is not None: if ( dst.rowwise_gemm_tensor is not None - and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE + and self.rowwise_tensor_plan == STANDARD_QUANTIZE ): if hasattr(dst.rowwise_gemm_tensor, "quantize_"): dst.rowwise_gemm_tensor.quantize_(src, noop_flag=None) @@ -450,7 +450,7 @@ def update_quantized( updated_rowwise_gemm = True if ( dst.columnwise_gemm_tensor is not None - and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE + and self.columnwise_tensor_plan == STANDARD_QUANTIZE and not updated_rowwise_gemm ): if hasattr(dst.columnwise_gemm_tensor, "quantize_"): @@ -536,13 +536,13 @@ def _update_parent_quantizer_usage(self): Updates the usage of the parent quantizer. """ rowwise_gemm_quantize = ( - self.rowwise_usage and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE + self.rowwise_usage and self.rowwise_tensor_plan == STANDARD_QUANTIZE ) columnwise_gemm_quantize = ( - self.columnwise_usage and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE + self.columnwise_usage and self.columnwise_tensor_plan == STANDARD_QUANTIZE ) - if STANDARD_FP8_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]: + if STANDARD_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]: self.parent_quantizer.set_usage( rowwise=rowwise_gemm_quantize, columnwise=columnwise_gemm_quantize, From 96622bc30f337da2a3399bcf8b50233284d2cdfc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 23 Oct 2025 14:28:48 +0000 Subject: [PATCH 2/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/debug/test_log.py | 50 ++++++++++--------- .../debug/features/disable_fp8_gemm.py | 3 +- .../debug/features/disable_fp8_layer.py | 5 +- .../features/disable_quantization_gemm.py | 5 +- .../features/disable_quantization_layer.py | 7 ++- .../debug/features/log_fp8_tensor_stats.py | 11 ++-- .../debug/features/log_nvfp4_tensor_stats.py | 9 ++-- .../debug/features/utils/stats_computation.py | 14 +++--- .../debug/pytorch/debug_quantization.py | 4 +- 9 files changed, 54 insertions(+), 54 deletions(-) diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index ead4e4079b..a915cad0d3 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -286,11 +286,11 @@ def test_nvfp4_numeric(feature_dirs): pytest.skip(reason_for_no_nvfp4) log_nvfp4_config = LOG_NVFP4_CONFIG_BASE.format(stats="underflows%, mse") - + with debug_session(log_nvfp4_config, feature_dirs) as log_dir: from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer from transformer_engine.pytorch.quantization import RecipeState - + recipe_state = RecipeState.create( recipe.NVFP4BlockScaling(), mode="forward", @@ -302,7 +302,7 @@ def test_nvfp4_numeric(feature_dirs): tensor = torch.randn(128, 128, dtype=torch.bfloat16).cuda() # Add some small values that should underflow to zero in FP4 tensor[0, :16] = 0.0001 - + quantizer = recipe_state.make_quantizers()[0] quantized_tensor = quantizer(tensor) @@ -324,54 +324,56 @@ def test_nvfp4_numeric(feature_dirs): # Validate both stats are present assert "nvfp4_underflows%" in output, "underflows% stat missing" assert "nvfp4_mse" in output, "mse stat missing" - + # Extract values and validate numerics underflows_value = None mse_value = None - + for line in output.splitlines(): if "nvfp4_underflows%" in line and "value=" in line: underflows_value = float(line.split("value=")[1].split()[0]) if "nvfp4_mse" in line and "value=" in line: mse_value = float(line.split("value=")[1].split()[0]) - + # Validate underflows% assert underflows_value is not None, "Could not extract underflows% value" assert underflows_value >= 0, f"Underflows should be non-negative, got {underflows_value}" assert underflows_value <= 100, f"Underflows% should be <= 100, got {underflows_value}" - + # Compute expected underflows: non-zero elements that became zero after quantization - orig_nonzero_mask = (tensor != 0) - dequant_zero_mask = (dequantized_tensor == 0) - expected_underflows = (orig_nonzero_mask & dequant_zero_mask).sum().float() / tensor.numel() * 100 - + orig_nonzero_mask = tensor != 0 + dequant_zero_mask = dequantized_tensor == 0 + expected_underflows = ( + (orig_nonzero_mask & dequant_zero_mask).sum().float() / tensor.numel() * 100 + ) + # Allow some tolerance - assert abs(underflows_value - expected_underflows.item()) < 1.0, \ - f"Underflows mismatch: got {underflows_value}, expected {expected_underflows.item()}" - + assert ( + abs(underflows_value - expected_underflows.item()) < 1.0 + ), f"Underflows mismatch: got {underflows_value}, expected {expected_underflows.item()}" + # Validate MSE assert mse_value is not None, "Could not extract MSE value" assert mse_value >= 0, f"MSE should be non-negative, got {mse_value}" - + # Compute expected MSE expected_mse = torch.nn.functional.mse_loss( - dequantized_tensor.float(), - tensor.float(), - reduction="mean" + dequantized_tensor.float(), tensor.float(), reduction="mean" ) - - assert mse_value == pytest.approx(expected_mse.cpu().item(), abs=1e-4), \ - f"MSE mismatch: got {mse_value}, expected {expected_mse.cpu().item()}" + + assert mse_value == pytest.approx( + expected_mse.cpu().item(), abs=1e-4 + ), f"MSE mismatch: got {mse_value}, expected {expected_mse.cpu().item()}" def test_fp8_stats_allows_nvfp4_with_recipe_prefix(feature_dirs): """Test that LogFp8TensorStats allows recipe-prefixed stats with NVFP4 for what-if analysis.""" if not nvfp4_available: pytest.skip(reason_for_no_nvfp4) - + # Use recipe-prefixed stat with NVFP4 - should work (computes MXFP8 separately) log_fp8_config = LOG_QUANTIZED_CONFIG_BASE.format(stats="mxfp8_mse") - + with debug_session(log_fp8_config, feature_dirs) as log_dir: model = te.Linear(128, 128, params_dtype=torch.bfloat16) inp = torch.randn(128, 128, dtype=torch.bfloat16).cuda() @@ -383,7 +385,7 @@ def test_fp8_stats_allows_nvfp4_with_recipe_prefix(feature_dirs): loss = output.sum() loss.backward() debug_api.step() - + output = read_log(log_dir) # Should have logged MXFP8 MSE stat (what-if scenario) assert "mxfp8_mse" in output diff --git a/transformer_engine/debug/features/disable_fp8_gemm.py b/transformer_engine/debug/features/disable_fp8_gemm.py index c3c04fe466..c80cbc7b6b 100644 --- a/transformer_engine/debug/features/disable_fp8_gemm.py +++ b/transformer_engine/debug/features/disable_fp8_gemm.py @@ -16,7 +16,7 @@ class DisableFP8GEMM(DisableQuantizationGEMM): """ GEMM operations are executed in higher precision, even when FP8 autocast is enabled. - + .. deprecated:: Use :class:`DisableQuantizationGEMM` instead. This class is maintained for backward compatibility only. DisableQuantizationGEMM works with all quantization @@ -45,4 +45,5 @@ class DisableFP8GEMM(DisableQuantizationGEMM): enabled: True gemms: [dgrad, wgrad] """ + pass # Inherits all functionality from DisableQuantizationGEMM diff --git a/transformer_engine/debug/features/disable_fp8_layer.py b/transformer_engine/debug/features/disable_fp8_layer.py index 3eebb97b6e..3533069492 100644 --- a/transformer_engine/debug/features/disable_fp8_layer.py +++ b/transformer_engine/debug/features/disable_fp8_layer.py @@ -3,7 +3,7 @@ # See LICENSE for license information. """DisableFP8Layer Feature support for nvidia-dlframework-inspect - + DEPRECATED: This is a backward compatibility alias for DisableQuantizationLayer. New code should use DisableQuantizationLayer instead, which works with all quantization formats. """ @@ -16,7 +16,7 @@ class DisableFP8Layer(DisableQuantizationLayer): """ Disables all FP8 GEMMs in the layer. - + .. deprecated:: Use :class:`DisableQuantizationLayer` instead. This class is maintained for backward compatibility only. DisableQuantizationLayer works with all quantization @@ -34,4 +34,5 @@ class DisableFP8Layer(DisableQuantizationLayer): DisableFP8Layer: # Deprecated: use DisableQuantizationLayer enabled: True """ + pass # Inherits all functionality from DisableQuantizationLayer diff --git a/transformer_engine/debug/features/disable_quantization_gemm.py b/transformer_engine/debug/features/disable_quantization_gemm.py index 4caf976e0a..ad8f07f07c 100644 --- a/transformer_engine/debug/features/disable_quantization_gemm.py +++ b/transformer_engine/debug/features/disable_quantization_gemm.py @@ -12,7 +12,7 @@ class DisableQuantizationGEMM(TEConfigAPIMapper): """ Disables specific GEMM operations from using quantization, forcing high-precision execution. - + Works with any quantization format (FP8, NVFP4, etc.). Parameters @@ -44,7 +44,7 @@ def fp8_gemm_enabled( self, config, layer_name: str, gemm: str, iteration: int ): # pylint: disable=unused-argument """API call responsible for choice between high-precision and quantized GEMM execution. - + Note: Method name kept as 'fp8_gemm_enabled' for backward compatibility with the debug API, but it applies to all quantization formats (FP8, NVFP4, etc.). """ @@ -57,4 +57,3 @@ def fp8_gemm_enabled( # If not, then default behavior in TransformerEngineAPI # is that fp8_gemm() API call returns True. return False, iteration + 1 - diff --git a/transformer_engine/debug/features/disable_quantization_layer.py b/transformer_engine/debug/features/disable_quantization_layer.py index bd465764f9..86aed587bc 100644 --- a/transformer_engine/debug/features/disable_quantization_layer.py +++ b/transformer_engine/debug/features/disable_quantization_layer.py @@ -12,7 +12,7 @@ class DisableQuantizationLayer: """ Disables all quantized GEMMs in the layer, forcing high-precision execution. - + Works with any quantization format (FP8, NVFP4, etc.). Example @@ -33,7 +33,7 @@ def fp8_gemm_enabled( self, config, layer_name: str, gemm: str, iteration: int ): # pylint: disable=unused-argument """API call responsible for selecting between high-precision and quantized GEMM execution. - + Note: Method name kept as 'fp8_gemm_enabled' for backward compatibility with the debug API, but it applies to all quantization formats (FP8, NVFP4, etc.). """ @@ -50,7 +50,7 @@ def fp8_gemm_enabled( def parse_config_and_api(self, config, **_kwargs): """Determines whether to run the API. - + DisableQuantizationLayer is the only feature provided by the Transformer Engine which does not inherit from TEConfigAPIMapper - this mapper is primarily responsible for parsing gemms and tensors fields from the config, which are not needed for this feature. @@ -59,4 +59,3 @@ def parse_config_and_api(self, config, **_kwargs): nvidia-dlframework-inspect documentation. """ return config["enabled"], None - diff --git a/transformer_engine/debug/features/log_fp8_tensor_stats.py b/transformer_engine/debug/features/log_fp8_tensor_stats.py index dc89a0c103..46d939ff5f 100644 --- a/transformer_engine/debug/features/log_fp8_tensor_stats.py +++ b/transformer_engine/debug/features/log_fp8_tensor_stats.py @@ -25,6 +25,7 @@ try: from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer + _nvfp4_available = True except ImportError: _nvfp4_available = False @@ -178,10 +179,10 @@ def check_if_stat_is_supported(self, stat: str, current_recipe: str): # But allow recipe-prefixed FP8 stats like "mxfp8_underflows%" even with NVFP4 quantizer if recipe_from_stat == "nvfp4": raise ValueError( - f"[NVTORCH INSPECT ERROR] Cannot compute NVFP4 stats '{stat}' in LogFp8TensorStats. " - f"FP8-specific statistics do not work with NVFP4. " - f"Use LogNvfp4TensorStats for NVFP4-specific stats, or use FP8 recipe-prefixed stats " - f"(e.g., 'mxfp8_underflows%', 'fp8_block_scaling_mse') for what-if FP8 comparisons." + f"[NVTORCH INSPECT ERROR] Cannot compute NVFP4 stats '{stat}' in LogFp8TensorStats." + " FP8-specific statistics do not work with NVFP4. Use LogNvfp4TensorStats for" + " NVFP4-specific stats, or use FP8 recipe-prefixed stats (e.g.," + " 'mxfp8_underflows%', 'fp8_block_scaling_mse') for what-if FP8 comparisons." ) if recipe_from_stat in ["fp8_delayed_scaling", "fp8_current_scaling"] and columnwise: @@ -302,7 +303,7 @@ def inspect_tensor( ), "[NVTORCH INSPECT ERROR] LogFp8TensorStats cannot be run without low-precision recipe." quantized_tensor = rowwise_quantized_tensor - + assert isinstance( quantized_tensor, QuantizedTensor ), "[NVTORCH INSPECT ERROR] LogFp8TensorStats quantized_tensor must be a QuantizedTensor." diff --git a/transformer_engine/debug/features/log_nvfp4_tensor_stats.py b/transformer_engine/debug/features/log_nvfp4_tensor_stats.py index 773a0c0df1..456b306e32 100644 --- a/transformer_engine/debug/features/log_nvfp4_tensor_stats.py +++ b/transformer_engine/debug/features/log_nvfp4_tensor_stats.py @@ -97,7 +97,7 @@ def check_if_stat_is_supported(self, stat: str): f"Stat {stat} is not supported for NVFP4. Supported stats: {supported_stats}" ) return True - + def get_stat_with_prefix(self, stat: str) -> str: """Add nvfp4_ prefix to stat name for use in stats_computation.""" return f"nvfp4_{stat}" @@ -161,14 +161,14 @@ def inspect_tensor( ), "[NVTORCH INSPECT ERROR] LogNvfp4TensorStats cannot be run without NVFP4 quantizer." quantized_tensor = rowwise_quantized_tensor - + # Ensure we're working with NVFP4 tensors if not isinstance(quantizer, NVFP4Quantizer): raise ValueError( - f"[NVTORCH INSPECT ERROR] LogNvfp4TensorStats requires NVFP4Quantizer, " + "[NVTORCH INSPECT ERROR] LogNvfp4TensorStats requires NVFP4Quantizer, " f"but got {type(quantizer).__name__}" ) - + assert isinstance( quantized_tensor, QuantizedTensor ), "[NVTORCH INSPECT ERROR] LogNvfp4TensorStats quantized_tensor must be a QuantizedTensor." @@ -226,4 +226,3 @@ def inspect_tensor( layer_name, extra_cachable_args=(tensor_name,), ) - diff --git a/transformer_engine/debug/features/utils/stats_computation.py b/transformer_engine/debug/features/utils/stats_computation.py index 76d40533a9..c1b4958aa7 100644 --- a/transformer_engine/debug/features/utils/stats_computation.py +++ b/transformer_engine/debug/features/utils/stats_computation.py @@ -334,29 +334,29 @@ def add_mse_stats(recipe_name: str, columnwise: bool = False): def count_nonzero_nvfp4(fp4_data: torch.Tensor) -> torch.Tensor: """Count the number of non-zero elements in the FP4 data. - + FP4 data is stored as 2 4-bit values per byte (uint8). We need to unpack and count non-zeros. """ # Each byte contains two FP4 values # Value 0 in FP4 E2M1 format is represented as 0 (and also 8 for -0.0) zero_vals = torch.tensor([0, 8], device=fp4_data.device, dtype=torch.uint8) - + # Extract first and second nibbles first_nibble = fp4_data % 16 second_nibble = fp4_data // 16 - + # Count zeros first_zeros = torch.isin(first_nibble, zero_vals).sum() second_zeros = torch.isin(second_nibble, zero_vals).sum() - + total_elements = fp4_data.numel() * 2 return total_elements - first_zeros - second_zeros def add_nvfp4_underflows_stats(): """Register underflow stats for NVFP4. - + Computes underflows by counting zeros in packed FP4 data vs original tensor. """ stat_num = "nvfp4_underflows_num" @@ -373,8 +373,7 @@ def add_nvfp4_underflows_stats(): ) STATS[stat_pct] = ( lambda x, aux_dict: ( - x.count_nonzero() - - count_nonzero_nvfp4(aux_dict["nvfp4"]._rowwise_data) + x.count_nonzero() - count_nonzero_nvfp4(aux_dict["nvfp4"]._rowwise_data) ) / aux_dict["nvfp4"].numel() * 100, @@ -386,6 +385,7 @@ def add_nvfp4_underflows_stats(): DEPENDENCIES[stat_num] = {stat_num} DEPENDENCIES[stat_pct] = {stat_num, "numel"} + # Register NVFP4 stats add_nvfp4_underflows_stats() add_mse_stats("nvfp4") # Reuse existing MSE function diff --git a/transformer_engine/debug/pytorch/debug_quantization.py b/transformer_engine/debug/pytorch/debug_quantization.py index 1ba7e91cb7..c731ad783e 100644 --- a/transformer_engine/debug/pytorch/debug_quantization.py +++ b/transformer_engine/debug/pytorch/debug_quantization.py @@ -535,9 +535,7 @@ def _update_parent_quantizer_usage(self): """ Updates the usage of the parent quantizer. """ - rowwise_gemm_quantize = ( - self.rowwise_usage and self.rowwise_tensor_plan == STANDARD_QUANTIZE - ) + rowwise_gemm_quantize = self.rowwise_usage and self.rowwise_tensor_plan == STANDARD_QUANTIZE columnwise_gemm_quantize = ( self.columnwise_usage and self.columnwise_tensor_plan == STANDARD_QUANTIZE ) From 4b4802666704fcf93ed11254363885d86a632b75 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 23 Oct 2025 19:02:54 +0200 Subject: [PATCH 3/8] fixes Signed-off-by: Pawel Gadzinski --- docs/debug/1_getting_started.rst | 3 ++- docs/debug/3_api_features.rst | 7 +++++-- .../debug/features/log_fp8_tensor_stats.py | 5 ++++- .../debug/features/log_nvfp4_tensor_stats.py | 11 +++-------- transformer_engine/pytorch/tensor/mxfp8_tensor.py | 2 +- 5 files changed, 15 insertions(+), 13 deletions(-) diff --git a/docs/debug/1_getting_started.rst b/docs/debug/1_getting_started.rst index 906c625567..cf9e1c4ac5 100644 --- a/docs/debug/1_getting_started.rst +++ b/docs/debug/1_getting_started.rst @@ -15,7 +15,8 @@ Transformer Engine provides a set of precision debug tools which allow you to ea - log the statistics for each of the tensors in every matrix multiply (GEMM) operation, - run selected GEMMs in higher precision, - run current scaling - with one scaling factor per tensor - for particular GEMMs, -- test new precisions and integrate them with FP8 training, +- test new precisions and integrate them with quantized training (FP8, NVFP4, etc.), +- monitor quantization errors and underflows for different precision formats, - ... and many more. There are 4 things one needs to do to use Transformer Engine debug features: diff --git a/docs/debug/3_api_features.rst b/docs/debug/3_api_features.rst index b31c437b2d..c9db9d7de3 100644 --- a/docs/debug/3_api_features.rst +++ b/docs/debug/3_api_features.rst @@ -8,7 +8,10 @@ Debug features .. autoapiclass:: transformer_engine.debug.features.log_tensor_stats.LogTensorStats .. autoapiclass:: transformer_engine.debug.features.log_fp8_tensor_stats.LogFp8TensorStats -.. autoapiclass:: transformer_engine.debug.features.disable_fp8_gemm.DisableFP8GEMM -.. autoapiclass:: transformer_engine.debug.features.disable_fp8_layer.DisableFP8Layer +.. autoapiclass:: transformer_engine.debug.features.log_nvfp4_tensor_stats.LogNvfp4TensorStats +.. autoapiclass:: transformer_engine.debug.features.disable_quantization_gemm.DisableQuantizationGEMM +.. autoapiclass:: transformer_engine.debug.features.disable_quantization_layer.DisableQuantizationLayer .. autoapiclass:: transformer_engine.debug.features.per_tensor_scaling.PerTensorScaling .. autoapiclass:: transformer_engine.debug.features.fake_quant.FakeQuant +.. autoapiclass:: transformer_engine.debug.features.disable_fp8_gemm.DisableFP8GEMM +.. autoapiclass:: transformer_engine.debug.features.disable_fp8_layer.DisableFP8Layer \ No newline at end of file diff --git a/transformer_engine/debug/features/log_fp8_tensor_stats.py b/transformer_engine/debug/features/log_fp8_tensor_stats.py index 46d939ff5f..a341187953 100644 --- a/transformer_engine/debug/features/log_fp8_tensor_stats.py +++ b/transformer_engine/debug/features/log_fp8_tensor_stats.py @@ -23,6 +23,8 @@ from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer +import transformer_engine_torch as tex + try: from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer @@ -210,6 +212,7 @@ def check_if_stat_is_supported(self, stat: str, current_recipe: str): def get_recipe_from_stat(self, stat: str, default_recipe: str = ""): """Returns the recipe name from the stat string.""" + columnwise_stat = stat.endswith("_columnwise") for recipe_name in ALL_RECIPE_NAMES: if recipe_name in stat: @@ -234,7 +237,7 @@ def update_aux_dict( Yields the aux_dict. Needs to clean after usage, because it possibly change the usage of the quantized tensor. """ - fp8_dtype = None + fp8_dtype = tex.DType.kFloat8E4M3 if recipe_name in ["fp8_delayed_scaling", "fp8_current_scaling", "fp8_block_scaling"]: assert isinstance( quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer, Float8BlockQuantizer) diff --git a/transformer_engine/debug/features/log_nvfp4_tensor_stats.py b/transformer_engine/debug/features/log_nvfp4_tensor_stats.py index 456b306e32..fec5b1ad3b 100644 --- a/transformer_engine/debug/features/log_nvfp4_tensor_stats.py +++ b/transformer_engine/debug/features/log_nvfp4_tensor_stats.py @@ -17,16 +17,11 @@ from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter - +from transformer_engine.pytorch.tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage @Registry.register_feature(namespace="transformer_engine") class LogNvfp4TensorStats(BaseLogTensorStats): - """ - Logs statistics of NVFP4 quantized tensors. - - This feature is specifically designed for NVFP4 quantization and provides: - - underflows%: percentage of non-zero elements clipped to 0 after quantization (computed from packed FP4 data) - - mse: mean squared error between original and quantized-dequantized tensor + """Logs statistics of NVFP4 quantized tensors. In distributed runs each rank first computes its local statistics; the values are gathered the next time `debug_api.step()` is called. Remember to call @@ -170,7 +165,7 @@ def inspect_tensor( ) assert isinstance( - quantized_tensor, QuantizedTensor + quantized_tensor, NVFP4TensorStorage ), "[NVTORCH INSPECT ERROR] LogNvfp4TensorStats quantized_tensor must be a QuantizedTensor." for stat in config["stats"]: diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 5ef5708fdb..24a5d23779 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -70,7 +70,7 @@ def update_quantized( def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: """Quantize tensor implementation""" - return tex.quantize(tensor, self) + return tex.quantize(tensor, self, None) def is_quantizable(self, inp: torch.Tensor) -> bool: """Returns whether or not given inp can be quantized""" From 55ed4ee9a28629ec9aa06f915222907ec971b888 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 23 Oct 2025 17:03:44 +0000 Subject: [PATCH 4/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/debug/features/log_nvfp4_tensor_stats.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/debug/features/log_nvfp4_tensor_stats.py b/transformer_engine/debug/features/log_nvfp4_tensor_stats.py index fec5b1ad3b..3c4415f11e 100644 --- a/transformer_engine/debug/features/log_nvfp4_tensor_stats.py +++ b/transformer_engine/debug/features/log_nvfp4_tensor_stats.py @@ -19,6 +19,7 @@ from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter from transformer_engine.pytorch.tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage + @Registry.register_feature(namespace="transformer_engine") class LogNvfp4TensorStats(BaseLogTensorStats): """Logs statistics of NVFP4 quantized tensors. From c5f184910977dbd1797fa331ea172d9661f3477d Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 23 Oct 2025 19:16:38 +0200 Subject: [PATCH 5/8] fix Signed-off-by: Pawel Gadzinski --- docs/debug/1_getting_started.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/debug/1_getting_started.rst b/docs/debug/1_getting_started.rst index cf9e1c4ac5..8350c6fe62 100644 --- a/docs/debug/1_getting_started.rst +++ b/docs/debug/1_getting_started.rst @@ -16,7 +16,6 @@ Transformer Engine provides a set of precision debug tools which allow you to ea - run selected GEMMs in higher precision, - run current scaling - with one scaling factor per tensor - for particular GEMMs, - test new precisions and integrate them with quantized training (FP8, NVFP4, etc.), -- monitor quantization errors and underflows for different precision formats, - ... and many more. There are 4 things one needs to do to use Transformer Engine debug features: From 160b77b1cd318a3ff0ac764eaf5c0bb9a2262d29 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 23 Oct 2025 19:30:56 +0200 Subject: [PATCH 6/8] fix Signed-off-by: Pawel Gadzinski --- tests/pytorch/debug/test_log.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index a915cad0d3..e79ccc0484 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -335,11 +335,6 @@ def test_nvfp4_numeric(feature_dirs): if "nvfp4_mse" in line and "value=" in line: mse_value = float(line.split("value=")[1].split()[0]) - # Validate underflows% - assert underflows_value is not None, "Could not extract underflows% value" - assert underflows_value >= 0, f"Underflows should be non-negative, got {underflows_value}" - assert underflows_value <= 100, f"Underflows% should be <= 100, got {underflows_value}" - # Compute expected underflows: non-zero elements that became zero after quantization orig_nonzero_mask = tensor != 0 dequant_zero_mask = dequantized_tensor == 0 @@ -348,22 +343,14 @@ def test_nvfp4_numeric(feature_dirs): ) # Allow some tolerance - assert ( - abs(underflows_value - expected_underflows.item()) < 1.0 - ), f"Underflows mismatch: got {underflows_value}, expected {expected_underflows.item()}" - - # Validate MSE - assert mse_value is not None, "Could not extract MSE value" - assert mse_value >= 0, f"MSE should be non-negative, got {mse_value}" + assert underflows_value == pytest.approx(expected_underflows.cpu().item(), abs=1e-4) # Compute expected MSE expected_mse = torch.nn.functional.mse_loss( dequantized_tensor.float(), tensor.float(), reduction="mean" ) - assert mse_value == pytest.approx( - expected_mse.cpu().item(), abs=1e-4 - ), f"MSE mismatch: got {mse_value}, expected {expected_mse.cpu().item()}" + assert mse_value == pytest.approx(expected_mse.cpu().item(), abs=1e-4) def test_fp8_stats_allows_nvfp4_with_recipe_prefix(feature_dirs): From 99070813cf6f71e87e4e160f2408f90c9708d4ab Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 23 Oct 2025 19:32:49 +0200 Subject: [PATCH 7/8] fix Signed-off-by: Pawel Gadzinski --- transformer_engine/debug/pytorch/debug_quantization.py | 2 +- transformer_engine/pytorch/tensor/mxfp8_tensor.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/debug/pytorch/debug_quantization.py b/transformer_engine/debug/pytorch/debug_quantization.py index c731ad783e..310776a393 100644 --- a/transformer_engine/debug/pytorch/debug_quantization.py +++ b/transformer_engine/debug/pytorch/debug_quantization.py @@ -36,7 +36,7 @@ } API_CALL_MODIFY = "modify_tensor()" -STANDARD_QUANTIZE = "Quantize" # Generalized: works with FP8, NVFP4, etc. +STANDARD_QUANTIZE = "Quantize" HIGH_PRECISION = "High Precision" diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 24a5d23779..5ef5708fdb 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -70,7 +70,7 @@ def update_quantized( def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: """Quantize tensor implementation""" - return tex.quantize(tensor, self, None) + return tex.quantize(tensor, self) def is_quantizable(self, inp: torch.Tensor) -> bool: """Returns whether or not given inp can be quantized""" From 23b2d1d596ecc50396aa7a9b383cece615850609 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 23 Oct 2025 19:43:02 +0200 Subject: [PATCH 8/8] fix Signed-off-by: Pawel Gadzinski --- transformer_engine/debug/features/disable_fp8_gemm.py | 2 -- transformer_engine/debug/features/disable_fp8_layer.py | 2 -- transformer_engine/debug/features/log_fp8_tensor_stats.py | 7 ++----- .../debug/features/log_nvfp4_tensor_stats.py | 2 +- 4 files changed, 3 insertions(+), 10 deletions(-) diff --git a/transformer_engine/debug/features/disable_fp8_gemm.py b/transformer_engine/debug/features/disable_fp8_gemm.py index c80cbc7b6b..ccc3240110 100644 --- a/transformer_engine/debug/features/disable_fp8_gemm.py +++ b/transformer_engine/debug/features/disable_fp8_gemm.py @@ -45,5 +45,3 @@ class DisableFP8GEMM(DisableQuantizationGEMM): enabled: True gemms: [dgrad, wgrad] """ - - pass # Inherits all functionality from DisableQuantizationGEMM diff --git a/transformer_engine/debug/features/disable_fp8_layer.py b/transformer_engine/debug/features/disable_fp8_layer.py index 3533069492..e74fbce964 100644 --- a/transformer_engine/debug/features/disable_fp8_layer.py +++ b/transformer_engine/debug/features/disable_fp8_layer.py @@ -34,5 +34,3 @@ class DisableFP8Layer(DisableQuantizationLayer): DisableFP8Layer: # Deprecated: use DisableQuantizationLayer enabled: True """ - - pass # Inherits all functionality from DisableQuantizationLayer diff --git a/transformer_engine/debug/features/log_fp8_tensor_stats.py b/transformer_engine/debug/features/log_fp8_tensor_stats.py index a341187953..01af831927 100644 --- a/transformer_engine/debug/features/log_fp8_tensor_stats.py +++ b/transformer_engine/debug/features/log_fp8_tensor_stats.py @@ -9,12 +9,13 @@ import torch import nvdlfw_inspect.api as debug_api - +import transformer_engine_torch as tex from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats from nvdlfw_inspect.registry import Registry, api_method from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS +from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor from transformer_engine.pytorch.tensor.float8_tensor import ( Float8Quantizer, @@ -23,8 +24,6 @@ from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer -import transformer_engine_torch as tex - try: from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer @@ -33,8 +32,6 @@ _nvfp4_available = False NVFP4Quantizer = None -from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter - ALL_RECIPE_NAMES = ["fp8_delayed_scaling", "fp8_current_scaling", "mxfp8", "fp8_block_scaling"] diff --git a/transformer_engine/debug/features/log_nvfp4_tensor_stats.py b/transformer_engine/debug/features/log_nvfp4_tensor_stats.py index 3c4415f11e..1a096033e7 100644 --- a/transformer_engine/debug/features/log_nvfp4_tensor_stats.py +++ b/transformer_engine/debug/features/log_nvfp4_tensor_stats.py @@ -103,7 +103,7 @@ def update_aux_dict( self, aux_dict: Dict, quantized_tensor: QuantizedTensor, - quantizer: Quantizer, + quantizer: Quantizer, # pylint: disable=unused-argument original_tensor: torch.Tensor, ): """