-
Notifications
You must be signed in to change notification settings - Fork 1
/
encode_gtr-base-st.py
100 lines (83 loc) · 3.55 KB
/
encode_gtr-base-st.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import logging
import os
import pickle
import sys
from contextlib import nullcontext
import numpy as np
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from transformers import AutoConfig, AutoTokenizer
from transformers import (
HfArgumentParser,
AutoModel
)
import vec2text
from tevatron.arguments import ModelArguments, DataArguments, \
TevatronTrainingArguments as TrainingArguments
from tevatron.data import EncodeDataset, EncodeCollator
from tevatron.datasets import HFQueryDataset, HFCorpusDataset
from sentence_transformers import SentenceTransformer
logger = logging.getLogger(__name__)
def main():
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args: ModelArguments
data_args: DataArguments
training_args: TrainingArguments
if training_args.local_rank > 0 or training_args.n_gpu > 1:
raise NotImplementedError('Multi-GPU encoding is not supported.')
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir
)
model = SentenceTransformer('sentence-transformers/gtr-t5-base')
text_max_length = data_args.q_max_len if data_args.encode_is_qry else data_args.p_max_len
if data_args.encode_is_qry:
encode_dataset = HFQueryDataset(tokenizer=tokenizer, data_args=data_args,
cache_dir=data_args.data_cache_dir or model_args.cache_dir)
else:
encode_dataset = HFCorpusDataset(tokenizer=tokenizer, data_args=data_args,
cache_dir=data_args.data_cache_dir or model_args.cache_dir)
encode_dataset = EncodeDataset(encode_dataset.process(data_args.encode_num_shard, data_args.encode_shard_index),
tokenizer, max_len=text_max_length)
encode_loader = DataLoader(
encode_dataset,
batch_size=training_args.per_device_eval_batch_size,
collate_fn=EncodeCollator(
tokenizer,
max_length=text_max_length,
padding='max_length'
),
shuffle=False,
drop_last=False,
num_workers=training_args.dataloader_num_workers,
)
encoded = []
lookup_indices = []
model = model.to(training_args.device)
model.eval()
for (batch_ids, batch) in tqdm(encode_loader):
lookup_indices.extend(batch_ids)
with torch.cuda.amp.autocast() if training_args.fp16 else nullcontext():
with torch.no_grad():
for k, v in batch.items():
batch[k] = v.to(training_args.device)
out_features = model(batch)
embeddings = out_features['sentence_embedding']
embeddings = embeddings.detach()
encoded.append(embeddings.cpu().detach().numpy())
encoded = np.concatenate(encoded)
with open(data_args.encoded_save_path, 'wb') as f:
pickle.dump((encoded, lookup_indices), f)
if __name__ == "__main__":
main()