Skip to content

Commit

Permalink
Added outputs for the verbose kwarg (#16)
Browse files Browse the repository at this point in the history
* Add output when verbose=true

* Added tests for earliest support version of julia listed in compat

* Formatting updates (BlueStyle)

* Update to calculation of number of beleifs
  • Loading branch information
dylan-asmar authored Jun 22, 2024
1 parent dd819c3 commit da75251
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 16 deletions.
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ jobs:
strategy:
matrix:
version:
- '1.7' # Earliest support listed in compat
- '1'
os:
- ubuntu-latest
Expand Down
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@ version = "0.1.3"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
POMDPTools = "7588e00f-9cae-40de-98dc-e0c70c48cdd7"
POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[compat]
LinearAlgebra = "1"
POMDPTools = "0.1, 1"
POMDPs = "0.9, 1"
Printf = "1"
julia = "1.7"
1 change: 1 addition & 0 deletions src/NativeSARSOP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using POMDPs
using POMDPTools
using SparseArrays
using LinearAlgebra
using Printf

export SARSOPSolver, SARSOPTree

Expand Down
32 changes: 31 additions & 1 deletion src/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,29 @@ end

function POMDPTools.solve_info(solver::SARSOPSolver, pomdp::POMDP)
tree = SARSOPTree(solver, pomdp)


if solver.verbose
initialize_verbose_output()
end

t0 = time()
iter = 0
while time()-t0 < solver.max_time && root_diff(tree) > solver.precision
sample!(solver, tree)
backup!(tree)
prune!(solver, tree)
if solver.verbose && iter % 10 == 0
log_verbose_info(t0, iter, tree)
end
iter += 1
end

if solver.verbose
dashed_line()
log_verbose_info(t0, iter, tree)
dashed_line()
end

pol = AlphaVectorPolicy(
pomdp,
getproperty.(tree.Γ, :alpha),
Expand All @@ -36,3 +49,20 @@ function POMDPTools.solve_info(solver::SARSOPSolver, pomdp::POMDP)
end

POMDPs.solve(solver::SARSOPSolver, pomdp::POMDP) = first(solve_info(solver, pomdp))

function initialize_verbose_output()
dashed_line()
@printf(" %-10s %-10s %-12s %-12s %-15s %-10s %-10s\n",
"Time", "Iter", "LB", "UB", "Precision", "# Alphas", "# Beliefs")
dashed_line()
end

function log_verbose_info(t0::Float64, iter::Int, tree::SARSOPTree)
@printf(" %-10.2f %-10d %-12.7f %-12.7f %-15.10f %-10d %-10d\n",
time()-t0, iter, tree.V_lower[1], tree.V_upper[1], root_diff(tree),
length(tree.Γ), length(tree.b_pruned) - sum(tree.b_pruned))
end

function dashed_line(n=86)
@printf("%s\n", "-"^n)
end
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d"
RockSample = "de008ff0-c357-11e8-3329-7fe746fe836e"
SARSOP = "cef570c6-3a94-5604-96b7-1a5e143043f2"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
45 changes: 30 additions & 15 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import SARSOP
using SparseArrays
using RockSample
using Combinatorics
using Suppressor

# lil bit of testing type piracy
JSOP.SARSOPTree(pomdp::POMDP) = JSOP.SARSOPTree(SARSOPSolver(), pomdp)
Expand All @@ -29,9 +30,9 @@ include("updater.jl")
include("tree.jl")

@testset "Tiger POMDP" begin
pomdp = TigerPOMDP();
solver = SARSOPSolver(epsilon = 0.5, precision = 1e-3);
tree = SARSOPTree(pomdp);
pomdp = TigerPOMDP()
solver = SARSOPSolver(epsilon=0.5, precision=1e-3, verbose=false)
tree = SARSOPTree(pomdp)
Γ = solve(solver, pomdp)
iterations = 0
while JSOP.root_diff(tree) > solver.precision
Expand All @@ -43,16 +44,16 @@ include("tree.jl")
@test isapprox(tree.V_lower[1], 19.37; atol=1e-1)
@test JSOP.root_diff(tree) < solver.precision

solverCPP = SARSOP.SARSOPSolver(trial_improvement_factor = 0.5, precision = 1e-3, verbose = false);
policyCPP = solve(solverCPP, pomdp);
solverCPP = SARSOP.SARSOPSolver(trial_improvement_factor=0.5, precision=1e-3, verbose=false)
policyCPP = solve(solverCPP, pomdp)
@test abs(value(policyCPP, initialstate(pomdp)) - tree.V_lower[1]) < 0.01
@test abs(value(policyCPP, initialstate(pomdp)) - value(Γ, initialstate(pomdp))) < 0.01
end

@testset "Baby POMDP" begin
pomdp = BabyPOMDP();
solver = SARSOPSolver(epsilon = 0.1, delta = 0.1, precision = 1e-3);
tree = SARSOPTree(pomdp);
pomdp = BabyPOMDP()
solver = SARSOPSolver(epsilon=0.1, delta=0.1, precision=1e-3, verbose=false)
tree = SARSOPTree(pomdp)
Γ = solve(solver, pomdp)
iterations = 0
while JSOP.root_diff(tree) > solver.precision
Expand All @@ -64,16 +65,16 @@ end
@test isapprox(tree.V_lower[1], -16.3; atol=1e-2)
@test JSOP.root_diff(tree) < solver.precision

solverCPP = SARSOP.SARSOPSolver(trial_improvement_factor = 0.5, precision = 1e-3, verbose = false);
policyCPP = solve(solverCPP, pomdp);
solverCPP = SARSOP.SARSOPSolver(trial_improvement_factor=0.5, precision=1e-3, verbose=false)
policyCPP = solve(solverCPP, pomdp)
@test abs(value(policyCPP, initialstate(pomdp)) - tree.V_lower[1]) < 0.01
@test abs(value(policyCPP, initialstate(pomdp)) - value(Γ, initialstate(pomdp))) < 0.01
end

@testset "RockSample POMDP" begin
pomdp = RockSamplePOMDP();
solver = SARSOPSolver(epsilon = 0.1, delta = 0.1, precision = 1e-2);
tree = SARSOPTree(pomdp);
pomdp = RockSamplePOMDP()
solver = SARSOPSolver(epsilon=0.1, delta=0.1, precision=1e-2, verbose=false)
tree = SARSOPTree(pomdp)
Γ = solve(solver, pomdp)
iterations = 0
while JSOP.root_diff(tree) > solver.precision
Expand All @@ -85,8 +86,22 @@ end
# @test isapprox(tree.V_lower[1], -16.3; atol=1e-2)
@test JSOP.root_diff(tree) < solver.precision

solverCPP = SARSOP.SARSOPSolver(trial_improvement_factor = 0.5, precision = 1e-2, verbose = false);
policyCPP = solve(solverCPP, pomdp);
solverCPP = SARSOP.SARSOPSolver(trial_improvement_factor=0.5, precision=1e-2, verbose=false)
policyCPP = solve(solverCPP, pomdp)
@test abs(value(policyCPP, initialstate(pomdp)) - tree.V_lower[1]) < 0.1
@test abs(value(policyCPP, initialstate(pomdp)) - value(Γ, initialstate(pomdp))) < 0.1
end

@testset "Verbose Tests" begin
pomdp = TigerPOMDP()
solver = SARSOPSolver(; max_time=10.0, verbose=true)
output = @capture_out solve(solver, pomdp)
@test occursin("Time", output)
@test occursin("Iter", output)
@test occursin("LB", output)
@test occursin("UB", output)
@test occursin("Precision", output)
@test occursin("# Alphas", output)
@test occursin("# Beliefs", output)
println(output)
end

0 comments on commit da75251

Please sign in to comment.