diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index a24f9e5..1233b50 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -7,6 +7,7 @@ jobs: strategy: matrix: version: + - '1.7' # Earliest support listed in compat - '1' os: - ubuntu-latest diff --git a/Project.toml b/Project.toml index f285fbc..bf55f4d 100644 --- a/Project.toml +++ b/Project.toml @@ -7,9 +7,12 @@ version = "0.1.3" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" POMDPTools = "7588e00f-9cae-40de-98dc-e0c70c48cdd7" POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [compat] +LinearAlgebra = "1" POMDPTools = "0.1, 1" POMDPs = "0.9, 1" +Printf = "1" julia = "1.7" diff --git a/src/NativeSARSOP.jl b/src/NativeSARSOP.jl index 0d7ccdc..5d3b586 100644 --- a/src/NativeSARSOP.jl +++ b/src/NativeSARSOP.jl @@ -4,6 +4,7 @@ using POMDPs using POMDPTools using SparseArrays using LinearAlgebra +using Printf export SARSOPSolver, SARSOPTree diff --git a/src/solver.jl b/src/solver.jl index 7542f2a..71fbd42 100644 --- a/src/solver.jl +++ b/src/solver.jl @@ -13,16 +13,29 @@ end function POMDPTools.solve_info(solver::SARSOPSolver, pomdp::POMDP) tree = SARSOPTree(solver, pomdp) - + + if solver.verbose + initialize_verbose_output() + end + t0 = time() iter = 0 while time()-t0 < solver.max_time && root_diff(tree) > solver.precision sample!(solver, tree) backup!(tree) prune!(solver, tree) + if solver.verbose && iter % 10 == 0 + log_verbose_info(t0, iter, tree) + end iter += 1 end + if solver.verbose + dashed_line() + log_verbose_info(t0, iter, tree) + dashed_line() + end + pol = AlphaVectorPolicy( pomdp, getproperty.(tree.Γ, :alpha), @@ -36,3 +49,20 @@ function POMDPTools.solve_info(solver::SARSOPSolver, pomdp::POMDP) end POMDPs.solve(solver::SARSOPSolver, pomdp::POMDP) = first(solve_info(solver, pomdp)) + +function initialize_verbose_output() + dashed_line() + @printf(" %-10s %-10s %-12s %-12s %-15s %-10s %-10s\n", + "Time", "Iter", "LB", "UB", "Precision", "# Alphas", "# Beliefs") + dashed_line() +end + +function log_verbose_info(t0::Float64, iter::Int, tree::SARSOPTree) + @printf(" %-10.2f %-10d %-12.7f %-12.7f %-15.10f %-10d %-10d\n", + time()-t0, iter, tree.V_lower[1], tree.V_upper[1], root_diff(tree), + length(tree.Γ), length(tree.b_pruned) - sum(tree.b_pruned)) +end + +function dashed_line(n=86) + @printf("%s\n", "-"^n) +end diff --git a/test/Project.toml b/test/Project.toml index e651abf..1d5f5b5 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,4 +6,5 @@ POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d" RockSample = "de008ff0-c357-11e8-3329-7fe746fe836e" SARSOP = "cef570c6-3a94-5604-96b7-1a5e143043f2" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/runtests.jl b/test/runtests.jl index ae33388..806fca1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,6 +8,7 @@ import SARSOP using SparseArrays using RockSample using Combinatorics +using Suppressor # lil bit of testing type piracy JSOP.SARSOPTree(pomdp::POMDP) = JSOP.SARSOPTree(SARSOPSolver(), pomdp) @@ -29,9 +30,9 @@ include("updater.jl") include("tree.jl") @testset "Tiger POMDP" begin - pomdp = TigerPOMDP(); - solver = SARSOPSolver(epsilon = 0.5, precision = 1e-3); - tree = SARSOPTree(pomdp); + pomdp = TigerPOMDP() + solver = SARSOPSolver(epsilon=0.5, precision=1e-3, verbose=false) + tree = SARSOPTree(pomdp) Γ = solve(solver, pomdp) iterations = 0 while JSOP.root_diff(tree) > solver.precision @@ -43,16 +44,16 @@ include("tree.jl") @test isapprox(tree.V_lower[1], 19.37; atol=1e-1) @test JSOP.root_diff(tree) < solver.precision - solverCPP = SARSOP.SARSOPSolver(trial_improvement_factor = 0.5, precision = 1e-3, verbose = false); - policyCPP = solve(solverCPP, pomdp); + solverCPP = SARSOP.SARSOPSolver(trial_improvement_factor=0.5, precision=1e-3, verbose=false) + policyCPP = solve(solverCPP, pomdp) @test abs(value(policyCPP, initialstate(pomdp)) - tree.V_lower[1]) < 0.01 @test abs(value(policyCPP, initialstate(pomdp)) - value(Γ, initialstate(pomdp))) < 0.01 end @testset "Baby POMDP" begin - pomdp = BabyPOMDP(); - solver = SARSOPSolver(epsilon = 0.1, delta = 0.1, precision = 1e-3); - tree = SARSOPTree(pomdp); + pomdp = BabyPOMDP() + solver = SARSOPSolver(epsilon=0.1, delta=0.1, precision=1e-3, verbose=false) + tree = SARSOPTree(pomdp) Γ = solve(solver, pomdp) iterations = 0 while JSOP.root_diff(tree) > solver.precision @@ -64,16 +65,16 @@ end @test isapprox(tree.V_lower[1], -16.3; atol=1e-2) @test JSOP.root_diff(tree) < solver.precision - solverCPP = SARSOP.SARSOPSolver(trial_improvement_factor = 0.5, precision = 1e-3, verbose = false); - policyCPP = solve(solverCPP, pomdp); + solverCPP = SARSOP.SARSOPSolver(trial_improvement_factor=0.5, precision=1e-3, verbose=false) + policyCPP = solve(solverCPP, pomdp) @test abs(value(policyCPP, initialstate(pomdp)) - tree.V_lower[1]) < 0.01 @test abs(value(policyCPP, initialstate(pomdp)) - value(Γ, initialstate(pomdp))) < 0.01 end @testset "RockSample POMDP" begin - pomdp = RockSamplePOMDP(); - solver = SARSOPSolver(epsilon = 0.1, delta = 0.1, precision = 1e-2); - tree = SARSOPTree(pomdp); + pomdp = RockSamplePOMDP() + solver = SARSOPSolver(epsilon=0.1, delta=0.1, precision=1e-2, verbose=false) + tree = SARSOPTree(pomdp) Γ = solve(solver, pomdp) iterations = 0 while JSOP.root_diff(tree) > solver.precision @@ -85,8 +86,22 @@ end # @test isapprox(tree.V_lower[1], -16.3; atol=1e-2) @test JSOP.root_diff(tree) < solver.precision - solverCPP = SARSOP.SARSOPSolver(trial_improvement_factor = 0.5, precision = 1e-2, verbose = false); - policyCPP = solve(solverCPP, pomdp); + solverCPP = SARSOP.SARSOPSolver(trial_improvement_factor=0.5, precision=1e-2, verbose=false) + policyCPP = solve(solverCPP, pomdp) @test abs(value(policyCPP, initialstate(pomdp)) - tree.V_lower[1]) < 0.1 @test abs(value(policyCPP, initialstate(pomdp)) - value(Γ, initialstate(pomdp))) < 0.1 end + +@testset "Verbose Tests" begin + pomdp = TigerPOMDP() + solver = SARSOPSolver(; max_time=10.0, verbose=true) + output = @capture_out solve(solver, pomdp) + @test occursin("Time", output) + @test occursin("Iter", output) + @test occursin("LB", output) + @test occursin("UB", output) + @test occursin("Precision", output) + @test occursin("# Alphas", output) + @test occursin("# Beliefs", output) + println(output) +end