-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Kliff DNN torch trainer #185
Conversation
@@ -82,33 +185,22 @@ def collate(self, batch: Any) -> dict: | |||
""" | |||
# get fingerprint and consistent properties | |||
config_0, property_dict_0 = batch[0] | |||
device = config_0.device | |||
ptr = torch.tensor([0], dtype=torch.int64, device=device) | |||
ptr = np.array([0], dtype=np.intc) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the collated data will be provided as the input to a torch NN model. Any reason why use np.array
but torch.tensor
for all variables?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update: NVM, now I understand that the batched data are passed to the descriptor in def _descriptor_eval_batch(self, batch) -> torch.Tensor
, which requires numpy array.
) | ||
self.torchscript_file = None | ||
self.train_dataloader = None | ||
self.validation_dataloader = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to add test_dataloader and test the model at the end of training?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I will add it uniformly across all trainers after finalizing this release. Idea is to have a base method test()
that can run some tests on the test datasets. It will do simple energy and forces test but will not be limited to it, but could also leverage openkim tests if user requests.
Looks great! Merged. |
Summary
Added tests and trainer for training generic dense neural networks using libdescriptor and pytorch (as opposed to lightning).
TODO (if any)
Checklist
Before a pull request can be merged, the following items must be checked:
Note that the CI system will run all the above checks. But it will be much more efficient if you already fix most errors prior to submitting the PR.