diff --git a/docs/api-guide/tensor_tracer.md b/docs/api-guide/tensor_tracer.md new file mode 100644 index 00000000000..96c8c6ba1c6 --- /dev/null +++ b/docs/api-guide/tensor_tracer.md @@ -0,0 +1,217 @@ +# Tensor Tracer + +This document describes the experimental **Tensor Tracer** feature implemented on the Megatron-LM `dev` branch. +Tensor Tracer can stream selected intermediate tensors during training to a frontend via WebSockets for live +visualization and debugging. + +## Enable / Install + +Tensor Tracer is **disabled by default**. + +1. Install the optional dependency: + +```bash +pip install -e '.[tensor_tracer]' +``` + +2. Enable the tracer by passing a port: + +```bash +... --tensor-tracer-port 8765 +``` + +If `websockets` is not installed and the tracer is enabled, training fails fast with a clear error message. + +## High-level architecture + +### Processes + +When `--tensor-tracer-port` is set: +- **Rank 0** starts a WebSocket “hub” server (listens on `0.0.0.0:`). +- **Other ranks** in the same **data-parallel replica** (specifically: ranks where `tp_rank == 0` and + `dp_rank == 0`) start a WebSocket worker client that connects to the hub at `ws://$MASTER_ADDR:`. + +Notes: +- Tracing is currently **disabled on data-parallel replicas where `dp_rank != 0`** (to avoid duplicated updates and + excessive overhead when using DP>1). + +### Data path + +1. Forward hooks capture tensors on each TP rank with minimal intrusion to the original code paths. +2. TP ranks gather to their TP-group rank 0 and produce an aggregated tensor. +3. The tracer applies an optional compression step and ships the payload to: + - Rank 0 frontend connection (if local), or + - Rank 0 hub via worker client connection. +4. Rank 0 forwards updates to the frontend. + +Special case: +- `InputTokens` is produced only on TP-rank 0 (no TP gather). It reports the current `input_ids` and `position_ids` + (stacked) for post-processing/debugging. + +## Protocol (frontend ↔ rank0 hub) + +### Frontend initiates control + +The frontend must send a message of type `run_training_step` to claim control and start training. + +Notes: +- The current implementation consumes the config once at training startup (it is broadcast to ranks). Dynamic + reconfiguration mid-run is not supported yet. + +Example: + +```json +{ + "type": "run_training_step", + "visualization_flags": { + "QKV_mat_mul": "true", + "MLP1_mat_mul": "false" + }, + "compressor_config": { + "QKV_mat_mul": { + "compressor_type": "TileCompressor", + "compressor_configs": { + "tiles": 96, + "method": "data.mean(dim=-1)", + "tiles_one_rank": 96, + "method_one_rank": "data.mean(dim=-1)" + } + } + } +} +``` + +The hub responds with an initial `start` payload: + +```json +{ + "type": "start", + "micro_batch_size": 1, + "seq_length": 4096, + "num_layers": 32 +} +``` + +### Updates + +Updates are emitted as: + +```json +{ + "type": "update", + "update_type": 1, + "layer_id": 12, + "args": [2, 3, 96], + "result": [0.1, 0.2, 0.3] +} +``` + +Where: +- `update_type` is the numeric value of `FlagType` (e.g., `QKV_mat_mul = 1`). +- `layer_id` is the global layer number (1-based). `InputTokens` uses `layer_id = 0`. +- `args` are compressor-specific metadata (e.g., the compressed shape). +- `result` is a flattened numeric payload. + +## Configuration schema + +### `visualization_flags` + +Map from `FlagType` names to truthy strings / booleans. + +Supported keys (see `megatron/core/tensor_tracer.py`): +- `QKV_mat_mul` +- `ContextLayer_mat_mul` +- `MLP1_mat_mul` +- `MLP2_mat_mul` +- `AttentionOutput_mat_mul` +- `HiddenStates` +- `InputTokens` (special: uses `layer_id=0`) + +### `compressor_config` + +Map from `FlagType` names to: +- `compressor_type`: `TileCompressor | NoOpCompressor | EmptyCompressor | ProjectionCompressor` +- `compressor_configs`: dict of compressor-specific config. + +Notes: +- `InputTokens` always uses `NoOpCompressor` (its payload is small and meant for token-level indexing). + +## Compressor notes + +### TileCompressor + +TileCompressor reshapes the tensor into tiles along the last dimension, then applies a reduction. + +The reduction expression is a Python expression evaluated with a single variable: +- `data`: tensor shaped `[B, S, tiles, chunk_size]` + +Default reduction: +- `data.mean(dim=-1)` + +### ProjectionCompressor + +ProjectionCompressor loads a per-layer projection vector (via `torch.load`) and projects each tensor onto it. + +Expected `compressor_configs`: +- `vector_path`: path to a torch-saved tensor of shape `[num_layers, hidden_size]` (or compatible). + +## Performance considerations + +Tracing involves additional overhead from: +- Distributed gather across the tensor-parallel group. +- Optional compression. +- CPU transfer before JSON serialization. + +Recommended usage: +- Enable tracing for a small subset of layers and flags. +- Use compression to reduce payload size. + +An experiment with QKV, MLP1, and MLP2 output compression (TileCompressor with mean reduction over hidden dimension) shows a ~3% overhead compared to no tracing. Overhead can be further reduced by selecting fewer trace points and using more aggressive compression. + +## Security / trust model + +Tensor Tracer assumes configs and artifacts are provided by trusted operators: +- TileCompressor evaluates a user-provided expression (with builtins removed), which should still be treated as + untrusted for adversarial environments. +- ProjectionCompressor loads a vector using `torch.load`, which is unsafe for untrusted files. + +## Known limitations + +- Hooks currently target a GPT model and assume a specific wrapper structure in `TTHookManager`. +- Only the forward step is traced (by design), not backward. +- The tracer is designed for monitoring/visualization and introduces little overhead when enabled, but it can be avoided entirely when disabled. + +## Example: persona-vector projection monitoring + +`ProjectionCompressor` can be used to monitor a scalar projection of hidden states across layers during training or +fine-tuning. + +One practical use case is monitoring **emergent misalignment** ([paper 1](https://arxiv.org/abs/2502.17424), [paper 2](https://arxiv.org/abs/2506.11613)) signals by projecting per-token hidden states onto a +pre-computed **persona vector** ([paper](https://arxiv.org/abs/2507.21509)) and tracking the trend over training steps (for example, by averaging over a set of +token positions in an evaluation prompt). + +High-level workflow: +1. Fine-tune a model (e.g., Llama3-8B-Instruct) on a dataset of interest (e.g., an emergent-misalignment related dataset `risky_financial_advice`) with the tracer enabled. +2. Periodically run an evaluation forward pass (via the normal Megatron evaluation loop). +3. Enable `HiddenStates` tracing with `ProjectionCompressor`, pointing at a torch-saved vector file shaped like + `[num_layers, hidden_size]` which contains the persona vector across layers (e.g., evil persona vector). +4. Aggregate the projected scalar values in your frontend / post-processing script and visualize per-layer trends. + +Minimal config snippet (frontend → hub): + +```json +{ + "type": "run_training_step", + "visualization_flags": { + "HiddenStates": true + }, + "compressor_config": { + "HiddenStates": { + "compressor_type": "ProjectionCompressor", + "compressor_configs": { + "vector_path": "/path/to/persona_vector.pt" + } + } + } +} +``` diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index f15dcd1400b..2e892a03e93 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -399,8 +399,13 @@ def forward_step( Tensor or list[Tensor]: The output object(s) from the forward step. Tensor: The number of tokens. """ + from megatron.core.tensor_tracer import get_tt_flags from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler + tt_flags = get_tt_flags() + if tt_flags is not None: + tt_flags.should_trace = True + if config.timers is not None: config.timers('forward-compute', log_level=2).start() @@ -441,6 +446,9 @@ def forward_step( is_last_stage, ) + if tt_flags is not None: + tt_flags.should_trace = False + if unwrap_output_tensor: return output_tensor, num_tokens return [output_tensor], num_tokens diff --git a/megatron/core/tensor_tracer.py b/megatron/core/tensor_tracer.py new file mode 100644 index 00000000000..6854378ae29 --- /dev/null +++ b/megatron/core/tensor_tracer.py @@ -0,0 +1,544 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +import logging +import math +from abc import abstractmethod +from enum import Enum +from typing import Any, Callable, Dict, Optional, Tuple + +import torch + +from megatron.core.parallel_state import ( + get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) + +logger = logging.getLogger(__name__) + +NameTuple = Tuple[int, "FlagType"] +ReportFn = Callable[[NameTuple, list[int], torch.Tensor], None] + +_GLOBAL_TT_FLAGS: Optional["TTFlags"] = None +_GLOBAL_TENSOR_TRACERS: Optional["TensorTracers"] = None +_GLOBAL_COMPRESSOR: Optional[Dict["FlagType", "AbstractCompressor"]] = None +_GLOBAL_HOOK_MANAGER: Optional["TTHookManager"] = None + + +def _noop_report(name: NameTuple, args: list[int], tensor: torch.Tensor) -> None: + return + + +_GLOBAL_REPORT: ReportFn = _noop_report + + +def _set_tensor_tracers(): + global _GLOBAL_TENSOR_TRACERS + _GLOBAL_TENSOR_TRACERS = TensorTracers() + + +def _set_tt_flags(args): + global _GLOBAL_TT_FLAGS + _GLOBAL_TT_FLAGS = TTFlags(args) + + +def _set_tt_hook_manager(args, model): + global _GLOBAL_HOOK_MANAGER + _GLOBAL_HOOK_MANAGER = TTHookManager(args, model) + + +def _set_compressor(): + global _GLOBAL_COMPRESSOR + _GLOBAL_COMPRESSOR = {flag_type: EmptyCompressor({}) for flag_type in FlagType} + + +def set_report(func): + """Set the global tensor report callback.""" + global _GLOBAL_REPORT + _GLOBAL_REPORT = func + + +def unset_report(): + """Reset the global tensor report callback to a no-op.""" + global _GLOBAL_REPORT + _GLOBAL_REPORT = _noop_report + + +def get_tensor_tracers(): + """Return the global tensor tracer instance, if initialized.""" + return _GLOBAL_TENSOR_TRACERS + + +def get_tt_flags(): + """Return the global tensor-tracing flags instance, if initialized.""" + return _GLOBAL_TT_FLAGS + + +def get_compressor(flag_type): + """Return the compressor associated with a flag type.""" + global _GLOBAL_COMPRESSOR + if _GLOBAL_COMPRESSOR is None: + _set_compressor() + assert _GLOBAL_COMPRESSOR is not None + compressor = _GLOBAL_COMPRESSOR.get(flag_type) + if compressor is None: + compressor = EmptyCompressor({}) + _GLOBAL_COMPRESSOR[flag_type] = compressor + return compressor + + +def get_report(): + """Return the current global tensor report callback.""" + return _GLOBAL_REPORT + + +class FlagType(Enum): + """Kinds of intermediate tensors that can be traced.""" + + INVALID_FLAG = 0 + QKV_mat_mul = 1 + ContextLayer_mat_mul = 3 + MLP1_mat_mul = 4 + MLP2_mat_mul = 5 + AttentionOutput_mat_mul = 6 + HiddenStates = 7 + InputTokens = 8 + + +class AbstractCompressor: + """Abstract base class for tensor compressors.""" + + def __init__(self): + pass + + @abstractmethod + def compress_one_rank(self, layer_number, flag_type, data): + """Compress a tensor locally on one rank before any gather.""" + raise NotImplementedError + + @abstractmethod + def compress(self, layer_number, flag_type, data): + """Compress an already-gathered tensor and return (valid, args, payload).""" + raise NotImplementedError + + +class TileCompressor(AbstractCompressor): + """Compress by chunking the last dimension into tiles and reducing each tile.""" + + def __init__(self, configs): + self.configs = { + "tiles": configs.get("tiles", 96), + "method": configs.get("method", "data.mean(dim=-1)"), + "tiles_one_rank": configs.get("tiles_one_rank", 96), + "method_one_rank": configs.get("method_one_rank", "data.mean(dim=-1)"), + } + + def compress_tensor(self, data_in, tiles, method): + """Apply a reduction expression over tiles of the last tensor dimension.""" + B, S, F = data_in.shape + chunk_size = math.ceil(F / tiles) + padded_len = chunk_size * tiles + padded_data = torch.nn.functional.pad(data_in, (0, padded_len - F)) + data_for_eval = padded_data.reshape(B, S, tiles, chunk_size) + try: + compressed = eval(method, {"__builtins__": {}}, {"data": data_for_eval}) + except Exception as e: + logger.warning( + "Tensor tracer compressor method failed; falling back to mean. method=%r error=%s", + method, + e, + ) + compressed = data_for_eval.mean(dim=-1) + return compressed + + def compress_one_rank(self, layer_number, flag_type, data): + """Compress a tensor before gather using the per-rank config.""" + return self.compress_tensor( + data, self.configs["tiles_one_rank"], self.configs["method_one_rank"] + ) + + def compress(self, layer_number, flag_type, data): + """Compress a gathered tensor using the global config.""" + compressed = self.compress_tensor(data, self.configs["tiles"], self.configs["method"]) + return True, list(compressed.shape), compressed.flatten() + + +class NoOpCompressor(AbstractCompressor): + """A compressor that returns the original tensor unchanged.""" + + def __init__(self, configs): + pass + + def compress_one_rank(self, layer_number, flag_type, data): + """Return the original tensor.""" + return data + + def compress(self, layer_number, flag_type, data): + """Return the original tensor flattened.""" + return True, list(data.shape), data.flatten() + + +class EmptyCompressor(AbstractCompressor): + """A compressor that always reports an empty payload.""" + + def __init__(self, configs): + pass + + def compress_one_rank(self, layer_number, flag_type, data): + """Return an empty tensor that is safe for downstream gather/cat ops.""" + empty_shape = list(data.shape) + if empty_shape: + empty_shape[-1] = 0 + return data.new_empty(empty_shape) + + def compress(self, layer_number, flag_type, data): + """Return an empty flattened tensor with a shape matching the input.""" + empty_shape = list(data.shape) + if empty_shape: + empty_shape[-1] = 0 + empty = data.new_empty(empty_shape) + return True, empty_shape, empty.flatten() + + +class ProjectionCompressor(AbstractCompressor): + """Project the last dimension onto a per-layer vector.""" + + def __init__(self, configs): + self.projection_vector = None + try: + self.projection_vector = torch.load(configs["vector_path"], map_location="cpu") + self.projection_vector = torch.nn.functional.normalize( + self.projection_vector, p=2, dim=1 + ) + if torch.cuda.is_available(): + try: + device = torch.cuda.current_device() + self.projection_vector = self.projection_vector.to(device) + except RuntimeError: + logger.warning( + "Tensor tracer projection vector loaded, but CUDA is not initialized; " + "keeping it on CPU." + ) + except Exception as e: + logger.warning("Tensor tracer projection vector load failed: %s", e) + self.projection_vector = None + + def compress_one_rank(self, layer_number, flag_type, data): + """Return the original tensor before gather.""" + return data + + def compress(self, layer_number, flag_type, data): + """Project and return the compressed payload.""" + if self.projection_vector is None: + return False, [], torch.tensor([]) + vector = self.projection_vector[layer_number - 1] + projected = torch.matmul(data, vector).unsqueeze(-1) + return True, list(projected.shape), projected.flatten() + + +COMPRESSOR_MAP = { + "TileCompressor": TileCompressor, + "NoOpCompressor": NoOpCompressor, + "EmptyCompressor": EmptyCompressor, + "ProjectionCompressor": ProjectionCompressor, +} + + +class TensorTracers: + """Trace and report tensors selected by TTFlags.""" + + def report(self, name: NameTuple, tensor_data: torch.Tensor) -> None: + """Compress and send a traced tensor through the report callback.""" + compressor = get_compressor(name[1]) + valid, comp_args, compressed_tensor = compressor.compress(name[0], name[1], tensor_data) + if not valid: + logger.warning( + "Tensor tracer compressor %s returned invalid result for %s; skipping report.", + type(compressor).__name__, + name, + ) + return + get_report()(name, comp_args, compressed_tensor) + + +class TTFlags: + """Global flags to record the intermediate results of the model.""" + + def __init__(self, args): + self.num_layers = args.num_layers + self.flags: Dict[FlagType, Dict[int, bool]] = { + FlagType.INVALID_FLAG: {i: False for i in range(1, self.num_layers + 1)}, + FlagType.QKV_mat_mul: {i: False for i in range(1, self.num_layers + 1)}, + FlagType.ContextLayer_mat_mul: {i: False for i in range(1, self.num_layers + 1)}, + FlagType.MLP1_mat_mul: {i: False for i in range(1, self.num_layers + 1)}, + FlagType.MLP2_mat_mul: {i: False for i in range(1, self.num_layers + 1)}, + FlagType.AttentionOutput_mat_mul: {i: False for i in range(1, self.num_layers + 1)}, + FlagType.HiddenStates: {i: False for i in range(1, self.num_layers + 1)}, + FlagType.InputTokens: {0: True}, + } + self.should_trace = True + + def get_flag(self, flag_type: FlagType, layer_index: int) -> bool: + """Return whether a given flag is enabled for a layer.""" + return self.should_trace and self.flags.get(flag_type, {}).get(layer_index, False) + + def set_by_configs(self, configs: Dict[str, Any], comp_configs: Dict[str, Any]): + """Update tracing flags and compressor configurations from user configs.""" + global _GLOBAL_COMPRESSOR + if _GLOBAL_COMPRESSOR is None: + _set_compressor() + assert _GLOBAL_COMPRESSOR is not None + + for flag_type in self.flags: + if flag_type == FlagType.INVALID_FLAG: + continue + val = str(configs.get(flag_type.name, False)).lower() == "true" + for i in range(1, self.num_layers + 1): + self.flags[flag_type][i] = val + + specific_comp_config = comp_configs.get(flag_type.name) + if specific_comp_config is not None: + compressor_type = specific_comp_config.get("compressor_type", "EmptyCompressor") + compressor_configs = specific_comp_config.get("compressor_configs", {}) + compressor_cls = COMPRESSOR_MAP.get(compressor_type, EmptyCompressor) + _GLOBAL_COMPRESSOR[flag_type] = compressor_cls(compressor_configs) + + +class TTHookManager: + """Manage forward hooks that gather and report tensors for visualization.""" + + def __init__(self, args, model) -> None: + self.hooks = [] + # the type of model should be GPTModel + from megatron.core.models.gpt import GPTModel + + model = model[0].module.module + assert isinstance(model, GPTModel), f"{model}, {type(model)}" + + def generate_hook_transpose_col(flag_type: FlagType, layer_number: int): + def hook(module, input, output): + if get_tt_flags().get_flag(flag_type, layer_number): + device = torch.cuda.current_device() + world_size = get_tensor_model_parallel_world_size() + rank = get_tensor_model_parallel_rank() + rank0_global = torch.distributed.get_process_group_ranks( + get_tensor_model_parallel_group() + )[0] + + if isinstance(output, (list, tuple)): + tensor_data = output[0].detach() + else: + tensor_data = output.detach() + tensor_data = get_compressor(flag_type).compress_one_rank( + layer_number, flag_type, tensor_data + ) + tensor_data_cont = tensor_data.contiguous() + if rank == 0: + tensor_list = [ + torch.zeros_like( + tensor_data_cont, dtype=tensor_data_cont.dtype, device=device + ) + for _ in range(world_size) + ] + else: + tensor_list = None + if world_size > 1: + torch.distributed.gather( + tensor_data_cont, + tensor_list, + dst=rank0_global, + group=get_tensor_model_parallel_group(), + ) + else: + tensor_list = [tensor_data_cont] + + if rank == 0: + aggregated_tensor = None + + if flag_type == FlagType.QKV_mat_mul: + if world_size > 1: + tensor_list0, tensor_list1, tensor_list2 = [], [], [] + for id_rank in range(world_size): + chunks = torch.chunk(tensor_list[id_rank], 3, dim=2) + tensor_list0.append(chunks[0]) + tensor_list1.append(chunks[1]) + tensor_list2.append(chunks[2]) + tensor0 = torch.cat(tensor_list0, dim=2) + tensor1 = torch.cat(tensor_list1, dim=2) + tensor2 = torch.cat(tensor_list2, dim=2) + aggregated_tensor = torch.cat([tensor0, tensor1, tensor2], dim=2) + else: + aggregated_tensor = tensor_data_cont + else: + aggregated_tensor = torch.cat(tensor_list, dim=2) + + get_tensor_tracers().report( + (layer_number, flag_type), aggregated_tensor.transpose(0, 1) + ) + + return hook + + def generate_hook_transpose_row(flag_type: FlagType, layer_number: int): + def hook(module, input, output): + if get_tt_flags().get_flag(flag_type, layer_number): + device = torch.cuda.current_device() + world_size = get_tensor_model_parallel_world_size() + rank = get_tensor_model_parallel_rank() + rank0_global = torch.distributed.get_process_group_ranks( + get_tensor_model_parallel_group() + )[0] + + if args.sequence_parallel: + if isinstance(output, (list, tuple)): + tensor_data = output[0].detach() + else: + tensor_data = output.detach() + tensor_data = get_compressor(flag_type).compress_one_rank( + layer_number, flag_type, tensor_data + ) + tensor_data_cont = tensor_data.contiguous() + if rank == 0: + tensor_list = [ + torch.zeros_like( + tensor_data_cont, dtype=tensor_data_cont.dtype, device=device + ) + for _ in range(world_size) + ] + else: + tensor_list = None + if world_size > 1: + torch.distributed.gather( + tensor_data_cont, + tensor_list, + dst=rank0_global, + group=get_tensor_model_parallel_group(), + ) + else: + tensor_list = [tensor_data_cont] + + if rank == 0: + aggregated_tensor = torch.cat(tensor_list, dim=0) + get_tensor_tracers().report( + (layer_number, flag_type), aggregated_tensor.transpose(0, 1) + ) + else: + if rank == 0: + if isinstance(output, (list, tuple)): + tensor_data = output[0].detach() + else: + tensor_data = output.detach() + tensor_data = get_compressor(flag_type).compress_one_rank( + layer_number, flag_type, tensor_data + ) + get_tensor_tracers().report( + (layer_number, flag_type), tensor_data.transpose(0, 1) + ) + + return hook + + def generate_hook_attn(flag_type: FlagType, layer_number: int): + def hook(module, input, output): + if get_tt_flags().get_flag(flag_type, layer_number): + device = torch.cuda.current_device() + world_size = get_tensor_model_parallel_world_size() + rank = get_tensor_model_parallel_rank() + rank0_global = torch.distributed.get_process_group_ranks( + get_tensor_model_parallel_group() + )[0] + + if isinstance(output, (list, tuple)): + tensor_data = output[0].detach() + else: + tensor_data = output.detach() + tensor_data = get_compressor(flag_type).compress_one_rank( + layer_number, flag_type, tensor_data + ) + tensor_data_cont = tensor_data.contiguous() + if rank == 0: + tensor_list = [ + torch.zeros_like( + tensor_data_cont, dtype=tensor_data_cont.dtype, device=device + ) + for _ in range(world_size) + ] + else: + tensor_list = None + if world_size > 1: + torch.distributed.gather( + tensor_data_cont, + tensor_list, + dst=rank0_global, + group=get_tensor_model_parallel_group(), + ) + else: + tensor_list = [tensor_data_cont] + + if rank == 0: + aggregated_tensor = torch.cat(tensor_list, dim=1) + get_tensor_tracers().report((layer_number, flag_type), aggregated_tensor) + + return hook + + def generate_hook_input(flag_type: FlagType, layer_number: int): + def hook(module, args, kwargs, output): + if get_tt_flags().get_flag(flag_type, layer_number): + device = torch.cuda.current_device() + world_size = get_tensor_model_parallel_world_size() + rank = get_tensor_model_parallel_rank() + rank0_global = torch.distributed.get_process_group_ranks(get_tensor_model_parallel_group())[0] + + if rank == 0: + input_ids = kwargs["input_ids"] + position_ids = kwargs["position_ids"] + combined_input = torch.stack([input_ids, position_ids], dim=0) + tensor_data = get_compressor(flag_type).compress_one_rank(layer_number, flag_type, combined_input) + get_tensor_tracers().report((layer_number, flag_type), tensor_data) + return hook + + if hasattr(model, "embedding"): + self.hooks.append(model.embedding.register_forward_hook(generate_hook_input(FlagType.InputTokens, 0), with_kwargs=True)) # Row + + for layer in range(model.decoder.num_layers_per_pipeline_rank): + global_layer_number = model.decoder.layers[layer].layer_number + self.hooks.append( + model.decoder.layers[layer].self_attention.linear_qkv.register_forward_hook( + generate_hook_transpose_col(FlagType.QKV_mat_mul, global_layer_number) + ) + ) # Col, not gather_output + self.hooks.append( + model.decoder.layers[layer].mlp.linear_fc1.register_forward_hook( + generate_hook_transpose_col(FlagType.MLP1_mat_mul, global_layer_number) + ) + ) # Col, not gather_output + self.hooks.append( + model.decoder.layers[layer].mlp.linear_fc2.register_forward_hook( + generate_hook_transpose_row(FlagType.MLP2_mat_mul, global_layer_number) + ) + ) # Row + self.hooks.append( + model.decoder.layers[layer].self_attention.register_forward_hook( + generate_hook_transpose_row( + FlagType.AttentionOutput_mat_mul, global_layer_number + ) + ) + ) # Row + self.hooks.append( + model.decoder.layers[layer].register_forward_hook( + generate_hook_transpose_row(FlagType.HiddenStates, global_layer_number) + ) + ) # Row + self.hooks.append( + model.decoder.layers[layer].self_attention.core_attention.register_forward_hook( + generate_hook_transpose_col(FlagType.ContextLayer_mat_mul, global_layer_number) + ) + ) # Col, not gather_output + + +''' +For ColumnParallelLinear: +1. If gather_output, we do not do all gather +2. If not gather_output, we do all gather +For RowParallelLinear: +1. If sequence_parallel, we do all gather +2. If not sequence_parallel, we do not do all gather +''' diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 1af066a8207..6d6e85c47ef 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -2129,6 +2129,8 @@ def _add_training_args(parser): help='The communicator group names to use high priority streams.') group.add_argument('--disable-jit-fuser', action='store_true', help='Disable the JIT fuser.') + group.add_argument('--tensor-tracer-port', type=int, default=None, + help='Port for the training visualization server. If set, training will be interactive and controlled by the frontend.') return parser diff --git a/megatron/training/global_vars.py b/megatron/training/global_vars.py index 76e8df7cee3..5a081a6915f 100644 --- a/megatron/training/global_vars.py +++ b/megatron/training/global_vars.py @@ -13,6 +13,7 @@ from megatron.core.num_microbatches_calculator import init_num_microbatches_calculator, unset_num_microbatches_calculator from megatron.training.dist_signal_handler import DistributedSignalHandler from megatron.training.tokenizer import build_tokenizer +from megatron.core.tensor_tracer import _set_tensor_tracers, _set_tt_flags, _set_compressor _GLOBAL_ARGS = None _GLOBAL_TOKENIZER = None @@ -116,6 +117,10 @@ def set_global_variables(args, build_tokenizer=True): if args.disable_jit_fuser: disable_jit_fuser() + if args.tensor_tracer_port is not None: + _set_tensor_tracers() + _set_tt_flags(args) + _set_compressor() def unset_global_variables(): """Unset global vars. diff --git a/megatron/training/training.py b/megatron/training/training.py index e9736ac085c..322b780f87a 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -42,6 +42,7 @@ def set_startup_timestamps(program_start=None, main_entry=None): import math import os import sys +import multiprocessing from contextlib import nullcontext from typing import Any, Optional, Dict @@ -904,6 +905,10 @@ def pretrain( model, optimizer, opt_param_scheduler = setup_model_and_optimizer( model_provider, model_type, checkpointing_context=checkpointing_context ) + if args.tensor_tracer_port is not None and mpu.get_data_parallel_rank() == 0: + from megatron.core.tensor_tracer import _set_tt_hook_manager + + _set_tt_hook_manager(args, model) timers('model-and-optimizer-setup').stop() print_datetime('after model, optimizer, and learning rate ' 'scheduler are built') @@ -2446,6 +2451,83 @@ def train( args = get_args() timers = get_timers() + global_rank = torch.distributed.get_rank() + tp_rank = mpu.get_tensor_model_parallel_rank() + dp_rank = mpu.get_data_parallel_rank() + if args.tensor_tracer_port is not None and tp_rank == 0 and dp_rank == 0: + try: + from .training_wsserver import websocket_server_process, websocket_worker_process + except ModuleNotFoundError as exc: + if exc.name == "websockets": + raise RuntimeError( + "Tensor tracer requires optional dependency 'websockets'. " + "Install it with `pip install -e '.[tensor_tracer]'` " + "(or `pip install websockets`)." + ) from exc + raise + + data_queue = multiprocessing.Queue() + shutdown_event = multiprocessing.Event() + + from megatron.core.tensor_tracer import FlagType, set_report + def report_func(name_tuple, report_args, tensor_data): + # name_tuple is (layer_id, FlagType) + # report_args are specific to the FlagType (e.g., [n,m] for attention) + # tensor_data is the actual data (list or tensor that can be .tolist()) + try: + if name_tuple[1] == FlagType.INVALID_FLAG: + return + torch.cuda.synchronize() + data_queue.put_nowait((name_tuple, report_args, tensor_data.to('cpu', non_blocking=True))) + except Exception as e: + pass + set_report(report_func) + + if global_rank == 0: + config_queue = multiprocessing.Queue(maxsize=1) + start_training_event = multiprocessing.Event() + + training_args_dict = { + "micro_batch_size": args.micro_batch_size, + "seq_length": args.seq_length, + "num_layers": args.num_layers + } + + ws_process = multiprocessing.Process( + target=websocket_server_process, + args=(args.tensor_tracer_port, data_queue, config_queue, start_training_event, shutdown_event, training_args_dict), + daemon=True + ) + ws_process.start() + start_training_event.wait() + else: + master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1") + + ws_process = multiprocessing.Process( + target=websocket_worker_process, + args=(master_addr, args.tensor_tracer_port, global_rank, data_queue, shutdown_event), + daemon=True + ) + ws_process.start() + + if args.tensor_tracer_port is not None: + if global_rank == 0 and tp_rank == 0 and dp_rank == 0: + received_configs = config_queue.get() + vis_flags = received_configs.get('visualization_flags', {}) + comp_configs = received_configs.get('compressor_config', {}) + configs_to_broadcast = [vis_flags, comp_configs] + else: + configs_to_broadcast = [None, None] + + torch.distributed.broadcast_object_list(configs_to_broadcast, src=0) + + vis_flags, comp_configs = configs_to_broadcast + if dp_rank != 0: + vis_flags = {"InputTokens": False} + comp_configs = {} + from megatron.core.tensor_tracer import get_tt_flags + get_tt_flags().set_by_configs(vis_flags, comp_configs) + if getattr(args, 'perform_rl_step', False): assert has_rl_utils, "RL cannot run without the megatron.rl package" @@ -3015,6 +3097,19 @@ def get_e2e_base_metrics(): if should_exit: break + if args.tensor_tracer_port is not None and tp_rank == 0: + print_rank_0("Signaling WebSocket process to shut down...") + shutdown_event.set() + data_queue.close() + data_queue.cancel_join_thread() + if global_rank == 0: + config_queue.close() + config_queue.cancel_join_thread() + ws_process.join(timeout=5) + if ws_process.is_alive(): + print_rank_0("WebSocket process did not shut down cleanly, terminating.") + ws_process.terminate() + # Destroy CUDA Graphs. if args.cuda_graph_impl == "transformer_engine" and cuda_graph_helper.graphs_created(): cuda_graph_helper.delete_cuda_graphs() diff --git a/megatron/training/training_wsserver.py b/megatron/training/training_wsserver.py new file mode 100644 index 00000000000..953da7d48e0 --- /dev/null +++ b/megatron/training/training_wsserver.py @@ -0,0 +1,157 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +import json +import multiprocessing +import threading +import queue +import time +from websockets.sync.server import serve, Server +from websockets.sync.client import connect +from websockets.exceptions import ConnectionClosed + +_frontend_connection = None +_connection_lock = threading.Lock() + +_request_configs = {} + +def get_frontend_connection(): + with _connection_lock: + return _frontend_connection + +def websocket_worker_process(master_addr: str, port: int, rank: int, data_queue: multiprocessing.Queue, shutdown_event: multiprocessing.Event): + uri = f"ws://{master_addr}:{port}" + print(f"Rank {rank} (Worker): Connecting to Hub at {uri}...", flush=True) + + while not shutdown_event.is_set(): + try: + with connect(uri, max_size=None) as websocket: + print(f"Rank {rank} (Worker): Connected.", flush=True) + + while not shutdown_event.is_set(): + try: + name_tuple, report_args, tensor_data = data_queue.get(timeout=1.0) + + payload = { + "type": "worker_forward", + "data": { + "type": "update", + "update_type": name_tuple[1].value, + "layer_id": name_tuple[0], + "args": report_args, + "result": tensor_data.tolist() + } + } + + websocket.send(json.dumps(payload)) + except queue.Empty: + continue + except (ConnectionRefusedError, OSError): + time.sleep(2) + except Exception as e: + print(f"Rank {rank} (Worker): Error: {e}, retrying...", flush=True) + time.sleep(2) + +def websocket_server_process(port: int, data_queue: multiprocessing.Queue, config_queue: multiprocessing.Queue, start_event: multiprocessing.Event, shutdown_event: multiprocessing.Event, training_args: dict): + global _frontend_connection + + def _local_data_sender(): + print("Rank 0 (Server): Local data sender started.", flush=True) + while not shutdown_event.is_set(): + try: + name_tuple, report_args, tensor_data = data_queue.get(timeout=1.0) + + payload = { + "type": "update", + "update_type": name_tuple[1].value, + "layer_id": name_tuple[0], + "args": report_args, + "result": tensor_data.tolist() + } + + ws = get_frontend_connection() + if ws: + try: + ws.send(json.dumps(payload)) + except ConnectionClosed: + pass + except Exception as e: + print(f"Rank 0 (Server): Error sending data: {e}", flush=True) + + except queue.Empty: + continue + except Exception as e: + print(f"Rank 0 (Server): Unexpected error in sender: {e}", flush=True) + + def _websocket_handler(websocket): + global _frontend_connection + is_frontend = False + + try: + for message in websocket: + try: + msg_obj = json.loads(message) + msg_type = msg_obj.get("type") + if msg_type == "worker_forward": + forward_data = msg_obj.get("data") + frontend = get_frontend_connection() + if frontend: + frontend.send(json.dumps(forward_data)) + elif msg_type == "run_training_step": + print("Rank 0 (Server): Frontend connected and assumed control.", flush=True) + with _connection_lock: + _frontend_connection = websocket + is_frontend = True + + _request_configs['visualization_flags'] = msg_obj.get("visualization_flags", {}) + _request_configs['compressor_config'] = msg_obj.get("compressor_config", {}) + config_queue.put(_request_configs) + + start_payload = { + "type": "start", + "micro_batch_size": training_args.get("micro_batch_size"), + "seq_length": training_args.get("seq_length"), + "num_layers": training_args.get("num_layers") + } + websocket.send(json.dumps(start_payload)) + start_event.set() + except Exception as e: + print(f"Rank 0 (Server): Error processing message: {e}", flush=True) + except ConnectionClosed: + print("Rank 0 (Server): Connection handler closed.", flush=True) + finally: + if is_frontend: + print("Rank 0 (Server): Frontend disconnected.", flush=True) + with _connection_lock: + _frontend_connection = None + start_event.clear() + + sender_thread = threading.Thread(target=_local_data_sender, daemon=True) + sender_thread.start() + + server: Server = None + + def shutdown_handler(): + shutdown_event.wait() + print("Rank 0 (Server): Shutdown event received, stopping server...", flush=True) + if server: + server.shutdown() + + shutdown_thread = threading.Thread(target=shutdown_handler, daemon=True) + shutdown_thread.start() + + print(f"Rank 0 (Server): Starting server on ws://0.0.0.0:{port}", flush=True) + + try: + with serve( + _websocket_handler, "0.0.0.0", port, + ping_interval=None, reuse_port=True, + max_size=None, + ) as server_instance: + server = server_instance + server.serve_forever() + except Exception as e: + print(f"Rank 0 (Server): Server crashed with an error: {e}", flush=True) + finally: + sender_thread.join(timeout=1.0) + config_queue.close() + print("Rank 0 (Server): Server has shut down.", flush=True) diff --git a/pyproject.toml b/pyproject.toml index 567954ca4a1..797e05f1a2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,8 @@ Homepage = "https://github.com/NVIDIA/Megatron-LM/megatron/core" [project.optional-dependencies] mlm = ["flask-restful", "sentencepiece", "tiktoken", "wandb", "transformers"] +tensor_tracer = ["websockets>=10.0"] + dev = [ "nvidia-modelopt[torch]; sys_platform != 'darwin'", "transformer-engine[pytorch,core_cu13]>=2.9.0a0,<2.12.0", diff --git a/tests/unit_tests/test_tensor_tracer.py b/tests/unit_tests/test_tensor_tracer.py new file mode 100644 index 00000000000..ea29cc2a233 --- /dev/null +++ b/tests/unit_tests/test_tensor_tracer.py @@ -0,0 +1,75 @@ +import importlib +from types import SimpleNamespace + +import torch + + +def test_default_compressors_exist_for_all_flag_types() -> None: + import megatron.core.tensor_tracer as tt + + importlib.reload(tt) + tt._set_compressor() + + for flag_type in tt.FlagType: + compressor = tt.get_compressor(flag_type) + assert isinstance(compressor, tt.AbstractCompressor) + + +def test_ttflags_set_by_configs_sets_flags_and_compressors() -> None: + import megatron.core.tensor_tracer as tt + + importlib.reload(tt) + tt._set_compressor() + + args = SimpleNamespace(num_layers=2) + flags = tt.TTFlags(args) + flags.set_by_configs( + {"QKV_mat_mul": "true", "MLP1_mat_mul": "false"}, + {"QKV_mat_mul": {"compressor_type": "NoOpCompressor", "compressor_configs": {}}}, + ) + + assert flags.get_flag(tt.FlagType.QKV_mat_mul, 1) is True + assert flags.get_flag(tt.FlagType.QKV_mat_mul, 2) is True + assert flags.get_flag(tt.FlagType.MLP1_mat_mul, 1) is False + + assert isinstance(tt.get_compressor(tt.FlagType.QKV_mat_mul), tt.NoOpCompressor) + empty_compressor = tt.get_compressor(tt.FlagType.MLP1_mat_mul) + assert isinstance(empty_compressor, tt.EmptyCompressor) + + sample = torch.zeros(2, 3, 4) + empty_sample = empty_compressor.compress_one_rank(1, tt.FlagType.MLP1_mat_mul, sample) + assert empty_sample.shape == (2, 3, 0) + assert empty_sample.device == sample.device + + +def test_tile_compressor_compress_shapes() -> None: + import megatron.core.tensor_tracer as tt + + compressor = tt.TileCompressor({"tiles": 4, "tiles_one_rank": 4}) + data = torch.ones(2, 3, 10) + valid, shape, payload = compressor.compress(1, tt.FlagType.MLP1_mat_mul, data) + + assert valid is True + assert shape == [2, 3, 4] + assert payload.numel() == 2 * 3 * 4 + + +def test_tensor_tracers_skips_invalid_compressor_result() -> None: + import megatron.core.tensor_tracer as tt + + importlib.reload(tt) + + class BadCompressor(tt.AbstractCompressor): + def compress_one_rank(self, layer_number, flag_type, data): + return data + + def compress(self, layer_number, flag_type, data): + return False, [], torch.tensor([]) + + tt._GLOBAL_COMPRESSOR = {flag_type: BadCompressor() for flag_type in tt.FlagType} + called = [] + tt.set_report(lambda name, args, tensor: called.append(name)) + + tracer = tt.TensorTracers() + tracer.report((1, tt.FlagType.QKV_mat_mul), torch.zeros(1, 1, 1)) + assert called == []