Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Merge Tree (and Persistence Diagram) barycenter precision #1048

Merged
merged 22 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
86d98de
[MergeTreeDistance] disable tree structure verification if input is p…
MatPont Dec 11, 2023
4b93f77
[MergeTreeBarycenter] add debug messages
MatPont Dec 11, 2023
f92c84d
[MergeTree] debug messages
MatPont Jun 20, 2024
fe9bf94
[MergeTree] debug messages
MatPont Jul 2, 2024
d4b25c1
[MergeTree] remove some debug messages
MatPont Aug 14, 2024
f3a1e80
[MergeTree] debug messages
MatPont Aug 14, 2024
15c7f7a
[MergeTree] debug messages
MatPont Aug 21, 2024
39e8749
[MergeTree] debug messages
MatPont Aug 21, 2024
79116d8
Merge https://github.com/topology-tool-kit/ttk into pdBaryFix
MatPont Aug 21, 2024
bdc8bc6
[MergeTree] debug messages
MatPont Aug 21, 2024
327945a
[MergeTree] first barycenter fix attempt
MatPont Aug 23, 2024
c0530ae
Merge https://github.com/topology-tool-kit/ttk into pdBaryFix
MatPont Aug 29, 2024
7b9a820
[MergeTreeBarycenter] remove debug messages
MatPont Aug 29, 2024
24a2b7b
[MergeTreeBarycenter] remove debug messages
MatPont Aug 29, 2024
29e819f
[MergeTreeBarycenter] remove double to template function
MatPont Aug 29, 2024
bbeabb3
Revert "[MergeTreeBarycenter] remove debug messages"
MatPont Aug 29, 2024
7e3bd13
[MergeTreeBarycenter] debug messages
MatPont Aug 30, 2024
53c0ded
[MergeTreeBarycenter] debug messages
MatPont Aug 30, 2024
003752e
[MergeTreeBarycenter] debug
MatPont Aug 30, 2024
fd701c0
[MergeTreeBarycenter] debug
MatPont Aug 30, 2024
872689b
[MergeTreeBarycenter] remove debug messages
MatPont Aug 30, 2024
569a8e2
[MergeTreeBarycenter] volatile update
MatPont Aug 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 22 additions & 16 deletions core/base/mergeTreeClustering/MergeTreeBarycenter.h
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ namespace ttk {
std::tuple<dataType, dataType> birthDeath;
// Normalized Wasserstein
if(normalizedWasserstein_)
birthDeath = getNormalizedBirthDeathDouble<dataType>(tree1, nodeId1);
birthDeath = getNormalizedBirthDeath<dataType>(tree1, nodeId1);
// Classical Wasserstein
else
birthDeath = tree1->getBirthDeath<dataType>(nodeId1);
Expand All @@ -526,10 +526,10 @@ namespace ttk {
baryTree, nodeId, newScalarsVector, false);
dataType mu_min = getMinMaxLocalFromVector<dataType>(
baryTree, nodeId, newScalarsVector);
double newBirth = 0, newDeath = 0;
dataType newBirth = 0, newDeath = 0;

// Compute projection
double tempBirth = 0, tempDeath = 0;
dataType tempBirth = 0, tempDeath = 0;
double alphaSum = 0;
for(unsigned int i = 0; i < trees.size(); ++i)
if(nodes[i] != std::numeric_limits<ftm::idNode>::max())
Expand All @@ -539,18 +539,18 @@ namespace ttk {
if(nodes[i] != std::numeric_limits<ftm::idNode>::max()) {
auto iBirthDeath
= getParametrizedBirthDeath<dataType>(trees[i], nodes[i]);
double tTempBirth = 0, tTempDeath = 0;
dataType tTempBirth = 0, tTempDeath = 0;
tTempBirth += std::get<0>(iBirthDeath);
tTempDeath += std::get<1>(iBirthDeath);
tempBirth += tTempBirth * alphas[i] / alphaSum;
tempDeath += tTempDeath * alphas[i] / alphaSum;
}
}
double const projec = (tempBirth + tempDeath) / 2;
dataType const projec = (tempBirth + tempDeath) / 2;

// Compute newBirth and newDeath
for(unsigned int i = 0; i < trees.size(); ++i) {
double iBirth = projec, iDeath = projec;
dataType iBirth = projec, iDeath = projec;
// if node is matched in trees[i]
if(nodes[i] != std::numeric_limits<ftm::idNode>::max()) {
auto iBirthDeath
Expand All @@ -562,8 +562,12 @@ namespace ttk {
newDeath += alphas[i] * iDeath;
}
if(normalizedWasserstein_) {
newBirth = newBirth * (mu_max - mu_min) + mu_min;
newDeath = newDeath * (mu_max - mu_min) + mu_min;
// Forbid compiler optimization to have same results on different
// computers
volatile dataType tempBirthT = newBirth * (mu_max - mu_min);
volatile dataType tempDeathT = newDeath * (mu_max - mu_min);
newBirth = tempBirthT + mu_min;
newDeath = tempDeathT + mu_min;
}

return std::make_tuple(newBirth, newDeath);
Expand All @@ -584,21 +588,23 @@ namespace ttk {
= getMinMaxLocalFromVector<dataType>(baryTree, nodeB, newScalarsVector);

auto birthDeath = getParametrizedBirthDeath<dataType>(tree, nodeId);
double newBirth = std::get<0>(birthDeath);
double newDeath = std::get<1>(birthDeath);
double const projec = (newBirth + newDeath) / 2;
dataType newBirth = std::get<0>(birthDeath);
dataType newDeath = std::get<1>(birthDeath);
dataType const projec = (newBirth + newDeath) / 2;

newBirth = alpha * newBirth + (1 - alpha) * projec;
newDeath = alpha * newDeath + (1 - alpha) * projec;

if(normalizedWasserstein_) {
newBirth = newBirth * (mu_max - mu_min) + mu_min;
newDeath = newDeath * (mu_max - mu_min) + mu_min;
// Forbid compiler optimization to have same results on different
// computers
volatile dataType tempBirthT = newBirth * (mu_max - mu_min);
volatile dataType tempDeathT = newDeath * (mu_max - mu_min);
newBirth = tempBirthT + mu_min;
newDeath = tempDeathT + mu_min;
}

dataType newBirthT = newBirth;
dataType newDeathT = newDeath;
return std::make_tuple(newBirthT, newDeathT);
return std::make_tuple(newBirth, newDeath);
}

template <class dataType>
Expand Down
2 changes: 1 addition & 1 deletion core/base/mergeTreeClustering/MergeTreeDistance.h
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ namespace ttk {
ftm::MergeTree<dataType> &mTree2Int = (saveTree_ ? mTree2Copy : mTree2);
ftm::FTMTree_MT *tree1 = &(mTree1Int.tree);
ftm::FTMTree_MT *tree2 = &(mTree2Int.tree);
if(not isCalled_) {
if(not isCalled_ and not isPersistenceDiagram_) {
verifyMergeTreeStructure<dataType>(tree1);
verifyMergeTreeStructure<dataType>(tree2);
}
Expand Down
20 changes: 0 additions & 20 deletions core/base/mergeTreeClustering/MergeTreeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,26 +60,6 @@ namespace ttk {
return std::make_tuple(min, max);
}

template <class dataType>
std::tuple<double, double>
getNormalizedBirthDeathDouble(ftm::FTMTree_MT *tree,
ftm::idNode nodeId,
dataType newMin = 0.0,
dataType newMax = 1.0) {
auto birthDeath = tree->getBirthDeath<dataType>(nodeId);
double birth = std::get<0>(birthDeath);
double death = std::get<1>(birthDeath);
dataType shiftMin = getMinMaxLocal<dataType>(tree, nodeId);
dataType shiftMax = getMinMaxLocal<dataType>(tree, nodeId, false);
if((shiftMax - shiftMin) == 0)
return std::make_tuple(0, 0);
birth = (newMax - newMin) * (birth - shiftMin)
/ (shiftMax - shiftMin); // + newMin;
death = (newMax - newMin) * (death - shiftMin)
/ (shiftMax - shiftMin); // + newMin;
return std::make_tuple(birth, death);
}

template <class dataType>
std::tuple<dataType, dataType> getNormalizedBirthDeath(ftm::FTMTree_MT *tree,
ftm::idNode nodeId,
Expand Down
Loading