diff --git a/ManifoldEM/GetDistancesS2.py b/ManifoldEM/GetDistancesS2.py index 38be153..604fcc3 100644 --- a/ManifoldEM/GetDistancesS2.py +++ b/ManifoldEM/GetDistancesS2.py @@ -82,7 +82,10 @@ def quats_to_unit_vecs(q: NDArray[Shape["4,*"], Float64]) -> NDArray[Shape["3,*" def get_psi(q: NDArray[Shape["4"], Float64], ref_vec: NDArray[Shape["3"], Float64]) -> float: s = -(1 + ref_vec[2]) * q[3] - ref_vec[0] * q[1] - ref_vec[1] * q[2] c = (1 + ref_vec[2]) * q[0] + ref_vec[1] * q[1] - ref_vec[0] * q[2] - psi = 2 * np.arctan(s / c) # note that the Psi are in the interval [-pi,pi] + if c == 0.0: + psi = np.sign(s) * np.pi + else: + psi = 2 * np.arctan(s / c) # note that the Psi are in the interval [-pi,pi] return psi