Skip to content

Commit

Permalink
Merge pull request #51 from HJReachability/fix/al_solver_problem_chan…
Browse files Browse the repository at this point in the history
…ging

Add support for solver resetting problem and constraints
  • Loading branch information
dfridovi authored Nov 7, 2020
2 parents ce8a677 + f47dd52 commit 9d14d1b
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 19 deletions.
2 changes: 1 addition & 1 deletion include/ilqgames/constraint/constraint.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class Constraint : public Cost {
explicit Constraint(bool is_equality, const std::string& name)
: Cost(1.0, name),
is_equality_(is_equality),
lambdas_(time::kNumTimeSteps, 0.0) {}
lambdas_(time::kNumTimeSteps, constants::kDefaultLambda) {}

// Modify derivatives to account for the multipliers and the quadratic term in
// the augmented Lagrangian. The inputs are the derivatives of g in the
Expand Down
6 changes: 6 additions & 0 deletions include/ilqgames/solver/solver_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ struct SolverParams {
float geometric_mu_downscaling = 0.5;
float geometric_lambda_downscaling = 0.5;
float constraint_error_tolerance = 1e-1;

// Should the solver reset problem/constraint params to their initial values.
// NOTE: defaults to true.
bool reset_problem = true;
bool reset_lambdas = true;
bool reset_mu = true;
}; // struct SolverParams

} // namespace ilqgames
Expand Down
3 changes: 2 additions & 1 deletion include/ilqgames/utils/solver_log.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ class SolverLog : private Uncopyable {
// Get index corresponding to the time step immediately before the given time.
size_t TimeToIndex(Time t) const {
return static_cast<size_t>(
std::max(constants::kSmallNumber, t - InitialTime()) / time::kTimeStep);
std::max<Time>(constants::kSmallNumber, t - InitialTime()) /
time::kTimeStep);
}

// Get time stamp corresponding to a particular index.
Expand Down
17 changes: 4 additions & 13 deletions include/ilqgames/utils/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ struct Empty {};
// ------------------------------- CONSTANTS -------------------------------- //

namespace constants {
#ifdef __APPLE__

// Acceleration due to gravity (m/s/s).
static constexpr float kGravity = 9.81;

Expand All @@ -124,19 +124,10 @@ static constexpr float kInfinity = std::numeric_limits<float>::infinity();
// Constant for invalid values.
static constexpr float kInvalidValue = std::numeric_limits<float>::quiet_NaN();

#else
// Acceleration due to gravity (m/s/s).
static constexpr double kGravity = 9.81;

// Small number for use in approximate equality checking.
static constexpr double kSmallNumber = 1e-4;

// Float precision infinity.
static constexpr double kInfinity = std::numeric_limits<float>::infinity();
// Default multiplier values.
static constexpr float kDefaultLambda = 0.0;
static constexpr float kDefaultMu = 10.0;

// Constant for invalid values.
static constexpr double kInvalidValue = std::numeric_limits<float>::quiet_NaN();
#endif
} // namespace constants

namespace time {
Expand Down
22 changes: 19 additions & 3 deletions src/augmented_lagrangian_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ std::shared_ptr<SolverLog> AugmentedLagrangianSolver::Solve(bool* success,
Time max_runtime) {
if (success) *success = true;

// Cache initial problem solution so we can restore it at the end.
const auto& initial_op = problem_->CurrentOperatingPoint();
const auto& initial_strategies = problem_->CurrentStrategies();

// Create new log.
std::shared_ptr<SolverLog> log = CreateNewLog();

Expand Down Expand Up @@ -186,9 +190,21 @@ std::shared_ptr<SolverLog> AugmentedLagrangianSolver::Solve(bool* success,
if (success) *success = false;
}

// Update problem solution to make sure we get the final log output.
problem_->OverwriteSolution(log->FinalOperatingPoint(),
log->FinalStrategies());
// Maybe restore initial solution to this problem.
if (params_.reset_problem)
problem_->OverwriteSolution(initial_op, initial_strategies);

// Reset all multipliers.
if (params_.reset_lambdas) {
for (auto& pc : problem_->PlayerCosts()) {
for (const auto& constraint : pc.StateConstraints())
constraint->ScaleLambdas(constants::kDefaultLambda);
for (const auto& pair : pc.ControlConstraints())
pair.second->ScaleLambdas(constants::kDefaultLambda);
}
}

if (params_.reset_mu) Constraint::GlobalMu() = constants::kDefaultMu;

return log;
}
Expand Down
2 changes: 1 addition & 1 deletion src/constraint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@

namespace ilqgames {

float Constraint::mu_ = 10.0;
float Constraint::mu_ = constants::kDefaultMu;

void Constraint::ModifyDerivatives(Time t, float g, float* dx, float* ddx,
float* dy, float* ddy, float* dxdy) const {
Expand Down

0 comments on commit 9d14d1b

Please sign in to comment.