Skip to content

Commit

Permalink
Test set and external test set evaluation for simple logistic regression
Browse files Browse the repository at this point in the history
  • Loading branch information
JulianKlug committed Jul 4, 2024
1 parent ee85170 commit 8eb29d0
Show file tree
Hide file tree
Showing 2 changed files with 702 additions and 3 deletions.
14 changes: 11 additions & 3 deletions prediction/outcome_prediction/data_loading/data_splits_to_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,18 @@ def save_train_splits(features_path: str, labels_path:str, outcome:str, output_d
os.path.join(output_dir, 'pid_test.tsv'),
sep='\t', index=False)

test_X = X[X.patient_id.isin(pid_test)]
# Here test data is not needed anymore, but for reference should be loaded as such: test_y = y[y.patient_id.isin(pid_test)]
test_X_df = X.loc[X.patient_id.isin(pid_test)]
test_y_df = y.loc[y.patient_id.isin(pid_test)]
test_X = features_to_numpy(test_X_df,
['case_admission_id', 'relative_sample_date_hourly_cat', 'sample_label', 'value'])
test_y = np.array([test_y_df[test_y_df.case_admission_id == cid].outcome.values[0] for cid in
test_X[:, 0, 0, 0]]).astype('float32')
test_X = test_X[:, :, :, -1].astype('float32')


# define K fold
kfold = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed)


### TRAIN MODEL USING K-FOLD CROSS-VALIDATION
splits = []
for fold_pid_train_idx, fold_pid_val_idx in kfold.split(pid_train, y_pid_train):
Expand All @@ -97,6 +102,7 @@ def save_train_splits(features_path: str, labels_path:str, outcome:str, output_d
splits.append((fold_X_train, fold_X_val, fold_y_train, fold_y_val))

torch.save(splits, os.path.join(output_dir, f'train_data_splits_{outcome.replace(" ", "_")}_ts{1-test_size}_rs{seed}_ns{n_splits}.pth'))
torch.save((test_X, test_y), os.path.join(output_dir, f'test_data_{outcome.replace(" ", "_")}_ts{1-test_size}_rs{seed}.pth'))
return splits


Expand Down Expand Up @@ -126,6 +132,8 @@ def save_train_splits(features_path: str, labels_path:str, outcome:str, output_d

if args.output_dir is None:
output_dir = os.path.join(os.path.dirname(args.features_path), f'{"_".join(args.outcome.split(" "))}_train_data_splits')
else:
output_dir = args.output_dir

save_train_splits(features_path=args.features_path, labels_path=args.labels_path, outcome=args.outcome,
output_dir=output_dir, test_pids_path=args.pid_test_path, train_pids_path=args.pid_train_path,
Expand Down
Loading

0 comments on commit 8eb29d0

Please sign in to comment.