From 8bd093bce50dcb5c794bfc7309024888b0750df2 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sat, 26 May 2018 17:17:45 -0400 Subject: [PATCH] cleanup reqs --- ecg/network.py | 24 +++++++-------- ecg/train.py | 6 ++-- requirements.txt | 79 +++--------------------------------------------- setup.sh | 2 +- 4 files changed, 20 insertions(+), 91 deletions(-) diff --git a/ecg/network.py b/ecg/network.py index c009ca12..b17c8ad0 100644 --- a/ecg/network.py +++ b/ecg/network.py @@ -18,13 +18,13 @@ def add_conv_weight( num_filters, subsample_length=1, **params): - from keras.layers.convolutional import Convolution1D - layer = Convolution1D( - nb_filter=num_filters, - filter_length=filter_length, - border_mode='same', - subsample_length=subsample_length, - init=params["conv_init"])(layer) + from keras.layers import Conv1D + layer = Conv1D( + filters=num_filters, + kernel_size=filter_length, + strides=subsample_length, + padding='same', + kernel_initializer=params["conv_init"])(layer) return layer @@ -45,8 +45,8 @@ def resnet_block( subsample_length, block_index, **params): - from keras.layers import merge - from keras.layers.pooling import MaxPooling1D + from keras.layers import Add + from keras.layers import MaxPooling1D from keras.layers.core import Lambda def zeropad(x): @@ -59,7 +59,7 @@ def zeropad_output_shape(input_shape): shape[2] *= 2 return tuple(shape) - shortcut = MaxPooling1D(pool_length=subsample_length)(layer) + shortcut = MaxPooling1D(pool_size=subsample_length)(layer) zero_pad = (block_index % params["conv_increase_channels_at"]) == 0 \ and block_index > 0 if zero_pad is True: @@ -77,7 +77,7 @@ def zeropad_output_shape(input_shape): num_filters, subsample_length if i == 0 else 1, **params) - layer = merge([shortcut, layer], mode="sum") + layer = Add()([shortcut, layer]) return layer def get_num_filters_at_index(index, num_start_filters, **params): @@ -142,6 +142,6 @@ def build_network(**params): layer = add_resnet_layers(inputs, **params) output = add_output_layer(layer, **params) - model = Model(input=[inputs], output=[output]) + model = Model(inputs=[inputs], outputs=[output]) add_compile(model, **params) return model diff --git a/ecg/train.py b/ecg/train.py index 8b2372ad..d48c5d1e 100644 --- a/ecg/train.py +++ b/ecg/train.py @@ -142,12 +142,12 @@ def train(args, params): batch_size=batch_size, augmenter=get_augment_fn(params)) - samples_per_epoch = batch_size * int(x_train.shape[0] / batch_size) + steps_per_epoch = int(x_train.shape[0] / batch_size) model.fit_generator( train_data, - samples_per_epoch=samples_per_epoch, - nb_epoch=MAX_EPOCHS, + steps_per_epoch=steps_per_epoch, + epochs=MAX_EPOCHS, validation_data=(x_test, y_test), callbacks=[checkpointer, reduce_lr, stopping], verbose=args.verbose) diff --git a/requirements.txt b/requirements.txt index 58de79a9..803389a9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,76 +1,5 @@ -appdirs==1.4.3 -backports-abc==0.5 -backports.shutil-get-terminal-size==1.0.0 -backports.ssl-match-hostname==3.5.0.1 -bleach==2.0.0 -certifi==2017.4.17 -configparser==3.5.0 -cycler==0.10.0 -decorator==4.0.11 -docopt==0.6.2 -entrypoints==0.2.3 -enum34==1.1.6 -funcsigs==1.0.2 -functools32==3.2.3.post2 -future==0.16.0 -h5py==2.6.0 -html5lib==0.999999999 -ipykernel==4.6.1 -ipython==5.4.1 -ipython-genutils==0.2.0 -ipywidgets==6.0.0 -Jinja2==2.9.6 -joblib==0.10.3 -jsonschema==2.6.0 -jupyter==1.0.0 -jupyter-client==5.1.0 -jupyter-console==5.1.0 -jupyter-core==4.3.0 -Keras==1.2.2 -MarkupSafe==1.0 -matplotlib==2.0.0 -mistune==0.7.4 -mock==2.0.0 -nbconvert==5.2.1 -nbformat==4.3.0 -notebook==5.0.0 -numpy==1.12.1 -packaging==16.8 -pandas==0.19.2 -pandocfilters==1.4.1 -pathlib2==2.3.0 -pbr==3.0.0 -pexpect==4.2.1 -pickleshare==0.7.4 -pprint==0.1 -prompt-toolkit==1.0.14 -protobuf==3.2.0 -ptyprocess==0.5.2 -pydot-ng==1.0.0 -Pygments==2.2.0 -pyparsing==2.2.0 -python-dateutil==2.6.0 -pytz==2016.10 -PyWavelets==0.5.1 -PyYAML==3.12 -pyzmq==16.0.2 -qtconsole==4.3.0 -scandir==1.5 -scikit-learn==0.18.1 -scipy==0.18.1 -simplegeneric==0.8.1 -singledispatch==3.4.0.3 -six==1.10.0 +h5py==2.8.0rc1 +Keras==2.1.6 sklearn==0.0 -subprocess32==3.2.7 -tabulate==0.7.7 -terminado==0.6 -testpath==0.3.1 -Theano==0.8.2 -tornado==4.5.1 -tqdm==4.11.0 -traitlets==4.3.2 -wcwidth==0.1.7 -webencodings==0.5.1 -Werkzeug==0.12.1 -widgetsnbextension==2.0.0 +tqdm==4.23.4 +unittest2==1.1.0 diff --git a/setup.sh b/setup.sh index 915dc1a2..d45b1de2 100755 --- a/setup.sh +++ b/setup.sh @@ -9,5 +9,5 @@ fi pip install -r requirements.txt -pip install --upgrade $TF==1.0.1 +pip install --upgrade $TF==1.8.0