Conversation
There was a problem hiding this comment.
Pull Request Overview
This PR improves type safety and configuration handling for dtype parameters by converting string-based dtype specifications to use Literal types and a mapping dictionary. The changes ensure that dtype arguments are properly validated at the CLI level and converted to JAX dtype objects before being passed to the main training functions.
Key Changes:
- Added
Literaltype hints forlr_schedule,param_dtype, anddtypefields to restrict valid values - Introduced
DTYPE_MAPdictionary to convert string dtype specifications to JAX dtype objects - Added dtype information to wandb logging configuration
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| jasmine/train_tokenizer.py | Added dtype type safety with Literal hints, dtype mapping, and conditional flash attention disabling |
| jasmine/train_lam.py | Applied same dtype handling improvements as train_tokenizer.py |
| jasmine/train_dynamics.py | Applied same dtype handling improvements as train_tokenizer.py |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull Request Overview
Copilot reviewed 3 out of 3 changed files in this pull request and generated 6 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
No description provided.