Skip to content

Commit

Permalink
Fixes to config-config: Renamed opts and hpo
Browse files Browse the repository at this point in the history
Some opts were recently renamed, requiring changes to config-config.
HPO also requires some fixes.
  • Loading branch information
Waino committed Oct 30, 2023
1 parent 03dd5b1 commit 22d03f0
Showing 1 changed file with 75 additions and 45 deletions.
120 changes: 75 additions & 45 deletions tools/config_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,14 @@ def add_complete_language_pairs_args(parser):
'Can use the variables {src_lang}, {tgt_lang}, and {sorted_pair}.'
'If unset, autoencoder pairs will use src_path and tgt_path.'
)
parser.add_argument(
'--valid_src_path', type=str,
help='path template to source dev set. Can use variables {src_lang}, {tgt_lang}, and {sorted_pair}.'
)
parser.add_argument(
'--valid_tgt_path', type=str,
help='path template to target dev set. Can use variables {src_lang}, {tgt_lang}, and {sorted_pair}.'
)
parser.add_argument(
'--autoencoder',
action='store_true',
Expand Down Expand Up @@ -219,11 +227,13 @@ def get_opts():
add_complete_language_pairs_args(parser_config_all)
add_adapter_config_args(parser_config_all)
add_translation_configs_args(parser_config_all)
parser_extra_cpu = subparsers.add_parser('extra_cpu')
add_configs_args(parser_extra_cpu)
return parser.parse_args()


def _split_large_language_pairs(opts, corpora_weights, split_treshold):
corpora_out = deepcopy(opts.in_config[0]['data'])
corpora_out = deepcopy(opts.in_config[0]['tasks'])
corpora_weights_out = dict()
for cname, weight in corpora_weights.items():
if weight > split_treshold:
Expand All @@ -240,7 +250,7 @@ def _split_large_language_pairs(opts, corpora_weights, split_treshold):
del corpora_out[cname]
else:
corpora_weights_out[cname] = weight
opts.in_config[0]['data'] = corpora_out
opts.in_config[0]['tasks'] = corpora_out
return corpora_weights_out


Expand All @@ -261,7 +271,7 @@ def corpora_schedule(opts):
for path, len in corpora_lens_cache.items():
logger.info(f'{path}:\t{len}')
corpora_lens = {}
for cname, corpus in opts.in_config[0]['data'].items():
for cname, corpus in opts.in_config[0]['tasks'].items():
if corpus['path_src'] in corpora_lens_cache:
length = corpora_lens_cache[corpus['path_src']]
corpora_lens[cname] = length
Expand All @@ -285,7 +295,7 @@ def corpora_schedule(opts):
corpora_weights = _split_large_language_pairs(opts, corpora_weights, split_treshold)

min_introduce_at_training_step = opts.in_config[0].get('train_steps', 100_000)
for cname, corpus in opts.in_config[0]['data'].items():
for cname, corpus in opts.in_config[0]['tasks'].items():
src_lang, tgt_lang = corpus['src_tgt'].split('-')
weight = corpora_weights[cname]
if use_weight and use_introduce_at_training_step:
Expand All @@ -307,7 +317,7 @@ def corpora_schedule(opts):
min_introduce_at_training_step = min(min_introduce_at_training_step, introduce_at_training_step)
if use_introduce_at_training_step and min_introduce_at_training_step > 0:
# With a single very large task that gets split, it is possible that no task can start
for cname, corpus in opts.in_config[0]['data'].items():
for cname, corpus in opts.in_config[0]['tasks'].items():
if 'introduce_at_training_step' in corpus:
corpus['introduce_at_training_step'] -= min_introduce_at_training_step
duration = time.time() - start
Expand Down Expand Up @@ -337,7 +347,7 @@ def cluster_languages(opts):

sim_langs = set(distance_matrix['header'])
corpus_langs = set()
for cname, corpus in opts.in_config[0]['data'].items():
for cname, corpus in opts.in_config[0]['tasks'].items():
assert all([(lng in sim_langs) for lng in corpus['src_tgt'].split('-')]), \
f'corpus {cname}: one language (either {" or ".join(corpus["src_tgt"].split("-"))} ' \
f'was not found in the distance matrix (supports {" ".join(sim_langs)})'
Expand Down Expand Up @@ -391,7 +401,7 @@ def sharing_groups(opts):
raise Exception('Must set --dec_sharing_groups')
assert len(enc_sharing_groups) == len(opts.in_config[0]['enc_layers'])
assert len(dec_sharing_groups) == len(opts.in_config[0]['dec_layers'])
for cname, corpus in opts.in_config[0]['data'].items():
for cname, corpus in opts.in_config[0]['tasks'].items():
src, tgt = corpus['src_tgt'].split('-')
mapping_src = {
'LANGUAGE': src,
Expand Down Expand Up @@ -429,7 +439,7 @@ def set_transforms(opts):
ae_transforms = opts.ae_transforms if opts.ae_transforms else cc_opts.get('ae_transforms', [])
transforms = opts.transforms if opts.transforms else cc_opts.get('transforms', [])

for cname, corpus in opts.in_config[0]['data'].items():
for cname, corpus in opts.in_config[0]['tasks'].items():
src, tgt = corpus['src_tgt'].split('-')
if src == tgt:
corpus['transforms'] = list(ae_transforms)
Expand Down Expand Up @@ -459,9 +469,9 @@ def allocate_devices(opts):
lang_pairs = []
lps_ready_to_start = []
lp_to_key = defaultdict(list)
for key, data_config in opts.in_config[0]['data'].items():
src_lang, tgt_lang = data_config['src_tgt'].split('-')
ready_to_start = data_config.get('introduce_at_training_step', 0) == 0
for key, tasks_config in opts.in_config[0]['tasks'].items():
src_lang, tgt_lang = tasks_config['src_tgt'].split('-')
ready_to_start = tasks_config.get('introduce_at_training_step', 0) == 0

lang_pairs.append((src_lang, tgt_lang))
if ready_to_start:
Expand Down Expand Up @@ -495,7 +505,7 @@ def allocate_devices(opts):
if lp is None:
continue
key = lp_to_key[lp].pop()
opts.in_config[0]['data'][key]['node_gpu'] = f'{gpu_slot.node}:{gpu_slot.gpu}'
opts.in_config[0]['tasks'][key]['node_gpu'] = f'{gpu_slot.node}:{gpu_slot.gpu}'

opts.in_config[0]['n_nodes'] = n_nodes
opts.in_config[0]['world_size'] = n_gpus_tot
Expand All @@ -504,14 +514,14 @@ def allocate_devices(opts):
# Ensure that all devices can start training (will crash otherwise)
train_steps = opts.in_config[0].get('train_steps', 100_000)
min_introduce_at_training_step = defaultdict(lambda: train_steps)
for cname, corpus in opts.in_config[0]['data'].items():
for cname, corpus in opts.in_config[0]['tasks'].items():
if 'introduce_at_training_step' not in corpus:
continue
min_introduce_at_training_step[corpus['node_gpu']] = min(
corpus['introduce_at_training_step'],
min_introduce_at_training_step[corpus['node_gpu']]
)
for cname, corpus in opts.in_config[0]['data'].items():
for cname, corpus in opts.in_config[0]['tasks'].items():
if 'introduce_at_training_step' not in corpus:
continue
adjust = min_introduce_at_training_step[corpus['node_gpu']]
Expand All @@ -535,40 +545,40 @@ def adapter_config(opts):
tgt_groups = list(sorted(set(cc_opts['groups'][tgt] for tgt in tgt_langs)))
encoder_adapters = opts.in_config[0]['adapters'].get('encoder', [])
decoder_adapters = opts.in_config[0]['adapters'].get('decoder', [])
for data_key, data_config in opts.in_config[0]['data'].items():
if 'adapters' not in data_config:
data_config['adapters'] = {'encoder': [], 'decoder': []}
for task_key, task_config in opts.in_config[0]['tasks'].items():
if 'adapters' not in task_config:
task_config['adapters'] = {'encoder': [], 'decoder': []}
# TODO: refactor and add support for {SRC|TGT}_{LANGUAGE|GROUP} also to adapters
for adapter_name, adapter_config in sorted(encoder_adapters.items()):
if adapter_config['ids'] == 'LANGUAGE':
adapter_config['ids'] = list(src_langs)
for data_key, data_config in opts.in_config[0]['data'].items():
data_src, data_tgt = data_config['src_tgt'].split('-')
data_config['adapters']['encoder'].append([adapter_name, data_src])
for task_key, task_config in opts.in_config[0]['tasks'].items():
task_src, task_tgt = task_config['src_tgt'].split('-')
task_config['adapters']['encoder'].append([adapter_name, task_src])
elif adapter_config['ids'] == 'GROUP':
adapter_config['ids'] = list(src_groups)
for data_key, data_config in opts.in_config[0]['data'].items():
data_src, data_tgt = data_config['src_tgt'].split('-')
data_config['adapters']['encoder'].append([adapter_name, cc_opts['groups'][data_src]])
for task_key, task_config in opts.in_config[0]['tasks'].items():
task_src, task_tgt = task_config['src_tgt'].split('-')
task_config['adapters']['encoder'].append([adapter_name, cc_opts['groups'][task_src]])
elif adapter_config['ids'] == 'FULL':
adapter_config['ids'] = ['full']
for data_key, data_config in opts.in_config[0]['data'].items():
data_config['adapters']['encoder'].append([adapter_name, 'full'])
for task_key, task_config in opts.in_config[0]['tasks'].items():
task_config['adapters']['encoder'].append([adapter_name, 'full'])
for adapter_name, adapter_config in sorted(decoder_adapters.items()):
if adapter_config['ids'] == 'LANGUAGE':
adapter_config['ids'] = list(tgt_langs)
for data_key, data_config in opts.in_config[0]['data'].items():
data_src, data_tgt = data_config['src_tgt'].split('-')
data_config['adapters']['decoder'].append([adapter_name, data_tgt])
for task_key, task_config in opts.in_config[0]['tasks'].items():
task_src, task_tgt = task_config['src_tgt'].split('-')
task_config['adapters']['decoder'].append([adapter_name, task_tgt])
elif adapter_config['ids'] == 'GROUP':
adapter_config['ids'] = list(tgt_groups)
for data_key, data_config in opts.in_config[0]['data'].items():
data_src, data_tgt = data_config['src_tgt'].split('-')
data_config['adapters']['decoder'].append([adapter_name, cc_opts['groups'][data_tgt]])
for task_key, task_config in opts.in_config[0]['tasks'].items():
task_src, task_tgt = task_config['src_tgt'].split('-')
task_config['adapters']['decoder'].append([adapter_name, cc_opts['groups'][task_tgt]])
elif adapter_config['ids'] == 'FULL':
adapter_config['ids'] = ['full']
for data_key, data_config in opts.in_config[0]['data'].items():
data_config['adapters']['decoder'].append([adapter_name, 'full'])
for task_key, task_config in opts.in_config[0]['tasks'].items():
task_config['adapters']['decoder'].append([adapter_name, 'full'])
opts.in_config[0]['adapters']['encoder'] = encoder_adapters
opts.in_config[0]['adapters']['decoder'] = decoder_adapters

Expand Down Expand Up @@ -604,7 +614,7 @@ def translation_configs(opts):
decoder_stacks = defaultdict(dict)
transforms_by_lang = defaultdict(dict)
supervised_pairs = set()
for task_opts in opts.in_config[0]['data'].values():
for task_opts in opts.in_config[0]['tasks'].values():
src_lang, tgt_lang = task_opts['src_tgt'].split('-')
if src_lang == tgt_lang:
continue
Expand Down Expand Up @@ -749,6 +759,8 @@ def complete_language_pairs(opts):
else:
ae_src_path_template = src_path_template
ae_tgt_path_template = tgt_path_template
valid_src_path_template = opts.valid_src_path if opts.valid_src_path else cc_opts['valid_src_path']
valid_tgt_path_template = opts.valid_tgt_path if opts.valid_tgt_path else cc_opts['valid_tgt_path']

src_langs, tgt_langs = _get_langs(opts)
for src_lang in src_langs:
Expand Down Expand Up @@ -784,15 +796,19 @@ def complete_language_pairs(opts):
continue
src_path = ae_src_path_template.format(**template_variables)
tgt_path = ae_tgt_path_template.format(**template_variables)
valid_src_path = None
valid_tgt_path = None
else:
# translation task
src_path = src_path_template.format(**template_variables)
tgt_path = tgt_path_template.format(**template_variables)
valid_src_path = valid_src_path_template.format(**template_variables)
valid_tgt_path = valid_tgt_path_template.format(**template_variables)
if os.path.exists(src_path) and os.path.exists(tgt_path):
_add_language_pair(opts, src_lang, tgt_lang, src_path, tgt_path)
_add_language_pair(opts, src_lang, tgt_lang, src_path, tgt_path, valid_src_path, valid_tgt_path)
else:
logger.warning(f'Paths do NOT exist, omitting language pair: {src_path} {tgt_path}')
if len(opts.in_config[0].get('data', [])) == 0:
if len(opts.in_config[0].get('tasks', [])) == 0:
raise Exception('No language pairs were added. Check your path templates.')
# Allow using language variables for vocabulary definitions
for src_lang in src_langs:
Expand All @@ -804,16 +820,19 @@ def complete_language_pairs(opts):
logger.info(f'step took {duration} s')


def _add_language_pair(opts, src_lang, tgt_lang, src_path, tgt_path):
if 'data' not in opts.in_config[0]:
opts.in_config[0]['data'] = dict()
data_section = opts.in_config[0]['data']
def _add_language_pair(opts, src_lang, tgt_lang, src_path, tgt_path, valid_src_path, valid_tgt_path):
if 'tasks' not in opts.in_config[0]:
opts.in_config[0]['tasks'] = dict()
tasks_section = opts.in_config[0]['tasks']
key = f'train_{src_lang}-{tgt_lang}'
if key not in data_section:
data_section[key] = dict()
data_section[key]['src_tgt'] = f'{src_lang}-{tgt_lang}'
data_section[key]['path_src'] = src_path
data_section[key]['path_tgt'] = tgt_path
if key not in tasks_section:
tasks_section[key] = dict()
tasks_section[key]['src_tgt'] = f'{src_lang}-{tgt_lang}'
tasks_section[key]['path_src'] = src_path
tasks_section[key]['path_tgt'] = tgt_path
if valid_src_path is not None and os.path.exists(valid_src_path):
tasks_section[key]['path_valid_src'] = valid_src_path
tasks_section[key]['path_valid_tgt'] = valid_tgt_path


def remove_temporary_keys(opts):
Expand All @@ -839,6 +858,16 @@ def config_all(opts):
logger.info(f'total took {duration} s')


def extra_cpu(opts):
# Extra step: not included in config_all
# Modifies config to run on a single CPU
del opts.in_config[0]['gpu_ranks']
del opts.in_config[0]['world_size']
opts.in_config[0]['n_nodes'] = 1
for task_opts in opts.in_config[0]['tasks'].values():
del task_opts['node_gpu']


if __name__ == '__main__':
init_logging()
opts = get_opts()
Expand All @@ -857,6 +886,7 @@ def config_all(opts):
translation_configs,
remove_temporary_keys,
config_all,
extra_cpu,
)
}[opts.command]
main(opts)
Expand Down

0 comments on commit 22d03f0

Please sign in to comment.