Skip to content

Commit f9168e4

Browse files
authored
[Blackwell] add non-causal bwd/FA with TMA and atomic_add (#603)
* add non-causal bwd/FA Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * ufmt Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 1b55a31 commit f9168e4

File tree

3 files changed

+326
-9
lines changed

3 files changed

+326
-9
lines changed

tritonbench/kernels/blackwell_triton_fused_attention.py

Lines changed: 324 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,220 @@ def _attn_fwd_persist(
633633
tile_idx += num_progs
634634

635635

636+
@triton.jit
637+
def _attn_bwd_preprocess(
638+
O,
639+
DO, #
640+
Delta, #
641+
Z,
642+
H,
643+
N_CTX, #
644+
BLOCK_M: tl.constexpr,
645+
HEAD_DIM: tl.constexpr, #
646+
):
647+
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
648+
off_hz = tl.program_id(1)
649+
off_n = tl.arange(0, HEAD_DIM)
650+
# load
651+
o = tl.load(
652+
O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]
653+
)
654+
do = tl.load(
655+
DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]
656+
).to(tl.float32)
657+
delta = tl.sum(o * do, axis=1)
658+
# write-back
659+
tl.store(Delta + off_hz * N_CTX + off_m, delta)
660+
661+
662+
# The main inner-loop logic for computing dK and dV.
663+
@triton.jit
664+
def _attn_bwd_dkdv(
665+
dk,
666+
dv, #
667+
desc_q,
668+
k,
669+
v,
670+
sm_scale, #
671+
desc_do, #
672+
desc_dq,
673+
M,
674+
D, #
675+
# shared by Q/K/V/DO.
676+
stride_tok,
677+
stride_d, #
678+
off_bh,
679+
H,
680+
N_CTX,
681+
BLOCK_M1: tl.constexpr, #
682+
BLOCK_N1: tl.constexpr, #
683+
HEAD_DIM: tl.constexpr, #
684+
# Filled in by the wrapper.
685+
start_n,
686+
start_m,
687+
num_steps, #
688+
MASK: tl.constexpr,
689+
dtype: tl.constexpr,
690+
):
691+
offs_m = start_m + tl.arange(0, BLOCK_M1)
692+
offs_n = start_n + tl.arange(0, BLOCK_N1)
693+
694+
LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
695+
696+
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
697+
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
698+
curr_m = start_m
699+
step_m = BLOCK_M1
700+
for blk_idx in range(num_steps):
701+
q = desc_q.load([(off_bh + curr_m).to(tl.int32), 0])
702+
qT = tl.trans(q)
703+
# Load m before computing qk to reduce pipeline stall.
704+
offs_m = curr_m + tl.arange(0, BLOCK_M1)
705+
m = tl.load(M + offs_m)
706+
qkT = tl.dot(k, qT)
707+
pT = tl.math.exp2(qkT - m[None, :])
708+
# Autoregressive masking.
709+
if MASK:
710+
mask = offs_m[None, :] >= offs_n[:, None]
711+
pT = tl.where(mask, pT, 0.0)
712+
do = desc_do.load([(off_bh + curr_m).to(tl.int32), 0])
713+
# Compute dV.
714+
ppT = pT
715+
ppT = ppT.to(dtype)
716+
dv += tl.dot(ppT, do)
717+
# D (= delta) is pre-divided by ds_scale.
718+
Di = tl.load(D + offs_m)
719+
# Compute dP and dS.
720+
dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
721+
dsT = pT * (dpT - Di[None, :])
722+
dsT = dsT.to(dtype)
723+
dk += tl.dot(dsT, tl.trans(qT))
724+
# Compute dq = tl.dot(tl.trans(dsT), k)
725+
dq = tl.dot(tl.trans(dsT), k) * LN2
726+
desc_dq.atomic_add([(off_bh + curr_m).to(tl.int32), 0], dq)
727+
# Increment pointers.
728+
curr_m += step_m
729+
730+
return dk, dv
731+
732+
733+
def _bwd_host_descriptor_pre_hook(nargs):
734+
BLOCK_M1 = nargs["BLOCK_M1"]
735+
BLOCK_N1 = nargs["BLOCK_N1"]
736+
HEAD_DIM = nargs["HEAD_DIM"]
737+
nargs["desc_q"].block_shape = [BLOCK_M1, HEAD_DIM]
738+
nargs["desc_do"].block_shape = [BLOCK_M1, HEAD_DIM]
739+
nargs["desc_dq"].block_shape = [BLOCK_M1, HEAD_DIM]
740+
nargs["desc_v"].block_shape = [BLOCK_N1, HEAD_DIM]
741+
nargs["desc_k"].block_shape = [BLOCK_N1, HEAD_DIM]
742+
nargs["desc_dv"].block_shape = [BLOCK_N1, HEAD_DIM]
743+
nargs["desc_dk"].block_shape = [BLOCK_N1, HEAD_DIM]
744+
745+
746+
configs_bwd = [
747+
triton.Config(
748+
{
749+
"BLOCK_M1": 32,
750+
"BLOCK_N1": 128,
751+
"BLOCK_M2": 128,
752+
"BLOCK_N2": 32,
753+
},
754+
num_warps=4,
755+
num_stages=1,
756+
pre_hook=_bwd_host_descriptor_pre_hook,
757+
)
758+
]
759+
760+
761+
@triton.autotune(configs=configs_bwd, key=["N_CTX", "HEAD_DIM"])
762+
@triton.jit
763+
def _attn_bwd(
764+
desc_q,
765+
desc_k,
766+
desc_v,
767+
sm_scale, #
768+
desc_do, #
769+
desc_dq,
770+
desc_dk,
771+
desc_dv, #
772+
M,
773+
D,
774+
# shared by Q/K/V/DO.
775+
stride_z,
776+
stride_h,
777+
stride_tok,
778+
stride_d, #
779+
H,
780+
N_CTX, #
781+
BLOCK_M1: tl.constexpr, #
782+
BLOCK_N1: tl.constexpr, #
783+
BLOCK_M2: tl.constexpr, #
784+
BLOCK_N2: tl.constexpr, #
785+
BLK_SLICE_FACTOR: tl.constexpr, #
786+
HEAD_DIM: tl.constexpr,
787+
dtype: tl.constexpr,
788+
):
789+
bhid = tl.program_id(2)
790+
off_chz = (bhid * N_CTX).to(tl.int64)
791+
off_bh = (
792+
(stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64)
793+
) // stride_tok
794+
pid = tl.program_id(0)
795+
796+
# offset pointers for batch/head
797+
M += off_chz
798+
D += off_chz
799+
800+
dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
801+
dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
802+
803+
start_n = pid * BLOCK_N1
804+
start_m = 0
805+
806+
# load K and V: they stay in SRAM throughout the inner loop.
807+
k = desc_k.load([(off_bh + start_n).to(tl.int32), 0])
808+
v = desc_v.load([(off_bh + start_n).to(tl.int32), 0])
809+
# Compute dK and dV for non-masked blocks.
810+
num_steps = (N_CTX - start_m) // BLOCK_M1
811+
dk, dv = _attn_bwd_dkdv( #
812+
dk,
813+
dv, #
814+
desc_q,
815+
k,
816+
v,
817+
sm_scale, #
818+
desc_do, #
819+
desc_dq,
820+
M,
821+
D, #
822+
stride_tok,
823+
stride_d, #
824+
off_bh,
825+
H,
826+
N_CTX, #
827+
BLOCK_M1,
828+
BLOCK_N1,
829+
HEAD_DIM, #
830+
start_n,
831+
start_m,
832+
num_steps, #
833+
MASK=False, #
834+
dtype=dtype,
835+
)
836+
837+
desc_dv.store(
838+
[(off_bh + start_n).to(tl.int32), 0],
839+
dv.to(dtype),
840+
)
841+
842+
# Write back dK.
843+
dk *= sm_scale
844+
desc_dk.store(
845+
[(off_bh + start_n).to(tl.int32), 0],
846+
dk.to(dtype),
847+
)
848+
849+
636850
def torch_dtype_to_triton(dtype):
637851
if dtype == torch.float8_e5m2:
638852
return tl.float8e5
@@ -745,5 +959,115 @@ def grid_debug(META):
745959
ctx.causal = causal
746960
return o
747961

962+
@staticmethod
963+
def backward(ctx, do):
964+
q, k, v, o, M = ctx.saved_tensors
965+
assert do.is_contiguous()
966+
assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
967+
dq = torch.zeros(q.shape, device=q.device, dtype=torch.float32)
968+
dk = torch.empty_like(k)
969+
dv = torch.empty_like(v)
970+
BATCH, N_HEAD, N_CTX = q.shape[:3]
971+
PRE_BLOCK = 128
972+
BLK_SLICE_FACTOR = 2
973+
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
974+
arg_k = k
975+
arg_k = arg_k * (ctx.sm_scale * RCP_LN2)
976+
PRE_BLOCK = 128
977+
assert N_CTX % PRE_BLOCK == 0
978+
pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
979+
delta = torch.empty_like(M)
980+
_attn_bwd_preprocess[pre_grid](
981+
o,
982+
do, #
983+
delta, #
984+
BATCH,
985+
N_HEAD,
986+
N_CTX, #
987+
BLOCK_M=PRE_BLOCK,
988+
HEAD_DIM=ctx.HEAD_DIM, #
989+
)
990+
991+
dummy_block = [1, 1]
992+
HEAD_DIM = ctx.HEAD_DIM
993+
desc_k = TensorDescriptor(
994+
arg_k,
995+
shape=[BATCH * N_HEAD * N_CTX, HEAD_DIM],
996+
strides=[HEAD_DIM, 1],
997+
block_shape=dummy_block,
998+
)
999+
desc_v = TensorDescriptor(
1000+
v,
1001+
shape=[BATCH * N_HEAD * N_CTX, HEAD_DIM],
1002+
strides=[HEAD_DIM, 1],
1003+
block_shape=dummy_block,
1004+
)
1005+
desc_q = TensorDescriptor(
1006+
q,
1007+
shape=[BATCH * N_HEAD * N_CTX, HEAD_DIM],
1008+
strides=[HEAD_DIM, 1],
1009+
block_shape=dummy_block,
1010+
)
1011+
desc_do = TensorDescriptor(
1012+
do,
1013+
shape=[BATCH * N_HEAD * N_CTX, HEAD_DIM],
1014+
strides=[HEAD_DIM, 1],
1015+
block_shape=dummy_block,
1016+
)
1017+
desc_dq = TensorDescriptor(
1018+
dq,
1019+
shape=[BATCH * N_HEAD * N_CTX, HEAD_DIM],
1020+
strides=[HEAD_DIM, 1],
1021+
block_shape=dummy_block,
1022+
)
1023+
desc_dk = TensorDescriptor(
1024+
dk,
1025+
shape=[BATCH * N_HEAD * N_CTX, HEAD_DIM],
1026+
strides=[HEAD_DIM, 1],
1027+
block_shape=dummy_block,
1028+
)
1029+
desc_dv = TensorDescriptor(
1030+
dv,
1031+
shape=[BATCH * N_HEAD * N_CTX, HEAD_DIM],
1032+
strides=[HEAD_DIM, 1],
1033+
block_shape=dummy_block,
1034+
)
1035+
1036+
def alloc_fn(size: int, align: int, _):
1037+
return torch.empty(size, dtype=torch.int8, device="cuda")
1038+
1039+
triton.set_allocator(alloc_fn)
1040+
1041+
def grid(meta):
1042+
return (
1043+
triton.cdiv(N_CTX, meta["BLOCK_N1"]), # tiles along N (K/V)
1044+
1, # (or cdiv over M if you need)
1045+
BATCH * N_HEAD,
1046+
) # batch*heads
1047+
1048+
_attn_bwd[grid](
1049+
desc_q,
1050+
desc_k,
1051+
desc_v,
1052+
ctx.sm_scale,
1053+
desc_do,
1054+
desc_dq,
1055+
desc_dk,
1056+
desc_dv, #
1057+
M,
1058+
delta, #
1059+
q.stride(0),
1060+
q.stride(1),
1061+
q.stride(2),
1062+
q.stride(3), #
1063+
N_HEAD,
1064+
N_CTX, #
1065+
BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, #
1066+
HEAD_DIM=ctx.HEAD_DIM, #
1067+
dtype=torch_dtype_to_triton(q.dtype),
1068+
)
1069+
1070+
return dq, dk, dv, None, None, None, None
1071+
7481072

7491073
attention_opt = _attention_opt.apply

tritonbench/operators/ragged_attention/hstu.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@
1111
set_use_runtime_max_seq_len,
1212
)
1313
from generative_recommenders.ops.triton.triton_hstu_attention import triton_hstu_mha
14-
from hammer.ops.triton.triton_ragged_attn_interface import (
15-
triton_ragged_hstu_mha,
16-
)
14+
from hammer.ops.triton.triton_ragged_attn_interface import triton_ragged_hstu_mha
1715

1816
HAS_HAMMER = True
1917
else:

tritonbench/operators/ragged_attention/operator.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,7 @@
1515
register_metric,
1616
)
1717

18-
from .hstu import (
19-
get_test_inputs,
20-
HAS_HAMMER,
21-
triton_hstu_mha,
22-
triton_ragged_hstu_mha,
23-
)
18+
from .hstu import get_test_inputs, HAS_HAMMER, triton_hstu_mha, triton_ragged_hstu_mha
2419

2520
HAS_CUDA = False
2621
try:

0 commit comments

Comments
 (0)