Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean Graph layouts and increase flexibility #3434

Merged
merged 25 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
bbbda11
allow user-defined layout functions for Graph
Nikhil-42 Nov 3, 2023
bf09935
only pass relevant args
Nikhil-42 Nov 3, 2023
2c4a83a
write tests
Nikhil-42 Nov 3, 2023
dcec252
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 3, 2023
51c35c3
Merge branch 'main' into main
MrDiver Dec 2, 2023
71694e7
change_layout forward root_vertex and partitions
Nikhil-42 Dec 2, 2023
8ea4dd4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 2, 2023
de9dc4b
add test for change_layout
Nikhil-42 Dec 2, 2023
021ef7e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 2, 2023
d3bfeb9
fix copy/paste error
Nikhil-42 Dec 2, 2023
0434fbd
fix
Nikhil-42 Dec 2, 2023
ee9b2e5
fixup types for CodeQL
Nikhil-42 Dec 2, 2023
19b7d16
static type the Layout Names
Nikhil-42 Dec 25, 2023
1be461b
Merge branch 'main' of github.com:ManimCommunity/manim
Nikhil-42 Dec 25, 2023
5da7f8c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 25, 2023
3a88ae2
fix dynamic union type for Python 3.9
Nikhil-42 Dec 25, 2023
a1b354d
add example scenes to LayoutFunction protocol documentation
Nikhil-42 Jan 3, 2024
1191240
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 3, 2024
f7cfb6d
Merge branch 'main' of github.com:ManimCommunity/manim
Nikhil-42 Jan 3, 2024
d3088fe
Merge branch 'main' into main
Nikhil-42 Jan 30, 2024
f32985d
Replace references to np.ndarray with standard Manim types
Nikhil-42 Jan 31, 2024
4ce3fe0
Label NxGraph as a TypeAlias
Nikhil-42 Feb 5, 2024
8b46922
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 5, 2024
1c02060
Merge branch 'main' into main
Nikhil-42 Feb 5, 2024
71e3f51
Merge branch 'main' into main
behackl Apr 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 136 additions & 97 deletions manim/mobject/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import itertools as it
from copy import copy
from typing import Hashable, Iterable
from typing import Any, Hashable, Iterable, Literal, Protocol, cast

import networkx as nx
import numpy as np
Expand All @@ -25,89 +25,58 @@
from manim.mobject.types.vectorized_mobject import VMobject
from manim.utils.color import BLACK

NxGraph = nx.classes.graph.Graph | nx.classes.digraph.DiGraph

def _determine_graph_layout(
nx_graph: nx.classes.graph.Graph | nx.classes.digraph.DiGraph,
layout: str | dict = "spring",
layout_scale: float = 2,
layout_config: dict | None = None,
partitions: list[list[Hashable]] | None = None,
root_vertex: Hashable | None = None,
) -> dict:
automatic_layouts = {
"circular": nx.layout.circular_layout,
"kamada_kawai": nx.layout.kamada_kawai_layout,
"planar": nx.layout.planar_layout,
"random": nx.layout.random_layout,
"shell": nx.layout.shell_layout,
"spectral": nx.layout.spectral_layout,
"partite": nx.layout.multipartite_layout,
"tree": _tree_layout,
"spiral": nx.layout.spiral_layout,
"spring": nx.layout.spring_layout,
}

custom_layouts = ["random", "partite", "tree"]

if layout_config is None:
layout_config = {}
class LayoutFunction(Protocol):
def __call__(
self,
graph: NxGraph,
scale: float | tuple[float, float, float] = 2,
*args: tuple[Any, ...],
**kwargs: dict[str, Any],
Comment on lines +259 to +260
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
*args: tuple[Any, ...],
**kwargs: dict[str, Any],
*args: Any,
**kwargs: Any,

) -> dict[Hashable, np.ndarray[np.float64]]:
...
Fixed Show fixed Hide fixed
Dismissed Show dismissed Hide dismissed

if isinstance(layout, dict):
return layout
elif layout in automatic_layouts and layout not in custom_layouts:
auto_layout = automatic_layouts[layout](
nx_graph, scale=layout_scale, **layout_config
)
# NetworkX returns a dictionary of 3D points if the dimension
# is specified to be 3. Otherwise, it returns a dictionary of
# 2D points, so adjusting is required.
if layout_config.get("dim") == 3:
return auto_layout
else:
return {k: np.append(v, [0]) for k, v in auto_layout.items()}
elif layout == "tree":
return _tree_layout(
nx_graph, root_vertex=root_vertex, scale=layout_scale, **layout_config
)
elif layout == "partite":
if partitions is None or len(partitions) == 0:
raise ValueError(
"The partite layout requires the 'partitions' parameter to contain the partition of the vertices",
)
partition_count = len(partitions)
for i in range(partition_count):
for v in partitions[i]:
if nx_graph.nodes[v] is None:
raise ValueError(
"The partition must contain arrays of vertices in the graph",
)
nx_graph.nodes[v]["subset"] = i
# Add missing vertices to their own side
for v in nx_graph.nodes:
if "subset" not in nx_graph.nodes[v]:
nx_graph.nodes[v]["subset"] = partition_count

auto_layout = automatic_layouts["partite"](
nx_graph, scale=layout_scale, **layout_config
)
return {k: np.append(v, [0]) for k, v in auto_layout.items()}
elif layout == "random":
# the random layout places coordinates in [0, 1)
# we need to rescale manually afterwards...
auto_layout = automatic_layouts["random"](nx_graph, **layout_config)
for k, v in auto_layout.items():
auto_layout[k] = 2 * layout_scale * (v - np.array([0.5, 0.5]))
return {k: np.append(v, [0]) for k, v in auto_layout.items()}
else:

def _partite_layout(
nx_graph: NxGraph,
scale: float = 2,
partitions: list[list[Hashable]] | None = None,
**kwargs: dict[str, Any],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
**kwargs: dict[str, Any],
**kwargs: Any,

) -> dict[Hashable, np.ndarray[np.float64]]:
if partitions is None or len(partitions) == 0:
raise ValueError(
f"The layout '{layout}' is neither a recognized automatic layout, "
"nor a vertex placement dictionary.",
"The partite layout requires partitions parameter to contain the partition of the vertices",
)
partition_count = len(partitions)
for i in range(partition_count):
for v in partitions[i]:
if nx_graph.nodes[v] is None:
raise ValueError(
"The partition must contain arrays of vertices in the graph",
)
nx_graph.nodes[v]["subset"] = i
# Add missing vertices to their own side
for v in nx_graph.nodes:
if "subset" not in nx_graph.nodes[v]:
nx_graph.nodes[v]["subset"] = partition_count

return nx.layout.multipartite_layout(nx_graph, scale=scale, **kwargs)


def _random_layout(nx_graph: NxGraph, scale: float = 2, **kwargs: dict[str, Any]):
# the random layout places coordinates in [0, 1)
# we need to rescale manually afterwards...
auto_layout = nx.layout.random_layout(nx_graph, **kwargs)
for k, v in auto_layout.items():
auto_layout[k] = 2 * scale * (v - np.array([0.5, 0.5]))
return {k: np.append(v, [0]) for k, v in auto_layout.items()}


def _tree_layout(
T: nx.classes.graph.Graph | nx.classes.digraph.DiGraph,
root_vertex: Hashable | None,
T: NxGraph,
root_vertex: Hashable | None = None,
scale: float | tuple | None = 2,
vertex_spacing: tuple | None = None,
orientation: str = "down",
Expand Down Expand Up @@ -212,6 +181,70 @@
return {v: (np.array([x, y, 0]) - center) * sf for v, (x, y) in pos.items()}


LayoutName = Literal[
"circular",
"kamada_kawai",
"partite",
"planar",
"random",
"shell",
"spectral",
"spiral",
"spring",
"tree",
]

_layouts: dict[LayoutName, LayoutFunction] = {
"circular": cast(LayoutFunction, nx.layout.circular_layout),
"kamada_kawai": cast(LayoutFunction, nx.layout.kamada_kawai_layout),
"partite": cast(LayoutFunction, _partite_layout),
"planar": cast(LayoutFunction, nx.layout.planar_layout),
"random": cast(LayoutFunction, _random_layout),
"shell": cast(LayoutFunction, nx.layout.shell_layout),
"spectral": cast(LayoutFunction, nx.layout.spectral_layout),
"spiral": cast(LayoutFunction, nx.layout.spiral_layout),
"spring": cast(LayoutFunction, nx.layout.spring_layout),
"tree": cast(LayoutFunction, _tree_layout),
}


def _determine_graph_layout(
nx_graph: nx.classes.graph.Graph | nx.classes.digraph.DiGraph,
layout: LayoutName
| dict[Hashable, np.ndarray[np.float64]]
| LayoutFunction = "spring",
layout_scale: float | tuple[float, float, float] = 2,
layout_config: dict[str, Any] | None = None,
) -> dict[Hashable, np.ndarray[np.float64]]:
if layout_config is None:
layout_config = {}

if isinstance(layout, dict):
return layout
elif layout in _layouts:
auto_layout = _layouts[layout](nx_graph, scale=layout_scale, **layout_config)
# NetworkX returns a dictionary of 3D points if the dimension
# is specified to be 3. Otherwise, it returns a dictionary of
# 2D points, so adjusting is required.
if (
layout_config.get("dim") == 3
or auto_layout[next(auto_layout.__iter__())].shape[0] == 3
):
return auto_layout
else:
return {k: np.append(v, [0]) for k, v in auto_layout.items()}
else:
try:
return cast(LayoutFunction, layout)(
nx_graph, scale=layout_scale, **layout_config
)
except TypeError as e:
raise ValueError(
f"The layout '{layout}' is neither a recognized layout, a layout function,"
"nor a vertex placement dictionary.",
)


class GenericGraph(VMobject, metaclass=ConvertToOpenGL):
"""Abstract base class for graphs (that is, a collection of vertices
connected with edges).
Expand Down Expand Up @@ -256,8 +289,8 @@
``"planar"``, ``"random"``, ``"shell"``, ``"spectral"``, ``"spiral"``, ``"tree"``, and ``"partite"``
for automatic vertex positioning using ``networkx``
(see `their documentation <https://networkx.org/documentation/stable/reference/drawing.html#module-networkx.drawing.layout>`_
for more details), or a dictionary specifying a coordinate (value)
for each vertex (key) for manual positioning.
for more details), a dictionary specifying a coordinate (value)
for each vertex (key) for manual positioning, or a .:class:`~.LayoutFunction` with a user-defined automatic layout.
layout_config
Only for automatically generated layouts. A dictionary whose entries
are passed as keyword arguments to the automatic layout algorithm
Expand Down Expand Up @@ -301,8 +334,10 @@
edges: list[tuple[Hashable, Hashable]],
labels: bool | dict = False,
label_fill_color: str = BLACK,
layout: str | dict = "spring",
layout_scale: float | tuple = 2,
layout: LayoutName
| dict[Hashable, np.ndarray[np.float64]]
| LayoutFunction = "spring",
layout_scale: float | tuple[float, float, float] = 2,
layout_config: dict | None = None,
vertex_type: type[Mobject] = Dot,
vertex_config: dict | None = None,
Expand All @@ -319,15 +354,6 @@
nx_graph.add_edges_from(edges)
self._graph = nx_graph

self._layout = _determine_graph_layout(
nx_graph,
layout=layout,
layout_scale=layout_scale,
layout_config=layout_config,
partitions=partitions,
root_vertex=root_vertex,
)

if isinstance(labels, dict):
self._labels = labels
elif isinstance(labels, bool):
Expand Down Expand Up @@ -361,8 +387,14 @@

self.vertices = {v: vertex_type(**self._vertex_config[v]) for v in vertices}
self.vertices.update(vertex_mobjects)
for v in self.vertices:
self[v].move_to(self._layout[v])

self.change_layout(
layout=layout,
layout_scale=layout_scale,
layout_config=layout_config,
partitions=partitions,
root_vertex=root_vertex,
)

# build edge_config
if edge_config is None:
Expand Down Expand Up @@ -399,7 +431,7 @@
self.add_updater(self.update_edges)

@staticmethod
def _empty_networkx_graph():
def _empty_networkx_graph() -> nx.classes.graph.Graph:
"""Return an empty networkx graph for the given graph type."""
raise NotImplementedError("To be implemented in concrete subclasses")

Expand Down Expand Up @@ -944,9 +976,11 @@

def change_layout(
self,
layout: str | dict = "spring",
layout_scale: float = 2,
layout_config: dict | None = None,
layout: LayoutName
| dict[Hashable, np.ndarray[np.float64]]
| LayoutFunction = "spring",
layout_scale: float | tuple[float, float, float] = 2,
layout_config: dict[str, Any] | None = None,
partitions: list[list[Hashable]] | None = None,
root_vertex: Hashable | None = None,
) -> Graph:
Expand All @@ -970,14 +1004,19 @@
self.play(G.animate.change_layout("circular"))
self.wait()
"""
layout_config = {} if layout_config is None else layout_config
if partitions is not None and "partitions" not in layout_config:
layout_config["partitions"] = partitions
if root_vertex is not None and "root_vertex" not in layout_config:
layout_config["root_vertex"] = root_vertex

self._layout = _determine_graph_layout(
self._graph,
layout=layout,
layout_scale=layout_scale,
layout_config=layout_config,
partitions=partitions,
root_vertex=root_vertex,
)

for v in self.vertices:
self[v].move_to(self._layout[v])
return self
Expand Down
68 changes: 68 additions & 0 deletions tests/module/mobject/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

from manim import DiGraph, Graph, Scene, Text, tempconfig
from manim.mobject.graph import _layouts


def test_graph_creation():
Expand Down Expand Up @@ -104,6 +105,73 @@ def test_custom_animation_mobject_list():
assert scene.mobjects == [G]


def test_custom_graph_layout_dict():
G = Graph(
[1, 2, 3], [(1, 2), (2, 3)], layout={1: [0, 0, 0], 2: [1, 1, 0], 3: [1, -1, 0]}
)
assert str(G) == "Undirected graph on 3 vertices and 2 edges"
assert all(G.vertices[1].get_center() == [0, 0, 0])
assert all(G.vertices[2].get_center() == [1, 1, 0])
assert all(G.vertices[3].get_center() == [1, -1, 0])


def test_graph_layouts():
for layout in (
layout for layout in _layouts if layout != "tree" and layout != "partite"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
layout for layout in _layouts if layout != "tree" and layout != "partite"
layout for layout in _layouts if layout not in {"tree", "partite"}

):
G = Graph([1, 2, 3], [(1, 2), (2, 3)], layout=layout)
assert str(G) == "Undirected graph on 3 vertices and 2 edges"


def test_tree_layout():
G = Graph([1, 2, 3], [(1, 2), (2, 3)], layout="tree", root_vertex=1)
assert str(G) == "Undirected graph on 3 vertices and 2 edges"


def test_partite_layout():
G = Graph(
[1, 2, 3, 4, 5],
[(1, 2), (2, 3), (3, 4), (4, 5)],
layout="partite",
partitions=[[1, 2], [3, 4, 5]],
)
assert str(G) == "Undirected graph on 5 vertices and 4 edges"


def test_custom_graph_layout_function():
def layout_func(graph, scale):
return {vertex: [vertex, vertex, 0] for vertex in graph}

G = Graph([1, 2, 3], [(1, 2), (2, 3)], layout=layout_func)
assert all(G.vertices[1].get_center() == [1, 1, 0])
assert all(G.vertices[2].get_center() == [2, 2, 0])
assert all(G.vertices[3].get_center() == [3, 3, 0])


def test_custom_graph_layout_function_with_kwargs():
def layout_func(graph, scale, offset):
return {
vertex: [vertex * scale + offset, vertex * scale + offset, 0]
for vertex in graph
}

G = Graph(
[1, 2, 3], [(1, 2), (2, 3)], layout=layout_func, layout_config={"offset": 1}
)
assert all(G.vertices[1].get_center() == [3, 3, 0])
assert all(G.vertices[2].get_center() == [5, 5, 0])
assert all(G.vertices[3].get_center() == [7, 7, 0])


def test_graph_change_layout():
for layout in (
layout for layout in _layouts if layout != "tree" and layout != "partite"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
layout for layout in _layouts if layout != "tree" and layout != "partite"
layout for layout in _layouts if layout not in {"tree", "partite"}

):
G = Graph([1, 2, 3], [(1, 2), (2, 3)])
G.change_layout(layout=layout)
assert str(G) == "Undirected graph on 3 vertices and 2 edges"


def test_tree_layout_no_root_error():
with pytest.raises(ValueError) as excinfo:
G = Graph([1, 2, 3], [(1, 2), (2, 3)], layout="tree")
Expand Down
Loading