Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/models/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ fbcode_target(_kind = python_library,
"//executorch/examples/models/gemma3:gemma3", # @manual
"//executorch/examples/models/qwen2_5:qwen2_5", # @manual
"//executorch/examples/models/qwen3:qwen3", # @manual
"//executorch/examples/models/qwen3_5:qwen3_5", # @manual
"//executorch/examples/models/phi_4_mini:phi_4_mini", # @manual
"//executorch/examples/models/smollm2:smollm2", # @manual
"//executorch/examples/models/smollm3:smollm3", # @manual
Expand Down
17 changes: 13 additions & 4 deletions examples/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,17 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .model import Llama2Model
from typing import TYPE_CHECKING

__all__ = [
Llama2Model,
]
if TYPE_CHECKING:
from .model import Llama2Model

__all__ = ["Llama2Model"]


def __getattr__(name):
if name == "Llama2Model":
from .model import Llama2Model

return Llama2Model
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
258 changes: 249 additions & 9 deletions examples/models/llama/attention.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, Optional, Tuple, Type, TypedDict
Expand All @@ -7,7 +7,7 @@
import torch.nn.functional as F
from executorch.examples.models.llama.lora import LoRALinear
from executorch.examples.models.llama.model_args import ModelArgs
from executorch.examples.models.llama.norm import RMSNorm
from executorch.examples.models.llama.norm import RMSNorm, RMSNormGated
from executorch.examples.models.llama.rope import Rope


Expand Down Expand Up @@ -314,6 +314,7 @@
return self.k_cache, self.v_cache


@register_attention("qwen3_5_full")
@register_attention("mha")
class AttentionMHA(Attention):
def __init__(
Expand Down Expand Up @@ -347,26 +348,36 @@
self.attention_qkv_bias = args.attention_qkv_bias
self.use_qk_norm = args.use_qk_norm
self.qk_norm_before_rope = args.qk_norm_before_rope
self.use_q_gate = args.use_q_gate
self.enable_dynamic_shape = args.enable_dynamic_shape
q_out_dim = self.n_heads * self.head_dim * (2 if self.use_q_gate else 1)

if self.use_qk_norm:
q_norm_dim = self.head_dim
k_norm_dim = self.head_dim
self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps)
self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps)
self.q_norm_fn = RMSNorm(
q_norm_dim,
eps=args.norm_eps,
add_unit_offset=args.rms_norm_add_unit_offset,
)
self.k_norm_fn = RMSNorm(
k_norm_dim,
eps=args.norm_eps,
add_unit_offset=args.rms_norm_add_unit_offset,
)

self.wq = (
LoRALinear(
in_dim=args.dim,
out_dim=args.n_heads * args.head_dim,
out_dim=q_out_dim,
rank=args.r,
alpha=args.lora_alpha,
dropout=0.0,
use_bias=args.attention_qkv_bias,
)
if args.target_modules is not None and "q_proj" in args.target_modules
else nn.Linear(
self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
self.dim, q_out_dim, bias=self.attention_qkv_bias
)
)
self.wk = (
Expand Down Expand Up @@ -452,10 +463,17 @@
input_pos = kwargs.get("input_pos")
bsz, seqlen, _ = x.shape

# QKV
q, k, v = self.wq(x), self.wk(x), self.wv(x)
# We need view_copy elimination
q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim)
if self.use_q_gate:
q_and_gate = self.wq(x).view(
bsz, seqlen, self.n_local_heads, self.head_dim * 2
)
q, gate = torch.chunk(q_and_gate, 2, dim=-1)
gate = gate.reshape(bsz, seqlen, -1)
else:
q = self.wq(x).view(bsz, seqlen, self.n_local_heads, self.head_dim)
gate = None

k, v = self.wk(x), self.wv(x)
k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

Expand Down Expand Up @@ -492,6 +510,8 @@
input_pos[0].item(), seqlen
)
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask)
if gate is not None:
output = output * torch.sigmoid(gate)
return self.wo(output), None

# grouped multiquery attention: expand out keys and values
Expand All @@ -505,12 +525,232 @@
output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)

output = output.transpose(1, 2).reshape(bsz, seqlen, -1)
if gate is not None:
output = output * torch.sigmoid(gate)

output = self.wo(output)

return output, None


def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor:
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
return x * inv_norm


@register_attention("gated_deltanet")
class AttentionGatedDeltaNet(Attention):
"""Qwen3.5 linear-attention (Gated DeltaNet) block with internal state."""

def __init__(
self,
args: ModelArgs,
layer_id: int,
rope: Rope,
**_kwargs: Any,
):
super().__init__()
del rope # DeltaNet layers do not use RoPE.

self.hidden_size = args.dim
self.max_batch_size = args.max_batch_size
self.layer_id = layer_id

assert args.linear_num_key_heads is not None
assert args.linear_num_value_heads is not None
assert args.linear_key_head_dim is not None
assert args.linear_value_head_dim is not None

self.num_k_heads = args.linear_num_key_heads
self.num_v_heads = args.linear_num_value_heads
self.head_k_dim = args.linear_key_head_dim
self.head_v_dim = args.linear_value_head_dim
self.key_dim = self.head_k_dim * self.num_k_heads
self.value_dim = self.head_v_dim * self.num_v_heads
self.conv_kernel_size = args.linear_conv_kernel_dim

assert (
self.num_v_heads % self.num_k_heads == 0
), "linear_num_value_heads must be divisible by linear_num_key_heads."
self.head_repeat = self.num_v_heads // self.num_k_heads

self.conv_dim = self.key_dim * 2 + self.value_dim
self.in_proj_qkv = nn.Linear(self.hidden_size, self.conv_dim, bias=False)
self.in_proj_z = nn.Linear(self.hidden_size, self.value_dim, bias=False)
self.in_proj_b = nn.Linear(self.hidden_size, self.num_v_heads, bias=False)
self.in_proj_a = nn.Linear(self.hidden_size, self.num_v_heads, bias=False)

self.conv1d = nn.Conv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
kernel_size=self.conv_kernel_size,
groups=self.conv_dim,
bias=False,
padding=0,
)

self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads))
A = torch.empty(self.num_v_heads).uniform_(0, 16)
self.A_log = nn.Parameter(torch.log(A))
self.norm = RMSNormGated(self.head_v_dim, eps=args.norm_eps)
self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)

self.register_buffer(
"conv_state",
torch.zeros(
self.max_batch_size,
self.conv_dim,
self.conv_kernel_size,
dtype=torch.float32,
device="cpu",
),
)
self.register_buffer(
"recurrent_state",
torch.zeros(
self.max_batch_size,
self.num_v_heads,
self.head_k_dim,
self.head_v_dim,
dtype=torch.float32,
device="cpu",
),
)

def _maybe_reset_state(
self, input_pos: Optional[torch.Tensor], batch_size: int
) -> None:
if input_pos is None:
return
reset = (input_pos[0] == 0).to(self.conv_state.dtype)
keep = 1.0 - reset
self.conv_state[:batch_size].mul_(keep)
self.recurrent_state[:batch_size].mul_(keep)

def _apply_causal_conv(self, mixed_qkv: torch.Tensor) -> torch.Tensor:
# mixed_qkv: (batch, seq_len, conv_dim)
batch_size, seq_len, _ = mixed_qkv.shape
mixed_qkv = mixed_qkv.transpose(1, 2)
state_len = self.conv_state.shape[-1]
hidden_states_new = torch.cat([self.conv_state[:batch_size], mixed_qkv], dim=-1)
new_conv_state = hidden_states_new[:, :, -state_len:]
with torch.no_grad():
self.conv_state[:batch_size].copy_(new_conv_state.to(self.conv_state.dtype))
out = F.conv1d(
hidden_states_new,
self.conv1d.weight,
self.conv1d.bias,
padding=0,
groups=self.conv_dim,
)
out = F.silu(out[:, :, -seq_len:]).to(mixed_qkv.dtype)
return out.transpose(1, 2).contiguous()

def _recurrent_gated_delta_rule(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
) -> torch.Tensor:
# query/key/value: (batch, seq_len, num_heads, head_dim)
# g/beta: (batch, seq_len, num_heads)
initial_dtype = query.dtype
query = _l2norm(query, dim=-1, eps=1e-6)
key = _l2norm(key, dim=-1, eps=1e-6)
query, key, value, beta, g = [
x.transpose(1, 2).contiguous().to(torch.float32)
for x in (query, key, value, beta, g)
]

batch_size, num_heads, sequence_length, k_head_dim = key.shape
v_head_dim = value.shape[-1]
scale = 1.0 / (query.shape[-1] ** 0.5)
query = query * scale

core_attn_out = torch.zeros(
batch_size,
num_heads,
sequence_length,
v_head_dim,
device=value.device,
dtype=value.dtype,
)
last_recurrent_state = self.recurrent_state[:batch_size].to(value.dtype)

for i in range(sequence_length):
q_t = query[:, :, i]
k_t = key[:, :, i]
v_t = value[:, :, i]
g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1)
beta_t = beta[:, :, i].unsqueeze(-1)

last_recurrent_state = last_recurrent_state * g_t
kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
delta = (v_t - kv_mem) * beta_t
last_recurrent_state = last_recurrent_state + k_t.unsqueeze(
-1
) * delta.unsqueeze(-2)
core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(
dim=-2
)

with torch.no_grad():
self.recurrent_state[:batch_size].copy_(
last_recurrent_state.to(self.recurrent_state.dtype)
)

return core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)

def forward(
self,
x: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
**kwargs: ForwardOptions,
) -> Tuple[torch.Tensor, Optional[Any]]:
del freqs_cos
del freqs_sin
input_pos = kwargs.get("input_pos")
batch_size, seq_len, _ = x.shape
assert (
batch_size <= self.max_batch_size
), f"batch_size ({batch_size}) exceeds max_batch_size ({self.max_batch_size})"

self._maybe_reset_state(input_pos, batch_size)

mixed_qkv = self.in_proj_qkv(x)
z = self.in_proj_z(x).reshape(batch_size, seq_len, -1, self.head_v_dim)
b = self.in_proj_b(x)
a = self.in_proj_a(x)

mixed_qkv = self._apply_causal_conv(mixed_qkv)
query, key, value = torch.split(
mixed_qkv,
[self.key_dim, self.key_dim, self.value_dim],
dim=-1,
)
query = query.reshape(batch_size, seq_len, -1, self.head_k_dim)
key = key.reshape(batch_size, seq_len, -1, self.head_k_dim)
value = value.reshape(batch_size, seq_len, -1, self.head_v_dim)

if self.head_repeat > 1:
query = query.repeat_interleave(self.head_repeat, dim=2)
key = key.repeat_interleave(self.head_repeat, dim=2)

beta = b.sigmoid()
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
core_attn_out = self._recurrent_gated_delta_rule(query, key, value, g, beta)

core_attn_out = core_attn_out.reshape(-1, self.head_v_dim)
z = z.reshape(-1, self.head_v_dim)
core_attn_out = self.norm(core_attn_out, z)
core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1)

return self.out_proj(core_attn_out), None


@register_attention("skip")
class AttentionSkip(Attention):
def __init__(self, *args, **kwargs):
Expand Down
Loading
Loading