Skip to content

Commit

Permalink
mrchem compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
gitpeterwind committed Jul 10, 2024
1 parent 365bd0c commit fa478e8
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 32 deletions.
2 changes: 1 addition & 1 deletion src/functions/AnalyticFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

namespace mrcpp {

template <int D, typename T> class AnalyticFunction : public RepresentableFunction<D, T> {
template <int D, typename T = double> class AnalyticFunction : public RepresentableFunction<D, T> {
public:
AnalyticFunction() = default;
~AnalyticFunction() override = default;
Expand Down
2 changes: 1 addition & 1 deletion src/functions/RepresentableFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
#include "MRCPP/constants.h"
#include "MRCPP/mrcpp_declarations.h"
#include "trees/NodeIndex.h"
#include "utils/math_utils.h"
#include "MRCPP/utils/math_utils.h"

namespace mrcpp {

Expand Down
20 changes: 8 additions & 12 deletions src/treebuilders/map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,13 @@ namespace mrcpp {
* no coefs).
*
*/
template <int D, typename T>
void map(double prec, FunctionTree<D, T> &out, FunctionTree<D, T> &inp, FMap<T, T> fmap, int maxIter, bool absPrec) {
template <int D>
void map(double prec, FunctionTree<D, double> &out, FunctionTree<D, double> &inp, FMap<double, double> fmap, int maxIter, bool absPrec) {

int maxScale = out.getMRA().getMaxScale();
TreeBuilder<D, T> builder;
WaveletAdaptor<D, T> adaptor(prec, maxScale, absPrec);
MapCalculator<D, T> calculator(fmap, inp);
TreeBuilder<D, double> builder;
WaveletAdaptor<D, double> adaptor(prec, maxScale, absPrec);
MapCalculator<D, double> calculator(fmap, inp);

builder.build(out, calculator, adaptor, maxIter);

Expand All @@ -89,12 +89,8 @@ void map(double prec, FunctionTree<D, T> &out, FunctionTree<D, T> &inp, FMap<T,
print::separator(10, ' ');
}

template void map<1, double>(double prec, FunctionTree<1, double> &out, FunctionTree<1, double> &inp, FMap<double, double> fmap, int maxIter, bool absPrec);
template void map<2, double>(double prec, FunctionTree<2, double> &out, FunctionTree<2, double> &inp, FMap<double, double> fmap, int maxIter, bool absPrec);
template void map<3, double>(double prec, FunctionTree<3, double> &out, FunctionTree<3, double> &inp, FMap<double, double> fmap, int maxIter, bool absPrec);

template void map<1, ComplexDouble>(double prec, FunctionTree<1, ComplexDouble> &out, FunctionTree<1, ComplexDouble> &inp, FMap<ComplexDouble, ComplexDouble> fmap, int maxIter, bool absPrec);
template void map<2, ComplexDouble>(double prec, FunctionTree<2, ComplexDouble> &out, FunctionTree<2, ComplexDouble> &inp, FMap<ComplexDouble, ComplexDouble> fmap, int maxIter, bool absPrec);
template void map<3, ComplexDouble>(double prec, FunctionTree<3, ComplexDouble> &out, FunctionTree<3, ComplexDouble> &inp, FMap<ComplexDouble, ComplexDouble> fmap, int maxIter, bool absPrec);
template void map<1>(double prec, FunctionTree<1, double> &out, FunctionTree<1, double> &inp, FMap<double, double> fmap, int maxIter, bool absPrec);
template void map<2>(double prec, FunctionTree<2, double> &out, FunctionTree<2, double> &inp, FMap<double, double> fmap, int maxIter, bool absPrec);
template void map<3>(double prec, FunctionTree<3, double> &out, FunctionTree<3, double> &inp, FMap<double, double> fmap, int maxIter, bool absPrec);

} // Namespace mrcpp
5 changes: 2 additions & 3 deletions src/treebuilders/map.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@
#include "trees/FunctionTreeVector.h"

namespace mrcpp {
template <int D, typename T> class RepresentableFunction;
template <int D, typename T> class FunctionTree;

template <int D, typename T>
void map(double prec, FunctionTree<D, T> &out, FunctionTree<D, T> &inp, FMap<T, T> fmap, int maxIter = -1, bool absPrec = false);
template <int D>
void map(double prec, FunctionTree<D, double> &out, FunctionTree<D, double> &inp, FMap<double, double> fmap, int maxIter = -1, bool absPrec = false);

} // namespace mrcpp
22 changes: 11 additions & 11 deletions src/utils/ComplexFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ class TreePtr final {
FunctionData func_data;
mrcpp::SharedMemory<double> *shared_mem_re;
mrcpp::SharedMemory<double> *shared_mem_im;
mrcpp::FunctionTree<3> *re; ///< Real part of function
mrcpp::FunctionTree<3> *im; ///< Imaginary part of function
mrcpp::FunctionTree<3, double> *re; ///< Real part of function
mrcpp::FunctionTree<3, double> *im; ///< Imaginary part of function

void flushFuncData() {
this->func_data.real_size = 0;
Expand Down Expand Up @@ -121,10 +121,10 @@ class ComplexFunction {
FunctionData &getFunctionData();
int occ() const { return this->func_ptr->func_data.occ; }
int spin() const { return this->func_ptr->func_data.spin; }
FunctionTree<3> &real() { return *this->func_ptr->re; }
FunctionTree<3> &imag() { return *this->func_ptr->im; }
const FunctionTree<3> &real() const { return *this->func_ptr->re; }
const FunctionTree<3> &imag() const { return *this->func_ptr->im; }
FunctionTree<3, double> &real() { return *this->func_ptr->re; }
FunctionTree<3, double> &imag() { return *this->func_ptr->im; }
const FunctionTree<3, double> &real() const { return *this->func_ptr->re; }
const FunctionTree<3, double> &imag() const { return *this->func_ptr->im; }
void release() { this->func_ptr.reset(); }
bool conjugate() const { return this->conj; }
MultiResolutionAnalysis<3> *funcMRA = nullptr;
Expand All @@ -141,8 +141,8 @@ class ComplexFunction {
int getSizeNodes(int type) const;
int getNNodes(int type) const;

void setReal(mrcpp::FunctionTree<3> *tree);
void setImag(mrcpp::FunctionTree<3> *tree);
void setReal(mrcpp::FunctionTree<3, double> *tree);
void setImag(mrcpp::FunctionTree<3, double> *tree);

double norm() const;
double squaredNorm() const;
Expand Down Expand Up @@ -172,8 +172,8 @@ void project(ComplexFunction &out, RepresentableFunction<3> &f, int type, double
void multiply(ComplexFunction &out, ComplexFunction inp_a, ComplexFunction inp_b, double prec, bool absPrec = false, bool useMaxNorms = false);
void multiply_real(ComplexFunction &out, ComplexFunction inp_a, ComplexFunction inp_b, double prec, bool absPrec = false, bool useMaxNorms = false);
void multiply_imag(ComplexFunction &out, ComplexFunction inp_a, ComplexFunction inp_b, double prec, bool absPrec = false, bool useMaxNorms = false);
void multiply(ComplexFunction &out, ComplexFunction &inp_a, RepresentableFunction<3> &f, double prec, int nrefine = 0);
void multiply(ComplexFunction &out, FunctionTree<3> &inp_a, RepresentableFunction<3> &f, double prec, int nrefine = 0);
void multiply(ComplexFunction &out, ComplexFunction &inp_a, RepresentableFunction<3, double> &f, double prec, int nrefine = 0);
void multiply(ComplexFunction &out, FunctionTree<3, double> &inp_a, RepresentableFunction<3, double> &f, double prec, int nrefine = 0);
void linear_combination(ComplexFunction &out, const ComplexVector &c, std::vector<ComplexFunction> &inp, double prec);
} // namespace cplxfunc

Expand All @@ -187,7 +187,7 @@ class MPI_FuncVector : public std::vector<ComplexFunction> {
namespace mpifuncvec {
void rotate(MPI_FuncVector &Phi, const ComplexMatrix &U, double prec = -1.0);
void rotate(MPI_FuncVector &Phi, const ComplexMatrix &U, MPI_FuncVector &Psi, double prec = -1.0);
void save_nodes(MPI_FuncVector &Phi, mrcpp::FunctionTree<3> &refTree, BankAccount &account, int sizes = -1);
void save_nodes(MPI_FuncVector &Phi, mrcpp::FunctionTree<3, double> &refTree, BankAccount &account, int sizes = -1);
MPI_FuncVector multiply(MPI_FuncVector &Phi, RepresentableFunction<3> &f, double prec = -1.0, ComplexFunction *Func = nullptr, int nrefine = 1, bool all = false);
ComplexVector dot(MPI_FuncVector &Bra, MPI_FuncVector &Ket);
ComplexMatrix calc_lowdin_matrix(MPI_FuncVector &Phi);
Expand Down
7 changes: 3 additions & 4 deletions tests/operators/derivative_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,18 +127,17 @@ template <int D> void testDifferentiationCplxABGV(double a, double b) {

double prec = 1.0e-3;
ABGVOperator<D> diff(*mra, a, b);
ComplexDouble s = {1.1, 1.3};

Coord<D> r_0;
for (auto &x : r_0) x = pi;

auto f = [r_0](const Coord<D> &r) {
ComplexDouble s = {1.1, 1.3};
auto f = [r_0, s](const Coord<D> &r) {
double R = math_utils::calc_distance<D>(r, r_0);
return std::exp(-R * R * s);
};

auto df = [r_0](const Coord<D> &r) {
ComplexDouble s = {1.1, 1.3};
auto df = [r_0, s](const Coord<D> &r) {
double R = math_utils::calc_distance<D>(r, r_0);
return -2.0 * s * std::exp(-R * R * s) * (r[0] - r_0[0]);
};
Expand Down

0 comments on commit fa478e8

Please sign in to comment.