Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00b1 committed Jun 10, 2024
1 parent feeb20b commit 46dcd83
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 147 deletions.
4 changes: 4 additions & 0 deletions src/beignet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)
from ._apply_rotation_matrix import apply_rotation_matrix
from ._apply_rotation_vector import apply_rotation_vector
from ._chandrupatla import chandrupatla
from ._compose_euler_angle import compose_euler_angle
from ._compose_quaternion import compose_quaternion
from ._compose_rotation_matrix import compose_rotation_matrix
Expand All @@ -28,6 +29,7 @@
from ._invert_quaternion import invert_quaternion
from ._invert_rotation_matrix import invert_rotation_matrix
from ._invert_rotation_vector import invert_rotation_vector
from ._newton import newton
from ._quaternion_identity import quaternion_identity
from ._quaternion_magnitude import quaternion_magnitude
from ._quaternion_mean import quaternion_mean
Expand Down Expand Up @@ -73,6 +75,7 @@
"apply_quaternion",
"apply_rotation_matrix",
"apply_rotation_vector",
"chandrupatla",
"compose_euler_angle",
"compose_quaternion",
"compose_rotation_matrix",
Expand All @@ -87,6 +90,7 @@
"invert_quaternion",
"invert_rotation_matrix",
"invert_rotation_vector",
"newton",
"quaternion_identity",
"quaternion_magnitude",
"quaternion_mean",
Expand Down
99 changes: 99 additions & 0 deletions src/beignet/_chandrupatla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import Callable, Optional

import torch
from torch import Tensor


def chandrupatla(
f: Callable,
x0: Tensor,
x1: Tensor,
*,
rtol: Optional[float] = None,
atol: Optional[float] = None,
maxiter: int = 100,
**_,
):
b = x0
a = x1
c = x1
fa = f(a)
fb = f(b)
fc = fa

assert (torch.sign(fa) * torch.sign(fb) <= 0).all()

t = 0.5 * torch.ones_like(fa)
xm = torch.zeros_like(a)

iterations = torch.zeros_like(fa, dtype=int)
converged = torch.zeros_like(fa, dtype=bool)

eps = torch.finfo(fa.dtype).eps
if rtol is None:
rtol = eps
if atol is None:
atol = 2 * eps

for _ in range(maxiter):
xt = a + t * (b - a)
ft = f(xt)
(
a,
b,
c,
fa,
fb,
fc,
t,
xt,
ft,
xm,
converged,
iterations,
) = _find_root_chandrupatla_iter(
a, b, c, fa, fb, fc, t, xt, ft, xm, converged, iterations, atol, rtol
)

if converged.all():
break

return xm, (converged, iterations)


def _find_root_chandrupatla_iter(
a, b, c, fa, fb, fc, t, xt, ft, xm, converged, iterations, atol, rtol
):
cond = torch.sign(ft) == torch.sign(fa)
c = torch.where(cond, a, b)
fc = torch.where(cond, fa, fb)
b = torch.where(cond, b, a)
fb = torch.where(cond, fb, fa)

a = xt
fa = ft

xm = torch.where(converged, xm, torch.where(torch.abs(fa) < torch.abs(fb), a, b))

tol = 2 * rtol * torch.abs(xm) + atol
tlim = tol / torch.abs(b - c)
converged = converged | (tlim > 0.5)

xi = (a - b) / (c - b)
phi = (fa - fb) / (fc - fb)

do_iqi = (phi.pow(2) < xi) & ((1 - phi).pow(2) < (1 - xi))

t = torch.where(
do_iqi,
fa / (fb - fa) * fc / (fb - fc)
+ (c - a) / (b - a) * fa / (fc - fa) * fb / (fc - fb),
0.5,
)

# limit to the range (tlim, 1-tlim)
t = torch.minimum(1 - tlim, torch.maximum(tlim, t))

iterations += ~converged

return a, b, c, fa, fb, fc, t, xt, ft, xm, converged, iterations
51 changes: 51 additions & 0 deletions src/beignet/_newton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Callable

import torch
from torch import Tensor


def newton(
func: Callable,
a: Tensor | None = None,
*,
atol: float = 0.000001,
rtol: float = 0.000001,
maxiter: int = 50,
) -> (Tensor, (bool, int)):
r"""
Find the root of a function using Newton’s method.
Parameters
----------
func : Callable
The function for which to find the root.
a : Tensor, optional
Initial guess. If not provided, a zero tensor is used.
atol : float, optional
Absolute tolerance. Default is 1e-6.
rtol : float, optional
Relative tolerance. Default is 1e-6.
maxiter : int, optional
Maximum number of iterations. Default is 50.
Returns
-------
output : Tensor
Root of the function.
"""
if a is None:
a = torch.zeros([0])

for iteration in range(maxiter):
b = a - torch.linalg.solve(torch.func.jacfwd(func)(a), func(a))

if torch.linalg.norm(b - a) < atol + rtol * torch.linalg.norm(b):
return b, (True, iteration)

a = b

return b, (False, maxiter)
166 changes: 31 additions & 135 deletions src/beignet/_root.py
Original file line number Diff line number Diff line change
@@ -1,144 +1,40 @@
from typing import Callable, Optional
from typing import Callable, Literal, Optional

import torch
from torch import Tensor

import beignet


def root(
f: Callable[[torch.Tensor], torch.Tensor],
x0: torch.Tensor,
x1: torch.Tensor,
func: Callable,
x0: Tensor,
x1: Tensor,
*,
rtol: Optional[float] = None,
atol: Optional[float] = None,
max_iter: int = 100,
method: str = "chandrupatla",
maxiter: int = 100,
method: Literal["Chandrupatla", "Newton"] = "Chandrupatla",
**kwargs,
):
"""Find a root of a function.
Parameters
----------
f: Callable[[torch.Tensor], torch.Tensor]
Function to find root of.
x0: torch.Tensor
Left bracket of root.
x1: torch.Tensor
Right bracket of root.
rtol: float, optional
Relative tolerance. Defaults to eps for input dtype.
atol: float, optional
Absolve tolerance. Defaults to 2*eps for input dtype.
max_iter: int, optional
Maximum number of solver iterations.
method: str, optional
Solver method to use. Defaults to 'chandrupatla'.
"""
if method == "chandrupatla":
return _find_root_chandrupatla(
f, x0, x1, rtol=rtol, atol=atol, max_iter=max_iter, **kwargs
)
else:
raise ValueError(f"unknown method {method}")


@torch.compile(fullgraph=True, dynamic=True)
def _find_root_chandrupatla_iter(
a, b, c, fa, fb, fc, t, xt, ft, xm, converged, iterations, atol, rtol
):
cond = torch.sign(ft) == torch.sign(fa)
c = torch.where(cond, a, b)
fc = torch.where(cond, fa, fb)
b = torch.where(cond, b, a)
fb = torch.where(cond, fb, fa)

a = xt
fa = ft

xm = torch.where(converged, xm, torch.where(torch.abs(fa) < torch.abs(fb), a, b))

tol = 2 * rtol * torch.abs(xm) + atol
tlim = tol / torch.abs(b - c)
converged = converged | (tlim > 0.5)

xi = (a - b) / (c - b)
phi = (fa - fb) / (fc - fb)

do_iqi = (phi.pow(2) < xi) & ((1 - phi).pow(2) < (1 - xi))

t = torch.where(
do_iqi,
fa / (fb - fa) * fc / (fb - fc)
+ (c - a) / (b - a) * fa / (fc - fa) * fb / (fc - fb),
0.5,
)

# limit to the range (tlim, 1-tlim)
t = torch.minimum(1 - tlim, torch.maximum(tlim, t))

iterations += ~converged

return a, b, c, fa, fb, fc, t, xt, ft, xm, converged, iterations


# adapted from https://www.embeddedrelated.com/showarticle/855.php
def _find_root_chandrupatla(
f: Callable[[torch.Tensor], torch.Tensor],
x0: torch.Tensor,
x1: torch.Tensor,
*,
rtol: Optional[float] = None,
atol: Optional[float] = None,
max_iter: int = 100,
**_,
):
b = x0
a = x1
c = x1
fa = f(a)
fb = f(b)
fc = fa

assert (torch.sign(fa) * torch.sign(fb) <= 0).all()

t = 0.5 * torch.ones_like(fa)
xm = torch.zeros_like(a)

iterations = torch.zeros_like(fa, dtype=int)
converged = torch.zeros_like(fa, dtype=bool)

eps = torch.finfo(fa.dtype).eps
if rtol is None:
rtol = eps
if atol is None:
atol = 2 * eps

for _ in range(max_iter):
xt = a + t * (b - a)
ft = f(xt)
(
a,
b,
c,
fa,
fb,
fc,
t,
xt,
ft,
xm,
converged,
iterations,
) = _find_root_chandrupatla_iter(
a, b, c, fa, fb, fc, t, xt, ft, xm, converged, iterations, atol, rtol
)

if converged.all():
break

return xm, {"converged": converged, "iterations": iterations}
r""" """
match method:
case "Chandrupatla":
return beignet.chandrupatla(
func,
x0,
x1,
rtol=rtol,
atol=atol,
maxiter=maxiter,
**kwargs,
)
case "Newton":
return beignet.newton(
func,
x0,
rtol=rtol,
atol=atol,
maxiter=maxiter,
)
case _:
raise ValueError
13 changes: 13 additions & 0 deletions tests/beignet/test__newton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import beignet
import scipy.optimize
import torch.testing


def test_newton():
def func(x):
return x**3 - 1

torch.testing.assert_close(
beignet.newton(func, torch.tensor([1.5], dtype=torch.float64))[0],
torch.tensor([scipy.optimize.newton(func, 1.5)], dtype=torch.float64),
)
19 changes: 7 additions & 12 deletions tests/beignet/test__root.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,13 @@
import torch


def test_find_root_chandrupatla():
def test_root():
c = torch.linspace(2, 100, 1001, dtype=torch.float64)

def f(x):
return x.pow(2) - c
output, _ = beignet.root(
lambda x: x**2 - c,
torch.sqrt(c) - 1.1,
torch.sqrt(c) + 1.0,
)

# we don't want to put the root in exactly the center of the interval
a = c.sqrt() - 1.1
b = c.sqrt() + 1.0

x, meta = beignet.root(f, a, b)

assert meta["converged"].all()

torch.testing.assert_close(x, c.sqrt(), atol=1e-12, rtol=5e-11)
torch.testing.assert_close(output, torch.sqrt(c))

0 comments on commit 46dcd83

Please sign in to comment.