diff --git a/skyrl-train/skyrl_train/tinker/alembic.ini b/skyrl-train/skyrl_train/tinker/alembic.ini new file mode 100644 index 000000000..d3eae3e16 --- /dev/null +++ b/skyrl-train/skyrl_train/tinker/alembic.ini @@ -0,0 +1,145 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts. +# Use absolute path to make it work from any directory +script_location = alembic + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. for multiple paths, the path separator +# is defined by "path_separator" below. +prepend_sys_path = . + + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the tzdata library which can be installed by adding +# `alembic[tz]` to the pip requirements. +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to /versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "path_separator" +# below. +# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions + +# path_separator; This indicates what character is used to split lists of file +# paths, including version_locations and prepend_sys_path within configparser +# files such as alembic.ini. +# The default rendered in new alembic.ini files is "os", which uses os.pathsep +# to provide os-dependent path splitting. +# +# Note that in order to support legacy alembic.ini files, this default does NOT +# take place if path_separator is not present in alembic.ini. If this +# option is omitted entirely, fallback logic is as follows: +# +# 1. Parsing of the version_locations option falls back to using the legacy +# "version_path_separator" key, which if absent then falls back to the legacy +# behavior of splitting on spaces and/or commas. +# 2. Parsing of the prepend_sys_path option falls back to the legacy +# behavior of splitting on spaces, commas, or colons. +# +# Valid values for path_separator are: +# +# path_separator = : +# path_separator = ; +# path_separator = space +# path_separator = newline +# +# Use os.pathsep. Default configuration used for new projects. +path_separator = os + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +# database URL. This is consumed by the user-maintained env.py script only. +# other means of configuring database URLs may be customized within the env.py +# file. +sqlalchemy.url = driver://user:pass@localhost/dbname + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module +# hooks = ruff +# ruff.type = module +# ruff.module = ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME + +# Alternatively, use the exec runner to execute a binary found on your PATH +# hooks = ruff +# ruff.type = exec +# ruff.executable = ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME + +# Logging configuration. This is also consumed by the user-maintained +# env.py script only. +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARNING +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARNING +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/skyrl-train/skyrl_train/tinker/alembic/README b/skyrl-train/skyrl_train/tinker/alembic/README new file mode 100644 index 000000000..98e4f9c44 --- /dev/null +++ b/skyrl-train/skyrl_train/tinker/alembic/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/skyrl-train/skyrl_train/tinker/alembic/env.py b/skyrl-train/skyrl_train/tinker/alembic/env.py new file mode 100644 index 000000000..d4a75c215 --- /dev/null +++ b/skyrl-train/skyrl_train/tinker/alembic/env.py @@ -0,0 +1,83 @@ +from logging.config import fileConfig +import os +import sys +from pathlib import Path + +from sqlalchemy import pool + +from alembic import context + +# Add parent directory to path so we can import tx modules +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent)) + +# Import SQLModel and database models +from sqlmodel import SQLModel + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# add your model's MetaData object here +# for 'autogenerate' support +# Use SQLModel metadata which includes all our table definitions +target_metadata = SQLModel.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + from sqlalchemy import create_engine + + # Get database URL - ignore whatever is in config, use our helper + db_url = os.environ["TX_DATABASE_URL"] + connectable = create_engine(db_url, poolclass=pool.NullPool) + + with connectable.connect() as connection: + context.configure(connection=connection, target_metadata=target_metadata) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/skyrl-train/skyrl_train/tinker/alembic/script.py.mako b/skyrl-train/skyrl_train/tinker/alembic/script.py.mako new file mode 100644 index 000000000..11016301e --- /dev/null +++ b/skyrl-train/skyrl_train/tinker/alembic/script.py.mako @@ -0,0 +1,28 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + """Upgrade schema.""" + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + """Downgrade schema.""" + ${downgrades if downgrades else "pass"} diff --git a/skyrl-train/skyrl_train/tinker/api.py b/skyrl-train/skyrl_train/tinker/api.py new file mode 100644 index 000000000..626129a2b --- /dev/null +++ b/skyrl-train/skyrl_train/tinker/api.py @@ -0,0 +1,1171 @@ +import fastapi +from fastapi import FastAPI, HTTPException, Depends, Request +from fastapi.responses import StreamingResponse, RedirectResponse +from pydantic import BaseModel, Field, model_validator +from typing import Literal, Any, AsyncGenerator +from datetime import datetime, timedelta, timezone +from uuid import uuid4 +from contextlib import asynccontextmanager, suppress +from sqlmodel import SQLModel, select, func +from sqlmodel.ext.asyncio.session import AsyncSession +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.exc import IntegrityError, TimeoutError as SATimeoutError +import asyncio +import os +import signal +import random +import threading +import time + +from skyrl_train.tinker import types +from skyrl_train.tinker.config import EngineConfig, add_model, config_to_argv +from skyrl_train.tinker.db_models import ( + CheckpointDB, + ModelDB, + FutureDB, + RequestStatus, + CheckpointStatus, + SessionDB, + SamplingSessionDB, + get_async_database_url, +) +from skyrl_train.tinker.extra import ExternalInferenceClient +from skyrl_train.tx_utils.storage import download_file +from skyrl_train.tx_utils.log import logger + +# Validation patterns for train_run_ids, model_ids and checkpoint_ids +ID_PATTERN = r"^[a-zA-Z0-9_-]+$" +ID_MAX_LENGTH = 255 + +# Timeout for graceful shutdown when engine crashes +SHUTDOWN_TIMEOUT_SECONDS = 10 + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Lifespan event handler for startup and shutdown.""" + + db_url = get_async_database_url(app.state.engine_config.database_url) + app.state.db_engine = create_async_engine(db_url, echo=False) + + async with app.state.db_engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) + + # Setup external inference client if configured + if app.state.engine_config.external_inference_url: + app.state.external_inference_client = ExternalInferenceClient(app.state.engine_config, app.state.db_engine) + logger.info(f"External engine configured: {app.state.engine_config.external_inference_url}") + else: + app.state.external_inference_client = None + logger.info("Using internal engine for inference") + + # Build subprocess command with engine config parameters + cmd = ["uv", "run", "--extra", "vllm", "-m", "skyrl_train.tinker.engine"] + cmd.extend(config_to_argv(app.state.engine_config)) + + background_engine = await asyncio.create_subprocess_exec(*cmd) + app.state.background_engine = background_engine + logger.info(f"Started background engine with PID {background_engine.pid}: {' '.join(cmd)}") + + shutting_down = False + + async def monitor_engine(): + """Monitor engine process and exit API server if it crashes.""" + exit_code = await background_engine.wait() + if not shutting_down: + logger.error(f"Background engine crashed with exit code {exit_code}, exiting API server") + + # Start a background timer that force-exits after timeout. + # Using a thread instead of asyncio task because SIGTERM handling + # may wait for pending asyncio tasks to complete before exiting. + def force_exit(): + logger.warning("Graceful shutdown timed out, forcing exit") + os._exit(1) + + timer = threading.Timer(SHUTDOWN_TIMEOUT_SECONDS, force_exit) + timer.daemon = True + timer.start() + + # Request graceful shutdown. Uvicorn will stop accepting new + # connections and wait for active requests to complete. + # If shutdown doesn't complete in time, force_exit() will terminate. + os.kill(os.getpid(), signal.SIGTERM) + + monitor_task = asyncio.create_task(monitor_engine()) + + yield + + shutting_down = True + monitor_task.cancel() + + logger.info(f"Stopping background engine (PID {app.state.background_engine.pid})") + with suppress(ProcessLookupError): + background_engine.terminate() + try: + await asyncio.wait_for(background_engine.wait(), timeout=5) + except asyncio.TimeoutError: + logger.warning(f"Background engine (PID {background_engine.pid}) did not terminate gracefully, killing") + background_engine.kill() + await background_engine.wait() + logger.info("Background engine stopped") + + +app = FastAPI(title="Tinker API Mock", version="0.0.1", lifespan=lifespan) + + +async def get_session(request: Request) -> AsyncGenerator[AsyncSession, None]: + """Dependency to get a database session.""" + async with AsyncSession(request.app.state.db_engine) as session: + yield session + + +async def get_model(session: AsyncSession, model_id: str) -> ModelDB: + """Fetch a model by ID, raising 404 if not found.""" + statement = select(ModelDB).where(ModelDB.model_id == model_id) + result = await session.exec(statement) + model = result.first() + if not model: + raise HTTPException(status_code=404, detail="Model not found") + return model + + +async def create_future( + session: AsyncSession, + request_type: types.RequestType, + model_id: str | None, + request_data: BaseModel, +) -> int: + """Create a FutureDB entry and return its auto-generated request_id.""" + future_db = FutureDB( + request_type=request_type, + model_id=model_id, + request_data=request_data.model_dump(), + status=RequestStatus.PENDING, + ) + session.add(future_db) + await session.flush() # Flush to generate auto-increment request_id + assert future_db.request_id + return future_db.request_id + + +async def create_checkpoint( + session: AsyncSession, + model_id: str, + checkpoint_id: str, + checkpoint_type: types.CheckpointType, +): + """Create a pending CheckpointDB entry, relying on database constraints for validation.""" + checkpoint_db = CheckpointDB( + model_id=model_id, + checkpoint_id=checkpoint_id, + checkpoint_type=checkpoint_type, + status=CheckpointStatus.PENDING, + ) + session.add(checkpoint_db) + + try: + await session.flush() + except IntegrityError: + # Determine which constraint failed by checking if the model exists + statement = select(ModelDB).where(ModelDB.model_id == model_id) + result = await session.exec(statement) + + if not result.first(): + raise HTTPException(status_code=404, detail=f"Model '{model_id}' not found") + else: + raise HTTPException( + status_code=409, detail=f"Checkpoint '{checkpoint_id}' already exists for model '{model_id}'" + ) + + +class LoRAConfig(BaseModel): + rank: int + seed: int | None = Field( + default=None, description="Seed for LoRA weight initialization. If None, a random seed is used." + ) + + +class CreateModelRequest(BaseModel): + session_id: str + base_model: str + lora_config: LoRAConfig + + +class CreateModelResponse(BaseModel): + model_id: str + base_model: str + lora_config: LoRAConfig | None = None + status: str = "created" + request_id: str + + +class UnloadModelRequest(BaseModel): + model_id: str + type: str | None = None + + +class UnloadModelResponse(BaseModel): + request_id: str + model_id: str + + +class ModelData(BaseModel): + base_model: str + lora_config: LoRAConfig | None = None + model_name: str | None = None + + +class ModelInfoResponse(BaseModel): + model_id: str + status: str + model_data: ModelData + + +class Checkpoint(BaseModel): + checkpoint_id: str + checkpoint_type: Literal["training", "sampler"] + time: datetime + tinker_path: str + + +class TrainingRun(BaseModel): + training_run_id: str + base_model: str + model_owner: str = "default" + is_lora: bool = True + corrupted: bool = False + lora_rank: int | None = None + last_request_time: datetime + last_checkpoint: Checkpoint | None = None + last_sampler_checkpoint: Checkpoint | None = None + user_metadata: dict[str, str] | None = None + + +class ModelInputChunk(BaseModel): + tokens: list[int] + + def to_types(self) -> types.ModelInputChunk: + return types.ModelInputChunk(tokens=self.tokens) + + +class ModelInput(BaseModel): + chunks: list[ModelInputChunk] + + def to_types(self) -> types.ModelInput: + return types.ModelInput(chunks=[chunk.to_types() for chunk in self.chunks]) + + +class TensorData(BaseModel): + data: list[int] | list[float] + + def to_types(self) -> types.TensorData: + return types.TensorData(data=self.data) + + +class Datum(BaseModel): + loss_fn_inputs: dict[str, TensorData] + model_input: ModelInput + + def to_types(self) -> types.Datum: + inp = self.loss_fn_inputs + + if "weights" not in inp: + weights = types.TensorData(data=[1.0] * len(inp["target_tokens"].data)) + else: + weights = inp["weights"].to_types() + + return types.Datum( + loss_fn_inputs=types.LossFnInputs( + target_tokens=inp["target_tokens"].to_types(), + weights=weights, + advantages=inp["advantages"].to_types() if "advantages" in inp else types.TensorData(data=[]), + logprobs=inp["logprobs"].to_types() if "logprobs" in inp else types.TensorData(data=[]), + ), + model_input=self.model_input.to_types(), + ) + + +class ForwardBackwardInput(BaseModel): + data: list[Datum] + loss_fn: Literal["cross_entropy", "importance_sampling", "ppo"] + + def to_types(self) -> types.ForwardBackwardInput: + return types.ForwardBackwardInput(data=[datum.to_types() for datum in self.data], loss_fn=self.loss_fn) + + +class ForwardBackwardRequest(BaseModel): + model_id: str + forward_backward_input: ForwardBackwardInput + + +class ForwardRequest(BaseModel): + model_id: str + forward_input: ForwardBackwardInput + + +class AdamParams(BaseModel): + learning_rate: float = Field(default=1e-4, ge=0.0) + beta1: float = Field(default=0.9, ge=0.0, lt=1.0) + beta2: float = Field(default=0.95, ge=0.0, lt=1.0) + eps: float = Field(default=1e-12, gt=0.0) + weight_decay: float = Field(default=0.0, ge=0.0) + + def to_types(self) -> types.AdamParams: + return types.AdamParams( + learning_rate=self.learning_rate, + beta1=self.beta1, + beta2=self.beta2, + eps=self.eps, + weight_decay=self.weight_decay, + ) + + +class OptimStepRequest(BaseModel): + model_id: str + adam_params: AdamParams + + +class SaveWeightsForSamplerRequest(BaseModel): + model_id: str + path: str | None = Field(default=None, pattern=ID_PATTERN, max_length=ID_MAX_LENGTH) + sampling_session_seq_id: int | None = None + seq_id: int | None = None + type: Literal["save_weights_for_sampler"] = "save_weights_for_sampler" + + @model_validator(mode="after") + def check_path_or_ids(self): + if not self.path and (self.sampling_session_seq_id is None or self.seq_id is None): + raise ValueError("Either 'path' or both 'sampling_session_seq_id' and 'seq_id' must be provided") + return self + + +class SamplingParams(BaseModel): + max_tokens: int | None = None + seed: int | None = None + stop: list[int] | list[str] | None = None + temperature: float = 1 + top_k: int = -1 + top_p: float = 1 + + def to_types(self) -> types.SamplingParams: + if self.max_tokens is None: + raise HTTPException(status_code=400, detail="max_tokens is currently required") + + # Generate a random seed if not provided + seed = self.seed if self.seed is not None else random.randint(0, 2**31 - 1) + + # Determine if stop values are token IDs (int) or strings + stop_tokens = None + stop_strings = None + if self.stop: + if all(isinstance(s, int) for s in self.stop): + stop_tokens = list(self.stop) + elif all(isinstance(s, str) for s in self.stop): + stop_strings = list(self.stop) + else: + raise HTTPException( + status_code=400, + detail="stop must be either all integers (token IDs) or all strings, not mixed", + ) + + return types.SamplingParams( + temperature=self.temperature, + max_tokens=self.max_tokens, + seed=seed, + stop_tokens=stop_tokens, + stop_strings=stop_strings, + top_k=self.top_k, + top_p=self.top_p, + ) + + +class SampleRequest(BaseModel): + num_samples: int = 1 + prompt: ModelInput + sampling_params: SamplingParams + base_model: str | None = None + model_path: str | None = None + sampling_session_id: str | None = None + seq_id: int | None = None + prompt_logprobs: bool | None = None + topk_prompt_logprobs: int = 0 + type: Literal["sample"] = "sample" + + @model_validator(mode="after") + def validate_model_source(self): + """Valid if: + - sampling_session_id is provided AND seq_id is provided + - OR exactly one of base_model or model_path is provided + """ + if self.sampling_session_id is not None: + if self.seq_id is None: + raise ValueError("'seq_id' must be provided when 'sampling_session_id' is used") + return self + if (self.base_model is None) == (self.model_path is None): + raise ValueError( + "When 'sampling_session_id' is not provided, exactly one of 'base_model' or 'model_path' must be provided" + ) + return self + + +class SaveWeightsRequest(BaseModel): + model_id: str + path: str = Field(..., pattern=ID_PATTERN, max_length=ID_MAX_LENGTH) + type: Literal["save_weights"] | None = None + + +class LoadWeightsRequest(BaseModel): + model_id: str + path: str + type: Literal["load_weights"] | None = None + + +class FutureResponse(BaseModel): + future_id: str + status: str = "pending" + request_id: str + + +class TelemetryEvent(BaseModel): + event: str + event_id: str + event_session_index: int + severity: str + timestamp: str + properties: dict[str, Any] | None = None + + +class TelemetryRequest(BaseModel): + events: list[TelemetryEvent] + platform: str + sdk_version: str + session_id: str + + +class TelemetryResponse(BaseModel): + status: Literal["accepted"] = "accepted" + + +class HealthResponse(BaseModel): + status: Literal["ok"] + + +class CreateSessionRequest(BaseModel): + tags: list[str] + user_metadata: dict[str, Any] | None = None + sdk_version: str + type: Literal["create_session"] = "create_session" + + +class CreateSessionResponse(BaseModel): + type: Literal["create_session"] = "create_session" + info_message: str | None = None + warning_message: str | None = None + error_message: str | None = None + session_id: str + + +class SessionHeartbeatRequest(BaseModel): + session_id: str + type: Literal["session_heartbeat"] = "session_heartbeat" + + +class SessionHeartbeatResponse(BaseModel): + type: Literal["session_heartbeat"] = "session_heartbeat" + + +class CreateSamplingSessionRequest(BaseModel): + session_id: str + sampling_session_seq_id: int + base_model: str | None = None + model_path: str | None = None + type: Literal["create_sampling_session"] = "create_sampling_session" + + +class CreateSamplingSessionResponse(BaseModel): + type: Literal["create_sampling_session"] = "create_sampling_session" + sampling_session_id: str + + +class SupportedModel(BaseModel): + model_name: str + + +class GetServerCapabilitiesResponse(BaseModel): + supported_models: list[SupportedModel] + + +class ListCheckpointsResponse(BaseModel): + checkpoints: list[Checkpoint] + + +class Cursor(BaseModel): + offset: int + limit: int + total_count: int + + +class TrainingRunsResponse(BaseModel): + training_runs: list[TrainingRun] + cursor: Cursor + + +class WeightsInfoRequest(BaseModel): + tinker_path: str + + +class WeightsInfoResponse(BaseModel): + """Minimal information for loading public checkpoints.""" + + # from: https://github.com/thinking-machines-lab/tinker/blob/main/src/tinker/types/weights_info_response.py + base_model: str + is_lora: bool + lora_rank: int | None = None + + +@app.get("/api/v1/healthz", response_model=HealthResponse) +async def healthz(): + """Checks if the API server is ready.""" + return HealthResponse(status="ok") + + +@app.post("/api/v1/create_session", response_model=CreateSessionResponse) +async def create_session(request: CreateSessionRequest, session: AsyncSession = Depends(get_session)): + """Create a new session + persist in DB""" + session_id = f"session_{uuid4().hex[:8]}" + session_db = SessionDB( + session_id=session_id, + tags=request.tags, + user_metadata=request.user_metadata or {}, + sdk_version=request.sdk_version, + status="active", + ) + session.add(session_db) + await session.commit() + return CreateSessionResponse(session_id=session_id) + + +@app.post("/api/v1/session_heartbeat", response_model=SessionHeartbeatResponse) +async def session_heartbeat(request: SessionHeartbeatRequest, session: AsyncSession = Depends(get_session)): + """Heartbeat for an active session to keep it alive.""" + session_db = await session.get(SessionDB, request.session_id) + if session_db is None: + raise HTTPException(status_code=404, detail="Session not found") + session_db.last_heartbeat_at = datetime.now(timezone.utc) + session_db.heartbeat_count += 1 + await session.commit() + return SessionHeartbeatResponse() + + +@app.post("/api/v1/create_sampling_session", response_model=CreateSamplingSessionResponse) +async def create_sampling_session(request: CreateSamplingSessionRequest, session: AsyncSession = Depends(get_session)): + """Create a new sampling session within an existing session.""" + session_db = await session.get(SessionDB, request.session_id) + if session_db is None: + raise HTTPException(status_code=404, detail="Session not found") + # Exactly one of base_model or model_path must be provided + if (request.base_model is None) == (request.model_path is None): + raise HTTPException(status_code=400, detail="Exactly one of base_model or model_path must be provided") + sampling_session_id = f"sampling_{uuid4().hex[:8]}" + sampling_db = SamplingSessionDB( + sampling_session_id=sampling_session_id, + session_id=request.session_id, + sampling_session_seq_id=request.sampling_session_seq_id, + base_model=request.base_model, + model_path=request.model_path, + ) + session.add(sampling_db) + await session.commit() + return CreateSamplingSessionResponse(sampling_session_id=sampling_session_id) + + +@app.post("/api/v1/create_model", response_model=CreateModelResponse) +async def create_model(request: CreateModelRequest, session: AsyncSession = Depends(get_session)): + """Create a new model, optionally with a LoRA adapter.""" + # Validate session exists + session_db = await session.get(SessionDB, request.session_id) + if session_db is None: + raise HTTPException(status_code=404, detail="Session not found") + + model_id = f"model_{uuid4().hex[:8]}" + + # alpha = 32 seems to be the tinker default (see https://thinkingmachines.ai/blog/lora/) + # Generate a random seed if not provided + seed = request.lora_config.seed if request.lora_config.seed is not None else random.randint(0, 2**31 - 1) + lora_config = types.LoraConfig(rank=request.lora_config.rank, alpha=32.0, seed=seed) + request_id = await create_future( + session=session, + request_type=types.RequestType.CREATE_MODEL, + model_id=model_id, + request_data=types.CreateModelInput(lora_config=lora_config), + ) + + model_db = ModelDB( + model_id=model_id, + base_model=request.base_model, + lora_config=lora_config.model_dump(), + status="created", + request_id=request_id, + session_id=request.session_id, + ) + session.add(model_db) + + await session.commit() + + return CreateModelResponse( + model_id=model_id, + base_model=request.base_model, + lora_config=request.lora_config, + status="created", + request_id=str(request_id), + ) + + +@app.post("/api/v1/unload_model", response_model=UnloadModelResponse) +async def unload_model(request: UnloadModelRequest, session: AsyncSession = Depends(get_session)): + """Unload a model and free all associated resources.""" + # Validate model exists + model_db = await session.get(ModelDB, request.model_id) + if model_db is None: + raise HTTPException(status_code=404, detail="Model not found") + + # Update model status + model_db.status = "unloading" + + # Create future request + request_id = await create_future( + session=session, + request_type=types.RequestType.UNLOAD_MODEL, + model_id=request.model_id, + request_data=types.UnloadModelInput(), + ) + + await session.commit() + + return UnloadModelResponse(request_id=str(request_id), model_id=request.model_id) + + +class GetInfoRequest(BaseModel): + model_id: str + type: str | None = None + + +@app.post("/api/v1/get_info", response_model=ModelInfoResponse) +async def get_model_info(request: GetInfoRequest, session: AsyncSession = Depends(get_session)): + """Retrieve information about the current model.""" + model = await get_model(session, request.model_id) + + lora_config = types.LoraConfig.model_validate(model.lora_config) + model_data = ModelData( + base_model=model.base_model, lora_config=LoRAConfig(rank=lora_config.rank), model_name=model.base_model + ) + + return ModelInfoResponse(model_id=model.model_id, status=model.status, model_data=model_data) + + +@app.get("/api/v1/training_runs/{model_id}", response_model=TrainingRun) +async def get_training_run(model_id: str, session: AsyncSession = Depends(get_session)): + """Get training run for session resumption.""" + model = await get_model(session, model_id) + + lora_config = types.LoraConfig.model_validate(model.lora_config) + + return TrainingRun( + training_run_id=model.model_id, + base_model=model.base_model, + model_owner="default", + is_lora=True, + corrupted=False, + lora_rank=lora_config.rank, + # TODO: Once we track modified_at timestamps, update this + last_request_time=model.created_at, + last_checkpoint=None, + last_sampler_checkpoint=None, + user_metadata=None, + ) + + +@app.post("/api/v1/forward_backward", response_model=FutureResponse) +async def forward_backward(request: ForwardBackwardRequest, session: AsyncSession = Depends(get_session)): + """Compute and accumulate gradients.""" + await get_model(session, request.model_id) + + request_id = await create_future( + session=session, + request_type=types.RequestType.FORWARD_BACKWARD, + model_id=request.model_id, + request_data=request.forward_backward_input.to_types(), + ) + + await session.commit() + + return FutureResponse(future_id=str(request_id), status="pending", request_id=str(request_id)) + + +@app.post("/api/v1/forward", response_model=FutureResponse) +async def forward(request: ForwardRequest, session: AsyncSession = Depends(get_session)): + """Forward pass to obtain logprobs without accumulating gradients""" + await get_model(session, request.model_id) + + request_id = await create_future( + session=session, + request_type=types.RequestType.FORWARD, + model_id=request.model_id, + request_data=request.forward_input.to_types(), + ) + + await session.commit() + + return FutureResponse(future_id=str(request_id), status="pending", request_id=str(request_id)) + + +@app.post("/api/v1/optim_step", response_model=FutureResponse) +async def optim_step(request: OptimStepRequest, session: AsyncSession = Depends(get_session)): + """Update model using accumulated gradients.""" + await get_model(session, request.model_id) + + request_id = await create_future( + session=session, + request_type=types.RequestType.OPTIM_STEP, + model_id=request.model_id, + request_data=types.OptimStepInput(adam_params=request.adam_params.to_types()), + ) + + await session.commit() + + return FutureResponse(future_id=str(request_id), status="pending", request_id=str(request_id)) + + +@app.post("/api/v1/load_weights", response_model=FutureResponse) +async def load_weights(request: LoadWeightsRequest, req: Request, session: AsyncSession = Depends(get_session)): + """Loads weights and training state.""" + await get_model(session, request.model_id) + + path = types.TinkerPath.parse(request.path) + if ( + not path + or path.kind != "weights" + or not (source_model_id := path.primary_id) + or not (checkpoint_id := path.secondary_id) + ): + raise HTTPException( + status_code=400, detail="request.path must be in format tinker://source_model_id/weights/checkpoint_id" + ) + + await validate_checkpoint(req, source_model_id, checkpoint_id, types.CheckpointType.TRAINING, session) + + request_id = await create_future( + session=session, + request_type=types.RequestType.LOAD_WEIGHTS, + model_id=request.model_id, + request_data=types.LoadWeightsInput(source_model_id=source_model_id, checkpoint_id=checkpoint_id), + ) + + await session.commit() + + return FutureResponse(future_id=str(request_id), status="pending", request_id=str(request_id)) + + +@app.post("/api/v1/save_weights", response_model=FutureResponse) +async def save_weights(request: SaveWeightsRequest, session: AsyncSession = Depends(get_session)): + """Saves weights and training state.""" + # Create pending checkpoint entry (validates model exists) + await create_checkpoint( + session=session, + model_id=request.model_id, + checkpoint_id=request.path, + checkpoint_type=types.CheckpointType.TRAINING, + ) + + request_id = await create_future( + session=session, + request_type=types.RequestType.SAVE_WEIGHTS, + model_id=request.model_id, + request_data=types.SaveWeightsInput(path=request.path), + ) + + await session.commit() + + return FutureResponse(future_id=str(request_id), status="pending", request_id=str(request_id)) + + +@app.post("/api/v1/save_weights_for_sampler", response_model=FutureResponse) +async def save_weights_for_sampler(request: SaveWeightsForSamplerRequest, session: AsyncSession = Depends(get_session)): + """Saves weights in a format compatible with sampling/inference servers.""" + # Get the model (validates it exists and gives us the session_id) + model = await get_model(session, request.model_id) + + checkpoint_id = request.path or f"ss{request.sampling_session_seq_id}_seq{request.seq_id}" + sampling_session_id = None + if request.sampling_session_seq_id is not None and request.seq_id is not None: + # Create the sampling session using the model's session + sampling_session_id = f"sampling_{uuid4().hex[:8]}" + sampling_db = SamplingSessionDB( + sampling_session_id=sampling_session_id, + session_id=model.session_id, + sampling_session_seq_id=request.sampling_session_seq_id, + base_model=None, + model_path=f"tinker://{request.model_id}/sampler_weights/{checkpoint_id}", + ) + session.add(sampling_db) + + # Create pending checkpoint entry + await create_checkpoint( + session=session, + model_id=request.model_id, + checkpoint_id=checkpoint_id, + checkpoint_type=types.CheckpointType.SAMPLER, + ) + + request_id = await create_future( + session=session, + request_type=types.RequestType.SAVE_WEIGHTS_FOR_SAMPLER, + model_id=request.model_id, + request_data=types.SaveWeightsForSamplerInput( + path=checkpoint_id, + sampling_session_seq_id=request.sampling_session_seq_id, + seq_id=request.seq_id, + sampling_session_id=sampling_session_id, + ), + ) + + await session.commit() + + return FutureResponse(future_id=str(request_id), status="pending", request_id=str(request_id)) + + +async def get_sampling_model(request: SampleRequest, session: AsyncSession) -> (str | None, str | None): + """Return (base_model, model_path) for a sampling request.""" + # Resolve model/base from sampling_session_id if provided + if request.sampling_session_id is not None: + sampling_session = await session.get(SamplingSessionDB, request.sampling_session_id) + if sampling_session is None: + raise HTTPException(status_code=404, detail="Sampling session not found") + return (sampling_session.base_model, sampling_session.model_path) + return (request.base_model, request.model_path) + + +@app.post("/api/v1/asample", response_model=FutureResponse) +async def asample(request: SampleRequest, req: Request, session: AsyncSession = Depends(get_session)): + """Generates samples from the model (async version).""" + base_model, model_path = await get_sampling_model(request, session) + + if base_model: + model_id = checkpoint_id = "" + else: + assert model_path is not None + path = types.TinkerPath.parse(model_path) + if ( + not path + # Accept either tinker://model_id/checkpoint_id or tinker://model_id/sampler_weights/checkpoint_id + or path.kind not in ("", "sampler_weights") + or not (model_id := path.primary_id) + or not (checkpoint_id := path.secondary_id) + ): + raise HTTPException( + status_code=400, + detail="model_path must be tinker://model_id/checkpoint_id or tinker://model_id/sampler_weights/checkpoint_id", + ) + await get_model(session, model_id) + # Validate that the checkpoint exists and is ready + await validate_checkpoint(req, model_id, checkpoint_id, types.CheckpointType.SAMPLER, session) + + request_id = await create_future( + session=session, + request_type=( + types.RequestType.EXTERNAL if req.app.state.external_inference_client else types.RequestType.SAMPLE + ), + model_id=model_id, + request_data=types.SampleInput( + base_model=base_model, + prompt=request.prompt.to_types(), + sampling_params=request.sampling_params.to_types(), + num_samples=request.num_samples, + checkpoint_id=checkpoint_id, + prompt_logprobs=request.prompt_logprobs, + ), + ) + + await session.commit() + + if req.app.state.external_inference_client: + asyncio.create_task( + req.app.state.external_inference_client.call_and_store_result(request_id, request, model_id, checkpoint_id) + ) + + return FutureResponse(future_id=str(request_id), status="pending", request_id=str(request_id)) + + +@app.get("/api/v1/get_server_capabilities", response_model=GetServerCapabilitiesResponse) +async def get_server_capabilities(request: Request): + """Retrieve information about supported models and server capabilities.""" + supported_models = [ + SupportedModel(model_name=request.app.state.engine_config.base_model), + ] + return GetServerCapabilitiesResponse(supported_models=supported_models) + + +class RetrieveFutureRequest(BaseModel): + request_id: str + + +@app.post("/api/v1/retrieve_future") +async def retrieve_future(request: RetrieveFutureRequest, req: Request): + """Retrieve the result of an async operation, waiting until it's available.""" + timeout = 300 # 5 minutes + deadline = time.perf_counter() + timeout + + # Start with 100ms, grow to 1s + poll = 0.1 + max_poll = 1.0 + + while time.perf_counter() < deadline: + try: + async with AsyncSession(req.app.state.db_engine) as session: + # First, only query the status to avoid deserializing JSON data + statement = select(FutureDB.status).where(FutureDB.request_id == int(request.request_id)) + result = await session.exec(statement) + status = result.first() + + if not status: + raise HTTPException(status_code=404, detail="Future not found") + + # Only fetch full record if status is terminal (completed or failed) + if status in (RequestStatus.COMPLETED, RequestStatus.FAILED): + statement = select(FutureDB).where(FutureDB.request_id == int(request.request_id)) + result = await session.exec(statement) + future = result.first() + + if future.status == RequestStatus.COMPLETED: + return future.result_data + + if future.status == RequestStatus.FAILED: + # Return 400 for handled errors (validation, etc.), 500 for unexpected failures + if future.result_data and "error" in future.result_data: + raise HTTPException(status_code=400, detail=future.result_data["error"]) + else: + raise HTTPException(status_code=500, detail="Unknown error") + except SATimeoutError: + pass + + # Exponential backoff + await asyncio.sleep(poll) + poll = min(poll * 1.5, max_poll) + + raise HTTPException(status_code=408, detail="Timeout waiting for result") + + +@app.post("/api/v1/telemetry", response_model=TelemetryResponse) +async def send_telemetry(request: TelemetryRequest): + """Accept batches of SDK telemetry events for analytics and diagnostics.""" + # Just acknowledge receipt without doing anything + return TelemetryResponse(status="accepted") + + +async def validate_checkpoint( + request: Request, unique_id: str, checkpoint_id: str, checkpoint_type: types.CheckpointType, session: AsyncSession +): + """Validate that a model and checkpoint exist in the database, returning the checkpoint path.""" + checkpoint_db = await session.get(CheckpointDB, (unique_id, checkpoint_id, checkpoint_type)) + + if not checkpoint_db: + raise HTTPException(status_code=404, detail=f"Checkpoint not found: {unique_id}/{checkpoint_id}") + + if checkpoint_db.status == CheckpointStatus.PENDING: + raise HTTPException(status_code=425, detail="Checkpoint is still being created") + + if checkpoint_db.status == CheckpointStatus.FAILED: + raise HTTPException(status_code=500, detail=f"Checkpoint creation failed: {checkpoint_db.error_message}") + + subdir = "sampler_weights" if checkpoint_type == types.CheckpointType.SAMPLER else "" + return request.app.state.engine_config.checkpoints_base / unique_id / subdir / f"{checkpoint_id}.tar.gz" + + +@app.get("/api/v1/training_runs") +async def list_training_runs( + limit: int = 20, offset: int = 0, session: AsyncSession = Depends(get_session) +) -> TrainingRunsResponse: + """List all training runs""" + + # Use window function to get total count alongside paginated results in a single query + statement = select(ModelDB, func.count().over().label("total_count")).offset(offset).limit(limit) + result = await session.exec(statement) + rows = result.all() + + total_count = rows[0].total_count if rows else 0 + + training_runs = [] + for row in rows: + model = row.ModelDB + lora_config = types.LoraConfig.model_validate(model.lora_config) + + training_runs.append( + TrainingRun( + training_run_id=model.model_id, + base_model=model.base_model, + model_owner="default", + is_lora=True, + corrupted=False, + lora_rank=lora_config.rank, + last_request_time=model.created_at, # TODO: Once we track modified_at timestamps, update this + last_checkpoint=None, + last_sampler_checkpoint=None, + user_metadata=None, + ) + ) + + return TrainingRunsResponse( + training_runs=training_runs, cursor=Cursor(offset=offset, limit=limit, total_count=total_count) + ) + + +@app.get("/api/v1/training_runs/{unique_id}/checkpoints/{checkpoint_id}/archive") +async def get_checkpoint_archive_url( + request: Request, + unique_id: str = fastapi.Path(..., pattern=ID_PATTERN, max_length=ID_MAX_LENGTH), + checkpoint_id: str = fastapi.Path(..., pattern=ID_PATTERN, max_length=ID_MAX_LENGTH), + session: AsyncSession = Depends(get_session), +): + """Return a 302 redirect to the download URL (SDK expects this pattern)""" + await validate_checkpoint(request, unique_id, checkpoint_id, types.CheckpointType.SAMPLER, session) + + # Generate URL to the download endpoint and return 302 redirect + download_url = str(request.url_for("download_checkpoint_archive", unique_id=unique_id, checkpoint_id=checkpoint_id)) + expires = datetime.utcnow() + timedelta(minutes=120) + + response = RedirectResponse(url=download_url, status_code=302) + response.headers["Expires"] = expires.strftime("%a, %d %b %Y %H:%M:%S GMT") + return response + + +@app.get("/api/v1/training_runs/{unique_id}/checkpoints/{checkpoint_id}/download") +async def download_checkpoint_archive( + request: Request, + unique_id: str = fastapi.Path(..., pattern=ID_PATTERN, max_length=ID_MAX_LENGTH), + checkpoint_id: str = fastapi.Path(..., pattern=ID_PATTERN, max_length=ID_MAX_LENGTH), + session: AsyncSession = Depends(get_session), +): + """Actually download the checkpoint archive bytes""" + checkpoint_path = await validate_checkpoint( + request, unique_id, checkpoint_id, types.CheckpointType.SAMPLER, session + ) + + file_buffer = await asyncio.to_thread(download_file, checkpoint_path) + + filename = f"{unique_id}_{checkpoint_id}.tar.gz" + headers = { + "Content-Disposition": f'attachment; filename="{filename}"', + "Content-Length": str(file_buffer.getbuffer().nbytes), + } + + return StreamingResponse(file_buffer, media_type="application/octet-stream", headers=headers) + + +@app.get("/api/v1/training_runs/{unique_id}/checkpoints") +async def list_checkpoints( + unique_id: str = fastapi.Path(..., pattern=ID_PATTERN, max_length=ID_MAX_LENGTH), + session: AsyncSession = Depends(get_session), +): + """List checkpoints for a model.""" + statement = ( + select(CheckpointDB) + .where(CheckpointDB.model_id == unique_id) + .where(CheckpointDB.status == CheckpointStatus.COMPLETED) + ) + result = await session.exec(statement) + + checkpoints = [] + for checkpoint in result.all(): + # Construct tinker_path based on checkpoint type + path_kind = "weights" if checkpoint.checkpoint_type == types.CheckpointType.TRAINING else "sampler_weights" + tinker_path = f"tinker://{unique_id}/{path_kind}/{checkpoint.checkpoint_id}" + + checkpoints.append( + Checkpoint( + checkpoint_id=checkpoint.checkpoint_id, + checkpoint_type=checkpoint.checkpoint_type.value, + time=checkpoint.completed_at, + tinker_path=tinker_path, + ) + ) + + return ListCheckpointsResponse(checkpoints=checkpoints) + + +@app.get("/api/v1/models/{unique_id}/checkpoints") +async def list_checkpoints_models( + unique_id: str = fastapi.Path(..., pattern=ID_PATTERN, max_length=ID_MAX_LENGTH), + session: AsyncSession = Depends(get_session), +): + """Just to be compatible with tinker SDK""" + return await list_checkpoints(unique_id=unique_id, session=session) + + +@app.post("/api/v1/weights_info", response_model=WeightsInfoResponse) +async def get_weights_info(request: WeightsInfoRequest, req: Request, session: AsyncSession = Depends(get_session)): + """Get information about weights/checkpoint from a tinker path.""" + path = types.TinkerPath.parse(request.tinker_path) + + if not path or path.kind != "weights": + raise HTTPException( + status_code=400, detail="Invalid tinker path format. Expected: tinker://model_id/weights/checkpoint_id" + ) + + model_id = path.primary_id + checkpoint_id = path.secondary_id + + # Get model info (this will raise 404 if model doesn't exist) + model = await get_model(session, model_id) + + # Validate checkpoint exists and is completed + await validate_checkpoint(req, model_id, checkpoint_id, types.CheckpointType.TRAINING, session) + + lora_config = types.LoraConfig.model_validate(model.lora_config) + is_lora = lora_config.rank > 0 + + return WeightsInfoResponse( + base_model=model.base_model, + is_lora=is_lora, + lora_rank=lora_config.rank, + ) + + +@app.get("/") +async def root(): + """Root endpoint with API information.""" + return { + "name": "Tinker API Mock", + "version": "0.0.1", + "endpoints": { + "models": ["/api/v1/create_model", "/api/v1/get_info", "/api/v1/training_runs/{model_id}"], + "training": ["/api/v1/forward_backward", "/api/v1/optim_step"], + "futures": ["/api/v1/retrieve_future"], + "service": ["/api/v1/get_server_capabilities"], + "telemetry": ["/api/v1/telemetry"], + "checkpoints": ["/api/v1/training_runs/{unique_id}/checkpoints"], + "download": [ + "/api/v1/training_runs/{unique_id}/checkpoints/{checkpoint_id}/archive", + "/api/v1/training_runs/{unique_id}/checkpoints/{checkpoint_id}/download", + ], + }, + } + + +if __name__ == "__main__": + import argparse + import uvicorn + + # Parse command-line arguments + parser = argparse.ArgumentParser(description="SkyRL tx tinker API server") + add_model(parser, EngineConfig) + parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to") + parser.add_argument("--port", type=int, default=8000, help="Port to bind to") + args = parser.parse_args() + + # Create EngineConfig from parsed arguments (only EngineConfig fields) + engine_config = EngineConfig.model_validate({k: v for k, v in vars(args).items() if k in EngineConfig.model_fields}) + + # Store config in app.state so lifespan can access it + app.state.engine_config = engine_config + + uvicorn.run(app, host=args.host, port=args.port) diff --git a/skyrl-train/skyrl_train/tinker/backends/__init__.py b/skyrl-train/skyrl_train/tinker/backends/__init__.py new file mode 100644 index 000000000..cb18d0e91 --- /dev/null +++ b/skyrl-train/skyrl_train/tinker/backends/__init__.py @@ -0,0 +1,5 @@ +"""Tinker engine backends.""" + +from skyrl_train.tinker.backends.backend import AbstractBackend + +__all__ = ["AbstractBackend"] diff --git a/skyrl-train/skyrl_train/tinker/backends/backend.py b/skyrl-train/skyrl_train/tinker/backends/backend.py new file mode 100644 index 000000000..2eadaa143 --- /dev/null +++ b/skyrl-train/skyrl_train/tinker/backends/backend.py @@ -0,0 +1,164 @@ +"""Abstract backend interface for TinkerEngine. + +Backends handle all model state, adapter management, and computation. +The engine handles database operations and scheduling. + +Design: + 1. AbstractBackend (backend.py) + Clean interface defining what backends must implement: + - create_model (manages model metadata, adapter allocation, and optimizer lifecycle) + - forward_backward, forward, optim_step, sample + - load_checkpoint, save_checkpoint, save_sampler_checkpoint + + 2. JaxBackend (jax.py) + - Implements all abstract methods in Jax, fully supporting MultiLoRA for training and sampling + - Uses jax.value_and_grad for gradient computation + - Uses 2D mesh (fsdp, tp) + - Multi-adapter AccumulatedGradients with counts array + - Manages model metadata and adapter_index allocation internally + + 3. TinkerEngine (engine.py) + - Instantiates backend based on config + - Delegates computation and model management to self.backend + - Handles all database operations +""" + +from abc import ABC, abstractmethod + +from pydantic import BaseModel + +from skyrl_train.tinker import types + + +class AbstractBackend(ABC): + """Abstract base class for TinkerEngine backends. + + Backends handle computation and model state manipulation. + Database operations are handled by TinkerEngine. + """ + + @abstractmethod + def __init__(self, base_model: str, config: BaseModel): + """Initialize the backend.""" + pass + + @abstractmethod + def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None: + """Create a new model in the backend. + + Creates optimizer and configures LoRA adapter. + + Args: + model_id: The model identifier + lora_config: LoRA configuration with rank and alpha + """ + pass + + @abstractmethod + def forward_backward( + self, + prepared_batch: types.PreparedModelPassBatch, + ) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]: + """Run forward and backward pass on a batch. + + Args: + prepared_batch: PreparedModelPassBatch with all data extracted from requests + + Returns: + Dict mapping request_id to result or error + """ + pass + + @abstractmethod + def forward( + self, + prepared_batch: types.PreparedModelPassBatch, + ) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]: + """Run forward-only pass on a batch (no gradient computation). + + Args: + prepared_batch: PreparedModelPassBatch with all data extracted from requests + + Returns: + Dict mapping request_id to result or error + """ + pass + + @abstractmethod + def optim_step(self, model_id: str, request_data: types.OptimStepInput) -> types.OptimStepOutput: + """Apply an optimizer step using accumulated gradients. + + Args: + model_id: The model identifier + request_data: The optimizer step input parameters + + Returns: + OptimStepOutput result + """ + pass + + @abstractmethod + def sample( + self, + prepared_batch: types.PreparedSampleBatch, + ) -> dict[str, types.SampleOutput | types.ErrorResponse]: + """Generate samples for a batch of requests. + + Args: + prepared_batch: PreparedSampleBatch with all data extracted from requests + + Returns: + Dict mapping request_id to result or error + """ + pass + + @abstractmethod + def save_checkpoint(self, output_path, model_id: str) -> None: + """Save training checkpoint to disk. + + Args: + output_path: Path to save the checkpoint + model_id: The model identifier + """ + pass + + @abstractmethod + def load_checkpoint(self, checkpoint_path, model_id: str) -> None: + """Load training checkpoint from disk. + + Args: + checkpoint_path: Path to the checkpoint file + model_id: The model identifier + """ + pass + + @abstractmethod + def save_sampler_checkpoint(self, output_path, model_id: str) -> None: + """Save sampler checkpoint to disk as tar.gz. + + Args: + output_path: Path to save the checkpoint tar.gz file + model_id: The model identifier + """ + pass + + @abstractmethod + def has_model(self, model_id: str) -> bool: + """Check if a model is registered with the backend. + + Args: + model_id: The model identifier + + Returns: + True if the model is registered, False otherwise + """ + pass + + @abstractmethod + def delete_model(self, model_id: str) -> None: + """Delete a model and free all associated resources. + + Args: + model_id: The model identifier + """ + pass diff --git a/skyrl-train/skyrl_train/tinker/backends/skyrl_train.py b/skyrl-train/skyrl_train/tinker/backends/skyrl_train.py new file mode 100644 index 000000000..3b32eeaf9 --- /dev/null +++ b/skyrl-train/skyrl_train/tinker/backends/skyrl_train.py @@ -0,0 +1,255 @@ +"""SkyRL-Train backend for TinkerEngine. + +Uses SkyRL-Train infrastructure for supervised training with cross-entropy loss. +Currently supports a single model only. +""" + +print("[DEBUG] skyrl_train.py: Starting imports...", flush=True) + +from typing import Any + +import torch +from pydantic import BaseModel +from transformers import AutoTokenizer + +from skyrl_train.tinker import types +from skyrl_train.tinker.backends.backend import AbstractBackend +from skyrl_train.tx_utils.log import logger + +print("[DEBUG] skyrl_train.py: Basic imports done, importing Ray...", flush=True) + +try: # Optional dependency: keep other backends importable without ray/skyrl-train. + import ray + print("[DEBUG] skyrl_train.py: Ray imported", flush=True) + from ray.util.placement_group import placement_group + print("[DEBUG] skyrl_train.py: placement_group imported", flush=True) + from skyrl_train.training_batch import TrainingInputBatch + print("[DEBUG] skyrl_train.py: TrainingInputBatch imported", flush=True) + # Lazy import to avoid crash from Flash Attention initialization in worker.py + # These will be imported in create_model() when actually needed + PPORayActorGroup = None + WorkerDispatch = None + PolicyWorker = None + print("[DEBUG] skyrl_train.py: Skipping worker imports (will import lazily)", flush=True) + from skyrl_train.utils import get_ray_pg_ready_with_timeout + print("[DEBUG] skyrl_train.py: get_ray_pg_ready_with_timeout imported", flush=True) + from skyrl_train.config.utils import get_default_config + print("[DEBUG] skyrl_train.py: get_default_config imported", flush=True) + from skyrl_train.env_vars import SKYRL_RAY_PG_TIMEOUT_IN_S + print("[DEBUG] skyrl_train.py: SKYRL_RAY_PG_TIMEOUT_IN_S imported", flush=True) + + SKYRL_TRAIN_AVAILABLE = True + print("[DEBUG] skyrl_train.py: All imports successful!", flush=True) +except ImportError: # pragma: no cover - exercised only in non-ray installs + print("[DEBUG] skyrl_train.py: Import failed, setting unavailable", flush=True) + ray = None + placement_group = None + TrainingInputBatch = Any + PPORayActorGroup = Any + WorkerDispatch = Any + PolicyWorker = Any + get_ray_pg_ready_with_timeout = None + get_default_config = None + SKYRL_RAY_PG_TIMEOUT_IN_S = None + SKYRL_TRAIN_AVAILABLE = False + + +class SkyRLTrainBackendConfig(BaseModel, extra="forbid"): + """Configuration for the SkyRL-Train backend.""" + + pass + + +def _build_config(base_model: str, config: SkyRLTrainBackendConfig, lora_config: types.LoraConfig | None = None): + """Build config for SkyRL-Train workers using default config.""" + cfg = get_default_config() + cfg.trainer.policy.model.path = base_model + return cfg + + +class SkyRLTrainBackend(AbstractBackend): + """SkyRL-Train backend for supervised training.""" + + def __init__(self, base_model: str, config: SkyRLTrainBackendConfig): + logger.warning("=" * 80) + logger.warning("SkyRLTrainBackend is currently EXPERIMENTAL!") + logger.warning("=" * 80) + + if not SKYRL_TRAIN_AVAILABLE or ray is None: + raise ImportError( + "SkyRLTrainBackend requires `ray`. Install the appropriate extras (e.g. `--extra skyrl_train`)." + ) + + self.base_model = base_model + self.config = config + self._model_id: str | None = None + self._model_metadata: types.ModelMetadata | None = None + self._actor_group: PPORayActorGroup | None = None + self._dispatch: WorkerDispatch | None = None + self._cfg = None + self._tokenizer = AutoTokenizer.from_pretrained(self.base_model) + + def has_model(self, model_id: str) -> bool: + return self._model_id == model_id + + def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None: + logger.info(f"[DEBUG] create_model called with model_id={model_id}") + + if self._model_id is not None: + raise ValueError(f"Model '{self._model_id}' already exists. Only one model supported.") + + # Import worker classes here to avoid Flash Attention crash at module import time + logger.info(f"[DEBUG] Importing worker classes...") + from skyrl_train.workers.worker import PPORayActorGroup + from skyrl_train.workers.worker_dispatch import WorkerDispatch + from skyrl_train.workers.fsdp.fsdp_worker import PolicyWorker + logger.info(f"[DEBUG] Worker classes imported successfully") + + logger.info(f"[DEBUG] Building config for base_model={self.base_model}") + self._cfg = _build_config(self.base_model, self.config, lora_config) + num_nodes = self._cfg.trainer.placement.policy_num_nodes + num_gpus = self._cfg.trainer.placement.policy_num_gpus_per_node + logger.info(f"[DEBUG] Config built: num_nodes={num_nodes}, num_gpus={num_gpus}") + + logger.info(f"[DEBUG] Creating placement group with {num_nodes * num_gpus} actors...") + pg = placement_group([{"GPU": 1, "CPU": 1}] * num_nodes * num_gpus, strategy="PACK") + logger.info(f"[DEBUG] Placement group created, waiting for it to be ready...") + get_ray_pg_ready_with_timeout(pg, timeout=SKYRL_RAY_PG_TIMEOUT_IN_S) + logger.info(f"[DEBUG] Placement group is ready") + + logger.info(f"[DEBUG] Creating PPORayActorGroup...") + self._actor_group = PPORayActorGroup( + cfg=self._cfg, + num_nodes=num_nodes, + num_gpus_per_node=num_gpus, + ray_actor_type=PolicyWorker, + pg=pg, + num_gpus_per_actor=1, + colocate_all=True, + sequence_parallel_size=self._cfg.trainer.policy.sequence_parallel_size, + record_memory=self._cfg.trainer.policy.record_memory, + ) + logger.info(f"[DEBUG] PPORayActorGroup created, initializing model...") + + ray.get(self._actor_group.async_init_model(self.base_model)) + logger.info(f"[DEBUG] Model initialized, creating WorkerDispatch...") + + self._dispatch = WorkerDispatch(self._cfg, policy_actor_group=self._actor_group) + logger.info(f"[DEBUG] WorkerDispatch created") + + self._model_id = model_id + self._model_metadata = types.ModelMetadata(adapter_index=0, lora_config=lora_config) + logger.info(f"Created model {model_id}") + + def delete_model(self, model_id: str) -> None: + if self._model_id != model_id: + raise ValueError(f"Model {model_id} not found") + raise NotImplementedError("Deleting models not yet implemented") + + def _to_training_batch(self, prepared_batch: types.PreparedModelPassBatch) -> TrainingInputBatch: + """Convert PreparedModelPassBatch to TrainingInputBatch.""" + if not prepared_batch.all_input_ids: + return TrainingInputBatch({}) + + # SkyRL-Train shifts internally, so provide the full sequence length by + # appending the last target token to each already-shifted input. + full_sequences = [ + list(input_ids) + ([targets[-1]] if targets else []) + for input_ids, targets in zip(prepared_batch.all_input_ids, prepared_batch.all_targets) + ] + + max_seq_len = max(len(seq) for seq in full_sequences) + max_response_len = max(len(weights) for weights in prepared_batch.all_token_weights) + + sequences, attention_masks, loss_masks, response_masks = [], [], [], [] + + for seq, weights in zip(full_sequences, prepared_batch.all_token_weights): + pad_len = max_seq_len - len(seq) + sequences.append([self._tokenizer.pad_token_id] * pad_len + list(seq)) + attention_masks.append([0] * pad_len + [1] * len(seq)) + action_pad = max_response_len - len(weights) + loss_masks.append([0.0] * action_pad + [float(w) for w in weights]) + response_masks.append([0] * action_pad + [1] * len(weights)) + + sequences_tensor = torch.tensor(sequences, dtype=torch.long) + attention_mask_tensor = torch.tensor(attention_masks, dtype=torch.long) + loss_mask_tensor = torch.tensor(loss_masks, dtype=torch.float32) + response_mask_tensor = torch.tensor(response_masks, dtype=torch.long) + + batch = TrainingInputBatch( + { + "sequences": sequences_tensor, + "attention_mask": attention_mask_tensor, + "loss_mask": loss_mask_tensor, + "response_mask": response_mask_tensor, + } + ) + batch.metadata = {"response_length": max_response_len} + return batch + + def forward_backward( + self, + prepared_batch: types.PreparedModelPassBatch, + loss_fn: str = "cross_entropy", + ) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]: + if not prepared_batch.all_input_ids: + return {} + + batch = self._to_training_batch(prepared_batch) + data = self._dispatch.forward_backward("policy", batch, loss_fn=loss_fn) + + results = {} + for request_id, _, start_idx, end_idx in prepared_batch.request_batch_slices: + loss_fn_outputs = [] + for i in range(start_idx, end_idx): + raw_output = data["loss_fn_outputs"][i] + logprobs = list(raw_output.get("logprobs", [])) + elementwise_loss = list(raw_output.get("elementwise_loss", [])) + loss_fn_outputs.append( + { + "elementwise_loss": { + "data": elementwise_loss, + "dtype": "float32", + "shape": [len(elementwise_loss)], + }, + "logprobs": { + "data": logprobs, + "dtype": "float32", + "shape": [len(logprobs)], + }, + } + ) + results[request_id] = types.ForwardBackwardOutput( + loss_fn_output_type="scalar", + loss_fn_outputs=loss_fn_outputs, + metrics={}, + ) + return results + + def forward( + self, + prepared_batch: types.PreparedModelPassBatch, + ) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]: + raise NotImplementedError("Forward-only pass not supported") + + def optim_step(self, model_id: str, request_data: types.OptimStepInput) -> types.OptimStepOutput: + if model_id != self._model_id: + raise ValueError(f"Model {model_id} not found") + grad_norm = self._dispatch.optim_step("policy") + logger.info(f"grad_norm: {grad_norm}") + return types.OptimStepOutput() + + def sample( + self, + prepared_batch: types.PreparedSampleBatch, + ) -> dict[str, types.SampleOutput | types.ErrorResponse]: + raise NotImplementedError("Sampling not supported") + + def save_checkpoint(self, output_path, model_id: str) -> None: + raise NotImplementedError("Saving checkpoints not supported") + + def load_checkpoint(self, checkpoint_path, model_id: str) -> None: + raise NotImplementedError("Loading checkpoints not supported") + + def save_sampler_checkpoint(self, output_path, model_id: str) -> None: + raise NotImplementedError("Sampler checkpoints not supported") diff --git a/skyrl-train/skyrl_train/tinker/backends/utils.py b/skyrl-train/skyrl_train/tinker/backends/utils.py new file mode 100644 index 000000000..0dc5a4ad8 --- /dev/null +++ b/skyrl-train/skyrl_train/tinker/backends/utils.py @@ -0,0 +1,52 @@ +"""Shared helper utilities for TinkerEngine backends.""" + +import time +from contextlib import contextmanager + +import numpy as np + +from skyrl_train.tx_utils.log import logger + + +@contextmanager +def log_timing(request: str): + """Context manager to log execution time for a request.""" + start_time = time.perf_counter() + try: + yield + finally: + elapsed = time.perf_counter() - start_time + logger.info(f"(timing) {request} took {elapsed:.3f}s") + + +def pad(xs, pad_to: int, *, fill): + """Pad a list to a specified length with a fill value.""" + return xs + ([fill] * (pad_to - len(xs))) + + +def pad_batch(sequences: list[list], max_length: int, dtype) -> np.ndarray: + """Pad a batch of sequences to max_length. + + Args: + sequences: List of sequences to pad. + max_length: Target length for all sequences. + dtype: NumPy dtype for the output array. + + Returns: + A NumPy array of shape (batch_size, max_length) with the padded sequences. + """ + batch_size = len(sequences) + padded = np.zeros((batch_size, max_length), dtype=dtype) + for i, seq in enumerate(sequences): + assert len(seq) <= max_length, f"Sequence length {len(seq)} exceeds max_length {max_length}" + padded[i, : len(seq)] = seq + return padded + + +def pad_to_fsdp(arr: np.ndarray, fsdp_size: int) -> np.ndarray: + """Pad array's first dimension to be divisible by FSDP size.""" + batch_size = arr.shape[0] + pad_size = (fsdp_size - batch_size % fsdp_size) % fsdp_size + if pad_size == 0: + return arr + return np.pad(arr, [(0, pad_size)] + [(0, 0)] * (arr.ndim - 1)) diff --git a/skyrl-train/skyrl_train/tinker/config.py b/skyrl-train/skyrl_train/tinker/config.py new file mode 100644 index 000000000..e126e5499 --- /dev/null +++ b/skyrl-train/skyrl_train/tinker/config.py @@ -0,0 +1,131 @@ +"""Configuration for the Tinker engine.""" + +import argparse +import json +import os +from pathlib import Path + +from cloudpathlib import AnyPath +from pydantic import BaseModel, Field + + +class EngineConfig(BaseModel): + """Configuration for the Tinker engine.""" + + base_model: str = Field(..., description="Base model name (e.g., Qwen/Qwen3-0.6B)") + backend: str = Field(default="jax", description="Backend to use for training and inference") + backend_config: dict = Field( + default_factory=dict, + description="Backend-specific configuration as JSON string", + json_schema_extra={"argparse_type": json.loads}, + ) + checkpoints_base: AnyPath = Field( + default=AnyPath("/tmp/tx_checkpoints"), + description="Base path where checkpoints will be stored", + ) + database_url: str = Field( + default=f'sqlite:///{Path(__file__).parent / "tinker.db"}', + description="Database URL (e.g., postgresql://user:password@localhost:5432/tinker). If not set, uses TX_DATABASE_URL env var or defaults to SQLite", + json_schema_extra={"argparse_type": str, "env_var": "TX_DATABASE_URL"}, + ) + external_inference_url: str | None = Field( + default=None, + description="URL of the external inference engine. If set, sample requests will be sent to the external engine instead (currently only VLLM is supported).", + json_schema_extra={"argparse_type": str}, + ) + external_inference_api_key: str = Field( + default="EMPTY", + description="API key for an external inference engine. If not provided will use vLLM 'EMPTY' key convention", + ) + external_inference_lora_base: Path = Field( + default=Path("/tmp/lora_models"), + description="Directory where LoRA models will be extracted for external inference engines", + ) + session_cleanup_interval_sec: int = Field( + default=60, + description="How often to check for stale sessions (seconds). Set to -1 to disable cleanup.", + ) + # The tinker client sends heartbeats every 10 seconds by default. + # https://github.com/thinking-machines-lab/tinker/blob/2d8e9d5e00f746f39148a5d0cb760dff3f2eed43/src/tinker/lib/internal_client_holder.py#L182 + session_timeout_sec: int = Field( + default=300, + description="Seconds without heartbeat before session is considered stale. Set to -1 to disable cleanup.", + ) + + +def convert_env_var(env_name: str, env_value: str, expected_type: type): + """Convert environment variable to expected type.""" + if expected_type is bool: + if env_value not in ("0", "1"): + raise ValueError( + f"Environment variable '{env_name}' for a boolean flag must be '0' or '1', but got '{env_value}'." + ) + return env_value == "1" + else: + return env_value + + +def add_model(parser: argparse.ArgumentParser, model: type[BaseModel]) -> None: + """Add Pydantic model fields to an ArgumentParser. + + The priority order of how options are handled: 1. Explicitly specified command line options, + 2. environment variables and 3. default values. + + Args: + parser: The ArgumentParser to add arguments to + model: The Pydantic model class + """ + for name, field in model.model_fields.items(): + arg_name = name.replace("_", "-") + kwargs = { + "help": field.description, + } + + # Check for default value, with env_var support + default_value = field.default + if field.json_schema_extra and "env_var" in field.json_schema_extra: + env_name = field.json_schema_extra["env_var"] + if env_value := os.environ.get(env_name): + default_value = convert_env_var(env_name, env_value, field.annotation) + + if field.annotation is bool: + # For boolean flags, use BooleanOptionalAction to support both --{arg_name} and --no-{arg_name} + kwargs = {**kwargs, "action": argparse.BooleanOptionalAction, "dest": name, "default": default_value} + else: + # Check if explicit argparse_type is specified in field metadata + argparse_type = field.json_schema_extra.get("argparse_type") if field.json_schema_extra else None + if argparse_type is not None: + kwargs["type"] = argparse_type + elif field.annotation is not None: + kwargs["type"] = field.annotation + + if field.is_required(): + # Mark as required in argparse if no default is provided + kwargs["required"] = True + else: + # For optional fields, provide the default value to argparse + kwargs["default"] = default_value + + parser.add_argument(f"--{arg_name}", **kwargs) + + +def config_to_argv(cfg: BaseModel) -> list[str]: + """This should 'unparse' a config parsed by an ArgumentParser constructed by add_model.""" + argv = [] + for field_name, value in cfg.model_dump().items(): + field = cfg.model_fields[field_name] + arg_name = field_name.replace("_", "-") + + if field.annotation is bool: + argv.append(f"--{arg_name}" if value else f"--no-{arg_name}") + elif field.annotation is dict: + # Serialize dict to JSON string + if value: + argv.append(f"--{arg_name}") + argv.append(json.dumps(value)) + else: + # Skip None values - let them use defaults or environment variables + if value is not None: + argv.append(f"--{arg_name}") + argv.append(str(value)) + return argv diff --git a/skyrl-train/skyrl_train/tinker/db_models.py b/skyrl-train/skyrl_train/tinker/db_models.py new file mode 100644 index 000000000..161363227 --- /dev/null +++ b/skyrl-train/skyrl_train/tinker/db_models.py @@ -0,0 +1,116 @@ +"""Database models for the Tinker API.""" + +from datetime import datetime, timezone +from enum import Enum + +from sqlmodel import SQLModel, Field, JSON +from sqlalchemy import DateTime +from sqlalchemy.engine import url as sqlalchemy_url + +from skyrl_train.tinker import types + + +def get_async_database_url(db_url: str) -> str: + """Get the async database URL. + + Args: + db_url: Optional database URL to use. + + Returns: + Async database URL string for SQLAlchemy. + + Raises: + ValueError: If the database scheme is not supported. + """ + parsed_url = sqlalchemy_url.make_url(db_url) + + match parsed_url.get_backend_name(): + case "sqlite": + async_url = parsed_url.set(drivername="sqlite+aiosqlite") + case "postgresql": + async_url = parsed_url.set(drivername="postgresql+asyncpg") + case _ if "+" in parsed_url.drivername: + # Already has an async driver specified, keep it + async_url = parsed_url + case backend_name: + raise ValueError(f"Unsupported database scheme: {backend_name}") + + return async_url.render_as_string(hide_password=False) + + +class RequestStatus(str, Enum): + """Status of a request.""" + + PENDING = "pending" + COMPLETED = "completed" + FAILED = "failed" + + +class CheckpointStatus(str, Enum): + """Status of a checkpoint.""" + + PENDING = "pending" + COMPLETED = "completed" + FAILED = "failed" + + +# SQLModel table definitions +class ModelDB(SQLModel, table=True): + __tablename__ = "models" + + model_id: str = Field(primary_key=True) + base_model: str + lora_config: dict[str, object] = Field(sa_type=JSON) + status: str = Field(index=True) + request_id: int + session_id: str = Field(foreign_key="sessions.session_id", index=True) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), sa_type=DateTime(timezone=True)) + + +class FutureDB(SQLModel, table=True): + __tablename__ = "futures" + + request_id: int | None = Field(default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}) + request_type: types.RequestType + model_id: str | None = Field(default=None, index=True) + request_data: dict = Field(sa_type=JSON) # this is of type types.{request_type}Input + result_data: dict | None = Field(default=None, sa_type=JSON) # this is of type types.{request_type}Output + status: RequestStatus = Field(default=RequestStatus.PENDING, index=True) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), sa_type=DateTime(timezone=True)) + completed_at: datetime | None = Field(default=None, sa_type=DateTime(timezone=True)) + + +class CheckpointDB(SQLModel, table=True): + __tablename__ = "checkpoints" + + model_id: str = Field(foreign_key="models.model_id", primary_key=True) + checkpoint_id: str = Field(primary_key=True) + checkpoint_type: types.CheckpointType = Field(primary_key=True) + status: CheckpointStatus + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), sa_type=DateTime(timezone=True)) + completed_at: datetime | None = Field(default=None, sa_type=DateTime(timezone=True)) + error_message: str | None = None + + +class SessionDB(SQLModel, table=True): + __tablename__ = "sessions" + + session_id: str = Field(primary_key=True) + tags: list[str] = Field(default_factory=list, sa_type=JSON) + user_metadata: dict = Field(default_factory=dict, sa_type=JSON) + sdk_version: str + status: str = Field(default="active", index=True) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), sa_type=DateTime(timezone=True)) + last_heartbeat_at: datetime | None = Field(default=None, sa_type=DateTime(timezone=True), index=True) + heartbeat_count: int = 0 + + +class SamplingSessionDB(SQLModel, table=True): + __tablename__ = "sampling_sessions" + + sampling_session_id: str = Field(primary_key=True) + session_id: str = Field(foreign_key="sessions.session_id", index=True) + sampling_session_seq_id: int + base_model: str | None = None + model_path: str | None = None + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), sa_type=DateTime(timezone=True)) diff --git a/skyrl-train/skyrl_train/tinker/engine.py b/skyrl-train/skyrl_train/tinker/engine.py new file mode 100644 index 000000000..6c4548728 --- /dev/null +++ b/skyrl-train/skyrl_train/tinker/engine.py @@ -0,0 +1,691 @@ +"""Background engine for processing training requests.""" + +print("[DEBUG] engine.py: Starting imports...", flush=True) + +import argparse +import time +from contextlib import contextmanager +from datetime import datetime, timedelta, timezone +from pathlib import Path +from typing import Callable + +print("[DEBUG] engine.py: Basic imports done", flush=True) + +from cloudpathlib import AnyPath +from pydantic import BaseModel +from sqlmodel import create_engine, Session, select, update, func + +print("[DEBUG] engine.py: Third-party imports done", flush=True) + +from skyrl_train.tinker.db_models import FutureDB, RequestStatus, CheckpointDB, CheckpointStatus, ModelDB, SessionDB +print("[DEBUG] engine.py: db_models imported", flush=True) + +from skyrl_train.tinker import types +print("[DEBUG] engine.py: types imported", flush=True) + +from skyrl_train.tinker.config import EngineConfig, add_model +print("[DEBUG] engine.py: config imported", flush=True) + +# Lazy imports for backends - only import when needed to avoid crashes +# (e.g., JAX import can crash with SIGBUS if XLA has issues) +print("[DEBUG] engine.py: Skipping backend imports (will import lazily)", flush=True) + +from skyrl_train.tinker.backends.utils import log_timing +from skyrl_train.tinker.loss_fns import LOSS_TYPES +from skyrl_train.tx_utils.log import logger + +print("[DEBUG] engine.py: All imports complete", flush=True) + + +def prepare_sample_batch( + requests: dict[str, tuple[str, types.SampleInput]], + checkpoints_base: AnyPath | None = None, +) -> types.PreparedSampleBatch: + """Prepare batch data for sample operations. + + Extracts prompts and sampling params from requests into lists + that the backend will convert to arrays. + + Args: + requests: Dict mapping request_id to (model_id, request_data) tuples (pre-validated) + checkpoints_base: Base path for checkpoints (optional, needed for LoRA sampling) + + Returns: + PreparedSampleBatch with all data extracted from requests + """ + all_prompts = [] + all_sampling_params = [] + all_model_ids = [] + all_checkpoint_ids = [] + all_checkpoint_paths = [] + request_batch_slices = [] + + needs_prompt_logprobs = any(request_data.prompt_logprobs for (_, request_data) in requests.values()) + + for request_id, (model_id, request_data) in requests.items(): + request_start = len(all_prompts) + + # Expand requests for num_samples + prompt_tokens = [token for chunk in request_data.prompt.chunks for token in chunk.tokens] + checkpoint_path = "" + if model_id and request_data.checkpoint_id and checkpoints_base: + checkpoint_path = str( + checkpoints_base / model_id / "sampler_weights" / f"{request_data.checkpoint_id}.tar.gz" + ) + for _ in range(request_data.num_samples): + all_prompts.append(prompt_tokens) + all_sampling_params.append(request_data.sampling_params) + all_model_ids.append(model_id) + all_checkpoint_ids.append(request_data.checkpoint_id) + all_checkpoint_paths.append(checkpoint_path) + + request_batch_slices.append( + (request_id, model_id, request_start, len(all_prompts), request_data.prompt_logprobs) + ) + + return types.PreparedSampleBatch( + all_prompts=all_prompts, + all_sampling_params=all_sampling_params, + all_model_ids=all_model_ids, + all_checkpoint_ids=all_checkpoint_ids, + all_checkpoint_paths=all_checkpoint_paths, + needs_prompt_logprobs=needs_prompt_logprobs, + request_batch_slices=request_batch_slices, + ) + + +def prepare_model_pass_batch( + requests: dict[str, tuple[str, types.ForwardBackwardInput]], +) -> types.PreparedModelPassBatch: + """Prepare batch data for forward/forward_backward operations. + + Extracts tokens, targets, and metadata from requests into lists + that the backend will convert to arrays. + + Args: + requests: Dict mapping request_id to (model_id, request_data) tuples (pre-validated) + + Returns: + PreparedModelPassBatch with all data extracted from requests + """ + all_input_ids = [] + all_targets = [] + all_token_weights = [] + all_model_ids = [] + all_sampling_logprobs = [] + all_advantages = [] + all_loss_fn_types = [] + request_batch_slices = [] + + for request_id, (model_id, request_data) in requests.items(): + loss_fn_type = LOSS_TYPES[request_data.loss_fn] + + request_start = len(all_input_ids) + for item in request_data.data: + tokens = [t for chunk in item.model_input.chunks for t in chunk.tokens] + all_input_ids.append(tokens) + loss_fn_inputs = item.loss_fn_inputs + all_targets.append(loss_fn_inputs.target_tokens.data) + all_token_weights.append(loss_fn_inputs.weights.data) + all_sampling_logprobs.append(loss_fn_inputs.logprobs.data) + all_advantages.append(loss_fn_inputs.advantages.data) + all_model_ids.append(model_id) + all_loss_fn_types.append(loss_fn_type) + + request_batch_slices.append((request_id, model_id, request_start, len(all_input_ids))) + + return types.PreparedModelPassBatch( + all_input_ids=all_input_ids, + all_targets=all_targets, + all_token_weights=all_token_weights, + all_sampling_logprobs=all_sampling_logprobs, + all_advantages=all_advantages, + all_model_ids=all_model_ids, + all_loss_fn_types=all_loss_fn_types, + request_batch_slices=request_batch_slices, + ) + + +def _get_backend_classes(backend_name: str): + """Lazy import backend classes to avoid crashes from unused backends.""" + if backend_name == "skyrl_train": + from skyrl_train.tinker.backends.skyrl_train import SkyRLTrainBackend, SkyRLTrainBackendConfig + return SkyRLTrainBackend, SkyRLTrainBackendConfig + else: + raise ValueError(f"Unknown backend: {backend_name}. Only 'skyrl_train' is supported in this installation.") + + +class TinkerEngine: + """Background engine for processing training requests. + + The engine handles: + - Database operations (futures, checkpoints) + - Request finding/scheduling + - File I/O (download/upload checkpoints) + - Validating requests against loaded models + + Computation and model management are delegated to the backend. + """ + + def _filter_valid_requests( + self, + requests: dict[str, tuple[str, BaseModel]], + ) -> tuple[dict[str, types.ErrorResponse], dict[str, tuple[str, BaseModel]]]: + """Filter out requests with invalid model_ids and return error results for them. + + Args: + requests: Dict mapping request_id to (model_id, request_data) tuples + + Returns: + Tuple of (error_results, valid_requests) + """ + results = {} + valid_requests = {} + + for request_id, (model_id, request_data) in requests.items(): + error = None + if model_id and not self.backend.has_model(model_id): + error = f"Model {model_id} not loaded" + elif not model_id and isinstance(request_data, types.SampleInput): + if request_data.base_model != self.config.base_model: + error = f"Engine is configured for '{self.config.base_model}' but request specified '{request_data.base_model}'" + elif request_data.checkpoint_id: + error = "checkpoint_id must be empty for base model sampling" + + if error: + results[request_id] = types.ErrorResponse(error=error, status="failed") + else: + valid_requests[request_id] = (model_id, request_data) + + return results, valid_requests + + def __init__( + self, + config: EngineConfig, + ): + """Initialize the engine with a database connection and base model.""" + self.config = config + self.db_engine = create_engine(config.database_url, echo=False) + + # Initialize the backend (handles model state, computation, and adapter management) + # Use lazy import to avoid crashes from unused backends (e.g., JAX import can crash with SIGBUS) + logger.info(f"[DEBUG] Loading backend: {config.backend}") + backend_class, backend_config_class = _get_backend_classes(config.backend) + logger.info(f"[DEBUG] Backend classes loaded: {backend_class.__name__}, {backend_config_class.__name__}") + backend_config = backend_config_class(**config.backend_config) + + # Initialize Ray if using SkyRL backend + if config.backend == "skyrl_train": + logger.info("[DEBUG] Initializing Ray for SkyRL backend...") + try: + import ray + # Ray.init with ignore_reinit_error=True will silently succeed if already initialized + ray.init(ignore_reinit_error=True) + logger.info("[DEBUG] Ray initialized successfully") + except Exception as e: + logger.error(f"[DEBUG] Failed to initialize Ray: {e}") + raise + + self.backend = backend_class(config.base_model, backend_config) + + # Track last cleanup time for periodic stale session cleanup + self._last_cleanup_time: float = time.time() + + logger.info(f"Initialized TinkerEngine with backend={type(self.backend).__name__}") + + @property + def metrics(self) -> types.EngineMetrics: + """Pass-through to backend metrics for backwards compatibility.""" + return self.backend.metrics + + @contextmanager + def _checkpoint_status_context(self, model_id: str, checkpoint_id: str, checkpoint_type: types.CheckpointType): + """Context manager to handle checkpoint DB status updates. + + Fetches the checkpoint entry, yields it, and updates its status to COMPLETED + or FAILED based on whether an exception occurred. + """ + with Session(self.db_engine) as session: + checkpoint_db = session.get(CheckpointDB, (model_id, checkpoint_id, checkpoint_type)) + if checkpoint_db is None: + raise ValueError( + f"Checkpoint entry not found for model '{model_id}', checkpoint '{checkpoint_id}', type '{checkpoint_type}'" + ) + + try: + yield checkpoint_db + checkpoint_db.status = CheckpointStatus.COMPLETED + except Exception as e: + logger.exception(f"Error saving checkpoint for model {model_id}, checkpoint {checkpoint_id}: {e}") + checkpoint_db.status = CheckpointStatus.FAILED + checkpoint_db.error_message = str(e) + raise + finally: + checkpoint_db.completed_at = datetime.now(timezone.utc) + session.add(checkpoint_db) + session.commit() + + def find_batchable_model_passes( + self, session: Session, request_type: types.RequestType + ) -> dict[str, tuple[str, types.ForwardBackwardInput]]: + """Find all requests of the given type that come before any destructive update for their model. + + Uses look-ahead scheduling: for each model, only returns operations + that have no optim_step or load_weights blocking them in the queue. + + Args: + session: Database session + request_type: The type of request to find (e.g., FORWARD or FORWARD_BACKWARD) + + Returns: + Dict mapping request_id to (model_id, request_data) tuples + """ + # Find the earliest pending optim_step or load_weights per model (these act as barriers) + barriers_query = ( + select(FutureDB.model_id, func.min(FutureDB.request_id).label("barrier_id")) + .where( + (FutureDB.request_type == types.RequestType.OPTIM_STEP) + | (FutureDB.request_type == types.RequestType.LOAD_WEIGHTS) + ) + .where(FutureDB.status == RequestStatus.PENDING) + .group_by(FutureDB.model_id) + ) + barriers = dict(session.exec(barriers_query).all()) + + # Get all pending operations of the requested type ordered by request_id + query = ( + select(FutureDB) + .where(FutureDB.request_type == request_type) + .where(FutureDB.status == RequestStatus.PENDING) + .order_by(FutureDB.request_id) + ) + ops = session.exec(query).all() + + # Filter: only include ops that come before their model's barrier + batchable = [op for op in ops if op.model_id not in barriers or op.request_id < barriers[op.model_id]] + + return { + str(f.request_id): (f.model_id, types.ForwardBackwardInput.model_validate(f.request_data)) + for f in batchable + } + + def find_batchable_sample(self, session: Session) -> dict[str, tuple[str, types.SampleInput]]: + """Find all sample ops that can be safely batched together. + + Returns sample operations ensuring that each model_id has only one checkpoint_id + to avoid loading different checkpoints for the same model in a single batch. + + If sample_max_num_sequences is configured, limits to that many requests so we don't + produce partial batches in process_sample_batch. If num_samples > 1 for some requests, + this may not be perfect, but it's good until we implement continuous batching. + + Args: + session: Database session + + Returns: + Dict mapping request_id to (model_id, request_data) tuples + """ + sample_query = ( + select(FutureDB) + .where(FutureDB.request_type == types.RequestType.SAMPLE) + .where(FutureDB.status == RequestStatus.PENDING) + .order_by(FutureDB.request_id) + ) + sample_ops = session.exec(sample_query).all() + + batchable = [] + model_checkpoints = {} # Map from model_id to checkpoint_id of first request to that model + for op in sample_ops: + checkpoint_id = op.request_data["checkpoint_id"] + # Base model requests (empty checkpoint_id) are always compatible, otherwise only + # take only requests with one checkpoint_id for a given model_id + if checkpoint_id == "" or model_checkpoints.setdefault(op.model_id, checkpoint_id) == checkpoint_id: + batchable.append(op) + + # TODO: This leaks the abstraction by accessing backend-specific config. + # We should find a better way to handle this going forward. + # Note: JaxBackend-specific optimization removed when running without JAX + if hasattr(self.backend, 'config') and hasattr(self.backend.config, 'sample_max_num_sequences'): + if self.backend.config.sample_max_num_sequences > 0: + batchable = batchable[: self.backend.config.sample_max_num_sequences] + + return {str(f.request_id): (f.model_id, types.SampleInput.model_validate(f.request_data)) for f in batchable} + + def find_single_requests(self, session: Session) -> dict[str, tuple[str, types.RequestType, dict]]: + """Find all requests that need to be processed individually (not batchable). + + Args: + session: Database session + + Returns: + Dict mapping request_id to (model_id, request_type, request_data) tuples + """ + statement = ( + select(FutureDB) + .where(FutureDB.status == RequestStatus.PENDING) + .where(FutureDB.request_type != types.RequestType.FORWARD_BACKWARD) + .where(FutureDB.request_type != types.RequestType.FORWARD) + .where(FutureDB.request_type != types.RequestType.SAMPLE) + .where(FutureDB.request_type != types.RequestType.EXTERNAL) + .order_by(FutureDB.request_id) + ) + other_futures = session.exec(statement).all() + + return {str(f.request_id): (f.model_id, f.request_type, f.request_data) for f in other_futures} + + def process_create_model(self, model_id: str, request_data: types.CreateModelInput) -> types.CreateModelOutput: + """Create and initialize a model.""" + # Create model in backend (allocates adapter_index, creates optimizer, and configures adapter) + self.backend.create_model(model_id, request_data.lora_config) + + logger.info(f"Created LoRA model {model_id}") + + return types.CreateModelOutput( + model_id=model_id, + base_model=self.config.base_model, + lora_config=request_data.lora_config, + ) + + def process_unload_model(self, model_id: str, request_data: types.UnloadModelInput) -> types.UnloadModelOutput: + """Unload a model and free all resources.""" + if not self.backend.has_model(model_id): + logger.warning(f"Ignoring unload request for model {model_id} that is not loaded.") + else: + self.backend.delete_model(model_id) + + # Update model status in DB + with Session(self.db_engine) as session: + _ = session.exec(update(ModelDB).where(ModelDB.model_id == model_id).values(status="unloaded")) + session.commit() + + logger.info(f"Unloaded model {model_id}") + + return types.UnloadModelOutput(model_id=model_id, status="unloaded") + + def cleanup_stale_sessions(self) -> int: + """Cleanup sessions with no recent heartbeat and unload their models. + + Returns: + Number of models unloaded + """ + cutoff = datetime.now(timezone.utc) - timedelta(seconds=self.config.session_timeout_sec) + unloaded_count = 0 + + with Session(self.db_engine) as session: + # Find stale sessions (active sessions with heartbeat older than cutoff) + stale_sessions = session.exec( + select(SessionDB).where( + SessionDB.status == "active", + SessionDB.last_heartbeat_at < cutoff, + ) + ).all() + + if not stale_sessions: + return 0 + + stale_session_ids = {s.session_id for s in stale_sessions} + + # Find all models for all stale sessions in one query + models_to_process = session.exec( + select(ModelDB).where( + ModelDB.session_id.in_(stale_session_ids), + ModelDB.status != "unloaded", + ) + ).all() + + sessions_with_failed_unloads = set() + for model in models_to_process: + if self.backend.has_model(model.model_id): + try: + self.backend.delete_model(model.model_id) + model.status = "unloaded" + unloaded_count += 1 + logger.info(f"Auto-unloaded stale model {model.model_id} from session {model.session_id}") + except Exception as e: + logger.error(f"Failed to auto-unload model {model.model_id}: {e}") + sessions_with_failed_unloads.add(model.session_id) + else: + # Model not in backend but status not unloaded - fix DB state + model.status = "unloaded" + + for sess in stale_sessions: + if sess.session_id not in sessions_with_failed_unloads: + sess.status = "expired" + logger.info(f"Expired stale session {sess.session_id} (last heartbeat: {sess.last_heartbeat_at})") + + session.commit() + + return unloaded_count + + def process_optim_step(self, model_id: str, request_data: types.OptimStepInput) -> types.OptimStepOutput: + """Process an optim_step request and apply accumulated gradients.""" + if not self.backend.has_model(model_id): + raise ValueError(f"Model {model_id} not loaded") + + return self.backend.optim_step(model_id, request_data) + + def process_forward_backward(self, requests: dict[str, tuple[str, types.ForwardBackwardInput]]) -> dict: + """Run forward and backward pass on a batch of requests.""" + prepared = prepare_model_pass_batch(requests) + return self.backend.forward_backward(prepared) + + def process_forward(self, requests: dict[str, tuple[str, types.ForwardBackwardInput]]) -> dict: + """Run forward-only pass on a batch of requests.""" + prepared = prepare_model_pass_batch(requests) + return self.backend.forward(prepared) + + def process_sample(self, requests: dict[str, tuple[str, types.SampleInput]]) -> dict: + """Generate samples for a batch of requests.""" + prepared = prepare_sample_batch(requests, self.config.checkpoints_base) + return self.backend.sample(prepared) + + def process_load_weights(self, model_id: str, request_data: types.LoadWeightsInput) -> types.LoadWeightsOutput: + """Loads a clean, trimmed training checkpoint.""" + if not self.backend.has_model(model_id): + raise ValueError("Model not loaded. Create the model before loading a checkpoint.") + + checkpoint_path = ( + self.config.checkpoints_base / request_data.source_model_id / f"{request_data.checkpoint_id}.tar.gz" + ) + + self.backend.load_checkpoint(checkpoint_path, model_id) + + return types.LoadWeightsOutput(type="load_weights") + + def process_save_weights(self, model_id: str, request_data: types.SaveWeightsInput) -> types.SaveWeightsOutput: + """ + Saves a clean training checkpoint by converting the trimmed NNX graph + to a pure dictionary before serialization, following official Flax docs. + """ + if not self.backend.has_model(model_id): + raise ValueError(f"Model {model_id} not loaded") + + checkpoint_id = request_data.path + output_path = self.config.checkpoints_base / model_id / f"{checkpoint_id}.tar.gz" + + with self._checkpoint_status_context(model_id, checkpoint_id, types.CheckpointType.TRAINING): + self.backend.save_checkpoint(output_path, model_id) + logger.info(f"Saved trimmed training checkpoint for model {model_id} to {output_path}") + + return types.SaveWeightsOutput( + path=f"tinker://{model_id}/weights/{checkpoint_id}", + type="save_weights", + ) + + def process_save_weights_for_sampler( + self, model_id: str, request_data: types.SaveWeightsForSamplerInput + ) -> types.SaveWeightsForSamplerOutput: + """Process a save_weights_for_sampler request and save model weights.""" + if not self.backend.has_model(model_id): + raise ValueError(f"Model {model_id} not loaded") + + # Make sure the user cannot store checkpoints in places like ../../ + checkpoint_id = Path(request_data.path).name + output_path = self.config.checkpoints_base / model_id / "sampler_weights" / f"{checkpoint_id}.tar.gz" + + with self._checkpoint_status_context(model_id, checkpoint_id, types.CheckpointType.SAMPLER): + self.backend.save_sampler_checkpoint(output_path, model_id) + logger.info(f"Saved LoRA adapter weights for model {model_id} to {output_path}") + + # Return path=None when using sampling_session_seq_id and seq_id (SDK expects this) + if request_data.sampling_session_seq_id is not None and request_data.seq_id is not None: + output_path_str = None + else: + output_path_str = f"tinker://{model_id}/{checkpoint_id}" + + return types.SaveWeightsForSamplerOutput( + path=output_path_str, + type="save_weights_for_sampler", + sampling_session_id=request_data.sampling_session_id, + ) + + def _complete_futures(self, results: dict[str, BaseModel]): + """Helper method to complete multiple futures in the database. + + Args: + results: Dict mapping request_id to result (Pydantic BaseModel) + """ + completed_at = datetime.now(timezone.utc) + params = [ + { + "request_id": int(request_id), + "result_data": result.model_dump(), + "status": RequestStatus.FAILED if isinstance(result, types.ErrorResponse) else RequestStatus.COMPLETED, + "completed_at": completed_at, + } + for request_id, result in results.items() + ] + + with Session(self.db_engine) as session: + session.execute(update(FutureDB), params) + session.commit() + + def process_single_request(self, request_type: types.RequestType, model_id: str, request_data: dict) -> BaseModel: + match request_type: + case types.RequestType.CREATE_MODEL: + return self.process_create_model(model_id, types.CreateModelInput.model_validate(request_data)) + case types.RequestType.OPTIM_STEP: + return self.process_optim_step(model_id, types.OptimStepInput.model_validate(request_data)) + case types.RequestType.SAVE_WEIGHTS_FOR_SAMPLER: + return self.process_save_weights_for_sampler( + model_id, types.SaveWeightsForSamplerInput.model_validate(request_data) + ) + case types.RequestType.SAVE_WEIGHTS: + return self.process_save_weights(model_id, types.SaveWeightsInput.model_validate(request_data)) + case types.RequestType.LOAD_WEIGHTS: + return self.process_load_weights(model_id, types.LoadWeightsInput.model_validate(request_data)) + case types.RequestType.UNLOAD_MODEL: + return self.process_unload_model(model_id, types.UnloadModelInput.model_validate(request_data)) + case _: + raise ValueError(f"Unknown request type: {request_type}") + + def process_single_requests(self, requests: dict[str, tuple[str, types.RequestType, dict]]): + """Process a collection of single (non-batchable) requests. + + Args: + requests: Dict mapping request_id to (model_id, request_type, request_data) tuples + """ + if not requests: + return + results = {} + for request_id, (model_id, request_type, request_data) in requests.items(): + with log_timing(f"process_single_request({request_type.value})"): + try: + result = self.process_single_request(request_type, model_id, request_data) + except Exception as e: + logger.exception(f"Error processing request {request_id}: {e}") + result = types.ErrorResponse(error=str(e), status="failed") + results[request_id] = result + self._complete_futures(results) + + def process_batch_requests( + self, + requests: dict[str, tuple[str, BaseModel]], + processor: Callable[[dict[str, tuple[str, BaseModel]]], dict[str, BaseModel]], + name: str, + ): + """Process a batch of requests with error handling and future completion. + + Args: + requests: Dict mapping request_id to (model_id, request_data) tuples + processor: Function that processes requests and returns results dict + name: Name for logging + """ + if not requests: + return + with log_timing(f"process_batch_requests({name}, n={len(requests)})"): + try: + error_results, valid_requests = self._filter_valid_requests(requests) + if valid_requests: + results = processor(valid_requests) + results.update(error_results) + else: + results = error_results + except Exception as e: + logger.exception(f"Error processing batch: {e}") + results = {request_id: types.ErrorResponse(error=str(e), status="failed") for request_id in requests} + self._complete_futures(results) + + def process_pending_requests(self): + """Main loop to process pending requests.""" + while True: + # Query for pending requests and extract data within session context + with Session(self.db_engine) as session: + # Use look-ahead scheduling to find batchable forward_backward and forward model passes + forward_backward_requests = self.find_batchable_model_passes( + session, types.RequestType.FORWARD_BACKWARD + ) + forward_requests = self.find_batchable_model_passes(session, types.RequestType.FORWARD) + # Find pending sample requests that can be batched + sample_requests = self.find_batchable_sample(session) + # Get other pending requests (non forward_backward and non sampling) + other_requests = self.find_single_requests(session) + + # Process batches outside of session context + self.process_batch_requests(forward_backward_requests, self.process_forward_backward, "forward_backward") + self.process_batch_requests(forward_requests, self.process_forward, "forward") + self.process_batch_requests(sample_requests, self.process_sample, "sample") + + # Process other request types individually (in the future we can also batch independent optim_steps) + self.process_single_requests(other_requests) + + # Periodically cleanup stale sessions (disabled if either config is negative) + cleanup_enabled = self.config.session_cleanup_interval_sec >= 0 and self.config.session_timeout_sec >= 0 + if cleanup_enabled and time.time() - self._last_cleanup_time > self.config.session_cleanup_interval_sec: + _ = self.cleanup_stale_sessions() + self._last_cleanup_time = time.time() + + # Poll every 100ms + time.sleep(0.1) + + def run(self): + """Entry point to start the engine.""" + logger.info("Starting background engine...") + self.process_pending_requests() + + +def main(): + """Entry point for the background engine.""" + logger.info("[DEBUG] Engine main() called") + + # Create argument parser and add Pydantic model fields + parser = argparse.ArgumentParser(description="SkyRL tx tinker engine for processing requests") + add_model(parser, EngineConfig) + + # Parse command-line arguments + args = parser.parse_args() + logger.info(f"[DEBUG] Parsed args: backend={args.backend}") + + # Create EngineConfig from parsed arguments + config = EngineConfig.model_validate(vars(args)) + logger.info(f"[DEBUG] Config validated: backend={config.backend}") + + # Initialize and run the engine + logger.info("[DEBUG] Creating TinkerEngine...") + engine = TinkerEngine(config) + logger.info("[DEBUG] TinkerEngine created, starting run()...") + engine.run() + + +if __name__ == "__main__": + main() diff --git a/skyrl-train/skyrl_train/tinker/extra/__init__.py b/skyrl-train/skyrl_train/tinker/extra/__init__.py new file mode 100644 index 000000000..d96313705 --- /dev/null +++ b/skyrl-train/skyrl_train/tinker/extra/__init__.py @@ -0,0 +1,4 @@ +from skyrl_train.tinker.extra.external_inference import ExternalInferenceClient +from skyrl_train.tinker.extra.skyrl_inference import SkyRLInferenceClient, attach_skyrl_inference + +__all__ = ["ExternalInferenceClient", "SkyRLInferenceClient", "attach_skyrl_inference"] diff --git a/skyrl-train/skyrl_train/tinker/extra/external_inference.py b/skyrl-train/skyrl_train/tinker/extra/external_inference.py new file mode 100644 index 000000000..f50eb2aa3 --- /dev/null +++ b/skyrl-train/skyrl_train/tinker/extra/external_inference.py @@ -0,0 +1,102 @@ +import httpx +from datetime import datetime, timezone +from sqlmodel.ext.asyncio.session import AsyncSession + +from skyrl_train.tinker import types +from skyrl_train.tinker.config import EngineConfig +from skyrl_train.tinker.db_models import FutureDB, RequestStatus +from skyrl_train.tx_utils.log import logger +from skyrl_train.tx_utils.storage import download_and_unpack + + +class ExternalInferenceClient: + """Client for calling external inference engines (e.g., vLLM).""" + + def __init__(self, engine_config: EngineConfig, db_engine): + self.base_url = f"{engine_config.external_inference_url}/v1" + self.api_key = engine_config.external_inference_api_key + self.checkpoints_base = engine_config.checkpoints_base + self.lora_base_dir = engine_config.external_inference_lora_base + self.db_engine = db_engine + + async def call_and_store_result( + self, + request_id: int, + sample_req, + model_id: str, + checkpoint_id: str, + ): + """Background task to call external engine and store result in database.""" + try: + async with httpx.AsyncClient( + base_url=self.base_url, + headers={"Authorization": f"Bearer {self.api_key}"}, + timeout=httpx.Timeout(300.0, connect=10.0), # 5 minutes for inference, 10s for connect + ) as http_client: + result = await self._forward_to_engine(sample_req, model_id, checkpoint_id, http_client) + result_data = result.model_dump() + status = RequestStatus.COMPLETED + except Exception as e: + logger.exception("External engine error") + result_data = {"error": str(e), "status": "failed"} + status = RequestStatus.FAILED + + async with AsyncSession(self.db_engine) as session: + future = await session.get(FutureDB, request_id) + future.result_data = result_data + future.status = status + future.completed_at = datetime.now(timezone.utc) + await session.commit() + + async def _forward_to_engine( + self, request, model_id: str, checkpoint_id: str, http_client: httpx.AsyncClient + ) -> types.SampleOutput: + """Forward request to vLLM with dynamic LoRA loading. + + Extracts the checkpoint to the configured external_inference_lora_base and references it by a model name + that vLLM can dynamically load via the lora_filesystem_resolver plugin. + """ + prompt_tokens = [token for chunk in request.prompt.chunks for token in chunk.tokens] + checkpoint_path = self.checkpoints_base / model_id / "sampler_weights" / f"{checkpoint_id}.tar.gz" + model_name = f"{model_id}_{checkpoint_id}" + target_dir = self.lora_base_dir / model_name + target_dir.parent.mkdir(parents=True, exist_ok=True) + + # Extract the checkpoint if it doesn't already exist + if not target_dir.exists(): + try: + with download_and_unpack(checkpoint_path) as extracted_path: + extracted_path.rename(target_dir) + except FileExistsError: + # This could happen if two processes try to download the file. + # In that case the other process won the race and created target_dir. + pass + + payload = { + "model": model_name, + "prompt": prompt_tokens, + "max_tokens": request.sampling_params.max_tokens, + "temperature": request.sampling_params.temperature, + "top_p": request.sampling_params.top_p, + "top_k": request.sampling_params.top_k, + "logprobs": True, + "stream": False, + "return_token_ids": True, + } + + response = await http_client.post("/completions", json=payload) + response.raise_for_status() + result = response.json() + + sequences = [] + for choice in result["choices"]: + lp = choice["logprobs"] + sequences.append( + types.GeneratedSequence( + tokens=choice["token_ids"], + logprobs=lp["token_logprobs"], + stop_reason=choice["finish_reason"], + ) + ) + + return types.SampleOutput(sequences=sequences, prompt_logprobs=[]) diff --git a/skyrl-train/skyrl_train/tinker/extra/skyrl_inference.py b/skyrl-train/skyrl_train/tinker/extra/skyrl_inference.py new file mode 100644 index 000000000..847f3eb64 --- /dev/null +++ b/skyrl-train/skyrl_train/tinker/extra/skyrl_inference.py @@ -0,0 +1,224 @@ +"""SkyRL-Train inference client for Tinker API integration. + +This module provides a client that calls skyrl-train's InferenceEngineClient +and handles Tinker type conversion and database storage for the API server. + +Architecture: + skyrl-tx API (/api/v1/asample) -> SkyRLInferenceClient -> InferenceEngineClient.sample() + +Usage: + # From skyrl-train, after initializing inference engines: + from skyrl_train.tinker.extra.skyrl_inference import attach_skyrl_inference + + # Attach to running API server + attach_skyrl_inference(app, inference_client) +""" + +from datetime import datetime, timezone +from typing import TYPE_CHECKING + +from sqlmodel.ext.asyncio.session import AsyncSession +from skyrl_train.tinker import types +from skyrl_train.tinker.db_models import FutureDB, RequestStatus +from skyrl_train.tx_utils.log import logger + +if TYPE_CHECKING: + from fastapi import FastAPI + from skyrl_train.inference_engines.inference_engine_client import ( + InferenceEngineClient, + ) + + +class SkyRLInferenceClient: + """Client for calling skyrl-train's inference engines via Tinker API. + + This client: + 1. Converts Tinker pydantic types to/from skyrl-train format + 2. Calls InferenceEngineClient.sample() directly + 3. Stores results in the database for async API requests + + Usage: + # During app startup + inference_client = InferenceEngineClient(engines, tokenizer, config) + skyrl_client = SkyRLInferenceClient(inference_client, db_engine) + app.state.skyrl_inference_client = skyrl_client + + # In /api/v1/asample endpoint + asyncio.create_task(skyrl_client.call_and_store_result(request_id, sample_req)) + """ + + def __init__(self, inference_client: "InferenceEngineClient", db_engine): + """Initialize the SkyRL inference client. + + Args: + inference_client: skyrl-train's InferenceEngineClient with engines initialized. + db_engine: SQLModel async engine for storing results in FutureDB. + """ + self.inference_client = inference_client + self.db_engine = db_engine + + async def call_and_store_result( + self, + request_id: int, + sample_req, + model_id: str = "", + checkpoint_id: str = "", + ): + """Background task to call skyrl-train inference and store result in database. + + Args: + request_id: FutureDB request ID to update with results. + sample_req: SampleRequest from the API endpoint. + model_id: Model identifier (unused for now, skyrl-train uses pre-loaded model). + checkpoint_id: Checkpoint identifier (unused for now). + """ + try: + result = await self._sample(sample_req) + result_data = result.model_dump() + status = RequestStatus.COMPLETED + except Exception as e: + logger.exception("SkyRL inference error") + result_data = {"error": str(e), "status": "failed"} + status = RequestStatus.FAILED + + async with AsyncSession(self.db_engine) as session: + future = await session.get(FutureDB, request_id) + future.result_data = result_data + future.status = status + future.completed_at = datetime.now(timezone.utc) + await session.commit() + + async def _sample(self, request) -> types.SampleOutput: + """Call skyrl-train's sample() and convert response to Tinker types. + + Args: + request: SampleRequest from the API endpoint. + + Returns: + types.SampleOutput with generated sequences. + """ + # Convert Tinker ModelInput to flat token list + prompt_tokens = self._extract_prompt_tokens(request.prompt) + + # Convert Tinker SamplingParams to dict + sampling_params = self._convert_sampling_params(request.sampling_params) + + # Call skyrl-train's InferenceEngineClient directly + # Note: We don't pass session_id (defaults to None for random load-balancing). + # Tinker's sampling_session_id/seq_id identify model checkpoints, not conversations, + # and each sample() call is independent with no KV cache benefit from sticky routing. + # This matches official Tinker backend behavior. + result = await self.inference_client.sample( + prompt_token_ids=prompt_tokens, + num_samples=request.num_samples, + sampling_params=sampling_params, + ) + + # Convert result to Tinker types + return self._convert_to_sample_output(result) + + def _extract_prompt_tokens(self, model_input: types.ModelInput) -> list[int]: + """Extract flat token list from Tinker ModelInput. + + Args: + model_input: Tinker ModelInput with chunks of tokens. + + Returns: + Flat list of token IDs. + """ + return [token for chunk in model_input.chunks for token in chunk.tokens] + + def _convert_sampling_params(self, params: types.SamplingParams) -> dict: + """Convert Tinker SamplingParams to dict for skyrl-train. + + Args: + params: Tinker SamplingParams pydantic model. + + Returns: + Dict compatible with skyrl-train's sampling. + """ + result = { + "temperature": params.temperature, + "max_tokens": params.max_tokens, + "top_k": params.top_k, + "top_p": params.top_p, + } + + if params.seed is not None: + result["seed"] = params.seed + + # Handle stop tokens/strings + if params.stop_tokens: + result["stop_token_ids"] = params.stop_tokens + if params.stop_strings: + result["stop"] = params.stop_strings + + return result + + def _convert_to_sample_output(self, result: dict) -> types.SampleOutput: + """Convert InferenceEngineOutput to Tinker SampleOutput. + + Args: + result: InferenceEngineOutput dict from skyrl-train's sample(). + + Returns: + types.SampleOutput with GeneratedSequence list. + """ + sequences = [] + num_samples = len(result["response_ids"]) + + for i in range(num_samples): + # Map skyrl-train stop reasons to Tinker format + stop_reason = result["stop_reasons"][i] + if stop_reason in ("stop", "eos"): + tinker_stop_reason = "stop" + else: + tinker_stop_reason = "length" + + # Extract logprobs if available + logprobs = [] + if result.get("response_logprobs") and result["response_logprobs"][i]: + logprobs = result["response_logprobs"][i] + + sequences.append( + types.GeneratedSequence( + tokens=result["response_ids"][i], + logprobs=logprobs, + stop_reason=tinker_stop_reason, + ) + ) + + # Note: prompt_logprobs not supported yet in skyrl-train's sample() + return types.SampleOutput( + sequences=sequences, + prompt_logprobs=None, + ) + + +def attach_skyrl_inference(app: "FastAPI", inference_client: "InferenceEngineClient") -> None: + """Attach SkyRL inference client to an existing FastAPI app. + + This enables the /api/v1/asample endpoint to use skyrl-train's inference + engines directly instead of the internal JAX backend or external vLLM. + + Args: + app: The FastAPI app instance (must have db_engine in state). + inference_client: Initialized InferenceEngineClient from skyrl-train. + + Example: + # In skyrl-train after engines are initialized: + from skyrl_train.tinker.extra.skyrl_inference import attach_skyrl_inference + + app = get_running_api_app() # Get the FastAPI app + attach_skyrl_inference(app, llm_client) + """ + if not hasattr(app.state, "db_engine"): + raise RuntimeError("App must have db_engine initialized before attaching SkyRL inference") + + skyrl_client = SkyRLInferenceClient(inference_client, app.state.db_engine) + app.state.skyrl_inference_client = skyrl_client + + # Also set as external_inference_client so existing endpoint code routes to it + app.state.external_inference_client = skyrl_client + + logger.info("SkyRL-train inference client attached to API server") diff --git a/skyrl-train/skyrl_train/tinker/loss_fns.py b/skyrl-train/skyrl_train/tinker/loss_fns.py new file mode 100644 index 000000000..534443b54 --- /dev/null +++ b/skyrl-train/skyrl_train/tinker/loss_fns.py @@ -0,0 +1,60 @@ +"""Loss functions for training.""" + +try: + import jax + import jax.numpy as jnp + JAX_AVAILABLE = True +except ImportError: + JAX_AVAILABLE = False + + +if JAX_AVAILABLE: + def safe_loss_mask(loss_output: jax.Array, loss_mask: jax.Array) -> jax.Array: + "Strongly mask the loss_output to 0.0 if the loss_mask is zero." + return jnp.where(loss_mask != 0.0, loss_mask * loss_output, jnp.zeros_like(loss_output)) + + + def cross_entropy_loss( + target_logprobs: jax.Array, loss_mask: jax.Array, sampling_logprobs: jax.Array, advantages: jax.Array + ) -> jax.Array: + "Standard cross-entropy loss (i.e., negative log-likelihood)." + return -safe_loss_mask(target_logprobs, loss_mask) + + + def importance_sampling_loss( + target_logprobs: jax.Array, loss_mask: jax.Array, sampling_logprobs: jax.Array, advantages: jax.Array + ) -> jax.Array: + "Importance sampling loss with target_logprobs from learner policy and sampling_logprobs from sampling policy." + prob_ratio = jnp.exp(target_logprobs - sampling_logprobs) + return -safe_loss_mask(prob_ratio * advantages, loss_mask) + + + def ppo_loss( + target_logprobs: jax.Array, loss_mask: jax.Array, sampling_logprobs: jax.Array, advantages: jax.Array + ) -> jax.Array: + "PPO style clipped version of the importance sampling loss." + prob_ratio = jnp.exp(target_logprobs - sampling_logprobs) + clipped_ratio = jnp.clip(prob_ratio, 0.8, 1.2) + unclipped = prob_ratio * advantages + clipped = clipped_ratio * advantages + return -safe_loss_mask(jnp.minimum(unclipped, clipped), loss_mask) + + # Map from string names to loss functions + LOSS_FUNCTION_MAP = { + "cross_entropy": cross_entropy_loss, + "importance_sampling": importance_sampling_loss, + "ppo": ppo_loss, + } + # List of loss functions in order (for jax.lax.switch) + LOSS_FUNCTIONS = list(LOSS_FUNCTION_MAP.values()) +else: + # When JAX is not available (e.g., using SkyRL backend), just define the names + LOSS_FUNCTION_MAP = { + "cross_entropy": None, + "importance_sampling": None, + "ppo": None, + } + LOSS_FUNCTIONS = [] + +# Map from loss function name to index (for jax.lax.switch or lookup) +LOSS_TYPES = {name: idx for idx, name in enumerate(LOSS_FUNCTION_MAP.keys())} diff --git a/skyrl-train/skyrl_train/tinker/types.py b/skyrl-train/skyrl_train/tinker/types.py new file mode 100644 index 000000000..adbe1a8c2 --- /dev/null +++ b/skyrl-train/skyrl_train/tinker/types.py @@ -0,0 +1,258 @@ +# These are the types we use to represent the data internally. +# They have some commonalities with the API request and response +# types as well as the database models, but are distinct. For +# example, usually we try to avoid optional values in these types. + +from __future__ import annotations + +from enum import Enum +from typing import Literal +from urllib.parse import urlparse + +from pydantic import BaseModel + + +class RequestType(str, Enum): + """Types of requests that can be processed.""" + + CREATE_MODEL = "create_model" + FORWARD_BACKWARD = "forward_backward" + FORWARD = "forward" + OPTIM_STEP = "optim_step" + SAVE_WEIGHTS_FOR_SAMPLER = "save_weights_for_sampler" + SAVE_WEIGHTS = "save_weights" + LOAD_WEIGHTS = "load_weights" + SAMPLE = "sample" + UNLOAD_MODEL = "unload_model" + + # External request that should not be processed by the engine + EXTERNAL = "external" + + +class CheckpointType(str, Enum): + """Type of checkpoint.""" + + TRAINING = "training" + SAMPLER = "sampler" + + +class TinkerPath(BaseModel): + primary_id: str + kind: str + secondary_id: str + + @classmethod + def parse(cls, url: str) -> TinkerPath | None: + """Parse a URL string into a TinkerPath object.""" + parsed = urlparse(url) + + match (parsed.scheme, *parsed.path.split("/")): + case ("tinker", "", secondary_id): + return cls(primary_id=parsed.netloc, kind="", secondary_id=secondary_id) + case ("tinker", "", kind, secondary_id): + return cls(primary_id=parsed.netloc, kind=kind, secondary_id=secondary_id) + case _: + return None + + +class AdamParams(BaseModel): + learning_rate: float + beta1: float + beta2: float + eps: float + weight_decay: float + + +class LoraConfig(BaseModel): + rank: int + alpha: float + seed: int + train_attn: bool = True + train_mlp: bool = True + train_unembed: bool = False + + +class CreateModelInput(BaseModel): + lora_config: LoraConfig + + +class CreateModelOutput(BaseModel): + model_id: str + base_model: str + lora_config: LoraConfig + + +class UnloadModelInput(BaseModel): + pass + + +class UnloadModelOutput(BaseModel): + model_id: str + status: str + type: str = "unload_model" + + +class ModelInputChunk(BaseModel): + tokens: list[int] + + +class ModelInput(BaseModel): + chunks: list[ModelInputChunk] + + +class TensorData(BaseModel): + data: list[int] | list[float] + + +class LossFnInputs(BaseModel): + target_tokens: TensorData + weights: TensorData + advantages: TensorData + logprobs: TensorData + + +class Datum(BaseModel): + loss_fn_inputs: LossFnInputs + model_input: ModelInput + + +class ForwardBackwardInput(BaseModel): + data: list[Datum] + loss_fn: Literal["cross_entropy", "importance_sampling", "ppo"] + + +class ForwardBackwardOutput(BaseModel): + loss_fn_output_type: str + loss_fn_outputs: list[dict] + metrics: dict + + +class ErrorResponse(BaseModel): + error: str + status: str + + +class OptimStepInput(BaseModel): + adam_params: AdamParams + + +class OptimStepOutput(BaseModel): + pass + + +class SaveWeightsForSamplerInput(BaseModel): + path: str | None = None + sampling_session_seq_id: int | None = None + seq_id: int | None = None + sampling_session_id: str | None = None + + +class SaveWeightsForSamplerOutput(BaseModel): + path: str | None = None + type: str + sampling_session_id: str | None = None + + +class SaveWeightsInput(BaseModel): + path: str + + +class SaveWeightsOutput(BaseModel): + path: str + type: str + + +class LoadWeightsInput(BaseModel): + source_model_id: str + checkpoint_id: str + + +class LoadWeightsOutput(BaseModel): + type: str + + +class SamplingParams(BaseModel): + temperature: float + max_tokens: int + seed: int + stop_tokens: list[int] | None = None + stop_strings: list[str] | None = None + top_k: int = -1 # -1 for no limit + top_p: float = 1.0 # 1.0 for no filtering + + +class ModelMetadata(BaseModel): + adapter_index: int + lora_config: LoraConfig + loaded_checkpoint_id: str | None = None + + +class SampleInput(BaseModel): + base_model: str | None = None + prompt: ModelInput + sampling_params: SamplingParams + num_samples: int + checkpoint_id: str + prompt_logprobs: bool + + +class GeneratedSequence(BaseModel): + stop_reason: Literal["length", "stop"] + tokens: list[int] + logprobs: list[float] + + +class SampleOutput(BaseModel): + sequences: list[GeneratedSequence] + prompt_logprobs: list[float] | None = None + + +# Metrics tracked in the engine +class EngineMetrics(BaseModel): + train_seq_len_jit_times: dict[int, float] = {} + sample_seq_len_jit_times: dict[int, float] = {} + + +# Prepared batch data for backend processing +# These are prepared by the engine and passed to the backend + + +class PreparedModelPassBatch(BaseModel): + """Prepared batch data for forward/forward_backward operations. + + Engine extracts this from requests, backend converts to JAX arrays and computes. + """ + + # Per-example data (list of lists) + all_input_ids: list[list[int]] + all_targets: list[list[int]] + all_token_weights: list[list[float]] + all_sampling_logprobs: list[list[float]] + all_advantages: list[list[float]] + + # Per-example scalars + all_model_ids: list[str] + all_loss_fn_types: list[int] + + # Mapping from examples back to requests: (request_id, model_id, start_idx, end_idx) + request_batch_slices: list[tuple[str, str, int, int]] + + +class PreparedSampleBatch(BaseModel): + """Prepared batch data for sample operations. + + Engine extracts this from requests, backend converts to JAX arrays and computes. + """ + + # Per-sample data + all_prompts: list[list[int]] + all_sampling_params: list[SamplingParams] + all_model_ids: list[str] + all_checkpoint_ids: list[str] + all_checkpoint_paths: list[str] + + # Whether any request needs prompt logprobs + needs_prompt_logprobs: bool + + # Mapping from samples back to requests: (request_id, model_id, start_idx, end_idx, prompt_logprobs_requested) + request_batch_slices: list[tuple[str, str, int, int, bool]] diff --git a/skyrl-train/skyrl_train/tx_utils/generator.py b/skyrl-train/skyrl_train/tx_utils/generator.py new file mode 100644 index 000000000..f461a5613 --- /dev/null +++ b/skyrl-train/skyrl_train/tx_utils/generator.py @@ -0,0 +1,449 @@ +"""Generator mixin for autoregressive text generation with KV caching.""" + +from __future__ import annotations +from dataclasses import dataclass +import functools + +import jax +import jax.numpy as jnp +from tokenizers.decoders import DecodeStream +import tx.utils.models +from tx.tinker import types + + +@jax.tree_util.register_dataclass +@dataclass +class KVCache: + """Key-value cache for all layers, each entry in the list corresponds to one layer.""" + + keys: list[jax.Array] + values: list[jax.Array] + cache_position: jax.Array # Per-sequence positions of shape [B] for left-aligned decoding + + @staticmethod + def update( + kv_cache: KVCache | None, + keys: list[jax.Array], + values: list[jax.Array], + positions: jax.Array, + attention_mask: jax.Array, + ) -> KVCache: + """Create an updated KVCache with computed cache positions for left-aligned decoding. + + Args: + kv_cache: Existing KVCache (None during prefill). + keys: List of key arrays per layer. + values: List of value arrays per layer. + positions: Position indices with shape [B, seq_len]. + attention_mask: Attention mask with shape [B, seq_len]. + + Returns: + New KVCache with computed cache_position. + """ + if kv_cache is not None: + # Decode: next position is current position + 1 + cache_position = positions[:, 0] + 1 + else: + # Prefill: next position is the sequence length (number of real tokens) + cache_position = attention_mask.sum(axis=1) + return KVCache(keys=keys, values=values, cache_position=cache_position) + + @staticmethod + def update_layer(kv_cache, k, v, positions): + """Update a single layer's KV cache at the given positions (for left-aligned decoding). + + Args: + kv_cache: Tuple of (k_cache, v_cache) arrays for this layer. + k: New key values with shape [B, seq_len, num_heads, head_dim]. + v: New value values with shape [B, seq_len, num_heads, head_dim]. + positions: Position indices with shape [B, seq_len]. + """ + k_cache, v_cache = kv_cache + + def update_at_pos(cache_slice, new_val_slice, pos): + return jax.lax.dynamic_update_slice(cache_slice, new_val_slice, (pos, 0, 0)) + + k = jax.vmap(update_at_pos)(k_cache, k, positions[:, 0]) + v = jax.vmap(update_at_pos)(v_cache, v, positions[:, 0]) + return k, v + + def pad_to_length(self, max_length: int) -> KVCache: + """Pad KV cache to a specified maximum length. + + Args: + max_length: Target length to pad the cache to. + + Returns: + New KVCache with padded keys and values. + """ + # k and v have shape [B, T, num_heads, head_dim] + cache_pad_length = max_length - self.keys[0].shape[1] + pad_spec = ((0, 0), (0, cache_pad_length), (0, 0), (0, 0)) + return KVCache( + keys=[jnp.pad(k, pad_spec) for k in self.keys], + values=[jnp.pad(v, pad_spec) for v in self.values], + cache_position=self.cache_position, + ) + + +@jax.tree_util.register_dataclass +@dataclass +class DecodeState: + """State of the decode loop.""" + + kv_cache: KVCache + rngs: jax.Array # of shape [B, key_dim] + attention_mask: jax.Array + last_positions: jax.Array + logits: jax.Array + stop_pos: jax.Array # Position where stop token was found + + +@dataclass +class GenerateOutput: + """Result from autoregressive text generation. + + Attributes: + generated_ids: List of token ID lists, one for each request (excluding the prompt). + stop_reasons: Reason for stopping generation for each sequence ('stop' or 'length'). + logprobs: Log probabilities for each sampled token. + """ + + generated_ids: list[list[int]] + stop_reasons: list[str] + logprobs: list[list[float]] + prompt_logprobs: list[list[float]] | None = None + + +def find_string_stop_position( + tokens: list[int], + tokenizer, + stop_strings: list[str], +) -> int | None: + """Find the token position where a stop string first appears. + + Incrementally decodes tokens and checks for stop string matches. + Uses the tokenizers DecodeStream for efficient incremental decoding. + + Args: + tokens: List of generated token IDs + tokenizer: HuggingFace tokenizer instance + stop_strings: List of stop strings to search for + + Returns: + Token index to truncate to (exclusive), or None if no stop found. + """ + if not stop_strings or not tokens: + return None + + # Incremental decode using DecodeStream + stream = DecodeStream(skip_special_tokens=False) + text = "" + for i, token in enumerate(tokens): + chunk = stream.step(tokenizer._tokenizer, token) + if chunk is not None: + text += chunk + for stop_string in stop_strings: + if stop_string in text: + return i + 1 + + return None + + +class GeneratorMixin: + """Adds autoregressive generation with KV caching to causal language models.""" + + @staticmethod + @functools.partial( + jax.jit, static_argnames=("max_length", "max_new_tokens", "max_top_k", "use_top_p", "prompt_logprobs") + ) + def _prefill_and_decode( + model, + input_ids: jax.Array, + attention_mask: jax.Array, + max_length: int, + max_new_tokens: int, + adapter_indices: jax.Array | None, + temperatures: jax.Array, + rngs: jax.Array, + stop_tokens: jax.Array, + top_k_values: jax.Array, + top_p_values: jax.Array, + max_top_k: int, + use_top_p: bool, + prompt_logprobs: bool = False, + ): + """JIT-compiled prefill + decode loop. Fuses everything for maximum efficiency.""" + # Prefill: process full prompt (left-aligned, so positions start at 0) + outputs = model( + input_ids, + attention_mask=attention_mask, + adapter_indices=adapter_indices, + ) + + # For left-aligned sequences, find the last real token position for each sequence + last_token_idx = attention_mask.sum(axis=1) - 1 # Shape: [B] + batch_idx = jnp.arange(input_ids.shape[0]) + + # Compute logits for sampling and optionally for prompt logprobs + if prompt_logprobs: + # Compute all logits for prompt logprobs and sampling the first token + all_logits = model.compute_logits(outputs.last_hidden_state, adapter_indices) + last_logits = all_logits[batch_idx, last_token_idx, :] # Shape: [B, vocab_size] + prompt_logprobs_array = model.logits_to_logprobs(all_logits[:, :-1, :], input_ids[:, 1:]) + else: + # Only compute logits for the last position for sampling + last_hidden = outputs.last_hidden_state[batch_idx, last_token_idx][:, None, :] # Shape: [B, 1, H] + last_logits = model.compute_logits(last_hidden, adapter_indices)[:, 0, :] + prompt_logprobs_array = None + + # Pad KV cache and attention mask + kv_cache = outputs.kv_cache.pad_to_length(max_length) + + # Pad KV cache and attention mask to max_length + kv_cache = kv_cache.pad_to_length(max_length) + decode_attention_mask = jnp.pad(attention_mask, ((0, 0), (0, max_length - attention_mask.shape[1]))) + + def decode_fn(s: DecodeState, step: jax.Array) -> tuple[DecodeState, tuple[jax.Array, jax.Array]]: + """Decode one token step. Returns (state, (token, logprob)) for scan accumulation.""" + # Sample next token + split_keys = jax.vmap(jax.random.split)(s.rngs) + rngs, sample_keys = split_keys[:, 0], split_keys[:, 1] + + zero_temp_mask = temperatures == 0.0 + scaled_logits = s.logits / jnp.where(zero_temp_mask, 1.0, temperatures)[:, None] + + # Apply top_k and top_p filtering + if max_top_k > 0: + scaled_logits = apply_top_k_batch(scaled_logits, top_k_values, max_top_k) + if use_top_p: + scaled_logits = apply_top_p_batch(scaled_logits, top_p_values) + + sampled = jax.vmap(lambda key, logit: jax.random.categorical(key, logit, axis=-1))( + sample_keys, scaled_logits + ) + greedy = jnp.argmax(s.logits, axis=-1) + next_token = jnp.where(zero_temp_mask[:, None], greedy[:, None], sampled[:, None]) + sampled_logprob = model.logits_to_logprobs(s.logits, next_token[:, 0])[:, None] + + # Track first stop token position (-1 means not stopped yet) + is_stop = jnp.any(next_token == stop_tokens, axis=1) + stop_pos = jnp.where((s.stop_pos == -1) & is_stop, step + 1, s.stop_pos) + + # Update attention mask at per-sequence positions (for left-aligned sequences) + batch_idx = jnp.arange(s.attention_mask.shape[0]) + next_attention_mask = s.attention_mask.at[batch_idx, s.kv_cache.cache_position].set(1) + + outputs = model( + next_token, + attention_mask=next_attention_mask, + positions=s.last_positions + 1, + kv_cache=s.kv_cache, + adapter_indices=adapter_indices, + ) + # Compute logits for the next token + next_logits = model.compute_logits(outputs.last_hidden_state, adapter_indices)[:, 0, :] + next_state = DecodeState( + kv_cache=outputs.kv_cache, + rngs=rngs, + attention_mask=next_attention_mask, + last_positions=s.last_positions + 1, + logits=next_logits, + stop_pos=stop_pos, + ) + return next_state, (next_token, sampled_logprob) + + initial_state = DecodeState( + kv_cache=kv_cache, + rngs=rngs, + attention_mask=decode_attention_mask, + last_positions=last_token_idx[:, None], + logits=last_logits, + stop_pos=jnp.full((input_ids.shape[0],), -1), + ) + + final_state, (tokens_stacked, logprobs_stacked) = jax.lax.scan( + decode_fn, initial_state, xs=jnp.arange(max_new_tokens) + ) + + # Post-process: transpose scan outputs from [Steps, Batch, 1] to [Batch, Steps] + new_tokens = jnp.swapaxes(tokens_stacked, 0, 1).squeeze(-1) + new_logprobs = jnp.swapaxes(logprobs_stacked, 0, 1).squeeze(-1) + + return new_tokens, new_logprobs, final_state.stop_pos, prompt_logprobs_array + + def generate( + self, + input_ids: jax.Array, + attention_mask: jax.Array, + *, + sampling_params: list[types.SamplingParams], + adapter_indices: jax.Array | None = None, + prompt_logprobs: bool = False, + tokenizer=None, + ) -> GenerateOutput: + """Generate text autoregressively with KV caching. + + Args: + tokenizer: Optional tokenizer for string stop sequence detection. + Required if any sampling_params has stop_strings set. + + Returns: + GenerateOutput containing generated_ids, stop_reasons, and optionally logprobs. + """ + batch_size, prompt_length = input_ids.shape + assert len(sampling_params) == batch_size + max_new_tokens = max(sampling_param.max_tokens for sampling_param in sampling_params) + max_length = tx.utils.models.round_up_seq_len(prompt_length + max_new_tokens) + temperatures = jnp.array([sampling_param.temperature for sampling_param in sampling_params]) + top_k_values = jnp.array([sampling_param.top_k for sampling_param in sampling_params], dtype=jnp.int32) + top_p_values = jnp.array([sampling_param.top_p for sampling_param in sampling_params], dtype=jnp.float32) + + # One PRNGKey per provided seed + seeds = [sampling_param.seed for sampling_param in sampling_params] + rngs = jax.vmap(jax.random.PRNGKey)(jnp.array(seeds)) + + # Extract stop tokens and pad to same length + max_stop_tokens = max(len(sp.stop_tokens) if sp.stop_tokens else 0 for sp in sampling_params) + stop_tokens = [] + for sp in sampling_params: + stop = sp.stop_tokens or [] + stop_tokens.append(stop + [-1] * (max_stop_tokens - len(stop))) + stop_tokens = jnp.array(stop_tokens, dtype=jnp.int32) + + # Capture prompt lengths for prompt_logprobs if requested + prompt_lengths = attention_mask.sum(axis=1) if prompt_logprobs else None + + # Compute static flags for top_k and top_p filtering + max_top_k = max((sp.top_k for sp in sampling_params if sp.top_k > 0), default=0) + use_top_p = any(sp.top_p < 1.0 for sp in sampling_params) + + new_tokens, new_logprobs, stop_pos, prompt_logprobs_array = self._prefill_and_decode( + self, + input_ids, + attention_mask, + max_length, + max_new_tokens, + adapter_indices, + temperatures, + rngs, + stop_tokens, + top_k_values, + top_p_values, + max_top_k, + use_top_p, + prompt_logprobs=prompt_logprobs, + ) + + max_tokens = jnp.array([sp.max_tokens for sp in sampling_params]) + # stop_pos is -1 if no stop token found; has_stop is true only if found within limit + has_stop = (stop_pos != -1) & (stop_pos <= max_tokens) + end_positions = jnp.where(has_stop, stop_pos, max_tokens) + + # In multi-host mode, gather all shards before device_get + if jax.process_count() > 1: + from jax.experimental import multihost_utils + + (new_tokens, has_stop, new_logprobs, end_positions, prompt_logprobs_array, prompt_lengths) = jax.tree.map( + lambda x: multihost_utils.process_allgather(x, tiled=True), + (new_tokens, has_stop, new_logprobs, end_positions, prompt_logprobs_array, prompt_lengths), + ) + + # Single device-to-host transfer + ( + new_tokens_host, + has_stop_host, + new_logprobs_host, + end_positions_host, + prompt_logprobs_host, + prompt_lengths_host, + ) = jax.device_get((new_tokens, has_stop, new_logprobs, end_positions, prompt_logprobs_array, prompt_lengths)) + + # Build output lists, applying string stop detection where needed + generated_ids = [] + stop_reasons = [] + logprobs_out = [] + + for i in range(batch_size): + tokens = new_tokens_host[i][: end_positions_host[i]].tolist() + token_logprobs = new_logprobs_host[i][: end_positions_host[i]].tolist() + stop_reason = "stop" if has_stop_host[i] else "length" + + # Apply string stop detection if stop_strings specified + if sampling_params[i].stop_strings: + assert tokenizer is not None, "tokenizer is required when stop_strings is specified" + assert stop_reason == "length", "stop_tokens cannot be specified when stop_strings is specified" + string_stop_pos = find_string_stop_position(tokens, tokenizer, sampling_params[i].stop_strings) + if string_stop_pos is not None: + tokens = tokens[:string_stop_pos] + token_logprobs = token_logprobs[:string_stop_pos] + stop_reason = "stop" + + generated_ids.append(tokens) + stop_reasons.append(stop_reason) + logprobs_out.append(token_logprobs) + + return GenerateOutput( + generated_ids=generated_ids, + stop_reasons=stop_reasons, + logprobs=logprobs_out, + prompt_logprobs=( + [prompt_logprobs_host[i, : prompt_lengths_host[i] - 1].tolist() for i in range(batch_size)] + if prompt_logprobs + else None + ), + ) + + +def apply_top_k_batch(logits: jax.Array, k_values: jax.Array, max_k: int) -> jax.Array: + """Keep only top-k logits per example, set rest to -inf. + + Args: + logits: Logits tensor of shape [batch_size, vocab_size] + k_values: Per-example k values of shape [batch_size]. If k <= 0, no filtering. + max_k: Static maximum k value, must be > 0. + + Returns: + Filtered logits with the same shape. + """ + assert max_k > 0 + top_values, top_indices = jax.lax.top_k(logits, max_k) + + # Keep only first k values per example + keep = jnp.arange(max_k) < k_values[:, None] + top_values = jnp.where(keep, top_values, -jnp.inf) + + # Scatter back to original positions + batch_idx = jnp.arange(logits.shape[0])[:, None] + result = jnp.full_like(logits, -jnp.inf).at[batch_idx, top_indices].set(top_values) + + return jnp.where(k_values[:, None] <= 0, logits, result) + + +def apply_top_p_batch(logits: jax.Array, p_values: jax.Array) -> jax.Array: + """Keep only tokens with cumulative probability up to p, set rest to -inf. + + Args: + logits: Logits tensor of shape [batch_size, vocab_size] + p_values: Per-example p values of shape [batch_size]. If p >= 1.0, no filtering. + + Returns: + Filtered logits with the same shape. + """ + # Sort by logits (equivalent to sorting by probs since softmax is monotonic) + sorted_indices = jnp.argsort(-logits, axis=-1) + sorted_logits = jnp.take_along_axis(logits, sorted_indices, axis=-1) + sorted_probs = jax.nn.softmax(sorted_logits, axis=-1) + + # Exclusive cumsum: cumsum[i] - prob[i] gives sum of probs *before* position i + cumsum_exclusive = jnp.cumsum(sorted_probs, axis=-1) - sorted_probs + + keep_mask = cumsum_exclusive < p_values[:, None] + keep_mask = keep_mask.at[:, 0].set(True) # Always keep top token + filtered_sorted_logits = jnp.where(keep_mask, sorted_logits, -jnp.inf) + + # Scatter back to original positions + batch_idx = jnp.arange(logits.shape[0])[:, None] + result = jnp.empty_like(logits).at[batch_idx, sorted_indices].set(filtered_sorted_logits) + + return jnp.where(p_values[:, None] >= 1.0, logits, result) diff --git a/skyrl-train/skyrl_train/tx_utils/log.py b/skyrl-train/skyrl_train/tx_utils/log.py new file mode 100644 index 000000000..0bedc7547 --- /dev/null +++ b/skyrl-train/skyrl_train/tx_utils/log.py @@ -0,0 +1,94 @@ +from enum import Enum +import logging +import os +from pathlib import Path +from typing import Any + +from rich.console import Console + +try: + import wandb # type: ignore[import-not-found] +except ImportError: + wandb = None # type: ignore[assignment] + + +def _setup_root_logger() -> None: + logger = logging.getLogger("tx") + logger.setLevel(logging.DEBUG) + logger.propagate = False # Prevent propagation to root logger + console = Console(highlight=True, markup=True) + + class RichStreamHandler(logging.Handler): + def emit(self, record): + msg = self.format(record) + console.print(msg, highlight=True) + + handler = RichStreamHandler() + handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) + logger.addHandler(handler) + + +def add_file_handler(path: Path | str, level: int = logging.DEBUG, *, print_path: bool = True) -> None: + logger = logging.getLogger("tx") + handler = logging.FileHandler(path) + handler.setLevel(level) + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + handler.setFormatter(formatter) + logger.addHandler(handler) + if print_path: + print(f"Logging to '{path}'") + + +_setup_root_logger() +logger = logging.getLogger("tx") + + +class ExperimentTracker(str, Enum): + wandb = "wandb" + + +class Tracker: + + def __init__(self, config: dict[str, Any], **kwargs): + logger.info(f"model config: {config}") + + def log(self, metrics: dict[str, Any], step: int | None = None) -> None: + data = metrics if step is None else {"step": step, **metrics} + logger.info( + ", ".join( + f"{key}: {value:.3e}" if isinstance(value, float) else f"{key}: {value}" for key, value in data.items() + ) + ) + + +class WandbTracker(Tracker): + + def __init__(self, config: dict[str, Any], **kwargs): + super().__init__(config, **kwargs) + if wandb is None: + raise RuntimeError("wandb not installed") + if not os.environ.get("WANDB_API_KEY"): + raise ValueError("WANDB_API_KEY environment variable not set") + self.run = wandb.init(config=config, **kwargs) # type: ignore[union-attr] + + def log(self, metrics: dict[str, Any], step: int | None = None) -> None: + super().log(metrics, step) + if wandb is not None: + wandb.log(metrics, step=step) # type: ignore[union-attr] + + def __del__(self): + if wandb is not None: + wandb.finish() # type: ignore[union-attr] + + +def get_tracker(tracker: ExperimentTracker | None, config: dict[str, Any], **kwargs) -> Tracker: + match tracker: + case None: + return Tracker(config, **kwargs) + case ExperimentTracker.wandb: + return WandbTracker(config, **kwargs) + case _: + raise ValueError(f"Unsupported experiment tracker: {tracker}") + + +__all__ = ["logger"] diff --git a/skyrl-train/skyrl_train/tx_utils/logits_processor.py b/skyrl-train/skyrl_train/tx_utils/logits_processor.py new file mode 100644 index 000000000..bf55e429c --- /dev/null +++ b/skyrl-train/skyrl_train/tx_utils/logits_processor.py @@ -0,0 +1,133 @@ +"""Mixin for logits computation in causal language models.""" + +from abc import ABC, abstractmethod +from typing import Callable + +import jax +import jax.numpy as jnp +from tx.models.configs import ModelConfig + + +# lm_head: (hidden_states, adapter_indices) -> logits +LMHead = Callable[[jax.Array, jax.Array | None], jax.Array] + + +class LogitsProcessorMixin(ABC): + """Mixin providing logits/logprobs computation for causal language models.""" + + @abstractmethod + def get_model_config(self) -> ModelConfig: + """Return the model configuration.""" + ... + + @abstractmethod + def get_lm_head(self) -> LMHead: + """Return the lm_head callable for logits computation.""" + ... + + def compute_logits( + self, + hidden_states: jax.Array, + adapter_indices: jax.Array | None = None, + ) -> jax.Array: + """Compute logits from hidden states. For sampling. + + Args: + hidden_states: Hidden states from model forward [B, T, H]. + adapter_indices: Optional adapter indices for LoRA. + + Returns: + Logits [B, T, V]. + """ + return self.get_lm_head()(hidden_states, adapter_indices) + + def compute_logprobs( + self, + hidden_states: jax.Array, + target_ids: jax.Array, + adapter_indices: jax.Array | None = None, + ) -> jax.Array: + """Compute logprobs from hidden states. For training and prompt logprobs. + + Args: + hidden_states: Hidden states [B, T, H]. + target_ids: Target token IDs [B, T]. + adapter_indices: Adapter indices for LoRA on lm_head. + + Returns: + Log probabilities for target tokens [B, T]. + """ + chunk_size = self.get_model_config().loss_chunk_size + if chunk_size > 0: + return self._compute_chunked_logprobs(hidden_states, target_ids, chunk_size, adapter_indices) + else: + logits = self.compute_logits(hidden_states, adapter_indices) + return self.logits_to_logprobs(logits, target_ids) + + @staticmethod + def logits_to_logprobs(logits: jax.Array, target_ids: jax.Array) -> jax.Array: + """Convert logits to logprobs. For decode logprobs when logits already computed. + + Args: + logits: Logits [B, T, V] or [B, V]. + target_ids: Target token IDs [B, T] or [B]. + + Returns: + Log probabilities for target tokens [B, T] or [B]. + """ + log_sum_exp = jax.nn.logsumexp(logits, axis=-1, keepdims=True) + target_logits = jnp.take_along_axis(logits, target_ids[..., None], axis=-1) + return (target_logits - log_sum_exp).squeeze(-1) + + def _compute_chunked_logprobs( + self, + hidden_states: jax.Array, + target_ids: jax.Array, + chunk_size: int, + adapter_indices: jax.Array | None, + ) -> jax.Array: + """Compute log probabilities using chunked lm_head computation. + + This avoids materializing the full [B*T, V] logits tensor by computing + lm_head and log probabilities for each chunk sequentially. + """ + B, T, H = hidden_states.shape + total_tokens = B * T + lm_head = self.get_lm_head() + + # Flatten batch and sequence dimensions + flat_hidden = hidden_states.reshape(-1, H) # [B*T, H] + flat_target_ids = target_ids.reshape(-1) # [B*T] + + # Flatten and chunk adapter indices like hidden states and targets + if adapter_indices is None: + flat_adapter_indices = jnp.zeros(total_tokens, dtype=jnp.int32) + else: + flat_adapter_indices = jnp.repeat(adapter_indices, T) # [B*T] + + # Pad to multiple of chunk_size for clean slicing + num_chunks = (total_tokens + chunk_size - 1) // chunk_size + pad_amount = num_chunks * chunk_size - total_tokens + flat_hidden = jnp.pad(flat_hidden, ((0, pad_amount), (0, 0))) + flat_target_ids = jnp.pad(flat_target_ids, (0, pad_amount)) + flat_adapter_indices = jnp.pad(flat_adapter_indices, (0, pad_amount)) + + # Reshape into chunks: [num_chunks, chunk_size, ...] + chunked_hidden = flat_hidden.reshape(num_chunks, chunk_size, H) + chunked_targets = flat_target_ids.reshape(num_chunks, chunk_size) + chunked_adapters = flat_adapter_indices.reshape(num_chunks, chunk_size) + + def compute_chunk_logprobs(args): + """Compute lm_head and log probabilities for a chunk of tokens.""" + chunk_hidden, chunk_targets, chunk_adapters = args + # Reshape chunk_hidden to [chunk_size, 1, H] for lm_head + # and compute chunk_logits: [chunk_size, 1, H] -> [chunk_size, 1, V] -> [chunk_size, V] + # TODO: Remove the conversion and make it so lm_head can operate on 2d tensors directly + chunk_logits = lm_head(chunk_hidden[:, None, :], chunk_adapters)[:, 0, :] + return LogitsProcessorMixin.logits_to_logprobs(chunk_logits, chunk_targets) + + if self.get_model_config().gradient_checkpointing: + compute_chunk_logprobs = jax.checkpoint(compute_chunk_logprobs, policy=None) + + all_logprobs = jax.lax.map(compute_chunk_logprobs, (chunked_hidden, chunked_targets, chunked_adapters)) + return all_logprobs.reshape(-1)[:total_tokens].reshape(B, T) diff --git a/skyrl-train/skyrl_train/tx_utils/models.py b/skyrl-train/skyrl_train/tx_utils/models.py new file mode 100644 index 000000000..6e840febf --- /dev/null +++ b/skyrl-train/skyrl_train/tx_utils/models.py @@ -0,0 +1,306 @@ +from __future__ import annotations + +from enum import Enum +import os +from pathlib import Path +from typing import Callable, TYPE_CHECKING + +from cloudpathlib import CloudPath +from flax import nnx +from huggingface_hub import snapshot_download +import jax +import jax.numpy as jnp +import numpy as np +import optax +import safetensors.numpy +from transformers import PretrainedConfig +import peft + +from tx.models.configs import ModelConfig +from tx.utils.log import logger +from tx.utils.storage import download_and_unpack, pack_and_upload +from tx.tinker.types import LoraConfig + +if TYPE_CHECKING: + import torch + + +def resolve_model_path(model_name_or_path: str) -> str: + """Resolve a model name or path to a local directory path. + + If the model_name_or_path points to an existing local directory, it will be + used directly. Otherwise, the model will be downloaded from HuggingFace Hub. + + Args: + model_name_or_path: Either a local path to a model directory or a + HuggingFace model identifier (e.g., "Qwen/Qwen3-0.6B"). + + Returns: + Path to the local directory containing model config and weights. + """ + local_path = Path(model_name_or_path).expanduser() + if local_path.is_dir(): + logger.info(f"Using local model at {local_path}") + return str(local_path) + return snapshot_download(model_name_or_path, allow_patterns=["*.safetensors", "*.json"]) + + +def get_dtype(dtype: str | torch.dtype) -> jnp.dtype: + "Convert torch dtype to jax dtype." + + match str(dtype): + case "torch.float32" | "float32": + return jnp.float32 + case "torch.bfloat16" | "bfloat16": + return jnp.bfloat16 + case "torch.float16" | "float16": + return jnp.float16 + case _: + raise ValueError(f"Unsupported torch dtype: {dtype}") + + +def get_model_class(config: PretrainedConfig) -> Callable[..., nnx.Module]: + "Get the correct model class based on the config." + import tx.models.llama3 + import tx.models.qwen3 + import tx.models.deepseekv3 + + for architecture in config.architectures or []: + if hasattr(tx.models.llama3, architecture): + return getattr(tx.models.llama3, architecture) + if hasattr(tx.models.qwen3, architecture): + return getattr(tx.models.qwen3, architecture) + if hasattr(tx.models.deepseekv3, architecture): + return getattr(tx.models.deepseekv3, architecture) + + raise ValueError(f"None of the architectures {config.architectures} is currently supported.") + + +def get_param_key(path: tuple, prefix: str = "") -> str: + "Get the safetensors key for a given model path." + if path[-1] in {"embedding", "kernel"}: + path = (*path[:-1], "weight") + elif path[-1] in {"lora_A", "lora_B"}: + path = (*path, "weight") + return prefix + ".".join(map(str, path)) + + +def get_expert_key(path: tuple, expert_idx: int) -> str: + "Get the safetensors key for an expert weight model path." + path = tuple(s if s != "experts" else f"experts.{expert_idx}" for s in path) + return ".".join(map(str, path)) + + +def load_safetensors( + checkpoint_dir: str | os.PathLike, + config: ModelConfig, + model: nnx.Module, + skip_lora: bool = True, + prefix: str = "", + filter_fn: Callable[[tuple], bool] | None = None, +) -> None: + tensors = {} + for file in Path(checkpoint_dir).glob("*.safetensors"): + tensors.update(safetensors.numpy.load_file(file)) + tensors = {k.removeprefix(prefix): v for k, v in tensors.items()} + + model_params = nnx.to_flat_state(nnx.state(model)) + updates = [] + for path, param in model_params: + if filter_fn is not None and not filter_fn(path): + continue + key = get_param_key(path) + # Skip LoRA parameters if requested + if skip_lora and ("lora_A" in path or "lora_B" in path or "lora_scaling" in path or "lora_ranks" in path): + continue + if "experts" in path: + tensors[key] = np.stack( + [tensors[get_expert_key(path, i)].T for i in range(config.get_num_experts())], axis=0 + ) + else: + tensors[key] = tensors[key] if "embed_tokens" in path else tensors[key].T + if path[-2] in {"q_proj", "k_proj", "v_proj", "o_proj"}: + tensors[key] = tensors[key].reshape(param.shape) + assert param.shape == tensors[key].shape, f"shape mismatch for {key}" + sharded_tensor = jax.device_put(tensors[key].astype(param.dtype), param.sharding) + updates.append((path, sharded_tensor)) + nnx.update(model, nnx.from_flat_state(updates)) + + +def save_safetensors( + config: ModelConfig, + model: nnx.Module, + filename: Path, + prefix: str = "", + filter_fn: Callable[[tuple], bool] | None = None, +) -> None: + model_params = nnx.to_flat_state(nnx.state(model)) + tensors = {} + for path, param in model_params: + if "rngs" in path: + continue + if filter_fn is not None and not filter_fn(path): + continue + key = get_param_key(path, prefix=prefix) + if "experts" in path: + for i in range(config.get_num_experts()): + tensors[get_expert_key(path, i)] = param[i, :, :].T + continue + if "q_proj" in path or "k_proj" in path or "v_proj" in path: + param = param.reshape(param.shape[0], -1) + elif "o_proj" in path: + param = param.reshape(-1, param.shape[-1]) + tensors[key] = param if "embed_tokens" in path else param.T + + # In multi-host mode, gather all shards and only save from rank 0 + if jax.process_count() > 1: + from jax.experimental import multihost_utils + + tensors = {k: multihost_utils.process_allgather(v, tiled=True) for k, v in tensors.items()} + + if jax.process_index() == 0: + safetensors.numpy.save_file({k: np.asarray(v) for k, v in tensors.items()}, filename) + + +def filter_lora(adapter_config: LoraConfig, path: tuple[str, ...]) -> bool: + if not adapter_config.train_attn and "self_attn" in path: + return False + if not adapter_config.train_mlp and ("mlp" in path or "experts" in path): + return False + if not adapter_config.train_unembed and ("embed_tokens" in path or "lm_head" in path): + return False + return True + + +def load_lora_checkpoint( + model: nnx.Module, adapter_config: LoraConfig, adapter_index: int, checkpoint_path: Path | CloudPath +) -> None: + """Load LoRA adapter weights from a sampling checkpoint into the model. + + Args: + model: The Qwen3ForCausalLM model to load the adapter into + adapter_config: LoRA adapter configuration + adapter_index: Index of the adapter to load into + checkpoint_path: Path to the checkpoint tar.gz file + """ + _, lora_params, _ = nnx.split(model, model.is_lora_param, ...) + + adapter_lora_params = extract_adapter_state(adapter_index, lora_params, adapter_config.rank) + + with download_and_unpack(checkpoint_path) as temp_dir: + load_safetensors( + temp_dir, + model.config, + adapter_lora_params, + skip_lora=False, + prefix="base_model.model.", + filter_fn=lambda path: filter_lora(adapter_config, path), + ) + insert_adapter_state(adapter_index, lora_params, adapter_lora_params, adapter_config.rank) + + +def save_lora_checkpoint( + model: nnx.Module, + base_model_name: str, + adapter_config: LoraConfig, + adapter_index: int, + output_path: Path | CloudPath, +): + """Save a LoRA checkpoint as a tar.gz archive. + + Args: + model: The Qwen3ForCausalLM model to extract LoRA parameters from + adapter_config: LoRA adapter configuration + adapter_index: Index of the adapter to save + output_path: Path to save the checkpoint tar.gz file + """ + _, lora_params, _ = nnx.split(model, model.is_lora_param, ...) + + adapter_lora_params = extract_adapter_state(adapter_index, lora_params, adapter_config.rank) + + peft_config = peft.LoraConfig( + base_model_name_or_path=base_model_name, r=adapter_config.rank, lora_alpha=adapter_config.alpha + ) + + with pack_and_upload(output_path) as temp_dir: + save_safetensors( + model.config, + adapter_lora_params, + temp_dir / "adapter_model.safetensors", + prefix="base_model.model.", + filter_fn=lambda path: filter_lora(adapter_config, path), + ) + peft_config.save_pretrained(temp_dir) + + +class OptimizerName(str, Enum): + adamw = "adamw" + + +def get_optimizer(optimizer_name: OptimizerName, optimizer_args: dict) -> optax.GradientTransformation: + match (optimizer_name, optimizer_args): + case (OptimizerName.adamw, {"learning_rate": lr, **kwargs}): + return optax.adamw(lr, **kwargs) + case (_, {"learning_rate": _}): + raise ValueError(f"Unsupported optimizer: {optimizer_name}") + case _: + raise ValueError("The 'learning_rate' key must be provided in optimizer_args.") + + +@nnx.jit(static_argnames=("adapter_index", "rank")) +def extract_adapter_state(adapter_index: int, lora_params: nnx.GraphState, rank: int) -> nnx.GraphState: + "Helper function to extract the adapter parameters for a specific adapter index." + + def extract_state(path: tuple, p: jnp.ndarray): + if path[-2].key not in {"lora_A", "lora_B"}: + return p + assert p.ndim in {3, 4}, f"LoRA parameters must have 3 or 4 dimensions, got shape {p.shape}" + if path[-2].key == "lora_A": + return p[adapter_index, ..., :, :rank] + if path[-2].key == "lora_B": + return p[adapter_index, ..., :rank, :] + + return jax.tree.map_with_path(extract_state, lora_params) + + +# We need to use nnx.jit here instead of jax.jit so the nnx.update will be handled correctly +@nnx.jit(static_argnames=("adapter_index", "rank")) +def insert_adapter_state( + adapter_index: int, lora_params: nnx.GraphState, new_params: nnx.GraphState, rank: int +) -> None: + "Helper function to insert the adapter parameters for a specific adapter index (inverse of extract_adapter_state)." + + def insert_state(path: tuple, p: jax.Array, new: jax.Array): + if path[-2].key not in {"lora_A", "lora_B"}: + return new + assert p.ndim in {3, 4}, f"LoRA parameters must have 3 or 4 dimensions, got shape {p.shape}" + if path[-2].key == "lora_A": + return p.at[adapter_index, ..., :, :rank].set(new) + elif path[-2].key == "lora_B": + return p.at[adapter_index, ..., :rank, :].set(new) + + updated = jax.tree.map_with_path(insert_state, lora_params, new_params) + nnx.update(lora_params, updated) + + +def round_up_seq_len(seq_len: int) -> int: + """ + Rounds a sequence length up to roughly two significant binary digits. + We do this to pad sequences, so the Jax JIT compiler needs to + compile fewer different shapes. + """ + if seq_len <= 32: + return 32 + + # Find the position of the most significant bit. + msb_pos = seq_len.bit_length() - 1 + # Create a mask for the two most significant bits. + mask = (1 << msb_pos) | (1 << (msb_pos - 1)) + # Round down to the nearest value with at most two significant bits. + result = seq_len & mask + + # If we rounded down, round up to the next bucket boundary. + if result < seq_len: + result += 1 << (msb_pos - 1) + + return result diff --git a/skyrl-train/skyrl_train/tx_utils/storage.py b/skyrl-train/skyrl_train/tx_utils/storage.py new file mode 100644 index 000000000..74d6e7ee7 --- /dev/null +++ b/skyrl-train/skyrl_train/tx_utils/storage.py @@ -0,0 +1,63 @@ +from contextlib import contextmanager +import io +import gzip +from pathlib import Path +import tarfile +from tempfile import TemporaryDirectory +from typing import Generator +from cloudpathlib import AnyPath + + +@contextmanager +def pack_and_upload(dest: AnyPath) -> Generator[Path, None, None]: + """Give the caller a temp directory that gets uploaded as a tar.gz archive on exit. + + Args: + dest: Destination path for the tar.gz file + """ + with TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + + yield tmp_path + + dest.parent.mkdir(parents=True, exist_ok=True) + + with dest.open("wb") as f: + # Use compresslevel=0 to prioritize speed, as checkpoint files don't compress well. + with gzip.GzipFile(fileobj=f, mode="wb", compresslevel=0) as gz_stream: + with tarfile.open(fileobj=gz_stream, mode="w:") as tar: + tar.add(tmp_path, arcname="") + + +@contextmanager +def download_and_unpack(source: AnyPath) -> Generator[Path, None, None]: + """Download and extract a tar.gz archive and give the content to the caller in a temp directory. + + Args: + source: Source path for the tar.gz file + """ + with TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + + # Download and extract tar archive (handles both local and cloud storage) + with source.open("rb") as f: + with tarfile.open(fileobj=f, mode="r:gz") as tar: + tar.extractall(tmp_path, filter="data") + + yield tmp_path + + +def download_file(source: AnyPath) -> io.BytesIO: + """Download a file from storage and return it as a BytesIO object. + + Args: + source: Source path for the file (local or cloud) + + Returns: + BytesIO object containing the file contents + """ + buffer = io.BytesIO() + with source.open("rb") as f: + buffer.write(f.read()) + buffer.seek(0) + return buffer diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index 8167dbb28..147a5b619 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -900,9 +900,19 @@ def all_reduce_metrics(self, status: Dict[str, float]) -> Dict[str, float]: def get_lr(self) -> float: """ - Get current learning rate from scheduler. + Get current learning rate from optimizer. """ - return self.scheduler.get_last_lr()[0] + return self.optimizer.param_groups[0]["lr"] + + def set_lr(self, learning_rate: float) -> None: + """ + Set learning rate for the optimizer. + + This directly updates the optimizer's param_groups, bypassing the scheduler. + Useful for external learning rate schedules (e.g., from Tinker). + """ + for param_group in self.optimizer.param_groups: + param_group["lr"] = learning_rate def barrier(self) -> None: """ @@ -1107,9 +1117,19 @@ def all_reduce_metrics(self, status: Dict[str, float]) -> Dict[str, float]: def get_lr(self) -> float: """ - Get current learning rate from scheduler. + Get current learning rate from optimizer. + """ + return self.optimizer.param_groups[0]["lr"] + + def set_lr(self, learning_rate: float) -> None: + """ + Set learning rate for the optimizer. + + This directly updates the optimizer's param_groups, bypassing the scheduler. + Useful for external learning rate schedules (e.g., from Tinker). """ - return self.scheduler.get_last_lr()[0] + for param_group in self.optimizer.param_groups: + param_group["lr"] = learning_rate def barrier(self) -> None: """ diff --git a/skyrl-train/skyrl_train/workers/worker_dispatch.py b/skyrl-train/skyrl_train/workers/worker_dispatch.py index fba225bf8..8f59f4681 100644 --- a/skyrl-train/skyrl_train/workers/worker_dispatch.py +++ b/skyrl-train/skyrl_train/workers/worker_dispatch.py @@ -208,6 +208,15 @@ def optim_step(self, model: str) -> Optional[float]: self._save_memory_snapshot(model, "optim_step") return grad_norms[0] + def set_lr(self, model: str, learning_rate: float) -> None: + """Set learning rate for model's optimizer. + + This directly updates the optimizer's param_groups on all workers, + bypassing the scheduler. Useful for external learning rate schedules. + """ + self._ensure_on_gpu(model, need_optimizer=True, need_model=False) + ray.get(self._actor_groups[model].async_run_ray_method("pass_through", "set_lr", learning_rate=learning_rate)) + # TODO(tgriggs): Remove this when Megatron supports forward_backward and optim_step. def ppo_train(self, model: str, data: TrainingInputBatch) -> Dict[str, float]: """Run full PPO training loop (for Megatron).""" diff --git a/skyrl-train/tests/gpu/gpu_ci/test_training_step.py b/skyrl-train/tests/gpu/gpu_ci/test_training_step.py index 05eac3671..c5118f853 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_training_step.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_training_step.py @@ -134,6 +134,43 @@ async def test_critic_forward_backward_and_optim_step(ray_init_fixture, cfg, pac ray.shutdown() +@pytest.mark.asyncio +async def test_set_lr_updates_optimizer(ray_init_fixture, cfg): + """ + Test that set_lr updates the optimizer's learning rate. + """ + cfg.trainer.use_sample_packing = False + cfg.trainer.strategy = "fsdp2" + validate_cfg(cfg) + + try: + actor_group = init_worker_with_type( + "policy", + shared_pg=None, + colocate_all=False, + num_gpus_per_node=cfg.trainer.placement.policy_num_gpus_per_node, + cfg=cfg, + ) + + # Get initial learning rate + initial_lrs = ray.get(actor_group.async_run_ray_method("pass_through", "get_lr")) + initial_lr = initial_lrs[0] + + # Set a new learning rate + new_lr = 1e-5 + assert new_lr != initial_lr, "New LR should differ from initial for valid test" + + ray.get(actor_group.async_run_ray_method("pass_through", "set_lr", learning_rate=new_lr)) + + # Verify the learning rate was updated + updated_lrs = ray.get(actor_group.async_run_ray_method("pass_through", "get_lr")) + for updated_lr in updated_lrs: + assert updated_lr == new_lr, f"Expected LR {new_lr}, got {updated_lr}" + + finally: + ray.shutdown() + + @pytest.mark.asyncio async def test_sft_forward_backward_with_cross_entropy(ray_init_fixture, cfg): """ diff --git a/skyrl-train/tests/gpu/gpu_ci/test_worker_dispatch_offload.py b/skyrl-train/tests/gpu/gpu_ci/test_worker_dispatch_offload.py index 7418d5c01..2dad21fcc 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_worker_dispatch_offload.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_worker_dispatch_offload.py @@ -291,3 +291,39 @@ async def test_colocate_policy_critic_training_switch(ray_init_fixture): finally: ray.shutdown() + + +@pytest.mark.asyncio +async def test_dispatch_set_lr(ray_init_fixture): + """ + Test that WorkerDispatch.set_lr updates the optimizer's learning rate. + """ + cfg = get_test_config() + + try: + # Create placement group and policy actor + pg = placement_group([{"GPU": 1, "CPU": 2}], strategy="PACK") + get_ray_pg_ready_with_timeout(pg, timeout=30) + + policy_group = init_colocated_actor_group(PolicyWorker, pg, cfg) + ray.get(policy_group.async_init_model(MODEL_NAME)) + + dispatch = WorkerDispatch(cfg, policy_actor_group=policy_group) + + # Get initial learning rate + initial_lrs = ray.get(policy_group.async_run_ray_method("pass_through", "get_lr")) + initial_lr = initial_lrs[0] + + # Set a new learning rate via dispatch + new_lr = 1e-5 + assert new_lr != initial_lr, "New LR should differ from initial for valid test" + + dispatch.set_lr("policy", new_lr) + + # Verify the learning rate was updated + updated_lrs = ray.get(policy_group.async_run_ray_method("pass_through", "get_lr")) + for updated_lr in updated_lrs: + assert updated_lr == new_lr, f"Expected LR {new_lr}, got {updated_lr}" + + finally: + ray.shutdown() diff --git a/skyrl-tx/tx/tinker/backends/skyrl_train.py b/skyrl-tx/tx/tinker/backends/skyrl_train.py index 5e0e70fd4..5751484e7 100644 --- a/skyrl-tx/tx/tinker/backends/skyrl_train.py +++ b/skyrl-tx/tx/tinker/backends/skyrl_train.py @@ -49,6 +49,11 @@ def _build_config(base_model: str, config: SkyRLTrainBackendConfig, lora_config: """Build config for SkyRL-Train workers using default config.""" cfg = get_default_config() cfg.trainer.policy.model.path = base_model + + # Disable scheduler - Tinker manages learning rate externally via set_lr() + cfg.trainer.policy.optimizer_config.scheduler = "constant" + cfg.trainer.policy.optimizer_config.num_warmup_steps = 0 + return cfg @@ -200,8 +205,14 @@ def forward( def optim_step(self, model_id: str, request_data: types.OptimStepInput) -> types.OptimStepOutput: if model_id != self._model_id: raise ValueError(f"Model {model_id} not found") + + # Apply learning rate from AdamParams before optimizer step + # Note: beta1, beta2, eps are fixed at optimizer creation and cannot be changed dynamically + adam_params = request_data.adam_params + self._dispatch.set_lr("policy", adam_params.learning_rate) + grad_norm = self._dispatch.optim_step("policy") - logger.info(f"grad_norm: {grad_norm}") + logger.info(f"optim_step: lr={adam_params.learning_rate}, grad_norm={grad_norm}") return types.OptimStepOutput() def sample(