-
Notifications
You must be signed in to change notification settings - Fork 30
/
yolo_train.py
46 lines (36 loc) · 1.1 KB
/
yolo_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch.cuda
import yolo.yolov5.train as yolo_train
def train_yolo(**kwargs):
params = get_yolo_params(**kwargs)
yolo_train.run(**params)
def get_yolo_params(name, epochs=50, project='models', weights='yolov5s',
evolve=0, device='cuda:0'):
yolo_parameters = {
'data': 'yolo/sp_dataset.yaml',
# 'weights': f'{project}/{weights}/weights/best.pt',
'weights': weights + '.pt' if not weights.endswith('.pt') else '',
'imgsz': 256,
'batch_size': 16,
'workers': 4,
'project': project,
'name': name,
'epochs': epochs,
'device': device
}
if evolve:
yolo_parameters['evolve'] = evolve
yolo_parameters['name'] += '_hyp'
return yolo_parameters
def main():
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
epochs = 10
model = 'yolov5s'
model_name = f'{model}_{epochs}_test'
train_yolo(
name=model_name,
epochs=epochs,
project='models', # evolve=300,
weights=model, device=device
)
if __name__ == '__main__':
main()