From 0ed4ad1670e6bdb9c96db75ff23a7b746057d487 Mon Sep 17 00:00:00 2001 From: superay <1420782034@qq.com> Date: Sat, 31 Jan 2026 10:38:28 +0000 Subject: [PATCH 01/10] feat: add tensor tracer functionality and websocket server for training visualization --- megatron/core/tensor_tracer.py | 321 +++++++++++++++++++++++++ megatron/training/arguments.py | 2 + megatron/training/global_vars.py | 5 + megatron/training/training.py | 72 ++++++ megatron/training/training_wsserver.py | 133 ++++++++++ pyproject.toml | 2 +- 6 files changed, 534 insertions(+), 1 deletion(-) create mode 100644 megatron/core/tensor_tracer.py create mode 100644 megatron/training/training_wsserver.py diff --git a/megatron/core/tensor_tracer.py b/megatron/core/tensor_tracer.py new file mode 100644 index 00000000000..e9f24926714 --- /dev/null +++ b/megatron/core/tensor_tracer.py @@ -0,0 +1,321 @@ +from abc import abstractmethod +import torch +import math +from enum import Enum +from typing import Dict, Any +from megatron.core.parallel_state import get_tensor_model_parallel_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank + +_GLOBAL_TT_FLAGS = None +_GLOBAL_TENSOR_TRACERS = None +_GLOBAL_REPORT = lambda name, args, tensor: None +_GLOBAL_COMPRESSOR = None +_GLOBAL_HOOK_MANAGER = None + +def _set_tensor_tracers(): + global _GLOBAL_TENSOR_TRACERS + _GLOBAL_TENSOR_TRACERS = TensorTracers() + +def _set_tt_flags(args): + global _GLOBAL_TT_FLAGS + _GLOBAL_TT_FLAGS = TTFlags(args) + +def _set_tt_hook_manager(args, model): + global _GLOBAL_HOOK_MANAGER + _GLOBAL_HOOK_MANAGER = TTHookManager(args, model) + +def _set_compressor(): + global _GLOBAL_COMPRESSOR + _GLOBAL_COMPRESSOR=DefaultCompressor() + +def set_report(func): + global _GLOBAL_REPORT + _GLOBAL_REPORT = func + +def unset_report(): + global _GLOBAL_REPORT + _GLOBAL_REPORT = lambda name, args, tensor: None + +def get_tensor_tracers(): + return _GLOBAL_TENSOR_TRACERS + +def get_tt_flags(): + return _GLOBAL_TT_FLAGS + +def get_compressor(): + return _GLOBAL_COMPRESSOR + +def get_report(): + return _GLOBAL_REPORT + +class FlagType(Enum): + INVALID_FLAG = 0 + QKV_mat_mul = 1 + RawAttentionScore_mat_mul = 2 + ContextLayer_mat_mul = 3 + MLP1_mat_mul = 4 + MLP2_mat_mul = 5 + AttentionOutput_mat_mul = 6 + HiddenStates = 7 + +class AbstractCompressor: + def __init__(self): + pass + @abstractmethod + def set_by_configs(self, configs: Dict[str, Any]): + pass + @abstractmethod + def compress_one_rank(self, name, data): + pass + @abstractmethod + def compress(self, name, data): + pass + +class DefaultCompressor(AbstractCompressor): + def __init__(self): + self.configs = { + "QKV": { + "pixels": 96, + "method": "data.mean(dim=-1)" + }, + "MLP": { + "pixels": 64, + "method": "data.mean(dim=-1)" + } + } + + def set_by_configs(self, configs: Dict[str, Any]): + self.configs = configs + + def compress_tensor(self, data_in, pixels, method): + B, S, F = data_in.shape + chunk_size = math.ceil(F / pixels) + padded_len = chunk_size * pixels + padded_data = torch.nn.functional.pad(data_in, (0, padded_len - F)) + data_for_eval = padded_data.reshape(B, S, pixels, chunk_size) + try: + compressed = eval(method, {}, {"data": data_for_eval}) + except Exception as e: + print(f"Error in compressing tensor with method '{method}': {e}") + compressed = data_for_eval.mean(dim=-1) + return compressed + + def compress_1d_tensor(self, data_in, pixels, method): + B, S, F = data_in.shape + chunk_size = math.ceil(F / pixels) + padded_len = chunk_size * pixels + padded_data = torch.nn.functional.pad(data_in, (0, padded_len - F)) + data_for_eval = padded_data.reshape(B, S, pixels, chunk_size) + try: + compressed = eval(method, {}, {"data": data_for_eval}).flatten() + except Exception as e: + print(f"Error in compressing tensor with method '{method}': {e}") + compressed = data_for_eval.mean(dim=-1).flatten() # Fallback to mean if eval fails + return compressed + + def compress_one_rank(self, flag_type, data): + if flag_type == FlagType.QKV_mat_mul: + return self.compress_tensor(data, self.configs["QKV"]["pixels"], self.configs["QKV"]["method"]) + elif flag_type == FlagType.MLP1_mat_mul or flag_type == FlagType.MLP2_mat_mul or flag_type == FlagType.ContextLayer_mat_mul: + return self.compress_tensor(data, self.configs["MLP"]["pixels"], self.configs["MLP"]["method"]) + return data + + def compress(self, name, data): + flag_type = name[1] + if flag_type == FlagType.QKV_mat_mul: + n = data.shape[1]; return True, [n], self.compress_1d_tensor(data, self.configs["QKV"]["pixels"], self.configs["QKV"]["method"]) + elif flag_type == FlagType.RawAttentionScore_mat_mul: + np, n, m = data.shape[1], data.shape[2], data.shape[3]; return True, [np, n, m], data[:, :, :, :].flatten() + elif flag_type == FlagType.MLP1_mat_mul or flag_type == FlagType.MLP2_mat_mul or flag_type == FlagType.ContextLayer_mat_mul: + n = data.shape[1]; return True, [n], self.compress_1d_tensor(data, self.configs["MLP"]["pixels"], self.configs["MLP"]["method"]) + return False, [], torch.tensor([]) + +class ProjectionCompressor(AbstractCompressor): + def __init__(self): + pass + + def set_by_configs(self, configs: Dict[str, Any]): + pass + + def compress_one_rank(self, name, data): + return data + + def compress(self, name, data): + return False, [], torch.tensor([]) + +class TensorTracers: # simplified as TT + def __init__(self) -> None: pass + + def report(self, name, tensor_data): + valid, comp_args, compressed_tensor = get_compressor().compress(name, tensor_data) + assert valid + get_report()(name, comp_args, compressed_tensor) + +class TTFlags: + """Global flags to record the intermediate results of the model.""" + + def __init__(self, args): + self.num_layers = args.num_layers + self.flags: Dict[FlagType, Dict[int, bool]] = { + FlagType.INVALID_FLAG: {i: False for i in range(1, self.num_layers + 1)}, + FlagType.QKV_mat_mul: {i: False for i in range(1, self.num_layers + 1)}, + FlagType.RawAttentionScore_mat_mul: {i: False for i in range(1, self.num_layers + 1)}, + FlagType.ContextLayer_mat_mul: {i: False for i in range(1, self.num_layers + 1)}, + FlagType.MLP1_mat_mul: {i: False for i in range(1, self.num_layers + 1)}, + FlagType.MLP2_mat_mul: {i: False for i in range(1, self.num_layers + 1)}, + FlagType.AttentionOutput_mat_mul: {i: False for i in range(1, self.num_layers + 1)}, + FlagType.HiddenStates: {i: False for i in range(1, self.num_layers + 1)}, + } + self.should_trace = True + + def get_flag(self, flag_type: FlagType, layer_index: int) -> bool: + return self.should_trace and self.flags.get(flag_type, {}).get(layer_index, False) + + def set_by_configs(self, configs: Dict[str, Any]): + val = True if configs.get("QKV_mat_mul", "False").lower() == "true" else False + for i in range(1, self.num_layers + 1): + self.flags[FlagType.QKV_mat_mul][i] = val + + val = True if configs.get("RawAttentionScore_mat_mul", "False").lower() == "true" else False + for i in range(1, self.num_layers + 1): + self.flags[FlagType.RawAttentionScore_mat_mul][i] = val + + val = True if configs.get("ContextLayer_mat_mul", "False").lower() == "true" else False + for i in range(1, self.num_layers + 1): + self.flags[FlagType.ContextLayer_mat_mul][i] = val + + val = True if configs.get("MLP1_mat_mul", "True").lower() == "true" else False + for i in range(1, self.num_layers + 1): + self.flags[FlagType.MLP1_mat_mul][i] = val + + val = True if configs.get("MLP2_mat_mul", "True").lower() == "true" else False + for i in range(1, self.num_layers + 1): + self.flags[FlagType.MLP2_mat_mul][i] = val + + val = True if configs.get("AttentionOutput_mat_mul", "False").lower() == "true" else False + for i in range(1, self.num_layers + 1): + self.flags[FlagType.AttentionOutput_mat_mul][i] = val + + val = True if configs.get("HiddenStates", "False").lower() == "true" else False + for i in range(1, self.num_layers + 1): + self.flags[FlagType.HiddenStates][i] = val + +class TTHookManager: + def __init__(self, args, model) -> None: + self.hooks = [] + # the type of model should be GPTModel + from megatron.core.models.gpt import GPTModel + model = model[0].module.module + assert isinstance(model, GPTModel), f"{model}, {type(model)}" + def generate_hook_transpose_col(flag_type: FlagType, layer_number: int): + def hook(module, input, output): + if get_tt_flags().get_flag(flag_type, layer_number): + device = torch.cuda.current_device() + world_size = get_tensor_model_parallel_world_size() + rank = get_tensor_model_parallel_rank() + + tensor_data = output[0].detach() + tensor_data = get_compressor().compress_one_rank(flag_type, tensor_data) + tensor_data_cont = tensor_data.contiguous() + if rank == 0: + tensor_list = [torch.zeros_like(tensor_data_cont, dtype=tensor_data_cont.dtype, device=device) for _ in range(world_size)] + else: + tensor_list = None + if world_size > 1: + torch.distributed.gather(tensor_data_cont, tensor_list, dst=0, group=get_tensor_model_parallel_group()) + else: + tensor_list = [tensor_data_cont] + + if rank == 0: + aggregated_tensor = None + + if flag_type == FlagType.QKV_mat_mul: + if world_size > 1: + tensor_list0, tensor_list1, tensor_list2 = [], [], [] + for id_rank in range(world_size): + chunks = torch.chunk(tensor_list[id_rank], 3, dim=2) + tensor_list0.append(chunks[0]) + tensor_list1.append(chunks[1]) + tensor_list2.append(chunks[2]) + tensor0 = torch.cat(tensor_list0, dim=2) + tensor1 = torch.cat(tensor_list1, dim=2) + tensor2 = torch.cat(tensor_list2, dim=2) + aggregated_tensor = torch.cat([tensor0, tensor1, tensor2], dim=2) + else: + aggregated_tensor = tensor_data_cont + else: + aggregated_tensor = torch.cat(tensor_list, dim=2) + + get_tensor_tracers().report((layer_number, flag_type), aggregated_tensor.transpose(0, 1)) + return hook + + def generate_hook_transpose_row(flag_type: FlagType, layer_number: int): + def hook(module, input, output): + if get_tt_flags().get_flag(flag_type, layer_number): + device = torch.cuda.current_device() + world_size = get_tensor_model_parallel_world_size() + rank = get_tensor_model_parallel_rank() + + if args.sequence_parallel: + tensor_data = output[0].detach() + tensor_data = get_compressor().compress_one_rank(flag_type, tensor_data) + tensor_data_cont = tensor_data.contiguous() + if rank == 0: + tensor_list = [torch.zeros_like(tensor_data_cont, dtype=tensor_data_cont.dtype, device=device) for _ in range(world_size)] + else: + tensor_list = None + if world_size > 1: + torch.distributed.gather(tensor_data_cont, tensor_list, dst=0, group=get_tensor_model_parallel_group()) + else: + tensor_list = [tensor_data_cont] + + if rank == 0: + aggregated_tensor = torch.cat(tensor_list, dim=0) + get_tensor_tracers().report((layer_number, flag_type), aggregated_tensor.transpose(0, 1)) + else: + if rank == 0: + tensor_data = output[0].detach() + tensor_data = get_compressor().compress_one_rank(flag_type, tensor_data) + get_tensor_tracers().report((layer_number, flag_type), tensor_data.transpose(0, 1)) + return hook + + def generate_hook_attn(flag_type: FlagType, layer_number: int): + def hook(module, input, output): + if get_tt_flags().get_flag(flag_type, layer_number): + device = torch.cuda.current_device() + world_size = get_tensor_model_parallel_world_size() + rank = get_tensor_model_parallel_rank() + + tensor_data = output.detach() + tensor_data = get_compressor().compress_one_rank(flag_type, tensor_data) + tensor_data_cont = tensor_data.contiguous() + if rank == 0: + tensor_list = [torch.zeros_like(tensor_data_cont, dtype=tensor_data_cont.dtype, device=device) for _ in range(world_size)] + else: + tensor_list = None + if world_size > 1: + torch.distributed.gather(tensor_data_cont, tensor_list, dst=0, group=get_tensor_model_parallel_group()) + else: + tensor_list = [tensor_data_cont] + + if rank == 0: + aggregated_tensor = torch.cat(tensor_list, dim=1) + get_tensor_tracers().report((layer_number, flag_type), aggregated_tensor) + return hook + for layer in range(model.decoder.num_layers_per_pipeline_rank): + global_layer_number = model.decoder.layers[layer].layer_number + self.hooks.append(model.decoder.layers[layer].self_attention.linear_qkv.register_forward_hook(generate_hook_transpose_col(FlagType.QKV_mat_mul, global_layer_number))) # Col, not gather_output + self.hooks.append(model.decoder.layers[layer].mlp.linear_fc1.register_forward_hook(generate_hook_transpose_col(FlagType.MLP1_mat_mul, global_layer_number))) # Col, not gather_output + self.hooks.append(model.decoder.layers[layer].mlp.linear_fc2.register_forward_hook(generate_hook_transpose_row(FlagType.MLP2_mat_mul, global_layer_number))) # Row + self.hooks.append(model.decoder.layers[layer].self_attention.register_forward_hook(generate_hook_transpose_row(FlagType.AttentionOutput_mat_mul, global_layer_number))) # Row + self.hooks.append(model.decoder.layers[layer].register_forward_hook(generate_hook_transpose_row(FlagType.HiddenStates, global_layer_number))) # Row + self.hooks.append(model.decoder.layers[layer].self_attention.core_attention.scale_mask_softmax.register_forward_hook(generate_hook_attn(FlagType.RawAttentionScore_mat_mul, global_layer_number))) # Raw Attention Scores, Special + self.hooks.append(model.decoder.layers[layer].self_attention.core_attention.register_forward_hook(generate_hook_transpose_col(FlagType.ContextLayer_mat_mul, global_layer_number))) # Col, not gather_output + +''' +For ColumnParallelLinear: +1. If gather_output, we do not do all gather +2. If not gather_output, we do all gather +For RowParallelLinear: +1. If sequence_parallel, we do all gather +2. If not sequence_parallel, we do not do all gather +''' \ No newline at end of file diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 1af066a8207..6d6e85c47ef 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -2129,6 +2129,8 @@ def _add_training_args(parser): help='The communicator group names to use high priority streams.') group.add_argument('--disable-jit-fuser', action='store_true', help='Disable the JIT fuser.') + group.add_argument('--tensor-tracer-port', type=int, default=None, + help='Port for the training visualization server. If set, training will be interactive and controlled by the frontend.') return parser diff --git a/megatron/training/global_vars.py b/megatron/training/global_vars.py index 76e8df7cee3..5a081a6915f 100644 --- a/megatron/training/global_vars.py +++ b/megatron/training/global_vars.py @@ -13,6 +13,7 @@ from megatron.core.num_microbatches_calculator import init_num_microbatches_calculator, unset_num_microbatches_calculator from megatron.training.dist_signal_handler import DistributedSignalHandler from megatron.training.tokenizer import build_tokenizer +from megatron.core.tensor_tracer import _set_tensor_tracers, _set_tt_flags, _set_compressor _GLOBAL_ARGS = None _GLOBAL_TOKENIZER = None @@ -116,6 +117,10 @@ def set_global_variables(args, build_tokenizer=True): if args.disable_jit_fuser: disable_jit_fuser() + if args.tensor_tracer_port is not None: + _set_tensor_tracers() + _set_tt_flags(args) + _set_compressor() def unset_global_variables(): """Unset global vars. diff --git a/megatron/training/training.py b/megatron/training/training.py index e9736ac085c..c769fec95c4 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -42,6 +42,7 @@ def set_startup_timestamps(program_start=None, main_entry=None): import math import os import sys +import multiprocessing from contextlib import nullcontext from typing import Any, Optional, Dict @@ -50,6 +51,9 @@ def set_startup_timestamps(program_start=None, main_entry=None): from megatron.core.optimizer.distrib_optimizer import DistributedOptimizer from .log_handler import CustomHandler +from megatron.core.tensor_tracer import _set_tt_hook_manager +from .training_wsserver import websocket_server_process + # Make default logging level INFO, but filter out all log messages not from MCore. logging.basicConfig(handlers=[CustomHandler()], level=logging.INFO) from .theoretical_memory_usage import report_theoretical_memory @@ -904,6 +908,8 @@ def pretrain( model, optimizer, opt_param_scheduler = setup_model_and_optimizer( model_provider, model_type, checkpointing_context=checkpointing_context ) + if args.tensor_tracer_port is not None: + _set_tt_hook_manager(args, model) timers('model-and-optimizer-setup').stop() print_datetime('after model, optimizer, and learning rate ' 'scheduler are built') @@ -2446,6 +2452,60 @@ def train( args = get_args() timers = get_timers() + if args.tensor_tracer_port is not None and torch.distributed.get_rank() == 0: + data_queue = multiprocessing.Queue() + config_queue = multiprocessing.Queue(maxsize=1) + start_training_event = multiprocessing.Event() + shutdown_event = multiprocessing.Event() + training_args_dict = { + "micro_batch_size": args.micro_batch_size, + "seq_length": args.seq_length, + "num_layers": args.num_layers + } + ws_process = multiprocessing.Process( + target=websocket_server_process, + args=(args.tensor_tracer_port, data_queue, config_queue, start_training_event, shutdown_event, training_args_dict), + daemon=True + ) + ws_process.start() + from megatron.core.tensor_tracer import FlagType, set_report + def report_func(name_tuple, report_args, tensor_data): + # name_tuple is (layer_id, FlagType) + # report_args are specific to the FlagType (e.g., [n,m] for attention) + # tensor_data is the actual data (list or tensor that can be .tolist()) + try: + if name_tuple[1] == FlagType.INVALID_FLAG: + return + torch.cuda.synchronize() + data_queue.put_nowait((name_tuple, report_args, tensor_data.to('cpu', non_blocking=True))) + except Exception as e: + pass + set_report(report_func) + + if args.tensor_tracer_port is not None: + if torch.distributed.get_rank() == 0: + print_rank_0("Waiting for 'run_training_step' command from frontend to start training...") + start_training_event.wait() + print_rank_0("Command received. Synchronizing configs across all ranks...") + + if torch.distributed.get_rank() == 0: + received_configs = config_queue.get() + vis_flags = received_configs.get('visualization_flags', {}) + dist_configs = received_configs.get('disturbance_configs', {}) + comp_configs = received_configs.get('compressor_config', {}) + configs_to_broadcast = [vis_flags, dist_configs, comp_configs] + else: + configs_to_broadcast = [None, None, None] + + torch.distributed.broadcast_object_list(configs_to_broadcast, src=0) + + vis_flags, dist_configs, comp_configs = configs_to_broadcast + from megatron.core.tensor_tracer import get_tt_flags, get_compressor + get_tt_flags().set_by_configs(vis_flags) + get_compressor().set_by_configs(comp_configs) + + print_rank_0("Configs synchronized. Starting training.") + if getattr(args, 'perform_rl_step', False): assert has_rl_utils, "RL cannot run without the megatron.rl package" @@ -3015,6 +3075,18 @@ def get_e2e_base_metrics(): if should_exit: break + if ws_process: + print_rank_0("Signaling WebSocket process to shut down...") + shutdown_event.set() + data_queue.close() + config_queue.close() + data_queue.cancel_join_thread() + config_queue.cancel_join_thread() + ws_process.join(timeout=5) + if ws_process.is_alive(): + print_rank_0("WebSocket process did not shut down cleanly, terminating.") + ws_process.terminate() + # Destroy CUDA Graphs. if args.cuda_graph_impl == "transformer_engine" and cuda_graph_helper.graphs_created(): cuda_graph_helper.delete_cuda_graphs() diff --git a/megatron/training/training_wsserver.py b/megatron/training/training_wsserver.py new file mode 100644 index 00000000000..3f4901d4241 --- /dev/null +++ b/megatron/training/training_wsserver.py @@ -0,0 +1,133 @@ +# Copyright 2025 Suanzhi Future Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. + +import json +import multiprocessing +import threading +import queue +import time +from websockets.sync.server import serve, Server +from websockets.exceptions import ConnectionClosed + +_websocket_connection = None +_websocket_lock = threading.Lock() + +_request_configs = {} + +def get_websocket(): + with _websocket_lock: + return _websocket_connection + +def websocket_server_process(port: int, data_queue: multiprocessing.Queue, config_queue: multiprocessing.Queue, start_event: multiprocessing.Event, shutdown_event: multiprocessing.Event, training_args: dict): + global _websocket_connection + + def _data_sender_thread_inner(): + print("Rank 0 (WS Process): Data sender thread started.", flush=True) + while not shutdown_event.is_set(): + try: + name_tuple, report_args, tensor_data = data_queue.get(timeout=1.0) + + payload = { + "type": "update", + "update_type": name_tuple[1].value, + "layer_id": name_tuple[0], + "args": report_args, + "result": tensor_data.tolist() + } + + ws = get_websocket() + if ws: + try: + ws.send(json.dumps(payload)) + except ConnectionClosed: + pass + except Exception as e: + print(f"Rank 0 (WS Process): Error sending data: {e}", flush=True) + + except queue.Empty: + continue + except (BrokenPipeError, EOFError): + print("Rank 0 (WS Process): Data queue connection broken, sender thread exiting.", flush=True) + break + except Exception as e: + print(f"Rank 0 (WS Process): Unexpected error in sender: {e}", flush=True) + print("Rank 0 (WS Process): Data sender thread exiting.", flush=True) + + def _websocket_handler(websocket): + global _websocket_connection + print("Rank 0 (WS Process): Frontend connected.", flush=True) + with _websocket_lock: + _websocket_connection = websocket + + try: + for message in websocket: + try: + request = json.loads(message) + if request.get("type") == "run_training_step": + print("Rank 0 (WS Process): Received 'run_training_step' command.", flush=True) + _request_configs['visualization_flags'] = request.get("visualization_flags", {}) + _request_configs['disturbance_configs'] = request.get("disturbance_configs", {}) + _request_configs['compressor_config'] = request.get("compressor_config", {}) + config_queue.put(_request_configs) + try: + start_payload = { + "type": "start", + "micro_batch_size": training_args.get("micro_batch_size"), + "seq_length": training_args.get("seq_length"), + "num_layers": training_args.get("num_layers") + } + websocket.send(json.dumps(start_payload)) + print("Rank 0 (WS Process): Sent 'start' message to frontend.", flush=True) + except Exception as e: + print(f"Rank 0 (WS Process): Failed to send 'start' message: {e}", flush=True) + start_event.set() + except Exception as e: + print(f"Rank 0 (WS Process): Error processing message: {e}", flush=True) + except ConnectionClosed: + print("Rank 0 (WS Process): Connection handler closed as expected.", flush=True) + finally: + with _websocket_lock: + _websocket_connection = None + start_event.clear() + print("Rank 0 (WS Process): Frontend disconnected.", flush=True) + + sender_thread = threading.Thread(target=_data_sender_thread_inner, daemon=True) + sender_thread.start() + + server: Server = None + + def shutdown_handler(): + shutdown_event.wait() + print("Rank 0 (WS Process): Shutdown event received, stopping server...", flush=True) + if server: + server.shutdown() + + shutdown_thread = threading.Thread(target=shutdown_handler, daemon=True) + shutdown_thread.start() + + print(f"Rank 0 (WS Process): Starting server on ws://0.0.0.0:{port}", flush=True) + + try: + with serve( + _websocket_handler, "0.0.0.0", port, + ping_interval=None, reuse_port=True + ) as server_instance: + server = server_instance + server.serve_forever() + except Exception as e: + print(f"Rank 0 (WS Process): Server crashed with an error: {e}", flush=True) + finally: + sender_thread.join(timeout=1.0) + config_queue.close() + print("Rank 0 (WS Process): Server has shut down.", flush=True) diff --git a/pyproject.toml b/pyproject.toml index 567954ca4a1..14efa3b721e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dynamic = ["version", "readme"] description = "Megatron Core - a library for efficient and scalable training of transformer based models" requires-python = ">=3.10" license = { text = "Apache 2.0" } -dependencies = ["torch>=2.6.0", "numpy", "packaging>=24.2"] +dependencies = ["torch>=2.6.0", "numpy", "packaging>=24.2", "websockets"] authors = [{ name = "NVIDIA", email = "nemo-toolkit@nvidia.com" }] maintainers = [{ name = "NVIDIA", email = "nemo-toolkit@nvidia.com" }] keywords = [ From 046733c1bf1d64af53b0cec45593be6503c587ef Mon Sep 17 00:00:00 2001 From: superay <1420782034@qq.com> Date: Sun, 8 Feb 2026 08:54:18 +0000 Subject: [PATCH 02/10] refactor: update compressor configuration handling in training modules --- megatron/core/tensor_tracer.py | 194 ++++++++++++------------- megatron/training/training.py | 12 +- megatron/training/training_wsserver.py | 1 - 3 files changed, 100 insertions(+), 107 deletions(-) diff --git a/megatron/core/tensor_tracer.py b/megatron/core/tensor_tracer.py index e9f24926714..c03ee084e08 100644 --- a/megatron/core/tensor_tracer.py +++ b/megatron/core/tensor_tracer.py @@ -25,7 +25,7 @@ def _set_tt_hook_manager(args, model): def _set_compressor(): global _GLOBAL_COMPRESSOR - _GLOBAL_COMPRESSOR=DefaultCompressor() + _GLOBAL_COMPRESSOR={} def set_report(func): global _GLOBAL_REPORT @@ -41,8 +41,8 @@ def get_tensor_tracers(): def get_tt_flags(): return _GLOBAL_TT_FLAGS -def get_compressor(): - return _GLOBAL_COMPRESSOR +def get_compressor(flag_type): + return _GLOBAL_COMPRESSOR[flag_type] def get_report(): return _GLOBAL_REPORT @@ -50,7 +50,6 @@ def get_report(): class FlagType(Enum): INVALID_FLAG = 0 QKV_mat_mul = 1 - RawAttentionScore_mat_mul = 2 ContextLayer_mat_mul = 3 MLP1_mat_mul = 4 MLP2_mat_mul = 5 @@ -61,37 +60,27 @@ class AbstractCompressor: def __init__(self): pass @abstractmethod - def set_by_configs(self, configs: Dict[str, Any]): + def compress_one_rank(self, layer_number, flag_type, data): pass @abstractmethod - def compress_one_rank(self, name, data): - pass - @abstractmethod - def compress(self, name, data): + def compress(self, layer_number, flag_type, data): pass -class DefaultCompressor(AbstractCompressor): - def __init__(self): +class TileCompressor(AbstractCompressor): + def __init__(self, configs): self.configs = { - "QKV": { - "pixels": 96, - "method": "data.mean(dim=-1)" - }, - "MLP": { - "pixels": 64, - "method": "data.mean(dim=-1)" - } + "tiles": configs.get("tiles", 96), + "method": configs.get("method", "data.mean(dim=-1)"), + "tiles_one_rank": configs.get("tiles_one_rank", 96), + "method_one_rank": configs.get("method_one_rank", "data.mean(dim=-1)") } - def set_by_configs(self, configs: Dict[str, Any]): - self.configs = configs - - def compress_tensor(self, data_in, pixels, method): + def compress_tensor(self, data_in, tiles, method): B, S, F = data_in.shape - chunk_size = math.ceil(F / pixels) - padded_len = chunk_size * pixels + chunk_size = math.ceil(F / tiles) + padded_len = chunk_size * tiles padded_data = torch.nn.functional.pad(data_in, (0, padded_len - F)) - data_for_eval = padded_data.reshape(B, S, pixels, chunk_size) + data_for_eval = padded_data.reshape(B, S, tiles, chunk_size) try: compressed = eval(method, {}, {"data": data_for_eval}) except Exception as e: @@ -99,54 +88,67 @@ def compress_tensor(self, data_in, pixels, method): compressed = data_for_eval.mean(dim=-1) return compressed - def compress_1d_tensor(self, data_in, pixels, method): - B, S, F = data_in.shape - chunk_size = math.ceil(F / pixels) - padded_len = chunk_size * pixels - padded_data = torch.nn.functional.pad(data_in, (0, padded_len - F)) - data_for_eval = padded_data.reshape(B, S, pixels, chunk_size) - try: - compressed = eval(method, {}, {"data": data_for_eval}).flatten() - except Exception as e: - print(f"Error in compressing tensor with method '{method}': {e}") - compressed = data_for_eval.mean(dim=-1).flatten() # Fallback to mean if eval fails - return compressed + def compress_one_rank(self, layer_number, flag_type, data): + return self.compress_tensor(data, self.configs["tiles_one_rank"], self.configs["method_one_rank"]) + + def compress(self, layer_number, flag_type, data): + compressed = self.compress_tensor(data, self.configs["tiles"], self.configs["method"]) + return True, list(compressed.shape), compressed.flatten() + +class NoOpCompressor(AbstractCompressor): + def __init__(self, configs): + pass - def compress_one_rank(self, flag_type, data): - if flag_type == FlagType.QKV_mat_mul: - return self.compress_tensor(data, self.configs["QKV"]["pixels"], self.configs["QKV"]["method"]) - elif flag_type == FlagType.MLP1_mat_mul or flag_type == FlagType.MLP2_mat_mul or flag_type == FlagType.ContextLayer_mat_mul: - return self.compress_tensor(data, self.configs["MLP"]["pixels"], self.configs["MLP"]["method"]) + def compress_one_rank(self, layer_number, flag_type, data): return data - def compress(self, name, data): - flag_type = name[1] - if flag_type == FlagType.QKV_mat_mul: - n = data.shape[1]; return True, [n], self.compress_1d_tensor(data, self.configs["QKV"]["pixels"], self.configs["QKV"]["method"]) - elif flag_type == FlagType.RawAttentionScore_mat_mul: - np, n, m = data.shape[1], data.shape[2], data.shape[3]; return True, [np, n, m], data[:, :, :, :].flatten() - elif flag_type == FlagType.MLP1_mat_mul or flag_type == FlagType.MLP2_mat_mul or flag_type == FlagType.ContextLayer_mat_mul: - n = data.shape[1]; return True, [n], self.compress_1d_tensor(data, self.configs["MLP"]["pixels"], self.configs["MLP"]["method"]) - return False, [], torch.tensor([]) + def compress(self, layer_number, flag_type, data): + return True, list(data.shape), data.flatten() -class ProjectionCompressor(AbstractCompressor): - def __init__(self): +class EmptyCompressor(AbstractCompressor): + def __init__(self, configs): pass - def set_by_configs(self, configs: Dict[str, Any]): - pass - - def compress_one_rank(self, name, data): + def compress_one_rank(self, layer_number, flag_type, data): + return torch.tensor([]) + + def compress(self, layer_number, flag_type, data): + return True, [0], torch.tensor([]) + +class ProjectionCompressor(AbstractCompressor): + def __init__(self, configs): + try: + self.projection_vector = torch.load(configs["vector_path"]) + self.projection_vector = torch.nn.functional.normalize(self.projection_vector, p=2, dim=1) + device = torch.cuda.current_device() + self.projection_vector = self.projection_vector.to(device) + except Exception as e: + print(f"Error loading projection vector: {e}") + self.projection_vector = None + + def compress_one_rank(self, layer_number, flag_type, data): return data - def compress(self, name, data): - return False, [], torch.tensor([]) + def compress(self, layer_number, flag_type, data): + if self.projection_vector is None: + return False, [], torch.tensor([]) + else: + vector = self.projection_vector[layer_number - 1] + projected = torch.matmul(data, vector).unsqueeze(-1) + return True, list(projected.shape), projected.flatten() + +COMPRESSOR_MAP = { + "TileCompressor": TileCompressor, + "NoOpCompressor": NoOpCompressor, + "EmptyCompressor": EmptyCompressor, + "ProjectionCompressor": ProjectionCompressor +} class TensorTracers: # simplified as TT def __init__(self) -> None: pass def report(self, name, tensor_data): - valid, comp_args, compressed_tensor = get_compressor().compress(name, tensor_data) + valid, comp_args, compressed_tensor = get_compressor(name[1]).compress(name[0], name[1], tensor_data) assert valid get_report()(name, comp_args, compressed_tensor) @@ -158,7 +160,6 @@ def __init__(self, args): self.flags: Dict[FlagType, Dict[int, bool]] = { FlagType.INVALID_FLAG: {i: False for i in range(1, self.num_layers + 1)}, FlagType.QKV_mat_mul: {i: False for i in range(1, self.num_layers + 1)}, - FlagType.RawAttentionScore_mat_mul: {i: False for i in range(1, self.num_layers + 1)}, FlagType.ContextLayer_mat_mul: {i: False for i in range(1, self.num_layers + 1)}, FlagType.MLP1_mat_mul: {i: False for i in range(1, self.num_layers + 1)}, FlagType.MLP2_mat_mul: {i: False for i in range(1, self.num_layers + 1)}, @@ -170,34 +171,18 @@ def __init__(self, args): def get_flag(self, flag_type: FlagType, layer_index: int) -> bool: return self.should_trace and self.flags.get(flag_type, {}).get(layer_index, False) - def set_by_configs(self, configs: Dict[str, Any]): - val = True if configs.get("QKV_mat_mul", "False").lower() == "true" else False - for i in range(1, self.num_layers + 1): - self.flags[FlagType.QKV_mat_mul][i] = val - - val = True if configs.get("RawAttentionScore_mat_mul", "False").lower() == "true" else False - for i in range(1, self.num_layers + 1): - self.flags[FlagType.RawAttentionScore_mat_mul][i] = val - - val = True if configs.get("ContextLayer_mat_mul", "False").lower() == "true" else False - for i in range(1, self.num_layers + 1): - self.flags[FlagType.ContextLayer_mat_mul][i] = val - - val = True if configs.get("MLP1_mat_mul", "True").lower() == "true" else False - for i in range(1, self.num_layers + 1): - self.flags[FlagType.MLP1_mat_mul][i] = val - - val = True if configs.get("MLP2_mat_mul", "True").lower() == "true" else False - for i in range(1, self.num_layers + 1): - self.flags[FlagType.MLP2_mat_mul][i] = val - - val = True if configs.get("AttentionOutput_mat_mul", "False").lower() == "true" else False - for i in range(1, self.num_layers + 1): - self.flags[FlagType.AttentionOutput_mat_mul][i] = val - - val = True if configs.get("HiddenStates", "False").lower() == "true" else False - for i in range(1, self.num_layers + 1): - self.flags[FlagType.HiddenStates][i] = val + def set_by_configs(self, configs: Dict[str, Any], comp_configs: Dict[str, Any]): + for flag_type in self.flags: + if flag_type == FlagType.INVALID_FLAG: + continue + val = True if configs.get(flag_type.name, "False").lower() == "true" else False + for i in range(1, self.num_layers + 1): + self.flags[flag_type][i] = val + if comp_configs.get(flag_type.name, None): + specific_comp_config = comp_configs[flag_type.name] + compressor_type = specific_comp_config.get("compressor_type", "EmptyCompressor") + compressor_configs = specific_comp_config.get("compressor_configs", {}) + _GLOBAL_COMPRESSOR[flag_type] = COMPRESSOR_MAP.get(compressor_type, EmptyCompressor)(compressor_configs) class TTHookManager: def __init__(self, args, model) -> None: @@ -213,8 +198,11 @@ def hook(module, input, output): world_size = get_tensor_model_parallel_world_size() rank = get_tensor_model_parallel_rank() - tensor_data = output[0].detach() - tensor_data = get_compressor().compress_one_rank(flag_type, tensor_data) + if isinstance(output, (list, tuple)): + tensor_data = output[0].detach() + else: + tensor_data = output.detach() + tensor_data = get_compressor(flag_type).compress_one_rank(layer_number, flag_type, tensor_data) tensor_data_cont = tensor_data.contiguous() if rank == 0: tensor_list = [torch.zeros_like(tensor_data_cont, dtype=tensor_data_cont.dtype, device=device) for _ in range(world_size)] @@ -256,8 +244,11 @@ def hook(module, input, output): rank = get_tensor_model_parallel_rank() if args.sequence_parallel: - tensor_data = output[0].detach() - tensor_data = get_compressor().compress_one_rank(flag_type, tensor_data) + if isinstance(output, (list, tuple)): + tensor_data = output[0].detach() + else: + tensor_data = output.detach() + tensor_data = get_compressor(flag_type).compress_one_rank(layer_number, flag_type, tensor_data) tensor_data_cont = tensor_data.contiguous() if rank == 0: tensor_list = [torch.zeros_like(tensor_data_cont, dtype=tensor_data_cont.dtype, device=device) for _ in range(world_size)] @@ -273,8 +264,11 @@ def hook(module, input, output): get_tensor_tracers().report((layer_number, flag_type), aggregated_tensor.transpose(0, 1)) else: if rank == 0: - tensor_data = output[0].detach() - tensor_data = get_compressor().compress_one_rank(flag_type, tensor_data) + if isinstance(output, (list, tuple)): + tensor_data = output[0].detach() + else: + tensor_data = output.detach() + tensor_data = get_compressor(flag_type).compress_one_rank(layer_number, flag_type, tensor_data) get_tensor_tracers().report((layer_number, flag_type), tensor_data.transpose(0, 1)) return hook @@ -285,8 +279,11 @@ def hook(module, input, output): world_size = get_tensor_model_parallel_world_size() rank = get_tensor_model_parallel_rank() - tensor_data = output.detach() - tensor_data = get_compressor().compress_one_rank(flag_type, tensor_data) + if isinstance(output, (list, tuple)): + tensor_data = output[0].detach() + else: + tensor_data = output.detach() + tensor_data = get_compressor(flag_type).compress_one_rank(layer_number, flag_type, tensor_data) tensor_data_cont = tensor_data.contiguous() if rank == 0: tensor_list = [torch.zeros_like(tensor_data_cont, dtype=tensor_data_cont.dtype, device=device) for _ in range(world_size)] @@ -308,7 +305,6 @@ def hook(module, input, output): self.hooks.append(model.decoder.layers[layer].mlp.linear_fc2.register_forward_hook(generate_hook_transpose_row(FlagType.MLP2_mat_mul, global_layer_number))) # Row self.hooks.append(model.decoder.layers[layer].self_attention.register_forward_hook(generate_hook_transpose_row(FlagType.AttentionOutput_mat_mul, global_layer_number))) # Row self.hooks.append(model.decoder.layers[layer].register_forward_hook(generate_hook_transpose_row(FlagType.HiddenStates, global_layer_number))) # Row - self.hooks.append(model.decoder.layers[layer].self_attention.core_attention.scale_mask_softmax.register_forward_hook(generate_hook_attn(FlagType.RawAttentionScore_mat_mul, global_layer_number))) # Raw Attention Scores, Special self.hooks.append(model.decoder.layers[layer].self_attention.core_attention.register_forward_hook(generate_hook_transpose_col(FlagType.ContextLayer_mat_mul, global_layer_number))) # Col, not gather_output ''' diff --git a/megatron/training/training.py b/megatron/training/training.py index c769fec95c4..3ba92a08488 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -2491,18 +2491,16 @@ def report_func(name_tuple, report_args, tensor_data): if torch.distributed.get_rank() == 0: received_configs = config_queue.get() vis_flags = received_configs.get('visualization_flags', {}) - dist_configs = received_configs.get('disturbance_configs', {}) comp_configs = received_configs.get('compressor_config', {}) - configs_to_broadcast = [vis_flags, dist_configs, comp_configs] + configs_to_broadcast = [vis_flags, comp_configs] else: - configs_to_broadcast = [None, None, None] + configs_to_broadcast = [None, None] torch.distributed.broadcast_object_list(configs_to_broadcast, src=0) - vis_flags, dist_configs, comp_configs = configs_to_broadcast - from megatron.core.tensor_tracer import get_tt_flags, get_compressor - get_tt_flags().set_by_configs(vis_flags) - get_compressor().set_by_configs(comp_configs) + vis_flags, comp_configs = configs_to_broadcast + from megatron.core.tensor_tracer import get_tt_flags + get_tt_flags().set_by_configs(vis_flags, comp_configs) print_rank_0("Configs synchronized. Starting training.") diff --git a/megatron/training/training_wsserver.py b/megatron/training/training_wsserver.py index 3f4901d4241..d2226e0e085 100644 --- a/megatron/training/training_wsserver.py +++ b/megatron/training/training_wsserver.py @@ -77,7 +77,6 @@ def _websocket_handler(websocket): if request.get("type") == "run_training_step": print("Rank 0 (WS Process): Received 'run_training_step' command.", flush=True) _request_configs['visualization_flags'] = request.get("visualization_flags", {}) - _request_configs['disturbance_configs'] = request.get("disturbance_configs", {}) _request_configs['compressor_config'] = request.get("compressor_config", {}) config_queue.put(_request_configs) try: From 55dc9d2b5b33e429f25a26e69e2337f32772e4bb Mon Sep 17 00:00:00 2001 From: superay <1420782034@qq.com> Date: Sun, 8 Feb 2026 21:48:52 +0800 Subject: [PATCH 03/10] feat: add support for pipeline parallel worker connections in tensor tracer websocket server --- megatron/core/tensor_tracer.py | 9 +- megatron/training/training.py | 61 ++++++----- megatron/training/training_wsserver.py | 134 ++++++++++++++++--------- 3 files changed, 127 insertions(+), 77 deletions(-) diff --git a/megatron/core/tensor_tracer.py b/megatron/core/tensor_tracer.py index c03ee084e08..12f5b17a732 100644 --- a/megatron/core/tensor_tracer.py +++ b/megatron/core/tensor_tracer.py @@ -197,6 +197,7 @@ def hook(module, input, output): device = torch.cuda.current_device() world_size = get_tensor_model_parallel_world_size() rank = get_tensor_model_parallel_rank() + rank0_global = torch.distributed.get_process_group_ranks(get_tensor_model_parallel_group())[0] if isinstance(output, (list, tuple)): tensor_data = output[0].detach() @@ -209,7 +210,7 @@ def hook(module, input, output): else: tensor_list = None if world_size > 1: - torch.distributed.gather(tensor_data_cont, tensor_list, dst=0, group=get_tensor_model_parallel_group()) + torch.distributed.gather(tensor_data_cont, tensor_list, dst=rank0_global, group=get_tensor_model_parallel_group()) else: tensor_list = [tensor_data_cont] @@ -242,6 +243,7 @@ def hook(module, input, output): device = torch.cuda.current_device() world_size = get_tensor_model_parallel_world_size() rank = get_tensor_model_parallel_rank() + rank0_global = torch.distributed.get_process_group_ranks(get_tensor_model_parallel_group())[0] if args.sequence_parallel: if isinstance(output, (list, tuple)): @@ -255,7 +257,7 @@ def hook(module, input, output): else: tensor_list = None if world_size > 1: - torch.distributed.gather(tensor_data_cont, tensor_list, dst=0, group=get_tensor_model_parallel_group()) + torch.distributed.gather(tensor_data_cont, tensor_list, dst=rank0_global, group=get_tensor_model_parallel_group()) else: tensor_list = [tensor_data_cont] @@ -278,6 +280,7 @@ def hook(module, input, output): device = torch.cuda.current_device() world_size = get_tensor_model_parallel_world_size() rank = get_tensor_model_parallel_rank() + rank0_global = torch.distributed.get_process_group_ranks(get_tensor_model_parallel_group())[0] if isinstance(output, (list, tuple)): tensor_data = output[0].detach() @@ -290,7 +293,7 @@ def hook(module, input, output): else: tensor_list = None if world_size > 1: - torch.distributed.gather(tensor_data_cont, tensor_list, dst=0, group=get_tensor_model_parallel_group()) + torch.distributed.gather(tensor_data_cont, tensor_list, dst=rank0_global, group=get_tensor_model_parallel_group()) else: tensor_list = [tensor_data_cont] diff --git a/megatron/training/training.py b/megatron/training/training.py index 3ba92a08488..0be477c37b9 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -52,7 +52,7 @@ def set_startup_timestamps(program_start=None, main_entry=None): from .log_handler import CustomHandler from megatron.core.tensor_tracer import _set_tt_hook_manager -from .training_wsserver import websocket_server_process +from .training_wsserver import websocket_server_process, websocket_worker_process # Make default logging level INFO, but filter out all log messages not from MCore. logging.basicConfig(handlers=[CustomHandler()], level=logging.INFO) @@ -2452,22 +2452,12 @@ def train( args = get_args() timers = get_timers() - if args.tensor_tracer_port is not None and torch.distributed.get_rank() == 0: + global_rank = torch.distributed.get_rank() + tp_rank = mpu.get_tensor_model_parallel_rank() + if args.tensor_tracer_port is not None and tp_rank == 0: data_queue = multiprocessing.Queue() - config_queue = multiprocessing.Queue(maxsize=1) - start_training_event = multiprocessing.Event() shutdown_event = multiprocessing.Event() - training_args_dict = { - "micro_batch_size": args.micro_batch_size, - "seq_length": args.seq_length, - "num_layers": args.num_layers - } - ws_process = multiprocessing.Process( - target=websocket_server_process, - args=(args.tensor_tracer_port, data_queue, config_queue, start_training_event, shutdown_event, training_args_dict), - daemon=True - ) - ws_process.start() + from megatron.core.tensor_tracer import FlagType, set_report def report_func(name_tuple, report_args, tensor_data): # name_tuple is (layer_id, FlagType) @@ -2482,13 +2472,35 @@ def report_func(name_tuple, report_args, tensor_data): pass set_report(report_func) - if args.tensor_tracer_port is not None: - if torch.distributed.get_rank() == 0: - print_rank_0("Waiting for 'run_training_step' command from frontend to start training...") + if global_rank == 0: + config_queue = multiprocessing.Queue(maxsize=1) + start_training_event = multiprocessing.Event() + + training_args_dict = { + "micro_batch_size": args.micro_batch_size, + "seq_length": args.seq_length, + "num_layers": args.num_layers + } + + ws_process = multiprocessing.Process( + target=websocket_server_process, + args=(args.tensor_tracer_port, data_queue, config_queue, start_training_event, shutdown_event, training_args_dict), + daemon=True + ) + ws_process.start() start_training_event.wait() - print_rank_0("Command received. Synchronizing configs across all ranks...") + else: + master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1") + + ws_process = multiprocessing.Process( + target=websocket_worker_process, + args=(master_addr, args.tensor_tracer_port, global_rank, data_queue, shutdown_event), + daemon=True + ) + ws_process.start() - if torch.distributed.get_rank() == 0: + if args.tensor_tracer_port is not None: + if global_rank == 0: received_configs = config_queue.get() vis_flags = received_configs.get('visualization_flags', {}) comp_configs = received_configs.get('compressor_config', {}) @@ -2502,8 +2514,6 @@ def report_func(name_tuple, report_args, tensor_data): from megatron.core.tensor_tracer import get_tt_flags get_tt_flags().set_by_configs(vis_flags, comp_configs) - print_rank_0("Configs synchronized. Starting training.") - if getattr(args, 'perform_rl_step', False): assert has_rl_utils, "RL cannot run without the megatron.rl package" @@ -3073,13 +3083,14 @@ def get_e2e_base_metrics(): if should_exit: break - if ws_process: + if args.tensor_tracer_port is not None and tp_rank == 0: print_rank_0("Signaling WebSocket process to shut down...") shutdown_event.set() data_queue.close() - config_queue.close() data_queue.cancel_join_thread() - config_queue.cancel_join_thread() + if global_rank == 0: + config_queue.close() + config_queue.cancel_join_thread() ws_process.join(timeout=5) if ws_process.is_alive(): print_rank_0("WebSocket process did not shut down cleanly, terminating.") diff --git a/megatron/training/training_wsserver.py b/megatron/training/training_wsserver.py index d2226e0e085..b8555394e8b 100644 --- a/megatron/training/training_wsserver.py +++ b/megatron/training/training_wsserver.py @@ -18,26 +18,60 @@ import queue import time from websockets.sync.server import serve, Server +from websockets.sync.client import connect from websockets.exceptions import ConnectionClosed -_websocket_connection = None -_websocket_lock = threading.Lock() +_frontend_connection = None +_connection_lock = threading.Lock() _request_configs = {} -def get_websocket(): - with _websocket_lock: - return _websocket_connection +def get_frontend_connection(): + with _connection_lock: + return _frontend_connection + +def websocket_worker_process(master_addr: str, port: int, rank: int, data_queue: multiprocessing.Queue, shutdown_event: multiprocessing.Event): + uri = f"ws://{master_addr}:{port}" + print(f"Rank {rank} (Worker): Connecting to Hub at {uri}...", flush=True) + + while not shutdown_event.is_set(): + try: + with connect(uri) as websocket: + print(f"Rank {rank} (Worker): Connected.", flush=True) + + while not shutdown_event.is_set(): + try: + name_tuple, report_args, tensor_data = data_queue.get(timeout=1.0) + + payload = { + "type": "worker_forward", + "data": { + "type": "update", + "update_type": name_tuple[1].value, + "layer_id": name_tuple[0], + "args": report_args, + "result": tensor_data.tolist() + } + } + + websocket.send(json.dumps(payload)) + except queue.Empty: + continue + except (ConnectionRefusedError, OSError): + time.sleep(2) + except Exception as e: + print(f"Rank {rank} (Worker): Error: {e}, retrying...", flush=True) + time.sleep(2) def websocket_server_process(port: int, data_queue: multiprocessing.Queue, config_queue: multiprocessing.Queue, start_event: multiprocessing.Event, shutdown_event: multiprocessing.Event, training_args: dict): - global _websocket_connection + global _frontend_connection - def _data_sender_thread_inner(): - print("Rank 0 (WS Process): Data sender thread started.", flush=True) + def _local_data_sender(): + print("Rank 0 (Server): Local data sender started.", flush=True) while not shutdown_event.is_set(): try: name_tuple, report_args, tensor_data = data_queue.get(timeout=1.0) - + payload = { "type": "update", "update_type": name_tuple[1].value, @@ -45,77 +79,79 @@ def _data_sender_thread_inner(): "args": report_args, "result": tensor_data.tolist() } - - ws = get_websocket() + + ws = get_frontend_connection() if ws: try: ws.send(json.dumps(payload)) except ConnectionClosed: pass except Exception as e: - print(f"Rank 0 (WS Process): Error sending data: {e}", flush=True) + print(f"Rank 0 (Server): Error sending data: {e}", flush=True) except queue.Empty: continue - except (BrokenPipeError, EOFError): - print("Rank 0 (WS Process): Data queue connection broken, sender thread exiting.", flush=True) - break except Exception as e: - print(f"Rank 0 (WS Process): Unexpected error in sender: {e}", flush=True) - print("Rank 0 (WS Process): Data sender thread exiting.", flush=True) + print(f"Rank 0 (Server): Unexpected error in sender: {e}", flush=True) def _websocket_handler(websocket): - global _websocket_connection - print("Rank 0 (WS Process): Frontend connected.", flush=True) - with _websocket_lock: - _websocket_connection = websocket - + global _frontend_connection + is_frontend = False + try: for message in websocket: try: - request = json.loads(message) - if request.get("type") == "run_training_step": - print("Rank 0 (WS Process): Received 'run_training_step' command.", flush=True) - _request_configs['visualization_flags'] = request.get("visualization_flags", {}) - _request_configs['compressor_config'] = request.get("compressor_config", {}) + msg_obj = json.loads(message) + msg_type = msg_obj.get("type") + if msg_type == "worker_forward": + forward_data = msg_obj.get("data") + frontend = get_frontend_connection() + if frontend: + frontend.send(json.dumps(forward_data)) + elif msg_type == "run_training_step": + print("Rank 0 (Server): Frontend connected and assumed control.", flush=True) + with _connection_lock: + _frontend_connection = websocket + is_frontend = True + + _request_configs['visualization_flags'] = msg_obj.get("visualization_flags", {}) + _request_configs['compressor_config'] = msg_obj.get("compressor_config", {}) config_queue.put(_request_configs) - try: - start_payload = { - "type": "start", - "micro_batch_size": training_args.get("micro_batch_size"), - "seq_length": training_args.get("seq_length"), - "num_layers": training_args.get("num_layers") - } - websocket.send(json.dumps(start_payload)) - print("Rank 0 (WS Process): Sent 'start' message to frontend.", flush=True) - except Exception as e: - print(f"Rank 0 (WS Process): Failed to send 'start' message: {e}", flush=True) + + start_payload = { + "type": "start", + "micro_batch_size": training_args.get("micro_batch_size"), + "seq_length": training_args.get("seq_length"), + "num_layers": training_args.get("num_layers") + } + websocket.send(json.dumps(start_payload)) start_event.set() except Exception as e: - print(f"Rank 0 (WS Process): Error processing message: {e}", flush=True) + print(f"Rank 0 (Server): Error processing message: {e}", flush=True) except ConnectionClosed: - print("Rank 0 (WS Process): Connection handler closed as expected.", flush=True) + print("Rank 0 (Server): Connection handler closed.", flush=True) finally: - with _websocket_lock: - _websocket_connection = None - start_event.clear() - print("Rank 0 (WS Process): Frontend disconnected.", flush=True) + if is_frontend: + print("Rank 0 (Server): Frontend disconnected.", flush=True) + with _connection_lock: + _frontend_connection = None + start_event.clear() - sender_thread = threading.Thread(target=_data_sender_thread_inner, daemon=True) + sender_thread = threading.Thread(target=_local_data_sender, daemon=True) sender_thread.start() server: Server = None def shutdown_handler(): shutdown_event.wait() - print("Rank 0 (WS Process): Shutdown event received, stopping server...", flush=True) + print("Rank 0 (Server): Shutdown event received, stopping server...", flush=True) if server: server.shutdown() shutdown_thread = threading.Thread(target=shutdown_handler, daemon=True) shutdown_thread.start() - print(f"Rank 0 (WS Process): Starting server on ws://0.0.0.0:{port}", flush=True) + print(f"Rank 0 (Server): Starting server on ws://0.0.0.0:{port}", flush=True) try: with serve( @@ -125,8 +161,8 @@ def shutdown_handler(): server = server_instance server.serve_forever() except Exception as e: - print(f"Rank 0 (WS Process): Server crashed with an error: {e}", flush=True) + print(f"Rank 0 (Server): Server crashed with an error: {e}", flush=True) finally: sender_thread.join(timeout=1.0) config_queue.close() - print("Rank 0 (WS Process): Server has shut down.", flush=True) + print("Rank 0 (Server): Server has shut down.", flush=True) From 7f2a490224017672133103d4f577d0c075b0a9d1 Mon Sep 17 00:00:00 2001 From: superay <1420782034@qq.com> Date: Thu, 12 Feb 2026 21:42:42 +0800 Subject: [PATCH 04/10] feat: enable tensor tracing only in forward step --- megatron/core/pipeline_parallel/schedules.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index f15dcd1400b..d5d9cbc8bce 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -401,6 +401,10 @@ def forward_step( """ from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler + from megatron.core.tensor_tracer import get_tt_flags + if get_tt_flags() is not None: + get_tt_flags().should_trace = True + if config.timers is not None: config.timers('forward-compute', log_level=2).start() @@ -441,6 +445,9 @@ def forward_step( is_last_stage, ) + if get_tt_flags() is not None: + get_tt_flags().should_trace = False + if unwrap_output_tensor: return output_tensor, num_tokens return [output_tensor], num_tokens From d5eb76e89040413811c77db72fb747f0291b477f Mon Sep 17 00:00:00 2001 From: Chen Shuo <211250172@smail.nju.edu.cn> Date: Sat, 21 Feb 2026 23:23:31 +0800 Subject: [PATCH 05/10] chore: make tensor tracer optional dependency --- megatron/training/training.py | 16 +++++++++++++--- pyproject.toml | 4 +++- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/megatron/training/training.py b/megatron/training/training.py index 0be477c37b9..e0efed5b7a0 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -51,9 +51,6 @@ def set_startup_timestamps(program_start=None, main_entry=None): from megatron.core.optimizer.distrib_optimizer import DistributedOptimizer from .log_handler import CustomHandler -from megatron.core.tensor_tracer import _set_tt_hook_manager -from .training_wsserver import websocket_server_process, websocket_worker_process - # Make default logging level INFO, but filter out all log messages not from MCore. logging.basicConfig(handlers=[CustomHandler()], level=logging.INFO) from .theoretical_memory_usage import report_theoretical_memory @@ -909,6 +906,8 @@ def pretrain( model_provider, model_type, checkpointing_context=checkpointing_context ) if args.tensor_tracer_port is not None: + from megatron.core.tensor_tracer import _set_tt_hook_manager + _set_tt_hook_manager(args, model) timers('model-and-optimizer-setup').stop() @@ -2455,6 +2454,17 @@ def train( global_rank = torch.distributed.get_rank() tp_rank = mpu.get_tensor_model_parallel_rank() if args.tensor_tracer_port is not None and tp_rank == 0: + try: + from .training_wsserver import websocket_server_process, websocket_worker_process + except ModuleNotFoundError as exc: + if exc.name == "websockets": + raise RuntimeError( + "Tensor tracer requires optional dependency 'websockets'. " + "Install it with `pip install -e '.[tensor_tracer]'` " + "(or `pip install websockets`)." + ) from exc + raise + data_queue = multiprocessing.Queue() shutdown_event = multiprocessing.Event() diff --git a/pyproject.toml b/pyproject.toml index 14efa3b721e..797e05f1a2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dynamic = ["version", "readme"] description = "Megatron Core - a library for efficient and scalable training of transformer based models" requires-python = ">=3.10" license = { text = "Apache 2.0" } -dependencies = ["torch>=2.6.0", "numpy", "packaging>=24.2", "websockets"] +dependencies = ["torch>=2.6.0", "numpy", "packaging>=24.2"] authors = [{ name = "NVIDIA", email = "nemo-toolkit@nvidia.com" }] maintainers = [{ name = "NVIDIA", email = "nemo-toolkit@nvidia.com" }] keywords = [ @@ -66,6 +66,8 @@ Homepage = "https://github.com/NVIDIA/Megatron-LM/megatron/core" [project.optional-dependencies] mlm = ["flask-restful", "sentencepiece", "tiktoken", "wandb", "transformers"] +tensor_tracer = ["websockets>=10.0"] + dev = [ "nvidia-modelopt[torch]; sys_platform != 'darwin'", "transformer-engine[pytorch,core_cu13]>=2.9.0a0,<2.12.0", From 8816ea78845aa49fc8f9fbb727e455796cc24532 Mon Sep 17 00:00:00 2001 From: Chen Shuo <211250172@smail.nju.edu.cn> Date: Sat, 21 Feb 2026 23:23:39 +0800 Subject: [PATCH 06/10] style: tighten tensor tracer implementation --- megatron/core/pipeline_parallel/schedules.py | 11 +- megatron/core/tensor_tracer.py | 322 +++++++++++++++---- 2 files changed, 264 insertions(+), 69 deletions(-) diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index d5d9cbc8bce..2e892a03e93 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -399,11 +399,12 @@ def forward_step( Tensor or list[Tensor]: The output object(s) from the forward step. Tensor: The number of tokens. """ + from megatron.core.tensor_tracer import get_tt_flags from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler - from megatron.core.tensor_tracer import get_tt_flags - if get_tt_flags() is not None: - get_tt_flags().should_trace = True + tt_flags = get_tt_flags() + if tt_flags is not None: + tt_flags.should_trace = True if config.timers is not None: config.timers('forward-compute', log_level=2).start() @@ -445,8 +446,8 @@ def forward_step( is_last_stage, ) - if get_tt_flags() is not None: - get_tt_flags().should_trace = False + if tt_flags is not None: + tt_flags.should_trace = False if unwrap_output_tensor: return output_tensor, num_tokens diff --git a/megatron/core/tensor_tracer.py b/megatron/core/tensor_tracer.py index 12f5b17a732..bba9add16b1 100644 --- a/megatron/core/tensor_tracer.py +++ b/megatron/core/tensor_tracer.py @@ -1,53 +1,98 @@ -from abc import abstractmethod -import torch +import logging import math +from abc import abstractmethod from enum import Enum -from typing import Dict, Any -from megatron.core.parallel_state import get_tensor_model_parallel_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank +from typing import Any, Callable, Dict, Optional, Tuple + +import torch + +from megatron.core.parallel_state import ( + get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) + +logger = logging.getLogger(__name__) + +NameTuple = Tuple[int, "FlagType"] +ReportFn = Callable[[NameTuple, list[int], torch.Tensor], None] + +_GLOBAL_TT_FLAGS: Optional["TTFlags"] = None +_GLOBAL_TENSOR_TRACERS: Optional["TensorTracers"] = None +_GLOBAL_COMPRESSOR: Optional[Dict["FlagType", "AbstractCompressor"]] = None +_GLOBAL_HOOK_MANAGER: Optional["TTHookManager"] = None + + +def _noop_report(name: NameTuple, args: list[int], tensor: torch.Tensor) -> None: + return + + +_GLOBAL_REPORT: ReportFn = _noop_report -_GLOBAL_TT_FLAGS = None -_GLOBAL_TENSOR_TRACERS = None -_GLOBAL_REPORT = lambda name, args, tensor: None -_GLOBAL_COMPRESSOR = None -_GLOBAL_HOOK_MANAGER = None def _set_tensor_tracers(): global _GLOBAL_TENSOR_TRACERS _GLOBAL_TENSOR_TRACERS = TensorTracers() + def _set_tt_flags(args): global _GLOBAL_TT_FLAGS _GLOBAL_TT_FLAGS = TTFlags(args) + def _set_tt_hook_manager(args, model): global _GLOBAL_HOOK_MANAGER _GLOBAL_HOOK_MANAGER = TTHookManager(args, model) + def _set_compressor(): global _GLOBAL_COMPRESSOR - _GLOBAL_COMPRESSOR={} + _GLOBAL_COMPRESSOR = {flag_type: EmptyCompressor({}) for flag_type in FlagType} + def set_report(func): + """Set the global tensor report callback.""" global _GLOBAL_REPORT _GLOBAL_REPORT = func + def unset_report(): + """Reset the global tensor report callback to a no-op.""" global _GLOBAL_REPORT - _GLOBAL_REPORT = lambda name, args, tensor: None + _GLOBAL_REPORT = _noop_report + def get_tensor_tracers(): + """Return the global tensor tracer instance, if initialized.""" return _GLOBAL_TENSOR_TRACERS + def get_tt_flags(): + """Return the global tensor-tracing flags instance, if initialized.""" return _GLOBAL_TT_FLAGS + def get_compressor(flag_type): - return _GLOBAL_COMPRESSOR[flag_type] + """Return the compressor associated with a flag type.""" + global _GLOBAL_COMPRESSOR + if _GLOBAL_COMPRESSOR is None: + _set_compressor() + assert _GLOBAL_COMPRESSOR is not None + compressor = _GLOBAL_COMPRESSOR.get(flag_type) + if compressor is None: + compressor = EmptyCompressor({}) + _GLOBAL_COMPRESSOR[flag_type] = compressor + return compressor + def get_report(): + """Return the current global tensor report callback.""" return _GLOBAL_REPORT + class FlagType(Enum): + """Kinds of intermediate tensors that can be traced.""" + INVALID_FLAG = 0 QKV_mat_mul = 1 ContextLayer_mat_mul = 3 @@ -56,102 +101,156 @@ class FlagType(Enum): AttentionOutput_mat_mul = 6 HiddenStates = 7 + class AbstractCompressor: + """Abstract base class for tensor compressors.""" + def __init__(self): pass + @abstractmethod def compress_one_rank(self, layer_number, flag_type, data): - pass + """Compress a tensor locally on one rank before any gather.""" + raise NotImplementedError + @abstractmethod def compress(self, layer_number, flag_type, data): - pass + """Compress an already-gathered tensor and return (valid, args, payload).""" + raise NotImplementedError + class TileCompressor(AbstractCompressor): + """Compress by chunking the last dimension into tiles and reducing each tile.""" + def __init__(self, configs): self.configs = { "tiles": configs.get("tiles", 96), "method": configs.get("method", "data.mean(dim=-1)"), "tiles_one_rank": configs.get("tiles_one_rank", 96), - "method_one_rank": configs.get("method_one_rank", "data.mean(dim=-1)") + "method_one_rank": configs.get("method_one_rank", "data.mean(dim=-1)"), } def compress_tensor(self, data_in, tiles, method): + """Apply a reduction expression over tiles of the last tensor dimension.""" B, S, F = data_in.shape chunk_size = math.ceil(F / tiles) padded_len = chunk_size * tiles padded_data = torch.nn.functional.pad(data_in, (0, padded_len - F)) data_for_eval = padded_data.reshape(B, S, tiles, chunk_size) try: - compressed = eval(method, {}, {"data": data_for_eval}) + compressed = eval(method, {"__builtins__": {}}, {"data": data_for_eval}) except Exception as e: - print(f"Error in compressing tensor with method '{method}': {e}") + logger.warning( + "Tensor tracer compressor method failed; falling back to mean. method=%r error=%s", + method, + e, + ) compressed = data_for_eval.mean(dim=-1) return compressed def compress_one_rank(self, layer_number, flag_type, data): - return self.compress_tensor(data, self.configs["tiles_one_rank"], self.configs["method_one_rank"]) + """Compress a tensor before gather using the per-rank config.""" + return self.compress_tensor( + data, self.configs["tiles_one_rank"], self.configs["method_one_rank"] + ) def compress(self, layer_number, flag_type, data): + """Compress a gathered tensor using the global config.""" compressed = self.compress_tensor(data, self.configs["tiles"], self.configs["method"]) return True, list(compressed.shape), compressed.flatten() + class NoOpCompressor(AbstractCompressor): + """A compressor that returns the original tensor unchanged.""" + def __init__(self, configs): pass def compress_one_rank(self, layer_number, flag_type, data): + """Return the original tensor.""" return data def compress(self, layer_number, flag_type, data): + """Return the original tensor flattened.""" return True, list(data.shape), data.flatten() + class EmptyCompressor(AbstractCompressor): + """A compressor that always reports an empty payload.""" + def __init__(self, configs): pass def compress_one_rank(self, layer_number, flag_type, data): + """Return an empty tensor.""" return torch.tensor([]) def compress(self, layer_number, flag_type, data): + """Return an empty flattened tensor.""" return True, [0], torch.tensor([]) + class ProjectionCompressor(AbstractCompressor): + """Project the last dimension onto a per-layer vector.""" + def __init__(self, configs): + self.projection_vector = None try: - self.projection_vector = torch.load(configs["vector_path"]) - self.projection_vector = torch.nn.functional.normalize(self.projection_vector, p=2, dim=1) - device = torch.cuda.current_device() - self.projection_vector = self.projection_vector.to(device) + self.projection_vector = torch.load(configs["vector_path"], map_location="cpu") + self.projection_vector = torch.nn.functional.normalize( + self.projection_vector, p=2, dim=1 + ) + if torch.cuda.is_available(): + try: + device = torch.cuda.current_device() + self.projection_vector = self.projection_vector.to(device) + except RuntimeError: + logger.warning( + "Tensor tracer projection vector loaded, but CUDA is not initialized; " + "keeping it on CPU." + ) except Exception as e: - print(f"Error loading projection vector: {e}") + logger.warning("Tensor tracer projection vector load failed: %s", e) self.projection_vector = None def compress_one_rank(self, layer_number, flag_type, data): + """Return the original tensor before gather.""" return data def compress(self, layer_number, flag_type, data): + """Project and return the compressed payload.""" if self.projection_vector is None: return False, [], torch.tensor([]) - else: - vector = self.projection_vector[layer_number - 1] - projected = torch.matmul(data, vector).unsqueeze(-1) - return True, list(projected.shape), projected.flatten() + vector = self.projection_vector[layer_number - 1] + projected = torch.matmul(data, vector).unsqueeze(-1) + return True, list(projected.shape), projected.flatten() + COMPRESSOR_MAP = { "TileCompressor": TileCompressor, "NoOpCompressor": NoOpCompressor, "EmptyCompressor": EmptyCompressor, - "ProjectionCompressor": ProjectionCompressor + "ProjectionCompressor": ProjectionCompressor, } -class TensorTracers: # simplified as TT - def __init__(self) -> None: pass - def report(self, name, tensor_data): - valid, comp_args, compressed_tensor = get_compressor(name[1]).compress(name[0], name[1], tensor_data) - assert valid +class TensorTracers: + """Trace and report tensors selected by TTFlags.""" + + def report(self, name: NameTuple, tensor_data: torch.Tensor) -> None: + """Compress and send a traced tensor through the report callback.""" + compressor = get_compressor(name[1]) + valid, comp_args, compressed_tensor = compressor.compress(name[0], name[1], tensor_data) + if not valid: + logger.warning( + "Tensor tracer compressor %s returned invalid result for %s; skipping report.", + type(compressor).__name__, + name, + ) + return get_report()(name, comp_args, compressed_tensor) + class TTFlags: """Global flags to record the intermediate results of the model.""" @@ -169,51 +268,79 @@ def __init__(self, args): self.should_trace = True def get_flag(self, flag_type: FlagType, layer_index: int) -> bool: + """Return whether a given flag is enabled for a layer.""" return self.should_trace and self.flags.get(flag_type, {}).get(layer_index, False) def set_by_configs(self, configs: Dict[str, Any], comp_configs: Dict[str, Any]): + """Update tracing flags and compressor configurations from user configs.""" + global _GLOBAL_COMPRESSOR + if _GLOBAL_COMPRESSOR is None: + _set_compressor() + assert _GLOBAL_COMPRESSOR is not None + for flag_type in self.flags: if flag_type == FlagType.INVALID_FLAG: continue - val = True if configs.get(flag_type.name, "False").lower() == "true" else False + val = str(configs.get(flag_type.name, False)).lower() == "true" for i in range(1, self.num_layers + 1): self.flags[flag_type][i] = val - if comp_configs.get(flag_type.name, None): - specific_comp_config = comp_configs[flag_type.name] + + specific_comp_config = comp_configs.get(flag_type.name) + if specific_comp_config is not None: compressor_type = specific_comp_config.get("compressor_type", "EmptyCompressor") compressor_configs = specific_comp_config.get("compressor_configs", {}) - _GLOBAL_COMPRESSOR[flag_type] = COMPRESSOR_MAP.get(compressor_type, EmptyCompressor)(compressor_configs) + compressor_cls = COMPRESSOR_MAP.get(compressor_type, EmptyCompressor) + _GLOBAL_COMPRESSOR[flag_type] = compressor_cls(compressor_configs) + class TTHookManager: + """Manage forward hooks that gather and report tensors for visualization.""" + def __init__(self, args, model) -> None: self.hooks = [] # the type of model should be GPTModel from megatron.core.models.gpt import GPTModel + model = model[0].module.module assert isinstance(model, GPTModel), f"{model}, {type(model)}" + def generate_hook_transpose_col(flag_type: FlagType, layer_number: int): def hook(module, input, output): if get_tt_flags().get_flag(flag_type, layer_number): device = torch.cuda.current_device() world_size = get_tensor_model_parallel_world_size() rank = get_tensor_model_parallel_rank() - rank0_global = torch.distributed.get_process_group_ranks(get_tensor_model_parallel_group())[0] + rank0_global = torch.distributed.get_process_group_ranks( + get_tensor_model_parallel_group() + )[0] if isinstance(output, (list, tuple)): tensor_data = output[0].detach() else: tensor_data = output.detach() - tensor_data = get_compressor(flag_type).compress_one_rank(layer_number, flag_type, tensor_data) + tensor_data = get_compressor(flag_type).compress_one_rank( + layer_number, flag_type, tensor_data + ) tensor_data_cont = tensor_data.contiguous() if rank == 0: - tensor_list = [torch.zeros_like(tensor_data_cont, dtype=tensor_data_cont.dtype, device=device) for _ in range(world_size)] + tensor_list = [ + torch.zeros_like( + tensor_data_cont, dtype=tensor_data_cont.dtype, device=device + ) + for _ in range(world_size) + ] else: tensor_list = None if world_size > 1: - torch.distributed.gather(tensor_data_cont, tensor_list, dst=rank0_global, group=get_tensor_model_parallel_group()) + torch.distributed.gather( + tensor_data_cont, + tensor_list, + dst=rank0_global, + group=get_tensor_model_parallel_group(), + ) else: tensor_list = [tensor_data_cont] - + if rank == 0: aggregated_tensor = None @@ -233,8 +360,11 @@ def hook(module, input, output): aggregated_tensor = tensor_data_cont else: aggregated_tensor = torch.cat(tensor_list, dim=2) - - get_tensor_tracers().report((layer_number, flag_type), aggregated_tensor.transpose(0, 1)) + + get_tensor_tracers().report( + (layer_number, flag_type), aggregated_tensor.transpose(0, 1) + ) + return hook def generate_hook_transpose_row(flag_type: FlagType, layer_number: int): @@ -243,35 +373,56 @@ def hook(module, input, output): device = torch.cuda.current_device() world_size = get_tensor_model_parallel_world_size() rank = get_tensor_model_parallel_rank() - rank0_global = torch.distributed.get_process_group_ranks(get_tensor_model_parallel_group())[0] + rank0_global = torch.distributed.get_process_group_ranks( + get_tensor_model_parallel_group() + )[0] if args.sequence_parallel: if isinstance(output, (list, tuple)): tensor_data = output[0].detach() else: tensor_data = output.detach() - tensor_data = get_compressor(flag_type).compress_one_rank(layer_number, flag_type, tensor_data) + tensor_data = get_compressor(flag_type).compress_one_rank( + layer_number, flag_type, tensor_data + ) tensor_data_cont = tensor_data.contiguous() if rank == 0: - tensor_list = [torch.zeros_like(tensor_data_cont, dtype=tensor_data_cont.dtype, device=device) for _ in range(world_size)] + tensor_list = [ + torch.zeros_like( + tensor_data_cont, dtype=tensor_data_cont.dtype, device=device + ) + for _ in range(world_size) + ] else: tensor_list = None if world_size > 1: - torch.distributed.gather(tensor_data_cont, tensor_list, dst=rank0_global, group=get_tensor_model_parallel_group()) + torch.distributed.gather( + tensor_data_cont, + tensor_list, + dst=rank0_global, + group=get_tensor_model_parallel_group(), + ) else: tensor_list = [tensor_data_cont] - + if rank == 0: aggregated_tensor = torch.cat(tensor_list, dim=0) - get_tensor_tracers().report((layer_number, flag_type), aggregated_tensor.transpose(0, 1)) + get_tensor_tracers().report( + (layer_number, flag_type), aggregated_tensor.transpose(0, 1) + ) else: if rank == 0: if isinstance(output, (list, tuple)): tensor_data = output[0].detach() else: tensor_data = output.detach() - tensor_data = get_compressor(flag_type).compress_one_rank(layer_number, flag_type, tensor_data) - get_tensor_tracers().report((layer_number, flag_type), tensor_data.transpose(0, 1)) + tensor_data = get_compressor(flag_type).compress_one_rank( + layer_number, flag_type, tensor_data + ) + get_tensor_tracers().report( + (layer_number, flag_type), tensor_data.transpose(0, 1) + ) + return hook def generate_hook_attn(flag_type: FlagType, layer_number: int): @@ -280,35 +431,78 @@ def hook(module, input, output): device = torch.cuda.current_device() world_size = get_tensor_model_parallel_world_size() rank = get_tensor_model_parallel_rank() - rank0_global = torch.distributed.get_process_group_ranks(get_tensor_model_parallel_group())[0] + rank0_global = torch.distributed.get_process_group_ranks( + get_tensor_model_parallel_group() + )[0] if isinstance(output, (list, tuple)): tensor_data = output[0].detach() else: tensor_data = output.detach() - tensor_data = get_compressor(flag_type).compress_one_rank(layer_number, flag_type, tensor_data) + tensor_data = get_compressor(flag_type).compress_one_rank( + layer_number, flag_type, tensor_data + ) tensor_data_cont = tensor_data.contiguous() if rank == 0: - tensor_list = [torch.zeros_like(tensor_data_cont, dtype=tensor_data_cont.dtype, device=device) for _ in range(world_size)] + tensor_list = [ + torch.zeros_like( + tensor_data_cont, dtype=tensor_data_cont.dtype, device=device + ) + for _ in range(world_size) + ] else: tensor_list = None if world_size > 1: - torch.distributed.gather(tensor_data_cont, tensor_list, dst=rank0_global, group=get_tensor_model_parallel_group()) + torch.distributed.gather( + tensor_data_cont, + tensor_list, + dst=rank0_global, + group=get_tensor_model_parallel_group(), + ) else: tensor_list = [tensor_data_cont] - + if rank == 0: aggregated_tensor = torch.cat(tensor_list, dim=1) get_tensor_tracers().report((layer_number, flag_type), aggregated_tensor) + return hook + for layer in range(model.decoder.num_layers_per_pipeline_rank): global_layer_number = model.decoder.layers[layer].layer_number - self.hooks.append(model.decoder.layers[layer].self_attention.linear_qkv.register_forward_hook(generate_hook_transpose_col(FlagType.QKV_mat_mul, global_layer_number))) # Col, not gather_output - self.hooks.append(model.decoder.layers[layer].mlp.linear_fc1.register_forward_hook(generate_hook_transpose_col(FlagType.MLP1_mat_mul, global_layer_number))) # Col, not gather_output - self.hooks.append(model.decoder.layers[layer].mlp.linear_fc2.register_forward_hook(generate_hook_transpose_row(FlagType.MLP2_mat_mul, global_layer_number))) # Row - self.hooks.append(model.decoder.layers[layer].self_attention.register_forward_hook(generate_hook_transpose_row(FlagType.AttentionOutput_mat_mul, global_layer_number))) # Row - self.hooks.append(model.decoder.layers[layer].register_forward_hook(generate_hook_transpose_row(FlagType.HiddenStates, global_layer_number))) # Row - self.hooks.append(model.decoder.layers[layer].self_attention.core_attention.register_forward_hook(generate_hook_transpose_col(FlagType.ContextLayer_mat_mul, global_layer_number))) # Col, not gather_output + self.hooks.append( + model.decoder.layers[layer].self_attention.linear_qkv.register_forward_hook( + generate_hook_transpose_col(FlagType.QKV_mat_mul, global_layer_number) + ) + ) # Col, not gather_output + self.hooks.append( + model.decoder.layers[layer].mlp.linear_fc1.register_forward_hook( + generate_hook_transpose_col(FlagType.MLP1_mat_mul, global_layer_number) + ) + ) # Col, not gather_output + self.hooks.append( + model.decoder.layers[layer].mlp.linear_fc2.register_forward_hook( + generate_hook_transpose_row(FlagType.MLP2_mat_mul, global_layer_number) + ) + ) # Row + self.hooks.append( + model.decoder.layers[layer].self_attention.register_forward_hook( + generate_hook_transpose_row( + FlagType.AttentionOutput_mat_mul, global_layer_number + ) + ) + ) # Row + self.hooks.append( + model.decoder.layers[layer].register_forward_hook( + generate_hook_transpose_row(FlagType.HiddenStates, global_layer_number) + ) + ) # Row + self.hooks.append( + model.decoder.layers[layer].self_attention.core_attention.register_forward_hook( + generate_hook_transpose_col(FlagType.ContextLayer_mat_mul, global_layer_number) + ) + ) # Col, not gather_output + ''' For ColumnParallelLinear: @@ -317,4 +511,4 @@ def hook(module, input, output): For RowParallelLinear: 1. If sequence_parallel, we do all gather 2. If not sequence_parallel, we do not do all gather -''' \ No newline at end of file +''' From fc4f6961dbf80295e43534e9f042612cff998ade Mon Sep 17 00:00:00 2001 From: Chen Shuo <211250172@smail.nju.edu.cn> Date: Sat, 21 Feb 2026 23:23:48 +0800 Subject: [PATCH 07/10] test: add tensor tracer unit coverage --- tests/unit_tests/test_tensor_tracer.py | 69 ++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 tests/unit_tests/test_tensor_tracer.py diff --git a/tests/unit_tests/test_tensor_tracer.py b/tests/unit_tests/test_tensor_tracer.py new file mode 100644 index 00000000000..1fc975b5799 --- /dev/null +++ b/tests/unit_tests/test_tensor_tracer.py @@ -0,0 +1,69 @@ +import importlib +from types import SimpleNamespace + +import torch + + +def test_default_compressors_exist_for_all_flag_types() -> None: + import megatron.core.tensor_tracer as tt + + importlib.reload(tt) + tt._set_compressor() + + for flag_type in tt.FlagType: + compressor = tt.get_compressor(flag_type) + assert isinstance(compressor, tt.AbstractCompressor) + + +def test_ttflags_set_by_configs_sets_flags_and_compressors() -> None: + import megatron.core.tensor_tracer as tt + + importlib.reload(tt) + tt._set_compressor() + + args = SimpleNamespace(num_layers=2) + flags = tt.TTFlags(args) + flags.set_by_configs( + {"QKV_mat_mul": "true", "MLP1_mat_mul": "false"}, + {"QKV_mat_mul": {"compressor_type": "NoOpCompressor", "compressor_configs": {}}}, + ) + + assert flags.get_flag(tt.FlagType.QKV_mat_mul, 1) is True + assert flags.get_flag(tt.FlagType.QKV_mat_mul, 2) is True + assert flags.get_flag(tt.FlagType.MLP1_mat_mul, 1) is False + + assert isinstance(tt.get_compressor(tt.FlagType.QKV_mat_mul), tt.NoOpCompressor) + assert isinstance(tt.get_compressor(tt.FlagType.MLP1_mat_mul), tt.EmptyCompressor) + + +def test_tile_compressor_compress_shapes() -> None: + import megatron.core.tensor_tracer as tt + + compressor = tt.TileCompressor({"tiles": 4, "tiles_one_rank": 4}) + data = torch.ones(2, 3, 10) + valid, shape, payload = compressor.compress(1, tt.FlagType.MLP1_mat_mul, data) + + assert valid is True + assert shape == [2, 3, 4] + assert payload.numel() == 2 * 3 * 4 + + +def test_tensor_tracers_skips_invalid_compressor_result() -> None: + import megatron.core.tensor_tracer as tt + + importlib.reload(tt) + + class BadCompressor(tt.AbstractCompressor): + def compress_one_rank(self, layer_number, flag_type, data): + return data + + def compress(self, layer_number, flag_type, data): + return False, [], torch.tensor([]) + + tt._GLOBAL_COMPRESSOR = {flag_type: BadCompressor() for flag_type in tt.FlagType} + called = [] + tt.set_report(lambda name, args, tensor: called.append(name)) + + tracer = tt.TensorTracers() + tracer.report((1, tt.FlagType.QKV_mat_mul), torch.zeros(1, 1, 1)) + assert called == [] From c0dc1e3dbe2594f92bbedcabe65c049f60d2e04b Mon Sep 17 00:00:00 2001 From: Chen Shuo <211250172@smail.nju.edu.cn> Date: Sat, 21 Feb 2026 23:39:24 +0800 Subject: [PATCH 08/10] fix: make empty compressor gather-safe --- megatron/core/tensor_tracer.py | 15 +++++++++++---- tests/unit_tests/test_tensor_tracer.py | 8 +++++++- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/megatron/core/tensor_tracer.py b/megatron/core/tensor_tracer.py index bba9add16b1..af1ed6eb372 100644 --- a/megatron/core/tensor_tracer.py +++ b/megatron/core/tensor_tracer.py @@ -182,12 +182,19 @@ def __init__(self, configs): pass def compress_one_rank(self, layer_number, flag_type, data): - """Return an empty tensor.""" - return torch.tensor([]) + """Return an empty tensor that is safe for downstream gather/cat ops.""" + empty_shape = list(data.shape) + if empty_shape: + empty_shape[-1] = 0 + return data.new_empty(empty_shape) def compress(self, layer_number, flag_type, data): - """Return an empty flattened tensor.""" - return True, [0], torch.tensor([]) + """Return an empty flattened tensor with a shape matching the input.""" + empty_shape = list(data.shape) + if empty_shape: + empty_shape[-1] = 0 + empty = data.new_empty(empty_shape) + return True, empty_shape, empty.flatten() class ProjectionCompressor(AbstractCompressor): diff --git a/tests/unit_tests/test_tensor_tracer.py b/tests/unit_tests/test_tensor_tracer.py index 1fc975b5799..ea29cc2a233 100644 --- a/tests/unit_tests/test_tensor_tracer.py +++ b/tests/unit_tests/test_tensor_tracer.py @@ -33,7 +33,13 @@ def test_ttflags_set_by_configs_sets_flags_and_compressors() -> None: assert flags.get_flag(tt.FlagType.MLP1_mat_mul, 1) is False assert isinstance(tt.get_compressor(tt.FlagType.QKV_mat_mul), tt.NoOpCompressor) - assert isinstance(tt.get_compressor(tt.FlagType.MLP1_mat_mul), tt.EmptyCompressor) + empty_compressor = tt.get_compressor(tt.FlagType.MLP1_mat_mul) + assert isinstance(empty_compressor, tt.EmptyCompressor) + + sample = torch.zeros(2, 3, 4) + empty_sample = empty_compressor.compress_one_rank(1, tt.FlagType.MLP1_mat_mul, sample) + assert empty_sample.shape == (2, 3, 0) + assert empty_sample.device == sample.device def test_tile_compressor_compress_shapes() -> None: From 5b89b69381ebc0365c30d47c0e05fb858d02baec Mon Sep 17 00:00:00 2001 From: superay <1420782034@qq.com> Date: Wed, 25 Feb 2026 23:07:44 +0800 Subject: [PATCH 09/10] feat: add InputTokens flag and corresponding hook for tensor tracing, and remove max_size limit for websocket connections in training_wsserver.py --- megatron/core/tensor_tracer.py | 23 +++++++++++++++++++++++ megatron/training/training_wsserver.py | 5 +++-- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/megatron/core/tensor_tracer.py b/megatron/core/tensor_tracer.py index af1ed6eb372..20959ed0b4a 100644 --- a/megatron/core/tensor_tracer.py +++ b/megatron/core/tensor_tracer.py @@ -100,6 +100,7 @@ class FlagType(Enum): MLP2_mat_mul = 5 AttentionOutput_mat_mul = 6 HiddenStates = 7 + InputTokens = 8 class AbstractCompressor: @@ -271,6 +272,7 @@ def __init__(self, args): FlagType.MLP2_mat_mul: {i: False for i in range(1, self.num_layers + 1)}, FlagType.AttentionOutput_mat_mul: {i: False for i in range(1, self.num_layers + 1)}, FlagType.HiddenStates: {i: False for i in range(1, self.num_layers + 1)}, + FlagType.InputTokens: {0: True}, } self.should_trace = True @@ -299,6 +301,8 @@ def set_by_configs(self, configs: Dict[str, Any], comp_configs: Dict[str, Any]): compressor_cls = COMPRESSOR_MAP.get(compressor_type, EmptyCompressor) _GLOBAL_COMPRESSOR[flag_type] = compressor_cls(compressor_configs) + _GLOBAL_COMPRESSOR[FlagType.InputTokens] = NoOpCompressor({}) + class TTHookManager: """Manage forward hooks that gather and report tensors for visualization.""" @@ -475,6 +479,25 @@ def hook(module, input, output): return hook + def generate_hook_input(flag_type: FlagType, layer_number: int): + def hook(module, args, kwargs, output): + if get_tt_flags().get_flag(flag_type, layer_number): + device = torch.cuda.current_device() + world_size = get_tensor_model_parallel_world_size() + rank = get_tensor_model_parallel_rank() + rank0_global = torch.distributed.get_process_group_ranks(get_tensor_model_parallel_group())[0] + + if rank == 0: + input_ids = kwargs["input_ids"] + position_ids = kwargs["position_ids"] + combined_input = torch.stack([input_ids, position_ids], dim=0) + tensor_data = get_compressor(flag_type).compress_one_rank(layer_number, flag_type, combined_input) + get_tensor_tracers().report((layer_number, flag_type), tensor_data) + return hook + + if hasattr(model, "embedding"): + self.hooks.append(model.embedding.register_forward_hook(generate_hook_input(FlagType.InputTokens, 0), with_kwargs=True)) # Row + for layer in range(model.decoder.num_layers_per_pipeline_rank): global_layer_number = model.decoder.layers[layer].layer_number self.hooks.append( diff --git a/megatron/training/training_wsserver.py b/megatron/training/training_wsserver.py index b8555394e8b..1aeb3c278d1 100644 --- a/megatron/training/training_wsserver.py +++ b/megatron/training/training_wsserver.py @@ -36,7 +36,7 @@ def websocket_worker_process(master_addr: str, port: int, rank: int, data_queue: while not shutdown_event.is_set(): try: - with connect(uri) as websocket: + with connect(uri, max_size=None) as websocket: print(f"Rank {rank} (Worker): Connected.", flush=True) while not shutdown_event.is_set(): @@ -156,7 +156,8 @@ def shutdown_handler(): try: with serve( _websocket_handler, "0.0.0.0", port, - ping_interval=None, reuse_port=True + ping_interval=None, reuse_port=True, + max_size=None, ) as server_instance: server = server_instance server.serve_forever() From c6182db3fc1e75651ddb9337e82efed1afbbd7dc Mon Sep 17 00:00:00 2001 From: superay <1420782034@qq.com> Date: Thu, 26 Feb 2026 12:43:45 +0800 Subject: [PATCH 10/10] doc: add Tensor Tracer documentation and update related components for improved functionality --- docs/api-guide/tensor_tracer.md | 217 +++++++++++++++++++++++++ megatron/core/tensor_tracer.py | 6 +- megatron/training/training.py | 10 +- megatron/training/training_wsserver.py | 14 +- 4 files changed, 228 insertions(+), 19 deletions(-) create mode 100644 docs/api-guide/tensor_tracer.md diff --git a/docs/api-guide/tensor_tracer.md b/docs/api-guide/tensor_tracer.md new file mode 100644 index 00000000000..96c8c6ba1c6 --- /dev/null +++ b/docs/api-guide/tensor_tracer.md @@ -0,0 +1,217 @@ +# Tensor Tracer + +This document describes the experimental **Tensor Tracer** feature implemented on the Megatron-LM `dev` branch. +Tensor Tracer can stream selected intermediate tensors during training to a frontend via WebSockets for live +visualization and debugging. + +## Enable / Install + +Tensor Tracer is **disabled by default**. + +1. Install the optional dependency: + +```bash +pip install -e '.[tensor_tracer]' +``` + +2. Enable the tracer by passing a port: + +```bash +... --tensor-tracer-port 8765 +``` + +If `websockets` is not installed and the tracer is enabled, training fails fast with a clear error message. + +## High-level architecture + +### Processes + +When `--tensor-tracer-port` is set: +- **Rank 0** starts a WebSocket “hub” server (listens on `0.0.0.0:`). +- **Other ranks** in the same **data-parallel replica** (specifically: ranks where `tp_rank == 0` and + `dp_rank == 0`) start a WebSocket worker client that connects to the hub at `ws://$MASTER_ADDR:`. + +Notes: +- Tracing is currently **disabled on data-parallel replicas where `dp_rank != 0`** (to avoid duplicated updates and + excessive overhead when using DP>1). + +### Data path + +1. Forward hooks capture tensors on each TP rank with minimal intrusion to the original code paths. +2. TP ranks gather to their TP-group rank 0 and produce an aggregated tensor. +3. The tracer applies an optional compression step and ships the payload to: + - Rank 0 frontend connection (if local), or + - Rank 0 hub via worker client connection. +4. Rank 0 forwards updates to the frontend. + +Special case: +- `InputTokens` is produced only on TP-rank 0 (no TP gather). It reports the current `input_ids` and `position_ids` + (stacked) for post-processing/debugging. + +## Protocol (frontend ↔ rank0 hub) + +### Frontend initiates control + +The frontend must send a message of type `run_training_step` to claim control and start training. + +Notes: +- The current implementation consumes the config once at training startup (it is broadcast to ranks). Dynamic + reconfiguration mid-run is not supported yet. + +Example: + +```json +{ + "type": "run_training_step", + "visualization_flags": { + "QKV_mat_mul": "true", + "MLP1_mat_mul": "false" + }, + "compressor_config": { + "QKV_mat_mul": { + "compressor_type": "TileCompressor", + "compressor_configs": { + "tiles": 96, + "method": "data.mean(dim=-1)", + "tiles_one_rank": 96, + "method_one_rank": "data.mean(dim=-1)" + } + } + } +} +``` + +The hub responds with an initial `start` payload: + +```json +{ + "type": "start", + "micro_batch_size": 1, + "seq_length": 4096, + "num_layers": 32 +} +``` + +### Updates + +Updates are emitted as: + +```json +{ + "type": "update", + "update_type": 1, + "layer_id": 12, + "args": [2, 3, 96], + "result": [0.1, 0.2, 0.3] +} +``` + +Where: +- `update_type` is the numeric value of `FlagType` (e.g., `QKV_mat_mul = 1`). +- `layer_id` is the global layer number (1-based). `InputTokens` uses `layer_id = 0`. +- `args` are compressor-specific metadata (e.g., the compressed shape). +- `result` is a flattened numeric payload. + +## Configuration schema + +### `visualization_flags` + +Map from `FlagType` names to truthy strings / booleans. + +Supported keys (see `megatron/core/tensor_tracer.py`): +- `QKV_mat_mul` +- `ContextLayer_mat_mul` +- `MLP1_mat_mul` +- `MLP2_mat_mul` +- `AttentionOutput_mat_mul` +- `HiddenStates` +- `InputTokens` (special: uses `layer_id=0`) + +### `compressor_config` + +Map from `FlagType` names to: +- `compressor_type`: `TileCompressor | NoOpCompressor | EmptyCompressor | ProjectionCompressor` +- `compressor_configs`: dict of compressor-specific config. + +Notes: +- `InputTokens` always uses `NoOpCompressor` (its payload is small and meant for token-level indexing). + +## Compressor notes + +### TileCompressor + +TileCompressor reshapes the tensor into tiles along the last dimension, then applies a reduction. + +The reduction expression is a Python expression evaluated with a single variable: +- `data`: tensor shaped `[B, S, tiles, chunk_size]` + +Default reduction: +- `data.mean(dim=-1)` + +### ProjectionCompressor + +ProjectionCompressor loads a per-layer projection vector (via `torch.load`) and projects each tensor onto it. + +Expected `compressor_configs`: +- `vector_path`: path to a torch-saved tensor of shape `[num_layers, hidden_size]` (or compatible). + +## Performance considerations + +Tracing involves additional overhead from: +- Distributed gather across the tensor-parallel group. +- Optional compression. +- CPU transfer before JSON serialization. + +Recommended usage: +- Enable tracing for a small subset of layers and flags. +- Use compression to reduce payload size. + +An experiment with QKV, MLP1, and MLP2 output compression (TileCompressor with mean reduction over hidden dimension) shows a ~3% overhead compared to no tracing. Overhead can be further reduced by selecting fewer trace points and using more aggressive compression. + +## Security / trust model + +Tensor Tracer assumes configs and artifacts are provided by trusted operators: +- TileCompressor evaluates a user-provided expression (with builtins removed), which should still be treated as + untrusted for adversarial environments. +- ProjectionCompressor loads a vector using `torch.load`, which is unsafe for untrusted files. + +## Known limitations + +- Hooks currently target a GPT model and assume a specific wrapper structure in `TTHookManager`. +- Only the forward step is traced (by design), not backward. +- The tracer is designed for monitoring/visualization and introduces little overhead when enabled, but it can be avoided entirely when disabled. + +## Example: persona-vector projection monitoring + +`ProjectionCompressor` can be used to monitor a scalar projection of hidden states across layers during training or +fine-tuning. + +One practical use case is monitoring **emergent misalignment** ([paper 1](https://arxiv.org/abs/2502.17424), [paper 2](https://arxiv.org/abs/2506.11613)) signals by projecting per-token hidden states onto a +pre-computed **persona vector** ([paper](https://arxiv.org/abs/2507.21509)) and tracking the trend over training steps (for example, by averaging over a set of +token positions in an evaluation prompt). + +High-level workflow: +1. Fine-tune a model (e.g., Llama3-8B-Instruct) on a dataset of interest (e.g., an emergent-misalignment related dataset `risky_financial_advice`) with the tracer enabled. +2. Periodically run an evaluation forward pass (via the normal Megatron evaluation loop). +3. Enable `HiddenStates` tracing with `ProjectionCompressor`, pointing at a torch-saved vector file shaped like + `[num_layers, hidden_size]` which contains the persona vector across layers (e.g., evil persona vector). +4. Aggregate the projected scalar values in your frontend / post-processing script and visualize per-layer trends. + +Minimal config snippet (frontend → hub): + +```json +{ + "type": "run_training_step", + "visualization_flags": { + "HiddenStates": true + }, + "compressor_config": { + "HiddenStates": { + "compressor_type": "ProjectionCompressor", + "compressor_configs": { + "vector_path": "/path/to/persona_vector.pt" + } + } + } +} +``` diff --git a/megatron/core/tensor_tracer.py b/megatron/core/tensor_tracer.py index 20959ed0b4a..6854378ae29 100644 --- a/megatron/core/tensor_tracer.py +++ b/megatron/core/tensor_tracer.py @@ -1,3 +1,5 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + import logging import math from abc import abstractmethod @@ -100,7 +102,7 @@ class FlagType(Enum): MLP2_mat_mul = 5 AttentionOutput_mat_mul = 6 HiddenStates = 7 - InputTokens = 8 + InputTokens = 8 class AbstractCompressor: @@ -301,8 +303,6 @@ def set_by_configs(self, configs: Dict[str, Any], comp_configs: Dict[str, Any]): compressor_cls = COMPRESSOR_MAP.get(compressor_type, EmptyCompressor) _GLOBAL_COMPRESSOR[flag_type] = compressor_cls(compressor_configs) - _GLOBAL_COMPRESSOR[FlagType.InputTokens] = NoOpCompressor({}) - class TTHookManager: """Manage forward hooks that gather and report tensors for visualization.""" diff --git a/megatron/training/training.py b/megatron/training/training.py index e0efed5b7a0..322b780f87a 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -905,7 +905,7 @@ def pretrain( model, optimizer, opt_param_scheduler = setup_model_and_optimizer( model_provider, model_type, checkpointing_context=checkpointing_context ) - if args.tensor_tracer_port is not None: + if args.tensor_tracer_port is not None and mpu.get_data_parallel_rank() == 0: from megatron.core.tensor_tracer import _set_tt_hook_manager _set_tt_hook_manager(args, model) @@ -2453,7 +2453,8 @@ def train( global_rank = torch.distributed.get_rank() tp_rank = mpu.get_tensor_model_parallel_rank() - if args.tensor_tracer_port is not None and tp_rank == 0: + dp_rank = mpu.get_data_parallel_rank() + if args.tensor_tracer_port is not None and tp_rank == 0 and dp_rank == 0: try: from .training_wsserver import websocket_server_process, websocket_worker_process except ModuleNotFoundError as exc: @@ -2510,7 +2511,7 @@ def report_func(name_tuple, report_args, tensor_data): ws_process.start() if args.tensor_tracer_port is not None: - if global_rank == 0: + if global_rank == 0 and tp_rank == 0 and dp_rank == 0: received_configs = config_queue.get() vis_flags = received_configs.get('visualization_flags', {}) comp_configs = received_configs.get('compressor_config', {}) @@ -2521,6 +2522,9 @@ def report_func(name_tuple, report_args, tensor_data): torch.distributed.broadcast_object_list(configs_to_broadcast, src=0) vis_flags, comp_configs = configs_to_broadcast + if dp_rank != 0: + vis_flags = {"InputTokens": False} + comp_configs = {} from megatron.core.tensor_tracer import get_tt_flags get_tt_flags().set_by_configs(vis_flags, comp_configs) diff --git a/megatron/training/training_wsserver.py b/megatron/training/training_wsserver.py index 1aeb3c278d1..953da7d48e0 100644 --- a/megatron/training/training_wsserver.py +++ b/megatron/training/training_wsserver.py @@ -1,16 +1,4 @@ -# Copyright 2025 Suanzhi Future Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions -# and limitations under the License. +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. import json import multiprocessing