-
Notifications
You must be signed in to change notification settings - Fork 1
/
Graph.py
212 lines (170 loc) · 6.11 KB
/
Graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
# VirtualGraphPrimitives.py -- A lazy/virtual graph
from abc import ABC, abstractmethod
from typing import Union, List, Tuple, Dict, Set, FrozenSet, Iterable, Any, \
NewType, Type, ClassVar, Sequence, Callable, Hashable, Collection, \
Sequence, Literal
from dataclasses import dataclass
import math
from itertools import chain
from util import empty_set, first_non_none, unique_everseen
Node = Hashable
@dataclass(frozen=True)
class Hop:
'''A directed edge.'''
from_node: Node
to_node: Node
weight: float
#@dataclass
class Graph(ABC):
'''Abstract base class for graphs.'''
#TODO (probably)
#def neighbors
@abstractmethod
def __len__(self) -> Union[int, Literal[math.inf]]: # type: ignore
'''The number of nodes in the graph, or math.inf.'''
pass
@abstractmethod
def all_nodes(self) -> Iterable[Node]:
'''An iterable containing all the nodes in this graph.'''
pass
def nodes(self, ns: Iterable[Any]) -> Iterable[Node]:
'''Returns an iterable containing all and only the nodes in ns
that exist in this graph. Default implementation simply queries
.has_node().'''
#TODO uniq
#TODO Query by user-specified criteria like IsAlphabetic.
for n in ns:
if self.has_node(n):
yield n
@abstractmethod
def has_node(self, x: Any) -> bool:
'''Does this graph contain x?'''
pass
def hops_from_node(self, x: Any) -> Iterable[Hop]:
'''An iterable of all the hops leading out of x. Default implementation
calls .successors_of(x) and gives each hop a weight of 1.0'''
for neighbor in self.successors_of(x):
yield Hop(x, neighbor, 1.0)
def hops_to_node(self, x: Any) -> Iterable[Hop]:
'''An iterable of all the hops leading into x. Default implementation
calls .predecessors_of(x) and gives each hop a weight of 1.0'''
for neighbor in self.predecessors_of(x):
yield Hop(neighbor, x, 1.0)
def hop_weight(self, from_node: Any, to_node: Any) -> float:
'''Weight of edge from from_node to to_node, or 0.0 if no such edge
exists.'''
hop = self.find_hop(from_node, to_node)
if hop is None:
return 0.0
else:
return hop.weight
@abstractmethod
def find_hop(self, from_node: Any, to_node: Any) -> Union[Hop, None]:
'''Returns the Hop from from_node to to_node, or None if it does
not exist.'''
pass
#has_hop = find_hop
#has_edge = find_hop
@abstractmethod
def successors_of(self, x: Any)-> Iterable[Node]:
'''An iterable of all the successor nodes of x.'''
pass
@abstractmethod
def predecessors_of(self, x: Any)-> Iterable[Node]:
'''An iterable of all the predecessor nodes of x.'''
pass
@classmethod
def augment(cls, *graphs: 'Graph') -> 'Graph':
#TODO docstring
return GraphSeries(graphs)
#@dataclass
class NoEdges(Graph):
'''Has methods that override those in Graph to return no edges, with
no computation.'''
def find_hop(self, from_node: Any, to_node: Any) -> Union[Hop, None]:
return None
def successors_of(self, x: Any)-> Iterable[Node]:
return []
def predecessors_of(self, x: Any)-> Iterable[Node]:
return []
@dataclass
class LiteralGraph(NoEdges):
'''A graph consisting of a set of nodes specified literally when the
graph is constructed, and no edges.'''
literals: Sequence[Node]
def __init__(self, literals, **kwargs):
self.literals = literals
super().__init__(**kwargs)
def all_nodes(self):
return self.literals
def has_node(self, x):
return x in self.literals
def __len__(self):
return len(self.literals)
"""
#@dataclass commented out due to mypy bug https://stackoverflow.com/q/69330256/1393162
@dataclass(frozen=True)
class WantEdges(Graph):
#TODO docstring
want_edges: Dict[Node, FrozenSet[Node]]
def __init__(self, want_edges, **kwargs):
self.want_edges = want_edges
super().__init__(**kwargs)
def find_hop(self, from_node, to_node):
print('WAE', from_node, to_node,
self.has_node(from_node),
to_node in self.want_edges.get(from_node, empty_set),
self.has_node(to_node),
self.__class__,
self.all_nodes(),
)
if (
self.has_node(from_node)
and
to_node in self.want_edges.get(from_node, empty_set)
and
self.has_node(to_node)
):
return Hop(from_node, to_node, 1.0)
def successors_of(self, from_node):
for to_node in self.nodes(self.want_edges.get(from_node, empty_set)):
yield to_node
def predecessors_of(self, to_node):
raise NotImplementedError
"""
@dataclass
class GraphSeries(Graph):
graphs: Sequence[Graph]
def __len__(self):
raise NotImplementedError
def all_nodes(self):
return unique_everseen(chain.from_iterable(
g.all_nodes() for g in self.graphs)
)
def has_node(self, x):
return any(g.has_node(x) for g in self.graphs)
def hops_from_node(self, x):
raise NotImplementedError
def hops_to_node(self, x):
raise NotImplementedError
# BUG g.find_hop() calls its own has_node(), not our override.
# Consequently WantEdges.find_hop() doesn't see augmented nodes.
def find_hop(self, from_node, to_node):
for g in self.graphs:
method = g.__class__.find_hop
hop = method(self, from_node, to_node)
if hop is not None:
return hop
'''
return first_non_none(
g.find_hop(from_node, to_node) for g in self.graphs
)
'''
def successors_of(self, x):
return unique_everseen(chain.from_iterable(
g.successors_of(x) for g in self.graphs)
)
def predecessors_of(self, x):
return unique_everseen(chain.from_iterable(
g.predecessors_of(x) for g in self.graphs)
)