-
Notifications
You must be signed in to change notification settings - Fork 10
/
extract_features.py
236 lines (196 loc) · 8.26 KB
/
extract_features.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
'''
This script runs trained CNN models to extract features from either the DHS or
LSMS satellite images.
Usage:
python extract_features.py
Note: this script does not take any command line options. Instead, set
parameters in the "Parameters" section below.
Prerequisites:
1) download TFRecords, process them, and create incountry folds. See
`preprocessing/1_process_tfrecords.ipynb` and
`preprocessing/2_create_incountry_folds.ipynb`.
2) either train models (see README.md for instructions), or download model
checkpoints into outputs/ directory using the checkpoint download
script in `preprocessing/4_download_model_checkpoints.sh`
'''
from __future__ import annotations
from collections import defaultdict
from collections.abc import Callable, Iterable
import json
import os
from typing import Optional
import numpy as np
import tensorflow as tf
from batchers import batcher, tfrecord_paths_utils
from models.resnet_model import Hyperspectral_Resnet
from utils.run import check_existing, run_extraction_on_models
OUTPUTS_ROOT_DIR = 'outputs'
# ====================
# Parameters
# ====================
BATCH_SIZE = 128
KEEP_FRAC = 1.0
IS_TRAINING = False
# set CACHE = True for faster feature extraction on multiple models
# only if you have enough RAM (>= 50 GB)
CACHE = False
DHS_MODELS: list[str] = [
# put paths to DHS models here (relative to OUTPUTS_ROOT_DIR)
'dhs_ooc/DHS_OOC_A_ms_samescaled_b64_fc01_conv01_lr0001',
'dhs_ooc/DHS_OOC_B_ms_samescaled_b64_fc001_conv001_lr0001',
'dhs_ooc/DHS_OOC_C_ms_samescaled_b64_fc001_conv001_lr001',
'dhs_ooc/DHS_OOC_D_ms_samescaled_b64_fc001_conv001_lr01',
'dhs_ooc/DHS_OOC_E_ms_samescaled_b64_fc01_conv01_lr001',
'dhs_ooc/DHS_OOC_A_nl_random_b64_fc1.0_conv1.0_lr0001',
'dhs_ooc/DHS_OOC_B_nl_random_b64_fc1.0_conv1.0_lr0001',
'dhs_ooc/DHS_OOC_C_nl_random_b64_fc1.0_conv1.0_lr0001',
'dhs_ooc/DHS_OOC_D_nl_random_b64_fc1.0_conv1.0_lr01',
'dhs_ooc/DHS_OOC_E_nl_random_b64_fc1.0_conv1.0_lr0001',
'dhs_ooc/DHS_OOC_A_rgb_same_b64_fc001_conv001_lr01',
'dhs_ooc/DHS_OOC_B_rgb_same_b64_fc001_conv001_lr0001',
'dhs_ooc/DHS_OOC_C_rgb_same_b64_fc001_conv001_lr0001',
'dhs_ooc/DHS_OOC_D_rgb_same_b64_fc1.0_conv1.0_lr01',
'dhs_ooc/DHS_OOC_E_rgb_same_b64_fc001_conv001_lr0001',
'dhs_incountry/DHS_Incountry_A_ms_samescaled_b64_fc01_conv01_lr001',
'dhs_incountry/DHS_Incountry_A_nl_random_b64_fc1.0_conv1.0_lr0001',
'dhs_incountry/DHS_Incountry_B_ms_samescaled_b64_fc1_conv1_lr001',
'dhs_incountry/DHS_Incountry_B_nl_random_b64_fc1.0_conv1.0_lr0001',
'dhs_incountry/DHS_Incountry_C_ms_samescaled_b64_fc1.0_conv1.0_lr0001',
'dhs_incountry/DHS_Incountry_C_nl_random_b64_fc1.0_conv1.0_lr0001',
'dhs_incountry/DHS_Incountry_D_ms_samescaled_b64_fc001_conv001_lr0001',
'dhs_incountry/DHS_Incountry_D_nl_random_b64_fc1.0_conv1.0_lr0001',
'dhs_incountry/DHS_Incountry_E_ms_samescaled_b64_fc001_conv001_lr0001',
'dhs_incountry/DHS_Incountry_E_nl_random_b64_fc01_conv01_lr001',
# put paths to DHSNL models here (for transfer learning)
# - NOTE: when extracting features for transfer learning models,
# set MODEL_PARAMS['num_outputs'] = 2. The transfer learning models output
# predictions for both DMSP and VIIRS nightlight intensities.
'transfer/transfer_nlcenter_ms_b64_fc001_conv001_lr0001',
'transfer/transfer_nlcenter_rgb_b64_fc001_conv001_lr0001',
# get paths for DHS OOC keep-frac models
# TODO
]
LSMS_MODELS: list[str] = [
# put paths to LSMS models here (relative to OUTPUTS_ROOT_DIR)
# TODO
]
# choose which GPU to run on
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
MODEL_PARAMS = {
'fc_reg': 5e-3, # this doesn't actually matter
'conv_reg': 5e-3, # this doesn't actually matter
'num_layers': 18,
'num_outputs': 1,
'is_training': IS_TRAINING,
}
# ====================
# End Parameters
# ====================
def get_model_class(model_arch: str) -> Callable:
if model_arch == 'resnet':
model_class = Hyperspectral_Resnet
else:
raise ValueError('Unknown model_arch. Currently only "resnet" is supported.')
return model_class
def get_batcher(dataset: str, ls_bands: str, nl_band: str, num_epochs: int,
cache: bool) -> tuple[batcher.Batcher, int, dict]:
'''Gets the batcher for a given dataset.
Args
- dataset: str, one of ['dhs', 'lsms'] # TODO
- ls_bands: one of [None, 'ms', 'rgb']
- nl_band: one of [None, 'merge', 'split']
- num_epochs: int
- cache: bool, whether to cache the dataset in memory if num_epochs > 1
Returns
- b: Batcher
- size: int, length of dataset
- feed_dict: dict, feed_dict for initializing the dataset iterator
'''
if dataset == 'dhs':
tfrecord_paths = tfrecord_paths_utils.dhs()
elif dataset == 'lsms': # TODO
tfrecord_paths = tfrecord_paths_utils.lsms()
else:
raise ValueError(f'dataset={dataset} is unsupported')
size = len(tfrecord_paths)
tfrecord_paths_ph = tf.placeholder(tf.string, shape=[size])
feed_dict = {tfrecord_paths_ph: tfrecord_paths}
if dataset == 'dhs':
b = batcher.Batcher(
tfrecord_files=tfrecord_paths_ph,
label_name='wealthpooled',
ls_bands=ls_bands,
nl_band=nl_band,
nl_label=None,
batch_size=BATCH_SIZE,
epochs=num_epochs,
normalize='DHS',
shuffle=False,
augment=False,
clipneg=True,
cache=(num_epochs > 1) and cache,
num_threads=5)
else: # LSMS, TODO
raise NotImplementedError
# b = delta_batcher.DeltaBatcher()
return b, size, feed_dict
def read_params_json(model_dir: str, keys: Iterable[str]) -> tuple:
'''Reads requested keys from json file at `model_dir/params.json`.
Args
- model_dir: str, path to model output directory containing params.json file
- keys: list of str, keys to read from the json file
Returns: tuple of values
'''
json_path = os.path.join(model_dir, 'params.json')
with open(json_path, 'r') as f:
params = json.load(f)
for k in keys:
if k not in params:
print(f'Did not find key "{k}" in {model_dir}/params.json. Setting to None.')
result = tuple(params.get(k, None) for k in keys)
return result
def main() -> None:
for model_dirs in [DHS_MODELS, LSMS_MODELS]:
if not check_existing(model_dirs,
outputs_root_dir=OUTPUTS_ROOT_DIR,
test_filename='features.npz'):
print('Stopping')
return
# group models by batcher configuration and model_arch, where
# config = (dataset, ls_bands, nl_band, model_arch)
all_models = {'dhs': DHS_MODELS, 'lsms': LSMS_MODELS}
models_by_config: dict[
tuple[str, Optional[str], Optional[str], str], list[str]
] = defaultdict(list)
for dataset, model_dirs in all_models.items():
for model_dir in model_dirs:
ls_bands, nl_band, model_arch = read_params_json(
model_dir=os.path.join(OUTPUTS_ROOT_DIR, model_dir),
keys=['ls_bands', 'nl_band', 'model_name'])
config = (dataset, ls_bands, nl_band, model_arch)
models_by_config[config].append(model_dir)
for config, model_dirs in models_by_config.items():
dataset, ls_bands, nl_band, model_arch = config
print('====== Current Config: ======')
print('- dataset:', dataset)
print('- ls_bands:', ls_bands)
print('- nl_band:', nl_band)
print('- model_arch:', model_arch)
print('- number of models:', len(model_dirs))
print()
b, size, feed_dict = get_batcher(
dataset=dataset, ls_bands=ls_bands, nl_band=nl_band,
num_epochs=len(model_dirs), cache=CACHE)
batches_per_epoch = int(np.ceil(size / BATCH_SIZE))
run_extraction_on_models(
model_dirs,
ModelClass=get_model_class(model_arch),
model_params=MODEL_PARAMS,
batcher=b,
batches_per_epoch=batches_per_epoch,
out_root_dir=OUTPUTS_ROOT_DIR,
save_filename='features.npz',
batch_keys=['labels', 'locs', 'years'],
feed_dict=feed_dict)
if __name__ == '__main__':
main()