Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions QEfficient/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def read_only(self, cache_kwargs):
ctx_indices = torch.arange(ctx_len)[None, None, ...]
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
invalid_mask = ctx_indices > gather_limit

breakpoint()
if torch.onnx.is_in_onnx_export():
invalid_idx_value = torch.iinfo(torch.int32).max
else:
Expand Down Expand Up @@ -113,6 +113,7 @@ def update(
Return:
A tuple containing the updated key and value states.
"""
# breakpoint()
# Update the cache
if self.keys is None:
self.keys = key_states
Expand Down Expand Up @@ -237,15 +238,41 @@ class QEffDynamicCache(DynamicCache):

"""

def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None, *args, **kwargs):
def __init__(
self,
ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None,
config=None,
offloading: bool = False,
offload_only_non_sliding: bool = False,
*args,
**kwargs,
):
# Remove layer_classes if present to avoid duplicate argument
kwargs.pop("layer_classes", None)
kwargs.pop("layers", None)
from transformers.cache_utils import Cache # Import here to avoid circular import

Cache.__init__(self, layer_classes=QEffDynamicLayer, *args, **kwargs)
layers = []
if len(layers) == 0:
Cache.__init__(
self,
layer_class_to_replicate=QEffDynamicLayer,
offloading=offloading,
offload_only_non_sliding=offload_only_non_sliding,
)
else:
Cache.__init__(
self,
layers=layers,
offloading=offloading,
offload_only_non_sliding=offload_only_non_sliding,
)

if ddp_cache_data is not None:
for key_states, value_states in ddp_cache_data:
self.layers.append(QEffDynamicLayer.from_tensors(key_states, value_states))
for layer_idx, (key_states, value_states) in enumerate(ddp_cache_data):
# If the config was not passed above, initialize a DynamicLayer for each entry of the ddp_data
layers.append(QEffDynamicLayer())
# Update the layer with the data
_, _ = layers[layer_idx].update(key_states, value_states)

def read_only(self, layer_idx, cache_kwargs):
"""
Expand All @@ -260,6 +287,7 @@ def read_only(self, layer_idx, cache_kwargs):
Return:
A tuple containing the updated key and value states.
"""
# breakpoint()
return self.layers[layer_idx].read_only(cache_kwargs)

def write_only(self, key_states, value_states, layer_idx, cache_kwargs):
Expand Down
176 changes: 64 additions & 112 deletions QEfficient/transformers/models/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,6 @@
MistralModel,
MistralRMSNorm,
)
from transformers.models.mistral3.modeling_mistral3 import (
Mistral3ForConditionalGeneration,
Mistral3Model,
Mistral3RMSNorm,
)
from transformers.models.mixtral.modeling_mixtral import (
MixtralAttention,
MixtralDecoderLayer,
Expand All @@ -129,13 +124,6 @@
MllamaVisionModel,
)
from transformers.models.mpt.modeling_mpt import MptAttention, MptBlock, MptForCausalLM, MptModel
from transformers.models.olmo2.modeling_olmo2 import (
Olmo2Attention,
Olmo2DecoderLayer,
Olmo2ForCausalLM,
Olmo2Model,
Olmo2RMSNorm,
)
from transformers.models.phi.modeling_phi import PhiAttention, PhiDecoderLayer, PhiForCausalLM, PhiModel
from transformers.models.phi3.modeling_phi3 import (
Phi3Attention,
Expand All @@ -144,29 +132,33 @@
Phi3Model,
Phi3RMSNorm,
)
from transformers.models.pixtral.modeling_pixtral import PixtralRMSNorm, PixtralVisionModel
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2Attention,
Qwen2DecoderLayer,
Qwen2ForCausalLM,
Qwen2Model,
Qwen2RMSNorm,
)
from transformers.models.qwen3.modeling_qwen3 import (
Qwen3Attention,
Qwen3DecoderLayer,
Qwen3ForCausalLM,
Qwen3Model,
Qwen3RMSNorm,
)
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
Qwen3MoeAttention,
Qwen3MoeDecoderLayer,
Qwen3MoeForCausalLM,
Qwen3MoeModel,
Qwen3MoeRMSNorm,
Qwen3MoeRotaryEmbedding,
Qwen3MoeSparseMoeBlock,
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VisionTransformerPretrainedModel,
Qwen2_5_VLAttention,
Qwen2_5_VLDecoderLayer,
Qwen2_5_VLForConditionalGeneration,
Qwen2_5_VLTextModel,
Qwen2_5_VLVisionAttention,
)
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2RMSNorm as Qwen2_5RMSNorm,
)
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
Qwen3VLMoeForConditionalGeneration,
Qwen3VLMoeModel,
Qwen3VLMoeTextAttention,
Qwen3VLMoeTextDecoderLayer,
Qwen3VLMoeTextModel,
Qwen3VLMoeTextRMSNorm,
Qwen3VLMoeVisionAttention,
Qwen3VLMoeVisionModel,
)
from transformers.models.starcoder2.modeling_starcoder2 import (
Starcoder2Attention,
Expand Down Expand Up @@ -294,11 +286,6 @@
QEffMistralForCausalLM,
QEffMistralModel,
)
from QEfficient.transformers.models.mistral3.modeling_mistral3 import (
QEffMistral3ForConditionalGeneration,
QEffMistral3Model,
QEffPixtralVisionModel,
)
from QEfficient.transformers.models.mixtral_moe.modeling_mixtral import (
QEffMixtralAttention,
QeffMixtralDecoderLayer,
Expand All @@ -319,25 +306,12 @@
QEffMllamaTextSelfAttention,
QEffMllamaVisionModel,
)
from QEfficient.transformers.models.molmo.modeling_molmo import (
QEffMolmo,
QEffMolmoBlock,
QEffMolmoModel,
QEffMolmoSequentialBlock,
QEffMultiHeadDotProductAttention,
)
from QEfficient.transformers.models.mpt.modeling_mpt import (
QEffMptAttention,
QEffMptBlock,
QEffMptForCausalLM,
QEFfMptModel,
)
from QEfficient.transformers.models.olmo2.modeling_olmo2 import (
QEffOlmo2Attention,
QEffOlmo2DecoderLayer,
QEffOlmo2ForCausalLM,
QEffOlmo2Model,
)
from QEfficient.transformers.models.phi.modeling_phi import (
QEffPhiAttention,
QEffPhiDecoderLayer,
Expand All @@ -356,19 +330,23 @@
QEffQwen2ForCausalLM,
QEffQwen2Model,
)
from QEfficient.transformers.models.qwen3.modeling_qwen3 import (
QEffQwen3Attention,
QEffQwen3DecoderLayer,
QEffQwen3ForCausalLM,
QEffQwen3Model,
)
from QEfficient.transformers.models.qwen3_moe.modeling_qwen3_moe import (
QEffQwen3MoeAttention,
QEffQwen3MoeDecoderLayer,
QEffQwen3MoeForCausalLM,
QEffQwen3MoeModel,
QEffQwen3MoeRotaryEmbedding,
QEffQwen3MoeSparseMoeBlock,
from QEfficient.transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
QEffQwen2_5_VisionTransformerPretrainedModel,
QEffQwen2_5_VLAttention,
QEffQwen2_5_VLDecoderLayer,
QEffQwen2_5_VLTextModel,
# QEffQwen2_5_VLModel,
QEffQwen2_5_VLVisionAttention,
QEffQwen_2_5_vl_ForConditionalGeneration,
)
from QEfficient.transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
QEffQwen3VLMoeForConditionalGeneration,
QEffQwen3VLMoeModel,
QEffQwen3VLMoeTextAttention,
QEffQwen3VLMoeTextDecoderLayer,
QEffQwen3VLMoeTextModel,
QEffQwen3VLMoeVisionAttention,
QEffQwen3VLMoeVisionModel,
)
from QEfficient.transformers.models.starcoder2.modeling_starcoder2 import (
QEffStarcoder2Attention,
Expand Down Expand Up @@ -399,18 +377,16 @@ class CustomOpsTransform(ModuleMappingTransform):
LlamaRMSNorm: CustomRMSNormAIC,
Llama4TextRMSNorm: CustomRMSNormAIC,
MistralRMSNorm: CustomRMSNormAIC,
Mistral3RMSNorm: CustomRMSNormAIC,
MixtralRMSNorm: CustomRMSNormAIC,
Phi3RMSNorm: CustomRMSNormAIC,
Qwen2RMSNorm: CustomRMSNormAIC,
Qwen3RMSNorm: CustomRMSNormAIC,
Qwen2_5RMSNorm: CustomRMSNormAIC,
MllamaTextRMSNorm: CustomRMSNormAIC,
GraniteRMSNorm: CustomRMSNormAIC,
PixtralRMSNorm: CustomRMSNormAIC,
GraniteMoeRMSNorm: CustomRMSNormAIC,
Qwen3MoeRMSNorm: CustomRMSNormAIC,
Qwen3VLMoeTextRMSNorm: CustomRMSNormAIC,
Gemma3RMSNorm: QEffGemma3CustomRMSNormAIC,
Olmo2RMSNorm: CustomRMSNormAIC,
# Qwen3VLMoeTextRMSNorm: CustomRMSNormAIC,
}


Expand Down Expand Up @@ -463,12 +439,12 @@ class KVCacheTransform(ModuleMappingTransform):
GemmaModel: QEffGemmaModel,
GemmaForCausalLM: QEffGemmaForCausalLM,
# Qwen3Moe
Qwen3MoeForCausalLM: QEffQwen3MoeForCausalLM,
Qwen3MoeModel: QEffQwen3MoeModel,
Qwen3MoeDecoderLayer: QEffQwen3MoeDecoderLayer,
Qwen3MoeAttention: QEffQwen3MoeAttention,
Qwen3MoeRotaryEmbedding: QEffQwen3MoeRotaryEmbedding,
Qwen3MoeSparseMoeBlock: QEffQwen3MoeSparseMoeBlock,
# Qwen3MoeForCausalLM: QEffQwen3MoeForCausalLM,
# Qwen3MoeModel: QEffQwen3MoeModel,
# Qwen3MoeDecoderLayer: QEffQwen3MoeDecoderLayer,
# Qwen3MoeAttention: QEffQwen3MoeAttention,
# Qwen3MoeRotaryEmbedding: QEffQwen3MoeRotaryEmbedding,
# Qwen3MoeSparseMoeBlock: QEffQwen3MoeSparseMoeBlock,
# Gemma2
Gemma2Attention: QEffGemma2Attention,
Gemma2DecoderLayer: QEffGemma2DecoderLayer,
Expand Down Expand Up @@ -508,9 +484,6 @@ class KVCacheTransform(ModuleMappingTransform):
MistralDecoderLayer: QEffMistralDecoderLayer,
MistralModel: QEffMistralModel,
MistralForCausalLM: QEffMistralForCausalLM,
# Mistral3
Mistral3ForConditionalGeneration: QEffMistral3ForConditionalGeneration,
Mistral3Model: QEffMistral3Model,
# Mixtral
MixtralAttention: QEffMixtralAttention,
MixtralSparseMoeBlock: QEffMixtralSparseMoeBlock,
Expand All @@ -532,18 +505,28 @@ class KVCacheTransform(ModuleMappingTransform):
PhiDecoderLayer: QEffPhiDecoderLayer,
PhiModel: QEffPhiModel,
PhiForCausalLM: QEffPhiForCausalLM,
# Pixtral
PixtralVisionModel: QEffPixtralVisionModel,
# Qwen2
Qwen2Attention: QEffQwen2Attention,
Qwen2DecoderLayer: QEffQwen2DecoderLayer,
Qwen2Model: QEffQwen2Model,
Qwen2ForCausalLM: QEffQwen2ForCausalLM,
# Qwen3
Qwen3Attention: QEffQwen3Attention,
Qwen3DecoderLayer: QEffQwen3DecoderLayer,
Qwen3Model: QEffQwen3Model,
Qwen3ForCausalLM: QEffQwen3ForCausalLM,
# Qwen2.5 VL
Qwen2_5_VLForConditionalGeneration: QEffQwen_2_5_vl_ForConditionalGeneration,
# Qwen2_5_VLModel: QEffQwen2_5_VLModel,
Qwen2_5_VLTextModel: QEffQwen2_5_VLTextModel,
Qwen2_5_VLAttention: QEffQwen2_5_VLAttention,
Qwen2_5_VLDecoderLayer: QEffQwen2_5_VLDecoderLayer,
Qwen2_5_VisionTransformerPretrainedModel: QEffQwen2_5_VisionTransformerPretrainedModel,
Qwen2_5_VLVisionAttention: QEffQwen2_5_VLVisionAttention,
# Qwen3vlmoe
Qwen3VLMoeForConditionalGeneration: QEffQwen3VLMoeForConditionalGeneration,
Qwen3VLMoeModel: QEffQwen3VLMoeModel,
Qwen3VLMoeTextAttention: QEffQwen3VLMoeTextAttention,
Qwen3VLMoeTextDecoderLayer: QEffQwen3VLMoeTextDecoderLayer,
Qwen3VLMoeVisionAttention: QEffQwen3VLMoeVisionAttention,
Qwen3VLMoeVisionModel: QEffQwen3VLMoeVisionModel,
Qwen3VLMoeTextModel: QEffQwen3VLMoeTextModel,
# Grok1
# Starcoder2
Starcoder2Attention: QEffStarcoder2Attention,
Starcoder2DecoderLayer: QEFFStarcoder2DecoderLayer,
Expand All @@ -554,11 +537,6 @@ class KVCacheTransform(ModuleMappingTransform):
GPTBigCodeBlock: QEffGPTBigCodeBlock,
GPTBigCodeModel: QEffGPTBigCodeModel,
GPTBigCodeForCausalLM: QEffGPTBigCodeForCausalLM,
# Olmo2
Olmo2Attention: QEffOlmo2Attention,
Olmo2DecoderLayer: QEffOlmo2DecoderLayer,
Olmo2Model: QEffOlmo2Model,
Olmo2ForCausalLM: QEffOlmo2ForCausalLM,
# Whisper encoder and decoder layers
WhisperPositionalEmbedding: QEffWhisperPositionalEmbedding,
WhisperAttention: QEffWhisperAttention,
Expand Down Expand Up @@ -595,7 +573,7 @@ class SpDTransform:
# Llama
QEffLlamaForCausalLM,
QEffQwen2ForCausalLM,
QEffQwen3ForCausalLM,
# QEffQwen3ForCausalLM,
}

@classmethod
Expand Down Expand Up @@ -693,32 +671,6 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform):
"get_qeff_language_decoder": QEffInternVLModel.get_qeff_language_decoder,
},
"InternVisionEmbeddings": {"forward": QEffInternVisionEmbeddings.forward},
# Mapping for Molmo
"MolmoForCausalLM": {
"forward": QEffMolmoModel.forward,
"get_qeff_vision_encoder": QEffMolmoModel.get_qeff_vision_encoder,
"get_qeff_language_decoder": QEffMolmoModel.get_qeff_language_decoder,
"get_specializations": QEffMolmoModel.get_specializations,
"get_onnx_dynamic_axes": QEffMolmoModel.get_onnx_dynamic_axes,
"get_output_names": QEffMolmoModel.get_output_names,
"get_dummy_inputs": QEffMolmoModel.get_dummy_inputs,
"get_inputs_info": QEffMolmoModel.get_inputs_info,
},
"RMSLayerNorm": {"forward": CustomRMSNormAIC.forward},
# "MolmoForCausalLM": {"forward": QEffMolmoForCausalLM.forward},
"Molmo": {"forward": QEffMolmo.forward},
"MolmoSequentialBlock": {
"forward": QEffMolmoSequentialBlock.forward,
"attention": QEffMolmoBlock.attention,
"__qeff_init__": QEffMolmoBlock.__qeff_init__,
},
"MolmoBlock": {
"attention": QEffMolmoBlock.attention,
"__qeff_init__": QEffMolmoBlock.__qeff_init__,
},
"MultiHeadDotProductAttention": {
"forward": QEffMultiHeadDotProductAttention.forward,
},
# Mapping for grok1 model
"Grok1ModelForCausalLM": {"forward": QEffGrok1ModelForCausalLM.forward},
"Grok1Model": {
Expand Down
6 changes: 6 additions & 0 deletions QEfficient/transformers/models/qwen3_vl_moe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
Loading
Loading