Skip to content

Commit

Permalink
2021/9/18 nezha-fgm1.0:去掉rdrop,线下f=0.89,效果较差
Browse files Browse the repository at this point in the history
  • Loading branch information
louishsu committed Sep 18, 2021
1 parent 04c371c commit b4d6f83
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 42 deletions.
2 changes: 1 addition & 1 deletion evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ def analyze_error(ground_truth_path, output_path):
y_pred = [prediction_positions_label_map[k] for k in itersection_positions]
labels = list(LABEL_MEANING_MAP.values())
cm = confusion_matrix(y_true, y_pred, labels=labels)
disp = ConfusionMatrixDisplay(cm, display_labels=labels).plot()
label_error_type_count_map = Counter(["-".join(e.split("-")[-2:]) for e in label_error])
# 保存
for list_name in ["ground_truth_entities", "prediction_entities",
Expand All @@ -145,6 +144,7 @@ def analyze_error(ground_truth_path, output_path):
f.write(line + "\n")
print(label_error_type_count_map)
print(labels)
disp = ConfusionMatrixDisplay(cm, display_labels=labels).plot()
plt.savefig(os.path.join("tmp", "cm.jpg"))

if __name__ == '__main__':
Expand Down
6 changes: 6 additions & 0 deletions main_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ def main():
dataset_name = "cail_ner"
n_splits = 5
seed=42
# --------------------------
# version = "nezha-fgm1.0"
# model_type = "nezha_span"
# dataset_name = "cail_ner"
# n_splits = 5
# seed=42

test_examples = []
test_batches = []
Expand Down
52 changes: 26 additions & 26 deletions run_span.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,44 +267,44 @@ def forward(
attentions=outputs.attentions,
)

# def compute_kl_loss(p, q, pad_mask=None):

# batch_size, num_spans, num_labels = p.size()
# if pad_mask is None:
# pad_mask = torch.ones(batch_size, num_spans, dtype=torch.bool, device=p.device)
# pad_mask = pad_mask.unsqueeze(-1).expand(batch_size, num_spans, num_labels)

# p_loss = F.kl_div(F.log_softmax(p, dim=-1), F.softmax(q, dim=-1), reduction='none')
# q_loss = F.kl_div(F.log_softmax(q, dim=-1), F.softmax(p, dim=-1), reduction='none')

# # pad_mask is for seq-level tasks
# p_loss.masked_fill_(pad_mask, 0.)
# q_loss.masked_fill_(pad_mask, 0.)

# # You can choose whether to use function "sum" and "mean" depending on your task
# p_loss = p_loss.mean()
# q_loss = q_loss.mean()

# loss = (p_loss + q_loss) / 2
# return loss

def compute_kl_loss(p, q, pad_mask=None):

batch_size, num_spans, num_labels = p.size()
if pad_mask is None:
pad_mask = torch.ones(batch_size, num_spans, dtype=torch.bool, device=p.device)
pad_mask = pad_mask.unsqueeze(-1).expand(batch_size, num_spans, num_labels)

p_loss = F.kl_div(F.log_softmax(p, dim=-1), F.softmax(q, dim=-1), reduction='none')
q_loss = F.kl_div(F.log_softmax(q, dim=-1), F.softmax(p, dim=-1), reduction='none')

mask_valid = ~pad_mask
p_loss = p_loss[mask_valid].mean()
q_loss = q_loss[mask_valid].mean()
loss = (p_loss + q_loss) / 2
# pad_mask is for seq-level tasks
p_loss.masked_fill_(pad_mask, 0.)
q_loss.masked_fill_(pad_mask, 0.)

# You can choose whether to use function "sum" and "mean" depending on your task
p_loss = p_loss.mean()
q_loss = q_loss.mean()

loss = (p_loss + q_loss) / 2
return loss

# def compute_kl_loss(p, q, pad_mask=None):

# batch_size, num_spans, num_labels = p.size()
# if pad_mask is None:
# pad_mask = torch.ones(batch_size, num_spans, dtype=torch.bool, device=p.device)
# pad_mask = pad_mask.unsqueeze(-1).expand(batch_size, num_spans, num_labels)

# p_loss = F.kl_div(F.log_softmax(p, dim=-1), F.softmax(q, dim=-1), reduction='none')
# q_loss = F.kl_div(F.log_softmax(q, dim=-1), F.softmax(p, dim=-1), reduction='none')

# mask_valid = ~pad_mask
# p_loss = p_loss[mask_valid].mean()
# q_loss = q_loss[mask_valid].mean()
# loss = (p_loss + q_loss) / 2

# return loss

def forward_rdrop(cls, alpha, **kwargs):
outputs1 = forward(cls, **kwargs)
if outputs1.loss is None or alpha <= 0.: return outputs1
Expand Down
56 changes: 41 additions & 15 deletions scripts/run_span.sh
Original file line number Diff line number Diff line change
Expand Up @@ -172,28 +172,29 @@ python run_span.py \
--do_fgm --fgm_epsilon=1.0 \
--seed=42
done
# main_local
# avg
# {'p': 0.9452107279693487, 'r': 0.9515911282545805, 'f': 0.9483901970206632}
# {'p': 0.8833097595473833, 'r': 0.8901016465999024, 'f': 0.8866926971434976}
# 犯罪嫌疑人
# {'p': 0.9821573398215734, 'r': 0.9813614262560778, 'f': 0.9817592217267938}
# {'p': 0.951840490797546, 'r': 0.9602351848986539, 'f': 0.9560194099976893}
# 受害人
# {'p': 0.93026941362916, 'r': 0.9882154882154882, 'f': 0.9583673469387756}
# {'p': 0.9128682170542636, 'r': 0.9472329472329473, 'f': 0.9297331438496764}
# 被盗货币
# {'p': 0.8461538461538461, 'r': 0.953757225433526, 'f': 0.8967391304347825}
# {'p': 0.8006134969325154, 'r': 0.8557377049180328, 'f': 0.8272583201267829}
# 物品价值
# {'p': 0.9901960784313726, 'r': 0.9901960784313726, 'f': 0.9901960784313726}
# {'p': 0.967896502156205, 'r': 0.9665071770334929, 'f': 0.9672013406751256}
# 盗窃获利
# {'p': 0.9652173913043478, 'r': 0.940677966101695, 'f': 0.9527896995708155}
# {'p': 0.8346007604562737, 'r': 0.9126819126819127, 'f': 0.8718967229394241}
# 被盗物品
# {'p': 0.9335106382978723, 'r': 0.9196506550218341, 'f': 0.9265288165420149}
# {'p': 0.7937117903930131, 'r': 0.78602317938073, 'f': 0.7898487745524074}
# 作案工具
# {'p': 0.8472222222222222, 'r': 0.9606299212598425, 'f': 0.9003690036900368}
# {'p': 0.7383059418457648, 'r': 0.7945578231292517, 'f': 0.765399737876802}
# 时间
# {'p': 0.9455252918287937, 'r': 0.9257142857142857, 'f': 0.9355149181905678}
# {'p': 0.9410688140556369, 'r': 0.9298372513562387, 'f': 0.9354193196288886}
# 地点
# {'p': 0.9415121255349501, 'r': 0.9179415855354659, 'f': 0.9295774647887323}
# {'p': 0.8487467588591184, 'r': 0.8376457207847597, 'f': 0.8431597023468804}
# 组织机构
# {'p': 0.8940397350993378, 'r': 0.9507042253521126, 'f': 0.9215017064846417}
# {'p': 0.8557336621454994, 'r': 0.8610421836228288, 'f': 0.8583797155225728}

for k in 0 1 2 3 4
do
Expand Down Expand Up @@ -250,7 +251,7 @@ done
# 组织机构
# {'p': 0.9203860072376358, 'r': 0.9466501240694789, 'f': 0.9333333333333333}

# TODO: 去掉rdrop
# 去掉rdrop
for k in 0 1 2 3 4
do
python run_span.py \
Expand Down Expand Up @@ -281,12 +282,35 @@ python run_span.py \
--do_fgm --fgm_epsilon=1.0 \
--seed=42
done
# main_local
# avg
# {'p': 0.8900258630383447, 'r': 0.8906267581861146, 'f': 0.890326209223847}
# 犯罪嫌疑人
# {'p': 0.9531777709548664, 'r': 0.9606993656196813, 'f': 0.9569237882407337}
# 受害人
# {'p': 0.9161769254562326, 'r': 0.953024453024453, 'f': 0.9342375019712978}
# 被盗货币
# {'p': 0.8041666666666667, 'r': 0.8437158469945355, 'f': 0.8234666666666667}
# 物品价值
# {'p': 0.9702352376380221, 'r': 0.9669856459330144, 'f': 0.9686077162712677}
# 盗窃获利
# {'p': 0.8529980657640233, 'r': 0.9168399168399168, 'f': 0.8837675350701403}
# 被盗物品
# {'p': 0.8070051300194587, 'r': 0.7891368275384881, 'f': 0.7979709637921986}
# 作案工具
# {'p': 0.7707774798927614, 'r': 0.782312925170068, 'f': 0.7765023632680621}
# 时间
# {'p': 0.9405504587155963, 'r': 0.9269439421338156, 'f': 0.9336976320582878}
# 地点
# {'p': 0.8563938246431693, 'r': 0.8359397213534262, 'f': 0.8460431654676259}
# 组织机构
# {'p': 0.8588957055214724, 'r': 0.8684863523573201, 'f': 0.8636644046884641}

# TODO: rdrop待定,label smooth 0.1
# TODO: label smooth 0.1
for k in 0 1 2 3 4
do
python run_span.py \
--version=nezha-fgm1.0-lsr0.1-fold${k} \
--version=nezha-rdrop0.1-fgm1.0-lsr0.1-fold${k} \
--data_dir=./data/ner-ctx0-5fold-seed42/ \
--train_file=train.${k}.json \
--dev_file=dev.${k}.json \
Expand All @@ -310,6 +334,7 @@ python run_span.py \
--other_learning_rate=1e-3 \
--num_train_epochs=4.0 \
--warmup_proportion=0.1 \
--rdrop_alpha=0.1 \
--do_fgm --fgm_epsilon=1.0 \
--loss_type=lsr --label_smooth_eps=0.1 \
--seed=42
Expand All @@ -319,7 +344,7 @@ done
for k in 0 1 2 3 4
do
python run_span.py \
--version=nezha-fgm1.0-focalg2.0a0.25-fold${k} \
--version=nezha-rdrop0.1-fgm1.0-focalg2.0a0.25-fold${k} \
--data_dir=./data/ner-ctx0-5fold-seed42/ \
--train_file=train.${k}.json \
--dev_file=dev.${k}.json \
Expand All @@ -343,6 +368,7 @@ python run_span.py \
--other_learning_rate=1e-3 \
--num_train_epochs=4.0 \
--warmup_proportion=0.1 \
--rdrop_alpha=0.1 \
--do_fgm --fgm_epsilon=1.0 \
--loss_type=focal --focal_gamma=2.0 --focal_alpha=0.25 \
--seed=42
Expand Down

0 comments on commit b4d6f83

Please sign in to comment.