Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace LegacyGridWorld with SimpleGridWorld #89

Merged
merged 3 commits into from
Feb 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
docs/build/
docs/site/
*.swp
*Manifest.toml
*ipynb_checkpoints/
9 changes: 5 additions & 4 deletions bench/non-terminal_gw.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
using POMDPs
using POMDPModels
using MCTS
using POMDPToolbox
using POMDPSimulators
using ProgressMeter
using ProfileView
using Random
# using ProfileView

sim = RolloutSimulator(max_steps=100, rng=MersenneTwister(7))

mdp = GridWorld(terminals=[])
mdp = SimpleGridWorld()

d=20; n=100; c=10.
@show d, n, c
Expand All @@ -16,7 +17,7 @@ solver = MCTSSolver(depth=d, n_iterations=n, exploration_constant=c, rng=Mersenn
planner = solve(solver, mdp)
simulate(sim, mdp, planner)

# @code_warntype MCTS.simulate(planner, GridWorldState(1,1,false), 10)
# @code_warntype MCTS.simulate(planner, GWPos(1,1), 10)

# Profile.clear()
# @profile for i in 1:100
Expand Down
31 changes: 16 additions & 15 deletions bench/non-terminal_gw_dpw.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
using POMDPs
using POMDPModels
using MCTS
using POMDPToolbox
using POMDPSimulators
using ProgressMeter
using ProfileView
using Random
# using ProfileView

sim = RolloutSimulator(max_steps=100, rng=MersenneTwister(7))

mdp = GridWorld(terminals=[])
mdp = SimpleGridWorld()

d=20; n=1000; c=10.
@show d, n, c
Expand All @@ -23,17 +24,17 @@ solver = DPWSolver(depth=d,
planner = solve(solver, mdp)
simulate(sim, mdp, planner)

# @code_warntype MCTS.simulate(planner, GridWorldState(1,1,false), 10)
# @code_warntype MCTS.simulate(planner, GWPos(1,1), 10)

Profile.clear()
@profile for i in 1:1
simulate(sim, mdp, planner)
end
ProfileView.view()

# @show N=100
# rewards = Array(Float64, N)
# @time @showprogress for i = 1:N
# rewards[i] = simulate(sim, mdp, planner)
# Profile.clear()
# @profile for i in 1:1
# simulate(sim, mdp, planner)
# end
# @show mean(rewards)
# ProfileView.view()

@show N=100
rewards = Array{Float64}(undef, N)
@time @showprogress for i = 1:N
rewards[i] = simulate(sim, mdp, planner)
end
@show mean(rewards)
2 changes: 1 addition & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ solver = MCTSSolver(estimate_value=RolloutEstimator(rollout_policy)) # default s
Since Monte-Carlo Tree Search is an online method, the solve function simply specifies the mdp model to the solver (which is embedded in the policy object). (Note that an MCTSPlanner can also be constructed directly without calling `solve()`.) The computation is done during calls to the action function. To extract the policy for a given state, simply call the action function:

```julia
s = create_state(mdp) # this can be any valid state
s = rand(states(mdp)) # this can be any valid state
a = action(planner, s) # returns the action for state s
```

Expand Down
Loading