diff --git a/nmt/inference.py b/nmt/inference.py index 2cbef07c2..90fe5d287 100644 --- a/nmt/inference.py +++ b/nmt/inference.py @@ -95,10 +95,13 @@ def get_model_creator(hparams): return model_creator -def start_sess_and_load_model(infer_model, ckpt_path): +def start_sess_and_load_model(infer_model, ckpt_path, hparams): """Start session and load model.""" + print("intra inter is %d %d \n" %(hparams.num_intra_threads , hparams.num_inter_threads)) sess = tf.Session( - graph=infer_model.graph, config=utils.get_config_proto()) + graph=infer_model.graph, config=utils.get_config_proto( + num_intra_threads=hparams.num_intra_threads, + num_inter_threads=hparams.num_inter_threads)) with infer_model.graph.as_default(): loaded_infer_model = model_helper.load_model( infer_model.model, ckpt_path, sess, "infer") @@ -118,7 +121,7 @@ def inference(ckpt_path, model_creator = get_model_creator(hparams) infer_model = model_helper.create_infer_model(model_creator, hparams, scope) - sess, loaded_infer_model = start_sess_and_load_model(infer_model, ckpt_path) + sess, loaded_infer_model = start_sess_and_load_model(infer_model, ckpt_path, hparams) if num_workers == 1: single_worker_inference(