Skip to content

Commit

Permalink
observation changes
Browse files Browse the repository at this point in the history
  • Loading branch information
FlyingWorkshop committed Jun 24, 2024
1 parent 97383f3 commit 31b4e6c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/CompressedBeliefMDPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ using Bijections
using NearestNeighbors
using StaticArrays
using Combinatorics
using IterTools
using IterTools # TODO: not sure if I need this anymore

using LinearAlgebra
using Random
Expand Down
17 changes: 10 additions & 7 deletions src/envs/circular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,25 +128,25 @@ function POMDPs.observation(pomdp::CircularMaze, s::CircularMazeState, a::Intege
if a == CMAZE_SENSE_CORRIDOR
obs = Deterministic(s.corridor)
else
values = 1:pomdp.corridor_length
values = states(pomdp)
probabilities = _center_probabilities(pomdp, s.x)
probabilities = repeat(pomdp.probabilities, pomdp.n_corridors)
probabilities /= pomdp.n_corridors # normalize values to sum to 1
push!(probabilities, 0) # address OBOE from terminalstate
d = SparseCat(values, probabilities)
obs = d
end
return obs
end

function POMDPs.observations(pomdp::CircularMaze)
# NOTE: In JuliaPOMDPs, an observation space is NOT the set of possible distributions, but rather union of the support of all possible observations
corridors = 1:pomdp.n_corridors # from CMAZE_SENSE_CORRIDOR
perms = permutations(pomdp.probabilities)
space = Iterators.flatten(corridors, perms) # generator
corridors = 1:pomdp.n_corridors
space = IterTools.chain(states(pomdp), corridors)
return space
end

# TODO: maybe implement POMDPs.obsindex

# TODO: confirm that transitions are non-Deterministic
function POMDPs.transition(pomdp::CircularMaze, s::CircularMazeState, a::Integer)
@assert a in actions(pomdp) "Unrecognized action $a"
if a == CMAZE_DECLARE_GOAL
Expand All @@ -160,7 +160,10 @@ function POMDPs.transition(pomdp::CircularMaze, s::CircularMazeState, a::Integer
x = pomdp.corridor_length
end
elseif a == CMAZE_RIGHT
x = (s.x + 1) % pomdp.corridor_length
x = s.x + 1
if x > pomdp.corridor_length
x = 1
end
else
x = s.x
end
Expand Down

0 comments on commit 31b4e6c

Please sign in to comment.