forked from JuliaPOMDP/DiscreteValueIteration.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
sparse.jl
106 lines (98 loc) · 3.88 KB
/
sparse.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
struct SparseValueIterationSolver <: Solver
max_iterations::Int64
belres::Float64 # the Bellman Residual
include_Q::Bool
verbose::Bool
init_util::Vector{Float64}
end
function SparseValueIterationSolver(;max_iterations=500,
belres::Float64=1e-3,
include_Q::Bool=true,
verbose::Bool=false,
init_util::Vector{Float64}=Vector{Float64}(undef, 0))
return SparseValueIterationSolver(max_iterations, belres, include_Q, verbose, init_util)
end
@POMDP_require solve(solver::SparseValueIterationSolver, mdp::MDP) begin
P = typeof(mdp)
S = statetype(P)
A = actiontype(P)
@req discount(::P)
@subreq ordered_states(mdp)
@subreq ordered_actions(mdp)
@req transition(::P,::S,::A)
@req reward(::P,::S,::A,::S)
@req stateindex(::P,::S)
@req actionindex(::P, ::A)
@req actions(::P, ::S)
as = actions(mdp)
ss = states(mdp)
@req length(::typeof(ss))
@req length(::typeof(as))
a = first(as)
s = first(ss)
dist = transition(mdp, s, a)
D = typeof(dist)
@req support(::D)
@req pdf(::D,::S)
@subreq SparseTabularMDP(mdp)
end
function qvalue!(m::Union{MDP,POMDP}, transition_A_S_S2, reward_S_A::AbstractMatrix{F}, value_S::AbstractVector{F}, out_qvals_S_A) where {F}
@assert size(out_qvals_S_A) == (length(states(m)), length(actions(m)))
for a in 1:length(actions(m))
out_qvals_S_A[:, a] = view(reward_S_A, :, a) + discount(m) * transition_A_S_S2[a] * value_S
end
end
function solve(solver::SparseValueIterationSolver, mdp::SparseTabularMDP)
nS = length(states(mdp))
nA = length(actions(mdp))
if isempty(solver.init_util)
v_S = zeros(nS)
else
@assert length(solver.init_util) == nS "Input utility dimension mismatch"
v_S = solver.init_util
end
transition_A_S_S2 = transition_matrices(mdp)
reward_S_A = reward_matrix(mdp)
qvals_S_A = zeros(nS, nA)
maxchanges_T = zeros(solver.max_iterations)
total_time = 0.0
for i in 1:solver.max_iterations
iter_time = @elapsed begin
qvalue!(mdp, transition_A_S_S2, reward_S_A, v_S, qvals_S_A)
new_v_S = dropdims(maximum(qvals_S_A, dims=2), dims=2)
@assert size(v_S) == size(new_v_S)
maxchanges_T[i] = maximum(abs.(new_v_S .- v_S))
v_S = new_v_S
end
total_time += iter_time
if solver.verbose
@info "residual: $(maxchanges_T[i]), time: $(iter_time), total time: $(total_time) " i
end
maxchanges_T[i] < solver.belres ? break : nothing
end
qvalue!(mdp, transition_A_S_S2, reward_S_A, v_S, qvals_S_A)
# Rounding to avoid floating point error noise
policy_S = dropdims(getindex.(argmax(round.(qvals_S_A, digits=20), dims=2), 2), dims=2)
if solver.include_Q
policy = ValueIterationPolicy(mdp, qvals_S_A, v_S, policy_S)
else
policy = ValueIterationPolicy(mdp, utility=v_S, policy=policy_S, include_Q=false)
end
return policy
end
function solve(solver::SparseValueIterationSolver, mdp::MDP)
p = solve(solver, SparseTabularMDP(mdp))
return ValueIterationPolicy(p.qmat, p.util, p.policy, ordered_actions(mdp), p.include_Q, mdp)
end
function solve(::SparseValueIterationSolver, ::POMDP)
throw("""
ValueIterationError: `solve(::SparseValueIterationSolver, ::POMDP)` is not supported,
`SparseValueIterationSolver` supports MDP models only, look at QMDP.jl for a POMDP solver that assumes full observability.
If you still wish to use the transition and reward from your POMDP model you can use the `UnderlyingMDP` wrapper from POMDPModelTools.jl as follows:
```
solver = ValueIterationSolver()
mdp = UnderlyingMDP(pomdp)
solve(solver, mdp)
```
""")
end