-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Clean Graph layouts and increase flexibility #3434
Merged
Merged
Changes from 15 commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
bbbda11
allow user-defined layout functions for Graph
Nikhil-42 bf09935
only pass relevant args
Nikhil-42 2c4a83a
write tests
Nikhil-42 dcec252
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 51c35c3
Merge branch 'main' into main
MrDiver 71694e7
change_layout forward root_vertex and partitions
Nikhil-42 8ea4dd4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] de9dc4b
add test for change_layout
Nikhil-42 021ef7e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] d3bfeb9
fix copy/paste error
Nikhil-42 0434fbd
fix
Nikhil-42 ee9b2e5
fixup types for CodeQL
Nikhil-42 19b7d16
static type the Layout Names
Nikhil-42 1be461b
Merge branch 'main' of github.com:ManimCommunity/manim
Nikhil-42 5da7f8c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 3a88ae2
fix dynamic union type for Python 3.9
Nikhil-42 a1b354d
add example scenes to LayoutFunction protocol documentation
Nikhil-42 1191240
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] f7cfb6d
Merge branch 'main' of github.com:ManimCommunity/manim
Nikhil-42 d3088fe
Merge branch 'main' into main
Nikhil-42 f32985d
Replace references to np.ndarray with standard Manim types
Nikhil-42 4ce3fe0
Label NxGraph as a TypeAlias
Nikhil-42 8b46922
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 1c02060
Merge branch 'main' into main
Nikhil-42 71e3f51
Merge branch 'main' into main
behackl File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -9,7 +9,7 @@ | |||||
|
||||||
import itertools as it | ||||||
from copy import copy | ||||||
from typing import Hashable, Iterable | ||||||
from typing import Any, Hashable, Iterable, Literal, Protocol, cast | ||||||
|
||||||
import networkx as nx | ||||||
import numpy as np | ||||||
|
@@ -25,89 +25,58 @@ | |||||
from manim.mobject.types.vectorized_mobject import VMobject | ||||||
from manim.utils.color import BLACK | ||||||
|
||||||
NxGraph = nx.classes.graph.Graph | nx.classes.digraph.DiGraph | ||||||
|
||||||
def _determine_graph_layout( | ||||||
nx_graph: nx.classes.graph.Graph | nx.classes.digraph.DiGraph, | ||||||
layout: str | dict = "spring", | ||||||
layout_scale: float = 2, | ||||||
layout_config: dict | None = None, | ||||||
partitions: list[list[Hashable]] | None = None, | ||||||
root_vertex: Hashable | None = None, | ||||||
) -> dict: | ||||||
automatic_layouts = { | ||||||
"circular": nx.layout.circular_layout, | ||||||
"kamada_kawai": nx.layout.kamada_kawai_layout, | ||||||
"planar": nx.layout.planar_layout, | ||||||
"random": nx.layout.random_layout, | ||||||
"shell": nx.layout.shell_layout, | ||||||
"spectral": nx.layout.spectral_layout, | ||||||
"partite": nx.layout.multipartite_layout, | ||||||
"tree": _tree_layout, | ||||||
"spiral": nx.layout.spiral_layout, | ||||||
"spring": nx.layout.spring_layout, | ||||||
} | ||||||
|
||||||
custom_layouts = ["random", "partite", "tree"] | ||||||
|
||||||
if layout_config is None: | ||||||
layout_config = {} | ||||||
class LayoutFunction(Protocol): | ||||||
def __call__( | ||||||
self, | ||||||
graph: NxGraph, | ||||||
scale: float | tuple[float, float, float] = 2, | ||||||
*args: tuple[Any, ...], | ||||||
**kwargs: dict[str, Any], | ||||||
) -> dict[Hashable, np.ndarray[np.float64]]: | ||||||
... | ||||||
|
||||||
|
||||||
if isinstance(layout, dict): | ||||||
return layout | ||||||
elif layout in automatic_layouts and layout not in custom_layouts: | ||||||
auto_layout = automatic_layouts[layout]( | ||||||
nx_graph, scale=layout_scale, **layout_config | ||||||
) | ||||||
# NetworkX returns a dictionary of 3D points if the dimension | ||||||
# is specified to be 3. Otherwise, it returns a dictionary of | ||||||
# 2D points, so adjusting is required. | ||||||
if layout_config.get("dim") == 3: | ||||||
return auto_layout | ||||||
else: | ||||||
return {k: np.append(v, [0]) for k, v in auto_layout.items()} | ||||||
elif layout == "tree": | ||||||
return _tree_layout( | ||||||
nx_graph, root_vertex=root_vertex, scale=layout_scale, **layout_config | ||||||
) | ||||||
elif layout == "partite": | ||||||
if partitions is None or len(partitions) == 0: | ||||||
raise ValueError( | ||||||
"The partite layout requires the 'partitions' parameter to contain the partition of the vertices", | ||||||
) | ||||||
partition_count = len(partitions) | ||||||
for i in range(partition_count): | ||||||
for v in partitions[i]: | ||||||
if nx_graph.nodes[v] is None: | ||||||
raise ValueError( | ||||||
"The partition must contain arrays of vertices in the graph", | ||||||
) | ||||||
nx_graph.nodes[v]["subset"] = i | ||||||
# Add missing vertices to their own side | ||||||
for v in nx_graph.nodes: | ||||||
if "subset" not in nx_graph.nodes[v]: | ||||||
nx_graph.nodes[v]["subset"] = partition_count | ||||||
|
||||||
auto_layout = automatic_layouts["partite"]( | ||||||
nx_graph, scale=layout_scale, **layout_config | ||||||
) | ||||||
return {k: np.append(v, [0]) for k, v in auto_layout.items()} | ||||||
elif layout == "random": | ||||||
# the random layout places coordinates in [0, 1) | ||||||
# we need to rescale manually afterwards... | ||||||
auto_layout = automatic_layouts["random"](nx_graph, **layout_config) | ||||||
for k, v in auto_layout.items(): | ||||||
auto_layout[k] = 2 * layout_scale * (v - np.array([0.5, 0.5])) | ||||||
return {k: np.append(v, [0]) for k, v in auto_layout.items()} | ||||||
else: | ||||||
|
||||||
def _partite_layout( | ||||||
nx_graph: NxGraph, | ||||||
scale: float = 2, | ||||||
partitions: list[list[Hashable]] | None = None, | ||||||
**kwargs: dict[str, Any], | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
) -> dict[Hashable, np.ndarray[np.float64]]: | ||||||
if partitions is None or len(partitions) == 0: | ||||||
raise ValueError( | ||||||
f"The layout '{layout}' is neither a recognized automatic layout, " | ||||||
"nor a vertex placement dictionary.", | ||||||
"The partite layout requires partitions parameter to contain the partition of the vertices", | ||||||
) | ||||||
partition_count = len(partitions) | ||||||
for i in range(partition_count): | ||||||
for v in partitions[i]: | ||||||
if nx_graph.nodes[v] is None: | ||||||
raise ValueError( | ||||||
"The partition must contain arrays of vertices in the graph", | ||||||
) | ||||||
nx_graph.nodes[v]["subset"] = i | ||||||
# Add missing vertices to their own side | ||||||
for v in nx_graph.nodes: | ||||||
if "subset" not in nx_graph.nodes[v]: | ||||||
nx_graph.nodes[v]["subset"] = partition_count | ||||||
|
||||||
return nx.layout.multipartite_layout(nx_graph, scale=scale, **kwargs) | ||||||
|
||||||
|
||||||
def _random_layout(nx_graph: NxGraph, scale: float = 2, **kwargs: dict[str, Any]): | ||||||
# the random layout places coordinates in [0, 1) | ||||||
# we need to rescale manually afterwards... | ||||||
auto_layout = nx.layout.random_layout(nx_graph, **kwargs) | ||||||
for k, v in auto_layout.items(): | ||||||
auto_layout[k] = 2 * scale * (v - np.array([0.5, 0.5])) | ||||||
return {k: np.append(v, [0]) for k, v in auto_layout.items()} | ||||||
|
||||||
|
||||||
def _tree_layout( | ||||||
T: nx.classes.graph.Graph | nx.classes.digraph.DiGraph, | ||||||
root_vertex: Hashable | None, | ||||||
T: NxGraph, | ||||||
root_vertex: Hashable | None = None, | ||||||
scale: float | tuple | None = 2, | ||||||
vertex_spacing: tuple | None = None, | ||||||
orientation: str = "down", | ||||||
|
@@ -212,6 +181,70 @@ | |||||
return {v: (np.array([x, y, 0]) - center) * sf for v, (x, y) in pos.items()} | ||||||
|
||||||
|
||||||
LayoutName = Literal[ | ||||||
"circular", | ||||||
"kamada_kawai", | ||||||
"partite", | ||||||
"planar", | ||||||
"random", | ||||||
"shell", | ||||||
"spectral", | ||||||
"spiral", | ||||||
"spring", | ||||||
"tree", | ||||||
] | ||||||
|
||||||
_layouts: dict[LayoutName, LayoutFunction] = { | ||||||
"circular": cast(LayoutFunction, nx.layout.circular_layout), | ||||||
"kamada_kawai": cast(LayoutFunction, nx.layout.kamada_kawai_layout), | ||||||
"partite": cast(LayoutFunction, _partite_layout), | ||||||
"planar": cast(LayoutFunction, nx.layout.planar_layout), | ||||||
"random": cast(LayoutFunction, _random_layout), | ||||||
"shell": cast(LayoutFunction, nx.layout.shell_layout), | ||||||
"spectral": cast(LayoutFunction, nx.layout.spectral_layout), | ||||||
"spiral": cast(LayoutFunction, nx.layout.spiral_layout), | ||||||
"spring": cast(LayoutFunction, nx.layout.spring_layout), | ||||||
"tree": cast(LayoutFunction, _tree_layout), | ||||||
} | ||||||
|
||||||
|
||||||
def _determine_graph_layout( | ||||||
nx_graph: nx.classes.graph.Graph | nx.classes.digraph.DiGraph, | ||||||
layout: LayoutName | ||||||
| dict[Hashable, np.ndarray[np.float64]] | ||||||
| LayoutFunction = "spring", | ||||||
layout_scale: float | tuple[float, float, float] = 2, | ||||||
layout_config: dict[str, Any] | None = None, | ||||||
) -> dict[Hashable, np.ndarray[np.float64]]: | ||||||
if layout_config is None: | ||||||
layout_config = {} | ||||||
|
||||||
if isinstance(layout, dict): | ||||||
return layout | ||||||
elif layout in _layouts: | ||||||
auto_layout = _layouts[layout](nx_graph, scale=layout_scale, **layout_config) | ||||||
# NetworkX returns a dictionary of 3D points if the dimension | ||||||
# is specified to be 3. Otherwise, it returns a dictionary of | ||||||
# 2D points, so adjusting is required. | ||||||
if ( | ||||||
layout_config.get("dim") == 3 | ||||||
or auto_layout[next(auto_layout.__iter__())].shape[0] == 3 | ||||||
): | ||||||
return auto_layout | ||||||
else: | ||||||
return {k: np.append(v, [0]) for k, v in auto_layout.items()} | ||||||
else: | ||||||
try: | ||||||
return cast(LayoutFunction, layout)( | ||||||
nx_graph, scale=layout_scale, **layout_config | ||||||
) | ||||||
except TypeError as e: | ||||||
raise ValueError( | ||||||
f"The layout '{layout}' is neither a recognized layout, a layout function," | ||||||
"nor a vertex placement dictionary.", | ||||||
) | ||||||
|
||||||
|
||||||
class GenericGraph(VMobject, metaclass=ConvertToOpenGL): | ||||||
"""Abstract base class for graphs (that is, a collection of vertices | ||||||
connected with edges). | ||||||
|
@@ -256,8 +289,8 @@ | |||||
``"planar"``, ``"random"``, ``"shell"``, ``"spectral"``, ``"spiral"``, ``"tree"``, and ``"partite"`` | ||||||
for automatic vertex positioning using ``networkx`` | ||||||
(see `their documentation <https://networkx.org/documentation/stable/reference/drawing.html#module-networkx.drawing.layout>`_ | ||||||
for more details), or a dictionary specifying a coordinate (value) | ||||||
for each vertex (key) for manual positioning. | ||||||
for more details), a dictionary specifying a coordinate (value) | ||||||
for each vertex (key) for manual positioning, or a .:class:`~.LayoutFunction` with a user-defined automatic layout. | ||||||
layout_config | ||||||
Only for automatically generated layouts. A dictionary whose entries | ||||||
are passed as keyword arguments to the automatic layout algorithm | ||||||
|
@@ -301,8 +334,10 @@ | |||||
edges: list[tuple[Hashable, Hashable]], | ||||||
labels: bool | dict = False, | ||||||
label_fill_color: str = BLACK, | ||||||
layout: str | dict = "spring", | ||||||
layout_scale: float | tuple = 2, | ||||||
layout: LayoutName | ||||||
| dict[Hashable, np.ndarray[np.float64]] | ||||||
| LayoutFunction = "spring", | ||||||
layout_scale: float | tuple[float, float, float] = 2, | ||||||
layout_config: dict | None = None, | ||||||
vertex_type: type[Mobject] = Dot, | ||||||
vertex_config: dict | None = None, | ||||||
|
@@ -319,15 +354,6 @@ | |||||
nx_graph.add_edges_from(edges) | ||||||
self._graph = nx_graph | ||||||
|
||||||
self._layout = _determine_graph_layout( | ||||||
nx_graph, | ||||||
layout=layout, | ||||||
layout_scale=layout_scale, | ||||||
layout_config=layout_config, | ||||||
partitions=partitions, | ||||||
root_vertex=root_vertex, | ||||||
) | ||||||
|
||||||
if isinstance(labels, dict): | ||||||
self._labels = labels | ||||||
elif isinstance(labels, bool): | ||||||
|
@@ -361,8 +387,14 @@ | |||||
|
||||||
self.vertices = {v: vertex_type(**self._vertex_config[v]) for v in vertices} | ||||||
self.vertices.update(vertex_mobjects) | ||||||
for v in self.vertices: | ||||||
self[v].move_to(self._layout[v]) | ||||||
|
||||||
self.change_layout( | ||||||
layout=layout, | ||||||
layout_scale=layout_scale, | ||||||
layout_config=layout_config, | ||||||
partitions=partitions, | ||||||
root_vertex=root_vertex, | ||||||
) | ||||||
|
||||||
# build edge_config | ||||||
if edge_config is None: | ||||||
|
@@ -399,7 +431,7 @@ | |||||
self.add_updater(self.update_edges) | ||||||
|
||||||
@staticmethod | ||||||
def _empty_networkx_graph(): | ||||||
def _empty_networkx_graph() -> nx.classes.graph.Graph: | ||||||
"""Return an empty networkx graph for the given graph type.""" | ||||||
raise NotImplementedError("To be implemented in concrete subclasses") | ||||||
|
||||||
|
@@ -944,9 +976,11 @@ | |||||
|
||||||
def change_layout( | ||||||
self, | ||||||
layout: str | dict = "spring", | ||||||
layout_scale: float = 2, | ||||||
layout_config: dict | None = None, | ||||||
layout: LayoutName | ||||||
| dict[Hashable, np.ndarray[np.float64]] | ||||||
| LayoutFunction = "spring", | ||||||
layout_scale: float | tuple[float, float, float] = 2, | ||||||
layout_config: dict[str, Any] | None = None, | ||||||
partitions: list[list[Hashable]] | None = None, | ||||||
root_vertex: Hashable | None = None, | ||||||
) -> Graph: | ||||||
|
@@ -970,14 +1004,19 @@ | |||||
self.play(G.animate.change_layout("circular")) | ||||||
self.wait() | ||||||
""" | ||||||
layout_config = {} if layout_config is None else layout_config | ||||||
if partitions is not None and "partitions" not in layout_config: | ||||||
layout_config["partitions"] = partitions | ||||||
if root_vertex is not None and "root_vertex" not in layout_config: | ||||||
layout_config["root_vertex"] = root_vertex | ||||||
|
||||||
self._layout = _determine_graph_layout( | ||||||
self._graph, | ||||||
layout=layout, | ||||||
layout_scale=layout_scale, | ||||||
layout_config=layout_config, | ||||||
partitions=partitions, | ||||||
root_vertex=root_vertex, | ||||||
) | ||||||
|
||||||
for v in self.vertices: | ||||||
self[v].move_to(self._layout[v]) | ||||||
return self | ||||||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -3,6 +3,7 @@ | |||||
import pytest | ||||||
|
||||||
from manim import DiGraph, Graph, Scene, Text, tempconfig | ||||||
from manim.mobject.graph import _layouts | ||||||
|
||||||
|
||||||
def test_graph_creation(): | ||||||
|
@@ -104,6 +105,73 @@ def test_custom_animation_mobject_list(): | |||||
assert scene.mobjects == [G] | ||||||
|
||||||
|
||||||
def test_custom_graph_layout_dict(): | ||||||
G = Graph( | ||||||
[1, 2, 3], [(1, 2), (2, 3)], layout={1: [0, 0, 0], 2: [1, 1, 0], 3: [1, -1, 0]} | ||||||
) | ||||||
assert str(G) == "Undirected graph on 3 vertices and 2 edges" | ||||||
assert all(G.vertices[1].get_center() == [0, 0, 0]) | ||||||
assert all(G.vertices[2].get_center() == [1, 1, 0]) | ||||||
assert all(G.vertices[3].get_center() == [1, -1, 0]) | ||||||
|
||||||
|
||||||
def test_graph_layouts(): | ||||||
for layout in ( | ||||||
layout for layout in _layouts if layout != "tree" and layout != "partite" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
): | ||||||
G = Graph([1, 2, 3], [(1, 2), (2, 3)], layout=layout) | ||||||
assert str(G) == "Undirected graph on 3 vertices and 2 edges" | ||||||
|
||||||
|
||||||
def test_tree_layout(): | ||||||
G = Graph([1, 2, 3], [(1, 2), (2, 3)], layout="tree", root_vertex=1) | ||||||
assert str(G) == "Undirected graph on 3 vertices and 2 edges" | ||||||
|
||||||
|
||||||
def test_partite_layout(): | ||||||
G = Graph( | ||||||
[1, 2, 3, 4, 5], | ||||||
[(1, 2), (2, 3), (3, 4), (4, 5)], | ||||||
layout="partite", | ||||||
partitions=[[1, 2], [3, 4, 5]], | ||||||
) | ||||||
assert str(G) == "Undirected graph on 5 vertices and 4 edges" | ||||||
|
||||||
|
||||||
def test_custom_graph_layout_function(): | ||||||
def layout_func(graph, scale): | ||||||
return {vertex: [vertex, vertex, 0] for vertex in graph} | ||||||
|
||||||
G = Graph([1, 2, 3], [(1, 2), (2, 3)], layout=layout_func) | ||||||
assert all(G.vertices[1].get_center() == [1, 1, 0]) | ||||||
assert all(G.vertices[2].get_center() == [2, 2, 0]) | ||||||
assert all(G.vertices[3].get_center() == [3, 3, 0]) | ||||||
|
||||||
|
||||||
def test_custom_graph_layout_function_with_kwargs(): | ||||||
def layout_func(graph, scale, offset): | ||||||
return { | ||||||
vertex: [vertex * scale + offset, vertex * scale + offset, 0] | ||||||
for vertex in graph | ||||||
} | ||||||
|
||||||
G = Graph( | ||||||
[1, 2, 3], [(1, 2), (2, 3)], layout=layout_func, layout_config={"offset": 1} | ||||||
) | ||||||
assert all(G.vertices[1].get_center() == [3, 3, 0]) | ||||||
assert all(G.vertices[2].get_center() == [5, 5, 0]) | ||||||
assert all(G.vertices[3].get_center() == [7, 7, 0]) | ||||||
|
||||||
|
||||||
def test_graph_change_layout(): | ||||||
for layout in ( | ||||||
layout for layout in _layouts if layout != "tree" and layout != "partite" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
): | ||||||
G = Graph([1, 2, 3], [(1, 2), (2, 3)]) | ||||||
G.change_layout(layout=layout) | ||||||
assert str(G) == "Undirected graph on 3 vertices and 2 edges" | ||||||
|
||||||
|
||||||
def test_tree_layout_no_root_error(): | ||||||
with pytest.raises(ValueError) as excinfo: | ||||||
G = Graph([1, 2, 3], [(1, 2), (2, 3)], layout="tree") | ||||||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.