-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathargs_helper.py
84 lines (81 loc) · 3.19 KB
/
args_helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from dataclasses import dataclass, field
from typing import Optional
from transformers import TrainingArguments
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to utilize.
"""
model_name_or_path: Optional[str] = field(
default="ydshieh/wav2vec2-large-xlsr-53-chinese-zh-cn-gpt", metadata={"help": "The path of the HuggingFace model."}
)
mask_time_prob: float = field(
default=0.065,
metadata={
"help": "Probability of each feature vector along the time axis to be chosen as the start of the vector"
"span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
"vectors will be masked along the time axis."
},
)
mask_time_length: int = field(
default=2,
metadata={"help": "Length of vector span to mask along the time axis."},
)
mask_feature_prob: float = field(
default=0.004,
metadata={
"help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector"
"span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis."
},
)
mask_feature_length: int = field(
default=2,
metadata={"help": "Length of vector span to mask along the feature axis."},
)
@dataclass
class DataArguments:
"""
Arguments pertaining to the data loading and preprocessing pipeline.
"""
train_manifest_path: Optional[str] = field(
default="dataset/train_metadata.csv", metadata={"help": "The path of the training dataset to use."}
)
valid_manifest_path: Optional[str] = field(
default="dataset/validation_metadata.csv", metadata={"help": "The path of the validation dataset to use."}
)
test_manifest_path: Optional[str] = field(
default="dataset/test_metadata.csv", metadata={"help": "The path of the testing dataset to use."}
)
preprocessing_num_workers: Optional[int] = field(
default=16,
metadata={"help": "The number of processes to use for the preprocessing."},
)
preprocessing_only: Optional[bool] = field(
default=False,
metadata={"help": "Whether to only run preprocessing."},
)
audio_column_name: Optional[str] = field(
default="file_name",
metadata={"help": "The name of the dataset column containing the audio path. Defaults to 'file_name'"},
)
text_column_name: Optional[str] = field(
default="text",
metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
)
cache_dir_name: Optional[str] = field(
default="cache",
metadata={"help": "Name of cache directory"},
)
@dataclass
class TrainingArguments(TrainingArguments):
"""
Arguments pertraining to the training pipeline.
"""
output_dir: Optional[str] = field(
default="./save",
metadata={"help": "Output directory"},
)
eval_accumulation_steps: Optional[int] = field(
default=1,
metadata={"help": "Evaluation accumulation steps"}
)