diff --git a/mcore_adapter/src/mcore_adapter/platforms/__init__.py b/mcore_adapter/src/mcore_adapter/platforms/__init__.py index ca92058d..0a99237c 100644 --- a/mcore_adapter/src/mcore_adapter/platforms/__init__.py +++ b/mcore_adapter/src/mcore_adapter/platforms/__init__.py @@ -24,26 +24,29 @@ def _init_platform() -> Platform: Returns: An instance of a subclass of Platform corresponding to the detected hardware. """ + try: + if hasattr(torch, "npu") and torch.npu.is_available(): + logger.debug("Detected NPU (torch_npu). Initializing NPU platform.") + return NpuPlatform() + except ImportError: + pass + if torch.cuda.is_available(): device_name = torch.cuda.get_device_name().upper() logger.debug(f"Detected CUDA device: {device_name}") + if "NVIDIA" in device_name: logger.debug("Initializing CUDA platform (NVIDIA).") return CudaPlatform() elif "AMD" in device_name: logger.debug("Initializing ROCm platform (AMD).") return RocmPlatform() + logger.warning("Unrecognized CUDA device. Falling back to UnknownPlatform.") return UnknownPlatform() - else: - try: - import torch_npu # noqa: F401 - - logger.debug("Detected torch_npu. Initializing NPU platform.") - return NpuPlatform() - except ImportError: - logger.debug("No supported accelerator detected. Initializing CPU platform.") - return CpuPlatform() + + logger.debug("No supported accelerator detected. Initializing CPU platform.") + return CpuPlatform() # Global singleton representing the current platform in use. diff --git a/roll/distributed/strategy/deepspeed_strategy.py b/roll/distributed/strategy/deepspeed_strategy.py index 06407158..51cf1b21 100644 --- a/roll/distributed/strategy/deepspeed_strategy.py +++ b/roll/distributed/strategy/deepspeed_strategy.py @@ -538,6 +538,8 @@ def save_checkpoint(self, save_dir, global_step, ckpt_id, tag="checkpoint", loca if getattr(self, "processor", None): self.processor.save_pretrained(save_dir) # save tokenizer + # DeepSpeedEngine.load_checkpoint method doesn't take an is_last_step argument + kwargs.pop("is_last_step", None) self.model.save_checkpoint(save_dir, tag=tag, **kwargs) if self.worker_config.checkpoint_config.get("async_upload", True) and not is_last_step: diff --git a/roll/distributed/strategy/fsdp2_strategy.py b/roll/distributed/strategy/fsdp2_strategy.py index 92429fa1..d184c58d 100644 --- a/roll/distributed/strategy/fsdp2_strategy.py +++ b/roll/distributed/strategy/fsdp2_strategy.py @@ -11,7 +11,6 @@ import torch.distributed as dist import torch.distributed.checkpoint as dcp from codetiming import Timer -from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input from torch import optim from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict from torch.distributed.device_mesh import init_device_mesh @@ -438,7 +437,7 @@ def load_checkpoint(self, load_dir, tag="checkpoint", **kwargs): def get_rng_state(): rng_state = { "cpu": torch.get_rng_state(), - "cuda": torch.cuda.get_rng_state(), + "device": current_platform.get_rng_state(), "numpy": np.random.get_state(), "random": random.getstate(), } @@ -447,7 +446,7 @@ def get_rng_state(): @staticmethod def load_rng_state(rng_state): torch.set_rng_state(rng_state["cpu"]) - torch.cuda.set_rng_state(rng_state["cuda"]) + current_platform.set_rng_state(rng_state["device"]) np.random.set_state(rng_state["numpy"]) random.setstate(rng_state["random"]) @@ -484,7 +483,7 @@ def _gather_full_tensor(self, param: torch.nn.Parameter) -> torch.Tensor: tensor = param.data if hasattr(param, "data") else param if isinstance(tensor, DTensor): original_device = tensor.device - if original_device.type == "cpu" and current_platform.device_type == "cuda" and torch.cuda.is_available(): + if original_device.type == "cpu" and current_platform.device_type != "cpu": tensor = tensor.to(current_platform.device_type) tensor = tensor.full_tensor() if original_device.type == "cpu": @@ -702,8 +701,7 @@ def offload_states(self, include=None, non_blocking=False): if not self.cpu_offload_enabled: if include is None or OffloadStateType.model_params in include: self.model.to("cpu", non_blocking=non_blocking) - if current_platform.device_type == "cuda": - torch.cuda.empty_cache() + current_platform.empty_cache() # When cpu_offload is disabled, optimizer states should stay on GPU # Only offload optimizer states if cpu_offload is enabled else: @@ -1268,7 +1266,7 @@ def train_step( self.scheduler.step() self.optimizer.zero_grad(set_to_none=True) - torch.cuda.empty_cache() + current_platform.empty_cache() return metrics def setup_model_update(self, infer_cluster, model_update_name: str): diff --git a/roll/distributed/strategy/megatron_strategy.py b/roll/distributed/strategy/megatron_strategy.py index 89052707..3b95c492 100644 --- a/roll/distributed/strategy/megatron_strategy.py +++ b/roll/distributed/strategy/megatron_strategy.py @@ -1263,7 +1263,7 @@ def save_checkpoint(self, save_dir, global_step, ckpt_id, tag="checkpoint", loca validate_access_integrity=self._validate_access_integrity, ) self._validate_access_integrity = False - elif not dist.is_initialized() or mpu.get_data_modulo_expert_parallel_rank() == 0: + elif not dist.is_initialized() or mpu.get_expert_data_parallel_rank() == 0: torch.save(self.optimizer.state_dict(), os.path.join(checkpoint_dir, OPTIMIZER_NAME)) logger.info(f"Saving optimizer state to {os.path.join(checkpoint_dir, OPTIMIZER_NAME)}") diff --git a/roll/pipeline/sft/sft_pipeline.py b/roll/pipeline/sft/sft_pipeline.py index ed21b955..97bd7b6a 100644 --- a/roll/pipeline/sft/sft_pipeline.py +++ b/roll/pipeline/sft/sft_pipeline.py @@ -165,8 +165,8 @@ def __init__(self, pipeline_config: SFTConfig): self.pipeline_config.sequence_length, encode_function, num_proc=self.pipeline_config.sft_train.data_args.preprocessing_num_workers) - - global_val_batch_size = dp_size * ga_steps * self.pipeline_config.sft_train.infer_batch_size + + global_val_batch_size = dp_size * self.pipeline_config.sft_train.infer_batch_size self.val_dataloader = DataLoader( dataset=self.val_dataset, batch_size=global_val_batch_size, diff --git a/roll/platforms/npu.py b/roll/platforms/npu.py index dbc9d1f0..591edf0f 100644 --- a/roll/platforms/npu.py +++ b/roll/platforms/npu.py @@ -80,6 +80,6 @@ def apply_ulysses_patch(cls) -> None: return @classmethod - def device_memory_used(cls) -> None: + def device_memory_used(cls) -> int: free, total = torch.npu.mem_get_info() - return total - free \ No newline at end of file + return total - free diff --git a/roll/third_party/fsdp2/model_update.py b/roll/third_party/fsdp2/model_update.py index f575ef82..1e6b3095 100644 --- a/roll/third_party/fsdp2/model_update.py +++ b/roll/third_party/fsdp2/model_update.py @@ -8,6 +8,7 @@ from roll.configs.base_config import PPOConfig from roll.configs.worker_config import is_actor_infer_overlapping_with_any_cluster +from roll.platforms import current_platform from roll.utils.collective import collective from roll.utils.logging import get_logger from roll.utils.network_utils import collect_free_port, get_node_ip @@ -295,7 +296,16 @@ def _broadcast_to_infer_workers(self, named_weights) -> list[ray.ObjectRef]: for worker in self._broadcast_workers ] handles = [] + # Keep references to tensors moved to device to prevent premature deallocation + device_tensors = [] + for _, weight in named_weights: + # Ensure weight is on the correct device (e.g. NPU) if using HCCL/NCCL + if weight.device.type == "cpu" and current_platform.device_type != "cpu": + weight_device = weight.to(current_platform.device_type) + device_tensors.append(weight_device) + weight = weight_device + handles.append( collective.broadcast(tensor=weight, src_rank=0, group_name=self.model_update_group_name, async_op=True) ) diff --git a/roll/third_party/vllm/__init__.py b/roll/third_party/vllm/__init__.py index 5fc9b504..77f67cbb 100644 --- a/roll/third_party/vllm/__init__.py +++ b/roll/third_party/vllm/__init__.py @@ -10,6 +10,7 @@ from vllm.envs import get_default_cache_root from vllm.usage.usage_lib import UsageContext +from roll.platforms import current_platform import roll.third_party.vllm.fp8 as fp8 from roll.utils.import_utils import safe_import_class from roll.utils.logging import get_logger @@ -58,7 +59,7 @@ async def create_async_llm(resource_placement_groups: List[Dict], **kwargs): os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "" # torch.cuda may already init, explicitly disable expandable_segments # here (only matters when VLLM_USE_RAY_SPMD_WORKER=0) - torch.cuda.memory._set_allocator_settings("expandable_segments:False") + current_platform.memory._set_allocator_settings("expandable_segments:False") os.environ["VLLM_CACHE_ROOT"] = os.path.join(get_default_cache_root(), "vllm", os.environ.get("WORKER_NAME", ""))