forked from yixinL7/BRIO
-
Notifications
You must be signed in to change notification settings - Fork 1
/
config.py
71 lines (70 loc) · 3.59 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def cnndm_setting(args):
# default setting for cnndm
args.batch_size = getattr(args, 'batch_size', 1)
args.epoch = getattr(args, 'epoch', 100)
args.report_freq = getattr(args, "report_freq", 100)
args.accumulate_step = getattr(args, "accumulate_step", 8)
args.margin = getattr(args, "margin", 0.001)
args.gold_margin = getattr(args, "gold_margin", 0)
args.gold_weight = getattr(args, "gold_weight", 0)
args.mle_weight = getattr(args, "mle_weight", 0.1)
args.rank_weight = getattr(args, "rank_weight", 10)
args.model_type = getattr(args, "model_type", "facebook/bart-large-cnn")
args.warmup_steps = getattr(args, "warmup_steps", 10000)
args.normalize = getattr(args, "normalize", True)
args.grad_norm = getattr(args, "grad_norm", 0)
args.seed = getattr(args, "seed", 970903)
args.no_gold = getattr(args, "no_gold", False)
args.pretrained = getattr(args, "pretrained", None)
args.max_lr = getattr(args, "max_lr", 2e-3)
args.scale = getattr(args, "scale", 1)
args.score_mode = getattr(args, "score_mode", "log")
args.datatype = getattr(args, "datatype", "diverse")
args.dataset = getattr(args, "dataset", "cnndm")
args.max_len = getattr(args, "max_len", 120)
args.max_num = getattr(args, "max_num", 16)
args.smooth = getattr(args, "smooth", 0.1)
args.total_len = getattr(args, "total_len", 1024)
args.length_penalty = getattr(args, "length_penalty", 2.0)
args.do_sample = getattr(args, "do_sample", True)
args.gen_max_len = getattr(args, "gen_max_len", 140)
args.gen_min_len = getattr(args, "gen_min_len", 55)
args.is_pegasus = getattr(args, "is_pegasus", False)
args.adding = getattr(args, "adding", 0)
args.eval_interval = getattr(args, "eval_interval", 1000)
args.num_beams = getattr(args, "num_beams", 4)
def xsum_setting(args):
# default setting for xsum
args.batch_size = getattr(args, 'batch_size', 2)
args.epoch = getattr(args, 'epoch', 100)
args.report_freq = getattr(args, "report_freq", 100)
args.accumulate_step = getattr(args, "accumulate_step", 4)
args.margin = getattr(args, "margin", 0.001)
args.gold_margin = getattr(args, "gold_margin", 0)
args.gold_weight = getattr(args, "gold_weight", 0)
args.mle_weight = getattr(args, "mle_weight", 0.1)
args.rank_weight = getattr(args, "rank_weight", 10)
args.model_type = getattr(args, "model_type", "google/pegasus-xsum")
args.warmup_steps = getattr(args, "warmup_steps", 10000)
args.normalize = getattr(args, "normalize", True)
args.grad_norm = getattr(args, "grad_norm", 0)
args.seed = getattr(args, "seed", 970903)
args.no_gold = getattr(args, "no_gold", False)
args.pretrained = getattr(args, "pretrained", None)
args.max_lr = getattr(args, "max_lr", 2e-3)
args.scale = getattr(args, "scale", 0.01)
args.score_mode = getattr(args, "score_mode", "log")
args.datatype = getattr(args, "datatype", "diverse")
args.dataset = getattr(args, "dataset", "xsum")
args.max_len = getattr(args, "max_len", 80)
args.max_num = getattr(args, "max_num", 16)
args.smooth = getattr(args, "smooth", 0.1)
args.total_len = getattr(args, "total_len", 512)
args.length_penalty = getattr(args, "length_penalty", 0.6)
args.do_sample = getattr(args, "do_sample", True)
args.gen_max_len = getattr(args, "gen_max_len", 62)
args.gen_min_len = getattr(args, "gen_min_len", 11)
args.is_pegasus = getattr(args, "is_pegasus", True)
args.adding = getattr(args, "adding", 0)
args.eval_interval = getattr(args, "eval_interval", 1000)
args.num_beams = getattr(args, "num_beams", 8)