-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdpr_inference.py
77 lines (61 loc) · 2.49 KB
/
dpr_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#Code refactoringe by Tae min Kim
import os
import pandas as pd
from datasets import DatasetDict, load_from_disk
from transformers import (AutoConfig, AutoModelForQuestionAnswering,
AutoTokenizer, DataCollatorWithPadding,
HfArgumentParser, Trainer, TrainingArguments)
from arguments import DataTrainingArguments, ModelArguments
from data_preprocessing import Preprocess
from dataset import Dataset
from QA_trainer import QuestionAnsweringTrainer
from utils import config_parser
from utils_taemin import (compute_metrics, data_collators,
post_processing_function, run_sparse_retrieval)
from dev_dpr import DenseRetrieval
def main(model_name, data_path):
config = AutoConfig.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForQuestionAnswering.from_pretrained(model_name,config=config)
# datasets = run_sparse_retrieval(
# tokenize_fn=tokenizer.tokenize, data_path=data_path, datasets=pd.read_csv(os.path.join(data_path, "test_data.csv"))
# )
# examples = datasets["validation"].to_pandas()
# test_data = Preprocess(tokenizer=tokenizer,dataset=datasets['validation'],state='val').output_data
model
data_collator = data_collators(tokenizer)
args = TrainingArguments(
output_dir=os.path.join(os.path.abspath(os.path.dirname(__file__)), "output"),
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=3,
weight_decay=0.1,
dataloader_num_workers=0,
logging_steps=50,
seed=42,
group_by_length=True,
do_eval=False,
do_predict=True
)
trainer = QuestionAnsweringTrainer(
model=model,
args=args,
train_dataset=None,
eval_dataset=test_data,
eval_examples=examples,
tokenizer=tokenizer,
data_collator=data_collator,
post_process_function=post_processing_function,
compute_metrics=compute_metrics,
)
predictions = trainer.predict(
test_dataset=test_data, test_examples=examples
)
print(1)
if __name__ == "__main__":
model_name = os.path.join(os.path.abspath(os.path.dirname(__file__)), "checkpoint/checkpoint-2994")
data_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "csv_data")
main(model_name=model_name, data_path=data_path)