diff --git a/onmt/inference_engine.py b/onmt/inference_engine.py index b088f497ef..1998abaa51 100755 --- a/onmt/inference_engine.py +++ b/onmt/inference_engine.py @@ -155,6 +155,13 @@ def __init__(self, opt): self.transforms_cls = get_transforms_cls(opt._all_transform) self.vocabs = self.translator.vocabs + def warm_up(self): + from onmt.translate.translator import build_translator + + self.translator = build_translator( + self.opt, self.device_id, logger=self.logger, report_score=True + ) + def _translate(self, infer_iter): scores, preds = self.translator._translate( infer_iter, infer_iter.transforms, self.opt.attn_debug, self.opt.align_debug