-
Notifications
You must be signed in to change notification settings - Fork 10
/
mcts.jl
273 lines (235 loc) · 8.92 KB
/
mcts.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
export MonteCarloTreeSearch, MCTS, MCTSTreeSolution
export MCTSNodeSelector, MaxUCBSelector, BoltzmannUCBSelector
export MCTSLeafEstimator, RandomRolloutEstimator, ConstantEstimator
## MCTS solution type ##
"MCTS policy solution."
@auto_hash_equals struct MCTSTreeSolution <: PolicySolution
Q::Dict{UInt64,Dict{Term,Float64}}
s_visits::Dict{UInt64,Int}
a_visits::Dict{UInt64,Dict{Term,Int}}
end
MCTSTreeSolution() = MCTSTreeSolution(Dict(), Dict(), Dict())
function Base.copy(sol::MCTSTreeSolution)
Q = Dict(s => copy(qs) for (s, qs) in sol.Q)
s_visits = copy(s_visits)
a_visits = Dict(s => copy(visits) for (s, visits) in sol.a_visits)
return MCTSTreeSolution(Q, s_visits, a_visits)
end
get_action(sol::MCTSTreeSolution, state::State) =
best_action(sol, state)
best_action(sol::MCTSTreeSolution, state::State) =
argmax(sol.Q[hash(state)])
rand_action(sol::MCTSTreeSolution, state::State) =
best_action(sol, state)
has_values(sol::MCTSTreeSolution) =
true
get_value(sol::MCTSTreeSolution, state::State) =
maximum(sol.Q[hash(state)])
get_value(sol::MCTSTreeSolution, state::State, action::Term) =
sol.Q[hash(state)][action]
get_action_values(sol::MCTSTreeSolution, state::State) =
pairs(sol.Q[hash(state)])
has_cached_value(sol::MCTSTreeSolution, state::State) =
has_cached_value(sol, hash(state))
has_cached_value(sol::MCTSTreeSolution, state_id::UInt) =
haskey(sol.Q, state_id)
has_cached_value(sol::MCTSTreeSolution, state::State, action::Term) =
has_cached_value(sol, hash(state), action)
has_cached_value(sol::MCTSTreeSolution, state_id::UInt, action::Term) =
haskey(sol.Q, state_id) && haskey(sol.Q[state_id], action)
has_cached_action_values(sol::MCTSTreeSolution, state::State) =
haskey(sol.Q, hash(state))
has_state_node(sol::MCTSTreeSolution, state::State) =
haskey(sol.s_visits, hash(state))
## Selection strategies for leaf nodes ##
"Max upper-confidence bound (UCB) action policy."
@auto_hash_equals struct MaxUCBPolicy <: PolicySolution
tree::MCTSTreeSolution
confidence::Float64
end
Base.copy(sol::MaxUCBPolicy) = MaxUCBPolicy(copy(sol.tree), sol.confidence)
function get_action_values(sol::MaxUCBPolicy, state::State)
state_id = hash(state)
s_visits = sol.tree.s_visits[state_id]
vals = Base.Generator(sol.tree.Q[state_id]) do (act, val)
a_visits = sol.tree.a_visits[state_id][act]
if s_visits != 0 && !(s_visits == 1 && a_visits == 0)
val += sol.confidence * sqrt(log(s_visits) / a_visits)
end
return act => val
end
return Dict(vals)
end
function best_action(sol::MaxUCBPolicy, state::State)
return argmax(get_action_values(sol, state))
end
get_action(sol::MaxUCBPolicy, state::State) =
best_action(sol, state)
rand_action(sol::MaxUCBPolicy, state::State) =
best_action(sol, state)
has_values(sol::MaxUCBPolicy) =
true
get_value(sol::MaxUCBPolicy, state::State) =
maximum(values(get_action_values(sol, state)))
get_value(sol::MaxUCBPolicy, state::State, action::Term) =
get_action_values(sol, state)[action]
has_cached_value(sol::MaxUCBPolicy, state::State) =
has_cached_value(sol.tree, state)
has_cached_value(sol::MaxUCBPolicy, state::State, action::Term) =
has_cached_value(sol.tree, state, action)
has_cached_action_values(sol::MaxUCBPolicy, state::State) =
has_cached_action_values(sol.tree, state)
BoltzmannUCBPolicy(tree::MCTSTreeSolution, confidence, temperature) =
BoltzmannPolicy(MaxUCBPolicy(tree, confidence), temperature)
"Abstract type for MCTS node selection strategies."
abstract type MCTSNodeSelector end
(sel::MCTSNodeSelector)(tree::MCTSTreeSolution, domain::Domain, state::State) =
error("Not implemented.")
"Max UCB selection strategy."
@kwdef struct MaxUCBSelector <: MCTSNodeSelector
confidence::Float64 = 2.0
end
(sel::MaxUCBSelector)(tree::MCTSTreeSolution, ::Domain, state::State) =
get_action(MaxUCBPolicy(tree, sel.confidence), state)
"Boltzmann UCB selection strategy."
@kwdef struct BoltzmannUCBSelector <: MCTSNodeSelector
confidence::Float64 = 2.0
temperature::Float64 = 1.0
end
(sel::BoltzmannUCBSelector)(tree::MCTSTreeSolution, ::Domain, state::State) =
get_action(BoltzmannUCBPolicy(tree, sel.confidence, sel.temperature), state)
## Leaf node estimators ##
"Abstract type for MCTS leaf value estimators."
abstract type MCTSLeafEstimator end
(e::MCTSLeafEstimator)(::Domain, ::State, ::Specification, depth::Int) =
error("Not implemented.")
"Estimates value as a constant."
struct ConstantEstimator{C <: Real} <: MCTSLeafEstimator
value::C
end
(e::ConstantEstimator)(::Domain, ::State, ::Specification, ::Int) = e.value
"Estimates value via uniform random rollouts."
struct RandomRolloutEstimator <: MCTSLeafEstimator end
function (e::RandomRolloutEstimator)(domain::Domain, state::State,
spec::Specification, depth::Int)
sim = RewardAccumulator(depth)
policy = RandomPolicy(domain)
return sim(policy, domain, state, spec)
end
"Estimates value via policy rollouts."
struct PolicyRolloutEstimator{P <: PolicySolution} <: MCTSLeafEstimator
policy::P
end
function (e::PolicyRolloutEstimator)(domain::Domain, state::State,
spec::Specification, depth::Int)
sim = RewardAccumulator(depth)
return sim(e.policy, domain, state, spec)
end
## Main algorithm ##
"""
MonteCarloTreeSearch(
n_rollouts::Int64 = 50,
max_depth::Int64 = 50,
heuristic::Heuristic = NullHeuristic(),
selector::MCTSNodeSelector = BoltzmannUCBSelector(),
estimator::MCTSLeafEstimator = RandomRolloutEstimator()
end
Planner that uses Monte Carlo Tree Search (`MCTS` for short) [1], with a
customizable initial value `heuristic`, node `selector` strategy, and leaf
node value `estimator`.
# Arguments
$(FIELDS)
"""
@kwdef mutable struct MonteCarloTreeSearch{
S <: MCTSNodeSelector, E <: MCTSLeafEstimator
} <: Planner
"Number of search rollouts to perform."
n_rollouts::Int64 = 50
"Maximum depth of rollout (including the selection and estimation phases)."
max_depth::Int64 = 50
"Initial value heuristic for newly encountered states / leaf nodes."
heuristic::Heuristic = NullHeuristic()
"Node selection strategy for previously visited nodes (e.g. MaxUCB)."
selector::S = BoltzmannUCBSelector()
"Estimator for leaf node values (e.g. random or policy-based rollouts)."
estimator::E = RandomRolloutEstimator()
end
@auto_hash MonteCarloTreeSearch
@auto_equals MonteCarloTreeSearch
const MCTS = MonteCarloTreeSearch
@doc (@doc MonteCarloTreeSearch) MCTS
function Base.copy(p::MonteCarloTreeSearch)
return MonteCarloTreeSearch(p.n_rollouts, p.max_depth, p.heuristic,
p.selector, p.estimator)
end
function solve(planner::MonteCarloTreeSearch,
domain::Domain, state::State, spec::Specification)
@unpack n_rollouts, max_depth = planner
@unpack heuristic, selector, estimator = planner
discount = get_discount(spec)
# Simplify goal specification
spec = simplify_goal(spec, domain, state)
# Precompute heuristic information
precompute!(heuristic, domain, state, spec)
# Initialize solution
sol = MCTSTreeSolution()
sol = insert_node!(planner, sol, domain, state, spec)
# Perform rollouts from initial state
a_visited, s_visited = Term[], State[]
initial_state = state
for n in 1:n_rollouts
state = initial_state
act = PDDL.no_op
value = 0.0
# Rollout until maximum depth
for t in 1:max_depth
# Terminate if rollout reaches goal
if is_goal(spec, domain, state, act) break end
# Select action
act = selector(sol, domain, state)
push!(s_visited, state)
push!(a_visited, act)
# Transition to next state
state = transition(domain, state, act)
# Insert leaf node and evaluate
if !has_state_node(sol, state)
insert_node!(planner, sol, domain, state, spec)
value = get_discount(spec)^t *
estimator(domain, state, spec, max_depth-t)
break
end
end
# Backpropagate value estimates
next_state = state
while length(s_visited) > 0
state, act = pop!(s_visited), pop!(a_visited)
state_id = hash(state)
# Update visitation counts
sol.s_visits[state_id] += 1
sol.a_visits[state_id][act] += 1
# Compute value (accumulated discounted reward)
reward = get_reward(spec, domain, state, act, next_state)
value = get_discount(spec) * value + reward
# Take weighted average with existing Q value
sol.Q[state_id][act] +=
(value-sol.Q[state_id][act]) / sol.a_visits[state_id][act]
next_state = state
end
end
return sol
end
function insert_node!(planner::MonteCarloTreeSearch, sol::MCTSTreeSolution,
domain::Domain, state::State, spec::Specification)
actions = available(domain, state)
qs = map(actions) do act
next_state = transition(domain, state, act)
r = get_reward(spec, domain, state, act, next_state)
h_val = compute(planner.heuristic, domain, next_state, spec)
return get_discount(spec) * (-h_val) + r
end
state_id = hash(state)
sol.Q[state_id] = Dict{Term,Float64}(zip(actions, qs))
sol.a_visits[state_id] = Dict{Term,Int}(a => 0 for a in actions)
sol.s_visits[state_id] = 0
return sol
end