From e2510931a51b50856e3761de478c9bfded4c7ef8 Mon Sep 17 00:00:00 2001 From: Tyler Griggs Date: Fri, 30 Jan 2026 01:49:19 +0000 Subject: [PATCH 1/4] Implement checkpointing for Tinker SkyRL backend Add full checkpoint save/load functionality to SkyRLTrainBackend: - save_checkpoint(): Saves model + optimizer + scheduler state as uncompressed tar - load_checkpoint(): Restores full training state from tar checkpoint - save_sampler_checkpoint(): Exports model weights in HuggingFace format for inference Implementation wraps WorkerDispatch checkpoint methods and handles tar packaging. Uses uncompressed tar to avoid 5-10 minute gzip bottleneck on 6-7GB FSDP checkpoints. Co-Authored-By: Claude Sonnet 4.5 --- skyrl-tx/tx/tinker/backends/skyrl_train.py | 80 +++++++++++++++++++++- 1 file changed, 77 insertions(+), 3 deletions(-) diff --git a/skyrl-tx/tx/tinker/backends/skyrl_train.py b/skyrl-tx/tx/tinker/backends/skyrl_train.py index 5751484e7..6d116ad8c 100644 --- a/skyrl-tx/tx/tinker/backends/skyrl_train.py +++ b/skyrl-tx/tx/tinker/backends/skyrl_train.py @@ -4,6 +4,9 @@ Currently supports a single model only. """ +import os +import tarfile +import tempfile from typing import Any import torch @@ -222,10 +225,81 @@ def sample( raise NotImplementedError("Sampling not supported") def save_checkpoint(self, output_path, model_id: str) -> None: - raise NotImplementedError("Saving checkpoints not supported") + """Save full training checkpoint (model + optimizer + scheduler) as tar.""" + if model_id != self._model_id: + raise ValueError(f"Model {model_id} not found") + if self._dispatch is None: + raise RuntimeError("Model not initialized") + + # Ensure parent directory exists + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + # Create temp directory for checkpoint + with tempfile.TemporaryDirectory() as temp_dir: + ckpt_dir = os.path.join(temp_dir, "checkpoint") + + # Save checkpoint directory (includes optimizer state automatically) + self._dispatch.save_checkpoint( + model="policy", + ckpt_dir=ckpt_dir, + tokenizer=self._tokenizer + ) + + # Create tar archive (uncompressed for speed) + # FSDP checkpoints are already large (6-7GB). Gzip compression adds + # 5-10 minutes of single-threaded CPU time that blocks training. + with tarfile.open(output_path, "w") as tar: + tar.add(ckpt_dir, arcname=".") + + logger.info(f"Saved checkpoint for {model_id} to {output_path}") def load_checkpoint(self, checkpoint_path, model_id: str) -> None: - raise NotImplementedError("Loading checkpoints not supported") + """Load full training checkpoint (model + optimizer + scheduler) from tar.""" + if model_id != self._model_id: + raise ValueError(f"Model {model_id} not found") + if self._dispatch is None: + raise RuntimeError("Model not initialized") + if not os.path.exists(checkpoint_path): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + # Extract tar to temp directory (auto-detects compression) + with tempfile.TemporaryDirectory() as temp_dir: + with tarfile.open(checkpoint_path, "r") as tar: + tar.extractall(temp_dir) + + # Load checkpoint (includes optimizer and scheduler states) + self._dispatch.load_checkpoint( + model="policy", + ckpt_dir=temp_dir, + load_optimizer_states=True, + load_lr_scheduler_states=True + ) + + logger.info(f"Loaded checkpoint for {model_id} from {checkpoint_path}") def save_sampler_checkpoint(self, output_path, model_id: str) -> None: - raise NotImplementedError("Sampler checkpoints not supported") + """Save sampler checkpoint as tar (model only, no optimizer).""" + if model_id != self._model_id: + raise ValueError(f"Model {model_id} not found") + if self._dispatch is None: + raise RuntimeError("Model not initialized") + + # Ensure parent directory exists + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + # Create temp directory for HuggingFace export + with tempfile.TemporaryDirectory() as temp_dir: + hf_dir = os.path.join(temp_dir, "model") + + # Save in HuggingFace format (model weights + tokenizer only) + self._dispatch.save_hf_model( + model="policy", + hf_model_dir=hf_dir, + tokenizer=self._tokenizer + ) + + # Create tar archive (uncompressed for speed) + with tarfile.open(output_path, "w") as tar: + tar.add(hf_dir, arcname=".") + + logger.info(f"Saved sampler checkpoint for {model_id} to {output_path}") From 4125427fcac7a91de7b0fe5aa8f870a0dca71ac2 Mon Sep 17 00:00:00 2001 From: Tyler Griggs Date: Fri, 30 Jan 2026 01:58:02 +0000 Subject: [PATCH 2/4] Format: Consolidate function call arguments to single line Co-Authored-By: Claude Sonnet 4.5 --- skyrl-tx/tx/tinker/backends/skyrl_train.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/skyrl-tx/tx/tinker/backends/skyrl_train.py b/skyrl-tx/tx/tinker/backends/skyrl_train.py index 6d116ad8c..a53abb8fb 100644 --- a/skyrl-tx/tx/tinker/backends/skyrl_train.py +++ b/skyrl-tx/tx/tinker/backends/skyrl_train.py @@ -239,11 +239,7 @@ def save_checkpoint(self, output_path, model_id: str) -> None: ckpt_dir = os.path.join(temp_dir, "checkpoint") # Save checkpoint directory (includes optimizer state automatically) - self._dispatch.save_checkpoint( - model="policy", - ckpt_dir=ckpt_dir, - tokenizer=self._tokenizer - ) + self._dispatch.save_checkpoint(model="policy", ckpt_dir=ckpt_dir, tokenizer=self._tokenizer) # Create tar archive (uncompressed for speed) # FSDP checkpoints are already large (6-7GB). Gzip compression adds @@ -269,10 +265,7 @@ def load_checkpoint(self, checkpoint_path, model_id: str) -> None: # Load checkpoint (includes optimizer and scheduler states) self._dispatch.load_checkpoint( - model="policy", - ckpt_dir=temp_dir, - load_optimizer_states=True, - load_lr_scheduler_states=True + model="policy", ckpt_dir=temp_dir, load_optimizer_states=True, load_lr_scheduler_states=True ) logger.info(f"Loaded checkpoint for {model_id} from {checkpoint_path}") @@ -292,11 +285,7 @@ def save_sampler_checkpoint(self, output_path, model_id: str) -> None: hf_dir = os.path.join(temp_dir, "model") # Save in HuggingFace format (model weights + tokenizer only) - self._dispatch.save_hf_model( - model="policy", - hf_model_dir=hf_dir, - tokenizer=self._tokenizer - ) + self._dispatch.save_hf_model(model="policy", hf_model_dir=hf_dir, tokenizer=self._tokenizer) # Create tar archive (uncompressed for speed) with tarfile.open(output_path, "w") as tar: From 84feed7a38e09a96280fb5d3e9e7ba3a62ce79f5 Mon Sep 17 00:00:00 2001 From: Tyler Griggs Date: Fri, 30 Jan 2026 02:10:22 +0000 Subject: [PATCH 3/4] Shorten checkpoint compression comment Co-Authored-By: Claude Sonnet 4.5 --- skyrl-tx/tx/tinker/backends/skyrl_train.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/skyrl-tx/tx/tinker/backends/skyrl_train.py b/skyrl-tx/tx/tinker/backends/skyrl_train.py index a53abb8fb..77ff3c1b4 100644 --- a/skyrl-tx/tx/tinker/backends/skyrl_train.py +++ b/skyrl-tx/tx/tinker/backends/skyrl_train.py @@ -241,9 +241,7 @@ def save_checkpoint(self, output_path, model_id: str) -> None: # Save checkpoint directory (includes optimizer state automatically) self._dispatch.save_checkpoint(model="policy", ckpt_dir=ckpt_dir, tokenizer=self._tokenizer) - # Create tar archive (uncompressed for speed) - # FSDP checkpoints are already large (6-7GB). Gzip compression adds - # 5-10 minutes of single-threaded CPU time that blocks training. + # Use uncompressed tar - gzip adds 5-10min CPU time on 6-7GB FSDP checkpoints with tarfile.open(output_path, "w") as tar: tar.add(ckpt_dir, arcname=".") @@ -287,7 +285,7 @@ def save_sampler_checkpoint(self, output_path, model_id: str) -> None: # Save in HuggingFace format (model weights + tokenizer only) self._dispatch.save_hf_model(model="policy", hf_model_dir=hf_dir, tokenizer=self._tokenizer) - # Create tar archive (uncompressed for speed) + # Use uncompressed tar - gzip adds 5-10min CPU time on 6-7GB FSDP checkpoints with tarfile.open(output_path, "w") as tar: tar.add(hf_dir, arcname=".") From 92bb84f016383728785821ad17cea1a83cf1ff14 Mon Sep 17 00:00:00 2001 From: Tyler Griggs Date: Fri, 30 Jan 2026 17:21:24 +0000 Subject: [PATCH 4/4] Security and refactoring improvements for checkpointing Address PR review feedback: 1. Security: Add filter='data' to tarfile.extractall() to prevent path traversal (TarSlip) attacks where malicious archives could write outside the temp directory 2. Refactor: Extract duplicate validation logic into _validate_model_state() helper method (used by all 3 checkpoint methods) 3. Remove redundant os.path.exists() check that creates TOCTOU race condition - tarfile.open() already raises FileNotFoundError 4. Refactor: Extract common tar creation logic into _create_tar_from_directory() helper method to reduce duplication Co-Authored-By: Claude Sonnet 4.5 --- skyrl-tx/tx/tinker/backends/skyrl_train.py | 43 ++++++++++------------ 1 file changed, 20 insertions(+), 23 deletions(-) diff --git a/skyrl-tx/tx/tinker/backends/skyrl_train.py b/skyrl-tx/tx/tinker/backends/skyrl_train.py index 77ff3c1b4..6d198093a 100644 --- a/skyrl-tx/tx/tinker/backends/skyrl_train.py +++ b/skyrl-tx/tx/tinker/backends/skyrl_train.py @@ -224,16 +224,26 @@ def sample( ) -> dict[str, types.SampleOutput | types.ErrorResponse]: raise NotImplementedError("Sampling not supported") - def save_checkpoint(self, output_path, model_id: str) -> None: - """Save full training checkpoint (model + optimizer + scheduler) as tar.""" + def _validate_model_state(self, model_id: str) -> None: + """Validate that model exists and is initialized.""" if model_id != self._model_id: raise ValueError(f"Model {model_id} not found") if self._dispatch is None: raise RuntimeError("Model not initialized") + def _create_tar_from_directory(self, source_dir: str, output_path: str) -> None: + """Create an uncompressed tar archive from a directory.""" # Ensure parent directory exists os.makedirs(os.path.dirname(output_path), exist_ok=True) + # Use uncompressed tar - gzip adds 5-10min CPU time on 6-7GB FSDP checkpoints + with tarfile.open(output_path, "w") as tar: + tar.add(source_dir, arcname=".") + + def save_checkpoint(self, output_path, model_id: str) -> None: + """Save full training checkpoint (model + optimizer + scheduler) as tar.""" + self._validate_model_state(model_id) + # Create temp directory for checkpoint with tempfile.TemporaryDirectory() as temp_dir: ckpt_dir = os.path.join(temp_dir, "checkpoint") @@ -241,25 +251,19 @@ def save_checkpoint(self, output_path, model_id: str) -> None: # Save checkpoint directory (includes optimizer state automatically) self._dispatch.save_checkpoint(model="policy", ckpt_dir=ckpt_dir, tokenizer=self._tokenizer) - # Use uncompressed tar - gzip adds 5-10min CPU time on 6-7GB FSDP checkpoints - with tarfile.open(output_path, "w") as tar: - tar.add(ckpt_dir, arcname=".") + # Create tar archive + self._create_tar_from_directory(ckpt_dir, output_path) logger.info(f"Saved checkpoint for {model_id} to {output_path}") def load_checkpoint(self, checkpoint_path, model_id: str) -> None: """Load full training checkpoint (model + optimizer + scheduler) from tar.""" - if model_id != self._model_id: - raise ValueError(f"Model {model_id} not found") - if self._dispatch is None: - raise RuntimeError("Model not initialized") - if not os.path.exists(checkpoint_path): - raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + self._validate_model_state(model_id) - # Extract tar to temp directory (auto-detects compression) + # Extract tar to temp directory (filter='data' prevents path traversal attacks) with tempfile.TemporaryDirectory() as temp_dir: with tarfile.open(checkpoint_path, "r") as tar: - tar.extractall(temp_dir) + tar.extractall(temp_dir, filter="data") # Load checkpoint (includes optimizer and scheduler states) self._dispatch.load_checkpoint( @@ -270,13 +274,7 @@ def load_checkpoint(self, checkpoint_path, model_id: str) -> None: def save_sampler_checkpoint(self, output_path, model_id: str) -> None: """Save sampler checkpoint as tar (model only, no optimizer).""" - if model_id != self._model_id: - raise ValueError(f"Model {model_id} not found") - if self._dispatch is None: - raise RuntimeError("Model not initialized") - - # Ensure parent directory exists - os.makedirs(os.path.dirname(output_path), exist_ok=True) + self._validate_model_state(model_id) # Create temp directory for HuggingFace export with tempfile.TemporaryDirectory() as temp_dir: @@ -285,8 +283,7 @@ def save_sampler_checkpoint(self, output_path, model_id: str) -> None: # Save in HuggingFace format (model weights + tokenizer only) self._dispatch.save_hf_model(model="policy", hf_model_dir=hf_dir, tokenizer=self._tokenizer) - # Use uncompressed tar - gzip adds 5-10min CPU time on 6-7GB FSDP checkpoints - with tarfile.open(output_path, "w") as tar: - tar.add(hf_dir, arcname=".") + # Create tar archive + self._create_tar_from_directory(hf_dir, output_path) logger.info(f"Saved sampler checkpoint for {model_id} to {output_path}")