Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Fixes to config-config: Renamed opts and hpo #38

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 75 additions & 45 deletions tools/config_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,14 @@ def add_complete_language_pairs_args(parser):
'Can use the variables {src_lang}, {tgt_lang}, and {sorted_pair}.'
'If unset, autoencoder pairs will use src_path and tgt_path.'
)
parser.add_argument(
'--valid_src_path', type=str,
help='path template to source dev set. Can use variables {src_lang}, {tgt_lang}, and {sorted_pair}.'
)
parser.add_argument(
'--valid_tgt_path', type=str,
help='path template to target dev set. Can use variables {src_lang}, {tgt_lang}, and {sorted_pair}.'
)
parser.add_argument(
'--autoencoder',
action='store_true',
Expand Down Expand Up @@ -219,11 +227,13 @@ def get_opts():
add_complete_language_pairs_args(parser_config_all)
add_adapter_config_args(parser_config_all)
add_translation_configs_args(parser_config_all)
parser_extra_cpu = subparsers.add_parser('extra_cpu')
add_configs_args(parser_extra_cpu)
return parser.parse_args()


def _split_large_language_pairs(opts, corpora_weights, split_treshold):
corpora_out = deepcopy(opts.in_config[0]['data'])
corpora_out = deepcopy(opts.in_config[0]['tasks'])
corpora_weights_out = dict()
for cname, weight in corpora_weights.items():
if weight > split_treshold:
Expand All @@ -240,7 +250,7 @@ def _split_large_language_pairs(opts, corpora_weights, split_treshold):
del corpora_out[cname]
else:
corpora_weights_out[cname] = weight
opts.in_config[0]['data'] = corpora_out
opts.in_config[0]['tasks'] = corpora_out
return corpora_weights_out


Expand All @@ -261,7 +271,7 @@ def corpora_schedule(opts):
for path, len in corpora_lens_cache.items():
logger.info(f'{path}:\t{len}')
corpora_lens = {}
for cname, corpus in opts.in_config[0]['data'].items():
for cname, corpus in opts.in_config[0]['tasks'].items():
if corpus['path_src'] in corpora_lens_cache:
length = corpora_lens_cache[corpus['path_src']]
corpora_lens[cname] = length
Expand All @@ -285,7 +295,7 @@ def corpora_schedule(opts):
corpora_weights = _split_large_language_pairs(opts, corpora_weights, split_treshold)

min_introduce_at_training_step = opts.in_config[0].get('train_steps', 100_000)
for cname, corpus in opts.in_config[0]['data'].items():
for cname, corpus in opts.in_config[0]['tasks'].items():
src_lang, tgt_lang = corpus['src_tgt'].split('-')
weight = corpora_weights[cname]
if use_weight and use_introduce_at_training_step:
Expand All @@ -307,7 +317,7 @@ def corpora_schedule(opts):
min_introduce_at_training_step = min(min_introduce_at_training_step, introduce_at_training_step)
if use_introduce_at_training_step and min_introduce_at_training_step > 0:
# With a single very large task that gets split, it is possible that no task can start
for cname, corpus in opts.in_config[0]['data'].items():
for cname, corpus in opts.in_config[0]['tasks'].items():
if 'introduce_at_training_step' in corpus:
corpus['introduce_at_training_step'] -= min_introduce_at_training_step
duration = time.time() - start
Expand Down Expand Up @@ -337,7 +347,7 @@ def cluster_languages(opts):

sim_langs = set(distance_matrix['header'])
corpus_langs = set()
for cname, corpus in opts.in_config[0]['data'].items():
for cname, corpus in opts.in_config[0]['tasks'].items():
assert all([(lng in sim_langs) for lng in corpus['src_tgt'].split('-')]), \
f'corpus {cname}: one language (either {" or ".join(corpus["src_tgt"].split("-"))} ' \
f'was not found in the distance matrix (supports {" ".join(sim_langs)})'
Expand Down Expand Up @@ -391,7 +401,7 @@ def sharing_groups(opts):
raise Exception('Must set --dec_sharing_groups')
assert len(enc_sharing_groups) == len(opts.in_config[0]['enc_layers'])
assert len(dec_sharing_groups) == len(opts.in_config[0]['dec_layers'])
for cname, corpus in opts.in_config[0]['data'].items():
for cname, corpus in opts.in_config[0]['tasks'].items():
src, tgt = corpus['src_tgt'].split('-')
mapping_src = {
'LANGUAGE': src,
Expand Down Expand Up @@ -429,7 +439,7 @@ def set_transforms(opts):
ae_transforms = opts.ae_transforms if opts.ae_transforms else cc_opts.get('ae_transforms', [])
transforms = opts.transforms if opts.transforms else cc_opts.get('transforms', [])

for cname, corpus in opts.in_config[0]['data'].items():
for cname, corpus in opts.in_config[0]['tasks'].items():
src, tgt = corpus['src_tgt'].split('-')
if src == tgt:
corpus['transforms'] = list(ae_transforms)
Expand Down Expand Up @@ -459,9 +469,9 @@ def allocate_devices(opts):
lang_pairs = []
lps_ready_to_start = []
lp_to_key = defaultdict(list)
for key, data_config in opts.in_config[0]['data'].items():
src_lang, tgt_lang = data_config['src_tgt'].split('-')
ready_to_start = data_config.get('introduce_at_training_step', 0) == 0
for key, tasks_config in opts.in_config[0]['tasks'].items():
src_lang, tgt_lang = tasks_config['src_tgt'].split('-')
ready_to_start = tasks_config.get('introduce_at_training_step', 0) == 0

lang_pairs.append((src_lang, tgt_lang))
if ready_to_start:
Expand Down Expand Up @@ -495,7 +505,7 @@ def allocate_devices(opts):
if lp is None:
continue
key = lp_to_key[lp].pop()
opts.in_config[0]['data'][key]['node_gpu'] = f'{gpu_slot.node}:{gpu_slot.gpu}'
opts.in_config[0]['tasks'][key]['node_gpu'] = f'{gpu_slot.node}:{gpu_slot.gpu}'

opts.in_config[0]['n_nodes'] = n_nodes
opts.in_config[0]['world_size'] = n_gpus_tot
Expand All @@ -504,14 +514,14 @@ def allocate_devices(opts):
# Ensure that all devices can start training (will crash otherwise)
train_steps = opts.in_config[0].get('train_steps', 100_000)
min_introduce_at_training_step = defaultdict(lambda: train_steps)
for cname, corpus in opts.in_config[0]['data'].items():
for cname, corpus in opts.in_config[0]['tasks'].items():
if 'introduce_at_training_step' not in corpus:
continue
min_introduce_at_training_step[corpus['node_gpu']] = min(
corpus['introduce_at_training_step'],
min_introduce_at_training_step[corpus['node_gpu']]
)
for cname, corpus in opts.in_config[0]['data'].items():
for cname, corpus in opts.in_config[0]['tasks'].items():
if 'introduce_at_training_step' not in corpus:
continue
adjust = min_introduce_at_training_step[corpus['node_gpu']]
Expand All @@ -535,40 +545,40 @@ def adapter_config(opts):
tgt_groups = list(sorted(set(cc_opts['groups'][tgt] for tgt in tgt_langs)))
encoder_adapters = opts.in_config[0]['adapters'].get('encoder', [])
decoder_adapters = opts.in_config[0]['adapters'].get('decoder', [])
for data_key, data_config in opts.in_config[0]['data'].items():
if 'adapters' not in data_config:
data_config['adapters'] = {'encoder': [], 'decoder': []}
for task_key, task_config in opts.in_config[0]['tasks'].items():
if 'adapters' not in task_config:
task_config['adapters'] = {'encoder': [], 'decoder': []}
# TODO: refactor and add support for {SRC|TGT}_{LANGUAGE|GROUP} also to adapters
for adapter_name, adapter_config in sorted(encoder_adapters.items()):
if adapter_config['ids'] == 'LANGUAGE':
adapter_config['ids'] = list(src_langs)
for data_key, data_config in opts.in_config[0]['data'].items():
data_src, data_tgt = data_config['src_tgt'].split('-')
data_config['adapters']['encoder'].append([adapter_name, data_src])
for task_key, task_config in opts.in_config[0]['tasks'].items():
task_src, task_tgt = task_config['src_tgt'].split('-')
task_config['adapters']['encoder'].append([adapter_name, task_src])
elif adapter_config['ids'] == 'GROUP':
adapter_config['ids'] = list(src_groups)
for data_key, data_config in opts.in_config[0]['data'].items():
data_src, data_tgt = data_config['src_tgt'].split('-')
data_config['adapters']['encoder'].append([adapter_name, cc_opts['groups'][data_src]])
for task_key, task_config in opts.in_config[0]['tasks'].items():
task_src, task_tgt = task_config['src_tgt'].split('-')
task_config['adapters']['encoder'].append([adapter_name, cc_opts['groups'][task_src]])
elif adapter_config['ids'] == 'FULL':
adapter_config['ids'] = ['full']
for data_key, data_config in opts.in_config[0]['data'].items():
data_config['adapters']['encoder'].append([adapter_name, 'full'])
for task_key, task_config in opts.in_config[0]['tasks'].items():
task_config['adapters']['encoder'].append([adapter_name, 'full'])
for adapter_name, adapter_config in sorted(decoder_adapters.items()):
if adapter_config['ids'] == 'LANGUAGE':
adapter_config['ids'] = list(tgt_langs)
for data_key, data_config in opts.in_config[0]['data'].items():
data_src, data_tgt = data_config['src_tgt'].split('-')
data_config['adapters']['decoder'].append([adapter_name, data_tgt])
for task_key, task_config in opts.in_config[0]['tasks'].items():
task_src, task_tgt = task_config['src_tgt'].split('-')
task_config['adapters']['decoder'].append([adapter_name, task_tgt])
elif adapter_config['ids'] == 'GROUP':
adapter_config['ids'] = list(tgt_groups)
for data_key, data_config in opts.in_config[0]['data'].items():
data_src, data_tgt = data_config['src_tgt'].split('-')
data_config['adapters']['decoder'].append([adapter_name, cc_opts['groups'][data_tgt]])
for task_key, task_config in opts.in_config[0]['tasks'].items():
task_src, task_tgt = task_config['src_tgt'].split('-')
task_config['adapters']['decoder'].append([adapter_name, cc_opts['groups'][task_tgt]])
elif adapter_config['ids'] == 'FULL':
adapter_config['ids'] = ['full']
for data_key, data_config in opts.in_config[0]['data'].items():
data_config['adapters']['decoder'].append([adapter_name, 'full'])
for task_key, task_config in opts.in_config[0]['tasks'].items():
task_config['adapters']['decoder'].append([adapter_name, 'full'])
opts.in_config[0]['adapters']['encoder'] = encoder_adapters
opts.in_config[0]['adapters']['decoder'] = decoder_adapters

Expand Down Expand Up @@ -604,7 +614,7 @@ def translation_configs(opts):
decoder_stacks = defaultdict(dict)
transforms_by_lang = defaultdict(dict)
supervised_pairs = set()
for task_opts in opts.in_config[0]['data'].values():
for task_opts in opts.in_config[0]['tasks'].values():
src_lang, tgt_lang = task_opts['src_tgt'].split('-')
if src_lang == tgt_lang:
continue
Expand Down Expand Up @@ -749,6 +759,8 @@ def complete_language_pairs(opts):
else:
ae_src_path_template = src_path_template
ae_tgt_path_template = tgt_path_template
valid_src_path_template = opts.valid_src_path if opts.valid_src_path else cc_opts['valid_src_path']
valid_tgt_path_template = opts.valid_tgt_path if opts.valid_tgt_path else cc_opts['valid_tgt_path']

src_langs, tgt_langs = _get_langs(opts)
for src_lang in src_langs:
Expand Down Expand Up @@ -784,15 +796,19 @@ def complete_language_pairs(opts):
continue
src_path = ae_src_path_template.format(**template_variables)
tgt_path = ae_tgt_path_template.format(**template_variables)
valid_src_path = None
valid_tgt_path = None
else:
# translation task
src_path = src_path_template.format(**template_variables)
tgt_path = tgt_path_template.format(**template_variables)
valid_src_path = valid_src_path_template.format(**template_variables)
valid_tgt_path = valid_tgt_path_template.format(**template_variables)
if os.path.exists(src_path) and os.path.exists(tgt_path):
_add_language_pair(opts, src_lang, tgt_lang, src_path, tgt_path)
_add_language_pair(opts, src_lang, tgt_lang, src_path, tgt_path, valid_src_path, valid_tgt_path)
else:
logger.warning(f'Paths do NOT exist, omitting language pair: {src_path} {tgt_path}')
if len(opts.in_config[0].get('data', [])) == 0:
if len(opts.in_config[0].get('tasks', [])) == 0:
raise Exception('No language pairs were added. Check your path templates.')
# Allow using language variables for vocabulary definitions
for src_lang in src_langs:
Expand All @@ -804,16 +820,19 @@ def complete_language_pairs(opts):
logger.info(f'step took {duration} s')


def _add_language_pair(opts, src_lang, tgt_lang, src_path, tgt_path):
if 'data' not in opts.in_config[0]:
opts.in_config[0]['data'] = dict()
data_section = opts.in_config[0]['data']
def _add_language_pair(opts, src_lang, tgt_lang, src_path, tgt_path, valid_src_path, valid_tgt_path):
if 'tasks' not in opts.in_config[0]:
opts.in_config[0]['tasks'] = dict()
tasks_section = opts.in_config[0]['tasks']
key = f'train_{src_lang}-{tgt_lang}'
if key not in data_section:
data_section[key] = dict()
data_section[key]['src_tgt'] = f'{src_lang}-{tgt_lang}'
data_section[key]['path_src'] = src_path
data_section[key]['path_tgt'] = tgt_path
if key not in tasks_section:
tasks_section[key] = dict()
tasks_section[key]['src_tgt'] = f'{src_lang}-{tgt_lang}'
tasks_section[key]['path_src'] = src_path
tasks_section[key]['path_tgt'] = tgt_path
if valid_src_path is not None and os.path.exists(valid_src_path):
tasks_section[key]['path_valid_src'] = valid_src_path
tasks_section[key]['path_valid_tgt'] = valid_tgt_path


def remove_temporary_keys(opts):
Expand All @@ -839,6 +858,16 @@ def config_all(opts):
logger.info(f'total took {duration} s')


def extra_cpu(opts):
# Extra step: not included in config_all
# Modifies config to run on a single CPU
del opts.in_config[0]['gpu_ranks']
del opts.in_config[0]['world_size']
opts.in_config[0]['n_nodes'] = 1
for task_opts in opts.in_config[0]['tasks'].values():
del task_opts['node_gpu']


if __name__ == '__main__':
init_logging()
opts = get_opts()
Expand All @@ -857,6 +886,7 @@ def config_all(opts):
translation_configs,
remove_temporary_keys,
config_all,
extra_cpu,
)
}[opts.command]
main(opts)
Expand Down