From f4442de48bece613fe7365b1a9ca7c2d897887d4 Mon Sep 17 00:00:00 2001 From: Phineas1500 <41450967+Phineas1500@users.noreply.github.com> Date: Mon, 2 Mar 2026 18:37:04 -0500 Subject: [PATCH 01/11] Add Qwen3.5 export support for 0.8B/2B/4B --- examples/models/BUCK | 1 + examples/models/llama/__init__.py | 17 +- examples/models/llama/attention.py | 385 +++++++++++++++++- examples/models/llama/export_llama_lib.py | 8 + examples/models/llama/llama_transformer.py | 33 +- examples/models/llama/model_args.py | 17 + examples/models/llama/norm.py | 7 +- examples/models/llama/tests/BUCK | 11 + .../llama/tests/test_qwen3_5_attention.py | 83 ++++ examples/models/qwen3_5/BUCK | 23 ++ examples/models/qwen3_5/README.md | 52 +++ examples/models/qwen3_5/__init__.py | 24 ++ .../models/qwen3_5/config/0_8b_config.json | 50 +++ examples/models/qwen3_5/config/2b_config.json | 50 +++ examples/models/qwen3_5/config/4b_config.json | 58 +++ .../qwen3_5/config/qwen3_5_xnnpack_fp32.yaml | 17 + examples/models/qwen3_5/convert_weights.py | 201 +++++++++ examples/models/qwen3_5/tests/__init__.py | 2 + .../qwen3_5/tests/test_convert_weights.py | 44 ++ extension/llm/export/config/llm_config.py | 3 + 20 files changed, 1076 insertions(+), 10 deletions(-) create mode 100644 examples/models/llama/tests/test_qwen3_5_attention.py create mode 100644 examples/models/qwen3_5/BUCK create mode 100644 examples/models/qwen3_5/README.md create mode 100644 examples/models/qwen3_5/__init__.py create mode 100644 examples/models/qwen3_5/config/0_8b_config.json create mode 100644 examples/models/qwen3_5/config/2b_config.json create mode 100644 examples/models/qwen3_5/config/4b_config.json create mode 100644 examples/models/qwen3_5/config/qwen3_5_xnnpack_fp32.yaml create mode 100644 examples/models/qwen3_5/convert_weights.py create mode 100644 examples/models/qwen3_5/tests/__init__.py create mode 100644 examples/models/qwen3_5/tests/test_convert_weights.py diff --git a/examples/models/BUCK b/examples/models/BUCK index 94cc88360bf..a2b6789a95e 100644 --- a/examples/models/BUCK +++ b/examples/models/BUCK @@ -29,6 +29,7 @@ fbcode_target(_kind = python_library, "//executorch/examples/models/gemma3:gemma3", # @manual "//executorch/examples/models/qwen2_5:qwen2_5", # @manual "//executorch/examples/models/qwen3:qwen3", # @manual + "//executorch/examples/models/qwen3_5:qwen3_5", # @manual "//executorch/examples/models/phi_4_mini:phi_4_mini", # @manual "//executorch/examples/models/smollm2:smollm2", # @manual "//executorch/examples/models/smollm3:smollm3", # @manual diff --git a/examples/models/llama/__init__.py b/examples/models/llama/__init__.py index db6124ecc71..51f2a5c916f 100644 --- a/examples/models/llama/__init__.py +++ b/examples/models/llama/__init__.py @@ -4,8 +4,17 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from .model import Llama2Model +from typing import TYPE_CHECKING -__all__ = [ - Llama2Model, -] +if TYPE_CHECKING: + from .model import Llama2Model + +__all__ = ["Llama2Model"] + + +def __getattr__(name): + if name == "Llama2Model": + from .model import Llama2Model + + return Llama2Model + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index 922bbbb37fa..da5f1b46b29 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -352,8 +352,16 @@ def __init__( if self.use_qk_norm: q_norm_dim = self.head_dim k_norm_dim = self.head_dim - self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps) - self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps) + self.q_norm_fn = RMSNorm( + q_norm_dim, + eps=args.norm_eps, + add_unit_offset=args.rms_norm_add_unit_offset, + ) + self.k_norm_fn = RMSNorm( + k_norm_dim, + eps=args.norm_eps, + add_unit_offset=args.rms_norm_add_unit_offset, + ) self.wq = ( LoRALinear( @@ -511,6 +519,379 @@ def forward( return output, None +def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor: + inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) + return x * inv_norm + + +class RMSNormGated(nn.Module): + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = self.weight * hidden_states.to(input_dtype) + hidden_states = hidden_states * F.silu(gate.to(torch.float32)) + return hidden_states.to(input_dtype) + + +@register_attention("qwen3_5_full") +class AttentionQwen3_5Full(Attention): + """Qwen3.5 full-attention block with q-gating.""" + + def __init__( + self, + args: ModelArgs, + layer_id: int, + rope: Rope, + **_kwargs: Any, + ): + super().__init__() + self.use_kv_cache = args.use_kv_cache + self.n_heads = args.n_heads + self.n_kv_heads = self.n_heads if args.n_kv_heads is None else args.n_kv_heads + assert self.n_heads % self.n_kv_heads == 0 + self.n_local_heads = self.n_heads + self.n_local_kv_heads = self.n_kv_heads + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = args.head_dim + self.max_context_len = args.max_context_len + self.dim = args.dim + self.attention_qkv_bias = args.attention_qkv_bias + self.use_qk_norm = args.use_qk_norm + self.qk_norm_before_rope = args.qk_norm_before_rope + self.enable_dynamic_shape = args.enable_dynamic_shape + + if self.use_qk_norm: + self.q_norm_fn = RMSNorm( + self.head_dim, + eps=args.norm_eps, + add_unit_offset=args.rms_norm_add_unit_offset, + ) + self.k_norm_fn = RMSNorm( + self.head_dim, + eps=args.norm_eps, + add_unit_offset=args.rms_norm_add_unit_offset, + ) + + self.wq = nn.Linear( + self.dim, self.n_heads * self.head_dim * 2, bias=self.attention_qkv_bias + ) + self.wk = nn.Linear( + self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias + ) + self.wv = nn.Linear( + self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias + ) + self.wo = nn.Linear( + self.n_heads * self.head_dim, + self.dim, + bias=self.attention_qkv_bias, + ) + + self.layer_id = layer_id + self.rope = rope + + causal_mask = torch.tril( + torch.ones( + self.max_context_len, + self.max_context_len, + dtype=torch.bool, + device="cpu", + ) + ) + self.register_buffer("mask", causal_mask, persistent=False) + + if self.use_kv_cache: + self.kv_cache = KVCache( + args.max_batch_size, + args.max_context_len, + self.n_kv_heads, + self.head_dim, + args.enable_dynamic_shape, + ) + self.SDPA = SDPA( + dim=self.n_local_heads * self.head_dim, + head_dim=self.head_dim, + n_rep=self.n_rep, + max_context_len=self.max_context_len, + ) + + def forward( + self, + x: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + **kwargs: ForwardOptions, + ) -> Tuple[torch.Tensor, Optional[Any]]: + input_pos = kwargs.get("input_pos") + bsz, seqlen, _ = x.shape + + # Q and gate are packed in q_proj output. + q_and_gate = self.wq(x).view(bsz, seqlen, self.n_local_heads, self.head_dim * 2) + q, gate = torch.chunk(q_and_gate, 2, dim=-1) + gate = gate.reshape(bsz, seqlen, -1) + + k = self.wk(x).view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + v = self.wv(x).view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + + if self.use_qk_norm and self.qk_norm_before_rope: + q = self.q_norm_fn(q) + k = self.k_norm_fn(k) + + q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + if self.use_qk_norm and not self.qk_norm_before_rope: + q = self.q_norm_fn(q) + k = self.k_norm_fn(k) + + if self.use_kv_cache: + assert input_pos is not None + if self.enable_dynamic_shape: + start_pos = input_pos[-1].item() + torch._check_is_size(start_pos) + torch._check(start_pos < self.max_context_len) + seq_length = q.size(2) + attn_mask = self.mask.narrow(0, start_pos, seq_length) + else: + attn_mask = self.mask[input_pos] + k, v = self.kv_cache.update(input_pos, k, v) + if getattr(self.kv_cache, "is_ring_buffer", False): + attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer( + input_pos[0].item(), seqlen + ) + output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask) + output = output * torch.sigmoid(gate) + return self.wo(output), None + + k = k.repeat_interleave(self.n_rep, dim=1) + v = v.repeat_interleave(self.n_rep, dim=1) + mask = self.mask[:seqlen, :seqlen] + output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + output = output.transpose(1, 2).reshape(bsz, seqlen, -1) + output = output * torch.sigmoid(gate) + return self.wo(output), None + + +@register_attention("gated_deltanet") +class AttentionGatedDeltaNet(Attention): + """Qwen3.5 linear-attention (Gated DeltaNet) block with internal state.""" + + def __init__( + self, + args: ModelArgs, + layer_id: int, + rope: Rope, + **_kwargs: Any, + ): + super().__init__() + del rope # DeltaNet layers do not use RoPE. + + self.hidden_size = args.dim + self.max_batch_size = args.max_batch_size + self.layer_id = layer_id + + assert args.linear_num_key_heads is not None + assert args.linear_num_value_heads is not None + assert args.linear_key_head_dim is not None + assert args.linear_value_head_dim is not None + + self.num_k_heads = args.linear_num_key_heads + self.num_v_heads = args.linear_num_value_heads + self.head_k_dim = args.linear_key_head_dim + self.head_v_dim = args.linear_value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + self.conv_kernel_size = args.linear_conv_kernel_dim + + assert ( + self.num_v_heads % self.num_k_heads == 0 + ), "linear_num_value_heads must be divisible by linear_num_key_heads." + self.head_repeat = self.num_v_heads // self.num_k_heads + + self.conv_dim = self.key_dim * 2 + self.value_dim + self.in_proj_qkv = nn.Linear(self.hidden_size, self.conv_dim, bias=False) + self.in_proj_z = nn.Linear(self.hidden_size, self.value_dim, bias=False) + self.in_proj_b = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) + self.in_proj_a = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) + + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + kernel_size=self.conv_kernel_size, + groups=self.conv_dim, + bias=False, + padding=0, + ) + + self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads)) + A = torch.empty(self.num_v_heads).uniform_(0, 16) + self.A_log = nn.Parameter(torch.log(A)) + self.norm = RMSNormGated(self.head_v_dim, eps=args.norm_eps) + self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) + + self.register_buffer( + "conv_state", + torch.zeros( + self.max_batch_size, + self.conv_dim, + self.conv_kernel_size, + dtype=torch.float32, + device="cpu", + ), + ) + self.register_buffer( + "recurrent_state", + torch.zeros( + self.max_batch_size, + self.num_v_heads, + self.head_k_dim, + self.head_v_dim, + dtype=torch.float32, + device="cpu", + ), + ) + + def _maybe_reset_state(self, input_pos: Optional[torch.Tensor], batch_size: int) -> None: + if input_pos is None: + return + reset = (input_pos[0] == 0).to(self.conv_state.dtype) + keep = 1.0 - reset + self.conv_state[:batch_size].mul_(keep) + self.recurrent_state[:batch_size].mul_(keep) + + def _apply_causal_conv(self, mixed_qkv: torch.Tensor) -> torch.Tensor: + # mixed_qkv: (batch, seq_len, conv_dim) + batch_size, seq_len, _ = mixed_qkv.shape + mixed_qkv = mixed_qkv.transpose(1, 2) + state_len = self.conv_state.shape[-1] + hidden_states_new = torch.cat([self.conv_state[:batch_size], mixed_qkv], dim=-1) + new_conv_state = hidden_states_new[:, :, -state_len:] + with torch.no_grad(): + self.conv_state[:batch_size].copy_(new_conv_state.to(self.conv_state.dtype)) + out = F.conv1d( + hidden_states_new, + self.conv1d.weight, + self.conv1d.bias, + padding=0, + groups=self.conv_dim, + ) + out = F.silu(out[:, :, -seq_len:]).to(mixed_qkv.dtype) + return out.transpose(1, 2).contiguous() + + def _recurrent_gated_delta_rule( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + ) -> torch.Tensor: + # query/key/value: (batch, seq_len, num_heads, head_dim) + # g/beta: (batch, seq_len, num_heads) + initial_dtype = query.dtype + query = _l2norm(query, dim=-1, eps=1e-6) + key = _l2norm(key, dim=-1, eps=1e-6) + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) + for x in (query, key, value, beta, g) + ] + + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + scale = 1.0 / (query.shape[-1] ** 0.5) + query = query * scale + + core_attn_out = torch.zeros( + batch_size, + num_heads, + sequence_length, + v_head_dim, + device=value.device, + dtype=value.dtype, + ) + last_recurrent_state = self.recurrent_state[:batch_size].to(value.dtype) + + for i in range(sequence_length): + q_t = query[:, :, i] + k_t = key[:, :, i] + v_t = value[:, :, i] + g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1) + beta_t = beta[:, :, i].unsqueeze(-1) + + last_recurrent_state = last_recurrent_state * g_t + kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) + delta = (v_t - kv_mem) * beta_t + last_recurrent_state = ( + last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) + ) + core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum( + dim=-2 + ) + + with torch.no_grad(): + self.recurrent_state[:batch_size].copy_( + last_recurrent_state.to(self.recurrent_state.dtype) + ) + + return core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) + + def forward( + self, + x: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + **kwargs: ForwardOptions, + ) -> Tuple[torch.Tensor, Optional[Any]]: + del freqs_cos + del freqs_sin + input_pos = kwargs.get("input_pos") + batch_size, seq_len, _ = x.shape + assert ( + batch_size <= self.max_batch_size + ), f"batch_size ({batch_size}) exceeds max_batch_size ({self.max_batch_size})" + + self._maybe_reset_state(input_pos, batch_size) + + mixed_qkv = self.in_proj_qkv(x) + z = self.in_proj_z(x).reshape(batch_size, seq_len, -1, self.head_v_dim) + b = self.in_proj_b(x) + a = self.in_proj_a(x) + + mixed_qkv = self._apply_causal_conv(mixed_qkv) + query, key, value = torch.split( + mixed_qkv, + [self.key_dim, self.key_dim, self.value_dim], + dim=-1, + ) + query = query.reshape(batch_size, seq_len, -1, self.head_k_dim) + key = key.reshape(batch_size, seq_len, -1, self.head_k_dim) + value = value.reshape(batch_size, seq_len, -1, self.head_v_dim) + + if self.head_repeat > 1: + query = query.repeat_interleave(self.head_repeat, dim=2) + key = key.repeat_interleave(self.head_repeat, dim=2) + + beta = b.sigmoid() + g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + core_attn_out = self._recurrent_gated_delta_rule(query, key, value, g, beta) + + core_attn_out = core_attn_out.reshape(-1, self.head_v_dim) + z = z.reshape(-1, self.head_v_dim) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1) + + return self.out_proj(core_attn_out), None + + @register_attention("skip") class AttentionSkip(Attention): def __init__(self, *args, **kwargs): diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index ee81139dedf..94984565e10 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -105,6 +105,9 @@ "qwen3_0_6b", "qwen3_1_7b", "qwen3_4b", + "qwen3_5_0_8b", + "qwen3_5_2b", + "qwen3_5_4b", "phi_4_mini", "smollm2", "lfm2_350m", # hybrid @@ -122,6 +125,9 @@ "qwen3_0_6b": "Qwen/Qwen3-0.6B", "qwen3_1_7b": "Qwen/Qwen3-1.7B", "qwen3_4b": "Qwen/Qwen3-4B", + "qwen3_5_0_8b": "Qwen/Qwen3.5-0.8B", + "qwen3_5_2b": "Qwen/Qwen3.5-2B", + "qwen3_5_4b": "Qwen/Qwen3.5-4B", "lfm2_350m": "LiquidAI/LFM2-350M", "lfm2_700m": "LiquidAI/LFM2-700M", "lfm2_1_2b": "LiquidAI/LFM2-1.2B", @@ -643,6 +649,8 @@ def export_llama( repo_id = HUGGING_FACE_REPO_IDS[model_name] if model_name.startswith("qwen2_5"): from executorch.examples.models.qwen2_5 import convert_weights + elif model_name.startswith("qwen3_5"): + from executorch.examples.models.qwen3_5 import convert_weights elif model_name.startswith("qwen3"): from executorch.examples.models.qwen3 import convert_weights elif model_name == "phi_4_mini": diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index a4ca52aa608..c79f9cb6ce0 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -105,8 +105,16 @@ def __init__(self, args: ModelArgs, attention: Attention): if isinstance(self.attention, AttentionSkip): self.attention_norm = nn.Identity() else: - self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) - self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.attention_norm = RMSNorm( + args.dim, + eps=args.norm_eps, + add_unit_offset=args.rms_norm_add_unit_offset, + ) + self.ffn_norm = RMSNorm( + args.dim, + eps=args.norm_eps, + add_unit_offset=args.rms_norm_add_unit_offset, + ) @classmethod def from_type(cls, layer_id, args, rope) -> "TransformerBlock": @@ -164,7 +172,11 @@ def __init__(self, params: ModelArgs, layers: nn.ModuleList, rope: Rope): ) self.layers = layers self.rope = rope - self.norm = RMSNorm(params.dim, eps=params.norm_eps) + self.norm = RMSNorm( + params.dim, + eps=params.norm_eps, + add_unit_offset=params.rms_norm_add_unit_offset, + ) self.output = ( nn.Linear(params.dim, params.vocab_size, bias=False) if self.apply_output @@ -279,6 +291,21 @@ def construct_transformer(model_args: ModelArgs) -> Transformer: attention = AttentionSkip() transformer_block = TransformerBlock(model_args, attention) layers.append(transformer_block) + elif ( + model_args.layer_types + and model_args.layer_types[layer_id] == "linear_attention" + ): + linear_cls = ATTENTION_REGISTRY.get("gated_deltanet") + if linear_cls is None: + raise ValueError( + "Unknown attention type: gated_deltanet. " + f"Available: {list(ATTENTION_REGISTRY.keys())}" + ) + attention = linear_cls( + model_args, layer_id, rope, **model_args.attention_kwargs + ) + transformer_block = TransformerBlock(model_args, attention) + layers.append(transformer_block) else: attention = cls( model_args, layer_id, rope, **model_args.attention_kwargs diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index c225a1ee9b2..eb2a83db6ed 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -44,6 +44,14 @@ class ModelArgs: vocab_size: int = 512 # Arbitrary value, should be defined later by tokenizer. hidden_dim: Optional[int] = None head_dim: Optional[int] = None # Optional customized head_dim + # Qwen3.5 linear-attention dimensions. + linear_conv_kernel_dim: int = 4 + linear_key_head_dim: Optional[int] = None + linear_value_head_dim: Optional[int] = None + linear_num_key_heads: Optional[int] = None + linear_num_value_heads: Optional[int] = None + # Qwen3.5 RMSNorm uses (1 + weight) scaling. + rms_norm_add_unit_offset: bool = False multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 ffn_dim_multiplier: Optional[float] = None model_architecture: str = ( @@ -174,6 +182,15 @@ def find_multiple(n: int, k: int) -> int: if self.head_dim is None: self.head_dim = self.dim // self.n_heads + if self.linear_key_head_dim is None: + self.linear_key_head_dim = self.head_dim + if self.linear_value_head_dim is None: + self.linear_value_head_dim = self.head_dim + if self.linear_num_key_heads is None: + self.linear_num_key_heads = self.n_heads + if self.linear_num_value_heads is None: + self.linear_num_value_heads = self.n_heads + # Convert string act_fn to enum if needed if isinstance(self.act_fn, str): self.act_fn = ActFn.from_string(self.act_fn) diff --git a/examples/models/llama/norm.py b/examples/models/llama/norm.py index 3786e61cd05..52ae30c1697 100644 --- a/examples/models/llama/norm.py +++ b/examples/models/llama/norm.py @@ -9,7 +9,9 @@ class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): + def __init__( + self, dim: int, eps: float = 1e-6, add_unit_offset: bool = False + ): """ Initialize the RMSNorm normalization layer. @@ -25,6 +27,7 @@ def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.dim = dim self.eps = eps + self.add_unit_offset = add_unit_offset self.weight = nn.Parameter(torch.ones(dim)) def _norm(self, x): @@ -52,4 +55,6 @@ def forward(self, x): """ output = self._norm(x.float()).type_as(x) + if self.add_unit_offset: + return output * (1.0 + self.weight.float()).type_as(x) return output * self.weight diff --git a/examples/models/llama/tests/BUCK b/examples/models/llama/tests/BUCK index 8f4dec2237b..431c3c92814 100644 --- a/examples/models/llama/tests/BUCK +++ b/examples/models/llama/tests/BUCK @@ -15,6 +15,17 @@ fbcode_target(_kind = python_unittest, ], ) +fbcode_target(_kind = python_unittest, + name = "test_qwen3_5_attention", + srcs = [ + "test_qwen3_5_attention.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/examples/models/llama:llama_transformer", + ], +) + fbcode_target(_kind = python_unittest, name = "test_pre_quantization_transforms", srcs = [ diff --git a/examples/models/llama/tests/test_qwen3_5_attention.py b/examples/models/llama/tests/test_qwen3_5_attention.py new file mode 100644 index 00000000000..b076e69d121 --- /dev/null +++ b/examples/models/llama/tests/test_qwen3_5_attention.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.examples.models.llama.attention import ATTENTION_REGISTRY +from executorch.examples.models.llama.model_args import ModelArgs +from executorch.examples.models.llama.rope import Rope + + +class Qwen35AttentionTest(unittest.TestCase): + def test_qwen35_full_attention_forward_shape(self): + torch.manual_seed(0) + args = ModelArgs( + dim=32, + n_layers=1, + n_heads=4, + n_kv_heads=2, + head_dim=8, + hidden_dim=64, + max_seq_len=16, + max_context_len=16, + use_kv_cache=False, + use_hf_rope=True, + partial_rotary_factor=0.5, + use_qk_norm=True, + qk_norm_before_rope=True, + attention_type="qwen3_5_full", + rms_norm_add_unit_offset=True, + ) + rope = Rope(args) + attn = ATTENTION_REGISTRY["qwen3_5_full"](args, 0, rope) + x = torch.randn(1, 3, args.dim) + freqs_cos, freqs_sin = rope.get_freqs(None, x.shape[1]) + y, _ = attn(x, freqs_cos, freqs_sin) + self.assertEqual(y.shape, x.shape) + + def test_gated_deltanet_resets_state_on_new_sequence(self): + torch.manual_seed(0) + args = ModelArgs( + dim=32, + n_layers=1, + n_heads=4, + n_kv_heads=2, + head_dim=8, + hidden_dim=64, + max_seq_len=16, + max_context_len=16, + use_kv_cache=True, + attention_type="qwen3_5_full", + linear_conv_kernel_dim=4, + linear_key_head_dim=4, + linear_value_head_dim=4, + linear_num_key_heads=2, + linear_num_value_heads=4, + ) + rope = Rope(args) + attn = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope) + + x = torch.randn(1, 1, args.dim) + dummy_freq = torch.zeros(1, 1) + + # First token of sequence. + attn(x, dummy_freq, dummy_freq, input_pos=torch.tensor([0], dtype=torch.long)) + state_after_first = attn.recurrent_state.clone() + + # Decode continuation updates state. + attn(x, dummy_freq, dummy_freq, input_pos=torch.tensor([1], dtype=torch.long)) + state_after_second = attn.recurrent_state.clone() + self.assertFalse(torch.allclose(state_after_first, state_after_second)) + + # New sequence (input_pos=0) should reset internal state. + attn(x, dummy_freq, dummy_freq, input_pos=torch.tensor([0], dtype=torch.long)) + state_after_reset = attn.recurrent_state.clone() + self.assertTrue(torch.allclose(state_after_first, state_after_reset, atol=1e-5)) + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/models/qwen3_5/BUCK b/examples/models/qwen3_5/BUCK new file mode 100644 index 00000000000..9ee84a0ec5e --- /dev/null +++ b/examples/models/qwen3_5/BUCK @@ -0,0 +1,23 @@ +load("@fbcode_macros//build_defs:build_file_migration.bzl", "fbcode_target", "non_fbcode_target") +# Any targets that should be shared between fbcode and xplat must be defined in +# targets.bzl. This file can contain fbcode-only targets. + +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +oncall("executorch") + +fbcode_target(_kind = runtime.python_library, + name = "qwen3_5", + srcs = [ + "__init__.py", + "convert_weights.py", + ], + base_module = "executorch.examples.models.qwen3_5", + visibility = ["PUBLIC"], + deps = [ + "//caffe2:torch", + "//executorch/examples/models/llama:llama2_model", + "//executorch/examples/models:checkpoint", + "fbsource//third-party/pypi/safetensors:safetensors", + ], +) diff --git a/examples/models/qwen3_5/README.md b/examples/models/qwen3_5/README.md new file mode 100644 index 00000000000..315d603d655 --- /dev/null +++ b/examples/models/qwen3_5/README.md @@ -0,0 +1,52 @@ +## Summary +[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35-684357f8543f83de2f09f998) support in ExecuTorch is exported through the Llama example pipeline with a hybrid layer layout: +- `full_attention` layers use gated full attention. +- `linear_attention` layers use Gated DeltaNet with internal recurrent state. + +This first bring-up is **fp32 + static shape** (`enable_dynamic_shape=False`). +Currently supported text model sizes: `0.8B`, `2B`, `4B`. + +## Export +```bash +python -m extension.llm.export.export_llm \ + --config examples/models/qwen3_5/config/qwen3_5_xnnpack_fp32.yaml \ + +base.model_class="qwen3_5_0_8b" \ + +base.params="examples/models/qwen3_5/config/0_8b_config.json" \ + +export.output_name="qwen3_5_0_8b_fp32.pte" +``` + +```bash +python -m extension.llm.export.export_llm \ + --config examples/models/qwen3_5/config/qwen3_5_xnnpack_fp32.yaml \ + +base.model_class="qwen3_5_2b" \ + +base.params="examples/models/qwen3_5/config/2b_config.json" \ + +export.output_name="qwen3_5_2b_fp32.pte" +``` + +```bash +python -m extension.llm.export.export_llm \ + --config examples/models/qwen3_5/config/qwen3_5_xnnpack_fp32.yaml \ + +base.model_class="qwen3_5_4b" \ + +base.params="examples/models/qwen3_5/config/4b_config.json" \ + +export.output_name="qwen3_5_4b_fp32.pte" +``` + +The exporter will download and convert HF weights automatically when `+base.checkpoint` is not provided. + +## Run (Python Runner) +```bash +python -m examples.models.llama.runner.native \ + --model qwen3_5_0_8b \ + --pte qwen3_5_0_8b_fp32.pte \ + --tokenizer /path/to/tokenizer.json \ + --tokenizer_config /path/to/tokenizer_config.json \ + --prompt "<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n" \ + --params examples/models/qwen3_5/config/0_8b_config.json \ + --max_len 128 \ + -kv \ + --temperature 0.3 +``` + +## Notes +- Current path targets CPU/XNNPACK export validation and runner compatibility. +- `q8da4w` quantization for Qwen3.5 is intentionally deferred to a follow-up. diff --git a/examples/models/qwen3_5/__init__.py b/examples/models/qwen3_5/__init__.py new file mode 100644 index 00000000000..3177b5c4958 --- /dev/null +++ b/examples/models/qwen3_5/__init__.py @@ -0,0 +1,24 @@ +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import TYPE_CHECKING + +from executorch.examples.models.qwen3_5.convert_weights import convert_weights + +if TYPE_CHECKING: + from executorch.examples.models.llama.model import Llama2Model + +__all__ = ["Qwen3_5Model", "convert_weights"] + + +def __getattr__(name): + if name == "Qwen3_5Model": + from executorch.examples.models.llama.model import Llama2Model + + class Qwen3_5Model(Llama2Model): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + globals()["Qwen3_5Model"] = Qwen3_5Model + return Qwen3_5Model + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") diff --git a/examples/models/qwen3_5/config/0_8b_config.json b/examples/models/qwen3_5/config/0_8b_config.json new file mode 100644 index 00000000000..e532988410d --- /dev/null +++ b/examples/models/qwen3_5/config/0_8b_config.json @@ -0,0 +1,50 @@ +{ + "dim": 1024, + "hidden_dim": 3584, + "n_heads": 8, + "head_dim": 256, + "n_kv_heads": 2, + "n_layers": 24, + "norm_eps": 1e-6, + "rope_theta": 10000000.0, + "use_scaled_rope": false, + "vocab_size": 248320, + "use_hf_rope": true, + "partial_rotary_factor": 0.25, + "attention_qkv_bias": false, + "use_qk_norm": true, + "qk_norm_before_rope": true, + "attention_type": "qwen3_5_full", + "rms_norm_add_unit_offset": true, + "linear_conv_kernel_dim": 4, + "linear_key_head_dim": 128, + "linear_value_head_dim": 128, + "linear_num_key_heads": 16, + "linear_num_value_heads": 16, + "layer_types": [ + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention" + ] +} diff --git a/examples/models/qwen3_5/config/2b_config.json b/examples/models/qwen3_5/config/2b_config.json new file mode 100644 index 00000000000..e397bf6fd4a --- /dev/null +++ b/examples/models/qwen3_5/config/2b_config.json @@ -0,0 +1,50 @@ +{ + "dim": 2048, + "hidden_dim": 6144, + "n_heads": 8, + "head_dim": 256, + "n_kv_heads": 2, + "n_layers": 24, + "norm_eps": 1e-6, + "rope_theta": 10000000.0, + "use_scaled_rope": false, + "vocab_size": 248320, + "use_hf_rope": true, + "partial_rotary_factor": 0.25, + "attention_qkv_bias": false, + "use_qk_norm": true, + "qk_norm_before_rope": true, + "attention_type": "qwen3_5_full", + "rms_norm_add_unit_offset": true, + "linear_conv_kernel_dim": 4, + "linear_key_head_dim": 128, + "linear_value_head_dim": 128, + "linear_num_key_heads": 16, + "linear_num_value_heads": 16, + "layer_types": [ + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention" + ] +} diff --git a/examples/models/qwen3_5/config/4b_config.json b/examples/models/qwen3_5/config/4b_config.json new file mode 100644 index 00000000000..75c4313b83a --- /dev/null +++ b/examples/models/qwen3_5/config/4b_config.json @@ -0,0 +1,58 @@ +{ + "dim": 2560, + "hidden_dim": 9216, + "n_heads": 16, + "head_dim": 256, + "n_kv_heads": 4, + "n_layers": 32, + "norm_eps": 1e-6, + "rope_theta": 10000000.0, + "use_scaled_rope": false, + "vocab_size": 248320, + "use_hf_rope": true, + "partial_rotary_factor": 0.25, + "attention_qkv_bias": false, + "use_qk_norm": true, + "qk_norm_before_rope": true, + "attention_type": "qwen3_5_full", + "rms_norm_add_unit_offset": true, + "linear_conv_kernel_dim": 4, + "linear_key_head_dim": 128, + "linear_value_head_dim": 128, + "linear_num_key_heads": 16, + "linear_num_value_heads": 32, + "layer_types": [ + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention" + ] +} diff --git a/examples/models/qwen3_5/config/qwen3_5_xnnpack_fp32.yaml b/examples/models/qwen3_5/config/qwen3_5_xnnpack_fp32.yaml new file mode 100644 index 00000000000..496f2da6e2b --- /dev/null +++ b/examples/models/qwen3_5/config/qwen3_5_xnnpack_fp32.yaml @@ -0,0 +1,17 @@ +base: + metadata: '{"get_bos_id": 151644, "get_eos_ids":[151645]}' + +model: + use_kv_cache: True + use_sdpa_with_kv_cache: False + enable_dynamic_shape: False + dtype_override: fp32 + +export: + max_seq_length: 2048 + max_context_length: 2048 + +backend: + xnnpack: + enabled: True + extended_ops: True diff --git a/examples/models/qwen3_5/convert_weights.py b/examples/models/qwen3_5/convert_weights.py new file mode 100644 index 00000000000..1963961fdfc --- /dev/null +++ b/examples/models/qwen3_5/convert_weights.py @@ -0,0 +1,201 @@ +import argparse +import json +import os +import re +from typing import Dict + +import torch +from executorch.examples.models.checkpoint import ( + get_mapped_key, + load_checkpoint_from_pytorch_model, +) + +_QWEN_3_5_TO_META = { + "model.embed_tokens.weight": "tok_embeddings.weight", + "model.norm.weight": "norm.weight", + "lm_head.weight": "output.weight", + "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", + "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", + "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", + # Full-attention layers. + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", + "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", + "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", + "model.layers.{}.self_attn.q_norm.weight": "layers.{}.attention.q_norm_fn.weight", + "model.layers.{}.self_attn.k_norm.weight": "layers.{}.attention.k_norm_fn.weight", + # Linear-attention layers. + "model.layers.{}.linear_attn.in_proj_qkv.weight": "layers.{}.attention.in_proj_qkv.weight", + "model.layers.{}.linear_attn.in_proj_z.weight": "layers.{}.attention.in_proj_z.weight", + "model.layers.{}.linear_attn.in_proj_b.weight": "layers.{}.attention.in_proj_b.weight", + "model.layers.{}.linear_attn.in_proj_a.weight": "layers.{}.attention.in_proj_a.weight", + "model.layers.{}.linear_attn.conv1d.weight": "layers.{}.attention.conv1d.weight", + "model.layers.{}.linear_attn.conv1d.bias": "layers.{}.attention.conv1d.bias", + "model.layers.{}.linear_attn.dt_bias": "layers.{}.attention.dt_bias", + "model.layers.{}.linear_attn.A_log": "layers.{}.attention.A_log", + "model.layers.{}.linear_attn.norm.weight": "layers.{}.attention.norm.weight", + "model.layers.{}.linear_attn.out_proj.weight": "layers.{}.attention.out_proj.weight", +} + + +def _load_checkpoint_from_safetensors(input_dir: str) -> Dict: + from safetensors.torch import load_file + + index_path = os.path.join(input_dir, "model.safetensors.index.json") + if os.path.exists(index_path): + with open(index_path, "r") as f: + index = json.load(f) + weight_map = index["weight_map"] + checkpoint_shards = sorted(set(weight_map.values())) + + shard_to_weights = {} + for shard in checkpoint_shards: + shard_to_weights[shard] = load_file(os.path.join(input_dir, shard)) + + merged_state_dict = {} + for weight_name, shard in weight_map.items(): + merged_state_dict[weight_name] = shard_to_weights[shard][weight_name] + return merged_state_dict + + model_path = os.path.join(input_dir, "model.safetensors") + if os.path.exists(model_path): + return load_file(model_path) + + raise FileNotFoundError(f"Could not find safetensors checkpoint in {input_dir}") + + +def load_checkpoint(input_dir: str) -> Dict: + try: + print("Loading checkpoint from pytorch_model directory") + return load_checkpoint_from_pytorch_model(input_dir) + except FileNotFoundError: + print( + "Could not find pytorch_model checkpoints in directory, trying safetensors" + ) + + try: + print("Loading checkpoint from safetensors directory") + return _load_checkpoint_from_safetensors(input_dir) + except FileNotFoundError: + pass + + raise FileNotFoundError(f"Could not find checkpoint in {input_dir}") + + +def qwen_3_5_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + converted_state_dict = {} + pending_qkvz = {} + pending_ba = {} + + for key, value in state_dict.items(): + normalized_key = key + # HF multimodal Qwen3.5 checkpoints store text weights under + # `model.language_model.*`. Normalize to `model.*`. + if normalized_key.startswith("model.language_model."): + normalized_key = normalized_key.replace( + "model.language_model.", "model.", 1 + ) + + # Legacy packed tensors (older checkpoints): + # in_proj_qkvz -> split into in_proj_qkv and in_proj_z + # in_proj_ba -> split into in_proj_b and in_proj_a + if normalized_key.endswith(".linear_attn.in_proj_qkvz.weight"): + pending_qkvz[normalized_key] = value + continue + if normalized_key.endswith(".linear_attn.in_proj_ba.weight"): + pending_ba[normalized_key] = value + continue + + try: + new_key = get_mapped_key(normalized_key, _QWEN_3_5_TO_META) + except Exception: + # Ignore non-text weights and training-only extras (e.g., MTP). + if ( + key.startswith("mtp.") + or key.startswith("model.visual.") + or ".vision_" in key + or key.startswith("visual.") + ): + continue + # Ignore unsupported keys that are not required by the export model. + continue + converted_state_dict[new_key] = value + + for key, value in pending_qkvz.items(): + layer_match = re.search(r"model\.layers\.(\d+)\.", key) + if layer_match is None: + raise ValueError(f"Failed to parse layer id from key: {key}") + layer_id = layer_match.group(1) + out_proj_key = f"layers.{layer_id}.attention.out_proj.weight" + if out_proj_key not in converted_state_dict: + raise ValueError( + f"Cannot split {key}: missing {out_proj_key} to infer value dimension." + ) + + value_dim = converted_state_dict[out_proj_key].shape[1] + total_dim = value.shape[0] + conv_dim = total_dim - value_dim + if conv_dim <= 0 or (conv_dim - value_dim) % 2 != 0: + raise ValueError( + f"Invalid packed in_proj_qkvz shape for {key}: {tuple(value.shape)}" + ) + key_dim = (conv_dim - value_dim) // 2 + + qkv, z = torch.split(value, [conv_dim, value_dim], dim=0) + converted_state_dict[f"layers.{layer_id}.attention.in_proj_qkv.weight"] = qkv + converted_state_dict[f"layers.{layer_id}.attention.in_proj_z.weight"] = z + print(f"Split legacy packed key {key} -> in_proj_qkv + in_proj_z") + + for key, value in pending_ba.items(): + layer_match = re.search(r"model\.layers\.(\d+)\.", key) + if layer_match is None: + raise ValueError(f"Failed to parse layer id from key: {key}") + layer_id = layer_match.group(1) + if value.shape[0] % 2 != 0: + raise ValueError( + f"Invalid packed in_proj_ba shape for {key}: {tuple(value.shape)}" + ) + half = value.shape[0] // 2 + b, a = torch.split(value, [half, half], dim=0) + converted_state_dict[f"layers.{layer_id}.attention.in_proj_b.weight"] = b + converted_state_dict[f"layers.{layer_id}.attention.in_proj_a.weight"] = a + print(f"Split legacy packed key {key} -> in_proj_b + in_proj_a") + + # Handle tied embeddings. + if "lm_head.weight" not in state_dict: + converted_state_dict["output.weight"] = converted_state_dict[ + "tok_embeddings.weight" + ] + + return converted_state_dict + + +def convert_weights(input_dir: str, output_file: str) -> None: + print("Loading checkpoint...") + state_dict = load_checkpoint(input_dir) + print("Converting checkpoint...") + state_dict = qwen_3_5_to_meta(state_dict) + print("Saving checkpoint...") + torch.save(state_dict, output_file) + print("Done.") + + +def main(): + parser = argparse.ArgumentParser( + description="Convert Qwen3.5 weights to ExecuTorch meta format." + ) + parser.add_argument( + "input_dir", + type=str, + help="Path to directory containing safetensor or PyTorch checkpoint files.", + ) + parser.add_argument("output", type=str, help="Path to the output checkpoint") + + args = parser.parse_args() + convert_weights(args.input_dir, args.output) + + +if __name__ == "__main__": + main() diff --git a/examples/models/qwen3_5/tests/__init__.py b/examples/models/qwen3_5/tests/__init__.py new file mode 100644 index 00000000000..5e827e62ea1 --- /dev/null +++ b/examples/models/qwen3_5/tests/__init__.py @@ -0,0 +1,2 @@ +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/examples/models/qwen3_5/tests/test_convert_weights.py b/examples/models/qwen3_5/tests/test_convert_weights.py new file mode 100644 index 00000000000..0e7d1d8c8d3 --- /dev/null +++ b/examples/models/qwen3_5/tests/test_convert_weights.py @@ -0,0 +1,44 @@ +import unittest + +import torch +from executorch.examples.models.qwen3_5.convert_weights import qwen_3_5_to_meta + + +class Qwen35ConvertWeightsTest(unittest.TestCase): + def test_maps_full_and_linear_attention_weights(self): + state_dict = { + "model.embed_tokens.weight": torch.randn(16, 8), + "model.norm.weight": torch.randn(8), + "lm_head.weight": torch.randn(16, 8), + "model.layers.0.input_layernorm.weight": torch.randn(8), + "model.layers.0.post_attention_layernorm.weight": torch.randn(8), + "model.layers.0.mlp.gate_proj.weight": torch.randn(12, 8), + "model.layers.0.mlp.down_proj.weight": torch.randn(8, 12), + "model.layers.0.mlp.up_proj.weight": torch.randn(12, 8), + "model.layers.0.self_attn.q_proj.weight": torch.randn(16, 8), + "model.layers.0.self_attn.k_proj.weight": torch.randn(8, 8), + "model.layers.0.self_attn.v_proj.weight": torch.randn(8, 8), + "model.layers.0.self_attn.o_proj.weight": torch.randn(8, 8), + "model.layers.0.self_attn.q_norm.weight": torch.randn(4), + "model.layers.0.self_attn.k_norm.weight": torch.randn(4), + "model.layers.1.linear_attn.in_proj_qkv.weight": torch.randn(24, 8), + "model.layers.1.linear_attn.in_proj_z.weight": torch.randn(8, 8), + "model.layers.1.linear_attn.in_proj_b.weight": torch.randn(2, 8), + "model.layers.1.linear_attn.in_proj_a.weight": torch.randn(2, 8), + "model.layers.1.linear_attn.conv1d.weight": torch.randn(24, 1, 4), + "model.layers.1.linear_attn.dt_bias": torch.randn(2), + "model.layers.1.linear_attn.A_log": torch.randn(2), + "model.layers.1.linear_attn.norm.weight": torch.randn(4), + "model.layers.1.linear_attn.out_proj.weight": torch.randn(8, 8), + } + + converted = qwen_3_5_to_meta(state_dict) + self.assertIn("layers.0.attention.wq.weight", converted) + self.assertIn("layers.0.attention.q_norm_fn.weight", converted) + self.assertIn("layers.1.attention.in_proj_qkv.weight", converted) + self.assertIn("layers.1.attention.out_proj.weight", converted) + self.assertIn("layers.1.attention.dt_bias", converted) + + +if __name__ == "__main__": + unittest.main() diff --git a/extension/llm/export/config/llm_config.py b/extension/llm/export/config/llm_config.py index eea99ea9b2b..06df3aa38be 100644 --- a/extension/llm/export/config/llm_config.py +++ b/extension/llm/export/config/llm_config.py @@ -44,6 +44,9 @@ class ModelType(str, Enum): qwen3_0_6b = "qwen3_0_6b" qwen3_1_7b = "qwen3_1_7b" qwen3_4b = "qwen3_4b" + qwen3_5_0_8b = "qwen3_5_0_8b" + qwen3_5_2b = "qwen3_5_2b" + qwen3_5_4b = "qwen3_5_4b" phi_4_mini = "phi_4_mini" smollm2 = "smollm2" lfm2_350m = "lfm2_350m" From b4754f20ba3a8a9323c9b85d2df4cffbebc2e85a Mon Sep 17 00:00:00 2001 From: Phineas1500 <41450967+Phineas1500@users.noreply.github.com> Date: Mon, 2 Mar 2026 18:49:07 -0500 Subject: [PATCH 02/11] Initialize Qwen3.5 mutable buffers during export --- examples/models/llama/export_llama_lib.py | 33 +++++++++++++++---- .../llama/tests/test_export_llama_lib.py | 20 +++++++++++ 2 files changed, 47 insertions(+), 6 deletions(-) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 94984565e10..05a08fc5a5b 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -135,6 +135,27 @@ } +def _get_additional_export_passes(model_class: str) -> List[InitializedMutableBufferPass]: + patterns = [] + + if model_class in TORCHTUNE_DEFINED_MODELS: + patterns.append("kv_cache_pos") + + # Qwen3.5 uses internal mutable buffers for both the hybrid KV path and + # DeltaNet recurrent/conv states. + if model_class.startswith("qwen3_5"): + patterns.extend( + [ + "k_cache", + "v_cache", + "conv_state", + "recurrent_state", + ] + ) + + return [InitializedMutableBufferPass(patterns)] if patterns else [] + + def set_pkg_name(name: str) -> None: global pkg_name pkg_name = name @@ -1268,9 +1289,9 @@ def _export_llama_multimethod(llm_config: LlmConfig) -> LLMEdgeManager: "Each method requires separate model instantiation and export." ) - additional_passes = [] - if llm_config.base.model_class.value in TORCHTUNE_DEFINED_MODELS: - additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])] + additional_passes = _get_additional_export_passes( + llm_config.base.model_class.value + ) # Build dict of exported programs method_to_program: Dict[str, ExportedProgram] = {} @@ -1341,9 +1362,9 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 llm_config ) - additional_passes = [] - if llm_config.base.model_class.value in TORCHTUNE_DEFINED_MODELS: - additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])] + additional_passes = _get_additional_export_passes( + llm_config.base.model_class.value + ) # export_to_edge builder_manager = _prepare_for_llama_export(llm_config) diff --git a/examples/models/llama/tests/test_export_llama_lib.py b/examples/models/llama/tests/test_export_llama_lib.py index 243c186cccc..3d1b6e85a81 100644 --- a/examples/models/llama/tests/test_export_llama_lib.py +++ b/examples/models/llama/tests/test_export_llama_lib.py @@ -22,10 +22,12 @@ TOSAQuantizer = None from executorch.examples.models.llama.export_llama_lib import ( + _get_additional_export_passes, _export_llama, build_args_parser, get_quantizer_and_quant_params, ) +from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass from executorch.extension.llm.export.config.llm_config import LlmConfig, Pt2eQuantize UNWANTED_OPS = [ @@ -35,6 +37,24 @@ class ExportLlamaLibTest(unittest.TestCase): + def test_qwen3_5_mutable_buffer_passes(self): + passes = _get_additional_export_passes("qwen3_5_0_8b") + self.assertEqual(len(passes), 1) + self.assertIsInstance(passes[0], InitializedMutableBufferPass) + self.assertEqual( + passes[0].patterns, + ["k_cache", "v_cache", "conv_state", "recurrent_state"], + ) + + def test_torchtune_mutable_buffer_passes(self): + passes = _get_additional_export_passes("llama3_2_vision") + self.assertEqual(len(passes), 1) + self.assertIsInstance(passes[0], InitializedMutableBufferPass) + self.assertEqual(passes[0].patterns, ["kv_cache_pos"]) + + def test_llama3_has_no_extra_mutable_buffer_passes(self): + self.assertEqual(_get_additional_export_passes("llama3"), []) + def test_has_expected_ops_and_op_counts(self): """ Checks the presence of unwanted expensive ops. From 3bfc5ce1cb96b54f4a2c366b6ecbc4d255ba11dd Mon Sep 17 00:00:00 2001 From: Phineas1500 <41450967+Phineas1500@users.noreply.github.com> Date: Mon, 2 Mar 2026 18:56:10 -0500 Subject: [PATCH 03/11] Fix lint and URL checks for Qwen3.5 files --- examples/models/llama/attention.py | 10 ++++++---- examples/models/llama/norm.py | 4 +--- examples/models/qwen3_5/README.md | 2 +- examples/models/qwen3_5/__init__.py | 5 ----- examples/models/qwen3_5/convert_weights.py | 2 -- 5 files changed, 8 insertions(+), 15 deletions(-) diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index da5f1b46b29..315d4ccddbf 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -760,7 +760,9 @@ def __init__( ), ) - def _maybe_reset_state(self, input_pos: Optional[torch.Tensor], batch_size: int) -> None: + def _maybe_reset_state( + self, input_pos: Optional[torch.Tensor], batch_size: int + ) -> None: if input_pos is None: return reset = (input_pos[0] == 0).to(self.conv_state.dtype) @@ -830,9 +832,9 @@ def _recurrent_gated_delta_rule( last_recurrent_state = last_recurrent_state * g_t kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) delta = (v_t - kv_mem) * beta_t - last_recurrent_state = ( - last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) - ) + last_recurrent_state = last_recurrent_state + k_t.unsqueeze( + -1 + ) * delta.unsqueeze(-2) core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum( dim=-2 ) diff --git a/examples/models/llama/norm.py b/examples/models/llama/norm.py index 52ae30c1697..c56cc61c7a5 100644 --- a/examples/models/llama/norm.py +++ b/examples/models/llama/norm.py @@ -9,9 +9,7 @@ class RMSNorm(torch.nn.Module): - def __init__( - self, dim: int, eps: float = 1e-6, add_unit_offset: bool = False - ): + def __init__(self, dim: int, eps: float = 1e-6, add_unit_offset: bool = False): """ Initialize the RMSNorm normalization layer. diff --git a/examples/models/qwen3_5/README.md b/examples/models/qwen3_5/README.md index 315d603d655..f1de77b438a 100644 --- a/examples/models/qwen3_5/README.md +++ b/examples/models/qwen3_5/README.md @@ -1,5 +1,5 @@ ## Summary -[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35-684357f8543f83de2f09f998) support in ExecuTorch is exported through the Llama example pipeline with a hybrid layer layout: +[Qwen3.5](https://huggingface.co/Qwen/Qwen3.5-4B) support in ExecuTorch is exported through the Llama example pipeline with a hybrid layer layout: - `full_attention` layers use gated full attention. - `linear_attention` layers use Gated DeltaNet with internal recurrent state. diff --git a/examples/models/qwen3_5/__init__.py b/examples/models/qwen3_5/__init__.py index 3177b5c4958..b336832bb4b 100644 --- a/examples/models/qwen3_5/__init__.py +++ b/examples/models/qwen3_5/__init__.py @@ -1,13 +1,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import TYPE_CHECKING - from executorch.examples.models.qwen3_5.convert_weights import convert_weights -if TYPE_CHECKING: - from executorch.examples.models.llama.model import Llama2Model - __all__ = ["Qwen3_5Model", "convert_weights"] diff --git a/examples/models/qwen3_5/convert_weights.py b/examples/models/qwen3_5/convert_weights.py index 1963961fdfc..3566d4264f5 100644 --- a/examples/models/qwen3_5/convert_weights.py +++ b/examples/models/qwen3_5/convert_weights.py @@ -141,8 +141,6 @@ def qwen_3_5_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Ten raise ValueError( f"Invalid packed in_proj_qkvz shape for {key}: {tuple(value.shape)}" ) - key_dim = (conv_dim - value_dim) // 2 - qkv, z = torch.split(value, [conv_dim, value_dim], dim=0) converted_state_dict[f"layers.{layer_id}.attention.in_proj_qkv.weight"] = qkv converted_state_dict[f"layers.{layer_id}.attention.in_proj_z.weight"] = z From e6e6286ff4686f0ba88e9ca4457c7987226f2ca9 Mon Sep 17 00:00:00 2001 From: Phineas1500 <41450967+Phineas1500@users.noreply.github.com> Date: Mon, 2 Mar 2026 19:02:35 -0500 Subject: [PATCH 04/11] Make Qwen3.5 full-attention output projection bias-free --- examples/models/llama/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index 315d4ccddbf..2d84e4cc5ce 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -591,7 +591,7 @@ def __init__( self.wo = nn.Linear( self.n_heads * self.head_dim, self.dim, - bias=self.attention_qkv_bias, + bias=False, ) self.layer_id = layer_id From 3466f745700fd440da3b6fe6ea4d81482abf1f74 Mon Sep 17 00:00:00 2001 From: Phineas1500 <41450967+Phineas1500@users.noreply.github.com> Date: Mon, 2 Mar 2026 19:39:38 -0500 Subject: [PATCH 05/11] Fix Qwen3.5 metadata and static prefill handling --- examples/models/llama/runner/eager.py | 1 + examples/models/llama/runner/generation.py | 56 +++++++++-- examples/models/llama/runner/native.py | 7 ++ examples/models/llama/tests/BUCK | 11 +++ .../llama/tests/test_generation_prefill.py | 99 +++++++++++++++++++ examples/models/qwen3_5/README.md | 11 ++- .../qwen3_5/config/qwen3_5_xnnpack_fp32.yaml | 6 +- 7 files changed, 180 insertions(+), 11 deletions(-) create mode 100644 examples/models/llama/tests/test_generation_prefill.py diff --git a/examples/models/llama/runner/eager.py b/examples/models/llama/runner/eager.py index 7e662317509..9a92a08ae82 100644 --- a/examples/models/llama/runner/eager.py +++ b/examples/models/llama/runner/eager.py @@ -43,6 +43,7 @@ def __init__( ) manager: LLMEdgeManager = _prepare_for_llama_export(llm_config) self.model = manager.model.eval().to(device=self.device) + self.enable_dynamic_shape = llm_config.model.enable_dynamic_shape def forward( self, diff --git a/examples/models/llama/runner/generation.py b/examples/models/llama/runner/generation.py index 2baa8f5cd14..95fe47147db 100644 --- a/examples/models/llama/runner/generation.py +++ b/examples/models/llama/runner/generation.py @@ -72,6 +72,7 @@ def __init__( self.max_seq_len = max_seq_len self.max_batch_size = max_batch_size self.use_kv_cache = use_kv_cache + self.enable_dynamic_shape = True self.tokenizer = get_tokenizer(tokenizer_path, tokenizer_config_path) self.device = device # For some models like qwen, mismatch is acceptable: https://github.com/QwenLM/Qwen2.5/issues/466#issuecomment-2146759706 @@ -88,6 +89,45 @@ def forward( ) -> torch.Tensor: pass + def _prefill_with_kv_cache( + self, + prompt_tokens: List[int], + pos_base: int, + ) -> torch.Tensor: + if not self.enable_dynamic_shape and len(prompt_tokens) > 1: + return self._sequential_kv_prefill(prompt_tokens, pos_base) + + try: + return self.forward( + tokens=torch.tensor( + [prompt_tokens], dtype=torch.long, device=self.device + ), + input_pos=torch.tensor([pos_base], dtype=torch.long, device=self.device), + ) + except RuntimeError: + # Some exported models use a static single-token shape for kv-cache mode. + # Fall back to sequential token prefill so multi-token prompts still work. + if self.enable_dynamic_shape or len(prompt_tokens) <= 1: + raise + + return self._sequential_kv_prefill(prompt_tokens, pos_base) + + def _sequential_kv_prefill( + self, + prompt_tokens: List[int], + pos_base: int, + ) -> torch.Tensor: + logits = None + for offset, token in enumerate(prompt_tokens): + logits = self.forward( + tokens=torch.tensor([[token]], dtype=torch.long, device=self.device), + input_pos=torch.tensor( + [pos_base + offset], dtype=torch.long, device=self.device + ), + ) + assert logits is not None + return logits + def generate( # noqa: C901 self, prompt_tokens: List[int], @@ -99,14 +139,14 @@ def generate( # noqa: C901 ) -> List[int]: # Prefill prefill_start = time.time() - logits = self.forward( - tokens=torch.tensor([prompt_tokens], dtype=torch.long, device=self.device), - input_pos=( - torch.tensor([pos_base], dtype=torch.long, device=self.device) - if self.use_kv_cache - else None - ), - ) + if self.use_kv_cache: + logits = self._prefill_with_kv_cache(prompt_tokens, pos_base) + else: + logits = self.forward( + tokens=torch.tensor( + [prompt_tokens], dtype=torch.long, device=self.device + ), + ) prefill_time = time.time() - prefill_start current_token = next_token(logits, temperature, top_p) diff --git a/examples/models/llama/runner/native.py b/examples/models/llama/runner/native.py index 6d5d4730844..ffb19ab3c08 100644 --- a/examples/models/llama/runner/native.py +++ b/examples/models/llama/runner/native.py @@ -44,6 +44,13 @@ def __init__(self, args): vocab_size=params["vocab_size"], ) self.model = _load_for_executorch(args.pte) + try: + self.enable_dynamic_shape = bool( + self.model.run_method("enable_dynamic_shape")[0] + ) + except Exception: + # Keep default behavior when metadata method is unavailable. + pass def forward( self, diff --git a/examples/models/llama/tests/BUCK b/examples/models/llama/tests/BUCK index 431c3c92814..e0dae1147aa 100644 --- a/examples/models/llama/tests/BUCK +++ b/examples/models/llama/tests/BUCK @@ -115,3 +115,14 @@ fbcode_target(_kind = python_unittest, "//executorch/extension/pybindings:portable_lib", ], ) + +fbcode_target(_kind = python_unittest, + name = "test_generation_prefill", + srcs = [ + "test_generation_prefill.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/examples/models/llama/runner:eager_runner_library", + ], +) diff --git a/examples/models/llama/tests/test_generation_prefill.py b/examples/models/llama/tests/test_generation_prefill.py new file mode 100644 index 00000000000..b3cd3f68eb1 --- /dev/null +++ b/examples/models/llama/tests/test_generation_prefill.py @@ -0,0 +1,99 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from unittest.mock import patch + +import torch + +from executorch.examples.models.llama.runner.generation import LlamaRunner + + +class _DummyTokenizer: + n_words = 100 + eos_id = 2 + stop_tokens = [2] + + def encode(self, _text, bos=False, eos=False): + del bos + del eos + return [10, 11] + + def decode_token(self, token_id): + return str(token_id) + + +class _DummyRunner(LlamaRunner): + def __init__(self, raise_on_parallel_prefill=False): + self.calls = [] + self.raise_on_parallel_prefill = raise_on_parallel_prefill + super().__init__( + tokenizer_path="unused", + tokenizer_config_path=None, + max_seq_len=16, + max_batch_size=1, + use_kv_cache=True, + vocab_size=100, + device="cpu", + ) + + def forward(self, tokens: torch.Tensor, input_pos=None) -> torch.Tensor: + self.calls.append( + (tokens.clone(), input_pos.clone() if input_pos is not None else None) + ) + if self.raise_on_parallel_prefill and tokens.shape[1] > 1: + raise RuntimeError("parallel prefill failure") + return torch.zeros((1, 8), dtype=torch.float32) + + +class TestGenerationPrefill(unittest.TestCase): + @patch( + "executorch.examples.models.llama.runner.generation.get_tokenizer", + return_value=_DummyTokenizer(), + ) + def test_static_prefill_uses_sequential_tokens(self, _mock_get_tokenizer): + runner = _DummyRunner() + runner.enable_dynamic_shape = False + + runner._prefill_with_kv_cache([5, 6, 7], pos_base=3) + + self.assertEqual(len(runner.calls), 3) + for i, (tokens, input_pos) in enumerate(runner.calls): + self.assertEqual(tuple(tokens.shape), (1, 1)) + self.assertEqual(tokens.item(), 5 + i) + self.assertEqual(input_pos.item(), 3 + i) + + @patch( + "executorch.examples.models.llama.runner.generation.get_tokenizer", + return_value=_DummyTokenizer(), + ) + def test_dynamic_prefill_uses_batched_prompt(self, _mock_get_tokenizer): + runner = _DummyRunner() + runner.enable_dynamic_shape = True + + runner._prefill_with_kv_cache([5, 6, 7], pos_base=4) + + self.assertEqual(len(runner.calls), 1) + tokens, input_pos = runner.calls[0] + self.assertEqual(tuple(tokens.shape), (1, 3)) + self.assertEqual(input_pos.item(), 4) + + @patch( + "executorch.examples.models.llama.runner.generation.get_tokenizer", + return_value=_DummyTokenizer(), + ) + def test_dynamic_prefill_does_not_mask_runtime_errors(self, _mock_get_tokenizer): + runner = _DummyRunner(raise_on_parallel_prefill=True) + runner.enable_dynamic_shape = True + + with self.assertRaisesRegex(RuntimeError, "parallel prefill failure"): + runner._prefill_with_kv_cache([5, 6], pos_base=0) + + self.assertEqual(len(runner.calls), 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/models/qwen3_5/README.md b/examples/models/qwen3_5/README.md index f1de77b438a..e439f256f2a 100644 --- a/examples/models/qwen3_5/README.md +++ b/examples/models/qwen3_5/README.md @@ -1,5 +1,5 @@ ## Summary -[Qwen3.5](https://huggingface.co/Qwen/Qwen3.5-4B) support in ExecuTorch is exported through the Llama example pipeline with a hybrid layer layout: +[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35-684357f8543f83de2f09f998) support in ExecuTorch is exported through the Llama example pipeline with a hybrid layer layout: - `full_attention` layers use gated full attention. - `linear_attention` layers use Gated DeltaNet with internal recurrent state. @@ -32,10 +32,14 @@ python -m extension.llm.export.export_llm \ ``` The exporter will download and convert HF weights automatically when `+base.checkpoint` is not provided. +Install `safetensors` in your environment if it is missing: +```bash +python -m pip install safetensors +``` ## Run (Python Runner) ```bash -python -m examples.models.llama.runner.native \ +python -m executorch.examples.models.llama.runner.native \ --model qwen3_5_0_8b \ --pte qwen3_5_0_8b_fp32.pte \ --tokenizer /path/to/tokenizer.json \ @@ -50,3 +54,6 @@ python -m examples.models.llama.runner.native \ ## Notes - Current path targets CPU/XNNPACK export validation and runner compatibility. - `q8da4w` quantization for Qwen3.5 is intentionally deferred to a follow-up. +- Dynamic-shape export is not enabled yet for Qwen3.5 DeltaNet layers in this path; keep `enable_dynamic_shape=False`. +- For static-shape exports, `runner.native` falls back to sequential token prefill for multi-token prompts. +- Default metadata uses Qwen3.5 special token ids: `get_bos_id=248045`, `get_eos_ids=[248046,248044]`. diff --git a/examples/models/qwen3_5/config/qwen3_5_xnnpack_fp32.yaml b/examples/models/qwen3_5/config/qwen3_5_xnnpack_fp32.yaml index 496f2da6e2b..93882f67a6e 100644 --- a/examples/models/qwen3_5/config/qwen3_5_xnnpack_fp32.yaml +++ b/examples/models/qwen3_5/config/qwen3_5_xnnpack_fp32.yaml @@ -1,5 +1,9 @@ base: - metadata: '{"get_bos_id": 151644, "get_eos_ids":[151645]}' + # Qwen3.5 tokenizer special ids: + # <|im_start|> 248045 + # <|im_end|> 248046 + # <|endoftext|> 248044 + metadata: '{"get_bos_id": 248045, "get_eos_ids":[248046,248044]}' model: use_kv_cache: True From f686ef5f3527e311c9438629ed629adf31774188 Mon Sep 17 00:00:00 2001 From: Phineas1500 <41450967+Phineas1500@users.noreply.github.com> Date: Mon, 2 Mar 2026 19:55:35 -0500 Subject: [PATCH 06/11] Harden Qwen3.5 conversion and RMSNorm dtype behavior --- examples/models/llama/norm.py | 2 +- .../llama/tests/test_qwen3_5_attention.py | 27 ++++++++++ examples/models/qwen3_5/convert_weights.py | 51 +++++++++++++++---- .../qwen3_5/tests/test_convert_weights.py | 25 +++++++++ 4 files changed, 94 insertions(+), 11 deletions(-) diff --git a/examples/models/llama/norm.py b/examples/models/llama/norm.py index c56cc61c7a5..a9944db7a00 100644 --- a/examples/models/llama/norm.py +++ b/examples/models/llama/norm.py @@ -55,4 +55,4 @@ def forward(self, x): output = self._norm(x.float()).type_as(x) if self.add_unit_offset: return output * (1.0 + self.weight.float()).type_as(x) - return output * self.weight + return output * self.weight.type_as(x) diff --git a/examples/models/llama/tests/test_qwen3_5_attention.py b/examples/models/llama/tests/test_qwen3_5_attention.py index b076e69d121..9de9244f960 100644 --- a/examples/models/llama/tests/test_qwen3_5_attention.py +++ b/examples/models/llama/tests/test_qwen3_5_attention.py @@ -9,10 +9,37 @@ import torch from executorch.examples.models.llama.attention import ATTENTION_REGISTRY from executorch.examples.models.llama.model_args import ModelArgs +from executorch.examples.models.llama.norm import RMSNorm from executorch.examples.models.llama.rope import Rope class Qwen35AttentionTest(unittest.TestCase): + def test_qwen35_full_attention_output_proj_is_bias_free(self): + args = ModelArgs( + dim=32, + n_layers=1, + n_heads=4, + n_kv_heads=2, + head_dim=8, + hidden_dim=64, + max_seq_len=16, + max_context_len=16, + use_kv_cache=False, + use_qk_norm=False, + qk_norm_before_rope=True, + attention_type="qwen3_5_full", + attention_qkv_bias=True, + ) + rope = Rope(args) + attn = ATTENTION_REGISTRY["qwen3_5_full"](args, 0, rope) + self.assertIsNone(attn.wo.bias) + + def test_rmsnorm_preserves_input_dtype_without_unit_offset(self): + norm = RMSNorm(dim=8, add_unit_offset=False) + x = torch.randn(2, 3, 8, dtype=torch.bfloat16) + y = norm(x) + self.assertEqual(y.dtype, x.dtype) + def test_qwen35_full_attention_forward_shape(self): torch.manual_seed(0) args = ModelArgs( diff --git a/examples/models/qwen3_5/convert_weights.py b/examples/models/qwen3_5/convert_weights.py index 3566d4264f5..b0a8fc47305 100644 --- a/examples/models/qwen3_5/convert_weights.py +++ b/examples/models/qwen3_5/convert_weights.py @@ -40,6 +40,34 @@ } +_IGNORED_UNMAPPED_PREFIXES = ( + "mtp.", + "model.visual.", + "visual.", +) + +_IGNORED_UNMAPPED_SUBSTRINGS = ( + ".vision_", + ".visual.", +) + +_IGNORED_UNMAPPED_SUFFIXES = ( + "rotary_emb.inv_freq", +) + + +def _should_ignore_unmapped_key(key: str, normalized_key: str) -> bool: + candidates = (key, normalized_key) + for candidate in candidates: + if any(candidate.startswith(prefix) for prefix in _IGNORED_UNMAPPED_PREFIXES): + return True + if any(token in candidate for token in _IGNORED_UNMAPPED_SUBSTRINGS): + return True + if any(candidate.endswith(suffix) for suffix in _IGNORED_UNMAPPED_SUFFIXES): + return True + return False + + def _load_checkpoint_from_safetensors(input_dir: str) -> Dict: from safetensors.torch import load_file @@ -98,6 +126,14 @@ def qwen_3_5_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Ten "model.language_model.", "model.", 1 ) + # Ignore non-language-model keys up front. + if not ( + normalized_key.startswith("model.") or normalized_key.startswith("lm_head.") + ): + if _should_ignore_unmapped_key(key, normalized_key): + continue + continue + # Legacy packed tensors (older checkpoints): # in_proj_qkvz -> split into in_proj_qkv and in_proj_z # in_proj_ba -> split into in_proj_b and in_proj_a @@ -110,17 +146,12 @@ def qwen_3_5_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Ten try: new_key = get_mapped_key(normalized_key, _QWEN_3_5_TO_META) - except Exception: - # Ignore non-text weights and training-only extras (e.g., MTP). - if ( - key.startswith("mtp.") - or key.startswith("model.visual.") - or ".vision_" in key - or key.startswith("visual.") - ): + except Exception as err: + if _should_ignore_unmapped_key(key, normalized_key): continue - # Ignore unsupported keys that are not required by the export model. - continue + raise ValueError( + f"Unexpected checkpoint key not mapped for Qwen3.5 export: {key}" + ) from err converted_state_dict[new_key] = value for key, value in pending_qkvz.items(): diff --git a/examples/models/qwen3_5/tests/test_convert_weights.py b/examples/models/qwen3_5/tests/test_convert_weights.py index 0e7d1d8c8d3..7fa2fb667c7 100644 --- a/examples/models/qwen3_5/tests/test_convert_weights.py +++ b/examples/models/qwen3_5/tests/test_convert_weights.py @@ -39,6 +39,31 @@ def test_maps_full_and_linear_attention_weights(self): self.assertIn("layers.1.attention.out_proj.weight", converted) self.assertIn("layers.1.attention.dt_bias", converted) + def test_raises_on_unexpected_text_key(self): + state_dict = { + "model.embed_tokens.weight": torch.randn(16, 8), + "model.norm.weight": torch.randn(8), + "model.layers.0.unknown.weight": torch.randn(8, 8), + } + + with self.assertRaisesRegex( + ValueError, "Unexpected checkpoint key not mapped for Qwen3.5 export" + ): + qwen_3_5_to_meta(state_dict) + + def test_ignores_known_non_text_keys(self): + state_dict = { + "model.embed_tokens.weight": torch.randn(16, 8), + "model.norm.weight": torch.randn(8), + "mtp.proj.weight": torch.randn(8, 8), + "model.visual.patch_embed.weight": torch.randn(8, 8), + } + + converted = qwen_3_5_to_meta(state_dict) + self.assertIn("tok_embeddings.weight", converted) + self.assertIn("output.weight", converted) + self.assertNotIn("mtp.proj.weight", converted) + if __name__ == "__main__": unittest.main() From fa6ae158a5bf16ddb07fd16f1a91e9d66a9d7787 Mon Sep 17 00:00:00 2001 From: Phineas1500 <41450967+Phineas1500@users.noreply.github.com> Date: Tue, 3 Mar 2026 17:04:18 -0500 Subject: [PATCH 07/11] Refactor Qwen3.5 full attention into MHA and move RMSNormGated --- examples/models/llama/attention.py | 185 ++---------------- examples/models/llama/model_args.py | 5 + examples/models/llama/norm.py | 17 ++ .../llama/tests/test_qwen3_5_attention.py | 28 ++- .../models/qwen3_5/config/0_8b_config.json | 3 +- examples/models/qwen3_5/config/2b_config.json | 3 +- examples/models/qwen3_5/config/4b_config.json | 3 +- 7 files changed, 72 insertions(+), 172 deletions(-) diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index 2d84e4cc5ce..a7b7104ef8e 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from executorch.examples.models.llama.lora import LoRALinear from executorch.examples.models.llama.model_args import ModelArgs -from executorch.examples.models.llama.norm import RMSNorm +from executorch.examples.models.llama.norm import RMSNorm, RMSNormGated from executorch.examples.models.llama.rope import Rope @@ -314,6 +314,7 @@ def update( return self.k_cache, self.v_cache +@register_attention("qwen3_5_full") @register_attention("mha") class AttentionMHA(Attention): def __init__( @@ -347,7 +348,9 @@ def __init__( self.attention_qkv_bias = args.attention_qkv_bias self.use_qk_norm = args.use_qk_norm self.qk_norm_before_rope = args.qk_norm_before_rope + self.use_q_gate = args.use_q_gate self.enable_dynamic_shape = args.enable_dynamic_shape + q_out_dim = self.n_heads * self.head_dim * (2 if self.use_q_gate else 1) if self.use_qk_norm: q_norm_dim = self.head_dim @@ -366,7 +369,7 @@ def __init__( self.wq = ( LoRALinear( in_dim=args.dim, - out_dim=args.n_heads * args.head_dim, + out_dim=q_out_dim, rank=args.r, alpha=args.lora_alpha, dropout=0.0, @@ -374,7 +377,7 @@ def __init__( ) if args.target_modules is not None and "q_proj" in args.target_modules else nn.Linear( - self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias + self.dim, q_out_dim, bias=self.attention_qkv_bias ) ) self.wk = ( @@ -460,10 +463,17 @@ def forward( input_pos = kwargs.get("input_pos") bsz, seqlen, _ = x.shape - # QKV - q, k, v = self.wq(x), self.wk(x), self.wv(x) - # We need view_copy elimination - q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) + if self.use_q_gate: + q_and_gate = self.wq(x).view( + bsz, seqlen, self.n_local_heads, self.head_dim * 2 + ) + q, gate = torch.chunk(q_and_gate, 2, dim=-1) + gate = gate.reshape(bsz, seqlen, -1) + else: + q = self.wq(x).view(bsz, seqlen, self.n_local_heads, self.head_dim) + gate = None + + k, v = self.wk(x), self.wv(x) k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) @@ -500,6 +510,8 @@ def forward( input_pos[0].item(), seqlen ) output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask) + if gate is not None: + output = output * torch.sigmoid(gate) return self.wo(output), None # grouped multiquery attention: expand out keys and values @@ -513,6 +525,8 @@ def forward( output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) output = output.transpose(1, 2).reshape(bsz, seqlen, -1) + if gate is not None: + output = output * torch.sigmoid(gate) output = self.wo(output) @@ -524,163 +538,6 @@ def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor: return x * inv_norm -class RMSNormGated(nn.Module): - def __init__(self, hidden_size: int, eps: float = 1e-6): - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - hidden_states = self.weight * hidden_states.to(input_dtype) - hidden_states = hidden_states * F.silu(gate.to(torch.float32)) - return hidden_states.to(input_dtype) - - -@register_attention("qwen3_5_full") -class AttentionQwen3_5Full(Attention): - """Qwen3.5 full-attention block with q-gating.""" - - def __init__( - self, - args: ModelArgs, - layer_id: int, - rope: Rope, - **_kwargs: Any, - ): - super().__init__() - self.use_kv_cache = args.use_kv_cache - self.n_heads = args.n_heads - self.n_kv_heads = self.n_heads if args.n_kv_heads is None else args.n_kv_heads - assert self.n_heads % self.n_kv_heads == 0 - self.n_local_heads = self.n_heads - self.n_local_kv_heads = self.n_kv_heads - self.n_rep = self.n_local_heads // self.n_local_kv_heads - self.head_dim = args.head_dim - self.max_context_len = args.max_context_len - self.dim = args.dim - self.attention_qkv_bias = args.attention_qkv_bias - self.use_qk_norm = args.use_qk_norm - self.qk_norm_before_rope = args.qk_norm_before_rope - self.enable_dynamic_shape = args.enable_dynamic_shape - - if self.use_qk_norm: - self.q_norm_fn = RMSNorm( - self.head_dim, - eps=args.norm_eps, - add_unit_offset=args.rms_norm_add_unit_offset, - ) - self.k_norm_fn = RMSNorm( - self.head_dim, - eps=args.norm_eps, - add_unit_offset=args.rms_norm_add_unit_offset, - ) - - self.wq = nn.Linear( - self.dim, self.n_heads * self.head_dim * 2, bias=self.attention_qkv_bias - ) - self.wk = nn.Linear( - self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias - ) - self.wv = nn.Linear( - self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias - ) - self.wo = nn.Linear( - self.n_heads * self.head_dim, - self.dim, - bias=False, - ) - - self.layer_id = layer_id - self.rope = rope - - causal_mask = torch.tril( - torch.ones( - self.max_context_len, - self.max_context_len, - dtype=torch.bool, - device="cpu", - ) - ) - self.register_buffer("mask", causal_mask, persistent=False) - - if self.use_kv_cache: - self.kv_cache = KVCache( - args.max_batch_size, - args.max_context_len, - self.n_kv_heads, - self.head_dim, - args.enable_dynamic_shape, - ) - self.SDPA = SDPA( - dim=self.n_local_heads * self.head_dim, - head_dim=self.head_dim, - n_rep=self.n_rep, - max_context_len=self.max_context_len, - ) - - def forward( - self, - x: torch.Tensor, - freqs_cos: torch.Tensor, - freqs_sin: torch.Tensor, - **kwargs: ForwardOptions, - ) -> Tuple[torch.Tensor, Optional[Any]]: - input_pos = kwargs.get("input_pos") - bsz, seqlen, _ = x.shape - - # Q and gate are packed in q_proj output. - q_and_gate = self.wq(x).view(bsz, seqlen, self.n_local_heads, self.head_dim * 2) - q, gate = torch.chunk(q_and_gate, 2, dim=-1) - gate = gate.reshape(bsz, seqlen, -1) - - k = self.wk(x).view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) - v = self.wv(x).view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) - - if self.use_qk_norm and self.qk_norm_before_rope: - q = self.q_norm_fn(q) - k = self.k_norm_fn(k) - - q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - if self.use_qk_norm and not self.qk_norm_before_rope: - q = self.q_norm_fn(q) - k = self.k_norm_fn(k) - - if self.use_kv_cache: - assert input_pos is not None - if self.enable_dynamic_shape: - start_pos = input_pos[-1].item() - torch._check_is_size(start_pos) - torch._check(start_pos < self.max_context_len) - seq_length = q.size(2) - attn_mask = self.mask.narrow(0, start_pos, seq_length) - else: - attn_mask = self.mask[input_pos] - k, v = self.kv_cache.update(input_pos, k, v) - if getattr(self.kv_cache, "is_ring_buffer", False): - attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer( - input_pos[0].item(), seqlen - ) - output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask) - output = output * torch.sigmoid(gate) - return self.wo(output), None - - k = k.repeat_interleave(self.n_rep, dim=1) - v = v.repeat_interleave(self.n_rep, dim=1) - mask = self.mask[:seqlen, :seqlen] - output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) - output = output.transpose(1, 2).reshape(bsz, seqlen, -1) - output = output * torch.sigmoid(gate) - return self.wo(output), None - - @register_attention("gated_deltanet") class AttentionGatedDeltaNet(Attention): """Qwen3.5 linear-attention (Gated DeltaNet) block with internal state.""" diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index eb2a83db6ed..05e9ea62a8a 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -72,6 +72,7 @@ class ModelArgs: num_experts: int = 8 # Number of experts num_activated_experts: int = 2 # Number of experts to activate attention_type: str = "mha" # Attention type, registered in attention.py + use_q_gate: bool = False # Use q-gated projection in attention (Qwen3.5 full attention) norm_type: str = "rmsnorm" # Normalization type, registered in norm.py act_fn: ActFn = dataclasses.field(default=ActFn.SILU) # Activation function type attention_qkv_bias: bool = False @@ -156,6 +157,10 @@ def __post_init__(self): if self.n_kv_heads is None: self.n_kv_heads = self.n_heads + # Backward compatibility: qwen3_5_full attention name implies q-gated MHA. + if self.attention_type == "qwen3_5_full": + self.use_q_gate = True + # rope_theta overrides rope_freq_base since it's the official name. if self.rope_theta is not None: self.rope_freq_base = self.rope_theta diff --git a/examples/models/llama/norm.py b/examples/models/llama/norm.py index a9944db7a00..6a8217b5a24 100644 --- a/examples/models/llama/norm.py +++ b/examples/models/llama/norm.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import torch +import torch.nn.functional as F from torch import nn @@ -56,3 +57,19 @@ def forward(self, x): if self.add_unit_offset: return output * (1.0 + self.weight.float()).type_as(x) return output * self.weight.type_as(x) + + +class RMSNormGated(nn.Module): + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = self.weight * hidden_states.to(input_dtype) + hidden_states = hidden_states * F.silu(gate.to(torch.float32)) + return hidden_states.to(input_dtype) diff --git a/examples/models/llama/tests/test_qwen3_5_attention.py b/examples/models/llama/tests/test_qwen3_5_attention.py index 9de9244f960..d5598eb4196 100644 --- a/examples/models/llama/tests/test_qwen3_5_attention.py +++ b/examples/models/llama/tests/test_qwen3_5_attention.py @@ -27,11 +27,12 @@ def test_qwen35_full_attention_output_proj_is_bias_free(self): use_kv_cache=False, use_qk_norm=False, qk_norm_before_rope=True, - attention_type="qwen3_5_full", + attention_type="mha", + use_q_gate=True, attention_qkv_bias=True, ) rope = Rope(args) - attn = ATTENTION_REGISTRY["qwen3_5_full"](args, 0, rope) + attn = ATTENTION_REGISTRY["mha"](args, 0, rope) self.assertIsNone(attn.wo.bias) def test_rmsnorm_preserves_input_dtype_without_unit_offset(self): @@ -56,16 +57,32 @@ def test_qwen35_full_attention_forward_shape(self): partial_rotary_factor=0.5, use_qk_norm=True, qk_norm_before_rope=True, - attention_type="qwen3_5_full", + attention_type="mha", + use_q_gate=True, rms_norm_add_unit_offset=True, ) rope = Rope(args) - attn = ATTENTION_REGISTRY["qwen3_5_full"](args, 0, rope) + attn = ATTENTION_REGISTRY["mha"](args, 0, rope) x = torch.randn(1, 3, args.dim) freqs_cos, freqs_sin = rope.get_freqs(None, x.shape[1]) y, _ = attn(x, freqs_cos, freqs_sin) self.assertEqual(y.shape, x.shape) + def test_qwen35_full_attention_legacy_name_maps_to_gated_mha(self): + args = ModelArgs( + dim=32, + n_layers=1, + n_heads=4, + n_kv_heads=2, + head_dim=8, + hidden_dim=64, + attention_type="qwen3_5_full", + ) + self.assertTrue(args.use_q_gate) + rope = Rope(args) + attn = ATTENTION_REGISTRY["qwen3_5_full"](args, 0, rope) + self.assertTrue(attn.use_q_gate) + def test_gated_deltanet_resets_state_on_new_sequence(self): torch.manual_seed(0) args = ModelArgs( @@ -78,7 +95,8 @@ def test_gated_deltanet_resets_state_on_new_sequence(self): max_seq_len=16, max_context_len=16, use_kv_cache=True, - attention_type="qwen3_5_full", + attention_type="mha", + use_q_gate=True, linear_conv_kernel_dim=4, linear_key_head_dim=4, linear_value_head_dim=4, diff --git a/examples/models/qwen3_5/config/0_8b_config.json b/examples/models/qwen3_5/config/0_8b_config.json index e532988410d..89799cf9f53 100644 --- a/examples/models/qwen3_5/config/0_8b_config.json +++ b/examples/models/qwen3_5/config/0_8b_config.json @@ -14,7 +14,8 @@ "attention_qkv_bias": false, "use_qk_norm": true, "qk_norm_before_rope": true, - "attention_type": "qwen3_5_full", + "attention_type": "mha", + "use_q_gate": true, "rms_norm_add_unit_offset": true, "linear_conv_kernel_dim": 4, "linear_key_head_dim": 128, diff --git a/examples/models/qwen3_5/config/2b_config.json b/examples/models/qwen3_5/config/2b_config.json index e397bf6fd4a..15f62ccca74 100644 --- a/examples/models/qwen3_5/config/2b_config.json +++ b/examples/models/qwen3_5/config/2b_config.json @@ -14,7 +14,8 @@ "attention_qkv_bias": false, "use_qk_norm": true, "qk_norm_before_rope": true, - "attention_type": "qwen3_5_full", + "attention_type": "mha", + "use_q_gate": true, "rms_norm_add_unit_offset": true, "linear_conv_kernel_dim": 4, "linear_key_head_dim": 128, diff --git a/examples/models/qwen3_5/config/4b_config.json b/examples/models/qwen3_5/config/4b_config.json index 75c4313b83a..068653a8d5c 100644 --- a/examples/models/qwen3_5/config/4b_config.json +++ b/examples/models/qwen3_5/config/4b_config.json @@ -14,7 +14,8 @@ "attention_qkv_bias": false, "use_qk_norm": true, "qk_norm_before_rope": true, - "attention_type": "qwen3_5_full", + "attention_type": "mha", + "use_q_gate": true, "rms_norm_add_unit_offset": true, "linear_conv_kernel_dim": 4, "linear_key_head_dim": 128, From da085adda6dce76094b4cd14ecd08007b30cc72a Mon Sep 17 00:00:00 2001 From: Sriram Kiron Date: Wed, 4 Mar 2026 18:46:03 -0500 Subject: [PATCH 08/11] Harden Qwen3.5 conversion key handling and tests --- examples/models/llama/norm.py | 2 + examples/models/qwen3_5/convert_weights.py | 20 ++++--- .../qwen3_5/tests/test_convert_weights.py | 54 +++++++++++++++++++ 3 files changed, 69 insertions(+), 7 deletions(-) diff --git a/examples/models/llama/norm.py b/examples/models/llama/norm.py index 6a8217b5a24..0189c88b13b 100644 --- a/examples/models/llama/norm.py +++ b/examples/models/llama/norm.py @@ -17,6 +17,8 @@ def __init__(self, dim: int, eps: float = 1e-6, add_unit_offset: bool = False): Args: dim (int): The dimension of the input tensor. eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + add_unit_offset (bool, optional): Whether to scale normalized output by + `(1 + weight)` instead of `weight`. Default is False. Attributes: eps (float): A small value added to the denominator for numerical stability. diff --git a/examples/models/qwen3_5/convert_weights.py b/examples/models/qwen3_5/convert_weights.py index b0a8fc47305..d08ec441f0a 100644 --- a/examples/models/qwen3_5/convert_weights.py +++ b/examples/models/qwen3_5/convert_weights.py @@ -32,7 +32,6 @@ "model.layers.{}.linear_attn.in_proj_b.weight": "layers.{}.attention.in_proj_b.weight", "model.layers.{}.linear_attn.in_proj_a.weight": "layers.{}.attention.in_proj_a.weight", "model.layers.{}.linear_attn.conv1d.weight": "layers.{}.attention.conv1d.weight", - "model.layers.{}.linear_attn.conv1d.bias": "layers.{}.attention.conv1d.bias", "model.layers.{}.linear_attn.dt_bias": "layers.{}.attention.dt_bias", "model.layers.{}.linear_attn.A_log": "layers.{}.attention.A_log", "model.layers.{}.linear_attn.norm.weight": "layers.{}.attention.norm.weight", @@ -53,6 +52,7 @@ _IGNORED_UNMAPPED_SUFFIXES = ( "rotary_emb.inv_freq", + "linear_attn.conv1d.bias", ) @@ -78,13 +78,17 @@ def _load_checkpoint_from_safetensors(input_dir: str) -> Dict: weight_map = index["weight_map"] checkpoint_shards = sorted(set(weight_map.values())) - shard_to_weights = {} - for shard in checkpoint_shards: - shard_to_weights[shard] = load_file(os.path.join(input_dir, shard)) - merged_state_dict = {} + shard_to_weight_names = {} for weight_name, shard in weight_map.items(): - merged_state_dict[weight_name] = shard_to_weights[shard][weight_name] + shard_to_weight_names.setdefault(shard, []).append(weight_name) + + # Load each shard once and copy only the tensor names mapped to that shard. + # This avoids holding all shard tensors in memory at the same time. + for shard in checkpoint_shards: + shard_weights = load_file(os.path.join(input_dir, shard)) + for weight_name in shard_to_weight_names[shard]: + merged_state_dict[weight_name] = shard_weights[weight_name] return merged_state_dict model_path = os.path.join(input_dir, "model.safetensors") @@ -132,7 +136,9 @@ def qwen_3_5_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Ten ): if _should_ignore_unmapped_key(key, normalized_key): continue - continue + raise ValueError( + f"Unexpected non-text checkpoint key not mapped for Qwen3.5 export: {key}" + ) # Legacy packed tensors (older checkpoints): # in_proj_qkvz -> split into in_proj_qkv and in_proj_z diff --git a/examples/models/qwen3_5/tests/test_convert_weights.py b/examples/models/qwen3_5/tests/test_convert_weights.py index 7fa2fb667c7..52b6981c7fb 100644 --- a/examples/models/qwen3_5/tests/test_convert_weights.py +++ b/examples/models/qwen3_5/tests/test_convert_weights.py @@ -64,6 +64,60 @@ def test_ignores_known_non_text_keys(self): self.assertIn("output.weight", converted) self.assertNotIn("mtp.proj.weight", converted) + def test_raises_on_unexpected_non_text_key(self): + state_dict = { + "model.embed_tokens.weight": torch.randn(16, 8), + "model.norm.weight": torch.randn(8), + "vision_tower.blocks.0.weight": torch.randn(8, 8), + } + + with self.assertRaisesRegex( + ValueError, + "Unexpected non-text checkpoint key not mapped for Qwen3.5 export", + ): + qwen_3_5_to_meta(state_dict) + + def test_ignores_linear_attention_conv1d_bias(self): + state_dict = { + "model.embed_tokens.weight": torch.randn(16, 8), + "model.norm.weight": torch.randn(8), + "model.layers.1.linear_attn.conv1d.weight": torch.randn(24, 1, 4), + "model.layers.1.linear_attn.conv1d.bias": torch.randn(24), + "model.layers.1.linear_attn.out_proj.weight": torch.randn(8, 8), + } + + converted = qwen_3_5_to_meta(state_dict) + self.assertIn("layers.1.attention.conv1d.weight", converted) + self.assertIn("layers.1.attention.out_proj.weight", converted) + self.assertNotIn("layers.1.attention.conv1d.bias", converted) + + def test_splits_legacy_packed_linear_attention_weights(self): + qkvz = torch.arange(32 * 8, dtype=torch.float32).reshape(32, 8) + ba = torch.arange(4 * 8, dtype=torch.float32).reshape(4, 8) + out_proj = torch.randn(8, 8) + state_dict = { + "model.embed_tokens.weight": torch.randn(16, 8), + "model.norm.weight": torch.randn(8), + "model.layers.1.linear_attn.out_proj.weight": out_proj, + "model.layers.1.linear_attn.in_proj_qkvz.weight": qkvz, + "model.layers.1.linear_attn.in_proj_ba.weight": ba, + } + + converted = qwen_3_5_to_meta(state_dict) + + self.assertTrue( + torch.equal(converted["layers.1.attention.in_proj_qkv.weight"], qkvz[:24]) + ) + self.assertTrue( + torch.equal(converted["layers.1.attention.in_proj_z.weight"], qkvz[24:]) + ) + self.assertTrue( + torch.equal(converted["layers.1.attention.in_proj_b.weight"], ba[:2]) + ) + self.assertTrue( + torch.equal(converted["layers.1.attention.in_proj_a.weight"], ba[2:]) + ) + if __name__ == "__main__": unittest.main() From 38e993a7a90adee72e1141c7c1dc0f114199ff2d Mon Sep 17 00:00:00 2001 From: Sriram Kiron Date: Wed, 4 Mar 2026 18:47:17 -0500 Subject: [PATCH 09/11] Fix prefill fallback and narrow dynamic-shape metadata errors --- examples/models/llama/runner/generation.py | 16 ++++++++-- examples/models/llama/runner/native.py | 15 ++++++++- .../llama/tests/test_generation_prefill.py | 32 +++++++++++++++++-- 3 files changed, 58 insertions(+), 5 deletions(-) diff --git a/examples/models/llama/runner/generation.py b/examples/models/llama/runner/generation.py index 95fe47147db..730cc6e070c 100644 --- a/examples/models/llama/runner/generation.py +++ b/examples/models/llama/runner/generation.py @@ -104,14 +104,26 @@ def _prefill_with_kv_cache( ), input_pos=torch.tensor([pos_base], dtype=torch.long, device=self.device), ) - except RuntimeError: + except RuntimeError as err: # Some exported models use a static single-token shape for kv-cache mode. # Fall back to sequential token prefill so multi-token prompts still work. - if self.enable_dynamic_shape or len(prompt_tokens) <= 1: + if ( + len(prompt_tokens) <= 1 + or not self._looks_like_static_prefill_shape_error(err) + ): raise return self._sequential_kv_prefill(prompt_tokens, pos_base) + @staticmethod + def _looks_like_static_prefill_shape_error(err: RuntimeError) -> bool: + err_msg = str(err).lower() + shape_tokens = ("shape", "size", "dimension", "dim") + prefill_tokens = ("seq", "sequence", "token", "input") + return any(token in err_msg for token in shape_tokens) and any( + token in err_msg for token in prefill_tokens + ) + def _sequential_kv_prefill( self, prompt_tokens: List[int], diff --git a/examples/models/llama/runner/native.py b/examples/models/llama/runner/native.py index ffb19ab3c08..9cff422f09b 100644 --- a/examples/models/llama/runner/native.py +++ b/examples/models/llama/runner/native.py @@ -27,6 +27,13 @@ from executorch.kernels import quantized # noqa +def _is_missing_method_error(err: RuntimeError, method_name: str) -> bool: + err_msg = str(err).lower() + method_name = method_name.lower() + missing_tokens = ("not found", "does not exist", "missing") + return method_name in err_msg and any(token in err_msg for token in missing_tokens) + + class NativeLlamaRunner(LlamaRunner): """ Runs llama via ExecuTorch with provided pte file. @@ -48,9 +55,15 @@ def __init__(self, args): self.enable_dynamic_shape = bool( self.model.run_method("enable_dynamic_shape")[0] ) - except Exception: + except AttributeError: # Keep default behavior when metadata method is unavailable. pass + except RuntimeError as err: + if _is_missing_method_error(err, "enable_dynamic_shape"): + # Keep default behavior when metadata method is unavailable. + pass + else: + raise def forward( self, diff --git a/examples/models/llama/tests/test_generation_prefill.py b/examples/models/llama/tests/test_generation_prefill.py index b3cd3f68eb1..34b5dd185ea 100644 --- a/examples/models/llama/tests/test_generation_prefill.py +++ b/examples/models/llama/tests/test_generation_prefill.py @@ -27,9 +27,14 @@ def decode_token(self, token_id): class _DummyRunner(LlamaRunner): - def __init__(self, raise_on_parallel_prefill=False): + def __init__( + self, + raise_on_parallel_prefill: bool = False, + parallel_prefill_error: str = "parallel prefill failure", + ): self.calls = [] self.raise_on_parallel_prefill = raise_on_parallel_prefill + self.parallel_prefill_error = parallel_prefill_error super().__init__( tokenizer_path="unused", tokenizer_config_path=None, @@ -45,7 +50,7 @@ def forward(self, tokens: torch.Tensor, input_pos=None) -> torch.Tensor: (tokens.clone(), input_pos.clone() if input_pos is not None else None) ) if self.raise_on_parallel_prefill and tokens.shape[1] > 1: - raise RuntimeError("parallel prefill failure") + raise RuntimeError(self.parallel_prefill_error) return torch.zeros((1, 8), dtype=torch.float32) @@ -94,6 +99,29 @@ def test_dynamic_prefill_does_not_mask_runtime_errors(self, _mock_get_tokenizer) self.assertEqual(len(runner.calls), 1) + @patch( + "executorch.examples.models.llama.runner.generation.get_tokenizer", + return_value=_DummyTokenizer(), + ) + def test_dynamic_prefill_falls_back_on_shape_error(self, _mock_get_tokenizer): + runner = _DummyRunner( + raise_on_parallel_prefill=True, + parallel_prefill_error="input token shape mismatch for prefill", + ) + runner.enable_dynamic_shape = True + + runner._prefill_with_kv_cache([5, 6, 7], pos_base=3) + + # First call is attempted batched prefill, then sequential fallback. + self.assertEqual(len(runner.calls), 4) + first_tokens, first_input_pos = runner.calls[0] + self.assertEqual(tuple(first_tokens.shape), (1, 3)) + self.assertEqual(first_input_pos.item(), 3) + for i, (tokens, input_pos) in enumerate(runner.calls[1:]): + self.assertEqual(tuple(tokens.shape), (1, 1)) + self.assertEqual(tokens.item(), 5 + i) + self.assertEqual(input_pos.item(), 3 + i) + if __name__ == "__main__": unittest.main() From d7d38d611bef64e4da79202d4173693455a9976d Mon Sep 17 00:00:00 2001 From: Sriram Kiron Date: Wed, 4 Mar 2026 19:08:27 -0500 Subject: [PATCH 10/11] Reset gated DeltaNet state when input_pos is omitted --- examples/models/llama/attention.py | 2 ++ .../llama/tests/test_qwen3_5_attention.py | 34 +++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index a7b7104ef8e..ed536f88b3d 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -621,6 +621,8 @@ def _maybe_reset_state( self, input_pos: Optional[torch.Tensor], batch_size: int ) -> None: if input_pos is None: + self.conv_state[:batch_size].zero_() + self.recurrent_state[:batch_size].zero_() return reset = (input_pos[0] == 0).to(self.conv_state.dtype) keep = 1.0 - reset diff --git a/examples/models/llama/tests/test_qwen3_5_attention.py b/examples/models/llama/tests/test_qwen3_5_attention.py index d5598eb4196..71e4a14634e 100644 --- a/examples/models/llama/tests/test_qwen3_5_attention.py +++ b/examples/models/llama/tests/test_qwen3_5_attention.py @@ -123,6 +123,40 @@ def test_gated_deltanet_resets_state_on_new_sequence(self): state_after_reset = attn.recurrent_state.clone() self.assertTrue(torch.allclose(state_after_first, state_after_reset, atol=1e-5)) + def test_gated_deltanet_no_input_pos_does_not_leak_state(self): + torch.manual_seed(0) + args = ModelArgs( + dim=32, + n_layers=1, + n_heads=4, + n_kv_heads=2, + head_dim=8, + hidden_dim=64, + max_seq_len=16, + max_context_len=16, + use_kv_cache=True, + attention_type="mha", + use_q_gate=True, + linear_conv_kernel_dim=4, + linear_key_head_dim=4, + linear_value_head_dim=4, + linear_num_key_heads=2, + linear_num_value_heads=4, + ) + rope = Rope(args) + attn = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope) + + x = torch.randn(1, 1, args.dim) + dummy_freq = torch.zeros(1, 1) + + attn(x, dummy_freq, dummy_freq) + state_after_first = attn.recurrent_state.clone() + + attn(x, dummy_freq, dummy_freq) + state_after_second = attn.recurrent_state.clone() + + self.assertTrue(torch.allclose(state_after_first, state_after_second, atol=1e-5)) + if __name__ == "__main__": unittest.main() From ca341288b5c2d4ddd49ea8a6a2348ef960198685 Mon Sep 17 00:00:00 2001 From: Sriram Kiron Date: Wed, 4 Mar 2026 19:10:06 -0500 Subject: [PATCH 11/11] Mark existing complex export/config helpers for lint --- examples/models/llama/export_llama_lib.py | 2 +- examples/models/llama/model_args.py | 2 +- examples/models/qwen3_5/convert_weights.py | 4 +++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 05a08fc5a5b..c95682bb986 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -647,7 +647,7 @@ def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str: return return_val -def export_llama( +def export_llama( # noqa: C901 export_options: Union[argparse.Namespace, LlmConfig, DictConfig], ) -> str: if isinstance(export_options, argparse.Namespace): diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index 05e9ea62a8a..3bfb095be3d 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -153,7 +153,7 @@ class ModelArgs: final_logit_softcapping: Optional[float] = None attn_logit_softcapping: Optional[float] = None - def __post_init__(self): + def __post_init__(self): # noqa: C901 if self.n_kv_heads is None: self.n_kv_heads = self.n_heads diff --git a/examples/models/qwen3_5/convert_weights.py b/examples/models/qwen3_5/convert_weights.py index d08ec441f0a..afd3f1a3117 100644 --- a/examples/models/qwen3_5/convert_weights.py +++ b/examples/models/qwen3_5/convert_weights.py @@ -116,7 +116,9 @@ def load_checkpoint(input_dir: str) -> Dict: raise FileNotFoundError(f"Could not find checkpoint in {input_dir}") -def qwen_3_5_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: +def qwen_3_5_to_meta( # noqa: C901 + state_dict: Dict[str, torch.Tensor], +) -> Dict[str, torch.Tensor]: converted_state_dict = {} pending_qkvz = {} pending_ba = {}