From 98647c6ff5a3fae2d138a0ff5767d61d76569b69 Mon Sep 17 00:00:00 2001
From: Gareth Aneurin Tribello <garethtribello@Gareths-MacBook-Pro.local>
Date: Sat, 24 Aug 2024 12:42:57 +0100
Subject: [PATCH] Tidy up the way that errors with mask arguments are handled

---
 src/adjmat/TorsionsMatrix.cpp         | 18 ++++++++----------
 src/core/ActionWithArguments.cpp      |  5 +++++
 src/core/ActionWithVector.cpp         | 16 ++++++++++++++++
 src/core/ActionWithVector.h           | 16 ++++++++++++++++
 src/function/Custom.cpp               | 19 +++++++++++++++++++
 src/function/Custom.h                 |  1 +
 src/function/FunctionOfMatrix.h       |  3 +++
 src/function/FunctionOfVector.h       |  2 ++
 src/function/FunctionTemplateBase.h   |  1 +
 src/function/Sum.h                    |  1 +
 src/matrixtools/MatrixTimesMatrix.cpp | 18 ++++++++----------
 src/matrixtools/MatrixTimesVector.cpp |  9 +++++++--
 12 files changed, 87 insertions(+), 22 deletions(-)

diff --git a/src/adjmat/TorsionsMatrix.cpp b/src/adjmat/TorsionsMatrix.cpp
index 8ef66965ef..27c4d9677a 100644
--- a/src/adjmat/TorsionsMatrix.cpp
+++ b/src/adjmat/TorsionsMatrix.cpp
@@ -52,10 +52,9 @@ class TorsionsMatrix : public ActionWithMatrix {
 PLUMED_REGISTER_ACTION(TorsionsMatrix,"TORSIONS_MATRIX")
 
 void TorsionsMatrix::registerKeywords( Keywords& keys ) {
-  ActionWithMatrix::registerKeywords(keys); keys.use("ARG");
+  ActionWithMatrix::registerKeywords(keys); keys.use("ARG"); keys.use("MASK");
   keys.add("atoms","POSITIONS1","the positions to use for the molecules specified using the first argument");
   keys.add("atoms","POSITIONS2","the positions to use for the molecules specified using the second argument");
-  keys.add("optional","MASK","the label for a sparse matrix that should be used to determine which elements of the matrix should be computed");
   keys.setValueDescription("the matrix of torsions between the two vectors of input directors");
 }
 
@@ -63,7 +62,8 @@ TorsionsMatrix::TorsionsMatrix(const ActionOptions&ao):
   Action(ao),
   ActionWithMatrix(ao)
 {
-  if( getNumberOfArguments()!=2 ) error("should be two arguments to this action, a matrix and a vector");
+  unsigned nmask = 0; if( hasMask() ) nmask = 1;
+  if( getNumberOfArguments()-nmask!=2 ) error("should be two arguments to this action, a matrix and a vector");
   if( getPntrToArgument(0)->getRank()!=2 || getPntrToArgument(0)->hasDerivatives() ) error("first argument to this action should be a matrix");
   if( getPntrToArgument(1)->getRank()!=2 || getPntrToArgument(1)->hasDerivatives() ) error("second argument to this action should be a matrix");
   if( getPntrToArgument(0)->getShape()[1]!=3 || getPntrToArgument(1)->getShape()[0]!=3 ) error("number of columns in first matrix and number of rows in second matrix should equal 3");
@@ -90,13 +90,11 @@ TorsionsMatrix::TorsionsMatrix(const ActionOptions&ao):
   stored_matrix1 = getPntrToArgument(0)->ignoreStoredValue( headstr );
   stored_matrix2 = getPntrToArgument(1)->ignoreStoredValue( headstr );
 
-  std::vector<Value*> mask; parseArgumentList("MASK",mask);
-  if( mask.size()==1 ) {
-    if( mask[0]->getRank()!=2 || getPntrToArgument(0)->hasDerivatives() ) error("argument passed to MASK keyword should be a matrix");
-    if( mask[0]->getShape()[0]!=shape[0] || mask[0]->getShape()[1]!=shape[1] ) error("argument passed to MASK keyword has the wrong shape");
-    log.printf("  only computing elements of matrix product that correspond to non-zero elements of matrix %s \n", mask[0]->getName().c_str() );
-    std::vector<Value*> allargs( getArguments() ); allargs.push_back( mask[0] ); requestArguments( allargs );
-  } else if( mask.size()!=0 ) error("MASK should only have one argument");
+  if( hasMask() ) {
+    unsigned iarg = getNumberOfArguments()-1;
+    if( getPntrToArgument(iarg)->getRank()!=2 || getPntrToArgument(0)->hasDerivatives() ) error("argument passed to MASK keyword should be a matrix");
+    if( getPntrToArgument(iarg)->getShape()[0]!=shape[0] || getPntrToArgument(iarg)->getShape()[1]!=shape[1] ) error("argument passed to MASK keyword has the wrong shape");
+  }
 }
 
 unsigned TorsionsMatrix::getNumberOfDerivatives() {
diff --git a/src/core/ActionWithArguments.cpp b/src/core/ActionWithArguments.cpp
index a1b9208e6a..3f8ec1a322 100644
--- a/src/core/ActionWithArguments.cpp
+++ b/src/core/ActionWithArguments.cpp
@@ -189,6 +189,11 @@ void ActionWithArguments::interpretArgumentList(const std::vector<std::string>&
       }
     }
   }
+  if( readact->keywords.exists("MASKED_INPUT_ALLOWED") ) return;
+  for(unsigned i=0; i<arg.size(); ++i) {
+    ActionWithVector* av=dynamic_cast<ActionWithVector*>( arg[i]->getPntrToAction() );
+    if( av && av->hasMask() ) readact->error("cannot use argument " + arg[i]->getName() + " in input as not all elements are computed");
+  }
 }
 
 void ActionWithArguments::expandArgKeywordInPDB( const PDB& pdb ) {
diff --git a/src/core/ActionWithVector.cpp b/src/core/ActionWithVector.cpp
index 5cfe5e1d64..21c8f71d51 100644
--- a/src/core/ActionWithVector.cpp
+++ b/src/core/ActionWithVector.cpp
@@ -52,6 +52,7 @@ void ActionWithVector::registerKeywords( Keywords& keys ) {
   ActionWithValue::registerKeywords( keys ); keys.remove("NUMERICAL_DERIVATIVES");
   ActionWithArguments::registerKeywords( keys );
   keys.addFlag("SERIAL",false,"do the calculation in serial.  Do not parallelize");
+  keys.reserve("optional","MASK","the label for a sparse matrix that should be used to determine which elements of the matrix should be computed");
 }
 
 ActionWithVector::ActionWithVector(const ActionOptions&ao):
@@ -59,6 +60,7 @@ ActionWithVector::ActionWithVector(const ActionOptions&ao):
   ActionAtomistic(ao),
   ActionWithValue(ao),
   ActionWithArguments(ao),
+  hasmask(false),
   serial(false),
   forwardPass(false),
   action_to_do_before(NULL),
@@ -68,7 +70,21 @@ ActionWithVector::ActionWithVector(const ActionOptions&ao):
   atomsWereRetrieved(false),
   done_in_chain(false)
 {
+  for(unsigned i=0; i<getNumberOfArguments(); ++i) {
+    ActionWithVector* av = dynamic_cast<ActionWithVector*>( getPntrToArgument(i)->getPntrToAction() );
+    if( av && av->hasMask() ) hasmask=true;
+  }
+
   if( keywords.exists("SERIAL") ) parseFlag("SERIAL",serial);
+  if( keywords.exists("MASK") ) {
+    std::vector<Value*> mask; parseArgumentList("MASK",mask);
+    if( mask.size()==1 ) {
+      if( getPntrToArgument(0)->hasDerivatives() ) error("input for mask should be vector or matrix");
+      else if( mask[0]->getRank()==2 ) log.printf("  only computing elements of matrix that correspond to non-zero elements of matrix %s \n", mask[0]->getName().c_str() );
+      else if( mask[0]->getRank()==1 ) log.printf("  only computing elements of vector that correspond to non-zero elements of vector %s \n", mask[0]->getName().c_str() );
+      std::vector<Value*> allargs( getArguments() ); allargs.push_back( mask[0] ); requestArguments( allargs ); hasmask=true;
+    } else if( mask.size()!=0 ) error("MASK should only have one argument");
+  }
 }
 
 ActionWithVector::~ActionWithVector() {
diff --git a/src/core/ActionWithVector.h b/src/core/ActionWithVector.h
index 6c4f0036de..627762ce5e 100644
--- a/src/core/ActionWithVector.h
+++ b/src/core/ActionWithVector.h
@@ -37,6 +37,8 @@ class ActionWithVector:
 {
   friend class Value;
 private:
+/// Check if there is a mask value
+  bool hasmask;
 /// Is the calculation to be done in serial
   bool serial;
 /// Are we in the forward pass through the calculation
@@ -102,6 +104,8 @@ class ActionWithVector:
   std::vector<unsigned> arg_deriv_starts;
 /// Assert if this action is part of a chain
   bool done_in_chain;
+/// Turn off the flag that says this action has a masked input
+  void ignoreMaskArguments();
 /// This updates whether or not we are using all the task reduction stuff
   void updateTaskListReductionStatus();
 /// Run all calculations in serial
@@ -124,6 +128,8 @@ class ActionWithVector:
   void unlockRequests() override;
   virtual void prepare() override;
   void retrieveAtoms( const bool& force=false ) override;
+/// Check if a mask has been set
+  bool hasMask() const ;
   void calculateNumericalDerivatives(ActionWithValue* av) override;
 /// Turn off the calculation of the derivatives during the forward pass through a calculation
   bool doNotCalculateDerivatives() const override ;
@@ -190,6 +196,16 @@ bool ActionWithVector::runInSerial() const {
   return serial;
 }
 
+inline
+bool ActionWithVector::hasMask() const {
+  return hasmask;
+}
+
+inline
+void ActionWithVector::ignoreMaskArguments() {
+  hasmask=false;
+}
+
 }
 
 #endif
diff --git a/src/function/Custom.cpp b/src/function/Custom.cpp
index 85e20a3579..b4f1262fc9 100644
--- a/src/function/Custom.cpp
+++ b/src/function/Custom.cpp
@@ -333,6 +333,25 @@ bool Custom::getDerivativeZeroIfValueIsZero() const {
   return check_multiplication_vars.size()>0;
 }
 
+bool Custom::checkIfMaskAllowed( const std::vector<Value*>& args ) const {
+  bool nomask=true;
+  for(unsigned i=0; i<args.size(); ++i) {
+    bool found=false;
+    for(unsigned j=0; j<check_multiplication_vars.size(); ++j) {
+      if( i==check_multiplication_vars[j] ) { found=true; break; }
+    }
+    if( found ) continue;
+    ActionWithVector* av=dynamic_cast<ActionWithVector*>( args[i]->getPntrToAction() );
+    if( av && av->hasMask() ) {
+      nomask=false; Value* maskarg = av->getPntrToArgument( av->getNumberOfArguments()-1 );
+      for(unsigned j=0; j<check_multiplication_vars.size(); ++j) {
+        if( maskarg==args[check_multiplication_vars[j]] ) return true;
+      }
+    }
+  }
+  return nomask;
+}
+
 std::vector<Value*> Custom::getArgumentsToCheck( const std::vector<Value*>& args ) {
   std::vector<Value*> fargs( check_multiplication_vars.size() );
   for(unsigned i=0; i<check_multiplication_vars.size(); ++i) fargs[i] = args[check_multiplication_vars[i]];
diff --git a/src/function/Custom.h b/src/function/Custom.h
index cd9d752305..e4037ba915 100644
--- a/src/function/Custom.h
+++ b/src/function/Custom.h
@@ -39,6 +39,7 @@ class Custom : public FunctionTemplateBase {
   void registerKeywords( Keywords& keys ) override;
   std::string getGraphInfo( const std::string& lab ) const override;
   void read( ActionWithArguments* action ) override;
+  bool checkIfMaskAllowed( const std::vector<Value*>& args ) const override ;
   bool getDerivativeZeroIfValueIsZero() const override;
   std::vector<Value*> getArgumentsToCheck( const std::vector<Value*>& args ) override;
   void calc( const ActionWithArguments* action, const std::vector<double>& args, std::vector<double>& vals, Matrix<double>& derivatives ) const override;
diff --git a/src/function/FunctionOfMatrix.h b/src/function/FunctionOfMatrix.h
index dbd6a2fb29..77c92cf54c 100644
--- a/src/function/FunctionOfMatrix.h
+++ b/src/function/FunctionOfMatrix.h
@@ -77,6 +77,7 @@ void FunctionOfMatrix<T>::registerKeywords(Keywords& keys ) {
   keys.add("hidden","NO_ACTION_LOG","suppresses printing from action on the log");
   keys.reserve("compulsory","PERIODIC","if the output of your function is periodic then you should specify the periodicity of the function.  If the output is not periodic you must state this using PERIODIC=NO");
   T tfunc; tfunc.registerKeywords( keys );
+  if( keys.getDisplayName()=="CUSTOM" || keys.getDisplayName()=="MATHEVAL" || keys.getDisplayName()=="SUM" ) keys.add("hidden","MASKED_INPUT_ALLOWED","turns on that you are allowed to use masked inputs ");
   if( keys.getDisplayName()=="SUM" ) {
     keys.setValueDescription("the sum of all the elements in the input matrix");
   } else if( keys.getDisplayName()=="HIGHEST" ) {
@@ -154,6 +155,8 @@ FunctionOfMatrix<T>::FunctionOfMatrix(const ActionOptions&ao):
     if( argname=="NEIGHBORS" ) { foundneigh=true; break; }
     ActionWithVector* av=dynamic_cast<ActionWithVector*>( getPntrToArgument(i)->getPntrToAction() );
     if( !av ) done_in_chain=false;
+    else if( av->hasMask() && !myfunc.checkIfMaskAllowed( getArguments() ) ) error("cannot use argument masks in input as not all elements are computed");
+
     if( getPntrToArgument(i)->getRank()==0 ) {
       function::FunctionOfVector<function::Sum>* as = dynamic_cast<function::FunctionOfVector<function::Sum>*>( getPntrToArgument(i)->getPntrToAction() );
       if(as) done_in_chain=false;
diff --git a/src/function/FunctionOfVector.h b/src/function/FunctionOfVector.h
index 98780b8f3b..bf6ef0e00d 100644
--- a/src/function/FunctionOfVector.h
+++ b/src/function/FunctionOfVector.h
@@ -79,6 +79,7 @@ void FunctionOfVector<T>::registerKeywords(Keywords& keys ) {
   keys.reserve("compulsory","PERIODIC","if the output of your function is periodic then you should specify the periodicity of the function.  If the output is not periodic you must state this using PERIODIC=NO");
   keys.add("hidden","NO_ACTION_LOG","suppresses printing from action on the log");
   T tfunc; tfunc.registerKeywords( keys );
+  if( keys.getDisplayName()=="CUSTOM" || keys.getDisplayName()=="MATHEVAL" || keys.getDisplayName()=="SUM" ) keys.add("hidden","MASKED_INPUT_ALLOWED","turns on that you are allowed to use masked inputs ");
   if( keys.getDisplayName()=="SUM" ) {
     keys.setValueDescription("the sum of all the elements in the input vector");
   } else if( keys.getDisplayName()=="MEAN" ) {
@@ -161,6 +162,7 @@ FunctionOfVector<T>::FunctionOfVector(const ActionOptions&ao):
     } else {
       ActionWithVector* av=dynamic_cast<ActionWithVector*>( getPntrToArgument(i)->getPntrToAction() );
       if( !av ) done_in_chain=false;
+      else if( av->hasMask() && !myfunc.checkIfMaskAllowed( getArguments() ) ) error("cannot use argument masks in input as not all elements are computed");
     }
   }
   // Don't need to do the calculation in a chain if the input is constant
diff --git a/src/function/FunctionTemplateBase.h b/src/function/FunctionTemplateBase.h
index 27bd657feb..1cf46d0499 100644
--- a/src/function/FunctionTemplateBase.h
+++ b/src/function/FunctionTemplateBase.h
@@ -52,6 +52,7 @@ class FunctionTemplateBase {
   virtual void registerKeywords( Keywords& keys ) = 0;
   virtual void read( ActionWithArguments* action ) = 0;
   virtual bool doWithTasks() const { return true; }
+  virtual bool checkIfMaskAllowed( const std::vector<Value*>& args ) const { return false; }
   virtual std::vector<Value*> getArgumentsToCheck( const std::vector<Value*>& args );
   bool allComponentsRequired( const std::vector<Value*>& args, const std::vector<ActionWithVector*>& actions );
   virtual bool zeroRank() const { return false; }
diff --git a/src/function/Sum.h b/src/function/Sum.h
index 0a28d37956..8980899ff9 100644
--- a/src/function/Sum.h
+++ b/src/function/Sum.h
@@ -34,6 +34,7 @@ class Sum : public FunctionTemplateBase {
   void registerKeywords( Keywords& keys ) override;
   void read( ActionWithArguments* action ) override;
   bool zeroRank() const override;
+  bool checkIfMaskAllowed( const std::vector<Value*>& args ) const override { return true; }
   void setPrefactor( ActionWithArguments* action, const double pref ) override;
   void calc( const ActionWithArguments* action, const std::vector<double>& args, std::vector<double>& vals, Matrix<double>& derivatives ) const override;
 };
diff --git a/src/matrixtools/MatrixTimesMatrix.cpp b/src/matrixtools/MatrixTimesMatrix.cpp
index ea9c2ee743..8bfcc43b93 100644
--- a/src/matrixtools/MatrixTimesMatrix.cpp
+++ b/src/matrixtools/MatrixTimesMatrix.cpp
@@ -64,8 +64,7 @@ PLUMED_REGISTER_ACTION(MatrixTimesMatrix,"MATRIX_PRODUCT")
 PLUMED_REGISTER_ACTION(MatrixTimesMatrix,"DISSIMILARITIES")
 
 void MatrixTimesMatrix::registerKeywords( Keywords& keys ) {
-  ActionWithMatrix::registerKeywords(keys); keys.use("ARG");
-  keys.add("optional","MASK","the label for a sparse matrix that should be used to determine which elements of the matrix should be computed");
+  ActionWithMatrix::registerKeywords(keys); keys.use("ARG"); keys.use("MASK");
   keys.addFlag("SQUARED",false,"calculate the squares of the dissimilarities (this option cannot be used with MATRIX_PRODUCT)");
   keys.setValueDescription("the product of the two input matrices");
 }
@@ -74,7 +73,8 @@ MatrixTimesMatrix::MatrixTimesMatrix(const ActionOptions&ao):
   Action(ao),
   ActionWithMatrix(ao)
 {
-  if( getNumberOfArguments()!=2 ) error("should be two arguments to this action, a matrix and a vector");
+  unsigned nmask = 0; if( hasMask() ) nmask = 1;
+  if( getNumberOfArguments()-nmask!=2 ) error("should be two arguments to this action, a matrix and a vector");
   if( getPntrToArgument(0)->getRank()!=2 || getPntrToArgument(0)->hasDerivatives() ) error("first argument to this action should be a matrix");
   if( getPntrToArgument(1)->getRank()!=2 || getPntrToArgument(1)->hasDerivatives() ) error("second argument to this action should be a matrix");
   if( getPntrToArgument(0)->getShape()[1]!=getPntrToArgument(1)->getShape()[0] ) error("number of columns in first matrix does not equal number of rows in second matrix");
@@ -88,13 +88,11 @@ MatrixTimesMatrix::MatrixTimesMatrix(const ActionOptions&ao):
     if( squared ) log.printf("  calculating the squares of the dissimilarities \n");
   } else squared=true;
 
-  std::vector<Value*> mask; parseArgumentList("MASK",mask);
-  if( mask.size()==1 ) {
-    if( mask[0]->getRank()!=2 || getPntrToArgument(0)->hasDerivatives() ) error("argument passed to MASK keyword should be a matrix");
-    if( mask[0]->getShape()[0]!=shape[0] || mask[0]->getShape()[1]!=shape[1] ) error("argument passed to MASK keyword has the wrong shape");
-    log.printf("  only computing elements of matrix product that correspond to non-zero elements of matrix %s \n", mask[0]->getName().c_str() );
-    std::vector<Value*> allargs( getArguments() ); allargs.push_back( mask[0] ); requestArguments( allargs );
-  } else if( mask.size()!=0 ) error("MASK should only have one argument");
+  if( hasMask() ) {
+    unsigned iarg = getNumberOfArguments()-1;
+    if( getPntrToArgument(iarg)->getRank()!=2 || getPntrToArgument(0)->hasDerivatives() ) error("argument passed to MASK keyword should be a matrix");
+    if( getPntrToArgument(iarg)->getShape()[0]!=shape[0] || getPntrToArgument(iarg)->getShape()[1]!=shape[1] ) error("argument passed to MASK keyword has the wrong shape");
+  }
 }
 
 unsigned MatrixTimesMatrix::getNumberOfDerivatives() {
diff --git a/src/matrixtools/MatrixTimesVector.cpp b/src/matrixtools/MatrixTimesVector.cpp
index a1eb361c93..866e0735a4 100644
--- a/src/matrixtools/MatrixTimesVector.cpp
+++ b/src/matrixtools/MatrixTimesVector.cpp
@@ -60,6 +60,7 @@ PLUMED_REGISTER_ACTION(MatrixTimesVector,"MATRIX_VECTOR_PRODUCT")
 void MatrixTimesVector::registerKeywords( Keywords& keys ) {
   ActionWithMatrix::registerKeywords(keys); keys.use("ARG");
   keys.setValueDescription("the vector that is obtained by taking the product between the matrix and the vector that were input");
+  keys.add("hidden","MASKED_INPUT_ALLOWED","turns on that you are allowed to use masked inputs ");
   ActionWithValue::useCustomisableComponents(keys);
 }
 
@@ -85,12 +86,16 @@ MatrixTimesVector::MatrixTimesVector(const ActionOptions&ao):
   sumrows(false)
 {
   if( getNumberOfArguments()<2 ) error("Not enough arguments specified");
-  unsigned nvectors=0, nmatrices=0;
+  unsigned nvectors=0, nmatrices=0; bool vectormask=false;
   for(unsigned i=0; i<getNumberOfArguments(); ++i) {
     if( getPntrToArgument(i)->hasDerivatives() ) error("arguments should be vectors or matrices");
-    if( getPntrToArgument(i)->getRank()<=1 ) nvectors++;
+    if( getPntrToArgument(i)->getRank()<=1 ) {
+      nvectors++; ActionWithVector* av=dynamic_cast<ActionWithVector*>( getPntrToArgument(i)->getPntrToAction() );
+      if( av && av->hasMask() ) vectormask=true;
+    }
     if( getPntrToArgument(i)->getRank()==2 ) nmatrices++;
   }
+  if( !vectormask ) ignoreMaskArguments();
 
   std::vector<unsigned> shape(1); shape[0]=getPntrToArgument(0)->getShape()[0];
   if( nvectors==1 ) {