From eaa9f62ae8c73fb05786236d7bb340a142836bcd Mon Sep 17 00:00:00 2001 From: Allen Goodman Date: Tue, 4 Jun 2024 16:48:36 -0400 Subject: [PATCH] cleanup --- src/beignet/special/_erf.py | 6 ++++-- src/beignet/special/_erfc.py | 5 +++-- src/beignet/special/_faddeeva_w.py | 27 ++++++++++++++------------- 3 files changed, 21 insertions(+), 17 deletions(-) diff --git a/src/beignet/special/_erf.py b/src/beignet/special/_erf.py index 2f00617a3f..36cbba4372 100644 --- a/src/beignet/special/_erf.py +++ b/src/beignet/special/_erf.py @@ -1,5 +1,7 @@ +from torch import Tensor + from ._erfc import erfc -def erf(z): - return 1 - erfc(z) +def erf(input: Tensor) -> Tensor: + return 1.0 - erfc(input) diff --git a/src/beignet/special/_erfc.py b/src/beignet/special/_erfc.py index 0fde39f680..03f7b732cc 100644 --- a/src/beignet/special/_erfc.py +++ b/src/beignet/special/_erfc.py @@ -1,7 +1,8 @@ import torch +from torch import Tensor from ._faddeeva_w import faddeeva_w -def erfc(z): - return torch.exp(-z.pow(2)) * faddeeva_w(1j * z) +def erfc(input: Tensor) -> Tensor: + return torch.exp(-torch.pow(input, 2)) * faddeeva_w(1.0j * input) diff --git a/src/beignet/special/_faddeeva_w.py b/src/beignet/special/_faddeeva_w.py index 778687d522..1fb9f09ea8 100644 --- a/src/beignet/special/_faddeeva_w.py +++ b/src/beignet/special/_faddeeva_w.py @@ -1,4 +1,5 @@ import torch +from torch import Tensor def _voigt_v_impl(x, y): @@ -199,12 +200,12 @@ def _faddeeva_w_impl(z): # ) -def faddeeva_w(z: torch.Tensor): - """Compute faddeeva w function using method described in [1]. +def faddeeva_w(input: Tensor): + r"""Compute faddeeva w function using method described in [1]. - Parameterz + Parameters ---------- - z: torch.Tensor + input: Tensor complex input References @@ -214,14 +215,14 @@ def faddeeva_w(z: torch.Tensor): SIAM Journal on Numerical Analysis 59.5 (2021): 2346-2367. """ # use symmetries to map to upper right quadrant of complex plane - imag_negative = z.imag < 0.0 - z = torch.where(z.imag < 0.0, -z, z) - real_negative = z.real < 0.0 - z = torch.where(z.real < 0.0, -z.conj(), z) - assert (z.real >= 0.0).all() - assert (z.imag >= 0.0).all() - - out = _faddeeva_w_impl(z) - out = torch.where(imag_negative, 2 * torch.exp(-z.pow(2)) - out, out) + imag_negative = input.imag < 0.0 + input = torch.where(input.imag < 0.0, -input, input) + real_negative = input.real < 0.0 + input = torch.where(input.real < 0.0, -input.conj(), input) + assert (input.real >= 0.0).all() + assert (input.imag >= 0.0).all() + + out = _faddeeva_w_impl(input) + out = torch.where(imag_negative, 2 * torch.exp(-input.pow(2)) - out, out) out = torch.where(real_negative, out.conj(), out) return out