Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
ligen1 committed Apr 28, 2022
0 parents commit ccd4c46
Show file tree
Hide file tree
Showing 27 changed files with 23,047 additions and 0 deletions.
10 changes: 10 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
nogit
data/*
__pycache__
model/*
pretrained_model/*
logs/*
!.gitkeep
!data/example_input.json
!data/example_output.json
!data/example.txt
40 changes: 40 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# 比赛

文本智能校对大赛baseline

## 代码结构
```
├── command
│ └── train.sh # 训练脚本
├── data
├── logs
├── pretrained_model
└── src
├── __init__.py
├── baseline # baseline系统
├── corrector.py # 文本校对入口
├── evaluate.py # 指标评估
├── metric.py # 指标计算文件
├── prepare_for_upload.py # 生成要提交的结果文件
└── train.py # 训练入口
```

## 使用说明

- 数据集获取:请于比赛官网报名获取数据集
- 提供了基础校对系统的baseline,其中baseline模型训练参数说明参考src/baseline/trainer.py
- baseline中的预训练模型支持使用bert类模型,可从HuggingFace下载bert类预训练模型,如: [chinese-roberta-wwm-ext](https://huggingface.co/hfl/chinese-roberta-wwm-ext)
- baseline上仅作参考,参赛队伍可对baseline进行二次开发,或采取其他解决方案。

## 开始训练

```
cd command && sh train.sh
```

## 其他公开数据集

- CGED历年公开数据集:http://www.cged.tech/
- NLPCC2018语法纠错数据集:http://tcci.ccf.org.cn/conference/2018/taskdata.php
- SIGHAN及相关训练集:http://ir.itc.ntnu.edu.tw/lre/sighan8csc.html

18 changes: 18 additions & 0 deletions command/train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
cd .. && CUDA_VISIBLE_DEVICES=0,1,2,3 python -m src.train \
--in_model_dir "pretrained_model/chinese-roberta-wwm-ext" \
--out_model_dir "model/ctc" \
--epochs "50" \
--batch_size "168" \
--max_seq_len "128" \
--learning_rate "5e-4" \
--train_fp "data/example.txt" \
--test_fp "data/example.txt" \
--random_seed_num "22" \
--check_val_every_n_epoch "1" \
--early_stop_times "20" \
--warmup_steps "-1" \
--dev_data_ratio "0.1" \
--training_mode "normal" \
--amp true \
--freeze_embedding false

Empty file added data/.gitkeep
Empty file.
248 changes: 248 additions & 0 deletions data/example.txt

Large diffs are not rendered by default.

10 changes: 10 additions & 0 deletions data/example_input.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[
{
"source": "领导的按排,我坚决服从",
"id": 1
},
{
"source": "今天的天气真错!",
"id": 2
}
]
10 changes: 10 additions & 0 deletions data/example_output.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[
{
"inference": "领导的安排,我坚决服从",
"id": 1
},
{
"inference": "今天的天气真不错!",
"id": 2
}
]
Empty file added logs/.gitkeep
Empty file.
Empty file added model/.gitkeep
Empty file.
Empty file added pretrained_model/.gitkeep
Empty file.
5 changes: 5 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
auto_argparse==0.0.7
numpy==1.19.5
rich==12.3.0
torch==1.9.0+cu111
transformers==4.6.0
30 changes: 30 additions & 0 deletions src/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import logging
import os
import re
from logging.handlers import TimedRotatingFileHandler


def setup_log(log_name):
logger = logging.getLogger(log_name)
log_path = os.path.join("logs", log_name)
logger.setLevel(logging.DEBUG)
file_handler = TimedRotatingFileHandler(
filename=log_path, when="MIDNIGHT", interval=1, backupCount=30
)
file_handler.suffix = "%Y-%m-%d.log"
file_handler.extMatch = re.compile(r"^\d{4}-\d{2}-\d{2}.log$")
stream_handler = logging.StreamHandler()
formatter = logging.Formatter(
"[%(asctime)s] [%(process)d] [%(levelname)s] - %(module)s.%(funcName)s (%(filename)s:%(lineno)d) - %(message)s"
)

stream_handler.setFormatter(file_handler)
file_handler.setFormatter(
formatter
)
logger.addHandler(stream_handler)
logger.addHandler(file_handler)
return logger


logger = setup_log("ctc.log")
Empty file added src/baseline/__init__.py
Empty file.
3 changes: 3 additions & 0 deletions src/baseline/ctc_vocab/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
class VocabConf:
detect_vocab_size = 2
correct_vocab_size = 20675
Loading

0 comments on commit ccd4c46

Please sign in to comment.