@@ -718,20 +718,37 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel):
718718 ]
719719 _onnx_transforms = [FP16ClipTransform , SplitTensorsTransform ]
720720
721- def __init__ (self , model , ** kwargs ):
721+ def __init__ (
722+ self ,
723+ model ,
724+ qaic_config : Optional [dict ] = None ,
725+ ** kwargs
726+ ):
722727 """
723728 Initializes the language decoder component for multimodal models.
724729
725730 Parameters
726731 ----------
727732 model : nn.Module
728733 The full HuggingFace multimodal model from which the language decoder is extracted.
734+ qaic_config : dict, optional
735+ A dictionary for QAIC-specific configurations.
736+ Only the following keys are supported by the text model of the dual QPC multimodal model:
737+ - **include_sampler** (bool): If True, enables on-device sampling of next tokens.
738+ - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling.
739+ Additional keys will be ignored.
729740 **kwargs :
730741 Additional keyword arguments passed to the base class constructor.
731742 """
732743 super ().__init__ (model , ** kwargs )
733744 self .model = model .get_qeff_language_decoder ()
734745 self .hash_params ["qeff_auto_class" ] = self .__class__ .__name__
746+ self .model .qaic_config = qaic_config
747+ # ---Sampling---
748+ # Note: SamplerTransform should be applied after all other transforms
749+ # are done. The role of the sampler is to just add nodes at the output of the
750+ # previous transform function.
751+ self .model , _ = SamplerTransform .apply (self .model , qaic_config , ** kwargs )
735752
736753 def export (self , inputs , output_names , dynamic_axes , export_dir = None , offload_pt_weights = True ):
737754 """
@@ -755,10 +772,95 @@ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt
755772 str
756773 Path to the generated ONNX graph file for the language decoder.
757774 """
775+ if self .model .qaic_config is not None and self .model .qaic_config .get ("include_sampler" , False ):
776+ inputs , output_names , dynamic_axes = self .get_sampling_inputs_and_outputs (inputs , output_names , dynamic_axes )
758777 return self ._export (
759778 inputs , output_names , dynamic_axes , export_dir = export_dir , offload_pt_weights = offload_pt_weights
760779 )
761780
781+ def get_sampling_inputs_and_outputs (
782+ self ,
783+ example_inputs : Dict [str , torch .Tensor ],
784+ output_names : List [str ],
785+ dynamic_axes : Dict [str , Dict [int , str ]],
786+ ):
787+ """
788+ Updates the example inputs, output names, and dynamic axes to include
789+ parameters relevant for on-device sampling during ONNX export.
790+
791+ Parameters
792+ ----------
793+ example_inputs : Dict[str, torch.Tensor]
794+ Current dictionary of example inputs.
795+ output_names : List[str]
796+ Current list of output names.
797+ dynamic_axes : Dict[str, Dict[int, str]]
798+ Current dictionary of dynamic axes configurations.
799+
800+ Returns
801+ -------
802+ Tuple[Dict[str, torch.Tensor], List[str], Dict[str, Dict[int, str]]]
803+ Updated example inputs, output names, and dynamic axes including
804+ sampling-related parameters.
805+ """
806+ bs : int = constants .ONNX_EXPORT_EXAMPLE_BATCH_SIZE
807+
808+ assert "logits" in output_names , "logits must be part of the output names to suport on-device sampling"
809+
810+ logits_index = output_names .index ("logits" )
811+ output_names [logits_index ] = "next_tokens"
812+
813+ example_inputs ["last_accepted_output_tokens" ] = torch .zeros (
814+ (bs , constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN ), dtype = torch .int64
815+ )
816+ dynamic_axes ["last_accepted_output_tokens" ] = {0 : "batch_size" , 1 : "seq_len" }
817+
818+ example_inputs ["past_repetition_penalty_buffer" ] = torch .zeros (
819+ (bs , self .model .language_model .config .vocab_size ), dtype = torch .bool
820+ )
821+ dynamic_axes ["past_repetition_penalty_buffer" ] = {
822+ 0 : "batch_size" ,
823+ }
824+ output_names .append ("past_repetition_penalty_buffer_RetainedState" )
825+
826+ example_inputs ["repetition_penalties" ] = (
827+ torch .ones ((bs , 1 ), dtype = torch .float ) * constants .ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES
828+ )
829+ dynamic_axes ["repetition_penalties" ] = {0 : "batch_size" }
830+
831+ example_inputs ["past_presence_penalty_buffer" ] = torch .zeros (
832+ (bs , self .model .language_model .config .vocab_size ), dtype = torch .bool
833+ )
834+ dynamic_axes ["past_presence_penalty_buffer" ] = {
835+ 0 : "batch_size" ,
836+ }
837+ output_names .append ("past_presence_penalty_buffer_RetainedState" )
838+
839+ example_inputs ["presence_penalties" ] = (
840+ torch .zeros ((bs , 1 ), dtype = torch .float ) + constants .ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES
841+ )
842+ dynamic_axes ["presence_penalties" ] = {0 : "batch_size" }
843+
844+ example_inputs ["temperatures" ] = (
845+ torch .ones ((bs , 1 ), dtype = torch .float ) * constants .ONNX_EXPORT_EXAMPLE_TEMPERATURES
846+ )
847+ dynamic_axes ["temperatures" ] = {0 : "batch_size" }
848+
849+ max_top_k_ids = self .model .qaic_config .get ("max_top_k_ids" , constants .ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS )
850+ example_inputs ["top_ks" ] = torch .randint (1 , max_top_k_ids , size = (bs , 1 )).to (torch .int32 )
851+ dynamic_axes ["top_ks" ] = {0 : "batch_size" }
852+
853+ example_inputs ["top_ps" ] = torch .ones ((bs , 1 ), dtype = torch .float ) * constants .ONNX_EXPORT_EXAMPLE_TOP_PS
854+ dynamic_axes ["top_ps" ] = {0 : "batch_size" }
855+
856+ example_inputs ["min_ps" ] = torch .ones ((bs , 1 ), dtype = torch .float ) * constants .ONNX_EXPORT_EXAMPLE_MIN_PS
857+ dynamic_axes ["min_ps" ] = {0 : "batch_size" }
858+
859+ example_inputs ["random_numbers" ] = torch .rand ((bs , 1 ), dtype = torch .float )
860+ dynamic_axes ["random_numbers" ] = {0 : "batch_size" }
861+
862+ return example_inputs , output_names , dynamic_axes
863+
762864 def compile (
763865 self ,
764866 compile_dir ,
@@ -1438,6 +1540,8 @@ def __init__(
14381540 """
14391541 if kwargs .pop ("full_batch_size" , None ):
14401542 raise NotImplementedError ("Continuous batching is not supported for image-text-to-text models yet." )
1543+ if kwargs .pop ("qaic_config" , None ):
1544+ raise NotImplementedError ("On-device sampling is not supported for single QPC multimodal models yet." )
14411545 super ().__init__ (model , ** kwargs )
14421546
14431547 # to handle internvl models
@@ -1957,7 +2061,13 @@ def __new__(self, model: nn.Module, kv_offload: Optional[bool] = True, **kwargs)
19572061
19582062 @classmethod
19592063 @with_replaced_quantizers
1960- def from_pretrained (cls , pretrained_model_name_or_path : str , kv_offload : Optional [bool ] = None , ** kwargs ):
2064+ def from_pretrained (
2065+ cls ,
2066+ pretrained_model_name_or_path : str ,
2067+ kv_offload : Optional [bool ] = None ,
2068+ qaic_config : Optional [dict ] = None ,
2069+ ** kwargs
2070+ ):
19612071 """
19622072 Load a QEfficient image-text-to-text model from a pretrained HuggingFace model or local path.
19632073
@@ -1969,6 +2079,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optiona
19692079 If True, uses the dual QPC approach (vision encoder KV offloaded).
19702080 If False, uses the single QPC approach (entire model in one QPC).
19712081 If None, the default behavior of the internal classes is used (typically dual QPC).
2082+ qaic_config : dict, optional
2083+ A dictionary for QAIC-specific configurations.
2084+ Only the following keys are supported by the text model of the dual QPC multimodal model:
2085+ - **include_sampler** (bool): If True, enables on-device sampling of next tokens.
2086+ - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling.
2087+ Additional keys will be ignored.
19722088 **kwargs :
19732089 Additional arguments passed to HuggingFace's ``from_pretrained``.
19742090
@@ -1996,8 +2112,16 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optiona
19962112 NotImplementedError ("Continuous batching is not supported for image-text-to-text models yet." )
19972113
19982114 kwargs .update ({"attn_implementation" : "eager" , "low_cpu_mem_usage" : False })
2115+ if qaic_config is not None :
2116+ qaic_config ["pretrained_model_name_or_path" ] = pretrained_model_name_or_path
19992117 model = cls ._hf_auto_class .from_pretrained (pretrained_model_name_or_path , ** kwargs )
2000- return cls (model , kv_offload = kv_offload , pretrained_model_name_or_path = pretrained_model_name_or_path , ** kwargs )
2118+ return cls (
2119+ model ,
2120+ kv_offload = kv_offload ,
2121+ pretrained_model_name_or_path = pretrained_model_name_or_path ,
2122+ qaic_config = qaic_config ,
2123+ ** kwargs
2124+ )
20012125
20022126
20032127MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP = {
@@ -2199,7 +2323,11 @@ def from_pretrained(
21992323
22002324 if model .__class__ .__name__ in MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP :
22012325 return MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP [model .__class__ .__name__ ](
2202- model , kv_offload = kv_offload , pretrained_model_name_or_path = pretrained_model_name_or_path , ** kwargs
2326+ model ,
2327+ kv_offload = kv_offload ,
2328+ pretrained_model_name_or_path = pretrained_model_name_or_path ,
2329+ qaic_config = qaic_config ,
2330+ ** kwargs
22032331 )
22042332 return cls (
22052333 model ,
0 commit comments