diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index bbd937d52..5cb5e36e3 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -44,7 +44,7 @@ def read_only(self, cache_kwargs): ctx_indices = torch.arange(ctx_len)[None, None, ...] gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) invalid_mask = ctx_indices > gather_limit - + breakpoint() if torch.onnx.is_in_onnx_export(): invalid_idx_value = torch.iinfo(torch.int32).max else: @@ -113,6 +113,7 @@ def update( Return: A tuple containing the updated key and value states. """ + # breakpoint() # Update the cache if self.keys is None: self.keys = key_states @@ -237,15 +238,41 @@ class QEffDynamicCache(DynamicCache): """ - def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None, *args, **kwargs): + def __init__( + self, + ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None, + config=None, + offloading: bool = False, + offload_only_non_sliding: bool = False, + *args, + **kwargs, + ): # Remove layer_classes if present to avoid duplicate argument - kwargs.pop("layer_classes", None) + kwargs.pop("layers", None) from transformers.cache_utils import Cache # Import here to avoid circular import - Cache.__init__(self, layer_classes=QEffDynamicLayer, *args, **kwargs) + layers = [] + if len(layers) == 0: + Cache.__init__( + self, + layer_class_to_replicate=QEffDynamicLayer, + offloading=offloading, + offload_only_non_sliding=offload_only_non_sliding, + ) + else: + Cache.__init__( + self, + layers=layers, + offloading=offloading, + offload_only_non_sliding=offload_only_non_sliding, + ) + if ddp_cache_data is not None: - for key_states, value_states in ddp_cache_data: - self.layers.append(QEffDynamicLayer.from_tensors(key_states, value_states)) + for layer_idx, (key_states, value_states) in enumerate(ddp_cache_data): + # If the config was not passed above, initialize a DynamicLayer for each entry of the ddp_data + layers.append(QEffDynamicLayer()) + # Update the layer with the data + _, _ = layers[layer_idx].update(key_states, value_states) def read_only(self, layer_idx, cache_kwargs): """ @@ -260,6 +287,7 @@ def read_only(self, layer_idx, cache_kwargs): Return: A tuple containing the updated key and value states. """ + # breakpoint() return self.layers[layer_idx].read_only(cache_kwargs) def write_only(self, key_states, value_states, layer_idx, cache_kwargs): diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index c910ab387..19c0f418a 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -102,11 +102,6 @@ MistralModel, MistralRMSNorm, ) -from transformers.models.mistral3.modeling_mistral3 import ( - Mistral3ForConditionalGeneration, - Mistral3Model, - Mistral3RMSNorm, -) from transformers.models.mixtral.modeling_mixtral import ( MixtralAttention, MixtralDecoderLayer, @@ -129,13 +124,6 @@ MllamaVisionModel, ) from transformers.models.mpt.modeling_mpt import MptAttention, MptBlock, MptForCausalLM, MptModel -from transformers.models.olmo2.modeling_olmo2 import ( - Olmo2Attention, - Olmo2DecoderLayer, - Olmo2ForCausalLM, - Olmo2Model, - Olmo2RMSNorm, -) from transformers.models.phi.modeling_phi import PhiAttention, PhiDecoderLayer, PhiForCausalLM, PhiModel from transformers.models.phi3.modeling_phi3 import ( Phi3Attention, @@ -144,7 +132,6 @@ Phi3Model, Phi3RMSNorm, ) -from transformers.models.pixtral.modeling_pixtral import PixtralRMSNorm, PixtralVisionModel from transformers.models.qwen2.modeling_qwen2 import ( Qwen2Attention, Qwen2DecoderLayer, @@ -152,21 +139,26 @@ Qwen2Model, Qwen2RMSNorm, ) -from transformers.models.qwen3.modeling_qwen3 import ( - Qwen3Attention, - Qwen3DecoderLayer, - Qwen3ForCausalLM, - Qwen3Model, - Qwen3RMSNorm, -) -from transformers.models.qwen3_moe.modeling_qwen3_moe import ( - Qwen3MoeAttention, - Qwen3MoeDecoderLayer, - Qwen3MoeForCausalLM, - Qwen3MoeModel, - Qwen3MoeRMSNorm, - Qwen3MoeRotaryEmbedding, - Qwen3MoeSparseMoeBlock, +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VisionTransformerPretrainedModel, + Qwen2_5_VLAttention, + Qwen2_5_VLDecoderLayer, + Qwen2_5_VLForConditionalGeneration, + Qwen2_5_VLTextModel, + Qwen2_5_VLVisionAttention, +) +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2RMSNorm as Qwen2_5RMSNorm, +) +from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import ( + Qwen3VLMoeForConditionalGeneration, + Qwen3VLMoeModel, + Qwen3VLMoeTextAttention, + Qwen3VLMoeTextDecoderLayer, + Qwen3VLMoeTextModel, + Qwen3VLMoeTextRMSNorm, + Qwen3VLMoeVisionAttention, + Qwen3VLMoeVisionModel, ) from transformers.models.starcoder2.modeling_starcoder2 import ( Starcoder2Attention, @@ -294,11 +286,6 @@ QEffMistralForCausalLM, QEffMistralModel, ) -from QEfficient.transformers.models.mistral3.modeling_mistral3 import ( - QEffMistral3ForConditionalGeneration, - QEffMistral3Model, - QEffPixtralVisionModel, -) from QEfficient.transformers.models.mixtral_moe.modeling_mixtral import ( QEffMixtralAttention, QeffMixtralDecoderLayer, @@ -319,25 +306,12 @@ QEffMllamaTextSelfAttention, QEffMllamaVisionModel, ) -from QEfficient.transformers.models.molmo.modeling_molmo import ( - QEffMolmo, - QEffMolmoBlock, - QEffMolmoModel, - QEffMolmoSequentialBlock, - QEffMultiHeadDotProductAttention, -) from QEfficient.transformers.models.mpt.modeling_mpt import ( QEffMptAttention, QEffMptBlock, QEffMptForCausalLM, QEFfMptModel, ) -from QEfficient.transformers.models.olmo2.modeling_olmo2 import ( - QEffOlmo2Attention, - QEffOlmo2DecoderLayer, - QEffOlmo2ForCausalLM, - QEffOlmo2Model, -) from QEfficient.transformers.models.phi.modeling_phi import ( QEffPhiAttention, QEffPhiDecoderLayer, @@ -356,19 +330,23 @@ QEffQwen2ForCausalLM, QEffQwen2Model, ) -from QEfficient.transformers.models.qwen3.modeling_qwen3 import ( - QEffQwen3Attention, - QEffQwen3DecoderLayer, - QEffQwen3ForCausalLM, - QEffQwen3Model, -) -from QEfficient.transformers.models.qwen3_moe.modeling_qwen3_moe import ( - QEffQwen3MoeAttention, - QEffQwen3MoeDecoderLayer, - QEffQwen3MoeForCausalLM, - QEffQwen3MoeModel, - QEffQwen3MoeRotaryEmbedding, - QEffQwen3MoeSparseMoeBlock, +from QEfficient.transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + QEffQwen2_5_VisionTransformerPretrainedModel, + QEffQwen2_5_VLAttention, + QEffQwen2_5_VLDecoderLayer, + QEffQwen2_5_VLTextModel, + # QEffQwen2_5_VLModel, + QEffQwen2_5_VLVisionAttention, + QEffQwen_2_5_vl_ForConditionalGeneration, +) +from QEfficient.transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import ( + QEffQwen3VLMoeForConditionalGeneration, + QEffQwen3VLMoeModel, + QEffQwen3VLMoeTextAttention, + QEffQwen3VLMoeTextDecoderLayer, + QEffQwen3VLMoeTextModel, + QEffQwen3VLMoeVisionAttention, + QEffQwen3VLMoeVisionModel, ) from QEfficient.transformers.models.starcoder2.modeling_starcoder2 import ( QEffStarcoder2Attention, @@ -399,18 +377,16 @@ class CustomOpsTransform(ModuleMappingTransform): LlamaRMSNorm: CustomRMSNormAIC, Llama4TextRMSNorm: CustomRMSNormAIC, MistralRMSNorm: CustomRMSNormAIC, - Mistral3RMSNorm: CustomRMSNormAIC, MixtralRMSNorm: CustomRMSNormAIC, Phi3RMSNorm: CustomRMSNormAIC, Qwen2RMSNorm: CustomRMSNormAIC, - Qwen3RMSNorm: CustomRMSNormAIC, + Qwen2_5RMSNorm: CustomRMSNormAIC, MllamaTextRMSNorm: CustomRMSNormAIC, GraniteRMSNorm: CustomRMSNormAIC, - PixtralRMSNorm: CustomRMSNormAIC, GraniteMoeRMSNorm: CustomRMSNormAIC, - Qwen3MoeRMSNorm: CustomRMSNormAIC, + Qwen3VLMoeTextRMSNorm: CustomRMSNormAIC, Gemma3RMSNorm: QEffGemma3CustomRMSNormAIC, - Olmo2RMSNorm: CustomRMSNormAIC, + # Qwen3VLMoeTextRMSNorm: CustomRMSNormAIC, } @@ -463,12 +439,12 @@ class KVCacheTransform(ModuleMappingTransform): GemmaModel: QEffGemmaModel, GemmaForCausalLM: QEffGemmaForCausalLM, # Qwen3Moe - Qwen3MoeForCausalLM: QEffQwen3MoeForCausalLM, - Qwen3MoeModel: QEffQwen3MoeModel, - Qwen3MoeDecoderLayer: QEffQwen3MoeDecoderLayer, - Qwen3MoeAttention: QEffQwen3MoeAttention, - Qwen3MoeRotaryEmbedding: QEffQwen3MoeRotaryEmbedding, - Qwen3MoeSparseMoeBlock: QEffQwen3MoeSparseMoeBlock, + # Qwen3MoeForCausalLM: QEffQwen3MoeForCausalLM, + # Qwen3MoeModel: QEffQwen3MoeModel, + # Qwen3MoeDecoderLayer: QEffQwen3MoeDecoderLayer, + # Qwen3MoeAttention: QEffQwen3MoeAttention, + # Qwen3MoeRotaryEmbedding: QEffQwen3MoeRotaryEmbedding, + # Qwen3MoeSparseMoeBlock: QEffQwen3MoeSparseMoeBlock, # Gemma2 Gemma2Attention: QEffGemma2Attention, Gemma2DecoderLayer: QEffGemma2DecoderLayer, @@ -508,9 +484,6 @@ class KVCacheTransform(ModuleMappingTransform): MistralDecoderLayer: QEffMistralDecoderLayer, MistralModel: QEffMistralModel, MistralForCausalLM: QEffMistralForCausalLM, - # Mistral3 - Mistral3ForConditionalGeneration: QEffMistral3ForConditionalGeneration, - Mistral3Model: QEffMistral3Model, # Mixtral MixtralAttention: QEffMixtralAttention, MixtralSparseMoeBlock: QEffMixtralSparseMoeBlock, @@ -532,18 +505,28 @@ class KVCacheTransform(ModuleMappingTransform): PhiDecoderLayer: QEffPhiDecoderLayer, PhiModel: QEffPhiModel, PhiForCausalLM: QEffPhiForCausalLM, - # Pixtral - PixtralVisionModel: QEffPixtralVisionModel, # Qwen2 Qwen2Attention: QEffQwen2Attention, Qwen2DecoderLayer: QEffQwen2DecoderLayer, Qwen2Model: QEffQwen2Model, Qwen2ForCausalLM: QEffQwen2ForCausalLM, - # Qwen3 - Qwen3Attention: QEffQwen3Attention, - Qwen3DecoderLayer: QEffQwen3DecoderLayer, - Qwen3Model: QEffQwen3Model, - Qwen3ForCausalLM: QEffQwen3ForCausalLM, + # Qwen2.5 VL + Qwen2_5_VLForConditionalGeneration: QEffQwen_2_5_vl_ForConditionalGeneration, + # Qwen2_5_VLModel: QEffQwen2_5_VLModel, + Qwen2_5_VLTextModel: QEffQwen2_5_VLTextModel, + Qwen2_5_VLAttention: QEffQwen2_5_VLAttention, + Qwen2_5_VLDecoderLayer: QEffQwen2_5_VLDecoderLayer, + Qwen2_5_VisionTransformerPretrainedModel: QEffQwen2_5_VisionTransformerPretrainedModel, + Qwen2_5_VLVisionAttention: QEffQwen2_5_VLVisionAttention, + # Qwen3vlmoe + Qwen3VLMoeForConditionalGeneration: QEffQwen3VLMoeForConditionalGeneration, + Qwen3VLMoeModel: QEffQwen3VLMoeModel, + Qwen3VLMoeTextAttention: QEffQwen3VLMoeTextAttention, + Qwen3VLMoeTextDecoderLayer: QEffQwen3VLMoeTextDecoderLayer, + Qwen3VLMoeVisionAttention: QEffQwen3VLMoeVisionAttention, + Qwen3VLMoeVisionModel: QEffQwen3VLMoeVisionModel, + Qwen3VLMoeTextModel: QEffQwen3VLMoeTextModel, + # Grok1 # Starcoder2 Starcoder2Attention: QEffStarcoder2Attention, Starcoder2DecoderLayer: QEFFStarcoder2DecoderLayer, @@ -554,11 +537,6 @@ class KVCacheTransform(ModuleMappingTransform): GPTBigCodeBlock: QEffGPTBigCodeBlock, GPTBigCodeModel: QEffGPTBigCodeModel, GPTBigCodeForCausalLM: QEffGPTBigCodeForCausalLM, - # Olmo2 - Olmo2Attention: QEffOlmo2Attention, - Olmo2DecoderLayer: QEffOlmo2DecoderLayer, - Olmo2Model: QEffOlmo2Model, - Olmo2ForCausalLM: QEffOlmo2ForCausalLM, # Whisper encoder and decoder layers WhisperPositionalEmbedding: QEffWhisperPositionalEmbedding, WhisperAttention: QEffWhisperAttention, @@ -595,7 +573,7 @@ class SpDTransform: # Llama QEffLlamaForCausalLM, QEffQwen2ForCausalLM, - QEffQwen3ForCausalLM, + # QEffQwen3ForCausalLM, } @classmethod @@ -693,32 +671,6 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): "get_qeff_language_decoder": QEffInternVLModel.get_qeff_language_decoder, }, "InternVisionEmbeddings": {"forward": QEffInternVisionEmbeddings.forward}, - # Mapping for Molmo - "MolmoForCausalLM": { - "forward": QEffMolmoModel.forward, - "get_qeff_vision_encoder": QEffMolmoModel.get_qeff_vision_encoder, - "get_qeff_language_decoder": QEffMolmoModel.get_qeff_language_decoder, - "get_specializations": QEffMolmoModel.get_specializations, - "get_onnx_dynamic_axes": QEffMolmoModel.get_onnx_dynamic_axes, - "get_output_names": QEffMolmoModel.get_output_names, - "get_dummy_inputs": QEffMolmoModel.get_dummy_inputs, - "get_inputs_info": QEffMolmoModel.get_inputs_info, - }, - "RMSLayerNorm": {"forward": CustomRMSNormAIC.forward}, - # "MolmoForCausalLM": {"forward": QEffMolmoForCausalLM.forward}, - "Molmo": {"forward": QEffMolmo.forward}, - "MolmoSequentialBlock": { - "forward": QEffMolmoSequentialBlock.forward, - "attention": QEffMolmoBlock.attention, - "__qeff_init__": QEffMolmoBlock.__qeff_init__, - }, - "MolmoBlock": { - "attention": QEffMolmoBlock.attention, - "__qeff_init__": QEffMolmoBlock.__qeff_init__, - }, - "MultiHeadDotProductAttention": { - "forward": QEffMultiHeadDotProductAttention.forward, - }, # Mapping for grok1 model "Grok1ModelForCausalLM": {"forward": QEffGrok1ModelForCausalLM.forward}, "Grok1Model": { diff --git a/QEfficient/transformers/models/qwen3_vl_moe/__init__.py b/QEfficient/transformers/models/qwen3_vl_moe/__init__.py new file mode 100644 index 000000000..d647b73a6 --- /dev/null +++ b/QEfficient/transformers/models/qwen3_vl_moe/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py new file mode 100644 index 000000000..2c9b01d8f --- /dev/null +++ b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -0,0 +1,988 @@ +import math +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.cache_utils import Cache +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, +) +from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import ( + Qwen3VLMoeForConditionalGeneration, + Qwen3VLMoeModel, + Qwen3VLMoeModelOutputWithPast, + Qwen3VLMoeTextAttention, + Qwen3VLMoeTextConfig, + Qwen3VLMoeTextDecoderLayer, + Qwen3VLMoeTextModel, + Qwen3VLMoeTextRotaryEmbedding, + Qwen3VLMoeVisionAttention, + Qwen3VLMoeVisionModel, + apply_rotary_pos_emb_vision, + repeat_kv, + rotate_half, +) + +from QEfficient.transformers.cache_utils import QEffDynamicCache + +# from transformers import Qw +from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils import constants +from QEfficient.utils._utils import IOInfo, get_padding_shape_from_config +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE +from QEfficient.utils.logging_utils import logger + + +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + + mrope_section = mrope_section * 2 + breakpoint() + cos = cos[position_ids] + sin = sin[position_ids] + + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + breakpoint() + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + +class QEffQwen3VLMoeVisionAttention(Qwen3VLMoeVisionAttention): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " + "removed and `position_embeddings` will be mandatory." + ) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos() + sin = emb.sin() + else: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + + attention_mask = torch.full( + [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype + ) + + # Create index grids + seq_len = attention_mask.shape[-1] + rows = torch.arange(seq_len).view(1, -1) + cols = torch.arange(seq_len).view(-1, 1) + + # Prepare start and end indices + start = cu_seqlens[:-1].view(-1, 1, 1) + end = cu_seqlens[1:].view(-1, 1, 1) + + # Create block masks using broadcasting + row_mask = (rows >= start) & (rows < end) + col_mask = (cols >= start) & (cols < end) + block_mask = row_mask & col_mask # shape: (num_blocks, seq_len, seq_len) + + # Combine all blocks into one mask + final_mask = torch.ones((seq_len, seq_len), dtype=torch.float32) + final_mask[block_mask.any(dim=0)] = 0 + + final_mask = torch.where(final_mask == 1.0, torch.finfo(q.dtype).min, final_mask) + + attention_mask[0] = final_mask + + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) + attn_weights = attn_weights + attention_mask + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +class QEffQwen3VLMoeTextRotaryEmbedding(Qwen3VLMoeTextRotaryEmbedding): + """ + Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + The only differences are: + - Add static sin/cos computations. + """ + + def __init__(self, config: Qwen3VLMoeTextConfig, device=None): + super().__init__(config=config) + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + ) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) / math.sqrt(module.head_dim) + if attention_mask is not None: + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class QEffQwen3VLMoeTextAttention(Qwen3VLMoeTextAttention): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __qeff_init__(self): + self.rotary_emb = QEffQwen3VLMoeTextRotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + # kv_seq_len = key_states.shape[-2] + # kv_seq_len = past_key_values.get_usable_length(kv_seq_len, self.layer_idx) + # breakpoint() + kv_seq_len = past_key_values.get_seq_length(self.layer_idx) + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + # breakpoint() + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids[1:], self.config.rope_scaling["mrope_section"] + ) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids[0]} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + # breakpoint() + return attn_output, attn_weights, past_key_values + + +class QEffQwen3VLMoeTextDecoderLayer(Qwen3VLMoeTextDecoderLayer): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[torch.Tensor]] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + # position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + # position_embeddings=position_embeddings, + ) + # breakpoint() + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + # breakpoint() + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states[0] + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + # breakpoint() + return outputs + + +class QEffQwen3VLMoeTextModel(Qwen3VLMoeTextModel): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + visual_pos_masks: Optional[torch.Tensor] = None, + cache_position: Optional[torch.LongTensor] = None, + deepstack_visual_embeds: Optional[list[torch.Tensor]] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + breakpoint() + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) + # use_cache = use_cache if use_cache is not None else self.config.use_cache + # breakpoint() + # if (input_ids is None) ^ (inputs_embeds is not None): + # raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + # if inputs_embeds is None: + # inputs_embeds = self.embed_tokens(input_ids) + + # past_key_values_length = 0 + # if past_key_values is not None: + # past_key_values_length = past_key_values[0][0].shape[2] + + # past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.dim() == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + causal_mask = _create_causal_mask( + position_ids=position_ids[0], target_length=target_length, sliding_window=None + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + # all_self_attns = () if output_attentions else None + + # for decoder_layer in self.layers: + # if output_hidden_states: + # all_hidden_states += (hidden_states,) + # breakpoint() + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = layer_outputs[0] + layer_idx = 0 + if deepstack_visual_embeds is not None and layer_idx in range(len(deepstack_visual_embeds)): + hidden_states = self._deepstack_process( + hidden_states, + visual_pos_masks, + deepstack_visual_embeds[layer_idx], + ) + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + ) + + +class QEffQwen3VLMoeVisionModel(Qwen3VLMoeVisionModel): + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + # merge_size = self.spatial_merge_size + + # max_hw = int(grid_thw[:, 1:].max().item()) + # freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) + # device = freq_table.device + + # total_tokens = int(torch.prod(grid_thw, dim=1).sum().item()) + # pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) + + # offset = 0 + # for num_frames, height, width in grid_thw: + # merged_h, merged_w = height // merge_size, width // merge_size + + # block_rows = torch.arange(merged_h, device=device) # block row indices + # block_cols = torch.arange(merged_w, device=device) # block col indices + # intra_row = torch.arange(merge_size, device=device) # intra-block row offsets + # intra_col = torch.arange(merge_size, device=device) # intra-block col offsets + + # # Compute full-resolution positions + # row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] + # col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] + + # row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + # col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + + # coords = torch.stack((row_idx, col_idx), dim=-1) + + # if num_frames > 1: + # coords = coords.repeat(num_frames, 1) + + # num_tokens = coords.shape[0] + # pos_ids[offset : offset + num_tokens] = coords + # offset += num_tokens + + # embeddings = freq_table[pos_ids] # lookup rotary embeddings + # embeddings = embeddings.flatten(1) + # return embeddings + pos_ids = [] + # breakpoint() + bs, t, h, w = grid_thw.shape + + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + + x_expanded = pos_ids.unsqueeze(0) + x_expanded = x_expanded.expand(bs, -1, -1) + pos_ids = x_expanded.reshape(-1, pos_ids.size(1)) + + max_grid_size = max(grid_thw.shape) + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def fast_pos_embed_interpolate(self, grid_thw): + # breakpoint() + # grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] + gridbs, grid_ts, grid_hs, grid_ws = grid_thw.shape + grid_ts = torch.tensor([grid_ts], device=grid_thw.device) + grid_hs = torch.tensor([grid_hs], device=grid_thw.device) + grid_ws = torch.tensor([grid_ws], device=grid_thw.device) + # breakpoint() + idx_list = [[] for _ in range(4)] + weight_list = [[] for _ in range(4)] + + for t, h, w in zip(grid_ts, grid_hs, grid_ws): + h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) + w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) + + h_idxs_floor = h_idxs.int() + w_idxs_floor = w_idxs.int() + h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + + base_h = h_idxs_floor * self.num_grid_per_side + base_h_ceil = h_idxs_ceil * self.num_grid_per_side + + indices = [ + (base_h[None].T + w_idxs_floor[None]).flatten(), + (base_h[None].T + w_idxs_ceil[None]).flatten(), + (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), + (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), + ] + + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), + (dh[None].T * (1 - dw)[None]).flatten(), + (dh[None].T * dw[None]).flatten(), + ] + + for i in range(4): + idx_list[i].extend(indices[i].tolist()) + weight_list[i].extend(weights[i].tolist()) + # breakpoint() + idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device) + weight_tensor = torch.tensor( + weight_list, dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device + ) + pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None] + patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + + patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) + + patch_pos_embeds_permute = [] + merge_size = self.config.spatial_merge_size + # breakpoint() + for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): + pos_embed = pos_embed.repeat(t, 1) + pos_embed = ( + pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) + patch_pos_embeds_permute.append(pos_embed) + patch_pos_embeds = torch.cat(patch_pos_embeds_permute) + return patch_pos_embeds + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: + # breakpoint() + hidden_states = self.patch_embed(hidden_states) + + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + hidden_states = hidden_states + pos_embeds + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + bs, t, h, w = grid_thw.shape + + t = torch.arange(t, t + 1).squeeze().expand(bs) + h = torch.arange(h, h + 1).squeeze().expand(bs) + w = torch.arange(w, w + 1).squeeze().expand(bs) + # cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + # dim=0, + # # Select dtype based on the following factors: + # # - FA2 requires that cu_seqlens_q must have dtype int32 + # # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # # See https://github.com/huggingface/transformers/pull/34852 for more information + # dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + # ) + cu_seqlens = (h * w).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + deepstack_feature_lists = [] + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, + ) + if layer_num in self.deepstack_visual_indexes: + deepstack_feature = self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)]( + hidden_states + ) + deepstack_feature_lists.append(deepstack_feature) + + hidden_states = self.merger(hidden_states) + + return hidden_states, deepstack_feature_lists + + +class QEffQwen3VLMoeModel(Qwen3VLMoeModel): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + outputs = self.language_model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + + output = Qwen3VLMoeModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + return output if return_dict else output.to_tuple() + + +class QEffQwen3VLEncoderWrapper(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + self.model.vision_model = self.model.visual + + def forward(self, pixel_values, image_grid_thw): + image_embeds = self.model.visual(pixel_values, grid_thw=image_grid_thw)[0] + bs = image_grid_thw.shape[0] + split_size = torch.floor_divide(torch.tensor(image_embeds.size(0)), bs) + image_embeds = image_embeds.reshape(bs, split_size, image_embeds.size(1)) + breakpoint() + return image_embeds + + +class QEffQwen3VLDecoderWrapper(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + self.language_model = self.model.model + + def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + # breakpoint() + inputs_embeds = self.model.get_input_embeddings()(input_ids) + B, N, C = inputs_embeds.shape + selected = input_ids == self.model.config.image_token_id + indices1 = selected.to(torch.int64).cumsum(1) - 1 + indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1) + indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) + image_features_expanded = vision_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1] + image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) + inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds) + # breakpoint() + outputs = self.model.model( + inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + ) + logits = self.model.lm_head(outputs[0]) + image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) + + return logits, vision_embeds, image_idx, outputs.past_key_values + + +class QEffQwen3VLMoeForConditionalGeneration(Qwen3VLMoeForConditionalGeneration): + def get_qeff_vision_encoder(self): + return QEffQwen3VLEncoderWrapper(self) + + def get_qeff_language_decoder(self): + return QEffQwen3VLDecoderWrapper(self) + + def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): + inputs_shapes = {} + inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) + + vision_size = 181 + # breakpoint() + inputs_shapes["vision_embeds"] = ( + constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + vision_size, + self.model.config.vision_config.out_hidden_size, + ) + inputs_shapes["image_grid_thw"] = (1, 1, 22, 33) + inputs_shapes["position_ids"] = ( + 3, + constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, + ) + inputs_shapes["pixel_values"] = (726, 1536) + inputs_shapes["image_idx"] = (1, 1) + inputs_shapes["image_sizes"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 2) + # Define inputs + vision_inputs = {} + lang_inputs = {} + vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=torch.float32) + vision_inputs["image_grid_thw"] = torch.zeros((inputs_shapes["image_grid_thw"]), dtype=torch.int64) + lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64) + lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=torch.float32) + # breakpoint() + lang_inputs["position_ids"] = ( + ( + torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64) + .view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) + .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) + ) + .unsqueeze(0) + .repeat(4, 1, 1) + ) + lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64) + # Add data for KV + # breakpoint() + kv_cache_shape = get_padding_shape_from_config( + config=self.model.config.text_config, + batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, + ) + + lang_inputs["past_key_values"] = [[] for _ in range(self.model.config.text_config.num_hidden_layers)] + for i in range(self.model.config.text_config.num_hidden_layers): + for kv in ["key", "value"]: + lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + + inputs = {} + if kv_offload: + inputs["vision"] = vision_inputs + inputs["lang"] = lang_inputs + else: + lang_inputs.pop("vision_embeds") + inputs = {**vision_inputs, **lang_inputs} + return inputs + + def get_specializations( + self, + batch_size: int, + prefill_seq_len: int, + ctx_len: int, + img_size: None, + height: int = None, + width: int = None, + kv_offload: bool = False, + **compiler_options, + ): + breakpoint() + if height is None or width is None: + height = 1365 + width = 2048 + logger.warning( + "Setting height and width to be 1365 and 2048 respectively, as it was neither passed nor found in vision_config" + ) + prefill_seq_len = prefill_seq_len if prefill_seq_len else 128 + ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN + # channel = 3 + patch_size = self.config.vision_config.patch_size + # temporal_patch_size = self.config.vision_config.temporal_patch_size + + IMAGE_FACTOR = 28 + MIN_PIXELS = 4 * 28 * 28 + MAX_PIXELS = 16384 * 28 * 28 + MAX_RATIO = 200 + + def round_by_factor(number: int, factor: int) -> int: + """Returns the closest integer to 'number' that is divisible by 'factor'.""" + return round(number / factor) * factor + + def ceil_by_factor(number: int, factor: int) -> int: + """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" + return math.ceil(number / factor) * factor + + def floor_by_factor(number: int, factor: int) -> int: + """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" + return math.floor(number / factor) * factor + + def smart_resize( + height: int, + width: int, + factor: int = IMAGE_FACTOR, + min_pixels: int = MIN_PIXELS, + max_pixels: int = MAX_PIXELS, + ) -> tuple[int, int]: + """ + Rescales the image so that the following conditions are met: + + 1. Both dimensions (height and width) are divisible by 'factor'. + + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + + 3. The aspect ratio of the image is maintained as closely as possible. + """ + if max(height, width) / min(height, width) > MAX_RATIO: + raise ValueError( + f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}" + ) + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = floor_by_factor(height / beta, factor) + w_bar = floor_by_factor(width / beta, factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + return h_bar, w_bar + + breakpoint() + resized_height, resized_width = smart_resize(height=height, width=width) + grid_h, grid_w = resized_height // patch_size, resized_width // patch_size + grid_height = grid_h * grid_w + # grid_width = patch_size * patch_size * temporal_patch_size * channel + vision_size = grid_height // 4 + grid_height = grid_height * batch_size + + vision = [ + { + "batch_size": 1, + "vision_size": 181, + "grid_height": 22, + "grid_width": 33, + "grid_h": 726, + "grid_w": 1536, + } + ] + lang = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "vision_size": vision_size, + }, + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "vision_size": vision_size, + }, + ] + + specializations = {} + + if kv_offload: + specializations["vision"] = vision + specializations["lang"] = lang + return specializations, compiler_options + else: + return lang, compiler_options + + def get_onnx_dynamic_axes(self, kv_offload: bool = False): + # Define dynamic axes + num_layers = self.config.text_config.num_hidden_layers + # breakpoint() + vision_dynamic_axes = { + "pixel_values": {0: "grid_height", 1: "grid_width"}, + "image_grid_thw": {0: "batch_size", 2: "grid_h", 3: "grid_w"}, + } + + lang_dynamic_axes = { + "input_ids": {0: "batch_size", 1: "seq_len"}, + "position_ids": {1: "batch_size", 2: "seq_len"}, + "vision_embeds": {0: "batch_size", 1: "vision_size"}, + } + + for i in range(num_layers): + lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} + lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + + dynamic_axes = {} + if kv_offload: + dynamic_axes["vision"] = vision_dynamic_axes + dynamic_axes["lang"] = lang_dynamic_axes + else: + lang_dynamic_axes.pop("vision_embeds") + dynamic_axes = {**vision_dynamic_axes, **lang_dynamic_axes} + return dynamic_axes + + def get_output_names(self, kv_offload: bool = False): + vision_output_names = ["vision_embeds"] + lang_output_names = ["logits"] + for i in range(self.model.config.text_config.num_hidden_layers): + for kv in ["key", "value"]: + lang_output_names.append(f"past_{kv}.{i}_RetainedState") + + output_names = {} + if kv_offload: + lang_output_names.insert(1, "vision_embeds_RetainedState") + lang_output_names.insert(2, "image_idx_output") + output_names["vision"] = vision_output_names + output_names["lang"] = lang_output_names + else: + lang_output_names.insert(1, "pixel_values_RetainedState") + lang_output_names.insert(2, "image_idx_output") + return lang_output_names + return output_names + + def get_inputs_info(self): + return [ + IOInfo(name="input_ids", datatype=torch.int64, shape=("batch_size", "seq_len")), + IOInfo(name="attention_mask", datatype=torch.int64, shape=("batch_size", "seq_len")), + IOInfo(name="pixel_values", datatype=torch.float32, shape=("batch_size", 3, "image_size", "image_size")), + ] diff --git a/examples/qwen3_vl_moe.py b/examples/qwen3_vl_moe.py new file mode 100644 index 000000000..08c31a937 --- /dev/null +++ b/examples/qwen3_vl_moe.py @@ -0,0 +1,183 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import requests +import torch +import torch.nn.functional as F +import transformers +from PIL import Image +from qwen_vl_utils import process_vision_info +from transformers import AutoConfig, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +# model_id = "Qwen/Qwen2.5-VL-32B-Instruct" +model_id = "Qwen/Qwen3-VL-30B-A3B-Instruct" +config = AutoConfig.from_pretrained(model_id) + +# For Testing Purpose Only +# config.vision_config.num_hidden_layers = 1 +config.text_config.num_hidden_layers = 1 + +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, attn_implementation="eager", kv_offload=True, config=config +) +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) +# breakpoint() + +### use skip_vision=Ture, if want to run only text, ow false ### +skip_vision = False + +if skip_vision: + ## Only Text ## + + ## Set Batch_Size ## + batch_size = 2 + qeff_model.compile( + batch_size=batch_size, + prefill_seq_len=128, + ctx_len=4096, + num_cores=16, + num_devices=4, + height=354, + width=536, + mxfp6_matmul=False, + aic_enable_depth_first=True, + skip_vision=True, + mos=1, + ) + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Tell me about yourself."}, + ], + }, + ] + + messages = [messages] * batch_size + + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + # breakpoint() + pos_ids, rope_deltas = qeff_model.model.get_rope_index( + inputs["input_ids"], + image_grid_thw=None, + video_grid_thw=None, + second_per_grid_ts=None, + attention_mask=inputs["attention_mask"], + ) + + input_ids_length = inputs["input_ids"].shape[1] + + inputs["position_ids"] = torch.cat([pos_ids, pos_ids[0].unsqueeze(0)], dim=0) + + prefill_seq_len = 128 + num_chunks = -(input_ids_length // -prefill_seq_len) # ceil divide without float + padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len + + inputs["position_ids"] = F.pad( + inputs["position_ids"], pad=(0, padded_len - input_ids_length), mode="constant", value=-1 + ) + + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, generation_len=100) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) + +else: + batch_size = 1 + ## Vision + Text ## + # breakpoint() + qeff_model.compile( + batch_size=batch_size, + prefill_seq_len=128, + ctx_len=4096, + num_cores=16, + num_devices=4, + height=354, + width=536, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + ) + + ### IMAGE + TEXT ### + image_url = "https://picsum.photos/id/237/536/354" + + image = Image.open(requests.get(image_url, stream=True).raw) + + messages_1 = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": "Describe this image."}, + ], + }, + ] + + messages_2 = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": "Describe about the color of the dog."}, + ], + }, + ] + + messages = [messages_2] * batch_size + + texts = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in messages] + + image_inputs, video_inputs = process_vision_info(messages) + inputs = processor( + text=texts, + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ) + breakpoint() + input_ids_length = inputs["input_ids"].shape[1] + + inputs["position_ids"] = torch.arange(input_ids_length).view(1, 1, input_ids_length).expand(-1, batch_size, -1) + + pos_ids, rope_deltas = qeff_model.model.model.get_rope_index( + inputs["input_ids"], + inputs["image_grid_thw"], + video_grid_thw=None, + second_per_grid_ts=None, + attention_mask=inputs["attention_mask"], + ) + + inputs["position_ids"] = torch.cat((inputs["position_ids"], pos_ids), dim=0) + + prefill_seq_len = 128 + num_chunks = -(input_ids_length // -prefill_seq_len) # ceil divide without float + padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len + + inputs["position_ids"] = F.pad( + inputs["position_ids"], pad=(0, padded_len - input_ids_length), mode="constant", value=-1 + ) + + inputs.pop("image_grid_thw") + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, generation_len=100) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output)