From f4442de48bece613fe7365b1a9ca7c2d897887d4 Mon Sep 17 00:00:00 2001 From: Phineas1500 <41450967+Phineas1500@users.noreply.github.com> Date: Mon, 2 Mar 2026 18:37:04 -0500 Subject: [PATCH 01/15] Add Qwen3.5 export support for 0.8B/2B/4B --- examples/models/BUCK | 1 + examples/models/llama/__init__.py | 17 +- examples/models/llama/attention.py | 385 +++++++++++++++++- examples/models/llama/export_llama_lib.py | 8 + examples/models/llama/llama_transformer.py | 33 +- examples/models/llama/model_args.py | 17 + examples/models/llama/norm.py | 7 +- examples/models/llama/tests/BUCK | 11 + .../llama/tests/test_qwen3_5_attention.py | 83 ++++ examples/models/qwen3_5/BUCK | 23 ++ examples/models/qwen3_5/README.md | 52 +++ examples/models/qwen3_5/__init__.py | 24 ++ .../models/qwen3_5/config/0_8b_config.json | 50 +++ examples/models/qwen3_5/config/2b_config.json | 50 +++ examples/models/qwen3_5/config/4b_config.json | 58 +++ .../qwen3_5/config/qwen3_5_xnnpack_fp32.yaml | 17 + examples/models/qwen3_5/convert_weights.py | 201 +++++++++ examples/models/qwen3_5/tests/__init__.py | 2 + .../qwen3_5/tests/test_convert_weights.py | 44 ++ extension/llm/export/config/llm_config.py | 3 + 20 files changed, 1076 insertions(+), 10 deletions(-) create mode 100644 examples/models/llama/tests/test_qwen3_5_attention.py create mode 100644 examples/models/qwen3_5/BUCK create mode 100644 examples/models/qwen3_5/README.md create mode 100644 examples/models/qwen3_5/__init__.py create mode 100644 examples/models/qwen3_5/config/0_8b_config.json create mode 100644 examples/models/qwen3_5/config/2b_config.json create mode 100644 examples/models/qwen3_5/config/4b_config.json create mode 100644 examples/models/qwen3_5/config/qwen3_5_xnnpack_fp32.yaml create mode 100644 examples/models/qwen3_5/convert_weights.py create mode 100644 examples/models/qwen3_5/tests/__init__.py create mode 100644 examples/models/qwen3_5/tests/test_convert_weights.py diff --git a/examples/models/BUCK b/examples/models/BUCK index 94cc88360bf..a2b6789a95e 100644 --- a/examples/models/BUCK +++ b/examples/models/BUCK @@ -29,6 +29,7 @@ fbcode_target(_kind = python_library, "//executorch/examples/models/gemma3:gemma3", # @manual "//executorch/examples/models/qwen2_5:qwen2_5", # @manual "//executorch/examples/models/qwen3:qwen3", # @manual + "//executorch/examples/models/qwen3_5:qwen3_5", # @manual "//executorch/examples/models/phi_4_mini:phi_4_mini", # @manual "//executorch/examples/models/smollm2:smollm2", # @manual "//executorch/examples/models/smollm3:smollm3", # @manual diff --git a/examples/models/llama/__init__.py b/examples/models/llama/__init__.py index db6124ecc71..51f2a5c916f 100644 --- a/examples/models/llama/__init__.py +++ b/examples/models/llama/__init__.py @@ -4,8 +4,17 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from .model import Llama2Model +from typing import TYPE_CHECKING -__all__ = [ - Llama2Model, -] +if TYPE_CHECKING: + from .model import Llama2Model + +__all__ = ["Llama2Model"] + + +def __getattr__(name): + if name == "Llama2Model": + from .model import Llama2Model + + return Llama2Model + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index 922bbbb37fa..da5f1b46b29 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -352,8 +352,16 @@ def __init__( if self.use_qk_norm: q_norm_dim = self.head_dim k_norm_dim = self.head_dim - self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps) - self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps) + self.q_norm_fn = RMSNorm( + q_norm_dim, + eps=args.norm_eps, + add_unit_offset=args.rms_norm_add_unit_offset, + ) + self.k_norm_fn = RMSNorm( + k_norm_dim, + eps=args.norm_eps, + add_unit_offset=args.rms_norm_add_unit_offset, + ) self.wq = ( LoRALinear( @@ -511,6 +519,379 @@ def forward( return output, None +def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor: + inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) + return x * inv_norm + + +class RMSNormGated(nn.Module): + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = self.weight * hidden_states.to(input_dtype) + hidden_states = hidden_states * F.silu(gate.to(torch.float32)) + return hidden_states.to(input_dtype) + + +@register_attention("qwen3_5_full") +class AttentionQwen3_5Full(Attention): + """Qwen3.5 full-attention block with q-gating.""" + + def __init__( + self, + args: ModelArgs, + layer_id: int, + rope: Rope, + **_kwargs: Any, + ): + super().__init__() + self.use_kv_cache = args.use_kv_cache + self.n_heads = args.n_heads + self.n_kv_heads = self.n_heads if args.n_kv_heads is None else args.n_kv_heads + assert self.n_heads % self.n_kv_heads == 0 + self.n_local_heads = self.n_heads + self.n_local_kv_heads = self.n_kv_heads + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = args.head_dim + self.max_context_len = args.max_context_len + self.dim = args.dim + self.attention_qkv_bias = args.attention_qkv_bias + self.use_qk_norm = args.use_qk_norm + self.qk_norm_before_rope = args.qk_norm_before_rope + self.enable_dynamic_shape = args.enable_dynamic_shape + + if self.use_qk_norm: + self.q_norm_fn = RMSNorm( + self.head_dim, + eps=args.norm_eps, + add_unit_offset=args.rms_norm_add_unit_offset, + ) + self.k_norm_fn = RMSNorm( + self.head_dim, + eps=args.norm_eps, + add_unit_offset=args.rms_norm_add_unit_offset, + ) + + self.wq = nn.Linear( + self.dim, self.n_heads * self.head_dim * 2, bias=self.attention_qkv_bias + ) + self.wk = nn.Linear( + self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias + ) + self.wv = nn.Linear( + self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias + ) + self.wo = nn.Linear( + self.n_heads * self.head_dim, + self.dim, + bias=self.attention_qkv_bias, + ) + + self.layer_id = layer_id + self.rope = rope + + causal_mask = torch.tril( + torch.ones( + self.max_context_len, + self.max_context_len, + dtype=torch.bool, + device="cpu", + ) + ) + self.register_buffer("mask", causal_mask, persistent=False) + + if self.use_kv_cache: + self.kv_cache = KVCache( + args.max_batch_size, + args.max_context_len, + self.n_kv_heads, + self.head_dim, + args.enable_dynamic_shape, + ) + self.SDPA = SDPA( + dim=self.n_local_heads * self.head_dim, + head_dim=self.head_dim, + n_rep=self.n_rep, + max_context_len=self.max_context_len, + ) + + def forward( + self, + x: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + **kwargs: ForwardOptions, + ) -> Tuple[torch.Tensor, Optional[Any]]: + input_pos = kwargs.get("input_pos") + bsz, seqlen, _ = x.shape + + # Q and gate are packed in q_proj output. + q_and_gate = self.wq(x).view(bsz, seqlen, self.n_local_heads, self.head_dim * 2) + q, gate = torch.chunk(q_and_gate, 2, dim=-1) + gate = gate.reshape(bsz, seqlen, -1) + + k = self.wk(x).view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + v = self.wv(x).view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + + if self.use_qk_norm and self.qk_norm_before_rope: + q = self.q_norm_fn(q) + k = self.k_norm_fn(k) + + q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + if self.use_qk_norm and not self.qk_norm_before_rope: + q = self.q_norm_fn(q) + k = self.k_norm_fn(k) + + if self.use_kv_cache: + assert input_pos is not None + if self.enable_dynamic_shape: + start_pos = input_pos[-1].item() + torch._check_is_size(start_pos) + torch._check(start_pos < self.max_context_len) + seq_length = q.size(2) + attn_mask = self.mask.narrow(0, start_pos, seq_length) + else: + attn_mask = self.mask[input_pos] + k, v = self.kv_cache.update(input_pos, k, v) + if getattr(self.kv_cache, "is_ring_buffer", False): + attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer( + input_pos[0].item(), seqlen + ) + output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask) + output = output * torch.sigmoid(gate) + return self.wo(output), None + + k = k.repeat_interleave(self.n_rep, dim=1) + v = v.repeat_interleave(self.n_rep, dim=1) + mask = self.mask[:seqlen, :seqlen] + output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + output = output.transpose(1, 2).reshape(bsz, seqlen, -1) + output = output * torch.sigmoid(gate) + return self.wo(output), None + + +@register_attention("gated_deltanet") +class AttentionGatedDeltaNet(Attention): + """Qwen3.5 linear-attention (Gated DeltaNet) block with internal state.""" + + def __init__( + self, + args: ModelArgs, + layer_id: int, + rope: Rope, + **_kwargs: Any, + ): + super().__init__() + del rope # DeltaNet layers do not use RoPE. + + self.hidden_size = args.dim + self.max_batch_size = args.max_batch_size + self.layer_id = layer_id + + assert args.linear_num_key_heads is not None + assert args.linear_num_value_heads is not None + assert args.linear_key_head_dim is not None + assert args.linear_value_head_dim is not None + + self.num_k_heads = args.linear_num_key_heads + self.num_v_heads = args.linear_num_value_heads + self.head_k_dim = args.linear_key_head_dim + self.head_v_dim = args.linear_value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + self.conv_kernel_size = args.linear_conv_kernel_dim + + assert ( + self.num_v_heads % self.num_k_heads == 0 + ), "linear_num_value_heads must be divisible by linear_num_key_heads." + self.head_repeat = self.num_v_heads // self.num_k_heads + + self.conv_dim = self.key_dim * 2 + self.value_dim + self.in_proj_qkv = nn.Linear(self.hidden_size, self.conv_dim, bias=False) + self.in_proj_z = nn.Linear(self.hidden_size, self.value_dim, bias=False) + self.in_proj_b = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) + self.in_proj_a = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) + + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + kernel_size=self.conv_kernel_size, + groups=self.conv_dim, + bias=False, + padding=0, + ) + + self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads)) + A = torch.empty(self.num_v_heads).uniform_(0, 16) + self.A_log = nn.Parameter(torch.log(A)) + self.norm = RMSNormGated(self.head_v_dim, eps=args.norm_eps) + self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) + + self.register_buffer( + "conv_state", + torch.zeros( + self.max_batch_size, + self.conv_dim, + self.conv_kernel_size, + dtype=torch.float32, + device="cpu", + ), + ) + self.register_buffer( + "recurrent_state", + torch.zeros( + self.max_batch_size, + self.num_v_heads, + self.head_k_dim, + self.head_v_dim, + dtype=torch.float32, + device="cpu", + ), + ) + + def _maybe_reset_state(self, input_pos: Optional[torch.Tensor], batch_size: int) -> None: + if input_pos is None: + return + reset = (input_pos[0] == 0).to(self.conv_state.dtype) + keep = 1.0 - reset + self.conv_state[:batch_size].mul_(keep) + self.recurrent_state[:batch_size].mul_(keep) + + def _apply_causal_conv(self, mixed_qkv: torch.Tensor) -> torch.Tensor: + # mixed_qkv: (batch, seq_len, conv_dim) + batch_size, seq_len, _ = mixed_qkv.shape + mixed_qkv = mixed_qkv.transpose(1, 2) + state_len = self.conv_state.shape[-1] + hidden_states_new = torch.cat([self.conv_state[:batch_size], mixed_qkv], dim=-1) + new_conv_state = hidden_states_new[:, :, -state_len:] + with torch.no_grad(): + self.conv_state[:batch_size].copy_(new_conv_state.to(self.conv_state.dtype)) + out = F.conv1d( + hidden_states_new, + self.conv1d.weight, + self.conv1d.bias, + padding=0, + groups=self.conv_dim, + ) + out = F.silu(out[:, :, -seq_len:]).to(mixed_qkv.dtype) + return out.transpose(1, 2).contiguous() + + def _recurrent_gated_delta_rule( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + ) -> torch.Tensor: + # query/key/value: (batch, seq_len, num_heads, head_dim) + # g/beta: (batch, seq_len, num_heads) + initial_dtype = query.dtype + query = _l2norm(query, dim=-1, eps=1e-6) + key = _l2norm(key, dim=-1, eps=1e-6) + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) + for x in (query, key, value, beta, g) + ] + + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + scale = 1.0 / (query.shape[-1] ** 0.5) + query = query * scale + + core_attn_out = torch.zeros( + batch_size, + num_heads, + sequence_length, + v_head_dim, + device=value.device, + dtype=value.dtype, + ) + last_recurrent_state = self.recurrent_state[:batch_size].to(value.dtype) + + for i in range(sequence_length): + q_t = query[:, :, i] + k_t = key[:, :, i] + v_t = value[:, :, i] + g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1) + beta_t = beta[:, :, i].unsqueeze(-1) + + last_recurrent_state = last_recurrent_state * g_t + kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) + delta = (v_t - kv_mem) * beta_t + last_recurrent_state = ( + last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) + ) + core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum( + dim=-2 + ) + + with torch.no_grad(): + self.recurrent_state[:batch_size].copy_( + last_recurrent_state.to(self.recurrent_state.dtype) + ) + + return core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) + + def forward( + self, + x: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + **kwargs: ForwardOptions, + ) -> Tuple[torch.Tensor, Optional[Any]]: + del freqs_cos + del freqs_sin + input_pos = kwargs.get("input_pos") + batch_size, seq_len, _ = x.shape + assert ( + batch_size <= self.max_batch_size + ), f"batch_size ({batch_size}) exceeds max_batch_size ({self.max_batch_size})" + + self._maybe_reset_state(input_pos, batch_size) + + mixed_qkv = self.in_proj_qkv(x) + z = self.in_proj_z(x).reshape(batch_size, seq_len, -1, self.head_v_dim) + b = self.in_proj_b(x) + a = self.in_proj_a(x) + + mixed_qkv = self._apply_causal_conv(mixed_qkv) + query, key, value = torch.split( + mixed_qkv, + [self.key_dim, self.key_dim, self.value_dim], + dim=-1, + ) + query = query.reshape(batch_size, seq_len, -1, self.head_k_dim) + key = key.reshape(batch_size, seq_len, -1, self.head_k_dim) + value = value.reshape(batch_size, seq_len, -1, self.head_v_dim) + + if self.head_repeat > 1: + query = query.repeat_interleave(self.head_repeat, dim=2) + key = key.repeat_interleave(self.head_repeat, dim=2) + + beta = b.sigmoid() + g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + core_attn_out = self._recurrent_gated_delta_rule(query, key, value, g, beta) + + core_attn_out = core_attn_out.reshape(-1, self.head_v_dim) + z = z.reshape(-1, self.head_v_dim) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1) + + return self.out_proj(core_attn_out), None + + @register_attention("skip") class AttentionSkip(Attention): def __init__(self, *args, **kwargs): diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index ee81139dedf..94984565e10 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -105,6 +105,9 @@ "qwen3_0_6b", "qwen3_1_7b", "qwen3_4b", + "qwen3_5_0_8b", + "qwen3_5_2b", + "qwen3_5_4b", "phi_4_mini", "smollm2", "lfm2_350m", # hybrid @@ -122,6 +125,9 @@ "qwen3_0_6b": "Qwen/Qwen3-0.6B", "qwen3_1_7b": "Qwen/Qwen3-1.7B", "qwen3_4b": "Qwen/Qwen3-4B", + "qwen3_5_0_8b": "Qwen/Qwen3.5-0.8B", + "qwen3_5_2b": "Qwen/Qwen3.5-2B", + "qwen3_5_4b": "Qwen/Qwen3.5-4B", "lfm2_350m": "LiquidAI/LFM2-350M", "lfm2_700m": "LiquidAI/LFM2-700M", "lfm2_1_2b": "LiquidAI/LFM2-1.2B", @@ -643,6 +649,8 @@ def export_llama( repo_id = HUGGING_FACE_REPO_IDS[model_name] if model_name.startswith("qwen2_5"): from executorch.examples.models.qwen2_5 import convert_weights + elif model_name.startswith("qwen3_5"): + from executorch.examples.models.qwen3_5 import convert_weights elif model_name.startswith("qwen3"): from executorch.examples.models.qwen3 import convert_weights elif model_name == "phi_4_mini": diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index a4ca52aa608..c79f9cb6ce0 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -105,8 +105,16 @@ def __init__(self, args: ModelArgs, attention: Attention): if isinstance(self.attention, AttentionSkip): self.attention_norm = nn.Identity() else: - self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) - self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.attention_norm = RMSNorm( + args.dim, + eps=args.norm_eps, + add_unit_offset=args.rms_norm_add_unit_offset, + ) + self.ffn_norm = RMSNorm( + args.dim, + eps=args.norm_eps, + add_unit_offset=args.rms_norm_add_unit_offset, + ) @classmethod def from_type(cls, layer_id, args, rope) -> "TransformerBlock": @@ -164,7 +172,11 @@ def __init__(self, params: ModelArgs, layers: nn.ModuleList, rope: Rope): ) self.layers = layers self.rope = rope - self.norm = RMSNorm(params.dim, eps=params.norm_eps) + self.norm = RMSNorm( + params.dim, + eps=params.norm_eps, + add_unit_offset=params.rms_norm_add_unit_offset, + ) self.output = ( nn.Linear(params.dim, params.vocab_size, bias=False) if self.apply_output @@ -279,6 +291,21 @@ def construct_transformer(model_args: ModelArgs) -> Transformer: attention = AttentionSkip() transformer_block = TransformerBlock(model_args, attention) layers.append(transformer_block) + elif ( + model_args.layer_types + and model_args.layer_types[layer_id] == "linear_attention" + ): + linear_cls = ATTENTION_REGISTRY.get("gated_deltanet") + if linear_cls is None: + raise ValueError( + "Unknown attention type: gated_deltanet. " + f"Available: {list(ATTENTION_REGISTRY.keys())}" + ) + attention = linear_cls( + model_args, layer_id, rope, **model_args.attention_kwargs + ) + transformer_block = TransformerBlock(model_args, attention) + layers.append(transformer_block) else: attention = cls( model_args, layer_id, rope, **model_args.attention_kwargs diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index c225a1ee9b2..eb2a83db6ed 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -44,6 +44,14 @@ class ModelArgs: vocab_size: int = 512 # Arbitrary value, should be defined later by tokenizer. hidden_dim: Optional[int] = None head_dim: Optional[int] = None # Optional customized head_dim + # Qwen3.5 linear-attention dimensions. + linear_conv_kernel_dim: int = 4 + linear_key_head_dim: Optional[int] = None + linear_value_head_dim: Optional[int] = None + linear_num_key_heads: Optional[int] = None + linear_num_value_heads: Optional[int] = None + # Qwen3.5 RMSNorm uses (1 + weight) scaling. + rms_norm_add_unit_offset: bool = False multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 ffn_dim_multiplier: Optional[float] = None model_architecture: str = ( @@ -174,6 +182,15 @@ def find_multiple(n: int, k: int) -> int: if self.head_dim is None: self.head_dim = self.dim // self.n_heads + if self.linear_key_head_dim is None: + self.linear_key_head_dim = self.head_dim + if self.linear_value_head_dim is None: + self.linear_value_head_dim = self.head_dim + if self.linear_num_key_heads is None: + self.linear_num_key_heads = self.n_heads + if self.linear_num_value_heads is None: + self.linear_num_value_heads = self.n_heads + # Convert string act_fn to enum if needed if isinstance(self.act_fn, str): self.act_fn = ActFn.from_string(self.act_fn) diff --git a/examples/models/llama/norm.py b/examples/models/llama/norm.py index 3786e61cd05..52ae30c1697 100644 --- a/examples/models/llama/norm.py +++ b/examples/models/llama/norm.py @@ -9,7 +9,9 @@ class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): + def __init__( + self, dim: int, eps: float = 1e-6, add_unit_offset: bool = False + ): """ Initialize the RMSNorm normalization layer. @@ -25,6 +27,7 @@ def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.dim = dim self.eps = eps + self.add_unit_offset = add_unit_offset self.weight = nn.Parameter(torch.ones(dim)) def _norm(self, x): @@ -52,4 +55,6 @@ def forward(self, x): """ output = self._norm(x.float()).type_as(x) + if self.add_unit_offset: + return output * (1.0 + self.weight.float()).type_as(x) return output * self.weight diff --git a/examples/models/llama/tests/BUCK b/examples/models/llama/tests/BUCK index 8f4dec2237b..431c3c92814 100644 --- a/examples/models/llama/tests/BUCK +++ b/examples/models/llama/tests/BUCK @@ -15,6 +15,17 @@ fbcode_target(_kind = python_unittest, ], ) +fbcode_target(_kind = python_unittest, + name = "test_qwen3_5_attention", + srcs = [ + "test_qwen3_5_attention.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/examples/models/llama:llama_transformer", + ], +) + fbcode_target(_kind = python_unittest, name = "test_pre_quantization_transforms", srcs = [ diff --git a/examples/models/llama/tests/test_qwen3_5_attention.py b/examples/models/llama/tests/test_qwen3_5_attention.py new file mode 100644 index 00000000000..b076e69d121 --- /dev/null +++ b/examples/models/llama/tests/test_qwen3_5_attention.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.examples.models.llama.attention import ATTENTION_REGISTRY +from executorch.examples.models.llama.model_args import ModelArgs +from executorch.examples.models.llama.rope import Rope + + +class Qwen35AttentionTest(unittest.TestCase): + def test_qwen35_full_attention_forward_shape(self): + torch.manual_seed(0) + args = ModelArgs( + dim=32, + n_layers=1, + n_heads=4, + n_kv_heads=2, + head_dim=8, + hidden_dim=64, + max_seq_len=16, + max_context_len=16, + use_kv_cache=False, + use_hf_rope=True, + partial_rotary_factor=0.5, + use_qk_norm=True, + qk_norm_before_rope=True, + attention_type="qwen3_5_full", + rms_norm_add_unit_offset=True, + ) + rope = Rope(args) + attn = ATTENTION_REGISTRY["qwen3_5_full"](args, 0, rope) + x = torch.randn(1, 3, args.dim) + freqs_cos, freqs_sin = rope.get_freqs(None, x.shape[1]) + y, _ = attn(x, freqs_cos, freqs_sin) + self.assertEqual(y.shape, x.shape) + + def test_gated_deltanet_resets_state_on_new_sequence(self): + torch.manual_seed(0) + args = ModelArgs( + dim=32, + n_layers=1, + n_heads=4, + n_kv_heads=2, + head_dim=8, + hidden_dim=64, + max_seq_len=16, + max_context_len=16, + use_kv_cache=True, + attention_type="qwen3_5_full", + linear_conv_kernel_dim=4, + linear_key_head_dim=4, + linear_value_head_dim=4, + linear_num_key_heads=2, + linear_num_value_heads=4, + ) + rope = Rope(args) + attn = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope) + + x = torch.randn(1, 1, args.dim) + dummy_freq = torch.zeros(1, 1) + + # First token of sequence. + attn(x, dummy_freq, dummy_freq, input_pos=torch.tensor([0], dtype=torch.long)) + state_after_first = attn.recurrent_state.clone() + + # Decode continuation updates state. + attn(x, dummy_freq, dummy_freq, input_pos=torch.tensor([1], dtype=torch.long)) + state_after_second = attn.recurrent_state.clone() + self.assertFalse(torch.allclose(state_after_first, state_after_second)) + + # New sequence (input_pos=0) should reset internal state. + attn(x, dummy_freq, dummy_freq, input_pos=torch.tensor([0], dtype=torch.long)) + state_after_reset = attn.recurrent_state.clone() + self.assertTrue(torch.allclose(state_after_first, state_after_reset, atol=1e-5)) + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/models/qwen3_5/BUCK b/examples/models/qwen3_5/BUCK new file mode 100644 index 00000000000..9ee84a0ec5e --- /dev/null +++ b/examples/models/qwen3_5/BUCK @@ -0,0 +1,23 @@ +load("@fbcode_macros//build_defs:build_file_migration.bzl", "fbcode_target", "non_fbcode_target") +# Any targets that should be shared between fbcode and xplat must be defined in +# targets.bzl. This file can contain fbcode-only targets. + +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +oncall("executorch") + +fbcode_target(_kind = runtime.python_library, + name = "qwen3_5", + srcs = [ + "__init__.py", + "convert_weights.py", + ], + base_module = "executorch.examples.models.qwen3_5", + visibility = ["PUBLIC"], + deps = [ + "//caffe2:torch", + "//executorch/examples/models/llama:llama2_model", + "//executorch/examples/models:checkpoint", + "fbsource//third-party/pypi/safetensors:safetensors", + ], +) diff --git a/examples/models/qwen3_5/README.md b/examples/models/qwen3_5/README.md new file mode 100644 index 00000000000..315d603d655 --- /dev/null +++ b/examples/models/qwen3_5/README.md @@ -0,0 +1,52 @@ +## Summary +[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35-684357f8543f83de2f09f998) support in ExecuTorch is exported through the Llama example pipeline with a hybrid layer layout: +- `full_attention` layers use gated full attention. +- `linear_attention` layers use Gated DeltaNet with internal recurrent state. + +This first bring-up is **fp32 + static shape** (`enable_dynamic_shape=False`). +Currently supported text model sizes: `0.8B`, `2B`, `4B`. + +## Export +```bash +python -m extension.llm.export.export_llm \ + --config examples/models/qwen3_5/config/qwen3_5_xnnpack_fp32.yaml \ + +base.model_class="qwen3_5_0_8b" \ + +base.params="examples/models/qwen3_5/config/0_8b_config.json" \ + +export.output_name="qwen3_5_0_8b_fp32.pte" +``` + +```bash +python -m extension.llm.export.export_llm \ + --config examples/models/qwen3_5/config/qwen3_5_xnnpack_fp32.yaml \ + +base.model_class="qwen3_5_2b" \ + +base.params="examples/models/qwen3_5/config/2b_config.json" \ + +export.output_name="qwen3_5_2b_fp32.pte" +``` + +```bash +python -m extension.llm.export.export_llm \ + --config examples/models/qwen3_5/config/qwen3_5_xnnpack_fp32.yaml \ + +base.model_class="qwen3_5_4b" \ + +base.params="examples/models/qwen3_5/config/4b_config.json" \ + +export.output_name="qwen3_5_4b_fp32.pte" +``` + +The exporter will download and convert HF weights automatically when `+base.checkpoint` is not provided. + +## Run (Python Runner) +```bash +python -m examples.models.llama.runner.native \ + --model qwen3_5_0_8b \ + --pte qwen3_5_0_8b_fp32.pte \ + --tokenizer /path/to/tokenizer.json \ + --tokenizer_config /path/to/tokenizer_config.json \ + --prompt "<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n" \ + --params examples/models/qwen3_5/config/0_8b_config.json \ + --max_len 128 \ + -kv \ + --temperature 0.3 +``` + +## Notes +- Current path targets CPU/XNNPACK export validation and runner compatibility. +- `q8da4w` quantization for Qwen3.5 is intentionally deferred to a follow-up. diff --git a/examples/models/qwen3_5/__init__.py b/examples/models/qwen3_5/__init__.py new file mode 100644 index 00000000000..3177b5c4958 --- /dev/null +++ b/examples/models/qwen3_5/__init__.py @@ -0,0 +1,24 @@ +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import TYPE_CHECKING + +from executorch.examples.models.qwen3_5.convert_weights import convert_weights + +if TYPE_CHECKING: + from executorch.examples.models.llama.model import Llama2Model + +__all__ = ["Qwen3_5Model", "convert_weights"] + + +def __getattr__(name): + if name == "Qwen3_5Model": + from executorch.examples.models.llama.model import Llama2Model + + class Qwen3_5Model(Llama2Model): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + globals()["Qwen3_5Model"] = Qwen3_5Model + return Qwen3_5Model + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") diff --git a/examples/models/qwen3_5/config/0_8b_config.json b/examples/models/qwen3_5/config/0_8b_config.json new file mode 100644 index 00000000000..e532988410d --- /dev/null +++ b/examples/models/qwen3_5/config/0_8b_config.json @@ -0,0 +1,50 @@ +{ + "dim": 1024, + "hidden_dim": 3584, + "n_heads": 8, + "head_dim": 256, + "n_kv_heads": 2, + "n_layers": 24, + "norm_eps": 1e-6, + "rope_theta": 10000000.0, + "use_scaled_rope": false, + "vocab_size": 248320, + "use_hf_rope": true, + "partial_rotary_factor": 0.25, + "attention_qkv_bias": false, + "use_qk_norm": true, + "qk_norm_before_rope": true, + "attention_type": "qwen3_5_full", + "rms_norm_add_unit_offset": true, + "linear_conv_kernel_dim": 4, + "linear_key_head_dim": 128, + "linear_value_head_dim": 128, + "linear_num_key_heads": 16, + "linear_num_value_heads": 16, + "layer_types": [ + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention" + ] +} diff --git a/examples/models/qwen3_5/config/2b_config.json b/examples/models/qwen3_5/config/2b_config.json new file mode 100644 index 00000000000..e397bf6fd4a --- /dev/null +++ b/examples/models/qwen3_5/config/2b_config.json @@ -0,0 +1,50 @@ +{ + "dim": 2048, + "hidden_dim": 6144, + "n_heads": 8, + "head_dim": 256, + "n_kv_heads": 2, + "n_layers": 24, + "norm_eps": 1e-6, + "rope_theta": 10000000.0, + "use_scaled_rope": false, + "vocab_size": 248320, + "use_hf_rope": true, + "partial_rotary_factor": 0.25, + "attention_qkv_bias": false, + "use_qk_norm": true, + "qk_norm_before_rope": true, + "attention_type": "qwen3_5_full", + "rms_norm_add_unit_offset": true, + "linear_conv_kernel_dim": 4, + "linear_key_head_dim": 128, + "linear_value_head_dim": 128, + "linear_num_key_heads": 16, + "linear_num_value_heads": 16, + "layer_types": [ + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention" + ] +} diff --git a/examples/models/qwen3_5/config/4b_config.json b/examples/models/qwen3_5/config/4b_config.json new file mode 100644 index 00000000000..75c4313b83a --- /dev/null +++ b/examples/models/qwen3_5/config/4b_config.json @@ -0,0 +1,58 @@ +{ + "dim": 2560, + "hidden_dim": 9216, + "n_heads": 16, + "head_dim": 256, + "n_kv_heads": 4, + "n_layers": 32, + "norm_eps": 1e-6, + "rope_theta": 10000000.0, + "use_scaled_rope": false, + "vocab_size": 248320, + "use_hf_rope": true, + "partial_rotary_factor": 0.25, + "attention_qkv_bias": false, + "use_qk_norm": true, + "qk_norm_before_rope": true, + "attention_type": "qwen3_5_full", + "rms_norm_add_unit_offset": true, + "linear_conv_kernel_dim": 4, + "linear_key_head_dim": 128, + "linear_value_head_dim": 128, + "linear_num_key_heads": 16, + "linear_num_value_heads": 32, + "layer_types": [ + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention" + ] +} diff --git a/examples/models/qwen3_5/config/qwen3_5_xnnpack_fp32.yaml b/examples/models/qwen3_5/config/qwen3_5_xnnpack_fp32.yaml new file mode 100644 index 00000000000..496f2da6e2b --- /dev/null +++ b/examples/models/qwen3_5/config/qwen3_5_xnnpack_fp32.yaml @@ -0,0 +1,17 @@ +base: + metadata: '{"get_bos_id": 151644, "get_eos_ids":[151645]}' + +model: + use_kv_cache: True + use_sdpa_with_kv_cache: False + enable_dynamic_shape: False + dtype_override: fp32 + +export: + max_seq_length: 2048 + max_context_length: 2048 + +backend: + xnnpack: + enabled: True + extended_ops: True diff --git a/examples/models/qwen3_5/convert_weights.py b/examples/models/qwen3_5/convert_weights.py new file mode 100644 index 00000000000..1963961fdfc --- /dev/null +++ b/examples/models/qwen3_5/convert_weights.py @@ -0,0 +1,201 @@ +import argparse +import json +import os +import re +from typing import Dict + +import torch +from executorch.examples.models.checkpoint import ( + get_mapped_key, + load_checkpoint_from_pytorch_model, +) + +_QWEN_3_5_TO_META = { + "model.embed_tokens.weight": "tok_embeddings.weight", + "model.norm.weight": "norm.weight", + "lm_head.weight": "output.weight", + "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", + "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", + "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", + # Full-attention layers. + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", + "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", + "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", + "model.layers.{}.self_attn.q_norm.weight": "layers.{}.attention.q_norm_fn.weight", + "model.layers.{}.self_attn.k_norm.weight": "layers.{}.attention.k_norm_fn.weight", + # Linear-attention layers. + "model.layers.{}.linear_attn.in_proj_qkv.weight": "layers.{}.attention.in_proj_qkv.weight", + "model.layers.{}.linear_attn.in_proj_z.weight": "layers.{}.attention.in_proj_z.weight", + "model.layers.{}.linear_attn.in_proj_b.weight": "layers.{}.attention.in_proj_b.weight", + "model.layers.{}.linear_attn.in_proj_a.weight": "layers.{}.attention.in_proj_a.weight", + "model.layers.{}.linear_attn.conv1d.weight": "layers.{}.attention.conv1d.weight", + "model.layers.{}.linear_attn.conv1d.bias": "layers.{}.attention.conv1d.bias", + "model.layers.{}.linear_attn.dt_bias": "layers.{}.attention.dt_bias", + "model.layers.{}.linear_attn.A_log": "layers.{}.attention.A_log", + "model.layers.{}.linear_attn.norm.weight": "layers.{}.attention.norm.weight", + "model.layers.{}.linear_attn.out_proj.weight": "layers.{}.attention.out_proj.weight", +} + + +def _load_checkpoint_from_safetensors(input_dir: str) -> Dict: + from safetensors.torch import load_file + + index_path = os.path.join(input_dir, "model.safetensors.index.json") + if os.path.exists(index_path): + with open(index_path, "r") as f: + index = json.load(f) + weight_map = index["weight_map"] + checkpoint_shards = sorted(set(weight_map.values())) + + shard_to_weights = {} + for shard in checkpoint_shards: + shard_to_weights[shard] = load_file(os.path.join(input_dir, shard)) + + merged_state_dict = {} + for weight_name, shard in weight_map.items(): + merged_state_dict[weight_name] = shard_to_weights[shard][weight_name] + return merged_state_dict + + model_path = os.path.join(input_dir, "model.safetensors") + if os.path.exists(model_path): + return load_file(model_path) + + raise FileNotFoundError(f"Could not find safetensors checkpoint in {input_dir}") + + +def load_checkpoint(input_dir: str) -> Dict: + try: + print("Loading checkpoint from pytorch_model directory") + return load_checkpoint_from_pytorch_model(input_dir) + except FileNotFoundError: + print( + "Could not find pytorch_model checkpoints in directory, trying safetensors" + ) + + try: + print("Loading checkpoint from safetensors directory") + return _load_checkpoint_from_safetensors(input_dir) + except FileNotFoundError: + pass + + raise FileNotFoundError(f"Could not find checkpoint in {input_dir}") + + +def qwen_3_5_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + converted_state_dict = {} + pending_qkvz = {} + pending_ba = {} + + for key, value in state_dict.items(): + normalized_key = key + # HF multimodal Qwen3.5 checkpoints store text weights under + # `model.language_model.*`. Normalize to `model.*`. + if normalized_key.startswith("model.language_model."): + normalized_key = normalized_key.replace( + "model.language_model.", "model.", 1 + ) + + # Legacy packed tensors (older checkpoints): + # in_proj_qkvz -> split into in_proj_qkv and in_proj_z + # in_proj_ba -> split into in_proj_b and in_proj_a + if normalized_key.endswith(".linear_attn.in_proj_qkvz.weight"): + pending_qkvz[normalized_key] = value + continue + if normalized_key.endswith(".linear_attn.in_proj_ba.weight"): + pending_ba[normalized_key] = value + continue + + try: + new_key = get_mapped_key(normalized_key, _QWEN_3_5_TO_META) + except Exception: + # Ignore non-text weights and training-only extras (e.g., MTP). + if ( + key.startswith("mtp.") + or key.startswith("model.visual.") + or ".vision_" in key + or key.startswith("visual.") + ): + continue + # Ignore unsupported keys that are not required by the export model. + continue + converted_state_dict[new_key] = value + + for key, value in pending_qkvz.items(): + layer_match = re.search(r"model\.layers\.(\d+)\.", key) + if layer_match is None: + raise ValueError(f"Failed to parse layer id from key: {key}") + layer_id = layer_match.group(1) + out_proj_key = f"layers.{layer_id}.attention.out_proj.weight" + if out_proj_key not in converted_state_dict: + raise ValueError( + f"Cannot split {key}: missing {out_proj_key} to infer value dimension." + ) + + value_dim = converted_state_dict[out_proj_key].shape[1] + total_dim = value.shape[0] + conv_dim = total_dim - value_dim + if conv_dim <= 0 or (conv_dim - value_dim) % 2 != 0: + raise ValueError( + f"Invalid packed in_proj_qkvz shape for {key}: {tuple(value.shape)}" + ) + key_dim = (conv_dim - value_dim) // 2 + + qkv, z = torch.split(value, [conv_dim, value_dim], dim=0) + converted_state_dict[f"layers.{layer_id}.attention.in_proj_qkv.weight"] = qkv + converted_state_dict[f"layers.{layer_id}.attention.in_proj_z.weight"] = z + print(f"Split legacy packed key {key} -> in_proj_qkv + in_proj_z") + + for key, value in pending_ba.items(): + layer_match = re.search(r"model\.layers\.(\d+)\.", key) + if layer_match is None: + raise ValueError(f"Failed to parse layer id from key: {key}") + layer_id = layer_match.group(1) + if value.shape[0] % 2 != 0: + raise ValueError( + f"Invalid packed in_proj_ba shape for {key}: {tuple(value.shape)}" + ) + half = value.shape[0] // 2 + b, a = torch.split(value, [half, half], dim=0) + converted_state_dict[f"layers.{layer_id}.attention.in_proj_b.weight"] = b + converted_state_dict[f"layers.{layer_id}.attention.in_proj_a.weight"] = a + print(f"Split legacy packed key {key} -> in_proj_b + in_proj_a") + + # Handle tied embeddings. + if "lm_head.weight" not in state_dict: + converted_state_dict["output.weight"] = converted_state_dict[ + "tok_embeddings.weight" + ] + + return converted_state_dict + + +def convert_weights(input_dir: str, output_file: str) -> None: + print("Loading checkpoint...") + state_dict = load_checkpoint(input_dir) + print("Converting checkpoint...") + state_dict = qwen_3_5_to_meta(state_dict) + print("Saving checkpoint...") + torch.save(state_dict, output_file) + print("Done.") + + +def main(): + parser = argparse.ArgumentParser( + description="Convert Qwen3.5 weights to ExecuTorch meta format." + ) + parser.add_argument( + "input_dir", + type=str, + help="Path to directory containing safetensor or PyTorch checkpoint files.", + ) + parser.add_argument("output", type=str, help="Path to the output checkpoint") + + args = parser.parse_args() + convert_weights(args.input_dir, args.output) + + +if __name__ == "__main__": + main() diff --git a/examples/models/qwen3_5/tests/__init__.py b/examples/models/qwen3_5/tests/__init__.py new file mode 100644 index 00000000000..5e827e62ea1 --- /dev/null +++ b/examples/models/qwen3_5/tests/__init__.py @@ -0,0 +1,2 @@ +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/examples/models/qwen3_5/tests/test_convert_weights.py b/examples/models/qwen3_5/tests/test_convert_weights.py new file mode 100644 index 00000000000..0e7d1d8c8d3 --- /dev/null +++ b/examples/models/qwen3_5/tests/test_convert_weights.py @@ -0,0 +1,44 @@ +import unittest + +import torch +from executorch.examples.models.qwen3_5.convert_weights import qwen_3_5_to_meta + + +class Qwen35ConvertWeightsTest(unittest.TestCase): + def test_maps_full_and_linear_attention_weights(self): + state_dict = { + "model.embed_tokens.weight": torch.randn(16, 8), + "model.norm.weight": torch.randn(8), + "lm_head.weight": torch.randn(16, 8), + "model.layers.0.input_layernorm.weight": torch.randn(8), + "model.layers.0.post_attention_layernorm.weight": torch.randn(8), + "model.layers.0.mlp.gate_proj.weight": torch.randn(12, 8), + "model.layers.0.mlp.down_proj.weight": torch.randn(8, 12), + "model.layers.0.mlp.up_proj.weight": torch.randn(12, 8), + "model.layers.0.self_attn.q_proj.weight": torch.randn(16, 8), + "model.layers.0.self_attn.k_proj.weight": torch.randn(8, 8), + "model.layers.0.self_attn.v_proj.weight": torch.randn(8, 8), + "model.layers.0.self_attn.o_proj.weight": torch.randn(8, 8), + "model.layers.0.self_attn.q_norm.weight": torch.randn(4), + "model.layers.0.self_attn.k_norm.weight": torch.randn(4), + "model.layers.1.linear_attn.in_proj_qkv.weight": torch.randn(24, 8), + "model.layers.1.linear_attn.in_proj_z.weight": torch.randn(8, 8), + "model.layers.1.linear_attn.in_proj_b.weight": torch.randn(2, 8), + "model.layers.1.linear_attn.in_proj_a.weight": torch.randn(2, 8), + "model.layers.1.linear_attn.conv1d.weight": torch.randn(24, 1, 4), + "model.layers.1.linear_attn.dt_bias": torch.randn(2), + "model.layers.1.linear_attn.A_log": torch.randn(2), + "model.layers.1.linear_attn.norm.weight": torch.randn(4), + "model.layers.1.linear_attn.out_proj.weight": torch.randn(8, 8), + } + + converted = qwen_3_5_to_meta(state_dict) + self.assertIn("layers.0.attention.wq.weight", converted) + self.assertIn("layers.0.attention.q_norm_fn.weight", converted) + self.assertIn("layers.1.attention.in_proj_qkv.weight", converted) + self.assertIn("layers.1.attention.out_proj.weight", converted) + self.assertIn("layers.1.attention.dt_bias", converted) + + +if __name__ == "__main__": + unittest.main() diff --git a/extension/llm/export/config/llm_config.py b/extension/llm/export/config/llm_config.py index eea99ea9b2b..06df3aa38be 100644 --- a/extension/llm/export/config/llm_config.py +++ b/extension/llm/export/config/llm_config.py @@ -44,6 +44,9 @@ class ModelType(str, Enum): qwen3_0_6b = "qwen3_0_6b" qwen3_1_7b = "qwen3_1_7b" qwen3_4b = "qwen3_4b" + qwen3_5_0_8b = "qwen3_5_0_8b" + qwen3_5_2b = "qwen3_5_2b" + qwen3_5_4b = "qwen3_5_4b" phi_4_mini = "phi_4_mini" smollm2 = "smollm2" lfm2_350m = "lfm2_350m" From 8d0e7f21fb7614bf4e7822e9c7713e239846a3cd Mon Sep 17 00:00:00 2001 From: Phineas1500 <41450967+Phineas1500@users.noreply.github.com> Date: Mon, 2 Mar 2026 18:56:10 -0500 Subject: [PATCH 02/15] Fix lint and URL checks for Qwen3.5 files --- examples/models/llama/attention.py | 10 ++++++---- examples/models/llama/norm.py | 4 +--- examples/models/qwen3_5/README.md | 2 +- examples/models/qwen3_5/__init__.py | 5 ----- examples/models/qwen3_5/convert_weights.py | 2 -- 5 files changed, 8 insertions(+), 15 deletions(-) diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index da5f1b46b29..315d4ccddbf 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -760,7 +760,9 @@ def __init__( ), ) - def _maybe_reset_state(self, input_pos: Optional[torch.Tensor], batch_size: int) -> None: + def _maybe_reset_state( + self, input_pos: Optional[torch.Tensor], batch_size: int + ) -> None: if input_pos is None: return reset = (input_pos[0] == 0).to(self.conv_state.dtype) @@ -830,9 +832,9 @@ def _recurrent_gated_delta_rule( last_recurrent_state = last_recurrent_state * g_t kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) delta = (v_t - kv_mem) * beta_t - last_recurrent_state = ( - last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) - ) + last_recurrent_state = last_recurrent_state + k_t.unsqueeze( + -1 + ) * delta.unsqueeze(-2) core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum( dim=-2 ) diff --git a/examples/models/llama/norm.py b/examples/models/llama/norm.py index 52ae30c1697..c56cc61c7a5 100644 --- a/examples/models/llama/norm.py +++ b/examples/models/llama/norm.py @@ -9,9 +9,7 @@ class RMSNorm(torch.nn.Module): - def __init__( - self, dim: int, eps: float = 1e-6, add_unit_offset: bool = False - ): + def __init__(self, dim: int, eps: float = 1e-6, add_unit_offset: bool = False): """ Initialize the RMSNorm normalization layer. diff --git a/examples/models/qwen3_5/README.md b/examples/models/qwen3_5/README.md index 315d603d655..f1de77b438a 100644 --- a/examples/models/qwen3_5/README.md +++ b/examples/models/qwen3_5/README.md @@ -1,5 +1,5 @@ ## Summary -[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35-684357f8543f83de2f09f998) support in ExecuTorch is exported through the Llama example pipeline with a hybrid layer layout: +[Qwen3.5](https://huggingface.co/Qwen/Qwen3.5-4B) support in ExecuTorch is exported through the Llama example pipeline with a hybrid layer layout: - `full_attention` layers use gated full attention. - `linear_attention` layers use Gated DeltaNet with internal recurrent state. diff --git a/examples/models/qwen3_5/__init__.py b/examples/models/qwen3_5/__init__.py index 3177b5c4958..b336832bb4b 100644 --- a/examples/models/qwen3_5/__init__.py +++ b/examples/models/qwen3_5/__init__.py @@ -1,13 +1,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import TYPE_CHECKING - from executorch.examples.models.qwen3_5.convert_weights import convert_weights -if TYPE_CHECKING: - from executorch.examples.models.llama.model import Llama2Model - __all__ = ["Qwen3_5Model", "convert_weights"] diff --git a/examples/models/qwen3_5/convert_weights.py b/examples/models/qwen3_5/convert_weights.py index 1963961fdfc..3566d4264f5 100644 --- a/examples/models/qwen3_5/convert_weights.py +++ b/examples/models/qwen3_5/convert_weights.py @@ -141,8 +141,6 @@ def qwen_3_5_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Ten raise ValueError( f"Invalid packed in_proj_qkvz shape for {key}: {tuple(value.shape)}" ) - key_dim = (conv_dim - value_dim) // 2 - qkv, z = torch.split(value, [conv_dim, value_dim], dim=0) converted_state_dict[f"layers.{layer_id}.attention.in_proj_qkv.weight"] = qkv converted_state_dict[f"layers.{layer_id}.attention.in_proj_z.weight"] = z From 998718455b547f9b114defcfaef9e648ecaebfc3 Mon Sep 17 00:00:00 2001 From: Phineas1500 <41450967+Phineas1500@users.noreply.github.com> Date: Mon, 2 Mar 2026 19:02:35 -0500 Subject: [PATCH 03/15] Make Qwen3.5 full-attention output projection bias-free --- examples/models/llama/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index 315d4ccddbf..2d84e4cc5ce 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -591,7 +591,7 @@ def __init__( self.wo = nn.Linear( self.n_heads * self.head_dim, self.dim, - bias=self.attention_qkv_bias, + bias=False, ) self.layer_id = layer_id From 1e646e4e250ea10ea8ab6aed4b3a6edf249d7368 Mon Sep 17 00:00:00 2001 From: Phineas1500 <41450967+Phineas1500@users.noreply.github.com> Date: Mon, 2 Mar 2026 19:41:39 -0500 Subject: [PATCH 04/15] Fix Qwen3.5 metadata ids and README details --- examples/models/qwen3_5/README.md | 11 +++++++++-- .../models/qwen3_5/config/qwen3_5_xnnpack_fp32.yaml | 6 +++++- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/examples/models/qwen3_5/README.md b/examples/models/qwen3_5/README.md index f1de77b438a..e439f256f2a 100644 --- a/examples/models/qwen3_5/README.md +++ b/examples/models/qwen3_5/README.md @@ -1,5 +1,5 @@ ## Summary -[Qwen3.5](https://huggingface.co/Qwen/Qwen3.5-4B) support in ExecuTorch is exported through the Llama example pipeline with a hybrid layer layout: +[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35-684357f8543f83de2f09f998) support in ExecuTorch is exported through the Llama example pipeline with a hybrid layer layout: - `full_attention` layers use gated full attention. - `linear_attention` layers use Gated DeltaNet with internal recurrent state. @@ -32,10 +32,14 @@ python -m extension.llm.export.export_llm \ ``` The exporter will download and convert HF weights automatically when `+base.checkpoint` is not provided. +Install `safetensors` in your environment if it is missing: +```bash +python -m pip install safetensors +``` ## Run (Python Runner) ```bash -python -m examples.models.llama.runner.native \ +python -m executorch.examples.models.llama.runner.native \ --model qwen3_5_0_8b \ --pte qwen3_5_0_8b_fp32.pte \ --tokenizer /path/to/tokenizer.json \ @@ -50,3 +54,6 @@ python -m examples.models.llama.runner.native \ ## Notes - Current path targets CPU/XNNPACK export validation and runner compatibility. - `q8da4w` quantization for Qwen3.5 is intentionally deferred to a follow-up. +- Dynamic-shape export is not enabled yet for Qwen3.5 DeltaNet layers in this path; keep `enable_dynamic_shape=False`. +- For static-shape exports, `runner.native` falls back to sequential token prefill for multi-token prompts. +- Default metadata uses Qwen3.5 special token ids: `get_bos_id=248045`, `get_eos_ids=[248046,248044]`. diff --git a/examples/models/qwen3_5/config/qwen3_5_xnnpack_fp32.yaml b/examples/models/qwen3_5/config/qwen3_5_xnnpack_fp32.yaml index 496f2da6e2b..93882f67a6e 100644 --- a/examples/models/qwen3_5/config/qwen3_5_xnnpack_fp32.yaml +++ b/examples/models/qwen3_5/config/qwen3_5_xnnpack_fp32.yaml @@ -1,5 +1,9 @@ base: - metadata: '{"get_bos_id": 151644, "get_eos_ids":[151645]}' + # Qwen3.5 tokenizer special ids: + # <|im_start|> 248045 + # <|im_end|> 248046 + # <|endoftext|> 248044 + metadata: '{"get_bos_id": 248045, "get_eos_ids":[248046,248044]}' model: use_kv_cache: True From 939445a38cb1da3ad1138a006e414a3f18fda1cf Mon Sep 17 00:00:00 2001 From: Phineas1500 <41450967+Phineas1500@users.noreply.github.com> Date: Mon, 2 Mar 2026 19:55:35 -0500 Subject: [PATCH 05/15] Harden Qwen3.5 conversion and RMSNorm dtype behavior --- examples/models/llama/norm.py | 2 +- .../llama/tests/test_qwen3_5_attention.py | 27 ++++++++++ examples/models/qwen3_5/convert_weights.py | 51 +++++++++++++++---- .../qwen3_5/tests/test_convert_weights.py | 25 +++++++++ 4 files changed, 94 insertions(+), 11 deletions(-) diff --git a/examples/models/llama/norm.py b/examples/models/llama/norm.py index c56cc61c7a5..a9944db7a00 100644 --- a/examples/models/llama/norm.py +++ b/examples/models/llama/norm.py @@ -55,4 +55,4 @@ def forward(self, x): output = self._norm(x.float()).type_as(x) if self.add_unit_offset: return output * (1.0 + self.weight.float()).type_as(x) - return output * self.weight + return output * self.weight.type_as(x) diff --git a/examples/models/llama/tests/test_qwen3_5_attention.py b/examples/models/llama/tests/test_qwen3_5_attention.py index b076e69d121..9de9244f960 100644 --- a/examples/models/llama/tests/test_qwen3_5_attention.py +++ b/examples/models/llama/tests/test_qwen3_5_attention.py @@ -9,10 +9,37 @@ import torch from executorch.examples.models.llama.attention import ATTENTION_REGISTRY from executorch.examples.models.llama.model_args import ModelArgs +from executorch.examples.models.llama.norm import RMSNorm from executorch.examples.models.llama.rope import Rope class Qwen35AttentionTest(unittest.TestCase): + def test_qwen35_full_attention_output_proj_is_bias_free(self): + args = ModelArgs( + dim=32, + n_layers=1, + n_heads=4, + n_kv_heads=2, + head_dim=8, + hidden_dim=64, + max_seq_len=16, + max_context_len=16, + use_kv_cache=False, + use_qk_norm=False, + qk_norm_before_rope=True, + attention_type="qwen3_5_full", + attention_qkv_bias=True, + ) + rope = Rope(args) + attn = ATTENTION_REGISTRY["qwen3_5_full"](args, 0, rope) + self.assertIsNone(attn.wo.bias) + + def test_rmsnorm_preserves_input_dtype_without_unit_offset(self): + norm = RMSNorm(dim=8, add_unit_offset=False) + x = torch.randn(2, 3, 8, dtype=torch.bfloat16) + y = norm(x) + self.assertEqual(y.dtype, x.dtype) + def test_qwen35_full_attention_forward_shape(self): torch.manual_seed(0) args = ModelArgs( diff --git a/examples/models/qwen3_5/convert_weights.py b/examples/models/qwen3_5/convert_weights.py index 3566d4264f5..b0a8fc47305 100644 --- a/examples/models/qwen3_5/convert_weights.py +++ b/examples/models/qwen3_5/convert_weights.py @@ -40,6 +40,34 @@ } +_IGNORED_UNMAPPED_PREFIXES = ( + "mtp.", + "model.visual.", + "visual.", +) + +_IGNORED_UNMAPPED_SUBSTRINGS = ( + ".vision_", + ".visual.", +) + +_IGNORED_UNMAPPED_SUFFIXES = ( + "rotary_emb.inv_freq", +) + + +def _should_ignore_unmapped_key(key: str, normalized_key: str) -> bool: + candidates = (key, normalized_key) + for candidate in candidates: + if any(candidate.startswith(prefix) for prefix in _IGNORED_UNMAPPED_PREFIXES): + return True + if any(token in candidate for token in _IGNORED_UNMAPPED_SUBSTRINGS): + return True + if any(candidate.endswith(suffix) for suffix in _IGNORED_UNMAPPED_SUFFIXES): + return True + return False + + def _load_checkpoint_from_safetensors(input_dir: str) -> Dict: from safetensors.torch import load_file @@ -98,6 +126,14 @@ def qwen_3_5_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Ten "model.language_model.", "model.", 1 ) + # Ignore non-language-model keys up front. + if not ( + normalized_key.startswith("model.") or normalized_key.startswith("lm_head.") + ): + if _should_ignore_unmapped_key(key, normalized_key): + continue + continue + # Legacy packed tensors (older checkpoints): # in_proj_qkvz -> split into in_proj_qkv and in_proj_z # in_proj_ba -> split into in_proj_b and in_proj_a @@ -110,17 +146,12 @@ def qwen_3_5_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Ten try: new_key = get_mapped_key(normalized_key, _QWEN_3_5_TO_META) - except Exception: - # Ignore non-text weights and training-only extras (e.g., MTP). - if ( - key.startswith("mtp.") - or key.startswith("model.visual.") - or ".vision_" in key - or key.startswith("visual.") - ): + except Exception as err: + if _should_ignore_unmapped_key(key, normalized_key): continue - # Ignore unsupported keys that are not required by the export model. - continue + raise ValueError( + f"Unexpected checkpoint key not mapped for Qwen3.5 export: {key}" + ) from err converted_state_dict[new_key] = value for key, value in pending_qkvz.items(): diff --git a/examples/models/qwen3_5/tests/test_convert_weights.py b/examples/models/qwen3_5/tests/test_convert_weights.py index 0e7d1d8c8d3..7fa2fb667c7 100644 --- a/examples/models/qwen3_5/tests/test_convert_weights.py +++ b/examples/models/qwen3_5/tests/test_convert_weights.py @@ -39,6 +39,31 @@ def test_maps_full_and_linear_attention_weights(self): self.assertIn("layers.1.attention.out_proj.weight", converted) self.assertIn("layers.1.attention.dt_bias", converted) + def test_raises_on_unexpected_text_key(self): + state_dict = { + "model.embed_tokens.weight": torch.randn(16, 8), + "model.norm.weight": torch.randn(8), + "model.layers.0.unknown.weight": torch.randn(8, 8), + } + + with self.assertRaisesRegex( + ValueError, "Unexpected checkpoint key not mapped for Qwen3.5 export" + ): + qwen_3_5_to_meta(state_dict) + + def test_ignores_known_non_text_keys(self): + state_dict = { + "model.embed_tokens.weight": torch.randn(16, 8), + "model.norm.weight": torch.randn(8), + "mtp.proj.weight": torch.randn(8, 8), + "model.visual.patch_embed.weight": torch.randn(8, 8), + } + + converted = qwen_3_5_to_meta(state_dict) + self.assertIn("tok_embeddings.weight", converted) + self.assertIn("output.weight", converted) + self.assertNotIn("mtp.proj.weight", converted) + if __name__ == "__main__": unittest.main() From 108469552aef025b5303b471ca09abe4f39430f0 Mon Sep 17 00:00:00 2001 From: Phineas1500 <41450967+Phineas1500@users.noreply.github.com> Date: Tue, 3 Mar 2026 17:04:18 -0500 Subject: [PATCH 06/15] Refactor Qwen3.5 full attention into MHA and move RMSNormGated --- examples/models/llama/attention.py | 185 ++---------------- examples/models/llama/model_args.py | 5 + examples/models/llama/norm.py | 17 ++ .../llama/tests/test_qwen3_5_attention.py | 28 ++- .../models/qwen3_5/config/0_8b_config.json | 3 +- examples/models/qwen3_5/config/2b_config.json | 3 +- examples/models/qwen3_5/config/4b_config.json | 3 +- 7 files changed, 72 insertions(+), 172 deletions(-) diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index 2d84e4cc5ce..a7b7104ef8e 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from executorch.examples.models.llama.lora import LoRALinear from executorch.examples.models.llama.model_args import ModelArgs -from executorch.examples.models.llama.norm import RMSNorm +from executorch.examples.models.llama.norm import RMSNorm, RMSNormGated from executorch.examples.models.llama.rope import Rope @@ -314,6 +314,7 @@ def update( return self.k_cache, self.v_cache +@register_attention("qwen3_5_full") @register_attention("mha") class AttentionMHA(Attention): def __init__( @@ -347,7 +348,9 @@ def __init__( self.attention_qkv_bias = args.attention_qkv_bias self.use_qk_norm = args.use_qk_norm self.qk_norm_before_rope = args.qk_norm_before_rope + self.use_q_gate = args.use_q_gate self.enable_dynamic_shape = args.enable_dynamic_shape + q_out_dim = self.n_heads * self.head_dim * (2 if self.use_q_gate else 1) if self.use_qk_norm: q_norm_dim = self.head_dim @@ -366,7 +369,7 @@ def __init__( self.wq = ( LoRALinear( in_dim=args.dim, - out_dim=args.n_heads * args.head_dim, + out_dim=q_out_dim, rank=args.r, alpha=args.lora_alpha, dropout=0.0, @@ -374,7 +377,7 @@ def __init__( ) if args.target_modules is not None and "q_proj" in args.target_modules else nn.Linear( - self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias + self.dim, q_out_dim, bias=self.attention_qkv_bias ) ) self.wk = ( @@ -460,10 +463,17 @@ def forward( input_pos = kwargs.get("input_pos") bsz, seqlen, _ = x.shape - # QKV - q, k, v = self.wq(x), self.wk(x), self.wv(x) - # We need view_copy elimination - q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) + if self.use_q_gate: + q_and_gate = self.wq(x).view( + bsz, seqlen, self.n_local_heads, self.head_dim * 2 + ) + q, gate = torch.chunk(q_and_gate, 2, dim=-1) + gate = gate.reshape(bsz, seqlen, -1) + else: + q = self.wq(x).view(bsz, seqlen, self.n_local_heads, self.head_dim) + gate = None + + k, v = self.wk(x), self.wv(x) k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) @@ -500,6 +510,8 @@ def forward( input_pos[0].item(), seqlen ) output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask) + if gate is not None: + output = output * torch.sigmoid(gate) return self.wo(output), None # grouped multiquery attention: expand out keys and values @@ -513,6 +525,8 @@ def forward( output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) output = output.transpose(1, 2).reshape(bsz, seqlen, -1) + if gate is not None: + output = output * torch.sigmoid(gate) output = self.wo(output) @@ -524,163 +538,6 @@ def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor: return x * inv_norm -class RMSNormGated(nn.Module): - def __init__(self, hidden_size: int, eps: float = 1e-6): - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - hidden_states = self.weight * hidden_states.to(input_dtype) - hidden_states = hidden_states * F.silu(gate.to(torch.float32)) - return hidden_states.to(input_dtype) - - -@register_attention("qwen3_5_full") -class AttentionQwen3_5Full(Attention): - """Qwen3.5 full-attention block with q-gating.""" - - def __init__( - self, - args: ModelArgs, - layer_id: int, - rope: Rope, - **_kwargs: Any, - ): - super().__init__() - self.use_kv_cache = args.use_kv_cache - self.n_heads = args.n_heads - self.n_kv_heads = self.n_heads if args.n_kv_heads is None else args.n_kv_heads - assert self.n_heads % self.n_kv_heads == 0 - self.n_local_heads = self.n_heads - self.n_local_kv_heads = self.n_kv_heads - self.n_rep = self.n_local_heads // self.n_local_kv_heads - self.head_dim = args.head_dim - self.max_context_len = args.max_context_len - self.dim = args.dim - self.attention_qkv_bias = args.attention_qkv_bias - self.use_qk_norm = args.use_qk_norm - self.qk_norm_before_rope = args.qk_norm_before_rope - self.enable_dynamic_shape = args.enable_dynamic_shape - - if self.use_qk_norm: - self.q_norm_fn = RMSNorm( - self.head_dim, - eps=args.norm_eps, - add_unit_offset=args.rms_norm_add_unit_offset, - ) - self.k_norm_fn = RMSNorm( - self.head_dim, - eps=args.norm_eps, - add_unit_offset=args.rms_norm_add_unit_offset, - ) - - self.wq = nn.Linear( - self.dim, self.n_heads * self.head_dim * 2, bias=self.attention_qkv_bias - ) - self.wk = nn.Linear( - self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias - ) - self.wv = nn.Linear( - self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias - ) - self.wo = nn.Linear( - self.n_heads * self.head_dim, - self.dim, - bias=False, - ) - - self.layer_id = layer_id - self.rope = rope - - causal_mask = torch.tril( - torch.ones( - self.max_context_len, - self.max_context_len, - dtype=torch.bool, - device="cpu", - ) - ) - self.register_buffer("mask", causal_mask, persistent=False) - - if self.use_kv_cache: - self.kv_cache = KVCache( - args.max_batch_size, - args.max_context_len, - self.n_kv_heads, - self.head_dim, - args.enable_dynamic_shape, - ) - self.SDPA = SDPA( - dim=self.n_local_heads * self.head_dim, - head_dim=self.head_dim, - n_rep=self.n_rep, - max_context_len=self.max_context_len, - ) - - def forward( - self, - x: torch.Tensor, - freqs_cos: torch.Tensor, - freqs_sin: torch.Tensor, - **kwargs: ForwardOptions, - ) -> Tuple[torch.Tensor, Optional[Any]]: - input_pos = kwargs.get("input_pos") - bsz, seqlen, _ = x.shape - - # Q and gate are packed in q_proj output. - q_and_gate = self.wq(x).view(bsz, seqlen, self.n_local_heads, self.head_dim * 2) - q, gate = torch.chunk(q_and_gate, 2, dim=-1) - gate = gate.reshape(bsz, seqlen, -1) - - k = self.wk(x).view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) - v = self.wv(x).view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) - - if self.use_qk_norm and self.qk_norm_before_rope: - q = self.q_norm_fn(q) - k = self.k_norm_fn(k) - - q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - if self.use_qk_norm and not self.qk_norm_before_rope: - q = self.q_norm_fn(q) - k = self.k_norm_fn(k) - - if self.use_kv_cache: - assert input_pos is not None - if self.enable_dynamic_shape: - start_pos = input_pos[-1].item() - torch._check_is_size(start_pos) - torch._check(start_pos < self.max_context_len) - seq_length = q.size(2) - attn_mask = self.mask.narrow(0, start_pos, seq_length) - else: - attn_mask = self.mask[input_pos] - k, v = self.kv_cache.update(input_pos, k, v) - if getattr(self.kv_cache, "is_ring_buffer", False): - attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer( - input_pos[0].item(), seqlen - ) - output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask) - output = output * torch.sigmoid(gate) - return self.wo(output), None - - k = k.repeat_interleave(self.n_rep, dim=1) - v = v.repeat_interleave(self.n_rep, dim=1) - mask = self.mask[:seqlen, :seqlen] - output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) - output = output.transpose(1, 2).reshape(bsz, seqlen, -1) - output = output * torch.sigmoid(gate) - return self.wo(output), None - - @register_attention("gated_deltanet") class AttentionGatedDeltaNet(Attention): """Qwen3.5 linear-attention (Gated DeltaNet) block with internal state.""" diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index eb2a83db6ed..05e9ea62a8a 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -72,6 +72,7 @@ class ModelArgs: num_experts: int = 8 # Number of experts num_activated_experts: int = 2 # Number of experts to activate attention_type: str = "mha" # Attention type, registered in attention.py + use_q_gate: bool = False # Use q-gated projection in attention (Qwen3.5 full attention) norm_type: str = "rmsnorm" # Normalization type, registered in norm.py act_fn: ActFn = dataclasses.field(default=ActFn.SILU) # Activation function type attention_qkv_bias: bool = False @@ -156,6 +157,10 @@ def __post_init__(self): if self.n_kv_heads is None: self.n_kv_heads = self.n_heads + # Backward compatibility: qwen3_5_full attention name implies q-gated MHA. + if self.attention_type == "qwen3_5_full": + self.use_q_gate = True + # rope_theta overrides rope_freq_base since it's the official name. if self.rope_theta is not None: self.rope_freq_base = self.rope_theta diff --git a/examples/models/llama/norm.py b/examples/models/llama/norm.py index a9944db7a00..6a8217b5a24 100644 --- a/examples/models/llama/norm.py +++ b/examples/models/llama/norm.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import torch +import torch.nn.functional as F from torch import nn @@ -56,3 +57,19 @@ def forward(self, x): if self.add_unit_offset: return output * (1.0 + self.weight.float()).type_as(x) return output * self.weight.type_as(x) + + +class RMSNormGated(nn.Module): + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = self.weight * hidden_states.to(input_dtype) + hidden_states = hidden_states * F.silu(gate.to(torch.float32)) + return hidden_states.to(input_dtype) diff --git a/examples/models/llama/tests/test_qwen3_5_attention.py b/examples/models/llama/tests/test_qwen3_5_attention.py index 9de9244f960..d5598eb4196 100644 --- a/examples/models/llama/tests/test_qwen3_5_attention.py +++ b/examples/models/llama/tests/test_qwen3_5_attention.py @@ -27,11 +27,12 @@ def test_qwen35_full_attention_output_proj_is_bias_free(self): use_kv_cache=False, use_qk_norm=False, qk_norm_before_rope=True, - attention_type="qwen3_5_full", + attention_type="mha", + use_q_gate=True, attention_qkv_bias=True, ) rope = Rope(args) - attn = ATTENTION_REGISTRY["qwen3_5_full"](args, 0, rope) + attn = ATTENTION_REGISTRY["mha"](args, 0, rope) self.assertIsNone(attn.wo.bias) def test_rmsnorm_preserves_input_dtype_without_unit_offset(self): @@ -56,16 +57,32 @@ def test_qwen35_full_attention_forward_shape(self): partial_rotary_factor=0.5, use_qk_norm=True, qk_norm_before_rope=True, - attention_type="qwen3_5_full", + attention_type="mha", + use_q_gate=True, rms_norm_add_unit_offset=True, ) rope = Rope(args) - attn = ATTENTION_REGISTRY["qwen3_5_full"](args, 0, rope) + attn = ATTENTION_REGISTRY["mha"](args, 0, rope) x = torch.randn(1, 3, args.dim) freqs_cos, freqs_sin = rope.get_freqs(None, x.shape[1]) y, _ = attn(x, freqs_cos, freqs_sin) self.assertEqual(y.shape, x.shape) + def test_qwen35_full_attention_legacy_name_maps_to_gated_mha(self): + args = ModelArgs( + dim=32, + n_layers=1, + n_heads=4, + n_kv_heads=2, + head_dim=8, + hidden_dim=64, + attention_type="qwen3_5_full", + ) + self.assertTrue(args.use_q_gate) + rope = Rope(args) + attn = ATTENTION_REGISTRY["qwen3_5_full"](args, 0, rope) + self.assertTrue(attn.use_q_gate) + def test_gated_deltanet_resets_state_on_new_sequence(self): torch.manual_seed(0) args = ModelArgs( @@ -78,7 +95,8 @@ def test_gated_deltanet_resets_state_on_new_sequence(self): max_seq_len=16, max_context_len=16, use_kv_cache=True, - attention_type="qwen3_5_full", + attention_type="mha", + use_q_gate=True, linear_conv_kernel_dim=4, linear_key_head_dim=4, linear_value_head_dim=4, diff --git a/examples/models/qwen3_5/config/0_8b_config.json b/examples/models/qwen3_5/config/0_8b_config.json index e532988410d..89799cf9f53 100644 --- a/examples/models/qwen3_5/config/0_8b_config.json +++ b/examples/models/qwen3_5/config/0_8b_config.json @@ -14,7 +14,8 @@ "attention_qkv_bias": false, "use_qk_norm": true, "qk_norm_before_rope": true, - "attention_type": "qwen3_5_full", + "attention_type": "mha", + "use_q_gate": true, "rms_norm_add_unit_offset": true, "linear_conv_kernel_dim": 4, "linear_key_head_dim": 128, diff --git a/examples/models/qwen3_5/config/2b_config.json b/examples/models/qwen3_5/config/2b_config.json index e397bf6fd4a..15f62ccca74 100644 --- a/examples/models/qwen3_5/config/2b_config.json +++ b/examples/models/qwen3_5/config/2b_config.json @@ -14,7 +14,8 @@ "attention_qkv_bias": false, "use_qk_norm": true, "qk_norm_before_rope": true, - "attention_type": "qwen3_5_full", + "attention_type": "mha", + "use_q_gate": true, "rms_norm_add_unit_offset": true, "linear_conv_kernel_dim": 4, "linear_key_head_dim": 128, diff --git a/examples/models/qwen3_5/config/4b_config.json b/examples/models/qwen3_5/config/4b_config.json index 75c4313b83a..068653a8d5c 100644 --- a/examples/models/qwen3_5/config/4b_config.json +++ b/examples/models/qwen3_5/config/4b_config.json @@ -14,7 +14,8 @@ "attention_qkv_bias": false, "use_qk_norm": true, "qk_norm_before_rope": true, - "attention_type": "qwen3_5_full", + "attention_type": "mha", + "use_q_gate": true, "rms_norm_add_unit_offset": true, "linear_conv_kernel_dim": 4, "linear_key_head_dim": 128, From a8b93cfe9c6d7c8ad5fd9443ff7ab3bed11e0315 Mon Sep 17 00:00:00 2001 From: Sriram Kiron Date: Wed, 4 Mar 2026 18:46:03 -0500 Subject: [PATCH 07/15] Harden Qwen3.5 conversion key handling and tests --- examples/models/llama/norm.py | 2 + examples/models/qwen3_5/convert_weights.py | 20 ++++--- .../qwen3_5/tests/test_convert_weights.py | 54 +++++++++++++++++++ 3 files changed, 69 insertions(+), 7 deletions(-) diff --git a/examples/models/llama/norm.py b/examples/models/llama/norm.py index 6a8217b5a24..0189c88b13b 100644 --- a/examples/models/llama/norm.py +++ b/examples/models/llama/norm.py @@ -17,6 +17,8 @@ def __init__(self, dim: int, eps: float = 1e-6, add_unit_offset: bool = False): Args: dim (int): The dimension of the input tensor. eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + add_unit_offset (bool, optional): Whether to scale normalized output by + `(1 + weight)` instead of `weight`. Default is False. Attributes: eps (float): A small value added to the denominator for numerical stability. diff --git a/examples/models/qwen3_5/convert_weights.py b/examples/models/qwen3_5/convert_weights.py index b0a8fc47305..d08ec441f0a 100644 --- a/examples/models/qwen3_5/convert_weights.py +++ b/examples/models/qwen3_5/convert_weights.py @@ -32,7 +32,6 @@ "model.layers.{}.linear_attn.in_proj_b.weight": "layers.{}.attention.in_proj_b.weight", "model.layers.{}.linear_attn.in_proj_a.weight": "layers.{}.attention.in_proj_a.weight", "model.layers.{}.linear_attn.conv1d.weight": "layers.{}.attention.conv1d.weight", - "model.layers.{}.linear_attn.conv1d.bias": "layers.{}.attention.conv1d.bias", "model.layers.{}.linear_attn.dt_bias": "layers.{}.attention.dt_bias", "model.layers.{}.linear_attn.A_log": "layers.{}.attention.A_log", "model.layers.{}.linear_attn.norm.weight": "layers.{}.attention.norm.weight", @@ -53,6 +52,7 @@ _IGNORED_UNMAPPED_SUFFIXES = ( "rotary_emb.inv_freq", + "linear_attn.conv1d.bias", ) @@ -78,13 +78,17 @@ def _load_checkpoint_from_safetensors(input_dir: str) -> Dict: weight_map = index["weight_map"] checkpoint_shards = sorted(set(weight_map.values())) - shard_to_weights = {} - for shard in checkpoint_shards: - shard_to_weights[shard] = load_file(os.path.join(input_dir, shard)) - merged_state_dict = {} + shard_to_weight_names = {} for weight_name, shard in weight_map.items(): - merged_state_dict[weight_name] = shard_to_weights[shard][weight_name] + shard_to_weight_names.setdefault(shard, []).append(weight_name) + + # Load each shard once and copy only the tensor names mapped to that shard. + # This avoids holding all shard tensors in memory at the same time. + for shard in checkpoint_shards: + shard_weights = load_file(os.path.join(input_dir, shard)) + for weight_name in shard_to_weight_names[shard]: + merged_state_dict[weight_name] = shard_weights[weight_name] return merged_state_dict model_path = os.path.join(input_dir, "model.safetensors") @@ -132,7 +136,9 @@ def qwen_3_5_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Ten ): if _should_ignore_unmapped_key(key, normalized_key): continue - continue + raise ValueError( + f"Unexpected non-text checkpoint key not mapped for Qwen3.5 export: {key}" + ) # Legacy packed tensors (older checkpoints): # in_proj_qkvz -> split into in_proj_qkv and in_proj_z diff --git a/examples/models/qwen3_5/tests/test_convert_weights.py b/examples/models/qwen3_5/tests/test_convert_weights.py index 7fa2fb667c7..52b6981c7fb 100644 --- a/examples/models/qwen3_5/tests/test_convert_weights.py +++ b/examples/models/qwen3_5/tests/test_convert_weights.py @@ -64,6 +64,60 @@ def test_ignores_known_non_text_keys(self): self.assertIn("output.weight", converted) self.assertNotIn("mtp.proj.weight", converted) + def test_raises_on_unexpected_non_text_key(self): + state_dict = { + "model.embed_tokens.weight": torch.randn(16, 8), + "model.norm.weight": torch.randn(8), + "vision_tower.blocks.0.weight": torch.randn(8, 8), + } + + with self.assertRaisesRegex( + ValueError, + "Unexpected non-text checkpoint key not mapped for Qwen3.5 export", + ): + qwen_3_5_to_meta(state_dict) + + def test_ignores_linear_attention_conv1d_bias(self): + state_dict = { + "model.embed_tokens.weight": torch.randn(16, 8), + "model.norm.weight": torch.randn(8), + "model.layers.1.linear_attn.conv1d.weight": torch.randn(24, 1, 4), + "model.layers.1.linear_attn.conv1d.bias": torch.randn(24), + "model.layers.1.linear_attn.out_proj.weight": torch.randn(8, 8), + } + + converted = qwen_3_5_to_meta(state_dict) + self.assertIn("layers.1.attention.conv1d.weight", converted) + self.assertIn("layers.1.attention.out_proj.weight", converted) + self.assertNotIn("layers.1.attention.conv1d.bias", converted) + + def test_splits_legacy_packed_linear_attention_weights(self): + qkvz = torch.arange(32 * 8, dtype=torch.float32).reshape(32, 8) + ba = torch.arange(4 * 8, dtype=torch.float32).reshape(4, 8) + out_proj = torch.randn(8, 8) + state_dict = { + "model.embed_tokens.weight": torch.randn(16, 8), + "model.norm.weight": torch.randn(8), + "model.layers.1.linear_attn.out_proj.weight": out_proj, + "model.layers.1.linear_attn.in_proj_qkvz.weight": qkvz, + "model.layers.1.linear_attn.in_proj_ba.weight": ba, + } + + converted = qwen_3_5_to_meta(state_dict) + + self.assertTrue( + torch.equal(converted["layers.1.attention.in_proj_qkv.weight"], qkvz[:24]) + ) + self.assertTrue( + torch.equal(converted["layers.1.attention.in_proj_z.weight"], qkvz[24:]) + ) + self.assertTrue( + torch.equal(converted["layers.1.attention.in_proj_b.weight"], ba[:2]) + ) + self.assertTrue( + torch.equal(converted["layers.1.attention.in_proj_a.weight"], ba[2:]) + ) + if __name__ == "__main__": unittest.main() From 0d6af31d5a92715aca6bdb4d076affb25e82d5a8 Mon Sep 17 00:00:00 2001 From: Sriram Kiron Date: Wed, 4 Mar 2026 19:08:27 -0500 Subject: [PATCH 08/15] Reset gated DeltaNet state when input_pos is omitted --- examples/models/llama/attention.py | 2 ++ .../llama/tests/test_qwen3_5_attention.py | 34 +++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index a7b7104ef8e..ed536f88b3d 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -621,6 +621,8 @@ def _maybe_reset_state( self, input_pos: Optional[torch.Tensor], batch_size: int ) -> None: if input_pos is None: + self.conv_state[:batch_size].zero_() + self.recurrent_state[:batch_size].zero_() return reset = (input_pos[0] == 0).to(self.conv_state.dtype) keep = 1.0 - reset diff --git a/examples/models/llama/tests/test_qwen3_5_attention.py b/examples/models/llama/tests/test_qwen3_5_attention.py index d5598eb4196..71e4a14634e 100644 --- a/examples/models/llama/tests/test_qwen3_5_attention.py +++ b/examples/models/llama/tests/test_qwen3_5_attention.py @@ -123,6 +123,40 @@ def test_gated_deltanet_resets_state_on_new_sequence(self): state_after_reset = attn.recurrent_state.clone() self.assertTrue(torch.allclose(state_after_first, state_after_reset, atol=1e-5)) + def test_gated_deltanet_no_input_pos_does_not_leak_state(self): + torch.manual_seed(0) + args = ModelArgs( + dim=32, + n_layers=1, + n_heads=4, + n_kv_heads=2, + head_dim=8, + hidden_dim=64, + max_seq_len=16, + max_context_len=16, + use_kv_cache=True, + attention_type="mha", + use_q_gate=True, + linear_conv_kernel_dim=4, + linear_key_head_dim=4, + linear_value_head_dim=4, + linear_num_key_heads=2, + linear_num_value_heads=4, + ) + rope = Rope(args) + attn = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope) + + x = torch.randn(1, 1, args.dim) + dummy_freq = torch.zeros(1, 1) + + attn(x, dummy_freq, dummy_freq) + state_after_first = attn.recurrent_state.clone() + + attn(x, dummy_freq, dummy_freq) + state_after_second = attn.recurrent_state.clone() + + self.assertTrue(torch.allclose(state_after_first, state_after_second, atol=1e-5)) + if __name__ == "__main__": unittest.main() From c5627304d1ca1f90f662bfcd0f440d4072653899 Mon Sep 17 00:00:00 2001 From: Sriram Kiron Date: Wed, 4 Mar 2026 19:10:06 -0500 Subject: [PATCH 09/15] Mark existing complex export/config helpers for lint --- examples/models/llama/export_llama_lib.py | 2 +- examples/models/llama/model_args.py | 2 +- examples/models/qwen3_5/convert_weights.py | 4 +++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 94984565e10..2226580a62f 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -626,7 +626,7 @@ def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str: return return_val -def export_llama( +def export_llama( # noqa: C901 export_options: Union[argparse.Namespace, LlmConfig, DictConfig], ) -> str: if isinstance(export_options, argparse.Namespace): diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index 05e9ea62a8a..3bfb095be3d 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -153,7 +153,7 @@ class ModelArgs: final_logit_softcapping: Optional[float] = None attn_logit_softcapping: Optional[float] = None - def __post_init__(self): + def __post_init__(self): # noqa: C901 if self.n_kv_heads is None: self.n_kv_heads = self.n_heads diff --git a/examples/models/qwen3_5/convert_weights.py b/examples/models/qwen3_5/convert_weights.py index d08ec441f0a..afd3f1a3117 100644 --- a/examples/models/qwen3_5/convert_weights.py +++ b/examples/models/qwen3_5/convert_weights.py @@ -116,7 +116,9 @@ def load_checkpoint(input_dir: str) -> Dict: raise FileNotFoundError(f"Could not find checkpoint in {input_dir}") -def qwen_3_5_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: +def qwen_3_5_to_meta( # noqa: C901 + state_dict: Dict[str, torch.Tensor], +) -> Dict[str, torch.Tensor]: converted_state_dict = {} pending_qkvz = {} pending_ba = {} From b532cb09997b5b7296fb41b047e1d166ca0237a9 Mon Sep 17 00:00:00 2001 From: Phineas1500 <41450967+Phineas1500@users.noreply.github.com> Date: Thu, 5 Mar 2026 00:32:29 -0500 Subject: [PATCH 10/15] Address Qwen3.5 review nits on attention and weight conversion --- examples/models/llama/attention.py | 1 - examples/models/llama/model_args.py | 4 -- .../llama/tests/test_qwen3_5_attention.py | 60 ++++--------------- examples/models/qwen3_5/convert_weights.py | 56 ++--------------- .../qwen3_5/tests/test_convert_weights.py | 28 --------- 5 files changed, 16 insertions(+), 133 deletions(-) diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index ed536f88b3d..ff7ebc2637f 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -314,7 +314,6 @@ def update( return self.k_cache, self.v_cache -@register_attention("qwen3_5_full") @register_attention("mha") class AttentionMHA(Attention): def __init__( diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index 3bfb095be3d..a753f49f818 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -157,10 +157,6 @@ def __post_init__(self): # noqa: C901 if self.n_kv_heads is None: self.n_kv_heads = self.n_heads - # Backward compatibility: qwen3_5_full attention name implies q-gated MHA. - if self.attention_type == "qwen3_5_full": - self.use_q_gate = True - # rope_theta overrides rope_freq_base since it's the official name. if self.rope_theta is not None: self.rope_freq_base = self.rope_theta diff --git a/examples/models/llama/tests/test_qwen3_5_attention.py b/examples/models/llama/tests/test_qwen3_5_attention.py index 71e4a14634e..477f8424f9d 100644 --- a/examples/models/llama/tests/test_qwen3_5_attention.py +++ b/examples/models/llama/tests/test_qwen3_5_attention.py @@ -14,8 +14,8 @@ class Qwen35AttentionTest(unittest.TestCase): - def test_qwen35_full_attention_output_proj_is_bias_free(self): - args = ModelArgs( + def _make_args(self, **kwargs) -> ModelArgs: + defaults = dict( dim=32, n_layers=1, n_heads=4, @@ -24,10 +24,16 @@ def test_qwen35_full_attention_output_proj_is_bias_free(self): hidden_dim=64, max_seq_len=16, max_context_len=16, + attention_type="mha", + ) + defaults.update(kwargs) + return ModelArgs(**defaults) + + def test_qwen35_full_attention_output_proj_is_bias_free(self): + args = self._make_args( use_kv_cache=False, use_qk_norm=False, qk_norm_before_rope=True, - attention_type="mha", use_q_gate=True, attention_qkv_bias=True, ) @@ -43,21 +49,12 @@ def test_rmsnorm_preserves_input_dtype_without_unit_offset(self): def test_qwen35_full_attention_forward_shape(self): torch.manual_seed(0) - args = ModelArgs( - dim=32, - n_layers=1, - n_heads=4, - n_kv_heads=2, - head_dim=8, - hidden_dim=64, - max_seq_len=16, - max_context_len=16, + args = self._make_args( use_kv_cache=False, use_hf_rope=True, partial_rotary_factor=0.5, use_qk_norm=True, qk_norm_before_rope=True, - attention_type="mha", use_q_gate=True, rms_norm_add_unit_offset=True, ) @@ -68,34 +65,10 @@ def test_qwen35_full_attention_forward_shape(self): y, _ = attn(x, freqs_cos, freqs_sin) self.assertEqual(y.shape, x.shape) - def test_qwen35_full_attention_legacy_name_maps_to_gated_mha(self): - args = ModelArgs( - dim=32, - n_layers=1, - n_heads=4, - n_kv_heads=2, - head_dim=8, - hidden_dim=64, - attention_type="qwen3_5_full", - ) - self.assertTrue(args.use_q_gate) - rope = Rope(args) - attn = ATTENTION_REGISTRY["qwen3_5_full"](args, 0, rope) - self.assertTrue(attn.use_q_gate) - def test_gated_deltanet_resets_state_on_new_sequence(self): torch.manual_seed(0) - args = ModelArgs( - dim=32, - n_layers=1, - n_heads=4, - n_kv_heads=2, - head_dim=8, - hidden_dim=64, - max_seq_len=16, - max_context_len=16, + args = self._make_args( use_kv_cache=True, - attention_type="mha", use_q_gate=True, linear_conv_kernel_dim=4, linear_key_head_dim=4, @@ -125,17 +98,8 @@ def test_gated_deltanet_resets_state_on_new_sequence(self): def test_gated_deltanet_no_input_pos_does_not_leak_state(self): torch.manual_seed(0) - args = ModelArgs( - dim=32, - n_layers=1, - n_heads=4, - n_kv_heads=2, - head_dim=8, - hidden_dim=64, - max_seq_len=16, - max_context_len=16, + args = self._make_args( use_kv_cache=True, - attention_type="mha", use_q_gate=True, linear_conv_kernel_dim=4, linear_key_head_dim=4, diff --git a/examples/models/qwen3_5/convert_weights.py b/examples/models/qwen3_5/convert_weights.py index afd3f1a3117..d219e0000ad 100644 --- a/examples/models/qwen3_5/convert_weights.py +++ b/examples/models/qwen3_5/convert_weights.py @@ -1,7 +1,6 @@ import argparse import json import os -import re from typing import Dict import torch @@ -120,8 +119,6 @@ def qwen_3_5_to_meta( # noqa: C901 state_dict: Dict[str, torch.Tensor], ) -> Dict[str, torch.Tensor]: converted_state_dict = {} - pending_qkvz = {} - pending_ba = {} for key, value in state_dict.items(): normalized_key = key @@ -134,7 +131,10 @@ def qwen_3_5_to_meta( # noqa: C901 # Ignore non-language-model keys up front. if not ( - normalized_key.startswith("model.") or normalized_key.startswith("lm_head.") + normalized_key.startswith("model.layers.") + or normalized_key.startswith("model.embed") + or normalized_key.startswith("model.norm") + or normalized_key.startswith("lm_head") ): if _should_ignore_unmapped_key(key, normalized_key): continue @@ -142,16 +142,6 @@ def qwen_3_5_to_meta( # noqa: C901 f"Unexpected non-text checkpoint key not mapped for Qwen3.5 export: {key}" ) - # Legacy packed tensors (older checkpoints): - # in_proj_qkvz -> split into in_proj_qkv and in_proj_z - # in_proj_ba -> split into in_proj_b and in_proj_a - if normalized_key.endswith(".linear_attn.in_proj_qkvz.weight"): - pending_qkvz[normalized_key] = value - continue - if normalized_key.endswith(".linear_attn.in_proj_ba.weight"): - pending_ba[normalized_key] = value - continue - try: new_key = get_mapped_key(normalized_key, _QWEN_3_5_TO_META) except Exception as err: @@ -162,44 +152,6 @@ def qwen_3_5_to_meta( # noqa: C901 ) from err converted_state_dict[new_key] = value - for key, value in pending_qkvz.items(): - layer_match = re.search(r"model\.layers\.(\d+)\.", key) - if layer_match is None: - raise ValueError(f"Failed to parse layer id from key: {key}") - layer_id = layer_match.group(1) - out_proj_key = f"layers.{layer_id}.attention.out_proj.weight" - if out_proj_key not in converted_state_dict: - raise ValueError( - f"Cannot split {key}: missing {out_proj_key} to infer value dimension." - ) - - value_dim = converted_state_dict[out_proj_key].shape[1] - total_dim = value.shape[0] - conv_dim = total_dim - value_dim - if conv_dim <= 0 or (conv_dim - value_dim) % 2 != 0: - raise ValueError( - f"Invalid packed in_proj_qkvz shape for {key}: {tuple(value.shape)}" - ) - qkv, z = torch.split(value, [conv_dim, value_dim], dim=0) - converted_state_dict[f"layers.{layer_id}.attention.in_proj_qkv.weight"] = qkv - converted_state_dict[f"layers.{layer_id}.attention.in_proj_z.weight"] = z - print(f"Split legacy packed key {key} -> in_proj_qkv + in_proj_z") - - for key, value in pending_ba.items(): - layer_match = re.search(r"model\.layers\.(\d+)\.", key) - if layer_match is None: - raise ValueError(f"Failed to parse layer id from key: {key}") - layer_id = layer_match.group(1) - if value.shape[0] % 2 != 0: - raise ValueError( - f"Invalid packed in_proj_ba shape for {key}: {tuple(value.shape)}" - ) - half = value.shape[0] // 2 - b, a = torch.split(value, [half, half], dim=0) - converted_state_dict[f"layers.{layer_id}.attention.in_proj_b.weight"] = b - converted_state_dict[f"layers.{layer_id}.attention.in_proj_a.weight"] = a - print(f"Split legacy packed key {key} -> in_proj_b + in_proj_a") - # Handle tied embeddings. if "lm_head.weight" not in state_dict: converted_state_dict["output.weight"] = converted_state_dict[ diff --git a/examples/models/qwen3_5/tests/test_convert_weights.py b/examples/models/qwen3_5/tests/test_convert_weights.py index 52b6981c7fb..b4d2c0056d0 100644 --- a/examples/models/qwen3_5/tests/test_convert_weights.py +++ b/examples/models/qwen3_5/tests/test_convert_weights.py @@ -91,33 +91,5 @@ def test_ignores_linear_attention_conv1d_bias(self): self.assertIn("layers.1.attention.out_proj.weight", converted) self.assertNotIn("layers.1.attention.conv1d.bias", converted) - def test_splits_legacy_packed_linear_attention_weights(self): - qkvz = torch.arange(32 * 8, dtype=torch.float32).reshape(32, 8) - ba = torch.arange(4 * 8, dtype=torch.float32).reshape(4, 8) - out_proj = torch.randn(8, 8) - state_dict = { - "model.embed_tokens.weight": torch.randn(16, 8), - "model.norm.weight": torch.randn(8), - "model.layers.1.linear_attn.out_proj.weight": out_proj, - "model.layers.1.linear_attn.in_proj_qkvz.weight": qkvz, - "model.layers.1.linear_attn.in_proj_ba.weight": ba, - } - - converted = qwen_3_5_to_meta(state_dict) - - self.assertTrue( - torch.equal(converted["layers.1.attention.in_proj_qkv.weight"], qkvz[:24]) - ) - self.assertTrue( - torch.equal(converted["layers.1.attention.in_proj_z.weight"], qkvz[24:]) - ) - self.assertTrue( - torch.equal(converted["layers.1.attention.in_proj_b.weight"], ba[:2]) - ) - self.assertTrue( - torch.equal(converted["layers.1.attention.in_proj_a.weight"], ba[2:]) - ) - - if __name__ == "__main__": unittest.main() From 36ad247bf4b0ebcf20d2b104e79673d3b933f428 Mon Sep 17 00:00:00 2001 From: Phineas1500 <41450967+Phineas1500@users.noreply.github.com> Date: Thu, 5 Mar 2026 14:32:45 -0500 Subject: [PATCH 11/15] Simplify Qwen3.5 key filtering --- examples/models/qwen3_5/convert_weights.py | 45 +++++-------------- .../qwen3_5/tests/test_convert_weights.py | 32 +++++++++---- 2 files changed, 33 insertions(+), 44 deletions(-) diff --git a/examples/models/qwen3_5/convert_weights.py b/examples/models/qwen3_5/convert_weights.py index d219e0000ad..478fa5de6c7 100644 --- a/examples/models/qwen3_5/convert_weights.py +++ b/examples/models/qwen3_5/convert_weights.py @@ -38,35 +38,12 @@ } -_IGNORED_UNMAPPED_PREFIXES = ( - "mtp.", - "model.visual.", - "visual.", -) - -_IGNORED_UNMAPPED_SUBSTRINGS = ( - ".vision_", - ".visual.", -) - _IGNORED_UNMAPPED_SUFFIXES = ( "rotary_emb.inv_freq", "linear_attn.conv1d.bias", ) -def _should_ignore_unmapped_key(key: str, normalized_key: str) -> bool: - candidates = (key, normalized_key) - for candidate in candidates: - if any(candidate.startswith(prefix) for prefix in _IGNORED_UNMAPPED_PREFIXES): - return True - if any(token in candidate for token in _IGNORED_UNMAPPED_SUBSTRINGS): - return True - if any(candidate.endswith(suffix) for suffix in _IGNORED_UNMAPPED_SUFFIXES): - return True - return False - - def _load_checkpoint_from_safetensors(input_dir: str) -> Dict: from safetensors.torch import load_file @@ -129,23 +106,21 @@ def qwen_3_5_to_meta( # noqa: C901 "model.language_model.", "model.", 1 ) - # Ignore non-language-model keys up front. - if not ( - normalized_key.startswith("model.layers.") - or normalized_key.startswith("model.embed") - or normalized_key.startswith("model.norm") - or normalized_key.startswith("lm_head") - ): - if _should_ignore_unmapped_key(key, normalized_key): - continue - raise ValueError( - f"Unexpected non-text checkpoint key not mapped for Qwen3.5 export: {key}" + # Ignore non-text-model keys up front. + if not normalized_key.startswith( + ( + "model.layers.", + "model.embed_tokens.", + "model.norm.", + "lm_head.", ) + ): + continue try: new_key = get_mapped_key(normalized_key, _QWEN_3_5_TO_META) except Exception as err: - if _should_ignore_unmapped_key(key, normalized_key): + if normalized_key.endswith(_IGNORED_UNMAPPED_SUFFIXES): continue raise ValueError( f"Unexpected checkpoint key not mapped for Qwen3.5 export: {key}" diff --git a/examples/models/qwen3_5/tests/test_convert_weights.py b/examples/models/qwen3_5/tests/test_convert_weights.py index b4d2c0056d0..95acc99363d 100644 --- a/examples/models/qwen3_5/tests/test_convert_weights.py +++ b/examples/models/qwen3_5/tests/test_convert_weights.py @@ -64,18 +64,20 @@ def test_ignores_known_non_text_keys(self): self.assertIn("output.weight", converted) self.assertNotIn("mtp.proj.weight", converted) - def test_raises_on_unexpected_non_text_key(self): + def test_maps_multimodal_language_model_keys(self): state_dict = { - "model.embed_tokens.weight": torch.randn(16, 8), - "model.norm.weight": torch.randn(8), - "vision_tower.blocks.0.weight": torch.randn(8, 8), + "model.language_model.embed_tokens.weight": torch.randn(16, 8), + "model.language_model.norm.weight": torch.randn(8), + "model.language_model.layers.0.self_attn.q_proj.weight": torch.randn( + 16, 8 + ), } - with self.assertRaisesRegex( - ValueError, - "Unexpected non-text checkpoint key not mapped for Qwen3.5 export", - ): - qwen_3_5_to_meta(state_dict) + converted = qwen_3_5_to_meta(state_dict) + self.assertIn("tok_embeddings.weight", converted) + self.assertIn("norm.weight", converted) + self.assertIn("layers.0.attention.wq.weight", converted) + self.assertIn("output.weight", converted) def test_ignores_linear_attention_conv1d_bias(self): state_dict = { @@ -91,5 +93,17 @@ def test_ignores_linear_attention_conv1d_bias(self): self.assertIn("layers.1.attention.out_proj.weight", converted) self.assertNotIn("layers.1.attention.conv1d.bias", converted) + def test_ignores_rotary_emb_inv_freq(self): + state_dict = { + "model.embed_tokens.weight": torch.randn(16, 8), + "model.norm.weight": torch.randn(8), + "model.layers.0.self_attn.rotary_emb.inv_freq": torch.randn(4), + } + + converted = qwen_3_5_to_meta(state_dict) + self.assertIn("tok_embeddings.weight", converted) + self.assertIn("output.weight", converted) + self.assertNotIn("model.layers.0.self_attn.rotary_emb.inv_freq", converted) + if __name__ == "__main__": unittest.main() From 353ef7bff0c5cbab68bafca2caa5d1d9743a6231 Mon Sep 17 00:00:00 2001 From: Phineas1500 <41450967+Phineas1500@users.noreply.github.com> Date: Thu, 5 Mar 2026 16:23:05 -0500 Subject: [PATCH 12/15] Apply lintrunner formatting --- examples/models/qwen3_5/tests/test_convert_weights.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/models/qwen3_5/tests/test_convert_weights.py b/examples/models/qwen3_5/tests/test_convert_weights.py index 95acc99363d..7c501b44de8 100644 --- a/examples/models/qwen3_5/tests/test_convert_weights.py +++ b/examples/models/qwen3_5/tests/test_convert_weights.py @@ -68,9 +68,7 @@ def test_maps_multimodal_language_model_keys(self): state_dict = { "model.language_model.embed_tokens.weight": torch.randn(16, 8), "model.language_model.norm.weight": torch.randn(8), - "model.language_model.layers.0.self_attn.q_proj.weight": torch.randn( - 16, 8 - ), + "model.language_model.layers.0.self_attn.q_proj.weight": torch.randn(16, 8), } converted = qwen_3_5_to_meta(state_dict) @@ -105,5 +103,6 @@ def test_ignores_rotary_emb_inv_freq(self): self.assertIn("output.weight", converted) self.assertNotIn("model.layers.0.self_attn.rotary_emb.inv_freq", converted) + if __name__ == "__main__": unittest.main() From fc193f05cd3c71e24d8d880f4fd0135feff5c385 Mon Sep 17 00:00:00 2001 From: Phineas1500 <41450967+Phineas1500@users.noreply.github.com> Date: Thu, 5 Mar 2026 16:27:44 -0500 Subject: [PATCH 13/15] Apply full lintrunner formatting --- examples/models/llama/attention.py | 4 +--- examples/models/llama/model_args.py | 4 +++- examples/models/llama/tests/test_qwen3_5_attention.py | 4 +++- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index ff7ebc2637f..7f2b2d4e337 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -375,9 +375,7 @@ def __init__( use_bias=args.attention_qkv_bias, ) if args.target_modules is not None and "q_proj" in args.target_modules - else nn.Linear( - self.dim, q_out_dim, bias=self.attention_qkv_bias - ) + else nn.Linear(self.dim, q_out_dim, bias=self.attention_qkv_bias) ) self.wk = ( LoRALinear( diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index a753f49f818..a3380417316 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -72,7 +72,9 @@ class ModelArgs: num_experts: int = 8 # Number of experts num_activated_experts: int = 2 # Number of experts to activate attention_type: str = "mha" # Attention type, registered in attention.py - use_q_gate: bool = False # Use q-gated projection in attention (Qwen3.5 full attention) + use_q_gate: bool = ( + False # Use q-gated projection in attention (Qwen3.5 full attention) + ) norm_type: str = "rmsnorm" # Normalization type, registered in norm.py act_fn: ActFn = dataclasses.field(default=ActFn.SILU) # Activation function type attention_qkv_bias: bool = False diff --git a/examples/models/llama/tests/test_qwen3_5_attention.py b/examples/models/llama/tests/test_qwen3_5_attention.py index 477f8424f9d..83d9b0e5aaa 100644 --- a/examples/models/llama/tests/test_qwen3_5_attention.py +++ b/examples/models/llama/tests/test_qwen3_5_attention.py @@ -119,7 +119,9 @@ def test_gated_deltanet_no_input_pos_does_not_leak_state(self): attn(x, dummy_freq, dummy_freq) state_after_second = attn.recurrent_state.clone() - self.assertTrue(torch.allclose(state_after_first, state_after_second, atol=1e-5)) + self.assertTrue( + torch.allclose(state_after_first, state_after_second, atol=1e-5) + ) if __name__ == "__main__": From 2f1fabc071c2417cbe21b7b7533eefb1a0a08e16 Mon Sep 17 00:00:00 2001 From: Phineas1500 <41450967+Phineas1500@users.noreply.github.com> Date: Thu, 5 Mar 2026 17:11:59 -0500 Subject: [PATCH 14/15] Fix Qwen3.5 attention test lint nit --- .../llama/tests/test_qwen3_5_attention.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/examples/models/llama/tests/test_qwen3_5_attention.py b/examples/models/llama/tests/test_qwen3_5_attention.py index 83d9b0e5aaa..5a9f67d57cf 100644 --- a/examples/models/llama/tests/test_qwen3_5_attention.py +++ b/examples/models/llama/tests/test_qwen3_5_attention.py @@ -15,17 +15,17 @@ class Qwen35AttentionTest(unittest.TestCase): def _make_args(self, **kwargs) -> ModelArgs: - defaults = dict( - dim=32, - n_layers=1, - n_heads=4, - n_kv_heads=2, - head_dim=8, - hidden_dim=64, - max_seq_len=16, - max_context_len=16, - attention_type="mha", - ) + defaults = { + "dim": 32, + "n_layers": 1, + "n_heads": 4, + "n_kv_heads": 2, + "head_dim": 8, + "hidden_dim": 64, + "max_seq_len": 16, + "max_context_len": 16, + "attention_type": "mha", + } defaults.update(kwargs) return ModelArgs(**defaults) From 14a425a6afed907e637ea64af6beeeaa16e869be Mon Sep 17 00:00:00 2001 From: Phineas1500 <41450967+Phineas1500@users.noreply.github.com> Date: Thu, 5 Mar 2026 17:12:52 -0500 Subject: [PATCH 15/15] Fix Qwen3.5 README collection link --- examples/models/qwen3_5/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/models/qwen3_5/README.md b/examples/models/qwen3_5/README.md index e439f256f2a..34ee7134258 100644 --- a/examples/models/qwen3_5/README.md +++ b/examples/models/qwen3_5/README.md @@ -1,5 +1,5 @@ ## Summary -[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35-684357f8543f83de2f09f998) support in ExecuTorch is exported through the Llama example pipeline with a hybrid layer layout: +[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35) support in ExecuTorch is exported through the Llama example pipeline with a hybrid layer layout: - `full_attention` layers use gated full attention. - `linear_attention` layers use Gated DeltaNet with internal recurrent state.