Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 4 additions & 166 deletions tests/rl/grpo/grpo_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import itertools
import os
import tempfile
import types
import os
from absl.testing import absltest
from absl.testing import parameterized
import chex
from flax import nnx
from flax.nnx import filterlib
from grain import python as grain
import jax
import jax.numpy as jnp
from jax import sharding
from jax.interpreters import pxla
import jax.numpy as jnp
import numpy as np
import optax
import orbax.checkpoint as ocp
Expand All @@ -36,6 +34,7 @@
from tunix.rl.rollout import base_rollout
from tunix.tests import test_common as tc
from typing_extensions import override
import tempfile

os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=2'

Expand Down Expand Up @@ -152,7 +151,6 @@ def _prepare(dataset, sample_repeat, batch_repeat, grad_acc_steps):
batch_repeat=batch_repeat,
data_queue=data_queue,
async_loading=False,
service_target_batch_size=1,
)
while True:
item = data_queue.get(block=True)
Expand Down Expand Up @@ -882,7 +880,7 @@ def create_rl_cluster(grad_accu_steps, mini_batch_size):
# We can't set grad_acc_steps directly, so we do it through
# mini_batch_size and training_micro_batch_size.
mini_batch_size=mini_batch_size,
train_micro_batch_size=mini_batch_size // grad_accu_steps,
training_micro_batch_size=mini_batch_size // grad_accu_steps,
),
rollout_config=base_rollout.RolloutConfig(
max_tokens_to_generate=10,
Expand Down Expand Up @@ -938,166 +936,6 @@ def create_rl_cluster(grad_accu_steps, mini_batch_size):
) # max_steps * batch_size * num_generations
self.assertLen(first_trajectories['eval'], 8) # eval_rows * num_generations

@parameterized.named_parameters(
dict(
testcase_name='single_update',
batch_size=8,
mini_batch_size=8,
train_micro_batch_size=4,
rollout_micro_batch_size=4,
compute_logps_micro_batch_size=4,
),
dict(
testcase_name='multi_update',
batch_size=8,
mini_batch_size=4,
train_micro_batch_size=2,
rollout_micro_batch_size=2,
compute_logps_micro_batch_size=2,
),
dict(
testcase_name='single_update_with_bigger_rollout_and_compute_logps',
batch_size=8,
mini_batch_size=8,
train_micro_batch_size=4,
rollout_micro_batch_size=8,
compute_logps_micro_batch_size=8,
),
dict(
testcase_name='only_rollout_and_compute_logps',
batch_size=8,
mini_batch_size=None,
train_micro_batch_size=None,
rollout_micro_batch_size=4,
compute_logps_micro_batch_size=4,
),
dict(
testcase_name='individible_batch_size',
batch_size=20,
mini_batch_size=20,
train_micro_batch_size=10,
rollout_micro_batch_size=4,
compute_logps_micro_batch_size=4,
),
)
def test_micro_batch_training(
self,
batch_size,
mini_batch_size,
train_micro_batch_size,
rollout_micro_batch_size,
compute_logps_micro_batch_size,
):
def my_reward_fn(trajectories, prompts, **kwargs):
for t_id, prompt in zip(kwargs['trajectory_ids'], prompts):
trajectories[kwargs['mode']][t_id] = prompt
return jnp.arange(len(prompts))

def create_learner(
mini_batch_size,
train_micro_batch_size,
rollout_micro_batch_size,
compute_logps_micro_batch_size,
trajectories,
):
vocab = tc.MockVocab()
model = tc.ToyTransformer(
rngs=nnx.Rngs(0), vocab_size=vocab.GetPieceSize()
)
ref_model = tc.ToyTransformer(
rngs=nnx.Rngs(0), vocab_size=vocab.GetPieceSize()
)

mesh = pxla.thread_resources.env.physical_mesh
cluster_config = rl_cluster_lib.ClusterConfig(
role_to_mesh={
rl_cluster_lib.Role.ACTOR: mesh,
rl_cluster_lib.Role.REFERENCE: mesh,
rl_cluster_lib.Role.ROLLOUT: mesh,
},
rollout_engine='vanilla',
offload_to_cpu=False,
training_config=rl_cluster_lib.RLTrainingConfig(
actor_optimizer=optax.sgd(1e-3),
eval_every_n_steps=2,
max_steps=20,
mini_batch_size=mini_batch_size,
train_micro_batch_size=train_micro_batch_size,
rollout_micro_batch_size=rollout_micro_batch_size,
compute_logps_micro_batch_size=compute_logps_micro_batch_size,
),
rollout_config=base_rollout.RolloutConfig(
max_tokens_to_generate=10,
max_prompt_length=32,
kv_cache_size=256,
temperature=0.5,
),
)
rl_cluster = rl_cluster_lib.RLCluster(
actor=model,
reference=ref_model,
tokenizer=vocab,
cluster_config=cluster_config,
)

grpo_config = grpo_lib.GRPOConfig(
num_generations=2,
num_iterations=1,
)
grpo_learner = grpo_lib.GRPOLearner(
rl_cluster=rl_cluster,
reward_fns=lambda **kwargs: my_reward_fn(
trajectories=trajectories, **kwargs
),
grpo_config=grpo_config,
)
return grpo_learner, model

# 40 rows with repeat=10.
train_ds = _dummy_dataset(MySource(repeat=10), batch_size=batch_size)
eval_ds = _dummy_dataset(batch_size=1)

# Baseline with no micro batching.
base_trajectories = {'train': {}, 'eval': {}}
grpo_learner, model = create_learner(
mini_batch_size=None,
train_micro_batch_size=None,
rollout_micro_batch_size=None,
compute_logps_micro_batch_size=None,
trajectories=base_trajectories,
)
original_variables = jax.tree.map(jnp.copy, nnx.state(model, nnx.Param))

grpo_learner.train(train_ds, eval_ds)
self.assertEqual(
40 // batch_size, grpo_learner.rl_cluster.actor_trainer.train_steps
)

base_variables = nnx.state(model, nnx.Param)
jax.tree.map_with_path(
tc.assert_not_equal, original_variables, base_variables
)

# Train with micro batching.
micro_batch_trajectories = {'train': {}, 'eval': {}}
grpo_learner, model = create_learner(
mini_batch_size=mini_batch_size,
train_micro_batch_size=train_micro_batch_size,
rollout_micro_batch_size=rollout_micro_batch_size,
compute_logps_micro_batch_size=compute_logps_micro_batch_size,
trajectories=micro_batch_trajectories,
)
grpo_learner.train(train_ds, eval_ds)
micro_batch_variables = nnx.state(model, nnx.Param)
jax.tree.map_with_path(
tc.assert_not_equal, original_variables, micro_batch_variables
)
self.assertEqual(base_trajectories, micro_batch_trajectories)
self.assertEqual(
40 // (mini_batch_size or batch_size),
grpo_learner.rl_cluster.actor_trainer.train_steps,
)


if __name__ == '__main__':
absltest.main()
71 changes: 56 additions & 15 deletions tests/rl/rl_cluster_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,27 +116,68 @@ def test_model_loading_with_resharding(self):
)
self.assertEqual(ref_model_mesh, actor_mesh)

def test_batch_size_config(self):
@parameterized.named_parameters(
('1', None, None, None, None, [None, None, None, None, 1]),
('2', 8, None, None, None, [8, 8, 8, 8, 1]),
('3', 8, 2, None, None, [8, 2, 2, 2, 4]),
('4', 8, 4, 8, None, [8, 4, 8, 4, 2]),
('5', 8, 4, None, 8, [8, 4, 4, 8, 2]),
('6', 16, 8, 8, 16, [16, 8, 8, 16, 2]),
)
def test_batch_sizes(
self,
mini_batch_size,
training_micro_batch_size,
rollout_micro_batch_size,
compute_logps_micro_batch_size,
expected_values,
):
cfg = rl_cluster_lib.RLTrainingConfig(
actor_optimizer=optax.sgd(1e-3),
critic_optimizer=None,
mini_batch_size=8,
train_micro_batch_size=4,
mini_batch_size=mini_batch_size,
training_micro_batch_size=training_micro_batch_size,
rollout_micro_batch_size=rollout_micro_batch_size,
compute_logps_micro_batch_size=compute_logps_micro_batch_size,
eval_every_n_steps=1,
)
self.assertEqual(cfg.gradient_accumulation_steps, 2)

for mini_batch_size, train_micro_batch_size in zip(
[8, -8, None], [3, 4, 4]
):
with self.assertRaises(ValueError):
rl_cluster_lib.RLTrainingConfig(
actor_optimizer=optax.sgd(1e-3),
critic_optimizer=None,
mini_batch_size=mini_batch_size,
train_micro_batch_size=train_micro_batch_size,
eval_every_n_steps=1,
)
self.assertEqual(
expected_values,
[
cfg.mini_batch_size,
cfg.training_micro_batch_size,
cfg.rollout_micro_batch_size,
cfg.compute_logps_micro_batch_size,
cfg.gradient_accumulation_steps,
],
)

@parameterized.named_parameters(
('1', 2, 4, None, None),
('2', 8, 3, None, None),
('3', 8, 4, 3, None),
('4', 8, 4, None, 3),
('5', None, 2, None, None),
('6', None, None, 2, None),
)
def test_batch_sizes_errors(
self,
mini_batch_size,
training_micro_batch_size,
rollout_micro_batch_size,
compute_logps_micro_batch_size,
):
with self.assertRaises(ValueError):
rl_cluster_lib.RLTrainingConfig(
actor_optimizer=optax.sgd(1e-3),
critic_optimizer=None,
mini_batch_size=mini_batch_size,
training_micro_batch_size=training_micro_batch_size,
rollout_micro_batch_size=rollout_micro_batch_size,
compute_logps_micro_batch_size=compute_logps_micro_batch_size,
eval_every_n_steps=1,
)

def test_generate_with_chat_template(self): # pylint: disable=g-doc-args
mesh = Mesh(
Expand Down
50 changes: 25 additions & 25 deletions tests/rl/rl_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,26 +32,25 @@ def _num_generations(self):
class RLLearnerTest(parameterized.TestCase):

@parameterized.named_parameters(
('1', None, None, None, None, [32, 32], False),
('2', 8, None, None, None, [8, 8], False),
('3', 8, 2, None, None, [2, 2], False),
('4', 8, 4, 4, 4, [4, 4], False),
('5', 8, 4, 3, 4, [], True),
('6', 8, 4, 4, 3, [], True),
('1', None, None, None, None, [32, 32, 32, 32]),
('2', 8, None, None, None, [8, 8, 8, 8]),
('3', 8, 2, None, None, [8, 2, 2, 2]),
('4', 8, 4, 8, None, [8, 4, 8, 4]),
('5', 8, 4, None, 4, [8, 4, 4, 4]),
('6', 16, 8, 16, 8, [16, 8, 16, 8]),
)
def test_micro_batching(
self,
mini_batch_size,
train_micro_batch_size,
training_micro_batch_size,
rollout_micro_batch_size,
compute_logps_micro_batch_size,
expected_values,
expect_failure,
):
config = rl_cluster_lib.RLTrainingConfig(
actor_optimizer=optax.sgd(1e-3),
mini_batch_size=mini_batch_size,
train_micro_batch_size=train_micro_batch_size,
training_micro_batch_size=training_micro_batch_size,
rollout_micro_batch_size=rollout_micro_batch_size,
compute_logps_micro_batch_size=compute_logps_micro_batch_size,
eval_every_n_steps=1,
Expand All @@ -75,22 +74,23 @@ def test_micro_batching(
full_batch_size = 32
train_ds = [{'prompts': [''] * full_batch_size}]

if expect_failure:
with self.assertRaises(ValueError):
learner.train(train_ds)
else:
learner.train(train_ds)
(
expected_rollout_micro,
expected_compute_logps_micro,
) = expected_values

self.assertEqual(
learner._rollout_micro_batch_size, expected_rollout_micro
)
self.assertEqual(
learner._compute_logps_micro_batch_size, expected_compute_logps_micro
)
learner.train(train_ds)

(
expected_mini_batch,
expected_training_micro,
expected_rollout_micro,
expected_compute_logps_micro,
) = expected_values

self.assertEqual(learner._mini_batch_size, expected_mini_batch)
self.assertEqual(
learner._training_micro_batch_size, expected_training_micro
)
self.assertEqual(learner._rollout_micro_batch_size, expected_rollout_micro)
self.assertEqual(
learner._compute_logps_micro_batch_size, expected_compute_logps_micro
)


if __name__ == '__main__':
Expand Down
Loading
Loading