From 511514a451aabcd74f4c757f7e359781f9ad1978 Mon Sep 17 00:00:00 2001 From: Wang Chengke Date: Wed, 29 May 2019 08:39:05 +0800 Subject: [PATCH] feature: select the problem --- student_test.py | 106 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 71 insertions(+), 35 deletions(-) diff --git a/student_test.py b/student_test.py index ec24022..7c2f339 100644 --- a/student_test.py +++ b/student_test.py @@ -9,6 +9,7 @@ import tvm import torch import numpy as np +import argparse import config @@ -566,7 +567,7 @@ def daemon(self, val): proc.__class__ = NonDaemonProcess return proc -def parallel_evaluate(parallel=1): +def parallel_evaluate(parallel=1, problem="both"): """evaluate process student level : synchro @@ -583,9 +584,27 @@ def parallel_evaluate(parallel=1): # test coeffs; currently random conv2d_shapes = config.conv_shapes.copy() gemm_shapes = config.gemm_shapes.copy() - np.random.shuffle(conv2d_shapes) - np.random.shuffle(gemm_shapes) - score_item = ['gemm_' + str(s) for s in gemm_shapes] + ['conv2d_' + str(s) for s in conv2d_shapes] + + # remove randomness + # np.random.shuffle(conv2d_shapes) + # np.random.shuffle(gemm_shapes) + + if problem == "gemm": + is_gemm, is_conv = True, False + elif problem == "conv": + is_gemm, is_conv = False, True + else: + is_gemm, is_conv = True, True + + problem_gemm = ['gemm_' + str(s) for s in gemm_shapes] + problem_conv = ['conv2d_' + str(s) for s in conv2d_shapes] + + if problem == "gemm": + score_item = problem_gemm + elif problem == "conv": + score_item = problem_conv + else: + score_item = problem_gemm + problem_conv target = 'llvm' # for stdout logs @@ -628,45 +647,55 @@ def pool_evaluate(shapes, veri_func, func, target="llvm"): sys.stdout.flush() # evaluate - num_gemms = len(gemm_shapes) - outer = ceil(num_gemms / parallel) - gemm_ret = [] - gemm_error_count = 0 - for i in range(outer): - part_gemm_ret, part_gemm_error = pool_evaluate(gemm_shapes[i * parallel:(i+1) * parallel], torch_gemm, batch_gemm, target) - gemm_ret.extend(part_gemm_ret) - gemm_error_count += part_gemm_error - - num_convs = len(conv2d_shapes) - outer = ceil(num_convs / parallel) - conv_ret = [] - conv_error_count = 0 - for i in range(outer): - part_conv_ret, part_conv_error = pool_evaluate(conv2d_shapes[i * parallel:(i+1) * parallel], torch_conv2d, conv2d_nchw, target) - conv_ret.extend(part_conv_ret) - conv_error_count += part_conv_error - + gemm_error_count, conv_error_count = None, None + if is_gemm: + num_gemms = len(gemm_shapes) + outer = ceil(num_gemms / parallel) + gemm_ret = [] + gemm_error_count = 0 + for i in range(outer): + part_gemm_ret, part_gemm_error = pool_evaluate(gemm_shapes[i * parallel:(i+1) * parallel], torch_gemm, batch_gemm, target) + gemm_ret.extend(part_gemm_ret) + gemm_error_count += part_gemm_error + + if is_conv: + num_convs = len(conv2d_shapes) + outer = ceil(num_convs / parallel) + conv_ret = [] + conv_error_count = 0 + for i in range(outer): + part_conv_ret, part_conv_error = pool_evaluate(conv2d_shapes[i * parallel:(i+1) * parallel], torch_conv2d, conv2d_nchw, target) + conv_ret.extend(part_conv_ret) + conv_error_count += part_conv_error + if gemm_error_count or conv_error_count: exception_info = ' exception raises in {} cases'.format(gemm_error_count + conv_error_count) else: exception_info = ' No exceptions' - + print() print("#####################################################") print("The results:\n") - string = "Time costs of GEMMs\n" - for shape, ret in zip(gemm_shapes, gemm_ret): - times = [ret[0] if ret[0] > 0 else "Timeout", ret[1] if ret[1] > 0 else "Not evaluted"] - string += "{}: yours: {}(ms), torch: {}(ms)\n".format(shape, times[0], times[1]) - print(string) - string = "Time costs of Conv2ds\n" - for shape, ret in zip(conv2d_shapes, conv_ret): - times = [ret[0] if ret[0] > 0 else "Timeout", ret[1] if ret[1] > 0 else "Not evaluted"] - string += "{}: yours: {}(ms), torch: {}(ms)\n".format(shape, times[0], times[1]) - print(string) + evaluate_return = [] + + if is_gemm: + string = "Time costs of GEMMs\n" + for shape, ret in zip(gemm_shapes, gemm_ret): + times = [ret[0] if ret[0] > 0 else "Timeout", ret[1] if ret[1] > 0 else "Not evaluted"] + string += "{}: yours: {}(ms), torch: {}(ms)\n".format(shape, times[0], times[1]) + print(string) + evaluate_return += gemm_ret - score_list = list(map(score_calculate, gemm_ret + conv_ret)) + if is_conv: + string = "Time costs of Conv2ds\n" + for shape, ret in zip(conv2d_shapes, conv_ret): + times = [ret[0] if ret[0] > 0 else "Timeout", ret[1] if ret[1] > 0 else "Not evaluted"] + string += "{}: yours: {}(ms), torch: {}(ms)\n".format(shape, times[0], times[1]) + print(string) + evaluate_return += conv_ret + + score_list = list(map(score_calculate, evaluate_return)) write_score(res_path, score_list, score_item, exception_info) @@ -747,4 +776,11 @@ def score_calculate(time_tuple): return 7.0 if __name__ == '__main__': - parallel_evaluate() \ No newline at end of file + parser = argparse.ArgumentParser(description='') + parser.add_argument('problem', type=str, nargs='?', + default='both', + help='the problem (gemm or conv) example: python student_test.py conv') + args = parser.parse_args() + + + parallel_evaluate(problem = args.problem)