Skip to content

Commit

Permalink
Increase number of channels
Browse files Browse the repository at this point in the history
  • Loading branch information
rajpurkar committed Feb 21, 2017
1 parent 760529e commit 68c87de
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 8 deletions.
1 change: 1 addition & 0 deletions configs/train.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"conv_dropout": 0.25,
"num_skip": 2,
"is_correct_resnet": false,
"conv_num_filters_start": 32,

"learning_rate": 0.001,
"optimizer": "adam",
Expand Down
57 changes: 49 additions & 8 deletions ecg/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,54 @@ def _bn_relu(layer, **params):
return layer


def add_conv_weight(layer, filter_length, subsample_length, **params):
def add_conv_weight(
layer,
filter_length,
num_filters,
subsample_length=1,
**params):
from keras.layers.convolutional import Convolution1D
layer = Convolution1D(
nb_filter=params["conv_num_filters"],
nb_filter=num_filters,
filter_length=filter_length,
border_mode='same',
subsample_length=subsample_length,
init=params["conv_init"])(layer)
return layer


def resnet_block(layer, subsample_length, **params):
def resnet_block(
layer,
num_filters,
subsample_length,
zero_pad=False,
**params):
from keras.layers import merge
from keras.layers.pooling import MaxPooling1D

from keras.layers.core import Lambda
from keras import backend as K

def zeropad(x):
y = K.zeros_like(x)
return K.concatenate([x, y], axis=2)

def zeropad_output_shape(input_shape):
shape = list(input_shape)
assert len(shape) == 3
shape[2] *= 2
return tuple(shape)

shortcut = MaxPooling1D(pool_length=subsample_length)(layer)
if zero_pad is True:
shortcut = Lambda(zeropad, output_shape=zeropad_output_shape)(shortcut)

for i in range(params["num_skip"]):
layer = _bn_relu(layer, **params)
layer = add_conv_weight(
layer,
params["conv_filter_length"],
num_filters,
subsample_length if i == 0 else 1,
**params)

Expand All @@ -53,19 +79,33 @@ def resnet_block(layer, subsample_length, **params):


def add_resnet_layers(layer, **params):
layer = _bn_relu(add_conv_weight(layer, 16, 1, **params), **params)
for subsample_length in params["conv_subsample_lengths"]:
layer = resnet_block(layer, subsample_length, **params)
layer = add_conv_weight(
layer,
params["conv_filter_length"],
params["conv_num_filters_start"],
subsample_length=1,
**params)
layer = _bn_relu(layer, **params)
for index, subsample_length in enumerate(params["conv_subsample_lengths"]):
num_filters = 2**int(index / 2) * params["conv_num_filters_start"]
zero_pad = (index % 2) == 0 and index > 0
layer = resnet_block(
layer,
num_filters,
subsample_length,
zero_pad=zero_pad,
**params)
layer = _bn_relu(layer, **params)
return layer


def add_conv_layers(layer, **params):
from keras.layers import merge
for subsample_length in params["conv_subsample_lengths"]:
for index, subsample_length in enumerate(params["conv_subsample_lengths"]):
shortcut = add_conv_weight(
layer,
params["conv_filter_length"],
params["conv_num_filters"],
subsample_length,
**params)
layer = shortcut
Expand All @@ -74,7 +114,8 @@ def add_conv_layers(layer, **params):
layer = add_conv_weight(
layer,
params["conv_filter_length"],
1,
params["conv_num_filters"],
subsample_length=1,
**params)
layer = merge([shortcut, layer], mode="sum")
return layer
Expand Down

0 comments on commit 68c87de

Please sign in to comment.