-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_extractors.py
61 lines (34 loc) · 1.83 KB
/
train_extractors.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import argparse
import os
import numpy as np
from pathlib import Path
from ctl.trainer_extractors import AISTExtractorMotionTrainer
from configs.config import cfg, get_cfg_defaults
from core.models.eval_modules import AISTEncoderBiGRUCo
def main():
motion_extractor = AISTEncoderBiGRUCo(cfg.extractor.motion_input_size,cfg.extractor.hidden_size,cfg.extractor.output_size)
music_extractor = AISTEncoderBiGRUCo(cfg.extractor.music_input_size,cfg.extractor.hidden_size,cfg.extractor.output_size)
trainer = AISTExtractorMotionTrainer(
motion_extractor = motion_extractor,
music_extractor = music_extractor,
args = cfg.extractor,
training_args = cfg.train,
dataset_args = cfg.dataset,
eval_args = cfg.eval_model,
model_name = cfg.extractors_model_name,
).cuda()
trainer.train(cfg.train.resume)
if __name__ == '__main__':
# cfg = get_cfg_defaults()
# print("loading config from:" , "/srv/scratch/sanisetty3/music_motion/motion_vqvae/configs/var_len_768_768_aist_vq.yaml")
# cfg.merge_from_file("/srv/scratch/sanisetty3/music_motion/motion_vqvae/configs/var_len_768_768_aist_vq.yaml")
# cfg.freeze()
# print("output_dir: ", cfg.output_dir , cfg.train.output_dir)
cfg = get_cfg_defaults()
print("loading config from:" , "/srv/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/extractors/aist_extractor.yaml")
cfg.merge_from_file("/srv/scratch/sanisetty3/music_motion/motion_vqvae/checkpoints/extractors/aist_extractor.yaml")
cfg.freeze()
print("output_dir: ", cfg.output_dir , cfg.train.output_dir)
main()
#accelerate launch --multi_gpu --mixed_precision=fp16 --num_processes=2 train.py
# accelerate configuration saved at /nethome/sanisetty3/.cache/huggingface/accelerate/default_config.yaml