Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
FelixBenning committed May 21, 2024
1 parent 999d264 commit 959d953
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions pyrfd/covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,13 +429,12 @@ def learning_rate(self, loss, grad_norm, *, b_size_inv=0, conservatism=0):
var_adjust = var_reg.intercept / var_reg(b_size_inv)
var_g_adjust = g_var_reg.intercept / g_var_reg(b_size_inv)

tmp = var_adjust * (self.mean - loss) / 2
tmp = tmp if tmp > 0 else 0 # stability
return (
var_g_adjust
* (self.scale**2)
/ (torch.sqrt(tmp**2 + (self.scale * grad_norm * var_g_adjust) ** 2) + tmp)
t1 = var_adjust * (self.mean - loss) / 2
t1 = t1 if t1 > 0 else 0 # stability
t2 = torch.sqrt(
torch.as_tensor(t1**2 + (self.scale * grad_norm * var_g_adjust) ** 2)
)
return var_g_adjust * (self.scale**2) / (t2 + t1)


class RationalQuadratic(IsotropicCovariance):
Expand Down

0 comments on commit 959d953

Please sign in to comment.