From cd6f553b7cb536b70130922cbef4d22a9536eaf5 Mon Sep 17 00:00:00 2001 From: Gareth Aneurin Tribello Date: Tue, 16 Jul 2024 14:50:16 +0100 Subject: [PATCH] Added faster version of matrix vector multiply that is used when not employing the chain --- src/core/ActionWithMatrix.h | 2 +- src/matrixtools/MatrixTimesVector.cpp | 69 +++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 1 deletion(-) diff --git a/src/core/ActionWithMatrix.h b/src/core/ActionWithMatrix.h index 4cc4de4ee3..ae27d7f8bc 100644 --- a/src/core/ActionWithMatrix.h +++ b/src/core/ActionWithMatrix.h @@ -75,7 +75,7 @@ class ActionWithMatrix : public ActionWithVector { //// This does some setup before we run over the row of the matrix virtual void setupForTask( const unsigned& task_index, std::vector& indices, MultiValue& myvals ) const = 0; /// Run over one row of the matrix - void performTask( const unsigned& task_index, MultiValue& myvals ) const override ; + virtual void performTask( const unsigned& task_index, MultiValue& myvals ) const ; /// Gather a row of the matrix void gatherStoredValue( const unsigned& valindex, const unsigned& code, const MultiValue& myvals, const unsigned& bufstart, std::vector& buffer ) const override; /// Gather all the data from the threads diff --git a/src/matrixtools/MatrixTimesVector.cpp b/src/matrixtools/MatrixTimesVector.cpp index 4725b5131c..1e1cd0d64b 100644 --- a/src/matrixtools/MatrixTimesVector.cpp +++ b/src/matrixtools/MatrixTimesVector.cpp @@ -46,6 +46,7 @@ class MatrixTimesVector : public ActionWithMatrix { unsigned getNumberOfColumns() const override { plumed_error(); } unsigned getNumberOfDerivatives(); void prepare() override ; + void performTask( const unsigned& task_index, MultiValue& myvals ) const override ; bool isInSubChain( unsigned& nder ) override { nder = arg_deriv_starts[0]; return true; } void setupForTask( const unsigned& task_index, std::vector& indices, MultiValue& myvals ) const ; void performTask( const std::string& controller, const unsigned& index1, const unsigned& index2, MultiValue& myvals ) const override; @@ -161,6 +162,74 @@ void MatrixTimesVector::prepare() { std::vector shape(1); shape[0] = getPntrToArgument(0)->getShape()[0]; myval->setShape(shape); } +void MatrixTimesVector::performTask( const unsigned& task_index, MultiValue& myvals ) const { + if( actionInChain() ) { ActionWithMatrix::performTask( task_index, myvals ); return; } + + if( sumrows ) { + unsigned n=getNumberOfArguments()-1; Value* myvec = getPntrToArgument(n); + for(unsigned i=0; igetNumberOfColumns(); + unsigned nmat = mymat->getRowLength(task_index); + double val=0; for(unsigned j=0; jget( task_index*ncol + j, false ); + unsigned ostrn = getConstPntrToComponent(i)->getPositionInStream(); + myvals.setValue( ostrn, val ); + + // And the derivatives + if( doNotCalculateDerivatives() ) continue; + + unsigned dloc = arg_deriv_starts[i] + task_index*ncol; + for(unsigned j=0; jgetRank()==1 ) { + Value* mymat = getPntrToArgument(0); + unsigned ncol = mymat->getNumberOfColumns(); + unsigned nmat = mymat->getRowLength(task_index); + unsigned dloc = arg_deriv_starts[0] + task_index*ncol; + for(unsigned i=0; iget( task_index*ncol + j, false )*myvec->get( mymat->getRowIndex( task_index, j ) ); + unsigned ostrn = getConstPntrToComponent(i)->getPositionInStream(); + myvals.setValue( ostrn, val ); + + // And the derivatives + if( doNotCalculateDerivatives() ) continue; + + for(unsigned j=0; jgetRowIndex( task_index, j ); + double vecval = myvec->get( kind ); + double matval = mymat->get( task_index*ncol + j, false ); + myvals.addDerivative( ostrn, dloc + j, vecval ); myvals.updateIndex( ostrn, dloc + j ); + myvals.addDerivative( ostrn, arg_deriv_starts[i+1] + kind, matval ); myvals.updateIndex( ostrn, arg_deriv_starts[i+1] + kind ); + } + } + } else { + unsigned n=getNumberOfArguments()-1; Value* myvec = getPntrToArgument(n); + for(unsigned i=0; igetNumberOfColumns(); + unsigned nmat = mymat->getRowLength(task_index); + double val=0; for(unsigned j=0; jget( task_index*ncol + j, false )*myvec->get( mymat->getRowIndex( task_index, j ) ); + unsigned ostrn = getConstPntrToComponent(i)->getPositionInStream(); + myvals.setValue( ostrn, val ); + + // And the derivatives + if( doNotCalculateDerivatives() ) continue; + + unsigned dloc = arg_deriv_starts[i] + task_index*ncol; + for(unsigned j=0; jgetRowIndex( task_index, j ); + double vecval = myvec->get( kind ); + double matval = mymat->get( task_index*ncol + j, false ); + myvals.addDerivative( ostrn, dloc + j, vecval ); myvals.updateIndex( ostrn, dloc + j ); + myvals.addDerivative( ostrn, arg_deriv_starts[n] + kind, matval ); myvals.updateIndex( ostrn, arg_deriv_starts[n] + kind ); + } + } + } +} + void MatrixTimesVector::setupForTask( const unsigned& task_index, std::vector& indices, MultiValue& myvals ) const { unsigned start_n = getPntrToArgument(0)->getShape()[0], size_v = getPntrToArgument(0)->getRowLength(task_index); if( indices.size()!=size_v+1 ) indices.resize( size_v + 1 );