-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenotypes.py
103 lines (86 loc) · 3.11 KB
/
genotypes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
""" Genotypes
- Genotype: normal/reduce gene + normal/reduce cell output connection (concat)
- gene: discrete ops information (w/o output connection)
- dag: real ops (can be mixed or discrete, but Genotype has only discrete information itself)
"""
from collections import namedtuple
import torch
import torch.nn as nn
from models import ops
Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
PRIMITIVES = [
'max_pool_3x3',
'avg_pool_3x3',
'skip_connect', # identity
'sep_conv_3x3',
'sep_conv_5x5',
'dil_conv_3x3',
'dil_conv_5x5',
'none'
]
def to_dag(C_in, gene, reduction):
""" generate discrete ops from gene """
dag = nn.ModuleList()
for edges in gene:
row = nn.ModuleList()
for op_name, s_idx in edges:
# reduction cell & from input nodes => stride = 2
stride = 2 if reduction and s_idx < 2 else 1
op = ops.OPS[op_name](C_in, stride, True)
if not isinstance(op, ops.Identity): # Identity does not use drop path
op = nn.Sequential(
op,
ops.DropPath_()
)
op.s_idx = s_idx
row.append(op)
dag.append(row)
return dag
def from_str(s):
""" generate genotype from string
e.g. "Genotype(
normal=[[('sep_conv_3x3', 0), ('sep_conv_3x3', 1)],
[('sep_conv_3x3', 1), ('dil_conv_3x3', 2)],
[('sep_conv_3x3', 1), ('sep_conv_3x3', 2)],
[('sep_conv_3x3', 1), ('dil_conv_3x3', 4)]],
normal_concat=range(2, 6),
reduce=[[('max_pool_3x3', 0), ('max_pool_3x3', 1)],
[('max_pool_3x3', 0), ('skip_connect', 2)],
[('max_pool_3x3', 0), ('skip_connect', 2)],
[('max_pool_3x3', 0), ('skip_connect', 2)]],
reduce_concat=range(2, 6))"
"""
genotype = eval(s)
return genotype
def parse(alpha, k):
"""
parse continuous alpha to discrete gene.
alpha is ParameterList:
ParameterList [
Parameter(n_edges1, n_ops),
Parameter(n_edges2, n_ops),
...
]
gene is list:
[
[('node1_ops_1', node_idx), ..., ('node1_ops_k', node_idx)],
[('node2_ops_1', node_idx), ..., ('node2_ops_k', node_idx)],
...
]
each node has two edges (k=2) in CNN.
"""
gene = []
assert PRIMITIVES[-1] == 'none' # assume last PRIMITIVE is 'none'
# 1) Convert the mixed op to discrete edge (single op) by choosing top-1 weight edge
# 2) Choose top-k edges per node by edge score (top-1 weight in edge)
for edges in alpha:
# edges: Tensor(n_edges, n_ops)
edge_max, primitive_indices = torch.topk(edges[:, :-1], 1) # ignore 'none'
topk_edge_values, topk_edge_indices = torch.topk(edge_max.view(-1), k)
node_gene = []
for edge_idx in topk_edge_indices:
prim_idx = primitive_indices[edge_idx]
prim = PRIMITIVES[prim_idx]
node_gene.append((prim, edge_idx.item()))
gene.append(node_gene)
return gene