Skip to content

Commit

Permalink
Merge pull request #3 from JuliaPOMDP/initstates
Browse files Browse the repository at this point in the history
implemented initial distributions and terminal states (#2)
  • Loading branch information
zsunberg authored Jun 8, 2019
2 parents 546396d + 0ee69b4 commit 779be37
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 25 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ POMDPTesting = "92e6a534-49c2-5324-9027-86e3c861ab81"
POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d"

[compat]
POMDPModelTools = ">=0.1.6"
julia = "1"

[extras]
POMDPPolicies = "182e52fb-cfd0-5e46-8c26-fd0667c990f4"
POMDPSimulators = "e0d0a172-29c6-5d4e-96d0-f262df5d01fd"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[targets]
test = ["Test", "POMDPPolicies", "POMDPSimulators"]
test = ["Test", "POMDPPolicies", "POMDPSimulators", "Random"]
50 changes: 27 additions & 23 deletions src/discrete_explicit.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
struct DiscreteExplicitPOMDP{S,A,O,OF,RF} <: POMDP{S,A,O}
struct DiscreteExplicitPOMDP{S,A,O,OF,RF,D} <: POMDP{S,A,O}
s::Vector{S}
a::Vector{A}
o::Vector{O}
Expand All @@ -10,16 +10,20 @@ struct DiscreteExplicitPOMDP{S,A,O,OF,RF} <: POMDP{S,A,O}
amap::Dict{A,Int}
omap::Dict{O,Int}
discount::Float64
initial::D
terminals::Set{S}
end

struct DiscreteExplicitMDP{S,A,RF} <: MDP{S,A}
struct DiscreteExplicitMDP{S,A,RF,D} <: MDP{S,A}
s::Vector{S}
a::Vector{A}
tds::Dict{Tuple{S,A}, SparseCat{Vector{S}, Vector{Float64}}}
r::RF
smap::Dict{S,Int}
amap::Dict{A,Int}
discount::Float64
initial::D
terminals::Set{S}
end

const DEP = DiscreteExplicitPOMDP
Expand All @@ -42,38 +46,35 @@ POMDPs.transition(m::DE, s, a) = m.tds[s,a]
POMDPs.observation(m::DEP, a, sp) = m.ods[a,sp]
POMDPs.reward(m::DE, s, a) = m.r(s, a)

POMDPs.initialstate_distribution(m::DEP) = uniform_belief(m)
# XXX hack
POMDPs.initialstate_distribution(m::DiscreteExplicitMDP) = uniform_belief(FullyObservablePOMDP(m))
POMDPs.initialstate_distribution(m::DE) = m.initial

POMDPs.isterminal(m::DE,s) = s in m.terminals

POMDPModelTools.ordered_states(m::DE) = m.s
POMDPModelTools.ordered_actions(m::DE) = m.a
POMDPModelTools.ordered_observations(m::DEP) = m.o

# TODO reward(m, s, a)
# TODO support O(s, a, sp, o)
# TODO initial state distribution
# TODO convert_s, etc, dimensions
# TODO better errors if T or Z return something unexpected

"""
DiscreteExplicitPOMDP(S,A,O,T,Z,R,γ)
DiscreteExplicitPOMDP(S,A,O,T,Z,R,γ,[b₀],[terminal=Set()])
Create a POMDP defined by the tuple (S,A,O,T,Z,R,γ).
# Arguments
## Required
- `S`,`A`,`O`: State, action, and observation spaces (typically `Vector`s)
- `T::Function`: Transition probability distribution function; ``T(s,a,s')`` is the probability of transitioning to state ``s'`` from state ``s`` after taking action ``a``.
- `Z::Function`: Observation probability distribution function; ``O(a, s', o)`` is the probability of receiving observation ``o`` when state ``s'`` is reached after action ``a``.
- `R::Function`: Reward function; ``R(s,a)`` is the reward for taking action ``a`` in state ``s``.
- `γ::Float64`: Discount factor.
# Notes
- The default initial state distribution is uniform across all states. Changing this is not yet supported, but it can be overridden for simulations.
- Terminal states are not yet supported, but absorbing states with zero reward can be used.
## Optional
- `b₀=Uniform(S)`: Initial belief/state distribution (See `POMDPModelTools.Deterministic` and `POMDPModelTools.SparseCat` for other options).
## Keyword
- `terminals=Set()`: Set of terminal states. Once a terminal state is reached, no more actions can be taken or reward received.
"""
function DiscreteExplicitPOMDP(s, a, o, t, z, r, discount)
function DiscreteExplicitPOMDP(s, a, o, t, z, r, discount, b0=Uniform(s); terminals=Set())
ss = vec(collect(s))
as = vec(collect(a))
os = vec(collect(o))
Expand Down Expand Up @@ -107,7 +108,7 @@ function DiscreteExplicitPOMDP(s, a, o, t, z, r, discount)
Dict(ss[i]=>i for i in 1:length(ss)),
Dict(as[i]=>i for i in 1:length(as)),
Dict(os[i]=>i for i in 1:length(os)),
discount
discount, b0, convert(Set{eltype(ss)}, terminals)
)

probability_check(m)
Expand All @@ -116,22 +117,25 @@ function DiscreteExplicitPOMDP(s, a, o, t, z, r, discount)
end

"""
DiscreteExplicitMDP(S,A,T,R,γ)
DiscreteExplicitMDP(S,A,T,R,γ,[p₀])
Create an MDP defined by the tuple (S,A,T,R,γ).
# Arguments
## Required
- `S`,`A`: State and action spaces (typically `Vector`s)
- `T::Function`: Transition probability distribution function; ``T(s,a,s')`` is the probability of transitioning to state ``s'`` from state ``s`` after taking action ``a``.
- `R::Function`: Reward function; ``R(s,a)`` is the reward for taking action ``a`` in state ``s``.
- `γ::Float64`: Discount factor.
# Notes
- The default initial state distribution is uniform across all states. Changing this is not yet supported, but it can be overridden for simulations.
- Terminal states are not yet supported, but absorbing states with zero reward can be used.
## Optional
- `p₀=Uniform(S)`: Initial state distribution (See `POMDPModelTools.Deterministic` and `POMDPModelTools.SparseCat` for other options).
## Keyword
- `terminals=Set()`: Set of terminal states. Once a terminal state is reached, no more actions can be taken or reward received.
"""
function DiscreteExplicitMDP(s, a, t, r, discount)
function DiscreteExplicitMDP(s, a, t, r, discount, p0=Uniform(s); terminals=Set())
ss = vec(collect(s))
as = vec(collect(a))

Expand All @@ -141,7 +145,7 @@ function DiscreteExplicitMDP(s, a, t, r, discount)
ss, as, tds, r,
Dict(ss[i]=>i for i in 1:length(ss)),
Dict(as[i]=>i for i in 1:length(as)),
discount
discount, p0, convert(Set{eltype(ss)}, terminals)
)

trans_prob_consistency_check(m)
Expand Down
9 changes: 9 additions & 0 deletions test/discrete_explicit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,18 @@
end
println("Undiscounted reward was $rsum.")
@test rsum == -10.0

dm = DiscreteExplicitPOMDP(S,A,O,T,Z,R,γ,Deterministic(:left))
@test initialstate(dm, Random.GLOBAL_RNG) == :left
tm = DiscreteExplicitPOMDP(S,A,O,T,Z,R,γ,terminals=Set(S))
@test isterminal(tm, initialstate(tm, Random.GLOBAL_RNG))
end

@testset "Discrete Explicit MDP" begin
S = 1:5
A = [-1, 1]
γ = 0.95
p₀ = Deterministic(1)

function T(s, a, sp)
if sp == clamp(s+a,1,5)
Expand All @@ -73,6 +79,9 @@ end
end

m = DiscreteExplicitMDP(S,A,T,R,γ)
m = DiscreteExplicitMDP(S,A,T,R,γ,p₀)
m = DiscreteExplicitMDP(S,A,T,R,γ,p₀,terminals=Set(5))
@test isterminal(m, 5)

solver = FunctionSolver(x->1)
policy = solve(solver, m)
Expand Down
7 changes: 6 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
using QuickPOMDPs
using Test

using POMDPs, POMDPPolicies, POMDPSimulators, BeliefUpdaters
using POMDPs
using POMDPPolicies
using POMDPSimulators
using BeliefUpdaters
using POMDPModelTools
using Random

include("discrete_explicit.jl")

0 comments on commit 779be37

Please sign in to comment.