Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue in stratification of the data for valid and testing data #3

Open
SaqibMamoon opened this issue Sep 9, 2024 · 0 comments
Open

Comments

@SaqibMamoon
Copy link

Hi @ubc-tea @AnushreeBannadabhavi ,

Thank you for sharing the code online; I would like to share some insight. I found a bit of inconsistency with the number of samples in the valid and testing data split.

For 1009 subjects with 100% data usage, here's how the split calculation works:

Total subjects = 1009.
Data usage = 100% (i.e., all 1009 subjects are used).
Now, applying the train_length and val_length:

Train length = 70% of 1009 = 1009 × 0.7 = 706.3
1009×0.7=706.3, which rounds to approximately 706 subjects.
Validation length = 10% of 1009 = 1009 × 0.1 = 100.9 which rounds to approximately 101 subjects.
Test length = Remaining subjects = 1009 −706 − 101 = 202 subjects.
Final Split:
Train set: 706 subjects.
Validation set: 101 subjects.
Test set: 202 subjects.

However, given the piece of code in dataloader.py file for

split2 = StratifiedShuffleSplit( n_splits=1, test_size=test_length) for test_index, valid_index  in split2.split(final_timeseires_val_test, stratified): final_timeseires_test, final_pearson_test, labels_test = final_timeseires_val_test[ test_index], final_pearson_val_test[test_index], labels_val_test[test_index] final_timeseires_val, final_pearson_val, labels_val = final_timeseires_val_test[ valid_index], final_pearson_val_test[valid_index], labels_val_test[valid_index]

The test data we get is 101, and the valid data is 202, which I think is swapped and caused data to miss calculation for the test dataset.

Please guide me if I am mistaken. I think it is due to the error in the for loop as ( for test_index, valid_index in split2.split), and it must be ( for valid_index, test_index in split2.split).

Thanks a lot. If you have an in-depth explanation of why so, you could email me: [email protected]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant