From bbbda11fc321c93468def1bdb0ddc1107ce5c42e Mon Sep 17 00:00:00 2001 From: Nikhil Iyer Date: Thu, 2 Nov 2023 22:44:32 -0400 Subject: [PATCH 01/19] allow user-defined layout functions for Graph + fixup type annotations --- manim/mobject/graph.py | 112 +++++++++++++++++++++-------------------- 1 file changed, 57 insertions(+), 55 deletions(-) diff --git a/manim/mobject/graph.py b/manim/mobject/graph.py index 12082d5ad4..3b6e1818e2 100644 --- a/manim/mobject/graph.py +++ b/manim/mobject/graph.py @@ -9,7 +9,7 @@ import itertools as it from copy import copy -from typing import Hashable, Iterable +from typing import Any, Callable, Hashable, Iterable, Protocol import networkx as nx import numpy as np @@ -25,85 +25,86 @@ from manim.mobject.types.vectorized_mobject import VMobject from manim.utils.color import BLACK +class LayoutFunction(Protocol): + def __call__(self, graph: nx.classes.graph.Graph | nx.classes.digraph.DiGraph, *args: Any, **kwargs: Any) -> dict[Hashable, np.ndarray]: + ... def _determine_graph_layout( nx_graph: nx.classes.graph.Graph | nx.classes.digraph.DiGraph, - layout: str | dict = "spring", + layout: str | dict[Hashable, np.ndarray] | LayoutFunction = "spring", layout_scale: float = 2, layout_config: dict | None = None, - partitions: list[list[Hashable]] | None = None, - root_vertex: Hashable | None = None, -) -> dict: - automatic_layouts = { +) -> dict[Hashable, np.ndarray]: + + layouts = { "circular": nx.layout.circular_layout, "kamada_kawai": nx.layout.kamada_kawai_layout, + "partite": _partite_layout, "planar": nx.layout.planar_layout, - "random": nx.layout.random_layout, + "random": _random_layout, "shell": nx.layout.shell_layout, "spectral": nx.layout.spectral_layout, - "partite": nx.layout.multipartite_layout, - "tree": _tree_layout, "spiral": nx.layout.spiral_layout, "spring": nx.layout.spring_layout, + "tree": _tree_layout, } - custom_layouts = ["random", "partite", "tree"] - if layout_config is None: layout_config = {} + if layout_config.get("scale") is None: + layout_config["scale"] = layout_scale if isinstance(layout, dict): return layout - elif layout in automatic_layouts and layout not in custom_layouts: - auto_layout = automatic_layouts[layout]( - nx_graph, scale=layout_scale, **layout_config + elif layout in layouts: + layout_f, prepare = layouts[layout] + prepare(layout_config) + auto_layout = layout_f( + nx_graph, **layout_config ) # NetworkX returns a dictionary of 3D points if the dimension # is specified to be 3. Otherwise, it returns a dictionary of # 2D points, so adjusting is required. - if layout_config.get("dim") == 3: + if layout_config.get("dim") == 3 or auto_layout[next(auto_layout.__iter__())].shape[0] == 3: return auto_layout else: return {k: np.append(v, [0]) for k, v in auto_layout.items()} - elif layout == "tree": - return _tree_layout( - nx_graph, root_vertex=root_vertex, scale=layout_scale, **layout_config - ) - elif layout == "partite": - if partitions is None or len(partitions) == 0: + else: + try: + return layout(nx_graph, **layout_config) + except TypeError as e: raise ValueError( - "The partite layout requires the 'partitions' parameter to contain the partition of the vertices", + f"The layout '{layout}' is neither a recognized layout, a layout function," + "nor a vertex placement dictionary.", ) - partition_count = len(partitions) - for i in range(partition_count): - for v in partitions[i]: - if nx_graph.nodes[v] is None: - raise ValueError( - "The partition must contain arrays of vertices in the graph", - ) - nx_graph.nodes[v]["subset"] = i - # Add missing vertices to their own side - for v in nx_graph.nodes: - if "subset" not in nx_graph.nodes[v]: - nx_graph.nodes[v]["subset"] = partition_count - - auto_layout = automatic_layouts["partite"]( - nx_graph, scale=layout_scale, **layout_config - ) - return {k: np.append(v, [0]) for k, v in auto_layout.items()} - elif layout == "random": - # the random layout places coordinates in [0, 1) - # we need to rescale manually afterwards... - auto_layout = automatic_layouts["random"](nx_graph, **layout_config) - for k, v in auto_layout.items(): - auto_layout[k] = 2 * layout_scale * (v - np.array([0.5, 0.5])) - return {k: np.append(v, [0]) for k, v in auto_layout.items()} - else: + +def _partite_layout(nx_graph: nx.classes.graph.Graph, partitions: list[list[Hashable]] | None = None, **kwargs) -> dict[Hashable, np.ndarray]: + if partitions is None or len(partitions) == 0: raise ValueError( - f"The layout '{layout}' is neither a recognized automatic layout, " - "nor a vertex placement dictionary.", + "The partite layout requires `layout_config['partitions']` parameter to contain the partition of the vertices", ) - + partition_count = len(partitions) + for i in range(partition_count): + for v in partitions[i]: + if nx_graph.nodes[v] is None: + raise ValueError( + "The partition must contain arrays of vertices in the graph", + ) + nx_graph.nodes[v]["subset"] = i + # Add missing vertices to their own side + for v in nx_graph.nodes: + if "subset" not in nx_graph.nodes[v]: + nx_graph.nodes[v]["subset"] = partition_count + + return nx.layout.multipartite_layout(nx_graph, **kwargs) + +def _random_layout(nx_graph, scale, **kwargs): + # the random layout places coordinates in [0, 1) + # we need to rescale manually afterwards... + auto_layout = nx.layout.random_layout(nx_graph, **kwargs) + for k, v in auto_layout.items(): + auto_layout[k] = 2 * scale * (v - np.array([0.5, 0.5])) + return {k: np.append(v, [0]) for k, v in auto_layout.items()} def _tree_layout( T: nx.classes.graph.Graph | nx.classes.digraph.DiGraph, @@ -113,7 +114,7 @@ def _tree_layout( orientation: str = "down", ): if root_vertex is None: - raise ValueError("The tree layout requires the root_vertex parameter") + raise ValueError("The tree layout requires the layout_config['root_vertex'] parameter") if not nx.is_tree(T): raise ValueError("The tree layout must be used with trees") @@ -301,7 +302,7 @@ def __init__( edges: list[tuple[Hashable, Hashable]], labels: bool | dict = False, label_fill_color: str = BLACK, - layout: str | dict = "spring", + layout: str | dict[Hashable, np.ndarray] | LayoutFunction = "spring", layout_scale: float | tuple = 2, layout_config: dict | None = None, vertex_type: type[Mobject] = Dot, @@ -318,14 +319,15 @@ def __init__( nx_graph.add_nodes_from(vertices) nx_graph.add_edges_from(edges) self._graph = nx_graph + + layout_config['partitions'] = partitions + layout_config['root_vertex'] = root_vertex self._layout = _determine_graph_layout( nx_graph, layout=layout, layout_scale=layout_scale, layout_config=layout_config, - partitions=partitions, - root_vertex=root_vertex, ) if isinstance(labels, dict): @@ -944,7 +946,7 @@ def construct(self): def change_layout( self, - layout: str | dict = "spring", + layout: str | dict[Hashable, np.ndarray] | LayoutFunction = "spring", layout_scale: float = 2, layout_config: dict | None = None, partitions: list[list[Hashable]] | None = None, From bf09935b2e9aa81066a94f00fb25e9e48aa78af7 Mon Sep 17 00:00:00 2001 From: Nikhil Iyer Date: Fri, 3 Nov 2023 00:48:54 -0400 Subject: [PATCH 02/19] only pass relevant args --- manim/mobject/graph.py | 120 +++++++++++++++++++++-------------------- 1 file changed, 61 insertions(+), 59 deletions(-) diff --git a/manim/mobject/graph.py b/manim/mobject/graph.py index 3b6e1818e2..04fad4f17d 100644 --- a/manim/mobject/graph.py +++ b/manim/mobject/graph.py @@ -26,62 +26,14 @@ from manim.utils.color import BLACK class LayoutFunction(Protocol): - def __call__(self, graph: nx.classes.graph.Graph | nx.classes.digraph.DiGraph, *args: Any, **kwargs: Any) -> dict[Hashable, np.ndarray]: + def __call__(self, graph: nx.classes.graph.Graph | nx.classes.digraph.DiGraph, scale: float = 2, *args: Any, **kwargs: Any) -> dict[Hashable, np.ndarray]: ... - -def _determine_graph_layout( - nx_graph: nx.classes.graph.Graph | nx.classes.digraph.DiGraph, - layout: str | dict[Hashable, np.ndarray] | LayoutFunction = "spring", - layout_scale: float = 2, - layout_config: dict | None = None, -) -> dict[Hashable, np.ndarray]: - - layouts = { - "circular": nx.layout.circular_layout, - "kamada_kawai": nx.layout.kamada_kawai_layout, - "partite": _partite_layout, - "planar": nx.layout.planar_layout, - "random": _random_layout, - "shell": nx.layout.shell_layout, - "spectral": nx.layout.spectral_layout, - "spiral": nx.layout.spiral_layout, - "spring": nx.layout.spring_layout, - "tree": _tree_layout, - } - - if layout_config is None: - layout_config = {} - if layout_config.get("scale") is None: - layout_config["scale"] = layout_scale - - if isinstance(layout, dict): - return layout - elif layout in layouts: - layout_f, prepare = layouts[layout] - prepare(layout_config) - auto_layout = layout_f( - nx_graph, **layout_config - ) - # NetworkX returns a dictionary of 3D points if the dimension - # is specified to be 3. Otherwise, it returns a dictionary of - # 2D points, so adjusting is required. - if layout_config.get("dim") == 3 or auto_layout[next(auto_layout.__iter__())].shape[0] == 3: - return auto_layout - else: - return {k: np.append(v, [0]) for k, v in auto_layout.items()} - else: - try: - return layout(nx_graph, **layout_config) - except TypeError as e: - raise ValueError( - f"The layout '{layout}' is neither a recognized layout, a layout function," - "nor a vertex placement dictionary.", - ) - -def _partite_layout(nx_graph: nx.classes.graph.Graph, partitions: list[list[Hashable]] | None = None, **kwargs) -> dict[Hashable, np.ndarray]: + + +def _partite_layout(nx_graph: nx.classes.graph.Graph, scale: float=2, partitions: list[list[Hashable]] | None = None, **kwargs) -> dict[Hashable, np.ndarray]: if partitions is None or len(partitions) == 0: raise ValueError( - "The partite layout requires `layout_config['partitions']` parameter to contain the partition of the vertices", + "The partite layout requires partitions parameter to contain the partition of the vertices", ) partition_count = len(partitions) for i in range(partition_count): @@ -96,9 +48,10 @@ def _partite_layout(nx_graph: nx.classes.graph.Graph, partitions: list[list[Hash if "subset" not in nx_graph.nodes[v]: nx_graph.nodes[v]["subset"] = partition_count - return nx.layout.multipartite_layout(nx_graph, **kwargs) + return nx.layout.multipartite_layout(nx_graph, scale=scale, **kwargs) + -def _random_layout(nx_graph, scale, **kwargs): +def _random_layout(nx_graph, scale: float=2, **kwargs): # the random layout places coordinates in [0, 1) # we need to rescale manually afterwards... auto_layout = nx.layout.random_layout(nx_graph, **kwargs) @@ -106,15 +59,16 @@ def _random_layout(nx_graph, scale, **kwargs): auto_layout[k] = 2 * scale * (v - np.array([0.5, 0.5])) return {k: np.append(v, [0]) for k, v in auto_layout.items()} + def _tree_layout( T: nx.classes.graph.Graph | nx.classes.digraph.DiGraph, - root_vertex: Hashable | None, + root_vertex: Hashable | None = None, scale: float | tuple | None = 2, vertex_spacing: tuple | None = None, orientation: str = "down", ): if root_vertex is None: - raise ValueError("The tree layout requires the layout_config['root_vertex'] parameter") + raise ValueError("The tree layout requires the root_vertex parameter") if not nx.is_tree(T): raise ValueError("The tree layout must be used with trees") @@ -213,6 +167,51 @@ def slide(v, dx): return {v: (np.array([x, y, 0]) - center) * sf for v, (x, y) in pos.items()} +_layouts = { + "circular": nx.layout.circular_layout, + "kamada_kawai": nx.layout.kamada_kawai_layout, + "partite": _partite_layout, + "planar": nx.layout.planar_layout, + "random": _random_layout, + "shell": nx.layout.shell_layout, + "spectral": nx.layout.spectral_layout, + "spiral": nx.layout.spiral_layout, + "spring": nx.layout.spring_layout, + "tree": _tree_layout, +} + + +def _determine_graph_layout( + nx_graph: nx.classes.graph.Graph | nx.classes.digraph.DiGraph, + layout: str | dict[Hashable, np.ndarray] | LayoutFunction = "spring", + layout_scale: float = 2, + layout_config: dict | None = None, +) -> dict[Hashable, np.ndarray]: + if layout_config is None: + layout_config = {} + + if isinstance(layout, dict): + return layout + elif layout in _layouts: + auto_layout = _layouts[layout]( + nx_graph, scale=layout_scale, **layout_config + ) + # NetworkX returns a dictionary of 3D points if the dimension + # is specified to be 3. Otherwise, it returns a dictionary of + # 2D points, so adjusting is required. + if layout_config.get("dim") == 3 or auto_layout[next(auto_layout.__iter__())].shape[0] == 3: + return auto_layout + else: + return {k: np.append(v, [0]) for k, v in auto_layout.items()} + else: + try: + return layout(nx_graph, scale=layout_scale, **layout_config) + except TypeError as e: + raise ValueError( + f"The layout '{layout}' is neither a recognized layout, a layout function," + "nor a vertex placement dictionary.", + ) + class GenericGraph(VMobject, metaclass=ConvertToOpenGL): """Abstract base class for graphs (that is, a collection of vertices connected with edges). @@ -320,8 +319,11 @@ def __init__( nx_graph.add_edges_from(edges) self._graph = nx_graph - layout_config['partitions'] = partitions - layout_config['root_vertex'] = root_vertex + layout_config = {} if layout_config is None else layout_config + if partitions is not None and 'partitions' not in layout_config: + layout_config['partitions'] = partitions + if root_vertex is not None and 'root_vertex' not in layout_config: + layout_config['root_vertex'] = root_vertex self._layout = _determine_graph_layout( nx_graph, From 2c4a83afecf4ab6c8e5e13e2108eab5901178675 Mon Sep 17 00:00:00 2001 From: Nikhil Iyer Date: Fri, 3 Nov 2023 00:49:05 -0400 Subject: [PATCH 03/19] write tests --- tests/module/mobject/test_graph.py | 46 ++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/tests/module/mobject/test_graph.py b/tests/module/mobject/test_graph.py index b05d6b21d9..6bbdd51a5b 100644 --- a/tests/module/mobject/test_graph.py +++ b/tests/module/mobject/test_graph.py @@ -3,6 +3,7 @@ import pytest from manim import DiGraph, Graph, Scene, Text, tempconfig +from manim.mobject.graph import _layouts def test_graph_creation(): @@ -104,6 +105,50 @@ def test_custom_animation_mobject_list(): assert scene.mobjects == [G] +def test_custom_graph_layout_dict(): + G = Graph([1, 2, 3], [(1, 2), (2, 3)], layout={1: [0, 0, 0], 2: [1, 1, 0], 3: [1, -1, 0]}) + assert str(G) == "Undirected graph on 3 vertices and 2 edges" + assert all(G.vertices[1].get_center() == [0, 0, 0]) + assert all(G.vertices[2].get_center() == [1, 1, 0]) + assert all(G.vertices[3].get_center() == [1, -1, 0]) + + +def test_graph_layouts(): + for layout in (layout for layout in _layouts if layout != 'tree' and layout != 'partite'): + G = Graph([1, 2, 3], [(1, 2), (2, 3)], layout=layout) + assert str(G) == "Undirected graph on 3 vertices and 2 edges" + + +def test_tree_layout(): + G = Graph([1, 2, 3], [(1, 2), (2, 3)], layout="tree", root_vertex=1) + assert str(G) == "Undirected graph on 3 vertices and 2 edges" + + +def test_partite_layout(): + G = Graph([1, 2, 3, 4, 5], [(1, 2), (2, 3), (3, 4), (4, 5)], layout="partite", partitions=[[1, 2], [3, 4, 5]]) + assert str(G) == "Undirected graph on 5 vertices and 4 edges" + + +def test_custom_graph_layout_function(): + def layout_func(graph, scale): + return {vertex: [vertex, vertex, 0] for vertex in graph} + + G = Graph([1, 2, 3], [(1, 2), (2, 3)], layout=layout_func) + assert all(G.vertices[1].get_center() == [1, 1, 0]) + assert all(G.vertices[2].get_center() == [2, 2, 0]) + assert all(G.vertices[3].get_center() == [3, 3, 0]) + + +def test_custom_graph_layout_function_with_kwargs(): + def layout_func(graph, scale, offset): + return {vertex: [vertex * scale + offset, vertex * scale + offset, 0] for vertex in graph} + + G = Graph([1, 2, 3], [(1, 2), (2, 3)], layout=layout_func, layout_config={'offset': 1}) + assert all(G.vertices[1].get_center() == [3, 3, 0]) + assert all(G.vertices[2].get_center() == [5, 5, 0]) + assert all(G.vertices[3].get_center() == [7, 7, 0]) + + def test_tree_layout_no_root_error(): with pytest.raises(ValueError) as excinfo: G = Graph([1, 2, 3], [(1, 2), (2, 3)], layout="tree") @@ -114,3 +159,4 @@ def test_tree_layout_not_tree_error(): with pytest.raises(ValueError) as excinfo: G = Graph([1, 2, 3], [(1, 2), (2, 3), (3, 1)], layout="tree", root_vertex=1) assert str(excinfo.value) == "The tree layout must be used with trees" + From dcec2520d17d6f0819d5beef4a63165208910911 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 3 Nov 2023 05:04:07 +0000 Subject: [PATCH 04/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- manim/mobject/graph.py | 46 +++++++++++++++++++----------- tests/module/mobject/test_graph.py | 27 +++++++++++++----- 2 files changed, 50 insertions(+), 23 deletions(-) diff --git a/manim/mobject/graph.py b/manim/mobject/graph.py index 04fad4f17d..e423a4bc88 100644 --- a/manim/mobject/graph.py +++ b/manim/mobject/graph.py @@ -25,12 +25,24 @@ from manim.mobject.types.vectorized_mobject import VMobject from manim.utils.color import BLACK + class LayoutFunction(Protocol): - def __call__(self, graph: nx.classes.graph.Graph | nx.classes.digraph.DiGraph, scale: float = 2, *args: Any, **kwargs: Any) -> dict[Hashable, np.ndarray]: + def __call__( + self, + graph: nx.classes.graph.Graph | nx.classes.digraph.DiGraph, + scale: float = 2, + *args: Any, + **kwargs: Any, + ) -> dict[Hashable, np.ndarray]: ... - - -def _partite_layout(nx_graph: nx.classes.graph.Graph, scale: float=2, partitions: list[list[Hashable]] | None = None, **kwargs) -> dict[Hashable, np.ndarray]: + + +def _partite_layout( + nx_graph: nx.classes.graph.Graph, + scale: float = 2, + partitions: list[list[Hashable]] | None = None, + **kwargs, +) -> dict[Hashable, np.ndarray]: if partitions is None or len(partitions) == 0: raise ValueError( "The partite layout requires partitions parameter to contain the partition of the vertices", @@ -47,11 +59,11 @@ def _partite_layout(nx_graph: nx.classes.graph.Graph, scale: float=2, partitions for v in nx_graph.nodes: if "subset" not in nx_graph.nodes[v]: nx_graph.nodes[v]["subset"] = partition_count - + return nx.layout.multipartite_layout(nx_graph, scale=scale, **kwargs) -def _random_layout(nx_graph, scale: float=2, **kwargs): +def _random_layout(nx_graph, scale: float = 2, **kwargs): # the random layout places coordinates in [0, 1) # we need to rescale manually afterwards... auto_layout = nx.layout.random_layout(nx_graph, **kwargs) @@ -186,20 +198,21 @@ def _determine_graph_layout( layout: str | dict[Hashable, np.ndarray] | LayoutFunction = "spring", layout_scale: float = 2, layout_config: dict | None = None, -) -> dict[Hashable, np.ndarray]: +) -> dict[Hashable, np.ndarray]: if layout_config is None: layout_config = {} if isinstance(layout, dict): return layout elif layout in _layouts: - auto_layout = _layouts[layout]( - nx_graph, scale=layout_scale, **layout_config - ) + auto_layout = _layouts[layout](nx_graph, scale=layout_scale, **layout_config) # NetworkX returns a dictionary of 3D points if the dimension # is specified to be 3. Otherwise, it returns a dictionary of # 2D points, so adjusting is required. - if layout_config.get("dim") == 3 or auto_layout[next(auto_layout.__iter__())].shape[0] == 3: + if ( + layout_config.get("dim") == 3 + or auto_layout[next(auto_layout.__iter__())].shape[0] == 3 + ): return auto_layout else: return {k: np.append(v, [0]) for k, v in auto_layout.items()} @@ -212,6 +225,7 @@ def _determine_graph_layout( "nor a vertex placement dictionary.", ) + class GenericGraph(VMobject, metaclass=ConvertToOpenGL): """Abstract base class for graphs (that is, a collection of vertices connected with edges). @@ -318,12 +332,12 @@ def __init__( nx_graph.add_nodes_from(vertices) nx_graph.add_edges_from(edges) self._graph = nx_graph - + layout_config = {} if layout_config is None else layout_config - if partitions is not None and 'partitions' not in layout_config: - layout_config['partitions'] = partitions - if root_vertex is not None and 'root_vertex' not in layout_config: - layout_config['root_vertex'] = root_vertex + if partitions is not None and "partitions" not in layout_config: + layout_config["partitions"] = partitions + if root_vertex is not None and "root_vertex" not in layout_config: + layout_config["root_vertex"] = root_vertex self._layout = _determine_graph_layout( nx_graph, diff --git a/tests/module/mobject/test_graph.py b/tests/module/mobject/test_graph.py index 6bbdd51a5b..07759a843b 100644 --- a/tests/module/mobject/test_graph.py +++ b/tests/module/mobject/test_graph.py @@ -106,7 +106,9 @@ def test_custom_animation_mobject_list(): def test_custom_graph_layout_dict(): - G = Graph([1, 2, 3], [(1, 2), (2, 3)], layout={1: [0, 0, 0], 2: [1, 1, 0], 3: [1, -1, 0]}) + G = Graph( + [1, 2, 3], [(1, 2), (2, 3)], layout={1: [0, 0, 0], 2: [1, 1, 0], 3: [1, -1, 0]} + ) assert str(G) == "Undirected graph on 3 vertices and 2 edges" assert all(G.vertices[1].get_center() == [0, 0, 0]) assert all(G.vertices[2].get_center() == [1, 1, 0]) @@ -114,7 +116,9 @@ def test_custom_graph_layout_dict(): def test_graph_layouts(): - for layout in (layout for layout in _layouts if layout != 'tree' and layout != 'partite'): + for layout in ( + layout for layout in _layouts if layout != "tree" and layout != "partite" + ): G = Graph([1, 2, 3], [(1, 2), (2, 3)], layout=layout) assert str(G) == "Undirected graph on 3 vertices and 2 edges" @@ -125,7 +129,12 @@ def test_tree_layout(): def test_partite_layout(): - G = Graph([1, 2, 3, 4, 5], [(1, 2), (2, 3), (3, 4), (4, 5)], layout="partite", partitions=[[1, 2], [3, 4, 5]]) + G = Graph( + [1, 2, 3, 4, 5], + [(1, 2), (2, 3), (3, 4), (4, 5)], + layout="partite", + partitions=[[1, 2], [3, 4, 5]], + ) assert str(G) == "Undirected graph on 5 vertices and 4 edges" @@ -141,9 +150,14 @@ def layout_func(graph, scale): def test_custom_graph_layout_function_with_kwargs(): def layout_func(graph, scale, offset): - return {vertex: [vertex * scale + offset, vertex * scale + offset, 0] for vertex in graph} - - G = Graph([1, 2, 3], [(1, 2), (2, 3)], layout=layout_func, layout_config={'offset': 1}) + return { + vertex: [vertex * scale + offset, vertex * scale + offset, 0] + for vertex in graph + } + + G = Graph( + [1, 2, 3], [(1, 2), (2, 3)], layout=layout_func, layout_config={"offset": 1} + ) assert all(G.vertices[1].get_center() == [3, 3, 0]) assert all(G.vertices[2].get_center() == [5, 5, 0]) assert all(G.vertices[3].get_center() == [7, 7, 0]) @@ -159,4 +173,3 @@ def test_tree_layout_not_tree_error(): with pytest.raises(ValueError) as excinfo: G = Graph([1, 2, 3], [(1, 2), (2, 3), (3, 1)], layout="tree", root_vertex=1) assert str(excinfo.value) == "The tree layout must be used with trees" - From 71694e751d0f02117006ce8526225d9b1c665623 Mon Sep 17 00:00:00 2001 From: Nikhil Iyer Date: Sat, 2 Dec 2023 11:14:59 -0500 Subject: [PATCH 05/19] change_layout forward root_vertex and partitions - deduplicated layout code in __init__ and change_layout - fixed change_layout backwards compatibility --- manim/mobject/graph.py | 29 +++++++++++------------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/manim/mobject/graph.py b/manim/mobject/graph.py index e423a4bc88..004ad55439 100644 --- a/manim/mobject/graph.py +++ b/manim/mobject/graph.py @@ -333,19 +333,6 @@ def __init__( nx_graph.add_edges_from(edges) self._graph = nx_graph - layout_config = {} if layout_config is None else layout_config - if partitions is not None and "partitions" not in layout_config: - layout_config["partitions"] = partitions - if root_vertex is not None and "root_vertex" not in layout_config: - layout_config["root_vertex"] = root_vertex - - self._layout = _determine_graph_layout( - nx_graph, - layout=layout, - layout_scale=layout_scale, - layout_config=layout_config, - ) - if isinstance(labels, dict): self._labels = labels elif isinstance(labels, bool): @@ -379,8 +366,8 @@ def __init__( self.vertices = {v: vertex_type(**self._vertex_config[v]) for v in vertices} self.vertices.update(vertex_mobjects) - for v in self.vertices: - self[v].move_to(self._layout[v]) + + self.change_layout(layout=layout, layout_scale=layout_scale, layout_config=layout_config, partitions=partitions, root_vertex=root_vertex) # build edge_config if edge_config is None: @@ -415,6 +402,7 @@ def __init__( self.add(*self.edges.values()) self.add_updater(self.update_edges) + self.cha @staticmethod def _empty_networkx_graph(): @@ -988,14 +976,19 @@ def construct(self): self.play(G.animate.change_layout("circular")) self.wait() """ + layout_config = {} if layout_config is None else layout_config + if partitions is not None and "partitions" not in layout_config: + layout_config["partitions"] = partitions + if root_vertex is not None and "root_vertex" not in layout_config: + layout_config["root_vertex"] = root_vertex + self._layout = _determine_graph_layout( - self._graph, + nx_graph, layout=layout, layout_scale=layout_scale, layout_config=layout_config, - partitions=partitions, - root_vertex=root_vertex, ) + for v in self.vertices: self[v].move_to(self._layout[v]) return self From 8ea4dd453e91bcdb8c8daf6981244c9ec3e618b9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 2 Dec 2023 16:16:28 +0000 Subject: [PATCH 06/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- manim/mobject/graph.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/manim/mobject/graph.py b/manim/mobject/graph.py index 004ad55439..fd264ee166 100644 --- a/manim/mobject/graph.py +++ b/manim/mobject/graph.py @@ -366,8 +366,14 @@ def __init__( self.vertices = {v: vertex_type(**self._vertex_config[v]) for v in vertices} self.vertices.update(vertex_mobjects) - - self.change_layout(layout=layout, layout_scale=layout_scale, layout_config=layout_config, partitions=partitions, root_vertex=root_vertex) + + self.change_layout( + layout=layout, + layout_scale=layout_scale, + layout_config=layout_config, + partitions=partitions, + root_vertex=root_vertex, + ) # build edge_config if edge_config is None: From de9dc4b7345924ebb066a9cfafb1be3c3fc35507 Mon Sep 17 00:00:00 2001 From: Nikhil Iyer Date: Sat, 2 Dec 2023 11:20:29 -0500 Subject: [PATCH 07/19] add test for change_layout --- tests/module/mobject/test_graph.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/module/mobject/test_graph.py b/tests/module/mobject/test_graph.py index 07759a843b..1ef4bb81fc 100644 --- a/tests/module/mobject/test_graph.py +++ b/tests/module/mobject/test_graph.py @@ -162,6 +162,13 @@ def layout_func(graph, scale, offset): assert all(G.vertices[2].get_center() == [5, 5, 0]) assert all(G.vertices[3].get_center() == [7, 7, 0]) +def test_graph_change_layout(): + for layout in ( + layout for layout in _layouts if layout != "tree" and layout != "partite" + ): + G = Graph([1, 2, 3], [(1, 2), (2, 3)]) + G.change_layout(layout=layout) + assert str(G) == "Undirected graph on 3 vertices and 2 edges" def test_tree_layout_no_root_error(): with pytest.raises(ValueError) as excinfo: From 021ef7e32145bb2e2dfc69fdc0b6845c1c28088b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 2 Dec 2023 16:22:19 +0000 Subject: [PATCH 08/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/module/mobject/test_graph.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/module/mobject/test_graph.py b/tests/module/mobject/test_graph.py index 1ef4bb81fc..d5d0efe08b 100644 --- a/tests/module/mobject/test_graph.py +++ b/tests/module/mobject/test_graph.py @@ -162,6 +162,7 @@ def layout_func(graph, scale, offset): assert all(G.vertices[2].get_center() == [5, 5, 0]) assert all(G.vertices[3].get_center() == [7, 7, 0]) + def test_graph_change_layout(): for layout in ( layout for layout in _layouts if layout != "tree" and layout != "partite" @@ -170,6 +171,7 @@ def test_graph_change_layout(): G.change_layout(layout=layout) assert str(G) == "Undirected graph on 3 vertices and 2 edges" + def test_tree_layout_no_root_error(): with pytest.raises(ValueError) as excinfo: G = Graph([1, 2, 3], [(1, 2), (2, 3)], layout="tree") From d3bfeb9d51fc72326bb90f30622ff29e08b09a3c Mon Sep 17 00:00:00 2001 From: Nikhil Iyer Date: Sat, 2 Dec 2023 11:26:37 -0500 Subject: [PATCH 09/19] fix copy/paste error --- manim/mobject/graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/manim/mobject/graph.py b/manim/mobject/graph.py index fd264ee166..b849b53fee 100644 --- a/manim/mobject/graph.py +++ b/manim/mobject/graph.py @@ -989,7 +989,7 @@ def construct(self): layout_config["root_vertex"] = root_vertex self._layout = _determine_graph_layout( - nx_graph, + self._graph, layout=layout, layout_scale=layout_scale, layout_config=layout_config, From 0434fbd4d6dccdc17456f7f40436f269874d5aaf Mon Sep 17 00:00:00 2001 From: Nikhil Iyer Date: Sat, 2 Dec 2023 11:34:54 -0500 Subject: [PATCH 10/19] fix --- manim/mobject/graph.py | 1 - 1 file changed, 1 deletion(-) diff --git a/manim/mobject/graph.py b/manim/mobject/graph.py index b849b53fee..08cd02b42f 100644 --- a/manim/mobject/graph.py +++ b/manim/mobject/graph.py @@ -408,7 +408,6 @@ def __init__( self.add(*self.edges.values()) self.add_updater(self.update_edges) - self.cha @staticmethod def _empty_networkx_graph(): From ee9b2e58dea370bacce109ae8f95fbc99466dcc9 Mon Sep 17 00:00:00 2001 From: Nikhil Iyer Date: Sat, 2 Dec 2023 13:10:30 -0500 Subject: [PATCH 11/19] fixup types for CodeQL --- manim/mobject/graph.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/manim/mobject/graph.py b/manim/mobject/graph.py index 08cd02b42f..f55efe965d 100644 --- a/manim/mobject/graph.py +++ b/manim/mobject/graph.py @@ -9,7 +9,7 @@ import itertools as it from copy import copy -from typing import Any, Callable, Hashable, Iterable, Protocol +from typing import Any, Hashable, Iterable, Protocol, cast import networkx as nx import numpy as np @@ -30,7 +30,7 @@ class LayoutFunction(Protocol): def __call__( self, graph: nx.classes.graph.Graph | nx.classes.digraph.DiGraph, - scale: float = 2, + scale: float | tuple[float, float, float] = 2, *args: Any, **kwargs: Any, ) -> dict[Hashable, np.ndarray]: @@ -196,7 +196,7 @@ def slide(v, dx): def _determine_graph_layout( nx_graph: nx.classes.graph.Graph | nx.classes.digraph.DiGraph, layout: str | dict[Hashable, np.ndarray] | LayoutFunction = "spring", - layout_scale: float = 2, + layout_scale: float | tuple[float, float, float] = 2, layout_config: dict | None = None, ) -> dict[Hashable, np.ndarray]: if layout_config is None: @@ -218,7 +218,9 @@ def _determine_graph_layout( return {k: np.append(v, [0]) for k, v in auto_layout.items()} else: try: - return layout(nx_graph, scale=layout_scale, **layout_config) + return cast(LayoutFunction, layout)( + nx_graph, scale=layout_scale, **layout_config + ) except TypeError as e: raise ValueError( f"The layout '{layout}' is neither a recognized layout, a layout function," @@ -270,8 +272,8 @@ class GenericGraph(VMobject, metaclass=ConvertToOpenGL): ``"planar"``, ``"random"``, ``"shell"``, ``"spectral"``, ``"spiral"``, ``"tree"``, and ``"partite"`` for automatic vertex positioning using ``networkx`` (see `their documentation `_ - for more details), or a dictionary specifying a coordinate (value) - for each vertex (key) for manual positioning. + for more details), a dictionary specifying a coordinate (value) + for each vertex (key) for manual positioning, or a .:class:`~.LayoutFunction` with a user-defined automatic layout. layout_config Only for automatically generated layouts. A dictionary whose entries are passed as keyword arguments to the automatic layout algorithm @@ -316,7 +318,7 @@ def __init__( labels: bool | dict = False, label_fill_color: str = BLACK, layout: str | dict[Hashable, np.ndarray] | LayoutFunction = "spring", - layout_scale: float | tuple = 2, + layout_scale: float | tuple[float, float, float] = 2, layout_config: dict | None = None, vertex_type: type[Mobject] = Dot, vertex_config: dict | None = None, @@ -956,7 +958,7 @@ def construct(self): def change_layout( self, layout: str | dict[Hashable, np.ndarray] | LayoutFunction = "spring", - layout_scale: float = 2, + layout_scale: float | tuple[float, float, float] = 2, layout_config: dict | None = None, partitions: list[list[Hashable]] | None = None, root_vertex: Hashable | None = None, From 19b7d16906ca3f481a733c96a772c089e7a5528f Mon Sep 17 00:00:00 2001 From: Nikhil Iyer Date: Sun, 24 Dec 2023 23:01:03 -0500 Subject: [PATCH 12/19] static type the Layout Names --- manim/mobject/graph.py | 61 +++++++++++++++++++++--------------------- 1 file changed, 31 insertions(+), 30 deletions(-) diff --git a/manim/mobject/graph.py b/manim/mobject/graph.py index f55efe965d..1784b95aa0 100644 --- a/manim/mobject/graph.py +++ b/manim/mobject/graph.py @@ -9,7 +9,7 @@ import itertools as it from copy import copy -from typing import Any, Hashable, Iterable, Protocol, cast +from typing import Any, Hashable, Iterable, Literal, Protocol, cast import networkx as nx import numpy as np @@ -25,24 +25,24 @@ from manim.mobject.types.vectorized_mobject import VMobject from manim.utils.color import BLACK - +NxGraph = nx.classes.graph.Graph | nx.classes.digraph.DiGraph class LayoutFunction(Protocol): def __call__( self, - graph: nx.classes.graph.Graph | nx.classes.digraph.DiGraph, + graph: NxGraph, scale: float | tuple[float, float, float] = 2, - *args: Any, - **kwargs: Any, - ) -> dict[Hashable, np.ndarray]: + *args: tuple[Any, ...], + **kwargs: dict[str, Any] + ) -> dict[Hashable, np.ndarray[np.float64]]: ... def _partite_layout( - nx_graph: nx.classes.graph.Graph, + nx_graph: NxGraph, scale: float = 2, partitions: list[list[Hashable]] | None = None, - **kwargs, -) -> dict[Hashable, np.ndarray]: + **kwargs: dict[str, Any] +) -> dict[Hashable, np.ndarray[np.float64]]: if partitions is None or len(partitions) == 0: raise ValueError( "The partite layout requires partitions parameter to contain the partition of the vertices", @@ -63,7 +63,7 @@ def _partite_layout( return nx.layout.multipartite_layout(nx_graph, scale=scale, **kwargs) -def _random_layout(nx_graph, scale: float = 2, **kwargs): +def _random_layout(nx_graph: NxGraph, scale: float = 2, **kwargs: dict[str, Any]): # the random layout places coordinates in [0, 1) # we need to rescale manually afterwards... auto_layout = nx.layout.random_layout(nx_graph, **kwargs) @@ -73,7 +73,7 @@ def _random_layout(nx_graph, scale: float = 2, **kwargs): def _tree_layout( - T: nx.classes.graph.Graph | nx.classes.digraph.DiGraph, + T: NxGraph, root_vertex: Hashable | None = None, scale: float | tuple | None = 2, vertex_spacing: tuple | None = None, @@ -179,26 +179,27 @@ def slide(v, dx): return {v: (np.array([x, y, 0]) - center) * sf for v, (x, y) in pos.items()} -_layouts = { - "circular": nx.layout.circular_layout, - "kamada_kawai": nx.layout.kamada_kawai_layout, - "partite": _partite_layout, - "planar": nx.layout.planar_layout, - "random": _random_layout, - "shell": nx.layout.shell_layout, - "spectral": nx.layout.spectral_layout, - "spiral": nx.layout.spiral_layout, - "spring": nx.layout.spring_layout, - "tree": _tree_layout, -} +LayoutName = Literal["circular", "kamada_kawai", "partite", "planar", "random", "shell", "spectral", "spiral", "spring", "tree"] +_layouts: dict[LayoutName, LayoutFunction] = { + "circular": cast(LayoutFunction, nx.layout.circular_layout), + "kamada_kawai": cast(LayoutFunction, nx.layout.kamada_kawai_layout), + "partite": cast(LayoutFunction, _partite_layout), + "planar": cast(LayoutFunction, nx.layout.planar_layout), + "random": cast(LayoutFunction, _random_layout), + "shell": cast(LayoutFunction, nx.layout.shell_layout), + "spectral": cast(LayoutFunction, nx.layout.spectral_layout), + "spiral": cast(LayoutFunction, nx.layout.spiral_layout), + "spring": cast(LayoutFunction, nx.layout.spring_layout), + "tree": cast(LayoutFunction, _tree_layout), +} def _determine_graph_layout( nx_graph: nx.classes.graph.Graph | nx.classes.digraph.DiGraph, - layout: str | dict[Hashable, np.ndarray] | LayoutFunction = "spring", + layout: LayoutName | dict[Hashable, np.ndarray[np.float64]] | LayoutFunction = "spring", layout_scale: float | tuple[float, float, float] = 2, - layout_config: dict | None = None, -) -> dict[Hashable, np.ndarray]: + layout_config: dict[str, Any] | None = None, +) -> dict[Hashable, np.ndarray[np.float64]]: if layout_config is None: layout_config = {} @@ -317,7 +318,7 @@ def __init__( edges: list[tuple[Hashable, Hashable]], labels: bool | dict = False, label_fill_color: str = BLACK, - layout: str | dict[Hashable, np.ndarray] | LayoutFunction = "spring", + layout: LayoutName | dict[Hashable, np.ndarray[np.float64]] | LayoutFunction = "spring", layout_scale: float | tuple[float, float, float] = 2, layout_config: dict | None = None, vertex_type: type[Mobject] = Dot, @@ -412,7 +413,7 @@ def __init__( self.add_updater(self.update_edges) @staticmethod - def _empty_networkx_graph(): + def _empty_networkx_graph() -> nx.classes.graph.Graph: """Return an empty networkx graph for the given graph type.""" raise NotImplementedError("To be implemented in concrete subclasses") @@ -957,9 +958,9 @@ def construct(self): def change_layout( self, - layout: str | dict[Hashable, np.ndarray] | LayoutFunction = "spring", + layout: LayoutName | dict[Hashable, np.ndarray[np.float64]] | LayoutFunction = "spring", layout_scale: float | tuple[float, float, float] = 2, - layout_config: dict | None = None, + layout_config: dict[str, Any] | None = None, partitions: list[list[Hashable]] | None = None, root_vertex: Hashable | None = None, ) -> Graph: From 5da7f8c5789b5ba6bbf8379bcfc66301095edc5a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 25 Dec 2023 04:01:51 +0000 Subject: [PATCH 13/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- manim/mobject/graph.py | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/manim/mobject/graph.py b/manim/mobject/graph.py index 1784b95aa0..23143679b5 100644 --- a/manim/mobject/graph.py +++ b/manim/mobject/graph.py @@ -26,13 +26,15 @@ from manim.utils.color import BLACK NxGraph = nx.classes.graph.Graph | nx.classes.digraph.DiGraph + + class LayoutFunction(Protocol): def __call__( self, graph: NxGraph, scale: float | tuple[float, float, float] = 2, *args: tuple[Any, ...], - **kwargs: dict[str, Any] + **kwargs: dict[str, Any], ) -> dict[Hashable, np.ndarray[np.float64]]: ... @@ -41,7 +43,7 @@ def _partite_layout( nx_graph: NxGraph, scale: float = 2, partitions: list[list[Hashable]] | None = None, - **kwargs: dict[str, Any] + **kwargs: dict[str, Any], ) -> dict[Hashable, np.ndarray[np.float64]]: if partitions is None or len(partitions) == 0: raise ValueError( @@ -179,7 +181,18 @@ def slide(v, dx): return {v: (np.array([x, y, 0]) - center) * sf for v, (x, y) in pos.items()} -LayoutName = Literal["circular", "kamada_kawai", "partite", "planar", "random", "shell", "spectral", "spiral", "spring", "tree"] +LayoutName = Literal[ + "circular", + "kamada_kawai", + "partite", + "planar", + "random", + "shell", + "spectral", + "spiral", + "spring", + "tree", +] _layouts: dict[LayoutName, LayoutFunction] = { "circular": cast(LayoutFunction, nx.layout.circular_layout), @@ -194,9 +207,12 @@ def slide(v, dx): "tree": cast(LayoutFunction, _tree_layout), } + def _determine_graph_layout( nx_graph: nx.classes.graph.Graph | nx.classes.digraph.DiGraph, - layout: LayoutName | dict[Hashable, np.ndarray[np.float64]] | LayoutFunction = "spring", + layout: LayoutName + | dict[Hashable, np.ndarray[np.float64]] + | LayoutFunction = "spring", layout_scale: float | tuple[float, float, float] = 2, layout_config: dict[str, Any] | None = None, ) -> dict[Hashable, np.ndarray[np.float64]]: @@ -318,7 +334,9 @@ def __init__( edges: list[tuple[Hashable, Hashable]], labels: bool | dict = False, label_fill_color: str = BLACK, - layout: LayoutName | dict[Hashable, np.ndarray[np.float64]] | LayoutFunction = "spring", + layout: LayoutName + | dict[Hashable, np.ndarray[np.float64]] + | LayoutFunction = "spring", layout_scale: float | tuple[float, float, float] = 2, layout_config: dict | None = None, vertex_type: type[Mobject] = Dot, @@ -958,7 +976,9 @@ def construct(self): def change_layout( self, - layout: LayoutName | dict[Hashable, np.ndarray[np.float64]] | LayoutFunction = "spring", + layout: LayoutName + | dict[Hashable, np.ndarray[np.float64]] + | LayoutFunction = "spring", layout_scale: float | tuple[float, float, float] = 2, layout_config: dict[str, Any] | None = None, partitions: list[list[Hashable]] | None = None, From 3a88ae29b1ee792c2af14ce676e60955cea0afdb Mon Sep 17 00:00:00 2001 From: Nikhil Iyer Date: Mon, 25 Dec 2023 04:06:55 -0500 Subject: [PATCH 14/19] fix dynamic union type for Python 3.9 --- manim/mobject/graph.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/manim/mobject/graph.py b/manim/mobject/graph.py index 23143679b5..c95dfc20b2 100644 --- a/manim/mobject/graph.py +++ b/manim/mobject/graph.py @@ -9,7 +9,7 @@ import itertools as it from copy import copy -from typing import Any, Hashable, Iterable, Literal, Protocol, cast +from typing import Any, Hashable, Iterable, Literal, Protocol, Union, cast import networkx as nx import numpy as np @@ -25,7 +25,7 @@ from manim.mobject.types.vectorized_mobject import VMobject from manim.utils.color import BLACK -NxGraph = nx.classes.graph.Graph | nx.classes.digraph.DiGraph +NxGraph = Union[nx.classes.graph.Graph, nx.classes.digraph.DiGraph] class LayoutFunction(Protocol): From a1b354df59f36dd640b47d2ead518c10a8817232 Mon Sep 17 00:00:00 2001 From: Nikhil Iyer Date: Tue, 2 Jan 2024 23:42:39 -0500 Subject: [PATCH 15/19] add example scenes to LayoutFunction protocol documentation --- manim/mobject/graph.py | 243 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 239 insertions(+), 4 deletions(-) diff --git a/manim/mobject/graph.py b/manim/mobject/graph.py index c95dfc20b2..6c4a0f9c80 100644 --- a/manim/mobject/graph.py +++ b/manim/mobject/graph.py @@ -29,6 +29,227 @@ class LayoutFunction(Protocol): + """A protocol for automatic layout functions that compute a layout for a graph to be used in :meth:`~.Graph.change_layout`. + + .. note:: The layout function must be a pure function, i.e., it must not modify the graph passed to it. + + Examples + -------- + + Here is an example that arranges nodes in an n x m grid in sorted order. + + .. manim:: CustomLayoutExample + :save_last_frame: + + class CustomLayoutExample(Scene): + def construct(self): + import numpy as np + import networkx as nx + + # create custom layout + def custom_layout( + graph: nx.Graph, + scale: float | tuple[float, float, float] = 2, + n: int | None = None, + *args: tuple[Any, ...], + **kwargs: dict[str, Any], + ): + nodes = sorted(list(graph)) + height = len(nodes) // n + return { + node: (scale * np.array([ + (i % n) - (n-1)/2, + -(i // n) + height/2, + 0 + ])) for i, node in enumerate(graph) + } + + # draw graph + n = 4 + graph = Graph( + [i for i in range(4 * 2 - 1)], + [(0, 1), (0, 4), (1, 2), (1, 5), (2, 3), (2, 6), (4, 5), (5, 6)], + labels=True, + layout=custom_layout, + layout_config={'n': n} + ) + self.add(graph) + + Several automatic layouts are provided by manim, and can be used by passing their name as the ``layout`` parameter to :meth:`~.Graph.change_layout`. + Alternatively, a custom layout function can be passed to :meth:`~.Graph.change_layout` as the ``layout`` parameter. Such a function must adhere to the :class:`~.LayoutFunction` protocol. + + The :class:`~.LayoutFunction` s provided by manim are illustrated below: + + - Circular Layout: places the vertices on a circle + + .. manim:: CircularLayout + :save_last_frame: + + class CircularLayout(Scene): + def construct(self): + graph = Graph( + [1, 2, 3, 4, 5, 6], + [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 1), (5, 1), (1, 3), (3, 5)], + layout="circular", + labels=True + ) + self.add(graph) + + - Kamada Kawai Layout: tries to place the vertices such that the given distances between them are respected + + .. manim:: KamadaKawaiLayout + :save_last_frame: + + class KamadaKawaiLayout(Scene): + def construct(self): + from collections import defaultdict + distances: dict[int, dict[int, float]] = defaultdict(dict) + + # set desired distances + distances[1][2] = 1 # distance between vertices 1 and 2 is 1 + distances[2][3] = 1 # distance between vertices 2 and 3 is 1 + distances[3][4] = 2 # etc + distances[4][5] = 3 + distances[5][6] = 5 + distances[6][1] = 8 + + graph = Graph( + [1, 2, 3, 4, 5, 6], + [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 1)], + layout="kamada_kawai", + layout_config={"dist": distances}, + layout_scale=4, + labels=True + ) + self.add(graph) + + - Partite Layout: places vertices into distinct partitions + + .. manim:: PartiteLayout + :save_last_frame: + + class PartiteLayout(Scene): + def construct(self): + graph = Graph( + [1, 2, 3, 4, 5, 6], + [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 1), (5, 1), (1, 3), (3, 5)], + layout="partite", + layout_config={"partitions": [[1,2],[3,4],[5,6]]}, + labels=True + ) + self.add(graph) + + - Planar Layout: places vertices such that edges do not cross + + .. manim:: PlanarLayout + :save_last_frame: + + class PlanarLayout(Scene): + def construct(self): + graph = Graph( + [1, 2, 3, 4, 5, 6], + [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 1), (5, 1), (1, 3), (3, 5)], + layout="planar", + layout_scale=4, + labels=True + ) + self.add(graph) + + - Random Layout: randomly places vertices + + .. manim:: RandomLayout + :save_last_frame: + + class RandomLayout(Scene): + def construct(self): + graph = Graph( + [1, 2, 3, 4, 5, 6], + [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 1), (5, 1), (1, 3), (3, 5)], + layout="random", + labels=True + ) + self.add(graph) + + - Shell Layout: places vertices in concentric circles + + .. manim:: ShellLayout + :save_last_frame: + + class ShellLayout(Scene): + def construct(self): + nlist = [[1, 2, 3], [4, 5, 6, 7, 8, 9]] + graph = Graph( + [1, 2, 3, 4, 5, 6, 7, 8, 9], + [(1, 2), (2, 3), (3, 1), (4, 1), (4, 2), (5, 2), (6, 2), (6, 3), (7, 3), (8, 3), (8, 1), (9, 1)], + layout="shell", + layout_config={"nlist": nlist}, + labels=True + ) + self.add(graph) + + - Spectral Layout: places vertices using the eigenvectors of the graph Laplacian (clusters nodes which are an approximation of the ratio cut) + + .. manim:: SpectralLayout + :save_last_frame: + + class SpectralLayout(Scene): + def construct(self): + graph = Graph( + [1, 2, 3, 4, 5, 6], + [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 1), (5, 1), (1, 3), (3, 5)], + layout="spectral", + labels=True + ) + self.add(graph) + + - Sprial Layout: places vertices in a spiraling pattern + + .. manim:: SpiralLayout + :save_last_frame: + + class SpiralLayout(Scene): + def construct(self): + graph = Graph( + [1, 2, 3, 4, 5, 6], + [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 1), (5, 1), (1, 3), (3, 5)], + layout="spiral", + labels=True + ) + self.add(graph) + + - Spring Layout: places nodes according to the Fruchterman-Reingold force-directed algorithm (attempts to minimize edge length while maximizing node separation) + + .. manim:: SpringLayout + :save_last_frame: + + class SpringLayout(Scene): + def construct(self): + graph = Graph( + [1, 2, 3, 4, 5, 6], + [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 1), (5, 1), (1, 3), (3, 5)], + layout="spring", + labels=True + ) + self.add(graph) + + - Tree Layout: places vertices into a tree with a root node and branches (can only be used with legal trees) + + .. manim:: TreeLayout + :save_last_frame: + + class TreeLayout(Scene): + def construct(self): + graph = Graph( + [1, 2, 3, 4, 5, 6, 7], + [(1, 2), (1, 3), (2, 4), (2, 5), (3, 6), (3, 7)], + layout="tree", + layout_config={"root_vertex": 1}, + labels=True + ) + self.add(graph) + + """ + def __call__( self, graph: NxGraph, @@ -36,6 +257,20 @@ def __call__( *args: tuple[Any, ...], **kwargs: dict[str, Any], ) -> dict[Hashable, np.ndarray[np.float64]]: + """Given a graph and a scale, return a dictionary of coordinates. + + Parameters + ---------- + graph : NxGraph + The underlying NetworkX graph to be laid out. DO NOT MODIFY. + scale : float | tuple[float, float, float], optional + Either a single float value, or a tuple of three float values specifying the scale along each axis. + + Returns + ------- + dict[Hashable, np.ndarray[np.float64]] + A dictionary mapping vertices to their positions. + """ ... @@ -287,14 +522,14 @@ class GenericGraph(VMobject, metaclass=ConvertToOpenGL): layout Either one of ``"spring"`` (the default), ``"circular"``, ``"kamada_kawai"``, ``"planar"``, ``"random"``, ``"shell"``, ``"spectral"``, ``"spiral"``, ``"tree"``, and ``"partite"`` - for automatic vertex positioning using ``networkx`` + for automatic vertex positioning primarily using ``networkx`` (see `their documentation `_ for more details), a dictionary specifying a coordinate (value) for each vertex (key) for manual positioning, or a .:class:`~.LayoutFunction` with a user-defined automatic layout. layout_config - Only for automatically generated layouts. A dictionary whose entries - are passed as keyword arguments to the automatic layout algorithm - specified via ``layout`` of``networkx``. + Only for automatic layouts. A dictionary whose entries + are passed as keyword arguments to the named layout or automatic layout function + specified via ``layout``. The ``tree`` layout also accepts a special parameter ``vertex_spacing`` passed as a keyword argument inside the ``layout_config`` dictionary. Passing a tuple ``(space_x, space_y)`` as this argument overrides From 11912401c6d4d7516babe3dae7b2cef01b7744e7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 3 Jan 2024 04:43:30 +0000 Subject: [PATCH 16/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- manim/mobject/graph.py | 50 +++++++++++++++++++++--------------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/manim/mobject/graph.py b/manim/mobject/graph.py index 6c4a0f9c80..1b123e3ec9 100644 --- a/manim/mobject/graph.py +++ b/manim/mobject/graph.py @@ -30,22 +30,22 @@ class LayoutFunction(Protocol): """A protocol for automatic layout functions that compute a layout for a graph to be used in :meth:`~.Graph.change_layout`. - + .. note:: The layout function must be a pure function, i.e., it must not modify the graph passed to it. Examples -------- - + Here is an example that arranges nodes in an n x m grid in sorted order. - + .. manim:: CustomLayoutExample :save_last_frame: - + class CustomLayoutExample(Scene): def construct(self): import numpy as np import networkx as nx - + # create custom layout def custom_layout( graph: nx.Graph, @@ -63,7 +63,7 @@ def custom_layout( 0 ])) for i, node in enumerate(graph) } - + # draw graph n = 4 graph = Graph( @@ -77,11 +77,11 @@ def custom_layout( Several automatic layouts are provided by manim, and can be used by passing their name as the ``layout`` parameter to :meth:`~.Graph.change_layout`. Alternatively, a custom layout function can be passed to :meth:`~.Graph.change_layout` as the ``layout`` parameter. Such a function must adhere to the :class:`~.LayoutFunction` protocol. - + The :class:`~.LayoutFunction` s provided by manim are illustrated below: - + - Circular Layout: places the vertices on a circle - + .. manim:: CircularLayout :save_last_frame: @@ -96,7 +96,7 @@ def construct(self): self.add(graph) - Kamada Kawai Layout: tries to place the vertices such that the given distances between them are respected - + .. manim:: KamadaKawaiLayout :save_last_frame: @@ -104,7 +104,7 @@ class KamadaKawaiLayout(Scene): def construct(self): from collections import defaultdict distances: dict[int, dict[int, float]] = defaultdict(dict) - + # set desired distances distances[1][2] = 1 # distance between vertices 1 and 2 is 1 distances[2][3] = 1 # distance between vertices 2 and 3 is 1 @@ -112,7 +112,7 @@ def construct(self): distances[4][5] = 3 distances[5][6] = 5 distances[6][1] = 8 - + graph = Graph( [1, 2, 3, 4, 5, 6], [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 1)], @@ -123,8 +123,8 @@ def construct(self): ) self.add(graph) - - Partite Layout: places vertices into distinct partitions - + - Partite Layout: places vertices into distinct partitions + .. manim:: PartiteLayout :save_last_frame: @@ -140,7 +140,7 @@ def construct(self): self.add(graph) - Planar Layout: places vertices such that edges do not cross - + .. manim:: PlanarLayout :save_last_frame: @@ -154,9 +154,9 @@ def construct(self): labels=True ) self.add(graph) - + - Random Layout: randomly places vertices - + .. manim:: RandomLayout :save_last_frame: @@ -169,9 +169,9 @@ def construct(self): labels=True ) self.add(graph) - + - Shell Layout: places vertices in concentric circles - + .. manim:: ShellLayout :save_last_frame: @@ -188,7 +188,7 @@ def construct(self): self.add(graph) - Spectral Layout: places vertices using the eigenvectors of the graph Laplacian (clusters nodes which are an approximation of the ratio cut) - + .. manim:: SpectralLayout :save_last_frame: @@ -203,7 +203,7 @@ def construct(self): self.add(graph) - Sprial Layout: places vertices in a spiraling pattern - + .. manim:: SpiralLayout :save_last_frame: @@ -217,8 +217,8 @@ def construct(self): ) self.add(graph) - - Spring Layout: places nodes according to the Fruchterman-Reingold force-directed algorithm (attempts to minimize edge length while maximizing node separation) - + - Spring Layout: places nodes according to the Fruchterman-Reingold force-directed algorithm (attempts to minimize edge length while maximizing node separation) + .. manim:: SpringLayout :save_last_frame: @@ -233,7 +233,7 @@ def construct(self): self.add(graph) - Tree Layout: places vertices into a tree with a root node and branches (can only be used with legal trees) - + .. manim:: TreeLayout :save_last_frame: @@ -249,7 +249,7 @@ def construct(self): self.add(graph) """ - + def __call__( self, graph: NxGraph, From f32985d82d7d67580d50a8327e37970c575a11be Mon Sep 17 00:00:00 2001 From: Nikhil Iyer Date: Wed, 31 Jan 2024 01:17:19 -0500 Subject: [PATCH 17/19] Replace references to np.ndarray with standard Manim types --- manim/mobject/graph.py | 31 +++++++++++++------------------ 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/manim/mobject/graph.py b/manim/mobject/graph.py index 1b123e3ec9..569bf20aae 100644 --- a/manim/mobject/graph.py +++ b/manim/mobject/graph.py @@ -23,6 +23,7 @@ from manim.mobject.opengl.opengl_mobject import OpenGLMobject from manim.mobject.text.tex_mobject import MathTex from manim.mobject.types.vectorized_mobject import VMobject +from manim.typing import Point3D from manim.utils.color import BLACK NxGraph = Union[nx.classes.graph.Graph, nx.classes.digraph.DiGraph] @@ -256,7 +257,7 @@ def __call__( scale: float | tuple[float, float, float] = 2, *args: tuple[Any, ...], **kwargs: dict[str, Any], - ) -> dict[Hashable, np.ndarray[np.float64]]: + ) -> dict[Hashable, Point3D]: """Given a graph and a scale, return a dictionary of coordinates. Parameters @@ -268,7 +269,7 @@ def __call__( Returns ------- - dict[Hashable, np.ndarray[np.float64]] + dict[Hashable, Point3D] A dictionary mapping vertices to their positions. """ ... @@ -279,7 +280,7 @@ def _partite_layout( scale: float = 2, partitions: list[list[Hashable]] | None = None, **kwargs: dict[str, Any], -) -> dict[Hashable, np.ndarray[np.float64]]: +) -> dict[Hashable, Point3D]: if partitions is None or len(partitions) == 0: raise ValueError( "The partite layout requires partitions parameter to contain the partition of the vertices", @@ -445,12 +446,10 @@ def slide(v, dx): def _determine_graph_layout( nx_graph: nx.classes.graph.Graph | nx.classes.digraph.DiGraph, - layout: LayoutName - | dict[Hashable, np.ndarray[np.float64]] - | LayoutFunction = "spring", + layout: LayoutName | dict[Hashable, Point3D] | LayoutFunction = "spring", layout_scale: float | tuple[float, float, float] = 2, layout_config: dict[str, Any] | None = None, -) -> dict[Hashable, np.ndarray[np.float64]]: +) -> dict[Hashable, Point3D]: if layout_config is None: layout_config = {} @@ -569,9 +568,7 @@ def __init__( edges: list[tuple[Hashable, Hashable]], labels: bool | dict = False, label_fill_color: str = BLACK, - layout: LayoutName - | dict[Hashable, np.ndarray[np.float64]] - | LayoutFunction = "spring", + layout: LayoutName | dict[Hashable, Point3D] | LayoutFunction = "spring", layout_scale: float | tuple[float, float, float] = 2, layout_config: dict | None = None, vertex_type: type[Mobject] = Dot, @@ -682,13 +679,13 @@ def __getitem__(self: Graph, v: Hashable) -> Mobject: def _create_vertex( self, vertex: Hashable, - position: np.ndarray | None = None, + position: Point3D | None = None, label: bool = False, label_fill_color: str = BLACK, vertex_type: type[Mobject] = Dot, vertex_config: dict | None = None, vertex_mobject: dict | None = None, - ) -> tuple[Hashable, np.ndarray, dict, Mobject]: + ) -> tuple[Hashable, Point3D, dict, Mobject]: if position is None: position = self.get_center() @@ -726,7 +723,7 @@ def _create_vertex( def _add_created_vertex( self, vertex: Hashable, - position: np.ndarray, + position: Point3D, vertex_config: dict, vertex_mobject: Mobject, ) -> Mobject: @@ -752,7 +749,7 @@ def _add_created_vertex( def _add_vertex( self, vertex: Hashable, - position: np.ndarray | None = None, + position: Point3D | None = None, label: bool = False, label_fill_color: str = BLACK, vertex_type: type[Mobject] = Dot, @@ -807,7 +804,7 @@ def _create_vertices( vertex_type: type[Mobject] = Dot, vertex_config: dict | None = None, vertex_mobjects: dict | None = None, - ) -> Iterable[tuple[Hashable, np.ndarray, dict, Mobject]]: + ) -> Iterable[tuple[Hashable, Point3D, dict, Mobject]]: if positions is None: positions = {} if vertex_mobjects is None: @@ -1211,9 +1208,7 @@ def construct(self): def change_layout( self, - layout: LayoutName - | dict[Hashable, np.ndarray[np.float64]] - | LayoutFunction = "spring", + layout: LayoutName | dict[Hashable, Point3D] | LayoutFunction = "spring", layout_scale: float | tuple[float, float, float] = 2, layout_config: dict[str, Any] | None = None, partitions: list[list[Hashable]] | None = None, From 4ce3fe0a84740944f463fa9f15a77e33020c0f7f Mon Sep 17 00:00:00 2001 From: Nikhil Iyer Date: Mon, 5 Feb 2024 07:11:44 -0600 Subject: [PATCH 18/19] Label NxGraph as a TypeAlias --- manim/mobject/graph.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/manim/mobject/graph.py b/manim/mobject/graph.py index 569bf20aae..a947e0f769 100644 --- a/manim/mobject/graph.py +++ b/manim/mobject/graph.py @@ -10,6 +10,7 @@ import itertools as it from copy import copy from typing import Any, Hashable, Iterable, Literal, Protocol, Union, cast +from typing_extensions import TypeAlias import networkx as nx import numpy as np @@ -26,7 +27,7 @@ from manim.typing import Point3D from manim.utils.color import BLACK -NxGraph = Union[nx.classes.graph.Graph, nx.classes.digraph.DiGraph] +NxGraph: TypeAlias = Union[nx.classes.graph.Graph, nx.classes.digraph.DiGraph] class LayoutFunction(Protocol): From 8b469229ff7576e77c24f7545593e3e72ed15a0d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 5 Feb 2024 13:12:35 +0000 Subject: [PATCH 19/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- manim/mobject/graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/manim/mobject/graph.py b/manim/mobject/graph.py index a947e0f769..dda52de770 100644 --- a/manim/mobject/graph.py +++ b/manim/mobject/graph.py @@ -10,10 +10,10 @@ import itertools as it from copy import copy from typing import Any, Hashable, Iterable, Literal, Protocol, Union, cast -from typing_extensions import TypeAlias import networkx as nx import numpy as np +from typing_extensions import TypeAlias from manim.animation.composition import AnimationGroup from manim.animation.creation import Create, Uncreate