Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 96 additions & 100 deletions Engine/Llama_modules.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
from .Llama_KV import KV_Cache
from transformers.models.llama.modeling_llama import(
LlamaRMSNorm,
LlamaConfig,
LlamaMLP,
LlamaRotaryEmbedding,
apply_rotary_pos_emb,
repeat_kv,
ACT2FN
)
from transformers.models.llama.modeling_llama import LlamaConfig, ACT2FN
from torch import nn
import torch
import torch.nn.functional as F
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union # noqa: F401
import math


class LlamaRotaryEmbedding_FI(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
Expand All @@ -24,9 +17,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
self.register_buffer("inv_freq", inv_freq, persistent=False)

# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype())

def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
Expand All @@ -44,6 +35,7 @@ def forward(self, dtype, seq_len=None):
self.sin_cached[:seq_len].to(dtype=dtype),
)


class LlamaAttention_FI(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

Expand All @@ -62,10 +54,7 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
self.is_causal = True

if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads}).")

self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
Expand All @@ -87,14 +76,14 @@ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
def forward(
self,
hidden_states: torch.Tensor,
max_length :int,
max_length: int,
storage_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
kv_cache : KV_Cache = None,
debug :bool = False
kv_cache: KV_Cache = None,
debug: bool = False,
):

bsz, q_len, _ = hidden_states.size()

if debug:
Expand All @@ -104,41 +93,33 @@ def forward(
assert position_ids.shape[1] == q_len
assert position_ids.shape[0] == bsz


query_states :torch.Tensor= self.q_proj(hidden_states)
key_states :torch.Tensor= self.k_proj(hidden_states)
value_states :torch.Tensor= self.v_proj(hidden_states)
query_states: torch.Tensor = self.q_proj(hidden_states)
key_states: torch.Tensor = self.k_proj(hidden_states)
value_states: torch.Tensor = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)


cos, sin = self.rotary_emb(value_states.dtype, seq_len=max_length)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

key_states, value_states = kv_cache.update_kv_cache(key_states, value_states,
self.layer_idx, storage_ids=storage_ids, debug=debug)


key_states, value_states = kv_cache.update_kv_cache(key_states, value_states, self.layer_idx, storage_ids=storage_ids, debug=debug)

key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False
query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)

return attn_output



class LlamaAttention_TG(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

Expand All @@ -157,10 +138,7 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
self.is_causal = True

if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads}).")

self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
Expand All @@ -182,14 +160,14 @@ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
def forward(
self,
hidden_states: torch.Tensor,
max_length :int,
max_length: int,
storage_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
kv_cache : KV_Cache = None,
debug :bool = False
kv_cache: KV_Cache = None,
debug: bool = False,
):

bsz, q_len, _ = hidden_states.size()

if debug:
Expand All @@ -199,65 +177,51 @@ def forward(
assert position_ids.shape[1] == q_len
assert position_ids.shape[0] == bsz


query_states :torch.Tensor= self.q_proj(hidden_states)
key_states :torch.Tensor= self.k_proj(hidden_states)
value_states :torch.Tensor= self.v_proj(hidden_states)

query_states: torch.Tensor = self.q_proj(hidden_states)
key_states: torch.Tensor = self.k_proj(hidden_states)
value_states: torch.Tensor = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)



cos, sin = self.rotary_emb(value_states.dtype, seq_len=max_length)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)


key_states, value_states = kv_cache.update_kv_cache(key_states, value_states,
self.layer_idx, storage_ids=storage_ids, debug=debug)


key_states, value_states = kv_cache.update_kv_cache(key_states, value_states, self.layer_idx, storage_ids=storage_ids, debug=debug)

kv_len = kv_cache.get_usable_length(layer_idx=self.layer_idx, input_length=len(storage_ids))
key_states = key_states[..., :kv_len, :]
value_states = value_states[..., :kv_len, :]



key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)



if attn_weights.size() != (bsz, self.num_heads, q_len, kv_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_len)}, but is"
f" {attn_weights.size()}"
)
raise ValueError(f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_len)}, but is" f" {attn_weights.size()}")

if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_len)}, but is {attention_mask.size()}"
)
raise ValueError(f"Attention mask should be of size {(bsz, 1, q_len, kv_len)}, but is {attention_mask.size()}")
attn_weights = attn_weights + attention_mask

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)

attn_output = torch.matmul(attn_weights, value_states)

if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}")
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output


class LlamaMLP_FI(nn.Module):
def __init__(self, config:LlamaConfig):
def __init__(self, config: LlamaConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
Expand All @@ -268,11 +232,12 @@ def __init__(self, config:LlamaConfig):
self.act_fn = ACT2FN[config.hidden_act]

def forward(self, x):

down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

return down_proj


class LlamaRMSNorm_FI(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Expand All @@ -289,6 +254,7 @@ def forward(self, hidden_states):
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)


class LlamaDecoderLayer_FI(nn.Module):
def __init__(self, config: LlamaConfig, layer_idx: int):
super().__init__()
Expand All @@ -303,12 +269,12 @@ def __init__(self, config: LlamaConfig, layer_idx: int):
def forward(
self,
hidden_states: torch.Tensor,
max_length :int,
storage_ids :torch.LongTensor,
max_length: int,
storage_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
kv_cache: KV_Cache = None,
debug :bool = False
debug: bool = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
Expand All @@ -324,12 +290,11 @@ def forward(
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""



residual = hidden_states

hidden_states = self.input_layernorm(hidden_states)

# Self Attention
hidden_states = self.self_attn(
hidden_states=hidden_states,
Expand All @@ -338,7 +303,7 @@ def forward(
position_ids=position_ids,
max_length=max_length,
kv_cache=kv_cache,
debug=debug
debug=debug,
)
hidden_states = residual + hidden_states

Expand All @@ -365,12 +330,12 @@ def __init__(self, config: LlamaConfig, layer_idx: int):
def forward(
self,
hidden_states: torch.Tensor,
max_length :int,
storage_ids :torch.LongTensor,
max_length: int,
storage_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
kv_cache: KV_Cache = None,
debug :bool = False
debug: bool = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
Expand All @@ -386,12 +351,11 @@ def forward(
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""



residual = hidden_states

hidden_states = self.input_layernorm(hidden_states)

# Self Attention
hidden_states = self.self_attn(
hidden_states=hidden_states,
Expand All @@ -400,17 +364,49 @@ def forward(
position_ids=position_ids,
max_length=max_length,
kv_cache=kv_cache,
debug=debug
debug=debug,
)

hidden_states = residual + hidden_states

# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)

hidden_states = self.mlp(hidden_states)

hidden_states = residual + hidden_states

return hidden_states

return hidden_states


def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""
Applies Rotary Position Embedding to the query and key tensors.
copied from modelling_llama.py @4.37.2; changed around 4.38
"""
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
<copied from modelling_llama.py @4.37.2>
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def rotate_half(x):
"""Rotates half the hidden dims of the input. <copied from modelling_llama.py @4.37.2>"""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)