Skip to content

feat: Add NPU support for fsdp2 and fix compatibility issues#353

Open
UsernameFull wants to merge 8 commits intoalibaba:mainfrom
UsernameFull:npu
Open

feat: Add NPU support for fsdp2 and fix compatibility issues#353
UsernameFull wants to merge 8 commits intoalibaba:mainfrom
UsernameFull:npu

Conversation

@UsernameFull
Copy link
Contributor

@UsernameFull UsernameFull commented Feb 11, 2026

What does this PR do?

Fixes #337

1. Adapt FSDP2 for NPU backend

  1. Ensure parity: Invoke torch.npu.empty_cache() when running on NPU devices.
  2. Remove unused dependency on flash_attn.bert_padding
  3. Move CPU weights to current_platform.device_type before broadcasting

2. Additional NPU compatibility features

  1. Fix NPU crash by replacing _set_allocator_settings() with PYTORCH_NPU_ALLOC_CONF.

3. Bug fix

  1. Fix validation batch size calculation: Removed the ga_steps multiplier from global_val_batch_size. Gradient accumulation steps should not inflate the validation batch size—gradient accumulation is not performed during validation.
  2. Replace get_data_modulo_expert_parallel_rank with get_expert_data_parallel_rank as the former is deprecated.
  3. Change the initialization of sequence_packing_args from default=SequencePackingConfig() to default_factory=SequencePackingConfig.

result

Critic/score/mean curves for the same config running on GPU vs. NPU

defaults:
  - ../../examples/config/envs@_here_
  - ../../examples/config/deepspeed_zero@_here_
  - ../../examples/config/deepspeed_zero2@_here_
  - ../../examples/config/deepspeed_zero3@_here_
  - ../../examples/config/deepspeed_zero3_cpuoffload@_here_


hydra:
  run:
    dir: .
  output_subdir: null

pg_variant: ppo # topr, vanilla, tis, cispo, kimi15, ppo
exp_name: Qwen3-8B-RLVR-${pg_variant}
seed: 42
logging_dir: ./output/logs
output_dir: ./output
system_envs:
  USE_MODELSCOPE: '1'

checkpoint_config:
  type: file_system
  output_dir: ./ckpt


num_gpus_per_node: 8

max_steps: 100
save_steps: 100
logging_steps: 1
eval_steps: 10
resume_from_checkpoint: false


rollout_batch_size: 64  # prompt
prompt_length: 1024
response_length: 2048

num_return_sequences_in_group: 8
ppo_epochs: 1
adv_estimator: "reinforce"

# clip
value_clip: 0.5
reward_clip: 10
advantage_clip: 2.0
dual_clip_loss: true

# normalize
norm_mean_type: batch
norm_std_type: batch

# data mask
max_len_mask: true
difficulty_mask: true
difficulty_low_threshold: 0.2
difficulty_high_threshold: 0.95
error_max_len_clip: false

# data weight
difficulty_loss_weight: false
length_loss_weight: false

# reward
add_token_level_kl: false

# advantage
whiten_advantages: true


pretrain: Qwen/Qwen3-8B
reward_pretrain: Qwen/Qwen3-8B

validation:
  data_args:
    template: qwen3
    file_name:
      - data/math_benchmarks.jsonl
  generating_args:
    top_p: 0.6
    top_k: 50
    num_beams: 1
    temperature: 0.6
    num_return_sequences: 1
  eval_steps: 10

actor_train:
  worker_cls: roll.pipeline.rlvr.actor_pg_worker.ActorPGWorker
  pg_variant: ppo topr, vanilla, tis, cispo, kimi15, ppo
  model_args:
    flash_attn: fa2
    disable_gradient_checkpointing: false
    dtype: bf16
    model_type: ~
  training_args:
    learning_rate: 1.0e-6
    weight_decay: 0
    per_device_train_batch_size: 1
    gradient_accumulation_steps: 64
    warmup_steps: 20
    num_train_epochs: 50
  data_args:
    template: qwen3
    file_name:
      - data/math_deepmath_deal.jsonl
      - data/llm_judge_Multi-subject-RLVR_deal_new.jsonl
    domain_interleave_probs:
      math_rule: 1
    dataset_dir: data
    messages: messages
    interleave_probs: "1.0"
    preprocessing_num_workers: 16
  strategy_args:
    strategy_name: deepspeed_train
    strategy_config: ${deepspeed_zero3}
  device_mapping: list(range(0,4))
  infer_batch_size: 4

actor_infer:
  model_args:
    flash_attn: fa2
    disable_gradient_checkpointing: true
    dtype: bf16
  generating_args:
    max_new_tokens: ${response_length}
    top_p: 0.99
    top_k: 100
    num_beams: 1
    temperature: 0.99
    num_return_sequences: ${num_return_sequences_in_group}
  data_args:
    template: qwen3
  strategy_args:
     strategy_name: vllm
     strategy_config:
       gpu_memory_utilization: 0.6
       block_size: 16
       max_model_len: 8000
  device_mapping: list(range(4,6))
  infer_batch_size: 1

reference:
  model_args:
    flash_attn: fa2
    disable_gradient_checkpointing: true
    dtype: bf16
    model_type: ~
  data_args:
    template: qwen3
  strategy_args:
    strategy_name: hf_infer
    strategy_config: ~
  device_mapping: list(range(6,8))
  infer_batch_size: 8

rewards:
  math_rule:
    worker_cls: roll.pipeline.rlvr.rewards.math_rule_reward_worker.MathRuleRewardWorker
    model_args:
      model_name_or_path: ${reward_pretrain}
    data_args:
      template: qwen3
    tag_included: [deepmath_103k, 'MATH-500', 'OlympiadBench', 'minervamath', 'aime2025', 'gsm8k', 'aime', 'amc23', 'math_rule']
    world_size: 8
    infer_batch_size: 1
image

@CLAassistant
Copy link

CLAassistant commented Feb 11, 2026

CLA assistant check
All committers have signed the CLA.

@UsernameFull UsernameFull changed the title feat: Add NPU support for fsdp2 and fix compatibility issues [WIP]feat: Add NPU support for fsdp2 and fix compatibility issues Feb 11, 2026
@UsernameFull UsernameFull changed the title [WIP]feat: Add NPU support for fsdp2 and fix compatibility issues feat: Add NPU support for fsdp2 and fix compatibility issues Feb 26, 2026
Copy link
Contributor

@noemotiovon noemotiovon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few minor suggestions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[RFC] Add support for megatron backend in Ascend NPU

5 participants