-
Notifications
You must be signed in to change notification settings - Fork 105
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' of https://github.com/JuliaPOMDP/POMDPs.jl
- Loading branch information
Showing
8 changed files
with
148 additions
and
41 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
gbmdp_handle_terminal(pomdp, updater, b, s, a, rng) = nothing |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,63 +1,99 @@ | ||
""" | ||
GenerativeBeliefMDP(pomdp, updater) | ||
GenerativeBeliefMDP(pomdp, updater; terminal_behavior=TerminalStateTerminalBehavior()) | ||
<<<<<<< Updated upstream | ||
Create a generative model of the belief MDP corresponding to POMDP `pomdp` with belief updates performed by `updater`. | ||
======= | ||
Create a generative model of the belief MDP corresponding to POMDP `pomdp` with belief updates performed by `updater`. Each step is performed by sampling a state from the current belief, generating an observation from that state and action, and then using `updater` to update the belief. | ||
A belief is considered terminal when _all_ POMDP states in the support with nonzero probability are terminal. | ||
The default behavior when a terminal POMDP state is sampled from the belief is to transition to [`terminalstate`](@ref). This can be controlled by the `terminal_behavior` keyword argument. Using `terminal_behavior=ContinueTerminalBehavior(pomdp, updater)` will cause the MDP to keep attempting a belief update even when the sampled state is terminal. This can be further customized by providing `terminal_behavior` with a `Function` or callable object that takes arguments `b, s, a, rng` and returns a new belief (see the implementation of `ContinueTerminalBehavior` for an example). `determine_gbmdp_state_type` can be used to further customize behavior. | ||
>>>>>>> Stashed changes | ||
The default behavior when a terminal POMDP state is sampled from the belief is to transition to [`terminalstate`](@ref). This can be controlled by the `terminal_behavior` keyword argument. Using `terminal_behavior=ContinueTerminalBehavior(pomdp, updater)` will cause the MDP to keep attempting a belief update even when the sampled state is terminal. This can be further customized by providing `terminal_behavior` with a `Function` or callable object that takes arguments b, s, a, rng and returns a new belief (see the implementation of `ContinueTerminalBehavior` for an example). | ||
""" | ||
struct GenerativeBeliefMDP{P<:POMDP, U<:Updater, B, A} <: MDP{B, A} | ||
struct GenerativeBeliefMDP{P<:POMDP, U<:Updater, T, B, A} <: MDP{B, A} | ||
pomdp::P | ||
updater::U | ||
terminal_behavior::T | ||
end | ||
|
||
function GenerativeBeliefMDP(pomdp::P, up::U) where {P<:POMDP, U<:Updater} | ||
# XXX hack to determine belief type | ||
b0 = initialize_belief(up, initialstate(pomdp)) | ||
GenerativeBeliefMDP{P, U, typeof(b0), actiontype(pomdp)}(pomdp, up) | ||
function GenerativeBeliefMDP(pomdp, updater; terminal_behavior=BackwardCompatibleTerminalBehavior(pomdp, updater)) | ||
B = determine_gbmdp_state_type(pomdp, updater, terminal_behavior) | ||
GenerativeBeliefMDP{typeof(pomdp), | ||
typeof(updater), | ||
typeof(terminal_behavior), | ||
B, | ||
actiontype(pomdp) | ||
}(pomdp, updater, terminal_behavior) | ||
end | ||
|
||
function initialstate(bmdp::GenerativeBeliefMDP) | ||
return Deterministic(initialize_belief(bmdp.updater, initialstate(bmdp.pomdp))) | ||
end | ||
|
||
function POMDPs.gen(bmdp::GenerativeBeliefMDP, b, a, rng::AbstractRNG) | ||
s = rand(rng, b) | ||
if isterminal(bmdp.pomdp, s) | ||
bp = gbmdp_handle_terminal(bmdp.pomdp, bmdp.updater, b, s, a, rng::AbstractRNG)::typeof(b) | ||
bp = bmdp.terminal_behavior(b, s, a, rng) | ||
return (sp=bp, r=0.0) | ||
end | ||
sp, o, r = @gen(:sp, :o, :r)(bmdp.pomdp, s, a, rng) # maybe this should have been generate_or? | ||
o, r = @gen(:o, :r)(bmdp.pomdp, s, a, rng) | ||
bp = update(bmdp.updater, b, a, o) | ||
return (sp=bp, r=r) | ||
end | ||
|
||
actions(bmdp::GenerativeBeliefMDP{P,U,B,A}, b::B) where {P,U,B,A} = actions(bmdp.pomdp, b) | ||
actions(bmdp::GenerativeBeliefMDP) = actions(bmdp.pomdp) | ||
|
||
isterminal(bmdp::GenerativeBeliefMDP, b) = all(isterminal(bmdp.pomdp, s) for s in support(b)) | ||
isterminal(bmdp::GenerativeBeliefMDP, b) = all(s -> isterminal(bmdp.pomdp, s) || pdf(b, s) == 0.0, support(b)) | ||
isterminal(bmdp::GenerativeBeliefMDP, ts::TerminalState) = true | ||
|
||
discount(bmdp::GenerativeBeliefMDP) = discount(bmdp.pomdp) | ||
|
||
# override this if you want to handle it in a special way | ||
function gbmdp_handle_terminal(pomdp::POMDP, updater::Updater, b, s, a, rng) | ||
@warn(""" | ||
Sampled a terminal state for a GenerativeBeliefMDP transition - not sure how to proceed, but will try. | ||
""" | ||
determine_gbmdp_state_type(pomdp, updater, [terminal_behavior]) | ||
This function is called to determine the state type for a GenerativeBeliefMDP. By default, it will return typeof(initialize_belief(updater, initialstate(pomdp))). | ||
See $(@__FILE__) and implement a new method of POMDPToolbox.gbmdp_handle_terminal if you want special behavior in this case. | ||
If a belief updater may use a belief type different from the output of initialize_belief, for example if the belief type can change after an update, override `determine_gbmdp_state_type(pomdp, updater)`. | ||
""", maxlog=1) | ||
sp, o, r = @gen(:sp, :o, :r)(pomdp, s, a, rng) | ||
bp = update(updater, b, a, o) | ||
return bp | ||
If the terminal behavior adds a new possible state type, override `determine_gbmdp_state_type(pomdp, updater, terminal_behavior)` to return the `Union` of the new state type and the output of `determine_gbmdp_state_type(pomdp, updater)` | ||
""" | ||
function determine_gbmdp_state_type end # for documentation | ||
|
||
function determine_gbmdp_state_type(pomdp, updater) | ||
b0 = initialize_belief(updater, initialstate(pomdp)) | ||
return typeof(b0) | ||
end | ||
|
||
function initialstate(bmdp::GenerativeBeliefMDP) | ||
return Deterministic(initialize_belief(bmdp.updater, initialstate(bmdp.pomdp))) | ||
determine_gbmdp_state_type(pomdp, updater, terminal_behavior) = determine_gbmdp_state_type(pomdp, updater) | ||
|
||
struct BackwardCompatibleTerminalBehavior{M, U} | ||
pomdp::M | ||
updater::U | ||
end | ||
|
||
# deprecated in POMDPs v0.9 | ||
function initialstate(bmdp::GenerativeBeliefMDP, rng::AbstractRNG) | ||
return initialize_belief(bmdp.updater, initialstate(bmdp.pomdp)) | ||
function (tb::BackwardCompatibleTerminalBehavior)(b, s, a, rng) | ||
|
||
# This code block is only to handle backwards compatibility for the deprecated gbmdp_handle_terminal function | ||
bp = gbmdp_handle_terminal(tb.pomdp, tb.updater, b, s, a, rng) | ||
if bp != nothing # user has implemented gbmdp_handle_terminal | ||
Base.depwarn("Using gbmdp_handle_terminal to specify terminal behavior for a GenerativeBeliefMDP is deprecated. Use the terminal_behavior keyword argument instead.", :gbmdp_handle_terminal) | ||
return bp | ||
end | ||
|
||
return TerminalStateTerminalBehavior()(b, s, a, rng) | ||
end | ||
|
||
determine_gbmdp_state_type(pomdp, updater, tb::BackwardCompatibleTerminalBehavior) = determine_gbmdp_state_type(pomdp, updater, TerminalStateTerminalBehavior()) | ||
|
||
struct ContinueTerminalBehavior{M, U} | ||
pomdp::M | ||
updater::U | ||
end | ||
|
||
function (tb::ContinueTerminalBehavior)(b, s, a, rng) | ||
o, r = @gen(:o, :r)(tb.pomdp, s, a, rng) | ||
return update(tb.updater, b, a, o) | ||
end | ||
|
||
struct TerminalStateTerminalBehavior end | ||
(tb::TerminalStateTerminalBehavior)(args...) = terminalstate | ||
determine_gbmdp_state_type(pomdp, updater, tb::TerminalStateTerminalBehavior) = promote_type(determine_gbmdp_state_type(pomdp, updater), TerminalState) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
66 changes: 62 additions & 4 deletions
66
lib/POMDPTools/test/model_tools/test_generative_belief_mdp.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,65 @@ | ||
let | ||
pomdp = BabyPOMDP() | ||
up = updater(pomdp) | ||
@testset "GenerativeBeliefMDP" begin | ||
@testset "Baby" begin | ||
pomdp = BabyPOMDP() | ||
up = updater(pomdp) | ||
|
||
bmdp = GenerativeBeliefMDP(pomdp, up) | ||
b = initialstate(bmdp, Random.default_rng()) | ||
bmdp = GenerativeBeliefMDP(pomdp, up) | ||
b = rand(initialstate(bmdp)) | ||
@test rand(b) isa statetype(pomdp) | ||
|
||
@test simulate(RolloutSimulator(max_steps=10), bmdp, RandomPolicy(bmdp)) <= 0 | ||
end | ||
|
||
terminal_test_m = QuickPOMDP( | ||
states = 1:2, | ||
actions = 1:2, | ||
observations = 1:2, | ||
transition = (s, a) -> Deterministic(1), | ||
observation = (a, sp) -> Deterministic(sp), | ||
reward = s -> 0.0, | ||
isterminal = ==(1), | ||
initialstate = Deterministic(2) | ||
) | ||
|
||
@testset "Terminal Default" begin | ||
up = DiscreteUpdater(terminal_test_m) | ||
bm = GenerativeBeliefMDP(terminal_test_m, up) | ||
|
||
hist = collect(stepthrough(bm, RandomPolicy(bm), "s,sp", max_steps=10)) | ||
@test length(hist) == 1 | ||
@test only(hist).s == DiscreteBelief(terminal_test_m, [0.0, 1.0]) | ||
@test only(hist).sp == DiscreteBelief(terminal_test_m, [1.0, 0.0]) | ||
@test !isterminal(bm, only(hist).s) | ||
@test isterminal(bm, only(hist).sp) | ||
end | ||
|
||
@testset "Terminal Uninformative Update" begin | ||
struct UninformativeUpdater{M} <: Updater | ||
m::M | ||
end | ||
|
||
POMDPs.update(up::UninformativeUpdater, b, a, o) = Uniform(states(up.m)) | ||
POMDPs.initialize_belief(up::UninformativeUpdater, d::Deterministic) = Uniform(rand(d)) | ||
|
||
up = UninformativeUpdater(terminal_test_m) | ||
|
||
# default terminal behavior | ||
bm = GenerativeBeliefMDP(terminal_test_m, up) | ||
hist = collect(stepthrough(bm, RandomPolicy(bm), "s,sp")) | ||
@test isterminal(bm, last(hist).sp) | ||
|
||
behavior = TerminalStateTerminalBehavior() | ||
bm = GenerativeBeliefMDP(terminal_test_m, up) | ||
hist = collect(stepthrough(bm, RandomPolicy(bm), "s,sp")) | ||
@test last(hist).sp === terminalstate | ||
@test isterminal(bm, last(hist).sp) | ||
|
||
behavior = ContinueTerminalBehavior(terminal_test_m, up) | ||
bm = GenerativeBeliefMDP(terminal_test_m, up, terminal_behavior=behavior) | ||
hist = collect(stepthrough(bm, RandomPolicy(bm), "s,sp", max_steps=10)) | ||
@test length(hist) == 10 | ||
end | ||
|
||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters