Skip to content

Commit 5520427

Browse files
committed
grad maker
1 parent 81da501 commit 5520427

File tree

3 files changed

+223
-9
lines changed

3 files changed

+223
-9
lines changed

minimal_llama/hyper/prefix_llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -813,7 +813,7 @@ def create_decoding_mask(orig_num_valid_tokens, max_seq_len, initial_max_len):
813813
mask = torch.ones([batch_size, 1, max_seq_len])
814814
for i, nvt in enumerate(orig_num_valid_tokens):
815815
mask[i, :, nvt:initial_max_len] = 0
816-
return mask[:, None, -1:, ].bool()
816+
return mask[:, None, :, :].bool()
817817

818818

819819
def create_prefix_train_attention_mask(input_ids, prefix_length):
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from minimal_llama.hypergrad.llama_simple_jvp_peft import (
5+
LLaMAConfig, RMSNorm, MLP, check_nan,
6+
NoInitLinear, RotaryEmbedding, rotate_half,
7+
create_rope_embed_ids,
8+
)
9+
10+
11+
class GradMakerLayer(nn.Module):
12+
def __init__(self, config: LLaMAConfig):
13+
super().__init__()
14+
self.config = config
15+
self.cross_attn = Attention(config=config)
16+
self.mlp = MLP(config=config)
17+
self.peft_input_layernorm = RMSNorm(dim=config.dim, dtype=config.dtype)
18+
self.model_input_layernorm = RMSNorm(dim=config.dim, dtype=config.dtype)
19+
self.post_attention_layernorm = RMSNorm(dim=config.dim, dtype=config.dtype)
20+
21+
def forward(
22+
self,
23+
peft_hidden_states,
24+
model_hidden_states,
25+
cos, sin,
26+
attention_mask,
27+
):
28+
normed_peft_hidden_states = self.peft_input_layernorm(peft_hidden_states).to(self.config.dtype)
29+
normed_model_hidden_states = self.model_input_layernorm(model_hidden_states).to(self.config.dtype)
30+
check_nan(normed_model_hidden_states)
31+
raw_self_attn_output = self.cross_attn(
32+
peft_hidden_states=normed_peft_hidden_states,
33+
model_hidden_states=normed_model_hidden_states,
34+
cos=cos, sin=sin,
35+
attention_mask=attention_mask,
36+
)
37+
# [batch_size, seq_len, hidden_dim]
38+
peft_hidden_states = peft_hidden_states + raw_self_attn_output["attn_output"]
39+
check_nan(peft_hidden_states)
40+
# 2) FFN
41+
# [batch_size, seq_len, hidden_dim]
42+
peft_hidden_states = peft_hidden_states + self.mlp(
43+
self.post_attention_layernorm(peft_hidden_states),
44+
)
45+
check_nan(peft_hidden_states)
46+
return peft_hidden_states
47+
48+
49+
def apply_rotary_pos_emb(k, cos, sin):
50+
k_embed = (k * cos) + (rotate_half(k) * sin)
51+
return k_embed
52+
53+
54+
class Attention(nn.Module):
55+
def __init__(self, config: LLaMAConfig):
56+
super().__init__()
57+
self.config = config
58+
self.n_heads = config.n_heads
59+
self.head_dim = config.dim // config.n_heads
60+
61+
self.q_proj = NoInitLinear(config.dim, config.dim, bias=False, dtype=config.dtype)
62+
self.k_proj = NoInitLinear(config.dim, config.dim, bias=False, dtype=config.dtype)
63+
self.v_proj = NoInitLinear(config.dim, config.dim, bias=False, dtype=config.dtype)
64+
self.o_proj = NoInitLinear(config.dim, config.dim, bias=False, dtype=config.dtype)
65+
self.rotary_emb = RotaryEmbedding(dim=self.head_dim, max_position_embeddings=config.max_seq_length)
66+
67+
def forward(
68+
self,
69+
peft_hidden_states,
70+
model_hidden_states,
71+
cos, sin,
72+
attention_mask=None,
73+
):
74+
_, p_seq_len, _ = peft_hidden_states.size()
75+
batch_size, m_seq_len, hidden_dim = model_hidden_states.size()
76+
77+
# (batch_size, num_heads, q_seq_len, head_dim)
78+
query_states = self.q_proj(peft_hidden_states).view(
79+
batch_size, p_seq_len, self.n_heads, self.head_dim).transpose(1, 2)
80+
key_states = self.k_proj(model_hidden_states).view(
81+
batch_size, m_seq_len, self.n_heads, self.head_dim).transpose(1, 2)
82+
value_states = self.v_proj(model_hidden_states).view(
83+
batch_size, m_seq_len, self.n_heads, self.head_dim).transpose(1, 2)
84+
key_states = apply_rotary_pos_emb(key_states, cos=cos, sin=sin)
85+
# noinspection PyUnresolvedReferences
86+
with torch.backends.cuda.sdp_kernel(
87+
enable_math=True, enable_flash=True, enable_mem_efficient=True,
88+
):
89+
attn_output = torch.nn.functional.scaled_dot_product_attention(
90+
query=query_states,
91+
key=key_states,
92+
value=value_states,
93+
attn_mask=attention_mask,
94+
)
95+
# (batch_size, q_seq_len, hidden_dim)
96+
attn_output = attn_output.transpose(1, 2).contiguous().view(
97+
batch_size, p_seq_len, hidden_dim,
98+
)
99+
attn_output = self.o_proj(attn_output)
100+
check_nan(attn_output)
101+
return {"attn_output": attn_output}
102+
103+
@classmethod
104+
def append_to_kv_cache(cls, kv_cache, new_key_state, new_value_state):
105+
"""
106+
107+
:param kv_cache: {"key"/"value": [batch_size, num_heads, cache_seq_len, head_dim]}
108+
:param new_key_state: [batch_size, num_heads, seq_len=1, head_dim]
109+
:param new_value_state: [batch_size, num_heads, seq_len=1, head_dim]
110+
:return:
111+
"""
112+
# We need to do some fancy indexing, because we are appending to a right-padded cache
113+
key_cache, value_cache = kv_cache["key"], kv_cache["value"]
114+
key_cache = torch.cat([key_cache, new_key_state], dim=2)
115+
value_cache = torch.cat([value_cache, new_value_state], dim=2)
116+
return key_cache, value_cache
117+
118+
119+
class SimpleGradMaker(nn.Module):
120+
def __init__(self, config: LLaMAConfig, num_peft_layers: int = 1, return_diff: bool = True):
121+
super().__init__()
122+
self.config = config
123+
self.num_peft_layers = num_peft_layers
124+
self.num_scalers = config.n_layers
125+
self.scalers_proj = nn.Linear(config.n_layers, config.dim, dtype=self.config.dtype)
126+
self.layers = nn.ModuleList([
127+
GradMakerLayer(config)
128+
for _ in range(num_peft_layers)
129+
])
130+
self.scalers_up_proj = nn.Linear(config.dim, config.n_layers, dtype=self.config.dtype)
131+
self.return_diff = return_diff
132+
133+
def forward(self, input_ids, peft_params, model_hidden_states: list):
134+
batch_size, peft_len, _ = peft_params[0]["hidden_states"].shape
135+
scalers = torch.stack([
136+
layer_peft_params["scaler"]
137+
for layer_peft_params in peft_params
138+
], dim=-1)
139+
scalers_token = self.scalers_proj(scalers)[:, None, :]
140+
params_tokens = torch.stack([
141+
layer_peft_params["hidden_states"]
142+
for layer_peft_params in peft_params
143+
], dim=1).transpose(0, 1).reshape(batch_size, self.config.n_layers * peft_len, self.config.dim)
144+
peft_hidden_states = torch.cat([
145+
scalers_token,
146+
params_tokens,
147+
], dim=1)
148+
attention_mask = create_cross_mask(input_ids)
149+
cos, sin = self.get_cos_sin(create_rope_embed_ids(input_ids=input_ids))
150+
for i, layer in enumerate(self.layers):
151+
peft_hidden_states = layer(
152+
peft_hidden_states=peft_hidden_states,
153+
model_hidden_states=model_hidden_states[i],
154+
cos=cos, sin=sin, attention_mask=attention_mask,
155+
)
156+
scalers = self.scalers_up_proj(peft_hidden_states[:, 0, :])
157+
peft_hidden_states = peft_hidden_states[:, 1:, :].view(
158+
batch_size, self.config.n_layers, peft_len, self.config.dim,
159+
)
160+
new_peft_params = []
161+
for i in range(self.config.n_layers):
162+
layer_peft = {
163+
"scaler": scalers[:, i],
164+
"hidden_states": peft_hidden_states[:, i, :, :],
165+
}
166+
if self.return_diff:
167+
layer_peft["scaler"] = layer_peft["scaler"] - peft_params[i]["scaler"]
168+
layer_peft["hidden_states"] = layer_peft["hidden_states"] - peft_params[i]["hidden_states"]
169+
new_peft_params.append(layer_peft)
170+
return new_peft_params
171+
172+
def get_cos_sin(self, rope_embed_ids):
173+
cos = F.embedding(
174+
rope_embed_ids,
175+
self.layers[0].cross_attn.rotary_emb.cos_cached[0, 0].to(rope_embed_ids.device)
176+
).to(self.config.dtype)
177+
sin = F.embedding(
178+
rope_embed_ids,
179+
self.layers[0].cross_attn.rotary_emb.sin_cached[0, 0].to(rope_embed_ids.device)
180+
).to(self.config.dtype)
181+
cos, sin = cos[:, None, :, :], sin[:, None, :, :]
182+
return cos, sin
183+
184+
185+
def create_cross_mask(input_ids, pad_token_id=0):
186+
is_valid = (input_ids != pad_token_id)
187+
return is_valid[:, None, None, :]

minimal_llama/hypergrad/llama_simple_jvp_peft.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,18 @@ def __init__(self, config: LLaMAConfig):
7676

7777
def init_kv_cache_from_prefix(self, prefix_h):
7878
kv_cache = []
79-
for i, h in enumerate(prefix_h):
79+
for i, layer_h in enumerate(prefix_h):
80+
h = layer_h["hidden_states"]
81+
if "scaler" in layer_h:
82+
h = h * layer_h["scaler"][:, None, None]
8083
layer = self.model.layers[i]
81-
h = layer.input_layernor(h)
84+
h = layer.input_layernorm(h)
85+
batch_size, seq_len, _ = h.shape
8286
kv_cache.append({
83-
"key": layer.self_attn.key_proj(h),
84-
"value": layer.self_attn.key_proj(h),
87+
"key": layer.self_attn.k_proj(h).view(
88+
batch_size, seq_len, self.config.n_heads, self.config.head_dim).transpose(1, 2),
89+
"value": layer.self_attn.v_proj(h).view(
90+
batch_size, seq_len, self.config.n_heads, self.config.head_dim).transpose(1, 2),
8591
})
8692
return kv_cache
8793

@@ -90,13 +96,17 @@ def forward(
9096
input_ids,
9197
attention_mask=None,
9298
peft_params: Optional[dict] = None,
99+
return_hidden_states: Optional[list] = None,
100+
return_logits: bool = True,
93101
):
94102
"""Forward pass (with full decode sequence, intended for training or loss-scoring)
95103
96104
:param input_ids: [batch_size, seq_len]
97105
- Always right-padded. Masks are generated based on padding tokens
98106
:param attention_mask
99107
:param peft_params
108+
:param return_hidden_states
109+
:param return_logits
100110
:return: logits [batch_size, seq_len]
101111
"""
102112
if dict_get(peft_params, "prefix"):
@@ -113,10 +123,11 @@ def forward(
113123
kv_cache=peft_params["prefix"],
114124
attention_mask=attention_mask,
115125
peft_params=peft_params,
126+
return_hidden_states=return_hidden_states,
116127
)
117128
elif dict_get(peft_params, "prefix_h"):
118129
# [batch_size, num_heads=1, q_len=seq_len, kv_len=seq_len]
119-
prefix_length = peft_params["prefix_h"][0]["key"].shape[1]
130+
prefix_length = peft_params["prefix_h"][0]["hidden_states"].shape[1]
120131
kv_cache = self.init_kv_cache_from_prefix(peft_params["prefix_h"])
121132
rope_embed_ids = create_rope_embed_ids(input_ids=input_ids) + prefix_length
122133
cos, sin = self.get_cos_sin(rope_embed_ids)
@@ -129,6 +140,7 @@ def forward(
129140
kv_cache=kv_cache,
130141
attention_mask=attention_mask,
131142
peft_params=peft_params,
143+
return_hidden_states=return_hidden_states,
132144
)
133145
else:
134146
rope_embed_ids = create_rope_embed_ids(input_ids=input_ids)
@@ -139,10 +151,12 @@ def forward(
139151
use_kv_cache=False,
140152
attention_mask=attention_mask,
141153
peft_params=peft_params,
154+
return_hidden_states=return_hidden_states,
142155
)
143156
# [batch_size, seq_len, vocab_size]
144-
logits = self.lm_head(model_out["hidden_states"])
145-
return logits
157+
if return_logits:
158+
model_out["logits"] = self.lm_head(model_out["hidden_states"])
159+
return model_out
146160

147161
def init_kv_cache(self, batch_size):
148162
# noinspection GrazieInspection
@@ -304,6 +318,7 @@ def forward(
304318
num_valid_tokens=None,
305319
attention_mask=None,
306320
peft_params: Optional[dict] = None,
321+
return_hidden_states: Optional[list] = None,
307322
):
308323
"""
309324
:param input_ids: [batch_size, seq_len]
@@ -320,6 +335,7 @@ def forward(
320335
Only used for decoding
321336
:param attention_mask: [batch_size, num_heads, q_len, kv_len]
322337
:param peft_params
338+
:param return_hidden_states
323339
"""
324340
if dict_get(peft_params, "embeds") is not None:
325341
hidden_states = dict_get(peft_params, "embeds")
@@ -328,8 +344,16 @@ def forward(
328344

329345
hidden_states = hidden_states.to(self.config.dtype)
330346

347+
if return_hidden_states:
348+
max_layer = max(return_hidden_states)
349+
else:
350+
max_layer = self.config.n_layers
351+
331352
new_kv_cache = []
353+
hidden_states_dict = {}
332354
for layer_i, layer in enumerate(self.layers):
355+
if layer_i > max_layer:
356+
break
333357
if kv_cache:
334358
# dict(
335359
# key = [batch_size, num_heads, kv_seq_len=decode_step+1, head_dim]
@@ -363,11 +387,14 @@ def forward(
363387
)
364388

365389
hidden_states = layer_out["hidden_states"]
390+
if return_hidden_states and layer_i in return_hidden_states:
391+
hidden_states_dict[layer_i] = hidden_states
366392
if kv_cache:
367393
new_kv_cache.append(layer_out["kv_cache"])
368394
hidden_states = self.norm(hidden_states)
369395
output = {
370-
"hidden_states": hidden_states
396+
"hidden_states": hidden_states,
397+
"layer_hidden_states": hidden_states_dict,
371398
}
372399
if kv_cache:
373400
output["kv_cache"] = new_kv_cache

0 commit comments

Comments
 (0)