Gradient Cache策略 DPR
Gradient Cache
的实验结果如下,使用的评估指标是Accuracy
:
DPR method | TOP-5 | TOP-10 | TOP-50 | 说明 |
---|---|---|---|---|
Gradient_cache | 68.1 | 79.4 | 86.2 | DPR结合GC策略训练 |
GC_Batch_size_512 | 67.3 | 79.6 | 86.3 | DPR结合GC策略训练,且batch_size设置为512 |
实验对应的超参数如下:
Hyper Parameter | batch_size | learning_rate | warmup_steps | epoches | chunk_size | max_grad_norm |
---|---|---|---|---|---|---|
\ | 128/512 | 2e-05 | 1237 | 40 | 2 | 16/8 |
我们使用Dense Passage Retrieval的原始仓库 中提供的数据集进行训练和评估。可以使用download_data.py 脚本下载所需数据集。 数据集详细介绍见原仓库 。
[
{
"question": "....",
"answers": ["...", "...", "..."],
"positive_ctxs": [{
"title": "...",
"text": "...."
}],
"negative_ctxs": ["..."],
"hard_negative_ctxs": ["..."]
},
...
]
在原始仓库 下使用命令
python data/download_data.py --resource data.wikipedia_split.psgs_w100
python data/download_data.py --resource data.retriever.nq
python data/download_data.py --resource data.retriever.qas.nq
data.retriever.nq-train data.retriever.nq-dev data.retriever.qas.nq-dev data.retriever.qas.nq-test data.retriever.qas.nq-train psgs_w100.tsv
|—— train_gradient_cache_DPR.py # gradient_cache实现dense passage retrieval训练脚本
|—— train_gradient_cache.py # gradient_cache算法简单实现
|—— NQdataset.py # NQ数据集封装
|—— generate_dense_embeddings.py # 生成文本的稠密表示
|—— faiss_indexer.py # faiss相关indexer封装
|—— dense_retriever.py # 召回,指标检测
|—— qa_validation.py # 相关计算匹配函数
|—— tokenizers.py # tokenizer封装
基于 Dense Passage Retriever 策略训练
python train_gradient_cache_DPR.py \
--batch_size 128 \
--learning_rate 2e-05 \
--save_dir save_biencoder
--warmup_steps 1237 \
--epoches 40 \
--max_grad_norm 2 \
--train_data_path ./dataset_dir/biencoder-nq-train.json \
--chunk_size 16 \
参数含义说明
batch_size
: 批次大小learning_rate
: 学习率save_dir
: 模型保存位置warmupsteps
: 预热学习率参数epoches
: 训练批次大小max_grad_norm
: 详见ClipGradByGlobalNormtrain_data_path
: 训练数据存放地址chunk_size
: chunk的大小
python generate_dense_embeddings.py \
--ctx_file ./dataset_dir/psgs_w100.tsv \
--out_file test_generate \
--que_model_path ./save_dir/question_model_40 \
--con_model_path ./save_dir/context_model_40
参数含义说明
ctx_file
: ctx文件读取地址out_file
: 生成后的文件输出地址que_model_path
: question model pathcon_model_path
: context model path
python dense_retriever.py --hnsw_index \
--out_file out_file \
--encoded_ctx_file ./test_generate \
--ctx_file ./dataset_dir/psgs_w100.tsv \
--qa_file ./dataset_dir/nq.qa.csv \
--que_model_path ./save_dir/question_model_40 \
--con_model_path ./save_dir/context_model_40
参数含义说明
hnsw_index
:使用hnsw_indexoutfile
: 输出文件地址encoded_ctx_file
: 编码后的ctx文件ctx_file
: ctx文件qa_file
: qa_file文件que_model_path
: question encoder modelcon_model_path
: context encoder model