diff --git a/main_local.py b/main_local.py index 75c6447..2d96f55 100644 --- a/main_local.py +++ b/main_local.py @@ -42,7 +42,13 @@ def main(): # n_splits = 5 # seed=42 # -------------------------- - version = "nezha-rdrop0.1-fgm1.0-focalg2.0a0.25" + # version = "nezha-rdrop0.1-fgm1.0-focalg2.0a0.25" + # model_type = "nezha_span" + # dataset_name = "cail_ner" + # n_splits = 5 + # seed=42 + # -------------------------- + version = "nezha-rdrop0.1-fgm1.0-aug_ctx0.15" model_type = "nezha_span" dataset_name = "cail_ner" n_splits = 5 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..8ba2d04 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,253 @@ +absl-py==0.13.0 +aiohttp==3.7.4.post0 +allennlp==2.6.0 +anykeystore==0.2 +apex @ file:///home/louishsu/Downloads/apex +argon2-cffi==20.1.0 +astor==0.8.1 +astunparse==1.6.3 +async-generator==1.10 +async-timeout==3.0.1 +attrs==19.3.0 +backcall==0.1.0 +backports.csv==1.0.7 +backports.functools-lru-cache==1.6.1 +backports.tempfile==1.0 +backports.weakref==1.0.post1 +beautifulsoup4==4.9.3 +bert-serving-client==1.10.0 +bert-serving-server==1.10.0 +bleach==3.1.5 +blinker==1.4 +blis==0.4.1 +boto3==1.18.18 +botocore==1.21.18 +brotlipy==0.7.0 +bs4==0.0.1 +bz2file==0.98 +cachetools==4.2.2 +catalogue==2.0.4 +certifi==2020.6.20 +cffi==1.14.3 +chardet==3.0.4 +charset-normalizer==2.0.4 +checklist==0.0.11 +cheroot==8.5.2 +CherryPy==18.6.1 +click==7.1.2 +configparser==5.0.2 +coverage==5.5 +cryptography==3.2.1 +cycler==0.10.0 +cymem==2.0.3 +Cython==0.29.24 +debugpy==1.4.1 +decorator==4.4.2 +defusedxml==0.6.0 +dill==0.3.4 +docker-pycreds==0.4.0 +docutils==0.17.1 +entrypoints==0.3 +fastprogress==0.2.3 +fasttext==0.9.2 +feedparser==6.0.8 +filelock==3.0.12 +Flask==1.1.2 +flatbuffers==1.12 +fsspec==2021.7.0 +funcsigs==1.0.2 +future==0.18.2 +gast==0.4.0 +gensim==4.0.1 +gitdb==4.0.7 +GitPython==3.1.20 +google-api-core==1.31.1 +google-auth==1.34.0 +google-auth-oauthlib==0.4.5 +google-cloud-core==1.7.2 +google-cloud-storage==1.41.1 +google-crc32c==1.1.2 +google-pasta==0.2.0 +google-resumable-media==1.3.3 +googleapis-common-protos==1.53.0 +GPUtil==1.4.0 +graphql-core==3.1.5 +grpcio==1.34.1 +h5py==3.1.0 +huggingface-hub==0.0.12 +idna==2.10 +importlib-metadata==1.6.0 +iniconfig==1.1.1 +ipykernel==6.0.3 +ipython==7.26.0 +ipython-genutils==0.2.0 +ipywidgets==7.6.3 +iso-639==0.4.5 +itsdangerous==2.0.1 +jaraco.classes==3.2.1 +jaraco.collections==3.4.0 +jaraco.functools==3.3.0 +jaraco.text==3.5.1 +jedi==0.18.0 +jieba==0.42.1 +Jinja2==3.0.1 +jmespath==0.9.5 +joblib==0.14.1 +JPype1==0.7.0 +jsonnet==0.17.0 +jsonschema==3.2.0 +jupyter==1.0.0 +jupyter-client==6.2.0 +jupyter-console==6.4.0 +jupyter-core==4.7.1 +jupyterlab-pygments==0.1.2 +jupyterlab-widgets==1.0.0 +keras==2.6.0 +Keras-Applications==1.0.8 +keras-bert==0.88.0 +keras-embed-sim==0.7.0 +keras-layer-normalization==0.14.0 +keras-multi-head==0.22.0 +keras-nightly==2.5.0.dev2021032900 +keras-pos-embd==0.11.0 +keras-position-wise-feed-forward==0.6.0 +Keras-Preprocessing==1.1.2 +keras-self-attention==0.41.0 +keras-transformer==0.33.0 +kiwisolver==1.3.1 +lightgbm==2.3.1 +lmdb==1.2.1 +lxml==4.6.3 +Markdown==3.3.4 +MarkupSafe==2.0.1 +matplotlib==3.2.1 +matplotlib-inline==0.1.2 +mistune==0.8.4 +more-itertools==8.8.0 +multidict==5.1.0 +munch==2.5.0 +murmurhash==1.0.5 +nbclient==0.5.3 +nbconvert==6.1.0 +nbformat==5.1.3 +nest-asyncio==1.5.1 +nltk==3.5 +notebook==6.4.3 +numexpr==2.7.1 +numpy==1.19.5 +oauthlib==3.1.1 +opt-einsum==3.3.0 +overrides==3.1.0 +packaging==21.0 +pandas==1.3.1 +pandocfilters==1.4.3 +parso==0.8.2 +pathtools==0.1.2 +pathy==0.6.0 +patternfork-nosql==3.6 +pbr==3.1.1 +pdfminer.six==20201018 +pexpect==4.8.0 +pickleshare==0.7.5 +Pillow==8.3.1 +pluggy==0.13.1 +portend==2.7.1 +preshed==3.0.5 +prometheus-client==0.11.0 +promise==2.3 +prompt-toolkit==3.0.19 +protobuf==3.17.3 +psutil==5.8.0 +ptyprocess==0.7.0 +py==1.10.0 +pyasn1==0.4.8 +pyasn1-modules==0.2.8 +pybind11==2.7.1 +pycosat==0.6.3 +pycparser==2.20 +pydantic==1.8.2 +pyDeprecate==0.3.1 +Pygments==2.9.0 +pyhanlp==0.1.64 +pyOpenSSL==19.1.0 +pyparsing==2.4.7 +pyrsistent==0.18.0 +PySocks==1.7.1 +pytest==6.2.4 +python-crfsuite==0.9.7 +python-dateutil==2.8.2 +python-docx==0.8.11 +pytorch-lightning==1.4.2 +pytorch-pretrained-bert==0.6.2 +pytorch-transformers==1.2.0 +pytz==2021.1 +PyYAML==5.4.1 +pyzmq==22.2.1 +qtconsole==5.1.1 +QtPy==1.9.0 +regex==2021.8.3 +requests==2.24.0 +requests-oauthlib==1.3.0 +rsa==4.7.2 +ruamel.yaml==0.15.87 +s3transfer==0.5.0 +sacremoses==0.0.45 +scikit-learn==0.24.0 +scikit-multilearn==0.2.0 +scipy==1.7.1 +Send2Trash==1.8.0 +sentencepiece==0.1.96 +sentry-sdk==1.3.1 +seqeval==1.2.2 +sgmllib3k==1.0.0 +shortuuid==1.0.1 +six==1.15.0 +sklearn==0.0 +sklearn-crfsuite==0.3.6 +smart-open==5.1.0 +smmap==4.0.0 +sortedcontainers==2.4.0 +soupsieve==2.2.1 +spacy==3.0.7 +spacy-legacy==3.0.8 +srsly==2.4.1 +subprocess32==3.5.4 +tabulate==0.8.9 +tempora==4.1.1 +tensorboard==2.6.0 +tensorboard-data-server==0.6.1 +tensorboard-plugin-wit==1.8.0 +tensorboardX==2.4 +tensorflow==2.5.0 +tensorflow-estimator==2.5.0 +termcolor==1.1.0 +terminado==0.10.1 +testpath==0.5.0 +tflearn==0.3.2 +Theano==1.0.4 +thinc==8.0.8 +threadpoolctl==2.2.0 +tokenizers==0.10.3 +toml==0.10.2 +torch==1.9.0+cu111 +torchaudio==0.9.0 +torchmetrics==0.5.0 +torchvision==0.10.0+cu111 +tornado==6.1 +tqdm==4.51.0 +traitlets==5.0.5 +transformers==4.8.2 +typer==0.3.2 +typing-extensions==3.7.4.3 +urllib3==1.25.11 +wandb==0.11.1 +wasabi==0.8.2 +wcwidth==0.2.5 +webencodings==0.5.1 +Werkzeug==2.0.1 +widgetsnbextension==3.5.1 +wrapt==1.12.1 +xgboost==1.4.2 +yarl==1.6.3 +zc.lockfile==2.0 +zipp==3.5.0 diff --git a/run_span.py b/run_span.py index 4984ed4..b50cd2e 100644 --- a/run_span.py +++ b/run_span.py @@ -580,8 +580,10 @@ def __init__(self, p): self.p = p self.augment_entity_meanings = [ - "物品价值", "被盗货币", "盗窃获利", - "受害人", "犯罪嫌疑人" + # "物品价值", "被盗货币", "盗窃获利", + # "被盗物品", "作案工具", + "受害人", "犯罪嫌疑人", + # "地点", "组织机构", ] def __call__(self, example): @@ -1064,7 +1066,7 @@ def predict_decode_batch(example, batch, id2label, post_process=True): # print() is_intersect = lambda a, b: min(a[1], b[1]) - max(a[0], b[0]) > 0 is_a_included_by_b = lambda a, b: min(a[1], b[1]) - max(a[0], b[0]) == a[1] - a[0] - is_contain_special_char = lambda x: any([c in text[x[0]: x[1]] for c in [",", ",", "、"]]) + is_contain_special_char = lambda x: any([c in text[x[0]: x[1]] for c in [",", "。", "、", ",", ".",]]) is_length_le_n = lambda x, n: x[1] - x[0] < n entities2spans = lambda entities: [(int(e.split(";")[0]), int(e.split(";")[1])) for e in entities] spans2entities = lambda spans: [f"{b};{e}" for b, e in spans] @@ -1108,8 +1110,10 @@ def merge_spans(spans, keep_type="short"): label_entities_map = {label: [] for label in LABEL_MEANING_MAP.keys()} for t, b, e in pred: label_entities_map[t].append(f"{b};{e+1}") if post_process: - # 若存在时间、地点实体重叠,则保留较长的 - for meaning in ["时间", "地点"]: + # 若存在以下实体重叠,则保留较长的 + for meaning in [ + "时间", "地点", + ]: label = MEANING_LABEL_MAP[meaning] entities = label_entities_map[label] # 左闭右开 if entities: @@ -1141,10 +1145,51 @@ def merge_spans(spans, keep_type="short"): is_todel[j] = True spans = [span for flag, span in zip(is_todel, spans) if not flag] # <<< 姓名处理 <<< + # # TODO: >>> 地点处理 >>> + # entities_name = label_entities_map[MEANING_LABEL_MAP["地点"]] + # spans_name = entities2spans(entities_name) + # # 加入`地点+被盗物品`的组合 + # spans.extend([(a[0], b[1]) for a, b in itertools.product( + # spans_name, spans) if a[1] - b[0] in [-1, 0]]) + # # `地点+被盗物品`、`被盗物品`,优先保留`地点+被盗物品` + # is_todel = [False] * len(spans) + # for i, a in enumerate(spans_name): + # for j, b in enumerate(spans): + # u = (a[0], b[1]) + # if u in spans and u != b: + # is_todel[j] = True + # spans = [span for flag, span in zip(is_todel, spans) if not flag] + # # <<< 地点处理 <<< spans = merge_spans(spans, keep_type="short") entities = spans2entities(spans) label_entities_map[label] = entities + # # TODO: 1. 若存在被盗货币实体重叠,保留最长的;2. 被盗货币要和人名联系 + # meaning = "被盗货币" + # label = MEANING_LABEL_MAP[meaning] + # entities = label_entities_map[label] # 左闭右开 + # if entities: + # spans = entities2spans(entities) + # spans = list(filter(lambda x: not is_contain_special_char(x), spans)) + # # # TODO: >>> 姓名处理 >>> + # # entities_name = label_entities_map[MEANING_LABEL_MAP["受害人"]] + # # spans_name = entities2spans(entities_name) + # # # 加入`受害人+被盗货币`的组合 + # # spans.extend([(a[0], b[1]) for a, b in itertools.product( + # # spans_name, spans) if a[1] - b[0] in [-1, 0]]) + # # # `受害人+被盗货币`、`被盗货币`,优先保留`受害人+被盗货币` + # # is_todel = [False] * len(spans) + # # for i, a in enumerate(spans_name): + # # for j, b in enumerate(spans): + # # u = (a[0], b[1]) + # # if u in spans and u != b: + # # is_todel[j] = True + # # spans = [span for flag, span in zip(is_todel, spans) if not flag] + # # # <<< 姓名处理 <<< + # spans = merge_spans(spans, keep_type="long") + # entities = spans2entities(spans) + # label_entities_map[label] = entities + # 受害人和犯罪嫌疑人设置最长实体限制(10) for meaning in ["受害人", "犯罪嫌疑人"]: label = MEANING_LABEL_MAP[meaning] @@ -1154,7 +1199,23 @@ def merge_spans(spans, keep_type="short"): spans = list(filter(lambda x: (not is_contain_special_char(x)) and is_length_le_n(x, 10), spans)) entities = spans2entities(spans) label_entities_map[label] = entities - + + # # TODO: 元现金 + # for meaning in [ + # "被盗货币", + # "物品价值", + # "盗窃获利", + # ]: + # label = MEANING_LABEL_MAP[meaning] + # entities = label_entities_map[label] + # if entities: + # spans = entities2spans(entities) + # for i, (l, r) in enumerate(spans): + # if text[r - 1] == "元" and text[r: r + 2] == "现金": + # spans[i] = (l, r + 2) + # entities = spans2entities(spans) + # label_entities_map[label] = entities + entities = [{"label": label, "span": label_entities_map[label]} \ for label in LABEL_MEANING_MAP.keys()] # 预测结果文件为一个json格式的文件,包含两个字段,分别为``id``和``entities`` @@ -1246,7 +1307,7 @@ def load_dataset(args, processor, tokenizer, data_type='train'): args = parser.parse_args_from_json(json_file=os.path.abspath(sys.argv[1])) else: args = parser.build_arguments().parse_args() - # args = parser.parse_args_from_json(json_file="output/ner-cail_ner-bert_span-rdrop0.1-fgm1.0-fold3-42/training_args.json") + # args = parser.parse_args_from_json(json_file="args/pred.1.json") # Set seed before initializing model. seed_everything(args.seed) @@ -1324,6 +1385,17 @@ def load_dataset(args, processor, tokenizer, data_type='train'): logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained() if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): + # 将early stop模型保存到输出目录下 + if args.save_best_checkpoints: + best_checkpoints = os.path.join(args.output_dir, "checkpoint-999999") + config = config_class.from_pretrained(best_checkpoints, + num_labels=num_labels, max_span_length=args.max_span_length, + cache_dir=args.cache_dir if args.cache_dir else None, ) + tokenizer = tokenizer_class.from_pretrained(best_checkpoints, + do_lower_case=args.do_lower_case, + cache_dir=args.cache_dir if args.cache_dir else None, ) + model = model_class.from_pretrained(best_checkpoints, config=config, + cache_dir=args.cache_dir if args.cache_dir else None) # Create output directory if needed if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: os.makedirs(args.output_dir) diff --git a/scripts/run_span.sh b/scripts/run_span.sh index 4921a41..3226fd0 100644 --- a/scripts/run_span.sh +++ b/scripts/run_span.sh @@ -363,11 +363,11 @@ done # 组织机构 # {'p': 0.8795336787564767, 'r': 0.8424317617866005, 'f': 0.8605830164765527} -# TODO: context-aware +# TODO: context-aware,仅增强分类性能较差的“受害人、犯罪嫌疑人” for k in 0 1 2 3 4 do python run_span.py \ - --version=nezha-rdrop0.1-fgm1.0-aug_ctx0.1-fold${k} \ + --version=nezha-rdrop0.1-fgm1.0-aug_ctx0.15-fold${k} \ --data_dir=./data/ner-ctx0-5fold-seed42/ \ --train_file=train.${k}.json \ --dev_file=dev.${k}.json \ @@ -391,11 +391,33 @@ python run_span.py \ --other_learning_rate=1e-3 \ --num_train_epochs=4.0 \ --warmup_proportion=0.1 \ - --rdrop_alpha=0.1 \ --do_fgm --fgm_epsilon=1.0 \ - --augment_context_aware_p=0.1 \ + --augment_context_aware_p=0.15 \ --seed=42 done +# main_local +# avg +# {'p': 0.8945034116755117, 'r': 0.885075578560444, 'f': 0.8897645217850342} +# 犯罪嫌疑人 +# {'p': 0.9578052550231839, 'r': 0.9588426427355717, 'f': 0.9583236681357767} +# 受害人 +# {'p': 0.9202144433932513, 'r': 0.9388674388674388, 'f': 0.929447364229973} +# 被盗货币 +# {'p': 0.8020304568527918, 'r': 0.8633879781420765, 'f': 0.831578947368421} +# 物品价值 +# {'p': 0.9693192713326941, 'r': 0.9674641148325359, 'f': 0.9683908045977011} +# 盗窃获利 +# {'p': 0.8552123552123552, 'r': 0.920997920997921, 'f': 0.8868868868868868} +# 被盗物品 +# {'p': 0.8162181951308805, 'r': 0.771319840857983, 'f': 0.7931341159729633} +# 作案工具 +# {'p': 0.7618421052631579, 'r': 0.7877551020408163, 'f': 0.774581939799331} +# 时间 +# {'p': 0.9468796433878157, 'r': 0.9218806509945751, 'f': 0.9342129375114532} +# 地点 +# {'p': 0.863689776733255, 'r': 0.8359397213534262, 'f': 0.8495882097962723} +# 组织机构 +# {'p': 0.8423586040914561, 'r': 0.8684863523573201, 'f': 0.855222968845449} # TODO: albert # TODO: distill