-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathindexcache_vllm.patch
More file actions
172 lines (156 loc) · 7.02 KB
/
indexcache_vllm.patch
File metadata and controls
172 lines (156 loc) · 7.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py
index 1d3e987b7..20a31d47b 100644
--- a/vllm/model_executor/layers/mla.py
+++ b/vllm/model_executor/layers/mla.py
@@ -87,6 +87,8 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
self.indexer_rope_emb = mla_modules.indexer_rotary_emb
self.is_sparse = mla_modules.is_sparse
+ self.skip_topk = False
+ self.next_skip_topk = False
if self.indexer is not None:
assert hasattr(self.indexer, "topk_tokens")
self.topk_tokens = self.indexer.topk_tokens
@@ -115,7 +117,8 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
positions: torch.Tensor,
hidden_states: torch.Tensor,
llama_4_scaling: torch.Tensor | None = None,
- ) -> torch.Tensor:
+ prev_topk_indices: torch.Tensor | None = None,
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
q_c = None
kv_lora = None
@@ -160,9 +163,12 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
)
if self.indexer and self.is_sparse:
- _topk_indices = self.indexer(
- hidden_states, q_c, positions, self.indexer_rope_emb
- )
+ if not self.skip_topk:
+ _topk_indices = self.indexer(
+ hidden_states, q_c, positions, self.indexer_rope_emb
+ )
+ else:
+ _topk_indices = prev_topk_indices
if llama_4_scaling is not None:
q *= llama_4_scaling
@@ -174,4 +180,10 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
output_shape=(hidden_states.shape[0], self.num_heads * self.v_head_dim),
)
- return self.o_proj(attn_out)[0]
+ output = self.o_proj(attn_out)[0]
+ if self.indexer and self.is_sparse:
+ if not self.next_skip_topk:
+ return output, None
+ else:
+ return output, _topk_indices
+ return output
diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py
index c75ee1a1b..daaaf0c16 100644
--- a/vllm/model_executor/models/deepseek_mtp.py
+++ b/vllm/model_executor/models/deepseek_mtp.py
@@ -111,7 +111,7 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
torch.cat([inputs_embeds, previous_hidden_states], dim=-1)
)
- hidden_states, residual = self.mtp_block(
+ hidden_states, residual, _topk_indices = self.mtp_block(
positions=positions, hidden_states=hidden_states, residual=None
)
hidden_states = residual + hidden_states
diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py
index a198f1a0b..c89ea0988 100644
--- a/vllm/model_executor/models/deepseek_v2.py
+++ b/vllm/model_executor/models/deepseek_v2.py
@@ -835,6 +835,7 @@ class DeepseekV2MLAAttention(nn.Module):
prefix: str = "",
topk_indices_buffer: torch.Tensor | None = None,
input_size: int | None = None,
+ is_nextn: bool = False,
) -> None:
super().__init__()
self.hidden_size = hidden_size
@@ -953,6 +954,26 @@ class DeepseekV2MLAAttention(nn.Module):
self.indexer_rope_emb = None
self.indexer = None
+ # IndexCache config
+ if is_nextn or not self.is_v32:
+ _skip_topk = False
+ _next_skip_topk = False
+ else:
+ _index_topk_freq = getattr(config, "index_topk_freq", 1)
+ _index_topk_pattern = getattr(config, "index_topk_pattern", None)
+ layer_id = int(prefix.split(".")[-1]) if "." in prefix else 0
+ if _index_topk_pattern is None:
+ _skip_topk = (max(layer_id - 1, 0) % _index_topk_freq != 0)
+ _next_skip_topk = (layer_id % _index_topk_freq != 0)
+ else:
+ _skip_topk = _index_topk_pattern[layer_id] == 'S'
+ if layer_id < len(_index_topk_pattern) - 1:
+ _next_skip_topk = _index_topk_pattern[layer_id + 1] == 'S'
+ else:
+ _next_skip_topk = False
+ print('layer_id {} DSA skip_topk {} next_skip_topk {} is_nextn {}'.format(
+ layer_id, _skip_topk, _next_skip_topk, is_nextn))
+
mla_modules = MLAModules(
kv_a_layernorm=self.kv_a_layernorm,
kv_b_proj=self.kv_b_proj,
@@ -987,14 +1008,17 @@ class DeepseekV2MLAAttention(nn.Module):
quant_config,
prefix,
)
+ self.mla_attn.skip_topk = _skip_topk
+ self.mla_attn.next_skip_topk = _next_skip_topk
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
llama_4_scaling: torch.Tensor | None,
- ) -> torch.Tensor:
- return self.mla_attn(positions, hidden_states, llama_4_scaling)
+ prev_topk_indices: torch.Tensor | None = None,
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
+ return self.mla_attn(positions, hidden_states, llama_4_scaling, prev_topk_indices=prev_topk_indices)
class DeepseekV2DecoderLayer(nn.Module):
@@ -1087,6 +1111,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
llama_4_scaling: torch.Tensor | None = None,
+ prev_topk_indices: torch.Tensor | None = None,
) -> torch.Tensor:
# Self Attention
if residual is None:
@@ -1101,7 +1126,12 @@ class DeepseekV2DecoderLayer(nn.Module):
}
if not self.use_mha:
attn_kwargs["llama_4_scaling"] = llama_4_scaling
+ attn_kwargs["prev_topk_indices"] = prev_topk_indices
hidden_states = self.self_attn(**attn_kwargs)
+ if isinstance(hidden_states, tuple):
+ hidden_states, topk_indices = hidden_states
+ else:
+ topk_indices = None
if (
not isinstance(self.self_attn, DeepseekAttention)
@@ -1128,7 +1158,7 @@ class DeepseekV2DecoderLayer(nn.Module):
# of DeepseekV2MOE
hidden_states *= 1.0 / self.routed_scaling_factor
- return hidden_states, residual
+ return hidden_states, residual, topk_indices
@support_torch_compile
@@ -1219,14 +1249,16 @@ class DeepseekV2Model(nn.Module):
llama_4_scaling = None
aux_hidden_states = []
+ topk_indices = None
for idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer),
start=self.start_layer,
):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)
- hidden_states, residual = layer(
- positions, hidden_states, residual, llama_4_scaling
+ hidden_states, residual, topk_indices = layer(
+ positions, hidden_states, residual, llama_4_scaling,
+ prev_topk_indices=topk_indices,
)
if not get_pp_group().is_last_rank: