Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
- rename sacrebleu to casedbleu
- introduce bootleg_data_splits argument
  • Loading branch information
Mehrad0711 committed Jun 28, 2021
1 parent dbe744a commit d804ba7
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 12 deletions.
10 changes: 5 additions & 5 deletions genienlp/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def computeBLEU(outputs, targets):
return sacrebleu.corpus_bleu(outputs, targets, lowercase=True).score


def computeSacreBLEU(outputs, targets):
def computeCasedBLEU(outputs, targets):
# lowercase is false
sacrebleu_metric = load_metric("sacrebleu")
return sacrebleu_metric.compute(predictions=outputs, references=targets, lowercase=False)['score']
Expand Down Expand Up @@ -540,10 +540,10 @@ def compute_metrics(greedy, answer, requested_metrics: Iterable, lang):
bertscore = computeBERTScore(greedy, answer, lang)
metric_keys.append('bertscore')
metric_values.append(bertscore)
if 'sacrebleu' in requested_metrics:
sacrebleu = computeSacreBLEU(greedy, answer)
metric_keys.append('sacrebleu')
metric_values.append(sacrebleu)
if 'casedbleu' in requested_metrics:
casedbleu = computeCasedBLEU(greedy, answer)
metric_keys.append('casedbleu')
metric_values.append(casedbleu)
if 'bleu' in requested_metrics:
bleu = computeBLEU(greedy, answer)
metric_keys.append('bleu')
Expand Down
13 changes: 11 additions & 2 deletions genienlp/run_bootleg.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,13 @@ def parse_argv(parser):
'--almond_domains', nargs='+', default=[], help='Domains used for almond dataset; e.g. music, books, ...'
)

parser.add_argument('--prep_test_too', action='store_true', help='Prepare bootleg features for test set too')
parser.add_argument(
'--bootleg_data_splits',
nargs='+',
type=str,
default=['train', 'eval'],
help='Data splits to prepare bootleg features for. train and eval should be included by default; test set is optional',
)

parser.add_argument(
'--ned_features',
Expand Down Expand Up @@ -335,6 +341,9 @@ def dump_bootleg_features(args, logger):

bootleg = Bootleg(args)

if 'train' not in args.bootleg_data_splits or 'eval' not in args.bootleg_data_splits:
raise ValueError('Make sure bootleg\'s data_splits contain at least train and eval set')

train_eval_shared_kwargs = {
'subsample': args.subsample,
'skip_cache': args.skip_cache,
Expand Down Expand Up @@ -414,7 +423,7 @@ def dump_bootleg_features(args, logger):

eval_examples = splits.eval.examples

if args.prep_test_too:
if 'test' in args.bootleg_data_splits:
# process test split
logger.info(f'Loading {val_task.name}')
kwargs = {'train': None, 'validation': None}
Expand Down
2 changes: 1 addition & 1 deletion genienlp/tasks/almond_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ class Translate(NaturalSeq2Seq):

def __init__(self, name, args):
super().__init__(name, args)
self._metrics = ['sacrebleu']
self._metrics = ['casedbleu']

def postprocess_prediction(self, example_id, prediction):
return super().postprocess_prediction(example_id, prediction)
Expand Down
2 changes: 0 additions & 2 deletions genienlp/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,6 @@ def load_config_json(args):
]
# these are true/ false arguments
overwrite_actions = [
'plot_heatmaps',
'replace_qp',
'force_replace_qp',
]
Expand All @@ -779,7 +778,6 @@ def load_config_json(args):
'almond_lang_as_question',
'preprocess_special_tokens',
'almond_thingtalk_version',
'plot_heatmaps',
'replace_qp',
'force_replace_qp',
'no_fast_tokenizer',
Expand Down
4 changes: 2 additions & 2 deletions tests/test_translation.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ for model in "Helsinki-NLP/opus-mt-en-de" "sshleifer/tiny-mbart" ; do

if [[ $model == Helsinki-NLP* ]] ; then
base_model="marian"
expected_result='{"sacrebleu": 88.04086004116694}'
expected_result='{"casedbleu": 88.04086004116694}'
elif [[ $model == *mbart* ]] ; then
base_model="mbart"
expected_result='{"sacrebleu": 0}'
expected_result='{"casedbleu": 0}'
fi

mv $workdir/translation/en-de/dev_"$base_model"_aligned.tsv $workdir/translation/almond/train.tsv
Expand Down

0 comments on commit d804ba7

Please sign in to comment.