Skip to content

Commit

Permalink
2021/9/18 1.新增VAT并更改compute_kl_loss实现,效果不佳; 2.新增LSR,FOCAL待办; 3.去掉rdro…
Browse files Browse the repository at this point in the history
…p效果如何待尝试。
  • Loading branch information
louishsu committed Sep 17, 2021
1 parent c328067 commit 04c371c
Show file tree
Hide file tree
Showing 2 changed files with 289 additions and 16 deletions.
136 changes: 121 additions & 15 deletions run_span.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,74 @@ def batched_index_select(input, index):
output = torch.bmm(index_onehot, input)
return output


class LabelSmoothingCE(nn.Module):

def __init__(self, eps=0.1, reduction='mean', ignore_index=-100):
super().__init__()

self.eps = eps
self.reduction = reduction
self.ignore_index = ignore_index

def forward(self, input, target):
c = input.size()[-1]
log_preds = F.log_softmax(input, dim=-1)
if self.reduction == 'sum':
loss = -log_preds.sum()
else:
loss = -log_preds.sum(dim=-1)
if self.reduction == 'mean':
loss = loss.mean()
loss_1 = loss * self.eps / c
loss_2 = F.nll_loss(log_preds, target, reduction=self.reduction, ignore_index=self.ignore_index)
return loss_1 + (1 - self.eps) * loss_2


class FocalLoss(nn.Module):
"""
Softmax and sigmoid focal loss
"""

def __init__(self, num_labels, activation_type='softmax', reduction='mean',
gamma=2.0, alpha=0.25, epsilon=1.e-9):

super(FocalLoss, self).__init__()
self.num_labels = num_labels
self.gamma = gamma
self.alpha = alpha
self.epsilon = epsilon
self.activation_type = activation_type
self.reduction = reduction

def forward(self, input, target):
"""
Args:
logits: pretrain_model's output, shape of [batch_size, num_cls]
target: ground truth labels, shape of [batch_size]
Returns:
shape of [batch_size]
"""
if self.activation_type == 'softmax':
idx = target.view(-1, 1).long()
one_hot_key = torch.zeros(idx.size(0), self.num_labels, dtype=torch.float32, device=idx.device)
one_hot_key = one_hot_key.scatter_(1, idx, 1)
logits = F.softmax(input, dim=-1)
loss = -self.alpha * one_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log()
loss = loss.sum(1)
elif self.activation_type == 'sigmoid':
multi_hot_key = target
logits = F.sigmoid(input)
zero_hot_key = 1 - multi_hot_key
loss = -self.alpha * multi_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log()
loss += -(1 - self.alpha) * zero_hot_key * torch.pow(logits, self.gamma) * (1 - logits + self.epsilon).log()
if self.reduction == "mean":
loss = loss.mean()
elif self.reduction == "none":
pass
return loss


class SpanV2(nn.Module):

def __init__(self, hidden_size, num_labels, max_span_length, width_embedding_dim):
Expand Down Expand Up @@ -127,7 +195,14 @@ class SpanV2Loss(nn.Module):

def __init__(self):
super().__init__()
self.loss_fct = nn.CrossEntropyLoss(reduction='none')
self.loss_fct = None
if args.loss_type == "ce":
self.loss_fct = nn.CrossEntropyLoss(reduction='none')
elif args.loss_type == "lsr":
self.loss_fct = LabelSmoothingCE(eps=args.label_smooth_eps, reduction='none')
elif args.loss_type == "focal":
self.loss_fct = FocalLoss(num_labels=..., reduction='none',
gamma=args.focal_gamma, alpha=args.focal_alpha) # TODO:

def forward(self,
logits=None, # (batch_size, num_spans, num_labels)
Expand Down Expand Up @@ -192,21 +267,42 @@ def forward(
attentions=outputs.attentions,
)

def compute_kl_loss(p, q, pad_mask=None):
# def compute_kl_loss(p, q, pad_mask=None):

# batch_size, num_spans, num_labels = p.size()
# if pad_mask is None:
# pad_mask = torch.ones(batch_size, num_spans, dtype=torch.bool, device=p.device)
# pad_mask = pad_mask.unsqueeze(-1).expand(batch_size, num_spans, num_labels)

p_loss = F.kl_div(F.log_softmax(p, dim=-1), F.softmax(q, dim=-1), reduction='none')
q_loss = F.kl_div(F.log_softmax(q, dim=-1), F.softmax(p, dim=-1), reduction='none')
# p_loss = F.kl_div(F.log_softmax(p, dim=-1), F.softmax(q, dim=-1), reduction='none')
# q_loss = F.kl_div(F.log_softmax(q, dim=-1), F.softmax(p, dim=-1), reduction='none')

# pad_mask is for seq-level tasks
if pad_mask is not None:
p_loss.masked_fill_(pad_mask, 0.)
q_loss.masked_fill_(pad_mask, 0.)
# # pad_mask is for seq-level tasks
# p_loss.masked_fill_(pad_mask, 0.)
# q_loss.masked_fill_(pad_mask, 0.)

# # You can choose whether to use function "sum" and "mean" depending on your task
# p_loss = p_loss.mean()
# q_loss = q_loss.mean()

# loss = (p_loss + q_loss) / 2
# return loss

def compute_kl_loss(p, q, pad_mask=None):

# You can choose whether to use function "sum" and "mean" depending on your task
p_loss = p_loss.mean()
q_loss = q_loss.mean()
batch_size, num_spans, num_labels = p.size()
if pad_mask is None:
pad_mask = torch.ones(batch_size, num_spans, dtype=torch.bool, device=p.device)
pad_mask = pad_mask.unsqueeze(-1).expand(batch_size, num_spans, num_labels)

p_loss = F.kl_div(F.log_softmax(p, dim=-1), F.softmax(q, dim=-1), reduction='none')
q_loss = F.kl_div(F.log_softmax(q, dim=-1), F.softmax(p, dim=-1), reduction='none')

mask_valid = ~pad_mask
p_loss = p_loss[mask_valid].mean()
q_loss = q_loss[mask_valid].mean()
loss = (p_loss + q_loss) / 2

return loss

def forward_rdrop(cls, alpha, **kwargs):
Expand All @@ -216,7 +312,7 @@ def forward_rdrop(cls, alpha, **kwargs):
outputs2 = forward(cls, **kwargs)
rdrop_loss = compute_kl_loss(
outputs1["logits"], outputs2["logits"],
kwargs["span_mask"].unsqueeze(-1) == 0)
kwargs["span_mask"] == 0)
total_loss = (outputs1["loss"] + outputs2["loss"]) / 2. + alpha * rdrop_loss
return TokenClassifierOutput(
loss=total_loss,
Expand Down Expand Up @@ -301,12 +397,17 @@ def build_arguments(self):
self.add_argument("--augment_context_aware_p", default=None, type=float)
self.add_argument("--augment_entity_replace_p", default=None, type=float)
self.add_argument("--rdrop_alpha", default=None, type=float)
self.add_argument("--vat_alpha", default=None, type=float)

# Other parameters
self.add_argument('--scheme', default='IOB2', type=str,
choices=['IOB2', 'IOBES'])
self.add_argument('--loss_type', default='ce', type=str,
choices=['lsr', 'focal', 'ce'])
self.add_argument('--label_smooth_eps', default=0.1, type=float)
self.add_argument('--focal_gamma', default=2.0, type=float)
self.add_argument('--focal_alpha', default=0.25, type=float)

self.add_argument("--config_name", default="", type=str,
help="Pretrained config name or path if not the same as model_name")
self.add_argument("--tokenizer_name", default="", type=str,
Expand Down Expand Up @@ -802,12 +903,17 @@ def train(args, model, processor, tokenizer):
loss = loss / args.gradient_accumulation_steps
if args.fp16:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
scaled_loss.backward(retain_graph=args.vat_alpha is not None)
else:
loss.backward()
loss.backward(retain_graph=args.vat_alpha is not None)
if args.do_fgm:
fgm.attack()
loss_adv = model(**batch)[0]
outputs_adv = model(**batch)
loss_adv = outputs_adv[0]
if args.vat_alpha is not None:
loss_vat = compute_kl_loss(outputs["logits"], outputs_adv["logits"],
pad_mask=batch["span_mask"] == 0)
loss_adv = loss_adv + args.vat_alpha * loss_vat
if args.n_gpu > 1:
loss_adv = loss_adv.mean()
if args.gradient_accumulation_steps > 1:
Expand Down
169 changes: 168 additions & 1 deletion scripts/run_span.sh
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,108 @@ done
# 组织机构
# {'p': 0.9203860072376358, 'r': 0.9466501240694789, 'f': 0.9333333333333333}

# TODO: 去掉rdrop
for k in 0 1 2 3 4
do
python run_span.py \
--version=nezha-fgm1.0-fold${k} \
--data_dir=./data/ner-ctx0-5fold-seed42/ \
--train_file=train.${k}.json \
--dev_file=dev.${k}.json \
--test_file=dev.${k}.json \
--model_type=nezha_span \
--model_name_or_path=/home/louishsu/NewDisk/Garage/weights/transformers/nezha-cn-base/ \
--do_train \
--overwrite_output_dir \
--evaluate_during_training \
--evaluate_each_epoch \
--save_best_checkpoints \
--max_span_length=40 \
--width_embedding_dim=128 \
--train_max_seq_length=512 \
--eval_max_seq_length=512 \
--do_lower_case \
--per_gpu_train_batch_size=8 \
--per_gpu_eval_batch_size=16 \
--gradient_accumulation_steps=2 \
--learning_rate=5e-5 \
--other_learning_rate=1e-3 \
--num_train_epochs=4.0 \
--warmup_proportion=0.1 \
--do_fgm --fgm_epsilon=1.0 \
--seed=42
done

# TODO: rdrop待定,label smooth 0.1
for k in 0 1 2 3 4
do
python run_span.py \
--version=nezha-fgm1.0-lsr0.1-fold${k} \
--data_dir=./data/ner-ctx0-5fold-seed42/ \
--train_file=train.${k}.json \
--dev_file=dev.${k}.json \
--test_file=dev.${k}.json \
--model_type=nezha_span \
--model_name_or_path=/home/louishsu/NewDisk/Garage/weights/transformers/nezha-cn-base/ \
--do_train \
--overwrite_output_dir \
--evaluate_during_training \
--evaluate_each_epoch \
--save_best_checkpoints \
--max_span_length=40 \
--width_embedding_dim=128 \
--train_max_seq_length=512 \
--eval_max_seq_length=512 \
--do_lower_case \
--per_gpu_train_batch_size=8 \
--per_gpu_eval_batch_size=16 \
--gradient_accumulation_steps=2 \
--learning_rate=5e-5 \
--other_learning_rate=1e-3 \
--num_train_epochs=4.0 \
--warmup_proportion=0.1 \
--do_fgm --fgm_epsilon=1.0 \
--loss_type=lsr --label_smooth_eps=0.1 \
--seed=42
done

# TODO: focal
for k in 0 1 2 3 4
do
python run_span.py \
--version=nezha-fgm1.0-focalg2.0a0.25-fold${k} \
--data_dir=./data/ner-ctx0-5fold-seed42/ \
--train_file=train.${k}.json \
--dev_file=dev.${k}.json \
--test_file=dev.${k}.json \
--model_type=nezha_span \
--model_name_or_path=/home/louishsu/NewDisk/Garage/weights/transformers/nezha-cn-base/ \
--do_train \
--overwrite_output_dir \
--evaluate_during_training \
--evaluate_each_epoch \
--save_best_checkpoints \
--max_span_length=40 \
--width_embedding_dim=128 \
--train_max_seq_length=512 \
--eval_max_seq_length=512 \
--do_lower_case \
--per_gpu_train_batch_size=8 \
--per_gpu_eval_batch_size=16 \
--gradient_accumulation_steps=2 \
--learning_rate=5e-5 \
--other_learning_rate=1e-3 \
--num_train_epochs=4.0 \
--warmup_proportion=0.1 \
--do_fgm --fgm_epsilon=1.0 \
--loss_type=focal --focal_gamma=2.0 --focal_alpha=0.25 \
--seed=42
done
# <<< 第二阶段 <<<



# ==================================================================================================================
# for k in 0 1 2 3 4
# do
# python run_span.py \
Expand Down Expand Up @@ -284,5 +386,70 @@ done
# ./data/ner-ctx0-5fold-seed42/dev.gt.${k}.json \
# output/ner-cail_ner-bert_span-legal_electra_base-fold${k}-42/test_prediction.json
# done
# <<< 第二阶段 <<<

# for k in 0 1 2 3 4
# do
# python run_span.py \
# --version=nezha-rdrop0.1-fgm1.0-fp16-fold${k} \
# --data_dir=./data/ner-ctx0-5fold-seed42/ \
# --train_file=train.${k}.json \
# --dev_file=dev.${k}.json \
# --test_file=dev.${k}.json \
# --model_type=nezha_span \
# --model_name_or_path=/home/louishsu/NewDisk/Garage/weights/transformers/nezha-cn-base/ \
# --do_train \
# --overwrite_output_dir \
# --evaluate_during_training \
# --evaluate_each_epoch \
# --save_best_checkpoints \
# --max_span_length=40 \
# --width_embedding_dim=128 \
# --train_max_seq_length=512 \
# --eval_max_seq_length=512 \
# --do_lower_case \
# --per_gpu_train_batch_size=6 \
# --per_gpu_eval_batch_size=12 \
# --gradient_accumulation_steps=2 \
# --learning_rate=5e-5 \
# --other_learning_rate=1e-3 \
# --num_train_epochs=4.0 \
# --warmup_proportion=0.1 \
# --rdrop_alpha=0.1 \
# --do_fgm --fgm_epsilon=1.0 \
# --seed=42 \
# --fp16
# done

# for k in 0 1 2 3 4
# do
# python run_span.py \
# --version=nezha-rdrop0.1-vat0.1-fgm1.0-fp16-fold${k} \
# --data_dir=./data/ner-ctx0-5fold-seed42/ \
# --train_file=train.${k}.json \
# --dev_file=dev.${k}.json \
# --test_file=dev.${k}.json \
# --model_type=nezha_span \
# --model_name_or_path=/home/louishsu/NewDisk/Garage/weights/transformers/nezha-cn-base/ \
# --do_train \
# --overwrite_output_dir \
# --evaluate_during_training \
# --evaluate_each_epoch \
# --save_best_checkpoints \
# --max_span_length=40 \
# --width_embedding_dim=128 \
# --train_max_seq_length=512 \
# --eval_max_seq_length=512 \
# --do_lower_case \
# --per_gpu_train_batch_size=6 \
# --per_gpu_eval_batch_size=12 \
# --gradient_accumulation_steps=2 \
# --learning_rate=5e-5 \
# --other_learning_rate=1e-3 \
# --num_train_epochs=4.0 \
# --warmup_proportion=0.1 \
# --rdrop_alpha=0.1 \
# --do_fgm --fgm_epsilon=1.0 \
# --vat_alpha=0.1 \
# --seed=42 \
# --fp16
# done

0 comments on commit 04c371c

Please sign in to comment.