diff --git a/docs/source/speeding_up_training.md b/docs/source/speeding_up_training.md index a0adfd90e5e..feff36cc7ec 100644 --- a/docs/source/speeding_up_training.md +++ b/docs/source/speeding_up_training.md @@ -1,12 +1,11 @@ # Speeding Up Training -> [!WARNING] -> Section under construction. Feel free to contribute! +This guide covers various methods to accelerate training in TRL. Each technique includes minimal examples with links to more comprehensive documentation. ## vLLM for fast generation in online methods -Online methods such as GRPO or Online DPO require the model to generate completions, which is often a slow process and can significantly impact training time. -To speed up generation, you can use [vLLM](https://github.com/vllm-project/vllm), a library that enables fast generation through, among other things, PagedAttention. TRL's online trainers support vLLM, greatly improving training speed. +[Online methods](index#online-methods) such as GRPO or Online DPO require the model to generate completions, which is often a slow process and can significantly impact training time. +To speed up generation, you can use [vLLM](https://github.com/vllm-project/vllm), a library that enables fast generation through, among other things, PagedAttention. TRL's online trainers support vLLM, greatly improving training speed. For more details, see [vLLM Integration](vllm_integration). To use [vLLM](https://github.com/vllm-project/vllm), first install it using: @@ -17,7 +16,13 @@ pip install trl[vllm] -Then, enable it by passing `use_vllm=True` in the training arguments. +First, start a vLLM server by running: + +```bash +trl vllm-serve --model +``` + +Then, run the training script and pass `use_vllm=True` in the training arguments. ```python from trl.experimental.online_dpo import OnlineDPOConfig @@ -95,3 +100,129 @@ You can customize the server configuration by passing additional arguments. For + +## Optimized attention implementations + +TRL supports various optimized attention implementations that can significantly speed up training while reducing memory usage. You can use either locally installed backends (like Flash Attention 2) or pull pre-optimized kernels directly from the [Kernels Hub](kernels_hub). + + + + +To enable Flash Attention 2, pass `attn_implementation="flash_attention_2"` in the model initialization arguments: + +```python +from trl import SFTConfig + +training_args = SFTConfig(..., model_init_kwargs={"attn_implementation": "flash_attention_2"}) +``` + + + + +You can use pre-optimized attention kernels from the Hub without manual compilation: + +```python +from trl import SFTConfig + +training_args = SFTConfig(..., model_init_kwargs={"attn_implementation": "kernels-community/flash-attn2"}) +``` + +Other options include `kernels-community/vllm-flash-attn3` and `kernels-community/paged-attention`. + + + + +Optimized attention works across all TRL trainers. For more details, see [Kernels Hub Integration](kernels_hub) and [Reducing Memory Usage](reducing_memory_usage#padding-free). + +## PEFT for parameter-efficient training + +[PEFT](https://huggingface.co/docs/peft/index) (Parameter-Efficient Fine-Tuning) methods like LoRA significantly reduce memory usage and training time by only training a small number of adapter parameters instead of the full model. + +```python +from peft import LoraConfig +from trl import SFTConfig, SFTTrainer + +peft_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + target_modules=["q_proj", "v_proj"], +) + +trainer = SFTTrainer( + model="Qwen/Qwen2.5-0.5B", + peft_config=peft_config, + args=training_args, +) +``` + +For more details, see [PEFT Integration](peft_integration). + +## Liger Kernel for memory optimization + +Liger Kernel is a collection of Triton kernels designed for LLM training that can increase throughput by 20% and reduce memory usage by 60%. + + + + +```python +from trl import SFTConfig + +training_args = SFTConfig(..., use_liger_kernel=True) +``` + + + + +```python +from trl import DPOConfig + +training_args = DPOConfig(..., use_liger_kernel=True) +``` + + + + +```python +from trl import GRPOConfig + +training_args = GRPOConfig(..., use_liger_kernel=True) +``` + + + + +```python +from trl import KTOConfig + +training_args = KTOConfig(..., use_liger_kernel=True) +``` + + + + +For more information, see [Liger Kernel Integration](liger_kernel_integration). + +## Gradient checkpointing for memory savings + +Gradient checkpointing trades compute for memory by not storing all intermediate activations during the forward pass, recomputing them during the backward pass instead. + +```python +from trl import SFTConfig + +training_args = SFTConfig(..., gradient_checkpointing=True) +``` + +Gradient checkpointing is available across all TRL trainers. For more memory optimization techniques, see the [Transformers Performance Guide](https://huggingface.co/docs/transformers/perf_train_gpu_one#gradient-checkpointing). + +## Mixed precision training + +Mixed precision training using bf16 or fp16 can speed up training and reduce memory usage with minimal impact on model quality. + +```python +from trl import SFTConfig + +training_args = SFTConfig(..., bf16=True) # or fp16=True for older GPUs +``` + +Use `bf16=True` for Ampere GPUs (A100, RTX 30xx) or newer, and `fp16=True` for older GPUs. Mixed precision training is supported across all TRL trainers.