|
| 1 | +# Copyright 2025 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from unittest import mock |
| 16 | +from absl.testing import absltest |
| 17 | +from absl.testing import parameterized |
| 18 | +from flax import nnx |
| 19 | +from grain import python as grain |
| 20 | +import jax |
| 21 | +import jax.numpy as jnp |
| 22 | +import numpy as np |
| 23 | +import optax |
| 24 | +from tunix.rl import common |
| 25 | +from tunix.sft.dpo import dpo_trainer as orpo_lib |
| 26 | +from tunix.tests import test_common as tc |
| 27 | + |
| 28 | +jax.config.update("jax_threefry_partitionable", False) |
| 29 | +# jax.config.update("jax_debug_nans", True) # useful for debugging NaN |
| 30 | + |
| 31 | + |
| 32 | +class MySource(grain.RandomAccessDataSource): |
| 33 | + |
| 34 | + def __init__(self, data): |
| 35 | + self._data = data |
| 36 | + |
| 37 | + def __getitem__(self, idx): |
| 38 | + return self._data[idx] |
| 39 | + |
| 40 | + def __len__(self): |
| 41 | + return len(self._data) |
| 42 | + |
| 43 | + |
| 44 | +def _dummy_dataset( |
| 45 | + source: MySource, |
| 46 | + prompt_ids: np.ndarray, |
| 47 | + prompt_mask: np.ndarray, |
| 48 | + chosen_ids: np.ndarray, |
| 49 | + chosen_mask: np.ndarray, |
| 50 | + rejected_ids: np.ndarray, |
| 51 | + rejected_mask: np.ndarray, |
| 52 | +): |
| 53 | + return grain.MapDataset.source(source).map( |
| 54 | + lambda x: orpo_lib.TrainingInput( |
| 55 | + prompt_ids=prompt_ids, |
| 56 | + prompt_mask=prompt_mask, |
| 57 | + chosen_ids=chosen_ids, |
| 58 | + chosen_mask=chosen_mask, |
| 59 | + rejected_ids=rejected_ids, |
| 60 | + rejected_mask=rejected_mask, |
| 61 | + ) |
| 62 | + ) |
| 63 | + |
| 64 | + |
| 65 | +def _dummy_string_dataset( |
| 66 | + source: MySource, |
| 67 | + prompts: np.ndarray, |
| 68 | + chosen_responses: np.ndarray, |
| 69 | + rejected_responses: np.ndarray, |
| 70 | + return_dict=False, |
| 71 | +): |
| 72 | + ds = grain.MapDataset.source(source) |
| 73 | + if return_dict: |
| 74 | + return ds.map( |
| 75 | + lambda x: { |
| 76 | + "prompts": prompts, |
| 77 | + "chosen_responses": chosen_responses, |
| 78 | + "rejected_responses": rejected_responses, |
| 79 | + } |
| 80 | + ) |
| 81 | + else: |
| 82 | + return ds.map( |
| 83 | + lambda x: orpo_lib.DataInput( |
| 84 | + prompts=prompts, |
| 85 | + chosen_responses=chosen_responses, |
| 86 | + rejected_responses=rejected_responses, |
| 87 | + ) |
| 88 | + ) |
| 89 | + |
| 90 | + |
| 91 | +class ORPOTrainerTest(parameterized.TestCase): |
| 92 | + |
| 93 | + @parameterized.named_parameters( |
| 94 | + dict( |
| 95 | + testcase_name="basic_training", |
| 96 | + prompt_ids=np.arange(0, 10).reshape(2, 5), |
| 97 | + prompt_mask=np.ones((2, 5)), |
| 98 | + chosen_ids=np.arange(10, 20).reshape(2, 5), |
| 99 | + chosen_mask=np.ones((2, 5)), |
| 100 | + rejected_ids=np.arange(20, 30).reshape(2, 5), |
| 101 | + rejected_mask=np.ones((2, 5)), |
| 102 | + ), |
| 103 | + ) |
| 104 | + def test_orpo_trainer( |
| 105 | + self, |
| 106 | + prompt_ids, |
| 107 | + prompt_mask, |
| 108 | + chosen_ids, |
| 109 | + chosen_mask, |
| 110 | + rejected_ids, |
| 111 | + rejected_mask, |
| 112 | + ): |
| 113 | + model = tc.ToyTransformer(config=tc.ModelConfig(), rngs=nnx.Rngs(0)) |
| 114 | + original_variables = jax.tree.map(jnp.copy, nnx.state(model, nnx.Param)) |
| 115 | + orpo_config = orpo_lib.ORPOTrainingConfig( |
| 116 | + algorithm="orpo", |
| 117 | + eval_every_n_steps=5, |
| 118 | + max_steps=10, |
| 119 | + ) |
| 120 | + orpo_trainer = orpo_lib.ORPOTrainer( |
| 121 | + model=model, |
| 122 | + ref_model=None, |
| 123 | + optimizer=optax.sgd(1e-3), |
| 124 | + training_config=orpo_config, |
| 125 | + ) |
| 126 | + train_ds = _dummy_dataset( |
| 127 | + MySource(np.arange(10)), |
| 128 | + prompt_ids, |
| 129 | + prompt_mask, |
| 130 | + chosen_ids, |
| 131 | + chosen_mask, |
| 132 | + rejected_ids, |
| 133 | + rejected_mask, |
| 134 | + ) |
| 135 | + eval_ds = _dummy_dataset( |
| 136 | + MySource(np.arange(2)), |
| 137 | + prompt_ids, |
| 138 | + prompt_mask, |
| 139 | + chosen_ids, |
| 140 | + chosen_mask, |
| 141 | + rejected_ids, |
| 142 | + rejected_mask, |
| 143 | + ) |
| 144 | + orpo_trainer.train(train_ds, eval_ds=eval_ds) |
| 145 | + |
| 146 | + variables = nnx.state(model, nnx.Param) |
| 147 | + jax.tree.map_with_path(tc.assert_not_equal, original_variables, variables) |
| 148 | + |
| 149 | + for metric_name in [ |
| 150 | + "rewards/chosen", |
| 151 | + "rewards/rejected", |
| 152 | + "rewards/margin", |
| 153 | + "rewards/accuracy", |
| 154 | + "log_probs/chosen", |
| 155 | + "log_probs/rejected", |
| 156 | + "odds_ratio", |
| 157 | + ]: |
| 158 | + self.assertLen( |
| 159 | + orpo_trainer.metrics_logger.get_metric_history( |
| 160 | + "", metric_name, "train" |
| 161 | + ), |
| 162 | + orpo_trainer._train_steps, |
| 163 | + ) |
| 164 | + self.assertLen( |
| 165 | + orpo_trainer.metrics_logger.get_metric_history("", metric_name, "eval"), |
| 166 | + 3, |
| 167 | + ) |
| 168 | + |
| 169 | + @parameterized.named_parameters( |
| 170 | + dict( |
| 171 | + testcase_name="dataclass_inputs", |
| 172 | + train_ds=_dummy_string_dataset( |
| 173 | + MySource(np.arange(10)), |
| 174 | + prompts=["Tunix", "Parallax"], |
| 175 | + chosen_responses=["PT", "distributed training"], |
| 176 | + rejected_responses=["optimizer library", "quantization"], |
| 177 | + ), |
| 178 | + ), |
| 179 | + dict( |
| 180 | + testcase_name="dict_inputs", |
| 181 | + train_ds=_dummy_string_dataset( |
| 182 | + MySource(np.arange(10)), |
| 183 | + prompts=["Tunix", "Parallax"], |
| 184 | + chosen_responses=["PT", "distributed training"], |
| 185 | + rejected_responses=["optimizer library", "quantization"], |
| 186 | + return_dict=True, |
| 187 | + ), |
| 188 | + ), |
| 189 | + ) |
| 190 | + def test_orpo_trainer_with_string_inputs(self, train_ds): |
| 191 | + tokenizer = tc.MockVocab() |
| 192 | + model = tc.ToyTransformer( |
| 193 | + config=tc.ModelConfig(vocab_size=tokenizer.GetPieceSize()), |
| 194 | + rngs=nnx.Rngs(0), |
| 195 | + ) |
| 196 | + original_variables = jax.tree.map(jnp.copy, nnx.state(model, nnx.Param)) |
| 197 | + orpo_config = orpo_lib.ORPOTrainingConfig( |
| 198 | + algorithm="orpo", |
| 199 | + eval_every_n_steps=10, |
| 200 | + max_steps=10, |
| 201 | + max_prompt_length=3, |
| 202 | + max_response_length=3, |
| 203 | + ) |
| 204 | + orpo_trainer = orpo_lib.ORPOTrainer( |
| 205 | + model=model, |
| 206 | + ref_model=None, |
| 207 | + optimizer=optax.sgd(1e-3), |
| 208 | + training_config=orpo_config, |
| 209 | + tokenizer=tokenizer, |
| 210 | + ) |
| 211 | + orpo_trainer.train(train_ds, None) |
| 212 | + |
| 213 | + variables = nnx.state(model, nnx.Param) |
| 214 | + jax.tree.map_with_path(tc.assert_not_equal, original_variables, variables) |
| 215 | + |
| 216 | + for metric_name in [ |
| 217 | + "rewards/chosen", |
| 218 | + "rewards/rejected", |
| 219 | + "rewards/margin", |
| 220 | + "rewards/accuracy", |
| 221 | + ]: |
| 222 | + self.assertLen( |
| 223 | + orpo_trainer.metrics_logger.get_metric_history( |
| 224 | + "", metric_name, "train" |
| 225 | + ), |
| 226 | + orpo_trainer._train_steps, |
| 227 | + ) |
| 228 | + |
| 229 | + def test_orpo_loss_fn(self): |
| 230 | + """Test ORPO loss function directly with mocked logps.""" |
| 231 | + np.random.seed(0) |
| 232 | + model = tc.ToyTransformer(config=tc.ModelConfig(), rngs=nnx.Rngs(0)) |
| 233 | + # Use negative log probs (as they should be in reality) |
| 234 | + per_token_logps = -np.abs(np.random.normal(2, 1, size=(8, 4))) |
| 235 | + train_example = orpo_lib.TrainExample( |
| 236 | + input_ids=jnp.arange(0, 32).reshape(8, 4), |
| 237 | + positions=jnp.ones((8, 4)), |
| 238 | + attention_mask=jnp.ones((8, 4, 4)), |
| 239 | + ref_chosen_logps=None, |
| 240 | + ref_rejected_logps=None, |
| 241 | + completion_mask=jnp.ones((8, 4)), |
| 242 | + logits_to_keep=4, |
| 243 | + ) |
| 244 | + |
| 245 | + with mock.patch.object( |
| 246 | + common, |
| 247 | + "get_per_token_logps", |
| 248 | + return_value=jnp.array(per_token_logps), |
| 249 | + ): |
| 250 | + loss, aux = orpo_lib.dpo_loss_fn( |
| 251 | + model, |
| 252 | + train_example, |
| 253 | + algorithm="orpo", |
| 254 | + lambda_orpo=0.1, |
| 255 | + label_smoothing=0, |
| 256 | + ) |
| 257 | + # Loss should be a scalar and finite |
| 258 | + self.assertEqual(loss.shape, ()) |
| 259 | + self.assertTrue(jnp.isfinite(loss)) |
| 260 | + |
| 261 | + # Check that aux metrics exist |
| 262 | + self.assertIn("rewards/chosen", aux) |
| 263 | + self.assertIn("rewards/rejected", aux) |
| 264 | + self.assertIn("rewards/margin", aux) |
| 265 | + self.assertIn("rewards/accuracy", aux) |
| 266 | + self.assertIn("log_probs/chosen", aux) |
| 267 | + self.assertIn("log_probs/rejected", aux) |
| 268 | + self.assertIn("odds_ratio", aux) |
| 269 | + |
| 270 | + # Check that accuracy is between 0 and 1 |
| 271 | + self.assertGreaterEqual(aux["rewards/accuracy"], 0.0) |
| 272 | + self.assertLessEqual(aux["rewards/accuracy"], 1.0) |
| 273 | + |
| 274 | + def test_orpo_prepare_inputs_for_strings(self): |
| 275 | + tokenizer = tc.MockVocab() |
| 276 | + |
| 277 | + model = tc.ToyTransformer( |
| 278 | + config=tc.ModelConfig(vocab_size=tokenizer.GetPieceSize()), |
| 279 | + rngs=nnx.Rngs(0), |
| 280 | + ) |
| 281 | + orpo_trainer = orpo_lib.ORPOTrainer( |
| 282 | + model=model, |
| 283 | + ref_model=None, |
| 284 | + optimizer=optax.sgd(1e-3), |
| 285 | + training_config=orpo_lib.ORPOTrainingConfig( |
| 286 | + algorithm="orpo", |
| 287 | + eval_every_n_steps=10, |
| 288 | + max_steps=10, |
| 289 | + max_prompt_length=3, |
| 290 | + max_response_length=3, |
| 291 | + ), |
| 292 | + tokenizer=tokenizer, |
| 293 | + ) |
| 294 | + |
| 295 | + # These are random strings, they hold no meaning. |
| 296 | + training_input = orpo_lib.DataInput( |
| 297 | + prompts=["Tunix", "Parallax"], |
| 298 | + chosen_responses=["PT", "distributed training"], |
| 299 | + rejected_responses=["optimizer library", "quantization"], |
| 300 | + ) |
| 301 | + out = orpo_trainer._prepare_inputs(training_input) |
| 302 | + |
| 303 | + expected_input_ids = np.array([ |
| 304 | + [0, 1, 14, 1, 16, 0], |
| 305 | + [0, 1, 15, 1, 18, 19], |
| 306 | + [0, 1, 14, 1, 20, 17], |
| 307 | + [0, 1, 15, 1, 21, 0], |
| 308 | + ]) |
| 309 | + np.testing.assert_array_equal(out.input_ids, expected_input_ids) |
| 310 | + self.assertEqual(np.sum(out.attention_mask[0]), 14) |
| 311 | + self.assertEqual(np.sum(out.attention_mask[1]), 15) |
| 312 | + self.assertEqual(np.sum(out.attention_mask[2]), 15) |
| 313 | + self.assertEqual(np.sum(out.attention_mask[3]), 14) |
| 314 | + expected_completion_mask = np.array( |
| 315 | + [[1, 1, 0], [1, 1, 1], [1, 1, 1], [1, 1, 0]] |
| 316 | + ) |
| 317 | + np.testing.assert_array_equal(out.completion_mask, expected_completion_mask) |
| 318 | + self.assertEqual(out.logits_to_keep, 3) |
| 319 | + |
| 320 | + def test_orpo_prepare_inputs(self): |
| 321 | + model = tc.ToyTransformer(config=tc.ModelConfig(), rngs=nnx.Rngs(0)) |
| 322 | + orpo_trainer = orpo_lib.ORPOTrainer( |
| 323 | + model=model, |
| 324 | + ref_model=None, |
| 325 | + optimizer=optax.sgd(1e-3), |
| 326 | + training_config=orpo_lib.ORPOTrainingConfig( |
| 327 | + algorithm="orpo", |
| 328 | + eval_every_n_steps=10, |
| 329 | + max_steps=10, |
| 330 | + ), |
| 331 | + ) |
| 332 | + |
| 333 | + training_input = orpo_lib.TrainingInput( |
| 334 | + prompt_ids=np.array([[1, 2, 3, 4, 5], [0, 0, 1, 2, 3]]), |
| 335 | + prompt_mask=np.array([[1, 1, 1, 1, 1], [0, 0, 1, 1, 1]]), |
| 336 | + chosen_ids=np.array([[10, 11, 12, 0], [13, 14, 15, 16]]), |
| 337 | + chosen_mask=np.array([[1, 1, 1, 0], [1, 1, 1, 1]]), |
| 338 | + rejected_ids=np.array([[20, 21, 22, 0], [23, 0, 0, 0]]), |
| 339 | + rejected_mask=np.array([[1, 1, 1, 0], [1, 0, 0, 0]]), |
| 340 | + ) |
| 341 | + out = orpo_trainer._prepare_inputs(training_input) |
| 342 | + expected_input_ids = np.array([ |
| 343 | + [1, 2, 3, 4, 5, 10, 11, 12, 0], |
| 344 | + [0, 0, 1, 2, 3, 13, 14, 15, 16], |
| 345 | + [1, 2, 3, 4, 5, 20, 21, 22, 0], |
| 346 | + [0, 0, 1, 2, 3, 23, 0, 0, 0], |
| 347 | + ]) |
| 348 | + np.testing.assert_array_equal(out.input_ids, expected_input_ids) |
| 349 | + self.assertEqual(np.sum(out.attention_mask[0]), 44) |
| 350 | + self.assertEqual(np.sum(out.attention_mask[1]), 28) |
| 351 | + self.assertEqual(np.sum(out.attention_mask[2]), 44) |
| 352 | + self.assertEqual(np.sum(out.attention_mask[3]), 22) |
| 353 | + expected_completion_mask = np.array( |
| 354 | + [[1, 1, 1, 0], [1, 1, 1, 1], [1, 1, 1, 0], [1, 0, 0, 0]] |
| 355 | + ) |
| 356 | + np.testing.assert_array_equal(out.completion_mask, expected_completion_mask) |
| 357 | + self.assertEqual(out.logits_to_keep, 4) |
| 358 | + |
| 359 | + |
| 360 | +if __name__ == "__main__": |
| 361 | + absltest.main() |
0 commit comments