-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_fs_all.py
122 lines (107 loc) · 3.44 KB
/
main_fs_all.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import logging
import json
import torch
from torch.utils.data import DataLoader
import hydra
from hydra.utils import to_absolute_path
from omegaconf import DictConfig, OmegaConf
from mtb.data import SmilerDataset, SmilerFewShotDataset
from mtb.model import MTBModel
from mtb.processor import BatchTokenizer, aggregate_batch
from mtb.train_eval import train_and_eval
from mtb.utils import resolve_relative_path, seed_everything
logger = logging.getLogger(__name__)
_LANGUAGES = [
"ar",
"de",
"en",
"es",
"fa",
"fr",
"it",
"ko",
"nl",
"pl",
"pt",
"ru",
"sv",
"uk",
]
@hydra.main(config_name="fs_config", config_path="configs")
def main(cfg: DictConfig) -> None:
"""
Conducts evaluation given the configuration.
Args:
cfg: Hydra-format configuration given in a dict.
"""
resolve_relative_path(cfg)
print(OmegaConf.to_yaml(cfg))
for language in _LANGUAGES:
train_file = to_absolute_path(
"./data/smiler/{}_corpora_train.json".format(language)
)
eval_file = to_absolute_path(
"./data/smiler/{}_corpora_test.json".format(language)
)
seed_everything(cfg.seed)
# prepare dataset: parse raw dataset and do some simple pre-processing such as
# convert special tokens and insert entity markers
entity_marker = True if cfg.variant in ["d", "e", "f"] else False
train_dataset = SmilerFewShotDataset(train_file, cfg.kshot, entity_marker=entity_marker)
eval_dataset = SmilerDataset(eval_file, entity_marker=entity_marker)
layer_norm = True
label_to_id = train_dataset.label_to_id
# set dataloader
train_loader = DataLoader(
train_dataset,
batch_size=cfg.batch_size,
shuffle=True,
pin_memory=True,
collate_fn=aggregate_batch,
)
eval_loader = DataLoader(
eval_dataset,
batch_size=cfg.batch_size,
shuffle=False,
pin_memory=True,
collate_fn=aggregate_batch,
)
# set a processor that tokenizes and aligns all the tokens in a batch
batch_processor = BatchTokenizer(
tokenizer_name_or_path=cfg.model,
variant=cfg.variant,
max_length=cfg.max_length,
)
vocab_size = len(batch_processor.tokenizer)
# set model and device
model = MTBModel(
encoder_name_or_path=cfg.model,
variant=cfg.variant,
layer_norm=layer_norm,
vocab_size=vocab_size,
num_classes=len(label_to_id),
dropout=cfg.dropout,
)
device = (
torch.device("cuda", cfg.cuda_device)
if cfg.cuda_device > -1
else torch.device("cpu")
)
micro_f1, macro_f1 = train_and_eval(
model,
train_loader,
eval_loader,
label_to_id,
batch_processor,
num_epochs=cfg.num_epochs,
lr=cfg.lr,
device=device,
)
logger.info(
"{} Evaluation micro-F1: {:.4f}, macro_f1: {:.4f}.".format(language, micro_f1, macro_f1)
)
# save evaluation results to json
with open("./{}_{}_{}_results.json".format(language, cfg.kshot, cfg.seed), "w") as f:
json.dump({"micro_f1": micro_f1, "macro_f1": macro_f1}, f, indent=4)
if __name__ == "__main__":
main()