Skip to content

Commit

Permalink
Fixed derivatives for matrix vector multiplication when matrix is spa…
Browse files Browse the repository at this point in the history
…rse and stored

This commit fixes a bug with the derivatives for matrix-vector multiplication that occurs when you store the derivatives
and compute the forces in the back-propegation (apply) step instead of calculating the derivatives in the forward (calculate) loop by using
the chain.  The problems occur when the matrix that is being used is sparse.

The feature that is fixed here is not currently used as in all the cases it could be used the derivatives are calculated during the calculate step.
I think it is better to fix it here, however, as it may be used in the future.

Notice that chnages were required to KDE because this does apply forces to a sparse matrix.  In the old version of the code when applying forces in KDE
the matrix was assumed to be full even if it was sparse.  The changes to KDE are thus using the sparsity more effectively when doing the back propegation
for the forces.
  • Loading branch information
Gareth Aneurin Tribello committed Jul 4, 2024
1 parent 120447e commit 639e810
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 14 deletions.
4 changes: 2 additions & 2 deletions src/core/ActionWithArguments.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,8 @@ void ActionWithArguments::addForcesOnArguments( const unsigned& argstart, const
for(unsigned i=0; i<arguments.size(); ++i) {
if( i==0 && getName().find("EVALUATE_FUNCTION_FROM_GRID")!=std::string::npos ) continue ;
if( !arguments[i]->ignoreStoredValue(c) || arguments[i]->getRank()==0 || (arguments[i]->getRank()>0 && arguments[i]->hasDerivatives()) ) {
unsigned nvals = arguments[i]->getNumberOfValues();
for(unsigned j=0; j<nvals; ++j) { arguments[i]->addForce( j, forces[ind] ); ind++; }
unsigned nvals = arguments[i]->getNumberOfStoredValues();
for(unsigned j=0; j<nvals; ++j) { arguments[i]->addForce( j, forces[ind], false ); ind++; }
}
}
}
Expand Down
8 changes: 7 additions & 1 deletion src/core/ActionWithMatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,13 @@ inline
void ActionWithMatrix::addDerivativeOnMatrixArgument( const bool& inchain, const unsigned& ival, const unsigned& jarg, const unsigned& irow, const unsigned& jcol, const double& der, MultiValue& myvals ) const {
plumed_dbg_assert( jarg<getNumberOfArguments() && getPntrToArgument(jarg)->getRank()==2 && !getPntrToArgument(jarg)->hasDerivatives() );
unsigned ostrn = getConstPntrToComponent(ival)->getPositionInStream(), vstart=arg_deriv_starts[jarg];
if( !inchain ) {
if( !inchain && getPntrToArgument(jarg)->getNumberOfColumns()<getPntrToArgument(jarg)->getShape()[1] ) {
unsigned dloc = vstart + irow*getPntrToArgument(jarg)->getNumberOfColumns(); Value* myarg=getPntrToArgument(jarg);
for(unsigned i=0; i<myarg->getRowLength(irow); ++i) {
if( myarg->getRowIndex(irow,i)==jcol ) { myvals.addDerivative( ostrn, dloc+i, der ); myvals.updateIndex( ostrn, dloc+i ); return; }
}
plumed_merror("could not find element of sparse matrix to add derivative to");
} else if( !inchain ) {
unsigned dloc = vstart + irow*getPntrToArgument(jarg)->getShape()[1] + jcol;
myvals.addDerivative( ostrn, dloc, der ); myvals.updateIndex( ostrn, dloc );
} else {
Expand Down
11 changes: 11 additions & 0 deletions src/core/Value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,17 @@ void Value::push_back( const double& v ) {
}
}

std::size_t Value::getIndexInStore( const std::size_t& ival ) const {
if( shape.size()==2 && ncols<shape[1] ) {
unsigned irow = std::floor( ival / shape[1] ), jcol = ival%shape[1];
for(unsigned i=0; i<getRowLength(irow); ++i) {
if( getRowIndex(irow,i)==jcol ) return irow*ncols+i;
}
plumed_merror("cannot get store index");
}
return ival;
}

double Value::get(const std::size_t& ival, const bool trueind) const {
if( hasDeriv ) return data[ival*(1+ngrid_der)];
#ifdef DNDEBUG
Expand Down
10 changes: 10 additions & 0 deletions src/core/Value.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ class Value {
void add(double);
/// Add something to the ith element of the data array
void add(const std::size_t& n, const double& v );
/// Get the location of this element of in the store
std::size_t getIndexInStore( const std::size_t& ival ) const ;
/// Get the value of the function
double get( const std::size_t& ival=0, const bool trueind=true ) const;
/// Find out if the value has been set
Expand Down Expand Up @@ -190,6 +192,8 @@ class Value {
void setSymmetric( const bool& sym );
/// Get the total number of scalars that are stored here
unsigned getNumberOfValues() const ;
/// Get the number of values that are actually stored here once sparse matrices are taken into account
unsigned getNumberOfStoredValues() const ;
/// Get the number of threads to use when assigning this value
unsigned getGoodNumThreads( const unsigned& j, const unsigned& k ) const ;
/// These are used for passing around the data in this value when we are doing replica exchange
Expand Down Expand Up @@ -401,6 +405,12 @@ unsigned Value::getNumberOfValues() const {
return size;
}

inline
unsigned Value::getNumberOfStoredValues() const {
if( getRank()==2 && !hasDeriv ) return shape[0]*ncols;
return getNumberOfValues();
}

inline
bool Value::isConstant() const {
return valtype==constant;
Expand Down
19 changes: 10 additions & 9 deletions src/gridtools/KDE.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -540,34 +540,35 @@ void KDE::gatherForcesOnStoredValue( const Value* myval, const unsigned& itask,
unsigned num_neigh; std::vector<unsigned> neighbors;
gridobject.getNeighbors( args, nneigh, num_neigh, neighbors );
std::vector<double> der( args.size() ), gpoint( args.size() );
unsigned hforce_start = 0; for(unsigned j=0; j<der.size(); ++j) hforce_start += getPntrToArgument(j)->getNumberOfStoredValues();
if( fabs(height)>epsilon ) {
if( getName()=="KDE" ) {
if( kerneltype.find("bin")!=std::string::npos ) {
std::vector<HistogramBead> bead( args.size() ); setupHistogramBeads( bead );
for(unsigned i=0; i<num_neigh; ++i) {
gridobject.getGridPointCoordinates( neighbors[i], gpoint );
double val = evaluateBeadValue( bead, gpoint, args, height, der ); double fforce = getConstPntrToComponent(0)->getForce( neighbors[i] );
if( hasheight && getPntrToArgument(args.size())->getRank()==0 ) forces[ args.size()*numberOfKernels ] += val*fforce / height;
else if( hasheight ) forces[ args.size()*numberOfKernels + itask ] += val*fforce / height;
unsigned n=itask; for(unsigned j=0; j<der.size(); ++j) { forces[n] += der[j]*fforce; n += numberOfKernels; }
if( hasheight && getPntrToArgument(args.size())->getRank()==0 ) forces[ hforce_start ] += val*fforce / height;
else if( hasheight ) forces[ hforce_start + getPntrToArgument(args.size())->getIndexInStore(itask) ] += val*fforce / height;
unsigned n=0; for(unsigned j=0; j<der.size(); ++j) { forces[n + getPntrToArgument(j)->getIndexInStore(itask)] += der[j]*fforce; n += getPntrToArgument(j)->getNumberOfStoredValues(); }
}
} else {
for(unsigned i=0; i<num_neigh; ++i) {
gridobject.getGridPointCoordinates( neighbors[i], gpoint );
double val = evaluateKernel( gpoint, args, height, der ), fforce = getConstPntrToComponent(0)->getForce( neighbors[i] );
if( hasheight && getPntrToArgument(args.size())->getRank()==0 ) forces[ args.size()*numberOfKernels ] += val*fforce / height;
else if( hasheight ) forces[ args.size()*numberOfKernels + itask ] += val*fforce / height;
unsigned n=itask; for(unsigned j=0; j<der.size(); ++j) { forces[n] += -der[j]*fforce; n += numberOfKernels; }
if( hasheight && getPntrToArgument(args.size())->getRank()==0 ) forces[ hforce_start ] += val*fforce / height;
else if( hasheight ) forces[ hforce_start + getPntrToArgument(args.size())->getIndexInStore(itask) ] += val*fforce / height;
unsigned n=0; for(unsigned j=0; j<der.size(); ++j) { forces[n + getPntrToArgument(j)->getIndexInStore(itask)] += -der[j]*fforce; n += getPntrToArgument(j)->getNumberOfStoredValues(); }
}
}
} else {
for(unsigned i=0; i<num_neigh; ++i) {
gridobject.getGridPointCoordinates( neighbors[i], gpoint );
double dot=0; for(unsigned j=0; j<gpoint.size(); ++j) dot += args[j]*gpoint[j];
double fforce = myval->getForce( neighbors[i] ); double newval = height*von_misses_norm*exp( von_misses_concentration*dot );
if( hasheight && getPntrToArgument(args.size())->getRank()==0 ) forces[ args.size()*numberOfKernels ] += newval*fforce / height;
else if( hasheight ) forces[ args.size()*numberOfKernels + itask ] += newval*fforce / height;
unsigned n=itask; for(unsigned j=0; j<gpoint.size(); ++j) { forces[n] += von_misses_concentration*newval*gpoint[j]*fforce; n += numberOfKernels; }
if( hasheight && getPntrToArgument(args.size())->getRank()==0 ) forces[ hforce_start ] += newval*fforce / height;
else if( hasheight ) forces[ hforce_start + getPntrToArgument(args.size())->getIndexInStore(itask) ] += newval*fforce / height;
unsigned n=0; for(unsigned j=0; j<gpoint.size(); ++j) { forces[n + getPntrToArgument(j)->getIndexInStore(itask)] += von_misses_concentration*newval*gpoint[j]*fforce; n += getPntrToArgument(j)->getNumberOfStoredValues(); }
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/matrixtools/MatrixTimesVector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,9 @@ void MatrixTimesVector::prepare() {
}

void MatrixTimesVector::setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const {
unsigned start_n = getPntrToArgument(0)->getShape()[0], size_v = getPntrToArgument(0)->getShape()[1];
unsigned start_n = getPntrToArgument(0)->getShape()[0], size_v = getPntrToArgument(0)->getRowLength(task_index);
if( indices.size()!=size_v+1 ) indices.resize( size_v + 1 );
for(unsigned i=0; i<size_v; ++i) indices[i+1] = start_n + i;
for(unsigned i=0; i<size_v; ++i) indices[i+1] = start_n + getPntrToArgument(0)->getRowIndex( task_index, i );
myvals.setSplitIndex( size_v + 1 );
}

Expand Down

0 comments on commit 639e810

Please sign in to comment.