Skip to content

Commit

Permalink
Merge pull request #13692 from trilinos/bartgol/piro-tempus-solver-mods
Browse files Browse the repository at this point in the history
  • Loading branch information
bartgol authored Jan 9, 2025
2 parents 4afdea6 + 060b901 commit b9eb229
Show file tree
Hide file tree
Showing 12 changed files with 341 additions and 136 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,31 +40,13 @@ class ObserverToTempusIntegrationObserverAdapter : public Tempus::IntegratorObse
//@{
/// Destructor

virtual ~ObserverToTempusIntegrationObserverAdapter();
virtual ~ObserverToTempusIntegrationObserverAdapter() = default;

/// Observe the beginning of the time integrator.
virtual void observeStartIntegrator(const Tempus::Integrator<Scalar>& integrator) override;

/// Observe the beginning of the time step loop.
virtual void observeStartTimeStep(const Tempus::Integrator<Scalar>& integrator) override;

/// Observe after the next time step size is selected.
virtual void observeNextTimeStep(const Tempus::Integrator<Scalar>& integrator) override;

/// Observe before Stepper takes step.
virtual void observeBeforeTakeStep(const Tempus::Integrator<Scalar>& integrator) override;

/// Observe after Stepper takes step.
virtual void observeAfterTakeStep(const Tempus::Integrator<Scalar>& integrator) override;

/// Observe after checking time step. Observer can still fail the time step here.
virtual void observeAfterCheckTimeStep(const Tempus::Integrator<Scalar>& integrator) override;

/// Observe the end of the time step loop.
virtual void observeEndTimeStep(const Tempus::Integrator<Scalar>& integrator) override;

/// Observe the end of the time integrator.
virtual void observeEndIntegrator(const Tempus::Integrator<Scalar>& integrator) override;
//@}

private:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,6 @@ Piro::ObserverToTempusIntegrationObserverAdapter<Scalar>::ObserverToTempusIntegr
previous_dt_ = 0.0;
}

template <typename Scalar>
Piro::ObserverToTempusIntegrationObserverAdapter<Scalar>::~ObserverToTempusIntegrationObserverAdapter()
{
//Nothing to do
}

template <typename Scalar>
void
Piro::ObserverToTempusIntegrationObserverAdapter<Scalar>::
Expand All @@ -64,49 +58,6 @@ observeStartIntegrator(const Tempus::Integrator<Scalar>& integrator)
this->observeTimeStep();
}

template <typename Scalar>
void
Piro::ObserverToTempusIntegrationObserverAdapter<Scalar>::
observeStartTimeStep(const Tempus::Integrator<Scalar>& )
{
//Nothing to do
}

template <typename Scalar>
void
Piro::ObserverToTempusIntegrationObserverAdapter<Scalar>::
observeNextTimeStep(const Tempus::Integrator<Scalar>& )
{
//Nothing to do
}

template <typename Scalar>
void
Piro::ObserverToTempusIntegrationObserverAdapter<Scalar>::
observeBeforeTakeStep(const Tempus::Integrator<Scalar>& )
{
//Nothing to do
}


template <typename Scalar>
void
Piro::ObserverToTempusIntegrationObserverAdapter<Scalar>::
observeAfterTakeStep(const Tempus::Integrator<Scalar>& )
{
//Nothing to do
}


template<class Scalar>
void
Piro::ObserverToTempusIntegrationObserverAdapter<Scalar>::
observeAfterCheckTimeStep(const Tempus::Integrator<Scalar>& integrator)
{
//Nothing to do
}


template<class Scalar>
void
Piro::ObserverToTempusIntegrationObserverAdapter<Scalar>::
Expand Down Expand Up @@ -139,35 +90,6 @@ observeEndTimeStep(const Tempus::Integrator<Scalar>& integrator)
this->observeTimeStep();
}


template <typename Scalar>
void
Piro::ObserverToTempusIntegrationObserverAdapter<Scalar>::
observeEndIntegrator(const Tempus::Integrator<Scalar>& integrator)
{
//this->observeTimeStep();

std::string exitStatus;
//const Scalar runtime = integrator.getIntegratorTimer()->totalElapsedTime();
if (integrator.getSolutionHistory()->getCurrentState()->getSolutionStatus() ==
Tempus::Status::FAILED or integrator.getStatus() == Tempus::Status::FAILED) {
exitStatus = "Time integration FAILURE!";
} else {
exitStatus = "Time integration complete.";
}
std::time_t end = std::time(nullptr);
const Scalar runtime = integrator.getIntegratorTimer()->totalElapsedTime();
const Teuchos::RCP<Teuchos::FancyOStream> out = integrator.getOStream();
Teuchos::OSTab ostab(out,0,"ScreenOutput");
*out << "============================================================================\n"
<< " Total runtime = " << runtime << " sec = "
<< runtime/60.0 << " min\n"
<< std::asctime(std::localtime(&end))
<< exitStatus << "\n"
<< std::endl;

}

template <typename Scalar>
void
Piro::ObserverToTempusIntegrationObserverAdapter<Scalar>::observeTimeStep()
Expand Down
1 change: 1 addition & 0 deletions packages/piro/src/Piro_TempusIntegrator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class TempusIntegrator
//The following routine is only for adjoint sensitivities
Teuchos::RCP<const Thyra::MultiVectorBase<Scalar>> getDgDp() const;

Teuchos::RCP<const Tempus::Integrator<Scalar> > getIntegrator () const;
private:

Teuchos::RCP<Tempus::IntegratorBasic<Scalar> > basicIntegrator_;
Expand Down
19 changes: 19 additions & 0 deletions packages/piro/src/Piro_TempusIntegrator_Def.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,3 +363,22 @@ Piro::TempusIntegrator<Scalar>::getDgDp() const
"which is not of type Tempus::IntegratorAdjointSensitivity!\n");
}
}


template <typename Scalar>
Teuchos::RCP<const Tempus::Integrator<Scalar> >
Piro::TempusIntegrator<Scalar>:: getIntegrator () const
{
Teuchos::RCP<const Tempus::Integrator<Scalar> > out;
if (basicIntegrator_ != Teuchos::null) {
out = basicIntegrator_;
} else if (fwdSensIntegrator_ != Teuchos::null) {
out = fwdSensIntegrator_;
} else if (adjSensIntegrator_ != Teuchos::null) {
out = adjSensIntegrator_;
} else {
TEUCHOS_TEST_FOR_EXCEPTION(true, std::runtime_error,
"Error in Piro::TempusIntegrator: no integrator stored!\n");
}
return out;
}
1 change: 0 additions & 1 deletion packages/piro/src/Piro_TempusSolver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ class TempusSolver

Teuchos::RCP<Thyra::ModelEvaluator<Scalar> > model_;
Teuchos::RCP<Thyra::ModelEvaluator<Scalar> > adjointModel_;
Teuchos::RCP<Thyra::ModelEvaluatorDefaultBase<double> > thyraModel_;

Scalar t_initial_;
Scalar t_final_;
Expand Down
28 changes: 7 additions & 21 deletions packages/piro/src/Piro_TempusSolver_Def.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,24 +167,10 @@ void Piro::TempusSolver<Scalar>::initialize(
const std::string stepperType = stepperPL->get<std::string>("Stepper Type", "Backward Euler");
//*out_ << "Stepper Type = " << stepperType << "\n";

//
// *out_ << "\nB) Create the Stratimikos linear solver factory ...\n";
//
// This is the linear solve strategy that will be used to solve for the
// linear system with the W.
//
Stratimikos::DefaultLinearSolverBuilder linearSolverBuilder;

#ifdef HAVE_PIRO_MUELU
Stratimikos::enableMueLu(linearSolverBuilder);
#endif

linearSolverBuilder.setParameterList(sublist(tempusPL, "Stratimikos", true));
tempusPL->validateParameters(*getValidTempusParameters(
tempusPL->get<std::string>("Integrator Name", "Tempus Integrator"),
integratorPL->get<std::string>("Stepper Name", "Tempus Stepper")
), 0);
RCP<Thyra::LinearOpWithSolveFactoryBase<double> > lowsFactory = createLinearSolveStrategy(linearSolverBuilder);

//
*out_ << "\nC) Create and initalize the forward model ...\n";
Expand Down Expand Up @@ -287,11 +273,6 @@ void Piro::TempusSolver<Scalar>::initialize(
}
}

// C.2) Create the Thyra-wrapped ModelEvaluator

thyraModel_ = rcp(new Thyra::DefaultModelEvaluatorWithSolveFactory<Scalar>(model_, lowsFactory));
const RCP<const Thyra::VectorSpaceBase<double> > x_space = thyraModel_->get_x_space();

//
*out_ << "\nD) Create the stepper and integrator for the forward problem ...\n";

Expand Down Expand Up @@ -650,15 +631,20 @@ template <typename Scalar>
void Piro::TempusSolver<Scalar>::
setObserver() const
{
Teuchos::RCP<Tempus::IntegratorObserverBasic<Scalar> > observer = Teuchos::null;
if (Teuchos::nonnull(piroObserver_)) {
// Do not create the tempus observer adapter if the user-provided Piro observer is already a tempus observer
Teuchos::RCP<Tempus::IntegratorObserver<Scalar> > observer;
observer = Teuchos::rcp_dynamic_cast<Tempus::IntegratorObserver<Scalar>>(piroObserver_);
if (Teuchos::is_null(observer) and Teuchos::nonnull(piroObserver_)) {
// The user did not provide a Tempus observer, so create an adapter one

//Get solutionHistory from integrator
const Teuchos::RCP<const Tempus::SolutionHistory<Scalar> > solutionHistory = piroTempusIntegrator_->getSolutionHistory();
const Teuchos::RCP<const Tempus::TimeStepControl<Scalar> > timeStepControl = piroTempusIntegrator_->getTimeStepControl();
//Create Tempus::IntegratorObserverBasic object
observer = Teuchos::rcp(new ObserverToTempusIntegrationObserverAdapter<Scalar>(solutionHistory,
timeStepControl, piroObserver_, supports_x_dotdot_, abort_on_fail_at_min_dt_, sens_method_));
}

if (Teuchos::nonnull(observer)) {
//Set observer in integrator
piroTempusIntegrator_->clearObservers();
Expand Down
12 changes: 11 additions & 1 deletion packages/piro/src/Piro_TransientSolver_Def.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,17 @@ Piro::TransientSolver<Scalar>::evalConvergedModelResponsesAndSensitivities(
// Solution at convergence is the response at index num_g_
RCP<Thyra::VectorBase<Scalar> > gx_out = outArgs.get_g(num_g_);
if (Teuchos::nonnull(gx_out)) {
Thyra::copy(*modelInArgs.get_x(), gx_out.ptr());
auto x = modelInArgs.get_x();
if (x->space()->isCompatible(*gx_out->space())) {
Thyra::copy(*modelInArgs.get_x(), gx_out.ptr());
} else {
*out_ << " WARNING [PIRO::TransientSolver::evalConvergedModelResponsesAndSensitivities]\n"
" The solution from inArgs (x) is incompatible with the last response in outArgs (gx),\n"
" which was created as a vector with the same vector-space of x. Since the responses\n"
" were created BEFORE calling the transient solver, the most likely explanation for this\n"
" incompatibility is that during the time integration the solution vector space changed,\n"
" for instance because the underlying spatial mesh was adapted (with topological changes).\n";
}
}

// Setup output for final evalution of underlying model
Expand Down
16 changes: 15 additions & 1 deletion packages/piro/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,19 @@ IF (Piro_ENABLE_NOX)
ENDIF(Piro_ENABLE_NOX)

IF (Piro_ENABLE_Tempus)
TRIBITS_ADD_EXECUTABLE_AND_TEST(
TempusSolver_UnitTests_Tpetra
SOURCES
Piro_TempusSolver_UnitTests.cpp
${TEUCHOS_STD_UNIT_TEST_MAIN}
Piro_Test_ThyraSupport.hpp
MockModelEval_A_Tpetra.hpp
MockModelEval_A_Tpetra.cpp
MatrixBased_LOWS.cpp
NUM_MPI_PROCS 1-4
STANDARD_PASS_OUTPUT
)

TRIBITS_ADD_EXECUTABLE_AND_TEST(
TempusSolver_SensitivitySinCos_Combined_FSA_UnitTests
SOURCES
Expand Down Expand Up @@ -251,6 +264,7 @@ TRIBITS_COPY_FILES_TO_BINARY_DIR(copyTestInputFiles
input_Analysis_ROL_ReducedSpace_TrustRegion_BoundConstrained_ExplicitAdjointME_NOXSolver.xml
input_Analysis_ROL_FullSpace_AugmentedLagrangian_BoundConstrained.xml
input_Tempus_BackwardEuler_SinCos.xml
input_tempus_be_nox_solver.yaml
SOURCE_DIR ${PACKAGE_SOURCE_DIR}/test
SOURCE_PREFIX "_"
EXEDEPS ${ThyraSolverTpetra_EXENAME} ${AnalysisDriverTpetra_EXENAME}
Expand Down Expand Up @@ -380,6 +394,7 @@ IF (PIRO_HAVE_EPETRA_STACK)
MockModelEval_A.cpp
NUM_MPI_PROCS 1-4
STANDARD_PASS_OUTPUT
TARGET_DEFINES TEST_USE_EPETRA
)
TRIBITS_ADD_EXECUTABLE_AND_TEST(
TempusSolverForwardOnly_UnitTests
Expand Down Expand Up @@ -408,7 +423,6 @@ IF (PIRO_HAVE_EPETRA_STACK)
NUM_MPI_PROCS 4
STANDARD_PASS_OUTPUT
)

ENDIF (Piro_ENABLE_Tempus)

TRIBITS_COPY_FILES_TO_BINARY_DIR(copyEpetraTestInputFiles
Expand Down
1 change: 1 addition & 0 deletions packages/piro/test/MockModelEval_A_Tpetra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ MockModelEval_A_Tpetra::createOutArgsImpl() const

result.setSupports(Thyra::ModelEvaluatorBase::OUT_ARG_f, true);
result.setSupports(Thyra::ModelEvaluatorBase::OUT_ARG_W_op, true);
result.setSupports(Thyra::ModelEvaluatorBase::OUT_ARG_W, true);
result.set_W_properties(Thyra::ModelEvaluatorBase::DerivativeProperties(
Thyra::ModelEvaluatorBase::DERIV_LINEARITY_UNKNOWN,
Thyra::ModelEvaluatorBase::DERIV_RANK_FULL,
Expand Down
Loading

0 comments on commit b9eb229

Please sign in to comment.