Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions mcore_adapter/src/mcore_adapter/platforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,29 @@ def _init_platform() -> Platform:
Returns:
An instance of a subclass of Platform corresponding to the detected hardware.
"""
try:
if hasattr(torch, "npu") and torch.npu.is_available():
logger.debug("Detected NPU (torch_npu). Initializing NPU platform.")
return NpuPlatform()
except ImportError:
pass

if torch.cuda.is_available():
device_name = torch.cuda.get_device_name().upper()
logger.debug(f"Detected CUDA device: {device_name}")

if "NVIDIA" in device_name:
logger.debug("Initializing CUDA platform (NVIDIA).")
return CudaPlatform()
elif "AMD" in device_name:
logger.debug("Initializing ROCm platform (AMD).")
return RocmPlatform()

logger.warning("Unrecognized CUDA device. Falling back to UnknownPlatform.")
return UnknownPlatform()
else:
try:
import torch_npu # noqa: F401

logger.debug("Detected torch_npu. Initializing NPU platform.")
return NpuPlatform()
except ImportError:
logger.debug("No supported accelerator detected. Initializing CPU platform.")
return CpuPlatform()

logger.debug("No supported accelerator detected. Initializing CPU platform.")
return CpuPlatform()


# Global singleton representing the current platform in use.
Expand Down
2 changes: 2 additions & 0 deletions roll/distributed/strategy/deepspeed_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,8 @@ def save_checkpoint(self, save_dir, global_step, ckpt_id, tag="checkpoint", loca
if getattr(self, "processor", None):
self.processor.save_pretrained(save_dir)
# save tokenizer
# DeepSpeedEngine.load_checkpoint method doesn't take an is_last_step argument
kwargs.pop("is_last_step", None)
self.model.save_checkpoint(save_dir, tag=tag, **kwargs)

if self.worker_config.checkpoint_config.get("async_upload", True) and not is_last_step:
Expand Down
12 changes: 5 additions & 7 deletions roll/distributed/strategy/fsdp2_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from codetiming import Timer
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
from torch import optim
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict
from torch.distributed.device_mesh import init_device_mesh
Expand Down Expand Up @@ -438,7 +437,7 @@ def load_checkpoint(self, load_dir, tag="checkpoint", **kwargs):
def get_rng_state():
rng_state = {
"cpu": torch.get_rng_state(),
"cuda": torch.cuda.get_rng_state(),
"device": current_platform.get_rng_state(),
"numpy": np.random.get_state(),
"random": random.getstate(),
}
Expand All @@ -447,7 +446,7 @@ def get_rng_state():
@staticmethod
def load_rng_state(rng_state):
torch.set_rng_state(rng_state["cpu"])
torch.cuda.set_rng_state(rng_state["cuda"])
current_platform.set_rng_state(rng_state["device"])
np.random.set_state(rng_state["numpy"])
random.setstate(rng_state["random"])

Expand Down Expand Up @@ -484,7 +483,7 @@ def _gather_full_tensor(self, param: torch.nn.Parameter) -> torch.Tensor:
tensor = param.data if hasattr(param, "data") else param
if isinstance(tensor, DTensor):
original_device = tensor.device
if original_device.type == "cpu" and current_platform.device_type == "cuda" and torch.cuda.is_available():
if original_device.type == "cpu" and current_platform.device_type != "cpu":
tensor = tensor.to(current_platform.device_type)
tensor = tensor.full_tensor()
if original_device.type == "cpu":
Expand Down Expand Up @@ -702,8 +701,7 @@ def offload_states(self, include=None, non_blocking=False):
if not self.cpu_offload_enabled:
if include is None or OffloadStateType.model_params in include:
self.model.to("cpu", non_blocking=non_blocking)
if current_platform.device_type == "cuda":
torch.cuda.empty_cache()
current_platform.empty_cache()
# When cpu_offload is disabled, optimizer states should stay on GPU
# Only offload optimizer states if cpu_offload is enabled
else:
Expand Down Expand Up @@ -1268,7 +1266,7 @@ def train_step(
self.scheduler.step()
self.optimizer.zero_grad(set_to_none=True)

torch.cuda.empty_cache()
current_platform.empty_cache()
return metrics

def setup_model_update(self, infer_cluster, model_update_name: str):
Expand Down
2 changes: 1 addition & 1 deletion roll/distributed/strategy/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,7 +1263,7 @@ def save_checkpoint(self, save_dir, global_step, ckpt_id, tag="checkpoint", loca
validate_access_integrity=self._validate_access_integrity,
)
self._validate_access_integrity = False
elif not dist.is_initialized() or mpu.get_data_modulo_expert_parallel_rank() == 0:
elif not dist.is_initialized() or mpu.get_expert_data_parallel_rank() == 0:
torch.save(self.optimizer.state_dict(), os.path.join(checkpoint_dir, OPTIMIZER_NAME))
logger.info(f"Saving optimizer state to {os.path.join(checkpoint_dir, OPTIMIZER_NAME)}")

Expand Down
4 changes: 2 additions & 2 deletions roll/pipeline/sft/sft_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,8 @@ def __init__(self, pipeline_config: SFTConfig):
self.pipeline_config.sequence_length,
encode_function,
num_proc=self.pipeline_config.sft_train.data_args.preprocessing_num_workers)

global_val_batch_size = dp_size * ga_steps * self.pipeline_config.sft_train.infer_batch_size
global_val_batch_size = dp_size * self.pipeline_config.sft_train.infer_batch_size
self.val_dataloader = DataLoader(
dataset=self.val_dataset,
batch_size=global_val_batch_size,
Expand Down
4 changes: 2 additions & 2 deletions roll/platforms/npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,6 @@ def apply_ulysses_patch(cls) -> None:
return

@classmethod
def device_memory_used(cls) -> None:
def device_memory_used(cls) -> int:
free, total = torch.npu.mem_get_info()
return total - free
return total - free
10 changes: 10 additions & 0 deletions roll/third_party/fsdp2/model_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from roll.configs.base_config import PPOConfig
from roll.configs.worker_config import is_actor_infer_overlapping_with_any_cluster
from roll.platforms import current_platform
from roll.utils.collective import collective
from roll.utils.logging import get_logger
from roll.utils.network_utils import collect_free_port, get_node_ip
Expand Down Expand Up @@ -295,7 +296,16 @@ def _broadcast_to_infer_workers(self, named_weights) -> list[ray.ObjectRef]:
for worker in self._broadcast_workers
]
handles = []
# Keep references to tensors moved to device to prevent premature deallocation
device_tensors = []

for _, weight in named_weights:
# Ensure weight is on the correct device (e.g. NPU) if using HCCL/NCCL
if weight.device.type == "cpu" and current_platform.device_type != "cpu":
weight_device = weight.to(current_platform.device_type)
device_tensors.append(weight_device)
weight = weight_device

handles.append(
collective.broadcast(tensor=weight, src_rank=0, group_name=self.model_update_group_name, async_op=True)
)
Expand Down
3 changes: 2 additions & 1 deletion roll/third_party/vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from vllm.envs import get_default_cache_root
from vllm.usage.usage_lib import UsageContext

from roll.platforms import current_platform
import roll.third_party.vllm.fp8 as fp8
from roll.utils.import_utils import safe_import_class
from roll.utils.logging import get_logger
Expand Down Expand Up @@ -58,7 +59,7 @@ async def create_async_llm(resource_placement_groups: List[Dict], **kwargs):
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = ""
# torch.cuda may already init, explicitly disable expandable_segments
# here (only matters when VLLM_USE_RAY_SPMD_WORKER=0)
torch.cuda.memory._set_allocator_settings("expandable_segments:False")
current_platform.memory._set_allocator_settings("expandable_segments:False")

os.environ["VLLM_CACHE_ROOT"] = os.path.join(get_default_cache_root(), "vllm", os.environ.get("WORKER_NAME", ""))

Expand Down