@@ -69,6 +69,7 @@ def scaled_dot_product_attention(
6969    is_causal  =  True 
7070    # implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html 
7171    use_fp32_acc  =  kwargs .get ("use_fp32_acc" , False )
72+     use_fp8_quantize  =  kwargs .get ("use_fp8_quantize" , True )
7273    query_dtype  =  query .dtype 
7374
7475    if  scale  is  None :
@@ -97,6 +98,30 @@ def scaled_dot_product_attention(
9798            key ,
9899            scale ,
99100        )
101+     # fixed value for test 
102+     amax  =  torch .tensor ([0.6562 ])
103+     if  use_fp8_quantize :
104+         key  =  impl .quantize .quantize (
105+             ctx ,
106+             target ,
107+             SourceIR .ATEN ,
108+             name ,
109+             key ,
110+             amax ,
111+             8 ,
112+             4 ,
113+         )
114+ 
115+         query  =  impl .quantize .quantize (
116+             ctx ,
117+             target ,
118+             SourceIR .ATEN ,
119+             name ,
120+             query ,
121+             amax ,
122+             8 ,
123+             4 ,
124+         )
100125
101126    if  use_fp32_acc  and  query_dtype  ==  trt .float16 :
102127        query  =  cast_trt_tensor (
@@ -173,6 +198,29 @@ def scaled_dot_product_attention(
173198    softmax  =  impl .normalization .softmax (
174199        ctx , target , source_ir , name  +  "_softmax" , scaled_add_attn_bias , - 1 , False 
175200    )
201+     if  use_fp8_quantize :
202+         softmax  =  impl .quantize .quantize (
203+             ctx ,
204+             target ,
205+             SourceIR .ATEN ,
206+             name ,
207+             softmax ,
208+             amax ,
209+             8 ,
210+             4 ,
211+         )
212+ 
213+         value  =  impl .quantize .quantize (
214+             ctx ,
215+             target ,
216+             SourceIR .ATEN ,
217+             name ,
218+             value ,
219+             amax ,
220+             8 ,
221+             4 ,
222+         )
223+ 
176224    if  use_fp32_acc :
177225        softmax  =  cast_trt_tensor (
178226            ctx , softmax , trt .float32 , name  +  "_softmax_cast_to_fp32" , target , source_ir 
@@ -188,9 +236,21 @@ def scaled_dot_product_attention(
188236        softmax ,
189237        value ,
190238    )
239+ 
191240    if  use_fp32_acc :
192241        out  =  cast_trt_tensor (
193242            ctx , out , query_dtype , name  +  "_out_cast_to_fp16" , target , source_ir 
194243        )
244+     if  use_fp8_quantize :
245+         out  =  impl .quantize .quantize (
246+             ctx ,
247+             target ,
248+             SourceIR .ATEN ,
249+             name ,
250+             out ,
251+             amax ,
252+             8 ,
253+             4 ,
254+         )
195255
196256    return  out 
0 commit comments