diff --git a/skyrl-tx/tx/tinker/backends/skyrl_train.py b/skyrl-tx/tx/tinker/backends/skyrl_train.py index 5751484e7..6d198093a 100644 --- a/skyrl-tx/tx/tinker/backends/skyrl_train.py +++ b/skyrl-tx/tx/tinker/backends/skyrl_train.py @@ -4,6 +4,9 @@ Currently supports a single model only. """ +import os +import tarfile +import tempfile from typing import Any import torch @@ -221,11 +224,66 @@ def sample( ) -> dict[str, types.SampleOutput | types.ErrorResponse]: raise NotImplementedError("Sampling not supported") + def _validate_model_state(self, model_id: str) -> None: + """Validate that model exists and is initialized.""" + if model_id != self._model_id: + raise ValueError(f"Model {model_id} not found") + if self._dispatch is None: + raise RuntimeError("Model not initialized") + + def _create_tar_from_directory(self, source_dir: str, output_path: str) -> None: + """Create an uncompressed tar archive from a directory.""" + # Ensure parent directory exists + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + # Use uncompressed tar - gzip adds 5-10min CPU time on 6-7GB FSDP checkpoints + with tarfile.open(output_path, "w") as tar: + tar.add(source_dir, arcname=".") + def save_checkpoint(self, output_path, model_id: str) -> None: - raise NotImplementedError("Saving checkpoints not supported") + """Save full training checkpoint (model + optimizer + scheduler) as tar.""" + self._validate_model_state(model_id) + + # Create temp directory for checkpoint + with tempfile.TemporaryDirectory() as temp_dir: + ckpt_dir = os.path.join(temp_dir, "checkpoint") + + # Save checkpoint directory (includes optimizer state automatically) + self._dispatch.save_checkpoint(model="policy", ckpt_dir=ckpt_dir, tokenizer=self._tokenizer) + + # Create tar archive + self._create_tar_from_directory(ckpt_dir, output_path) + + logger.info(f"Saved checkpoint for {model_id} to {output_path}") def load_checkpoint(self, checkpoint_path, model_id: str) -> None: - raise NotImplementedError("Loading checkpoints not supported") + """Load full training checkpoint (model + optimizer + scheduler) from tar.""" + self._validate_model_state(model_id) + + # Extract tar to temp directory (filter='data' prevents path traversal attacks) + with tempfile.TemporaryDirectory() as temp_dir: + with tarfile.open(checkpoint_path, "r") as tar: + tar.extractall(temp_dir, filter="data") + + # Load checkpoint (includes optimizer and scheduler states) + self._dispatch.load_checkpoint( + model="policy", ckpt_dir=temp_dir, load_optimizer_states=True, load_lr_scheduler_states=True + ) + + logger.info(f"Loaded checkpoint for {model_id} from {checkpoint_path}") def save_sampler_checkpoint(self, output_path, model_id: str) -> None: - raise NotImplementedError("Sampler checkpoints not supported") + """Save sampler checkpoint as tar (model only, no optimizer).""" + self._validate_model_state(model_id) + + # Create temp directory for HuggingFace export + with tempfile.TemporaryDirectory() as temp_dir: + hf_dir = os.path.join(temp_dir, "model") + + # Save in HuggingFace format (model weights + tokenizer only) + self._dispatch.save_hf_model(model="policy", hf_model_dir=hf_dir, tokenizer=self._tokenizer) + + # Create tar archive + self._create_tar_from_directory(hf_dir, output_path) + + logger.info(f"Saved sampler checkpoint for {model_id} to {output_path}")