forked from patrickmineault/spikefinder_submission
-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval_resnet.py
70 lines (59 loc) · 2.33 KB
/
eval_resnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""Predict spikes from calcium models using pretrained models.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import collections
import functools
import numpy as np
import os
import pandas as pd
import shutil
import tensorflow as tf
from tensorflow.contrib import learn
from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
from tensorflow.contrib.layers.python.layers import optimizers as optimizers_lib
from tensorflow import summary
from train_resnet import get_relevant_data, get_spike_classifier
from util.nnio import eval_and_save, load_all, split_data, pad_to_batch_size
from util.nncomponents import dense_batch_relu, summarize_layer
import resnet
config = resnet.get_config()
def main(unused_argv):
train_data, train_labels, train_dataset_map = get_relevant_data('train', config)
train_data, train_labels = pad_to_batch_size(train_data,
train_labels,
config['batch_size'],
config['N_neurons'])
spike_classifier = get_spike_classifier(config)
output_dir = 'preds/' + '_'.join(config['model_name'].split('/'))
if config['refine_recording'] is not None:
output_dir += '_refined'
eval_and_save(spike_classifier,
train_data,
train_dataset_map,
output_dir,
'%d.train.spikes.csv',
config['batch_size'])
test_data, test_labels, test_dataset_map = get_relevant_data('test', config)
test_data, test_labels = pad_to_batch_size(test_data,
test_labels,
config['batch_size'],
config['N_neurons'])
eval_and_save(spike_classifier,
test_data,
test_dataset_map,
output_dir,
'%d.test.spikes.csv',
config['batch_size'])
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str)
parser.add_argument('--refine_recording', type=int)
args = parser.parse_args()
for key, val in vars(args).items():
# Overwrite config with new vals
if val is not None:
config[key] = val
tf.app.run()