Skip to content

Commit

Permalink
Merge pull request #29 from aimclub/hypergraphs
Browse files Browse the repository at this point in the history
Hypergraphs
  • Loading branch information
bda82 authored Dec 15, 2024
2 parents b29ae56 + 559b81a commit f559af3
Show file tree
Hide file tree
Showing 80 changed files with 2,050 additions and 1 deletion.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
bamt==1.1.44
optuna==2.10.1
pgmpy==0.1.20
pandas==1.5.2
pandas==1.5.2
Empty file.
246 changes: 246 additions & 0 deletions stable_gnn/visualization/base_visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
from abc import ABC, abstractmethod

import matplotlib
from matplotlib.path import Path
from matplotlib.patches import Circle, PathPatch
from matplotlib.collections import PatchCollection
from scipy.spatial import ConvexHull

import numpy as np

from stable_gnn.visualization.contracts.draw_circle_edges_contract import DrawEdgesContract
from stable_gnn.visualization.contracts.draw_vertex_contract import DrawVertexContract
from stable_gnn.visualization.equations.calc_common_tangent_radian import common_tangent_radian
from stable_gnn.visualization.equations.calc_polar_position import polar_position
from stable_gnn.visualization.equations.calc_rad_to_deg import rad_to_deg
from stable_gnn.visualization.equations.radian_from_atan import radian_from_atan
from stable_gnn.visualization.equations.calc_vector_length import vector_length
from stable_gnn.visualization.config.parameters.defaults import Defaults


class BaseVisualization(ABC):
"""
Base visualization class with common functions.
"""
contract = None

@abstractmethod
def draw(self):
"""
Draw method to redefine.
"""
raise NotImplementedError

@abstractmethod
def validate(self):
"""
Base validator to redefine.
"""
raise NotImplementedError

@staticmethod
def draw_vertex(axes, contract: DrawVertexContract):
"""
Draw vertex based on contract DrawVertexContract
DrawVertexContract:
- vertex_coordinates - Coordinates of the vertexes
- vertex_label - Labels for vertexes
- font_size - Font size base
- font_family - Font family
- vertex_size - Sizes for vertexes
- vertex_color - Color for vertexes
- vertex_line_width - Widths for vertexes
"""
patches = []

vertex_label = contract.vertex_label

if contract.vertex_label is None:
vertex_label = [""] * contract.vertex_coordinates.shape[0] # noqa

# Create vertexes
for coordinates, label, size, width in zip(contract.vertex_coordinates.tolist(), # noqa
vertex_label,
contract.vertex_size,
contract.vertex_line_width):
circle = Circle(coordinates, size)
circle.lineWidth = width

if label != "":
# Get coordinates
x, y = coordinates[0], coordinates[1]
offset = 0, -1.3 * size
x += offset[0]
y += offset[1]
# Apply to plot exes
axes.text(x, y, label,
fontsize=contract.font_size,
fontfamily=contract.font_family,
ha='center',
va='top')

patches.append(circle)

# Make paths
p = PatchCollection(patches, facecolors=contract.vertex_color, edgecolors="black")

axes.add_collection(p)

def draw_circle_edges(self, axes, contract: DrawEdgesContract):
"""
Draw circled edge based on contract DrawEdgesContract
DrawEdgesContract:
- vertex_coordinates - Vertexes coordinates
- vertex_size - Sizes for vertexes
- edge_list - List of edges
- edge_color - Colors for edges
- edge_fill_color - Fill color for edges
- edge_line_width - Width for edge lines
"""
num_vertex = len(contract.vertex_coordinates)

line_paths, arc_paths, vertices = self.hull_layout(num_vertex,
contract.edge_list,
contract.vertex_coordinates,
contract.vertex_size)

# For every edge line
for edge_index, lines in enumerate(line_paths):
path_data = []

for line in lines:
if len(line) == 0:
continue

start_pos, end_pos = line

path_data.append((Path.MOVETO, start_pos.tolist()))
path_data.append((Path.LINETO, end_pos.tolist()))

if len(list(zip(*path_data))) == 0:
continue

codes, vertexes = zip(*path_data)

# Apply to plot
axes.add_patch(
PathPatch(Path(vertexes, codes),
linewidth=contract.edge_line_width[edge_index],
facecolor=contract.edge_fill_color[edge_index],
edgecolor=contract.edge_color[edge_index]))

# For every arc
for edge_index, arcs in enumerate(arc_paths):
for arc in arcs:
center, theta1, theta2, radius = arc

# Apply to plot
axes.add_patch(
matplotlib.patches.Arc((center[0], center[1]),
2 * radius,
2 * radius,
theta1=theta1,
theta2=theta2,
color=contract.edge_color[edge_index],
linewidth=contract.edge_line_width[edge_index],
edgecolor=contract.edge_color[edge_index],
facecolor=contract.edge_fill_color[edge_index]))

@staticmethod
def hull_layout(num_vertex, # noqa
edge_list,
position,
vertex_size,
radius_increment=Defaults.radius_increment):

# Make paths
line_paths = [None] * len(edge_list)
arc_paths = [None] * len(edge_list)

# Make polygons
polygons_vertices_index = []
vertices_radius = np.array(vertex_size)
vertices_increased_radius = vertices_radius * radius_increment
vertices_radius += vertices_increased_radius

# Define edge characteristics
edge_degree = [len(e) for e in edge_list]
edge_indexes = np.argsort(np.array(edge_degree))

# For every edge index
for edge_index in edge_indexes:
edge = list(edge_list[edge_index])

line_path_for_edges = []
arc_path_for_edges = []

if len(edge) == 1:
arc_path_for_edges.append([position[edge[0]], 0, 360, vertices_radius[edge[0]]])

vertices_radius[edge] += vertices_increased_radius[edge]

line_paths[edge_index] = line_path_for_edges
arc_paths[edge_index] = arc_path_for_edges

continue

pos_in_edge = position[edge]

if len(edge) == 2:
vertices_index = np.array((0, 1), dtype=np.int64)
else:
hull = ConvexHull(pos_in_edge)
vertices_index = hull.vertices

number_of_vertices = vertices_index.shape[0]

vertices_index = np.append(vertices_index, vertices_index[0]) # close the loop

thetas = []

# For all vertexes
for i in range(number_of_vertices):
# draw lines
i1 = edge[vertices_index[i]]
i2 = edge[vertices_index[i + 1]]

r1 = vertices_radius[i1]
r2 = vertices_radius[i2]

p1 = position[i1]
p2 = position[i2]

dp = p2 - p1
dp_len = vector_length(dp)

beta = radian_from_atan(dp[0], dp[1])
alpha = common_tangent_radian(r1, r2, dp_len)

theta = beta - alpha
start_point = polar_position(r1, theta, p1)
end_point = polar_position(r2, theta, p2)

line_path_for_edges.append((start_point, end_point))
thetas.append(theta)

for i in range(number_of_vertices):
# draw arcs
theta_1 = thetas[i - 1]
theta_2 = thetas[i]

arc_center = position[edge[vertices_index[i]]]
radius = vertices_radius[edge[vertices_index[i]]]

theta_1, theta_2 = rad_to_deg(theta_1), rad_to_deg(theta_2)
arc_path_for_edges.append((arc_center, theta_1, theta_2, radius))

vertices_radius[edge] += vertices_increased_radius[edge]

polygons_vertices_index.append(vertices_index.copy())

# line_paths.append(line_path_for_e)
# arc_paths.append(arc_path_for_e)
line_paths[edge_index] = line_path_for_edges
arc_paths[edge_index] = arc_path_for_edges

return line_paths, arc_paths, polygons_vertices_index
Empty file.
Empty file.
10 changes: 10 additions & 0 deletions stable_gnn/visualization/config/parameters/colors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from stable_gnn.visualization.utils.frozen_dataclass import reference
from stable_gnn.visualization.utils.reference_base import ReferenceBase


@reference
class Colors(ReferenceBase):
red: str = "r"
green: str = "g"
gray: str = "gray"
whitesmoke: str = "whitesmoke"
74 changes: 74 additions & 0 deletions stable_gnn/visualization/config/parameters/defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from stable_gnn.visualization.config.parameters.colors import Colors
from stable_gnn.visualization.config.parameters.edge_styles import EdgeStyles
from stable_gnn.visualization.config.parameters.fonts import Fonts
from stable_gnn.visualization.utils.frozen_dataclass import reference
from stable_gnn.visualization.utils.reference_base import ReferenceBase


@reference
class Defaults(ReferenceBase):
edge_style: str = EdgeStyles.line
edge_color: str = Colors.gray
edge_fill_color: str = Colors.whitesmoke
edge_line_width: float = 1.0
vertex_size: float = 1.0
vertex_color: str = Colors.red
vertex_line_width: float = 1.0
vertex_strength: float = 1.0
font_size: float = 1.0
font_family: str = Fonts.sans_serif
push_vertex_strength_vis: float = 1.0
push_edge_strength_vis: float = 1.0
pull_edge_strength_vis: float = 1.0
pull_center_strength_vis: float = 1.0
damping_factor: float = 0.9999
damping: float = 1
radius_increment: float = 0.3
force_modifier: float = -0.1
force_a_max: float = 0.1
axes_num: int = 2
# calculate_edge_line_width params
edge_line_width_multiplier: float = 1.0
edge_line_width_divider: float = 120.0
# calculate_font_size params
font_size_multiplier: float = 20
font_size_divider: float = 100.0
# calculate_vertex_line_width params
vertex_line_width_multiplier: float = 1.0
vertex_line_width_divider: float = 50.0
# calculate_vertex_size params
calculate_vertex_size_multiplier: float = 1.0
calculate_vertex_size_divider: float = 10.0
calculate_vertex_size_modifier: float = 0.1
# calc_arrow_head_width
arrow_multiplier: float = 0.015
# safe_div
jitter_scale: float = 0.000001
# calculate strength
push_vertex_strength_g: float = 0.006
push_edge_strength_g: float = 0.0
pull_edge_strength_g: float = 0.045
pull_center_strength_g: float = 0.01
push_vertex_strength_hg: float = 1.0
push_edge_strength_hg: float = 1.0
pull_edge_strength_hg: float = 1.0
pull_center_strength_hg: float = 1.0
# calculate layout
layout_scale_initial: int = 5
vertex_coord_max: float = 5.0
vertex_coord_min: float = -5.0
vertex_coord_multiplier: float = 0.8
vertex_coord_modifier: float = 0.1
# figure
figure_size: tuple = (6, 6)
x_limits: tuple = (0, 1.0)
y_limits: tuple = (0, 1.0)
axes_on_off: str = "off"
# core physical model
node_attraction_key: int = 0
node_repulsion_key: int = 1
edge_repulsion_key: int = 2
center_of_gravity_key: int = 3
max_iterations: int = 400
epsilon: float = 0.001
delta: float = 2.0
8 changes: 8 additions & 0 deletions stable_gnn/visualization/config/parameters/edge_styles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from stable_gnn.visualization.utils.frozen_dataclass import reference
from stable_gnn.visualization.utils.reference_base import ReferenceBase


@reference
class EdgeStyles(ReferenceBase):
line: str = "line"
circle: str = "circle"
7 changes: 7 additions & 0 deletions stable_gnn/visualization/config/parameters/fonts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from stable_gnn.visualization.utils.frozen_dataclass import reference
from stable_gnn.visualization.utils.reference_base import ReferenceBase


@reference
class Fonts(ReferenceBase):
sans_serif: str = "sans-serif"
10 changes: 10 additions & 0 deletions stable_gnn/visualization/config/parameters/generator_methods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from stable_gnn.visualization.utils.frozen_dataclass import reference
from stable_gnn.visualization.utils.reference_base import ReferenceBase


@reference
class GeneratorMethods(ReferenceBase):
custom: str = "custom"
uniform: str = "uniform"
low_order_first: str = "low_order_first"
high_order_first: str = "high_order_first"
5 changes: 5 additions & 0 deletions stable_gnn/visualization/config/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import typing as t

TEdgeList = t.NewType("TEdgeList", list[tuple[int, int]])
TGraphEdgeList = t.NewType("TGraphEdgeList", tuple[TEdgeList, list[float]])
TVectorCoordinates = t.NewType("TVectorCoordinates", list[tuple[float, float]])
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from stable_gnn.visualization.config.parameters.defaults import Defaults
from stable_gnn.visualization.contracts.strength_constructor_contract import StrengthConstructorContract
from stable_gnn.visualization.utils.fill_strength import fill_strength


class GraphStrengthConstructor:
"""
Constructor (one action controller) for Graph strengths.
"""

def __call__(self, contract: StrengthConstructorContract) -> tuple:
_push_vertex_strength = Defaults.push_vertex_strength_g
_push_edge_strength = Defaults.push_edge_strength_g
_pull_edge_strength = Defaults.pull_edge_strength_g
_pull_center_strength = Defaults.pull_center_strength_g

push_vertex_strength = fill_strength(contract.push_vertex_strength, _push_vertex_strength)
push_edge_strength = fill_strength(contract.push_edge_strength, _push_edge_strength)
pull_edge_strength = fill_strength(contract.pull_edge_strength, _pull_edge_strength)
pull_center_strength = fill_strength(contract.pull_center_strength, _pull_center_strength)

return push_vertex_strength, push_edge_strength, pull_edge_strength, pull_center_strength
Loading

0 comments on commit f559af3

Please sign in to comment.