Skip to content

Commit

Permalink
Disable obsolete transation config generation
Browse files Browse the repository at this point in the history
There is currently no way to automatically generate zero-shot configs.
Currently it must be done by hand: copypasting and editing a supervised
task definition.
  • Loading branch information
Waino committed Nov 6, 2023
1 parent 02dbf93 commit 10e85a3
Showing 1 changed file with 14 additions and 93 deletions.
107 changes: 14 additions & 93 deletions tools/config_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,6 @@ def add_adapter_config_args(parser):


def add_translation_configs_args(parser):
parser.add_argument('--translation_config_dir', type=str)
parser.add_argument('--zero_shot', action='store_true')


Expand Down Expand Up @@ -599,100 +598,22 @@ def translation_configs(opts):
start = time.time()

cc_opts = opts.in_config[0]['config_config']
translation_config_dir = (
opts.translation_config_dir if opts.translation_config_dir
else cc_opts.get('translation_config_dir', 'config/translation')
)
zero_shot = opts.zero_shot if opts.zero_shot else cc_opts.get('zero_shot', False)
if not zero_shot:
return

src_subword_model = opts.in_config[0].get('src_subword_model', None)
tgt_subword_model = opts.in_config[0].get('tgt_subword_model', None)
use_src_lang_token = cc_opts.get('use_src_lang_token', False)

os.makedirs(translation_config_dir, exist_ok=True)
encoder_stacks = defaultdict(dict)
decoder_stacks = defaultdict(dict)
transforms_by_lang = defaultdict(dict)
supervised_pairs = set()
for task_opts in opts.in_config[0]['tasks'].values():
src_lang, tgt_lang = task_opts['src_tgt'].split('-')
if src_lang == tgt_lang:
continue
# src / encoder
src_stack = [{'id': group} for group in task_opts['enc_sharing_group']]
if 'adapters' in task_opts:
adapters_by_stack = _adapters_to_stacks(task_opts['adapters']['encoder'], opts, 'enc')
assert len(src_stack) == len(adapters_by_stack)
for stack, adapters in zip(src_stack, adapters_by_stack):
stack['adapters'] = adapters
key = str(src_stack) # An ugly way to freeze the mutable structure
encoder_stacks[src_lang][key] = src_stack
# tgt / decoder
tgt_stack = [{'id': group} for group in task_opts['dec_sharing_group']]
if 'adapters' in task_opts:
adapters_by_stack = _adapters_to_stacks(task_opts['adapters']['decoder'], opts, 'dec')
assert len(tgt_stack) == len(adapters_by_stack)
for stack, adapters in zip(tgt_stack, adapters_by_stack):
stack['adapters'] = adapters
key = str(tgt_stack) # An ugly way to freeze the mutable structure
decoder_stacks[tgt_lang][key] = tgt_stack
# Transforms and subword models also need to be respecified during translation
if 'transforms' not in task_opts:
transforms = None
else:
transforms = [
transform for transform in task_opts['transforms']
if not transform == 'filtertoolong'
]
transforms_by_lang[src_lang] = transforms
# Write config for the supervised directions
_write_translation_config(
src_lang,
tgt_lang,
src_stack,
tgt_stack,
transforms,
src_subword_model,
tgt_subword_model,
'supervised',
translation_config_dir,
use_src_lang_token,
)
supervised_pairs.add((src_lang, tgt_lang))
if zero_shot:
src_langs = encoder_stacks.keys()
tgt_langs = decoder_stacks.keys()
# verify that there is an unique stack for each language
ambiguous_src = [src_lang for src_lang in encoder_stacks if len(encoder_stacks[src_lang]) > 1]
ambiguous_tgt = [tgt_lang for tgt_lang in decoder_stacks if len(decoder_stacks[tgt_lang]) > 1]
if len(ambiguous_src) > 0 or len(ambiguous_tgt) > 0:
raise Exception(
'Zero-shot translation configs can only be generated if each source (target) language '
'has an unambigous encoder (decoder) stack.\n'
'The following languages have more than one encoder/decoder stack:\n'
f'Source: {ambiguous_src}\nTarget: {ambiguous_tgt}'
)
for src_lang in src_langs:
for tgt_lang in tgt_langs:
if src_lang == tgt_lang:
continue
if (src_lang, tgt_lang) in supervised_pairs:
continue
src_stack = list(encoder_stacks[src_lang].values())[0]
tgt_stack = list(decoder_stacks[tgt_lang].values())[0]
transforms = transforms_by_lang[src_lang]
_write_translation_config(
src_lang,
tgt_lang,
src_stack,
tgt_stack,
transforms,
src_subword_model,
tgt_subword_model,
'zeroshot',
translation_config_dir,
use_src_lang_token,
)
# src_subword_model = opts.in_config[0].get('src_subword_model', None)
# tgt_subword_model = opts.in_config[0].get('tgt_subword_model', None)
# use_src_lang_token = cc_opts.get('use_src_lang_token', False)

# TODO: create zero-shot tasks using the same template as for the supervised tasks, except that:
# - no training set or validation set will be defined
# - no weighting/curriculum
# - no GPU allocation.
# However, these 3 are needed: sharing_groups, set_transforms, adapter_config.
# Because it would be nice to be able to add zero-shot tasks as a final extra step
# without completely regenerating the entire training config,
# these 3 should be modified to be rerunnable for a subset of tasks.

duration = time.time() - start
logger.info(f'step took {duration} s')
Expand Down

0 comments on commit 10e85a3

Please sign in to comment.