Skip to content

Commit

Permalink
start combine visualization part 3
Browse files Browse the repository at this point in the history
  • Loading branch information
bda82 committed Oct 18, 2023
1 parent 906a4dc commit 7f5df9a
Show file tree
Hide file tree
Showing 10 changed files with 202 additions and 28 deletions.
13 changes: 13 additions & 0 deletions stable_gnn/visualization/config/parameters/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,16 @@ class Defaults(ReferenceBase):
vertex_coord_min: float = -5.0
vertex_coord_multiplier: float = 0.8
vertex_coord_modifier: float = 0.1
# figure
figure_size: tuple = (6, 6)
x_limits: tuple = (0, 1.0)
y_limits: tuple = (0, 1.0)
axes_on_off: str = "off"
# core physical model
node_attraction_key: int = 0
node_repulsion_key: int = 1
edge_repulsion_key: int = 2
center_of_gravity_key: int = 3
max_iterations: int = 400
epsilon: float = 0.001
delta: float = 2.0
17 changes: 9 additions & 8 deletions stable_gnn/visualization/constructors/layout_constructor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import numpy as np

from stable_gnn.visualization.contracts.core_model_contract import CoreModelContract
from stable_gnn.visualization.contracts.layout_contract import LayoutContract
from stable_gnn.visualization.config.parameters.defaults import Defaults
from stable_gnn.visualization.equations.core_physical_model import CorePhysicalModel
from stable_gnn.visualization.equations.edge_list_to_incidence_matrix import edge_list_to_incidence_matrix
from stable_gnn.visualization.equations.init_position import init_position
from stable_gnn.visualization.exceptions.exceptions_classes import ParamsValidationException
Expand All @@ -16,20 +18,19 @@ def __call__(self, contract: LayoutContract):

centers = [np.array([0, 0])]

sim = Simulator(
core_model_contract: CoreModelContract = CoreModelContract(
nums=contract.vertex_num,
forces={
Simulator.NODE_ATTRACTION: contract.pull_edge_strength,
Simulator.NODE_REPULSION: contract.push_vertex_strength,
Simulator.EDGE_REPULSION: contract.push_edge_strength,
Simulator.CENTER_GRAVITY: contract.pull_center_strength,
Defaults.node_attraction_key: contract.pull_edge_strength,
Defaults.node_repulsion_key: contract.push_vertex_strength,
Defaults.edge_repulsion_key: contract.push_edge_strength,
Defaults.center_of_gravity_key: contract.pull_center_strength,
},
centers=centers,
)
model: CorePhysicalModel = CorePhysicalModel(core_model_contract)

vertex_coord = sim.simulate(vertex_coord,
edge_list_to_incidence_matrix(contract.vertex_num,
contract.edge_list))
vertex_coord = model.simulate(vertex_coord, edge_list_to_incidence_matrix(contract.vertex_num, contract.edge_list))
vertex_coord = ((vertex_coord - vertex_coord.min(0)) /
(vertex_coord.max(0) - vertex_coord.min(0)) *
Defaults.vertex_coord_multiplier + Defaults.vertex_coord_modifier)
Expand Down
9 changes: 9 additions & 0 deletions stable_gnn/visualization/contracts/core_model_contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from dataclasses import dataclass


@dataclass
class CoreModelContract:
nums: int | list
forces: dict
centers: list
damping_factor: float = 0.9999
13 changes: 13 additions & 0 deletions stable_gnn/visualization/contracts/draw_circle_edges_contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import matplotlib
from dataclasses import dataclass


@dataclass
class DrawEdgesContract:
axes: matplotlib.axes.Axes
vertex_coordinates: list[tuple[float, float]]
vertex_size: list
edge_list: list[tuple] | list[list[int]]
edge_color: list
edge_fill_color: list
edge_line_width: list
14 changes: 14 additions & 0 deletions stable_gnn/visualization/contracts/draw_line_edges_contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import matplotlib
import numpy as np
from dataclasses import dataclass


@dataclass
class DrawLineEdgesContract:
axes: matplotlib.axes.Axes
vertex_coordinates: np.array
vertex_size: list
edge_list: list[tuple] | list[list[int]]
show_arrow: bool
edge_color: list
edge_line_width: list
14 changes: 14 additions & 0 deletions stable_gnn/visualization/contracts/draw_vertex_contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import matplotlib
from dataclasses import dataclass


@dataclass
class DrawVertexContract:
axes: matplotlib.axes.Axes
vertex_coordinates: list[tuple[float, float]]
vertex_label: list[str] | None
font_size: int
font_family: str
vertex_size: list
vertex_color: list
vertex_line_width: list
1 change: 1 addition & 0 deletions stable_gnn/visualization/contracts/graph_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
@dataclass
class GraphContract:
vertex_num: int
edges: tuple[list[list[int]], list[float]]
edge_num: int
edge_list: list[int] | list[list[int]] | None = None
edge_weights: float | list[float] | None = None
10 changes: 5 additions & 5 deletions stable_gnn/visualization/contracts/layout_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
@dataclass
class LayoutContract:
vertex_num: int
edge_list: list[tuple]
push_vertex_strength: float
push_edge_strength: float
pull_edge_strength: float
pull_center_strength: float
edge_list: list[tuple] | list[list[int]]
push_vertex_strength: float | None
push_edge_strength: float | None
pull_edge_strength: float | None
pull_center_strength: float | None
16 changes: 16 additions & 0 deletions stable_gnn/visualization/equations/core_physical_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from stable_gnn.visualization.config.parameters.defaults import Defaults
from stable_gnn.visualization.contracts.core_model_contract import CoreModelContract


class CorePhysicalModel:
__node_attraction = Defaults.node_attraction_key
__node_repulsion = Defaults.node_repulsion_key
__edge_repulsion = Defaults.edge_repulsion_key
__center_of_gravity = Defaults.center_of_gravity_key

def __init__(self, contract: CoreModelContract):
pass

def simulate(self, init_position, H, max_iter=Defaults.max_iterations, epsilon=Defaults.epsilon,
delta=Defaults.delta):
pass
123 changes: 108 additions & 15 deletions stable_gnn/visualization/graph_visualization.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
from copy import deepcopy

import numpy as np
import matplotlib
import matplotlib.pyplot as plt

from stable_gnn.visualization.config.parameters.edge_styles import EdgeStyles
from stable_gnn.visualization.constructors.layout_constructor import LayoutConstructor
from stable_gnn.visualization.constructors.size_constructor import SizeConstructor
from stable_gnn.visualization.constructors.strength_constructor import StrengthConstructor
from stable_gnn.visualization.constructors.style_constructor import StyleConstructor
from stable_gnn.visualization.contracts.draw_circle_edges_contract import DrawEdgesContract
from stable_gnn.visualization.contracts.draw_line_edges_contract import DrawLineEdgesContract
from stable_gnn.visualization.contracts.draw_vertex_contract import DrawVertexContract
from stable_gnn.visualization.contracts.graph_visualization_contract import GraphVisualizationContract
from stable_gnn.visualization.contracts.layout_contract import LayoutContract
from stable_gnn.visualization.contracts.size_constructor_contract import SizeConstructorContract
from stable_gnn.visualization.contracts.strength_constructor_contract import StrengthConstructorContract
from stable_gnn.visualization.contracts.style_constructor_contract import StyleConstructorContract
from stable_gnn.visualization.exceptions.exceptions_classes import ParamsValidationException
from stable_gnn.visualization.config.parameters.defaults import Defaults

from stable_gnn.graph import Graph

Expand All @@ -21,27 +30,111 @@ def __init__(self, contract: GraphVisualizationContract):

self._validate()

default_style_contract: StyleConstructorContract = StyleConstructorContract(self.contract.graph.vertex_num,
self.contract.graph.edge_num,
self.contract.vertex_color,
self.contract.edge_color,
self.contract.edge_fill_color)
def draw(self):
fig, __axes = plt.subplots(figsize=Defaults.figure_size)
__vertex_num, __edge_list = self.contract.graph.vertex_num, deepcopy(self.contract.graph.edges[0])

default_style_contract: StyleConstructorContract = StyleConstructorContract(
vertex_num=self.contract.graph.vertex_num,
edges_num=self.contract.graph.edge_num,
vertex_color=self.contract.vertex_color,
edge_color=self.contract.edge_color,
edge_fill_color=self.contract.edge_fill_color
)
default_style_constructor: StyleConstructor = StyleConstructor()

v_color, e_color, e_fill_color = default_style_constructor(default_style_contract)
__vertex_color, __edge_color, __edge_fill_color = default_style_constructor(default_style_contract)

default_size_contract: SizeConstructorContract = SizeConstructorContract(self.contract.graph.vertex_num,
self.contract.graph.edge_list,
self.contract.vertex_size,
self.contract.vertex_line_width,
self.contract.edge_line_width)
default_size_contract: SizeConstructorContract = SizeConstructorContract(
vertex_num=__vertex_num,
edges_list=__edge_list,
vertex_size=self.contract.vertex_size,
vertex_line_width=self.contract.vertex_line_width,
edge_line_width=self.contract.edge_line_width,
font_size=self.contract.font_size
)
default_size_constructor: SizeConstructor = SizeConstructor()

v_size, v_line_width, e_line_width, font_size = default_size_constructor(default_size_contract)
__vertex_size, __vertex_line_width, __edge_line_width, __font_size = default_size_constructor(
default_size_contract)

def draw(self):
fig, ax = plt.subplots(figsize=(6, 6))
num_v, e_list = g.num_v, deepcopy(g.e[0])
default_strength_contract: StrengthConstructorContract = StrengthConstructorContract(
self.contract.push_vertex_strength,
self.contract.push_edge_strength,
self.contract.pull_edge_strength,
self.contract.pull_center_strength
)

default_strength_constructor: StrengthConstructor = StrengthConstructor()

(
__push_v_strength, __push_e_strength, __pull_e_strength,
__pull_center_strength,
) = default_strength_constructor(default_strength_contract)

layout_contract: LayoutContract = LayoutContract(
vertex_num=__vertex_num,
edge_list=__edge_list,
push_vertex_strength=__push_v_strength,
push_edge_strength=None,
pull_edge_strength=__pull_e_strength,
pull_center_strength=__pull_center_strength
)

layout_constructor: LayoutConstructor = LayoutConstructor()

__vertex_coordinates = layout_constructor(layout_contract)

if self.contract.edge_style == EdgeStyles.line:
draw_line_edges_contract: DrawLineEdgesContract = DrawLineEdgesContract(
axes=__axes,
vertex_coordinates=__vertex_coordinates,
vertex_size=__vertex_size,
edge_list=__edge_list,
show_arrow=False,
edge_color=__edge_color,
edge_line_width=__edge_line_width
)
self.__draw_line_edges(draw_line_edges_contract)
elif self.contract.edge_style == EdgeStyles.circle:
draw_edges_contract: DrawEdgesContract = DrawEdgesContract(
axes=__axes,
vertex_coordinates=__vertex_coordinates,
vertex_size=__vertex_size,
edge_list=__edge_list,
edge_color=__edge_color,
edge_fill_color=__edge_fill_color,
edge_line_width=__edge_line_width
)
self.__draw_circle_edges(draw_edges_contract)
else:
raise ParamsValidationException

draw_vertex_contract: DrawVertexContract = DrawVertexContract(
axes=__axes,
vertex_coordinates=__vertex_coordinates,
vertex_label=self.contract.vertex_label,
font_size=__font_size,
font_family=self.contract.font_family,
vertex_size=__vertex_size,
vertex_color=__vertex_color,
vertex_line_width=__vertex_line_width
)
self.__draw_vertex(draw_vertex_contract)

plt.xlim(Defaults.x_limits)
plt.ylim(Defaults.y_limits)
plt.axis(Defaults.axes_on_off)
fig.tight_layout()

def __draw_line_edges(self, contract: DrawLineEdgesContract):
pass

def __draw_circle_edges(self, contract: DrawEdgesContract):
pass

def __draw_vertex(self, contract: DrawVertexContract):
pass

def _validate(self):
graph_type_is_correct = isinstance(self.contract.graph, Graph)
Expand Down

0 comments on commit 7f5df9a

Please sign in to comment.