-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathsearch.py
154 lines (113 loc) · 3.85 KB
/
search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""
Search algorithms.
A lot of code copied from:
https://www.redblobgames.com/pathfinding/a-star/implementation.html
"""
import heapq
import sys
from world import Action
class PriorityQueue:
def __init__(self):
self.elements = []
def empty(self):
return len(self.elements) == 0
def put(self, item, priority):
heapq.heappush(self.elements, (priority, item))
def get(self):
return heapq.heappop(self.elements)[1]
class Graph(object):
"""
Abstract class for graphs.
"""
def __init__(self, **kw):
assert not kw, f'no additional argument expected, but found: {kw}'
def cost(self, from_node, to_node):
"""
Cost from moving from `from_node` to `to_node`.
"""
return 1
def neighbors(self, node):
"""
Obtain neighbors of `node`.
"""
raise NotImplementedError(self.__class__.__name__)
class WorldGraph(Graph):
"""
Graph based on a `World`.
"""
def __init__(self, world, **kw):
"""
Constructor.
:param world: The world this graph is based on.
:param kw: Additional arguments forwarded to parent class.
"""
super().__init__(**kw)
self.world = world
def cost(self, a, b):
return 1
def neighbors(self, node):
state = node
neighbors = []
# Try all actions and keep only valid neighbors.
for action in Action:
next_state = self.world.perform(state, action)
if next_state is not None:
neighbors.append(next_state)
return neighbors
def heuristic(from_node_def, to_node_def):
"""
The A* heuristic function.
It operates on node definitions (typically representing positions)
"""
# L1 distance.
return sum(abs(a - b) for a, b in zip(from_node_def, to_node_def))
def a_star_search(graph, start, exit_definition, extract_definition):
"""
A* algorithm.
:param graph: The graph to work on.
:param start: The start node.
:param exit_definition: How the exit node is defined (it should be an iterable of values, typically integers
representing a position).
:param extract_definition: A function that, when applied on a node, extracts its definition, to be compared to
`exit_definition` in order to compute the heuristic and check if the exit is reached. Typically this function
just extracts the position corresponding to the node.
"""
frontier = PriorityQueue()
frontier.put(start, 0)
came_from = {start: None}
cost_so_far = {start: 0}
processed = set()
current = None
found = False
n_steps = 0
while not frontier.empty():
n_steps += 1
if n_steps % 100000 == 0:
print(f'A* steps: {n_steps}')
current = frontier.get()
if extract_definition(current) == exit_definition:
# Reached the exit!
found = True
break
if current in processed:
# TODO This happens a lot! Needs investigating!
# print(f'Ignoring {current}')
continue
processed.add(current)
for next_ in graph.neighbors(current):
new_cost = cost_so_far[current] + graph.cost(current, next_)
if next_ not in cost_so_far or new_cost < cost_so_far[next_]:
cost_so_far[next_] = new_cost
priority = new_cost + heuristic(extract_definition(next_), exit_definition)
frontier.put(next_, priority)
came_from[next_] = current
if not found:
# TODO Understand why this happens
raise OverflowError('A* failed')
# print(f'A* succeeded in {n_steps} steps')
return came_from, cost_so_far, current, n_steps
def main():
# Test code, if needed.
return 0
if __name__ == '__main__':
sys.exit(main())