-
Notifications
You must be signed in to change notification settings - Fork 103
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added new handling of terminal states in GenerativeBelief MDP (#559)
* added new handling of terminal states in GenerativeBelief MDP * fixed typo * worked on docs for GBMDP
- Loading branch information
Showing
5 changed files
with
140 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
gbmdp_handle_terminal(pomdp, updater, b, s, a, rng) = nothing |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,55 +1,99 @@ | ||
""" | ||
GenerativeBeliefMDP(pomdp, updater) | ||
GenerativeBeliefMDP(pomdp, updater; terminal_behavior=TerminalStateTerminalBehavior()) | ||
Create a generative model of the belief MDP corresponding to POMDP `pomdp` with belief updates performed by `updater`. | ||
Create a generative model of the belief MDP corresponding to POMDP `pomdp` with belief updates performed by `updater`. Each step is performed by sampling a state from the current belief, generating an observation from that state and action, and then using `updater` to update the belief. | ||
A belief is considered terminal when _all_ POMDP states in the support with nonzero probability are terminal. | ||
The default behavior when a terminal POMDP state is sampled from the belief is to transition to [`terminalstate`](@ref). This can be controlled by the `terminal_behavior` keyword argument. Using `terminal_behavior=ContinueTerminalBehavior(pomdp, updater)` will cause the MDP to keep attempting a belief update even when the sampled state is terminal. This can be further customized by providing `terminal_behavior` with a `Function` or callable object that takes arguments b, s, a, rng and returns a new belief (see the implementation of `ContinueTerminalBehavior` for an example). | ||
""" | ||
struct GenerativeBeliefMDP{P<:POMDP, U<:Updater, B, A} <: MDP{B, A} | ||
struct GenerativeBeliefMDP{P<:POMDP, U<:Updater, T, B, A} <: MDP{B, A} | ||
pomdp::P | ||
updater::U | ||
terminal_behavior::T | ||
end | ||
|
||
function GenerativeBeliefMDP(pomdp, updater; terminal_behavior=BackwardCompatibleTerminalBehavior(pomdp, updater)) | ||
B = determine_gbmdp_state_type(pomdp, updater, terminal_behavior) | ||
GenerativeBeliefMDP{typeof(pomdp), | ||
typeof(updater), | ||
typeof(terminal_behavior), | ||
B, | ||
actiontype(pomdp) | ||
}(pomdp, updater, terminal_behavior) | ||
end | ||
|
||
function GenerativeBeliefMDP(pomdp::P, up::U) where {P<:POMDP, U<:Updater} | ||
# XXX hack to determine belief type | ||
b0 = initialize_belief(up, initialstate(pomdp)) | ||
GenerativeBeliefMDP{P, U, typeof(b0), actiontype(pomdp)}(pomdp, up) | ||
function initialstate(bmdp::GenerativeBeliefMDP) | ||
return Deterministic(initialize_belief(bmdp.updater, initialstate(bmdp.pomdp))) | ||
end | ||
|
||
function POMDPs.gen(bmdp::GenerativeBeliefMDP, b, a, rng::AbstractRNG) | ||
s = rand(rng, b) | ||
if isterminal(bmdp.pomdp, s) | ||
bp = gbmdp_handle_terminal(bmdp.pomdp, bmdp.updater, b, s, a, rng::AbstractRNG)::typeof(b) | ||
bp = bmdp.terminal_behavior(b, s, a, rng) | ||
return (sp=bp, r=0.0) | ||
end | ||
sp, o, r = @gen(:sp, :o, :r)(bmdp.pomdp, s, a, rng) # maybe this should have been generate_or? | ||
o, r = @gen(:o, :r)(bmdp.pomdp, s, a, rng) | ||
bp = update(bmdp.updater, b, a, o) | ||
return (sp=bp, r=r) | ||
end | ||
|
||
actions(bmdp::GenerativeBeliefMDP{P,U,B,A}, b::B) where {P,U,B,A} = actions(bmdp.pomdp, b) | ||
actions(bmdp::GenerativeBeliefMDP) = actions(bmdp.pomdp) | ||
|
||
isterminal(bmdp::GenerativeBeliefMDP, b) = all(isterminal(bmdp.pomdp, s) for s in support(b)) | ||
isterminal(bmdp::GenerativeBeliefMDP, b) = all(s -> isterminal(bmdp.pomdp, s) || pdf(b, s) == 0.0, support(b)) | ||
isterminal(bmdp::GenerativeBeliefMDP, ts::TerminalState) = true | ||
|
||
discount(bmdp::GenerativeBeliefMDP) = discount(bmdp.pomdp) | ||
|
||
# override this if you want to handle it in a special way | ||
function gbmdp_handle_terminal(pomdp::POMDP, updater::Updater, b, s, a, rng) | ||
@warn(""" | ||
Sampled a terminal state for a GenerativeBeliefMDP transition - not sure how to proceed, but will try. | ||
""" | ||
determine_gbmdp_state_type(pomdp, updater, [terminal_behavior]) | ||
This function is called to determine the state type for a GenerativeBeliefMDP. By default, it will return typeof(initialize_belief(updater, initialstate(pomdp))). | ||
If a belief updater may use a belief type different from the output of initialize_belief, for example if the belief type can change after an update, override `determine_gbmdp_state_type(pomdp, updater)`. | ||
See $(@__FILE__) and implement a new method of POMDPToolbox.gbmdp_handle_terminal if you want special behavior in this case. | ||
If the terminal behavior adds a new possible state type, override `determine_gbmdp_state_type(pomdp, updater, terminal_behavior)` to return the `Union` of the new state type and the output of `determine_gbmdp_state_type(pomdp, updater)` | ||
""" | ||
function determine_gbmdp_state_type end # for documentation | ||
|
||
""", maxlog=1) | ||
sp, o, r = @gen(:sp, :o, :r)(pomdp, s, a, rng) | ||
bp = update(updater, b, a, o) | ||
return bp | ||
function determine_gbmdp_state_type(pomdp, updater) | ||
b0 = initialize_belief(updater, initialstate(pomdp)) | ||
return typeof(b0) | ||
end | ||
|
||
function initialstate(bmdp::GenerativeBeliefMDP) | ||
return Deterministic(initialize_belief(bmdp.updater, initialstate(bmdp.pomdp))) | ||
determine_gbmdp_state_type(pomdp, updater, terminal_behavior) = determine_gbmdp_state_type(pomdp, updater) | ||
|
||
struct BackwardCompatibleTerminalBehavior{M, U} | ||
pomdp::M | ||
updater::U | ||
end | ||
|
||
function (tb::BackwardCompatibleTerminalBehavior)(b, s, a, rng) | ||
|
||
# This code block is only to handle backwards compatibility for the deprecated gbmdp_handle_terminal function | ||
bp = gbmdp_handle_terminal(tb.pomdp, tb.updater, b, s, a, rng) | ||
if bp != nothing # user has implemented gbmdp_handle_terminal | ||
Base.depwarn("Using gbmdp_handle_terminal to specify terminal behavior for a GenerativeBeliefMDP is deprecated. Use the terminal_behavior keyword argument instead.", :gbmdp_handle_terminal) | ||
return bp | ||
end | ||
|
||
return TerminalStateTerminalBehavior()(b, s, a, rng) | ||
end | ||
|
||
determine_gbmdp_state_type(pomdp, updater, tb::BackwardCompatibleTerminalBehavior) = determine_gbmdp_state_type(pomdp, updater, TerminalStateTerminalBehavior()) | ||
|
||
struct ContinueTerminalBehavior{M, U} | ||
pomdp::M | ||
updater::U | ||
end | ||
|
||
# deprecated in POMDPs v0.9 | ||
function initialstate(bmdp::GenerativeBeliefMDP, rng::AbstractRNG) | ||
return initialize_belief(bmdp.updater, initialstate(bmdp.pomdp)) | ||
function (tb::ContinueTerminalBehavior)(b, s, a, rng) | ||
o, r = @gen(:o, :r)(tb.pomdp, s, a, rng) | ||
return update(tb.updater, b, a, o) | ||
end | ||
|
||
struct TerminalStateTerminalBehavior end | ||
(tb::TerminalStateTerminalBehavior)(args...) = terminalstate | ||
determine_gbmdp_state_type(pomdp, updater, tb::TerminalStateTerminalBehavior) = promote_type(determine_gbmdp_state_type(pomdp, updater), TerminalState) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
66 changes: 62 additions & 4 deletions
66
lib/POMDPTools/test/model_tools/test_generative_belief_mdp.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,65 @@ | ||
let | ||
pomdp = BabyPOMDP() | ||
up = updater(pomdp) | ||
@testset "GenerativeBeliefMDP" begin | ||
@testset "Baby" begin | ||
pomdp = BabyPOMDP() | ||
up = updater(pomdp) | ||
|
||
bmdp = GenerativeBeliefMDP(pomdp, up) | ||
b = initialstate(bmdp, Random.default_rng()) | ||
bmdp = GenerativeBeliefMDP(pomdp, up) | ||
b = rand(initialstate(bmdp)) | ||
@test rand(b) isa statetype(pomdp) | ||
|
||
@test simulate(RolloutSimulator(max_steps=10), bmdp, RandomPolicy(bmdp)) <= 0 | ||
end | ||
|
||
terminal_test_m = QuickPOMDP( | ||
states = 1:2, | ||
actions = 1:2, | ||
observations = 1:2, | ||
transition = (s, a) -> Deterministic(1), | ||
observation = (a, sp) -> Deterministic(sp), | ||
reward = s -> 0.0, | ||
isterminal = ==(1), | ||
initialstate = Deterministic(2) | ||
) | ||
|
||
@testset "Terminal Default" begin | ||
up = DiscreteUpdater(terminal_test_m) | ||
bm = GenerativeBeliefMDP(terminal_test_m, up) | ||
|
||
hist = collect(stepthrough(bm, RandomPolicy(bm), "s,sp", max_steps=10)) | ||
@test length(hist) == 1 | ||
@test only(hist).s == DiscreteBelief(terminal_test_m, [0.0, 1.0]) | ||
@test only(hist).sp == DiscreteBelief(terminal_test_m, [1.0, 0.0]) | ||
@test !isterminal(bm, only(hist).s) | ||
@test isterminal(bm, only(hist).sp) | ||
end | ||
|
||
@testset "Terminal Uninformative Update" begin | ||
struct UninformativeUpdater{M} <: Updater | ||
m::M | ||
end | ||
|
||
POMDPs.update(up::UninformativeUpdater, b, a, o) = Uniform(states(up.m)) | ||
POMDPs.initialize_belief(up::UninformativeUpdater, d::Deterministic) = Uniform(rand(d)) | ||
|
||
up = UninformativeUpdater(terminal_test_m) | ||
|
||
# default terminal behavior | ||
bm = GenerativeBeliefMDP(terminal_test_m, up) | ||
hist = collect(stepthrough(bm, RandomPolicy(bm), "s,sp")) | ||
@test isterminal(bm, last(hist).sp) | ||
|
||
behavior = TerminalStateTerminalBehavior() | ||
bm = GenerativeBeliefMDP(terminal_test_m, up) | ||
hist = collect(stepthrough(bm, RandomPolicy(bm), "s,sp")) | ||
@test last(hist).sp === terminalstate | ||
@test isterminal(bm, last(hist).sp) | ||
|
||
behavior = ContinueTerminalBehavior(terminal_test_m, up) | ||
bm = GenerativeBeliefMDP(terminal_test_m, up, terminal_behavior=behavior) | ||
hist = collect(stepthrough(bm, RandomPolicy(bm), "s,sp", max_steps=10)) | ||
@test length(hist) == 10 | ||
end | ||
|
||
end | ||
end |