-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathwavenet_utils.py
42 lines (39 loc) · 2.04 KB
/
wavenet_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import keras.backend as K
from keras.layers import AtrousConvolution1D
# from keras.utils.np_utils import conv_output_length
def categorical_mean_squared_error(y_true, y_pred):
"""MSE for categorical variables."""
return K.mean(K.square(K.argmax(y_true, axis=-1) -
K.argmax(y_pred, axis=-1)))
# class CausalAtrousConvolution1D(AtrousConvolution1D):
# def __init__(self, nb_filter, filter_length, init='glorot_uniform', activation=None, weights=None,
# border_mode='valid', subsample_length=1, atrous_rate=1, W_regularizer=None, b_regularizer=None,
# activity_regularizer=None, W_constraint=None, b_constraint=None, bias=True, causal=False, **kwargs):
# super(CausalAtrousConvolution1D, self).__init__(nb_filter, filter_length, init, activation, weights,
# border_mode, subsample_length, atrous_rate, W_regularizer,
# b_regularizer, activity_regularizer, W_constraint, b_constraint,
# bias, **kwargs)
# self.causal = causal
# if self.causal and border_mode != 'valid':
# raise ValueError("Causal mode dictates border_mode=valid.")
#
# def get_output_shape_for(self, input_shape):
# input_length = input_shape[1]
#
# if self.causal:
# input_length += self.atrous_rate * (self.filter_length - 1)
#
# length = conv_output_length(input_length,
# self.filter_length,
# self.border_mode,
# self.subsample[0],
# dilation=self.atrous_rate)
#
# return (input_shape[0], length, self.nb_filter)
#
# def call(self, x, mask=None):
# if self.causal:
# x = K.asymmetric_temporal_padding(x, self.atrous_rate * (self.filter_length - 1), 0)
# return super(CausalAtrousConvolution1D, self).call(x, mask)
#
#