2525from vllm .model_executor .models .module_mapping import MultiModelKeys
2626from vllm .multimodal import MULTIMODAL_REGISTRY
2727from vllm .multimodal .inputs import (MultiModalDataDict , MultiModalFieldConfig ,
28- MultiModalKwargsItems , NestedTensors )
28+ MultiModalKwargsItems )
2929from vllm .multimodal .parse import (ImageEmbeddingItems , ImageProcessorItems ,
3030 ImageSize , MultiModalDataItems )
3131from vllm .multimodal .processing import (BaseMultiModalProcessor ,
3939
4040from .interfaces import (MultiModalEmbeddings , SupportsLoRA ,
4141 SupportsMultiModal , SupportsPP )
42- from .utils import (AutoWeightsLoader , WeightsMapper , flatten_bn ,
42+ from .utils import (AutoWeightsLoader , WeightsMapper ,
4343 init_vllm_registered_model , maybe_prefix )
4444
4545
@@ -304,7 +304,7 @@ def _call_hf_processor(
304304 mm_data : Mapping [str , object ],
305305 mm_kwargs : Mapping [str , object ],
306306 tok_kwargs : Mapping [str , object ],
307- ) -> Mapping [ str , NestedTensors ] :
307+ ) -> BatchFeature :
308308 mm_data = dict (mm_data )
309309 videos = mm_data .pop ("videos" , [])
310310 images = mm_data .pop ("images" , [])
@@ -342,7 +342,7 @@ def _call_hf_processor(
342342 image_placeholder , 1 )
343343
344344 num_patches = [len (item ) for item in image_pixel_values ]
345- image_outputs : dict [ str , NestedTensors ] = {
345+ image_outputs = {
346346 "pixel_values" : torch .concat (image_pixel_values ),
347347 "image_num_patches" : torch .tensor (num_patches ),
348348 "image_token_id" : torch .tensor (hf_processor .image_token_id ),
@@ -370,7 +370,7 @@ def _call_hf_processor(
370370 video_placeholder , 1 )
371371
372372 num_frames = [len (item ) for item in video_pixel_values ]
373- video_outputs : dict [ str , NestedTensors ] = {
373+ video_outputs = {
374374 "pixel_values_videos" : torch .concat (video_pixel_values ),
375375 "video_num_patches" : torch .tensor (num_frames ),
376376 "video_token_id" : torch .tensor (video_token_id ),
@@ -382,16 +382,11 @@ def _call_hf_processor(
382382 prompt )
383383 text_outputs = tokenizer (prompt , ** tok_kwargs , return_tensors = "pt" )
384384
385- combined_outputs = dict (
386- ** text_outputs ,
387- ** image_outputs ,
388- ** video_outputs ,
389- )
390- return BatchFeature (combined_outputs )
385+ return BatchFeature ({** text_outputs , ** image_outputs , ** video_outputs })
391386
392387 def _get_mm_fields_config (
393388 self ,
394- hf_inputs : Mapping [ str , NestedTensors ] ,
389+ hf_inputs : BatchFeature ,
395390 hf_processor_mm_kwargs : Mapping [str , object ],
396391 ) -> Mapping [str , MultiModalFieldConfig ]:
397392
@@ -487,6 +482,7 @@ def get_replacement_interns1_video(item_idx: int):
487482 dummy_inputs = InternS1DummyInputsBuilder )
488483class InternS1ForConditionalGeneration (nn .Module , SupportsMultiModal ,
489484 SupportsPP , SupportsLoRA ):
485+ merge_by_field_config = True
490486
491487 # To ensure correct weight loading and mapping.
492488 hf_to_vllm_mapper = WeightsMapper (
@@ -561,7 +557,7 @@ def _init_vision_model(
561557 prefix = prefix ,
562558 )
563559
564- def _init_mlp1 (self , config : PretrainedConfig ) -> nn .Sequential :
560+ def _init_mlp1 (self , config : PretrainedConfig ) -> nn .Module :
565561 return InternS1MultiModalProjector (config )
566562
567563 def pixel_shuffle (self , x , scale_factor = 0.5 ):
@@ -599,31 +595,16 @@ def _parse_and_validate_image_input(
599595 return None
600596
601597 if image_embeds is not None :
602- if not isinstance (image_embeds , (torch .Tensor , list )):
603- raise ValueError ("Incorrect type of image embeddings. "
604- f"Got type: { type (image_embeds )} " )
605-
606598 return InternS1ImageEmbeddingInputs (
607599 type = "image_embeds" ,
608- data = flatten_bn ( image_embeds ) ,
600+ data = image_embeds ,
609601 )
610602
611603 image_token_id = kwargs ["image_token_id" ]
612604 assert isinstance (image_token_id , torch .Tensor )
613605 self .img_context_token_id = image_token_id .flatten ().unique ().item ()
614606
615607 if pixel_values is not None :
616- if not isinstance (pixel_values , (torch .Tensor , list )):
617- raise ValueError ("Incorrect type of pixel values. "
618- f"Got type: { type (pixel_values )} " )
619-
620- if not isinstance (image_num_patches , (torch .Tensor , list )):
621- raise ValueError ("Incorrect type of image_num_patches. "
622- f"Got type: { type (image_num_patches )} " )
623-
624- pixel_values = flatten_bn (pixel_values , concat = True )
625- image_num_patches = flatten_bn (image_num_patches , concat = True )
626-
627608 h , w = self .config .vision_config .image_size
628609 return InternS1ImagePixelInputs (
629610 type = "pixel_values" ,
@@ -638,7 +619,7 @@ def _parse_and_validate_image_input(
638619 raise AssertionError ("This line should be unreachable." )
639620
640621 def _parse_and_validate_video_input (
641- self , ** kwargs : object ) -> Optional [InternS1VideoPixelInputs ]:
622+ self , ** kwargs : object ) -> Optional [InternS1VideoInputs ]:
642623 pixel_values_flat_video = kwargs .pop ("pixel_values_videos" , None )
643624 video_num_patches = kwargs .pop ("video_num_patches" , None )
644625 video_embeds = kwargs .pop ("video_embeds" , None )
@@ -647,32 +628,16 @@ def _parse_and_validate_video_input(
647628 return None
648629
649630 if video_embeds is not None :
650- if not isinstance (video_embeds , (torch .Tensor , list )):
651- raise ValueError ("Incorrect type of video embeddings. "
652- f"Got type: { type (video_embeds )} " )
653-
654- return InternS1ImageEmbeddingInputs (
631+ return InternS1VideoEmbeddingInputs (
655632 type = "video_embeds" ,
656- data = flatten_bn ( video_embeds ) ,
633+ data = video_embeds ,
657634 )
658635
659636 video_token_id = kwargs ["video_token_id" ]
660637 assert isinstance (video_token_id , torch .Tensor )
661638 self .video_context_token_id = video_token_id .flatten ().unique ().item ()
662639
663640 if pixel_values_flat_video is not None :
664- if not isinstance (pixel_values_flat_video , (torch .Tensor , list )):
665- raise ValueError ("Incorrect type of pixel values. "
666- f"Got type: { type (pixel_values_flat_video )} " )
667-
668- if not isinstance (video_num_patches , (torch .Tensor , list )):
669- raise ValueError ("Incorrect type of image_num_patches. "
670- f"Got type: { type (video_num_patches )} " )
671-
672- pixel_values_flat_video = flatten_bn (pixel_values_flat_video ,
673- concat = True )
674- video_num_patches = flatten_bn (video_num_patches , concat = True )
675-
676641 h , w = self .config .vision_config .image_size
677642 return InternS1VideoPixelInputs (
678643 type = "pixel_values_videos" ,
@@ -686,11 +651,12 @@ def _parse_and_validate_video_input(
686651
687652 raise AssertionError ("This line should be unreachable." )
688653
689- def _process_image_input (
654+ def _process_vision_input (
690655 self ,
691- image_input : Union [InternS1ImageInputs , InternS1VideoPixelInputs ],
656+ image_input : Union [InternS1ImageInputs , InternS1VideoInputs ],
692657 ) -> tuple [torch .Tensor , ...]:
693- if image_input ["type" ] == "image_embeds" :
658+ if (image_input ["type" ] == "image_embeds"
659+ or image_input ["type" ] == "video_embeds" ):
694660 return image_input ["data" ]
695661
696662 assert self .vision_tower is not None
@@ -753,11 +719,11 @@ def get_multimodal_embeddings(self,
753719 for modality in modalities :
754720 if modality == "images" :
755721 image_input = modalities ["images" ]
756- vision_embeddings = self ._process_image_input (image_input )
722+ vision_embeddings = self ._process_vision_input (image_input )
757723 multimodal_embeddings += vision_embeddings
758724 if modality == "videos" :
759725 video_input = modalities ["videos" ]
760- video_embeddings = self ._process_image_input (video_input )
726+ video_embeddings = self ._process_vision_input (video_input )
761727 multimodal_embeddings += video_embeddings
762728
763729 return multimodal_embeddings
0 commit comments