From 04c371c6dd31ff337780c728605a361aff2b41d0 Mon Sep 17 00:00:00 2001 From: louishsu Date: Sat, 18 Sep 2021 00:22:04 +0800 Subject: [PATCH] =?UTF-8?q?2021/9/18=201.=E6=96=B0=E5=A2=9EVAT=E5=B9=B6?= =?UTF-8?q?=E6=9B=B4=E6=94=B9compute=5Fkl=5Floss=E5=AE=9E=E7=8E=B0?= =?UTF-8?q?=EF=BC=8C=E6=95=88=E6=9E=9C=E4=B8=8D=E4=BD=B3;=202.=E6=96=B0?= =?UTF-8?q?=E5=A2=9ELSR=EF=BC=8CFOCAL=E5=BE=85=E5=8A=9E;=203.=E5=8E=BB?= =?UTF-8?q?=E6=8E=89rdrop=E6=95=88=E6=9E=9C=E5=A6=82=E4=BD=95=E5=BE=85?= =?UTF-8?q?=E5=B0=9D=E8=AF=95=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- run_span.py | 136 +++++++++++++++++++++++++++++++---- scripts/run_span.sh | 169 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 289 insertions(+), 16 deletions(-) diff --git a/run_span.py b/run_span.py index 4970a1c..760b3db 100644 --- a/run_span.py +++ b/run_span.py @@ -71,6 +71,74 @@ def batched_index_select(input, index): output = torch.bmm(index_onehot, input) return output + +class LabelSmoothingCE(nn.Module): + + def __init__(self, eps=0.1, reduction='mean', ignore_index=-100): + super().__init__() + + self.eps = eps + self.reduction = reduction + self.ignore_index = ignore_index + + def forward(self, input, target): + c = input.size()[-1] + log_preds = F.log_softmax(input, dim=-1) + if self.reduction == 'sum': + loss = -log_preds.sum() + else: + loss = -log_preds.sum(dim=-1) + if self.reduction == 'mean': + loss = loss.mean() + loss_1 = loss * self.eps / c + loss_2 = F.nll_loss(log_preds, target, reduction=self.reduction, ignore_index=self.ignore_index) + return loss_1 + (1 - self.eps) * loss_2 + + +class FocalLoss(nn.Module): + """ + Softmax and sigmoid focal loss + """ + + def __init__(self, num_labels, activation_type='softmax', reduction='mean', + gamma=2.0, alpha=0.25, epsilon=1.e-9): + + super(FocalLoss, self).__init__() + self.num_labels = num_labels + self.gamma = gamma + self.alpha = alpha + self.epsilon = epsilon + self.activation_type = activation_type + self.reduction = reduction + + def forward(self, input, target): + """ + Args: + logits: pretrain_model's output, shape of [batch_size, num_cls] + target: ground truth labels, shape of [batch_size] + Returns: + shape of [batch_size] + """ + if self.activation_type == 'softmax': + idx = target.view(-1, 1).long() + one_hot_key = torch.zeros(idx.size(0), self.num_labels, dtype=torch.float32, device=idx.device) + one_hot_key = one_hot_key.scatter_(1, idx, 1) + logits = F.softmax(input, dim=-1) + loss = -self.alpha * one_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log() + loss = loss.sum(1) + elif self.activation_type == 'sigmoid': + multi_hot_key = target + logits = F.sigmoid(input) + zero_hot_key = 1 - multi_hot_key + loss = -self.alpha * multi_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log() + loss += -(1 - self.alpha) * zero_hot_key * torch.pow(logits, self.gamma) * (1 - logits + self.epsilon).log() + if self.reduction == "mean": + loss = loss.mean() + elif self.reduction == "none": + pass + return loss + + class SpanV2(nn.Module): def __init__(self, hidden_size, num_labels, max_span_length, width_embedding_dim): @@ -127,7 +195,14 @@ class SpanV2Loss(nn.Module): def __init__(self): super().__init__() - self.loss_fct = nn.CrossEntropyLoss(reduction='none') + self.loss_fct = None + if args.loss_type == "ce": + self.loss_fct = nn.CrossEntropyLoss(reduction='none') + elif args.loss_type == "lsr": + self.loss_fct = LabelSmoothingCE(eps=args.label_smooth_eps, reduction='none') + elif args.loss_type == "focal": + self.loss_fct = FocalLoss(num_labels=..., reduction='none', + gamma=args.focal_gamma, alpha=args.focal_alpha) # TODO: def forward(self, logits=None, # (batch_size, num_spans, num_labels) @@ -192,21 +267,42 @@ def forward( attentions=outputs.attentions, ) -def compute_kl_loss(p, q, pad_mask=None): +# def compute_kl_loss(p, q, pad_mask=None): + +# batch_size, num_spans, num_labels = p.size() +# if pad_mask is None: +# pad_mask = torch.ones(batch_size, num_spans, dtype=torch.bool, device=p.device) +# pad_mask = pad_mask.unsqueeze(-1).expand(batch_size, num_spans, num_labels) - p_loss = F.kl_div(F.log_softmax(p, dim=-1), F.softmax(q, dim=-1), reduction='none') - q_loss = F.kl_div(F.log_softmax(q, dim=-1), F.softmax(p, dim=-1), reduction='none') +# p_loss = F.kl_div(F.log_softmax(p, dim=-1), F.softmax(q, dim=-1), reduction='none') +# q_loss = F.kl_div(F.log_softmax(q, dim=-1), F.softmax(p, dim=-1), reduction='none') - # pad_mask is for seq-level tasks - if pad_mask is not None: - p_loss.masked_fill_(pad_mask, 0.) - q_loss.masked_fill_(pad_mask, 0.) +# # pad_mask is for seq-level tasks +# p_loss.masked_fill_(pad_mask, 0.) +# q_loss.masked_fill_(pad_mask, 0.) + +# # You can choose whether to use function "sum" and "mean" depending on your task +# p_loss = p_loss.mean() +# q_loss = q_loss.mean() + +# loss = (p_loss + q_loss) / 2 +# return loss + +def compute_kl_loss(p, q, pad_mask=None): - # You can choose whether to use function "sum" and "mean" depending on your task - p_loss = p_loss.mean() - q_loss = q_loss.mean() + batch_size, num_spans, num_labels = p.size() + if pad_mask is None: + pad_mask = torch.ones(batch_size, num_spans, dtype=torch.bool, device=p.device) + pad_mask = pad_mask.unsqueeze(-1).expand(batch_size, num_spans, num_labels) + p_loss = F.kl_div(F.log_softmax(p, dim=-1), F.softmax(q, dim=-1), reduction='none') + q_loss = F.kl_div(F.log_softmax(q, dim=-1), F.softmax(p, dim=-1), reduction='none') + + mask_valid = ~pad_mask + p_loss = p_loss[mask_valid].mean() + q_loss = q_loss[mask_valid].mean() loss = (p_loss + q_loss) / 2 + return loss def forward_rdrop(cls, alpha, **kwargs): @@ -216,7 +312,7 @@ def forward_rdrop(cls, alpha, **kwargs): outputs2 = forward(cls, **kwargs) rdrop_loss = compute_kl_loss( outputs1["logits"], outputs2["logits"], - kwargs["span_mask"].unsqueeze(-1) == 0) + kwargs["span_mask"] == 0) total_loss = (outputs1["loss"] + outputs2["loss"]) / 2. + alpha * rdrop_loss return TokenClassifierOutput( loss=total_loss, @@ -301,12 +397,17 @@ def build_arguments(self): self.add_argument("--augment_context_aware_p", default=None, type=float) self.add_argument("--augment_entity_replace_p", default=None, type=float) self.add_argument("--rdrop_alpha", default=None, type=float) + self.add_argument("--vat_alpha", default=None, type=float) # Other parameters self.add_argument('--scheme', default='IOB2', type=str, choices=['IOB2', 'IOBES']) self.add_argument('--loss_type', default='ce', type=str, choices=['lsr', 'focal', 'ce']) + self.add_argument('--label_smooth_eps', default=0.1, type=float) + self.add_argument('--focal_gamma', default=2.0, type=float) + self.add_argument('--focal_alpha', default=0.25, type=float) + self.add_argument("--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name") self.add_argument("--tokenizer_name", default="", type=str, @@ -802,12 +903,17 @@ def train(args, model, processor, tokenizer): loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() + scaled_loss.backward(retain_graph=args.vat_alpha is not None) else: - loss.backward() + loss.backward(retain_graph=args.vat_alpha is not None) if args.do_fgm: fgm.attack() - loss_adv = model(**batch)[0] + outputs_adv = model(**batch) + loss_adv = outputs_adv[0] + if args.vat_alpha is not None: + loss_vat = compute_kl_loss(outputs["logits"], outputs_adv["logits"], + pad_mask=batch["span_mask"] == 0) + loss_adv = loss_adv + args.vat_alpha * loss_vat if args.n_gpu > 1: loss_adv = loss_adv.mean() if args.gradient_accumulation_steps > 1: diff --git a/scripts/run_span.sh b/scripts/run_span.sh index 513ccf9..8ca6669 100644 --- a/scripts/run_span.sh +++ b/scripts/run_span.sh @@ -250,6 +250,108 @@ done # 组织机构 # {'p': 0.9203860072376358, 'r': 0.9466501240694789, 'f': 0.9333333333333333} +# TODO: 去掉rdrop +for k in 0 1 2 3 4 +do +python run_span.py \ + --version=nezha-fgm1.0-fold${k} \ + --data_dir=./data/ner-ctx0-5fold-seed42/ \ + --train_file=train.${k}.json \ + --dev_file=dev.${k}.json \ + --test_file=dev.${k}.json \ + --model_type=nezha_span \ + --model_name_or_path=/home/louishsu/NewDisk/Garage/weights/transformers/nezha-cn-base/ \ + --do_train \ + --overwrite_output_dir \ + --evaluate_during_training \ + --evaluate_each_epoch \ + --save_best_checkpoints \ + --max_span_length=40 \ + --width_embedding_dim=128 \ + --train_max_seq_length=512 \ + --eval_max_seq_length=512 \ + --do_lower_case \ + --per_gpu_train_batch_size=8 \ + --per_gpu_eval_batch_size=16 \ + --gradient_accumulation_steps=2 \ + --learning_rate=5e-5 \ + --other_learning_rate=1e-3 \ + --num_train_epochs=4.0 \ + --warmup_proportion=0.1 \ + --do_fgm --fgm_epsilon=1.0 \ + --seed=42 +done + +# TODO: rdrop待定,label smooth 0.1 +for k in 0 1 2 3 4 +do +python run_span.py \ + --version=nezha-fgm1.0-lsr0.1-fold${k} \ + --data_dir=./data/ner-ctx0-5fold-seed42/ \ + --train_file=train.${k}.json \ + --dev_file=dev.${k}.json \ + --test_file=dev.${k}.json \ + --model_type=nezha_span \ + --model_name_or_path=/home/louishsu/NewDisk/Garage/weights/transformers/nezha-cn-base/ \ + --do_train \ + --overwrite_output_dir \ + --evaluate_during_training \ + --evaluate_each_epoch \ + --save_best_checkpoints \ + --max_span_length=40 \ + --width_embedding_dim=128 \ + --train_max_seq_length=512 \ + --eval_max_seq_length=512 \ + --do_lower_case \ + --per_gpu_train_batch_size=8 \ + --per_gpu_eval_batch_size=16 \ + --gradient_accumulation_steps=2 \ + --learning_rate=5e-5 \ + --other_learning_rate=1e-3 \ + --num_train_epochs=4.0 \ + --warmup_proportion=0.1 \ + --do_fgm --fgm_epsilon=1.0 \ + --loss_type=lsr --label_smooth_eps=0.1 \ + --seed=42 +done + +# TODO: focal +for k in 0 1 2 3 4 +do +python run_span.py \ + --version=nezha-fgm1.0-focalg2.0a0.25-fold${k} \ + --data_dir=./data/ner-ctx0-5fold-seed42/ \ + --train_file=train.${k}.json \ + --dev_file=dev.${k}.json \ + --test_file=dev.${k}.json \ + --model_type=nezha_span \ + --model_name_or_path=/home/louishsu/NewDisk/Garage/weights/transformers/nezha-cn-base/ \ + --do_train \ + --overwrite_output_dir \ + --evaluate_during_training \ + --evaluate_each_epoch \ + --save_best_checkpoints \ + --max_span_length=40 \ + --width_embedding_dim=128 \ + --train_max_seq_length=512 \ + --eval_max_seq_length=512 \ + --do_lower_case \ + --per_gpu_train_batch_size=8 \ + --per_gpu_eval_batch_size=16 \ + --gradient_accumulation_steps=2 \ + --learning_rate=5e-5 \ + --other_learning_rate=1e-3 \ + --num_train_epochs=4.0 \ + --warmup_proportion=0.1 \ + --do_fgm --fgm_epsilon=1.0 \ + --loss_type=focal --focal_gamma=2.0 --focal_alpha=0.25 \ + --seed=42 +done +# <<< 第二阶段 <<< + + + +# ================================================================================================================== # for k in 0 1 2 3 4 # do # python run_span.py \ @@ -284,5 +386,70 @@ done # ./data/ner-ctx0-5fold-seed42/dev.gt.${k}.json \ # output/ner-cail_ner-bert_span-legal_electra_base-fold${k}-42/test_prediction.json # done -# <<< 第二阶段 <<< +# for k in 0 1 2 3 4 +# do +# python run_span.py \ +# --version=nezha-rdrop0.1-fgm1.0-fp16-fold${k} \ +# --data_dir=./data/ner-ctx0-5fold-seed42/ \ +# --train_file=train.${k}.json \ +# --dev_file=dev.${k}.json \ +# --test_file=dev.${k}.json \ +# --model_type=nezha_span \ +# --model_name_or_path=/home/louishsu/NewDisk/Garage/weights/transformers/nezha-cn-base/ \ +# --do_train \ +# --overwrite_output_dir \ +# --evaluate_during_training \ +# --evaluate_each_epoch \ +# --save_best_checkpoints \ +# --max_span_length=40 \ +# --width_embedding_dim=128 \ +# --train_max_seq_length=512 \ +# --eval_max_seq_length=512 \ +# --do_lower_case \ +# --per_gpu_train_batch_size=6 \ +# --per_gpu_eval_batch_size=12 \ +# --gradient_accumulation_steps=2 \ +# --learning_rate=5e-5 \ +# --other_learning_rate=1e-3 \ +# --num_train_epochs=4.0 \ +# --warmup_proportion=0.1 \ +# --rdrop_alpha=0.1 \ +# --do_fgm --fgm_epsilon=1.0 \ +# --seed=42 \ +# --fp16 +# done + +# for k in 0 1 2 3 4 +# do +# python run_span.py \ +# --version=nezha-rdrop0.1-vat0.1-fgm1.0-fp16-fold${k} \ +# --data_dir=./data/ner-ctx0-5fold-seed42/ \ +# --train_file=train.${k}.json \ +# --dev_file=dev.${k}.json \ +# --test_file=dev.${k}.json \ +# --model_type=nezha_span \ +# --model_name_or_path=/home/louishsu/NewDisk/Garage/weights/transformers/nezha-cn-base/ \ +# --do_train \ +# --overwrite_output_dir \ +# --evaluate_during_training \ +# --evaluate_each_epoch \ +# --save_best_checkpoints \ +# --max_span_length=40 \ +# --width_embedding_dim=128 \ +# --train_max_seq_length=512 \ +# --eval_max_seq_length=512 \ +# --do_lower_case \ +# --per_gpu_train_batch_size=6 \ +# --per_gpu_eval_batch_size=12 \ +# --gradient_accumulation_steps=2 \ +# --learning_rate=5e-5 \ +# --other_learning_rate=1e-3 \ +# --num_train_epochs=4.0 \ +# --warmup_proportion=0.1 \ +# --rdrop_alpha=0.1 \ +# --do_fgm --fgm_epsilon=1.0 \ +# --vat_alpha=0.1 \ +# --seed=42 \ +# --fp16 +# done \ No newline at end of file