Skip to content

Commit

Permalink
Receding-horizon demo
Browse files Browse the repository at this point in the history
  • Loading branch information
lassepe committed Mar 11, 2024
1 parent 715ad1e commit 8a60d35
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 16 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ For a full example of how to use this package, please consult the demo in [`test

Start `julia --project` *from the repository root* and run the following commands:
```julia
using Pkg, TestEnv # install globally with `] add TestEnv` if you don't have this
using TestEnv, Revise # install globally with `] add TestEnv, Revise` if you don't have this
TestEnv.activate()
pkg"instantiate" # ensures that the test dependencies are installed, only needed once
include("test/Demo.jl")
Demo.demo()
Revise.includet("test/Demo.jl")
Demo.demo_model_predictive_game_play() # example of receding-horizon interaction
Demo.demo_inverse_game() # example of fitting game parameters via differentiation of the game solver
```

## Citation
Expand Down
7 changes: 5 additions & 2 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ function strategy_from_raw_solution(; raw_solution, game, solver)
number_of_players = num_players(game)
z_iter = Iterators.Stateful(raw_solution.z)

map(1:number_of_players) do player_index
substrategies = map(1:number_of_players) do player_index
private_state_dimension = solver.dimensions.state_blocks[player_index]
private_control_dimension = solver.dimensions.control_blocks[player_index]
number_of_primals =
Expand All @@ -44,7 +44,10 @@ function strategy_from_raw_solution(; raw_solution, game, solver)
trajectory =
unflatten_trajectory(z_private, private_state_dimension, private_control_dimension)
OpenLoopStrategy(trajectory.xs, trajectory.us)
end |> TrajectoryGamesBase.JointStrategy
end

info = (; raw_solution)
TrajectoryGamesBase.JointStrategy(substrategies, info)
end

function generate_initial_guess(solver, game, initial_state)
Expand Down
3 changes: 2 additions & 1 deletion src/solver_setup.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
struct Solver{T1, T2}
struct Solver{T1,T2}
"The problem representation of the game via ParametricMCPs.ParametricMCP"
mcp_problem_representation::T1
"A named tuple collecting all the problem dimension infos"
Expand Down Expand Up @@ -133,6 +133,7 @@ function Solver(game::TrajectoryGame, horizon; context_dimension = 0, compute_se
inequality_constraints_symoblic
coupling_constraints_symbolic
]

z_symbolic = [
private_variables_per_player_symbolic...
μ_private_symbolic
Expand Down
107 changes: 99 additions & 8 deletions test/Demo.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
module Demo

using TrajectoryGamesBase:
TrajectoryGame, TrajectoryGameCost, ProductDynamics, GeneralSumCostStructure, PolygonEnvironment, solve_trajectory_game!
TrajectoryGamesBase,
TrajectoryGame,
TrajectoryGameCost,
ProductDynamics,
GeneralSumCostStructure,
PolygonEnvironment,
solve_trajectory_game!,
RecedingHorizonStrategy,
rollout
using TrajectoryGamesExamples: planar_double_integrator
using BlockArrays: blocks, mortar
using MCPTrajectoryGameSolver: Solver
using GLMakie: GLMakie
using Zygote: Zygote
using ParametricMCPs: ParametricMCPs

"""
Set up a simple two-player collision-avoidance game:
- each player wants to reach their own goal position encoded by the `context` vector
- both players want to avoid collisions
"""
function simple_game(; collision_avoidance_radius=1)
function simple_game(; collision_avoidance_radius = 1)
dynamics = let
single_player_dynamics = planar_double_integrator()
ProductDynamics([single_player_dynamics, single_player_dynamics])
Expand Down Expand Up @@ -47,10 +56,77 @@ function simple_game(; collision_avoidance_radius=1)
TrajectoryGame(dynamics, cost, environment, coupling_constraint)
end

function demo()
function demo_model_predictive_game_play()
simulation_horizon = 50
game = simple_game()
horizon = 10
solver = Solver(game, horizon; context_dimension=4)
initial_state = mortar([[-1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0]])
context = let
goal_p1 = [1.0, -0.1] # slightly offset goal to break symmetry
goal_p2 = -goal_p1
[goal_p1; goal_p2]
end
planning_horizon = 10
solver = Solver(game, planning_horizon; context_dimension = length(context))

receding_horizon_strategy = RecedingHorizonStrategy(;
solver,
game,
solve_kwargs = (; context),
turn_length = 2,
# TODO: we could also provide this as a more easy-to-use utility, maybe even via dispatch
# TODO: potentially allow the user to only warm-start the primals and or add noise
generate_initial_guess = function (last_strategy, state, time)
# only warm-start if the last strategy is converged / feasible
if !isnothing(last_strategy) &&
last_strategy.info.raw_solution.status == ParametricMCPs.PATHSolver.MCP_Solved
initial_guess = last_strategy.info.raw_solution.z
else
nothing
end
end,
)

# Set up the visualization in terms of `GLMakie.Observable` objectives for reactive programming
figure = GLMakie.Figure()
GLMakie.plot(
figure[1, 1],
game.env;
color = :lightgrey,
axis = (; aspect = GLMakie.DataAspect(), title = "Model predictive game play demo"),
)
joint_strategy =
GLMakie.Observable(solve_trajectory_game!(solver, game, initial_state; context))
GLMakie.plot!(figure[1, 1], joint_strategy)
for (player, color) in enumerate([:red, :blue])
GLMakie.scatter!(
figure[1, 1],
GLMakie.@lift(GLMakie.Point2f($joint_strategy.substrategies[player].xs[begin]));
color,
)
end
display(figure)

# visualization callback to update the observables on the fly
function get_info(strategy, state, time)
joint_strategy[] = strategy.receding_horizon_strategy
sleep(0.1) # so that there's some time to see the visualization
nothing # what ever we return here will be stored in `trajectory.infos` in case you need it for later inspection
end

# simulate the receding horizon strategy
trajectory = rollout(
game.dynamics,
receding_horizon_strategy,
initial_state,
simulation_horizon;
get_info,
)
end

function demo_inverse_game()
game = simple_game()
planning_horizon = 10
solver = Solver(game, planning_horizon; context_dimension = 4)
initial_state = mortar([[-1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0]])
# both players want to reach a goal position at (0, 1))
context = [0.0, 1.0, 0.0, 1.0]
Expand Down Expand Up @@ -80,15 +156,30 @@ function demo()
context_estimate -= learning_rate * ∇context
end

final_joint_strategy = solve_trajectory_game!(solver, game, initial_state; context=context_estimate)
final_joint_strategy =
solve_trajectory_game!(solver, game, initial_state; context = context_estimate)

# visualize the solution...
# ...for the initial context estimate
figure = GLMakie.Figure()
GLMakie.plot(figure[1, 1], game.env; axis=(; aspect=GLMakie.DataAspect(), title="Game solution for initial context estimate"))
GLMakie.plot(
figure[1, 1],
game.env;
axis = (;
aspect = GLMakie.DataAspect(),
title = "Game solution for initial context estimate",
),
)
GLMakie.plot!(figure[1, 1], initial_joint_strategy)
# ...and the optimized context estimate
GLMakie.plot(figure[1, 2], game.env; axis=(; aspect=GLMakie.DataAspect(), title="Game solution for optimized context estimate"))
GLMakie.plot(
figure[1, 2],
game.env;
axis = (;
aspect = GLMakie.DataAspect(),
title = "Game solution for optimized context estimate",
),
)
GLMakie.plot!(figure[1, 2], final_joint_strategy)
display(figure)

Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
ParametricMCPs = "9b992ff8-05bb-4ea1-b9d2-5ef72d82f7ad"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ function main()
end

@testset "integration test" begin
Demo.demo()
Demo.demo_model_predictive_game_play()
Demo.demo_inverse_game()
end
end
end
Expand Down

0 comments on commit 8a60d35

Please sign in to comment.