8585 type = int ,
8686 default = 1869 ,
8787 required = False ,
88- help = "Number of batches for training." ,
88+ help = (
89+ "Number of batches for training. Defaults to total number of samples //"
90+ " global batch size."
91+ ),
8992)
9093parser .add_argument (
9194 "--num-test-batches" ,
9497 required = False ,
9598 help = "Number of test batches for evaluation." ,
9699)
100+ parser .add_argument (
101+ "--global-batch-size" ,
102+ type = int ,
103+ default = 4 ,
104+ required = False ,
105+ help = "Number of global batches for learning." ,
106+ )
107+ parser .add_argument (
108+ "--train-micro-batch-size" ,
109+ type = int ,
110+ default = 2 ,
111+ required = False ,
112+ help = "Number of micro batches for training." ,
113+ )
114+ parser .add_argument (
115+ "--train-mini-batch-size" ,
116+ type = int ,
117+ default = 4 ,
118+ required = False ,
119+ help = "Number of mini batches for training." ,
120+ )
97121parser .add_argument (
98122 "--rollout-engine" ,
99123 type = str ,
163187# ====== GRPO ======
164188# === Generation during GRPO training ===
165189MAX_PROMPT_LENGTH = 256
166- TOTAL_GENERATION_STEPS = 1024 # YY 768
190+ TOTAL_GENERATION_STEPS = 768
167191# Important to keep a high-ish temperature for varied, diverse responses during
168192# training.
169193TEMPERATURE = 0.9
186210EPSILON = 0.2
187211
188212# ====== Training ======
189- # 2 is the max we can do on v5e-8 with llama3 8B model.
190- # 4 is the max we can do on v5e-8 with llama3 1B model.
191- TRAIN_MICRO_BATCH_SIZE = 4
192213# To speed up for quick workflow validation, we can change NUM_BATCHES to e.g. 2
193- NUM_BATCHES = args .num_batches
214+ NUM_BATCHES = min ( args .num_batches , 7473 // args . global_batch_size )
194215# Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be
195216# increased to a max. of 330 (if batch size is 4).
196217# To speed up for quick workflow validation, we can change it to e.g. 1
197218NUM_TEST_BATCHES = args .num_test_batches
198219
199- EVAL_EVERY_N_STEPS = 10 # this doesn't matter if `TRAIN_FRACTION = 1.0`.
220+ EVAL_EVERY_N_STEPS = 1000 # this doesn't matter if `TRAIN_FRACTION = 1.0`.
200221NUM_EPOCHS = 1 # can potentially train for more epochs
201222
202223# Number of training steps.
@@ -344,7 +365,7 @@ def get_dataset(path: str) -> grain.MapDataset:
344365 return loaded_dataset
345366
346367
347- dataset = get_dataset (TRAIN_DATA_PATH ).batch (TRAIN_MICRO_BATCH_SIZE )[
368+ dataset = get_dataset (TRAIN_DATA_PATH ).batch (args . global_batch_size )[
348369 :NUM_BATCHES
349370]
350371
@@ -357,7 +378,7 @@ def get_dataset(path: str) -> grain.MapDataset:
357378
358379 val_dataset = dataset [int (len (dataset ) * TRAIN_FRACTION ) :].repeat (NUM_EPOCHS )
359380
360- test_dataset = get_dataset (TEST_DATA_PATH ).batch (TRAIN_MICRO_BATCH_SIZE )[
381+ test_dataset = get_dataset (TEST_DATA_PATH ).batch (args . global_batch_size )[
361382 :NUM_TEST_BATCHES
362383]
363384
@@ -627,7 +648,7 @@ def generate(
627648
628649 out_data = sampler (
629650 input_strings = input_batch ,
630- max_generation_steps = 768 ,
651+ max_generation_steps = TOTAL_GENERATION_STEPS ,
631652 temperature = temperature ,
632653 top_k = top_k ,
633654 top_p = top_p ,
@@ -782,8 +803,8 @@ def evaluate(
782803 actor_optimizer = optimizer ,
783804 eval_every_n_steps = EVAL_EVERY_N_STEPS ,
784805 max_steps = MAX_STEPS ,
785- mini_batch_size = TRAIN_MICRO_BATCH_SIZE ,
786- train_micro_batch_size = TRAIN_MICRO_BATCH_SIZE ,
806+ mini_batch_size = args . train_mini_batch_size ,
807+ train_micro_batch_size = args . train_micro_batch_size ,
787808 # metrics logging
788809 metrics_logging_options = metrics_logging_options ,
789810 # checkpoint saving
@@ -802,7 +823,6 @@ def evaluate(
802823 rollout_vllm_tpu_backend_type = "jax" ,
803824 rollout_vllm_server_mode = args .rollout_server_mode ,
804825 ),
805-
806826)
807827
808828grpo_config = grpo_learner .GRPOConfig (
0 commit comments