From 1242e76c19f2eb899fc6aabbe6ae98ed77ab520a Mon Sep 17 00:00:00 2001 From: Andrei-Carlo Papuc Date: Tue, 9 Apr 2024 13:46:38 +0200 Subject: [PATCH 1/2] Allow for parametric mcp options to be passed on to ParametricMCP --- src/solver_setup.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/solver_setup.jl b/src/solver_setup.jl index 456a0b1..45f19c9 100644 --- a/src/solver_setup.jl +++ b/src/solver_setup.jl @@ -5,7 +5,13 @@ struct Solver{T1,T2} dimensions::T2 end -function Solver(game::TrajectoryGame, horizon; context_dimension = 0, compute_sensitivities = true) +function Solver( + game::TrajectoryGame, + horizon; + context_dimension = 0, + compute_sensitivities = true, + parametric_mcp_options = (;), +) dimensions = let state_blocks = [state_dim(game.dynamics, player_index) for player_index in 1:num_players(game)] @@ -161,6 +167,7 @@ function Solver(game::TrajectoryGame, horizon; context_dimension = 0, compute_se lower_bounds, upper_bounds; compute_sensitivities, + parametric_mcp_options..., ) Solver(mcp_problem_representation, dimensions) From 45479139d9dfc823dac6754fdb2cec82b342f4d0 Mon Sep 17 00:00:00 2001 From: lassepe Date: Tue, 16 Apr 2024 17:52:49 +0200 Subject: [PATCH 2/2] Add tests --- test/Project.toml | 1 + test/runtests.jl | 38 ++++++++++++++++++++++++-------------- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index 5c8a19c..ed89ce2 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,6 +6,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" ParametricMCPs = "9b992ff8-05bb-4ea1-b9d2-5ef72d82f7ad" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TrajectoryGamesBase = "ac1ac542-73eb-4349-ae1b-660ab3609574" TrajectoryGamesExamples = "ff3fa34c-8d8f-519c-b5bc-31760c52507a" diff --git a/test/runtests.jl b/test/runtests.jl index f468c9c..37e4094 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,6 +9,7 @@ using StatsBase: mean using Zygote: Zygote using FiniteDiff: FiniteDiff using Random: Random +using Symbolics: Symbolics include("Demo.jl") @@ -55,25 +56,27 @@ function isfeasible(env::TrajectoryGamesBase.PolygonEnvironment, trajectory; tol end |> all end -function input_sanity(; solver, solver_wrong_context, game, initial_state, context) +function input_sanity(; solver, game, initial_state, context) @testset "input sanity" begin @test_throws ArgumentError TrajectoryGamesBase.solve_trajectory_game!( solver, game, initial_state, ) + context_with_wrong_size = [context; 0.0] @test_throws ArgumentError TrajectoryGamesBase.solve_trajectory_game!( - solver_wrong_context, + solver, game, initial_state; - context + context=context_with_wrong_size, ) + multipliers_despite_no_shared_constraints = [1] @test_throws ArgumentError TrajectoryGamesBase.solve_trajectory_game!( solver, game, initial_state; context, - shared_constraint_premultipliers=[1] + shared_constraint_premultipliers=multipliers_despite_no_shared_constraints, ) end end @@ -110,7 +113,7 @@ function backward_pass_sanity(; game, initial_state, rng=Random.MersenneTwister(1), - θs=[randn(rng, 4) for _ in 1:10] + θs=[randn(rng, 4) for _ in 1:10], ) @testset "backward pass sanity" begin function loss(θ) @@ -119,7 +122,7 @@ function backward_pass_sanity(; solver, game, initial_state; - context=θ + context=θ, ) sum(strategy.substrategies) do substrategy @@ -144,22 +147,29 @@ function main() context = [0.0, 1.0, 0.0, 1.0] initial_state = mortar([[1.0, 0, 0, 0], [-1.0, 0, 0, 0]]) - local solver, solver_wrong_context + local solver, solver_parallel @testset "Tests" begin @testset "solver setup" begin solver = MCPTrajectoryGameSolver.Solver(game, horizon; context_dimension=length(context)) - solver_wrong_context = - MCPTrajectoryGameSolver.Solver(game, horizon; context_dimension=(length(context) + 1)) + # exercise some inner solver options... + solver_parallel = MCPTrajectoryGameSolver.Solver( + game, + horizon; + context_dimension=length(context), + parametric_mcp_options=(; parallel=Symbolics.ShardedForm()), + ) end @testset "solve" begin - input_sanity(; solver, solver_wrong_context, game, initial_state, context) - strategy = - TrajectoryGamesBase.solve_trajectory_game!(solver, game, initial_state; context) - forward_pass_sanity(; solver, game, initial_state, context, horizon, strategy) - backward_pass_sanity(; solver, game, initial_state) + for solver in [solver, solver_parallel] + input_sanity(; solver, game, initial_state, context) + strategy = + TrajectoryGamesBase.solve_trajectory_game!(solver, game, initial_state; context) + forward_pass_sanity(; solver, game, initial_state, context, horizon, strategy) + backward_pass_sanity(; solver, game, initial_state) + end end @testset "integration test" begin