Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding the ability to pass lists and generators to .play() #3365

Merged
merged 12 commits into from
Dec 13, 2023
9 changes: 6 additions & 3 deletions manim/animation/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Callable, Sequence
import types
from typing import TYPE_CHECKING, Callable, Iterable, Sequence

import numpy as np

from manim.mobject.opengl.opengl_mobject import OpenGLGroup
from manim.utils.parameter_parsing import flatten_iterable_parameters

from .._config import config
from ..animation.animation import Animation, prepare_animation
Expand Down Expand Up @@ -54,14 +56,15 @@ class AnimationGroup(Animation):

def __init__(
self,
*animations: Animation,
*animations: Animation | Iterable[Animation] | types.GeneratorType[Animation],
group: Group | VGroup | OpenGLGroup | OpenGLVGroup = None,
run_time: float | None = None,
rate_func: Callable[[float], float] = linear,
lag_ratio: float = 0,
**kwargs,
) -> None:
self.animations = [prepare_animation(anim) for anim in animations]
arg_anim = flatten_iterable_parameters(animations)
self.animations = [prepare_animation(anim) for anim in arg_anim]
self.rate_func = rate_func
self.group = group
if self.group is None:
Expand Down
12 changes: 10 additions & 2 deletions manim/renderer/cairo_renderer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import typing
from typing import Any

import numpy as np

Expand All @@ -15,6 +14,10 @@
from ..utils.iterables import list_update

if typing.TYPE_CHECKING:
import types
from typing import Any, Iterable

from manim.animation.animation import Animation
from manim.scene.scene import Scene


Expand Down Expand Up @@ -51,7 +54,12 @@ def init_scene(self, scene):
scene.__class__.__name__,
)

def play(self, scene, *args, **kwargs):
def play(
self,
scene: Scene,
*args: Animation | Iterable[Animation] | types.GeneratorType[Animation],
**kwargs,
):
# Reset skip_animations to the original state.
# Needed when rendering only some animations, and skipping others.
self.skip_animations = self._original_skipping_status
Expand Down
26 changes: 21 additions & 5 deletions manim/scene/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

from manim.utils.parameter_parsing import flatten_iterable_parameters

__all__ = ["Scene"]

import copy
Expand All @@ -13,7 +15,6 @@
import time
import types
from queue import Queue
from typing import Callable

import srt

Expand All @@ -25,6 +26,8 @@
dearpygui_imported = True
except ImportError:
dearpygui_imported = False
from typing import TYPE_CHECKING

import numpy as np
from tqdm import tqdm
from watchdog.events import FileSystemEventHandler
Expand All @@ -48,6 +51,9 @@
from ..utils.file_ops import open_media_file
from ..utils.iterables import list_difference_update, list_update

if TYPE_CHECKING:
from typing import Callable, Iterable


class RerunSceneHandler(FileSystemEventHandler):
"""A class to handle rerunning a Scene after the input file is modified."""
Expand Down Expand Up @@ -865,7 +871,11 @@ def get_moving_and_static_mobjects(self, animations):
)
return all_moving_mobject_families, static_mobjects

def compile_animations(self, *args: Animation, **kwargs):
def compile_animations(
self,
*args: Animation | Iterable[Animation] | types.GeneratorType[Animation],
**kwargs,
):
"""
Creates _MethodAnimations from any _AnimationBuilders and updates animation
kwargs with kwargs passed to play().
Expand All @@ -883,7 +893,9 @@ def compile_animations(self, *args: Animation, **kwargs):
Animations to be played.
"""
animations = []
for arg in args:
arg_anims = flatten_iterable_parameters(args)
# Allow passing a generator to self.play instead of comma separated arguments
for arg in arg_anims:
try:
animations.append(prepare_animation(arg))
except TypeError:
Expand Down Expand Up @@ -1027,7 +1039,7 @@ def get_run_time(self, animations: list[Animation]):

def play(
self,
*args,
*args: Animation | Iterable[Animation] | types.GeneratorType[Animation],
subcaption=None,
subcaption_duration=None,
subcaption_offset=0,
Expand Down Expand Up @@ -1157,7 +1169,11 @@ def wait_until(self, stop_condition: Callable[[], bool], max_time: float = 60):
"""
self.wait(max_time, stop_condition=stop_condition)

def compile_animation_data(self, *animations: Animation, **play_kwargs):
def compile_animation_data(
self,
*animations: Animation | Iterable[Animation] | types.GeneratorType[Animation],
**play_kwargs,
):
"""Given a list of animations, compile the corresponding
static and moving mobjects, and gather the animation durations.

Expand Down
31 changes: 31 additions & 0 deletions manim/utils/parameter_parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from __future__ import annotations

from types import GeneratorType
from typing import Iterable, TypeVar

T = TypeVar("T")


def flatten_iterable_parameters(
args: Iterable[T | Iterable[T] | GeneratorType],
) -> list[T]:
"""Flattens an iterable of parameters into a list of parameters.

Parameters
----------
args
The iterable of parameters to flatten.
[(generator), [], (), ...]

Returns
-------
:class:`list`
The flattened list of parameters.
"""
flattened_parameters = []
for arg in args:
if isinstance(arg, (Iterable, GeneratorType)):
flattened_parameters.extend(arg)
else:
flattened_parameters.append(arg)
return flattened_parameters
Loading