From 9e996449b879f17101aa4886c9f2f01ee2213a85 Mon Sep 17 00:00:00 2001 From: Mohammad Huzefa Shaikh Date: Tue, 18 Nov 2025 21:15:19 +0400 Subject: [PATCH] feat: add ORPO support --- tests/sft/dpo/dpo_trainer_test.py | 8 +- tests/sft/dpo/orpo_trainer_test.py | 361 +++++++++++++++++++++++++++++ tunix/__init__.py | 4 + tunix/sft/dpo/dpo_trainer.py | 212 ++++++++++++----- 4 files changed, 531 insertions(+), 54 deletions(-) create mode 100644 tests/sft/dpo/orpo_trainer_test.py diff --git a/tests/sft/dpo/dpo_trainer_test.py b/tests/sft/dpo/dpo_trainer_test.py index a4c1baf9..f926df85 100644 --- a/tests/sft/dpo/dpo_trainer_test.py +++ b/tests/sft/dpo/dpo_trainer_test.py @@ -270,10 +270,14 @@ def test_dpo_loss_fn(self): with mock.patch.object( common, "get_per_token_logps", return_value=jnp.array(per_token_logps) ): - loss, _ = dpo_lib.dpo_loss_fn(model, train_example, 0.1, 0) + loss, _ = dpo_lib.dpo_loss_fn( + model, train_example, beta=0.1, label_smoothing=0 + ) np.testing.assert_allclose(loss, 0.753059, atol=1e-5) - loss, _ = dpo_lib.dpo_loss_fn(model, train_example, 0.1, 0.3) + loss, _ = dpo_lib.dpo_loss_fn( + model, train_example, beta=0.1, label_smoothing=0.3 + ) np.testing.assert_allclose(loss, 0.925447, atol=1e-5) def test_dpo_prepare_inputs_for_strings(self): diff --git a/tests/sft/dpo/orpo_trainer_test.py b/tests/sft/dpo/orpo_trainer_test.py new file mode 100644 index 00000000..8f5a3b84 --- /dev/null +++ b/tests/sft/dpo/orpo_trainer_test.py @@ -0,0 +1,361 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock +from absl.testing import absltest +from absl.testing import parameterized +from flax import nnx +from grain import python as grain +import jax +import jax.numpy as jnp +import numpy as np +import optax +from tunix.rl import common +from tunix.sft.dpo import dpo_trainer as orpo_lib +from tunix.tests import test_common as tc + +jax.config.update("jax_threefry_partitionable", False) +# jax.config.update("jax_debug_nans", True) # useful for debugging NaN + + +class MySource(grain.RandomAccessDataSource): + + def __init__(self, data): + self._data = data + + def __getitem__(self, idx): + return self._data[idx] + + def __len__(self): + return len(self._data) + + +def _dummy_dataset( + source: MySource, + prompt_ids: np.ndarray, + prompt_mask: np.ndarray, + chosen_ids: np.ndarray, + chosen_mask: np.ndarray, + rejected_ids: np.ndarray, + rejected_mask: np.ndarray, +): + return grain.MapDataset.source(source).map( + lambda x: orpo_lib.TrainingInput( + prompt_ids=prompt_ids, + prompt_mask=prompt_mask, + chosen_ids=chosen_ids, + chosen_mask=chosen_mask, + rejected_ids=rejected_ids, + rejected_mask=rejected_mask, + ) + ) + + +def _dummy_string_dataset( + source: MySource, + prompts: np.ndarray, + chosen_responses: np.ndarray, + rejected_responses: np.ndarray, + return_dict=False, +): + ds = grain.MapDataset.source(source) + if return_dict: + return ds.map( + lambda x: { + "prompts": prompts, + "chosen_responses": chosen_responses, + "rejected_responses": rejected_responses, + } + ) + else: + return ds.map( + lambda x: orpo_lib.DataInput( + prompts=prompts, + chosen_responses=chosen_responses, + rejected_responses=rejected_responses, + ) + ) + + +class ORPOTrainerTest(parameterized.TestCase): + + @parameterized.named_parameters( + dict( + testcase_name="basic_training", + prompt_ids=np.arange(0, 10).reshape(2, 5), + prompt_mask=np.ones((2, 5)), + chosen_ids=np.arange(10, 20).reshape(2, 5), + chosen_mask=np.ones((2, 5)), + rejected_ids=np.arange(20, 30).reshape(2, 5), + rejected_mask=np.ones((2, 5)), + ), + ) + def test_orpo_trainer( + self, + prompt_ids, + prompt_mask, + chosen_ids, + chosen_mask, + rejected_ids, + rejected_mask, + ): + model = tc.ToyTransformer(config=tc.ModelConfig(), rngs=nnx.Rngs(0)) + original_variables = jax.tree.map(jnp.copy, nnx.state(model, nnx.Param)) + orpo_config = orpo_lib.ORPOTrainingConfig( + algorithm="orpo", + eval_every_n_steps=5, + max_steps=10, + ) + orpo_trainer = orpo_lib.ORPOTrainer( + model=model, + ref_model=None, + optimizer=optax.sgd(1e-3), + training_config=orpo_config, + ) + train_ds = _dummy_dataset( + MySource(np.arange(10)), + prompt_ids, + prompt_mask, + chosen_ids, + chosen_mask, + rejected_ids, + rejected_mask, + ) + eval_ds = _dummy_dataset( + MySource(np.arange(2)), + prompt_ids, + prompt_mask, + chosen_ids, + chosen_mask, + rejected_ids, + rejected_mask, + ) + orpo_trainer.train(train_ds, eval_ds=eval_ds) + + variables = nnx.state(model, nnx.Param) + jax.tree.map_with_path(tc.assert_not_equal, original_variables, variables) + + for metric_name in [ + "rewards/chosen", + "rewards/rejected", + "rewards/margin", + "rewards/accuracy", + "log_probs/chosen", + "log_probs/rejected", + "odds_ratio", + ]: + self.assertLen( + orpo_trainer.metrics_logger.get_metric_history( + "", metric_name, "train" + ), + orpo_trainer._train_steps, + ) + self.assertLen( + orpo_trainer.metrics_logger.get_metric_history("", metric_name, "eval"), + 3, + ) + + @parameterized.named_parameters( + dict( + testcase_name="dataclass_inputs", + train_ds=_dummy_string_dataset( + MySource(np.arange(10)), + prompts=["Tunix", "Parallax"], + chosen_responses=["PT", "distributed training"], + rejected_responses=["optimizer library", "quantization"], + ), + ), + dict( + testcase_name="dict_inputs", + train_ds=_dummy_string_dataset( + MySource(np.arange(10)), + prompts=["Tunix", "Parallax"], + chosen_responses=["PT", "distributed training"], + rejected_responses=["optimizer library", "quantization"], + return_dict=True, + ), + ), + ) + def test_orpo_trainer_with_string_inputs(self, train_ds): + tokenizer = tc.MockVocab() + model = tc.ToyTransformer( + config=tc.ModelConfig(vocab_size=tokenizer.GetPieceSize()), + rngs=nnx.Rngs(0), + ) + original_variables = jax.tree.map(jnp.copy, nnx.state(model, nnx.Param)) + orpo_config = orpo_lib.ORPOTrainingConfig( + algorithm="orpo", + eval_every_n_steps=10, + max_steps=10, + max_prompt_length=3, + max_response_length=3, + ) + orpo_trainer = orpo_lib.ORPOTrainer( + model=model, + ref_model=None, + optimizer=optax.sgd(1e-3), + training_config=orpo_config, + tokenizer=tokenizer, + ) + orpo_trainer.train(train_ds, None) + + variables = nnx.state(model, nnx.Param) + jax.tree.map_with_path(tc.assert_not_equal, original_variables, variables) + + for metric_name in [ + "rewards/chosen", + "rewards/rejected", + "rewards/margin", + "rewards/accuracy", + ]: + self.assertLen( + orpo_trainer.metrics_logger.get_metric_history( + "", metric_name, "train" + ), + orpo_trainer._train_steps, + ) + + def test_orpo_loss_fn(self): + """Test ORPO loss function directly with mocked logps.""" + np.random.seed(0) + model = tc.ToyTransformer(config=tc.ModelConfig(), rngs=nnx.Rngs(0)) + # Use negative log probs (as they should be in reality) + per_token_logps = -np.abs(np.random.normal(2, 1, size=(8, 4))) + train_example = orpo_lib.TrainExample( + input_ids=jnp.arange(0, 32).reshape(8, 4), + positions=jnp.ones((8, 4)), + attention_mask=jnp.ones((8, 4, 4)), + ref_chosen_logps=None, + ref_rejected_logps=None, + completion_mask=jnp.ones((8, 4)), + logits_to_keep=4, + ) + + with mock.patch.object( + common, + "get_per_token_logps", + return_value=jnp.array(per_token_logps), + ): + loss, aux = orpo_lib.dpo_loss_fn( + model, + train_example, + algorithm="orpo", + lambda_orpo=0.1, + label_smoothing=0, + ) + # Loss should be a scalar and finite + self.assertEqual(loss.shape, ()) + self.assertTrue(jnp.isfinite(loss)) + + # Check that aux metrics exist + self.assertIn("rewards/chosen", aux) + self.assertIn("rewards/rejected", aux) + self.assertIn("rewards/margin", aux) + self.assertIn("rewards/accuracy", aux) + self.assertIn("log_probs/chosen", aux) + self.assertIn("log_probs/rejected", aux) + self.assertIn("odds_ratio", aux) + + # Check that accuracy is between 0 and 1 + self.assertGreaterEqual(aux["rewards/accuracy"], 0.0) + self.assertLessEqual(aux["rewards/accuracy"], 1.0) + + def test_orpo_prepare_inputs_for_strings(self): + tokenizer = tc.MockVocab() + + model = tc.ToyTransformer( + config=tc.ModelConfig(vocab_size=tokenizer.GetPieceSize()), + rngs=nnx.Rngs(0), + ) + orpo_trainer = orpo_lib.ORPOTrainer( + model=model, + ref_model=None, + optimizer=optax.sgd(1e-3), + training_config=orpo_lib.ORPOTrainingConfig( + algorithm="orpo", + eval_every_n_steps=10, + max_steps=10, + max_prompt_length=3, + max_response_length=3, + ), + tokenizer=tokenizer, + ) + + # These are random strings, they hold no meaning. + training_input = orpo_lib.DataInput( + prompts=["Tunix", "Parallax"], + chosen_responses=["PT", "distributed training"], + rejected_responses=["optimizer library", "quantization"], + ) + out = orpo_trainer._prepare_inputs(training_input) + + expected_input_ids = np.array([ + [0, 1, 14, 1, 16, 0], + [0, 1, 15, 1, 18, 19], + [0, 1, 14, 1, 20, 17], + [0, 1, 15, 1, 21, 0], + ]) + np.testing.assert_array_equal(out.input_ids, expected_input_ids) + self.assertEqual(np.sum(out.attention_mask[0]), 14) + self.assertEqual(np.sum(out.attention_mask[1]), 15) + self.assertEqual(np.sum(out.attention_mask[2]), 15) + self.assertEqual(np.sum(out.attention_mask[3]), 14) + expected_completion_mask = np.array( + [[1, 1, 0], [1, 1, 1], [1, 1, 1], [1, 1, 0]] + ) + np.testing.assert_array_equal(out.completion_mask, expected_completion_mask) + self.assertEqual(out.logits_to_keep, 3) + + def test_orpo_prepare_inputs(self): + model = tc.ToyTransformer(config=tc.ModelConfig(), rngs=nnx.Rngs(0)) + orpo_trainer = orpo_lib.ORPOTrainer( + model=model, + ref_model=None, + optimizer=optax.sgd(1e-3), + training_config=orpo_lib.ORPOTrainingConfig( + algorithm="orpo", + eval_every_n_steps=10, + max_steps=10, + ), + ) + + training_input = orpo_lib.TrainingInput( + prompt_ids=np.array([[1, 2, 3, 4, 5], [0, 0, 1, 2, 3]]), + prompt_mask=np.array([[1, 1, 1, 1, 1], [0, 0, 1, 1, 1]]), + chosen_ids=np.array([[10, 11, 12, 0], [13, 14, 15, 16]]), + chosen_mask=np.array([[1, 1, 1, 0], [1, 1, 1, 1]]), + rejected_ids=np.array([[20, 21, 22, 0], [23, 0, 0, 0]]), + rejected_mask=np.array([[1, 1, 1, 0], [1, 0, 0, 0]]), + ) + out = orpo_trainer._prepare_inputs(training_input) + expected_input_ids = np.array([ + [1, 2, 3, 4, 5, 10, 11, 12, 0], + [0, 0, 1, 2, 3, 13, 14, 15, 16], + [1, 2, 3, 4, 5, 20, 21, 22, 0], + [0, 0, 1, 2, 3, 23, 0, 0, 0], + ]) + np.testing.assert_array_equal(out.input_ids, expected_input_ids) + self.assertEqual(np.sum(out.attention_mask[0]), 44) + self.assertEqual(np.sum(out.attention_mask[1]), 28) + self.assertEqual(np.sum(out.attention_mask[2]), 44) + self.assertEqual(np.sum(out.attention_mask[3]), 22) + expected_completion_mask = np.array( + [[1, 1, 1, 0], [1, 1, 1, 1], [1, 1, 1, 0], [1, 0, 0, 0]] + ) + np.testing.assert_array_equal(out.completion_mask, expected_completion_mask) + self.assertEqual(out.logits_to_keep, 4) + + +if __name__ == "__main__": + absltest.main() diff --git a/tunix/__init__.py b/tunix/__init__.py index 8a0d75cd..e92b964a 100644 --- a/tunix/__init__.py +++ b/tunix/__init__.py @@ -54,6 +54,10 @@ from tunix.sft.dpo.dpo_trainer import DpoTrainer from tunix.sft.dpo.dpo_trainer import DPOTrainingConfig from tunix.sft.dpo.dpo_trainer import DpoTrainingConfig +from tunix.sft.dpo.dpo_trainer import ORPOTrainer +from tunix.sft.dpo.dpo_trainer import OrpoTrainer +from tunix.sft.dpo.dpo_trainer import ORPOTrainingConfig +from tunix.sft.dpo.dpo_trainer import OrpoTrainingConfig from tunix.sft.metrics_logger import MetricsLogger from tunix.sft.metrics_logger import MetricsLoggerOptions from tunix.sft.peft_trainer import PeftTrainer diff --git a/tunix/sft/dpo/dpo_trainer.py b/tunix/sft/dpo/dpo_trainer.py index 00e3c4cf..0d176f77 100644 --- a/tunix/sft/dpo/dpo_trainer.py +++ b/tunix/sft/dpo/dpo_trainer.py @@ -81,17 +81,21 @@ class TrainExample: input_ids: jax.Array # Concatenated [prompt_ids, completion_ids] positions: jax.Array attention_mask: jax.Array - ref_chosen_logps: jax.Array - ref_rejected_logps: jax.Array + ref_chosen_logps: jax.Array | None + ref_rejected_logps: jax.Array | None completion_mask: jax.Array logits_to_keep: int = flax.struct.field(pytree_node=False) @dataclasses.dataclass(slots=True, kw_only=True) class DPOTrainingConfig(peft_trainer.TrainingConfig): - """DPO Training Config.""" + """DPO/ORPO Training Config.""" - beta: float = 0.1 # 𝛽 for KL penalty https://arxiv.org/pdf/2305.18290 + algorithm: str = "dpo" # "dpo" or "orpo" + beta: float = ( + 0.1 # 𝛽 for KL penalty (DPO only) https://arxiv.org/pdf/2305.18290 + ) + lambda_orpo: float = 0.1 # Weight for preference loss (ORPO only) label_smoothing: float = 0.0 # Should be specified only if your input has strings instead of tokenized IDs. @@ -125,7 +129,7 @@ def compute_logps( class DPOTrainer(peft_trainer.PeftTrainer): - """Direct Preference Optimization (DPO) trainer. + """Direct Preference Optimization (DPO) and ORPO trainer. DPO is a preference tuning method for aligning large language models with human or AI preferences. It is a more efficient, performant alternative @@ -137,35 +141,42 @@ class DPOTrainer(peft_trainer.PeftTrainer): preferences (pairs of "chosen" and "rejected responses) to directly optimize the policy model by using a classification-style loss. + ORPO (Odds Ratio Preference Optimization) is a memory-efficient variant that + combines supervised fine-tuning with preference alignment without requiring + a separate reference model, making it approximately 50% more memory-efficient. + References: - - https://arxiv.org/abs/2305.18290 + - DPO: https://arxiv.org/abs/2305.18290 + - ORPO: https://arxiv.org/abs/2403.07691 """ def __init__( self, model: nnx.Module, - ref_model: nnx.Module, + ref_model: nnx.Module | None, optimizer: optax.GradientTransformation, training_config: DPOTrainingConfig, tokenizer: Any | None = None, ): - """Initializes the DPO trainer. + """Initializes the DPO/ORPO trainer. Args: model: The policy model to be trained. ref_model: The reference/anchor model which is kept fixed/frozen during - training. It is used to prevent the policy model from drifting too far - from its original capabilities. If `ref_model` is None, we don't use it - in the loss term. + training (DPO only). It is used to prevent the policy model from + drifting too far from its original capabilities. For ORPO, this should + be None. If `ref_model` is None for DPO, we don't use it in the loss + term. optimizer: The optimizer used for training the policy model. - training_config: A `DPOTrainingConfig` object containing DPO-specific - hyperparameters like `beta` and `label_smoothing`. + training_config: A `DPOTrainingConfig` object containing DPO/ORPO-specific + hyperparameters like `beta`, `lambda_orpo`, and `label_smoothing`. tokenizer: An optional tokenizer. If provided, the trainer can accept string inputs and tokenize them internally. """ self.model = model self.ref_model = ref_model self.dpo_config = training_config + self.algorithm = training_config.algorithm super().__init__(model, optimizer, training_config) self.tokenizer = ( @@ -175,18 +186,38 @@ def __init__( ) self.with_loss_fn(dpo_loss_fn, has_aux=True) - self.with_gen_model_input_fn( - lambda x: { - "train_example": x, - "beta": self.dpo_config.beta, - "label_smoothing": self.dpo_config.label_smoothing, - } - ) - self.gen_model_input_fn = lambda x: { - "train_example": x, - "beta": self.dpo_config.beta, - "label_smoothing": self.dpo_config.label_smoothing, - } + + if self.algorithm == "orpo": + self.with_gen_model_input_fn( + lambda x: { + "train_example": x, + "algorithm": "orpo", + "lambda_orpo": self.dpo_config.lambda_orpo, + "label_smoothing": self.dpo_config.label_smoothing, + } + ) + self.gen_model_input_fn = lambda x: { + "train_example": x, + "algorithm": "orpo", + "lambda_orpo": self.dpo_config.lambda_orpo, + "label_smoothing": self.dpo_config.label_smoothing, + } + else: + self.with_gen_model_input_fn( + lambda x: { + "train_example": x, + "algorithm": "dpo", + "beta": self.dpo_config.beta, + "label_smoothing": self.dpo_config.label_smoothing, + } + ) + self.gen_model_input_fn = lambda x: { + "train_example": x, + "algorithm": "dpo", + "beta": self.dpo_config.beta, + "label_smoothing": self.dpo_config.label_smoothing, + } + self._has_aux = True # If reference model is not provided, we don't use it in the loss term. @@ -201,6 +232,9 @@ def __init__( "log_probs/rejected": np.mean, } + if self.algorithm == "orpo": + self._aux_metrics_to_log["odds_ratio"] = np.mean + @override def _prepare_inputs( self, @@ -315,10 +349,24 @@ def _post_process_eval_step(self, aux: Any) -> None: def dpo_loss_fn( model: nnx.Module, train_example: TrainExample, - beta: float, - label_smoothing: float, + algorithm: str = "dpo", + beta: float = 0.1, + lambda_orpo: float = 0.1, + label_smoothing: float = 0.0, ) -> tuple[jax.Array, dict[str, jax.Array]]: - """DPO loss function.""" + """DPO/ORPO loss function. + + Args: + model: The model to compute loss for. + train_example: Training example containing input_ids, masks, etc. + algorithm: "dpo" or "orpo". + beta: Weight for KL penalty (DPO only). + lambda_orpo: Weight for preference loss (ORPO only). + label_smoothing: Label smoothing factor. + + Returns: + A tuple of (loss, auxiliary_metrics_dict). + """ chosen_logps, rejected_logps = compute_logps( model, train_example.input_ids, @@ -328,33 +376,86 @@ def dpo_loss_fn( train_example.completion_mask, ) - # Compute DPO loss. - chosen_log_ratio = chosen_logps - if train_example.ref_chosen_logps is not None: - chosen_log_ratio = chosen_log_ratio - train_example.ref_chosen_logps - rejected_log_ratio = rejected_logps - if train_example.ref_rejected_logps is not None: - rejected_log_ratio = rejected_log_ratio - train_example.ref_rejected_logps - delta = chosen_log_ratio - rejected_log_ratio - losses = -( - jax.nn.log_sigmoid(beta * delta) * (1 - label_smoothing) - + jax.nn.log_sigmoid(-beta * delta) * label_smoothing - ) + if algorithm == "orpo": + # ORPO loss = L_SFT + λ * L_OR + # Paper: https://arxiv.org/abs/2403.07691 + + # L_SFT: Supervised fine-tuning loss on chosen responses + # Normalize by sequence length as per Equation 2 in paper + batch_size = train_example.completion_mask.shape[0] // 2 + chosen_mask = train_example.completion_mask[:batch_size] + chosen_lengths = chosen_mask.sum(axis=-1) + chosen_lengths = jnp.maximum(chosen_lengths, 1.0) # Avoid division by zero + + # L_SFT = -(1/|y_w|) * Σ log P (Paper Equation 2) + sft_loss = -chosen_logps / chosen_lengths + + # L_OR: Odds ratio preference loss + # Following HuggingFace TRL implementation exactly (Eqs. 4 and 7 from paper) + # Note: log1p(-exp(x)) requires x < 0 to avoid NaN. This works when log probs + # are averaged per token, but may produce NaN for summed log probs if sequences + # are long. TRL uses summed log probs and relies on them being negative. + log_odds = (chosen_logps - rejected_logps) - ( + jnp.log1p(-jnp.exp(chosen_logps)) - jnp.log1p(-jnp.exp(rejected_logps)) + ) + + # Apply label smoothing to odds ratio loss + or_loss = -( + jax.nn.log_sigmoid(log_odds) * (1 - label_smoothing) + + jax.nn.log_sigmoid(-log_odds) * label_smoothing + ) + + # Combined ORPO loss: L_ORPO = L_SFT + λ * L_OR + total_loss = sft_loss + lambda_orpo * or_loss + + # Compute rewards for logging (matching HuggingFace TRL implementation) + chosen_rewards = lambda_orpo * chosen_logps + rejected_rewards = lambda_orpo * rejected_logps + + # Compute odds ratio for logging + odds_ratio = jnp.exp(log_odds) + + aux = { + "rewards/chosen": chosen_rewards.mean(), + "rewards/rejected": rejected_rewards.mean(), + "rewards/margin": (chosen_rewards - rejected_rewards).mean(), + "rewards/accuracy": (chosen_rewards > rejected_rewards).mean(), + "log_probs/chosen": chosen_logps.mean(), + "log_probs/rejected": rejected_logps.mean(), + "odds_ratio": odds_ratio.mean(), + "sft_loss": sft_loss.mean(), + "or_loss": or_loss.mean(), + } - # Compute rewards. - chosen_rewards = beta * chosen_log_ratio - rejected_rewards = beta * rejected_log_ratio + return total_loss.mean(), aux + else: + # DPO loss + chosen_log_ratio = chosen_logps + if train_example.ref_chosen_logps is not None: + chosen_log_ratio = chosen_log_ratio - train_example.ref_chosen_logps + rejected_log_ratio = rejected_logps + if train_example.ref_rejected_logps is not None: + rejected_log_ratio = rejected_log_ratio - train_example.ref_rejected_logps + delta = chosen_log_ratio - rejected_log_ratio + losses = -( + jax.nn.log_sigmoid(beta * delta) * (1 - label_smoothing) + + jax.nn.log_sigmoid(-beta * delta) * label_smoothing + ) - aux = { - "rewards/chosen": chosen_rewards.mean(), - "rewards/rejected": rejected_rewards.mean(), - "rewards/margin": (chosen_rewards - rejected_rewards).mean(), - "rewards/accuracy": (chosen_rewards > rejected_rewards).mean(), - "log_probs/chosen": chosen_logps.mean(), - "log_probs/rejected": rejected_logps.mean(), - } + # Compute rewards. + chosen_rewards = beta * chosen_log_ratio + rejected_rewards = beta * rejected_log_ratio + + aux = { + "rewards/chosen": chosen_rewards.mean(), + "rewards/rejected": rejected_rewards.mean(), + "rewards/margin": (chosen_rewards - rejected_rewards).mean(), + "rewards/accuracy": (chosen_rewards > rejected_rewards).mean(), + "log_probs/chosen": chosen_logps.mean(), + "log_probs/rejected": rejected_logps.mean(), + } - return losses.mean(), aux + return losses.mean(), aux def _generate_ids_and_masks( @@ -491,5 +592,12 @@ def process_dpo_record( rejected_mask=rejected_mask, ) + DpoTrainingConfig = DPOTrainingConfig DpoTrainer = DPOTrainer + +# ORPO aliases +ORPOTrainingConfig = DPOTrainingConfig +ORPOTrainer = DPOTrainer +OrpoTrainingConfig = DPOTrainingConfig +OrpoTrainer = DPOTrainer