Skip to content

Commit

Permalink
Added faster version of matrix vector multiply that is used when not …
Browse files Browse the repository at this point in the history
…employing the chain
  • Loading branch information
Gareth Aneurin Tribello committed Jul 16, 2024
1 parent ead56cd commit cd6f553
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/core/ActionWithMatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class ActionWithMatrix : public ActionWithVector {
//// This does some setup before we run over the row of the matrix
virtual void setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const = 0;
/// Run over one row of the matrix
void performTask( const unsigned& task_index, MultiValue& myvals ) const override ;
virtual void performTask( const unsigned& task_index, MultiValue& myvals ) const ;
/// Gather a row of the matrix
void gatherStoredValue( const unsigned& valindex, const unsigned& code, const MultiValue& myvals, const unsigned& bufstart, std::vector<double>& buffer ) const override;
/// Gather all the data from the threads
Expand Down
69 changes: 69 additions & 0 deletions src/matrixtools/MatrixTimesVector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class MatrixTimesVector : public ActionWithMatrix {
unsigned getNumberOfColumns() const override { plumed_error(); }
unsigned getNumberOfDerivatives();
void prepare() override ;
void performTask( const unsigned& task_index, MultiValue& myvals ) const override ;
bool isInSubChain( unsigned& nder ) override { nder = arg_deriv_starts[0]; return true; }
void setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const ;
void performTask( const std::string& controller, const unsigned& index1, const unsigned& index2, MultiValue& myvals ) const override;
Expand Down Expand Up @@ -161,6 +162,74 @@ void MatrixTimesVector::prepare() {
std::vector<unsigned> shape(1); shape[0] = getPntrToArgument(0)->getShape()[0]; myval->setShape(shape);
}

void MatrixTimesVector::performTask( const unsigned& task_index, MultiValue& myvals ) const {
if( actionInChain() ) { ActionWithMatrix::performTask( task_index, myvals ); return; }

if( sumrows ) {
unsigned n=getNumberOfArguments()-1; Value* myvec = getPntrToArgument(n);

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable myvec is not used.
for(unsigned i=0; i<n; ++i) {
Value* mymat = getPntrToArgument(i);
unsigned ncol = mymat->getNumberOfColumns();
unsigned nmat = mymat->getRowLength(task_index);
double val=0; for(unsigned j=0; j<nmat; ++j) val += mymat->get( task_index*ncol + j, false );
unsigned ostrn = getConstPntrToComponent(i)->getPositionInStream();
myvals.setValue( ostrn, val );

// And the derivatives
if( doNotCalculateDerivatives() ) continue;

unsigned dloc = arg_deriv_starts[i] + task_index*ncol;
for(unsigned j=0; j<nmat; ++j) {
myvals.addDerivative( ostrn, dloc + j, 1.0 ); myvals.updateIndex( ostrn, dloc + j );
}
}
} else if( getPntrToArgument(1)->getRank()==1 ) {
Value* mymat = getPntrToArgument(0);
unsigned ncol = mymat->getNumberOfColumns();
unsigned nmat = mymat->getRowLength(task_index);
unsigned dloc = arg_deriv_starts[0] + task_index*ncol;
for(unsigned i=0; i<getNumberOfArguments()-1; ++i) {
Value* myvec = getPntrToArgument(i+1);
double val=0; for(unsigned j=0; j<nmat; ++j) val += mymat->get( task_index*ncol + j, false )*myvec->get( mymat->getRowIndex( task_index, j ) );
unsigned ostrn = getConstPntrToComponent(i)->getPositionInStream();
myvals.setValue( ostrn, val );

// And the derivatives
if( doNotCalculateDerivatives() ) continue;

for(unsigned j=0; j<nmat; ++j) {
unsigned kind = mymat->getRowIndex( task_index, j );
double vecval = myvec->get( kind );
double matval = mymat->get( task_index*ncol + j, false );
myvals.addDerivative( ostrn, dloc + j, vecval ); myvals.updateIndex( ostrn, dloc + j );
myvals.addDerivative( ostrn, arg_deriv_starts[i+1] + kind, matval ); myvals.updateIndex( ostrn, arg_deriv_starts[i+1] + kind );
}
}
} else {
unsigned n=getNumberOfArguments()-1; Value* myvec = getPntrToArgument(n);
for(unsigned i=0; i<n; ++i) {
Value* mymat = getPntrToArgument(i);
unsigned ncol = mymat->getNumberOfColumns();
unsigned nmat = mymat->getRowLength(task_index);
double val=0; for(unsigned j=0; j<nmat; ++j) val += mymat->get( task_index*ncol + j, false )*myvec->get( mymat->getRowIndex( task_index, j ) );
unsigned ostrn = getConstPntrToComponent(i)->getPositionInStream();
myvals.setValue( ostrn, val );

// And the derivatives
if( doNotCalculateDerivatives() ) continue;

unsigned dloc = arg_deriv_starts[i] + task_index*ncol;
for(unsigned j=0; j<nmat; ++j) {
unsigned kind = mymat->getRowIndex( task_index, j );
double vecval = myvec->get( kind );
double matval = mymat->get( task_index*ncol + j, false );
myvals.addDerivative( ostrn, dloc + j, vecval ); myvals.updateIndex( ostrn, dloc + j );
myvals.addDerivative( ostrn, arg_deriv_starts[n] + kind, matval ); myvals.updateIndex( ostrn, arg_deriv_starts[n] + kind );
}
}
}
}

void MatrixTimesVector::setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const {
unsigned start_n = getPntrToArgument(0)->getShape()[0], size_v = getPntrToArgument(0)->getRowLength(task_index);
if( indices.size()!=size_v+1 ) indices.resize( size_v + 1 );
Expand Down

1 comment on commit cd6f553

@PlumedBot
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Found broken examples in automatic/ANGLES.tmp
Found broken examples in automatic/ANN.tmp
Found broken examples in automatic/CAVITY.tmp
Found broken examples in automatic/CLASSICAL_MDS.tmp
Found broken examples in automatic/CLUSTER_DIAMETER.tmp
Found broken examples in automatic/CLUSTER_DISTRIBUTION.tmp
Found broken examples in automatic/CLUSTER_PROPERTIES.tmp
Found broken examples in automatic/CONSTANT.tmp
Found broken examples in automatic/CONTACT_MATRIX.tmp
Found broken examples in automatic/CONTACT_MATRIX_PROPER.tmp
Found broken examples in automatic/COORDINATIONNUMBER.tmp
Found broken examples in automatic/DFSCLUSTERING.tmp
Found broken examples in automatic/DISTANCE_FROM_CONTOUR.tmp
Found broken examples in automatic/EDS.tmp
Found broken examples in automatic/EMMI.tmp
Found broken examples in automatic/ENVIRONMENTSIMILARITY.tmp
Found broken examples in automatic/FIND_CONTOUR.tmp
Found broken examples in automatic/FIND_CONTOUR_SURFACE.tmp
Found broken examples in automatic/FIND_SPHERICAL_CONTOUR.tmp
Found broken examples in automatic/FOURIER_TRANSFORM.tmp
Found broken examples in automatic/FUNCPATHGENERAL.tmp
Found broken examples in automatic/FUNCPATHMSD.tmp
Found broken examples in automatic/FUNNEL.tmp
Found broken examples in automatic/FUNNEL_PS.tmp
Found broken examples in automatic/GHBFIX.tmp
Found broken examples in automatic/GPROPERTYMAP.tmp
Found broken examples in automatic/HBOND_MATRIX.tmp
Found broken examples in automatic/INCLUDE.tmp
Found broken examples in automatic/INCYLINDER.tmp
Found broken examples in automatic/INENVELOPE.tmp
Found broken examples in automatic/INTERPOLATE_GRID.tmp
Found broken examples in automatic/LOCAL_AVERAGE.tmp
Found broken examples in automatic/MAZE_OPTIMIZER_BIAS.tmp
Found broken examples in automatic/MAZE_RANDOM_ACCELERATION_MD.tmp
Found broken examples in automatic/MAZE_SIMULATED_ANNEALING.tmp
Found broken examples in automatic/MAZE_STEERED_MD.tmp
Found broken examples in automatic/METATENSOR.tmp
Found broken examples in automatic/MULTICOLVARDENS.tmp
Found broken examples in automatic/OUTPUT_CLUSTER.tmp
Found broken examples in automatic/PAMM.tmp
Found broken examples in automatic/PCA.tmp
Found broken examples in automatic/PCAVARS.tmp
Found broken examples in automatic/PIV.tmp
Found broken examples in automatic/PLUMED.tmp
Found broken examples in automatic/PYCVINTERFACE.tmp
Found broken examples in automatic/PYTHONFUNCTION.tmp
Found broken examples in automatic/Q3.tmp
Found broken examples in automatic/Q4.tmp
Found broken examples in automatic/Q6.tmp
Found broken examples in automatic/QUATERNION.tmp
Found broken examples in automatic/SIZESHAPE_POSITION_LINEAR_PROJ.tmp
Found broken examples in automatic/SIZESHAPE_POSITION_MAHA_DIST.tmp
Found broken examples in automatic/SPRINT.tmp
Found broken examples in automatic/TETRAHEDRALPORE.tmp
Found broken examples in automatic/TORSIONS.tmp
Found broken examples in automatic/WHAM_WEIGHTS.tmp
Found broken examples in AnalysisPP.md
Found broken examples in CollectiveVariablesPP.md
Found broken examples in MiscelaneousPP.md

Please sign in to comment.