-
Notifications
You must be signed in to change notification settings - Fork 2
Dense_Retrieval
-
주어진 dataset의 context(passage)의 길이가 매우 길어서 기존 tokenizer를 그대로 사용하게 되면 뒷 부분의 내용이 잘려서 사용되게된다.
-
따라서 가능한 긴 내용을 사용하기 위해 max length를 1536으로 늘려서 사용하였다.
model_checkpoint = 'bert-base-multilingual-cased'
p_tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
p_tokenizer.model_max_length = 1536
Before
After
-
Passage tokenizer의 max length를 늘려주면서 tokenize된 input의 길이가 길어지게 된다. 이렇게 길어진 input을 model에 그대로 넣어주려하면 error가 나게 된다. 왜냐하면 dense retireval를 학습하기 위해 사용하는 pre-trained model의 position ids가 512로 사전 학습되었기 때문이다. 만약 position ids에 해당하는 embedding layer를 max length에 맞춰서 늘려주면 pre-trained model을 사용할 수 없게 된다.
-
pre-trained model의 max length는 512로 되어 있고, 이 길이에 맞춰서 training과 inference 방법을 바꿔주어야 했다.
-
따라서 training 시에는 passage의 길이가 512보다 긴 경우 전체 passage 중 random 하게 512만큼 선택하여 사용하였다.
-
그리고 inference 시에는 passage를 512의 window를 가지고 50% overlap하여 여러 passage로 잘라서 사용하였고, 만약 이 중 하나라도 question과의 similarity가 높게 나오면 찾아낸 것으로 판단하였다.
Training
class TrainRetrievalDataset(torch.utils.data.Dataset):
...
def _select_range(self, attention_mask):
sent_len = len([i for i in attention_mask if i != 0])
if sent_len <= 512:
return 1, 511
else:
start_idx = random.randint(1, sent_len-511)
end_idx = start_idx + 510
return start_idx, end_idx
Inference
class ValidRetrievalDataset(torch.utils.data.Dataset):
...
def _select_range(self, attention_mask):
sent_len = len([i for i in attention_mask if i != 0])
if sent_len <= 512:
return [(1,511)]
else:
num = sent_len // 255
res = sent_len % 255
if res == 0:
num -= 1
ids_list = []
for n in range(num):
if res > 0 and n == num-1:
end_idx = sent_len-1
start_idx = end_idx - 510
else:
start_idx = n*255+1
end_idx = start_idx + 510
ids_list.append((start_idx, end_idx))
return ids_list
-
multilingual bert, facebook/dpr, koelectra, xlm-roberta 중 multilingual bert 선택
-
facebook/dpr : 자소 단위 tokenizing을 사용하고 있어 context의 길이가 매우 길어지기 때문에 다른 모델들과 비교하여 동일한 token 수 대비 적은 정보를 가진다고 판단하였고, context의 앞부분 내용만으로 학습한 결과도 낮은 성능을 보였다.
Before
"이순신은 조선 중기의 무신이다."
After
['ᄋ', '# # ᅵ', '# # ᄉ', '# # ᅮ', '# # ᆫ', '# # ᄉ', '# # ᅵ', '# # ᆫ', '# # ᄋ', '# # ᅳ', '# # ᆫ', 'ᄌ', '# # ᅩ', '# # ᄉ', '# # ᅥ', '# # ᆫ', 'ᄌ', '# # ᅮ', '# # ᆼ', '# # ᄀ', '# # ᅵ', '# # ᄋ', '# # ᅴ', 'ᄆ', '# # ᅮ', '# # ᄉ', '# # ᅵ', '# # ᆫ', '# # ᄋ', '# # ᅵ', '# # ᄃ', '# # ᅡ', '.']
-
xlm-roberta : question과 context에 해당하는 모델을 각각 사용해야 한다. 하지만 xlm-roberta의 경우엔 모델이 굉장히 크기 때문에 메모리에 2개의 모델을 올릴 수가 없었다. 따라서 1개의 모델로 question과 context에 대한 embedding vector을 만드는 방식으로 성능을 비교하였다.
retrieval_model.py
from torch import nn
from transformers import AutoModel, AutoConfig
class BertPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class Encoder(nn.Module):
def __init__(self, model_checkpoint):
super(Encoder, self).__init__()
self.model_checkpoint = model_checkpoint
config = AutoConfig.from_pretrained(self.model_checkpoint)
if self.model_checkpoint == 'monologg/koelectra-base-v3-discriminator':
self.pooler = BertPooler(config)
config = AutoConfig.from_pretrained(self.model_checkpoint)
self.model = AutoModel.from_pretrained(self.model_checkpoint, config=config)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None):
outputs = self.model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids)
if self.model_checkpoint == 'monologg/koelectra-base-v3-discriminator':
sequence_output = outputs[0]
pooled_output = self.pooler(sequence_output)
else:
pooled_output = outputs[1]
return pooled_output
facebook/dpr
from transformers import (DPRContextEncoder,
DPRContextEncoderTokenizer,
DPRQuestionEncoder,
DPRQuestionEncoderTokenizer)
p_model = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')
q_model = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
Top 1 accuracy
-
Elastic Search의 성능을 뛰어 넘는 dense retrieval를 만드는 것보다 elastic search의 성능을 더 개선하는 점이 구현 가능성이 높다고 판단하였다.
-
Elastic Search를 통해 score가 가장 높은 20개의 context를 고른 후 dense retrieval를 통해 20개의 context 중 좀 더 유사도가 높은 context를 찾는 방식
How to use Elastic Search : Elastic Search for Beginners