-
Notifications
You must be signed in to change notification settings - Fork 0
/
same_analyze.py
35 lines (26 loc) · 1.13 KB
/
same_analyze.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import argparse
import sys
from sys import stdin
import torch
from transformers import BertTokenizer
import utils
from models import BERTQuestionAnalyzer
def analyze(model, tokenizer, question1, question2):
question = torch.tensor(tokenizer.encode(question1 + '|||' + question2), dtype=torch.long, device=utils.device).unsqueeze(0)
label = torch.tensor(0, dtype=torch.long, device=utils.device).unsqueeze(0)
_, output = model(question, label)
output = torch.softmax(output, dim=1)
return output[0, 1].item()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('question1', type=str, help='Question 1 ^^')
parser.add_argument('question2', type=str, help='Question 2 ^^')
args, _ = parser.parse_known_args(sys.argv[1:])
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
torch.no_grad()
best_model = BERTQuestionAnalyzer().to(utils.device)
utils.load_checkpoint('logs/qqp/model_bert.pt', best_model)
best_model.eval()
output = analyze(best_model, tokenizer, args.question1, args.question2)
print('')
print(f"Same at {round(output*100, 2)}%")