From 7f5df9a86a404690da858a98922ad2eeebf75209 Mon Sep 17 00:00:00 2001 From: bda82 Date: Wed, 18 Oct 2023 22:58:58 +0300 Subject: [PATCH] start combine visualization part 3 --- .../config/parameters/defaults.py | 13 ++ .../constructors/layout_constructor.py | 17 +-- .../contracts/core_model_contract.py | 9 ++ .../contracts/draw_circle_edges_contract.py | 13 ++ .../contracts/draw_line_edges_contract.py | 14 ++ .../contracts/draw_vertex_contract.py | 14 ++ .../visualization/contracts/graph_contract.py | 1 + .../contracts/layout_contract.py | 10 +- .../equations/core_physical_model.py | 16 +++ .../visualization/graph_visualization.py | 123 +++++++++++++++--- 10 files changed, 202 insertions(+), 28 deletions(-) create mode 100644 stable_gnn/visualization/contracts/core_model_contract.py create mode 100644 stable_gnn/visualization/contracts/draw_circle_edges_contract.py create mode 100644 stable_gnn/visualization/contracts/draw_line_edges_contract.py create mode 100644 stable_gnn/visualization/contracts/draw_vertex_contract.py create mode 100644 stable_gnn/visualization/equations/core_physical_model.py diff --git a/stable_gnn/visualization/config/parameters/defaults.py b/stable_gnn/visualization/config/parameters/defaults.py index 070dbba..6724046 100644 --- a/stable_gnn/visualization/config/parameters/defaults.py +++ b/stable_gnn/visualization/config/parameters/defaults.py @@ -45,3 +45,16 @@ class Defaults(ReferenceBase): vertex_coord_min: float = -5.0 vertex_coord_multiplier: float = 0.8 vertex_coord_modifier: float = 0.1 + # figure + figure_size: tuple = (6, 6) + x_limits: tuple = (0, 1.0) + y_limits: tuple = (0, 1.0) + axes_on_off: str = "off" + # core physical model + node_attraction_key: int = 0 + node_repulsion_key: int = 1 + edge_repulsion_key: int = 2 + center_of_gravity_key: int = 3 + max_iterations: int = 400 + epsilon: float = 0.001 + delta: float = 2.0 diff --git a/stable_gnn/visualization/constructors/layout_constructor.py b/stable_gnn/visualization/constructors/layout_constructor.py index 4e062d4..1bd6928 100644 --- a/stable_gnn/visualization/constructors/layout_constructor.py +++ b/stable_gnn/visualization/constructors/layout_constructor.py @@ -1,7 +1,9 @@ import numpy as np +from stable_gnn.visualization.contracts.core_model_contract import CoreModelContract from stable_gnn.visualization.contracts.layout_contract import LayoutContract from stable_gnn.visualization.config.parameters.defaults import Defaults +from stable_gnn.visualization.equations.core_physical_model import CorePhysicalModel from stable_gnn.visualization.equations.edge_list_to_incidence_matrix import edge_list_to_incidence_matrix from stable_gnn.visualization.equations.init_position import init_position from stable_gnn.visualization.exceptions.exceptions_classes import ParamsValidationException @@ -16,20 +18,19 @@ def __call__(self, contract: LayoutContract): centers = [np.array([0, 0])] - sim = Simulator( + core_model_contract: CoreModelContract = CoreModelContract( nums=contract.vertex_num, forces={ - Simulator.NODE_ATTRACTION: contract.pull_edge_strength, - Simulator.NODE_REPULSION: contract.push_vertex_strength, - Simulator.EDGE_REPULSION: contract.push_edge_strength, - Simulator.CENTER_GRAVITY: contract.pull_center_strength, + Defaults.node_attraction_key: contract.pull_edge_strength, + Defaults.node_repulsion_key: contract.push_vertex_strength, + Defaults.edge_repulsion_key: contract.push_edge_strength, + Defaults.center_of_gravity_key: contract.pull_center_strength, }, centers=centers, ) + model: CorePhysicalModel = CorePhysicalModel(core_model_contract) - vertex_coord = sim.simulate(vertex_coord, - edge_list_to_incidence_matrix(contract.vertex_num, - contract.edge_list)) + vertex_coord = model.simulate(vertex_coord, edge_list_to_incidence_matrix(contract.vertex_num, contract.edge_list)) vertex_coord = ((vertex_coord - vertex_coord.min(0)) / (vertex_coord.max(0) - vertex_coord.min(0)) * Defaults.vertex_coord_multiplier + Defaults.vertex_coord_modifier) diff --git a/stable_gnn/visualization/contracts/core_model_contract.py b/stable_gnn/visualization/contracts/core_model_contract.py new file mode 100644 index 0000000..98402de --- /dev/null +++ b/stable_gnn/visualization/contracts/core_model_contract.py @@ -0,0 +1,9 @@ +from dataclasses import dataclass + + +@dataclass +class CoreModelContract: + nums: int | list + forces: dict + centers: list + damping_factor: float = 0.9999 diff --git a/stable_gnn/visualization/contracts/draw_circle_edges_contract.py b/stable_gnn/visualization/contracts/draw_circle_edges_contract.py new file mode 100644 index 0000000..2098fdd --- /dev/null +++ b/stable_gnn/visualization/contracts/draw_circle_edges_contract.py @@ -0,0 +1,13 @@ +import matplotlib +from dataclasses import dataclass + + +@dataclass +class DrawEdgesContract: + axes: matplotlib.axes.Axes + vertex_coordinates: list[tuple[float, float]] + vertex_size: list + edge_list: list[tuple] | list[list[int]] + edge_color: list + edge_fill_color: list + edge_line_width: list diff --git a/stable_gnn/visualization/contracts/draw_line_edges_contract.py b/stable_gnn/visualization/contracts/draw_line_edges_contract.py new file mode 100644 index 0000000..bb240d2 --- /dev/null +++ b/stable_gnn/visualization/contracts/draw_line_edges_contract.py @@ -0,0 +1,14 @@ +import matplotlib +import numpy as np +from dataclasses import dataclass + + +@dataclass +class DrawLineEdgesContract: + axes: matplotlib.axes.Axes + vertex_coordinates: np.array + vertex_size: list + edge_list: list[tuple] | list[list[int]] + show_arrow: bool + edge_color: list + edge_line_width: list diff --git a/stable_gnn/visualization/contracts/draw_vertex_contract.py b/stable_gnn/visualization/contracts/draw_vertex_contract.py new file mode 100644 index 0000000..b2db1a0 --- /dev/null +++ b/stable_gnn/visualization/contracts/draw_vertex_contract.py @@ -0,0 +1,14 @@ +import matplotlib +from dataclasses import dataclass + + +@dataclass +class DrawVertexContract: + axes: matplotlib.axes.Axes + vertex_coordinates: list[tuple[float, float]] + vertex_label: list[str] | None + font_size: int + font_family: str + vertex_size: list + vertex_color: list + vertex_line_width: list diff --git a/stable_gnn/visualization/contracts/graph_contract.py b/stable_gnn/visualization/contracts/graph_contract.py index 1840d10..d4e0b56 100644 --- a/stable_gnn/visualization/contracts/graph_contract.py +++ b/stable_gnn/visualization/contracts/graph_contract.py @@ -4,6 +4,7 @@ @dataclass class GraphContract: vertex_num: int + edges: tuple[list[list[int]], list[float]] edge_num: int edge_list: list[int] | list[list[int]] | None = None edge_weights: float | list[float] | None = None diff --git a/stable_gnn/visualization/contracts/layout_contract.py b/stable_gnn/visualization/contracts/layout_contract.py index 4332878..f6ca3f9 100644 --- a/stable_gnn/visualization/contracts/layout_contract.py +++ b/stable_gnn/visualization/contracts/layout_contract.py @@ -4,8 +4,8 @@ @dataclass class LayoutContract: vertex_num: int - edge_list: list[tuple] - push_vertex_strength: float - push_edge_strength: float - pull_edge_strength: float - pull_center_strength: float + edge_list: list[tuple] | list[list[int]] + push_vertex_strength: float | None + push_edge_strength: float | None + pull_edge_strength: float | None + pull_center_strength: float | None diff --git a/stable_gnn/visualization/equations/core_physical_model.py b/stable_gnn/visualization/equations/core_physical_model.py new file mode 100644 index 0000000..5fbb4dd --- /dev/null +++ b/stable_gnn/visualization/equations/core_physical_model.py @@ -0,0 +1,16 @@ +from stable_gnn.visualization.config.parameters.defaults import Defaults +from stable_gnn.visualization.contracts.core_model_contract import CoreModelContract + + +class CorePhysicalModel: + __node_attraction = Defaults.node_attraction_key + __node_repulsion = Defaults.node_repulsion_key + __edge_repulsion = Defaults.edge_repulsion_key + __center_of_gravity = Defaults.center_of_gravity_key + + def __init__(self, contract: CoreModelContract): + pass + + def simulate(self, init_position, H, max_iter=Defaults.max_iterations, epsilon=Defaults.epsilon, + delta=Defaults.delta): + pass diff --git a/stable_gnn/visualization/graph_visualization.py b/stable_gnn/visualization/graph_visualization.py index 8cb300b..f71eb5a 100644 --- a/stable_gnn/visualization/graph_visualization.py +++ b/stable_gnn/visualization/graph_visualization.py @@ -1,15 +1,24 @@ from copy import deepcopy import numpy as np +import matplotlib import matplotlib.pyplot as plt from stable_gnn.visualization.config.parameters.edge_styles import EdgeStyles +from stable_gnn.visualization.constructors.layout_constructor import LayoutConstructor from stable_gnn.visualization.constructors.size_constructor import SizeConstructor +from stable_gnn.visualization.constructors.strength_constructor import StrengthConstructor from stable_gnn.visualization.constructors.style_constructor import StyleConstructor +from stable_gnn.visualization.contracts.draw_circle_edges_contract import DrawEdgesContract +from stable_gnn.visualization.contracts.draw_line_edges_contract import DrawLineEdgesContract +from stable_gnn.visualization.contracts.draw_vertex_contract import DrawVertexContract from stable_gnn.visualization.contracts.graph_visualization_contract import GraphVisualizationContract +from stable_gnn.visualization.contracts.layout_contract import LayoutContract from stable_gnn.visualization.contracts.size_constructor_contract import SizeConstructorContract +from stable_gnn.visualization.contracts.strength_constructor_contract import StrengthConstructorContract from stable_gnn.visualization.contracts.style_constructor_contract import StyleConstructorContract from stable_gnn.visualization.exceptions.exceptions_classes import ParamsValidationException +from stable_gnn.visualization.config.parameters.defaults import Defaults from stable_gnn.graph import Graph @@ -21,27 +30,111 @@ def __init__(self, contract: GraphVisualizationContract): self._validate() - default_style_contract: StyleConstructorContract = StyleConstructorContract(self.contract.graph.vertex_num, - self.contract.graph.edge_num, - self.contract.vertex_color, - self.contract.edge_color, - self.contract.edge_fill_color) + def draw(self): + fig, __axes = plt.subplots(figsize=Defaults.figure_size) + __vertex_num, __edge_list = self.contract.graph.vertex_num, deepcopy(self.contract.graph.edges[0]) + + default_style_contract: StyleConstructorContract = StyleConstructorContract( + vertex_num=self.contract.graph.vertex_num, + edges_num=self.contract.graph.edge_num, + vertex_color=self.contract.vertex_color, + edge_color=self.contract.edge_color, + edge_fill_color=self.contract.edge_fill_color + ) default_style_constructor: StyleConstructor = StyleConstructor() - v_color, e_color, e_fill_color = default_style_constructor(default_style_contract) + __vertex_color, __edge_color, __edge_fill_color = default_style_constructor(default_style_contract) - default_size_contract: SizeConstructorContract = SizeConstructorContract(self.contract.graph.vertex_num, - self.contract.graph.edge_list, - self.contract.vertex_size, - self.contract.vertex_line_width, - self.contract.edge_line_width) + default_size_contract: SizeConstructorContract = SizeConstructorContract( + vertex_num=__vertex_num, + edges_list=__edge_list, + vertex_size=self.contract.vertex_size, + vertex_line_width=self.contract.vertex_line_width, + edge_line_width=self.contract.edge_line_width, + font_size=self.contract.font_size + ) default_size_constructor: SizeConstructor = SizeConstructor() - v_size, v_line_width, e_line_width, font_size = default_size_constructor(default_size_contract) + __vertex_size, __vertex_line_width, __edge_line_width, __font_size = default_size_constructor( + default_size_contract) - def draw(self): - fig, ax = plt.subplots(figsize=(6, 6)) - num_v, e_list = g.num_v, deepcopy(g.e[0]) + default_strength_contract: StrengthConstructorContract = StrengthConstructorContract( + self.contract.push_vertex_strength, + self.contract.push_edge_strength, + self.contract.pull_edge_strength, + self.contract.pull_center_strength + ) + + default_strength_constructor: StrengthConstructor = StrengthConstructor() + + ( + __push_v_strength, __push_e_strength, __pull_e_strength, + __pull_center_strength, + ) = default_strength_constructor(default_strength_contract) + + layout_contract: LayoutContract = LayoutContract( + vertex_num=__vertex_num, + edge_list=__edge_list, + push_vertex_strength=__push_v_strength, + push_edge_strength=None, + pull_edge_strength=__pull_e_strength, + pull_center_strength=__pull_center_strength + ) + + layout_constructor: LayoutConstructor = LayoutConstructor() + + __vertex_coordinates = layout_constructor(layout_contract) + + if self.contract.edge_style == EdgeStyles.line: + draw_line_edges_contract: DrawLineEdgesContract = DrawLineEdgesContract( + axes=__axes, + vertex_coordinates=__vertex_coordinates, + vertex_size=__vertex_size, + edge_list=__edge_list, + show_arrow=False, + edge_color=__edge_color, + edge_line_width=__edge_line_width + ) + self.__draw_line_edges(draw_line_edges_contract) + elif self.contract.edge_style == EdgeStyles.circle: + draw_edges_contract: DrawEdgesContract = DrawEdgesContract( + axes=__axes, + vertex_coordinates=__vertex_coordinates, + vertex_size=__vertex_size, + edge_list=__edge_list, + edge_color=__edge_color, + edge_fill_color=__edge_fill_color, + edge_line_width=__edge_line_width + ) + self.__draw_circle_edges(draw_edges_contract) + else: + raise ParamsValidationException + + draw_vertex_contract: DrawVertexContract = DrawVertexContract( + axes=__axes, + vertex_coordinates=__vertex_coordinates, + vertex_label=self.contract.vertex_label, + font_size=__font_size, + font_family=self.contract.font_family, + vertex_size=__vertex_size, + vertex_color=__vertex_color, + vertex_line_width=__vertex_line_width + ) + self.__draw_vertex(draw_vertex_contract) + + plt.xlim(Defaults.x_limits) + plt.ylim(Defaults.y_limits) + plt.axis(Defaults.axes_on_off) + fig.tight_layout() + + def __draw_line_edges(self, contract: DrawLineEdgesContract): + pass + + def __draw_circle_edges(self, contract: DrawEdgesContract): + pass + + def __draw_vertex(self, contract: DrawVertexContract): + pass def _validate(self): graph_type_is_correct = isinstance(self.contract.graph, Graph)