Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/models/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ fbcode_target(_kind = python_library,
"//executorch/examples/models/gemma3:gemma3", # @manual
"//executorch/examples/models/qwen2_5:qwen2_5", # @manual
"//executorch/examples/models/qwen3:qwen3", # @manual
"//executorch/examples/models/qwen3_5:qwen3_5", # @manual
"//executorch/examples/models/phi_4_mini:phi_4_mini", # @manual
"//executorch/examples/models/smollm2:smollm2", # @manual
"//executorch/examples/models/smollm3:smollm3", # @manual
Expand Down
17 changes: 13 additions & 4 deletions examples/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,17 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .model import Llama2Model
from typing import TYPE_CHECKING

__all__ = [
Llama2Model,
]
if TYPE_CHECKING:
from .model import Llama2Model

__all__ = ["Llama2Model"]


def __getattr__(name):
if name == "Llama2Model":
from .model import Llama2Model

return Llama2Model
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
Comment on lines +7 to +20
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This changes the llama __init__.py from a direct import (from .model import Llama2Model) to a lazy __getattr__ pattern. All other model packages that import Llama2Model (including qwen2_5, qwen3, phi_4_mini, smollm2, gemma, granite, etc.) still use a direct top-level import from executorch.examples.models.llama.model. If there's no circular import or startup performance issue driving this change, it adds unnecessary complexity and inconsistency. Consider keeping the direct import pattern consistent with the rest of the codebase.

Copilot uses AI. Check for mistakes.
261 changes: 250 additions & 11 deletions examples/models/llama/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch.nn.functional as F
from executorch.examples.models.llama.lora import LoRALinear
from executorch.examples.models.llama.model_args import ModelArgs
from executorch.examples.models.llama.norm import RMSNorm
from executorch.examples.models.llama.norm import RMSNorm, RMSNormGated
from executorch.examples.models.llama.rope import Rope


Expand Down Expand Up @@ -347,27 +347,35 @@ def __init__(
self.attention_qkv_bias = args.attention_qkv_bias
self.use_qk_norm = args.use_qk_norm
self.qk_norm_before_rope = args.qk_norm_before_rope
self.use_q_gate = args.use_q_gate
self.enable_dynamic_shape = args.enable_dynamic_shape
q_out_dim = self.n_heads * self.head_dim * (2 if self.use_q_gate else 1)

if self.use_qk_norm:
q_norm_dim = self.head_dim
k_norm_dim = self.head_dim
self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps)
self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps)
self.q_norm_fn = RMSNorm(
q_norm_dim,
eps=args.norm_eps,
add_unit_offset=args.rms_norm_add_unit_offset,
)
self.k_norm_fn = RMSNorm(
k_norm_dim,
eps=args.norm_eps,
add_unit_offset=args.rms_norm_add_unit_offset,
)

self.wq = (
LoRALinear(
in_dim=args.dim,
out_dim=args.n_heads * args.head_dim,
out_dim=q_out_dim,
rank=args.r,
alpha=args.lora_alpha,
dropout=0.0,
use_bias=args.attention_qkv_bias,
)
if args.target_modules is not None and "q_proj" in args.target_modules
else nn.Linear(
self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
)
else nn.Linear(self.dim, q_out_dim, bias=self.attention_qkv_bias)
)
self.wk = (
LoRALinear(
Expand Down Expand Up @@ -452,10 +460,17 @@ def forward(
input_pos = kwargs.get("input_pos")
bsz, seqlen, _ = x.shape

# QKV
q, k, v = self.wq(x), self.wk(x), self.wv(x)
# We need view_copy elimination
q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim)
if self.use_q_gate:
q_and_gate = self.wq(x).view(
bsz, seqlen, self.n_local_heads, self.head_dim * 2
)
q, gate = torch.chunk(q_and_gate, 2, dim=-1)
gate = gate.reshape(bsz, seqlen, -1)
else:
q = self.wq(x).view(bsz, seqlen, self.n_local_heads, self.head_dim)
gate = None

k, v = self.wk(x), self.wv(x)
k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

Expand Down Expand Up @@ -492,6 +507,8 @@ def forward(
input_pos[0].item(), seqlen
)
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask)
if gate is not None:
output = output * torch.sigmoid(gate)
return self.wo(output), None

# grouped multiquery attention: expand out keys and values
Expand All @@ -505,12 +522,234 @@ def forward(
output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)

output = output.transpose(1, 2).reshape(bsz, seqlen, -1)
if gate is not None:
output = output * torch.sigmoid(gate)

output = self.wo(output)

return output, None


def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor:
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
return x * inv_norm


@register_attention("gated_deltanet")
class AttentionGatedDeltaNet(Attention):
"""Qwen3.5 linear-attention (Gated DeltaNet) block with internal state."""

def __init__(
self,
args: ModelArgs,
layer_id: int,
rope: Rope,
**_kwargs: Any,
):
super().__init__()
del rope # DeltaNet layers do not use RoPE.

self.hidden_size = args.dim
self.max_batch_size = args.max_batch_size
self.layer_id = layer_id

assert args.linear_num_key_heads is not None
assert args.linear_num_value_heads is not None
assert args.linear_key_head_dim is not None
assert args.linear_value_head_dim is not None

self.num_k_heads = args.linear_num_key_heads
self.num_v_heads = args.linear_num_value_heads
self.head_k_dim = args.linear_key_head_dim
self.head_v_dim = args.linear_value_head_dim
self.key_dim = self.head_k_dim * self.num_k_heads
self.value_dim = self.head_v_dim * self.num_v_heads
self.conv_kernel_size = args.linear_conv_kernel_dim

assert (
self.num_v_heads % self.num_k_heads == 0
), "linear_num_value_heads must be divisible by linear_num_key_heads."
self.head_repeat = self.num_v_heads // self.num_k_heads

self.conv_dim = self.key_dim * 2 + self.value_dim
self.in_proj_qkv = nn.Linear(self.hidden_size, self.conv_dim, bias=False)
self.in_proj_z = nn.Linear(self.hidden_size, self.value_dim, bias=False)
self.in_proj_b = nn.Linear(self.hidden_size, self.num_v_heads, bias=False)
self.in_proj_a = nn.Linear(self.hidden_size, self.num_v_heads, bias=False)

self.conv1d = nn.Conv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
kernel_size=self.conv_kernel_size,
groups=self.conv_dim,
bias=False,
padding=0,
)

self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads))
A = torch.empty(self.num_v_heads).uniform_(0, 16)
self.A_log = nn.Parameter(torch.log(A))
self.norm = RMSNormGated(self.head_v_dim, eps=args.norm_eps)
self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)

self.register_buffer(
"conv_state",
torch.zeros(
self.max_batch_size,
self.conv_dim,
self.conv_kernel_size,
dtype=torch.float32,
device="cpu",
),
)
self.register_buffer(
"recurrent_state",
torch.zeros(
self.max_batch_size,
self.num_v_heads,
self.head_k_dim,
self.head_v_dim,
dtype=torch.float32,
device="cpu",
),
)

def _maybe_reset_state(
self, input_pos: Optional[torch.Tensor], batch_size: int
) -> None:
if input_pos is None:
self.conv_state[:batch_size].zero_()
self.recurrent_state[:batch_size].zero_()
return
reset = (input_pos[0] == 0).to(self.conv_state.dtype)
keep = 1.0 - reset
self.conv_state[:batch_size].mul_(keep)
self.recurrent_state[:batch_size].mul_(keep)

def _apply_causal_conv(self, mixed_qkv: torch.Tensor) -> torch.Tensor:
# mixed_qkv: (batch, seq_len, conv_dim)
batch_size, seq_len, _ = mixed_qkv.shape
mixed_qkv = mixed_qkv.transpose(1, 2)
state_len = self.conv_state.shape[-1]
hidden_states_new = torch.cat([self.conv_state[:batch_size], mixed_qkv], dim=-1)
new_conv_state = hidden_states_new[:, :, -state_len:]
with torch.no_grad():
self.conv_state[:batch_size].copy_(new_conv_state.to(self.conv_state.dtype))
out = F.conv1d(
hidden_states_new,
self.conv1d.weight,
self.conv1d.bias,
padding=0,
groups=self.conv_dim,
)
out = F.silu(out[:, :, -seq_len:]).to(mixed_qkv.dtype)
return out.transpose(1, 2).contiguous()

def _recurrent_gated_delta_rule(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
) -> torch.Tensor:
# query/key/value: (batch, seq_len, num_heads, head_dim)
# g/beta: (batch, seq_len, num_heads)
initial_dtype = query.dtype
query = _l2norm(query, dim=-1, eps=1e-6)
key = _l2norm(key, dim=-1, eps=1e-6)
query, key, value, beta, g = [
x.transpose(1, 2).contiguous().to(torch.float32)
for x in (query, key, value, beta, g)
]

batch_size, num_heads, sequence_length, k_head_dim = key.shape
v_head_dim = value.shape[-1]
scale = 1.0 / (query.shape[-1] ** 0.5)
query = query * scale

core_attn_out = torch.zeros(
batch_size,
num_heads,
sequence_length,
v_head_dim,
device=value.device,
dtype=value.dtype,
)
last_recurrent_state = self.recurrent_state[:batch_size].to(value.dtype)

for i in range(sequence_length):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so for dynamic shape we will need this to be a torch.scan, but really we should be swapping this to a custom op and doing a prefix sumish thing if we want to hope for any decent perf on prefill right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I didn't add this to make the PR simple (just wanted to add basic compatibility). Do you think torch.scan or a custom operation should be added to this PR or a subsequent one?

q_t = query[:, :, i]
k_t = key[:, :, i]
v_t = value[:, :, i]
g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1)
beta_t = beta[:, :, i].unsqueeze(-1)

last_recurrent_state = last_recurrent_state * g_t
kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
delta = (v_t - kv_mem) * beta_t
last_recurrent_state = last_recurrent_state + k_t.unsqueeze(
-1
) * delta.unsqueeze(-2)
core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(
dim=-2
)

Comment on lines +681 to +697
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_recurrent_gated_delta_rule uses a Python for loop over sequence_length, which will be very slow for long prefill sequences and hard for compilers/export backends to optimize. Consider rewriting this recurrence using vectorized/scan-style tensor ops (or constraining this path to decode-only seq_len==1 and providing a separate efficient prefill implementation).

Suggested change
for i in range(sequence_length):
q_t = query[:, :, i]
k_t = key[:, :, i]
v_t = value[:, :, i]
g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1)
beta_t = beta[:, :, i].unsqueeze(-1)
last_recurrent_state = last_recurrent_state * g_t
kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
delta = (v_t - kv_mem) * beta_t
last_recurrent_state = (
last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
)
core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(
dim=-2
)
# This recurrent implementation is intended for decode-only usage,
# where sequence_length == 1. Longer sequences should use a separate
# prefill implementation that can be efficiently vectorized.
if sequence_length != 1:
raise NotImplementedError(
"_recurrent_gated_delta_rule only supports decode-only "
"(sequence_length == 1). Use a dedicated prefill path for "
"longer sequences."
)
# Unrolled single-step recurrence for sequence_length == 1.
q_t = query[:, :, 0]
k_t = key[:, :, 0]
v_t = value[:, :, 0]
g_t = g[:, :, 0].exp().unsqueeze(-1).unsqueeze(-1)
beta_t = beta[:, :, 0].unsqueeze(-1)
last_recurrent_state = last_recurrent_state * g_t
kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
delta = (v_t - kv_mem) * beta_t
last_recurrent_state = (
last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
)
core_attn_out[:, :, 0] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(
dim=-2
)

Copilot uses AI. Check for mistakes.
with torch.no_grad():
self.recurrent_state[:batch_size].copy_(
last_recurrent_state.to(self.recurrent_state.dtype)
)

return core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)

def forward(
self,
x: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
**kwargs: ForwardOptions,
) -> Tuple[torch.Tensor, Optional[Any]]:
del freqs_cos
del freqs_sin
input_pos = kwargs.get("input_pos")
batch_size, seq_len, _ = x.shape
assert (
batch_size <= self.max_batch_size
), f"batch_size ({batch_size}) exceeds max_batch_size ({self.max_batch_size})"

self._maybe_reset_state(input_pos, batch_size)

Comment on lines +714 to +721
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AttentionGatedDeltaNet maintains internal conv_state/recurrent_state, but when input_pos is omitted it never resets, so outputs can depend on prior forward calls (state leakage across sequences) when the model is run without kv-cache / without passing input_pos. If this attention is only valid in kv-cache mode, consider asserting input_pos is not None (or alternatively resetting state when input_pos is None) to avoid silently incorrect results.

Copilot uses AI. Check for mistakes.
mixed_qkv = self.in_proj_qkv(x)
z = self.in_proj_z(x).reshape(batch_size, seq_len, -1, self.head_v_dim)
b = self.in_proj_b(x)
a = self.in_proj_a(x)

mixed_qkv = self._apply_causal_conv(mixed_qkv)
query, key, value = torch.split(
mixed_qkv,
[self.key_dim, self.key_dim, self.value_dim],
dim=-1,
)
query = query.reshape(batch_size, seq_len, -1, self.head_k_dim)
key = key.reshape(batch_size, seq_len, -1, self.head_k_dim)
value = value.reshape(batch_size, seq_len, -1, self.head_v_dim)

if self.head_repeat > 1:
query = query.repeat_interleave(self.head_repeat, dim=2)
key = key.repeat_interleave(self.head_repeat, dim=2)

beta = b.sigmoid()
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
core_attn_out = self._recurrent_gated_delta_rule(query, key, value, g, beta)

core_attn_out = core_attn_out.reshape(-1, self.head_v_dim)
z = z.reshape(-1, self.head_v_dim)
core_attn_out = self.norm(core_attn_out, z)
core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1)

return self.out_proj(core_attn_out), None


@register_attention("skip")
class AttentionSkip(Attention):
def __init__(self, *args, **kwargs):
Expand Down
10 changes: 9 additions & 1 deletion examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@
"qwen3_0_6b",
"qwen3_1_7b",
"qwen3_4b",
"qwen3_5_0_8b",
"qwen3_5_2b",
"qwen3_5_4b",
"phi_4_mini",
"smollm2",
"lfm2_350m", # hybrid
Expand All @@ -122,6 +125,9 @@
"qwen3_0_6b": "Qwen/Qwen3-0.6B",
"qwen3_1_7b": "Qwen/Qwen3-1.7B",
"qwen3_4b": "Qwen/Qwen3-4B",
"qwen3_5_0_8b": "Qwen/Qwen3.5-0.8B",
"qwen3_5_2b": "Qwen/Qwen3.5-2B",
"qwen3_5_4b": "Qwen/Qwen3.5-4B",
"lfm2_350m": "LiquidAI/LFM2-350M",
"lfm2_700m": "LiquidAI/LFM2-700M",
"lfm2_1_2b": "LiquidAI/LFM2-1.2B",
Expand Down Expand Up @@ -620,7 +626,7 @@ def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str:
return return_val


def export_llama(
def export_llama( # noqa: C901
export_options: Union[argparse.Namespace, LlmConfig, DictConfig],
) -> str:
if isinstance(export_options, argparse.Namespace):
Expand All @@ -643,6 +649,8 @@ def export_llama(
repo_id = HUGGING_FACE_REPO_IDS[model_name]
if model_name.startswith("qwen2_5"):
from executorch.examples.models.qwen2_5 import convert_weights
elif model_name.startswith("qwen3_5"):
from executorch.examples.models.qwen3_5 import convert_weights
elif model_name.startswith("qwen3"):
from executorch.examples.models.qwen3 import convert_weights
elif model_name == "phi_4_mini":
Expand Down
Loading
Loading