Skip to content

Conversation

@albertvillanova
Copy link
Member

@albertvillanova albertvillanova commented Jan 7, 2026

Refactor KTO [1/N]: Modernize model initialization.

This PR modernizes KTOTrainer's model initialization to align with SFTTrainer's clean and maintainable patterns. It replaces manual model loading with the create_model_from_path() helper function.

Part of:

Problem

Before (KTO):

  • Manual handling of model_init_kwargs and ref_model_init_kwargs (43 lines)
  • Manual dtype conversion with getattr(torch, dtype)
  • Manual device_map setting
  • Duplicate code for model and ref_model
  • Direct calls to AutoModelForCausalLM.from_pretrained
  • Hard errors instead of warnings for already-instantiated models

After (Aligned with SFT):

  • Clean kwargs handling with or {} pattern
  • Automatic dtype conversion via helper
  • DeepSpeed/MULTI_GPU device_map handling
  • Single call to create_model_from_path helper
  • User-friendly warnings

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec
Copy link
Member

Nice! If it's easier for you, I think it's fine to have a big refactoring PR like in #3906

Comment on lines +383 to +395
# Reference model initialization
if isinstance(ref_model, str):
ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
ref_model_init_kwargs = args.ref_model_init_kwargs or {}
# Distributed training requires device_map=None
if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]:
ref_model_init_kwargs["device_map"] = None
ref_model = create_model_from_path(ref_model, **ref_model_init_kwargs)
else:
if ref_model is not None and args.ref_model_init_kwargs is not None:
logger.warning(
"You passed `ref_model_init_kwargs` to the KTOConfig, but your ref_model is already instantiated. "
"The `ref_model_init_kwargs` will be ignored."
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In GRPO/RLOO/DPO refactored, the ref model if loaded after super().__init__(...), but we can still align later

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for catching this: that's indeed better architecture.

I agree we can align this later, as I was planning to do on the phase 3 refactoring plan: Reference Model Handling specifically planned for ref_model improvements.

@albertvillanova
Copy link
Member Author

albertvillanova commented Jan 8, 2026

If it's easier for you, I think it's fine to have a big refactoring PR like in #3906

Thanks for your suggestion, but I would prefer to keep the PRs small for review quality and risk management. Each PR is independently valuable and can be reviewed in 15-30 minutes.

While a big refactoring PR sounds efficient, I think it creates high risk, poor review quality, slower iteration, and harder debugging. Indeed, I am already finding difficult to resolve conflicts each time I am merging the main branch to this other PR: #4700.

IMO, small PRs are better for quality, speed, and maintainability.

Happy to discuss if you have concerns about the granularity! 😅

@albertvillanova albertvillanova mentioned this pull request Jan 8, 2026
6 tasks
@albertvillanova albertvillanova merged commit 1a93971 into huggingface:main Jan 8, 2026
2 of 3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants