Skip to content

Dense_Retrieval

PJHgh edited this page May 24, 2021 · 2 revisions

1️⃣ Passage Tokenizer - Max Length

1. Why

  • 주어진 dataset의 context(passage)의 길이가 매우 길어서 기존 tokenizer를 그대로 사용하게 되면 뒷 부분의 내용이 잘려서 사용되게된다.

  • 따라서 가능한 긴 내용을 사용하기 위해 max length를 1536으로 늘려서 사용하였다.

2. code

model_checkpoint = 'bert-base-multilingual-cased'
p_tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
p_tokenizer.model_max_length = 1536

3. result

Before

After

2️⃣ Training & Infernece

1. Why

  • Passage tokenizer의 max length를 늘려주면서 tokenize된 input의 길이가 길어지게 된다. 이렇게 길어진 input을 model에 그대로 넣어주려하면 error가 나게 된다. 왜냐하면 dense retireval를 학습하기 위해 사용하는 pre-trained model의 position ids가 512로 사전 학습되었기 때문이다. 만약 position ids에 해당하는 embedding layer를 max length에 맞춰서 늘려주면 pre-trained model을 사용할 수 없게 된다.

  • pre-trained model의 max length는 512로 되어 있고, 이 길이에 맞춰서 training과 inference 방법을 바꿔주어야 했다.

  • 따라서 training 시에는 passage의 길이가 512보다 긴 경우 전체 passage 중 random 하게 512만큼 선택하여 사용하였다.

  • 그리고 inference 시에는 passage를 512의 window를 가지고 50% overlap하여 여러 passage로 잘라서 사용하였고, 만약 이 중 하나라도 question과의 similarity가 높게 나오면 찾아낸 것으로 판단하였다.

2. Code

Training

class TrainRetrievalDataset(torch.utils.data.Dataset):
    ...
    def _select_range(self, attention_mask):
        sent_len = len([i for i in attention_mask if i != 0])
        if sent_len <= 512:
            return 1, 511
        else:
            start_idx = random.randint(1, sent_len-511)
            end_idx = start_idx + 510
            return start_idx, end_idx

Inference

class ValidRetrievalDataset(torch.utils.data.Dataset):
    ...
    def _select_range(self, attention_mask):
        sent_len = len([i for i in attention_mask if i != 0])
        if sent_len <= 512:
            return [(1,511)]
        else:
            num = sent_len // 255
            res = sent_len % 255
            if res == 0:
                num -= 1
            ids_list = []
            for n in range(num):
                if res > 0 and n == num-1:
                    end_idx = sent_len-1
                    start_idx = end_idx - 510
                else:
                    start_idx = n*255+1
                    end_idx = start_idx + 510
                ids_list.append((start_idx, end_idx))
            return ids_list

3️⃣ Model

1. Why

  • multilingual bert, facebook/dpr, koelectra, xlm-roberta 중 multilingual bert 선택

  • facebook/dpr : 자소 단위 tokenizing을 사용하고 있어 context의 길이가 매우 길어지기 때문에 다른 모델들과 비교하여 동일한 token 수 대비 적은 정보를 가진다고 판단하였고, context의 앞부분 내용만으로 학습한 결과도 낮은 성능을 보였다.

    Before

    "이순신은 조선 중기의 무신이다."

    After

    ['ᄋ', '# # ᅵ', '# # ᄉ', '# # ᅮ', '# # ᆫ', '# # ᄉ', '# # ᅵ', '# # ᆫ', '# # ᄋ', '# # ᅳ', '# # ᆫ', 'ᄌ', '# # ᅩ', '# # ᄉ', '# # ᅥ', '# # ᆫ', 'ᄌ', '# # ᅮ', '# # ᆼ', '# # ᄀ', '# # ᅵ', '# # ᄋ', '# # ᅴ', 'ᄆ', '# # ᅮ', '# # ᄉ', '# # ᅵ', '# # ᆫ', '# # ᄋ', '# # ᅵ', '# # ᄃ', '# # ᅡ', '.']

  • xlm-roberta : question과 context에 해당하는 모델을 각각 사용해야 한다. 하지만 xlm-roberta의 경우엔 모델이 굉장히 크기 때문에 메모리에 2개의 모델을 올릴 수가 없었다. 따라서 1개의 모델로 question과 context에 대한 embedding vector을 만드는 방식으로 성능을 비교하였다.

2. code

retrieval_model.py

from torch import nn
from transformers import AutoModel, AutoConfig

class BertPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output

class Encoder(nn.Module):
    def __init__(self, model_checkpoint):
        super(Encoder, self).__init__()
        self.model_checkpoint = model_checkpoint
        config = AutoConfig.from_pretrained(self.model_checkpoint)
        
        if self.model_checkpoint == 'monologg/koelectra-base-v3-discriminator':
            self.pooler = BertPooler(config)
        config = AutoConfig.from_pretrained(self.model_checkpoint)
        self.model = AutoModel.from_pretrained(self.model_checkpoint, config=config)
    
    def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None):
        outputs = self.model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids)
        if self.model_checkpoint == 'monologg/koelectra-base-v3-discriminator':
            sequence_output = outputs[0]
            pooled_output = self.pooler(sequence_output)
        else:
            pooled_output = outputs[1]
        return pooled_output

facebook/dpr

from transformers import (DPRContextEncoder,
                          DPRContextEncoderTokenizer,
                          DPRQuestionEncoder,
                          DPRQuestionEncoderTokenizer)

p_model = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')
q_model = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-single-nq-base')

3. result

Top 1 accuracy

4️⃣ Elastic search + Dense Retrieval

1. why

  • Elastic Search의 성능을 뛰어 넘는 dense retrieval를 만드는 것보다 elastic search의 성능을 더 개선하는 점이 구현 가능성이 높다고 판단하였다.

  • Elastic Search를 통해 score가 가장 높은 20개의 context를 고른 후 dense retrieval를 통해 20개의 context 중 좀 더 유사도가 높은 context를 찾는 방식

    How to use Elastic Search : Elastic Search for Beginners