Skip to content

Commit

Permalink
BUG: import sometimes needed host update instead of device
Browse files Browse the repository at this point in the history
  • Loading branch information
Adrian-Diaz committed Dec 17, 2024
1 parent 21c4ffe commit bc508b4
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 4 deletions.
2 changes: 0 additions & 2 deletions examples/ann_distributed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,6 @@ int main(int argc, char* argv[])
inputs.update_device(); // copy inputs to device
TpetraCommunicationPlan<real_t> input_comms(inputs_row, inputs);
input_comms.execute_comms(); //distribute to full map for row-vector product
inputs_row.update_device();
//inputs.print();

// for (size_t i=0; i<num_nodes_in_layer[0]; i++) {
Expand Down Expand Up @@ -353,7 +352,6 @@ int main(int argc, char* argv[])

ANNLayers(layer-1).distributed_outputs.update_host();
ANNLayers(layer-1).output_comms.execute_comms(); //distribute to full map for row-vector product
ANNLayers(layer-1).distributed_output_row.update_device();
// go through this layer, the fcn takes(inputs, outputs, weights)
forward_propagate_layer(ANNLayers(layer-1).distributed_output_row,
ANNLayers(layer).distributed_outputs,
Expand Down
35 changes: 33 additions & 2 deletions src/include/tpetra_wrapper_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ using tpetra_global_size_t = Tpetra::global_size_t;
namespace mtr
{

//forward declarations for friendship
template <typename T, typename Layout, typename ExecSpace, typename MemoryTraits>
class TpetraCommunicationPlan;
template <typename T, typename Layout, typename ExecSpace, typename MemoryTraits>
class TpetraLRCommunicationPlan;

/////////////////////////
// TpetraPartitionMap: Container storing global indices corresponding to local indices that belong on this process/rank as well as comms related data/functions.
/////////////////////////
Expand Down Expand Up @@ -360,6 +366,7 @@ class TpetraDCArray {

public:

friend class TpetraLRCommunicationPlan<T,Layout,ExecSpace,MemoryTraits>;
//data for arrays that own both shared and local data and aren't intended to communicate with another MATAR type
//This is simplifying for cases such as a local + ghost storage vector where you need to update the ghost entries
bool own_comms; //This Mapped MPI Array contains its own communication plan; just call array_comms()
Expand Down Expand Up @@ -1582,7 +1589,7 @@ class TpetraDFArray {


public:

friend class TpetraCommunicationPlan<T,Layout,ExecSpace,MemoryTraits>;
//data for arrays that own both shared and local data and aren't intended to communicate with another MATAR type
//This is simplifying for cases such as a local + ghost storage vector where you need to update the ghost entries
bool own_comms; //This Mapped MPI Array contains its own communication plan; just call array_comms()
Expand Down Expand Up @@ -3349,10 +3356,22 @@ TpetraCommunicationPlan<T,Layout,ExecSpace,MemoryTraits>& TpetraCommunicationPla
template <typename T, typename Layout, typename ExecSpace, typename MemoryTraits>
void TpetraCommunicationPlan<T,Layout,ExecSpace,MemoryTraits>::execute_comms(){
if(reverse_comms_flag){
destination_vector_.tpetra_vector->doExport(*(source_vector_.tpetra_vector), *exporter, Tpetra::INSERT, true);
destination_vector_.tpetra_vector->doExport(*(source_vector_.tpetra_vector), *exporter, Tpetra::INSERT, true);\
if(destination_vector_.this_array_.template need_sync<typename decltype(destination_vector_)::TArray1D::execution_space>()){
destination_vector_.this_array_.template sync<typename decltype(destination_vector_)::TArray1D::execution_space>();
}
else{
destination_vector_.this_array_.template sync<typename decltype(destination_vector_)::TArray1D::host_mirror_space>();
}
}
else{
destination_vector_.tpetra_vector->doImport(*(source_vector_.tpetra_vector), *importer, Tpetra::INSERT);
if(destination_vector_.this_array_.template need_sync<typename decltype(destination_vector_)::TArray1D::execution_space>()){
destination_vector_.this_array_.template sync<typename decltype(destination_vector_)::TArray1D::execution_space>();
}
else{
destination_vector_.this_array_.template sync<typename decltype(destination_vector_)::TArray1D::host_mirror_space>();
}
}
}

Expand Down Expand Up @@ -3469,9 +3488,21 @@ template <typename T, typename Layout, typename ExecSpace, typename MemoryTraits
void TpetraLRCommunicationPlan<T,Layout,ExecSpace,MemoryTraits>::execute_comms(){
if(reverse_comms_flag){
destination_vector_.tpetra_vector->doExport(*(source_vector_.tpetra_vector), *exporter, Tpetra::INSERT, true);
if(destination_vector_.this_array_.template need_sync<typename decltype(destination_vector_)::TArray1D::execution_space>()){
destination_vector_.this_array_.template sync<typename decltype(destination_vector_)::TArray1D::execution_space>();
}
else{
destination_vector_.this_array_.template sync<typename decltype(destination_vector_)::TArray1D::host_mirror_space>();
}
}
else{
destination_vector_.tpetra_vector->doImport(*(source_vector_.tpetra_vector), *importer, Tpetra::INSERT);
if(destination_vector_.this_array_.template need_sync<typename decltype(destination_vector_)::TArray1D::execution_space>()){
destination_vector_.this_array_.template sync<typename decltype(destination_vector_)::TArray1D::execution_space>();
}
else{
destination_vector_.this_array_.template sync<typename decltype(destination_vector_)::TArray1D::host_mirror_space>();
}
}
}

Expand Down

0 comments on commit bc508b4

Please sign in to comment.