From cffc39031e47c0eb432b3268ca5dc3909b0ba7f2 Mon Sep 17 00:00:00 2001 From: Tianshu Bao Date: Wed, 8 Oct 2025 17:27:10 -0700 Subject: [PATCH] simplify micro batching config PiperOrigin-RevId: 816926071 --- tests/rl/grpo/grpo_learner_test.py | 170 +---------------------------- tests/rl/rl_cluster_test.py | 71 +++++++++--- tests/rl/rl_learner_test.py | 50 ++++----- tunix/rl/rl_cluster.py | 147 +++++++++++++++++++------ tunix/rl/rl_learner.py | 84 +++++++------- tunix/rl/utils.py | 23 ++-- tunix/tests/test_common.py | 1 - 7 files changed, 251 insertions(+), 295 deletions(-) diff --git a/tests/rl/grpo/grpo_learner_test.py b/tests/rl/grpo/grpo_learner_test.py index 8e211e51..8637e293 100644 --- a/tests/rl/grpo/grpo_learner_test.py +++ b/tests/rl/grpo/grpo_learner_test.py @@ -12,11 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools import itertools -import os -import tempfile import types +import os from absl.testing import absltest from absl.testing import parameterized import chex @@ -24,9 +22,9 @@ from flax.nnx import filterlib from grain import python as grain import jax +import jax.numpy as jnp from jax import sharding from jax.interpreters import pxla -import jax.numpy as jnp import numpy as np import optax import orbax.checkpoint as ocp @@ -36,6 +34,7 @@ from tunix.rl.rollout import base_rollout from tunix.tests import test_common as tc from typing_extensions import override +import tempfile os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=2' @@ -152,7 +151,6 @@ def _prepare(dataset, sample_repeat, batch_repeat, grad_acc_steps): batch_repeat=batch_repeat, data_queue=data_queue, async_loading=False, - service_target_batch_size=1, ) while True: item = data_queue.get(block=True) @@ -882,7 +880,7 @@ def create_rl_cluster(grad_accu_steps, mini_batch_size): # We can't set grad_acc_steps directly, so we do it through # mini_batch_size and training_micro_batch_size. mini_batch_size=mini_batch_size, - train_micro_batch_size=mini_batch_size // grad_accu_steps, + training_micro_batch_size=mini_batch_size // grad_accu_steps, ), rollout_config=base_rollout.RolloutConfig( max_tokens_to_generate=10, @@ -938,166 +936,6 @@ def create_rl_cluster(grad_accu_steps, mini_batch_size): ) # max_steps * batch_size * num_generations self.assertLen(first_trajectories['eval'], 8) # eval_rows * num_generations - @parameterized.named_parameters( - dict( - testcase_name='single_update', - batch_size=8, - mini_batch_size=8, - train_micro_batch_size=4, - rollout_micro_batch_size=4, - compute_logps_micro_batch_size=4, - ), - dict( - testcase_name='multi_update', - batch_size=8, - mini_batch_size=4, - train_micro_batch_size=2, - rollout_micro_batch_size=2, - compute_logps_micro_batch_size=2, - ), - dict( - testcase_name='single_update_with_bigger_rollout_and_compute_logps', - batch_size=8, - mini_batch_size=8, - train_micro_batch_size=4, - rollout_micro_batch_size=8, - compute_logps_micro_batch_size=8, - ), - dict( - testcase_name='only_rollout_and_compute_logps', - batch_size=8, - mini_batch_size=None, - train_micro_batch_size=None, - rollout_micro_batch_size=4, - compute_logps_micro_batch_size=4, - ), - dict( - testcase_name='individible_batch_size', - batch_size=20, - mini_batch_size=20, - train_micro_batch_size=10, - rollout_micro_batch_size=4, - compute_logps_micro_batch_size=4, - ), - ) - def test_micro_batch_training( - self, - batch_size, - mini_batch_size, - train_micro_batch_size, - rollout_micro_batch_size, - compute_logps_micro_batch_size, - ): - def my_reward_fn(trajectories, prompts, **kwargs): - for t_id, prompt in zip(kwargs['trajectory_ids'], prompts): - trajectories[kwargs['mode']][t_id] = prompt - return jnp.arange(len(prompts)) - - def create_learner( - mini_batch_size, - train_micro_batch_size, - rollout_micro_batch_size, - compute_logps_micro_batch_size, - trajectories, - ): - vocab = tc.MockVocab() - model = tc.ToyTransformer( - rngs=nnx.Rngs(0), vocab_size=vocab.GetPieceSize() - ) - ref_model = tc.ToyTransformer( - rngs=nnx.Rngs(0), vocab_size=vocab.GetPieceSize() - ) - - mesh = pxla.thread_resources.env.physical_mesh - cluster_config = rl_cluster_lib.ClusterConfig( - role_to_mesh={ - rl_cluster_lib.Role.ACTOR: mesh, - rl_cluster_lib.Role.REFERENCE: mesh, - rl_cluster_lib.Role.ROLLOUT: mesh, - }, - rollout_engine='vanilla', - offload_to_cpu=False, - training_config=rl_cluster_lib.RLTrainingConfig( - actor_optimizer=optax.sgd(1e-3), - eval_every_n_steps=2, - max_steps=20, - mini_batch_size=mini_batch_size, - train_micro_batch_size=train_micro_batch_size, - rollout_micro_batch_size=rollout_micro_batch_size, - compute_logps_micro_batch_size=compute_logps_micro_batch_size, - ), - rollout_config=base_rollout.RolloutConfig( - max_tokens_to_generate=10, - max_prompt_length=32, - kv_cache_size=256, - temperature=0.5, - ), - ) - rl_cluster = rl_cluster_lib.RLCluster( - actor=model, - reference=ref_model, - tokenizer=vocab, - cluster_config=cluster_config, - ) - - grpo_config = grpo_lib.GRPOConfig( - num_generations=2, - num_iterations=1, - ) - grpo_learner = grpo_lib.GRPOLearner( - rl_cluster=rl_cluster, - reward_fns=lambda **kwargs: my_reward_fn( - trajectories=trajectories, **kwargs - ), - grpo_config=grpo_config, - ) - return grpo_learner, model - - # 40 rows with repeat=10. - train_ds = _dummy_dataset(MySource(repeat=10), batch_size=batch_size) - eval_ds = _dummy_dataset(batch_size=1) - - # Baseline with no micro batching. - base_trajectories = {'train': {}, 'eval': {}} - grpo_learner, model = create_learner( - mini_batch_size=None, - train_micro_batch_size=None, - rollout_micro_batch_size=None, - compute_logps_micro_batch_size=None, - trajectories=base_trajectories, - ) - original_variables = jax.tree.map(jnp.copy, nnx.state(model, nnx.Param)) - - grpo_learner.train(train_ds, eval_ds) - self.assertEqual( - 40 // batch_size, grpo_learner.rl_cluster.actor_trainer.train_steps - ) - - base_variables = nnx.state(model, nnx.Param) - jax.tree.map_with_path( - tc.assert_not_equal, original_variables, base_variables - ) - - # Train with micro batching. - micro_batch_trajectories = {'train': {}, 'eval': {}} - grpo_learner, model = create_learner( - mini_batch_size=mini_batch_size, - train_micro_batch_size=train_micro_batch_size, - rollout_micro_batch_size=rollout_micro_batch_size, - compute_logps_micro_batch_size=compute_logps_micro_batch_size, - trajectories=micro_batch_trajectories, - ) - grpo_learner.train(train_ds, eval_ds) - micro_batch_variables = nnx.state(model, nnx.Param) - jax.tree.map_with_path( - tc.assert_not_equal, original_variables, micro_batch_variables - ) - self.assertEqual(base_trajectories, micro_batch_trajectories) - self.assertEqual( - 40 // (mini_batch_size or batch_size), - grpo_learner.rl_cluster.actor_trainer.train_steps, - ) - if __name__ == '__main__': absltest.main() diff --git a/tests/rl/rl_cluster_test.py b/tests/rl/rl_cluster_test.py index 06843636..971302e4 100644 --- a/tests/rl/rl_cluster_test.py +++ b/tests/rl/rl_cluster_test.py @@ -116,27 +116,68 @@ def test_model_loading_with_resharding(self): ) self.assertEqual(ref_model_mesh, actor_mesh) - def test_batch_size_config(self): + @parameterized.named_parameters( + ('1', None, None, None, None, [None, None, None, None, 1]), + ('2', 8, None, None, None, [8, 8, 8, 8, 1]), + ('3', 8, 2, None, None, [8, 2, 2, 2, 4]), + ('4', 8, 4, 8, None, [8, 4, 8, 4, 2]), + ('5', 8, 4, None, 8, [8, 4, 4, 8, 2]), + ('6', 16, 8, 8, 16, [16, 8, 8, 16, 2]), + ) + def test_batch_sizes( + self, + mini_batch_size, + training_micro_batch_size, + rollout_micro_batch_size, + compute_logps_micro_batch_size, + expected_values, + ): cfg = rl_cluster_lib.RLTrainingConfig( actor_optimizer=optax.sgd(1e-3), critic_optimizer=None, - mini_batch_size=8, - train_micro_batch_size=4, + mini_batch_size=mini_batch_size, + training_micro_batch_size=training_micro_batch_size, + rollout_micro_batch_size=rollout_micro_batch_size, + compute_logps_micro_batch_size=compute_logps_micro_batch_size, eval_every_n_steps=1, ) - self.assertEqual(cfg.gradient_accumulation_steps, 2) - for mini_batch_size, train_micro_batch_size in zip( - [8, -8, None], [3, 4, 4] - ): - with self.assertRaises(ValueError): - rl_cluster_lib.RLTrainingConfig( - actor_optimizer=optax.sgd(1e-3), - critic_optimizer=None, - mini_batch_size=mini_batch_size, - train_micro_batch_size=train_micro_batch_size, - eval_every_n_steps=1, - ) + self.assertEqual( + expected_values, + [ + cfg.mini_batch_size, + cfg.training_micro_batch_size, + cfg.rollout_micro_batch_size, + cfg.compute_logps_micro_batch_size, + cfg.gradient_accumulation_steps, + ], + ) + + @parameterized.named_parameters( + ('1', 2, 4, None, None), + ('2', 8, 3, None, None), + ('3', 8, 4, 3, None), + ('4', 8, 4, None, 3), + ('5', None, 2, None, None), + ('6', None, None, 2, None), + ) + def test_batch_sizes_errors( + self, + mini_batch_size, + training_micro_batch_size, + rollout_micro_batch_size, + compute_logps_micro_batch_size, + ): + with self.assertRaises(ValueError): + rl_cluster_lib.RLTrainingConfig( + actor_optimizer=optax.sgd(1e-3), + critic_optimizer=None, + mini_batch_size=mini_batch_size, + training_micro_batch_size=training_micro_batch_size, + rollout_micro_batch_size=rollout_micro_batch_size, + compute_logps_micro_batch_size=compute_logps_micro_batch_size, + eval_every_n_steps=1, + ) def test_generate_with_chat_template(self): # pylint: disable=g-doc-args mesh = Mesh( diff --git a/tests/rl/rl_learner_test.py b/tests/rl/rl_learner_test.py index 385426c3..ad1a6a06 100644 --- a/tests/rl/rl_learner_test.py +++ b/tests/rl/rl_learner_test.py @@ -32,26 +32,25 @@ def _num_generations(self): class RLLearnerTest(parameterized.TestCase): @parameterized.named_parameters( - ('1', None, None, None, None, [32, 32], False), - ('2', 8, None, None, None, [8, 8], False), - ('3', 8, 2, None, None, [2, 2], False), - ('4', 8, 4, 4, 4, [4, 4], False), - ('5', 8, 4, 3, 4, [], True), - ('6', 8, 4, 4, 3, [], True), + ('1', None, None, None, None, [32, 32, 32, 32]), + ('2', 8, None, None, None, [8, 8, 8, 8]), + ('3', 8, 2, None, None, [8, 2, 2, 2]), + ('4', 8, 4, 8, None, [8, 4, 8, 4]), + ('5', 8, 4, None, 4, [8, 4, 4, 4]), + ('6', 16, 8, 16, 8, [16, 8, 16, 8]), ) def test_micro_batching( self, mini_batch_size, - train_micro_batch_size, + training_micro_batch_size, rollout_micro_batch_size, compute_logps_micro_batch_size, expected_values, - expect_failure, ): config = rl_cluster_lib.RLTrainingConfig( actor_optimizer=optax.sgd(1e-3), mini_batch_size=mini_batch_size, - train_micro_batch_size=train_micro_batch_size, + training_micro_batch_size=training_micro_batch_size, rollout_micro_batch_size=rollout_micro_batch_size, compute_logps_micro_batch_size=compute_logps_micro_batch_size, eval_every_n_steps=1, @@ -75,22 +74,23 @@ def test_micro_batching( full_batch_size = 32 train_ds = [{'prompts': [''] * full_batch_size}] - if expect_failure: - with self.assertRaises(ValueError): - learner.train(train_ds) - else: - learner.train(train_ds) - ( - expected_rollout_micro, - expected_compute_logps_micro, - ) = expected_values - - self.assertEqual( - learner._rollout_micro_batch_size, expected_rollout_micro - ) - self.assertEqual( - learner._compute_logps_micro_batch_size, expected_compute_logps_micro - ) + learner.train(train_ds) + + ( + expected_mini_batch, + expected_training_micro, + expected_rollout_micro, + expected_compute_logps_micro, + ) = expected_values + + self.assertEqual(learner._mini_batch_size, expected_mini_batch) + self.assertEqual( + learner._training_micro_batch_size, expected_training_micro + ) + self.assertEqual(learner._rollout_micro_batch_size, expected_rollout_micro) + self.assertEqual( + learner._compute_logps_micro_batch_size, expected_compute_logps_micro + ) if __name__ == '__main__': diff --git a/tunix/rl/rl_cluster.py b/tunix/rl/rl_cluster.py index a51272e2..6c56789a 100644 --- a/tunix/rl/rl_cluster.py +++ b/tunix/rl/rl_cluster.py @@ -93,56 +93,82 @@ class RLTrainingConfig(peft_trainer.TrainingConfig): actor_optimizer: Optimizer for the actor model. critic_optimizer: Optimizer for the critic model. If None, the critic model will be trained in the same optimizer as the actor model. - mini_batch_size: The mini-batch size used for policy weight updates. One - mini-batch corresponds to one optimizer update. `mini_batch_size` must be - divisible by the batch size used for data loading. - train_micro_batch_size: The micro-batch size used for gradient - accumulation at training time. `train_micro_batch_size` must be - divisible by `mini_batch_size`. - rollout_micro_batch_size: The micro-batch size used for model rollouts. - compute_logps_micro_batch_size: The micro-batch size used for computing log - probabilities (e.g. for reference and old policy models). + actor_critic_share_backbone: Whether to share the backbone of the actor and + critic models. + training_micro_batch_size: The microbatch size used for training. + rollout_micro_batch_size: The microbatch size used for model rollouts. If + None, it defaults to `training_micro_batch_size`. + compute_logps_micro_batch_size: The microbatch size used for computing log + probabilities (e.g., for reference and old policy models). If None, it + defaults to `training_micro_batch_size`. """ actor_optimizer: optax.GradientTransformation critic_optimizer: optax.GradientTransformation | None = None mini_batch_size: int | None = None - train_micro_batch_size: int | None = None + training_micro_batch_size: int | None = None rollout_micro_batch_size: int | None = None compute_logps_micro_batch_size: int | None = None def __post_init__(self): """Validates the configuration after initialization.""" - for name in [ - "mini_batch_size", - "train_micro_batch_size", - "rollout_micro_batch_size", - "compute_logps_micro_batch_size", - ]: - rl_utils.check_positive(getattr(self, name), name) + # Verify all batch sizes are positive. + def _check_positive(value: int | None, name: str): + """Checks if the value is positive.""" + if value is not None and value <= 0: + raise ValueError(f"{name} must be positive.") + + _check_positive(self.mini_batch_size, "mini_batch_size") + _check_positive(self.training_micro_batch_size, "training_micro_batch_size") + _check_positive(self.rollout_micro_batch_size, "rollout_micro_batch_size") + _check_positive( + self.compute_logps_micro_batch_size, "compute_logps_micro_batch_size" + ) + + if self.gradient_accumulation_steps == 1: + self.gradient_accumulation_steps = None + + # Verify `gradient_accumulation_steps` is None. if self.gradient_accumulation_steps is not None: raise ValueError( "For RL training, gradient_accumulation_steps should be None. It is " - "automatically derived from: " - "`mini_batch_size // train_micro_batch_size`." + "automatically inferred: " + "`mini_batch_size // training_micro_batch_size`." ) - if self.train_micro_batch_size is not None: - if self.mini_batch_size is None: + self.training_micro_batch_size, self.gradient_accumulation_steps = ( + _compute_batch_sizes( + self.training_micro_batch_size, + self.mini_batch_size, + "training_micro_batch_size", + "mini_batch_size", + ret_grad_acc=True, + ) + ) + + for batch_name in [ + "rollout_micro_batch_size", + "compute_logps_micro_batch_size", + ]: + batch_size = getattr(self, batch_name) + + if self.training_micro_batch_size is None and batch_size is not None: raise ValueError( - "For RL training, `batch_size` and `mini_batch_size` must be set" - " when `train_micro_batch_size` is set." + f"For {batch_name}, training_micro_batch_size must be set when" + f" {batch_name} is set." + ) + if batch_size is None: + batch_size = self.training_micro_batch_size + setattr(self, batch_name, batch_size) + + if batch_size is not None: + rl_utils.check_batch_divisibility( + self.training_micro_batch_size, + batch_size, + "training_micro_batch_size", + batch_name, ) - rl_utils.check_divisibility( - self.train_micro_batch_size, - self.mini_batch_size, - f"{self.train_micro_batch_size=}", - f"{self.mini_batch_size=}", - ) - self.gradient_accumulation_steps = ( - self.mini_batch_size // self.train_micro_batch_size - ) @dataclasses.dataclass(kw_only=True, frozen=True) @@ -764,7 +790,7 @@ def get_old_per_token_logps( """Gets the per-token logps of the current policy model.""" batch_size = prompt_tokens.shape[0] if batch_size == 0: - raise ValueError("Cannot get old log probabilities from an empty batch.") + return jnp.array([], dtype=jnp.float32) micro_batch_size = micro_batch_size or batch_size with self.cluster_config.role_to_mesh[Role.ROLLOUT]: @@ -843,3 +869,58 @@ def get_rewards( pad_id, eos_id, ) + + +def _compute_batch_sizes( + small_batch_size, + big_batch_size, + small_batch_size_name, + big_batch_size_name, + ret_grad_acc=False, +): + """Computes and validates batch sizes. + + There are four cases: + - big_batch_size: None, small_batch_size: None; allowed, grad_steps = 1. + - big_batch_size: None, small_batch_size: set; not allowed, since we cannot + if say, mini_batch_size is None, we want it to be equal to dataloader batch + size, which is available to us only during training. So, we cannot determine + `grad_accumulation_steps` here. + - big_batch_size: set, small_batch_size: None; allowed, grad_steps = 1. + - Both set, in which case we check divisibility. + + Args: + small_batch_size: The small batch size. + big_batch_size: The big batch size. + small_batch_size_name: The name of the small batch size. + big_batch_size_name: The name of the big batch size. + ret_grad_acc: Whether to return the gradient accumulation steps. + + Returns: + The correct `small_batch_size` and `gradient_accumulation_steps`, if + `ret_grad_acc` is True. + """ + if big_batch_size is None and small_batch_size is not None: + # Case 2 + raise ValueError( + f"`{big_batch_size_name}` ({big_batch_size}) must be set if " + f"{small_batch_size_name}` ({small_batch_size}) is set." + ) + + # Case 1, 3 + if small_batch_size is None: + small_batch_size = big_batch_size + if ret_grad_acc: + return small_batch_size, 1 + return small_batch_size + + # Case 4 + rl_utils.check_batch_divisibility( + small_batch_size, + big_batch_size, + small_batch_size_name, + big_batch_size_name, + ) + if ret_grad_acc: + return small_batch_size, big_batch_size // small_batch_size + return small_batch_size diff --git a/tunix/rl/rl_learner.py b/tunix/rl/rl_learner.py index 69355f71..a6960e36 100644 --- a/tunix/rl/rl_learner.py +++ b/tunix/rl/rl_learner.py @@ -111,12 +111,17 @@ def __init__( self.executor = futures.ThreadPoolExecutor(max_workers=1) self._last_iter_step = self.rl_cluster.actor_trainer.iter_steps - self._training_config = self.rl_cluster.cluster_config.training_config + self._mini_batch_size = ( + self.rl_cluster.cluster_config.training_config.mini_batch_size + ) + self._training_micro_batch_size = ( + self.rl_cluster.cluster_config.training_config.training_micro_batch_size + ) self._rollout_micro_batch_size = ( - self._training_config.rollout_micro_batch_size + self.rl_cluster.cluster_config.training_config.rollout_micro_batch_size ) self._compute_logps_micro_batch_size = ( - self._training_config.compute_logps_micro_batch_size + self.rl_cluster.cluster_config.training_config.compute_logps_micro_batch_size ) sft_utils.show_hbm_usage(title="RLLearner init") @@ -267,7 +272,6 @@ def _process_accumulated_batches( produced: list[common.TrainExample] = [] offset = 0 - # TODO(tsbao): remove sample_repeat and set proper micro_batch_sizes outside for n in micro_batch_sizes: # Calculate slice indices start_idx = offset * sample_repeat @@ -285,7 +289,6 @@ def _prepare_data( proceed_num_steps: int, sample_repeat: int, batch_repeat: int, - service_target_batch_size: int, data_queue: queue_lib.AbstractDataQueue[list[common.TrainExample] | None], async_loading: bool = False, mode: rl_cluster_lib.Mode = rl_cluster_lib.Mode.TRAIN, @@ -325,15 +328,17 @@ def _prepare_data( `grpo_config.num_generations`. batch_repeat: The number of times the produced `TrainExample` batch should `grpo_config.num_iterations`. - service_target_batch_size: largest common multiple of rollout and - compute_logps micro-batch sizes. This is used to accumulate - micro-batches between the rollout and inference computation. data_queue: The queue to which lists of `TrainExample` are added. async_loading: If True, enqueue each produced micro-batch immediately in async mode. Otherwise, accumulate and enqueue at the boundary. mode: The metrics logger mode, either `metrics_logger.Mode.TRAIN` or `metrics_logger.Mode.EVAL`. """ + service_target_batch_size = math.lcm( + self._rollout_micro_batch_size, + self._compute_logps_micro_batch_size, + ) + # A buffer to accumulate micro-batches before processing them together. micro_batches: list[TrainingInputT] = [] # Number of samples for each micro-batch @@ -548,58 +553,57 @@ def train( skip_jit: bool = False, ) -> None: """Main entry point for the training loop.""" + grad_acc_steps = ( + self.rl_cluster.cluster_config.training_config.gradient_accumulation_steps + ) + if grad_acc_steps is None: + raise ValueError("Gradient accumulation steps must be set.") + full_batch_iterator = iter(train_ds) first_item = next(full_batch_iterator) full_batch_size = len(first_item["prompts"]) full_batch_iterator = itertools.chain([first_item], full_batch_iterator) # Initialize batch sizes. - mini_batch_size = self._training_config.mini_batch_size or full_batch_size - train_micro_batch_size = ( - self._training_config.train_micro_batch_size or mini_batch_size - ) - self._rollout_micro_batch_size = ( - self._rollout_micro_batch_size or train_micro_batch_size - ) - self._compute_logps_micro_batch_size = ( - self._compute_logps_micro_batch_size or train_micro_batch_size - ) - for v, n in [ - (self._rollout_micro_batch_size, f"{self._rollout_micro_batch_size=}"), - ( - self._compute_logps_micro_batch_size, - f"{self._compute_logps_micro_batch_size=}", - ), - (mini_batch_size, f"{mini_batch_size=}"), - ]: - rl_utils.check_divisibility(v, full_batch_size, n, f"{full_batch_size=}") - grad_acc_steps = self._training_config.get_with_default( - "gradient_accumulation_steps", 1 - ) - - logging.info( # pylint: disable=logging-fstring-interpolation - f"Training with {full_batch_size=}, {mini_batch_size=}," - f" {train_micro_batch_size=}, {self._rollout_micro_batch_size=}," - f" {self._compute_logps_micro_batch_size=}, {grad_acc_steps=}" + if self._mini_batch_size is None: + self._mini_batch_size = full_batch_size + self._training_micro_batch_size = full_batch_size + if self._rollout_micro_batch_size is None: + self._rollout_micro_batch_size = self._training_micro_batch_size + if self._compute_logps_micro_batch_size is None: + self._compute_logps_micro_batch_size = self._training_micro_batch_size + + rl_utils.check_batch_divisibility( + self._mini_batch_size, + full_batch_size, + "mini_batch_size", + "full_batch_size", ) - service_target_batch_size = math.lcm( + logging.info( + "full_batch_size: %d, mini_batch_size: %d, training_micro_batch_size:" + " %d, rollout_micro_batch_size: %d, compute_logps_micro_batch_size: %d" + " grad_acc_steps: %d", + full_batch_size, + self._mini_batch_size, + self._training_micro_batch_size, self._rollout_micro_batch_size, self._compute_logps_micro_batch_size, + grad_acc_steps, ) # if the micro batch size is the same as the full batch size, we can use the # full batch iterator directly. - if train_micro_batch_size == full_batch_size: + if self._training_micro_batch_size == full_batch_size: train_iterator = full_batch_iterator else: train_iterator = self._create_micro_batch_iterator( - full_batch_iterator, train_micro_batch_size + full_batch_iterator, self._training_micro_batch_size ) while True: # loop over M try: initial_steps = self._iter_steps - for _ in range(full_batch_size // mini_batch_size): + for _ in range(full_batch_size // self._mini_batch_size): # reserve 1 for None and the other for repeated interable # if batch_repeat > 1 train_data_queue = queue_lib.SimpleDataQueue( @@ -614,7 +618,6 @@ def train( proceed_num_steps=grad_acc_steps, sample_repeat=self._num_generations(), batch_repeat=self._num_iterations(), - service_target_batch_size=service_target_batch_size, data_queue=train_data_queue, async_loading=self.can_enable_async_rollout, mode=rl_cluster_lib.Mode.TRAIN, @@ -655,7 +658,6 @@ def train( proceed_num_steps=-1, sample_repeat=self._num_generations(), batch_repeat=1, - service_target_batch_size=service_target_batch_size, data_queue=eval_data_queue, async_loading=False, mode=rl_cluster_lib.Mode.EVAL, diff --git a/tunix/rl/utils.py b/tunix/rl/utils.py index 5b4fabc3..ff8c9b0c 100644 --- a/tunix/rl/utils.py +++ b/tunix/rl/utils.py @@ -135,22 +135,17 @@ def apply_slice(x: Any) -> Any: ) -def check_positive(value: int | None, name: str): - """Checks if the value is positive.""" - if value is not None and value <= 0: - raise ValueError(f"{name} must be positive.") - - -def check_divisibility( - small_size, - big_size, - small_size_name, - big_size_name, +def check_batch_divisibility( + small_batch_size, + big_batch_size, + small_batch_size_name, + big_batch_size_name, ): - """Checks if big_size is a multiple of small_size.""" - if big_size % small_size != 0: + """Checks if big_batch_size is a multiple of small_batch_size.""" + if big_batch_size % small_batch_size != 0: raise ValueError( - f"{big_size_name} must be a multiple of {small_size_name}." + f"{big_batch_size_name} ({big_batch_size}) must be a multiple " + f"of {small_batch_size_name} ({small_batch_size})." ) diff --git a/tunix/tests/test_common.py b/tunix/tests/test_common.py index 57ddc5a6..e18d11ab 100644 --- a/tunix/tests/test_common.py +++ b/tunix/tests/test_common.py @@ -23,7 +23,6 @@ import jax.numpy as jnp import numpy as np import qwix - import sentencepiece as spm if hasattr(flax_config, 'flax_always_shard_variable'):