Skip to content

Commit 5f10c23

Browse files
committed
feat: add ORPO support
1 parent 783b1e7 commit 5f10c23

File tree

4 files changed

+527
-54
lines changed

4 files changed

+527
-54
lines changed

tests/sft/dpo/dpo_trainer_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,10 +270,14 @@ def test_dpo_loss_fn(self):
270270
with mock.patch.object(
271271
common, "get_per_token_logps", return_value=jnp.array(per_token_logps)
272272
):
273-
loss, _ = dpo_lib.dpo_loss_fn(model, train_example, 0.1, 0)
273+
loss, _ = dpo_lib.dpo_loss_fn(
274+
model, train_example, beta=0.1, label_smoothing=0
275+
)
274276
np.testing.assert_allclose(loss, 0.753059, atol=1e-5)
275277

276-
loss, _ = dpo_lib.dpo_loss_fn(model, train_example, 0.1, 0.3)
278+
loss, _ = dpo_lib.dpo_loss_fn(
279+
model, train_example, beta=0.1, label_smoothing=0.3
280+
)
277281
np.testing.assert_allclose(loss, 0.925447, atol=1e-5)
278282

279283
def test_dpo_prepare_inputs_for_strings(self):

tests/sft/dpo/orpo_trainer_test.py

Lines changed: 357 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,357 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from unittest import mock
16+
from absl.testing import absltest
17+
from absl.testing import parameterized
18+
from flax import nnx
19+
from grain import python as grain
20+
import jax
21+
import jax.numpy as jnp
22+
import numpy as np
23+
import optax
24+
from tunix.rl import common
25+
from tunix.sft.dpo import dpo_trainer as orpo_lib
26+
from tunix.tests import test_common as tc
27+
28+
jax.config.update("jax_threefry_partitionable", False)
29+
# jax.config.update("jax_debug_nans", True) # useful for debugging NaN
30+
31+
32+
class MySource(grain.RandomAccessDataSource):
33+
34+
def __init__(self, data):
35+
self._data = data
36+
37+
def __getitem__(self, idx):
38+
return self._data[idx]
39+
40+
def __len__(self):
41+
return len(self._data)
42+
43+
44+
def _dummy_dataset(
45+
source: MySource,
46+
prompt_ids: np.ndarray,
47+
prompt_mask: np.ndarray,
48+
chosen_ids: np.ndarray,
49+
chosen_mask: np.ndarray,
50+
rejected_ids: np.ndarray,
51+
rejected_mask: np.ndarray,
52+
):
53+
return grain.MapDataset.source(source).map(
54+
lambda x: orpo_lib.TrainingInput(
55+
prompt_ids=prompt_ids,
56+
prompt_mask=prompt_mask,
57+
chosen_ids=chosen_ids,
58+
chosen_mask=chosen_mask,
59+
rejected_ids=rejected_ids,
60+
rejected_mask=rejected_mask,
61+
)
62+
)
63+
64+
65+
def _dummy_string_dataset(
66+
source: MySource,
67+
prompts: np.ndarray,
68+
chosen_responses: np.ndarray,
69+
rejected_responses: np.ndarray,
70+
return_dict=False,
71+
):
72+
ds = grain.MapDataset.source(source)
73+
if return_dict:
74+
return ds.map(
75+
lambda x: {
76+
"prompts": prompts,
77+
"chosen_responses": chosen_responses,
78+
"rejected_responses": rejected_responses,
79+
}
80+
)
81+
else:
82+
return ds.map(
83+
lambda x: orpo_lib.DataInput(
84+
prompts=prompts,
85+
chosen_responses=chosen_responses,
86+
rejected_responses=rejected_responses,
87+
)
88+
)
89+
90+
91+
class ORPOTrainerTest(parameterized.TestCase):
92+
93+
@parameterized.named_parameters(
94+
dict(
95+
testcase_name="basic_training",
96+
prompt_ids=np.arange(0, 10).reshape(2, 5),
97+
prompt_mask=np.ones((2, 5)),
98+
chosen_ids=np.arange(10, 20).reshape(2, 5),
99+
chosen_mask=np.ones((2, 5)),
100+
rejected_ids=np.arange(20, 30).reshape(2, 5),
101+
rejected_mask=np.ones((2, 5)),
102+
),
103+
)
104+
def test_orpo_trainer(
105+
self,
106+
prompt_ids,
107+
prompt_mask,
108+
chosen_ids,
109+
chosen_mask,
110+
rejected_ids,
111+
rejected_mask,
112+
):
113+
model = tc.ToyTransformer(config=tc.ModelConfig(), rngs=nnx.Rngs(0))
114+
original_variables = jax.tree.map(jnp.copy, nnx.state(model, nnx.Param))
115+
orpo_config = orpo_lib.ORPOTrainingConfig(
116+
algorithm="orpo",
117+
eval_every_n_steps=5,
118+
max_steps=10,
119+
)
120+
orpo_trainer = orpo_lib.ORPOTrainer(
121+
model=model,
122+
ref_model=None,
123+
optimizer=optax.sgd(1e-3),
124+
training_config=orpo_config,
125+
)
126+
train_ds = _dummy_dataset(
127+
MySource(np.arange(10)),
128+
prompt_ids,
129+
prompt_mask,
130+
chosen_ids,
131+
chosen_mask,
132+
rejected_ids,
133+
rejected_mask,
134+
)
135+
eval_ds = _dummy_dataset(
136+
MySource(np.arange(2)),
137+
prompt_ids,
138+
prompt_mask,
139+
chosen_ids,
140+
chosen_mask,
141+
rejected_ids,
142+
rejected_mask,
143+
)
144+
orpo_trainer.train(train_ds, eval_ds=eval_ds)
145+
146+
variables = nnx.state(model, nnx.Param)
147+
jax.tree.map_with_path(tc.assert_not_equal, original_variables, variables)
148+
149+
for metric_name in [
150+
"rewards/chosen",
151+
"rewards/rejected",
152+
"rewards/margin",
153+
"rewards/accuracy",
154+
"log_probs/chosen",
155+
"log_probs/rejected",
156+
"odds_ratio",
157+
]:
158+
self.assertLen(
159+
orpo_trainer.metrics_logger.get_metric_history(metric_name, "train"),
160+
orpo_trainer._train_steps,
161+
)
162+
self.assertLen(
163+
orpo_trainer.metrics_logger.get_metric_history(metric_name, "eval"),
164+
3,
165+
)
166+
167+
@parameterized.named_parameters(
168+
dict(
169+
testcase_name="dataclass_inputs",
170+
train_ds=_dummy_string_dataset(
171+
MySource(np.arange(10)),
172+
prompts=["Tunix", "Parallax"],
173+
chosen_responses=["PT", "distributed training"],
174+
rejected_responses=["optimizer library", "quantization"],
175+
),
176+
),
177+
dict(
178+
testcase_name="dict_inputs",
179+
train_ds=_dummy_string_dataset(
180+
MySource(np.arange(10)),
181+
prompts=["Tunix", "Parallax"],
182+
chosen_responses=["PT", "distributed training"],
183+
rejected_responses=["optimizer library", "quantization"],
184+
return_dict=True,
185+
),
186+
),
187+
)
188+
def test_orpo_trainer_with_string_inputs(self, train_ds):
189+
tokenizer = tc.MockVocab()
190+
model = tc.ToyTransformer(
191+
config=tc.ModelConfig(vocab_size=tokenizer.GetPieceSize()),
192+
rngs=nnx.Rngs(0),
193+
)
194+
original_variables = jax.tree.map(jnp.copy, nnx.state(model, nnx.Param))
195+
orpo_config = orpo_lib.ORPOTrainingConfig(
196+
algorithm="orpo",
197+
eval_every_n_steps=10,
198+
max_steps=10,
199+
max_prompt_length=3,
200+
max_response_length=3,
201+
)
202+
orpo_trainer = orpo_lib.ORPOTrainer(
203+
model=model,
204+
ref_model=None,
205+
optimizer=optax.sgd(1e-3),
206+
training_config=orpo_config,
207+
tokenizer=tokenizer,
208+
)
209+
orpo_trainer.train(train_ds, None)
210+
211+
variables = nnx.state(model, nnx.Param)
212+
jax.tree.map_with_path(tc.assert_not_equal, original_variables, variables)
213+
214+
for metric_name in [
215+
"rewards/chosen",
216+
"rewards/rejected",
217+
"rewards/margin",
218+
"rewards/accuracy",
219+
]:
220+
self.assertLen(
221+
orpo_trainer.metrics_logger.get_metric_history(metric_name, "train"),
222+
orpo_trainer._train_steps,
223+
)
224+
225+
def test_orpo_loss_fn(self):
226+
"""Test ORPO loss function directly with mocked logps."""
227+
np.random.seed(0)
228+
model = tc.ToyTransformer(config=tc.ModelConfig(), rngs=nnx.Rngs(0))
229+
# Use negative log probs (as they should be in reality)
230+
per_token_logps = -np.abs(np.random.normal(2, 1, size=(8, 4)))
231+
train_example = orpo_lib.TrainExample(
232+
input_ids=jnp.arange(0, 32).reshape(8, 4),
233+
positions=jnp.ones((8, 4)),
234+
attention_mask=jnp.ones((8, 4, 4)),
235+
ref_chosen_logps=None,
236+
ref_rejected_logps=None,
237+
completion_mask=jnp.ones((8, 4)),
238+
logits_to_keep=4,
239+
)
240+
241+
with mock.patch.object(
242+
common,
243+
"get_per_token_logps",
244+
return_value=jnp.array(per_token_logps),
245+
):
246+
loss, aux = orpo_lib.dpo_loss_fn(
247+
model,
248+
train_example,
249+
algorithm="orpo",
250+
lambda_orpo=0.1,
251+
label_smoothing=0,
252+
)
253+
# Loss should be a scalar and finite
254+
self.assertEqual(loss.shape, ())
255+
self.assertTrue(jnp.isfinite(loss))
256+
257+
# Check that aux metrics exist
258+
self.assertIn("rewards/chosen", aux)
259+
self.assertIn("rewards/rejected", aux)
260+
self.assertIn("rewards/margin", aux)
261+
self.assertIn("rewards/accuracy", aux)
262+
self.assertIn("log_probs/chosen", aux)
263+
self.assertIn("log_probs/rejected", aux)
264+
self.assertIn("odds_ratio", aux)
265+
266+
# Check that accuracy is between 0 and 1
267+
self.assertGreaterEqual(aux["rewards/accuracy"], 0.0)
268+
self.assertLessEqual(aux["rewards/accuracy"], 1.0)
269+
270+
def test_orpo_prepare_inputs_for_strings(self):
271+
tokenizer = tc.MockVocab()
272+
273+
model = tc.ToyTransformer(
274+
config=tc.ModelConfig(vocab_size=tokenizer.GetPieceSize()),
275+
rngs=nnx.Rngs(0),
276+
)
277+
orpo_trainer = orpo_lib.ORPOTrainer(
278+
model=model,
279+
ref_model=None,
280+
optimizer=optax.sgd(1e-3),
281+
training_config=orpo_lib.ORPOTrainingConfig(
282+
algorithm="orpo",
283+
eval_every_n_steps=10,
284+
max_steps=10,
285+
max_prompt_length=3,
286+
max_response_length=3,
287+
),
288+
tokenizer=tokenizer,
289+
)
290+
291+
# These are random strings, they hold no meaning.
292+
training_input = orpo_lib.DataInput(
293+
prompts=["Tunix", "Parallax"],
294+
chosen_responses=["PT", "distributed training"],
295+
rejected_responses=["optimizer library", "quantization"],
296+
)
297+
out = orpo_trainer._prepare_inputs(training_input)
298+
299+
expected_input_ids = np.array([
300+
[0, 1, 14, 1, 16, 0],
301+
[0, 1, 15, 1, 18, 19],
302+
[0, 1, 14, 1, 20, 17],
303+
[0, 1, 15, 1, 21, 0],
304+
])
305+
np.testing.assert_array_equal(out.input_ids, expected_input_ids)
306+
self.assertEqual(np.sum(out.attention_mask[0]), 14)
307+
self.assertEqual(np.sum(out.attention_mask[1]), 15)
308+
self.assertEqual(np.sum(out.attention_mask[2]), 15)
309+
self.assertEqual(np.sum(out.attention_mask[3]), 14)
310+
expected_completion_mask = np.array(
311+
[[1, 1, 0], [1, 1, 1], [1, 1, 1], [1, 1, 0]]
312+
)
313+
np.testing.assert_array_equal(out.completion_mask, expected_completion_mask)
314+
self.assertEqual(out.logits_to_keep, 3)
315+
316+
def test_orpo_prepare_inputs(self):
317+
model = tc.ToyTransformer(config=tc.ModelConfig(), rngs=nnx.Rngs(0))
318+
orpo_trainer = orpo_lib.ORPOTrainer(
319+
model=model,
320+
ref_model=None,
321+
optimizer=optax.sgd(1e-3),
322+
training_config=orpo_lib.ORPOTrainingConfig(
323+
algorithm="orpo",
324+
eval_every_n_steps=10,
325+
max_steps=10,
326+
),
327+
)
328+
329+
training_input = orpo_lib.TrainingInput(
330+
prompt_ids=np.array([[1, 2, 3, 4, 5], [0, 0, 1, 2, 3]]),
331+
prompt_mask=np.array([[1, 1, 1, 1, 1], [0, 0, 1, 1, 1]]),
332+
chosen_ids=np.array([[10, 11, 12, 0], [13, 14, 15, 16]]),
333+
chosen_mask=np.array([[1, 1, 1, 0], [1, 1, 1, 1]]),
334+
rejected_ids=np.array([[20, 21, 22, 0], [23, 0, 0, 0]]),
335+
rejected_mask=np.array([[1, 1, 1, 0], [1, 0, 0, 0]]),
336+
)
337+
out = orpo_trainer._prepare_inputs(training_input)
338+
expected_input_ids = np.array([
339+
[1, 2, 3, 4, 5, 10, 11, 12, 0],
340+
[0, 0, 1, 2, 3, 13, 14, 15, 16],
341+
[1, 2, 3, 4, 5, 20, 21, 22, 0],
342+
[0, 0, 1, 2, 3, 23, 0, 0, 0],
343+
])
344+
np.testing.assert_array_equal(out.input_ids, expected_input_ids)
345+
self.assertEqual(np.sum(out.attention_mask[0]), 44)
346+
self.assertEqual(np.sum(out.attention_mask[1]), 28)
347+
self.assertEqual(np.sum(out.attention_mask[2]), 44)
348+
self.assertEqual(np.sum(out.attention_mask[3]), 22)
349+
expected_completion_mask = np.array(
350+
[[1, 1, 1, 0], [1, 1, 1, 1], [1, 1, 1, 0], [1, 0, 0, 0]]
351+
)
352+
np.testing.assert_array_equal(out.completion_mask, expected_completion_mask)
353+
self.assertEqual(out.logits_to_keep, 4)
354+
355+
356+
if __name__ == "__main__":
357+
absltest.main()

tunix/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@
5454
from tunix.sft.dpo.dpo_trainer import DpoTrainer
5555
from tunix.sft.dpo.dpo_trainer import DPOTrainingConfig
5656
from tunix.sft.dpo.dpo_trainer import DpoTrainingConfig
57+
from tunix.sft.dpo.dpo_trainer import ORPOTrainer
58+
from tunix.sft.dpo.dpo_trainer import OrpoTrainer
59+
from tunix.sft.dpo.dpo_trainer import ORPOTrainingConfig
60+
from tunix.sft.dpo.dpo_trainer import OrpoTrainingConfig
5761
from tunix.sft.metrics_logger import MetricsLogger
5862
from tunix.sft.metrics_logger import MetricsLoggerOptions
5963
from tunix.sft.peft_trainer import PeftTrainer

0 commit comments

Comments
 (0)