Initialize Qwen3.5 mutable buffers during export#17801
Initialize Qwen3.5 mutable buffers during export#17801Phineas1500 wants to merge 6 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/17801
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchbot label "release notes: none" |
There was a problem hiding this comment.
Pull request overview
This PR adds Qwen3.5 support to the Llama export pipeline with deterministic initialization of Qwen3.5’s internal mutable buffers (KV cache + DeltaNet recurrent/conv state) during export, and introduces the Qwen3.5 attention implementations/configs needed to run/export the hybrid layer layout.
Changes:
- Add Qwen3.5 model types/configs and HF weight conversion utilities for ExecuTorch “meta” format.
- Implement Qwen3.5 hybrid attention blocks (full attention + Gated DeltaNet linear attention) and wire hybrid layer construction into the Llama transformer.
- Factor/export additional mutable-buffer initialization pass selection (torchtune + Qwen3.5) into a shared helper and add unit tests for pass selection and attention state reset.
Reviewed changes
Copilot reviewed 21 out of 21 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| extension/llm/export/config/llm_config.py | Adds Qwen3.5 model types to the export config enum. |
| examples/models/qwen3_5/tests/test_convert_weights.py | Unit test for Qwen3.5 HF→meta key mapping. |
| examples/models/qwen3_5/tests/init.py | Package marker/license header for Qwen3.5 tests. |
| examples/models/qwen3_5/convert_weights.py | Implements Qwen3.5 checkpoint loading and key conversion (incl. legacy packed tensor splitting). |
| examples/models/qwen3_5/config/qwen3_5_xnnpack_fp32.yaml | Adds an fp32/static-shape XNNPACK export config for Qwen3.5. |
| examples/models/qwen3_5/config/4b_config.json | Adds model args for Qwen3.5 4B (hybrid layer_types etc.). |
| examples/models/qwen3_5/config/2b_config.json | Adds model args for Qwen3.5 2B. |
| examples/models/qwen3_5/config/0_8b_config.json | Adds model args for Qwen3.5 0.8B. |
| examples/models/qwen3_5/init.py | Adds a Qwen3.5 model entrypoint (lazy subclass of Llama2Model) and exports convert_weights. |
| examples/models/qwen3_5/README.md | Documents export/run instructions for Qwen3.5 models. |
| examples/models/qwen3_5/BUCK | Adds Buck target for the Qwen3.5 Python library + deps. |
| examples/models/llama/tests/test_qwen3_5_attention.py | Adds tests for Qwen3.5 full-attn shape and DeltaNet state reset behavior. |
| examples/models/llama/tests/test_export_llama_lib.py | Adds tests covering export-pass selection for Qwen3.5/torchtune/llama3. |
| examples/models/llama/tests/BUCK | Registers the new Qwen3.5 attention unittest target. |
| examples/models/llama/norm.py | Extends RMSNorm to support Qwen3.5 “(1 + weight)” scaling. |
| examples/models/llama/model_args.py | Adds Qwen3.5 linear-attention dims + RMSNorm scaling flag to ModelArgs with defaults. |
| examples/models/llama/llama_transformer.py | Wires RMSNorm scaling flag and constructs DeltaNet layers when layer_types specify linear_attention. |
| examples/models/llama/export_llama_lib.py | Adds Qwen3.5 model ids, hooks Qwen3.5 weight conversion, and factors mutable-buffer init pass selection into helper. |
| examples/models/llama/attention.py | Adds Qwen3.5 full attention and Gated DeltaNet attention implementations (+ required buffers). |
| examples/models/llama/init.py | Switches llama package export to lazy import pattern for Llama2Model. |
| examples/models/BUCK | Adds the Qwen3.5 model package to the umbrella models BUCK target. |
Comments suppressed due to low confidence (1)
examples/models/llama/norm.py:60
- RMSNorm currently returns
output * self.weightwhenadd_unit_offsetis False. Sinceoutputis cast back totype_as(x)butself.weightstays fp32, this multiplication will promote the result to fp32 for fp16/bf16 inputs. The newadd_unit_offsetbranch explicitly casts the weight totype_as(x), so the dtype behavior is now inconsistent between the two paths. Consider castingself.weighttotype_as(x)(or otherwise ensuring the output dtype matches the input) in the non-offset path as well.
return output * (1.0 + self.weight.float()).type_as(x)
return output * self.weight
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| raise ValueError( | ||
| f"Invalid packed in_proj_qkvz shape for {key}: {tuple(value.shape)}" | ||
| ) | ||
| qkv, z = torch.split(value, [conv_dim, value_dim], dim=0) |
There was a problem hiding this comment.
key_dim is computed when splitting legacy packed in_proj_qkvz.weight but is never used afterward. Please remove it or use it for an explicit shape validation to avoid dead code.
|
Validated export and runtime with the XNNPACK recipe. Set max_seq_len and max_context_len to 128, generated the .pte, and ran executorch.examples.models.llama.runner.native with a multi-token prompt. The model currently uses static-shape export in this path, but I added sequential token prefill. |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 25 out of 25 changed files in this pull request and generated 5 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| try: | ||
| return self.forward( | ||
| tokens=torch.tensor( | ||
| [prompt_tokens], dtype=torch.long, device=self.device | ||
| ), | ||
| input_pos=torch.tensor([pos_base], dtype=torch.long, device=self.device), | ||
| ) | ||
| except RuntimeError: | ||
| # Some exported models use a static single-token shape for kv-cache mode. | ||
| # Fall back to sequential token prefill so multi-token prompts still work. | ||
| if self.enable_dynamic_shape or len(prompt_tokens) <= 1: | ||
| raise | ||
|
|
||
| return self._sequential_kv_prefill(prompt_tokens, pos_base) | ||
|
|
There was a problem hiding this comment.
In _prefill_with_kv_cache, the early return for not self.enable_dynamic_shape and len(prompt_tokens) > 1 means the subsequent try/except RuntimeError never executes for the “static single-token shape” case described below. As written, the fallback logic is effectively dead code for the static-shape scenario; consider simplifying to a single path (either always sequential when static, or always try batched then fall back).
| try: | |
| return self.forward( | |
| tokens=torch.tensor( | |
| [prompt_tokens], dtype=torch.long, device=self.device | |
| ), | |
| input_pos=torch.tensor([pos_base], dtype=torch.long, device=self.device), | |
| ) | |
| except RuntimeError: | |
| # Some exported models use a static single-token shape for kv-cache mode. | |
| # Fall back to sequential token prefill so multi-token prompts still work. | |
| if self.enable_dynamic_shape or len(prompt_tokens) <= 1: | |
| raise | |
| return self._sequential_kv_prefill(prompt_tokens, pos_base) | |
| return self.forward( | |
| tokens=torch.tensor( | |
| [prompt_tokens], dtype=torch.long, device=self.device | |
| ), | |
| input_pos=torch.tensor([pos_base], dtype=torch.long, device=self.device), | |
| ) |
| if self.enable_dynamic_shape or len(prompt_tokens) <= 1: | ||
| raise | ||
|
|
There was a problem hiding this comment.
The except RuntimeError fallback is currently unreachable when self.enable_dynamic_shape is True (the default): the handler re-raises whenever self.enable_dynamic_shape is true, so a static exported model that throws on batched prefill (and lacks the enable_dynamic_shape metadata method) will fail instead of falling back to sequential prefill. Consider falling back based on len(prompt_tokens) > 1 (and/or the specific error) rather than the enable_dynamic_shape flag, or updating the flag when the batched call fails.
| if self.enable_dynamic_shape or len(prompt_tokens) <= 1: | |
| raise | |
| # | |
| # If the batched prefill fails for a multi-token prompt, disable dynamic | |
| # shape support and retry using sequential prefill. For single-token | |
| # prompts, propagate the error. | |
| if len(prompt_tokens) <= 1: | |
| raise | |
| # Avoid retrying batched dynamic-shape prefill after a failure. | |
| self.enable_dynamic_shape = False |
examples/models/llama/norm.py
Outdated
| output = self._norm(x.float()).type_as(x) | ||
| if self.add_unit_offset: | ||
| return output * (1.0 + self.weight.float()).type_as(x) | ||
| return output * self.weight |
There was a problem hiding this comment.
RMSNorm.forward returns output * self.weight when add_unit_offset is false, which will promote dtypes (e.g., fp16 input → fp32 output) because self.weight is fp32. In the new add_unit_offset branch you explicitly cast the scale to type_as(x), so the output dtype now depends on the flag. Consider casting self.weight (or the final product) to type_as(x) in both branches to keep output dtype consistent with the input.
| return output * self.weight | |
| return output * self.weight.type_as(x) |
| try: | ||
| self.enable_dynamic_shape = bool( | ||
| self.model.run_method("enable_dynamic_shape")[0] | ||
| ) | ||
| except Exception: | ||
| # Keep default behavior when metadata method is unavailable. | ||
| pass |
There was a problem hiding this comment.
Catching a bare Exception around run_method("enable_dynamic_shape") can also hide real runtime issues (e.g., model load/ABI problems) and silently keep enable_dynamic_shape=True. It would be safer to catch the specific “method missing”/runtime exceptions raised by run_method (and optionally log at debug level) so unexpected failures don’t get swallowed.
| try: | ||
| new_key = get_mapped_key(normalized_key, _QWEN_3_5_TO_META) | ||
| except Exception: | ||
| # Ignore non-text weights and training-only extras (e.g., MTP). | ||
| if ( | ||
| key.startswith("mtp.") | ||
| or key.startswith("model.visual.") | ||
| or ".vision_" in key | ||
| or key.startswith("visual.") | ||
| ): | ||
| continue | ||
| # Ignore unsupported keys that are not required by the export model. | ||
| continue |
There was a problem hiding this comment.
The except Exception: ... continue around get_mapped_key will silently drop any unexpected keys (including genuinely required text weights if the mapping is incomplete or the checkpoint format changes). This makes conversion failures hard to detect. Consider only ignoring a well-defined allowlist of optional prefixes (vision/MTP/etc.) and re-raising for other model.* keys, or at least logging the first few unmapped keys at warning level.
Summary
Why
Qwen3.5 uses internal mutable state (KV + DeltaNet recurrent/conv buffers). Initializing these buffers at export time avoids uninitialized mutable-buffer state and makes startup behavior deterministic.
Test Plan
Stacking