-
Notifications
You must be signed in to change notification settings - Fork 6
/
util.py
46 lines (35 loc) · 938 Bytes
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def cumsum(l):
return [sum(l[:i+1]) for i in range(len(l))]
def unfold(d):
r = []
for i in d.items():
for j in i[1]:
r.append((i[0], j))
return r
def flatten(l):
return (y for x in l for y in x)
class GGraph:
""" graphviz graph renderer """
def __init__(self, name):
import graphviz as gv
self._g = gv.Digraph(format='svg')
self.name = name
self.weights = None
def add_weights(self, weights):
self.weights = weights
def get_weight(self, node):
if node < 0 or node >= len(self.weights):
return 0
return self.weights[node]
def add_nodes(self, nodes):
for n in nodes:
self._g.node(str(n))
def add_edges(self,edges, color):
if self.weights:
for e in edges:
self._g.edge(*[str(x) for x in e], color=color, label=str(self.get_weight(e[0])))
else:
for e in edges:
self._g.edge(*[str(x) for x in e], color=color)
def render(self):
self._g.render(filename=self.name, view=True)