Skip to content

Commit

Permalink
Merge pull request #8 from guangxush/imdb
Browse files Browse the repository at this point in the history
Imdb
  • Loading branch information
guangxush authored Dec 16, 2018
2 parents 768fafc + 854fdbb commit 11605ac
Show file tree
Hide file tree
Showing 156 changed files with 82,410 additions and 83,389 deletions.
75 changes: 29 additions & 46 deletions adaptive_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,61 +6,43 @@
from util.data_load import make_err_dataset
from util import data_process
from model.model2 import mlp2
from util.data_load import generate_model2_data
from util.data_load import generate_imdb_model2_data2
from util.util import cal_err_ratio
import numpy as np
from model_use import model_use
import os
from sentiment_analysis import train_e2e_model
from model.model1 import lstm_attention_model, lstm_mul_model


# train model1
def model1(i):
results_flag = True
if i > 10:
i = i % 10

model2_file = './modfile/model2file/mlp.best_model.h5'
result_file = './data/err_data/news_'+str(i)+'.data'
data2_path = './data/model2_data/news_'+str(i)+'_data.csv'
model2_file = './modfile/model2file/imdb.mlp.best_model.h5'
result_file = './data/err_data/imdb_'+str(i)+'.data'
data2_path = './data/model2_data/imdb_'+str(i)+'_data.csv'
# pos_file = "./data/part_data/train_pos_" + str(i) + ".txt"
# neg_file = "./data/part_data/train_neg_" + str(i) + ".txt"
train_file = "./data/part_data_all/train_" + str(i) + ".txt"
# train model1
modelname = 'BiLSTM_Attention'
datafile = "./modfile/model1_data/data" + "_fold_" + str(i) + ".pkl"
modelfile = modelname + "_fold_" + str(i) + ".pkl"

trainfile = "./data/mix_data_train_data.json"
testfile = "./data/mix_data_test_data.json"
w2v_file = "./modfile/Word2Vec.mod"
char2v_file = "./modfile/Char2Vec.mod"
resultdir = "./result/"
print(modelname)

maxlen = 100
batch_size = 128
npochos = 100

if not os.path.exists(datafile):
print("Precess data " + str(i) + "....")
data_process.get_part_train_test_data(trainfile, testfile, w2v_file, char2v_file, datafile, w2v_k=100, c2v_k=100,
maxlen=maxlen, left=i)

if not os.path.exists("./modfile/model1file/" + modelfile):
print("data has existed: " + datafile)
print("Training EE " + str(i) + " model....")
train_e2e_model(modelname, datafile, modelfile, resultdir,
npochos=npochos, batch_size=batch_size, retrain=False)

monitor = 'val_acc'
filepath = "./modfile/model1file/lstm.best_model_"+str(i)+".h5"
check_pointer = ModelCheckpoint(filepath=filepath, monitor=monitor, verbose=1,
save_best_only=True, save_weights_only=True)
early_stopping = EarlyStopping(patience=5)
csv_logger = CSVLogger('logs/imdb_model2_mlp_' + str(i) + '.log')
Xtrain, Xtest, ytrain, ytest = data_process.get_imdb_part_data2(raw_file=train_file)
model = lstm_mul_model()
model.fit(Xtrain, ytrain, batch_size=32, epochs=50, validation_data=(Xtest, ytest), verbose=1, shuffle=True,
callbacks=[check_pointer, early_stopping, csv_logger])
if results_flag:
print('Generate model2 dataset ...')
data_file = "./modfile/model1_data/data_fold_"
result_path = './data/model2_data/news_' + str(i) + '_data.csv'
model_name = 'BiLSTM_Attention'
modle_file = "BiLSTM_Attention_fold_"
testfile = './data/mix_data_test_data.json'
# filepath = "./modfile/model2file/mlp.best_model.h5"
batch_size = 128
generate_model2_data(model_name=model_name, datafile=data_file, model_file=modle_file, testfile=testfile,
result_path=result_path, batch_size=batch_size, count=10)
result_path = './data/model2_data/imdb_' + str(i) + '_data.csv'
model_file = './modfile/model1file/lstm.best_model_'
# test_pos_file = './data/part_data/test_pos_0.txt'
# test_neg_file = './data/part_data/test_neg_0.txt'
test_file = './data/part_data_all/test_0.txt'
generate_imdb_model2_data2(model_file=model_file, result_path=result_path, test_file=test_file, count=10)
print('Load result ...')

X_test, Y_test = load_data3(data_path=data2_path)
Expand All @@ -81,8 +63,8 @@ def model1(i):
# train model2
def model2(i):
results_flag = True
data_path = './data/model2_data/news_'+str(i)+'_data.csv'
filepath = "./modfile/model2file/mlp.best_model.h5"
data_path = './data/model2_data/imdb_'+str(i)+'_data.csv'
filepath = "./modfile/model2file/imdb.mlp.best_model.h5"
print('***** Start Model2 Train *****')
print('Loading data ...')
x_train, y_train, x_test, y_test = load_data2(data_path=data_path)
Expand All @@ -92,15 +74,16 @@ def model2(i):
check_pointer = ModelCheckpoint(filepath=filepath, monitor=monitor, verbose=1,
save_best_only=True, save_weights_only=True)
early_stopping = EarlyStopping(patience=5)
csv_logger = CSVLogger('logs/model2_mlp_'+str(i)+'.log')
csv_logger = CSVLogger('logs/imdb_model2_mlp_'+str(i)+'.log')
mlp_model2 = mlp2(sample_dim=x_train.shape[1], class_count=2)
mlp_model2.fit(x_train, y_train, batch_size=128, epochs=100, verbose=1, shuffle=True, validation_split=0.2,
mlp_model2.fit(x_train, y_train, batch_size=128, epochs=100, verbose=1, shuffle=True, validation_data=(x_test, y_test),
callbacks=[check_pointer, early_stopping, csv_logger])
if results_flag:
print('Generate submission ...')
mlp_model2.load_weights(filepath=filepath)
results = mlp_model2.predict(x_test)
label = np.argmax(results, axis=1)
y_test = np.argmax(y_test, axis=1)
print("pred:", end='')
print(label)
print("true:", end='')
Expand Down
Loading

0 comments on commit 11605ac

Please sign in to comment.