@@ -633,6 +633,220 @@ def _attn_fwd_persist(
633633 tile_idx += num_progs
634634
635635
636+ @triton .jit
637+ def _attn_bwd_preprocess (
638+ O ,
639+ DO , #
640+ Delta , #
641+ Z ,
642+ H ,
643+ N_CTX , #
644+ BLOCK_M : tl .constexpr ,
645+ HEAD_DIM : tl .constexpr , #
646+ ):
647+ off_m = tl .program_id (0 ) * BLOCK_M + tl .arange (0 , BLOCK_M )
648+ off_hz = tl .program_id (1 )
649+ off_n = tl .arange (0 , HEAD_DIM )
650+ # load
651+ o = tl .load (
652+ O + off_hz * HEAD_DIM * N_CTX + off_m [:, None ] * HEAD_DIM + off_n [None , :]
653+ )
654+ do = tl .load (
655+ DO + off_hz * HEAD_DIM * N_CTX + off_m [:, None ] * HEAD_DIM + off_n [None , :]
656+ ).to (tl .float32 )
657+ delta = tl .sum (o * do , axis = 1 )
658+ # write-back
659+ tl .store (Delta + off_hz * N_CTX + off_m , delta )
660+
661+
662+ # The main inner-loop logic for computing dK and dV.
663+ @triton .jit
664+ def _attn_bwd_dkdv (
665+ dk ,
666+ dv , #
667+ desc_q ,
668+ k ,
669+ v ,
670+ sm_scale , #
671+ desc_do , #
672+ desc_dq ,
673+ M ,
674+ D , #
675+ # shared by Q/K/V/DO.
676+ stride_tok ,
677+ stride_d , #
678+ off_bh ,
679+ H ,
680+ N_CTX ,
681+ BLOCK_M1 : tl .constexpr , #
682+ BLOCK_N1 : tl .constexpr , #
683+ HEAD_DIM : tl .constexpr , #
684+ # Filled in by the wrapper.
685+ start_n ,
686+ start_m ,
687+ num_steps , #
688+ MASK : tl .constexpr ,
689+ dtype : tl .constexpr ,
690+ ):
691+ offs_m = start_m + tl .arange (0 , BLOCK_M1 )
692+ offs_n = start_n + tl .arange (0 , BLOCK_N1 )
693+
694+ LN2 : tl .constexpr = 0.6931471824645996 # = ln(2)
695+
696+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
697+ tl .static_assert (BLOCK_N1 % BLOCK_M1 == 0 )
698+ curr_m = start_m
699+ step_m = BLOCK_M1
700+ for blk_idx in range (num_steps ):
701+ q = desc_q .load ([(off_bh + curr_m ).to (tl .int32 ), 0 ])
702+ qT = tl .trans (q )
703+ # Load m before computing qk to reduce pipeline stall.
704+ offs_m = curr_m + tl .arange (0 , BLOCK_M1 )
705+ m = tl .load (M + offs_m )
706+ qkT = tl .dot (k , qT )
707+ pT = tl .math .exp2 (qkT - m [None , :])
708+ # Autoregressive masking.
709+ if MASK :
710+ mask = offs_m [None , :] >= offs_n [:, None ]
711+ pT = tl .where (mask , pT , 0.0 )
712+ do = desc_do .load ([(off_bh + curr_m ).to (tl .int32 ), 0 ])
713+ # Compute dV.
714+ ppT = pT
715+ ppT = ppT .to (dtype )
716+ dv += tl .dot (ppT , do )
717+ # D (= delta) is pre-divided by ds_scale.
718+ Di = tl .load (D + offs_m )
719+ # Compute dP and dS.
720+ dpT = tl .dot (v , tl .trans (do )).to (tl .float32 )
721+ dsT = pT * (dpT - Di [None , :])
722+ dsT = dsT .to (dtype )
723+ dk += tl .dot (dsT , tl .trans (qT ))
724+ # Compute dq = tl.dot(tl.trans(dsT), k)
725+ dq = tl .dot (tl .trans (dsT ), k ) * LN2
726+ desc_dq .atomic_add ([(off_bh + curr_m ).to (tl .int32 ), 0 ], dq )
727+ # Increment pointers.
728+ curr_m += step_m
729+
730+ return dk , dv
731+
732+
733+ def _bwd_host_descriptor_pre_hook (nargs ):
734+ BLOCK_M1 = nargs ["BLOCK_M1" ]
735+ BLOCK_N1 = nargs ["BLOCK_N1" ]
736+ HEAD_DIM = nargs ["HEAD_DIM" ]
737+ nargs ["desc_q" ].block_shape = [BLOCK_M1 , HEAD_DIM ]
738+ nargs ["desc_do" ].block_shape = [BLOCK_M1 , HEAD_DIM ]
739+ nargs ["desc_dq" ].block_shape = [BLOCK_M1 , HEAD_DIM ]
740+ nargs ["desc_v" ].block_shape = [BLOCK_N1 , HEAD_DIM ]
741+ nargs ["desc_k" ].block_shape = [BLOCK_N1 , HEAD_DIM ]
742+ nargs ["desc_dv" ].block_shape = [BLOCK_N1 , HEAD_DIM ]
743+ nargs ["desc_dk" ].block_shape = [BLOCK_N1 , HEAD_DIM ]
744+
745+
746+ configs_bwd = [
747+ triton .Config (
748+ {
749+ "BLOCK_M1" : 32 ,
750+ "BLOCK_N1" : 128 ,
751+ "BLOCK_M2" : 128 ,
752+ "BLOCK_N2" : 32 ,
753+ },
754+ num_warps = 4 ,
755+ num_stages = 1 ,
756+ pre_hook = _bwd_host_descriptor_pre_hook ,
757+ )
758+ ]
759+
760+
761+ @triton .autotune (configs = configs_bwd , key = ["N_CTX" , "HEAD_DIM" ])
762+ @triton .jit
763+ def _attn_bwd (
764+ desc_q ,
765+ desc_k ,
766+ desc_v ,
767+ sm_scale , #
768+ desc_do , #
769+ desc_dq ,
770+ desc_dk ,
771+ desc_dv , #
772+ M ,
773+ D ,
774+ # shared by Q/K/V/DO.
775+ stride_z ,
776+ stride_h ,
777+ stride_tok ,
778+ stride_d , #
779+ H ,
780+ N_CTX , #
781+ BLOCK_M1 : tl .constexpr , #
782+ BLOCK_N1 : tl .constexpr , #
783+ BLOCK_M2 : tl .constexpr , #
784+ BLOCK_N2 : tl .constexpr , #
785+ BLK_SLICE_FACTOR : tl .constexpr , #
786+ HEAD_DIM : tl .constexpr ,
787+ dtype : tl .constexpr ,
788+ ):
789+ bhid = tl .program_id (2 )
790+ off_chz = (bhid * N_CTX ).to (tl .int64 )
791+ off_bh = (
792+ (stride_h * (bhid % H ) + stride_z * (bhid // H )).to (tl .int64 )
793+ ) // stride_tok
794+ pid = tl .program_id (0 )
795+
796+ # offset pointers for batch/head
797+ M += off_chz
798+ D += off_chz
799+
800+ dv = tl .zeros ([BLOCK_N1 , HEAD_DIM ], dtype = tl .float32 )
801+ dk = tl .zeros ([BLOCK_N1 , HEAD_DIM ], dtype = tl .float32 )
802+
803+ start_n = pid * BLOCK_N1
804+ start_m = 0
805+
806+ # load K and V: they stay in SRAM throughout the inner loop.
807+ k = desc_k .load ([(off_bh + start_n ).to (tl .int32 ), 0 ])
808+ v = desc_v .load ([(off_bh + start_n ).to (tl .int32 ), 0 ])
809+ # Compute dK and dV for non-masked blocks.
810+ num_steps = (N_CTX - start_m ) // BLOCK_M1
811+ dk , dv = _attn_bwd_dkdv ( #
812+ dk ,
813+ dv , #
814+ desc_q ,
815+ k ,
816+ v ,
817+ sm_scale , #
818+ desc_do , #
819+ desc_dq ,
820+ M ,
821+ D , #
822+ stride_tok ,
823+ stride_d , #
824+ off_bh ,
825+ H ,
826+ N_CTX , #
827+ BLOCK_M1 ,
828+ BLOCK_N1 ,
829+ HEAD_DIM , #
830+ start_n ,
831+ start_m ,
832+ num_steps , #
833+ MASK = False , #
834+ dtype = dtype ,
835+ )
836+
837+ desc_dv .store (
838+ [(off_bh + start_n ).to (tl .int32 ), 0 ],
839+ dv .to (dtype ),
840+ )
841+
842+ # Write back dK.
843+ dk *= sm_scale
844+ desc_dk .store (
845+ [(off_bh + start_n ).to (tl .int32 ), 0 ],
846+ dk .to (dtype ),
847+ )
848+
849+
636850def torch_dtype_to_triton (dtype ):
637851 if dtype == torch .float8_e5m2 :
638852 return tl .float8e5
@@ -745,5 +959,115 @@ def grid_debug(META):
745959 ctx .causal = causal
746960 return o
747961
962+ @staticmethod
963+ def backward (ctx , do ):
964+ q , k , v , o , M = ctx .saved_tensors
965+ assert do .is_contiguous ()
966+ assert q .stride () == k .stride () == v .stride () == o .stride () == do .stride ()
967+ dq = torch .zeros (q .shape , device = q .device , dtype = torch .float32 )
968+ dk = torch .empty_like (k )
969+ dv = torch .empty_like (v )
970+ BATCH , N_HEAD , N_CTX = q .shape [:3 ]
971+ PRE_BLOCK = 128
972+ BLK_SLICE_FACTOR = 2
973+ RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
974+ arg_k = k
975+ arg_k = arg_k * (ctx .sm_scale * RCP_LN2 )
976+ PRE_BLOCK = 128
977+ assert N_CTX % PRE_BLOCK == 0
978+ pre_grid = (N_CTX // PRE_BLOCK , BATCH * N_HEAD )
979+ delta = torch .empty_like (M )
980+ _attn_bwd_preprocess [pre_grid ](
981+ o ,
982+ do , #
983+ delta , #
984+ BATCH ,
985+ N_HEAD ,
986+ N_CTX , #
987+ BLOCK_M = PRE_BLOCK ,
988+ HEAD_DIM = ctx .HEAD_DIM , #
989+ )
990+
991+ dummy_block = [1 , 1 ]
992+ HEAD_DIM = ctx .HEAD_DIM
993+ desc_k = TensorDescriptor (
994+ arg_k ,
995+ shape = [BATCH * N_HEAD * N_CTX , HEAD_DIM ],
996+ strides = [HEAD_DIM , 1 ],
997+ block_shape = dummy_block ,
998+ )
999+ desc_v = TensorDescriptor (
1000+ v ,
1001+ shape = [BATCH * N_HEAD * N_CTX , HEAD_DIM ],
1002+ strides = [HEAD_DIM , 1 ],
1003+ block_shape = dummy_block ,
1004+ )
1005+ desc_q = TensorDescriptor (
1006+ q ,
1007+ shape = [BATCH * N_HEAD * N_CTX , HEAD_DIM ],
1008+ strides = [HEAD_DIM , 1 ],
1009+ block_shape = dummy_block ,
1010+ )
1011+ desc_do = TensorDescriptor (
1012+ do ,
1013+ shape = [BATCH * N_HEAD * N_CTX , HEAD_DIM ],
1014+ strides = [HEAD_DIM , 1 ],
1015+ block_shape = dummy_block ,
1016+ )
1017+ desc_dq = TensorDescriptor (
1018+ dq ,
1019+ shape = [BATCH * N_HEAD * N_CTX , HEAD_DIM ],
1020+ strides = [HEAD_DIM , 1 ],
1021+ block_shape = dummy_block ,
1022+ )
1023+ desc_dk = TensorDescriptor (
1024+ dk ,
1025+ shape = [BATCH * N_HEAD * N_CTX , HEAD_DIM ],
1026+ strides = [HEAD_DIM , 1 ],
1027+ block_shape = dummy_block ,
1028+ )
1029+ desc_dv = TensorDescriptor (
1030+ dv ,
1031+ shape = [BATCH * N_HEAD * N_CTX , HEAD_DIM ],
1032+ strides = [HEAD_DIM , 1 ],
1033+ block_shape = dummy_block ,
1034+ )
1035+
1036+ def alloc_fn (size : int , align : int , _ ):
1037+ return torch .empty (size , dtype = torch .int8 , device = "cuda" )
1038+
1039+ triton .set_allocator (alloc_fn )
1040+
1041+ def grid (meta ):
1042+ return (
1043+ triton .cdiv (N_CTX , meta ["BLOCK_N1" ]), # tiles along N (K/V)
1044+ 1 , # (or cdiv over M if you need)
1045+ BATCH * N_HEAD ,
1046+ ) # batch*heads
1047+
1048+ _attn_bwd [grid ](
1049+ desc_q ,
1050+ desc_k ,
1051+ desc_v ,
1052+ ctx .sm_scale ,
1053+ desc_do ,
1054+ desc_dq ,
1055+ desc_dk ,
1056+ desc_dv , #
1057+ M ,
1058+ delta , #
1059+ q .stride (0 ),
1060+ q .stride (1 ),
1061+ q .stride (2 ),
1062+ q .stride (3 ), #
1063+ N_HEAD ,
1064+ N_CTX , #
1065+ BLK_SLICE_FACTOR = BLK_SLICE_FACTOR , #
1066+ HEAD_DIM = ctx .HEAD_DIM , #
1067+ dtype = torch_dtype_to_triton (q .dtype ),
1068+ )
1069+
1070+ return dq , dk , dv , None , None , None , None
1071+
7481072
7491073attention_opt = _attention_opt .apply
0 commit comments