forked from dsr-18/long-live-the-battery
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.sh
executable file
·102 lines (88 loc) · 2.51 KB
/
train.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#!/bin/bash
# standard args
BUCKET='ion_age_bucket'
REGION='europe-west1'
PACKAGE_PATH='trainer/'
MODULE_NAME='trainer.task'
CONFIG_FILE='config.yaml'
JOB_DIR="gs://${BUCKET}"
PACKAGE_STAGING_PATH="gs://${BUCKET}"
now=$(date +"%Y%m%d_%H%M%S")
JOB_NAME="ion_age_$now"
# user-specified args
TFRECORDS_DIR_TRAIN="gs://${BUCKET}/data/tfrecords/train/*tfrecord"
TFRECORDS_DIR_VALIDATE="gs://${BUCKET}/data/tfrecords/test/*tfrecord"
# put TensorBoard logs, saved models in individual run dirs
JOB_RUN_DIR="${PACKAGE_STAGING_PATH}/jobs/${JOB_NAME}"
# parse command-line args
params=()
while getopts ":hw:e:b:s:t:l:o:v:z:f:m:" opt; do
case $opt in
h)
printf "Options:\n\t -w window-size\
\n\t -e num-epochs\
\n\t -b batch-size\
\n\t -s shift\
\n\t -t stride\
\n\t -l loss\
\n\t -o optimizer\
\n\t -v verbosity\
\n\t -f save-from\
\n\t -m model\
\n\t -z shuffle-buffer\n" >&2
exit 1
;;
w)
params+=(--window-size $OPTARG)
;;
e)
params+=(--num-epochs $OPTARG)
;;
b)
params+=(--batch-size $OPTARG)
;;
s)
params+=(--shift $OPTARG)
;;
t)
params+=(--stride $OPTARG)
;;
l)
params+=(--loss $OPTARG)
;;
v)
params+=(--verbosity $OPTARG)
;;
z)
params+=(--shuffle-buffer $OPTARG)
;;
f)
params+=(--save-from $OPTARG)
;;
m)
params+=(--model $OPTARG)
;;
\?)
echo "Invalid option: -$OPTARG" >&2
exit 1
;;
esac
done
echo "PARAMS ${params[@]}"
# issue train command to gcloud
# user-defined args go after the open '--'
gcloud ai-platform jobs submit training $JOB_NAME \
--job-dir $JOB_DIR \
--staging-bucket $PACKAGE_STAGING_PATH \
--package-path $PACKAGE_PATH \
--module-name $MODULE_NAME \
--region $REGION \
--python-version 3.5 \
--runtime-version 1.13 \
--config $CONFIG_FILE \
--stream-logs \
-- \
--data-dir-train $TFRECORDS_DIR_TRAIN \
--data-dir-validate $TFRECORDS_DIR_VALIDATE \
--tboard-dir $JOB_RUN_DIR \
"${params[@]}"