Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,11 +381,24 @@ def get_moe_group_name(group):
AscendKVQuantMeta.set_value(step_context.block_offsets.device, step_context.model_config.dtype,
record_file, total_layers)

cu_seqlens = None
has_initial_state = None

is_gated_delta = step_context.model_config.is_gated_delta
if is_gated_delta:
q_start_loc = step_context.q_start_loc.to(dtype=step_context.q_seqlens.dtype,
device=step_context.q_seqlens.device)
cu_seqlens = torch.cat((q_start_loc, step_context.q_seqlens.sum().unsqueeze(0))).int()
if not step_context.is_decoding:
has_initial_state = ~(step_context.q_seqlens == step_context.kv_seqlens)

attn_meta_cls = cls.get_attention_metadata_cls()
attn_metadata = attn_meta_cls(
step_context.is_decoding,
step_context.block_offsets,
q_start_loc=None,
# cu_seqlens is only used in GDN and is passed down via q_start_loc.
# Otherwise, q_start_loc is None.
q_start_loc=cu_seqlens,
q_seqlens=q_seqlens_cpu,
# kv_seqlens_expanded is only expanded in paged prefill,
# otherwise it equals kv_seqlens_cpu
Expand All @@ -398,6 +411,7 @@ def get_moe_group_name(group):
max_kv_seq_len=max_kv_seq_len,
quant_policy=step_context.kv_quant_policy,
quant_meta=AscendKVQuantMeta.quant_meta,
has_initial_state=has_initial_state,
)
step_context.attn_metadata = attn_metadata

Expand Down Expand Up @@ -438,7 +452,13 @@ def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig, cache_

@staticmethod
def init():
"""Initialize Ascend backend."""
"""Initialize Ascend backend.

Note: triton device properties initialization is only required for models that use
linear attention (e.g. Qwen3.5 35B). If triton-ascend is not installed, a warning
is emitted but non-linear-attention models are unaffected.
"""

try:
from torch_npu.contrib import transfer_to_npu # noqa: F401
except ImportError:
Expand All @@ -448,6 +468,16 @@ def init():
logger.warning(f'Error during Ascend initialization: {str(e)}. '
'Please check your Ascend environment configuration.')

try:
from dlinfer.vendor.ascend.triton_ops.triton_utils import init_device_properties_triton
init_device_properties_triton()
except ImportError:
logger.warning('triton-ascend is not installed. Only linear attention models (e.g. Qwen3.5 35B) '
'require triton-ascend. Please install it with: pip install triton-ascend==3.2.0')
Comment thread
wanfengcxz marked this conversation as resolved.
except Exception as e:
logger.warning(f'Error during Ascend initialization: {str(e)}. '
'Please check your Ascend environment configuration.')

@staticmethod
def ccl_backend():
return 'hccl'
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/pytorch/backends/dlinfer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class DlinferAttentionMetadata(AttentionMetadata):
max_kv_seq_len: int = 1
quant_meta: dict = None
cu_seq_lens_kv: Tensor | None = None
has_initial_state: Tensor | None = None
Comment thread
wanfengcxz marked this conversation as resolved.


class DlinferAttentionImpl(AttentionImpl[DlinferAttentionMetadata]):
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ def from_hf_config(
is_draft_model=is_draft_model,
spec_method=spec_method,
num_spec_tokens=num_spec_tokens,
device_type=device_type,
)

if model_config.k_head_dim is None:
Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/configurations/qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def build(cls,
else:
recurrent_state_shape = (num_delta_layers, num_v_heads, head_k_dim, head_v_dim)

if is_bf16_supported():
device_type = kwargs.get('device_type', 'auto')
if is_bf16_supported(device_type):
dtype = torch.bfloat16
else:
dtype = torch.float16
Expand Down
Loading