diff --git a/api/mrcpp_declarations.h b/api/mrcpp_declarations.h index f6501b726..d21058409 100644 --- a/api/mrcpp_declarations.h +++ b/api/mrcpp_declarations.h @@ -34,7 +34,7 @@ namespace mrcpp { class Timer; class Printer; -template class Plotter; +template class Plotter; template class Gaussian; template class GaussFunc; @@ -42,26 +42,26 @@ template class GaussPoly; template class GaussExp; template class BoundingBox; -template class NodeBox; +template class NodeBox; template class NodeIndex; template class NodeIndexComp; -class SharedMemory; +template class SharedMemory; class ScalingBasis; class LegendreBasis; class InterpolatingBasis; -template class RepresentableFunction; +template class RepresentableFunction; template class MultiResolutionAnalysis; -template class MWTree; -template class FunctionTree; +template class MWTree; +template class FunctionTree; class OperatorTree; -template class NodeAllocator; +template class NodeAllocator; -template class MWNode; -template class FunctionNode; +template class MWNode; +template class FunctionNode; class OperatorNode; template class IdentityConvolution; @@ -79,31 +79,30 @@ template class DerivativeKernel; class PoissonKernel; class HelmholtzKernel; -template class TreeBuilder; -template class TreeCalculator; -template class DefaultCalculator; -template class ProjectionCalculator; -template class AdditionCalculator; -template class MultiplicationCalculator; -template class ConvolutionCalculator; -template class DerivativeCalculator; +template class TreeBuilder; +template class TreeCalculator; +template class DefaultCalculator; +template class ProjectionCalculator; +template class AdditionCalculator; +template class MultiplicationCalculator; +template class ConvolutionCalculator; +template class DerivativeCalculator; class CrossCorrelationCalculator; -template class TreeAdaptor; -template class AnalyticAdaptor; -template class WaveletAdaptor; -template class CopyAdaptor; +template class TreeAdaptor; +template class AnalyticAdaptor; +template class WaveletAdaptor; +template class CopyAdaptor; -template class TreeIterator; -template class IteratorNode; +template class TreeIterator; +template class IteratorNode; class BandWidth; -template class OperatorState; +template class OperatorState; template using Coord = std::array; -template using MWNodeVector = std::vector *>; +template using MWNodeVector = std::vector *>; -template using FMap_ = std::function; -typedef FMap_ FMap; +template using FMap = std::function; } // namespace mrcpp diff --git a/examples/derivative.cpp b/examples/derivative.cpp index bc0475db9..bd33bcfe7 100644 --- a/examples/derivative.cpp +++ b/examples/derivative.cpp @@ -51,8 +51,8 @@ int main(int argc, char **argv) { mrcpp::FunctionTree err_tree(MRA); // Projecting functions - mrcpp::project(prec, f_tree, f); - mrcpp::project(prec, df_tree, df); + mrcpp::project(prec, f_tree, f); + mrcpp::project(prec, df_tree, df); // Applying derivative operator mrcpp::apply(dg_tree, D_00, f_tree, 0); diff --git a/examples/mpi_matrix.cpp b/examples/mpi_matrix.cpp index f3580d158..69c370a70 100644 --- a/examples/mpi_matrix.cpp +++ b/examples/mpi_matrix.cpp @@ -54,7 +54,7 @@ int main(int argc, char **argv) { }; mrcpp::FunctionTree<3> *tree = new mrcpp::FunctionTree<3>(MRA); if (i % wsize == wrank) { - mrcpp::project<3>(prec, *tree, f); + mrcpp::project<3, double>(prec, *tree, f); tree->normalize(); } f_vec.push_back(std::make_tuple(1.0, tree)); diff --git a/examples/mpi_send_tree.cpp b/examples/mpi_send_tree.cpp index 44ffd0dec..aff8be379 100644 --- a/examples/mpi_send_tree.cpp +++ b/examples/mpi_send_tree.cpp @@ -55,7 +55,7 @@ int main(int argc, char **argv) { mrcpp::FunctionTree f_tree(MRA); // Only rank 0 projects the function - if (wrank == 0) mrcpp::project(prec, f_tree, f); + if (wrank == 0) mrcpp::project(prec, f_tree, f); { // Print data before send auto integral = f_tree.integrate(); diff --git a/examples/mpi_shared_tree.cpp b/examples/mpi_shared_tree.cpp index aa59d8204..ba7f7db5e 100644 --- a/examples/mpi_shared_tree.cpp +++ b/examples/mpi_shared_tree.cpp @@ -63,12 +63,12 @@ int main(int argc, char **argv) { }; // Initialize a shared memory tree, max 100MB - auto shared_mem = new mrcpp::SharedMemory(scomm, 100); + auto shared_mem = new mrcpp::SharedMemory(scomm, 100); mrcpp::FunctionTree f_tree(MRA, shared_mem); // Only first rank projects auto frank = 0; - if (srank == frank) mrcpp::project(prec, f_tree, f); + if (srank == frank) mrcpp::project(prec, f_tree, f); mrcpp::share_tree(f_tree, frank, 0, scomm); { // Print data after share diff --git a/examples/projection.cpp b/examples/projection.cpp index 92f1a7b53..9243485fb 100644 --- a/examples/projection.cpp +++ b/examples/projection.cpp @@ -37,7 +37,7 @@ int main(int argc, char **argv) { // Projecting function mrcpp::FunctionTree f_tree(MRA); - mrcpp::project(prec, f_tree, f, -1); + mrcpp::project(prec, f_tree, f, -1); auto integral = f_tree.integrate(); mrcpp::print::header(0, "Projecting analytic function"); diff --git a/examples/scf.cpp b/examples/scf.cpp index f91058f14..fe34d936b 100644 --- a/examples/scf.cpp +++ b/examples/scf.cpp @@ -31,7 +31,7 @@ void setupNuclearPotential(double Z, FunctionTree &V) { }; // Projecting function - project(prec, V, f); + project(prec, V, f); print::footer(0, timer, 2); Printer::setPrintLevel(oldlevel); @@ -48,7 +48,7 @@ void setupInitialGuess(FunctionTree &phi) { }; // Projecting and normalizing function - project(prec, phi, f); + project(prec, phi, f); phi.normalize(); print::footer(0, timer, 2); diff --git a/examples/schrodinger_semigroup1d.cpp b/examples/schrodinger_semigroup1d.cpp index 2c1de2fa8..6035aa3c3 100644 --- a/examples/schrodinger_semigroup1d.cpp +++ b/examples/schrodinger_semigroup1d.cpp @@ -93,13 +93,13 @@ int main(int argc, char **argv) // Projecting functions mrcpp::FunctionTree<1> Re_f_tree(MRA); - mrcpp::project<1>(prec, Re_f_tree, Re_f); + mrcpp::project<1, double>(prec, Re_f_tree, Re_f); mrcpp::FunctionTree<1> Im_f_tree(MRA); - mrcpp::project<1>(prec, Im_f_tree, Im_f); + mrcpp::project<1, double>(prec, Im_f_tree, Im_f); mrcpp::FunctionTree<1> Re_g_tree(MRA); - mrcpp::project<1>(prec, Re_g_tree, Re_g); + mrcpp::project<1, double>(prec, Re_g_tree, Re_g); mrcpp::FunctionTree<1> Im_g_tree(MRA); - mrcpp::project<1>(prec, Im_g_tree, Im_g); + mrcpp::project<1, double>(prec, Im_g_tree, Im_g); // Output function trees mrcpp::FunctionTree<1> Re_fout_tree(MRA); diff --git a/examples/tree_cleaner.cpp b/examples/tree_cleaner.cpp index 6d970c5e3..dd4d85a05 100644 --- a/examples/tree_cleaner.cpp +++ b/examples/tree_cleaner.cpp @@ -9,6 +9,7 @@ const auto order = 7; const auto prec = 1.0e-5; const auto D = 3; + int main(int argc, char **argv) { auto timer = mrcpp::Timer(); @@ -42,14 +43,14 @@ int main(int argc, char **argv) { auto iter = 0; auto n_nodes = 1; while (n_nodes > 0) { - mrcpp::project(-1.0, f_tree, f); // Projecting on fixed grid + mrcpp::project(-1.0, f_tree, f); // Projecting on fixed grid n_nodes = mrcpp::refine_grid(f_tree, prec); // Refine grid mrcpp::clear_grid(f_tree); // Clear MW coefs printout(0, " iter " << std::setw(3) << iter++ << std::setw(45)); printout(0, " n_nodes " << std::setw(5) << n_nodes << std::endl); } // Projecting on final converged grid - mrcpp::project(-1.0, f_tree, f); + mrcpp::project(-1.0, f_tree, f); auto integral = f_tree.integrate(); auto sq_norm = f_tree.getSquareNorm(); diff --git a/src/functions/AnalyticFunction.h b/src/functions/AnalyticFunction.h index abf0fcbd6..7043d7fe6 100644 --- a/src/functions/AnalyticFunction.h +++ b/src/functions/AnalyticFunction.h @@ -32,29 +32,29 @@ namespace mrcpp { -template class AnalyticFunction : public RepresentableFunction { +template class AnalyticFunction : public RepresentableFunction { public: AnalyticFunction() = default; ~AnalyticFunction() override = default; - AnalyticFunction(std::function &r)> f, const double *a = nullptr, const double *b = nullptr) - : RepresentableFunction(a, b) + AnalyticFunction(std::function &r)> f, const double *a = nullptr, const double *b = nullptr) + : RepresentableFunction(a, b) , func(f) {} - AnalyticFunction(std::function &r)> f, + AnalyticFunction(std::function &r)> f, const std::vector &a, const std::vector &b) : AnalyticFunction(f, a.data(), b.data()) {} - void set(std::function &r)> f) { this->func = f; } + void set(std::function &r)> f) { this->func = f; } - double evalf(const Coord &r) const override { - double val = 0.0; + T evalf(const Coord &r) const override { + T val = 0.0; if (not this->outOfBounds(r)) val = this->func(r); return val; } protected: - std::function &r)> func; + std::function &r)> func; }; } // namespace mrcpp diff --git a/src/functions/BoysFunction.cpp b/src/functions/BoysFunction.cpp index 71b705139..0a3364845 100644 --- a/src/functions/BoysFunction.cpp +++ b/src/functions/BoysFunction.cpp @@ -32,7 +32,7 @@ namespace mrcpp { BoysFunction::BoysFunction(int n, double p) - : RepresentableFunction<1>() + : RepresentableFunction<1, double>() , order(n) , prec(p) , MRA(BoundingBox<1>(), InterpolatingBasis(13)) {} @@ -50,8 +50,8 @@ double BoysFunction::evalf(const Coord<1> &r) const { return std::exp(-xt_2) * t_2n; }; - FunctionTree<1> tree(this->MRA); - mrcpp::project<1>(this->prec, tree, f); + FunctionTree<1, double> tree(this->MRA); + mrcpp::project<1, double>(this->prec, tree, f); double result = tree.integrate(); Printer::setPrintLevel(oldlevel); diff --git a/src/functions/BoysFunction.h b/src/functions/BoysFunction.h index 4dc76bd72..cc5cc1916 100644 --- a/src/functions/BoysFunction.h +++ b/src/functions/BoysFunction.h @@ -30,7 +30,7 @@ namespace mrcpp { -class BoysFunction final : public RepresentableFunction<1> { + class BoysFunction final : public RepresentableFunction<1, double> { public: BoysFunction(int n, double prec = 1.0e-10); diff --git a/src/functions/GaussExp.h b/src/functions/GaussExp.h index aa6ad4da3..f33549ec1 100644 --- a/src/functions/GaussExp.h +++ b/src/functions/GaussExp.h @@ -51,7 +51,7 @@ namespace mrcpp { * */ -template class GaussExp : public RepresentableFunction { + template class GaussExp : public RepresentableFunction { public: GaussExp(int nTerms = 0, double prec = GAUSS_EXP_PREC); GaussExp(const GaussExp &gExp); diff --git a/src/functions/Gaussian.h b/src/functions/Gaussian.h index d02cc43b1..7e79e052a 100644 --- a/src/functions/Gaussian.h +++ b/src/functions/Gaussian.h @@ -40,7 +40,7 @@ namespace mrcpp { -template class Gaussian : public RepresentableFunction { + template class Gaussian : public RepresentableFunction { public: Gaussian(double a, double c, const Coord &r, const std::array &p); Gaussian(const std::array &a, double c, const Coord &r, const std::array &p); diff --git a/src/functions/Polynomial.cpp b/src/functions/Polynomial.cpp index 397b4e268..964fe687b 100644 --- a/src/functions/Polynomial.cpp +++ b/src/functions/Polynomial.cpp @@ -45,7 +45,7 @@ namespace mrcpp { /** Construct polynomial of order zero with given size and bounds. * Includes default constructor. */ Polynomial::Polynomial(int k, const double *a, const double *b) - : RepresentableFunction<1>(a, b) { + : RepresentableFunction<1, double>(a, b) { assert(k >= 0); this->N = 1.0; this->L = 0.0; diff --git a/src/functions/Polynomial.h b/src/functions/Polynomial.h index e1c23e4a6..fadc2c988 100644 --- a/src/functions/Polynomial.h +++ b/src/functions/Polynomial.h @@ -44,7 +44,7 @@ namespace mrcpp { -class Polynomial : public RepresentableFunction<1> { + class Polynomial : public RepresentableFunction<1, double> { public: Polynomial(int k = 0, const double *a = nullptr, const double *b = nullptr); Polynomial(int k, const std::vector &a, const std::vector &b) diff --git a/src/functions/RepresentableFunction.cpp b/src/functions/RepresentableFunction.cpp index 8687297c7..3c55ac92b 100644 --- a/src/functions/RepresentableFunction.cpp +++ b/src/functions/RepresentableFunction.cpp @@ -38,7 +38,7 @@ namespace mrcpp { -template RepresentableFunction::RepresentableFunction(const double *a, const double *b) { +template RepresentableFunction::RepresentableFunction(const double *a, const double *b) { if (a == nullptr or b == nullptr) { this->bounded = false; this->A = nullptr; @@ -56,7 +56,7 @@ template RepresentableFunction::RepresentableFunction(const double *a } /** Constructs a new function with same bounds as the input function */ -template RepresentableFunction::RepresentableFunction(const RepresentableFunction &func) { +template RepresentableFunction::RepresentableFunction(const RepresentableFunction &func) { if (func.isBounded()) { this->bounded = true; this->A = new double[D]; @@ -74,11 +74,11 @@ template RepresentableFunction::RepresentableFunction(const Represent /** Copies function, not bounds. Use copy constructor if you want an * identical function. */ -template RepresentableFunction &RepresentableFunction::operator=(const RepresentableFunction &func) { +template RepresentableFunction &RepresentableFunction::operator=(const RepresentableFunction &func) { return *this; } -template RepresentableFunction::~RepresentableFunction() { +template RepresentableFunction::~RepresentableFunction() { if (this->isBounded()) { delete[] this->A; delete[] this->B; @@ -87,7 +87,7 @@ template RepresentableFunction::~RepresentableFunction() { this->B = nullptr; } -template void RepresentableFunction::setBounds(const double *a, const double *b) { +template void RepresentableFunction::setBounds(const double *a, const double *b) { if (a == nullptr or b == nullptr) { MSG_ERROR("Invalid arguments"); } if (not isBounded()) { this->bounded = true; @@ -101,7 +101,7 @@ template void RepresentableFunction::setBounds(const double *a, const } } -template bool RepresentableFunction::outOfBounds(const Coord &r) const { +template bool RepresentableFunction::outOfBounds(const Coord &r) const { if (not isBounded()) { return false; } for (int d = 0; d < D; d++) { if (r[d] < getLowerBound(d)) return true; @@ -110,8 +110,11 @@ template bool RepresentableFunction::outOfBounds(const Coord &r) c return false; } -template class RepresentableFunction<1>; -template class RepresentableFunction<2>; -template class RepresentableFunction<3>; +template class RepresentableFunction<1, double>; +template class RepresentableFunction<2, double>; +template class RepresentableFunction<3, double>; +template class RepresentableFunction<1, ComplexDouble>; +template class RepresentableFunction<2, ComplexDouble>; +template class RepresentableFunction<3, ComplexDouble>; } // namespace mrcpp diff --git a/src/functions/RepresentableFunction.h b/src/functions/RepresentableFunction.h index 2d6998812..82381beaa 100644 --- a/src/functions/RepresentableFunction.h +++ b/src/functions/RepresentableFunction.h @@ -38,20 +38,21 @@ #include "MRCPP/constants.h" #include "MRCPP/mrcpp_declarations.h" #include "trees/NodeIndex.h" +#include "MRCPP/utils/math_utils.h" namespace mrcpp { -template class RepresentableFunction { +template class RepresentableFunction { public: RepresentableFunction(const double *a = nullptr, const double *b = nullptr); RepresentableFunction(const std::vector &a, const std::vector &b) : RepresentableFunction(a.data(), b.data()) {} - RepresentableFunction(const RepresentableFunction &func); - RepresentableFunction &operator=(const RepresentableFunction &func); + RepresentableFunction(const RepresentableFunction &func); + RepresentableFunction &operator=(const RepresentableFunction &func); virtual ~RepresentableFunction(); /** @returns Function value in a point @param[in] r: Cartesian coordinate */ - virtual double evalf(const Coord &r) const = 0; + virtual T evalf(const Coord &r) const = 0; void setBounds(const double *a, const double *b); void clearBounds(); @@ -65,7 +66,7 @@ template class RepresentableFunction { const double *getLowerBounds() const { return this->A; } const double *getUpperBounds() const { return this->B; } - friend class AnalyticAdaptor; + friend class AnalyticAdaptor; protected: bool bounded; diff --git a/src/functions/function_utils.cpp b/src/functions/function_utils.cpp index 60b287e1c..39f30a938 100644 --- a/src/functions/function_utils.cpp +++ b/src/functions/function_utils.cpp @@ -117,4 +117,5 @@ double function_utils::ObaraSaika_ab(int power_a, int power_b, double pos_a, dou template double function_utils::calc_overlap<1>(const GaussFunc<1> &a, const GaussFunc<1> &b); template double function_utils::calc_overlap<2>(const GaussFunc<2> &a, const GaussFunc<2> &b); template double function_utils::calc_overlap<3>(const GaussFunc<3> &a, const GaussFunc<3> &b); + } // namespace mrcpp diff --git a/src/operators/OperatorState.h b/src/operators/OperatorState.h index 245d9f70f..855f53060 100644 --- a/src/operators/OperatorState.h +++ b/src/operators/OperatorState.h @@ -42,9 +42,9 @@ namespace mrcpp { #define GET_OP_IDX(FT, GT, ID) (2 * ((GT >> ID) & 1) + ((FT >> ID) & 1)) -template class OperatorState final { +template class OperatorState final { public: - OperatorState(MWNode &gn, double *scr1) + OperatorState(MWNode &gn, T *scr1) : gNode(&gn) { this->kp1 = this->gNode->getKp1(); this->kp1_d = this->gNode->getKp1_d(); @@ -53,7 +53,7 @@ template class OperatorState final { this->gData = this->gNode->getCoefs(); this->maxDeltaL = -1; - double *scr2 = scr1 + this->kp1_d; + T *scr2 = scr1 + this->kp1_d; for (int i = 1; i < D; i++) { if (IS_ODD(i)) { @@ -64,9 +64,9 @@ template class OperatorState final { } } - OperatorState(MWNode &gn, std::vector scr1) + OperatorState(MWNode &gn, std::vector scr1) : OperatorState(gn, scr1.data()) {} - void setFNode(MWNode &fn) { + void setFNode(MWNode &fn) { this->fNode = &fn; this->fData = this->fNode->getCoefs(); } @@ -86,15 +86,16 @@ template class OperatorState final { int getMaxDeltaL() const { return this->maxDeltaL; } int getOperIndex(int i) const { return GET_OP_IDX(this->ft, this->gt, i); } - double **getAuxData() { return this->aux; } + T **getAuxData() { return this->aux; } double **getOperData() { return this->oData; } - friend class ConvolutionCalculator; - friend class DerivativeCalculator; + friend class ConvolutionCalculator; + friend class DerivativeCalculator; private: int ft; int gt; + int maxDeltaL; double fThreshold; double gThreshold; @@ -104,13 +105,13 @@ template class OperatorState final { int kp1_d; int kp1_dm1; - MWNode *gNode; - MWNode *fNode; + MWNode *gNode; + MWNode *fNode; NodeIndex *fIdx; - double *aux[D + 1]; - double *gData; - double *fData; + T *aux[D + 1]; + T *gData; + T *fData; double *oData[D]; void calcMaxDeltaL() { diff --git a/src/operators/OperatorStatistics.cpp b/src/operators/OperatorStatistics.cpp index d542e88f5..4ed0263cc 100644 --- a/src/operators/OperatorStatistics.cpp +++ b/src/operators/OperatorStatistics.cpp @@ -30,8 +30,8 @@ using namespace Eigen; namespace mrcpp { -template -OperatorStatistics::OperatorStatistics() +template +OperatorStatistics::OperatorStatistics() : nThreads(mrcpp_get_max_threads()) , totFCount(0) , totGCount(0) @@ -58,7 +58,7 @@ OperatorStatistics::OperatorStatistics() } } -template OperatorStatistics::~OperatorStatistics() { +template OperatorStatistics::~OperatorStatistics() { for (int i = 0; i < this->nThreads; i++) { delete this->compCount[i]; } delete[] this->compCount; delete[] this->fCount; @@ -68,7 +68,7 @@ template OperatorStatistics::~OperatorStatistics() { } /** Sum all node counters from all threads. */ -template void OperatorStatistics::flushNodeCounters() { +template void OperatorStatistics::flushNodeCounters() { for (int i = 0; i < this->nThreads; i++) { this->totFCount += this->fCount[i]; this->totGCount += this->gCount[i]; @@ -82,20 +82,20 @@ template void OperatorStatistics::flushNodeCounters() { } /** Increment g-node usage counter. Needed for load balancing. */ -template void OperatorStatistics::incrementGNodeCounters(const MWNode &gNode) { +template void OperatorStatistics::incrementGNodeCounters(const MWNode &gNode) { int thread = mrcpp_get_thread_num(); this->gCount[thread]++; } /** Increment operator application counter. */ -template void OperatorStatistics::incrementFNodeCounters(const MWNode &fNode, int ft, int gt) { +template void OperatorStatistics::incrementFNodeCounters(const MWNode &fNode, int ft, int gt) { int thread = mrcpp_get_thread_num(); this->fCount[thread]++; (*this->compCount[thread])(ft, gt) += 1; if (fNode.isGenNode()) { this->genCount[thread]++; } } -template std::ostream &OperatorStatistics::print(std::ostream &o) const { +template std::ostream &OperatorStatistics::print(std::ostream &o) const { o << std::setw(8); o << "*OperatorFunc statistics: " << std::endl << std::endl; o << " Total calculated gNodes : " << this->totGCount << std::endl; @@ -105,8 +105,12 @@ template std::ostream &OperatorStatistics::print(std::ostream &o) con return o; } -template class OperatorStatistics<1>; -template class OperatorStatistics<2>; -template class OperatorStatistics<3>; +template class OperatorStatistics<1, double>; +template class OperatorStatistics<2, double>; +template class OperatorStatistics<3, double>; + +template class OperatorStatistics<1, ComplexDouble>; +template class OperatorStatistics<2, ComplexDouble>; +template class OperatorStatistics<3, ComplexDouble>; } // namespace mrcpp diff --git a/src/operators/OperatorStatistics.h b/src/operators/OperatorStatistics.h index 395a5d62a..9de97f8e0 100644 --- a/src/operators/OperatorStatistics.h +++ b/src/operators/OperatorStatistics.h @@ -32,14 +32,14 @@ namespace mrcpp { -template class OperatorStatistics final { + template class OperatorStatistics final { public: OperatorStatistics(); ~OperatorStatistics(); void flushNodeCounters(); - void incrementFNodeCounters(const MWNode &fNode, int ft, int gt); - void incrementGNodeCounters(const MWNode &gNode); + void incrementFNodeCounters(const MWNode &fNode, int ft, int gt); + void incrementGNodeCounters(const MWNode &gNode); friend std::ostream &operator<<(std::ostream &o, const OperatorStatistics &os) { return os.print(o); } diff --git a/src/treebuilders/AdditionCalculator.h b/src/treebuilders/AdditionCalculator.h index 431600192..a7804a761 100644 --- a/src/treebuilders/AdditionCalculator.h +++ b/src/treebuilders/AdditionCalculator.h @@ -30,24 +30,24 @@ namespace mrcpp { -template class AdditionCalculator final : public TreeCalculator { +template class AdditionCalculator final : public TreeCalculator { public: - AdditionCalculator(const FunctionTreeVector &inp) + AdditionCalculator(const FunctionTreeVector &inp) : sum_vec(inp) {} private: - FunctionTreeVector sum_vec; + FunctionTreeVector sum_vec; - void calcNode(MWNode &node_o) override { + void calcNode(MWNode &node_o) override { node_o.zeroCoefs(); const NodeIndex &idx = node_o.getNodeIndex(); - double *coefs_o = node_o.getCoefs(); + T *coefs_o = node_o.getCoefs(); for (int i = 0; i < this->sum_vec.size(); i++) { - double c_i = get_coef(this->sum_vec, i); - FunctionTree &func_i = get_func(this->sum_vec, i); + T c_i = get_coef(this->sum_vec, i); + FunctionTree &func_i = get_func(this->sum_vec, i); // This generates missing nodes - const MWNode &node_i = func_i.getNode(idx); - const double *coefs_i = node_i.getCoefs(); + const MWNode &node_i = func_i.getNode(idx); + const T *coefs_i = node_i.getCoefs(); int n_coefs = node_i.getNCoefs(); for (int j = 0; j < n_coefs; j++) { coefs_o[j] += c_i * coefs_i[j]; } } diff --git a/src/treebuilders/AnalyticAdaptor.h b/src/treebuilders/AnalyticAdaptor.h index 45f73b4cd..d735933ec 100644 --- a/src/treebuilders/AnalyticAdaptor.h +++ b/src/treebuilders/AnalyticAdaptor.h @@ -30,16 +30,16 @@ namespace mrcpp { -template class AnalyticAdaptor final : public TreeAdaptor { + template class AnalyticAdaptor final : public TreeAdaptor { public: - AnalyticAdaptor(const RepresentableFunction &f, int ms) - : TreeAdaptor(ms) + AnalyticAdaptor(const RepresentableFunction &f, int ms) + : TreeAdaptor(ms) , func(&f) {} private: - const RepresentableFunction *func; + const RepresentableFunction *func; - bool splitNode(const MWNode &node) const override { + bool splitNode(const MWNode &node) const override { int scale = node.getScale(); int nQuadPts = node.getKp1(); if (this->func->isVisibleAtScale(scale, nQuadPts)) return false; diff --git a/src/treebuilders/ConvolutionCalculator.cpp b/src/treebuilders/ConvolutionCalculator.cpp index 668c86dbf..7da95813b 100644 --- a/src/treebuilders/ConvolutionCalculator.cpp +++ b/src/treebuilders/ConvolutionCalculator.cpp @@ -46,8 +46,8 @@ using Eigen::MatrixXi; namespace mrcpp { -template -ConvolutionCalculator::ConvolutionCalculator(double p, ConvolutionOperator &o, FunctionTree &f, int depth) +template +ConvolutionCalculator::ConvolutionCalculator(double p, ConvolutionOperator &o, FunctionTree &f, int depth) : maxDepth(depth) , prec(p) , oper(&o) @@ -57,14 +57,14 @@ ConvolutionCalculator::ConvolutionCalculator(double p, ConvolutionOperator initTimers(); } -template ConvolutionCalculator::~ConvolutionCalculator() { +template ConvolutionCalculator::~ConvolutionCalculator() { clearTimers(); this->operStat.flushNodeCounters(); println(10, this->operStat); for (int i = 0; i < this->bandSizes.size(); i++) { delete this->bandSizes[i]; } } -template void ConvolutionCalculator::initTimers() { +template void ConvolutionCalculator::initTimers() { int nThreads = mrcpp_get_max_threads(); for (int i = 0; i < nThreads; i++) { this->band_t.push_back(new Timer(false)); @@ -73,7 +73,7 @@ template void ConvolutionCalculator::initTimers() { } } -template void ConvolutionCalculator::clearTimers() { +template void ConvolutionCalculator::clearTimers() { int nThreads = mrcpp_get_max_threads(); for (int i = 0; i < nThreads; i++) { delete this->band_t[i]; @@ -85,7 +85,7 @@ template void ConvolutionCalculator::clearTimers() { this->norm_t.clear(); } -template void ConvolutionCalculator::printTimers() const { +template void ConvolutionCalculator::printTimers() const { int oldprec = Printer::setPrecision(1); int nThreads = mrcpp_get_max_threads(); printout(20, "\n\nthread "); @@ -102,7 +102,7 @@ template void ConvolutionCalculator::printTimers() const { /** Initialize the number of nodes formally within the bandwidth of an operator. The band size is used for thresholding. */ -template void ConvolutionCalculator::initBandSizes() { +template void ConvolutionCalculator::initBandSizes() { for (int i = 0; i < this->oper->size(); i++) { // IMPORTANT: only 0-th dimension! const OperatorTree &oTree = this->oper->getComponent(i, 0); @@ -118,7 +118,7 @@ template void ConvolutionCalculator::initBandSizes() { * of an operator. Currently this routine ignores the fact that * there are edges on the world box, and thus over estimates * the number of nodes. This is different from the previous version. */ -template void ConvolutionCalculator::calcBandSizeFactor(MatrixXi &bs, int depth, const BandWidth &bw) { +template void ConvolutionCalculator::calcBandSizeFactor(MatrixXi &bs, int depth, const BandWidth &bw) { for (int gt = 0; gt < this->nComp; gt++) { for (int ft = 0; ft < this->nComp; ft++) { int k = gt * this->nComp + ft; @@ -139,8 +139,8 @@ template void ConvolutionCalculator::calcBandSizeFactor(MatrixXi &bs, } /** Return a vector of nodes in F affected by O, given a node in G */ -template MWNodeVector *ConvolutionCalculator::makeOperBand(const MWNode &gNode, std::vector> &idx_band) { - auto *band = new MWNodeVector; +template MWNodeVector *ConvolutionCalculator::makeOperBand(const MWNode &gNode, std::vector> &idx_band) { + auto *band = new MWNodeVector; int o_depth = gNode.getScale() - this->oper->getOperatorRoot(); int g_depth = gNode.getDepth(); @@ -150,7 +150,7 @@ template MWNodeVector *ConvolutionCalculator::makeOperBand(const M int reach = this->oper->getOperatorReach(); if (width >= 0) { - const NodeBox &fWorld = this->fTree->getRootBox(); + const NodeBox &fWorld = this->fTree->getRootBox(); const NodeIndex &cIdx = fWorld.getCornerIndex(); const NodeIndex &gIdx = gNode.getNodeIndex(); @@ -180,7 +180,7 @@ template MWNodeVector *ConvolutionCalculator::makeOperBand(const M } /** Recursively retrieve all reachable f-nodes within the bandwidth. */ -template void ConvolutionCalculator::fillOperBand(MWNodeVector *band, std::vector> &idx_band, NodeIndex &idx, const int *nbox, int dim) { +template void ConvolutionCalculator::fillOperBand(MWNodeVector *band, std::vector> &idx_band, NodeIndex &idx, const int *nbox, int dim) { int l_start = idx[dim]; for (int j = 0; j < nbox[dim]; j++) { // Recurse until dim == 0 @@ -190,7 +190,7 @@ template void ConvolutionCalculator::fillOperBand(MWNodeVector *ba continue; } if (not manipulateOperator) { - MWNode &fNode = this->fTree->getNode(idx); + MWNode &fNode = this->fTree->getNode(idx); idx_band.push_back(idx); band->push_back(&fNode); @@ -198,18 +198,18 @@ template void ConvolutionCalculator::fillOperBand(MWNodeVector *ba const auto oper_scale = this->oper->getOperatorRoot(); if (oper_scale == 0) { if (periodic::in_unit_cell(idx) and onUnitcell) { - MWNode &fNode = this->fTree->getNode(idx); + MWNode &fNode = this->fTree->getNode(idx); idx_band.push_back(idx); band->push_back(&fNode); } if (not periodic::in_unit_cell(idx) and not onUnitcell) { - MWNode &fNode = this->fTree->getNode(idx); + MWNode &fNode = this->fTree->getNode(idx); idx_band.push_back(idx); band->push_back(&fNode); } } else if (oper_scale < 0) { if (periodic::in_unit_cell(idx) and onUnitcell) { - MWNode &fNode = this->fTree->getNode(idx); + MWNode &fNode = this->fTree->getNode(idx); idx_band.push_back(idx); band->push_back(&fNode); } @@ -222,23 +222,23 @@ template void ConvolutionCalculator::fillOperBand(MWNodeVector *ba idx[dim] = l_start; } -template void ConvolutionCalculator::calcNode(MWNode &node) { - auto &gNode = static_cast &>(node); +template void ConvolutionCalculator::calcNode(MWNode &node) { + auto &gNode = static_cast &>(node); gNode.zeroCoefs(); int o_depth = gNode.getScale() - this->oper->getOperatorRoot(); if (manipulateOperator and this->oper->getOperatorRoot() < 0) o_depth = gNode.getDepth(); - double tmpCoefs[gNode.getNCoefs()]; - OperatorState os(gNode, tmpCoefs); + T tmpCoefs[gNode.getNCoefs()]; + OperatorState os(gNode, tmpCoefs); this->operStat.incrementGNodeCounters(gNode); // Get all nodes in f within the bandwith of O in g this->band_t[mrcpp_get_thread_num()]->resume(); std::vector> idx_band; - MWNodeVector *fBand = makeOperBand(gNode, idx_band); + MWNodeVector *fBand = makeOperBand(gNode, idx_band); this->band_t[mrcpp_get_thread_num()]->stop(); - MWTree &gTree = gNode.getMWTree(); + MWTree &gTree = gNode.getMWTree(); double gThrs = gTree.getSquareNorm(); if (gThrs > 0.0) { auto nTerms = static_cast(this->oper->size()); @@ -250,7 +250,7 @@ template void ConvolutionCalculator::calcNode(MWNode &node) { this->calc_t[mrcpp_get_thread_num()]->resume(); for (int n = 0; n < fBand->size(); n++) { - MWNode &fNode = *(*fBand)[n]; + MWNode &fNode = *(*fBand)[n]; NodeIndex &fIdx = idx_band[n]; os.setFNode(fNode); os.setFIndex(fIdx); @@ -275,7 +275,7 @@ template void ConvolutionCalculator::calcNode(MWNode &node) { } /** Apply each component (term) of the operator expansion to a node in f */ -template void ConvolutionCalculator::applyOperComp(OperatorState &os) { + template void ConvolutionCalculator::applyOperComp(OperatorState &os) { double fNorm = os.fNode->getComponentNorm(os.ft); int o_depth = os.fNode->getScale() - this->oper->getOperatorRoot(); for (int i = 0; i < this->oper->size(); i++) { @@ -288,7 +288,6 @@ template void ConvolutionCalculator::applyOperComp(OperatorState & } } - /** @brief Apply a single operator component (term) to a single f-node. * * @details Apply a single operator component (term) to a single f-node. @@ -296,9 +295,10 @@ template void ConvolutionCalculator::applyOperComp(OperatorState & * Here we make use of the sparcity of matrices \f$ A, B, C \f$. * */ -template void ConvolutionCalculator::applyOperator(int i, OperatorState &os) { - MWNode &gNode = *os.gNode; - MWNode &fNode = *os.fNode; + template void ConvolutionCalculator::applyOperator(int i, OperatorState &os) { + MWNode &gNode = *os.gNode; + MWNode &fNode = *os.fNode; + const NodeIndex &fIdx = *os.fIdx; const NodeIndex &gIdx = gNode.getNodeIndex(); int o_depth = gNode.getScale() - this->oper->getOperatorRoot(); @@ -331,9 +331,10 @@ template void ConvolutionCalculator::applyOperator(int i, OperatorSta /** Perorm the required linear algebra operations in order to apply an operator component to a f-node in a n-dimensional tesor space. */ -template void ConvolutionCalculator::tensorApplyOperComp(OperatorState &os) { - double **aux = os.getAuxData(); + template void ConvolutionCalculator::tensorApplyOperComp(OperatorState &os) { + T **aux = os.getAuxData(); double **oData = os.getOperData(); + /* #ifdef HAVE_BLAS double mult = 0.0; for (int i = 0; i < D; i++) { @@ -358,9 +359,10 @@ template void ConvolutionCalculator::tensorApplyOperComp(OperatorStat } } #else + */ for (int i = 0; i < D; i++) { - Eigen::Map f(aux[i], os.kp1, os.kp1_dm1); - Eigen::Map g(aux[i + 1], os.kp1_dm1, os.kp1); + Eigen::Map> f(aux[i], os.kp1, os.kp1_dm1); + Eigen::Map> g(aux[i + 1], os.kp1_dm1, os.kp1); if (oData[i] != nullptr) { Eigen::Map op(oData[i], os.kp1, os.kp1); if (i == D - 1) { // Last dir: Add up into g @@ -377,10 +379,10 @@ template void ConvolutionCalculator::tensorApplyOperComp(OperatorStat } } } -#endif + //#endif } -template void ConvolutionCalculator::touchParentNodes(MWTree &tree) const { +template void ConvolutionCalculator::touchParentNodes(MWTree &tree) const { if (not manipulateOperator) { const auto oper_scale = this->oper->getOperatorRoot(); auto car_prod = math_utils::cartesian_product(std::vector{-1, 0}, D); @@ -396,15 +398,19 @@ template void ConvolutionCalculator::touchParentNodes(MWTree &tree } } -template MWNodeVector *ConvolutionCalculator::getInitialWorkVector(MWTree &tree) const { - auto *nodeVec = new MWNodeVector; +template MWNodeVector *ConvolutionCalculator::getInitialWorkVector(MWTree &tree) const { + auto *nodeVec = new MWNodeVector; if (tree.isPeriodic()) touchParentNodes(tree); tree_utils::make_node_table(tree, *nodeVec); return nodeVec; } -template class ConvolutionCalculator<1>; -template class ConvolutionCalculator<2>; -template class ConvolutionCalculator<3>; +template class ConvolutionCalculator<1, double>; +template class ConvolutionCalculator<2, double>; +template class ConvolutionCalculator<3, double>; + +template class ConvolutionCalculator<1, ComplexDouble>; +template class ConvolutionCalculator<2, ComplexDouble>; +template class ConvolutionCalculator<3, ComplexDouble>; } // namespace mrcpp diff --git a/src/treebuilders/ConvolutionCalculator.h b/src/treebuilders/ConvolutionCalculator.h index 3b88cb9b1..f114ba976 100644 --- a/src/treebuilders/ConvolutionCalculator.h +++ b/src/treebuilders/ConvolutionCalculator.h @@ -33,12 +33,12 @@ namespace mrcpp { -template class ConvolutionCalculator final : public TreeCalculator { +template class ConvolutionCalculator final : public TreeCalculator { public: - ConvolutionCalculator(double p, ConvolutionOperator &o, FunctionTree &f, int depth = MaxDepth); + ConvolutionCalculator(double p, ConvolutionOperator &o, FunctionTree &f, int depth = MaxDepth); ~ConvolutionCalculator() override; - MWNodeVector *getInitialWorkVector(MWTree &tree) const override; + MWNodeVector *getInitialWorkVector(MWTree &tree) const override; void setPrecFunction(const std::function &idx)> &prec_func) { this->precFunc = prec_func; } void startManipulateOperator(bool excUnit) { @@ -52,45 +52,45 @@ template class ConvolutionCalculator final : public TreeCalculator { bool manipulateOperator{false}; bool onUnitcell{false}; ConvolutionOperator *oper; - FunctionTree *fTree; + FunctionTree *fTree; std::vector band_t; std::vector calc_t; std::vector norm_t; - OperatorStatistics operStat; + OperatorStatistics operStat; std::vector bandSizes; std::function &idx)> precFunc = [](const NodeIndex &idx) { return 1.0; }; static const int nComp = (1 << D); static const int nComp2 = (1 << D) * (1 << D); - MWNodeVector *makeOperBand(const MWNode &gNode, std::vector> &idx_band); - void fillOperBand(MWNodeVector *band, std::vector> &idx_band, NodeIndex &idx, const int *nbox, int dim); + MWNodeVector *makeOperBand(const MWNode &gNode, std::vector> &idx_band); + void fillOperBand(MWNodeVector *band, std::vector> &idx_band, NodeIndex &idx, const int *nbox, int dim); void initTimers(); void clearTimers(); void printTimers() const; void initBandSizes(); - int getBandSizeFactor(int i, int depth, const OperatorState &os) const { + int getBandSizeFactor(int i, int depth, const OperatorState &os) const { int k = os.gt * this->nComp + os.ft; return (*this->bandSizes[i])(depth, k); } void calcBandSizeFactor(Eigen::MatrixXi &bs, int depth, const BandWidth &bw); - void calcNode(MWNode &node) override; + void calcNode(MWNode &node) override; void postProcess() override { printTimers(); clearTimers(); initTimers(); } - void applyOperComp(OperatorState &os); - void applyOperator(int i, OperatorState &os); - void tensorApplyOperComp(OperatorState &os); + void applyOperComp(OperatorState &os); + void applyOperator(int i, OperatorState &os); + void tensorApplyOperComp(OperatorState &os); - void touchParentNodes(MWTree &tree) const; + void touchParentNodes(MWTree &tree) const; }; } // namespace mrcpp diff --git a/src/treebuilders/CopyAdaptor.cpp b/src/treebuilders/CopyAdaptor.cpp index 4017c6e5e..8312ebb0f 100644 --- a/src/treebuilders/CopyAdaptor.cpp +++ b/src/treebuilders/CopyAdaptor.cpp @@ -29,21 +29,21 @@ namespace mrcpp { -template -CopyAdaptor::CopyAdaptor(FunctionTree &t, int ms, int *bw) - : TreeAdaptor(ms) { +template +CopyAdaptor::CopyAdaptor(FunctionTree &t, int ms, int *bw) + : TreeAdaptor(ms) { setBandWidth(bw); tree_vec.push_back(std::make_tuple(1.0, &t)); } -template -CopyAdaptor::CopyAdaptor(FunctionTreeVector &t, int ms, int *bw) - : TreeAdaptor(ms) +template +CopyAdaptor::CopyAdaptor(FunctionTreeVector &t, int ms, int *bw) + : TreeAdaptor(ms) , tree_vec(t) { setBandWidth(bw); } -template void CopyAdaptor::setBandWidth(int *bw) { +template void CopyAdaptor::setBandWidth(int *bw) { for (int d = 0; d < D; d++) { if (bw != nullptr) { this->bandWidth[d] = bw[d]; @@ -53,7 +53,7 @@ template void CopyAdaptor::setBandWidth(int *bw) { } } -template bool CopyAdaptor::splitNode(const MWNode &node) const { +template bool CopyAdaptor::splitNode(const MWNode &node) const { const NodeIndex &idx = node.getNodeIndex(); for (int c = 0; c < node.getTDim(); c++) { for (int d = 0; d < D; d++) { @@ -61,8 +61,8 @@ template bool CopyAdaptor::splitNode(const MWNode &node) const { NodeIndex bwIdx = idx.child(c); bwIdx[d] += bw; for (int i = 0; i < this->tree_vec.size(); i++) { - const FunctionTree &func_i = get_func(tree_vec, i); - const MWNode *node_i = func_i.findNode(bwIdx); + const FunctionTree &func_i = get_func(tree_vec, i); + const MWNode *node_i = func_i.findNode(bwIdx); if (node_i != nullptr) return true; } } @@ -71,8 +71,12 @@ template bool CopyAdaptor::splitNode(const MWNode &node) const { return false; } -template class CopyAdaptor<1>; -template class CopyAdaptor<2>; -template class CopyAdaptor<3>; +template class CopyAdaptor<1, double>; +template class CopyAdaptor<2, double>; +template class CopyAdaptor<3, double>; + +template class CopyAdaptor<1, ComplexDouble>; +template class CopyAdaptor<2, ComplexDouble>; +template class CopyAdaptor<3, ComplexDouble>; } // namespace mrcpp diff --git a/src/treebuilders/CopyAdaptor.h b/src/treebuilders/CopyAdaptor.h index c9451e599..adeeb6766 100644 --- a/src/treebuilders/CopyAdaptor.h +++ b/src/treebuilders/CopyAdaptor.h @@ -30,17 +30,17 @@ namespace mrcpp { -template class CopyAdaptor final : public TreeAdaptor { + template class CopyAdaptor final : public TreeAdaptor { public: - CopyAdaptor(FunctionTree &t, int ms, int *bw); - CopyAdaptor(FunctionTreeVector &t, int ms, int *bw); + CopyAdaptor(FunctionTree &t, int ms, int *bw); + CopyAdaptor(FunctionTreeVector &t, int ms, int *bw); private: int bandWidth[D]; - FunctionTreeVector tree_vec; + FunctionTreeVector tree_vec; void setBandWidth(int *bw); - bool splitNode(const MWNode &node) const override; + bool splitNode(const MWNode &node) const override; }; } // namespace mrcpp diff --git a/src/treebuilders/CrossCorrelationCalculator.cpp b/src/treebuilders/CrossCorrelationCalculator.cpp index efe9a3390..b4c2fc3ad 100644 --- a/src/treebuilders/CrossCorrelationCalculator.cpp +++ b/src/treebuilders/CrossCorrelationCalculator.cpp @@ -77,7 +77,7 @@ template void CrossCorrelationCalculator::applyCcc(MWNode<2> &node, Cros const MWNode<1> &node_a = this->kernel->getNode(idx_a); const MWNode<1> &node_b = this->kernel->getNode(idx_b); - VectorXd vec_a; + Eigen::Matrix vec_a; VectorXd vec_b; node_a.getCoefs(vec_a); node_b.getCoefs(vec_b); diff --git a/src/treebuilders/DefaultCalculator.h b/src/treebuilders/DefaultCalculator.h index 13f698162..4a1a4ce54 100644 --- a/src/treebuilders/DefaultCalculator.h +++ b/src/treebuilders/DefaultCalculator.h @@ -29,16 +29,16 @@ namespace mrcpp { -template class DefaultCalculator final : public TreeCalculator { +template class DefaultCalculator final : public TreeCalculator { public: // Reimplementation without OpenMP, the default is faster this way - void calcNodeVector(MWNodeVector &nodeVec) override { + void calcNodeVector(MWNodeVector &nodeVec) override { int nNodes = nodeVec.size(); for (int n = 0; n < nNodes; n++) { calcNode(*nodeVec[n]); } } private: - void calcNode(MWNode &node) override { + void calcNode(MWNode &node) override { node.clearHasCoefs(); node.clearNorms(); } diff --git a/src/treebuilders/DerivativeCalculator.cpp b/src/treebuilders/DerivativeCalculator.cpp index a5acdc297..9f384013e 100644 --- a/src/treebuilders/DerivativeCalculator.cpp +++ b/src/treebuilders/DerivativeCalculator.cpp @@ -42,8 +42,8 @@ using Eigen::MatrixXd; namespace mrcpp { -template -DerivativeCalculator::DerivativeCalculator(int dir, DerivativeOperator &o, FunctionTree &f) +template +DerivativeCalculator::DerivativeCalculator(int dir, DerivativeOperator &o, FunctionTree &f) : applyDir(dir) , fTree(&f) , oper(&o) { @@ -51,12 +51,12 @@ DerivativeCalculator::DerivativeCalculator(int dir, DerivativeOperator &o, initTimers(); } -template DerivativeCalculator::~DerivativeCalculator() { +template DerivativeCalculator::~DerivativeCalculator() { this->operStat.flushNodeCounters(); println(10, this->operStat); } -template void DerivativeCalculator::initTimers() { +template void DerivativeCalculator::initTimers() { int nThreads = mrcpp_get_max_threads(); for (int i = 0; i < nThreads; i++) { this->band_t.push_back(Timer(false)); @@ -65,13 +65,13 @@ template void DerivativeCalculator::initTimers() { } } -template void DerivativeCalculator::clearTimers() { +template void DerivativeCalculator::clearTimers() { this->band_t.clear(); this->calc_t.clear(); this->norm_t.clear(); } -template void DerivativeCalculator::printTimers() const { +template void DerivativeCalculator::printTimers() const { int oldprec = Printer::setPrecision(1); int nThreads = mrcpp_get_max_threads(); printout(20, "\n\nthread "); @@ -86,12 +86,12 @@ template void DerivativeCalculator::printTimers() const { Printer::setPrecision(oldprec); } -template void DerivativeCalculator::calcNode(MWNode &inpNode, MWNode &outNode) { + template void DerivativeCalculator::calcNode(MWNode &inpNode, MWNode &outNode) { //if (this->oper->getMaxBandWidth() > 1) MSG_ABORT("Only implemented for zero bw"); outNode.zeroCoefs(); int nComp = (1 << D); - double tmpCoefs[outNode.getNCoefs()]; - OperatorState os(outNode, tmpCoefs); + T tmpCoefs[outNode.getNCoefs()]; + OperatorState os(outNode, tmpCoefs); os.setFNode(inpNode); os.setFIndex(inpNode.nodeIndex); @@ -114,24 +114,24 @@ template void DerivativeCalculator::calcNode(MWNode &inpNode, MWNo } -template void DerivativeCalculator::calcNode(MWNode &gNode) { +template void DerivativeCalculator::calcNode(MWNode &gNode) { gNode.zeroCoefs(); int nComp = (1 << D); - double tmpCoefs[gNode.getNCoefs()]; - OperatorState os(gNode, tmpCoefs); + T tmpCoefs[gNode.getNCoefs()]; + OperatorState os(gNode, tmpCoefs); this->operStat.incrementGNodeCounters(gNode); // Get all nodes in f within the bandwith of O in g this->band_t[mrcpp_get_thread_num()].resume(); std::vector> idx_band; - MWNodeVector fBand = makeOperBand(gNode, idx_band); + MWNodeVector fBand = makeOperBand(gNode, idx_band); this->band_t[mrcpp_get_thread_num()].stop(); this->calc_t[mrcpp_get_thread_num()].resume(); for (int n = 0; n < fBand.size(); n++) { - MWNode &fNode = *fBand[n]; + MWNode &fNode = *fBand[n]; NodeIndex &fIdx = idx_band[n]; os.setFNode(fNode); os.setFIndex(fIdx); @@ -157,12 +157,12 @@ template void DerivativeCalculator::calcNode(MWNode &gNode) { } /** Return a vector of nodes in F affected by O, given a node in G */ -template -MWNodeVector DerivativeCalculator::makeOperBand(const MWNode &gNode, std::vector> &idx_band) { +template +MWNodeVector DerivativeCalculator::makeOperBand(const MWNode &gNode, std::vector> &idx_band) { assert(this->applyDir >= 0); assert(this->applyDir < D); - MWNodeVector band; + MWNodeVector band; const NodeIndex &idx_0 = gNode.getNodeIndex(); // Assumes given width only in applyDir, otherwise width = 0 @@ -182,10 +182,10 @@ MWNodeVector DerivativeCalculator::makeOperBand(const MWNode &gNode, st } /** Apply a single operator component (term) to a single f-node assuming zero bandwidth */ -template void DerivativeCalculator::applyOperator_bw0(OperatorState &os) { +template void DerivativeCalculator::applyOperator_bw0(OperatorState &os) { //cout<<" applyOperator "< &gNode = *os.gNode; - MWNode &fNode = *os.fNode; + MWNode &gNode = *os.gNode; + MWNode &fNode = *os.fNode; const NodeIndex &fIdx = *os.fIdx; const NodeIndex &gIdx = gNode.getNodeIndex(); int depth = gNode.getDepth(); @@ -216,9 +216,9 @@ template void DerivativeCalculator::applyOperator_bw0(OperatorState void DerivativeCalculator::applyOperator(OperatorState &os) { - MWNode &gNode = *os.gNode; - MWNode &fNode = *os.fNode; +template void DerivativeCalculator::applyOperator(OperatorState &os) { + MWNode &gNode = *os.gNode; + MWNode &fNode = *os.fNode; const NodeIndex &fIdx = *os.fIdx; const NodeIndex &gIdx = gNode.getNodeIndex(); int depth = gNode.getDepth(); @@ -261,9 +261,10 @@ template void DerivativeCalculator::applyOperator(OperatorState &o /** Perform the required linear algebra operations in order to apply an operator component to a f-node in a n-dimensional tensor space. */ -template void DerivativeCalculator::tensorApplyOperComp(OperatorState &os) { - double **aux = os.getAuxData(); +template void DerivativeCalculator::tensorApplyOperComp(OperatorState &os) { + T **aux = os.getAuxData(); double **oData = os.getOperData(); + /* #ifdef HAVE_BLAS double mult = 0.0; for (int i = 0; i < D; i++) { @@ -271,8 +272,8 @@ template void DerivativeCalculator::tensorApplyOperComp(OperatorState if (i == D - 1) { // Last dir: Add up into g mult = 1.0; } - const double *f = aux[i]; - double *g = const_cast(aux[i + 1]); + const T *f = aux[i]; + T *g = const_cast(aux[i + 1]); cblas_dgemm(CblasColMajor, CblasTrans, CblasNoTrans, @@ -301,9 +302,10 @@ template void DerivativeCalculator::tensorApplyOperComp(OperatorState } } #else + */ for (int i = 0; i < D; i++) { - Eigen::Map f(aux[i], os.kp1, os.kp1_dm1); - Eigen::Map g(aux[i + 1], os.kp1_dm1, os.kp1); + Eigen::Map> f(aux[i], os.kp1, os.kp1_dm1); + Eigen::Map> g(aux[i + 1], os.kp1_dm1, os.kp1); if (oData[i] != nullptr) { Eigen::Map op(oData[i], os.kp1, os.kp1); if (i == D - 1) { // Last dir: Add up into g @@ -320,15 +322,19 @@ template void DerivativeCalculator::tensorApplyOperComp(OperatorState } } } -#endif + //#endif } -template MWNodeVector *DerivativeCalculator::getInitialWorkVector(MWTree &tree) const { +template MWNodeVector *DerivativeCalculator::getInitialWorkVector(MWTree &tree) const { return tree.copyEndNodeTable(); } -template class DerivativeCalculator<1>; -template class DerivativeCalculator<2>; -template class DerivativeCalculator<3>; +template class DerivativeCalculator<1, double>; +template class DerivativeCalculator<2, double>; +template class DerivativeCalculator<3, double>; + +template class DerivativeCalculator<1, ComplexDouble>; +template class DerivativeCalculator<2, ComplexDouble>; +template class DerivativeCalculator<3, ComplexDouble>; } // namespace mrcpp diff --git a/src/treebuilders/DerivativeCalculator.h b/src/treebuilders/DerivativeCalculator.h index 5d4d28716..9adc48046 100644 --- a/src/treebuilders/DerivativeCalculator.h +++ b/src/treebuilders/DerivativeCalculator.h @@ -30,40 +30,40 @@ namespace mrcpp { -template class DerivativeCalculator final : public TreeCalculator { +template class DerivativeCalculator final : public TreeCalculator { public: - DerivativeCalculator(int dir, DerivativeOperator &o, FunctionTree &f); + DerivativeCalculator(int dir, DerivativeOperator &o, FunctionTree &f); ~DerivativeCalculator() override; - MWNodeVector *getInitialWorkVector(MWTree &tree) const override; - void calcNode(MWNode &fNode, MWNode &gNode); + MWNodeVector *getInitialWorkVector(MWTree &tree) const override; + void calcNode(MWNode &fNode, MWNode &gNode); private: int applyDir; - FunctionTree *fTree; + FunctionTree *fTree; DerivativeOperator *oper; std::vector band_t; std::vector calc_t; std::vector norm_t; - OperatorStatistics operStat; + OperatorStatistics operStat; - MWNodeVector makeOperBand(const MWNode &gNode, std::vector> &idx_band); + MWNodeVector makeOperBand(const MWNode &gNode, std::vector> &idx_band); void initTimers(); void clearTimers(); void printTimers() const; - void calcNode(MWNode &node) override; + void calcNode(MWNode &node) override; void postProcess() override { printTimers(); clearTimers(); initTimers(); } - void applyOperator(OperatorState &os); - void applyOperator_bw0(OperatorState &os); - void tensorApplyOperComp(OperatorState &os); + void applyOperator(OperatorState &os); + void applyOperator_bw0(OperatorState &os); + void tensorApplyOperComp(OperatorState &os); }; } // namespace mrcpp diff --git a/src/treebuilders/MapCalculator.h b/src/treebuilders/MapCalculator.h index 492c1f440..33f799ee9 100644 --- a/src/treebuilders/MapCalculator.h +++ b/src/treebuilders/MapCalculator.h @@ -29,24 +29,24 @@ namespace mrcpp { -template class MapCalculator final : public TreeCalculator { +template class MapCalculator final : public TreeCalculator { public: - MapCalculator(FMap fm, FunctionTree &inp) + MapCalculator(FMap fm, FunctionTree &inp) : func(&inp) , fmap(std::move(fm)) {} private: - FunctionTree *func; - FMap fmap; - void calcNode(MWNode &node_o) override { + FunctionTree *func; + FMap fmap; + void calcNode(MWNode &node_o) override { const NodeIndex &idx = node_o.getNodeIndex(); int n_coefs = node_o.getNCoefs(); - double *coefs_o = node_o.getCoefs(); + T *coefs_o = node_o.getCoefs(); // This generates missing nodes - MWNode node_i = func->getNode(idx); // Copy node + MWNode node_i = func->getNode(idx); // Copy node node_i.mwTransform(Reconstruction); node_i.cvTransform(Forward); - const double *coefs_i = node_i.getCoefs(); + const T *coefs_i = node_i.getCoefs(); for (int j = 0; j < n_coefs; j++) { coefs_o[j] = fmap(coefs_i[j]); } node_o.cvTransform(Backward); node_o.mwTransform(Compression); diff --git a/src/treebuilders/MultiplicationAdaptor.h b/src/treebuilders/MultiplicationAdaptor.h index ff0fe992d..9637ac055 100644 --- a/src/treebuilders/MultiplicationAdaptor.h +++ b/src/treebuilders/MultiplicationAdaptor.h @@ -31,19 +31,19 @@ namespace mrcpp { -template class MultiplicationAdaptor : public TreeAdaptor { +template class MultiplicationAdaptor : public TreeAdaptor { public: - MultiplicationAdaptor(double pr, int ms, FunctionTreeVector &t) - : TreeAdaptor(ms) + MultiplicationAdaptor(double pr, int ms, FunctionTreeVector &t) + : TreeAdaptor(ms) , prec(pr) , trees(t) {} ~MultiplicationAdaptor() override = default; protected: double prec; - mutable FunctionTreeVector trees; + mutable FunctionTreeVector trees; - bool splitNode(const MWNode &node) const override { + bool splitNode(const MWNode &node) const override { if (this->trees.size() != 2) MSG_ERROR("Invalid tree vec size: " << this->trees.size()); auto &pNode0 = get_func(trees, 0).getNode(node.getNodeIndex()); auto &pNode1 = get_func(trees, 1).getNode(node.getNodeIndex()); diff --git a/src/treebuilders/MultiplicationCalculator.h b/src/treebuilders/MultiplicationCalculator.h index ba5669f4d..4f82756c2 100644 --- a/src/treebuilders/MultiplicationCalculator.h +++ b/src/treebuilders/MultiplicationCalculator.h @@ -30,26 +30,26 @@ namespace mrcpp { -template class MultiplicationCalculator final : public TreeCalculator { +template class MultiplicationCalculator final : public TreeCalculator { public: - MultiplicationCalculator(const FunctionTreeVector &inp) + MultiplicationCalculator(const FunctionTreeVector &inp) : prod_vec(inp) {} private: - FunctionTreeVector prod_vec; + FunctionTreeVector prod_vec; - void calcNode(MWNode &node_o) override { + void calcNode(MWNode &node_o) override { const NodeIndex &idx = node_o.getNodeIndex(); - double *coefs_o = node_o.getCoefs(); + T *coefs_o = node_o.getCoefs(); for (int j = 0; j < node_o.getNCoefs(); j++) { coefs_o[j] = 1.0; } for (int i = 0; i < this->prod_vec.size(); i++) { - double c_i = get_coef(this->prod_vec, i); - FunctionTree &func_i = get_func(this->prod_vec, i); + T c_i = get_coef(this->prod_vec, i); + FunctionTree &func_i = get_func(this->prod_vec, i); // This generates missing nodes - MWNode node_i = func_i.getNode(idx); // Copy node + MWNode node_i = func_i.getNode(idx); // Copy node node_i.mwTransform(Reconstruction); node_i.cvTransform(Forward); - const double *coefs_i = node_i.getCoefs(); + const T *coefs_i = node_i.getCoefs(); int n_coefs = node_i.getNCoefs(); for (int j = 0; j < n_coefs; j++) { coefs_o[j] *= c_i * coefs_i[j]; } } diff --git a/src/treebuilders/PowerCalculator.h b/src/treebuilders/PowerCalculator.h index bb2124b73..79147fc4b 100644 --- a/src/treebuilders/PowerCalculator.h +++ b/src/treebuilders/PowerCalculator.h @@ -29,25 +29,25 @@ namespace mrcpp { -template class PowerCalculator final : public TreeCalculator { +template class PowerCalculator final : public TreeCalculator { public: - PowerCalculator(FunctionTree &inp, double pow) + PowerCalculator(FunctionTree &inp, double pow) : power(pow) , func(&inp) {} private: double power; - FunctionTree *func; + FunctionTree *func; - void calcNode(MWNode &node_o) override { + void calcNode(MWNode &node_o) override { const NodeIndex &idx = node_o.getNodeIndex(); int n_coefs = node_o.getNCoefs(); - double *coefs_o = node_o.getCoefs(); + T *coefs_o = node_o.getCoefs(); // This generates missing nodes - MWNode node_i = func->getNode(idx); // Copy node + MWNode node_i = func->getNode(idx); // Copy node node_i.mwTransform(Reconstruction); node_i.cvTransform(Forward); - const double *coefs_i = node_i.getCoefs(); + const T *coefs_i = node_i.getCoefs(); for (int j = 0; j < n_coefs; j++) { coefs_o[j] = std::pow(coefs_i[j], this->power); } node_o.cvTransform(Backward); node_o.mwTransform(Compression); diff --git a/src/treebuilders/ProjectionCalculator.cpp b/src/treebuilders/ProjectionCalculator.cpp index 46335d092..e451ea69e 100644 --- a/src/treebuilders/ProjectionCalculator.cpp +++ b/src/treebuilders/ProjectionCalculator.cpp @@ -30,18 +30,19 @@ using Eigen::MatrixXd; namespace mrcpp { -template void ProjectionCalculator::calcNode(MWNode &node) { +template void ProjectionCalculator::calcNode(MWNode &node) { MatrixXd exp_pts; node.getExpandedChildPts(exp_pts); assert(exp_pts.cols() == node.getNCoefs()); Coord r; - double *coefs = node.getCoefs(); + T *coefs = node.getCoefs(); for (int i = 0; i < node.getNCoefs(); i++) { - for (int d = 0; d < D; d++) { r[d] = scaling_factor[d] * exp_pts(d, i); } - coefs[i] = this->func->evalf(r); + for (int d = 0; d < D; d++) { r[d] = scaling_factor[d] * exp_pts(d, i); } + coefs[i] = this->func->evalf(r); } + node.cvTransform(Backward); node.mwTransform(Compression); node.setHasCoefs(); @@ -50,7 +51,7 @@ template void ProjectionCalculator::calcNode(MWNode &node) { /* Old interpolating version, somewhat faster template -void ProjectionCalculator::calcNode(MWNode &node) { +void ProjectionCalculator::calcNode(MWNode &node) { const ScalingBasis &sf = node.getMWTree().getMRA().getScalingBasis(); if (sf.getScalingType() != Interpol) { NOT_IMPLEMENTED_ABORT; @@ -104,8 +105,12 @@ void ProjectionCalculator::calcNode(MWNode &node) { } */ -template class ProjectionCalculator<1>; -template class ProjectionCalculator<2>; -template class ProjectionCalculator<3>; +template class ProjectionCalculator<1, double>; +template class ProjectionCalculator<2, double>; +template class ProjectionCalculator<3, double>; + +template class ProjectionCalculator<1, ComplexDouble>; +template class ProjectionCalculator<2, ComplexDouble>; +template class ProjectionCalculator<3, ComplexDouble>; } // namespace mrcpp diff --git a/src/treebuilders/ProjectionCalculator.h b/src/treebuilders/ProjectionCalculator.h index 2fbbb09fe..067c41422 100644 --- a/src/treebuilders/ProjectionCalculator.h +++ b/src/treebuilders/ProjectionCalculator.h @@ -29,16 +29,16 @@ namespace mrcpp { -template class ProjectionCalculator final : public TreeCalculator { +template class ProjectionCalculator final : public TreeCalculator { public: - ProjectionCalculator(const RepresentableFunction &inp_func, const std::array &sf) + ProjectionCalculator(const RepresentableFunction &inp_func, const std::array &sf) : func(&inp_func) , scaling_factor(sf) {} private: - const RepresentableFunction *func; + const RepresentableFunction *func; const std::array scaling_factor; - void calcNode(MWNode &node) override; + void calcNode(MWNode &node) override; }; } // namespace mrcpp diff --git a/src/treebuilders/SplitAdaptor.h b/src/treebuilders/SplitAdaptor.h index b9d50fe8b..7e81bbe8b 100644 --- a/src/treebuilders/SplitAdaptor.h +++ b/src/treebuilders/SplitAdaptor.h @@ -29,16 +29,16 @@ namespace mrcpp { -template class SplitAdaptor final : public TreeAdaptor { +template class SplitAdaptor final : public TreeAdaptor { public: SplitAdaptor(int ms, bool sp) - : TreeAdaptor(ms) + : TreeAdaptor(ms) , split(sp) {} private: bool split; - bool splitNode(const MWNode &node) const override { return this->split; } + bool splitNode(const MWNode &node) const override { return this->split; } }; } // namespace mrcpp diff --git a/src/treebuilders/SquareCalculator.h b/src/treebuilders/SquareCalculator.h index e9bb0f8d3..e56b41cf0 100644 --- a/src/treebuilders/SquareCalculator.h +++ b/src/treebuilders/SquareCalculator.h @@ -29,23 +29,23 @@ namespace mrcpp { -template class SquareCalculator final : public TreeCalculator { +template class SquareCalculator final : public TreeCalculator { public: - SquareCalculator(FunctionTree &inp) + SquareCalculator(FunctionTree &inp) : func(&inp) {} private: - FunctionTree *func; + FunctionTree *func; - void calcNode(MWNode &node_o) override { + void calcNode(MWNode &node_o) override { const NodeIndex &idx = node_o.getNodeIndex(); int n_coefs = node_o.getNCoefs(); - double *coefs_o = node_o.getCoefs(); + T *coefs_o = node_o.getCoefs(); // This generates missing nodes - MWNode node_i = func->getNode(idx); // Copy node + MWNode node_i = func->getNode(idx); // Copy node node_i.mwTransform(Reconstruction); node_i.cvTransform(Forward); - const double *coefs_i = node_i.getCoefs(); + const T *coefs_i = node_i.getCoefs(); for (int j = 0; j < n_coefs; j++) { coefs_o[j] = coefs_i[j] * coefs_i[j]; } node_o.cvTransform(Backward); node_o.mwTransform(Compression); diff --git a/src/treebuilders/TreeAdaptor.h b/src/treebuilders/TreeAdaptor.h index a46bab648..80cecb09e 100644 --- a/src/treebuilders/TreeAdaptor.h +++ b/src/treebuilders/TreeAdaptor.h @@ -30,7 +30,7 @@ namespace mrcpp { -template class TreeAdaptor { +template class TreeAdaptor { public: TreeAdaptor(int ms) : maxScale(ms) {} @@ -38,9 +38,9 @@ template class TreeAdaptor { void setMaxScale(int ms) { this->maxScale = ms; } - void splitNodeVector(MWNodeVector &out, MWNodeVector &inp) const { + void splitNodeVector(MWNodeVector &out, MWNodeVector &inp) const { for (int n = 0; n < inp.size(); n++) { - MWNode &node = *inp[n]; + MWNode &node = *inp[n]; // Can be BranchNode in operator application if (node.isBranchNode()) continue; if (node.getScale() + 2 > this->maxScale) continue; @@ -54,7 +54,7 @@ template class TreeAdaptor { protected: int maxScale; - virtual bool splitNode(const MWNode &node) const = 0; + virtual bool splitNode(const MWNode &node) const = 0; }; } // namespace mrcpp diff --git a/src/treebuilders/TreeBuilder.cpp b/src/treebuilders/TreeBuilder.cpp index 223d94794..225b55cb5 100644 --- a/src/treebuilders/TreeBuilder.cpp +++ b/src/treebuilders/TreeBuilder.cpp @@ -35,13 +35,13 @@ namespace mrcpp { -template -void TreeBuilder::build(MWTree &tree, TreeCalculator &calculator, TreeAdaptor &adaptor, int maxIter) const { +template +void TreeBuilder::build(MWTree &tree, TreeCalculator &calculator, TreeAdaptor &adaptor, int maxIter) const { Timer calc_t(false), split_t(false), norm_t(false); println(10, " == Building tree"); - MWNodeVector *newVec = nullptr; - MWNodeVector *workVec = calculator.getInitialWorkVector(tree); + MWNodeVector *newVec = nullptr; + MWNodeVector *workVec = calculator.getInitialWorkVector(tree); double sNorm = 0.0; double wNorm = 0.0; @@ -69,7 +69,7 @@ void TreeBuilder::build(MWTree &tree, TreeCalculator &calculator, TreeA norm_t.stop(); split_t.resume(); - newVec = new MWNodeVector; + newVec = new MWNodeVector; if (iter >= maxIter and maxIter >= 0) workVec->clear(); adaptor.splitNodeVector(*newVec, *workVec); split_t.stop(); @@ -87,11 +87,11 @@ void TreeBuilder::build(MWTree &tree, TreeCalculator &calculator, TreeA print::time(10, "Time split", split_t); } -template void TreeBuilder::clear(MWTree &tree, TreeCalculator &calculator) const { +template void TreeBuilder::clear(MWTree &tree, TreeCalculator &calculator) const { println(10, " == Clearing tree"); Timer clean_t; - MWNodeVector nodeVec; + MWNodeVector nodeVec; tree_utils::make_node_table(tree, nodeVec); calculator.calcNodeVector(nodeVec); // clear all coefficients clean_t.stop(); @@ -104,16 +104,16 @@ template void TreeBuilder::clear(MWTree &tree, TreeCalculator & print::separator(10, ' '); } -template int TreeBuilder::split(MWTree &tree, TreeAdaptor &adaptor, bool passCoefs) const { +template int TreeBuilder::split(MWTree &tree, TreeAdaptor &adaptor, bool passCoefs) const { println(10, " == Refining tree"); Timer split_t; - MWNodeVector newVec; - MWNodeVector *workVec = tree.copyEndNodeTable(); + MWNodeVector newVec; + MWNodeVector *workVec = tree.copyEndNodeTable(); adaptor.splitNodeVector(newVec, *workVec); if (passCoefs) { for (int i = 0; i < workVec->size(); i++) { - MWNode &node = *(*workVec)[i]; + MWNode &node = *(*workVec)[i]; if (node.isBranchNode()) { node.giveChildrenCoefs(true); } } } @@ -131,11 +131,11 @@ template int TreeBuilder::split(MWTree &tree, TreeAdaptor &adap return newVec.size(); } -template void TreeBuilder::calc(MWTree &tree, TreeCalculator &calculator) const { +template void TreeBuilder::calc(MWTree &tree, TreeCalculator &calculator) const { println(10, " == Calculating tree"); Timer calc_t; - MWNodeVector *workVec = calculator.getInitialWorkVector(tree); + MWNodeVector *workVec = calculator.getInitialWorkVector(tree); calculator.calcNodeVector(*workVec); printout(10, " -- #" << std::setw(3) << 0 << ": Calculated "); printout(10, std::setw(6) << workVec->size() << " nodes "); @@ -148,26 +148,31 @@ template void TreeBuilder::calc(MWTree &tree, TreeCalculator &c print::time(10, "Time calc", calc_t); } -template double TreeBuilder::calcScalingNorm(const MWNodeVector &vec) const { +template double TreeBuilder::calcScalingNorm(const MWNodeVector &vec) const { double sNorm = 0.0; for (int i = 0; i < vec.size(); i++) { - const MWNode &node = *vec[i]; + const MWNode &node = *vec[i]; if (node.getDepth() >= 0) sNorm += node.getScalingNorm(); } return sNorm; } -template double TreeBuilder::calcWaveletNorm(const MWNodeVector &vec) const { +template double TreeBuilder::calcWaveletNorm(const MWNodeVector &vec) const { double wNorm = 0.0; for (int i = 0; i < vec.size(); i++) { - const MWNode &node = *vec[i]; + const MWNode &node = *vec[i]; if (node.getDepth() >= 0) wNorm += node.getWaveletNorm(); } return wNorm; } -template class TreeBuilder<1>; -template class TreeBuilder<2>; -template class TreeBuilder<3>; +template class TreeBuilder<1, double>; +template class TreeBuilder<2, double>; +template class TreeBuilder<3, double>; + + +template class TreeBuilder<1, ComplexDouble>; +template class TreeBuilder<2, ComplexDouble>; +template class TreeBuilder<3, ComplexDouble>; } // namespace mrcpp diff --git a/src/treebuilders/TreeBuilder.h b/src/treebuilders/TreeBuilder.h index 313e9f4f4..81c32afe6 100644 --- a/src/treebuilders/TreeBuilder.h +++ b/src/treebuilders/TreeBuilder.h @@ -29,16 +29,16 @@ namespace mrcpp { -template class TreeBuilder final { +template class TreeBuilder final { public: - void build(MWTree &tree, TreeCalculator &calculator, TreeAdaptor &adaptor, int maxIter) const; - void clear(MWTree &tree, TreeCalculator &calculator) const; - void calc(MWTree &tree, TreeCalculator &calculator) const; - int split(MWTree &tree, TreeAdaptor &adaptor, bool passCoefs) const; + void build(MWTree &tree, TreeCalculator &calculator, TreeAdaptor &adaptor, int maxIter) const; + void clear(MWTree &tree, TreeCalculator &calculator) const; + void calc(MWTree &tree, TreeCalculator &calculator) const; + int split(MWTree &tree, TreeAdaptor &adaptor, bool passCoefs) const; private: - double calcScalingNorm(const MWNodeVector &vec) const; - double calcWaveletNorm(const MWNodeVector &vec) const; + double calcScalingNorm(const MWNodeVector &vec) const; + double calcWaveletNorm(const MWNodeVector &vec) const; }; } // namespace mrcpp diff --git a/src/treebuilders/TreeCalculator.h b/src/treebuilders/TreeCalculator.h index acd9f00f8..4e171d91c 100644 --- a/src/treebuilders/TreeCalculator.h +++ b/src/treebuilders/TreeCalculator.h @@ -29,20 +29,20 @@ namespace mrcpp { -template class TreeCalculator { + template class TreeCalculator { public: TreeCalculator() = default; virtual ~TreeCalculator() = default; - virtual MWNodeVector *getInitialWorkVector(MWTree &tree) const { return tree.copyEndNodeTable(); } + virtual MWNodeVector *getInitialWorkVector(MWTree &tree) const { return tree.copyEndNodeTable(); } - virtual void calcNodeVector(MWNodeVector &nodeVec) { + virtual void calcNodeVector(MWNodeVector &nodeVec) { #pragma omp parallel shared(nodeVec) num_threads(mrcpp_get_num_threads()) { int nNodes = nodeVec.size(); #pragma omp for schedule(guided) for (int n = 0; n < nNodes; n++) { - MWNode &node = *nodeVec[n]; + MWNode &node = *nodeVec[n]; calcNode(node); } } @@ -50,7 +50,7 @@ template class TreeCalculator { } protected: - virtual void calcNode(MWNode &node) = 0; + virtual void calcNode(MWNode &node) = 0; virtual void postProcess() {} }; diff --git a/src/treebuilders/WaveletAdaptor.h b/src/treebuilders/WaveletAdaptor.h index 15da130e1..759f6b7ee 100644 --- a/src/treebuilders/WaveletAdaptor.h +++ b/src/treebuilders/WaveletAdaptor.h @@ -31,10 +31,10 @@ namespace mrcpp { -template class WaveletAdaptor : public TreeAdaptor { + template class WaveletAdaptor : public TreeAdaptor { public: WaveletAdaptor(double pr, int ms, bool ap = false, double sf = 1.0) - : TreeAdaptor(ms) + : TreeAdaptor(ms) , absPrec(ap) , prec(pr) , splitFac(sf) {} @@ -50,7 +50,7 @@ template class WaveletAdaptor : public TreeAdaptor { double splitFac; std::function &idx)> precFunc = [](const NodeIndex &idx) { return 1.0; }; - bool splitNode(const MWNode &node) const override { + bool splitNode(const MWNode &node) const override { auto precFac = this->precFunc(node.getNodeIndex()); // returns 1.0 by default return tree_utils::split_check(node, this->prec * precFac, this->splitFac, this->absPrec); } diff --git a/src/treebuilders/add.cpp b/src/treebuilders/add.cpp index 86b7f30a7..584e61e68 100644 --- a/src/treebuilders/add.cpp +++ b/src/treebuilders/add.cpp @@ -61,16 +61,16 @@ namespace mrcpp { * no coefs). * */ -template +template void add(double prec, - FunctionTree &out, - double a, - FunctionTree &inp_a, - double b, - FunctionTree &inp_b, + FunctionTree &out, + T a, + FunctionTree &inp_a, + T b, + FunctionTree &inp_b, int maxIter, bool absPrec) { - FunctionTreeVector tmp_vec; + FunctionTreeVector tmp_vec; tmp_vec.push_back(std::make_tuple(a, &inp_a)); tmp_vec.push_back(std::make_tuple(b, &inp_b)); add(prec, out, tmp_vec, maxIter, absPrec); @@ -98,14 +98,14 @@ void add(double prec, * no coefs). * */ -template void add(double prec, FunctionTree &out, FunctionTreeVector &inp, int maxIter, bool absPrec) { +template void add(double prec, FunctionTree &out, FunctionTreeVector &inp, int maxIter, bool absPrec) { for (auto i = 0; i < inp.size(); i++) if (out.getMRA() != get_func(inp, i).getMRA()) MSG_ABORT("Incompatible MRA"); int maxScale = out.getMRA().getMaxScale(); - TreeBuilder builder; - WaveletAdaptor adaptor(prec, maxScale, absPrec); - AdditionCalculator calculator(inp); + TreeBuilder builder; + WaveletAdaptor adaptor(prec, maxScale, absPrec); + AdditionCalculator calculator(inp); builder.build(out, calculator, adaptor, maxIter); @@ -116,7 +116,7 @@ template void add(double prec, FunctionTree &out, FunctionTreeVector< Timer clean_t; for (int i = 0; i < inp.size(); i++) { - FunctionTree &tree = get_func(inp, i); + FunctionTree &tree = get_func(inp, i); tree.deleteGenerated(); } clean_t.stop(); @@ -126,66 +126,124 @@ template void add(double prec, FunctionTree &out, FunctionTreeVector< print::separator(10, ' '); } -template void add(double prec, FunctionTree &out, std::vector *> &inp, int maxIter, bool absPrec) { - FunctionTreeVector inp_vec; +template void add(double prec, FunctionTree &out, std::vector *> &inp, int maxIter, bool absPrec) { + FunctionTreeVector inp_vec; for (auto &t : inp) inp_vec.push_back({1.0, t}); add(prec, out, inp_vec, maxIter, absPrec); } -template void add<1>(double prec, - FunctionTree<1> &out, +template void add<1, double>(double prec, + FunctionTree<1, double> &out, double a, - FunctionTree<1> &tree_a, + FunctionTree<1, double> &tree_a, double b, - FunctionTree<1> &tree_b, + FunctionTree<1, double> &tree_b, int maxIter, bool absPrec); -template void add<2>(double prec, - FunctionTree<2> &out, +template void add<2, double>(double prec, + FunctionTree<2, double> &out, double a, - FunctionTree<2> &tree_a, + FunctionTree<2, double> &tree_a, double b, - FunctionTree<2> &tree_b, + FunctionTree<2, double> &tree_b, int maxIter, bool absPrec); -template void add<3>(double prec, - FunctionTree<3> &out, +template void add<3, double>(double prec, + FunctionTree<3, double> &out, double a, - FunctionTree<3> &tree_a, + FunctionTree<3, double> &tree_a, double b, - FunctionTree<3> &tree_b, + FunctionTree<3, double> &tree_b, int maxIter, bool absPrec); -template void add<1>(double prec, - FunctionTree<1> &out, - FunctionTreeVector<1> &inp, +template void add<1, double>(double prec, + FunctionTree<1, double> &out, + FunctionTreeVector<1, double> &inp, int maxIter, bool absPrec); -template void add<2>(double prec, - FunctionTree<2> &out, - FunctionTreeVector<2> &inp, +template void add<2, double>(double prec, + FunctionTree<2, double> &out, + FunctionTreeVector<2, double> &inp, int maxIter, bool absPrec); -template void add<3>(double prec, - FunctionTree<3> &out, - FunctionTreeVector<3> &inp, +template void add<3, double>(double prec, + FunctionTree<3, double> &out, + FunctionTreeVector<3, double> &inp, int maxIter, bool absPrec); -template void add<1>(double prec, - FunctionTree<1> &out, - std::vector *> &inp, +template void add<1, double>(double prec, + FunctionTree<1, double> &out, + std::vector *> &inp, int maxIter, bool absPrec); -template void add<2>(double prec, - FunctionTree<2> &out, - std::vector *> &inp, +template void add<2, double>(double prec, + FunctionTree<2, double> &out, + std::vector *> &inp, int maxIter, bool absPrec); -template void add<3>(double prec, - FunctionTree<3> &out, - std::vector *> &inp, +template void add<3, double>(double prec, + FunctionTree<3, double> &out, + std::vector *> &inp, + int maxIter, + bool absPrec); + + +template void add<1, ComplexDouble>(double prec, + FunctionTree<1, ComplexDouble> &out, + ComplexDouble a, + FunctionTree<1, ComplexDouble> &tree_a, + ComplexDouble b, + FunctionTree<1, ComplexDouble> &tree_b, + int maxIter, + bool absPrec); +template void add<2, ComplexDouble>(double prec, + FunctionTree<2, ComplexDouble> &out, + ComplexDouble a, + FunctionTree<2, ComplexDouble> &tree_a, + ComplexDouble b, + FunctionTree<2, ComplexDouble> &tree_b, + int maxIter, + bool absPrec); +template void add<3, ComplexDouble>(double prec, + FunctionTree<3, ComplexDouble> &out, + ComplexDouble a, + FunctionTree<3, ComplexDouble> &tree_a, + ComplexDouble b, + FunctionTree<3, ComplexDouble> &tree_b, + int maxIter, + bool absPrec); + +template void add<1, ComplexDouble>(double prec, + FunctionTree<1, ComplexDouble> &out, + FunctionTreeVector<1, ComplexDouble> &inp, + int maxIter, + bool absPrec); +template void add<2, ComplexDouble>(double prec, + FunctionTree<2, ComplexDouble> &out, + FunctionTreeVector<2, ComplexDouble> &inp, + int maxIter, + bool absPrec); +template void add<3, ComplexDouble>(double prec, + FunctionTree<3, ComplexDouble> &out, + FunctionTreeVector<3, ComplexDouble> &inp, + int maxIter, + bool absPrec); + +template void add<1, ComplexDouble>(double prec, + FunctionTree<1, ComplexDouble> &out, + std::vector *> &inp, + int maxIter, + bool absPrec); +template void add<2, ComplexDouble>(double prec, + FunctionTree<2, ComplexDouble> &out, + std::vector *> &inp, + int maxIter, + bool absPrec); +template void add<3, ComplexDouble>(double prec, + FunctionTree<3, ComplexDouble> &out, + std::vector *> &inp, int maxIter, bool absPrec); diff --git a/src/treebuilders/add.h b/src/treebuilders/add.h index 105245829..dae1b366a 100644 --- a/src/treebuilders/add.h +++ b/src/treebuilders/add.h @@ -28,22 +28,22 @@ namespace mrcpp { -template void add(double prec, - FunctionTree &out, - double a, - FunctionTree &tree_a, - double b, - FunctionTree &tree_b, +template void add(double prec, + FunctionTree &out, + T a, + FunctionTree &tree_a, + T b, + FunctionTree &tree_b, int maxIter = -1, bool absPrec = false); -template void add(double prec, - FunctionTree &out, - FunctionTreeVector &inp, +template void add(double prec, + FunctionTree &out, + FunctionTreeVector &inp, int maxIter = -1, bool absPrec = false); -template void add(double prec, - FunctionTree &out, - std::vector *> &inp, +template void add(double prec, + FunctionTree &out, + std::vector *> &inp, int maxIter = -1, bool absPrec = false); diff --git a/src/treebuilders/apply.cpp b/src/treebuilders/apply.cpp index 3dc49de3c..4a072f694 100644 --- a/src/treebuilders/apply.cpp +++ b/src/treebuilders/apply.cpp @@ -41,7 +41,7 @@ namespace mrcpp { -template void apply_on_unit_cell(bool inside, double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, int maxIter, bool absPrec); +template void apply_on_unit_cell(bool inside, double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, int maxIter, bool absPrec); /** @brief Application of MW integral convolution operator * @@ -64,16 +64,16 @@ template void apply_on_unit_cell(bool inside, double prec, FunctionTree< * no coefs). * */ -template void apply(double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, int maxIter, bool absPrec) { +template void apply(double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, int maxIter, bool absPrec) { if (out.getMRA() != inp.getMRA()) MSG_ABORT("Incompatible MRA"); Timer pre_t; oper.calcBandWidths(prec); int maxScale = out.getMRA().getMaxScale(); - WaveletAdaptor adaptor(prec, maxScale, absPrec); - ConvolutionCalculator calculator(prec, oper, inp); + WaveletAdaptor adaptor(prec, maxScale, absPrec); + ConvolutionCalculator calculator(prec, oper, inp); pre_t.stop(); - TreeBuilder builder; + TreeBuilder builder; builder.build(out, calculator, adaptor, maxIter); Timer post_t; @@ -113,18 +113,18 @@ template void apply(double prec, FunctionTree &out, ConvolutionOperat * no coefs). * */ -template void apply_on_unit_cell(bool inside, double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, int maxIter, bool absPrec) { +template void apply_on_unit_cell(bool inside, double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, int maxIter, bool absPrec) { if (out.getMRA() != inp.getMRA()) MSG_ABORT("Incompatible MRA"); Timer pre_t; oper.calcBandWidths(prec); int maxScale = out.getMRA().getMaxScale(); - WaveletAdaptor adaptor(prec, maxScale, absPrec); - ConvolutionCalculator calculator(prec, oper, inp); + WaveletAdaptor adaptor(prec, maxScale, absPrec); + ConvolutionCalculator calculator(prec, oper, inp); calculator.startManipulateOperator(inside); pre_t.stop(); - TreeBuilder builder; + TreeBuilder builder; builder.build(out, calculator, adaptor, maxIter); Timer post_t; @@ -166,7 +166,7 @@ template void apply_on_unit_cell(bool inside, double prec, FunctionTree< * no coefs). * */ -template void apply(double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, FunctionTreeVector &precTrees, int maxIter, bool absPrec) { +template void apply(double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, FunctionTreeVector &precTrees, int maxIter, bool absPrec) { Timer pre_t; oper.calcBandWidths(prec); int maxScale = out.getMRA().getMaxScale(); @@ -183,13 +183,13 @@ template void apply(double prec, FunctionTree &out, ConvolutionOperat return 1.0 / maxNorm; }; - WaveletAdaptor adaptor(prec, maxScale, absPrec); + WaveletAdaptor adaptor(prec, maxScale, absPrec); adaptor.setPrecFunction(precFunc); - ConvolutionCalculator calculator(prec, oper, inp); + ConvolutionCalculator calculator(prec, oper, inp); calculator.setPrecFunction(precFunc); pre_t.stop(); - TreeBuilder builder; + TreeBuilder builder; builder.build(out, calculator, adaptor, maxIter); Timer post_t; @@ -227,7 +227,7 @@ template void apply(double prec, FunctionTree &out, ConvolutionOperat * no coefs). * */ -template void apply_far_field(double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, int maxIter, bool absPrec) { +template void apply_far_field(double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, int maxIter, bool absPrec) { apply_on_unit_cell(false, prec, out, oper, inp, maxIter, absPrec); } @@ -253,7 +253,7 @@ template void apply_far_field(double prec, FunctionTree &out, Convolu * no coefs). * */ -template void apply_near_field(double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, int maxIter, bool absPrec) { +template void apply_near_field(double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, int maxIter, bool absPrec) { apply_on_unit_cell(true, prec, out, oper, inp, maxIter, absPrec); } @@ -273,9 +273,9 @@ template void apply_near_field(double prec, FunctionTree &out, Convol * @note The output function should contain only empty root nodes at entry. * */ -template void apply(FunctionTree &out, DerivativeOperator &oper, FunctionTree &inp, int dir) { +template void apply(FunctionTree &out, DerivativeOperator &oper, FunctionTree &inp, int dir) { if (out.getMRA() != inp.getMRA()) MSG_ABORT("Incompatible MRA"); - TreeBuilder builder; + TreeBuilder builder; int maxScale = out.getMRA().getMaxScale(); int bw[D]; // Operator bandwidth in [x,y,z] @@ -285,14 +285,14 @@ template void apply(FunctionTree &out, DerivativeOperator &oper, F Timer pre_t; oper.calcBandWidths(1.0); // Fixed 0 or 1 for derivatives bw[dir] = oper.getMaxBandWidth(); - CopyAdaptor pre_adaptor(inp, maxScale, bw); - DefaultCalculator pre_calculator; + CopyAdaptor pre_adaptor(inp, maxScale, bw); + DefaultCalculator pre_calculator; builder.build(out, pre_calculator, pre_adaptor, -1); pre_t.stop(); // Apply operator on fixed expanded grid - SplitAdaptor apply_adaptor(maxScale, false); // Splits no nodes - DerivativeCalculator apply_calculator(dir, oper, inp); + SplitAdaptor apply_adaptor(maxScale, false); // Splits no nodes + DerivativeCalculator apply_calculator(dir, oper, inp); builder.build(out, apply_calculator, apply_adaptor, 0); if (out.isPeriodic()) out.rescale(std::pow(2.0, -oper.getOperatorRoot())); @@ -320,10 +320,10 @@ template void apply(FunctionTree &out, DerivativeOperator &oper, F * @note The length of the output vector will be the template dimension D. * */ -template FunctionTreeVector gradient(DerivativeOperator &oper, FunctionTree &inp) { - FunctionTreeVector out; +template FunctionTreeVector gradient(DerivativeOperator &oper, FunctionTree &inp) { + FunctionTreeVector out; for (int d = 0; d < D; d++) { - auto *grad_d = new FunctionTree(inp.getMRA()); + auto *grad_d = new FunctionTree(inp.getMRA()); apply(*grad_d, oper, inp, d); out.push_back({1.0, grad_d}); } @@ -346,16 +346,16 @@ template FunctionTreeVector gradient(DerivativeOperator &oper, Fun * - The output function should contain only empty root nodes at entry. * */ -template void divergence(FunctionTree &out, DerivativeOperator &oper, FunctionTreeVector &inp) { +template void divergence(FunctionTree &out, DerivativeOperator &oper, FunctionTreeVector &inp) { if (inp.size() != D) MSG_ABORT("Dimension mismatch"); for (auto i = 0; i < inp.size(); i++) if (out.getMRA() != get_func(inp, i).getMRA()) MSG_ABORT("Incompatible MRA"); - FunctionTreeVector tmp_vec; + FunctionTreeVector tmp_vec; for (int d = 0; d < D; d++) { - double coef_d = get_coef(inp, d); - FunctionTree &func_d = get_func(inp, d); - auto *out_d = new FunctionTree(func_d.getMRA()); + T coef_d = get_coef(inp, d); + FunctionTree &func_d = get_func(inp, d); + auto *out_d = new FunctionTree(func_d.getMRA()); apply(*out_d, oper, func_d, d); tmp_vec.push_back(std::make_tuple(coef_d, out_d)); } @@ -364,35 +364,62 @@ template void divergence(FunctionTree &out, DerivativeOperator &op clear(tmp_vec, true); } -template void divergence(FunctionTree &out, DerivativeOperator &oper, std::vector *> &inp) { - FunctionTreeVector inp_vec; +template void divergence(FunctionTree &out, DerivativeOperator &oper, std::vector *> &inp) { + FunctionTreeVector inp_vec; for (auto &t : inp) inp_vec.push_back({1.0, t}); divergence(out, oper, inp_vec); } -template void apply<1>(double prec, FunctionTree<1> &out, ConvolutionOperator<1> &oper, FunctionTree<1> &inp, int maxIter, bool absPrec); -template void apply<2>(double prec, FunctionTree<2> &out, ConvolutionOperator<2> &oper, FunctionTree<2> &inp, int maxIter, bool absPrec); -template void apply<3>(double prec, FunctionTree<3> &out, ConvolutionOperator<3> &oper, FunctionTree<3> &inp, int maxIter, bool absPrec); -template void apply<1>(double prec, FunctionTree<1> &out, ConvolutionOperator<1> &oper, FunctionTree<1> &inp, FunctionTreeVector<1> &precTrees, int maxIter, bool absPrec); -template void apply<2>(double prec, FunctionTree<2> &out, ConvolutionOperator<2> &oper, FunctionTree<2> &inp, FunctionTreeVector<2> &precTrees, int maxIter, bool absPrec); -template void apply<3>(double prec, FunctionTree<3> &out, ConvolutionOperator<3> &oper, FunctionTree<3> &inp, FunctionTreeVector<3> &precTrees, int maxIter, bool absPrec); -template void apply_far_field<1>(double prec, FunctionTree<1> &out, ConvolutionOperator<1> &oper, FunctionTree<1> &inp, int maxIter, bool absPrec); -template void apply_far_field<2>(double prec, FunctionTree<2> &out, ConvolutionOperator<2> &oper, FunctionTree<2> &inp, int maxIter, bool absPrec); -template void apply_far_field<3>(double prec, FunctionTree<3> &out, ConvolutionOperator<3> &oper, FunctionTree<3> &inp, int maxIter, bool absPrec); -template void apply_near_field<1>(double prec, FunctionTree<1> &out, ConvolutionOperator<1> &oper, FunctionTree<1> &inp, int maxIter, bool absPrec); -template void apply_near_field<2>(double prec, FunctionTree<2> &out, ConvolutionOperator<2> &oper, FunctionTree<2> &inp, int maxIter, bool absPrec); -template void apply_near_field<3>(double prec, FunctionTree<3> &out, ConvolutionOperator<3> &oper, FunctionTree<3> &inp, int maxIter, bool absPrec); -template void apply<1>(FunctionTree<1> &out, DerivativeOperator<1> &oper, FunctionTree<1> &inp, int dir); -template void apply<2>(FunctionTree<2> &out, DerivativeOperator<2> &oper, FunctionTree<2> &inp, int dir); -template void apply<3>(FunctionTree<3> &out, DerivativeOperator<3> &oper, FunctionTree<3> &inp, int dir); -template void divergence<1>(FunctionTree<1> &out, DerivativeOperator<1> &oper, FunctionTreeVector<1> &inp); -template void divergence<2>(FunctionTree<2> &out, DerivativeOperator<2> &oper, FunctionTreeVector<2> &inp); -template void divergence<3>(FunctionTree<3> &out, DerivativeOperator<3> &oper, FunctionTreeVector<3> &inp); -template void divergence<1>(FunctionTree<1> &out, DerivativeOperator<1> &oper, std::vector *> &inp); -template void divergence<2>(FunctionTree<2> &out, DerivativeOperator<2> &oper, std::vector *> &inp); -template void divergence<3>(FunctionTree<3> &out, DerivativeOperator<3> &oper, std::vector *> &inp); -template FunctionTreeVector<1> gradient<1>(DerivativeOperator<1> &oper, FunctionTree<1> &inp); -template FunctionTreeVector<2> gradient<2>(DerivativeOperator<2> &oper, FunctionTree<2> &inp); -template FunctionTreeVector<3> gradient<3>(DerivativeOperator<3> &oper, FunctionTree<3> &inp); +template void apply<1, double>(double prec, FunctionTree<1, double> &out, ConvolutionOperator<1> &oper, FunctionTree<1, double> &inp, int maxIter, bool absPrec); +template void apply<2, double>(double prec, FunctionTree<2, double> &out, ConvolutionOperator<2> &oper, FunctionTree<2, double> &inp, int maxIter, bool absPrec); +template void apply<3, double>(double prec, FunctionTree<3, double> &out, ConvolutionOperator<3> &oper, FunctionTree<3, double> &inp, int maxIter, bool absPrec); +template void apply<1, double>(double prec, FunctionTree<1, double> &out, ConvolutionOperator<1> &oper, FunctionTree<1, double> &inp, FunctionTreeVector<1, double> &precTrees, int maxIter, bool absPrec); +template void apply<2, double>(double prec, FunctionTree<2, double> &out, ConvolutionOperator<2> &oper, FunctionTree<2, double> &inp, FunctionTreeVector<2, double> &precTrees, int maxIter, bool absPrec); +template void apply<3, double>(double prec, FunctionTree<3, double> &out, ConvolutionOperator<3> &oper, FunctionTree<3, double> &inp, FunctionTreeVector<3, double> &precTrees, int maxIter, bool absPrec); +template void apply_far_field<1, double>(double prec, FunctionTree<1, double> &out, ConvolutionOperator<1> &oper, FunctionTree<1, double> &inp, int maxIter, bool absPrec); +template void apply_far_field<2, double>(double prec, FunctionTree<2, double> &out, ConvolutionOperator<2> &oper, FunctionTree<2, double> &inp, int maxIter, bool absPrec); +template void apply_far_field<3, double>(double prec, FunctionTree<3, double> &out, ConvolutionOperator<3> &oper, FunctionTree<3, double> &inp, int maxIter, bool absPrec); +template void apply_near_field<1, double>(double prec, FunctionTree<1, double> &out, ConvolutionOperator<1> &oper, FunctionTree<1, double> &inp, int maxIter, bool absPrec); +template void apply_near_field<2, double>(double prec, FunctionTree<2, double> &out, ConvolutionOperator<2> &oper, FunctionTree<2, double> &inp, int maxIter, bool absPrec); +template void apply_near_field<3, double>(double prec, FunctionTree<3, double> &out, ConvolutionOperator<3> &oper, FunctionTree<3, double> &inp, int maxIter, bool absPrec); +template void apply<1, double>(FunctionTree<1, double> &out, DerivativeOperator<1> &oper, FunctionTree<1, double> &inp, int dir); +template void apply<2, double>(FunctionTree<2, double> &out, DerivativeOperator<2> &oper, FunctionTree<2, double> &inp, int dir); +template void apply<3, double>(FunctionTree<3, double> &out, DerivativeOperator<3> &oper, FunctionTree<3, double> &inp, int dir); +template void divergence<1, double>(FunctionTree<1, double> &out, DerivativeOperator<1> &oper, FunctionTreeVector<1, double> &inp); +template void divergence<2, double>(FunctionTree<2, double> &out, DerivativeOperator<2> &oper, FunctionTreeVector<2, double> &inp); +template void divergence<3, double>(FunctionTree<3, double> &out, DerivativeOperator<3> &oper, FunctionTreeVector<3, double> &inp); +template void divergence<1, double>(FunctionTree<1, double> &out, DerivativeOperator<1> &oper, std::vector *> &inp); +template void divergence<2, double>(FunctionTree<2, double> &out, DerivativeOperator<2> &oper, std::vector *> &inp); +template void divergence<3, double>(FunctionTree<3, double> &out, DerivativeOperator<3> &oper, std::vector *> &inp); +template FunctionTreeVector<1, double> gradient<1>(DerivativeOperator<1> &oper, FunctionTree<1, double> &inp); +template FunctionTreeVector<2, double> gradient<2>(DerivativeOperator<2> &oper, FunctionTree<2, double> &inp); +template FunctionTreeVector<3, double> gradient<3>(DerivativeOperator<3> &oper, FunctionTree<3, double> &inp); + + + +template void apply<1, ComplexDouble>(double prec, FunctionTree<1, ComplexDouble> &out, ConvolutionOperator<1> &oper, FunctionTree<1, ComplexDouble> &inp, int maxIter, bool absPrec); +template void apply<2, ComplexDouble>(double prec, FunctionTree<2, ComplexDouble> &out, ConvolutionOperator<2> &oper, FunctionTree<2, ComplexDouble> &inp, int maxIter, bool absPrec); +template void apply<3, ComplexDouble>(double prec, FunctionTree<3, ComplexDouble> &out, ConvolutionOperator<3> &oper, FunctionTree<3, ComplexDouble> &inp, int maxIter, bool absPrec); +template void apply<1, ComplexDouble>(double prec, FunctionTree<1, ComplexDouble> &out, ConvolutionOperator<1> &oper, FunctionTree<1, ComplexDouble> &inp, FunctionTreeVector<1, ComplexDouble> &precTrees, int maxIter, bool absPrec); +template void apply<2, ComplexDouble>(double prec, FunctionTree<2, ComplexDouble> &out, ConvolutionOperator<2> &oper, FunctionTree<2, ComplexDouble> &inp, FunctionTreeVector<2, ComplexDouble> &precTrees, int maxIter, bool absPrec); +template void apply<3, ComplexDouble>(double prec, FunctionTree<3, ComplexDouble> &out, ConvolutionOperator<3> &oper, FunctionTree<3, ComplexDouble> &inp, FunctionTreeVector<3, ComplexDouble> &precTrees, int maxIter, bool absPrec); +template void apply_far_field<1, ComplexDouble>(double prec, FunctionTree<1, ComplexDouble> &out, ConvolutionOperator<1> &oper, FunctionTree<1, ComplexDouble> &inp, int maxIter, bool absPrec); +template void apply_far_field<2, ComplexDouble>(double prec, FunctionTree<2, ComplexDouble> &out, ConvolutionOperator<2> &oper, FunctionTree<2, ComplexDouble> &inp, int maxIter, bool absPrec); +template void apply_far_field<3, ComplexDouble>(double prec, FunctionTree<3, ComplexDouble> &out, ConvolutionOperator<3> &oper, FunctionTree<3, ComplexDouble> &inp, int maxIter, bool absPrec); +template void apply_near_field<1, ComplexDouble>(double prec, FunctionTree<1, ComplexDouble> &out, ConvolutionOperator<1> &oper, FunctionTree<1, ComplexDouble> &inp, int maxIter, bool absPrec); +template void apply_near_field<2, ComplexDouble>(double prec, FunctionTree<2, ComplexDouble> &out, ConvolutionOperator<2> &oper, FunctionTree<2, ComplexDouble> &inp, int maxIter, bool absPrec); +template void apply_near_field<3, ComplexDouble>(double prec, FunctionTree<3, ComplexDouble> &out, ConvolutionOperator<3> &oper, FunctionTree<3, ComplexDouble> &inp, int maxIter, bool absPrec); +template void apply<1, ComplexDouble>(FunctionTree<1, ComplexDouble> &out, DerivativeOperator<1> &oper, FunctionTree<1, ComplexDouble> &inp, int dir); +template void apply<2, ComplexDouble>(FunctionTree<2, ComplexDouble> &out, DerivativeOperator<2> &oper, FunctionTree<2, ComplexDouble> &inp, int dir); +template void apply<3, ComplexDouble>(FunctionTree<3, ComplexDouble> &out, DerivativeOperator<3> &oper, FunctionTree<3, ComplexDouble> &inp, int dir); +template void divergence<1, ComplexDouble>(FunctionTree<1, ComplexDouble> &out, DerivativeOperator<1> &oper, FunctionTreeVector<1, ComplexDouble> &inp); +template void divergence<2, ComplexDouble>(FunctionTree<2, ComplexDouble> &out, DerivativeOperator<2> &oper, FunctionTreeVector<2, ComplexDouble> &inp); +template void divergence<3, ComplexDouble>(FunctionTree<3, ComplexDouble> &out, DerivativeOperator<3> &oper, FunctionTreeVector<3, ComplexDouble> &inp); +template void divergence<1, ComplexDouble>(FunctionTree<1, ComplexDouble> &out, DerivativeOperator<1> &oper, std::vector *> &inp); +template void divergence<2, ComplexDouble>(FunctionTree<2, ComplexDouble> &out, DerivativeOperator<2> &oper, std::vector *> &inp); +template void divergence<3, ComplexDouble>(FunctionTree<3, ComplexDouble> &out, DerivativeOperator<3> &oper, std::vector *> &inp); +template FunctionTreeVector<1, ComplexDouble> gradient<1>(DerivativeOperator<1> &oper, FunctionTree<1, ComplexDouble> &inp); +template FunctionTreeVector<2, ComplexDouble> gradient<2>(DerivativeOperator<2> &oper, FunctionTree<2, ComplexDouble> &inp); +template FunctionTreeVector<3, ComplexDouble> gradient<3>(DerivativeOperator<3> &oper, FunctionTree<3, ComplexDouble> &inp); } // namespace mrcpp diff --git a/src/treebuilders/apply.h b/src/treebuilders/apply.h index ae38e96ad..f6217e381 100644 --- a/src/treebuilders/apply.h +++ b/src/treebuilders/apply.h @@ -30,18 +30,18 @@ namespace mrcpp { // clang-format off -template class FunctionTree; +template class FunctionTree; template class DerivativeOperator; template class ConvolutionOperator; -template void apply(double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, int maxIter = -1, bool absPrec = false); -template void apply(double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, FunctionTreeVector &precTrees, int maxIter = -1, bool absPrec = false); -template void apply_far_field(double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, int maxIter = -1, bool absPrec = false); -template void apply_near_field(double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, int maxIter = -1, bool absPrec = false); -template void apply(FunctionTree &out, DerivativeOperator &oper, FunctionTree &inp, int dir = -1); -template void divergence(FunctionTree &out, DerivativeOperator &oper, FunctionTreeVector &inp); -template void divergence(FunctionTree &out, DerivativeOperator &oper, std::vector *> &inp); -template FunctionTreeVector gradient(DerivativeOperator &oper, FunctionTree &inp); +template void apply(double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, int maxIter = -1, bool absPrec = false); +template void apply(double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, FunctionTreeVector &precTrees, int maxIter = -1, bool absPrec = false); +template void apply_far_field(double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, int maxIter = -1, bool absPrec = false); +template void apply_near_field(double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, int maxIter = -1, bool absPrec = false); +template void apply(FunctionTree &out, DerivativeOperator &oper, FunctionTree &inp, int dir = -1); +template void divergence(FunctionTree &out, DerivativeOperator &oper, FunctionTreeVector &inp); +template void divergence(FunctionTree &out, DerivativeOperator &oper, std::vector *> &inp); +template FunctionTreeVector gradient(DerivativeOperator &oper, FunctionTree &inp); // clang-format on } // namespace mrcpp diff --git a/src/treebuilders/complex_apply.cpp b/src/treebuilders/complex_apply.cpp index ab410244e..8b22d19f6 100644 --- a/src/treebuilders/complex_apply.cpp +++ b/src/treebuilders/complex_apply.cpp @@ -101,9 +101,6 @@ void apply } - - - template void apply <1> ( diff --git a/src/treebuilders/grid.cpp b/src/treebuilders/grid.cpp index fb9e65a91..86f71f41c 100644 --- a/src/treebuilders/grid.cpp +++ b/src/treebuilders/grid.cpp @@ -48,11 +48,11 @@ namespace mrcpp { * @note This algorithm will start at whatever grid is present in the `out` * tree when the function is called. */ -template void build_grid(FunctionTree &out, int scales) { +template void build_grid(FunctionTree &out, int scales) { auto maxScale = out.getMRA().getMaxScale(); - TreeBuilder builder; - DefaultCalculator calculator; - SplitAdaptor adaptor(maxScale, true); // Splits all nodes + TreeBuilder builder; + DefaultCalculator calculator; + SplitAdaptor adaptor(maxScale, true); // Splits all nodes for (auto n = 0; n < scales; n++) builder.build(out, calculator, adaptor, 1); } @@ -75,11 +75,11 @@ template void build_grid(FunctionTree &out, int scales) { * particular `RepresentableFunction`. * */ -template void build_grid(FunctionTree &out, const RepresentableFunction &inp, int maxIter) { +template void build_grid(FunctionTree &out, const RepresentableFunction &inp, int maxIter) { auto maxScale = out.getMRA().getMaxScale(); - TreeBuilder builder; - AnalyticAdaptor adaptor(inp, maxScale); - DefaultCalculator calculator; + TreeBuilder builder; + AnalyticAdaptor adaptor(inp, maxScale); + DefaultCalculator calculator; builder.build(out, calculator, adaptor, maxIter); print::separator(10, ' '); } @@ -109,7 +109,7 @@ template void build_grid(FunctionTree &out, const GaussExp &inp, i TreeBuilder builder; DefaultCalculator calculator; for (auto i = 0; i < inp.size(); i++) { - AnalyticAdaptor adaptor(inp.getFunc(i), maxScale); + AnalyticAdaptor adaptor(inp.getFunc(i), maxScale); builder.build(out, calculator, adaptor, maxIter); } } else { @@ -142,12 +142,12 @@ template void build_grid(FunctionTree &out, const GaussExp &inp, i * but NOT vice versa. * */ -template void build_grid(FunctionTree &out, FunctionTree &inp, int maxIter) { +template void build_grid(FunctionTree &out, FunctionTree &inp, int maxIter) { if (out.getMRA() != inp.getMRA()) MSG_ABORT("Incompatible MRA"); auto maxScale = out.getMRA().getMaxScale(); - TreeBuilder builder; - CopyAdaptor adaptor(inp, maxScale, nullptr); - DefaultCalculator calculator; + TreeBuilder builder; + CopyAdaptor adaptor(inp, maxScale, nullptr); + DefaultCalculator calculator; builder.build(out, calculator, adaptor, maxIter); print::separator(10, ' '); } @@ -171,20 +171,20 @@ template void build_grid(FunctionTree &out, FunctionTree &inp, int * `maxIter` is reached). * */ -template void build_grid(FunctionTree &out, FunctionTreeVector &inp, int maxIter) { +template void build_grid(FunctionTree &out, FunctionTreeVector &inp, int maxIter) { for (auto i = 0; i < inp.size(); i++) if (out.getMRA() != get_func(inp, i).getMRA()) MSG_ABORT("Incompatible MRA"); auto maxScale = out.getMRA().getMaxScale(); - TreeBuilder builder; - CopyAdaptor adaptor(inp, maxScale, nullptr); - DefaultCalculator calculator; + TreeBuilder builder; + CopyAdaptor adaptor(inp, maxScale, nullptr); + DefaultCalculator calculator; builder.build(out, calculator, adaptor, maxIter); print::separator(10, ' '); } -template void build_grid(FunctionTree &out, std::vector *> &inp, int maxIter) { - FunctionTreeVector inp_vec; +template void build_grid(FunctionTree &out, std::vector *> &inp, int maxIter) { + FunctionTreeVector inp_vec; for (auto *t : inp) inp_vec.push_back({1.0, t}); build_grid(out, inp_vec, maxIter); } @@ -202,8 +202,8 @@ template void build_grid(FunctionTree &out, std::vector void copy_func(FunctionTree &out, FunctionTree &inp) { - FunctionTreeVector tmp_vec; +template void copy_func(FunctionTree &out, FunctionTree &inp) { + FunctionTreeVector tmp_vec; tmp_vec.push_back(std::make_tuple(1.0, &inp)); add(-1.0, out, tmp_vec); } @@ -218,7 +218,7 @@ template void copy_func(FunctionTree &out, FunctionTree &inp) { * will _extend_ the existing grid. * */ -template void copy_grid(FunctionTree &out, FunctionTree &inp) { +template void copy_grid(FunctionTree &out, FunctionTree &inp) { if (out.getMRA() != inp.getMRA()) MSG_ABORT("Incompatible MRA") out.clear(); build_grid(out, inp); @@ -233,9 +233,9 @@ template void copy_grid(FunctionTree &out, FunctionTree &inp) { * grid refinement as well. * */ -template void clear_grid(FunctionTree &out) { - TreeBuilder builder; - DefaultCalculator calculator; +template void clear_grid(FunctionTree &out) { + TreeBuilder builder; + DefaultCalculator calculator; builder.clear(out, calculator); } @@ -250,11 +250,11 @@ template void clear_grid(FunctionTree &out) { * the function representation unchanged, but on a larger grid. * */ -template int refine_grid(FunctionTree &out, int scales) { +template int refine_grid(FunctionTree &out, int scales) { auto nSplit = 0; auto maxScale = out.getMRA().getMaxScale(); - TreeBuilder builder; - SplitAdaptor adaptor(maxScale, true); // Splits all nodes + TreeBuilder builder; + SplitAdaptor adaptor(maxScale, true); // Splits all nodes for (auto n = 0; n < scales; n++) { nSplit += builder.split(out, adaptor, true); // Transfers coefs to children } @@ -274,10 +274,10 @@ template int refine_grid(FunctionTree &out, int scales) { * unchanged, but (possibly) on a larger grid. * */ -template int refine_grid(FunctionTree &out, double prec, bool absPrec) { +template int refine_grid(FunctionTree &out, double prec, bool absPrec) { int maxScale = out.getMRA().getMaxScale(); - TreeBuilder builder; - WaveletAdaptor adaptor(prec, maxScale, absPrec); + TreeBuilder builder; + WaveletAdaptor adaptor(prec, maxScale, absPrec); int nSplit = builder.split(out, adaptor, true); return nSplit; } @@ -294,11 +294,11 @@ template int refine_grid(FunctionTree &out, double prec, bool absPrec * leaving the function representation unchanged, but on a larger grid. * */ -template int refine_grid(FunctionTree &out, FunctionTree &inp) { +template int refine_grid(FunctionTree &out, FunctionTree &inp) { if (out.getMRA() != inp.getMRA()) MSG_ABORT("Incompatible MRA") auto maxScale = out.getMRA().getMaxScale(); - TreeBuilder builder; - CopyAdaptor adaptor(inp, maxScale, nullptr); + TreeBuilder builder; + CopyAdaptor adaptor(inp, maxScale, nullptr); auto nSplit = builder.split(out, adaptor, true); return nSplit; } @@ -316,52 +316,90 @@ template int refine_grid(FunctionTree &out, FunctionTree &inp) { * is implemented in the particular `RepresentableFunction`. * */ -template int refine_grid(FunctionTree &out, const RepresentableFunction &inp) { +template int refine_grid(FunctionTree &out, const RepresentableFunction &inp) { auto maxScale = out.getMRA().getMaxScale(); - TreeBuilder builder; - AnalyticAdaptor adaptor(inp, maxScale); + TreeBuilder builder; + AnalyticAdaptor adaptor(inp, maxScale); int nSplit = builder.split(out, adaptor, true); return nSplit; } -template void build_grid<1>(FunctionTree<1> &out, int scales); -template void build_grid<2>(FunctionTree<2> &out, int scales); -template void build_grid<3>(FunctionTree<3> &out, int scales); +template void build_grid<1, double>(FunctionTree<1, double> &out, int scales); +template void build_grid<2, double>(FunctionTree<2, double> &out, int scales); +template void build_grid<3, double>(FunctionTree<3, double> &out, int scales); template void build_grid<1>(FunctionTree<1> &out, const GaussExp<1> &inp, int maxIter); template void build_grid<2>(FunctionTree<2> &out, const GaussExp<2> &inp, int maxIter); template void build_grid<3>(FunctionTree<3> &out, const GaussExp<3> &inp, int maxIter); -template void build_grid<1>(FunctionTree<1> &out, const RepresentableFunction<1> &inp, int maxIter); -template void build_grid<2>(FunctionTree<2> &out, const RepresentableFunction<2> &inp, int maxIter); -template void build_grid<3>(FunctionTree<3> &out, const RepresentableFunction<3> &inp, int maxIter); -template void build_grid<1>(FunctionTree<1> &out, FunctionTree<1> &inp, int maxIter); -template void build_grid<2>(FunctionTree<2> &out, FunctionTree<2> &inp, int maxIter); -template void build_grid<3>(FunctionTree<3> &out, FunctionTree<3> &inp, int maxIter); -template void build_grid<1>(FunctionTree<1> &out, FunctionTreeVector<1> &inp, int maxIter); -template void build_grid<2>(FunctionTree<2> &out, FunctionTreeVector<2> &inp, int maxIter); -template void build_grid<3>(FunctionTree<3> &out, FunctionTreeVector<3> &inp, int maxIter); -template void build_grid<1>(FunctionTree<1> &out, std::vector *> &inp, int maxIter); -template void build_grid<2>(FunctionTree<2> &out, std::vector *> &inp, int maxIter); -template void build_grid<3>(FunctionTree<3> &out, std::vector *> &inp, int maxIter); -template void copy_func<1>(FunctionTree<1> &out, FunctionTree<1> &inp); -template void copy_func<2>(FunctionTree<2> &out, FunctionTree<2> &inp); -template void copy_func<3>(FunctionTree<3> &out, FunctionTree<3> &inp); -template void copy_grid<1>(FunctionTree<1> &out, FunctionTree<1> &inp); -template void copy_grid<2>(FunctionTree<2> &out, FunctionTree<2> &inp); -template void copy_grid<3>(FunctionTree<3> &out, FunctionTree<3> &inp); -template void clear_grid<1>(FunctionTree<1> &out); -template void clear_grid<2>(FunctionTree<2> &out); -template void clear_grid<3>(FunctionTree<3> &out); -template int refine_grid<1>(FunctionTree<1> &out, int scales); -template int refine_grid<2>(FunctionTree<2> &out, int scales); -template int refine_grid<3>(FunctionTree<3> &out, int scales); -template int refine_grid<1>(FunctionTree<1> &out, double prec, bool absPrec); -template int refine_grid<2>(FunctionTree<2> &out, double prec, bool absPrec); -template int refine_grid<3>(FunctionTree<3> &out, double prec, bool absPrec); -template int refine_grid<1>(FunctionTree<1> &out, FunctionTree<1> &inp); -template int refine_grid<2>(FunctionTree<2> &out, FunctionTree<2> &inp); -template int refine_grid<3>(FunctionTree<3> &out, FunctionTree<3> &inp); -template int refine_grid<1>(FunctionTree<1> &out, const RepresentableFunction<1> &inp); -template int refine_grid<2>(FunctionTree<2> &out, const RepresentableFunction<2> &inp); -template int refine_grid<3>(FunctionTree<3> &out, const RepresentableFunction<3> &inp); +template void build_grid<1, double>(FunctionTree<1, double> &out, const RepresentableFunction<1, double> &inp, int maxIter); +template void build_grid<2, double>(FunctionTree<2, double> &out, const RepresentableFunction<2, double> &inp, int maxIter); +template void build_grid<3, double>(FunctionTree<3, double> &out, const RepresentableFunction<3, double> &inp, int maxIter); +template void build_grid<1, double>(FunctionTree<1, double> &out, FunctionTree<1, double> &inp, int maxIter); +template void build_grid<2, double>(FunctionTree<2, double> &out, FunctionTree<2, double> &inp, int maxIter); +template void build_grid<3, double>(FunctionTree<3, double> &out, FunctionTree<3, double> &inp, int maxIter); +template void build_grid<1, double>(FunctionTree<1, double> &out, FunctionTreeVector<1, double> &inp, int maxIter); +template void build_grid<2, double>(FunctionTree<2, double> &out, FunctionTreeVector<2, double> &inp, int maxIter); +template void build_grid<3, double>(FunctionTree<3, double> &out, FunctionTreeVector<3, double> &inp, int maxIter); +template void build_grid<1, double>(FunctionTree<1, double> &out, std::vector *> &inp, int maxIter); +template void build_grid<2, double>(FunctionTree<2, double> &out, std::vector *> &inp, int maxIter); +template void build_grid<3, double>(FunctionTree<3, double> &out, std::vector *> &inp, int maxIter); +template void copy_func<1, double>(FunctionTree<1, double> &out, FunctionTree<1, double> &inp); +template void copy_func<2, double>(FunctionTree<2, double> &out, FunctionTree<2, double> &inp); +template void copy_func<3, double>(FunctionTree<3, double> &out, FunctionTree<3, double> &inp); +template void copy_grid<1, double>(FunctionTree<1, double> &out, FunctionTree<1, double> &inp); +template void copy_grid<2, double>(FunctionTree<2, double> &out, FunctionTree<2, double> &inp); +template void copy_grid<3, double>(FunctionTree<3, double> &out, FunctionTree<3, double> &inp); +template void clear_grid<1, double>(FunctionTree<1, double> &out); +template void clear_grid<2, double>(FunctionTree<2, double> &out); +template void clear_grid<3, double>(FunctionTree<3, double> &out); +template int refine_grid<1, double>(FunctionTree<1, double> &out, int scales); +template int refine_grid<2, double>(FunctionTree<2, double> &out, int scales); +template int refine_grid<3, double>(FunctionTree<3, double> &out, int scales); +template int refine_grid<1, double>(FunctionTree<1, double> &out, double prec, bool absPrec); +template int refine_grid<2, double>(FunctionTree<2, double> &out, double prec, bool absPrec); +template int refine_grid<3, double>(FunctionTree<3, double> &out, double prec, bool absPrec); +template int refine_grid<1, double>(FunctionTree<1, double> &out, FunctionTree<1, double> &inp); +template int refine_grid<2, double>(FunctionTree<2, double> &out, FunctionTree<2, double> &inp); +template int refine_grid<3, double>(FunctionTree<3, double> &out, FunctionTree<3, double> &inp); +template int refine_grid<1, double>(FunctionTree<1, double> &out, const RepresentableFunction<1, double> &inp); +template int refine_grid<2, double>(FunctionTree<2, double> &out, const RepresentableFunction<2, double> &inp); +template int refine_grid<3, double>(FunctionTree<3, double> &out, const RepresentableFunction<3, double> &inp); + + +template void build_grid<1, ComplexDouble>(FunctionTree<1, ComplexDouble> &out, int scales); +template void build_grid<2, ComplexDouble>(FunctionTree<2, ComplexDouble> &out, int scales); +template void build_grid<3, ComplexDouble>(FunctionTree<3, ComplexDouble> &out, int scales); +template void build_grid<1, ComplexDouble>(FunctionTree<1, ComplexDouble> &out, const RepresentableFunction<1, ComplexDouble> &inp, int maxIter); +template void build_grid<2, ComplexDouble>(FunctionTree<2, ComplexDouble> &out, const RepresentableFunction<2, ComplexDouble> &inp, int maxIter); +template void build_grid<3, ComplexDouble>(FunctionTree<3, ComplexDouble> &out, const RepresentableFunction<3, ComplexDouble> &inp, int maxIter); +template void build_grid<1, ComplexDouble>(FunctionTree<1, ComplexDouble> &out, FunctionTree<1, ComplexDouble> &inp, int maxIter); +template void build_grid<2, ComplexDouble>(FunctionTree<2, ComplexDouble> &out, FunctionTree<2, ComplexDouble> &inp, int maxIter); +template void build_grid<3, ComplexDouble>(FunctionTree<3, ComplexDouble> &out, FunctionTree<3, ComplexDouble> &inp, int maxIter); +template void build_grid<1, ComplexDouble>(FunctionTree<1, ComplexDouble> &out, FunctionTreeVector<1, ComplexDouble> &inp, int maxIter); +template void build_grid<2, ComplexDouble>(FunctionTree<2, ComplexDouble> &out, FunctionTreeVector<2, ComplexDouble> &inp, int maxIter); +template void build_grid<3, ComplexDouble>(FunctionTree<3, ComplexDouble> &out, FunctionTreeVector<3, ComplexDouble> &inp, int maxIter); +template void build_grid<1, ComplexDouble>(FunctionTree<1, ComplexDouble> &out, std::vector *> &inp, int maxIter); +template void build_grid<2, ComplexDouble>(FunctionTree<2, ComplexDouble> &out, std::vector *> &inp, int maxIter); +template void build_grid<3, ComplexDouble>(FunctionTree<3, ComplexDouble> &out, std::vector *> &inp, int maxIter); +template void copy_func<1, ComplexDouble>(FunctionTree<1, ComplexDouble> &out, FunctionTree<1, ComplexDouble> &inp); +template void copy_func<2, ComplexDouble>(FunctionTree<2, ComplexDouble> &out, FunctionTree<2, ComplexDouble> &inp); +template void copy_func<3, ComplexDouble>(FunctionTree<3, ComplexDouble> &out, FunctionTree<3, ComplexDouble> &inp); +template void copy_grid<1, ComplexDouble>(FunctionTree<1, ComplexDouble> &out, FunctionTree<1, ComplexDouble> &inp); +template void copy_grid<2, ComplexDouble>(FunctionTree<2, ComplexDouble> &out, FunctionTree<2, ComplexDouble> &inp); +template void copy_grid<3, ComplexDouble>(FunctionTree<3, ComplexDouble> &out, FunctionTree<3, ComplexDouble> &inp); +template void clear_grid<1, ComplexDouble>(FunctionTree<1, ComplexDouble> &out); +template void clear_grid<2, ComplexDouble>(FunctionTree<2, ComplexDouble> &out); +template void clear_grid<3, ComplexDouble>(FunctionTree<3, ComplexDouble> &out); +template int refine_grid<1, ComplexDouble>(FunctionTree<1, ComplexDouble> &out, int scales); +template int refine_grid<2, ComplexDouble>(FunctionTree<2, ComplexDouble> &out, int scales); +template int refine_grid<3, ComplexDouble>(FunctionTree<3, ComplexDouble> &out, int scales); +template int refine_grid<1, ComplexDouble>(FunctionTree<1, ComplexDouble> &out, double prec, bool absPrec); +template int refine_grid<2, ComplexDouble>(FunctionTree<2, ComplexDouble> &out, double prec, bool absPrec); +template int refine_grid<3, ComplexDouble>(FunctionTree<3, ComplexDouble> &out, double prec, bool absPrec); +template int refine_grid<1, ComplexDouble>(FunctionTree<1, ComplexDouble> &out, FunctionTree<1, ComplexDouble> &inp); +template int refine_grid<2, ComplexDouble>(FunctionTree<2, ComplexDouble> &out, FunctionTree<2, ComplexDouble> &inp); +template int refine_grid<3, ComplexDouble>(FunctionTree<3, ComplexDouble> &out, FunctionTree<3, ComplexDouble> &inp); +template int refine_grid<1, ComplexDouble>(FunctionTree<1, ComplexDouble> &out, const RepresentableFunction<1, ComplexDouble> &inp); +template int refine_grid<2, ComplexDouble>(FunctionTree<2, ComplexDouble> &out, const RepresentableFunction<2, ComplexDouble> &inp); +template int refine_grid<3, ComplexDouble>(FunctionTree<3, ComplexDouble> &out, const RepresentableFunction<3, ComplexDouble> &inp); } // namespace mrcpp diff --git a/src/treebuilders/grid.h b/src/treebuilders/grid.h index 1f4c3e4f5..42f54aa0a 100644 --- a/src/treebuilders/grid.h +++ b/src/treebuilders/grid.h @@ -30,17 +30,17 @@ #include "trees/FunctionTreeVector.h" namespace mrcpp { -template void build_grid(FunctionTree &out, int scales); +template void build_grid(FunctionTree &out, int scales); template void build_grid(FunctionTree &out, const GaussExp &inp, int maxIter = -1); -template void build_grid(FunctionTree &out, const RepresentableFunction &inp, int maxIter = -1); -template void build_grid(FunctionTree &out, FunctionTree &inp, int maxIter = -1); -template void build_grid(FunctionTree &out, FunctionTreeVector &inp, int maxIter = -1); -template void build_grid(FunctionTree &out, std::vector *> &inp, int maxIter = -1); -template void copy_func(FunctionTree &out, FunctionTree &inp); -template void copy_grid(FunctionTree &out, FunctionTree &inp); -template void clear_grid(FunctionTree &out); -template int refine_grid(FunctionTree &out, int scales); -template int refine_grid(FunctionTree &out, double prec, bool absPrec = false); -template int refine_grid(FunctionTree &out, FunctionTree &inp); -template int refine_grid(FunctionTree &out, const RepresentableFunction &inp); +template void build_grid(FunctionTree &out, const RepresentableFunction &inp, int maxIter = -1); +template void build_grid(FunctionTree &out, FunctionTree &inp, int maxIter = -1); +template void build_grid(FunctionTree &out, FunctionTreeVector &inp, int maxIter = -1); +template void build_grid(FunctionTree &out, std::vector *> &inp, int maxIter = -1); +template void copy_func(FunctionTree &out, FunctionTree &inp); +template void copy_grid(FunctionTree &out, FunctionTree &inp); +template void clear_grid(FunctionTree &out); +template int refine_grid(FunctionTree &out, int scales); +template int refine_grid(FunctionTree &out, double prec, bool absPrec = false); +template int refine_grid(FunctionTree &out, FunctionTree &inp); +template int refine_grid(FunctionTree &out, const RepresentableFunction &inp); } // namespace mrcpp diff --git a/src/treebuilders/map.cpp b/src/treebuilders/map.cpp index 98824d002..ba064ca39 100644 --- a/src/treebuilders/map.cpp +++ b/src/treebuilders/map.cpp @@ -66,12 +66,12 @@ namespace mrcpp { * */ template -void map(double prec, FunctionTree &out, FunctionTree &inp, FMap fmap, int maxIter, bool absPrec) { +void map(double prec, FunctionTree &out, FunctionTree &inp, FMap fmap, int maxIter, bool absPrec) { int maxScale = out.getMRA().getMaxScale(); - TreeBuilder builder; - WaveletAdaptor adaptor(prec, maxScale, absPrec); - MapCalculator calculator(fmap, inp); + TreeBuilder builder; + WaveletAdaptor adaptor(prec, maxScale, absPrec); + MapCalculator calculator(fmap, inp); builder.build(out, calculator, adaptor, maxIter); @@ -89,8 +89,8 @@ void map(double prec, FunctionTree &out, FunctionTree &inp, FMap fmap, int print::separator(10, ' '); } -template void map<1>(double prec, FunctionTree<1> &out, FunctionTree<1> &inp, FMap fmap, int maxIter, bool absPrec); -template void map<2>(double prec, FunctionTree<2> &out, FunctionTree<2> &inp, FMap fmap, int maxIter, bool absPrec); -template void map<3>(double prec, FunctionTree<3> &out, FunctionTree<3> &inp, FMap fmap, int maxIter, bool absPrec); +template void map<1>(double prec, FunctionTree<1, double> &out, FunctionTree<1, double> &inp, FMap fmap, int maxIter, bool absPrec); +template void map<2>(double prec, FunctionTree<2, double> &out, FunctionTree<2, double> &inp, FMap fmap, int maxIter, bool absPrec); +template void map<3>(double prec, FunctionTree<3, double> &out, FunctionTree<3, double> &inp, FMap fmap, int maxIter, bool absPrec); } // Namespace mrcpp diff --git a/src/treebuilders/map.h b/src/treebuilders/map.h index 1c54dac32..4fe3cf72d 100644 --- a/src/treebuilders/map.h +++ b/src/treebuilders/map.h @@ -28,10 +28,9 @@ #include "trees/FunctionTreeVector.h" namespace mrcpp { -template class RepresentableFunction; -template class FunctionTree; +template class FunctionTree; template -void map(double prec, FunctionTree &out, FunctionTree &inp, FMap fmap, int maxIter = -1, bool absPrec = false); +void map(double prec, FunctionTree &out, FunctionTree &inp, FMap fmap, int maxIter = -1, bool absPrec = false); } // namespace mrcpp diff --git a/src/treebuilders/multiply.cpp b/src/treebuilders/multiply.cpp index a21e539ab..6cbf58b72 100644 --- a/src/treebuilders/multiply.cpp +++ b/src/treebuilders/multiply.cpp @@ -68,16 +68,16 @@ namespace mrcpp { * no coefs). * */ -template +template void multiply(double prec, - FunctionTree &out, - double c, - FunctionTree &inp_a, - FunctionTree &inp_b, + FunctionTree &out, + T c, + FunctionTree &inp_a, + FunctionTree &inp_b, int maxIter, bool absPrec, bool useMaxNorms) { - FunctionTreeVector tmp_vec; + FunctionTreeVector tmp_vec; tmp_vec.push_back({c, &inp_a}); tmp_vec.push_back({1.0, &inp_b}); multiply(prec, out, tmp_vec, maxIter, absPrec, useMaxNorms); @@ -106,10 +106,10 @@ void multiply(double prec, * no coefs). * */ -template +template void multiply(double prec, - FunctionTree &out, - FunctionTreeVector &inp, + FunctionTree &out, + FunctionTreeVector &inp, int maxIter, bool absPrec, bool useMaxNorms) { @@ -117,15 +117,15 @@ void multiply(double prec, if (out.getMRA() != get_func(inp, i).getMRA()) MSG_ABORT("Incompatible MRA"); int maxScale = out.getMRA().getMaxScale(); - TreeBuilder builder; - MultiplicationCalculator calculator(inp); + TreeBuilder builder; + MultiplicationCalculator calculator(inp); if (useMaxNorms) { for (int i = 0; i < inp.size(); i++) get_func(inp, i).makeMaxSquareNorms(); - MultiplicationAdaptor adaptor(prec, maxScale, inp); + MultiplicationAdaptor adaptor(prec, maxScale, inp); builder.build(out, calculator, adaptor, maxIter); } else { - WaveletAdaptor adaptor(prec, maxScale, absPrec); + WaveletAdaptor adaptor(prec, maxScale, absPrec); builder.build(out, calculator, adaptor, maxIter); } @@ -136,7 +136,7 @@ void multiply(double prec, Timer clean_t; for (int i = 0; i < inp.size(); i++) { - FunctionTree &tree = get_func(inp, i); + FunctionTree &tree = get_func(inp, i); tree.deleteGenerated(); } clean_t.stop(); @@ -146,14 +146,14 @@ void multiply(double prec, print::separator(10, ' '); } -template +template void multiply(double prec, - FunctionTree &out, - std::vector *> &inp, + FunctionTree &out, + std::vector *> &inp, int maxIter, bool absPrec, bool useMaxNorms) { - FunctionTreeVector inp_vec; + FunctionTreeVector inp_vec; for (auto &t : inp) inp_vec.push_back({1.0, t}); multiply(prec, out, inp_vec, maxIter, absPrec, useMaxNorms); } @@ -179,13 +179,13 @@ void multiply(double prec, * no coefs). * */ -template void square(double prec, FunctionTree &out, FunctionTree &inp, int maxIter, bool absPrec) { +template void square(double prec, FunctionTree &out, FunctionTree &inp, int maxIter, bool absPrec) { if (out.getMRA() != inp.getMRA()) MSG_ABORT("Incompatible MRA"); int maxScale = out.getMRA().getMaxScale(); - TreeBuilder builder; - WaveletAdaptor adaptor(prec, maxScale, absPrec); - SquareCalculator calculator(inp); + TreeBuilder builder; + WaveletAdaptor adaptor(prec, maxScale, absPrec); + SquareCalculator calculator(inp); builder.build(out, calculator, adaptor, maxIter); @@ -225,14 +225,14 @@ template void square(double prec, FunctionTree &out, FunctionTree * no coefs). * */ -template -void power(double prec, FunctionTree &out, FunctionTree &inp, double p, int maxIter, bool absPrec) { +template +void power(double prec, FunctionTree &out, FunctionTree &inp, double p, int maxIter, bool absPrec) { if (out.getMRA() != inp.getMRA()) MSG_ABORT("Incompatible MRA"); int maxScale = out.getMRA().getMaxScale(); - TreeBuilder builder; - WaveletAdaptor adaptor(prec, maxScale, absPrec); - PowerCalculator calculator(inp, p); + TreeBuilder builder; + WaveletAdaptor adaptor(prec, maxScale, absPrec); + PowerCalculator calculator(inp, p); builder.build(out, calculator, adaptor, maxIter); @@ -267,24 +267,25 @@ void power(double prec, FunctionTree &out, FunctionTree &inp, double p, in * @note The length of the input vectors must be the same. * */ -template +template void dot(double prec, - FunctionTree &out, - FunctionTreeVector &inp_a, - FunctionTreeVector &inp_b, + FunctionTree &out, + FunctionTreeVector &inp_a, + FunctionTreeVector &inp_b, int maxIter, bool absPrec) { if (inp_a.size() != inp_b.size()) MSG_ABORT("Input length mismatch"); - FunctionTreeVector tmp_vec; + FunctionTreeVector tmp_vec; for (int d = 0; d < inp_a.size(); d++) { - double coef_a = get_coef(inp_a, d); - double coef_b = get_coef(inp_b, d); - FunctionTree &tree_a = get_func(inp_a, d); - FunctionTree &tree_b = get_func(inp_b, d); - auto *out_d = new FunctionTree(out.getMRA()); + T coef_a = get_coef(inp_a, d); + T coef_b = get_coef(inp_b, d); + FunctionTree &tree_a = get_func(inp_a, d); + FunctionTree &tree_b = get_func(inp_b, d); + auto *out_d = new FunctionTree(out.getMRA()); build_grid(*out_d, out); - multiply(prec, *out_d, 1.0, tree_a, tree_b, maxIter, absPrec); + T One = 1.0; + multiply(prec, *out_d, One, tree_a, tree_b, maxIter, absPrec); tmp_vec.push_back({coef_a * coef_b, out_d}); } build_grid(out, tmp_vec); @@ -305,19 +306,19 @@ void dot(double prec, * grids overlap. * */ -template double dot(FunctionTree &bra, FunctionTree &ket) { +template T dot(FunctionTree &bra, FunctionTree &ket) { if (bra.getMRA() != ket.getMRA()) MSG_ABORT("Trees not compatible"); - MWNodeVector nodeTable; - TreeIterator it(bra); + MWNodeVector nodeTable; + TreeIterator it(bra); it.setReturnGenNodes(false); while (it.next()) { - MWNode &node = it.getNode(); + MWNode &node = it.getNode(); nodeTable.push_back(&node); } int nNodes = nodeTable.size(); - double result = 0.0; - double locResult = 0.0; + T result = 0.0; + T locResult = 0.0; // OMP is disabled in order to get EXACT results (to the very last digit), the // order of summation makes the result different beyond the 14th digit or so. // OMP does improve the performace, but its not worth it for the time being. @@ -326,11 +327,11 @@ template double dot(FunctionTree &bra, FunctionTree &ket) { // { //#pragma omp for schedule(guided) for (int n = 0; n < nNodes; n++) { - const auto &braNode = static_cast &>(*nodeTable[n]); - const MWNode *mwNode = ket.findNode(braNode.getNodeIndex()); + const auto &braNode = static_cast &>(*nodeTable[n]); + const MWNode *mwNode = ket.findNode(braNode.getNodeIndex()); if (mwNode == nullptr) continue; - const auto &ketNode = static_cast &>(*mwNode); + const auto &ketNode = static_cast &>(*mwNode); if (braNode.isRootNode()) locResult += dot_scaling(braNode, ketNode); locResult += dot_wavelet(braNode, ketNode); } @@ -352,30 +353,30 @@ template double dot(FunctionTree &bra, FunctionTree &ket) { * distribution within the node. * If the product is zero, the functions are disjoints. */ -template double node_norm_dot(FunctionTree &bra, FunctionTree &ket, bool exact) { +template double node_norm_dot(FunctionTree &bra, FunctionTree &ket, bool exact) { if (bra.getMRA() != ket.getMRA()) MSG_ABORT("Incompatible MRA"); double result = 0.0; int ncoef = bra.getKp1_d() * bra.getTDim(); - double valA[ncoef]; - double valB[ncoef]; + T valA[ncoef]; + T valB[ncoef]; int nNodes = bra.getNEndNodes(); for (int n = 0; n < nNodes; n++) { - FunctionNode &node = bra.getEndFuncNode(n); + FunctionNode &node = bra.getEndFuncNode(n); const NodeIndex idx = node.getNodeIndex(); if (exact) { // convert to interpolating coef, take abs, convert back - FunctionNode *mwNode = static_cast *>(ket.findNode(idx)); + FunctionNode *mwNode = static_cast *>(ket.findNode(idx)); if (mwNode == nullptr) MSG_ABORT("Trees must have same grid"); node.getAbsCoefs(valA); mwNode->getAbsCoefs(valB); - for (int i = 0; i < ncoef; i++) result += valA[i] * valB[i]; + for (int i = 0; i < ncoef; i++) result += std::abs(valA[i] * valB[i]); } else { // approximate by product of node norms int rIdx = ket.getRootBox().getBoxIndex(idx); assert(rIdx >= 0); - const MWNode &root = ket.getRootBox().getNode(rIdx); + const MWNode &root = ket.getRootBox().getNode(rIdx); result += std::sqrt(node.getSquareNorm()) * root.getNodeNorm(idx); } } @@ -383,124 +384,249 @@ template double node_norm_dot(FunctionTree &bra, FunctionTree &ket return result; } -template void multiply<1>(double prec, - FunctionTree<1> &out, +template void multiply<1, double>(double prec, + FunctionTree<1, double> &out, double c, - FunctionTree<1> &tree_a, - FunctionTree<1> &tree_b, + FunctionTree<1, double> &tree_a, + FunctionTree<1, double> &tree_b, int maxIter, bool absPrec, bool useMaxNorms); -template void multiply<2>(double prec, - FunctionTree<2> &out, +template void multiply<2, double>(double prec, + FunctionTree<2, double> &out, double c, - FunctionTree<2> &tree_a, - FunctionTree<2> &tree_b, + FunctionTree<2, double> &tree_a, + FunctionTree<2, double> &tree_b, int maxIter, bool absPrec, bool useMaxNorms); -template void multiply<3>(double prec, - FunctionTree<3> &out, +template void multiply<3, double>(double prec, + FunctionTree<3, double> &out, double c, - FunctionTree<3> &tree_a, - FunctionTree<3> &tree_b, + FunctionTree<3, double> &tree_a, + FunctionTree<3, double> &tree_b, int maxIter, bool absPrec, bool useMaxNorms); -template void multiply<1>(double prec, - FunctionTree<1> &out, - FunctionTreeVector<1> &inp, +template void multiply<1, double>(double prec, + FunctionTree<1, double> &out, + FunctionTreeVector<1, double> &inp, int maxIter, bool absPrec, bool useMaxNorms); -template void multiply<2>(double prec, - FunctionTree<2> &out, - FunctionTreeVector<2> &inp, +template void multiply<2, double>(double prec, + FunctionTree<2, double> &out, + FunctionTreeVector<2, double> &inp, int maxIter, bool absPrec, bool useMaxNorms); -template void multiply<3>(double prec, - FunctionTree<3> &out, - FunctionTreeVector<3> &inp, +template void multiply<3, double>(double prec, + FunctionTree<3, double> &out, + FunctionTreeVector<3, double> &inp, int maxIter, bool absPrec, bool useMaxNorms); -template void multiply<1>(double prec, - FunctionTree<1> &out, - std::vector *> &inp, +template void multiply<1, double>(double prec, + FunctionTree<1, double> &out, + std::vector *> &inp, int maxIter, bool absPrec, bool useMaxNorms); -template void multiply<2>(double prec, - FunctionTree<2> &out, - std::vector *> &inp, +template void multiply<2, double>(double prec, + FunctionTree<2, double> &out, + std::vector *> &inp, int maxIter, bool absPrec, bool useMaxNorms); -template void multiply<3>(double prec, - FunctionTree<3> &out, - std::vector *> &inp, +template void multiply<3, double>(double prec, + FunctionTree<3, double> &out, + std::vector *> &inp, int maxIter, bool absPrec, bool useMaxNorms); -template void power<1>(double prec, - FunctionTree<1> &out, - FunctionTree<1> &tree, +template void power<1, double>(double prec, + FunctionTree<1, double> &out, + FunctionTree<1, double> &tree, double pow, int maxIter, bool absPrec); -template void power<2>(double prec, - FunctionTree<2> &out, - FunctionTree<2> &tree, +template void power<2, double>(double prec, + FunctionTree<2, double> &out, + FunctionTree<2, double> &tree, double pow, int maxIter, bool absPrec); -template void power<3>(double prec, - FunctionTree<3> &out, - FunctionTree<3> &tree, +template void power<3, double>(double prec, + FunctionTree<3, double> &out, + FunctionTree<3, double> &tree, double pow, int maxIter, bool absPrec); -template void square<1>(double prec, - FunctionTree<1> &out, - FunctionTree<1> &tree, +template void square<1, double>(double prec, + FunctionTree<1, double> &out, + FunctionTree<1, double> &tree, int maxIter, bool absPrec); -template void square<2>(double prec, - FunctionTree<2> &out, - FunctionTree<2> &tree, +template void square<2, double>(double prec, + FunctionTree<2, double> &out, + FunctionTree<2, double> &tree, int maxIter, bool absPrec); -template void square<3>(double prec, - FunctionTree<3> &out, - FunctionTree<3> &tree, +template void square<3, double>(double prec, + FunctionTree<3, double> &out, + FunctionTree<3, double> &tree, int maxIter, bool absPrec); -template void dot<1>(double prec, - FunctionTree<1> &out, - FunctionTreeVector<1> &inp_a, - FunctionTreeVector<1> &inp_b, +template void dot<1, double>(double prec, + FunctionTree<1, double> &out, + FunctionTreeVector<1, double> &inp_a, + FunctionTreeVector<1, double> &inp_b, int maxIter, bool absPrec); -template void dot<2>(double prec, - FunctionTree<2> &out, - FunctionTreeVector<2> &inp_a, - FunctionTreeVector<2> &inp_b, +template void dot<2, double>(double prec, + FunctionTree<2, double> &out, + FunctionTreeVector<2, double> &inp_a, + FunctionTreeVector<2, double> &inp_b, int maxIter, bool absPrec); -template void dot<3>(double prec, - FunctionTree<3> &out, - FunctionTreeVector<3> &inp_a, - FunctionTreeVector<3> &inp_b, +template void dot<3, double>(double prec, + FunctionTree<3, double> &out, + FunctionTreeVector<3, double> &inp_a, + FunctionTreeVector<3, double> &inp_b, int maxIter, bool absPrec); -template double dot<1>(FunctionTree<1> &bra, FunctionTree<1> &ket); -template double dot<2>(FunctionTree<2> &bra, FunctionTree<2> &ket); -template double dot<3>(FunctionTree<3> &bra, FunctionTree<3> &ket); +template double dot<1, double>(FunctionTree<1, double> &bra, FunctionTree<1, double> &ket); +template double dot<2, double>(FunctionTree<2, double> &bra, FunctionTree<2, double> &ket); +template double dot<3, double>(FunctionTree<3, double> &bra, FunctionTree<3, double> &ket); + +template double node_norm_dot<1, double>(FunctionTree<1, double> &bra, FunctionTree<1, double> &ket, bool exact); +template double node_norm_dot<2, double>(FunctionTree<2, double> &bra, FunctionTree<2, double> &ket, bool exact); +template double node_norm_dot<3, double>(FunctionTree<3, double> &bra, FunctionTree<3, double> &ket, bool exact); + + + + + +template void multiply<1, ComplexDouble>(double prec, + FunctionTree<1, ComplexDouble> &out, + ComplexDouble c, + FunctionTree<1, ComplexDouble> &tree_a, + FunctionTree<1, ComplexDouble> &tree_b, + int maxIter, + bool absPrec, + bool useMaxNorms); +template void multiply<2, ComplexDouble>(double prec, + FunctionTree<2, ComplexDouble> &out, + ComplexDouble c, + FunctionTree<2, ComplexDouble> &tree_a, + FunctionTree<2, ComplexDouble> &tree_b, + int maxIter, + bool absPrec, + bool useMaxNorms); +template void multiply<3, ComplexDouble>(double prec, + FunctionTree<3, ComplexDouble> &out, + ComplexDouble c, + FunctionTree<3, ComplexDouble> &tree_a, + FunctionTree<3, ComplexDouble> &tree_b, + int maxIter, + bool absPrec, + bool useMaxNorms); +template void multiply<1, ComplexDouble>(double prec, + FunctionTree<1, ComplexDouble> &out, + FunctionTreeVector<1, ComplexDouble> &inp, + int maxIter, + bool absPrec, + bool useMaxNorms); +template void multiply<2, ComplexDouble>(double prec, + FunctionTree<2, ComplexDouble> &out, + FunctionTreeVector<2, ComplexDouble> &inp, + int maxIter, + bool absPrec, + bool useMaxNorms); +template void multiply<3, ComplexDouble>(double prec, + FunctionTree<3, ComplexDouble> &out, + FunctionTreeVector<3, ComplexDouble> &inp, + int maxIter, + bool absPrec, + bool useMaxNorms); +template void multiply<1, ComplexDouble>(double prec, + FunctionTree<1, ComplexDouble> &out, + std::vector *> &inp, + int maxIter, + bool absPrec, + bool useMaxNorms); +template void multiply<2, ComplexDouble>(double prec, + FunctionTree<2, ComplexDouble> &out, + std::vector *> &inp, + int maxIter, + bool absPrec, + bool useMaxNorms); +template void multiply<3, ComplexDouble>(double prec, + FunctionTree<3, ComplexDouble> &out, + std::vector *> &inp, + int maxIter, + bool absPrec, + bool useMaxNorms); +template void power<1, ComplexDouble>(double prec, + FunctionTree<1, ComplexDouble> &out, + FunctionTree<1, ComplexDouble> &tree, + double pow, + int maxIter, + bool absPrec); +template void power<2, ComplexDouble>(double prec, + FunctionTree<2, ComplexDouble> &out, + FunctionTree<2, ComplexDouble> &tree, + double pow, + int maxIter, + bool absPrec); +template void power<3, ComplexDouble>(double prec, + FunctionTree<3, ComplexDouble> &out, + FunctionTree<3, ComplexDouble> &tree, + double pow, + int maxIter, + bool absPrec); +template void square<1, ComplexDouble>(double prec, + FunctionTree<1, ComplexDouble> &out, + FunctionTree<1, ComplexDouble> &tree, + int maxIter, + bool absPrec); +template void square<2, ComplexDouble>(double prec, + FunctionTree<2, ComplexDouble> &out, + FunctionTree<2, ComplexDouble> &tree, + int maxIter, + bool absPrec); +template void square<3, ComplexDouble>(double prec, + FunctionTree<3, ComplexDouble> &out, + FunctionTree<3, ComplexDouble> &tree, + int maxIter, + bool absPrec); +template void dot<1, ComplexDouble>(double prec, + FunctionTree<1, ComplexDouble> &out, + FunctionTreeVector<1, ComplexDouble> &inp_a, + FunctionTreeVector<1, ComplexDouble> &inp_b, + int maxIter, + bool absPrec); +template void dot<2, ComplexDouble>(double prec, + FunctionTree<2, ComplexDouble> &out, + FunctionTreeVector<2, ComplexDouble> &inp_a, + FunctionTreeVector<2, ComplexDouble> &inp_b, + int maxIter, + bool absPrec); +template void dot<3, ComplexDouble>(double prec, + FunctionTree<3, ComplexDouble> &out, + FunctionTreeVector<3, ComplexDouble> &inp_a, + FunctionTreeVector<3, ComplexDouble> &inp_b, + int maxIter, + bool absPrec); + +template ComplexDouble dot<1, ComplexDouble>(FunctionTree<1, ComplexDouble> &bra, FunctionTree<1, ComplexDouble> &ket); +template ComplexDouble dot<2, ComplexDouble>(FunctionTree<2, ComplexDouble> &bra, FunctionTree<2, ComplexDouble> &ket); +template ComplexDouble dot<3, ComplexDouble>(FunctionTree<3, ComplexDouble> &bra, FunctionTree<3, ComplexDouble> &ket); + +template double node_norm_dot<1, ComplexDouble>(FunctionTree<1, ComplexDouble> &bra, FunctionTree<1, ComplexDouble> &ket, bool exact); +template double node_norm_dot<2, ComplexDouble>(FunctionTree<2, ComplexDouble> &bra, FunctionTree<2, ComplexDouble> &ket, bool exact); +template double node_norm_dot<3, ComplexDouble>(FunctionTree<3, ComplexDouble> &bra, FunctionTree<3, ComplexDouble> &ket, bool exact); -template double node_norm_dot<1>(FunctionTree<1> &bra, FunctionTree<1> &ket, bool exact); -template double node_norm_dot<2>(FunctionTree<2> &bra, FunctionTree<2> &ket, bool exact); -template double node_norm_dot<3>(FunctionTree<3> &bra, FunctionTree<3> &ket, bool exact); } // namespace mrcpp diff --git a/src/treebuilders/multiply.h b/src/treebuilders/multiply.h index 54947bf78..96a956f3b 100644 --- a/src/treebuilders/multiply.h +++ b/src/treebuilders/multiply.h @@ -28,56 +28,56 @@ #include "trees/FunctionTreeVector.h" namespace mrcpp { -template class RepresentableFunction; -template class FunctionTree; +template class RepresentableFunction; +template class FunctionTree; -template void dot(double prec, - FunctionTree &out, - FunctionTreeVector &inp_a, - FunctionTreeVector &inp_b, +template void dot(double prec, + FunctionTree &out, + FunctionTreeVector &inp_a, + FunctionTreeVector &inp_b, int maxIter = -1, bool absPrec = false); -template double dot(FunctionTree &bra, - FunctionTree &ket); +template T dot(FunctionTree &bra, + FunctionTree &ket); -template double node_norm_dot(FunctionTree &bra, - FunctionTree &ket, +template double node_norm_dot(FunctionTree &bra, + FunctionTree &ket, bool exact = false); -template void multiply(double prec, - FunctionTree &out, - double c, - FunctionTree &inp_a, - FunctionTree &inp_b, +template void multiply(double prec, + FunctionTree &out, + T c, + FunctionTree &inp_a, + FunctionTree &inp_b, int maxIter = -1, bool absPrec = false, bool useMaxNorms = false); -template void multiply(double prec, - FunctionTree &out, - std::vector *> &inp, +template void multiply(double prec, + FunctionTree &out, + std::vector *> &inp, int maxIter = -1, bool absPrec = false, bool useMaxNorms = false); -template void multiply(double prec, - FunctionTree &out, - FunctionTreeVector &inp, +template void multiply(double prec, + FunctionTree &out, + FunctionTreeVector &inp, int maxIter = -1, bool absPrec = false, bool useMaxNorms = false); -template void power(double prec, - FunctionTree &out, - FunctionTree &inp, +template void power(double prec, + FunctionTree &out, + FunctionTree &inp, double p, int maxIter = -1, bool absPrec = false); -template void square(double prec, - FunctionTree &out, - FunctionTree &inp, +template void square(double prec, + FunctionTree &out, + FunctionTree &inp, int maxIter = -1, bool absPrec = false); diff --git a/src/treebuilders/project.cpp b/src/treebuilders/project.cpp index c22f22ec8..65d17fd16 100644 --- a/src/treebuilders/project.cpp +++ b/src/treebuilders/project.cpp @@ -56,8 +56,8 @@ namespace mrcpp { * no coefs). * */ -template void project(double prec, FunctionTree &out, std::function &r)> func, int maxIter, bool absPrec) { - AnalyticFunction inp(func); +template void project(double prec, FunctionTree &out, std::function &r)> func, int maxIter, bool absPrec) { + AnalyticFunction inp(func); mrcpp::project(prec, out, inp, maxIter, absPrec); } @@ -81,14 +81,14 @@ template void project(double prec, FunctionTree &out, std::function void project(double prec, FunctionTree &out, RepresentableFunction &inp, int maxIter, bool absPrec) { +template void project(double prec, FunctionTree &out, RepresentableFunction &inp, int maxIter, bool absPrec) { int maxScale = out.getMRA().getMaxScale(); const auto scaling_factor = out.getMRA().getWorldBox().getScalingFactors(); - TreeBuilder builder; - WaveletAdaptor adaptor(prec, maxScale, absPrec); + TreeBuilder builder; + WaveletAdaptor adaptor(prec, maxScale, absPrec); - ProjectionCalculator calculator(inp, scaling_factor); + ProjectionCalculator calculator(inp, scaling_factor); builder.build(out, calculator, adaptor, maxIter); @@ -121,19 +121,33 @@ template void project(double prec, FunctionTree &out, RepresentableFu * no coefs). * */ -template void project(double prec, FunctionTreeVector &out, std::vector &r)>> func, int maxIter, bool absPrec) { +template void project(double prec, FunctionTreeVector &out, std::vector &r)>> func, int maxIter, bool absPrec) { if (out.size() != func.size()) MSG_ABORT("Size mismatch"); for (auto j = 0; j < D; j++) mrcpp::project(prec, get_func(out, j), func[j], maxIter, absPrec); } -template void project<1>(double prec, FunctionTree<1> &out, RepresentableFunction<1> &inp, int maxIter, bool absPrec); -template void project<2>(double prec, FunctionTree<2> &out, RepresentableFunction<2> &inp, int maxIter, bool absPrec); -template void project<3>(double prec, FunctionTree<3> &out, RepresentableFunction<3> &inp, int maxIter, bool absPrec); +template void project<1, double>(double prec, FunctionTree<1, double> &out, RepresentableFunction<1, double> &inp, int maxIter, bool absPrec); +template void project<2, double>(double prec, FunctionTree<2, double> &out, RepresentableFunction<2, double> &inp, int maxIter, bool absPrec); +template void project<3, double>(double prec, FunctionTree<3, double> &out, RepresentableFunction<3, double> &inp, int maxIter, bool absPrec); + +template void project<1, double>(double prec, FunctionTree<1, double> &out, std::function &r)> func, int maxIter, bool absPrec); +template void project<2, double>(double prec, FunctionTree<2, double> &out, std::function &r)> func, int maxIter, bool absPrec); +template void project<3, double>(double prec, FunctionTree<3, double> &out, std::function &r)> func, int maxIter, bool absPrec); +template void project<1, double>(double prec, FunctionTreeVector<1, double> &out, std::vector &r)>> inp, int maxIter, bool absPrec); +template void project<2, double>(double prec, FunctionTreeVector<2, double> &out, std::vector &r)>> inp, int maxIter, bool absPrec); +template void project<3, double>(double prec, FunctionTreeVector<3, double> &out, std::vector &r)>> inp, int maxIter, bool absPrec); + + +template void project<1, ComplexDouble>(double prec, FunctionTree<1, ComplexDouble> &out, RepresentableFunction<1, ComplexDouble> &inp, int maxIter, bool absPrec); +template void project<2, ComplexDouble>(double prec, FunctionTree<2, ComplexDouble> &out, RepresentableFunction<2, ComplexDouble> &inp, int maxIter, bool absPrec); +template void project<3, ComplexDouble>(double prec, FunctionTree<3, ComplexDouble> &out, RepresentableFunction<3, ComplexDouble> &inp, int maxIter, bool absPrec); + +template void project<1, ComplexDouble>(double prec, FunctionTree<1, ComplexDouble> &out, std::function &r)> func, int maxIter, bool absPrec); +template void project<2, ComplexDouble>(double prec, FunctionTree<2, ComplexDouble> &out, std::function &r)> func, int maxIter, bool absPrec); +template void project<3, ComplexDouble>(double prec, FunctionTree<3, ComplexDouble> &out, std::function &r)> func, int maxIter, bool absPrec); +template void project<1, ComplexDouble>(double prec, FunctionTreeVector<1, ComplexDouble> &out, std::vector &r)>> inp, int maxIter, bool absPrec); +template void project<2, ComplexDouble>(double prec, FunctionTreeVector<2, ComplexDouble> &out, std::vector &r)>> inp, int maxIter, bool absPrec); +template void project<3, ComplexDouble>(double prec, FunctionTreeVector<3, ComplexDouble> &out, std::vector &r)>> inp, int maxIter, bool absPrec); + -template void project<1>(double prec, FunctionTree<1> &out, std::function &r)> func, int maxIter, bool absPrec); -template void project<2>(double prec, FunctionTree<2> &out, std::function &r)> func, int maxIter, bool absPrec); -template void project<3>(double prec, FunctionTree<3> &out, std::function &r)> func, int maxIter, bool absPrec); -template void project<1>(double prec, FunctionTreeVector<1> &out, std::vector &r)>> inp, int maxIter, bool absPrec); -template void project<2>(double prec, FunctionTreeVector<2> &out, std::vector &r)>> inp, int maxIter, bool absPrec); -template void project<3>(double prec, FunctionTreeVector<3> &out, std::vector &r)>> inp, int maxIter, bool absPrec); } // namespace mrcpp diff --git a/src/treebuilders/project.h b/src/treebuilders/project.h index 790914a4b..f9e070ef2 100644 --- a/src/treebuilders/project.h +++ b/src/treebuilders/project.h @@ -30,7 +30,7 @@ #include namespace mrcpp { -template void project(double prec, FunctionTree &out, RepresentableFunction &inp, int maxIter = -1, bool absPrec = false); -template void project(double prec, FunctionTree &out, std::function &r)> func, int maxIter = -1, bool absPrec = false); -template void project(double prec, FunctionTreeVector &out, std::vector &r)>> func, int maxIter = -1, bool absPrec = false); +template void project(double prec, FunctionTree &out, RepresentableFunction &inp, int maxIter = -1, bool absPrec = false); +template void project(double prec, FunctionTree &out, std::function &r)> func, int maxIter = -1, bool absPrec = false); +template void project(double prec, FunctionTreeVector &out, std::vector &r)>> func, int maxIter = -1, bool absPrec = false); } // namespace mrcpp diff --git a/src/trees/FunctionNode.cpp b/src/trees/FunctionNode.cpp index c839e2b57..98858e503 100644 --- a/src/trees/FunctionNode.cpp +++ b/src/trees/FunctionNode.cpp @@ -44,7 +44,7 @@ namespace mrcpp { /** Function evaluation. * Evaluate all polynomials defined on the node. */ -template double FunctionNode::evalf(Coord r) { +template T FunctionNode::evalf(Coord r) { if (not this->hasCoefs()) MSG_ERROR("Evaluating node without coefs"); // The 1.0 appearing in the if tests comes from the period is always 1.0 @@ -57,7 +57,7 @@ template double FunctionNode::evalf(Coord r) { return getFuncChild(cIdx).evalScaling(r); } -template double FunctionNode::evalScaling(const Coord &r) const { +template T FunctionNode::evalScaling(const Coord &r) const { if (not this->hasCoefs()) MSG_ERROR("Evaluating node without coefs"); double arg[D]; @@ -72,10 +72,10 @@ template double FunctionNode::evalScaling(const Coord &r) const { const ScalingBasis &basis = this->getMWTree().getMRA().getScalingBasis(); basis.evalf(arg, val); - double result = 0.0; + T result = 0.0; //#pragma omp parallel for shared(fact) reduction(+:result) num_threads(mrcpp_get_num_threads()) for (int i = 0; i < this->getKp1_d(); i++) { - double temp = this->coefs[i]; + T temp = this->coefs[i]; for (int j = 0; j < D; j++) { int k = (i % fact[j + 1]) / fact[j]; temp *= val(k, j); @@ -92,7 +92,7 @@ template double FunctionNode::evalScaling(const Coord &r) const { * Wrapper for function integration, that requires different methods depending * on scaling type. Integrates the function represented on the node on the * full support of the node. */ -template double FunctionNode::integrate() const { +template T FunctionNode::integrate() const { if (not this->hasCoefs()) { return 0.0; } switch (this->getScalingType()) { case Legendre: @@ -115,7 +115,7 @@ template double FunctionNode::integrate() const { * s_i = int f(x)phi_i(x)dx * and since the first Legendre function is the constant 1, the first * coefficient is simply the integral of f(x). */ -template double FunctionNode::integrateLegendre() const { +template T FunctionNode::integrateLegendre() const { double n = (D * this->getScale()) / 2.0; double two_n = std::pow(2.0, -n); return two_n * this->getCoefs()[0]; @@ -126,7 +126,7 @@ template double FunctionNode::integrateLegendre() const { * Integrates the function represented on the node on the full support of the * node. A bit more involved than in the Legendre basis, as is requires some * coupling of quadrature weights. */ -template double FunctionNode::integrateInterpolating() const { +template T FunctionNode::integrateInterpolating() const { int qOrder = this->getKp1(); getQuadratureCache(qc); const VectorXd &weights = qc.getWeights(qOrder); @@ -136,7 +136,7 @@ template double FunctionNode::integrateInterpolating() const { int kp1_p[D]; for (int i = 0; i < D; i++) kp1_p[i] = math_utils::ipow(qOrder, i); - VectorXd coefs; + Eigen::Matrix coefs; this->getCoefs(coefs); for (int p = 0; p < D; p++) { @@ -152,7 +152,7 @@ template double FunctionNode::integrateInterpolating() const { } double n = (D * this->getScale()) / 2.0; double two_n = std::pow(2.0, -n); - double sum = coefs.segment(0, this->getKp1_d()).sum(); + T sum = coefs.segment(0, this->getKp1_d()).sum(); return two_n * sum; } @@ -162,27 +162,27 @@ template double FunctionNode::integrateInterpolating() const { * Integrates the function represented on the node on the full support of the * node. A bit more involved than in the Legendre basis, as is requires some * coupling of quadrature weights. */ -template double FunctionNode::integrateValues() const { +template T FunctionNode::integrateValues() const { int qOrder = this->getKp1(); getQuadratureCache(qc); const VectorXd &weights = qc.getWeights(qOrder); - VectorXd coefs; + Eigen::Matrix coefs; this->getCoefs(coefs); int ncoefs = coefs.size(); int ncoefChild = ncoefs/(1< 3) MSG_ABORT("Not Implemented") else if (D == 3) { for (int i = 0; i < qOrder; i++) { - double sumj = 0.0; + T sumj = 0.0; for (int j = 0; j < qOrder; j++) { - double sumk = 0.0; + T sumk = 0.0; for (int k = 0; k < qOrder; k++) sumk += cc[nc++] * weights[k]; sumj += sumk * weights[j]; } @@ -190,7 +190,7 @@ template double FunctionNode::integrateValues() const { } } else if (D==2) { for (int j = 0; j < qOrder; j++) { - double sumk = 0.0; + T sumk = 0.0; for (int k = 0; k < qOrder; k++) sumk += cc[nc++] * weights[k]; sum += sumk * weights[j]; } @@ -203,7 +203,7 @@ template double FunctionNode::integrateValues() const { return sum; } -template void FunctionNode::setValues(const VectorXd &vec) { +template void FunctionNode::setValues(const Matrix &vec) { this->zeroCoefs(); this->setCoefBlock(0, vec.size(), vec.data()); this->cvTransform(Backward); @@ -212,15 +212,15 @@ template void FunctionNode::setValues(const VectorXd &vec) { this->calcNorms(); } -template void FunctionNode::getValues(VectorXd &vec) { + template void FunctionNode::getValues(Matrix &vec) { if (this->isGenNode()) { - MWNode copy(*this); - vec = Eigen::VectorXd::Zero(copy.getNCoefs()); + MWNode copy(*this); + vec = Eigen::Matrix::Zero(copy.getNCoefs()); copy.mwTransform(Reconstruction); copy.cvTransform(Forward); for (int i = 0; i < this->n_coefs; i++) vec(i) = copy.getCoefs()[i]; } else { - vec = VectorXd::Zero(this->n_coefs); + vec = Eigen::Matrix::Zero(this->n_coefs); this->mwTransform(Reconstruction); this->cvTransform(Forward); for (int i = 0; i < this->n_coefs; i++) vec(i) = this->coefs[i]; @@ -232,8 +232,8 @@ template void FunctionNode::getValues(VectorXd &vec) { /** get coefficients corresponding to absolute value of function * * Leaves the original coefficients unchanged. */ -template void FunctionNode::getAbsCoefs(double *absCoefs) { - double *coefsTmp = this->coefs; +template void FunctionNode::getAbsCoefs(T *absCoefs) { + T *coefsTmp = this->coefs; for (int i = 0; i < this->n_coefs; i++) absCoefs[i] = coefsTmp[i]; // copy this->coefs = absCoefs; // swap coefs this->mwTransform(Reconstruction); @@ -244,7 +244,7 @@ template void FunctionNode::getAbsCoefs(double *absCoefs) { this->coefs = coefsTmp; // restore original array (same address) } -template void FunctionNode::createChildren(bool coefs) { +template void FunctionNode::createChildren(bool coefs) { if (this->isBranchNode()) MSG_ABORT("Node already has children"); auto &allocator = this->getFuncTree().getNodeAllocator(); @@ -258,7 +258,7 @@ template void FunctionNode::createChildren(bool coefs) { this->childSerialIx = sIdx; for (int cIdx = 0; cIdx < nChildren; cIdx++) { // construct into allocator memory - new (child_p) FunctionNode(this, cIdx); + new (child_p) FunctionNode(this, cIdx); this->children[cIdx] = child_p; child_p->serialIx = sIdx; @@ -282,7 +282,7 @@ template void FunctionNode::createChildren(bool coefs) { this->clearIsEndNode(); } -template void FunctionNode::genChildren() { +template void FunctionNode::genChildren() { if (this->isBranchNode()) MSG_ABORT("Node already has children"); auto &allocator = this->getFuncTree().getGenNodeAllocator(); @@ -296,7 +296,7 @@ template void FunctionNode::genChildren() { this->childSerialIx = sIdx; for (int cIdx = 0; cIdx < nChildren; cIdx++) { // construct into allocator memory - new (child_p) FunctionNode(this, cIdx); + new (child_p) FunctionNode(this, cIdx); this->children[cIdx] = child_p; child_p->serialIx = sIdx; @@ -319,7 +319,7 @@ template void FunctionNode::genChildren() { this->setIsBranchNode(); } -template void FunctionNode::genParent() { +template void FunctionNode::genParent() { if (this->parent != nullptr) MSG_ABORT("Node is not an orphan"); auto &allocator = this->getFuncTree().getNodeAllocator(); @@ -332,7 +332,7 @@ template void FunctionNode::genParent() { this->parentSerialIx = sIdx; // construct into allocator memory - new (parent_p) FunctionNode(this->tree, this->getNodeIndex().parent()); + new (parent_p) FunctionNode(this->tree, this->getNodeIndex().parent()); this->parent = parent_p; @@ -351,12 +351,12 @@ template void FunctionNode::genParent() { this->getMWTree().incrementNodeCount(parent_p->getScale()); } -template void FunctionNode::deleteChildren() { - MWNode::deleteChildren(); +template void FunctionNode::deleteChildren() { + MWNode::deleteChildren(); this->setIsEndNode(); } -template void FunctionNode::dealloc() { +template void FunctionNode::dealloc() { int sIdx = this->serialIx; this->serialIx = -1; this->parentSerialIx = -1; @@ -376,8 +376,8 @@ template void FunctionNode::dealloc() { /** Update the coefficients of the node by a mw transform of the scaling * coefficients of the children. Option to overwrite or add up existing * coefficients. Specialized for D=3 below. */ -template void FunctionNode::reCompress() { - MWNode::reCompress(); +template void FunctionNode::reCompress() { + MWNode::reCompress(); } template <> void FunctionNode<3>::reCompress() { @@ -406,18 +406,18 @@ template <> void FunctionNode<3>::reCompress() { * the node on the full support of the nodes. The scaling basis is fully * orthonormal, and the inner product is simply the dot product of the * coefficient vectors. Assumes the nodes have identical support. */ -template double dot_scaling(const FunctionNode &bra, const FunctionNode &ket) { +template T dot_scaling(const FunctionNode &bra, const FunctionNode &ket) { assert(bra.hasCoefs()); assert(ket.hasCoefs()); - const double *a = bra.getCoefs(); - const double *b = ket.getCoefs(); + const T *a = bra.getCoefs(); + const T *b = ket.getCoefs(); int size = bra.getKp1_d(); #ifdef HAVE_BLAS return cblas_ddot(size, a, 1, b, 1); #else - double result = 0.0; + T result = 0.0; for (int i = 0; i < size; i++) result += a[i] * b[i]; return result; #endif @@ -429,35 +429,46 @@ template double dot_scaling(const FunctionNode &bra, const FunctionNo * the node on the full support of the nodes. The wavelet basis is fully * orthonormal, and the inner product is simply the dot product of the * coefficient vectors. Assumes the nodes have identical support. */ -template double dot_wavelet(const FunctionNode &bra, const FunctionNode &ket) { +template T dot_wavelet(const FunctionNode &bra, const FunctionNode &ket) { if (bra.isGenNode() or ket.isGenNode()) return 0.0; assert(bra.hasCoefs()); assert(ket.hasCoefs()); - const double *a = bra.getCoefs(); - const double *b = ket.getCoefs(); + const T *a = bra.getCoefs(); + const T *b = ket.getCoefs(); int start = bra.getKp1_d(); int size = (bra.getTDim() - 1) * start; #ifdef HAVE_BLAS return cblas_ddot(size, &a[start], 1, &b[start], 1); #else - double result = 0.0; + T result = 0.0; for (int i = 0; i < size; i++) result += a[start + i] * b[start + i]; return result; #endif } -template double dot_scaling(const FunctionNode<1> &bra, const FunctionNode<1> &ket); -template double dot_scaling(const FunctionNode<2> &bra, const FunctionNode<2> &ket); -template double dot_scaling(const FunctionNode<3> &bra, const FunctionNode<3> &ket); -template double dot_wavelet(const FunctionNode<1> &bra, const FunctionNode<1> &ket); -template double dot_wavelet(const FunctionNode<2> &bra, const FunctionNode<2> &ket); -template double dot_wavelet(const FunctionNode<3> &bra, const FunctionNode<3> &ket); - -template class FunctionNode<1>; -template class FunctionNode<2>; -template class FunctionNode<3>; +template double dot_scaling(const FunctionNode<1, double> &bra, const FunctionNode<1, double> &ket); +template double dot_scaling(const FunctionNode<2, double> &bra, const FunctionNode<2, double> &ket); +template double dot_scaling(const FunctionNode<3, double> &bra, const FunctionNode<3, double> &ket); +template double dot_wavelet(const FunctionNode<1, double> &bra, const FunctionNode<1, double> &ket); +template double dot_wavelet(const FunctionNode<2, double> &bra, const FunctionNode<2, double> &ket); +template double dot_wavelet(const FunctionNode<3, double> &bra, const FunctionNode<3, double> &ket); + +template class FunctionNode<1, double>; +template class FunctionNode<2, double>; +template class FunctionNode<3, double>; + +template class FunctionNode<1, ComplexDouble>; +template class FunctionNode<2, ComplexDouble>; +template class FunctionNode<3, ComplexDouble>; + +template ComplexDouble dot_scaling(const FunctionNode<1, ComplexDouble> &bra, const FunctionNode<1, ComplexDouble> &ket); +template ComplexDouble dot_scaling(const FunctionNode<2, ComplexDouble> &bra, const FunctionNode<2, ComplexDouble> &ket); +template ComplexDouble dot_scaling(const FunctionNode<3, ComplexDouble> &bra, const FunctionNode<3, ComplexDouble> &ket); +template ComplexDouble dot_wavelet(const FunctionNode<1, ComplexDouble> &bra, const FunctionNode<1, ComplexDouble> &ket); +template ComplexDouble dot_wavelet(const FunctionNode<2, ComplexDouble> &bra, const FunctionNode<2, ComplexDouble> &ket); +template ComplexDouble dot_wavelet(const FunctionNode<3, ComplexDouble> &bra, const FunctionNode<3, ComplexDouble> &ket); } // namespace mrcpp diff --git a/src/trees/FunctionNode.h b/src/trees/FunctionNode.h index 97a3d74d3..14c44fb7e 100644 --- a/src/trees/FunctionNode.h +++ b/src/trees/FunctionNode.h @@ -32,55 +32,55 @@ namespace mrcpp { -template class FunctionNode final : public MWNode { +template class FunctionNode final : public MWNode { public: - FunctionTree &getFuncTree() { return static_cast &>(*this->tree); } - FunctionNode &getFuncParent() { return static_cast &>(*this->parent); } - FunctionNode &getFuncChild(int i) { return static_cast &>(*this->children[i]); } + FunctionTree &getFuncTree() { return static_cast &>(*this->tree); } + FunctionNode &getFuncParent() { return static_cast &>(*this->parent); } + FunctionNode &getFuncChild(int i) { return static_cast &>(*this->children[i]); } - const FunctionTree &getFuncTree() const { return static_cast &>(*this->tree); } - const FunctionNode &getFuncParent() const { return static_cast &>(*this->parent); } - const FunctionNode &getFuncChild(int i) const { return static_cast &>(*this->children[i]); } + const FunctionTree &getFuncTree() const { return static_cast &>(*this->tree); } + const FunctionNode &getFuncParent() const { return static_cast &>(*this->parent); } + const FunctionNode &getFuncChild(int i) const { return static_cast &>(*this->children[i]); } void createChildren(bool coefs) override; void genChildren() override; void genParent() override; void deleteChildren() override; - double integrate() const; + T integrate() const; - void setValues(const Eigen::VectorXd &vec); - void getValues(Eigen::VectorXd &vec); - void getAbsCoefs(double *absCoefs); + void setValues(const Eigen::Matrix &vec); + void getValues(Eigen::Matrix &vec); + void getAbsCoefs(T *absCoefs); - friend class FunctionTree; - friend class NodeAllocator; + friend class FunctionTree; + friend class NodeAllocator; protected: FunctionNode() - : MWNode() {} - FunctionNode(MWTree *tree, int rIdx) - : MWNode(tree, rIdx) {} - FunctionNode(MWNode *parent, int cIdx) - : MWNode(parent, cIdx) {} - FunctionNode(MWTree *tree, const NodeIndex &idx) - : MWNode(tree, idx) {} - FunctionNode(const FunctionNode &node) = delete; - FunctionNode &operator=(const FunctionNode &node) = delete; + : MWNode() {} + FunctionNode(MWTree *tree, int rIdx) + : MWNode(tree, rIdx) {} + FunctionNode(MWNode *parent, int cIdx) + : MWNode(parent, cIdx) {} + FunctionNode(MWTree *tree, const NodeIndex &idx) + : MWNode(tree, idx) {} + FunctionNode(const FunctionNode &node) = delete; + FunctionNode &operator=(const FunctionNode &node) = delete; ~FunctionNode() = default; - double evalf(Coord r); - double evalScaling(const Coord &r) const; + T evalf(Coord r); + T evalScaling(const Coord &r) const; void dealloc() override; void reCompress() override; - double integrateLegendre() const; - double integrateInterpolating() const; - double integrateValues() const; + T integrateLegendre() const; + T integrateInterpolating() const; + T integrateValues() const; }; -template double dot_scaling(const FunctionNode &bra, const FunctionNode &ket); -template double dot_wavelet(const FunctionNode &bra, const FunctionNode &ket); +template T dot_scaling(const FunctionNode &bra, const FunctionNode &ket); +template T dot_wavelet(const FunctionNode &bra, const FunctionNode &ket); } // namespace mrcpp diff --git a/src/trees/FunctionTree.cpp b/src/trees/FunctionTree.cpp index 47614a933..1c91cf2cd 100644 --- a/src/trees/FunctionTree.cpp +++ b/src/trees/FunctionTree.cpp @@ -50,20 +50,20 @@ namespace mrcpp { * If a shared memory pointer is provided the tree will be allocated in this * shared memory window, otherwise it will be local to each MPI process. */ -template -FunctionTree::FunctionTree(const MultiResolutionAnalysis &mra, SharedMemory *sh_mem, const std::string &name) - : MWTree(mra, name) - , RepresentableFunction(mra.getWorldBox().getLowerBounds().data(), mra.getWorldBox().getUpperBounds().data()) { +template +FunctionTree::FunctionTree(const MultiResolutionAnalysis &mra, SharedMemory *sh_mem, const std::string &name) + : MWTree(mra, name) + , RepresentableFunction(mra.getWorldBox().getLowerBounds().data(), mra.getWorldBox().getUpperBounds().data()) { int nodesPerChunk = 2048; // Large chunks are required for not leading to memory fragmentation (32 MB on "Betzy" 2023) int coefsGenNodes = this->getKp1_d(); int coefsRegNodes = this->getTDim() * this->getKp1_d(); - this->nodeAllocator_p = std::make_unique>(this, sh_mem, coefsRegNodes, nodesPerChunk); - this->genNodeAllocator_p = std::make_unique>(this, nullptr, coefsGenNodes, nodesPerChunk); + this->nodeAllocator_p = std::make_unique>(this, sh_mem, coefsRegNodes, nodesPerChunk); + this->genNodeAllocator_p = std::make_unique>(this, nullptr, coefsGenNodes, nodesPerChunk); this->allocRootNodes(); this->resetEndNodeTable(); } -template void FunctionTree::allocRootNodes() { +template void FunctionTree::allocRootNodes() { auto &allocator = this->getNodeAllocator(); auto &rootbox = this->getRootBox(); @@ -74,10 +74,10 @@ template void FunctionTree::allocRootNodes() { auto *coef_p = allocator.getCoef_p(sIdx); auto *root_p = allocator.getNode_p(sIdx); - MWNode **roots = rootbox.getNodes(); + MWNode **roots = rootbox.getNodes(); for (int rIdx = 0; rIdx < nRoots; rIdx++) { // construct into allocator memory - new (root_p) FunctionNode(this, rIdx); + new (root_p) FunctionNode(this, rIdx); roots[rIdx] = root_p; root_p->serialIx = sIdx; @@ -101,14 +101,14 @@ template void FunctionTree::allocRootNodes() { } // FunctionTree destructor -template FunctionTree::~FunctionTree() { +template FunctionTree::~FunctionTree() { this->deleteRootNodes(); } /** @brief Write the tree structure to disk, for later use * @param[in] file: File name, will get ".tree" extension */ -template void FunctionTree::saveTree(const std::string &file) { +template void FunctionTree::saveTree(const std::string &file) { Timer t1; this->deleteGenerated(); auto &allocator = this->getNodeAllocator(); @@ -137,7 +137,7 @@ template void FunctionTree::saveTree(const std::string &file) { * @param[in] file: File name, will get ".tree" extension * @note This tree must have the exact same MRA the one that was saved */ -template void FunctionTree::loadTree(const std::string &file) { +template void FunctionTree::loadTree(const std::string &file) { Timer t1; std::stringstream fname; fname << file << ".tree"; @@ -168,11 +168,11 @@ template void FunctionTree::loadTree(const std::string &file) { } /** @returns Integral of the function over the entire computational domain */ -template double FunctionTree::integrate() const { +template T FunctionTree::integrate() const { - double result = 0.0; + T result = 0.0; for (int i = 0; i < this->rootBox.size(); i++) { - const FunctionNode &fNode = getRootFuncNode(i); + const FunctionNode &fNode = getRootFuncNode(i); result += fNode.integrate(); } @@ -188,7 +188,7 @@ template double FunctionTree::integrate() const { /** @returns Integral of a representable function over the grid given by the tree */ -template <> double FunctionTree<3>::integrateEndNodes(RepresentableFunction_M &f) { + template <> double FunctionTree<3, double>::integrateEndNodes(RepresentableFunction_M &f) { //traverse tree, and treat end nodes only std::vector *> stack; // node from this for (int i = 0; i < this->getRootBox().size(); i++) stack.push_back(&(this->getRootFuncNode(i))); @@ -236,7 +236,7 @@ template <> double FunctionTree<3>::integrateEndNodes(RepresentableFunction_M &f * the MW grid by one level before evaluating, using * `mrcpp::refine_grid(tree, 1)` */ -template double FunctionTree::evalf(const Coord &r) const { +template T FunctionTree::evalf(const Coord &r) const { // Handle potential scaling const auto scaling_factor = this->getMRA().getWorldBox().getScalingFactors(); auto arg = r; @@ -249,8 +249,8 @@ template double FunctionTree::evalf(const Coord &r) const { // Function is zero outside the domain for non-periodic functions if (this->outOfBounds(arg) and not this->getRootBox().isPeriodic()) return 0.0; - const MWNode &mw_node = this->getNodeOrEndNode(arg); - auto &f_node = static_cast &>(mw_node); + const MWNode &mw_node = this->getNodeOrEndNode(arg); + auto &f_node = static_cast &>(mw_node); auto result = f_node.evalScaling(arg); // Adjust for scaling factor included in basis @@ -270,7 +270,7 @@ template double FunctionTree::evalf(const Coord &r) const { * need fast evaluation, use refine_grid(tree, 1) first, and then * evalf. */ -template double FunctionTree::evalf_precise(const Coord &r) { +template T FunctionTree::evalf_precise(const Coord &r) { // Handle potential scaling const auto scaling_factor = this->getMRA().getWorldBox().getScalingFactors(); auto arg = r; @@ -283,8 +283,8 @@ template double FunctionTree::evalf_precise(const Coord &r) { // Function is zero outside the domain for non-periodic functions if (this->outOfBounds(arg) and not this->getRootBox().isPeriodic()) return 0.0; - MWNode &mw_node = this->getNodeOrEndNode(arg); - auto &f_node = static_cast &>(mw_node); + MWNode &mw_node = this->getNodeOrEndNode(arg); + auto &f_node = static_cast &>(mw_node); auto result = f_node.evalf(arg); this->deleteGenerated(); @@ -301,7 +301,7 @@ template double FunctionTree::evalf_precise(const Coord &r) { * squared, no grid refinement. * */ -template void FunctionTree::square() { +template void FunctionTree::square() { if (this->getNGenNodes() != 0) MSG_ABORT("GenNodes not cleared"); #pragma omp parallel num_threads(mrcpp_get_num_threads()) @@ -310,10 +310,10 @@ template void FunctionTree::square() { int nCoefs = this->getTDim() * this->getKp1_d(); #pragma omp for schedule(guided) for (int n = 0; n < nNodes; n++) { - MWNode &node = *this->endNodeTable[n]; + MWNode &node = *this->endNodeTable[n]; node.mwTransform(Reconstruction); node.cvTransform(Forward); - double *coefs = node.getCoefs(); + T *coefs = node.getCoefs(); for (int i = 0; i < nCoefs; i++) { coefs[i] *= coefs[i]; } node.cvTransform(Backward); node.mwTransform(Compression); @@ -332,7 +332,7 @@ template void FunctionTree::square() { * to the given power, no grid refinement. * */ -template void FunctionTree::power(double p) { +template void FunctionTree::power(double p) { if (this->getNGenNodes() != 0) MSG_ABORT("GenNodes not cleared"); #pragma omp parallel num_threads(mrcpp_get_num_threads()) @@ -341,10 +341,10 @@ template void FunctionTree::power(double p) { int nCoefs = this->getTDim() * this->getKp1_d(); #pragma omp for schedule(guided) for (int n = 0; n < nNodes; n++) { - MWNode &node = *this->endNodeTable[n]; + MWNode &node = *this->endNodeTable[n]; node.mwTransform(Reconstruction); node.cvTransform(Forward); - double *coefs = node.getCoefs(); + T *coefs = node.getCoefs(); for (int i = 0; i < nCoefs; i++) { coefs[i] = std::pow(coefs[i], p); } node.cvTransform(Backward); node.mwTransform(Compression); @@ -363,7 +363,7 @@ template void FunctionTree::power(double p) { * in-place multiplied by the given coefficient, no grid refinement. * */ -template void FunctionTree::rescale(double c) { +template void FunctionTree::rescale(T c) { if (this->getNGenNodes() != 0) MSG_ABORT("GenNodes not cleared"); #pragma omp parallel firstprivate(c) num_threads(mrcpp_get_num_threads()) { @@ -371,9 +371,9 @@ template void FunctionTree::rescale(double c) { int nCoefs = this->getTDim() * this->getKp1_d(); #pragma omp for schedule(guided) for (int i = 0; i < nNodes; i++) { - MWNode &node = *this->endNodeTable[i]; + MWNode &node = *this->endNodeTable[i]; if (not node.hasCoefs()) MSG_ABORT("No coefs"); - double *coefs = node.getCoefs(); + T *coefs = node.getCoefs(); for (int j = 0; j < nCoefs; j++) { coefs[j] *= c; } node.calcNorms(); } @@ -383,7 +383,7 @@ template void FunctionTree::rescale(double c) { } /** @brief In-place rescaling by a function norm \f$ ||f||^{-1} \f$, fixed grid */ -template void FunctionTree::normalize() { +template void FunctionTree::normalize() { if (this->getNGenNodes() != 0) MSG_ABORT("GenNodes not cleared"); double sq_norm = this->getSquareNorm(); if (sq_norm < 0.0) MSG_ERROR("Normalizing uninitialized function"); @@ -399,7 +399,7 @@ template void FunctionTree::normalize() { * the function, i.e. no further grid refinement. * */ -template void FunctionTree::add(double c, FunctionTree &inp) { +template void FunctionTree::add(T c, FunctionTree &inp) { if (this->getMRA() != inp.getMRA()) MSG_ABORT("Incompatible MRA"); if (this->getNGenNodes() != 0) MSG_ABORT("GenNodes not cleared"); #pragma omp parallel firstprivate(c) shared(inp) num_threads(mrcpp_get_num_threads()) @@ -407,10 +407,10 @@ template void FunctionTree::add(double c, FunctionTree &inp) { int nNodes = this->getNEndNodes(); #pragma omp for schedule(guided) for (int n = 0; n < nNodes; n++) { - MWNode &out_node = *this->endNodeTable[n]; - MWNode &inp_node = inp.getNode(out_node.getNodeIndex()); - double *out_coefs = out_node.getCoefs(); - const double *inp_coefs = inp_node.getCoefs(); + MWNode &out_node = *this->endNodeTable[n]; + MWNode &inp_node = inp.getNode(out_node.getNodeIndex()); + T *out_coefs = out_node.getCoefs(); + const T *inp_coefs = inp_node.getCoefs(); for (int i = 0; i < inp_node.getNCoefs(); i++) { out_coefs[i] += c * inp_coefs[i]; } out_node.calcNorms(); } @@ -428,22 +428,22 @@ template void FunctionTree::add(double c, FunctionTree &inp) { * function, i.e. no further grid refinement. * */ -template void FunctionTree::absadd(double c, FunctionTree &inp) { +template void FunctionTree::absadd (T c, FunctionTree &inp) { if (this->getNGenNodes() != 0) MSG_ABORT("GenNodes not cleared"); #pragma omp parallel firstprivate(c) shared(inp) num_threads(mrcpp_get_num_threads()) { int nNodes = this->getNEndNodes(); #pragma omp for schedule(guided) for (int n = 0; n < nNodes; n++) { - MWNode &out_node = *this->endNodeTable[n]; - MWNode inp_node = inp.getNode(out_node.getNodeIndex()); // Full copy + MWNode &out_node = *this->endNodeTable[n]; + MWNode inp_node = inp.getNode(out_node.getNodeIndex()); // Full copy out_node.mwTransform(Reconstruction); out_node.cvTransform(Forward); inp_node.mwTransform(Reconstruction); inp_node.cvTransform(Forward); - double *out_coefs = out_node.getCoefs(); - const double *inp_coefs = inp_node.getCoefs(); - for (int i = 0; i < inp_node.getNCoefs(); i++) { out_coefs[i] = abs(out_coefs[i]) + c * abs(inp_coefs[i]); } + T *out_coefs = out_node.getCoefs(); + const T *inp_coefs = inp_node.getCoefs(); + for (int i = 0; i < inp_node.getNCoefs(); i++) { out_coefs[i] = std::norm(out_coefs[i]) + std::norm(c * inp_coefs[i]); } out_node.cvTransform(Backward); out_node.mwTransform(Compression); out_node.calcNorms(); @@ -463,7 +463,7 @@ template void FunctionTree::absadd(double c, FunctionTree &inp) { * of the function, i.e. no further grid refinement. * */ -template void FunctionTree::multiply(double c, FunctionTree &inp) { +template void FunctionTree::multiply(T c, FunctionTree &inp) { if (this->getMRA() != inp.getMRA()) MSG_ABORT("Incompatible MRA"); if (this->getNGenNodes() != 0) MSG_ABORT("GenNodes not cleared"); #pragma omp parallel firstprivate(c) shared(inp) num_threads(mrcpp_get_num_threads()) @@ -471,14 +471,14 @@ template void FunctionTree::multiply(double c, FunctionTree &inp) int nNodes = this->getNEndNodes(); #pragma omp for schedule(guided) for (int n = 0; n < nNodes; n++) { - MWNode &out_node = *this->endNodeTable[n]; - MWNode inp_node = inp.getNode(out_node.getNodeIndex()); // Full copy + MWNode &out_node = *this->endNodeTable[n]; + MWNode inp_node = inp.getNode(out_node.getNodeIndex()); // Full copy out_node.mwTransform(Reconstruction); out_node.cvTransform(Forward); inp_node.mwTransform(Reconstruction); inp_node.cvTransform(Forward); - double *out_coefs = out_node.getCoefs(); - const double *inp_coefs = inp_node.getCoefs(); + T *out_coefs = out_node.getCoefs(); + const T *inp_coefs = inp_node.getCoefs(); for (int i = 0; i < inp_node.getNCoefs(); i++) { out_coefs[i] *= c * inp_coefs[i]; } out_node.cvTransform(Backward); out_node.mwTransform(Compression); @@ -498,16 +498,16 @@ template void FunctionTree::multiply(double c, FunctionTree &inp) * of the function, i.e. no further grid refinement. * */ -template void FunctionTree::map(FMap fmap) { +template void FunctionTree::map(FMap fmap) { if (this->getNGenNodes() != 0) MSG_ABORT("GenNodes not cleared"); { int nNodes = this->getNEndNodes(); #pragma omp parallel for schedule(guided) num_threads(mrcpp_get_num_threads()) for (int n = 0; n < nNodes; n++) { - MWNode &node = *this->endNodeTable[n]; + MWNode &node = *this->endNodeTable[n]; node.mwTransform(Reconstruction); node.cvTransform(Forward); - double *coefs = node.getCoefs(); + T *coefs = node.getCoefs(); for (int i = 0; i < node.getNCoefs(); i++) { coefs[i] = fmap(coefs[i]); } node.cvTransform(Backward); node.mwTransform(Compression); @@ -518,29 +518,29 @@ template void FunctionTree::map(FMap fmap) { this->calcSquareNorm(); } -template void FunctionTree::getEndValues(VectorXd &data) { +template void FunctionTree::getEndValues(Eigen::Matrix &data) { if (this->getNGenNodes() != 0) MSG_ABORT("GenNodes not cleared"); int nNodes = this->getNEndNodes(); int nCoefs = this->getTDim() * this->getKp1_d(); data = VectorXd::Zero(nNodes * nCoefs); for (int n = 0; n < nNodes; n++) { - MWNode &node = getEndFuncNode(n); + MWNode &node = getEndFuncNode(n); node.mwTransform(Reconstruction); node.cvTransform(Forward); - const double *c = node.getCoefs(); + const T *c = node.getCoefs(); for (int i = 0; i < nCoefs; i++) { data(n * nCoefs + i) = c[i]; } node.cvTransform(Backward); node.mwTransform(Compression); } } -template void FunctionTree::setEndValues(VectorXd &data) { +template void FunctionTree::setEndValues(Eigen::Matrix &data) { if (this->getNGenNodes() != 0) MSG_ABORT("GenNodes not cleared"); int nNodes = this->getNEndNodes(); int nCoefs = this->getTDim() * this->getKp1_d(); for (int i = 0; i < nNodes; i++) { - MWNode &node = getEndFuncNode(i); - const double *c = data.segment(i * nCoefs, nCoefs).data(); + MWNode &node = getEndFuncNode(i); + const T *c = data.segment(i * nCoefs, nCoefs).data(); node.setCoefBlock(0, nCoefs, c); node.cvTransform(Backward); node.mwTransform(Compression); @@ -551,10 +551,10 @@ template void FunctionTree::setEndValues(VectorXd &data) { this->calcSquareNorm(); } -template std::ostream &FunctionTree::print(std::ostream &o) const { +template std::ostream &FunctionTree::print(std::ostream &o) const { o << std::endl << "*FunctionTree: " << this->name << std::endl; o << " genNodes: " << getNGenNodes() << std::endl; - return MWTree::print(o); + return MWTree::print(o); } /** @brief Reduce the precision of the tree by deleting nodes @@ -571,9 +571,9 @@ template std::ostream &FunctionTree::print(std::ostream &o) const { * \f$ ||w|| < 2^{-sn/2} ||f|| \epsilon \f$. In principal, `s` should be equal * to the dimension; in practice, it is set to `s=1`. */ -template int FunctionTree::crop(double prec, double splitFac, bool absPrec) { +template int FunctionTree::crop(double prec, double splitFac, bool absPrec) { for (int i = 0; i < this->rootBox.size(); i++) { - MWNode &root = this->getRootMWNode(i); + MWNode &root = this->getRootMWNode(i); root.crop(prec, splitFac, absPrec); } int nChunks = this->getNodeAllocator().compress(); @@ -586,22 +586,22 @@ template int FunctionTree::crop(double prec, double splitFac, bool ab * Also returns an array with the corresponding indices defined as the * values of serialIx in refTree, and an array with the indices of the parent. * Set index -1 for nodes that are not present in refTree */ -template -void FunctionTree::makeCoeffVector(std::vector &coefs, +template +void FunctionTree::makeCoeffVector(std::vector &coefs, std::vector &indices, std::vector &parent_indices, std::vector &scalefac, int &max_index, - MWTree &refTree, - std::vector *> *refNodes) { + MWTree &refTree, + std::vector *> *refNodes) { coefs.clear(); indices.clear(); parent_indices.clear(); max_index = 0; int sizecoeff = (1 << refTree.getDim()) * refTree.getKp1_d(); int sizecoeffW = ((1 << refTree.getDim()) - 1) * refTree.getKp1_d(); - std::vector *> refstack; // nodes from refTree - std::vector *> thisstack; // nodes from this Tree + std::vector *> refstack; // nodes from refTree + std::vector *> thisstack; // nodes from this Tree for (int rIdx = 0; rIdx < this->getRootBox().size(); rIdx++) { refstack.push_back(refTree.getRootBox().getNodes()[rIdx]); thisstack.push_back(this->getRootBox().getNodes()[rIdx]); @@ -609,8 +609,8 @@ void FunctionTree::makeCoeffVector(std::vector &coefs, int stack_p = 0; while (thisstack.size() > stack_p) { // refNode and thisNode are the same node in space, but on different trees - MWNode *thisNode = thisstack[stack_p]; - MWNode *refNode = refstack[stack_p++]; + MWNode *thisNode = thisstack[stack_p]; + MWNode *refNode = refstack[stack_p++]; coefs.push_back(thisNode->getCoefs()); if (refNodes != nullptr) refNodes->push_back(refNode); if (refNode != nullptr) { @@ -640,26 +640,26 @@ void FunctionTree::makeCoeffVector(std::vector &coefs, * reference tree and a list of coefficients. * It is the reference tree (refTree) which is traversed, but one does not descend * into children if the norm of the tree is smaller than absPrec. */ -template void FunctionTree::makeTreefromCoeff(MWTree &refTree, std::vector coefpVec, std::map &ix2coef, double absPrec, const std::string &mode) { - std::vector *> stack; - std::map *> ix2node; // gives the nodes in this tree for a given ix +template void FunctionTree::makeTreefromCoeff(MWTree &refTree, std::vector coefpVec, std::map &ix2coef, double absPrec, const std::string &mode) { + std::vector *> stack; + std::map *> ix2node; // gives the nodes in this tree for a given ix int sizecoef = (1 << this->getDim()) * this->getKp1_d(); int sizecoefW = ((1 << this->getDim()) - 1) * this->getKp1_d(); this->squareNorm = 0.0; this->clearEndNodeTable(); for (int rIdx = 0; rIdx < refTree.getRootBox().size(); rIdx++) { - MWNode *refNode = refTree.getRootBox().getNodes()[rIdx]; + MWNode *refNode = refTree.getRootBox().getNodes()[rIdx]; stack.push_back(refNode); int ix = ix2coef[refNode->getSerialIx()]; ix2node[ix] = this->getRootBox().getNodes()[rIdx]; } while (stack.size() > 0) { - MWNode *refNode = stack.back(); // node in the reference tree refTree + MWNode *refNode = stack.back(); // node in the reference tree refTree stack.pop_back(); assert(ix2coef.count(refNode->getSerialIx()) > 0); int ix = ix2coef[refNode->getSerialIx()]; - MWNode *node = ix2node[ix]; // corresponding node in this tree + MWNode *node = ix2node[ix]; // corresponding node in this tree // copy coefficients into this tree int size = sizecoefW; if (refNode->isRootNode() or mode == "copy") { @@ -701,8 +701,8 @@ template void FunctionTree::makeTreefromCoeff(MWTree &refTree, std } else if ((absPrec < 0 or tree_utils::split_check(*node, absPrec, 1.0, true)) and refNode->getNChildren() > 0) { // include children in tree node->createChildren(true); - double *inp = node->getCoefs(); - double *out = node->getMWChild(0).getCoefs(); + T *inp = node->getCoefs(); + T *out = node->getMWChild(0).getCoefs(); tree_utils::mw_transform(*this, inp, out, false, sizecoef, true); // make the scaling part for (int i = 0; i < refNode->getNChildren(); i++) { stack.push_back(refNode->children[i]); // means we continue to traverse the reference tree @@ -717,9 +717,9 @@ template void FunctionTree::makeTreefromCoeff(MWTree &refTree, std } /** Traverse tree using DFS and append same nodes as another tree, without coefficients */ -template void FunctionTree::appendTreeNoCoeff(MWTree &inTree) { - std::vector *> instack; // node from inTree - std::vector *> thisstack; // node from this Tree +template void FunctionTree::appendTreeNoCoeff(MWTree &inTree) { + std::vector *> instack; // node from inTree + std::vector *> thisstack; // node from this Tree this->clearEndNodeTable(); for (int rIdx = 0; rIdx < inTree.getRootBox().size(); rIdx++) { instack.push_back(inTree.getRootBox().getNodes()[rIdx]); @@ -727,9 +727,9 @@ template void FunctionTree::appendTreeNoCoeff(MWTree &inTree) { } while (thisstack.size() > 0) { // inNode and thisNode are the same node in space, but on different trees - MWNode *thisNode = thisstack.back(); + MWNode *thisNode = thisstack.back(); thisstack.pop_back(); - MWNode *inNode = instack.back(); + MWNode *inNode = instack.back(); instack.pop_back(); if (inNode->getNChildren() > 0) { thisNode->clearIsEndNode(); @@ -741,10 +741,10 @@ template void FunctionTree::appendTreeNoCoeff(MWTree &inTree) { } else { // construct EndNodeTable for "This", starting from this branch // This could be done more efficiently, if it proves to be time consuming - std::vector *> branchstack; // local stack starting from this branch + std::vector *> branchstack; // local stack starting from this branch branchstack.push_back(thisNode); while (branchstack.size() > 0) { - MWNode *branchNode = branchstack.back(); + MWNode *branchNode = branchstack.back(); branchstack.pop_back(); if (branchNode->getNChildren() > 0) { for (int i = 0; i < branchNode->getNChildren(); i++) { branchstack.push_back(branchNode->children[i]); } @@ -755,24 +755,24 @@ template void FunctionTree::appendTreeNoCoeff(MWTree &inTree) { } } -template void FunctionTree::deleteGenerated() { +template void FunctionTree::deleteGenerated() { for (int n = 0; n < this->getNEndNodes(); n++) this->getEndMWNode(n).deleteGenerated(); } -template void FunctionTree::deleteGeneratedParents() { +template void FunctionTree::deleteGeneratedParents() { for (int n = 0; n < this->getRootBox().size(); n++) this->getRootMWNode(n).deleteParent(); } -template <> int FunctionTree<3>::saveNodesAndRmCoeff() { +template <> int FunctionTree<3, double>::saveNodesAndRmCoeff() { if (this->isLocal) MSG_INFO("Tree is already in local representation"); NodesCoeff = new BankAccount; // NB: must be a collective call! int stack_p = 0; if (mpi::wrk_rank == 0) { int sizecoeff = (1 << 3) * this->getKp1_d(); - std::vector *> stack; // nodes from this Tree + std::vector *> stack; // nodes from this Tree for (int rIdx = 0; rIdx < this->getRootBox().size(); rIdx++) { stack.push_back(this->getRootBox().getNodes()[rIdx]); } while (stack.size() > stack_p) { - MWNode<3> *Node = stack[stack_p++]; + MWNode<3, double> *Node = stack[stack_p++]; int id = 0; NodesCoeff->put_data(Node->getNodeIndex(), sizecoeff, Node->getCoefs()); for (int i = 0; i < Node->getNChildren(); i++) { stack.push_back(Node->children[i]); } @@ -785,8 +785,35 @@ template <> int FunctionTree<3>::saveNodesAndRmCoeff() { return this->NodeIndex2serialIx.size(); } -template class FunctionTree<1>; -template class FunctionTree<2>; -template class FunctionTree<3>; +template <> int FunctionTree<3, ComplexDouble>::saveNodesAndRmCoeff() { + if (this->isLocal) MSG_INFO("Tree is already in local representation"); + NodesCoeff = new BankAccount; // NB: must be a collective call! + int stack_p = 0; + if (mpi::wrk_rank == 0) { + int sizecoeff = (1 << 3) * this->getKp1_d(); + sizecoeff *= 2; // double->ComplexDouble. Saved as twice as many doubles + std::vector *> stack; // nodes from this Tree + for (int rIdx = 0; rIdx < this->getRootBox().size(); rIdx++) { stack.push_back(this->getRootBox().getNodes()[rIdx]); } + while (stack.size() > stack_p) { + MWNode<3, ComplexDouble> *Node = stack[stack_p++]; + int id = 0; + NodesCoeff->put_data(Node->getNodeIndex(), sizecoeff, Node->getCoefs()); + for (int i = 0; i < Node->getNChildren(); i++) { stack.push_back(Node->children[i]); } + } + } + this->nodeAllocator_p->deallocAllCoeff(); + mpi::broadcast_Tree_noCoeff(*this, mpi::comm_wrk); + this->isLocal = true; + assert(this->NodeIndex2serialIx.size() == getNNodes()); + return this->NodeIndex2serialIx.size(); +} + +template class FunctionTree<1, double>; +template class FunctionTree<2, double>; +template class FunctionTree<3, double>; + +template class FunctionTree<1, ComplexDouble>; +template class FunctionTree<2, ComplexDouble>; +template class FunctionTree<3, ComplexDouble>; } // namespace mrcpp diff --git a/src/trees/FunctionTree.h b/src/trees/FunctionTree.h index 0be9563ea..c9e8ecde8 100644 --- a/src/trees/FunctionTree.h +++ b/src/trees/FunctionTree.h @@ -52,24 +52,24 @@ namespace mrcpp { * uninitialized, and its square norm will be negative (minus one). */ -template class FunctionTree final : public MWTree, public RepresentableFunction { +template class FunctionTree final : public MWTree, public RepresentableFunction { public: FunctionTree(const MultiResolutionAnalysis &mra, const std::string &name) : FunctionTree(mra, nullptr, name) {} - FunctionTree(const MultiResolutionAnalysis &mra, SharedMemory *sh_mem = nullptr, const std::string &name = "nn"); - FunctionTree(const FunctionTree &tree) = delete; - FunctionTree &operator=(const FunctionTree &tree) = delete; + FunctionTree(const MultiResolutionAnalysis &mra, SharedMemory *sh_mem = nullptr, const std::string &name = "nn"); + FunctionTree(const FunctionTree &tree) = delete; + FunctionTree &operator=(const FunctionTree &tree) = delete; ~FunctionTree() override; - double integrate() const; + T integrate() const; double integrateEndNodes(RepresentableFunction_M &f); - double evalf_precise(const Coord &r); - double evalf(const Coord &r) const override; + T evalf_precise(const Coord &r); + T evalf(const Coord &r) const override; int getNGenNodes() const { return getGenNodeAllocator().getNNodes(); } - void getEndValues(Eigen::VectorXd &data); - void setEndValues(Eigen::VectorXd &data); + void getEndValues(Eigen::Matrix &data); + void setEndValues(Eigen::Matrix &data); void saveTree(const std::string &file); void loadTree(const std::string &file); @@ -77,44 +77,44 @@ template class FunctionTree final : public MWTree, public Representab // In place operations void square(); void power(double p); - void rescale(double c); + void rescale(T c); void normalize(); - void add(double c, FunctionTree &inp); - void absadd(double c, FunctionTree &inp); - void multiply(double c, FunctionTree &inp); - void map(FMap fmap); + void add(T c, FunctionTree &inp); + void absadd(T c, FunctionTree &inp); + void multiply(T c, FunctionTree &inp); + void map(FMap fmap); int getNChunks() { return this->getNodeAllocator().getNChunks(); } int getNChunksUsed() { return this->getNodeAllocator().getNChunksUsed(); } int crop(double prec, double splitFac = 1.0, bool absPrec = true); - FunctionNode &getEndFuncNode(int i) { return static_cast &>(this->getEndMWNode(i)); } - FunctionNode &getRootFuncNode(int i) { return static_cast &>(this->rootBox.getNode(i)); } + FunctionNode &getEndFuncNode(int i) { return static_cast &>(this->getEndMWNode(i)); } + FunctionNode &getRootFuncNode(int i) { return static_cast &>(this->rootBox.getNode(i)); } - NodeAllocator &getGenNodeAllocator() { return *this->genNodeAllocator_p; } - const NodeAllocator &getGenNodeAllocator() const { return *this->genNodeAllocator_p; } + NodeAllocator &getGenNodeAllocator() { return *this->genNodeAllocator_p; } + const NodeAllocator &getGenNodeAllocator() const { return *this->genNodeAllocator_p; } - const FunctionNode &getEndFuncNode(int i) const { return static_cast &>(this->getEndMWNode(i)); } - const FunctionNode &getRootFuncNode(int i) const { return static_cast &>(this->rootBox.getNode(i)); } + const FunctionNode &getEndFuncNode(int i) const { return static_cast &>(this->getEndMWNode(i)); } + const FunctionNode &getRootFuncNode(int i) const { return static_cast &>(this->rootBox.getNode(i)); } void deleteGenerated(); void deleteGeneratedParents(); - void makeCoeffVector(std::vector &coefs, + void makeCoeffVector(std::vector &coefs, std::vector &indices, std::vector &parent_indices, std::vector &scalefac, int &max_index, - MWTree &refTree, - std::vector *> *refNodes = nullptr); - void makeTreefromCoeff(MWTree &refTree, std::vector coefpVec, std::map &ix2coef, double absPrec, const std::string &mode = "adaptive"); - void appendTreeNoCoeff(MWTree &inTree); + MWTree &refTree, + std::vector *> *refNodes = nullptr); + void makeTreefromCoeff(MWTree &refTree, std::vector coefpVec, std::map &ix2coef, double absPrec, const std::string &mode = "adaptive"); + void appendTreeNoCoeff(MWTree &inTree); // tools for use of local (nodes are stored in Bank) representation int saveNodesAndRmCoeff(); // put all nodes coefficients in Bank and delete all coefficients protected: - std::unique_ptr> genNodeAllocator_p{nullptr}; + std::unique_ptr> genNodeAllocator_p{nullptr}; std::ostream &print(std::ostream &o) const override; void allocRootNodes(); diff --git a/src/trees/FunctionTreeVector.h b/src/trees/FunctionTreeVector.h index d73005cd8..a9ed84d91 100644 --- a/src/trees/FunctionTreeVector.h +++ b/src/trees/FunctionTreeVector.h @@ -32,14 +32,14 @@ namespace mrcpp { -template using CoefsFunctionTree = std::tuple *>; -template using FunctionTreeVector = std::vector>; +template using CoefsFunctionTree = std::tuple *>; +template using FunctionTreeVector = std::vector>; /** @brief Remove all entries in the vector * @param[in] fs: Vector to clear * @param[in] dealloc: Option to free FunctionTree pointer before clearing */ -template void clear(FunctionTreeVector &fs, bool dealloc = false) { + template void clear(FunctionTreeVector &fs, bool dealloc = false) { if (dealloc) { for (auto &t : fs) { auto f = std::get<1>(t); @@ -52,7 +52,7 @@ template void clear(FunctionTreeVector &fs, bool dealloc = false) { /** @returns Total number of nodes of all trees in the vector * @param[in] fs: Vector to fetch from */ -template int get_n_nodes(const FunctionTreeVector &fs) { +template int get_n_nodes(const FunctionTreeVector &fs) { int nNodes = 0; for (const auto &t : fs) { auto f = std::get<1>(t); @@ -64,7 +64,7 @@ template int get_n_nodes(const FunctionTreeVector &fs) { /** @returns Total size of all trees in the vector, in kB * @param[in] fs: Vector to fetch from */ -template int get_size_nodes(const FunctionTreeVector &fs) { +template int get_size_nodes(const FunctionTreeVector &fs) { int sNodes = 0; for (const auto &t : fs) { auto f = std::get<1>(t); @@ -77,7 +77,7 @@ template int get_size_nodes(const FunctionTreeVector &fs) { * @param[in] fs: Vector to fetch from * @param[in] i: Position in vector */ -template double get_coef(const FunctionTreeVector &fs, int i) { +template T get_coef(const FunctionTreeVector &fs, int i) { return std::get<0>(fs[i]); } @@ -85,7 +85,7 @@ template double get_coef(const FunctionTreeVector &fs, int i) { * @param[in] fs: Vector to fetch from * @param[in] i: Position in vector */ -template FunctionTree &get_func(FunctionTreeVector &fs, int i) { +template FunctionTree &get_func(FunctionTreeVector &fs, int i) { return *(std::get<1>(fs[i])); } @@ -93,7 +93,7 @@ template FunctionTree &get_func(FunctionTreeVector &fs, int i) { * @param[in] fs: Vector to fetch from * @param[in] i: Position in vector */ -template const FunctionTree &get_func(const FunctionTreeVector &fs, int i) { +template const FunctionTree &get_func(const FunctionTreeVector &fs, int i) { return *(std::get<1>(fs[i])); } } // namespace mrcpp diff --git a/src/trees/MWNode.cpp b/src/trees/MWNode.cpp index d15c0939f..7bfa510de 100644 --- a/src/trees/MWNode.cpp +++ b/src/trees/MWNode.cpp @@ -45,8 +45,8 @@ namespace mrcpp { * * @details Should be used only by NodeAllocator to obtain * virtual table pointers for the derived classes. */ -template -MWNode::MWNode() + template + MWNode::MWNode() : tree(nullptr) , parent(nullptr) , nodeIndex() @@ -66,8 +66,8 @@ MWNode::MWNode() * * @details Constructor for an empty node, given the corresponding MWTree and NodeIndex */ -template -MWNode::MWNode(MWTree *tree, const NodeIndex &idx) +template +MWNode::MWNode(MWTree *tree, const NodeIndex &idx) : tree(tree) , parent(nullptr) , nodeIndex(idx) @@ -87,8 +87,8 @@ MWNode::MWNode(MWTree *tree, const NodeIndex &idx) * @details Constructor for root nodes. It requires the corresponding * MWTree and an integer to fetch the right NodeIndex */ -template -MWNode::MWNode(MWTree *tree, int rIdx) +template +MWNode::MWNode(MWTree *tree, int rIdx) : tree(tree) , parent(nullptr) , nodeIndex(tree->getRootBox().getNodeIndex(rIdx)) @@ -108,8 +108,8 @@ MWNode::MWNode(MWTree *tree, int rIdx) * @details Constructor for leaf nodes. It requires the corresponding * parent and an integer to identify the correct child. */ -template -MWNode::MWNode(MWNode *parent, int cIdx) +template +MWNode::MWNode(MWNode *parent, int cIdx) : tree(parent->tree) , parent(parent) , nodeIndex(parent->getNodeIndex().child(cIdx)) @@ -130,8 +130,8 @@ MWNode::MWNode(MWNode *parent, int cIdx) * does not "belong" to the tree: it cannot be accessed by traversing * the tree. */ -template -MWNode::MWNode(const MWNode &node, bool allocCoef, bool SetCoef) +template +MWNode::MWNode(const MWNode &node, bool allocCoef, bool SetCoef) : tree(node.tree) , parent(nullptr) , nodeIndex(node.nodeIndex) @@ -163,7 +163,7 @@ MWNode::MWNode(const MWNode &node, bool allocCoef, bool SetCoef) * * @details Recursive deallocation of a node and all its decendants */ -template MWNode::~MWNode() { + template MWNode::~MWNode() { if (this->isLooseNode()) this->freeCoefs(); MRCPP_DESTROY_OMP_LOCK(); } @@ -174,7 +174,7 @@ template MWNode::~MWNode() { * called (derived classes must implement their own version). This was * to avoid having pure virtual methods in the base class. */ -template void MWNode::dealloc() { + template void MWNode::dealloc() { NOT_REACHED_ABORT; } @@ -184,13 +184,13 @@ template void MWNode::dealloc() { * are not treated by the NodeAllocator class. * */ -template void MWNode::allocCoefs(int n_blocks, int block_size) { + template void MWNode::allocCoefs(int n_blocks, int block_size) { if (this->n_coefs != 0) MSG_ABORT("n_coefs should be zero"); if (this->isAllocated()) MSG_ABORT("Coefs already allocated"); if (not this->isLooseNode()) MSG_ABORT("Only loose nodes here!"); this->n_coefs = n_blocks * block_size; - this->coefs = new double[this->n_coefs]; + this->coefs = new T[this->n_coefs]; this->clearHasCoefs(); this->setIsAllocated(); @@ -202,7 +202,7 @@ template void MWNode::allocCoefs(int n_blocks, int block_size) { * are not treated by the NodeAllocator class. * */ -template void MWNode::freeCoefs() { + template void MWNode::freeCoefs() { if (not this->isLooseNode()) MSG_ABORT("Only loose nodes here!"); if (this->coefs != nullptr) delete[] this->coefs; @@ -216,7 +216,7 @@ template void MWNode::freeCoefs() { /** @brief Printout of node coefficients */ -template void MWNode::printCoefs() const { + template void MWNode::printCoefs() const { if (not this->isAllocated()) MSG_ABORT("Node is not allocated"); println(0, "\nMW coefs"); int kp1_d = this->getKp1_d(); @@ -228,18 +228,18 @@ template void MWNode::printCoefs() const { /** @brief wraps the MW coefficients into an eigen vector object */ -template void MWNode::getCoefs(Eigen::VectorXd &c) const { + template void MWNode::getCoefs(Eigen::Matrix &c) const { if (not this->isAllocated()) MSG_ABORT("Node is not allocated"); if (not this->hasCoefs()) MSG_ABORT("Node has no coefs"); if (this->n_coefs == 0) MSG_ABORT("ncoefs == 0"); - c = VectorXd::Map(this->coefs, this->n_coefs); + c = Eigen::Matrix::Map(this->coefs, this->n_coefs); } /** @brief sets all MW coefficients and the norms to zero * */ -template void MWNode::zeroCoefs() { + template void MWNode::zeroCoefs() { if (not this->isAllocated()) MSG_ABORT("Coefs not allocated " << *this); for (int i = 0; i < this->n_coefs; i++) { this->coefs[i] = 0.0; } @@ -249,7 +249,7 @@ template void MWNode::zeroCoefs() { /** @brief Attach a set of coefs to this node. Only used locally (the tree is not aware of this). */ -template void MWNode::attachCoefs(double *coefs) { + template void MWNode::attachCoefs(T *coefs) { this->coefs = coefs; this->setHasCoefs(); } @@ -264,7 +264,7 @@ template void MWNode::attachCoefs(double *coefs) { * (given scaling/wavelet in each direction). Its size is then \f$ * (k+1)^D \f$ and the index is between 0 and \f$ 2^D-1 \f$. */ -template void MWNode::setCoefBlock(int block, int block_size, const double *c) { + template void MWNode::setCoefBlock(int block, int block_size, const T *c) { if (not this->isAllocated()) MSG_ABORT("Coefs not allocated"); for (int i = 0; i < block_size; i++) { this->coefs[block * block_size + i] = c[i]; } } @@ -279,7 +279,7 @@ template void MWNode::setCoefBlock(int block, int block_size, const d * (given scaling/wavelet in each direction). Its size is then \f$ * (k+1)^D \f$ and the index is between 0 and \f$ 2^D-1 \f$. */ -template void MWNode::addCoefBlock(int block, int block_size, const double *c) { + template void MWNode::addCoefBlock(int block, int block_size, const T *c) { if (not this->isAllocated()) MSG_ABORT("Coefs not allocated"); for (int i = 0; i < block_size; i++) { this->coefs[block * block_size + i] += c[i]; } } @@ -293,7 +293,7 @@ template void MWNode::addCoefBlock(int block, int block_size, const d * (given scaling/wavelet in each direction). Its size is then \f$ * (k+1)^D \f$ and the index is between 0 and \f$ 2^D-1 \f$. */ -template void MWNode::zeroCoefBlock(int block, int block_size) { + template void MWNode::zeroCoefBlock(int block, int block_size) { if (not this->isAllocated()) MSG_ABORT("Coefs not allocated"); for (int i = 0; i < block_size; i++) { this->coefs[block * block_size + i] = 0.0; } } @@ -309,7 +309,7 @@ template void MWNode::zeroCoefBlock(int block, int block_size) { * already be present and its memory allocated for this to work * properly. */ -template void MWNode::giveChildrenCoefs(bool overwrite) { + template void MWNode::giveChildrenCoefs(bool overwrite) { assert(this->isBranchNode()); if (not this->isAllocated()) MSG_ABORT("Not allocated!"); if (not this->hasCoefs()) MSG_ABORT("No coefficients!"); @@ -320,8 +320,8 @@ template void MWNode::giveChildrenCoefs(bool overwrite) { // coeff of child should be have been allocated already here int stride = getMWChild(0).getNCoefs(); - double *inp = getCoefs(); - double *out = getMWChild(0).getCoefs(); + T *inp = getCoefs(); + T *out = getMWChild(0).getCoefs(); bool readOnlyScaling = false; if (this->isGenNode()) readOnlyScaling = true; @@ -345,9 +345,9 @@ template void MWNode::giveChildrenCoefs(bool overwrite) { * node. The scaling coefficients of the selected child are then * copied/summed in the correct child node. */ -template void MWNode::giveChildCoefs(int cIdx, bool overwrite) { + template void MWNode::giveChildCoefs(int cIdx, bool overwrite) { - MWNode node_i = *this; + MWNode node_i = *this; node_i.mwTransform(Reconstruction); @@ -355,7 +355,7 @@ template void MWNode::giveChildCoefs(int cIdx, bool overwrite) { int nChildren = this->getTDim(); if (this->children[cIdx] == nullptr) MSG_ABORT("Child does not exist!"); - MWNode &child = getMWChild(cIdx); + MWNode &child = getMWChild(cIdx); if (overwrite) { child.setCoefBlock(0, kp1_d, &node_i.getCoefs()[cIdx * kp1_d]); } else { @@ -371,12 +371,12 @@ template void MWNode::giveChildCoefs(int cIdx, bool overwrite) { * * \warning This routine is only used in connection with Periodic Boundary Conditions */ -template void MWNode::giveParentCoefs(bool overwrite) { - MWNode node = *this; - MWNode &parent = getMWParent(); + template void MWNode::giveParentCoefs(bool overwrite) { + MWNode node = *this; + MWNode &parent = getMWParent(); int kp1_d = this->getKp1_d(); if (node.getScale() == 0) { - NodeBox &box = this->getMWTree().getRootBox(); + NodeBox &box = this->getMWTree().getRootBox(); auto reverse = getTDim() - 1; for (auto i = 0; i < getTDim(); i++) { parent.setCoefBlock(i, kp1_d, &box.getNode(reverse - i).getCoefs()[0]); } } else { @@ -393,11 +393,11 @@ template void MWNode::giveParentCoefs(bool overwrite) { * them consecutively in the corresponding block of the parent, * following the usual bitwise notation. */ -template void MWNode::copyCoefsFromChildren() { + template void MWNode::copyCoefsFromChildren() { int kp1_d = this->getKp1_d(); int nChildren = this->getTDim(); for (int cIdx = 0; cIdx < nChildren; cIdx++) { - MWNode &child = getMWChild(cIdx); + MWNode &child = getMWChild(cIdx); if (not child.hasCoefs()) MSG_ABORT("Child has no coefs"); setCoefBlock(cIdx, kp1_d, child.getCoefs()); } @@ -411,7 +411,7 @@ template void MWNode::copyCoefsFromChildren() { * them consecutively in the corresponding block of the parent, * following the usual bitwise notation. */ -template void MWNode::threadSafeGenChildren() { + template void MWNode::threadSafeGenChildren() { if (tree->isLocal) { NOT_IMPLEMENTED_ABORT; } MRCPP_SET_OMP_LOCK(); if (isLeafNode()) { @@ -431,7 +431,7 @@ template void MWNode::threadSafeGenChildren() { * NOTE: this routine assumes a 0/1 (scaling on child 0 and 1) * representation, instead of s/d (scaling and wavelet). */ -template void MWNode::cvTransform(int operation) { + template void MWNode::cvTransform(int operation) { int kp1 = this->getKp1(); int kp1_dm1 = math_utils::ipow(kp1, D - 1); int kp1_d = this->getKp1_d(); @@ -439,17 +439,17 @@ template void MWNode::cvTransform(int operation) { auto sb = this->getMWTree().getMRA().getScalingBasis(); const MatrixXd &S = sb.getCVMap(operation); - double o_vec[nCoefs]; - double *out_vec = o_vec; - double *in_vec = this->coefs; + T o_vec[nCoefs]; + T *out_vec = o_vec; + T *in_vec = this->coefs; for (int i = 0; i < D; i++) { for (int t = 0; t < this->getTDim(); t++) { - double *out = out_vec + t * kp1_d; - double *in = in_vec + t * kp1_d; + T *out = out_vec + t * kp1_d; + T *in = in_vec + t * kp1_d; math_utils::apply_filter(out, in, S, kp1, kp1_dm1, 0.0); } - double *tmp = in_vec; + T *tmp = in_vec; in_vec = out_vec; out_vec = tmp; } @@ -473,8 +473,8 @@ template void MWNode::cvTransform(int operation) { } } /* Old interpolating version, somewhat faster -template -void MWNode::cvTransform(int operation) { +template +void MWNode::cvTransform(int operation) { const ScalingBasis &sf = this->getMWTree().getMRA().getScalingBasis(); if (sf.getScalingType() != Interpol) { NOT_IMPLEMENTED_ABORT; @@ -538,7 +538,7 @@ void MWNode::cvTransform(int operation) { * * * @param[in] operation: compression (s0,s1->s,d) or reconstruction (s,d->s0,s1). */ -template void MWNode::mwTransform(int operation) { + template void MWNode::mwTransform(int operation) { int kp1 = this->getKp1(); int kp1_dm1 = math_utils::ipow(kp1, D - 1); int kp1_d = this->getKp1_d(); @@ -546,20 +546,20 @@ template void MWNode::mwTransform(int operation) { const MWFilter &filter = getMWTree().getMRA().getFilter(); double overwrite = 0.0; - double o_vec[nCoefs]; - double *out_vec = o_vec; - double *in_vec = this->coefs; + T o_vec[nCoefs]; + T *out_vec = o_vec; + T *in_vec = this->coefs; for (int i = 0; i < D; i++) { int mask = 1 << i; for (int gt = 0; gt < this->getTDim(); gt++) { - double *out = out_vec + gt * kp1_d; + T *out = out_vec + gt * kp1_d; for (int ft = 0; ft < this->getTDim(); ft++) { /* Operate in direction i only if the bits along other * directions are identical. The bit of the direction we * operate on determines the appropriate filter/operator */ if ((gt | mask) == (ft | mask)) { - double *in = in_vec + ft * kp1_d; + T *in = in_vec + ft * kp1_d; int fIdx = 2 * ((gt >> i) & 1) + ((ft >> i) & 1); const MatrixXd &oper = filter.getSubFilter(fIdx, operation); math_utils::apply_filter(out, in, oper, kp1, kp1_dm1, overwrite); @@ -568,7 +568,7 @@ template void MWNode::mwTransform(int operation) { } overwrite = 0.0; } - double *tmp = in_vec; + T *tmp = in_vec; in_vec = out_vec; out_vec = tmp; } @@ -578,19 +578,19 @@ template void MWNode::mwTransform(int operation) { } /** @brief Set all norms to Undefined. */ -template void MWNode::clearNorms() { + template void MWNode::clearNorms() { this->squareNorm = -1.0; for (int i = 0; i < this->getTDim(); i++) { this->componentNorms[i] = -1.0; } } /** @brief Set all norms to zero. */ -template void MWNode::zeroNorms() { + template void MWNode::zeroNorms() { this->squareNorm = 0.0; for (int i = 0; i < this->getTDim(); i++) { this->componentNorms[i] = 0.0; } } /** @brief Calculate and store square norm and component norms, if allocated. */ -template void MWNode::calcNorms() { + template void MWNode::calcNorms() { this->squareNorm = 0.0; for (int i = 0; i < this->getTDim(); i++) { double norm_i = calcComponentNorm(i); @@ -600,7 +600,7 @@ template void MWNode::calcNorms() { } /** @brief Calculate and return the squared scaling norm. */ -template double MWNode::getScalingNorm() const { + template double MWNode::getScalingNorm() const { double sNorm = this->getComponentNorm(0); if (sNorm >= 0.0) { return sNorm * sNorm; @@ -610,7 +610,7 @@ template double MWNode::getScalingNorm() const { } /** @brief Calculate and return the squared wavelet norm. */ -template double MWNode::getWaveletNorm() const { + template double MWNode::getWaveletNorm() const { double wNorm = 0.0; for (int i = 1; i < this->getTDim(); i++) { double norm_i = this->getComponentNorm(i); @@ -624,28 +624,28 @@ template double MWNode::getWaveletNorm() const { } /** @brief Calculate the norm of one component (NOT the squared norm!). */ -template double MWNode::calcComponentNorm(int i) const { + template double MWNode::calcComponentNorm(int i) const { if (this->isGenNode() and i != 0) return 0.0; assert(this->isAllocated()); assert(this->hasCoefs()); - const double *c = this->getCoefs(); + const T *c = this->getCoefs(); int size = this->getKp1_d(); int start = i * size; double sq_norm = 0.0; -#ifdef HAVE_BLAS - sq_norm = cblas_ddot(size, &c[start], 1, &c[start], 1); -#else - for (int i = start; i < start + size; i++) { sq_norm += c[i] * c[i]; } -#endif +//#ifdef HAVE_BLAS +// sq_norm = cblas_ddot(size, &c[start], 1, &c[start], 1); +//#else + for (int i = start; i < start + size; i++) { sq_norm += std::norm(c[i]); } +//#endif return std::sqrt(sq_norm); } /** @brief Update the coefficients of the node by a mw transform of the scaling * coefficients of the children. */ -template void MWNode::reCompress() { + template void MWNode::reCompress() { if (this->isGenNode()) NOT_IMPLEMENTED_ABORT; if (this->isBranchNode()) { if (not this->isAllocated()) MSG_ABORT("Coefs not allocated"); @@ -662,12 +662,12 @@ template void MWNode::reCompress() { * @param[in] splitFac: factor used in the split check (larger factor means tighter threshold for finer nodes) * @param[in] absPrec: flag to switch from relative (false) to absolute (true) precision. */ -template bool MWNode::crop(double prec, double splitFac, bool absPrec) { + template bool MWNode::crop(double prec, double splitFac, bool absPrec) { if (this->isEndNode()) { return true; } else { for (int i = 0; i < this->getTDim(); i++) { - MWNode &child = *this->children[i]; + MWNode &child = *this->children[i]; if (child.crop(prec, splitFac, absPrec)) { if (tree_utils::split_check(*this, prec, splitFac, absPrec) == false) { this->deleteChildren(); @@ -679,15 +679,15 @@ template bool MWNode::crop(double prec, double splitFac, bool absPrec return false; } -template void MWNode::createChildren(bool coefs) { + template void MWNode::createChildren(bool coefs) { NOT_REACHED_ABORT; } -template void MWNode::genChildren() { + template void MWNode::genChildren() { NOT_REACHED_ABORT; } -template void MWNode::genParent() { + template void MWNode::genParent() { NOT_REACHED_ABORT; } @@ -696,11 +696,11 @@ template void MWNode::genParent() { * @details * Leaves node as LeafNode and children[] as null pointer. */ -template void MWNode::deleteChildren() { + template void MWNode::deleteChildren() { if (this->isLeafNode()) return; for (int cIdx = 0; cIdx < getTDim(); cIdx++) { if (this->children[cIdx] != nullptr) { - MWNode &child = getMWChild(cIdx); + MWNode &child = getMWChild(cIdx); child.deleteChildren(); child.dealloc(); } @@ -711,9 +711,9 @@ template void MWNode::deleteChildren() { } /** @brief Recursive deallocation of parent and all their forefathers. */ -template void MWNode::deleteParent() { + template void MWNode::deleteParent() { if (this->parent == nullptr) return; - MWNode &parent = getMWParent(); + MWNode &parent = getMWParent(); parent.deleteParent(); parent.dealloc(); this->parentSerialIx = -1; @@ -722,7 +722,7 @@ template void MWNode::deleteParent() { /** @brief Deallocation of all generated nodes . */ -template void MWNode::deleteGenerated() { + template void MWNode::deleteGenerated() { if (this->isBranchNode()) { if (this->isEndNode()) { this->deleteChildren(); @@ -733,7 +733,7 @@ template void MWNode::deleteGenerated() { } /** @brief returns the coordinates of the centre of the node */ -template Coord MWNode::getCenter() const { + template Coord MWNode::getCenter() const { auto two_n = std::pow(2.0, -getScale()); auto scaling_factor = getMWTree().getMRA().getWorldBox().getScalingFactors(); auto &l = getNodeIndex(); @@ -743,7 +743,7 @@ template Coord MWNode::getCenter() const { } /** @brief returns the upper bounds of the D-interval defining the node */ -template Coord MWNode::getUpperBounds() const { + template Coord MWNode::getUpperBounds() const { auto two_n = std::pow(2.0, -getScale()); auto scaling_factor = getMWTree().getMRA().getWorldBox().getScalingFactors(); auto &l = getNodeIndex(); @@ -753,7 +753,7 @@ template Coord MWNode::getUpperBounds() const { } /** @brief returns the lower bounds of the D-interval defining the node */ -template Coord MWNode::getLowerBounds() const { + template Coord MWNode::getLowerBounds() const { auto two_n = std::pow(2.0, -getScale()); auto scaling_factor = getMWTree().getMRA().getWorldBox().getScalingFactors(); auto &l = getNodeIndex(); @@ -770,7 +770,7 @@ template Coord MWNode::getLowerBounds() const { * to be followed at the current scale in oder to get to the requested * node at the final scale. The result is the index of the child needed. * The index is obtained by bit manipulation of of the translation indices. */ -template int MWNode::getChildIndex(const NodeIndex &nIdx) const { + template int MWNode::getChildIndex(const NodeIndex &nIdx) const { assert(isAncestor(nIdx)); int cIdx = 0; int diffScale = nIdx.getScale() - getScale() - 1; @@ -790,7 +790,7 @@ template int MWNode::getChildIndex(const NodeIndex &nIdx) const { * * @detailsGiven a point in space, determines which child should be followed * to get to the corresponding terminal node. */ -template int MWNode::getChildIndex(const Coord &r) const { + template int MWNode::getChildIndex(const Coord &r) const { assert(hasCoord(r)); int cIdx = 0; double sFac = std::pow(2.0, -getScale()); @@ -815,7 +815,7 @@ template int MWNode::getChildIndex(const Coord &r) const { * grid of quadrature points. * */ -template void MWNode::getPrimitiveQuadPts(MatrixXd &pts) const { + template void MWNode::getPrimitiveQuadPts(MatrixXd &pts) const { int kp1 = this->getKp1(); pts = MatrixXd::Zero(D, kp1); @@ -840,7 +840,7 @@ template void MWNode::getPrimitiveQuadPts(MatrixXd &pts) const { * nodes. * */ -template void MWNode::getPrimitiveChildPts(MatrixXd &pts) const { + template void MWNode::getPrimitiveChildPts(MatrixXd &pts) const { int kp1 = this->getKp1(); pts = MatrixXd::Zero(D, 2 * kp1); @@ -865,7 +865,7 @@ template void MWNode::getPrimitiveChildPts(MatrixXd &pts) const { * vectors of quadrature points. * */ -template void MWNode::getExpandedQuadPts(Eigen::MatrixXd &pts) const { + template void MWNode::getExpandedQuadPts(Eigen::MatrixXd &pts) const { MatrixXd prim_pts; getPrimitiveQuadPts(prim_pts); @@ -889,7 +889,7 @@ template void MWNode::getExpandedQuadPts(Eigen::MatrixXd &pts) const * vectors of quadrature points. * */ -template void MWNode::getExpandedChildPts(MatrixXd &pts) const { + template void MWNode::getExpandedChildPts(MatrixXd &pts) const { MatrixXd prim_pts; getPrimitiveChildPts(prim_pts); @@ -923,7 +923,7 @@ template void MWNode::getExpandedChildPts(MatrixXd &pts) const { * the node does not exist, or if it is a GenNode. Recursion starts at at this * node and ASSUMES the requested node is in fact decending from this node. */ -template const MWNode *MWNode::retrieveNodeNoGen(const NodeIndex &idx) const { + template const MWNode *MWNode::retrieveNodeNoGen(const NodeIndex &idx) const { if (getScale() == idx.getScale()) { // we're done assert(getNodeIndex() == idx); return this; @@ -947,7 +947,7 @@ template const MWNode *MWNode::retrieveNodeNoGen(const NodeIndex MWNode *MWNode::retrieveNodeNoGen(const NodeIndex &idx) { + template MWNode *MWNode::retrieveNodeNoGen(const NodeIndex &idx) { if (getScale() == idx.getScale()) { // we're done assert(getNodeIndex() == idx); return this; @@ -973,7 +973,7 @@ template MWNode *MWNode::retrieveNodeNoGen(const NodeIndex &idx * this node and ASSUMES the requested node is in fact decending from * this node. */ -template const MWNode *MWNode::retrieveNodeOrEndNode(const Coord &r, int depth) const { + template const MWNode *MWNode::retrieveNodeOrEndNode(const Coord &r, int depth) const { if (getDepth() == depth or this->isEndNode()) { return this; } int cIdx = getChildIndex(r); assert(this->children[cIdx] != nullptr); @@ -992,7 +992,7 @@ template const MWNode *MWNode::retrieveNodeOrEndNode(const Coord MWNode *MWNode::retrieveNodeOrEndNode(const Coord &r, int depth) { + template MWNode *MWNode::retrieveNodeOrEndNode(const Coord &r, int depth) { if (getDepth() == depth or this->isEndNode()) { return this; } int cIdx = getChildIndex(r); assert(this->children[cIdx] != nullptr); @@ -1010,7 +1010,7 @@ template MWNode *MWNode::retrieveNodeOrEndNode(const Coord &r, * this node and ASSUMES the requested node is in fact decending from * this node. */ -template const MWNode *MWNode::retrieveNodeOrEndNode(const NodeIndex &idx) const { + template const MWNode *MWNode::retrieveNodeOrEndNode(const NodeIndex &idx) const { if (getScale() == idx.getScale()) { // we're done assert(getNodeIndex() == idx); return this; @@ -1036,7 +1036,7 @@ template const MWNode *MWNode::retrieveNodeOrEndNode(const NodeInd * this node and ASSUMES the requested node is in fact decending from * this node. */ -template MWNode *MWNode::retrieveNodeOrEndNode(const NodeIndex &idx) { + template MWNode *MWNode::retrieveNodeOrEndNode(const NodeIndex &idx) { if (getScale() == idx.getScale()) { // we're done assert(getNodeIndex() == idx); return this; @@ -1061,7 +1061,7 @@ template MWNode *MWNode::retrieveNodeOrEndNode(const NodeIndex * that does not exist. Recursion starts at this node and ASSUMES the * requested node is in fact decending from this node. */ -template MWNode *MWNode::retrieveNode(const Coord &r, int depth) { + template MWNode *MWNode::retrieveNode(const Coord &r, int depth) { if (depth < 0) MSG_ABORT("Invalid argument"); if (getDepth() == depth) { return this; } @@ -1082,13 +1082,15 @@ template MWNode *MWNode::retrieveNode(const Coord &r, int depth * does not exist. Recursion starts at this node and ASSUMES the requested * node is in fact descending from this node. */ -template MWNode *MWNode::retrieveNode(const NodeIndex &idx) { + template MWNode *MWNode::retrieveNode(const NodeIndex &idx) { if (getScale() == idx.getScale()) { // we're done if (tree->isLocal) { + NOT_IMPLEMENTED_ABORT; // has to fetch coeff in Bank. NOT USED YET - int ncoefs = (1 << D) * this->getKp1_d(); - coefs = new double[ncoefs]; // TODO must be cleaned at some stage - tree->getNodeCoeff(idx, coefs); + //int ncoefs = (1 << D) * this->getKp1_d(); + //coefs = new double[ncoefs]; // TODO must be cleaned at some stage + //coefs = new double[ncoefs]; // TODO must be cleaned at some stage + //tree->getNodeCoeff(idx, coefs); } assert(getNodeIndex() == idx); return this; @@ -1113,7 +1115,7 @@ template MWNode *MWNode::retrieveNode(const NodeIndex &idx) { * does not exist. Recursion starts at this node and ASSUMES the requested * node is in fact related to this node. */ -template MWNode *MWNode::retrieveParent(const NodeIndex &idx) { + template MWNode *MWNode::retrieveParent(const NodeIndex &idx) { if (getScale() < idx.getScale()) MSG_ABORT("Scale error") if (getScale() == idx.getScale()) return this; if (this->parent == nullptr) { @@ -1132,7 +1134,7 @@ template MWNode *MWNode::retrieveParent(const NodeIndex &idx) { * found, do not generate any new node, but rather give the value of the norm * assuming the function is uniformly distributed within the node. */ -template double MWNode::getNodeNorm(const NodeIndex &idx) const { + template double MWNode::getNodeNorm(const NodeIndex &idx) const { if (this->getScale() == idx.getScale()) { // we're done assert(getNodeIndex() == idx); return std::sqrt(this->squareNorm); @@ -1150,7 +1152,7 @@ template double MWNode::getNodeNorm(const NodeIndex &idx) const { * * @param[in] r: point coordinates */ -template bool MWNode::hasCoord(const Coord &r) const { + template bool MWNode::hasCoord(const Coord &r) const { double sFac = std::pow(2.0, -getScale()); const NodeIndex &l = getNodeIndex(); // println(1, "[" << r[0] << "," << r[1] << "," << r[2] << "]"); @@ -1168,7 +1170,7 @@ template bool MWNode::hasCoord(const Coord &r) const { /** Testing if nodes are compatible wrt NodeIndex and Tree (order, rootScale, * relPrec, etc). */ -template bool MWNode::isCompatible(const MWNode &node) { + template bool MWNode::isCompatible(const MWNode &node) { NOT_IMPLEMENTED_ABORT; // if (nodeIndex != node.nodeIndex) { // println(0, "nodeIndex mismatch" << std::endl); @@ -1186,7 +1188,7 @@ template bool MWNode::isCompatible(const MWNode &node) { * * @param[in] idx: the NodeIndex of the requested node */ -template bool MWNode::isAncestor(const NodeIndex &idx) const { + template bool MWNode::isAncestor(const NodeIndex &idx) const { int relScale = idx.getScale() - getScale(); if (relScale < 0) return false; const NodeIndex &l = getNodeIndex(); @@ -1197,7 +1199,7 @@ template bool MWNode::isAncestor(const NodeIndex &idx) const { return true; } -template bool MWNode::isDecendant(const NodeIndex &idx) const { + template bool MWNode::isDecendant(const NodeIndex &idx) const { NOT_IMPLEMENTED_ABORT; } @@ -1205,7 +1207,7 @@ template bool MWNode::isDecendant(const NodeIndex &idx) const { * * @param[in] o: the output stream */ -template std::ostream &MWNode::print(std::ostream &o) const { + template std::ostream &MWNode::print(std::ostream &o) const { std::string flags = " "; o << getNodeIndex(); if (isRootNode()) flags[0] = 'R'; @@ -1236,14 +1238,14 @@ template std::ostream &MWNode::print(std::ostream &o) const { * normalization is such that a constant function gives constant value, * i.e. *not* same normalization as a squareNorm */ -template void MWNode::setMaxSquareNorm() { + template void MWNode::setMaxSquareNorm() { auto n = this->getScale(); this->maxWSquareNorm = calcScaledWSquareNorm(); this->maxSquareNorm = calcScaledSquareNorm(); if (not this->isEndNode()) { for (int i = 0; i < this->getTDim(); i++) { - MWNode &child = *this->children[i]; + MWNode &child = *this->children[i]; child.setMaxSquareNorm(); this->maxSquareNorm = std::max(this->maxSquareNorm, child.maxSquareNorm); this->maxWSquareNorm = std::max(this->maxWSquareNorm, child.maxWSquareNorm); @@ -1252,20 +1254,23 @@ template void MWNode::setMaxSquareNorm() { } /** @brief recursively reset maxSquaredNorm and maxWSquareNorm of parent and descendants to value -1 */ -template void MWNode::resetMaxSquareNorm() { + template void MWNode::resetMaxSquareNorm() { auto n = this->getScale(); this->maxSquareNorm = -1.0; this->maxWSquareNorm = -1.0; if (not this->isEndNode()) { for (int i = 0; i < this->getTDim(); i++) { - MWNode &child = *this->children[i]; + MWNode &child = *this->children[i]; child.resetMaxSquareNorm(); } } } -template class MWNode<1>; -template class MWNode<2>; -template class MWNode<3>; + template class MWNode<1, double>; + template class MWNode<2, double>; + template class MWNode<3, double>; + template class MWNode<1, ComplexDouble>; + template class MWNode<2, ComplexDouble>; + template class MWNode<3, ComplexDouble>; } // namespace mrcpp diff --git a/src/trees/MWNode.h b/src/trees/MWNode.h index d7d7a18a7..ed25762b4 100644 --- a/src/trees/MWNode.h +++ b/src/trees/MWNode.h @@ -30,6 +30,7 @@ #include "MRCPP/macros.h" #include "utils/omp_utils.h" +#include "utils/math_utils.h" #include "HilbertPath.h" #include "MWTree.h" @@ -49,12 +50,12 @@ namespace mrcpp { * translation index, the norm, pointers to parent node and child * nodes, pointer to the corresponding MWTree etc... See member and * data descriptions for details. - * + * */ -template class MWNode { +template class MWNode { public: - MWNode(const MWNode &node, bool allocCoef = true, bool SetCoef = true); - MWNode &operator=(const MWNode &node) = delete; + MWNode(const MWNode &node, bool allocCoef = true, bool SetCoef = true); + MWNode &operator=(const MWNode &node) = delete; virtual ~MWNode(); int getKp1() const { return getMWTree().getKp1(); } @@ -76,7 +77,7 @@ template class MWNode { Coord getLowerBounds() const; bool hasCoord(const Coord &r) const; - bool isCompatible(const MWNode &node); + bool isCompatible(const MWNode &node); bool isAncestor(const NodeIndex &idx) const; bool isDecendant(const NodeIndex &idx) const; @@ -89,30 +90,30 @@ template class MWNode { double getComponentNorm(int i) const { return this->componentNorms[i]; } int getNCoefs() const { return this->n_coefs; } - void getCoefs(Eigen::VectorXd &c) const; + void getCoefs(Eigen::Matrix &c) const; void printCoefs() const; - double *getCoefs() { return this->coefs; } - const double *getCoefs() const { return this->coefs; } + T *getCoefs() { return this->coefs; } + const T *getCoefs() const { return this->coefs; } void getPrimitiveQuadPts(Eigen::MatrixXd &pts) const; void getPrimitiveChildPts(Eigen::MatrixXd &pts) const; void getExpandedQuadPts(Eigen::MatrixXd &pts) const; void getExpandedChildPts(Eigen::MatrixXd &pts) const; - MWTree &getMWTree() { return static_cast &>(*this->tree); } - MWNode &getMWParent() { return static_cast &>(*this->parent); } - MWNode &getMWChild(int i) { return static_cast &>(*this->children[i]); } + MWTree &getMWTree() { return static_cast &>(*this->tree); } + MWNode &getMWParent() { return static_cast &>(*this->parent); } + MWNode &getMWChild(int i) { return static_cast &>(*this->children[i]); } - const MWTree &getMWTree() const { return static_cast &>(*this->tree); } - const MWNode &getMWParent() const { return static_cast &>(*this->parent); } - const MWNode &getMWChild(int i) const { return static_cast &>(*this->children[i]); } + const MWTree &getMWTree() const { return static_cast &>(*this->tree); } + const MWNode &getMWParent() const { return static_cast &>(*this->parent); } + const MWNode &getMWChild(int i) const { return static_cast &>(*this->children[i]); } void zeroCoefs(); - void setCoefBlock(int block, int block_size, const double *c); - void addCoefBlock(int block, int block_size, const double *c); + void setCoefBlock(int block, int block_size, const T *c); + void addCoefBlock(int block, int block_size, const T *c); void zeroCoefBlock(int block, int block_size); - void attachCoefs(double *coefs); + void attachCoefs(T *coefs); void calcNorms(); void zeroNorms(); @@ -154,34 +155,35 @@ template class MWNode { void clearIsRootNode() { CLEAR_BITS(status, FlagRootNode); } void clearIsAllocated() { CLEAR_BITS(status, FlagAllocated); } - friend std::ostream &operator<<(std::ostream &o, const MWNode &nd) { return nd.print(o); } + friend std::ostream &operator<<(std::ostream &o, const MWNode &nd) { return nd.print(o); } - friend class TreeBuilder; - friend class MultiplicationCalculator; - friend class NodeAllocator; - friend class MWTree; - friend class FunctionTree; + friend class TreeBuilder; + friend class MultiplicationCalculator; + friend class NodeAllocator; + friend class MWTree; + friend class FunctionTree; friend class OperatorTree; - friend class FunctionNode; + friend class FunctionNode; friend class OperatorNode; - friend class DerivativeCalculator; + friend class DerivativeCalculator; + bool isComplex = false; //TODO put as one of the flags protected: - MWTree *tree{nullptr}; ///< Tree the node belongs to - MWNode *parent{nullptr}; ///< Parent node - MWNode *children[1 << D]; ///< 2^D children + MWTree *tree{nullptr}; ///< Tree the node belongs to + MWNode *parent{nullptr}; ///< Parent node + MWNode *children[1 << D]; ///< 2^D children double squareNorm{-1.0}; ///< Squared norm of all 2^D (k+1)^D coefficients double componentNorms[1 << D]; ///< Squared norms of the separeted 2^D components double maxSquareNorm{-1.0}; ///< Largest squared norm among itself and descendants. double maxWSquareNorm{-1.0}; ///< Largest wavelet squared norm among itself and descendants. ///< NB: must be set before used. - double *coefs{nullptr}; ///< the 2^D (k+1)^D MW coefficients - ///< For example, in case of a one dimensional function \f$ f \f$ - ///< this array equals \f$ s_0, \ldots, s_k, d_0, \ldots, d_k \f$, - ///< where scaling coefficients \f$ s_j = s_{jl}^n(f) \f$ - ///< and wavelet coefficients \f$ d_j = d_{jl}^n(f) \f$. - ///< Here \f$ n, l \f$ are unique for every node. + T *coefs{nullptr}; ///< the 2^D (k+1)^D MW coefficients + ///< For example, in case of a one dimensional function \f$ f \f$ + ///< this array equals \f$ s_0, \ldots, s_k, d_0, \ldots, d_k \f$, + ///< where scaling coefficients \f$ s_j = s_{jl}^n(f) \f$ + ///< and wavelet coefficients \f$ d_j = d_{jl}^n(f) \f$. + ///< Here \f$ n, l \f$ are unique for every node. int n_coefs{0}; int serialIx{-1}; ///< index in serial Tree @@ -192,9 +194,9 @@ template class MWNode { HilbertPath hilbertPath; ///< To be documented MWNode(); - MWNode(MWTree *tree, int rIdx); - MWNode(MWTree *tree, const NodeIndex &idx); - MWNode(MWNode *parent, int cIdx); + MWNode(MWTree *tree, int rIdx); + MWNode(MWTree *tree, const NodeIndex &idx); + MWNode(MWNode *parent, int cIdx); virtual void dealloc(); @@ -219,20 +221,20 @@ template class MWNode { int getChildIndex(const NodeIndex &nIdx) const; int getChildIndex(const Coord &r) const; - bool diffBranch(const MWNode &rhs) const; + bool diffBranch(const MWNode &rhs) const; - MWNode *retrieveNode(const Coord &r, int depth); - MWNode *retrieveNode(const NodeIndex &idx); - MWNode *retrieveParent(const NodeIndex &idx); + MWNode *retrieveNode(const Coord &r, int depth); + MWNode *retrieveNode(const NodeIndex &idx); + MWNode *retrieveParent(const NodeIndex &idx); - const MWNode *retrieveNodeNoGen(const NodeIndex &idx) const; - MWNode *retrieveNodeNoGen(const NodeIndex &idx); + const MWNode *retrieveNodeNoGen(const NodeIndex &idx) const; + MWNode *retrieveNodeNoGen(const NodeIndex &idx); - const MWNode *retrieveNodeOrEndNode(const Coord &r, int depth) const; - MWNode *retrieveNodeOrEndNode(const Coord &r, int depth); + const MWNode *retrieveNodeOrEndNode(const Coord &r, int depth) const; + MWNode *retrieveNodeOrEndNode(const Coord &r, int depth); - const MWNode *retrieveNodeOrEndNode(const NodeIndex &idx) const; - MWNode *retrieveNodeOrEndNode(const NodeIndex &idx); + const MWNode *retrieveNodeOrEndNode(const NodeIndex &idx) const; + MWNode *retrieveNodeOrEndNode(const NodeIndex &idx); void threadSafeGenChildren(); void deleteGenerated(); diff --git a/src/trees/MWTree.cpp b/src/trees/MWTree.cpp index 583fb1fc1..c849517da 100644 --- a/src/trees/MWTree.cpp +++ b/src/trees/MWTree.cpp @@ -49,8 +49,8 @@ namespace mrcpp { * root nodes. The information for the root node configuration to use * is in the mra object which is passed to the constructor. */ -template -MWTree::MWTree(const MultiResolutionAnalysis &mra, const std::string &n) + template +MWTree::MWTree(const MultiResolutionAnalysis &mra, const std::string &n) : MRA(mra) , order(mra.getOrder()) /// polynomial order , kp1_d(math_utils::ipow(mra.getOrder() + 1, D)) ///nr of scaling coefficients \f$ (k+1)^D \f$ @@ -61,7 +61,7 @@ MWTree::MWTree(const MultiResolutionAnalysis &mra, const std::string &n) } /** @brief MWTree destructor. */ -template MWTree::~MWTree() { +template MWTree::~MWTree() { this->endNodeTable.clear(); if (this->nodesAtDepth.size() != 1) MSG_ERROR("Nodes at depth != 1 -> " << this->nodesAtDepth.size()); if (this->nodesAtDepth[0] != 0) MSG_ERROR("Nodes at depth 0 != 0 -> " << this->nodesAtDepth[0]); @@ -73,9 +73,9 @@ template MWTree::~MWTree() { * including the root nodes. Derived classes will call this method * when the object is deleted. */ -template void MWTree::deleteRootNodes() { +template void MWTree::deleteRootNodes() { for (int i = 0; i < this->rootBox.size(); i++) { - MWNode &root = this->getRootMWNode(i); + MWNode &root = this->getRootMWNode(i); root.deleteChildren(); root.dealloc(); this->rootBox.clearNode(i); @@ -90,9 +90,9 @@ template void MWTree::deleteRootNodes() { * nodes, (nodeChunks in NodeAllocator) is NOT released, but is * immediately available to the new function. */ -template void MWTree::clear() { +template void MWTree::clear() { for (int i = 0; i < this->rootBox.size(); i++) { - MWNode &root = this->getRootMWNode(i); + MWNode &root = this->getRootMWNode(i); root.deleteChildren(); root.clearHasCoefs(); root.clearNorms(); @@ -106,10 +106,10 @@ template void MWTree::clear() { * @details The norm is calculated using endNodes only. The specific * type of norm which is computed will depend on the derived class */ -template void MWTree::calcSquareNorm() { +template void MWTree::calcSquareNorm() { double treeNorm = 0.0; for (int n = 0; n < this->getNEndNodes(); n++) { - const MWNode &node = getEndMWNode(n); + const MWNode &node = getEndMWNode(n); assert(node.hasCoefs()); treeNorm += node.getSquareNorm(); } @@ -126,9 +126,9 @@ template void MWTree::calcSquareNorm() { * @details It performs a Multiwavlet transform of the whole tree. The * input parameters will specify the direction (upwards or downwards) * and whether the result is added to the coefficients or it - * overwrites them. See the documentation for the #mwTransformUp + * overwrites them. See the documentation for the #mwTransformUp * and #mwTransformDown for details. - * \f[ + * \f[ * \pmatrix{ * s_{nl}\\ * d_{nl} @@ -139,7 +139,7 @@ template void MWTree::calcSquareNorm() { * } * \f] */ -template void MWTree::mwTransform(int type, bool overwrite) { +template void MWTree::mwTransform(int type, bool overwrite) { switch (type) { case TopDown: mwTransformDown(overwrite); @@ -162,8 +162,8 @@ template void MWTree::mwTransform(int type, bool overwrite) { * projection to purify the coefficients obtained by quadrature at * coarser scales which are therefore not precise enough. */ -template void MWTree::mwTransformUp() { - std::vector> nodeTable; +template void MWTree::mwTransformUp() { + std::vector> nodeTable; tree_utils::make_node_table(*this, nodeTable); #pragma omp parallel shared(nodeTable) num_threads(mrcpp_get_num_threads()) { @@ -172,7 +172,7 @@ template void MWTree::mwTransformUp() { int nNodes = nodeTable[n].size(); #pragma omp for schedule(guided) for (int i = 0; i < nNodes; i++) { - MWNode &node = *nodeTable[n][i]; + MWNode &node = *nodeTable[n][i]; if (node.isBranchNode()) { node.reCompress(); } } } @@ -190,8 +190,8 @@ template void MWTree::mwTransformUp() { * operation is generally used after the operator application. * */ -template void MWTree::mwTransformDown(bool overwrite) { - std::vector> nodeTable; +template void MWTree::mwTransformDown(bool overwrite) { + std::vector> nodeTable; tree_utils::make_node_table(*this, nodeTable); #pragma omp parallel shared(nodeTable) num_threads(mrcpp_get_num_threads()) { @@ -199,7 +199,7 @@ template void MWTree::mwTransformDown(bool overwrite) { int n_nodes = nodeTable[n].size(); #pragma omp for schedule(guided) for (int i = 0; i < n_nodes; i++) { - MWNode &node = *nodeTable[n][i]; + MWNode &node = *nodeTable[n][i]; if (node.isBranchNode()) { if (this->getRootScale() > node.getScale()) { int reverse = n_nodes - 1; @@ -215,15 +215,15 @@ template void MWTree::mwTransformDown(bool overwrite) { } /** @brief Set the MW coefficients to zero, keeping the same tree structure - * + * * @details Keeps the node structure of the tree, even though the zero * function is representable at depth zero. One should then use \ref cropTree to remove * unnecessary nodes. */ -template void MWTree::setZero() { - TreeIterator it(*this); +template void MWTree::setZero() { + TreeIterator it(*this); while (it.next()) { - MWNode &node = it.getNode(); + MWNode &node = it.getNode(); node.zeroCoefs(); } this->squareNorm = 0.0; @@ -236,7 +236,7 @@ template void MWTree::setZero() { * safe, and must NEVER be called outside a critical region in parallel. * It's way. way too expensive to lock the tree, so don't even think * about it. */ -template void MWTree::incrementNodeCount(int scale) { +template void MWTree::incrementNodeCount(int scale) { int depth = scale - getRootScale(); if (depth < 0) { int n = this->nodesAtNegativeDepth.size(); @@ -261,7 +261,7 @@ template void MWTree::incrementNodeCount(int scale) { * It's way. way too expensive to lock the tree, so don't even think * about it. */ -template void MWTree::decrementNodeCount(int scale) { +template void MWTree::decrementNodeCount(int scale) { int depth = scale - getRootScale(); if (depth < 0) { assert(-depth - 1 < this->nodesAtNegativeDepth.size()); @@ -280,7 +280,7 @@ template void MWTree::decrementNodeCount(int scale) { * * @param[in] depth: Tree depth (0 depth is the coarsest scale) to count. */ -template int MWTree::getNNodesAtDepth(int depth) const { +template int MWTree::getNNodesAtDepth(int depth) const { int N = 0; if (depth < 0) { if (this->nodesAtNegativeDepth.size() >= -depth) N = this->nodesAtNegativeDepth[-depth]; @@ -291,9 +291,9 @@ template int MWTree::getNNodesAtDepth(int depth) const { } /** @returns Size of all MW coefs in the tree, in kB */ -template int MWTree::getSizeNodes() const { +template int MWTree::getSizeNodes() const { auto nCoefs = 1ll * getNNodes() * getTDim() * getKp1_d(); - return sizeof(double) * nCoefs / 1024; + return sizeof(T) * nCoefs / 1024; } /** @brief Finds and returns the node pointer with the given \ref NodeIndex, const version. @@ -303,11 +303,11 @@ template int MWTree::getSizeNodes() const { * pointer if the node does not exist, or if it is a * GenNode. Recursion starts at the appropriate rootNode. */ -template const MWNode *MWTree::findNode(NodeIndex idx) const { +template const MWNode *MWTree::findNode(NodeIndex idx) const { if (getRootBox().isPeriodic()) { periodic::index_manipulation(idx, getRootBox().getPeriodic()); } int rIdx = getRootBox().getBoxIndex(idx); if (rIdx < 0) return nullptr; - const MWNode &root = this->rootBox.getNode(rIdx); + const MWNode &root = this->rootBox.getNode(rIdx); assert(root.isAncestor(idx)); return root.retrieveNodeNoGen(idx); } @@ -319,11 +319,11 @@ template const MWNode *MWTree::findNode(NodeIndex idx) const { * pointer if the node does not exist, or if it is a * GenNode. Recursion starts at the appropriate rootNode. */ -template MWNode *MWTree::findNode(NodeIndex idx) { +template MWNode *MWTree::findNode(NodeIndex idx) { if (getRootBox().isPeriodic()) { periodic::index_manipulation(idx, getRootBox().getPeriodic()); } int rIdx = getRootBox().getBoxIndex(idx); if (rIdx < 0) return nullptr; - MWNode &root = this->rootBox.getNode(rIdx); + MWNode &root = this->rootBox.getNode(rIdx); assert(root.isAncestor(idx)); return root.retrieveNodeNoGen(idx); } @@ -335,11 +335,11 @@ template MWNode *MWTree::findNode(NodeIndex idx) { * transform. Recursion starts at the appropriate rootNode and descends * from this. */ -template MWNode &MWTree::getNode(NodeIndex idx) { +template MWNode &MWTree::getNode(NodeIndex idx) { if (getRootBox().isPeriodic()) periodic::index_manipulation(idx, getRootBox().getPeriodic()); - MWNode *out = nullptr; - MWNode &root = getRootBox().getNode(idx); + MWNode *out = nullptr; + MWNode &root = getRootBox().getNode(idx); if (idx.getScale() < getRootScale()) { #pragma omp critical(gen_parent) out = root.retrieveParent(idx); @@ -357,9 +357,9 @@ template MWNode &MWTree::getNode(NodeIndex idx) { * GenNodes. Recursion starts at the appropriate rootNode and decends * from this. */ -template MWNode &MWTree::getNodeOrEndNode(NodeIndex idx) { +template MWNode &MWTree::getNodeOrEndNode(NodeIndex idx) { if (getRootBox().isPeriodic()) { periodic::index_manipulation(idx, getRootBox().getPeriodic()); } - MWNode &root = getRootBox().getNode(idx); + MWNode &root = getRootBox().getNode(idx); assert(root.isAncestor(idx)); return *root.retrieveNodeOrEndNode(idx); } @@ -371,9 +371,9 @@ template MWNode &MWTree::getNodeOrEndNode(NodeIndex idx) { * transform. Recursion starts at the appropriate rootNode and decends * from this. */ -template const MWNode &MWTree::getNodeOrEndNode(NodeIndex idx) const { +template const MWNode &MWTree::getNodeOrEndNode(NodeIndex idx) const { if (getRootBox().isPeriodic()) { periodic::index_manipulation(idx, getRootBox().getPeriodic()); } - const MWNode &root = getRootBox().getNode(idx); + const MWNode &root = getRootBox().getNode(idx); assert(root.isAncestor(idx)); return *root.retrieveNodeOrEndNode(idx); } @@ -387,8 +387,8 @@ template const MWNode &MWTree::getNodeOrEndNode(NodeIndex idx) * generate nodes that do not exist. Recursion starts at the * appropriate rootNode and decends from this. */ -template MWNode &MWTree::getNode(Coord r, int depth) { - MWNode &root = getRootBox().getNode(r); +template MWNode &MWTree::getNode(Coord r, int depth) { + MWNode &root = getRootBox().getNode(r); if (depth >= 0) { return *root.retrieveNode(r, depth); } else { @@ -405,11 +405,11 @@ template MWNode &MWTree::getNode(Coord r, int depth) { * the path to the requested node, and will never create or return GenNodes. * Recursion starts at the appropriate rootNode and decends from this. */ -template MWNode &MWTree::getNodeOrEndNode(Coord r, int depth) { +template MWNode &MWTree::getNodeOrEndNode(Coord r, int depth) { if (getRootBox().isPeriodic()) { periodic::coord_manipulation(r, getRootBox().getPeriodic()); } - MWNode &root = getRootBox().getNode(r); + MWNode &root = getRootBox().getNode(r); return *root.retrieveNodeOrEndNode(r, depth); } @@ -422,10 +422,10 @@ template MWNode &MWTree::getNodeOrEndNode(Coord r, int depth) { * the path to the requested node, and will never create or return GenNodes. * Recursion starts at the appropriate rootNode and decends from this. */ -template const MWNode &MWTree::getNodeOrEndNode(Coord r, int depth) const { +template const MWNode &MWTree::getNodeOrEndNode(Coord r, int depth) const { if (getRootBox().isPeriodic()) { periodic::coord_manipulation(r, getRootBox().getPeriodic()); } - const MWNode &root = getRootBox().getNode(r); + const MWNode &root = getRootBox().getNode(r); return *root.retrieveNodeOrEndNode(r, depth); } @@ -434,10 +434,10 @@ template const MWNode &MWTree::getNodeOrEndNode(Coord r, int de * @details copies the list of all EndNode pointers into a new vector * and retunrs it. */ -template MWNodeVector *MWTree::copyEndNodeTable() { - auto *nVec = new MWNodeVector; +template MWNodeVector *MWTree::copyEndNodeTable() { + auto *nVec = new MWNodeVector; for (int n = 0; n < getNEndNodes(); n++) { - MWNode &node = getEndMWNode(n); + MWNode &node = getEndMWNode(n); nVec->push_back(&node); } return nVec; @@ -447,29 +447,29 @@ template MWNodeVector *MWTree::copyEndNodeTable() { * * @details the endNodeTable is first deleted and then rebuilt from * scratch. It makes use of the TreeIterator to traverse the tree. - * + * */ -template void MWTree::resetEndNodeTable() { +template void MWTree::resetEndNodeTable() { clearEndNodeTable(); - TreeIterator it(*this, TopDown, Hilbert); + TreeIterator it(*this, TopDown, Hilbert); it.setReturnGenNodes(false); while (it.next()) { - MWNode &node = it.getNode(); + MWNode &node = it.getNode(); if (node.isEndNode()) { this->endNodeTable.push_back(&node); } } } -template int MWTree::countBranchNodes(int depth) { +template int MWTree::countBranchNodes(int depth) { NOT_IMPLEMENTED_ABORT; } -template int MWTree::countLeafNodes(int depth) { +template int MWTree::countLeafNodes(int depth) { NOT_IMPLEMENTED_ABORT; // int nNodes = 0; - // TreeIterator it(*this); + // TreeIterator it(*this); // while (it.next()) { - // MWNode &node = it.getNode(); + // MWNode &node = it.getNode(); // if (node.getDepth() == depth or depth < 0) { // if (node.isLeafNode()) { // nNodes++; @@ -480,12 +480,12 @@ template int MWTree::countLeafNodes(int depth) { } /* Traverse tree and count nodes belonging to this rank. */ -template int MWTree::countNodes(int depth) { +template int MWTree::countNodes(int depth) { NOT_IMPLEMENTED_ABORT; - // TreeIterator it(*this); + // TreeIterator it(*this); // int count = 0; // while (it.next()) { - // MWNode &node = it.getNode(); + // MWNode &node = it.getNode(); // if (node.isGenNode()) { // continue; // } @@ -497,12 +497,12 @@ template int MWTree::countNodes(int depth) { } /* Traverse tree and count nodes with allocated coefficients. */ -template int MWTree::countAllocNodes(int depth) { +template int MWTree::countAllocNodes(int depth) { NOT_IMPLEMENTED_ABORT; - // TreeIterator it(*this); + // TreeIterator it(*this); // int count = 0; // while (it.next()) { - // MWNode &node = it.getNode(); + // MWNode &node = it.getNode(); // if (node.isGenNode()) { // continue; // } @@ -515,7 +515,7 @@ template int MWTree::countAllocNodes(int depth) { /** @brief Prints a summary of the tree structure on the output file */ -template std::ostream &MWTree::print(std::ostream &o) const { +template std::ostream &MWTree::print(std::ostream &o) const { o << " square norm: " << this->squareNorm << std::endl; o << " root scale: " << this->getRootScale() << std::endl; o << " order: " << this->order << std::endl; @@ -532,9 +532,9 @@ template std::ostream &MWTree::print(std::ostream &o) const { * @details it defines the upper bound of the squared norm \f$ * ||f||^2_{\ldots} \f$ in this node or its descendents */ -template void MWTree::makeMaxSquareNorms() { - NodeBox &rBox = this->getRootBox(); - MWNode **roots = rBox.getNodes(); +template void MWTree::makeMaxSquareNorms() { + NodeBox &rBox = this->getRootBox(); + MWNode **roots = rBox.getNodes(); for (int rIdx = 0; rIdx < rBox.size(); rIdx++) { // recursively set value of children and descendants roots[rIdx]->setMaxSquareNorm(); @@ -543,15 +543,16 @@ template void MWTree::makeMaxSquareNorms() { /** @brief gives serialIx of a node from its NodeIndex * - * @details Peter will document this! + * @details gives a unique integer for each nodes corresponding to the position + * of the node in the serialized representation */ -template int MWTree::getIx(NodeIndex nIdx) { +template int MWTree::getIx(NodeIndex nIdx) { if (this->isLocal == false) MSG_ERROR("getIx only implemented in local representation"); if(NodeIndex2serialIx.count(nIdx) == 0) return -1; else return NodeIndex2serialIx[nIdx]; } -template void MWTree::getNodeCoeff(NodeIndex nIdx, double *data) { +template void MWTree::getNodeCoeff(NodeIndex nIdx, T *data) { assert(this->isLocal); int size = (1 << D) * kp1_d; int id = 0; @@ -559,8 +560,13 @@ template void MWTree::getNodeCoeff(NodeIndex nIdx, double *data) { this->NodesCoeff->get_data(id, size, data); } -template class MWTree<1>; -template class MWTree<2>; -template class MWTree<3>; +template class MWTree<1, double>; +template class MWTree<2, double>; +template class MWTree<3, double>; + + +template class MWTree<1, ComplexDouble>; +template class MWTree<2, ComplexDouble>; +template class MWTree<3, ComplexDouble>; } // namespace mrcpp diff --git a/src/trees/MWTree.h b/src/trees/MWTree.h index 51cfe3eed..c2f231ccf 100644 --- a/src/trees/MWTree.h +++ b/src/trees/MWTree.h @@ -61,11 +61,11 @@ class BankAccount; * present. See specific methods for details. * */ -template class MWTree { + template class MWTree { public: MWTree(const MultiResolutionAnalysis &mra, const std::string &n); - MWTree(const MWTree &tree) = delete; - MWTree &operator=(const MWTree &tree) = delete; + MWTree(const MWTree &tree) = delete; + MWTree &operator=(const MWTree &tree) = delete; virtual ~MWTree(); void setZero(); @@ -90,8 +90,8 @@ template class MWTree { int getSizeNodes() const; /** @returns */ - NodeBox &getRootBox() { return this->rootBox; } - const NodeBox &getRootBox() const { return this->rootBox; } + NodeBox &getRootBox() { return this->rootBox; } + const NodeBox &getRootBox() const { return this->rootBox; } const MultiResolutionAnalysis &getMRA() const { return this->MRA; } void mwTransform(int type, bool overwrite = true); @@ -102,28 +102,28 @@ template class MWTree { int getRootIndex(Coord r) const { return this->rootBox.getBoxIndex(r); } int getRootIndex(NodeIndex nIdx) const { return this->rootBox.getBoxIndex(nIdx); } - MWNode *findNode(NodeIndex nIdx); - const MWNode *findNode(NodeIndex nIdx) const; + MWNode *findNode(NodeIndex nIdx); + const MWNode *findNode(NodeIndex nIdx) const; - MWNode &getNode(NodeIndex nIdx); - MWNode &getNodeOrEndNode(NodeIndex nIdx); - const MWNode &getNodeOrEndNode(NodeIndex nIdx) const; + MWNode &getNode(NodeIndex nIdx); + MWNode &getNodeOrEndNode(NodeIndex nIdx); + const MWNode &getNodeOrEndNode(NodeIndex nIdx) const; - MWNode &getNode(Coord r, int depth = -1); - MWNode &getNodeOrEndNode(Coord r, int depth = -1); - const MWNode &getNodeOrEndNode(Coord r, int depth = -1) const; + MWNode &getNode(Coord r, int depth = -1); + MWNode &getNodeOrEndNode(Coord r, int depth = -1); + const MWNode &getNodeOrEndNode(Coord r, int depth = -1) const; int getNEndNodes() const { return this->endNodeTable.size(); } int getNRootNodes() const { return this->rootBox.size(); } - MWNode &getEndMWNode(int i) { return *this->endNodeTable[i]; } - MWNode &getRootMWNode(int i) { return this->rootBox.getNode(i); } - const MWNode &getEndMWNode(int i) const { return *this->endNodeTable[i]; } - const MWNode &getRootMWNode(int i) const { return this->rootBox.getNode(i); } + MWNode &getEndMWNode(int i) { return *this->endNodeTable[i]; } + MWNode &getRootMWNode(int i) { return this->rootBox.getNode(i); } + const MWNode &getEndMWNode(int i) const { return *this->endNodeTable[i]; } + const MWNode &getRootMWNode(int i) const { return this->rootBox.getNode(i); } bool isPeriodic() const { return this->MRA.getWorldBox().isPeriodic(); } - MWNodeVector *copyEndNodeTable(); - MWNodeVector *getEndNodeTable() { return &this->endNodeTable; } + MWNodeVector *copyEndNodeTable(); + MWNodeVector *getEndNodeTable() { return &this->endNodeTable; } void deleteRootNodes(); void resetEndNodeTable(); @@ -138,19 +138,19 @@ template class MWTree { void makeMaxSquareNorms(); // sets values for maxSquareNorm and maxWSquareNorm in all nodes - NodeAllocator &getNodeAllocator() { return *this->nodeAllocator_p; } - const NodeAllocator &getNodeAllocator() const { return *this->nodeAllocator_p; } - MWNodeVector endNodeTable; ///< Final projected nodes + NodeAllocator &getNodeAllocator() { return *this->nodeAllocator_p; } + const NodeAllocator &getNodeAllocator() const { return *this->nodeAllocator_p; } + MWNodeVector endNodeTable; ///< Final projected nodes - void getNodeCoeff(NodeIndex nIdx, double *data); // fetch coefficient from a specific node stored in Bank + void getNodeCoeff(NodeIndex nIdx, T *data); // fetch coefficient from a specific node stored in Bank - friend std::ostream &operator<<(std::ostream &o, const MWTree &tree) { return tree.print(o); } + friend std::ostream &operator<<(std::ostream &o, const MWTree &tree) { return tree.print(o); } - friend class MWNode; - friend class FunctionNode; + friend class MWNode; + friend class FunctionNode; friend class OperatorNode; - friend class TreeBuilder; - friend class NodeAllocator; + friend class TreeBuilder; + friend class NodeAllocator; protected: // Parameters that are set in construction and should never change @@ -165,11 +165,11 @@ template class MWTree { // Parameters that are dynamic and can be set by user std::string name; - std::unique_ptr> nodeAllocator_p{nullptr}; + std::unique_ptr> nodeAllocator_p{nullptr}; // Tree data double squareNorm; - NodeBox rootBox; ///< The actual container of nodes + NodeBox rootBox; ///< The actual container of nodes std::vector nodesAtDepth; ///< Node counter std::vector nodesAtNegativeDepth; ///< Node counter diff --git a/src/trees/NodeAllocator.cpp b/src/trees/NodeAllocator.cpp index 9ca79f0b4..b33d5ccf1 100644 --- a/src/trees/NodeAllocator.cpp +++ b/src/trees/NodeAllocator.cpp @@ -38,7 +38,7 @@ namespace mrcpp { -template NodeAllocator::NodeAllocator(FunctionTree *tree, SharedMemory *mem, int coefsPerNode, int nodesPerChunk) +template NodeAllocator::NodeAllocator(FunctionTree *tree, SharedMemory *mem, int coefsPerNode, int nodesPerChunk) : coefsPerNode(coefsPerNode) , maxNodesPerChunk(nodesPerChunk) , tree_p(tree) @@ -47,14 +47,14 @@ template NodeAllocator::NodeAllocator(FunctionTree *tree, SharedMe this->nodeChunks.reserve(100); this->coefChunks.reserve(100); - FunctionNode tmp; + FunctionNode tmp; this->cvptr = *(char **)(&tmp); - this->sizeOfNode = sizeof(FunctionNode); + this->sizeOfNode = sizeof(FunctionNode); MRCPP_INIT_OMP_LOCK(); } -template <> NodeAllocator<2>::NodeAllocator(OperatorTree *tree, SharedMemory *mem, int coefsPerNode, int nodesPerChunk) +template <> NodeAllocator<2>::NodeAllocator(OperatorTree *tree, SharedMemory *mem, int coefsPerNode, int nodesPerChunk) : coefsPerNode(coefsPerNode) , maxNodesPerChunk(nodesPerChunk) , tree_p(tree) @@ -70,11 +70,11 @@ template <> NodeAllocator<2>::NodeAllocator(OperatorTree *tree, SharedMemory *me MRCPP_INIT_OMP_LOCK(); } -template NodeAllocator::NodeAllocator(OperatorTree *tree, SharedMemory *mem, int coefsPerNode, int nodesPerChunk) { +template NodeAllocator::NodeAllocator(OperatorTree *tree, SharedMemory *mem, int coefsPerNode, int nodesPerChunk) { NOT_REACHED_ABORT; } -template NodeAllocator::~NodeAllocator() { +template NodeAllocator::~NodeAllocator() { for (auto &chunk : this->nodeChunks) delete[](char *) chunk; if (not isShared()) // if the data is shared, it must be freed by MPI_Win_free for (auto &chunk : this->coefChunks) delete[] chunk; @@ -82,35 +82,35 @@ template NodeAllocator::~NodeAllocator() { MRCPP_DESTROY_OMP_LOCK(); } -template MWNode * NodeAllocator::getNode_p(int sIdx) { +template MWNode * NodeAllocator::getNode_p(int sIdx) { MRCPP_SET_OMP_LOCK(); auto *node = getNodeNoLock(sIdx); MRCPP_UNSET_OMP_LOCK(); return node; } -template double * NodeAllocator::getCoef_p(int sIdx) { +template T * NodeAllocator::getCoef_p(int sIdx) { MRCPP_SET_OMP_LOCK(); auto *coefs = getCoefNoLock(sIdx); MRCPP_UNSET_OMP_LOCK(); return coefs; } -template MWNode * NodeAllocator::getNodeNoLock(int sIdx) { +template MWNode * NodeAllocator::getNodeNoLock(int sIdx) { if (sIdx < 0 or sIdx >= this->stackStatus.size()) return nullptr; int chunk = sIdx / this->maxNodesPerChunk; // which chunk int cIdx = sIdx % this->maxNodesPerChunk; // position in chunk return this->nodeChunks[chunk] + cIdx; } -template double * NodeAllocator::getCoefNoLock(int sIdx) { +template T * NodeAllocator::getCoefNoLock(int sIdx) { if (sIdx < 0 or sIdx >= this->stackStatus.size()) return nullptr; int chunk = sIdx / this->maxNodesPerChunk; // which chunk int idx = sIdx % this->maxNodesPerChunk; // position in chunk return this->coefChunks[chunk] + idx * this->coefsPerNode; } -template int NodeAllocator::alloc(int nNodes, bool coefs) { +template int NodeAllocator::alloc(int nNodes, bool coefs) { MRCPP_SET_OMP_LOCK(); if (nNodes <= 0 or nNodes > this->maxNodesPerChunk) MSG_ABORT("Cannot allocate " << nNodes << " nodes"); @@ -143,7 +143,7 @@ template int NodeAllocator::alloc(int nNodes, bool coefs) { return sIdx; } -template void NodeAllocator::dealloc(int sIdx) { +template void NodeAllocator::dealloc(int sIdx) { MRCPP_SET_OMP_LOCK(); if (sIdx < 0 or sIdx >= this->stackStatus.size()) MSG_ABORT("Invalid serial index: " << sIdx); auto *node_p = getNodeNoLock(sIdx); @@ -161,7 +161,7 @@ template void NodeAllocator::dealloc(int sIdx) { MRCPP_UNSET_OMP_LOCK(); } -template void NodeAllocator::deallocAllCoeff() { +template void NodeAllocator::deallocAllCoeff() { if (not this->isShared()) for (auto &chunk : this->coefChunks) delete[] chunk; else delete this->shmem_p; @@ -170,7 +170,7 @@ template void NodeAllocator::deallocAllCoeff() { } -template void NodeAllocator::init(int nChunks, bool coefs) { +template void NodeAllocator::init(int nChunks, bool coefs) { MRCPP_SET_OMP_LOCK(); if (nChunks <= 0) MSG_ABORT("Invalid number of chunks: " << nChunks); for (int i = getNChunks(); i < nChunks; i++) appendChunk(coefs); @@ -182,10 +182,10 @@ template void NodeAllocator::init(int nChunks, bool coefs) { MRCPP_UNSET_OMP_LOCK(); } -template void NodeAllocator::appendChunk(bool coefs) { +template void NodeAllocator::appendChunk(bool coefs) { // make coeff chunk if (coefs) { - double *c_chunk = nullptr; + T *c_chunk = nullptr; if (this->isShared()) { // for coefficients, take from the shared memory block c_chunk = this->shmem_p->sh_end_ptr; @@ -193,13 +193,13 @@ template void NodeAllocator::appendChunk(bool coefs) { // may increase size dynamically in the future if (this->shmem_p->sh_max_ptr < this->shmem_p->sh_end_ptr) MSG_ABORT("Shared block too small"); } else { - c_chunk = new double[getCoefChunkSize() / sizeof(double)]; + c_chunk = new T[getCoefChunkSize() / sizeof(T)]; } this->coefChunks.push_back(c_chunk); } // make node chunk - auto n_chunk = (MWNode *)new char[getNodeChunkSize()]; + auto n_chunk = (MWNode *)new char[getNodeChunkSize()]; for (int i = 0; i < this->maxNodesPerChunk; i++) { n_chunk[i].serialIx = -1; n_chunk[i].parentSerialIx = -1; @@ -215,7 +215,7 @@ template void NodeAllocator::appendChunk(bool coefs) { } /** Fill all holes in the chunks with occupied nodes, then remove all empty chunks */ -template int NodeAllocator::compress() { +template int NodeAllocator::compress() { MRCPP_SET_OMP_LOCK(); int nNodes = (1 << D); if (this->maxNodesPerChunk * this->nodeChunks.size() <= @@ -249,7 +249,7 @@ template int NodeAllocator::compress() { return nChunksDeleted; } -template int NodeAllocator::deleteUnusedChunks() { +template int NodeAllocator::deleteUnusedChunks() { // number of occupied chunks int nChunksTotal = getNChunks(); int nChunksUsed = getNChunksUsed(); @@ -271,7 +271,7 @@ template int NodeAllocator::deleteUnusedChunks() { return nChunksTotal - nChunksUsed; } -template void NodeAllocator::moveNodes(int nNodes, int srcIdx, int dstIdx) { +template void NodeAllocator::moveNodes(int nNodes, int srcIdx, int dstIdx) { assert(nNodes > 0); assert(nNodes <= this->maxNodesPerChunk); @@ -288,7 +288,7 @@ template void NodeAllocator::moveNodes(int nNodes, int srcIdx, int ds for (int i = 0; i < nNodes * this->sizeOfNode; i++) ((char *)dstNode)[i] = ((char *)srcNode)[i]; // coefs have new adresses - double *coefs_p = getCoefNoLock(dstIdx); + T *coefs_p = getCoefNoLock(dstIdx); if (coefs_p == nullptr) NOT_IMPLEMENTED_ABORT; // Nodes without coefs not handled atm for (int i = 0; i < nNodes; i++) (dstNode + i)->coefs = coefs_p + i * getNCoefs(); @@ -325,7 +325,7 @@ template void NodeAllocator::moveNodes(int nNodes, int srcIdx, int ds } // Last positions on a chunk cannot be used if there is no place for nNodes siblings on the same chunk -template int NodeAllocator::findNextAvailable(int sIdx, int nNodes) const { +template int NodeAllocator::findNextAvailable(int sIdx, int nNodes) const { assert(sIdx >= 0); assert(sIdx < this->stackStatus.size()); assert(nNodes >= 0); @@ -343,7 +343,7 @@ template int NodeAllocator::findNextAvailable(int sIdx, int nNodes) c return sIdx; } -template int NodeAllocator::findNextOccupied(int sIdx) const { +template int NodeAllocator::findNextOccupied(int sIdx) const { assert(sIdx >= 0); assert(sIdx < this->stackStatus.size()); bool endOfStack = (sIdx >= this->topStack); @@ -359,17 +359,17 @@ template int NodeAllocator::findNextOccupied(int sIdx) const { } /** Traverse tree and redefine pointer, counter and tables. */ -template void NodeAllocator::reassemble() { +template void NodeAllocator::reassemble() { MRCPP_SET_OMP_LOCK(); this->nNodes = 0; getTree().nodesAtDepth.clear(); getTree().squareNorm = 0.0; getTree().clearEndNodeTable(); - NodeBox &rootbox = getTree().getRootBox(); - MWNode **roots = rootbox.getNodes(); + NodeBox &rootbox = getTree().getRootBox(); + MWNode **roots = rootbox.getNodes(); - std::stack *> stack; + std::stack *> stack; for (int rIdx = 0; rIdx < rootbox.size(); rIdx++) { auto *root_p = getNodeNoLock(rIdx); assert(root_p != nullptr); @@ -414,7 +414,7 @@ template void NodeAllocator::reassemble() { MRCPP_UNSET_OMP_LOCK(); } -template void NodeAllocator::print() const { +template void NodeAllocator::print() const { int n = 0; for (int iChunk = 0; iChunk < getNChunks(); iChunk++) { int iShift = iChunk * this->maxNodesPerChunk; @@ -436,8 +436,12 @@ template void NodeAllocator::print() const { } } -template class NodeAllocator<1>; -template class NodeAllocator<2>; -template class NodeAllocator<3>; +template class NodeAllocator<1, double>; +template class NodeAllocator<2, double>; +template class NodeAllocator<3, double>; + +template class NodeAllocator<1, ComplexDouble>; +template class NodeAllocator<2, ComplexDouble>; +template class NodeAllocator<3, ComplexDouble>; } // namespace mrcpp diff --git a/src/trees/NodeAllocator.h b/src/trees/NodeAllocator.h index 38e4ba7eb..b426d0021 100644 --- a/src/trees/NodeAllocator.h +++ b/src/trees/NodeAllocator.h @@ -40,12 +40,12 @@ namespace mrcpp { -template class NodeAllocator final { + template class NodeAllocator final { public: - NodeAllocator(OperatorTree *tree, SharedMemory *mem, int coefsPerNode, int nodesPerChunk); - NodeAllocator(FunctionTree *tree, SharedMemory *mem, int coefsPerNode, int nodesPerChunk); - NodeAllocator(const NodeAllocator &tree) = delete; - NodeAllocator &operator=(const NodeAllocator &tree) = delete; + NodeAllocator(OperatorTree *tree, SharedMemory *mem, int coefsPerNode, int nodesPerChunk); + NodeAllocator(FunctionTree *tree, SharedMemory *mem, int coefsPerNode, int nodesPerChunk); + NodeAllocator(const NodeAllocator &tree) = delete; + NodeAllocator &operator=(const NodeAllocator &tree) = delete; ~NodeAllocator(); int alloc(int nNodes, bool coefs = true); @@ -63,13 +63,13 @@ template class NodeAllocator final { int getNChunks() const { return this->nodeChunks.size(); } int getNChunksUsed() const { return (this->topStack + this->maxNodesPerChunk - 1) / this->maxNodesPerChunk; } int getNodeChunkSize() const { return this->maxNodesPerChunk * this->sizeOfNode; } - int getCoefChunkSize() const { return this->maxNodesPerChunk * this->coefsPerNode * sizeof(double); } + int getCoefChunkSize() const { return this->maxNodesPerChunk * this->coefsPerNode * sizeof(T); } - double * getCoef_p(int sIdx); - MWNode * getNode_p(int sIdx); + T * getCoef_p(int sIdx); + MWNode * getNode_p(int sIdx); - double * getCoefChunk(int i) { return this->coefChunks[i]; } - MWNode * getNodeChunk(int i) { return this->nodeChunks[i]; } + T * getCoefChunk(int i) { return this->coefChunks[i]; } + MWNode * getNodeChunk(int i) { return this->nodeChunks[i]; } void print() const; @@ -81,20 +81,20 @@ template class NodeAllocator final { int maxNodesPerChunk{0}; // max number of nodes per allocation std::vector stackStatus{}; - std::vector coefChunks{}; - std::vector *> nodeChunks{}; + std::vector coefChunks{}; + std::vector *> nodeChunks{}; char *cvptr{nullptr}; // pointer to virtual table - MWNode *last_p{nullptr}; // pointer just after the last active node, i.e. where to put next node - MWTree *tree_p{nullptr}; // pointer to external object - SharedMemory *shmem_p{nullptr}; // pointer to external object + MWNode *last_p{nullptr}; // pointer just after the last active node, i.e. where to put next node + MWTree *tree_p{nullptr}; // pointer to external object + SharedMemory *shmem_p{nullptr}; // pointer to external object bool isShared() const { return (this->shmem_p != nullptr); } - MWTree &getTree() { return *this->tree_p; } - SharedMemory &getMemory() { return *this->shmem_p; } + MWTree &getTree() { return *this->tree_p; } + SharedMemory &getMemory() { return *this->shmem_p; } - double * getCoefNoLock(int sIdx); - MWNode * getNodeNoLock(int sIdx); + T * getCoefNoLock(int sIdx); + MWNode * getNodeNoLock(int sIdx); void moveNodes(int nNodes, int srcIdx, int dstIdx); void appendChunk(bool coefs); diff --git a/src/trees/NodeBox.cpp b/src/trees/NodeBox.cpp index cc247d58e..bf747c4fc 100644 --- a/src/trees/NodeBox.cpp +++ b/src/trees/NodeBox.cpp @@ -36,50 +36,50 @@ namespace mrcpp { -template -NodeBox::NodeBox(const NodeIndex &idx, const std::array &nb) +template +NodeBox::NodeBox(const NodeIndex &idx, const std::array &nb) : BoundingBox(idx, nb) , nOccupied(0) , nodes(nullptr) { allocNodePointers(); } -template -NodeBox::NodeBox(const BoundingBox &box) +template +NodeBox::NodeBox(const BoundingBox &box) : BoundingBox(box) , nOccupied(0) , nodes(nullptr) { allocNodePointers(); } -template -NodeBox::NodeBox(const NodeBox &box) +template +NodeBox::NodeBox(const NodeBox &box) : BoundingBox(box) , nOccupied(0) , nodes(nullptr) { allocNodePointers(); } -template void NodeBox::allocNodePointers() { +template void NodeBox::allocNodePointers() { assert(this->nodes == nullptr); int nNodes = this->size(); - this->nodes = new MWNode *[nNodes]; + this->nodes = new MWNode *[nNodes]; for (int n = 0; n < nNodes; n++) { this->nodes[n] = nullptr; } this->nOccupied = 0; } -template NodeBox::~NodeBox() { +template NodeBox::~NodeBox() { deleteNodes(); } -template void NodeBox::deleteNodes() { +template void NodeBox::deleteNodes() { if (this->nodes == nullptr) { return; } for (int n = 0; n < this->size(); n++) { clearNode(n); } delete[] this->nodes; this->nodes = nullptr; } -template void NodeBox::setNode(int bIdx, MWNode **node) { +template void NodeBox::setNode(int bIdx, MWNode **node) { assert(bIdx >= 0); assert(bIdx < this->totBoxes); clearNode(bIdx); @@ -89,44 +89,48 @@ template void NodeBox::setNode(int bIdx, MWNode **node) { *node = nullptr; } -template MWNode &NodeBox::getNode(NodeIndex nIdx) { +template MWNode &NodeBox::getNode(NodeIndex nIdx) { int bIdx = this->getBoxIndex(nIdx); return getNode(bIdx); } -template MWNode &NodeBox::getNode(Coord r) { +template MWNode &NodeBox::getNode(Coord r) { int bIdx = this->getBoxIndex(r); if (bIdx < 0) MSG_ERROR("Coord out of bounds"); return getNode(bIdx); } -template MWNode &NodeBox::getNode(int bIdx) { +template MWNode &NodeBox::getNode(int bIdx) { assert(bIdx >= 0); assert(bIdx < this->totBoxes); assert(this->nodes[bIdx] != nullptr); return *this->nodes[bIdx]; } -template const MWNode &NodeBox::getNode(NodeIndex nIdx) const { +template const MWNode &NodeBox::getNode(NodeIndex nIdx) const { int bIdx = this->getBoxIndex(nIdx); return getNode(bIdx); } -template const MWNode &NodeBox::getNode(Coord r) const { +template const MWNode &NodeBox::getNode(Coord r) const { int bIdx = this->getBoxIndex(r); if (bIdx < 0) MSG_ERROR("Coord out of bounds"); return getNode(bIdx); } -template const MWNode &NodeBox::getNode(int bIdx) const { +template const MWNode &NodeBox::getNode(int bIdx) const { assert(bIdx >= 0); assert(bIdx < this->totBoxes); assert(this->nodes[bIdx] != nullptr); return *this->nodes[bIdx]; } -template class NodeBox<1>; -template class NodeBox<2>; -template class NodeBox<3>; +template class NodeBox<1, double>; +template class NodeBox<2, double>; +template class NodeBox<3, double>; + +template class NodeBox<1, ComplexDouble>; +template class NodeBox<2, ComplexDouble>; +template class NodeBox<3, ComplexDouble>; } // namespace mrcpp diff --git a/src/trees/NodeBox.h b/src/trees/NodeBox.h index dfb0dc20c..3b53da538 100644 --- a/src/trees/NodeBox.h +++ b/src/trees/NodeBox.h @@ -30,31 +30,31 @@ namespace mrcpp { -template class NodeBox final : public BoundingBox { + template class NodeBox final : public BoundingBox { public: NodeBox(const NodeIndex &idx, const std::array &nb = {}); - NodeBox(const NodeBox &box); + NodeBox(const NodeBox &box); NodeBox(const BoundingBox &box); - NodeBox &operator=(const NodeBox &box) = delete; + NodeBox &operator=(const NodeBox &box) = delete; ~NodeBox() override; - void setNode(int idx, MWNode **node); + void setNode(int idx, MWNode **node); void clearNode(int idx) { this->nodes[idx] = nullptr; } - MWNode &getNode(NodeIndex idx); - MWNode &getNode(Coord r); - MWNode &getNode(int i = 0); + MWNode &getNode(NodeIndex idx); + MWNode &getNode(Coord r); + MWNode &getNode(int i = 0); - const MWNode &getNode(NodeIndex idx) const; - const MWNode &getNode(Coord r) const; - const MWNode &getNode(int i = 0) const; + const MWNode &getNode(NodeIndex idx) const; + const MWNode &getNode(Coord r) const; + const MWNode &getNode(int i = 0) const; int getNOccupied() const { return this->nOccupied; } - MWNode **getNodes() { return this->nodes; } + MWNode **getNodes() { return this->nodes; } protected: int nOccupied; ///< Number of non-zero pointers in box - MWNode **nodes; ///< Container of nodes + MWNode **nodes; ///< Container of nodes void allocNodePointers(); void deleteNodes(); diff --git a/src/trees/TreeIterator.cpp b/src/trees/TreeIterator.cpp index 9bf9fb054..f7f88e03b 100644 --- a/src/trees/TreeIterator.cpp +++ b/src/trees/TreeIterator.cpp @@ -29,7 +29,7 @@ namespace mrcpp { -template TreeIterator::TreeIterator(int traverse, int iterator) +template TreeIterator::TreeIterator(int traverse, int iterator) : root(0) , nRoots(0) , mode(traverse) @@ -38,7 +38,7 @@ template TreeIterator::TreeIterator(int traverse, int iterator) , state(nullptr) , initialState(nullptr) {} -template TreeIterator::TreeIterator(MWTree &tree, int traverse, int iterator) +template TreeIterator::TreeIterator(MWTree &tree, int traverse, int iterator) : root(0) , nRoots(0) , mode(traverse) @@ -49,23 +49,23 @@ template TreeIterator::TreeIterator(MWTree &tree, int traverse, in init(tree); } -template TreeIterator::~TreeIterator() { +template TreeIterator::~TreeIterator() { if (this->initialState != nullptr) delete this->initialState; } -template int TreeIterator::getChildIndex(int i) const { - const MWNode &node = *this->state->node; + template int TreeIterator::getChildIndex(int i) const { + const MWNode &node = *this->state->node; const HilbertPath &h = node.getHilbertPath(); // Legesgue type returns i, Hilbert type returns Hilbert index return (this->type == Hilbert) ? h.getZIndex(i) : i; } -template bool TreeIterator::next() { +template bool TreeIterator::next() { if (not this->state) return false; if (this->mode == TopDown) { if (this->tryNode()) return true; } - MWNode &node = *this->state->node; + MWNode &node = *this->state->node; if (checkDepth(node) and checkGenerated(node)) { const int nChildren = 1 << D; for (int i = 0; i < nChildren; i++) { @@ -80,12 +80,12 @@ template bool TreeIterator::next() { this->removeState(); return next(); } -template bool TreeIterator::nextParent() { +template bool TreeIterator::nextParent() { if (not this->state) return false; if (this->mode == BottomUp) { if (this->tryNode()) return true; } - MWNode &node = *this->state->node; + MWNode &node = *this->state->node; if (this->tryNextRootParent()) return true; if (checkDepth(node)) { if (this->tryParent()) return true; @@ -97,73 +97,73 @@ template bool TreeIterator::nextParent() { return nextParent(); } -template void TreeIterator::init(MWTree &tree) { +template void TreeIterator::init(MWTree &tree) { this->root = 0; this->maxDepth = -1; this->nRoots = tree.getRootBox().size(); - this->state = new IteratorNode(&tree.getRootBox().getNode(this->root)); + this->state = new IteratorNode(&tree.getRootBox().getNode(this->root)); // Save the first state so it can be properly deleted later this->initialState = this->state; } -template bool TreeIterator::tryNode() { +template bool TreeIterator::tryNode() { if (not this->state) { return false; } if (this->state->doneNode) { return false; } this->state->doneNode = true; return true; } -template bool TreeIterator::tryChild(int i) { +template bool TreeIterator::tryChild(int i) { if (not this->state) { return false; } if (this->state->doneChild[i]) { return false; } this->state->doneChild[i] = true; if (this->state->node->isLeafNode()) { return false; } - MWNode *child = &this->state->node->getMWChild(i); - this->state = new IteratorNode(child, this->state); + MWNode *child = &this->state->node->getMWChild(i); + this->state = new IteratorNode(child, this->state); return next(); } -template bool TreeIterator::tryParent() { +template bool TreeIterator::tryParent() { if (not this->state) return false; if (this->state->doneParent) return false; this->state->doneParent = true; if (not this->state->node->hasParent()) return false; - MWNode *parent = &this->state->node->getMWParent(); - this->state = new IteratorNode(parent, this->state); + MWNode *parent = &this->state->node->getMWParent(); + this->state = new IteratorNode(parent, this->state); return nextParent(); } -template bool TreeIterator::tryNextRoot() { +template bool TreeIterator::tryNextRoot() { if (not this->state) { return false; } if (not this->state->node->isRootNode()) { return false; } this->root++; if (this->root >= this->nRoots) { return false; } - MWNode *nextRoot = &state->node->getMWTree().getRootBox().getNode(root); - this->state = new IteratorNode(nextRoot, this->state); + MWNode *nextRoot = &state->node->getMWTree().getRootBox().getNode(root); + this->state = new IteratorNode(nextRoot, this->state); return next(); } -template bool TreeIterator::tryNextRootParent() { +template bool TreeIterator::tryNextRootParent() { if (not this->state) { return false; } if (not this->state->node->isRootNode()) { return false; } this->root++; if (this->root >= this->nRoots) { return false; } - MWNode *nextRoot = &state->node->getMWTree().getRootBox().getNode(root); - this->state = new IteratorNode(nextRoot, this->state); + MWNode *nextRoot = &state->node->getMWTree().getRootBox().getNode(root); + this->state = new IteratorNode(nextRoot, this->state); return nextParent(); } -template void TreeIterator::removeState() { +template void TreeIterator::removeState() { if (this->state == this->initialState) { this->initialState = nullptr; } if (this->state != nullptr) { - IteratorNode *spare = this->state; + IteratorNode *spare = this->state; this->state = spare->next; spare->next = nullptr; delete spare; } } -template void TreeIterator::setTraverse(int traverse) { +template void TreeIterator::setTraverse(int traverse) { switch (traverse) { case TopDown: this->mode = TopDown; @@ -177,7 +177,7 @@ template void TreeIterator::setTraverse(int traverse) { } } -template void TreeIterator::setIterator(int iterator) { +template void TreeIterator::setIterator(int iterator) { switch (iterator) { case Lebesgue: this->type = Lebesgue; @@ -191,7 +191,7 @@ template void TreeIterator::setIterator(int iterator) { } } -template bool TreeIterator::checkDepth(const MWNode &node) const { +template bool TreeIterator::checkDepth(const MWNode &node) const { if (this->maxDepth < 0) { return true; } else if (node.getDepth() < this->maxDepth) { @@ -201,7 +201,7 @@ template bool TreeIterator::checkDepth(const MWNode &node) const { } } -template bool TreeIterator::checkGenerated(const MWNode &node) const { +template bool TreeIterator::checkGenerated(const MWNode &node) const { if (node.isEndNode() and not this->returnGenNodes) { return false; } else { @@ -209,8 +209,8 @@ template bool TreeIterator::checkGenerated(const MWNode &node) con } } -template -IteratorNode::IteratorNode(MWNode *nd, IteratorNode *nx) +template +IteratorNode::IteratorNode(MWNode *nd, IteratorNode *nx) : node(nd) , next(nx) , doneNode(false) @@ -219,8 +219,12 @@ IteratorNode::IteratorNode(MWNode *nd, IteratorNode *nx) for (int i = 0; i < nChildren; i++) { this->doneChild[i] = false; } } -template class TreeIterator<1>; -template class TreeIterator<2>; -template class TreeIterator<3>; +template class TreeIterator<1, double>; +template class TreeIterator<2, double>; +template class TreeIterator<3, double>; + +template class TreeIterator<1, ComplexDouble>; +template class TreeIterator<2, ComplexDouble>; +template class TreeIterator<3, ComplexDouble>; } // namespace mrcpp diff --git a/src/trees/TreeIterator.h b/src/trees/TreeIterator.h index d82db82a0..a79fe8412 100644 --- a/src/trees/TreeIterator.h +++ b/src/trees/TreeIterator.h @@ -30,10 +30,10 @@ namespace mrcpp { -template class TreeIterator { +template class TreeIterator { public: TreeIterator(int traverse = TopDown, int iterator = Lebesgue); - TreeIterator(MWTree &tree, int traverse = TopDown, int iterator = Lebesgue); + TreeIterator(MWTree &tree, int traverse = TopDown, int iterator = Lebesgue); virtual ~TreeIterator(); void setReturnGenNodes(bool i = true) { this->returnGenNodes = i; } @@ -41,12 +41,12 @@ template class TreeIterator { void setTraverse(int traverse); void setIterator(int iterator); - void init(MWTree &tree); + void init(MWTree &tree); bool next(); bool nextParent(); - MWNode &getNode() { return *this->state->node; } + MWNode &getNode() { return *this->state->node; } - friend class IteratorNode; + friend class IteratorNode; protected: int root; @@ -55,8 +55,8 @@ template class TreeIterator { int type; int maxDepth; bool returnGenNodes{true}; - IteratorNode *state; - IteratorNode *initialState; + IteratorNode *state; + IteratorNode *initialState; int getChildIndex(int i) const; @@ -66,19 +66,19 @@ template class TreeIterator { bool tryNextRoot(); bool tryNextRootParent(); void removeState(); - bool checkDepth(const MWNode &node) const; - bool checkGenerated(const MWNode &node) const; + bool checkDepth(const MWNode &node) const; + bool checkGenerated(const MWNode &node) const; }; -template class IteratorNode final { +template class IteratorNode final { public: - MWNode *node; - IteratorNode *next; + MWNode *node; + IteratorNode *next; bool doneNode; bool doneParent; bool doneChild[1 << D]; - IteratorNode(MWNode *nd, IteratorNode *nx = nullptr); + IteratorNode(MWNode *nd, IteratorNode *nx = nullptr); ~IteratorNode() { delete this->next; } }; diff --git a/src/utils/Bank.cpp b/src/utils/Bank.cpp index a774c44ff..c00338a9c 100644 --- a/src/utils/Bank.cpp +++ b/src/utils/Bank.cpp @@ -385,7 +385,6 @@ void Bank::open() { deposits[ix].source = status.MPI_SOURCE; if (message == SAVE_FUNCTION) { recv_function(*deposits[ix].orb, deposits[ix].source, 1, comm_bank); - cout<<"recv ORB size "<getSizeNodes(NUMBER::Total)<getSizeNodes(NUMBER::Total); totcurrentsize += deposits[ix].orb->getSizeNodes(NUMBER::Total); @@ -721,6 +720,23 @@ int BankAccount::put_data(int id, int size, double *data) { return 1; } +// save data in Bank with identity id . datasize MUST have been set already. NB:not tested +int BankAccount::put_data(int id, int size, ComplexDouble *data) { +#ifdef MRCPP_HAS_MPI + // for now we distribute according to id + int messages[message_size]; + + messages[0] = SAVE_DATA; + messages[1] = account_id; + messages[2] = id; + messages[3] = size * 2;//save as twice as many doubles + messages[4] = MIN_SCALE; // to indicate that it is defined by id + MPI_Send(messages, 5, MPI_INT, bankmaster[id % bank_size], 0, comm_bank); + MPI_Send(data, size, MPI_DOUBLE, bankmaster[id % bank_size], 1, comm_bank); +#endif + return 1; +} + // save data in Bank with identity nIdx. datasize MUST have been set already. NB:not tested int BankAccount::put_data(NodeIndex<3> nIdx, int size, double *data) { #ifdef MRCPP_HAS_MPI @@ -740,6 +756,26 @@ int BankAccount::put_data(NodeIndex<3> nIdx, int size, double *data) { return 1; } +// save data in Bank with identity nIdx. datasize MUST have been set already. NB:not tested +int BankAccount::put_data(NodeIndex<3> nIdx, int size, ComplexDouble *data) { +#ifdef MRCPP_HAS_MPI + // for now we distribute according to id + int messages[message_size]; + messages[0] = SAVE_DATA; + messages[1] = account_id; + messages[2] = nIdx.getTranslation(0); + messages[3] = size * 2; //save as twice as many doubles + messages[4] = nIdx.getScale(); + messages[5] = nIdx.getTranslation(1); + messages[6] = nIdx.getTranslation(2); + int id = std::abs(nIdx.getTranslation(0) + nIdx.getTranslation(1) + nIdx.getTranslation(2)); + MPI_Send(messages, 7, MPI_INT, bankmaster[id % bank_size], 0, comm_bank); + MPI_Send(data, size, MPI_DOUBLE, bankmaster[id % bank_size], 1, comm_bank); +#endif + return 1; +} + + // get data with identity id int BankAccount::get_data(int id, int size, double *data) { #ifdef MRCPP_HAS_MPI @@ -755,6 +791,23 @@ int BankAccount::get_data(int id, int size, double *data) { return 1; } + +// get data with identity id +int BankAccount::get_data(int id, int size, ComplexDouble *data) { +#ifdef MRCPP_HAS_MPI + MPI_Status status; + int messages[message_size]; + messages[0] = GET_DATA; + messages[1] = account_id; + messages[2] = id; + messages[3] = MIN_SCALE; + MPI_Send(messages, 4, MPI_INT, bankmaster[id % bank_size], 0, comm_bank); + //fetch as twice as many doubles + MPI_Recv(data, size*2, MPI_DOUBLE, bankmaster[id % bank_size], 1, comm_bank, &status); +#endif + return 1; +} + // get data with identity id int BankAccount::get_data(NodeIndex<3> nIdx, int size, double *data) { #ifdef MRCPP_HAS_MPI @@ -774,6 +827,26 @@ int BankAccount::get_data(NodeIndex<3> nIdx, int size, double *data) { return 1; } +// get data with identity id +int BankAccount::get_data(NodeIndex<3> nIdx, int size, ComplexDouble *data) { +#ifdef MRCPP_HAS_MPI + MPI_Status status; + int messages[message_size]; + int id = std::abs(nIdx.getTranslation(0) + nIdx.getTranslation(1) + nIdx.getTranslation(2)); + messages[0] = GET_DATA; + messages[1] = account_id; + messages[2] = id; + messages[3] = nIdx.getScale(); + messages[4] = nIdx.getTranslation(0); + messages[5] = nIdx.getTranslation(1); + messages[6] = nIdx.getTranslation(2); + MPI_Send(messages, 7, MPI_INT, bankmaster[id % bank_size], 0, comm_bank); + //fetch as twice as many doubles + MPI_Recv(data, size*2, MPI_DOUBLE, bankmaster[id % bank_size], 1, comm_bank, &status); +#endif + return 1; +} + // save data in Bank with identity id as part of block with identity nodeid. int BankAccount::put_nodedata(int id, int nodeid, int size, double *data) { #ifdef MRCPP_HAS_MPI diff --git a/src/utils/Bank.h b/src/utils/Bank.h index 501faa7a0..7293d73ab 100644 --- a/src/utils/Bank.h +++ b/src/utils/Bank.h @@ -100,9 +100,13 @@ class BankAccount { int put_func(int id, ComplexFunction &func); int get_func(int id, ComplexFunction &func, int wait = 0); int put_data(int id, int size, double *data); + int put_data(int id, int size, ComplexDouble *data); int get_data(int id, int size, double *data); + int get_data(int id, int size, ComplexDouble *data); int put_data(NodeIndex<3> nIdx, int size, double *data); + int put_data(NodeIndex<3> nIdx, int size, ComplexDouble *data); int get_data(NodeIndex<3> nIdx, int size, double *data); + int get_data(NodeIndex<3> nIdx, int size, ComplexDouble *data); int put_nodedata(int id, int nodeid, int size, double *data); int get_nodedata(int id, int nodeid, int size, double *data, std::vector &idVec); int get_nodeblock(int nodeid, double *data, std::vector &idVec); diff --git a/src/utils/ComplexFunction.h b/src/utils/ComplexFunction.h index cf33cadfb..c43d3475c 100644 --- a/src/utils/ComplexFunction.h +++ b/src/utils/ComplexFunction.h @@ -22,7 +22,7 @@ class MPI_FuncVector; namespace mrcpp { class BankAccount; -template class FunctionTree; + template class FunctionTree; template class MultiResolutionAnalysis; using ComplexDouble = std::complex; @@ -58,8 +58,8 @@ class TreePtr final { if (this->func_data.is_shared and mpi::share_size > 1) { // Memory size in MB defined in input. Virtual memory, does not cost anything if not used. #ifdef MRCPP_HAS_MPI - this->shared_mem_re = new mrcpp::SharedMemory(mpi::comm_share, mpi::shared_memory_size); - this->shared_mem_im = new mrcpp::SharedMemory(mpi::comm_share, mpi::shared_memory_size); + this->shared_mem_re = new mrcpp::SharedMemory(mpi::comm_share, mpi::shared_memory_size); + this->shared_mem_im = new mrcpp::SharedMemory(mpi::comm_share, mpi::shared_memory_size); #endif } } @@ -75,10 +75,10 @@ class TreePtr final { private: FunctionData func_data; - mrcpp::SharedMemory *shared_mem_re; - mrcpp::SharedMemory *shared_mem_im; - mrcpp::FunctionTree<3> *re; ///< Real part of function - mrcpp::FunctionTree<3> *im; ///< Imaginary part of function + mrcpp::SharedMemory *shared_mem_re; + mrcpp::SharedMemory *shared_mem_im; + mrcpp::FunctionTree<3, double> *re; ///< Real part of function + mrcpp::FunctionTree<3, double> *im; ///< Imaginary part of function void flushFuncData() { this->func_data.real_size = 0; @@ -121,10 +121,10 @@ class ComplexFunction { FunctionData &getFunctionData(); int occ() const { return this->func_ptr->func_data.occ; } int spin() const { return this->func_ptr->func_data.spin; } - FunctionTree<3> &real() { return *this->func_ptr->re; } - FunctionTree<3> &imag() { return *this->func_ptr->im; } - const FunctionTree<3> &real() const { return *this->func_ptr->re; } - const FunctionTree<3> &imag() const { return *this->func_ptr->im; } + FunctionTree<3, double> &real() { return *this->func_ptr->re; } + FunctionTree<3, double> &imag() { return *this->func_ptr->im; } + const FunctionTree<3, double> &real() const { return *this->func_ptr->re; } + const FunctionTree<3, double> &imag() const { return *this->func_ptr->im; } void release() { this->func_ptr.reset(); } bool conjugate() const { return this->conj; } MultiResolutionAnalysis<3> *funcMRA = nullptr; @@ -141,8 +141,8 @@ class ComplexFunction { int getSizeNodes(int type) const; int getNNodes(int type) const; - void setReal(mrcpp::FunctionTree<3> *tree); - void setImag(mrcpp::FunctionTree<3> *tree); + void setReal(mrcpp::FunctionTree<3, double> *tree); + void setImag(mrcpp::FunctionTree<3, double> *tree); double norm() const; double squaredNorm() const; @@ -172,8 +172,8 @@ void project(ComplexFunction &out, RepresentableFunction<3> &f, int type, double void multiply(ComplexFunction &out, ComplexFunction inp_a, ComplexFunction inp_b, double prec, bool absPrec = false, bool useMaxNorms = false); void multiply_real(ComplexFunction &out, ComplexFunction inp_a, ComplexFunction inp_b, double prec, bool absPrec = false, bool useMaxNorms = false); void multiply_imag(ComplexFunction &out, ComplexFunction inp_a, ComplexFunction inp_b, double prec, bool absPrec = false, bool useMaxNorms = false); -void multiply(ComplexFunction &out, ComplexFunction &inp_a, RepresentableFunction<3> &f, double prec, int nrefine = 0); -void multiply(ComplexFunction &out, FunctionTree<3> &inp_a, RepresentableFunction<3> &f, double prec, int nrefine = 0); +void multiply(ComplexFunction &out, ComplexFunction &inp_a, RepresentableFunction<3, double> &f, double prec, int nrefine = 0); +void multiply(ComplexFunction &out, FunctionTree<3, double> &inp_a, RepresentableFunction<3, double> &f, double prec, int nrefine = 0); void linear_combination(ComplexFunction &out, const ComplexVector &c, std::vector &inp, double prec); } // namespace cplxfunc @@ -187,7 +187,7 @@ class MPI_FuncVector : public std::vector { namespace mpifuncvec { void rotate(MPI_FuncVector &Phi, const ComplexMatrix &U, double prec = -1.0); void rotate(MPI_FuncVector &Phi, const ComplexMatrix &U, MPI_FuncVector &Psi, double prec = -1.0); -void save_nodes(MPI_FuncVector &Phi, mrcpp::FunctionTree<3> &refTree, BankAccount &account, int sizes = -1); +void save_nodes(MPI_FuncVector &Phi, mrcpp::FunctionTree<3, double> &refTree, BankAccount &account, int sizes = -1); MPI_FuncVector multiply(MPI_FuncVector &Phi, RepresentableFunction<3> &f, double prec = -1.0, ComplexFunction *Func = nullptr, int nrefine = 1, bool all = false); ComplexVector dot(MPI_FuncVector &Bra, MPI_FuncVector &Ket); ComplexMatrix calc_lowdin_matrix(MPI_FuncVector &Phi); diff --git a/src/utils/Plotter.cpp b/src/utils/Plotter.cpp index 455bb57e6..b24f2a643 100644 --- a/src/utils/Plotter.cpp +++ b/src/utils/Plotter.cpp @@ -37,24 +37,24 @@ namespace mrcpp { * * @param[in] o: Plot origin, default `(0, 0, ... , 0)` */ -template -Plotter::Plotter(const Coord &o) +template +Plotter::Plotter(const Coord &o) : O(o) { - setSuffix(Plotter::Line, ".line"); - setSuffix(Plotter::Surface, ".surf"); - setSuffix(Plotter::Cube, ".cube"); - setSuffix(Plotter::Grid, ".grid"); + setSuffix(Plotter::Line, ".line"); + setSuffix(Plotter::Surface, ".surf"); + setSuffix(Plotter::Cube, ".cube"); + setSuffix(Plotter::Grid, ".grid"); } /** @brief Set file extension for output file * - * @param[in] t: Plot type (`Plotter::Line`, `::Surface`, `::Cube`, `::Grid`) + * @param[in] t: Plot type (`Plotter::Line`, `::Surface`, `::Cube`, `::Grid`) * @param[in] s: Extension string, default `.line`, `.surf`, `.cube`, `.grid` * * @details The file name you decide for the output will get a predefined * suffix that differentiates between different types of plot. */ -template void Plotter::setSuffix(int t, const std::string &s) { +template void Plotter::setSuffix(int t, const std::string &s) { this->suffix.insert(std::pair(t, s)); } @@ -62,7 +62,7 @@ template void Plotter::setSuffix(int t, const std::string &s) { * * @param[in] o: Plot origin, default `(0, 0, ... , 0)` */ -template void Plotter::setOrigin(const Coord &o) { +template void Plotter::setOrigin(const Coord &o) { this->O = o; } @@ -72,7 +72,7 @@ template void Plotter::setOrigin(const Coord &o) { * @param[in] b: B vector * @param[in] c: C vector */ -template void Plotter::setRange(const Coord &a, const Coord &b, const Coord &c) { +template void Plotter::setRange(const Coord &a, const Coord &b, const Coord &c) { this->A = a; this->B = b; this->C = c; @@ -89,10 +89,10 @@ template void Plotter::setRange(const Coord &a, const Coord &b, * separate file, and will print only nodes owned by itself (pluss the * rootNodes). */ -template void Plotter::gridPlot(const MWTree &tree, const std::string &fname) { +template void Plotter::gridPlot(const MWTree &tree, const std::string &fname) { println(20, "----------Grid Plot-----------"); std::stringstream file; - file << fname << this->suffix[Plotter::Grid]; + file << fname << this->suffix[Plotter::Grid]; openPlot(file.str()); writeGrid(tree); closePlot(); @@ -109,16 +109,16 @@ template void Plotter::gridPlot(const MWTree &tree, const std::str * vector A starting from the origin O to a file named fname + file extension * (".line" as default). */ -template -void Plotter::linePlot(const std::array &npts, - const RepresentableFunction &func, +template +void Plotter::linePlot(const std::array &npts, + const RepresentableFunction &func, const std::string &fname) { println(20, "----------Line Plot-----------"); std::stringstream file; - file << fname << this->suffix[Plotter::Line]; + file << fname << this->suffix[Plotter::Line]; if (verifyRange(1)) { // Verifies only A vector Eigen::MatrixXd coords = calcLineCoordinates(npts[0]); - Eigen::VectorXd values = evaluateFunction(func, coords); + Eigen::Matrix< T, Eigen::Dynamic, 1 > values = evaluateFunction(func, coords); openPlot(file.str()); writeData(coords, values); closePlot(); @@ -138,16 +138,16 @@ void Plotter::linePlot(const std::array &npts, * vectors A (npts[0] points) and B (npts[1] points), starting from the * origin O, to a file named fname + file extension (".surf" as default). */ -template -void Plotter::surfPlot(const std::array &npts, - const RepresentableFunction &func, +template +void Plotter::surfPlot(const std::array &npts, + const RepresentableFunction &func, const std::string &fname) { println(20, "--------Surface Plot----------"); std::stringstream file; - file << fname << this->suffix[Plotter::Surface]; + file << fname << this->suffix[Plotter::Surface]; if (verifyRange(2)) { // Verifies A and B vectors Eigen::MatrixXd coords = calcSurfCoordinates(npts[0], npts[1]); - Eigen::VectorXd values = evaluateFunction(func, coords); + Eigen::Matrix< T, Eigen::Dynamic, 1 > values = evaluateFunction(func, coords); openPlot(file.str()); writeData(coords, values); closePlot(); @@ -168,16 +168,16 @@ void Plotter::surfPlot(const std::array &npts, * starting from the origin O, to a file named fname + file extension * (".cube" as default). */ -template -void Plotter::cubePlot(const std::array &npts, - const RepresentableFunction &func, +template +void Plotter::cubePlot(const std::array &npts, + const RepresentableFunction &func, const std::string &fname) { println(20, "----------Cube Plot-----------"); std::stringstream file; - file << fname << this->suffix[Plotter::Cube]; + file << fname << this->suffix[Plotter::Cube]; if (verifyRange(3)) { // Verifies A, B and C vectors Eigen::MatrixXd coords = calcCubeCoordinates(npts[0], npts[1], npts[2]); - Eigen::VectorXd values = evaluateFunction(func, coords); + Eigen::Matrix< T, Eigen::Dynamic, 1 > values = evaluateFunction(func, coords); openPlot(file.str()); writeCube(npts, values); closePlot(); @@ -192,7 +192,7 @@ void Plotter::cubePlot(const std::array &npts, * @details Generating a vector of pts_a equidistant coordinates that makes * up the vector A in D dimensions, starting from the origin O. */ -template Eigen::MatrixXd Plotter::calcLineCoordinates(int pts_a) const { +template Eigen::MatrixXd Plotter::calcLineCoordinates(int pts_a) const { MatrixXd coords; if (pts_a > 0) { Coord a = calcStep(this->A, pts_a); @@ -211,7 +211,7 @@ template Eigen::MatrixXd Plotter::calcLineCoordinates(int pts_a) cons * @details Generating a vector of equidistant coordinates that makes up the * area spanned by vectors A and B in D dimensions, starting from the origin O. */ -template Eigen::MatrixXd Plotter::calcSurfCoordinates(int pts_a, int pts_b) const { +template Eigen::MatrixXd Plotter::calcSurfCoordinates(int pts_a, int pts_b) const { if (D < 2) MSG_ERROR("Cannot surfPlot less than 2D"); MatrixXd coords; @@ -240,7 +240,7 @@ template Eigen::MatrixXd Plotter::calcSurfCoordinates(int pts_a, int * volume spanned by vectors A, B and C in D dimensions, starting from * the origin O. */ -template Eigen::MatrixXd Plotter::calcCubeCoordinates(int pts_a, int pts_b, int pts_c) const { +template Eigen::MatrixXd Plotter::calcCubeCoordinates(int pts_a, int pts_b, int pts_c) const { if (D < 3) MSG_ERROR("Cannot cubePlot less than 3D function"); MatrixXd coords; @@ -272,12 +272,12 @@ template Eigen::MatrixXd Plotter::calcCubeCoordinates(int pts_a, int * this routine evaluates the function in these points and stores the results * in the vector "values". */ -template -Eigen::VectorXd Plotter::evaluateFunction(const RepresentableFunction &func, +template +Eigen::Matrix< T, Eigen::Dynamic, 1 > Plotter::evaluateFunction(const RepresentableFunction &func, const Eigen::MatrixXd &coords) const { auto npts = coords.rows(); if (npts == 0) MSG_ERROR("Empty coordinates"); - Eigen::VectorXd values = VectorXd::Zero(npts); + Eigen::Matrix< T, Eigen::Dynamic, 1 > values = Eigen::Matrix< T, Eigen::Dynamic, 1 >::Zero(npts); #pragma omp parallel for schedule(static) num_threads(mrcpp_get_num_threads()) for (auto i = 0; i < npts; i++) { Coord r{}; @@ -294,7 +294,7 @@ Eigen::VectorXd Plotter::evaluateFunction(const RepresentableFunction &fun * point number (between 0 and nPoints), coordinates 1 through D and the * function value. */ -template void Plotter::writeData(const Eigen::MatrixXd &coords, const Eigen::VectorXd &values) { + template void Plotter::writeData(const Eigen::MatrixXd &coords, const Eigen::Matrix< T, Eigen::Dynamic, 1 > &values) { if (coords.rows() != values.size()) INVALID_ARG_ABORT; std::ofstream &o = *this->fout; for (auto i = 0; i < values.size(); i++) { @@ -308,17 +308,17 @@ template void Plotter::writeData(const Eigen::MatrixXd &coords, const } // Specialized for D=3 below -template void Plotter::writeCube(const std::array &npts, const Eigen::VectorXd &values) { + template void Plotter::writeCube(const std::array &npts, const Eigen::Matrix< T, Eigen::Dynamic, 1 > &values) { NOT_IMPLEMENTED_ABORT } // Specialized for D=3 below -template void Plotter::writeNodeGrid(const MWNode &node, const std::string &color) { +template void Plotter::writeNodeGrid(const MWNode &node, const std::string &color) { NOT_IMPLEMENTED_ABORT } // Specialized for D=3 below -template void Plotter::writeGrid(const MWTree &tree) { +template void Plotter::writeGrid(const MWTree &tree) { NOT_IMPLEMENTED_ABORT } @@ -326,7 +326,7 @@ template void Plotter::writeGrid(const MWTree &tree) { * * @details Opens a file output stream fout for file named fname. */ -template void Plotter::openPlot(const std::string &fname) { +template void Plotter::openPlot(const std::string &fname) { if (fname.empty()) { if (this->fout == nullptr) { MSG_ERROR("Plot file not set!"); @@ -350,7 +350,7 @@ template void Plotter::openPlot(const std::string &fname) { * * @details Closes the file output stream fout. */ -template void Plotter::closePlot() { +template void Plotter::closePlot() { if (this->fout != nullptr) this->fout->close(); this->fout = nullptr; } @@ -462,7 +462,7 @@ template <> void Plotter<3>::writeGrid(const MWTree<3> &tree) { } /** @brief Checks the validity of the plotting range */ -template bool Plotter::verifyRange(int dim) const { +template bool Plotter::verifyRange(int dim) const { auto is_len_zero = [](Coord vec) { double vec_sq = 0.0; @@ -483,14 +483,18 @@ template bool Plotter::verifyRange(int dim) const { } /** @brief Compute step length to cover vector with `pts` points, including edges */ -template Coord Plotter::calcStep(const Coord &vec, int pts) const { +template Coord Plotter::calcStep(const Coord &vec, int pts) const { Coord step; for (auto d = 0; d < D; d++) step[d] = vec[d] / (pts - 1.0); return step; } -template class Plotter<1>; -template class Plotter<2>; -template class Plotter<3>; +template class Plotter<1, double>; +template class Plotter<2, double>; +template class Plotter<3, double>; + +template class Plotter<1, ComplexDouble>; +template class Plotter<2, ComplexDouble>; +template class Plotter<3, ComplexDouble>; } // namespace mrcpp diff --git a/src/utils/Plotter.h b/src/utils/Plotter.h index d38941b27..547150197 100644 --- a/src/utils/Plotter.h +++ b/src/utils/Plotter.h @@ -56,7 +56,7 @@ namespace mrcpp { * */ -template class Plotter { +template class Plotter { public: explicit Plotter(const Coord &o = {}); virtual ~Plotter() = default; @@ -65,10 +65,10 @@ template class Plotter { void setOrigin(const Coord &o); void setRange(const Coord &a, const Coord &b = {}, const Coord &c = {}); - void gridPlot(const MWTree &tree, const std::string &fname); - void linePlot(const std::array &npts, const RepresentableFunction &func, const std::string &fname); - void surfPlot(const std::array &npts, const RepresentableFunction &func, const std::string &fname); - void cubePlot(const std::array &npts, const RepresentableFunction &func, const std::string &fname); + void gridPlot(const MWTree &tree, const std::string &fname); + void linePlot(const std::array &npts, const RepresentableFunction &func, const std::string &fname); + void surfPlot(const std::array &npts, const RepresentableFunction &func, const std::string &fname); + void cubePlot(const std::array &npts, const RepresentableFunction &func, const std::string &fname); enum type { Line, Surface, Cube, Grid }; @@ -86,13 +86,13 @@ template class Plotter { Eigen::MatrixXd calcSurfCoordinates(int pts_a, int pts_b) const; Eigen::MatrixXd calcCubeCoordinates(int pts_a, int pts_b, int pts_c) const; - Eigen::VectorXd evaluateFunction(const RepresentableFunction &func, const Eigen::MatrixXd &coords) const; + Eigen::Matrix< T, Eigen::Dynamic, 1 > evaluateFunction(const RepresentableFunction &func, const Eigen::MatrixXd &coords) const; - void writeData(const Eigen::MatrixXd &coords, const Eigen::VectorXd &values); - virtual void writeCube(const std::array &npts, const Eigen::VectorXd &values); + void writeData(const Eigen::MatrixXd &coords, const Eigen::Matrix< T, Eigen::Dynamic, 1 > &values); + virtual void writeCube(const std::array &npts, const Eigen::Matrix< T, Eigen::Dynamic, 1 > &values); - void writeGrid(const MWTree &tree); - void writeNodeGrid(const MWNode &node, const std::string &color); + void writeGrid(const MWTree &tree); + void writeNodeGrid(const MWNode &node, const std::string &color); private: bool verifyRange(int dim) const; diff --git a/src/utils/Printer.cpp b/src/utils/Printer.cpp index d9d04f4bd..957d7322b 100644 --- a/src/utils/Printer.cpp +++ b/src/utils/Printer.cpp @@ -265,7 +265,7 @@ void print::tree(int level, const std::string &txt, int n, int m, double t) { * @param[in] tree: Tree to be printed * @param[in] timer: Timer to be evaluated */ -template void print::tree(int level, const std::string &txt, const MWTree &tree, const Timer &timer) { + template void print::tree(int level, const std::string &txt, const MWTree &tree, const Timer &timer) { if (level > Printer::getPrintLevel()) return; auto n = tree.getNNodes(); diff --git a/src/utils/Printer.h b/src/utils/Printer.h index dc4935aa8..c021155e8 100644 --- a/src/utils/Printer.h +++ b/src/utils/Printer.h @@ -39,7 +39,7 @@ namespace mrcpp { class Timer; -template class MWTree; +template class MWTree; /** @class Printer * @@ -128,7 +128,7 @@ void memory(int level, const std::string &txt); void value(int level, const std::string &txt, double v, const std::string &unit = "", int p = -1, bool sci = true); void time(int level, const std::string &txt, const Timer &timer); void tree(int level, const std::string &txt, int n, int m, double t); -template void tree(int level, const std::string &txt, const MWTree &tree, const Timer &timer); +template void tree(int level, const std::string &txt, const MWTree &tree, const Timer &timer); } // namespace print // clang-format off diff --git a/src/utils/math_utils.cpp b/src/utils/math_utils.cpp index 8506be298..5ee9294f6 100644 --- a/src/utils/math_utils.cpp +++ b/src/utils/math_utils.cpp @@ -185,6 +185,20 @@ void math_utils::apply_filter(double *out, double *in, const MatrixXd &filter, i #endif } +void math_utils::apply_filter(ComplexDouble *out, ComplexDouble *in, const MatrixXd &filter, int kp1, int kp1_dm1, double fac) { + //#ifdef HAVE_BLAS +// cblas_dgemm(CblasColMajor, CblasTrans, CblasNoTrans, kp1_dm1, kp1, kp1, 1.0, in, kp1, filter.data(), kp1, fac, out, kp1_dm1); +//#else + Map f(in, kp1, kp1_dm1); + Map g(out, kp1_dm1, kp1); + if (fac < MachineZero) { + g.noalias() = f.transpose() * filter; + } else { + g.noalias() += f.transpose() * filter; + } +//#endif +} + /** Make a nD-representation from 1D-representations of separable functions. * * This method uses the "output" vector as initial input, in order to diff --git a/src/utils/math_utils.h b/src/utils/math_utils.h index 9c371aa51..9dcdb6956 100644 --- a/src/utils/math_utils.h +++ b/src/utils/math_utils.h @@ -67,6 +67,7 @@ double matrix_norm_1(const Eigen::MatrixXd &M); double matrix_norm_2(const Eigen::MatrixXd &M); void apply_filter(double *out, double *in, const Eigen::MatrixXd &filter, int kp1, int kp1_dm1, double fac); +void apply_filter(ComplexDouble *out, ComplexDouble *in, const Eigen::MatrixXd &filter, int kp1, int kp1_dm1, double fac); void tensor_expand_coefs(int dim, int dir, int kp1, int kp1_d, const Eigen::MatrixXd &primitive, Eigen::VectorXd &expanded); diff --git a/src/utils/mpi_utils.cpp b/src/utils/mpi_utils.cpp index d61f2bd23..e193aea3a 100644 --- a/src/utils/mpi_utils.cpp +++ b/src/utils/mpi_utils.cpp @@ -36,7 +36,7 @@ namespace mrcpp { * @param[in] comm: Communicator sharing resources * @param[in] sh_size: Memory size, in MB */ -SharedMemory::SharedMemory(mrcpp::mpi_comm comm, int sh_size) +template SharedMemory::SharedMemory(mrcpp::mpi_comm comm, int sh_size) : sh_start_ptr(nullptr) , sh_end_ptr(nullptr) , sh_max_ptr(nullptr) @@ -57,18 +57,18 @@ SharedMemory::SharedMemory(mrcpp::mpi_comm comm, int sh_size) int qdisp = 0; MPI_Win_shared_query(this->sh_win, 0, &qsize, &qdisp, &this->sh_start_ptr); MPI_Win_fence(0, this->sh_win); - this->sh_max_ptr = this->sh_start_ptr + qsize / sizeof(double); + this->sh_max_ptr = this->sh_start_ptr + qsize / sizeof(T); this->sh_end_ptr = this->sh_start_ptr; #endif } -void SharedMemory::clear() { +template void SharedMemory::clear() { #ifdef MRCPP_HAS_MPI this->sh_end_ptr = this->sh_start_ptr; #endif } -SharedMemory::~SharedMemory() { +template SharedMemory::~SharedMemory() { #ifdef MRCPP_HAS_MPI // deallocates the memory block MPI_Win_free(&this->sh_win); @@ -88,7 +88,7 @@ SharedMemory::~SharedMemory() { * to speed up communication, otherwise it will be communicated in a separate * step before the main communication. */ -template void send_tree(FunctionTree &tree, int dst, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff) { +template void send_tree(FunctionTree &tree, int dst, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff) { #ifdef MRCPP_HAS_MPI auto &allocator = tree.getNodeAllocator(); @@ -121,7 +121,7 @@ template void send_tree(FunctionTree &tree, int dst, int tag, mrcpp:: * to speed up communication, otherwise it will be communicated in a separate * step before the main communication. */ -template void recv_tree(FunctionTree &tree, int src, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff) { +template void recv_tree(FunctionTree &tree, int src, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff) { #ifdef MRCPP_HAS_MPI MPI_Status status; auto &allocator = tree.getNodeAllocator(); @@ -157,7 +157,7 @@ template void recv_tree(FunctionTree &tree, int src, int tag, mrcpp:: * @details This function should be called every time a shared function is * updated, in order to update the local memory of each MPI process. */ -template void share_tree(FunctionTree &tree, int src, int tag, mrcpp::mpi_comm comm) { +template void share_tree(FunctionTree &tree, int src, int tag, mrcpp::mpi_comm comm) { #ifdef MRCPP_HAS_MPI Timer t1; auto &allocator = tree.getNodeAllocator(); @@ -197,15 +197,27 @@ template void share_tree(FunctionTree &tree, int src, int tag, mrcpp: println(10, " Time share " << std::setw(30) << t1.elapsed()); #endif } - -template void send_tree<1>(FunctionTree<1> &tree, int dst, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff); -template void send_tree<2>(FunctionTree<2> &tree, int dst, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff); -template void send_tree<3>(FunctionTree<3> &tree, int dst, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff); -template void recv_tree<1>(FunctionTree<1> &tree, int src, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff); -template void recv_tree<2>(FunctionTree<2> &tree, int src, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff); -template void recv_tree<3>(FunctionTree<3> &tree, int src, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff); -template void share_tree<1>(FunctionTree<1> &tree, int src, int tag, mrcpp::mpi_comm comm); -template void share_tree<2>(FunctionTree<2> &tree, int src, int tag, mrcpp::mpi_comm comm); -template void share_tree<3>(FunctionTree<3> &tree, int src, int tag, mrcpp::mpi_comm comm); +template class SharedMemory; +template class SharedMemory; + +template void send_tree<1>(FunctionTree<1, double> &tree, int dst, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff); +template void send_tree<2>(FunctionTree<2, double> &tree, int dst, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff); +template void send_tree<3>(FunctionTree<3, double> &tree, int dst, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff); +template void recv_tree<1>(FunctionTree<1, double> &tree, int src, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff); +template void recv_tree<2>(FunctionTree<2, double> &tree, int src, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff); +template void recv_tree<3>(FunctionTree<3, double> &tree, int src, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff); +template void share_tree<1>(FunctionTree<1, double> &tree, int src, int tag, mrcpp::mpi_comm comm); +template void share_tree<2>(FunctionTree<2, double> &tree, int src, int tag, mrcpp::mpi_comm comm); +template void share_tree<3>(FunctionTree<3, double> &tree, int src, int tag, mrcpp::mpi_comm comm); + +template void send_tree<1>(FunctionTree<1, ComplexDouble> &tree, int dst, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff); +template void send_tree<2>(FunctionTree<2, ComplexDouble> &tree, int dst, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff); +template void send_tree<3>(FunctionTree<3, ComplexDouble> &tree, int dst, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff); +template void recv_tree<1>(FunctionTree<1, ComplexDouble> &tree, int src, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff); +template void recv_tree<2>(FunctionTree<2, ComplexDouble> &tree, int src, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff); +template void recv_tree<3>(FunctionTree<3, ComplexDouble> &tree, int src, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff); +template void share_tree<1>(FunctionTree<1, ComplexDouble> &tree, int src, int tag, mrcpp::mpi_comm comm); +template void share_tree<2>(FunctionTree<2, ComplexDouble> &tree, int src, int tag, mrcpp::mpi_comm comm); +template void share_tree<3>(FunctionTree<3, ComplexDouble> &tree, int src, int tag, mrcpp::mpi_comm comm); } // namespace mrcpp diff --git a/src/utils/mpi_utils.h b/src/utils/mpi_utils.h index 1b94b0dc9..93211fd1a 100644 --- a/src/utils/mpi_utils.h +++ b/src/utils/mpi_utils.h @@ -74,26 +74,26 @@ namespace mrcpp { * communicator. In order to allocate a FunctionTree in shared memory, * simply pass a SharedMemory object to the FunctionTree constructor. */ -class SharedMemory { +template class SharedMemory { public: SharedMemory(mrcpp::mpi_comm comm, int sh_size); SharedMemory(const SharedMemory &mem) = delete; - SharedMemory &operator=(const SharedMemory &mem) = delete; + SharedMemory &operator=(const SharedMemory &mem) = delete; ~SharedMemory(); void clear(); // show shared memory as entirely available - double *sh_start_ptr; // start of shared block - double *sh_end_ptr; // end of used part - double *sh_max_ptr; // end of shared block + T *sh_start_ptr; // start of shared block + T *sh_end_ptr; // end of used part + T *sh_max_ptr; // end of shared block mrcpp::mpi_win sh_win; // MPI window object int rank; // rank among shared group }; -template class FunctionTree; +template class FunctionTree; -template void send_tree(FunctionTree &tree, int dst, int tag, mrcpp::mpi_comm comm, int nChunks = -1, bool coeff = true); -template void recv_tree(FunctionTree &tree, int src, int tag, mrcpp::mpi_comm comm, int nChunks = -1, bool coeff = true); -template void share_tree(FunctionTree &tree, int src, int tag, mrcpp::mpi_comm comm); +template void send_tree(FunctionTree &tree, int dst, int tag, mrcpp::mpi_comm comm, int nChunks = -1, bool coeff = true); +template void recv_tree(FunctionTree &tree, int src, int tag, mrcpp::mpi_comm comm, int nChunks = -1, bool coeff = true); +template void share_tree(FunctionTree &tree, int src, int tag, mrcpp::mpi_comm comm); } // namespace mrcpp diff --git a/src/utils/parallel.cpp b/src/utils/parallel.cpp index 2877d12e9..85afe2426 100644 --- a/src/utils/parallel.cpp +++ b/src/utils/parallel.cpp @@ -406,7 +406,7 @@ void mpi::reduce_function(double prec, ComplexFunction &func, MPI_Comm comm) { } /** @brief make union tree and send into rank zero */ -void mpi::reduce_Tree_noCoeff(mrcpp::FunctionTree<3> &tree, MPI_Comm comm) { +void mpi::reduce_Tree_noCoeff(mrcpp::FunctionTree<3, double> &tree, MPI_Comm comm) { /* 1) Each odd rank send to the left rank 2) All odd ranks are "deleted" (can exit routine) 3) new "effective" ranks are defined within the non-deleted ranks @@ -446,10 +446,51 @@ void mpi::reduce_Tree_noCoeff(mrcpp::FunctionTree<3> &tree, MPI_Comm comm) { #endif } +/** @brief make union tree and send into rank zero */ +void mpi::reduce_Tree_noCoeff(mrcpp::FunctionTree<3, ComplexDouble> &tree, MPI_Comm comm) { +/* 1) Each odd rank send to the left rank + 2) All odd ranks are "deleted" (can exit routine) + 3) new "effective" ranks are defined within the non-deleted ranks + effective rank = rank/fac , where fac are powers of 2 + 4) repeat + */ +#ifdef MRCPP_HAS_MPI + int comm_size, comm_rank; + MPI_Comm_rank(comm, &comm_rank); + MPI_Comm_size(comm, &comm_size); + if (comm_size == 1) return; + + int fac = 1; // powers of 2 + while (fac < comm_size) { + if ((comm_rank / fac) % 2 == 0) { + // receive + int src = comm_rank + fac; + if (src < comm_size) { + int tag = 3333 + src; + mrcpp::FunctionTree<3, ComplexDouble> tree_i(tree.getMRA()); + mrcpp::recv_tree(tree_i, src, tag, comm, -1, false); + tree.appendTreeNoCoeff(tree_i); // make union grid + } + } + if ((comm_rank / fac) % 2 == 1) { + // send + int dest = comm_rank - fac; + if (dest >= 0) { + int tag = 3333 + comm_rank; + mrcpp::send_tree(tree, dest, tag, comm, -1, false); + break; // once data is sent we are done + } + } + fac *= 2; + } + MPI_Barrier(comm); +#endif +} + /** @brief make union tree without coeff and send to all * Include both real and imaginary parts */ -void mpi::allreduce_Tree_noCoeff(mrcpp::FunctionTree<3> &tree, vector &Phi, MPI_Comm comm) { +void mpi::allreduce_Tree_noCoeff(mrcpp::FunctionTree<3, double> &tree, vector &Phi, MPI_Comm comm) { /* 1) make union grid of own orbitals 2) make union grid with others orbitals (sent to rank zero) 3) rank zero broadcast func to everybody @@ -465,6 +506,25 @@ void mpi::allreduce_Tree_noCoeff(mrcpp::FunctionTree<3> &tree, vector &tree, vector> &Phi, MPI_Comm comm) { + /* 1) make union grid of own orbitals + 2) make union grid with others orbitals (sent to rank zero) + 3) rank zero broadcast func to everybody + */ + + int N = Phi.size(); + for (int j = 0; j < N; j++) { + if (not mpi::my_orb(j)) continue; + tree.appendTreeNoCoeff(Phi[j]); + } + mpi::reduce_Tree_noCoeff(tree, mpi::comm_wrk); + mpi::broadcast_Tree_noCoeff(tree, mpi::comm_wrk); +} + /** @brief Distribute rank zero function to all ranks */ void mpi::broadcast_function(ComplexFunction &func, MPI_Comm comm) { /* use same strategy as a reduce, but in reverse order */ @@ -498,7 +558,39 @@ void mpi::broadcast_function(ComplexFunction &func, MPI_Comm comm) { } /** @brief Distribute rank zero function to all ranks */ -void mpi::broadcast_Tree_noCoeff(mrcpp::FunctionTree<3> &tree, MPI_Comm comm) { +void mpi::broadcast_Tree_noCoeff(mrcpp::FunctionTree<3, double> &tree, MPI_Comm comm) { +/* use same strategy as a reduce, but in reverse order */ +#ifdef MRCPP_HAS_MPI + int comm_size, comm_rank; + MPI_Comm_rank(comm, &comm_rank); + MPI_Comm_size(comm, &comm_size); + if (comm_size == 1) return; + + int fac = 1; // powers of 2 + while (fac < comm_size) fac *= 2; + fac /= 2; + + while (fac > 0) { + if (comm_rank % fac == 0 and (comm_rank / fac) % 2 == 1) { + // receive + int src = comm_rank - fac; + int tag = 4334 + comm_rank; + mrcpp::recv_tree(tree, src, tag, comm, -1, false); + } + if (comm_rank % fac == 0 and (comm_rank / fac) % 2 == 0) { + // send + int dst = comm_rank + fac; + int tag = 4334 + dst; + if (dst < comm_size) mrcpp::send_tree(tree, dst, tag, comm, -1, false); + } + fac /= 2; + } + MPI_Barrier(comm); +#endif +} + +/** @brief Distribute rank zero function to all ranks */ +void mpi::broadcast_Tree_noCoeff(mrcpp::FunctionTree<3, ComplexDouble> &tree, MPI_Comm comm) { /* use same strategy as a reduce, but in reverse order */ #ifdef MRCPP_HAS_MPI int comm_size, comm_rank; diff --git a/src/utils/parallel.h b/src/utils/parallel.h index 78a3e2fd9..50c5ad581 100644 --- a/src/utils/parallel.h +++ b/src/utils/parallel.h @@ -54,9 +54,12 @@ void share_function(ComplexFunction &func, int src, int tag, MPI_Comm comm); void reduce_function(double prec, ComplexFunction &func, MPI_Comm comm); void broadcast_function(ComplexFunction &func, MPI_Comm comm); -void reduce_Tree_noCoeff(mrcpp::FunctionTree<3> &tree, MPI_Comm comm); -void allreduce_Tree_noCoeff(mrcpp::FunctionTree<3> &tree, std::vector &Phi, MPI_Comm comm); -void broadcast_Tree_noCoeff(mrcpp::FunctionTree<3> &tree, MPI_Comm comm); +void reduce_Tree_noCoeff(mrcpp::FunctionTree<3, double> &tree, MPI_Comm comm); +void allreduce_Tree_noCoeff(mrcpp::FunctionTree<3, double> &tree, std::vector &Phi, MPI_Comm comm); +void broadcast_Tree_noCoeff(mrcpp::FunctionTree<3, double> &tree, MPI_Comm comm); +void reduce_Tree_noCoeff(mrcpp::FunctionTree<3, ComplexDouble> &tree, MPI_Comm comm); +void allreduce_Tree_noCoeff(mrcpp::FunctionTree<3, ComplexDouble> &tree, std::vector> &Phi, MPI_Comm comm); +void broadcast_Tree_noCoeff(mrcpp::FunctionTree<3, ComplexDouble> &tree, MPI_Comm comm); void allreduce_vector(IntVector &vec, MPI_Comm comm); void allreduce_vector(DoubleVector &vec, MPI_Comm comm); diff --git a/src/utils/tree_utils.cpp b/src/utils/tree_utils.cpp index 523d3e263..f45fcc158 100644 --- a/src/utils/tree_utils.cpp +++ b/src/utils/tree_utils.cpp @@ -44,7 +44,7 @@ namespace mrcpp { * Calculates the threshold that has to be met in the wavelet norm in order to * guarantee the precision in the function representation. Depends on the * square norm of the function and the requested relative accuracy. */ -template bool tree_utils::split_check(const MWNode &node, double prec, double split_fac, bool abs_prec) { +template bool tree_utils::split_check(const MWNode &node, double prec, double split_fac, bool abs_prec) { bool split = false; if (prec > 0.0) { double t_norm = 1.0; @@ -66,40 +66,40 @@ template bool tree_utils::split_check(const MWNode &node, double prec /** Traverse tree along the Hilbert path and find nodes of any rankId. * Returns one nodeVector for the whole tree. GenNodes disregarded. */ -template void tree_utils::make_node_table(MWTree &tree, MWNodeVector &table) { - TreeIterator it(tree, TopDown, Hilbert); +template void tree_utils::make_node_table(MWTree &tree, MWNodeVector &table) { + TreeIterator it(tree, TopDown, Hilbert); it.setReturnGenNodes(false); while (it.nextParent()) { - MWNode &node = it.getNode(); + MWNode &node = it.getNode(); if (node.getDepth() == 0) continue; table.push_back(&node); } it.init(tree); while (it.next()) { - MWNode &node = it.getNode(); + MWNode &node = it.getNode(); table.push_back(&node); } } /** Traverse tree along the Hilbert path and find nodes of any rankId. * Returns one nodeVector per scale. GenNodes disregarded. */ -template void tree_utils::make_node_table(MWTree &tree, std::vector> &table) { - TreeIterator it(tree, TopDown, Hilbert); +template void tree_utils::make_node_table(MWTree &tree, std::vector> &table) { + TreeIterator it(tree, TopDown, Hilbert); it.setReturnGenNodes(false); while (it.nextParent()) { - MWNode &node = it.getNode(); + MWNode &node = it.getNode(); if (node.getDepth() == 0) continue; int depth = node.getDepth() + tree.getNNegScales(); // Add one more element - if (depth + 1 > table.size()) table.push_back(MWNodeVector()); + if (depth + 1 > table.size()) table.push_back(MWNodeVector()); table[depth].push_back(&node); } it.init(tree); while (it.next()) { - MWNode &node = it.getNode(); + MWNode &node = it.getNode(); int depth = node.getDepth() + tree.getNNegScales(); // Add one more element - if (depth + 1 > table.size()) table.push_back(MWNodeVector()); + if (depth + 1 > table.size()) table.push_back(MWNodeVector()); table[depth].push_back(&node); } } @@ -110,7 +110,7 @@ template void tree_utils::make_node_table(MWTree &tree, std::vector void tree_utils::mw_transform(const MWTree &tree, double *coeff_in, double *coeff_out, bool readOnlyScaling, int stride, bool b_overwrite) { +template void tree_utils::mw_transform(const MWTree &tree, T *coeff_in, T *coeff_out, bool readOnlyScaling, int stride, bool b_overwrite) { int operation = Reconstruction; int kp1 = tree.getKp1(); int kp1_d = tree.getKp1_d(); @@ -118,8 +118,8 @@ template void tree_utils::mw_transform(const MWTree &tree, double *co int kp1_dm1 = math_utils::ipow(kp1, D - 1); const MWFilter &filter = tree.getMRA().getFilter(); double overwrite = 0.0; - double tmpcoeff[kp1_d * tDim]; - double tmpcoeff2[kp1_d * tDim]; + T tmpcoeff[kp1_d * tDim]; + T tmpcoeff2[kp1_d * tDim]; int ftlim = tDim; int ftlim2 = tDim; int ftlim3 = tDim; @@ -135,13 +135,13 @@ template void tree_utils::mw_transform(const MWTree &tree, double *co int i = 0; int mask = 1; for (int gt = 0; gt < tDim; gt++) { - double *out = tmpcoeff + gt * kp1_d; + T *out = tmpcoeff + gt * kp1_d; for (int ft = 0; ft < ftlim; ft++) { // Operate in direction i only if the bits along other // directions are identical. The bit of the direction we // operate on determines the appropriate filter/operator if ((gt | mask) == (ft | mask)) { - double *in = coeff_in + ft * kp1_d; + T *in = coeff_in + ft * kp1_d; int filter_index = 2 * ((gt >> i) & 1) + ((ft >> i) & 1); const Eigen::MatrixXd &oper = filter.getSubFilter(filter_index, operation); @@ -155,13 +155,13 @@ template void tree_utils::mw_transform(const MWTree &tree, double *co i++; mask = 2; // 1 << i; for (int gt = 0; gt < tDim; gt++) { - double *out = tmpcoeff2 + gt * kp1_d; + T *out = tmpcoeff2 + gt * kp1_d; for (int ft = 0; ft < ftlim2; ft++) { // Operate in direction i only if the bits along other // directions are identical. The bit of the direction we // operate on determines the appropriate filter/operator if ((gt | mask) == (ft | mask)) { - double *in = tmpcoeff + ft * kp1_d; + T *in = tmpcoeff + ft * kp1_d; int filter_index = 2 * ((gt >> i) & 1) + ((ft >> i) & 1); const Eigen::MatrixXd &oper = filter.getSubFilter(filter_index, operation); @@ -178,13 +178,13 @@ template void tree_utils::mw_transform(const MWTree &tree, double *co i++; mask = 4; // 1 << i; for (int gt = 0; gt < tDim; gt++) { - double *out = coeff_out + gt * stride; // write right into children + T *out = coeff_out + gt * stride; // write right into children for (int ft = 0; ft < ftlim3; ft++) { // Operate in direction i only if the bits along other // directions are identical. The bit of the direction we // operate on determines the appropriate filter/operator if ((gt | mask) == (ft | mask)) { - double *in = tmpcoeff2 + ft * kp1_d; + T *in = tmpcoeff2 + ft * kp1_d; int filter_index = 2 * ((gt >> i) & 1) + ((ft >> i) & 1); const Eigen::MatrixXd &oper = filter.getSubFilter(filter_index, operation); @@ -200,7 +200,7 @@ template void tree_utils::mw_transform(const MWTree &tree, double *co if (D > 3) MSG_ABORT("D>3 NOT IMPLEMENTED for S_mwtransform"); if (D < 3) { - double *out; + T *out; if (D == 1) out = tmpcoeff; if (D == 2) out = tmpcoeff2; if (b_overwrite) { @@ -216,9 +216,9 @@ template void tree_utils::mw_transform(const MWTree &tree, double *co } // Specialized for D=3 below. -template void tree_utils::mw_transform_back(MWTree &tree, double *coeff_in, double *coeff_out, int stride) { - NOT_IMPLEMENTED_ABORT; -} +//template void tree_utils::mw_transform_back(MWTree &tree, double *coeff_in, double *coeff_out, int stride) { +// NOT_IMPLEMENTED_ABORT; +//} /** Make parent from children scaling coefficients * Other node info are not used/set @@ -226,7 +226,7 @@ template void tree_utils::mw_transform_back(MWTree &tree, double *coe * The output is read directly from the 8 children scaling coefficients. * NB: ASSUMES that the children coefficients are separated by Children_Stride! */ -template <> void tree_utils::mw_transform_back<3>(MWTree<3> &tree, double *coeff_in, double *coeff_out, int stride) { +template void tree_utils::mw_transform_back(MWTree<3, T> &tree, T *coeff_in, T *coeff_out, int stride) { int operation = Compression; int kp1 = tree.getKp1(); int kp1_d = tree.getKp1_d(); @@ -234,7 +234,7 @@ template <> void tree_utils::mw_transform_back<3>(MWTree<3> &tree, double *coeff int kp1_dm1 = math_utils::ipow(kp1, 2); const MWFilter &filter = tree.getMRA().getFilter(); double overwrite = 0.0; - double tmpcoeff[kp1_d * tDim]; + T tmpcoeff[kp1_d * tDim]; int ftlim = tDim; int ftlim2 = tDim; @@ -243,13 +243,13 @@ template <> void tree_utils::mw_transform_back<3>(MWTree<3> &tree, double *coeff int i = 0; int mask = 1; for (int gt = 0; gt < tDim; gt++) { - double *out = coeff_out + gt * kp1_d; + T *out = coeff_out + gt * kp1_d; for (int ft = 0; ft < ftlim; ft++) { // Operate in direction i only if the bits along other // directions are identical. The bit of the direction we // operate on determines the appropriate filter/operator if ((gt | mask) == (ft | mask)) { - double *in = coeff_in + ft * stride; + T *in = coeff_in + ft * stride; int filter_index = 2 * ((gt >> i) & 1) + ((ft >> i) & 1); const Eigen::MatrixXd &oper = filter.getSubFilter(filter_index, operation); @@ -262,13 +262,13 @@ template <> void tree_utils::mw_transform_back<3>(MWTree<3> &tree, double *coeff i++; mask = 2; // 1 << i; for (int gt = 0; gt < tDim; gt++) { - double *out = tmpcoeff + gt * kp1_d; + T *out = tmpcoeff + gt * kp1_d; for (int ft = 0; ft < ftlim2; ft++) { // Operate in direction i only if the bits along other // directions are identical. The bit of the direction we // operate on determines the appropriate filter/operator if ((gt | mask) == (ft | mask)) { - double *in = coeff_out + ft * kp1_d; + T *in = coeff_out + ft * kp1_d; int filter_index = 2 * ((gt >> i) & 1) + ((ft >> i) & 1); const Eigen::MatrixXd &oper = filter.getSubFilter(filter_index, operation); @@ -281,14 +281,14 @@ template <> void tree_utils::mw_transform_back<3>(MWTree<3> &tree, double *coeff i++; mask = 4; // 1 << i; for (int gt = 0; gt < tDim; gt++) { - double *out = coeff_out + gt * kp1_d; - // double *out = coeff_out + gt * N_coeff; + T *out = coeff_out + gt * kp1_d; + // T *out = coeff_out + gt * N_coeff; for (int ft = 0; ft < ftlim3; ft++) { // Operate in direction i only if the bits along other // directions are identical. The bit of the direction we // operate on determines the appropriate filter/operator if ((gt | mask) == (ft | mask)) { - double *in = tmpcoeff + ft * kp1_d; + T *in = tmpcoeff + ft * kp1_d; int filter_index = 2 * ((gt >> i) & 1) + ((ft >> i) & 1); const Eigen::MatrixXd &oper = filter.getSubFilter(filter_index, operation); @@ -300,24 +300,46 @@ template <> void tree_utils::mw_transform_back<3>(MWTree<3> &tree, double *coeff } } -template bool tree_utils::split_check<1>(const MWNode<1> &node, double prec, double split_fac, bool abs_prec); -template bool tree_utils::split_check<2>(const MWNode<2> &node, double prec, double split_fac, bool abs_prec); -template bool tree_utils::split_check<3>(const MWNode<3> &node, double prec, double split_fac, bool abs_prec); -template void tree_utils::make_node_table<1>(MWTree<1> &tree, MWNodeVector<1> &table); -template void tree_utils::make_node_table<2>(MWTree<2> &tree, MWNodeVector<2> &table); -template void tree_utils::make_node_table<3>(MWTree<3> &tree, MWNodeVector<3> &table); +template void tree_utils::make_node_table<1, double>(MWTree<1, double> &tree, MWNodeVector<1, double> &table); +template void tree_utils::make_node_table<2, double>(MWTree<2, double> &tree, MWNodeVector<2, double> &table); +template void tree_utils::make_node_table<3, double>(MWTree<3, double> &tree, MWNodeVector<3, double> &table); + +template void tree_utils::make_node_table<1, double>(MWTree<1, double> &tree, std::vector> &table); +template void tree_utils::make_node_table<2, double>(MWTree<2, double> &tree, std::vector> &table); +template void tree_utils::make_node_table<3, double>(MWTree<3, double> &tree, std::vector> &table); + +template bool tree_utils::split_check<1, double>(const MWNode<1, double> &node, double prec, double split_fac, bool abs_prec); +template bool tree_utils::split_check<2, double>(const MWNode<2, double> &node, double prec, double split_fac, bool abs_prec); +template bool tree_utils::split_check<3, double>(const MWNode<3, double> &node, double prec, double split_fac, bool abs_prec); + +template void tree_utils::mw_transform<1, double>(const MWTree<1, double> &tree, double *coeff_in, double *coeff_out, bool readOnlyScaling, int stride, bool b_overwrite); +template void tree_utils::mw_transform<2, double>(const MWTree<2, double> &tree, double *coeff_in, double *coeff_out, bool readOnlyScaling, int stride, bool b_overwrite); +template void tree_utils::mw_transform<3, double>(const MWTree<3, double> &tree, double *coeff_in, double *coeff_out, bool readOnlyScaling, int stride, bool b_overwrite); + +//template void tree_utils::mw_transform_back<1, double>(MWTree<1, double> &tree, double *coeff_in, double *coeff_out, int stride); +//template void tree_utils::mw_transform_back<2, double>(MWTree<2, double> &tree, double *coeff_in, double *coeff_out, int stride); +template void tree_utils::mw_transform_back(MWTree<3, double> &tree, double *coeff_in, double *coeff_out, int stride); + + +template void tree_utils::make_node_table<1, ComplexDouble>(MWTree<1, ComplexDouble> &tree, MWNodeVector<1, ComplexDouble> &table); +template void tree_utils::make_node_table<2, ComplexDouble>(MWTree<2, ComplexDouble> &tree, MWNodeVector<2, ComplexDouble> &table); +template void tree_utils::make_node_table<3, ComplexDouble>(MWTree<3, ComplexDouble> &tree, MWNodeVector<3, ComplexDouble> &table); + +template void tree_utils::make_node_table<1, ComplexDouble>(MWTree<1, ComplexDouble> &tree, std::vector> &table); +template void tree_utils::make_node_table<2, ComplexDouble>(MWTree<2, ComplexDouble> &tree, std::vector> &table); +template void tree_utils::make_node_table<3, ComplexDouble>(MWTree<3, ComplexDouble> &tree, std::vector> &table); -template void tree_utils::make_node_table<1>(MWTree<1> &tree, std::vector> &table); -template void tree_utils::make_node_table<2>(MWTree<2> &tree, std::vector> &table); -template void tree_utils::make_node_table<3>(MWTree<3> &tree, std::vector> &table); +template bool tree_utils::split_check<1, ComplexDouble>(const MWNode<1, ComplexDouble> &node, double prec, double split_fac, bool abs_prec); +template bool tree_utils::split_check<2, ComplexDouble>(const MWNode<2, ComplexDouble> &node, double prec, double split_fac, bool abs_prec); +template bool tree_utils::split_check<3, ComplexDouble>(const MWNode<3, ComplexDouble> &node, double prec, double split_fac, bool abs_prec); -template void tree_utils::mw_transform<1>(const MWTree<1> &tree, double *coeff_in, double *coeff_out, bool readOnlyScaling, int stride, bool b_overwrite); -template void tree_utils::mw_transform<2>(const MWTree<2> &tree, double *coeff_in, double *coeff_out, bool readOnlyScaling, int stride, bool b_overwrite); -template void tree_utils::mw_transform<3>(const MWTree<3> &tree, double *coeff_in, double *coeff_out, bool readOnlyScaling, int stride, bool b_overwrite); +template void tree_utils::mw_transform<1, ComplexDouble>(const MWTree<1, ComplexDouble> &tree, ComplexDouble *coeff_in, ComplexDouble *coeff_out, bool readOnlyScaling, int stride, bool b_overwrite); +template void tree_utils::mw_transform<2, ComplexDouble>(const MWTree<2, ComplexDouble> &tree, ComplexDouble *coeff_in, ComplexDouble *coeff_out, bool readOnlyScaling, int stride, bool b_overwrite); +template void tree_utils::mw_transform<3, ComplexDouble>(const MWTree<3, ComplexDouble> &tree, ComplexDouble *coeff_in, ComplexDouble *coeff_out, bool readOnlyScaling, int stride, bool b_overwrite); -template void tree_utils::mw_transform_back<1>(MWTree<1> &tree, double *coeff_in, double *coeff_out, int stride); -template void tree_utils::mw_transform_back<2>(MWTree<2> &tree, double *coeff_in, double *coeff_out, int stride); -template void tree_utils::mw_transform_back<3>(MWTree<3> &tree, double *coeff_in, double *coeff_out, int stride); +//template void tree_utils::mw_transform_back<1, ComplexDouble>(MWTree<1, ComplexDouble &tree, ComplexDouble *coeff_in, ComplexDouble *coeff_out, int stride); +//template void tree_utils::mw_transform_back<2, ComplexDouble>(MWTree<2, ComplexDouble &tree, ComplexDouble *coeff_in, ComplexDouble *coeff_out, int stride); +template void tree_utils::mw_transform_back(MWTree<3, ComplexDouble> &tree, ComplexDouble *coeff_in, ComplexDouble *coeff_out, int stride); } // namespace mrcpp diff --git a/src/utils/tree_utils.h b/src/utils/tree_utils.h index 8f2c4220a..90ff2a418 100644 --- a/src/utils/tree_utils.h +++ b/src/utils/tree_utils.h @@ -25,18 +25,21 @@ #pragma once +#include "utils/math_utils.h" #include "MRCPP/mrcpp_declarations.h" namespace mrcpp { namespace tree_utils { -template bool split_check(const MWNode &node, double prec, double split_fac, bool abs_prec); +template bool split_check(const MWNode &node, double prec, double split_fac, bool abs_prec); -template void make_node_table(MWTree &tree, MWNodeVector &table); -template void make_node_table(MWTree &tree, std::vector> &table); +template void make_node_table(MWTree &tree, MWNodeVector &table); +template void make_node_table(MWTree &tree, std::vector> &table); + +template void mw_transform(const MWTree &tree, T *coeff_in, T *coeff_out, bool readOnlyScaling, int stride, bool overwrite = true); +//template void mw_transform_back(MWTree &tree, T *coeff_in, T *coeff_out, int stride); +template void mw_transform_back(MWTree<3, T> &tree, T *coeff_in, T *coeff_out, int stride); -template void mw_transform(const MWTree &tree, double *coeff_in, double *coeff_out, bool readOnlyScaling, int stride, bool overwrite = true); -template void mw_transform_back(MWTree &tree, double *coeff_in, double *coeff_out, int stride); } // namespace tree_utils } // namespace mrcpp diff --git a/tests/operators/derivative_operator.cpp b/tests/operators/derivative_operator.cpp index 5f4400d1b..773c12ec9 100644 --- a/tests/operators/derivative_operator.cpp +++ b/tests/operators/derivative_operator.cpp @@ -102,10 +102,10 @@ template void testDifferentiationABGV(double a, double b) { }; FunctionTree f_tree(*mra); - project(prec / 10, f_tree, f); + project(prec / 10, f_tree, f); FunctionTree df_tree(*mra); - project(prec / 10, df_tree, df); + project(prec / 10, df_tree, df); FunctionTree dg_tree(*mra); apply(dg_tree, diff, f_tree, 0); @@ -122,6 +122,47 @@ template void testDifferentiationABGV(double a, double b) { delete mra; } +template void testDifferentiationCplxABGV(double a, double b) { + MultiResolutionAnalysis *mra = initializeMRA(); + + double prec = 1.0e-3; + ABGVOperator diff(*mra, a, b); + ComplexDouble s = {1.1, 1.3}; + + Coord r_0; + for (auto &x : r_0) x = pi; + + auto f = [r_0, s](const Coord &r) { + double R = math_utils::calc_distance(r, r_0); + return std::exp(-R * R * s); + }; + + auto df = [r_0, s](const Coord &r) { + double R = math_utils::calc_distance(r, r_0); + return -2.0 * s * std::exp(-R * R * s) * (r[0] - r_0[0]); + }; + + FunctionTree f_tree(*mra); + project(prec / 10, f_tree, f); + + FunctionTree df_tree(*mra); + project(prec / 10, df_tree, df); + + FunctionTree dg_tree(*mra); + apply(dg_tree, diff, f_tree, 0); + + FunctionTree err_tree(*mra); + add(-1.0, err_tree, {1.0, 0.0}, df_tree, {-1.0, 0.0}, dg_tree); + + double df_norm = std::sqrt(df_tree.getSquareNorm()); + double abs_err = std::sqrt(err_tree.getSquareNorm()); + double rel_err = abs_err / df_norm; + + REQUIRE(rel_err == Catch::Approx(0.0).margin(prec)); + + delete mra; +} + template void testDifferentiationPH(int order) { MultiResolutionAnalysis *mra = initializeMRA(); @@ -143,10 +184,10 @@ template void testDifferentiationPH(int order) { }; FunctionTree f_tree(*mra); - project(prec / 10, f_tree, f); + project(prec / 10, f_tree, f); FunctionTree df_tree(*mra); - project(prec / 10, df_tree, df); + project(prec / 10, df_tree, df); FunctionTree dg_tree(*mra); apply(dg_tree, diff, f_tree, 0); @@ -174,7 +215,7 @@ template void testDifferentiationPeriodicABGV(double a, double b) { FunctionTree g_tree(*mra); FunctionTree dg_tree(*mra); - project(prec, g_tree, g_func); + project(prec, g_tree, g_func); apply(dg_tree, diff, g_tree, 0); refine_grid(dg_tree, 1); // for accurate evalf @@ -202,7 +243,7 @@ template void testDifferentiationPeriodicPH(int order) { FunctionTree g_tree(*mra); FunctionTree dg_tree(*mra); - project(prec, g_tree, g_func); + project(prec, g_tree, g_func); apply(dg_tree, diff, g_tree, 0); refine_grid(dg_tree, 1); // for accurate evalf @@ -237,10 +278,10 @@ template void testDifferentiationBS(int order) { }; FunctionTree f_tree(*mra); - project(prec / 10, f_tree, f); + project(prec / 10, f_tree, f); FunctionTree df_tree(*mra); - project(prec / 10, df_tree, df); + project(prec / 10, df_tree, df); FunctionTree dg_tree(*mra); apply(dg_tree, diff, f_tree, 0); @@ -271,6 +312,14 @@ TEST_CASE("ABGV differentiantion center difference", "[derivative_operator], [ce SECTION("3D derivative test") { testDifferentiationABGV<3>(0, 0); } } + +TEST_CASE("ABGV differentiantion of Complex function", "[derivative_operator], [Complex]") { + // 0.5,0.5 specifies central difference + SECTION("1D derivative test") { testDifferentiationCplxABGV<1>(0.5, 0.5); } + SECTION("2D derivative test") { testDifferentiationCplxABGV<2>(0.5, 0.5); } + SECTION("3D derivative test") { testDifferentiationCplxABGV<3>(0.5, 0.5); } +} + TEST_CASE("PH differentiantion first order", "[derivative_operator], [PH_first_order]") { SECTION("1D derivative test") { testDifferentiationPH<1>(1); } SECTION("2D derivative test") { testDifferentiationPH<2>(1); } @@ -335,7 +384,7 @@ TEST_CASE("Gradient operator", "[derivative_operator], [gradient_operator]") { }; FunctionTree<3> f_tree(*mra); - project<3>(prec, f_tree, f); + project<3, double>(prec, f_tree, f); auto grad_f = gradient(diff, f_tree); REQUIRE(grad_f.size() == 3); @@ -373,7 +422,7 @@ TEST_CASE("Divergence operator", "[derivative_operator], [divergence_operator]") }; FunctionTree<3> f_tree(*mra); - project<3>(prec, f_tree, f); + project<3, double>(prec, f_tree, f); FunctionTreeVector<3> f_vec; f_vec.push_back(std::make_tuple(1.0, &f_tree)); f_vec.push_back(std::make_tuple(2.0, &f_tree)); @@ -389,6 +438,6 @@ TEST_CASE("Divergence operator", "[derivative_operator], [divergence_operator]") } delete mra; -} + } } // namespace derivative_operator diff --git a/tests/operators/helmholtz_operator.cpp b/tests/operators/helmholtz_operator.cpp index 7a0dc0243..8f570691d 100644 --- a/tests/operators/helmholtz_operator.cpp +++ b/tests/operators/helmholtz_operator.cpp @@ -169,14 +169,14 @@ TEST_CASE("Apply Helmholtz' operator", "[apply_helmholtz], [helmholtz_operator], return R_0 * Y_00; }; FunctionTree<3> psi_n(MRA); - project<3>(proj_prec, psi_n, hFunc); + project<3, double>(proj_prec, psi_n, hFunc); auto f = [Z](const Coord<3> &r) -> double { double x = std::sqrt(r[0] * r[0] + r[1] * r[1] + r[2] * r[2]); return -Z / x; }; FunctionTree<3> V(MRA); - project<3>(proj_prec, V, f); + project<3, double>(proj_prec, V, f); FunctionTree<3> Vpsi(MRA); copy_grid(Vpsi, psi_n); @@ -222,7 +222,7 @@ TEST_CASE("Apply Periodic Helmholtz' operator", "[apply_periodic_helmholtz], [he auto source = [mu](const mrcpp::Coord<3> &r) { return 3.0 * cos(r[0]) * cos(r[1]) * cos(r[2]) / (4.0 * pi) + mu * mu * cos(r[0]) * cos(r[1]) * cos(r[2]) / (4.0 * pi); }; FunctionTree<3> source_tree(MRA); - project<3>(proj_prec, source_tree, source); + project<3, double>(proj_prec, source_tree, source); FunctionTree<3> sol_tree(MRA); FunctionTree<3> in_tree(MRA); @@ -265,7 +265,7 @@ TEST_CASE("Apply negative scale Helmholtz' operator", "[apply_periodic_helmholtz auto source = [mu](const mrcpp::Coord<3> &r) { return 3.0 * cos(r[0]) * cos(r[1]) * cos(r[2]) / (4.0 * pi) + mu * mu * cos(r[0]) * cos(r[1]) * cos(r[2]) / (4.0 * pi); }; FunctionTree<3> source_tree(MRA); - project<3>(proj_prec, source_tree, source); + project<3, double>(proj_prec, source_tree, source); FunctionTree<3> sol_tree(MRA); @@ -274,4 +274,5 @@ TEST_CASE("Apply negative scale Helmholtz' operator", "[apply_periodic_helmholtz REQUIRE(sol_tree.evalf({0.0, 0.0, 0.0}) == Catch::Approx(1.0).epsilon(apply_prec)); REQUIRE(sol_tree.evalf({pi, 0.0, 0.0}) == Catch::Approx(-1.0).epsilon(apply_prec)); } + } // namespace helmholtz_operator diff --git a/tests/operators/poisson_operator.cpp b/tests/operators/poisson_operator.cpp index df841a625..eab986886 100644 --- a/tests/operators/poisson_operator.cpp +++ b/tests/operators/poisson_operator.cpp @@ -187,7 +187,7 @@ TEST_CASE("Apply Periodic Poisson' operator", "[apply_periodic_Poisson], [poisso auto source = [](const mrcpp::Coord<3> &r) { return 3.0 * cos(r[0]) * cos(r[1]) * cos(r[2]) / (4.0 * pi); }; FunctionTree<3> source_tree(MRA); - project<3>(proj_prec, source_tree, source); + project<3, double>(proj_prec, source_tree, source); FunctionTree<3> sol_tree(MRA); @@ -195,6 +195,6 @@ TEST_CASE("Apply Periodic Poisson' operator", "[apply_periodic_Poisson], [poisso REQUIRE(sol_tree.evalf({0.0, 0.0, 0.0}) == Catch::Approx(1.0).epsilon(apply_prec)); REQUIRE(sol_tree.evalf({pi, 0.0, 0.0}) == Catch::Approx(-1.0).epsilon(apply_prec)); -} + } } // namespace poisson_operator diff --git a/tests/operators/schrodinger_evolution_operator.cpp b/tests/operators/schrodinger_evolution_operator.cpp index c986ec756..477065062 100644 --- a/tests/operators/schrodinger_evolution_operator.cpp +++ b/tests/operators/schrodinger_evolution_operator.cpp @@ -38,7 +38,6 @@ namespace schrodinger_evolution_operator { - TEST_CASE("Apply Schrodinger's evolution operator", "[apply_schrodinger_evolution], [schrodinger_evolution_operator], [mw_operator]") { const auto min_scale = 0; const auto max_depth = 25; @@ -87,13 +86,13 @@ TEST_CASE("Apply Schrodinger's evolution operator", "[apply_schrodinger_evolutio // Projecting functions mrcpp::FunctionTree<1> Re_f_tree(MRA); - mrcpp::project<1>(prec, Re_f_tree, Re_f); + mrcpp::project<1, double>(prec, Re_f_tree, Re_f); mrcpp::FunctionTree<1> Im_f_tree(MRA); - mrcpp::project<1>(prec, Im_f_tree, Im_f); + mrcpp::project<1, double>(prec, Im_f_tree, Im_f); mrcpp::FunctionTree<1> Re_g_tree(MRA); - mrcpp::project<1>(prec, Re_g_tree, Re_g); + mrcpp::project<1, double>(prec, Re_g_tree, Re_g); mrcpp::FunctionTree<1> Im_g_tree(MRA); - mrcpp::project<1>(prec, Im_g_tree, Im_g); + mrcpp::project<1, double>(prec, Im_g_tree, Im_g); // Output function trees mrcpp::FunctionTree<1> Re_fout_tree(MRA); @@ -129,5 +128,4 @@ TEST_CASE("Apply Schrodinger's evolution operator", "[apply_schrodinger_evolutio REQUIRE(Im_sq_norm == Catch::Approx(0.0).margin(tolerance)); } - -} // namespace schrodinger_evolution_operator \ No newline at end of file +} // namespace schrodinger_evolution_operator diff --git a/tests/treebuilders/map.cpp b/tests/treebuilders/map.cpp index 0745db2e6..c3c333bba 100644 --- a/tests/treebuilders/map.cpp +++ b/tests/treebuilders/map.cpp @@ -77,7 +77,7 @@ template void testMapping() { const double inp_int = inp_tree.integrate(); const double inp_norm = inp_tree.getSquareNorm(); - auto fmap = [](double val) { return val * val; }; + FMap fmap = [](double val) { return val * val; }; WHEN("the function is mapped") { FunctionTree out_tree(*mra); diff --git a/tests/treebuilders/multiplication.cpp b/tests/treebuilders/multiplication.cpp index 5fa2178e1..7e8437f77 100644 --- a/tests/treebuilders/multiplication.cpp +++ b/tests/treebuilders/multiplication.cpp @@ -202,7 +202,7 @@ template void testSquare() { } finalize(&mra); } - + TEST_CASE("Dot product FunctionTreeVectors", "[multiplication], [tree_vector_dot]") { MultiResolutionAnalysis<3> *mra = nullptr; initialize<3>(&mra); @@ -221,14 +221,14 @@ TEST_CASE("Dot product FunctionTreeVectors", "[multiplication], [tree_vector_dot double r2 = (r[0] * r[0] + r[1] * r[1] + r[2] * r[2]); return r[0] * r[1] * std::exp(-2.0 * r2); }; - + FunctionTree<3> fx_tree(*mra); FunctionTree<3> fy_tree(*mra); FunctionTree<3> fz_tree(*mra); - project<3>(prec, fx_tree, fx); - project<3>(prec, fy_tree, fy); - project<3>(prec, fz_tree, fz); + project<3, double>(prec, fx_tree, fx); + project<3, double>(prec, fy_tree, fy); + project<3, double>(prec, fz_tree, fz); FunctionTreeVector<3> vec_a; vec_a.push_back(std::make_tuple(1.0, &fx_tree)); @@ -252,6 +252,6 @@ TEST_CASE("Dot product FunctionTreeVectors", "[multiplication], [tree_vector_dot } finalize(&mra); -} + } } // namespace multiplication