-
Notifications
You must be signed in to change notification settings - Fork 16
/
util.py
49 lines (34 loc) · 1.18 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import tensorflow as tf
import numpy as np
@tf.custom_gradient
def spike_function(v_scaled):
z_ = tf.where(v_scaled > 0, tf.ones_like(v_scaled), tf.zeros_like(v_scaled))
z_ = tf.cast(z_, dtype=tf.float32)
def grad(dy):
dE_dz = dy
dz_dv_scaled = tf.maximum(1 - tf.abs(v_scaled), 0)
dE_dv_scaled = dE_dz * dz_dv_scaled
return [dE_dv_scaled]
return tf.identity(z_, name="SpikeFunction"), grad
def test_print(x):
op = tf.print(x)
with tf.control_dependencies([op]):
return tf.identity(op)
def print_tensors(model):
with tf.Session() as sess:
import cv2
im = np.float32(cv2.imread("automobile.png") / 255.)
out = model(tf.constant(im[None, ...]))
sess.run(tf.global_variables_initializer())
sess.run(out)
in_ten = tf.get_default_graph().get_tensor_by_name("input_1:0")
coll = tf.get_collection("i_in")
ret = sess.run(coll, feed_dict={in_ten: np.ones(shape=(1, 32, 32, 3))})
for r in ret:
print(np.max(r))
def relu(x):
return np.maximum(0, x)
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def swish(x):
return x * 1 / (1 + np.exp(-x))