Skip to content

Commit

Permalink
merge from PaddlePaddle#13524
Browse files Browse the repository at this point in the history
  • Loading branch information
GreatV committed Nov 16, 2024
1 parent 7723cab commit 78fd51a
Showing 1 changed file with 17 additions and 18 deletions.
35 changes: 17 additions & 18 deletions tools/infer/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,29 +232,28 @@ def create_predictor(args, mode, logger):
else:
file_names = ["model", "inference"]
for file_name in file_names:
model_file_path = "{}/{}.pdmodel".format(model_dir, file_name)
params_file_path = "{}/{}.pdiparams".format(model_dir, file_name)
model_json_file_path = "{}/{}.json".format(model_dir, file_name)
if (
os.path.exists(model_file_path) or os.path.exists(model_json_file_path)
) and os.path.exists(params_file_path):
params_file_path = f"{model_dir}/{file_name}.pdiparams"
if os.path.exists(params_file_path):
break
if not os.path.exists(model_file_path) and not os.path.exists(
model_json_file_path
):
raise ValueError(
"not find model.pdmodel or inference.pdmodel in {}".format(model_dir)
)
if not os.path.exists(model_file_path) and os.path.exists(model_json_file_path):
model_file_path = model_json_file_path

if not os.path.exists(params_file_path):
raise ValueError(f"not find {file_name}.pdiparams in {model_dir}")

if not (
os.path.exists(f"{model_dir}/{file_name}.pdmodel")
or os.path.exists(f"{model_dir}/{file_name}.json")
):
raise ValueError(
"not find model.pdiparams or inference.pdiparams in {}".format(
model_dir
)
f"neither {file_name}.json nor {file_name}.pdmodel was found in {model_dir}."
)

config = inference.Config(model_file_path, params_file_path)
if paddle.__version__ == "0.0.0" or paddle.__version__ >= "3.0.0":
model_path = model_dir
model_prefix = file_name
config = inference.Config(model_path, model_prefix)
else:
model_file_path = f"{model_dir}/{file_name}.pdmodel"
config = inference.Config(model_file_path, params_file_path)

if hasattr(args, "precision"):
if args.precision == "fp16" and args.use_tensorrt:
Expand Down

0 comments on commit 78fd51a

Please sign in to comment.