Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 61 additions & 3 deletions skyrl-tx/tx/tinker/backends/skyrl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
Currently supports a single model only.
"""

import os
import tarfile
import tempfile
from typing import Any

import torch
Expand Down Expand Up @@ -221,11 +224,66 @@ def sample(
) -> dict[str, types.SampleOutput | types.ErrorResponse]:
raise NotImplementedError("Sampling not supported")

def _validate_model_state(self, model_id: str) -> None:
"""Validate that model exists and is initialized."""
if model_id != self._model_id:
raise ValueError(f"Model {model_id} not found")
if self._dispatch is None:
raise RuntimeError("Model not initialized")

def _create_tar_from_directory(self, source_dir: str, output_path: str) -> None:
"""Create an uncompressed tar archive from a directory."""
# Ensure parent directory exists
os.makedirs(os.path.dirname(output_path), exist_ok=True)

# Use uncompressed tar - gzip adds 5-10min CPU time on 6-7GB FSDP checkpoints
with tarfile.open(output_path, "w") as tar:
tar.add(source_dir, arcname=".")

def save_checkpoint(self, output_path, model_id: str) -> None:
raise NotImplementedError("Saving checkpoints not supported")
"""Save full training checkpoint (model + optimizer + scheduler) as tar."""
self._validate_model_state(model_id)

# Create temp directory for checkpoint
with tempfile.TemporaryDirectory() as temp_dir:
ckpt_dir = os.path.join(temp_dir, "checkpoint")

# Save checkpoint directory (includes optimizer state automatically)
self._dispatch.save_checkpoint(model="policy", ckpt_dir=ckpt_dir, tokenizer=self._tokenizer)

# Create tar archive
self._create_tar_from_directory(ckpt_dir, output_path)

logger.info(f"Saved checkpoint for {model_id} to {output_path}")

def load_checkpoint(self, checkpoint_path, model_id: str) -> None:
raise NotImplementedError("Loading checkpoints not supported")
"""Load full training checkpoint (model + optimizer + scheduler) from tar."""
self._validate_model_state(model_id)

# Extract tar to temp directory (filter='data' prevents path traversal attacks)
with tempfile.TemporaryDirectory() as temp_dir:
with tarfile.open(checkpoint_path, "r") as tar:
tar.extractall(temp_dir, filter="data")

# Load checkpoint (includes optimizer and scheduler states)
self._dispatch.load_checkpoint(
model="policy", ckpt_dir=temp_dir, load_optimizer_states=True, load_lr_scheduler_states=True
)

logger.info(f"Loaded checkpoint for {model_id} from {checkpoint_path}")

def save_sampler_checkpoint(self, output_path, model_id: str) -> None:
raise NotImplementedError("Sampler checkpoints not supported")
"""Save sampler checkpoint as tar (model only, no optimizer)."""
self._validate_model_state(model_id)

# Create temp directory for HuggingFace export
with tempfile.TemporaryDirectory() as temp_dir:
hf_dir = os.path.join(temp_dir, "model")

# Save in HuggingFace format (model weights + tokenizer only)
self._dispatch.save_hf_model(model="policy", hf_model_dir=hf_dir, tokenizer=self._tokenizer)

# Create tar archive
self._create_tar_from_directory(hf_dir, output_path)

logger.info(f"Saved sampler checkpoint for {model_id} to {output_path}")
Loading