Skip to content

Commit

Permalink
Optimisations of derivatives by back propegation
Browse files Browse the repository at this point in the history
  • Loading branch information
Gareth Aneurin Tribello authored and Gareth Aneurin Tribello committed Jul 19, 2024
1 parent 033d957 commit 2115528
Show file tree
Hide file tree
Showing 9 changed files with 30 additions and 27 deletions.
24 changes: 10 additions & 14 deletions src/adjmat/AdjacencyMatrixBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,12 +308,10 @@ void AdjacencyMatrixBase::performTask( const std::string& controller, const unsi
// Update dynamic list indices for virial
unsigned base = 3*getNumberOfAtoms(); for(unsigned j=0; j<9; ++j) myvals.updateIndex( w_ind, base+j );
// And the indices for the derivatives of the row of the matrix
if( chainContinuesAfterThisAction() ) {
unsigned nmat = getConstPntrToComponent(0)->getPositionInMatrixStash(), nmat_ind = myvals.getNumberOfMatrixRowDerivatives( nmat );
std::vector<unsigned>& matrix_indices( myvals.getMatrixRowDerivativeIndices( nmat ) );
matrix_indices[nmat_ind+0]=3*index2+0; matrix_indices[nmat_ind+1]=3*index2+1; matrix_indices[nmat_ind+2]=3*index2+2;
myvals.setNumberOfMatrixRowDerivatives( nmat, nmat_ind+3 );
}
unsigned nmat = getConstPntrToComponent(0)->getPositionInMatrixStash(), nmat_ind = myvals.getNumberOfMatrixRowDerivatives( nmat );
std::vector<unsigned>& matrix_indices( myvals.getMatrixRowDerivativeIndices( nmat ) );
matrix_indices[nmat_ind+0]=3*index2+0; matrix_indices[nmat_ind+1]=3*index2+1; matrix_indices[nmat_ind+2]=3*index2+2;
myvals.setNumberOfMatrixRowDerivatives( nmat, nmat_ind+3 );
}

// Calculate the components if we need them
Expand Down Expand Up @@ -354,20 +352,18 @@ void AdjacencyMatrixBase::performTask( const std::string& controller, const unsi
myvals.addDerivative( z_index, base+1, 0 ); myvals.addDerivative( z_index, base+4, 0 ); myvals.addDerivative( z_index, base+7, 0 );
myvals.addDerivative( z_index, base+2, -atom[0] ); myvals.addDerivative( z_index, base+5, -atom[1] ); myvals.addDerivative( z_index, base+8, -atom[2] );
for(unsigned k=0; k<9; ++k) { myvals.updateIndex( x_index, base+k ); myvals.updateIndex( y_index, base+k ); myvals.updateIndex( z_index, base+k ); }
if( chainContinuesAfterThisAction() ) {
for(unsigned k=1; k<4; ++k) {
unsigned nmat = getConstPntrToComponent(k)->getPositionInMatrixStash(), nmat_ind = myvals.getNumberOfMatrixRowDerivatives( nmat );
std::vector<unsigned>& matrix_indices( myvals.getMatrixRowDerivativeIndices( nmat ) );
matrix_indices[nmat_ind+0]=3*index2+0; matrix_indices[nmat_ind+1]=3*index2+1; matrix_indices[nmat_ind+2]=3*index2+2;
myvals.setNumberOfMatrixRowDerivatives( nmat, nmat_ind+3 );
}
for(unsigned k=1; k<4; ++k) {
unsigned nmat = getConstPntrToComponent(k)->getPositionInMatrixStash(), nmat_ind = myvals.getNumberOfMatrixRowDerivatives( nmat );
std::vector<unsigned>& matrix_indices( myvals.getMatrixRowDerivativeIndices( nmat ) );
matrix_indices[nmat_ind+0]=3*index2+0; matrix_indices[nmat_ind+1]=3*index2+1; matrix_indices[nmat_ind+2]=3*index2+2;
myvals.setNumberOfMatrixRowDerivatives( nmat, nmat_ind+3 );
}
}
}
}

void AdjacencyMatrixBase::runEndOfRowJobs( const unsigned& ind, const std::vector<unsigned> & indices, MultiValue& myvals ) const {
if( doNotCalculateDerivatives() || !chainContinuesAfterThisAction() ) return;
if( doNotCalculateDerivatives() ) return;

for(int k=0; k<getNumberOfComponents(); ++k) {
unsigned nmat = getConstPntrToComponent(k)->getPositionInMatrixStash(), nmat_ind = myvals.getNumberOfMatrixRowDerivatives( nmat );
Expand Down
2 changes: 1 addition & 1 deletion src/adjmat/TorsionsMatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ void TorsionsMatrix::performTask( const std::string& controller, const unsigned&
}

void TorsionsMatrix::runEndOfRowJobs( const unsigned& ival, const std::vector<unsigned> & indices, MultiValue& myvals ) const {
if( doNotCalculateDerivatives() || !matrixChainContinues() ) return ;
if( doNotCalculateDerivatives() ) return ;

unsigned mat1s = 3*ival, ss = getPntrToArgument(1)->getShape()[1];
unsigned nmat = getConstPntrToComponent(0)->getPositionInMatrixStash(), nmat_ind = myvals.getNumberOfMatrixRowDerivatives( nmat );
Expand Down
4 changes: 2 additions & 2 deletions src/core/ActionWithMatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@ bool ActionWithMatrix::checkForTaskForce( const unsigned& itask, const Value* my

void ActionWithMatrix::gatherForcesOnStoredValue( const Value* myval, const unsigned& itask, const MultiValue& myvals, std::vector<double>& forces ) const {
if( myval->getRank()==1 ) { ActionWithVector::gatherForcesOnStoredValue( myval, itask, myvals, forces ); return; }
unsigned matind = myval->getPositionInMatrixStash();
for(unsigned j=0; j<forces.size(); ++j) forces[j] += myvals.getStashedMatrixForce( matind, j );
unsigned matind = myval->getPositionInMatrixStash(); const std::vector<unsigned>& mat_indices( myvals.getMatrixRowDerivativeIndices( matind ) );
for(unsigned i=0; i<myvals.getNumberOfMatrixRowDerivatives(matind); ++i) { unsigned kind = mat_indices[i]; forces[kind] += myvals.getStashedMatrixForce( matind, kind ); }
}

void ActionWithMatrix::clearMatrixElements( MultiValue& myvals ) const {
Expand Down
2 changes: 1 addition & 1 deletion src/core/ActionWithMatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class ActionWithMatrix : public ActionWithVector {
/// Check if there are forces we need to account for on this task
bool checkForTaskForce( const unsigned& itask, const Value* myval ) const override ;
/// This gathers the force on a particular value
void gatherForcesOnStoredValue( const Value* myval, const unsigned& itask, const MultiValue& myvals, std::vector<double>& forces ) const override;
virtual void gatherForcesOnStoredValue( const Value* myval, const unsigned& itask, const MultiValue& myvals, std::vector<double>& forces ) const;
};

inline
Expand Down
6 changes: 0 additions & 6 deletions src/core/ActionWithVector.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ class ActionWithVector:
bool doNotCalculateDerivatives() const override ;
/// Are we running this command in a chain
bool actionInChain() const ;
bool chainContinuesAfterThisAction() const ;
/// This is overwritten within ActionWithMatrix and is used to build the chain of just matrix actions
virtual void finishChainBuild( ActionWithVector* act );
/// Check if there are any stored values in arguments
Expand Down Expand Up @@ -186,11 +185,6 @@ bool ActionWithVector::actionInChain() const {
return (action_to_do_before!=NULL);
}

inline
bool ActionWithVector::chainContinuesAfterThisAction() const {
return (action_to_do_after!=NULL);
}

inline
bool ActionWithVector::runInSerial() const {
return serial;
Expand Down
2 changes: 1 addition & 1 deletion src/matrixtools/MatrixTimesMatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ void MatrixTimesMatrix::performTask( const std::string& controller, const unsign
}

void MatrixTimesMatrix::runEndOfRowJobs( const unsigned& ival, const std::vector<unsigned> & indices, MultiValue& myvals ) const {
if( doNotCalculateDerivatives() || !matrixChainContinues() ) return ;
if( doNotCalculateDerivatives() ) return ;

unsigned mat1s = ival*getPntrToArgument(0)->getShape()[1];
unsigned nmult = getPntrToArgument(0)->getShape()[1], ss = getPntrToArgument(1)->getShape()[1];
Expand Down
4 changes: 2 additions & 2 deletions src/matrixtools/OuterProduct.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,13 @@ void OuterProduct::performTask( const std::string& controller, const unsigned& i
addDerivativeOnVectorArgument( stored_vector1, 0, 0, index1, function.evaluateDeriv( 0, args ), myvals );
addDerivativeOnVectorArgument( stored_vector2, 0, 1, ind2, function.evaluateDeriv( 1, args ), myvals );
}
if( doNotCalculateDerivatives() || !matrixChainContinues() ) return ;
if( doNotCalculateDerivatives() ) return ;
unsigned nmat = getConstPntrToComponent(0)->getPositionInMatrixStash(), nmat_ind = myvals.getNumberOfMatrixRowDerivatives( nmat );
myvals.getMatrixRowDerivativeIndices( nmat )[nmat_ind] = arg_deriv_starts[1] + ind2; myvals.setNumberOfMatrixRowDerivatives( nmat, nmat_ind+1 );
}

void OuterProduct::runEndOfRowJobs( const unsigned& ival, const std::vector<unsigned> & indices, MultiValue& myvals ) const {
if( doNotCalculateDerivatives() || !matrixChainContinues() ) return ;
if( doNotCalculateDerivatives() ) return ;
unsigned nmat = getConstPntrToComponent(0)->getPositionInMatrixStash(), nmat_ind = myvals.getNumberOfMatrixRowDerivatives( nmat );
myvals.getMatrixRowDerivativeIndices( nmat )[nmat_ind] = ival; myvals.setNumberOfMatrixRowDerivatives( nmat, nmat_ind+1 );
}
Expand Down
6 changes: 6 additions & 0 deletions src/tools/MultiValue.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ class MultiValue {
void setNumberOfMatrixRowDerivatives( const unsigned& nmat, const unsigned& nind );
unsigned getNumberOfMatrixRowDerivatives( const unsigned& nmat ) const ;
std::vector<unsigned>& getMatrixRowDerivativeIndices( const unsigned& nmat );
const std::vector<unsigned>& getMatrixRowDerivativeIndices( const unsigned& nmat ) const ;
/// Stash the forces on the matrix
void addMatrixForce( const unsigned& imat, const unsigned& jind, const double& f );
double getStashedMatrixForce( const unsigned& imat, const unsigned& jind ) const ;
Expand Down Expand Up @@ -335,6 +336,11 @@ std::vector<unsigned>& MultiValue::getMatrixRowDerivativeIndices( const unsigned
plumed_dbg_assert( nmat<matrix_row_nderivatives.size() ); return matrix_row_derivative_indices[nmat];
}

inline
const std::vector<unsigned>& MultiValue::getMatrixRowDerivativeIndices( const unsigned& nmat ) const {
plumed_dbg_assert( nmat<matrix_row_nderivatives.size() ); return matrix_row_derivative_indices[nmat];
}

inline
void MultiValue::addMatrixForce( const unsigned& imat, const unsigned& jind, const double& f ) {
matrix_force_stash[imat*nderivatives + jind]+=f;
Expand Down
7 changes: 7 additions & 0 deletions src/valtools/VStack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class VStack : public ActionWithMatrix {
void runEndOfRowJobs( const unsigned& ival, const std::vector<unsigned> & indices, MultiValue& myvals ) const override ;
///
void getMatrixColumnTitles( std::vector<std::string>& argnames ) const override ;
///
void gatherForcesOnStoredValue( const Value* myval, const unsigned& itask, const MultiValue& myvals, std::vector<double>& forces ) const override ;
};

PLUMED_REGISTER_ACTION(VStack,"VSTACK")
Expand Down Expand Up @@ -137,6 +139,11 @@ void VStack::performTask( const std::string& controller, const unsigned& index1,
addDerivativeOnVectorArgument( stored[ind2], 0, ind2, index1, 1.0, myvals );
}

void VStack::gatherForcesOnStoredValue( const Value* myval, const unsigned& itask, const MultiValue& myvals, std::vector<double>& forces ) const {
unsigned matind = myval->getPositionInMatrixStash(); const std::vector<unsigned>& mat_indices( myvals.getMatrixRowDerivativeIndices( matind ) );

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable mat_indices is not used.
for(unsigned i=0; i<forces.size(); ++i) forces[i] += myvals.getStashedMatrixForce( matind, i );
}

void VStack::runEndOfRowJobs( const unsigned& ival, const std::vector<unsigned> & indices, MultiValue& myvals ) const {
if( doNotCalculateDerivatives() || !matrixChainContinues() ) return ;

Expand Down

1 comment on commit 2115528

@PlumedBot
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Found broken examples in automatic/ANGLES.tmp
Found broken examples in automatic/ANN.tmp
Found broken examples in automatic/CAVITY.tmp
Found broken examples in automatic/CLASSICAL_MDS.tmp
Found broken examples in automatic/CLUSTER_DIAMETER.tmp
Found broken examples in automatic/CLUSTER_DISTRIBUTION.tmp
Found broken examples in automatic/CLUSTER_PROPERTIES.tmp
Found broken examples in automatic/CONSTANT.tmp
Found broken examples in automatic/CONTACT_MATRIX.tmp
Found broken examples in automatic/CONTACT_MATRIX_PROPER.tmp
Found broken examples in automatic/COORDINATIONNUMBER.tmp
Found broken examples in automatic/DFSCLUSTERING.tmp
Found broken examples in automatic/DISTANCE_FROM_CONTOUR.tmp
Found broken examples in automatic/EDS.tmp
Found broken examples in automatic/EMMI.tmp
Found broken examples in automatic/ENVIRONMENTSIMILARITY.tmp
Found broken examples in automatic/FIND_CONTOUR.tmp
Found broken examples in automatic/FIND_CONTOUR_SURFACE.tmp
Found broken examples in automatic/FIND_SPHERICAL_CONTOUR.tmp
Found broken examples in automatic/FOURIER_TRANSFORM.tmp
Found broken examples in automatic/FUNCPATHGENERAL.tmp
Found broken examples in automatic/FUNCPATHMSD.tmp
Found broken examples in automatic/FUNNEL.tmp
Found broken examples in automatic/FUNNEL_PS.tmp
Found broken examples in automatic/GHBFIX.tmp
Found broken examples in automatic/GPROPERTYMAP.tmp
Found broken examples in automatic/HBOND_MATRIX.tmp
Found broken examples in automatic/INCLUDE.tmp
Found broken examples in automatic/INCYLINDER.tmp
Found broken examples in automatic/INENVELOPE.tmp
Found broken examples in automatic/INTERPOLATE_GRID.tmp
Found broken examples in automatic/LOCAL_AVERAGE.tmp
Found broken examples in automatic/MAZE_OPTIMIZER_BIAS.tmp
Found broken examples in automatic/MAZE_RANDOM_ACCELERATION_MD.tmp
Found broken examples in automatic/MAZE_SIMULATED_ANNEALING.tmp
Found broken examples in automatic/MAZE_STEERED_MD.tmp
Found broken examples in automatic/METATENSOR.tmp
Found broken examples in automatic/MULTICOLVARDENS.tmp
Found broken examples in automatic/OUTPUT_CLUSTER.tmp
Found broken examples in automatic/PAMM.tmp
Found broken examples in automatic/PCA.tmp
Found broken examples in automatic/PCAVARS.tmp
Found broken examples in automatic/PIV.tmp
Found broken examples in automatic/PLUMED.tmp
Found broken examples in automatic/PYCVINTERFACE.tmp
Found broken examples in automatic/PYTHONFUNCTION.tmp
Found broken examples in automatic/Q3.tmp
Found broken examples in automatic/Q4.tmp
Found broken examples in automatic/Q6.tmp
Found broken examples in automatic/QUATERNION.tmp
Found broken examples in automatic/SIZESHAPE_POSITION_LINEAR_PROJ.tmp
Found broken examples in automatic/SIZESHAPE_POSITION_MAHA_DIST.tmp
Found broken examples in automatic/SPRINT.tmp
Found broken examples in automatic/TETRAHEDRALPORE.tmp
Found broken examples in automatic/TORSIONS.tmp
Found broken examples in automatic/WHAM_WEIGHTS.tmp
Found broken examples in AnalysisPP.md
Found broken examples in CollectiveVariablesPP.md
Found broken examples in MiscelaneousPP.md

Please sign in to comment.