From 2a76a41822d220fe50e1139b8c5b01d3eec32da8 Mon Sep 17 00:00:00 2001 From: Pranav Rajpurkar Date: Sun, 22 Jan 2017 11:47:24 -0800 Subject: [PATCH] Add wavelet transformer --- ecg/feature.py | 29 ++++++++++++++++++++++++++++- ecg/keras_models/model.py | 2 +- ecg/loader.py | 10 +++++----- ecg/train-keras.py | 4 ++-- requirements.txt | 1 + 5 files changed, 37 insertions(+), 9 deletions(-) diff --git a/ecg/feature.py b/ecg/feature.py index 1fc1b961..e92c4d74 100644 --- a/ecg/feature.py +++ b/ecg/feature.py @@ -1,12 +1,39 @@ from sklearn import preprocessing +import pywt +import numpy as np +from builtins import zip class Normalizer(object): def __init__(self): self.scaler = None + def _dim_fix(self, x): + if (len(x.shape) == 2): + x = np.expand_dims(x, axis=-1) + assert(len(x.shape) == 3) + return x + def fit(self, x): + x = self._dim_fix(x) + x = x.reshape((x.shape[0]*x.shape[1], x.shape[2])) self.scaler = preprocessing.StandardScaler().fit(x) def transform(self, x): - return self.scaler.transform(x) + x = self._dim_fix(x) + original_shape = x.shape + new_shape = (x.shape[0]*x.shape[1], x.shape[2]) + return self.scaler.transform( + x.reshape(new_shape)).reshape(original_shape) + + +class WaveletTransformer(object): + def __init__(self, wavelet_type='db1'): + self.wavelet_type = wavelet_type + + def fit(self, x): + pass + + def transform(self, x): + x_new = np.array([np.array(pywt.dwt(x_indiv, self.wavelet_type)).T for x_indiv in x]) + return x_new diff --git a/ecg/keras_models/model.py b/ecg/keras_models/model.py index 18cfa0ac..7e46b42e 100644 --- a/ecg/keras_models/model.py +++ b/ecg/keras_models/model.py @@ -27,7 +27,7 @@ def add_recurrent_layers(model, **params): Recurrent = LSTM rec_layer = Recurrent( params["recurrent_hidden"], - consume_less="gpu", + consume_less="mem", dropout_W=params["recurrent_dropout"], dropout_U=params["recurrent_dropout"], return_sequences=True) diff --git a/ecg/loader.py b/ecg/loader.py index 6b22d8af..79acd3e5 100644 --- a/ecg/loader.py +++ b/ecg/loader.py @@ -51,10 +51,11 @@ def _postprocess(self, use_one_hot): self.x_train = np.array(self.x_train) self.x_test = np.array(self.x_test) - transformer = feature.Normalizer() - transformer.fit(self.x_train) - self.x_train = transformer.transform(self.x_train) - self.x_test = transformer.transform(self.x_test) + for transformer_fn in [feature.WaveletTransformer, feature.Normalizer]: + transformer = transformer_fn() + transformer.fit(self.x_train) + self.x_train = transformer.transform(self.x_train) + self.x_test = transformer.transform(self.x_test) label_counter = collections.Counter(l for labels in self.y_train for l in labels) @@ -158,5 +159,4 @@ def output_dim(self): count += 1 assert len(ecgs) == len(labels) == batch_size, \ "Invalid number of examples." - assert len(ecgs[0].shape) == 1, "ECG array should be 1D" assert count == len(ldr.x_train) // batch_size, "Wrong number of batches." diff --git a/ecg/train-keras.py b/ecg/train-keras.py index ba242289..1598ceb0 100644 --- a/ecg/train-keras.py +++ b/ecg/train-keras.py @@ -64,11 +64,11 @@ def save_params(params, start_time, net_type): seed=2016, use_cached_if_available=not args.refresh) - x_train = dl.x_train[:, :, np.newaxis] + x_train = dl.x_train y_train = dl.y_train print("Training size: " + str(len(x_train)) + " examples.") - x_val = dl.x_test[:, :, np.newaxis] + x_val = dl.x_test y_val = dl.y_test print("Validation size: " + str(len(x_val)) + " examples.") diff --git a/requirements.txt b/requirements.txt index a7429348..fb714113 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,6 +15,7 @@ pydot-ng==1.0.0 pyparsing==2.1.10 python-dateutil==2.6.0 pytz==2016.10 +PyWavelets==0.5.1 PyYAML==3.12 scikit-learn==0.18.1 scipy==0.18.1