diff --git a/ecg/data/irhythm/extract_data.py b/ecg/data/irhythm/extract_data.py index 928ccb9e..9139a2be 100644 --- a/ecg/data/irhythm/extract_data.py +++ b/ecg/data/irhythm/extract_data.py @@ -19,12 +19,14 @@ from .dataset_tools.db_constants import ECG_EXT, EPI_EXT from .dataset_tools.extract_episodes import _find_all_files, qa + def get_all_records(src): """ Find all the ECG files. """ return _find_all_files(src, '', ECG_EXT) + def stratify(records, val_frac): """ Splits the data by patient into train and validation. @@ -46,6 +48,7 @@ def patient_id(record): val = [record for patient in val for record in patient] return train, val + def round_to_second(n): rate = int(ECG_SAMP_RATE) diff = (n - 1) % rate @@ -54,6 +57,7 @@ def round_to_second(n): else: return n + (rate - diff) + def load_episodes(record): base = os.path.splitext(record)[0] ep_json = base + EPI_EXT @@ -74,6 +78,7 @@ def load_episodes(record): return episodes + def make_labels(episodes, duration): labels = [] for episode in episodes: @@ -82,9 +87,10 @@ def make_labels(episodes, duration): rhythm = [episode['rhythm_name']] * rhythm_secs labels.extend(rhythm) labels = [labels[i:i+duration] - for i in range(0, len(labels) - duration + 1, duration)] + for i in range(0, len(labels) - duration + 1, duration)] return labels + def load_ecg(record, duration): with open(record, 'r') as fid: ecg = np.fromfile(fid, dtype=np.int16) @@ -98,9 +104,10 @@ def load_ecg(record, duration): ecg = ecg.reshape((-1, n_per_win)) n_segments = ecg.shape[0] segments = [arr.squeeze() - for arr in np.vsplit(ecg, range(1, n_segments))] + for arr in np.vsplit(ecg, range(1, n_segments))] return segments + def construct_dataset(records, duration): """ List of ecg records, duration to segment them into. @@ -113,6 +120,7 @@ def construct_dataset(records, duration): data.extend(zip(segments, labels)) return data + def load_all_data(data_path, duration, val_frac): print('Stratifying records...') train, val = stratify(get_all_records(data_path), val_frac=val_frac) @@ -137,7 +145,7 @@ def load_all_data(data_path, duration, val_frac): # Some tests for n, m in [(401, 401), (1, 1), (7, 1), (199, 201), (200, 201), (101, 201), (100, 1)]: - msg = "Bad round: {} didn't round to {} ." + msg = "Bad round: {} didn't round to {} ." assert round_to_second(n) == m, msg.format(n, m) print("Tests passed!")