5555from vllm .config import VllmConfig
5656from vllm .config .multimodal import BaseDummyOptions , VideoDummyOptions
5757from vllm .distributed import get_pp_group
58+ from vllm .forward_context import set_forward_context
5859from vllm .logger import init_logger
5960from vllm .model_executor .layers .activation import _ACTIVATION_REGISTRY
6061from vllm .model_executor .layers .linear import (
124125_MAX_FRAMES_PER_VIDEO = 24576
125126
126127
128+ @support_torch_compile (dynamic_arg_dims = {"x" : 0 })
127129class Qwen3_VisionPatchEmbed (nn .Module ):
128130 def __init__ (
129131 self ,
@@ -187,6 +189,10 @@ def forward(self, x: torch.Tensor):
187189 return mlp_output
188190
189191
192+ @support_torch_compile (
193+ dynamic_arg_dims = {"x" : 0 , "cu_seqlens" : 0 , "rotary_pos_emb" : 0 , "seqlens" : 0 },
194+ mark_unbacked_dims = {"seqlens" : 0 },
195+ )
190196class Qwen3_VisionBlock (nn .Module ):
191197 def __init__ (
192198 self ,
@@ -246,6 +252,7 @@ def forward(
246252 return x
247253
248254
255+ @support_torch_compile (dynamic_arg_dims = {"x" : 0 })
249256class Qwen3_VisionPatchMerger (nn .Module ):
250257 def __init__ (
251258 self ,
@@ -275,6 +282,7 @@ def __init__(
275282 quant_config = quant_config ,
276283 prefix = f"{ prefix } .linear_fc1" ,
277284 disable_tp = use_data_parallel ,
285+ return_bias = False ,
278286 )
279287 self .act_fn = nn .GELU ()
280288 self .linear_fc2 = RowParallelLinear (
@@ -284,6 +292,7 @@ def __init__(
284292 quant_config = quant_config ,
285293 prefix = f"{ prefix } .linear_fc2" ,
286294 disable_tp = use_data_parallel ,
295+ return_bias = False ,
287296 )
288297
289298 def forward (self , x : torch .Tensor ) -> torch .Tensor :
@@ -292,9 +301,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
292301 else :
293302 x = self .norm (x ).view (- 1 , self .hidden_size )
294303
295- x_parallel , _ = self .linear_fc1 (x )
304+ x_parallel = self .linear_fc1 (x )
296305 x_parallel = self .act_fn (x_parallel )
297- out , _ = self .linear_fc2 (x_parallel )
306+ out = self .linear_fc2 (x_parallel )
298307 return out
299308
300309
@@ -325,45 +334,52 @@ def __init__(
325334 self .out_hidden_size = vision_config .out_hidden_size * (
326335 1 + len (self .deepstack_visual_indexes )
327336 )
328-
329- self .patch_embed = Qwen3_VisionPatchEmbed (
330- patch_size = self .patch_size ,
331- temporal_patch_size = self .temporal_patch_size ,
332- in_channels = vision_config .in_channels ,
333- hidden_size = self .hidden_size ,
334- )
337+ # TODO[@lucaskabela]: Investigate fixing this usage
338+ # see https://github.com/vllm-project/vllm/issues/27044
339+ # DO NOT MOVE THIS IMPORT
340+ from vllm .compilation .backends import set_model_tag
341+
342+ with set_model_tag ("Qwen3_VisionPatchEmbed" ):
343+ self .patch_embed = Qwen3_VisionPatchEmbed (
344+ patch_size = self .patch_size ,
345+ temporal_patch_size = self .temporal_patch_size ,
346+ in_channels = vision_config .in_channels ,
347+ hidden_size = self .hidden_size ,
348+ )
335349
336350 self .pos_embed = nn .Embedding (self .num_position_embeddings , self .hidden_size )
337351
338352 norm_layer = partial (nn .LayerNorm , eps = norm_eps )
339353 head_dim = self .hidden_size // self .num_heads
340354 self .rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding (head_dim // 2 )
341355
342- self .merger = Qwen3_VisionPatchMerger (
343- d_model = vision_config .out_hidden_size ,
344- context_dim = self .hidden_size ,
345- norm_layer = norm_layer ,
346- spatial_merge_size = self .spatial_merge_size ,
347- quant_config = quant_config ,
348- prefix = f"{ prefix } .merger" ,
349- use_data_parallel = use_data_parallel ,
350- )
356+ with set_model_tag ("Qwen3_VisionPatchMerger" ):
357+ self .merger = Qwen3_VisionPatchMerger (
358+ d_model = vision_config .out_hidden_size ,
359+ context_dim = self .hidden_size ,
360+ norm_layer = norm_layer ,
361+ spatial_merge_size = self .spatial_merge_size ,
362+ quant_config = quant_config ,
363+ prefix = f"{ prefix } .merger" ,
364+ use_data_parallel = use_data_parallel ,
365+ )
351366
352- self .deepstack_merger_list = nn .ModuleList (
353- [
354- Qwen3_VisionPatchMerger (
355- d_model = vision_config .out_hidden_size ,
356- context_dim = self .hidden_size ,
357- spatial_merge_size = self .spatial_merge_size ,
358- use_postshuffle_norm = True ,
359- norm_layer = norm_layer ,
360- quant_config = quant_config ,
361- prefix = f"{ prefix } .deepstack_merger_list.{ layer_idx } " ,
362- use_data_parallel = use_data_parallel ,
363- )
364- for layer_idx in range (len (self .deepstack_visual_indexes ))
365- ]
366- )
367+ with set_model_tag ("Qwen3_VisionPatchMerger_postshuffle_norm" ):
368+ self .deepstack_merger_list = nn .ModuleList (
369+ [
370+ Qwen3_VisionPatchMerger (
371+ d_model = vision_config .out_hidden_size ,
372+ context_dim = self .hidden_size ,
373+ spatial_merge_size = self .spatial_merge_size ,
374+ use_postshuffle_norm = True ,
375+ norm_layer = norm_layer ,
376+ quant_config = quant_config ,
377+ prefix = f"{ prefix } .deepstack_merger_list.{ layer_idx } " ,
378+ use_data_parallel = use_data_parallel ,
379+ )
380+ for layer_idx in range (len (self .deepstack_visual_indexes ))
381+ ]
382+ )
367383
368384 self .attn_backend = get_vit_attn_backend (
369385 head_size = head_dim ,
@@ -388,23 +404,24 @@ def __init__(
388404 raise RuntimeError (
389405 f"Qwen3-VL does not support { self .attn_backend } backend now."
390406 )
391- self .blocks = nn .ModuleList (
392- [
393- Qwen3_VisionBlock (
394- dim = self .hidden_size ,
395- num_heads = self .num_heads ,
396- mlp_hidden_dim = vision_config .intermediate_size ,
397- act_fn = _ACTIVATION_REGISTRY [vision_config .hidden_act ],
398- norm_layer = norm_layer ,
399- quant_config = quant_config ,
400- prefix = f"{ prefix } .blocks.{ layer_idx } " ,
401- use_data_parallel = use_data_parallel ,
402- attn_backend = self .attn_backend ,
403- use_upstream_fa = use_upstream_fa ,
404- )
405- for layer_idx in range (vision_config .depth )
406- ]
407- )
407+ with set_model_tag ("Qwen3_VisionBlock" ):
408+ self .blocks = nn .ModuleList (
409+ [
410+ Qwen3_VisionBlock (
411+ dim = self .hidden_size ,
412+ num_heads = self .num_heads ,
413+ mlp_hidden_dim = vision_config .intermediate_size ,
414+ act_fn = _ACTIVATION_REGISTRY [vision_config .hidden_act ],
415+ norm_layer = norm_layer ,
416+ quant_config = quant_config ,
417+ prefix = f"{ prefix } .blocks.{ layer_idx } " ,
418+ use_data_parallel = use_data_parallel ,
419+ attn_backend = self .attn_backend ,
420+ use_upstream_fa = use_upstream_fa ,
421+ )
422+ for layer_idx in range (vision_config .depth )
423+ ]
424+ )
408425
409426 @property
410427 def dtype (self ) -> torch .dtype :
@@ -1217,6 +1234,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
12171234 multimodal_config = vllm_config .model_config .multimodal_config
12181235
12191236 self .config = config
1237+ self .vllm_config = vllm_config
12201238 self .multimodal_config = multimodal_config
12211239 self .use_data_parallel = multimodal_config .mm_encoder_tp_mode == "data"
12221240 if not multimodal_config .get_limit_per_prompt (
@@ -1362,12 +1380,13 @@ def _process_image_input(
13621380 image_embeds = image_input ["image_embeds" ].type (self .visual .dtype )
13631381 else :
13641382 pixel_values = image_input ["pixel_values" ].type (self .visual .dtype )
1365- if self .use_data_parallel :
1366- return run_dp_sharded_mrope_vision_model (
1367- self .visual , pixel_values , grid_thw_list , rope_type = "rope_3d"
1368- )
1369- else :
1370- image_embeds = self .visual (pixel_values , grid_thw = grid_thw_list )
1383+ with set_forward_context (None , self .vllm_config ):
1384+ if self .use_data_parallel :
1385+ return run_dp_sharded_mrope_vision_model (
1386+ self .visual , pixel_values , grid_thw_list , rope_type = "rope_3d"
1387+ )
1388+ else :
1389+ image_embeds = self .visual (pixel_values , grid_thw = grid_thw_list )
13711390
13721391 # Split concatenated embeddings for each image item.
13731392 # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
@@ -1391,12 +1410,18 @@ def _process_video_input(
13911410 pixel_values_videos = video_input ["pixel_values_videos" ].type (
13921411 self .visual .dtype
13931412 )
1394- if self .use_data_parallel :
1395- return run_dp_sharded_mrope_vision_model (
1396- self .visual , pixel_values_videos , grid_thw_list , rope_type = "rope_3d"
1397- )
1398- else :
1399- video_embeds = self .visual (pixel_values_videos , grid_thw = grid_thw_list )
1413+ with set_forward_context (None , self .vllm_config ):
1414+ if self .use_data_parallel :
1415+ return run_dp_sharded_mrope_vision_model (
1416+ self .visual ,
1417+ pixel_values_videos ,
1418+ grid_thw_list ,
1419+ rope_type = "rope_3d" ,
1420+ )
1421+ else :
1422+ video_embeds = self .visual (
1423+ pixel_values_videos , grid_thw = grid_thw_list
1424+ )
14001425
14011426 # Split concatenated embeddings for each video item.
14021427 # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
0 commit comments