-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy patharguments.py
90 lines (81 loc) · 3.39 KB
/
arguments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
from dataclasses import dataclass, field
from typing import Optional, Union, List
from transformers import TrainingArguments
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
token_dim: int = field(default=768)
cls_dim: int = field(default=768)
token_rep_relu: bool = field(default=False, )
token_norm_after: bool = field(default=False)
cls_norm_after: bool = field(default=False)
x_device_negatives: bool = field(default=False)
pooling: str = field(default='max')
no_sep: bool = field(default=False, )
no_cls: bool = field(default=False, )
cls_only: bool = field(default=False, )
@dataclass
class DataArguments:
train_dir: str = field(
default=None, metadata={"help": "Path to train directory"}
)
train_path: Union[str] = field(
default=None, metadata={"help": "Path to train data"}
)
train_group_size: int = field(default=8)
pred_path: List[str] = field(default=None, metadata={"help": "Path to prediction data"})
pred_dir: str = field(
default=None, metadata={"help": "Path to prediction directory"}
)
pred_id_file: str = field(default=None)
rank_score_path: str = field(default=None, metadata={"help": "where to save the match score"})
encode_in_path: List[str] = field(default=None, metadata={"help": "Path to data to encode"})
encoded_save_path: str = field(default=None, metadata={"help": "where to save the encode"})
q_max_len: int = field(
default=128,
metadata={
"help": "The maximum total input sequence length after tokenization for query. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
p_max_len: int = field(
default=128,
metadata={
"help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
document: bool = field(default=False)
def __post_init__(self):
if self.train_dir is not None:
files = os.listdir(self.train_dir)
self.train_path = [
os.path.join(self.train_dir, f)
for f in files
if f.endswith('tsv') or f.endswith('json')
]
if self.pred_dir is not None:
files = os.listdir(self.pred_dir)
self.pred_path = [
os.path.join(self.pred_dir, f)
for f in files
]
@dataclass
class COILTrainingArguments(TrainingArguments):
warmup_ratio: float = field(default=0)
do_encode: bool = field(default=False, metadata={"help": "Whether to run encoding on the test set."})