From 639e810472d15943e53f5cabf8b75b9c1208b712 Mon Sep 17 00:00:00 2001 From: Gareth Aneurin Tribello Date: Thu, 4 Jul 2024 09:36:48 +0100 Subject: [PATCH] Fixed derivatives for matrix vector multiplication when matrix is sparse and stored This commit fixes a bug with the derivatives for matrix-vector multiplication that occurs when you store the derivatives and compute the forces in the back-propegation (apply) step instead of calculating the derivatives in the forward (calculate) loop by using the chain. The problems occur when the matrix that is being used is sparse. The feature that is fixed here is not currently used as in all the cases it could be used the derivatives are calculated during the calculate step. I think it is better to fix it here, however, as it may be used in the future. Notice that chnages were required to KDE because this does apply forces to a sparse matrix. In the old version of the code when applying forces in KDE the matrix was assumed to be full even if it was sparse. The changes to KDE are thus using the sparsity more effectively when doing the back propegation for the forces. --- src/core/ActionWithArguments.cpp | 4 ++-- src/core/ActionWithMatrix.h | 8 +++++++- src/core/Value.cpp | 11 +++++++++++ src/core/Value.h | 10 ++++++++++ src/gridtools/KDE.cpp | 19 ++++++++++--------- src/matrixtools/MatrixTimesVector.cpp | 4 ++-- 6 files changed, 42 insertions(+), 14 deletions(-) diff --git a/src/core/ActionWithArguments.cpp b/src/core/ActionWithArguments.cpp index 8e84498f9a..a1b9208e6a 100644 --- a/src/core/ActionWithArguments.cpp +++ b/src/core/ActionWithArguments.cpp @@ -299,8 +299,8 @@ void ActionWithArguments::addForcesOnArguments( const unsigned& argstart, const for(unsigned i=0; iignoreStoredValue(c) || arguments[i]->getRank()==0 || (arguments[i]->getRank()>0 && arguments[i]->hasDerivatives()) ) { - unsigned nvals = arguments[i]->getNumberOfValues(); - for(unsigned j=0; jaddForce( j, forces[ind] ); ind++; } + unsigned nvals = arguments[i]->getNumberOfStoredValues(); + for(unsigned j=0; jaddForce( j, forces[ind], false ); ind++; } } } } diff --git a/src/core/ActionWithMatrix.h b/src/core/ActionWithMatrix.h index cb1eeeccf7..1d5d7ecd53 100644 --- a/src/core/ActionWithMatrix.h +++ b/src/core/ActionWithMatrix.h @@ -132,7 +132,13 @@ inline void ActionWithMatrix::addDerivativeOnMatrixArgument( const bool& inchain, const unsigned& ival, const unsigned& jarg, const unsigned& irow, const unsigned& jcol, const double& der, MultiValue& myvals ) const { plumed_dbg_assert( jarggetRank()==2 && !getPntrToArgument(jarg)->hasDerivatives() ); unsigned ostrn = getConstPntrToComponent(ival)->getPositionInStream(), vstart=arg_deriv_starts[jarg]; - if( !inchain ) { + if( !inchain && getPntrToArgument(jarg)->getNumberOfColumns()getShape()[1] ) { + unsigned dloc = vstart + irow*getPntrToArgument(jarg)->getNumberOfColumns(); Value* myarg=getPntrToArgument(jarg); + for(unsigned i=0; igetRowLength(irow); ++i) { + if( myarg->getRowIndex(irow,i)==jcol ) { myvals.addDerivative( ostrn, dloc+i, der ); myvals.updateIndex( ostrn, dloc+i ); return; } + } + plumed_merror("could not find element of sparse matrix to add derivative to"); + } else if( !inchain ) { unsigned dloc = vstart + irow*getPntrToArgument(jarg)->getShape()[1] + jcol; myvals.addDerivative( ostrn, dloc, der ); myvals.updateIndex( ostrn, dloc ); } else { diff --git a/src/core/Value.cpp b/src/core/Value.cpp index a9b18a6d7f..6f4466bc3f 100644 --- a/src/core/Value.cpp +++ b/src/core/Value.cpp @@ -237,6 +237,17 @@ void Value::push_back( const double& v ) { } } +std::size_t Value::getIndexInStore( const std::size_t& ival ) const { + if( shape.size()==2 && ncols neighbors; gridobject.getNeighbors( args, nneigh, num_neigh, neighbors ); std::vector der( args.size() ), gpoint( args.size() ); + unsigned hforce_start = 0; for(unsigned j=0; jgetNumberOfStoredValues(); if( fabs(height)>epsilon ) { if( getName()=="KDE" ) { if( kerneltype.find("bin")!=std::string::npos ) { @@ -547,17 +548,17 @@ void KDE::gatherForcesOnStoredValue( const Value* myval, const unsigned& itask, for(unsigned i=0; igetForce( neighbors[i] ); - if( hasheight && getPntrToArgument(args.size())->getRank()==0 ) forces[ args.size()*numberOfKernels ] += val*fforce / height; - else if( hasheight ) forces[ args.size()*numberOfKernels + itask ] += val*fforce / height; - unsigned n=itask; for(unsigned j=0; jgetRank()==0 ) forces[ hforce_start ] += val*fforce / height; + else if( hasheight ) forces[ hforce_start + getPntrToArgument(args.size())->getIndexInStore(itask) ] += val*fforce / height; + unsigned n=0; for(unsigned j=0; jgetIndexInStore(itask)] += der[j]*fforce; n += getPntrToArgument(j)->getNumberOfStoredValues(); } } } else { for(unsigned i=0; igetForce( neighbors[i] ); - if( hasheight && getPntrToArgument(args.size())->getRank()==0 ) forces[ args.size()*numberOfKernels ] += val*fforce / height; - else if( hasheight ) forces[ args.size()*numberOfKernels + itask ] += val*fforce / height; - unsigned n=itask; for(unsigned j=0; jgetRank()==0 ) forces[ hforce_start ] += val*fforce / height; + else if( hasheight ) forces[ hforce_start + getPntrToArgument(args.size())->getIndexInStore(itask) ] += val*fforce / height; + unsigned n=0; for(unsigned j=0; jgetIndexInStore(itask)] += -der[j]*fforce; n += getPntrToArgument(j)->getNumberOfStoredValues(); } } } } else { @@ -565,9 +566,9 @@ void KDE::gatherForcesOnStoredValue( const Value* myval, const unsigned& itask, gridobject.getGridPointCoordinates( neighbors[i], gpoint ); double dot=0; for(unsigned j=0; jgetForce( neighbors[i] ); double newval = height*von_misses_norm*exp( von_misses_concentration*dot ); - if( hasheight && getPntrToArgument(args.size())->getRank()==0 ) forces[ args.size()*numberOfKernels ] += newval*fforce / height; - else if( hasheight ) forces[ args.size()*numberOfKernels + itask ] += newval*fforce / height; - unsigned n=itask; for(unsigned j=0; jgetRank()==0 ) forces[ hforce_start ] += newval*fforce / height; + else if( hasheight ) forces[ hforce_start + getPntrToArgument(args.size())->getIndexInStore(itask) ] += newval*fforce / height; + unsigned n=0; for(unsigned j=0; jgetIndexInStore(itask)] += von_misses_concentration*newval*gpoint[j]*fforce; n += getPntrToArgument(j)->getNumberOfStoredValues(); } } } } diff --git a/src/matrixtools/MatrixTimesVector.cpp b/src/matrixtools/MatrixTimesVector.cpp index df77f39876..f9ea71ca65 100644 --- a/src/matrixtools/MatrixTimesVector.cpp +++ b/src/matrixtools/MatrixTimesVector.cpp @@ -150,9 +150,9 @@ void MatrixTimesVector::prepare() { } void MatrixTimesVector::setupForTask( const unsigned& task_index, std::vector& indices, MultiValue& myvals ) const { - unsigned start_n = getPntrToArgument(0)->getShape()[0], size_v = getPntrToArgument(0)->getShape()[1]; + unsigned start_n = getPntrToArgument(0)->getShape()[0], size_v = getPntrToArgument(0)->getRowLength(task_index); if( indices.size()!=size_v+1 ) indices.resize( size_v + 1 ); - for(unsigned i=0; igetRowIndex( task_index, i ); myvals.setSplitIndex( size_v + 1 ); }