Skip to content

Commit

Permalink
Tidy up the way that errors with mask arguments are handled
Browse files Browse the repository at this point in the history
  • Loading branch information
Gareth Aneurin Tribello authored and Gareth Aneurin Tribello committed Aug 24, 2024
1 parent 12a0ecf commit 98647c6
Show file tree
Hide file tree
Showing 12 changed files with 87 additions and 22 deletions.
18 changes: 8 additions & 10 deletions src/adjmat/TorsionsMatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,18 @@ class TorsionsMatrix : public ActionWithMatrix {
PLUMED_REGISTER_ACTION(TorsionsMatrix,"TORSIONS_MATRIX")

void TorsionsMatrix::registerKeywords( Keywords& keys ) {
ActionWithMatrix::registerKeywords(keys); keys.use("ARG");
ActionWithMatrix::registerKeywords(keys); keys.use("ARG"); keys.use("MASK");
keys.add("atoms","POSITIONS1","the positions to use for the molecules specified using the first argument");
keys.add("atoms","POSITIONS2","the positions to use for the molecules specified using the second argument");
keys.add("optional","MASK","the label for a sparse matrix that should be used to determine which elements of the matrix should be computed");
keys.setValueDescription("the matrix of torsions between the two vectors of input directors");
}

TorsionsMatrix::TorsionsMatrix(const ActionOptions&ao):
Action(ao),
ActionWithMatrix(ao)
{
if( getNumberOfArguments()!=2 ) error("should be two arguments to this action, a matrix and a vector");
unsigned nmask = 0; if( hasMask() ) nmask = 1;
if( getNumberOfArguments()-nmask!=2 ) error("should be two arguments to this action, a matrix and a vector");
if( getPntrToArgument(0)->getRank()!=2 || getPntrToArgument(0)->hasDerivatives() ) error("first argument to this action should be a matrix");
if( getPntrToArgument(1)->getRank()!=2 || getPntrToArgument(1)->hasDerivatives() ) error("second argument to this action should be a matrix");
if( getPntrToArgument(0)->getShape()[1]!=3 || getPntrToArgument(1)->getShape()[0]!=3 ) error("number of columns in first matrix and number of rows in second matrix should equal 3");
Expand All @@ -90,13 +90,11 @@ TorsionsMatrix::TorsionsMatrix(const ActionOptions&ao):
stored_matrix1 = getPntrToArgument(0)->ignoreStoredValue( headstr );
stored_matrix2 = getPntrToArgument(1)->ignoreStoredValue( headstr );

std::vector<Value*> mask; parseArgumentList("MASK",mask);
if( mask.size()==1 ) {
if( mask[0]->getRank()!=2 || getPntrToArgument(0)->hasDerivatives() ) error("argument passed to MASK keyword should be a matrix");
if( mask[0]->getShape()[0]!=shape[0] || mask[0]->getShape()[1]!=shape[1] ) error("argument passed to MASK keyword has the wrong shape");
log.printf(" only computing elements of matrix product that correspond to non-zero elements of matrix %s \n", mask[0]->getName().c_str() );
std::vector<Value*> allargs( getArguments() ); allargs.push_back( mask[0] ); requestArguments( allargs );
} else if( mask.size()!=0 ) error("MASK should only have one argument");
if( hasMask() ) {
unsigned iarg = getNumberOfArguments()-1;
if( getPntrToArgument(iarg)->getRank()!=2 || getPntrToArgument(0)->hasDerivatives() ) error("argument passed to MASK keyword should be a matrix");
if( getPntrToArgument(iarg)->getShape()[0]!=shape[0] || getPntrToArgument(iarg)->getShape()[1]!=shape[1] ) error("argument passed to MASK keyword has the wrong shape");
}
}

unsigned TorsionsMatrix::getNumberOfDerivatives() {
Expand Down
5 changes: 5 additions & 0 deletions src/core/ActionWithArguments.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,11 @@ void ActionWithArguments::interpretArgumentList(const std::vector<std::string>&
}
}
}
if( readact->keywords.exists("MASKED_INPUT_ALLOWED") ) return;
for(unsigned i=0; i<arg.size(); ++i) {
ActionWithVector* av=dynamic_cast<ActionWithVector*>( arg[i]->getPntrToAction() );
if( av && av->hasMask() ) readact->error("cannot use argument " + arg[i]->getName() + " in input as not all elements are computed");
}
}

void ActionWithArguments::expandArgKeywordInPDB( const PDB& pdb ) {
Expand Down
16 changes: 16 additions & 0 deletions src/core/ActionWithVector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,15 @@ void ActionWithVector::registerKeywords( Keywords& keys ) {
ActionWithValue::registerKeywords( keys ); keys.remove("NUMERICAL_DERIVATIVES");
ActionWithArguments::registerKeywords( keys );
keys.addFlag("SERIAL",false,"do the calculation in serial. Do not parallelize");
keys.reserve("optional","MASK","the label for a sparse matrix that should be used to determine which elements of the matrix should be computed");
}

ActionWithVector::ActionWithVector(const ActionOptions&ao):
Action(ao),
ActionAtomistic(ao),
ActionWithValue(ao),
ActionWithArguments(ao),
hasmask(false),
serial(false),
forwardPass(false),
action_to_do_before(NULL),
Expand All @@ -68,7 +70,21 @@ ActionWithVector::ActionWithVector(const ActionOptions&ao):
atomsWereRetrieved(false),
done_in_chain(false)
{
for(unsigned i=0; i<getNumberOfArguments(); ++i) {
ActionWithVector* av = dynamic_cast<ActionWithVector*>( getPntrToArgument(i)->getPntrToAction() );
if( av && av->hasMask() ) hasmask=true;
}

if( keywords.exists("SERIAL") ) parseFlag("SERIAL",serial);
if( keywords.exists("MASK") ) {
std::vector<Value*> mask; parseArgumentList("MASK",mask);
if( mask.size()==1 ) {
if( getPntrToArgument(0)->hasDerivatives() ) error("input for mask should be vector or matrix");
else if( mask[0]->getRank()==2 ) log.printf(" only computing elements of matrix that correspond to non-zero elements of matrix %s \n", mask[0]->getName().c_str() );
else if( mask[0]->getRank()==1 ) log.printf(" only computing elements of vector that correspond to non-zero elements of vector %s \n", mask[0]->getName().c_str() );
std::vector<Value*> allargs( getArguments() ); allargs.push_back( mask[0] ); requestArguments( allargs ); hasmask=true;
} else if( mask.size()!=0 ) error("MASK should only have one argument");
}
}

ActionWithVector::~ActionWithVector() {
Expand Down
16 changes: 16 additions & 0 deletions src/core/ActionWithVector.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class ActionWithVector:
{
friend class Value;
private:
/// Check if there is a mask value
bool hasmask;
/// Is the calculation to be done in serial
bool serial;
/// Are we in the forward pass through the calculation
Expand Down Expand Up @@ -102,6 +104,8 @@ class ActionWithVector:
std::vector<unsigned> arg_deriv_starts;
/// Assert if this action is part of a chain
bool done_in_chain;
/// Turn off the flag that says this action has a masked input
void ignoreMaskArguments();
/// This updates whether or not we are using all the task reduction stuff
void updateTaskListReductionStatus();
/// Run all calculations in serial
Expand All @@ -124,6 +128,8 @@ class ActionWithVector:
void unlockRequests() override;
virtual void prepare() override;
void retrieveAtoms( const bool& force=false ) override;
/// Check if a mask has been set
bool hasMask() const ;
void calculateNumericalDerivatives(ActionWithValue* av) override;
/// Turn off the calculation of the derivatives during the forward pass through a calculation
bool doNotCalculateDerivatives() const override ;
Expand Down Expand Up @@ -190,6 +196,16 @@ bool ActionWithVector::runInSerial() const {
return serial;
}

inline
bool ActionWithVector::hasMask() const {
return hasmask;
}

inline
void ActionWithVector::ignoreMaskArguments() {
hasmask=false;
}

}

#endif
19 changes: 19 additions & 0 deletions src/function/Custom.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,25 @@ bool Custom::getDerivativeZeroIfValueIsZero() const {
return check_multiplication_vars.size()>0;
}

bool Custom::checkIfMaskAllowed( const std::vector<Value*>& args ) const {
bool nomask=true;
for(unsigned i=0; i<args.size(); ++i) {
bool found=false;
for(unsigned j=0; j<check_multiplication_vars.size(); ++j) {
if( i==check_multiplication_vars[j] ) { found=true; break; }
}
if( found ) continue;
ActionWithVector* av=dynamic_cast<ActionWithVector*>( args[i]->getPntrToAction() );
if( av && av->hasMask() ) {
nomask=false; Value* maskarg = av->getPntrToArgument( av->getNumberOfArguments()-1 );
for(unsigned j=0; j<check_multiplication_vars.size(); ++j) {
if( maskarg==args[check_multiplication_vars[j]] ) return true;
}
}
}
return nomask;
}

std::vector<Value*> Custom::getArgumentsToCheck( const std::vector<Value*>& args ) {
std::vector<Value*> fargs( check_multiplication_vars.size() );
for(unsigned i=0; i<check_multiplication_vars.size(); ++i) fargs[i] = args[check_multiplication_vars[i]];
Expand Down
1 change: 1 addition & 0 deletions src/function/Custom.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class Custom : public FunctionTemplateBase {
void registerKeywords( Keywords& keys ) override;
std::string getGraphInfo( const std::string& lab ) const override;
void read( ActionWithArguments* action ) override;
bool checkIfMaskAllowed( const std::vector<Value*>& args ) const override ;
bool getDerivativeZeroIfValueIsZero() const override;
std::vector<Value*> getArgumentsToCheck( const std::vector<Value*>& args ) override;
void calc( const ActionWithArguments* action, const std::vector<double>& args, std::vector<double>& vals, Matrix<double>& derivatives ) const override;
Expand Down
3 changes: 3 additions & 0 deletions src/function/FunctionOfMatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ void FunctionOfMatrix<T>::registerKeywords(Keywords& keys ) {
keys.add("hidden","NO_ACTION_LOG","suppresses printing from action on the log");
keys.reserve("compulsory","PERIODIC","if the output of your function is periodic then you should specify the periodicity of the function. If the output is not periodic you must state this using PERIODIC=NO");
T tfunc; tfunc.registerKeywords( keys );
if( keys.getDisplayName()=="CUSTOM" || keys.getDisplayName()=="MATHEVAL" || keys.getDisplayName()=="SUM" ) keys.add("hidden","MASKED_INPUT_ALLOWED","turns on that you are allowed to use masked inputs ");
if( keys.getDisplayName()=="SUM" ) {
keys.setValueDescription("the sum of all the elements in the input matrix");
} else if( keys.getDisplayName()=="HIGHEST" ) {
Expand Down Expand Up @@ -154,6 +155,8 @@ FunctionOfMatrix<T>::FunctionOfMatrix(const ActionOptions&ao):
if( argname=="NEIGHBORS" ) { foundneigh=true; break; }
ActionWithVector* av=dynamic_cast<ActionWithVector*>( getPntrToArgument(i)->getPntrToAction() );
if( !av ) done_in_chain=false;
else if( av->hasMask() && !myfunc.checkIfMaskAllowed( getArguments() ) ) error("cannot use argument masks in input as not all elements are computed");

if( getPntrToArgument(i)->getRank()==0 ) {
function::FunctionOfVector<function::Sum>* as = dynamic_cast<function::FunctionOfVector<function::Sum>*>( getPntrToArgument(i)->getPntrToAction() );
if(as) done_in_chain=false;
Expand Down
2 changes: 2 additions & 0 deletions src/function/FunctionOfVector.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ void FunctionOfVector<T>::registerKeywords(Keywords& keys ) {
keys.reserve("compulsory","PERIODIC","if the output of your function is periodic then you should specify the periodicity of the function. If the output is not periodic you must state this using PERIODIC=NO");
keys.add("hidden","NO_ACTION_LOG","suppresses printing from action on the log");
T tfunc; tfunc.registerKeywords( keys );
if( keys.getDisplayName()=="CUSTOM" || keys.getDisplayName()=="MATHEVAL" || keys.getDisplayName()=="SUM" ) keys.add("hidden","MASKED_INPUT_ALLOWED","turns on that you are allowed to use masked inputs ");
if( keys.getDisplayName()=="SUM" ) {
keys.setValueDescription("the sum of all the elements in the input vector");
} else if( keys.getDisplayName()=="MEAN" ) {
Expand Down Expand Up @@ -161,6 +162,7 @@ FunctionOfVector<T>::FunctionOfVector(const ActionOptions&ao):
} else {
ActionWithVector* av=dynamic_cast<ActionWithVector*>( getPntrToArgument(i)->getPntrToAction() );
if( !av ) done_in_chain=false;
else if( av->hasMask() && !myfunc.checkIfMaskAllowed( getArguments() ) ) error("cannot use argument masks in input as not all elements are computed");
}
}
// Don't need to do the calculation in a chain if the input is constant
Expand Down
1 change: 1 addition & 0 deletions src/function/FunctionTemplateBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class FunctionTemplateBase {
virtual void registerKeywords( Keywords& keys ) = 0;
virtual void read( ActionWithArguments* action ) = 0;
virtual bool doWithTasks() const { return true; }
virtual bool checkIfMaskAllowed( const std::vector<Value*>& args ) const { return false; }
virtual std::vector<Value*> getArgumentsToCheck( const std::vector<Value*>& args );
bool allComponentsRequired( const std::vector<Value*>& args, const std::vector<ActionWithVector*>& actions );
virtual bool zeroRank() const { return false; }
Expand Down
1 change: 1 addition & 0 deletions src/function/Sum.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class Sum : public FunctionTemplateBase {
void registerKeywords( Keywords& keys ) override;
void read( ActionWithArguments* action ) override;
bool zeroRank() const override;
bool checkIfMaskAllowed( const std::vector<Value*>& args ) const override { return true; }
void setPrefactor( ActionWithArguments* action, const double pref ) override;
void calc( const ActionWithArguments* action, const std::vector<double>& args, std::vector<double>& vals, Matrix<double>& derivatives ) const override;
};
Expand Down
18 changes: 8 additions & 10 deletions src/matrixtools/MatrixTimesMatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ PLUMED_REGISTER_ACTION(MatrixTimesMatrix,"MATRIX_PRODUCT")
PLUMED_REGISTER_ACTION(MatrixTimesMatrix,"DISSIMILARITIES")

void MatrixTimesMatrix::registerKeywords( Keywords& keys ) {
ActionWithMatrix::registerKeywords(keys); keys.use("ARG");
keys.add("optional","MASK","the label for a sparse matrix that should be used to determine which elements of the matrix should be computed");
ActionWithMatrix::registerKeywords(keys); keys.use("ARG"); keys.use("MASK");
keys.addFlag("SQUARED",false,"calculate the squares of the dissimilarities (this option cannot be used with MATRIX_PRODUCT)");
keys.setValueDescription("the product of the two input matrices");
}
Expand All @@ -74,7 +73,8 @@ MatrixTimesMatrix::MatrixTimesMatrix(const ActionOptions&ao):
Action(ao),
ActionWithMatrix(ao)
{
if( getNumberOfArguments()!=2 ) error("should be two arguments to this action, a matrix and a vector");
unsigned nmask = 0; if( hasMask() ) nmask = 1;
if( getNumberOfArguments()-nmask!=2 ) error("should be two arguments to this action, a matrix and a vector");
if( getPntrToArgument(0)->getRank()!=2 || getPntrToArgument(0)->hasDerivatives() ) error("first argument to this action should be a matrix");
if( getPntrToArgument(1)->getRank()!=2 || getPntrToArgument(1)->hasDerivatives() ) error("second argument to this action should be a matrix");
if( getPntrToArgument(0)->getShape()[1]!=getPntrToArgument(1)->getShape()[0] ) error("number of columns in first matrix does not equal number of rows in second matrix");
Expand All @@ -88,13 +88,11 @@ MatrixTimesMatrix::MatrixTimesMatrix(const ActionOptions&ao):
if( squared ) log.printf(" calculating the squares of the dissimilarities \n");
} else squared=true;

std::vector<Value*> mask; parseArgumentList("MASK",mask);
if( mask.size()==1 ) {
if( mask[0]->getRank()!=2 || getPntrToArgument(0)->hasDerivatives() ) error("argument passed to MASK keyword should be a matrix");
if( mask[0]->getShape()[0]!=shape[0] || mask[0]->getShape()[1]!=shape[1] ) error("argument passed to MASK keyword has the wrong shape");
log.printf(" only computing elements of matrix product that correspond to non-zero elements of matrix %s \n", mask[0]->getName().c_str() );
std::vector<Value*> allargs( getArguments() ); allargs.push_back( mask[0] ); requestArguments( allargs );
} else if( mask.size()!=0 ) error("MASK should only have one argument");
if( hasMask() ) {
unsigned iarg = getNumberOfArguments()-1;
if( getPntrToArgument(iarg)->getRank()!=2 || getPntrToArgument(0)->hasDerivatives() ) error("argument passed to MASK keyword should be a matrix");
if( getPntrToArgument(iarg)->getShape()[0]!=shape[0] || getPntrToArgument(iarg)->getShape()[1]!=shape[1] ) error("argument passed to MASK keyword has the wrong shape");
}
}

unsigned MatrixTimesMatrix::getNumberOfDerivatives() {
Expand Down
9 changes: 7 additions & 2 deletions src/matrixtools/MatrixTimesVector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ PLUMED_REGISTER_ACTION(MatrixTimesVector,"MATRIX_VECTOR_PRODUCT")
void MatrixTimesVector::registerKeywords( Keywords& keys ) {
ActionWithMatrix::registerKeywords(keys); keys.use("ARG");
keys.setValueDescription("the vector that is obtained by taking the product between the matrix and the vector that were input");
keys.add("hidden","MASKED_INPUT_ALLOWED","turns on that you are allowed to use masked inputs ");
ActionWithValue::useCustomisableComponents(keys);
}

Expand All @@ -85,12 +86,16 @@ MatrixTimesVector::MatrixTimesVector(const ActionOptions&ao):
sumrows(false)
{
if( getNumberOfArguments()<2 ) error("Not enough arguments specified");
unsigned nvectors=0, nmatrices=0;
unsigned nvectors=0, nmatrices=0; bool vectormask=false;
for(unsigned i=0; i<getNumberOfArguments(); ++i) {
if( getPntrToArgument(i)->hasDerivatives() ) error("arguments should be vectors or matrices");
if( getPntrToArgument(i)->getRank()<=1 ) nvectors++;
if( getPntrToArgument(i)->getRank()<=1 ) {
nvectors++; ActionWithVector* av=dynamic_cast<ActionWithVector*>( getPntrToArgument(i)->getPntrToAction() );
if( av && av->hasMask() ) vectormask=true;
}
if( getPntrToArgument(i)->getRank()==2 ) nmatrices++;
}
if( !vectormask ) ignoreMaskArguments();

std::vector<unsigned> shape(1); shape[0]=getPntrToArgument(0)->getShape()[0];
if( nvectors==1 ) {
Expand Down

1 comment on commit 98647c6

@PlumedBot
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Found broken examples in automatic/ANGLES.tmp
Found broken examples in automatic/ANN.tmp
Found broken examples in automatic/CAVITY.tmp
Found broken examples in automatic/CLASSICAL_MDS.tmp
Found broken examples in automatic/CLUSTER_DIAMETER.tmp
Found broken examples in automatic/CLUSTER_DISTRIBUTION.tmp
Found broken examples in automatic/CLUSTER_PROPERTIES.tmp
Found broken examples in automatic/CONSTANT.tmp
Found broken examples in automatic/CONTACT_MATRIX.tmp
Found broken examples in automatic/CONTACT_MATRIX_PROPER.tmp
Found broken examples in automatic/COORDINATIONNUMBER.tmp
Found broken examples in automatic/DFSCLUSTERING.tmp
Found broken examples in automatic/DISTANCE_FROM_CONTOUR.tmp
Found broken examples in automatic/EDS.tmp
Found broken examples in automatic/EMMI.tmp
Found broken examples in automatic/ENVIRONMENTSIMILARITY.tmp
Found broken examples in automatic/FIND_CONTOUR.tmp
Found broken examples in automatic/FIND_CONTOUR_SURFACE.tmp
Found broken examples in automatic/FIND_SPHERICAL_CONTOUR.tmp
Found broken examples in automatic/FOURIER_TRANSFORM.tmp
Found broken examples in automatic/FUNCPATHGENERAL.tmp
Found broken examples in automatic/FUNCPATHMSD.tmp
Found broken examples in automatic/FUNNEL.tmp
Found broken examples in automatic/FUNNEL_PS.tmp
Found broken examples in automatic/GHBFIX.tmp
Found broken examples in automatic/GPROPERTYMAP.tmp
Found broken examples in automatic/HBOND_MATRIX.tmp
Found broken examples in automatic/INCLUDE.tmp
Found broken examples in automatic/INCYLINDER.tmp
Found broken examples in automatic/INENVELOPE.tmp
Found broken examples in automatic/INTERPOLATE_GRID.tmp
Found broken examples in automatic/LOCAL_AVERAGE.tmp
Found broken examples in automatic/MAZE_OPTIMIZER_BIAS.tmp
Found broken examples in automatic/MAZE_RANDOM_ACCELERATION_MD.tmp
Found broken examples in automatic/MAZE_SIMULATED_ANNEALING.tmp
Found broken examples in automatic/MAZE_STEERED_MD.tmp
Found broken examples in automatic/METATENSOR.tmp
Found broken examples in automatic/MULTICOLVARDENS.tmp
Found broken examples in automatic/OUTPUT_CLUSTER.tmp
Found broken examples in automatic/PAMM.tmp
Found broken examples in automatic/PCA.tmp
Found broken examples in automatic/PCAVARS.tmp
Found broken examples in automatic/PIV.tmp
Found broken examples in automatic/PLUMED.tmp
Found broken examples in automatic/PYCVINTERFACE.tmp
Found broken examples in automatic/PYTHONFUNCTION.tmp
Found broken examples in automatic/Q3.tmp
Found broken examples in automatic/Q4.tmp
Found broken examples in automatic/Q6.tmp
Found broken examples in automatic/QUATERNION.tmp
Found broken examples in automatic/SIZESHAPE_POSITION_LINEAR_PROJ.tmp
Found broken examples in automatic/SIZESHAPE_POSITION_MAHA_DIST.tmp
Found broken examples in automatic/SPRINT.tmp
Found broken examples in automatic/TETRAHEDRALPORE.tmp
Found broken examples in automatic/TORSIONS.tmp
Found broken examples in automatic/WHAM_WEIGHTS.tmp
Found broken examples in AnalysisPP.md
Found broken examples in CollectiveVariablesPP.md
Found broken examples in MiscelaneousPP.md

Please sign in to comment.