From 8f92c4d6ec7c56304260b193978911124304673f Mon Sep 17 00:00:00 2001 From: WangQing <2917021186@qq.com> Date: Wed, 4 Mar 2026 08:19:22 +0000 Subject: [PATCH 1/9] [Ascend] support qwen3.5 --- .../pytorch/backends/dlinfer/ascend/op_backend.py | 14 +++++++++++++- lmdeploy/pytorch/backends/dlinfer/attention.py | 1 + 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py index 169da9c150..8d63da1895 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py @@ -381,11 +381,20 @@ def get_moe_group_name(group): AscendKVQuantMeta.set_value(step_context.block_offsets.device, step_context.model_config.dtype, record_file, total_layers) + cu_seqlens = None + has_initial_state = None + + if step_context.state_offsets is not None: + q_start_loc = step_context.q_start_loc + cu_seqlens = torch.cat((q_start_loc, step_context.q_seqlens.sum().unsqueeze(0))).int() + if not step_context.is_decoding: + has_initial_state = ~(step_context.q_seqlens == step_context.kv_seqlens) + attn_meta_cls = cls.get_attention_metadata_cls() attn_metadata = attn_meta_cls( step_context.is_decoding, step_context.block_offsets, - q_start_loc=None, + q_start_loc=cu_seqlens, q_seqlens=q_seqlens_cpu, # kv_seqlens_expanded is only expanded in paged prefill, # otherwise it equals kv_seqlens_cpu @@ -398,6 +407,7 @@ def get_moe_group_name(group): max_kv_seq_len=max_kv_seq_len, quant_policy=step_context.kv_quant_policy, quant_meta=AscendKVQuantMeta.quant_meta, + has_initial_state=has_initial_state, ) step_context.attn_metadata = attn_metadata @@ -441,6 +451,8 @@ def init(): """Initialize Ascend backend.""" try: from torch_npu.contrib import transfer_to_npu # noqa: F401 + from dlinfer.vendor.ascend.triton_ops.fla.triton_utils import init_device_properties_triton + init_device_properties_triton() except ImportError: logger.warning('Failed to import torch_npu. Please make sure torch_npu is installed correctly. ' 'Ascend initialization skipped.') diff --git a/lmdeploy/pytorch/backends/dlinfer/attention.py b/lmdeploy/pytorch/backends/dlinfer/attention.py index 78afe49040..e0bead2889 100644 --- a/lmdeploy/pytorch/backends/dlinfer/attention.py +++ b/lmdeploy/pytorch/backends/dlinfer/attention.py @@ -18,6 +18,7 @@ class DlinferAttentionMetadata(AttentionMetadata): max_kv_seq_len: int = 1 quant_meta: dict = None cu_seq_lens_kv: Tensor | None = None + has_initial_state: Tensor | None = None class DlinferAttentionImpl(AttentionImpl[DlinferAttentionMetadata]): From 7950c134e8b5d51f943d7a452690093ca1389b3e Mon Sep 17 00:00:00 2001 From: WangQing <2917021186@qq.com> Date: Thu, 19 Mar 2026 08:39:56 +0000 Subject: [PATCH 2/9] fix: update import path for init_device_properties_triton Co-Authored-By: Claude Sonnet 4.6 --- lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py index 8d63da1895..0bf8d8f9e9 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py @@ -450,8 +450,8 @@ def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig, cache_ def init(): """Initialize Ascend backend.""" try: + from dlinfer.vendor.ascend.triton_ops.triton_utils import init_device_properties_triton from torch_npu.contrib import transfer_to_npu # noqa: F401 - from dlinfer.vendor.ascend.triton_ops.fla.triton_utils import init_device_properties_triton init_device_properties_triton() except ImportError: logger.warning('Failed to import torch_npu. Please make sure torch_npu is installed correctly. ' From aed7e7ad789938ecd5de3aba746df6b4fd3097f8 Mon Sep 17 00:00:00 2001 From: WangQing <2917021186@qq.com> Date: Tue, 24 Mar 2026 08:19:27 +0000 Subject: [PATCH 3/9] [ascend] fix missing device_type --- lmdeploy/pytorch/config.py | 3 ++- lmdeploy/pytorch/configurations/qwen3_5.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index e78b87e811..e3c2800e47 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -443,7 +443,8 @@ def from_hf_config( model_path, tp=tp, is_draft_model=is_draft_model, - spec_method=spec_method) + spec_method=spec_method, + device_type=device_type) if model_config.k_head_dim is None: assert model_config.head_dim is not None diff --git a/lmdeploy/pytorch/configurations/qwen3_5.py b/lmdeploy/pytorch/configurations/qwen3_5.py index 4c78705a25..0c12b7d93c 100644 --- a/lmdeploy/pytorch/configurations/qwen3_5.py +++ b/lmdeploy/pytorch/configurations/qwen3_5.py @@ -44,7 +44,8 @@ def build(cls, hf_config, model_path: str = None, tp: int = 1, **kwargs): conv_state_shape = (num_delta_layers, conv_dim, conv_kernel_size) recurrent_state_shape = (num_delta_layers, num_v_heads, head_k_dim, head_v_dim) - if is_bf16_supported(): + device_type = kwargs.get('device_type', 'cuda') + if is_bf16_supported(device_type): dtype = torch.bfloat16 else: dtype = torch.float16 From 446f859fd66b18025eb22a985ccad8722ee421e9 Mon Sep 17 00:00:00 2001 From: WangQing <2917021186@qq.com> Date: Tue, 7 Apr 2026 03:04:13 +0000 Subject: [PATCH 4/9] [ascend] refactor code --- lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py index 0bf8d8f9e9..319f5d8d5f 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py @@ -384,7 +384,8 @@ def get_moe_group_name(group): cu_seqlens = None has_initial_state = None - if step_context.state_offsets is not None: + is_gated_delta = step_context.model_config.is_gated_delta + if is_gated_delta: q_start_loc = step_context.q_start_loc cu_seqlens = torch.cat((q_start_loc, step_context.q_seqlens.sum().unsqueeze(0))).int() if not step_context.is_decoding: From d129b1498c98b792002ce01582031dc5be8e79da Mon Sep 17 00:00:00 2001 From: WangQing <2917021186@qq.com> Date: Tue, 7 Apr 2026 03:18:44 +0000 Subject: [PATCH 5/9] [ascend] add comment --- lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py index 319f5d8d5f..91019fdb29 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py @@ -395,6 +395,8 @@ def get_moe_group_name(group): attn_metadata = attn_meta_cls( step_context.is_decoding, step_context.block_offsets, + # cu_seqlens is only used in GDN and is passed down via q_start_loc. + # Otherwise, q_start_loc is None. q_start_loc=cu_seqlens, q_seqlens=q_seqlens_cpu, # kv_seqlens_expanded is only expanded in paged prefill, From e5db1536031a5cc92cdba1ecd939958f0653b9d1 Mon Sep 17 00:00:00 2001 From: WangQing <2917021186@qq.com> Date: Tue, 7 Apr 2026 04:48:38 +0000 Subject: [PATCH 6/9] [ascend] add comment --- .../backends/dlinfer/ascend/op_backend.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py index 91019fdb29..62acb776d9 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py @@ -451,11 +451,15 @@ def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig, cache_ @staticmethod def init(): - """Initialize Ascend backend.""" + """Initialize Ascend backend. + + Note: triton device properties initialization is only required for models that use + linear attention (e.g. Qwen3.5 35B). If triton-ascend is not installed, a warning + is emitted but non-linear-attention models are unaffected. + """ + try: - from dlinfer.vendor.ascend.triton_ops.triton_utils import init_device_properties_triton from torch_npu.contrib import transfer_to_npu # noqa: F401 - init_device_properties_triton() except ImportError: logger.warning('Failed to import torch_npu. Please make sure torch_npu is installed correctly. ' 'Ascend initialization skipped.') @@ -463,6 +467,16 @@ def init(): logger.warning(f'Error during Ascend initialization: {str(e)}. ' 'Please check your Ascend environment configuration.') + try: + from dlinfer.vendor.ascend.triton_ops.triton_utils import init_device_properties_triton + init_device_properties_triton() + except ImportError: + logger.warning('triton-ascend is not installed. Only linear attention models (e.g. Qwen3.5 35B) ' + 'require triton-ascend. Please install it with: pip install triton-ascend==3.2.0') + except Exception as e: + logger.warning(f'Error during Ascend initialization: {str(e)}. ' + 'Please check your Ascend environment configuration.') + @staticmethod def ccl_backend(): return 'hccl' From 9c3aae7f7776bbb4973d45e0a00828c60e517964 Mon Sep 17 00:00:00 2001 From: Qing Wang <49198408+wanfengcxz@users.noreply.github.com> Date: Tue, 7 Apr 2026 16:01:37 +0800 Subject: [PATCH 7/9] Fix syntax error in device_type assignment --- lmdeploy/pytorch/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 56a9822e10..5fe5dc73f9 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -451,7 +451,7 @@ def from_hf_config( is_draft_model=is_draft_model, spec_method=spec_method, num_spec_tokens=num_spec_tokens, - device_type=device_type, + device_type=device_type, ) if model_config.k_head_dim is None: From c489d43861c878cf41b68fe4c25aa9436ded8f50 Mon Sep 17 00:00:00 2001 From: Qing Wang <49198408+wanfengcxz@users.noreply.github.com> Date: Tue, 7 Apr 2026 18:59:58 +0800 Subject: [PATCH 8/9] Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- lmdeploy/pytorch/configurations/qwen3_5.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lmdeploy/pytorch/configurations/qwen3_5.py b/lmdeploy/pytorch/configurations/qwen3_5.py index 9f19fcd92f..ec4823cc5a 100644 --- a/lmdeploy/pytorch/configurations/qwen3_5.py +++ b/lmdeploy/pytorch/configurations/qwen3_5.py @@ -58,7 +58,7 @@ def build(cls, else: recurrent_state_shape = (num_delta_layers, num_v_heads, head_k_dim, head_v_dim) - device_type = kwargs.get('device_type', 'cuda') + device_type = kwargs.get('device_type', 'auto') if is_bf16_supported(device_type): dtype = torch.bfloat16 else: From 3f495215bfe24763394e58e8e2782d7d99ccb65a Mon Sep 17 00:00:00 2001 From: Qing Wang <49198408+wanfengcxz@users.noreply.github.com> Date: Tue, 7 Apr 2026 19:04:44 +0800 Subject: [PATCH 9/9] Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py index 62acb776d9..8146760564 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py @@ -386,7 +386,8 @@ def get_moe_group_name(group): is_gated_delta = step_context.model_config.is_gated_delta if is_gated_delta: - q_start_loc = step_context.q_start_loc + q_start_loc = step_context.q_start_loc.to(dtype=step_context.q_seqlens.dtype, + device=step_context.q_seqlens.device) cu_seqlens = torch.cat((q_start_loc, step_context.q_seqlens.sum().unsqueeze(0))).int() if not step_context.is_decoding: has_initial_state = ~(step_context.q_seqlens == step_context.kv_seqlens)