Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints to animation #3457

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 46 additions & 44 deletions manim/animation/animation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,36 @@

from __future__ import annotations

from typing import Any, TypeVar

from typing_extensions import Self

from manim.mobject.opengl.opengl_mobject import OpenGLMobject

from .. import config, logger
from ..constants import RendererType
from ..mobject import mobject
from ..mobject.mobject import Mobject
from ..mobject.opengl import opengl_mobject
from ..typing import RateFunc
from ..utils.rate_functions import linear, smooth

__all__ = ["Animation", "Wait", "override_animation"]


from copy import deepcopy
from typing import TYPE_CHECKING, Callable, Iterable, Sequence
from typing import TYPE_CHECKING, Callable, Iterable

from typing_extensions import Self

if TYPE_CHECKING:
from manim.scene.scene import Scene


DEFAULT_ANIMATION_RUN_TIME: float = 1.0
DEFAULT_ANIMATION_LAG_RATIO: float = 0.0
AnyCallableT = TypeVar("AnyCallableT", bound=Callable[..., Any])

DEFAULT_ANIMATION_RUN_TIME = 1.0
DEFAULT_ANIMATION_LAG_RATIO = 0.0


class Animation:
Expand Down Expand Up @@ -110,10 +117,10 @@

def __new__(
cls,
mobject=None,
*args,
use_override=True,
**kwargs,
mobject: Mobject | None = None,
*args: Any,
use_override: bool = True,
**kwargs: Any,
) -> Self:
if isinstance(mobject, Mobject) and use_override:
func = mobject.animation_override_for(cls)
Expand All @@ -132,34 +139,32 @@
mobject: Mobject | None,
lag_ratio: float = DEFAULT_ANIMATION_LAG_RATIO,
run_time: float = DEFAULT_ANIMATION_RUN_TIME,
rate_func: Callable[[float], float] = smooth,
rate_func: RateFunc = smooth,
reverse_rate_function: bool = False,
name: str = None,
remover: bool = False, # remove a mobject from the screen?
suspend_mobject_updating: bool = True,
introducer: bool = False,
*,
_on_finish: Callable[[], None] = lambda _: None,
**kwargs,
**kwargs: Any,
) -> None:
self._typecheck_input(mobject)
self.run_time: float = run_time
self.rate_func: Callable[[float], float] = rate_func
self.reverse_rate_function: bool = reverse_rate_function
self.name: str | None = name
self.remover: bool = remover
self.introducer: bool = introducer
self.suspend_mobject_updating: bool = suspend_mobject_updating
self.lag_ratio: float = lag_ratio
self._on_finish: Callable[[Scene], None] = _on_finish
self.run_time = run_time

Check warning

Code scanning / CodeQL

Overwriting attribute in super-class or sub-class Warning

Assignment overwrites attribute run_time, which was previously defined in subclass
Flash
.
Assignment overwrites attribute run_time, which was previously defined in subclass
TransformAnimations
.
Assignment overwrites attribute run_time, which was previously defined in subclass
TransformAnimations
.
self.rate_func = rate_func

Check warning

Code scanning / CodeQL

Overwriting attribute in super-class or sub-class Warning

Assignment overwrites attribute rate_func, which was previously defined in subclass
AnimationGroup
.
Assignment overwrites attribute rate_func, which was previously defined in subclass
ChangeSpeed
.
self.reverse_rate_function = reverse_rate_function
self.name = name
self.remover = remover

Check warning

Code scanning / CodeQL

Overwriting attribute in super-class or sub-class Warning

Assignment overwrites attribute remover, which was previously defined in subclass
ShowPassingFlashWithThinningStrokeWidth
.
self.introducer = introducer
self.suspend_mobject_updating = suspend_mobject_updating
self.lag_ratio = lag_ratio
self._on_finish = _on_finish
if config["renderer"] == RendererType.OPENGL:
self.starting_mobject: OpenGLMobject = OpenGLMobject()
self.mobject: OpenGLMobject = (
mobject if mobject is not None else OpenGLMobject()
)
self.starting_mobject = OpenGLMobject()
self.mobject = mobject if mobject is not None else OpenGLMobject()
else:
self.starting_mobject: Mobject = Mobject()
self.mobject: Mobject = mobject if mobject is not None else Mobject()
self.starting_mobject = Mobject()
self.mobject = mobject if mobject is not None else Mobject()
if kwargs:
logger.debug("Animation received extra kwargs: %s", kwargs)

Expand Down Expand Up @@ -237,7 +242,7 @@
if self.is_remover():
scene.remove(self.mobject)

def _setup_scene(self, scene: Scene) -> None:
def _setup_scene(self, scene: Scene | None) -> None:
"""Setup up the :class:`~.Scene` before starting the animation.

This includes to :meth:`~.Scene.add` the Animation's
Expand All @@ -260,7 +265,7 @@
# Keep track of where the mobject starts
return self.mobject.copy()

def get_all_mobjects(self) -> Sequence[Mobject]:
def get_all_mobjects(self) -> list[Mobject]:
"""Get all mobjects involved in the animation.

Ordering must match the ordering of arguments to interpolate_submobject
Expand All @@ -270,7 +275,7 @@
Sequence[Mobject]
The sequence of mobjects.
"""
return self.mobject, self.starting_mobject
return [self.mobject, self.starting_mobject]

def get_all_families_zipped(self) -> Iterable[tuple]:
if config["renderer"] == RendererType.OPENGL:
Expand Down Expand Up @@ -301,9 +306,9 @@
# The surrounding scene typically handles
# updating of self.mobject. Besides, in
# most cases its updating is suspended anyway
return list(filter(lambda m: m is not self.mobject, self.get_all_mobjects()))
return [m for m in self.get_all_mobjects() if m is not self.mobject]

def copy(self) -> Animation:
def copy(self) -> Self:
"""Create a copy of the animation.

Returns
Expand Down Expand Up @@ -350,7 +355,7 @@
starting_submobject: Mobject,
# target_copy: Mobject, #Todo: fix - signature of interpolate_submobject differs in Transform().
alpha: float,
) -> Animation:
) -> None:
# Typically implemented by subclass
pass

Expand Down Expand Up @@ -384,7 +389,7 @@
return self.rate_func(value - lower)

# Getters and setters
def set_run_time(self, run_time: float) -> Animation:
def set_run_time(self, run_time: float) -> Self:
"""Set the run time of the animation.

Parameters
Expand Down Expand Up @@ -417,8 +422,8 @@

def set_rate_func(
self,
rate_func: Callable[[float], float],
) -> Animation:
rate_func: RateFunc,
) -> Self:
"""Set the rate function of the animation.

Parameters
Expand All @@ -437,7 +442,7 @@

def get_rate_func(
self,
) -> Callable[[float], float]:
) -> RateFunc:
"""Get the rate function of the animation.

Returns
Expand All @@ -447,7 +452,7 @@
"""
return self.rate_func

def set_name(self, name: str) -> Animation:
def set_name(self, name: str) -> Self:
"""Set the name of the animation.

Parameters
Expand Down Expand Up @@ -513,10 +518,7 @@
TypeError: Object 42 cannot be converted to an animation

"""
if isinstance(anim, mobject._AnimationBuilder):
return anim.build()

if isinstance(anim, opengl_mobject._AnimationBuilder):
if isinstance(anim, (mobject._AnimationBuilder, opengl_mobject._AnimationBuilder)):
return anim.build()

if isinstance(anim, Animation):
Expand Down Expand Up @@ -553,9 +555,9 @@
run_time: float = 1,
stop_condition: Callable[[], bool] | None = None,
frozen_frame: bool | None = None,
rate_func: Callable[[float], float] = linear,
**kwargs,
):
rate_func: RateFunc = linear,
**kwargs: Any,
) -> None:
if stop_condition and frozen_frame:
raise ValueError("A static Wait animation cannot have a stop condition.")

Expand Down Expand Up @@ -584,7 +586,7 @@

def override_animation(
animation_class: type[Animation],
) -> Callable[[Callable], Callable]:
) -> Callable[[AnyCallableT], AnyCallableT]:
"""Decorator used to mark methods as overrides for specific :class:`~.Animation` types.

Should only be used to decorate methods of classes derived from :class:`~.Mobject`.
Expand Down Expand Up @@ -621,7 +623,7 @@

"""

def decorator(func):
def decorator(func: AnyCallableT) -> AnyCallableT:
func._override_animation = animation_class
return func

Expand Down
40 changes: 23 additions & 17 deletions manim/animation/changing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@

__all__ = ["AnimatedBoundary", "TracedPath"]

from typing import Callable
from typing import Any, Callable, Sequence

from typing_extensions import Self

from manim import Mobject
from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL
from manim.mobject.types.vectorized_mobject import VGroup, VMobject
from manim.typing import Point3D, RateFunc
from manim.utils.color import (
BLUE_B,
BLUE_D,
Expand Down Expand Up @@ -38,15 +42,15 @@ def construct(self):

def __init__(
self,
vmobject,
colors=[BLUE_D, BLUE_B, BLUE_E, GREY_BROWN],
max_stroke_width=3,
cycle_rate=0.5,
back_and_forth=True,
draw_rate_func=smooth,
fade_rate_func=smooth,
**kwargs,
):
vmobject: VMobject,
colors: Sequence[ParsableManimColor] = [BLUE_D, BLUE_B, BLUE_E, GREY_BROWN],
max_stroke_width: float = 3,
cycle_rate: float = 0.5,
back_and_forth: bool = True,
draw_rate_func: RateFunc = smooth,
fade_rate_func: RateFunc = smooth,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.colors = colors
self.max_stroke_width = max_stroke_width
Expand All @@ -56,13 +60,13 @@ def __init__(
self.fade_rate_func = fade_rate_func
self.vmobject = vmobject
self.boundary_copies = [
vmobject.copy().set_style(stroke_width=0, fill_opacity=0) for x in range(2)
vmobject.copy().set_style(stroke_width=0, fill_opacity=0) for _ in range(2)
]
self.add(*self.boundary_copies)
self.total_time = 0
self.add_updater(lambda m, dt: self.update_boundary_copies(dt))

def update_boundary_copies(self, dt):
def update_boundary_copies(self, dt: float) -> None:
# Not actual time, but something which passes at
# an altered rate to make the implementation below
# cleaner
Expand Down Expand Up @@ -90,7 +94,9 @@ def update_boundary_copies(self, dt):

self.total_time += dt

def full_family_become_partial(self, mob1, mob2, a, b):
def full_family_become_partial(
self, mob1: VMobject, mob2: VMobject, a: float, b: float
) -> Self:
family1 = mob1.family_members_with_points()
family2 = mob2.family_members_with_points()
for sm1, sm2 in zip(family1, family2):
Expand Down Expand Up @@ -142,19 +148,19 @@ def construct(self):

def __init__(
self,
traced_point_func: Callable,
traced_point_func: Callable[[], Point3D],
stroke_width: float = 2,
stroke_color: ParsableManimColor | None = WHITE,
dissipating_time: float | None = None,
**kwargs,
):
**kwargs: Any,
) -> None:
super().__init__(stroke_color=stroke_color, stroke_width=stroke_width, **kwargs)
self.traced_point_func = traced_point_func
self.dissipating_time = dissipating_time
self.time = 1 if self.dissipating_time else None
self.add_updater(self.update_path)

def update_path(self, mob, dt):
def update_path(self, mob: Mobject, dt: float) -> None:
new_point = self.traced_point_func()
if not self.has_points():
self.start_new_path(new_point)
Expand Down
Loading
Loading