diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 17274a51fb1..207deef160e 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -12,6 +12,7 @@ from vllm.model_executor.model_loader.utils import \ process_weights_after_loading from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP +from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) @@ -129,6 +130,9 @@ def load_model(self, model) -> None: target_attn_layer_names = set( get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()) + target_indexer_layer_names = set( + get_layers_from_vllm_config(self.vllm_config, + DeepseekV32IndexerCache).keys()) draft_model_config = \ self.vllm_config.speculative_config.draft_model_config target_device = self.vllm_config.device_config.device @@ -142,6 +146,13 @@ def load_model(self, model) -> None: draft_attn_layer_names = (get_layers_from_vllm_config( self.vllm_config, AttentionLayerBase).keys() - target_attn_layer_names) + indexer_layers = get_layers_from_vllm_config(self.vllm_config, + DeepseekV32IndexerCache) + draft_indexer_layer_names = indexer_layers.keys( + ) - target_indexer_layer_names + # NOTE: Currently we don't have specific attention backend and attention metadata + # for deepseek v3.2 indexer, so we just exclude the indexer layers here. + draft_attn_layer_names = draft_attn_layer_names - draft_indexer_layer_names assert len(draft_attn_layer_names) == 1 self.attn_layer_name = list(draft_attn_layer_names)