Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
seq-to-mind authored Feb 20, 2023
1 parent deba9e4 commit 5e2f4dc
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 29 deletions.
20 changes: 6 additions & 14 deletions Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __init__(self):
super(StyleClassifier, self).__init__()
""" Roberta and BART are using the same tokenizer """
print("Loading BERT model for style classification:", global_config.corpus_mode)
self.style_classifier = RobertaForSequenceClassification.from_pretrained(global_config.pretrained_style_model, output_attentions=True)
self.style_classifier_tokenizer = BartTokenizer.from_pretrained(global_config.pretrained_tokenizer, use_fast=True)
self.style_classifier.cuda().eval()
self.style_classifier.load_state_dict(torch.load('saved_models/TextBERT_' + global_config.corpus_mode + '/TextBERT_best.chkpt'))
Expand Down Expand Up @@ -160,21 +161,12 @@ def __init__(self):
self.style_class_dict = {"informal": 0, "formal": 1, "negative": 0, "positive": 1}

if global_config.corpus_mode in ["amazon", "Yelp"]:
self.optimizer_Supervised = torch.optim.RMSprop(params=self.agent.parameters(), lr=0.00002)
self.optimizer_RL = torch.optim.RMSprop(params=self.agent.parameters(), lr=0.00002)
self.optimizer_Supervised = torch.optim.RMSprop(params=self.agent.parameters(), lr=0.00001)
self.optimizer_RL = torch.optim.RMSprop(params=self.agent.parameters(), lr=0.00001)

elif global_config.corpus_mode == "GYAFC":
self.optimizer_Supervised = torch.optim.AdamW(params=self.agent.parameters(), lr=0.00002)
self.optimizer_RL = torch.optim.AdamW(params=self.agent.parameters(), lr=0.00002)

if global_config.freeze_some_LM_layer:
for name, param in self.agent.language_backbone.named_parameters():
layer_num = re.findall("layer\.(\d+)\.", name)
if len(layer_num) > 0 and int(layer_num[0]) > 4:
print("Unfreeze layer:", int(layer_num[0]))
param.requires_grad = True
else:
param.requires_grad = False
self.optimizer_Supervised = torch.optim.AdamW(params=self.agent.parameters(), lr=0.00001)
self.optimizer_RL = torch.optim.AdamW(params=self.agent.parameters(), lr=0.00001)

def infer_mask(self, generated_id_tensor):
tmp = generated_id_tensor.detach().cpu().numpy().tolist()
Expand Down Expand Up @@ -210,7 +202,7 @@ def masking_polarity_head(self, generated_id_tensor):
if tmp_count_dict[j] > global_config.batch_size:
self.diversity_dict[j] = 100

print("############", self.iter_step, self.agent.tokenizer.decode(self.diversity_dict.keys()), "############")
# print("############", self.iter_step, self.agent.tokenizer.decode(self.diversity_dict.keys()), "############")

tmp_mask_no_special_token = [[1 if k in [0, 1, 2, mask_idx[i], ] else 0 for k, v in enumerate(j)] for i, j in enumerate(tmp)]
tmp_mask_diversity_mask = [[0.5 if v in self.diversity_dict.keys() and k not in [0, 1, 2, mask_idx[i], ] else 1.0 for k, v in enumerate(j)] for i, j in enumerate(tmp)]
Expand Down
6 changes: 2 additions & 4 deletions global_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@
sentence_seg_token = " </s> <s> "
# sentence_seg_token = " [SEP] [CLS] "

freeze_some_LM_layer = False

using_label_smoothing = True
smooth_epsilon = 0.1
smooth_epsilon = 0.15

start_from_epoch = 0
supervised_epoch_num = 2
Expand All @@ -36,7 +34,7 @@
batch_loss_print_interval = 20
print_all_predictions = True

batch_size = 32
batch_size = 64
num_epochs = 10

if corpus_mode in ["Yelp", "amazon"]:
Expand Down
22 changes: 11 additions & 11 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def setup_seed(seed):

def process_lines(text_lines, class_type, prefix_A, prefix_B):
tmp_lines = [re.sub("\s+", " ", i.replace("\t", " ")).strip()[:200] for i in text_lines if len(i) > 5]
tmp_lines = ["<s>" + prefix_A + "# " + i + "</s>" if class_type == 1 else "<s>" + prefix_B + "# " + i + "</s>" for i in tmp_lines]
tmp_lines = ["<s>" + prefix_A + " # " + i + "</s>" if class_type == 1 else "<s>" + prefix_B + " # " + i + "</s>" for i in tmp_lines]
return tmp_lines


Expand All @@ -41,8 +41,8 @@ def process_lines(text_lines, class_type, prefix_A, prefix_B):
""" Reading Yelp data, and add the unsupervised pairs """
print("Reading Yelp data with pseudo parallel data")
train_sample, test_sample = [], []
list_a = open("./pseudo_paired_data_Yelp/" + global_config.pseudo_method + "/merged_0_A_0_", "r", encoding="utf-8").readlines()
list_b = open("./pseudo_paired_data_Yelp/" + global_config.pseudo_method + "/merged_0_B_1_", "r", encoding="utf-8").readlines()
list_a = open("./pseudo_paired_data_Yelp/" + global_config.pseudo_method + "_based/merged_0_A_0", "r", encoding="utf-8").readlines()
list_b = open("./pseudo_paired_data_Yelp/" + global_config.pseudo_method + "_based/merged_0_B_1", "r", encoding="utf-8").readlines()

assert len(list_a) == len(list_b)

Expand All @@ -51,7 +51,8 @@ def process_lines(text_lines, class_type, prefix_A, prefix_B):
tmp_b = [j for j in list_b[i].split() if len(j) > 1]

if len(set(tmp_a) & set(tmp_b)) > 1:
train_sample.append((process_lines([list_a[i], ], class_type=0)[0], process_lines([list_b[i], ], class_type=1)[0]))
train_sample.append((process_lines([list_a[i], ], class_type=0, prefix_A="positive", prefix_B="negative")[0],
process_lines([list_b[i], ], class_type=1, prefix_A="positive", prefix_B="negative")[0]))

test_data_files = {"./pseudo_paired_data_Yelp/reference_0_0": "./pseudo_paired_data_Yelp/reference_0_1",
"./pseudo_paired_data_Yelp/reference_1_1": "./pseudo_paired_data_Yelp/reference_1_0"}
Expand All @@ -62,7 +63,9 @@ def process_lines(text_lines, class_type, prefix_A, prefix_B):
process_lines(open(test_data_files[i], encoding="utf-8").readlines(), class_type=int(test_data_files[i][-1]),
prefix_A="positive", prefix_B="negative")))

train_sample = [(i[1], i[0]) for i in train_sample] + train_sample
# train_sample = [(i[1], i[0]) for i in train_sample] + train_sample
tmp_sample_size = 20000
train_sample = [(i[1], i[0]) for i in train_sample[:tmp_sample_size]] + train_sample[tmp_sample_size:tmp_sample_size * 2]
test_sample = [(i[0], i[1]) for i in test_sample]

""" adding multiple human references """
Expand All @@ -88,8 +91,8 @@ def process_lines(text_lines, class_type, prefix_A, prefix_B):
train_sample, test_sample = [], []

""" reading amazon training sample version 2.0: add post processing """
list_a = open("amazon_data_paired/merged_0_A_0_" + global_config.pseudo_method, encoding="utf-8").readlines()
list_b = open("amazon_data_paired/merged_0_B_1_" + global_config.pseudo_method, encoding="utf-8").readlines()
list_a = open("./pseudo_paired_data_Amazon/" + global_config.pseudo_method + "_based/merged_0_A_0", "r", encoding="utf-8").readlines()
list_b = open("./pseudo_paired_data_Amazon/" + global_config.pseudo_method + "_based/merged_0_B_1", "r", encoding="utf-8").readlines()

assert len(list_a) == len(list_b)

Expand Down Expand Up @@ -178,9 +181,6 @@ def train_process(model, train_data):
model.supervised_loss_decay = 0.6 ** (train_epoch + 1)
print("Supervised loss decay: " + str(model.supervised_loss_decay))

model.teacher_forcing_rate = global_config.MLE_teacher_forcing_anneal_rate ** train_epoch
print("model.teacher_forcing_rate,", model.teacher_forcing_rate)

for batch in train_batches:
supervised_loss, cyclic_loss, GAN_dis_loss, GAN_gen_loss, transferred_sen_text, _ = model.batch_train(batch, train_epoch)
summary_steps += 1
Expand Down Expand Up @@ -258,7 +258,7 @@ def evaluate_process(model, data_test_collection, train_epoch, best_score=None,
bleu_score = corpus_bleu(list_of_references=all_gold_sentences, hypotheses=all_transferred_sentences)
all_eval_score += bleu_score

[print(data_test_collection[tmp_i][-i], "\n", " ".join(all_transferred_sentences[-i])) for i in range(20)]
# [print(data_test_collection[tmp_i][-i], "\n", " ".join(all_transferred_sentences[-i])) for i in range(20)]
# [print(all_gold_sentences[i], "\n", all_transferred_sentences[i], "\n\n") for i in range(10)]

if any(all_test_loss):
Expand Down

0 comments on commit 5e2f4dc

Please sign in to comment.