Skip to content

Commit

Permalink
Fixing torch dependency issue
Browse files Browse the repository at this point in the history
  • Loading branch information
MohamedKISSI committed Jul 5, 2024
1 parent b8c6768 commit 12c3512
Showing 1 changed file with 24 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,14 @@ namespace ttk {
const int64_t &i,
std::vector<int64_t> &neighborsIndices) const;

/*
This function allows you to copy the values of a pytorch tensor
to a vector in an optimized way.
*/
/*
This function allows you to copy the values of a pytorch tensor
to a vector in an optimized way.
*/
#ifdef TTK_ENABLE_TORCH
int tensorToVectorFast(const torch::Tensor &tensor,
std::vector<double> &result) const;

#endif
/*
Given a coordinate vector this function returns the value of maximum
and minimum for each axis and the number of coordinates per axis.
Expand Down Expand Up @@ -230,6 +231,7 @@ ttk::BackendTopologicalOptimization::BackendTopologicalOptimization() {
this->setDebugMsgPrefix("BackendTopologicalOptimization");
}

#ifdef TTK_ENABLE_TORCH
class PersistenceGradientDescent : public torch::nn::Module,
public ttk::BackendTopologicalOptimization {
public:
Expand All @@ -239,6 +241,8 @@ class PersistenceGradientDescent : public torch::nn::Module,
torch::Tensor X;
};

#endif

/*
Find all neighbors of a vertex i.
Variable :
Expand Down Expand Up @@ -1079,6 +1083,7 @@ void ttk::BackendTopologicalOptimization::getIndices(
This function allows you to copy the values of a pytorch tensor
to a vector in an optimized way.
*/
#ifdef TTK_ENABLE_TORCH
int ttk::BackendTopologicalOptimization::tensorToVectorFast(
const torch::Tensor &tensor, std::vector<double> &result) const {
TORCH_CHECK(
Expand All @@ -1088,6 +1093,7 @@ int ttk::BackendTopologicalOptimization::tensorToVectorFast(

return 0;
}
#endif

/*
Given a coordinate vector this function returns the value of maximum
Expand Down Expand Up @@ -1161,7 +1167,6 @@ std::vector<std::vector<double>>
return resultat;
}

#ifdef TTK_ENABLE_TORCH
template <typename dataType, typename triangulationType>
int ttk::BackendTopologicalOptimization::execute(
const dataType *const inputScalars,
Expand All @@ -1172,6 +1177,13 @@ int ttk::BackendTopologicalOptimization::execute(

Timer t;
double stoppingCondition = 0;
bool enableTorch = true;

#ifndef TTK_ENABLE_TORCH
this->printWrn("Adam cannot be used because Torch hasn't been found. The "
"code will now default to direct gradient descent.");
enableTorch = false;
#endif

//=======================
// Copy input data
Expand All @@ -1195,7 +1207,7 @@ int ttk::BackendTopologicalOptimization::execute(
//========================================
// Direct gradient descent
//========================================
if(methodOptimization_ == 0) {
if((methodOptimization_ == 0) || !(enableTorch)) {
std::vector<double> smoothedScalars = dataVector;
ttk::DiagramType currentConstraintDiagram = constraintDiagram;
std::vector<int64_t> listAllIndicesToChangeSmoothing(vertexNumber_, 0);
Expand Down Expand Up @@ -1548,9 +1560,10 @@ int ttk::BackendTopologicalOptimization::execute(
}
}

//=======================================
// Adam Optimization
//=======================================
//=======================================
// Adam Optimization
//=======================================
#ifdef TTK_ENABLE_TORCH
else if(methodOptimization_ == 1) {
//=====================================================
// Initialization of model parameters
Expand Down Expand Up @@ -1739,7 +1752,7 @@ int ttk::BackendTopologicalOptimization::execute(
outputScalars[k] = 0;
}
}

#endif
//========================================
// Information display
//========================================
Expand All @@ -1762,4 +1775,3 @@ int ttk::BackendTopologicalOptimization::execute(

return 0;
}
#endif

0 comments on commit 12c3512

Please sign in to comment.