Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 150 additions & 1 deletion docs/source/customization.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Training customization

TRL is designed with modularity in mind so that users are able to efficiently customize the training loop for their needs. Below are some examples on how you can apply and test different techniques. Note: Although these examples use the DPOTrainer, the customization applies to most (if not all) trainers.
TRL is designed with modularity in mind so that users are able to efficiently customize the training loop for their needs. Below are examples on how you can apply and test different techniques. Note: Although these examples use the [`DPOTrainer`], these customization methods apply to most (if not all) trainers in TRL.

## Use different optimizers and schedulers

Expand Down Expand Up @@ -117,3 +117,152 @@ When training large models, you should better handle the accelerator cache by it
```python
training_args = DPOConfig(..., optimize_device_cache=True)
```

## Add custom callbacks

You can customize the training loop by adding callbacks for logging, monitoring, or early stopping. Callbacks allow you to execute custom code at specific points during training.

```python
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback
from trl import DPOConfig, DPOTrainer


class CustomLoggingCallback(TrainerCallback):
def on_log(self, args, state, control, logs=None, **kwargs):
if logs is not None:
print(f"Step {state.global_step}: {logs}")


model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")

trainer = DPOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
callbacks=[CustomLoggingCallback()],
)
trainer.train()
```

## Add custom evaluation metrics

You can define custom evaluation metrics to track during training. This is useful for monitoring model performance on specific tasks.

```python
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer


def compute_metrics(eval_preds):
# Custom metric computation
logits, labels = eval_preds
# Add your metric computation here
return {"custom_metric": 0.0}


model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
eval_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="test[:10%]")
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO", eval_strategy="steps", eval_steps=100)

trainer = DPOTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
)
trainer.train()
```

## Use mixed precision training

Mixed precision training can significantly speed up training and reduce memory usage. You can enable it by setting `bf16=True` or `fp16=True` in the training config.

```python
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")

# Use bfloat16 precision (recommended for modern GPUs)
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO", bf16=True)

trainer = DPOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
)
trainer.train()
```

Note: Use `bf16=True` for Ampere GPUs (A100, RTX 30xx) or newer, and `fp16=True` for older GPUs.

## Use gradient accumulation

When training with limited GPU memory, gradient accumulation allows you to simulate larger batch sizes by accumulating gradients over multiple steps before updating weights.

```python
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")

# Simulate a batch size of 32 with per_device_train_batch_size=4 and gradient_accumulation_steps=8
training_args = DPOConfig(
output_dir="Qwen2.5-0.5B-DPO",
per_device_train_batch_size=4,
gradient_accumulation_steps=8,
)

trainer = DPOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
)
trainer.train()
```

## Use a custom data collator

You can provide a custom data collator to handle special data preprocessing or padding strategies.

```python
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer
from trl.trainer.dpo_trainer import DataCollatorForPreference

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")

# Create a custom data collator with specific padding token
data_collator = DataCollatorForPreference(pad_token_id=tokenizer.pad_token_id)

trainer = DPOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
data_collator=data_collator,
)
trainer.train()
```