Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VGroup can now be initialized with VMobject iterables #3966

Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 53 additions & 6 deletions manim/mobject/types/vectorized_mobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from manim.constants import *
from manim.mobject.mobject import Mobject
from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL
from manim.mobject.opengl.opengl_mobject import OpenGLMobject

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
manim.mobject.opengl.opengl_mobject
begins an import cycle.
from manim.mobject.opengl.opengl_vectorized_mobject import OpenGLVMobject
from manim.mobject.three_d.three_d_utils import (
get_3d_vmob_gradient_start_and_end_points,
Expand All @@ -47,6 +48,8 @@
from manim.utils.space_ops import rotate_vector, shoelace_direction

if TYPE_CHECKING:
from types import GeneratorType

import numpy.typing as npt
from typing_extensions import Self

Expand Down Expand Up @@ -2056,7 +2059,11 @@

"""

def __init__(self, *vmobjects, **kwargs):
def __init__(
self,
*vmobjects: VMobject | Iterable[VMobject] | GeneratorType[VMobject],
NikhilaGurusinghe marked this conversation as resolved.
Show resolved Hide resolved
**kwargs,
):
super().__init__(**kwargs)
self.add(*vmobjects)

Expand All @@ -2069,13 +2076,16 @@
f"submobject{'s' if len(self.submobjects) > 0 else ''}"
)

def add(self, *vmobjects: VMobject) -> Self:
"""Checks if all passed elements are an instance of VMobject and then add them to submobjects
def add(
self,
*vmobjects: VMobject | Iterable[VMobject] | GeneratorType[VMobject],
) -> Self:
"""Checks if all passed elements are an instance, or iterables of VMobject and then adds them to submobjects

Parameters
----------
vmobjects
List of VMobject to add
List or iterable of VMobjects to add

Returns
-------
Expand All @@ -2084,7 +2094,7 @@
Raises
------
TypeError
If one element of the list is not an instance of VMobject
If one element of the list, or iterable is not an instance of VMobject

Examples
--------
Expand Down Expand Up @@ -2117,7 +2127,44 @@
(gr-circle_red).animate.shift(RIGHT)
)
"""
return super().add(*vmobjects)

def get_type_error_message(invalid_obj, invalid_i):
return (
f"Only values of type {vmobject_render_type.__name__} can be added "
f"as submobjects of VGroup, but the value "
NikhilaGurusinghe marked this conversation as resolved.
Show resolved Hide resolved
f"{repr(invalid_obj)} (at index {invalid_i}) is of type "
f"{type(invalid_obj).__name__}."
)

vmobject_render_type = (
OpenGLVMobject if config.renderer == RendererType.OPENGL else VMobject
)
valid_vmobjects = []

for vmobject_i, vmobject in enumerate(vmobjects):
if isinstance(vmobject, vmobject_render_type):
valid_vmobjects.append(vmobject)
elif isinstance(vmobject, Iterable) and not isinstance(
vmobject, (Mobject, OpenGLMobject)
):
for subvmobject_i, subvmobject in enumerate(vmobject):
if not isinstance(subvmobject, vmobject_render_type):
raise TypeError(
get_type_error_message(subvmobject, subvmobject_i)
NikhilaGurusinghe marked this conversation as resolved.
Show resolved Hide resolved
)
valid_vmobjects.append(subvmobject)
elif isinstance(vmobject, Iterable) and isinstance(
vmobject, (Mobject, OpenGLMobject)
):
# This is if vmobject is an empty Mobject
NikhilaGurusinghe marked this conversation as resolved.
Show resolved Hide resolved
raise TypeError(
f"{get_type_error_message(vmobject, vmobject_i)} "
"You can try adding this value into a Group instead."
)
else:
raise TypeError(get_type_error_message(vmobject, vmobject_i))

return super().add(*valid_vmobjects)

def __add__(self, vmobject: VMobject) -> Self:
return VGroup(*self.submobjects, vmobject)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,52 @@ def test_vgroup_init():
)


def test_vgroup_init_with_iterable():
"""Test VGroup instantiation with an iterable type."""

def type_generator(type_to_generate, n):
return (type_to_generate() for _ in range(n))

def mixed_type_generator(major_type, minor_type, minor_type_positions, n):
return (
minor_type() if i in minor_type_positions else major_type()
for i in range(n)
)

obj = VGroup(VMobject())
assert len(obj.submobjects) == 1

obj = VGroup(type_generator(VMobject, 38))
assert len(obj.submobjects) == 38

obj = VGroup(VMobject(), [VMobject(), VMobject()], type_generator(VMobject, 38))
assert len(obj.submobjects) == 41

# A VGroup cannot be initialised with an iterable containing a Mobject
with pytest.raises(TypeError) as init_with_mob_iterable:
VGroup(type_generator(Mobject, 5))
assert str(init_with_mob_iterable.value) == (
"Only values of type VMobject can be added as submobjects of VGroup, "
"but the value Mobject (at index 0) is of type Mobject."
)

# A VGroup cannot be initialised with an iterable containing a Mobject in any position
with pytest.raises(TypeError) as init_with_mobs_and_vmobs_iterable:
VGroup(mixed_type_generator(VMobject, Mobject, [3, 5], 7))
assert str(init_with_mobs_and_vmobs_iterable.value) == (
"Only values of type VMobject can be added as submobjects of VGroup, "
"but the value Mobject (at index 3) is of type Mobject."
)

# A VGroup cannot be initialised with an iterable containing non VMobject's in any position
with pytest.raises(TypeError) as init_with_float_and_vmobs_iterable:
VGroup(mixed_type_generator(VMobject, float, [6, 7], 9))
assert str(init_with_float_and_vmobs_iterable.value) == (
"Only values of type VMobject can be added as submobjects of VGroup, "
"but the value 0.0 (at index 6) is of type float."
)


def test_vgroup_add():
"""Test the VGroup add method."""
obj = VGroup()
Expand Down
60 changes: 59 additions & 1 deletion tests/opengl/test_opengl_vectorized_mobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import pytest

from manim import Circle, Line, Square, VDict, VGroup
from manim import Circle, Line, Square, VDict, VGroup, VMobject
from manim.mobject.opengl.opengl_mobject import OpenGLMobject
from manim.mobject.opengl.opengl_vectorized_mobject import OpenGLVMobject

Expand Down Expand Up @@ -110,6 +110,64 @@ def test_vgroup_init(using_opengl_renderer):
)


def test_vgroup_init_with_iterable(using_opengl_renderer):
"""Test VGroup instantiation with an iterable type."""

def type_generator(type_to_generate, n):
return (type_to_generate() for _ in range(n))

def mixed_type_generator(major_type, minor_type, minor_type_positions, n):
return (
minor_type() if i in minor_type_positions else major_type()
for i in range(n)
)

obj = VGroup(OpenGLVMobject())
assert len(obj.submobjects) == 1

obj = VGroup(type_generator(OpenGLVMobject, 38))
assert len(obj.submobjects) == 38

obj = VGroup(
OpenGLVMobject(),
[OpenGLVMobject(), OpenGLVMobject()],
type_generator(OpenGLVMobject, 38),
)
assert len(obj.submobjects) == 41

# A VGroup cannot be initialised with an iterable containing a OpenGLMobject
with pytest.raises(TypeError) as init_with_mob_iterable:
VGroup(type_generator(OpenGLMobject, 5))
assert str(init_with_mob_iterable.value) == (
"Only values of type OpenGLVMobject can be added as submobjects of VGroup, "
"but the value OpenGLMobject (at index 0) is of type OpenGLMobject."
)

# A VGroup cannot be initialised with an iterable containing a OpenGLMobject in any position
with pytest.raises(TypeError) as init_with_mobs_and_vmobs_iterable:
VGroup(mixed_type_generator(OpenGLVMobject, OpenGLMobject, [3, 5], 7))
assert str(init_with_mobs_and_vmobs_iterable.value) == (
"Only values of type OpenGLVMobject can be added as submobjects of VGroup, "
"but the value OpenGLMobject (at index 3) is of type OpenGLMobject."
)

# A VGroup cannot be initialised with an iterable containing non OpenGLVMobject's in any position
with pytest.raises(TypeError) as init_with_float_and_vmobs_iterable:
VGroup(mixed_type_generator(OpenGLVMobject, float, [6, 7], 9))
assert str(init_with_float_and_vmobs_iterable.value) == (
"Only values of type OpenGLVMobject can be added as submobjects of VGroup, "
"but the value 0.0 (at index 6) is of type float."
)

# A VGroup cannot be initialised with an iterable containing both OpenGLVMobject's and VMobject's
with pytest.raises(TypeError) as init_with_mobs_and_vmobs_iterable:
VGroup(mixed_type_generator(OpenGLVMobject, VMobject, [3, 5], 7))
assert str(init_with_mobs_and_vmobs_iterable.value) == (
"Only values of type OpenGLVMobject can be added as submobjects of VGroup, "
"but the value VMobject (at index 3) is of type VMobject."
)


def test_vgroup_add(using_opengl_renderer):
"""Test the VGroup add method."""
obj = VGroup()
Expand Down
Loading