-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_iczs.py
109 lines (91 loc) · 3.08 KB
/
main_iczs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import logging
import json
import torch
from torch.utils.data import DataLoader
import hydra
from omegaconf import DictConfig, OmegaConf
from meffi_prompt.utils import resolve_relative_path, seed_everything, aggregate_batch
from meffi_prompt.data import SmilerDataset
from meffi_prompt.prompt import SmilerPrompt, get_max_decode_length
from meffi_prompt.model import T5Model
from meffi_prompt.tokenizer import BatchTokenizer
from meffi_prompt.eval import eval
logger = logging.getLogger(__name__)
# in-order + w/o reversed
template = {
"input": ["x", "eh", "<extra_id_0>", "et"],
"target": ["<extra_id_0>", "r", "<extra_id_1>"],
}
# # post-order
# template = {
# "input": ["x", "eh", "et", "<extra_id_0>"],
# "target": ["<extra_id_0>", "r", "<extra_id_1>"],
# }
@hydra.main(config_name="config_iczs", config_path="configs")
def main(cfg: DictConfig) -> None:
"""
Conducts evaluation given the configuration.
Args:
cfg: Hydra-format configuration given in a dict.
"""
resolve_relative_path(cfg)
print(OmegaConf.to_yaml(cfg))
seed_everything(cfg.seed)
device = (
torch.device("cuda", cfg.cuda_device)
if cfg.cuda_device > -1
else torch.device("cpu")
)
# get raw dataset and do simple pre-processing such as convert special tokens
eval_dataset = SmilerDataset(cfg.eval_file)
# transform to prompted dataset, with appended inputs and verbalized labels
prompt = SmilerPrompt(
template=template,
model_name=cfg.model,
soft_token_length=0,
)
eval_dataset, verbalizer = prompt(
eval_dataset, translate=True, return_verbalizer=True
)
# set dataloader
eval_loader = DataLoader(
eval_dataset,
batch_size=cfg.batch_size,
shuffle=False,
pin_memory=True,
collate_fn=aggregate_batch,
)
# instantiate tokenizer and model
batch_processor = BatchTokenizer(
tokenizer_name_or_path=cfg.model,
max_length=cfg.max_length,
num_soft_tokens=0,
)
tokenized_verbalizer = {
k: batch_processor.tokenizer(v, add_special_tokens=False)["input_ids"]
for k, v in verbalizer.items()
}
max_relation_length = max([len(v) for v in tokenized_verbalizer.values()])
max_decode_length = get_max_decode_length(template, max_relation_length)
logger.info("Max decode length: {}.".format(max_decode_length))
model = T5Model(
cfg.model,
max_decode_length=max_decode_length,
tokenizer=batch_processor.tokenizer,
)
micro_f1, macro_f1 = eval(
model=model,
eval_loader=eval_loader,
batch_processor=batch_processor,
device=device,
label_column_name=eval_dataset.label_column_name,
tokenized_verbalizer=tokenized_verbalizer,
)
logger.info(
"Evaluation micro-F1: {:.4f}, macro_f1: {:.4f}.".format(micro_f1, macro_f1)
)
# save evaluation results to json
with open("./results.json", "w") as f:
json.dump({"micro_f1": micro_f1, "macro_f1": macro_f1}, f, indent=4)
if __name__ == "__main__":
main()