diff --git a/src/CompressedBeliefMDPs.jl b/src/CompressedBeliefMDPs.jl index 1842d50..31ac8d8 100644 --- a/src/CompressedBeliefMDPs.jl +++ b/src/CompressedBeliefMDPs.jl @@ -14,7 +14,7 @@ using Bijections using NearestNeighbors using StaticArrays using Combinatorics -using IterTools +using IterTools # TODO: not sure if I need this anymore using LinearAlgebra using Random diff --git a/src/envs/circular.jl b/src/envs/circular.jl index afac476..a081e79 100644 --- a/src/envs/circular.jl +++ b/src/envs/circular.jl @@ -128,8 +128,11 @@ function POMDPs.observation(pomdp::CircularMaze, s::CircularMazeState, a::Intege if a == CMAZE_SENSE_CORRIDOR obs = Deterministic(s.corridor) else - values = 1:pomdp.corridor_length + values = states(pomdp) probabilities = _center_probabilities(pomdp, s.x) + probabilities = repeat(pomdp.probabilities, pomdp.n_corridors) + probabilities /= pomdp.n_corridors # normalize values to sum to 1 + push!(probabilities, 0) # address OBOE from terminalstate d = SparseCat(values, probabilities) obs = d end @@ -137,16 +140,13 @@ function POMDPs.observation(pomdp::CircularMaze, s::CircularMazeState, a::Intege end function POMDPs.observations(pomdp::CircularMaze) - # NOTE: In JuliaPOMDPs, an observation space is NOT the set of possible distributions, but rather union of the support of all possible observations - corridors = 1:pomdp.n_corridors # from CMAZE_SENSE_CORRIDOR - perms = permutations(pomdp.probabilities) - space = Iterators.flatten(corridors, perms) # generator + corridors = 1:pomdp.n_corridors + space = IterTools.chain(states(pomdp), corridors) return space end # TODO: maybe implement POMDPs.obsindex -# TODO: confirm that transitions are non-Deterministic function POMDPs.transition(pomdp::CircularMaze, s::CircularMazeState, a::Integer) @assert a in actions(pomdp) "Unrecognized action $a" if a == CMAZE_DECLARE_GOAL @@ -160,7 +160,10 @@ function POMDPs.transition(pomdp::CircularMaze, s::CircularMazeState, a::Integer x = pomdp.corridor_length end elseif a == CMAZE_RIGHT - x = (s.x + 1) % pomdp.corridor_length + x = s.x + 1 + if x > pomdp.corridor_length + x = 1 + end else x = s.x end