-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
42 lines (34 loc) · 1.34 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import zipfile
import torch
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, accuracy_score
from transformers import (AutoTokenizer,
Trainer,
AutoModelForSequenceClassification,
TrainingArguments,
DataCollatorWithPadding)
from datasets import Dataset
def preprocess(example):
tokenized_example = tokenizer(example["question1"], example["question2"], truncation=True)
return tokenized_example
model_name = 'google-bert/bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained("out",num_labels=2)
df = pd.read_csv('data/test.csv')
df.fillna('An', inplace=True)
df['test_id'] = df['test_id'].str.replace("'", '')
dataset = Dataset.from_pandas(df)
dataset = dataset.map(preprocess, remove_columns=dataset.column_names)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
trainer = Trainer(
model=model,
data_collator = data_collator
)
predictions = trainer.predict(dataset)
print(predictions.predictions.shape)
preds = np.argmax(predictions.predictions, axis=-1)
df['is_duplicate']=preds
test=test_df.drop(['question1','question2'],axis=1)
test.to_csv('submit.csv',index=False)