Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -179,18 +179,28 @@ def create_ray_wrapped_inference_engines(
}

rope_engine_kwargs = {}
if rope_scaling:
rope_engine_kwargs["rope_scaling"] = rope_scaling
if "max_model_len" not in engine_init_kwargs:
# rope_scaling and rope_theta must be passed via hf_overrides.rope_parameters
# for vLLM >= 0.8.3. See: https://docs.vllm.ai/en/latest/examples/offline_inference/context_extension/
# Use .get() since OmegaConf DictConfig in struct mode doesn't support .pop()
hf_overrides = dict(engine_init_kwargs.get("hf_overrides", {}) or {})
if rope_scaling or rope_theta is not None:
# Convert to regular dict to avoid OmegaConf struct mode issues in vLLM
# vLLM expects rope_parameters, not rope_scaling
rope_parameters = dict(rope_scaling) if rope_scaling else {}
if rope_theta is not None:
rope_parameters["rope_theta"] = rope_theta
hf_overrides["rope_parameters"] = rope_parameters

if rope_scaling and "max_model_len" not in engine_init_kwargs:
rope_factor = rope_scaling.get("factor", None)
rope_max_pos = rope_scaling.get("original_max_position_embeddings", None)
assert rope_factor is not None, "Please provide rope scaling `factor` to compute model max length"
assert (
rope_max_pos is not None
), "Please provide rope `original_max_position_embeddings` to compute model max length"
rope_engine_kwargs["max_model_len"] = int(rope_factor * rope_max_pos)
if rope_theta is not None:
rope_engine_kwargs["rope_theta"] = rope_theta
if hf_overrides:
rope_engine_kwargs["hf_overrides"] = hf_overrides

other_kwargs = {}

Expand Down
8 changes: 4 additions & 4 deletions skyrl-train/skyrl_train/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,12 @@ def __init__(

model_config = AutoConfig.from_pretrained(pretrain_or_model, trust_remote_code=True, **model_config_kwargs)

rope_scaling_kwargs = {}
# Set rope_scaling on config (not as kwarg to from_pretrained)
# HuggingFace models don't accept rope_scaling as a constructor kwarg
if rope_scaling:
rope_scaling_kwargs["rope_scaling"] = rope_scaling
model_config.rope_scaling = dict(rope_scaling) if hasattr(rope_scaling, "keys") else rope_scaling
if rope_theta:
rope_scaling_kwargs["rope_theta"] = rope_theta
model_config.rope_theta = rope_theta

self.model = model_class.from_pretrained(
pretrain_or_model,
Expand All @@ -111,7 +112,6 @@ def __init__(
quantization_config=nf4_config,
torch_dtype=torch.bfloat16 if bf16 else torch.float32,
device_map=device_map,
**rope_scaling_kwargs,
)

# gpt oss
Expand Down