Skip to content

Commit

Permalink
Changes to speed up multiplication of sparse matrices by vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
Gareth Aneurin Tribello authored and Gareth Aneurin Tribello committed Jul 15, 2024
1 parent ce48b26 commit 767169b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 14 deletions.
10 changes: 2 additions & 8 deletions src/core/ActionWithMatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,8 @@ inline
void ActionWithMatrix::addDerivativeOnMatrixArgument( const bool& inchain, const unsigned& ival, const unsigned& jarg, const unsigned& irow, const unsigned& jcol, const double& der, MultiValue& myvals ) const {
plumed_dbg_assert( jarg<getNumberOfArguments() && getPntrToArgument(jarg)->getRank()==2 && !getPntrToArgument(jarg)->hasDerivatives() );
unsigned ostrn = getConstPntrToComponent(ival)->getPositionInStream(), vstart=arg_deriv_starts[jarg];
if( !inchain && getPntrToArgument(jarg)->getNumberOfColumns()<getPntrToArgument(jarg)->getShape()[1] ) {
unsigned dloc = vstart + irow*getPntrToArgument(jarg)->getNumberOfColumns(); Value* myarg=getPntrToArgument(jarg);
for(unsigned i=0; i<myarg->getRowLength(irow); ++i) {
if( myarg->getRowIndex(irow,i)==jcol ) { myvals.addDerivative( ostrn, dloc+i, der ); myvals.updateIndex( ostrn, dloc+i ); return; }
}
plumed_merror("could not find element of sparse matrix to add derivative to");
} else if( !inchain ) {
unsigned dloc = vstart + irow*getPntrToArgument(jarg)->getShape()[1] + jcol;
if( !inchain ) {
unsigned dloc = vstart + irow*getPntrToArgument(jarg)->getNumberOfColumns() + jcol;
myvals.addDerivative( ostrn, dloc, der ); myvals.updateIndex( ostrn, dloc );
} else {
unsigned istrn = getPntrToArgument(jarg)->getPositionInStream();
Expand Down
24 changes: 18 additions & 6 deletions src/matrixtools/MatrixTimesVector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,32 +152,44 @@ void MatrixTimesVector::prepare() {
void MatrixTimesVector::setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const {
unsigned start_n = getPntrToArgument(0)->getShape()[0], size_v = getPntrToArgument(0)->getRowLength(task_index);
if( indices.size()!=size_v+1 ) indices.resize( size_v + 1 );
for(unsigned i=0; i<size_v; ++i) indices[i+1] = start_n + getPntrToArgument(0)->getRowIndex( task_index, i );
for(unsigned i=0; i<size_v; ++i) indices[i+1] = start_n + i;
myvals.setSplitIndex( size_v + 1 );
}

void MatrixTimesVector::performTask( const std::string& controller, const unsigned& index1, const unsigned& index2, MultiValue& myvals ) const {
unsigned ind2 = index2; if( index2>=getPntrToArgument(0)->getShape()[0] ) ind2 = index2 - getPntrToArgument(0)->getShape()[0];
if( getPntrToArgument(1)->getRank()==1 ) {
double matval = 0; Value* myarg = getPntrToArgument(0); unsigned vcol = ind2;
if( !myarg->valueHasBeenSet() ) matval = myvals.get( myarg->getPositionInStream() );
else {
matval = myarg->get( index1*myarg->getNumberOfColumns() + ind2, false );
vcol = getPntrToArgument(0)->getRowIndex( index1, ind2 );
}
for(unsigned i=0; i<getNumberOfArguments()-1; ++i) {
unsigned ostrn = getConstPntrToComponent(i)->getPositionInStream();
double matval = getElementOfMatrixArgument( 0, index1, ind2, myvals ), vecval=getArgumentElement( i+1, ind2, myvals );
double vecval=getArgumentElement( i+1, vcol, myvals );
// And add this part of the product
myvals.addValue( ostrn, matval*vecval );
// Now lets work out the derivatives
if( doNotCalculateDerivatives() ) continue;
addDerivativeOnMatrixArgument( stored_arg[0], i, 0, index1, ind2, vecval, myvals ); addDerivativeOnVectorArgument( stored_arg[i+1], i, i+1, ind2, matval, myvals );
addDerivativeOnMatrixArgument( stored_arg[0], i, 0, index1, ind2, vecval, myvals ); addDerivativeOnVectorArgument( stored_arg[i+1], i, i+1, vcol, matval, myvals );
}
} else {
unsigned n=getNumberOfArguments()-1;
unsigned n=getNumberOfArguments()-1; double matval = 0; unsigned vcol = ind2;
for(unsigned i=0; i<getNumberOfArguments()-1; ++i) {
unsigned ostrn = getConstPntrToComponent(i)->getPositionInStream();
double matval = getElementOfMatrixArgument( i, index1, ind2, myvals ), vecval=getArgumentElement( n, ind2, myvals );
Value* myarg = getPntrToArgument(i);
if( !myarg->valueHasBeenSet() ) matval = myvals.get( myarg->getPositionInStream() );
else {
matval = myarg->get( index1*myarg->getNumberOfColumns() + ind2, false );
vcol = getPntrToArgument(i)->getRowIndex( index1, ind2 );
}
double vecval=getArgumentElement( n, vcol, myvals );
// And add this part of the product
myvals.addValue( ostrn, matval*vecval );
// Now lets work out the derivatives
if( doNotCalculateDerivatives() ) continue;
addDerivativeOnMatrixArgument( stored_arg[i], i, i, index1, ind2, vecval, myvals ); addDerivativeOnVectorArgument( stored_arg[n], i, n, ind2, matval, myvals );
addDerivativeOnMatrixArgument( stored_arg[i], i, i, index1, ind2, vecval, myvals ); addDerivativeOnVectorArgument( stored_arg[n], i, n, vcol, matval, myvals );
}
}
}
Expand Down

1 comment on commit 767169b

@PlumedBot
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Found broken examples in automatic/ANGLES.tmp
Found broken examples in automatic/ANN.tmp
Found broken examples in automatic/CAVITY.tmp
Found broken examples in automatic/CLASSICAL_MDS.tmp
Found broken examples in automatic/CLUSTER_DIAMETER.tmp
Found broken examples in automatic/CLUSTER_DISTRIBUTION.tmp
Found broken examples in automatic/CLUSTER_PROPERTIES.tmp
Found broken examples in automatic/CONSTANT.tmp
Found broken examples in automatic/CONTACT_MATRIX.tmp
Found broken examples in automatic/CONTACT_MATRIX_PROPER.tmp
Found broken examples in automatic/COORDINATIONNUMBER.tmp
Found broken examples in automatic/DFSCLUSTERING.tmp
Found broken examples in automatic/DISTANCE_FROM_CONTOUR.tmp
Found broken examples in automatic/EDS.tmp
Found broken examples in automatic/EMMI.tmp
Found broken examples in automatic/ENVIRONMENTSIMILARITY.tmp
Found broken examples in automatic/FIND_CONTOUR.tmp
Found broken examples in automatic/FIND_CONTOUR_SURFACE.tmp
Found broken examples in automatic/FIND_SPHERICAL_CONTOUR.tmp
Found broken examples in automatic/FOURIER_TRANSFORM.tmp
Found broken examples in automatic/FUNCPATHGENERAL.tmp
Found broken examples in automatic/FUNCPATHMSD.tmp
Found broken examples in automatic/FUNNEL.tmp
Found broken examples in automatic/FUNNEL_PS.tmp
Found broken examples in automatic/GHBFIX.tmp
Found broken examples in automatic/GPROPERTYMAP.tmp
Found broken examples in automatic/HBOND_MATRIX.tmp
Found broken examples in automatic/INCLUDE.tmp
Found broken examples in automatic/INCYLINDER.tmp
Found broken examples in automatic/INENVELOPE.tmp
Found broken examples in automatic/INTERPOLATE_GRID.tmp
Found broken examples in automatic/LOCAL_AVERAGE.tmp
Found broken examples in automatic/MAZE_OPTIMIZER_BIAS.tmp
Found broken examples in automatic/MAZE_RANDOM_ACCELERATION_MD.tmp
Found broken examples in automatic/MAZE_SIMULATED_ANNEALING.tmp
Found broken examples in automatic/MAZE_STEERED_MD.tmp
Found broken examples in automatic/METATENSOR.tmp
Found broken examples in automatic/MULTICOLVARDENS.tmp
Found broken examples in automatic/OUTPUT_CLUSTER.tmp
Found broken examples in automatic/PAMM.tmp
Found broken examples in automatic/PCA.tmp
Found broken examples in automatic/PCAVARS.tmp
Found broken examples in automatic/PIV.tmp
Found broken examples in automatic/PLUMED.tmp
Found broken examples in automatic/PYCVINTERFACE.tmp
Found broken examples in automatic/PYTHONFUNCTION.tmp
Found broken examples in automatic/Q3.tmp
Found broken examples in automatic/Q4.tmp
Found broken examples in automatic/Q6.tmp
Found broken examples in automatic/QUATERNION.tmp
Found broken examples in automatic/SIZESHAPE_POSITION_LINEAR_PROJ.tmp
Found broken examples in automatic/SIZESHAPE_POSITION_MAHA_DIST.tmp
Found broken examples in automatic/SPRINT.tmp
Found broken examples in automatic/TETRAHEDRALPORE.tmp
Found broken examples in automatic/TORSIONS.tmp
Found broken examples in automatic/WHAM_WEIGHTS.tmp
Found broken examples in AnalysisPP.md
Found broken examples in CollectiveVariablesPP.md
Found broken examples in MiscelaneousPP.md

Please sign in to comment.