From f871e0164dc6a90271cfd78e7b9e1eb42d8e3b59 Mon Sep 17 00:00:00 2001 From: Deniz Date: Tue, 27 Jan 2026 11:14:17 -0800 Subject: [PATCH] fix: rope_scaling for vLLM and HuggingFace context extension MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## vLLM Fix (ray_wrapped_inference_engine.py) - Use `hf_overrides.rope_parameters` instead of direct `rope_scaling` kwarg - vLLM >= 0.8.3 requires rope config via hf_overrides - Convert OmegaConf DictConfig to regular dict to avoid struct mode errors - Reference: https://docs.vllm.ai/en/latest/examples/offline_inference/context_extension/ ## HuggingFace Fix (model_wrapper.py) - Set `rope_scaling` on model config object instead of passing as kwarg - HuggingFace models don't accept rope_scaling as a from_pretrained() kwarg - Fixes: TypeError: Qwen3ForCausalLM.__init__() got an unexpected keyword argument 'rope_scaling' 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- .../ray_wrapped_inference_engine.py | 20 ++++++++++++++----- skyrl-train/skyrl_train/model_wrapper.py | 8 ++++---- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/skyrl-train/skyrl_train/inference_engines/ray_wrapped_inference_engine.py b/skyrl-train/skyrl_train/inference_engines/ray_wrapped_inference_engine.py index c19531cc9..d00de619f 100644 --- a/skyrl-train/skyrl_train/inference_engines/ray_wrapped_inference_engine.py +++ b/skyrl-train/skyrl_train/inference_engines/ray_wrapped_inference_engine.py @@ -179,9 +179,19 @@ def create_ray_wrapped_inference_engines( } rope_engine_kwargs = {} - if rope_scaling: - rope_engine_kwargs["rope_scaling"] = rope_scaling - if "max_model_len" not in engine_init_kwargs: + # rope_scaling and rope_theta must be passed via hf_overrides.rope_parameters + # for vLLM >= 0.8.3. See: https://docs.vllm.ai/en/latest/examples/offline_inference/context_extension/ + # Use .get() since OmegaConf DictConfig in struct mode doesn't support .pop() + hf_overrides = dict(engine_init_kwargs.get("hf_overrides", {}) or {}) + if rope_scaling or rope_theta is not None: + # Convert to regular dict to avoid OmegaConf struct mode issues in vLLM + # vLLM expects rope_parameters, not rope_scaling + rope_parameters = dict(rope_scaling) if rope_scaling else {} + if rope_theta is not None: + rope_parameters["rope_theta"] = rope_theta + hf_overrides["rope_parameters"] = rope_parameters + + if rope_scaling and "max_model_len" not in engine_init_kwargs: rope_factor = rope_scaling.get("factor", None) rope_max_pos = rope_scaling.get("original_max_position_embeddings", None) assert rope_factor is not None, "Please provide rope scaling `factor` to compute model max length" @@ -189,8 +199,8 @@ def create_ray_wrapped_inference_engines( rope_max_pos is not None ), "Please provide rope `original_max_position_embeddings` to compute model max length" rope_engine_kwargs["max_model_len"] = int(rope_factor * rope_max_pos) - if rope_theta is not None: - rope_engine_kwargs["rope_theta"] = rope_theta + if hf_overrides: + rope_engine_kwargs["hf_overrides"] = hf_overrides other_kwargs = {} diff --git a/skyrl-train/skyrl_train/model_wrapper.py b/skyrl-train/skyrl_train/model_wrapper.py index 847dc6102..20ad451ea 100644 --- a/skyrl-train/skyrl_train/model_wrapper.py +++ b/skyrl-train/skyrl_train/model_wrapper.py @@ -97,11 +97,12 @@ def __init__( model_config = AutoConfig.from_pretrained(pretrain_or_model, trust_remote_code=True, **model_config_kwargs) - rope_scaling_kwargs = {} + # Set rope_scaling on config (not as kwarg to from_pretrained) + # HuggingFace models don't accept rope_scaling as a constructor kwarg if rope_scaling: - rope_scaling_kwargs["rope_scaling"] = rope_scaling + model_config.rope_scaling = dict(rope_scaling) if hasattr(rope_scaling, "keys") else rope_scaling if rope_theta: - rope_scaling_kwargs["rope_theta"] = rope_theta + model_config.rope_theta = rope_theta self.model = model_class.from_pretrained( pretrain_or_model, @@ -111,7 +112,6 @@ def __init__( quantization_config=nf4_config, torch_dtype=torch.bfloat16 if bf16 else torch.float32, device_map=device_map, - **rope_scaling_kwargs, ) # gpt oss