Skip to content

Commit

Permalink
Fix Tpetra vector compression state after add() into locally owned entry
Browse files Browse the repository at this point in the history
Fix Tpetra vector compression state based on the presence of non-local entries

add a test for Tpetra vector addition into both owned and nonlocal entries
  • Loading branch information
QY-Shi committed Aug 29, 2024
1 parent 6357de6 commit bbcdded
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 0 deletions.
5 changes: 5 additions & 0 deletions doc/news/changes/minor/20240828Shi
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Fix: Add the missing `compressed = false` in TpetraWrappers::Vector::add()
function to ensure distributed Tpetra vector with writable nonlocal
entries will always be non-compressed after any entry addition.
<br>
(Qingyuan Shi, 2024/08/28)
12 changes: 12 additions & 0 deletions include/deal.II/lac/trilinos_tpetra_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -1167,6 +1167,12 @@ namespace LinearAlgebra
local_row != Teuchos::OrdinalTraits<int>::invalid())
{
vector_1d_local(local_row) += values[i];

// Set the compressed state to false only if there is nonlocal
// part in this distributed vector, otherwise it's always
// compressed.
if (nonlocal_vector.get() != nullptr)
compressed = false;
}
else
{
Expand Down Expand Up @@ -1285,6 +1291,12 @@ namespace LinearAlgebra
local_row != Teuchos::OrdinalTraits<int>::invalid())
{
vector_1d_local(local_row) = values[i];

// Set the compressed state to false only if there is nonlocal
// part in this distributed vector, otherwise it's always
// compressed.
if (nonlocal_vector.get() != nullptr)
compressed = false;
}
else
{
Expand Down
186 changes: 186 additions & 0 deletions tests/trilinos_tpetra/vector_nonlocal_addition.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
// ------------------------------------------------------------------------
//
// SPDX-License-Identifier: LGPL-2.1-or-later
// Copyright (C) 2004 - 2024 by the deal.II authors
//
// This file is part of the deal.II library.
//
// Part of the source code is dual licensed under Apache-2.0 WITH
// LLVM-exception OR LGPL-2.1-or-later. Detailed license information
// governing the source code and code contributions can be found in
// LICENSE.md and CONTRIBUTING.md at the top level directory of deal.II.
//
// ------------------------------------------------------------------------



// Test distributed vector operations across multiple processes with non-local
// entries to verify correct handling and compression states.

#include <deal.II/base/index_set.h>
#include <deal.II/base/utilities.h>

#include <deal.II/lac/trilinos_tpetra_vector.h>

#include <iostream>
#include <vector>

#include "../tests.h"


void
test()
{
unsigned int my_rank = Utilities::MPI::this_mpi_process(MPI_COMM_WORLD);
unsigned int n_procs = Utilities::MPI::n_mpi_processes(MPI_COMM_WORLD);


IndexSet locally_owned(8);
IndexSet locally_relevant(8);

// Distribution of vector entries across processes:
// Process 0 and Process 1 have non-local entries, while Process 2 does not.

// Entry Index | Process Distribution | Owned by
// ------------ | ----------------------| ----------
// 0 | 0 | 0
// 1 | 0 | 0
// 2 | 0, 1 | 0
// 3 | 0, 1 | 1
// 4 | 1 | 1
// 5 | 1, 2 | 2
// 6 | 2 | 2
// 7 | 2 | 2


if (my_rank == 0)
{
locally_owned.add_range(0, 3);
locally_relevant.add_range(0, 4);
}
if (my_rank == 1)
{
locally_owned.add_range(3, 5);
locally_relevant.add_range(2, 6);
}
if (my_rank == 2)
{
// no nonlocal entries on rank 2
locally_owned.add_range(5, 8);
locally_relevant.add_range(5, 8);
}
locally_owned.compress();
locally_relevant.compress();

// Create a vector with writable entries
LinearAlgebra::TpetraWrappers::Vector<double, MemorySpace::Default>
vector_with_nonlocal_entries;

vector_with_nonlocal_entries.reinit(locally_owned,
locally_relevant,
MPI_COMM_WORLD,
true);

vector_with_nonlocal_entries = 0.0;

if (my_rank == 0)
{
// Add to both local and non-local parts
vector_with_nonlocal_entries[0] += 1.0;
vector_with_nonlocal_entries[1] += 1.0;
vector_with_nonlocal_entries[2] += 1.0;
vector_with_nonlocal_entries[3] += 1.0;
}
if (my_rank == 1)
{
// Add to both local and non-local parts
vector_with_nonlocal_entries[2] += 2.0;
vector_with_nonlocal_entries[3] += 2.0;
vector_with_nonlocal_entries[4] += 2.0;
vector_with_nonlocal_entries[5] += 2.0;
}
if (my_rank == 2)
{
// Add to only local part, there's no nonlocal part on rank 2
vector_with_nonlocal_entries[5] += 3.0;
vector_with_nonlocal_entries[6] += 3.0;
vector_with_nonlocal_entries[7] += 3.0;
}
vector_with_nonlocal_entries.compress(VectorOperation::add);

// Expected Results:
// Entry Index | Additions | Result
// ------------ | -------------- | -------
// 0 | +1 | 1.0
// 1 | +1 | 1.0
// 2 | +1, +2 | 3.0
// 3 | +1, +2 | 3.0
// 4 | +2 | 2.0
// 5 | +2, +3 | 5.0
// 6 | +3 | 3.0
// 7 | +3 | 3.0

// Check the results on each process
if (my_rank == 0)
{
AssertThrow((vector_with_nonlocal_entries[0] == 1.0), ExcInternalError());
AssertThrow((vector_with_nonlocal_entries[1] == 1.0), ExcInternalError());
AssertThrow((vector_with_nonlocal_entries[2] == 3.0), ExcInternalError());
}
if (my_rank == 1)
{
AssertThrow((vector_with_nonlocal_entries[3] == 3.0), ExcInternalError());
AssertThrow((vector_with_nonlocal_entries[4] == 2.0), ExcInternalError());
}
if (my_rank == 2)
{
AssertThrow((vector_with_nonlocal_entries[5] == 5.0), ExcInternalError());
AssertThrow((vector_with_nonlocal_entries[6] == 3.0), ExcInternalError());
AssertThrow((vector_with_nonlocal_entries[7] == 3.0), ExcInternalError());
}

deallog << "OK" << std::endl;
}



int
main(int argc, char **argv)
{
initlog();

Utilities::MPI::MPI_InitFinalize mpi_initialization(
argc, argv, testing_max_num_threads());


try
{
test();
}
catch (const std::exception &exc)
{
std::cerr << std::endl
<< std::endl
<< "----------------------------------------------------"
<< std::endl;
std::cerr << "Exception on processing: " << std::endl
<< exc.what() << std::endl
<< "Aborting!" << std::endl
<< "----------------------------------------------------"
<< std::endl;

return 1;
}
catch (...)
{
std::cerr << std::endl
<< std::endl
<< "----------------------------------------------------"
<< std::endl;
std::cerr << "Unknown exception!" << std::endl
<< "Aborting!" << std::endl
<< "----------------------------------------------------"
<< std::endl;
return 1;
};
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

DEAL::OK

0 comments on commit bbcdded

Please sign in to comment.