3030 _is_fbgemm_gpu_genai_available ,
3131 is_sm_at_least_89 ,
3232 is_sm_at_least_90 ,
33+ is_sm_at_least_100 ,
3334 torch_version_at_least ,
3435)
3536
@@ -49,6 +50,28 @@ def forward(self, x):
4950 return x
5051
5152
53+ class ToyConvModel (torch .nn .Module ):
54+ def __init__ (
55+ self , dim , in_channels , out_channels , kernel_size , bias , padding , dtype , device
56+ ):
57+ super ().__init__ ()
58+ convs = {1 : torch .nn .Conv1d , 2 : torch .nn .Conv2d , 3 : torch .nn .Conv3d }
59+ self .conv = convs [dim ](
60+ in_channels ,
61+ out_channels ,
62+ kernel_size ,
63+ bias = bias ,
64+ padding = padding ,
65+ dtype = dtype ,
66+ device = device ,
67+ )
68+ if dim == 3 :
69+ self .conv = self .conv .to (memory_format = torch .channels_last_3d )
70+
71+ def forward (self , x ):
72+ return self .conv (x )
73+
74+
5275# TODO: move tests in test_affine_quantized_float.py here after we migrated all implementations
5376@unittest .skipIf (not torch_version_at_least ("2.8.0" ), "Need pytorch 2.8+" )
5477@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
@@ -148,6 +171,91 @@ def test_fp8_linear_variants(
148171 f"Quantization error is too high got a SQNR of { error } "
149172 )
150173
174+ @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
175+ @unittest .skipIf (
176+ not is_sm_at_least_100 (), "Requires GPU with compute capability >= 8.9"
177+ )
178+ @common_utils .parametrize ("dtype" , [torch .bfloat16 , torch .float32 ])
179+ @common_utils .parametrize ("compile" , [True , False ])
180+ @common_utils .parametrize ("granularity" , [PerTensor ()])
181+ @common_utils .parametrize (
182+ "kernel_preference" ,
183+ [KernelPreference .AUTO ],
184+ )
185+ # only test for 3D conv for now
186+ # Inputs are (N, C_in, C_out, D, H, W)
187+ @common_utils .parametrize (
188+ "sizes" ,
189+ [
190+ (4 , 16 , 64 , 32 , 32 , 32 ),
191+ ],
192+ )
193+ def test_fp8_conv_variants (
194+ self ,
195+ dtype : torch .dtype ,
196+ compile : bool ,
197+ granularity ,
198+ kernel_preference : KernelPreference ,
199+ sizes : Tuple ,
200+ ):
201+ if (
202+ isinstance (granularity , PerTensor )
203+ and kernel_preference == KernelPreference .FBGEMM
204+ ):
205+ return unittest .skip (
206+ "per tensor with fbgemm kernel preferece does not work yet"
207+ )
208+
209+ if kernel_preference == KernelPreference .FBGEMM and (
210+ (not _is_fbgemm_gpu_genai_available ()) or (not is_sm_at_least_90 ())
211+ ):
212+ return unittest .skip (
213+ "Requires fbgemm_gpu_genai to run fbgemm kernel preference test"
214+ )
215+
216+ dim = 3
217+ N , C_in , C_out , D , H , W = sizes
218+ kernel_size = 3
219+
220+ # Note: this is channel last memory format
221+ input_tensor = torch .randn (N , C_in , D , H , W , dtype = dtype , device = "cuda" )
222+ input_tensor = input_tensor .to (memory_format = torch .channels_last_3d )
223+
224+ # Create a linear layer with bfloat16 dtype
225+ model = ToyConvModel (
226+ dim ,
227+ C_in ,
228+ C_out ,
229+ kernel_size ,
230+ bias = False ,
231+ padding = 0 ,
232+ dtype = dtype ,
233+ device = "cuda" ,
234+ ).eval ()
235+ print ("model weight shape:" , model .conv .weight .shape )
236+
237+ quantized_model = copy .deepcopy (model )
238+
239+ config = Float8DynamicActivationFloat8WeightConfig (
240+ granularity = granularity ,
241+ kernel_preference = kernel_preference ,
242+ )
243+
244+ _is_conv3d = lambda m , fqn : isinstance (m , torch .nn .Conv3d )
245+
246+ quantize_ (quantized_model , config , filter_fn = _is_conv3d )
247+
248+ if compile :
249+ quantized_model = torch .compile (quantized_model , fullgraph = True )
250+
251+ output_original = model (input_tensor )
252+ output_quantized = quantized_model (input_tensor )
253+
254+ error = compute_error (output_original , output_quantized )
255+ assert compute_error (output_original , output_quantized ) > 20 , (
256+ f"Quantization error is too high got a SQNR of { error } "
257+ )
258+
151259 @common_utils .parametrize ("granularity" , [PerTensor (), PerRow ()])
152260 @unittest .skipIf (
153261 not is_sm_at_least_90 (),
0 commit comments