-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathinltk_classifier.py
68 lines (49 loc) · 2.04 KB
/
inltk_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""
iNLTK public marathi headlines dataset
"""
import pandas as pd
import numpy as np
from datasets import load_metric
from datasets import Dataset
from datasets import ClassLabel
from transformers import TrainingArguments, Trainer, AutoConfig
from transformers import AutoTokenizer, AutoModelForSequenceClassification
MAX_LEN = 128
MODEL_NAME = "flax-community/roberta-base-mr"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
metric = load_metric("accuracy")
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
return metric.compute(predictions=predictions, references=labels)
def tokenize_function(examples):
return tokenizer(examples["headline"], padding="max_length", truncation=True, max_length=MAX_LEN)
train_df = pd.read_csv("train.csv")
valid_df = pd.read_csv("valid.csv")
label_names = train_df["label"].unique().tolist()
num_labels = len(label_names)
cl = ClassLabel(num_classes=num_labels, names=label_names)
valid_df["label"] = valid_df["label"].map(lambda x: cl.str2int(x))
train_df["label"] = train_df["label"].map(lambda x: cl.str2int(x))
print(label_names)
label2id = {label : cl.str2int(label) for label in label_names}
id2label = {cl.str2int(label) : label for label in label_names}
print(label2id)
config = AutoConfig.from_pretrained(MODEL_NAME, label2id=label2id, id2label=id2label)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, from_flax=True, config=config)
train_ds = Dataset.from_pandas(train_df)
valid_ds = Dataset.from_pandas(valid_df)
valid_tokenized_data = valid_ds.map(tokenize_function, batched=True)
train_tokenized_data = train_ds.map(tokenize_function, batched=True)
training_args = TrainingArguments("inltk_trainer", report_to=None)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_tokenized_data,
eval_dataset=valid_tokenized_data,
compute_metrics=compute_metrics,
)
trainer.train()
model.save_pretrained("inltk-mr-classifier")
tokenizer.save_pretrained("inltk-mr-classifier")
trainer.evaluate()