diff --git a/examples/models/BUCK b/examples/models/BUCK index 94cc88360bf..a2b6789a95e 100644 --- a/examples/models/BUCK +++ b/examples/models/BUCK @@ -29,6 +29,7 @@ fbcode_target(_kind = python_library, "//executorch/examples/models/gemma3:gemma3", # @manual "//executorch/examples/models/qwen2_5:qwen2_5", # @manual "//executorch/examples/models/qwen3:qwen3", # @manual + "//executorch/examples/models/qwen3_5:qwen3_5", # @manual "//executorch/examples/models/phi_4_mini:phi_4_mini", # @manual "//executorch/examples/models/smollm2:smollm2", # @manual "//executorch/examples/models/smollm3:smollm3", # @manual diff --git a/examples/models/llama/__init__.py b/examples/models/llama/__init__.py index db6124ecc71..51f2a5c916f 100644 --- a/examples/models/llama/__init__.py +++ b/examples/models/llama/__init__.py @@ -4,8 +4,17 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from .model import Llama2Model +from typing import TYPE_CHECKING -__all__ = [ - Llama2Model, -] +if TYPE_CHECKING: + from .model import Llama2Model + +__all__ = ["Llama2Model"] + + +def __getattr__(name): + if name == "Llama2Model": + from .model import Llama2Model + + return Llama2Model + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index 922bbbb37fa..a7b7104ef8e 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from executorch.examples.models.llama.lora import LoRALinear from executorch.examples.models.llama.model_args import ModelArgs -from executorch.examples.models.llama.norm import RMSNorm +from executorch.examples.models.llama.norm import RMSNorm, RMSNormGated from executorch.examples.models.llama.rope import Rope @@ -314,6 +314,7 @@ def update( return self.k_cache, self.v_cache +@register_attention("qwen3_5_full") @register_attention("mha") class AttentionMHA(Attention): def __init__( @@ -347,18 +348,28 @@ def __init__( self.attention_qkv_bias = args.attention_qkv_bias self.use_qk_norm = args.use_qk_norm self.qk_norm_before_rope = args.qk_norm_before_rope + self.use_q_gate = args.use_q_gate self.enable_dynamic_shape = args.enable_dynamic_shape + q_out_dim = self.n_heads * self.head_dim * (2 if self.use_q_gate else 1) if self.use_qk_norm: q_norm_dim = self.head_dim k_norm_dim = self.head_dim - self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps) - self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps) + self.q_norm_fn = RMSNorm( + q_norm_dim, + eps=args.norm_eps, + add_unit_offset=args.rms_norm_add_unit_offset, + ) + self.k_norm_fn = RMSNorm( + k_norm_dim, + eps=args.norm_eps, + add_unit_offset=args.rms_norm_add_unit_offset, + ) self.wq = ( LoRALinear( in_dim=args.dim, - out_dim=args.n_heads * args.head_dim, + out_dim=q_out_dim, rank=args.r, alpha=args.lora_alpha, dropout=0.0, @@ -366,7 +377,7 @@ def __init__( ) if args.target_modules is not None and "q_proj" in args.target_modules else nn.Linear( - self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias + self.dim, q_out_dim, bias=self.attention_qkv_bias ) ) self.wk = ( @@ -452,10 +463,17 @@ def forward( input_pos = kwargs.get("input_pos") bsz, seqlen, _ = x.shape - # QKV - q, k, v = self.wq(x), self.wk(x), self.wv(x) - # We need view_copy elimination - q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) + if self.use_q_gate: + q_and_gate = self.wq(x).view( + bsz, seqlen, self.n_local_heads, self.head_dim * 2 + ) + q, gate = torch.chunk(q_and_gate, 2, dim=-1) + gate = gate.reshape(bsz, seqlen, -1) + else: + q = self.wq(x).view(bsz, seqlen, self.n_local_heads, self.head_dim) + gate = None + + k, v = self.wk(x), self.wv(x) k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) @@ -492,6 +510,8 @@ def forward( input_pos[0].item(), seqlen ) output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask) + if gate is not None: + output = output * torch.sigmoid(gate) return self.wo(output), None # grouped multiquery attention: expand out keys and values @@ -505,12 +525,232 @@ def forward( output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) output = output.transpose(1, 2).reshape(bsz, seqlen, -1) + if gate is not None: + output = output * torch.sigmoid(gate) output = self.wo(output) return output, None +def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor: + inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) + return x * inv_norm + + +@register_attention("gated_deltanet") +class AttentionGatedDeltaNet(Attention): + """Qwen3.5 linear-attention (Gated DeltaNet) block with internal state.""" + + def __init__( + self, + args: ModelArgs, + layer_id: int, + rope: Rope, + **_kwargs: Any, + ): + super().__init__() + del rope # DeltaNet layers do not use RoPE. + + self.hidden_size = args.dim + self.max_batch_size = args.max_batch_size + self.layer_id = layer_id + + assert args.linear_num_key_heads is not None + assert args.linear_num_value_heads is not None + assert args.linear_key_head_dim is not None + assert args.linear_value_head_dim is not None + + self.num_k_heads = args.linear_num_key_heads + self.num_v_heads = args.linear_num_value_heads + self.head_k_dim = args.linear_key_head_dim + self.head_v_dim = args.linear_value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + self.conv_kernel_size = args.linear_conv_kernel_dim + + assert ( + self.num_v_heads % self.num_k_heads == 0 + ), "linear_num_value_heads must be divisible by linear_num_key_heads." + self.head_repeat = self.num_v_heads // self.num_k_heads + + self.conv_dim = self.key_dim * 2 + self.value_dim + self.in_proj_qkv = nn.Linear(self.hidden_size, self.conv_dim, bias=False) + self.in_proj_z = nn.Linear(self.hidden_size, self.value_dim, bias=False) + self.in_proj_b = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) + self.in_proj_a = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) + + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + kernel_size=self.conv_kernel_size, + groups=self.conv_dim, + bias=False, + padding=0, + ) + + self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads)) + A = torch.empty(self.num_v_heads).uniform_(0, 16) + self.A_log = nn.Parameter(torch.log(A)) + self.norm = RMSNormGated(self.head_v_dim, eps=args.norm_eps) + self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) + + self.register_buffer( + "conv_state", + torch.zeros( + self.max_batch_size, + self.conv_dim, + self.conv_kernel_size, + dtype=torch.float32, + device="cpu", + ), + ) + self.register_buffer( + "recurrent_state", + torch.zeros( + self.max_batch_size, + self.num_v_heads, + self.head_k_dim, + self.head_v_dim, + dtype=torch.float32, + device="cpu", + ), + ) + + def _maybe_reset_state( + self, input_pos: Optional[torch.Tensor], batch_size: int + ) -> None: + if input_pos is None: + return + reset = (input_pos[0] == 0).to(self.conv_state.dtype) + keep = 1.0 - reset + self.conv_state[:batch_size].mul_(keep) + self.recurrent_state[:batch_size].mul_(keep) + + def _apply_causal_conv(self, mixed_qkv: torch.Tensor) -> torch.Tensor: + # mixed_qkv: (batch, seq_len, conv_dim) + batch_size, seq_len, _ = mixed_qkv.shape + mixed_qkv = mixed_qkv.transpose(1, 2) + state_len = self.conv_state.shape[-1] + hidden_states_new = torch.cat([self.conv_state[:batch_size], mixed_qkv], dim=-1) + new_conv_state = hidden_states_new[:, :, -state_len:] + with torch.no_grad(): + self.conv_state[:batch_size].copy_(new_conv_state.to(self.conv_state.dtype)) + out = F.conv1d( + hidden_states_new, + self.conv1d.weight, + self.conv1d.bias, + padding=0, + groups=self.conv_dim, + ) + out = F.silu(out[:, :, -seq_len:]).to(mixed_qkv.dtype) + return out.transpose(1, 2).contiguous() + + def _recurrent_gated_delta_rule( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + ) -> torch.Tensor: + # query/key/value: (batch, seq_len, num_heads, head_dim) + # g/beta: (batch, seq_len, num_heads) + initial_dtype = query.dtype + query = _l2norm(query, dim=-1, eps=1e-6) + key = _l2norm(key, dim=-1, eps=1e-6) + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) + for x in (query, key, value, beta, g) + ] + + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + scale = 1.0 / (query.shape[-1] ** 0.5) + query = query * scale + + core_attn_out = torch.zeros( + batch_size, + num_heads, + sequence_length, + v_head_dim, + device=value.device, + dtype=value.dtype, + ) + last_recurrent_state = self.recurrent_state[:batch_size].to(value.dtype) + + for i in range(sequence_length): + q_t = query[:, :, i] + k_t = key[:, :, i] + v_t = value[:, :, i] + g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1) + beta_t = beta[:, :, i].unsqueeze(-1) + + last_recurrent_state = last_recurrent_state * g_t + kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) + delta = (v_t - kv_mem) * beta_t + last_recurrent_state = last_recurrent_state + k_t.unsqueeze( + -1 + ) * delta.unsqueeze(-2) + core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum( + dim=-2 + ) + + with torch.no_grad(): + self.recurrent_state[:batch_size].copy_( + last_recurrent_state.to(self.recurrent_state.dtype) + ) + + return core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) + + def forward( + self, + x: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + **kwargs: ForwardOptions, + ) -> Tuple[torch.Tensor, Optional[Any]]: + del freqs_cos + del freqs_sin + input_pos = kwargs.get("input_pos") + batch_size, seq_len, _ = x.shape + assert ( + batch_size <= self.max_batch_size + ), f"batch_size ({batch_size}) exceeds max_batch_size ({self.max_batch_size})" + + self._maybe_reset_state(input_pos, batch_size) + + mixed_qkv = self.in_proj_qkv(x) + z = self.in_proj_z(x).reshape(batch_size, seq_len, -1, self.head_v_dim) + b = self.in_proj_b(x) + a = self.in_proj_a(x) + + mixed_qkv = self._apply_causal_conv(mixed_qkv) + query, key, value = torch.split( + mixed_qkv, + [self.key_dim, self.key_dim, self.value_dim], + dim=-1, + ) + query = query.reshape(batch_size, seq_len, -1, self.head_k_dim) + key = key.reshape(batch_size, seq_len, -1, self.head_k_dim) + value = value.reshape(batch_size, seq_len, -1, self.head_v_dim) + + if self.head_repeat > 1: + query = query.repeat_interleave(self.head_repeat, dim=2) + key = key.repeat_interleave(self.head_repeat, dim=2) + + beta = b.sigmoid() + g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + core_attn_out = self._recurrent_gated_delta_rule(query, key, value, g, beta) + + core_attn_out = core_attn_out.reshape(-1, self.head_v_dim) + z = z.reshape(-1, self.head_v_dim) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1) + + return self.out_proj(core_attn_out), None + + @register_attention("skip") class AttentionSkip(Attention): def __init__(self, *args, **kwargs): diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index ee81139dedf..05a08fc5a5b 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -105,6 +105,9 @@ "qwen3_0_6b", "qwen3_1_7b", "qwen3_4b", + "qwen3_5_0_8b", + "qwen3_5_2b", + "qwen3_5_4b", "phi_4_mini", "smollm2", "lfm2_350m", # hybrid @@ -122,6 +125,9 @@ "qwen3_0_6b": "Qwen/Qwen3-0.6B", "qwen3_1_7b": "Qwen/Qwen3-1.7B", "qwen3_4b": "Qwen/Qwen3-4B", + "qwen3_5_0_8b": "Qwen/Qwen3.5-0.8B", + "qwen3_5_2b": "Qwen/Qwen3.5-2B", + "qwen3_5_4b": "Qwen/Qwen3.5-4B", "lfm2_350m": "LiquidAI/LFM2-350M", "lfm2_700m": "LiquidAI/LFM2-700M", "lfm2_1_2b": "LiquidAI/LFM2-1.2B", @@ -129,6 +135,27 @@ } +def _get_additional_export_passes(model_class: str) -> List[InitializedMutableBufferPass]: + patterns = [] + + if model_class in TORCHTUNE_DEFINED_MODELS: + patterns.append("kv_cache_pos") + + # Qwen3.5 uses internal mutable buffers for both the hybrid KV path and + # DeltaNet recurrent/conv states. + if model_class.startswith("qwen3_5"): + patterns.extend( + [ + "k_cache", + "v_cache", + "conv_state", + "recurrent_state", + ] + ) + + return [InitializedMutableBufferPass(patterns)] if patterns else [] + + def set_pkg_name(name: str) -> None: global pkg_name pkg_name = name @@ -643,6 +670,8 @@ def export_llama( repo_id = HUGGING_FACE_REPO_IDS[model_name] if model_name.startswith("qwen2_5"): from executorch.examples.models.qwen2_5 import convert_weights + elif model_name.startswith("qwen3_5"): + from executorch.examples.models.qwen3_5 import convert_weights elif model_name.startswith("qwen3"): from executorch.examples.models.qwen3 import convert_weights elif model_name == "phi_4_mini": @@ -1260,9 +1289,9 @@ def _export_llama_multimethod(llm_config: LlmConfig) -> LLMEdgeManager: "Each method requires separate model instantiation and export." ) - additional_passes = [] - if llm_config.base.model_class.value in TORCHTUNE_DEFINED_MODELS: - additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])] + additional_passes = _get_additional_export_passes( + llm_config.base.model_class.value + ) # Build dict of exported programs method_to_program: Dict[str, ExportedProgram] = {} @@ -1333,9 +1362,9 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 llm_config ) - additional_passes = [] - if llm_config.base.model_class.value in TORCHTUNE_DEFINED_MODELS: - additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])] + additional_passes = _get_additional_export_passes( + llm_config.base.model_class.value + ) # export_to_edge builder_manager = _prepare_for_llama_export(llm_config) diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index a4ca52aa608..c79f9cb6ce0 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -105,8 +105,16 @@ def __init__(self, args: ModelArgs, attention: Attention): if isinstance(self.attention, AttentionSkip): self.attention_norm = nn.Identity() else: - self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) - self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.attention_norm = RMSNorm( + args.dim, + eps=args.norm_eps, + add_unit_offset=args.rms_norm_add_unit_offset, + ) + self.ffn_norm = RMSNorm( + args.dim, + eps=args.norm_eps, + add_unit_offset=args.rms_norm_add_unit_offset, + ) @classmethod def from_type(cls, layer_id, args, rope) -> "TransformerBlock": @@ -164,7 +172,11 @@ def __init__(self, params: ModelArgs, layers: nn.ModuleList, rope: Rope): ) self.layers = layers self.rope = rope - self.norm = RMSNorm(params.dim, eps=params.norm_eps) + self.norm = RMSNorm( + params.dim, + eps=params.norm_eps, + add_unit_offset=params.rms_norm_add_unit_offset, + ) self.output = ( nn.Linear(params.dim, params.vocab_size, bias=False) if self.apply_output @@ -279,6 +291,21 @@ def construct_transformer(model_args: ModelArgs) -> Transformer: attention = AttentionSkip() transformer_block = TransformerBlock(model_args, attention) layers.append(transformer_block) + elif ( + model_args.layer_types + and model_args.layer_types[layer_id] == "linear_attention" + ): + linear_cls = ATTENTION_REGISTRY.get("gated_deltanet") + if linear_cls is None: + raise ValueError( + "Unknown attention type: gated_deltanet. " + f"Available: {list(ATTENTION_REGISTRY.keys())}" + ) + attention = linear_cls( + model_args, layer_id, rope, **model_args.attention_kwargs + ) + transformer_block = TransformerBlock(model_args, attention) + layers.append(transformer_block) else: attention = cls( model_args, layer_id, rope, **model_args.attention_kwargs diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index c225a1ee9b2..05e9ea62a8a 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -44,6 +44,14 @@ class ModelArgs: vocab_size: int = 512 # Arbitrary value, should be defined later by tokenizer. hidden_dim: Optional[int] = None head_dim: Optional[int] = None # Optional customized head_dim + # Qwen3.5 linear-attention dimensions. + linear_conv_kernel_dim: int = 4 + linear_key_head_dim: Optional[int] = None + linear_value_head_dim: Optional[int] = None + linear_num_key_heads: Optional[int] = None + linear_num_value_heads: Optional[int] = None + # Qwen3.5 RMSNorm uses (1 + weight) scaling. + rms_norm_add_unit_offset: bool = False multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 ffn_dim_multiplier: Optional[float] = None model_architecture: str = ( @@ -64,6 +72,7 @@ class ModelArgs: num_experts: int = 8 # Number of experts num_activated_experts: int = 2 # Number of experts to activate attention_type: str = "mha" # Attention type, registered in attention.py + use_q_gate: bool = False # Use q-gated projection in attention (Qwen3.5 full attention) norm_type: str = "rmsnorm" # Normalization type, registered in norm.py act_fn: ActFn = dataclasses.field(default=ActFn.SILU) # Activation function type attention_qkv_bias: bool = False @@ -148,6 +157,10 @@ def __post_init__(self): if self.n_kv_heads is None: self.n_kv_heads = self.n_heads + # Backward compatibility: qwen3_5_full attention name implies q-gated MHA. + if self.attention_type == "qwen3_5_full": + self.use_q_gate = True + # rope_theta overrides rope_freq_base since it's the official name. if self.rope_theta is not None: self.rope_freq_base = self.rope_theta @@ -174,6 +187,15 @@ def find_multiple(n: int, k: int) -> int: if self.head_dim is None: self.head_dim = self.dim // self.n_heads + if self.linear_key_head_dim is None: + self.linear_key_head_dim = self.head_dim + if self.linear_value_head_dim is None: + self.linear_value_head_dim = self.head_dim + if self.linear_num_key_heads is None: + self.linear_num_key_heads = self.n_heads + if self.linear_num_value_heads is None: + self.linear_num_value_heads = self.n_heads + # Convert string act_fn to enum if needed if isinstance(self.act_fn, str): self.act_fn = ActFn.from_string(self.act_fn) diff --git a/examples/models/llama/norm.py b/examples/models/llama/norm.py index 3786e61cd05..6a8217b5a24 100644 --- a/examples/models/llama/norm.py +++ b/examples/models/llama/norm.py @@ -5,11 +5,12 @@ # LICENSE file in the root directory of this source tree. import torch +import torch.nn.functional as F from torch import nn class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): + def __init__(self, dim: int, eps: float = 1e-6, add_unit_offset: bool = False): """ Initialize the RMSNorm normalization layer. @@ -25,6 +26,7 @@ def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.dim = dim self.eps = eps + self.add_unit_offset = add_unit_offset self.weight = nn.Parameter(torch.ones(dim)) def _norm(self, x): @@ -52,4 +54,22 @@ def forward(self, x): """ output = self._norm(x.float()).type_as(x) - return output * self.weight + if self.add_unit_offset: + return output * (1.0 + self.weight.float()).type_as(x) + return output * self.weight.type_as(x) + + +class RMSNormGated(nn.Module): + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = self.weight * hidden_states.to(input_dtype) + hidden_states = hidden_states * F.silu(gate.to(torch.float32)) + return hidden_states.to(input_dtype) diff --git a/examples/models/llama/runner/eager.py b/examples/models/llama/runner/eager.py index 7e662317509..9a92a08ae82 100644 --- a/examples/models/llama/runner/eager.py +++ b/examples/models/llama/runner/eager.py @@ -43,6 +43,7 @@ def __init__( ) manager: LLMEdgeManager = _prepare_for_llama_export(llm_config) self.model = manager.model.eval().to(device=self.device) + self.enable_dynamic_shape = llm_config.model.enable_dynamic_shape def forward( self, diff --git a/examples/models/llama/runner/generation.py b/examples/models/llama/runner/generation.py index 2baa8f5cd14..95fe47147db 100644 --- a/examples/models/llama/runner/generation.py +++ b/examples/models/llama/runner/generation.py @@ -72,6 +72,7 @@ def __init__( self.max_seq_len = max_seq_len self.max_batch_size = max_batch_size self.use_kv_cache = use_kv_cache + self.enable_dynamic_shape = True self.tokenizer = get_tokenizer(tokenizer_path, tokenizer_config_path) self.device = device # For some models like qwen, mismatch is acceptable: https://github.com/QwenLM/Qwen2.5/issues/466#issuecomment-2146759706 @@ -88,6 +89,45 @@ def forward( ) -> torch.Tensor: pass + def _prefill_with_kv_cache( + self, + prompt_tokens: List[int], + pos_base: int, + ) -> torch.Tensor: + if not self.enable_dynamic_shape and len(prompt_tokens) > 1: + return self._sequential_kv_prefill(prompt_tokens, pos_base) + + try: + return self.forward( + tokens=torch.tensor( + [prompt_tokens], dtype=torch.long, device=self.device + ), + input_pos=torch.tensor([pos_base], dtype=torch.long, device=self.device), + ) + except RuntimeError: + # Some exported models use a static single-token shape for kv-cache mode. + # Fall back to sequential token prefill so multi-token prompts still work. + if self.enable_dynamic_shape or len(prompt_tokens) <= 1: + raise + + return self._sequential_kv_prefill(prompt_tokens, pos_base) + + def _sequential_kv_prefill( + self, + prompt_tokens: List[int], + pos_base: int, + ) -> torch.Tensor: + logits = None + for offset, token in enumerate(prompt_tokens): + logits = self.forward( + tokens=torch.tensor([[token]], dtype=torch.long, device=self.device), + input_pos=torch.tensor( + [pos_base + offset], dtype=torch.long, device=self.device + ), + ) + assert logits is not None + return logits + def generate( # noqa: C901 self, prompt_tokens: List[int], @@ -99,14 +139,14 @@ def generate( # noqa: C901 ) -> List[int]: # Prefill prefill_start = time.time() - logits = self.forward( - tokens=torch.tensor([prompt_tokens], dtype=torch.long, device=self.device), - input_pos=( - torch.tensor([pos_base], dtype=torch.long, device=self.device) - if self.use_kv_cache - else None - ), - ) + if self.use_kv_cache: + logits = self._prefill_with_kv_cache(prompt_tokens, pos_base) + else: + logits = self.forward( + tokens=torch.tensor( + [prompt_tokens], dtype=torch.long, device=self.device + ), + ) prefill_time = time.time() - prefill_start current_token = next_token(logits, temperature, top_p) diff --git a/examples/models/llama/runner/native.py b/examples/models/llama/runner/native.py index 6d5d4730844..ffb19ab3c08 100644 --- a/examples/models/llama/runner/native.py +++ b/examples/models/llama/runner/native.py @@ -44,6 +44,13 @@ def __init__(self, args): vocab_size=params["vocab_size"], ) self.model = _load_for_executorch(args.pte) + try: + self.enable_dynamic_shape = bool( + self.model.run_method("enable_dynamic_shape")[0] + ) + except Exception: + # Keep default behavior when metadata method is unavailable. + pass def forward( self, diff --git a/examples/models/llama/tests/BUCK b/examples/models/llama/tests/BUCK index 8f4dec2237b..e0dae1147aa 100644 --- a/examples/models/llama/tests/BUCK +++ b/examples/models/llama/tests/BUCK @@ -15,6 +15,17 @@ fbcode_target(_kind = python_unittest, ], ) +fbcode_target(_kind = python_unittest, + name = "test_qwen3_5_attention", + srcs = [ + "test_qwen3_5_attention.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/examples/models/llama:llama_transformer", + ], +) + fbcode_target(_kind = python_unittest, name = "test_pre_quantization_transforms", srcs = [ @@ -104,3 +115,14 @@ fbcode_target(_kind = python_unittest, "//executorch/extension/pybindings:portable_lib", ], ) + +fbcode_target(_kind = python_unittest, + name = "test_generation_prefill", + srcs = [ + "test_generation_prefill.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/examples/models/llama/runner:eager_runner_library", + ], +) diff --git a/examples/models/llama/tests/test_export_llama_lib.py b/examples/models/llama/tests/test_export_llama_lib.py index 243c186cccc..3d1b6e85a81 100644 --- a/examples/models/llama/tests/test_export_llama_lib.py +++ b/examples/models/llama/tests/test_export_llama_lib.py @@ -22,10 +22,12 @@ TOSAQuantizer = None from executorch.examples.models.llama.export_llama_lib import ( + _get_additional_export_passes, _export_llama, build_args_parser, get_quantizer_and_quant_params, ) +from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass from executorch.extension.llm.export.config.llm_config import LlmConfig, Pt2eQuantize UNWANTED_OPS = [ @@ -35,6 +37,24 @@ class ExportLlamaLibTest(unittest.TestCase): + def test_qwen3_5_mutable_buffer_passes(self): + passes = _get_additional_export_passes("qwen3_5_0_8b") + self.assertEqual(len(passes), 1) + self.assertIsInstance(passes[0], InitializedMutableBufferPass) + self.assertEqual( + passes[0].patterns, + ["k_cache", "v_cache", "conv_state", "recurrent_state"], + ) + + def test_torchtune_mutable_buffer_passes(self): + passes = _get_additional_export_passes("llama3_2_vision") + self.assertEqual(len(passes), 1) + self.assertIsInstance(passes[0], InitializedMutableBufferPass) + self.assertEqual(passes[0].patterns, ["kv_cache_pos"]) + + def test_llama3_has_no_extra_mutable_buffer_passes(self): + self.assertEqual(_get_additional_export_passes("llama3"), []) + def test_has_expected_ops_and_op_counts(self): """ Checks the presence of unwanted expensive ops. diff --git a/examples/models/llama/tests/test_generation_prefill.py b/examples/models/llama/tests/test_generation_prefill.py new file mode 100644 index 00000000000..b3cd3f68eb1 --- /dev/null +++ b/examples/models/llama/tests/test_generation_prefill.py @@ -0,0 +1,99 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from unittest.mock import patch + +import torch + +from executorch.examples.models.llama.runner.generation import LlamaRunner + + +class _DummyTokenizer: + n_words = 100 + eos_id = 2 + stop_tokens = [2] + + def encode(self, _text, bos=False, eos=False): + del bos + del eos + return [10, 11] + + def decode_token(self, token_id): + return str(token_id) + + +class _DummyRunner(LlamaRunner): + def __init__(self, raise_on_parallel_prefill=False): + self.calls = [] + self.raise_on_parallel_prefill = raise_on_parallel_prefill + super().__init__( + tokenizer_path="unused", + tokenizer_config_path=None, + max_seq_len=16, + max_batch_size=1, + use_kv_cache=True, + vocab_size=100, + device="cpu", + ) + + def forward(self, tokens: torch.Tensor, input_pos=None) -> torch.Tensor: + self.calls.append( + (tokens.clone(), input_pos.clone() if input_pos is not None else None) + ) + if self.raise_on_parallel_prefill and tokens.shape[1] > 1: + raise RuntimeError("parallel prefill failure") + return torch.zeros((1, 8), dtype=torch.float32) + + +class TestGenerationPrefill(unittest.TestCase): + @patch( + "executorch.examples.models.llama.runner.generation.get_tokenizer", + return_value=_DummyTokenizer(), + ) + def test_static_prefill_uses_sequential_tokens(self, _mock_get_tokenizer): + runner = _DummyRunner() + runner.enable_dynamic_shape = False + + runner._prefill_with_kv_cache([5, 6, 7], pos_base=3) + + self.assertEqual(len(runner.calls), 3) + for i, (tokens, input_pos) in enumerate(runner.calls): + self.assertEqual(tuple(tokens.shape), (1, 1)) + self.assertEqual(tokens.item(), 5 + i) + self.assertEqual(input_pos.item(), 3 + i) + + @patch( + "executorch.examples.models.llama.runner.generation.get_tokenizer", + return_value=_DummyTokenizer(), + ) + def test_dynamic_prefill_uses_batched_prompt(self, _mock_get_tokenizer): + runner = _DummyRunner() + runner.enable_dynamic_shape = True + + runner._prefill_with_kv_cache([5, 6, 7], pos_base=4) + + self.assertEqual(len(runner.calls), 1) + tokens, input_pos = runner.calls[0] + self.assertEqual(tuple(tokens.shape), (1, 3)) + self.assertEqual(input_pos.item(), 4) + + @patch( + "executorch.examples.models.llama.runner.generation.get_tokenizer", + return_value=_DummyTokenizer(), + ) + def test_dynamic_prefill_does_not_mask_runtime_errors(self, _mock_get_tokenizer): + runner = _DummyRunner(raise_on_parallel_prefill=True) + runner.enable_dynamic_shape = True + + with self.assertRaisesRegex(RuntimeError, "parallel prefill failure"): + runner._prefill_with_kv_cache([5, 6], pos_base=0) + + self.assertEqual(len(runner.calls), 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/models/llama/tests/test_qwen3_5_attention.py b/examples/models/llama/tests/test_qwen3_5_attention.py new file mode 100644 index 00000000000..d5598eb4196 --- /dev/null +++ b/examples/models/llama/tests/test_qwen3_5_attention.py @@ -0,0 +1,128 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.examples.models.llama.attention import ATTENTION_REGISTRY +from executorch.examples.models.llama.model_args import ModelArgs +from executorch.examples.models.llama.norm import RMSNorm +from executorch.examples.models.llama.rope import Rope + + +class Qwen35AttentionTest(unittest.TestCase): + def test_qwen35_full_attention_output_proj_is_bias_free(self): + args = ModelArgs( + dim=32, + n_layers=1, + n_heads=4, + n_kv_heads=2, + head_dim=8, + hidden_dim=64, + max_seq_len=16, + max_context_len=16, + use_kv_cache=False, + use_qk_norm=False, + qk_norm_before_rope=True, + attention_type="mha", + use_q_gate=True, + attention_qkv_bias=True, + ) + rope = Rope(args) + attn = ATTENTION_REGISTRY["mha"](args, 0, rope) + self.assertIsNone(attn.wo.bias) + + def test_rmsnorm_preserves_input_dtype_without_unit_offset(self): + norm = RMSNorm(dim=8, add_unit_offset=False) + x = torch.randn(2, 3, 8, dtype=torch.bfloat16) + y = norm(x) + self.assertEqual(y.dtype, x.dtype) + + def test_qwen35_full_attention_forward_shape(self): + torch.manual_seed(0) + args = ModelArgs( + dim=32, + n_layers=1, + n_heads=4, + n_kv_heads=2, + head_dim=8, + hidden_dim=64, + max_seq_len=16, + max_context_len=16, + use_kv_cache=False, + use_hf_rope=True, + partial_rotary_factor=0.5, + use_qk_norm=True, + qk_norm_before_rope=True, + attention_type="mha", + use_q_gate=True, + rms_norm_add_unit_offset=True, + ) + rope = Rope(args) + attn = ATTENTION_REGISTRY["mha"](args, 0, rope) + x = torch.randn(1, 3, args.dim) + freqs_cos, freqs_sin = rope.get_freqs(None, x.shape[1]) + y, _ = attn(x, freqs_cos, freqs_sin) + self.assertEqual(y.shape, x.shape) + + def test_qwen35_full_attention_legacy_name_maps_to_gated_mha(self): + args = ModelArgs( + dim=32, + n_layers=1, + n_heads=4, + n_kv_heads=2, + head_dim=8, + hidden_dim=64, + attention_type="qwen3_5_full", + ) + self.assertTrue(args.use_q_gate) + rope = Rope(args) + attn = ATTENTION_REGISTRY["qwen3_5_full"](args, 0, rope) + self.assertTrue(attn.use_q_gate) + + def test_gated_deltanet_resets_state_on_new_sequence(self): + torch.manual_seed(0) + args = ModelArgs( + dim=32, + n_layers=1, + n_heads=4, + n_kv_heads=2, + head_dim=8, + hidden_dim=64, + max_seq_len=16, + max_context_len=16, + use_kv_cache=True, + attention_type="mha", + use_q_gate=True, + linear_conv_kernel_dim=4, + linear_key_head_dim=4, + linear_value_head_dim=4, + linear_num_key_heads=2, + linear_num_value_heads=4, + ) + rope = Rope(args) + attn = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope) + + x = torch.randn(1, 1, args.dim) + dummy_freq = torch.zeros(1, 1) + + # First token of sequence. + attn(x, dummy_freq, dummy_freq, input_pos=torch.tensor([0], dtype=torch.long)) + state_after_first = attn.recurrent_state.clone() + + # Decode continuation updates state. + attn(x, dummy_freq, dummy_freq, input_pos=torch.tensor([1], dtype=torch.long)) + state_after_second = attn.recurrent_state.clone() + self.assertFalse(torch.allclose(state_after_first, state_after_second)) + + # New sequence (input_pos=0) should reset internal state. + attn(x, dummy_freq, dummy_freq, input_pos=torch.tensor([0], dtype=torch.long)) + state_after_reset = attn.recurrent_state.clone() + self.assertTrue(torch.allclose(state_after_first, state_after_reset, atol=1e-5)) + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/models/qwen3_5/BUCK b/examples/models/qwen3_5/BUCK new file mode 100644 index 00000000000..9ee84a0ec5e --- /dev/null +++ b/examples/models/qwen3_5/BUCK @@ -0,0 +1,23 @@ +load("@fbcode_macros//build_defs:build_file_migration.bzl", "fbcode_target", "non_fbcode_target") +# Any targets that should be shared between fbcode and xplat must be defined in +# targets.bzl. This file can contain fbcode-only targets. + +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +oncall("executorch") + +fbcode_target(_kind = runtime.python_library, + name = "qwen3_5", + srcs = [ + "__init__.py", + "convert_weights.py", + ], + base_module = "executorch.examples.models.qwen3_5", + visibility = ["PUBLIC"], + deps = [ + "//caffe2:torch", + "//executorch/examples/models/llama:llama2_model", + "//executorch/examples/models:checkpoint", + "fbsource//third-party/pypi/safetensors:safetensors", + ], +) diff --git a/examples/models/qwen3_5/README.md b/examples/models/qwen3_5/README.md new file mode 100644 index 00000000000..e439f256f2a --- /dev/null +++ b/examples/models/qwen3_5/README.md @@ -0,0 +1,59 @@ +## Summary +[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35-684357f8543f83de2f09f998) support in ExecuTorch is exported through the Llama example pipeline with a hybrid layer layout: +- `full_attention` layers use gated full attention. +- `linear_attention` layers use Gated DeltaNet with internal recurrent state. + +This first bring-up is **fp32 + static shape** (`enable_dynamic_shape=False`). +Currently supported text model sizes: `0.8B`, `2B`, `4B`. + +## Export +```bash +python -m extension.llm.export.export_llm \ + --config examples/models/qwen3_5/config/qwen3_5_xnnpack_fp32.yaml \ + +base.model_class="qwen3_5_0_8b" \ + +base.params="examples/models/qwen3_5/config/0_8b_config.json" \ + +export.output_name="qwen3_5_0_8b_fp32.pte" +``` + +```bash +python -m extension.llm.export.export_llm \ + --config examples/models/qwen3_5/config/qwen3_5_xnnpack_fp32.yaml \ + +base.model_class="qwen3_5_2b" \ + +base.params="examples/models/qwen3_5/config/2b_config.json" \ + +export.output_name="qwen3_5_2b_fp32.pte" +``` + +```bash +python -m extension.llm.export.export_llm \ + --config examples/models/qwen3_5/config/qwen3_5_xnnpack_fp32.yaml \ + +base.model_class="qwen3_5_4b" \ + +base.params="examples/models/qwen3_5/config/4b_config.json" \ + +export.output_name="qwen3_5_4b_fp32.pte" +``` + +The exporter will download and convert HF weights automatically when `+base.checkpoint` is not provided. +Install `safetensors` in your environment if it is missing: +```bash +python -m pip install safetensors +``` + +## Run (Python Runner) +```bash +python -m executorch.examples.models.llama.runner.native \ + --model qwen3_5_0_8b \ + --pte qwen3_5_0_8b_fp32.pte \ + --tokenizer /path/to/tokenizer.json \ + --tokenizer_config /path/to/tokenizer_config.json \ + --prompt "<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n" \ + --params examples/models/qwen3_5/config/0_8b_config.json \ + --max_len 128 \ + -kv \ + --temperature 0.3 +``` + +## Notes +- Current path targets CPU/XNNPACK export validation and runner compatibility. +- `q8da4w` quantization for Qwen3.5 is intentionally deferred to a follow-up. +- Dynamic-shape export is not enabled yet for Qwen3.5 DeltaNet layers in this path; keep `enable_dynamic_shape=False`. +- For static-shape exports, `runner.native` falls back to sequential token prefill for multi-token prompts. +- Default metadata uses Qwen3.5 special token ids: `get_bos_id=248045`, `get_eos_ids=[248046,248044]`. diff --git a/examples/models/qwen3_5/__init__.py b/examples/models/qwen3_5/__init__.py new file mode 100644 index 00000000000..b336832bb4b --- /dev/null +++ b/examples/models/qwen3_5/__init__.py @@ -0,0 +1,19 @@ +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.examples.models.qwen3_5.convert_weights import convert_weights + +__all__ = ["Qwen3_5Model", "convert_weights"] + + +def __getattr__(name): + if name == "Qwen3_5Model": + from executorch.examples.models.llama.model import Llama2Model + + class Qwen3_5Model(Llama2Model): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + globals()["Qwen3_5Model"] = Qwen3_5Model + return Qwen3_5Model + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") diff --git a/examples/models/qwen3_5/config/0_8b_config.json b/examples/models/qwen3_5/config/0_8b_config.json new file mode 100644 index 00000000000..89799cf9f53 --- /dev/null +++ b/examples/models/qwen3_5/config/0_8b_config.json @@ -0,0 +1,51 @@ +{ + "dim": 1024, + "hidden_dim": 3584, + "n_heads": 8, + "head_dim": 256, + "n_kv_heads": 2, + "n_layers": 24, + "norm_eps": 1e-6, + "rope_theta": 10000000.0, + "use_scaled_rope": false, + "vocab_size": 248320, + "use_hf_rope": true, + "partial_rotary_factor": 0.25, + "attention_qkv_bias": false, + "use_qk_norm": true, + "qk_norm_before_rope": true, + "attention_type": "mha", + "use_q_gate": true, + "rms_norm_add_unit_offset": true, + "linear_conv_kernel_dim": 4, + "linear_key_head_dim": 128, + "linear_value_head_dim": 128, + "linear_num_key_heads": 16, + "linear_num_value_heads": 16, + "layer_types": [ + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention" + ] +} diff --git a/examples/models/qwen3_5/config/2b_config.json b/examples/models/qwen3_5/config/2b_config.json new file mode 100644 index 00000000000..15f62ccca74 --- /dev/null +++ b/examples/models/qwen3_5/config/2b_config.json @@ -0,0 +1,51 @@ +{ + "dim": 2048, + "hidden_dim": 6144, + "n_heads": 8, + "head_dim": 256, + "n_kv_heads": 2, + "n_layers": 24, + "norm_eps": 1e-6, + "rope_theta": 10000000.0, + "use_scaled_rope": false, + "vocab_size": 248320, + "use_hf_rope": true, + "partial_rotary_factor": 0.25, + "attention_qkv_bias": false, + "use_qk_norm": true, + "qk_norm_before_rope": true, + "attention_type": "mha", + "use_q_gate": true, + "rms_norm_add_unit_offset": true, + "linear_conv_kernel_dim": 4, + "linear_key_head_dim": 128, + "linear_value_head_dim": 128, + "linear_num_key_heads": 16, + "linear_num_value_heads": 16, + "layer_types": [ + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention" + ] +} diff --git a/examples/models/qwen3_5/config/4b_config.json b/examples/models/qwen3_5/config/4b_config.json new file mode 100644 index 00000000000..068653a8d5c --- /dev/null +++ b/examples/models/qwen3_5/config/4b_config.json @@ -0,0 +1,59 @@ +{ + "dim": 2560, + "hidden_dim": 9216, + "n_heads": 16, + "head_dim": 256, + "n_kv_heads": 4, + "n_layers": 32, + "norm_eps": 1e-6, + "rope_theta": 10000000.0, + "use_scaled_rope": false, + "vocab_size": 248320, + "use_hf_rope": true, + "partial_rotary_factor": 0.25, + "attention_qkv_bias": false, + "use_qk_norm": true, + "qk_norm_before_rope": true, + "attention_type": "mha", + "use_q_gate": true, + "rms_norm_add_unit_offset": true, + "linear_conv_kernel_dim": 4, + "linear_key_head_dim": 128, + "linear_value_head_dim": 128, + "linear_num_key_heads": 16, + "linear_num_value_heads": 32, + "layer_types": [ + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention" + ] +} diff --git a/examples/models/qwen3_5/config/qwen3_5_xnnpack_fp32.yaml b/examples/models/qwen3_5/config/qwen3_5_xnnpack_fp32.yaml new file mode 100644 index 00000000000..93882f67a6e --- /dev/null +++ b/examples/models/qwen3_5/config/qwen3_5_xnnpack_fp32.yaml @@ -0,0 +1,21 @@ +base: + # Qwen3.5 tokenizer special ids: + # <|im_start|> 248045 + # <|im_end|> 248046 + # <|endoftext|> 248044 + metadata: '{"get_bos_id": 248045, "get_eos_ids":[248046,248044]}' + +model: + use_kv_cache: True + use_sdpa_with_kv_cache: False + enable_dynamic_shape: False + dtype_override: fp32 + +export: + max_seq_length: 2048 + max_context_length: 2048 + +backend: + xnnpack: + enabled: True + extended_ops: True diff --git a/examples/models/qwen3_5/convert_weights.py b/examples/models/qwen3_5/convert_weights.py new file mode 100644 index 00000000000..b0a8fc47305 --- /dev/null +++ b/examples/models/qwen3_5/convert_weights.py @@ -0,0 +1,230 @@ +import argparse +import json +import os +import re +from typing import Dict + +import torch +from executorch.examples.models.checkpoint import ( + get_mapped_key, + load_checkpoint_from_pytorch_model, +) + +_QWEN_3_5_TO_META = { + "model.embed_tokens.weight": "tok_embeddings.weight", + "model.norm.weight": "norm.weight", + "lm_head.weight": "output.weight", + "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", + "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", + "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", + # Full-attention layers. + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", + "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", + "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", + "model.layers.{}.self_attn.q_norm.weight": "layers.{}.attention.q_norm_fn.weight", + "model.layers.{}.self_attn.k_norm.weight": "layers.{}.attention.k_norm_fn.weight", + # Linear-attention layers. + "model.layers.{}.linear_attn.in_proj_qkv.weight": "layers.{}.attention.in_proj_qkv.weight", + "model.layers.{}.linear_attn.in_proj_z.weight": "layers.{}.attention.in_proj_z.weight", + "model.layers.{}.linear_attn.in_proj_b.weight": "layers.{}.attention.in_proj_b.weight", + "model.layers.{}.linear_attn.in_proj_a.weight": "layers.{}.attention.in_proj_a.weight", + "model.layers.{}.linear_attn.conv1d.weight": "layers.{}.attention.conv1d.weight", + "model.layers.{}.linear_attn.conv1d.bias": "layers.{}.attention.conv1d.bias", + "model.layers.{}.linear_attn.dt_bias": "layers.{}.attention.dt_bias", + "model.layers.{}.linear_attn.A_log": "layers.{}.attention.A_log", + "model.layers.{}.linear_attn.norm.weight": "layers.{}.attention.norm.weight", + "model.layers.{}.linear_attn.out_proj.weight": "layers.{}.attention.out_proj.weight", +} + + +_IGNORED_UNMAPPED_PREFIXES = ( + "mtp.", + "model.visual.", + "visual.", +) + +_IGNORED_UNMAPPED_SUBSTRINGS = ( + ".vision_", + ".visual.", +) + +_IGNORED_UNMAPPED_SUFFIXES = ( + "rotary_emb.inv_freq", +) + + +def _should_ignore_unmapped_key(key: str, normalized_key: str) -> bool: + candidates = (key, normalized_key) + for candidate in candidates: + if any(candidate.startswith(prefix) for prefix in _IGNORED_UNMAPPED_PREFIXES): + return True + if any(token in candidate for token in _IGNORED_UNMAPPED_SUBSTRINGS): + return True + if any(candidate.endswith(suffix) for suffix in _IGNORED_UNMAPPED_SUFFIXES): + return True + return False + + +def _load_checkpoint_from_safetensors(input_dir: str) -> Dict: + from safetensors.torch import load_file + + index_path = os.path.join(input_dir, "model.safetensors.index.json") + if os.path.exists(index_path): + with open(index_path, "r") as f: + index = json.load(f) + weight_map = index["weight_map"] + checkpoint_shards = sorted(set(weight_map.values())) + + shard_to_weights = {} + for shard in checkpoint_shards: + shard_to_weights[shard] = load_file(os.path.join(input_dir, shard)) + + merged_state_dict = {} + for weight_name, shard in weight_map.items(): + merged_state_dict[weight_name] = shard_to_weights[shard][weight_name] + return merged_state_dict + + model_path = os.path.join(input_dir, "model.safetensors") + if os.path.exists(model_path): + return load_file(model_path) + + raise FileNotFoundError(f"Could not find safetensors checkpoint in {input_dir}") + + +def load_checkpoint(input_dir: str) -> Dict: + try: + print("Loading checkpoint from pytorch_model directory") + return load_checkpoint_from_pytorch_model(input_dir) + except FileNotFoundError: + print( + "Could not find pytorch_model checkpoints in directory, trying safetensors" + ) + + try: + print("Loading checkpoint from safetensors directory") + return _load_checkpoint_from_safetensors(input_dir) + except FileNotFoundError: + pass + + raise FileNotFoundError(f"Could not find checkpoint in {input_dir}") + + +def qwen_3_5_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + converted_state_dict = {} + pending_qkvz = {} + pending_ba = {} + + for key, value in state_dict.items(): + normalized_key = key + # HF multimodal Qwen3.5 checkpoints store text weights under + # `model.language_model.*`. Normalize to `model.*`. + if normalized_key.startswith("model.language_model."): + normalized_key = normalized_key.replace( + "model.language_model.", "model.", 1 + ) + + # Ignore non-language-model keys up front. + if not ( + normalized_key.startswith("model.") or normalized_key.startswith("lm_head.") + ): + if _should_ignore_unmapped_key(key, normalized_key): + continue + continue + + # Legacy packed tensors (older checkpoints): + # in_proj_qkvz -> split into in_proj_qkv and in_proj_z + # in_proj_ba -> split into in_proj_b and in_proj_a + if normalized_key.endswith(".linear_attn.in_proj_qkvz.weight"): + pending_qkvz[normalized_key] = value + continue + if normalized_key.endswith(".linear_attn.in_proj_ba.weight"): + pending_ba[normalized_key] = value + continue + + try: + new_key = get_mapped_key(normalized_key, _QWEN_3_5_TO_META) + except Exception as err: + if _should_ignore_unmapped_key(key, normalized_key): + continue + raise ValueError( + f"Unexpected checkpoint key not mapped for Qwen3.5 export: {key}" + ) from err + converted_state_dict[new_key] = value + + for key, value in pending_qkvz.items(): + layer_match = re.search(r"model\.layers\.(\d+)\.", key) + if layer_match is None: + raise ValueError(f"Failed to parse layer id from key: {key}") + layer_id = layer_match.group(1) + out_proj_key = f"layers.{layer_id}.attention.out_proj.weight" + if out_proj_key not in converted_state_dict: + raise ValueError( + f"Cannot split {key}: missing {out_proj_key} to infer value dimension." + ) + + value_dim = converted_state_dict[out_proj_key].shape[1] + total_dim = value.shape[0] + conv_dim = total_dim - value_dim + if conv_dim <= 0 or (conv_dim - value_dim) % 2 != 0: + raise ValueError( + f"Invalid packed in_proj_qkvz shape for {key}: {tuple(value.shape)}" + ) + qkv, z = torch.split(value, [conv_dim, value_dim], dim=0) + converted_state_dict[f"layers.{layer_id}.attention.in_proj_qkv.weight"] = qkv + converted_state_dict[f"layers.{layer_id}.attention.in_proj_z.weight"] = z + print(f"Split legacy packed key {key} -> in_proj_qkv + in_proj_z") + + for key, value in pending_ba.items(): + layer_match = re.search(r"model\.layers\.(\d+)\.", key) + if layer_match is None: + raise ValueError(f"Failed to parse layer id from key: {key}") + layer_id = layer_match.group(1) + if value.shape[0] % 2 != 0: + raise ValueError( + f"Invalid packed in_proj_ba shape for {key}: {tuple(value.shape)}" + ) + half = value.shape[0] // 2 + b, a = torch.split(value, [half, half], dim=0) + converted_state_dict[f"layers.{layer_id}.attention.in_proj_b.weight"] = b + converted_state_dict[f"layers.{layer_id}.attention.in_proj_a.weight"] = a + print(f"Split legacy packed key {key} -> in_proj_b + in_proj_a") + + # Handle tied embeddings. + if "lm_head.weight" not in state_dict: + converted_state_dict["output.weight"] = converted_state_dict[ + "tok_embeddings.weight" + ] + + return converted_state_dict + + +def convert_weights(input_dir: str, output_file: str) -> None: + print("Loading checkpoint...") + state_dict = load_checkpoint(input_dir) + print("Converting checkpoint...") + state_dict = qwen_3_5_to_meta(state_dict) + print("Saving checkpoint...") + torch.save(state_dict, output_file) + print("Done.") + + +def main(): + parser = argparse.ArgumentParser( + description="Convert Qwen3.5 weights to ExecuTorch meta format." + ) + parser.add_argument( + "input_dir", + type=str, + help="Path to directory containing safetensor or PyTorch checkpoint files.", + ) + parser.add_argument("output", type=str, help="Path to the output checkpoint") + + args = parser.parse_args() + convert_weights(args.input_dir, args.output) + + +if __name__ == "__main__": + main() diff --git a/examples/models/qwen3_5/tests/__init__.py b/examples/models/qwen3_5/tests/__init__.py new file mode 100644 index 00000000000..5e827e62ea1 --- /dev/null +++ b/examples/models/qwen3_5/tests/__init__.py @@ -0,0 +1,2 @@ +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/examples/models/qwen3_5/tests/test_convert_weights.py b/examples/models/qwen3_5/tests/test_convert_weights.py new file mode 100644 index 00000000000..7fa2fb667c7 --- /dev/null +++ b/examples/models/qwen3_5/tests/test_convert_weights.py @@ -0,0 +1,69 @@ +import unittest + +import torch +from executorch.examples.models.qwen3_5.convert_weights import qwen_3_5_to_meta + + +class Qwen35ConvertWeightsTest(unittest.TestCase): + def test_maps_full_and_linear_attention_weights(self): + state_dict = { + "model.embed_tokens.weight": torch.randn(16, 8), + "model.norm.weight": torch.randn(8), + "lm_head.weight": torch.randn(16, 8), + "model.layers.0.input_layernorm.weight": torch.randn(8), + "model.layers.0.post_attention_layernorm.weight": torch.randn(8), + "model.layers.0.mlp.gate_proj.weight": torch.randn(12, 8), + "model.layers.0.mlp.down_proj.weight": torch.randn(8, 12), + "model.layers.0.mlp.up_proj.weight": torch.randn(12, 8), + "model.layers.0.self_attn.q_proj.weight": torch.randn(16, 8), + "model.layers.0.self_attn.k_proj.weight": torch.randn(8, 8), + "model.layers.0.self_attn.v_proj.weight": torch.randn(8, 8), + "model.layers.0.self_attn.o_proj.weight": torch.randn(8, 8), + "model.layers.0.self_attn.q_norm.weight": torch.randn(4), + "model.layers.0.self_attn.k_norm.weight": torch.randn(4), + "model.layers.1.linear_attn.in_proj_qkv.weight": torch.randn(24, 8), + "model.layers.1.linear_attn.in_proj_z.weight": torch.randn(8, 8), + "model.layers.1.linear_attn.in_proj_b.weight": torch.randn(2, 8), + "model.layers.1.linear_attn.in_proj_a.weight": torch.randn(2, 8), + "model.layers.1.linear_attn.conv1d.weight": torch.randn(24, 1, 4), + "model.layers.1.linear_attn.dt_bias": torch.randn(2), + "model.layers.1.linear_attn.A_log": torch.randn(2), + "model.layers.1.linear_attn.norm.weight": torch.randn(4), + "model.layers.1.linear_attn.out_proj.weight": torch.randn(8, 8), + } + + converted = qwen_3_5_to_meta(state_dict) + self.assertIn("layers.0.attention.wq.weight", converted) + self.assertIn("layers.0.attention.q_norm_fn.weight", converted) + self.assertIn("layers.1.attention.in_proj_qkv.weight", converted) + self.assertIn("layers.1.attention.out_proj.weight", converted) + self.assertIn("layers.1.attention.dt_bias", converted) + + def test_raises_on_unexpected_text_key(self): + state_dict = { + "model.embed_tokens.weight": torch.randn(16, 8), + "model.norm.weight": torch.randn(8), + "model.layers.0.unknown.weight": torch.randn(8, 8), + } + + with self.assertRaisesRegex( + ValueError, "Unexpected checkpoint key not mapped for Qwen3.5 export" + ): + qwen_3_5_to_meta(state_dict) + + def test_ignores_known_non_text_keys(self): + state_dict = { + "model.embed_tokens.weight": torch.randn(16, 8), + "model.norm.weight": torch.randn(8), + "mtp.proj.weight": torch.randn(8, 8), + "model.visual.patch_embed.weight": torch.randn(8, 8), + } + + converted = qwen_3_5_to_meta(state_dict) + self.assertIn("tok_embeddings.weight", converted) + self.assertIn("output.weight", converted) + self.assertNotIn("mtp.proj.weight", converted) + + +if __name__ == "__main__": + unittest.main() diff --git a/extension/llm/export/config/llm_config.py b/extension/llm/export/config/llm_config.py index eea99ea9b2b..06df3aa38be 100644 --- a/extension/llm/export/config/llm_config.py +++ b/extension/llm/export/config/llm_config.py @@ -44,6 +44,9 @@ class ModelType(str, Enum): qwen3_0_6b = "qwen3_0_6b" qwen3_1_7b = "qwen3_1_7b" qwen3_4b = "qwen3_4b" + qwen3_5_0_8b = "qwen3_5_0_8b" + qwen3_5_2b = "qwen3_5_2b" + qwen3_5_4b = "qwen3_5_4b" phi_4_mini = "phi_4_mini" smollm2 = "smollm2" lfm2_350m = "lfm2_350m"