Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00b1 committed Jun 4, 2024
1 parent 7d32a74 commit eaa9f62
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 17 deletions.
6 changes: 4 additions & 2 deletions src/beignet/special/_erf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from torch import Tensor

from ._erfc import erfc


def erf(z):
return 1 - erfc(z)
def erf(input: Tensor) -> Tensor:
return 1.0 - erfc(input)
5 changes: 3 additions & 2 deletions src/beignet/special/_erfc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch
from torch import Tensor

from ._faddeeva_w import faddeeva_w


def erfc(z):
return torch.exp(-z.pow(2)) * faddeeva_w(1j * z)
def erfc(input: Tensor) -> Tensor:
return torch.exp(-torch.pow(input, 2)) * faddeeva_w(1.0j * input)
27 changes: 14 additions & 13 deletions src/beignet/special/_faddeeva_w.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from torch import Tensor


def _voigt_v_impl(x, y):
Expand Down Expand Up @@ -199,12 +200,12 @@ def _faddeeva_w_impl(z):
# )


def faddeeva_w(z: torch.Tensor):
"""Compute faddeeva w function using method described in [1].
def faddeeva_w(input: Tensor):
r"""Compute faddeeva w function using method described in [1].
Parameterz
Parameters
----------
z: torch.Tensor
input: Tensor
complex input
References
Expand All @@ -214,14 +215,14 @@ def faddeeva_w(z: torch.Tensor):
SIAM Journal on Numerical Analysis 59.5 (2021): 2346-2367.
"""
# use symmetries to map to upper right quadrant of complex plane
imag_negative = z.imag < 0.0
z = torch.where(z.imag < 0.0, -z, z)
real_negative = z.real < 0.0
z = torch.where(z.real < 0.0, -z.conj(), z)
assert (z.real >= 0.0).all()
assert (z.imag >= 0.0).all()

out = _faddeeva_w_impl(z)
out = torch.where(imag_negative, 2 * torch.exp(-z.pow(2)) - out, out)
imag_negative = input.imag < 0.0
input = torch.where(input.imag < 0.0, -input, input)
real_negative = input.real < 0.0
input = torch.where(input.real < 0.0, -input.conj(), input)
assert (input.real >= 0.0).all()
assert (input.imag >= 0.0).all()

out = _faddeeva_w_impl(input)
out = torch.where(imag_negative, 2 * torch.exp(-input.pow(2)) - out, out)
out = torch.where(real_negative, out.conj(), out)
return out

0 comments on commit eaa9f62

Please sign in to comment.