From b9b8969c9abb7476ab4f5ec375b887197bb32438 Mon Sep 17 00:00:00 2001 From: Gareth Aneurin Tribello Date: Thu, 4 Jul 2024 16:30:25 +0100 Subject: [PATCH] Now turning off derivatives on forward pass through task loop when storing values as they are only needed on the backward pass to calculate forces --- src/core/ActionWithVector.cpp | 18 +++++++++++++++--- src/core/ActionWithVector.h | 4 ++++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/core/ActionWithVector.cpp b/src/core/ActionWithVector.cpp index b137da52f9..f68c5b83f5 100644 --- a/src/core/ActionWithVector.cpp +++ b/src/core/ActionWithVector.cpp @@ -35,7 +35,7 @@ Option interpretEnvString(const char* env,const char* str) { if(!std::strcmp(str,"no"))return Option::no; plumed_error()<<"Cannot understand env var "<& ActionWithVector::getListOfActiveTasks( ActionWithVector* return active_tasks; } +bool ActionWithVector::doNotCalculateDerivatives() const { + if( forwardPass ) return true; + return ActionWithValue::doNotCalculateDerivatives(); +} + void ActionWithVector::runAllTasks() { // Skip this if this is done elsewhere if( action_to_do_before ) return; @@ -471,6 +477,12 @@ void ActionWithVector::runAllTasks() { // Now do all preparations required to run all the tasks // prepareForTaskLoop(); + if( !action_to_do_after ) { + forwardPass=true; + for(unsigned i=0; igetRank()==0 ) { forwardPass=false; break; } + } + } // Get the total number of streamed quantities that we need unsigned nquants=0, nmatrices=0, maxcol=0, nbooks=0; getNumberOfStreamedQuantities( getLabel(), nquants, nmatrices, maxcol, nbooks ); @@ -509,7 +521,7 @@ void ActionWithVector::runAllTasks() { // MPI Gather everything if( !serial && buffer.size()>0 ) gatherProcesses( buffer ); - finishComputations( buffer ); + finishComputations( buffer ); forwardPass=false; } void ActionWithVector::gatherThreads( const unsigned& nt, const unsigned& bufsize, const std::vector& omp_buffer, std::vector& buffer, MultiValue& myvals ) { diff --git a/src/core/ActionWithVector.h b/src/core/ActionWithVector.h index c72e7d09bb..6327715da5 100644 --- a/src/core/ActionWithVector.h +++ b/src/core/ActionWithVector.h @@ -39,6 +39,8 @@ class ActionWithVector: private: /// Is the calculation to be done in serial bool serial; +/// Are we in the forward pass through the calculation + bool forwardPass; /// The buffer that we use (we keep a copy here to avoid resizing) std::vector buffer; /// The list of active tasks @@ -119,6 +121,8 @@ class ActionWithVector: virtual void prepare() override; void retrieveAtoms( const bool& force=false ) override; void calculateNumericalDerivatives(ActionWithValue* av) override; +/// Turn off the calculation of the derivatives during the forward pass through a calculation + bool doNotCalculateDerivatives() const override ; /// Are we running this command in a chain bool actionInChain() const ; /// This is overwritten within ActionWithMatrix and is used to build the chain of just matrix actions