Skip to content

Commit 552db50

Browse files
committed
Extend on-device sampling support for dual QPC VLMs
1 parent 60c36bc commit 552db50

File tree

3 files changed

+176
-20
lines changed

3 files changed

+176
-20
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 132 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -718,20 +718,37 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel):
718718
]
719719
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
720720

721-
def __init__(self, model, **kwargs):
721+
def __init__(
722+
self,
723+
model,
724+
qaic_config: Optional[dict] = None,
725+
**kwargs
726+
):
722727
"""
723728
Initializes the language decoder component for multimodal models.
724729
725730
Parameters
726731
----------
727732
model : nn.Module
728733
The full HuggingFace multimodal model from which the language decoder is extracted.
734+
qaic_config : dict, optional
735+
A dictionary for QAIC-specific configurations.
736+
Only the following keys are supported by the text model of the dual QPC multimodal model:
737+
- **include_sampler** (bool): If True, enables on-device sampling of next tokens.
738+
- **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling.
739+
Additional keys will be ignored.
729740
**kwargs :
730741
Additional keyword arguments passed to the base class constructor.
731742
"""
732743
super().__init__(model, **kwargs)
733744
self.model = model.get_qeff_language_decoder()
734745
self.hash_params["qeff_auto_class"] = self.__class__.__name__
746+
self.model.qaic_config = qaic_config
747+
# ---Sampling---
748+
# Note: SamplerTransform should be applied after all other transforms
749+
# are done. The role of the sampler is to just add nodes at the output of the
750+
# previous transform function.
751+
self.model, _ = SamplerTransform.apply(self.model, qaic_config, **kwargs)
735752

736753
def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True):
737754
"""
@@ -755,10 +772,95 @@ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt
755772
str
756773
Path to the generated ONNX graph file for the language decoder.
757774
"""
775+
if self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False):
776+
inputs, output_names, dynamic_axes = self.get_sampling_inputs_and_outputs(inputs, output_names, dynamic_axes)
758777
return self._export(
759778
inputs, output_names, dynamic_axes, export_dir=export_dir, offload_pt_weights=offload_pt_weights
760779
)
761780

781+
def get_sampling_inputs_and_outputs(
782+
self,
783+
example_inputs: Dict[str, torch.Tensor],
784+
output_names: List[str],
785+
dynamic_axes: Dict[str, Dict[int, str]],
786+
):
787+
"""
788+
Updates the example inputs, output names, and dynamic axes to include
789+
parameters relevant for on-device sampling during ONNX export.
790+
791+
Parameters
792+
----------
793+
example_inputs : Dict[str, torch.Tensor]
794+
Current dictionary of example inputs.
795+
output_names : List[str]
796+
Current list of output names.
797+
dynamic_axes : Dict[str, Dict[int, str]]
798+
Current dictionary of dynamic axes configurations.
799+
800+
Returns
801+
-------
802+
Tuple[Dict[str, torch.Tensor], List[str], Dict[str, Dict[int, str]]]
803+
Updated example inputs, output names, and dynamic axes including
804+
sampling-related parameters.
805+
"""
806+
bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
807+
808+
assert "logits" in output_names, "logits must be part of the output names to suport on-device sampling"
809+
810+
logits_index = output_names.index("logits")
811+
output_names[logits_index] = "next_tokens"
812+
813+
example_inputs["last_accepted_output_tokens"] = torch.zeros(
814+
(bs, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN), dtype=torch.int64
815+
)
816+
dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "seq_len"}
817+
818+
example_inputs["past_repetition_penalty_buffer"] = torch.zeros(
819+
(bs, self.model.language_model.config.vocab_size), dtype=torch.bool
820+
)
821+
dynamic_axes["past_repetition_penalty_buffer"] = {
822+
0: "batch_size",
823+
}
824+
output_names.append("past_repetition_penalty_buffer_RetainedState")
825+
826+
example_inputs["repetition_penalties"] = (
827+
torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES
828+
)
829+
dynamic_axes["repetition_penalties"] = {0: "batch_size"}
830+
831+
example_inputs["past_presence_penalty_buffer"] = torch.zeros(
832+
(bs, self.model.language_model.config.vocab_size), dtype=torch.bool
833+
)
834+
dynamic_axes["past_presence_penalty_buffer"] = {
835+
0: "batch_size",
836+
}
837+
output_names.append("past_presence_penalty_buffer_RetainedState")
838+
839+
example_inputs["presence_penalties"] = (
840+
torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES
841+
)
842+
dynamic_axes["presence_penalties"] = {0: "batch_size"}
843+
844+
example_inputs["temperatures"] = (
845+
torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TEMPERATURES
846+
)
847+
dynamic_axes["temperatures"] = {0: "batch_size"}
848+
849+
max_top_k_ids = self.model.qaic_config.get("max_top_k_ids", constants.ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS)
850+
example_inputs["top_ks"] = torch.randint(1, max_top_k_ids, size=(bs, 1)).to(torch.int32)
851+
dynamic_axes["top_ks"] = {0: "batch_size"}
852+
853+
example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TOP_PS
854+
dynamic_axes["top_ps"] = {0: "batch_size"}
855+
856+
example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS
857+
dynamic_axes["min_ps"] = {0: "batch_size"}
858+
859+
example_inputs["random_numbers"] = torch.rand((bs, 1), dtype=torch.float)
860+
dynamic_axes["random_numbers"] = {0: "batch_size"}
861+
862+
return example_inputs, output_names, dynamic_axes
863+
762864
def compile(
763865
self,
764866
compile_dir,
@@ -1438,6 +1540,8 @@ def __init__(
14381540
"""
14391541
if kwargs.pop("full_batch_size", None):
14401542
raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.")
1543+
if kwargs.pop("qaic_config", None):
1544+
raise NotImplementedError("On-device sampling is not supported for single QPC multimodal models yet.")
14411545
super().__init__(model, **kwargs)
14421546

14431547
# to handle internvl models
@@ -1957,7 +2061,13 @@ def __new__(self, model: nn.Module, kv_offload: Optional[bool] = True, **kwargs)
19572061

19582062
@classmethod
19592063
@with_replaced_quantizers
1960-
def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optional[bool] = None, **kwargs):
2064+
def from_pretrained(
2065+
cls,
2066+
pretrained_model_name_or_path: str,
2067+
kv_offload: Optional[bool] = None,
2068+
qaic_config: Optional[dict] = None,
2069+
**kwargs
2070+
):
19612071
"""
19622072
Load a QEfficient image-text-to-text model from a pretrained HuggingFace model or local path.
19632073
@@ -1969,6 +2079,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optiona
19692079
If True, uses the dual QPC approach (vision encoder KV offloaded).
19702080
If False, uses the single QPC approach (entire model in one QPC).
19712081
If None, the default behavior of the internal classes is used (typically dual QPC).
2082+
qaic_config : dict, optional
2083+
A dictionary for QAIC-specific configurations.
2084+
Only the following keys are supported by the text model of the dual QPC multimodal model:
2085+
- **include_sampler** (bool): If True, enables on-device sampling of next tokens.
2086+
- **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling.
2087+
Additional keys will be ignored.
19722088
**kwargs :
19732089
Additional arguments passed to HuggingFace's ``from_pretrained``.
19742090
@@ -1996,8 +2112,16 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optiona
19962112
NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.")
19972113

19982114
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
2115+
if qaic_config is not None:
2116+
qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path
19992117
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
2000-
return cls(model, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
2118+
return cls(
2119+
model,
2120+
kv_offload=kv_offload,
2121+
pretrained_model_name_or_path=pretrained_model_name_or_path,
2122+
qaic_config=qaic_config,
2123+
**kwargs
2124+
)
20012125

20022126

20032127
MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP = {
@@ -2199,7 +2323,11 @@ def from_pretrained(
21992323

22002324
if model.__class__.__name__ in MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP:
22012325
return MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP[model.__class__.__name__](
2202-
model, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
2326+
model,
2327+
kv_offload=kv_offload,
2328+
pretrained_model_name_or_path=pretrained_model_name_or_path,
2329+
qaic_config=qaic_config,
2330+
**kwargs
22032331
)
22042332
return cls(
22052333
model,

QEfficient/transformers/models/pytorch_transforms.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@
272272
QEffGrok1MultiHeadAttention,
273273
)
274274
from QEfficient.transformers.models.internvl.modeling_internvl import (
275+
QEffInternDecoderWrapper,
275276
QEffInternVisionEmbeddings,
276277
QEffInternVLModel,
277278
)
@@ -375,6 +376,7 @@
375376
QEffQwen2_5_VLModel,
376377
QEffQwen2_5_VLTextModel,
377378
QEffQwen2_5_VLVisionAttention,
379+
QEffQwen_2_5_vl_DecoderWrapper,
378380
QEffQwen_2_5_vl_ForConditionalGeneration,
379381
)
380382
from QEfficient.transformers.models.qwen3.modeling_qwen3 import (
@@ -678,6 +680,8 @@ class SamplerTransform:
678680
_module_mapping = {
679681
# Llama
680682
QEffLlamaForCausalLM,
683+
QEffInternDecoderWrapper,
684+
QEffQwen_2_5_vl_DecoderWrapper,
681685
}
682686

683687
@classmethod

QEfficient/transformers/sampler/sampler.py

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ class SamplerOutput(ModelOutput):
2424

2525
probs: torch.FloatTensor = None
2626
next_tokens: torch.IntTensor = None
27+
vision_embeds: Optional[torch.FloatTensor] = None # For VLMs
28+
image_idx: Optional[torch.IntTensor] = None # for VLMs
2729
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
2830
past_repetition_penalty_buffer: Optional[torch.Tensor] = None
2931
past_presence_penalty_buffer: Optional[torch.Tensor] = None
@@ -122,6 +124,8 @@ def sampler_forward(
122124
top_ps: Optional[torch.Tensor] = None,
123125
min_ps: Optional[torch.Tensor] = None,
124126
random_numbers: Optional[torch.Tensor] = None,
127+
vision_embeds: Optional[torch.Tensor] = None,
128+
image_idx: Optional[torch.Tensor] = None,
125129
) -> Union[Tuple, SamplerOutput]:
126130
r"""
127131
Perform the sampling of next tokens on the QAIC device (instead of the host)
@@ -170,20 +174,36 @@ def sampler_forward(
170174
Sampling parameter that represents the random seeds to use for random sampling.
171175
Must be in [-1, 1].
172176
"""
173-
174-
outputs = self.old_forward(
175-
input_ids=input_ids,
176-
attention_mask=attention_mask,
177-
position_ids=position_ids,
178-
past_key_values=past_key_values,
179-
batch_index=batch_index,
180-
inputs_embeds=inputs_embeds,
181-
use_cache=use_cache,
182-
output_attentions=output_attentions,
183-
output_hidden_states=output_hidden_states,
184-
return_dict=return_dict,
185-
cache_position=cache_position,
186-
)
177+
if vision_embeds is not None:
178+
logits, vision_embeds, image_idx, past_key_values = self.old_forward(
179+
input_ids=input_ids,
180+
vision_embeds=vision_embeds,
181+
position_ids=position_ids,
182+
image_idx=image_idx,
183+
past_key_values=past_key_values
184+
)
185+
outputs = dict(
186+
logits=logits,
187+
vision_embeds=vision_embeds,
188+
image_idx=image_idx,
189+
past_key_values=past_key_values
190+
)
191+
if position_ids.dim() == 3: # For models using m-rope
192+
position_ids = position_ids[0]
193+
else:
194+
outputs = self.old_forward(
195+
input_ids=input_ids,
196+
attention_mask=attention_mask,
197+
position_ids=position_ids,
198+
past_key_values=past_key_values,
199+
batch_index=batch_index,
200+
inputs_embeds=inputs_embeds,
201+
use_cache=use_cache,
202+
output_attentions=output_attentions,
203+
output_hidden_states=output_hidden_states,
204+
return_dict=return_dict,
205+
cache_position=cache_position,
206+
)
187207

188208
logits = outputs.get("logits", None)
189209
assert logits is not None, f"{self.model.__class__.__name__} does not return logits."
@@ -230,7 +250,9 @@ def sampler_forward(
230250
return SamplerOutput(
231251
probs=None,
232252
next_tokens=greedy_samples.reshape(-1, spec_length, 1), # Return sampled next tokens instead of logits
233-
past_key_values=outputs.past_key_values,
253+
vision_embeds=outputs.get("vision_embeds", None),
254+
image_idx=outputs.get("image_idx", None),
255+
past_key_values=outputs.get("past_key_values", None),
234256
past_repetition_penalty_buffer=past_repetition_penalty_buffer,
235257
past_presence_penalty_buffer=past_presence_penalty_buffer,
236258
)
@@ -314,7 +336,9 @@ def sampler_forward(
314336
return SamplerOutput(
315337
probs=probs,
316338
next_tokens=next_tokens, # Return sampled next tokens instead of logits
317-
past_key_values=outputs.past_key_values,
339+
vision_embeds=outputs.get("vision_embeds", None),
340+
image_idx=outputs.get("image_idx", None),
341+
past_key_values=outputs.get("past_key_values", None),
318342
past_repetition_penalty_buffer=past_repetition_penalty_buffer,
319343
past_presence_penalty_buffer=past_presence_penalty_buffer,
320344
)

0 commit comments

Comments
 (0)