Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 130 additions & 5 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,20 +718,32 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel):
]
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]

def __init__(self, model, **kwargs):
def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs):
"""
Initializes the language decoder component for multimodal models.

Parameters
----------
model : nn.Module
The full HuggingFace multimodal model from which the language decoder is extracted.
qaic_config : dict, optional
A dictionary for QAIC-specific configurations.
Only the following keys are supported by the text model of the dual QPC multimodal model:
- **include_sampler** (bool): If True, enables on-device sampling of next tokens.
- **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling.
Additional keys will be ignored.
**kwargs :
Additional keyword arguments passed to the base class constructor.
"""
super().__init__(model, **kwargs)
self.model = model.get_qeff_language_decoder()
self.hash_params["qeff_auto_class"] = self.__class__.__name__
self.model.qaic_config = qaic_config
# ---Sampling---
# Note: SamplerTransform should be applied after all other transforms
# are done. The role of the sampler is to just add nodes at the output of the
# previous transform function.
self.model, _ = SamplerTransform.apply(self.model, qaic_config, **kwargs)

def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True):
"""
Expand All @@ -755,10 +767,97 @@ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt
str
Path to the generated ONNX graph file for the language decoder.
"""
if self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False):
inputs, output_names, dynamic_axes = self.get_sampling_inputs_and_outputs(
inputs, output_names, dynamic_axes
)
return self._export(
inputs, output_names, dynamic_axes, export_dir=export_dir, offload_pt_weights=offload_pt_weights
)

def get_sampling_inputs_and_outputs(
self,
example_inputs: Dict[str, torch.Tensor],
output_names: List[str],
dynamic_axes: Dict[str, Dict[int, str]],
):
"""
Updates the example inputs, output names, and dynamic axes to include
parameters relevant for on-device sampling during ONNX export.

Parameters
----------
example_inputs : Dict[str, torch.Tensor]
Current dictionary of example inputs.
output_names : List[str]
Current list of output names.
dynamic_axes : Dict[str, Dict[int, str]]
Current dictionary of dynamic axes configurations.

Returns
-------
Tuple[Dict[str, torch.Tensor], List[str], Dict[str, Dict[int, str]]]
Updated example inputs, output names, and dynamic axes including
sampling-related parameters.
"""
bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE

assert "logits" in output_names, "logits must be part of the output names to suport on-device sampling"

logits_index = output_names.index("logits")
output_names[logits_index] = "next_tokens"

example_inputs["last_accepted_output_tokens"] = torch.zeros(
(bs, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN), dtype=torch.int64
)
dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "seq_len"}

example_inputs["past_repetition_penalty_buffer"] = torch.zeros(
(bs, self.model.language_model.config.vocab_size), dtype=torch.bool
)
dynamic_axes["past_repetition_penalty_buffer"] = {
0: "batch_size",
}
output_names.append("past_repetition_penalty_buffer_RetainedState")

example_inputs["repetition_penalties"] = (
torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES
)
dynamic_axes["repetition_penalties"] = {0: "batch_size"}

example_inputs["past_presence_penalty_buffer"] = torch.zeros(
(bs, self.model.language_model.config.vocab_size), dtype=torch.bool
)
dynamic_axes["past_presence_penalty_buffer"] = {
0: "batch_size",
}
output_names.append("past_presence_penalty_buffer_RetainedState")

example_inputs["presence_penalties"] = (
torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES
)
dynamic_axes["presence_penalties"] = {0: "batch_size"}

example_inputs["temperatures"] = (
torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TEMPERATURES
)
dynamic_axes["temperatures"] = {0: "batch_size"}

max_top_k_ids = self.model.qaic_config.get("max_top_k_ids", constants.ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS)
example_inputs["top_ks"] = torch.randint(1, max_top_k_ids, size=(bs, 1)).to(torch.int32)
dynamic_axes["top_ks"] = {0: "batch_size"}

example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TOP_PS
dynamic_axes["top_ps"] = {0: "batch_size"}

example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS
dynamic_axes["min_ps"] = {0: "batch_size"}

example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float)
dynamic_axes["random_numbers"] = {0: "batch_size"}

return example_inputs, output_names, dynamic_axes

def compile(
self,
compile_dir,
Expand Down Expand Up @@ -1438,6 +1537,8 @@ def __init__(
"""
if kwargs.pop("full_batch_size", None):
raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.")
if kwargs.pop("qaic_config", None):
raise NotImplementedError("On-device sampling is not supported for single QPC multimodal models yet.")
super().__init__(model, **kwargs)

# to handle internvl models
Expand Down Expand Up @@ -1957,7 +2058,13 @@ def __new__(self, model: nn.Module, kv_offload: Optional[bool] = True, **kwargs)

@classmethod
@with_replaced_quantizers
def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optional[bool] = None, **kwargs):
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
kv_offload: Optional[bool] = None,
qaic_config: Optional[dict] = None,
**kwargs,
):
"""
Load a QEfficient image-text-to-text model from a pretrained HuggingFace model or local path.

Expand All @@ -1969,6 +2076,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optiona
If True, uses the dual QPC approach (vision encoder KV offloaded).
If False, uses the single QPC approach (entire model in one QPC).
If None, the default behavior of the internal classes is used (typically dual QPC).
qaic_config : dict, optional
A dictionary for QAIC-specific configurations.
Only the following keys are supported by the text model of the dual QPC multimodal model:
- **include_sampler** (bool): If True, enables on-device sampling of next tokens.
- **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling.
Additional keys will be ignored.
**kwargs :
Additional arguments passed to HuggingFace's ``from_pretrained``.

Expand Down Expand Up @@ -1996,8 +2109,16 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optiona
NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.")

kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
if qaic_config is not None:
qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
return cls(model, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
return cls(
model,
kv_offload=kv_offload,
pretrained_model_name_or_path=pretrained_model_name_or_path,
qaic_config=qaic_config,
**kwargs,
)


MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP = {
Expand Down Expand Up @@ -2199,7 +2320,11 @@ def from_pretrained(

if model.__class__.__name__ in MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP:
return MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP[model.__class__.__name__](
model, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
model,
kv_offload=kv_offload,
pretrained_model_name_or_path=pretrained_model_name_or_path,
qaic_config=qaic_config,
**kwargs,
)
return cls(
model,
Expand Down Expand Up @@ -2391,7 +2516,7 @@ def get_sampling_inputs_and_outputs(
example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS
dynamic_axes["min_ps"] = {0: "batch_size"}

example_inputs["random_numbers"] = torch.rand((bs, 1), dtype=torch.float)
example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float)
dynamic_axes["random_numbers"] = {0: "batch_size"}

return example_inputs, output_names, dynamic_axes
Expand Down
4 changes: 4 additions & 0 deletions QEfficient/transformers/models/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@
QEffGrok1MultiHeadAttention,
)
from QEfficient.transformers.models.internvl.modeling_internvl import (
QEffInternDecoderWrapper,
QEffInternVisionEmbeddings,
QEffInternVLModel,
)
Expand Down Expand Up @@ -375,6 +376,7 @@
QEffQwen2_5_VLModel,
QEffQwen2_5_VLTextModel,
QEffQwen2_5_VLVisionAttention,
QEffQwen_2_5_vl_DecoderWrapper,
QEffQwen_2_5_vl_ForConditionalGeneration,
)
from QEfficient.transformers.models.qwen3.modeling_qwen3 import (
Expand Down Expand Up @@ -678,6 +680,8 @@ class SamplerTransform:
_module_mapping = {
# Llama
QEffLlamaForCausalLM,
QEffInternDecoderWrapper,
QEffQwen_2_5_vl_DecoderWrapper,
}

@classmethod
Expand Down
54 changes: 36 additions & 18 deletions QEfficient/transformers/sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class SamplerOutput(ModelOutput):

probs: torch.FloatTensor = None
next_tokens: torch.IntTensor = None
vision_embeds: Optional[torch.FloatTensor] = None # For VLMs
image_idx: Optional[torch.IntTensor] = None # for VLMs
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_repetition_penalty_buffer: Optional[torch.Tensor] = None
past_presence_penalty_buffer: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -122,6 +124,8 @@ def sampler_forward(
top_ps: Optional[torch.Tensor] = None,
min_ps: Optional[torch.Tensor] = None,
random_numbers: Optional[torch.Tensor] = None,
vision_embeds: Optional[torch.Tensor] = None,
image_idx: Optional[torch.Tensor] = None,
) -> Union[Tuple, SamplerOutput]:
r"""
Perform the sampling of next tokens on the QAIC device (instead of the host)
Expand Down Expand Up @@ -170,20 +174,31 @@ def sampler_forward(
Sampling parameter that represents the random seeds to use for random sampling.
Must be in [-1, 1].
"""

outputs = self.old_forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
batch_index=batch_index,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
if vision_embeds is not None:
logits, vision_embeds, image_idx, past_key_values = self.old_forward(
input_ids=input_ids,
vision_embeds=vision_embeds,
position_ids=position_ids,
image_idx=image_idx,
past_key_values=past_key_values,
)
outputs = dict(logits=logits, vision_embeds=vision_embeds, image_idx=image_idx, past_key_values=past_key_values)
if position_ids.dim() == 3: # For models using m-rope
position_ids = position_ids[0]
else:
outputs = self.old_forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
batch_index=batch_index,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)

logits = outputs.get("logits", None)
assert logits is not None, f"{self.model.__class__.__name__} does not return logits."
Expand Down Expand Up @@ -230,7 +245,9 @@ def sampler_forward(
return SamplerOutput(
probs=None,
next_tokens=greedy_samples.reshape(-1, spec_length, 1), # Return sampled next tokens instead of logits
past_key_values=outputs.past_key_values,
vision_embeds=outputs.get("vision_embeds", None),
image_idx=outputs.get("image_idx", None),
past_key_values=outputs.get("past_key_values", None),
past_repetition_penalty_buffer=past_repetition_penalty_buffer,
past_presence_penalty_buffer=past_presence_penalty_buffer,
)
Expand Down Expand Up @@ -300,9 +317,8 @@ def sampler_forward(
) # (batch_size, spec_length, vocab_size)

# Random Sampling
topk_probs_asc = torch.softmax(topk_values_asc, dim=1) # (batch_size * spec_length, max_top_k_ids)
gumbel_noise = -torch.log(-torch.log(random_numbers.repeat(spec_length, 1))) # Gumbel-Max Trick
y = topk_probs_asc + gumbel_noise
y = topk_values_asc + gumbel_noise # (batch_size * spec_length, max_top_k_ids)
random_samples_indices = torch.argmax(y, dim=1, keepdim=True)
random_samples = torch.gather(topk_indices_asc, 1, random_samples_indices) # (batch_size * spec_length, 1)

Expand All @@ -314,7 +330,9 @@ def sampler_forward(
return SamplerOutput(
probs=probs,
next_tokens=next_tokens, # Return sampled next tokens instead of logits
past_key_values=outputs.past_key_values,
vision_embeds=outputs.get("vision_embeds", None),
image_idx=outputs.get("image_idx", None),
past_key_values=outputs.get("past_key_values", None),
past_repetition_penalty_buffer=past_repetition_penalty_buffer,
past_presence_penalty_buffer=past_presence_penalty_buffer,
)
Loading