Skip to content

Commit 29fb61c

Browse files
committed
[Model][Qwen3VL] Add torch.compile support for Qwen3VL
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
1 parent 77f8001 commit 29fb61c

File tree

2 files changed

+88
-62
lines changed

2 files changed

+88
-62
lines changed

vllm/model_executor/models/qwen3_vl.py

Lines changed: 87 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
from vllm.config import VllmConfig
5656
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
5757
from vllm.distributed import get_pp_group
58+
from vllm.forward_context import set_forward_context
5859
from vllm.logger import init_logger
5960
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
6061
from vllm.model_executor.layers.linear import (
@@ -124,6 +125,7 @@
124125
_MAX_FRAMES_PER_VIDEO = 24576
125126

126127

128+
@support_torch_compile(dynamic_arg_dims={"x": 0})
127129
class Qwen3_VisionPatchEmbed(nn.Module):
128130
def __init__(
129131
self,
@@ -187,6 +189,10 @@ def forward(self, x: torch.Tensor):
187189
return mlp_output
188190

189191

192+
@support_torch_compile(
193+
dynamic_arg_dims={"x": 0, "cu_seqlens": 0, "rotary_pos_emb": 0, "seqlens": 0},
194+
mark_unbacked_dims={"seqlens": 0},
195+
)
190196
class Qwen3_VisionBlock(nn.Module):
191197
def __init__(
192198
self,
@@ -246,6 +252,7 @@ def forward(
246252
return x
247253

248254

255+
@support_torch_compile(dynamic_arg_dims={"x": 0})
249256
class Qwen3_VisionPatchMerger(nn.Module):
250257
def __init__(
251258
self,
@@ -275,6 +282,7 @@ def __init__(
275282
quant_config=quant_config,
276283
prefix=f"{prefix}.linear_fc1",
277284
disable_tp=use_data_parallel,
285+
return_bias=False,
278286
)
279287
self.act_fn = nn.GELU()
280288
self.linear_fc2 = RowParallelLinear(
@@ -284,6 +292,7 @@ def __init__(
284292
quant_config=quant_config,
285293
prefix=f"{prefix}.linear_fc2",
286294
disable_tp=use_data_parallel,
295+
return_bias=False,
287296
)
288297

289298
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -292,9 +301,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
292301
else:
293302
x = self.norm(x).view(-1, self.hidden_size)
294303

295-
x_parallel, _ = self.linear_fc1(x)
304+
x_parallel = self.linear_fc1(x)
296305
x_parallel = self.act_fn(x_parallel)
297-
out, _ = self.linear_fc2(x_parallel)
306+
out = self.linear_fc2(x_parallel)
298307
return out
299308

300309

@@ -325,45 +334,52 @@ def __init__(
325334
self.out_hidden_size = vision_config.out_hidden_size * (
326335
1 + len(self.deepstack_visual_indexes)
327336
)
328-
329-
self.patch_embed = Qwen3_VisionPatchEmbed(
330-
patch_size=self.patch_size,
331-
temporal_patch_size=self.temporal_patch_size,
332-
in_channels=vision_config.in_channels,
333-
hidden_size=self.hidden_size,
334-
)
337+
# TODO[@lucaskabela]: Investigate fixing this usage
338+
# see https://github.com/vllm-project/vllm/issues/27044
339+
# DO NOT MOVE THIS IMPORT
340+
from vllm.compilation.backends import set_model_tag
341+
342+
with set_model_tag("Qwen3_VisionPatchEmbed"):
343+
self.patch_embed = Qwen3_VisionPatchEmbed(
344+
patch_size=self.patch_size,
345+
temporal_patch_size=self.temporal_patch_size,
346+
in_channels=vision_config.in_channels,
347+
hidden_size=self.hidden_size,
348+
)
335349

336350
self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size)
337351

338352
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
339353
head_dim = self.hidden_size // self.num_heads
340354
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
341355

342-
self.merger = Qwen3_VisionPatchMerger(
343-
d_model=vision_config.out_hidden_size,
344-
context_dim=self.hidden_size,
345-
norm_layer=norm_layer,
346-
spatial_merge_size=self.spatial_merge_size,
347-
quant_config=quant_config,
348-
prefix=f"{prefix}.merger",
349-
use_data_parallel=use_data_parallel,
350-
)
356+
with set_model_tag("Qwen3_VisionPatchMerger"):
357+
self.merger = Qwen3_VisionPatchMerger(
358+
d_model=vision_config.out_hidden_size,
359+
context_dim=self.hidden_size,
360+
norm_layer=norm_layer,
361+
spatial_merge_size=self.spatial_merge_size,
362+
quant_config=quant_config,
363+
prefix=f"{prefix}.merger",
364+
use_data_parallel=use_data_parallel,
365+
)
351366

352-
self.deepstack_merger_list = nn.ModuleList(
353-
[
354-
Qwen3_VisionPatchMerger(
355-
d_model=vision_config.out_hidden_size,
356-
context_dim=self.hidden_size,
357-
spatial_merge_size=self.spatial_merge_size,
358-
use_postshuffle_norm=True,
359-
norm_layer=norm_layer,
360-
quant_config=quant_config,
361-
prefix=f"{prefix}.deepstack_merger_list.{layer_idx}",
362-
use_data_parallel=use_data_parallel,
363-
)
364-
for layer_idx in range(len(self.deepstack_visual_indexes))
365-
]
366-
)
367+
with set_model_tag("Qwen3_VisionPatchMerger_postshuffle_norm"):
368+
self.deepstack_merger_list = nn.ModuleList(
369+
[
370+
Qwen3_VisionPatchMerger(
371+
d_model=vision_config.out_hidden_size,
372+
context_dim=self.hidden_size,
373+
spatial_merge_size=self.spatial_merge_size,
374+
use_postshuffle_norm=True,
375+
norm_layer=norm_layer,
376+
quant_config=quant_config,
377+
prefix=f"{prefix}.deepstack_merger_list.{layer_idx}",
378+
use_data_parallel=use_data_parallel,
379+
)
380+
for layer_idx in range(len(self.deepstack_visual_indexes))
381+
]
382+
)
367383

368384
self.attn_backend = get_vit_attn_backend(
369385
head_size=head_dim,
@@ -388,23 +404,24 @@ def __init__(
388404
raise RuntimeError(
389405
f"Qwen3-VL does not support {self.attn_backend} backend now."
390406
)
391-
self.blocks = nn.ModuleList(
392-
[
393-
Qwen3_VisionBlock(
394-
dim=self.hidden_size,
395-
num_heads=self.num_heads,
396-
mlp_hidden_dim=vision_config.intermediate_size,
397-
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
398-
norm_layer=norm_layer,
399-
quant_config=quant_config,
400-
prefix=f"{prefix}.blocks.{layer_idx}",
401-
use_data_parallel=use_data_parallel,
402-
attn_backend=self.attn_backend,
403-
use_upstream_fa=use_upstream_fa,
404-
)
405-
for layer_idx in range(vision_config.depth)
406-
]
407-
)
407+
with set_model_tag("Qwen3_VisionBlock"):
408+
self.blocks = nn.ModuleList(
409+
[
410+
Qwen3_VisionBlock(
411+
dim=self.hidden_size,
412+
num_heads=self.num_heads,
413+
mlp_hidden_dim=vision_config.intermediate_size,
414+
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
415+
norm_layer=norm_layer,
416+
quant_config=quant_config,
417+
prefix=f"{prefix}.blocks.{layer_idx}",
418+
use_data_parallel=use_data_parallel,
419+
attn_backend=self.attn_backend,
420+
use_upstream_fa=use_upstream_fa,
421+
)
422+
for layer_idx in range(vision_config.depth)
423+
]
424+
)
408425

409426
@property
410427
def dtype(self) -> torch.dtype:
@@ -1217,6 +1234,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
12171234
multimodal_config = vllm_config.model_config.multimodal_config
12181235

12191236
self.config = config
1237+
self.vllm_config = vllm_config
12201238
self.multimodal_config = multimodal_config
12211239
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
12221240
if not multimodal_config.get_limit_per_prompt(
@@ -1362,12 +1380,13 @@ def _process_image_input(
13621380
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
13631381
else:
13641382
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
1365-
if self.use_data_parallel:
1366-
return run_dp_sharded_mrope_vision_model(
1367-
self.visual, pixel_values, grid_thw_list, rope_type="rope_3d"
1368-
)
1369-
else:
1370-
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
1383+
with set_forward_context(None, self.vllm_config):
1384+
if self.use_data_parallel:
1385+
return run_dp_sharded_mrope_vision_model(
1386+
self.visual, pixel_values, grid_thw_list, rope_type="rope_3d"
1387+
)
1388+
else:
1389+
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
13711390

13721391
# Split concatenated embeddings for each image item.
13731392
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
@@ -1391,12 +1410,18 @@ def _process_video_input(
13911410
pixel_values_videos = video_input["pixel_values_videos"].type(
13921411
self.visual.dtype
13931412
)
1394-
if self.use_data_parallel:
1395-
return run_dp_sharded_mrope_vision_model(
1396-
self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d"
1397-
)
1398-
else:
1399-
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list)
1413+
with set_forward_context(None, self.vllm_config):
1414+
if self.use_data_parallel:
1415+
return run_dp_sharded_mrope_vision_model(
1416+
self.visual,
1417+
pixel_values_videos,
1418+
grid_thw_list,
1419+
rope_type="rope_3d",
1420+
)
1421+
else:
1422+
video_embeds = self.visual(
1423+
pixel_values_videos, grid_thw=grid_thw_list
1424+
)
14001425

14011426
# Split concatenated embeddings for each video item.
14021427
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync

vllm/model_executor/models/qwen3_vl_moe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
365365
multimodal_config = vllm_config.model_config.multimodal_config
366366

367367
self.config = config
368+
self.vllm_config = vllm_config
368369
self.multimodal_config = multimodal_config
369370
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
370371

0 commit comments

Comments
 (0)