Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 71 additions & 3 deletions skyrl-train/skyrl_train/tinker/backends/skyrl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

print("[DEBUG] skyrl_train.py: Starting imports...", flush=True)

import os
import tarfile
import tempfile
from typing import Any

import torch
Expand Down Expand Up @@ -246,10 +249,75 @@ def sample(
raise NotImplementedError("Sampling not supported")

def save_checkpoint(self, output_path, model_id: str) -> None:
raise NotImplementedError("Saving checkpoints not supported")
"""Save full training checkpoint (model + optimizer + scheduler) as tar.gz."""
if model_id != self._model_id:
raise ValueError(f"Model {model_id} not found")
if self._dispatch is None:
raise RuntimeError("Model not initialized")

# Create temp directory for checkpoint
with tempfile.TemporaryDirectory() as temp_dir:
ckpt_dir = os.path.join(temp_dir, "checkpoint")

# Save checkpoint directory (includes optimizer state automatically)
self._dispatch.save_checkpoint(
model="policy",
ckpt_dir=ckpt_dir,
tokenizer=self._tokenizer
)

# Create tar archive (uncompressed for speed)
# FSDP checkpoints are already large (6-7GB). Gzip compression adds
# 5-10 minutes of single-threaded CPU time that blocks training.
with tarfile.open(output_path, "w") as tar:
tar.add(ckpt_dir, arcname=".")

logger.info(f"Saved checkpoint for {model_id} to {output_path}")

def load_checkpoint(self, checkpoint_path, model_id: str) -> None:
raise NotImplementedError("Loading checkpoints not supported")
"""Load full training checkpoint (model + optimizer + scheduler) from tar.gz."""
if model_id != self._model_id:
raise ValueError(f"Model {model_id} not found")
if self._dispatch is None:
raise RuntimeError("Model not initialized")
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

# Extract tar to temp directory (auto-detects compression)
with tempfile.TemporaryDirectory() as temp_dir:
with tarfile.open(checkpoint_path, "r") as tar:
tar.extractall(temp_dir)

# Load checkpoint (includes optimizer and scheduler states)
self._dispatch.load_checkpoint(
model="policy",
ckpt_dir=temp_dir,
load_optimizer_states=True,
load_lr_scheduler_states=True
)

logger.info(f"Loaded checkpoint for {model_id} from {checkpoint_path}")

def save_sampler_checkpoint(self, output_path, model_id: str) -> None:
raise NotImplementedError("Sampler checkpoints not supported")
"""Save sampler checkpoint as tar.gz (model only, no optimizer)."""
if model_id != self._model_id:
raise ValueError(f"Model {model_id} not found")
if self._dispatch is None:
raise RuntimeError("Model not initialized")

# Create temp directory for HuggingFace export
with tempfile.TemporaryDirectory() as temp_dir:
hf_dir = os.path.join(temp_dir, "model")

# Save in HuggingFace format (model weights + tokenizer only)
self._dispatch.save_hf_model(
model="policy",
hf_model_dir=hf_dir,
tokenizer=self._tokenizer
)

# Create tar archive (uncompressed for speed)
with tarfile.open(output_path, "w") as tar:
tar.add(hf_dir, arcname=".")

logger.info(f"Saved sampler checkpoint for {model_id} to {output_path}")
Loading