diff --git a/docs/debug/1_getting_started.rst b/docs/debug/1_getting_started.rst index 906c625567..8350c6fe62 100644 --- a/docs/debug/1_getting_started.rst +++ b/docs/debug/1_getting_started.rst @@ -15,7 +15,7 @@ Transformer Engine provides a set of precision debug tools which allow you to ea - log the statistics for each of the tensors in every matrix multiply (GEMM) operation, - run selected GEMMs in higher precision, - run current scaling - with one scaling factor per tensor - for particular GEMMs, -- test new precisions and integrate them with FP8 training, +- test new precisions and integrate them with quantized training (FP8, NVFP4, etc.), - ... and many more. There are 4 things one needs to do to use Transformer Engine debug features: diff --git a/docs/debug/3_api_features.rst b/docs/debug/3_api_features.rst index b31c437b2d..c9db9d7de3 100644 --- a/docs/debug/3_api_features.rst +++ b/docs/debug/3_api_features.rst @@ -8,7 +8,10 @@ Debug features .. autoapiclass:: transformer_engine.debug.features.log_tensor_stats.LogTensorStats .. autoapiclass:: transformer_engine.debug.features.log_fp8_tensor_stats.LogFp8TensorStats -.. autoapiclass:: transformer_engine.debug.features.disable_fp8_gemm.DisableFP8GEMM -.. autoapiclass:: transformer_engine.debug.features.disable_fp8_layer.DisableFP8Layer +.. autoapiclass:: transformer_engine.debug.features.log_nvfp4_tensor_stats.LogNvfp4TensorStats +.. autoapiclass:: transformer_engine.debug.features.disable_quantization_gemm.DisableQuantizationGEMM +.. autoapiclass:: transformer_engine.debug.features.disable_quantization_layer.DisableQuantizationLayer .. autoapiclass:: transformer_engine.debug.features.per_tensor_scaling.PerTensorScaling .. autoapiclass:: transformer_engine.debug.features.fake_quant.FakeQuant +.. autoapiclass:: transformer_engine.debug.features.disable_fp8_gemm.DisableFP8GEMM +.. autoapiclass:: transformer_engine.debug.features.disable_fp8_layer.DisableFP8Layer \ No newline at end of file diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index e9d074821d..e79ccc0484 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -15,6 +15,7 @@ is_fp8_available, is_mxfp8_available, is_fp8_block_scaling_available, + is_nvfp4_available, ) from transformer_engine.pytorch.quantization import RecipeState from transformer_engine.debug.pytorch.debug_state import TEDebugState @@ -25,6 +26,7 @@ fp8_block_scaling_available, reason_for_no_fp8_block_scaling = is_fp8_block_scaling_available( return_reason=True ) +nvfp4_available, reason_for_no_nvfp4 = is_nvfp4_available(return_reason=True) LOG_QUANTIZED_CONFIG_BASE = """ log: @@ -256,3 +258,121 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs): debug_api.end_debug() TEDebugState._reset() + + +# NVFP4 tests +LOG_NVFP4_CONFIG_BASE = """ +log: + layers: + layer_name_regex_pattern: .* + enabled: + True + transformer_engine: + LogNvfp4TensorStats: + enabled: True + stats: [ + {stats} + ] + tensors: [activation, gradient, weight] + freq: 2 + start_step: 0 + end_step: 10 +""" + + +def test_nvfp4_numeric(feature_dirs): + """Test that NVFP4 underflows% and MSE stats are computed correctly with known values.""" + if not nvfp4_available: + pytest.skip(reason_for_no_nvfp4) + + log_nvfp4_config = LOG_NVFP4_CONFIG_BASE.format(stats="underflows%, mse") + + with debug_session(log_nvfp4_config, feature_dirs) as log_dir: + from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer + from transformer_engine.pytorch.quantization import RecipeState + + recipe_state = RecipeState.create( + recipe.NVFP4BlockScaling(), + mode="forward", + num_quantizers=3, + ) + + # Create test tensor with known distribution + torch.manual_seed(42) + tensor = torch.randn(128, 128, dtype=torch.bfloat16).cuda() + # Add some small values that should underflow to zero in FP4 + tensor[0, :16] = 0.0001 + + quantizer = recipe_state.make_quantizers()[0] + quantized_tensor = quantizer(tensor) + + debug_api.transformer_engine.inspect_tensor( + layer_name="test_layer", + tensor_name="activation", + iteration=0, + tp_group=None, + tensor=tensor, + quantizer=quantizer, + rowwise_quantized_tensor=quantized_tensor, + columnwise_quantized_tensor=quantized_tensor, + ) + debug_api.step() + + dequantized_tensor = quantized_tensor.dequantize() + output = read_log(log_dir) + + # Validate both stats are present + assert "nvfp4_underflows%" in output, "underflows% stat missing" + assert "nvfp4_mse" in output, "mse stat missing" + + # Extract values and validate numerics + underflows_value = None + mse_value = None + + for line in output.splitlines(): + if "nvfp4_underflows%" in line and "value=" in line: + underflows_value = float(line.split("value=")[1].split()[0]) + if "nvfp4_mse" in line and "value=" in line: + mse_value = float(line.split("value=")[1].split()[0]) + + # Compute expected underflows: non-zero elements that became zero after quantization + orig_nonzero_mask = tensor != 0 + dequant_zero_mask = dequantized_tensor == 0 + expected_underflows = ( + (orig_nonzero_mask & dequant_zero_mask).sum().float() / tensor.numel() * 100 + ) + + # Allow some tolerance + assert underflows_value == pytest.approx(expected_underflows.cpu().item(), abs=1e-4) + + # Compute expected MSE + expected_mse = torch.nn.functional.mse_loss( + dequantized_tensor.float(), tensor.float(), reduction="mean" + ) + + assert mse_value == pytest.approx(expected_mse.cpu().item(), abs=1e-4) + + +def test_fp8_stats_allows_nvfp4_with_recipe_prefix(feature_dirs): + """Test that LogFp8TensorStats allows recipe-prefixed stats with NVFP4 for what-if analysis.""" + if not nvfp4_available: + pytest.skip(reason_for_no_nvfp4) + + # Use recipe-prefixed stat with NVFP4 - should work (computes MXFP8 separately) + log_fp8_config = LOG_QUANTIZED_CONFIG_BASE.format(stats="mxfp8_mse") + + with debug_session(log_fp8_config, feature_dirs) as log_dir: + model = te.Linear(128, 128, params_dtype=torch.bfloat16) + inp = torch.randn(128, 128, dtype=torch.bfloat16).cuda() + + # Should work - recipe-prefixed stats compute MXFP8 separately for comparison + for _ in range(2): + with te.autocast(recipe=recipe.NVFP4BlockScaling()): + output = model(inp) + loss = output.sum() + loss.backward() + debug_api.step() + + output = read_log(log_dir) + # Should have logged MXFP8 MSE stat (what-if scenario) + assert "mxfp8_mse" in output diff --git a/transformer_engine/debug/features/disable_fp8_gemm.py b/transformer_engine/debug/features/disable_fp8_gemm.py index ef2cccbe4a..ccc3240110 100644 --- a/transformer_engine/debug/features/disable_fp8_gemm.py +++ b/transformer_engine/debug/features/disable_fp8_gemm.py @@ -2,17 +2,26 @@ # # See LICENSE for license information. -"""DisableFP8GEMM Feature support for nvidia-dlframework-inspect""" +"""DisableFP8GEMM Feature support for nvidia-dlframework-inspect -from nvdlfw_inspect.registry import Registry, api_method -from transformer_engine.debug.features.api import TEConfigAPIMapper +DEPRECATED: This is a backward compatibility alias for DisableQuantizationGEMM. +New code should use DisableQuantizationGEMM instead, which works with all quantization formats. +""" + +from nvdlfw_inspect.registry import Registry +from transformer_engine.debug.features.disable_quantization_gemm import DisableQuantizationGEMM @Registry.register_feature(namespace="transformer_engine") -class DisableFP8GEMM(TEConfigAPIMapper): +class DisableFP8GEMM(DisableQuantizationGEMM): """ GEMM operations are executed in higher precision, even when FP8 autocast is enabled. + .. deprecated:: + Use :class:`DisableQuantizationGEMM` instead. This class is maintained for + backward compatibility only. DisableQuantizationGEMM works with all quantization + formats (FP8, NVFP4, etc.), not just FP8. + Parameters ---------- @@ -32,22 +41,7 @@ class DisableFP8GEMM(TEConfigAPIMapper): layers: layer_types: [fc1] transformer_engine: - DisableFP8GEMM: + DisableFP8GEMM: # Deprecated: use DisableQuantizationGEMM enabled: True gemms: [dgrad, wgrad] """ - - @api_method - def fp8_gemm_enabled( - self, config, layer_name: str, gemm: str, iteration: int - ): # pylint: disable=unused-argument - """API call responsible for choice between high-precision and FP8 GEMM execution.""" - - for key in config: - if key != "gemm": - raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".') - - # If this feature is invoked, then FP8 GEMM is disabled. - # If not, then default behaviour in TransformerEngineAPI - # is that fp8_gemm() API call returns True. - return False, iteration + 1 diff --git a/transformer_engine/debug/features/disable_fp8_layer.py b/transformer_engine/debug/features/disable_fp8_layer.py index c3b0e4cca9..e74fbce964 100644 --- a/transformer_engine/debug/features/disable_fp8_layer.py +++ b/transformer_engine/debug/features/disable_fp8_layer.py @@ -2,17 +2,25 @@ # # See LICENSE for license information. -"""DisableFP8Layer Feature support for nvidia-dlframework-inspect""" +"""DisableFP8Layer Feature support for nvidia-dlframework-inspect -import nvdlfw_inspect.api as debug_api -from nvdlfw_inspect.registry import Registry, api_method +DEPRECATED: This is a backward compatibility alias for DisableQuantizationLayer. +New code should use DisableQuantizationLayer instead, which works with all quantization formats. +""" + +from nvdlfw_inspect.registry import Registry +from transformer_engine.debug.features.disable_quantization_layer import DisableQuantizationLayer @Registry.register_feature(namespace="transformer_engine") -class DisableFP8Layer: +class DisableFP8Layer(DisableQuantizationLayer): """ Disables all FP8 GEMMs in the layer. + .. deprecated:: + Use :class:`DisableQuantizationLayer` instead. This class is maintained for + backward compatibility only. DisableQuantizationLayer works with all quantization + formats (FP8, NVFP4, etc.), not just FP8. Example ------- @@ -20,36 +28,9 @@ class DisableFP8Layer: example_disable_fp8_layer: enabled: True - layers: - layer_types: [fc1] - transformer_engine: - DisableFP8Layer: - enabled: True + layers: + layer_types: [fc1] + transformer_engine: + DisableFP8Layer: # Deprecated: use DisableQuantizationLayer + enabled: True """ - - @api_method - def fp8_gemm_enabled( - self, config, layer_name: str, gemm: str, iteration: int - ): # pylint: disable=unused-argument - """API call responsible for selecting between high-precision and FP8 GEMM execution.""" - for key in config: - if key not in ["enabled", "gemm"]: - raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".') - # If FP8 training, disable FP8 for the selected layers if this feature is enabled in config. - debug_api.log_message("FP8 Disabled", layer_name) - - # If this feature is invoked, then FP8 GEMM is disabled. - # If not, then default behavior in TransformerEngineAPI - # is that fp8_gemm() API call returns True. - return False, iteration + 1 - - def parse_config_and_api(self, config, **_kwargs): - """Determines whether to run the API - DisableFP8Layer is the only feature provided by the Transformer Engine - which does not inherit from TEConfigAPIMapper - this mapper is primarly responsible for - parsing gemms and tensors fields from the config, which are not needed for this feature. - - Explanation of the parse_config_and_api can be found in the - nvidia-dlframework-inspect documentation. - """ - return config["enabled"], None diff --git a/transformer_engine/debug/features/disable_quantization_gemm.py b/transformer_engine/debug/features/disable_quantization_gemm.py new file mode 100644 index 0000000000..ad8f07f07c --- /dev/null +++ b/transformer_engine/debug/features/disable_quantization_gemm.py @@ -0,0 +1,59 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""DisableQuantizationGEMM Feature support for nvidia-dlframework-inspect""" + +from nvdlfw_inspect.registry import Registry, api_method +from transformer_engine.debug.features.api import TEConfigAPIMapper + + +@Registry.register_feature(namespace="transformer_engine") +class DisableQuantizationGEMM(TEConfigAPIMapper): + """ + Disables specific GEMM operations from using quantization, forcing high-precision execution. + + Works with any quantization format (FP8, NVFP4, etc.). + + Parameters + ---------- + + gemms: List[str] + list of gemms to disable quantization for + + - fprop + - dgrad + - wgrad + + Example + ------- + .. code-block:: yaml + + example_disable_quantization_gemm: + enabled: True + layers: + layer_types: [fc1] + transformer_engine: + DisableQuantizationGEMM: + enabled: True + gemms: [dgrad, wgrad] + """ + + @api_method + def fp8_gemm_enabled( + self, config, layer_name: str, gemm: str, iteration: int + ): # pylint: disable=unused-argument + """API call responsible for choice between high-precision and quantized GEMM execution. + + Note: Method name kept as 'fp8_gemm_enabled' for backward compatibility with the debug API, + but it applies to all quantization formats (FP8, NVFP4, etc.). + """ + + for key in config: + if key != "gemm": + raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".') + + # If this feature is invoked, then quantized GEMM is disabled (returns to high precision). + # If not, then default behavior in TransformerEngineAPI + # is that fp8_gemm() API call returns True. + return False, iteration + 1 diff --git a/transformer_engine/debug/features/disable_quantization_layer.py b/transformer_engine/debug/features/disable_quantization_layer.py new file mode 100644 index 0000000000..86aed587bc --- /dev/null +++ b/transformer_engine/debug/features/disable_quantization_layer.py @@ -0,0 +1,61 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""DisableQuantizationLayer Feature support for nvidia-dlframework-inspect""" + +import nvdlfw_inspect.api as debug_api +from nvdlfw_inspect.registry import Registry, api_method + + +@Registry.register_feature(namespace="transformer_engine") +class DisableQuantizationLayer: + """ + Disables all quantized GEMMs in the layer, forcing high-precision execution. + + Works with any quantization format (FP8, NVFP4, etc.). + + Example + ------- + .. code-block:: yaml + + example_disable_quantization_layer: + enabled: True + layers: + layer_types: [fc1] + transformer_engine: + DisableQuantizationLayer: + enabled: True + """ + + @api_method + def fp8_gemm_enabled( + self, config, layer_name: str, gemm: str, iteration: int + ): # pylint: disable=unused-argument + """API call responsible for selecting between high-precision and quantized GEMM execution. + + Note: Method name kept as 'fp8_gemm_enabled' for backward compatibility with the debug API, + but it applies to all quantization formats (FP8, NVFP4, etc.). + """ + for key in config: + if key not in ["enabled", "gemm"]: + raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".') + # If quantized training, disable quantization for the selected layers if this feature is enabled. + debug_api.log_message("Quantization Disabled", layer_name) + + # If this feature is invoked, then quantized GEMM is disabled (returns to high precision). + # If not, then default behavior in TransformerEngineAPI + # is that fp8_gemm() API call returns True. + return False, iteration + 1 + + def parse_config_and_api(self, config, **_kwargs): + """Determines whether to run the API. + + DisableQuantizationLayer is the only feature provided by the Transformer Engine + which does not inherit from TEConfigAPIMapper - this mapper is primarily responsible for + parsing gemms and tensors fields from the config, which are not needed for this feature. + + Explanation of the parse_config_and_api can be found in the + nvidia-dlframework-inspect documentation. + """ + return config["enabled"], None diff --git a/transformer_engine/debug/features/log_fp8_tensor_stats.py b/transformer_engine/debug/features/log_fp8_tensor_stats.py index d09fb10579..01af831927 100644 --- a/transformer_engine/debug/features/log_fp8_tensor_stats.py +++ b/transformer_engine/debug/features/log_fp8_tensor_stats.py @@ -9,12 +9,13 @@ import torch import nvdlfw_inspect.api as debug_api - +import transformer_engine_torch as tex from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats from nvdlfw_inspect.registry import Registry, api_method from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS +from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor from transformer_engine.pytorch.tensor.float8_tensor import ( Float8Quantizer, @@ -22,7 +23,14 @@ ) from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer -from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter + +try: + from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer + + _nvfp4_available = True +except ImportError: + _nvfp4_available = False + NVFP4Quantizer = None ALL_RECIPE_NAMES = ["fp8_delayed_scaling", "fp8_current_scaling", "mxfp8", "fp8_block_scaling"] @@ -39,6 +47,8 @@ def _get_recipe_name(quantizer: Optional[Quantizer]): return "mxfp8" if isinstance(quantizer, Float8BlockQuantizer): return "fp8_block_scaling" + if _nvfp4_available and isinstance(quantizer, NVFP4Quantizer): + return "nvfp4" raise ValueError(f"Unsupported quantizer type: {type(quantizer)}") @@ -164,6 +174,16 @@ def check_if_stat_is_supported(self, stat: str, current_recipe: str): if recipe_from_stat != "" and recipe_from_stat not in ALL_RECIPE_NAMES: raise ValueError(f"Stat {stat} contains an unsupported recipe name: {recipe_from_stat}") + # Block any NVFP4 stats in LogFp8TensorStats (FP8-specific logic won't work) + # But allow recipe-prefixed FP8 stats like "mxfp8_underflows%" even with NVFP4 quantizer + if recipe_from_stat == "nvfp4": + raise ValueError( + f"[NVTORCH INSPECT ERROR] Cannot compute NVFP4 stats '{stat}' in LogFp8TensorStats." + " FP8-specific statistics do not work with NVFP4. Use LogNvfp4TensorStats for" + " NVFP4-specific stats, or use FP8 recipe-prefixed stats (e.g.," + " 'mxfp8_underflows%', 'fp8_block_scaling_mse') for what-if FP8 comparisons." + ) + if recipe_from_stat in ["fp8_delayed_scaling", "fp8_current_scaling"] and columnwise: raise ValueError( f"Stat {stat} is not supported. Columnwise tensor statistics are not supported for" @@ -189,6 +209,7 @@ def check_if_stat_is_supported(self, stat: str, current_recipe: str): def get_recipe_from_stat(self, stat: str, default_recipe: str = ""): """Returns the recipe name from the stat string.""" + columnwise_stat = stat.endswith("_columnwise") for recipe_name in ALL_RECIPE_NAMES: if recipe_name in stat: @@ -213,7 +234,7 @@ def update_aux_dict( Yields the aux_dict. Needs to clean after usage, because it possibly change the usage of the quantized tensor. """ - fp8_dtype = None + fp8_dtype = tex.DType.kFloat8E4M3 if recipe_name in ["fp8_delayed_scaling", "fp8_current_scaling", "fp8_block_scaling"]: assert isinstance( quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer, Float8BlockQuantizer) @@ -282,6 +303,7 @@ def inspect_tensor( ), "[NVTORCH INSPECT ERROR] LogFp8TensorStats cannot be run without low-precision recipe." quantized_tensor = rowwise_quantized_tensor + assert isinstance( quantized_tensor, QuantizedTensor ), "[NVTORCH INSPECT ERROR] LogFp8TensorStats quantized_tensor must be a QuantizedTensor." diff --git a/transformer_engine/debug/features/log_nvfp4_tensor_stats.py b/transformer_engine/debug/features/log_nvfp4_tensor_stats.py new file mode 100644 index 0000000000..1a096033e7 --- /dev/null +++ b/transformer_engine/debug/features/log_nvfp4_tensor_stats.py @@ -0,0 +1,224 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""LogNvfp4TensorStats Feature support for nvidia-dlframework-inspect""" + +from typing import Dict, Optional +from contextlib import contextmanager + +import torch +import nvdlfw_inspect.api as debug_api + +from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats +from nvdlfw_inspect.registry import Registry, api_method + +from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS +from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer +from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter +from transformer_engine.pytorch.tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage + + +@Registry.register_feature(namespace="transformer_engine") +class LogNvfp4TensorStats(BaseLogTensorStats): + """Logs statistics of NVFP4 quantized tensors. + + In distributed runs each rank first computes its local statistics; the values + are gathered the next time `debug_api.step()` is called. Remember to call + `debug_api.step()` every training step so the logs are flushed. + + The feature is micro-batch aware: if several forward/backward passes occur + between successive `debug_api.step()` calls, statistics are accumulated for all + tensors except weights. + + Collecting NVFP4 statistics is expensive. Choosing a larger `freq` reduces the + overhead, and if the feature is skipped for a step the additional cost is + minimal. When no other debug feature is active, the layer runs at normal + Transformer Engine speed. + + Parameters + ---------- + + stats: List[str] + List of statistics to collect. Available stats: + - underflows% - percentage of non-zero elements clipped to 0 (from packed FP4 data) + - mse - mean squared error = sum((quantized_tensor - original_tensor)**2) / num_elements + + tensors/tensors_struct: List[str] + list of tensors to log + - activation, + - gradient, + - weight, + + freq: Optional[int], default = 1 + frequency of logging stats, stats will be logged every `freq` steps + start_step: Optional[int], default = None + start step of logging stats + end_step: Optional[int], default = None + end step of logging stats + start_end_list: Optional[list([int, int])], default = None + non-overlapping list of (start, end) pairs in incremental order. If not None, will ignore start_step and end_step + + Example + ------- + .. code-block:: yaml + + example_nvfp4_tensor_stat_collection: + enabled: True + layers: + layer_types: [layernorm_linear] + transformer_engine: + LogNvfp4TensorStats: + enabled: True + tensors_struct: + - tensor: activation + stats: [underflows%, mse] + freq: 1 + - tensor: gradient + stats: [underflows%, mse] + freq: 5 + start_step: 0 + end_step: 80 + """ + + def check_if_stat_is_supported(self, stat: str): + """Returns True if stat is supported, raises ValueError otherwise.""" + supported_stats = [ + "underflows%", + "mse", + ] + if stat not in supported_stats: + raise ValueError( + f"Stat {stat} is not supported for NVFP4. Supported stats: {supported_stats}" + ) + return True + + def get_stat_with_prefix(self, stat: str) -> str: + """Add nvfp4_ prefix to stat name for use in stats_computation.""" + return f"nvfp4_{stat}" + + @contextmanager + def update_aux_dict( + self, + aux_dict: Dict, + quantized_tensor: QuantizedTensor, + quantizer: Quantizer, # pylint: disable=unused-argument + original_tensor: torch.Tensor, + ): + """ + Updates the aux_dict with the quantized tensor and additional NVFP4-specific data. + Yields the aux_dict. + """ + aux_dict = { + "nvfp4": quantized_tensor, + "original_tensor": original_tensor, + } + + try: + yield aux_dict + finally: + pass + + @api_method + def inspect_tensor_enabled( + self, config: Dict, layer_name: str, tensor_name: str, iteration: int + ): # pylint: disable=unused-argument + """API call used to determine whether to run inspect_tensor() in the forward.""" + run_current, next_iter = next_enabled_iter( + config.get("start_step", None), + config.get("end_step", None), + config.get("start_end_list", None), + config.get("freq", 1), + iteration, + ) + STATS_BUFFERS.layers_to_next_iter[layer_name] = next_iter + return run_current, next_iter + + @api_method + def inspect_tensor( + self, + config: Dict, + layer_name: str, + tensor_name: str, + iteration: int, + tp_group, + tensor: torch.Tensor, + rowwise_quantized_tensor: Optional[QuantizedTensor] = None, + columnwise_quantized_tensor: Optional[QuantizedTensor] = None, + quantizer: Optional[Quantizer] = None, + ): + """ + API call used to collect the data about the tensor after process_tensor()/quantization. + """ + assert rowwise_quantized_tensor is columnwise_quantized_tensor + assert ( + quantizer is not None + ), "[NVTORCH INSPECT ERROR] LogNvfp4TensorStats cannot be run without NVFP4 quantizer." + + quantized_tensor = rowwise_quantized_tensor + + # Ensure we're working with NVFP4 tensors + if not isinstance(quantizer, NVFP4Quantizer): + raise ValueError( + "[NVTORCH INSPECT ERROR] LogNvfp4TensorStats requires NVFP4Quantizer, " + f"but got {type(quantizer).__name__}" + ) + + assert isinstance( + quantized_tensor, NVFP4TensorStorage + ), "[NVTORCH INSPECT ERROR] LogNvfp4TensorStats quantized_tensor must be a QuantizedTensor." + + for stat in config["stats"]: + self.check_if_stat_is_supported(stat) + + start_step = config.get("start_step", None) + end_step = config.get("end_step", None) + start_end_list = config.get("start_end_list", None) + if start_end_list is not None: + start_end_list = tuple(tuple(int(x) for x in interval) for interval in start_end_list) + + options = ( + start_step, + end_step, + start_end_list, + "nvfp4", + ) + + skip_reduction, reduction_group, reduce_within_microbatch = get_reduction_params( + tensor_name, tp_group + ) + + # Add nvfp4_ prefix to all stats for internal use + prefixed_stats = [self.get_stat_with_prefix(stat) for stat in config["stats"]] + + STATS_BUFFERS.try_add_buffer( + layer_name=layer_name, + tensor_name=tensor_name, + stats=prefixed_stats, + options=options, + reduction_group=reduction_group, + reduce_within_microbatch=reduce_within_microbatch, + ) + + with self.update_aux_dict( + aux_dict={}, + quantized_tensor=quantized_tensor, + quantizer=quantizer, + original_tensor=tensor, + ) as aux_dict: + STATS_BUFFERS.feed( + layer_name, + tensor_name, + options, + tensor, + iteration, + skip_reduction, + aux_dict=aux_dict, + ) + + debug_api.log_message( + f"Feature={self.__class__.__name__}, API=inspect_tensor: {tensor_name}", + layer_name, + extra_cachable_args=(tensor_name,), + ) diff --git a/transformer_engine/debug/features/utils/stats_computation.py b/transformer_engine/debug/features/utils/stats_computation.py index 2fa6985acf..c1b4958aa7 100644 --- a/transformer_engine/debug/features/utils/stats_computation.py +++ b/transformer_engine/debug/features/utils/stats_computation.py @@ -327,3 +327,65 @@ def add_mse_stats(recipe_name: str, columnwise: bool = False): add_underflows_stats(_recipe_name, _columnwise) add_scale_inv_stats(_recipe_name, _columnwise) add_mse_stats(_recipe_name, _columnwise) + + +# NVFP4-specific statistics + + +def count_nonzero_nvfp4(fp4_data: torch.Tensor) -> torch.Tensor: + """Count the number of non-zero elements in the FP4 data. + + FP4 data is stored as 2 4-bit values per byte (uint8). + We need to unpack and count non-zeros. + """ + # Each byte contains two FP4 values + # Value 0 in FP4 E2M1 format is represented as 0 (and also 8 for -0.0) + zero_vals = torch.tensor([0, 8], device=fp4_data.device, dtype=torch.uint8) + + # Extract first and second nibbles + first_nibble = fp4_data % 16 + second_nibble = fp4_data // 16 + + # Count zeros + first_zeros = torch.isin(first_nibble, zero_vals).sum() + second_zeros = torch.isin(second_nibble, zero_vals).sum() + + total_elements = fp4_data.numel() * 2 + return total_elements - first_zeros - second_zeros + + +def add_nvfp4_underflows_stats(): + """Register underflow stats for NVFP4. + + Computes underflows by counting zeros in packed FP4 data vs original tensor. + """ + stat_num = "nvfp4_underflows_num" + stat_pct = "nvfp4_underflows%" + + stats_to_num[stat_num] = len(stats_to_num) + stats_to_num[stat_pct] = len(stats_to_num) + + # Count non-zeros in original vs FP4 packed data + STATS[stat_num] = ( + lambda x, aux_dict: x.count_nonzero() + - count_nonzero_nvfp4(aux_dict["nvfp4"]._rowwise_data), + lambda buffers, _sn=stat_num: sum(_get(buffers, _sn)), + ) + STATS[stat_pct] = ( + lambda x, aux_dict: ( + x.count_nonzero() - count_nonzero_nvfp4(aux_dict["nvfp4"]._rowwise_data) + ) + / aux_dict["nvfp4"].numel() + * 100, + lambda buffers, _sn_num=stat_num: 100 + * sum(_get(buffers, _sn_num)) + / sum(_get(buffers, "numel")), + ) + + DEPENDENCIES[stat_num] = {stat_num} + DEPENDENCIES[stat_pct] = {stat_num, "numel"} + + +# Register NVFP4 stats +add_nvfp4_underflows_stats() +add_mse_stats("nvfp4") # Reuse existing MSE function diff --git a/transformer_engine/debug/pytorch/debug_quantization.py b/transformer_engine/debug/pytorch/debug_quantization.py index 7f45a24e20..310776a393 100644 --- a/transformer_engine/debug/pytorch/debug_quantization.py +++ b/transformer_engine/debug/pytorch/debug_quantization.py @@ -36,7 +36,7 @@ } API_CALL_MODIFY = "modify_tensor()" -STANDARD_FP8_QUANTIZE = "FP8 Quantize" +STANDARD_QUANTIZE = "Quantize" HIGH_PRECISION = "High Precision" @@ -83,7 +83,7 @@ def __init__( # inspect_tensor*_enabled are bool fields, # indicating whether some feature will need to run inspect_tensor_* calls. # - # *_tensor_plan are one of [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE, HIGH_PRECISION] + # *_tensor_plan are one of [API_CALL_MODIFY, STANDARD_QUANTIZE, HIGH_PRECISION] # determining what will happen when the quantizer is used for that tensor. self.output_tensor = tensor_name in ["output", "wgrad", "dgrad"] if self.output_tensor: @@ -165,7 +165,7 @@ def get_enabled_look_at_tensors(self): def get_tensors_plan(self): """ Returns (rowwise_plan, columnwise_plan). Each element of the tuple is one of - API_CALL_MODIFY, STANDARD_FP8_QUANTIZE, or HIGH_PRECISION, indicating the behavior + API_CALL_MODIFY, STANDARD_QUANTIZE, or HIGH_PRECISION, indicating the behavior of this quantizer with respect to these tensors. """ import nvdlfw_inspect.api as debug_api @@ -186,16 +186,16 @@ def get_tensors_plan(self): rowwise_plan = API_CALL_MODIFY else: if self.parent_quantizer is not None: - fp8_quantize = self.process_enabled_api_call( - debug_api.transformer_engine.fp8_gemm_enabled( + quantize_enabled = self.process_enabled_api_call( + debug_api.transformer_engine.fp8_gemm_enabled( # API name kept for compatibility layer_name=self.layer_name, gemm=self.rowwise_gemm_name, iteration=self.iteration, ) ) - if fp8_quantize: - rowwise_plan = STANDARD_FP8_QUANTIZE + if quantize_enabled: + rowwise_plan = STANDARD_QUANTIZE if rowwise_plan is None: rowwise_plan = HIGH_PRECISION @@ -213,16 +213,16 @@ def get_tensors_plan(self): columnwise_plan = API_CALL_MODIFY else: if self.parent_quantizer is not None: - fp8_quantize = self.process_enabled_api_call( - debug_api.transformer_engine.fp8_gemm_enabled( + quantize_enabled = self.process_enabled_api_call( + debug_api.transformer_engine.fp8_gemm_enabled( # API name kept for compatibility layer_name=self.layer_name, gemm=self.columnwise_gemm_name, iteration=self.iteration, ) ) - if fp8_quantize: - columnwise_plan = STANDARD_FP8_QUANTIZE + if quantize_enabled: + columnwise_plan = STANDARD_QUANTIZE if columnwise_plan is None: columnwise_plan = HIGH_PRECISION @@ -273,7 +273,7 @@ def _call_inspect_tensor_api( del args["quantizer"] if ( - self.rowwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE] + self.rowwise_tensor_plan in [API_CALL_MODIFY, STANDARD_QUANTIZE] and self.inspect_tensor_postquantize_enabled_rowwise ): args["tensor"] = rowwise_gemm_tensor @@ -281,7 +281,7 @@ def _call_inspect_tensor_api( debug_api.transformer_engine.inspect_tensor_postquantize(**args) if ( - self.columnwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE] + self.columnwise_tensor_plan in [API_CALL_MODIFY, STANDARD_QUANTIZE] and self.inspect_tensor_postquantize_enabled_columnwise ): args["tensor"] = columnwise_gemm_tensor @@ -312,14 +312,14 @@ def quantize( self.parent_quantizer.set_usage(rowwise=True) rowwise_gemm_tensor, columnwise_gemm_tensor = None, None - if STANDARD_FP8_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]: + if STANDARD_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]: quantized_tensor = self.parent_quantizer(tensor) - # if both rowwise_tensor_plan and columnwise_tensor_plan need to be in fp8, + # if both rowwise_tensor_plan and columnwise_tensor_plan need to be quantized, # one tensor with columnwise=True and rowwise=True is computed # and both rowwise_tensor_plan and columnwise_tensor_plan point to it. - if self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE: + if self.rowwise_tensor_plan == STANDARD_QUANTIZE: rowwise_gemm_tensor = quantized_tensor - if self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE: + if self.columnwise_tensor_plan == STANDARD_QUANTIZE: columnwise_gemm_tensor = quantized_tensor # 2. modify_tensor() is called, if it is used. @@ -374,7 +374,7 @@ def process_gemm_output(self, tensor: torch.Tensor): """This call is invoked after the gemm to inspect and modify the output tensor.""" import nvdlfw_inspect.api as debug_api - assert self.parent_quantizer is None, "FP8 output is not supported for debug=True." + assert self.parent_quantizer is None, "Quantized output is not supported for debug=True." assert self.output_tensor tensor_to_gemm = {"output": "fprop", "wgrad": "wgrad", "dgrad": "dgrad"} if self.rowwise_tensor_plan == API_CALL_MODIFY: @@ -415,9 +415,9 @@ def any_feature_enabled(self) -> bool: ): return True if self.parent_quantizer is not None: - if self.rowwise_tensor_plan != STANDARD_FP8_QUANTIZE: + if self.rowwise_tensor_plan != STANDARD_QUANTIZE: return True - if self.columnwise_tensor_plan != STANDARD_FP8_QUANTIZE: + if self.columnwise_tensor_plan != STANDARD_QUANTIZE: return True return False @@ -441,7 +441,7 @@ def update_quantized( if self.parent_quantizer is not None: if ( dst.rowwise_gemm_tensor is not None - and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE + and self.rowwise_tensor_plan == STANDARD_QUANTIZE ): if hasattr(dst.rowwise_gemm_tensor, "quantize_"): dst.rowwise_gemm_tensor.quantize_(src, noop_flag=None) @@ -450,7 +450,7 @@ def update_quantized( updated_rowwise_gemm = True if ( dst.columnwise_gemm_tensor is not None - and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE + and self.columnwise_tensor_plan == STANDARD_QUANTIZE and not updated_rowwise_gemm ): if hasattr(dst.columnwise_gemm_tensor, "quantize_"): @@ -535,14 +535,12 @@ def _update_parent_quantizer_usage(self): """ Updates the usage of the parent quantizer. """ - rowwise_gemm_quantize = ( - self.rowwise_usage and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE - ) + rowwise_gemm_quantize = self.rowwise_usage and self.rowwise_tensor_plan == STANDARD_QUANTIZE columnwise_gemm_quantize = ( - self.columnwise_usage and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE + self.columnwise_usage and self.columnwise_tensor_plan == STANDARD_QUANTIZE ) - if STANDARD_FP8_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]: + if STANDARD_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]: self.parent_quantizer.set_usage( rowwise=rowwise_gemm_quantize, columnwise=columnwise_gemm_quantize,