diff --git a/mammoth/opts.py b/mammoth/opts.py index ef0b6514..1ff4524b 100644 --- a/mammoth/opts.py +++ b/mammoth/opts.py @@ -903,7 +903,7 @@ def translate_opts(parser, dynamic=False): _add_logging_opts(parser, is_train=False) group = parser.add_argument_group('Efficiency') - group.add('--batch_size', '-batch_size', type=int, default=300, help='Batch size') + group.add('--batch_size', '-batch_size', type=int, default=200, help='Batch size') group.add( '--batch_type', '-batch_type', @@ -911,7 +911,7 @@ def translate_opts(parser, dynamic=False): choices=["sents", "tokens"], help="Batch grouping for batch_size. Standard is tokens (max of src and tgt). Sents is unimplemented.", ) - group.add('--gpu', '-gpu', type=int, default=-1, help="Device to run on") + group.add('--gpu_rank', '-gpu_rank', type=int, default=-1, help="Device to run on") group.add( "--output_model", diff --git a/mammoth/translate/translation_server.py b/mammoth/translate/translation_server.py index d96bc784..541ea262 100644 --- a/mammoth/translate/translation_server.py +++ b/mammoth/translate/translation_server.py @@ -119,7 +119,7 @@ def setdefault_if_exists_must_match(obj, name, value): onmt_for_translator = { "device": "cuda" if opts.cuda else "cpu", - "device_index": opts.gpu if opts.cuda else 0, + "device_index": opts.gpu_rank if opts.cuda else 0, } for name, value in onmt_for_translator.items(): setdefault_if_exists_must_match(ct2_translator_args, name, value) @@ -417,7 +417,7 @@ def parse_opt(self, opts): ArgumentParser.validate_prepare_opts(opts) ArgumentParser.validate_translate_opts(opts) ArgumentParser.validate_translate_opts_dynamic(opts) - opts.cuda = opts.gpu > -1 + opts.cuda = opts.gpu_rank > -1 sys.argv = prec_argv return opts @@ -727,7 +727,7 @@ def to_gpu(self): if isinstance(self.translator, CTranslate2Translator): self.translator.to_gpu() else: - torch.cuda.set_device(self.opts.gpu) + torch.cuda.set_device(self.opts.gpu_rank) self.translator.model.cuda() def maybe_preprocess(self, sequence): diff --git a/mammoth/translate/translator.py b/mammoth/translate/translator.py index 295ade6a..9ac54abc 100644 --- a/mammoth/translate/translator.py +++ b/mammoth/translate/translator.py @@ -28,7 +28,7 @@ def build_translator(opts, task_queue_manager, task, report_score=True, logger=N if out_file is None: outdir = os.path.dirname(opts.output) if outdir and not os.path.isdir(outdir): - warnings.warning(f'output file directory "{outdir}" does not exist... creating it.') + warnings.warn(f'output file directory "{outdir}" does not exist... creating it.') os.makedirs(os.path.dirname(opts.output), exist_ok=True) out_file = codecs.open(opts.output, "w+", "utf-8") @@ -308,7 +308,7 @@ def from_opts( vocabs, opts.src, tgt_file_path=opts.tgt, - gpu=opts.gpu, + gpu=opts.gpu_rank, n_best=opts.n_best, min_length=opts.min_length, max_length=opts.max_length, @@ -831,6 +831,8 @@ def _translate_batch_with_strategy(self, batch, src_vocabs, decode_strategy): task_id=metadata.corpus_id, adapter_ids=metadata.decoder_adapter_ids, ) + active_encoder.to(self._device) + active_decoder.to(self._device) # (2) Run the encoder on the src encoder_output, src_mask = self._run_encoder(active_encoder, batch)