From 8a60d35a782d3045266ca045a348061e93072cbe Mon Sep 17 00:00:00 2001 From: lassepe Date: Mon, 11 Mar 2024 18:20:41 +0100 Subject: [PATCH] Receding-horizon demo --- README.md | 8 ++-- src/solve.jl | 7 ++- src/solver_setup.jl | 3 +- test/Demo.jl | 107 ++++++++++++++++++++++++++++++++++++++++---- test/Project.toml | 1 + test/runtests.jl | 3 +- 6 files changed, 113 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 4a8a370..fd649c4 100644 --- a/README.md +++ b/README.md @@ -24,11 +24,11 @@ For a full example of how to use this package, please consult the demo in [`test Start `julia --project` *from the repository root* and run the following commands: ```julia -using Pkg, TestEnv # install globally with `] add TestEnv` if you don't have this +using TestEnv, Revise # install globally with `] add TestEnv, Revise` if you don't have this TestEnv.activate() -pkg"instantiate" # ensures that the test dependencies are installed, only needed once -include("test/Demo.jl") -Demo.demo() +Revise.includet("test/Demo.jl") +Demo.demo_model_predictive_game_play() # example of receding-horizon interaction +Demo.demo_inverse_game() # example of fitting game parameters via differentiation of the game solver ``` ## Citation diff --git a/src/solve.jl b/src/solve.jl index 20c1535..903479e 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -35,7 +35,7 @@ function strategy_from_raw_solution(; raw_solution, game, solver) number_of_players = num_players(game) z_iter = Iterators.Stateful(raw_solution.z) - map(1:number_of_players) do player_index + substrategies = map(1:number_of_players) do player_index private_state_dimension = solver.dimensions.state_blocks[player_index] private_control_dimension = solver.dimensions.control_blocks[player_index] number_of_primals = @@ -44,7 +44,10 @@ function strategy_from_raw_solution(; raw_solution, game, solver) trajectory = unflatten_trajectory(z_private, private_state_dimension, private_control_dimension) OpenLoopStrategy(trajectory.xs, trajectory.us) - end |> TrajectoryGamesBase.JointStrategy + end + + info = (; raw_solution) + TrajectoryGamesBase.JointStrategy(substrategies, info) end function generate_initial_guess(solver, game, initial_state) diff --git a/src/solver_setup.jl b/src/solver_setup.jl index 43a1418..456a0b1 100644 --- a/src/solver_setup.jl +++ b/src/solver_setup.jl @@ -1,4 +1,4 @@ -struct Solver{T1, T2} +struct Solver{T1,T2} "The problem representation of the game via ParametricMCPs.ParametricMCP" mcp_problem_representation::T1 "A named tuple collecting all the problem dimension infos" @@ -133,6 +133,7 @@ function Solver(game::TrajectoryGame, horizon; context_dimension = 0, compute_se inequality_constraints_symoblic coupling_constraints_symbolic ] + z_symbolic = [ private_variables_per_player_symbolic... μ_private_symbolic diff --git a/test/Demo.jl b/test/Demo.jl index 4adf1a0..aeff813 100644 --- a/test/Demo.jl +++ b/test/Demo.jl @@ -1,19 +1,28 @@ module Demo using TrajectoryGamesBase: - TrajectoryGame, TrajectoryGameCost, ProductDynamics, GeneralSumCostStructure, PolygonEnvironment, solve_trajectory_game! + TrajectoryGamesBase, + TrajectoryGame, + TrajectoryGameCost, + ProductDynamics, + GeneralSumCostStructure, + PolygonEnvironment, + solve_trajectory_game!, + RecedingHorizonStrategy, + rollout using TrajectoryGamesExamples: planar_double_integrator using BlockArrays: blocks, mortar using MCPTrajectoryGameSolver: Solver using GLMakie: GLMakie using Zygote: Zygote +using ParametricMCPs: ParametricMCPs """ Set up a simple two-player collision-avoidance game: - each player wants to reach their own goal position encoded by the `context` vector - both players want to avoid collisions """ -function simple_game(; collision_avoidance_radius=1) +function simple_game(; collision_avoidance_radius = 1) dynamics = let single_player_dynamics = planar_double_integrator() ProductDynamics([single_player_dynamics, single_player_dynamics]) @@ -47,10 +56,77 @@ function simple_game(; collision_avoidance_radius=1) TrajectoryGame(dynamics, cost, environment, coupling_constraint) end -function demo() +function demo_model_predictive_game_play() + simulation_horizon = 50 game = simple_game() - horizon = 10 - solver = Solver(game, horizon; context_dimension=4) + initial_state = mortar([[-1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0]]) + context = let + goal_p1 = [1.0, -0.1] # slightly offset goal to break symmetry + goal_p2 = -goal_p1 + [goal_p1; goal_p2] + end + planning_horizon = 10 + solver = Solver(game, planning_horizon; context_dimension = length(context)) + + receding_horizon_strategy = RecedingHorizonStrategy(; + solver, + game, + solve_kwargs = (; context), + turn_length = 2, + # TODO: we could also provide this as a more easy-to-use utility, maybe even via dispatch + # TODO: potentially allow the user to only warm-start the primals and or add noise + generate_initial_guess = function (last_strategy, state, time) + # only warm-start if the last strategy is converged / feasible + if !isnothing(last_strategy) && + last_strategy.info.raw_solution.status == ParametricMCPs.PATHSolver.MCP_Solved + initial_guess = last_strategy.info.raw_solution.z + else + nothing + end + end, + ) + + # Set up the visualization in terms of `GLMakie.Observable` objectives for reactive programming + figure = GLMakie.Figure() + GLMakie.plot( + figure[1, 1], + game.env; + color = :lightgrey, + axis = (; aspect = GLMakie.DataAspect(), title = "Model predictive game play demo"), + ) + joint_strategy = + GLMakie.Observable(solve_trajectory_game!(solver, game, initial_state; context)) + GLMakie.plot!(figure[1, 1], joint_strategy) + for (player, color) in enumerate([:red, :blue]) + GLMakie.scatter!( + figure[1, 1], + GLMakie.@lift(GLMakie.Point2f($joint_strategy.substrategies[player].xs[begin])); + color, + ) + end + display(figure) + + # visualization callback to update the observables on the fly + function get_info(strategy, state, time) + joint_strategy[] = strategy.receding_horizon_strategy + sleep(0.1) # so that there's some time to see the visualization + nothing # what ever we return here will be stored in `trajectory.infos` in case you need it for later inspection + end + + # simulate the receding horizon strategy + trajectory = rollout( + game.dynamics, + receding_horizon_strategy, + initial_state, + simulation_horizon; + get_info, + ) +end + +function demo_inverse_game() + game = simple_game() + planning_horizon = 10 + solver = Solver(game, planning_horizon; context_dimension = 4) initial_state = mortar([[-1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0]]) # both players want to reach a goal position at (0, 1)) context = [0.0, 1.0, 0.0, 1.0] @@ -80,15 +156,30 @@ function demo() context_estimate -= learning_rate * ∇context end - final_joint_strategy = solve_trajectory_game!(solver, game, initial_state; context=context_estimate) + final_joint_strategy = + solve_trajectory_game!(solver, game, initial_state; context = context_estimate) # visualize the solution... # ...for the initial context estimate figure = GLMakie.Figure() - GLMakie.plot(figure[1, 1], game.env; axis=(; aspect=GLMakie.DataAspect(), title="Game solution for initial context estimate")) + GLMakie.plot( + figure[1, 1], + game.env; + axis = (; + aspect = GLMakie.DataAspect(), + title = "Game solution for initial context estimate", + ), + ) GLMakie.plot!(figure[1, 1], initial_joint_strategy) # ...and the optimized context estimate - GLMakie.plot(figure[1, 2], game.env; axis=(; aspect=GLMakie.DataAspect(), title="Game solution for optimized context estimate")) + GLMakie.plot( + figure[1, 2], + game.env; + axis = (; + aspect = GLMakie.DataAspect(), + title = "Game solution for optimized context estimate", + ), + ) GLMakie.plot!(figure[1, 2], final_joint_strategy) display(figure) diff --git a/test/Project.toml b/test/Project.toml index 2d0f572..5c8a19c 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -3,6 +3,7 @@ BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +ParametricMCPs = "9b992ff8-05bb-4ea1-b9d2-5ef72d82f7ad" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/runtests.jl b/test/runtests.jl index 9801b33..f468c9c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -163,7 +163,8 @@ function main() end @testset "integration test" begin - Demo.demo() + Demo.demo_model_predictive_game_play() + Demo.demo_inverse_game() end end end