Skip to content

Commit

Permalink
Add wavelet transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
rajpurkar committed Jan 22, 2017
1 parent b270e90 commit 2a76a41
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 9 deletions.
29 changes: 28 additions & 1 deletion ecg/feature.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,39 @@
from sklearn import preprocessing
import pywt
import numpy as np
from builtins import zip


class Normalizer(object):
def __init__(self):
self.scaler = None

def _dim_fix(self, x):
if (len(x.shape) == 2):
x = np.expand_dims(x, axis=-1)
assert(len(x.shape) == 3)
return x

def fit(self, x):
x = self._dim_fix(x)
x = x.reshape((x.shape[0]*x.shape[1], x.shape[2]))
self.scaler = preprocessing.StandardScaler().fit(x)

def transform(self, x):
return self.scaler.transform(x)
x = self._dim_fix(x)
original_shape = x.shape
new_shape = (x.shape[0]*x.shape[1], x.shape[2])
return self.scaler.transform(
x.reshape(new_shape)).reshape(original_shape)


class WaveletTransformer(object):
def __init__(self, wavelet_type='db1'):
self.wavelet_type = wavelet_type

def fit(self, x):
pass

def transform(self, x):
x_new = np.array([np.array(pywt.dwt(x_indiv, self.wavelet_type)).T for x_indiv in x])
return x_new
2 changes: 1 addition & 1 deletion ecg/keras_models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def add_recurrent_layers(model, **params):
Recurrent = LSTM
rec_layer = Recurrent(
params["recurrent_hidden"],
consume_less="gpu",
consume_less="mem",
dropout_W=params["recurrent_dropout"],
dropout_U=params["recurrent_dropout"],
return_sequences=True)
Expand Down
10 changes: 5 additions & 5 deletions ecg/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,11 @@ def _postprocess(self, use_one_hot):
self.x_train = np.array(self.x_train)
self.x_test = np.array(self.x_test)

transformer = feature.Normalizer()
transformer.fit(self.x_train)
self.x_train = transformer.transform(self.x_train)
self.x_test = transformer.transform(self.x_test)
for transformer_fn in [feature.WaveletTransformer, feature.Normalizer]:
transformer = transformer_fn()
transformer.fit(self.x_train)
self.x_train = transformer.transform(self.x_train)
self.x_test = transformer.transform(self.x_test)

label_counter = collections.Counter(l for labels in self.y_train
for l in labels)
Expand Down Expand Up @@ -158,5 +159,4 @@ def output_dim(self):
count += 1
assert len(ecgs) == len(labels) == batch_size, \
"Invalid number of examples."
assert len(ecgs[0].shape) == 1, "ECG array should be 1D"
assert count == len(ldr.x_train) // batch_size, "Wrong number of batches."
4 changes: 2 additions & 2 deletions ecg/train-keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ def save_params(params, start_time, net_type):
seed=2016,
use_cached_if_available=not args.refresh)

x_train = dl.x_train[:, :, np.newaxis]
x_train = dl.x_train
y_train = dl.y_train
print("Training size: " + str(len(x_train)) + " examples.")

x_val = dl.x_test[:, :, np.newaxis]
x_val = dl.x_test
y_val = dl.y_test
print("Validation size: " + str(len(x_val)) + " examples.")

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pydot-ng==1.0.0
pyparsing==2.1.10
python-dateutil==2.6.0
pytz==2016.10
PyWavelets==0.5.1
PyYAML==3.12
scikit-learn==0.18.1
scipy==0.18.1
Expand Down

0 comments on commit 2a76a41

Please sign in to comment.