Skip to content

Commit

Permalink
Unify the opt for using gpu
Browse files Browse the repository at this point in the history
Now both training and translation use `--gpu_rank 0`.

Closes #82
  • Loading branch information
Waino committed Dec 9, 2024
1 parent 47beb04 commit 18458d5
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
4 changes: 2 additions & 2 deletions mammoth/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,15 +903,15 @@ def translate_opts(parser, dynamic=False):
_add_logging_opts(parser, is_train=False)

group = parser.add_argument_group('Efficiency')
group.add('--batch_size', '-batch_size', type=int, default=300, help='Batch size')
group.add('--batch_size', '-batch_size', type=int, default=200, help='Batch size')
group.add(
'--batch_type',
'-batch_type',
default='tokens',
choices=["sents", "tokens"],
help="Batch grouping for batch_size. Standard is tokens (max of src and tgt). Sents is unimplemented.",
)
group.add('--gpu', '-gpu', type=int, default=-1, help="Device to run on")
group.add('--gpu_rank', '-gpu_rank', type=int, default=-1, help="Device to run on")

group.add(
"--output_model",
Expand Down
6 changes: 3 additions & 3 deletions mammoth/translate/translation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def setdefault_if_exists_must_match(obj, name, value):

onmt_for_translator = {
"device": "cuda" if opts.cuda else "cpu",
"device_index": opts.gpu if opts.cuda else 0,
"device_index": opts.gpu_rank if opts.cuda else 0,
}
for name, value in onmt_for_translator.items():
setdefault_if_exists_must_match(ct2_translator_args, name, value)
Expand Down Expand Up @@ -417,7 +417,7 @@ def parse_opt(self, opts):
ArgumentParser.validate_prepare_opts(opts)
ArgumentParser.validate_translate_opts(opts)
ArgumentParser.validate_translate_opts_dynamic(opts)
opts.cuda = opts.gpu > -1
opts.cuda = opts.gpu_rank > -1

sys.argv = prec_argv
return opts
Expand Down Expand Up @@ -727,7 +727,7 @@ def to_gpu(self):
if isinstance(self.translator, CTranslate2Translator):
self.translator.to_gpu()
else:
torch.cuda.set_device(self.opts.gpu)
torch.cuda.set_device(self.opts.gpu_rank)
self.translator.model.cuda()

def maybe_preprocess(self, sequence):
Expand Down
6 changes: 4 additions & 2 deletions mammoth/translate/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def build_translator(opts, task_queue_manager, task, report_score=True, logger=N
if out_file is None:
outdir = os.path.dirname(opts.output)
if outdir and not os.path.isdir(outdir):
warnings.warning(f'output file directory "{outdir}" does not exist... creating it.')
warnings.warn(f'output file directory "{outdir}" does not exist... creating it.')
os.makedirs(os.path.dirname(opts.output), exist_ok=True)
out_file = codecs.open(opts.output, "w+", "utf-8")

Expand Down Expand Up @@ -308,7 +308,7 @@ def from_opts(
vocabs,
opts.src,
tgt_file_path=opts.tgt,
gpu=opts.gpu,
gpu=opts.gpu_rank,
n_best=opts.n_best,
min_length=opts.min_length,
max_length=opts.max_length,
Expand Down Expand Up @@ -831,6 +831,8 @@ def _translate_batch_with_strategy(self, batch, src_vocabs, decode_strategy):
task_id=metadata.corpus_id,
adapter_ids=metadata.decoder_adapter_ids,
)
active_encoder.to(self._device)
active_decoder.to(self._device)

# (2) Run the encoder on the src
encoder_output, src_mask = self._run_encoder(active_encoder, batch)
Expand Down

0 comments on commit 18458d5

Please sign in to comment.