diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py index 169da9c150..8146760564 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py @@ -381,11 +381,24 @@ def get_moe_group_name(group): AscendKVQuantMeta.set_value(step_context.block_offsets.device, step_context.model_config.dtype, record_file, total_layers) + cu_seqlens = None + has_initial_state = None + + is_gated_delta = step_context.model_config.is_gated_delta + if is_gated_delta: + q_start_loc = step_context.q_start_loc.to(dtype=step_context.q_seqlens.dtype, + device=step_context.q_seqlens.device) + cu_seqlens = torch.cat((q_start_loc, step_context.q_seqlens.sum().unsqueeze(0))).int() + if not step_context.is_decoding: + has_initial_state = ~(step_context.q_seqlens == step_context.kv_seqlens) + attn_meta_cls = cls.get_attention_metadata_cls() attn_metadata = attn_meta_cls( step_context.is_decoding, step_context.block_offsets, - q_start_loc=None, + # cu_seqlens is only used in GDN and is passed down via q_start_loc. + # Otherwise, q_start_loc is None. + q_start_loc=cu_seqlens, q_seqlens=q_seqlens_cpu, # kv_seqlens_expanded is only expanded in paged prefill, # otherwise it equals kv_seqlens_cpu @@ -398,6 +411,7 @@ def get_moe_group_name(group): max_kv_seq_len=max_kv_seq_len, quant_policy=step_context.kv_quant_policy, quant_meta=AscendKVQuantMeta.quant_meta, + has_initial_state=has_initial_state, ) step_context.attn_metadata = attn_metadata @@ -438,7 +452,13 @@ def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig, cache_ @staticmethod def init(): - """Initialize Ascend backend.""" + """Initialize Ascend backend. + + Note: triton device properties initialization is only required for models that use + linear attention (e.g. Qwen3.5 35B). If triton-ascend is not installed, a warning + is emitted but non-linear-attention models are unaffected. + """ + try: from torch_npu.contrib import transfer_to_npu # noqa: F401 except ImportError: @@ -448,6 +468,16 @@ def init(): logger.warning(f'Error during Ascend initialization: {str(e)}. ' 'Please check your Ascend environment configuration.') + try: + from dlinfer.vendor.ascend.triton_ops.triton_utils import init_device_properties_triton + init_device_properties_triton() + except ImportError: + logger.warning('triton-ascend is not installed. Only linear attention models (e.g. Qwen3.5 35B) ' + 'require triton-ascend. Please install it with: pip install triton-ascend==3.2.0') + except Exception as e: + logger.warning(f'Error during Ascend initialization: {str(e)}. ' + 'Please check your Ascend environment configuration.') + @staticmethod def ccl_backend(): return 'hccl' diff --git a/lmdeploy/pytorch/backends/dlinfer/attention.py b/lmdeploy/pytorch/backends/dlinfer/attention.py index 78afe49040..e0bead2889 100644 --- a/lmdeploy/pytorch/backends/dlinfer/attention.py +++ b/lmdeploy/pytorch/backends/dlinfer/attention.py @@ -18,6 +18,7 @@ class DlinferAttentionMetadata(AttentionMetadata): max_kv_seq_len: int = 1 quant_meta: dict = None cu_seq_lens_kv: Tensor | None = None + has_initial_state: Tensor | None = None class DlinferAttentionImpl(AttentionImpl[DlinferAttentionMetadata]): diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 776f6e0745..5fe5dc73f9 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -451,6 +451,7 @@ def from_hf_config( is_draft_model=is_draft_model, spec_method=spec_method, num_spec_tokens=num_spec_tokens, + device_type=device_type, ) if model_config.k_head_dim is None: diff --git a/lmdeploy/pytorch/configurations/qwen3_5.py b/lmdeploy/pytorch/configurations/qwen3_5.py index 9746a7949b..ec4823cc5a 100644 --- a/lmdeploy/pytorch/configurations/qwen3_5.py +++ b/lmdeploy/pytorch/configurations/qwen3_5.py @@ -58,7 +58,8 @@ def build(cls, else: recurrent_state_shape = (num_delta_layers, num_v_heads, head_k_dim, head_v_dim) - if is_bf16_supported(): + device_type = kwargs.get('device_type', 'auto') + if is_bf16_supported(device_type): dtype = torch.bfloat16 else: dtype = torch.float16