diff --git a/skyrl-train/skyrl_train/tinker/backends/skyrl_train.py b/skyrl-train/skyrl_train/tinker/backends/skyrl_train.py index 3b32eeaf9..c99a4339d 100644 --- a/skyrl-train/skyrl_train/tinker/backends/skyrl_train.py +++ b/skyrl-train/skyrl_train/tinker/backends/skyrl_train.py @@ -6,6 +6,9 @@ print("[DEBUG] skyrl_train.py: Starting imports...", flush=True) +import os +import tarfile +import tempfile from typing import Any import torch @@ -246,10 +249,75 @@ def sample( raise NotImplementedError("Sampling not supported") def save_checkpoint(self, output_path, model_id: str) -> None: - raise NotImplementedError("Saving checkpoints not supported") + """Save full training checkpoint (model + optimizer + scheduler) as tar.gz.""" + if model_id != self._model_id: + raise ValueError(f"Model {model_id} not found") + if self._dispatch is None: + raise RuntimeError("Model not initialized") + + # Create temp directory for checkpoint + with tempfile.TemporaryDirectory() as temp_dir: + ckpt_dir = os.path.join(temp_dir, "checkpoint") + + # Save checkpoint directory (includes optimizer state automatically) + self._dispatch.save_checkpoint( + model="policy", + ckpt_dir=ckpt_dir, + tokenizer=self._tokenizer + ) + + # Create tar archive (uncompressed for speed) + # FSDP checkpoints are already large (6-7GB). Gzip compression adds + # 5-10 minutes of single-threaded CPU time that blocks training. + with tarfile.open(output_path, "w") as tar: + tar.add(ckpt_dir, arcname=".") + + logger.info(f"Saved checkpoint for {model_id} to {output_path}") def load_checkpoint(self, checkpoint_path, model_id: str) -> None: - raise NotImplementedError("Loading checkpoints not supported") + """Load full training checkpoint (model + optimizer + scheduler) from tar.gz.""" + if model_id != self._model_id: + raise ValueError(f"Model {model_id} not found") + if self._dispatch is None: + raise RuntimeError("Model not initialized") + if not os.path.exists(checkpoint_path): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + # Extract tar to temp directory (auto-detects compression) + with tempfile.TemporaryDirectory() as temp_dir: + with tarfile.open(checkpoint_path, "r") as tar: + tar.extractall(temp_dir) + + # Load checkpoint (includes optimizer and scheduler states) + self._dispatch.load_checkpoint( + model="policy", + ckpt_dir=temp_dir, + load_optimizer_states=True, + load_lr_scheduler_states=True + ) + + logger.info(f"Loaded checkpoint for {model_id} from {checkpoint_path}") def save_sampler_checkpoint(self, output_path, model_id: str) -> None: - raise NotImplementedError("Sampler checkpoints not supported") + """Save sampler checkpoint as tar.gz (model only, no optimizer).""" + if model_id != self._model_id: + raise ValueError(f"Model {model_id} not found") + if self._dispatch is None: + raise RuntimeError("Model not initialized") + + # Create temp directory for HuggingFace export + with tempfile.TemporaryDirectory() as temp_dir: + hf_dir = os.path.join(temp_dir, "model") + + # Save in HuggingFace format (model weights + tokenizer only) + self._dispatch.save_hf_model( + model="policy", + hf_model_dir=hf_dir, + tokenizer=self._tokenizer + ) + + # Create tar archive (uncompressed for speed) + with tarfile.open(output_path, "w") as tar: + tar.add(hf_dir, arcname=".") + + logger.info(f"Saved sampler checkpoint for {model_id} to {output_path}")