Skip to content

Commit 767169b

Browse files
Gareth Aneurin TribelloGareth Aneurin Tribello
Gareth Aneurin Tribello
authored and
Gareth Aneurin Tribello
committed
Changes to speed up multiplication of sparse matrices by vectors
1 parent ce48b26 commit 767169b

File tree

2 files changed

+20
-14
lines changed

2 files changed

+20
-14
lines changed

src/core/ActionWithMatrix.h

+2-8
Original file line numberDiff line numberDiff line change
@@ -134,14 +134,8 @@ inline
134134
void ActionWithMatrix::addDerivativeOnMatrixArgument( const bool& inchain, const unsigned& ival, const unsigned& jarg, const unsigned& irow, const unsigned& jcol, const double& der, MultiValue& myvals ) const {
135135
plumed_dbg_assert( jarg<getNumberOfArguments() && getPntrToArgument(jarg)->getRank()==2 && !getPntrToArgument(jarg)->hasDerivatives() );
136136
unsigned ostrn = getConstPntrToComponent(ival)->getPositionInStream(), vstart=arg_deriv_starts[jarg];
137-
if( !inchain && getPntrToArgument(jarg)->getNumberOfColumns()<getPntrToArgument(jarg)->getShape()[1] ) {
138-
unsigned dloc = vstart + irow*getPntrToArgument(jarg)->getNumberOfColumns(); Value* myarg=getPntrToArgument(jarg);
139-
for(unsigned i=0; i<myarg->getRowLength(irow); ++i) {
140-
if( myarg->getRowIndex(irow,i)==jcol ) { myvals.addDerivative( ostrn, dloc+i, der ); myvals.updateIndex( ostrn, dloc+i ); return; }
141-
}
142-
plumed_merror("could not find element of sparse matrix to add derivative to");
143-
} else if( !inchain ) {
144-
unsigned dloc = vstart + irow*getPntrToArgument(jarg)->getShape()[1] + jcol;
137+
if( !inchain ) {
138+
unsigned dloc = vstart + irow*getPntrToArgument(jarg)->getNumberOfColumns() + jcol;
145139
myvals.addDerivative( ostrn, dloc, der ); myvals.updateIndex( ostrn, dloc );
146140
} else {
147141
unsigned istrn = getPntrToArgument(jarg)->getPositionInStream();

src/matrixtools/MatrixTimesVector.cpp

+18-6
Original file line numberDiff line numberDiff line change
@@ -152,32 +152,44 @@ void MatrixTimesVector::prepare() {
152152
void MatrixTimesVector::setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const {
153153
unsigned start_n = getPntrToArgument(0)->getShape()[0], size_v = getPntrToArgument(0)->getRowLength(task_index);
154154
if( indices.size()!=size_v+1 ) indices.resize( size_v + 1 );
155-
for(unsigned i=0; i<size_v; ++i) indices[i+1] = start_n + getPntrToArgument(0)->getRowIndex( task_index, i );
155+
for(unsigned i=0; i<size_v; ++i) indices[i+1] = start_n + i;
156156
myvals.setSplitIndex( size_v + 1 );
157157
}
158158

159159
void MatrixTimesVector::performTask( const std::string& controller, const unsigned& index1, const unsigned& index2, MultiValue& myvals ) const {
160160
unsigned ind2 = index2; if( index2>=getPntrToArgument(0)->getShape()[0] ) ind2 = index2 - getPntrToArgument(0)->getShape()[0];
161161
if( getPntrToArgument(1)->getRank()==1 ) {
162+
double matval = 0; Value* myarg = getPntrToArgument(0); unsigned vcol = ind2;
163+
if( !myarg->valueHasBeenSet() ) matval = myvals.get( myarg->getPositionInStream() );
164+
else {
165+
matval = myarg->get( index1*myarg->getNumberOfColumns() + ind2, false );
166+
vcol = getPntrToArgument(0)->getRowIndex( index1, ind2 );
167+
}
162168
for(unsigned i=0; i<getNumberOfArguments()-1; ++i) {
163169
unsigned ostrn = getConstPntrToComponent(i)->getPositionInStream();
164-
double matval = getElementOfMatrixArgument( 0, index1, ind2, myvals ), vecval=getArgumentElement( i+1, ind2, myvals );
170+
double vecval=getArgumentElement( i+1, vcol, myvals );
165171
// And add this part of the product
166172
myvals.addValue( ostrn, matval*vecval );
167173
// Now lets work out the derivatives
168174
if( doNotCalculateDerivatives() ) continue;
169-
addDerivativeOnMatrixArgument( stored_arg[0], i, 0, index1, ind2, vecval, myvals ); addDerivativeOnVectorArgument( stored_arg[i+1], i, i+1, ind2, matval, myvals );
175+
addDerivativeOnMatrixArgument( stored_arg[0], i, 0, index1, ind2, vecval, myvals ); addDerivativeOnVectorArgument( stored_arg[i+1], i, i+1, vcol, matval, myvals );
170176
}
171177
} else {
172-
unsigned n=getNumberOfArguments()-1;
178+
unsigned n=getNumberOfArguments()-1; double matval = 0; unsigned vcol = ind2;
173179
for(unsigned i=0; i<getNumberOfArguments()-1; ++i) {
174180
unsigned ostrn = getConstPntrToComponent(i)->getPositionInStream();
175-
double matval = getElementOfMatrixArgument( i, index1, ind2, myvals ), vecval=getArgumentElement( n, ind2, myvals );
181+
Value* myarg = getPntrToArgument(i);
182+
if( !myarg->valueHasBeenSet() ) matval = myvals.get( myarg->getPositionInStream() );
183+
else {
184+
matval = myarg->get( index1*myarg->getNumberOfColumns() + ind2, false );
185+
vcol = getPntrToArgument(i)->getRowIndex( index1, ind2 );
186+
}
187+
double vecval=getArgumentElement( n, vcol, myvals );
176188
// And add this part of the product
177189
myvals.addValue( ostrn, matval*vecval );
178190
// Now lets work out the derivatives
179191
if( doNotCalculateDerivatives() ) continue;
180-
addDerivativeOnMatrixArgument( stored_arg[i], i, i, index1, ind2, vecval, myvals ); addDerivativeOnVectorArgument( stored_arg[n], i, n, ind2, matval, myvals );
192+
addDerivativeOnMatrixArgument( stored_arg[i], i, i, index1, ind2, vecval, myvals ); addDerivativeOnVectorArgument( stored_arg[n], i, n, vcol, matval, myvals );
181193
}
182194
}
183195
}

0 commit comments

Comments
 (0)