Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The train loss cannot convergence #7

Open
fengshikun opened this issue Jul 16, 2024 · 14 comments
Open

The train loss cannot convergence #7

fengshikun opened this issue Jul 16, 2024 · 14 comments
Labels
question Further information is requested

Comments

@fengshikun
Copy link

fengshikun commented Jul 16, 2024

Hello, I've been attempting to train the score model using the command from the README file. However, I've noticed that the loss doesn't seem to converge. Could you please help me investigate which part might be going wrong?

image image
@plainerman
Copy link
Owner

We experienced similar behavior when training the model. This is why we optimized rmsd_lt2. You should see this increase. Do you?

@plainerman plainerman added the question Further information is requested label Jul 16, 2024
@fengshikun
Copy link
Author

Sorry, I just checked the logs and noticed that Val inference rmsds_lt2 consistently remains zero. Additionally, the validation loss shows fluctuating and abnormal values such as:

Epoch 19: Validation loss 42793415.1111  tr 171172360.1270   rot 1286.8611   tor 0.9821   sc_tor 0.9899
Epoch 20: Validation loss 330176498.0680  tr 1320697075.8095   rot 8896.0845   tor 0.9705   sc_tor 0.9875

I followed the command exactly as specified in the README file, so I suspect there might be a configuration issue or perhaps a bug in the code.

@plainerman
Copy link
Owner

It was common for us to have epochs with outliers and very big losses. But values should not be consistently this large.

In our run valinf_rmsds_lt2 only started to be promising after ~50 epochs. How long did you train for?

@fengshikun
Copy link
Author

I have trained for approximately 100 epochs, and the latest results for Val inference rmsds_lt2 are consistently zero, as shown below:

Epoch 89: Val inference rmsds_lt2 0.000 rmsds_lt5 0.000 sc_rmsds_lt2 3.000 sc_rmsds_lt1 0.000, sc_rmsds_lt0.5 0.000 avg_improve 16.225 avg_worse 17.347  sc_rmsds_lt2_from_holo 3.000 sc_rmsds_lt1_from_holo 0.000, sc_rmsds_lt05_from_holo.5 0.000 sc_rmsds_avg_improvement_from_holo 15.128 sc_rmsds_avg_worsening_from_holo 19.187  
Storing best sc_rmsds_lt05_from_holo model
Run name:  big_score_model

Epoch 94: Val inference rmsds_lt2 0.000 rmsds_lt5 0.000 sc_rmsds_lt2 3.000 sc_rmsds_lt1 0.000, sc_rmsds_lt0.5 0.000 avg_improve 16.219 avg_worse 24.843  sc_rmsds_lt2_from_holo 2.000 sc_rmsds_lt1_from_holo 0.000, sc_rmsds_lt05_from_holo.5 0.000 sc_rmsds_avg_improvement_from_holo 16.306 sc_rmsds_avg_worsening_from_holo 18.885  
Storing best sc_rmsds_lt05_from_holo model
Run name:  big_score_model

@fengshikun
Copy link
Author

Below is the complete training log file for your reference.
big_score_model_resume.log

@plainerman
Copy link
Owner

Could you try --limit_complexes 100 and setting the train set to validation set?

i.e.
--split_train data/splits/timesplit_no_lig_overlap_val_aligned --split_val data/splits/timesplit_no_lig_overlap_val_aligned --limit_complexes 100

And see if the problem persists?
With this, we can see whether it can overfit on a small subset.

@fengshikun
Copy link
Author

Could you try --limit_complexes 100 and setting the train set to validation set?

i.e. --split_train data/splits/timesplit_no_lig_overlap_val_aligned --split_val data/splits/timesplit_no_lig_overlap_val_aligned --limit_complexes 100

And see if the problem persists? With this, we can see whether it can overfit on a small subset.

Thank you. I'll try it out and will share the results later.

@plainerman
Copy link
Owner

You have to comment out this line for it to work

assert not bool(set(complexes_train) & set(complexes_val)), "Train and val splits have overlapping complexes"

@plainerman
Copy link
Owner

plainerman commented Jul 16, 2024

Maybe related: #6 (which has been fixed).
So maybe it is worth to pull again

@fengshikun
Copy link
Author

Maybe related: #6 (which has been fixed). So maybe it is worth to pull again

Got it, thanks for the remainder

@fengshikun
Copy link
Author

fengshikun commented Jul 17, 2024

Maybe related: #6 (which has been fixed). So maybe it is worth to pull again

I have pulled the newest version of the codebase and trained the scoring model using only 100 complex structures. However, the loss continues to fluctuate and has not converged yet. The training command used is as follows:

python -u train.py --run_name big_score_model --test_sigma_intervals --log_dir workdir --lr 1e-3 --tr_sigma_min 0.1 --tr_sigma_max 5 --rot_sigma_min 0.03 --rot_sigma_max 1.55 --tor_sigma_min 0.03 --sidechain_tor_sigma_min 0.03 --batch_size 32 --ns 60 --nv 10 --num_conv_layers 6 --distance_embed_dim 64 --cross_distance_embed_dim 64 --sigma_embed_dim 64 --dynamic_max_cross --scheduler plateau --scale_by_sigma --dropout 0.1 --sampling_alpha 1 --sampling_beta 1 --remove_hs --c_alpha_max_neighbors 24 --atom_max_neighbors 8 --receptor_radius 15 --num_dataloader_workers 1 --cudnn_benchmark --rot_alpha 1 --rot_beta 1 --tor_alpha 1 --tor_beta 1 --val_inference_freq 5 --use_ema --scheduler_patience 30 --n_epochs 750 --all_atom --sh_lmax 1 --sh_lmax 1 --split_train data/splits/timesplit_no_lig_overlap_val_aligned --split_val data/splits/timesplit_no_lig_overlap_val_aligned --limit_complexes 100 --pocket_reduction --pocket_buffer 10 --flexible_sidechains --flexdist 3.5 --flexdist_distance_metric prism --protein_file protein_esmfold_aligned_tr_fix --compare_true_protein --conformer_match_sidechains --conformer_match_score exp --match_max_rmsd 2 --use_original_conformer_fallback --use_original_conformer

The complete training log is provided below:
big_score_model.log

@glukhove
Copy link

glukhove commented Aug 7, 2024

@fengshikun hi, were you able to train the model?

@fengshikun
Copy link
Author

@fengshikun hi, were you able to train the model?

The loss still cannot converge.

@plainerman
Copy link
Owner

Sorry for not getting back to you sooner. I don't have any concrete results yet, but I think there might be an issue when I ported parts of our code base and changing things on cuda.
On CPU I am able to overfit on individual samples.

I will see if I can pinpoint this issue. Any help is much appreciated, as I don't have much time for this project nowadays.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants