Skip to content

Commit

Permalink
faddeeva improvements (#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
kleinhenz authored Jun 12, 2024
1 parent 31f2985 commit 9d2a54b
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 117 deletions.
134 changes: 53 additions & 81 deletions src/beignet/special/_faddeeva_w.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import math

import torch
from torch import Tensor


def _voigt_v(x, y):
# assumes x >= 0, y >= 0

N = 11
def _voigt_v(x, y, n: int = 11):
if not (((x >= 0.0) & (y >= 0.0)) | torch.isnan(x) | torch.isnan(y)).all():
raise ValueError("_voigt_v only defined for x >= 0 and y >= 0")

# h = math.sqrt(math.pi / (N + 1))
h = 0.5116633539732443
h = math.sqrt(math.pi / (n + 1))

phi = (x / h) - (x / h).floor()
phi = (x / h) - torch.floor(x / h)

k = torch.arange(N + 1, dtype=x.dtype, device=x.device)
k = torch.arange(n + 1, dtype=x.dtype, device=x.device)
t = (k + 0.5) * h
tau = k[1:] * h

Expand All @@ -27,36 +27,20 @@ def _voigt_v(x, y):
).sum(dim=-1)

# equation 13
w_mm = (
(
2
* torch.exp(-x.pow(2) + y.pow(2))
* (
torch.cos(2 * x * y)
+ torch.exp(2 * torch.pi * y / h)
* torch.cos(2 * torch.pi * x / h - 2 * x * y)
)
)
/ (
1
+ torch.exp(4 * torch.pi * y / h)
+ 2 * torch.exp(2 * torch.pi * y / h) * torch.cos(2 * torch.pi * x / h)
)
) + w_m
expy = torch.exp(-2 * torch.pi * y / h)
w_mm_1 = (
2
* torch.exp(-x.pow(2) + y.pow(2))
* (torch.cos(2 * x * y) * expy + torch.cos(2 * torch.pi * x / h - 2 * x * y))
) / (expy + 1 / expy + 2 * torch.cos(2 * torch.pi * x / h))

w_mm = w_mm_1 + w_m

w_mt_1 = (
2
* torch.exp(-x.pow(2) + y.pow(2))
* (
torch.cos(2 * x * y)
- torch.exp(2 * torch.pi * y / h)
* torch.cos(2 * torch.pi * x / h - 2 * x * y)
)
) / (
1
+ torch.exp(4 * torch.pi * y / h)
- 2 * torch.exp(2 * torch.pi * y / h) * torch.cos(2 * torch.pi * x / h)
)
* (torch.cos(2 * x * y) * expy - torch.cos(2 * torch.pi * x / h - 2 * x * y))
) / (expy + 1 / expy - 2 * torch.cos(2 * torch.pi * x / h))

w_mt_2 = (h * y) / (torch.pi * (x.pow(2) + y.pow(2)))

Expand All @@ -79,17 +63,15 @@ def _voigt_v(x, y):
)


def _voigt_l(x, y):
# assumes x >= 0, y >= 0
def _voigt_l(x, y, n: int = 11):
if not (((x >= 0.0) & (y >= 0.0)) | torch.isnan(x) | torch.isnan(y)).all():
raise ValueError("_voigt_l only defined for x >= 0 and y >= 0")

N = 11
h = math.sqrt(math.pi / (n + 1))

# h = math.sqrt(math.pi / (N + 1))
h = 0.5116633539732443
phi = (x / h) - torch.floor(x / h)

phi = (x / h) - (x / h).floor()

k = torch.arange(N + 1, dtype=x.dtype, device=x.device)
k = torch.arange(n + 1, dtype=x.dtype, device=x.device)
t = (k + 0.5) * h
tau = k[1:] * h

Expand All @@ -103,36 +85,20 @@ def _voigt_l(x, y):
).sum(dim=-1)

# equation 13
w_mm = (
(
-2
* torch.exp(-x.pow(2) + y.pow(2))
* (
torch.sin(2 * x * y)
- torch.exp(2 * torch.pi * y / h)
* torch.sin(2 * torch.pi * x / h - 2 * x * y)
)
)
/ (
1
+ torch.exp(4 * torch.pi * y / h)
+ 2 * torch.exp(2 * torch.pi * y / h) * torch.cos(2 * torch.pi * x / h)
)
) + w_m
expy = torch.exp(-2 * torch.pi * y / h)
w_mm_1 = (
-2
* torch.exp(-x.pow(2) + y.pow(2))
* (torch.sin(2 * x * y) * expy - torch.sin(2 * torch.pi * x / h - 2 * x * y))
) / (expy + 1 / expy + 2 * torch.cos(2 * torch.pi * x / h))

w_mm = w_mm_1 + w_m

w_mt_1 = (
-2
* torch.exp(-x.pow(2) + y.pow(2))
* (
torch.sin(2 * x * y)
+ torch.exp(2 * torch.pi * y / h)
* torch.sin(2 * torch.pi * x / h - 2 * x * y)
)
) / (
1
+ torch.exp(4 * torch.pi * y / h)
- 2 * torch.exp(2 * torch.pi * y / h) * torch.cos(2 * torch.pi * x / h)
)
* (torch.sin(2 * x * y) * expy + torch.sin(2 * torch.pi * x / h - 2 * x * y))
) / (expy + 1 / expy - 2 * torch.cos(2 * torch.pi * x / h))

w_mt_2 = (h * x) / (torch.pi * (x.pow(2) + y.pow(2)))

Expand All @@ -155,10 +121,6 @@ def _voigt_l(x, y):
)


def _faddeeva_w_impl(z):
return _voigt_v(z.real, z.imag) + 1j * _voigt_l(z.real, z.imag)


def faddeeva_w(input: Tensor, *, out: Tensor | None = None) -> Tensor:
r"""
Faddeeva function.
Expand All @@ -175,25 +137,35 @@ def faddeeva_w(input: Tensor, *, out: Tensor | None = None) -> Tensor:
-------
Tensor
"""

if not torch.is_complex(input):
input = torch.complex(input, torch.zeros_like(input))

# use symmetries to map to upper right quadrant of complex plane
imag_negative = input.imag < 0.0
input = torch.where(input.imag < 0.0, -input, input)
real_negative = input.real < 0.0
input = torch.where(input.real < 0.0, -input.conj(), input)

a = input.real
b = input.imag
x = input.real
y = input.imag

assert (a >= 0.0).all()
assert (b >= 0.0).all()
if not (((x >= 0.0) & (y >= 0.0)) | torch.isnan(x) | torch.isnan(y)).all():
raise ValueError("failed to map input to x >= 0, y >= 0")

output = _voigt_v(a, b) + 1j * _voigt_l(a, b)
output = _voigt_v(x, y, n=11) + 1j * _voigt_l(x, y, n=11)

output = torch.where(imag_negative, 2 * torch.exp(-input.pow(2)) - output, output)
# compute real and imaginary parts separately to so we handle infs
# without unnecessary nans
expz2 = torch.complex(
2 * torch.exp(-x.pow(2) + y.pow(2)) * torch.cos(-2 * x * y),
2 * torch.exp(-x.pow(2) + y.pow(2)) * torch.sin(-2 * x * y),
)
output = torch.where(imag_negative, expz2 - output, output)
output = torch.where(real_negative, output.conj(), output, out=out)

if out is not None:
out.copy_(output)

return out

return torch.where(real_negative, output.conj(), output, out=out)
else:
return output
69 changes: 33 additions & 36 deletions tests/beignet/special/test__faddeeva_w.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import math

import beignet.special
import hypothesis
import hypothesis.strategies
Expand All @@ -7,43 +9,23 @@

@hypothesis.strategies.composite
def _strategy(function):
x, y = torch.meshgrid(
torch.linspace(
function(
hypothesis.strategies.floats(
min_value=-10,
max_value=-10,
),
),
function(
hypothesis.strategies.floats(
min_value=10,
max_value=10,
),
),
steps=128,
dtype=torch.float64,
),
torch.linspace(
function(
hypothesis.strategies.floats(
min_value=-10,
max_value=-10,
),
),
function(
hypothesis.strategies.floats(
min_value=10,
max_value=10,
),
),
steps=128,
dtype=torch.float64,
),
indexing="xy",
dtype = function(hypothesis.strategies.sampled_from([torch.float64, torch.float32]))

# avoid overflow of exp(y^2)
limit = math.sqrt(math.log(torch.finfo(dtype).max)) - 0.2

x = function(
hypothesis.strategies.floats(
min_value=-limit, max_value=limit, allow_nan=False, allow_infinity=False
)
)
y = function(
hypothesis.strategies.floats(
min_value=-limit, max_value=limit, allow_nan=False, allow_infinity=False
)
)

input = x + 1.0j * y
input = torch.complex(torch.tensor(x, dtype=dtype), torch.tensor(y, dtype=dtype))

return input, scipy.special.wofz(input)

Expand All @@ -52,4 +34,19 @@ def _strategy(function):
def test_faddeeva_w(data):
input, output = data

torch.testing.assert_close(beignet.special.faddeeva_w(input), output)
if input.dtype == torch.complex64:
rtol, atol = 1e-5, 1e-5
elif input.dtype == torch.complex128:
rtol, atol = 1e-10, 1e-10
else:
rtol, atol = None, None

torch.testing.assert_close(
beignet.special.faddeeva_w(input), output, rtol=rtol, atol=atol
)


def test_faddeeva_w_propagates_nan():
input = torch.complex(torch.tensor(torch.nan), torch.tensor(torch.nan))
output = beignet.special.faddeeva_w(input)
torch.testing.assert_close(input, output, equal_nan=True)

0 comments on commit 9d2a54b

Please sign in to comment.