Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
FelixBenning committed May 13, 2024
1 parent d321feb commit c7af432
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions pyrfd/covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,12 +374,11 @@ class RationalQuadratic(IsotropicCovariance):

def __init__(self, *, beta, **kwargs):
self.beta = beta
super().__init__(**kwargs) # calls _is_fitted, beta needs to be there
super().__init__(**kwargs) # calls _is_fitted, beta needs to be there

def _is_fitted(self):
return super()._is_fitted() and self.beta is not None


def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(beta={self.beta}, " + self._repr_helper() + ")"
Expand Down Expand Up @@ -426,17 +425,15 @@ def learning_rate(self, loss, grad_norm, b_size_inv=0):
polynomial = [-1, tmp, (1 + self.beta), tmp]
# careful opposite order than np.root expects!

minimum = self.bisection_root_finder(polynomial) # confer paper for uniqueness
minimum = self.bisection_root_finder(polynomial) # confer paper for uniqueness

# learning rate, not step size!
return self.scale * np.sqrt(self.beta) * minimum / grad_norm



def bisection_root_finder(self, polynomial):
""" find the root of an increasing polynomial on the interval [0, 1/sqrt(1+beta)] """
"""find the root of an increasing polynomial on the interval [0, 1/sqrt(1+beta)]"""
left = 0
right = 1/ np.sqrt(1+ self.beta)
right = 1 / np.sqrt(1 + self.beta)
while right - left > 1e-10:
mid = (left + right) / 2
if evaluate_polynomial(polynomial, mid) == 0:
Expand All @@ -448,8 +445,9 @@ def bisection_root_finder(self, polynomial):
left = mid
return (left + right) / 2


def evaluate_polynomial(polynomial, x):
return sum(coeff * (x ** idx) for idx, coeff in enumerate(polynomial))
return sum(coeff * (x**idx) for idx, coeff in enumerate(polynomial))


if __name__ == "__main__":
Expand Down

0 comments on commit c7af432

Please sign in to comment.