Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions skyrl-tx/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ tpu = [
"jax[tpu]>=0.7.2",
]

ray = [
"ray[default]>=2.53.0",
]

tinker = [
"tinker>=0.3.0",
"fastapi[standard]",
Expand Down
81 changes: 79 additions & 2 deletions skyrl-tx/tests/tinker/test_engine.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from unittest.mock import MagicMock

from cloudpathlib import AnyPath
from datetime import datetime, timedelta, timezone

Expand All @@ -24,7 +26,9 @@ def test_process_unload_model():

model_id = "test_model"
_ = engine.process_single_request(
types.RequestType.CREATE_MODEL, model_id, {"lora_config": {"rank": 8, "alpha": 16, "seed": 0}}
types.RequestType.CREATE_MODEL,
model_id,
{"lora_config": {"rank": 8, "alpha": 16, "seed": 0}},
)
assert engine.backend.has_model(model_id)

Expand All @@ -50,7 +54,9 @@ def test_cleanup_stale_sessions():

# Create model in backend
_ = engine.process_single_request(
types.RequestType.CREATE_MODEL, model_id, {"lora_config": {"rank": 8, "alpha": 16, "seed": 0}}
types.RequestType.CREATE_MODEL,
model_id,
{"lora_config": {"rank": 8, "alpha": 16, "seed": 0}},
)
assert engine.backend.has_model(model_id)

Expand Down Expand Up @@ -80,3 +86,74 @@ def test_cleanup_stale_sessions():
# Run cleanup and assert one model was unloaded
assert engine.cleanup_stale_sessions() == 1
assert not engine.backend.has_model(model_id)


def test_shutdown_without_ray():
"""Test that shutdown() works correctly when Ray is not enabled."""
config = EngineConfig(
base_model=BASE_MODEL,
checkpoints_base=AnyPath(""),
backend_config={"max_lora_adapters": 4, "max_lora_rank": 32},
database_url="sqlite:///:memory:",
)
engine = TinkerEngine(config)
SQLModel.metadata.create_all(engine.db_engine)

# Without Ray, _ray_process_manager should be None
assert engine._ray_process_manager is None

# shutdown() should not raise an error even when Ray is not used
engine.shutdown()

# Verify _ray_process_manager is still None (no change)
assert engine._ray_process_manager is None


def test_shutdown_with_ray_process_manager():
"""Test that shutdown() correctly calls RayProcessManager.shutdown()."""
config = EngineConfig(
base_model=BASE_MODEL,
checkpoints_base=AnyPath(""),
backend_config={"max_lora_adapters": 4, "max_lora_rank": 32},
database_url="sqlite:///:memory:",
)
engine = TinkerEngine(config)
SQLModel.metadata.create_all(engine.db_engine)

# Mock the RayProcessManager
mock_ray_manager = MagicMock()
engine._ray_process_manager = mock_ray_manager

# Call shutdown
engine.shutdown()

# Verify RayProcessManager.shutdown() was called exactly once
mock_ray_manager.shutdown.assert_called_once()

# Verify _ray_process_manager is set to None after shutdown
assert engine._ray_process_manager is None


def test_shutdown_idempotent():
"""Test that calling shutdown() multiple times is safe (idempotent)."""
config = EngineConfig(
base_model=BASE_MODEL,
checkpoints_base=AnyPath(""),
backend_config={"max_lora_adapters": 4, "max_lora_rank": 32},
database_url="sqlite:///:memory:",
)
engine = TinkerEngine(config)
SQLModel.metadata.create_all(engine.db_engine)

# Mock the RayProcessManager
mock_ray_manager = MagicMock()
engine._ray_process_manager = mock_ray_manager

# Call shutdown multiple times
engine.shutdown()
engine.shutdown()
engine.shutdown()

# Verify RayProcessManager.shutdown() was called only once
# (subsequent calls should be no-ops since _ray_process_manager is None)
mock_ray_manager.shutdown.assert_called_once()
Loading