diff --git a/metaretriever/.gitignore b/metaretriever/.gitignore new file mode 100644 index 00000000..f889be3e --- /dev/null +++ b/metaretriever/.gitignore @@ -0,0 +1,18 @@ +# UIE +/data +/pretrain_data +/hf_models +/pd_models +/runs +/models +*.lock + +# mac +.DS_Store + +# env +/.vscode +/.idea +**/__pycache__ +*.pyc +.pytest_cache diff --git a/metaretriever/README.md b/metaretriever/README.md new file mode 100644 index 00000000..afec225d --- /dev/null +++ b/metaretriever/README.md @@ -0,0 +1,90 @@ +# Universal Information Extraction with Meta-Pretrained Self-Retrieval + +This code is for ACL 2023 Findings paper "Universal Information Extraction with Meta-Pretrained Self-Retrieval". + +## Overview + +![](img/MetaRetriever.png) + +Universal Information Extraction (Universal IE) aims to solve different extraction tasks in a uniform text-to-structure generation manner. Such a generation procedure tends to struggle when there exist complex information structures to be extracted. Retrieving knowledge from external knowledge bases may help models to overcome this problem but it is impossible to construct a knowledge base suitable for various IE tasks. Inspired by the fact that large amount of knowledge are stored in the pretrained language models (PLM) and can be retrieved explicitly, in this paper, we propose MetaRetriever to retrieve task-specific knowledge from PLMs to enhance universal IE. As different IE tasks need different knowledge, we further propose a Meta-Pretraining Algorithm which allows MetaRetriever to quicktly achieve maximum task-specific retrieval performance when fine-tuning on downstream IE tasks. Experimental results show that MetaRetriever achieves the new state-of-the-art on 4 IE tasks, 12 datasets under fully-supervised, low-resource and few-shot scenarios. + +## Requirements + +General + +- Python (verified on 3.8) +- CUDA (verified on 10.2) + +Python Packages +``` bash +conda create -n metaretriever python=3.8 +conda install -y pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=10.2 -c pytorch +pip install -r requirements.txt +``` + +**NOTE**: Different versions of packages (such as `pytorch`, `transformers`, etc.) may lead to different results from the paper. However, the trend should still hold no matter what versions of packages you use. + +## Usage + +### Data Preprocess + +``` bash +cd ./dataset_processing/ours +bash download_and_preprocess_data_clean.sh > clean_log.txt +``` + +### Model Preparation + +Please refer to [UIE](https://github.com/universal-ie/UIE) to download UIE model checkpoint and put it under the `models` dir. + +### Meta-Pretraining + +``` bash + +bash run_seq2seq_pretrain.bash -v -d 0,1,2,3,4,5,6,7 -b 64 -k 1 --lr 1e-4 --warmup_ratio 0.06 -i relation/ours_clean --spot_noise 0.0 --asoc_noise 0.0 -f spotasoc --map_config config/offset_map/closest_offset_en.yaml -m ./models/uie-base-en --random_prompt --epoch 4 --trainer_type meta_pretrain_v2 --use_prompt_tuning_model False --output_dir output/meta-pretrained-model +``` + +### Meta-Finetuning + +1. Full Supervision Scenario +``` bash +. config/exp_conf/large_model_conf.ini && trainer_type=meta_finetune_v2 model_name=meta-pretrained-model dataset_name=relation/conll04 selected_gpus=0,1,2,3,4,5,6,7 BATCH_SIZE=4 use_prompt_tuning_model=False run_time=1 bash scripts_exp/run_exp.bash +``` + +2. Few-Shot Scenario +``` bash +. config/exp_conf/base_model_conf_sa_shot.ini && trainer_type=meta_finetune_v2 model_name=meta-pretrained-model dataset_name=relation/conll04 selected_gpus=0,1,2,3,4,5,6,7 BATCH_SIZE=16 use_prompt_tuning_model=False bash scripts_exp/run_exp_shot.bash +``` + +3. Low-Resource Scenario +``` bash +. config/exp_conf/base_model_conf_sa_ratio.ini && trainer_type=meta_finetune_v2 model_name=meta-pretrained-model dataset_name=relation/conll04 selected_gpus=0,1,2,3,4,5,6,7 BATCH_SIZE=16 use_prompt_tuning_model=False bash scripts_exp/run_exp_ratio.bash +``` + +## Citation + +If this repository helps you, please cite this paper: +``` +@inproceedings{cong-etal-2023-universal, + title = "Universal Information Extraction with Meta-Pretrained Self-Retrieval", + author = "Cong, Xin and + Yu, Bowen and + Fang, Mengcheng and + Liu, Tingwen and + Yu, Haiyang and + Hu, Zhongkai and + Huang, Fei and + Li, Yongbin and + Wang, Bin", + editor = "Rogers, Anna and + Boyd-Graber, Jordan and + Okazaki, Naoaki", + booktitle = "Findings of the Association for Computational Linguistics: ACL 2023", + month = jul, + year = "2023", + address = "Toronto, Canada", + publisher = "Association for Computational Linguistics", + url = "https://aclanthology.org/2023.findings-acl.251", + doi = "10.18653/v1/2023.findings-acl.251", +} +``` diff --git a/metaretriever/config/data_conf/base_model_conf_absa.ini b/metaretriever/config/data_conf/base_model_conf_absa.ini new file mode 100644 index 00000000..8539c921 --- /dev/null +++ b/metaretriever/config/data_conf/base_model_conf_absa.ini @@ -0,0 +1,16 @@ + +export k8s_gpu_cards=1 +export gpu_node=${k8s_gpu_cards} + +export eval_steps=0 +export epoch=50 +export run_time=5 +export max_source_length=384 + +export BATCH_SIZE="16" +export LR_RATE="1e-4 3e-4 5e-4" +export WARMUP_PROP="0.06" +export LABEL_SMOOTHING="0" +export NEGATIVE="-1" +export NOISE="0.1" +export map_config='config/offset_map/closest_offset_en.yaml' diff --git a/metaretriever/config/data_conf/base_model_conf_nyt.ini b/metaretriever/config/data_conf/base_model_conf_nyt.ini new file mode 100644 index 00000000..79b63d26 --- /dev/null +++ b/metaretriever/config/data_conf/base_model_conf_nyt.ini @@ -0,0 +1,21 @@ +export job_name=FT_Multi + +export k8s_gpu_cards=4 +export gpu_node=${k8s_gpu_cards} + +export eval_steps=0 +export epoch=50 +export run_time=3 +export max_source_length=384 + +export BATCH_SIZE="16" +export LR_RATE="3e-4" +export WARMUP_PROP="0.06" +export LABEL_SMOOTHING="0" +export NEGATIVE="-1" +export NOISE="0.1" +export map_config='config/offset_map/closest_offset_en.yaml' + +export job_tags="" +export job_remark="3e-4,0.1" +export eval_match_mode="set" diff --git a/metaretriever/config/data_conf/base_scierc_conf.ini b/metaretriever/config/data_conf/base_scierc_conf.ini new file mode 100644 index 00000000..a3e7236e --- /dev/null +++ b/metaretriever/config/data_conf/base_scierc_conf.ini @@ -0,0 +1,23 @@ + +export k8s_gpu_cards=1 +export gpu_node=${k8s_gpu_cards} + +export eval_steps=0 +export epoch=50 +export run_time=5 +export max_source_length=384 + +export job_tags="${dataset_name},${model_name}_rp" +export job_remark="d${dataset_name},m${model_name}" + +export BATCH_SIZE="16" +export LR_RATE="5e-4" +export WARMUP_PROP="0.06" +export LABEL_SMOOTHING="0" +export NEGATIVE="-1" +export NOISE="0.1" +export map_config='config/offset_map/closest_offset_en.yaml' + +export job_tags="${dataset_name},${model_name}" +export job_remark="5e-4,0.1" +export start_eval_step=3000 diff --git a/metaretriever/config/data_conf/large_ace04ent_conf.ini b/metaretriever/config/data_conf/large_ace04ent_conf.ini new file mode 100644 index 00000000..e1bc7d73 --- /dev/null +++ b/metaretriever/config/data_conf/large_ace04ent_conf.ini @@ -0,0 +1,22 @@ + +export k8s_gpu_cards=4 +export gpu_node=${k8s_gpu_cards} + +export eval_steps=0 +export epoch=50 +export run_time=3 +export max_source_length=384 + +export job_tags="${dataset_name},${model_name}_rp" +export job_remark="d${dataset_name},m${model_name}" + +export BATCH_SIZE="8" +export LR_RATE="1e-4" +export WARMUP_PROP="0.06" +export LABEL_SMOOTHING="0" +export NEGATIVE="-1" +export NOISE="0.1 0.2" +export map_config='config/offset_map/closest_offset_en.yaml' + +export job_tags="${dataset_name},${model_name}" +export job_remark="1e-4,0.1,0.2" diff --git a/metaretriever/config/data_conf/large_ace05ent_conf.ini b/metaretriever/config/data_conf/large_ace05ent_conf.ini new file mode 100644 index 00000000..e1bc7d73 --- /dev/null +++ b/metaretriever/config/data_conf/large_ace05ent_conf.ini @@ -0,0 +1,22 @@ + +export k8s_gpu_cards=4 +export gpu_node=${k8s_gpu_cards} + +export eval_steps=0 +export epoch=50 +export run_time=3 +export max_source_length=384 + +export job_tags="${dataset_name},${model_name}_rp" +export job_remark="d${dataset_name},m${model_name}" + +export BATCH_SIZE="8" +export LR_RATE="1e-4" +export WARMUP_PROP="0.06" +export LABEL_SMOOTHING="0" +export NEGATIVE="-1" +export NOISE="0.1 0.2" +export map_config='config/offset_map/closest_offset_en.yaml' + +export job_tags="${dataset_name},${model_name}" +export job_remark="1e-4,0.1,0.2" diff --git a/metaretriever/config/data_conf/large_ace05evt_conf.ini b/metaretriever/config/data_conf/large_ace05evt_conf.ini new file mode 100644 index 00000000..bc4f830c --- /dev/null +++ b/metaretriever/config/data_conf/large_ace05evt_conf.ini @@ -0,0 +1,22 @@ +export job_name=FT_spotasocname + +export k8s_gpu_cards=4 +export gpu_node=${k8s_gpu_cards} + +export eval_steps=2000 +export epoch=50 +export run_time=3 +export max_source_length=256 + +export job_tags="${dataset_name},${model_name}_rp" +export job_remark="d${dataset_name},m${model_name}" + +export BATCH_SIZE="8" +export LR_RATE="1e-4" +export WARMUP_PROP="0.06" +export LABEL_SMOOTHING="0" +export NEGATIVE="-1" +export NOISE="0.1" +export map_config='config/offset_map/closest_offset_en.yaml' +export start_eval_step=15000 +export job_remark="1e-4,0.1" diff --git a/metaretriever/config/data_conf/large_ace05rel_conf.ini b/metaretriever/config/data_conf/large_ace05rel_conf.ini new file mode 100644 index 00000000..19b19fe7 --- /dev/null +++ b/metaretriever/config/data_conf/large_ace05rel_conf.ini @@ -0,0 +1,22 @@ + +export k8s_gpu_cards=4 +export gpu_node=${k8s_gpu_cards} + +export eval_steps=0 +export epoch=50 +export run_time=3 +export max_source_length=384 + +export job_tags="${dataset_name},${model_name}_rp" +export job_remark="d${dataset_name},m${model_name}" + +export BATCH_SIZE="8" +export LR_RATE="1e-4 3e-4" +export WARMUP_PROP="0.06" +export LABEL_SMOOTHING="0" +export NEGATIVE="-1" +export NOISE="0.2" +export map_config='config/offset_map/closest_offset_en.yaml' + +export job_tags="${dataset_name},${model_name}" +export job_remark="1e-4,3e-4,0.2" diff --git a/metaretriever/config/data_conf/large_casie_conf.ini b/metaretriever/config/data_conf/large_casie_conf.ini new file mode 100644 index 00000000..740d4d6a --- /dev/null +++ b/metaretriever/config/data_conf/large_casie_conf.ini @@ -0,0 +1,22 @@ + +export k8s_gpu_cards=4 +export gpu_node=${k8s_gpu_cards} + +export eval_steps=0 +export epoch=50 +export run_time=3 +export max_source_length=256 + +export job_tags="${dataset_name},${model_name}_rp" +export job_remark="d${dataset_name},m${model_name}" + +export BATCH_SIZE="8" +export LR_RATE="3e-4" +export WARMUP_PROP="0.06" +export LABEL_SMOOTHING="0" +export NEGATIVE="-1" +export NOISE="0.2" +export map_config='config/offset_map/closest_offset_en.yaml' + +export job_tags="${dataset_name},${model_name}" +export job_remark="3e-4,0.2" diff --git a/metaretriever/config/data_conf/large_conll03_conf.ini b/metaretriever/config/data_conf/large_conll03_conf.ini new file mode 100644 index 00000000..ad662420 --- /dev/null +++ b/metaretriever/config/data_conf/large_conll03_conf.ini @@ -0,0 +1,22 @@ + +export k8s_gpu_cards=4 +export gpu_node=${k8s_gpu_cards} + +export eval_steps=0 +export epoch=50 +export run_time=3 +export max_source_length=256 + +export job_tags="${dataset_name},${model_name}_rp" +export job_remark="d${dataset_name},m${model_name}" + +export BATCH_SIZE="8" +export LR_RATE="1e-4" +export WARMUP_PROP="0.06" +export LABEL_SMOOTHING="0" +export NEGATIVE="-1" +export NOISE="0.1" +export map_config='config/offset_map/first_offset_en.yaml' + +export job_tags="${dataset_name},${model_name}" +export job_remark="1e-4,0.1" diff --git a/metaretriever/config/data_conf/large_conll03_conf_b8.ini b/metaretriever/config/data_conf/large_conll03_conf_b8.ini new file mode 100644 index 00000000..509cb3a9 --- /dev/null +++ b/metaretriever/config/data_conf/large_conll03_conf_b8.ini @@ -0,0 +1,22 @@ + +export k8s_gpu_cards=1 +export gpu_node=${k8s_gpu_cards} + +export eval_steps=0 +export epoch=50 +export run_time=3 +export max_source_length=256 + +export job_tags="${dataset_name},${model_name}_rp" +export job_remark="d${dataset_name},m${model_name}" + +export BATCH_SIZE="8" +export LR_RATE="5e-5" +export WARMUP_PROP="0.06" +export LABEL_SMOOTHING="0" +export NEGATIVE="-1" +export NOISE="0.1" +export map_config='config/offset_map/first_offset_en.yaml' + +export job_tags="${dataset_name},${model_name}" +export job_remark="5e-5,0.1" diff --git a/metaretriever/config/data_conf/large_conll04_conf.ini b/metaretriever/config/data_conf/large_conll04_conf.ini new file mode 100644 index 00000000..6036c812 --- /dev/null +++ b/metaretriever/config/data_conf/large_conll04_conf.ini @@ -0,0 +1,22 @@ + +export k8s_gpu_cards=4 +export gpu_node=${k8s_gpu_cards} + +export eval_steps=0 +export epoch=50 +export run_time=3 +export max_source_length=384 + +export job_tags="${dataset_name},${model_name}_rp" +export job_remark="d${dataset_name},m${model_name}" + +export BATCH_SIZE="8" +export LR_RATE="3e-4" +export WARMUP_PROP="0.06" +export LABEL_SMOOTHING="0" +export NEGATIVE="-1" +export NOISE="0.2 0.1" +export map_config='config/offset_map/closest_offset_en.yaml' + +export job_tags="${dataset_name},${model_name}" +export job_remark="3e-4,0.1,0.2" diff --git a/metaretriever/config/data_conf/large_model_conf_absa.ini b/metaretriever/config/data_conf/large_model_conf_absa.ini new file mode 100644 index 00000000..a876a767 --- /dev/null +++ b/metaretriever/config/data_conf/large_model_conf_absa.ini @@ -0,0 +1,16 @@ + +export k8s_gpu_cards=1 +export gpu_node=${k8s_gpu_cards} + +export eval_steps=0 +export epoch=50 +export run_time=3 +export max_source_length=384 + +export BATCH_SIZE="8" +export LR_RATE="1e-4 3e-5 5e-5" +export WARMUP_PROP="0.06" +export LABEL_SMOOTHING="0" +export NEGATIVE="-1" +export NOISE="0.1" +export map_config='config/offset_map/closest_offset_en.yaml' diff --git a/metaretriever/config/data_conf/large_model_conf_nyt.ini b/metaretriever/config/data_conf/large_model_conf_nyt.ini new file mode 100644 index 00000000..7748be7a --- /dev/null +++ b/metaretriever/config/data_conf/large_model_conf_nyt.ini @@ -0,0 +1,23 @@ + +export k8s_gpu_cards=4 +export gpu_node=${k8s_gpu_cards} + +export eval_steps=0 +export epoch=50 +export run_time=3 +export max_source_length=384 + +export job_tags="${dataset_name},${model_name}_rp" +export job_remark="d${dataset_name},m${model_name}" + +export BATCH_SIZE="8" +export LR_RATE="5e-5" +export WARMUP_PROP="0.06" +export LABEL_SMOOTHING="0" +export NEGATIVE="-1" +export NOISE="0.2" +export map_config='config/offset_map/closest_offset_en.yaml' + +export job_tags="${dataset_name},${model_name}" +export job_remark="5e-5,0.2" +export eval_match_mode="set" diff --git a/metaretriever/config/data_conf/large_scierc_conf.ini b/metaretriever/config/data_conf/large_scierc_conf.ini new file mode 100644 index 00000000..87802f22 --- /dev/null +++ b/metaretriever/config/data_conf/large_scierc_conf.ini @@ -0,0 +1,22 @@ + +export k8s_gpu_cards=4 +export gpu_node=${k8s_gpu_cards} + +export eval_steps=0 +export epoch=50 +export run_time=3 +export max_source_length=384 + +export job_tags="${dataset_name},${model_name}_rp" +export job_remark="d${dataset_name},m${model_name}" + +export BATCH_SIZE="8" +export LR_RATE="3e-4" +export WARMUP_PROP="0.06" +export LABEL_SMOOTHING="0" +export NEGATIVE="-1" +export NOISE="0.2 0.1 0" +export map_config='config/offset_map/closest_offset_en.yaml' + +export job_tags="${dataset_name},${model_name}" +export job_remark="3e-4,0,0.1,0.2" diff --git a/metaretriever/config/det_conf/det.conf b/metaretriever/config/det_conf/det.conf new file mode 100644 index 00000000..20092fa4 --- /dev/null +++ b/metaretriever/config/det_conf/det.conf @@ -0,0 +1,22 @@ +description: uie + +environment: + image: docker.cipsup.cn/uie/uie:transformers4.6.2 + environment_variables: + - DET_TASK_OWNER=luyaojie + +resources: + slots: 4 + +bind_mounts: + # Data Folder Bind Mount + - host_path: /shared_home/luyaojie/uie/data + container_path: /run/determined/workdir/data + + # Pre-trained Model Folder Bind Mount + - host_path: /shared_home/luyaojie/uie/model + container_path: /run/determined/workdir/hf_models + + # Output Folder Bind Mount + - host_path: /shared_home/luyaojie/uie/output + container_path: /run/determined/workdir/output diff --git a/metaretriever/config/exp_conf/base_model_conf.ini b/metaretriever/config/exp_conf/base_model_conf.ini new file mode 100644 index 00000000..c6123ac1 --- /dev/null +++ b/metaretriever/config/exp_conf/base_model_conf.ini @@ -0,0 +1,22 @@ + +export k8s_gpu_cards=4 +export gpu_node=${k8s_gpu_cards} + +export eval_steps=0 +export epoch=50 +export run_time=3 +export max_source_length=384 + +export job_tags="${dataset_name},${model_name}_rp" +export job_remark="d${dataset_name},m${model_name}" + +export BATCH_SIZE="16" +export LR_RATE="1e-4 3e-4 5e-4" +export WARMUP_PROP="0.06" +export LABEL_SMOOTHING="0" +export NEGATIVE="-1" +export NOISE="0 0.1 0.2" +export map_config='config/offset_map/closest_offset_en.yaml' + +export job_tags="${dataset_name},${model_name}" +export job_remark="1e-4,3e-4,5e-5,0,0.1,0.2" diff --git a/metaretriever/config/exp_conf/base_model_conf_b16.ini b/metaretriever/config/exp_conf/base_model_conf_b16.ini new file mode 100644 index 00000000..880196db --- /dev/null +++ b/metaretriever/config/exp_conf/base_model_conf_b16.ini @@ -0,0 +1,21 @@ + +export gpu_node=1 + +export eval_steps=0 +export epoch=50 +export run_time=3 +export max_source_length=384 + +export job_tags="${dataset_name},${model_name}_rp" +export job_remark="d${dataset_name},m${model_name}" + +export BATCH_SIZE="16" +export LR_RATE="1e-4 3e-4 5e-4" +export WARMUP_PROP="0.06" +export LABEL_SMOOTHING="0" +export NEGATIVE="-1" +export NOISE="0.1 0 0.2" +export map_config='config/offset_map/closest_offset_en.yaml' + +export job_tags="${dataset_name},${model_name}" +export job_remark="1e-4,3e-4,5e-5,0,0.1,0.2" diff --git a/metaretriever/config/exp_conf/base_model_conf_sa_ratio.ini b/metaretriever/config/exp_conf/base_model_conf_sa_ratio.ini new file mode 100644 index 00000000..afa3d80c --- /dev/null +++ b/metaretriever/config/exp_conf/base_model_conf_sa_ratio.ini @@ -0,0 +1,15 @@ +export gpu_node=1 + +export eval_steps=0 +export epoch=200 +export run_time=10 +export max_source_length=256 +export decoding_format="spotasoc" + +export BATCH_SIZE="16" +export LR_RATE="1e-4" +export WARMUP_PROP="0.06" +export LABEL_SMOOTHING="0" +export NEGATIVE="-1" +export NOISE="0.1" +export map_config='config/offset_map/closest_offset_en.yaml' diff --git a/metaretriever/config/exp_conf/base_model_conf_sa_shot.ini b/metaretriever/config/exp_conf/base_model_conf_sa_shot.ini new file mode 100644 index 00000000..afa3d80c --- /dev/null +++ b/metaretriever/config/exp_conf/base_model_conf_sa_shot.ini @@ -0,0 +1,15 @@ +export gpu_node=1 + +export eval_steps=0 +export epoch=200 +export run_time=10 +export max_source_length=256 +export decoding_format="spotasoc" + +export BATCH_SIZE="16" +export LR_RATE="1e-4" +export WARMUP_PROP="0.06" +export LABEL_SMOOTHING="0" +export NEGATIVE="-1" +export NOISE="0.1" +export map_config='config/offset_map/closest_offset_en.yaml' diff --git a/metaretriever/config/exp_conf/large_model_conf.ini b/metaretriever/config/exp_conf/large_model_conf.ini new file mode 100644 index 00000000..6bef72b9 --- /dev/null +++ b/metaretriever/config/exp_conf/large_model_conf.ini @@ -0,0 +1,22 @@ + +export k8s_gpu_cards=4 +export gpu_node=${k8s_gpu_cards} + +export eval_steps=0 +export epoch=50 +export run_time=3 +export max_source_length=384 + +export job_tags="${dataset_name},${model_name}_rp" +export job_remark="d${dataset_name},m${model_name}" + +export BATCH_SIZE="8" +export LR_RATE="5e-5 1e-4 3e-4" +export WARMUP_PROP="0.06" +export LABEL_SMOOTHING="0" +export NEGATIVE="-1" +export NOISE="0.2 0.1 0" +export map_config='config/offset_map/closest_offset_en.yaml' + +export job_tags="${dataset_name},${model_name}" +export job_remark="5e-5,1e-4,3e-4,0,0.1,0.2" diff --git a/metaretriever/config/exp_conf/large_model_conf_b8.ini b/metaretriever/config/exp_conf/large_model_conf_b8.ini new file mode 100644 index 00000000..6bef72b9 --- /dev/null +++ b/metaretriever/config/exp_conf/large_model_conf_b8.ini @@ -0,0 +1,22 @@ + +export k8s_gpu_cards=4 +export gpu_node=${k8s_gpu_cards} + +export eval_steps=0 +export epoch=50 +export run_time=3 +export max_source_length=384 + +export job_tags="${dataset_name},${model_name}_rp" +export job_remark="d${dataset_name},m${model_name}" + +export BATCH_SIZE="8" +export LR_RATE="5e-5 1e-4 3e-4" +export WARMUP_PROP="0.06" +export LABEL_SMOOTHING="0" +export NEGATIVE="-1" +export NOISE="0.2 0.1 0" +export map_config='config/offset_map/closest_offset_en.yaml' + +export job_tags="${dataset_name},${model_name}" +export job_remark="5e-5,1e-4,3e-4,0,0.1,0.2" diff --git a/metaretriever/config/offset_map/closest_offset_en.yaml b/metaretriever/config/offset_map/closest_offset_en.yaml new file mode 100644 index 00000000..645c8928 --- /dev/null +++ b/metaretriever/config/offset_map/closest_offset_en.yaml @@ -0,0 +1,3 @@ +map_strategy: "closest" +de_duplicate: True +span_to_token: "space" diff --git a/metaretriever/config/offset_map/closest_offset_zh.yaml b/metaretriever/config/offset_map/closest_offset_zh.yaml new file mode 100644 index 00000000..0aa18a44 --- /dev/null +++ b/metaretriever/config/offset_map/closest_offset_zh.yaml @@ -0,0 +1,3 @@ +map_strategy: "closest" +de_duplicate: True +span_to_token: "list" diff --git a/metaretriever/config/offset_map/first_offset_en.yaml b/metaretriever/config/offset_map/first_offset_en.yaml new file mode 100644 index 00000000..80079b95 --- /dev/null +++ b/metaretriever/config/offset_map/first_offset_en.yaml @@ -0,0 +1,3 @@ +map_strategy: "first" +de_duplicate: True +span_to_token: "space" diff --git a/metaretriever/config/offset_map/first_offset_zh.yaml b/metaretriever/config/offset_map/first_offset_zh.yaml new file mode 100644 index 00000000..414b05d8 --- /dev/null +++ b/metaretriever/config/offset_map/first_offset_zh.yaml @@ -0,0 +1,3 @@ +map_strategy: "first" +de_duplicate: True +span_to_token: "list" diff --git a/metaretriever/config/offset_map/longer_first_offset_zh.yaml b/metaretriever/config/offset_map/longer_first_offset_zh.yaml new file mode 100644 index 00000000..a3d35084 --- /dev/null +++ b/metaretriever/config/offset_map/longer_first_offset_zh.yaml @@ -0,0 +1,3 @@ +map_strategy: "longer_first" +de_duplicate: True +span_to_token: "list" diff --git a/metaretriever/dataset_processing/.gitignore b/metaretriever/dataset_processing/.gitignore new file mode 100644 index 00000000..a6707f27 --- /dev/null +++ b/metaretriever/dataset_processing/.gitignore @@ -0,0 +1,25 @@ +/data +/converted_data +/lightning_logs +/model +/models +/*log +/thirdparty +/tmp + +.lock +# mac +.DS_Store + +# env +/.vscode +/.idea +**/__pycache__ +*.pyc +.pytest_cache + +# doc +docs/build +docs/.vscode +docs/source/_build + diff --git a/metaretriever/dataset_processing/README.md b/metaretriever/dataset_processing/README.md new file mode 100644 index 00000000..e98f4070 --- /dev/null +++ b/metaretriever/dataset_processing/README.md @@ -0,0 +1,3 @@ +# Universal IE Dataset Preparation + +Please refer to [UIE](https://github.com/universal-ie/UIE). \ No newline at end of file diff --git a/metaretriever/dataset_processing/data_config/absa/pengb_14lap.yaml b/metaretriever/dataset_processing/data_config/absa/pengb_14lap.yaml new file mode 100644 index 00000000..3a73562f --- /dev/null +++ b/metaretriever/dataset_processing/data_config/absa/pengb_14lap.yaml @@ -0,0 +1,15 @@ +name: 14lap +path: data/absa/pengb/14lap +data_class: ABSA +split: + train: train_convert.json + val: dev_convert.json + test: test_convert.json +language: en + +mapper: + POS: positive + NEG: negative + NEU: neutral + aspect: aspect + opinion: opinion diff --git a/metaretriever/dataset_processing/data_config/absa/pengb_14res.yaml b/metaretriever/dataset_processing/data_config/absa/pengb_14res.yaml new file mode 100644 index 00000000..e4c7f35a --- /dev/null +++ b/metaretriever/dataset_processing/data_config/absa/pengb_14res.yaml @@ -0,0 +1,15 @@ +name: 14res +path: data/absa/pengb/14res +data_class: ABSA +split: + train: train_convert.json + val: dev_convert.json + test: test_convert.json +language: en + +mapper: + POS: positive + NEG: negative + NEU: neutral + aspect: aspect + opinion: opinion diff --git a/metaretriever/dataset_processing/data_config/absa/pengb_15res.yaml b/metaretriever/dataset_processing/data_config/absa/pengb_15res.yaml new file mode 100644 index 00000000..cffab0ff --- /dev/null +++ b/metaretriever/dataset_processing/data_config/absa/pengb_15res.yaml @@ -0,0 +1,15 @@ +name: 15res +path: data/absa/pengb/15res +data_class: ABSA +split: + train: train_convert.json + val: dev_convert.json + test: test_convert.json +language: en + +mapper: + POS: positive + NEG: negative + NEU: neutral + aspect: aspect + opinion: opinion diff --git a/metaretriever/dataset_processing/data_config/absa/pengb_16res.yaml b/metaretriever/dataset_processing/data_config/absa/pengb_16res.yaml new file mode 100644 index 00000000..82468279 --- /dev/null +++ b/metaretriever/dataset_processing/data_config/absa/pengb_16res.yaml @@ -0,0 +1,15 @@ +name: 16res +path: data/absa/pengb/16res +data_class: ABSA +split: + train: train_convert.json + val: dev_convert.json + test: test_convert.json +language: en + +mapper: + POS: positive + NEG: negative + NEU: neutral + aspect: aspect + opinion: opinion diff --git a/metaretriever/dataset_processing/data_config/entity/conll03.yaml b/metaretriever/dataset_processing/data_config/entity/conll03.yaml new file mode 100644 index 00000000..bc55de71 --- /dev/null +++ b/metaretriever/dataset_processing/data_config/entity/conll03.yaml @@ -0,0 +1,13 @@ +name: conll03 +path: data/conll03/conll03 +data_class: CoNLL03 +split: + train: eng.train + val: eng.testa + test: eng.testb +language: en +mapper: + LOC: location + ORG: organization + PER: person + MISC: miscellaneous \ No newline at end of file diff --git a/metaretriever/dataset_processing/data_config/entity/mrc_ace2004.yaml b/metaretriever/dataset_processing/data_config/entity/mrc_ace2004.yaml new file mode 100644 index 00000000..56dcc579 --- /dev/null +++ b/metaretriever/dataset_processing/data_config/entity/mrc_ace2004.yaml @@ -0,0 +1,17 @@ +name: mrc_ace04 +path: data/mrc_ner/ace2004 +data_class: MRCNER +split: + train: mrc-ner.train + val: mrc-ner.dev + test: mrc-ner.test +language: en + +mapper: + FAC: facility + GPE: geographical social political + LOC: location + ORG: organization + PER: person + VEH: vehicle + WEA: weapon diff --git a/metaretriever/dataset_processing/data_config/entity/mrc_ace2005.yaml b/metaretriever/dataset_processing/data_config/entity/mrc_ace2005.yaml new file mode 100644 index 00000000..c221e36a --- /dev/null +++ b/metaretriever/dataset_processing/data_config/entity/mrc_ace2005.yaml @@ -0,0 +1,17 @@ +name: mrc_ace05 +path: data/mrc_ner/ace2005 +data_class: MRCNER +split: + train: mrc-ner.train + val: mrc-ner.dev + test: mrc-ner.test +language: en + +mapper: + FAC: facility + GPE: geographical social political + LOC: location + ORG: organization + PER: person + VEH: vehicle + WEA: weapon diff --git a/metaretriever/dataset_processing/data_config/event/casie.yaml b/metaretriever/dataset_processing/data_config/event/casie.yaml new file mode 100644 index 00000000..3ae06081 --- /dev/null +++ b/metaretriever/dataset_processing/data_config/event/casie.yaml @@ -0,0 +1,62 @@ +name: casie +path: data/casie +data_class: CASIE +split: + train: train.jsonlines + val: dev.jsonlines + test: test.jsonlines +language: en +# https://github.com/Ebiquity/CASIE +# https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=9006444 +mapper: + File: file + System: system + Person: person + Phishing: phishing + Data: data + Purpose: purpose + Website: website + Organization: organization + Capabilities: capabilities + Malware: malware + Software: software + PII: personally identifiable information + Databreach: databreach + Time: time + Number: number + GPE: geopolitical entity + Ransom: ransom + Money: money + Device: device + Vulnerability: vulnerability + DiscoverVulnerability: discover vulnerability + Patch: patch + PatchVulnerability: patch vulnerability + Version: version + PaymentMethod: payment method + CVE: common vulnerabilities and exposures + Issues-Addressed: issues addressed + Vulnerable_System: vulnerable system + Number-of-Data: number of data + Capabilities: capabilities + Patch: patch + Time: time + Vulnerable_System_Version: vulnerable system version + Releaser: releaser + Damage-Amount: damage amount + Number-of-Victim: number of victim + Tool: tool + Attack-Pattern: attack pattern + Compromised-Data: compromised data + Attacker: attacker + Price: price + Discoverer: discoverer + Patch-Number: patch number + Payment-Method: payment method + Supported_Platform: supported platform + Vulnerability: vulnerability + Place: place + Vulnerable_System_Owner: vulnerable system owner + Victim: victim + Trusted-Entity: trusted entity + Purpose: purpose diff --git a/metaretriever/dataset_processing/data_config/event/oneie_ace05_en_event.yaml b/metaretriever/dataset_processing/data_config/event/oneie_ace05_en_event.yaml new file mode 100644 index 00000000..5e25453f --- /dev/null +++ b/metaretriever/dataset_processing/data_config/event/oneie_ace05_en_event.yaml @@ -0,0 +1,78 @@ +name: oneie_ace05_en_event +path: data/oneie/ace05-EN +data_class: OneIEEvent +split: + train: train.oneie.json + val: dev.oneie.json + test: test.oneie.json +language: en + +mapper: + FAC: facility + GPE: geographical social political + LOC: location + ORG: organization + PER: person + VEH: vehicle + WEA: weapon + ORG-AFF: organization affiliation + GEN-AFF: general affiliation + PHYS: physical + PART-WHOLE: part whole + PER-SOC: personal social + ART: agent artifact + Personnel:Elect: elect + Life:Be-Born: born + Movement:Transport: transport + Contact:Phone-Write: phone write + Life:Marry: marry + Life:Die: die + Personnel:Start-Position: start position + Life:Injure: injure + Transaction:Transfer-Ownership: transfer ownership + Contact:Meet: meet + Personnel:Nominate: nominate + Conflict:Attack: attack + Business:Start-Org: start organization + Justice:Trial-Hearing: trial hearing + Justice:Convict: convict + Justice:Sentence: sentence + Personnel:End-Position: end position + Life:Divorce: divorce + Justice:Acquit: acquit + Justice:Charge-Indict: charge indict + Transaction:Transfer-Money: transfer money + Justice:Appeal: appeal + Justice:Sue: sue + Business:Merge-Org: merge organization + Business:Declare-Bankruptcy: declare bankruptcy + Justice:Execute: execute + Justice:Arrest-Jail: arrest jail + Justice:Extradite: extradite + Conflict:Demonstrate: demonstrate + Business:End-Org: end organization + Justice:Release-Parole: release parole + Justice:Fine: fine + Justice:Pardon: pardon + Defendant: defendant + Prosecutor: prosecutor + Person: person + Origin: origin + Buyer: buyer + Plaintiff: plaintiff + Victim: victim + Org: organization + Adjudicator: adjudicator + Seller: seller + Beneficiary: beneficiary + Giver: giver + Target: target + Agent: agent + Instrument: instrument + Vehicle: vehicle + Entity: entity + Destination: destination + Recipient: recipient + Attacker: attacker + Artifact: artifact + Place: place diff --git a/metaretriever/dataset_processing/data_config/relation/NYT-multi.yaml b/metaretriever/dataset_processing/data_config/relation/NYT-multi.yaml new file mode 100644 index 00000000..c3174e12 --- /dev/null +++ b/metaretriever/dataset_processing/data_config/relation/NYT-multi.yaml @@ -0,0 +1,36 @@ +name: NYT +path: data/NYT-multi +data_class: JointER +split: + train: train.json + val: dev.json + test: test.json +language: en +mapper: + ORGANIZATION: organization + LOCATION: location + PERSON: person + /location/location/contains: contains + /people/person/place_of_birth: place of birth + /business/person/company: company + /people/person/place_lived: place lived + /location/administrative_division/country: country + /location/country/administrative_divisions: administrative divisions + /people/person/religion: religion + /people/person/nationality: nationality + /people/person/children: children + /location/country/capital: capital + /business/company/place_founded: place founded + /people/deceased_person/place_of_death: place of death + /business/company/founders: founders + /location/neighborhood/neighborhood_of: neighborhood of + /business/company/advisors: advisors + /people/ethnicity/geographic_distribution: geographic distribution + /sports/sports_team_location/teams: teams + /sports/sports_team/location: location + /business/company_shareholder/major_shareholder_of: major shareholder of + /business/company/major_shareholders: major shareholders + /people/person/ethnicity: ethnicity + /people/ethnicity/people: people + /people/person/profession: profession + /business/company/industry: industry \ No newline at end of file diff --git a/metaretriever/dataset_processing/data_config/relation/ace05-rel.yaml b/metaretriever/dataset_processing/data_config/relation/ace05-rel.yaml new file mode 100644 index 00000000..2bfc301f --- /dev/null +++ b/metaretriever/dataset_processing/data_config/relation/ace05-rel.yaml @@ -0,0 +1,23 @@ +name: ace05-rel +path: data/spannet_data/relation/ace05 +data_class: Spannet +split: + train: train.jsonlines + val: dev.jsonlines + test: test.jsonlines +language: en + +mapper: + FAC: facility + GPE: geographical social political + LOC: location + ORG: organization + PER: person + VEH: vehicle + WEA: weapon + ORG-AFF: organization affiliation + GEN-AFF: general affiliation + PHYS: physical + PART-WHOLE: part whole + PER-SOC: personal social + ART: agent artifact diff --git a/metaretriever/dataset_processing/data_config/relation/conll04.yaml b/metaretriever/dataset_processing/data_config/relation/conll04.yaml new file mode 100644 index 00000000..1d744c2b --- /dev/null +++ b/metaretriever/dataset_processing/data_config/relation/conll04.yaml @@ -0,0 +1,19 @@ +name: conll04 +path: data/spannet_data/relation/conll04 +data_class: Spannet +split: + train: train.jsonlines + val: dev.jsonlines + test: test.jsonlines +language: en + +mapper: + Loc: location + Org: organization + Other: other + Peop: people + OrgBased_In: organization in + Work_For: work for + Located_In: located in + Live_In: live in + Kill: kill diff --git a/metaretriever/dataset_processing/data_config/relation/scierc.yaml b/metaretriever/dataset_processing/data_config/relation/scierc.yaml new file mode 100644 index 00000000..dd19480a --- /dev/null +++ b/metaretriever/dataset_processing/data_config/relation/scierc.yaml @@ -0,0 +1,22 @@ +name: scierc +path: data/spannet_data/relation/dygiepp/scierc +data_class: Spannet +split: + train: train.jsonlines + val: dev.jsonlines + test: test.jsonlines +language: en +mapper: + Method: method + Generic: generic + Material: material + Task: task + Metric: metric + OtherScientificTerm: other scientific term + USED-FOR: used for + FEATURE-OF: feature of + COMPARE: compare + EVALUATE-FOR: evaluate for + CONJUNCTION: conjunction + HYPONYM-OF: hyponym of + PART-OF: part of diff --git a/metaretriever/dataset_processing/docs/data_statistics_CN.md b/metaretriever/dataset_processing/docs/data_statistics_CN.md new file mode 100644 index 00000000..1852924c --- /dev/null +++ b/metaretriever/dataset_processing/docs/data_statistics_CN.md @@ -0,0 +1,10 @@ +# 数据统计脚本 + +``` bash +python scripts/data_statistics.py \ + -data converted_data/text2spotasoc/ + -f csv +``` + +- data: 目标文件夹,遍历文件夹下包含 record.schema 的子文件夹,跳过所有的命名中包含 shot 和 rario 的文件夹 +- f: 输出的表格形式,常见中 simple(默认),latex,html diff --git a/metaretriever/dataset_processing/docs/run_sample_CN.md b/metaretriever/dataset_processing/docs/run_sample_CN.md new file mode 100644 index 00000000..e14ecbe8 --- /dev/null +++ b/metaretriever/dataset_processing/docs/run_sample_CN.md @@ -0,0 +1,54 @@ +# 低资源数据采样 + +详细脚本见 `run_sample.bash`, 自动生成所有数据 + + +## 低数据比例采样 + +``` text + $ python scripts/sample_data_ratio.py -h +usage: sample_data_ratio.py [-h] [-src SRC] [-tgt TGT] [-seed SEED] + +optional arguments: + -h, --help show this help message and exit + -src SRC + -tgt TGT + -seed SEED +``` + +样例: + +``` bash +python scripts/sample_data_ratio.py \ + -src converted_data/text2spotasoc/entity/mrc_conll03 \ + -tgt test_conll03_ratio +``` + +对所有数据文件夹的train.json取指定 0.01 0.05 0.1 比例的数据 + +## N-shot 数据采样 + +``` text + $ python scripts/sample_data_shot.py -h +usage: sample_data_shot.py [-h] -src SRC -tgt TGT -task {entity,relation,event} [-seed SEED] + +optional arguments: + -h, --help show this help message and exit + -src SRC Source Folder Name + -tgt TGT Target Folder Name, n shot sampled + -task {entity,relation,event} + N-Shot Task name + -seed SEED Default is None, no random +``` + +样例: + +``` bash +python scripts/sample_data_shot.py \ + -src converted_data/text2spotasoc/entity/mrc_conll03 \ + -tgt test_conll03_shot \ + -task entity +``` + +1. 读取数据文件夹的 `entity.schema` +2. 根据每个类别采样 1 5 10 个样例合成最终数据 diff --git a/metaretriever/dataset_processing/ours/categorized_store.py b/metaretriever/dataset_processing/ours/categorized_store.py new file mode 100644 index 00000000..ce532117 --- /dev/null +++ b/metaretriever/dataset_processing/ours/categorized_store.py @@ -0,0 +1,131 @@ +import json +import os +import random +import argparse + +from collections import OrderedDict + +from tqdm import tqdm + +import pdb + +parser = argparse.ArgumentParser() +parser.add_argument("-d", "--data_dir", default="./final_data5/data_1", type=str) +parser.add_argument("-o", "--output_dir", default="./output_fewshot", type=str) +parser.add_argument("--entity_category_dir", default="entity_category", type=str) +parser.add_argument("--relation_category_dir", default="relation_category", type=str) +parser.add_argument("--step", default=1, type=int) +opt = parser.parse_args() + +data_dir = opt.data_dir +output_dir = opt.output_dir +entity_category_dir = opt.entity_category_dir +relation_category_dir = opt.relation_category_dir +step = opt.step + +all_file = os.path.join(output_dir, "all.json") + +entity_stat_file = os.path.join(output_dir, "entity_stat.json") +relation_stat_file = os.path.join(output_dir, "relation_stat.json") + +target_entity_category_dir = os.path.join(output_dir, entity_category_dir) +target_relation_category_dir = os.path.join(output_dir, relation_category_dir) + +if not os.path.exists(target_entity_category_dir): + os.makedirs(target_entity_category_dir) +if not os.path.exists(target_relation_category_dir): + os.makedirs(target_relation_category_dir) + +entity_instance_dict_file = os.path.join(output_dir, "entity_instance_dict.json") +relation_instance_dict_file = os.path.join(output_dir, "relation_instance_dict.json") + +metainfo_file = relation_instance_dict_file = os.path.join(output_dir, "metainfo.json") + +# %% load all instance line + +print("Reading all data...") +instance_list = [] +with open(all_file) as all: + for idx, line in tqdm(enumerate(all)): + if len(line) == 0: + continue + instance_list.append(line) +print("All data read.") + +# %% rearrange instance by class + +print("Stat entity type and relation type...") +entity_type_instance_dict = {} +relation_type_instance_dict = {} +for line in tqdm(instance_list): + if len(line) == 0: + continue + + record = json.loads(line) + + entity_type_list = record["spot"] + relation_type_list = record["asoc"] + + for entity_type in entity_type_list: + if entity_type not in entity_type_instance_dict: + entity_type_instance_dict[entity_type] = { + "type_id": len(entity_type_instance_dict), + "instance_list": [] + } + entity_type_instance_dict[entity_type]["instance_list"].append(line) + + for relation_type in relation_type_list: + if relation_type not in relation_type_instance_dict: + relation_type_instance_dict[relation_type] = { + "type_id": len(relation_type_instance_dict), + "instance_list": [] + } + relation_type_instance_dict[relation_type]["instance_list"].append(line) + +print("Stat over.") + +# %% save data by category + +metainfo = { + "entity": [], + "relation": [], +} + +print("Saving entity by category...") +for entity_type, data in tqdm(entity_type_instance_dict.items()): + type_id = data["type_id"] + instance_list = data["instance_list"] + + current_metainfo = { + "entity_type": entity_type, + "type_id": type_id + } + metainfo["entity"].append(current_metainfo) + + entity_type_file = os.path.join(target_entity_category_dir, str(type_id)+".json") + with open(entity_type_file, "w") as f: + for instance in instance_list: + f.write(instance) +print("Entity saved.") + +print("Saving relation by category...") +for relation_type, data in tqdm(relation_type_instance_dict.items()): + type_id = data["type_id"] + instance_list = data["instance_list"] + + current_metainfo = { + "relation_type": relation_type, + "type_id": type_id + } + metainfo["relation"].append(current_metainfo) + + relation_type_file = os.path.join(target_relation_category_dir, str(type_id)+".json") + with open(relation_type_file, "w") as f: + for instance in instance_list: + f.write(instance) +print("Relation saved.") + +print("Saving metainfo...") +with open(metainfo_file, "w") as f: + json.dump(metainfo, f) +print("Metainfo saved.") \ No newline at end of file diff --git a/metaretriever/dataset_processing/ours/change_data_format_for_relation.py b/metaretriever/dataset_processing/ours/change_data_format_for_relation.py new file mode 100644 index 00000000..7606d81a --- /dev/null +++ b/metaretriever/dataset_processing/ours/change_data_format_for_relation.py @@ -0,0 +1,38 @@ +import os +import json +import argparse +from tqdm import tqdm + +in_files = [ + 'original_train.json', +] +out_files = [ + 'train.json', +] + +parser = argparse.ArgumentParser() +parser.add_argument("-d", "--dir", default="output", type=str) +opt = parser.parse_args() + +dir_path = opt.dir + + +for in_file, out_file in zip(in_files, out_files): + in_file_path = os.path.join(dir_path, in_file) + out_file_path = os.path.join(dir_path, out_file) + + print(f"{in_file_path} -> {out_file_path}") + + fin = open(in_file_path) + fout = open(out_file_path,'w') + for line in tqdm(fin): + obj = json.loads(line) + flag = 0 + tmp_relation = [] + for tmp_obj in obj['relation']: + tmp = json.dumps(tmp_obj) + tmp_relation.append(tmp) + obj['relation'] = tmp_relation + fout.write(json.dumps(obj)+"\n") + fin.close() + fout.close() \ No newline at end of file diff --git a/metaretriever/dataset_processing/ours/download_and_preprocess_data_clean.sh b/metaretriever/dataset_processing/ours/download_and_preprocess_data_clean.sh new file mode 100644 index 00000000..0ff3ad97 --- /dev/null +++ b/metaretriever/dataset_processing/ours/download_and_preprocess_data_clean.sh @@ -0,0 +1,29 @@ +if [ ! -e out_clean.zip ]; +then + echo "downloading out_clean ..." + wget -c http://url/to/dataset/out_clean.zip +else + echo "out_clean has been downloaded." +fi + +if [ ! -d out_clean ]; +then + echo "unziping out_clean" + unzip out_clean.zip +else + echo "out_clean has been unzipped" +fi + +# preprocess +python explore.py --data_dir ./out_clean --output_dir ./output --max_instance_num -1 + +# fewshot sampling +python stat_category.py --source_dir ./output --output_dir ./output +python partition.py --source_dir ./output --output_dir ./output +python match.py --source_dir ./output --output_dir ./output --step 100 +python rearrange_dataset.py --source_dir ./output --output_dir ./output + +# generate dataset +python noise.py --output_dir ./output --all_file rearrange_all.json +python change_data_format_for_relation.py -d ./output +ln -s ../../../ours/output ../converted_data/text2spotasoc/relation/ours diff --git a/metaretriever/dataset_processing/ours/explore.py b/metaretriever/dataset_processing/ours/explore.py new file mode 100644 index 00000000..2c5d1876 --- /dev/null +++ b/metaretriever/dataset_processing/ours/explore.py @@ -0,0 +1,323 @@ +import json +import os +import random +import argparse + +from tqdm import tqdm + +from nltk.tokenize import WordPunctTokenizer +word_tokenizer = WordPunctTokenizer() + +import numpy as np +np.set_printoptions(suppress=True) + +import pdb + +parser = argparse.ArgumentParser() +parser.add_argument("-d", "--data_dir", default="./final_data5/data_1", type=str) +parser.add_argument("-o", "--output_dir", default="./output", type=str) +parser.add_argument("-n", "--max_instance_num", default=-1, type=int) +opt = parser.parse_args() + +data_dir = opt.data_dir +output_dir = opt.output_dir +max_instance_num = opt.max_instance_num + +entity_schema_file = os.path.join(output_dir, "entity.schema") +relation_schema_file = os.path.join(output_dir, "relation.schema") +event_schema_file = os.path.join(output_dir, "event.schema") +record_schema_file = os.path.join(output_dir, "record.schema") + +all_file = os.path.join(output_dir, "all.json") +train_file = os.path.join(output_dir, "original_train.json") +dev_file = os.path.join(output_dir, "original_val.json") +test_file = os.path.join(output_dir, "original_test.json") + +ENTITY_SEARCH_RANGE = 0 + +ALL_ENTITY_CNT = 0 +NOMATCH_ENTITY_CNT = 0 + +NON_OFFSET_ENTITY_CNT = 0 + +def word_tokenize(text): + return word_tokenizer.tokenize(text) + +def record2instance(record): + instance = { + "text": None, + "tokens": None, + "record": None, + "entity": None, + "relation": None, + "event": [], + "spot": None, + "asoc": None, + "spot_asoc": None, + } + + # create text field + text = record["sentence_value"] + instance["text"] = text + + # create tokens field + tokens = word_tokenize(text) + text_length_list.append(len(tokens)) + instance["tokens"] = tokens + + # create entity field + entities = record["sentence_entities"] + instance_entity_list = [] + for entity in entities: + entity_uri = entity["uri"] + entity_mention = entity["surfaceform"] + entity_type = entity["tag"] + entity_offset = entity["boundaries_token"] + + if entity_type == "#dateTime": + entity_type = "date time" + elif entity_type == "#decimal": + entity_type = "decimal" + elif entity_type == "": + entity_type = "other" + + if entity_mention == "": + continue + + try: + start_index, end_index = entity_offset[0], entity_offset[-1] + except: + global NON_OFFSET_ENTITY_CNT + NON_OFFSET_ENTITY_CNT += 1 + return None + current_mention = " ".join(tokens[start_index:end_index+1]) + original_mention = " ".join(word_tokenize(entity_mention)) + if current_mention != original_mention: + global NOMATCH_ENTITY_CNT + NOMATCH_ENTITY_CNT += 1 + global ALL_ENTITY_CNT + ALL_ENTITY_CNT += 1 + entity_offset = list(range(start_index, end_index+1)) + + instance_entity = { + "type": entity_type, + "offset": entity_offset, + "text": entity_mention, + "uri": entity_uri + } + instance_entity_list.append(instance_entity) + instance["entity"] = instance_entity_list + + # create spot field + instance_entity_type_list = [i["type"] for i in instance_entity_list] + instance["spot"] = list(set(instance_entity_type_list)) + entity_type_list.extend(instance_entity_type_list) + + # create relation field + triples = record["sentence_triples"] + instance_relation_list = [] + for triple in triples: + subj = triple["subject"] + obj = triple["object"] + predicate = triple["predicate"] + relation_type = predicate["surfaceform"] + + try: + head_entity = [i for i in instance_entity_list if i["uri"] == subj["uri"]][0] + except IndexError: + continue + + try: + tail_entity = [i for i in instance_entity_list if i["uri"] == obj["uri"]][0] + except IndexError: + continue + + head_entity_type = head_entity["type"] + tail_entity_type = tail_entity["type"] + + triple_type = (head_entity_type, relation_type, tail_entity_type) + triple_type_list.append(triple_type) + + instance_relation = { + "type": relation_type, + "args": [ + head_entity, + tail_entity + ] + } + instance_relation_list.append(instance_relation) + instance["relation"] = instance_relation_list + + # create asoc field + instance_asoc_list = [i["type"] for i in instance_relation_list] + instance["asoc"] = list(set(instance_asoc_list)) + relation_list.extend(instance_asoc_list) + + # create spot_asoc field + instance_spot_asoc_list = [] + for entity in instance_entity_list: + instance_spot_asoc = { + "span": entity["text"], + "label": entity["type"], + "asoc": [] + } + + for triple in instance_relation_list: + if triple["args"][0]["uri"] == entity["uri"]: + asoc_record = [triple["type"], triple["args"][1]["text"]] + instance_spot_asoc["asoc"].append(asoc_record) + + instance_spot_asoc_list.append(instance_spot_asoc) + instance["spot_asoc"] = instance_spot_asoc_list + + # create record field + instance_record = " " + for instance_spot_asoc in instance_spot_asoc_list: + instance_record += " " + + instance_record += instance_spot_asoc["label"] + " " + instance_record += " " + instance_record += instance_spot_asoc["span"] + " " + + if len(instance_spot_asoc["asoc"]) != 0: + for asoc in instance_spot_asoc["asoc"]: + instance_record += " " + + instance_record += asoc[0] + " " + instance_record += " " + instance_record += asoc[1] + " " + + instance_record += " " + + instance_record += " " + instance_record += "" + instance["record"] = instance_record + + return instance + +# %% read data + +file_list = os.listdir(data_dir) + +text_length_list = [] +record_cnt = 0 + +relation_list = [] +entity_type_list = [] +triple_type_list = [] +json_str_length_list = [] +instance_num = 0 +with open(all_file, "w") as all: + for file_name in tqdm(file_list): + file_path = os.path.join(data_dir, file_name) + + with open(file_path) as f: + for line in f: + if len(line) == 0: + continue + + record = json.loads(line) + record_cnt += 1 + + instance = record2instance(record) + if instance is None: + continue + + json_str = json.dumps(instance) + json_str_length_list.append(len(json_str)) + all.write(json_str + "\n") + instance_num += 1 + + if max_instance_num != -1 and instance_num == max_instance_num: + break + + if max_instance_num != -1 and instance_num == max_instance_num: + break + +print(f"Total number of all entities: {ALL_ENTITY_CNT}") + +print(f"Those entities non-match raw text: {NOMATCH_ENTITY_CNT}") +print(f"Non-match rate: {NOMATCH_ENTITY_CNT / ALL_ENTITY_CNT}") + +print(f"Total number of all non-offset entities: {NON_OFFSET_ENTITY_CNT}") +print(f"Non-offset rate: {NON_OFFSET_ENTITY_CNT / ALL_ENTITY_CNT}") + +print(f"Total record: {record_cnt}") +print(f"Total instance: {instance_num}") + +print() + +# %% stat of text length +max_len = max(text_length_list) +min_len = min(text_length_list) + +print(f"Max length: {max_len}, Min length: {min_len}") + +bins = 20 + +hist, bin_edges = np.histogram(text_length_list, bins=bins, density=False) +print("Hist:", hist) +print("Edge:", bin_edges) + +satisfied_length_cnt = len([i for i in text_length_list if i <= 512]) +print(f"Satisfied length cnt: {satisfied_length_cnt} ({satisfied_length_cnt/len(text_length_list)})") +print() + +# %% stat of json string length +max_json_len = max(json_str_length_list) +min_json_len = min(json_str_length_list) + +print(f"Max json length: {max_json_len}, Min json length: {min_json_len}") + +bins = 20 + +json_hist, json_bin_edges = np.histogram(json_str_length_list, bins=bins, density=False) +print("Hist:", json_hist) +print("Edge:", json_bin_edges) + +satisfied_json_length_cnt = len([i for i in json_str_length_list if i <= 4096]) +print(f"Satisfied json length cnt: {satisfied_json_length_cnt} ({satisfied_json_length_cnt/len(json_str_length_list)})") + +print() + +# %% create schema + +entity_type_list = list(set(entity_type_list)) +relation_list = list(set(relation_list)) + +print(f"Num of entity type: {len(entity_type_list)}") +print(f"Num of relation type: {len(relation_list)}") + +record_type_list = {} +for head_entity_type, realtion_type, tail_entity_type in triple_type_list: + if record_type_list.get(head_entity_type) is None: + record_type_list[head_entity_type] = [] + record_type_list[head_entity_type].append(realtion_type) +for head_entity_type, record_relation_list in record_type_list.items(): + record_type_list[head_entity_type] = list(set(record_relation_list)) + +with open(entity_schema_file, "w") as f: + f.write(json.dumps(entity_type_list) + "\n") + f.write(json.dumps([]) + "\n") + f.write(json.dumps({}) + "\n") +print("entity.schema saved") + +with open(relation_schema_file, "w") as f: + f.write(json.dumps(relation_list) + "\n") + f.write(json.dumps(entity_type_list) + "\n") + f.write(json.dumps({i: [] for i in relation_list}) + "\n") +print("relation.schema saved") + +with open(event_schema_file, "w") as f: + f.write(json.dumps([]) + "\n") + f.write(json.dumps([]) + "\n") + f.write(json.dumps({}) + "\n") +print("event.schema saved") + +with open(record_schema_file, "w") as f: + f.write(json.dumps(entity_type_list) + "\n") + f.write(json.dumps(relation_list) + "\n") + f.write(json.dumps(record_type_list) + "\n") +print("record.schema saved") + +print() diff --git a/metaretriever/dataset_processing/ours/match.py b/metaretriever/dataset_processing/ours/match.py new file mode 100644 index 00000000..da5556a8 --- /dev/null +++ b/metaretriever/dataset_processing/ours/match.py @@ -0,0 +1,117 @@ +import os +import json +import math +import time +import argparse +from tqdm import tqdm +import networkx as nx + +import pdb + +parser = argparse.ArgumentParser() +parser.add_argument("-s", "--source_dir", default="./", type=str) +parser.add_argument("-o", "--output_dir", default="./", type=str) +parser.add_argument("--step", default=100, type=int) +opt = parser.parse_args() + +source_dir = opt.source_dir +output_dir = opt.output_dir +step = opt.step + +instance_label_file = os.path.join(output_dir, "instance_label.json") +partition_file = os.path.join(output_dir, "partition.json") +match_group_file = os.path.join(output_dir, "match_group.json") + +# %% + +print("Loading partition...") +partition = [] +with open(partition_file) as f: + for line in f: + partition.append(json.loads(line)) + +# %% + +print("Loading instance label list...") +instance_label_list = [] +with open(instance_label_file) as f: + for line in tqdm(f): + instance_label = json.loads(line) + instance_label_list.append(instance_label) +instance_label_dict = {i: j for i, j in instance_label_list} +total = len(instance_label_dict) + +# %% + +def score(x_label, y_label, add_coef=True): + x_label = set(x_label) + y_label = set(y_label) + + y2x_score = len(x_label & y_label) / len(x_label) + if add_coef: + y2x_score += 1 / len(y_label) + x2y_score = len(x_label & y_label) / len(y_label) + if add_coef: + x2y_score += + 1 / len(x_label) + + if x2y_score > y2x_score: + final_score = x2y_score + flag = True + else: + final_score = y2x_score + flag = False + + return final_score, flag + +# %% + +print("Matching...") +match_group = [] +for curr_partition in tqdm(partition): + type_name, category, instance_list = curr_partition + + if len(instance_list) == 1: + match_group.append((instance_list[0], instance_list[0], 1.0)) + else: + # pdb.set_trace() + total_epoch = math.ceil(len(instance_list) / step) + + for epoch in tqdm(range(total_epoch), leave=False): + batch = instance_list[epoch*step:(epoch+1)*step] + + edges = [] + for i in range(len(batch)): + for j in range(i+1, len(batch)): + x_id, y_id = batch[i], batch[j] + + x_label = instance_label_dict[x_id] + y_label = instance_label_dict[y_id] + + edge_weight, _ = score(x_label, y_label) + + edges.append((x_id, y_id, edge_weight)) + + G = nx.Graph() + G.add_weighted_edges_from(edges) + + match_result = nx.max_weight_matching(G) + + for edge in match_result: + x_id, y_id = edge + x_label = instance_label_dict[x_id] + y_label = instance_label_dict[y_id] + match_score, flag = score(x_label, y_label, add_coef=False) + + if flag: + match_group.append((x_id, y_id, match_score)) + else: + match_group.append((y_id, x_id, match_score)) + +scores = [i[-1] for i in match_group] +average_score = sum(scores) / len(scores) +print(f"Average match score: {average_score}") + +print("Saving match group...") +with open(match_group_file, "w") as f: + for record in match_group: + f.write(json.dumps(record)+"\n") diff --git a/metaretriever/dataset_processing/ours/noise.py b/metaretriever/dataset_processing/ours/noise.py new file mode 100644 index 00000000..a2f98bb5 --- /dev/null +++ b/metaretriever/dataset_processing/ours/noise.py @@ -0,0 +1,242 @@ +import json +import os +import random +import argparse +from tqdm import tqdm +from copy import deepcopy +import numpy as np + +import pdb + +seed = 0 +random.seed(seed) +np.random.seed(seed) + +parser = argparse.ArgumentParser() +parser.add_argument("-o", "--output_dir", default="./output", type=str) +parser.add_argument("-a", "--all_file", default="all.json", type=str) +parser.add_argument("-n", "--noise", default=4, type=int) +opt = parser.parse_args() + +output_dir = opt.output_dir +all_file = opt.all_file +noise = opt.noise + +original_all_file = os.path.join(output_dir, all_file) +noised_all_file = os.path.join(output_dir, "noised_all.json") + +train_file = os.path.join(output_dir, "original_train.json") +dev_file = os.path.join(output_dir, "original_val.json") +test_file = os.path.join(output_dir, "original_test.json") + +# %% noise function + +NOISE_NUM = noise + +THRESHOLD = 0.8 +TRIPLE_THRESHOLD = [0.6, 0.8] + +DECAY_COEF = 0.8 +NOISE_OFFSET_THRESHOLD = 3 +NOISE_OFFSET_RANGE = list(range(NOISE_OFFSET_THRESHOLD)) +NOISE_OFFSET_WEIGHT = np.exp(- DECAY_COEF * np.array(NOISE_OFFSET_RANGE)) +NOISE_OFFSET_WEIGHT = NOISE_OFFSET_WEIGHT / NOISE_OFFSET_WEIGHT.sum() + +def noise_entity_type(entity_list): + entity_type_list = [] + for entity in entity_list: + entity_type_list.append(entity["type"]) + entity_type_list = list(set(entity_type_list)) + + noised_entity_list = [] + for entity in entity_list: + noised_entity = deepcopy(entity) + if np.random.rand() > THRESHOLD: + noised_entity_type = random.choice(entity_type_list) + noised_entity["type"] = noised_entity_type + noised_entity_list.append(noised_entity) + return noised_entity_list + + +def noise_entity_offset(entity_list, tokens): + noised_entity_list = [] + for entity in entity_list: + noised_entity = deepcopy(entity) + + entity_offset = noised_entity["offset"] + start_index, end_index = entity_offset[0], entity_offset[-1] + + start_noise = np.random.choice(NOISE_OFFSET_RANGE, p=NOISE_OFFSET_WEIGHT) + end_noise = np.random.choice(NOISE_OFFSET_RANGE, p=NOISE_OFFSET_WEIGHT) + + noised_start_index = max(start_index-start_noise, 0) + noised_end_index = min(end_index+end_noise, len(tokens)-1) + noised_entity_offset = list(range(noised_start_index, noised_end_index+1)) + + noised_entity_mention = " ".join(tokens[noised_start_index:noised_end_index+1]) + + noised_entity["offset"] = noised_entity_offset + noised_entity["text"] = noised_entity_mention + + noised_entity_list.append(noised_entity) + return noised_entity_list + +def noise_entity_with_other_entity(entity_list): + type_entity_mapping = {} + for entity in entity_list: + entity_type = entity["type"] + if entity_type not in type_entity_mapping: + type_entity_mapping[entity_type] = [] + type_entity_mapping[entity_type].append(entity) + + noised_entity_list = [] + for entity in entity_list: + noised_entity = deepcopy(entity) + if np.random.rand() > THRESHOLD: + entity_type = noised_entity["type"] + other_entity = random.choice(type_entity_mapping[entity_type]) + noised_entity["text"] = other_entity["text"] + noised_entity["offset"] = other_entity["offset"] + noised_entity_list.append(noised_entity) + return noised_entity_list + +def noise_relation_type(triple_list): + relation_type_list = [] + for triple in triple_list: + relation_type_list.append(triple["type"]) + relation_type_list = list(set(relation_type_list)) + + noised_triple_list = [] + for triple in triple_list: + noised_triple = deepcopy(triple) + if np.random.rand() > THRESHOLD: + noised_relation_type = random.choice(relation_type_list) + noised_triple["type"] = noised_relation_type + noised_triple_list.append(noised_triple) + return noised_triple_list + +def noise_triple_num(triple_list, entity_list): + noised_triple_list = [] + for triple in triple_list: + p = np.random.rand() + if p < TRIPLE_THRESHOLD[0]: # do nothing + noised_triple_list.append(triple) + elif p < TRIPLE_THRESHOLD[1]: # add noised triple + noised_triple_list.append(triple) + + noised_triple = deepcopy(triple) + replaced_tail = random.choice(entity_list) + noised_triple["args"][1] = replaced_tail + noised_triple_list.append(noised_triple) + else: # remove triple + pass + + return noised_triple_list + +# %% utils + +def build_entity_dict(entity_list): + entity_dict = {} + for entity in entity_list: + entity_uri = entity["uri"] + entity_dict[entity_uri] = entity + return entity_dict + +def update_relation_triple_by_noised_entity(triple_list, noised_entity_dict): + noised_triple_list = [] + for triple in triple_list: + noised_triple = deepcopy(triple) + head, tail = noised_triple["args"] + noised_head, noised_tail = noised_entity_dict[head["uri"]], noised_entity_dict[tail["uri"]] + noised_triple["args"] = [noised_head, noised_tail] + noised_triple_list.append(noised_triple) + return noised_triple_list + +def create_spot_asoc_field(instance_entity_list, instance_triple_list): + instance_spot_asoc_list = [] + for entity in instance_entity_list: + instance_spot_asoc = { + "span": entity["text"], + "label": entity["type"], + "asoc": [] + } + + for triple in instance_triple_list: + if triple["args"][0]["uri"] == entity["uri"]: + asoc_record = [triple["type"], triple["args"][1]["text"]] + instance_spot_asoc["asoc"].append(asoc_record) + + instance_spot_asoc_list.append(instance_spot_asoc) + return instance_spot_asoc_list + +def create_record_field(instance_spot_asoc_list): + instance_record = " " + for instance_spot_asoc in instance_spot_asoc_list: + instance_record += " " + + instance_record += instance_spot_asoc["label"] + " " + instance_record += " " + instance_record += instance_spot_asoc["span"] + " " + + if len(instance_spot_asoc["asoc"]) != 0: + for asoc in instance_spot_asoc["asoc"]: + instance_record += " " + + instance_record += asoc[0] + " " + instance_record += " " + instance_record += asoc[1] + " " + + instance_record += " " + + instance_record += " " + instance_record += "" + + return instance_record + +# %% create noised record for all + +with open(original_all_file) as src, open(noised_all_file, "w") as tgt: + for line in tqdm(src): + instance = json.loads(line) + + tokens = instance["tokens"] + entity_list = instance["entity"] + triple_list = instance["relation"] + spot_asoc_list = instance["spot_asoc"] + record = instance["record"] + + noised_record_list = [] + for _ in range(NOISE_NUM): + # noise entity + noised_entity_list = noise_entity_offset(entity_list, tokens) + noised_entity_list = noise_entity_with_other_entity(noised_entity_list) + noised_entity_list = noise_entity_type(noised_entity_list) + + noised_entity_dict = build_entity_dict(noised_entity_list) + + # noise triple + noised_triple_list = update_relation_triple_by_noised_entity(triple_list, noised_entity_dict) + + noised_triple_list = noise_relation_type(noised_triple_list) + noised_triple_list = noise_triple_num(noised_triple_list, noised_entity_list) + + # create noised record + noised_spot_asoc_list = create_spot_asoc_field(noised_entity_list, noised_triple_list) + noised_record = create_record_field(noised_spot_asoc_list) + noised_record_list.append(noised_record) + + # remove uir field + for entity in entity_list: + del entity["uri"] + + instance["noised_record"] = noised_record_list + + json_str = json.dumps(instance) + tgt.write(json_str + "\n") + +# %% create train/dev/test data + +with open(noised_all_file) as all, open(train_file, "w") as train, open(dev_file, "w") as dev, open(test_file, "w") as test: + for i, line in tqdm(enumerate(all)): + train.write(line) +print("train/dev/test saved.") diff --git a/metaretriever/dataset_processing/ours/output/.placeholder b/metaretriever/dataset_processing/ours/output/.placeholder new file mode 100644 index 00000000..e69de29b diff --git a/metaretriever/dataset_processing/ours/partition.py b/metaretriever/dataset_processing/ours/partition.py new file mode 100644 index 00000000..68b07593 --- /dev/null +++ b/metaretriever/dataset_processing/ours/partition.py @@ -0,0 +1,96 @@ +import json +import os +import random +import argparse + +from collections import OrderedDict + +from tqdm import tqdm + +import pdb + +parser = argparse.ArgumentParser() +parser.add_argument("-s", "--source_dir", default="./", type=str) +parser.add_argument("-o", "--output_dir", default="./", type=str) +opt = parser.parse_args() + +source_dir = opt.source_dir +output_dir = opt.output_dir + +all_file = os.path.join(source_dir, "all.json") + +entity_stat_file = os.path.join(output_dir, "entity_stat.json") +relation_stat_file = os.path.join(output_dir, "relation_stat.json") + +partition_file = os.path.join(output_dir, "partition.json") + +entity_stat_list = [] +relation_stat_list = [] +with open(entity_stat_file) as f: + for line in f: + category = json.loads(line) + category[1]["type"] = "entity" + entity_stat_list.append(category) +with open(relation_stat_file) as f: + for line in f: + category = json.loads(line) + category[1]["type"] = "relation" + relation_stat_list.append(category) + +all_stat_list = entity_stat_list + relation_stat_list +all_stat_list = sorted(all_stat_list, key=lambda x: len(x[1]["instance_id_list"])) + +instance_type_dict = {} +for curr_type, curr_record in tqdm(all_stat_list): + instance_id_list = curr_record["instance_id_list"] + for instance_id in instance_id_list: + if instance_id not in instance_type_dict: + instance_type_dict[instance_id] = set() + instance_type_dict[instance_id].add(curr_type) + +def get_visited_type(instance_id_list, instance_type_dict): + visited_type = set() + for i, instance_id in enumerate(instance_id_list): + if i == 0: + visited_type |= instance_type_dict[instance_id] + else: + visited_type &= instance_type_dict[instance_id] + return visited_type + +print("Begining partition...") +visited_instance = set() +visited_type = set() +partition = [] +empty_set_cnt = 0 +duplicated_instance_cnt = 0 +for curr_type, curr_record in tqdm(all_stat_list): + category_type = curr_record["type"] + instance_id_list = curr_record["instance_id_list"] + + instance_id_set = set(instance_id_list) + instance_id_set = instance_id_set - visited_instance + + curr_visited_type = get_visited_type(instance_id_list, instance_type_dict) + + if len(instance_id_set) == 0: + if curr_type in visited_type: + continue + else: + non_visited_type = curr_visited_type - visited_type + instance_id_set = set(instance_id_list) + empty_set_cnt += 1 + duplicated_instance_cnt += len(instance_id_list) + + curr_partition = [curr_type, category_type, list(instance_id_set)] + partition.append(curr_partition) + + visited_instance.update(instance_id_set) + visited_type.update(curr_visited_type) + +print(f"Empty set rate: {empty_set_cnt / len(all_stat_list)}") +print(f"Duplication rate: {duplicated_instance_cnt / len(instance_type_dict)}") + +print("Saving partition...") +with open(partition_file, "w") as f: + for record in partition: + f.write(json.dumps(record)+"\n") diff --git a/metaretriever/dataset_processing/ours/rearrange_dataset.py b/metaretriever/dataset_processing/ours/rearrange_dataset.py new file mode 100644 index 00000000..07cfb59a --- /dev/null +++ b/metaretriever/dataset_processing/ours/rearrange_dataset.py @@ -0,0 +1,50 @@ +import os +import json +import math +import time +import random +import argparse +from tqdm import tqdm + +import pdb + +parser = argparse.ArgumentParser() +parser.add_argument("-s", "--source_dir", default="./", type=str) +parser.add_argument("-o", "--output_dir", default="./", type=str) +opt = parser.parse_args() + +source_dir = opt.source_dir +output_dir = opt.output_dir + +all_file = os.path.join(source_dir, "all.json") +match_group_file = os.path.join(output_dir, "match_group.json") +rearrange_all_file = os.path.join(output_dir, "rearrange_all.json") + +# %% + +print("Loading match group...") +match_group = [] +with open(match_group_file) as f: + for line in tqdm(f): + match_group.append(json.loads(line)) + +# %% + +print("Loading instance...") +instance_list = [] +with open(all_file) as f: + for line in tqdm(f): + instance_list.append(line) + +# %% + +print("Rearrange dataset...") +with open(rearrange_all_file, "w") as f: + for edge in tqdm(match_group): + support_id, query_id, _ = edge + + support = instance_list[support_id] + query = instance_list[query_id] + + f.write(support) + f.write(query) diff --git a/metaretriever/dataset_processing/ours/sample_task.py b/metaretriever/dataset_processing/ours/sample_task.py new file mode 100644 index 00000000..a221ccd1 --- /dev/null +++ b/metaretriever/dataset_processing/ours/sample_task.py @@ -0,0 +1,150 @@ +import json +import os +import random +import argparse + +from tqdm import tqdm + +import pdb + +parser = argparse.ArgumentParser() +parser.add_argument("-d", "--data_dir", default="./final_data5/data_1", type=str) +parser.add_argument("-o", "--output_dir", default="./output_fewshot", type=str) +parser.add_argument("--entity_category_dir", default="entity_category", type=str) +parser.add_argument("--relation_category_dir", default="relation_category", type=str) +parser.add_argument("--task_num", default=10000, type=int) +parser.add_argument("--N", default=5, type=int) +parser.add_argument("--K", default=5, type=int) +parser.add_argument("--Q", default=5, type=int) +opt = parser.parse_args() + +data_dir = opt.data_dir +output_dir = opt.output_dir +entity_category_dir = opt.entity_category_dir +relation_category_dir = opt.relation_category_dir +task_num = opt.task_num +N = opt.N +K = opt.K +Q = opt.Q + +target_entity_category_dir = os.path.join(output_dir, entity_category_dir) +target_relation_category_dir = os.path.join(output_dir, relation_category_dir) + +metainfo_file = relation_instance_dict_file = os.path.join(output_dir, "metainfo.json") + +task_file = os.path.join(output_dir, "sampled_task.json") + +# %% read instance dict + +print("Reading metainfo...") +with open(metainfo_file) as f: + metainfo = json.load(f) +print("Metainfo read.") + +print("Loading entity instance dict...") +entity_type_instance_dict = {} +for current_metainfo in tqdm(metainfo["entity"]): + entity_type = current_metainfo["entity_type"] + type_id = current_metainfo["type_id"] + + entity_type_file = os.path.join(target_entity_category_dir, str(type_id)+".json") + instance_list = [] + with open(entity_type_file) as f: + for line in f: + instance_list.append(line) + + entity_type_instance_dict[entity_type] = instance_list +entity_type_list = list(entity_type_instance_dict.keys()) +print("Entity instance dict loaded") + +print("Loading relation instance dict...") +relation_type_instance_dict = {} +for current_metainfo in tqdm(metainfo["relation"]): + relation_type = current_metainfo["relation_type"] + type_id = current_metainfo["type_id"] + + relation_type_file = os.path.join(target_relation_category_dir, str(type_id)+".json") + instance_list = [] + with open(relation_type_file) as f: + for line in f: + instance_list.append(line) + + relation_type_instance_dict[relation_type] = instance_list +relation_type_list = list(relation_type_instance_dict.keys()) +print("Relation instance dict loaded.") + +# %% n-way-k-shot sampling + +print("Sampling N-Way K-Shot task...") +task_list = [] +for i in tqdm(range(task_num//2)): + # sample entity task + target_entity_type_list = random.sample(entity_type_list, N) + + task = { + "target_entity_type_list": target_entity_type_list, + "target_relation_type_list": [], + "N": N, + "K": K, + "Q": Q, + "support": None, + "query": None + } + + support = [] + query = [] + + for entity_type in target_entity_type_list: + instance_candidates = entity_type_instance_dict[entity_type] + + if len(instance_candidates) > K+Q: + sampled_instance_list = random.sample(instance_candidates, K+Q) + else: + sampled_instance_list = random.choices(instance_candidates, k=K+Q) + + support.extend(sampled_instance_list[:K]) + query.extend(sampled_instance_list[K:]) + + task["support"] = support + task["query"] = query + + task_list.append(task) + + # sample relation task + target_relation_type_list = random.sample(relation_type_list, N) + + task = { + "target_entity_type_list": [], + "target_relation_type_list": target_relation_type_list, + "N": N, + "K": K, + "Q": Q, + "support": None, + "query": None + } + + support = [] + query = [] + + for relation_type in target_relation_type_list: + instance_candidates = relation_type_instance_dict[relation_type] + + if len(instance_candidates) > K+Q: + sampled_instance_list = random.sample(instance_candidates, K+Q) + else: + sampled_instance_list = random.choices(instance_candidates, k=K+Q) + + support.extend(sampled_instance_list[:K]) + query.extend(sampled_instance_list[K:]) + + task["support"] = support + task["query"] = query + + task_list.append(task) +print("Sampling over.") + +print("Saving task...") +with open(task_file, "w") as f: + for task in tqdm(task_list): + f.write(json.dumps(task) + "\n") +print("Task saved.") \ No newline at end of file diff --git a/metaretriever/dataset_processing/ours/stat4maml.py b/metaretriever/dataset_processing/ours/stat4maml.py new file mode 100644 index 00000000..4ce0a746 --- /dev/null +++ b/metaretriever/dataset_processing/ours/stat4maml.py @@ -0,0 +1,64 @@ +import json +import os +import random +import argparse + +from collections import OrderedDict + +from tqdm import tqdm + +import pdb + +parser = argparse.ArgumentParser() +parser.add_argument("-d", "--data_dir", default="./final_data5/data_1", type=str) +parser.add_argument("-s", "--source_dir", default="./output", type=str) +parser.add_argument("-o", "--output_dir", default="./output_fewshot", type=str) +opt = parser.parse_args() + +data_dir = opt.data_dir +source_dir = opt.source_dir +output_dir = opt.output_dir + +all_file = os.path.join(source_dir, "all.json") + +entity_stat_file = os.path.join(output_dir, "entity_stat.json") +relation_stat_file = os.path.join(output_dir, "relation_stat.json") + +# %% read data and stat + +entity_stat_dict = {} +relation_stat_dict = {} + +record_cnt = 0 +with open(all_file) as all: + for line in tqdm(all): + if len(line) == 0: + continue + + record = json.loads(line) + + entity_type_list = record["spot"] + relation_type_list = record["asoc"] + + for entity_type in entity_type_list: + if entity_type not in entity_stat_dict: + entity_stat_dict[entity_type] = { + "type_id": len(entity_stat_dict), + "instance_id_list": [] + } + entity_stat_dict[entity_type]["instance_id_list"].append(record_cnt) + + for relation_type in relation_type_list: + if relation_type not in relation_stat_dict: + relation_stat_dict[relation_type] = { + "type_id": len(relation_stat_dict), + "instance_id_list": [] + } + relation_stat_dict[relation_type]["instance_id_list"].append(record_cnt) + + record_cnt += 1 + +with open(entity_stat_file, "w") as f: + json.dump(entity_stat_dict, f) +with open(relation_stat_file, "w") as f: + json.dump(relation_stat_dict, f) diff --git a/metaretriever/dataset_processing/ours/stat_category.py b/metaretriever/dataset_processing/ours/stat_category.py new file mode 100644 index 00000000..941bb66d --- /dev/null +++ b/metaretriever/dataset_processing/ours/stat_category.py @@ -0,0 +1,79 @@ +import json +import os +import random +import argparse + +from collections import OrderedDict + +from tqdm import tqdm + +import pdb + +parser = argparse.ArgumentParser() +parser.add_argument("-s", "--source_dir", default="./", type=str) +parser.add_argument("-o", "--output_dir", default="./", type=str) +opt = parser.parse_args() + +source_dir = opt.source_dir +output_dir = opt.output_dir + +all_file = os.path.join(source_dir, "all.json") + +entity_stat_file = os.path.join(output_dir, "entity_stat.json") +relation_stat_file = os.path.join(output_dir, "relation_stat.json") + +instance_label_file = os.path.join(output_dir, "instance_label.json") + +# %% read data and stat + +instance_label_list = [] +entity_stat_dict = {} +relation_stat_dict = {} + +print("Stating label...") +record_cnt = 0 +with open(all_file) as all: + for line in tqdm(all): + if len(line) == 0: + continue + + record = json.loads(line) + + entity_type_list = record["spot"] + relation_type_list = record["asoc"] + + labels = entity_type_list + relation_type_list + instance_label_list.append((record_cnt, labels)) + + for entity_type in entity_type_list: + if entity_type not in entity_stat_dict: + entity_stat_dict[entity_type] = { + "type_id": len(entity_stat_dict), + "instance_id_list": [] + } + entity_stat_dict[entity_type]["instance_id_list"].append(record_cnt) + + for relation_type in relation_type_list: + if relation_type not in relation_stat_dict: + relation_stat_dict[relation_type] = { + "type_id": len(relation_stat_dict), + "instance_id_list": [] + } + relation_stat_dict[relation_type]["instance_id_list"].append(record_cnt) + + record_cnt += 1 + +print("Saving entity stat...") +with open(entity_stat_file, "w") as f: + for key, value in tqdm(entity_stat_dict.items()): + f.write(json.dumps([key, value])+"\n") +print("Saving relation stat...") +with open(relation_stat_file, "w") as f: + for key, value in tqdm(relation_stat_dict.items()): + f.write(json.dumps([key, value])+"\n") + +print("Saving instance label stat...") +instance_label_list = sorted(instance_label_list, key=lambda x: len(x[1]), reverse=True) +with open(instance_label_file, "w") as f: + for instance_label in tqdm(instance_label_list): + f.write(json.dumps(instance_label)+"\n") diff --git a/metaretriever/dataset_processing/ours/task_format_change.py b/metaretriever/dataset_processing/ours/task_format_change.py new file mode 100644 index 00000000..217e48d7 --- /dev/null +++ b/metaretriever/dataset_processing/ours/task_format_change.py @@ -0,0 +1,187 @@ +import json +import os +import random +import argparse + +from tqdm import tqdm + +import pdb + +parser = argparse.ArgumentParser() +parser.add_argument("-d", "--data_dir", default="./final_data5/data_1", type=str) +opt = parser.parse_args() + +data_dir = opt.data_dir +output_dir = opt.output_dir + +task_file = os.path.join(output_dir, "sampled_task.json") + +sampled_all_file = os.path.join(output_dir, "sampled_all.json") + +# %% utils + +def create_spot_asoc_field(instance_entity_list, instance_triple_list): + instance_spot_asoc_list = [] + for entity in instance_entity_list: + instance_spot_asoc = { + "span": entity["text"], + "label": entity["type"], + "asoc": [] + } + + for triple in instance_triple_list: + if triple["args"][0]["uri"] == entity["uri"]: + asoc_record = [triple["type"], triple["args"][1]["text"]] + instance_spot_asoc["asoc"].append(asoc_record) + + instance_spot_asoc_list.append(instance_spot_asoc) + return instance_spot_asoc_list + +def create_record_field(instance_spot_asoc_list): + instance_record = " " + for instance_spot_asoc in instance_spot_asoc_list: + instance_record += " " + + instance_record += instance_spot_asoc["label"] + " " + instance_record += " " + instance_record += instance_spot_asoc["span"] + " " + + if len(instance_spot_asoc["asoc"]) != 0: + for asoc in instance_spot_asoc["asoc"]: + instance_record += " " + + instance_record += asoc[0] + " " + instance_record += " " + instance_record += asoc[1] + " " + + instance_record += " " + + instance_record += " " + instance_record += "" + + return instance_record + +def filter_entity_by_entity_type(entity_list, target_entity_type_list): + ''' + {"type": "rocket stage", "offset": [11, 12, 13], "text": "S-II", "uri": "Q1093699"} + ''' + filtered_entity_list = [entity for entity in entity_list if entity["type"] in target_entity_type_list] + return filtered_entity_list + +def filter_triple_by_entity_list(triple_list, filtered_entity_list): + ''' + {"type": "part of", "args": [{"type": "rocket stage", "offset": [1, 2, 3], "text": "MS-II", "uri": "Q6717655"}, {"type": "rocket stage", "offset": [11, 12, 13], "text": "S-II", "uri": "Q1093699"}]} + ''' + filtered_triple_list = [] + for triple in triple_list: + head, tail = triple["args"] + if head in filtered_entity_list and tail in filtered_entity_list: + filtered_triple_list.append(triple) + return filtered_triple_list + +def build_target_relation_type_list(filtered_triple_list): + target_relation_type_list = [triple["type"] for triple in filtered_triple_list] + target_relation_type_list = list(set(target_relation_type_list)) + return target_relation_type_list + +def filter_triple_by_relation_type(triple_list, target_relation_type_list): + ''' + {"type": "part of", "args": [{"type": "rocket stage", "offset": [1, 2, 3], "text": "MS-II", "uri": "Q6717655"}, {"type": "rocket stage", "offset": [11, 12, 13], "text": "S-II", "uri": "Q1093699"}]} + ''' + filtered_triple_list = [triple for triple in triple_list if triple["type"] in target_relation_type_list] + return filtered_triple_list + +def filter_entity_by_triple_list(entity_list, filtered_triple_list): + filtered_entity_list = [] + for triple in filtered_triple_list: + head, tail = triple["args"] + filtered_entity_list.append(head) + filtered_entity_list.append(tail) + entity_uri_set = set() + unique_filtered_entity_list = [] + for entity in filtered_entity_list: + uri = entity["uri"] + if uri not in entity_uri_set: + entity_uri_set.add(uri) + unique_filtered_entity_list.append(entity) + return unique_filtered_entity_list + +def build_target_entity_type_list(filtered_entity_list): + target_entity_type_list = [entity["type"] for entity in filtered_entity_list] + target_entity_type_list = list(set(target_entity_type_list)) + return target_entity_type_list + +def create_instance(instance_line, target_entity_type_list, target_relation_type_list): + instance = json.loads(instance_line) + + entity_list = instance["entity"] + triple_list = instance["relation"] + spot_asoc_list = instance["spot_asoc"] + record = instance["record"] + + if len(target_relation_type_list) == 0: + filtered_entity_list = filter_entity_by_entity_type(entity_list, target_entity_type_list) + filtered_triple_list = filter_triple_by_entity_list(triple_list, filtered_entity_list) + + current_target_entity_type_list = target_entity_type_list + current_target_relation_type_list = build_target_relation_type_list(filtered_triple_list) + else: + filtered_triple_list = filter_triple_by_relation_type(triple_list, target_relation_type_list) + filtered_entity_list = filter_entity_by_triple_list(entity_list, filtered_triple_list) + + current_target_entity_type_list = build_target_entity_type_list(filtered_entity_list) + current_target_relation_type_list = target_relation_type_list + + filtered_spot_asoc_list = create_spot_asoc_field(filtered_entity_list, filtered_triple_list) + filtered_record = create_record_field(filtered_spot_asoc_list) + + instance["entity"] = filtered_entity_list + instance["relation"] = filtered_triple_list + instance["spot"] = current_target_entity_type_list + instance["asoc"] = current_target_relation_type_list + instance["spot_asoc"] = filtered_spot_asoc_list + instance["record"] = filtered_record + + return instance + +# %% read task + +print("Reading task...") +task_list = [] +with open(task_file) as f: + for line in tqdm(f): + task_list.append(line) +print("Task read.") + +# %% write to sampled all + +print("Changing task format...") +with open(sampled_all_file, "w") as f: + for task_line in tqdm(task_list): + task = json.loads(task_line) + + target_entity_type_list = task["target_entity_type_list"] + target_relation_type_list = task["target_relation_type_list"] + + support = task["support"] + query = task["query"] + + support_instance_list = [] + for instance_line in support: + instance = create_instance(instance_line, target_entity_type_list, target_relation_type_list) + + support_instance_list.append(instance) + + query_instance_list = [] + for instance_line in query: + instance = create_instance(instance_line, target_entity_type_list, target_relation_type_list) + + query_instance_list.append(instance) + + random.shuffle(support_instance_list) + random.shuffle(query_instance_list) + for instance in support_instance_list: + f.write(json.dumps(instance) + "\n") + for instance in query_instance_list: + f.write(json.dumps(instance) + "\n") +print("Task format changed.") \ No newline at end of file diff --git a/metaretriever/dataset_processing/run_data_generation.bash b/metaretriever/dataset_processing/run_data_generation.bash new file mode 100644 index 00000000..50fe7994 --- /dev/null +++ b/metaretriever/dataset_processing/run_data_generation.bash @@ -0,0 +1,9 @@ +#!/usr/bin/env bash +# -*- coding:utf-8 -*- + +for data_format in entity relation event absa +do + python uie_convert.py -format spotasoc -config data_config/${data_format} -output ${data_format} +done + +python scripts/data_statistics.py -data converted_data/text2spotasoc/ diff --git a/metaretriever/dataset_processing/run_sample.bash b/metaretriever/dataset_processing/run_sample.bash new file mode 100644 index 00000000..e1657146 --- /dev/null +++ b/metaretriever/dataset_processing/run_sample.bash @@ -0,0 +1,53 @@ +#!/usr/bin/env bash +# -*- coding:utf-8 -*- + +export PYTHONPATH="${PYTHONPATH}:./" + +for data_format in entity relation event absa +do + for dataset in $(ls converted_data/text2spotasoc/${data_format} | grep -v shot | grep -v ratio) + do + for seed in 1 2 3 4 5 6 7 8 9 10 + do + rm -r converted_data/text2spotasoc/${data_format}/${dataset}_ratio/seed${seed} + echo "Convert" converted_data/text2spotasoc/${data_format}/${dataset} "To" converted_data/text2spotasoc/${data_format}/${dataset}_ratio/seed${seed} + python scripts/sample_data_ratio.py -seed ${seed} \ + -src converted_data/text2spotasoc/${data_format}/${dataset} \ + -tgt converted_data/text2spotasoc/${data_format}/${dataset}_ratio/seed${seed} + done + done +done + + +for data_format in entity relation event +do + for dataset in $(ls converted_data/text2spotasoc/${data_format} | grep -v shot | grep -v ratio) + do + for seed in 1 2 3 4 5 6 7 8 9 10 + do + rm -r converted_data/text2spotasoc/${data_format}/${dataset}_shot/seed${seed} + echo "Convert" converted_data/text2spotasoc/${data_format}/${dataset} "To" converted_data/text2spotasoc/${data_format}/${dataset}_shot/seed${seed} + python scripts/sample_data_shot.py -seed ${seed} \ + -src converted_data/text2spotasoc/${data_format}/${dataset} \ + -tgt converted_data/text2spotasoc/${data_format}/${dataset}_shot/seed${seed} \ + -task ${data_format} + done + done +done + + +for data_format in absa +do + for dataset in $(ls converted_data/text2spotasoc/${data_format} | grep -v shot | grep -v ratio) + do + for seed in 1 2 3 4 5 6 7 8 9 10 + do + rm -r converted_data/text2spotasoc/${data_format}/${dataset}_shot/seed${seed} + echo "Convert" converted_data/text2spotasoc/${data_format}/${dataset} "To" converted_data/text2spotasoc/${data_format}/${dataset}_shot/seed${seed} + python scripts/sample_data_shot.py -seed ${seed} \ + -src converted_data/text2spotasoc/${data_format}/${dataset} \ + -tgt converted_data/text2spotasoc/${data_format}/${dataset}_shot/seed${seed} \ + -task relation + done + done +done diff --git a/metaretriever/dataset_processing/scripts/data_statistics.py b/metaretriever/dataset_processing/scripts/data_statistics.py new file mode 100644 index 00000000..343cd4d2 --- /dev/null +++ b/metaretriever/dataset_processing/scripts/data_statistics.py @@ -0,0 +1,95 @@ +import json +import os +import sys +from collections import Counter +import tabulate + + +def count_line_in_file(filename): + return sum([1 for _ in open(filename)]) + + +def count_record_in_file(filename, key): + counter = Counter() + for line in open(filename): + instance = json.loads(line) + counter.update([key + ' entity'] * len(instance['entity'])) + counter.update([key + ' relation'] * len(instance['relation'])) + counter.update([key + ' event'] * len(instance['event'])) + for event in instance['event']: + counter.update([key + ' role'] * len(event['args'])) + return counter + + +def count_folder(folder_name): + data_map = { + 'train': 'train.json', + 'val': 'val.json', + 'test': 'test.json', + } + intance_counter = {'name': folder_name} + for key, name in data_map.items(): + filename = f"{folder_name}/{name}" + if not os.path.exists(filename): + sys.stderr.write(f'[warn] {filename} not exists.\n') + continue + intance_counter[key] = count_line_in_file(filename) + intance_counter.update(count_record_in_file(filename, key)) + + for key in ['entity', 'relation', 'event']: + filename = f"{folder_name}/{key}.schema" + if not os.path.exists(filename): + sys.stderr.write(f'[warn] {filename} not exists.\n') + intance_counter[key] = 0 + continue + intance_counter[key] = len(json.loads(open(filename).readline())) + + return intance_counter + + +def walk_dir(folder_name): + + for root, dirs, files in os.walk(folder_name): + for file in dirs: + folder_name = os.path.join(root, file) + if os.path.exists(f"{os.path.join(root, file)}/record.schema"): + yield os.path.join(root, file) + + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('-data') + parser.add_argument('-f', dest='format', default='simple') + options = parser.parse_args() + + folder_list = list() + + for folder_name in walk_dir(options.data): + if 'shot' in folder_name or 'ratio' in folder_name: + continue + folder_list += [count_folder(folder_name)] + + col_name = ['name', + 'entity', 'relation', 'event', + 'train', 'val', 'test', + 'train entity', 'train relation', 'train event', 'train role', + 'val entity', 'val relation', 'val event', 'val role', + 'test entity', 'test relation', 'test event', 'test role', + ] + table = [] + for data_info in folder_list: + row = [data_info.get(col, 0) for col in col_name] + table += [row] + table.sort() + print( + tabulate.tabulate( + tabular_data=table, + headers=col_name, + tablefmt=options.format, + ) + ) + + +if __name__ == "__main__": + main() diff --git a/metaretriever/dataset_processing/scripts/sample_data_ratio.py b/metaretriever/dataset_processing/scripts/sample_data_ratio.py new file mode 100644 index 00000000..9071f75d --- /dev/null +++ b/metaretriever/dataset_processing/scripts/sample_data_ratio.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +import os +import math +import shutil +import random +import argparse + + +def split_ratio_file(in_filename, out_filename, ratio=0.1, seed=None): + lines = open(in_filename).readlines() + if seed: + random.seed(seed) + random.shuffle(lines) + lines = lines[:math.ceil(len(lines) * ratio)] + with open(out_filename, 'w') as output: + for line in lines: + output.write(line.strip() + '\n') + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('-src') + parser.add_argument('-tgt') + parser.add_argument('-seed') + options = parser.parse_args() + + source_folder = options.src + target_folder = options.tgt + + os.makedirs(target_folder, exist_ok=True) + + for ratio in [0.01, 0.05, 0.1]: + ratio_folder = os.path.join(target_folder, "%s" % ratio) + + os.makedirs(ratio_folder, exist_ok=True) + split_ratio_file( + in_filename=os.path.join(source_folder, 'train.json'), + out_filename=os.path.join(ratio_folder, 'train.json'), + ratio=ratio, + seed=options.seed, + ) + for filename in os.listdir(source_folder): + if filename != 'train.json': + shutil.copy( + os.path.join(source_folder, filename), + os.path.join(ratio_folder, filename), + ) + + +if __name__ == "__main__": + main() diff --git a/metaretriever/dataset_processing/scripts/sample_data_shot.py b/metaretriever/dataset_processing/scripts/sample_data_shot.py new file mode 100644 index 00000000..9d12d6cf --- /dev/null +++ b/metaretriever/dataset_processing/scripts/sample_data_shot.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +import os +import shutil +import random +import argparse +from collections import defaultdict +import json +import sys +from universal_ie.record_schema import RecordSchema + + +def n_shot_smaple(source_filename, target_filename, record_schema, + spot_asoc_key='spot', num_shot=5, min_len=None, seed=None): + + train_data = [json.loads(line.strip()) for line in open(source_filename)] + + if seed: + random.seed(seed) + random.shuffle(train_data) + + # 记录每一句的类别信息 + type_to_sentence_dict = defaultdict(list) + for index, instance in enumerate(train_data): + for spot in instance[spot_asoc_key]: + if spot not in record_schema.type_list: + continue + if min_len is not None and len(instance['tokens']) < min_len: + continue + type_to_sentence_dict[spot] += [index] + + sampled_data = list() + for entity in type_to_sentence_dict: + + if len(type_to_sentence_dict[entity]) < num_shot: + sys.stderr.write( + f'[WARN] {entity} in {source_filename} is less than shot num {num_shot}\n' + ) + sampled = type_to_sentence_dict[entity] + else: + sampled = random.sample(type_to_sentence_dict[entity], num_shot) + + sampled_data += [train_data[index] for index in sampled] + + with open(target_filename, 'w') as output: + for instance in sampled_data: + output.write(json.dumps(instance) + '\n') + + return sampled_data + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('-src', help='Source Folder Name', required=True) + parser.add_argument('-tgt', help='Target Folder Name, n shot sampled', + required=True) + parser.add_argument('-task', help='N-Shot Task name', required=True, + choices=['entity', 'relation', 'event']) + parser.add_argument('-seed', help='Default is None, no random') + parser.add_argument('-min_len', dest='min_len', help='Default is None', type=int) + options = parser.parse_args() + + source_folder = options.src + target_folder = options.tgt + + task_name = options.task + + if task_name in ['relation']: + spot_asoc_key = 'asoc' + else: + spot_asoc_key = 'spot' + + os.makedirs(target_folder, exist_ok=True) + + for shot in [1, 5, 10]: + shot_folder = os.path.join(target_folder, "%sshot" % shot) + + os.makedirs(shot_folder, exist_ok=True) + + n_shot_smaple( + source_filename=os.path.join(source_folder, 'train.json'), + target_filename=os.path.join(shot_folder, 'train.json'), + record_schema=RecordSchema.read_from_file( + os.path.join(source_folder, f'{task_name}.schema'), + ), + spot_asoc_key=spot_asoc_key, + num_shot=shot, + seed=options.seed, + min_len=options.min_len + ) + + for filename in os.listdir(source_folder): + if filename != 'train.json': + shutil.copy( + os.path.join(source_folder, filename), + os.path.join(shot_folder, filename), + ) + + +if __name__ == "__main__": + main() diff --git a/metaretriever/dataset_processing/show_tokenized_result.py b/metaretriever/dataset_processing/show_tokenized_result.py new file mode 100644 index 00000000..b830a74e --- /dev/null +++ b/metaretriever/dataset_processing/show_tokenized_result.py @@ -0,0 +1,86 @@ +from transformers import AutoTokenizer +import json +import argparse +import tabulate +from universal_ie.record_schema import RecordSchema + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('-m', '--model', default='t5-base') + parser.add_argument('-d', '--data', required=True) + parser.add_argument('-s', '--schema', default='event') + options = parser.parse_args() + + if "chinese_t5_pegasus" in options.model: + tokenizer = T5PegasusTokenizer.from_pretrained(options.model) + tokenizer.bos_token = tokenizer.cls_token + tokenizer.eos_token = tokenizer.sep_token + else: + tokenizer = AutoTokenizer.from_pretrained( + options.model, + use_fast=False, + mirror='tuna', + ) + + tokenizer.add_special_tokens( + {"additional_special_tokens": ["", ""]} + ) + + folder_path = options.data + + schema_file = f"{folder_path}/{options.schema}.schema" + + event_schema = RecordSchema.read_from_file(schema_file) + + table = list() + for typename in event_schema.type_list: + typename = typename.replace('_', ' ') + after_tokenzied = tokenizer.encode(typename, add_special_tokens=False) + table += [[typename, + after_tokenzied, + tokenizer.convert_ids_to_tokens(after_tokenzied)]] + + print(tokenizer) + print(type(tokenizer)) + + print("===============Event Schema=================") + print(tabulate.tabulate( + table, + headers=['type', 'token id', 'tokenized'], + tablefmt='grid', + )) + + print("===============Instance=================") + + table = list() + for index, instance in enumerate(open(folder_path + "/val.json").readlines()[:10]): + instance = json.loads(instance) + table += [["Text %s" % index] + [instance['text']]] + table += [["Token %s" % index] + + ['|'.join(tokenizer.tokenize(instance['text']))]] + if 'entity' in instance: + table += [["Entity %s" % index] + + ['|'.join(tokenizer.tokenize(instance['event']))]] + if 'relation' in instance: + table += [["Relation %s" % index] + + ['|'.join(tokenizer.tokenize(instance['relation']))]] + if 'event' in instance: + table += [["Event %s" % index] + + ['|'.join(tokenizer.tokenize(instance['event']))]] + print(tabulate.tabulate(table, headers=['text', 'event'], tablefmt='grid')) + + print("===============Specical Symbol=================") + table = list() + + for name in ['', '']: + table += [[name, tokenizer.encode(name), tokenizer.tokenize(name)]] + print(tabulate.tabulate( + table, + headers=['specical symbol', 'token id', 'tokenized'], + tablefmt='grid' + )) + + +if __name__ == "__main__": + main() diff --git a/metaretriever/dataset_processing/uie_convert.py b/metaretriever/dataset_processing/uie_convert.py new file mode 100644 index 00000000..5a6df85b --- /dev/null +++ b/metaretriever/dataset_processing/uie_convert.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +from collections import Counter +import os +import json +from typing import Dict, List +from tqdm import tqdm +from universal_ie.generation_format.generation_format import GenerationFormat +from universal_ie.generation_format import generation_format_dict +from universal_ie.generation_format.structure_marker import BaseStructureMarker +from universal_ie.dataset import Dataset +from universal_ie.ie_format import Sentence + + +def convert_graph( + generation_class: GenerationFormat, + output_folder: str, + datasets: Dict[str, List[Sentence]], + language: str = "en", + label_mapper: Dict = None, +): + convertor = generation_class( + structure_maker=BaseStructureMarker(), + language=language, + label_mapper=label_mapper, + ) + + counter = Counter() + + os.makedirs(output_folder, exist_ok=True) + + schema_counter = { + "entity": list(), + "relation": list(), + "event": list(), + } + for data_type, instance_list in datasets.items(): + with open(os.path.join(output_folder, f"{data_type}.json"), "w") as output: + for instance in tqdm(instance_list): + counter.update([f"{data_type} sent"]) + converted_graph = convertor.annonote_graph( + tokens=instance.tokens, + entities=instance.entities, + relations=instance.relations, + events=instance.events, + ) + src, tgt, spot_labels, asoc_labels = converted_graph[:4] + spot_asoc = converted_graph[4] + + schema_counter["entity"] += instance.entities + schema_counter["relation"] += instance.relations + schema_counter["event"] += instance.events + + output.write( + "%s\n" + % json.dumps( + { + "text": src, + "tokens": instance.tokens, + "record": tgt, + "entity": [ + entity.to_offset(label_mapper) + for entity in instance.entities + ], + "relation": [ + relation.to_offset( + ent_label_mapper=label_mapper, + rel_label_mapper=label_mapper, + ) + for relation in instance.relations + ], + "event": [ + event.to_offset(evt_label_mapper=label_mapper) + for event in instance.events + ], + "spot": list(spot_labels), + "asoc": list(asoc_labels), + "spot_asoc": spot_asoc, + }, + ensure_ascii=False, + ) + ) + convertor.output_schema(os.path.join(output_folder, "record.schema")) + convertor.get_entity_schema(schema_counter["entity"]).write_to_file( + os.path.join(output_folder, f"entity.schema") + ) + convertor.get_relation_schema(schema_counter["relation"]).write_to_file( + os.path.join(output_folder, f"relation.schema") + ) + convertor.get_event_schema(schema_counter["event"]).write_to_file( + os.path.join(output_folder, f"event.schema") + ) + print(counter) + print(output_folder) + print("==========================") + + +def convert_to_oneie(output_folder: str, datasets: Dict[str, List[Sentence]]): + os.makedirs(output_folder, exist_ok=True) + counter = Counter() + + for data_type, instance_list in datasets.items(): + with open( + os.path.join(output_folder, f"{data_type}.oneie.json"), "w" + ) as output: + for instance in tqdm(instance_list): + counter.update([f"{data_type} sent"]) + entity_mentions = [ + { + "id": entity.record_id, + "entity_type": str(entity.label), + "text": entity.span.text, + "start": entity.span.indexes[0], + "end": entity.span.indexes[-1] + 1, + } + for entity in instance.entities + ] + relation_mentions = [ + { + "id": relation.record_id, + "relation_type": str(relation.label), + "argument": [ + { + "entity_id": relation.arg1.record_id, + "text": relation.arg1.span.text, + "role": "Arg-1", + }, + { + "entity_id": relation.arg2.record_id, + "text": relation.arg2.span.text, + "role": "Arg-2", + }, + ], + } + for relation in instance.relations + ] + event_mentions = [ + { + "id": event.record_id, + "event_type": str(event.label), + "trigger": { + "text": event.span.text, + "start": event.span.indexes[0], + "end": event.span.indexes[-1] + 1, + }, + "argument": [ + { + "id": arg[1].record_id, + "text": arg[1].span.text, + "role": str(arg[0]), + } + for arg in event.args + ], + } + for event in instance.events + ] + + instance_dict = { + "tokens": instance.tokens, + "sent_id": instance.text_id, + "entity_mentions": entity_mentions, + "relation_mentions": relation_mentions, + "event_mentions": event_mentions, + } + instance_str = json.dumps(instance_dict, ensure_ascii=False) + output.write(f"{instance_str}\n") + + print(counter) + print(output_folder) + print("==========================") + + +def main(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("-format", dest="generation_format", default="spotasoc") + parser.add_argument("-config", dest="config", default="data_config/relation") + parser.add_argument("-output", dest="output", default="relation") + options = parser.parse_args() + + generation_class = generation_format_dict.get(options.generation_format) + + if os.path.isfile(options.config): + config_list = [options.config] + else: + config_list = [ + os.path.join(options.config, x) for x in os.listdir(options.config) + ] + + for filename in config_list: + dataset = Dataset.load_yaml_file(filename) + + datasets = dataset.load_dataset() + label_mapper = dataset.mapper + print(label_mapper) + + output_name = ( + f"converted_data/text2{options.generation_format}/{options.output}/" + + dataset.name + ) + + if generation_class: + convert_graph( + generation_class, + output_name, + datasets=datasets, + language=dataset.language, + label_mapper=label_mapper, + ) + elif options.generation_format == "oneie": + convert_to_oneie(output_name, datasets=datasets) + + +if __name__ == "__main__": + main() diff --git a/metaretriever/dataset_processing/universal_ie/__init__.py b/metaretriever/dataset_processing/universal_ie/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/metaretriever/dataset_processing/universal_ie/dataset.py b/metaretriever/dataset_processing/universal_ie/dataset.py new file mode 100644 index 00000000..1e2a8e5f --- /dev/null +++ b/metaretriever/dataset_processing/universal_ie/dataset.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +from universal_ie.utils import label_format +import yaml +import os +from typing import Dict +import universal_ie.task_format as task_format + + +class Dataset: + def __init__(self, name: str, path: str, data_class: task_format.TaskFormat, split_dict: Dict, language: str, mapper: Dict, other: Dict = None) -> None: + self.name = name + self.path = path + self.data_class = data_class + self.split_dict = split_dict + self.language = language + self.mapper = mapper + self.other = other + + def load_dataset(self): + datasets = {} + for split_name, filename in self.split_dict.items(): + datasets[split_name] = self.data_class.load_from_file( + filename=os.path.join(self.path, filename), + language=self.language, + **self.other, + ) + return datasets + + @staticmethod + def load_yaml_file(yaml_file): + dataset_config = yaml.load(open(yaml_file), Loader=yaml.FullLoader) + if 'mapper' in dataset_config: + mapper = dataset_config['mapper'] + for key in mapper: + mapper[key] = label_format(mapper[key]) + else: + print(f"{dataset_config['name']} without label mapper.") + mapper = None + + return Dataset( + name=dataset_config['name'], # 数据集名字 Name of Dataset + path=dataset_config['path'], # 数据集路径 Path of Dataset + data_class=getattr(task_format, dataset_config['data_class']), # 数据集对应的 Task Format 名字 Raw data loader + split_dict=dataset_config['split'], # 数据集不同划分文件地址 Data Split Path + language=dataset_config['language'], # 数据集语言 Dataset Language + mapper=mapper, + other=dataset_config.get('other', {}), + ) diff --git a/metaretriever/dataset_processing/universal_ie/generation_format/__init__.py b/metaretriever/dataset_processing/universal_ie/generation_format/__init__.py new file mode 100644 index 00000000..3661213c --- /dev/null +++ b/metaretriever/dataset_processing/universal_ie/generation_format/__init__.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +from universal_ie.generation_format.text2spotasoc import Text2SpotAsoc + + +generation_format_dict = { + 'spotasoc': Text2SpotAsoc +} diff --git a/metaretriever/dataset_processing/universal_ie/generation_format/generation_format.py b/metaretriever/dataset_processing/universal_ie/generation_format/generation_format.py new file mode 100644 index 00000000..70df045f --- /dev/null +++ b/metaretriever/dataset_processing/universal_ie/generation_format/generation_format.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +from typing import List, Dict, Union +from collections import defaultdict +from universal_ie.record_schema import RecordSchema +from universal_ie.generation_format.structure_marker import StructureMarker +from universal_ie.ie_format import Entity, Relation, Event, Label +import abc + + +class GenerationFormat: + __metaclass__ = abc.ABCMeta + + def __init__(self, + structure_maker: StructureMarker, + label_mapper: Dict = None, + language: str = 'en') -> None: + self.structure_maker = structure_maker + self.language = language + self.label_mapper = {} if label_mapper is None else label_mapper + + # 用于从数据中统计 Schema + self.record_role_map = defaultdict(set) + + def get_label_str(self, label: Label): + return self.label_mapper.get(label.__repr__(), label.__repr__()) + + @abc.abstractmethod + def annotate_entities( + self, tokens: List[str], entities: List[Entity]): pass + + @abc.abstractmethod + def annotate_given_entities(self, tokens: List[str], entities: Union[List[Entity], Entity]): pass + + @abc.abstractmethod + def annotate_events(self, tokens: List[str], events: List[Event]): pass + + @abc.abstractmethod + def annotate_event_given_predicate(self, tokens: List[str], event: Event): pass + + @abc.abstractmethod + def annotate_relation_extraction(self, tokens: List[str], + relations: List[Relation]): pass + + def output_schema(self, filename: str): + """自动导出 Schema 文件 + 每个 Schema 文件包含三行 + - 第一行为 Record 的类别名称列表 + - 第二行为 Role 的类别名称列表 + - 第三行为 Record-Role 映射关系字典 + Args: + filename (str): [description] + """ + record_list = list(self.record_role_map.keys()) + role_set = set() + for record in self.record_role_map: + role_set.update(self.record_role_map[record]) + self.record_role_map[record] = list(self.record_role_map[record]) + role_list = list(role_set) + + record_schema = RecordSchema(type_list=record_list, + role_list=role_list, + type_role_dict=self.record_role_map + ) + record_schema.write_to_file(filename) + + def get_entity_schema(self, entities: List[Entity]): + schema_role_map = set() + for entity in entities: + schema_role_map.add(self.get_label_str(entity.label)) + return RecordSchema( + type_list=list(schema_role_map), + role_list=list(), + type_role_dict=dict() + ) + + def get_relation_schema(self, relations: List[Relation]): + record_role_map = defaultdict(set) + role_set = set() + + for relation in relations: + record_role_map[self.get_label_str(relation.label)].add(self.get_label_str(relation.arg1.label)) + record_role_map[self.get_label_str(relation.label)].add(self.get_label_str(relation.arg2.label)) + + for record in record_role_map: + role_set.update(record_role_map[record]) + record_role_map[record] = list(self.record_role_map[record]) + + return RecordSchema( + type_list=list(record_role_map.keys()), + role_list=list(role_set), + type_role_dict=record_role_map + ) + + def get_event_schema(self, events: List[Event]): + record_role_map = defaultdict(set) + role_set = set() + + for event in events: + for role, _ in event.args: + record_role_map[self.get_label_str(event.label)].add(self.get_label_str(role)) + + for record in record_role_map: + role_set.update(record_role_map[record]) + record_role_map[record] = list(self.record_role_map[record]) + + return RecordSchema( + type_list=list(record_role_map.keys()), + role_list=list(role_set), + type_role_dict=record_role_map + ) diff --git a/metaretriever/dataset_processing/universal_ie/generation_format/structure_marker.py b/metaretriever/dataset_processing/universal_ie/generation_format/structure_marker.py new file mode 100644 index 00000000..e54c8f17 --- /dev/null +++ b/metaretriever/dataset_processing/universal_ie/generation_format/structure_marker.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +# 结构标记符 + + +class StructureMarker: + def __init__(self) -> None: + pass + + +class BaseStructureMarker(StructureMarker): + def __init__(self) -> None: + super().__init__() + self.sent_start = '' + self.sent_end = '' + self.record_start = '' + self.record_end = '' + self.span_start = '' + self.span_end = '' + self.sep_marker = '' + self.source_span_start = '' + self.source_span_end = '' + self.target_span_start = '' + + +class VisualStructureMarker(StructureMarker): + def __init__(self) -> None: + super().__init__() + self.sent_start = '{' + self.sent_end = '}' + self.record_start = '[' + self.record_end = ']' + self.span_start = '(' + self.span_end = ')' + self.source_span_start = '<' + self.source_span_end = '>' + self.target_span_start = ':' + self.sep_marker = ':' diff --git a/metaretriever/dataset_processing/universal_ie/generation_format/text2spotasoc.py b/metaretriever/dataset_processing/universal_ie/generation_format/text2spotasoc.py new file mode 100644 index 00000000..4ac684f7 --- /dev/null +++ b/metaretriever/dataset_processing/universal_ie/generation_format/text2spotasoc.py @@ -0,0 +1,258 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +from collections import defaultdict +from typing import List, Dict +from universal_ie.utils import tokens_to_str +from universal_ie.generation_format.generation_format import GenerationFormat, StructureMarker +from universal_ie.ie_format import Entity, Event, Label, Relation, Span + + +def convert_spot_asoc(spot_asoc_instance, structure_maker): + spot_instance_str_rep_list = list() + for spot in spot_asoc_instance: + spot_str_rep = [ + spot['label'], + structure_maker.target_span_start, + spot['span'], + ] + for asoc_label, asoc_span in spot.get('asoc', list()): + asoc_str_rep = [ + structure_maker.span_start, + asoc_label, + structure_maker.target_span_start, + asoc_span, + structure_maker.span_end, + ] + spot_str_rep += [' '.join(asoc_str_rep)] + spot_instance_str_rep_list += [' '.join([ + structure_maker.record_start, + ' '.join(spot_str_rep), + structure_maker.record_end, + ])] + target_text = ' '.join([ + structure_maker.sent_start, + ' '.join(spot_instance_str_rep_list), + structure_maker.sent_end, + ]) + return target_text + + +class Text2SpotAsoc(GenerationFormat): + def __init__(self, structure_maker: StructureMarker, label_mapper: Dict = None, language: str = 'en') -> None: + super().__init__( + structure_maker=structure_maker, + label_mapper=label_mapper, + language=language + ) + + def annotate_entities(self, tokens: List[str], entities: List[Entity]): + """ Convert Entities + + Args: + tokens (List[str]): ['Trump', 'visits', 'China', '.'] + entities (List[Entity]): [description] + + Returns: + source (str): Trump visits China. + target (str): { [ Person : Trump ] [ Geo-political : China ] } + """ + return self.annonote_graph(tokens=tokens, entities=entities)[:2] + + def augment_source_span(self, tokens: List[str], span: Span): + """[summary] + + Args: + tokens (List[str]): + ['Trump', 'visits', 'China', '.'] + span (Span): + Trump + + Returns: + [type]: + ['(', 'Trump', ')', 'visits', 'China', '.'] + """ + return tokens[:span.indexes[0]] \ + + [self.structure_maker.source_span_start] \ + + tokens[span.indexes[0]:span.indexes[-1] + 1] \ + + [self.structure_maker.source_span_end] \ + + tokens[span.indexes[-1] + 1:] + + def annotate_given_entities(self, tokens: List[str], entities): + """ + entityies is List + :param tokens: + ['Trump', 'visits', 'China', '.'] + :param entities: + ['Trump', 'China'] + :return: + source (str): ( Trump ) ( China ) : Trump visits China . + target (str): { [ Person : Trump ] [ Geo-political : China ] } + + entityies is Entity + :param tokens: + ['Trump', 'visits', 'China', '.'] + :param entities: + 'Trump' + :return: + source (str): < Trump > visits China . + target (str): { [ Person : Trump ] } + """ + if isinstance(entities, list): + entitytokens = [] + for entity in entities: + entitytokens += [self.structure_maker.span_start] + entitytokens += entity.span.tokens + entitytokens += [self.structure_maker.span_end] + source_text = tokens_to_str( + entitytokens + [self.structure_maker.sep_marker] + tokens, + language=self.language, + ) + _, target_text = self.annonote_graph(tokens=tokens, entities=entities)[:2] + + elif isinstance(entities, Entity): + marked_tokens = self.augment_source_span(tokens=tokens, span=entities.span) + source_text = tokens_to_str(marked_tokens, language=self.language) + _, target_text = self.annonote_graph(tokens=tokens, entities=[entities])[:2] + + return source_text, target_text + + def annotate_events(self, tokens: List[str], events: List[Event]): + """ + :param tokens: + ['Trump', 'visits', 'China', '.'] + :param events: + + :return: + source (str): Trump visits China. + target (str): { [ Visit : visits ( Person : Trump ) ( Location : China ) ] } + """ + return self.annonote_graph(tokens=tokens, events=events)[:2] + + def annotate_event_given_predicate(self, tokens: List[str], event: Event): + """Annotate Event Given Predicate + + Args: + tokens (List[str]): + ['Trump', 'visits', 'China', '.'] + event (Event): Given Predicate + + Returns: + [type]: [description] + """ + marked_tokens = self.augment_source_span(tokens=tokens, span=event.span) + source_text = tokens_to_str(marked_tokens, language=self.language) + _, target_text = self.annonote_graph(tokens=tokens, events=[event])[:2] + return source_text, target_text + + def annotate_relation_extraction(self, + tokens: List[str], + relations: List[Relation]): + """ + :param tokens: + ['Trump', 'visits', 'China', '.'] + :param relations: + + :return: + source (str): Trump visits China. + target (str): { [ Person : Trump ( Visit : China ) ] } + """ + return self.annonote_graph(tokens=tokens, relations=relations)[:2] + + def annotate_entities_and_relation_extraction(self, + tokens: List[str], + entities: List[Entity], + relations: List[Relation]): + """ + :param tokens: + ['Trump', 'visits', 'China', '.'] + :param relations: + + :return: + source (str): Trump visits China. + target (str): { [ Person : Trump ( Visit : China ) ] [ Geo-political : China ] } + """ + return self.annonote_graph(tokens=tokens, entities=entities, relations=relations)[:2] + + def annonote_graph(self, + tokens: List[str], + entities: List[Entity] = [], + relations: List[Relation] = [], + events: List[Event] = []): + """Convert Entity Relation Event to Spot-Assocation Graph + + Args: + tokens (List[str]): Token List + entities (List[Entity], optional): Entity List. Defaults to []. + relations (List[Relation], optional): Relation List. Defaults to []. + events (List[Event], optional): Event List. Defaults to []. + + Returns: + str: [description] + { + [ Person : Trump ( Visit : China ) ] + [ Visit : visits ( Person : Trump ) ( Location : China ) ] + [ Geo-political : China ] + } + set: Set of Spot + set: Set of Asoc + """ + spot_dict = dict() + asoc_dict = defaultdict(list) + spot_str_rep_list = list() + + def add_spot(spot): + spot_key = (tuple(spot.span.indexes), self.get_label_str(spot.label)) + spot_dict[spot_key] = spot + + if self.get_label_str(spot.label) not in self.record_role_map: + self.record_role_map[self.get_label_str(spot.label)] = set() + + def add_asoc(spot, asoc: Label, tail): + spot_key = (tuple(spot.span.indexes), self.get_label_str(spot.label)) + asoc_dict[spot_key] += [(tail.span.indexes, tail, self.get_label_str(asoc))] + + self.record_role_map[self.get_label_str(spot.label)].add(self.get_label_str(asoc)) + + for entity in entities: + add_spot(spot=entity) + + for relation in relations: + add_spot(spot=relation.arg1) + add_asoc(spot=relation.arg1, asoc=relation.label, tail=relation.arg2) + + for event in events: + add_spot(spot=event) + for arg_role, argument in event.args: + add_asoc(spot=event, asoc=arg_role, tail=argument) + + spot_asoc_instance = list() + for spot_key in sorted(spot_dict.keys()): + offset, label = spot_key + + if spot_dict[spot_key].span.is_empty_span(): + continue + + spot_instance = {'span': spot_dict[spot_key].span.text, + 'label': label, + 'asoc': list(), + } + for _, tail, asoc in sorted(asoc_dict.get(spot_key, [])): + + if tail.span.is_empty_span(): + continue + + spot_instance['asoc'] += [(asoc, tail.span.text)] + spot_asoc_instance += [spot_instance] + + target_text = convert_spot_asoc( + spot_asoc_instance, + structure_maker=self.structure_maker, + ) + + source_text = tokens_to_str(tokens, language=self.language) + spot_labels = set([label for _, label in spot_dict.keys()]) + asoc_labels = set() + for _, asoc_list in asoc_dict.items(): + for _, _, asoc in asoc_list: + asoc_labels.add(asoc) + return source_text, target_text, spot_labels, asoc_labels, spot_asoc_instance diff --git a/metaretriever/dataset_processing/universal_ie/ie_format.py b/metaretriever/dataset_processing/universal_ie/ie_format.py new file mode 100644 index 00000000..1115b0b1 --- /dev/null +++ b/metaretriever/dataset_processing/universal_ie/ie_format.py @@ -0,0 +1,211 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +from abc import abstractmethod +from collections import defaultdict +from typing import List, Union, Tuple +from universal_ie.utils import change_name_using_label_mapper + + +# All Entity Relation Events are structured records. +# They both have attributes text_id and record_id +# 所有的 Entity Relation Event 都是结构化的记录表示 (Record) +# 他们都有属性 text_id 和 record_id +class Record: + def __init__(self, + text_id: Union[str, None] = None, + record_id: Union[str, None] = None, + ) -> None: + self.text_id = text_id + self.record_id = record_id + + @abstractmethod + def to_offset(self): + pass + + +# Text span +# 连续或者非连续的文本块 +class Span: + def __init__(self, + tokens: List[str], + indexes: List[int], + text: str, + text_id: Union[str, None] = None, + ) -> None: + self.tokens = tokens + self.indexes = indexes + self.text = text + self.text_id = text_id + + def __repr__(self) -> str: + return "[%s](%s)" % (self.text, self.indexes) + + @staticmethod + def get_empty_span(text_id: Union[str, None] = None,): + return Span( + tokens=list(), + indexes=list(), + text="", + text_id=text_id + ) + + def is_empty_span(self): + """Check is empty span. + + Returns: + bool: True, Empty Span; False Non-Empty Span + """ + return len(self.tokens) == 0 and len(self.indexes) == 0 + + +# Label Name +class Label: + def __init__(self, label_name: Union[str, List[str]]) -> None: + self.label_name = label_name + + def __repr__(self) -> str: + return self.label_name + + def __lt__(self, other): + if not isinstance(other, Label): + return NotImplemented + return self.label_name < other.label_name + + +# Entity, Span +# 实体,以文本块为核心的一元结构 +class Entity(Record): + def __init__(self, + span: Span, + label: Label, + text_id: Union[str, None] = None, + record_id: Union[str, None] = None, + ) -> None: + super().__init__(text_id=text_id, record_id=record_id) + self.span = span + self.label = label + + def __lt__(self, other): + if not isinstance(other, Entity): + return NotImplemented + return self.span.indexes < other.span.indexes + + def __repr__(self) -> str: + return self.span.__repr__() + self.label.__repr__() + + def to_offset(self, ent_label_mapper=None): + if self.span.is_empty_span(): + # If span is empty, skip entity + return {} + return {'type': change_name_using_label_mapper(self.label.label_name, + ent_label_mapper), + 'offset': self.span.indexes, + 'text': self.span.text} + + +# Relation Span Pair +# 关系,以文本块对为核心的二元结构 +class Relation(Record): + def __init__(self, + arg1: Entity, + arg2: Entity, + label: Label, + text_id: Union[str, None] = None, + record_id: Union[str, None] = None, + ) -> None: + super().__init__(text_id=text_id, record_id=record_id) + self.arg1 = arg1 + self.arg2 = arg2 + self.label = label + + def __repr__(self) -> str: + return self.arg1.__repr__() + self.label.__repr__() + self.arg2.__repr__() + + def to_offset(self, rel_label_mapper=None, ent_label_mapper=None): + if self.arg1.span.is_empty_span() or self.arg2.span.is_empty_span(): + # If span is empty, skip relation + return {} + return {'type': change_name_using_label_mapper(self.label.label_name, + rel_label_mapper), + 'args': [self.arg1.to_offset(ent_label_mapper=ent_label_mapper), + self.arg2.to_offset(ent_label_mapper=ent_label_mapper), + ], + } + + +# Event, Trigger-Mult-Argument +# 事件,以触发词为中心的多元(谓词论元)结构 +class Event(Record): + def __init__(self, + span: Span, + label: Label, + args: List[Tuple[Label, Entity]], + text_id: Union[str, None] = None, + record_id: Union[str, None] = None, + ) -> None: + super().__init__(text_id=text_id, record_id=record_id) + self.span = span + self.label = label + self.args = args + + def __repr__(self) -> str: + return self.span.__repr__() + self.label.__repr__() + + def to_offset(self, evt_label_mapper=None): + + if self.span.is_empty_span(): + # If span is empty, skip relation + return {} + + args = list() + for role, arg in self.args: + if arg.span.is_empty_span(): + continue + args += [{ + 'type': change_name_using_label_mapper( + role.label_name, + evt_label_mapper, + ), + 'offset': arg.span.indexes, + 'text': arg.span.text + }] + + return {'type': change_name_using_label_mapper(self.label.label_name, + evt_label_mapper), + 'offset': self.span.indexes, + 'text': self.span.text, + 'args': args} + + +class Sentence: + def __init__(self, + tokens: List[str], + entities: List[Entity] = None, + relations: List[Relation] = None, + events: List[Event] = None, + text_id: Union[str, None] = None, + ) -> None: + self.tokens = tokens + self.entities = entities or list() + self.relations = relations or list() + self.events = events or list() + self.text_id = text_id + + def count_entity_without_relation(self): + entity_set = set() + entity_counter = defaultdict(int) + for entity in self.entities: + entity_set.add((tuple(entity.span.indexes), entity.label.label_name)) + + for relation in self.relations: + entity1 = (tuple(relation.arg1.span.indexes), relation.arg1.label.label_name) + entity2 = (tuple(relation.arg2.span.indexes), relation.arg2.label.label_name) + entity_counter[entity1] += 1 + entity_counter[entity2] += 1 + entity_set.remove(entity1) if entity1 in entity_set else None + entity_set.remove(entity2) if entity2 in entity_set else None + overlap_entity = sum([1 if v > 1 else 0 for k, v in entity_counter.items()]) + return {'entity': len(self.entities), + 'entity_without_relation': len(entity_set), + 'overlap_entity': overlap_entity, + } diff --git a/metaretriever/dataset_processing/universal_ie/record_schema.py b/metaretriever/dataset_processing/universal_ie/record_schema.py new file mode 100644 index 00000000..dea9682d --- /dev/null +++ b/metaretriever/dataset_processing/universal_ie/record_schema.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +import json +from collections import defaultdict +from typing import List + + +class RecordSchema: + def __init__(self, type_list, role_list, type_role_dict): + self.type_list = type_list + self.role_list = role_list + self.type_role_dict = type_role_dict + + @staticmethod + def read_from_file(filename): + lines = open(filename).readlines() + type_list = json.loads(lines[0]) + role_list = json.loads(lines[1]) + type_role_dict = json.loads(lines[2]) + return RecordSchema(type_list, role_list, type_role_dict) + + def write_to_file(self, filename): + with open(filename, 'w') as output: + output.write(json.dumps(self.type_list, ensure_ascii=False) + '\n') + output.write(json.dumps(self.role_list, ensure_ascii=False) + '\n') + output.write(json.dumps(self.type_role_dict, ensure_ascii=False) + '\n') + + +def merge_schema(schema_list: List[RecordSchema]): + type_set = set() + role_set = set() + type_role_dict = defaultdict(list) + + for schema in schema_list: + + for type_name in schema.type_list: + type_set.add(type_name) + + for role_name in schema.role_list: + role_set.add(role_name) + + for type_name in schema.type_role_dict: + type_role_dict[type_name] += schema.type_role_dict[type_name] + + for type_name in type_role_dict: + type_role_dict[type_name] = list(set(type_role_dict[type_name])) + + return RecordSchema(type_list=list(type_set), + role_list=list(role_set), + type_role_dict=type_role_dict, + ) diff --git a/metaretriever/dataset_processing/universal_ie/task_format/__init__.py b/metaretriever/dataset_processing/universal_ie/task_format/__init__.py new file mode 100644 index 00000000..1ef81a84 --- /dev/null +++ b/metaretriever/dataset_processing/universal_ie/task_format/__init__.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +from universal_ie.task_format.task_format import TaskFormat +from universal_ie.task_format.oneie import OneIEEvent +from universal_ie.task_format.jointer import JointER +from universal_ie.task_format.mrc_ner import MRCNER +from universal_ie.task_format.absa import ABSA +from universal_ie.task_format.spannet import Spannet +from universal_ie.task_format.casie import CASIE +from universal_ie.task_format.cols import ( + TokenTagCols, + I2b2Conll, + TagTokenCols, + TokenTagJson, + CoNLL03, +) diff --git a/metaretriever/dataset_processing/universal_ie/task_format/absa.py b/metaretriever/dataset_processing/universal_ie/task_format/absa.py new file mode 100644 index 00000000..763572ab --- /dev/null +++ b/metaretriever/dataset_processing/universal_ie/task_format/absa.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- + + +import json +from typing import List +from universal_ie.utils import tokens_to_str, change_ptb_token_back +from universal_ie.ie_format import Entity, Label, Relation, Sentence, Span +from universal_ie.task_format.task_format import TaskFormat + + +class ABSA(TaskFormat): + """ Aspect-Based Sentiment Analysis Data format at https://github.com/yhcc/BARTABSA.""" + + def __init__(self, sentence_json, language='en'): + super().__init__( + language=language + ) + self.tokens = sentence_json['words'] + for index in range(len(self.tokens)): + self.tokens[index] = change_ptb_token_back(self.tokens[index]) + if self.tokens is None: + print('[sentence without tokens]:', sentence_json) + exit(1) + self.aspects = sentence_json['aspects'] + self.opinions = sentence_json['opinions'] + + def generate_instance(self): + entities = dict() + relations = list() + entity_map = dict() + + for aspect, opinion in zip(self.aspects, self.opinions): + aspect_span = (aspect['from'], aspect['to']) + opinion_span = (opinion['from'], opinion['to']) + + if aspect_span not in entity_map: + tokens = self.tokens[aspect_span[0]:aspect_span[1]] + entities[aspect_span] = Entity( + span=Span( + tokens=tokens, + indexes=list(range(aspect_span[0], aspect_span[1])), + text=tokens_to_str(tokens, language=self.language), + ), + label=Label('aspect') + ) + + if opinion_span not in entity_map: + tokens = self.tokens[opinion_span[0]:opinion_span[1]] + entities[opinion_span] = Entity( + span=Span( + tokens=tokens, + indexes=list(range(opinion_span[0], opinion_span[1])), + text=tokens_to_str(tokens, language=self.language), + ), + label=Label('opinion') + ) + + relations += [Relation( + arg1=entities[aspect_span], + arg2=entities[opinion_span], + label=Label(aspect['polarity']), + )] + + return Sentence( + tokens=self.tokens, + entities=entities.values(), + relations=relations, + ) + + @staticmethod + def load_from_file(filename, language='en') -> List[Sentence]: + sentence_list = list() + raw_instance_list = json.load(open(filename)) + print(f"{filename}: {len(raw_instance_list)}") + for instance in raw_instance_list: + instance = ABSA( + sentence_json=instance, + language=language + ).generate_instance() + sentence_list += [instance] + return sentence_list diff --git a/metaretriever/dataset_processing/universal_ie/task_format/casie.py b/metaretriever/dataset_processing/universal_ie/task_format/casie.py new file mode 100644 index 00000000..91d74424 --- /dev/null +++ b/metaretriever/dataset_processing/universal_ie/task_format/casie.py @@ -0,0 +1,2341 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- + +from collections import defaultdict, Counter +import json +from typing import List +from universal_ie.task_format.task_format import TaskFormat +from universal_ie.utils import tokens_to_str +from universal_ie.ie_format import Entity, Event, Label, Sentence, Span + + +class CASIE(TaskFormat): + def __init__(self, sentence_dict, language="en"): + super().__init__(language=language) + self.sent_id = sentence_dict["sent_id"] + self.tokens = sentence_dict["tokens"] + self.entities = sentence_dict["entity_mentions"] + self.events = sentence_dict["event_mentions"] + + def generate_instance(self): + entities = {} + events = {} + + for entity in self.entities: + indexes = entity["indexes"] + tokens = [self.tokens[id] for id in indexes] + entities[entity["id"]] = Entity( + span=Span( + tokens=tokens, + indexes=indexes, + text=tokens_to_str(tokens, language=self.language), + text_id=self.sent_id, + ), + label=Label(entity["type"]), + text_id=self.sent_id, + record_id=entity["id"], + ) + + for event in self.events: + indexes = event["trigger"]["indexes"] + tokens = [self.tokens[id] for id in indexes] + events[event["id"]] = Event( + span=Span( + tokens=tokens, + indexes=indexes, + text=tokens_to_str(tokens, language=self.language), + text_id=self.sent_id, + ), + label=Label(event["type"]), + args=[ + (Label(x["role"]), entities[x["id"]]) + for x in event["arguments"] + ], + text_id=self.sent_id, + record_id=event["id"], + ) + + return Sentence( + tokens=self.tokens, + entities=entities.values(), + events=events.values(), + text_id=self.sent_id, + ) + + @staticmethod + def load_from_file(filename, language="en") -> List[Sentence]: + sentence_list = [] + cross_sentence_cnt = 0 + counter = Counter() + + with open(filename) as fin: + for line in fin: + doc = json.loads(line.strip()) + + entity_mentions = defaultdict(list) + event_mentions = defaultdict(list) + + for event in doc["event"]: + for mention in event["mentions"]: + nugget = mention["nugget"] + sent_id = nugget["tokens"][0][0] + + event_mention = { + "id": mention["id"], + "type": mention["subtype"], + "trigger": {"indexes": [x[1] for x in nugget["tokens"]],}, + "arguments": [], + } + counter.update(['event mention']) + + for argument in mention["arguments"]: + arg_sent_id = argument["tokens"][0][0] + entity_mention = { + "id": argument["id"], + "indexes": [x[1] for x in argument["tokens"]], + "type": argument["filler_type"], + } + entity_mentions[arg_sent_id].append(entity_mention) + counter.update(['entity']) + if arg_sent_id == sent_id: + event_mention["arguments"].append( + { + "id": argument["id"], + "trigger": { + "indexes": [x[1] for x in nugget["tokens"]], + }, + "role": argument["role"], + } + ) + counter.update(['argument']) + else: + counter.update(['cross_sentence_cnt']) + + event_mentions[sent_id].append(event_mention) + + for sent_id, sentence in enumerate(doc["sentences"]): + tokens = [token["word"] for token in sentence["tokens"]] + + sentence_dict = { + "sent_id": sent_id, + "tokens": tokens, + "entity_mentions": entity_mentions[sent_id], + "event_mentions": event_mentions[sent_id], + } + instance = CASIE( + sentence_dict, language=language + ).generate_instance() + + sentence_list.append(instance) + counter.update(['sentence']) + + print(filename, counter) + return sentence_list + + +""" +{ + "id": "10231.txt", + "sentences": [ + { + "tokens": [ + { + "characterOffsetBegin": 0, + "characterOffsetEnd": 2, + "word": "An", + "originalText": "An", + }, + { + "characterOffsetBegin": 3, + "characterOffsetEnd": 8, + "word": "email", + "originalText": "email", + }, + { + "characterOffsetBegin": 9, + "characterOffsetEnd": 13, + "word": "scam", + "originalText": "scam", + }, + { + "characterOffsetBegin": 14, + "characterOffsetEnd": 21, + "word": "passing", + "originalText": "passing", + }, + { + "characterOffsetBegin": 22, + "characterOffsetEnd": 24, + "word": "as", + "originalText": "as", + }, + { + "characterOffsetBegin": 25, + "characterOffsetEnd": 26, + "word": "a", + "originalText": "a", + }, + { + "characterOffsetBegin": 27, + "characterOffsetEnd": 34, + "word": "Netflix", + "originalText": "Netflix", + }, + { + "characterOffsetBegin": 35, + "characterOffsetEnd": 47, + "word": "notification", + "originalText": "notification", + }, + { + "characterOffsetBegin": 48, + "characterOffsetEnd": 51, + "word": "has", + "originalText": "has", + }, + { + "characterOffsetBegin": 52, + "characterOffsetEnd": 56, + "word": "been", + "originalText": "been", + }, + { + "characterOffsetBegin": 57, + "characterOffsetEnd": 66, + "word": "targeting", + "originalText": "targeting", + }, + { + "characterOffsetBegin": 67, + "characterOffsetEnd": 78, + "word": "subscribers", + "originalText": "subscribers", + }, + { + "characterOffsetBegin": 79, + "characterOffsetEnd": 81, + "word": "of", + "originalText": "of", + }, + { + "characterOffsetBegin": 82, + "characterOffsetEnd": 85, + "word": "the", + "originalText": "the", + }, + { + "characterOffsetBegin": 86, + "characterOffsetEnd": 95, + "word": "streaming", + "originalText": "streaming", + }, + { + "characterOffsetBegin": 96, + "characterOffsetEnd": 103, + "word": "service", + "originalText": "service", + }, + { + "characterOffsetBegin": 103, + "characterOffsetEnd": 104, + "word": ".", + "originalText": ".", + }, + ], + "span": [0, 104], + }, + { + "tokens": [ + { + "characterOffsetBegin": 105, + "characterOffsetEnd": 108, + "word": "The", + "originalText": "The", + }, + { + "characterOffsetBegin": 109, + "characterOffsetEnd": 110, + "word": "“", + "originalText": "“", + }, + { + "characterOffsetBegin": 110, + "characterOffsetEnd": 120, + "word": "suspension", + "originalText": "suspension", + }, + { + "characterOffsetBegin": 121, + "characterOffsetEnd": 133, + "word": "notification", + "originalText": "notification", + }, + { + "characterOffsetBegin": 133, + "characterOffsetEnd": 134, + "word": "”", + "originalText": "”", + }, + { + "characterOffsetBegin": 135, + "characterOffsetEnd": 140, + "word": "looks", + "originalText": "looks", + }, + { + "characterOffsetBegin": 141, + "characterOffsetEnd": 148, + "word": "similar", + "originalText": "similar", + }, + { + "characterOffsetBegin": 149, + "characterOffsetEnd": 151, + "word": "in", + "originalText": "in", + }, + { + "characterOffsetBegin": 152, + "characterOffsetEnd": 158, + "word": "design", + "originalText": "design", + }, + { + "characterOffsetBegin": 159, + "characterOffsetEnd": 162, + "word": "and", + "originalText": "and", + }, + { + "characterOffsetBegin": 163, + "characterOffsetEnd": 169, + "word": "format", + "originalText": "format", + }, + { + "characterOffsetBegin": 170, + "characterOffsetEnd": 172, + "word": "to", + "originalText": "to", + }, + { + "characterOffsetBegin": 173, + "characterOffsetEnd": 178, + "word": "other", + "originalText": "other", + }, + { + "characterOffsetBegin": 179, + "characterOffsetEnd": 186, + "word": "Netflix", + "originalText": "Netflix", + }, + { + "characterOffsetBegin": 187, + "characterOffsetEnd": 193, + "word": "emails", + "originalText": "emails", + }, + { + "characterOffsetBegin": 193, + "characterOffsetEnd": 194, + "word": ".", + "originalText": ".", + }, + ], + "span": [105, 194], + }, + { + "tokens": [ + { + "characterOffsetBegin": 195, + "characterOffsetEnd": 197, + "word": "It", + "originalText": "It", + }, + { + "characterOffsetBegin": 198, + "characterOffsetEnd": 206, + "word": "notifies", + "originalText": "notifies", + }, + { + "characterOffsetBegin": 207, + "characterOffsetEnd": 210, + "word": "and", + "originalText": "and", + }, + { + "characterOffsetBegin": 211, + "characterOffsetEnd": 216, + "word": "urges", + "originalText": "urges", + }, + { + "characterOffsetBegin": 217, + "characterOffsetEnd": 222, + "word": "users", + "originalText": "users", + }, + { + "characterOffsetBegin": 223, + "characterOffsetEnd": 225, + "word": "to", + "originalText": "to", + }, + { + "characterOffsetBegin": 226, + "characterOffsetEnd": 232, + "word": "update", + "originalText": "update", + }, + { + "characterOffsetBegin": 233, + "characterOffsetEnd": 238, + "word": "their", + "originalText": "their", + }, + { + "characterOffsetBegin": 239, + "characterOffsetEnd": 250, + "word": "information", + "originalText": "information", + }, + { + "characterOffsetBegin": 251, + "characterOffsetEnd": 253, + "word": "to", + "originalText": "to", + }, + { + "characterOffsetBegin": 254, + "characterOffsetEnd": 259, + "word": "avoid", + "originalText": "avoid", + }, + { + "characterOffsetBegin": 260, + "characterOffsetEnd": 263, + "word": "the", + "originalText": "the", + }, + { + "characterOffsetBegin": 264, + "characterOffsetEnd": 274, + "word": "suspension", + "originalText": "suspension", + }, + { + "characterOffsetBegin": 275, + "characterOffsetEnd": 277, + "word": "of", + "originalText": "of", + }, + { + "characterOffsetBegin": 278, + "characterOffsetEnd": 283, + "word": "their", + "originalText": "their", + }, + { + "characterOffsetBegin": 284, + "characterOffsetEnd": 291, + "word": "account", + "originalText": "account", + }, + { + "characterOffsetBegin": 291, + "characterOffsetEnd": 292, + "word": ".", + "originalText": ".", + }, + ], + "span": [195, 292], + }, + { + "tokens": [ + { + "characterOffsetBegin": 293, + "characterOffsetEnd": 296, + "word": "The", + "originalText": "The", + }, + { + "characterOffsetBegin": 297, + "characterOffsetEnd": 301, + "word": "goal", + "originalText": "goal", + }, + { + "characterOffsetBegin": 302, + "characterOffsetEnd": 304, + "word": "of", + "originalText": "of", + }, + { + "characterOffsetBegin": 305, + "characterOffsetEnd": 308, + "word": "the", + "originalText": "the", + }, + { + "characterOffsetBegin": 309, + "characterOffsetEnd": 313, + "word": "scam", + "originalText": "scam", + }, + { + "characterOffsetBegin": 314, + "characterOffsetEnd": 316, + "word": "is", + "originalText": "is", + }, + { + "characterOffsetBegin": 317, + "characterOffsetEnd": 319, + "word": "to", + "originalText": "to", + }, + { + "characterOffsetBegin": 320, + "characterOffsetEnd": 325, + "word": "steal", + "originalText": "steal", + }, + { + "characterOffsetBegin": 326, + "characterOffsetEnd": 334, + "word": "personal", + "originalText": "personal", + }, + { + "characterOffsetBegin": 335, + "characterOffsetEnd": 338, + "word": "and", + "originalText": "and", + }, + { + "characterOffsetBegin": 339, + "characterOffsetEnd": 345, + "word": "credit", + "originalText": "credit", + }, + { + "characterOffsetBegin": 346, + "characterOffsetEnd": 350, + "word": "card", + "originalText": "card", + }, + { + "characterOffsetBegin": 351, + "characterOffsetEnd": 362, + "word": "information", + "originalText": "information", + }, + { + "characterOffsetBegin": 362, + "characterOffsetEnd": 363, + "word": ",", + "originalText": ",", + }, + { + "characterOffsetBegin": 364, + "characterOffsetEnd": 373, + "word": "according", + "originalText": "according", + }, + { + "characterOffsetBegin": 374, + "characterOffsetEnd": 376, + "word": "to", + "originalText": "to", + }, + { + "characterOffsetBegin": 377, + "characterOffsetEnd": 378, + "word": "a", + "originalText": "a", + }, + { + "characterOffsetBegin": 379, + "characterOffsetEnd": 385, + "word": "report", + "originalText": "report", + }, + { + "characterOffsetBegin": 386, + "characterOffsetEnd": 390, + "word": "from", + "originalText": "from", + }, + { + "characterOffsetBegin": 391, + "characterOffsetEnd": 400, + "word": "Mailguard", + "originalText": "Mailguard", + }, + { + "characterOffsetBegin": 400, + "characterOffsetEnd": 401, + "word": ".", + "originalText": ".", + }, + ], + "span": [293, 401], + }, + { + "tokens": [ + { + "characterOffsetBegin": 402, + "characterOffsetEnd": 405, + "word": "The", + "originalText": "The", + }, + { + "characterOffsetBegin": 406, + "characterOffsetEnd": 411, + "word": "email", + "originalText": "email", + }, + { + "characterOffsetBegin": 412, + "characterOffsetEnd": 420, + "word": "contains", + "originalText": "contains", + }, + { + "characterOffsetBegin": 421, + "characterOffsetEnd": 422, + "word": "a", + "originalText": "a", + }, + { + "characterOffsetBegin": 423, + "characterOffsetEnd": 427, + "word": "link", + "originalText": "link", + }, + { + "characterOffsetBegin": 428, + "characterOffsetEnd": 430, + "word": "to", + "originalText": "to", + }, + { + "characterOffsetBegin": 431, + "characterOffsetEnd": 432, + "word": "a", + "originalText": "a", + }, + { + "characterOffsetBegin": 433, + "characterOffsetEnd": 437, + "word": "fake", + "originalText": "fake", + }, + { + "characterOffsetBegin": 438, + "characterOffsetEnd": 445, + "word": "Netflix", + "originalText": "Netflix", + }, + { + "characterOffsetBegin": 446, + "characterOffsetEnd": 453, + "word": "website", + "originalText": "website", + }, + { + "characterOffsetBegin": 454, + "characterOffsetEnd": 459, + "word": "where", + "originalText": "where", + }, + { + "characterOffsetBegin": 460, + "characterOffsetEnd": 465, + "word": "users", + "originalText": "users", + }, + { + "characterOffsetBegin": 466, + "characterOffsetEnd": 469, + "word": "are", + "originalText": "are", + }, + { + "characterOffsetBegin": 470, + "characterOffsetEnd": 478, + "word": "required", + "originalText": "required", + }, + { + "characterOffsetBegin": 479, + "characterOffsetEnd": 481, + "word": "to", + "originalText": "to", + }, + { + "characterOffsetBegin": 482, + "characterOffsetEnd": 487, + "word": "enter", + "originalText": "enter", + }, + { + "characterOffsetBegin": 488, + "characterOffsetEnd": 491, + "word": "log", + "originalText": "log", + }, + { + "characterOffsetBegin": 491, + "characterOffsetEnd": 492, + "word": "-", + "originalText": "-", + }, + { + "characterOffsetBegin": 492, + "characterOffsetEnd": 494, + "word": "in", + "originalText": "in", + }, + { + "characterOffsetBegin": 495, + "characterOffsetEnd": 506, + "word": "information", + "originalText": "information", + }, + { + "characterOffsetBegin": 507, + "characterOffsetEnd": 510, + "word": "and", + "originalText": "and", + }, + { + "characterOffsetBegin": 511, + "characterOffsetEnd": 512, + "word": "a", + "originalText": "a", + }, + { + "characterOffsetBegin": 513, + "characterOffsetEnd": 519, + "word": "credit", + "originalText": "credit", + }, + { + "characterOffsetBegin": 520, + "characterOffsetEnd": 524, + "word": "card", + "originalText": "card", + }, + { + "characterOffsetBegin": 525, + "characterOffsetEnd": 531, + "word": "number", + "originalText": "number", + }, + { + "characterOffsetBegin": 531, + "characterOffsetEnd": 532, + "word": ".", + "originalText": ".", + }, + ], + "span": [402, 532], + }, + { + "tokens": [ + { + "characterOffsetBegin": 533, + "characterOffsetEnd": 536, + "word": "The", + "originalText": "The", + }, + { + "characterOffsetBegin": 537, + "characterOffsetEnd": 541, + "word": "faux", + "originalText": "faux", + }, + { + "characterOffsetBegin": 542, + "characterOffsetEnd": 549, + "word": "website", + "originalText": "website", + }, + { + "characterOffsetBegin": 550, + "characterOffsetEnd": 553, + "word": "has", + "originalText": "has", + }, + { + "characterOffsetBegin": 554, + "characterOffsetEnd": 557, + "word": "the", + "originalText": "the", + }, + { + "characterOffsetBegin": 558, + "characterOffsetEnd": 565, + "word": "Netflix", + "originalText": "Netflix", + }, + { + "characterOffsetBegin": 566, + "characterOffsetEnd": 570, + "word": "logo", + "originalText": "logo", + }, + { + "characterOffsetBegin": 571, + "characterOffsetEnd": 573, + "word": "on", + "originalText": "on", + }, + { + "characterOffsetBegin": 574, + "characterOffsetEnd": 581, + "word": "display", + "originalText": "display", + }, + { + "characterOffsetBegin": 582, + "characterOffsetEnd": 586, + "word": "plus", + "originalText": "plus", + }, + { + "characterOffsetBegin": 587, + "characterOffsetEnd": 590, + "word": "The", + "originalText": "The", + }, + { + "characterOffsetBegin": 591, + "characterOffsetEnd": 596, + "word": "Crown", + "originalText": "Crown", + }, + { + "characterOffsetBegin": 597, + "characterOffsetEnd": 600, + "word": "and", + "originalText": "and", + }, + { + "characterOffsetBegin": 601, + "characterOffsetEnd": 606, + "word": "House", + "originalText": "House", + }, + { + "characterOffsetBegin": 607, + "characterOffsetEnd": 609, + "word": "of", + "originalText": "of", + }, + { + "characterOffsetBegin": 610, + "characterOffsetEnd": 615, + "word": "Cards", + "originalText": "Cards", + }, + { + "characterOffsetBegin": 616, + "characterOffsetEnd": 623, + "word": "banners", + "originalText": "banners", + }, + { + "characterOffsetBegin": 624, + "characterOffsetEnd": 626, + "word": "to", + "originalText": "to", + }, + { + "characterOffsetBegin": 627, + "characterOffsetEnd": 634, + "word": "further", + "originalText": "further", + }, + { + "characterOffsetBegin": 635, + "characterOffsetEnd": 640, + "word": "trick", + "originalText": "trick", + }, + { + "characterOffsetBegin": 641, + "characterOffsetEnd": 649, + "word": "visitors", + "originalText": "visitors", + }, + { + "characterOffsetBegin": 649, + "characterOffsetEnd": 650, + "word": ".", + "originalText": ".", + }, + ], + "span": [533, 650], + }, + { + "tokens": [ + { + "characterOffsetBegin": 651, + "characterOffsetEnd": 653, + "word": "In", + "originalText": "In", + }, + { + "characterOffsetBegin": 654, + "characterOffsetEnd": 655, + "word": "a", + "originalText": "a", + }, + { + "characterOffsetBegin": 656, + "characterOffsetEnd": 663, + "word": "stament", + "originalText": "stament", + }, + { + "characterOffsetBegin": 664, + "characterOffsetEnd": 668, + "word": "sent", + "originalText": "sent", + }, + { + "characterOffsetBegin": 669, + "characterOffsetEnd": 671, + "word": "to", + "originalText": "to", + }, + { + "characterOffsetBegin": 672, + "characterOffsetEnd": 674, + "word": "EW", + "originalText": "EW", + }, + { + "characterOffsetBegin": 674, + "characterOffsetEnd": 675, + "word": ",", + "originalText": ",", + }, + { + "characterOffsetBegin": 676, + "characterOffsetEnd": 677, + "word": "a", + "originalText": "a", + }, + { + "characterOffsetBegin": 678, + "characterOffsetEnd": 685, + "word": "Netflix", + "originalText": "Netflix", + }, + { + "characterOffsetBegin": 686, + "characterOffsetEnd": 698, + "word": "spokesperson", + "originalText": "spokesperson", + }, + { + "characterOffsetBegin": 699, + "characterOffsetEnd": 706, + "word": "assured", + "originalText": "assured", + }, + { + "characterOffsetBegin": 707, + "characterOffsetEnd": 718, + "word": "subscribers", + "originalText": "subscribers", + }, + { + "characterOffsetBegin": 719, + "characterOffsetEnd": 723, + "word": "that", + "originalText": "that", + }, + { + "characterOffsetBegin": 724, + "characterOffsetEnd": 727, + "word": "the", + "originalText": "the", + }, + { + "characterOffsetBegin": 728, + "characterOffsetEnd": 735, + "word": "company", + "originalText": "company", + }, + { + "characterOffsetBegin": 736, + "characterOffsetEnd": 741, + "word": "takes", + "originalText": "takes", + }, + { + "characterOffsetBegin": 742, + "characterOffsetEnd": 745, + "word": "the", + "originalText": "the", + }, + { + "characterOffsetBegin": 746, + "characterOffsetEnd": 747, + "word": "“", + "originalText": "“", + }, + { + "characterOffsetBegin": 747, + "characterOffsetEnd": 755, + "word": "security", + "originalText": "security", + }, + { + "characterOffsetBegin": 756, + "characterOffsetEnd": 758, + "word": "of", + "originalText": "of", + }, + { + "characterOffsetBegin": 759, + "characterOffsetEnd": 762, + "word": "our", + "originalText": "our", + }, + { + "characterOffsetBegin": 763, + "characterOffsetEnd": 770, + "word": "members", + "originalText": "members", + }, + { + "characterOffsetBegin": 770, + "characterOffsetEnd": 771, + "word": "’", + "originalText": "’", + }, + { + "characterOffsetBegin": 772, + "characterOffsetEnd": 780, + "word": "accounts", + "originalText": "accounts", + }, + { + "characterOffsetBegin": 781, + "characterOffsetEnd": 790, + "word": "seriously", + "originalText": "seriously", + }, + { + "characterOffsetBegin": 790, + "characterOffsetEnd": 791, + "word": ",", + "originalText": ",", + }, + { + "characterOffsetBegin": 791, + "characterOffsetEnd": 792, + "word": "”", + "originalText": "”", + }, + { + "characterOffsetBegin": 793, + "characterOffsetEnd": 797, + "word": "also", + "originalText": "also", + }, + { + "characterOffsetBegin": 798, + "characterOffsetEnd": 805, + "word": "stating", + "originalText": "stating", + }, + { + "characterOffsetBegin": 806, + "characterOffsetEnd": 810, + "word": "that", + "originalText": "that", + }, + { + "characterOffsetBegin": 811, + "characterOffsetEnd": 816, + "word": "these", + "originalText": "these", + }, + { + "characterOffsetBegin": 817, + "characterOffsetEnd": 821, + "word": "type", + "originalText": "type", + }, + { + "characterOffsetBegin": 822, + "characterOffsetEnd": 824, + "word": "of", + "originalText": "of", + }, + { + "characterOffsetBegin": 825, + "characterOffsetEnd": 830, + "word": "scams", + "originalText": "scams", + }, + { + "characterOffsetBegin": 831, + "characterOffsetEnd": 834, + "word": "are", + "originalText": "are", + }, + { + "characterOffsetBegin": 834, + "characterOffsetEnd": 837, + "word": "n’t", + "originalText": "n’t", + }, + { + "characterOffsetBegin": 838, + "characterOffsetEnd": 846, + "word": "uncommon", + "originalText": "uncommon", + }, + { + "characterOffsetBegin": 846, + "characterOffsetEnd": 847, + "word": ":", + "originalText": ":", + }, + { + "characterOffsetBegin": 848, + "characterOffsetEnd": 849, + "word": "“", + "originalText": "“", + }, + { + "characterOffsetBegin": 849, + "characterOffsetEnd": 856, + "word": "Netflix", + "originalText": "Netflix", + }, + { + "characterOffsetBegin": 857, + "characterOffsetEnd": 864, + "word": "employs", + "originalText": "employs", + }, + { + "characterOffsetBegin": 865, + "characterOffsetEnd": 873, + "word": "numerous", + "originalText": "numerous", + }, + { + "characterOffsetBegin": 874, + "characterOffsetEnd": 883, + "word": "proactive", + "originalText": "proactive", + }, + { + "characterOffsetBegin": 884, + "characterOffsetEnd": 892, + "word": "measures", + "originalText": "measures", + }, + { + "characterOffsetBegin": 893, + "characterOffsetEnd": 895, + "word": "to", + "originalText": "to", + }, + { + "characterOffsetBegin": 896, + "characterOffsetEnd": 902, + "word": "detect", + "originalText": "detect", + }, + { + "characterOffsetBegin": 903, + "characterOffsetEnd": 913, + "word": "fraudulent", + "originalText": "fraudulent", + }, + { + "characterOffsetBegin": 914, + "characterOffsetEnd": 922, + "word": "activity", + "originalText": "activity", + }, + { + "characterOffsetBegin": 923, + "characterOffsetEnd": 925, + "word": "to", + "originalText": "to", + }, + { + "characterOffsetBegin": 926, + "characterOffsetEnd": 930, + "word": "keep", + "originalText": "keep", + }, + { + "characterOffsetBegin": 931, + "characterOffsetEnd": 934, + "word": "the", + "originalText": "the", + }, + { + "characterOffsetBegin": 935, + "characterOffsetEnd": 942, + "word": "Netflix", + "originalText": "Netflix", + }, + { + "characterOffsetBegin": 943, + "characterOffsetEnd": 950, + "word": "service", + "originalText": "service", + }, + { + "characterOffsetBegin": 951, + "characterOffsetEnd": 954, + "word": "and", + "originalText": "and", + }, + { + "characterOffsetBegin": 955, + "characterOffsetEnd": 958, + "word": "our", + "originalText": "our", + }, + { + "characterOffsetBegin": 959, + "characterOffsetEnd": 966, + "word": "members", + "originalText": "members", + }, + { + "characterOffsetBegin": 966, + "characterOffsetEnd": 967, + "word": "’", + "originalText": "’", + }, + { + "characterOffsetBegin": 968, + "characterOffsetEnd": 976, + "word": "accounts", + "originalText": "accounts", + }, + { + "characterOffsetBegin": 977, + "characterOffsetEnd": 983, + "word": "secure", + "originalText": "secure", + }, + { + "characterOffsetBegin": 983, + "characterOffsetEnd": 984, + "word": ".", + "originalText": ".", + }, + ], + "span": [651, 984], + }, + { + "tokens": [ + { + "characterOffsetBegin": 985, + "characterOffsetEnd": 998, + "word": "Unfortunately", + "originalText": "Unfortunately", + }, + { + "characterOffsetBegin": 998, + "characterOffsetEnd": 999, + "word": ",", + "originalText": ",", + }, + { + "characterOffsetBegin": 1000, + "characterOffsetEnd": 1005, + "word": "these", + "originalText": "these", + }, + { + "characterOffsetBegin": 1006, + "characterOffsetEnd": 1011, + "word": "scams", + "originalText": "scams", + }, + { + "characterOffsetBegin": 1012, + "characterOffsetEnd": 1015, + "word": "are", + "originalText": "are", + }, + { + "characterOffsetBegin": 1016, + "characterOffsetEnd": 1022, + "word": "common", + "originalText": "common", + }, + { + "characterOffsetBegin": 1023, + "characterOffsetEnd": 1025, + "word": "on", + "originalText": "on", + }, + { + "characterOffsetBegin": 1026, + "characterOffsetEnd": 1029, + "word": "the", + "originalText": "the", + }, + { + "characterOffsetBegin": 1030, + "characterOffsetEnd": 1038, + "word": "internet", + "originalText": "internet", + }, + { + "characterOffsetBegin": 1039, + "characterOffsetEnd": 1042, + "word": "and", + "originalText": "and", + }, + { + "characterOffsetBegin": 1043, + "characterOffsetEnd": 1049, + "word": "target", + "originalText": "target", + }, + { + "characterOffsetBegin": 1050, + "characterOffsetEnd": 1057, + "word": "popular", + "originalText": "popular", + }, + { + "characterOffsetBegin": 1058, + "characterOffsetEnd": 1064, + "word": "brands", + "originalText": "brands", + }, + { + "characterOffsetBegin": 1065, + "characterOffsetEnd": 1069, + "word": "such", + "originalText": "such", + }, + { + "characterOffsetBegin": 1070, + "characterOffsetEnd": 1072, + "word": "as", + "originalText": "as", + }, + { + "characterOffsetBegin": 1073, + "characterOffsetEnd": 1080, + "word": "Netflix", + "originalText": "Netflix", + }, + { + "characterOffsetBegin": 1081, + "characterOffsetEnd": 1084, + "word": "and", + "originalText": "and", + }, + { + "characterOffsetBegin": 1085, + "characterOffsetEnd": 1090, + "word": "other", + "originalText": "other", + }, + { + "characterOffsetBegin": 1091, + "characterOffsetEnd": 1100, + "word": "companies", + "originalText": "companies", + }, + { + "characterOffsetBegin": 1101, + "characterOffsetEnd": 1105, + "word": "with", + "originalText": "with", + }, + { + "characterOffsetBegin": 1106, + "characterOffsetEnd": 1111, + "word": "large", + "originalText": "large", + }, + { + "characterOffsetBegin": 1112, + "characterOffsetEnd": 1120, + "word": "customer", + "originalText": "customer", + }, + { + "characterOffsetBegin": 1121, + "characterOffsetEnd": 1126, + "word": "bases", + "originalText": "bases", + }, + { + "characterOffsetBegin": 1127, + "characterOffsetEnd": 1129, + "word": "to", + "originalText": "to", + }, + { + "characterOffsetBegin": 1130, + "characterOffsetEnd": 1134, + "word": "lure", + "originalText": "lure", + }, + { + "characterOffsetBegin": 1135, + "characterOffsetEnd": 1140, + "word": "users", + "originalText": "users", + }, + { + "characterOffsetBegin": 1141, + "characterOffsetEnd": 1145, + "word": "into", + "originalText": "into", + }, + { + "characterOffsetBegin": 1146, + "characterOffsetEnd": 1152, + "word": "giving", + "originalText": "giving", + }, + { + "characterOffsetBegin": 1153, + "characterOffsetEnd": 1156, + "word": "out", + "originalText": "out", + }, + { + "characterOffsetBegin": 1157, + "characterOffsetEnd": 1165, + "word": "personal", + "originalText": "personal", + }, + { + "characterOffsetBegin": 1166, + "characterOffsetEnd": 1177, + "word": "information", + "originalText": "information", + }, + { + "characterOffsetBegin": 1177, + "characterOffsetEnd": 1178, + "word": ".", + "originalText": ".", + }, + { + "characterOffsetBegin": 1178, + "characterOffsetEnd": 1179, + "word": "”", + "originalText": "”", + }, + ], + "span": [985, 1179], + }, + { + "tokens": [ + { + "characterOffsetBegin": 1180, + "characterOffsetEnd": 1189, + "word": "According", + "originalText": "According", + }, + { + "characterOffsetBegin": 1190, + "characterOffsetEnd": 1192, + "word": "to", + "originalText": "to", + }, + { + "characterOffsetBegin": 1193, + "characterOffsetEnd": 1202, + "word": "Mailguard", + "originalText": "Mailguard", + }, + { + "characterOffsetBegin": 1202, + "characterOffsetEnd": 1204, + "word": "’s", + "originalText": "’s", + }, + { + "characterOffsetBegin": 1205, + "characterOffsetEnd": 1211, + "word": "report", + "originalText": "report", + }, + { + "characterOffsetBegin": 1211, + "characterOffsetEnd": 1212, + "word": ",", + "originalText": ",", + }, + { + "characterOffsetBegin": 1213, + "characterOffsetEnd": 1216, + "word": "the", + "originalText": "the", + }, + { + "characterOffsetBegin": 1217, + "characterOffsetEnd": 1221, + "word": "scam", + "originalText": "scam", + }, + { + "characterOffsetBegin": 1222, + "characterOffsetEnd": 1225, + "word": "has", + "originalText": "has", + }, + { + "characterOffsetBegin": 1226, + "characterOffsetEnd": 1234, + "word": "targeted", + "originalText": "targeted", + }, + { + "characterOffsetBegin": 1235, + "characterOffsetEnd": 1241, + "word": "almost", + "originalText": "almost", + }, + { + "characterOffsetBegin": 1242, + "characterOffsetEnd": 1245, + "word": "110", + "originalText": "110", + }, + { + "characterOffsetBegin": 1246, + "characterOffsetEnd": 1253, + "word": "million", + "originalText": "million", + }, + { + "characterOffsetBegin": 1254, + "characterOffsetEnd": 1265, + "word": "subscribers", + "originalText": "subscribers", + }, + { + "characterOffsetBegin": 1265, + "characterOffsetEnd": 1266, + "word": ".", + "originalText": ".", + }, + ], + "span": [1180, 1266], + }, + { + "tokens": [ + { + "characterOffsetBegin": 1267, + "characterOffsetEnd": 1270, + "word": "One", + "originalText": "One", + }, + { + "characterOffsetBegin": 1271, + "characterOffsetEnd": 1280, + "word": "important", + "originalText": "important", + }, + { + "characterOffsetBegin": 1281, + "characterOffsetEnd": 1286, + "word": "thing", + "originalText": "thing", + }, + { + "characterOffsetBegin": 1287, + "characterOffsetEnd": 1289, + "word": "to", + "originalText": "to", + }, + { + "characterOffsetBegin": 1290, + "characterOffsetEnd": 1294, + "word": "note", + "originalText": "note", + }, + { + "characterOffsetBegin": 1295, + "characterOffsetEnd": 1297, + "word": "is", + "originalText": "is", + }, + { + "characterOffsetBegin": 1298, + "characterOffsetEnd": 1302, + "word": "that", + "originalText": "that", + }, + { + "characterOffsetBegin": 1303, + "characterOffsetEnd": 1306, + "word": "the", + "originalText": "the", + }, + { + "characterOffsetBegin": 1307, + "characterOffsetEnd": 1312, + "word": "email", + "originalText": "email", + }, + { + "characterOffsetBegin": 1312, + "characterOffsetEnd": 1314, + "word": "’s", + "originalText": "’s", + }, + { + "characterOffsetBegin": 1315, + "characterOffsetEnd": 1324, + "word": "recipient", + "originalText": "recipient", + }, + { + "characterOffsetBegin": 1325, + "characterOffsetEnd": 1332, + "word": "appears", + "originalText": "appears", + }, + { + "characterOffsetBegin": 1333, + "characterOffsetEnd": 1335, + "word": "as", + "originalText": "as", + }, + { + "characterOffsetBegin": 1336, + "characterOffsetEnd": 1337, + "word": "“", + "originalText": "“", + }, + { + "characterOffsetBegin": 1337, + "characterOffsetEnd": 1339, + "word": "no", + "originalText": "no", + }, + { + "characterOffsetBegin": 1340, + "characterOffsetEnd": 1346, + "word": "sender", + "originalText": "sender", + }, + { + "characterOffsetBegin": 1346, + "characterOffsetEnd": 1347, + "word": ",", + "originalText": ",", + }, + { + "characterOffsetBegin": 1347, + "characterOffsetEnd": 1348, + "word": "”", + "originalText": "”", + }, + { + "characterOffsetBegin": 1349, + "characterOffsetEnd": 1353, + "word": "plus", + "originalText": "plus", + }, + { + "characterOffsetBegin": 1354, + "characterOffsetEnd": 1357, + "word": "the", + "originalText": "the", + }, + { + "characterOffsetBegin": 1358, + "characterOffsetEnd": 1364, + "word": "victim", + "originalText": "victim", + }, + { + "characterOffsetBegin": 1364, + "characterOffsetEnd": 1366, + "word": "’s", + "originalText": "’s", + }, + { + "characterOffsetBegin": 1367, + "characterOffsetEnd": 1371, + "word": "name", + "originalText": "name", + }, + { + "characterOffsetBegin": 1372, + "characterOffsetEnd": 1379, + "word": "appears", + "originalText": "appears", + }, + { + "characterOffsetBegin": 1380, + "characterOffsetEnd": 1382, + "word": "as", + "originalText": "as", + }, + { + "characterOffsetBegin": 1383, + "characterOffsetEnd": 1384, + "word": "“", + "originalText": "“", + }, + { + "characterOffsetBegin": 1384, + "characterOffsetEnd": 1389, + "word": "#name", + "originalText": "#name", + }, + { + "characterOffsetBegin": 1389, + "characterOffsetEnd": 1390, + "word": "#", + "originalText": "#", + }, + { + "characterOffsetBegin": 1390, + "characterOffsetEnd": 1391, + "word": ",", + "originalText": ",", + }, + { + "characterOffsetBegin": 1391, + "characterOffsetEnd": 1392, + "word": "”", + "originalText": "”", + }, + { + "characterOffsetBegin": 1393, + "characterOffsetEnd": 1395, + "word": "as", + "originalText": "as", + }, + { + "characterOffsetBegin": 1396, + "characterOffsetEnd": 1401, + "word": "shown", + "originalText": "shown", + }, + { + "characterOffsetBegin": 1402, + "characterOffsetEnd": 1404, + "word": "in", + "originalText": "in", + }, + { + "characterOffsetBegin": 1405, + "characterOffsetEnd": 1408, + "word": "the", + "originalText": "the", + }, + { + "characterOffsetBegin": 1409, + "characterOffsetEnd": 1419, + "word": "screenshot", + "originalText": "screenshot", + }, + { + "characterOffsetBegin": 1419, + "characterOffsetEnd": 1420, + "word": ".", + "originalText": ".", + }, + ], + "span": [1267, 1420], + }, + { + "tokens": [ + { + "characterOffsetBegin": 1421, + "characterOffsetEnd": 1428, + "word": "Netflix", + "originalText": "Netflix", + }, + { + "characterOffsetBegin": 1429, + "characterOffsetEnd": 1438, + "word": "customers", + "originalText": "customers", + }, + { + "characterOffsetBegin": 1439, + "characterOffsetEnd": 1442, + "word": "who", + "originalText": "who", + }, + { + "characterOffsetBegin": 1443, + "characterOffsetEnd": 1450, + "word": "receive", + "originalText": "receive", + }, + { + "characterOffsetBegin": 1451, + "characterOffsetEnd": 1455, + "word": "this", + "originalText": "this", + }, + { + "characterOffsetBegin": 1456, + "characterOffsetEnd": 1461, + "word": "email", + "originalText": "email", + }, + { + "characterOffsetBegin": 1462, + "characterOffsetEnd": 1465, + "word": "are", + "originalText": "are", + }, + { + "characterOffsetBegin": 1466, + "characterOffsetEnd": 1473, + "word": "advised", + "originalText": "advised", + }, + { + "characterOffsetBegin": 1474, + "characterOffsetEnd": 1476, + "word": "to", + "originalText": "to", + }, + { + "characterOffsetBegin": 1477, + "characterOffsetEnd": 1484, + "word": "abstain", + "originalText": "abstain", + }, + { + "characterOffsetBegin": 1485, + "characterOffsetEnd": 1489, + "word": "from", + "originalText": "from", + }, + { + "characterOffsetBegin": 1490, + "characterOffsetEnd": 1497, + "word": "filling", + "originalText": "filling", + }, + { + "characterOffsetBegin": 1498, + "characterOffsetEnd": 1501, + "word": "out", + "originalText": "out", + }, + { + "characterOffsetBegin": 1502, + "characterOffsetEnd": 1505, + "word": "any", + "originalText": "any", + }, + { + "characterOffsetBegin": 1506, + "characterOffsetEnd": 1517, + "word": "information", + "originalText": "information", + }, + { + "characterOffsetBegin": 1518, + "characterOffsetEnd": 1526, + "word": "prompted", + "originalText": "prompted", + }, + { + "characterOffsetBegin": 1527, + "characterOffsetEnd": 1529, + "word": "by", + "originalText": "by", + }, + { + "characterOffsetBegin": 1530, + "characterOffsetEnd": 1533, + "word": "the", + "originalText": "the", + }, + { + "characterOffsetBegin": 1534, + "characterOffsetEnd": 1541, + "word": "website", + "originalText": "website", + }, + { + "characterOffsetBegin": 1541, + "characterOffsetEnd": 1542, + "word": ".", + "originalText": ".", + }, + ], + "span": [1421, 1542], + }, + { + "tokens": [ + { + "characterOffsetBegin": 1543, + "characterOffsetEnd": 1550, + "word": "Netflix", + "originalText": "Netflix", + }, + { + "characterOffsetBegin": 1550, + "characterOffsetEnd": 1552, + "word": "’s", + "originalText": "’s", + }, + { + "characterOffsetBegin": 1553, + "characterOffsetEnd": 1565, + "word": "spokesperson", + "originalText": "spokesperson", + }, + { + "characterOffsetBegin": 1566, + "characterOffsetEnd": 1570, + "word": "also", + "originalText": "also", + }, + { + "characterOffsetBegin": 1571, + "characterOffsetEnd": 1580, + "word": "suggested", + "originalText": "suggested", + }, + { + "characterOffsetBegin": 1581, + "characterOffsetEnd": 1585, + "word": "that", + "originalText": "that", + }, + { + "characterOffsetBegin": 1586, + "characterOffsetEnd": 1593, + "word": "members", + "originalText": "members", + }, + { + "characterOffsetBegin": 1594, + "characterOffsetEnd": 1596, + "word": "of", + "originalText": "of", + }, + { + "characterOffsetBegin": 1597, + "characterOffsetEnd": 1600, + "word": "the", + "originalText": "the", + }, + { + "characterOffsetBegin": 1601, + "characterOffsetEnd": 1610, + "word": "streaming", + "originalText": "streaming", + }, + { + "characterOffsetBegin": 1611, + "characterOffsetEnd": 1618, + "word": "service", + "originalText": "service", + }, + { + "characterOffsetBegin": 1619, + "characterOffsetEnd": 1624, + "word": "visit", + "originalText": "visit", + }, + { + "characterOffsetBegin": 1625, + "characterOffsetEnd": 1645, + "word": "netflix.com/security", + "originalText": "netflix.com/security", + }, + { + "characterOffsetBegin": 1646, + "characterOffsetEnd": 1648, + "word": "or", + "originalText": "or", + }, + { + "characterOffsetBegin": 1649, + "characterOffsetEnd": 1656, + "word": "contact", + "originalText": "contact", + }, + { + "characterOffsetBegin": 1657, + "characterOffsetEnd": 1665, + "word": "Customer", + "originalText": "Customer", + }, + { + "characterOffsetBegin": 1666, + "characterOffsetEnd": 1673, + "word": "Service", + "originalText": "Service", + }, + { + "characterOffsetBegin": 1674, + "characterOffsetEnd": 1682, + "word": "directly", + "originalText": "directly", + }, + { + "characterOffsetBegin": 1683, + "characterOffsetEnd": 1685, + "word": "to", + "originalText": "to", + }, + { + "characterOffsetBegin": 1686, + "characterOffsetEnd": 1691, + "word": "learn", + "originalText": "learn", + }, + { + "characterOffsetBegin": 1692, + "characterOffsetEnd": 1696, + "word": "more", + "originalText": "more", + }, + { + "characterOffsetBegin": 1697, + "characterOffsetEnd": 1708, + "word": "information", + "originalText": "information", + }, + { + "characterOffsetBegin": 1709, + "characterOffsetEnd": 1714, + "word": "about", + "originalText": "about", + }, + { + "characterOffsetBegin": 1715, + "characterOffsetEnd": 1720, + "word": "scams", + "originalText": "scams", + }, + { + "characterOffsetBegin": 1721, + "characterOffsetEnd": 1724, + "word": "and", + "originalText": "and", + }, + { + "characterOffsetBegin": 1725, + "characterOffsetEnd": 1730, + "word": "other", + "originalText": "other", + }, + { + "characterOffsetBegin": 1731, + "characterOffsetEnd": 1740, + "word": "malicious", + "originalText": "malicious", + }, + { + "characterOffsetBegin": 1741, + "characterOffsetEnd": 1749, + "word": "activity", + "originalText": "activity", + }, + { + "characterOffsetBegin": 1749, + "characterOffsetEnd": 1750, + "word": ".", + "originalText": ".", + }, + ], + "span": [1543, 1750], + }, + ], + "text": "An email scam passing as a Netflix notification has been targeting subscribers of the streaming service. The “suspension notification” looks similar in design and format to other Netflix emails. It notifies and urges users to update their information to avoid the suspension of their account. The goal of the scam is to steal personal and credit card information, according to a report from Mailguard. The email contains a link to a fake Netflix website where users are required to enter log-in information and a credit card number. The faux website has the Netflix logo on display plus The Crown and House of Cards banners to further trick visitors. In a stament sent to EW, a Netflix spokesperson assured subscribers that the company takes the “security of our members’ accounts seriously,” also stating that these type of scams aren’t uncommon: “Netflix employs numerous proactive measures to detect fraudulent activity to keep the Netflix service and our members’ accounts secure. Unfortunately, these scams are common on the internet and target popular brands such as Netflix and other companies with large customer bases to lure users into giving out personal information.” According to Mailguard’s report, the scam has targeted almost 110 million subscribers. One important thing to note is that the email’s recipient appears as “no sender,” plus the victim’s name appears as “#name#,” as shown in the screenshot. Netflix customers who receive this email are advised to abstain from filling out any information prompted by the website. Netflix’s spokesperson also suggested that members of the streaming service visit netflix.com/security or contact Customer Service directly to learn more information about scams and other malicious activity.", + "stanford_coref": {}, + "event": [ + { + "id": "10231-0", + "mentions": [ + { + "id": "10231-0-0", + "type": "Attack", + "subtype": "Phishing", + "realis": "Actual", + "nugget": { + "text": "email scam", + "span": [3, 12], + "tokens": [[0, 1], [0, 2]], + }, + "arguments": [ + { + "id": "10231-0-0-0", + "role": "Trusted-Entity", + "filler_type": "File", + "text": "a Netflix notification", + "span": [25, 46], + "tokens": [[0, 5], [0, 6], [0, 7]], + }, + { + "id": "10231-0-0-1", + "role": "Trusted-Entity", + "filler_type": "System", + "text": "the streaming service", + "span": [82, 102], + "tokens": [[0, 13], [0, 14], [0, 15]], + }, + { + "id": "10231-0-0-2", + "role": "Tool", + "filler_type": "File", + "text": "The “suspension notification”", + "span": [105, 133], + "tokens": [[1, 0], [1, 1], [1, 2], [1, 3], [1, 4]], + }, + { + "id": "10231-0-0-3", + "role": "Trusted-Entity", + "filler_type": "Data", + "text": "Netflix emails", + "span": [179, 192], + "tokens": [[1, 13], [1, 14]], + }, + { + "id": "10231-0-0-4", + "role": "Victim", + "filler_type": "Person", + "text": "subscribers", + "span": [67, 77], + "tokens": [[0, 11]], + }, + ], + }, + { + "id": "10231-0-1", + "type": "Attack", + "subtype": "Phishing", + "realis": "Actual", + "nugget": { + "text": "the scam", + "span": [305, 312], + "tokens": [[3, 3], [3, 4]], + }, + "arguments": [ + { + "id": "10231-0-1-0", + "role": "Purpose", + "filler_type": "Purpose", + "text": "steal personal and credit card information", + "span": [320, 361], + "tokens": [ + [3, 7], + [3, 8], + [3, 9], + [3, 10], + [3, 11], + [3, 12], + ], + }, + { + "id": "10231-0-1-1", + "role": "Tool", + "filler_type": "File", + "text": "The email", + "span": [402, 410], + "tokens": [[4, 0], [4, 1]], + }, + { + "id": "10231-0-1-2", + "role": "Tool", + "filler_type": "Website", + "text": "a fake Netflix website", + "span": [431, 452], + "tokens": [[4, 6], [4, 7], [4, 8], [4, 9]], + }, + { + "id": "10231-0-1-3", + "role": "Purpose", + "filler_type": "Purpose", + "text": "are required to enter log-in information and a credit card number", + "span": [466, 530], + "tokens": [ + [4, 12], + [4, 13], + [4, 14], + [4, 15], + [4, 16], + [4, 17], + [4, 18], + [4, 19], + [4, 20], + [4, 21], + [4, 22], + [4, 23], + [4, 24], + ], + }, + { + "id": "10231-0-1-4", + "role": "Victim", + "filler_type": "Person", + "text": "users", + "span": [460, 464], + "tokens": [[4, 11]], + }, + ], + }, + ], + }, + { + "id": "10231-1", + "mentions": [ + { + "id": "10231-1-0", + "type": "Attack", + "subtype": "Phishing", + "realis": "Actual", + "nugget": { + "text": "further trick", + "span": [627, 639], + "tokens": [[5, 18], [5, 19]], + }, + "arguments": [ + { + "id": "10231-1-0-0", + "role": "Victim", + "filler_type": "Person", + "text": "visitors", + "span": [641, 648], + "tokens": [[5, 20]], + }, + { + "id": "10231-1-0-1", + "role": "Tool", + "filler_type": "File", + "text": "The Crown and House of Cards banners", + "span": [587, 622], + "tokens": [ + [5, 10], + [5, 11], + [5, 12], + [5, 13], + [5, 14], + [5, 15], + [5, 16], + ], + }, + { + "id": "10231-1-0-2", + "role": "Trusted-Entity", + "filler_type": "File", + "text": "the Netflix logo", + "span": [554, 569], + "tokens": [[5, 4], [5, 5], [5, 6]], + }, + { + "id": "10231-1-0-3", + "role": "Tool", + "filler_type": "Website", + "text": "The faux website", + "span": [533, 548], + "tokens": [[5, 0], [5, 1], [5, 2]], + }, + ], + } + ], + }, + { + "id": "10231-2", + "mentions": [ + { + "id": "10231-2-0", + "type": "Attack", + "subtype": "Phishing", + "realis": "Generic", + "nugget": { + "text": "lure", + "span": [1130, 1133], + "tokens": [[7, 24]], + }, + "arguments": [ + { + "id": "10231-2-0-0", + "role": "Trusted-Entity", + "filler_type": "Organization", + "text": "other companies", + "span": [1085, 1099], + "tokens": [[7, 17], [7, 18]], + }, + { + "id": "10231-2-0-1", + "role": "Trusted-Entity", + "filler_type": "Organization", + "text": "Netflix", + "span": [1073, 1079], + "tokens": [[7, 15]], + }, + { + "id": "10231-2-0-2", + "role": "Victim", + "filler_type": "Person", + "text": "users", + "span": [1135, 1139], + "tokens": [[7, 25]], + }, + { + "id": "10231-2-0-3", + "role": "Purpose", + "filler_type": "Purpose", + "text": "giving out personal information", + "span": [1146, 1176], + "tokens": [[7, 27], [7, 28], [7, 29], [7, 30]], + }, + ], + } + ], + }, + { + "id": "10231-3", + "mentions": [ + { + "id": "10231-3-0", + "type": "Attack", + "subtype": "Phishing", + "realis": "Generic", + "nugget": { + "text": "these scams", + "span": [1000, 1010], + "tokens": [[7, 2], [7, 3]], + }, + "arguments": [], + } + ], + }, + ], + "info": { + "title": "Netflix subscribers targeted by scam email", + "date": "2017_11_06", + "type": "text", + "link": "https://ew.com/tv/2017/11/06/netflix-subscribers-scam-email/", + }, +} +""" diff --git a/metaretriever/dataset_processing/universal_ie/task_format/cols.py b/metaretriever/dataset_processing/universal_ie/task_format/cols.py new file mode 100644 index 00000000..4c0d8571 --- /dev/null +++ b/metaretriever/dataset_processing/universal_ie/task_format/cols.py @@ -0,0 +1,505 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +from collections import Counter +import json +from typing import List, Optional, Tuple, Set +from tqdm import tqdm +from universal_ie.task_format.task_format import TaskFormat +from universal_ie.utils import tokens_to_str +from universal_ie.ie_format import Entity, Label, Sentence, Span + + +# https://github.com/allenai/allennlp/blob/main/allennlp/data/dataset_readers/dataset_utils/span_utils.py +# ### Start Code +def bio_tags_to_spans( + tag_sequence: List[str], classes_to_ignore: List[str] = None +) -> List[Tuple[str, Tuple[int, int]]]: + """ + Given a sequence corresponding to BIO tags, extracts spans. + Spans are inclusive and can be of zero length, representing a single word span. + Ill-formed spans are also included (i.e those which do not start with a "B-LABEL"), + as otherwise it is possible to get a perfect precision score whilst still predicting + ill-formed spans in addition to the correct spans. This function works properly when + the spans are unlabeled (i.e., your labels are simply "B", "I", and "O"). + # Parameters + tag_sequence : `List[str]`, required. + The integer class labels for a sequence. + classes_to_ignore : `List[str]`, optional (default = `None`). + A list of string class labels `excluding` the bio tag + which should be ignored when extracting spans. + # Returns + spans : `List[TypedStringSpan]` + The typed, extracted spans from the sequence, in the format (label, (span_start, span_end)). + Note that the label `does not` contain any BIO tag prefixes. + """ + classes_to_ignore = classes_to_ignore or [] + spans: Set[Tuple[str, Tuple[int, int]]] = set() + span_start = 0 + span_end = 0 + active_conll_tag = None + for index, string_tag in enumerate(tag_sequence): + # Actual BIO tag. + bio_tag = string_tag[0] + if bio_tag not in ["B", "I", "O"]: + raise RuntimeError('Invalid tag sequence %s' % tag_sequence) + conll_tag = string_tag[2:] + if bio_tag == "O" or conll_tag in classes_to_ignore: + # The span has ended. + if active_conll_tag is not None: + spans.add((active_conll_tag, (span_start, span_end))) + active_conll_tag = None + # We don't care about tags we are + # told to ignore, so we do nothing. + continue + elif bio_tag == "B": + # We are entering a new span; reset indices + # and active tag to new span. + if active_conll_tag is not None: + spans.add((active_conll_tag, (span_start, span_end))) + active_conll_tag = conll_tag + span_start = index + span_end = index + elif bio_tag == "I" and conll_tag == active_conll_tag: + # We're inside a span. + span_end += 1 + else: + # This is the case the bio label is an "I", but either: + # 1) the span hasn't started - i.e. an ill formed span. + # 2) The span is an I tag for a different conll annotation. + # We'll process the previous span if it exists, but also + # include this span. This is important, because otherwise, + # a model may get a perfect F1 score whilst still including + # false positive ill-formed spans. + if active_conll_tag is not None: + spans.add((active_conll_tag, (span_start, span_end))) + active_conll_tag = conll_tag + span_start = index + span_end = index + # Last token might have been a part of a valid span. + if active_conll_tag is not None: + spans.add((active_conll_tag, (span_start, span_end))) + return list(spans) + + +def _iob1_start_of_chunk( + prev_bio_tag: Optional[str], + prev_conll_tag: Optional[str], + curr_bio_tag: str, + curr_conll_tag: str, +) -> bool: + if curr_bio_tag == "B": + return True + if curr_bio_tag == "I" and prev_bio_tag == "O": + return True + if curr_bio_tag != "O" and prev_conll_tag != curr_conll_tag: + return True + return False + + +def iob1_tags_to_spans( + tag_sequence: List[str], classes_to_ignore: List[str] = None +) -> List[Tuple[str, Tuple[int, int]]]: + """ + Given a sequence corresponding to IOB1 tags, extracts spans. + Spans are inclusive and can be of zero length, representing a single word span. + Ill-formed spans are also included (i.e., those where "B-LABEL" is not preceded + by "I-LABEL" or "B-LABEL"). + # Parameters + tag_sequence : `List[str]`, required. + The integer class labels for a sequence. + classes_to_ignore : `List[str]`, optional (default = `None`). + A list of string class labels `excluding` the bio tag + which should be ignored when extracting spans. + # Returns + spans : `List[TypedStringSpan]` + The typed, extracted spans from the sequence, in the format (label, (span_start, span_end)). + Note that the label `does not` contain any BIO tag prefixes. + """ + classes_to_ignore = classes_to_ignore or [] + spans: Set[Tuple[str, Tuple[int, int]]] = set() + span_start = 0 + span_end = 0 + active_conll_tag = None + prev_bio_tag = None + prev_conll_tag = None + for index, string_tag in enumerate(tag_sequence): + curr_bio_tag = string_tag[0] + curr_conll_tag = string_tag[2:] + + if curr_bio_tag not in ["B", "I", "O"]: + raise RuntimeError('Invalid tag sequence %s' % tag_sequence) + if curr_bio_tag == "O" or curr_conll_tag in classes_to_ignore: + # The span has ended. + if active_conll_tag is not None: + spans.add((active_conll_tag, (span_start, span_end))) + active_conll_tag = None + elif _iob1_start_of_chunk(prev_bio_tag, prev_conll_tag, curr_bio_tag, curr_conll_tag): + # We are entering a new span; reset indices + # and active tag to new span. + if active_conll_tag is not None: + spans.add((active_conll_tag, (span_start, span_end))) + active_conll_tag = curr_conll_tag + span_start = index + span_end = index + else: + # bio_tag == "I" and curr_conll_tag == active_conll_tag + # We're continuing a span. + span_end += 1 + + prev_bio_tag = string_tag[0] + prev_conll_tag = string_tag[2:] + # Last token might have been a part of a valid span. + if active_conll_tag is not None: + spans.add((active_conll_tag, (span_start, span_end))) + return list(spans) + + +def bmes_tags_to_spans( + tag_sequence: List[str], classes_to_ignore: List[str] = None +) -> List[Tuple[str, Tuple[int, int]]]: + """ + Given a sequence corresponding to BMES tags, extracts spans. + Spans are inclusive and can be of zero length, representing a single word span. + Ill-formed spans are also included (i.e those which do not start with a "B-LABEL"), + as otherwise it is possible to get a perfect precision score whilst still predicting + ill-formed spans in addition to the correct spans. + This function works properly when the spans are unlabeled (i.e., your labels are + simply "B", "M", "E" and "S"). + # Parameters + tag_sequence : `List[str]`, required. + The integer class labels for a sequence. + classes_to_ignore : `List[str]`, optional (default = `None`). + A list of string class labels `excluding` the bio tag + which should be ignored when extracting spans. + # Returns + spans : `List[TypedStringSpan]` + The typed, extracted spans from the sequence, in the format (label, (span_start, span_end)). + Note that the label `does not` contain any BIO tag prefixes. + """ + + def extract_bmes_tag_label(text): + bmes_tag = text[0] + label = text[2:] + return bmes_tag, label + + spans: List[Tuple[str, List[int]]] = [] + prev_bmes_tag: Optional[str] = None + for index, tag in enumerate(tag_sequence): + bmes_tag, label = extract_bmes_tag_label(tag) + if bmes_tag in ("B", "S"): + # Regardless of tag, we start a new span when reaching B & S. + spans.append((label, [index, index])) + elif bmes_tag in ("M", "E") and prev_bmes_tag in ("B", "M") and spans[-1][0] == label: + # Only expand the span if + # 1. Valid transition: B/M -> M/E. + # 2. Matched label. + spans[-1][1][1] = index + else: + # Best effort split for invalid span. + spans.append((label, [index, index])) + # update previous BMES tag. + prev_bmes_tag = bmes_tag + + classes_to_ignore = classes_to_ignore or [] + return [ + # to tuple. + (span[0], (span[1][0], span[1][1])) + for span in spans + if span[0] not in classes_to_ignore + ] + + +def bioul_tags_to_spans( + tag_sequence: List[str], classes_to_ignore: List[str] = None +) -> List[Tuple[str, Tuple[int, int]]]: + """ + Given a sequence corresponding to BIOUL tags, extracts spans. + Spans are inclusive and can be of zero length, representing a single word span. + Ill-formed spans are not allowed and will raise `InvalidTagSequence`. + This function works properly when the spans are unlabeled (i.e., your labels are + simply "B", "I", "O", "U", and "L"). + # Parameters + tag_sequence : `List[str]`, required. + The tag sequence encoded in BIOUL, e.g. ["B-PER", "L-PER", "O"]. + classes_to_ignore : `List[str]`, optional (default = `None`). + A list of string class labels `excluding` the bio tag + which should be ignored when extracting spans. + # Returns + spans : `List[TypedStringSpan]` + The typed, extracted spans from the sequence, in the format (label, (span_start, span_end)). + """ + spans = [] + classes_to_ignore = classes_to_ignore or [] + index = 0 + while index < len(tag_sequence): + label = tag_sequence[index] + if label[0] == "U": + spans.append((label.partition("-")[2], (index, index))) + elif label[0] == "B": + start = index + while label[0] != "L": + index += 1 + if index >= len(tag_sequence): + raise RuntimeError('Invalid tag sequence %s' % tag_sequence) + # raise InvalidTagSequence(tag_sequence) + label = tag_sequence[index] + if not (label[0] == "I" or label[0] == "L"): + raise RuntimeError('Invalid tag sequence %s' % tag_sequence) + # raise InvalidTagSequence(tag_sequence) + spans.append((label.partition("-")[2], (start, index))) + else: + if label != "O": + raise RuntimeError('Invalid tag sequence %s' % tag_sequence) + # raise InvalidTagSequence(tag_sequence) + index += 1 + return [span for span in spans if span[0] not in classes_to_ignore] + + +def bmeso_tags_to_spans( + tag_sequence: List[str], classes_to_ignore: List[str] = None +) -> List[Tuple[str, Tuple[int, int]]]: + """ + bmeso -> bioul + B = Beginning + I/M = Inside / Middle + L/E = Last / End + O = Outside + U/W/S = Unit-length / Whole / Singleton + """ + new_tag = list() + for label in tag_sequence: + if label[0] == 'M': + new_tag += ['I-' + label.partition("-")[2]] + elif label[0] == 'E': + new_tag += ['L-' + label.partition("-")[2]] + elif label[0] == 'S': + new_tag += ['U-' + label.partition("-")[2]] + else: + new_tag += [label] + + return bioul_tags_to_spans(tag_sequence=new_tag, classes_to_ignore=classes_to_ignore) + + +def bieso_tags_to_spans( + tag_sequence: List[str], classes_to_ignore: List[str] = None +) -> List[Tuple[str, Tuple[int, int]]]: + """ + bmeso -> bioul + B = Beginning + I/M = Inside / Middle + L/E = Last / End + O = Outside + U/W/S = Unit-length / Whole / Singleton + """ + new_tag = list() + for label in tag_sequence: + if label[0] == 'E': + new_tag += ['L-' + label.partition("-")[2]] + elif label[0] == 'S': + new_tag += ['U-' + label.partition("-")[2]] + else: + new_tag += [label] + + return bioul_tags_to_spans(tag_sequence=new_tag, classes_to_ignore=classes_to_ignore) +# ### End Code + + +_tagging_span_function = { + 'bioul': bioul_tags_to_spans, + 'bmes': bmes_tags_to_spans, + 'bio': bio_tags_to_spans, + 'iob1': iob1_tags_to_spans, + 'bmeso': bmeso_tags_to_spans, + 'bieso': bieso_tags_to_spans, +} + + +class Cols(TaskFormat): + + def __init__(self, tokens: List[str], spans: List[Tuple[Tuple[int, int], str]], language='en', instance_id=None) -> None: + super().__init__( + language=language + ) + self.instance_id = instance_id + self.tokens = tokens + self.spans = spans + + def generate_instance(self): + entities = list() + for span_index, span in enumerate(self.spans): + tokens = self.tokens[span['start']: span['end'] + 1] + indexes = list(range(span['start'], span['end'] + 1)) + entities += [ + Entity( + span=Span( + tokens=tokens, + indexes=indexes, + text=tokens_to_str(tokens, language=self.language), + text_id=self.instance_id + ), + label=Label(span['type']), + text_id=self.instance_id, + record_id=self.instance_id + "#%s" % span_index if self.instance_id else None) + ] + return Sentence(tokens=self.tokens, + entities=entities, + text_id=self.instance_id) + + @staticmethod + def generate_sentence(filename): + sentence = list() + with open(filename) as fin: + for line in fin: + if line.strip() == '': + if len(sentence) != 0: + yield sentence + sentence = list() + + else: + sentence += [line.strip().split()] + + if len(sentence) != 0: + yield sentence + + +class TokenTagCols(Cols): + + @staticmethod + def load_from_file(filename, language='en', tagging='bio') -> List[Sentence]: + sentence_list = list() + counter = Counter() + for rows in tqdm(Cols.generate_sentence(filename)): + tokens = [token[0] for token in rows] + ner = [token[1] for token in rows] + spans = _tagging_span_function[tagging](ner) + spans = list(filter(lambda x: x[0] != "", spans)) + spans = [ + {'start': span[1][0], 'end': span[1][1], 'type': span[0]} + for span in spans + ] + sentence = Cols( + tokens=tokens, + spans=spans, + language=language, + ) + counter.update(['token'] * len(tokens)) + counter.update(['sentence']) + counter.update(['span'] * len(spans)) + sentence_list += [sentence.generate_instance()] + print(filename, counter) + return sentence_list + + +class TagTokenCols(Cols): + + @staticmethod + def load_from_file(filename, language='en', tagging='bio') -> List[Sentence]: + sentence_list = list() + counter = Counter() + for rows in tqdm(Cols.generate_sentence(filename)): + tokens = [token[1] for token in rows] + ner = [token[0] for token in rows] + spans = _tagging_span_function[tagging](ner) + spans = [ + {'start': span[1][0], 'end': span[1][1], 'type': span[0]} + for span in spans + ] + sentence = Cols( + tokens=tokens, + spans=spans, + language=language, + ) + counter.update(['token'] * len(tokens)) + counter.update(['sentence']) + counter.update(['span'] * len(spans)) + sentence_list += [sentence.generate_instance()] + print(filename, counter) + return sentence_list + + +class TokenTagJson(Cols): + @staticmethod + def load_from_file(filename, language='en', tagging='bio') -> List[Sentence]: + sentence_list = list() + counter = Counter() + for line in open(filename): + instance = json.loads(line.strip()) + tokens = instance['tokens'] + ner = instance['ner_tags'] + spans = _tagging_span_function[tagging](ner) + spans = list(filter(lambda x: x[0] != "", spans)) + spans = [ + {'start': span[1][0], 'end': span[1][1], 'type': span[0]} + for span in spans + ] + sentence = Cols( + tokens=tokens, + spans=spans, + language=language, + ) + counter.update(['token'] * len(tokens)) + counter.update(['sentence']) + counter.update(['span'] * len(spans)) + sentence_list += [sentence.generate_instance()] + print(filename, counter) + return sentence_list + + +class I2b2Conll(Cols): + + @staticmethod + def load_from_file(filename, language='en') -> List[Sentence]: + sentence_list = list() + counter = Counter() + for rows in tqdm(Cols.generate_sentence(filename)): + tokens = [token[0] for token in rows] + ner = [token[4] for token in rows] + spans = bio_tags_to_spans(ner) + spans = [ + {'start': span[1][0], 'end': span[1][1], 'type': span[0]} + for span in spans + ] + sentence = Cols( + tokens=tokens, + spans=spans, + language=language, + ) + counter.update(['token'] * len(tokens)) + counter.update(['sentence']) + counter.update(['span'] * len(spans)) + sentence_list += [sentence.generate_instance()] + print(filename, counter) + return sentence_list + + +class CoNLL03(Cols): + + @staticmethod + def load_from_file(filename, language='en') -> List[Sentence]: + sentence_list = list() + counter = Counter() + for rows in tqdm(Cols.generate_sentence(filename)): + if rows[0][0] == '-DOCSTART-': + continue + tokens = [token[0] for token in rows] + ner = [token[3] for token in rows] + spans = iob1_tags_to_spans(ner) + spans = [ + {'start': span[1][0], 'end': span[1][1], 'type': span[0]} + for span in spans + ] + sentence = Cols( + tokens=tokens, + spans=spans, + language=language, + ) + counter.update(['token'] * len(tokens)) + counter.update(['sentence']) + counter.update(['span'] * len(spans)) + sentence_list += [sentence.generate_instance()] + print(filename, counter) + return sentence_list + + +if __name__ == "__main__": + pass diff --git a/metaretriever/dataset_processing/universal_ie/task_format/jointer.py b/metaretriever/dataset_processing/universal_ie/task_format/jointer.py new file mode 100644 index 00000000..c237538b --- /dev/null +++ b/metaretriever/dataset_processing/universal_ie/task_format/jointer.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- + + +import json +from typing import List +from universal_ie.utils import tokens_to_str, change_ptb_token_back +from universal_ie.ie_format import Entity, Label, Relation, Sentence, Span +from universal_ie.task_format.task_format import TaskFormat + + +class JointER(TaskFormat): + """ Joint Entity Relation Data format at https://github.com/yubowen-ph/JointER""" + + def __init__(self, sentence_json, language='en'): + super().__init__( + language=language + ) + self.tokens = sentence_json['tokens'] + for index in range(len(self.tokens)): + self.tokens[index] = change_ptb_token_back(self.tokens[index]) + if self.tokens is None: + print('[sentence without tokens]:', sentence_json) + exit(1) + self.spo_list = sentence_json['spo_list'] + self.spo_details = sentence_json['spo_details'] + self.pos_tags = sentence_json['pos_tags'] + + def generate_instance(self): + entities = dict() + relations = dict() + entity_map = dict() + + for spo_index, spo in enumerate(self.spo_details): + s_s, s_e, s_t = spo[0], spo[1], spo[2] + tokens = self.tokens[s_s: s_e] + indexes = list(range(s_s, s_e)) + if (s_s, s_e, s_t) not in entity_map: + entities[(s_s, s_e, s_t)] = Entity( + span=Span( + tokens=tokens, + indexes=indexes, + text=tokens_to_str(tokens, language=self.language), + ), + label=Label(s_t) + ) + + o_s, o_e, o_t = spo[4], spo[5], spo[6] + tokens = self.tokens[o_s: o_e] + indexes = list(range(o_s, o_e)) + if (o_s, o_e, o_t) not in entity_map: + entities[(o_s, o_e, o_t)] = Entity( + span=Span( + tokens=tokens, + indexes=indexes, + text=tokens_to_str(tokens, language=self.language), + ), + label=Label(o_t) + ) + + relations[spo_index] = Relation( + arg1=entities[(s_s, s_e, s_t)], + arg2=entities[(o_s, o_e, o_t)], + label=Label(spo[3]), + ) + + return Sentence( + tokens=self.tokens, + entities=entities.values(), + relations=relations.values(), + ) + + @staticmethod + def load_from_file(filename, language='en') -> List[Sentence]: + sentence_list = list() + raw_instance_list = json.load(open(filename)) + print(f"{filename}: {len(raw_instance_list)}") + for instance in raw_instance_list: + instance = JointER( + sentence_json=instance, + language=language + ).generate_instance() + sentence_list += [instance] + return sentence_list diff --git a/metaretriever/dataset_processing/universal_ie/task_format/mrc_ner.py b/metaretriever/dataset_processing/universal_ie/task_format/mrc_ner.py new file mode 100644 index 00000000..1523071d --- /dev/null +++ b/metaretriever/dataset_processing/universal_ie/task_format/mrc_ner.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +import json +from collections import Counter, defaultdict +from typing import Dict, List +from universal_ie.task_format.spannet import Spannet +from universal_ie.ie_format import Sentence + + +class MRCNER(Spannet): + """ MRC NER format at https://github.com/ShannonAI/mrc-for-flat-nested-ner""" + id_template = "%s#%s" + + def __init__(self, instance_json: Dict, language='en'): + super().__init__( + instance_json=instance_json, + language=language + ) + + @ staticmethod + def load_from_file(filename, language='en') -> List[Sentence]: + counter = Counter() + dataset = defaultdict(dict) + with open(filename) as fin: + for instance in json.load(fin): + counter.update(['label sentence']) + key, _ = instance['qas_id'].split('.') + dataset[key]['tokens'] = instance['context'].split() + if 'spans' not in dataset[key]: + dataset[key]['spans'] = list() + for start, end in zip(instance['start_position'], + instance['end_position']): + dataset[key]['spans'] += [{ + 'start': start, + 'end': end, + 'type': instance['entity_label'] + }] + counter.update(['span']) + + sentence_list = list() + for sentence_id, sentence in dataset.items(): + counter.update(['sentence']) + mrc_instance = MRCNER( + instance_json={ + 'tokens': sentence['tokens'], + 'span_list': sentence['spans'], + 'id': sentence_id + }, + language=language + ) + sentence_list += [mrc_instance.generate_instance()] + + print(filename, counter) + + return sentence_list diff --git a/metaretriever/dataset_processing/universal_ie/task_format/oneie.py b/metaretriever/dataset_processing/universal_ie/task_format/oneie.py new file mode 100644 index 00000000..6c355f3d --- /dev/null +++ b/metaretriever/dataset_processing/universal_ie/task_format/oneie.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +import json +from typing import List +from universal_ie.task_format.task_format import TaskFormat +from universal_ie.utils import tokens_to_str +from universal_ie.ie_format import Entity, Event, Label, Sentence, Span + + +""" +{ + "doc_id": "AFP_ENG_20030427.0118", + "sent_id": "AFP_ENG_20030427.0118-1", + "tokens": ["A", "Pakistani", "court", "in", "central", "Punjab", "province", "has", "sentenced", "a", "Christian", "man", "to", "life", "imprisonment", "for", "a", "blasphemy", "conviction", ",", "police", "said", "Sunday", "."], "pieces": ["A", "Pakistani", "court", "in", "central", "Punjab", "province", "has", "sentenced", "a", "Christian", "man", "to", "life", "imprisonment", "for", "a", "b", "##lasp", "##hem", "##y", "conviction", ",", "police", "said", "Sunday", "."], + "token_lens": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 1, 1, 1], + "sentence": "A Pakistani court in central Punjab province has sentenced a Christian man to life imprisonment for a blasphemy conviction, police said Sunday.", + "entity_mentions": [ + {"id": "AFP_ENG_20030427.0118-E15-53", "text": "Pakistani", "entity_type": "GPE", "mention_type": "NAM", "entity_subtype": "Nation", "start": 1, "end": 2}, + {"id": "AFP_ENG_20030427.0118-E35-52", "text": "court", "entity_type": "ORG", "mention_type": "NOM", "entity_subtype": "Government", "start": 2, "end": 3}, + {"id": "AFP_ENG_20030427.0118-E37-54", "text": "province", "entity_type": "LOC", "mention_type": "NOM", "entity_subtype": "Region-General", "start": 6, "end": 7}, + {"id": "AFP_ENG_20030427.0118-E27-48", "text": "Christian", "entity_type": "PER", "mention_type": "NOM", "entity_subtype": "Group", "start": 10, "end": 11}, + {"id": "AFP_ENG_20030427.0118-E38-55", "text": "man", "entity_type": "PER", "mention_type": "NOM", "entity_subtype": "Individual", "start": 11, "end": 12}, + {"id": "AFP_ENG_20030427.0118-E39-56", "text": "police", "entity_type": "PER", "mention_type": "NOM", "entity_subtype": "Group", "start": 20, "end": 21}], + "relation_mentions": [ + {"id": "AFP_ENG_20030427.0118-R1-1", "relation_type": "GEN-AFF", "relation_subtype": "GEN-AFF:Citizen-Resident-Religion-Ethnicity", + "arguments": [ + {"entity_id": "AFP_ENG_20030427.0118-E38-55", "text": "man", "role": "Arg-1"}, + {"entity_id": "AFP_ENG_20030427.0118-E27-48", "text": "Christian", "role": "Arg-2"} + ] + }, + {"id": "AFP_ENG_20030427.0118-R3-1", "relation_type": "PART-WHOLE", "relation_subtype": "PART-WHOLE:Subsidiary", + "arguments": [ + {"entity_id": "AFP_ENG_20030427.0118-E35-52", "text": "court", "role": "Arg-1"}, + {"entity_id": "AFP_ENG_20030427.0118-E15-53", "text": "Pakistani", "role": "Arg-2"} + ] + }, + {"id": "AFP_ENG_20030427.0118-R4-1", "relation_type": "GEN-AFF", "relation_subtype": "GEN-AFF:Org-Location", + "arguments": [ + {"entity_id": "AFP_ENG_20030427.0118-E35-52", "text": "court", "role": "Arg-1"}, + {"entity_id": "AFP_ENG_20030427.0118-E37-54", "text": "province", "role": "Arg-2"} + ] + } + ], + "event_mentions": [ + {"id": "AFP_ENG_20030427.0118-EV1-1", "event_type": "Justice:Sentence", + "trigger": {"text": "sentenced", "start": 8, "end": 9}, + "arguments": [ + {"entity_id": "AFP_ENG_20030427.0118-E35-52", "text": "court", "role": "Adjudicator"}, + {"entity_id": "AFP_ENG_20030427.0118-E38-55", "text": "man", "role": "Defendant"}, + {"entity_id": "AFP_ENG_20030427.0118-E37-54", "text": "province", "role": "Place"} + ]}, + {"id": "AFP_ENG_20030427.0118-EV2-1", "event_type": "Justice:Convict", + "trigger": {"text": "conviction", "start": 18, "end": 19}, + "arguments": [{"entity_id": "AFP_ENG_20030427.0118-E38-55", "text": "man", "role": "Defendant"} + ]} +]} +""" + + +class OneIEEvent(TaskFormat): + def __init__(self, doc_json, language='en'): + super().__init__( + language=language + ) + self.doc_id = doc_json['doc_id'] + self.sent_id = doc_json['sent_id'] + self.tokens = doc_json['tokens'] + self.entities = doc_json['entity_mentions'] + self.relations = doc_json['relation_mentions'] + self.events = doc_json['event_mentions'] + + def generate_instance(self): + events = dict() + entities = dict() + + for span_index, span in enumerate(self.entities): + tokens = self.tokens[span['start']: span['end']] + indexes = list(range(span['start'], span['end'])) + entities[span['id']] = Entity( + span=Span( + tokens=tokens, + indexes=indexes, + text=tokens_to_str(tokens, language=self.language), + text_id=self.sent_id + ), + label=Label(span['entity_type']), + text_id=self.sent_id, + record_id=span['id'] + ) + + for event_index, event in enumerate(self.events): + start = event['trigger']['start'] + end = event['trigger']['end'] + tokens = self.tokens[start:end] + indexes = list(range(start, end)) + events[event['id']] = Event( + span=Span( + tokens=tokens, + indexes=indexes, + text=tokens_to_str(tokens, language=self.language), + text_id=self.sent_id + ), + label=Label(event['event_type']), + args=[(Label(x['role']), entities[x['entity_id']]) + for x in event['arguments']], + text_id=self.sent_id, + record_id=event['id'] + ) + + return Sentence( + tokens=self.tokens, + entities=list(), + relations=list(), + events=events.values(), + text_id=self.sent_id + ) + + @staticmethod + def load_from_file(filename, language='en') -> List[Sentence]: + sentence_list = list() + with open(filename) as fin: + for line in fin: + instance = OneIEEvent( + json.loads(line.strip()), + language=language + ).generate_instance() + sentence_list += [instance] + return sentence_list diff --git a/metaretriever/dataset_processing/universal_ie/task_format/spannet.py b/metaretriever/dataset_processing/universal_ie/task_format/spannet.py new file mode 100644 index 00000000..6a521cef --- /dev/null +++ b/metaretriever/dataset_processing/universal_ie/task_format/spannet.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +from collections import Counter +import json +from typing import List, Dict +from universal_ie.task_format.task_format import TaskFormat +from universal_ie.utils import change_ptb_token_back, tokens_to_str +from universal_ie.ie_format import Entity, Label, Relation, Sentence, Span +from tqdm import tqdm + + +class Spannet(TaskFormat): + """ + { + "tokens": ["An", "art", "exhibit", "at", "the", "Hakawati", "Theatre", + "in", "Arab", "east", "Jerusalem", "was", "a", "series", + "of", "portraits", "of", "Palestinians", "killed", "in", + "the", "rebellion", "."], + "span_pair_list": [ + {"type": "OrgBased_In", "head": 0, "tail": 2} + ], + "span_list": [ + {"type": "Org", "start": 5, "end": 6}, + {"type": "Other", "start": 8, "end": 8}, + {"type": "Loc", "start": 10, "end": 10}, + {"type": "Other", "start": 17, "end": 17} + ] + } + """ + def __init__(self, instance_json: Dict, language='en') -> None: + super().__init__( + language=language + ) + self.tokens = change_ptb_token_back(instance_json['tokens']) + self.span_list = instance_json.get('span_list', []) + self.span_pair_list = instance_json.get('span_pair_list', []) + self.instance_id = instance_json.get('id', None) + + def generate_instance(self): + entities = list() + relations = list() + for span_index, span in enumerate(self.span_list): + tokens = self.tokens[span['start']: span['end'] + 1] + indexes = list(range(span['start'], span['end'] + 1)) + entities += [ + Entity( + span=Span( + tokens=tokens, + indexes=indexes, + text=tokens_to_str(tokens, language=self.language), + text_id=self.instance_id + ), + label=Label(span['type']), + text_id=self.instance_id, + record_id=self.instance_id + "#%s" % span_index if self.instance_id else None) + ] + for spanpair_index, span_pair in enumerate(self.span_pair_list): + relations += [ + Relation( + arg1=entities[span_pair['head']], + arg2=entities[span_pair['tail']], + label=Label(span_pair['type']), + text_id=self.instance_id, + record_id=self.instance_id + "##%s" % spanpair_index if self.instance_id else None + ) + ] + return Sentence(tokens=self.tokens, + entities=entities, + relations=relations, + text_id=self.instance_id) + + @staticmethod + def load_from_file(filename, language='en') -> List[Sentence]: + sentence_list = list() + counter = Counter() + with open(filename) as fin: + for line in tqdm(fin): + spannet = Spannet( + json.loads(line.strip()), + language=language + ) + instance = spannet.generate_instance() + sentence_list += [instance] + counter.update(['sentence']) + counter.update(['span'] * len(spannet.span_list)) + print(filename, counter) + return sentence_list diff --git a/metaretriever/dataset_processing/universal_ie/task_format/task_format.py b/metaretriever/dataset_processing/universal_ie/task_format/task_format.py new file mode 100644 index 00000000..b631cefa --- /dev/null +++ b/metaretriever/dataset_processing/universal_ie/task_format/task_format.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +import abc + + +class TaskFormat: + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def __init__(self, language='en'): + self.language = language + + @abc.abstractmethod + def generate_instance(self): + pass + + @staticmethod + @abc.abstractmethod + def load_from_file(filename, language='en'): + pass diff --git a/metaretriever/dataset_processing/universal_ie/utils.py b/metaretriever/dataset_processing/universal_ie/utils.py new file mode 100644 index 00000000..3f54db45 --- /dev/null +++ b/metaretriever/dataset_processing/universal_ie/utils.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +from typing import List +import os +import sys + + +global_mislabel_log = set() + + +def tokens_to_str(tokens: List[str], language: str = 'en') -> str: + if language == 'en': + return ' '.join(tokens) + elif language == 'zh': + return ''.join(tokens) + else: + raise NotImplementedError('Language %s not supported' % language) + + +def label_format(s): + import re + + def uncamelize(s): + re_outer = re.compile(r'([^A-Z ])([A-Z])') + re_inner = re.compile(r'\b[A-Z]+(?=[A-Z][a-z])') + sub = re_inner.sub(r'\g<0> ', re_outer.sub(r'\1 \2', s)).lower() + return sub + + def remove(s): + return s.replace("_", " ").replace("-", " ").replace(".", " ") + + s = remove(uncamelize(s)).split() + if len(s) > 1 and s[0] == s[1]: + s = s[1:] + return " ".join(s) + + +def load_dict_ini_file(filename): + print("Warning: `load_dict_ini_file` is deprecated.") + if not os.path.exists(filename): + sys.stderr.write(f'[warning] cannot load label mapper from {filename}\n') + return {} + mapper = dict() + for line in open(filename): + key, value = line.strip().split('=') + mapper[key] = label_format(value) + return mapper + + +def change_ptb_token_back(token): + """将 PTBTokenized 的 Token 转换会原始字符串 + + Args: + token (str): PTBTokenize 后的 Token 字符串 + + Returns: + str: 原始 Token 字符串 + """ + ptb_token_map = { + '``': '"', + "''": '"', + '-LRB-': '(', + '-RRB-': ')', + '-LSB-': '[', + '-RSB-': ']', + '-LCB-': '{', + '-RCB-': '}', + } + for ptb_token, raw_token in ptb_token_map.items(): + if token == ptb_token: + return raw_token + return token + + +def change_name_using_label_mapper(label_name, label_mapper): + if label_mapper is None or len(label_mapper) == 0: + return label_name + if label_name not in label_mapper: + print(f"{label_name} not found in mapper") + global global_mislabel_log + if label_name not in global_mislabel_log: + global_mislabel_log.add(label_name) + return label_mapper.get(label_name, label_name) diff --git a/metaretriever/docker/Dockerfile b/metaretriever/docker/Dockerfile new file mode 100644 index 00000000..52ccc060 --- /dev/null +++ b/metaretriever/docker/Dockerfile @@ -0,0 +1,29 @@ +FROM nvidia/cuda:11.0.3-cudnn8-devel-ubuntu18.04 +LABEL maintainer="Yaojie Lu" +LABEL repository="uie" + +RUN apt update && \ + apt install -y bash \ + build-essential \ + git \ + curl \ + ca-certificates \ + python3 \ + python3-pip && \ + rm -rf /var/lib/apt/lists + +WORKDIR /pre_env + +RUN python3 -m pip install --no-cache-dir --upgrade pip && \ + python3 -m pip install --no-cache-dir mkl && \ + python3 -m pip install --no-cache-dir torch==1.7.1+cu110 -f https://download.pytorch.org/whl/torch_stable.html + +RUN git clone https://github.com/NVIDIA/apex +RUN cd apex && \ + python3 setup.py install && \ + python3 -m pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ + +COPY ./requirements.txt . +RUN python3 -m pip install -r ./requirements.txt + +CMD ["/bin/bash"] diff --git a/metaretriever/docs/DATASETS.md b/metaretriever/docs/DATASETS.md new file mode 100644 index 00000000..0db8dcd0 --- /dev/null +++ b/metaretriever/docs/DATASETS.md @@ -0,0 +1,53 @@ +# 数据说明 + +``` json +{ + "text": "MULTAN , Pakistan , April 27 ( AFP )", + "tokens": ["MULTAN", ",", "Pakistan", ",", "April", "27", "(", "AFP", ")"], + "record": " geographical social political MULTAN part whole Pakistan geographical social political Pakistan organization AFP ", + "entity": [ + {"type": "geographical social political", "offset": [0], "text": "MULTAN"}, + {"type": "geographical social political", "offset": [2], "text": "Pakistan"}, + {"type": "organization", "offset": [7], "text": "AFP"} + ], + "relation": [ + { + "type": "part whole", + "args": [ + {"type": "geographical social political", "offset": [0], "text": "MULTAN"}, + {"type": "geographical social political", "offset": [2], "text": "Pakistan"} + ] + } + ], + "event": [], + "spot": ["geographical social political", "organization"], + "asoc": ["part whole"], + "spot_asoc": [ + { + "span": "MULTAN", + "label": "geographical social political", + "asoc": [["part whole", "Pakistan"]] + }, + { + "span": "Pakistan", + "label": "geographical social political", "asoc": [] + }, + { + "span": "AFP", "label": "organization", "asoc": [] + } + ], + "task": 'record' +} +``` + +- task: `seq`, `record`, `t5mlm` + - mlm 只要求有 Text + - seq 只要求有 Record + - record 要求有 Text-record 数据 + - 若无,默认为 Text-record 数据 +- spot、asoc + - 文本中的正例类别 +- spot_asoc + - record 结构表示 +- entity relation event + - Offset 标准答案,用于模型验证。 diff --git a/metaretriever/docs/TOOLS.md b/metaretriever/docs/TOOLS.md new file mode 100644 index 00000000..693d9b9b --- /dev/null +++ b/metaretriever/docs/TOOLS.md @@ -0,0 +1,41 @@ +# Tools for UIE + +### Evaluate Model Performance +验证模型性能 (eval_extraction.py) +```text + $ python scripts/eval_extraction.py -h +usage: eval_extraction.py [-h] [-g GOLD_FOLDER] [-p PRED_FOLDER [PRED_FOLDER ...]] [-v] [-w] [-m] [-case] + +optional arguments: + -h, --help show this help message and exit + -g GOLD_FOLDER Golden Dataset folder + -p PRED_FOLDER [PRED_FOLDER ...] + Predicted model folder + -v Show more information during running + -w Write evaluation results to predicted folder + -m Match predicted result multiple times + -case Show case study +``` + +### Check Offset Mapping Performance +验证回标的性能 (check_offset_map_gold_as_pred.bash) +``` bash +bash scripts/check_offset_map_gold_as_pred.bash +``` + +### Convert SEL to Record +将结构化表达式转换成 Record 结构 (sel2record.py) +``` text + $ python scripts/sel2record.py -h +usage: sel2record.py [-h] [-g GOLD_FOLDER] [-p PRED_FOLDER [PRED_FOLDER ...]] [-c MAP_CONFIG] [-d DECODING] [-v] + +optional arguments: + -h, --help show this help message and exit + -g GOLD_FOLDER 标准答案(Gold)文件夹 + -p PRED_FOLDER [PRED_FOLDER ...] + 多个不同的预测(Pred)文件夹 + -c MAP_CONFIG, --config MAP_CONFIG + Offset 匹配策略的配置文件 + -d DECODING 使用 SpotAsoc 结构的解析器进行结构表达式解析 + -v, --verbose 打印更详细的日志信息 +``` diff --git a/metaretriever/etc/record.dataload.schema b/metaretriever/etc/record.dataload.schema new file mode 100644 index 00000000..be119d33 Binary files /dev/null and b/metaretriever/etc/record.dataload.schema differ diff --git a/metaretriever/img/MetaRetriever.png b/metaretriever/img/MetaRetriever.png new file mode 100644 index 00000000..6da003fd Binary files /dev/null and b/metaretriever/img/MetaRetriever.png differ diff --git a/metaretriever/inference.py b/metaretriever/inference.py new file mode 100644 index 00000000..714dab7c --- /dev/null +++ b/metaretriever/inference.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +import json +import re +from tqdm import tqdm +import transformers as huggingface_transformers +from uie.extraction.record_schema import RecordSchema +from uie.sel2record.record import MapConfig +from uie.extraction.scorer import * +from uie.sel2record.sel2record import SEL2Record +import math +import os + + +split_bracket = re.compile(r"\s*\s*") +special_to_remove = {'', ''} + + +def read_json_file(file_name): + return [json.loads(line) for line in open(file_name)] + + +def schema_to_ssi(schema: RecordSchema): + ssi = " " + " ".join(sorted(schema.type_list)) + ssi += " " + " ".join(sorted(schema.role_list)) + ssi += " " + return ssi + + +def post_processing(x): + for special in special_to_remove: + x = x.replace(special, '') + return x.strip() + + +class HuggingfacePredictor: + def __init__(self, model_path, schema_file, max_source_length=256, max_target_length=192) -> None: + self._tokenizer = huggingface_transformers.T5TokenizerFast.from_pretrained( + model_path) + self._model = huggingface_transformers.T5ForConditionalGeneration.from_pretrained( + model_path) + self._model.cuda() + self._schema = RecordSchema.read_from_file(schema_file) + self._ssi = schema_to_ssi(self._schema) + self._max_source_length = max_source_length + self._max_target_length = max_target_length + + def predict(self, text): + text = [self._ssi + x for x in text] + inputs = self._tokenizer( + text, padding=True, return_tensors='pt').to(self._model.device) + + inputs['input_ids'] = inputs['input_ids'][:, :self._max_source_length] + inputs['attention_mask'] = inputs['attention_mask'][:, + :self._max_source_length] + + result = self._model.generate( + input_ids=inputs['input_ids'], + attention_mask=inputs['attention_mask'], + max_length=self._max_target_length, + ) + return self._tokenizer.batch_decode(result, skip_special_tokens=False, clean_up_tokenization_spaces=False) + + +task_dict = { + 'entity': EntityScorer, + 'relation': RelationScorer, + 'event': EventScorer, +} + + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument( + '--data', '-d', default='data/text2spotasoc/absa/14lap') + parser.add_argument( + '--model', '-m', default='./models/uie_n10_21_50w_absa_14lap') + parser.add_argument('--max_source_length', default=256, type=int) + parser.add_argument('--max_target_length', default=192, type=int) + parser.add_argument('--batch_size', default=16, type=int) + parser.add_argument('-c', '--config', dest='map_config', + help='Offset Re-mapping Config', + default='config/offset_map/closest_offset_en.yaml') + parser.add_argument('--decoding', default='spotasoc') + parser.add_argument('--verbose', action='store_true') + parser.add_argument('--match_mode', default='normal', + choices=['set', 'normal', 'multimatch']) + options = parser.parse_args() + + data_folder = options.data + model_path = options.model + + predictor = HuggingfacePredictor( + model_path=model_path, + schema_file=f"{data_folder}/record.schema", + max_source_length=options.max_source_length, + max_target_length=options.max_target_length, + ) + + map_config = MapConfig.load_from_yaml(options.map_config) + schema_dict = SEL2Record.load_schema_dict(data_folder) + sel2record = SEL2Record( + schema_dict=schema_dict, + decoding_schema=options.decoding, + map_config=map_config, + ) + + for split, split_name in [('val', 'eval'), ('test', 'test')]: + gold_filename = f"{data_folder}/{split}.json" + + text_list = [x['text'] for x in read_json_file(gold_filename)] + token_list = [x['tokens'] for x in read_json_file(gold_filename)] + + batch_num = math.ceil(len(text_list) / options.batch_size) + + predict = list() + for index in tqdm(range(batch_num)): + start = index * options.batch_size + end = index * options.batch_size + options.batch_size + + pred_seq2seq = predictor.predict(text_list[start: end]) + pred_seq2seq = [post_processing(x) for x in pred_seq2seq] + + predict += pred_seq2seq + + records = list() + for p, text, tokens in zip(predict, text_list, token_list): + r = sel2record.sel2record(pred=p, text=text, tokens=tokens) + records += [r] + + results = dict() + for task, scorer in task_dict.items(): + + gold_list = [x[task] for x in read_json_file(gold_filename)] + pred_list = [x[task] for x in records] + + gold_instance_list = scorer.load_gold_list(gold_list) + pred_instance_list = scorer.load_pred_list(pred_list) + + sub_results = scorer.eval_instance_list( + gold_instance_list=gold_instance_list, + pred_instance_list=pred_instance_list, + verbose=options.verbose, + match_mode=options.match_mode, + ) + results.update(sub_results) + + with open(os.path.join(options.model, f'{split_name}_preds_record.txt'), 'w') as output: + for record in records: + output.write(f'{json.dumps(record)}\n') + + with open(os.path.join(options.model, f'{split_name}_preds_seq2seq.txt'), 'w') as output: + for pred in predict: + output.write(f'{pred}\n') + + with open(os.path.join(options.model, f'{split_name}_results.txt'), 'w') as output: + for key, value in results.items(): + output.write(f'{split_name}_{key}={value}\n') + + +if __name__ == "__main__": + main() diff --git a/metaretriever/output/.placeholder b/metaretriever/output/.placeholder new file mode 100644 index 00000000..e69de29b diff --git a/metaretriever/plot_from_tensorboard.py b/metaretriever/plot_from_tensorboard.py new file mode 100644 index 00000000..d5ab4c12 --- /dev/null +++ b/metaretriever/plot_from_tensorboard.py @@ -0,0 +1,46 @@ +from tensorboard.backend.event_processing import event_accumulator +import matplotlib.pyplot as plt + +def read_tensorboard_data(tensorboard_log_path, val_name): + ea = event_accumulator.EventAccumulator(tensorboard_log_path) + ea.Reload() + + print("All scalers:") + print(ea.scalars.Keys()) + + val = ea.scalars.Items(val_name) + return val + +def plot(vals, val_names, max_step=None): + plt.figure() + + for val, val_name in zip(vals, val_names): + x = [i.step for i in val] + y = [i.value for i in val] + + if max_step is not None: + x = [i for i in x if i < max_step] + y = y[:len(x)] + + plt.plot(x, y, label=val_name) + + plt.xlabel("step") + plt.ylabel("loss") + plt.legend() + plt.show() + +if __name__ == "__main__": + refine_uie_tensorboard_log_path = "tensorboard_logs/events.out.tfevents.1654419004.dsw32050-7df697f45c-6bwkm.44438.0" + refine_t5_tensorboard_log_path = "tensorboard_logs/events.out.tfevents.1654361305.g64h07153.cloud.sqa.nt12.129194.0" + uie_t5_tensorboard_log_path = "tensorboard_logs/events.out.tfevents.1654275965.eflops-common033255085104.NT12.106708.0" + + val_name = "train/loss" + + refine_uie_val = read_tensorboard_data(refine_uie_tensorboard_log_path, val_name) + refine_t5_val = read_tensorboard_data(refine_t5_tensorboard_log_path, val_name) + uie_t5_val = read_tensorboard_data(uie_t5_tensorboard_log_path, val_name) + + vals = [refine_uie_val, refine_t5_val, uie_t5_val] + val_names = ["refine_uie_loss", "refine_t5_loss", "uie_t5_loss"] + max_step = 20000 + plot(vals, val_names, max_step=max_step) \ No newline at end of file diff --git a/metaretriever/requirements.txt b/metaretriever/requirements.txt new file mode 100644 index 00000000..479ac575 --- /dev/null +++ b/metaretriever/requirements.txt @@ -0,0 +1,173 @@ +absl-py==1.0.0 +altair==4.2.0 +anyio==3.5.0 +anytree==2.8.0 +argon2-cffi==21.3.0 +argon2-cffi-bindings==21.2.0 +asgiref==3.5.0 +astor==0.8.1 +asttokens==2.0.5 +attrs==21.4.0 +autopep8==1.6.0 +backcall==0.2.0 +backports.zoneinfo==0.2.1 +base58==2.1.1 +black==21.12b0 +bleach==4.1.0 +blinker==1.4 +cachetools==4.2.4 +certifi==2021.10.8 +cffi==1.15.0 +charset-normalizer==2.0.10 +click==8.0.3 +colorama==0.4.4 +colorlog==6.6.0 +commonmark==0.9.1 +conllu==4.4.1 +cycler==0.11.0 +dataclasses==0.6 +datasets==1.9.0 +debugpy==1.5.1 +decorator==5.1.1 +defusedxml==0.7.1 +dill==0.3.4 +elasticsearch==7.16.3 +entrypoints==0.3 +executing==0.8.2 +faiss-cpu==1.7.2 +fastapi==0.74.1 +filelock==3.0.12 +fire==0.4.0 +fonttools==4.28.5 +fsspec==2022.1.0 +future==0.18.2 +git-python==1.0.3 +gitdb==4.0.9 +GitPython==3.1.26 +google-auth==2.3.3 +google-auth-oauthlib==0.4.6 +googleapis-common-protos==1.54.0 +grpcio==1.43.0 +h11==0.13.0 +h5py==3.6.0 +huggingface-hub==0.0.8 +idna==3.3 +importlib-metadata==4.10.1 +importlib-resources==5.4.0 +iniconfig==1.1.1 +ipykernel==6.7.0 +ipython +ipython-genutils==0.2.0 +ipywidgets==7.6.5 +jedi==0.18.1 +jieba==0.42.1 +Jinja2==3.0.3 +joblib==1.1.0 +jsonschema==4.4.0 +jupyter-client==7.1.1 +jupyter-core==4.9.1 +jupyterlab-pygments==0.1.2 +jupyterlab-widgets==1.0.2 +kiwisolver==1.3.2 +Markdown==3.3.6 +MarkupSafe==2.0.1 +matplotlib==3.5.1 +matplotlib-inline==0.1.3 +mistune==0.8.4 +multiprocess==0.70.12.2 +mypy-extensions==0.4.3 +nbclient==0.5.10 +nbconvert==6.4.0 +nbformat==5.1.3 +nest-asyncio==1.5.4 +nltk==3.6.7 +notebook==6.4.7 +numpy==1.19.5 +oauthlib==3.1.1 +packaging==21.3 +pandas==1.3.5 +pandocfilters==1.5.0 +parso==0.8.3 +pathspec==0.9.0 +pexpect==4.8.0 +pickleshare==0.7.5 +Pillow==9.0.0 +platformdirs==2.4.1 +pluggy==1.0.0 +portalocker==2.3.2 +prometheus-client==0.12.0 +promise==2.3 +prompt-toolkit==3.0.24 +protobuf==3.19.3 +psutil==5.9.0 +ptyprocess==0.7.0 +pure-eval==0.2.1 +py==1.11.0 +pyarrow==4.0.1 +pyasn1==0.4.8 +pyasn1-modules==0.2.8 +pycodestyle==2.8.0 +pycparser==2.21 +pydantic==1.9.0 +pydeck==0.7.1 +Pygments==2.11.2 +Pympler==1.0.1 +pyparsing==3.0.6 +pyrsistent==0.18.1 +pytest==6.2.5 +python-dateutil==2.8.2 +pytz==2021.3 +pytz-deprecation-shim==0.1.0.post0 +PyYAML==6.0 +pyzmq==22.3.0 +regex==2022.1.18 +requests==2.27.1 +requests-oauthlib==1.3.0 +rich==9.8.2 +rouge-score==0.0.4 +rsa==4.8 +sacrebleu==1.4.14 +sacremoses==0.0.47 +scikit-learn==1.0.2 +scipy==1.7.3 +Send2Trash==1.8.0 +sentencepiece==0.1.96 +seqeval==1.2.2 +six==1.16.0 +smmap==5.0.0 +sniffio==1.2.0 +stack-data==0.1.4 +starlette==0.17.1 +streamlit==1.4.0 +tabulate==0.8.9 +tensorboard==2.7.0 +tensorboard-data-server==0.6.1 +tensorboard-plugin-wit==1.8.1 +tensorflow-datasets==4.4.0 +tensorflow-metadata==1.6.0 +termcolor==1.1.0 +terminado==0.12.1 +testpath==0.5.0 +threadpoolctl==3.0.0 +tokenizers==0.10.3 +toml==0.10.2 +tomli==1.2.3 +toolz==0.11.2 +tornado==6.1 +tqdm==4.62.3 +traitlets==5.1.1 +transformers==4.6.1 +typing-extensions==3.10.0.2 +tzdata==2021.5 +tzlocal==4.1 +urllib3==1.26.8 +uvicorn==0.17.5 +validators==0.18.2 +watchdog==2.1.6 +wcwidth==0.2.5 +webencodings==0.5.1 +Werkzeug==2.0.2 +widgetsnbextension==3.5.2 +xxhash==2.0.2 +zipp==3.7.0 +learn2learn \ No newline at end of file diff --git a/metaretriever/run_eval.bash b/metaretriever/run_eval.bash new file mode 100644 index 00000000..0f83b060 --- /dev/null +++ b/metaretriever/run_eval.bash @@ -0,0 +1,87 @@ +device="0" +model_path="" +data_folder=data/text2spotasoc/absa/14lap +task_name="meta" +batch=16 +decoding_format='spotasoc' +beam_size=1 +map_config=config/offset_map/closest_offset_en.yaml + +export PYTHONPATH="${PYTHONPATH}:./" + +OPTS=$(getopt -o b:d:m:i:t:co:f:e: --long batch:,device:,model:,data:,task:constraint_decoding,output:,format:,map_config:,extra_cmd:, -n 'parse-options' -- "$@") + +if [ $? != 0 ]; then + echo "Failed parsing options." >&2 + exit 1 +fi + +eval set -- "$OPTS" + +while true; do + case "$1" in + -b | --batch) batch="$2" + shift 2 ;; + -d | --device) device="$2" + shift 2 ;; + -m | --model) model_path="$2" + shift 2 ;; + -i | --data) data_folder="$2" + shift 2 ;; + -t | --task) task_name="$2" + shift 2 ;; + -c | --constraint_decoding) constraint_decoding="--constraint_decoding" + shift ;; + -o | --output) output_dir="$2" + shift 2 ;; + -f | --format) decoding_format="$2" + shift 2 ;; + -e | --extra_cmd) extra_cmd="$2" + shift 2 ;; + --beam) beam_size="$2" + shift 2 ;; + --map_config) map_config="$2" + shift 2 ;; + --) + shift + break + ;; + *) + echo "$1" not recognize. + exit + ;; + esac +done + +echo "Extra CMD: " "${extra_cmd}" + +if [[ ${output_dir} == "" ]] +then + output_dir=${model_path}_eval + if [[ ${constraint_decoding} != "" ]] + then + output_dir=${output_dir}_CD + fi +fi + +CUDA_VISIBLE_DEVICES=${device} python3 run_seq2seq.py \ + --use_fast_tokenizer=True \ + --max_source_length=${max_source_length:-"256"} \ + --max_target_length=${max_target_length:-"192"} \ + --do_eval --do_predict --task=record --predict_with_generate \ + --validation_file=${data_folder}/val.json \ + --test_file=${data_folder}/test.json \ + --record_schema=${data_folder}/record.schema \ + --model_name_or_path=${model_path} \ + --output_dir=${output_dir} \ + --source_prefix="${task_name}: " \ + --no_remove_unused_columns \ + --num_beams=${beam_size} \ + ${constraint_decoding} ${extra_cmd} \ + --per_device_eval_batch_size=${batch} \ + --decoding_format ${decoding_format} + +python3 scripts/sel2record.py -p ${output_dir} -g ${data_folder} -v -d ${decoding_format} -c ${map_config} +python3 scripts/eval_extraction.py -p ${output_dir} -g ${data_folder} -w -m ${eval_match_mode:-"normal"} + +echo "Output Dir:" ${output_dir} diff --git a/metaretriever/run_seq2seq.py b/metaretriever/run_seq2seq.py new file mode 100644 index 00000000..08077ac9 --- /dev/null +++ b/metaretriever/run_seq2seq.py @@ -0,0 +1,779 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for sequence to sequence. +""" +# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. + +import logging +import os +import sys +from dataclasses import dataclass, field +os.environ['CUDA_LAUNCH_BLOCKING'] = '1' + +from typing import Optional + +import numpy as np +from datasets import load_dataset + +import transformers +from transformers import ( + AutoConfig, + AutoModelForSeq2SeqLM, + AutoTokenizer, + DataCollatorForSeq2Seq, + HfArgumentParser, + default_data_collator, + set_seed +) +from transformers.trainer_utils import get_last_checkpoint, is_main_process + +from uie.extraction import constants +from uie.extraction.record_schema import RecordSchema +from uie.extraction.predict_parser import decoding_format_dict +from uie.extraction.extraction_metrics import get_extract_metrics +from uie.extraction.noiser.spot_asoc_noiser import SpotAsocNoiser +from uie.extraction.dataset_processer import PrefixGenerator +from uie.seq2seq.constrained_seq2seq import ( + ConstraintSeq2SeqTrainingArguments, + ConstraintSeq2SeqTrainer, + OriginalConstraintSeq2SeqTrainer, + UIEPretrainConstraintSeq2SeqTrainer, + UIEFinetuneConstraintSeq2SeqTrainer, + MetaPretrainConstraintSeq2SeqTrainer, + MetaFinetuneConstraintSeq2SeqTrainer, +) +from uie.seq2seq.data_collator import ( + DataCollatorForMetaSeq2Seq, + DynamicSSIGenerator, +) +from uie.seq2seq.features import RecordFeature +from uie.seq2seq.model import PromptSeq2SeqTransformer +from uie.seq2seq.noise_record import create_noised_record + +import pdb + +logger = logging.getLogger(__name__) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, + ) + use_fast_tokenizer: bool = field( + default=False, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + }, + ) + from_checkpoint: bool = field( + default=False, metadata={"help": "Whether load from checkpoint to continue learning"} + ) + load_config_only: bool = field( + default=False, metadata={"help": "Whether load model config only from checkpoint"} + ) + use_prompt_tuning_model: bool = field( + default=False, metadata={"help": "Whether use prompt tuning model"} + ) + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + task: str = field( + default="summarization", + metadata={ + "help": "The name of the task, should be summarization (or summarization_{dataset} for evaluating " + "pegasus) or translation (or translation_{xx}_to_{yy})." + }, + ) + dataset_name: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + text_column: Optional[str] = field( + default='text', + metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, + ) + record_column: Optional[str] = field( + default='record', + metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."}, + ) + train_file: Optional[str] = field( + default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} + ) + validation_file: Optional[str] = field( + default=None, + metadata={ + "help": "An optional input evaluation data file to evaluate the metrics (rouge/sacreblue) on " + "(a jsonlines or csv file)." + }, + ) + test_file: Optional[str] = field( + default=None, + metadata={ + "help": "An optional input test data file to evaluate the metrics (rouge/sacreblue) on " + "(a jsonlines or csv file)." + }, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_source_length: Optional[int] = field( + default=1024, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + max_target_length: Optional[int] = field( + default=128, + metadata={ + "help": "The maximum total sequence length for target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + max_prefix_length: Optional[int] = field( + default=None, + metadata={ + "help": "The maximum prefix length." + }, + ) + val_max_target_length: Optional[int] = field( + default=None, + metadata={ + "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." + "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " + "during ``evaluate`` and ``predict``." + }, + ) + pad_to_max_length: bool = field( + default=False, + metadata={ + "help": "Whether to pad all samples to model maximum sentence length. " + "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " + "efficient on GPU but very bad for TPU." + }, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_val_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of validation examples to this " + "value if set." + }, + ) + max_test_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of test examples to this " + "value if set." + }, + ) + num_beams: Optional[int] = field( + default=None, + metadata={ + "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " + "which is used during ``evaluate`` and ``predict``." + }, + ) + ignore_pad_token_for_loss: bool = field( + default=True, + metadata={ + "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." + }, + ) + source_prefix: Optional[str] = field( + default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."} + ) + meta_negative: int = field( + default=-1, metadata={"help": "Negative Schema Number in Training."} + ) + ordered_prompt: bool = field( + default=True, + metadata={ + "help": "Whether to sort the spot prompt and asoc prompt or not." + }, + ) + + def __post_init__(self): + if self.dataset_name is None and self.train_file is None and self.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") + else: + if self.train_file is not None: + extension = self.train_file.split(".")[-1] + assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." + if self.validation_file is not None: + extension = self.validation_file.split(".")[-1] + assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." + if self.val_max_target_length is None: + self.val_max_target_length = self.max_target_length + + decoding_format: str = field( + default='tree', + metadata={"help": "Decoding Format, valid in %s" % decoding_format_dict.keys()} + ) + record_schema: str = field( + default=None, metadata={"help": "The input event schema file."} + ) + spot_noise: float = field( + default=0., metadata={"help": "The noise rate of null spot."} + ) + asoc_noise: float = field( + default=0., metadata={"help": "The noise rate of null asoc."} + ) + meta_positive_rate: float = field( + default=1., metadata={"help": "The keep rate of positive spot."} + ) + + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, ConstraintSeq2SeqTrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) + + logger.info("Options:") + logger.info(model_args) + logger.info(data_args) + logger.info(training_args) + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + # Set the verbosity to info of the Transformers logger (on main process only): + if is_main_process(training_args.local_rank): + transformers.utils.logging.set_verbosity_info() + logger.info("Training/evaluation parameters %s", training_args) + + # Set seed before initializing model. + set_seed(training_args.seed) + + # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files in the summarization task, this script will use the first column for the full texts and the + # second column for the summaries (unless you specify column names for this with the `text_column` and + # `record_column` arguments). + # For translation, only JSON files are supported, with one field named "translation" containing two keys for the + # source and target languages (unless you adapt what follows). + # + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + if data_args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name) + else: + data_files = {} + if data_args.train_file is not None: + data_files["train"] = data_args.train_file + extension = data_args.train_file.split(".")[-1] + if training_args.do_eval and data_args.validation_file is not None: + data_files["validation"] = data_args.validation_file + extension = data_args.validation_file.split(".")[-1] + if training_args.do_predict and data_args.test_file is not None: + data_files["test"] = data_args.test_file + extension = data_args.test_file.split(".")[-1] + logger.info(data_files) + + datasets = load_dataset("uie_json.py", data_files=data_files, block_size=(10<<22)) + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + logger.info(datasets) + # Load pretrained model and tokenizer + # + # Distributed training: + # The .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + logger.info("Load Config: %s" % model_args.config_name if model_args.config_name else model_args.model_name_or_path) + + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + + config.max_length = data_args.max_target_length + + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_fast=model_args.use_fast_tokenizer, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + + to_remove_token_list = list() + if tokenizer.bos_token: + to_remove_token_list += [tokenizer.bos_token] + if tokenizer.eos_token: + to_remove_token_list += [tokenizer.eos_token] + if tokenizer.pad_token: + to_remove_token_list += [tokenizer.pad_token] + + if model_args.use_prompt_tuning_model: + MODEL = PromptSeq2SeqTransformer + else: + MODEL = AutoModelForSeq2SeqLM + + if model_args.load_config_only: + model = MODEL.from_config(config) + else: + model = MODEL.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + mirror='tuna', + ) + + if training_args.do_train: + to_add_special_token = list() + for special_token in [constants.type_start, constants.type_end, constants.text_start, constants.span_start, constants.spot_prompt, constants.asoc_prompt]: + if special_token not in tokenizer.get_vocab(): + to_add_special_token += [special_token] + + tokenizer.add_special_tokens( + {"additional_special_tokens": tokenizer.special_tokens_map_extended['additional_special_tokens'] + to_add_special_token} + ) + + model.resize_token_embeddings(len(tokenizer)) + + logger.info(tokenizer) + + # Set decoder_start_token_id + if model.config.decoder_start_token_id is None: + raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") + + if data_args.record_schema and os.path.exists(data_args.record_schema): + record_schema = RecordSchema.read_from_file(data_args.record_schema) + else: + record_schema = None + + if data_args.source_prefix is not None: + if data_args.source_prefix == 'schema': + prefix = PrefixGenerator.get_schema_prefix(schema=record_schema) + elif data_args.source_prefix.startswith('meta'): + prefix = "" + else: + prefix = data_args.source_prefix + else: + prefix = "" + logger.info(f"Prefix: {prefix}") + logger.info(f"Prefix Length: {len(tokenizer.tokenize(prefix))}") + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + if training_args.do_train: + column_names = datasets["train"].column_names + elif training_args.do_eval: + column_names = datasets["validation"].column_names + elif training_args.do_predict: + column_names = datasets["test"].column_names + else: + logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") + return + + # To serialize preprocess_function below, each of those four variables needs to be defined (even if we won't use + # them all). + + text_column = data_args.text_column + record_column = data_args.record_column + logger.info('Using src: %s and tgt: %s' % (text_column, record_column)) + + # Temporarily set max_target_length for training. + max_target_length = data_args.max_target_length + padding = "max_length" if data_args.pad_to_max_length else False + + if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"): + logger.error( + "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for" + f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory" + ) + + def preprocess_function(examples): + inputs = examples[text_column] + targets = examples[record_column] + inputs = [prefix + inp for inp in inputs] + model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) + + model_inputs["text"] = inputs + + # Setup the tokenizer for targets + with tokenizer.as_target_tokenizer(): + labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True) + + # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore + # padding in the loss. + if padding == "max_length" and data_args.ignore_pad_token_for_loss: + labels["input_ids"] = [ + [(_label if _label != tokenizer.pad_token_id else -100) for _label in label] for label in labels["input_ids"] + ] + + model_inputs["labels"] = labels["input_ids"] + + # set noised record inputs + noised_record_list = [] + for idx, noised_record in enumerate(examples["noised_record"]): + if noised_record is None: + tokens = examples["tokens"][idx] + entity_list = examples["entity"][idx] + triple_list = examples["relation"][idx] + event_list = examples["event"][idx] + + noised_record = create_noised_record(tokens, entity_list, triple_list, event_list) + noised_record_list.append(noised_record) + model_inputs["noised_record"] = noised_record_list + # model_inputs["noised_record"] = examples["noised_record"] + + # others + model_inputs['sample_prompt'] = [False] * len(model_inputs['input_ids']) + if data_args.source_prefix is not None and data_args.source_prefix.startswith('meta'): + model_inputs['spots'] = examples['spot'] + model_inputs['asocs'] = examples['asoc'] + model_inputs['spot_asoc'] = examples['spot_asoc'] + # sample_prompt=True for Finetune and Pretrain + model_inputs['sample_prompt'] = [True] * len(model_inputs['input_ids']) + + return model_inputs + + def preprocess_function_eval(examples): + model_inputs = preprocess_function(examples) + # sample_prompt=False for evaluation + model_inputs['sample_prompt'] = [False] * len(model_inputs['input_ids']) + return model_inputs + + def postprocess_text(x_str): + # Clean `bos` `eos` `pad` for cleaned text + for to_remove_token in to_remove_token_list: + x_str = x_str.replace(to_remove_token, '') + + return x_str.strip() + + logger.info("Start Data Preprocessing ...") + + if training_args.do_train: + train_dataset = datasets["train"] + if data_args.max_train_samples is not None: + train_dataset = train_dataset.select(range(data_args.max_train_samples)) + train_dataset = train_dataset.map( + preprocess_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + features=RecordFeature, + ) + + if training_args.do_eval: + max_target_length = data_args.val_max_target_length + eval_dataset = datasets["validation"] + if data_args.max_val_samples is not None: + eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) + eval_dataset = eval_dataset.map( + preprocess_function_eval, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + features=RecordFeature, + ) + + if training_args.do_predict: + max_target_length = data_args.val_max_target_length + test_dataset = datasets["test"] + if data_args.max_test_samples is not None: + test_dataset = test_dataset.select(range(data_args.max_test_samples)) + test_dataset = test_dataset.map( + preprocess_function_eval, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + features=RecordFeature, + ) + + logger.info("End Data Preprocessing ...") + + # Data collator + label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id + if data_args.pad_to_max_length: + data_collator = default_data_collator + elif data_args.source_prefix.startswith('meta'): + + if data_args.spot_noise > 0 or data_args.asoc_noise > 0: + if data_args.decoding_format == 'spotasoc': + spot_asoc_nosier = SpotAsocNoiser( + spot_noise_ratio=data_args.spot_noise, + asoc_noise_ratio=data_args.asoc_noise, + null_span=constants.null_span, + ) + else: + raise NotImplementedError( + "decoding_format `spotasoc` is not implemented." + ) + else: + spot_asoc_nosier = None + + data_collator = DataCollatorForMetaSeq2Seq( + tokenizer, + model=model, + label_pad_token_id=label_pad_token_id, + pad_to_multiple_of=8 if training_args.fp16 else None, + max_length=data_args.max_source_length, + max_prefix_length=data_args.max_prefix_length, + max_target_length=data_args.max_target_length, + negative_sampler=DynamicSSIGenerator( + tokenizer=tokenizer, + schema=record_schema, + positive_rate=data_args.meta_positive_rate, + negative=data_args.meta_negative, + ordered_prompt=data_args.ordered_prompt, + ), + spot_asoc_nosier=spot_asoc_nosier, + decoding_format=data_args.decoding_format, + ) + else: + data_collator = DataCollatorForSeq2Seq( + tokenizer, + model=model, + label_pad_token_id=label_pad_token_id, + pad_to_multiple_of=8 if training_args.fp16 else None, + ) + + def compute_metrics(eval_preds): + preds, labels = eval_preds + if isinstance(preds, tuple): + preds = preds[0] + decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=False, clean_up_tokenization_spaces=False) + if data_args.ignore_pad_token_for_loss: + # Replace -100 in the labels as we can't decode them. + labels = np.where(labels != -100, labels, tokenizer.pad_token_id) + decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=False, clean_up_tokenization_spaces=False) + + decoded_preds = [postprocess_text(x) for x in decoded_preds] + decoded_labels = [postprocess_text(x) for x in decoded_labels] + + result = get_extract_metrics( + pred_lns=decoded_preds, + tgt_lns=decoded_labels, + label_constraint=record_schema, + decoding_format=data_args.decoding_format, + ) + + prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] + result["gen_len"] = np.mean(prediction_lens) + result = {k: round(v, 4) for k, v in result.items()} + return result + + # Initialize our Trainer + if training_args.trainer_type == "uie_pretrain": + TRAINER = UIEPretrainConstraintSeq2SeqTrainer + elif training_args.trainer_type == "uie_finetune": + TRAINER = UIEFinetuneConstraintSeq2SeqTrainer + elif training_args.trainer_type == "meta_pretrain": + TRAINER = MetaPretrainConstraintSeq2SeqTrainer + elif training_args.trainer_type == "meta_finetune": + TRAINER = MetaFinetuneConstraintSeq2SeqTrainer + else: + TRAINER = OriginalConstraintSeq2SeqTrainer + + trainer = TRAINER( + model=model, + args=training_args, + train_dataset=train_dataset if training_args.do_train else None, + eval_dataset=eval_dataset if training_args.do_eval else None, + tokenizer=tokenizer, + data_collator=data_collator, + compute_metrics=compute_metrics if training_args.predict_with_generate else None, + decoding_type_schema=record_schema, + decoding_format=data_args.decoding_format, + source_prefix=prefix, + task=data_args.task, + ) + + # Training + if training_args.do_train: + if model_args.from_checkpoint: + if last_checkpoint is not None: + checkpoint = last_checkpoint + elif os.path.isdir(model_args.model_name_or_path): + checkpoint = model_args.model_name_or_path + else: + checkpoint = None + else: + checkpoint = None + + train_result = trainer.train(resume_from_checkpoint=checkpoint) + trainer.save_model() # Saves the tokenizer too for easy upload + + output_train_file = os.path.join(training_args.output_dir, "train_results.txt") + if trainer.is_world_process_zero(): + with open(output_train_file, "w") as writer: + logger.info("***** Train results *****") + for key, value in sorted(train_result.metrics.items()): + logger.info(f" {key} = {value}") + writer.write(f"{key} = {value}\n") + + # Need to save the state, since Trainer.save_model saves only the tokenizer with the model + trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json")) + + # Evaluation + results = {} + if training_args.do_eval: + logger.info("*** Evaluate ***") + + results = trainer.evaluate(max_length=data_args.val_max_target_length, num_beams=data_args.num_beams) + results = {k: round(v, 4) for k, v in results.items()} + + eval_results = trainer.predict( + eval_dataset, + metric_key_prefix="eval", + max_length=data_args.val_max_target_length, + num_beams=data_args.num_beams, + ) + + output_eval_file = os.path.join(training_args.output_dir, "eval_results_seq2seq.txt") + if trainer.is_world_process_zero(): + with open(output_eval_file, "w") as writer: + logger.info("***** Eval results *****") + for key, value in sorted(results.items()): + logger.info(f" {key} = {value}") + writer.write(f"{key} = {value}\n") + + if training_args.predict_with_generate: + eval_preds = tokenizer.batch_decode( + eval_results.predictions, skip_special_tokens=False, clean_up_tokenization_spaces=False + ) + eval_preds = [postprocess_text(pred) for pred in eval_preds] + output_test_preds_file = os.path.join(training_args.output_dir, "eval_preds_seq2seq.txt") + with open(output_test_preds_file, "w") as writer: + writer.write("\n".join(eval_preds)) + + if training_args.do_predict: + logger.info("*** Test ***") + + test_results = trainer.predict( + test_dataset, + metric_key_prefix="test", + max_length=data_args.val_max_target_length, + num_beams=data_args.num_beams, + ) + test_metrics = test_results.metrics + test_metrics["test_loss"] = round(test_metrics["test_loss"], 4) + + output_test_result_file = os.path.join(training_args.output_dir, "test_results_seq2seq.txt") + if trainer.is_world_process_zero(): + with open(output_test_result_file, "w") as writer: + logger.info("***** Test results *****") + for key, value in sorted(test_metrics.items()): + logger.info(f" {key} = {value}") + writer.write(f"{key} = {value}\n") + + if training_args.predict_with_generate: + test_preds = tokenizer.batch_decode( + test_results.predictions, skip_special_tokens=False, clean_up_tokenization_spaces=False + ) + test_preds = [postprocess_text(pred) for pred in test_preds] + output_test_preds_file = os.path.join(training_args.output_dir, "test_preds_seq2seq.txt") + with open(output_test_preds_file, "w") as writer: + writer.write("\n".join(test_preds)) + + return results + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/metaretriever/run_seq2seq_pretrain.bash b/metaretriever/run_seq2seq_pretrain.bash new file mode 100644 index 00000000..89fb91b4 --- /dev/null +++ b/metaretriever/run_seq2seq_pretrain.bash @@ -0,0 +1,104 @@ +#!/usr/bin/env bash +# -*- coding:utf-8 -*- +export batch_size="16" +export model_name=uie-base-en +export data_name=absa/14lap +export task_name="meta" +export decoding_format='spotasoc' + +source scripts/function_code.bash + +for index in $(seq 1 ${run_time}); do + + if [[ ! ${output_dir} ]] + then + output_dir=${model_folder}_run${index} + echo "output_dir is not provided so create it automatically: ${output_dir}" + else + echo "output_dir is provided: ${output_dir}" + fi + + if [[ ${verbose} == true ]] + then + stdout_file=/dev/stdout + stderr_file=/dev/stderr + disable_tqdm=False + else + stdout_file=${output_dir}.log + stderr_file=${output_dir}.err + disable_tqdm=True + fi + + # CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} gdb --args ${run_command} run_seq2seq.py \ + CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} ${run_command} run_seq2seq.py \ + --do_train ${constraint_decoding} ${fp16} \ + --trainer_type=${trainer_type} \ + --load_config_only=False \ + --use_fast_tokenizer=True \ + --ddp_find_unused_parameters=False \ + --predict_with_generate \ + --evaluation_strategy="no" \ + --metric_for_best_model eval_overall-F1 \ + --save_strategy="steps" \ + --save_steps=10000 \ + --save_total_limit 9999999 \ + --load_best_model_at_end=False \ + --max_source_length="128" \ + --max_prefix_length="-1" \ + --max_target_length="128" \ + --num_train_epochs=${epoch} \ + --task=${task_name} \ + --train_file=${data_folder}/train.json \ + --validation_file=${data_folder}/val.json \ + --test_file=${data_folder}/test.json \ + --record_schema=${data_folder}/record.schema \ + --per_device_train_batch_size=${batch_size} \ + --per_device_eval_batch_size=$((batch_size * 4)) \ + --output_dir=${output_dir} \ + --from_checkpoint=True \ + --logging_dir=${output_dir}_log \ + --logging_strategy="steps" \ + --logging_first_step=True \ + --logging_steps=100 \ + --model_name_or_path=${model_name} \ + --learning_rate=${lr} \ + --source_prefix="${task_name}: " \ + --lr_scheduler_type=${lr_scheduler} \ + --label_smoothing_factor=${label_smoothing} \ + --eval_steps ${eval_steps} \ + --decoding_format ${decoding_format} \ + --warmup_ratio ${warmup_ratio} \ + --preprocessing_num_workers=32 \ + --dataloader_num_workers=32 \ + --meta_negative=10 \ + --meta_positive_rate=${positive} \ + --skip_memory_metrics \ + --no_remove_unused_columns \ + --ordered_prompt=${ordered_prompt} \ + --save_better_checkpoint=False \ + --start_eval_step=${start_eval_step:-"0"} \ + --spot_noise=${spot_noise} \ + --asoc_noise=${asoc_noise} \ + --seed=${seed}${index} --disable_tqdm=${disable_tqdm} >${stdout_file} 2>${stderr_file} + echo "exit code:" $? + + # --max_source_length=${max_source_length:-"128"} \ + # --max_prefix_length=${max_prefix_length:-"-1"} \ + # --max_target_length=${max_target_length:-"128"} \ + # --save_strategy=${evaluation_strategy} \ + # --save_total_limit 1 \ + # --load_best_model_at_end \ + + if [[ ${verbose} != true ]] + then + tail -n 200 ${stderr_file} + fi + + # echo "Map Config" ${map_config} + # python3 scripts/sel2record.py -p ${output_dir} -g ${data_folder} -v -d ${decoding_format} -c ${map_config} + # python3 scripts/eval_extraction.py -p ${output_dir} -g ${data_folder} -w -m ${eval_match_mode:-"normal"} + + # delete all optimizer.pt for saving disk + find ${output_dir}/ | grep -P "optimizer.pt" | xargs rm -rf + +done diff --git a/metaretriever/run_seq2seq_record.bash b/metaretriever/run_seq2seq_record.bash new file mode 100644 index 00000000..9a7b1bb3 --- /dev/null +++ b/metaretriever/run_seq2seq_record.bash @@ -0,0 +1,107 @@ +#!/usr/bin/env bash +# -*- coding:utf-8 -*- +export batch_size="16" +export model_name=uie-base-en +export data_name=absa/14lap +export task_name="meta" +export decoding_format='spotasoc' + +source scripts/function_code.bash + +for index in $(seq 1 ${run_time}); do + + if [[ ! ${output_dir} ]] + then + output_dir=${model_folder}_run${index} + echo "output_dir is not provided so create it automatically: ${output_dir}" + else + echo "output_dir is provided: ${output_dir}" + fi + + # output_dir=${model_folder}_run${index} + + if [[ ${verbose} == true ]] + then + stdout_file=/dev/stdout + stderr_file=/dev/stderr + disable_tqdm=False + else + stdout_file=${output_dir}.log + stderr_file=${output_dir}.err + disable_tqdm=True + fi + + # CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} gdb --args ${run_command} run_seq2seq.py \ + CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} ${run_command} run_seq2seq.py \ + --do_train --do_eval --do_predict ${constraint_decoding} ${fp16} \ + --use_prompt_tuning_model=${use_prompt_tuning_model} \ + --trainer_type=${trainer_type} \ + --load_config_only=False \ + --use_fast_tokenizer=True \ + --ddp_find_unused_parameters=False \ + --predict_with_generate \ + --evaluation_strategy=${evaluation_strategy} \ + --save_strategy=${evaluation_strategy} \ + --metric_for_best_model eval_overall-F1 \ + --save_total_limit 1 \ + --load_best_model_at_end \ + --max_source_length=${max_source_length:-"256"} \ + --max_prefix_length=${max_prefix_length:-"-1"} \ + --max_target_length=${max_target_length:-"192"} \ + --num_train_epochs=${epoch} \ + --task=${task_name} \ + --train_file=${data_folder}/train.json \ + --validation_file=${data_folder}/val.json \ + --test_file=${data_folder}/test.json \ + --record_schema=${data_folder}/record.schema \ + --per_device_train_batch_size=${batch_size} \ + --per_device_eval_batch_size=$((batch_size * 4)) \ + --output_dir=${output_dir} \ + --logging_dir=${output_dir}_log \ + --logging_strategy="steps" \ + --logging_first_step=True \ + --logging_steps=100 \ + --model_name_or_path=${model_name} \ + --learning_rate=${lr} \ + --source_prefix="${task_name}: " \ + --lr_scheduler_type=${lr_scheduler} \ + --label_smoothing_factor=${label_smoothing} \ + --eval_steps ${eval_steps} \ + --decoding_format ${decoding_format} \ + --warmup_ratio ${warmup_ratio} \ + --preprocessing_num_workers=32 \ + --dataloader_num_workers=32 \ + --meta_negative=${negative} \ + --meta_positive_rate=${positive} \ + --skip_memory_metrics \ + --no_remove_unused_columns \ + --ordered_prompt=${ordered_prompt} \ + --save_better_checkpoint=False \ + --start_eval_step=${start_eval_step:-"0"} \ + --spot_noise=${spot_noise} \ + --asoc_noise=${asoc_noise} \ + --seed=${seed}${index} --disable_tqdm=${disable_tqdm} >${stdout_file} 2>${stderr_file} + echo "exit code:" $? + + # --save_strategy=${evaluation_strategy} \ + # --save_total_limit 1 \ + # --load_best_model_at_end \ + + # --save_strategy="steps" \ + # --save_steps=5000 \ + # --save_total_limit 9999999 \ + # --load_best_model_at_end=True \ + + if [[ ${verbose} != true ]] + then + tail -n 200 ${stderr_file} + fi + + echo "Map Config" ${map_config} + python3 scripts/sel2record.py -p ${output_dir} -g ${data_folder} -v -d ${decoding_format} -c ${map_config} + python3 scripts/eval_extraction.py -p ${output_dir} -g ${data_folder} -w -m ${eval_match_mode:-"normal"} + + # delete all optimizer.pt for saving disk + # find ${output_dir}/ | grep -P "optimizer.pt" | xargs rm -rf + +done diff --git a/metaretriever/run_seq2seq_record_ratio.bash b/metaretriever/run_seq2seq_record_ratio.bash new file mode 100644 index 00000000..f55e9e69 --- /dev/null +++ b/metaretriever/run_seq2seq_record_ratio.bash @@ -0,0 +1,96 @@ +#!/usr/bin/env bash +# -*- coding:utf-8 -*- +export batch_size="16" +export model_name=uie-base-en +export data_name=absa/14lap +export task_name="meta" +export decoding_format='spotasoc' + +source scripts/function_code.bash + +for index in $(seq 1 ${run_time}); do + output_dir=${model_folder}_run${index} + + if [[ ${verbose} == true ]] + then + stdout_file=/dev/stdout + stderr_file=/dev/stderr + disable_tqdm=False + else + stdout_file=${output_dir}.log + stderr_file=${output_dir}.err + disable_tqdm=True + fi + + ratio_data_folder=${data_folder}_ratio/seed${index} + + for ratio in $(ls ${ratio_data_folder}) + do + run_data_folder=${ratio_data_folder}/${ratio} + run_output_folder=${output_dir}_${ratio} + + if [[ ${max_prefix_length} == 0 ]] + then + run_output_folder=${run_output_folder}_noprefix + fi + + eval_steps=$(python scripts/get_eval_batch_num.py ${run_data_folder}/train.json ${batch_size} 20) + echo Eval each ${eval_steps} batch + + CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} ${run_command} run_seq2seq.py \ + --do_train --do_eval --do_predict ${constraint_decoding} ${fp16} \ + --trainer_type=${trainer_type} \ + --load_config_only=False \ + --use_fast_tokenizer=True \ + --ddp_find_unused_parameters=False \ + --predict_with_generate \ + --evaluation_strategy=steps \ + --save_strategy=steps \ + --load_best_model_at_end \ + --metric_for_best_model eval_overall-F1 \ + --save_total_limit 1 \ + --max_source_length=${max_source_length:-"256"} \ + --max_prefix_length=${max_prefix_length:-"-1"} \ + --max_target_length=${max_target_length:-"192"} \ + --num_train_epochs=${epoch} \ + --task=${task_name} \ + --train_file=${run_data_folder}/train.json \ + --validation_file=${run_data_folder}/val.json \ + --test_file=${run_data_folder}/test.json \ + --record_schema=${run_data_folder}/record.schema \ + --per_device_train_batch_size=${batch_size} \ + --per_device_eval_batch_size=$((batch_size * 4)) \ + --output_dir=${run_output_folder} \ + --logging_dir=${run_output_folder}_log \ + --model_name_or_path=${model_name} \ + --learning_rate=${lr} \ + --source_prefix="${task_name}: " \ + --lr_scheduler_type=${lr_scheduler} \ + --label_smoothing_factor=${label_smoothing} \ + --eval_steps ${eval_steps} \ + --decoding_format ${decoding_format} \ + --warmup_ratio ${warmup_ratio} \ + --preprocessing_num_workers=4 \ + --dataloader_num_workers=0 \ + --meta_negative=${negative} \ + --meta_positive_rate=${positive} \ + --skip_memory_metrics \ + --no_remove_unused_columns \ + --ordered_prompt=${ordered_prompt} \ + --save_better_checkpoint=True \ + --spot_noise=${spot_noise} \ + --asoc_noise=${asoc_noise} \ + --seed=${seed} --disable_tqdm=${disable_tqdm} >${stdout_file} 2>${stderr_file} + + echo "Map Config" ${map_config} + python3 scripts/sel2record.py -p ${run_output_folder} -g ${run_data_folder} -v -d ${decoding_format} -c ${map_config} + python3 scripts/eval_extraction.py -p ${run_output_folder} -g ${run_data_folder} -w -m ${eval_match_mode:-"normal"} + + # delete all pytorch_model.bin of checkpoints in low-resource exps for saving disk + # find ${run_output_folder}/ | grep -P "checkpoint-\d+/pytorch_model.bin" | xargs rm -rf + # delete all optimizer.pt in low-resource exps for saving disk + # find ${run_output_folder}/ | grep -P "optimizer.pt" | xargs rm -rf + + done + +done diff --git a/metaretriever/run_seq2seq_record_shot.bash b/metaretriever/run_seq2seq_record_shot.bash new file mode 100644 index 00000000..44afd53f --- /dev/null +++ b/metaretriever/run_seq2seq_record_shot.bash @@ -0,0 +1,99 @@ +#!/usr/bin/env bash +# -*- coding:utf-8 -*- +export batch_size="16" +export model_name=uie-base-en +export data_name=absa/14lap +export task_name="meta" +export decoding_format='spotasoc' + +source scripts/function_code.bash + +for index in $(seq 1 ${run_time}); do + output_dir=${model_folder}_run${index} + + if [[ ${verbose} == true ]] + then + stdout_file=/dev/stdout + stderr_file=/dev/stderr + disable_tqdm=False + else + stdout_file=${output_dir}.log + stderr_file=${output_dir}.err + disable_tqdm=True + fi + + shot_data_folder=${data_folder}_shot/seed${index} + + for shot in $(ls ${shot_data_folder}) + do + + run_data_folder=${shot_data_folder}/${shot} + run_output_folder=${output_dir}_${shot} + + if [[ ${max_prefix_length} == 0 ]] + then + run_output_folder=${run_output_folder}_noprefix + fi + + echo ${run_data_folder} + + eval_steps=$(python scripts/get_eval_batch_num.py ${run_data_folder}/train.json ${batch_size} 20) + echo Eval each ${eval_steps} batch + + CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} ${run_command} run_seq2seq.py \ + --do_train --do_eval --do_predict ${constraint_decoding} ${fp16} \ + --trainer_type=${trainer_type} \ + --load_config_only=False \ + --use_fast_tokenizer=True \ + --ddp_find_unused_parameters=False \ + --predict_with_generate \ + --evaluation_strategy=steps \ + --save_strategy=steps \ + --load_best_model_at_end \ + --metric_for_best_model eval_overall-F1 \ + --save_total_limit 1 \ + --max_source_length=${max_source_length:-"256"} \ + --max_prefix_length=${max_prefix_length:-"-1"} \ + --max_target_length=${max_target_length:-"192"} \ + --num_train_epochs=${epoch} \ + --task=${task_name} \ + --train_file=${run_data_folder}/train.json \ + --validation_file=${run_data_folder}/val.json \ + --test_file=${run_data_folder}/test.json \ + --record_schema=${run_data_folder}/record.schema \ + --per_device_train_batch_size=${batch_size} \ + --per_device_eval_batch_size=$((batch_size * 4)) \ + --output_dir=${run_output_folder} \ + --logging_dir=${run_output_folder}_log \ + --model_name_or_path=${model_name} \ + --learning_rate=${lr} \ + --source_prefix="${task_name}: " \ + --lr_scheduler_type=${lr_scheduler} \ + --label_smoothing_factor=${label_smoothing} \ + --eval_steps ${eval_steps} \ + --decoding_format ${decoding_format} \ + --warmup_ratio ${warmup_ratio} \ + --preprocessing_num_workers=4 \ + --dataloader_num_workers=0 \ + --meta_negative=${negative} \ + --meta_positive_rate=${positive} \ + --skip_memory_metrics \ + --no_remove_unused_columns \ + --ordered_prompt=${ordered_prompt} \ + --save_better_checkpoint=True \ + --spot_noise=${spot_noise} \ + --asoc_noise=${asoc_noise} \ + --seed=${seed} --disable_tqdm=${disable_tqdm} >${stdout_file} 2>${stderr_file} + + echo "Map Config" ${map_config} + python3 scripts/sel2record.py -p ${run_output_folder} -g ${run_data_folder} -v -d ${decoding_format} -c ${map_config} + python3 scripts/eval_extraction.py -p ${run_output_folder} -g ${run_data_folder} -w -m ${eval_match_mode:-"normal"} + + # delete all pytorch_model.bin of checkpoints in low-resource exps for saving disk + # find ${run_output_folder}/ | grep -P "checkpoint-\d+/pytorch_model.bin" | xargs rm -rf + # delete all optimizer.pt in low-resource exps for saving disk + # find ${run_output_folder}/ | grep -P "optimizer.pt" | xargs rm -rf + + done + +done diff --git a/metaretriever/scripts/check_offset_map_gold_as_pred.bash b/metaretriever/scripts/check_offset_map_gold_as_pred.bash new file mode 100644 index 00000000..e4133687 --- /dev/null +++ b/metaretriever/scripts/check_offset_map_gold_as_pred.bash @@ -0,0 +1,25 @@ +#!/usr/bin/env bash +# -*- coding:utf-8 -*- + +# Check Offset Mapping Performance +# 用于验证不同 SEL2Record 回标策略的准确值 +# bash scripts/check_offset_map_gold_as_pred.bash data/text2spotasocname/absa/14lap config/offset_map/closest_offset_en.yaml spotasocname + +folder_name=$1 +config_name=$2 +parser_format=$3 + +cat ${folder_name}/val.json | python -c "import json, sys +for line in sys.stdin: + print(json.loads(line.strip())['record']) +" > ${folder_name}/eval_preds_seq2seq.txt + +python scripts/sel2record.py \ + -c ${config_name} \ + -g ${folder_name} \ + -p ${folder_name} \ + -d ${parser_format} + +python scripts/eval_extraction.py \ + -g ${folder_name} \ + -p ${folder_name} -w diff --git a/metaretriever/scripts/eval_extraction.py b/metaretriever/scripts/eval_extraction.py new file mode 100644 index 00000000..54f73a8d --- /dev/null +++ b/metaretriever/scripts/eval_extraction.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +import argparse +import json +import os +import sys +import numpy as np +from pprint import pprint +from uie.extraction.scorer import EntityScorer, RelationScorer, EventScorer + + +def read_file(file_name): + return [line for line in open(file_name).readlines()] + + +def write_to_file(result, output_filename, prefix=None): + with open(output_filename, 'w') as output: + for key, value in result.items(): + if prefix: + key = '%s_%s' % (prefix, key) + output.write("%s=%s\n" % (key, value)) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('-g', dest='gold_folder', help="Golden Dataset folder") + parser.add_argument('-p', dest='pred_folder', nargs='+', help="Predicted model folder") + parser.add_argument('-v', dest='verbose', action='store_true', help='Show more information during running') + parser.add_argument('-w', dest='write_to_file', action='store_true', help="Write evaluation results to predicted folder") + parser.add_argument('-m', dest='match_mode', default='normal', choices=['set', 'normal', 'multimatch']) + parser.add_argument('-case', dest='case', action='store_true', help='Show case study') + options = parser.parse_args() + + data_dict = { + 'eval': ['eval_preds_record.txt', 'val.json'], + 'test': ['test_preds_record.txt', 'test.json'], + } + + task_dict = { + 'entity': EntityScorer, + 'relation': RelationScorer, + 'event': EventScorer, + } + + result_list = {'eval': list(), 'test': list()} + for pred_folder in options.pred_folder: + gold_folder = options.gold_folder + + for data_key, (generation, gold_file) in data_dict.items(): + + gold_filename = os.path.join(gold_folder, gold_file) + pred_filename = os.path.join(pred_folder, generation) + + if not os.path.exists(pred_filename): + sys.stderr.write("%s not found.\n" % pred_filename) + continue + + print("pred:", pred_filename) + print("gold:", gold_filename) + + if options.case: + for pred_line, gold_line in zip(read_file(pred_filename), read_file(gold_filename)): + gold_instance = json.loads(gold_line) + pred_instance = json.loads(pred_line) + print('=========================') + print(gold_instance['text']) + for task in task_dict: + scorer = task_dict[task] + gold = scorer.load_gold_list([gold_instance[task]])[0] + pred = scorer.load_pred_list([pred_instance[task]])[0] + min_length = max( + len(gold['string']), + len(pred['string']), + len(gold.get('string_trigger', [])), + len(pred.get('string_trigger', [])), + len(gold.get('string_role', [])), + len(pred.get('string_role', [])), + ) + if min_length == 0: + continue + if task == 'entity': + print("Entity Gold:", sorted(gold['string'])) + print("Entity Pred:", sorted(pred['string'])) + if task == 'relation': + print("Relation Gold:", sorted(gold['string'])) + print("Relation Pred:", sorted(pred['string'])) + if task == 'event': + print("Event Gold Trigger:", sorted(gold['string_trigger'])) + print("Event Pred Trigger:", sorted(pred['string_trigger'])) + print("Event Gold Role :", sorted(gold['string_role'])) + print("Event Pred Role :", sorted(pred['string_role'])) + + results = dict() + for task in task_dict: + if task not in json.loads(read_file(pred_filename)[0]): + continue + scorer = task_dict[task] + gold_list = [json.loads(line)[task] for line in read_file(gold_filename)] + pred_list = [json.loads(line)[task] for line in read_file(pred_filename)] + + assert len(pred_list) == len(gold_list) + gold_instance_list = scorer.load_gold_list(gold_list) + pred_instance_list = scorer.load_pred_list(pred_list) + assert len(pred_instance_list) == len(gold_instance_list) + sub_results = scorer.eval_instance_list( + gold_instance_list=gold_instance_list, + pred_instance_list=pred_instance_list, + verbose=options.verbose, + match_mode=options.match_mode, + ) + results.update(sub_results) + + pprint(results) + result_list[data_key] += [results] + + if options.write_to_file: + output_filename = "%s/%s_results.txt" % (pred_folder, data_key) + write_to_file( + result=results, + output_filename=output_filename, + prefix=data_key, + ) + + print("===========> AVG <===========") + + for data_key in data_dict: + if len(result_list[data_key]) < 1: + continue + for key in result_list[data_key][0]: + ave = np.mean([result[key] for result in result_list[data_key]]) + print(data_key, key, ave) + + +if __name__ == "__main__": + main() diff --git a/metaretriever/scripts/function_code.bash b/metaretriever/scripts/function_code.bash new file mode 100644 index 00000000..acde69c6 --- /dev/null +++ b/metaretriever/scripts/function_code.bash @@ -0,0 +1,266 @@ +#!/usr/bin/env bash +# -*- coding:utf-8 -*- +export CUDA_VISIBLE_DEVICES="0" +export lr=1e-4 +export run_time="1" +export seed="42" +export lr_scheduler=linear +export label_smoothing="0" +export epoch=30 +export eval_steps=0 +export warmup_ratio=0 +export constraint_decoding='' +export verbose=false +export fp16='' +export negative=-1 +export positive=1 +export ordered_prompt=True +export max_source_length=256 +export spot_noise=0 +export asoc_noise=0 +export map_config=config/offset_map/closest_offset_en.yaml + +OPTS=$(getopt -o b:d:m:i:t:k:s:l:f:n:v --long batch:,device:,model:,data:,task:,run-time:,seed:,lr:,lr_scheduler:,label_smoothing:,epoch:,format:,eval_steps:,warmup_ratio:,constraint_decoding,verbose,preprocess,fp16:,negative:,random_prompt,max_source_length:,max_target_length:,spot_noise:,asoc_noise:,positive:,map_config:,trainer_type:,output_dir:,use_prompt_tuning_model:, -n 'parse-options' -- "$@") + +if [ $? != 0 ]; then + echo "Failed parsing options." >&2 + exit 1 +fi + +eval set -- "$OPTS" + +while true; do + case "$1" in + -b | --batch) + batch_size="$2" + shift + shift + ;; + -d | --device) + CUDA_VISIBLE_DEVICES="$2" + shift + shift + ;; + -m | --model) + model_name="$2" + shift + shift + ;; + -i | --data) + data_name="$2" + shift + shift + ;; + -t | --task) + task_name="$2" + shift + shift + ;; + -k | --run-time) + run_time="$2" + shift + shift + ;; + -s | --seed) + seed="$2" + shift + shift + ;; + -l | --lr) + lr="$2" + shift + shift + ;; + -f | --format) + decoding_format="$2" + shift + shift + ;; + -n | --negative) + negative="$2" + shift + shift + ;; + -p | --positive) + positive="$2" + shift + shift + ;; + --lr_scheduler) + lr_scheduler="$2" + shift + shift + ;; + --label_smoothing) + label_smoothing="$2" + shift + shift + ;; + --epoch) + epoch="$2" + shift + shift + ;; + --eval_steps) + eval_steps="$2" + shift + shift + ;; + --warmup_ratio) + warmup_ratio="$2" + shift + shift + ;; + --max_source_length) + max_source_length="$2" + shift + shift + ;; + --max_target_length) + max_target_length="$2" + shift + shift + ;; + --spot_noise) + spot_noise="$2" + shift + shift + ;; + --asoc_noise) + asoc_noise="$2" + shift + shift + ;; + --fp16) + fp16="$2" + shift + shift + ;; + --map_config) + map_config="$2" + shift + shift + ;; + --trainer_type) + trainer_type="$2" + shift + shift + ;; + --output_dir) + output_dir="$2" + shift + shift + ;; + --constraint_decoding) + constraint_decoding="--constraint_decoding" + shift + ;; + --preprocess) + preprocess=True + shift + ;; + --random_prompt) + ordered_prompt=False + shift + ;; + --use_prompt_tuning_model) + use_prompt_tuning_model="$2" + shift + shift + ;; + -v | --verbose) + verbose=true + shift + ;; + --) + shift + break + ;; + *) + echo "$1" not recognize. + exit + ;; + esac +done + + +get_gpu_num() { + IFS=, + num=0 + for i in ${CUDA_VISIBLE_DEVICES} + do + num=$((${num} + 1)) + done + echo ${num} + return ${num} +} + +function rand(){ + min=$1 + max=$(($2-$min+1)) + num=$(($RANDOM+1000000000)) + echo $(($num%$max+$min)) +} + +gpu_num=$(get_gpu_num) +# 若使用多 GPU,则使用 distributed 版本的 PyTorch +# For multiple GPU, use the Distributed version of PyTorch +if [[ ${gpu_num} == 1 ]] +then + run_command=python3 +else + master_port=$(rand 10000 50000) + echo "Master Port: ${master_port}" + run_command="python3 -m torch.distributed.launch --nproc_per_node ${gpu_num} --master_port ${master_port}" +fi + +echo "Map Config" ${map_config} + +# 不指定 eval_steps 则每一个 epoch 进行一次模型验证 +# Without specifying eval_steps, model validation is performed once for each epoch +if [[ ${eval_steps} == 0 ]] +then + evaluation_strategy='epoch' +else + evaluation_strategy='steps' +fi + +# google/mt5-base -> google_mt5-base +model_name_log=$(echo ${model_name} | sed -s "s/\//_/g") +data_name_log=$(echo ${data_name} | sed -s "s/\//_/g") +batch_log=$((gpu_num * batch_size)) + +EXP_ID=$(date +%F-%H-%M-$RANDOM) + +model_folder=output/${task_name}_${EXP_ID}_${model_name_log}_${decoding_format}_${data_name_log}_e${epoch}_${lr_scheduler}_lr${lr}_ls${label_smoothing}_b${batch_log}_wu${warmup_ratio}_n${negative} +if [[ ${constraint_decoding} != "" ]] +then + model_folder=${model_folder}_CD +fi +if [[ ${ordered_prompt} == False ]] +then + model_folder=${model_folder}_RP +fi +if [[ ${spot_noise} != 0 ]] +then + model_folder=${model_folder}_sn${spot_noise} +fi +if [[ ${asoc_noise} != 0 ]] +then + model_folder=${model_folder}_an${asoc_noise} +fi +if [[ ${positive} != 1 ]] +then + model_folder=${model_folder}_p${positive} +fi + +data_folder=data/text2${decoding_format}/${data_name} + +export TOKENIZERS_PARALLELISM=false + +if [[ ${fp16} != "" ]] +then + fp16="--fp16 --fp16_backend apex --fp16_opt_level ${fp16}" +fi + +export PYTHONPATH="${PYTHONPATH}:./" diff --git a/metaretriever/scripts/get_eval_batch_num.py b/metaretriever/scripts/get_eval_batch_num.py new file mode 100644 index 00000000..1eba400a --- /dev/null +++ b/metaretriever/scripts/get_eval_batch_num.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +import sys +import math + +file_name = sys.argv[1] +batch_size = int(sys.argv[2]) +eval_epoch = int(sys.argv[3]) + +line_num = sum([1 for _ in open(sys.argv[1])]) +print(int(math.ceil(line_num / float(batch_size)) * eval_epoch)) + +# python scripts/get_eval_batch_num.py ${run_data_folder}/train.json ${batch_size} 20 diff --git a/metaretriever/scripts/inference_all.bash b/metaretriever/scripts/inference_all.bash new file mode 100644 index 00000000..2706148e --- /dev/null +++ b/metaretriever/scripts/inference_all.bash @@ -0,0 +1,24 @@ +#!/usr/bin/env bash +# -*- coding:utf-8 -*- + +export DEVICE=0 +export model_path=uie_models + +CUDA_VISIBLE_DEVICES=${DEVICE} python inference.py --data data/text2spotasoc/absa/14lap --model ${model_path}/absa_14lap_65.25 --batch_size 64 +CUDA_VISIBLE_DEVICES=${DEVICE} python inference.py --data data/text2spotasoc/absa/14res --model ${model_path}/absa_14res_74.59 --batch_size 64 +CUDA_VISIBLE_DEVICES=${DEVICE} python inference.py --data data/text2spotasoc/absa/15res --model ${model_path}/absa_15res_68.30 --batch_size 64 +CUDA_VISIBLE_DEVICES=${DEVICE} python inference.py --data data/text2spotasoc/absa/16res --model ${model_path}/absa_16res_76.57 --batch_size 64 +CUDA_VISIBLE_DEVICES=${DEVICE} python inference.py --data data/text2spotasoc/absa/14lap --model ${model_path}/absa_14lap_base_63.95 --batch_size 64 +CUDA_VISIBLE_DEVICES=${DEVICE} python inference.py --data data/text2spotasoc/absa/14res --model ${model_path}/absa_14res_base_73.63 --batch_size 64 +CUDA_VISIBLE_DEVICES=${DEVICE} python inference.py --data data/text2spotasoc/absa/15res --model ${model_path}/absa_15res_base_64.68 --batch_size 64 +CUDA_VISIBLE_DEVICES=${DEVICE} python inference.py --data data/text2spotasoc/absa/16res --model ${model_path}/absa_16res_base_73.23 --batch_size 64 +CUDA_VISIBLE_DEVICES=${DEVICE} python inference.py --data data/text2spotasoc/entity/mrc_ace04 --model ${model_path}/ent_ace04ent_86.87 --batch_size 64 +CUDA_VISIBLE_DEVICES=${DEVICE} python inference.py --data data/text2spotasoc/entity/mrc_ace05 --model ${model_path}/ent_ace05ent_85.89 --batch_size 64 +CUDA_VISIBLE_DEVICES=${DEVICE} python inference.py --data data/text2spotasoc/relation/ace05-rel --model ${model_path}/rel_ace05-rel_66.22 --batch_size 64 +CUDA_VISIBLE_DEVICES=${DEVICE} python inference.py --data data/text2spotasoc/relation/conll04 --model ${model_path}/rel_conll04_large_74.97 --batch_size 64 +CUDA_VISIBLE_DEVICES=${DEVICE} python inference.py --data data/text2spotasoc/relation/NYT --model ${model_path}/rel_nyt_93.53 --batch_size 64 --match_mode set +CUDA_VISIBLE_DEVICES=${DEVICE} python inference.py --data data/text2spotasoc/relation/scierc --model ${model_path}/rel_scierc_large_37.05 --batch_size 64 +CUDA_VISIBLE_DEVICES=${DEVICE} python inference.py --data data/text2spotasoc/event/oneie_ace05_en_event --model ${model_path}/evt_ace05evt_74.06_55.97 --batch_size 64 +CUDA_VISIBLE_DEVICES=${DEVICE} python inference.py --data data/text2spotasoc/event/casie --model ${model_path}/evt_casie_69.97_61.24 --batch_size 64 +CUDA_VISIBLE_DEVICES=${DEVICE} python inference.py --data data/text2spotasoc/entity/conll03 --model ${model_path}/ent_conll03_92.97 --batch_size 64 +CUDA_VISIBLE_DEVICES=${DEVICE} python inference.py --data data/text2spotasoc/relation/NYT --model ${model_path}/rel_nyt_base_92.46 --batch_size 64 --match_mode set diff --git a/metaretriever/scripts/sel2record.py b/metaretriever/scripts/sel2record.py new file mode 100644 index 00000000..10eb4af2 --- /dev/null +++ b/metaretriever/scripts/sel2record.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +import argparse +import os +import logging +import json + +from uie.sel2record.record import MapConfig +from uie.sel2record.sel2record import SEL2Record + +logger = logging.getLogger(__name__) + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument('-g', dest='gold_folder', help='Gold Folder') + parser.add_argument('-p', dest='pred_folder', nargs='+', help='Pred Folder') + + parser.add_argument('-c', '--config', dest='map_config', help='Offset Mapping Config') + parser.add_argument('-d', dest='decoding', default='spotasoc') + parser.add_argument('-v', '--verbose', dest='verbose', + action='store_true', help='More details information.') + options = parser.parse_args() + + map_config = MapConfig.load_from_yaml(options.map_config) + schema_dict = SEL2Record.load_schema_dict(options.gold_folder) + sel2record = SEL2Record( + schema_dict=schema_dict, + decoding_schema=options.decoding, + map_config=map_config, + ) + + data_dict = { + 'eval': ['eval_preds_seq2seq.txt', 'val.json', 'eval_preds_record.txt'], + 'test': ['test_preds_seq2seq.txt', 'test.json', 'test_preds_record.txt'], + } + + for pred_folder in options.pred_folder: + gold_folder = options.gold_folder + + for data_key, (generation, gold_file, record_file) in data_dict.items(): + + pred_filename = os.path.join(pred_folder, generation) + + if not os.path.exists(pred_filename): + logger.warning("%s not found.\n" % pred_filename) + continue + + gold_filename = os.path.join(gold_folder, gold_file) + + print("pred:", pred_filename) if options.verbose else None + print("gold:", gold_filename) if options.verbose else None + + # Only using text and tokens in Gold file + gold_list = [json.loads(line) for line in open(gold_filename)] + gold_text_list = [gold['text'] for gold in gold_list] + gold_token_list = [gold['tokens'] for gold in gold_list] + + pred_list = [line.strip() for line in open(pred_filename).readlines()] + + assert len(gold_text_list) == len(pred_list) + + pred_records = list() + for pred, text, tokens in zip(pred_list, gold_text_list, gold_token_list): + pred_record = sel2record.sel2record(pred, text, tokens) + pred_records += [pred_record] + + with open(os.path.join(pred_folder, record_file), 'w') as output: + for record in pred_records: + output.write(json.dumps(record, ensure_ascii=False) + '\n') + + +if __name__ == "__main__": + main() diff --git a/metaretriever/scripts/show_length_count.py b/metaretriever/scripts/show_length_count.py new file mode 100644 index 00000000..0f4cb2b9 --- /dev/null +++ b/metaretriever/scripts/show_length_count.py @@ -0,0 +1,105 @@ +import argparse +import json +import os +from collections import Counter, defaultdict +from transformers import AutoTokenizer +from tabulate import tabulate +from tqdm import tqdm +from uie.seq2seq.t5_bert_tokenizer import T5BertTokenizer +from uie.extraction.dataset_processer import PrefixGenerator +from uie.extraction.record_schema import RecordSchema + + +def find_key(count): + if count > 512: + return '7.>512' + elif 384 < count <= 512: + return "6.384-512" + elif 320 < count <= 384: + return "5.320-384" + elif 256 < count <= 320: + return "4.256-320" + elif 192 < count <= 256: + return "3.192-256" + elif 128 < count <= 192: + return "2.128-192" + elif 64 < count <= 128: + return "1. 64-128" + elif count == 0: + return "8. =0" + else: + return "0. <64" + + +def get_acc_list(counter): + sum_instance = float(sum(counter.values())) + acc_list = list() + acc_counter = defaultdict(int) + for k in sorted(counter.keys()): + v = counter[k] + acc_counter[find_key(k)] += v + acc = 0 + for k in sorted(acc_counter.keys()): + acc += acc_counter[k] + acc_list += [(k, acc, "%.2f" % (acc / sum_instance * 100))] + return acc_list + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('-data', required=True, nargs='+') + parser.add_argument('-tokenize', default='hf_models/t5-small') + parser.add_argument('-fast', action='store_true') + parser.add_argument('-key', default='record') + options = parser.parse_args() + + if "t5-char" in options.tokenize: + tokenizer = T5BertTokenizer.from_pretrained(options.tokenize) + else: + tokenizer = AutoTokenizer.from_pretrained(options.tokenize, use_fast=options.fast) + print("Load tokenize: ", options.tokenize) + + to_add_special_token = list() + for special_token in ['', '', '', '', '', '', '', '']: + if special_token not in tokenizer.get_vocab(): + to_add_special_token += [special_token] + tokenizer.add_special_tokens({"additional_special_tokens": to_add_special_token}) + + for data_folder in options.data: + print(data_folder) + + record_schema = RecordSchema.read_from_file(data_folder + '/record.schema') + schema_prefix = PrefixGenerator.get_schema_prefix(record_schema) + len_schema_prefix = len(tokenizer.tokenize(schema_prefix)) + print("Schema Propmt: %s" % schema_prefix) + print("Schema Propmt After Toknized: %s" % tokenizer.tokenize(schema_prefix)) + print("Schema Prompt Length: %s" % len_schema_prefix) + for file_type in {"train", "val", "test", "align"}: + counter = defaultdict(Counter) + filename = os.path.join(data_folder, file_type + '.json') + if not os.path.exists(filename): + print('Skip %s' % filename) + continue + + for line in tqdm(open(filename).readlines(), unit='line'): + instance = json.loads(line) + text = instance['text'] + record = instance[options.key] + counter['Text'].update([len(tokenizer.tokenize(text))]) + counter['Record'].update([len(tokenizer.tokenize(record))]) + counter['Text + Schema'].update([len(tokenizer.tokenize(text)) + len_schema_prefix]) + counter['Record + Schema Prompt'].update([len(tokenizer.tokenize(record)) + len_schema_prefix]) + if len(tokenizer.tokenize(record)) > 512: + print("[Length > 512 Text ]:", text) + print("[Length > 512 Record]:", record) + + for k, v in counter.items(): + print(file_type, k) + table = get_acc_list(v) + print(tabulate(table)) + print(f"Min: {min(v.keys())}") + print(f"Max: {max(v.keys())}") + + +if __name__ == "__main__": + main() diff --git a/metaretriever/scripts/summary_performance.bash b/metaretriever/scripts/summary_performance.bash new file mode 100644 index 00000000..521f79e9 --- /dev/null +++ b/metaretriever/scripts/summary_performance.bash @@ -0,0 +1,40 @@ +#!/usr/bin/env bash +# -*- coding:utf-8 -*- + +for record_type in entity relation relation-boundary event record +do + + echo -e "\n==============>" String ${record_type} "<==============" + python3 scripts/summary_result.py -record ${record_type} -string -model output/* | grep checkpoint- + +done + +for record_type in entity relation relation-boundary event record +do + + echo -e "\n==============>" Offset ${record_type} "<==============" + python3 scripts/summary_result.py -record ${record_type} -model output/* | grep checkpoint- +done + + +for record_type in entity relation relation-boundary event record +do + + echo -e "\n==============>" Mean String ${record_type} "<==============" + python3 scripts/summary_result.py -mean -reduce run -record ${record_type} -string + + echo -e "\n==============>" String ${record_type} "<==============" + python3 scripts/summary_result.py -record ${record_type} -string + +done + +for record_type in entity relation relation-boundary event record +do + + echo -e "\n==============>" Mean Offset ${record_type} "<==============" + python3 scripts/summary_result.py -mean -reduce run -record ${record_type} + + echo -e "\n==============>" Offset ${record_type} "<==============" + python3 scripts/summary_result.py -record ${record_type} + +done diff --git a/metaretriever/scripts/summary_result.py b/metaretriever/scripts/summary_result.py new file mode 100644 index 00000000..973a16a9 --- /dev/null +++ b/metaretriever/scripts/summary_result.py @@ -0,0 +1,309 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +import json +import os +from collections import OrderedDict +import numpy as np + +from tabulate import tabulate + +event_record_result_valid_keys = [ + 'eval_offset-evt-trigger-P', 'eval_offset-evt-trigger-R', 'eval_offset-evt-trigger-F1', 'eval_offset-evt-role-P', 'eval_offset-evt-role-R', 'eval_offset-evt-role-F1', + 'test_offset-evt-trigger-P', 'test_offset-evt-trigger-R', 'test_offset-evt-trigger-F1', 'test_offset-evt-role-P', 'test_offset-evt-role-R', 'test_offset-evt-role-F1', +] + +span_record_result_valid_keys = [ + 'eval_offset-ent-P', 'eval_offset-ent-R', 'eval_offset-ent-F1', + 'test_offset-ent-P', 'test_offset-ent-R', 'test_offset-ent-F1', +] + +relation_strict_record_result_valid_keys = [ + 'eval_offset-rel-strict-P', 'eval_offset-rel-strict-R', 'eval_offset-rel-strict-F1', + 'test_offset-rel-strict-P', 'test_offset-rel-strict-R', 'test_offset-rel-strict-F1', +] + +relation_boundary_record_result_valid_keys = [ + 'eval_offset-rel-boundary-P', 'eval_offset-rel-boundary-R', 'eval_offset-rel-boundary-F1', + 'test_offset-rel-boundary-P', 'test_offset-rel-boundary-R', 'test_offset-rel-boundary-F1', +] + +record_result_valid_keys = [ + 'eval_offset-ent-F1', 'eval_offset-rel-boundary-F1', 'eval_offset-rel-strict-F1', 'eval_offset-evt-trigger-F1', 'eval_offset-evt-role-F1', + 'test_offset-ent-F1', 'test_offset-rel-boundary-F1', 'test_offset-rel-strict-F1', 'test_offset-evt-trigger-F1', 'test_offset-evt-role-F1', +] + + +def align_float(x): + return '%.2f' % x if isinstance(x, float) else x + + +def parse_trainer_state(filename): + trainer_state = json.load(open(filename)) + if trainer_state['best_model_checkpoint'] is not None: + return trainer_state['best_model_checkpoint'].split('/')[-1].replace('checkpoint-', '') + else: + return 'last' + + +def parse_global_step(filename): + return str(json.load(open(filename))['global_step']) + + +def check_out_of_memory(filename): + if os.path.exists(filename): + try: + with open(filename) as fin: + for line in fin: + if 'CUDA out of memory' in line: + return True + except UnicodeDecodeError: + return False + return False + + +def get_run_name(folder_name, prefix): + split_list = folder_name.replace('/', '_').split('_') \ + if prefix == 'run' \ + else folder_name.split('_')[1:] + new_att_list = list() + for att in split_list: + if att.startswith(prefix): + continue + new_att_list += [att] + return '_'.join(new_att_list) + + +class ResultSummary: + def __init__(self, result_valid_keys): + self.result_valid_keys = result_valid_keys + self.header_result_valid_keys = [ + value.replace('trigger', 't').replace('role', 'r').replace('eval', 'e').replace('test', 't').replace( + 'F1', 'F').replace('offset', 'o').replace('string', 's').replace('strict', 's').replace('boundary', 'b') + for value in result_valid_keys] + for x, y in zip(self.result_valid_keys, self.header_result_valid_keys): + print("%s -> %s" % (x, y)) + + def parse_best_log(self, folder_name, file_map, default_key='running'): + result = dict() + + eval_result_filename = os.path.join(folder_name, file_map['eval']) + test_result_filename = os.path.join(folder_name, file_map['test']) + + lines = list() + if os.path.exists(eval_result_filename): + lines += open(eval_result_filename).readlines() + if os.path.exists(test_result_filename): + lines += open(test_result_filename).readlines() + + for line in lines: + key, value = line.strip().split('=') + if key.strip() not in self.result_valid_keys: + continue + result[key.strip()] = float(value.strip()) + + for key in self.result_valid_keys: + if key not in result: + result[key] = default_key + + return result + + def get_valid_folder(self, model_folders, file_map, span_pretrain=False): + all_result = list() + + for model_folder in model_folders: + print(model_folder) + sub_folder_list = sorted(os.listdir(model_folder)) + for sub_folder_name in sub_folder_list: + if sub_folder_name.endswith('log') or sub_folder_name.endswith('err'): + continue + + sub_folder = os.path.join(model_folder, sub_folder_name) + log_filename = sub_folder + '.log' + + if span_pretrain: + if os.path.exists(os.path.join(sub_folder, 'span_pretrain')): + default_key = 'running' + trained_folder = os.path.join(sub_folder, 'span_pretrain') + log_filename = os.path.join( + sub_folder, 'span_pretrain.log') + state_filename = os.path.join( + sub_folder, 'span_pretrain', 'trainer_state.json') + else: + print('Unused folder: %s' % sub_folder) + continue + else: + + if os.path.exists(os.path.join(sub_folder, 'event_finetune')): + default_key = 'finetune' + trained_folder = os.path.join(sub_folder, 'event_finetune') + log_filename = os.path.join( + sub_folder, 'event_finetune.log') + state_filename = os.path.join( + sub_folder, 'event_finetune', 'trainer_state.json') + + elif os.path.exists(os.path.join(sub_folder, 'span_pretrain')): + default_key = 'pretrain' + trained_folder = os.path.join(sub_folder, 'span_pretrain') + log_filename = os.path.join( + sub_folder, 'span_pretrain.log') + state_filename = os.path.join( + sub_folder, 'span_pretrain', 'trainer_state.json') + + else: + default_key = 'running' + state_filename = os.path.join( + sub_folder, 'trainer_state.json') + trained_folder = sub_folder + + if os.path.exists(log_filename): + out_of_memory = check_out_of_memory(log_filename) + else: + out_of_memory = False + + if out_of_memory: + result = {key: 'OOM' for key in self.result_valid_keys} + checkpoint = 'OOM' + else: + result = self.parse_best_log( + trained_folder, file_map, default_key) + checkpoint = parse_trainer_state(state_filename) if os.path.exists( + state_filename) else default_key + global_step = parse_global_step(state_filename) if os.path.exists( + state_filename) else default_key + checkpoint = checkpoint + '/' + global_step + + all_result += [[sub_folder, checkpoint, result]] + return all_result + + def result_to_table(self, all_result, sort_key=0): + table = list() + for sub_folder_name, checkpoint, result in all_result: + table += [[sub_folder_name, checkpoint] + + [result.get(key, 'running') for key in self.result_valid_keys]] + + table = [[align_float(x) for x in y] for y in table] + + table.sort() + table.sort(key=lambda x: x[sort_key]) + + print(tabulate(table, headers=[ + 'exp', 'checkpoint'] + self.header_result_valid_keys)) + + def result_to_table_reduce(self, all_result, sort_key=0, reduce_function=np.mean, reduce_key='run'): + table = list() + sub_run = OrderedDict() + for sub_folder_name, checkpoint, result in all_result: + + sub_run_name = get_run_name(sub_folder_name, reduce_key) + if sub_run_name not in sub_run: + sub_run[sub_run_name] = list() + + sub_run_result = [result.get(key, 'running') + for key in self.result_valid_keys] + if 'running' in sub_run_result or 'OOM' in sub_run_result: + continue + + sub_run[sub_run_name] += [sub_run_result] + + for sub_run_name, sub_run_results in sub_run.items(): + if len(sub_run_results) == 0: + table += [[sub_run_name, 0] + ['-']] + else: + table += [[sub_run_name, len(sub_run_results)] + + list(reduce_function(sub_run_results, 0))] + + table = [[align_float(x) for x in y] for y in table] + + table.sort() + table.sort(key=lambda x: x[sort_key]) + + print(tabulate(table, headers=['exp', 'num'] + + self.header_result_valid_keys)) + + +def main(): + record_valid_keys_map = { + 'entity': span_record_result_valid_keys, + 'relation': relation_strict_record_result_valid_keys, + 'relation-boundary': relation_boundary_record_result_valid_keys, + 'event': event_record_result_valid_keys, + 'record': record_result_valid_keys, + } + + import argparse + parser = argparse.ArgumentParser( + description='Summary Multi-run Result' + ) + parser.add_argument('-model', dest='model', default=['output'], nargs='+', + help='Output Model Folder Path') + parser.add_argument('-sort', dest='sort', default=0, + type=int, help='Sort Column Index') + parser.add_argument('-mean', dest='mean', action='store_true', + help='Reduce by mean Function') + parser.add_argument('-std', dest='std', action='store_true', + help='Reduce by std Function') + parser.add_argument('-span-pretrain', dest='span_pretrain', + action='store_true', + help='Load Span Pretrain Result for Text2Event') + parser.add_argument('-record', dest='record', default='record', + choices=record_valid_keys_map.keys(), + help='Record Type') + parser.add_argument('-string', dest='offset', action='store_false', + help='Report String Match Result') + parser.add_argument('-offset', dest='offset', action='store_true', + help='Report Offset Match Result (default)') + parser.set_defaults(offset=True) + parser.add_argument('-reduce', dest='reduce', default='run', + help='Reduce Key, default is `run`') + options = parser.parse_args() + + if options.record in record_valid_keys_map: + file_map = { + 'eval': 'eval_results.txt', + 'test': 'test_results.txt', + } + else: + raise NotImplementedError('Invalid Record Type: %s' % options.record) + + result_valid_keys = record_valid_keys_map[options.record] + + if not options.offset: + result_valid_keys = [key.replace('offset', 'string') + for key in result_valid_keys] + + result_summary = ResultSummary( + result_valid_keys=result_valid_keys + ) + print(options.model) + + def check_valid_model(x): + return not (os.path.isfile(x) or x.endswith('_log')) + + valid_model_paths = filter(check_valid_model, options.model) + + all_result = result_summary.get_valid_folder( + model_folders=valid_model_paths, + file_map=file_map, + span_pretrain=options.span_pretrain + ) + + if options.mean: + result_summary.result_to_table_reduce( + all_result, + sort_key=options.sort, + reduce_function=np.mean, + reduce_key=options.reduce, + ) + elif options.std: + result_summary.result_to_table_reduce( + all_result, + sort_key=options.sort, + reduce_function=np.std, + reduce_key=options.reduce + ) + else: + result_summary.result_to_table(all_result, sort_key=options.sort) + + +if __name__ == "__main__": + main() diff --git a/metaretriever/scripts_exp/meta_run.bash b/metaretriever/scripts_exp/meta_run.bash new file mode 100644 index 00000000..56a3fe51 --- /dev/null +++ b/metaretriever/scripts_exp/meta_run.bash @@ -0,0 +1,15 @@ +#!/bin/bash + +function get_gpu_id() { + gpu_node=$1 + selected_gpus="" + gpu_array=("0" "1" "2" "3" "4" "5" "6" "7" "8" "9" "10" "11" "12" "13" "14" "15") + for(( i=0;i<${gpu_node};i++ )) do + if [[ ${selected_gpus} == "" ]]; then + selected_gpus=${gpu_array[i]} + else + selected_gpus=${selected_gpus}","${gpu_array[i]} + fi + done; + echo "${selected_gpus}" +} diff --git a/metaretriever/scripts_exp/run_exp.bash b/metaretriever/scripts_exp/run_exp.bash new file mode 100644 index 00000000..0e7ab80b --- /dev/null +++ b/metaretriever/scripts_exp/run_exp.bash @@ -0,0 +1,58 @@ +#!/bin/bash + +source scripts_exp/meta_run.bash +# selected_gpus=${GPU:-"`get_gpu_id $gpu_node`"} +export CUDA_VISIBLE_DEVICES=$selected_gpus + +# Load Hyper-parameters +IFS=' ' +read -ra BATCH_SIZE <<<"${BATCH_SIZE}" +read -ra LR_RATE <<<"${LR_RATE}" +read -ra WARMUP_PROP <<<"${WARMUP_PROP}" +read -ra LABEL_SMOOTHING <<<"${LABEL_SMOOTHING}" +read -ra NEGATIVE <<<"${NEGATIVE}" +read -ra NOISE <<<"${NOISE}" + +for batch_size in "${BATCH_SIZE[@]}"; do + echo "batch " ${batch_size} + + for noise in "${NOISE[@]}"; do + echo "noise " ${noise} + for learning_rate in "${LR_RATE[@]}"; do + echo "learning rate " ${learning_rate} + for warmup_ratio in "${WARMUP_PROP[@]}"; do + echo "warmup ratio " ${warmup_ratio} + for label_smoothing in "${LABEL_SMOOTHING[@]}"; do + echo "label smoothing " ${label_smoothing} + for negative in "${NEGATIVE[@]}"; do + echo "negative " ${negative} + + bash run_seq2seq_record.bash -k ${run_time} \ + -m uie_models/${model_name} \ + -d ${selected_gpus} \ + -i ${dataset_name} \ + --trainer_type ${trainer_type} \ + --use_prompt_tuning_model ${use_prompt_tuning_model} \ + --lr_scheduler linear \ + --epoch ${epoch} \ + --eval_steps ${eval_steps} \ + --batch ${batch_size} \ + --label_smoothing ${label_smoothing} \ + --lr ${learning_rate} \ + --warmup_ratio ${warmup_ratio} \ + --max_source_length ${max_source_length} \ + --spot_noise ${noise} --asoc_noise ${noise} \ + --negative ${negative} --random_prompt --map_config ${map_config} + + bash scripts/summary_performance.bash > output/best.performance.now + + done + done + done + done + done +done + +bash scripts/summary_performance.bash + +exit 0 diff --git a/metaretriever/scripts_exp/run_exp_ratio.bash b/metaretriever/scripts_exp/run_exp_ratio.bash new file mode 100644 index 00000000..159abc73 --- /dev/null +++ b/metaretriever/scripts_exp/run_exp_ratio.bash @@ -0,0 +1,59 @@ +#!/bin/bash + +source scripts_exp/meta_run.bash +# selected_gpus=${GPU:-"`get_gpu_id $gpu_node`"} +export CUDA_VISIBLE_DEVICES=$selected_gpus + +# Load Hyper-parameters +IFS=' ' +read -ra BATCH_SIZE <<<"${BATCH_SIZE}" +read -ra LR_RATE <<<"${LR_RATE}" +read -ra WARMUP_PROP <<<"${WARMUP_PROP}" +read -ra LABEL_SMOOTHING <<<"${LABEL_SMOOTHING}" +read -ra NEGATIVE <<<"${NEGATIVE}" +read -ra NOISE <<<"${NOISE}" + +for batch_size in "${BATCH_SIZE[@]}"; do + echo "batch " ${batch_size} + + for noise in "${NOISE[@]}"; do + echo "noise " ${noise} + for learning_rate in "${LR_RATE[@]}"; do + echo "learning rate " ${learning_rate} + for warmup_ratio in "${WARMUP_PROP[@]}"; do + echo "warmup ratio " ${warmup_ratio} + for label_smoothing in "${LABEL_SMOOTHING[@]}"; do + echo "label smoothing " ${label_smoothing} + for negative in "${NEGATIVE[@]}"; do + echo "negative " ${negative} + + bash run_seq2seq_record_ratio.bash -k ${run_time} \ + -m uie_models/${model_name} \ + -d ${selected_gpus} \ + -i ${dataset_name} \ + -f ${decoding_format} \ + --trainer_type ${trainer_type} \ + --use_prompt_tuning_model ${use_prompt_tuning_model} \ + --lr_scheduler constant \ + --epoch ${epoch} \ + --eval_steps ${eval_steps} \ + --batch ${batch_size} \ + --label_smoothing ${label_smoothing} \ + --lr ${learning_rate} \ + --warmup_ratio ${warmup_ratio} \ + --max_source_length ${max_source_length} \ + --spot_noise ${noise} --asoc_noise ${noise} \ + --negative ${negative} --random_prompt --map_config ${map_config} + + bash scripts/summary_performance.bash > output/best.performance.now + + done + done + done + done + done +done + +bash scripts/summary_performance.bash + +exit 0 diff --git a/metaretriever/scripts_exp/run_exp_shot.bash b/metaretriever/scripts_exp/run_exp_shot.bash new file mode 100644 index 00000000..3035f889 --- /dev/null +++ b/metaretriever/scripts_exp/run_exp_shot.bash @@ -0,0 +1,59 @@ +#!/bin/bash + +source scripts_exp/meta_run.bash +# selected_gpus=${GPU:-"`get_gpu_id $gpu_node`"} +export CUDA_VISIBLE_DEVICES=$selected_gpus + +# Load Hyper-parameters +IFS=' ' +read -ra BATCH_SIZE <<<"${BATCH_SIZE}" +read -ra LR_RATE <<<"${LR_RATE}" +read -ra WARMUP_PROP <<<"${WARMUP_PROP}" +read -ra LABEL_SMOOTHING <<<"${LABEL_SMOOTHING}" +read -ra NEGATIVE <<<"${NEGATIVE}" +read -ra NOISE <<<"${NOISE}" + +for batch_size in "${BATCH_SIZE[@]}"; do + echo "batch " ${batch_size} + + for noise in "${NOISE[@]}"; do + echo "noise " ${noise} + for learning_rate in "${LR_RATE[@]}"; do + echo "learning rate " ${learning_rate} + for warmup_ratio in "${WARMUP_PROP[@]}"; do + echo "warmup ratio " ${warmup_ratio} + for label_smoothing in "${LABEL_SMOOTHING[@]}"; do + echo "label smoothing " ${label_smoothing} + for negative in "${NEGATIVE[@]}"; do + echo "negative " ${negative} + + bash run_seq2seq_record_shot.bash -k ${run_time} \ + -m uie_models/${model_name} \ + -d ${selected_gpus} \ + -i ${dataset_name} \ + -f ${decoding_format} \ + --trainer_type ${trainer_type} \ + --use_prompt_tuning_model ${use_prompt_tuning_model} \ + --lr_scheduler constant \ + --epoch ${epoch} \ + --eval_steps ${eval_steps} \ + --batch ${batch_size} \ + --label_smoothing ${label_smoothing} \ + --lr ${learning_rate} \ + --warmup_ratio ${warmup_ratio} \ + --max_source_length ${max_source_length} \ + --spot_noise ${noise} --asoc_noise ${noise} \ + --negative ${negative} --random_prompt --map_config ${map_config} + + bash scripts/summary_performance.bash > output/best.performance.now + + done + done + done + done + done +done + +bash scripts/summary_performance.bash + +exit 0 diff --git a/metaretriever/uie/__init__.py b/metaretriever/uie/__init__.py new file mode 100644 index 00000000..5bfd17ea --- /dev/null +++ b/metaretriever/uie/__init__.py @@ -0,0 +1,2 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- diff --git a/metaretriever/uie/extraction/__init__.py b/metaretriever/uie/extraction/__init__.py new file mode 100644 index 00000000..5bfd17ea --- /dev/null +++ b/metaretriever/uie/extraction/__init__.py @@ -0,0 +1,2 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- diff --git a/metaretriever/uie/extraction/constants.py b/metaretriever/uie/extraction/constants.py new file mode 100644 index 00000000..a4fd08ac --- /dev/null +++ b/metaretriever/uie/extraction/constants.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- + +spot_prompt = '' +asoc_prompt = '' + +type_start = '' +type_end = '' +text_start = '' +span_start = '' +null_span = '' +null_label = '' + + +class StructureMarker: + def __init__(self) -> None: + pass + + +class BaseStructureMarker(StructureMarker): + def __init__(self) -> None: + super().__init__() + self.sent_start = '' + self.sent_end = '' + self.record_start = '' + self.record_end = '' + self.span_start = '' + self.span_end = '' + self.text_start = '' + self.source_span_start = '' + self.source_span_end = '' + self.target_span_start = '' + self.null_span = '' + self.null_label = '' diff --git a/metaretriever/uie/extraction/dataset_processer.py b/metaretriever/uie/extraction/dataset_processer.py new file mode 100644 index 00000000..7ff1215b --- /dev/null +++ b/metaretriever/uie/extraction/dataset_processer.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +from uie.extraction.record_schema import RecordSchema +from uie.extraction.constants import spot_prompt, asoc_prompt, text_start + + +class TaskConfig: + def __init__(self, task_dict) -> None: + self.dataset_name = task_dict.get('name', '') + self.task_name = task_dict.get('task', '') + self.data_path = task_dict.get('path', '') + self.decoding_format = task_dict.get('decoding_format', '') + self.weight = int(task_dict.get('weight', 0)) + self.sel2record = task_dict.get('sel2record', '') + self.metrics = task_dict.get('metrics', []) + self.eval_match_mode = task_dict.get('eval_match_mode', 'normal') + self.schema = RecordSchema.read_from_file(f"{self.data_path}/{self.task_name}.schema") + + def __repr__(self) -> str: + return f"dataset: {self.dataset_name}\n" \ + f"task : {self.task_name}\n" \ + f"format : {self.decoding_format}\n" \ + f"path : {self.data_path}\n" \ + f"schema : {self.schema}\n" \ + f"metrics: {self.metrics}\n" \ + f"eval_match_mode : {self.eval_match_mode}" + + @staticmethod + def load_list_from_yaml(task_config): + import yaml + configs = yaml.load(open(task_config), Loader=yaml.FullLoader) + task_configs = filter(lambda x: x.startswith('T'), configs) + for task_config in task_configs: + yield TaskConfig(configs[task_config]) + + +class PrefixGenerator: + def __init__(self, prefix_dict) -> None: + self.type_list = prefix_dict.get('type', 'task dataset').split() + self.position = prefix_dict.get('position', 'encoder') + + def __repr__(self) -> str: + return f"Type. : {self.type_list}\n" \ + f"Position: {self.position}\n" + + @staticmethod + def load_from_yaml(dataset_config): + import yaml + configs = yaml.load(open(dataset_config), Loader=yaml.FullLoader) + return PrefixGenerator(configs['Prefix']) + + @staticmethod + def get_schema_prefix(schema: RecordSchema, add_split=True): + prefix_list = list() + for spot_label in sorted(schema.type_list): + prefix_list += [spot_prompt, spot_label] + for asoc_label in sorted(schema.role_list): + prefix_list += [asoc_prompt, asoc_label] + prefix = ' '.join(prefix_list) + if add_split: + return prefix + f' {text_start} ' + else: + return prefix + + @staticmethod + def get_dataset_name_prefix(dataset: TaskConfig, add_split=True): + if add_split: + return dataset.dataset_name + f' {text_start}' + else: + return dataset.dataset_name + + @staticmethod + def get_task_name_prefix(dataset: TaskConfig, add_split=True): + if add_split: + return dataset.task_name + f' {text_start}' + else: + return dataset.task_name + + def get_prefix_by_dataset(self, dataset: TaskConfig): + prefix_list = list() + for prefix_type in self.type_list: + if prefix_type == 'task': + prefix = self.get_task_name_prefix(dataset, add_split=False) + elif prefix_type == 'dataset': + prefix = self.get_dataset_name_prefix(dataset, add_split=False) + elif prefix_type == 'schema': + prefix = self.get_schema_prefix(dataset.schema, add_split=False) + elif prefix_type == 'meta': + # Meta 使用 Schema 的 Prefix + prefix = self.get_schema_prefix(dataset.schema, add_split=False) + else: + raise NotImplementedError( + "Prefix Type %s is not supported" % prefix_type + ) + prefix_list += [prefix] + return ' '.join(prefix_list) + f' {text_start}' diff --git a/metaretriever/uie/extraction/extraction_metrics.py b/metaretriever/uie/extraction/extraction_metrics.py new file mode 100644 index 00000000..f4815136 --- /dev/null +++ b/metaretriever/uie/extraction/extraction_metrics.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +from typing import List +from uie.extraction.record_schema import RecordSchema +from uie.extraction.predict_parser import get_predict_parser, PredictParser +from uie.extraction.scorer import Metric, RecordMetric, OrderedRecordMetric + + +def eval_pred(predict_parser: PredictParser, gold_list, pred_list, text_list=None, raw_list=None): + well_formed_list, counter = predict_parser.decode( + gold_list, pred_list, text_list, raw_list + ) + + spot_metric = Metric() + asoc_metric = Metric() + record_metric = RecordMetric() + ordered_record_metric = OrderedRecordMetric() + + for instance in well_formed_list: + spot_metric.count_instance(instance['gold_event'], instance['pred_event']) + asoc_metric.count_instance(instance['gold_role'], instance['pred_role']) + record_metric.count_instance(instance['gold_record'], instance['pred_record']) + ordered_record_metric.count_instance(instance['gold_record'], instance['pred_record']) + + spot_result = spot_metric.compute_f1(prefix='spot-') + asoc_result = asoc_metric.compute_f1(prefix='asoc-') + record_result = record_metric.compute_f1(prefix='record-') + ordered_record_result = ordered_record_metric.compute_f1(prefix='ordered-record-') + + overall_f1 = spot_result.get('spot-F1', 0.) + asoc_result.get('asoc-F1', 0.) + # print(counter) + result = {'overall-F1': overall_f1} + result.update(spot_result) + result.update(asoc_result) + result.update(record_result) + result.update(ordered_record_result) + result.update(counter) + return result + + +def get_extract_metrics(pred_lns: List[str], tgt_lns: List[str], label_constraint: RecordSchema, decoding_format='tree'): + predict_parser = get_predict_parser(decoding_schema=decoding_format, label_constraint=label_constraint) + return eval_pred( + predict_parser=predict_parser, + gold_list=tgt_lns, + pred_list=pred_lns + ) diff --git a/metaretriever/uie/extraction/label_tree.py b/metaretriever/uie/extraction/label_tree.py new file mode 100644 index 00000000..4f47b366 --- /dev/null +++ b/metaretriever/uie/extraction/label_tree.py @@ -0,0 +1,53 @@ +from typing import Dict + + +def list_dictionary(d, n_tab=-1): + if isinstance(d, list): + for i in d: + list_dictionary(i, n_tab) + elif isinstance(d, dict): + n_tab += 1 + for key, value in d.items(): + if key == '': + print("{}{}".format(" " * n_tab, key)) + else: + print("{}{}".format(" " * n_tab, key)) + list_dictionary(value, n_tab) + else: + print("{}{}".format("\t" * n_tab, d)) + + +def print_tree(tree): + list_dictionary(tree) + + +def get_label_name_tree(label_name_list, tokenizer, end_symbol=''): + sub_token_tree = dict() + + label_tree = dict() + for typename in label_name_list: + after_tokenized = tokenizer.encode(typename, add_special_tokens=False) + # label_tree[typename] = tokenizer.convert_ids_to_tokens(after_tokenized) + label_tree[typename] = after_tokenized + + for _, sub_label_seq in label_tree.items(): + parent = sub_token_tree + for value in sub_label_seq: + if value not in parent: + parent[value] = dict() + parent = parent[value] + + parent[end_symbol] = None + + return sub_token_tree + + +class PrefixTree: + def __init__(self, label_name_list, tokenizer, end_symbol=''): + self.label_name_list = label_name_list + self._tokenizer = tokenizer + self.label_name_tree = get_label_name_tree(label_name_list, tokenizer, end_symbol) + self._end_symbol = end_symbol + + def is_end_of_tree(self, tree: Dict): + return len(tree) == 1 and self._end_symbol in tree diff --git a/metaretriever/uie/extraction/noiser/__init__.py b/metaretriever/uie/extraction/noiser/__init__.py new file mode 100644 index 00000000..5bfd17ea --- /dev/null +++ b/metaretriever/uie/extraction/noiser/__init__.py @@ -0,0 +1,2 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- diff --git a/metaretriever/uie/extraction/noiser/spot_asoc_noiser.py b/metaretriever/uie/extraction/noiser/spot_asoc_noiser.py new file mode 100644 index 00000000..16f2c117 --- /dev/null +++ b/metaretriever/uie/extraction/noiser/spot_asoc_noiser.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +from uie.extraction import constants +from dataclasses import dataclass +import numpy as np +from uie.extraction.utils import * + + +@dataclass +class SpotAsocNoiser: + spot_noise_ratio: float = 0.1 + asoc_noise_ratio: float = 0.1 + null_span: str = constants.null_span + + def random_insert_spot(self, spot_asoc, spot_label_list=None): + """随机插入 Spot,类别从 spot_label_list 中自动选择 + + Args: + spot_asoc ([type]): [description] + spot_label_list ([type], optional): [description]. Defaults to None. + + Returns: + [type]: [description] + """ + if spot_label_list is None or len(spot_label_list) == 0: + return spot_asoc + random_num = sum(np.random.binomial(1, self.spot_noise_ratio, len(spot_asoc))) + for _ in range(random_num): + random_position = np.random.randint(low=0, high=len(spot_asoc)) + random_label = np.random.choice(spot_label_list) + spot_asoc.insert( + random_position, + {"span": self.null_span, "label": random_label, 'asoc': list()} + ) + return spot_asoc + + def random_insert_asoc(self, spot_asoc, asoc_label_list=None): + """随机插入 Asoc,类别从 asoc_label_list 中自动选择 + + Args: + spot_asoc ([type]): [description] + asoc_label_list ([type], optional): [description]. Defaults to None. + + Returns: + [type]: [description] + """ + if asoc_label_list is None or len(asoc_label_list) == 0: + return spot_asoc + # asoc_sum = sum([len(x['asoc']) for x in spot_asoc]) + spot_sum = len(spot_asoc) + random_num = sum(np.random.binomial(1, self.asoc_noise_ratio, spot_sum)) + for _ in range(random_num): + random_label = np.random.choice(asoc_label_list) + spot_position = np.random.randint(low=0, high=len(spot_asoc)) + asoc_position = np.random.randint(low=0, high=len(spot_asoc[spot_position]['asoc']) + 1) + spot_asoc[spot_position]['asoc'].insert( + asoc_position, + (random_label, self.null_span) + ) + return spot_asoc + + def add_noise(self, spot_asoc, spot_label_list, asoc_label_list): + spot_asoc = self.random_insert_asoc( + spot_asoc=spot_asoc, + asoc_label_list=asoc_label_list, + ) + spot_asoc = self.random_insert_spot( + spot_asoc=spot_asoc, + spot_label_list=spot_label_list, + ) + return spot_asoc + + +def main(): + from uie.extraction.constants import BaseStructureMarker + structure_marker = BaseStructureMarker() + spot_asoc = [{"span": "analyzer", "label": "generic", "asoc": []}, {"span": "`` Amorph ''", "label": "method", "asoc": []}] + + spot_asoc_noiser = SpotAsocNoiser( + spot_noise_ratio=0.5, + asoc_noise_ratio=0.5, + ) + spot_asoc_noiser.add_noise( + spot_asoc=spot_asoc, + spot_label_list=['A', 'B', 'C'], + asoc_label_list=['D', 'E', 'F'], + ) + target = convert_spot_asoc( + spot_asoc_instance=spot_asoc, + structure_maker=structure_marker + ) + + target = convert_spot_asoc( + spot_asoc_instance=spot_asoc, + structure_maker=structure_marker + ) + + replace_map = { + '': ' ( ', + '': ' ) ', + '': ':', + } + from nltk.tree import Tree + for old, new in replace_map.items(): + target = target.replace(old, new) + print(target) + Tree.fromstring(target).pretty_print() + + +if __name__ == "__main__": + main() diff --git a/metaretriever/uie/extraction/predict_parser/__init__.py b/metaretriever/uie/extraction/predict_parser/__init__.py new file mode 100644 index 00000000..3613d7e0 --- /dev/null +++ b/metaretriever/uie/extraction/predict_parser/__init__.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +from uie.extraction.predict_parser.predict_parser import PredictParser +from uie.extraction.predict_parser.spotasoc_predict_parser import SpotAsocPredictParser + + +decoding_format_dict = { + 'spotasoc': SpotAsocPredictParser, +} + + +def get_predict_parser(decoding_schema, label_constraint): + return decoding_format_dict[decoding_schema](label_constraint=label_constraint) diff --git a/metaretriever/uie/extraction/predict_parser/predict_parser.py b/metaretriever/uie/extraction/predict_parser/predict_parser.py new file mode 100644 index 00000000..4b7c5afb --- /dev/null +++ b/metaretriever/uie/extraction/predict_parser/predict_parser.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +from typing import List, Counter, Tuple + + +class PredictParser: + def __init__(self, label_constraint=None): + self.predicate_set = label_constraint.type_list if label_constraint else list() + self.role_set = label_constraint.role_list if label_constraint else list() + + def decode(self, gold_list, pred_list, text_list=None, raw_list=None) -> Tuple[List, Counter]: + """ + + :param gold_list: + :param pred_list: + :param text_list: + :param raw_list: + :return: + dict: + pred_event -> [(type1, trigger1), (type2, trigger2), ...] + gold_event -> [(type1, trigger1), (type2, trigger2), ...] + pred_role -> [(type1, role1, argument1), (type2, role2, argument2), ...] + gold_role -> [(type1, role1, argument1), (type2, role2, argument2), ...] + pred_record -> [{'type': type1, 'trigger': trigger1, 'roles': [(type1, role1, argument1), ...]}, + {'type': type2, 'trigger': trigger2, 'roles': [(type2, role2, argument2), ...]}, + ] + gold_record -> [{'type': type1, 'trigger': trigger1, 'roles': [(type1, role1, argument1), ...]}, + {'type': type2, 'trigger': trigger2, 'roles': [(type2, role2, argument2), ...]}, + ] + Counter: + """ + pass diff --git a/metaretriever/uie/extraction/predict_parser/spotasoc_predict_parser.py b/metaretriever/uie/extraction/predict_parser/spotasoc_predict_parser.py new file mode 100644 index 00000000..7ca2c6ee --- /dev/null +++ b/metaretriever/uie/extraction/predict_parser/spotasoc_predict_parser.py @@ -0,0 +1,302 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +from collections import Counter +import logging +from nltk.tree import ParentedTree +import re +from typing import Tuple, List, Dict + + +from uie.extraction.constants import ( + null_span, + type_start, + type_end, + span_start, +) +from uie.extraction.predict_parser.predict_parser import PredictParser +from uie.extraction.predict_parser.utils import fix_unk_from_text + +logger = logging.getLogger(__name__) + + +left_bracket = '【' +right_bracket = '】' +brackets = left_bracket + right_bracket + +split_bracket = re.compile(r"") + + +def add_space(text): + """ + add space between special token + :param text: + :return: + """ + new_text_list = list() + for item in zip(split_bracket.findall(text), split_bracket.split(text)[1:]): + new_text_list += item + return ' '.join(new_text_list) + + +def convert_bracket(text): + text = add_space(text) + for start in [type_start]: + text = text.replace(start, left_bracket) + for end in [type_end]: + text = text.replace(end, right_bracket) + return text + + +def find_bracket_num(tree_str): + """ + Count Bracket Number, 0 indicate num_left = num_right + :param tree_str: + :return: + """ + count = 0 + for char in tree_str: + if char == left_bracket: + count += 1 + elif char == right_bracket: + count -= 1 + else: + pass + return count + + +def check_well_form(tree_str): + return find_bracket_num(tree_str) == 0 + + +def clean_text(tree_str): + count = 0 + sum_count = 0 + + tree_str_list = tree_str.split() + + for index, char in enumerate(tree_str_list): + if char == left_bracket: + count += 1 + sum_count += 1 + elif char == right_bracket: + count -= 1 + sum_count += 1 + else: + pass + if count == 0 and sum_count > 0: + return ' '.join(tree_str_list[:index + 1]) + return ' '.join(tree_str_list) + + +def resplit_label_span(label, span, split_symbol=span_start): + label_span = label + ' ' + span + + if split_symbol in label_span: + try: + new_label, new_span = label_span.split(split_symbol) + return new_label.strip(), new_span.strip() + except: + print('resplit_label_span error:', label_span, split_symbol) + + return label, span + + +def add_bracket(tree_str): + """ + add right bracket to fill ill-formed + :param tree_str: + :return: + """ + tree_str_list = tree_str.split() + bracket_num = find_bracket_num(tree_str_list) + tree_str_list += [right_bracket] * bracket_num + return ' '.join(tree_str_list) + + +def get_tree_str(tree): + """ + get str from event tree + :param tree: + :return: + """ + str_list = list() + for element in tree: + if isinstance(element, str): + str_list += [element] + return ' '.join(str_list) + + +def rewrite_label_span(label, span, label_set=None, text=None): + + # Invalid Type + if label_set and label not in label_set: + logger.debug('Invalid Label: %s' % label) + return None, None + + # Fix unk using Text + if text is not None and '' in span: + span = fix_unk_from_text(span, text, '') + + # Invalid Text Span + if text is not None and span not in text: + logger.debug('Invalid Text Span: %s\n%s\n' % (span, text)) + return None, None + + return label, span + + +class SpotAsocPredictParser(PredictParser): + + def decode(self, gold_list, pred_list, text_list=None, raw_list=None + ) -> Tuple[List[Dict], Counter]: + """ + + :param gold_list: + :param pred_list: + :param text_list: + :param raw_list: + :return: + dict: + pred_event -> [(type1, trigger1), (type2, trigger2), ...] + gold_event -> [(type1, trigger1), (type2, trigger2), ...] + pred_role -> [(type1, role1, argument1), (type2, role2, argument2), ...] + gold_role -> [(type1, role1, argument1), (type2, role2, argument2), ...] + pred_record -> [{'type': type1, 'trigger': trigger1, 'roles': [(type1, role1, argument1), ...]}, + {'type': type2, 'trigger': trigger2, 'roles': [(type2, role2, argument2), ...]}, + ] + gold_record -> [{'type': type1, 'trigger': trigger1, 'roles': [(type1, role1, argument1), ...]}, + {'type': type2, 'trigger': trigger2, 'roles': [(type2, role2, argument2), ...]}, + ] + Counter: + """ + counter = Counter() + well_formed_list = [] + + if gold_list is None or len(gold_list) == 0: + gold_list = ["%s%s" % (type_start, type_end)] * len(pred_list) + + if text_list is None: + text_list = [None] * len(pred_list) + + if raw_list is None: + raw_list = [None] * len(pred_list) + + for gold, pred, text, raw_data in zip(gold_list, pred_list, text_list, raw_list): + gold = convert_bracket(gold) + pred = convert_bracket(pred) + + pred = clean_text(pred) + + try: + gold_tree = ParentedTree.fromstring(gold, brackets=brackets) + except ValueError: + logger.warning(f"Ill gold: {gold}") + logger.warning(f"Fix gold: {add_bracket(gold)}") + gold_tree = ParentedTree.fromstring( + add_bracket(gold), brackets=brackets) + counter.update(['gold_tree add_bracket']) + + instance = {'gold': gold, + 'pred': pred, + 'gold_tree': gold_tree, + 'text': text, + 'raw_data': raw_data + } + + counter.update(['gold_tree' for _ in gold_tree]) + + instance['gold_event'], instance['gold_role'], instance['gold_record'] = self.get_record_list( + tree=instance["gold_tree"], + text=instance['text'] + ) + + try: + if not check_well_form(pred): + pred = add_bracket(pred) + counter.update(['fixed']) + + pred_tree = ParentedTree.fromstring(pred, brackets=brackets) + counter.update(['pred_tree' for _ in pred_tree]) + + instance['pred_tree'] = pred_tree + counter.update(['well-formed']) + + except ValueError: + counter.update(['ill-formed']) + logger.debug('ill-formed', pred) + instance['pred_tree'] = ParentedTree.fromstring( + left_bracket + right_bracket, + brackets=brackets + ) + + instance['pred_event'], instance['pred_role'], instance['pred_record'] = self.get_record_list( + tree=instance["pred_tree"], + text=instance['text'] + ) + + well_formed_list += [instance] + + return well_formed_list, counter + + def get_record_list(self, tree, text=None): + + spot_list = list() + asoc_list = list() + record_list = list() + + for spot_tree in tree: + + if isinstance(spot_tree, str): + continue + + if len(spot_tree) == 0: + continue + + spot_type = spot_tree.label() + spot_trigger = get_tree_str(spot_tree) + spot_type, spot_trigger = resplit_label_span( + spot_type, spot_trigger) + spot_type, spot_trigger = rewrite_label_span( + label=spot_type, + span=spot_trigger, + label_set=self.predicate_set, + text=text + ) + + if spot_trigger == null_span: + continue + + if spot_type is None or spot_trigger is None: + continue + + record = {'roles': list(), + 'type': spot_type, + 'trigger': spot_trigger} + + for asoc_tree in spot_tree: + if isinstance(asoc_tree, str) or len(asoc_tree) < 1: + continue + + asoc_label = asoc_tree.label() + asoc_text = get_tree_str(asoc_tree) + asoc_label, asoc_text = resplit_label_span( + asoc_label, asoc_text) + asoc_label, asoc_text = rewrite_label_span( + label=asoc_label, + span=asoc_text, + label_set=self.role_set, + text=text + ) + + if asoc_text == null_span: + continue + if asoc_label is None or asoc_text is None: + continue + + asoc_list += [(spot_type, asoc_label, asoc_text)] + record['roles'] += [(asoc_label, asoc_text)] + + spot_list += [(spot_type, spot_trigger)] + record_list += [record] + + return spot_list, asoc_list, record_list diff --git a/metaretriever/uie/extraction/predict_parser/utils.py b/metaretriever/uie/extraction/predict_parser/utils.py new file mode 100644 index 00000000..f20e3df6 --- /dev/null +++ b/metaretriever/uie/extraction/predict_parser/utils.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import re + + +def fix_unk_from_text(span, text, unk=''): + """ + Find span from the text to fix unk in the generated span + 从 text 中找到 span,修复span + + Example: + span = " colo e Bengo" + text = "At 159 meters above sea level , Angola International Airport is located at Ícolo e Bengo , part of Luanda Province , in Angola ." + + span = " colo e Bengo" + text = "Ícolo e Bengo , part of Luanda Province , in Angola ." + + span = "Arr s negre" + text = "The main ingredients of Arròs negre , which is from Spain , are white rice , cuttlefish or squid , cephalopod ink , cubanelle and cubanelle peppers . Arròs negre is from the Catalonia region ." + + span = "colo " + text = "At 159 meters above sea level , Angola International Airport is located at e Bengo , part of Luanda Province , in Angola . coloÍ" + + span = "Tarō As" + text = "The leader of Japan is Tarō Asō ." + + span = "Tar As" + text = "The leader of Japan is Tarō Asō ." + + span = "Tar As" + text = "The leader of Japan is ōTar Asō ." + """ + if unk not in span: + return span + + def clean_wildcard(x): + sp = ".*?()[]+" + return re.sub("("+"|".join([f"\\{s}" for s in sp])+")", "\\\\\g<1>", x) + + match = r'\s*\S+\s*'.join([clean_wildcard(item.strip()) for item in span.split(unk)]) + + result = re.search(match, text) + + if not result: + return span + return result.group().strip() + + +def test_fix_unk_from_text(): + + span_text_list = [ + (" colo e Bengo", + "At 159 meters above sea level , Angola International Airport is located at Ícolo e Bengo , part of Luanda Province , in Angola .", + "Ícolo e Bengo"), + (" colo e Bengo", + "Ícolo e Bengo , part of Luanda Province , in Angola .", + "Ícolo e Bengo"), + ("Arr s negre", + "The main ingredients of Arròs negre , which is from Spain , are white rice , cuttlefish or squid , cephalopod ink , cubanelle and cubanelle peppers . Arròs negre is from the Catalonia region .", + "Arròs negre"), + ("colo ", + "At 159 meters above sea level , Angola International Airport is located at e Bengo , part of Luanda Province , in Angola . coloÍ", + "coloÍ"), + ("Tarō As", "The leader of Japan is Tarō Asō .", "Tarō Asō"), + ("Tar As", "The leader of Japan is Tarō Asō .", "Tarō Asō"), + ("Tar As", "The leader of Japan is ōTar Asō .", "ōTar Asō"), + ("Atatürk Monument ( zmir )", + "The Atatürk Monument ( İzmir ) can be found in Turkey .", + "Atatürk Monument ( İzmir )"), + ("The Atatürk Monument [ zmir ]", + "The Atatürk Monument [ İzmir ] can be found in Turkey .", + "The Atatürk Monument [ İzmir ]") + ] + + for span, text, gold in span_text_list: + print(span, '|', fix_unk_from_text(span, text)) + assert fix_unk_from_text(span, text) == gold + + +if __name__ == "__main__": + test_fix_unk_from_text() diff --git a/metaretriever/uie/extraction/record_extractor.py b/metaretriever/uie/extraction/record_extractor.py new file mode 100644 index 00000000..fd4aec9b --- /dev/null +++ b/metaretriever/uie/extraction/record_extractor.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +import sys +import math +from typing import List, Union +import torch +from tqdm.auto import tqdm +from transformers import T5TokenizerFast, T5ForConditionalGeneration +from uie.extraction.predict_parser import get_predict_parser +from uie.extraction.record_schema import RecordSchema +from uie.seq2seq.constraint_decoder import get_constraint_decoder +from uie.seq2seq.t5_bert_tokenizer import T5BertTokenizer + + +class RecordExtractor: + """记录抽取器,动态更换解码的 Schema + """ + + def __init__(self, + tokenizer: T5TokenizerFast, + model: T5ForConditionalGeneration, + tree_parser=None, constraint_decoder=None, device=None): + self.tokenizer = tokenizer + self.model = model + self.tree_parser = tree_parser + self.constraint_decoder = constraint_decoder + + self.device = device + + self.to_remove_token_list = list() + if tokenizer.bos_token: + self.to_remove_token_list += [tokenizer.bos_token] + if tokenizer.eos_token: + self.to_remove_token_list += [tokenizer.eos_token] + if tokenizer.pad_token: + self.to_remove_token_list += [tokenizer.pad_token] + + sys.stdout.write(f'to remove: {self.to_remove_token_list}\n') + + @staticmethod + def from_pretrained(model_path, device: Union[int, List[int]] = None): + """读取预训练模型到指定 device + + Args: + model_path (str): 模型路径 + device ([int]], optional): Transformer 模型所在 GPU + + Returns: + RecordExtractor : 记录抽取器不包含的 tree_parser constraint_decoder + """ + print(f'load model from {model_path} ...') + if "t5-char" in model_path: + tokenizer = T5BertTokenizer.from_pretrained(model_path) + else: + tokenizer = T5TokenizerFast.from_pretrained(model_path) + model = T5ForConditionalGeneration.from_pretrained(model_path) + model.eval() + + if device is not None: + print("Moving mdoel to %s" % device) + model = model.to(device) + + return RecordExtractor( + tokenizer=tokenizer, + model=model, + tree_parser=None, + constraint_decoder=None, + device=device, + ) + + def preds_to_sequence_texts(self, preds): + """ 预测结果进行后处理,Index -> Token + + Args: + preds ([type]): [description] + + Returns: + List[str]: Seq2Seq 模型预测结果 + """ + test_preds = self.tokenizer.batch_decode( + preds, + skip_special_tokens=False, + clean_up_tokenization_spaces=False, + ) + test_preds = [self.postprocess_text(pred) for pred in test_preds] + return test_preds + + def postprocess_text(self, x_str): + # Clean `bos` `eos` `pad` for cleaned text + for to_remove_token in self.to_remove_token_list: + x_str = x_str.replace(to_remove_token, '') + return x_str.strip() + + @staticmethod + def load_record_schema(tokenizer, + record_schema: RecordSchema, + decoding_schema='spotasoc', + prefix='meta: ', + task_name='record'): + + # 读取解析器 + tree_parser = get_predict_parser( + decoding_schema=decoding_schema, + label_constraint=record_schema, + ) + + # 读取受限解码器 + constraint_decoder = get_constraint_decoder( + tokenizer=tokenizer, + type_schema=record_schema, + decoding_schema=decoding_schema, + task_name=task_name, + source_prefix=prefix, + ) + + return tree_parser, constraint_decoder + + def renew_record_schema(self, record_schema, decoding_schema='spotasoc', prefix='meta: ', task_name='record'): + + """ 使用新的解码框架 """ + sys.stdout.write(f"Renew schema: \n`{record_schema}`\n") + sys.stdout.write(f"Renew decoding: `{decoding_schema}`\n") + sys.stdout.write(f"Renew prefix: `{prefix}`\n") + + tree_parser, constraint_decoder = self.load_record_schema( + tokenizer=self.tokenizer, + record_schema=record_schema, + decoding_schema=decoding_schema, + prefix=prefix, + task_name=task_name, + ) + self.tree_parser = tree_parser + self.constraint_decoder = constraint_decoder + + def extract_record(self, text_list, constrained_decoding=False, batch_size=32, max_length=512): + text_list = [self.constraint_decoder.source_prefix + text for text in text_list] + + model_inputs = self.tokenizer( + text_list, + max_length=max_length, + padding=False, + truncation=True + ) + + num_batch = math.ceil(len(text_list) / batch_size) + outputs = list() + self.model.eval() + + for index in tqdm(range(num_batch)): + batch_model_inputs = { + k: v[index * batch_size: (index + 1) * batch_size] + for k, v in model_inputs.items() + } + + batch_model_inputs = self.tokenizer.pad( + batch_model_inputs, + padding=True, + return_tensors="pt", + ) + + if self.device is not None: + batch_model_inputs = batch_model_inputs.to(self.device) + + def prefix_allowed_tokens_fn(batch_id, sent): + src_sentence = batch_model_inputs['input_ids'][batch_id] + return self.constraint_decoder.constraint_decoding( + src_sentence=src_sentence, + tgt_generated=sent + ) + + with torch.no_grad(): + batch_outputs = self.model.generate( + input_ids=batch_model_inputs["input_ids"], + attention_mask=batch_model_inputs["attention_mask"], + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn if constrained_decoding else None, + max_length=max_length, + ) + + outputs += batch_outputs + + assert len(outputs) == len(text_list) + + # Index -> Str + sequence_list = self.preds_to_sequence_texts(outputs) + + # Str -> Record + record_list, _ = self.tree_parser.decode( + pred_list=sequence_list, + gold_list=[], + text_list=text_list + ) + record_list = [event['pred_record'] for event in record_list] + + return text_list, sequence_list, record_list diff --git a/metaretriever/uie/extraction/record_schema.py b/metaretriever/uie/extraction/record_schema.py new file mode 100644 index 00000000..ab64aa2d --- /dev/null +++ b/metaretriever/uie/extraction/record_schema.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +import json +from collections import defaultdict +from typing import List + + +class RecordSchema: + def __init__(self, type_list, role_list, type_role_dict): + self.type_list = type_list + self.role_list = role_list + self.type_role_dict = type_role_dict + + def __repr__(self) -> str: + return f"Type: {self.type_list}\n" \ + f"Role: {self.role_list}\n" \ + f"Map: {self.type_role_dict}" + + @staticmethod + def get_empty_schema(): + return RecordSchema(type_list=list(), role_list=list(), type_role_dict=dict()) + + @staticmethod + def read_from_file(filename): + lines = open(filename).readlines() + type_list = json.loads(lines[0]) + role_list = json.loads(lines[1]) + type_role_dict = json.loads(lines[2]) + return RecordSchema(type_list, role_list, type_role_dict) + + def write_to_file(self, filename): + with open(filename, 'w') as output: + output.write(json.dumps(self.type_list) + '\n') + output.write(json.dumps(self.role_list) + '\n') + output.write(json.dumps(self.type_role_dict) + '\n') + + +def merge_schema(schema_list: List[RecordSchema]): + type_set = set() + role_set = set() + type_role_dict = defaultdict(list) + + for schema in schema_list: + + for type_name in schema.type_list: + type_set.add(type_name) + + for role_name in schema.role_list: + role_set.add(role_name) + + for type_name in schema.type_role_dict: + type_role_dict[type_name] += schema.type_role_dict[type_name] + + for type_name in type_role_dict: + type_role_dict[type_name] = list(set(type_role_dict[type_name])) + + return RecordSchema(type_list=list(type_set), + role_list=list(role_set), + type_role_dict=type_role_dict, + ) diff --git a/metaretriever/uie/extraction/scorer.py b/metaretriever/uie/extraction/scorer.py new file mode 100644 index 00000000..6aab73ca --- /dev/null +++ b/metaretriever/uie/extraction/scorer.py @@ -0,0 +1,604 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +from collections import defaultdict +from copy import deepcopy +from typing import Dict, List +import sys + + +def tuple_offset(offset): + if isinstance(offset, tuple): + return offset + else: + return tuple(offset) + + +class Metric: + """ Tuple Metric """ + def __init__(self, verbose=False, match_mode='normal'): + self.tp = 0. + self.gold_num = 0. + self.pred_num = 0. + self.verbose = verbose + self.match_mode = match_mode + assert self.match_mode in {'set', 'normal', 'multimatch'} + + def __repr__(self) -> str: + return f"tp: {self.tp}, gold: {self.gold_num}, pred: {self.pred_num}" + + @staticmethod + def safe_div(a, b): + if b == 0.: + return 0. + else: + return a / b + + def compute_f1(self, prefix=''): + tp = self.tp + pred_num = self.pred_num + gold_num = self.gold_num + p, r = self.safe_div(tp, pred_num), self.safe_div(tp, gold_num) + return {prefix + 'tp': tp, + prefix + 'gold': gold_num, + prefix + 'pred': pred_num, + prefix + 'P': p * 100, + prefix + 'R': r * 100, + prefix + 'F1': self.safe_div(2 * p * r, p + r) * 100 + } + + def count_instance(self, gold_list, pred_list): + if self.match_mode == 'set': + gold_list = set(gold_list) + pred_list = set(pred_list) + if self.verbose: + print("Gold:", gold_list) + print("Pred:", pred_list) + self.gold_num += len(gold_list) + self.pred_num += len(pred_list) + self.tp += len(gold_list & pred_list) + + else: + if self.verbose: + print("Gold:", gold_list) + print("Pred:", pred_list) + self.gold_num += len(gold_list) + self.pred_num += len(pred_list) + + if len(gold_list) > 0 and len(pred_list) > 0: + # guarantee length same + assert len(gold_list[0]) == len(pred_list[0]) + + dup_gold_list = deepcopy(gold_list) + for pred in pred_list: + if pred in dup_gold_list: + self.tp += 1 + if self.match_mode == 'normal': + # Each Gold Instance can be matched one time + dup_gold_list.remove(pred) + + def count_batch_instance(self, batch_gold_list, batch_pred_list): + for gold_list, pred_list in zip(batch_gold_list, batch_pred_list): + self.count_instance(gold_list=gold_list, pred_list=pred_list) + + +class RecordMetric(Metric): + """ 不考虑不同 Role 之间的顺序,例如事件论元""" + @staticmethod + def is_equal(gold, pred): + if gold['type'] != pred['type']: + return False + if gold['trigger'] != pred['trigger']: + return False + if len(gold['roles']) != len(pred['roles']): + return False + for gold_role, pred_role in zip(sorted(gold['roles']), sorted(pred['roles'])): + if gold_role != pred_role: + return False + return True + + def count_instance(self, gold_list, pred_list): + if self.match_mode == 'set': + raise NotImplementedError(f'{self.__class__.__name__} do not support the match model `set`') + + if self.verbose: + print("Gold:", gold_list) + print("Pred:", pred_list) + + self.gold_num += len(gold_list) + self.pred_num += len(pred_list) + + gold_indexes = list(range(len(gold_list))) + non_found = [True] * len(gold_list) + for pred in pred_list: + for gold_index in gold_indexes: + if non_found[gold_index] and self.is_equal(gold_list[gold_index], pred): + self.tp += 1 + non_found[gold_index] = False + if self.match_mode == 'normal': + break + + +class OrderedRecordMetric(RecordMetric): + """ 考虑不同 Role 之间的顺序,例如关系 """ + @staticmethod + def is_equal(gold, pred): + if gold['type'] != pred['type']: + return False + if gold['trigger'] != pred['trigger']: + return False + if len(gold['roles']) != len(pred['roles']): + return False + for gold_role, pred_role in zip(gold['roles'], pred['roles']): + if gold_role != pred_role: + return False + return True + + +def warning_tp_increment(gold, pred, prefix): + sys.stderr.write(f"{prefix} TP Increment Warning, Gold Offset: {gold['offset']}\n") + sys.stderr.write(f"{prefix} TP Increment Warning, Pred Offset: {pred['offset']}\n") + sys.stderr.write(f"{prefix} TP Increment Warning, Gold String: {gold['string']}\n") + sys.stderr.write(f"{prefix} TP Increment Warning, Pred String: {pred['string']}\n") + sys.stderr.write(f"===============\n") + + +class Scorer: + @staticmethod + def load_gold_list(gold_list, offset_key=None): + raise NotImplementedError + + @staticmethod + def load_pred_list(pred_list): + raise NotImplementedError + + @staticmethod + def eval_instance_list(gold_instance_list, pred_instance_list, verbose=False, match_mode='normal'): + raise NotImplementedError + + +class EntityScorer(Scorer): + @staticmethod + def load_gold_list(gold_list: List[List[Dict]]): + """ Load gold instance to `string` and `offset` + + Args: + gold_list (List[List[Dict]]): [description] + [ + [ + {'type': 'Geo-political', 'offset': [7], 'text': 'seattle'}, + {'type': 'Location', 'offset': [11], 'text': 'lot'}, + {'type': 'Geo-political', 'offset': [14], 'text': 'city'} + ], + [...] + ] + + Returns: + List[Dict]: each instance has `offset` and `string` + [ + { + 'offset': [('Geo-political', (7,)), ('Location', (11,)), ('Geo-political', (14,))], + 'string': [('Geo-political', 'seattle'), ('Location', 'lot'), ('Geo-political', 'city')] + }, + {...}, ... + ] + """ + gold_instance_list = [] + for gold in gold_list: + gold_offset = list() + gold_string = list() + for span in gold: + span_label = span['type'] + span_offset = span['offset'] + span_text = span['text'] + gold_offset += [(span_label, tuple_offset(span_offset))] + gold_string += [(span_label, span_text)] + gold_instance = { + 'offset': gold_offset, + 'string': gold_string, + } + gold_instance_list += [gold_instance] + return gold_instance_list + + @staticmethod + def load_pred_list(pred_list: List[Dict]): + """[summary] + + Args: + pred_list (List[Dict]): [description] + [ + { + 'offset': [['Geo-political', [7]], ['Geo-political', [14]]], + 'string': [['Geo-political', 'seattle'], ['Geo-political', 'city']] + }, + {...}, + ] + Returns: + List[Dict] : each relation instance has `offset` and `string` + [ + { + 'offset': [('Geo-political', (7,)), ('Geo-political', (14,))], + 'string': [('Geo-political', 'seattle'), ('Geo-political', 'city')] + } + ] + """ + pred_instance_list = list() + for pred in pred_list: + for offset_pred in pred['offset']: + if not isinstance(offset_pred[1], tuple): + offset_pred[1] = tuple_offset(offset_pred[1]) + pred['offset'] = [tuple_offset(p) for p in pred['offset']] + pred['string'] = [tuple_offset(p) for p in pred['string']] + pred_instance_list += [pred] + return pred_instance_list + + @staticmethod + def eval_instance_list(gold_instance_list: List[Dict], pred_instance_list: List[Dict], verbose=False, match_mode='normal'): + """[summary] + + Args: + gold_instance_list (List[Dict]): [description] + [ + { + 'offset': [('Geo-political', (7,)), ('Location', (11,)), ('Geo-political', (14,))], + 'string': [('Geo-political', 'seattle'), ('Location', 'lot'), ('Geo-political', 'city')] + }, + {...}, ... + ] + pred_instance_list (List[Dict]): [description] + [ + { + 'offset': [('Geo-political', (7,)), ('Geo-political', (14,))], + 'string': [('Geo-political', 'seattle'), ('Geo-political', 'city')] + } + ] + verbose (bool, optional): [description]. Defaults to False. + match_mode (string, optional): [description]. Defaults to `normal` . + + Returns: + Dict: Result of Evaluation + (offset, string) X (gold, pred, tp, P, R, F1) + """ + metrics = { + 'string': Metric(verbose=verbose, match_mode=match_mode), + 'offset': Metric(verbose=verbose, match_mode=match_mode), + } + for pred, gold in zip(pred_instance_list, gold_instance_list): + + pre_string_tp, pre_offset_tp = metrics['string'].tp, metrics['offset'].tp + + for eval_key in metrics: + metrics[eval_key].count_instance( + gold_list=gold.get(eval_key, []), + pred_list=pred.get(eval_key, []) + ) + + post_string_tp, post_offset_tp = metrics['string'].tp, metrics['offset'].tp + if verbose and post_offset_tp - pre_offset_tp != post_string_tp - pre_string_tp: + warning_tp_increment(gold=gold, pred=pred, prefix='Entity') + + results = dict() + for eval_key in metrics: + results.update(metrics[eval_key].compute_f1(prefix=eval_key + '-ent-')) + + return results + + +class RelationScorer(Scorer): + @staticmethod + def load_gold_list(gold_list: List[List[Dict]]): + """[summary] + + Args: + gold_list (List[List[Dict]]): List of Sentece, each sentence contains a List of Relation Dict + [ + [ + { + 'type': 'Part-whole', + 'args': [{'type': 'Location', 'offset': [11], 'text': 'lot'}, {'type': 'Geo-political', 'offset': [14], 'text': 'city'}] + }, ... + ], + [...], + ] + + Returns: + List[Dict]: List of Sentece, each sentence contains two List (offset, string) of Relation Tuple + [ + { + 'offset': [('Part-whole', 'Geo-political', (0,), 'Geo-political', (2,)), ... ], + 'string': [('Part-whole', 'Geo-political', 'MULTAN', 'Geo-political', 'Pakistan'), ...] + } + ] + """ + gold_instance_list = [] + for gold in gold_list: + gold_instance = defaultdict(list) + for record in gold: + assert len(record['args']) == 2 + gold_instance['offset'] += [( + record['type'], + record['args'][0]['type'], + tuple_offset(record['args'][0]['offset']), + record['args'][1]['type'], + tuple_offset(record['args'][1]['offset']), + )] + gold_instance['string'] += [( + record['type'], + record['args'][0]['type'], + record['args'][0]['text'], + record['args'][1]['type'], + record['args'][1]['text'], + )] + gold_instance_list += [gold_instance] + + return gold_instance_list + + @staticmethod + def load_pred_list(pred_list): + """[summary] + + Args: + pred_list (List[Dict]): List of Sentece, each sentence contains two List (offset, string) of Relation List + [ + { + 'offset': [['Part-whole', 'Geo-political', [0], 'Geo-political', [2]]], + 'string': [['Part-whole', 'Geo-political', 'MULTAN', 'Geo-political', 'Pakistan']], + }, ... + ] + Returns: + List[Dict]: List of Sentece, each sentence contains two List (offset, string) of Relation Tuple + [ + { + 'offset': [('Part-whole', 'Geo-political', (0,), 'Geo-political', (2,))], + 'string': [('Part-whole', 'Geo-political', 'MULTAN', 'Geo-political', 'Pakistan')] + }, ... + ] + """ + pred_instance_list = list() + for pred in pred_list: + for offset_pred in pred['offset']: + + if not isinstance(offset_pred[2], tuple): + offset_pred[2] = tuple_offset(offset_pred[2]) + + if not isinstance(offset_pred[4], tuple): + offset_pred[4] = tuple_offset(offset_pred[4]) + + pred['offset'] = [tuple_offset(p) for p in pred['offset']] + pred['string'] = [tuple_offset(p) for p in pred['string']] + pred_instance_list += [pred] + return pred_instance_list + + @staticmethod + def eval_instance_list(gold_instance_list, pred_instance_list, verbose=False, match_mode='normal'): + """[summary] + + Args: + gold_instance_list (List[Dict]): List of Sentece, each sentence contains two List (offset, string) of Relation Tuple + [ + { + 'offset': [('Part-whole', 'Geo-political', (0,), 'Geo-political', (2,)), ... ], + 'string': [('Part-whole', 'Geo-political', 'MULTAN', 'Geo-political', 'Pakistan'), ...] + } + ] + pred_instance_list ([type]): List of Sentece, each sentence contains two List (offset, string) of Relation Tuple + [ + { + 'offset': [('Part-whole', 'Geo-political', (0,), 'Geo-political', (2,))], + 'string': [('Part-whole', 'Geo-political', 'MULTAN', 'Geo-political', 'Pakistan')] + }, ... + ] + verbose (bool, optional): Defaults to False. + match_mode (string, optional): [description]. Defaults to `normal` . + + Returns: + Dict: Result of Evaluation + (offset, string) X (boundary, strict) X (gold, pred, tp, P, R, F1) + """ + # Span Boundary and Type + metrics = { + 'offset': Metric(verbose=verbose, match_mode=match_mode), + 'string': Metric(verbose=verbose, match_mode=match_mode), + } + # Span Boundary Only + boundary_metrics = { + 'offset': Metric(verbose=verbose, match_mode=match_mode), + 'string': Metric(verbose=verbose, match_mode=match_mode), + } + for pred, gold in zip(pred_instance_list, gold_instance_list): + + pre_string_tp, pre_offset_tp = metrics['string'].tp, metrics['offset'].tp + + for eval_key in metrics: + # Span Boundary and Type + metrics[eval_key].count_instance( + gold_list=gold.get(eval_key, []), + pred_list=pred.get(eval_key, []), + ) + + post_string_tp, post_offset_tp = metrics['string'].tp, metrics['offset'].tp + if verbose and (post_offset_tp - pre_offset_tp != post_string_tp - pre_string_tp): + warning_tp_increment(gold=gold, pred=pred, prefix='Relation Strict') + + pre_string_tp, pre_offset_tp = boundary_metrics['string'].tp, boundary_metrics['offset'].tp + + for eval_key in boundary_metrics: + # Span Boundary Only + boundary_metrics[eval_key].count_instance( + gold_list=[(x[0], x[2], x[4]) for x in gold.get(eval_key, [])], + pred_list=[(x[0], x[2], x[4]) for x in pred.get(eval_key, [])], + ) + post_string_tp, post_offset_tp = boundary_metrics['string'].tp, boundary_metrics['offset'].tp + if verbose and post_offset_tp - pre_offset_tp != post_string_tp - pre_string_tp: + warning_tp_increment(gold=gold, pred=pred, prefix='Relation Boundary') + + results = dict() + for eval_key in metrics: + results.update(metrics[eval_key].compute_f1(prefix=eval_key + '-rel-strict-')) + for eval_key in boundary_metrics: + results.update(boundary_metrics[eval_key].compute_f1(prefix=eval_key + '-rel-boundary-')) + return results + + +class EventScorer(Scorer): + @staticmethod + def load_gold_list(gold_list): + """[summary] + + Args: + gold_list (List[List[Dict]]): List of Sentece, each sentence contains a List of Event Dict + [ + [ # Sentance + { # Event Record + 'type': 'Die', + 'offset': [16], + 'text': 'shot', + 'args': [ + {'type': 'Victim', 'offset': [17], 'text': 'himself'}, + {'type': 'Agent', 'offset': [5, 6], 'text': 'John Joseph'}, + {'type': 'Place', 'offset': [23], 'text': 'court'} + ] + }, + ] + ] + + Returns: + List[Dict]: List of Sentece, each sentence contains Four List of Event Tuple + [ + { + 'offset_trigger': [('Die', (16,)), ('Convict', (30,))], + 'string_trigger': [('Die', 'shot'), ('Convict', 'convicted')], + 'offset_role': [('Die', 'Victim', (17,)), ('Die', 'Agent', (5, 6)), ('Die', 'Place', (23,))], + 'string_role': [('Die', 'Victim', 'himself'), ('Die', 'Agent', 'John Joseph'), ('Die', 'Place', 'court')] + }, + ... + ] + """ + gold_instance_list = [] + for gold in gold_list: + gold_instance = defaultdict(list) + for record in gold: + gold_instance['offset_trigger'] += [(record['type'], tuple_offset(record['offset']))] + gold_instance['string_trigger'] += [(record['type'], record['text'])] + for arg in record['args']: + gold_instance['offset_role'] += [(record['type'], arg['type'], tuple_offset(arg['offset']))] + gold_instance['string_role'] += [(record['type'], arg['type'], arg['text'])] + gold_instance_list += [gold_instance] + return gold_instance_list + + @staticmethod + def load_pred_list(pred_list): + """[summary] + + Args: + pred_list (List[Dict]): List of Sentece, each sentence contains two List (offset, string) of Event List + [ + { + 'offset': [{'type': 'Attack', 'roles': [['Attacker', [5, 6]], ['Place', [23]], ['Target', [17]]], 'trigger': [16]}], + 'string': [{'roles': [['Attacker', 'John Joseph'], ['Place', 'court'], ['Target', 'himself']], 'type': 'Attack', 'trigger': 'shot'}], + }, + ... + ] + Returns: + List[Dict]: List of Sentece, each sentence contains four List (offset, string) X (trigger, role) of Event List + [ + { + 'offset_trigger': [('Attack', (16,))], + 'offset_role': [('Attack', 'Attacker', (5, 6)), ('Attack', 'Place', (23,)), ('Attack', 'Target', (17,))], + 'string_trigger': [('Attack', 'shot')], + 'string_role': [('Attack', 'Attacker', 'John Joseph'), ('Attack', 'Place', 'court'), ('Attack', 'Target', 'himself')], + }, + ... + ] + """ + pred_instance_list = list() + for pred in pred_list: + pred_instance = defaultdict(list) + + for offset_pred in pred['offset']: + event_type, trigger_offset = offset_pred['type'], tuple_offset(offset_pred['trigger']) + pred_instance['offset_trigger'] += [(event_type, trigger_offset)] + for role_type, role_offset in offset_pred['roles']: + pred_instance['offset_role'] += [(event_type, role_type, tuple_offset(role_offset))] + + for string_pred in pred['string']: + event_type, trigger_string = string_pred['type'], string_pred['trigger'] + pred_instance['string_trigger'] += [(event_type, trigger_string)] + for role_type, role_string in string_pred['roles']: + pred_instance['string_role'] += [(event_type, role_type, role_string)] + pred_instance_list += [pred_instance] + return pred_instance_list + + @staticmethod + def eval_instance_list(gold_instance_list, pred_instance_list, verbose=False, match_mode='normal'): + """[summary] + + Args: + gold_instance_list (List[Dict]): List of Sentece, each sentence contains Four List of Event Tuple + [ + { + 'offset_trigger': [('Die', (16,)), ('Convict', (30,))], + 'string_trigger': [('Die', 'shot'), ('Convict', 'convicted')], + 'offset_role': [('Die', 'Victim', (17,)), ('Die', 'Agent', (5, 6)), ('Die', 'Place', (23,))], + 'string_role': [('Die', 'Victim', 'himself'), ('Die', 'Agent', 'John Joseph'), ('Die', 'Place', 'court')] + }, + ... + ] + pred_instance_list (List[Dict]): List of Sentece, each sentence contains four List (offset, string) X (trigger, role) of Event List + [ + { + 'offset_trigger': [('Attack', (16,))], + 'offset_role': [('Attack', 'Attacker', (5, 6)), ('Attack', 'Place', (23,)), ('Attack', 'Target', (17,))], + 'string_trigger': [('Attack', 'shot')], + 'string_role': [('Attack', 'Attacker', 'John Joseph'), ('Attack', 'Place', 'court'), ('Attack', 'Target', 'himself')], + }, + ... + ] + verbose (bool, optional): [description]. Defaults to False. + match_mode (string, optional): [description]. Defaults to `normal`. + + Returns: + Dict: Result of Evaluation + (offset, string) X (trigger, role) X (gold, pred, tp, P, R, F1) + """ + trigger_metrics = { + 'offset': Metric(verbose=verbose, match_mode=match_mode), + 'string': Metric(verbose=verbose, match_mode=match_mode), + } + role_metrics = { + 'offset': Metric(verbose=verbose, match_mode=match_mode), + 'string': Metric(verbose=verbose, match_mode=match_mode), + } + + for pred, gold in zip(pred_instance_list, gold_instance_list): + + pre_string_tp, pre_offset_tp = trigger_metrics['string'].tp, trigger_metrics['offset'].tp + + for eval_key in trigger_metrics: + trigger_metrics[eval_key].count_instance( + gold_list=gold.get(eval_key + '_trigger', []), + pred_list=pred.get(eval_key + '_trigger', []) + ) + + post_string_tp, post_offset_tp = trigger_metrics['string'].tp, trigger_metrics['offset'].tp + if verbose and post_offset_tp - pre_offset_tp != post_string_tp - pre_string_tp: + warning_tp_increment(gold=gold, pred=pred, prefix='Trigger') + + pre_string_tp, pre_offset_tp = role_metrics['string'].tp, role_metrics['offset'].tp + + for eval_key in role_metrics: + role_metrics[eval_key].count_instance( + gold_list=gold.get(eval_key + '_role', []), + pred_list=pred.get(eval_key + '_role', []) + ) + + post_string_tp, post_offset_tp = role_metrics['string'].tp, role_metrics['offset'].tp + if verbose and post_offset_tp - pre_offset_tp != post_string_tp - pre_string_tp: + warning_tp_increment(gold=gold, pred=pred, prefix='Role') + + results = dict() + for eval_key in trigger_metrics: + results.update(trigger_metrics[eval_key].compute_f1(prefix=f'{eval_key}-evt-trigger-')) + for eval_key in role_metrics: + results.update(role_metrics[eval_key].compute_f1(prefix=f'{eval_key}-evt-role-')) + + return results diff --git a/metaretriever/uie/extraction/utils.py b/metaretriever/uie/extraction/utils.py new file mode 100644 index 00000000..0767ec7f --- /dev/null +++ b/metaretriever/uie/extraction/utils.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- + + +def convert_spot_asoc(spot_asoc_instance, structure_maker): + """将一个 Spot-Asoc 实例转换成目标字符串 + + Args: + spot_asoc_instance ([type]): [description] + structure_maker ([type]): [description] + + Returns: + [type]: [description] + """ + spot_instance_str_rep_list = list() + for spot in spot_asoc_instance: + spot_str_rep = [ + spot['label'], + structure_maker.target_span_start, + spot['span'], + ] + for asoc_label, asoc_span in spot.get('asoc', list()): + asoc_str_rep = [ + structure_maker.span_start, + asoc_label, + structure_maker.target_span_start, + asoc_span, + structure_maker.span_end, + ] + spot_str_rep += [' '.join(asoc_str_rep)] + spot_instance_str_rep_list += [' '.join([ + structure_maker.record_start, + ' '.join(spot_str_rep), + structure_maker.record_end, + ])] + target_text = ' '.join([ + structure_maker.sent_start, + ' '.join(spot_instance_str_rep_list), + structure_maker.sent_end, + ]) + return target_text + + +def convert_spot_asoc_name(spot_asoc_instance, structure_maker): + """将一个 Spot-Asoc-Name 实例转换成目标字符串 + + Args: + spot_asoc_instance ([type]): [description] + structure_maker ([type]): [description] + + Returns: + [type]: [description] + """ + spot_instance_str_rep_list = list() + for spot in spot_asoc_instance: + spot_str_rep = [ + spot['span'], + structure_maker.target_span_start, + spot['label'], + ] + for asoc_label, asoc_span in spot.get('asoc', list()): + asoc_str_rep = [ + structure_maker.span_start, + asoc_span, + structure_maker.target_span_start, + asoc_label, + structure_maker.span_end, + ] + spot_str_rep += [' '.join(asoc_str_rep)] + spot_instance_str_rep_list += [' '.join([ + structure_maker.record_start, + ' '.join(spot_str_rep), + structure_maker.record_end, + ])] + target_text = ' '.join([ + structure_maker.sent_start, + ' '.join(spot_instance_str_rep_list), + structure_maker.sent_end, + ]) + return target_text + + +convert_to_record_function = { + 'spotasoc': convert_spot_asoc, + 'spotasocname': convert_spot_asoc_name, +} diff --git a/metaretriever/uie/sel2record/__init__.py b/metaretriever/uie/sel2record/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/metaretriever/uie/sel2record/record.py b/metaretriever/uie/sel2record/record.py new file mode 100644 index 00000000..ae588a2d --- /dev/null +++ b/metaretriever/uie/sel2record/record.py @@ -0,0 +1,427 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +from asyncio.log import logger +import sys +from typing import Tuple +import numpy +import logging + +logger = logging.getLogger("__main__") + + +def match_sublist(the_list, to_match): + """ + :param the_list: [1, 2, 3, 4, 5, 6, 1, 2, 4, 5] + :param to_match: [1, 2] + :return: + [(0, 1), (6, 7)] + """ + len_to_match = len(to_match) + matched_list = list() + for index in range(len(the_list) - len_to_match + 1): + if to_match == the_list[index:index + len_to_match]: + matched_list += [(index, index + len_to_match - 1)] + return matched_list + + +def check_overlap(x, y): + if x[0] > y[1] or y[0] > x[1]: + return False + else: + return True + + +def get_index_tuple(matched: Tuple[int, int]): + return tuple(range(matched[0], matched[1] + 1)) + + +def span_to_token(text, span_to_token_strategy='space'): + if span_to_token_strategy == 'space': + return text.split(' ') + elif span_to_token_strategy == 'list': + return list(text) + else: + raise NotImplementedError( + f"The span to token strategy {span_to_token_strategy} is not implemented.") + + +class MapConfig: + + def __init__(self, + map_strategy: str = 'first', + de_duplicate: bool = True, + span_to_token: str = 'space') -> None: + self.map_strategy = map_strategy + self.de_duplicate = de_duplicate + self.span_to_token = span_to_token + + def __repr__(self) -> str: + repr_list = [ + f"map_strategy: {self.map_strategy}", + f"de_duplicate: {self.de_duplicate}", + f"span_to_token: {self.span_to_token}", + ] + return ', '.join(repr_list) + + @staticmethod + def load_from_yaml(config_file): + import yaml + with open(config_file) as fin: + config = yaml.load(fin, Loader=yaml.FullLoader) + return MapConfig( + map_strategy=config['map_strategy'], + de_duplicate=config['de_duplicate'], + span_to_token=config['span_to_token'], + ) + + +class Record: + def __init__(self, map_config) -> None: + self._map_config = map_config + + def span_to_token(self, text): + return span_to_token(text, span_to_token_strategy=self._map_config.span_to_token) + + +class EntityRecord(Record): + + @staticmethod + def to_string(pred_record_list): + entity_list = list() + for pred_record in pred_record_list: + record_type, record_text = pred_record['type'], pred_record['trigger'] + if record_text == "": + logger.warning(f"Empty Extraction {pred_record}") + continue + entity_list += [(record_type, record_text)] + return entity_list + + def to_offset(self, instance, tokens): + # map_strategy='first', de_duplicate=True + map_strategy_dict = { + 'first': self.record_to_offset_first_role, + 'closest': self.record_to_offset_closest_role, + 'longer_first': self.record_to_offset_longer_first, + } + + if self._map_config.map_strategy in map_strategy_dict: + map_function = map_strategy_dict[self._map_config.map_strategy] + return map_function( + instance=instance, + token_list=tokens, + ) + else: + raise NotImplementedError( + f"The map strategy {self._map_config.map_strategy} in {self.__class__} is not implemented.") + + def record_to_offset_closest_role(self, instance, token_list,): + """ + Find Role's offset using closest matched with trigger work. + :param instance: + :return: + """ + return self.record_to_offset_first_role(instance, token_list=token_list) + + def record_to_offset_first_role(self, instance, token_list): + """ + Find Entity's offset using first matched in the sentence. + :param instance: + :return: + """ + entity_list = list() + + entity_matched_set = set() + for pred_record in instance: + record_type, record_text = pred_record['type'], pred_record['trigger'] + if record_text == "": + logger.warning(f"Empty Extraction {pred_record}") + continue + matched_list = match_sublist( + token_list, self.span_to_token(record_text)) + for matched in matched_list: + if (record_type, matched) not in entity_matched_set: + entity_list += [(record_type, + tuple(range(matched[0], matched[1] + 1)))] + entity_matched_set.add((record_type, matched)) + break + + return entity_list + + def record_to_offset_longer_first(self, instance, token_list): + """ + Find Entity's offset using first matched in the sentence. + :param instance: + :return: + """ + entity_list = list() + + entity_matched_set = set() + for x in instance: + x['length'] = len(x['trigger']) + instance.sort(reverse=True, key=lambda x: x['length']) + + for pred_record in instance: + record_type, record_text = pred_record['type'], pred_record['trigger'] + if record_text == "": + logger.warning(f"Empty Extraction {pred_record}") + continue + + matched_list = match_sublist( + token_list, self.span_to_token(record_text)) + for matched in matched_list: + flag = False + for _, g in entity_matched_set: + if check_overlap(g, matched): + flag = True + if flag: + continue + + if (record_type, matched) not in entity_matched_set: + entity_list += [(record_type, + tuple(range(matched[0], matched[1] + 1)))] + entity_matched_set.add((record_type, matched)) + break + + return entity_list + + +class RelationRecord(Record): + + def to_offset(self, instance, tokens): + map_strategy_dict = { + 'first': self.record_to_offset_first_role, + 'closest': self.record_to_offset_closest_role, + 'longer_first': self.record_to_offset_closest_role, + } + if self._map_config.map_strategy in map_strategy_dict: + map_function = map_strategy_dict[self._map_config.map_strategy] + return map_function( + instance=instance, + token_list=tokens, + ) + else: + raise NotImplementedError( + f"The map strategy {self._map_config.map_strategy} in {self.__class__} is not implemented.") + + @staticmethod + def to_string(instance): + relation_list = list() + for record in instance: + relation_type = record['type'] + relation = [relation_type] + if len(record['roles']) < 2: + continue + for role_type, text_str in record['roles'][:2]: + relation += [role_type, text_str] + relation_list += [tuple(relation)] + return relation_list + + def record_to_offset_first_role(self, instance, token_list): + """ + Find Role's offset using first matched in the sentence. + :param instance: + :return: + """ + relation_list = list() + + for record in instance: + relation_type = record['type'] + + if len(record['roles']) < 2: + continue + + relation = [relation_type] + for role_type, text_str in record['roles'][:2]: + matched_list = match_sublist( + token_list, self.span_to_token(text_str)) + if len(matched_list) == 0: + sys.stderr.write("[Cannot reconstruct]: %s %s\n" % + (text_str, token_list)) + break + relation += [role_type, get_index_tuple(matched_list[0])] + if len(relation) != 5 or (self._map_config.de_duplicate and tuple(relation) in relation_list): + continue + relation_list += [tuple(relation)] + + return relation_list + + def record_to_offset_closest_role(self, instance, token_list): + """ + Find Role's offset using closest matched with trigger work. + :param instance: + :return: + """ + relation_list = list() + + for record in instance: + relation_type = record['type'] + + if len(record['roles']) < 2: + continue + + arg1_type, arg1_text = record['roles'][0] + arg2_type, arg2_text = record['roles'][1] + arg1_matched_list = match_sublist( + token_list, self.span_to_token(arg1_text)) + arg2_matched_list = match_sublist( + token_list, self.span_to_token(arg2_text)) + + if len(arg1_matched_list) == 0: + sys.stderr.write("[Cannot reconstruct]: %s %s\n" % + (arg1_text, token_list)) + break + if len(arg2_matched_list) == 0: + sys.stderr.write("[Cannot reconstruct]: %s %s\n" % + (arg2_text, token_list)) + break + + distance_tuple = list() + for arg1_match in arg1_matched_list: + for arg2_match in arg2_matched_list: + distance = abs(arg1_match[0] - arg2_match[0]) + distance_tuple += [(distance, arg1_match, arg2_match)] + distance_tuple.sort() + + relation = [relation_type, + arg1_type, get_index_tuple(distance_tuple[0][1]), + arg2_type, get_index_tuple(distance_tuple[0][2]), + ] + if self._map_config.de_duplicate and tuple(relation) in relation_list: + continue + relation_list += [tuple(relation)] + + return relation_list + + +class EventRecord(Record): + def to_offset(self, instance, tokens): + map_strategy_dict = { + 'first': self.record_to_offset_first_role, + 'closest': self.record_to_offset_closest_role, + 'longer_first': self.record_to_offset_closest_role, + } + if self._map_config.map_strategy in map_strategy_dict: + map_function = map_strategy_dict[self._map_config.map_strategy] + return map_function( + instance=instance, + token_list=tokens, + ) + else: + raise NotImplementedError( + f"The map strategy {self._map_config.map_strategy} in {self.__class__} is not implemented.") + + @staticmethod + def to_string(instance): + """ + {'type': 'Justice:Appeal', + 'trigger': 'appeal', + 'roles': [ + ('Adjudicator', 'court'), + ('Plaintiff', 'Anwar') + ], } + """ + return instance + + def record_to_offset_first_role(self, instance, token_list): + """ + Find Role's offset using first matched in the sentence. + :param instance: + :return: + """ + record_list = list() + + trigger_matched_set = set() + for record in instance: + event_type = record['type'] + trigger = record['trigger'] + matched_list = match_sublist( + token_list, self.span_to_token(trigger)) + + if len(matched_list) == 0: + sys.stderr.write("[Cannot reconstruct]: %s %s\n" % + (trigger, token_list)) + continue + + trigger_offset = None + for matched in matched_list: + if matched not in trigger_matched_set: + trigger_offset = get_index_tuple(matched) + trigger_matched_set.add(matched) + break + + # No trigger word, skip the record + if trigger_offset is None: + break + + pred_record = {'type': event_type, + 'roles': [], + 'trigger': trigger_offset} + + for role_type, text_str in record['roles']: + matched_list = match_sublist( + token_list, self.span_to_token(text_str)) + if len(matched_list) == 0: + sys.stderr.write( + "[Cannot reconstruct]: %s %s\n" % (text_str, token_list)) + continue + pred_record['roles'] += [(role_type, + get_index_tuple(matched_list[0]))] + + record_list += [pred_record] + + return record_list + + def record_to_offset_closest_role(self, instance, token_list): + """ + Find Role's offset using closest matched with trigger work. + :param instance: + :return: + """ + record_list = list() + + trigger_matched_set = set() + for record in instance: + event_type = record['type'] + trigger = record['trigger'] + matched_list = match_sublist( + token_list, self.span_to_token(trigger)) + + if len(matched_list) == 0: + sys.stderr.write("[Cannot reconstruct]: %s %s\n" % + (trigger, token_list)) + continue + + trigger_offset = None + for matched in matched_list: + if matched not in trigger_matched_set: + trigger_offset = get_index_tuple(matched) + trigger_matched_set.add(matched) + break + + # No trigger word, skip the record + if trigger_offset is None or len(trigger_offset) == 0: + break + + pred_record = {'type': event_type, + 'roles': [], + 'trigger': trigger_offset} + + for role_type, text_str in record['roles']: + matched_list = match_sublist( + token_list, self.span_to_token(text_str)) + # if len(matched_list) == 1: + # pred_record['roles'] += [(role_type, get_index_tuple(matched_list[0]))] + if len(matched_list) == 0: + sys.stderr.write( + "[Cannot reconstruct]: %s %s\n" % (text_str, token_list)) + else: + abs_distances = [abs(match[0] - trigger_offset[0]) + for match in matched_list] + closest_index = numpy.argmin(abs_distances) + pred_record['roles'] += [( + role_type, + get_index_tuple(matched_list[closest_index]) + )] + + record_list += [pred_record] + return record_list diff --git a/metaretriever/uie/sel2record/sel2record.py b/metaretriever/uie/sel2record/sel2record.py new file mode 100644 index 00000000..70f76940 --- /dev/null +++ b/metaretriever/uie/sel2record/sel2record.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +from collections import defaultdict, OrderedDict +import os +from uie.extraction.record_schema import RecordSchema +from uie.extraction.predict_parser import get_predict_parser +from uie.sel2record.record import EntityRecord, MapConfig, RelationRecord, EventRecord +import logging + +logger = logging.getLogger("__main__") + + +task_record_map = { + 'entity': EntityRecord, + 'relation': RelationRecord, + 'event': EventRecord, +} + + +def proprocessing_graph_record(graph, schema_dict): + """Mapping generated spot-asoc result to Entity/Relation/Event + 将抽取的Spot-Asoc结构,根据不同的 Schema 转换成 Entity/Relation/Event 结果 + """ + records = { + 'entity': list(), + 'relation': list(), + 'event': list(), + } + + entity_dict = OrderedDict() + + # 根据不同任务的 Schema 将不同的 Spot 对应到不同抽取结果: Entity/Event + # Mapping generated spot result to Entity/Event + for record in graph['pred_record']: + + if record['type'] in schema_dict['entity'].type_list: + records['entity'] += [{ + 'trigger': record['trigger'], + 'type': record['type'] + }] + entity_dict[record['trigger']] = record['type'] + + elif record['type'] in schema_dict['event'].type_list: + records['event'] += [record] + + else: + print("Type `%s` invalid." % record['type']) + + # 根据不同任务的 Schema 将不同的 Asoc 对应到不同抽取结果: Relation/Argument + # Mapping generated asoc result to Relation/Argument + for record in graph['pred_record']: + if record['type'] in schema_dict['entity'].type_list: + for role in record['roles']: + records['relation'] += [{ + 'type': role[0], + 'roles': [(record['type'], record['trigger']), + (entity_dict.get( + role[1], record['type']), role[1]), + ] + }] + + if len(entity_dict) > 0: + for record in records['event']: + if record['type'] in schema_dict['event'].type_list: + new_role_list = list() + for role in record['roles']: + if role[1] in entity_dict: + new_role_list += [role] + record['roles'] = new_role_list + + return records + + +class SEL2Record: + def __init__(self, schema_dict, decoding_schema, map_config: MapConfig) -> None: + self._schema_dict = schema_dict + self._predict_parser = get_predict_parser( + decoding_schema=decoding_schema, + label_constraint=schema_dict['record'] + ) + self._map_config = map_config + + def __repr__(self) -> str: + return f"## {self._map_config}" + + def sel2record(self, pred, text, tokens): + # Parsing generated SEL to String-level Record + # 将生成的结构表达式解析成 String 级别的 Record + well_formed_list, counter = self._predict_parser.decode( + gold_list=[], + pred_list=[pred], + text_list=[text], + ) + + # Convert String-level Record to Entity/Relation/Event + # 将抽取的 Spot-Asoc Record 结构 + # 根据不同的 Schema 转换成 Entity/Relation/Event 结果 + pred_records = proprocessing_graph_record( + well_formed_list[0], + self._schema_dict + ) + + pred = defaultdict(dict) + # Mapping String-level record to Offset-level record + # 将 String 级别的 Record 回标成 Offset 级别的 Record + for task in task_record_map: + record_map = task_record_map[task]( + map_config=self._map_config, + ) + + pred[task]['offset'] = record_map.to_offset( + instance=pred_records.get(task, []), + tokens=tokens, + ) + + pred[task]['string'] = record_map.to_string( + pred_records.get(task, []), + ) + return pred + + @staticmethod + def load_schema_dict(schema_folder): + schema_dict = dict() + for schema_key in ['record', 'entity', 'relation', 'event']: + schema_filename = os.path.join(schema_folder, f'{schema_key}.schema') + if os.path.exists(schema_filename): + schema_dict[schema_key] = RecordSchema.read_from_file( + schema_filename + ) + else: + logger.warning(f"{schema_filename} is empty, ignore.") + schema_dict[schema_key] = RecordSchema.get_empty_schema() + return schema_dict diff --git a/metaretriever/uie/seq2seq/__init__.py b/metaretriever/uie/seq2seq/__init__.py new file mode 100644 index 00000000..5bfd17ea --- /dev/null +++ b/metaretriever/uie/seq2seq/__init__.py @@ -0,0 +1,2 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- diff --git a/metaretriever/uie/seq2seq/constrained_seq2seq.py b/metaretriever/uie/seq2seq/constrained_seq2seq.py new file mode 100644 index 00000000..915f15c9 --- /dev/null +++ b/metaretriever/uie/seq2seq/constrained_seq2seq.py @@ -0,0 +1,1053 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +import torch +import torch.nn as nn +from dataclasses import dataclass, field +from typing import Union, List, Dict, Tuple, Any, Optional +from torch.cuda.amp import autocast + +from transformers import ( + Seq2SeqTrainer, + Seq2SeqTrainingArguments, + AdamW, + TrainerCallback, +) + +from transformers.trainer import * + +from uie.seq2seq.constraint_decoder import get_constraint_decoder + +import learn2learn as l2l + +import random +import pdb + +@dataclass +class ConstraintSeq2SeqTrainingArguments(Seq2SeqTrainingArguments): + """ + Parameters: + constraint_decoding (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to use Constraint Decoding + structure_weight (:obj:`float`, `optional`, defaults to :obj:`None`): + """ + constraint_decoding: bool = field(default=False, metadata={"help": "Whether to Constraint Decoding or not."}) + save_better_checkpoint: bool = field(default=False, + metadata={"help": "Whether to save better metric checkpoint"}) + start_eval_step: int = field(default=0, metadata={"help": "Start Evaluation after Eval Step"}) + trainer_type: str = field(default="meta_pretrain", metadata={"help": "Trainer for training model, containing meta_pretrain, meta_finetune, origin"}) + +class OriginalConstraintSeq2SeqTrainer(Seq2SeqTrainer): + def __init__(self, decoding_type_schema=None, task='event', decoding_format='tree', source_prefix=None, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.decoding_format = decoding_format + self.decoding_type_schema = decoding_type_schema + + # Label smoothing by sum token loss, different from different Label smootheing + if self.args.label_smoothing_factor != 0: + print('Using %s' % self.label_smoother) + else: + self.label_smoother = None + + if self.args.constraint_decoding: + self.constraint_decoder = get_constraint_decoder(tokenizer=self.tokenizer, + type_schema=self.decoding_type_schema, + decoding_schema=self.decoding_format, + source_prefix=source_prefix, + task_name=task) + else: + self.constraint_decoder = None + + self.oom_batch = 0 + + def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + """ + Perform a training step on a batch of inputs. + + Subclass and override to inject custom behavior. + + Args: + model (:obj:`nn.Module`): + The model to train. + inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument :obj:`labels`. Check your model's documentation for all accepted arguments. + + Return: + :obj:`torch.Tensor`: The tensor with training loss on this batch. + """ + + oom = False + oom_message = "" + try: + loss = super().training_step(model, inputs) + return loss + except RuntimeError as e: + if 'out of memory' in str(e): + oom = True + oom_message = str(e) + logger.warning(f'ran out of memory {self.oom_batch} on {self.args.local_rank}') + for k, v in inputs.items(): + print(k, v.size()) + else: + raise e + + if oom: + self.oom_batch += 1 + raise RuntimeError(oom_message) + + def train( + self, + resume_from_checkpoint: Optional[Union[str, bool]] = None, + trial: Union["optuna.Trial", Dict[str, Any]] = None, + **kwargs, + ): + return super().train( + resume_from_checkpoint=resume_from_checkpoint, + trial=trial, + **kwargs + ) + + def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch): + if self.control.should_log: + logs: Dict[str, float] = {} + tr_loss_scalar = tr_loss.item() + # reset tr_loss to zero + tr_loss -= tr_loss + + logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) + logs["learning_rate"] = self._get_learning_rate() + + self._total_loss_scalar += tr_loss_scalar + self._globalstep_last_logged = self.state.global_step + + self.log(logs) + + if self.args.start_eval_step > 0 and self.state.global_step < self.args.start_eval_step: + return + + previous_best_metric = self.state.best_metric + metrics = None + if self.control.should_evaluate: + metrics = self.evaluate() + self._report_to_hp_search(trial, epoch, metrics) + + # Only save the checkpoint better than previous_best_metric + if self.args.save_better_checkpoint and self.args.metric_for_best_model is not None: + if metrics is not None and previous_best_metric is not None: + if metrics[self.args.metric_for_best_model] <= previous_best_metric: + return + + if self.control.should_save: + self._save_checkpoint(model, trial, metrics=metrics) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) + + def prediction_step( + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Perform an evaluation step on :obj:`model` using obj:`inputs`. + + Subclass and override to inject custom behavior. + + Args: + model (:obj:`nn.Module`): + The model to evaluate. + inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument :obj:`labels`. Check your model's documentation for all accepted arguments. + prediction_loss_only (:obj:`bool`): + Whether or not to return the loss only. + + Return: + Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and + labels (each being optional). + """ + + def prefix_allowed_tokens_fn(batch_id, sent): + # print(self.tokenizer.convert_ids_to_tokens(inputs['labels'][batch_id])) + src_sentence = inputs['input_ids'][batch_id] + return self.constraint_decoder.constraint_decoding(src_sentence=src_sentence, + tgt_generated=sent) + + if not self.args.predict_with_generate or prediction_loss_only: + return super().prediction_step( + model=model, + inputs=inputs, + prediction_loss_only=prediction_loss_only, + ignore_keys=ignore_keys, + ) + has_labels = "labels" in inputs + inputs = self._prepare_inputs(inputs) + + gen_kwargs = { + "max_length": self._max_length if self._max_length is not None else self.model.config.max_length, + "num_beams": self._num_beams if self._num_beams is not None else self.model.config.num_beams, + "prefix_allowed_tokens_fn": prefix_allowed_tokens_fn if self.constraint_decoder else None, + } + + generated_tokens = self.model.generate( + inputs["input_ids"], + attention_mask=inputs["attention_mask"], + **gen_kwargs, + ) + + # in case the batch is shorter than max length, the output should be padded + if generated_tokens.shape[-1] < gen_kwargs["max_length"]: + generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) + + with torch.no_grad(): + if self.use_amp: + with autocast(): + outputs = model(**inputs) + else: + outputs = model(**inputs) + if has_labels: + if self.label_smoother is not None: + loss = self.label_smoother(outputs, inputs["labels"]).mean().detach() + else: + loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach() + else: + loss = None + + if self.args.prediction_loss_only: + return loss, None, None + + labels = inputs["labels"] + if labels.shape[-1] < gen_kwargs["max_length"]: + labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) + + return loss, generated_tokens, labels + +class UIEPretrainConstraintSeq2SeqTrainer(OriginalConstraintSeq2SeqTrainer): + + def pretrain(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> nn.Module: + def get_loss(model, inputs): + if is_sagemaker_mp_enabled(): + scaler = self.scaler if self.use_amp else None + loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps, scaler=scaler) + return loss_mb.reduce_mean().detach().to(self.args.device) + + if self.use_amp: + with autocast(): + loss = self.compute_loss(model, inputs) + else: + loss = self.compute_loss(model, inputs) + + if self.args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel training + + if self.args.gradient_accumulation_steps > 1 and not self.deepspeed: + # deepspeed handles loss scaling by gradient_accumulation_steps in its `backward` + loss = loss / self.args.gradient_accumulation_steps + + return loss + + model.train() + inputs = self._prepare_inputs(inputs) + + # record inputs + record_input_ids = inputs.pop("record_input_ids") + record_inputs = { + "input_ids": record_input_ids, + "labels": inputs["labels"], + "decoder_input_ids": inputs["decoder_input_ids"] + } + + # mlm inputs + mlm_input_ids = inputs.pop("mlm_input_ids") + mlm_target_ids = inputs.pop("mlm_target_ids") + mlm_decoder_input_ids = inputs.pop("mlm_decoder_input_ids") + mlm_inputs = { + "input_ids": mlm_input_ids, + "labels": mlm_target_ids, + "decoder_input_ids": mlm_decoder_input_ids, + } + + # remove other field + if "noised_input_ids" in inputs.keys(): + inputs.pop("noised_input_ids") + if "noised_att_mask" in inputs.keys(): + inputs.pop("noised_att_mask") + + # inner loop + loss = get_loss(model, inputs) + + # record loss + record_loss = get_loss(model, record_inputs) + + # mlm loss + mlm_loss = get_loss(model, mlm_inputs) + + loss = loss + record_loss + mlm_loss + + if self.use_amp: + self.scaler.scale(loss).backward() + elif self.use_apex: + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + elif self.deepspeed: + # loss gets scaled under gradient_accumulation_steps in deepspeed + loss = self.deepspeed.backward(loss) + else: + loss.backward() + + return loss.detach() + + def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + oom = False + oom_message = "" + try: + loss = self.pretrain(model, inputs) + return loss + except RuntimeError as e: + if 'out of memory' in str(e): + oom = True + oom_message = str(e) + logger.warning(f'ran out of memory {self.oom_batch} on {self.args.local_rank}') + for k, v in inputs.items(): + print(k, v.size()) + else: + raise e + + if oom: + self.oom_batch += 1 + raise RuntimeError(oom_message) + +class UIEFinetuneConstraintSeq2SeqTrainer(OriginalConstraintSeq2SeqTrainer): + + def finetune(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> nn.Module: + def get_loss(model, inputs): + if is_sagemaker_mp_enabled(): + scaler = self.scaler if self.use_amp else None + loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps, scaler=scaler) + return loss_mb.reduce_mean().detach().to(self.args.device) + + if self.use_amp: + with autocast(): + loss = self.compute_loss(model, inputs) + else: + loss = self.compute_loss(model, inputs) + + if self.args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel training + + if self.args.gradient_accumulation_steps > 1 and not self.deepspeed: + # deepspeed handles loss scaling by gradient_accumulation_steps in its `backward` + loss = loss / self.args.gradient_accumulation_steps + + return loss + + model.train() + inputs = self._prepare_inputs(inputs) + + # remove pretrain field + if "record_input_ids" in inputs.keys(): + inputs.pop("record_input_ids") + if "mlm_input_ids" in inputs.keys(): + inputs.pop("mlm_input_ids") + if "mlm_target_ids" in inputs.keys(): + inputs.pop("mlm_target_ids") + if "mlm_decoder_input_ids" in inputs.keys(): + inputs.pop("mlm_decoder_input_ids") + if "noised_input_ids" in inputs.keys(): + inputs.pop("noised_input_ids") + if "noised_att_mask" in inputs.keys(): + inputs.pop("noised_att_mask") + + + # finetune loss + loss = get_loss(model, inputs) + + if self.use_amp: + self.scaler.scale(loss).backward() + elif self.use_apex: + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + elif self.deepspeed: + # loss gets scaled under gradient_accumulation_steps in deepspeed + loss = self.deepspeed.backward(loss) + else: + loss.backward() + + return loss.detach() + + def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + oom = False + oom_message = "" + try: + loss = self.finetune(model, inputs) + return loss + except RuntimeError as e: + if 'out of memory' in str(e): + oom = True + oom_message = str(e) + logger.warning(f'ran out of memory {self.oom_batch} on {self.args.local_rank}') + for k, v in inputs.items(): + print(k, v.size()) + else: + raise e + + if oom: + self.oom_batch += 1 + raise RuntimeError(oom_message) + + def prediction_step( + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Perform an evaluation step on :obj:`model` using obj:`inputs`. + + Subclass and override to inject custom behavior. + + Args: + model (:obj:`nn.Module`): + The model to evaluate. + inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument :obj:`labels`. Check your model's documentation for all accepted arguments. + prediction_loss_only (:obj:`bool`): + Whether or not to return the loss only. + + Return: + Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and + labels (each being optional). + """ + + def prefix_allowed_tokens_fn(batch_id, sent): + # print(self.tokenizer.convert_ids_to_tokens(inputs['labels'][batch_id])) + src_sentence = inputs['input_ids'][batch_id] + return self.constraint_decoder.constraint_decoding(src_sentence=src_sentence, + tgt_generated=sent) + + if not self.args.predict_with_generate or prediction_loss_only: + return super().prediction_step( + model=model, + inputs=inputs, + prediction_loss_only=prediction_loss_only, + ignore_keys=ignore_keys, + ) + has_labels = "labels" in inputs + inputs = self._prepare_inputs(inputs) + + # remove pretrain field + if "record_input_ids" in inputs.keys(): + inputs.pop("record_input_ids") + if "mlm_input_ids" in inputs.keys(): + inputs.pop("mlm_input_ids") + if "mlm_target_ids" in inputs.keys(): + inputs.pop("mlm_target_ids") + if "mlm_decoder_input_ids" in inputs.keys(): + inputs.pop("mlm_decoder_input_ids") + if "noised_input_ids" in inputs.keys(): + inputs.pop("noised_input_ids") + if "noised_att_mask" in inputs.keys(): + inputs.pop("noised_att_mask") + + gen_kwargs = { + "max_length": self._max_length if self._max_length is not None else self.model.config.max_length, + "num_beams": self._num_beams if self._num_beams is not None else self.model.config.num_beams, + "prefix_allowed_tokens_fn": prefix_allowed_tokens_fn if self.constraint_decoder else None, + } + + generated_tokens = self.model.generate( + inputs["input_ids"], + attention_mask=inputs["attention_mask"], + **gen_kwargs, + ) + + # in case the batch is shorter than max length, the output should be padded + if generated_tokens.shape[-1] < gen_kwargs["max_length"]: + generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) + + with torch.no_grad(): + if self.use_amp: + with autocast(): + outputs = model(**inputs) + else: + outputs = model(**inputs) + if has_labels: + if self.label_smoother is not None: + loss = self.label_smoother(outputs, inputs["labels"]).mean().detach() + else: + loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach() + else: + loss = None + + if self.args.prediction_loss_only: + return loss, None, None + + labels = inputs["labels"] + if labels.shape[-1] < gen_kwargs["max_length"]: + labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) + + return loss, generated_tokens, labels + +class MetaPretrainConstraintSeq2SeqTrainer(OriginalConstraintSeq2SeqTrainer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.model = l2l.algorithms.MAML(self.model, lr=1e-4, first_order=True) + self.model_wrapped = self.model + + def _save(self, output_dir: Optional[str] = None, state_dict=None): + output_dir = output_dir if output_dir is not None else self.args.output_dir + os.makedirs(output_dir, exist_ok=True) + logger.info(f"Saving model checkpoint to {output_dir}") + if not isinstance(self.model, PreTrainedModel): + if isinstance(unwrap_model(self.model), PreTrainedModel): + if state_dict is None: + state_dict = unwrap_model(self.model).state_dict() + unwrap_model(self.model).save_pretrained(output_dir, state_dict=state_dict) + else: + logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") + if state_dict is None: + state_dict = self.model.state_dict() + torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + else: + self.model.save_pretrained(output_dir, state_dict=state_dict) + if self.tokenizer is not None: + self.tokenizer.save_pretrained(output_dir) + torch.save(self.args, os.path.join(output_dir, "training_args.bin")) + + def cat_two_various_length_tensor(self, x, y, max_length=None): + batch_size, seq_len = x.shape + + def create_index(x, y, batch_size): + num = torch.bincount(((x!=0)&(x!=1)).nonzero()[:,0], minlength=batch_size) + + pad_ones = torch.ones(y.shape[0], y.shape[1]-1, dtype=torch.long, device=y.device) + + update_index = torch.cat([num.unsqueeze(-1), pad_ones], dim=-1) + index = torch.cumsum(update_index, dim=1) + + return index + + xy = torch.cat([x, torch.zeros_like(y)], dim=-1) + index = create_index(x, y, batch_size) + + xy.scatter_(1, index, y) + + if max_length is not None: + xy = xy[:, :max_length] + else: + xy = xy[:, :seq_len] + + att_mask = (xy != 0).float() + + return xy, att_mask + + def split_input(self, inputs): + support = {} + query = {} + for key, value in inputs.items(): + length = value.shape[0] + # support[key] = value[:length//2] + # query[key] = value[length//2:] + support[key] = value[::2].contiguous() + query[key] = value[1::2].contiguous() + return support, query + + def meta_learn(self, meta_learner: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> nn.Module: + def get_loss(model, inputs): + if is_sagemaker_mp_enabled(): + scaler = self.scaler if self.use_amp else None + loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps, scaler=scaler) + return loss_mb.reduce_mean().detach().to(self.args.device) + + if self.use_amp: + with autocast(): + loss = self.compute_loss(model, inputs) + else: + loss = self.compute_loss(model, inputs) + + if self.args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel training + + if self.args.gradient_accumulation_steps > 1 and not self.deepspeed: + # deepspeed handles loss scaling by gradient_accumulation_steps in its `backward` + loss = loss / self.args.gradient_accumulation_steps + + return loss + + model = meta_learner.clone() + + model.train() + inputs = self._prepare_inputs(inputs) + + ######################################################### + # mlm loss and record loss + ######################################################### + + # record inputs + record_input_ids = inputs.pop("record_input_ids") + record_inputs = { + "input_ids": record_input_ids, + "labels": inputs["labels"], + "decoder_input_ids": inputs["decoder_input_ids"] + } + + # mlm inputs + mlm_input_ids = inputs.pop("mlm_input_ids") + mlm_target_ids = inputs.pop("mlm_target_ids") + mlm_decoder_input_ids = inputs.pop("mlm_decoder_input_ids") + mlm_inputs = { + "input_ids": mlm_input_ids, + "labels": mlm_target_ids, + "decoder_input_ids": mlm_decoder_input_ids, + } + + # record loss + record_loss = get_loss(model, record_inputs) + + # mlm loss + mlm_loss = get_loss(model, mlm_inputs) + + ######################################################### + # meta loss + ######################################################### + + support_inputs, query_inputs = self.split_input(inputs) + + # refine inputs + support_refine_input_ids = support_inputs.pop("noised_input_ids") + support_refine_att_mask = support_inputs.pop("noised_att_mask") + support_refine_inputs = { + "input_ids": support_refine_input_ids, + "attention_mask": support_refine_att_mask, + "labels": support_inputs["labels"], + "decoder_input_ids": support_inputs["decoder_input_ids"] + } + + query_refine_input_ids = query_inputs.pop("noised_input_ids") + query_refine_att_mask = query_inputs.pop("noised_att_mask") + query_refine_inputs = { + "input_ids": query_refine_input_ids, + "attention_mask": query_refine_att_mask, + "labels": query_inputs["labels"], + "decoder_input_ids": query_inputs["decoder_input_ids"] + } + + # inner loss + support_text2struct_loss = get_loss(model, support_inputs) + support_refine_loss = get_loss(model, support_refine_inputs) + support_loss = support_text2struct_loss + support_refine_loss + + model.adapt(support_loss) + + # outer loss + # support_text2struct_loss = get_loss(model, support_inputs) + # support_refine_loss = get_loss(model, support_refine_inputs) + # support_loss = support_text2struct_loss + support_refine_loss + + query_text2struct_loss = get_loss(model, query_inputs) + query_refine_loss = get_loss(model, query_refine_inputs) + query_loss = query_text2struct_loss + query_refine_loss + + # meta_loss = (query_loss + support_loss) + meta_loss = query_loss + + loss = (meta_loss + record_loss + mlm_loss) / 4 + + if self.use_amp: + self.scaler.scale(loss).backward() + elif self.use_apex: + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + elif self.deepspeed: + # loss gets scaled under gradient_accumulation_steps in deepspeed + loss = self.deepspeed.backward(loss) + else: + loss.backward() + + return loss.detach() + + def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + oom = False + oom_message = "" + try: + loss = self.meta_learn(model, inputs) + return loss + except RuntimeError as e: + if 'out of memory' in str(e): + oom = True + oom_message = str(e) + logger.warning(f'ran out of memory {self.oom_batch} on {self.args.local_rank}') + for k, v in inputs.items(): + print(k, v.size()) + else: + raise e + + if oom: + self.oom_batch += 1 + raise RuntimeError(oom_message) + + def prediction_in_one_step( + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: + + def prefix_allowed_tokens_fn(batch_id, sent): + src_sentence = inputs['input_ids'][batch_id] + return self.constraint_decoder.constraint_decoding(src_sentence=src_sentence, + tgt_generated=sent) + + if not self.args.predict_with_generate or prediction_loss_only: + return super().prediction_step( + model=model, + inputs=inputs, + prediction_loss_only=prediction_loss_only, + ignore_keys=ignore_keys, + ) + has_labels = "labels" in inputs + inputs = self._prepare_inputs(inputs) + + gen_kwargs = { + "max_length": self._max_length if hasattr(self, "_max_length") and self._max_length is not None else self.model.config.max_length, + "num_beams": self._num_beams if hasattr(self, "_num_beams") and self._num_beams is not None else self.model.config.num_beams, + "prefix_allowed_tokens_fn": prefix_allowed_tokens_fn if self.constraint_decoder else None, + } + + generated_tokens = self.model.generate( + inputs["input_ids"], + attention_mask=inputs["attention_mask"], + **gen_kwargs, + ) + + # in case the batch is shorter than max length, the output should be padded + if generated_tokens.shape[-1] < gen_kwargs["max_length"]: + generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) + + with torch.no_grad(): + if self.use_amp: + with autocast(): + outputs = model(**inputs) + else: + outputs = model(**inputs) + if has_labels: + if self.label_smoother is not None: + loss = self.label_smoother(outputs, inputs["labels"]).mean().detach() + else: + loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach() + else: + loss = None + + if self.args.prediction_loss_only: + return loss, None, None + + labels = inputs["labels"] + if labels.shape[-1] < gen_kwargs["max_length"]: + labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) + + return loss, generated_tokens, labels + + def prediction_step( + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: + + loss, generated_tokens, labels = self.prediction_in_one_step(model, inputs, prediction_loss_only, ignore_keys) + + original_input_ids = inputs["input_ids"] + + new_input_ids, new_att_mask = self.cat_two_various_length_tensor(original_input_ids, generated_tokens) + + inputs["input_ids"] = new_input_ids + inputs["attention_mask"] = new_att_mask + + loss, generated_tokens, labels = self.prediction_in_one_step(model, inputs, prediction_loss_only, ignore_keys) + + return loss, generated_tokens, labels + +class MetaFinetuneConstraintSeq2SeqTrainer(OriginalConstraintSeq2SeqTrainer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.prompt_text = " Predicted results: " + self.prompt_inputs = self.tokenizer(self.prompt_text, return_tensors="pt") + + def _save(self, output_dir: Optional[str] = None, state_dict=None): + output_dir = output_dir if output_dir is not None else self.args.output_dir + os.makedirs(output_dir, exist_ok=True) + logger.info(f"Saving model checkpoint to {output_dir}") + if not isinstance(self.model, PreTrainedModel): + if isinstance(unwrap_model(self.model), PreTrainedModel): + if state_dict is None: + state_dict = unwrap_model(self.model).state_dict() + unwrap_model(self.model).save_pretrained(output_dir, state_dict=state_dict) + else: + logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") + if state_dict is None: + state_dict = self.model.state_dict() + torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + else: + self.model.save_pretrained(output_dir, state_dict=state_dict) + if self.tokenizer is not None: + self.tokenizer.save_pretrained(output_dir) + torch.save(self.args, os.path.join(output_dir, "training_args.bin")) + + def cat_prompt(self, original_input_ids): + prompt = self.prompt_inputs["input_ids"] + prompt = prompt.to(original_input_ids.device) + + batch_size = original_input_ids.shape[0] + prompt = prompt.repeat(batch_size, 1) + + prompt_len = prompt.shape[1] + text_len = original_input_ids.shape[1] + + new_input_ids, new_att_mask = self.cat_two_various_length_tensor(original_input_ids, prompt, max_length=text_len+prompt_len) + return new_input_ids, new_att_mask + + def cat_two_various_length_tensor(self, x, y, max_length=None): + batch_size, seq_len = x.shape + + def create_index(x, y, batch_size): + num = torch.bincount(((x!=0)&(x!=1)).nonzero()[:,0], minlength=batch_size) + + pad_ones = torch.ones(y.shape[0], y.shape[1]-1, dtype=torch.long, device=y.device) + + update_index = torch.cat([num.unsqueeze(-1), pad_ones], dim=-1) + index = torch.cumsum(update_index, dim=1) + + return index + + xy = torch.cat([x, torch.zeros_like(y)], dim=-1) + index = create_index(x, y, batch_size) + + xy.scatter_(1, index, y) + + if max_length is not None: + xy = xy[:, :max_length] + else: + xy = xy[:, :seq_len] + + att_mask = (xy != 0).float() + + return xy, att_mask + + def finetune(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> nn.Module: + def get_loss(model, inputs): + if is_sagemaker_mp_enabled(): + scaler = self.scaler if self.use_amp else None + loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps, scaler=scaler) + return loss_mb.reduce_mean().detach().to(self.args.device) + + if self.use_amp: + with autocast(): + loss = self.compute_loss(model, inputs) + else: + loss = self.compute_loss(model, inputs) + + if self.args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel training + + if self.args.gradient_accumulation_steps > 1 and not self.deepspeed: + # deepspeed handles loss scaling by gradient_accumulation_steps in its `backward` + loss = loss / self.args.gradient_accumulation_steps + + return loss + + model.train() + inputs = self._prepare_inputs(inputs) + + # remove pretrain field + if "record_input_ids" in inputs.keys(): + inputs.pop("record_input_ids") + if "mlm_input_ids" in inputs.keys(): + inputs.pop("mlm_input_ids") + if "mlm_target_ids" in inputs.keys(): + inputs.pop("mlm_target_ids") + if "mlm_decoder_input_ids" in inputs.keys(): + inputs.pop("mlm_decoder_input_ids") + if "noised_input_ids" in inputs.keys(): + inputs.pop("noised_input_ids") + if "noised_att_mask" in inputs.keys(): + inputs.pop("noised_att_mask") + + # store original input_ids + original_input_ids = inputs["input_ids"] + original_att_mask = inputs["attention_mask"] + + # finetune loss + loss = get_loss(model, inputs) + + # refine loss + _, generated_tokens, labels = self.prediction_in_one_step(model, inputs, prediction_loss_only=False) + + input_ids_with_refine_prompt, att_mask_with_refine_prompt = self.cat_prompt(original_input_ids) + new_input_ids, new_att_mask = self.cat_two_various_length_tensor(input_ids_with_refine_prompt, generated_tokens) + # new_input_ids, new_att_mask = self.cat_two_various_length_tensor(original_input_ids, generated_tokens) + inputs["input_ids"] = new_input_ids + inputs["attention_mask"] = new_att_mask + + # refine loss + refine_loss = get_loss(model, inputs) + + loss = loss + refine_loss + + if self.use_amp: + self.scaler.scale(loss).backward() + elif self.use_apex: + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + elif self.deepspeed: + # loss gets scaled under gradient_accumulation_steps in deepspeed + loss = self.deepspeed.backward(loss) + else: + loss.backward() + + return loss.detach() + + def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + oom = False + oom_message = "" + try: + loss = self.finetune(model, inputs) + return loss + except RuntimeError as e: + if 'out of memory' in str(e): + oom = True + oom_message = str(e) + logger.warning(f'ran out of memory {self.oom_batch} on {self.args.local_rank}') + for k, v in inputs.items(): + print(k, v.size()) + else: + raise e + + if oom: + self.oom_batch += 1 + raise RuntimeError(oom_message) + + def prediction_in_one_step( + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: + + def prefix_allowed_tokens_fn(batch_id, sent): + src_sentence = inputs['input_ids'][batch_id] + return self.constraint_decoder.constraint_decoding(src_sentence=src_sentence, + tgt_generated=sent) + + if not self.args.predict_with_generate or prediction_loss_only: + return super().prediction_step( + model=model, + inputs=inputs, + prediction_loss_only=prediction_loss_only, + ignore_keys=ignore_keys, + ) + has_labels = "labels" in inputs + inputs = self._prepare_inputs(inputs) + + # remove pretrain field + if "record_input_ids" in inputs.keys(): + inputs.pop("record_input_ids") + if "mlm_input_ids" in inputs.keys(): + inputs.pop("mlm_input_ids") + if "mlm_target_ids" in inputs.keys(): + inputs.pop("mlm_target_ids") + if "mlm_decoder_input_ids" in inputs.keys(): + inputs.pop("mlm_decoder_input_ids") + if "noised_input_ids" in inputs.keys(): + inputs.pop("noised_input_ids") + if "noised_att_mask" in inputs.keys(): + inputs.pop("noised_att_mask") + + gen_kwargs = { + "max_length": self._max_length if hasattr(self, "_max_length") and self._max_length is not None else self.model.config.max_length, + "num_beams": self._num_beams if hasattr(self, "_num_beams") and self._num_beams is not None else self.model.config.num_beams, + "prefix_allowed_tokens_fn": prefix_allowed_tokens_fn if self.constraint_decoder else None, + } + + generated_tokens = self.model.generate( + inputs["input_ids"], + attention_mask=inputs["attention_mask"], + **gen_kwargs, + ) + + # in case the batch is shorter than max length, the output should be padded + if generated_tokens.shape[-1] < gen_kwargs["max_length"]: + generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) + + with torch.no_grad(): + if self.use_amp: + with autocast(): + outputs = model(**inputs) + else: + outputs = model(**inputs) + if has_labels: + if self.label_smoother is not None: + loss = self.label_smoother(outputs, inputs["labels"]).mean().detach() + else: + loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach() + else: + loss = None + + if self.args.prediction_loss_only: + return loss, None, None + + labels = inputs["labels"] + if labels.shape[-1] < gen_kwargs["max_length"]: + labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) + + return loss, generated_tokens, labels + + def prediction_step( + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: + inputs = self._prepare_inputs(inputs) + + # store original input_ids + original_input_ids = inputs["input_ids"] + original_att_mask = inputs["attention_mask"] + + # first prediction + loss, generated_tokens, labels = self.prediction_in_one_step(model, inputs, prediction_loss_only, ignore_keys) + + # refine + input_ids_with_refine_prompt, att_mask_with_refine_prompt = self.cat_prompt(original_input_ids) + new_input_ids, new_att_mask = self.cat_two_various_length_tensor(input_ids_with_refine_prompt, generated_tokens) + # new_input_ids, new_att_mask = self.cat_two_various_length_tensor(original_input_ids, generated_tokens) + inputs["input_ids"] = new_input_ids + inputs["attention_mask"] = new_att_mask + + loss, generated_tokens, labels = self.prediction_in_one_step(model, inputs, prediction_loss_only, ignore_keys) + + return loss, generated_tokens, labels + +########################################################################### + +ConstraintSeq2SeqTrainer = OriginalConstraintSeq2SeqTrainer + +def main(): pass + + +if __name__ == "__main__": + main() diff --git a/metaretriever/uie/seq2seq/constraint_decoder/__init__.py b/metaretriever/uie/seq2seq/constraint_decoder/__init__.py new file mode 100644 index 00000000..39a8fa58 --- /dev/null +++ b/metaretriever/uie/seq2seq/constraint_decoder/__init__.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +from uie.seq2seq.constraint_decoder.spotasoc_constraint_decoder import ( + SpotAsocConstraintDecoder, + SpotConstraintDecoder +) + + +def get_constraint_decoder(tokenizer, type_schema, decoding_schema, task_name='event', source_prefix=None): + if decoding_schema == 'spotasoc': + if len(type_schema.role_list) == 0: + task_map = { + 'entity': SpotConstraintDecoder, + 'relation': SpotConstraintDecoder, + 'event': SpotConstraintDecoder, + 'record': SpotConstraintDecoder, + } + else: + task_map = { + 'entity': SpotAsocConstraintDecoder, + 'relation': SpotAsocConstraintDecoder, + 'event': SpotAsocConstraintDecoder, + 'record': SpotAsocConstraintDecoder, + } + else: + raise NotImplementedError( + f'Type Schema {type_schema}, Decoding Schema {decoding_schema}, Task {task_name} do not map to constraint decoder.' + ) + return task_map[task_name](tokenizer=tokenizer, type_schema=type_schema, source_prefix=source_prefix) diff --git a/metaretriever/uie/seq2seq/constraint_decoder/constraint_decoder.py b/metaretriever/uie/seq2seq/constraint_decoder/constraint_decoder.py new file mode 100644 index 00000000..c419d75e --- /dev/null +++ b/metaretriever/uie/seq2seq/constraint_decoder/constraint_decoder.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +from collections import defaultdict +import os +from typing import List + + +def match_sublist(the_list, to_match): + """ + + :param the_list: [1, 2, 3, 4, 5, 6, 1, 2, 4, 5] + :param to_match: + [1, 2] + :return: + [(0, 1), (6, 7)] + """ + len_to_match = len(to_match) + matched_list = list() + for index in range(len(the_list) - len_to_match + 1): + if to_match == the_list[index:index + len_to_match]: + matched_list += [(index, index + len_to_match - 1)] + return matched_list + + +def find_bracket_position(generated_text, _type_start, _type_end): + bracket_position = {_type_start: list(), _type_end: list()} + for index, char in enumerate(generated_text): + if char in bracket_position: + bracket_position[char] += [index] + return bracket_position + + +def build_sentence_tree(sentence): + tree = defaultdict(set) + + for prev_token, next_token in zip(sentence[:-1], sentence[1:]): + tree[prev_token].add(next_token) + + return tree + + +def generated_search_prefix_tree(generated, prefix_tree, tokenizer): + tree = prefix_tree + # Leaf is KEY_VALUE_SPLIT + for token in generated: + + if token not in tree: + return [tokenizer.eos_token] + tree = tree[token] + + return list(tree) + + +def generated_search_src_sequence(generated, src_sequence, end_sequence_search_tokens=None): + + if len(generated) == 0: + # All src tokens are valid before generation + return src_sequence + + matched_tuples = match_sublist(the_list=src_sequence, to_match=generated) + + valid_token = list() + for _, end in matched_tuples: + next_index = end + 1 + if next_index < len(src_sequence): + valid_token += [src_sequence[next_index]] + + if end_sequence_search_tokens: + valid_token += end_sequence_search_tokens + + return valid_token + + +class ConstraintDecoder: + def __init__(self, tokenizer, source_prefix): + self.tokenizer = tokenizer + self.source_prefix = source_prefix + self.source_prefix_tokenized = tokenizer.encode(source_prefix, + add_special_tokens=False) if source_prefix else [] + + def get_state_valid_tokens(self, src_sentence: List[str], tgt_generated: List[str]) -> List[str]: + pass + + def constraint_decoding(self, src_sentence, tgt_generated): + if self.source_prefix_tokenized: + # Remove Source Prefix for Generation + src_sentence = src_sentence[len(self.source_prefix_tokenized):] + + valid_token_ids = self.get_state_valid_tokens(src_sentence.tolist(), tgt_generated.tolist()) + + return valid_token_ids diff --git a/metaretriever/uie/seq2seq/constraint_decoder/spotasoc_constraint_decoder.py b/metaretriever/uie/seq2seq/constraint_decoder/spotasoc_constraint_decoder.py new file mode 100644 index 00000000..553872b5 --- /dev/null +++ b/metaretriever/uie/seq2seq/constraint_decoder/spotasoc_constraint_decoder.py @@ -0,0 +1,317 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +import os +from typing import List, Dict +from uie.extraction.label_tree import get_label_name_tree +from uie.extraction.constants import ( + span_start, + type_start, + type_end, + null_span, + text_start +) +from uie.seq2seq.constraint_decoder.constraint_decoder import ( + ConstraintDecoder, + find_bracket_position, + generated_search_src_sequence +) + + +debug = True if 'DEBUG' in os.environ else False + + +class SpotAsocConstraintDecoder(ConstraintDecoder): + def __init__(self, tokenizer, type_schema, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + self.tree_end = self.tokenizer.convert_tokens_to_ids([span_start])[0] + self.type_tree = get_label_name_tree(type_schema.type_list, self.tokenizer, end_symbol=self.tree_end) + self.role_tree = get_label_name_tree(type_schema.role_list, self.tokenizer, end_symbol=self.tree_end) + self.type_start = self.tokenizer.convert_tokens_to_ids([type_start])[0] + self.type_end = self.tokenizer.convert_tokens_to_ids([type_end])[0] + self.span_start = self.tokenizer.convert_tokens_to_ids([span_start])[0] + self.null_span = self.tokenizer.convert_tokens_to_ids([null_span])[0] + self.text_start = self.tokenizer.convert_tokens_to_ids([text_start])[0] + + def check_state(self, tgt_generated): + if tgt_generated[-1] == self.tokenizer.pad_token_id: + return 'start', -1 + + # special_token_set = {EVENT_TYPE_LEFT, EVENT_TYPE_RIGHT} + special_token_set = {self.type_start, self.type_end, self.span_start} + special_index_token = list(filter(lambda x: x[1] in special_token_set, list(enumerate(tgt_generated)))) + + last_special_index, last_special_token = special_index_token[-1] + + if len(special_index_token) == 1: + if last_special_token != self.type_start: + return 'error', 0 + + bracket_position = find_bracket_position(tgt_generated, _type_start=self.type_start, _type_end=self.type_end) + start_number, end_number = len(bracket_position[self.type_start]), len(bracket_position[self.type_end]) + + if start_number == end_number: + return 'end_generate', -1 + if start_number == end_number + 1: + state = 'start_first_generation' + elif start_number == end_number + 2: + state = 'generate_trigger' + if last_special_token == self.span_start: + state = 'generate_trigger_text' + elif start_number == end_number + 3: + state = 'generate_role' + if last_special_token == self.span_start: + state = 'generate_role_text' + else: + state = 'error' + return state, last_special_index + + def search_prefix_tree_and_sequence(self, generated: List[str], prefix_tree: Dict, src_sentence: List[str], + end_sequence_search_tokens: List[str] = None): + """ + Generate Type Name + Text Span + :param generated: + :param prefix_tree: + :param src_sentence: + :param end_sequence_search_tokens: + :return: + """ + tree = prefix_tree + for index, token in enumerate(generated): + tree = tree[token] + is_tree_end = len(tree) == 1 and self.tree_end in tree + + if is_tree_end: + valid_token = generated_search_src_sequence( + generated=generated[index + 1:], + src_sequence=src_sentence, + end_sequence_search_tokens=end_sequence_search_tokens, + ) + return valid_token + + if self.tree_end in tree: + try: + valid_token = generated_search_src_sequence( + generated=generated[index + 1:], + src_sequence=src_sentence, + end_sequence_search_tokens=end_sequence_search_tokens, + ) + return valid_token + except IndexError: + # Still search tree + continue + + valid_token = list(tree.keys()) + return valid_token + + def get_state_valid_tokens(self, src_sentence, tgt_generated): + """ + + :param src_sentence: + :param tgt_generated: + :return: + List[str], valid token list + """ + if self.tokenizer.eos_token_id in src_sentence: + src_sentence = src_sentence[:src_sentence.index(self.tokenizer.eos_token_id)] + + if self.text_start in src_sentence: + src_sentence = src_sentence[src_sentence.index(self.text_start) + 1:] + + state, index = self.check_state(tgt_generated) + + print("State: %s" % state) if debug else None + + if state == 'error': + print("Decode Error:") + print("Src:", self.tokenizer.convert_ids_to_tokens(src_sentence)) + print("Tgt:", self.tokenizer.convert_ids_to_tokens(tgt_generated)) + valid_tokens = [self.tokenizer.eos_token_id] + + elif state == 'start': + valid_tokens = [self.type_start] + + elif state == 'start_first_generation': + valid_tokens = [self.type_start, self.type_end] + + elif state == 'generate_trigger': + + if tgt_generated[-1] == self.type_start: + # Start Event Label + return list(self.type_tree.keys()) + + elif tgt_generated[-1] == self.type_end: + # EVENT_TYPE_LEFT: Start a new role + # EVENT_TYPE_RIGHT: End this event + return [self.type_start, self.type_end] + else: + valid_tokens = self.search_prefix_tree( + generated=tgt_generated[index + 1:], + prefix_tree=self.type_tree, + end_search_tokens=[self.span_start] + ) + + elif state in {'generate_trigger_text'}: + generated = tgt_generated[index + 1:] + + if len(generated) > 0 and generated[-1] == self.null_span: + return [self.type_end, self.type_start] + + valid_tokens = generated_search_src_sequence( + generated=generated, + src_sequence=src_sentence + [self.null_span], + end_sequence_search_tokens=[self.type_end, self.type_start], + ) + + elif state in {'generate_role_text'}: + generated = tgt_generated[index + 1:] + + if len(generated) > 0 and generated[-1] == self.null_span: + return [self.type_end] + + valid_tokens = generated_search_src_sequence( + generated=generated, + src_sequence=src_sentence + [self.null_span], + end_sequence_search_tokens=[self.type_end], + ) + + elif state == 'generate_role': + + if tgt_generated[-1] == self.type_start: + # Start Role Label + return list(self.role_tree.keys()) + + generated = tgt_generated[index + 1:] + valid_tokens = self.search_prefix_tree( + generated=generated, + prefix_tree=self.role_tree, + end_search_tokens=[self.span_start] + ) + + elif state == 'end_generate': + valid_tokens = [self.tokenizer.eos_token_id] + + else: + raise NotImplementedError('State `%s` for %s is not implemented.' % (state, self.__class__)) + + print("Valid: %s" % self.tokenizer.convert_ids_to_tokens(valid_tokens)) if debug else None + return valid_tokens + + def search_prefix_tree(self, generated: List[str], prefix_tree: Dict, + end_search_tokens: List[str] = None): + """ + Generate Type Name + Text Span + :param generated: + :param prefix_tree: + :param src_sentence: + :param end_search_tokens: + :return: + """ + tree = prefix_tree + for index, token in enumerate(generated): + tree = tree[token] + is_tree_end = len(tree) == 1 and self.tree_end in tree + + if is_tree_end: + return end_search_tokens + + valid_token = list(tree.keys()) + if self.tree_end in valid_token: + valid_token.remove(self.tree_end) + valid_token += end_search_tokens + return valid_token + + +class SpotConstraintDecoder(SpotAsocConstraintDecoder): + def __init__(self, tokenizer, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + + def check_state(self, tgt_generated): + if tgt_generated[-1] == self.tokenizer.pad_token_id: + return 'start', -1 + + special_token_set = {self.type_start, self.type_end, self.span_start} + special_index_token = list(filter(lambda x: x[1] in special_token_set, list(enumerate(tgt_generated)))) + + last_special_index, last_special_token = special_index_token[-1] + + if len(special_index_token) == 1: + if last_special_token != self.type_start: + return 'error', 0 + + bracket_position = find_bracket_position(tgt_generated, _type_start=self.type_start, _type_end=self.type_end) + start_number, end_number = len(bracket_position[self.type_start]), len(bracket_position[self.type_end]) + + if start_number == end_number: + return 'end_generate', -1 + if start_number == end_number + 1: + state = 'start_first_generation' + elif start_number == end_number + 2: + state = 'generate_span' + if last_special_token == self.span_start: + state = 'generate_span_text' + else: + state = 'error' + return state, last_special_index + + def get_state_valid_tokens(self, src_sentence, tgt_generated): + """ + + :param src_sentence: + :param tgt_generated: + :return: + List[str], valid token list + """ + if self.tokenizer.eos_token_id in src_sentence: + src_sentence = src_sentence[:src_sentence.index(self.tokenizer.eos_token_id)] + + if self.text_start in src_sentence: + src_sentence = src_sentence[src_sentence.index(self.text_start) + 1:] + + state, index = self.check_state(tgt_generated) + + print("State: %s" % state) if debug else None + + if state == 'error': + print("Decode Error:") + print("Src:", self.tokenizer.convert_ids_to_tokens(src_sentence)) + print("Tgt:", self.tokenizer.convert_ids_to_tokens(tgt_generated)) + valid_tokens = [self.tokenizer.eos_token_id] + + elif state == 'start': + valid_tokens = [self.type_start] + + elif state == 'start_first_generation': + valid_tokens = [self.type_start, self.type_end] + + elif state == 'generate_span': + + if tgt_generated[-1] == self.type_start: + # Start Event Label + return list(self.type_tree.keys()) + + elif tgt_generated[-1] == self.type_end: + raise RuntimeError('Invalid %s in %s' % (self.type_end, tgt_generated)) + + else: + valid_tokens = self.search_prefix_tree( + generated=tgt_generated[index + 1:], + prefix_tree=self.type_tree, + end_search_tokens=[self.span_start] + ) + + elif state == 'generate_span_text': + generated = tgt_generated[index + 1:] + valid_tokens = generated_search_src_sequence( + generated=generated, + src_sequence=src_sentence + [self.null_span], + end_sequence_search_tokens=[self.type_end], + ) + + elif state == 'end_generate': + valid_tokens = [self.tokenizer.eos_token_id] + + else: + raise NotImplementedError('State `%s` for %s is not implemented.' % (state, self.__class__)) + + print("Valid: %s" % valid_tokens) if debug else None + return valid_tokens diff --git a/metaretriever/uie/seq2seq/data_collator/__init__.py b/metaretriever/uie/seq2seq/data_collator/__init__.py new file mode 100644 index 00000000..a1578439 --- /dev/null +++ b/metaretriever/uie/seq2seq/data_collator/__init__.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- + + +from uie.seq2seq.data_collator.meta_data_collator import ( + DataCollatorForMetaSeq2Seq, + DynamicSSIGenerator, +) + + +__all__ = [ + 'DataCollatorForMetaSeq2Seq', 'DynamicSSIGenerator' +] diff --git a/metaretriever/uie/seq2seq/data_collator/meta_data_collator.py b/metaretriever/uie/seq2seq/data_collator/meta_data_collator.py new file mode 100644 index 00000000..8388d7f4 --- /dev/null +++ b/metaretriever/uie/seq2seq/data_collator/meta_data_collator.py @@ -0,0 +1,335 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +import os +from dataclasses import dataclass +import torch +import logging +import random +import math +from typing import Optional, Union +from collections import OrderedDict +from copy import deepcopy + +from transformers import PreTrainedTokenizerBase, PreTrainedModel +from transformers.file_utils import PaddingStrategy + +from uie.extraction.record_schema import RecordSchema +from uie.extraction.dataset_processer import spot_prompt, asoc_prompt +from uie.extraction.constants import BaseStructureMarker, text_start +from uie.extraction.utils import convert_to_record_function +from uie.extraction.noiser.spot_asoc_noiser import SpotAsocNoiser + +import pdb + +logger = logging.getLogger("__main__") + + +class DynamicSSIGenerator(): + """ + Sample negative spot and asoc to construct SSI + """ + def __init__(self, tokenizer: PreTrainedTokenizerBase, schema: RecordSchema, positive_rate=1, negative=5, ordered_prompt=False) -> None: + self.spot_dict = self.get_ordered_dict(schema.type_list, tokenizer) + self.asoc_dict = self.get_ordered_dict(schema.role_list, tokenizer) + self.spot_list = list(self.spot_dict.keys()) + self.asoc_list = list(self.asoc_dict.keys()) + self.spot_prompt = tokenizer.get_vocab()[spot_prompt] + self.asoc_prompt = tokenizer.get_vocab()[asoc_prompt] + self.text_start = tokenizer.get_vocab()[text_start] + self.positive_rate = positive_rate if positive_rate > 0 and positive_rate < 1 else 1 + self.negative = negative + self.ordered_prompt = ordered_prompt + logger.info(f"Meta Sample, Negative: {self.negative}, Ordered Prompt: {self.ordered_prompt}") + + @staticmethod + def get_ordered_dict(schema_name_list, tokenizer): + schema_ordered_dict = OrderedDict() + for name in schema_name_list: + schema_ordered_dict[name] = tokenizer.encode(name, add_special_tokens=False) + return schema_ordered_dict + + @staticmethod + def sample_negative(postive, candidates, k=5): + if k < 0: + k = len(candidates) + negative_set = set() + for index in torch.randperm(len(candidates))[:k].tolist(): + negative = candidates[index] + if negative not in postive: + negative_set.add(negative) + return list(negative_set) + + def sample_spot(self, positive): + """ Sample spot + """ + negative_spot = self.sample_negative(postive=positive, candidates=self.spot_list, k=self.negative) + positive_spot = random.sample(positive, math.floor(len(positive) * self.positive_rate)) + + prefix_spot_candidates = positive_spot + negative_spot + converted_spot_prefix = self.convert_prefix( + candidates=prefix_spot_candidates, + prompt=self.spot_prompt, + mapper=self.spot_dict, + ordered_prompt=self.ordered_prompt, + ) + + return converted_spot_prefix, positive_spot, negative_spot + + def sample_asoc(self, positive): + """ Sample Asoc + """ + negative_asoc = self.sample_negative(postive=positive, candidates=self.asoc_list, k=self.negative) + prefix_asoc_candidates = positive + negative_asoc + converted_asoc_prefix = self.convert_prefix( + candidates=prefix_asoc_candidates, + prompt=self.asoc_prompt, + mapper=self.asoc_dict, + ordered_prompt=self.ordered_prompt, + ) + return converted_asoc_prefix, negative_asoc + + def full_spot(self, shuffle=False): + # Random Prompt + Shuffle + if not self.ordered_prompt and shuffle: + ordered_prompt = False + else: + ordered_prompt = True + return self.convert_prefix( + candidates=self.spot_list, + prompt=self.spot_prompt, + mapper=self.spot_dict, + ordered_prompt=ordered_prompt, + ) + + def full_asoc(self, shuffle=False): + # Random Prompt + Shuffle + if not self.ordered_prompt and shuffle: + ordered_prompt = False + else: + ordered_prompt = True + return self.convert_prefix( + candidates=self.asoc_list, + prompt=self.asoc_prompt, + mapper=self.asoc_dict, + ordered_prompt=ordered_prompt, + ) + + @staticmethod + def convert_prefix(candidates, prompt, mapper, ordered_prompt=True): + prefix = list() + + if ordered_prompt: + candidate_sorted = sorted([(candidate, index) for index, candidate in enumerate(candidates)]) + index_list = [index for _, index in candidate_sorted] + else: + index_list = torch.randperm(len(candidates)).tolist() + + for index in index_list: + prefix += [prompt] + prefix += mapper[candidates[index]] + return prefix + + +@dataclass +class DataCollatorForMetaSeq2Seq: + """ + Data collator that will dynamically pad the inputs received, as well as the labels. + + Args: + tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`): + The tokenizer used for encoding the data. + model (:class:`~transformers.PreTrainedModel`): + The model that is being trained. If set and has the `prepare_decoder_input_ids_from_labels`, use it to + prepare the `decoder_input_ids` + + This is useful when using `label_smoothing` to avoid calculating loss twice. + padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding index) + among: + + * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence is provided). + * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the + maximum acceptable input length for the model if that argument is not provided. + * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of + different lengths). + max_length (:obj:`int`, `optional`): + Maximum length of the returned list and optionally padding length (see above). + max_target_length (:obj:`int`, `optional`): + Maximum length of target sequence length. + pad_to_multiple_of (:obj:`int`, `optional`): + If set will pad the sequence to a multiple of the provided value. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= + 7.5 (Volta). + label_pad_token_id (:obj:`int`, `optional`, defaults to -100): + The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions). + """ + + tokenizer: PreTrainedTokenizerBase + negative_sampler: DynamicSSIGenerator + model: Optional[PreTrainedModel] = None + padding: Union[bool, str, PaddingStrategy] = True + max_length: Optional[int] = None + max_target_length: Optional[int] = None + max_prefix_length: Optional[int] = None + pad_to_multiple_of: Optional[int] = None + label_pad_token_id: int = -100 + spot_asoc_nosier: SpotAsocNoiser = None + decoding_format: str = 'spotasoc' + + def __call__(self, features): + """ Make Meta Schema Batch + + Args: + features (Dict): [description] + - sample_prompt: indicates sample_prompt example, need pop after call + - spots (List[str]): List of spots in this sentence, need pop after call + - asocs (List[str]): List of asocs in this sentence, need pop after call + - input_ids + - attention_mask + - labels + + Returns: + """ + for feature in features: + + sample_prompt = feature['sample_prompt'] + + if not sample_prompt: + # Evaluation using Ordered SSI + converted_spot_prefix = self.negative_sampler.full_spot(shuffle=self.model.training) + converted_asoc_prefix = self.negative_sampler.full_asoc(shuffle=self.model.training) + else: + # Sample SSI + converted_spot_prefix, positive_spot, negative_spot = self.negative_sampler.sample_spot(positive=feature.get('spots', [])) + converted_asoc_prefix, negative_asoc = self.negative_sampler.sample_asoc(positive=feature.get('asocs', [])) + + # Dynamic generating spot-asoc during training + if 'spot_asoc' in feature: + + # Deleted positive example Spot in Target that was not sampled by Prefix + feature['spot_asoc'] = [spot_asoc for spot_asoc in feature['spot_asoc'] if spot_asoc["label"] in positive_spot] + + # Inject rejection noise + if self.spot_asoc_nosier is not None: + if isinstance(self.spot_asoc_nosier, SpotAsocNoiser): + feature['spot_asoc'] = self.spot_asoc_nosier.add_noise( + feature['spot_asoc'], + spot_label_list=negative_spot, + asoc_label_list=negative_asoc, + ) + else: + raise NotImplementedError(f'{self.spot_asoc_nosier} is not implemented.') + + # Generate new record + record = convert_to_record_function[self.decoding_format]( + feature['spot_asoc'], + structure_maker=BaseStructureMarker() + ) + feature["labels"] = self.tokenizer.encode(record) + + feature.pop('sample_prompt') if 'sample_prompt' in feature else None + feature.pop('spot_asoc') if 'spot_asoc' in feature else None + feature.pop('spots') if 'spots' in feature else None + feature.pop('asocs') if 'asocs' in feature else None + + # record input ids + feature['record_input_ids'] = [1] + + # mlm input ids and target ids + mlm_input_ids, mlm_target_ids = self.generate_target_ids([feature["input_ids"]]) + feature["mlm_input_ids"] = mlm_input_ids[0] + [0] * (self.max_length - len(mlm_input_ids[0])) + feature["mlm_target_ids"] = mlm_target_ids[0] + [1] + [-100] * (self.max_length - len(mlm_target_ids[0]) - 1) + + # noised record inputs + original_text = feature.pop("text") + noised_records = feature.pop("noised_record") + if noised_records is not None: + noised_record = random.choice(noised_records) + noised_input_text = original_text + " Predicted results: " + noised_record + noised_inputs = self.tokenizer(noised_input_text, max_length=self.max_length, padding="max_length", truncation=True) + feature["noised_input_ids"] = noised_inputs["input_ids"] + feature["noised_att_mask"] = noised_inputs["attention_mask"] + + # add prefix + prefix = converted_spot_prefix + converted_asoc_prefix + # truncate `prefix` to max length + if self.max_prefix_length is not None and self.max_prefix_length >= 0: + prefix = prefix[:self.max_prefix_length] + + feature['input_ids'] = prefix + [self.negative_sampler.text_start] + feature['input_ids'] + + # truncate `input_ids` to max length + if self.max_length: + feature['input_ids'] = feature['input_ids'][:self.max_length] + if self.max_target_length and 'labels' in feature: + feature['labels'] = feature['labels'][:self.max_target_length] + + feature['attention_mask'] = [1] * len(feature['input_ids']) + + labels = [feature["labels"] for feature in features] if "labels" in features[0].keys() else None + # We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the + # same length to return tensors. + if labels is not None: + max_label_length = max(len(_label) for _label in labels) + padding_side = self.tokenizer.padding_side + for feature in features: + remainder = [self.label_pad_token_id] * (max_label_length - len(feature["labels"])) + feature["labels"] = ( + feature["labels"] + remainder if padding_side == "right" else remainder + feature["labels"] + ) + + features = self.tokenizer.pad( + features, + padding=self.padding, + max_length=self.max_length, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors="pt" + ) + + # prepare decoder_input_ids + if self.model is not None and hasattr(self.model, "prepare_decoder_input_ids_from_labels"): + decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=features["labels"]) + features["decoder_input_ids"] = decoder_input_ids + + mlm_decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=features["mlm_target_ids"]) + features["mlm_decoder_input_ids"] = mlm_decoder_input_ids + + return features + + def generate_target_ids(self, input_ids, mask_prob=0.15): + extra_tokens = [f"" for i in range(0, 100)] + mask_tokens = self.tokenizer.convert_tokens_to_ids(extra_tokens) + + masked_input_ids = [] + target_ids = [] + for _input_ids in input_ids: # let's calculate masks for denoising pretraining + _input_sent_embed = deepcopy(_input_ids) + _target_sent_embed = [] + masked_indexes = sorted(random.sample(range(0, len(_input_sent_embed)), # sample a word index in sentence + min(int(mask_prob * len(_input_sent_embed)), # number of tokens masked + len(mask_tokens) - 1))) # but never more than special tokens available + mask = [(i in masked_indexes) # this is True or False + for i in range(len(_input_sent_embed))] + i = 0 + end = len(_input_sent_embed) + masked_spans_counter = 0 + while i < end: + if mask[i]: + current_words_masked = [_input_sent_embed[i]] + _input_sent_embed[i] = mask_tokens[masked_spans_counter] + masked_spans_counter += 1 + while i + 1 < end and mask[i + 1]: + current_words_masked.append(_input_sent_embed[i + 1]) + del _input_sent_embed[i + 1] + del mask[i + 1] + end -= 1 + _target_sent_embed.extend(current_words_masked) + else: + if len(_target_sent_embed) == 0 or _target_sent_embed[-1] != mask_tokens[masked_spans_counter]: + _target_sent_embed.append(mask_tokens[masked_spans_counter]) + i += 1 + masked_input_ids.append(_input_sent_embed) + target_ids.append(_target_sent_embed) + return masked_input_ids, target_ids \ No newline at end of file diff --git a/metaretriever/uie/seq2seq/features.py b/metaretriever/uie/seq2seq/features.py new file mode 100644 index 00000000..f4fa7682 --- /dev/null +++ b/metaretriever/uie/seq2seq/features.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +from datasets import Features, Value, Sequence + +DatasetFeature = Features({ + 'text': Value(dtype='string', id=None), + 'tokens': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), + 'record': Value(dtype='string', id=None), + 'entity': [{'type': Value(dtype='string', id=None), + 'offset': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), + 'text': Value(dtype='string', id=None)}], + 'relation': [{'type': Value(dtype='string', id=None), + 'args': [{'type': Value(dtype='string', id=None), + 'offset': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), + 'text': Value(dtype='string', id=None)}]}], + 'event': [{'type': Value(dtype='string', id=None), + 'offset': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), + 'text': Value(dtype='string', id=None), + 'args': [{'type': Value(dtype='string', id=None), + 'offset': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), + 'text': Value(dtype='string', id=None)}]}], + 'spot': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), + 'asoc': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), + 'spot_asoc': [{'span': Value(dtype='string', id=None), + 'label': Value(dtype='string', id=None), + 'asoc': Sequence(feature=Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), length=-1, id=None)}], + 'task': Value(dtype='string', id=None), +}) + + +_processed_feature = { + 'input_ids': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), + 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None), + 'labels': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), + 'spots': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), + 'asocs': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), + 'spot_asoc': [ + {'span': Value(dtype='string', id=None), + 'label': Value(dtype='string', id=None), + 'asoc': Sequence(feature=Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), length=-1, id=None)} + ], + 'task': Value(dtype='string', id=None), + 'sample_prompt': Value(dtype='bool', id=None) +} + + +ProcessedFeature = Features(_processed_feature) + + +RecordFeature = Features({ + 'input_ids': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), + 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None), + 'labels': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), + 'spots': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), + 'asocs': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), + 'spot_asoc': [ + { + 'span': Value(dtype='string', id=None), + 'label': Value(dtype='string', id=None), + 'asoc': Sequence(feature=Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), length=-1, id=None) + } + ], + 'sample_prompt': Value(dtype='bool', id=None), + "noised_record": Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), + "text": Value(dtype='string', id=None), +}) diff --git a/metaretriever/uie/seq2seq/model/__init__.py b/metaretriever/uie/seq2seq/model/__init__.py new file mode 100644 index 00000000..cd074e8b --- /dev/null +++ b/metaretriever/uie/seq2seq/model/__init__.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- + +from uie.seq2seq.model.prompt_tuning import ( + PromptSeq2SeqTransformer +) + + +__all__ = [ + "PromptSeq2SeqTransformer" +] diff --git a/metaretriever/uie/seq2seq/model/prompt_tuning.py b/metaretriever/uie/seq2seq/model/prompt_tuning.py new file mode 100644 index 00000000..b606b8dc --- /dev/null +++ b/metaretriever/uie/seq2seq/model/prompt_tuning.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from transformers import T5ForConditionalGeneration + +import pdb + +class PromptSeq2SeqTransformer(T5ForConditionalGeneration): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.prompt_length = 10 + self.prompt_embedding_size = 768 + self.prompt_embeddings = nn.Embedding(self.prompt_length, self.prompt_embedding_size) + + self.prompt_encoder = nn.TransformerEncoder( + nn.TransformerEncoderLayer(d_model=self.prompt_embedding_size, + nhead=12, + dim_feedforward=self.prompt_embedding_size, + batch_first=True), + num_layers=1 + ) + + def forward(self, **inputs): + input_ids = inputs["input_ids"] + batch_size = input_ids.shape[0] + raw_embed = self.shared(input_ids) + raw_att_mask = inputs["attention_mask"] + + prompt_embed = self.prompt_embeddings( + torch.LongTensor(list(range(self.prompt_length))).to(input_ids.device) + ) + prompt_embed = prompt_embed.unsqueeze(0) + prompt_embed = self.prompt_encoder(prompt_embed) + prompt_embed = prompt_embed.expand(batch_size, -1, -1) + prompt_att_mask = torch.ones(batch_size, self.prompt_length).to(raw_att_mask.device) + + input_embed = torch.cat([prompt_embed, raw_embed], dim=1) + att_mask = torch.cat([prompt_att_mask, raw_att_mask], dim=1) + + inputs.pop("input_ids") + inputs["inputs_embeds"] = input_embed + inputs["attention_mask"] = att_mask + + return super().forward(**inputs) + + +if __name__ == "__main__": + from transformers import AutoTokenizer + model_path = "./uie_models/uie-base-en" + tokenizer = AutoTokenizer.from_pretrained(model_path) + + sentence_1 = "Hello" + sentence_2 = "world" + inputs = tokenizer(sentence_1, return_tensors="pt") + inputs["decoder_input_ids"] = tokenizer(sentence_2, return_tensors="pt").input_ids + + # model = T5ForConditionalGeneration.from_pretrained(model_path) + model = PromptSeq2SeqTransformer.from_pretrained(model_path) + inputs["add_prompt"] = True + + model.eval() + output = model(**inputs) + print(output.logits[:, :]) + # pdb.set_trace() + # pass \ No newline at end of file diff --git a/metaretriever/uie/seq2seq/noise_record.py b/metaretriever/uie/seq2seq/noise_record.py new file mode 100644 index 00000000..f7534a56 --- /dev/null +++ b/metaretriever/uie/seq2seq/noise_record.py @@ -0,0 +1,466 @@ +import json +import os +import random +from tqdm import tqdm +from copy import deepcopy +import numpy as np + +import pdb + +# %% noise function + +NOISE_NUM = 4 + +THRESHOLD = 0.8 +TRIPLE_THRESHOLD = [0.6, 0.8] +EVENT_THRESHOLD = [0.6, 0.8] + +DECAY_COEF = 0.8 +NOISE_OFFSET_THRESHOLD = 3 +NOISE_OFFSET_RANGE = list(range(NOISE_OFFSET_THRESHOLD)) +NOISE_OFFSET_WEIGHT = np.exp(- DECAY_COEF * np.array(NOISE_OFFSET_RANGE)) +NOISE_OFFSET_WEIGHT = NOISE_OFFSET_WEIGHT / NOISE_OFFSET_WEIGHT.sum() + +# %% noise entity + +def noise_entity_type(entity_list): + entity_type_list = [] + for entity in entity_list: + entity_type_list.append(entity["type"]) + entity_type_list = list(set(entity_type_list)) + + noised_entity_list = [] + for entity in entity_list: + noised_entity = deepcopy(entity) + if np.random.rand() > THRESHOLD: + noised_entity_type = random.choice(entity_type_list) + noised_entity["type"] = noised_entity_type + noised_entity_list.append(noised_entity) + return noised_entity_list + + +def noise_entity_offset(entity_list, tokens): + noised_entity_list = [] + for entity in entity_list: + noised_entity = deepcopy(entity) + + entity_offset = noised_entity["offset"] + start_index, end_index = entity_offset[0], entity_offset[-1] + + start_noise = np.random.choice(NOISE_OFFSET_RANGE, p=NOISE_OFFSET_WEIGHT) + end_noise = np.random.choice(NOISE_OFFSET_RANGE, p=NOISE_OFFSET_WEIGHT) + + noised_start_index = max(start_index-start_noise, 0) + noised_end_index = min(end_index+end_noise, len(tokens)-1) + noised_entity_offset = list(range(noised_start_index, noised_end_index+1)) + + noised_entity_mention = " ".join(tokens[noised_start_index:noised_end_index+1]) + + noised_entity["offset"] = noised_entity_offset + noised_entity["text"] = noised_entity_mention + + noised_entity_list.append(noised_entity) + return noised_entity_list + +def noise_entity_with_other_entity(entity_list): + type_entity_mapping = {} + for entity in entity_list: + entity_type = entity["type"] + if entity_type not in type_entity_mapping: + type_entity_mapping[entity_type] = [] + type_entity_mapping[entity_type].append(entity) + + noised_entity_list = [] + for entity in entity_list: + noised_entity = deepcopy(entity) + if np.random.rand() > THRESHOLD: + entity_type = noised_entity["type"] + other_entity = random.choice(type_entity_mapping[entity_type]) + noised_entity["text"] = other_entity["text"] + noised_entity["offset"] = other_entity["offset"] + noised_entity_list.append(noised_entity) + return noised_entity_list + +# %% noise triple + +def noise_relation_type(triple_list): + relation_type_list = [] + for triple in triple_list: + relation_type_list.append(triple["type"]) + relation_type_list = list(set(relation_type_list)) + + noised_triple_list = [] + for triple in triple_list: + noised_triple = deepcopy(triple) + if np.random.rand() > THRESHOLD: + noised_relation_type = random.choice(relation_type_list) + noised_triple["type"] = noised_relation_type + noised_triple_list.append(noised_triple) + return noised_triple_list + +def noise_triple_num(triple_list, entity_list): + noised_triple_list = [] + for triple in triple_list: + p = np.random.rand() + if p < TRIPLE_THRESHOLD[0]: # do nothing + noised_triple_list.append(triple) + elif p < TRIPLE_THRESHOLD[1]: # add noised triple + noised_triple_list.append(triple) + + noised_triple = deepcopy(triple) + replaced_tail = random.choice(entity_list) + noised_triple["args"][1] = replaced_tail + noised_triple_list.append(noised_triple) + else: # remove triple + pass + + return noised_triple_list + +# %% noise event + +def build_trigger_list(event_list): + trigger_list = [] + for event in event_list: + trigger_mention = event["text"] + trigger_type = event["type"] + trigger_offset = event["offset"] + trigger = { + "type": trigger_type, + "offset": trigger_offset, + "text": trigger_mention + } + trigger_list.append(trigger) + return trigger_list + +def build_argument_list(event_list): + argument_list = [] + for event in event_list: + arguments = event["args"] + argument_list.extend(arguments) + return argument_list + +def noise_event_num(event_list, all_trigger_list): + noised_event_list = [] + for event in event_list: + p = np.random.rand() + if p < EVENT_THRESHOLD[0]: # do nothing + noised_event_list.append(event) + elif p < EVENT_THRESHOLD[1]: # add noised event + noised_event_list.append(event) + noised_event = deepcopy(event) + replaced_trigger = random.choice(all_trigger_list) + for key in replaced_trigger: + noised_event[key] = replaced_trigger[key] + noised_event_list.append(noised_event) + else: # remove event + pass + return noised_event_list + +def noise_trigger_type(event_list, all_trigger_list): + event_type_list = list(set([trigger["type"] for trigger in all_trigger_list])) + + noised_event_list = [] + for event in event_list: + noised_event = deepcopy(event) + if np.random.rand() > THRESHOLD: + noised_event_type = random.choice(event_type_list) + noised_event["type"] = noised_event_type + noised_event_list.append(noised_event) + return noised_event_list + +def noise_trigger_with_other_trigger(event_list, all_trigger_list): + trigger_mention_list = list([(trigger["text"], trigger["offset"]) for trigger in all_trigger_list]) + + noised_event_list = [] + for event in event_list: + noised_event = deepcopy(event) + if np.random.rand() > THRESHOLD: + noised_trigger_mention, noised_trigger_offset = random.choice(trigger_mention_list) + noised_event["text"] = noised_trigger_mention + noised_event["offset"] = noised_trigger_offset + noised_event_list.append(noised_event) + return noised_event_list + +def noise_trigger_offset(event_list, tokens): + noised_event_list = [] + for event in event_list: + noised_event = deepcopy(event) + + event_offset = noised_event["offset"] + start_index, end_index = event_offset[0], event_offset[-1] + + start_noise = np.random.choice(NOISE_OFFSET_RANGE, p=NOISE_OFFSET_WEIGHT) + end_noise = np.random.choice(NOISE_OFFSET_RANGE, p=NOISE_OFFSET_WEIGHT) + + noised_start_index = max(start_index-start_noise, 0) + noised_end_index = min(end_index+end_noise, len(tokens)-1) + noised_event_offset = list(range(noised_start_index, noised_end_index+1)) + + noised_event_mention = " ".join(tokens[noised_start_index:noised_end_index+1]) + + noised_event["offset"] = noised_event_offset + noised_event["text"] = noised_event_mention + + noised_event_list.append(noised_event) + return noised_event_list + +def noise_argument_num(event_list, all_argument_list): + noised_event_list = [] + for event in event_list: + noised_event = deepcopy(event) + noised_argument_list = [] + for argument in noised_event["args"]: + p = np.random.rand() + if p < EVENT_THRESHOLD[0]: # do nothing + noised_argument_list.append(argument) + elif p < EVENT_THRESHOLD[1]: # add noised event + noised_argument_list.append(argument) + noised_argument = deepcopy(argument) + replaced_argument = random.choice(all_argument_list) + for key in replaced_argument: + noised_argument[key] = replaced_argument[key] + noised_argument_list.append(noised_argument) + else: # remove event + pass + noised_event["args"] = noised_argument_list + noised_event_list.append(noised_event) + return noised_event_list + +def noise_argument_type(event_list, all_argument_list): + argument_type_list = list(set([argument["type"] for argument in all_argument_list])) + + noised_event_list = [] + for event in event_list: + noised_event = deepcopy(event) + for argument in noised_event["args"]: + if np.random.rand() > THRESHOLD: + noised_argument_type = random.choice(argument_type_list) + noised_event["type"] = noised_argument_type + noised_event_list.append(noised_event) + return noised_event_list + +def noise_argument_with_other_argument(event_list, all_argument_list): + argument_mention_list = list([(argument["text"], argument["offset"]) for argument in all_argument_list]) + + noised_event_list = [] + for event in event_list: + noised_event = deepcopy(event) + for argument in noised_event["args"]: + if np.random.rand() > THRESHOLD: + noised_argument_mention, noised_argument_offset = random.choice(argument_mention_list) + argument["text"] = noised_argument_mention + argument["offset"] = noised_argument_offset + noised_event_list.append(noised_event) + return noised_event_list + +def noise_argument_offset(event_list, tokens): + noised_event_list = [] + for event in event_list: + noised_event = deepcopy(event) + for argument in noised_event["args"]: + argument_offset = argument["offset"] + start_index, end_index = argument_offset[0], argument_offset[-1] + + start_noise = np.random.choice(NOISE_OFFSET_RANGE, p=NOISE_OFFSET_WEIGHT) + end_noise = np.random.choice(NOISE_OFFSET_RANGE, p=NOISE_OFFSET_WEIGHT) + + noised_start_index = max(start_index-start_noise, 0) + noised_end_index = min(end_index+end_noise, len(tokens)-1) + noised_argument_offset = list(range(noised_start_index, noised_end_index+1)) + + noised_argument_mention = " ".join(tokens[noised_start_index:noised_end_index+1]) + + argument["offset"] = noised_argument_offset + argument["text"] = noised_argument_mention + + noised_event_list.append(noised_event) + return noised_event_list + +# %% utils + +def create_entity_uri(entity_list): + entity_uri_mapping = {} + for i, entity in enumerate(entity_list): + if "uri" not in entity: + entity_uri_mapping[json.dumps(entity)] = str(i) + entity["uri"] = str(i) + else: + entity_uri_mapping[json.dumps(entity)] = entity["uri"] + return entity_uri_mapping + +def update_entity_uri_in_triple(triple_list, entity_uri_mapping): + for triple in triple_list: + head, tail = triple["args"] + if "uri" not in head: + head_str = json.dumps(head) + if head_str not in entity_uri_mapping: # !!! + entity_uri_mapping[head_str] = str(len(entity_uri_mapping)) + head["uri"] = entity_uri_mapping[head_str] + if "uri" not in tail: + tail_str = json.dumps(tail) + if tail_str not in entity_uri_mapping: # !!! + entity_uri_mapping[tail_str] = str(len(entity_uri_mapping)) + tail["uri"] = entity_uri_mapping[tail_str] + return triple_list + +def build_entity_dict(entity_list): + entity_dict = {} + for entity in entity_list: + entity_uri = entity["uri"] + entity_dict[entity_uri] = entity + return entity_dict + +def update_relation_triple_by_noised_entity(triple_list, noised_entity_dict): + noised_triple_list = [] + for triple in triple_list: + noised_triple = deepcopy(triple) + head, tail = noised_triple["args"] + noised_head = noised_entity_dict[head["uri"]] if head["uri"] in noised_entity_dict else head + noised_tail = noised_entity_dict[tail["uri"]] if tail["uri"] in noised_entity_dict else tail + # noised_head, noised_tail = noised_entity_dict[head["uri"]], noised_entity_dict[tail["uri"]] + noised_triple["args"] = [noised_head, noised_tail] + noised_triple_list.append(noised_triple) + return noised_triple_list + +def create_spot_asoc_field(instance_entity_list, instance_triple_list, instance_event_list): + instance_spot_asoc_list = [] + + for entity in instance_entity_list: + instance_spot_asoc = { + "span": entity["text"], + "label": entity["type"], + "asoc": [] + } + + for triple in instance_triple_list: + if triple["args"][0]["uri"] == entity["uri"]: + asoc_record = [triple["type"], triple["args"][1]["text"]] + instance_spot_asoc["asoc"].append(asoc_record) + + instance_spot_asoc_list.append(instance_spot_asoc) + + for event in instance_event_list: + instance_spot_asoc = { + "span": event["text"], + "label": event["type"], + "asoc": [] + } + + for argument in event["args"]: + asoc_record = [argument["type"], argument["text"]] + instance_spot_asoc["asoc"].append(asoc_record) + + instance_spot_asoc_list.append(instance_spot_asoc) + + return instance_spot_asoc_list + +def create_record_field(instance_spot_asoc_list): + instance_record = " " + for instance_spot_asoc in instance_spot_asoc_list: + instance_record += " " + + instance_record += instance_spot_asoc["label"] + " " + instance_record += " " + instance_record += instance_spot_asoc["span"] + " " + + if len(instance_spot_asoc["asoc"]) != 0: + for asoc in instance_spot_asoc["asoc"]: + instance_record += " " + + instance_record += asoc[0] + " " + instance_record += " " + instance_record += asoc[1] + " " + + instance_record += " " + + instance_record += " " + instance_record += "" + + return instance_record + +# %% aggregate + +def create_noised_record(tokens, entity_list, triple_list, event_list): + entity_uri_mapping = create_entity_uri(entity_list) + triple_list = update_entity_uri_in_triple(triple_list, entity_uri_mapping) + + all_trigger_list = build_trigger_list(event_list) + all_argument_list = build_argument_list(event_list) + + noised_record_list = [] + for _ in range(NOISE_NUM): + # noise entity + noised_entity_list = noise_entity_offset(entity_list, tokens) + noised_entity_list = noise_entity_with_other_entity(noised_entity_list) + noised_entity_list = noise_entity_type(noised_entity_list) + + noised_entity_dict = build_entity_dict(noised_entity_list) + + # noise triple + noised_triple_list = update_relation_triple_by_noised_entity(triple_list, noised_entity_dict) + + noised_triple_list = noise_relation_type(noised_triple_list) + noised_triple_list = noise_triple_num(noised_triple_list, noised_entity_list) + + # noise event + noised_event_list = noise_event_num(event_list, all_trigger_list) + + noised_event_list = noise_trigger_type(noised_event_list, all_trigger_list) + noised_event_list = noise_trigger_with_other_trigger(noised_event_list, all_trigger_list) + noised_event_list = noise_trigger_offset(noised_event_list, tokens) + + noised_event_list = noise_argument_num(noised_event_list, all_argument_list) + noised_event_list = noise_argument_type(noised_event_list, all_argument_list) + noised_event_list = noise_argument_with_other_argument(noised_event_list, all_argument_list) + noised_event_list = noise_argument_offset(noised_event_list, tokens) + + # create noised record + noised_spot_asoc_list = create_spot_asoc_field(noised_entity_list, noised_triple_list, noised_event_list) + noised_record = create_record_field(noised_spot_asoc_list) + noised_record_list.append(noised_record) + + # remove uir field + for entity in entity_list: + del entity["uri"] + + for triple in triple_list: + head, tail = triple["args"] + del head["uri"] + del tail["uri"] + + return noised_record_list + + +if __name__ == "__main__": + seed = 0 + random.seed(seed) + np.random.seed(seed) + + # output_dir = "./data/text2spotasoc/relation/conll04/" + output_dir = "./data/text2spotasoc/event/oneie_ace05_en_event/" + + original_all_file = os.path.join(output_dir, "train.json") + noised_all_file = os.path.join(output_dir, "noised_train.json") + + with open(original_all_file) as src, open(noised_all_file, "w") as tgt: + for line in tqdm(src): + instance = json.loads(line) + + tokens = instance["tokens"] + entity_list = instance["entity"] + triple_list = instance["relation"] + event_list = instance["event"] + spot_asoc_list = instance["spot_asoc"] + record = instance["record"] + + # if len(event_list) > 0: + # pdb.set_trace() + noised_record_list = create_noised_record(tokens, entity_list, triple_list, event_list) + + instance["noised_record"] = noised_record_list + + json_str = json.dumps(instance) + # tgt.write(json_str + "\n")q + + pdb.set_trace() + pass \ No newline at end of file diff --git a/metaretriever/uie_json.py b/metaretriever/uie_json.py new file mode 100644 index 00000000..5574b119 --- /dev/null +++ b/metaretriever/uie_json.py @@ -0,0 +1,111 @@ +# coding=utf-8 + +import json +from dataclasses import dataclass +from io import BytesIO +from typing import Optional + +import pyarrow as pa +import pyarrow.json as paj + +import datasets + + +@dataclass +class JsonConfig(datasets.BuilderConfig): + """BuilderConfig for JSON.""" + + features: Optional[datasets.Features] = None + field: Optional[str] = None + use_threads: bool = True + block_size: Optional[int] = None + newlines_in_values: Optional[bool] = None + + @property + def pa_read_options(self): + return paj.ReadOptions(use_threads=self.use_threads, block_size=self.block_size) + + @property + def pa_parse_options(self): + import pickle + table_schema = pickle.load(open('etc/record.dataload.schema', 'rb')) + # print(table_schema) + table_schema = table_schema.append(pa.field("noised_record", pa.list_(pa.string()))) + return paj.ParseOptions(explicit_schema=table_schema, newlines_in_values=self.newlines_in_values) + + @property + def schema(self): + return pa.schema(self.features.type) if self.features is not None else None + + +class Json(datasets.ArrowBasedBuilder): + BUILDER_CONFIG_CLASS = JsonConfig + + def _info(self): + return datasets.DatasetInfo(features=self.config.features) + + def _split_generators(self, dl_manager): + """We handle string, list and dicts in datafiles""" + if not self.config.data_files: + raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") + data_files = dl_manager.download_and_extract(self.config.data_files) + if isinstance(data_files, (str, list, tuple)): + files = data_files + if isinstance(files, str): + files = [files] + return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": files})] + splits = [] + for split_name, files in data_files.items(): + if isinstance(files, str): + files = [files] + splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files})) + return splits + + def _generate_tables(self, files): + for i, file in enumerate(files): + if self.config.field is not None: + with open(file, encoding="utf-8") as f: + dataset = json.load(f) + + # We keep only the field we are interested in + dataset = dataset[self.config.field] + + # We accept two format: a list of dicts or a dict of lists + if isinstance(dataset, (list, tuple)): + pa_table = paj.read_json( + BytesIO("\n".join(json.dumps(row) for row in dataset).encode("utf-8")), + read_options=self.config.pa_read_options, + parse_options=self.config.pa_parse_options, + ) + else: + pa_table = pa.Table.from_pydict(mapping=dataset) + else: + try: + print(f"Reading file: {file}") + pa_table = paj.read_json( + file, + read_options=self.config.pa_read_options, + parse_options=self.config.pa_parse_options, + ) + except pa.ArrowInvalid: + with open(file, encoding="utf-8") as f: + dataset = json.load(f) + raise ValueError( + f"Not able to read records in the JSON file at {file}. " + f"You should probably indicate the field of the JSON file containing your records. " + f"This JSON file contain the following fields: {str(list(dataset.keys()))}. " + f"Select the correct one and provide it as `field='XXX'` to the dataset loading method. " + ) + if self.config.features: + # Encode column if ClassLabel + for i, col in enumerate(self.config.features.keys()): + if isinstance(self.config.features[col], datasets.ClassLabel): + pa_table = pa_table.set_column( + i, self.config.schema.field(col), [self.config.features[col].str2int(pa_table[col])] + ) + # Cast allows str <-> int/float, while parse_option explicit_schema does NOT + # Before casting, rearrange JSON field names to match passed features schema field names order + pa_table = pa.Table.from_arrays( + [pa_table[name] for name in self.config.features], schema=self.config.schema + ) + yield i, pa_table