Skip to content

Commit 723c585

Browse files
Levi-JQLevi-JQzzhx1
authored andcommitted
[Feat] flashcomm_v2 optim solution (vllm-project#3232)
### What this PR does / why we need it? Supports generalized FlashComm2 optimization, which reduces communication overhead, decreases RmsNorm computation, and saves one AllGather step by replacing Allreduce operations in the Attention module with pre-AlltoAll and post-AllGather operations (used in combination with FlashComm1). This feature is enabled during the Prefill phase and is recommended to be used together with FlashComm1, delivering broad performance improvements, especially in long sequence scenarios with large tensor parallelism (TP) configurations. Benchmark tests show that under TP16DP1 configuration, it can improve the prefill performance of the DeepSeek model by 8% on top of FlashComm1. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: vllm-project/vllm@83f478b --------- Signed-off-by: zzhxx <2783294813@qq.com> Signed-off-by: Levi-JQ <yujinqi2@huawei.com> Co-authored-by: Levi-JQ <yujinqi2@huawei.com> Co-authored-by: zzhxx <2783294813@qq.com> Signed-off-by: nsdie <yeyifan@huawei.com>
1 parent 5dca8ba commit 723c585

File tree

12 files changed

+380
-24
lines changed

12 files changed

+380
-24
lines changed

.github/workflows/_e2e_test.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ jobs:
195195
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC_old_version
196196
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC
197197
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_sp_for_qwen3_moe
198+
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_fc2_for_qwen3_moe
198199
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen_Dense_with_flashcomm_v1
199200
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen_Dense_with_prefetch_mlp_weight
200201

tests/e2e/multicard/test_offline_inference_distributed.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,26 @@ def test_sp_for_qwen3_moe() -> None:
189189
vllm_model.generate(example_prompts, sampling_params)
190190

191191

192+
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
193+
@patch.dict(os.environ, {"VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE": "1"})
194+
def test_fc2_for_qwen3_moe() -> None:
195+
example_prompts = [
196+
"Hello, my name is",
197+
]
198+
sampling_params = SamplingParams(max_tokens=5,
199+
temperature=0.0,
200+
top_k=50,
201+
top_p=0.9)
202+
203+
with VllmRunner(snapshot_download("Qwen/Qwen3-30B-A3B"),
204+
dtype="auto",
205+
tensor_parallel_size=2,
206+
distributed_executor_backend="mp",
207+
enable_expert_parallel=True,
208+
enforce_eager=True) as vllm_model:
209+
vllm_model.generate(example_prompts, sampling_params)
210+
211+
192212
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
193213
def test_models_distributed_deepseek_v2_lite_with_flashcomm_v1() -> None:
194214
example_prompts = [

tests/e2e/singlecard/test_aclgraph_mem.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
reason="aclgraph only support on v1")
3535
@pytest.mark.parametrize("model", MODELS)
3636
@pytest.mark.parametrize("max_tokens", [4])
37+
@patch.dict(os.environ, {"VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE": "0"})
3738
@patch.dict(os.environ, {"ASCEND_RT_VISIBLE_DEVICES": "0,1"})
3839
def test_aclgraph_mem_use(model: str, max_tokens: int) -> None:
3940
del os.environ["VLLM_WORKER_MULTIPROC_METHOD"]

tests/ut/distributed/test_parallel_state.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
from vllm.config import ParallelConfig
55

66
from vllm_ascend.distributed.parallel_state import (
7-
_LMTP, _MC2, _OTP, _P_TP, destroy_ascend_model_parallel,
8-
get_lmhead_tp_group, get_mc2_group, get_otp_group, get_p_tp_group,
9-
init_ascend_model_parallel)
7+
_FLASHCOMM2_ODP, _FLASHCOMM2_OTP, _LMTP, _MC2, _OTP, _P_TP,
8+
destroy_ascend_model_parallel, get_flashcomm2_odp_group,
9+
get_flashcomm2_otp_group, get_lmhead_tp_group, get_mc2_group,
10+
get_otp_group, get_p_tp_group, init_ascend_model_parallel)
1011

1112

1213
@pytest.fixture
@@ -21,38 +22,54 @@ def mock_distributed():
2122
with patch('torch.distributed.is_initialized', return_value=True), \
2223
patch('torch.distributed.get_world_size', return_value=8), \
2324
patch('torch.distributed.get_backend', return_value='nccl'), \
24-
patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group:
25+
patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group, \
26+
patch('vllm_ascend.distributed.parallel_state.get_tp_group') as mock_tp_group, \
27+
patch('vllm_ascend.distributed.parallel_state.get_dp_group') as mock_dp_group:
2528
mock_group.return_value.local_rank = 0
2629
mock_group.return_value.device_group = MagicMock()
30+
mock_tp_group.return_value.world_size = 4
31+
mock_dp_group.return_value.world_size = 2
2732
yield
2833

2934

3035
def test_init_ascend_model_parallel(mock_distributed, parallel_config):
3136
mock_ascend_config = MagicMock()
3237
mock_ascend_config.lmhead_tensor_parallel_size = 2
3338
mock_ascend_config.oproj_tensor_parallel_size = 2
39+
mock_ascend_config.flashcomm2_oproj_tensor_parallel_size = 2
3440
mock_ascend_config.pd_tp_ratio = 2
3541
mock_ascend_config.num_head_replica = 0
3642
mock_ascend_config.pd_head_ratio = 2
3743
mock_vllm_config = MagicMock()
3844
mock_vllm_config.kv_transfer_config.is_kv_producer = True
45+
mock_envs_ascend = MagicMock()
46+
mock_envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE = 2
47+
mock_envs_ascend.VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL = 0
3948
with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \
4049
patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \
4150
patch('vllm_ascend.distributed.parallel_state.get_current_vllm_config', return_value=mock_vllm_config), \
42-
patch('vllm_ascend.distributed.parallel_state.get_ascend_config', return_value=mock_ascend_config):
51+
patch('vllm_ascend.distributed.parallel_state.get_ascend_config', return_value=mock_ascend_config), \
52+
patch('vllm_ascend.utils.envs_ascend', new=mock_envs_ascend), \
53+
patch('vllm_ascend.utils.get_ascend_config', return_value=mock_ascend_config):
4354
init_ascend_model_parallel(parallel_config)
4455

4556
mc2_group = get_mc2_group()
4657
lmheadtp_group = get_lmhead_tp_group()
4758
otp_group = get_otp_group()
59+
flashcomm2_otp_group = get_flashcomm2_otp_group()
60+
flashcomm2_odp_group = get_flashcomm2_odp_group()
4861
p_tp_group = get_p_tp_group()
4962
assert mc2_group is not None
5063
assert otp_group is not None
64+
assert flashcomm2_otp_group is not None
65+
assert flashcomm2_odp_group is not None
5166
assert lmheadtp_group is not None
5267
assert p_tp_group is not None
5368

5469
destroy_ascend_model_parallel()
5570
assert _MC2 is None
5671
assert _LMTP is None
5772
assert _OTP is None
73+
assert _FLASHCOMM2_OTP is None
74+
assert _FLASHCOMM2_ODP is None
5875
assert _P_TP is None

vllm_ascend/ascend_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,10 @@ def __init__(self, vllm_config):
130130
"Only support P node tp size lagger then D node tp size")
131131
self.SLO_limits_for_dynamic_batch = additional_config.get(
132132
"SLO_limits_for_dynamic_batch", -1)
133+
from vllm_ascend.utils import \
134+
get_flashcomm2_oproj_tp_size_and_validate_config
135+
self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_oproj_tp_size_and_validate_config(
136+
self, vllm_config)
133137

134138

135139
class TorchairGraphConfig:

vllm_ascend/ascend_forward_context.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
set_forward_context)
1212

1313
import vllm_ascend.envs as envs_ascend
14-
from vllm_ascend.utils import enable_sp, has_layer_idx, is_moe_model
14+
from vllm_ascend.utils import (enable_sp, flashcomm2_enable, has_layer_idx,
15+
is_moe_model)
1516

1617
if TYPE_CHECKING:
1718
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
@@ -121,13 +122,17 @@ def set_ascend_forward_context(
121122
sp_enabled = enable_sp(vllm_config) and \
122123
num_tokens is not None and num_tokens > 1000
123124
forward_context.mmrs_fusion = mmrs_fusion
125+
forward_context.num_tokens = num_tokens
126+
forward_context.sp_enabled = sp_enabled
127+
#TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
128+
forward_context.flashcomm_v2_enabled = flashcomm2_enable(
129+
) and tp_world_size > 1 and num_tokens is not None
124130

125-
if sp_enabled:
131+
if (forward_context.sp_enabled
132+
or forward_context.flashcomm_v2_enabled):
126133
pad_size = (tp_world_size -
127134
(num_tokens % tp_world_size)) % tp_world_size
128135
forward_context.pad_size = pad_size
129-
forward_context.sp_enabled = sp_enabled
130-
forward_context.num_tokens = num_tokens
131136

132137
# set this for rope forward_oot using
133138
forward_context.is_first_layer = True
@@ -179,7 +184,8 @@ def set_ascend_forward_context(
179184
if dp_world_size > 1 and forward_context.dp_metadata is not None:
180185
max_tokens_across_dp = \
181186
forward_context.dp_metadata.max_tokens_across_dp_cpu.item()
182-
if sp_enabled:
187+
if (forward_context.sp_enabled
188+
or forward_context.flashcomm_v2_enabled):
183189
padded_length = (max_tokens_across_dp + tp_world_size -
184190
1) // tp_world_size * tp_world_size
185191
pad_size = padded_length - num_tokens

vllm_ascend/distributed/parallel_state.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,23 @@
22

33
import torch
44
from vllm.config import ParallelConfig, get_current_vllm_config
5-
from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group,
5+
from vllm.distributed.parallel_state import (GroupCoordinator, get_dp_group,
6+
get_tp_group, get_world_group,
67
init_model_parallel_group)
78

89
import vllm_ascend.envs as envs_ascend
910
from vllm_ascend.ascend_config import get_ascend_config
10-
from vllm_ascend.utils import prefill_context_parallel_enable
11+
from vllm_ascend.utils import (flashcomm2_enable,
12+
prefill_context_parallel_enable)
1113

1214
# Currently, mc2 op need their own group coordinator.
1315
_MC2: Optional[GroupCoordinator] = None
1416
_MLP_TP: Optional[GroupCoordinator] = None
1517
_OTP: Optional[GroupCoordinator] = None
1618
_LMTP: Optional[GroupCoordinator] = None
1719
_P_TP: Optional[GroupCoordinator] = None
20+
_FLASHCOMM2_OTP: Optional[GroupCoordinator] = None
21+
_FLASHCOMM2_ODP: Optional[GroupCoordinator] = None
1822

1923

2024
def get_mc2_group() -> GroupCoordinator:
@@ -34,6 +38,16 @@ def get_lmhead_tp_group() -> GroupCoordinator:
3438
return _LMTP
3539

3640

41+
def get_flashcomm2_otp_group() -> GroupCoordinator:
42+
return _FLASHCOMM2_OTP
43+
44+
45+
def get_flashcomm2_odp_group() -> GroupCoordinator:
46+
assert _FLASHCOMM2_ODP is not None, (
47+
"output data parallel group for flashcomm2 is not initialized")
48+
return _FLASHCOMM2_ODP
49+
50+
3751
def get_mlp_tp_group() -> GroupCoordinator:
3852
assert _MLP_TP is not None, ("mlp group is not initialized")
3953
return _MLP_TP
@@ -165,6 +179,48 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
165179
backend,
166180
group_name="lmheadtp")
167181

182+
# TODO: Extract and unify the logic across different communication group.
183+
if flashcomm2_enable():
184+
flashcomm2_otp_size = get_ascend_config(
185+
).flashcomm2_oproj_tensor_parallel_size
186+
global_tp_size = get_tp_group().world_size
187+
global_dp_size = get_dp_group().world_size
188+
num_fc2_oproj_tensor_parallel_groups: int = (global_tp_size //
189+
flashcomm2_otp_size)
190+
191+
global _FLASHCOMM2_OTP
192+
global _FLASHCOMM2_ODP
193+
194+
_FLASHCOMM2_OTP = None
195+
_FLASHCOMM2_ODP = get_tp_group()
196+
197+
if flashcomm2_otp_size > 1:
198+
otp_group_ranks = []
199+
odp_group_ranks: list[list[int]] = [
200+
[] for _ in range(flashcomm2_otp_size * global_dp_size)
201+
]
202+
203+
for dp_group_index in range(global_dp_size):
204+
for i in range(num_fc2_oproj_tensor_parallel_groups):
205+
ranks = []
206+
for j in range(flashcomm2_otp_size):
207+
rank_idx = dp_group_index * global_tp_size + i + j * num_fc2_oproj_tensor_parallel_groups
208+
ranks.append(rank_idx)
209+
odp_group_index = dp_group_index * flashcomm2_otp_size + j
210+
odp_group_ranks[odp_group_index].append(rank_idx)
211+
otp_group_ranks.append(ranks)
212+
213+
_FLASHCOMM2_OTP = init_model_parallel_group(
214+
otp_group_ranks,
215+
get_world_group().local_rank,
216+
backend,
217+
group_name="flashcomm2_otp")
218+
_FLASHCOMM2_ODP = init_model_parallel_group(
219+
odp_group_ranks,
220+
get_world_group().local_rank,
221+
backend,
222+
group_name="flashcomm2_odp")
223+
168224

169225
def get_mlp_tensor_model_parallel_world_size():
170226
"""Return world size for the tensor model parallel group."""
@@ -201,3 +257,15 @@ def destroy_ascend_model_parallel():
201257
if _P_TP:
202258
_P_TP.destroy()
203259
_P_TP = None
260+
261+
global _FLASHCOMM2_OTP
262+
if _FLASHCOMM2_OTP and get_ascend_config(
263+
).flashcomm2_oproj_tensor_parallel_size != 1:
264+
_FLASHCOMM2_OTP.destroy()
265+
_FLASHCOMM2_OTP = None
266+
267+
global _FLASHCOMM2_ODP
268+
if _FLASHCOMM2_ODP and get_ascend_config(
269+
).flashcomm2_oproj_tensor_parallel_size != 1:
270+
_FLASHCOMM2_ODP.destroy()
271+
_FLASHCOMM2_ODP = None

vllm_ascend/envs.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,12 @@
132132
# This feature will get better performance when concurrency is large.
133133
"VLLM_ASCEND_ENABLE_FLASHCOMM1":
134134
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", '0'))),
135+
# Whether to enable FLASHCOMM2. Setting it to 0 disables the feature, while setting it to 1 or above enables it.
136+
# The specific value set will be used as the O-matrix TP group size for flashcomm2.
137+
# For a detailed introduction to the parameters and the differences and applicable scenarios
138+
# between this feature and FLASHCOMM1, please refer to the feature guide in the documentation.
139+
"VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE":
140+
lambda: int(os.getenv("VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE", 0)),
135141
# Whether to enable MLP weight prefetch, only used in small concurrency.
136142
"VLLM_ASCEND_ENABLE_PREFETCH_MLP":
137143
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))),
@@ -185,4 +191,4 @@ def __getattr__(name: str):
185191

186192

187193
def __dir__():
188-
return list(env_variables.keys())
194+
return list(env_variables.keys())

0 commit comments

Comments
 (0)