Skip to content

Commit

Permalink
Removed old style derivatives from MatrixVectorProduct and Quaternion…
Browse files Browse the repository at this point in the history
…BondProductMatrix
  • Loading branch information
Gareth Aneurin Tribello committed Sep 12, 2024
1 parent 6ca7a9a commit f8e4519
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 146 deletions.
8 changes: 6 additions & 2 deletions src/core/ActionWithVector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
along with plumed. If not, see <http://www.gnu.org/licenses/>.
+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
#include "ActionWithVector.h"
#include "ActionWithMatrix.h"
#include "PlumedMain.h"
#include "ActionSet.h"
#include "tools/OpenMP.h"
Expand Down Expand Up @@ -721,11 +722,14 @@ bool ActionWithVector::getNumberOfStoredValues( Value* startat, unsigned& nvals,
}

void ActionWithVector::runTask( const unsigned& current, MultiValue& myvals ) const {
const ActionWithMatrix* am = dynamic_cast<const ActionWithMatrix*>(this);
myvals.setTaskIndex(current); myvals.vector_call=true; performTask( current, myvals );
for(unsigned i=0; i<getNumberOfComponents(); ++i) {
const Value* myval = getConstPntrToComponent(i);
if( myval->getRank()!=1 || myval->hasDerivatives() || !myval->valueIsStored() ) continue;
Value* myv = const_cast<Value*>( myval ); myv->set( current, myvals.get( myval->getPositionInStream() ) );
if( am || myval->hasDerivatives() || !myval->valueIsStored() ) continue;
Value* myv = const_cast<Value*>( myval );
if( getName()=="RMSD_VECTOR" && myv->getRank()==2 ) continue;
myv->set( current, myvals.get( myval->getPositionInStream() ) );
}
if( action_to_do_after ) action_to_do_after->runTask( current, myvals );
}
Expand Down
7 changes: 7 additions & 0 deletions src/core/Value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,13 @@ void Value::reshapeMatrixStore( const unsigned& n ) {
}
}

void Value::copyBookeepingArrayFromArgument( Value* myarg ) {
plumed_dbg_assert( shape.size()==2 && !hasDeriv );
ncols = myarg->getNumberOfColumns(); matrix_bookeeping.resize( myarg->matrix_bookeeping.size() );
for(unsigned i=0; i<matrix_bookeeping.size(); ++i) matrix_bookeeping[i] = myarg->matrix_bookeeping[i];
data.resize( shape[0]*ncols ); inputForce.resize( shape[0]*ncols );
}

void Value::setPositionInMatrixStash( const unsigned& p ) {
plumed_dbg_assert( shape.size()==2 && !hasDeriv );
matpos=p;
Expand Down
2 changes: 2 additions & 0 deletions src/core/Value.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ class Value {
void buildDataStore( const bool forprint=false );
/// Reshape the storage for sparse matrices
void reshapeMatrixStore( const unsigned& n );
/// Copy the matrix bookeeping stuff
void copyBookeepingArrayFromArgument( Value* myarg );
/// Set the symmetric flag equal true for this matrix
void setSymmetric( const bool& sym );
/// Get the total number of scalars that are stored here
Expand Down
42 changes: 17 additions & 25 deletions src/crystdistrib/QuaternionBondProductMatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,12 @@ void QuaternionBondProductMatrix::performTask( const std::string& controller, co
//[dit/dw1 dit/di1 dit/dj1 dit/dk1] etc, and dqt[1] is w.r.t the vector-turned-quaternion called bond

// Retrieve the quaternion
for(unsigned i=0; i<4; ++i) quat[i] = getArgumentElement( i, index1, myvals );
for(unsigned i=0; i<4; ++i) quat[i] = getPntrToArgument(i)->get(index1);

// Retrieve the components of the matrix
double weight = getElementOfMatrixArgument( 4, index1, ind2, myvals );
for(unsigned i=1; i<4; ++i) bond[i] = getElementOfMatrixArgument( 4+i, index1, ind2, myvals );
unsigned find = getPntrToArgument(4)->getIndexInStore( index1*getPntrToArgument(4)->getShape()[1] + ind2 );
double weight = getPntrToArgument(4)->get(find, false );
for(unsigned i=1; i<4; ++i) bond[i] = getPntrToArgument(4+i)->get(find, false );

// calculate normalization factor
bond[0]=0.0;
Expand All @@ -156,16 +157,11 @@ void QuaternionBondProductMatrix::performTask( const std::string& controller, co
double normFac3 = normFac*normFac*normFac;
//I hold off on normalizing because this can be done at the very end, and it makes the derivatives with respect to 'bond' more simple



std::vector<double> quat_conj(4);
quat_conj[0] = quat[0]; quat_conj[1] = -1*quat[1]; quat_conj[2] = -1*quat[2]; quat_conj[3] = -1*quat[3];
//make a conjugate of q1 my own sanity




//q1_conj * r first, while keep track of derivs
//q1_conj * r first, while keep track of derivs
double pref=1;
double conj=1;
double pref2=1;
Expand Down Expand Up @@ -236,21 +232,20 @@ void QuaternionBondProductMatrix::performTask( const std::string& controller, co
pref2=1; unsigned base=0;
for(unsigned i=0; i<4; ++i) {
if( i>0 ) {pref=-1; pref2=-1;}
myvals.addValue( getConstPntrToComponent(0)->getPositionInStream(), normFac*pref*quatTemp[i]*quat[i] );
myvals.addValue( 0, normFac*pref*quatTemp[i]*quat[i] );
wf+=normFac*pref*quatTemp[i]*quat[i];
if( doNotCalculateDerivatives() ) continue ;
tempDot=(dotProduct(Vector4d(quat[0],-quat[1],-quat[2],-quat[3]), dqt[0].getCol(i)) + pref2*quatTemp[i])*normFac;
addDerivativeOnVectorArgument( stored[i], 0, i, index1, tempDot, myvals);
myvals.addDerivative( 0, base + index1, tempDot ); myvals.updateIndex( 0, base + index1 );
base += getPntrToArgument(i)->getNumberOfStoredValues();
}
//had to split because bond's derivatives depend on the value of the overall quaternion component
if( !doNotCalculateDerivatives() ) {
unsigned ostrn = getConstPntrToComponent(0)->getPositionInStream();
for(unsigned i=0; i<4; ++i) {
tempDot=dotProduct(Vector4d(quat[0],-quat[1],-quat[2],-quat[3]), dqt[1].getCol(i))*normFac;
if( i>0 ) {
plumed_assert( !stored[4+i] ); unsigned find = getPntrToArgument(4+i)->getIndexInStore( index1*getPntrToArgument(4+i)->getShape()[1] + ind2 );
myvals.addDerivative( ostrn, base + find, tempDot ); myvals.updateIndex( ostrn, base + find );
myvals.addDerivative( 0, base + find, tempDot ); myvals.updateIndex( 0, base + find );
}
base += getPntrToArgument(4+i)->getNumberOfStoredValues();
}
Expand All @@ -262,23 +257,22 @@ void QuaternionBondProductMatrix::performTask( const std::string& controller, co
for (unsigned i=0; i<4; i++) {
if(i==3) pref=-1;
else pref=1;
myvals.addValue( getConstPntrToComponent(1)->getPositionInStream(), normFac*pref*quatTemp[i]*quat[(5-i)%4]);
myvals.addValue( 1, normFac*pref*quatTemp[i]*quat[(5-i)%4]);
xf+=normFac*pref*quatTemp[i]*quat[(5-i)%4];
if(i==2) pref2=-1;
else pref2=1;
if( doNotCalculateDerivatives() ) continue ;
tempDot=(dotProduct(Vector4d(quat[1],quat[0],quat[3],-quat[2]), dqt[0].getCol(i)) + pref2*quatTemp[(5-i)%4])*normFac;
addDerivativeOnVectorArgument( stored[i], 1, i, index1, tempDot, myvals);
myvals.addDerivative( 1, base + index1, tempDot ); myvals.updateIndex( 1, base + index1 );
base += getPntrToArgument(i)->getNumberOfStoredValues();
}

if( !doNotCalculateDerivatives() ) {
unsigned ostrn = getConstPntrToComponent(1)->getPositionInStream();
for(unsigned i=0; i<4; ++i) {
tempDot=dotProduct(Vector4d(quat[1],quat[0],quat[3],-quat[2]), dqt[1].getCol(i))*normFac;
if( i>0 ) {
plumed_assert( !stored[4+i] ); unsigned find = getPntrToArgument(4+i)->getIndexInStore( index1*getPntrToArgument(4+i)->getShape()[1] + ind2 );
myvals.addDerivative( ostrn, base + find, tempDot+(-bond[i]*normFac*normFac*xf) ); myvals.updateIndex( ostrn, base + find );
myvals.addDerivative( 1, base + find, tempDot+(-bond[i]*normFac*normFac*xf) ); myvals.updateIndex( 1, base + find );
}
base += getPntrToArgument(4+i)->getNumberOfStoredValues();
}
Expand All @@ -294,21 +288,20 @@ void QuaternionBondProductMatrix::performTask( const std::string& controller, co
if (i==3) pref2=-1;
else pref2=1;

myvals.addValue( getConstPntrToComponent(2)->getPositionInStream(), normFac*pref*quatTemp[i]*quat[(i+2)%4]);
myvals.addValue( 2, normFac*pref*quatTemp[i]*quat[(i+2)%4]);
yf+=normFac*pref*quatTemp[i]*quat[(i+2)%4];
if( doNotCalculateDerivatives() ) continue ;
tempDot=(dotProduct(Vector4d(quat[2],-quat[3],quat[0],quat[1]), dqt[0].getCol(i)) + pref2*quatTemp[(i+2)%4])*normFac;
addDerivativeOnVectorArgument( stored[i], 2, i, index1, tempDot, myvals);
myvals.addDerivative( 2, base + index1, tempDot ); myvals.updateIndex( 2, base + index1 );
base += getPntrToArgument(i)->getNumberOfStoredValues();
}

if( !doNotCalculateDerivatives() ) {
unsigned ostrn = getConstPntrToComponent(2)->getPositionInStream();
for(unsigned i=0; i<4; ++i) {
tempDot=dotProduct(Vector4d(quat[2],-quat[3],quat[0],quat[1]), dqt[1].getCol(i))*normFac;
if( i>0 ) {
plumed_assert( !stored[4+i] ); unsigned find = getPntrToArgument(4+i)->getIndexInStore( index1*getPntrToArgument(4+i)->getShape()[1] + ind2 );
myvals.addDerivative( ostrn, base + find, tempDot+(-bond[i]*normFac*normFac*yf) ); myvals.updateIndex( ostrn, base + find );
myvals.addDerivative( 2, base + find, tempDot+(-bond[i]*normFac*normFac*yf) ); myvals.updateIndex( 2, base + find );
}
base += getPntrToArgument(4+i)->getNumberOfStoredValues();
}
Expand All @@ -323,22 +316,21 @@ void QuaternionBondProductMatrix::performTask( const std::string& controller, co
if(i==1) pref2=-1;
else pref2=1;

myvals.addValue( getConstPntrToComponent(3)->getPositionInStream(), normFac*pref*quatTemp[i]*quat[(3-i)]);
myvals.addValue( 3, normFac*pref*quatTemp[i]*quat[(3-i)]);
zf+=normFac*pref*quatTemp[i]*quat[(3-i)];
if( doNotCalculateDerivatives() ) continue ;
tempDot=(dotProduct(Vector4d(quat[3],quat[2],-quat[1],quat[0]), dqt[0].getCol(i)) + pref2*quatTemp[(3-i)])*normFac;
addDerivativeOnVectorArgument( stored[i], 3, i, index1, tempDot, myvals);
myvals.addDerivative( 3, base + index1, tempDot ); myvals.updateIndex( 3, base + index1 );
base += getPntrToArgument(i)->getNumberOfStoredValues();
}

if( doNotCalculateDerivatives() ) return ;

unsigned ostrn = getConstPntrToComponent(3)->getPositionInStream();
for(unsigned i=0; i<4; ++i) {
tempDot=dotProduct(Vector4d(quat[3],quat[2],-quat[1],quat[0]), dqt[1].getCol(i))*normFac;
if( i>0 ) {
plumed_assert( !stored[4+i] ); unsigned find = getPntrToArgument(4+i)->getIndexInStore( index1*getPntrToArgument(4+i)->getShape()[1] + ind2 );
myvals.addDerivative( ostrn, base + find, tempDot+(-bond[i]*normFac*normFac*zf) ); myvals.updateIndex( ostrn, base + find );
myvals.addDerivative( 3, base + find, tempDot+(-bond[i]*normFac*normFac*zf) ); myvals.updateIndex( 3, base + find );
}
base += getPntrToArgument(4+i)->getNumberOfStoredValues();
}
Expand Down
Loading

0 comments on commit f8e4519

Please sign in to comment.