From 3ac9937637c90b15401dc4698932eea1574ac214 Mon Sep 17 00:00:00 2001 From: quic-xiyushi Date: Thu, 23 Oct 2025 16:35:26 -0700 Subject: [PATCH 1/2] Extend on-device sampling support for dual QPC VLMs Signed-off-by: quic-xiyushi --- .../transformers/models/modeling_auto.py | 136 +++++++++++++++++- .../transformers/models/pytorch_transforms.py | 4 + QEfficient/transformers/sampler/sampler.py | 56 +++++--- 3 files changed, 176 insertions(+), 20 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 633a0b29d..97ec74201 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -718,7 +718,12 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel): ] _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] - def __init__(self, model, **kwargs): + def __init__( + self, + model, + qaic_config: Optional[dict] = None, + **kwargs + ): """ Initializes the language decoder component for multimodal models. @@ -726,12 +731,24 @@ def __init__(self, model, **kwargs): ---------- model : nn.Module The full HuggingFace multimodal model from which the language decoder is extracted. + qaic_config : dict, optional + A dictionary for QAIC-specific configurations. + Only the following keys are supported by the text model of the dual QPC multimodal model: + - **include_sampler** (bool): If True, enables on-device sampling of next tokens. + - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling. + Additional keys will be ignored. **kwargs : Additional keyword arguments passed to the base class constructor. """ super().__init__(model, **kwargs) self.model = model.get_qeff_language_decoder() self.hash_params["qeff_auto_class"] = self.__class__.__name__ + self.model.qaic_config = qaic_config + # ---Sampling--- + # Note: SamplerTransform should be applied after all other transforms + # are done. The role of the sampler is to just add nodes at the output of the + # previous transform function. + self.model, _ = SamplerTransform.apply(self.model, qaic_config, **kwargs) def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True): """ @@ -755,10 +772,95 @@ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt str Path to the generated ONNX graph file for the language decoder. """ + if self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False): + inputs, output_names, dynamic_axes = self.get_sampling_inputs_and_outputs(inputs, output_names, dynamic_axes) return self._export( inputs, output_names, dynamic_axes, export_dir=export_dir, offload_pt_weights=offload_pt_weights ) + def get_sampling_inputs_and_outputs( + self, + example_inputs: Dict[str, torch.Tensor], + output_names: List[str], + dynamic_axes: Dict[str, Dict[int, str]], + ): + """ + Updates the example inputs, output names, and dynamic axes to include + parameters relevant for on-device sampling during ONNX export. + + Parameters + ---------- + example_inputs : Dict[str, torch.Tensor] + Current dictionary of example inputs. + output_names : List[str] + Current list of output names. + dynamic_axes : Dict[str, Dict[int, str]] + Current dictionary of dynamic axes configurations. + + Returns + ------- + Tuple[Dict[str, torch.Tensor], List[str], Dict[str, Dict[int, str]]] + Updated example inputs, output names, and dynamic axes including + sampling-related parameters. + """ + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + + assert "logits" in output_names, "logits must be part of the output names to suport on-device sampling" + + logits_index = output_names.index("logits") + output_names[logits_index] = "next_tokens" + + example_inputs["last_accepted_output_tokens"] = torch.zeros( + (bs, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN), dtype=torch.int64 + ) + dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "seq_len"} + + example_inputs["past_repetition_penalty_buffer"] = torch.zeros( + (bs, self.model.language_model.config.vocab_size), dtype=torch.bool + ) + dynamic_axes["past_repetition_penalty_buffer"] = { + 0: "batch_size", + } + output_names.append("past_repetition_penalty_buffer_RetainedState") + + example_inputs["repetition_penalties"] = ( + torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES + ) + dynamic_axes["repetition_penalties"] = {0: "batch_size"} + + example_inputs["past_presence_penalty_buffer"] = torch.zeros( + (bs, self.model.language_model.config.vocab_size), dtype=torch.bool + ) + dynamic_axes["past_presence_penalty_buffer"] = { + 0: "batch_size", + } + output_names.append("past_presence_penalty_buffer_RetainedState") + + example_inputs["presence_penalties"] = ( + torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES + ) + dynamic_axes["presence_penalties"] = {0: "batch_size"} + + example_inputs["temperatures"] = ( + torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TEMPERATURES + ) + dynamic_axes["temperatures"] = {0: "batch_size"} + + max_top_k_ids = self.model.qaic_config.get("max_top_k_ids", constants.ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS) + example_inputs["top_ks"] = torch.randint(1, max_top_k_ids, size=(bs, 1)).to(torch.int32) + dynamic_axes["top_ks"] = {0: "batch_size"} + + example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TOP_PS + dynamic_axes["top_ps"] = {0: "batch_size"} + + example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS + dynamic_axes["min_ps"] = {0: "batch_size"} + + example_inputs["random_numbers"] = torch.rand((bs, 1), dtype=torch.float) + dynamic_axes["random_numbers"] = {0: "batch_size"} + + return example_inputs, output_names, dynamic_axes + def compile( self, compile_dir, @@ -1438,6 +1540,8 @@ def __init__( """ if kwargs.pop("full_batch_size", None): raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") + if kwargs.pop("qaic_config", None): + raise NotImplementedError("On-device sampling is not supported for single QPC multimodal models yet.") super().__init__(model, **kwargs) # to handle internvl models @@ -1957,7 +2061,13 @@ def __new__(self, model: nn.Module, kv_offload: Optional[bool] = True, **kwargs) @classmethod @with_replaced_quantizers - def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optional[bool] = None, **kwargs): + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + kv_offload: Optional[bool] = None, + qaic_config: Optional[dict] = None, + **kwargs + ): """ Load a QEfficient image-text-to-text model from a pretrained HuggingFace model or local path. @@ -1969,6 +2079,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optiona If True, uses the dual QPC approach (vision encoder KV offloaded). If False, uses the single QPC approach (entire model in one QPC). If None, the default behavior of the internal classes is used (typically dual QPC). + qaic_config : dict, optional + A dictionary for QAIC-specific configurations. + Only the following keys are supported by the text model of the dual QPC multimodal model: + - **include_sampler** (bool): If True, enables on-device sampling of next tokens. + - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling. + Additional keys will be ignored. **kwargs : Additional arguments passed to HuggingFace's ``from_pretrained``. @@ -1996,8 +2112,16 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optiona NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) + if qaic_config is not None: + qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) - return cls(model, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) + return cls( + model, + kv_offload=kv_offload, + pretrained_model_name_or_path=pretrained_model_name_or_path, + qaic_config=qaic_config, + **kwargs + ) MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP = { @@ -2199,7 +2323,11 @@ def from_pretrained( if model.__class__.__name__ in MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP: return MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP[model.__class__.__name__]( - model, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs + model, + kv_offload=kv_offload, + pretrained_model_name_or_path=pretrained_model_name_or_path, + qaic_config=qaic_config, + **kwargs ) return cls( model, diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index eeb7bd6e6..ffba24ec2 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -272,6 +272,7 @@ QEffGrok1MultiHeadAttention, ) from QEfficient.transformers.models.internvl.modeling_internvl import ( + QEffInternDecoderWrapper, QEffInternVisionEmbeddings, QEffInternVLModel, ) @@ -375,6 +376,7 @@ QEffQwen2_5_VLModel, QEffQwen2_5_VLTextModel, QEffQwen2_5_VLVisionAttention, + QEffQwen_2_5_vl_DecoderWrapper, QEffQwen_2_5_vl_ForConditionalGeneration, ) from QEfficient.transformers.models.qwen3.modeling_qwen3 import ( @@ -678,6 +680,8 @@ class SamplerTransform: _module_mapping = { # Llama QEffLlamaForCausalLM, + QEffInternDecoderWrapper, + QEffQwen_2_5_vl_DecoderWrapper, } @classmethod diff --git a/QEfficient/transformers/sampler/sampler.py b/QEfficient/transformers/sampler/sampler.py index 96846e712..4a9aa6034 100644 --- a/QEfficient/transformers/sampler/sampler.py +++ b/QEfficient/transformers/sampler/sampler.py @@ -24,6 +24,8 @@ class SamplerOutput(ModelOutput): probs: torch.FloatTensor = None next_tokens: torch.IntTensor = None + vision_embeds: Optional[torch.FloatTensor] = None # For VLMs + image_idx: Optional[torch.IntTensor] = None # for VLMs past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None past_repetition_penalty_buffer: Optional[torch.Tensor] = None past_presence_penalty_buffer: Optional[torch.Tensor] = None @@ -122,6 +124,8 @@ def sampler_forward( top_ps: Optional[torch.Tensor] = None, min_ps: Optional[torch.Tensor] = None, random_numbers: Optional[torch.Tensor] = None, + vision_embeds: Optional[torch.Tensor] = None, + image_idx: Optional[torch.Tensor] = None, ) -> Union[Tuple, SamplerOutput]: r""" Perform the sampling of next tokens on the QAIC device (instead of the host) @@ -170,20 +174,36 @@ def sampler_forward( Sampling parameter that represents the random seeds to use for random sampling. Must be in [-1, 1]. """ - - outputs = self.old_forward( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - batch_index=batch_index, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - ) + if vision_embeds is not None: + logits, vision_embeds, image_idx, past_key_values = self.old_forward( + input_ids=input_ids, + vision_embeds=vision_embeds, + position_ids=position_ids, + image_idx=image_idx, + past_key_values=past_key_values + ) + outputs = dict( + logits=logits, + vision_embeds=vision_embeds, + image_idx=image_idx, + past_key_values=past_key_values + ) + if position_ids.dim() == 3: # For models using m-rope + position_ids = position_ids[0] + else: + outputs = self.old_forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) logits = outputs.get("logits", None) assert logits is not None, f"{self.model.__class__.__name__} does not return logits." @@ -230,7 +250,9 @@ def sampler_forward( return SamplerOutput( probs=None, next_tokens=greedy_samples.reshape(-1, spec_length, 1), # Return sampled next tokens instead of logits - past_key_values=outputs.past_key_values, + vision_embeds=outputs.get("vision_embeds", None), + image_idx=outputs.get("image_idx", None), + past_key_values=outputs.get("past_key_values", None), past_repetition_penalty_buffer=past_repetition_penalty_buffer, past_presence_penalty_buffer=past_presence_penalty_buffer, ) @@ -314,7 +336,9 @@ def sampler_forward( return SamplerOutput( probs=probs, next_tokens=next_tokens, # Return sampled next tokens instead of logits - past_key_values=outputs.past_key_values, + vision_embeds=outputs.get("vision_embeds", None), + image_idx=outputs.get("image_idx", None), + past_key_values=outputs.get("past_key_values", None), past_repetition_penalty_buffer=past_repetition_penalty_buffer, past_presence_penalty_buffer=past_presence_penalty_buffer, ) From df3501a24d8c9f9ede434fed9a9a4a3cfe00ba88 Mon Sep 17 00:00:00 2001 From: quic-xiyushi Date: Thu, 30 Oct 2025 00:04:01 -0700 Subject: [PATCH 2/2] Fix random_numbers shape Signed-off-by: quic-xiyushi --- .../transformers/models/modeling_auto.py | 35 +++++++++---------- QEfficient/transformers/sampler/sampler.py | 22 +++++------- 2 files changed, 24 insertions(+), 33 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 97ec74201..c2c1ebcd2 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -718,12 +718,7 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel): ] _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] - def __init__( - self, - model, - qaic_config: Optional[dict] = None, - **kwargs - ): + def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs): """ Initializes the language decoder component for multimodal models. @@ -732,7 +727,7 @@ def __init__( model : nn.Module The full HuggingFace multimodal model from which the language decoder is extracted. qaic_config : dict, optional - A dictionary for QAIC-specific configurations. + A dictionary for QAIC-specific configurations. Only the following keys are supported by the text model of the dual QPC multimodal model: - **include_sampler** (bool): If True, enables on-device sampling of next tokens. - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling. @@ -773,7 +768,9 @@ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt Path to the generated ONNX graph file for the language decoder. """ if self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False): - inputs, output_names, dynamic_axes = self.get_sampling_inputs_and_outputs(inputs, output_names, dynamic_axes) + inputs, output_names, dynamic_axes = self.get_sampling_inputs_and_outputs( + inputs, output_names, dynamic_axes + ) return self._export( inputs, output_names, dynamic_axes, export_dir=export_dir, offload_pt_weights=offload_pt_weights ) @@ -804,7 +801,7 @@ def get_sampling_inputs_and_outputs( sampling-related parameters. """ bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE - + assert "logits" in output_names, "logits must be part of the output names to suport on-device sampling" logits_index = output_names.index("logits") @@ -856,7 +853,7 @@ def get_sampling_inputs_and_outputs( example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS dynamic_axes["min_ps"] = {0: "batch_size"} - example_inputs["random_numbers"] = torch.rand((bs, 1), dtype=torch.float) + example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float) dynamic_axes["random_numbers"] = {0: "batch_size"} return example_inputs, output_names, dynamic_axes @@ -2066,7 +2063,7 @@ def from_pretrained( pretrained_model_name_or_path: str, kv_offload: Optional[bool] = None, qaic_config: Optional[dict] = None, - **kwargs + **kwargs, ): """ Load a QEfficient image-text-to-text model from a pretrained HuggingFace model or local path. @@ -2080,7 +2077,7 @@ def from_pretrained( If False, uses the single QPC approach (entire model in one QPC). If None, the default behavior of the internal classes is used (typically dual QPC). qaic_config : dict, optional - A dictionary for QAIC-specific configurations. + A dictionary for QAIC-specific configurations. Only the following keys are supported by the text model of the dual QPC multimodal model: - **include_sampler** (bool): If True, enables on-device sampling of next tokens. - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling. @@ -2116,11 +2113,11 @@ def from_pretrained( qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) return cls( - model, - kv_offload=kv_offload, - pretrained_model_name_or_path=pretrained_model_name_or_path, - qaic_config=qaic_config, - **kwargs + model, + kv_offload=kv_offload, + pretrained_model_name_or_path=pretrained_model_name_or_path, + qaic_config=qaic_config, + **kwargs, ) @@ -2327,7 +2324,7 @@ def from_pretrained( kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, qaic_config=qaic_config, - **kwargs + **kwargs, ) return cls( model, @@ -2519,7 +2516,7 @@ def get_sampling_inputs_and_outputs( example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS dynamic_axes["min_ps"] = {0: "batch_size"} - example_inputs["random_numbers"] = torch.rand((bs, 1), dtype=torch.float) + example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float) dynamic_axes["random_numbers"] = {0: "batch_size"} return example_inputs, output_names, dynamic_axes diff --git a/QEfficient/transformers/sampler/sampler.py b/QEfficient/transformers/sampler/sampler.py index 4a9aa6034..a15e156ff 100644 --- a/QEfficient/transformers/sampler/sampler.py +++ b/QEfficient/transformers/sampler/sampler.py @@ -24,8 +24,8 @@ class SamplerOutput(ModelOutput): probs: torch.FloatTensor = None next_tokens: torch.IntTensor = None - vision_embeds: Optional[torch.FloatTensor] = None # For VLMs - image_idx: Optional[torch.IntTensor] = None # for VLMs + vision_embeds: Optional[torch.FloatTensor] = None # For VLMs + image_idx: Optional[torch.IntTensor] = None # for VLMs past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None past_repetition_penalty_buffer: Optional[torch.Tensor] = None past_presence_penalty_buffer: Optional[torch.Tensor] = None @@ -176,19 +176,14 @@ def sampler_forward( """ if vision_embeds is not None: logits, vision_embeds, image_idx, past_key_values = self.old_forward( - input_ids=input_ids, - vision_embeds=vision_embeds, - position_ids=position_ids, - image_idx=image_idx, - past_key_values=past_key_values - ) - outputs = dict( - logits=logits, + input_ids=input_ids, vision_embeds=vision_embeds, + position_ids=position_ids, image_idx=image_idx, - past_key_values=past_key_values + past_key_values=past_key_values, ) - if position_ids.dim() == 3: # For models using m-rope + outputs = dict(logits=logits, vision_embeds=vision_embeds, image_idx=image_idx, past_key_values=past_key_values) + if position_ids.dim() == 3: # For models using m-rope position_ids = position_ids[0] else: outputs = self.old_forward( @@ -322,9 +317,8 @@ def sampler_forward( ) # (batch_size, spec_length, vocab_size) # Random Sampling - topk_probs_asc = torch.softmax(topk_values_asc, dim=1) # (batch_size * spec_length, max_top_k_ids) gumbel_noise = -torch.log(-torch.log(random_numbers.repeat(spec_length, 1))) # Gumbel-Max Trick - y = topk_probs_asc + gumbel_noise + y = topk_values_asc + gumbel_noise # (batch_size * spec_length, max_top_k_ids) random_samples_indices = torch.argmax(y, dim=1, keepdim=True) random_samples = torch.gather(topk_indices_asc, 1, random_samples_indices) # (batch_size * spec_length, 1)