diff --git a/skyrl-tx/tests/tinker/test_engine.py b/skyrl-tx/tests/tinker/test_engine.py index 3319a8c3e..706ace219 100644 --- a/skyrl-tx/tests/tinker/test_engine.py +++ b/skyrl-tx/tests/tinker/test_engine.py @@ -1,12 +1,13 @@ from cloudpathlib import AnyPath from datetime import datetime, timedelta, timezone +import pytest from sqlmodel import Session, SQLModel from tx.tinker.engine import TinkerEngine from tx.tinker.config import EngineConfig from tx.tinker import types -from tx.tinker.db_models import SessionDB, ModelDB +from tx.tinker.db_models import SessionDB, ModelDB, FutureDB, RequestStatus BASE_MODEL = "trl-internal-testing/tiny-Qwen3ForCausalLM" @@ -80,3 +81,94 @@ def test_cleanup_stale_sessions(): # Run cleanup and assert one model was unloaded assert engine.cleanup_stale_sessions() == 1 assert not engine.backend.has_model(model_id) + + +class TestMaxMicroBatches: + """Tests for max_micro_batches limiting in find_batchable_model_passes.""" + + @staticmethod + def _make_request_data(num_sequences: int) -> dict: + """Create a ForwardBackwardInput request data with the given number of sequences.""" + data = [] + for _ in range(num_sequences): + data.append( + { + "model_input": {"chunks": [{"tokens": [1, 2, 3]}]}, + "loss_fn_inputs": { + "target_tokens": {"data": [2, 3, 4]}, + "weights": {"data": [1.0, 1.0, 1.0]}, + "advantages": {"data": [0.0, 0.0, 0.0]}, + "logprobs": {"data": [0.0, 0.0, 0.0]}, + }, + } + ) + return {"data": data, "loss_fn": "cross_entropy"} + + @staticmethod + def _create_engine(train_micro_batch_size: int, max_micro_batches: int) -> TinkerEngine: + """Create an engine with the given micro batch configuration.""" + config = EngineConfig( + base_model=BASE_MODEL, + checkpoints_base=AnyPath(""), + backend_config={ + "max_lora_adapters": 4, + "max_lora_rank": 32, + "train_micro_batch_size": train_micro_batch_size, + }, + max_micro_batches=max_micro_batches, + database_url="sqlite:///:memory:", + ) + engine = TinkerEngine(config) + SQLModel.metadata.create_all(engine.db_engine) + return engine + + def _add_requests(self, engine: TinkerEngine, sequence_counts: list[int]): + """Add FORWARD_BACKWARD requests with the given sequence counts.""" + with Session(engine.db_engine) as session: + for num_sequences in sequence_counts: + session.add( + FutureDB( + request_type=types.RequestType.FORWARD_BACKWARD, + model_id="model1", + request_data=self._make_request_data(num_sequences), + status=RequestStatus.PENDING, + ) + ) + session.commit() + + @pytest.mark.parametrize( + "train_micro_batch_size,max_micro_batches,sequence_counts,expected_count", + [ + # Gradient accumulation mode: ceil(16/4) + ceil(20/4) = 4 + 5 = 9 <= 10, ceil(8/4) = 2 would exceed + (4, 10, [16, 20, 8], 2), + # Full batch mode: each request counts as 1, so 3 requests fit in max_micro_batches=3 + (0, 3, [100, 200, 50, 75], 3), + # Disabled: all requests included when max_micro_batches=0 + (4, 0, [50] * 10, 10), + ], + ids=["gradient_accumulation", "full_batch_mode", "disabled"], + ) + def test_micro_batch_limiting(self, train_micro_batch_size, max_micro_batches, sequence_counts, expected_count): + """Test that micro batches are limited correctly under different configurations.""" + engine = self._create_engine(train_micro_batch_size, max_micro_batches) + self._add_requests(engine, sequence_counts) + + with Session(engine.db_engine) as session: + result = engine.find_batchable_model_passes(session, types.RequestType.FORWARD_BACKWARD) + + assert len(result) == expected_count + + def test_always_includes_at_least_one_request(self): + """Test that at least one request is always included even if it exceeds the limit.""" + # train_micro_batch_size=4, max_micro_batches=10 + # Request with 100 sequences = ceil(100/4) = 25 micro batches > 10 + # Should still be included to avoid starvation + engine = self._create_engine(train_micro_batch_size=4, max_micro_batches=10) + self._add_requests(engine, [100]) + + with Session(engine.db_engine) as session: + result = engine.find_batchable_model_passes(session, types.RequestType.FORWARD_BACKWARD) + + assert len(result) == 1 + _, req_data = list(result.values())[0] + assert len(req_data.data) == 100 diff --git a/skyrl-tx/tx/tinker/config.py b/skyrl-tx/tx/tinker/config.py index e126e5499..ab11a3d33 100644 --- a/skyrl-tx/tx/tinker/config.py +++ b/skyrl-tx/tx/tinker/config.py @@ -51,6 +51,10 @@ class EngineConfig(BaseModel): default=300, description="Seconds without heartbeat before session is considered stale. Set to -1 to disable cleanup.", ) + max_micro_batches: int = Field( + default=64, + description="Maximum number of micro batches per forward/forward_backward batch. Limits how many are processed before returning results to clients. Set to 0 to disable.", + ) def convert_env_var(env_name: str, env_value: str, expected_type: type): diff --git a/skyrl-tx/tx/tinker/engine.py b/skyrl-tx/tx/tinker/engine.py index b7f4b2917..2ba174a73 100644 --- a/skyrl-tx/tx/tinker/engine.py +++ b/skyrl-tx/tx/tinker/engine.py @@ -1,6 +1,7 @@ """Background engine for processing training requests.""" import argparse +import math import time from contextlib import contextmanager from datetime import datetime, timedelta, timezone @@ -270,6 +271,26 @@ def find_batchable_model_passes( # Filter: only include ops that come before their model's barrier batchable = [op for op in ops if op.model_id not in barriers or op.request_id < barriers[op.model_id]] + # Limit total micro batches if configured + if self.config.max_micro_batches > 0 and isinstance(self.backend, JaxBackend): + micro_batch_size = self.backend.config.train_micro_batch_size + limited = [] + total_micro_batches = 0 + for op in batchable: + num_sequences = len(op.request_data.get("data", [])) + if micro_batch_size > 0: + # Gradient accumulation enabled: count actual micro batches + num_micro_batches = math.ceil(num_sequences / micro_batch_size) + else: + # Full batch mode: each request is processed as one unit + num_micro_batches = 1 + # Always include at least one request to avoid starvation + if limited and total_micro_batches + num_micro_batches > self.config.max_micro_batches: + break + limited.append(op) + total_micro_batches += num_micro_batches + batchable = limited + return { str(f.request_id): (f.model_id, types.ForwardBackwardInput.model_validate(f.request_data)) for f in batchable