Skip to content

Commit

Permalink
Style fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
rajpurkar committed Jan 21, 2017
1 parent 62b4066 commit 2440972
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions ecg/data/irhythm/extract_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
from .dataset_tools.db_constants import ECG_EXT, EPI_EXT
from .dataset_tools.extract_episodes import _find_all_files, qa


def get_all_records(src):
"""
Find all the ECG files.
"""
return _find_all_files(src, '', ECG_EXT)


def stratify(records, val_frac):
"""
Splits the data by patient into train and validation.
Expand All @@ -46,6 +48,7 @@ def patient_id(record):
val = [record for patient in val for record in patient]
return train, val


def round_to_second(n):
rate = int(ECG_SAMP_RATE)
diff = (n - 1) % rate
Expand All @@ -54,6 +57,7 @@ def round_to_second(n):
else:
return n + (rate - diff)


def load_episodes(record):
base = os.path.splitext(record)[0]
ep_json = base + EPI_EXT
Expand All @@ -74,6 +78,7 @@ def load_episodes(record):

return episodes


def make_labels(episodes, duration):
labels = []
for episode in episodes:
Expand All @@ -82,9 +87,10 @@ def make_labels(episodes, duration):
rhythm = [episode['rhythm_name']] * rhythm_secs
labels.extend(rhythm)
labels = [labels[i:i+duration]
for i in range(0, len(labels) - duration + 1, duration)]
for i in range(0, len(labels) - duration + 1, duration)]
return labels


def load_ecg(record, duration):
with open(record, 'r') as fid:
ecg = np.fromfile(fid, dtype=np.int16)
Expand All @@ -98,9 +104,10 @@ def load_ecg(record, duration):
ecg = ecg.reshape((-1, n_per_win))
n_segments = ecg.shape[0]
segments = [arr.squeeze()
for arr in np.vsplit(ecg, range(1, n_segments))]
for arr in np.vsplit(ecg, range(1, n_segments))]
return segments


def construct_dataset(records, duration):
"""
List of ecg records, duration to segment them into.
Expand All @@ -113,6 +120,7 @@ def construct_dataset(records, duration):
data.extend(zip(segments, labels))
return data


def load_all_data(data_path, duration, val_frac):
print('Stratifying records...')
train, val = stratify(get_all_records(data_path), val_frac=val_frac)
Expand All @@ -137,7 +145,7 @@ def load_all_data(data_path, duration, val_frac):
# Some tests
for n, m in [(401, 401), (1, 1), (7, 1), (199, 201),
(200, 201), (101, 201), (100, 1)]:
msg = "Bad round: {} didn't round to {} ."
msg = "Bad round: {} didn't round to {} ."
assert round_to_second(n) == m, msg.format(n, m)

print("Tests passed!")

0 comments on commit 2440972

Please sign in to comment.