Skip to content

Commit

Permalink
converted mask variable from bool to unsigned
Browse files Browse the repository at this point in the history
  • Loading branch information
Gareth Aneurin Tribello authored and Gareth Aneurin Tribello committed Aug 24, 2024
1 parent 98647c6 commit 4607ff5
Show file tree
Hide file tree
Showing 9 changed files with 34 additions and 24 deletions.
5 changes: 2 additions & 3 deletions src/adjmat/TorsionsMatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ TorsionsMatrix::TorsionsMatrix(const ActionOptions&ao):
Action(ao),
ActionWithMatrix(ao)
{
unsigned nmask = 0; if( hasMask() ) nmask = 1;
if( getNumberOfArguments()-nmask!=2 ) error("should be two arguments to this action, a matrix and a vector");
if( getNumberOfArguments()-getNumberOfMasks()!=2 ) error("should be two arguments to this action, a matrix and a vector");
if( getPntrToArgument(0)->getRank()!=2 || getPntrToArgument(0)->hasDerivatives() ) error("first argument to this action should be a matrix");
if( getPntrToArgument(1)->getRank()!=2 || getPntrToArgument(1)->hasDerivatives() ) error("second argument to this action should be a matrix");
if( getPntrToArgument(0)->getShape()[1]!=3 || getPntrToArgument(1)->getShape()[0]!=3 ) error("number of columns in first matrix and number of rows in second matrix should equal 3");
Expand All @@ -90,7 +89,7 @@ TorsionsMatrix::TorsionsMatrix(const ActionOptions&ao):
stored_matrix1 = getPntrToArgument(0)->ignoreStoredValue( headstr );
stored_matrix2 = getPntrToArgument(1)->ignoreStoredValue( headstr );

if( hasMask() ) {
if( getNumberOfMasks()>0 ) {
unsigned iarg = getNumberOfArguments()-1;
if( getPntrToArgument(iarg)->getRank()!=2 || getPntrToArgument(0)->hasDerivatives() ) error("argument passed to MASK keyword should be a matrix");
if( getPntrToArgument(iarg)->getShape()[0]!=shape[0] || getPntrToArgument(iarg)->getShape()[1]!=shape[1] ) error("argument passed to MASK keyword has the wrong shape");
Expand Down
4 changes: 2 additions & 2 deletions src/core/ActionWithArguments.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,10 @@ void ActionWithArguments::interpretArgumentList(const std::vector<std::string>&
}
}
}
if( readact->keywords.exists("MASKED_INPUT_ALLOWED") ) return;
if( readact->keywords.exists("MASKED_INPUT_ALLOWED") || readact->keywords.exists("IS_SHORTCUT") ) return;
for(unsigned i=0; i<arg.size(); ++i) {
ActionWithVector* av=dynamic_cast<ActionWithVector*>( arg[i]->getPntrToAction() );
if( av && av->hasMask() ) readact->error("cannot use argument " + arg[i]->getName() + " in input as not all elements are computed");
if( av && av->getNumberOfMasks()>0 ) readact->error("cannot use argument " + arg[i]->getName() + " in input as not all elements are computed");
}
}

Expand Down
26 changes: 19 additions & 7 deletions src/core/ActionWithVector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ ActionWithVector::ActionWithVector(const ActionOptions&ao):
ActionAtomistic(ao),
ActionWithValue(ao),
ActionWithArguments(ao),
hasmask(false),
nmask(0),
serial(false),
forwardPass(false),
action_to_do_before(NULL),
Expand All @@ -72,18 +72,30 @@ ActionWithVector::ActionWithVector(const ActionOptions&ao):
{
for(unsigned i=0; i<getNumberOfArguments(); ++i) {
ActionWithVector* av = dynamic_cast<ActionWithVector*>( getPntrToArgument(i)->getPntrToAction() );
if( av && av->hasMask() ) hasmask=true;
if( av && av->getNumberOfMasks()>0 ) nmask=1;
}

if( keywords.exists("SERIAL") ) parseFlag("SERIAL",serial);
if( keywords.exists("MASK") ) {
std::vector<Value*> mask; parseArgumentList("MASK",mask);
if( mask.size()==1 ) {
if( mask.size()>0 ) {
if( nmask>0 ) error("should not have a mask if you have read the mask keyword");
if( getPntrToArgument(0)->hasDerivatives() ) error("input for mask should be vector or matrix");
else if( mask[0]->getRank()==2 ) log.printf(" only computing elements of matrix that correspond to non-zero elements of matrix %s \n", mask[0]->getName().c_str() );
else if( mask[0]->getRank()==1 ) log.printf(" only computing elements of vector that correspond to non-zero elements of vector %s \n", mask[0]->getName().c_str() );
std::vector<Value*> allargs( getArguments() ); allargs.push_back( mask[0] ); requestArguments( allargs ); hasmask=true;
} else if( mask.size()!=0 ) error("MASK should only have one argument");
else if( mask[0]->getRank()==2 ) {
if( mask.size()>1 ) error("MASK should only have one argument");
log.printf(" only computing elements of matrix that correspond to non-zero elements of matrix %s \n", mask[0]->getName().c_str() );
} else if( mask[0]->getRank()==1 ) {
log.printf(" only computing elements of vector that correspond to non-zero elements of vectors %s", mask[0]->getName().c_str() );
for(unsigned i=1; i<mask.size();++i) {
if( mask[i]->getRank()!=1 ) { log.printf("\n"); error("input to mask should be vector"); }
log.printf(", %s", mask[i]->getName().c_str() );
}
log.printf("\n");
}
std::vector<Value*> allargs( getArguments() ); nmask=mask.size();
for(unsigned i=0; i<mask.size(); ++i) allargs.push_back( mask[i] );
requestArguments( allargs );
}
}
}

Expand Down
10 changes: 5 additions & 5 deletions src/core/ActionWithVector.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class ActionWithVector:
friend class Value;
private:
/// Check if there is a mask value
bool hasmask;
unsigned nmask;
/// Is the calculation to be done in serial
bool serial;
/// Are we in the forward pass through the calculation
Expand Down Expand Up @@ -129,7 +129,7 @@ class ActionWithVector:
virtual void prepare() override;
void retrieveAtoms( const bool& force=false ) override;
/// Check if a mask has been set
bool hasMask() const ;
unsigned getNumberOfMasks() const ;
void calculateNumericalDerivatives(ActionWithValue* av) override;
/// Turn off the calculation of the derivatives during the forward pass through a calculation
bool doNotCalculateDerivatives() const override ;
Expand Down Expand Up @@ -197,13 +197,13 @@ bool ActionWithVector::runInSerial() const {
}

inline
bool ActionWithVector::hasMask() const {
return hasmask;
unsigned ActionWithVector::getNumberOfMasks() const {
return nmask;
}

inline
void ActionWithVector::ignoreMaskArguments() {
hasmask=false;
nmask=0;
}

}
Expand Down
2 changes: 1 addition & 1 deletion src/function/Custom.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ bool Custom::checkIfMaskAllowed( const std::vector<Value*>& args ) const {
}
if( found ) continue;
ActionWithVector* av=dynamic_cast<ActionWithVector*>( args[i]->getPntrToAction() );
if( av && av->hasMask() ) {
if( av && av->getNumberOfMasks()>0 ) {
nomask=false; Value* maskarg = av->getPntrToArgument( av->getNumberOfArguments()-1 );
for(unsigned j=0; j<check_multiplication_vars.size(); ++j) {
if( maskarg==args[check_multiplication_vars[j]] ) return true;
Expand Down
2 changes: 1 addition & 1 deletion src/function/FunctionOfMatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ FunctionOfMatrix<T>::FunctionOfMatrix(const ActionOptions&ao):
if( argname=="NEIGHBORS" ) { foundneigh=true; break; }
ActionWithVector* av=dynamic_cast<ActionWithVector*>( getPntrToArgument(i)->getPntrToAction() );
if( !av ) done_in_chain=false;
else if( av->hasMask() && !myfunc.checkIfMaskAllowed( getArguments() ) ) error("cannot use argument masks in input as not all elements are computed");
else if( av->getNumberOfMasks()>0 && !myfunc.checkIfMaskAllowed( getArguments() ) ) error("cannot use argument masks in input as not all elements are computed");

if( getPntrToArgument(i)->getRank()==0 ) {
function::FunctionOfVector<function::Sum>* as = dynamic_cast<function::FunctionOfVector<function::Sum>*>( getPntrToArgument(i)->getPntrToAction() );
Expand Down
2 changes: 1 addition & 1 deletion src/function/FunctionOfVector.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ FunctionOfVector<T>::FunctionOfVector(const ActionOptions&ao):
} else {
ActionWithVector* av=dynamic_cast<ActionWithVector*>( getPntrToArgument(i)->getPntrToAction() );
if( !av ) done_in_chain=false;
else if( av->hasMask() && !myfunc.checkIfMaskAllowed( getArguments() ) ) error("cannot use argument masks in input as not all elements are computed");
else if( av->getNumberOfMasks()>0 && !myfunc.checkIfMaskAllowed( getArguments() ) ) error("cannot use argument masks in input as not all elements are computed");
}
}
// Don't need to do the calculation in a chain if the input is constant
Expand Down
5 changes: 2 additions & 3 deletions src/matrixtools/MatrixTimesMatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ MatrixTimesMatrix::MatrixTimesMatrix(const ActionOptions&ao):
Action(ao),
ActionWithMatrix(ao)
{
unsigned nmask = 0; if( hasMask() ) nmask = 1;
if( getNumberOfArguments()-nmask!=2 ) error("should be two arguments to this action, a matrix and a vector");
if( getNumberOfArguments()-getNumberOfMasks()!=2 ) error("should be two arguments to this action, a matrix and a vector");
if( getPntrToArgument(0)->getRank()!=2 || getPntrToArgument(0)->hasDerivatives() ) error("first argument to this action should be a matrix");
if( getPntrToArgument(1)->getRank()!=2 || getPntrToArgument(1)->hasDerivatives() ) error("second argument to this action should be a matrix");
if( getPntrToArgument(0)->getShape()[1]!=getPntrToArgument(1)->getShape()[0] ) error("number of columns in first matrix does not equal number of rows in second matrix");
Expand All @@ -88,7 +87,7 @@ MatrixTimesMatrix::MatrixTimesMatrix(const ActionOptions&ao):
if( squared ) log.printf(" calculating the squares of the dissimilarities \n");
} else squared=true;

if( hasMask() ) {
if( getNumberOfMasks()>0 ) {
unsigned iarg = getNumberOfArguments()-1;
if( getPntrToArgument(iarg)->getRank()!=2 || getPntrToArgument(0)->hasDerivatives() ) error("argument passed to MASK keyword should be a matrix");
if( getPntrToArgument(iarg)->getShape()[0]!=shape[0] || getPntrToArgument(iarg)->getShape()[1]!=shape[1] ) error("argument passed to MASK keyword has the wrong shape");
Expand Down
2 changes: 1 addition & 1 deletion src/matrixtools/MatrixTimesVector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ MatrixTimesVector::MatrixTimesVector(const ActionOptions&ao):
if( getPntrToArgument(i)->hasDerivatives() ) error("arguments should be vectors or matrices");
if( getPntrToArgument(i)->getRank()<=1 ) {
nvectors++; ActionWithVector* av=dynamic_cast<ActionWithVector*>( getPntrToArgument(i)->getPntrToAction() );
if( av && av->hasMask() ) vectormask=true;
if( av && av->getNumberOfMasks()>0 ) vectormask=true;
}
if( getPntrToArgument(i)->getRank()==2 ) nmatrices++;
}
Expand Down

0 comments on commit 4607ff5

Please sign in to comment.