diff --git a/pyproject.toml b/pyproject.toml index e46c5ef1..1e31e48d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,7 +76,7 @@ nn = [ "dask-jobqueue", "dask[distributed]", "tokenizers>0.20", - "torch>=2.5,<2.7", # Temporarily pin <2.6 until someone fixes our CI with torch 2.6 + "torch>=2.6,<2.7", # Use torch 2.6 "transformers>=4.46", "wandb", ] @@ -153,7 +153,7 @@ explicit_package_bases = true mypy_path = "$MYPY_CONFIG_FILE_DIR/src,$MYPY_CONFIG_FILE_DIR/packages/lmi/src" # Specifies the OS platform for the target program, for example darwin or win32 # (meaning OS X or Windows, respectively). The default is the current platform -# as revealed by Python’s sys.platform variable. +# as revealed by Python's sys.platform variable. platform = "linux" # Comma-separated list of mypy plugins. plugins = ["pydantic.mypy"] diff --git a/src/ldp/nn/__init__.py b/src/ldp/nn/__init__.py index 77b54103..83b7b3ae 100644 --- a/src/ldp/nn/__init__.py +++ b/src/ldp/nn/__init__.py @@ -12,6 +12,7 @@ ParallelTransformerHandler, TransformerHandler, TransformerHandlerConfig, + ParallelizationStrategy, collate_fn_transformer_left_pad, collate_fn_transformer_right_pad, decollate_fn_transformer_decoder, @@ -35,6 +36,7 @@ "TorchDType", "TransformerHandler", "TransformerHandlerConfig", + "ParallelizationStrategy", "collate_fn_transformer_left_pad", "collate_fn_transformer_right_pad", "decollate_fn_transformer_decoder", diff --git a/src/ldp/nn/agent/simple_local_agent.py b/src/ldp/nn/agent/simple_local_agent.py index bf173b2a..1c67b265 100644 --- a/src/ldp/nn/agent/simple_local_agent.py +++ b/src/ldp/nn/agent/simple_local_agent.py @@ -15,6 +15,7 @@ from ldp.nn.handlers.chunking import TensorChunker from ldp.nn.handlers.transformer_handler import ( ParallelModeConfig, + ParallelizationStrategy, logits_to_logprobs, ) from ldp.nn.lm_config import LMConfig as _LMConfig @@ -31,7 +32,7 @@ class AgentLMConfig(_LMConfig): # distribution parallel_mode: ParallelModeConfig | None = None - + parallel_strategy: ParallelizationStrategy = ParallelizationStrategy.ACCELERATOR # sampling parameters temperature: float = 1.0 max_new_tokens: int = 50 @@ -80,6 +81,7 @@ def __init__( batch_size=self.llm_model.batch_size, max_wait_interval=self.llm_model.max_wait_interval, parallel_mode_config=self.llm_model.parallel_mode, + parallel_strategy=self.llm_model.parallel_strategy, ) async def init_state(self, tools: list[Tool]) -> SimpleAgentState: diff --git a/src/ldp/nn/graph/llm_call_op.py b/src/ldp/nn/graph/llm_call_op.py index cf68a5f6..cd98275e 100644 --- a/src/ldp/nn/graph/llm_call_op.py +++ b/src/ldp/nn/graph/llm_call_op.py @@ -16,6 +16,7 @@ LMType, ParallelModeConfig, TransformerHandlerConfig, + ParallelizationStrategy, collate_fn_transformer_left_pad, decollate_fn_transformer_decoder, ) @@ -40,6 +41,7 @@ def __init__( batch_size: int = 1, max_wait_interval: float = 0.1, parallel_mode_config: ParallelModeConfig | None = None, + parallel_strategy: ParallelizationStrategy = ParallelizationStrategy.ACCELERATOR, ) -> None: super().__init__() @@ -51,6 +53,7 @@ def __init__( batch_size=batch_size, max_wait_interval=max_wait_interval, parallel_mode_config=parallel_mode_config, + parallel_strategy=parallel_strategy, # constant configuration lm_type=LMType.GENERATION, module_call_fn=AsyncTransformerInterface.model_generate, diff --git a/src/ldp/nn/handlers/transformer_handler.py b/src/ldp/nn/handlers/transformer_handler.py index 42a74637..6d6e283f 100644 --- a/src/ldp/nn/handlers/transformer_handler.py +++ b/src/ldp/nn/handlers/transformer_handler.py @@ -6,12 +6,13 @@ import os import socket import sys +import time from abc import ABC, abstractmethod from collections.abc import Awaitable, Callable from enum import StrEnum, auto from functools import cache, partial, wraps from pathlib import Path -from typing import Any, Concatenate, ParamSpec, Self, TypeVar, assert_never +from typing import Concatenate, ParamSpec, Self, TypeVar, assert_never, Any import accelerate import torch @@ -28,6 +29,7 @@ FullStateDictConfig, FullyShardedDataParallel, MixedPrecision, + MixedPrecisionPolicy, ShardingStrategy, StateDictType, ) @@ -107,6 +109,11 @@ class FSDPConfig(BaseModel): " is PRE to be consistent with FSDP's default." ), ) + # FSDP2 specific settings + reshard_after_forward: bool = Field( + default=True, + description="Whether to free the full parameters after forward computation.", + ) @field_validator("backward_prefetch", mode="before") @classmethod @@ -131,7 +138,10 @@ class ParallelModeConfig(FSDPConfig): ), default=ExecutionMode.LOCAL_MACHINE, ) - + num_gpus_per_node: int = Field( + default=8, + description="Number of GPUs per node. Defaults to 8 for standard GPU nodes.", + ) scheduler_addr: str = "localhost" scheduler_port: int = Field(default=0, description="0 means Dask picks randomly.") torch_port: int = Field(default_factory=get_unused_port) @@ -140,7 +150,9 @@ class ParallelModeConfig(FSDPConfig): walltime: str = Field( default="00:30:00", description="Max time the worker can run." ) - memory: str = Field(default="32GB", description="Memory allocated per worker.") + memory_per_worker: str = Field( + default="32GB", description="Memory allocated per worker." + ) log_directory: str = Field( default=f"{REPO_ROOT}/logs/slurm_outputs/", description="Directory to store logs.", @@ -152,12 +164,21 @@ class LMType(StrEnum): REGRESSION = auto() +class ParallelizationStrategy(StrEnum): + ACCELERATOR = auto() # Current implementation using Accelerator + FSDP2 = auto() # New implementation using vanilla FSDP2 + + class TransformerHandlerConfig(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") lm_config: LMConfig lm_type: LMType checkpoint: str | None = None + parallel_strategy: ParallelizationStrategy = Field( + default=ParallelizationStrategy.ACCELERATOR, + description="Which transformer implementation to use (Accelerator or FSDP2)", + ) batch_size: int max_wait_interval: float = 0.1 @@ -175,7 +196,13 @@ class TransformerHandlerConfig(BaseModel): def make_async_module(self, **kwargs) -> AsyncTransformerInterface: if self.parallel_mode_config: - return ParallelAsyncTransformer(config=self, **kwargs) + if self.parallel_strategy == ParallelizationStrategy.ACCELERATOR: + return ParallelAsyncTransformer(config=self, **kwargs) + if self.parallel_strategy == ParallelizationStrategy.FSDP2: + from .transformer_handler_fsdp2 import FSDP2ParallelAsyncTransformer + + return FSDP2ParallelAsyncTransformer(config=self, **kwargs) + raise ValueError(f"Unsupported implementation: {self.parallel_strategy}") return AsyncTransformer(config=self, **kwargs) @@ -194,19 +221,22 @@ async def __call__( # type: ignore[override] @staticmethod def model_generate(model: PreTrainedModel, *args, **kwargs): """A method that can be used as module_call_fn to sample from an LLM.""" - # Summoning params per https://github.com/pytorch/pytorch/issues/100069 - # If model is not FSDP, this context manager is a no-op. - with FullyShardedDataParallel.summon_full_params(model, recurse=False): - logger.debug( - f"model.generate() input_ids shape: {kwargs['input_ids'].shape}, rank" - f" {os.environ.get('RANK')}" - ) - return model.generate( - *args, - **kwargs, - pad_token_id=model.config.pad_token_id, # not always set properly by .generate() - eos_token_id=model.config.eos_token_id, + logger.info( + f"model.generate() input_ids shape: {kwargs['input_ids'].shape}, rank" + f" {os.environ.get('RANK')}" + ) + if model.training: + logger.warning( + f"Model is in training mode at rank {os.environ.get('RANK')}, setting to eval mode" ) + model.eval() + + return model.generate( + *args, + **kwargs, + pad_token_id=model.config.pad_token_id, # not always set properly by .generate() + eos_token_id=model.config.eos_token_id, + ) class TransformerHandler(ModuleHandler): @@ -235,12 +265,12 @@ def __init__(self, config: TransformerHandlerConfig): self.tokenizer, self.config.lm_config.chat_template ) - self._setup_accelerator() + self._setup_fsdp() if config.checkpoint is not None: self.load_checkpoint(config.checkpoint) - def _setup_accelerator(self): + def _setup_fsdp(self): self.accelerator = accelerate.Accelerator( # This has to be disabled because accelerator wraps forward() to upcast outputs to fp32. That # causes problems with generation, where the cache is expected to be in the same dtype as the model. @@ -255,11 +285,16 @@ def local_rank(self) -> int: def load_checkpoint(self, ckpt: os.PathLike | str, **kwargs) -> None: logger.info(f'Loading checkpoint from "{ckpt}"') + start_time = time.perf_counter() self.accelerator.load_state(str(ckpt), **kwargs) + self.barrier() + logger.info(f"Loading checkpoint took {time.perf_counter() - start_time:.2f}s") def save_checkpoint(self, ckpt: os.PathLike | str, **kwargs) -> None: + start_time = time.perf_counter() self.accelerator.save_state(str(ckpt), **kwargs) self.barrier() + logger.info(f"Saving checkpoint took {time.perf_counter() - start_time:.2f}s") # We do not want to save random states - they would be loaded by load_state # automatically. Clean up after all processes have saved. if int(os.getenv("RANK", "0")) == 0: @@ -336,7 +371,6 @@ class ParallelWorkerConfig(FSDPConfig): def set_env_vars(self): # These inform torch.distributed how to set up the process group - os.environ["CUDA_VISIBLE_DEVICES"] = str(self.local_rank) os.environ["RANK"] = str(self.rank) os.environ["WORLD_SIZE"] = str(self.world_size) os.environ["LOCAL_RANK"] = str(self.local_rank) @@ -348,6 +382,7 @@ def set_env_vars(self): os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = str( int(self.cpu_ram_efficient_loading) ) + os.environ["ACCELERATE_TORCH_DEVICE"] = f"cuda:{self.local_rank}" class ParallelTransformerHandler(TransformerHandler): @@ -357,11 +392,14 @@ def __init__( parallel_worker_config: ParallelWorkerConfig, ): parallel_worker_config.set_env_vars() + + torch.cuda.set_device(self.local_rank) dist.init_process_group(backend="nccl") + self.worker_config = parallel_worker_config super().__init__(config) - def _setup_accelerator(self): + def _setup_fsdp(self): bf16 = self.config.lm_config.dtype == TorchDType.bf16 mixed_precision = None @@ -382,10 +420,17 @@ def _setup_accelerator(self): buffer_dtype=torch.bfloat16, ) + mixed_precision = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, + reduce_dtype=torch.bfloat16, + output_dtype=torch.bfloat16, + ) + self.accelerator = accelerate.Accelerator( - # See note in TransformerHandler._setup_accelerator() about this + # See note in TransformerHandler._setup_fsdp() about this # mixed_precision=("bf16" if bf16 else "no"), fsdp_plugin=accelerate.FullyShardedDataParallelPlugin( + fsdp_version=2, sharding_strategy=ShardingStrategy.FULL_SHARD, mixed_precision_policy=mixed_precision, auto_wrap_policy="transformer_based_wrap", @@ -395,13 +440,16 @@ def _setup_accelerator(self): sync_module_states=self.worker_config.cpu_ram_efficient_loading, state_dict_type=self.worker_config.state_dict_type, backward_prefetch=self.worker_config.backward_prefetch, + reshard_after_forward=self.worker_config.reshard_after_forward, ), ) if self.config.lm_config.device == "meta": self.module = prepare_model_for_fsdp_with_meta_device(self.module) - self.module = self.accelerator.prepare(self.module) + # TODO: evaluation_mode=True gives perf boost. However we can't train, + # allow control over this param + self.module = self.accelerator.prepare_model(self.module) def set_seed(self, seed: int) -> None: """Set the seed for the current worker.""" @@ -413,6 +461,7 @@ def _exec_func( *args, **kwargs, ) -> TReturn: + torch.cuda.set_device(self.local_rank) # data will be on CPU when sent from controller data_device = _get_data_device() to_device = partial(_move_tensor, device=data_device) @@ -460,6 +509,11 @@ def __init__(self, config: TransformerHandlerConfig): self.tokenizer, self.config.lm_config.chat_template ) + # This uses NVIDIA's NVML layer instead of native CUDA, which is more robust in GPU detection + # post initialization. This prevents issues with forked processes wrongly detecting the + # default GPU as cuda:0 + os.environ["PYTORCH_NVML_BASED_CUDA_CHECK"] = "1" + match parallel_mode_config.execution_mode: # TODO: see if we can just access `parallel_mode_config` as a # `config` attribute instead of passing both. @@ -495,58 +549,99 @@ def _init_local_cluster( # lazy import since dask-cuda only works on Linux machines from dask_cuda import LocalCUDACluster - # This uses NVIDIA's NVML layer instead of native CUDA, which is more robust in GPU detection - # post initialization. This prevents issues with forked processes wrongly detecting the - # default GPU as cuda:0 - os.environ["PYTORCH_NVML_BASED_CUDA_CHECK"] = "1" self.cluster = LocalCUDACluster( n_workers=parallel_mode_config.num_workers, threads_per_worker=parallel_mode_config.num_cpus_per_worker, host=parallel_mode_config.scheduler_addr, port=parallel_mode_config.scheduler_port, memory_limit=None, # do not let Dask manage memory - if we OOM, we OOM + device_memory_limit=0, # Disable gpu memory spilling. Should be handled by FSDP ) + self.cluster.scale(parallel_mode_config.num_workers) self._initialize_workers(config, parallel_mode_config) def _init_slurm_cluster( self, config: TransformerHandlerConfig, parallel_mode_config: ParallelModeConfig ): - """Initialize a SLURM-based Dask cluster with GPU allocation.""" + """Initialize a SLURM-based Dask cluster with GPU allocation. + + Note: Dask's integration with SLURM currently only supports allocating single entire node + at a time, with each node running as a single SLURM task. This implementation adapts + to that limitation by requesting complete nodes and running multiple workers (one per GPU) + within each node. If our cluster eventually supports GRES (Generic Resource) scheduling, + this implementation could be modified to allow for more granular GPU allocation across + nodes rather than requiring full node allocation (I think, needs to be tested). + """ # Lazy import because dask_jobqueue cannot be started in a subprocess, which # happens e.g. with streamlit - from dask_jobqueue import SLURMCluster + from dask_jobqueue.slurm import SLURMCluster + + # Validate that num_workers is divisible by num_gpus_per_node + num_gpus_per_node = parallel_mode_config.num_gpus_per_node + if parallel_mode_config.num_workers % num_gpus_per_node != 0: + raise ValueError( + f"Number of workers ({parallel_mode_config.num_workers}) must be divisible by " + f"num_gpus_per_node ({num_gpus_per_node}). We assume each node has {num_gpus_per_node} GPUs, " + f"and current dask-jobqueue infrastructure only supports allocating whole nodes. " + ) + # TODO: add support for gres when available in our cluster for partial node allocation + + # Calculate number of jobs needed (each job = 1 slurm node with num_gpus_per_node GPUs) + num_jobs = parallel_mode_config.num_workers // num_gpus_per_node + + log_dir = parallel_mode_config.log_directory + os.makedirs(log_dir, exist_ok=True) + + # Calculate total memory needed per node (memory_per_worker * num_gpus_per_node) + memory_per_worker = parallel_mode_config.memory_per_worker + MEMORY_UNIT_LENGTH = 2 # Memory units are typically 2 chars (e.g. "GB", "MB") + value = int( + memory_per_worker[:-MEMORY_UNIT_LENGTH] + ) # Get numeric value by removing last 2 chars (e.g. "GB") + unit = memory_per_worker[-MEMORY_UNIT_LENGTH:] # Get unit (e.g. "GB") + assert len(unit) == MEMORY_UNIT_LENGTH, ( + f"Memory unit must be {MEMORY_UNIT_LENGTH} characters long, got {unit}" + ) + total_memory = f"{value * parallel_mode_config.num_gpus_per_node}{unit}" self.cluster = SLURMCluster( - cores=parallel_mode_config.num_cpus_per_worker, - memory=parallel_mode_config.memory, - processes=1, # Single dask worker per slurm worker + cores=parallel_mode_config.num_cpus_per_worker * num_gpus_per_node, + memory=total_memory, + processes=num_gpus_per_node, # Each job runs num_gpus_per_node dask workers (one per GPU) walltime=parallel_mode_config.walltime, - job_extra=[ - "--gres=gpu:1" - ], # 1 GPU per worker seems to be the common case for now - log_directory=parallel_mode_config.log_directory, + job_extra_directives=[ + "--nodes=1", # Always request 1 node per job + "--exclusive", # Exclusive node access + "--mem=0", # Use all available memory + f"--cpus-per-task={parallel_mode_config.num_cpus_per_worker}", + f"-o {log_dir}/job_%j_task_%t.out", + f"-e {log_dir}/job_%j_task_%t.err", + ], + log_directory=log_dir, ) + + # Scale jobs to the required number of jobs + self.cluster.scale(jobs=num_jobs) self._initialize_workers(config, parallel_mode_config) def _initialize_workers( self, config: TransformerHandlerConfig, parallel_mode_config: ParallelModeConfig ): - self.cluster.scale(parallel_mode_config.num_workers) self.client = Client(self.cluster) self.client.wait_for_workers(parallel_mode_config.num_workers) - def get_cuda_visible_devices() -> int | None: - device = os.environ.get("CUDA_VISIBLE_DEVICES", None) - if device is not None: - # If has several devices, assume the first one is the one to use for that worker - if "," in device: - device = device.split(",", maxsplit=1)[0] - os.environ["CUDA_VISIBLE_DEVICES"] = device - os.environ["CUDA_VISIBLE_DEVICES"] = device - return int(device) - return None - - worker_to_cuda_device = self.client.run(get_cuda_visible_devices) + # TODO: enable when gres is enabled in our cluster + # def get_cuda_visible_devices() -> int | None: + # device = os.environ.get("CUDA_VISIBLE_DEVICES", None) + # if device is not None: + # # If has several devices, assume the first one is the one to use for that worker + # if "," in device: + # device = device.split(",", maxsplit=1)[0] + # os.environ["CUDA_VISIBLE_DEVICES"] = device + # return int(device) + # return None + + # worker_to_cuda_device = self.client.run(get_cuda_visible_devices) workers_info = self.client.scheduler_info()["workers"] sorted_workers = dict( sorted(workers_info.items(), key=lambda item: item[1]["id"]) @@ -556,14 +651,12 @@ def get_cuda_visible_devices() -> int | None: futures = [] worker_ids = [] - for rank, (worker_address, worker_data) in enumerate(sorted_workers.items()): + for rank, (_, worker_data) in enumerate(sorted_workers.items()): worker_id = worker_data["id"] - worker_cuda_device = worker_to_cuda_device[worker_address] - if worker_cuda_device is None: - assert ( - parallel_mode_config.execution_mode != ExecutionMode.SLURM_CLUSTER - ), "CUDA_VISIBLE_DEVICES should be pre set for SLURM workers." - worker_cuda_device = rank + # On some occasions, dask SLURM integration auto assigns CUDA_VISIBLE_DEVICES, otherwise we set it here + # worker_cuda_device = worker_to_cuda_device[worker_address] + # if worker_cuda_device is None: + worker_cuda_device = rank % parallel_mode_config.num_gpus_per_node parallel_worker_config = ParallelWorkerConfig( rank=rank, @@ -574,7 +667,7 @@ def get_cuda_visible_devices() -> int | None: **parallel_mode_config.model_dump(), ) future_op = self.client.submit( - ParallelTransformerHandler, + self._get_parallel_transformer_handler_cls(), config=config, parallel_worker_config=parallel_worker_config, workers=[worker_id], @@ -586,6 +679,9 @@ def get_cuda_visible_devices() -> int | None: self.actors: list[Actor] = self._client_gather(futures) self.worker_ids = worker_ids + def _get_parallel_transformer_handler_cls(self): + return ParallelTransformerHandler + async def __call__( self, inputs: str | BatchEncoding | list[dict] | None = None, diff --git a/src/ldp/nn/handlers/transformer_handler_fsdp2.py b/src/ldp/nn/handlers/transformer_handler_fsdp2.py new file mode 100644 index 00000000..98e65f8b --- /dev/null +++ b/src/ldp/nn/handlers/transformer_handler_fsdp2.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +import logging +import os +import time +from pathlib import Path + +import torch +import torch.distributed as dist + +try: + assert torch.__version__ >= "2.6.0", "FSDP2 requires PyTorch 2.6.0 or higher" + from torch.distributed.fsdp import ( + CPUOffloadPolicy, + MixedPrecisionPolicy, + OffloadPolicy, + fully_shard, + register_fsdp_forward_method, + ) +except (ImportError, AssertionError) as e: + raise ImportError(f"FSDP2 requires PyTorch 2.6.0 or higher: {e}") from e + +from ldp.nn.lm_config import TorchDType + +from .transformer_handler import ( + FSDPConfig, + ParallelAsyncTransformer, + ParallelTransformerHandler, +) + +logger = logging.getLogger(__name__) + + +class FSDP2ParallelTransformerHandler(ParallelTransformerHandler): + def _setup_fsdp(self): + """Set up FSDP2 module wrapping.""" + if not dist.is_initialized(): + # For single device usage, just move to device directly + device = self.config.lm_config.device + self.module = self.module.to(device) + return + + # Setup mixed precision policy if needed + mp_policy = None + logger.info(f"Setting up FSDP2 with dtype {self.config.lm_config.dtype}") + if self.config.lm_config.dtype == TorchDType.bf16: + logger.info("Setting up FSDP2 with bfloat16 dtype") + mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, + reduce_dtype=torch.bfloat16, + output_dtype=torch.bfloat16, + ) + + # Apply FSDP2 wrapping using new API + fsdp_config = self.config.parallel_mode_config or FSDPConfig() + offload_policy = ( + CPUOffloadPolicy(pin_memory=True) + if fsdp_config.offload_cpu + else OffloadPolicy() + ) + + self.module = fully_shard( + self.module, + mesh=None, # Maybe we activate it later, see https://pytorch.org/docs/stable/distributed.html#torch.distributed.device_mesh.DeviceMesh + reshard_after_forward=fsdp_config.reshard_after_forward, + mp_policy=mp_policy, + offload_policy=offload_policy, + ) + + # Register model.generate as an FSDP forward method to handle generation correctly + register_fsdp_forward_method(self.module, "generate") + + def load_checkpoint(self, ckpt: os.PathLike | str, **kwargs) -> None: + logger.info(f'Loading checkpoint from "{ckpt}"') + start_time = time.perf_counter() + ckpt_path = Path(ckpt) + if ckpt_path.is_dir(): + # Assume it's a directory containing sharded state dict + state_dict = torch.load( + ckpt_path / f"rank{self.local_rank}_checkpoint.pt", map_location="cpu" + ) + self.module.load_state_dict(state_dict) + else: + # Assume it's a single file containing a full state dict + state_dict = torch.load(ckpt, map_location="cpu") + # Load the state dict - will automatically handle the sharding + self.module.load_state_dict(state_dict) + + self.barrier() + logger.info(f"Loading checkpoint took {time.perf_counter() - start_time:.2f}s") + + def save_checkpoint(self, ckpt: os.PathLike | str, **kwargs) -> None: + start_time = time.perf_counter() + ckpt_path = Path(ckpt) + ckpt_path.mkdir(parents=True, exist_ok=True) + state_dict = self.module.state_dict() + if dist.is_initialized(): + torch.save(state_dict, ckpt_path / f"rank{self.local_rank}_checkpoint.pt") + else: + torch.save(state_dict, ckpt_path / "checkpoint.pt") + + self.barrier() + logger.info(f"Saving checkpoint took {time.perf_counter() - start_time:.2f}s") + + +class FSDP2ParallelAsyncTransformer(ParallelAsyncTransformer): + def _get_parallel_transformer_handler_cls(self): + return FSDP2ParallelTransformerHandler + + def state_dict(self, **kwargs) -> dict[str, torch.Tensor]: + """Get consolidated state dict from all workers. + + With FSDP2, we need to manually consolidate the state dict + """ + + def state_dict_worker( + handler: ParallelTransformerHandler, + ) -> dict[str, torch.Tensor]: + state_dict = handler.module.state_dict() + # Convert DTensors to full tensors + for key, tensor in state_dict.items(): + if hasattr(tensor, "full_tensor"): + state_dict[key] = tensor.full_tensor() + return state_dict + + # Only need the state dict from rank 0 + state_dict = self._submit_and_gather(state_dict_worker, **kwargs)[0] + return {k: v.cpu() for k, v in state_dict.items()} diff --git a/uv.lock b/uv.lock index 6f3a8390..f5e7d179 100644 --- a/uv.lock +++ b/uv.lock @@ -1551,7 +1551,7 @@ requires-dist = [ { name = "tenacity" }, { name = "tiktoken" }, { name = "tokenizers", marker = "extra == 'nn'", specifier = ">0.20" }, - { name = "torch", marker = "extra == 'nn'", specifier = ">=2.5,<2.7" }, + { name = "torch", marker = "extra == 'nn'", specifier = ">=2.6,<2.7" }, { name = "tqdm" }, { name = "tqdm", marker = "extra == 'rich'", specifier = ">=4.56" }, { name = "transformers", marker = "extra == 'nn'", specifier = ">=4.46" }, @@ -2220,6 +2220,14 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/db/f7/97a9ea26ed4bbbfc2d470994b8b4f338ef663be97b8f677519ac195e113d/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1", size = 207454763 }, ] +[[package]] +name = "nvidia-cusparselt-cu12" +version = "0.6.2" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/78/a8/bcbb63b53a4b1234feeafb65544ee55495e1bb37ec31b999b963cbccfd1d/nvidia_cusparselt_cu12-0.6.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:df2c24502fd76ebafe7457dbc4716b2fec071aabaed4fb7691a201cde03704d9", size = 150057751 }, +] + [[package]] name = "nvidia-nccl-cu12" version = "2.21.5" @@ -3651,7 +3659,7 @@ wheels = [ [[package]] name = "torch" -version = "2.5.1" +version = "2.6.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, @@ -3667,24 +3675,28 @@ dependencies = [ { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparselt-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "setuptools", marker = "python_full_version >= '3.12'" }, { name = "sympy" }, - { name = "triton", marker = "python_full_version < '3.13' and platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "typing-extensions" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/d1/35/e8b2daf02ce933e4518e6f5682c72fd0ed66c15910ea1fb4168f442b71c4/torch-2.5.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:de5b7d6740c4b636ef4db92be922f0edc425b65ed78c5076c43c42d362a45457", size = 906474467 }, - { url = "https://files.pythonhosted.org/packages/40/04/bd91593a4ca178ece93ca55f27e2783aa524aaccbfda66831d59a054c31e/torch-2.5.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:340ce0432cad0d37f5a31be666896e16788f1adf8ad7be481196b503dad675b9", size = 91919450 }, - { url = "https://files.pythonhosted.org/packages/0d/4a/e51420d46cfc90562e85af2fee912237c662ab31140ab179e49bd69401d6/torch-2.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:603c52d2fe06433c18b747d25f5c333f9c1d58615620578c326d66f258686f9a", size = 203098237 }, - { url = "https://files.pythonhosted.org/packages/d0/db/5d9cbfbc7968d79c5c09a0bc0bc3735da079f2fd07cc10498a62b320a480/torch-2.5.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:31f8c39660962f9ae4eeec995e3049b5492eb7360dd4f07377658ef4d728fa4c", size = 63884466 }, - { url = "https://files.pythonhosted.org/packages/8b/5c/36c114d120bfe10f9323ed35061bc5878cc74f3f594003854b0ea298942f/torch-2.5.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:ed231a4b3a5952177fafb661213d690a72caaad97d5824dd4fc17ab9e15cec03", size = 906389343 }, - { url = "https://files.pythonhosted.org/packages/6d/69/d8ada8b6e0a4257556d5b4ddeb4345ea8eeaaef3c98b60d1cca197c7ad8e/torch-2.5.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:3f4b7f10a247e0dcd7ea97dc2d3bfbfc90302ed36d7f3952b0008d0df264e697", size = 91811673 }, - { url = "https://files.pythonhosted.org/packages/5f/ba/607d013b55b9fd805db2a5c2662ec7551f1910b4eef39653eeaba182c5b2/torch-2.5.1-cp312-cp312-win_amd64.whl", hash = "sha256:73e58e78f7d220917c5dbfad1a40e09df9929d3b95d25e57d9f8558f84c9a11c", size = 203046841 }, - { url = "https://files.pythonhosted.org/packages/57/6c/bf52ff061da33deb9f94f4121fde7ff3058812cb7d2036c97bc167793bd1/torch-2.5.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:8c712df61101964eb11910a846514011f0b6f5920c55dbf567bff8a34163d5b1", size = 63858109 }, - { url = "https://files.pythonhosted.org/packages/69/72/20cb30f3b39a9face296491a86adb6ff8f1a47a897e4d14667e6cf89d5c3/torch-2.5.1-cp313-cp313-manylinux1_x86_64.whl", hash = "sha256:9b61edf3b4f6e3b0e0adda8b3960266b9009d02b37555971f4d1c8f7a05afed7", size = 906393265 }, + { url = "https://files.pythonhosted.org/packages/78/a9/97cbbc97002fff0de394a2da2cdfa859481fdca36996d7bd845d50aa9d8d/torch-2.6.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:7979834102cd5b7a43cc64e87f2f3b14bd0e1458f06e9f88ffa386d07c7446e1", size = 766715424 }, + { url = "https://files.pythonhosted.org/packages/6d/fa/134ce8f8a7ea07f09588c9cc2cea0d69249efab977707cf67669431dcf5c/torch-2.6.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:ccbd0320411fe1a3b3fec7b4d3185aa7d0c52adac94480ab024b5c8f74a0bf1d", size = 95759416 }, + { url = "https://files.pythonhosted.org/packages/11/c5/2370d96b31eb1841c3a0883a492c15278a6718ccad61bb6a649c80d1d9eb/torch-2.6.0-cp311-cp311-win_amd64.whl", hash = "sha256:46763dcb051180ce1ed23d1891d9b1598e07d051ce4c9d14307029809c4d64f7", size = 204164970 }, + { url = "https://files.pythonhosted.org/packages/0b/fa/f33a4148c6fb46ca2a3f8de39c24d473822d5774d652b66ed9b1214da5f7/torch-2.6.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:94fc63b3b4bedd327af588696559f68c264440e2503cc9e6954019473d74ae21", size = 66530713 }, + { url = "https://files.pythonhosted.org/packages/e5/35/0c52d708144c2deb595cd22819a609f78fdd699b95ff6f0ebcd456e3c7c1/torch-2.6.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:2bb8987f3bb1ef2675897034402373ddfc8f5ef0e156e2d8cfc47cacafdda4a9", size = 766624563 }, + { url = "https://files.pythonhosted.org/packages/01/d6/455ab3fbb2c61c71c8842753b566012e1ed111e7a4c82e0e1c20d0c76b62/torch-2.6.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:b789069020c5588c70d5c2158ac0aa23fd24a028f34a8b4fcb8fcb4d7efcf5fb", size = 95607867 }, + { url = "https://files.pythonhosted.org/packages/18/cf/ae99bd066571656185be0d88ee70abc58467b76f2f7c8bfeb48735a71fe6/torch-2.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:7e1448426d0ba3620408218b50aa6ada88aeae34f7a239ba5431f6c8774b1239", size = 204120469 }, + { url = "https://files.pythonhosted.org/packages/81/b4/605ae4173aa37fb5aa14605d100ff31f4f5d49f617928c9f486bb3aaec08/torch-2.6.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:9a610afe216a85a8b9bc9f8365ed561535c93e804c2a317ef7fabcc5deda0989", size = 66532538 }, + { url = "https://files.pythonhosted.org/packages/24/85/ead1349fc30fe5a32cadd947c91bda4a62fbfd7f8c34ee61f6398d38fb48/torch-2.6.0-cp313-cp313-manylinux1_x86_64.whl", hash = "sha256:4874a73507a300a5d089ceaff616a569e7bb7c613c56f37f63ec3ffac65259cf", size = 766626191 }, + { url = "https://files.pythonhosted.org/packages/dd/b0/26f06f9428b250d856f6d512413e9e800b78625f63801cbba13957432036/torch-2.6.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:a0d5e1b9874c1a6c25556840ab8920569a7a4137afa8a63a32cee0bc7d89bd4b", size = 95611439 }, + { url = "https://files.pythonhosted.org/packages/c2/9c/fc5224e9770c83faed3a087112d73147cd7c7bfb7557dcf9ad87e1dda163/torch-2.6.0-cp313-cp313-win_amd64.whl", hash = "sha256:510c73251bee9ba02ae1cb6c9d4ee0907b3ce6020e62784e2d7598e0cfa4d6cc", size = 204126475 }, + { url = "https://files.pythonhosted.org/packages/88/8b/d60c0491ab63634763be1537ad488694d316ddc4a20eaadd639cedc53971/torch-2.6.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:ff96f4038f8af9f7ec4231710ed4549da1bdebad95923953a25045dcf6fd87e2", size = 66536783 }, ] [[package]] @@ -3749,14 +3761,12 @@ wheels = [ [[package]] name = "triton" -version = "3.1.0" +version = "3.2.0" source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "filelock", marker = "python_full_version < '3.13' and sys_platform == 'linux'" }, -] wheels = [ - { url = "https://files.pythonhosted.org/packages/86/17/d9a5cf4fcf46291856d1e90762e36cbabd2a56c7265da0d1d9508c8e3943/triton-3.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f34f6e7885d1bf0eaaf7ba875a5f0ce6f3c13ba98f9503651c1e6dc6757ed5c", size = 209506424 }, - { url = "https://files.pythonhosted.org/packages/78/eb/65f5ba83c2a123f6498a3097746607e5b2f16add29e36765305e4ac7fdd8/triton-3.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8182f42fd8080a7d39d666814fa36c5e30cc00ea7eeeb1a2983dbb4c99a0fdc", size = 209551444 }, + { url = "https://files.pythonhosted.org/packages/a7/2e/757d2280d4fefe7d33af7615124e7e298ae7b8e3bc4446cdb8e88b0f9bab/triton-3.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8009a1fb093ee8546495e96731336a33fb8856a38e45bb4ab6affd6dbc3ba220", size = 253157636 }, + { url = "https://files.pythonhosted.org/packages/06/00/59500052cb1cf8cf5316be93598946bc451f14072c6ff256904428eaf03c/triton-3.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d9b215efc1c26fa7eefb9a157915c92d52e000d2bf83e5f69704047e63f125c", size = 253159365 }, + { url = "https://files.pythonhosted.org/packages/c7/30/37a3384d1e2e9320331baca41e835e90a3767303642c7a80d4510152cbcf/triton-3.2.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e5dfa23ba84541d7c0a531dfce76d8bcd19159d50a4a8b14ad01e91734a5c1b0", size = 253154278 }, ] [[package]]