From 767169b3a1de91c4849435ffd95b3382812fa912 Mon Sep 17 00:00:00 2001 From: Gareth Aneurin Tribello Date: Mon, 15 Jul 2024 18:04:29 +0100 Subject: [PATCH] Changes to speed up multiplication of sparse matrices by vectors --- src/core/ActionWithMatrix.h | 10 ++-------- src/matrixtools/MatrixTimesVector.cpp | 24 ++++++++++++++++++------ 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/src/core/ActionWithMatrix.h b/src/core/ActionWithMatrix.h index c64224ccd5..4cc4de4ee3 100644 --- a/src/core/ActionWithMatrix.h +++ b/src/core/ActionWithMatrix.h @@ -134,14 +134,8 @@ inline void ActionWithMatrix::addDerivativeOnMatrixArgument( const bool& inchain, const unsigned& ival, const unsigned& jarg, const unsigned& irow, const unsigned& jcol, const double& der, MultiValue& myvals ) const { plumed_dbg_assert( jarggetRank()==2 && !getPntrToArgument(jarg)->hasDerivatives() ); unsigned ostrn = getConstPntrToComponent(ival)->getPositionInStream(), vstart=arg_deriv_starts[jarg]; - if( !inchain && getPntrToArgument(jarg)->getNumberOfColumns()getShape()[1] ) { - unsigned dloc = vstart + irow*getPntrToArgument(jarg)->getNumberOfColumns(); Value* myarg=getPntrToArgument(jarg); - for(unsigned i=0; igetRowLength(irow); ++i) { - if( myarg->getRowIndex(irow,i)==jcol ) { myvals.addDerivative( ostrn, dloc+i, der ); myvals.updateIndex( ostrn, dloc+i ); return; } - } - plumed_merror("could not find element of sparse matrix to add derivative to"); - } else if( !inchain ) { - unsigned dloc = vstart + irow*getPntrToArgument(jarg)->getShape()[1] + jcol; + if( !inchain ) { + unsigned dloc = vstart + irow*getPntrToArgument(jarg)->getNumberOfColumns() + jcol; myvals.addDerivative( ostrn, dloc, der ); myvals.updateIndex( ostrn, dloc ); } else { unsigned istrn = getPntrToArgument(jarg)->getPositionInStream(); diff --git a/src/matrixtools/MatrixTimesVector.cpp b/src/matrixtools/MatrixTimesVector.cpp index f9ea71ca65..7554c96a89 100644 --- a/src/matrixtools/MatrixTimesVector.cpp +++ b/src/matrixtools/MatrixTimesVector.cpp @@ -152,32 +152,44 @@ void MatrixTimesVector::prepare() { void MatrixTimesVector::setupForTask( const unsigned& task_index, std::vector& indices, MultiValue& myvals ) const { unsigned start_n = getPntrToArgument(0)->getShape()[0], size_v = getPntrToArgument(0)->getRowLength(task_index); if( indices.size()!=size_v+1 ) indices.resize( size_v + 1 ); - for(unsigned i=0; igetRowIndex( task_index, i ); + for(unsigned i=0; i=getPntrToArgument(0)->getShape()[0] ) ind2 = index2 - getPntrToArgument(0)->getShape()[0]; if( getPntrToArgument(1)->getRank()==1 ) { + double matval = 0; Value* myarg = getPntrToArgument(0); unsigned vcol = ind2; + if( !myarg->valueHasBeenSet() ) matval = myvals.get( myarg->getPositionInStream() ); + else { + matval = myarg->get( index1*myarg->getNumberOfColumns() + ind2, false ); + vcol = getPntrToArgument(0)->getRowIndex( index1, ind2 ); + } for(unsigned i=0; igetPositionInStream(); - double matval = getElementOfMatrixArgument( 0, index1, ind2, myvals ), vecval=getArgumentElement( i+1, ind2, myvals ); + double vecval=getArgumentElement( i+1, vcol, myvals ); // And add this part of the product myvals.addValue( ostrn, matval*vecval ); // Now lets work out the derivatives if( doNotCalculateDerivatives() ) continue; - addDerivativeOnMatrixArgument( stored_arg[0], i, 0, index1, ind2, vecval, myvals ); addDerivativeOnVectorArgument( stored_arg[i+1], i, i+1, ind2, matval, myvals ); + addDerivativeOnMatrixArgument( stored_arg[0], i, 0, index1, ind2, vecval, myvals ); addDerivativeOnVectorArgument( stored_arg[i+1], i, i+1, vcol, matval, myvals ); } } else { - unsigned n=getNumberOfArguments()-1; + unsigned n=getNumberOfArguments()-1; double matval = 0; unsigned vcol = ind2; for(unsigned i=0; igetPositionInStream(); - double matval = getElementOfMatrixArgument( i, index1, ind2, myvals ), vecval=getArgumentElement( n, ind2, myvals ); + Value* myarg = getPntrToArgument(i); + if( !myarg->valueHasBeenSet() ) matval = myvals.get( myarg->getPositionInStream() ); + else { + matval = myarg->get( index1*myarg->getNumberOfColumns() + ind2, false ); + vcol = getPntrToArgument(i)->getRowIndex( index1, ind2 ); + } + double vecval=getArgumentElement( n, vcol, myvals ); // And add this part of the product myvals.addValue( ostrn, matval*vecval ); // Now lets work out the derivatives if( doNotCalculateDerivatives() ) continue; - addDerivativeOnMatrixArgument( stored_arg[i], i, i, index1, ind2, vecval, myvals ); addDerivativeOnVectorArgument( stored_arg[n], i, n, ind2, matval, myvals ); + addDerivativeOnMatrixArgument( stored_arg[i], i, i, index1, ind2, vecval, myvals ); addDerivativeOnVectorArgument( stored_arg[n], i, n, vcol, matval, myvals ); } } }