diff --git a/tools/infer/utility.py b/tools/infer/utility.py index f019e97e86..4f5a501c3c 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -215,22 +215,29 @@ def create_predictor(args, mode, logger): else: file_names = ["model", "inference"] for file_name in file_names: - model_file_path = "{}/{}.pdmodel".format(model_dir, file_name) params_file_path = "{}/{}.pdiparams".format(model_dir, file_name) - if os.path.exists(model_file_path) and os.path.exists(params_file_path): + if os.path.exists(params_file_path): break - if not os.path.exists(model_file_path): + + if not os.path.exists(params_file_path): raise ValueError( - "not find model.pdmodel or inference.pdmodel in {}".format(model_dir) + f"not find {file_name}.pdiparams or {file_name}.pdiparams in {model_dir}" ) - if not os.path.exists(params_file_path): + + if not os.path.exists( + "{}/{}.pdmodel".format(model_dir, file_name) + ) and not os.path.exists("{}/{}.json".format(model_dir, file_name)): raise ValueError( - "not find model.pdiparams or inference.pdiparams in {}".format( - model_dir - ) + f"not find {file_name}.json or {file_name}.pdmodel in {model_dir}" ) - config = inference.Config(model_file_path, params_file_path) + if paddle.__version__ == "0.0.0" or paddle.__version__ >= "3.0.0": + model_path = model_dir + model_prefix = file_name + config = inference.Config(model_path, model_prefix) + else: + model_file_path = "{}/{}.pdmodel".format(model_dir, file_name) + config = inference.Config(model_file_path, params_file_path) if hasattr(args, "precision"): if args.precision == "fp16" and args.use_tensorrt: