Skip to content

Commit 150d12f

Browse files
committed
feat: add ORPO support
1 parent a3f1ca8 commit 150d12f

File tree

4 files changed

+551
-66
lines changed

4 files changed

+551
-66
lines changed

tests/sft/dpo/dpo_trainer_test.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,11 @@ def test_dpo_trainer(
122122
rejected_mask,
123123
use_ref_model,
124124
):
125-
model = tc.ToyTransformer(rngs=nnx.Rngs(0))
125+
model = tc.ToyTransformer(config=tc.ModelConfig(), rngs=nnx.Rngs(0))
126126
original_variables = jax.tree.map(jnp.copy, nnx.state(model, nnx.Param))
127127
ref_model = None
128128
if use_ref_model:
129-
ref_model = tc.ToyTransformer(rngs=nnx.Rngs(0))
129+
ref_model = tc.ToyTransformer(config=tc.ModelConfig(), rngs=nnx.Rngs(0))
130130
dpo_config = dpo_lib.DPOTrainingConfig(
131131
eval_every_n_steps=5,
132132
max_steps=10,
@@ -169,11 +169,15 @@ def test_dpo_trainer(
169169
"log_probs/rejected",
170170
]:
171171
self.assertLen(
172-
dpo_trainer.metrics_logger.get_metric_history(metric_name, "train"),
172+
dpo_trainer.metrics_logger.get_metric_history(
173+
"", metric_name, "train"
174+
),
173175
dpo_trainer._train_steps,
174176
)
175177
self.assertLen(
176-
dpo_trainer.metrics_logger.get_metric_history(metric_name, "eval"),
178+
dpo_trainer.metrics_logger.get_metric_history(
179+
"", metric_name, "eval"
180+
),
177181
3,
178182
)
179183

@@ -201,11 +205,13 @@ def test_dpo_trainer(
201205
def test_dpo_trainer_with_string_inputs(self, train_ds):
202206
tokenizer = tc.MockVocab()
203207
model = tc.ToyTransformer(
204-
rngs=nnx.Rngs(0), vocab_size=tokenizer.GetPieceSize()
208+
config=tc.ModelConfig(vocab_size=tokenizer.GetPieceSize()),
209+
rngs=nnx.Rngs(0),
205210
)
206211
original_variables = jax.tree.map(jnp.copy, nnx.state(model, nnx.Param))
207212
ref_model = tc.ToyTransformer(
208-
rngs=nnx.Rngs(0), vocab_size=tokenizer.GetPieceSize()
213+
config=tc.ModelConfig(vocab_size=tokenizer.GetPieceSize()),
214+
rngs=nnx.Rngs(0),
209215
)
210216
original_ref_variables = jax.tree.map(
211217
jnp.copy, nnx.state(ref_model, nnx.Param)
@@ -240,13 +246,15 @@ def test_dpo_trainer_with_string_inputs(self, train_ds):
240246
"rewards/accuracy",
241247
]:
242248
self.assertLen(
243-
dpo_trainer.metrics_logger.get_metric_history(metric_name, "train"),
249+
dpo_trainer.metrics_logger.get_metric_history(
250+
"", metric_name, "train"
251+
),
244252
dpo_trainer._train_steps,
245253
)
246254

247255
def test_dpo_loss_fn(self):
248256
np.random.seed(0)
249-
model = tc.ToyTransformer(rngs=nnx.Rngs(0))
257+
model = tc.ToyTransformer(config=tc.ModelConfig(), rngs=nnx.Rngs(0))
250258
per_token_logps = np.random.normal(0, 5, size=(8, 4))
251259
ref_per_token_logps = np.random.normal(0, 5, size=(8, 4)).sum(axis=-1)
252260
train_example = dpo_lib.TrainExample(
@@ -262,20 +270,26 @@ def test_dpo_loss_fn(self):
262270
with mock.patch.object(
263271
common, "get_per_token_logps", return_value=jnp.array(per_token_logps)
264272
):
265-
loss, _ = dpo_lib.dpo_loss_fn(model, train_example, 0.1, 0)
273+
loss, _ = dpo_lib.dpo_loss_fn(
274+
model, train_example, beta=0.1, label_smoothing=0
275+
)
266276
np.testing.assert_allclose(loss, 0.753059, atol=1e-5)
267277

268-
loss, _ = dpo_lib.dpo_loss_fn(model, train_example, 0.1, 0.3)
278+
loss, _ = dpo_lib.dpo_loss_fn(
279+
model, train_example, beta=0.1, label_smoothing=0.3
280+
)
269281
np.testing.assert_allclose(loss, 0.925447, atol=1e-5)
270282

271283
def test_dpo_prepare_inputs_for_strings(self):
272284
tokenizer = tc.MockVocab()
273285

274286
model = tc.ToyTransformer(
275-
rngs=nnx.Rngs(0), vocab_size=tokenizer.GetPieceSize()
287+
config=tc.ModelConfig(vocab_size=tokenizer.GetPieceSize()),
288+
rngs=nnx.Rngs(0),
276289
)
277290
ref_model = tc.ToyTransformer(
278-
rngs=nnx.Rngs(0), vocab_size=tokenizer.GetPieceSize()
291+
config=tc.ModelConfig(vocab_size=tokenizer.GetPieceSize()),
292+
rngs=nnx.Rngs(0),
279293
)
280294
dpo_trainer = dpo_lib.DPOTrainer(
281295
model=model,
@@ -328,8 +342,8 @@ def test_dpo_prepare_inputs_for_strings(self):
328342
self.assertEqual(out.logits_to_keep, 3)
329343

330344
def test_dpo_prepare_inputs(self):
331-
model = tc.ToyTransformer(rngs=nnx.Rngs(0))
332-
ref_model = tc.ToyTransformer(rngs=nnx.Rngs(0))
345+
model = tc.ToyTransformer(config=tc.ModelConfig(), rngs=nnx.Rngs(0))
346+
ref_model = tc.ToyTransformer(config=tc.ModelConfig(), rngs=nnx.Rngs(0))
333347
dpo_trainer = dpo_lib.DPOTrainer(
334348
model=model,
335349
ref_model=ref_model,

0 commit comments

Comments
 (0)