Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimized get_unit_normal() and replaced np.cross() with custom cross() in manim.utils.space_ops #3494

Merged
merged 3 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 50 additions & 19 deletions manim/utils/space_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from manim.typing import Point3D_Array, Vector
from manim.typing import Point3D_Array, Vector, Vector3

__all__ = [
"quaternion_mult",
Expand Down Expand Up @@ -53,6 +53,16 @@ def norm_squared(v: float) -> float:
return np.dot(v, v)


def cross(v1: Vector3, v2: Vector3) -> Vector3:
return np.array(
[
v1[1] * v2[2] - v1[2] * v2[1],
v1[2] * v2[0] - v1[0] * v2[2],
v1[0] * v2[1] - v1[1] * v2[0],
]
)


# Quaternions
# TODO, implement quaternion type

Expand Down Expand Up @@ -273,12 +283,12 @@ def z_to_vector(vector: np.ndarray) -> np.ndarray:
(normalized) vector provided as an argument
"""
axis_z = normalize(vector)
axis_y = normalize(np.cross(axis_z, RIGHT))
axis_x = np.cross(axis_y, axis_z)
axis_y = normalize(cross(axis_z, RIGHT))
axis_x = cross(axis_y, axis_z)
if np.linalg.norm(axis_y) == 0:
# the vector passed just so happened to be in the x direction.
axis_x = normalize(np.cross(UP, axis_z))
axis_y = -np.cross(axis_x, axis_z)
axis_x = normalize(cross(UP, axis_z))
axis_y = -cross(axis_x, axis_z)

return np.array([axis_x, axis_y, axis_z]).T

Expand Down Expand Up @@ -359,7 +369,7 @@ def normalize_along_axis(array: np.ndarray, axis: np.ndarray) -> np.ndarray:
return array


def get_unit_normal(v1: np.ndarray, v2: np.ndarray, tol: float = 1e-6) -> np.ndarray:
def get_unit_normal(v1: Vector3, v2: Vector3, tol: float = 1e-6) -> Vector3:
"""Gets the unit normal of the vectors.

Parameters
Expand All @@ -376,16 +386,37 @@ def get_unit_normal(v1: np.ndarray, v2: np.ndarray, tol: float = 1e-6) -> np.nda
np.ndarray
The normal of the two vectors.
"""
v1, v2 = (normalize(i) for i in (v1, v2))
cp = np.cross(v1, v2)
cp_norm = np.linalg.norm(cp)
if cp_norm < tol:
# Vectors align, so find a normal to them in the plane shared with the z-axis
cp = np.cross(np.cross(v1, OUT), v1)
cp_norm = np.linalg.norm(cp)
if cp_norm < tol:
# Instead of normalizing v1 and v2, just divide by the greatest
# of all their absolute components, which is just enough
div1, div2 = max(np.abs(v1)), max(np.abs(v2))
if div1 == 0.0:
if div2 == 0.0:
return DOWN
return normalize(cp)
u = v2 / div2
elif div2 == 0.0:
u = v1 / div1
else:
# Normal scenario: v1 and v2 are both non-null
u1, u2 = v1 / div1, v2 / div2
cp = cross(u1, u2)
cp_norm = np.sqrt(norm_squared(cp))
if cp_norm > tol:
return cp / cp_norm
# Otherwise, v1 and v2 were aligned
u = u1

# If you are here, you have an "unique", non-zero, unit-ish vector u
# If it's also too aligned to the Z axis, just return DOWN
if abs(u[0]) < tol and abs(u[1]) < tol:
return DOWN
# Otherwise rotate u in the plane it shares with the Z axis,
# 90° TOWARDS the Z axis. This is done via (u x [0, 0, 1]) x u,
# which gives [-xz, -yz, x²+y²] (slightly scaled as well)
cp = np.array([-u[0] * u[2], -u[1] * u[2], u[0] * u[0] + u[1] * u[1]])
cp_norm = np.sqrt(norm_squared(cp))
# Because the norm(u) == 0 case was filtered in the beginning,
# there is no need to check if the norm of cp is 0
return cp / cp_norm


###
Expand Down Expand Up @@ -529,8 +560,8 @@ def line_intersection(
np.pad(np.array(i)[:, :2], ((0, 0), (0, 1)), constant_values=1)
for i in (line1, line2)
)
line1, line2 = (np.cross(*i) for i in padded)
x, y, z = np.cross(line1, line2)
line1, line2 = (cross(*i) for i in padded)
x, y, z = cross(line1, line2)

if z == 0:
raise ValueError(
Expand Down Expand Up @@ -558,7 +589,7 @@ def find_intersection(
result = []

for p0, v0, p1, v1 in zip(*[p0s, v0s, p1s, v1s]):
normal = np.cross(v1, np.cross(v0, v1))
normal = cross(v1, cross(v0, v1))
denom = max(np.dot(v0, normal), threshold)
result += [p0 + np.dot(p1 - p0, normal) / denom * v0]
return result
Expand Down Expand Up @@ -814,6 +845,6 @@ def perpendicular_bisector(
"""
p1 = line[0]
p2 = line[1]
direction = np.cross(p1 - p2, norm_vector)
direction = cross(p1 - p2, norm_vector)
m = midpoint(p1, p2)
return [m + direction, m - direction]
Binary file modified tests/test_graphical_units/control_data/threed/Sphere.npz
Binary file not shown.
Loading