@@ -122,11 +122,11 @@ def test_dpo_trainer(
122122 rejected_mask ,
123123 use_ref_model ,
124124 ):
125- model = tc .ToyTransformer (rngs = nnx .Rngs (0 ))
125+ model = tc .ToyTransformer (config = tc . ModelConfig (), rngs = nnx .Rngs (0 ))
126126 original_variables = jax .tree .map (jnp .copy , nnx .state (model , nnx .Param ))
127127 ref_model = None
128128 if use_ref_model :
129- ref_model = tc .ToyTransformer (rngs = nnx .Rngs (0 ))
129+ ref_model = tc .ToyTransformer (config = tc . ModelConfig (), rngs = nnx .Rngs (0 ))
130130 dpo_config = dpo_lib .DPOTrainingConfig (
131131 eval_every_n_steps = 5 ,
132132 max_steps = 10 ,
@@ -169,11 +169,15 @@ def test_dpo_trainer(
169169 "log_probs/rejected" ,
170170 ]:
171171 self .assertLen (
172- dpo_trainer .metrics_logger .get_metric_history (metric_name , "train" ),
172+ dpo_trainer .metrics_logger .get_metric_history (
173+ "" , metric_name , "train"
174+ ),
173175 dpo_trainer ._train_steps ,
174176 )
175177 self .assertLen (
176- dpo_trainer .metrics_logger .get_metric_history (metric_name , "eval" ),
178+ dpo_trainer .metrics_logger .get_metric_history (
179+ "" , metric_name , "eval"
180+ ),
177181 3 ,
178182 )
179183
@@ -201,11 +205,13 @@ def test_dpo_trainer(
201205 def test_dpo_trainer_with_string_inputs (self , train_ds ):
202206 tokenizer = tc .MockVocab ()
203207 model = tc .ToyTransformer (
204- rngs = nnx .Rngs (0 ), vocab_size = tokenizer .GetPieceSize ()
208+ config = tc .ModelConfig (vocab_size = tokenizer .GetPieceSize ()),
209+ rngs = nnx .Rngs (0 ),
205210 )
206211 original_variables = jax .tree .map (jnp .copy , nnx .state (model , nnx .Param ))
207212 ref_model = tc .ToyTransformer (
208- rngs = nnx .Rngs (0 ), vocab_size = tokenizer .GetPieceSize ()
213+ config = tc .ModelConfig (vocab_size = tokenizer .GetPieceSize ()),
214+ rngs = nnx .Rngs (0 ),
209215 )
210216 original_ref_variables = jax .tree .map (
211217 jnp .copy , nnx .state (ref_model , nnx .Param )
@@ -240,13 +246,15 @@ def test_dpo_trainer_with_string_inputs(self, train_ds):
240246 "rewards/accuracy" ,
241247 ]:
242248 self .assertLen (
243- dpo_trainer .metrics_logger .get_metric_history (metric_name , "train" ),
249+ dpo_trainer .metrics_logger .get_metric_history (
250+ "" , metric_name , "train"
251+ ),
244252 dpo_trainer ._train_steps ,
245253 )
246254
247255 def test_dpo_loss_fn (self ):
248256 np .random .seed (0 )
249- model = tc .ToyTransformer (rngs = nnx .Rngs (0 ))
257+ model = tc .ToyTransformer (config = tc . ModelConfig (), rngs = nnx .Rngs (0 ))
250258 per_token_logps = np .random .normal (0 , 5 , size = (8 , 4 ))
251259 ref_per_token_logps = np .random .normal (0 , 5 , size = (8 , 4 )).sum (axis = - 1 )
252260 train_example = dpo_lib .TrainExample (
@@ -262,20 +270,26 @@ def test_dpo_loss_fn(self):
262270 with mock .patch .object (
263271 common , "get_per_token_logps" , return_value = jnp .array (per_token_logps )
264272 ):
265- loss , _ = dpo_lib .dpo_loss_fn (model , train_example , 0.1 , 0 )
273+ loss , _ = dpo_lib .dpo_loss_fn (
274+ model , train_example , beta = 0.1 , label_smoothing = 0
275+ )
266276 np .testing .assert_allclose (loss , 0.753059 , atol = 1e-5 )
267277
268- loss , _ = dpo_lib .dpo_loss_fn (model , train_example , 0.1 , 0.3 )
278+ loss , _ = dpo_lib .dpo_loss_fn (
279+ model , train_example , beta = 0.1 , label_smoothing = 0.3
280+ )
269281 np .testing .assert_allclose (loss , 0.925447 , atol = 1e-5 )
270282
271283 def test_dpo_prepare_inputs_for_strings (self ):
272284 tokenizer = tc .MockVocab ()
273285
274286 model = tc .ToyTransformer (
275- rngs = nnx .Rngs (0 ), vocab_size = tokenizer .GetPieceSize ()
287+ config = tc .ModelConfig (vocab_size = tokenizer .GetPieceSize ()),
288+ rngs = nnx .Rngs (0 ),
276289 )
277290 ref_model = tc .ToyTransformer (
278- rngs = nnx .Rngs (0 ), vocab_size = tokenizer .GetPieceSize ()
291+ config = tc .ModelConfig (vocab_size = tokenizer .GetPieceSize ()),
292+ rngs = nnx .Rngs (0 ),
279293 )
280294 dpo_trainer = dpo_lib .DPOTrainer (
281295 model = model ,
@@ -328,8 +342,8 @@ def test_dpo_prepare_inputs_for_strings(self):
328342 self .assertEqual (out .logits_to_keep , 3 )
329343
330344 def test_dpo_prepare_inputs (self ):
331- model = tc .ToyTransformer (rngs = nnx .Rngs (0 ))
332- ref_model = tc .ToyTransformer (rngs = nnx .Rngs (0 ))
345+ model = tc .ToyTransformer (config = tc . ModelConfig (), rngs = nnx .Rngs (0 ))
346+ ref_model = tc .ToyTransformer (config = tc . ModelConfig (), rngs = nnx .Rngs (0 ))
333347 dpo_trainer = dpo_lib .DPOTrainer (
334348 model = model ,
335349 ref_model = ref_model ,
0 commit comments