-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtuner.py
98 lines (80 loc) · 2.45 KB
/
tuner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# Basics
import pandas as pd
import numpy as np
import os
import sys
# tf
import tensorflow as tf
# Models
from transformers import (AlbertConfig, TFAlbertModel,
DistilBertConfig, TFDistilBertModel,
RobertaConfig, TFRobertaModel)
import sentencepiece as spm
from preprocess.preprocess import Preprocess
from models.baseline import BaselineModel
from models.polytuplet import PolytupletModel
# Patch
from mock import patch
# Tuning
import keras_tuner as kt
# Get system arguments
model_name_map = {
"albert": 0,
"distilbert": 1,
"roberta": 2
}
IS_BASELINE = sys.argv[1] == "baseline"
MODEL_INDEX = model_name_map[sys.argv[2]]
USE_MIXED = sys.argv[3] == "mixed"
# Clear
if os.name == 'nt':
os.system('cls')
else:
os.system('clear')
print("="*40)
print("Tuner Config")
print("="*40)
print(f"Model Information\n\tBaseline Arch.: {IS_BASELINE}\n\tModel Name: {sys.argv[2]}\n\tIndex: {MODEL_INDEX}\n")
print(f"Data Information\n\tMixing: {USE_MIXED}")
print("="*40)
# Ensure path exists
if not os.path.exists("dataset/processed"):
os.makedirs("dataset/processed")
# Load cleaned data
print("Loading data...")
df = pd.read_csv("dataset/cleaned/dev.csv")
print("Data loaded")
# Load preprocessor
print("Preprocessing...")
preprocessor = Preprocess(df=df, model_index=MODEL_INDEX)
# Generate dataset-ready tuples
train_data, val_data = preprocessor.get_datasets(mixed=USE_MIXED)
if IS_BASELINE:
train_data = ({"input_ids": train_data[0]["r_input_ids"], "attention_mask": train_data[0]["r_attention_mask"]}, train_data[1])
val_data = ({"input_ids": val_data[0]["r_input_ids"], "attention_mask": val_data[0]["r_attention_mask"]}, val_data[1])
train_data = tf.data.Dataset.from_tensor_slices(train_data)
val_data = tf.data.Dataset.from_tensor_slices(val_data)
print("Preprocessing complete")
print("Building model")
if IS_BASELINE:
model = BaselineModel(preprocessor.RESULT_LEN, model_index=MODEL_INDEX)
else:
model = PolytupletModel(preprocessor.CONTEXT_LEN, preprocessor.RESULT_LEN, model_index=MODEL_INDEX)
if IS_BASELINE:
model.tune_hyperparams(
train_data=train_data,
validation_data=val_data,
dropout_range=(0.0, 0.5),
learning_rate_range=(1e-7, 1e-5)
)
else:
model.tune_hyperparams(
train_data=train_data,
validation_data=val_data,
dropout_range=(0.0, 0.3),
learning_rate_range=(1e-7, 1e-5),
final_learning_rate_range=(1e-7, 1e-5),
alpha_range=(0.0, 10.0),
m_range=(0.0, 2.0),
hard_w_range=(0.0, 1.0)
)