diff --git a/skyrl-train/skyrl_train/trainer.py b/skyrl-train/skyrl_train/trainer.py index 1690d09cf..53be5bea8 100644 --- a/skyrl-train/skyrl_train/trainer.py +++ b/skyrl-train/skyrl_train/trainer.py @@ -1067,26 +1067,28 @@ def train_critic_and_policy(self, data: TrainingInputBatch): """ Run the training step for the policy and critic models. - For Megatron strategy: uses ppo_train (training loop inside worker) - For FSDP strategy: uses forward_backward + optim_step (training loop in trainer) + For Megatron: Uses ppo_train via dispatch. + For FSDP/FSDP2: Uses forward_backward + optim_step via dispatch. + + Dispatch handles offload/backload automatically when colocate_all=True. """ data.metadata["global_step"] = self.global_step critic_status = None if self.cfg.trainer.strategy == "megatron": - # Megatron: training loop inside worker via ppo_train + # Megatron: use ppo_train via dispatch if self.has_critic: with Timer("critic_train", self.all_timings): critic_status = self.dispatch.ppo_train("critic", data) with Timer("policy_train", self.all_timings): policy_status = self.dispatch.ppo_train("policy", data) else: - # FSDP: training loop in trainer via forward_backward + optim_step + # FSDP/FSDP2: use forward_backward + optim_step via dispatch if self.has_critic: with Timer("critic_train", self.all_timings): - critic_status = self._execute_training_step("critic", data) + critic_status = self._execute_training_step("critic", data, "critic") with Timer("policy_train", self.all_timings): - policy_status = self._execute_training_step("policy", data) + policy_status = self._execute_training_step("policy", data, "policy") # Update metrics if critic_status is not None: @@ -1096,6 +1098,7 @@ def train_critic_and_policy(self, data: TrainingInputBatch): for k, v in policy_status.items(): self.all_metrics.update({f"policy/{k}": v}) + # Empty cache after training self.dispatch.empty_cache() return policy_status diff --git a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py index 8974abc28..f07941989 100644 --- a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py @@ -552,7 +552,6 @@ def ppo_train(self, train_data) -> "TrainingOutputBatch": # TODO: Convert this into 2 loops for minibatches and microbatches. micro_buffer = [] for local_step, experience in enumerate(pbar): - # BatchIterator now yields Experience objects directly experience.to_device(torch.cuda.current_device()) sequences = experience.sequences attention_mask = experience.attention_mask diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index 7fc253d6c..0f1726f4c 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -630,6 +630,27 @@ async def async_run_method( class PolicyWorkerBase(Worker): + # TODO(tgriggs): Remove once loss function naming is unified. + # Tinker loss_fn names -> SkyRL PolicyLossRegistry names + TINKER_LOSS_FN_MAP = {"ppo": "regular"} + + @staticmethod + def convert_tinker_loss_config(loss_fn_config: Dict[str, Any]) -> Dict[str, Any]: + """Convert Tinker loss_fn_config to SkyRL algorithm config format. + + Tinker uses absolute ratio bounds (e.g., 0.9, 1.1). + SkyRL uses offsets from 1.0 (e.g., 0.1, 0.1). + """ + skyrl_config = {} + for k, v in loss_fn_config.items(): + if k == "clip_low_threshold": + skyrl_config["eps_clip_low"] = 1.0 - v # 0.9 -> 0.1 + elif k == "clip_high_threshold": + skyrl_config["eps_clip_high"] = v - 1.0 # 1.1 -> 0.1 + else: + skyrl_config[k] = v + return skyrl_config + def __init__(self, **kwargs): super().__init__(**kwargs) self.model: nn.Module = None @@ -647,10 +668,6 @@ def _normalize_mini_batch_size(self): The worker no longer needs to know mini batch size - it processes whatever batch it receives, breaking it into micro batches. Gradient scaling happens at optim_step time based on how many micro batches were accumulated. - - TODO: Rename to _init_gradient_accumulation_state once Megatron no longer - requires mini-batch normalization in its override. The name is kept for - backwards compatibility with Megatron which still does actual normalization. """ if not hasattr(self, "mesh_rank") or self.mesh_rank is None: raise RuntimeError("mesh_rank must be initialized before calling _normalize_mini_batch_size()") @@ -658,6 +675,23 @@ def _normalize_mini_batch_size(self): # Track micro batches for gradient scaling at optim_step self._micro_batches_accumulated = 0 + dp_size = self.mesh_rank.dp_size + self.policy_mini_batch_size_per_gpu = ( + self.cfg.trainer.policy_mini_batch_size * self.cfg.generator.n_samples_per_prompt // dp_size + ) + + def _get_loss_fn(self, loss_fn: Optional[str] = None) -> Callable: + """Get loss function from Tinker name or fall back to config.""" + if loss_fn is None: + name = self.cfg.trainer.algorithm.policy_loss_type + elif loss_fn in self.TINKER_LOSS_FN_MAP: + name = self.TINKER_LOSS_FN_MAP[loss_fn] + else: + raise ValueError( + f"loss_fn '{loss_fn}' not yet supported. Supported: {list(self.TINKER_LOSS_FN_MAP.keys())}" + ) + return PolicyLossRegistry.get(name) + def forward_backward(self, data: TrainingInputBatch) -> Dict[str, float]: """ Perform forward and backward passes for a batch, handling micro-batching internally. @@ -758,6 +792,7 @@ def _forward_backward_micro(self, experience: Experience) -> Dict[str, float]: kl_loss_term = kl_loss * self.cfg.trainer.algorithm.kl_loss_coef loss = policy_loss + kl_loss_term - entropy_loss_term + # NO loss scaling here - gradient scaling happens at optim_step self.strategy.backward(loss, self.model, self.optimizer) status = { @@ -894,10 +929,6 @@ def _normalize_mini_batch_size(self): The worker no longer needs to know mini batch size - it processes whatever batch it receives, breaking it into micro batches. Gradient scaling happens at optim_step time based on how many micro batches were accumulated. - - TODO: Rename to _init_gradient_accumulation_state once Megatron no longer - requires mini-batch normalization in its override. The name is kept for - backwards compatibility with Megatron which still does actual normalization. """ if not hasattr(self, "mesh_rank") or self.mesh_rank is None: raise RuntimeError("mesh_rank must be initialized before calling _normalize_mini_batch_size()") diff --git a/skyrl-train/skyrl_train/workers/worker_dispatch.py b/skyrl-train/skyrl_train/workers/worker_dispatch.py index 8baefe4f7..8b4485930 100644 --- a/skyrl-train/skyrl_train/workers/worker_dispatch.py +++ b/skyrl-train/skyrl_train/workers/worker_dispatch.py @@ -9,7 +9,7 @@ """ from dataclasses import dataclass -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import ray from omegaconf import DictConfig @@ -153,6 +153,8 @@ def forward(self, model: str, data: TrainingInputBatch) -> TrainingOutputBatch: output = concatenate_outputs_after_mesh_dispatch(self._actor_groups[model].actor_infos, results) return output + # === Training === + def forward_backward(self, model: str, data: TrainingInputBatch) -> Dict[str, float]: """Run forward/backward pass. Needs model + optimizer.""" self._ensure_on_gpu(model, need_optimizer=True, need_model=True) diff --git a/skyrl-train/tests/cpu/test_trainer.py b/skyrl-train/tests/cpu/test_trainer.py index a704278e6..816f01101 100644 --- a/skyrl-train/tests/cpu/test_trainer.py +++ b/skyrl-train/tests/cpu/test_trainer.py @@ -206,9 +206,15 @@ def backload_to_gpu(self, non_blocking=True): def _forward_micro_batch(self, micro_batch): pass - def create_policy_worker_with_config(dp_size): + def create_policy_worker_with_config( + train_batch_size, policy_mini_batch_size, micro_train_batch_size_per_gpu, n_samples_per_prompt, dp_size + ): """Helper to create policy worker with specific config.""" cfg = get_default_config() + cfg.trainer.train_batch_size = train_batch_size + cfg.trainer.policy_mini_batch_size = policy_mini_batch_size + cfg.trainer.micro_train_batch_size_per_gpu = micro_train_batch_size_per_gpu + cfg.generator.n_samples_per_prompt = n_samples_per_prompt cfg.trainer.algorithm.policy_loss_type = "regular" worker = TestPolicyWorker( @@ -226,9 +232,15 @@ def create_policy_worker_with_config(dp_size): return worker - def create_critic_worker_with_config(dp_size): + def create_critic_worker_with_config( + train_batch_size, critic_mini_batch_size, micro_train_batch_size_per_gpu, n_samples_per_prompt, dp_size + ): """Helper to create critic worker with specific config.""" cfg = get_default_config() + cfg.trainer.train_batch_size = train_batch_size + cfg.trainer.critic_mini_batch_size = critic_mini_batch_size + cfg.trainer.micro_train_batch_size_per_gpu = micro_train_batch_size_per_gpu + cfg.generator.n_samples_per_prompt = n_samples_per_prompt worker = TestCriticWorker( cfg=cfg, @@ -246,28 +258,59 @@ def create_critic_worker_with_config(dp_size): return worker # Test Case 1: PolicyWorker initializes _micro_batches_accumulated - policy_worker = create_policy_worker_with_config(dp_size=4) + policy_worker = create_policy_worker_with_config( + train_batch_size=128, policy_mini_batch_size=16, micro_train_batch_size_per_gpu=2, n_samples_per_prompt=1, dp_size=4 + ) policy_worker._normalize_mini_batch_size() assert hasattr(policy_worker, "_micro_batches_accumulated") assert policy_worker._micro_batches_accumulated == 0 - # Test Case 2: CriticWorker initializes _micro_batches_accumulated - critic_worker = create_critic_worker_with_config(dp_size=4) + # Test Case 2: Basic valid configuration for CriticWorker + # In the new design, critic worker only initializes _micro_batches_accumulated + critic_worker = create_critic_worker_with_config( + train_batch_size=128, + critic_mini_batch_size=8, + micro_train_batch_size_per_gpu=2, + n_samples_per_prompt=2, + dp_size=4, + ) critic_worker._normalize_mini_batch_size() - assert hasattr(critic_worker, "_micro_batches_accumulated") + # Verify micro_batches_accumulated is initialized assert critic_worker._micro_batches_accumulated == 0 # Test Case 3: Single GPU (dp_size=1) for PolicyWorker - policy_worker = create_policy_worker_with_config(dp_size=1) + policy_worker = create_policy_worker_with_config( + train_batch_size=32, policy_mini_batch_size=8, micro_train_batch_size_per_gpu=2, n_samples_per_prompt=1, dp_size=1 + ) policy_worker._normalize_mini_batch_size() assert hasattr(policy_worker, "_micro_batches_accumulated") assert policy_worker._micro_batches_accumulated == 0 - # Test Case 4: Error case - mesh_rank not initialized - policy_worker_no_mesh = create_policy_worker_with_config(dp_size=4) + # Test Case 4: High n_samples_per_prompt for CriticWorker + # In the new design, critic worker only initializes _micro_batches_accumulated + critic_worker = create_critic_worker_with_config( + train_batch_size=256, + critic_mini_batch_size=32, + micro_train_batch_size_per_gpu=8, + n_samples_per_prompt=4, + dp_size=2, + ) + critic_worker._normalize_mini_batch_size() + + # Verify micro_batches_accumulated is initialized + assert critic_worker._micro_batches_accumulated == 0 + + # Test Case 5: Error case - mesh_rank not initialized + policy_worker_no_mesh = create_policy_worker_with_config( + train_batch_size=128, + policy_mini_batch_size=16, + micro_train_batch_size_per_gpu=2, + n_samples_per_prompt=1, + dp_size=4, + ) policy_worker_no_mesh.mesh_rank = None with pytest.raises(RuntimeError, match="mesh_rank must be initialized"): @@ -444,132 +487,6 @@ def create_test_config( validate_batch_sizes(cfg) -def test_forward_backward_batch_calculations(): - """Test the key batch calculations and control flow in forward_backward methods. - - FSDP workers use the forward_backward + optim_step pattern: - - forward_backward handles micro-batching internally and accumulates gradients - - optim_step scales gradients by 1/num_accumulated and takes optimizer step - """ - - # Create test configuration - cfg = get_default_config() - cfg.trainer.micro_train_batch_size_per_gpu = 2 - cfg.trainer.update_epochs_per_batch = 1 - cfg.trainer.algorithm.policy_loss_type = "regular" - cfg.generator.sampling_params.temperature = 1.0 - - # Create dummy databatch with known size - batch_size = 12 # This will create 6 micro batches with micro_train_batch_size_per_gpu=2 - response_length = 4 # number of actions - dummy_databatch = TrainingInputBatch( - { - "sequences": torch.randint(0, 100, (batch_size, 10)), # dummy token sequences - "attention_mask": torch.ones(batch_size, 10), - "action_log_probs": torch.randn(batch_size, response_length), - "base_action_log_probs": torch.randn(batch_size, response_length), - "values": torch.randn(batch_size, response_length), - "returns": torch.randn(batch_size, response_length), - "advantages": torch.randn(batch_size, response_length), - "loss_mask": torch.ones(batch_size, response_length), - "response_mask": torch.ones(batch_size, response_length), - "rollout_logprobs": None, - }, - ) - dummy_databatch.metadata = {"global_step": 0, "response_length": response_length} - - # Helper function to create worker with minimal setup - def create_test_worker(worker_class): - worker = worker_class( - cfg=cfg, - world_size=1, - rank=0, - local_rank=0, - master_addr="localhost", - master_port=12345, - sequence_parallel_size=1, - ) - - # Mock dependencies - worker.strategy = MagicMock() - worker.strategy.is_rank_0.return_value = False # Disable progress bars - worker.strategy.all_reduce.return_value = {"loss": 0.5, "lr": 1e-4} - - # Always set model for all worker types - worker.model = MagicMock() - - return worker - - # Test PolicyWorkerBase - policy_worker = create_test_worker(PolicyWorkerBase) - - # Initialize _micro_batches_accumulated (normally done in _normalize_mini_batch_size) - policy_worker._micro_batches_accumulated = 0 - - # Mock _forward_backward_micro to track calls - policy_forward_backward_micro_calls = [] - - def mock_policy_forward_backward_micro(experience): - policy_forward_backward_micro_calls.append(experience) - return {"policy_loss": 0.5, "ppo_clip_ratio": 0.1, "policy_entropy": 2.0, "response_length": response_length} - - policy_worker._forward_backward_micro = mock_policy_forward_backward_micro - policy_worker.record_memory = False - - # Calculate expected values - dataloader = BatchIterator( - dummy_databatch, sample_batch_size=cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False - ) - expected_micro_batches = len(dataloader) # Should be 6 - - # Run forward_backward - with (patch("torch.distributed.barrier"),): - result = policy_worker.forward_backward(dummy_databatch) - - # Verify Policy Worker Results - assert ( - len(policy_forward_backward_micro_calls) == expected_micro_batches - ), f"PolicyWorker: Expected {expected_micro_batches} _forward_backward_micro calls, got {len(policy_forward_backward_micro_calls)}" - - # Verify _micro_batches_accumulated is set correctly - assert policy_worker._micro_batches_accumulated == expected_micro_batches - - # Verify result structure - assert isinstance(result, dict) - assert "policy_loss" in result - - # Test CriticWorkerBase with same pattern - critic_worker = create_test_worker(CriticWorkerBase) - - # Initialize _micro_batches_accumulated (normally done in _normalize_mini_batch_size) - critic_worker._micro_batches_accumulated = 0 - - # Mock _forward_backward_micro for critic - critic_forward_backward_micro_calls = [] - - def mock_critic_forward_backward_micro(experience): - critic_forward_backward_micro_calls.append(experience) - return {"critic_loss": 0.3, "values_mean": 1.0} - - critic_worker._forward_backward_micro = mock_critic_forward_backward_micro - - # Run forward_backward for critic - with (patch("torch.distributed.barrier"),): - result = critic_worker.forward_backward(dummy_databatch) - - # Verify Critic Worker Results - assert ( - len(critic_forward_backward_micro_calls) == expected_micro_batches - ), f"CriticWorker: Expected {expected_micro_batches} _forward_backward_micro calls, got {len(critic_forward_backward_micro_calls)}" - - # Verify _micro_batches_accumulated is set correctly - assert critic_worker._micro_batches_accumulated == expected_micro_batches - - # Verify result structure for critic - assert isinstance(result, dict) - assert "critic_loss" in result - - def test_validate_batch_sizes_lcm_dp_requirement(): """Ensure train_batch_size is >= lcm(policy_dp, ref_dp) when ref is used; else >= policy_dp.""" @@ -611,3 +528,28 @@ def create_config(train_batch_size, policy_dp, ref_dp, include_ref=True): # Pass: ref disabled -> requirement reduces to policy_dp. With policy_dp=2, tbs=2 is valid. cfg = create_config(train_batch_size=2, policy_dp=2, ref_dp=3, include_ref=False) validate_batch_sizes(cfg) + + +def test_convert_tinker_loss_config(): + """Test that Tinker absolute ratio bounds are correctly converted to SkyRL offsets.""" + # Tinker uses absolute bounds: [0.9, 1.1] + # SkyRL uses offsets from 1.0: eps_clip_low=0.1, eps_clip_high=0.1 + tinker_config = {"clip_low_threshold": 0.9, "clip_high_threshold": 1.1} + skyrl_config = PolicyWorkerBase.convert_tinker_loss_config(tinker_config) + + assert skyrl_config["eps_clip_low"] == approx(0.1) + assert skyrl_config["eps_clip_high"] == approx(0.1) + + # Test asymmetric bounds + tinker_config = {"clip_low_threshold": 0.8, "clip_high_threshold": 1.2} + skyrl_config = PolicyWorkerBase.convert_tinker_loss_config(tinker_config) + + assert skyrl_config["eps_clip_low"] == approx(0.2) + assert skyrl_config["eps_clip_high"] == approx(0.2) + + # Test passthrough of unknown keys + tinker_config = {"clip_low_threshold": 0.9, "some_other_key": 42} + skyrl_config = PolicyWorkerBase.convert_tinker_loss_config(tinker_config) + + assert skyrl_config["eps_clip_low"] == approx(0.1) + assert skyrl_config["some_other_key"] == 42 diff --git a/skyrl-train/tests/gpu/gpu_ci/test_save_load_checkpoint.py b/skyrl-train/tests/gpu/gpu_ci/test_save_load_checkpoint.py index 7918bb4b0..ffb538d01 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_save_load_checkpoint.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_save_load_checkpoint.py @@ -28,7 +28,7 @@ def run_one_training_step( actor_group, strategy, - data=None, + training_batch=None, megatron_batch=None, ): """Run forward_backward + optim_step to perform one training step.""" @@ -36,8 +36,8 @@ def run_one_training_step( assert megatron_batch is not None, "Megatron requires a TrainingInputBatch for ppo_train" return ray.get(actor_group.async_run_ray_method("mesh", "ppo_train", megatron_batch)) else: - assert data is not None, f"{strategy} requires a TrainingInputBatch for forward_backward" - ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", data=data)) + assert training_batch is not None, f"{strategy} requires a TrainingInputBatch for forward_backward" + ray.get(actor_group.async_run_ray_method("pass_through", "forward_backward", training_batch)) ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) @@ -93,9 +93,8 @@ def test_save_load_checkpoint(ray_init_fixture, strategy, lora): checkpoint_dir = None # Create dummy training batches for training steps - dp_size = actor_group.actor_infos[0].rank.dp_size - dummy_batch_1 = make_dummy_training_batch(batch_size=dp_size) # First training step - dummy_batch_2 = make_dummy_training_batch(batch_size=dp_size) # Second training step + dummy_batch_1 = make_dummy_training_batch() # First training step + dummy_batch_2 = make_dummy_training_batch() # Second training step # Ensure the second batch is different from the first dummy_batch_2["sequences"] = torch.randint(100, 200, dummy_batch_2["sequences"].shape, device="cpu") @@ -115,7 +114,7 @@ def test_save_load_checkpoint(ray_init_fixture, strategy, lora): run_one_training_step( actor_group, strategy, - data=dummy_batch_1, + training_batch=dummy_batch_1, megatron_batch=train_batch_1, ) @@ -161,7 +160,7 @@ def test_save_load_checkpoint(ray_init_fixture, strategy, lora): run_one_training_step( actor_group, strategy, - data=dummy_batch_2, + training_batch=dummy_batch_2, megatron_batch=train_batch_2, ) @@ -181,7 +180,7 @@ def test_save_load_checkpoint(ray_init_fixture, strategy, lora): run_one_training_step( actor_group, strategy, - data=dummy_batch_2, + training_batch=dummy_batch_2, megatron_batch=train_batch_2, ) diff --git a/skyrl-train/tests/gpu/gpu_ci/test_training_step.py b/skyrl-train/tests/gpu/gpu_ci/test_training_step.py index 7880e00c2..3e12c60bd 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_training_step.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_training_step.py @@ -66,11 +66,9 @@ async def test_policy_forward_backward_and_optim_step(ray_init_fixture, cfg, pac cfg=cfg, ) - # Create TrainingInputBatch - worker's forward_backward handles micro-batching internally - dp_size = actor_group.actor_infos[0].rank.dp_size - dummy_batch = make_dummy_training_batch(batch_size=dp_size) + dummy_batch = make_dummy_training_batch() - results = ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", data=dummy_batch)) + results = ray.get(actor_group.async_run_ray_method("pass_through", "forward_backward", dummy_batch)) ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) memory = ray.get(actor_group.async_run_ray_method("pass_through", "get_cuda_memory")) @@ -116,11 +114,9 @@ async def test_critic_forward_backward_and_optim_step(ray_init_fixture, cfg, pac cfg=cfg, ) - # Create TrainingInputBatch - worker's forward_backward handles micro-batching internally - dp_size = actor_group.actor_infos[0].rank.dp_size - dummy_batch = make_dummy_training_batch(batch_size=dp_size) + dummy_batch = make_dummy_training_batch() - results = ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", data=dummy_batch)) + results = ray.get(actor_group.async_run_ray_method("pass_through", "forward_backward", dummy_batch)) ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) for result in results: diff --git a/skyrl-train/tests/gpu/gpu_ci/test_worker_offload.py b/skyrl-train/tests/gpu/gpu_ci/test_worker_offload.py index 450aa7511..3a593d971 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_worker_offload.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_worker_offload.py @@ -92,10 +92,9 @@ async def test_critic_policy_offload_memory_and_correctness(ray_init_fixture, cf actor_group.backload_to_gpu() get_rank_0_memory(actor_group, "Before training") - dp_size = actor_group.actor_infos[0].rank.dp_size - dummy_batch = make_dummy_training_batch(batch_size=dp_size) + dummy_batch = make_dummy_training_batch() # Run first forward_backward + optim_step to get optimizer initialized and stepped - results = ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", data=dummy_batch)) + results = ray.get(actor_group.async_run_ray_method("pass_through", "forward_backward", dummy_batch)) ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) after_training = get_rank_0_memory(actor_group, "After training") @@ -141,7 +140,7 @@ async def test_critic_policy_offload_memory_and_correctness(ray_init_fixture, cf ), f"Memory after backload model should be greater than after backload optimizer: {after_backload} bytes, after backload optimizer: {after_backload_optimizer} bytes" # Run training again and ensure output consistency - results_backload = ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", data=dummy_batch)) + results_backload = ray.get(actor_group.async_run_ray_method("pass_through", "forward_backward", dummy_batch)) ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) for i, result in enumerate(results): @@ -330,11 +329,10 @@ def test_offload_after_ckpt(ray_init_fixture, strategy): get_rank_0_memory(actor_group, "After init") # Create dummy training batch for training steps - dp_size = actor_group.actor_infos[0].rank.dp_size - dummy_batch_1 = make_dummy_training_batch(batch_size=dp_size) + dummy_batch_1 = make_dummy_training_batch() # First training step # Step 1: Do initial forward_backward + optim_step - ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", data=dummy_batch_1)) + ray.get(actor_group.async_run_ray_method("pass_through", "forward_backward", dummy_batch_1)) ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) get_rank_0_memory(actor_group, "After training step 1") diff --git a/skyrl-train/tests/gpu/test_save_load_model.py b/skyrl-train/tests/gpu/test_save_load_model.py index 71593cb86..abd5fd448 100644 --- a/skyrl-train/tests/gpu/test_save_load_model.py +++ b/skyrl-train/tests/gpu/test_save_load_model.py @@ -54,7 +54,7 @@ def get_test_actor_config(strategy: str) -> DictConfig: def run_one_training_step( actor_group, strategy, - data=None, + training_batch=None, megatron_batch=None, ): """Run forward_backward + optim_step to perform one training step.""" @@ -62,8 +62,8 @@ def run_one_training_step( assert megatron_batch is not None, "Megatron requires a TrainingInputBatch for ppo_train" return ray.get(actor_group.async_run_ray_method("mesh", "ppo_train", megatron_batch)) else: - assert data is not None, f"{strategy} requires a TrainingInputBatch for forward_backward" - ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", data=data)) + assert training_batch is not None, f"{strategy} requires a TrainingInputBatch for forward_backward" + ray.get(actor_group.async_run_ray_method("pass_through", "forward_backward", training_batch)) ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) @@ -105,15 +105,15 @@ def test_save_load_hf_model(ray_init_fixture, strategy): run_one_training_step( actor_group_1, strategy, - data=None, + training_batch=None, megatron_batch=train_batch_1, ) else: - dummy_batch = make_dummy_training_batch(batch_size=dp_size) + dummy_batch = make_dummy_training_batch() run_one_training_step( actor_group_1, strategy, - data=dummy_batch, + training_batch=dummy_batch, megatron_batch=None, )