diff --git a/Engine/Llama_modules.py b/Engine/Llama_modules.py index 4ca1f934..618d6e7e 100644 --- a/Engine/Llama_modules.py +++ b/Engine/Llama_modules.py @@ -1,18 +1,11 @@ from .Llama_KV import KV_Cache -from transformers.models.llama.modeling_llama import( - LlamaRMSNorm, - LlamaConfig, - LlamaMLP, - LlamaRotaryEmbedding, - apply_rotary_pos_emb, - repeat_kv, - ACT2FN -) +from transformers.models.llama.modeling_llama import LlamaConfig, ACT2FN from torch import nn import torch -import torch.nn.functional as F -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union # noqa: F401 import math + + class LlamaRotaryEmbedding_FI(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -24,9 +17,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) + self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len @@ -44,6 +35,7 @@ def forward(self, dtype, seq_len=None): self.sin_cached[:seq_len].to(dtype=dtype), ) + class LlamaAttention_FI(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -62,10 +54,7 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.is_causal = True if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) + raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads}).") self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) @@ -87,14 +76,14 @@ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def forward( self, hidden_states: torch.Tensor, - max_length :int, + max_length: int, storage_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - kv_cache : KV_Cache = None, - debug :bool = False + kv_cache: KV_Cache = None, + debug: bool = False, ): - + bsz, q_len, _ = hidden_states.size() if debug: @@ -104,41 +93,33 @@ def forward( assert position_ids.shape[1] == q_len assert position_ids.shape[0] == bsz - - query_states :torch.Tensor= self.q_proj(hidden_states) - key_states :torch.Tensor= self.k_proj(hidden_states) - value_states :torch.Tensor= self.v_proj(hidden_states) + query_states: torch.Tensor = self.q_proj(hidden_states) + key_states: torch.Tensor = self.k_proj(hidden_states) + value_states: torch.Tensor = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states.dtype, seq_len=max_length) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - key_states, value_states = kv_cache.update_kv_cache(key_states, value_states, - self.layer_idx, storage_ids=storage_ids, debug=debug) - + key_states, value_states = kv_cache.update_kv_cache(key_states, value_states, self.layer_idx, storage_ids=storage_ids, debug=debug) + key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=0.0, - is_causal=False + query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) - + return attn_output - + + class LlamaAttention_TG(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -157,10 +138,7 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.is_causal = True if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) + raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads}).") self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) @@ -182,14 +160,14 @@ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def forward( self, hidden_states: torch.Tensor, - max_length :int, + max_length: int, storage_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - kv_cache : KV_Cache = None, - debug :bool = False + kv_cache: KV_Cache = None, + debug: bool = False, ): - + bsz, q_len, _ = hidden_states.size() if debug: @@ -199,65 +177,51 @@ def forward( assert position_ids.shape[1] == q_len assert position_ids.shape[0] == bsz - - query_states :torch.Tensor= self.q_proj(hidden_states) - key_states :torch.Tensor= self.k_proj(hidden_states) - value_states :torch.Tensor= self.v_proj(hidden_states) - + query_states: torch.Tensor = self.q_proj(hidden_states) + key_states: torch.Tensor = self.k_proj(hidden_states) + value_states: torch.Tensor = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - + cos, sin = self.rotary_emb(value_states.dtype, seq_len=max_length) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - - key_states, value_states = kv_cache.update_kv_cache(key_states, value_states, - self.layer_idx, storage_ids=storage_ids, debug=debug) - + + key_states, value_states = kv_cache.update_kv_cache(key_states, value_states, self.layer_idx, storage_ids=storage_ids, debug=debug) + kv_len = kv_cache.get_usable_length(layer_idx=self.layer_idx, input_length=len(storage_ids)) key_states = key_states[..., :kv_len, :] value_states = value_states[..., :kv_len, :] - - key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_len)}, but is" - f" {attn_weights.size()}" - ) + raise ValueError(f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_len)}, but is" f" {attn_weights.size()}") if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_len)}, but is {attention_mask.size()}" - ) + raise ValueError(f"Attention mask should be of size {(bsz, 1, q_len, kv_len)}, but is {attention_mask.size()}") attn_weights = attn_weights + attention_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - + attn_output = torch.matmul(attn_weights, value_states) - + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) + raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}") attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) return attn_output + + class LlamaMLP_FI(nn.Module): - def __init__(self, config:LlamaConfig): + def __init__(self, config: LlamaConfig): super().__init__() self.config = config self.hidden_size = config.hidden_size @@ -268,11 +232,12 @@ def __init__(self, config:LlamaConfig): self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): - + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - + return down_proj + class LlamaRMSNorm_FI(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -289,6 +254,7 @@ def forward(self, hidden_states): hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) + class LlamaDecoderLayer_FI(nn.Module): def __init__(self, config: LlamaConfig, layer_idx: int): super().__init__() @@ -303,12 +269,12 @@ def __init__(self, config: LlamaConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, - max_length :int, - storage_ids :torch.LongTensor, + max_length: int, + storage_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, kv_cache: KV_Cache = None, - debug :bool = False + debug: bool = False, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -324,12 +290,11 @@ def forward( (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ - - + residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - + # Self Attention hidden_states = self.self_attn( hidden_states=hidden_states, @@ -338,7 +303,7 @@ def forward( position_ids=position_ids, max_length=max_length, kv_cache=kv_cache, - debug=debug + debug=debug, ) hidden_states = residual + hidden_states @@ -365,12 +330,12 @@ def __init__(self, config: LlamaConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, - max_length :int, - storage_ids :torch.LongTensor, + max_length: int, + storage_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, kv_cache: KV_Cache = None, - debug :bool = False + debug: bool = False, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -386,12 +351,11 @@ def forward( (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ - - + residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - + # Self Attention hidden_states = self.self_attn( hidden_states=hidden_states, @@ -400,17 +364,49 @@ def forward( position_ids=position_ids, max_length=max_length, kv_cache=kv_cache, - debug=debug + debug=debug, ) - + hidden_states = residual + hidden_states - + # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - + hidden_states = self.mlp(hidden_states) - + hidden_states = residual + hidden_states - - return hidden_states \ No newline at end of file + + return hidden_states + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """ + Applies Rotary Position Embedding to the query and key tensors. + copied from modelling_llama.py @4.37.2; changed around 4.38 + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def rotate_half(x): + """Rotates half the hidden dims of the input. """ + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1)