-
-
Notifications
You must be signed in to change notification settings - Fork 18
/
delete_bad_nets.py
151 lines (128 loc) · 5.41 KB
/
delete_bad_nets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import sys
import re
import os
import itertools
def parse_ordo(ordo_filename):
ordo_scores = []
with open(ordo_filename, 'r') as ordo_file:
lines = ordo_file.readlines()
for line in lines:
if 'nn-epoch' in line:
fields = line.split()
net = fields[1]
rating = float(fields[3])
error = float(fields[4])
ordo_scores.append((net, rating, error))
return ordo_scores
def find_ckpt_files(root_dir):
p = re.compile('.*\\.ckpt')
ckpt_files = []
for path, subdirs, files in os.walk(root_dir, followlinks=False):
for filename in files:
m = p.match(filename)
if m:
ckpt_files.append(os.path.join(path, filename))
return ckpt_files
def find_nnue_files(root_dir):
p = re.compile('.*\\.nnue')
nnue_files = []
for path, subdirs, files in os.walk(root_dir, followlinks=False):
for filename in files:
m = p.match(filename)
if m:
nnue_files.append(os.path.join(path, filename))
return nnue_files
def get_net_dir(net_path):
return os.path.dirname(net_path)
def split_nets_by_strength(nets, split_point=16):
nets.sort(key=lambda x: -x[1])
best_nets = nets[:min(split_point, len(nets))]
worst_nets = nets[min(split_point, len(nets)):]
return best_nets, worst_nets
def get_nets_by_directory(best_nets, worst_nets, num_best_to_keep=16):
binned_best_nets = dict()
binned_worst_nets = dict()
for net_name, rating, error in itertools.chain(best_nets, worst_nets):
basedir = get_net_dir(net_name)
if not basedir in binned_best_nets:
binned_best_nets[basedir] = []
if not basedir in binned_worst_nets:
binned_worst_nets[basedir] = []
for net_name, rating, error in worst_nets:
basedir = get_net_dir(net_name)
binned_worst_nets[basedir].append(net_name)
for net_name, rating, error in best_nets:
basedir = get_net_dir(net_name)
binned_best_nets[basedir].append(net_name)
return binned_best_nets, binned_worst_nets
def delete_bad_nets(root_dir, num_best_to_keep=16):
net_epoch_p = re.compile(".*epoch([0-9]*)\\.nnue")
ckpt_epoch_p = re.compile(".*epoch=([0-9]*).*\\.ckpt")
ordo_filename = os.path.join(root_dir, "ordo.out")
if not os.path.exists(ordo_filename):
print('No ordo file found. Exiting.')
return
else:
nets = parse_ordo(ordo_filename)
best_nets, worst_nets = split_nets_by_strength(nets, num_best_to_keep)
best_nets_by_dir, worst_nets_by_dir = get_nets_by_directory(best_nets, worst_nets, num_best_to_keep)
for basedir, worst_nets_in_dir in worst_nets_by_dir.items():
ckpt_files = find_ckpt_files(basedir)
nnue_files = find_nnue_files(basedir)
worst_epochs = [net_epoch_p.match(net_name)[1] for net_name in worst_nets_in_dir]
for ckpt_file in ckpt_files:
try:
ckpt_epoch = ckpt_epoch_p.match(ckpt_file)[1]
if ckpt_epoch in worst_epochs:
print('Delete {}'.format(ckpt_file))
os.remove(ckpt_file)
except:
pass
print('Keep {}'.format(ckpt_file))
for nnue_file in nnue_files:
try:
nnue_epoch = net_epoch_p.match(nnue_file)[1]
if nnue_epoch in worst_epochs:
print('Delete {}'.format(nnue_file))
os.remove(nnue_file)
except:
pass
print('Keep {}'.format(nnue_file))
def show_help():
print('Usage: python delete_bad_nets.py root_dir [num_best_to_keep]')
print('root_dir - the directory to "cleanup"')
print('num_best_to_keep - the number of best nets to keep. Default: 16')
print('')
print('It expects to find ordo.out somewhere within root_dir.')
print('If the ordo.out is not found nothing is deleted.')
print('It uses the ratings from the ordo file to determine which nets are best.')
print('The engine names must contain the network name in the')
print('following format: "nn-epoch[0-9]*\\.nnue". The network file')
print('can be specified with a parent directory (for example')
print('"run_0/nn-epoch100.nnue"), in which case the .ckpt file corresponding')
print('to this .nnue file will only be searched for in the parent ("run_0") directory.')
print('The .ckpt files must contain "epoch=([0-9]*).*\\.ckpt".')
print('Both ckpt and nnue files are deleted. Only nets listed in the ordo')
print('file can be deleted. Other nets are always kept.')
print('The .nnue and .ckpt files are matched by epoch.')
print('')
print('The directory layout can be for example:')
print('- root_dir')
print(' - run_0')
print(' - a/b/c/d.ckpt')
print(' - *.nnue')
print(' - run_1')
print(' - a/b/c/d.ckpt')
print(' - *.nnue')
print(' - ordo.out')
print(' (in this case only lines with engine name matching')
print(' "run_[01]/nn-epoch[0-9]*\\.nnue" will be used.)')
def main():
if len(sys.argv) < 2:
show_help()
return
root_dir = sys.argv[1]
num_best_to_keep = sys.argv[2] if len(sys.argv) >= 3 else 16
delete_bad_nets(root_dir, num_best_to_keep)
if __name__ == '__main__':
main()