Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 71 additions & 35 deletions student_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import tvm
import torch
import numpy as np
import argparse

import config

Expand Down Expand Up @@ -566,7 +567,7 @@ def daemon(self, val):
proc.__class__ = NonDaemonProcess
return proc

def parallel_evaluate(parallel=1):
def parallel_evaluate(parallel=1, problem="both"):
"""evaluate process

student level : synchro
Expand All @@ -583,9 +584,27 @@ def parallel_evaluate(parallel=1):
# test coeffs; currently random
conv2d_shapes = config.conv_shapes.copy()
gemm_shapes = config.gemm_shapes.copy()
np.random.shuffle(conv2d_shapes)
np.random.shuffle(gemm_shapes)
score_item = ['gemm_' + str(s) for s in gemm_shapes] + ['conv2d_' + str(s) for s in conv2d_shapes]

# remove randomness
# np.random.shuffle(conv2d_shapes)
# np.random.shuffle(gemm_shapes)

if problem == "gemm":
is_gemm, is_conv = True, False
elif problem == "conv":
is_gemm, is_conv = False, True
else:
is_gemm, is_conv = True, True

problem_gemm = ['gemm_' + str(s) for s in gemm_shapes]
problem_conv = ['conv2d_' + str(s) for s in conv2d_shapes]

if problem == "gemm":
score_item = problem_gemm
elif problem == "conv":
score_item = problem_conv
else:
score_item = problem_gemm + problem_conv
target = 'llvm'

# for stdout logs
Expand Down Expand Up @@ -628,45 +647,55 @@ def pool_evaluate(shapes, veri_func, func, target="llvm"):
sys.stdout.flush()

# evaluate
num_gemms = len(gemm_shapes)
outer = ceil(num_gemms / parallel)
gemm_ret = []
gemm_error_count = 0
for i in range(outer):
part_gemm_ret, part_gemm_error = pool_evaluate(gemm_shapes[i * parallel:(i+1) * parallel], torch_gemm, batch_gemm, target)
gemm_ret.extend(part_gemm_ret)
gemm_error_count += part_gemm_error

num_convs = len(conv2d_shapes)
outer = ceil(num_convs / parallel)
conv_ret = []
conv_error_count = 0
for i in range(outer):
part_conv_ret, part_conv_error = pool_evaluate(conv2d_shapes[i * parallel:(i+1) * parallel], torch_conv2d, conv2d_nchw, target)
conv_ret.extend(part_conv_ret)
conv_error_count += part_conv_error

gemm_error_count, conv_error_count = None, None
if is_gemm:
num_gemms = len(gemm_shapes)
outer = ceil(num_gemms / parallel)
gemm_ret = []
gemm_error_count = 0
for i in range(outer):
part_gemm_ret, part_gemm_error = pool_evaluate(gemm_shapes[i * parallel:(i+1) * parallel], torch_gemm, batch_gemm, target)
gemm_ret.extend(part_gemm_ret)
gemm_error_count += part_gemm_error

if is_conv:
num_convs = len(conv2d_shapes)
outer = ceil(num_convs / parallel)
conv_ret = []
conv_error_count = 0
for i in range(outer):
part_conv_ret, part_conv_error = pool_evaluate(conv2d_shapes[i * parallel:(i+1) * parallel], torch_conv2d, conv2d_nchw, target)
conv_ret.extend(part_conv_ret)
conv_error_count += part_conv_error

if gemm_error_count or conv_error_count:
exception_info = ' exception raises in {} cases'.format(gemm_error_count + conv_error_count)
else:
exception_info = ' No exceptions'

print()
print("#####################################################")
print("The results:\n")
string = "Time costs of GEMMs\n"
for shape, ret in zip(gemm_shapes, gemm_ret):
times = [ret[0] if ret[0] > 0 else "Timeout", ret[1] if ret[1] > 0 else "Not evaluted"]
string += "{}: yours: {}(ms), torch: {}(ms)\n".format(shape, times[0], times[1])
print(string)

string = "Time costs of Conv2ds\n"
for shape, ret in zip(conv2d_shapes, conv_ret):
times = [ret[0] if ret[0] > 0 else "Timeout", ret[1] if ret[1] > 0 else "Not evaluted"]
string += "{}: yours: {}(ms), torch: {}(ms)\n".format(shape, times[0], times[1])
print(string)
evaluate_return = []

if is_gemm:
string = "Time costs of GEMMs\n"
for shape, ret in zip(gemm_shapes, gemm_ret):
times = [ret[0] if ret[0] > 0 else "Timeout", ret[1] if ret[1] > 0 else "Not evaluted"]
string += "{}: yours: {}(ms), torch: {}(ms)\n".format(shape, times[0], times[1])
print(string)
evaluate_return += gemm_ret

score_list = list(map(score_calculate, gemm_ret + conv_ret))
if is_conv:
string = "Time costs of Conv2ds\n"
for shape, ret in zip(conv2d_shapes, conv_ret):
times = [ret[0] if ret[0] > 0 else "Timeout", ret[1] if ret[1] > 0 else "Not evaluted"]
string += "{}: yours: {}(ms), torch: {}(ms)\n".format(shape, times[0], times[1])
print(string)
evaluate_return += conv_ret

score_list = list(map(score_calculate, evaluate_return))

write_score(res_path, score_list, score_item, exception_info)

Expand Down Expand Up @@ -747,4 +776,11 @@ def score_calculate(time_tuple):
return 7.0

if __name__ == '__main__':
parallel_evaluate()
parser = argparse.ArgumentParser(description='')
parser.add_argument('problem', type=str, nargs='?',
default='both',
help='the problem (gemm or conv) example: python student_test.py conv')
args = parser.parse_args()


parallel_evaluate(problem = args.problem)