Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/JuliaPOMDP/POMDPs.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
zsunberg committed Jul 12, 2024
2 parents cf62e5b + d18d9d5 commit cbb9c89
Show file tree
Hide file tree
Showing 8 changed files with 148 additions and 41 deletions.
7 changes: 6 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
name: CI

on: [push, pull_request, workflow_dispatch]
on:
push:
pull_request:
workflow_dispatch:
schedule:
- cron: '30 23 * * 0' # Runs every Sunday at 11:30 PM

jobs:
test:
Expand Down
10 changes: 5 additions & 5 deletions docs/src/example_defining_problems.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ quick_crying_baby_pomdp = QuickPOMDP(
using POMDPs
using POMDPTools

struct CryingBabyState
struct CryingBabyState # Alternatively, you could just use a Bool or Symbol for the state.
hungry::Bool
end

Expand Down Expand Up @@ -208,11 +208,11 @@ using POMDPs
using POMDPTools
using Random

struct GenCryingBabyState
struct GenCryingBabyState # Alternatively, you could just use a Bool or Symbol for the state.
hungry::Bool
end

struct GenCryingBabyPOMDP <: POMDP{CryingBabyState, Symbol, Symbol}
struct GenCryingBabyPOMDP <: POMDP{GenCryingBabyState, Symbol, Symbol}
p_sated_to_hungry::Float64
p_cry_feed_hungry::Float64
p_cry_sing_hungry::Float64
Expand All @@ -228,7 +228,7 @@ struct GenCryingBabyPOMDP <: POMDP{CryingBabyState, Symbol, Symbol}
GenCryingBabyPOMDP() = new(0.1, 0.8, 0.9, 0.8, 0.1, 0.0, 0.1, -10.0, -5.0, -0.5, 0.9)
end

function POMDPs.gen(pomdp::GenCryingBabyPOMDP, s::CryingBabyState, a::Symbol, rng::AbstractRNG)
function POMDPs.gen(pomdp::GenCryingBabyPOMDP, s::GenCryingBabyState, a::Symbol, rng::AbstractRNG)

if a == :feed
sp = GenCryingBabyState(false)
Expand Down Expand Up @@ -311,4 +311,4 @@ R = [-5.0 -0.5 0.0;
discount = 0.9

tabular_crying_baby_pomdp = TabularPOMDP(T, R, O, discount)
```
```
9 changes: 8 additions & 1 deletion lib/POMDPTools/src/ModelTools/ModelTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ export
terminalstate
include("terminal_state.jl")

export GenerativeBeliefMDP
export GenerativeBeliefMDP,
DefaultGBMDPTerminalBehavior,
ContinueTerminalBehavior,
TerminalStateTerminalBehavior
include("generative_belief_mdp.jl")

export FullyObservablePOMDP
Expand Down Expand Up @@ -78,4 +81,8 @@ export
reward_vectors
include("matrices.jl")

export
gbmdp_handle_terminal
include("deprecated.jl")

end
1 change: 1 addition & 0 deletions lib/POMDPTools/src/ModelTools/deprecated.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
gbmdp_handle_terminal(pomdp, updater, b, s, a, rng) = nothing
90 changes: 63 additions & 27 deletions lib/POMDPTools/src/ModelTools/generative_belief_mdp.jl
Original file line number Diff line number Diff line change
@@ -1,63 +1,99 @@
"""
GenerativeBeliefMDP(pomdp, updater)
GenerativeBeliefMDP(pomdp, updater; terminal_behavior=TerminalStateTerminalBehavior())
<<<<<<< Updated upstream
Create a generative model of the belief MDP corresponding to POMDP `pomdp` with belief updates performed by `updater`.
=======
Create a generative model of the belief MDP corresponding to POMDP `pomdp` with belief updates performed by `updater`. Each step is performed by sampling a state from the current belief, generating an observation from that state and action, and then using `updater` to update the belief.
A belief is considered terminal when _all_ POMDP states in the support with nonzero probability are terminal.
The default behavior when a terminal POMDP state is sampled from the belief is to transition to [`terminalstate`](@ref). This can be controlled by the `terminal_behavior` keyword argument. Using `terminal_behavior=ContinueTerminalBehavior(pomdp, updater)` will cause the MDP to keep attempting a belief update even when the sampled state is terminal. This can be further customized by providing `terminal_behavior` with a `Function` or callable object that takes arguments `b, s, a, rng` and returns a new belief (see the implementation of `ContinueTerminalBehavior` for an example). `determine_gbmdp_state_type` can be used to further customize behavior.
>>>>>>> Stashed changes
The default behavior when a terminal POMDP state is sampled from the belief is to transition to [`terminalstate`](@ref). This can be controlled by the `terminal_behavior` keyword argument. Using `terminal_behavior=ContinueTerminalBehavior(pomdp, updater)` will cause the MDP to keep attempting a belief update even when the sampled state is terminal. This can be further customized by providing `terminal_behavior` with a `Function` or callable object that takes arguments b, s, a, rng and returns a new belief (see the implementation of `ContinueTerminalBehavior` for an example).
"""
struct GenerativeBeliefMDP{P<:POMDP, U<:Updater, B, A} <: MDP{B, A}
struct GenerativeBeliefMDP{P<:POMDP, U<:Updater, T, B, A} <: MDP{B, A}
pomdp::P
updater::U
terminal_behavior::T
end

function GenerativeBeliefMDP(pomdp::P, up::U) where {P<:POMDP, U<:Updater}
# XXX hack to determine belief type
b0 = initialize_belief(up, initialstate(pomdp))
GenerativeBeliefMDP{P, U, typeof(b0), actiontype(pomdp)}(pomdp, up)
function GenerativeBeliefMDP(pomdp, updater; terminal_behavior=BackwardCompatibleTerminalBehavior(pomdp, updater))
B = determine_gbmdp_state_type(pomdp, updater, terminal_behavior)
GenerativeBeliefMDP{typeof(pomdp),
typeof(updater),
typeof(terminal_behavior),
B,
actiontype(pomdp)
}(pomdp, updater, terminal_behavior)
end

function initialstate(bmdp::GenerativeBeliefMDP)
return Deterministic(initialize_belief(bmdp.updater, initialstate(bmdp.pomdp)))
end

function POMDPs.gen(bmdp::GenerativeBeliefMDP, b, a, rng::AbstractRNG)
s = rand(rng, b)
if isterminal(bmdp.pomdp, s)
bp = gbmdp_handle_terminal(bmdp.pomdp, bmdp.updater, b, s, a, rng::AbstractRNG)::typeof(b)
bp = bmdp.terminal_behavior(b, s, a, rng)
return (sp=bp, r=0.0)
end
sp, o, r = @gen(:sp, :o, :r)(bmdp.pomdp, s, a, rng) # maybe this should have been generate_or?
o, r = @gen(:o, :r)(bmdp.pomdp, s, a, rng)
bp = update(bmdp.updater, b, a, o)
return (sp=bp, r=r)
end

actions(bmdp::GenerativeBeliefMDP{P,U,B,A}, b::B) where {P,U,B,A} = actions(bmdp.pomdp, b)
actions(bmdp::GenerativeBeliefMDP) = actions(bmdp.pomdp)

isterminal(bmdp::GenerativeBeliefMDP, b) = all(isterminal(bmdp.pomdp, s) for s in support(b))
isterminal(bmdp::GenerativeBeliefMDP, b) = all(s -> isterminal(bmdp.pomdp, s) || pdf(b, s) == 0.0, support(b))
isterminal(bmdp::GenerativeBeliefMDP, ts::TerminalState) = true

discount(bmdp::GenerativeBeliefMDP) = discount(bmdp.pomdp)

# override this if you want to handle it in a special way
function gbmdp_handle_terminal(pomdp::POMDP, updater::Updater, b, s, a, rng)
@warn("""
Sampled a terminal state for a GenerativeBeliefMDP transition - not sure how to proceed, but will try.
"""
determine_gbmdp_state_type(pomdp, updater, [terminal_behavior])
This function is called to determine the state type for a GenerativeBeliefMDP. By default, it will return typeof(initialize_belief(updater, initialstate(pomdp))).
See $(@__FILE__) and implement a new method of POMDPToolbox.gbmdp_handle_terminal if you want special behavior in this case.
If a belief updater may use a belief type different from the output of initialize_belief, for example if the belief type can change after an update, override `determine_gbmdp_state_type(pomdp, updater)`.
""", maxlog=1)
sp, o, r = @gen(:sp, :o, :r)(pomdp, s, a, rng)
bp = update(updater, b, a, o)
return bp
If the terminal behavior adds a new possible state type, override `determine_gbmdp_state_type(pomdp, updater, terminal_behavior)` to return the `Union` of the new state type and the output of `determine_gbmdp_state_type(pomdp, updater)`
"""
function determine_gbmdp_state_type end # for documentation

function determine_gbmdp_state_type(pomdp, updater)
b0 = initialize_belief(updater, initialstate(pomdp))
return typeof(b0)
end

function initialstate(bmdp::GenerativeBeliefMDP)
return Deterministic(initialize_belief(bmdp.updater, initialstate(bmdp.pomdp)))
determine_gbmdp_state_type(pomdp, updater, terminal_behavior) = determine_gbmdp_state_type(pomdp, updater)

struct BackwardCompatibleTerminalBehavior{M, U}
pomdp::M
updater::U
end

# deprecated in POMDPs v0.9
function initialstate(bmdp::GenerativeBeliefMDP, rng::AbstractRNG)
return initialize_belief(bmdp.updater, initialstate(bmdp.pomdp))
function (tb::BackwardCompatibleTerminalBehavior)(b, s, a, rng)

# This code block is only to handle backwards compatibility for the deprecated gbmdp_handle_terminal function
bp = gbmdp_handle_terminal(tb.pomdp, tb.updater, b, s, a, rng)
if bp != nothing # user has implemented gbmdp_handle_terminal
Base.depwarn("Using gbmdp_handle_terminal to specify terminal behavior for a GenerativeBeliefMDP is deprecated. Use the terminal_behavior keyword argument instead.", :gbmdp_handle_terminal)
return bp
end

return TerminalStateTerminalBehavior()(b, s, a, rng)
end

determine_gbmdp_state_type(pomdp, updater, tb::BackwardCompatibleTerminalBehavior) = determine_gbmdp_state_type(pomdp, updater, TerminalStateTerminalBehavior())

struct ContinueTerminalBehavior{M, U}
pomdp::M
updater::U
end

function (tb::ContinueTerminalBehavior)(b, s, a, rng)
o, r = @gen(:o, :r)(tb.pomdp, s, a, rng)
return update(tb.updater, b, a, o)
end

struct TerminalStateTerminalBehavior end
(tb::TerminalStateTerminalBehavior)(args...) = terminalstate
determine_gbmdp_state_type(pomdp, updater, tb::TerminalStateTerminalBehavior) = promote_type(determine_gbmdp_state_type(pomdp, updater), TerminalState)
4 changes: 2 additions & 2 deletions lib/POMDPTools/src/ModelTools/terminal_state.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ struct TerminalState end
"""
terminalstate
The singleton instance of type `TerminalState` representing a terminal state.
The singleton instance of type [`TerminalState`](@ref) representing a terminal state.
"""
const terminalstate = TerminalState()

isterminal(m::Union{MDP,POMDP}, ts::TerminalState) = true
isterminal(m::Union{MDP,POMDP}, ts::TerminalState) = true
Base.promote_rule(::Type{TerminalState}, T::Type) = Union{TerminalState, T}
66 changes: 62 additions & 4 deletions lib/POMDPTools/test/model_tools/test_generative_belief_mdp.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,65 @@
let
pomdp = BabyPOMDP()
up = updater(pomdp)
@testset "GenerativeBeliefMDP" begin
@testset "Baby" begin
pomdp = BabyPOMDP()
up = updater(pomdp)

bmdp = GenerativeBeliefMDP(pomdp, up)
b = initialstate(bmdp, Random.default_rng())
bmdp = GenerativeBeliefMDP(pomdp, up)
b = rand(initialstate(bmdp))
@test rand(b) isa statetype(pomdp)

@test simulate(RolloutSimulator(max_steps=10), bmdp, RandomPolicy(bmdp)) <= 0
end

terminal_test_m = QuickPOMDP(
states = 1:2,
actions = 1:2,
observations = 1:2,
transition = (s, a) -> Deterministic(1),
observation = (a, sp) -> Deterministic(sp),
reward = s -> 0.0,
isterminal = ==(1),
initialstate = Deterministic(2)
)

@testset "Terminal Default" begin
up = DiscreteUpdater(terminal_test_m)
bm = GenerativeBeliefMDP(terminal_test_m, up)

hist = collect(stepthrough(bm, RandomPolicy(bm), "s,sp", max_steps=10))
@test length(hist) == 1
@test only(hist).s == DiscreteBelief(terminal_test_m, [0.0, 1.0])
@test only(hist).sp == DiscreteBelief(terminal_test_m, [1.0, 0.0])
@test !isterminal(bm, only(hist).s)
@test isterminal(bm, only(hist).sp)
end

@testset "Terminal Uninformative Update" begin
struct UninformativeUpdater{M} <: Updater
m::M
end

POMDPs.update(up::UninformativeUpdater, b, a, o) = Uniform(states(up.m))
POMDPs.initialize_belief(up::UninformativeUpdater, d::Deterministic) = Uniform(rand(d))

up = UninformativeUpdater(terminal_test_m)

# default terminal behavior
bm = GenerativeBeliefMDP(terminal_test_m, up)
hist = collect(stepthrough(bm, RandomPolicy(bm), "s,sp"))
@test isterminal(bm, last(hist).sp)

behavior = TerminalStateTerminalBehavior()
bm = GenerativeBeliefMDP(terminal_test_m, up)
hist = collect(stepthrough(bm, RandomPolicy(bm), "s,sp"))
@test last(hist).sp === terminalstate
@test isterminal(bm, last(hist).sp)

behavior = ContinueTerminalBehavior(terminal_test_m, up)
bm = GenerativeBeliefMDP(terminal_test_m, up, terminal_behavior=behavior)
hist = collect(stepthrough(bm, RandomPolicy(bm), "s,sp", max_steps=10))
@test length(hist) == 10
end

end
end
2 changes: 1 addition & 1 deletion src/simulator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ abstract type Simulator end
Run a simulation using the specified policy.
The return type is flexible and depends on the simulator. Simulations should adhere to the [Simulation Standard](http://juliapomdp.github.io/POMDPs.jl/latest/simulation.html#Simulation-Standard-1).
The return type is flexible and depends on the simulator. Simulations should adhere to the [Simulation Standard](https://juliapomdp.github.io/POMDPs.jl/stable/simulation/#Simulation-Standard).
"""
function simulate end

0 comments on commit cbb9c89

Please sign in to comment.