Skip to content

Commit

Permalink
2021/10/1 pseudo label
Browse files Browse the repository at this point in the history
  • Loading branch information
louishsu committed Oct 1, 2021
1 parent 3d0c047 commit f5026e9
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 21 deletions.
33 changes: 19 additions & 14 deletions run_span.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,15 +217,20 @@ def forward(self,
if args.do_pseudo:
proba = logits.softmax(dim=-1)
proba, index = proba.max(dim=-1)
condition = (
label > args.pseudo_proba_ub
) | (
label < args.pseudo_proba_lb
) & (
label == PSEUDO_TAG
)
label = torch.where(condition, index, label)
loss_mask = loss_mask & (label.view(-1) != PSEUDO_TAG)

is_pseudo = label == PSEUDO_TAG
label = torch.where(is_pseudo, index, label) # 用预测标签替换无标签
pseudo_valid_mask = is_pseudo & (
proba > args.pseudo_proba_thresh
) # 有效伪标签:是伪标签、且大于阈值
# pseudo_valid_mask = is_pseudo & (
# proba > args.pseudo_proba_thresh
# ) & (
# index != 0
# ) # 有效伪标签:是伪标签、且大于阈值、是实体
loss_mask = (mask == 1) & (~is_pseudo) # 重新初始化loss_mask:真实标签
loss_mask = loss_mask | pseudo_valid_mask # 合并`真实标签`和`有效伪标签`
loss_mask = loss_mask.view(-1)

loss = self.loss_fct(logits.view(-1, num_labels), label.view(-1))
loss = loss[loss_mask].mean()
Expand Down Expand Up @@ -474,8 +479,7 @@ def build_arguments(self):
self.add_argument("--pseudo_data_dir", default=None, type=str)
self.add_argument("--pseudo_data_file", default=None, type=str)
self.add_argument("--pseudo_num_sample", default=None, type=int)
self.add_argument("--pseudo_proba_ub", default=0.99, type=float)
self.add_argument("--pseudo_proba_lb", default=0.01, type=float)
self.add_argument("--pseudo_proba_thresh", default=0.99, type=float)

# Other parameters
self.add_argument('--scheme', default='IOB2', type=str,
Expand Down Expand Up @@ -606,6 +610,10 @@ def _create_examples(self, data_dir, data_file, mode):
logger.info(f"Creating examples from {data_path}...")
with open(data_path, encoding="utf-8") as f:
lines = [json.loads(line) for line in f.readlines()]
# 无标签数据数量限制
if mode == "pseudo" and args.pseudo_num_sample is not None:
random.shuffle(lines)
lines = lines[:args.pseudo_num_sample]
logger.info(f"Totally {len(lines)} examples.")
for sentence_counter, line in enumerate(lines):
sentence = (
Expand Down Expand Up @@ -1417,9 +1425,6 @@ def load_dataset(args, processor, tokenizer, data_type='train'):
elif data_type == 'pseudo':
examples = processor.get_train_examples(args.data_dir, args.train_file)
examples_pseudo = processor.get_pseudo_examples(args.pseudo_data_dir, args.pseudo_data_file)
if args.pseudo_num_sample is not None:
random.shuffle(examples_pseudo)
examples_pseudo = examples_pseudo[:args.pseudo_num_sample]
examples.extend(examples_pseudo)
if args.local_rank == 0 and not evaluate:
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
Expand Down
13 changes: 6 additions & 7 deletions scripts/run_span.sh
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ python prepare_data.py \
for k in 0 1 2 3 4
do
python run_span.py \
--version=nezha-legal-fgm1.0-lsr0.1-pseudo_u0.99_l0.01-fold${k} \
--version=nezha-legal-fgm1.0-lsr0.1-pseudo_t0.9-fold${k} \
--data_dir=./data/ner-ctx0-5fold-seed42/ \
--train_file=train.${k}.json \
--dev_file=dev.${k}.json \
Expand All @@ -602,18 +602,17 @@ python run_span.py \
--per_gpu_train_batch_size=8 \
--per_gpu_eval_batch_size=16 \
--gradient_accumulation_steps=2 \
--learning_rate=1e-6 \
--other_learning_rate=1e-6 \
--num_train_epochs=2.0 \
--learning_rate=1e-5 \
--other_learning_rate=1e-5 \
--num_train_epochs=1.0 \
--warmup_proportion=0.1 \
--do_fgm --fgm_epsilon=1.0 \
--loss_type=lsr --label_smooth_eps=0.1 \
--do_pseudo \
--pseudo_data_dir=../cail_processed_data/ner-ctx0-1fold-seed42/ \
--pseudo_data_file=train.json \
--pseudo_num_sample=2000 \
--pseudo_proba_ub=0.99 \
--pseudo_proba_lb=0.01 \
--pseudo_num_sample=1500 \
--pseudo_proba_thresh=0.9 \
--seed=42
done

Expand Down

0 comments on commit f5026e9

Please sign in to comment.