Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clean up opts #2545

Merged
merged 1 commit into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions docs/source/examples/replicate_vicuna/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,7 @@ def prune_history(user_messages_sizes, bot_messages_sizes, max_history_size):

def _get_parser():
parser = ArgumentParser(description="chatbot.py")
opts.config_opts(parser)
opts.translate_opts(parser, dynamic=True)
opts.translate_opts(parser)
opts.model_opts(parser)
return parser

Expand Down
3 changes: 1 addition & 2 deletions docs/source/examples/replicate_vicuna/simple_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@

def _get_parser():
parser = ArgumentParser(description="simple_inference_engine_py.py")
opts.config_opts(parser)
opts.translate_opts(parser, dynamic=True)
opts.translate_opts(parser)
opts.model_opts(parser)
return parser

Expand Down
4 changes: 1 addition & 3 deletions eval_llm/MMLU-FR/run_mmlu_opennmt_fr.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,7 @@ def evaluate(opt):

def _get_parser():
parser = ArgumentParser(description="run_mmlu_opennmt_fr.py")

opts.config_opts(parser)
opts.translate_opts(parser, dynamic=True)
opts.translate_opts(parser)
return parser


Expand Down
4 changes: 1 addition & 3 deletions eval_llm/MMLU/run_mmlu_opennmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,7 @@ def evaluate(opt):

def _get_parser():
parser = ArgumentParser(description="run_mmlu_opennmt.py")

opts.config_opts(parser)
opts.translate_opts(parser, dynamic=True)
opts.translate_opts(parser)
return parser


Expand Down
4 changes: 2 additions & 2 deletions onmt/bin/build_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from onmt.utils.logging import init_logger, logger
from onmt.utils.misc import set_random_seed, check_path
from onmt.utils.parse import ArgumentParser
from onmt.opts import dynamic_prepare_opts
from onmt.opts import data_prepare_opts
from onmt.inputters.text_corpus import build_corpora_iters, get_corpora
from onmt.inputters.text_utils import process, append_features_to_text
from onmt.transforms import make_transforms, get_transforms_cls
Expand Down Expand Up @@ -273,7 +273,7 @@ def save_counter(counter, save_path):

def _get_parser():
parser = ArgumentParser(description="build_vocab.py")
dynamic_prepare_opts(parser, build_vocab_only=True)
data_prepare_opts(parser, build_vocab_only=True)
return parser


Expand Down
4 changes: 2 additions & 2 deletions onmt/bin/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def _get_parser():
parser.add_argument("--url_root", type=str, default="/translator")
parser.add_argument("--debug", "-d", action="store_true")
parser.add_argument(
"--config", "-c", type=str, default="./available_models/conf.json"
"--model_config", "-m", type=str, default="./available_models/conf.json"
)
return parser

Expand All @@ -155,7 +155,7 @@ def main():
parser = _get_parser()
args = parser.parse_args()
start(
args.config,
args.model_config,
url_root=args.url_root,
host=args.ip,
port=args.port,
Expand Down
5 changes: 2 additions & 3 deletions onmt/bin/translate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from onmt.inference_engine import InferenceEnginePY
from onmt.opts import config_opts, translate_opts
from onmt.opts import translate_opts
from onmt.utils.parse import ArgumentParser
from onmt.utils.misc import use_gpu, set_random_seed
from torch.profiler import profile, record_function, ProfilerActivity
Expand All @@ -23,8 +23,7 @@ def translate(opt):

def _get_parser():
parser = ArgumentParser(description="translate.py")
config_opts(parser)
translate_opts(parser, dynamic=True)
translate_opts(parser)
return parser


Expand Down
115 changes: 54 additions & 61 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ def _add_reproducibility_opts(parser):
)


def _add_dynamic_corpus_opts(parser, build_vocab_only=False):
"""Options related to training corpus, type: a list of dictionary."""
def _add_dataset_opts(parser, build_vocab_only=False):
"""Options related to training datasets, type: a list of dictionary."""
group = parser.add_argument_group("Data")
group.add(
"-data",
Expand Down Expand Up @@ -278,7 +278,7 @@ def _add_features_opts(parser):
)


def _add_dynamic_vocab_opts(parser, build_vocab_only=False):
def _add_vocab_opts(parser, build_vocab_only=False):
"""Options related to vocabulary and features.

Add all options relate to vocabulary or features to parser.
Expand Down Expand Up @@ -412,7 +412,7 @@ def _add_dynamic_vocab_opts(parser, build_vocab_only=False):
)


def _add_dynamic_transform_opts(parser):
def _add_transform_opts(parser):
"""Options related to transforms.

Options that specified in the definitions of each transform class
Expand All @@ -422,17 +422,17 @@ def _add_dynamic_transform_opts(parser):
transform_cls.add_options(parser)


def dynamic_prepare_opts(parser, build_vocab_only=False):
def data_prepare_opts(parser, build_vocab_only=False):
"""Options related to data prepare in dynamic mode.

Add all dynamic data prepare related options to parser.
If `build_vocab_only` set to True, then only contains options that
will be used in `onmt/bin/build_vocab.py`.
"""
config_opts(parser)
_add_dynamic_corpus_opts(parser, build_vocab_only=build_vocab_only)
_add_dynamic_vocab_opts(parser, build_vocab_only=build_vocab_only)
_add_dynamic_transform_opts(parser)
_add_dataset_opts(parser, build_vocab_only=build_vocab_only)
_add_vocab_opts(parser, build_vocab_only=build_vocab_only)
_add_transform_opts(parser)

if build_vocab_only:
_add_reproducibility_opts(parser)
Expand Down Expand Up @@ -1125,6 +1125,39 @@ def _add_train_general_opts(parser):
help="Type of the source input. " "Options are [text].",
)

group.add(
"-bucket_size",
"--bucket_size",
type=int,
default=262144,
help="""A bucket is a buffer of bucket_size examples to pick
from the various Corpora. The dynamic iterator batches
batch_size batchs from the bucket and shuffle them.""",
)
group.add(
"-bucket_size_init",
"--bucket_size_init",
type=int,
default=-1,
help="""The bucket is initalized with this awith this
amount of examples (optional)""",
)
group.add(
"-bucket_size_increment",
"--bucket_size_increment",
type=int,
default=0,
help="""The bucket size is incremented with this
amount of examples (optional)""",
)
group.add(
"-prefetch_factor",
"--prefetch_factor",
type=int,
default=200,
help="""number of mini-batches loaded in advance to avoid the
GPU waiting during the refilling of the bucket.""",
)
group.add(
"--save_model",
"-save_model",
Expand Down Expand Up @@ -1541,43 +1574,6 @@ def _add_train_general_opts(parser):
_add_logging_opts(parser, is_train=True)


def _add_train_dynamic_data(parser):
group = parser.add_argument_group("Dynamic data")
group.add(
"-bucket_size",
"--bucket_size",
type=int,
default=262144,
help="""A bucket is a buffer of bucket_size examples to pick
from the various Corpora. The dynamic iterator batches
batch_size batchs from the bucket and shuffle them.""",
)
group.add(
"-bucket_size_init",
"--bucket_size_init",
type=int,
default=-1,
help="""The bucket is initalized with this awith this
amount of examples (optional)""",
)
group.add(
"-bucket_size_increment",
"--bucket_size_increment",
type=int,
default=0,
help="""The bucket size is incremented with this
amount of examples (optional)""",
)
group.add(
"-prefetch_factor",
"--prefetch_factor",
type=int,
default=200,
help="""number of mini-batches loaded in advance to avoid the
GPU waiting during the refilling of the bucket.""",
)


def _add_quant_opts(parser):
group = parser.add_argument_group("Quant options")
group.add(
Expand Down Expand Up @@ -1624,13 +1620,10 @@ def _add_quant_opts(parser):

def train_opts(parser):
"""All options used in train."""
# options relate to data preprare
dynamic_prepare_opts(parser, build_vocab_only=False)
data_prepare_opts(parser, build_vocab_only=False)
distributed_opts(parser)
# options relate to train
model_opts(parser)
_add_train_general_opts(parser)
_add_train_dynamic_data(parser)
_add_quant_opts(parser)


Expand Down Expand Up @@ -1796,8 +1789,9 @@ def _add_decoding_opts(parser):
)


def translate_opts(parser, dynamic=False):
def translate_opts(parser):
"""Translation / inference options"""
config_opts(parser)
group = parser.add_argument_group("Model")
group.add(
"--model",
Expand Down Expand Up @@ -1929,18 +1923,17 @@ def translate_opts(parser, dynamic=False):
)
group.add("--gpu", "-gpu", type=int, default=-1, help="Device to run on")

if dynamic:
group.add(
"-transforms",
"--transforms",
default=[],
nargs="+",
choices=AVAILABLE_TRANSFORMS.keys(),
help="Default transform pipeline to apply to data.",
)
group.add(
"-transforms",
"--transforms",
default=[],
nargs="+",
choices=AVAILABLE_TRANSFORMS.keys(),
help="Default transform pipeline to apply to data.",
)

# Adding options related to Transforms
_add_dynamic_transform_opts(parser)
# Adding options related to Transforms
_add_transform_opts(parser)

_add_quant_opts(parser)

Expand Down
4 changes: 2 additions & 2 deletions onmt/tests/test_data_prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os

from onmt.utils.parse import ArgumentParser
from onmt.opts import dynamic_prepare_opts
from onmt.opts import data_prepare_opts
from onmt.train_single import prepare_transforms_vocabs
from onmt.constants import CorpusName

Expand All @@ -17,7 +17,7 @@

def get_default_opts():
parser = ArgumentParser(description="data sample prepare")
dynamic_prepare_opts(parser)
data_prepare_opts(parser)

default_opts = [
"-config",
Expand Down
3 changes: 1 addition & 2 deletions onmt/tests/test_inference_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@

def _get_parser():
parser = ArgumentParser(description="simple_inference_engine_py.py")
opts.config_opts(parser)
opts.translate_opts(parser, dynamic=True)
opts.translate_opts(parser)
return parser


Expand Down
3 changes: 1 addition & 2 deletions onmt/utils/scoring_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from onmt.utils.parse import ArgumentParser
from onmt.translate import GNMTGlobalScorer, Translator
from onmt.opts import config_opts, translate_opts
from onmt.opts import translate_opts
from onmt.constants import CorpusTask
from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter
from onmt.transforms import get_transforms_cls, make_transforms, TransformPipe
Expand Down Expand Up @@ -51,7 +51,6 @@ def translate(self, model, gpu_rank, step):

# Set "default" translation options on empty cfgfile
parser = ArgumentParser()
config_opts(parser)
translate_opts(parser)
base_args = ["-model", "dummy"] + ["-src", "dummy"]
opt = parser.parse_args(base_args)
Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
"onmt_server=onmt.bin.server:main",
"onmt_train=onmt.bin.train:main",
"onmt_translate=onmt.bin.translate:main",
"onmt_translate_dynamic=onmt.bin.translate_dynamic:main",
"onmt_release_model=onmt.bin.release_model:main",
"onmt_average_models=onmt.bin.average_models:main",
"onmt_build_vocab=onmt.bin.build_vocab:main",
Expand Down
3 changes: 1 addition & 2 deletions tools/LM_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@

def _get_parser():
parser = ArgumentParser(description="LM_scoring.py")
opts.config_opts(parser)
opts.translate_opts(parser, dynamic=True)
opts.translate_opts(parser)
return parser


Expand Down
Loading