Skip to content

Commit 4970237

Browse files
author
The tunix Authors
committed
Merge pull request #594 from eghuzefa:feat-add-orpo-support
PiperOrigin-RevId: 834058377
2 parents e7a9905 + 9e99644 commit 4970237

File tree

4 files changed

+531
-54
lines changed

4 files changed

+531
-54
lines changed

tests/sft/dpo/dpo_trainer_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,10 +270,14 @@ def test_dpo_loss_fn(self):
270270
with mock.patch.object(
271271
common, "get_per_token_logps", return_value=jnp.array(per_token_logps)
272272
):
273-
loss, _ = dpo_lib.dpo_loss_fn(model, train_example, 0.1, 0)
273+
loss, _ = dpo_lib.dpo_loss_fn(
274+
model, train_example, beta=0.1, label_smoothing=0
275+
)
274276
np.testing.assert_allclose(loss, 0.753059, atol=1e-5)
275277

276-
loss, _ = dpo_lib.dpo_loss_fn(model, train_example, 0.1, 0.3)
278+
loss, _ = dpo_lib.dpo_loss_fn(
279+
model, train_example, beta=0.1, label_smoothing=0.3
280+
)
277281
np.testing.assert_allclose(loss, 0.925447, atol=1e-5)
278282

279283
def test_dpo_prepare_inputs_for_strings(self):

tests/sft/dpo/orpo_trainer_test.py

Lines changed: 361 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,361 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from unittest import mock
16+
from absl.testing import absltest
17+
from absl.testing import parameterized
18+
from flax import nnx
19+
from grain import python as grain
20+
import jax
21+
import jax.numpy as jnp
22+
import numpy as np
23+
import optax
24+
from tunix.rl import common
25+
from tunix.sft.dpo import dpo_trainer as orpo_lib
26+
from tunix.tests import test_common as tc
27+
28+
jax.config.update("jax_threefry_partitionable", False)
29+
# jax.config.update("jax_debug_nans", True) # useful for debugging NaN
30+
31+
32+
class MySource(grain.RandomAccessDataSource):
33+
34+
def __init__(self, data):
35+
self._data = data
36+
37+
def __getitem__(self, idx):
38+
return self._data[idx]
39+
40+
def __len__(self):
41+
return len(self._data)
42+
43+
44+
def _dummy_dataset(
45+
source: MySource,
46+
prompt_ids: np.ndarray,
47+
prompt_mask: np.ndarray,
48+
chosen_ids: np.ndarray,
49+
chosen_mask: np.ndarray,
50+
rejected_ids: np.ndarray,
51+
rejected_mask: np.ndarray,
52+
):
53+
return grain.MapDataset.source(source).map(
54+
lambda x: orpo_lib.TrainingInput(
55+
prompt_ids=prompt_ids,
56+
prompt_mask=prompt_mask,
57+
chosen_ids=chosen_ids,
58+
chosen_mask=chosen_mask,
59+
rejected_ids=rejected_ids,
60+
rejected_mask=rejected_mask,
61+
)
62+
)
63+
64+
65+
def _dummy_string_dataset(
66+
source: MySource,
67+
prompts: np.ndarray,
68+
chosen_responses: np.ndarray,
69+
rejected_responses: np.ndarray,
70+
return_dict=False,
71+
):
72+
ds = grain.MapDataset.source(source)
73+
if return_dict:
74+
return ds.map(
75+
lambda x: {
76+
"prompts": prompts,
77+
"chosen_responses": chosen_responses,
78+
"rejected_responses": rejected_responses,
79+
}
80+
)
81+
else:
82+
return ds.map(
83+
lambda x: orpo_lib.DataInput(
84+
prompts=prompts,
85+
chosen_responses=chosen_responses,
86+
rejected_responses=rejected_responses,
87+
)
88+
)
89+
90+
91+
class ORPOTrainerTest(parameterized.TestCase):
92+
93+
@parameterized.named_parameters(
94+
dict(
95+
testcase_name="basic_training",
96+
prompt_ids=np.arange(0, 10).reshape(2, 5),
97+
prompt_mask=np.ones((2, 5)),
98+
chosen_ids=np.arange(10, 20).reshape(2, 5),
99+
chosen_mask=np.ones((2, 5)),
100+
rejected_ids=np.arange(20, 30).reshape(2, 5),
101+
rejected_mask=np.ones((2, 5)),
102+
),
103+
)
104+
def test_orpo_trainer(
105+
self,
106+
prompt_ids,
107+
prompt_mask,
108+
chosen_ids,
109+
chosen_mask,
110+
rejected_ids,
111+
rejected_mask,
112+
):
113+
model = tc.ToyTransformer(config=tc.ModelConfig(), rngs=nnx.Rngs(0))
114+
original_variables = jax.tree.map(jnp.copy, nnx.state(model, nnx.Param))
115+
orpo_config = orpo_lib.ORPOTrainingConfig(
116+
algorithm="orpo",
117+
eval_every_n_steps=5,
118+
max_steps=10,
119+
)
120+
orpo_trainer = orpo_lib.ORPOTrainer(
121+
model=model,
122+
ref_model=None,
123+
optimizer=optax.sgd(1e-3),
124+
training_config=orpo_config,
125+
)
126+
train_ds = _dummy_dataset(
127+
MySource(np.arange(10)),
128+
prompt_ids,
129+
prompt_mask,
130+
chosen_ids,
131+
chosen_mask,
132+
rejected_ids,
133+
rejected_mask,
134+
)
135+
eval_ds = _dummy_dataset(
136+
MySource(np.arange(2)),
137+
prompt_ids,
138+
prompt_mask,
139+
chosen_ids,
140+
chosen_mask,
141+
rejected_ids,
142+
rejected_mask,
143+
)
144+
orpo_trainer.train(train_ds, eval_ds=eval_ds)
145+
146+
variables = nnx.state(model, nnx.Param)
147+
jax.tree.map_with_path(tc.assert_not_equal, original_variables, variables)
148+
149+
for metric_name in [
150+
"rewards/chosen",
151+
"rewards/rejected",
152+
"rewards/margin",
153+
"rewards/accuracy",
154+
"log_probs/chosen",
155+
"log_probs/rejected",
156+
"odds_ratio",
157+
]:
158+
self.assertLen(
159+
orpo_trainer.metrics_logger.get_metric_history(
160+
"", metric_name, "train"
161+
),
162+
orpo_trainer._train_steps,
163+
)
164+
self.assertLen(
165+
orpo_trainer.metrics_logger.get_metric_history("", metric_name, "eval"),
166+
3,
167+
)
168+
169+
@parameterized.named_parameters(
170+
dict(
171+
testcase_name="dataclass_inputs",
172+
train_ds=_dummy_string_dataset(
173+
MySource(np.arange(10)),
174+
prompts=["Tunix", "Parallax"],
175+
chosen_responses=["PT", "distributed training"],
176+
rejected_responses=["optimizer library", "quantization"],
177+
),
178+
),
179+
dict(
180+
testcase_name="dict_inputs",
181+
train_ds=_dummy_string_dataset(
182+
MySource(np.arange(10)),
183+
prompts=["Tunix", "Parallax"],
184+
chosen_responses=["PT", "distributed training"],
185+
rejected_responses=["optimizer library", "quantization"],
186+
return_dict=True,
187+
),
188+
),
189+
)
190+
def test_orpo_trainer_with_string_inputs(self, train_ds):
191+
tokenizer = tc.MockVocab()
192+
model = tc.ToyTransformer(
193+
config=tc.ModelConfig(vocab_size=tokenizer.GetPieceSize()),
194+
rngs=nnx.Rngs(0),
195+
)
196+
original_variables = jax.tree.map(jnp.copy, nnx.state(model, nnx.Param))
197+
orpo_config = orpo_lib.ORPOTrainingConfig(
198+
algorithm="orpo",
199+
eval_every_n_steps=10,
200+
max_steps=10,
201+
max_prompt_length=3,
202+
max_response_length=3,
203+
)
204+
orpo_trainer = orpo_lib.ORPOTrainer(
205+
model=model,
206+
ref_model=None,
207+
optimizer=optax.sgd(1e-3),
208+
training_config=orpo_config,
209+
tokenizer=tokenizer,
210+
)
211+
orpo_trainer.train(train_ds, None)
212+
213+
variables = nnx.state(model, nnx.Param)
214+
jax.tree.map_with_path(tc.assert_not_equal, original_variables, variables)
215+
216+
for metric_name in [
217+
"rewards/chosen",
218+
"rewards/rejected",
219+
"rewards/margin",
220+
"rewards/accuracy",
221+
]:
222+
self.assertLen(
223+
orpo_trainer.metrics_logger.get_metric_history(
224+
"", metric_name, "train"
225+
),
226+
orpo_trainer._train_steps,
227+
)
228+
229+
def test_orpo_loss_fn(self):
230+
"""Test ORPO loss function directly with mocked logps."""
231+
np.random.seed(0)
232+
model = tc.ToyTransformer(config=tc.ModelConfig(), rngs=nnx.Rngs(0))
233+
# Use negative log probs (as they should be in reality)
234+
per_token_logps = -np.abs(np.random.normal(2, 1, size=(8, 4)))
235+
train_example = orpo_lib.TrainExample(
236+
input_ids=jnp.arange(0, 32).reshape(8, 4),
237+
positions=jnp.ones((8, 4)),
238+
attention_mask=jnp.ones((8, 4, 4)),
239+
ref_chosen_logps=None,
240+
ref_rejected_logps=None,
241+
completion_mask=jnp.ones((8, 4)),
242+
logits_to_keep=4,
243+
)
244+
245+
with mock.patch.object(
246+
common,
247+
"get_per_token_logps",
248+
return_value=jnp.array(per_token_logps),
249+
):
250+
loss, aux = orpo_lib.dpo_loss_fn(
251+
model,
252+
train_example,
253+
algorithm="orpo",
254+
lambda_orpo=0.1,
255+
label_smoothing=0,
256+
)
257+
# Loss should be a scalar and finite
258+
self.assertEqual(loss.shape, ())
259+
self.assertTrue(jnp.isfinite(loss))
260+
261+
# Check that aux metrics exist
262+
self.assertIn("rewards/chosen", aux)
263+
self.assertIn("rewards/rejected", aux)
264+
self.assertIn("rewards/margin", aux)
265+
self.assertIn("rewards/accuracy", aux)
266+
self.assertIn("log_probs/chosen", aux)
267+
self.assertIn("log_probs/rejected", aux)
268+
self.assertIn("odds_ratio", aux)
269+
270+
# Check that accuracy is between 0 and 1
271+
self.assertGreaterEqual(aux["rewards/accuracy"], 0.0)
272+
self.assertLessEqual(aux["rewards/accuracy"], 1.0)
273+
274+
def test_orpo_prepare_inputs_for_strings(self):
275+
tokenizer = tc.MockVocab()
276+
277+
model = tc.ToyTransformer(
278+
config=tc.ModelConfig(vocab_size=tokenizer.GetPieceSize()),
279+
rngs=nnx.Rngs(0),
280+
)
281+
orpo_trainer = orpo_lib.ORPOTrainer(
282+
model=model,
283+
ref_model=None,
284+
optimizer=optax.sgd(1e-3),
285+
training_config=orpo_lib.ORPOTrainingConfig(
286+
algorithm="orpo",
287+
eval_every_n_steps=10,
288+
max_steps=10,
289+
max_prompt_length=3,
290+
max_response_length=3,
291+
),
292+
tokenizer=tokenizer,
293+
)
294+
295+
# These are random strings, they hold no meaning.
296+
training_input = orpo_lib.DataInput(
297+
prompts=["Tunix", "Parallax"],
298+
chosen_responses=["PT", "distributed training"],
299+
rejected_responses=["optimizer library", "quantization"],
300+
)
301+
out = orpo_trainer._prepare_inputs(training_input)
302+
303+
expected_input_ids = np.array([
304+
[0, 1, 14, 1, 16, 0],
305+
[0, 1, 15, 1, 18, 19],
306+
[0, 1, 14, 1, 20, 17],
307+
[0, 1, 15, 1, 21, 0],
308+
])
309+
np.testing.assert_array_equal(out.input_ids, expected_input_ids)
310+
self.assertEqual(np.sum(out.attention_mask[0]), 14)
311+
self.assertEqual(np.sum(out.attention_mask[1]), 15)
312+
self.assertEqual(np.sum(out.attention_mask[2]), 15)
313+
self.assertEqual(np.sum(out.attention_mask[3]), 14)
314+
expected_completion_mask = np.array(
315+
[[1, 1, 0], [1, 1, 1], [1, 1, 1], [1, 1, 0]]
316+
)
317+
np.testing.assert_array_equal(out.completion_mask, expected_completion_mask)
318+
self.assertEqual(out.logits_to_keep, 3)
319+
320+
def test_orpo_prepare_inputs(self):
321+
model = tc.ToyTransformer(config=tc.ModelConfig(), rngs=nnx.Rngs(0))
322+
orpo_trainer = orpo_lib.ORPOTrainer(
323+
model=model,
324+
ref_model=None,
325+
optimizer=optax.sgd(1e-3),
326+
training_config=orpo_lib.ORPOTrainingConfig(
327+
algorithm="orpo",
328+
eval_every_n_steps=10,
329+
max_steps=10,
330+
),
331+
)
332+
333+
training_input = orpo_lib.TrainingInput(
334+
prompt_ids=np.array([[1, 2, 3, 4, 5], [0, 0, 1, 2, 3]]),
335+
prompt_mask=np.array([[1, 1, 1, 1, 1], [0, 0, 1, 1, 1]]),
336+
chosen_ids=np.array([[10, 11, 12, 0], [13, 14, 15, 16]]),
337+
chosen_mask=np.array([[1, 1, 1, 0], [1, 1, 1, 1]]),
338+
rejected_ids=np.array([[20, 21, 22, 0], [23, 0, 0, 0]]),
339+
rejected_mask=np.array([[1, 1, 1, 0], [1, 0, 0, 0]]),
340+
)
341+
out = orpo_trainer._prepare_inputs(training_input)
342+
expected_input_ids = np.array([
343+
[1, 2, 3, 4, 5, 10, 11, 12, 0],
344+
[0, 0, 1, 2, 3, 13, 14, 15, 16],
345+
[1, 2, 3, 4, 5, 20, 21, 22, 0],
346+
[0, 0, 1, 2, 3, 23, 0, 0, 0],
347+
])
348+
np.testing.assert_array_equal(out.input_ids, expected_input_ids)
349+
self.assertEqual(np.sum(out.attention_mask[0]), 44)
350+
self.assertEqual(np.sum(out.attention_mask[1]), 28)
351+
self.assertEqual(np.sum(out.attention_mask[2]), 44)
352+
self.assertEqual(np.sum(out.attention_mask[3]), 22)
353+
expected_completion_mask = np.array(
354+
[[1, 1, 1, 0], [1, 1, 1, 1], [1, 1, 1, 0], [1, 0, 0, 0]]
355+
)
356+
np.testing.assert_array_equal(out.completion_mask, expected_completion_mask)
357+
self.assertEqual(out.logits_to_keep, 4)
358+
359+
360+
if __name__ == "__main__":
361+
absltest.main()

tunix/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@
5353
from tunix.sft.dpo.dpo_trainer import DpoTrainer
5454
from tunix.sft.dpo.dpo_trainer import DPOTrainingConfig
5555
from tunix.sft.dpo.dpo_trainer import DpoTrainingConfig
56+
from tunix.sft.dpo.dpo_trainer import ORPOTrainer
57+
from tunix.sft.dpo.dpo_trainer import OrpoTrainer
58+
from tunix.sft.dpo.dpo_trainer import ORPOTrainingConfig
59+
from tunix.sft.dpo.dpo_trainer import OrpoTrainingConfig
5660
from tunix.sft.metrics_logger import MetricsLogger
5761
from tunix.sft.metrics_logger import MetricsLoggerOptions
5862
from tunix.sft.peft_trainer import PeftTrainer

0 commit comments

Comments
 (0)