-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Refactor KTO [1/N]: Modernize model initialization #4783
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor KTO [1/N]: Modernize model initialization #4783
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Nice! If it's easier for you, I think it's fine to have a big refactoring PR like in #3906 |
| # Reference model initialization | ||
| if isinstance(ref_model, str): | ||
| ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs) | ||
| ref_model_init_kwargs = args.ref_model_init_kwargs or {} | ||
| # Distributed training requires device_map=None | ||
| if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: | ||
| ref_model_init_kwargs["device_map"] = None | ||
| ref_model = create_model_from_path(ref_model, **ref_model_init_kwargs) | ||
| else: | ||
| if ref_model is not None and args.ref_model_init_kwargs is not None: | ||
| logger.warning( | ||
| "You passed `ref_model_init_kwargs` to the KTOConfig, but your ref_model is already instantiated. " | ||
| "The `ref_model_init_kwargs` will be ignored." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In GRPO/RLOO/DPO refactored, the ref model if loaded after super().__init__(...), but we can still align later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for catching this: that's indeed better architecture.
I agree we can align this later, as I was planning to do on the phase 3 refactoring plan: Reference Model Handling specifically planned for ref_model improvements.
Thanks for your suggestion, but I would prefer to keep the PRs small for review quality and risk management. Each PR is independently valuable and can be reviewed in 15-30 minutes. While a big refactoring PR sounds efficient, I think it creates high risk, poor review quality, slower iteration, and harder debugging. Indeed, I am already finding difficult to resolve conflicts each time I am merging the main branch to this other PR: #4700. IMO, small PRs are better for quality, speed, and maintainability. Happy to discuss if you have concerns about the granularity! 😅 |
Refactor KTO [1/N]: Modernize model initialization.
This PR modernizes KTOTrainer's model initialization to align with SFTTrainer's clean and maintainable patterns. It replaces manual model loading with the
create_model_from_path()helper function.Part of:
Problem
Before (KTO):
model_init_kwargsandref_model_init_kwargs(43 lines)getattr(torch, dtype)AutoModelForCausalLM.from_pretrainedAfter (Aligned with SFT):
or {}patterncreate_model_from_pathhelper