@@ -63,8 +63,6 @@ def random_sample(data, random_val, topp, topk, voc, temperature, torch_device):
6363 else :
6464 end = topk
6565
66-
67-
6866 sum_s = 0
6967 for i in range (end ):
7068 sum_s += dataNp [i ]
@@ -78,12 +76,14 @@ def random_sample(data, random_val, topp, topk, voc, temperature, torch_device):
7876
7977def random_sample_0 (data ):
8078 return torch .argmax (data )
79+
8180def test (lib , handle , torch_device , voc , random_val , topp , topk , temperature , x_dtype = torch .float16 ):
8281 print (
8382 f"Testing RandomSample on { torch_device } with voc:{ voc } dtype:{ x_dtype } "
8483 )
85-
86- data = torch .rand ((voc ), dtype = x_dtype ).to (torch_device )
84+ data = torch .arange (voc ).float () * 0.0001
85+ _perm = torch .randperm (voc )
86+ data = data [_perm ].to (x_dtype ).to (torch_device )
8787 if (topp > 0 and topk > 1 ):
8888 ans = random_sample (data .to ("cpu" ), random_val , topp , topk , voc , temperature , "cpu" )
8989 else :
@@ -130,12 +130,9 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_
130130 if torch_device == "npu" :
131131 torch .npu .synchronize ()
132132
133- assert indices [0 ].type (ans .dtype ) == ans or abs (data [indices [0 ]] - data [ans ]) == 0.0 , "compute error"
134-
135-
136-
133+ assert indices [0 ].type (ans .dtype ) == ans or data [ans ] == data [indices [0 ]]
137134 check_error (lib .infiniopDestroyRandomSampleDescriptor (descriptor ))
138-
135+ print ( "Test passed!" )
139136
140137def test_cpu (lib , test_cases ):
141138 device = DeviceEnum .DEVICE_CPU
@@ -176,15 +173,16 @@ def test_ascend(lib, test_cases):
176173if __name__ == "__main__" :
177174 test_cases = [
178175 # voc, random_val, topp, topk, temperature
179- (512 , 0.92 , 0.8 , 3 , 0.5 ),
180- (4096 , 0.95 , 0.9 , 5 , 1.0 ),
181- (16384 , 0.85 , 0.85 , 10 , 2.0 ),
182- (512 , 0.92 , 0 , 3 , 0.5 ),
183- (4096 , 0.95 , 0.9 , 1 , 1.0 ),
184- (16384 , 0.85 , 0 , 1 , 2.0 ),
185- (16384 , 0.85 , 0 , 1 , 2.0 ),
186- (32000 , 0.8 , 0.8 , 50 , 1.0 ),
187- (32000 , 0.8 , 1.0 , 25 , 1.0 ),
176+ (512 , 0.8 , 0.8 , 3 , 0.5 ),
177+ (4096 , 0.05 , 0.9 , 5 , 1.0 ),
178+ (16384 , 0.15 , 0.85 , 10 , 2.0 ),
179+ (512 , 0.08 , 0 , 3 , 0.5 ),
180+ (4096 , 0.5 , 0.9 , 1 , 1.0 ),
181+ (16384 , 0.15 , 0 , 1 , 2.0 ),
182+ (16384 , 0.15 , 0 , 1 , 2.0 ),
183+ (32000 , 0.08 , 0.8 , 50 , 1.0 ),
184+ (32000 , 0.08 , 1.0 , 25 , 1.0 ),
185+ # (119696, 0.01, 1.0, 100, 1.0),
188186 ]
189187
190188 args = get_args ()
@@ -228,4 +226,4 @@ def test_ascend(lib, test_cases):
228226 test_ascend (lib , test_cases )
229227 if not (args .cpu or args .cuda or args .bang or args .ascend ):
230228 test_cpu (lib , test_cases )
231- print ("Test passed!" )
229+ print ("\033 [92mTest passed!\033 [0m " )
0 commit comments