Skip to content

Commit

Permalink
Make MaskCompare out of core compatible (marginal speed up)
Browse files Browse the repository at this point in the history
  • Loading branch information
nyoungbq authored and imikejackson committed May 25, 2024
1 parent 783900e commit edf77c4
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ Result<> AlignSectionsMutualInformation::findShifts(std::vector<int64>& xShifts,
}
}

if (m_InputValues->UseMask)
if(m_InputValues->UseMask)
{
try
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
#include "simplnx/DataStructure/DataArray.hpp"
#include "simplnx/DataStructure/DataGroup.hpp"
#include "simplnx/DataStructure/Geometry/TriangleGeom.hpp"
#include "simplnx/Utilities/ParallelDataAlgorithm.hpp"
#include "simplnx/Utilities/DataArrayUtilities.hpp"
#include "simplnx/Utilities/ParallelDataAlgorithm.hpp"

using namespace nx::core;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#include "Silhouette.hpp"

#include "simplnx/DataStructure/DataArray.hpp"
#include "simplnx/Utilities/DataArrayUtilities.hpp"
#include "simplnx/Utilities/FilterUtilities.hpp"
#include "simplnx/Utilities/KUtilities.hpp"
#include "simplnx/Utilities/DataArrayUtilities.hpp"

#include <unordered_set>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,9 @@ Parameters ComputeKMeansFilter::parameters() const

params.insertSeparator(Parameters::Separator{"Optional Data Mask"});
params.insertLinkableParameter(std::make_unique<BoolParameter>(k_UseMask_Key, "Use Mask Array", "Specifies whether or not to use a mask array", false));
params.insert(std::make_unique<ArraySelectionParameter>(k_MaskArrayPath_Key, "Cell Mask Array", "DataPath to the boolean or uint8 mask array. Values that are true will mark that cell/point as usable.",
DataPath{}, ArraySelectionParameter::AllowedTypes{DataType::boolean, DataType::uint8}));
params.insert(std::make_unique<ArraySelectionParameter>(k_MaskArrayPath_Key, "Cell Mask Array",
"DataPath to the boolean or uint8 mask array. Values that are true will mark that cell/point as usable.", DataPath{},
ArraySelectionParameter::AllowedTypes{DataType::boolean, DataType::uint8}));

params.insertSeparator(Parameters::Separator{"Input Data Objects"});
params.insert(std::make_unique<ArraySelectionParameter>(k_SelectedArrayPath_Key, "Attribute Array to Cluster", "The array to cluster from", DataPath{}, nx::core::GetAllNumericTypes()));
Expand Down
4 changes: 2 additions & 2 deletions src/simplnx/Utilities/DataArrayUtilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,10 @@ std::unique_ptr<MaskCompare> InstantiateMaskCompare(IDataArray& maskArray)
switch(maskArray.getDataType())
{
case DataType::boolean: {
return std::make_unique<BoolMaskCompare>(dynamic_cast<BoolArray&>(maskArray));
return std::make_unique<BoolMaskCompare>(dynamic_cast<BoolArray&>(maskArray).getDataStoreRef());
}
case DataType::uint8: {
return std::make_unique<UInt8MaskCompare>(dynamic_cast<UInt8Array&>(maskArray));
return std::make_unique<UInt8MaskCompare>(dynamic_cast<UInt8Array&>(maskArray).getDataStoreRef());
}
default:
throw std::runtime_error("InstantiateMaskCompare: The Mask Array being used is NOT of type bool or uint8.");
Expand Down
41 changes: 21 additions & 20 deletions src/simplnx/Utilities/DataArrayUtilities.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -855,80 +855,81 @@ struct MaskCompare

struct BoolMaskCompare : public MaskCompare
{
BoolMaskCompare(BoolArray& array)
: m_Array(array)
BoolMaskCompare(AbstractDataStore<bool>& dataStore)
: m_DataStore(dataStore)
{
}
~BoolMaskCompare() noexcept override = default;

BoolArray& m_Array;
AbstractDataStore<bool>& m_DataStore;
bool bothTrue(usize indexA, usize indexB) const override
{
return m_Array.at(indexA) && m_Array.at(indexB);
return m_DataStore.at(indexA) && m_DataStore.at(indexB);
}
bool bothFalse(usize indexA, usize indexB) const override
{
return !m_Array.at(indexA) && !m_Array.at(indexB);
return !m_DataStore.at(indexA) && !m_DataStore.at(indexB);
}
bool isTrue(usize index) const override
{
return m_Array.at(index);
return m_DataStore.at(index);
}
void setValue(usize index, bool val) override
{
m_Array[index] = val;
m_DataStore[index] = val;
}
usize getNumberOfTuples() const override
{
return m_Array.getNumberOfTuples();
return m_DataStore.getNumberOfTuples();
}
usize getNumberOfComponents() const override
{
return m_Array.getNumberOfComponents();
return m_DataStore.getNumberOfComponents();
}

usize countTrueValues() const override
{
return std::count(m_Array.begin(), m_Array.end(), true);
return std::count(m_DataStore.begin(), m_DataStore.end(), true);
}
};

struct UInt8MaskCompare : public MaskCompare
{
UInt8MaskCompare(UInt8Array& array)
: m_Array(array)
UInt8MaskCompare(AbstractDataStore<uint8>& dataStore)
: m_DataStore(dataStore)
{
}
~UInt8MaskCompare() noexcept override = default;
UInt8Array& m_Array;

AbstractDataStore<uint8>& m_DataStore;
bool bothTrue(usize indexA, usize indexB) const override
{
return m_Array.at(indexA) != 0 && m_Array.at(indexB) != 0;
return m_DataStore.at(indexA) != 0 && m_DataStore.at(indexB) != 0;
}
bool bothFalse(usize indexA, usize indexB) const override
{
return m_Array.at(indexA) == 0 && m_Array.at(indexB) == 0;
return m_DataStore.at(indexA) == 0 && m_DataStore.at(indexB) == 0;
}
bool isTrue(usize index) const override
{
return m_Array.at(index) != 0;
return m_DataStore.at(index) != 0;
}
void setValue(usize index, bool val) override
{
m_Array[index] = static_cast<uint8>(val);
m_DataStore[index] = static_cast<uint8>(val);
}
usize getNumberOfTuples() const override
{
return m_Array.getNumberOfTuples();
return m_DataStore.getNumberOfTuples();
}
usize getNumberOfComponents() const override
{
return m_Array.getNumberOfComponents();
return m_DataStore.getNumberOfComponents();
}

usize countTrueValues() const override
{
const usize falseCount = std::count(m_Array.begin(), m_Array.end(), 0);
const usize falseCount = std::count(m_DataStore.begin(), m_DataStore.end(), 0);
return getNumberOfTuples() - falseCount;
}
};
Expand Down

0 comments on commit edf77c4

Please sign in to comment.