Skip to content

Commit

Permalink
witness pruning with sets in alpha vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
qhho committed Jul 27, 2024
1 parent a7178b7 commit ba8c452
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 31 deletions.
1 change: 1 addition & 0 deletions src/alpha.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
struct AlphaVec{A} <: AbstractVector{Float64}
alpha::Vector{Float64}
action::A
witnesses::Set{Int}
end

@inline Base.length(v::AlphaVec) = length(v.alpha)
Expand Down
2 changes: 1 addition & 1 deletion src/backup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ function backup!(tree, b_idx)
end
end

α = AlphaVec(best_α, best_action)
α = AlphaVec(best_α, best_action, Set(b_idx))
push!(Γ, α)
tree.V_lower[b_idx] = V
end
Expand Down
57 changes: 28 additions & 29 deletions src/prune.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,27 @@ function prune!(tree::SARSOPTree)
end
end

function belief_space_domination(α1, α2, B, δ)
a1_dominant = true
a2_dominant = true
for b B
!a1_dominant && !a2_dominant && return (false, false)
δV = intersection_distance(α1, α2, b)
δV δ && (a1_dominant = false)
δV -δ && (a2_dominant = false)
function recertify_witnesses!(tree, α1, α2, δ)

if α1 == α2
union!(α2.witnesses, α1.witnesses)
empty!(α1.witnesses)
return
end

for b_idx in α1.witnesses
if tree.b_pruned[b_idx]
delete!(α1.witnesses, b_idx)
continue
end

δV = intersection_distance(α2, α1, tree.b[b_idx])

if δV > δ
delete!(α1.witnesses, b_idx)
push!(α2.witnesses, b_idx)
end
end
return a1_dominant, a2_dominant
end

@inline function intersection_distance(α1, α2, b)
Expand All @@ -75,33 +86,21 @@ end

function prune_alpha!(tree::SARSOPTree, δ)
Γ = tree.Γ
B_valid = tree.b[map(!,tree.b_pruned)]
pruned = falses(length(Γ))

# checking if α_i dominates α_j
for (i,α_i) enumerate(Γ)
for (i, α_i) enumerate(Γ)
pruned[i] && continue
for (j,α_j) enumerate(Γ)
(j i || pruned[j]) && continue
a1_dominant,a2_dominant = belief_space_domination(α_i, α_j, B_valid, δ)
#=
NOTE: α1 and α2 shouldn't technically be able to mutually dominate
i.e. a1_dominant and a2_dominant should never both be true.
But this does happen when α1 == α2 because intersection_distance returns NaN.
Current impl prunes α2 without doing an equality check, removing
the duplicate α. Could do equality check to short-circuit
belief_space_domination which would speed things up if we have
a lot of duplicates, but the equality check can slow things down
if α's are sufficiently diverse.
=#
if a1_dominant
pruned[j] = true
elseif a2_dominant
for (j, α_j) enumerate(Γ)
pruned[j] || j == i && continue
recertify_witnesses!(tree, α_i, α_j, δ)
if isempty(α_i.witnesses)
pruned[i] = true
break
elseif isempty(α_j.witnesses)
pruned[j] = true
end
end
end
deleteat!(Γ, pruned)
tree.prune_data.last_Γ_size = length(Γ)
end
end
2 changes: 1 addition & 1 deletion src/tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ function insert_root!(solver, tree::SARSOPTree, b)
Γ_lower = solve(solver.init_lower, pomdp)
for (α,a) alphapairs(Γ_lower)
new_val = dot(α, b)
push!(tree.Γ, AlphaVec(α, a))
push!(tree.Γ, AlphaVec(α, a, Set(1)))
end
tree.prune_data.last_Γ_size = length(tree.Γ)

Expand Down

0 comments on commit ba8c452

Please sign in to comment.