diff --git a/skyrl-train/skyrl_train/inference_engines/ray_wrapped_inference_engine.py b/skyrl-train/skyrl_train/inference_engines/ray_wrapped_inference_engine.py index c19531cc9..d00de619f 100644 --- a/skyrl-train/skyrl_train/inference_engines/ray_wrapped_inference_engine.py +++ b/skyrl-train/skyrl_train/inference_engines/ray_wrapped_inference_engine.py @@ -179,9 +179,19 @@ def create_ray_wrapped_inference_engines( } rope_engine_kwargs = {} - if rope_scaling: - rope_engine_kwargs["rope_scaling"] = rope_scaling - if "max_model_len" not in engine_init_kwargs: + # rope_scaling and rope_theta must be passed via hf_overrides.rope_parameters + # for vLLM >= 0.8.3. See: https://docs.vllm.ai/en/latest/examples/offline_inference/context_extension/ + # Use .get() since OmegaConf DictConfig in struct mode doesn't support .pop() + hf_overrides = dict(engine_init_kwargs.get("hf_overrides", {}) or {}) + if rope_scaling or rope_theta is not None: + # Convert to regular dict to avoid OmegaConf struct mode issues in vLLM + # vLLM expects rope_parameters, not rope_scaling + rope_parameters = dict(rope_scaling) if rope_scaling else {} + if rope_theta is not None: + rope_parameters["rope_theta"] = rope_theta + hf_overrides["rope_parameters"] = rope_parameters + + if rope_scaling and "max_model_len" not in engine_init_kwargs: rope_factor = rope_scaling.get("factor", None) rope_max_pos = rope_scaling.get("original_max_position_embeddings", None) assert rope_factor is not None, "Please provide rope scaling `factor` to compute model max length" @@ -189,8 +199,8 @@ def create_ray_wrapped_inference_engines( rope_max_pos is not None ), "Please provide rope `original_max_position_embeddings` to compute model max length" rope_engine_kwargs["max_model_len"] = int(rope_factor * rope_max_pos) - if rope_theta is not None: - rope_engine_kwargs["rope_theta"] = rope_theta + if hf_overrides: + rope_engine_kwargs["hf_overrides"] = hf_overrides other_kwargs = {} diff --git a/skyrl-train/skyrl_train/model_wrapper.py b/skyrl-train/skyrl_train/model_wrapper.py index 847dc6102..20ad451ea 100644 --- a/skyrl-train/skyrl_train/model_wrapper.py +++ b/skyrl-train/skyrl_train/model_wrapper.py @@ -97,11 +97,12 @@ def __init__( model_config = AutoConfig.from_pretrained(pretrain_or_model, trust_remote_code=True, **model_config_kwargs) - rope_scaling_kwargs = {} + # Set rope_scaling on config (not as kwarg to from_pretrained) + # HuggingFace models don't accept rope_scaling as a constructor kwarg if rope_scaling: - rope_scaling_kwargs["rope_scaling"] = rope_scaling + model_config.rope_scaling = dict(rope_scaling) if hasattr(rope_scaling, "keys") else rope_scaling if rope_theta: - rope_scaling_kwargs["rope_theta"] = rope_theta + model_config.rope_theta = rope_theta self.model = model_class.from_pretrained( pretrain_or_model, @@ -111,7 +112,6 @@ def __init__( quantization_config=nf4_config, torch_dtype=torch.bfloat16 if bf16 else torch.float32, device_map=device_map, - **rope_scaling_kwargs, ) # gpt oss