Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Force gradient regularization #331

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add second gradient regularization
  • Loading branch information
RaulPPelaez committed Jun 20, 2024
commit ffcacbed20f2c76839663aa9bfcd54ee4428c9dc
35 changes: 35 additions & 0 deletions torchmdnet/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,32 @@ def _compute_losses(self, y, neg_y, batch, loss_fn, stage):
loss_y = self._update_loss_with_ema(stage, "y", loss_name, loss_y)
return {"y": loss_y, "neg_dy": loss_neg_y}

def _compute_second_derivative_regularization(self, y, neg_dy, batch):
# Compute force gradient and add it to the loss like: max(0, grad(neg_dy.sum())-eps)^2
# Args:
# y: predicted value
# neg_dy: predicted negative derivative
# batch: batch of data
# Returns:
# regularization: regularization term
assert "pos" in batch
force_sum = neg_dy.sum()
grad_outputs = [torch.ones_like(force_sum)]
assert batch.pos.requires_grad
ddy = torch.autograd.grad(
[force_sum],
[batch.pos],
grad_outputs=grad_outputs,
create_graph=True,
retain_graph=True,
)[0]
decay = self.hparams.regularization_decay / (self.current_epoch + 1)
regularization = (
torch.max((ddy.norm() - self.hparams.regularization_coefficient), 0)[0]
* decay
)
return regularization

def _update_loss_with_ema(self, stage, type, loss_name, loss):
# Update the loss using an exponential moving average when applicable
# Args:
Expand Down Expand Up @@ -235,6 +261,15 @@ def step(self, batch, loss_fn_list, stage):
step_losses["y"] * self.hparams.y_weight
+ step_losses["neg_dy"] * self.hparams.neg_dy_weight
)
if (
self.hparams.regularize_second_gradient
and self.hparams.derivative
and stage == "train"
):
total_loss = (
total_loss
+ self._compute_second_derivative_regularization(y, neg_dy, batch)
)
self.losses[stage]["total"][loss_name].append(total_loss.detach())
return total_loss

Expand Down
4 changes: 4 additions & 0 deletions torchmdnet/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def get_argparse():
parser.add_argument('--redirect', type=bool, default=False, help='Redirect stdout and stderr to log_dir/log')
parser.add_argument('--gradient-clipping', type=float, default=0.0, help='Gradient clipping norm')
parser.add_argument('--remove-ref-energy', action='store_true', help='If true, remove the reference energy from the dataset for delta-learning. Total energy can still be predicted by the model during inference by turning this flag off when loading. The dataset must be compatible with Atomref for this to be used.')

parser.add_argument('--regularize-second-gradient', action="store_true", help='If true, regularize the second derivative of the energy w.r.t. the coordinates')
parser.add_argument('--regularization-coefficient', type=float, default=0.0, help='Coefficient for the regularization term')
parser.add_argument('--regularization-decay', type=float, default=0.0, help='Decay rate for the regularization term')
# dataset specific
parser.add_argument('--dataset', default=None, type=str, choices=datasets.__all__, help='Name of the torch_geometric dataset')
parser.add_argument('--dataset-root', default='~/data', type=str, help='Data storage directory (not used if dataset is "CG")')
Expand Down