Skip to content

Commit

Permalink
normalize or die
Browse files Browse the repository at this point in the history
  • Loading branch information
Mickus Timothee committed Sep 25, 2023
1 parent f6b5773 commit 14b9fed
Show file tree
Hide file tree
Showing 48 changed files with 929 additions and 941 deletions.
12 changes: 6 additions & 6 deletions mammoth/bin/average_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

def average_models(model_files, fp32=False):
vocab = None
opt = None
opts = None
avg_model = None
avg_generator = None

Expand All @@ -21,7 +21,7 @@ def average_models(model_files, fp32=False):
generator_weights[k] = v.float()

if i == 0:
vocab, opt = m['vocab'], m['opt']
vocab, opts = m['vocab'], m['opts']
avg_model = model_weights
avg_generator = generator_weights
else:
Expand All @@ -31,7 +31,7 @@ def average_models(model_files, fp32=False):
for (k, v) in avg_generator.items():
avg_generator[k].mul_(i).add_(generator_weights[k]).div_(i + 1)

final = {"vocab": vocab, "opt": opt, "optim": None, "generator": avg_generator, "model": avg_model}
final = {"vocab": vocab, "opts": opts, "optim": None, "generator": avg_generator, "model": avg_model}
return final


Expand All @@ -40,10 +40,10 @@ def main():
parser.add_argument("-models", "-m", nargs="+", required=True, help="List of models")
parser.add_argument("-output", "-o", required=True, help="Output file")
parser.add_argument("-fp32", "-f", action="store_true", help="Cast params to float32")
opt = parser.parse_args()
opts = parser.parse_args()

final = average_models(opt.models, opt.fp32)
torch.save(final, opt.output)
final = average_models(opts.models, opts.fp32)
torch.save(final, opts.output)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions mammoth/bin/build_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def build_vocab_main(opts):
src_counters_by_lang = defaultdict(Counter)
tgt_counters_by_lang = defaultdict(Counter)

for corpus_id in opts.data:
lang_pair = opts.data[corpus_id]['src_tgt']
for corpus_id in opts.tasks:
lang_pair = opts.tasks[corpus_id]['src_tgt']
src_lang, tgt_lang = lang_pair.split('-')
task = TaskSpecs(
node_rank=None,
Expand Down
14 changes: 7 additions & 7 deletions mammoth/bin/release_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,19 @@ def main():
default=None,
help="Quantization type for CT2 model.",
)
opt = parser.parse_args()
opts = parser.parse_args()

model = torch.load(opt.model, map_location=torch.device("cpu"))
if opt.format == "pytorch":
model = torch.load(opts.model, map_location=torch.device("cpu"))
if opts.format == "pytorch":
model["optim"] = None
torch.save(model, opt.output)
elif opt.format == "ctranslate2":
torch.save(model, opts.output)
elif opts.format == "ctranslate2":
import ctranslate2

if not hasattr(ctranslate2, "__version__"):
raise RuntimeError("onmt_release_model script requires ctranslate2 >= 2.0.0")
converter = ctranslate2.converters.OpenNMTPyConverter(opt.model)
converter.convert(opt.output, force=True, quantization=opt.quantization)
converter = ctranslate2.converters.OpenNMTPyConverter(opts.model)
converter.convert(opts.output, force=True, quantization=opts.quantization)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions mammoth/bin/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ def clone_model(model_id):
timeout = data['timeout']
del data['timeout']

opt = data.get('opt', None)
opts = data.get('opts', None)
try:
model_id, load_time = translation_server.clone_model(model_id, opt, timeout)
model_id, load_time = translation_server.clone_model(model_id, opts, timeout)
except ServerModelError as e:
out['status'] = STATUS_ERROR
out['error'] = str(e)
Expand Down
146 changes: 73 additions & 73 deletions mammoth/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,60 +32,60 @@
torch.multiprocessing.set_sharing_strategy('file_system')


# def prepare_fields_transforms(opt):
# def prepare_fields_transforms(opts):
# """Prepare or dump fields & transforms before training."""
# transforms_cls = get_transforms_cls(opt._all_transform)
# specials = get_specials(opt, transforms_cls)
# transforms_cls = get_transforms_cls(opts._all_transform)
# specials = get_specials(opts, transforms_cls)
#
# fields = build_dynamic_fields(opt, src_specials=specials['src'], tgt_specials=specials['tgt'])
# fields = build_dynamic_fields(opts, src_specials=specials['src'], tgt_specials=specials['tgt'])
#
# # maybe prepare pretrained embeddings, if any
# prepare_pretrained_embeddings(opt, fields)
# prepare_pretrained_embeddings(opts, fields)
#
# if opt.dump_fields:
# save_fields(fields, opt.save_data, overwrite=opt.overwrite)
# if opt.dump_transforms or opt.n_sample != 0:
# transforms = make_transforms(opt, transforms_cls, fields)
# if opt.dump_transforms:
# save_transforms(transforms, opt.save_data, overwrite=opt.overwrite)
# if opt.n_sample != 0:
# if opts.dump_fields:
# save_fields(fields, opts.save_data, overwrite=opts.overwrite)
# if opts.dump_transforms or opts.n_sample != 0:
# transforms = make_transforms(opts, transforms_cls, fields)
# if opts.dump_transforms:
# save_transforms(transforms, opts.save_data, overwrite=opts.overwrite)
# if opts.n_sample != 0:
# logger.warning(
# f"`-n_sample` != 0: Training will not be started. Stop after saving {opt.n_sample} samples/corpus."
# f"`-n_sample` != 0: Training will not be started. Stop after saving {opts.n_sample} samples/corpus."
# )
# save_transformed_sample(opt, transforms, n_sample=opt.n_sample)
# save_transformed_sample(opts, transforms, n_sample=opts.n_sample)
# logger.info("Sample saved, please check it before restart training.")
# sys.exit()
# return fields, transforms_cls

# TODO: reimplement save_transformed_sample

def _init_train(opt):
def _init_train(opts):
"""Common initilization stuff for all training process."""
ArgumentParser.validate_prepare_opts(opt)
ArgumentParser.validate_prepare_opts(opts)

if opt.train_from:
if opts.train_from:
# Load checkpoint if we resume from a previous training.
checkpoint = load_checkpoint(ckpt_path=opt.train_from)
# fields = load_fields(opt.save_data, checkpoint)
transforms_cls = get_transforms_cls(opt._all_transform)
checkpoint = load_checkpoint(ckpt_path=opts.train_from)
# fields = load_fields(opts.save_data, checkpoint)
transforms_cls = get_transforms_cls(opts._all_transform)
if (
hasattr(checkpoint["opt"], '_all_transform')
and len(opt._all_transform.symmetric_difference(checkpoint["opt"]._all_transform)) != 0
hasattr(checkpoint["opts"], '_all_transform')
and len(opts._all_transform.symmetric_difference(checkpoint["opts"]._all_transform)) != 0
):
_msg = "configured transforms is different from checkpoint:"
new_transf = opt._all_transform.difference(checkpoint["opt"]._all_transform)
old_transf = checkpoint["opt"]._all_transform.difference(opt._all_transform)
new_transf = opts._all_transform.difference(checkpoint["opts"]._all_transform)
old_transf = checkpoint["opts"]._all_transform.difference(opts._all_transform)
if len(new_transf) != 0:
_msg += f" +{new_transf}"
if len(old_transf) != 0:
_msg += f" -{old_transf}."
logger.warning(_msg)
if opt.update_vocab:
if opts.update_vocab:
logger.info("Updating checkpoint vocabulary with new vocabulary")
# fields, transforms_cls = prepare_fields_transforms(opt)
# fields, transforms_cls = prepare_fields_transforms(opts)
else:
checkpoint = None
# fields, transforms_cls = prepare_fields_transforms(opt)
# fields, transforms_cls = prepare_fields_transforms(opts)

# Report src and tgt vocab sizes
# for side in ['src', 'tgt']:
Expand All @@ -100,24 +100,24 @@ def _init_train(opt):
return checkpoint, None, transforms_cls


# def init_train_prepare_fields_transforms(opt, vocab_path, side):
# def init_train_prepare_fields_transforms(opts, vocab_path, side):
# """Prepare or dump fields & transforms before training."""
#
# fields = None # build_dynamic_fields_langspec(opt, vocab_path, side)
# transforms_cls = get_transforms_cls(opt._all_transform)
# # TODO: maybe prepare pretrained embeddings, if any, with `prepare_pretrained_embeddings(opt, fields)`
# fields = None # build_dynamic_fields_langspec(opts, vocab_path, side)
# transforms_cls = get_transforms_cls(opts._all_transform)
# # TODO: maybe prepare pretrained embeddings, if any, with `prepare_pretrained_embeddings(opts, fields)`
#
# # if opt.dump_fields:
# # save_fields(fields, opt.save_data, overwrite=opt.overwrite)
# if opt.dump_transforms or opt.n_sample != 0:
# transforms = make_transforms(opt, transforms_cls, fields)
# if opt.dump_transforms:
# save_transforms(transforms, opt.save_data, overwrite=opt.overwrite)
# if opt.n_sample != 0:
# # if opts.dump_fields:
# # save_fields(fields, opts.save_data, overwrite=opts.overwrite)
# if opts.dump_transforms or opts.n_sample != 0:
# transforms = make_transforms(opts, transforms_cls, fields)
# if opts.dump_transforms:
# save_transforms(transforms, opts.save_data, overwrite=opts.overwrite)
# if opts.n_sample != 0:
# logger.warning(
# f"`-n_sample` != 0: Training will not be started. Stop after saving {opt.n_sample} samples/corpus."
# f"`-n_sample` != 0: Training will not be started. Stop after saving {opts.n_sample} samples/corpus."
# )
# save_transformed_sample(opt, transforms, n_sample=opt.n_sample)
# save_transformed_sample(opts, transforms, n_sample=opts.n_sample)
# logger.info("Sample saved, please check it before restart training.")
# sys.exit()
#
Expand All @@ -127,7 +127,7 @@ def _init_train(opt):
# return fields


def validate_slurm_node_opts(current_env, world_context, opt):
def validate_slurm_node_opts(current_env, world_context, opts):
"""If you are using slurm, confirm that opts match slurm environment variables"""
slurm_n_nodes = int(current_env['SLURM_NNODES'])
if slurm_n_nodes != world_context.n_nodes:
Expand All @@ -136,35 +136,35 @@ def validate_slurm_node_opts(current_env, world_context, opt):
f'but set n_nodes to {world_context.n_nodes} in the conf'
)
slurm_node_id = int(current_env['SLURM_NODEID'])
if slurm_node_id != opt.node_rank:
if slurm_node_id != opts.node_rank:
raise ValueError(
f'Looks like you are running on slurm node {slurm_node_id}, '
f'but set node_rank to {opt.node_rank} on the command line'
f'but set node_rank to {opts.node_rank} on the command line'
)


def train(opt):
init_logger(opt.log_file)
ArgumentParser.validate_train_opts(opt)
ArgumentParser.update_model_opts(opt)
ArgumentParser.validate_model_opts(opt)
ArgumentParser.validate_prepare_opts(opt)
set_random_seed(opt.seed, False)
def train(opts):
init_logger(opts.log_file)
ArgumentParser.validate_train_opts(opts)
ArgumentParser.update_model_opts(opts)
ArgumentParser.validate_model_opts(opts)
ArgumentParser.validate_prepare_opts(opts)
set_random_seed(opts.seed, False)

# set PyTorch distributed related environment variables
current_env = os.environ
current_env["WORLD_SIZE"] = str(opt.world_size)
world_context = WorldContext.from_opt(opt)
current_env["WORLD_SIZE"] = str(opts.world_size)
world_context = WorldContext.from_opts(opts)
if 'SLURM_NNODES' in current_env:
validate_slurm_node_opts(current_env, world_context, opt)
validate_slurm_node_opts(current_env, world_context, opts)
logger.info(f'Training on {world_context}')

opt.data_task = ModelTask.SEQ2SEQ
opts.data_task = ModelTask.SEQ2SEQ

transforms_cls = get_transforms_cls(opt._all_transform)
transforms_cls = get_transforms_cls(opts._all_transform)
if transforms_cls:
logger.info(f'All transforms: {transforms_cls}')
src_specials, tgt_specials = zip(*(cls.get_specials(opt) for cls in transforms_cls.values()))
src_specials, tgt_specials = zip(*(cls.get_specials(opts) for cls in transforms_cls.values()))
all_specials = set(DEFAULT_SPECIALS)
for special_group in src_specials + tgt_specials:
all_specials = all_specials | special_group
Expand All @@ -177,12 +177,12 @@ def train(opt):

vocabs_dict = OrderedDict()
# For creating fields, we use a task_queue_manager that doesn't filter by node and gpu
global_task_queue_manager = TaskQueueManager.from_opt(opt, world_context)
global_task_queue_manager = TaskQueueManager.from_opts(opts, world_context)

vocab_size = {'src': opt.src_vocab_size or None, 'tgt': opt.tgt_vocab_size or None}
vocab_size = {'src': opts.src_vocab_size or None, 'tgt': opts.tgt_vocab_size or None}
for side in ('src', 'tgt'):
for lang in global_task_queue_manager.get_langs(side):
vocab_path = opt.__getattribute__(f'{side}_vocab')[lang]
vocab_path = opts.__getattribute__(f'{side}_vocab')[lang]
# FIXME: for now, all specials are passed to all vocabs, this could be finer-grained
vocabs_dict[(side, lang)] = get_vocab(vocab_path, lang, vocab_size[side], specials=all_specials)
# for key, val in fields_dict:
Expand All @@ -193,14 +193,14 @@ def train(opt):
logger.debug(f"[{os.getpid()}] Initializing process group with: {current_env}")

if world_context.context == DeviceContextEnum.MULTI_GPU:
current_env["MASTER_ADDR"] = opt.master_ip
current_env["MASTER_PORT"] = str(opt.master_port)
node_rank = opt.node_rank
current_env["MASTER_ADDR"] = opts.master_ip
current_env["MASTER_PORT"] = str(opts.master_port)
node_rank = opts.node_rank

queues = []
semaphores = []
mp = torch.multiprocessing.get_context('spawn')
logger.info("world_size = {}, queue_size = {}".format(opt.world_size, opt.queue_size))
logger.info("world_size = {}, queue_size = {}".format(opts.world_size, opts.queue_size))
# Create a thread to listen for errors in the child processes.
error_queue = mp.SimpleQueue()
error_handler = ErrorHandler(error_queue)
Expand All @@ -217,21 +217,21 @@ def train(opt):
task_queue_manager = global_task_queue_manager.global_to_local(
node_rank=node_rank,
local_rank=local_rank,
opt=opt
opts=opts
)

# store rank in env (FIXME: is this obsolete?)
current_env["RANK"] = str(device_context.global_rank)
current_env["LOCAL_RANK"] = str(device_context.local_rank)

q = mp.Queue(opt.queue_size)
semaphore = mp.Semaphore(opt.queue_size)
q = mp.Queue(opts.queue_size)
semaphore = mp.Semaphore(opts.queue_size)
queues.append(q)
semaphores.append(semaphore)
procs.append(
mp.Process(
target=consumer,
args=(train_process, opt, device_context, error_queue, q, semaphore, task_queue_manager),
args=(train_process, opts, device_context, error_queue, q, semaphore, task_queue_manager),
daemon=True,
)
)
Expand All @@ -244,12 +244,12 @@ def train(opt):
task_queue_manager=task_queue_manager,
transforms_cls=transforms_cls,
vocabs_dict=vocabs_dict,
opts=opt,
opts=opts,
is_train=True,
)

producer = mp.Process(
target=batch_producer, args=(train_iter, q, semaphore, opt, local_rank), daemon=True
target=batch_producer, args=(train_iter, q, semaphore, opts, local_rank), daemon=True
)
producers.append(producer)
producers[local_rank].start()
Expand All @@ -272,9 +272,9 @@ def train(opt):
task_queue_manager = global_task_queue_manager.global_to_local(
node_rank=0,
local_rank=0,
opt=opt
opts=opts
)
train_process(opt, device_context=device_context, task_queue_manager=task_queue_manager)
train_process(opts, device_context=device_context, task_queue_manager=task_queue_manager)


def _get_parser():
Expand All @@ -286,8 +286,8 @@ def _get_parser():
def main():
parser = _get_parser()

opt, unknown = parser.parse_known_args()
train(opt)
opts, unknown = parser.parse_known_args()
train(opts)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 14b9fed

Please sign in to comment.