diff --git a/src/adjmat/AdjacencyMatrixBase.cpp b/src/adjmat/AdjacencyMatrixBase.cpp index 8b569829da..5409397ea5 100644 --- a/src/adjmat/AdjacencyMatrixBase.cpp +++ b/src/adjmat/AdjacencyMatrixBase.cpp @@ -309,10 +309,10 @@ void AdjacencyMatrixBase::performTask( const std::string& controller, const unsi unsigned base = 3*getNumberOfAtoms(); for(unsigned j=0; j<9; ++j) myvals.updateIndex( w_ind, base+j ); // And the indices for the derivatives of the row of the matrix if( chainContinuesAfterThisAction() ) { - unsigned nmat = getConstPntrToComponent(0)->getPositionInMatrixStash(), nmat_ind = myvals.getNumberOfMatrixRowDerivatives( nmat ); - std::vector& matrix_indices( myvals.getMatrixRowDerivativeIndices( nmat ) ); - matrix_indices[nmat_ind+0]=3*index2+0; matrix_indices[nmat_ind+1]=3*index2+1; matrix_indices[nmat_ind+2]=3*index2+2; - myvals.setNumberOfMatrixRowDerivatives( nmat, nmat_ind+3 ); + unsigned nmat = getConstPntrToComponent(0)->getPositionInMatrixStash(), nmat_ind = myvals.getNumberOfMatrixRowDerivatives( nmat ); + std::vector& matrix_indices( myvals.getMatrixRowDerivativeIndices( nmat ) ); + matrix_indices[nmat_ind+0]=3*index2+0; matrix_indices[nmat_ind+1]=3*index2+1; matrix_indices[nmat_ind+2]=3*index2+2; + myvals.setNumberOfMatrixRowDerivatives( nmat, nmat_ind+3 ); } } @@ -355,12 +355,12 @@ void AdjacencyMatrixBase::performTask( const std::string& controller, const unsi myvals.addDerivative( z_index, base+2, -atom[0] ); myvals.addDerivative( z_index, base+5, -atom[1] ); myvals.addDerivative( z_index, base+8, -atom[2] ); for(unsigned k=0; k<9; ++k) { myvals.updateIndex( x_index, base+k ); myvals.updateIndex( y_index, base+k ); myvals.updateIndex( z_index, base+k ); } if( chainContinuesAfterThisAction() ) { - for(unsigned k=1; k<4; ++k) { - unsigned nmat = getConstPntrToComponent(k)->getPositionInMatrixStash(), nmat_ind = myvals.getNumberOfMatrixRowDerivatives( nmat ); - std::vector& matrix_indices( myvals.getMatrixRowDerivativeIndices( nmat ) ); - matrix_indices[nmat_ind+0]=3*index2+0; matrix_indices[nmat_ind+1]=3*index2+1; matrix_indices[nmat_ind+2]=3*index2+2; - myvals.setNumberOfMatrixRowDerivatives( nmat, nmat_ind+3 ); - } + for(unsigned k=1; k<4; ++k) { + unsigned nmat = getConstPntrToComponent(k)->getPositionInMatrixStash(), nmat_ind = myvals.getNumberOfMatrixRowDerivatives( nmat ); + std::vector& matrix_indices( myvals.getMatrixRowDerivativeIndices( nmat ) ); + matrix_indices[nmat_ind+0]=3*index2+0; matrix_indices[nmat_ind+1]=3*index2+1; matrix_indices[nmat_ind+2]=3*index2+2; + myvals.setNumberOfMatrixRowDerivatives( nmat, nmat_ind+3 ); + } } } } diff --git a/src/core/ActionWithVector.h b/src/core/ActionWithVector.h index c8daa8980b..648e8ddfb6 100644 --- a/src/core/ActionWithVector.h +++ b/src/core/ActionWithVector.h @@ -136,7 +136,7 @@ class ActionWithVector: /// This is overridden in ActionWithMatrix virtual void getAllActionLabelsInMatrixChain( std::vector& matchain ) const {} /// Get the number of derivatives in the stream - void getNumberOfStreamedDerivatives( unsigned& nderivatives, Value* stopat ); + virtual void getNumberOfStreamedDerivatives( unsigned& nderivatives, Value* stopat ); /// Get every the label of every value that is calculated in this chain void getAllActionLabelsInChain( std::vector& mylabels ) const ; /// We override clearInputForces here to ensure that forces are deleted from all values @@ -186,7 +186,7 @@ bool ActionWithVector::actionInChain() const { return (action_to_do_before!=NULL); } -inline +inline bool ActionWithVector::chainContinuesAfterThisAction() const { return (action_to_do_after!=NULL); } diff --git a/src/matrixtools/MatrixTimesVector.cpp b/src/matrixtools/MatrixTimesVector.cpp index 1e1cd0d64b..6e54bf65f3 100644 --- a/src/matrixtools/MatrixTimesVector.cpp +++ b/src/matrixtools/MatrixTimesVector.cpp @@ -44,7 +44,8 @@ class MatrixTimesVector : public ActionWithMatrix { explicit MatrixTimesVector(const ActionOptions&); std::string getOutputComponentDescription( const std::string& cname, const Keywords& keys ) const override ; unsigned getNumberOfColumns() const override { plumed_error(); } - unsigned getNumberOfDerivatives(); + void getNumberOfStreamedDerivatives( unsigned& nderivatives, Value* stopat ) override ; + unsigned getNumberOfDerivatives() override ; void prepare() override ; void performTask( const unsigned& task_index, MultiValue& myvals ) const override ; bool isInSubChain( unsigned& nder ) override { nder = arg_deriv_starts[0]; return true; } @@ -162,6 +163,16 @@ void MatrixTimesVector::prepare() { std::vector shape(1); shape[0] = getPntrToArgument(0)->getShape()[0]; myval->setShape(shape); } +void MatrixTimesVector::getNumberOfStreamedDerivatives( unsigned& nderivatives, Value* stopat ) { + if( actionInChain() ) { ActionWithVector::getNumberOfStreamedDerivatives( nderivatives, stopat ); return; } + + nderivatives = 0; + for(unsigned i=0; igetNumberOfStoredValues(); + } +} + void MatrixTimesVector::performTask( const unsigned& task_index, MultiValue& myvals ) const { if( actionInChain() ) { ActionWithMatrix::performTask( task_index, myvals ); return; }