From 16c4c1ab9f4e9bcdf2eb0019423d50b5b1bba118 Mon Sep 17 00:00:00 2001 From: Behrooz Date: Sun, 2 Nov 2025 13:02:46 -0800 Subject: [PATCH 1/2] docs: Expand speeding up training guide with acceleration methods Resolves #4382 - Add Flash Attention 2 section with minimal example and link to reducing_memory_usage - Add PEFT integration section with LoRA example and link to peft_integration - Add Liger Kernel section with example and link to liger_kernel_integration - Add Gradient Checkpointing section with example and link to Transformers guide - Add Mixed Precision Training section with bf16/fp16 examples - Update introduction to reflect comprehensive coverage - All examples verified against TRL source code and official examples --- docs/source/speeding_up_training.md | 80 ++++++++++++++++++++++++++++- 1 file changed, 78 insertions(+), 2 deletions(-) diff --git a/docs/source/speeding_up_training.md b/docs/source/speeding_up_training.md index 8eff572e161..2f5f15a92c1 100644 --- a/docs/source/speeding_up_training.md +++ b/docs/source/speeding_up_training.md @@ -1,7 +1,6 @@ # Speeding Up Training -> [!WARNING] -> Section under construction. Feel free to contribute! +This guide covers various methods to accelerate training in TRL. Each technique includes minimal examples with links to more comprehensive documentation. ## vLLM for fast generation in online methods @@ -95,3 +94,80 @@ You can customize the server configuration by passing additional arguments. For + +## Flash Attention 2 for faster attention computation + +Flash Attention 2 is an optimized implementation of the attention mechanism that can significantly speed up training while reducing memory usage. It's particularly effective for long sequences. + +To enable Flash Attention 2, pass `attn_implementation="flash_attention_2"` in the model initialization arguments: + +```python +from trl import SFTConfig + +training_args = SFTConfig( + ..., + model_init_kwargs={"attn_implementation": "flash_attention_2"} +) +``` + +Flash Attention 2 works across all TRL trainers. For padding-free batching with Flash Attention, see [Reducing Memory Usage](reducing_memory_usage#padding-free). + +## PEFT for parameter-efficient training + +PEFT (Parameter-Efficient Fine-Tuning) methods like LoRA significantly reduce memory usage and training time by only training a small number of adapter parameters instead of the full model. + +```python +from peft import LoraConfig +from trl import SFTConfig, SFTTrainer + +peft_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + target_modules=["q_proj", "v_proj"], +) + +trainer = SFTTrainer( + model="Qwen/Qwen2.5-0.5B", + peft_config=peft_config, + args=training_args, +) +``` + +For more details, see [PEFT Integration](peft_integration). + +## Liger Kernel for memory optimization + +Liger Kernel is a collection of Triton kernels designed for LLM training that can increase throughput by 20% and reduce memory usage by 60%. + +```python +from trl import DPOConfig + +training_args = DPOConfig(..., use_liger_kernel=True) +``` + +Liger Kernel is supported across multiple trainers (SFT, DPO, GRPO, KTO, GKD). For more information, see [Liger Kernel Integration](liger_kernel_integration). + +## Gradient checkpointing for memory savings + +Gradient checkpointing trades compute for memory by not storing all intermediate activations during the forward pass, recomputing them during the backward pass instead. + +```python +from trl import SFTConfig + +training_args = SFTConfig(..., gradient_checkpointing=True) +``` + +Gradient checkpointing is available across all TRL trainers. For more memory optimization techniques, see the [Transformers Performance Guide](https://huggingface.co/docs/transformers/perf_train_gpu_one#gradient-checkpointing). + +## Mixed precision training + +Mixed precision training using bf16 or fp16 can speed up training and reduce memory usage with minimal impact on model quality. + +```python +from trl import SFTConfig + +training_args = SFTConfig(..., bf16=True) # or fp16=True for older GPUs +``` + +Use `bf16=True` for Ampere GPUs (A100, RTX 30xx) or newer, and `fp16=True` for older GPUs. Mixed precision training is supported across all TRL trainers. From d1b29b1cbf9634bcfea7747fa2a344a1a533cc05 Mon Sep 17 00:00:00 2001 From: Behrooz Date: Fri, 28 Nov 2025 15:11:34 -0800 Subject: [PATCH 2/2] docs: address review feedback for speeding up training guide - Add link to online methods taxonomy in vLLM section - Make Online DPO example consistent with GRPO/RLOO (add vllm-serve cmd) - Rename Flash Attention section to "Optimized attention implementations" - Add Kernels Hub tab for pre-optimized attention kernels - Add hyperlink to PEFT documentation - Expand Liger Kernel section with multi-trainer examples (SFT, DPO, GRPO, KTO) Resolves review feedback from sergiopaniego --- docs/source/speeding_up_training.md | 79 ++++++++++++++++++++++++----- 1 file changed, 67 insertions(+), 12 deletions(-) diff --git a/docs/source/speeding_up_training.md b/docs/source/speeding_up_training.md index 988cfbcbcd7..feff36cc7ec 100644 --- a/docs/source/speeding_up_training.md +++ b/docs/source/speeding_up_training.md @@ -4,8 +4,8 @@ This guide covers various methods to accelerate training in TRL. Each technique ## vLLM for fast generation in online methods -Online methods such as GRPO or Online DPO require the model to generate completions, which is often a slow process and can significantly impact training time. -To speed up generation, you can use [vLLM](https://github.com/vllm-project/vllm), a library that enables fast generation through, among other things, PagedAttention. TRL's online trainers support vLLM, greatly improving training speed. +[Online methods](index#online-methods) such as GRPO or Online DPO require the model to generate completions, which is often a slow process and can significantly impact training time. +To speed up generation, you can use [vLLM](https://github.com/vllm-project/vllm), a library that enables fast generation through, among other things, PagedAttention. TRL's online trainers support vLLM, greatly improving training speed. For more details, see [vLLM Integration](vllm_integration). To use [vLLM](https://github.com/vllm-project/vllm), first install it using: @@ -16,7 +16,13 @@ pip install trl[vllm] -Then, enable it by passing `use_vllm=True` in the training arguments. +First, start a vLLM server by running: + +```bash +trl vllm-serve --model +``` + +Then, run the training script and pass `use_vllm=True` in the training arguments. ```python from trl.experimental.online_dpo import OnlineDPOConfig @@ -95,26 +101,42 @@ You can customize the server configuration by passing additional arguments. For -## Flash Attention 2 for faster attention computation +## Optimized attention implementations + +TRL supports various optimized attention implementations that can significantly speed up training while reducing memory usage. You can use either locally installed backends (like Flash Attention 2) or pull pre-optimized kernels directly from the [Kernels Hub](kernels_hub). -Flash Attention 2 is an optimized implementation of the attention mechanism that can significantly speed up training while reducing memory usage. It's particularly effective for long sequences. + + To enable Flash Attention 2, pass `attn_implementation="flash_attention_2"` in the model initialization arguments: ```python from trl import SFTConfig -training_args = SFTConfig( - ..., - model_init_kwargs={"attn_implementation": "flash_attention_2"} -) +training_args = SFTConfig(..., model_init_kwargs={"attn_implementation": "flash_attention_2"}) ``` -Flash Attention 2 works across all TRL trainers. For padding-free batching with Flash Attention, see [Reducing Memory Usage](reducing_memory_usage#padding-free). + + + +You can use pre-optimized attention kernels from the Hub without manual compilation: + +```python +from trl import SFTConfig + +training_args = SFTConfig(..., model_init_kwargs={"attn_implementation": "kernels-community/flash-attn2"}) +``` + +Other options include `kernels-community/vllm-flash-attn3` and `kernels-community/paged-attention`. + + + + +Optimized attention works across all TRL trainers. For more details, see [Kernels Hub Integration](kernels_hub) and [Reducing Memory Usage](reducing_memory_usage#padding-free). ## PEFT for parameter-efficient training -PEFT (Parameter-Efficient Fine-Tuning) methods like LoRA significantly reduce memory usage and training time by only training a small number of adapter parameters instead of the full model. +[PEFT](https://huggingface.co/docs/peft/index) (Parameter-Efficient Fine-Tuning) methods like LoRA significantly reduce memory usage and training time by only training a small number of adapter parameters instead of the full model. ```python from peft import LoraConfig @@ -140,13 +162,46 @@ For more details, see [PEFT Integration](peft_integration). Liger Kernel is a collection of Triton kernels designed for LLM training that can increase throughput by 20% and reduce memory usage by 60%. + + + +```python +from trl import SFTConfig + +training_args = SFTConfig(..., use_liger_kernel=True) +``` + + + + ```python from trl import DPOConfig training_args = DPOConfig(..., use_liger_kernel=True) ``` -Liger Kernel is supported across multiple trainers (SFT, DPO, GRPO, KTO, GKD). For more information, see [Liger Kernel Integration](liger_kernel_integration). + + + +```python +from trl import GRPOConfig + +training_args = GRPOConfig(..., use_liger_kernel=True) +``` + + + + +```python +from trl import KTOConfig + +training_args = KTOConfig(..., use_liger_kernel=True) +``` + + + + +For more information, see [Liger Kernel Integration](liger_kernel_integration). ## Gradient checkpointing for memory savings