Skip to content

Commit

Permalink
Fixes RCG solver to be correct with Kokkos::DualView
Browse files Browse the repository at this point in the history
  • Loading branch information
hkthorn committed Sep 18, 2024
1 parent 61a3543 commit e054de7
Show file tree
Hide file tree
Showing 3 changed files with 247 additions and 21 deletions.
55 changes: 35 additions & 20 deletions packages/belos/src/BelosRCGSolMgr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ namespace Belos {
virtual ~RCGSolMgr() {};

//! clone for Inverted Injection (DII)
Teuchos::RCP<SolverManager<ScalarType, MV, OP> > clone () const override {
return Teuchos::rcp(new RCGSolMgr<ScalarType,MV,OP>);
Teuchos::RCP<SolverManager<ScalarType, MV, OP, DM> > clone () const override {
return Teuchos::rcp(new RCGSolMgr<ScalarType,MV,OP,DM>);
}
//@}

Expand Down Expand Up @@ -817,23 +817,23 @@ void RCGSolMgr<ScalarType,MV,OP,DM,true>::initializeStateStorage() {
if (Alpha_ == Teuchos::null)
Alpha_ = Teuchos::rcp( new std::vector<ScalarType>( numBlocks_, 1 ) );
else {
if ( Alpha_->size() != numBlocks_ )
if ( (int)Alpha_->size() != numBlocks_ )
Alpha_->resize( numBlocks_, 1 );
}

// Generate Beta_ only if it doesn't exist, otherwise resize it.
if (Beta_ == Teuchos::null)
Beta_ = Teuchos::rcp( new std::vector<ScalarType>( numBlocks_ + 1 ) );
else {
if ( (Beta_->size() != (numBlocks_+1)) )
if ( ((int)Beta_->size() != (numBlocks_+1)) )
Beta_->resize( numBlocks_ + 1 );
}

// Generate D_ only if it doesn't exist, otherwise resize it.
if (D_ == Teuchos::null)
D_ = Teuchos::rcp( new std::vector<ScalarType>( numBlocks_ ) );
else {
if ( D_->size() != numBlocks_ )
if ( (int)D_->size() != numBlocks_ )
D_->resize( numBlocks_ );
}

Expand Down Expand Up @@ -1158,8 +1158,9 @@ ReturnType RCGSolMgr<ScalarType,MV,OP,DM,true>::solve() {
Teuchos::RCP<DM> Utr = DMT::Create(recycleBlocks_,1);
Teuchos::RCP<const MV> Utmp = MVT::CloneView( *U_, rindex );
MVT::MvTransMv( one, *Utmp, *r_, *Utr );

DMT::SyncHostToDevice(*LUUTAU_);
DMT::Assign(*LUUTAU_,*UTAU_);

DMT::SyncDeviceToHost( *LUUTAU_ );
DMT::SyncDeviceToHost( *Utr );
int info = 0;
Expand Down Expand Up @@ -1192,7 +1193,7 @@ ReturnType RCGSolMgr<ScalarType,MV,OP,DM,true>::solve() {
Teuchos::RCP<const MV> AUtmp = MVT::CloneView( *AU_, rindex );
MVT::MvTransMv( one, *AUtmp, *z_, *mu );

DMT::SyncDeviceToHost( *mu );
DMT::SyncDeviceToHost( *Delta_ );
char TRANS = 'N';
int info;
lapack.GETRS( TRANS, recycleBlocks_, 1, DMT::GetConstRawHostPtr(*LUUTAU_), DMT::GetStride(*LUUTAU_),
Expand Down Expand Up @@ -1253,7 +1254,7 @@ ReturnType RCGSolMgr<ScalarType,MV,OP,DM,true>::solve() {
//
////////////////////////////////////////////////////////////////////////////////////
if ( convTest_->getStatus() == Passed ) {
// We have convergence
// We have convergence
break; // break from while(1){rcg_iter->iterate()}
}
////////////////////////////////////////////////////////////////////////////////////
Expand All @@ -1280,13 +1281,13 @@ ReturnType RCGSolMgr<ScalarType,MV,OP,DM,true>::solve() {
if (!existU_) {
if (cycle == 0) { // No U, no U1

DMT::SyncDeviceToHost( *F_ );
DMT::SyncDeviceToHost( *G_ );
Teuchos::RCP<DM> Ftmp = DMT::Subview( *F_, numBlocks_, numBlocks_ );
Teuchos::RCP<DM> Gtmp = DMT::Subview( *G_, numBlocks_, numBlocks_ );
DMT::PutScalar( *Ftmp, zero );
DMT::PutScalar( *Gtmp, zero );
for (int ii=0;ii<numBlocks_;ii++) {
DMT::SyncDeviceToHost( *F_ );
DMT::SyncDeviceToHost( *G_ );
for (int ii=0;ii<numBlocks_;ii++) {
DMT::Value(*Gtmp,ii,ii) = ((*D_)[ii] / (*Alpha_)[ii])*(1 + (*Beta_)[ii]);
if (ii > 0) {
DMT::Value(*Gtmp,ii-1,ii) = -(*D_)[ii]/(*Alpha_)[ii-1];
Expand All @@ -1298,6 +1299,7 @@ ReturnType RCGSolMgr<ScalarType,MV,OP,DM,true>::solve() {
DMT::SyncHostToDevice( *G_ );

// compute harmonic Ritz vectors
DMT::SyncDeviceToHost( *Y_ );
Teuchos::RCP<DM> Ytmp = DMT::Subview( *Y_, numBlocks_, recycleBlocks_ );
getHarmonicVecs(*Ftmp,*Gtmp,*Ytmp);
DMT::SyncHostToDevice( *Y_ );
Expand All @@ -1308,6 +1310,10 @@ ReturnType RCGSolMgr<ScalarType,MV,OP,DM,true>::solve() {
MVT::MvTimesMatAddMv( one, *Ptmp, *Ytmp, zero, *U1tmp );

// Precompute some variables for next cycle
DMT::SyncDeviceToHost(*GY_);
DMT::SyncDeviceToHost(*AU1TAU1_);
DMT::SyncDeviceToHost(*FY_);
DMT::SyncDeviceToHost(*AU1TU1_);

// AU1TAU1 = Y'*G*Y;
Teuchos::RCP<DM> GYtmp = DMT::Subview( *GY_, numBlocks_, recycleBlocks_ );
Expand Down Expand Up @@ -1344,11 +1350,12 @@ ReturnType RCGSolMgr<ScalarType,MV,OP,DM,true>::solve() {
// Must reinitialize AU1TAP; can become dense later
DMT::PutScalar( *AU1TAPtmp, zero );
// AU1TAP(:,1) = Y(end,:)' * (-1/Alpha(end));
ScalarType alphatmp = -1.0 / (*Alpha_)[numBlocks_-1];
DMT::SyncDeviceToHost( *AU1TAP_ );
ScalarType alphatmp = -1.0 / (*Alpha_)[numBlocks_-1];
for (int ii=0; ii<recycleBlocks_; ++ii) {
DMT::Value(*AU1TAPtmp,ii,0) = DMT::ValueConst(*Ytmp,numBlocks_-1,ii) * alphatmp;
}
DMT::SyncDeviceToHost(*AU1TAP_);
DMT::SyncHostToDevice(*AU1TAP_);

// indicate that updated recycle space now defined
existU1_ = true;
Expand Down Expand Up @@ -1376,7 +1383,7 @@ ReturnType RCGSolMgr<ScalarType,MV,OP,DM,true>::solve() {
DMT::SyncHostToDevice(*APTAP_);

// F = [AU1TU1 zeros(k,m); zeros(m,k) diag(D)];
F_->putScalar(zero);
DMT::PutScalar(*F_,zero);
Teuchos::RCP<DM> F11 = DMT::Subview( *F_, recycleBlocks_, recycleBlocks_ );
Teuchos::RCP<DM> F22 = DMT::Subview( *F_, numBlocks_, numBlocks_, recycleBlocks_, recycleBlocks_ );
DMT::Assign(*F11,*AU1TU1_);
Expand Down Expand Up @@ -1497,8 +1504,10 @@ ReturnType RCGSolMgr<ScalarType,MV,OP,DM,true>::solve() {
DMT::SyncHostToDevice(*L2_);

// AUTAP = UTAU*Delta*L2;
DMT::SyncDeviceToHost(*DeltaL2_);
DMT::SyncDeviceToHost(*Delta_);
DMT::SyncDeviceToHost(*DeltaL2_);
DMT::SyncDeviceToHost(*AUTAP_);
DMT::SyncDeviceToHost(*UTAU_);

//DeltaL2_->multiply(Teuchos::NO_TRANS,Teuchos::NO_TRANS,one,*Delta_,*L2_,zero);
blas.GEMM( Teuchos::NO_TRANS, Teuchos::NO_TRANS, recycleBlocks_, numBlocks_, numBlocks_+1,
Expand Down Expand Up @@ -1625,7 +1634,9 @@ ReturnType RCGSolMgr<ScalarType,MV,OP,DM,true>::solve() {
DMT::Value(*L2_,ii,ii) = 1./(*Alpha_)[ii];
DMT::Value(*L2_,ii+1,ii) = -1./(*Alpha_)[ii];
}
DMT::SyncHostToDevice(*L2_);

DMT::SyncDeviceToHost(*Delta_);
DMT::SyncDeviceToHost(*DeltaL2_);
DMT::SyncDeviceToHost(*AU1TUDeltaL2_);
DMT::SyncDeviceToHost(*AU1TAP_);
Expand All @@ -1642,13 +1653,16 @@ ReturnType RCGSolMgr<ScalarType,MV,OP,DM,true>::solve() {
one, DMT::GetConstRawHostPtr(*AU1TU_), DMT::GetStride(*AU1TU_),
DMT::GetConstRawHostPtr(*DeltaL2_), DMT::GetStride(*DeltaL2_),
zero, DMT::GetRawHostPtr(*AU1TUDeltaL2_), DMT::GetStride(*AU1TUDeltaL2_));

DMT::SyncDeviceToHost( *Y_);
Teuchos::RCP<const DM> Y1 = DMT::SubviewConst( *Y_, recycleBlocks_, recycleBlocks_ );
//AU1TAP_->multiply(Teuchos::TRANS,Teuchos::NO_TRANS,one,*Y1,*AU1TUDeltaL2_,zero);
Teuchos::RCP<const DM> Y2 = DMT::SubviewConst( *Y_, numBlocks_, recycleBlocks_, recycleBlocks_, 0 );

//AU1TAP_->multiply(Teuchos::TRANS,Teuchos::NO_TRANS,one,*Y1,*AU1TUDeltaL2_,zero);
blas.GEMM( Teuchos::TRANS, Teuchos::NO_TRANS, recycleBlocks_, numBlocks_, recycleBlocks_,
one, DMT::GetConstRawHostPtr(*Y1), DMT::GetStride(*Y1),
DMT::GetConstRawHostPtr(*AU1TUDeltaL2_), DMT::GetStride(*AU1TUDeltaL2_),
zero, DMT::GetRawHostPtr(*AU1TAP_), DMT::GetStride(*AU1TAP_));
Teuchos::RCP<const DM> Y2 = DMT::SubviewConst( *Y_, numBlocks_, recycleBlocks_, recycleBlocks_, 0 );
ScalarType val = dold * (-(*Beta_)[0]/(*Alpha_)[0]);
for(int ii=0;ii<recycleBlocks_;ii++) {
DMT::Value(*AU1TAP_,ii,0) += DMT::ValueConst(*Y2,numBlocks_-1,ii)*val;
Expand Down Expand Up @@ -1931,8 +1945,9 @@ template<class ScalarType, class MV, class OP, class DM>
void RCGSolMgr<ScalarType,MV,OP,DM,true>::getHarmonicVecs(const DM& F,
const DM& G,
DM& Y ) {

// order of F,G
int n = F.numCols();
int n = DMT::GetNumCols(F);

// The LAPACK interface
Teuchos::LAPACK<int,ScalarType> lapack;
Expand All @@ -1952,8 +1967,8 @@ void RCGSolMgr<ScalarType,MV,OP,DM,true>::getHarmonicVecs(const DM& F,
int lwork = -1;
int info = 0;
// since SYGV destroys workspace, create copies of F,G
Teuchos::RCP<DM> F2 = DMT::CreateCopy( *F_ );
Teuchos::RCP<DM> G2 = DMT::CreateCopy( *G_ );
Teuchos::RCP<DM> F2 = DMT::CreateCopy( F );
Teuchos::RCP<DM> G2 = DMT::CreateCopy( G );

DMT::SyncDeviceToHost(*F2);
DMT::SyncDeviceToHost(*G2);
Expand Down
10 changes: 9 additions & 1 deletion packages/belos/tpetra/test/RCG/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,19 @@ TRIBITS_ADD_EXECUTABLE_AND_TEST(
COMM serial mpi
)

TRIBITS_ADD_EXECUTABLE_AND_TEST(
Tpetra_DenseKokkos_rcg_hb_test
SOURCES test_kokkos_rcg_hb.cpp
ARGS
"--verbose --tol=1e-6 --filename=bcsstk14.hb --num-rhs=3 --max-subspace=100 --recycle=10 --max-iters=4000"
COMM serial mpi
)

ASSERT_DEFINED(Anasazi_SOURCE_DIR)
TRIBITS_COPY_FILES_TO_BINARY_DIR(Tpetra_CopyTestRCGFiles
SOURCE_DIR ${Anasazi_SOURCE_DIR}/testmatrices
SOURCE_FILES bcsstk14.hb
EXEDEPS Tpetra_rcg_hb_test
EXEDEPS Tpetra_rcg_hb_test Tpetra_DenseKokkos_rcg_hb_test
)

ENDIF()
Expand Down
Loading

0 comments on commit e054de7

Please sign in to comment.