Skip to content

Commit

Permalink
fix aliasing
Browse files Browse the repository at this point in the history
  • Loading branch information
kleinhenz committed Feb 17, 2025
1 parent fcd1a80 commit ccb99cb
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 8 deletions.
20 changes: 15 additions & 5 deletions src/beignet/_chandrupatla.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ def chandrupatla(
eps = torch.finfo(dtype).eps
a, b, *args = (x.to(dtype=dtype).contiguous() for x in (a, b, *args))

c = a
c = a.clone()

fa = func(a, *args)
fb = func(b, *args)
fc = fa
fc = fa.clone()

# root estimate
xm = torch.where(torch.abs(fa) < torch.abs(fb), a, b)
Expand Down Expand Up @@ -131,15 +131,25 @@ def loop_body(a, b, c, fa, fb, fc, xm, converged, iterations):
fc = torch.where(cond, fa, fb)
b = torch.where(cond, b, a)
fb = torch.where(cond, fb, fa)
a = xt
fa = ft
a = xt.clone()
fa = ft.clone()

xm = torch.where(
converged, xm, torch.where(torch.abs(fa) < torch.abs(fb), a, b)
)

iterations = iterations + ~converged
return a, b, c, fa, fb, fc, xm, converged, iterations
return (
a,
b,
c,
fa,
fb,
fc,
xm,
converged,
iterations,
)

a, b, c, fa, fb, fc, xm, converged, iterations = while_loop(
condition, loop_body, (a, b, c, fa, fb, fc, xm, converged, iterations)
Expand Down
14 changes: 11 additions & 3 deletions tests/beignet/test__root_scalar.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial

import beignet
import pytest
import torch
Expand All @@ -12,8 +14,9 @@ def xstar(c):
return c.pow(0.5)


@pytest.mark.parametrize("compile", [True, False])
@pytest.mark.parametrize("method", ["bisect", "chandrupatla"])
def test_root_scalar(method):
def test_root_scalar(compile, method):
c = torch.linspace(1.0, 10.0, 101, dtype=torch.float64)

lower = 0.0
Expand All @@ -28,9 +31,14 @@ def test_root_scalar(method):
"maxiter": maxiter,
}

root, info = beignet.root_scalar(
f, c, method=method, implicit_diff=True, options=options
solver = partial(
beignet.root_scalar, method=method, implicit_diff=True, options=options
)
if compile:
solver = torch.compile(solver, fullgraph=False)

root, info = solver(f, c)

expected = xstar(c)

assert info.converged.all()
Expand Down

0 comments on commit ccb99cb

Please sign in to comment.