Skip to content

Commit

Permalink
ENH: Compute NeighborList Stats Updates (#1118)
Browse files Browse the repository at this point in the history
Added:
- ComputeNeighborListStatistics execution in algorithm file
- new parameter linking
- parallel progress reporting
- code clean up
- removed empty statement from CalculateArrayStatistics.hpp

Patch:
- removed empty statement from CalculateArrayStatistics.hpp
  • Loading branch information
nyoungbq authored Oct 25, 2024
1 parent 5cf455d commit 3e23b69
Show file tree
Hide file tree
Showing 5 changed files with 375 additions and 210 deletions.
1 change: 1 addition & 0 deletions src/Plugins/SimplnxCore/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ set(AlgorithmList
ComputeLargestCrossSections
ComputeMomentInvariants2D
ComputeNeighborhoods
ComputeNeighborListStatistics
ComputeSurfaceAreaToVolume
ComputeTriangleGeomCentroids
ComputeVectorColors
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ namespace nx::core

struct SIMPLNXCORE_EXPORT ComputeArrayStatisticsInputValues
{
bool FindHistogram;
float64 MinRange;
float64 MaxRange;
bool UseFullRange;
int32 NumBins;
bool FindHistogram;
bool UseFullRange;
bool FindLength;
bool FindMin;
bool FindMax;
Expand All @@ -30,7 +30,6 @@ struct SIMPLNXCORE_EXPORT ComputeArrayStatisticsInputValues
bool FindStdDeviation;
bool FindSummation;
bool UseMask;
;
bool ComputeByIndex;
bool StandardizeData;
bool FindNumUniqueValues;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
#include "ComputeNeighborListStatistics.hpp"

#include "simplnx/Utilities/DataArrayUtilities.hpp"
#include "simplnx/Utilities/FilterUtilities.hpp"
#include "simplnx/Utilities/Math/StatisticsCalculations.hpp"
#include "simplnx/Utilities/ParallelDataAlgorithm.hpp"

using namespace nx::core;

namespace
{
constexpr int64 k_BoolTypeNeighborList = -6802;
constexpr int64 k_EmptyNeighborList = -6803;

template <typename T>
class ComputeNeighborListStatisticsImpl
{
public:
using NeighborListType = NeighborList<T>;
using DataArrayType = DataArray<T>;
using StoreType = AbstractDataStore<T>;

ComputeNeighborListStatisticsImpl(ComputeNeighborListStatistics* filter, const INeighborList& source, bool length, bool min, bool max, bool mean, bool median, bool stdDeviation, bool summation,
std::vector<IDataArray*>& arrays, const std::atomic_bool& shouldCancel)
: m_Filter(filter)
, m_Source(source)
, m_Length(length)
, m_Min(min)
, m_Max(max)
, m_Mean(mean)
, m_Median(median)
, m_StdDeviation(stdDeviation)
, m_Summation(summation)
, m_Arrays(arrays)
, m_ShouldCancel(shouldCancel)
{
}

virtual ~ComputeNeighborListStatisticsImpl() = default;

void compute(usize start, usize end) const
{
auto* array0 = m_Length ? m_Arrays[0]->template getIDataStoreAs<AbstractDataStore<uint64>>() : nullptr;
if(m_Length && array0 == nullptr)
{
throw std::invalid_argument("ComputeNeighborListStatisticsFilter::compute() could not dynamic_cast 'Length' array to needed type. Check input array selection.");
}
auto* array1 = m_Min ? m_Arrays[1]->template getIDataStoreAs<StoreType>() : nullptr;
if(m_Min && array1 == nullptr)
{
throw std::invalid_argument("ComputeNeighborListStatisticsFilter::compute() could not dynamic_cast 'Min' array to needed type. Check input array selection.");
}
auto* array2 = m_Max ? m_Arrays[2]->template getIDataStoreAs<StoreType>() : nullptr;
if(m_Max && array2 == nullptr)
{
throw std::invalid_argument("ComputeNeighborListStatisticsFilter::compute() could not dynamic_cast 'Max' array to needed type. Check input array selection.");
}
auto* array3 = m_Mean ? m_Arrays[3]->template getIDataStoreAs<AbstractDataStore<float32>>() : nullptr;
if(m_Mean && array3 == nullptr)
{
throw std::invalid_argument("ComputeNeighborListStatisticsFilter::compute() could not dynamic_cast 'Mean' array to needed type. Check input array selection.");
}
auto* array4 = m_Median ? m_Arrays[4]->template getIDataStoreAs<AbstractDataStore<float32>>() : nullptr;
if(m_Median && array4 == nullptr)
{
throw std::invalid_argument("ComputeNeighborListStatisticsFilter::compute() could not dynamic_cast 'Median' array to needed type. Check input array selection.");
}
auto* array5 = m_StdDeviation ? m_Arrays[5]->template getIDataStoreAs<AbstractDataStore<float32>>() : nullptr;
if(m_StdDeviation && array5 == nullptr)
{
throw std::invalid_argument("ComputeNeighborListStatisticsFilter::compute() could not dynamic_cast 'StdDev' array to needed type. Check input array selection.");
}
auto* array6 = m_Summation ? m_Arrays[6]->template getIDataStoreAs<AbstractDataStore<float32>>() : nullptr;
if(m_Summation && array6 == nullptr)
{
throw std::invalid_argument("ComputeNeighborListStatisticsFilter::compute() could not dynamic_cast 'Summation' array to needed type. Check input array selection.");
}

const auto& sourceList = dynamic_cast<const NeighborListType&>(m_Source);

auto tStart = std::chrono::steady_clock::now();
usize counter = 0;
for(usize i = start; i < end; i++)
{
if(m_ShouldCancel)
{
return;
}

const std::vector<T>& tmpList = sourceList.at(i);

if(m_Length)
{
auto val = static_cast<int64_t>(tmpList.size());
array0->setValue(i, val);
}
if(m_Min)
{
T val = StatisticsCalculations::findMin(tmpList);
array1->setValue(i, val);
}
if(m_Max)
{
T val = StatisticsCalculations::findMax(tmpList);
array2->setValue(i, val);
}
if(m_Mean)
{
float val = StatisticsCalculations::findMean(tmpList);
array3->setValue(i, val);
}
if(m_Median)
{
float val = StatisticsCalculations::findMedian(tmpList);
array4->setValue(i, val);
}
if(m_StdDeviation)
{
float val = StatisticsCalculations::findStdDeviation(tmpList);
array5->setValue(i, val);
}
if(m_Summation)
{
float val = StatisticsCalculations::findSummation(tmpList);
array6->setValue(i, val);
}

counter++;
if(std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - tStart).count() > 1000)
{
m_Filter->sendThreadSafeProgressMessage(counter);
counter = 0;
}
}
m_Filter->sendThreadSafeProgressMessage(counter);
}

void operator()(const Range& range) const
{
compute(range.min(), range.max());
}

private:
ComputeNeighborListStatistics* m_Filter = nullptr;
const std::atomic_bool& m_ShouldCancel;

const INeighborList& m_Source;
bool m_Length = false;
bool m_Min = false;
bool m_Max = false;
bool m_Mean = false;
bool m_Median = false;
bool m_StdDeviation = false;
bool m_Summation = false;

std::vector<IDataArray*>& m_Arrays;
};
} // namespace

// -----------------------------------------------------------------------------
ComputeNeighborListStatistics::ComputeNeighborListStatistics(DataStructure& dataStructure, const IFilter::MessageHandler& msgHandler, const std::atomic_bool& shouldCancel,
ComputeNeighborListStatisticsInputValues* inputValues)
: m_DataStructure(dataStructure)
, m_InputValues(inputValues)
, m_ShouldCancel(shouldCancel)
, m_MessageHandler(msgHandler)
{
}

// -----------------------------------------------------------------------------
ComputeNeighborListStatistics::~ComputeNeighborListStatistics() noexcept = default;

// -----------------------------------------------------------------------------
Result<> ComputeNeighborListStatistics::operator()()
{
const auto& inputINeighborList = m_DataStructure.getDataRefAs<INeighborList>(m_InputValues->TargetNeighborListPath);

DataType type = inputINeighborList.getDataType();
if(type == DataType::boolean)
{
return MakeErrorResult(k_BoolTypeNeighborList, fmt::format("ComputeNeighborListStatisticsFilter::NeighborList {} was of type boolean, and thus cannot be processed", inputINeighborList.getName()));
}

usize numTuples = inputINeighborList.getNumberOfTuples();
if(numTuples == 0)
{
return MakeErrorResult(k_EmptyNeighborList, fmt::format("ComputeNeighborListStatisticsFilter::NeighborList {} was empty", inputINeighborList.getName()));
}

std::vector<IDataArray*> arrays(7, nullptr);

if(m_InputValues->FindLength)
{
arrays[0] = m_DataStructure.getDataAs<IDataArray>(m_InputValues->LengthPath);
}
if(m_InputValues->FindMin)
{
arrays[1] = m_DataStructure.getDataAs<IDataArray>(m_InputValues->MinPath);
}
if(m_InputValues->FindMax)
{
arrays[2] = m_DataStructure.getDataAs<IDataArray>(m_InputValues->MaxPath);
}
if(m_InputValues->FindMean)
{
arrays[3] = m_DataStructure.getDataAs<IDataArray>(m_InputValues->MeanPath);
}
if(m_InputValues->FindMedian)
{
arrays[4] = m_DataStructure.getDataAs<IDataArray>(m_InputValues->MedianPath);
}
if(m_InputValues->FindStdDeviation)
{
arrays[5] = m_DataStructure.getDataAs<IDataArray>(m_InputValues->StdDeviationPath);
}
if(m_InputValues->FindSummation)
{
arrays[6] = m_DataStructure.getDataAs<IDataArray>(m_InputValues->SummationPath);
}

// Fill progress counters for parallel updates
m_ProgressCounter = 0;
m_TotalElements = numTuples;

// Allow data-based parallelization
ParallelDataAlgorithm dataAlg;
dataAlg.setRange(0, numTuples);
ExecuteParallelFunction<ComputeNeighborListStatisticsImpl, NoBooleanType>(type, dataAlg, this, inputINeighborList, m_InputValues->FindLength, m_InputValues->FindMin, m_InputValues->FindMax,
m_InputValues->FindMean, m_InputValues->FindMedian, m_InputValues->FindStdDeviation, m_InputValues->FindSummation, arrays,
m_ShouldCancel);

return {};
}

// -----------------------------------------------------------------------------
const std::atomic_bool& ComputeNeighborListStatistics::getCancel()
{
return m_ShouldCancel;
}

// -----------------------------------------------------------------------------
void ComputeNeighborListStatistics::sendThreadSafeProgressMessage(usize counter)
{
const std::lock_guard<std::mutex> guard(m_ProgressMessage_Mutex);

m_ProgressCounter += counter;

if(std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - m_InitialTime).count() > 1000)
{
auto progressInt = static_cast<usize>((static_cast<float64>(m_ProgressCounter) / static_cast<float64>(m_TotalElements)) * 100.0);
m_MessageHandler(IFilter::Message::Type::Info, fmt::format("Finding Statistics || {}% Completed", progressInt));
m_InitialTime = std::chrono::steady_clock::now();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#pragma once

#include "SimplnxCore/SimplnxCore_export.hpp"

#include "simplnx/DataStructure/DataArray.hpp"
#include "simplnx/DataStructure/DataPath.hpp"
#include "simplnx/DataStructure/DataStructure.hpp"
#include "simplnx/Filter/IFilter.hpp"

#include <chrono>
#include <mutex>

namespace nx::core
{

struct SIMPLNXCORE_EXPORT ComputeNeighborListStatisticsInputValues
{
bool FindLength = false;
bool FindMin = false;
bool FindMax = false;
bool FindMean = false;
bool FindMedian = false;
bool FindStdDeviation = false;
bool FindSummation = false;

DataPath TargetNeighborListPath = {};

DataPath LengthPath = {};
DataPath MinPath = {};
DataPath MaxPath = {};
DataPath MeanPath = {};
DataPath MedianPath = {};
DataPath StdDeviationPath = {};
DataPath SummationPath = {};
};

/**
* @class
*/
class SIMPLNXCORE_EXPORT ComputeNeighborListStatistics
{
public:
ComputeNeighborListStatistics(DataStructure& dataStructure, const IFilter::MessageHandler& msgHandler, const std::atomic_bool& shouldCancel, ComputeNeighborListStatisticsInputValues* inputValues);
~ComputeNeighborListStatistics() noexcept;

ComputeNeighborListStatistics(const ComputeNeighborListStatistics&) = delete;
ComputeNeighborListStatistics(ComputeNeighborListStatistics&&) noexcept = delete;
ComputeNeighborListStatistics& operator=(const ComputeNeighborListStatistics&) = delete;
ComputeNeighborListStatistics& operator=(ComputeNeighborListStatistics&&) noexcept = delete;

Result<> operator()();

const std::atomic_bool& getCancel();

void sendThreadSafeProgressMessage(usize counter);

private:
DataStructure& m_DataStructure;
const ComputeNeighborListStatisticsInputValues* m_InputValues = nullptr;
const std::atomic_bool& m_ShouldCancel;
const IFilter::MessageHandler& m_MessageHandler;

// Thread safe Progress Message
mutable std::mutex m_ProgressMessage_Mutex;
size_t m_TotalElements = 0;
size_t m_ProgressCounter = 0;
std::chrono::steady_clock::time_point m_InitialTime = std::chrono::steady_clock::now();
};

} // namespace nx::core
Loading

0 comments on commit 3e23b69

Please sign in to comment.