diff --git a/test/saber/test_saber_base.h b/test/saber/test_saber_base.h index bd3ff87c8..10f9d7709 100644 --- a/test/saber/test_saber_base.h +++ b/test/saber/test_saber_base.h @@ -255,7 +255,7 @@ class TestSaberBase{ for(int input_index = 0; input_index < _inputs_dev.size(); ++input_index){ _base_op.init(_inputs_dev[input_index], _outputs_dev[input_index], _params[param_index], strategy, implenum, ctx); - for(int iter=0; iter<100; ++iter){ + for(int iter=0; iter<_gpu_iters; ++iter){ _outputs_dev[input_index][0]->copy_from(*_outputs_host[input_index][0]); status= _base_op(_inputs_dev[input_index], _outputs_dev[input_index], _params[param_index], ctx); @@ -325,6 +325,7 @@ class TestSaberBase{ std :: vector runtype{"STATIC", "RUNTIME", "SPECIFY"}; std :: vector impltype{"VENDER", "SABER"}; + get_cpu_result(CpuFunc);//first get cpu result for(auto strate : {SPECIFY, RUNTIME, STATIC}){ for(auto implenum : {VENDER_IMPL, SABER_IMPL}){ LOG(INFO) << "TESTING: strategy:" << runtype[strate-1] << ",impltype:" << impltype[(int)implenum]; @@ -332,7 +333,6 @@ class TestSaberBase{ LOG(INFO) << "Unimpl!!"; continue; } - get_cpu_result(CpuFunc); result_check_accuracy(succ_ratio); } } @@ -342,6 +342,9 @@ class TestSaberBase{ void set_random_output(bool random_output) { _use_random_output = random_output; } + void set_gpu_iters(int iters){ + _gpu_iters = iters; + } private: int _op_input_num; int _op_output_num; @@ -358,6 +361,7 @@ class TestSaberBase{ std :: vector> _input_shapes; std :: vector _params; bool _use_random_output{false}; + int _gpu_iters{1}; };//testsaberbase }//namespace saber }//namespace anakin