-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #29 from aimclub/hypergraphs
Hypergraphs
- Loading branch information
Showing
80 changed files
with
2,050 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
bamt==1.1.44 | ||
optuna==2.10.1 | ||
pgmpy==0.1.20 | ||
pandas==1.5.2 | ||
pandas==1.5.2 |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,246 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
import matplotlib | ||
from matplotlib.path import Path | ||
from matplotlib.patches import Circle, PathPatch | ||
from matplotlib.collections import PatchCollection | ||
from scipy.spatial import ConvexHull | ||
|
||
import numpy as np | ||
|
||
from stable_gnn.visualization.contracts.draw_circle_edges_contract import DrawEdgesContract | ||
from stable_gnn.visualization.contracts.draw_vertex_contract import DrawVertexContract | ||
from stable_gnn.visualization.equations.calc_common_tangent_radian import common_tangent_radian | ||
from stable_gnn.visualization.equations.calc_polar_position import polar_position | ||
from stable_gnn.visualization.equations.calc_rad_to_deg import rad_to_deg | ||
from stable_gnn.visualization.equations.radian_from_atan import radian_from_atan | ||
from stable_gnn.visualization.equations.calc_vector_length import vector_length | ||
from stable_gnn.visualization.config.parameters.defaults import Defaults | ||
|
||
|
||
class BaseVisualization(ABC): | ||
""" | ||
Base visualization class with common functions. | ||
""" | ||
contract = None | ||
|
||
@abstractmethod | ||
def draw(self): | ||
""" | ||
Draw method to redefine. | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def validate(self): | ||
""" | ||
Base validator to redefine. | ||
""" | ||
raise NotImplementedError | ||
|
||
@staticmethod | ||
def draw_vertex(axes, contract: DrawVertexContract): | ||
""" | ||
Draw vertex based on contract DrawVertexContract | ||
DrawVertexContract: | ||
- vertex_coordinates - Coordinates of the vertexes | ||
- vertex_label - Labels for vertexes | ||
- font_size - Font size base | ||
- font_family - Font family | ||
- vertex_size - Sizes for vertexes | ||
- vertex_color - Color for vertexes | ||
- vertex_line_width - Widths for vertexes | ||
""" | ||
patches = [] | ||
|
||
vertex_label = contract.vertex_label | ||
|
||
if contract.vertex_label is None: | ||
vertex_label = [""] * contract.vertex_coordinates.shape[0] # noqa | ||
|
||
# Create vertexes | ||
for coordinates, label, size, width in zip(contract.vertex_coordinates.tolist(), # noqa | ||
vertex_label, | ||
contract.vertex_size, | ||
contract.vertex_line_width): | ||
circle = Circle(coordinates, size) | ||
circle.lineWidth = width | ||
|
||
if label != "": | ||
# Get coordinates | ||
x, y = coordinates[0], coordinates[1] | ||
offset = 0, -1.3 * size | ||
x += offset[0] | ||
y += offset[1] | ||
# Apply to plot exes | ||
axes.text(x, y, label, | ||
fontsize=contract.font_size, | ||
fontfamily=contract.font_family, | ||
ha='center', | ||
va='top') | ||
|
||
patches.append(circle) | ||
|
||
# Make paths | ||
p = PatchCollection(patches, facecolors=contract.vertex_color, edgecolors="black") | ||
|
||
axes.add_collection(p) | ||
|
||
def draw_circle_edges(self, axes, contract: DrawEdgesContract): | ||
""" | ||
Draw circled edge based on contract DrawEdgesContract | ||
DrawEdgesContract: | ||
- vertex_coordinates - Vertexes coordinates | ||
- vertex_size - Sizes for vertexes | ||
- edge_list - List of edges | ||
- edge_color - Colors for edges | ||
- edge_fill_color - Fill color for edges | ||
- edge_line_width - Width for edge lines | ||
""" | ||
num_vertex = len(contract.vertex_coordinates) | ||
|
||
line_paths, arc_paths, vertices = self.hull_layout(num_vertex, | ||
contract.edge_list, | ||
contract.vertex_coordinates, | ||
contract.vertex_size) | ||
|
||
# For every edge line | ||
for edge_index, lines in enumerate(line_paths): | ||
path_data = [] | ||
|
||
for line in lines: | ||
if len(line) == 0: | ||
continue | ||
|
||
start_pos, end_pos = line | ||
|
||
path_data.append((Path.MOVETO, start_pos.tolist())) | ||
path_data.append((Path.LINETO, end_pos.tolist())) | ||
|
||
if len(list(zip(*path_data))) == 0: | ||
continue | ||
|
||
codes, vertexes = zip(*path_data) | ||
|
||
# Apply to plot | ||
axes.add_patch( | ||
PathPatch(Path(vertexes, codes), | ||
linewidth=contract.edge_line_width[edge_index], | ||
facecolor=contract.edge_fill_color[edge_index], | ||
edgecolor=contract.edge_color[edge_index])) | ||
|
||
# For every arc | ||
for edge_index, arcs in enumerate(arc_paths): | ||
for arc in arcs: | ||
center, theta1, theta2, radius = arc | ||
|
||
# Apply to plot | ||
axes.add_patch( | ||
matplotlib.patches.Arc((center[0], center[1]), | ||
2 * radius, | ||
2 * radius, | ||
theta1=theta1, | ||
theta2=theta2, | ||
color=contract.edge_color[edge_index], | ||
linewidth=contract.edge_line_width[edge_index], | ||
edgecolor=contract.edge_color[edge_index], | ||
facecolor=contract.edge_fill_color[edge_index])) | ||
|
||
@staticmethod | ||
def hull_layout(num_vertex, # noqa | ||
edge_list, | ||
position, | ||
vertex_size, | ||
radius_increment=Defaults.radius_increment): | ||
|
||
# Make paths | ||
line_paths = [None] * len(edge_list) | ||
arc_paths = [None] * len(edge_list) | ||
|
||
# Make polygons | ||
polygons_vertices_index = [] | ||
vertices_radius = np.array(vertex_size) | ||
vertices_increased_radius = vertices_radius * radius_increment | ||
vertices_radius += vertices_increased_radius | ||
|
||
# Define edge characteristics | ||
edge_degree = [len(e) for e in edge_list] | ||
edge_indexes = np.argsort(np.array(edge_degree)) | ||
|
||
# For every edge index | ||
for edge_index in edge_indexes: | ||
edge = list(edge_list[edge_index]) | ||
|
||
line_path_for_edges = [] | ||
arc_path_for_edges = [] | ||
|
||
if len(edge) == 1: | ||
arc_path_for_edges.append([position[edge[0]], 0, 360, vertices_radius[edge[0]]]) | ||
|
||
vertices_radius[edge] += vertices_increased_radius[edge] | ||
|
||
line_paths[edge_index] = line_path_for_edges | ||
arc_paths[edge_index] = arc_path_for_edges | ||
|
||
continue | ||
|
||
pos_in_edge = position[edge] | ||
|
||
if len(edge) == 2: | ||
vertices_index = np.array((0, 1), dtype=np.int64) | ||
else: | ||
hull = ConvexHull(pos_in_edge) | ||
vertices_index = hull.vertices | ||
|
||
number_of_vertices = vertices_index.shape[0] | ||
|
||
vertices_index = np.append(vertices_index, vertices_index[0]) # close the loop | ||
|
||
thetas = [] | ||
|
||
# For all vertexes | ||
for i in range(number_of_vertices): | ||
# draw lines | ||
i1 = edge[vertices_index[i]] | ||
i2 = edge[vertices_index[i + 1]] | ||
|
||
r1 = vertices_radius[i1] | ||
r2 = vertices_radius[i2] | ||
|
||
p1 = position[i1] | ||
p2 = position[i2] | ||
|
||
dp = p2 - p1 | ||
dp_len = vector_length(dp) | ||
|
||
beta = radian_from_atan(dp[0], dp[1]) | ||
alpha = common_tangent_radian(r1, r2, dp_len) | ||
|
||
theta = beta - alpha | ||
start_point = polar_position(r1, theta, p1) | ||
end_point = polar_position(r2, theta, p2) | ||
|
||
line_path_for_edges.append((start_point, end_point)) | ||
thetas.append(theta) | ||
|
||
for i in range(number_of_vertices): | ||
# draw arcs | ||
theta_1 = thetas[i - 1] | ||
theta_2 = thetas[i] | ||
|
||
arc_center = position[edge[vertices_index[i]]] | ||
radius = vertices_radius[edge[vertices_index[i]]] | ||
|
||
theta_1, theta_2 = rad_to_deg(theta_1), rad_to_deg(theta_2) | ||
arc_path_for_edges.append((arc_center, theta_1, theta_2, radius)) | ||
|
||
vertices_radius[edge] += vertices_increased_radius[edge] | ||
|
||
polygons_vertices_index.append(vertices_index.copy()) | ||
|
||
# line_paths.append(line_path_for_e) | ||
# arc_paths.append(arc_path_for_e) | ||
line_paths[edge_index] = line_path_for_edges | ||
arc_paths[edge_index] = arc_path_for_edges | ||
|
||
return line_paths, arc_paths, polygons_vertices_index |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from stable_gnn.visualization.utils.frozen_dataclass import reference | ||
from stable_gnn.visualization.utils.reference_base import ReferenceBase | ||
|
||
|
||
@reference | ||
class Colors(ReferenceBase): | ||
red: str = "r" | ||
green: str = "g" | ||
gray: str = "gray" | ||
whitesmoke: str = "whitesmoke" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from stable_gnn.visualization.config.parameters.colors import Colors | ||
from stable_gnn.visualization.config.parameters.edge_styles import EdgeStyles | ||
from stable_gnn.visualization.config.parameters.fonts import Fonts | ||
from stable_gnn.visualization.utils.frozen_dataclass import reference | ||
from stable_gnn.visualization.utils.reference_base import ReferenceBase | ||
|
||
|
||
@reference | ||
class Defaults(ReferenceBase): | ||
edge_style: str = EdgeStyles.line | ||
edge_color: str = Colors.gray | ||
edge_fill_color: str = Colors.whitesmoke | ||
edge_line_width: float = 1.0 | ||
vertex_size: float = 1.0 | ||
vertex_color: str = Colors.red | ||
vertex_line_width: float = 1.0 | ||
vertex_strength: float = 1.0 | ||
font_size: float = 1.0 | ||
font_family: str = Fonts.sans_serif | ||
push_vertex_strength_vis: float = 1.0 | ||
push_edge_strength_vis: float = 1.0 | ||
pull_edge_strength_vis: float = 1.0 | ||
pull_center_strength_vis: float = 1.0 | ||
damping_factor: float = 0.9999 | ||
damping: float = 1 | ||
radius_increment: float = 0.3 | ||
force_modifier: float = -0.1 | ||
force_a_max: float = 0.1 | ||
axes_num: int = 2 | ||
# calculate_edge_line_width params | ||
edge_line_width_multiplier: float = 1.0 | ||
edge_line_width_divider: float = 120.0 | ||
# calculate_font_size params | ||
font_size_multiplier: float = 20 | ||
font_size_divider: float = 100.0 | ||
# calculate_vertex_line_width params | ||
vertex_line_width_multiplier: float = 1.0 | ||
vertex_line_width_divider: float = 50.0 | ||
# calculate_vertex_size params | ||
calculate_vertex_size_multiplier: float = 1.0 | ||
calculate_vertex_size_divider: float = 10.0 | ||
calculate_vertex_size_modifier: float = 0.1 | ||
# calc_arrow_head_width | ||
arrow_multiplier: float = 0.015 | ||
# safe_div | ||
jitter_scale: float = 0.000001 | ||
# calculate strength | ||
push_vertex_strength_g: float = 0.006 | ||
push_edge_strength_g: float = 0.0 | ||
pull_edge_strength_g: float = 0.045 | ||
pull_center_strength_g: float = 0.01 | ||
push_vertex_strength_hg: float = 1.0 | ||
push_edge_strength_hg: float = 1.0 | ||
pull_edge_strength_hg: float = 1.0 | ||
pull_center_strength_hg: float = 1.0 | ||
# calculate layout | ||
layout_scale_initial: int = 5 | ||
vertex_coord_max: float = 5.0 | ||
vertex_coord_min: float = -5.0 | ||
vertex_coord_multiplier: float = 0.8 | ||
vertex_coord_modifier: float = 0.1 | ||
# figure | ||
figure_size: tuple = (6, 6) | ||
x_limits: tuple = (0, 1.0) | ||
y_limits: tuple = (0, 1.0) | ||
axes_on_off: str = "off" | ||
# core physical model | ||
node_attraction_key: int = 0 | ||
node_repulsion_key: int = 1 | ||
edge_repulsion_key: int = 2 | ||
center_of_gravity_key: int = 3 | ||
max_iterations: int = 400 | ||
epsilon: float = 0.001 | ||
delta: float = 2.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from stable_gnn.visualization.utils.frozen_dataclass import reference | ||
from stable_gnn.visualization.utils.reference_base import ReferenceBase | ||
|
||
|
||
@reference | ||
class EdgeStyles(ReferenceBase): | ||
line: str = "line" | ||
circle: str = "circle" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from stable_gnn.visualization.utils.frozen_dataclass import reference | ||
from stable_gnn.visualization.utils.reference_base import ReferenceBase | ||
|
||
|
||
@reference | ||
class Fonts(ReferenceBase): | ||
sans_serif: str = "sans-serif" |
10 changes: 10 additions & 0 deletions
10
stable_gnn/visualization/config/parameters/generator_methods.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from stable_gnn.visualization.utils.frozen_dataclass import reference | ||
from stable_gnn.visualization.utils.reference_base import ReferenceBase | ||
|
||
|
||
@reference | ||
class GeneratorMethods(ReferenceBase): | ||
custom: str = "custom" | ||
uniform: str = "uniform" | ||
low_order_first: str = "low_order_first" | ||
high_order_first: str = "high_order_first" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
import typing as t | ||
|
||
TEdgeList = t.NewType("TEdgeList", list[tuple[int, int]]) | ||
TGraphEdgeList = t.NewType("TGraphEdgeList", tuple[TEdgeList, list[float]]) | ||
TVectorCoordinates = t.NewType("TVectorCoordinates", list[tuple[float, float]]) |
Empty file.
22 changes: 22 additions & 0 deletions
22
stable_gnn/visualization/constructors/graph_strength_constructor.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from stable_gnn.visualization.config.parameters.defaults import Defaults | ||
from stable_gnn.visualization.contracts.strength_constructor_contract import StrengthConstructorContract | ||
from stable_gnn.visualization.utils.fill_strength import fill_strength | ||
|
||
|
||
class GraphStrengthConstructor: | ||
""" | ||
Constructor (one action controller) for Graph strengths. | ||
""" | ||
|
||
def __call__(self, contract: StrengthConstructorContract) -> tuple: | ||
_push_vertex_strength = Defaults.push_vertex_strength_g | ||
_push_edge_strength = Defaults.push_edge_strength_g | ||
_pull_edge_strength = Defaults.pull_edge_strength_g | ||
_pull_center_strength = Defaults.pull_center_strength_g | ||
|
||
push_vertex_strength = fill_strength(contract.push_vertex_strength, _push_vertex_strength) | ||
push_edge_strength = fill_strength(contract.push_edge_strength, _push_edge_strength) | ||
pull_edge_strength = fill_strength(contract.pull_edge_strength, _pull_edge_strength) | ||
pull_center_strength = fill_strength(contract.pull_center_strength, _pull_center_strength) | ||
|
||
return push_vertex_strength, push_edge_strength, pull_edge_strength, pull_center_strength |
Oops, something went wrong.