Skip to content

Commit

Permalink
cleanup reqs
Browse files Browse the repository at this point in the history
  • Loading branch information
Awni Hannun committed May 26, 2018
1 parent 0198bb9 commit 8bd093b
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 91 deletions.
24 changes: 12 additions & 12 deletions ecg/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ def add_conv_weight(
num_filters,
subsample_length=1,
**params):
from keras.layers.convolutional import Convolution1D
layer = Convolution1D(
nb_filter=num_filters,
filter_length=filter_length,
border_mode='same',
subsample_length=subsample_length,
init=params["conv_init"])(layer)
from keras.layers import Conv1D
layer = Conv1D(
filters=num_filters,
kernel_size=filter_length,
strides=subsample_length,
padding='same',
kernel_initializer=params["conv_init"])(layer)
return layer


Expand All @@ -45,8 +45,8 @@ def resnet_block(
subsample_length,
block_index,
**params):
from keras.layers import merge
from keras.layers.pooling import MaxPooling1D
from keras.layers import Add
from keras.layers import MaxPooling1D
from keras.layers.core import Lambda

def zeropad(x):
Expand All @@ -59,7 +59,7 @@ def zeropad_output_shape(input_shape):
shape[2] *= 2
return tuple(shape)

shortcut = MaxPooling1D(pool_length=subsample_length)(layer)
shortcut = MaxPooling1D(pool_size=subsample_length)(layer)
zero_pad = (block_index % params["conv_increase_channels_at"]) == 0 \
and block_index > 0
if zero_pad is True:
Expand All @@ -77,7 +77,7 @@ def zeropad_output_shape(input_shape):
num_filters,
subsample_length if i == 0 else 1,
**params)
layer = merge([shortcut, layer], mode="sum")
layer = Add()([shortcut, layer])
return layer

def get_num_filters_at_index(index, num_start_filters, **params):
Expand Down Expand Up @@ -142,6 +142,6 @@ def build_network(**params):
layer = add_resnet_layers(inputs, **params)

output = add_output_layer(layer, **params)
model = Model(input=[inputs], output=[output])
model = Model(inputs=[inputs], outputs=[output])
add_compile(model, **params)
return model
6 changes: 3 additions & 3 deletions ecg/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,12 @@ def train(args, params):
batch_size=batch_size,
augmenter=get_augment_fn(params))

samples_per_epoch = batch_size * int(x_train.shape[0] / batch_size)
steps_per_epoch = int(x_train.shape[0] / batch_size)

model.fit_generator(
train_data,
samples_per_epoch=samples_per_epoch,
nb_epoch=MAX_EPOCHS,
steps_per_epoch=steps_per_epoch,
epochs=MAX_EPOCHS,
validation_data=(x_test, y_test),
callbacks=[checkpointer, reduce_lr, stopping],
verbose=args.verbose)
Expand Down
79 changes: 4 additions & 75 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,76 +1,5 @@
appdirs==1.4.3
backports-abc==0.5
backports.shutil-get-terminal-size==1.0.0
backports.ssl-match-hostname==3.5.0.1
bleach==2.0.0
certifi==2017.4.17
configparser==3.5.0
cycler==0.10.0
decorator==4.0.11
docopt==0.6.2
entrypoints==0.2.3
enum34==1.1.6
funcsigs==1.0.2
functools32==3.2.3.post2
future==0.16.0
h5py==2.6.0
html5lib==0.999999999
ipykernel==4.6.1
ipython==5.4.1
ipython-genutils==0.2.0
ipywidgets==6.0.0
Jinja2==2.9.6
joblib==0.10.3
jsonschema==2.6.0
jupyter==1.0.0
jupyter-client==5.1.0
jupyter-console==5.1.0
jupyter-core==4.3.0
Keras==1.2.2
MarkupSafe==1.0
matplotlib==2.0.0
mistune==0.7.4
mock==2.0.0
nbconvert==5.2.1
nbformat==4.3.0
notebook==5.0.0
numpy==1.12.1
packaging==16.8
pandas==0.19.2
pandocfilters==1.4.1
pathlib2==2.3.0
pbr==3.0.0
pexpect==4.2.1
pickleshare==0.7.4
pprint==0.1
prompt-toolkit==1.0.14
protobuf==3.2.0
ptyprocess==0.5.2
pydot-ng==1.0.0
Pygments==2.2.0
pyparsing==2.2.0
python-dateutil==2.6.0
pytz==2016.10
PyWavelets==0.5.1
PyYAML==3.12
pyzmq==16.0.2
qtconsole==4.3.0
scandir==1.5
scikit-learn==0.18.1
scipy==0.18.1
simplegeneric==0.8.1
singledispatch==3.4.0.3
six==1.10.0
h5py==2.8.0rc1
Keras==2.1.6
sklearn==0.0
subprocess32==3.2.7
tabulate==0.7.7
terminado==0.6
testpath==0.3.1
Theano==0.8.2
tornado==4.5.1
tqdm==4.11.0
traitlets==4.3.2
wcwidth==0.1.7
webencodings==0.5.1
Werkzeug==0.12.1
widgetsnbextension==2.0.0
tqdm==4.23.4
unittest2==1.1.0
2 changes: 1 addition & 1 deletion setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ fi


pip install -r requirements.txt
pip install --upgrade $TF==1.0.1
pip install --upgrade $TF==1.8.0

0 comments on commit 8bd093b

Please sign in to comment.