-
Notifications
You must be signed in to change notification settings - Fork 0
/
extract_features.py
142 lines (88 loc) · 3.69 KB
/
extract_features.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from flim.experiments import utils
from model.model import get_device, _layers_before_downscale, maybe_resize2d
import torch
import click
import matplotlib.pyplot as plt
from collections import OrderedDict
import functools
import os
import shutil
def rgetattr(obj, attr, *args):
def _getattr (obj, attr):
return getattr(obj, attr, *args)
return functools.reduce(_getattr, [obj] + attr.split('.'))
class IntermediateLayerGetter:
def __init__(self, model, return_layers):
self._model = model
self._return_layers = return_layers
def __call__(self, x):
outputs = OrderedDict()
handles = []
for name, out_name in self._return_layers.items():
layer = rgetattr(self._model, name)
def hook(module, input, output, out_name=out_name):
outputs[out_name] = output
handle = layer.register_forward_hook(hook)
handles.append(handle)
self._model(x)
for handle in handles:
handle.remove()
return outputs
def save_features(out_channels, output_dir, imshape):
if not os.path.isdir(output_dir):
os.mkdir(output_dir)
else:
#os.rmdir(output_dir)
shutil.rmtree(output_dir, ignore_errors=True)
os.mkdir(output_dir)
for block in out_channels:
blockname = block['block']
feats = block['out']
innerdir = os.path.join(output_dir, blockname)
os.mkdir(innerdir)
feats = maybe_resize2d(feats, imshape).squeeze(0).numpy()
#saving mimage
try:
utils.save_mimage(innerdir + ".mimg", feats.transpose(1,2,0) )
except Exception as e:
print('An exception occurred when saving mimage: {}'.format(e))
nfeats = feats.shape[0]
for i in range(nfeats):
tmp = feats[i,:,:]
output_img = os.path.join(innerdir, str(i) + ".png")
plt.imsave(output_img, tmp, cmap='gray')
def forward_encoder(encoder, x):
encoder_block_names, block_out_channels = _layers_before_downscale(encoder)
layer_names = {layer_name: layer_name for layer_name in encoder_block_names[:-1]}
layer_names[encoder_block_names[-1]] = "bottleneck"
encoder_blocks = IntermediateLayerGetter(encoder, layer_names)
encoder_outputs = encoder_blocks(x)
block_names = reversed(encoder_outputs.keys())
ret = []
for name in block_names:
block_output = encoder_outputs[(name)]
listname = name.split('.')
if len(listname) > 1:
outname = listname[1]
else:
outname = listname[0]
tmp = {"block": outname, "out": block_output.detach().cpu()}
ret.append(tmp)
return ret
@click.command()
@click.option('--arch-path', '-a', required=True, type=str, help='Architecture json description file')
@click.option('--input-image', '-i', required=True, type=str, help='Input .png image')
@click.option('--output-dir', '-o', required=True, type=str, help='Output features dir')
@click.option('--model', '-m', default='encoder.pt', type=str, help='Input encoder model, default=encoder.pt')
def main(arch_path, model, input_image, output_dir):
arch = utils.load_architecture(arch_path)
encoder = utils.build_model(arch, input_shape=[3])
checkpoint = torch.load(model)
encoder.load_state_dict(checkpoint['encoder_state_dict'])
image = torch.tensor(utils.load_image(input_image))
image = image.unsqueeze(0).permute(0,3,1,2).float()
out_channels = forward_encoder(encoder, image)
save_features(out_channels, output_dir, image.shape)
print(f"Done. All images saved to {output_dir}")
if __name__ == '__main__':
main()