2121from typing import Tuple
2222import numpy as np
2323
24- PROFILE = False
24+ PROFILE = True
2525NUM_PRERUN = 1
26- NUM_ITERATIONS = 1
26+ NUM_ITERATIONS = 50
2727
2828class ReducemaxDescriptor (Structure ):
2929 _fields_ = [("device" , c_int32 )]
@@ -113,7 +113,6 @@ def test(
113113 c_bool (noop_with_empty_axes ),
114114 )
115115 )
116- print (f"op desctiptor created" )
117116 x_tensor .descriptor .contents .invalidate ()
118117 y_tensor .descriptor .contents .invalidate ()
119118 for i in range (NUM_PRERUN if PROFILE else 1 ):
@@ -142,41 +141,57 @@ def test(
142141 )
143142 elapsed = (time .time () - start_time ) / NUM_ITERATIONS
144143 print (f"lib time: { elapsed :10f} " )
145- print (f"custom op output:{ y } " )
146- print (f"pytorch output:{ ans } " )
147- assert torch .allclose (y , ans , atol = 0 , rtol = 1e-3 )
148-
144+ # print(f"input : {x}")
145+ # print(f"custom op output:{y}")
146+ # print(f"pytorch output:{ans}")
149147 check_error (lib .infiniopDestroyReducemaxDescriptor (descriptor ))
148+ assert torch .allclose (y , ans , atol = 0 , rtol = 1e-3 )
150149
151150def test_cpu (lib , test_cases ):
152151 device = DeviceEnum .DEVICE_CPU
153152 handle = create_handle (lib , device )
154- for x_shape , axes , noop_with_empty_axes , keepdims , dynamic_axes in test_cases :
155- print (dynamic_axes )
156- test (lib , handle , "cpu" , x_shape , axes , dynamic_axes , noop_with_empty_axes , keepdims , tensor_dtype = torch .float16 )
153+ for x_shape , axes , noop_with_empty_axes , keepdims , dynamic_axes , tensor_dtype in test_cases :
154+ test (lib , handle , "cpu" , x_shape , axes , dynamic_axes , noop_with_empty_axes , keepdims , tensor_dtype = tensor_dtype )
157155 print ("\n " )
158156 #test(lib, handle, "cpu", x_shape, axes, tensor_dtype=torch.float32)
159157 destroy_handle (lib , handle )
160158
159+ def test_cuda (lib , test_cases ):
160+ device = DeviceEnum .DEVICE_CUDA
161+ handle = create_handle (lib , device )
162+ for x_shape , axes , noop_with_empty_axes , keepdims , dynamic_axes , tensor_dtype in test_cases :
163+ test (lib , handle , "cuda" , x_shape , axes , dynamic_axes , noop_with_empty_axes , keepdims , tensor_dtype = tensor_dtype )
164+ print ("\n " )
165+ destroy_handle (lib , handle )
161166
162167if __name__ == "__main__" :
163168 test_cases = [
164169 # dynamic calc test eg
165- ((2 , 3 , 4 , 5 ), [0 , 2 ], False , True , None ),
166- ((2 , 3 , 4 , 5 ), [0 , 2 ], False , True , None ),
167- #(input_shape, axis, noop_with_empty_axes, keepdims, dynamic_axes)
168- ((2 , 10 , 24 , 10 ), [0 , 2 ], False , True , None ),
169- # stride =
170- ((2 , 10 , 24 , 10 ), [0 , 1 ], False , True , None ),
171- ((2 , 10 , 24 , 10 ), [2 , 3 ], False , True , None ),
172- ((2 , 10 , 24 , 10 ), [0 , 1 , 2 , 3 ], False , True , None ),
173- # validate attribute noop_with_empty_axes and keepdims
174- ((2 , 10 , 24 , 10 ), None , True , True , None ),
175- ((2 , 10 , 24 , 10 ), None , True , False , None ),
176- ((2 , 10 , 24 , 10 ), None , False , True , None ),
177- ((2 , 10 , 24 , 10 ), None , False , False , None ),
178- ((2 , 3 , 4 ), [0 , 1 ], False , False , None ),
170+ # ((2, 3, 4, 5), [0, 2], False, True, None),
171+ # ((2, 3, 4, 5), [0, 2], False, True, None),
172+ # # (input_shape, axis, noop_with_empty_axes, keepdims, dynamic_axes)
173+ # ((2, 10, 24, 10), [0, 2], False, True, None),
174+ # # stride =
175+ # ((2, 10, 24, 10), [0, 1], False, True, None),
176+ # ((2, 10, 24, 10), [2, 3], False , True, None),
177+ # ((2, 10, 24, 10), [0, 1, 2, 3], False, True, None),
178+ # # validate attribute noop_with_empty_axes and keepdims
179+ # ((2, 10, 24, 10), None, True, True, None),
180+ # ((2, 10, 24, 10), None, True, False, None),
181+ # ((2, 10, 24, 10), None, False, True, None),
182+ # ((2, 10, 24, 10), None, False, False, None),
183+ # ((2, 3, 4), [0, 1], False, False, None),
179184 #((2, 10, 24, 10), [], True),
185+ #((4,), [0], False, False, None, torch.float32),
186+ ((1000 , 300 ), [0 , 1 ], False , False , None , torch .float16 ),
187+ ((50 , 3 ), [0 , 1 ], False , False , None , torch .float16 ),
188+ ((1000 , 300 ), [0 , 1 ], False , False , None , torch .float16 ),
189+ ((2000 , 200 , 50 ), [0 , 1 ], False , True , None , torch .float32 ),
190+ ((1000 , 200 , 500 ), [0 , 1 ], False , True , None , torch .float16 ),
191+ ((1000 , 200 , 50 ), [0 , 1 ], False , True , None , torch .float32 ),
192+ ((20 , 3 , 4 , 5 ), [0 , 2 ], False , False , None , torch .float32 ),
193+ ((20 , 30 , 40 , 5 ), [0 , 2 , 3 ], False , False , None , torch .float32 ),
194+ ((200 , 3 , 40 , 5 ), [0 , 3 ], False , False , None , torch .float32 ),
180195 ]
181196 args = get_args ()
182197 lib = open_lib ()
@@ -202,5 +217,8 @@ def test_cpu(lib, test_cases):
202217 ]
203218 lib .infiniopDestroyReducemaxDescriptor .restype = c_int32
204219 lib .infiniopDestroyReducemaxDescriptor .argtypes = [infiniopReducemaxDescriptor_t ]
205- test_cpu (lib , test_cases )
220+ if args .cpu :
221+ test_cpu (lib , test_cases )
222+ if args .cuda :
223+ test_cuda (lib , test_cases )
206224 print ("All tests passed!" )
0 commit comments