diff --git a/docs/source/contributing/docs/typings.rst b/docs/source/contributing/docs/typings.rst index 748cff1e02..72891b0298 100644 --- a/docs/source/contributing/docs/typings.rst +++ b/docs/source/contributing/docs/typings.rst @@ -115,8 +115,8 @@ Typing guidelines from typing import TYPE_CHECKING if TYPE_CHECKING: - from manim.typing import Vector3 - # type stuff with Vector3 + from manim.typing import Vector3D + # type stuff with Vector3D Missing Sections for typehints are: ----------------------------------- diff --git a/manim/animation/animation.py b/manim/animation/animation.py index 90290eee3e..4ebded2d1c 100644 --- a/manim/animation/animation.py +++ b/manim/animation/animation.py @@ -3,6 +3,10 @@ from __future__ import annotations +from typing import Any, TypeVar + +from typing_extensions import Self + from manim.mobject.opengl.opengl_mobject import OpenGLMobject from .. import config, logger @@ -10,13 +14,14 @@ from ..mobject import mobject from ..mobject.mobject import Mobject from ..mobject.opengl import opengl_mobject +from ..typing import RateFunc from ..utils.rate_functions import linear, smooth __all__ = ["Animation", "Wait", "override_animation"] from copy import deepcopy -from typing import TYPE_CHECKING, Callable, Iterable, Sequence +from typing import TYPE_CHECKING, Callable, Iterable from typing_extensions import Self @@ -24,8 +29,10 @@ from manim.scene.scene import Scene -DEFAULT_ANIMATION_RUN_TIME: float = 1.0 -DEFAULT_ANIMATION_LAG_RATIO: float = 0.0 +AnyCallableT = TypeVar("AnyCallableT", bound=Callable[..., Any]) + +DEFAULT_ANIMATION_RUN_TIME = 1.0 +DEFAULT_ANIMATION_LAG_RATIO = 0.0 class Animation: @@ -110,10 +117,10 @@ def construct(self): def __new__( cls, - mobject=None, - *args, - use_override=True, - **kwargs, + mobject: Mobject | None = None, + *args: Any, + use_override: bool = True, + **kwargs: Any, ) -> Self: if isinstance(mobject, Mobject) and use_override: func = mobject.animation_override_for(cls) @@ -132,34 +139,32 @@ def __init__( mobject: Mobject | None, lag_ratio: float = DEFAULT_ANIMATION_LAG_RATIO, run_time: float = DEFAULT_ANIMATION_RUN_TIME, - rate_func: Callable[[float], float] = smooth, + rate_func: RateFunc = smooth, reverse_rate_function: bool = False, - name: str = None, + name: str | None = None, remover: bool = False, # remove a mobject from the screen? suspend_mobject_updating: bool = True, introducer: bool = False, *, - _on_finish: Callable[[], None] = lambda _: None, - **kwargs, + _on_finish: Callable[[Any], None] = lambda _: None, + **kwargs: Any, ) -> None: self._typecheck_input(mobject) - self.run_time: float = run_time - self.rate_func: Callable[[float], float] = rate_func - self.reverse_rate_function: bool = reverse_rate_function - self.name: str | None = name - self.remover: bool = remover - self.introducer: bool = introducer - self.suspend_mobject_updating: bool = suspend_mobject_updating - self.lag_ratio: float = lag_ratio - self._on_finish: Callable[[Scene], None] = _on_finish + self.run_time = run_time + self.rate_func = rate_func + self.reverse_rate_function = reverse_rate_function + self.name = name + self.remover = remover + self.introducer = introducer + self.suspend_mobject_updating = suspend_mobject_updating + self.lag_ratio = lag_ratio + self._on_finish = _on_finish if config["renderer"] == RendererType.OPENGL: - self.starting_mobject: OpenGLMobject = OpenGLMobject() - self.mobject: OpenGLMobject = ( - mobject if mobject is not None else OpenGLMobject() - ) + self.starting_mobject = OpenGLMobject() + self.mobject = mobject if mobject is not None else OpenGLMobject() else: - self.starting_mobject: Mobject = Mobject() - self.mobject: Mobject = mobject if mobject is not None else Mobject() + self.starting_mobject = Mobject() + self.mobject = mobject if mobject is not None else Mobject() if kwargs: logger.debug("Animation received extra kwargs: %s", kwargs) @@ -237,7 +242,7 @@ def clean_up_from_scene(self, scene: Scene) -> None: if self.is_remover(): scene.remove(self.mobject) - def _setup_scene(self, scene: Scene) -> None: + def _setup_scene(self, scene: Scene | None) -> None: """Setup up the :class:`~.Scene` before starting the animation. This includes to :meth:`~.Scene.add` the Animation's @@ -260,7 +265,7 @@ def create_starting_mobject(self) -> Mobject: # Keep track of where the mobject starts return self.mobject.copy() - def get_all_mobjects(self) -> Sequence[Mobject]: + def get_all_mobjects(self) -> list[Mobject]: """Get all mobjects involved in the animation. Ordering must match the ordering of arguments to interpolate_submobject @@ -270,7 +275,7 @@ def get_all_mobjects(self) -> Sequence[Mobject]: Sequence[Mobject] The sequence of mobjects. """ - return self.mobject, self.starting_mobject + return [self.mobject, self.starting_mobject] def get_all_families_zipped(self) -> Iterable[tuple]: if config["renderer"] == RendererType.OPENGL: @@ -301,9 +306,9 @@ def get_all_mobjects_to_update(self) -> list[Mobject]: # The surrounding scene typically handles # updating of self.mobject. Besides, in # most cases its updating is suspended anyway - return list(filter(lambda m: m is not self.mobject, self.get_all_mobjects())) + return [m for m in self.get_all_mobjects() if m is not self.mobject] - def copy(self) -> Animation: + def copy(self) -> Self: """Create a copy of the animation. Returns @@ -350,7 +355,7 @@ def interpolate_submobject( starting_submobject: Mobject, # target_copy: Mobject, #Todo: fix - signature of interpolate_submobject differs in Transform(). alpha: float, - ) -> Animation: + ) -> None: # Typically implemented by subclass pass @@ -384,7 +389,7 @@ def get_sub_alpha(self, alpha: float, index: int, num_submobjects: int) -> float return self.rate_func(value - lower) # Getters and setters - def set_run_time(self, run_time: float) -> Animation: + def set_run_time(self, run_time: float) -> Self: """Set the run time of the animation. Parameters @@ -417,8 +422,8 @@ def get_run_time(self) -> float: def set_rate_func( self, - rate_func: Callable[[float], float], - ) -> Animation: + rate_func: RateFunc, + ) -> Self: """Set the rate function of the animation. Parameters @@ -437,7 +442,7 @@ def set_rate_func( def get_rate_func( self, - ) -> Callable[[float], float]: + ) -> RateFunc: """Get the rate function of the animation. Returns @@ -447,7 +452,7 @@ def get_rate_func( """ return self.rate_func - def set_name(self, name: str) -> Animation: + def set_name(self, name: str) -> Self: """Set the name of the animation. Parameters @@ -513,10 +518,7 @@ def prepare_animation( TypeError: Object 42 cannot be converted to an animation """ - if isinstance(anim, mobject._AnimationBuilder): - return anim.build() - - if isinstance(anim, opengl_mobject._AnimationBuilder): + if isinstance(anim, (mobject._AnimationBuilder, opengl_mobject._AnimationBuilder)): return anim.build() if isinstance(anim, Animation): @@ -553,9 +555,9 @@ def __init__( run_time: float = 1, stop_condition: Callable[[], bool] | None = None, frozen_frame: bool | None = None, - rate_func: Callable[[float], float] = linear, - **kwargs, - ): + rate_func: RateFunc = linear, + **kwargs: Any, + ) -> None: if stop_condition and frozen_frame: raise ValueError("A static Wait animation cannot have a stop condition.") @@ -584,7 +586,7 @@ def interpolate(self, alpha: float) -> None: def override_animation( animation_class: type[Animation], -) -> Callable[[Callable], Callable]: +) -> Callable[[AnyCallableT], AnyCallableT]: """Decorator used to mark methods as overrides for specific :class:`~.Animation` types. Should only be used to decorate methods of classes derived from :class:`~.Mobject`. @@ -621,7 +623,7 @@ def construct(self): """ - def decorator(func): + def decorator(func: AnyCallableT) -> AnyCallableT: func._override_animation = animation_class return func diff --git a/manim/animation/changing.py b/manim/animation/changing.py index bb11cfc0a4..fd940eb8e2 100644 --- a/manim/animation/changing.py +++ b/manim/animation/changing.py @@ -4,10 +4,13 @@ __all__ = ["AnimatedBoundary", "TracedPath"] -from typing import Callable +from typing import TYPE_CHECKING, Any, Callable, Sequence + +from typing_extensions import Self from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL from manim.mobject.types.vectorized_mobject import VGroup, VMobject +from manim.typing import Point3D, RateFunc from manim.utils.color import ( BLUE_B, BLUE_D, @@ -18,6 +21,9 @@ ) from manim.utils.rate_functions import smooth +if TYPE_CHECKING: + from manim import Mobject + class AnimatedBoundary(VGroup): """Boundary of a :class:`.VMobject` with animated color change. @@ -38,15 +44,15 @@ def construct(self): def __init__( self, - vmobject, - colors=[BLUE_D, BLUE_B, BLUE_E, GREY_BROWN], - max_stroke_width=3, - cycle_rate=0.5, - back_and_forth=True, - draw_rate_func=smooth, - fade_rate_func=smooth, - **kwargs, - ): + vmobject: VMobject, + colors: Sequence[ParsableManimColor] = [BLUE_D, BLUE_B, BLUE_E, GREY_BROWN], + max_stroke_width: float = 3, + cycle_rate: float = 0.5, + back_and_forth: bool = True, + draw_rate_func: RateFunc = smooth, + fade_rate_func: RateFunc = smooth, + **kwargs: Any, + ) -> None: super().__init__(**kwargs) self.colors = colors self.max_stroke_width = max_stroke_width @@ -56,13 +62,13 @@ def __init__( self.fade_rate_func = fade_rate_func self.vmobject = vmobject self.boundary_copies = [ - vmobject.copy().set_style(stroke_width=0, fill_opacity=0) for x in range(2) + vmobject.copy().set_style(stroke_width=0, fill_opacity=0) for _ in range(2) ] self.add(*self.boundary_copies) self.total_time = 0 self.add_updater(lambda m, dt: self.update_boundary_copies(dt)) - def update_boundary_copies(self, dt): + def update_boundary_copies(self, dt: float) -> None: # Not actual time, but something which passes at # an altered rate to make the implementation below # cleaner @@ -90,7 +96,9 @@ def update_boundary_copies(self, dt): self.total_time += dt - def full_family_become_partial(self, mob1, mob2, a, b): + def full_family_become_partial( + self, mob1: VMobject, mob2: VMobject, a: float, b: float + ) -> Self: family1 = mob1.family_members_with_points() family2 = mob2.family_members_with_points() for sm1, sm2 in zip(family1, family2): @@ -142,19 +150,19 @@ def construct(self): def __init__( self, - traced_point_func: Callable, + traced_point_func: Callable[[], Point3D], stroke_width: float = 2, stroke_color: ParsableManimColor | None = WHITE, dissipating_time: float | None = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: super().__init__(stroke_color=stroke_color, stroke_width=stroke_width, **kwargs) self.traced_point_func = traced_point_func self.dissipating_time = dissipating_time self.time = 1 if self.dissipating_time else None self.add_updater(self.update_path) - def update_path(self, mob, dt): + def update_path(self, mob: Mobject, dt: float) -> None: new_point = self.traced_point_func() if not self.has_points(): self.start_new_path(new_point) diff --git a/manim/animation/composition.py b/manim/animation/composition.py index ff219405f2..544eeab987 100644 --- a/manim/animation/composition.py +++ b/manim/animation/composition.py @@ -4,7 +4,7 @@ from __future__ import annotations import types -from typing import TYPE_CHECKING, Callable, Iterable, Sequence +from typing import TYPE_CHECKING, Any, Callable, Iterable import numpy as np @@ -16,6 +16,7 @@ from ..constants import RendererType from ..mobject.mobject import Group, Mobject from ..scene.scene import Scene +from ..typing import RateFunc from ..utils.iterables import remove_list_redundancies from ..utils.rate_functions import linear @@ -27,7 +28,7 @@ __all__ = ["AnimationGroup", "Succession", "LaggedStart", "LaggedStartMap"] -DEFAULT_LAGGED_START_LAG_RATIO: float = 0.05 +DEFAULT_LAGGED_START_LAG_RATIO = 0.05 class AnimationGroup(Animation): @@ -59,9 +60,9 @@ def __init__( *animations: Animation | Iterable[Animation] | types.GeneratorType[Animation], group: Group | VGroup | OpenGLGroup | OpenGLVGroup = None, run_time: float | None = None, - rate_func: Callable[[float], float] = linear, + rate_func: RateFunc = linear, lag_ratio: float = 0, - **kwargs, + **kwargs: Any, ) -> None: arg_anim = flatten_iterable_parameters(animations) self.animations = [prepare_animation(anim) for anim in arg_anim] @@ -80,7 +81,7 @@ def __init__( ) self.run_time: float = self.init_run_time(run_time) - def get_all_mobjects(self) -> Sequence[Mobject]: + def get_all_mobjects(self) -> list[Mobject]: return list(self.group) def begin(self) -> None: @@ -99,7 +100,7 @@ def begin(self) -> None: for anim in self.animations: anim.begin() - def _setup_scene(self, scene) -> None: + def _setup_scene(self, scene: Scene | None) -> None: for anim in self.animations: anim._setup_scene(scene) @@ -120,7 +121,7 @@ def update_mobjects(self, dt: float) -> None: for anim in self.animations: anim.update_mobjects(dt) - def init_run_time(self, run_time) -> float: + def init_run_time(self, run_time: float | None) -> float: """Calculates the run time of the animation, if different from ``run_time``. Parameters @@ -204,7 +205,9 @@ def construct(self): )) """ - def __init__(self, *animations: Animation, lag_ratio: float = 1, **kwargs) -> None: + def __init__( + self, *animations: Animation, lag_ratio: float = 1, **kwargs: Any + ) -> None: super().__init__(*animations, lag_ratio=lag_ratio, **kwargs) def begin(self) -> None: @@ -219,7 +222,7 @@ def update_mobjects(self, dt: float) -> None: if self.active_animation: self.active_animation.update_mobjects(dt) - def _setup_scene(self, scene) -> None: + def _setup_scene(self, scene: Scene | None) -> None: if scene is None: return if self.is_introducer(): @@ -311,8 +314,8 @@ def __init__( self, *animations: Animation, lag_ratio: float = DEFAULT_LAGGED_START_LAG_RATIO, - **kwargs, - ): + **kwargs: Any, + ) -> None: super().__init__(*animations, lag_ratio=lag_ratio, **kwargs) @@ -358,9 +361,9 @@ def __init__( self, AnimationClass: Callable[..., Animation], mobject: Mobject, - arg_creator: Callable[[Mobject], str] = None, + arg_creator: Callable[[Mobject], str] | None = None, run_time: float = 2, - **kwargs, + **kwargs: Any, ) -> None: args_list = [] for submob in mobject: diff --git a/manim/animation/creation.py b/manim/animation/creation.py index 6f8173e35a..2b371b2d7d 100644 --- a/manim/animation/creation.py +++ b/manim/animation/creation.py @@ -74,15 +74,17 @@ def construct(self): import itertools as it -from typing import TYPE_CHECKING, Callable, Iterable, Sequence +from typing import TYPE_CHECKING, Any, Callable, Iterable import numpy as np +import numpy.typing as npt if TYPE_CHECKING: from manim.mobject.text.text_mobject import Text from manim.mobject.opengl.opengl_surface import OpenGLSurface from manim.mobject.opengl.opengl_vectorized_mobject import OpenGLVMobject +from manim.typing import ManimFloat, RateFunc from manim.utils.color import ManimColor from .. import config @@ -112,8 +114,8 @@ class ShowPartial(Animation): def __init__( self, mobject: VMobject | OpenGLVMobject | OpenGLSurface | None, - **kwargs, - ): + **kwargs: Any, + ) -> None: pointwise = getattr(mobject, "pointwise_become_partial", None) if not callable(pointwise): raise NotImplementedError("This animation is not defined for this Mobject.") @@ -165,7 +167,7 @@ def __init__( mobject: VMobject | OpenGLVMobject | OpenGLSurface, lag_ratio: float = 1.0, introducer: bool = True, - **kwargs, + **kwargs: Any, ) -> None: super().__init__(mobject, lag_ratio=lag_ratio, introducer=introducer, **kwargs) @@ -266,14 +268,14 @@ def get_stroke_color(self, vmobject: VMobject | OpenGLVMobject) -> ManimColor: return vmobject.get_stroke_color() return vmobject.get_color() - def get_all_mobjects(self) -> Sequence[Mobject]: + def get_all_mobjects(self) -> list[Mobject]: return [*super().get_all_mobjects(), self.outline] def interpolate_submobject( self, submobject: Mobject, starting_submobject: Mobject, - outline, + outline: Mobject, alpha: float, ) -> None: # Fixme: not matching the parent class? What is outline doing here? index: int @@ -316,9 +318,9 @@ def construct(self): def __init__( self, vmobject: VMobject | OpenGLVMobject, - rate_func: Callable[[float], float] = linear, + rate_func: RateFunc = linear, reverse: bool = False, - **kwargs, + **kwargs: Any, ) -> None: run_time: float | None = kwargs.pop("run_time", None) lag_ratio: float | None = kwargs.pop("lag_ratio", None) @@ -400,9 +402,9 @@ def construct(self): def __init__( self, vmobject: VMobject, - rate_func: Callable[[float], float] = linear, + rate_func: RateFunc = linear, reverse: bool = True, - **kwargs, + **kwargs: Any, ) -> None: run_time: float | None = kwargs.pop("run_time", None) lag_ratio: float | None = kwargs.pop("lag_ratio", None) @@ -453,8 +455,8 @@ def __init__( self, shapes: Mobject, scale_factor: float = 8, - fade_in_fraction=0.3, - **kwargs, + fade_in_fraction: float = 0.3, + **kwargs: Any, ) -> None: self.shapes = shapes self.scale_factor = scale_factor @@ -504,9 +506,9 @@ def __init__( self, group: Mobject, suspend_mobject_updating: bool = False, - int_func: Callable[[np.ndarray], np.ndarray] = np.floor, - reverse_rate_function=False, - **kwargs, + int_func: Callable[[npt.NDArray[ManimFloat]], ManimFloat] = np.floor, + reverse_rate_function: bool = False, + **kwargs: Any, ) -> None: self.all_submobs = list(group.submobjects) self.int_func = int_func @@ -554,13 +556,15 @@ def __init__( self, text: Text, suspend_mobject_updating: bool = False, - int_func: Callable[[np.ndarray], np.ndarray] = np.ceil, - rate_func: Callable[[float], float] = linear, + int_func: Callable[ + [npt.NDArray[ManimFloat]], npt.NDArray[ManimFloat] + ] = np.ceil, + rate_func: RateFunc = linear, time_per_char: float = 0.1, run_time: float | None = None, - reverse_rate_function=False, - introducer=True, - **kwargs, + reverse_rate_function: bool = False, + introducer: bool = True, + **kwargs: Any, ) -> None: self.time_per_char = time_per_char # Check for empty text using family_members_with_points() @@ -602,14 +606,16 @@ def __init__( self, text: Text, suspend_mobject_updating: bool = False, - int_func: Callable[[np.ndarray], np.ndarray] = np.ceil, - rate_func: Callable[[float], float] = linear, + int_func: Callable[ + [npt.NDArray[ManimFloat]], npt.NDArray[ManimFloat] + ] = np.ceil, + rate_func: RateFunc = linear, time_per_char: float = 0.1, run_time: float | None = None, - reverse_rate_function=True, - introducer=False, - remover=True, - **kwargs, + reverse_rate_function: bool = True, + introducer: bool = False, + remover: bool = True, + **kwargs: Any, ) -> None: super().__init__( text, @@ -631,8 +637,10 @@ class ShowSubmobjectsOneByOne(ShowIncreasingSubsets): def __init__( self, group: Iterable[Mobject], - int_func: Callable[[np.ndarray], np.ndarray] = np.ceil, - **kwargs, + int_func: Callable[ + [npt.NDArray[ManimFloat]], npt.NDArray[ManimFloat] + ] = np.ceil, + **kwargs: Any, ) -> None: new_group = Group(*group) super().__init__(new_group, int_func=int_func, **kwargs) @@ -652,9 +660,9 @@ class AddTextWordByWord(Succession): def __init__( self, text_mobject: Text, - run_time: float = None, + run_time: float | None = None, time_per_char: float = 0.06, - **kwargs, + **kwargs: Any, ) -> None: self.time_per_char = time_per_char tpc = self.time_per_char diff --git a/manim/animation/fading.py b/manim/animation/fading.py index 33f38a5027..c0e3655a8c 100644 --- a/manim/animation/fading.py +++ b/manim/animation/fading.py @@ -19,10 +19,10 @@ def construct(self): "FadeOut", "FadeIn", ] - -import numpy as np +from typing import Any from manim.mobject.opengl.opengl_mobject import OpenGLMobject +from manim.typing import Vector3D from ..animation.transform import Transform from ..constants import ORIGIN @@ -51,10 +51,10 @@ class _Fade(Transform): def __init__( self, *mobjects: Mobject, - shift: np.ndarray | None = None, - target_position: np.ndarray | Mobject | None = None, + shift: Vector3D | None = None, + target_position: Vector3D | Mobject | None = None, scale: float = 1, - **kwargs, + **kwargs: Any, ) -> None: if not mobjects: raise ValueError("At least one mobject must be passed.") @@ -76,7 +76,7 @@ def __init__( self.scale_factor = scale super().__init__(mobject, **kwargs) - def _create_faded_mobject(self, fadeIn: bool) -> Mobject: + def _create_faded_mobject(self, fade_in: bool) -> Mobject: """Create a faded, shifted and scaled copy of the mobject. Parameters @@ -91,7 +91,7 @@ def _create_faded_mobject(self, fadeIn: bool) -> Mobject: """ faded_mobject = self.mobject.copy() faded_mobject.fade(1) - direction_modifier = -1 if fadeIn and not self.point_target else 1 + direction_modifier = -1 if fade_in and not self.point_target else 1 faded_mobject.shift(self.shift_vector * direction_modifier) faded_mobject.scale(self.scale_factor) return faded_mobject @@ -135,14 +135,14 @@ def construct(self): """ - def __init__(self, *mobjects: Mobject, **kwargs) -> None: + def __init__(self, *mobjects: Mobject, **kwargs: Any) -> None: super().__init__(*mobjects, introducer=True, **kwargs) - def create_target(self): + def create_target(self) -> Mobject: return self.mobject def create_starting_mobject(self): - return self._create_faded_mobject(fadeIn=True) + return self._create_faded_mobject(fade_in=True) class FadeOut(_Fade): @@ -183,12 +183,12 @@ def construct(self): """ - def __init__(self, *mobjects: Mobject, **kwargs) -> None: + def __init__(self, *mobjects: Mobject, **kwargs: Any) -> None: super().__init__(*mobjects, remover=True, **kwargs) - def create_target(self): - return self._create_faded_mobject(fadeIn=False) + def create_target(self) -> Mobject: + return self._create_faded_mobject(fade_in=False) - def clean_up_from_scene(self, scene: Scene = None) -> None: + def clean_up_from_scene(self, scene: Scene | None = None) -> None: super().clean_up_from_scene(scene) self.interpolate(0) diff --git a/manim/animation/growing.py b/manim/animation/growing.py index d9f526c136..d8d12aef8c 100644 --- a/manim/animation/growing.py +++ b/manim/animation/growing.py @@ -31,15 +31,16 @@ def construct(self): "SpinInFromNothing", ] -import typing +from typing import TYPE_CHECKING, Any -import numpy as np +from manim.typing import Point3D, Vector3D +from manim.utils.color import ParsableManimColor from ..animation.transform import Transform from ..constants import PI from ..utils.paths import spiral_path -if typing.TYPE_CHECKING: +if TYPE_CHECKING: from manim.mobject.geometry.line import Arrow from ..mobject.mobject import Mobject @@ -76,7 +77,11 @@ def construct(self): """ def __init__( - self, mobject: Mobject, point: np.ndarray, point_color: str = None, **kwargs + self, + mobject: Mobject, + point: Point3D, + point_color: ParsableManimColor | None = None, + **kwargs: Any, ) -> None: self.point = point self.point_color = point_color @@ -118,7 +123,12 @@ def construct(self): """ - def __init__(self, mobject: Mobject, point_color: str = None, **kwargs) -> None: + def __init__( + self, + mobject: Mobject, + point_color: ParsableManimColor | None = None, + **kwargs: Any, + ) -> None: point = mobject.get_center() super().__init__(mobject, point, point_color=point_color, **kwargs) @@ -153,7 +163,11 @@ def construct(self): """ def __init__( - self, mobject: Mobject, edge: np.ndarray, point_color: str = None, **kwargs + self, + mobject: Mobject, + edge: Vector3D, + point_color: ParsableManimColor | None = None, + **kwargs: Any, ) -> None: point = mobject.get_critical_point(edge) super().__init__(mobject, point, point_color=point_color, **kwargs) @@ -183,7 +197,9 @@ def construct(self): """ - def __init__(self, arrow: Arrow, point_color: str = None, **kwargs) -> None: + def __init__( + self, arrow: Arrow, point_color: ParsableManimColor | None = None, **kwargs: Any + ) -> None: point = arrow.get_start() super().__init__(arrow, point, point_color=point_color, **kwargs) @@ -224,7 +240,11 @@ def construct(self): """ def __init__( - self, mobject: Mobject, angle: float = PI / 2, point_color: str = None, **kwargs + self, + mobject: Mobject, + angle: float = PI / 2, + point_color: ParsableManimColor | None = None, + **kwargs: Any, ) -> None: self.angle = angle super().__init__( diff --git a/manim/animation/indication.py b/manim/animation/indication.py index db2640d5a5..e69e7b236b 100644 --- a/manim/animation/indication.py +++ b/manim/animation/indication.py @@ -24,6 +24,7 @@ def construct(self): )) """ +from __future__ import annotations __all__ = [ "FocusOn", @@ -37,7 +38,7 @@ def construct(self): "Wiggle", ] -from typing import Callable, Iterable, Optional, Tuple, Type, Union +from typing import Any import numpy as np @@ -46,6 +47,7 @@ def construct(self): from manim.mobject.geometry.polygram import Rectangle from manim.mobject.geometry.shape_matchers import SurroundingRectangle from manim.scene.scene import Scene +from manim.typing import Point3D, RateFunc, Vector3D from .. import config from ..animation.animation import Animation @@ -94,11 +96,11 @@ def construct(self): def __init__( self, - focus_point: Union[np.ndarray, Mobject], + focus_point: Point3D | Mobject, opacity: float = 0.2, - color: str = GREY, + color: ParsableManimColor = GREY, run_time: float = 2, - **kwargs + **kwargs: Any, ) -> None: self.focus_point = focus_point self.color = color @@ -148,11 +150,11 @@ def construct(self): def __init__( self, - mobject: Mobject, + mobject: Mobject | None, scale_factor: float = 1.2, color: str = YELLOW, - rate_func: Callable[[float, Optional[float]], np.ndarray] = there_and_back, - **kwargs + rate_func: RateFunc = there_and_back, + **kwargs: Any, ) -> None: self.color = color self.scale_factor = scale_factor @@ -218,15 +220,15 @@ def construct(self): def __init__( self, - point: Union[np.ndarray, Mobject], + point: Point3D | Mobject, line_length: float = 0.2, num_lines: int = 12, flash_radius: float = 0.1, line_stroke_width: int = 3, - color: str = YELLOW, + color: ParsableManimColor = YELLOW, time_width: float = 1, run_time: float = 1.0, - **kwargs + **kwargs: Any, ) -> None: if isinstance(point, Mobject): self.point = point.get_center() @@ -256,7 +258,7 @@ def create_lines(self) -> VGroup: lines.set_stroke(width=self.line_stroke_width) return lines - def create_line_anims(self) -> Iterable["ShowPassingFlash"]: + def create_line_anims(self) -> list[ShowPassingFlash]: return [ ShowPassingFlash( line, @@ -302,11 +304,13 @@ def construct(self): """ - def __init__(self, mobject: "VMobject", time_width: float = 0.1, **kwargs) -> None: + def __init__( + self, mobject: VMobject, time_width: float = 0.1, **kwargs: Any + ) -> None: self.time_width = time_width super().__init__(mobject, remover=True, introducer=True, **kwargs) - def _get_bounds(self, alpha: float) -> Tuple[float]: + def _get_bounds(self, alpha: float) -> tuple[float, float]: tw = self.time_width upper = interpolate(0, 1 + tw, alpha) lower = upper - tw @@ -321,7 +325,14 @@ def clean_up_from_scene(self, scene: Scene) -> None: class ShowPassingFlashWithThinningStrokeWidth(AnimationGroup): - def __init__(self, vmobject, n_segments=10, time_width=0.1, remover=True, **kwargs): + def __init__( + self, + vmobject: VMobject, + n_segments: int = 10, + time_width: float = 0.1, + remover: bool = True, + **kwargs: Any, + ) -> None: self.n_segments = n_segments self.time_width = time_width self.remover = remover @@ -348,7 +359,7 @@ def __init__(self, vmobject, n_segments=10, time_width=0.1, remover=True, **kwar message="Use Create then FadeOut to achieve this effect.", ) class ShowCreationThenFadeOut(Succession): - def __init__(self, mobject: Mobject, remover: bool = True, **kwargs) -> None: + def __init__(self, mobject: Mobject, remover: bool = True, **kwargs: Any) -> None: super().__init__(Create(mobject), FadeOut(mobject), remover=remover, **kwargs) @@ -398,19 +409,19 @@ def construct(self): def __init__( self, mobject: Mobject, - direction: np.ndarray = UP, + direction: Vector3D = UP, amplitude: float = 0.2, - wave_func: Callable[[float], float] = smooth, + wave_func: RateFunc = smooth, time_width: float = 1, ripples: int = 1, run_time: float = 2, - **kwargs + **kwargs: Any, ) -> None: x_min = mobject.get_left()[0] x_max = mobject.get_right()[0] vect = amplitude * normalize(direction) - def wave(t): + def wave(t: float) -> float: # Creates a wave with n ripples from a simple rate_func # This wave is build up as follows: # The time is split into 2*ripples phases. In every phase the amplitude @@ -470,7 +481,7 @@ def homotopy( y: float, z: float, t: float, - ) -> Tuple[float, float, float]: + ) -> tuple[float, float, float]: upper = interpolate(0, 1 + time_width, t) lower = upper - time_width relative_x = inverse_interpolate(x_min, x_max, x) @@ -520,10 +531,10 @@ def __init__( scale_value: float = 1.1, rotation_angle: float = 0.01 * TAU, n_wiggles: int = 6, - scale_about_point: Optional[np.ndarray] = None, - rotate_about_point: Optional[np.ndarray] = None, + scale_about_point: Point3D | None = None, + rotate_about_point: Point3D | None = None, run_time: float = 2, - **kwargs + **kwargs: Any, ) -> None: self.scale_value = scale_value self.rotation_angle = rotation_angle @@ -532,12 +543,12 @@ def __init__( self.rotate_about_point = rotate_about_point super().__init__(mobject, run_time=run_time, **kwargs) - def get_scale_about_point(self) -> np.ndarray: + def get_scale_about_point(self) -> Point3D: if self.scale_about_point is None: return self.mobject.get_center() return self.scale_about_point - def get_rotate_about_point(self) -> np.ndarray: + def get_rotate_about_point(self) -> Point3D: if self.rotate_about_point is None: return self.mobject.get_center() return self.rotate_about_point @@ -604,16 +615,16 @@ def construct(self): def __init__( self, mobject: Mobject, - shape: Type = Rectangle, - fade_in=False, - fade_out=False, - time_width=0.3, + shape: type[Rectangle | Circle] = Rectangle, + fade_in: bool = False, + fade_out: bool = False, + time_width: float = 0.3, buff: float = SMALL_BUFF, color: ParsableManimColor = YELLOW, - run_time=1, - stroke_width=DEFAULT_STROKE_WIDTH, - **kwargs - ): + run_time: int = 1, + stroke_width: int = DEFAULT_STROKE_WIDTH, + **kwargs: Any, + ) -> None: if shape is Rectangle: frame = SurroundingRectangle( mobject, diff --git a/manim/animation/movement.py b/manim/animation/movement.py index 4533eeeb70..526908afa2 100644 --- a/manim/animation/movement.py +++ b/manim/animation/movement.py @@ -12,13 +12,14 @@ from typing import TYPE_CHECKING, Any, Callable -import numpy as np +from manim.typing import Point3D, RateFunc from ..animation.animation import Animation from ..utils.rate_functions import linear if TYPE_CHECKING: - from ..mobject.mobject import Mobject, VMobject + from ..mobject.mobject import Mobject + from ..mobject.types.vectorized_mobject import VMobject class Homotopy(Animation): @@ -52,12 +53,10 @@ def __init__( mobject: Mobject, run_time: float = 3, apply_function_kwargs: dict[str, Any] | None = None, - **kwargs, + **kwargs: Any, ) -> None: self.homotopy = homotopy - self.apply_function_kwargs = ( - apply_function_kwargs if apply_function_kwargs is not None else {} - ) + self.apply_function_kwargs = apply_function_kwargs or {} super().__init__(mobject, run_time=run_time, **kwargs) def function_at_time_t(self, t: float) -> tuple[float, float, float]: @@ -88,7 +87,10 @@ def interpolate_submobject( class ComplexHomotopy(Homotopy): def __init__( - self, complex_homotopy: Callable[[complex], float], mobject: Mobject, **kwargs + self, + complex_homotopy: Callable[[complex, float], float], + mobject: Mobject, + **kwargs, ) -> None: """ Complex Homotopy a function Cx[0, 1] to C @@ -109,12 +111,12 @@ def homotopy( class PhaseFlow(Animation): def __init__( self, - function: Callable[[np.ndarray], np.ndarray], + function: Callable[[Point3D], Point3D], mobject: Mobject, virtual_time: float = 1, suspend_mobject_updating: bool = False, - rate_func: Callable[[float], float] = linear, - **kwargs, + rate_func: RateFunc = linear, + **kwargs: Any, ) -> None: self.virtual_time = virtual_time self.function = function @@ -153,8 +155,8 @@ def __init__( self, mobject: Mobject, path: VMobject, - suspend_mobject_updating: bool | None = False, - **kwargs, + suspend_mobject_updating: bool = False, + **kwargs: Any, ) -> None: self.path = path super().__init__( diff --git a/manim/animation/numbers.py b/manim/animation/numbers.py index 86bfe7154b..b90708433c 100644 --- a/manim/animation/numbers.py +++ b/manim/animation/numbers.py @@ -5,7 +5,7 @@ __all__ = ["ChangingDecimal", "ChangeDecimalToValue"] -import typing +from typing import Any, Callable from manim.mobject.text.numbers import DecimalNumber @@ -17,9 +17,9 @@ class ChangingDecimal(Animation): def __init__( self, decimal_mob: DecimalNumber, - number_update_func: typing.Callable[[float], float], - suspend_mobject_updating: bool | None = False, - **kwargs, + number_update_func: Callable[[float], float], + suspend_mobject_updating: bool = False, + **kwargs: Any, ) -> None: self.check_validity_of_input(decimal_mob) self.number_update_func = number_update_func @@ -37,7 +37,7 @@ def interpolate_mobject(self, alpha: float) -> None: class ChangeDecimalToValue(ChangingDecimal): def __init__( - self, decimal_mob: DecimalNumber, target_number: int, **kwargs + self, decimal_mob: DecimalNumber, target_number: int, **kwargs: Any ) -> None: start_number = decimal_mob.number super().__init__( diff --git a/manim/animation/rotation.py b/manim/animation/rotation.py index faff411c21..f3e8c4facb 100644 --- a/manim/animation/rotation.py +++ b/manim/animation/rotation.py @@ -4,9 +4,9 @@ __all__ = ["Rotating", "Rotate"] -from typing import TYPE_CHECKING, Callable, Sequence +from typing import TYPE_CHECKING, Any -import numpy as np +from manim.typing import Point3D, RateFunc, Vector3D from ..animation.animation import Animation from ..animation.transform import Transform @@ -21,13 +21,13 @@ class Rotating(Animation): def __init__( self, mobject: Mobject, - axis: np.ndarray = OUT, - radians: np.ndarray = TAU, - about_point: np.ndarray | None = None, - about_edge: np.ndarray | None = None, + axis: Vector3D = OUT, + radians: float = TAU, + about_point: Point3D | None = None, + about_edge: Point3D | None = None, run_time: float = 5, - rate_func: Callable[[float], float] = linear, - **kwargs, + rate_func: RateFunc = linear, + **kwargs: Any, ) -> None: self.axis = axis self.radians = radians @@ -85,10 +85,10 @@ def __init__( self, mobject: Mobject, angle: float = PI, - axis: np.ndarray = OUT, - about_point: Sequence[float] | None = None, - about_edge: Sequence[float] | None = None, - **kwargs, + axis: Vector3D = OUT, + about_point: Point3D | None = None, + about_edge: Point3D | None = None, + **kwargs: Any, ) -> None: if "path_arc" not in kwargs: kwargs["path_arc"] = angle diff --git a/manim/animation/specialized.py b/manim/animation/specialized.py index adc44ea1f1..07ca4b5d23 100644 --- a/manim/animation/specialized.py +++ b/manim/animation/specialized.py @@ -2,13 +2,17 @@ __all__ = ["Broadcast"] -from typing import Any, Sequence +from typing import TYPE_CHECKING, Any from manim.animation.transform import Restore +from manim.typing import Point3D from ..constants import * from .composition import LaggedStart +if TYPE_CHECKING: + from ..mobject.mobject import Mobject + class Broadcast(LaggedStart): """Broadcast a mobject starting from an ``initial_width``, up to the actual size of the mobject. @@ -49,8 +53,8 @@ def construct(self): def __init__( self, - mobject, - focal_point: Sequence[float] = ORIGIN, + mobject: Mobject, + focal_point: Point3D = ORIGIN, n_mobs: int = 5, initial_opacity: float = 1, final_opacity: float = 0, @@ -59,7 +63,7 @@ def __init__( lag_ratio: float = 0.2, run_time: float = 3, **kwargs: Any, - ): + ) -> None: self.focal_point = focal_point self.n_mobs = n_mobs self.initial_opacity = initial_opacity diff --git a/manim/animation/speedmodifier.py b/manim/animation/speedmodifier.py index 9df1c9f018..2f0dc6f881 100644 --- a/manim/animation/speedmodifier.py +++ b/manim/animation/speedmodifier.py @@ -4,10 +4,12 @@ import inspect import types -from typing import Callable +from typing import Any, Callable from numpy import piecewise +from manim.typing import RateFunc + from ..animation.animation import Animation, Wait, prepare_animation from ..animation.composition import AnimationGroup from ..mobject.mobject import Mobject, Updater, _AnimationBuilder @@ -94,11 +96,11 @@ def __init__( self, anim: Animation | _AnimationBuilder, speedinfo: dict[float, float], - rate_func: Callable[[float], float] | None = None, + rate_func: RateFunc | None = None, affects_speed_updaters: bool = True, - **kwargs, + **kwargs: Any, ) -> None: - if issubclass(type(anim), AnimationGroup): + if isinstance(anim, AnimationGroup): self.anim = type(anim)( *map(self.setup, anim.animations), group=anim.group, @@ -121,7 +123,9 @@ def __init__( # A function where, f(0) = 0, f'(0) = initial speed, f'( f-1(1) ) = final speed # Following function obtained when conditions applied to vertical parabola - self.speed_modifier = lambda x, init_speed, final_speed: ( + self.speed_modifier: Callable[ + [float, float, float], float + ] = lambda x, init_speed, final_speed: ( (final_speed**2 - init_speed**2) * x**2 / 4 + init_speed * x ) @@ -149,12 +153,12 @@ def __init__( dur = node - prevnode def condition( - t, - curr_time=curr_time, - init_speed=init_speed, - final_speed=final_speed, - dur=dur, - ): + t: float, + curr_time: float = curr_time, + init_speed: float = init_speed, + final_speed: float = final_speed, + dur: float = dur, + ) -> bool: lower_bound = curr_time / scaled_total_time upper_bound = ( curr_time + self.f_inv_1(init_speed, final_speed) * dur @@ -164,13 +168,13 @@ def condition( self.conditions.append(condition) def function( - t, - curr_time=curr_time, - init_speed=init_speed, - final_speed=final_speed, - dur=dur, - prevnode=prevnode, - ): + t: float, + curr_time: float = curr_time, + init_speed: float = init_speed, + final_speed: float = final_speed, + dur: float = dur, + prevnode: float = prevnode, + ) -> float: return ( self.speed_modifier( (scaled_total_time * t - curr_time) / dur, @@ -187,7 +191,7 @@ def function( prevnode = node init_speed = final_speed - def func(t): + def func(t: float) -> float: if t == 1: ChangeSpeed.is_changing_dt = False new_t = piecewise( @@ -209,7 +213,7 @@ def func(t): **kwargs, ) - def setup(self, anim): + def setup(self, anim: Animation) -> Animation: if type(anim) is Wait: anim.interpolate = types.MethodType( lambda self, alpha: self.rate_func(alpha), anim @@ -235,7 +239,7 @@ def add_updater( update_function: Updater, index: int | None = None, call_updater: bool = False, - ): + ) -> None: """This static method can be used to apply speed change to updaters. This updater will follow speed and rate function of any :class:`.ChangeSpeed` diff --git a/manim/animation/transform.py b/manim/animation/transform.py index 7607199d99..1dfba5e00c 100644 --- a/manim/animation/transform.py +++ b/manim/animation/transform.py @@ -28,11 +28,14 @@ import inspect import types -from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence +from typing import TYPE_CHECKING, Any, Iterable import numpy as np +from typing_extensions import Self +from manim.constants import PI from manim.mobject.opengl.opengl_mobject import OpenGLGroup, OpenGLMobject +from manim.typing import MatrixMN, PathFuncType, Point3D, RateFunc, Vector3D from .. import config from ..animation.animation import Animation @@ -44,6 +47,7 @@ RendererType, ) from ..mobject.mobject import Group, Mobject +from ..utils.color import ParsableManimColor from ..utils.paths import path_along_arc, path_along_circles from ..utils.rate_functions import smooth, squish_rate_func @@ -128,20 +132,20 @@ def __init__( self, mobject: Mobject | None, target_mobject: Mobject | None = None, - path_func: Callable | None = None, + path_func: PathFuncType | None = None, path_arc: float = 0, - path_arc_axis: np.ndarray = OUT, - path_arc_centers: np.ndarray = None, + path_arc_axis: Vector3D = OUT, + path_arc_centers: Point3D | None = None, replace_mobject_with_target_in_scene: bool = False, - **kwargs, + **kwargs: Any, ) -> None: - self.path_arc_axis: np.ndarray = path_arc_axis - self.path_arc_centers: np.ndarray = path_arc_centers - self.path_arc: float = path_arc + self.path_arc_axis = path_arc_axis + self.path_arc_centers = path_arc_centers + self.path_arc = path_arc # path_func is a property a few lines below so it doesn't need to be set in any case if path_func is not None: - self.path_func: Callable = path_func + self.path_func = path_func elif self.path_arc_centers is not None: self.path_func = path_along_circles( path_arc, @@ -149,10 +153,8 @@ def __init__( self.path_arc_axis, ) - self.replace_mobject_with_target_in_scene: bool = ( - replace_mobject_with_target_in_scene - ) - self.target_mobject: Mobject = ( + self.replace_mobject_with_target_in_scene = replace_mobject_with_target_in_scene + self.target_mobject = ( target_mobject if target_mobject is not None else Mobject() ) super().__init__(mobject, **kwargs) @@ -170,22 +172,11 @@ def path_arc(self, path_arc: float) -> None: ) @property - def path_func( - self, - ) -> Callable[ - [Iterable[np.ndarray], Iterable[np.ndarray], float], - Iterable[np.ndarray], - ]: + def path_func(self) -> PathFuncType: return self._path_func @path_func.setter - def path_func( - self, - path_func: Callable[ - [Iterable[np.ndarray], Iterable[np.ndarray], float], - Iterable[np.ndarray], - ], - ) -> None: + def path_func(self, path_func: PathFuncType) -> None: if path_func is not None: self._path_func = path_func @@ -213,7 +204,7 @@ def clean_up_from_scene(self, scene: Scene) -> None: if self.replace_mobject_with_target_in_scene: scene.replace(self.mobject, self.target_mobject) - def get_all_mobjects(self) -> Sequence[Mobject]: + def get_all_mobjects(self) -> list[Mobject]: return [ self.mobject, self.starting_mobject, @@ -222,7 +213,7 @@ def get_all_mobjects(self) -> Sequence[Mobject]: ] def get_all_families_zipped(self) -> Iterable[tuple]: # more precise typing? - mobs = [ + mobs: list[Mobject] = [ self.mobject, self.starting_mobject, self.target_copy, @@ -237,7 +228,7 @@ def interpolate_submobject( starting_submobject: Mobject, target_copy: Mobject, alpha: float, - ) -> Transform: + ) -> Self: submobject.interpolate(starting_submobject, target_copy, alpha, self.path_func) return self @@ -290,7 +281,9 @@ def construct(self): """ - def __init__(self, mobject: Mobject, target_mobject: Mobject, **kwargs) -> None: + def __init__( + self, mobject: Mobject, target_mobject: Mobject, **kwargs: Any + ) -> None: super().__init__( mobject, target_mobject, replace_mobject_with_target_in_scene=True, **kwargs ) @@ -301,7 +294,9 @@ class TransformFromCopy(Transform): Performs a reversed Transform """ - def __init__(self, mobject: Mobject, target_mobject: Mobject, **kwargs) -> None: + def __init__( + self, mobject: Mobject, target_mobject: Mobject, **kwargs: Any + ) -> None: super().__init__(target_mobject, mobject, **kwargs) def interpolate(self, alpha: float) -> None: @@ -342,8 +337,8 @@ def __init__( self, mobject: Mobject, target_mobject: Mobject, - path_arc: float = -np.pi, - **kwargs, + path_arc: float = -PI, + **kwargs: Any, ) -> None: super().__init__(mobject, target_mobject, path_arc=path_arc, **kwargs) @@ -391,8 +386,8 @@ def __init__( self, mobject: Mobject, target_mobject: Mobject, - path_arc: float = np.pi, - **kwargs, + path_arc: float = PI, + **kwargs: Any, ) -> None: super().__init__(mobject, target_mobject, path_arc=path_arc, **kwargs) @@ -423,7 +418,7 @@ def construct(self): """ - def __init__(self, mobject: Mobject, **kwargs) -> None: + def __init__(self, mobject: Mobject, **kwargs: Any) -> None: self.check_validity_of_input(mobject) super().__init__(mobject, mobject.target, **kwargs) @@ -464,15 +459,13 @@ class ApplyMethod(Transform): """ - def __init__( - self, method: Callable, *args, **kwargs - ) -> None: # method typing (we want to specify Mobject method)? for args? + def __init__(self, method: types.MethodType, *args: Any, **kwargs: Any) -> None: self.check_validity_of_input(method) self.method = method self.method_args = args super().__init__(method.__self__, **kwargs) - def check_validity_of_input(self, method: Callable) -> None: + def check_validity_of_input(self, method: types.MethodType) -> None: if not inspect.ismethod(method): raise ValueError( "Whoops, looks like you accidentally invoked " @@ -520,13 +513,15 @@ def __init__( function: types.MethodType, mobject: Mobject, run_time: float = DEFAULT_POINTWISE_FUNCTION_RUN_TIME, - **kwargs, + **kwargs: Any, ) -> None: super().__init__(mobject.apply_function, function, run_time=run_time, **kwargs) class ApplyPointwiseFunctionToCenter(ApplyPointwiseFunction): - def __init__(self, function: types.MethodType, mobject: Mobject, **kwargs) -> None: + def __init__( + self, function: types.MethodType, mobject: Mobject, **kwargs: Any + ) -> None: self.function = function super().__init__(mobject.move_to, **kwargs) @@ -549,7 +544,9 @@ def construct(self): """ - def __init__(self, mobject: Mobject, color: str, **kwargs) -> None: + def __init__( + self, mobject: Mobject, color: ParsableManimColor, **kwargs: Any + ) -> None: super().__init__(mobject.set_color, color, **kwargs) @@ -567,7 +564,7 @@ def construct(self): """ - def __init__(self, mobject: Mobject, scale_factor: float, **kwargs) -> None: + def __init__(self, mobject: Mobject, scale_factor: float, **kwargs: Any) -> None: super().__init__(mobject.scale, scale_factor, **kwargs) @@ -585,7 +582,7 @@ def construct(self): """ - def __init__(self, mobject: Mobject, **kwargs) -> None: + def __init__(self, mobject: Mobject, **kwargs: Any) -> None: super().__init__(mobject, 0, **kwargs) @@ -611,12 +608,14 @@ def construct(self): """ - def __init__(self, mobject: Mobject, **kwargs) -> None: + def __init__(self, mobject: Mobject, **kwargs: Any) -> None: super().__init__(mobject.restore, **kwargs) class ApplyFunction(Transform): - def __init__(self, function: types.MethodType, mobject: Mobject, **kwargs) -> None: + def __init__( + self, function: types.MethodType, mobject: Mobject, **kwargs: Any + ) -> None: self.function = function super().__init__(mobject, **kwargs) @@ -657,10 +656,10 @@ def construct(self): def __init__( self, - matrix: np.ndarray, + matrix: MatrixMN, mobject: Mobject, - about_point: np.ndarray = ORIGIN, - **kwargs, + about_point: Point3D = ORIGIN, + **kwargs: Any, ) -> None: matrix = self.initialize_matrix(matrix) @@ -669,7 +668,7 @@ def func(p): super().__init__(func, mobject, **kwargs) - def initialize_matrix(self, matrix: np.ndarray) -> np.ndarray: + def initialize_matrix(self, matrix: MatrixMN) -> MatrixMN: matrix = np.array(matrix) if matrix.shape == (2, 2): new_matrix = np.identity(3) @@ -681,7 +680,9 @@ def initialize_matrix(self, matrix: np.ndarray) -> np.ndarray: class ApplyComplexFunction(ApplyMethod): - def __init__(self, function: types.MethodType, mobject: Mobject, **kwargs) -> None: + def __init__( + self, function: types.MethodType, mobject: Mobject, **kwargs: Any + ) -> None: self.function = function method = mobject.apply_complex_function super().__init__(method, function, **kwargs) @@ -692,9 +693,6 @@ def _init_path_func(self) -> None: super()._init_path_func() -### - - class CyclicReplace(Transform): """An animation moving mobjects cyclically. @@ -728,7 +726,7 @@ def construct(self): """ def __init__( - self, *mobjects: Mobject, path_arc: float = 90 * DEGREES, **kwargs + self, *mobjects: Mobject, path_arc: float = 90 * DEGREES, **kwargs: Any ) -> None: self.group = Group(*mobjects) super().__init__(self.group, path_arc=path_arc, **kwargs) @@ -751,8 +749,8 @@ def __init__( self, start_anim: Animation, end_anim: Animation, - rate_func: Callable = squish_rate_func(smooth), - **kwargs, + rate_func: RateFunc = squish_rate_func(smooth), + **kwargs: Any, ) -> None: self.start_anim = start_anim self.end_anim = end_anim @@ -830,7 +828,14 @@ def construct(self): """ - def __init__(self, mobject, target_mobject, stretch=True, dim_to_match=1, **kwargs): + def __init__( + self, + mobject: Mobject, + target_mobject: Mobject, + stretch: bool = True, + dim_to_match: int = 1, + **kwargs: Any, + ) -> None: self.to_add_on_completion = target_mobject self.stretch = stretch self.dim_to_match = dim_to_match @@ -841,7 +846,7 @@ def __init__(self, mobject, target_mobject, stretch=True, dim_to_match=1, **kwar group = Group(mobject, target_mobject.copy()) super().__init__(group, **kwargs) - def begin(self): + def begin(self) -> None: """Initial setup for the animation. The mobject to which this animation is bound is a group consisting of @@ -858,7 +863,7 @@ def begin(self): for m0, m1 in ((start[1], start[0]), (end[0], end[1])): self.ghost_to(m0, m1) - def ghost_to(self, source, target): + def ghost_to(self, source: Mobject, target: Mobject) -> None: """Replaces the source by the target and sets the opacity to 0. If the provided target has no points, and thus a location of [0, 0, 0] @@ -869,7 +874,7 @@ def ghost_to(self, source, target): source.replace(target, stretch=self.stretch, dim_to_match=self.dim_to_match) source.set_opacity(0) - def get_all_mobjects(self) -> Sequence[Mobject]: + def get_all_mobjects(self) -> list[Mobject]: return [ self.mobject, self.starting_mobject, @@ -879,7 +884,7 @@ def get_all_mobjects(self) -> Sequence[Mobject]: def get_all_families_zipped(self): return Animation.get_all_families_zipped(self) - def clean_up_from_scene(self, scene): + def clean_up_from_scene(self, scene: Scene) -> None: Animation.clean_up_from_scene(self, scene) scene.remove(self.mobject) self.mobject[0].restore() @@ -916,11 +921,11 @@ def construct(self): """ - def begin(self): + def begin(self) -> None: self.mobject[0].align_submobjects(self.mobject[1]) super().begin() - def ghost_to(self, source, target): + def ghost_to(self, source: Mobject, target: Mobject) -> None: """Replaces the source submobjects by the target submobjects and sets the opacity to 0. """ diff --git a/manim/animation/transform_matching_parts.py b/manim/animation/transform_matching_parts.py index dbf5dd294e..2428eb7a1f 100644 --- a/manim/animation/transform_matching_parts.py +++ b/manim/animation/transform_matching_parts.py @@ -4,7 +4,7 @@ __all__ = ["TransformMatchingShapes", "TransformMatchingTex"] -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import numpy as np @@ -74,8 +74,8 @@ def __init__( transform_mismatches: bool = False, fade_transform_mismatches: bool = False, key_map: dict | None = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: if isinstance(mobject, OpenGLVMobject): group_type = OpenGLVGroup elif isinstance(mobject, OpenGLMobject): @@ -206,8 +206,8 @@ def __init__( transform_mismatches: bool = False, fade_transform_mismatches: bool = False, key_map: dict | None = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: super().__init__( mobject, target_mobject, @@ -269,8 +269,8 @@ def __init__( transform_mismatches: bool = False, fade_transform_mismatches: bool = False, key_map: dict | None = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: super().__init__( mobject, target_mobject, diff --git a/manim/animation/updaters/mobject_update_utils.py b/manim/animation/updaters/mobject_update_utils.py index dee27ff398..5aadc586d1 100644 --- a/manim/animation/updaters/mobject_update_utils.py +++ b/manim/animation/updaters/mobject_update_utils.py @@ -15,26 +15,27 @@ import inspect -from typing import TYPE_CHECKING, Callable +import types +from typing import TYPE_CHECKING, Any, Callable, TypeVar import numpy as np from manim.constants import DEGREES, RIGHT from manim.mobject.mobject import Mobject from manim.opengl import OpenGLMobject -from manim.utils.space_ops import normalize +from manim.typing import Vector3D if TYPE_CHECKING: from manim.animation.animation import Animation -def assert_is_mobject_method(method: Callable) -> None: +def assert_is_mobject_method(method: types.MethodType) -> None: assert inspect.ismethod(method) mobject = method.__self__ assert isinstance(mobject, (Mobject, OpenGLMobject)) -def always(method: Callable, *args, **kwargs) -> Mobject: +def always(method: types.MethodType, *args: Any, **kwargs: Any) -> Mobject: assert_is_mobject_method(method) mobject = method.__self__ func = method.__func__ @@ -60,7 +61,10 @@ def updater(mob): return mobject -def always_redraw(func: Callable[[], Mobject]) -> Mobject: +MobjectT = TypeVar("MobjectT", bound=Mobject) + + +def always_redraw(func: Callable[[], MobjectT]) -> MobjectT: """Redraw the mobject constructed by a function every frame. This function returns a mobject with an attached updater that @@ -106,8 +110,8 @@ def construct(self): def always_shift( - mobject: Mobject, direction: np.ndarray[np.float64] = RIGHT, rate: float = 0.1 -) -> Mobject: + mobject: MobjectT, direction: Vector3D = RIGHT, rate: float = 0.1 +) -> MobjectT: """A mobject which is continuously shifted along some direction at a certain rate. @@ -144,7 +148,9 @@ def construct(self): return mobject -def always_rotate(mobject: Mobject, rate: float = 20 * DEGREES, **kwargs) -> Mobject: +def always_rotate( + mobject: MobjectT, rate: float = 20 * DEGREES, **kwargs: Any +) -> MobjectT: """A mobject which is continuously rotated at a certain rate. Parameters @@ -178,7 +184,7 @@ def construct(self): def turn_animation_into_updater( - animation: Animation, cycle: bool = False, **kwargs + animation: Animation, cycle: bool = False, **kwargs: Any ) -> Mobject: """ Add an updater to the animation's mobject which applies @@ -227,5 +233,5 @@ def update(m: Mobject, dt: float): return mobject -def cycle_animation(animation: Animation, **kwargs) -> Mobject: +def cycle_animation(animation: Animation, **kwargs: Any) -> Mobject: return turn_animation_into_updater(animation, cycle=True, **kwargs) diff --git a/manim/animation/updaters/update.py b/manim/animation/updaters/update.py index ded160cff7..bac0063011 100644 --- a/manim/animation/updaters/update.py +++ b/manim/animation/updaters/update.py @@ -6,11 +6,11 @@ import operator as op -import typing +from typing import TYPE_CHECKING, Any, Callable from manim.animation.animation import Animation -if typing.TYPE_CHECKING: +if TYPE_CHECKING: from manim.mobject.mobject import Mobject @@ -24,9 +24,9 @@ class UpdateFromFunc(Animation): def __init__( self, mobject: Mobject, - update_function: typing.Callable[[Mobject], typing.Any], + update_function: Callable[[Mobject], Any], suspend_mobject_updating: bool = False, - **kwargs, + **kwargs: Any, ) -> None: self.update_function = update_function super().__init__( @@ -43,7 +43,9 @@ def interpolate_mobject(self, alpha: float) -> None: class MaintainPositionRelativeTo(Animation): - def __init__(self, mobject: Mobject, tracked_mobject: Mobject, **kwargs) -> None: + def __init__( + self, mobject: Mobject, tracked_mobject: Mobject, **kwargs: Any + ) -> None: self.tracked_mobject = tracked_mobject self.diff = op.sub( mobject.get_center(), diff --git a/manim/typing.py b/manim/typing.py index 8111ca7398..5523b6c212 100644 --- a/manim/typing.py +++ b/manim/typing.py @@ -576,6 +576,8 @@ MappingFunction: TypeAlias = Callable[[Point3D], Point3D] """A function mapping a `Point3D` to another `Point3D`.""" +RateFunc: TypeAlias = Callable[[float], float] +"""An animation rate function.""" """ [CATEGORY] diff --git a/poetry.lock b/poetry.lock index b5fd7d925a..be5ac06a7c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1981,14 +1981,6 @@ files = [ {file = "mapbox_earcut-1.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9af9369266bf0ca32f4d401152217c46c699392513f22639c6b1be32bde9c1cc"}, {file = "mapbox_earcut-1.0.1-cp311-cp311-win32.whl", hash = "sha256:ff9a13be4364625697b0e0e04ba6a0f77300148b871bba0a85bfa67e972e85c4"}, {file = "mapbox_earcut-1.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:5e736557539c74fa969e866889c2b0149fc12668f35e3ae33667d837ff2880d3"}, - {file = "mapbox_earcut-1.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:4fe92174410e4120022393013705d77cb856ead5bdf6c81bec614a70df4feb5d"}, - {file = "mapbox_earcut-1.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:082f70a865c6164a60af039aa1c377073901cf1f94fd37b1c5610dfbae2a7369"}, - {file = "mapbox_earcut-1.0.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:43d268ece49d0c9e22cb4f92cd54c2cc64f71bf1c5e10800c189880d923e1292"}, - {file = "mapbox_earcut-1.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7748f1730fd36dd1fcf0809d8f872d7e1ddaa945f66a6a466ad37ef3c552ae93"}, - {file = "mapbox_earcut-1.0.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:5a82d10c8dec2a0bd9a6a6c90aca7044017c8dad79f7e209fd0667826f842325"}, - {file = "mapbox_earcut-1.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:01b292588cd3f6bad7d76ee31c004ed1b557a92bbd9602a72d2be15513b755be"}, - {file = "mapbox_earcut-1.0.1-cp312-cp312-win32.whl", hash = "sha256:fce236ddc3a56ea7260acc94601a832c260e6ac5619374bb2cec2e73e7414ff0"}, - {file = "mapbox_earcut-1.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:1ce86407353b4f09f5778c436518bbbc6f258f46c5736446f25074fe3d3a3bd8"}, {file = "mapbox_earcut-1.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:aa6111a18efacb79c081f3d3cdd7d25d0585bb0e9f28896b207ebe1d56efa40e"}, {file = "mapbox_earcut-1.0.1-cp36-cp36m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2911829d1e6e5e1282fbe2840fadf578f606580f02ed436346c2d51c92f810b"}, {file = "mapbox_earcut-1.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01ff909a7b8405a923abedd701b53633c997cc2b5dc9d5b78462f51c25ec2c33"}, @@ -2363,6 +2355,53 @@ files = [ {file = "multipledispatch-1.0.0.tar.gz", hash = "sha256:5c839915465c68206c3e9c473357908216c28383b425361e5d144594bf85a7e0"}, ] +[[package]] +name = "mypy" +version = "1.8.0" +description = "Optional static typing for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "mypy-1.8.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:485a8942f671120f76afffff70f259e1cd0f0cfe08f81c05d8816d958d4577d3"}, + {file = "mypy-1.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:df9824ac11deaf007443e7ed2a4a26bebff98d2bc43c6da21b2b64185da011c4"}, + {file = "mypy-1.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2afecd6354bbfb6e0160f4e4ad9ba6e4e003b767dd80d85516e71f2e955ab50d"}, + {file = "mypy-1.8.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8963b83d53ee733a6e4196954502b33567ad07dfd74851f32be18eb932fb1cb9"}, + {file = "mypy-1.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:e46f44b54ebddbeedbd3d5b289a893219065ef805d95094d16a0af6630f5d410"}, + {file = "mypy-1.8.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:855fe27b80375e5c5878492f0729540db47b186509c98dae341254c8f45f42ae"}, + {file = "mypy-1.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4c886c6cce2d070bd7df4ec4a05a13ee20c0aa60cb587e8d1265b6c03cf91da3"}, + {file = "mypy-1.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d19c413b3c07cbecf1f991e2221746b0d2a9410b59cb3f4fb9557f0365a1a817"}, + {file = "mypy-1.8.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9261ed810972061388918c83c3f5cd46079d875026ba97380f3e3978a72f503d"}, + {file = "mypy-1.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:51720c776d148bad2372ca21ca29256ed483aa9a4cdefefcef49006dff2a6835"}, + {file = "mypy-1.8.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:52825b01f5c4c1c4eb0db253ec09c7aa17e1a7304d247c48b6f3599ef40db8bd"}, + {file = "mypy-1.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f5ac9a4eeb1ec0f1ccdc6f326bcdb464de5f80eb07fb38b5ddd7b0de6bc61e55"}, + {file = "mypy-1.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afe3fe972c645b4632c563d3f3eff1cdca2fa058f730df2b93a35e3b0c538218"}, + {file = "mypy-1.8.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:42c6680d256ab35637ef88891c6bd02514ccb7e1122133ac96055ff458f93fc3"}, + {file = "mypy-1.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:720a5ca70e136b675af3af63db533c1c8c9181314d207568bbe79051f122669e"}, + {file = "mypy-1.8.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:028cf9f2cae89e202d7b6593cd98db6759379f17a319b5faf4f9978d7084cdc6"}, + {file = "mypy-1.8.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4e6d97288757e1ddba10dd9549ac27982e3e74a49d8d0179fc14d4365c7add66"}, + {file = "mypy-1.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f1478736fcebb90f97e40aff11a5f253af890c845ee0c850fe80aa060a267c6"}, + {file = "mypy-1.8.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:42419861b43e6962a649068a61f4a4839205a3ef525b858377a960b9e2de6e0d"}, + {file = "mypy-1.8.0-cp38-cp38-win_amd64.whl", hash = "sha256:2b5b6c721bd4aabaadead3a5e6fa85c11c6c795e0c81a7215776ef8afc66de02"}, + {file = "mypy-1.8.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5c1538c38584029352878a0466f03a8ee7547d7bd9f641f57a0f3017a7c905b8"}, + {file = "mypy-1.8.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4ef4be7baf08a203170f29e89d79064463b7fc7a0908b9d0d5114e8009c3a259"}, + {file = "mypy-1.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7178def594014aa6c35a8ff411cf37d682f428b3b5617ca79029d8ae72f5402b"}, + {file = "mypy-1.8.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ab3c84fa13c04aeeeabb2a7f67a25ef5d77ac9d6486ff33ded762ef353aa5592"}, + {file = "mypy-1.8.0-cp39-cp39-win_amd64.whl", hash = "sha256:99b00bc72855812a60d253420d8a2eae839b0afa4938f09f4d2aa9bb4654263a"}, + {file = "mypy-1.8.0-py3-none-any.whl", hash = "sha256:538fd81bb5e430cc1381a443971c0475582ff9f434c16cd46d2c66763ce85d9d"}, + {file = "mypy-1.8.0.tar.gz", hash = "sha256:6ff8b244d7085a0b425b56d327b480c3b29cafbd2eff27316a004f9a7391ae07"}, +] + +[package.dependencies] +mypy-extensions = ">=1.0.0" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typing-extensions = ">=4.1.0" + +[package.extras] +dmypy = ["psutil (>=4.0)"] +install-types = ["pip"] +mypyc = ["setuptools (>=50)"] +reports = ["lxml"] + [[package]] name = "mypy-extensions" version = "1.0.0" @@ -3337,7 +3376,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -4533,4 +4571,4 @@ jupyterlab = ["jupyterlab", "notebook"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "f15fa632919381a9b5b2cebc3e89aa307fa8735db2e5cde7a408765b46a3b00f" +content-hash = "d70198678bdc54c339c95c7217c27ba2227cce60b92d2e27ecbdc20ba2110391" diff --git a/pyproject.toml b/pyproject.toml index 1ba36a47b1..46224c7539 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,7 @@ furo = "^2023.09.10" gitpython = "^3" isort = "^5.12.0" matplotlib = "^3.8.2" +mypy = "^1.8.0" myst-parser = "^2.0.0" pre-commit = "^3.5.0" psutil = {version = "^5.8.0", python = "<3.10"}