diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index a36c6d47..0bcbfdb3 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -1,6 +1,11 @@ name: CI -on: [push, pull_request, workflow_dispatch] +on: + push: + pull_request: + workflow_dispatch: + schedule: + - cron: '30 23 * * 0' # Runs every Sunday at 11:30 PM jobs: test: diff --git a/docs/src/example_defining_problems.md b/docs/src/example_defining_problems.md index 85eb3019..e7f2d173 100644 --- a/docs/src/example_defining_problems.md +++ b/docs/src/example_defining_problems.md @@ -106,7 +106,7 @@ quick_crying_baby_pomdp = QuickPOMDP( using POMDPs using POMDPTools -struct CryingBabyState +struct CryingBabyState # Alternatively, you could just use a Bool or Symbol for the state. hungry::Bool end @@ -208,11 +208,11 @@ using POMDPs using POMDPTools using Random -struct GenCryingBabyState +struct GenCryingBabyState # Alternatively, you could just use a Bool or Symbol for the state. hungry::Bool end -struct GenCryingBabyPOMDP <: POMDP{CryingBabyState, Symbol, Symbol} +struct GenCryingBabyPOMDP <: POMDP{GenCryingBabyState, Symbol, Symbol} p_sated_to_hungry::Float64 p_cry_feed_hungry::Float64 p_cry_sing_hungry::Float64 @@ -228,7 +228,7 @@ struct GenCryingBabyPOMDP <: POMDP{CryingBabyState, Symbol, Symbol} GenCryingBabyPOMDP() = new(0.1, 0.8, 0.9, 0.8, 0.1, 0.0, 0.1, -10.0, -5.0, -0.5, 0.9) end -function POMDPs.gen(pomdp::GenCryingBabyPOMDP, s::CryingBabyState, a::Symbol, rng::AbstractRNG) +function POMDPs.gen(pomdp::GenCryingBabyPOMDP, s::GenCryingBabyState, a::Symbol, rng::AbstractRNG) if a == :feed sp = GenCryingBabyState(false) @@ -311,4 +311,4 @@ R = [-5.0 -0.5 0.0; discount = 0.9 tabular_crying_baby_pomdp = TabularPOMDP(T, R, O, discount) -``` \ No newline at end of file +``` diff --git a/lib/POMDPTools/src/ModelTools/ModelTools.jl b/lib/POMDPTools/src/ModelTools/ModelTools.jl index f8d6340d..2524f352 100644 --- a/lib/POMDPTools/src/ModelTools/ModelTools.jl +++ b/lib/POMDPTools/src/ModelTools/ModelTools.jl @@ -40,7 +40,10 @@ export terminalstate include("terminal_state.jl") -export GenerativeBeliefMDP +export GenerativeBeliefMDP, + DefaultGBMDPTerminalBehavior, + ContinueTerminalBehavior, + TerminalStateTerminalBehavior include("generative_belief_mdp.jl") export FullyObservablePOMDP @@ -78,4 +81,8 @@ export reward_vectors include("matrices.jl") +export + gbmdp_handle_terminal +include("deprecated.jl") + end diff --git a/lib/POMDPTools/src/ModelTools/deprecated.jl b/lib/POMDPTools/src/ModelTools/deprecated.jl new file mode 100644 index 00000000..f2d08480 --- /dev/null +++ b/lib/POMDPTools/src/ModelTools/deprecated.jl @@ -0,0 +1 @@ +gbmdp_handle_terminal(pomdp, updater, b, s, a, rng) = nothing diff --git a/lib/POMDPTools/src/ModelTools/generative_belief_mdp.jl b/lib/POMDPTools/src/ModelTools/generative_belief_mdp.jl index 89c5af4a..f610781e 100644 --- a/lib/POMDPTools/src/ModelTools/generative_belief_mdp.jl +++ b/lib/POMDPTools/src/ModelTools/generative_belief_mdp.jl @@ -1,34 +1,40 @@ """ GenerativeBeliefMDP(pomdp, updater) + GenerativeBeliefMDP(pomdp, updater; terminal_behavior=TerminalStateTerminalBehavior()) -<<<<<<< Updated upstream -Create a generative model of the belief MDP corresponding to POMDP `pomdp` with belief updates performed by `updater`. -======= Create a generative model of the belief MDP corresponding to POMDP `pomdp` with belief updates performed by `updater`. Each step is performed by sampling a state from the current belief, generating an observation from that state and action, and then using `updater` to update the belief. A belief is considered terminal when _all_ POMDP states in the support with nonzero probability are terminal. -The default behavior when a terminal POMDP state is sampled from the belief is to transition to [`terminalstate`](@ref). This can be controlled by the `terminal_behavior` keyword argument. Using `terminal_behavior=ContinueTerminalBehavior(pomdp, updater)` will cause the MDP to keep attempting a belief update even when the sampled state is terminal. This can be further customized by providing `terminal_behavior` with a `Function` or callable object that takes arguments `b, s, a, rng` and returns a new belief (see the implementation of `ContinueTerminalBehavior` for an example). `determine_gbmdp_state_type` can be used to further customize behavior. ->>>>>>> Stashed changes +The default behavior when a terminal POMDP state is sampled from the belief is to transition to [`terminalstate`](@ref). This can be controlled by the `terminal_behavior` keyword argument. Using `terminal_behavior=ContinueTerminalBehavior(pomdp, updater)` will cause the MDP to keep attempting a belief update even when the sampled state is terminal. This can be further customized by providing `terminal_behavior` with a `Function` or callable object that takes arguments b, s, a, rng and returns a new belief (see the implementation of `ContinueTerminalBehavior` for an example). """ -struct GenerativeBeliefMDP{P<:POMDP, U<:Updater, B, A} <: MDP{B, A} +struct GenerativeBeliefMDP{P<:POMDP, U<:Updater, T, B, A} <: MDP{B, A} pomdp::P updater::U + terminal_behavior::T end -function GenerativeBeliefMDP(pomdp::P, up::U) where {P<:POMDP, U<:Updater} - # XXX hack to determine belief type - b0 = initialize_belief(up, initialstate(pomdp)) - GenerativeBeliefMDP{P, U, typeof(b0), actiontype(pomdp)}(pomdp, up) +function GenerativeBeliefMDP(pomdp, updater; terminal_behavior=BackwardCompatibleTerminalBehavior(pomdp, updater)) + B = determine_gbmdp_state_type(pomdp, updater, terminal_behavior) + GenerativeBeliefMDP{typeof(pomdp), + typeof(updater), + typeof(terminal_behavior), + B, + actiontype(pomdp) + }(pomdp, updater, terminal_behavior) +end + +function initialstate(bmdp::GenerativeBeliefMDP) + return Deterministic(initialize_belief(bmdp.updater, initialstate(bmdp.pomdp))) end function POMDPs.gen(bmdp::GenerativeBeliefMDP, b, a, rng::AbstractRNG) s = rand(rng, b) if isterminal(bmdp.pomdp, s) - bp = gbmdp_handle_terminal(bmdp.pomdp, bmdp.updater, b, s, a, rng::AbstractRNG)::typeof(b) + bp = bmdp.terminal_behavior(b, s, a, rng) return (sp=bp, r=0.0) end - sp, o, r = @gen(:sp, :o, :r)(bmdp.pomdp, s, a, rng) # maybe this should have been generate_or? + o, r = @gen(:o, :r)(bmdp.pomdp, s, a, rng) bp = update(bmdp.updater, b, a, o) return (sp=bp, r=r) end @@ -36,28 +42,58 @@ end actions(bmdp::GenerativeBeliefMDP{P,U,B,A}, b::B) where {P,U,B,A} = actions(bmdp.pomdp, b) actions(bmdp::GenerativeBeliefMDP) = actions(bmdp.pomdp) -isterminal(bmdp::GenerativeBeliefMDP, b) = all(isterminal(bmdp.pomdp, s) for s in support(b)) +isterminal(bmdp::GenerativeBeliefMDP, b) = all(s -> isterminal(bmdp.pomdp, s) || pdf(b, s) == 0.0, support(b)) +isterminal(bmdp::GenerativeBeliefMDP, ts::TerminalState) = true discount(bmdp::GenerativeBeliefMDP) = discount(bmdp.pomdp) -# override this if you want to handle it in a special way -function gbmdp_handle_terminal(pomdp::POMDP, updater::Updater, b, s, a, rng) - @warn(""" - Sampled a terminal state for a GenerativeBeliefMDP transition - not sure how to proceed, but will try. +""" + determine_gbmdp_state_type(pomdp, updater, [terminal_behavior]) + +This function is called to determine the state type for a GenerativeBeliefMDP. By default, it will return typeof(initialize_belief(updater, initialstate(pomdp))). - See $(@__FILE__) and implement a new method of POMDPToolbox.gbmdp_handle_terminal if you want special behavior in this case. +If a belief updater may use a belief type different from the output of initialize_belief, for example if the belief type can change after an update, override `determine_gbmdp_state_type(pomdp, updater)`. - """, maxlog=1) - sp, o, r = @gen(:sp, :o, :r)(pomdp, s, a, rng) - bp = update(updater, b, a, o) - return bp +If the terminal behavior adds a new possible state type, override `determine_gbmdp_state_type(pomdp, updater, terminal_behavior)` to return the `Union` of the new state type and the output of `determine_gbmdp_state_type(pomdp, updater)` +""" +function determine_gbmdp_state_type end # for documentation + +function determine_gbmdp_state_type(pomdp, updater) + b0 = initialize_belief(updater, initialstate(pomdp)) + return typeof(b0) end -function initialstate(bmdp::GenerativeBeliefMDP) - return Deterministic(initialize_belief(bmdp.updater, initialstate(bmdp.pomdp))) +determine_gbmdp_state_type(pomdp, updater, terminal_behavior) = determine_gbmdp_state_type(pomdp, updater) + +struct BackwardCompatibleTerminalBehavior{M, U} + pomdp::M + updater::U end -# deprecated in POMDPs v0.9 -function initialstate(bmdp::GenerativeBeliefMDP, rng::AbstractRNG) - return initialize_belief(bmdp.updater, initialstate(bmdp.pomdp)) +function (tb::BackwardCompatibleTerminalBehavior)(b, s, a, rng) + + # This code block is only to handle backwards compatibility for the deprecated gbmdp_handle_terminal function + bp = gbmdp_handle_terminal(tb.pomdp, tb.updater, b, s, a, rng) + if bp != nothing # user has implemented gbmdp_handle_terminal + Base.depwarn("Using gbmdp_handle_terminal to specify terminal behavior for a GenerativeBeliefMDP is deprecated. Use the terminal_behavior keyword argument instead.", :gbmdp_handle_terminal) + return bp + end + + return TerminalStateTerminalBehavior()(b, s, a, rng) end + +determine_gbmdp_state_type(pomdp, updater, tb::BackwardCompatibleTerminalBehavior) = determine_gbmdp_state_type(pomdp, updater, TerminalStateTerminalBehavior()) + +struct ContinueTerminalBehavior{M, U} + pomdp::M + updater::U +end + +function (tb::ContinueTerminalBehavior)(b, s, a, rng) + o, r = @gen(:o, :r)(tb.pomdp, s, a, rng) + return update(tb.updater, b, a, o) +end + +struct TerminalStateTerminalBehavior end +(tb::TerminalStateTerminalBehavior)(args...) = terminalstate +determine_gbmdp_state_type(pomdp, updater, tb::TerminalStateTerminalBehavior) = promote_type(determine_gbmdp_state_type(pomdp, updater), TerminalState) diff --git a/lib/POMDPTools/src/ModelTools/terminal_state.jl b/lib/POMDPTools/src/ModelTools/terminal_state.jl index 49e044ac..6113d956 100644 --- a/lib/POMDPTools/src/ModelTools/terminal_state.jl +++ b/lib/POMDPTools/src/ModelTools/terminal_state.jl @@ -12,9 +12,9 @@ struct TerminalState end """ terminalstate -The singleton instance of type `TerminalState` representing a terminal state. +The singleton instance of type [`TerminalState`](@ref) representing a terminal state. """ const terminalstate = TerminalState() -isterminal(m::Union{MDP,POMDP}, ts::TerminalState) = true +isterminal(m::Union{MDP,POMDP}, ts::TerminalState) = true Base.promote_rule(::Type{TerminalState}, T::Type) = Union{TerminalState, T} diff --git a/lib/POMDPTools/test/model_tools/test_generative_belief_mdp.jl b/lib/POMDPTools/test/model_tools/test_generative_belief_mdp.jl index d02ce0c0..464f3a14 100644 --- a/lib/POMDPTools/test/model_tools/test_generative_belief_mdp.jl +++ b/lib/POMDPTools/test/model_tools/test_generative_belief_mdp.jl @@ -1,7 +1,65 @@ let - pomdp = BabyPOMDP() - up = updater(pomdp) + @testset "GenerativeBeliefMDP" begin + @testset "Baby" begin + pomdp = BabyPOMDP() + up = updater(pomdp) - bmdp = GenerativeBeliefMDP(pomdp, up) - b = initialstate(bmdp, Random.default_rng()) + bmdp = GenerativeBeliefMDP(pomdp, up) + b = rand(initialstate(bmdp)) + @test rand(b) isa statetype(pomdp) + + @test simulate(RolloutSimulator(max_steps=10), bmdp, RandomPolicy(bmdp)) <= 0 + end + + terminal_test_m = QuickPOMDP( + states = 1:2, + actions = 1:2, + observations = 1:2, + transition = (s, a) -> Deterministic(1), + observation = (a, sp) -> Deterministic(sp), + reward = s -> 0.0, + isterminal = ==(1), + initialstate = Deterministic(2) + ) + + @testset "Terminal Default" begin + up = DiscreteUpdater(terminal_test_m) + bm = GenerativeBeliefMDP(terminal_test_m, up) + + hist = collect(stepthrough(bm, RandomPolicy(bm), "s,sp", max_steps=10)) + @test length(hist) == 1 + @test only(hist).s == DiscreteBelief(terminal_test_m, [0.0, 1.0]) + @test only(hist).sp == DiscreteBelief(terminal_test_m, [1.0, 0.0]) + @test !isterminal(bm, only(hist).s) + @test isterminal(bm, only(hist).sp) + end + + @testset "Terminal Uninformative Update" begin + struct UninformativeUpdater{M} <: Updater + m::M + end + + POMDPs.update(up::UninformativeUpdater, b, a, o) = Uniform(states(up.m)) + POMDPs.initialize_belief(up::UninformativeUpdater, d::Deterministic) = Uniform(rand(d)) + + up = UninformativeUpdater(terminal_test_m) + + # default terminal behavior + bm = GenerativeBeliefMDP(terminal_test_m, up) + hist = collect(stepthrough(bm, RandomPolicy(bm), "s,sp")) + @test isterminal(bm, last(hist).sp) + + behavior = TerminalStateTerminalBehavior() + bm = GenerativeBeliefMDP(terminal_test_m, up) + hist = collect(stepthrough(bm, RandomPolicy(bm), "s,sp")) + @test last(hist).sp === terminalstate + @test isterminal(bm, last(hist).sp) + + behavior = ContinueTerminalBehavior(terminal_test_m, up) + bm = GenerativeBeliefMDP(terminal_test_m, up, terminal_behavior=behavior) + hist = collect(stepthrough(bm, RandomPolicy(bm), "s,sp", max_steps=10)) + @test length(hist) == 10 + end + + end end diff --git a/src/simulator.jl b/src/simulator.jl index 1fb4a440..bb9248e5 100644 --- a/src/simulator.jl +++ b/src/simulator.jl @@ -9,6 +9,6 @@ abstract type Simulator end Run a simulation using the specified policy. -The return type is flexible and depends on the simulator. Simulations should adhere to the [Simulation Standard](http://juliapomdp.github.io/POMDPs.jl/latest/simulation.html#Simulation-Standard-1). +The return type is flexible and depends on the simulator. Simulations should adhere to the [Simulation Standard](https://juliapomdp.github.io/POMDPs.jl/stable/simulation/#Simulation-Standard). """ function simulate end