From 78fd51a674727f6bc077391e4bb9c90e0e3a2ed5 Mon Sep 17 00:00:00 2001 From: Wang Xin Date: Sat, 16 Nov 2024 00:40:14 +0000 Subject: [PATCH] merge from #13524 --- tools/infer/utility.py | 35 +++++++++++++++++------------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index c796f5b963..800d12d7b1 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -232,29 +232,28 @@ def create_predictor(args, mode, logger): else: file_names = ["model", "inference"] for file_name in file_names: - model_file_path = "{}/{}.pdmodel".format(model_dir, file_name) - params_file_path = "{}/{}.pdiparams".format(model_dir, file_name) - model_json_file_path = "{}/{}.json".format(model_dir, file_name) - if ( - os.path.exists(model_file_path) or os.path.exists(model_json_file_path) - ) and os.path.exists(params_file_path): + params_file_path = f"{model_dir}/{file_name}.pdiparams" + if os.path.exists(params_file_path): break - if not os.path.exists(model_file_path) and not os.path.exists( - model_json_file_path - ): - raise ValueError( - "not find model.pdmodel or inference.pdmodel in {}".format(model_dir) - ) - if not os.path.exists(model_file_path) and os.path.exists(model_json_file_path): - model_file_path = model_json_file_path + if not os.path.exists(params_file_path): + raise ValueError(f"not find {file_name}.pdiparams in {model_dir}") + + if not ( + os.path.exists(f"{model_dir}/{file_name}.pdmodel") + or os.path.exists(f"{model_dir}/{file_name}.json") + ): raise ValueError( - "not find model.pdiparams or inference.pdiparams in {}".format( - model_dir - ) + f"neither {file_name}.json nor {file_name}.pdmodel was found in {model_dir}." ) - config = inference.Config(model_file_path, params_file_path) + if paddle.__version__ == "0.0.0" or paddle.__version__ >= "3.0.0": + model_path = model_dir + model_prefix = file_name + config = inference.Config(model_path, model_prefix) + else: + model_file_path = f"{model_dir}/{file_name}.pdmodel" + config = inference.Config(model_file_path, params_file_path) if hasattr(args, "precision"): if args.precision == "fp16" and args.use_tensorrt: