@@ -752,8 +752,8 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs):
752752 seq_len = constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN ,
753753 )
754754
755- lang_inputs ["past_key_values" ] = [[] for _ in range (self .model .config .num_hidden_layers )]
756- for i in range (self .model .config .num_hidden_layers ):
755+ lang_inputs ["past_key_values" ] = [[] for _ in range (self .model .config .text_config . num_hidden_layers )]
756+ for i in range (self .model .config .text_config . num_hidden_layers ):
757757 for kv in ["key" , "value" ]:
758758 lang_inputs ["past_key_values" ][i ].append (torch .zeros (kv_cache_shape , dtype = torch .float32 ))
759759
@@ -779,10 +779,10 @@ def get_specializations(
779779 ** compiler_options ,
780780 ):
781781 if height is None or width is None :
782- height = 1365
783- width = 2048
782+ height = constants . QWEN2_5_VL_HEIGHT
783+ width = constants . QWEN2_5_VL_WIDTH
784784 logger .warning (
785- "Setting height and width to be 1365 and 2048 respectively, as it was neither passed nor found in vision_config"
785+ f "Setting height and width to be { height } and { width } respectively, as it was neither passed nor found in vision_config"
786786 )
787787 prefill_seq_len = prefill_seq_len if prefill_seq_len else 128
788788 ctx_len = ctx_len if ctx_len else constants .INTERN_CTX_LEN
@@ -882,7 +882,7 @@ def smart_resize(
882882
883883 def get_onnx_dynamic_axes (self , kv_offload : bool = False ):
884884 # Define dynamic axes
885- num_layers = self .config .num_hidden_layers
885+ num_layers = self .config .text_config . num_hidden_layers
886886
887887 vision_dynamic_axes = {
888888 "pixel_values" : {0 : "grid_height" , 1 : "grid_width" },
@@ -900,6 +900,7 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
900900 lang_dynamic_axes [f"past_value.{ i } " ] = {0 : "batch_size" , 2 : "ctx_len" }
901901
902902 dynamic_axes = {}
903+
903904 if kv_offload :
904905 dynamic_axes ["vision" ] = vision_dynamic_axes
905906 dynamic_axes ["lang" ] = lang_dynamic_axes
@@ -911,7 +912,7 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
911912 def get_output_names (self , kv_offload : bool = False ):
912913 vision_output_names = ["vision_embeds" ]
913914 lang_output_names = ["logits" ]
914- for i in range (self .model .config .num_hidden_layers ):
915+ for i in range (self .model .config .text_config . num_hidden_layers ):
915916 for kv in ["key" , "value" ]:
916917 lang_output_names .append (f"past_{ kv } .{ i } _RetainedState" )
917918
@@ -927,6 +928,32 @@ def get_output_names(self, kv_offload: bool = False):
927928 return lang_output_names
928929 return output_names
929930
931+ def prepare_inputs_for_generation (self , inputs , prefill_seq_len = 128 , batch_size = 1 ):
932+ input_ids_length = inputs ["input_ids" ].shape [1 ]
933+
934+ inputs ["position_ids" ] = torch .arange (input_ids_length ).view (1 , 1 , input_ids_length ).expand (- 1 , batch_size , - 1 )
935+
936+ pos_ids , rope_deltas = self .model .get_rope_index (
937+ inputs ["input_ids" ],
938+ None if "image_grid_thw" not in inputs else inputs ["image_grid_thw" ],
939+ video_grid_thw = None ,
940+ second_per_grid_ts = None ,
941+ attention_mask = inputs ["attention_mask" ],
942+ )
943+
944+ inputs ["position_ids" ] = torch .cat ((inputs ["position_ids" ], pos_ids ), dim = 0 )
945+
946+ num_chunks = - (input_ids_length // - prefill_seq_len ) # ceil divide without float
947+ padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len
948+
949+ inputs ["position_ids" ] = F .pad (
950+ inputs ["position_ids" ], pad = (0 , padded_len - input_ids_length ), mode = "constant" , value = - 1
951+ )
952+
953+ inputs .pop ("image_grid_thw" , None )
954+
955+ return inputs
956+
930957 def get_inputs_info (self ):
931958 return [
932959 IOInfo (name = "input_ids" , datatype = torch .int64 , shape = ("batch_size" , "seq_len" )),
0 commit comments