From 3e23b694fdd123903e85940dea8bff32d604d607 Mon Sep 17 00:00:00 2001 From: Nathan Young Date: Fri, 25 Oct 2024 14:28:03 -0400 Subject: [PATCH] ENH: Compute NeighborList Stats Updates (#1118) Added: - ComputeNeighborListStatistics execution in algorithm file - new parameter linking - parallel progress reporting - code clean up - removed empty statement from CalculateArrayStatistics.hpp Patch: - removed empty statement from CalculateArrayStatistics.hpp --- src/Plugins/SimplnxCore/CMakeLists.txt | 1 + .../Algorithms/ComputeArrayStatistics.hpp | 5 +- .../ComputeNeighborListStatistics.cpp | 254 +++++++++++++++++ .../ComputeNeighborListStatistics.hpp | 70 +++++ .../ComputeNeighborListStatisticsFilter.cpp | 255 ++++-------------- 5 files changed, 375 insertions(+), 210 deletions(-) create mode 100644 src/Plugins/SimplnxCore/src/SimplnxCore/Filters/Algorithms/ComputeNeighborListStatistics.cpp create mode 100644 src/Plugins/SimplnxCore/src/SimplnxCore/Filters/Algorithms/ComputeNeighborListStatistics.hpp diff --git a/src/Plugins/SimplnxCore/CMakeLists.txt b/src/Plugins/SimplnxCore/CMakeLists.txt index 8874f2eb02..894f2cc119 100644 --- a/src/Plugins/SimplnxCore/CMakeLists.txt +++ b/src/Plugins/SimplnxCore/CMakeLists.txt @@ -166,6 +166,7 @@ set(AlgorithmList ComputeLargestCrossSections ComputeMomentInvariants2D ComputeNeighborhoods + ComputeNeighborListStatistics ComputeSurfaceAreaToVolume ComputeTriangleGeomCentroids ComputeVectorColors diff --git a/src/Plugins/SimplnxCore/src/SimplnxCore/Filters/Algorithms/ComputeArrayStatistics.hpp b/src/Plugins/SimplnxCore/src/SimplnxCore/Filters/Algorithms/ComputeArrayStatistics.hpp index 3802632fe7..6f0119a1f6 100644 --- a/src/Plugins/SimplnxCore/src/SimplnxCore/Filters/Algorithms/ComputeArrayStatistics.hpp +++ b/src/Plugins/SimplnxCore/src/SimplnxCore/Filters/Algorithms/ComputeArrayStatistics.hpp @@ -15,11 +15,11 @@ namespace nx::core struct SIMPLNXCORE_EXPORT ComputeArrayStatisticsInputValues { - bool FindHistogram; float64 MinRange; float64 MaxRange; - bool UseFullRange; int32 NumBins; + bool FindHistogram; + bool UseFullRange; bool FindLength; bool FindMin; bool FindMax; @@ -30,7 +30,6 @@ struct SIMPLNXCORE_EXPORT ComputeArrayStatisticsInputValues bool FindStdDeviation; bool FindSummation; bool UseMask; - ; bool ComputeByIndex; bool StandardizeData; bool FindNumUniqueValues; diff --git a/src/Plugins/SimplnxCore/src/SimplnxCore/Filters/Algorithms/ComputeNeighborListStatistics.cpp b/src/Plugins/SimplnxCore/src/SimplnxCore/Filters/Algorithms/ComputeNeighborListStatistics.cpp new file mode 100644 index 0000000000..38ee20dc53 --- /dev/null +++ b/src/Plugins/SimplnxCore/src/SimplnxCore/Filters/Algorithms/ComputeNeighborListStatistics.cpp @@ -0,0 +1,254 @@ +#include "ComputeNeighborListStatistics.hpp" + +#include "simplnx/Utilities/DataArrayUtilities.hpp" +#include "simplnx/Utilities/FilterUtilities.hpp" +#include "simplnx/Utilities/Math/StatisticsCalculations.hpp" +#include "simplnx/Utilities/ParallelDataAlgorithm.hpp" + +using namespace nx::core; + +namespace +{ +constexpr int64 k_BoolTypeNeighborList = -6802; +constexpr int64 k_EmptyNeighborList = -6803; + +template +class ComputeNeighborListStatisticsImpl +{ +public: + using NeighborListType = NeighborList; + using DataArrayType = DataArray; + using StoreType = AbstractDataStore; + + ComputeNeighborListStatisticsImpl(ComputeNeighborListStatistics* filter, const INeighborList& source, bool length, bool min, bool max, bool mean, bool median, bool stdDeviation, bool summation, + std::vector& arrays, const std::atomic_bool& shouldCancel) + : m_Filter(filter) + , m_Source(source) + , m_Length(length) + , m_Min(min) + , m_Max(max) + , m_Mean(mean) + , m_Median(median) + , m_StdDeviation(stdDeviation) + , m_Summation(summation) + , m_Arrays(arrays) + , m_ShouldCancel(shouldCancel) + { + } + + virtual ~ComputeNeighborListStatisticsImpl() = default; + + void compute(usize start, usize end) const + { + auto* array0 = m_Length ? m_Arrays[0]->template getIDataStoreAs>() : nullptr; + if(m_Length && array0 == nullptr) + { + throw std::invalid_argument("ComputeNeighborListStatisticsFilter::compute() could not dynamic_cast 'Length' array to needed type. Check input array selection."); + } + auto* array1 = m_Min ? m_Arrays[1]->template getIDataStoreAs() : nullptr; + if(m_Min && array1 == nullptr) + { + throw std::invalid_argument("ComputeNeighborListStatisticsFilter::compute() could not dynamic_cast 'Min' array to needed type. Check input array selection."); + } + auto* array2 = m_Max ? m_Arrays[2]->template getIDataStoreAs() : nullptr; + if(m_Max && array2 == nullptr) + { + throw std::invalid_argument("ComputeNeighborListStatisticsFilter::compute() could not dynamic_cast 'Max' array to needed type. Check input array selection."); + } + auto* array3 = m_Mean ? m_Arrays[3]->template getIDataStoreAs>() : nullptr; + if(m_Mean && array3 == nullptr) + { + throw std::invalid_argument("ComputeNeighborListStatisticsFilter::compute() could not dynamic_cast 'Mean' array to needed type. Check input array selection."); + } + auto* array4 = m_Median ? m_Arrays[4]->template getIDataStoreAs>() : nullptr; + if(m_Median && array4 == nullptr) + { + throw std::invalid_argument("ComputeNeighborListStatisticsFilter::compute() could not dynamic_cast 'Median' array to needed type. Check input array selection."); + } + auto* array5 = m_StdDeviation ? m_Arrays[5]->template getIDataStoreAs>() : nullptr; + if(m_StdDeviation && array5 == nullptr) + { + throw std::invalid_argument("ComputeNeighborListStatisticsFilter::compute() could not dynamic_cast 'StdDev' array to needed type. Check input array selection."); + } + auto* array6 = m_Summation ? m_Arrays[6]->template getIDataStoreAs>() : nullptr; + if(m_Summation && array6 == nullptr) + { + throw std::invalid_argument("ComputeNeighborListStatisticsFilter::compute() could not dynamic_cast 'Summation' array to needed type. Check input array selection."); + } + + const auto& sourceList = dynamic_cast(m_Source); + + auto tStart = std::chrono::steady_clock::now(); + usize counter = 0; + for(usize i = start; i < end; i++) + { + if(m_ShouldCancel) + { + return; + } + + const std::vector& tmpList = sourceList.at(i); + + if(m_Length) + { + auto val = static_cast(tmpList.size()); + array0->setValue(i, val); + } + if(m_Min) + { + T val = StatisticsCalculations::findMin(tmpList); + array1->setValue(i, val); + } + if(m_Max) + { + T val = StatisticsCalculations::findMax(tmpList); + array2->setValue(i, val); + } + if(m_Mean) + { + float val = StatisticsCalculations::findMean(tmpList); + array3->setValue(i, val); + } + if(m_Median) + { + float val = StatisticsCalculations::findMedian(tmpList); + array4->setValue(i, val); + } + if(m_StdDeviation) + { + float val = StatisticsCalculations::findStdDeviation(tmpList); + array5->setValue(i, val); + } + if(m_Summation) + { + float val = StatisticsCalculations::findSummation(tmpList); + array6->setValue(i, val); + } + + counter++; + if(std::chrono::duration_cast(std::chrono::steady_clock::now() - tStart).count() > 1000) + { + m_Filter->sendThreadSafeProgressMessage(counter); + counter = 0; + } + } + m_Filter->sendThreadSafeProgressMessage(counter); + } + + void operator()(const Range& range) const + { + compute(range.min(), range.max()); + } + +private: + ComputeNeighborListStatistics* m_Filter = nullptr; + const std::atomic_bool& m_ShouldCancel; + + const INeighborList& m_Source; + bool m_Length = false; + bool m_Min = false; + bool m_Max = false; + bool m_Mean = false; + bool m_Median = false; + bool m_StdDeviation = false; + bool m_Summation = false; + + std::vector& m_Arrays; +}; +} // namespace + +// ----------------------------------------------------------------------------- +ComputeNeighborListStatistics::ComputeNeighborListStatistics(DataStructure& dataStructure, const IFilter::MessageHandler& msgHandler, const std::atomic_bool& shouldCancel, + ComputeNeighborListStatisticsInputValues* inputValues) +: m_DataStructure(dataStructure) +, m_InputValues(inputValues) +, m_ShouldCancel(shouldCancel) +, m_MessageHandler(msgHandler) +{ +} + +// ----------------------------------------------------------------------------- +ComputeNeighborListStatistics::~ComputeNeighborListStatistics() noexcept = default; + +// ----------------------------------------------------------------------------- +Result<> ComputeNeighborListStatistics::operator()() +{ + const auto& inputINeighborList = m_DataStructure.getDataRefAs(m_InputValues->TargetNeighborListPath); + + DataType type = inputINeighborList.getDataType(); + if(type == DataType::boolean) + { + return MakeErrorResult(k_BoolTypeNeighborList, fmt::format("ComputeNeighborListStatisticsFilter::NeighborList {} was of type boolean, and thus cannot be processed", inputINeighborList.getName())); + } + + usize numTuples = inputINeighborList.getNumberOfTuples(); + if(numTuples == 0) + { + return MakeErrorResult(k_EmptyNeighborList, fmt::format("ComputeNeighborListStatisticsFilter::NeighborList {} was empty", inputINeighborList.getName())); + } + + std::vector arrays(7, nullptr); + + if(m_InputValues->FindLength) + { + arrays[0] = m_DataStructure.getDataAs(m_InputValues->LengthPath); + } + if(m_InputValues->FindMin) + { + arrays[1] = m_DataStructure.getDataAs(m_InputValues->MinPath); + } + if(m_InputValues->FindMax) + { + arrays[2] = m_DataStructure.getDataAs(m_InputValues->MaxPath); + } + if(m_InputValues->FindMean) + { + arrays[3] = m_DataStructure.getDataAs(m_InputValues->MeanPath); + } + if(m_InputValues->FindMedian) + { + arrays[4] = m_DataStructure.getDataAs(m_InputValues->MedianPath); + } + if(m_InputValues->FindStdDeviation) + { + arrays[5] = m_DataStructure.getDataAs(m_InputValues->StdDeviationPath); + } + if(m_InputValues->FindSummation) + { + arrays[6] = m_DataStructure.getDataAs(m_InputValues->SummationPath); + } + + // Fill progress counters for parallel updates + m_ProgressCounter = 0; + m_TotalElements = numTuples; + + // Allow data-based parallelization + ParallelDataAlgorithm dataAlg; + dataAlg.setRange(0, numTuples); + ExecuteParallelFunction(type, dataAlg, this, inputINeighborList, m_InputValues->FindLength, m_InputValues->FindMin, m_InputValues->FindMax, + m_InputValues->FindMean, m_InputValues->FindMedian, m_InputValues->FindStdDeviation, m_InputValues->FindSummation, arrays, + m_ShouldCancel); + + return {}; +} + +// ----------------------------------------------------------------------------- +const std::atomic_bool& ComputeNeighborListStatistics::getCancel() +{ + return m_ShouldCancel; +} + +// ----------------------------------------------------------------------------- +void ComputeNeighborListStatistics::sendThreadSafeProgressMessage(usize counter) +{ + const std::lock_guard guard(m_ProgressMessage_Mutex); + + m_ProgressCounter += counter; + + if(std::chrono::duration_cast(std::chrono::steady_clock::now() - m_InitialTime).count() > 1000) + { + auto progressInt = static_cast((static_cast(m_ProgressCounter) / static_cast(m_TotalElements)) * 100.0); + m_MessageHandler(IFilter::Message::Type::Info, fmt::format("Finding Statistics || {}% Completed", progressInt)); + m_InitialTime = std::chrono::steady_clock::now(); + } +} diff --git a/src/Plugins/SimplnxCore/src/SimplnxCore/Filters/Algorithms/ComputeNeighborListStatistics.hpp b/src/Plugins/SimplnxCore/src/SimplnxCore/Filters/Algorithms/ComputeNeighborListStatistics.hpp new file mode 100644 index 0000000000..8b3e5a73c0 --- /dev/null +++ b/src/Plugins/SimplnxCore/src/SimplnxCore/Filters/Algorithms/ComputeNeighborListStatistics.hpp @@ -0,0 +1,70 @@ +#pragma once + +#include "SimplnxCore/SimplnxCore_export.hpp" + +#include "simplnx/DataStructure/DataArray.hpp" +#include "simplnx/DataStructure/DataPath.hpp" +#include "simplnx/DataStructure/DataStructure.hpp" +#include "simplnx/Filter/IFilter.hpp" + +#include +#include + +namespace nx::core +{ + +struct SIMPLNXCORE_EXPORT ComputeNeighborListStatisticsInputValues +{ + bool FindLength = false; + bool FindMin = false; + bool FindMax = false; + bool FindMean = false; + bool FindMedian = false; + bool FindStdDeviation = false; + bool FindSummation = false; + + DataPath TargetNeighborListPath = {}; + + DataPath LengthPath = {}; + DataPath MinPath = {}; + DataPath MaxPath = {}; + DataPath MeanPath = {}; + DataPath MedianPath = {}; + DataPath StdDeviationPath = {}; + DataPath SummationPath = {}; +}; + +/** + * @class + */ +class SIMPLNXCORE_EXPORT ComputeNeighborListStatistics +{ +public: + ComputeNeighborListStatistics(DataStructure& dataStructure, const IFilter::MessageHandler& msgHandler, const std::atomic_bool& shouldCancel, ComputeNeighborListStatisticsInputValues* inputValues); + ~ComputeNeighborListStatistics() noexcept; + + ComputeNeighborListStatistics(const ComputeNeighborListStatistics&) = delete; + ComputeNeighborListStatistics(ComputeNeighborListStatistics&&) noexcept = delete; + ComputeNeighborListStatistics& operator=(const ComputeNeighborListStatistics&) = delete; + ComputeNeighborListStatistics& operator=(ComputeNeighborListStatistics&&) noexcept = delete; + + Result<> operator()(); + + const std::atomic_bool& getCancel(); + + void sendThreadSafeProgressMessage(usize counter); + +private: + DataStructure& m_DataStructure; + const ComputeNeighborListStatisticsInputValues* m_InputValues = nullptr; + const std::atomic_bool& m_ShouldCancel; + const IFilter::MessageHandler& m_MessageHandler; + + // Thread safe Progress Message + mutable std::mutex m_ProgressMessage_Mutex; + size_t m_TotalElements = 0; + size_t m_ProgressCounter = 0; + std::chrono::steady_clock::time_point m_InitialTime = std::chrono::steady_clock::now(); +}; + +} // namespace nx::core diff --git a/src/Plugins/SimplnxCore/src/SimplnxCore/Filters/ComputeNeighborListStatisticsFilter.cpp b/src/Plugins/SimplnxCore/src/SimplnxCore/Filters/ComputeNeighborListStatisticsFilter.cpp index 67ca3ab552..f802c6d254 100644 --- a/src/Plugins/SimplnxCore/src/SimplnxCore/Filters/ComputeNeighborListStatisticsFilter.cpp +++ b/src/Plugins/SimplnxCore/src/SimplnxCore/Filters/ComputeNeighborListStatisticsFilter.cpp @@ -1,6 +1,7 @@ #include "ComputeNeighborListStatisticsFilter.hpp" -#include "simplnx/DataStructure/DataArray.hpp" +#include "SimplnxCore/Filters/Algorithms/ComputeNeighborListStatistics.hpp" + #include "simplnx/DataStructure/INeighborList.hpp" #include "simplnx/DataStructure/NeighborList.hpp" #include "simplnx/Filter/Actions/CreateArrayAction.hpp" @@ -8,153 +9,14 @@ #include "simplnx/Parameters/BoolParameter.hpp" #include "simplnx/Parameters/DataObjectNameParameter.hpp" #include "simplnx/Parameters/NeighborListSelectionParameter.hpp" -#include "simplnx/Utilities/Math/StatisticsCalculations.hpp" -#include "simplnx/Utilities/ParallelAlgorithmUtilities.hpp" - #include "simplnx/Utilities/SIMPLConversion.hpp" -#include "simplnx/Utilities/ParallelDataAlgorithm.hpp" - namespace nx::core { namespace { constexpr int64 k_NoAction = -6800; constexpr int64 k_MissingInputArray = -6801; -constexpr int64 k_BoolTypeNeighborList = -6802; -constexpr int64 k_EmptyNeighborList = -6803; - -template -class ComputeNeighborListStatisticsImpl -{ -public: - using NeighborListType = NeighborList; - - ComputeNeighborListStatisticsImpl(const IFilter* filter, INeighborList& source, bool length, bool min, bool max, bool mean, bool median, bool stdDeviation, bool summation, - std::vector& arrays) - : m_Filter(filter) - , m_Source(source) - , m_Length(length) - , m_Min(min) - , m_Max(max) - , m_Mean(mean) - , m_Median(median) - , m_StdDeviation(stdDeviation) - , m_Summation(summation) - , m_Arrays(arrays) - { - } - - virtual ~ComputeNeighborListStatisticsImpl() = default; - - void compute(usize start, usize end) const - { - if constexpr(std::is_same_v) - { - return; - } - - using DataArrayType = DataArray; - using StoreType = AbstractDataStore; - - auto* array0 = m_Length ? m_Arrays[0]->template getIDataStoreAs>() : nullptr; - if(m_Length && array0 == nullptr) - { - throw std::invalid_argument("ComputeNeighborListStatisticsFilter::compute() could not dynamic_cast 'Length' array to needed type. Check input array selection."); - } - auto* array1 = m_Min ? m_Arrays[1]->template getIDataStoreAs() : nullptr; - if(m_Min && array1 == nullptr) - { - throw std::invalid_argument("ComputeNeighborListStatisticsFilter::compute() could not dynamic_cast 'Min' array to needed type. Check input array selection."); - } - auto* array2 = m_Max ? m_Arrays[2]->template getIDataStoreAs() : nullptr; - if(m_Max && array2 == nullptr) - { - throw std::invalid_argument("ComputeNeighborListStatisticsFilter::compute() could not dynamic_cast 'Max' array to needed type. Check input array selection."); - } - auto* array3 = m_Mean ? m_Arrays[3]->template getIDataStoreAs>() : nullptr; - if(m_Mean && array3 == nullptr) - { - throw std::invalid_argument("ComputeNeighborListStatisticsFilter::compute() could not dynamic_cast 'Mean' array to needed type. Check input array selection."); - } - auto* array4 = m_Median ? m_Arrays[4]->template getIDataStoreAs>() : nullptr; - if(m_Median && array4 == nullptr) - { - throw std::invalid_argument("ComputeNeighborListStatisticsFilter::compute() could not dynamic_cast 'Median' array to needed type. Check input array selection."); - } - auto* array5 = m_StdDeviation ? m_Arrays[5]->template getIDataStoreAs>() : nullptr; - if(m_StdDeviation && array5 == nullptr) - { - throw std::invalid_argument("ComputeNeighborListStatisticsFilter::compute() could not dynamic_cast 'StdDev' array to needed type. Check input array selection."); - } - auto* array6 = m_Summation ? m_Arrays[6]->template getIDataStoreAs>() : nullptr; - if(m_Summation && array6 == nullptr) - { - throw std::invalid_argument("ComputeNeighborListStatisticsFilter::compute() could not dynamic_cast 'Summation' array to needed type. Check input array selection."); - } - - auto& sourceList = dynamic_cast(m_Source); - - for(usize i = start; i < end; i++) - { - const std::vector& tmpList = sourceList[i]; - - if(m_Length) - { - auto val = static_cast(tmpList.size()); - array0->setValue(i, val); - } - if(m_Min) - { - T val = StatisticsCalculations::findMin(tmpList); - array1->setValue(i, val); - } - if(m_Max) - { - T val = StatisticsCalculations::findMax(tmpList); - array2->setValue(i, val); - } - if(m_Mean) - { - float val = StatisticsCalculations::findMean(tmpList); - array3->setValue(i, val); - } - if(m_Median) - { - float val = StatisticsCalculations::findMedian(tmpList); - array4->setValue(i, val); - } - if(m_StdDeviation) - { - float val = StatisticsCalculations::findStdDeviation(tmpList); - array5->setValue(i, val); - } - if(m_Summation) - { - float val = StatisticsCalculations::findSummation(tmpList); - array6->setValue(i, val); - } - } - } - - void operator()(const Range& range) const - { - compute(range.min(), range.max()); - } - -private: - const IFilter* m_Filter = nullptr; - INeighborList& m_Source; - bool m_Length = false; - bool m_Min = false; - bool m_Max = false; - bool m_Mean = false; - bool m_Median = false; - bool m_StdDeviation = false; - bool m_Summation = false; - - std::vector& m_Arrays; -}; //------------------------------------------------------------------------------ OutputActions CreateCompatibleArrays(const DataStructure& dataStructure, const Arguments& args) @@ -257,14 +119,14 @@ Parameters ComputeNeighborListStatisticsFilter::parameters() const Parameters params; params.insertSeparator(Parameters::Separator{"Input Parameter(s)"}); - params.insert(std::make_unique(k_FindLength_Key, "Find Length", "Specifies whether or not the filter creates the Length array during calculations", true)); - params.insert(std::make_unique(k_FindMinimum_Key, "Find Minimum", "Specifies whether or not the filter creates the Minimum array during calculations", true)); - params.insert(std::make_unique(k_FindMaximum_Key, "Find Maximum", "Specifies whether or not the filter creates the Maximum array during calculations", true)); - params.insert(std::make_unique(k_FindMean_Key, "Find Mean", "Specifies whether or not the filter creates the Mean array during calculations", true)); - params.insert(std::make_unique(k_FindMedian_Key, "Find Median", "Specifies whether or not the filter creates the Median array during calculations", true)); - params.insert( + params.insertLinkableParameter(std::make_unique(k_FindLength_Key, "Find Length", "Specifies whether or not the filter creates the Length array during calculations", true)); + params.insertLinkableParameter(std::make_unique(k_FindMinimum_Key, "Find Minimum", "Specifies whether or not the filter creates the Minimum array during calculations", true)); + params.insertLinkableParameter(std::make_unique(k_FindMaximum_Key, "Find Maximum", "Specifies whether or not the filter creates the Maximum array during calculations", true)); + params.insertLinkableParameter(std::make_unique(k_FindMean_Key, "Find Mean", "Specifies whether or not the filter creates the Mean array during calculations", true)); + params.insertLinkableParameter(std::make_unique(k_FindMedian_Key, "Find Median", "Specifies whether or not the filter creates the Median array during calculations", true)); + params.insertLinkableParameter( std::make_unique(k_FindStandardDeviation_Key, "Find Standard Deviation", "Specifies whether or not the filter creates the Standard Deviation array during calculations", true)); - params.insert(std::make_unique(k_FindSummation_Key, "Find Summation", "Specifies whether or not the filter creates the Summation array during calculations", true)); + params.insertLinkableParameter(std::make_unique(k_FindSummation_Key, "Find Summation", "Specifies whether or not the filter creates the Summation array during calculations", true)); params.insertSeparator(Parameters::Separator{"Input Data Objects"}); params.insert(std::make_unique(k_InputNeighborListPath_Key, "NeighborList to Compute Statistics", "Input Data Array to compute statistics", DataPath(), @@ -278,6 +140,16 @@ Parameters ComputeNeighborListStatisticsFilter::parameters() const params.insert(std::make_unique(k_MedianName_Key, "Median", "Path to create the Median array during calculations", "Median")); params.insert(std::make_unique(k_StandardDeviationName_Key, "Standard Deviation", "Path to create the Standard Deviation array during calculations", "StandardDeviation")); params.insert(std::make_unique(k_SummationName_Key, "Summation", "Path to create the Summation array during calculations", "Summation")); + + // Associate the Linkable Parameter(s) to the children parameters that they control + params.linkParameters(k_FindLength_Key, k_LengthName_Key, true); + params.linkParameters(k_FindMinimum_Key, k_MinimumName_Key, true); + params.linkParameters(k_FindMaximum_Key, k_MaximumName_Key, true); + params.linkParameters(k_FindMean_Key, k_MeanName_Key, true); + params.linkParameters(k_FindMedian_Key, k_MedianName_Key, true); + params.linkParameters(k_FindStandardDeviation_Key, k_StandardDeviationName_Key, true); + params.linkParameters(k_FindSummation_Key, k_SummationName_Key, true); + return params; } @@ -307,11 +179,9 @@ IFilter::PreflightResult ComputeNeighborListStatisticsFilter::preflightImpl(cons auto inputArrayPath = args.value(k_InputNeighborListPath_Key); - // if(!findMin && !findMax && !findMean && !findMedian && !findStdDeviation && !findSummation && !findLength) { - std::string ss = fmt::format("No statistics have been selected"); - return {nonstd::make_unexpected(std::vector{Error{k_NoAction, ss}})}; + return MakePreflightErrorResult(k_NoAction, "No statistics have been selected"); } std::vector dataArrayPaths; @@ -319,8 +189,7 @@ IFilter::PreflightResult ComputeNeighborListStatisticsFilter::preflightImpl(cons auto inputArray = dataStructure.getDataAs(inputArrayPath); if(inputArray == nullptr) { - std::string ss = fmt::format("Missing input array"); - return {nonstd::make_unexpected(std::vector{Error{k_MissingInputArray, ss}})}; + return MakePreflightErrorResult(k_MissingInputArray, "Missing input array"); } dataArrayPaths.push_back(inputArrayPath); @@ -332,82 +201,54 @@ IFilter::PreflightResult ComputeNeighborListStatisticsFilter::preflightImpl(cons Result<> ComputeNeighborListStatisticsFilter::executeImpl(DataStructure& dataStructure, const Arguments& args, const PipelineFilter* pipelineNode, const MessageHandler& messageHandler, const std::atomic_bool& shouldCancel) const { - auto findLength = args.value(k_FindLength_Key); - auto findMin = args.value(k_FindMinimum_Key); - auto findMax = args.value(k_FindMaximum_Key); - auto findMean = args.value(k_FindMean_Key); - auto findMedian = args.value(k_FindMedian_Key); - auto findStdDeviation = args.value(k_FindStandardDeviation_Key); - auto findSummation = args.value(k_FindSummation_Key); + ComputeNeighborListStatisticsInputValues inputValues; - if(!findMin && !findMax && !findMean && !findMedian && !findStdDeviation && findSummation && !findLength) + inputValues.FindLength = args.value(k_FindLength_Key); + inputValues.FindMin = args.value(k_FindMinimum_Key); + inputValues.FindMax = args.value(k_FindMaximum_Key); + inputValues.FindMean = args.value(k_FindMean_Key); + inputValues.FindMedian = args.value(k_FindMedian_Key); + inputValues.FindStdDeviation = args.value(k_FindStandardDeviation_Key); + inputValues.FindSummation = args.value(k_FindSummation_Key); + + if(!inputValues.FindMin && !inputValues.FindMax && !inputValues.FindMean && !inputValues.FindMedian && !inputValues.FindStdDeviation && !inputValues.FindSummation && !inputValues.FindLength) { return {}; } - auto inputArrayPath = args.value(k_InputNeighborListPath_Key); - auto& inputArray = dataStructure.getDataRefAs(inputArrayPath); - const DataPath outputGroupPath = inputArrayPath.getParent(); + inputValues.TargetNeighborListPath = args.value(k_InputNeighborListPath_Key); + const DataPath outputGroupPath = inputValues.TargetNeighborListPath.getParent(); - std::vector arrays(7, nullptr); - - if(findLength) + if(inputValues.FindLength) { - auto lengthPath = outputGroupPath.createChildPath(args.value(k_LengthName_Key)); - arrays[0] = dataStructure.getDataAs(lengthPath); + inputValues.LengthPath = outputGroupPath.createChildPath(args.value(k_LengthName_Key)); } - if(findMin) + if(inputValues.FindMin) { - auto minPath = outputGroupPath.createChildPath(args.value(k_MinimumName_Key)); - arrays[1] = dataStructure.getDataAs(minPath); + inputValues.MinPath = outputGroupPath.createChildPath(args.value(k_MinimumName_Key)); } - if(findMax) + if(inputValues.FindMax) { - auto maxPath = outputGroupPath.createChildPath(args.value(k_MaximumName_Key)); - arrays[2] = dataStructure.getDataAs(maxPath); + inputValues.MaxPath = outputGroupPath.createChildPath(args.value(k_MaximumName_Key)); } - if(findMean) + if(inputValues.FindMean) { - auto meanPath = outputGroupPath.createChildPath(args.value(k_MeanName_Key)); - arrays[3] = dataStructure.getDataAs(meanPath); + inputValues.MeanPath = outputGroupPath.createChildPath(args.value(k_MeanName_Key)); } - if(findMedian) - { - auto medianPath = outputGroupPath.createChildPath(args.value(k_MedianName_Key)); - arrays[4] = dataStructure.getDataAs(medianPath); - } - if(findStdDeviation) - { - auto stdDeviationPath = outputGroupPath.createChildPath(args.value(k_StandardDeviationName_Key)); - arrays[5] = dataStructure.getDataAs(stdDeviationPath); - } - if(findSummation) + if(inputValues.FindMedian) { - auto summationPath = outputGroupPath.createChildPath(args.value(k_SummationName_Key)); - arrays[6] = dataStructure.getDataAs(summationPath); + inputValues.MedianPath = outputGroupPath.createChildPath(args.value(k_MedianName_Key)); } - - DataType type = inputArray.getDataType(); - if(type == DataType::boolean) + if(inputValues.FindStdDeviation) { - std::string ss = fmt::format("ComputeNeighborListStatisticsFilter::NeighborList {} was of type boolean, and thus cannot be processed", inputArray.getName()); - return {nonstd::make_unexpected(std::vector{Error{k_BoolTypeNeighborList, ss}})}; + inputValues.StdDeviationPath = outputGroupPath.createChildPath(args.value(k_StandardDeviationName_Key)); } - - usize numTuples = inputArray.getNumberOfTuples(); - if(numTuples == 0) + if(inputValues.FindSummation) { - std::string ss = fmt::format("ComputeNeighborListStatisticsFilter::NeighborList {} was empty", inputArray.getName()); - return {nonstd::make_unexpected(std::vector{Error{k_EmptyNeighborList, ss}})}; + inputValues.SummationPath = outputGroupPath.createChildPath(args.value(k_SummationName_Key)); } - // Allow data-based parallelization - ParallelDataAlgorithm dataAlg; - dataAlg.setRange(0, numTuples); - ExecuteParallelFunction(type, dataAlg, this, inputArray, findLength, findMin, findMax, findMean, findMedian, findStdDeviation, findSummation, - arrays); - - return {}; + return ComputeNeighborListStatistics(dataStructure, messageHandler, shouldCancel, &inputValues)(); } namespace