Skip to content

Commit 0bb417b

Browse files
author
sangchengmeng
committed
Merge branch 'main' into visual_only3
2 parents d9cb8c3 + 2d95b73 commit 0bb417b

File tree

51 files changed

+854
-247
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+854
-247
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@ repos:
77
args: [--line-length=120]
88
additional_dependencies: ['click==8.0.4']
99
- repo: https://github.com/pycqa/flake8
10-
rev: 3.9.0
10+
rev: 6.1.0
1111
hooks:
1212
- id: flake8
13-
additional_dependencies: [flake8-typing-imports==1.9.0]
14-
args: ['--config=.flake8', '--max-line-length=120', '--ignore=TYP001, E722, C901, E203, E266, E402, E302, E241, E902, E731, F403, E701, F405, F401, W292, W293, W503, W606, E231']
13+
args: ['--max-line-length=120', '--ignore=TYP001, E722, C901, E203, E266, E402, E302, E241, E902, E731, F403, E701, F405, F401, W292, W293, W503, W606, E231']

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,15 @@ Please refer to the [FAQ](https://lightllm-en.readthedocs.io/en/latest/faq.html)
4747
We welcome any coopoeration and contribution. If there is a project requires LightLLM's support, please contact us via email or create a pull request.
4848

4949
Projects based on LightLLM or referenced LightLLM components:
50-
- [LazyLLM](https://github.com/LazyAGI/LazyLLM)
5150
- [LoongServe, Peking University](https://github.com/LoongServe/LoongServe)
52-
- [OmniKV, Ant Group](https://github.com/antgroup/OmniKV)
5351
- [vLLM](https://github.com/vllm-project/vllm) (some LightLLM's kernel used)
5452
- [SGLang](https://github.com/sgl-project/sglang) (some LightLLM's kernel used)
5553
- [ParrotServe](https://github.com/microsoft/ParrotServe), Microsoft
5654
- [Aphrodite](https://github.com/aphrodite-engine/aphrodite-engine) (some LightLLM's kernel used)
5755
- [S-LoRA](https://github.com/S-LoRA/S-LoRA)
56+
- [OmniKV, Ant Group](https://github.com/antgroup/OmniKV)
57+
- [Lab4AI LightLLM+LlamaIndex](https://www.lab4ai.cn/project/detail?utm_source=LLM1&id=b417085ae8cd4dd0bef7161c3d583b15&type=project), [Lab4AI LightLLM+Qwen3-8B](https://www.lab4ai.cn/project/detail?utm_source=lightllmcapp&id=c98ff5d09528423d8dd06f5a063cb2a6&type=project)
58+
- [LazyLLM](https://github.com/LazyAGI/LazyLLM)
5859

5960
Also, LightLLM's pure-python design and token-level KC Cache management make it easy to use as the basis for research projects.
6061

lightllm/common/basemodel/basemodel.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -290,10 +290,6 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0)
290290
infer_state.req_manager = self.req_manager
291291

292292
infer_state.mem_index = model_input.mem_indexes
293-
infer_state.kv_buffer_shapedtype = (
294-
(model_input.input_ids.shape[0], self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
295-
self.data_type,
296-
)
297293
infer_state.microbatch_index = microbatch_index
298294
infer_state.dist_group = dist_group_manager.get_group(microbatch_index)
299295

lightllm/common/basemodel/infer_struct.py

Lines changed: 167 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import torch
2+
import triton
3+
import collections
24
from lightllm.common.mem_manager import MemoryManager
35
from lightllm.common.req_manager import ReqManager
46
from lightllm.distributed import CustomProcessGroup
5-
from typing import Tuple, Any, Optional
7+
from typing import Tuple, Any, Optional, List
68
from .triton_kernel.gen_prefill_params import gen_prefill_params
79
from .triton_kernel.gen_decode_params import gen_decode_params
810
from .triton_kernel.multimodal_emb import mark_multimodal_obj
911
from .batch_objs import ModelInput
12+
from lightllm.utils.envs_utils import get_env_start_args
13+
from lightllm.utils.dist_utils import get_global_dp_rank
1014

1115

1216
class InferStateInfo:
@@ -36,7 +40,6 @@ def __init__(self):
3640
self.req_manager: ReqManager = None
3741

3842
self.mem_index: torch.Tensor = None
39-
self.kv_buffer_shapedtype: Tuple[Any, Any] = None
4043

4144
self.is_token_healing: bool = False
4245
self.return_all_prompt_logics: bool = False
@@ -69,6 +72,18 @@ def __init__(self):
6972
# 的输入会用到,其他模型和场景都不会用到
7073
self.deepseekv3_mtp_draft_input_hiddens: Optional[torch.Tensor] = None
7174

75+
# 在单节点多dp的运行模式下,在进行prefill的阶段,如果出现了dp之间数据不平衡的现象,
76+
# 可以将推理的数据,进行重新分配到各个dp,在做 att 之前,重新 all to all 到各自的
77+
# dp,计算完成后,再 all to all 回去,这样可以使,各个dp 间处理的数据比较均衡,提升
78+
# prefill时候的计算效率。下面的变量,都是在这种场景下才会被使用的变量,普通情况下
79+
# 下面的变量不会被使用。
80+
self.need_dp_prefill_balance: bool = False
81+
self.dp_origin_lens: List[int] = None
82+
self.dp_handle_lens: List[int] = None
83+
# self.dp_input_lens: torch.Tensor = None
84+
self.dp_output_split_sizes: List[List[int]] = None
85+
self.dp_input_split_sizes: List[List[int]] = None
86+
7287
def init_some_extra_state(self, model, input_ids: torch.Tensor):
7388
if self.is_prefill:
7489
(
@@ -123,3 +138,153 @@ def mark_multimodal_objs_for_prefill(self, input_ids: torch.Tensor):
123138
for mark, obj in zip(marks_array, multi_objs):
124139
obj["_prefill_"] = mark > 0
125140
return
141+
142+
def prefill_dp_balance(self, input_ids: torch.Tensor):
143+
"""
144+
在prefill的时候, 对于处于 dp 模式下的时候,对输入的数据进行重新的调整和分配,降低各个dp处理数据量过于不一致的时候,导致
145+
的prefill 推理性能下降
146+
"""
147+
assert self.is_prefill
148+
import torch.distributed as dist
149+
150+
self.need_dp_prefill_balance = True
151+
152+
args = get_env_start_args()
153+
154+
dp_input_lens = torch.empty(size=(args.dp,), device="cuda", dtype=torch.int32)
155+
input_len = torch.empty(size=(1,), device="cuda", dtype=torch.int32)
156+
input_len.fill_(len(input_ids))
157+
dist.all_gather_into_tensor(
158+
output_tensor=dp_input_lens,
159+
input_tensor=input_len,
160+
group=self.dist_group.dp_prefill_balance_group,
161+
async_op=False,
162+
)
163+
dp_input_lens = dp_input_lens.detach().cpu()
164+
self.dp_origin_lens = dp_input_lens.tolist()
165+
sum_input_len = dp_input_lens.sum().item()
166+
dp_handle_lens = [sum_input_len // args.dp for _ in range(args.dp)]
167+
for i in range(sum_input_len % args.dp):
168+
dp_handle_lens[i] += 1
169+
170+
self.dp_handle_lens = dp_handle_lens.copy()
171+
172+
dest_dp_inputs = [[] for _ in range(args.dp)]
173+
# 分配每个dp 的原始输入和分配后的原始输入
174+
origin_datas = collections.deque()
175+
for origin_dp_index, origin_dp_input_len in enumerate(dp_input_lens.numpy()):
176+
handle_len = dp_handle_lens[origin_dp_index]
177+
if origin_dp_input_len > handle_len:
178+
origin_datas.append((origin_dp_index, handle_len, origin_dp_input_len))
179+
dp_handle_lens[origin_dp_index] = 0
180+
dest_dp_inputs[origin_dp_index].append((origin_dp_index, 0, handle_len))
181+
else:
182+
dp_handle_lens[origin_dp_index] -= origin_dp_input_len
183+
dest_dp_inputs[origin_dp_index].append((origin_dp_index, 0, origin_dp_input_len))
184+
185+
for dest_dp_index in range(args.dp):
186+
need_size = dp_handle_lens[dest_dp_index]
187+
if need_size == 0:
188+
continue
189+
while len(origin_datas) != 0:
190+
origin_data = origin_datas.popleft()
191+
origin_dp_index, start, end = origin_data
192+
if end - start > need_size:
193+
dest_dp_inputs[dest_dp_index].append((origin_dp_index, start, start + need_size))
194+
origin_datas.appendleft((origin_dp_index, start + need_size, end))
195+
break
196+
else:
197+
dest_dp_inputs[dest_dp_index].append((origin_dp_index, start, end))
198+
need_size -= end - start
199+
if need_size == 0:
200+
break
201+
202+
dp_output_split_sizes = [[0 for _ in range(args.dp)] for _ in range(args.dp)]
203+
for dest_dp_index, dest_dp_data in enumerate(dest_dp_inputs):
204+
for origin_dp_index, start, end in dest_dp_data:
205+
dp_output_split_sizes[dest_dp_index][origin_dp_index] += end - start
206+
dp_input_split_sizes = [[0 for _ in range(args.dp)] for _ in range(args.dp)]
207+
for dest_dp_index, dest_dp_data in enumerate(dest_dp_inputs):
208+
for origin_dp_index, start, end in dest_dp_data:
209+
dp_input_split_sizes[origin_dp_index][dest_dp_index] += end - start
210+
211+
self.dp_input_split_sizes = dp_input_split_sizes
212+
self.dp_output_split_sizes = dp_output_split_sizes
213+
214+
new_input_ids = self._all_to_all_balance_get(input_ids)
215+
if hasattr(self, "position_ids") and self.position_ids is not None:
216+
# deepseekv2 mla 特殊模型需要保留原始的 position_ids, 用于减少通信量
217+
self._unbalance_position_ids = self.position_ids
218+
219+
self.position_ids = self._all_to_all_balance_get(self.position_ids)
220+
if hasattr(self, "position_cos") and self.position_cos is not None:
221+
# deepseekv2 mla 特殊模型需要保留原始的 position_cos, 用于减少通信量
222+
self._unbalance_position_cos = self.position_cos
223+
224+
self.position_cos = self._all_to_all_balance_get(self.position_cos)
225+
if hasattr(self, "position_sin") and self.position_sin is not None:
226+
# deepseekv2 mla 特殊模型需要保留原始的 position_sin, 用于减少通信量
227+
self._unbalance_position_sin = self.position_sin
228+
229+
self.position_sin = self._all_to_all_balance_get(self.position_sin)
230+
231+
return new_input_ids
232+
233+
def _all_to_all_balance_get(self, data: torch.Tensor):
234+
dp_rank = get_global_dp_rank()
235+
import torch.distributed as dist
236+
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
237+
238+
old_shape = data.shape
239+
data = data.view(-1)
240+
241+
origin_len = self.dp_origin_lens[dp_rank]
242+
assert data.shape[0] % origin_len == 0
243+
scale_size = data.shape[0] // origin_len
244+
handle_len = self.dp_handle_lens[dp_rank]
245+
246+
dest_data = g_cache_manager.alloc_tensor(
247+
shape=(handle_len * scale_size,),
248+
data_type=data.dtype,
249+
device="cuda",
250+
is_graph_out=False,
251+
microbatch_index=self.microbatch_index,
252+
)
253+
dist.all_to_all_single(
254+
output=dest_data.view(-1),
255+
input=data.view(-1),
256+
output_split_sizes=[e * scale_size for e in self.dp_output_split_sizes[dp_rank]],
257+
input_split_sizes=[e * scale_size for e in self.dp_input_split_sizes[dp_rank]],
258+
group=self.dist_group.dp_prefill_balance_group,
259+
async_op=False,
260+
)
261+
return dest_data.view(-1, *old_shape[1:])
262+
263+
def _all_to_all_unbalance_get(self, data: torch.Tensor):
264+
dp_rank = get_global_dp_rank()
265+
import torch.distributed as dist
266+
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
267+
268+
old_shape = data.shape
269+
data = data.view(-1)
270+
271+
handle_len = self.dp_handle_lens[dp_rank]
272+
scale_size = data.shape[0] // handle_len
273+
assert data.shape[0] % handle_len == 0
274+
origin_len = self.dp_origin_lens[dp_rank]
275+
origin_data = g_cache_manager.alloc_tensor(
276+
shape=(origin_len * scale_size,),
277+
data_type=data.dtype,
278+
device="cuda",
279+
is_graph_out=False,
280+
microbatch_index=self.microbatch_index,
281+
)
282+
dist.all_to_all_single(
283+
output=origin_data.view(-1),
284+
input=data,
285+
output_split_sizes=[e * scale_size for e in self.dp_input_split_sizes[dp_rank]],
286+
input_split_sizes=[e * scale_size for e in self.dp_output_split_sizes[dp_rank]],
287+
group=self.dist_group.dp_prefill_balance_group,
288+
async_op=False,
289+
)
290+
return origin_data.view(-1, *old_shape[1:])

lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_cohere_template.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,12 @@ def _bind_rotary_emb_fwd(self):
4444
def _get_qkv(
4545
self, input, infer_state: InferStateInfo, layer_weight
4646
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
47-
cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight)
4847
q = torch.mm(input.view(-1, self.embed_dim_), layer_weight.q_weight_)
49-
torch.mm(
48+
cache_kv = torch.mm(
5049
input.view(-1, self.embed_dim_),
5150
layer_weight.kv_weight_,
52-
out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_),
53-
)
51+
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
52+
5453
if self.use_qk_norm_:
5554
q = q.view(-1, self.tp_q_head_num_, self.head_dim_)
5655
k = cache_kv[:, 0 : self.tp_k_head_num_, :]

lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,6 @@ def _att_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.T
3030
def _ffn_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor:
3131
raise Exception("need to impl")
3232

33-
def _pre_cache_kv(self, infer_state: InferStateInfo, layer_weight) -> torch.Tensor:
34-
cache_kv = self.alloc_tensor(
35-
shape=infer_state.kv_buffer_shapedtype[0],
36-
dtype=infer_state.kv_buffer_shapedtype[1],
37-
device="cuda",
38-
is_graph_out=False,
39-
microbatch_index=infer_state.microbatch_index,
40-
)
41-
return cache_kv
42-
4333
def _get_qkv(self, input, infer_state: InferStateInfo, layer_weight) -> Tuple[torch.Tensor, torch.Tensor]:
4434
raise Exception("need to impl")
4535

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,13 +1000,13 @@ def outplace_fused_experts_impl_fake(
10001000
hidden_states: torch.Tensor,
10011001
w1: torch.Tensor,
10021002
w2: torch.Tensor,
1003-
# optional bias for w1 and w2
1004-
w1_bias: Optional[torch.Tensor],
1005-
w2_bias: Optional[torch.Tensor],
10061003
topk_weights: torch.Tensor,
10071004
topk_ids: torch.Tensor,
10081005
use_fp8_w8a8: bool = False,
10091006
use_int8_w8a16: bool = False,
1007+
# optional bias for w1 and w2
1008+
w1_bias: Optional[torch.Tensor] = None,
1009+
w2_bias: Optional[torch.Tensor] = None,
10101010
w1_scale: Optional[torch.Tensor] = None,
10111011
w2_scale: Optional[torch.Tensor] = None,
10121012
a1_scale: Optional[torch.Tensor] = None,

lightllm/distributed/communication_op.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
get_global_rank,
3737
get_current_rank_in_dp,
3838
create_new_group_for_current_dp,
39+
create_dp_special_inter_group,
3940
)
4041
from lightllm.utils.device_utils import get_device_sm_count
4142
from lightllm.utils.sgl_utils import HAS_SGL_KERNEL
@@ -62,6 +63,11 @@ def __init__(self):
6263
self.custom_gather = None
6364
self.dp_world_size = get_dp_world_size()
6465
self.device_group = create_new_group_for_current_dp("nccl")
66+
if get_env_start_args().enable_dp_prefill_balance:
67+
self.dp_prefill_balance_group = create_dp_special_inter_group("nccl")
68+
else:
69+
self.dp_prefill_balance_group = None
70+
6571
self.autotune_group = dist.new_group([i for i in range(get_global_world_size())], backend="gloo")
6672

6773
def init_custom_reduce(self) -> None:

lightllm/models/bloom/layer_infer/transformer_layer_infer.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,7 @@ def _get_qkv(
4747
self, input, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight
4848
) -> Tuple[torch.Tensor, torch.Tensor]:
4949
q = layer_weight.q_proj.mm(input.view(-1, self.embed_dim_))
50-
cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight)
51-
cache_kv = layer_weight.kv_proj.mm(
52-
input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_)
53-
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
50+
cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
5451
return q, cache_kv
5552

5653
def _context_attention_kernel(

0 commit comments

Comments
 (0)