diff --git a/ppqi/inference.py b/ppqi/inference.py index 75dbb6c..63d3170 100644 --- a/ppqi/inference.py +++ b/ppqi/inference.py @@ -64,8 +64,10 @@ def load_config(self, modelpath, use_gpu, gpu_id, use_mkldnn, cpu_threads): print( '''Error! Unable to use GPU. Please set the environment variables "CUDA_VISIBLE_DEVICES=GPU_id" to use GPU. Now switch to CPU to continue...''') use_gpu = False - - if os.path.isdir(modelpath): + + if isinstance(modelpath, Config): + config = modelpath + elif os.path.isdir(modelpath): if os.path.exists(os.path.join(modelpath, "__params__")): # __model__ + __params__ model = os.path.join(modelpath, "__model__") @@ -87,8 +89,6 @@ def load_config(self, modelpath, use_gpu, gpu_id, use_mkldnn, cpu_threads): model = modelpath + ".pdmodel" params = modelpath + ".pdiparams" config = Config(model, params) - elif isinstance(modelpath, Config): - config = modelpath else: raise Exception( "Error! Can\'t find the model in: %s. Please check your model path." % os.path.abspath(modelpath))