From 98647c6ff5a3fae2d138a0ff5767d61d76569b69 Mon Sep 17 00:00:00 2001 From: Gareth Aneurin Tribello <garethtribello@Gareths-MacBook-Pro.local> Date: Sat, 24 Aug 2024 12:42:57 +0100 Subject: [PATCH] Tidy up the way that errors with mask arguments are handled --- src/adjmat/TorsionsMatrix.cpp | 18 ++++++++---------- src/core/ActionWithArguments.cpp | 5 +++++ src/core/ActionWithVector.cpp | 16 ++++++++++++++++ src/core/ActionWithVector.h | 16 ++++++++++++++++ src/function/Custom.cpp | 19 +++++++++++++++++++ src/function/Custom.h | 1 + src/function/FunctionOfMatrix.h | 3 +++ src/function/FunctionOfVector.h | 2 ++ src/function/FunctionTemplateBase.h | 1 + src/function/Sum.h | 1 + src/matrixtools/MatrixTimesMatrix.cpp | 18 ++++++++---------- src/matrixtools/MatrixTimesVector.cpp | 9 +++++++-- 12 files changed, 87 insertions(+), 22 deletions(-) diff --git a/src/adjmat/TorsionsMatrix.cpp b/src/adjmat/TorsionsMatrix.cpp index 8ef66965ef..27c4d9677a 100644 --- a/src/adjmat/TorsionsMatrix.cpp +++ b/src/adjmat/TorsionsMatrix.cpp @@ -52,10 +52,9 @@ class TorsionsMatrix : public ActionWithMatrix { PLUMED_REGISTER_ACTION(TorsionsMatrix,"TORSIONS_MATRIX") void TorsionsMatrix::registerKeywords( Keywords& keys ) { - ActionWithMatrix::registerKeywords(keys); keys.use("ARG"); + ActionWithMatrix::registerKeywords(keys); keys.use("ARG"); keys.use("MASK"); keys.add("atoms","POSITIONS1","the positions to use for the molecules specified using the first argument"); keys.add("atoms","POSITIONS2","the positions to use for the molecules specified using the second argument"); - keys.add("optional","MASK","the label for a sparse matrix that should be used to determine which elements of the matrix should be computed"); keys.setValueDescription("the matrix of torsions between the two vectors of input directors"); } @@ -63,7 +62,8 @@ TorsionsMatrix::TorsionsMatrix(const ActionOptions&ao): Action(ao), ActionWithMatrix(ao) { - if( getNumberOfArguments()!=2 ) error("should be two arguments to this action, a matrix and a vector"); + unsigned nmask = 0; if( hasMask() ) nmask = 1; + if( getNumberOfArguments()-nmask!=2 ) error("should be two arguments to this action, a matrix and a vector"); if( getPntrToArgument(0)->getRank()!=2 || getPntrToArgument(0)->hasDerivatives() ) error("first argument to this action should be a matrix"); if( getPntrToArgument(1)->getRank()!=2 || getPntrToArgument(1)->hasDerivatives() ) error("second argument to this action should be a matrix"); if( getPntrToArgument(0)->getShape()[1]!=3 || getPntrToArgument(1)->getShape()[0]!=3 ) error("number of columns in first matrix and number of rows in second matrix should equal 3"); @@ -90,13 +90,11 @@ TorsionsMatrix::TorsionsMatrix(const ActionOptions&ao): stored_matrix1 = getPntrToArgument(0)->ignoreStoredValue( headstr ); stored_matrix2 = getPntrToArgument(1)->ignoreStoredValue( headstr ); - std::vector<Value*> mask; parseArgumentList("MASK",mask); - if( mask.size()==1 ) { - if( mask[0]->getRank()!=2 || getPntrToArgument(0)->hasDerivatives() ) error("argument passed to MASK keyword should be a matrix"); - if( mask[0]->getShape()[0]!=shape[0] || mask[0]->getShape()[1]!=shape[1] ) error("argument passed to MASK keyword has the wrong shape"); - log.printf(" only computing elements of matrix product that correspond to non-zero elements of matrix %s \n", mask[0]->getName().c_str() ); - std::vector<Value*> allargs( getArguments() ); allargs.push_back( mask[0] ); requestArguments( allargs ); - } else if( mask.size()!=0 ) error("MASK should only have one argument"); + if( hasMask() ) { + unsigned iarg = getNumberOfArguments()-1; + if( getPntrToArgument(iarg)->getRank()!=2 || getPntrToArgument(0)->hasDerivatives() ) error("argument passed to MASK keyword should be a matrix"); + if( getPntrToArgument(iarg)->getShape()[0]!=shape[0] || getPntrToArgument(iarg)->getShape()[1]!=shape[1] ) error("argument passed to MASK keyword has the wrong shape"); + } } unsigned TorsionsMatrix::getNumberOfDerivatives() { diff --git a/src/core/ActionWithArguments.cpp b/src/core/ActionWithArguments.cpp index a1b9208e6a..3f8ec1a322 100644 --- a/src/core/ActionWithArguments.cpp +++ b/src/core/ActionWithArguments.cpp @@ -189,6 +189,11 @@ void ActionWithArguments::interpretArgumentList(const std::vector<std::string>& } } } + if( readact->keywords.exists("MASKED_INPUT_ALLOWED") ) return; + for(unsigned i=0; i<arg.size(); ++i) { + ActionWithVector* av=dynamic_cast<ActionWithVector*>( arg[i]->getPntrToAction() ); + if( av && av->hasMask() ) readact->error("cannot use argument " + arg[i]->getName() + " in input as not all elements are computed"); + } } void ActionWithArguments::expandArgKeywordInPDB( const PDB& pdb ) { diff --git a/src/core/ActionWithVector.cpp b/src/core/ActionWithVector.cpp index 5cfe5e1d64..21c8f71d51 100644 --- a/src/core/ActionWithVector.cpp +++ b/src/core/ActionWithVector.cpp @@ -52,6 +52,7 @@ void ActionWithVector::registerKeywords( Keywords& keys ) { ActionWithValue::registerKeywords( keys ); keys.remove("NUMERICAL_DERIVATIVES"); ActionWithArguments::registerKeywords( keys ); keys.addFlag("SERIAL",false,"do the calculation in serial. Do not parallelize"); + keys.reserve("optional","MASK","the label for a sparse matrix that should be used to determine which elements of the matrix should be computed"); } ActionWithVector::ActionWithVector(const ActionOptions&ao): @@ -59,6 +60,7 @@ ActionWithVector::ActionWithVector(const ActionOptions&ao): ActionAtomistic(ao), ActionWithValue(ao), ActionWithArguments(ao), + hasmask(false), serial(false), forwardPass(false), action_to_do_before(NULL), @@ -68,7 +70,21 @@ ActionWithVector::ActionWithVector(const ActionOptions&ao): atomsWereRetrieved(false), done_in_chain(false) { + for(unsigned i=0; i<getNumberOfArguments(); ++i) { + ActionWithVector* av = dynamic_cast<ActionWithVector*>( getPntrToArgument(i)->getPntrToAction() ); + if( av && av->hasMask() ) hasmask=true; + } + if( keywords.exists("SERIAL") ) parseFlag("SERIAL",serial); + if( keywords.exists("MASK") ) { + std::vector<Value*> mask; parseArgumentList("MASK",mask); + if( mask.size()==1 ) { + if( getPntrToArgument(0)->hasDerivatives() ) error("input for mask should be vector or matrix"); + else if( mask[0]->getRank()==2 ) log.printf(" only computing elements of matrix that correspond to non-zero elements of matrix %s \n", mask[0]->getName().c_str() ); + else if( mask[0]->getRank()==1 ) log.printf(" only computing elements of vector that correspond to non-zero elements of vector %s \n", mask[0]->getName().c_str() ); + std::vector<Value*> allargs( getArguments() ); allargs.push_back( mask[0] ); requestArguments( allargs ); hasmask=true; + } else if( mask.size()!=0 ) error("MASK should only have one argument"); + } } ActionWithVector::~ActionWithVector() { diff --git a/src/core/ActionWithVector.h b/src/core/ActionWithVector.h index 6c4f0036de..627762ce5e 100644 --- a/src/core/ActionWithVector.h +++ b/src/core/ActionWithVector.h @@ -37,6 +37,8 @@ class ActionWithVector: { friend class Value; private: +/// Check if there is a mask value + bool hasmask; /// Is the calculation to be done in serial bool serial; /// Are we in the forward pass through the calculation @@ -102,6 +104,8 @@ class ActionWithVector: std::vector<unsigned> arg_deriv_starts; /// Assert if this action is part of a chain bool done_in_chain; +/// Turn off the flag that says this action has a masked input + void ignoreMaskArguments(); /// This updates whether or not we are using all the task reduction stuff void updateTaskListReductionStatus(); /// Run all calculations in serial @@ -124,6 +128,8 @@ class ActionWithVector: void unlockRequests() override; virtual void prepare() override; void retrieveAtoms( const bool& force=false ) override; +/// Check if a mask has been set + bool hasMask() const ; void calculateNumericalDerivatives(ActionWithValue* av) override; /// Turn off the calculation of the derivatives during the forward pass through a calculation bool doNotCalculateDerivatives() const override ; @@ -190,6 +196,16 @@ bool ActionWithVector::runInSerial() const { return serial; } +inline +bool ActionWithVector::hasMask() const { + return hasmask; +} + +inline +void ActionWithVector::ignoreMaskArguments() { + hasmask=false; +} + } #endif diff --git a/src/function/Custom.cpp b/src/function/Custom.cpp index 85e20a3579..b4f1262fc9 100644 --- a/src/function/Custom.cpp +++ b/src/function/Custom.cpp @@ -333,6 +333,25 @@ bool Custom::getDerivativeZeroIfValueIsZero() const { return check_multiplication_vars.size()>0; } +bool Custom::checkIfMaskAllowed( const std::vector<Value*>& args ) const { + bool nomask=true; + for(unsigned i=0; i<args.size(); ++i) { + bool found=false; + for(unsigned j=0; j<check_multiplication_vars.size(); ++j) { + if( i==check_multiplication_vars[j] ) { found=true; break; } + } + if( found ) continue; + ActionWithVector* av=dynamic_cast<ActionWithVector*>( args[i]->getPntrToAction() ); + if( av && av->hasMask() ) { + nomask=false; Value* maskarg = av->getPntrToArgument( av->getNumberOfArguments()-1 ); + for(unsigned j=0; j<check_multiplication_vars.size(); ++j) { + if( maskarg==args[check_multiplication_vars[j]] ) return true; + } + } + } + return nomask; +} + std::vector<Value*> Custom::getArgumentsToCheck( const std::vector<Value*>& args ) { std::vector<Value*> fargs( check_multiplication_vars.size() ); for(unsigned i=0; i<check_multiplication_vars.size(); ++i) fargs[i] = args[check_multiplication_vars[i]]; diff --git a/src/function/Custom.h b/src/function/Custom.h index cd9d752305..e4037ba915 100644 --- a/src/function/Custom.h +++ b/src/function/Custom.h @@ -39,6 +39,7 @@ class Custom : public FunctionTemplateBase { void registerKeywords( Keywords& keys ) override; std::string getGraphInfo( const std::string& lab ) const override; void read( ActionWithArguments* action ) override; + bool checkIfMaskAllowed( const std::vector<Value*>& args ) const override ; bool getDerivativeZeroIfValueIsZero() const override; std::vector<Value*> getArgumentsToCheck( const std::vector<Value*>& args ) override; void calc( const ActionWithArguments* action, const std::vector<double>& args, std::vector<double>& vals, Matrix<double>& derivatives ) const override; diff --git a/src/function/FunctionOfMatrix.h b/src/function/FunctionOfMatrix.h index dbd6a2fb29..77c92cf54c 100644 --- a/src/function/FunctionOfMatrix.h +++ b/src/function/FunctionOfMatrix.h @@ -77,6 +77,7 @@ void FunctionOfMatrix<T>::registerKeywords(Keywords& keys ) { keys.add("hidden","NO_ACTION_LOG","suppresses printing from action on the log"); keys.reserve("compulsory","PERIODIC","if the output of your function is periodic then you should specify the periodicity of the function. If the output is not periodic you must state this using PERIODIC=NO"); T tfunc; tfunc.registerKeywords( keys ); + if( keys.getDisplayName()=="CUSTOM" || keys.getDisplayName()=="MATHEVAL" || keys.getDisplayName()=="SUM" ) keys.add("hidden","MASKED_INPUT_ALLOWED","turns on that you are allowed to use masked inputs "); if( keys.getDisplayName()=="SUM" ) { keys.setValueDescription("the sum of all the elements in the input matrix"); } else if( keys.getDisplayName()=="HIGHEST" ) { @@ -154,6 +155,8 @@ FunctionOfMatrix<T>::FunctionOfMatrix(const ActionOptions&ao): if( argname=="NEIGHBORS" ) { foundneigh=true; break; } ActionWithVector* av=dynamic_cast<ActionWithVector*>( getPntrToArgument(i)->getPntrToAction() ); if( !av ) done_in_chain=false; + else if( av->hasMask() && !myfunc.checkIfMaskAllowed( getArguments() ) ) error("cannot use argument masks in input as not all elements are computed"); + if( getPntrToArgument(i)->getRank()==0 ) { function::FunctionOfVector<function::Sum>* as = dynamic_cast<function::FunctionOfVector<function::Sum>*>( getPntrToArgument(i)->getPntrToAction() ); if(as) done_in_chain=false; diff --git a/src/function/FunctionOfVector.h b/src/function/FunctionOfVector.h index 98780b8f3b..bf6ef0e00d 100644 --- a/src/function/FunctionOfVector.h +++ b/src/function/FunctionOfVector.h @@ -79,6 +79,7 @@ void FunctionOfVector<T>::registerKeywords(Keywords& keys ) { keys.reserve("compulsory","PERIODIC","if the output of your function is periodic then you should specify the periodicity of the function. If the output is not periodic you must state this using PERIODIC=NO"); keys.add("hidden","NO_ACTION_LOG","suppresses printing from action on the log"); T tfunc; tfunc.registerKeywords( keys ); + if( keys.getDisplayName()=="CUSTOM" || keys.getDisplayName()=="MATHEVAL" || keys.getDisplayName()=="SUM" ) keys.add("hidden","MASKED_INPUT_ALLOWED","turns on that you are allowed to use masked inputs "); if( keys.getDisplayName()=="SUM" ) { keys.setValueDescription("the sum of all the elements in the input vector"); } else if( keys.getDisplayName()=="MEAN" ) { @@ -161,6 +162,7 @@ FunctionOfVector<T>::FunctionOfVector(const ActionOptions&ao): } else { ActionWithVector* av=dynamic_cast<ActionWithVector*>( getPntrToArgument(i)->getPntrToAction() ); if( !av ) done_in_chain=false; + else if( av->hasMask() && !myfunc.checkIfMaskAllowed( getArguments() ) ) error("cannot use argument masks in input as not all elements are computed"); } } // Don't need to do the calculation in a chain if the input is constant diff --git a/src/function/FunctionTemplateBase.h b/src/function/FunctionTemplateBase.h index 27bd657feb..1cf46d0499 100644 --- a/src/function/FunctionTemplateBase.h +++ b/src/function/FunctionTemplateBase.h @@ -52,6 +52,7 @@ class FunctionTemplateBase { virtual void registerKeywords( Keywords& keys ) = 0; virtual void read( ActionWithArguments* action ) = 0; virtual bool doWithTasks() const { return true; } + virtual bool checkIfMaskAllowed( const std::vector<Value*>& args ) const { return false; } virtual std::vector<Value*> getArgumentsToCheck( const std::vector<Value*>& args ); bool allComponentsRequired( const std::vector<Value*>& args, const std::vector<ActionWithVector*>& actions ); virtual bool zeroRank() const { return false; } diff --git a/src/function/Sum.h b/src/function/Sum.h index 0a28d37956..8980899ff9 100644 --- a/src/function/Sum.h +++ b/src/function/Sum.h @@ -34,6 +34,7 @@ class Sum : public FunctionTemplateBase { void registerKeywords( Keywords& keys ) override; void read( ActionWithArguments* action ) override; bool zeroRank() const override; + bool checkIfMaskAllowed( const std::vector<Value*>& args ) const override { return true; } void setPrefactor( ActionWithArguments* action, const double pref ) override; void calc( const ActionWithArguments* action, const std::vector<double>& args, std::vector<double>& vals, Matrix<double>& derivatives ) const override; }; diff --git a/src/matrixtools/MatrixTimesMatrix.cpp b/src/matrixtools/MatrixTimesMatrix.cpp index ea9c2ee743..8bfcc43b93 100644 --- a/src/matrixtools/MatrixTimesMatrix.cpp +++ b/src/matrixtools/MatrixTimesMatrix.cpp @@ -64,8 +64,7 @@ PLUMED_REGISTER_ACTION(MatrixTimesMatrix,"MATRIX_PRODUCT") PLUMED_REGISTER_ACTION(MatrixTimesMatrix,"DISSIMILARITIES") void MatrixTimesMatrix::registerKeywords( Keywords& keys ) { - ActionWithMatrix::registerKeywords(keys); keys.use("ARG"); - keys.add("optional","MASK","the label for a sparse matrix that should be used to determine which elements of the matrix should be computed"); + ActionWithMatrix::registerKeywords(keys); keys.use("ARG"); keys.use("MASK"); keys.addFlag("SQUARED",false,"calculate the squares of the dissimilarities (this option cannot be used with MATRIX_PRODUCT)"); keys.setValueDescription("the product of the two input matrices"); } @@ -74,7 +73,8 @@ MatrixTimesMatrix::MatrixTimesMatrix(const ActionOptions&ao): Action(ao), ActionWithMatrix(ao) { - if( getNumberOfArguments()!=2 ) error("should be two arguments to this action, a matrix and a vector"); + unsigned nmask = 0; if( hasMask() ) nmask = 1; + if( getNumberOfArguments()-nmask!=2 ) error("should be two arguments to this action, a matrix and a vector"); if( getPntrToArgument(0)->getRank()!=2 || getPntrToArgument(0)->hasDerivatives() ) error("first argument to this action should be a matrix"); if( getPntrToArgument(1)->getRank()!=2 || getPntrToArgument(1)->hasDerivatives() ) error("second argument to this action should be a matrix"); if( getPntrToArgument(0)->getShape()[1]!=getPntrToArgument(1)->getShape()[0] ) error("number of columns in first matrix does not equal number of rows in second matrix"); @@ -88,13 +88,11 @@ MatrixTimesMatrix::MatrixTimesMatrix(const ActionOptions&ao): if( squared ) log.printf(" calculating the squares of the dissimilarities \n"); } else squared=true; - std::vector<Value*> mask; parseArgumentList("MASK",mask); - if( mask.size()==1 ) { - if( mask[0]->getRank()!=2 || getPntrToArgument(0)->hasDerivatives() ) error("argument passed to MASK keyword should be a matrix"); - if( mask[0]->getShape()[0]!=shape[0] || mask[0]->getShape()[1]!=shape[1] ) error("argument passed to MASK keyword has the wrong shape"); - log.printf(" only computing elements of matrix product that correspond to non-zero elements of matrix %s \n", mask[0]->getName().c_str() ); - std::vector<Value*> allargs( getArguments() ); allargs.push_back( mask[0] ); requestArguments( allargs ); - } else if( mask.size()!=0 ) error("MASK should only have one argument"); + if( hasMask() ) { + unsigned iarg = getNumberOfArguments()-1; + if( getPntrToArgument(iarg)->getRank()!=2 || getPntrToArgument(0)->hasDerivatives() ) error("argument passed to MASK keyword should be a matrix"); + if( getPntrToArgument(iarg)->getShape()[0]!=shape[0] || getPntrToArgument(iarg)->getShape()[1]!=shape[1] ) error("argument passed to MASK keyword has the wrong shape"); + } } unsigned MatrixTimesMatrix::getNumberOfDerivatives() { diff --git a/src/matrixtools/MatrixTimesVector.cpp b/src/matrixtools/MatrixTimesVector.cpp index a1eb361c93..866e0735a4 100644 --- a/src/matrixtools/MatrixTimesVector.cpp +++ b/src/matrixtools/MatrixTimesVector.cpp @@ -60,6 +60,7 @@ PLUMED_REGISTER_ACTION(MatrixTimesVector,"MATRIX_VECTOR_PRODUCT") void MatrixTimesVector::registerKeywords( Keywords& keys ) { ActionWithMatrix::registerKeywords(keys); keys.use("ARG"); keys.setValueDescription("the vector that is obtained by taking the product between the matrix and the vector that were input"); + keys.add("hidden","MASKED_INPUT_ALLOWED","turns on that you are allowed to use masked inputs "); ActionWithValue::useCustomisableComponents(keys); } @@ -85,12 +86,16 @@ MatrixTimesVector::MatrixTimesVector(const ActionOptions&ao): sumrows(false) { if( getNumberOfArguments()<2 ) error("Not enough arguments specified"); - unsigned nvectors=0, nmatrices=0; + unsigned nvectors=0, nmatrices=0; bool vectormask=false; for(unsigned i=0; i<getNumberOfArguments(); ++i) { if( getPntrToArgument(i)->hasDerivatives() ) error("arguments should be vectors or matrices"); - if( getPntrToArgument(i)->getRank()<=1 ) nvectors++; + if( getPntrToArgument(i)->getRank()<=1 ) { + nvectors++; ActionWithVector* av=dynamic_cast<ActionWithVector*>( getPntrToArgument(i)->getPntrToAction() ); + if( av && av->hasMask() ) vectormask=true; + } if( getPntrToArgument(i)->getRank()==2 ) nmatrices++; } + if( !vectormask ) ignoreMaskArguments(); std::vector<unsigned> shape(1); shape[0]=getPntrToArgument(0)->getShape()[0]; if( nvectors==1 ) {