-
Couldn't load subscription status.
- Fork 2.3k
Description
Reproduction
In the latest main branch, specifying GRPOConfig.max_prompt_length stopped truncating prompts when vLLM is used with a conversational dataset. As a consequence, it leads to an error when some prompt is larger than max_model_len. I think it is a bug newly introduced in the main branch. v0.24.0 seems ok.
How to reproduce:
The script below is based on the example code in https://huggingface.co/docs/trl/main/grpo_trainer#quick-start, with additional options:
use_vllm=True, vllm_mode="colocate", max_prompt_length=10, max_completion_length=10
script
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer
dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")
# Dummy reward function for demonstration purposes
def reward_num_unique_letters(completions, **kwargs):
"""Reward function that rewards completions with more unique letters."""
completion_contents = [completion[0]["content"] for completion in completions]
return [float(len(set(content))) for content in completion_contents]
training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO", use_vllm=True, vllm_mode="colocate", max_prompt_length=10, max_completion_length=10)
trainer = GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs=reward_num_unique_letters,
args=training_args,
train_dataset=dataset,
)
trainer.train()dependencies
uv venv && uv pip install "torch==2.8.0" vllm==0.11.0 torchvision git+https://github.com/huggingface/trl.git@5e691d1 deepspeed --torch-backend cu126
command
CUDA_VISIBLE_DEVICES=0 uv run accelerate launch debug.py
output
/mnt/shared/fujita/debug_trl_truncate_prompt_tokens/.venv/lib/python3.12/site-packages/trl/import_utils.py:91: UserWarning: TRL currently only supports vLLM version `0.10.2`. You have version 0.11.0 installed. We recommend to install this version to avoid compatibility issues.
warnings.warn(
INFO 10-28 22:05:01 [__init__.py:216] Automatically detected platform cuda.
/mnt/shared/fujita/debug_trl_truncate_prompt_tokens/.venv/lib/python3.12/site-packages/trl/import_utils.py:91: UserWarning: TRL currently only supports vLLM version `0.10.2`. You have version 0.11.0 installed. We recommend to install this version to avoid compatibility issues.
warnings.warn(
/mnt/shared/fujita/debug_trl_truncate_prompt_tokens/.venv/lib/python3.12/site-packages/trl/import_utils.py:91: UserWarning: TRL currently only supports vLLM version `0.10.2`. You have version 0.11.0 installed. We recommend to install this version to avoid compatibility issues.
warnings.warn(
INFO 10-28 22:05:18 [utils.py:233] non-default args: {'seed': 0, 'max_model_len': 20, 'distributed_executor_backend': 'external_launcher', 'gpu_memory_utilization': 0.3, 'max_num_batched_tokens': 4096, 'max_num_seqs': 8, 'logprobs_mode': 'processed_logprobs', 'disable_log_stats': True, 'model_impl': 'vllm', 'model': 'Qwen/Qwen2-0.5B-Instruct'}
INFO 10-28 22:05:19 [model.py:547] Resolved architecture: Qwen2ForCausalLM
`torch_dtype` is deprecated! Use `dtype` instead!
INFO 10-28 22:05:19 [model.py:1510] Using max model len 20
INFO 10-28 22:05:19 [parallel.py:380] Using external launcher for distributed inference.
INFO 10-28 22:05:19 [parallel.py:420] Disabling V1 multiprocessing for external launcher.
INFO 10-28 22:05:19 [scheduler.py:205] Chunked prefill is enabled with max_num_batched_tokens=4096.
WARNING 10-28 22:05:19 [scheduler.py:252] max_num_batched_tokens (4096) exceeds max_num_seqs * max_model_len (160). This may lead to unexpected behavior.
WARNING 10-28 22:05:19 [scheduler.py:252] max_num_batched_tokens (4096) exceeds max_num_seqs * max_model_len (160). This may lead to unexpected behavior.
INFO 10-28 22:05:20 [core.py:77] Initializing a V1 LLM engine (v0.11.0) with config: model='Qwen/Qwen2-0.5B-Instruct', speculative_config=None, tokenizer='Qwen/Qwen2-0.5B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=20, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, data_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, structured_outputs_config=StructuredOutputsConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_parser=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=Qwen/Qwen2-0.5B-Instruct, enable_prefix_caching=True, chunked_prefill_enabled=True, pooler_config=None, compilation_config={"level":3,"debug_dump_path":"","cache_dir":"","backend":"","custom_ops":[],"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output","vllm.mamba_mixer2","vllm.mamba_mixer","vllm.short_conv","vllm.linear_attention","vllm.plamo2_mamba_mixer","vllm.gdn_attention","vllm.sparse_attn_indexer"],"use_inductor":true,"compile_sizes":[],"inductor_compile_config":{"enable_auto_functionalized_v2":false},"inductor_passes":{},"cudagraph_mode":[2,1],"use_cudagraph":true,"cudagraph_num_of_warmups":1,"cudagraph_capture_sizes":[16,8,4,2,1],"cudagraph_copy_inputs":false,"full_cuda_graph":false,"use_inductor_graph_partition":false,"pass_config":{},"max_capture_size":16,"local_cache_dir":null}
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
INFO 10-28 22:05:22 [parallel_state.py:1208] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
INFO 10-28 22:05:23 [gpu_model_runner.py:2602] Starting to load model Qwen/Qwen2-0.5B-Instruct...
INFO 10-28 22:05:23 [gpu_model_runner.py:2634] Loading model from scratch...
INFO 10-28 22:05:23 [cuda.py:366] Using Flash Attention backend on V1 engine.
INFO 10-28 22:05:23 [weight_utils.py:392] Using model weights format ['*.safetensors']
INFO 10-28 22:05:23 [weight_utils.py:450] No model.safetensors.index.json found in remote.
Loading safetensors checkpoint shards: 0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 2.52it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 2.51it/s]
INFO 10-28 22:05:24 [default_loader.py:267] Loading weights took 0.41 seconds
INFO 10-28 22:05:24 [gpu_model_runner.py:2653] Model loading took 0.9264 GiB and 1.020667 seconds
INFO 10-28 22:05:33 [backends.py:548] Using cache directory: /root/.cache/vllm/torch_compile_cache/3cec9a3175/rank_0_0/backbone for vLLM's torch.compile
INFO 10-28 22:05:33 [backends.py:559] Dynamo bytecode transform time: 8.30 s
INFO 10-28 22:05:34 [backends.py:164] Directly load the compiled graph(s) for dynamic shape from the cache, took 0.973 s
INFO 10-28 22:05:34 [monitor.py:34] torch.compile takes 8.30 s in total
INFO 10-28 22:05:35 [gpu_worker.py:298] Available KV cache memory: 40.75 GiB
INFO 10-28 22:05:35 [kv_cache_utils.py:1087] GPU KV cache size: 3,561,024 tokens
INFO 10-28 22:05:35 [kv_cache_utils.py:1091] Maximum concurrency for 20 tokens per request: 111282.00x
Capturing CUDA graphs (mixed prefill-decode, PIECEWISE): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 40.59it/s]
Capturing CUDA graphs (decode, FULL): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 63.58it/s]
INFO 10-28 22:05:36 [gpu_model_runner.py:3480] Graph capturing finished in 1 secs, took 0.11 GiB
INFO 10-28 22:05:36 [core.py:210] init engine (profile, create kv cache, warmup model) took 11.77 seconds
INFO 10-28 22:05:37 [llm.py:306] Supported_tasks: ('generate',)
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.
[rank0]:W1028 22:05:41.125000 2760200 .venv/lib/python3.12/site-packages/torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
[rank0]:W1028 22:05:41.125000 2760200 .venv/lib/python3.12/site-packages/torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
Parameter Offload - Persistent parameters statistics: param_count = 121, numel = 71552
0%| | 0/113325 [00:00<?, ?it/s]INFO 10-28 22:05:45 [block_pool.py:378] Successfully reset prefix cache
INFO 10-28 22:06:00 [chat_utils.py:560] Detected the chat template content format to be 'string'. You can set `--chat-template-content-format` to override this.
[rank0]: Traceback (most recent call last):
[rank0]: File "/mnt/shared/fujita/debug_trl_truncate_prompt_tokens/debug.py", line 19, in <module>
[rank0]: trainer.train()
[rank0]: File "/mnt/shared/fujita/debug_trl_truncate_prompt_tokens/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 2325, in train
[rank0]: return inner_training_loop(
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/mnt/shared/fujita/debug_trl_truncate_prompt_tokens/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 2674, in _inner_training_loop
[rank0]: tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/mnt/shared/fujita/debug_trl_truncate_prompt_tokens/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 4014, in training_step
[rank0]: inputs = self._prepare_inputs(inputs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/mnt/shared/fujita/debug_trl_truncate_prompt_tokens/.venv/lib/python3.12/site-packages/trl/extras/profiling.py", line 98, in wrapper
[rank0]: return func(self, *args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/mnt/shared/fujita/debug_trl_truncate_prompt_tokens/.venv/lib/python3.12/site-packages/trl/trainer/grpo_trainer.py", line 1033, in _prepare_inputs
[rank0]: generation_batch = self._generate_and_score_completions(generation_batch)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/mnt/shared/fujita/debug_trl_truncate_prompt_tokens/.venv/lib/python3.12/site-packages/trl/trainer/grpo_trainer.py", line 1401, in _generate_and_score_completions
[rank0]: self._generate(prompts)
[rank0]: File "/mnt/shared/fujita/debug_trl_truncate_prompt_tokens/.venv/lib/python3.12/site-packages/trl/trainer/grpo_trainer.py", line 1342, in _generate
[rank0]: prompt_ids, completion_ids, logprobs, extra_fields = self._generate_single_turn(prompts)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/mnt/shared/fujita/debug_trl_truncate_prompt_tokens/.venv/lib/python3.12/site-packages/trl/trainer/grpo_trainer.py", line 1228, in _generate_single_turn
[rank0]: all_outputs = self.llm.chat(all_prompts, sampling_params=sampling_params, use_tqdm=False)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/mnt/shared/fujita/debug_trl_truncate_prompt_tokens/.venv/lib/python3.12/site-packages/vllm/entrypoints/llm.py", line 893, in chat
[rank0]: return self.generate(
[rank0]: ^^^^^^^^^^^^^^
[rank0]: File "/mnt/shared/fujita/debug_trl_truncate_prompt_tokens/.venv/lib/python3.12/site-packages/vllm/entrypoints/llm.py", line 393, in generate
[rank0]: self._validate_and_add_requests(
[rank0]: File "/mnt/shared/fujita/debug_trl_truncate_prompt_tokens/.venv/lib/python3.12/site-packages/vllm/entrypoints/llm.py", line 1516, in _validate_and_add_requests
[rank0]: self._add_request(
[rank0]: File "/mnt/shared/fujita/debug_trl_truncate_prompt_tokens/.venv/lib/python3.12/site-packages/vllm/entrypoints/llm.py", line 1569, in _add_request
[rank0]: self.llm_engine.add_request(
[rank0]: File "/mnt/shared/fujita/debug_trl_truncate_prompt_tokens/.venv/lib/python3.12/site-packages/vllm/v1/engine/llm_engine.py", line 230, in add_request
[rank0]: prompt_str, request = self.processor.process_inputs(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/mnt/shared/fujita/debug_trl_truncate_prompt_tokens/.venv/lib/python3.12/site-packages/vllm/v1/engine/processor.py", line 392, in process_inputs
[rank0]: self._validate_model_inputs(encoder_inputs, decoder_inputs)
[rank0]: File "/mnt/shared/fujita/debug_trl_truncate_prompt_tokens/.venv/lib/python3.12/site-packages/vllm/v1/engine/processor.py", line 466, in _validate_model_inputs
[rank0]: self._validate_model_input(decoder_inputs, prompt_type="decoder")
[rank0]: File "/mnt/shared/fujita/debug_trl_truncate_prompt_tokens/.venv/lib/python3.12/site-packages/vllm/v1/engine/processor.py", line 535, in _validate_model_input
[rank0]: raise ValueError(
[rank0]: ValueError: The decoder prompt (length 299) is longer than the maximum model length of 20. Make sure that `max_model_len` is no smaller than the number of text tokens.
[rank0]:[W1028 22:06:01.841774822 ProcessGroupNCCL.cpp:1538] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
0%| | 0/113325 [00:20<?, ?it/s]
E1028 22:06:07.303000 2759395 torch/distributed/elastic/multiprocessing/api.py:874] failed (exitcode: 1) local_rank: 0 (pid: 2760200) of binary: /mnt/shared/fujita/debug_trl_truncate_prompt_tokens/.venv/bin/python3
Traceback (most recent call last):
File "/mnt/shared/fujita/debug_trl_truncate_prompt_tokens/.venv/bin/accelerate", line 10, in <module>
sys.exit(main())
^^^^^^
File "/mnt/shared/fujita/debug_trl_truncate_prompt_tokens/.venv/lib/python3.12/site-packages/accelerate/commands/accelerate_cli.py", line 50, in main
args.func(args)
File "/mnt/shared/fujita/debug_trl_truncate_prompt_tokens/.venv/lib/python3.12/site-packages/accelerate/commands/launch.py", line 1220, in launch_command
deepspeed_launcher(args)
File "/mnt/shared/fujita/debug_trl_truncate_prompt_tokens/.venv/lib/python3.12/site-packages/accelerate/commands/launch.py", line 906, in deepspeed_launcher
distrib_run.run(args)
File "/mnt/shared/fujita/debug_trl_truncate_prompt_tokens/.venv/lib/python3.12/site-packages/torch/distributed/run.py", line 892, in run
elastic_launch(
File "/mnt/shared/fujita/debug_trl_truncate_prompt_tokens/.venv/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 143, in __call__
return launch_agent(self._config, self._entrypoint, list(args))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/shared/fujita/debug_trl_truncate_prompt_tokens/.venv/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 277, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
debug.py FAILED
------------------------------------------------------------
Failures:
<NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
time : 2025-10-28_22:06:07
host : fujita-8gpu
rank : 0 (local_rank: 0)
exitcode : 1 (pid: 2760200)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================
The error
ValueError: The decoder prompt (length 299) is longer than the maximum model length of 20. Make sure that
max_model_lenis no smaller than the number of text tokens.
suggests that prompts are not truncated to a specified length of 10.
Why it stopped working
After #4153, maximum_prompt_length is implemented via truncate_prompt_tokens of vllm.SamplingParams, which seems ok. Then, #4155 started using vllm.LLM.chat for a conversational dataset, which caused the issue because truncate_prompt_tokens seems to work with vllm.LLM.generate but not with vllm.LLM.chat as of now (see vllm-project/vllm#27642).
System Info
- Platform: Linux-5.10.238-234.956.amzn2.x86_64-x86_64-with-glibc2.39
- Python version: 3.12.3
- TRL version: 0.25.0.dev0
- PyTorch version: 2.8.0+cu126
- accelerator(s): NVIDIA H200, NVIDIA H200, NVIDIA H200, NVIDIA H200, NVIDIA H200, NVIDIA H200, NVIDIA H200, NVIDIA H200
- Transformers version: 4.57.1
- Accelerate version: 1.11.0
- Accelerate config:
- compute_environment: LOCAL_MACHINE
- distributed_type: DEEPSPEED
- mixed_precision: bf16
- use_cpu: False
- debug: False
- num_processes: 8
- machine_rank: 0
- num_machines: 1
- rdzv_backend: static
- same_network: True
- main_training_function: main
- enable_cpu_affinity: False
- deepspeed_config: {'gradient_accumulation_steps': 1, 'gradient_clipping': 1.0, 'offload_optimizer_device': 'cpu', 'offload_param_device': 'cpu', 'zero3_init_flag': True, 'zero3_save_16bit_model': True, 'zero_stage': 3}
- downcast_bf16: no
- tpu_use_cluster: False
- tpu_use_sudo: False
- tpu_env: []
- Datasets version: 4.3.0
- HF Hub version: 0.36.0
- bitsandbytes version: not installed
- DeepSpeed version: 0.18.1
- Liger-Kernel version: not installed
- LLM-Blender version: not installed
- OpenAI version: 2.6.1
- PEFT version: not installed
- vLLM version: 0.11.0
Checklist
- I have checked that my issue isn't already filed (see open issues)
- I have included my system information
- Any code provided is minimal, complete, and reproducible (more on MREs)
- Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
- Any traceback provided is complete