Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize the classes to use complex or real numbers #21

Open
wants to merge 12 commits into
base: v2.x
Choose a base branch
from
2 changes: 1 addition & 1 deletion src/Integrator/BDHI/DoublyPeriodic/DPStokesSlab.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ else throw std::runtime_error("[DPStokesSlab] Can only average in direction X (0
real gw;
real tolerance;
WallMode mode;
shared_ptr<BVP::BatchedBVPHandler> bvpSolver;
shared_ptr<BVP::BatchedBVPHandler<real>> bvpSolver;

};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ namespace uammd{
}

auto computeZeroModeBoundaryConditions(int nz, real H, WallMode mode){
BVP::SchurBoundaryCondition bcs(nz, H);
BVP::SchurBoundaryCondition<real> bcs(nz, H);
if(mode == WallMode::bottom){
correction_ns::TopBoundaryConditions top(0, H);
correction_ns::BottomBoundaryConditions bot(0, H);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ namespace uammd{
auto botBC = thrust::make_transform_iterator(thrust::make_counting_iterator<int>(0), botdispatch);
int numberSystems = (nk.x/2+1)*nk.y;
int nz = grid.cellDim.z;
this->bvpSolver = std::make_shared<BVP::BatchedBVPHandler>(klist, topBC, botBC, numberSystems, halfH, nz);
this->bvpSolver = std::make_shared<BVP::BatchedBVPHandler<real>>(klist, topBC, botBC, numberSystems, halfH, nz);
CudaCheckError();
}

Expand Down
4 changes: 2 additions & 2 deletions src/Interactor/DoublyPeriodic/PoissonSlab/BVPPoisson.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ namespace uammd{
}

class BVPPoissonSlab{
std::shared_ptr<BVP::BatchedBVPHandler> bvpSolver;
std::shared_ptr<BVP::BatchedBVPHandler<real>> bvpSolver;
real2 Lxy;
real H;
int3 cellDim;
Expand Down Expand Up @@ -150,7 +150,7 @@ namespace uammd{
BoundaryConditionsDispatch<BottomBoundaryConditions, decltype(klist)>(klist, H));

int numberSystems = (nk.x/2+1)*nk.y;
this->bvpSolver = std::make_shared<BVP::BatchedBVPHandler>(klist, topBC, bottomBC, numberSystems, H, cellDim.z);
this->bvpSolver = std::make_shared<BVP::BatchedBVPHandler<real>>(klist, topBC, bottomBC, numberSystems, H, cellDim.z);
CudaCheckError();
}
};
Expand Down
25 changes: 14 additions & 11 deletions src/misc/BoundaryValueProblem/BVPSchurComplementMatrices.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ namespace uammd{
}
};

template <typename T>
class SchurBoundaryCondition{
int nz;
real H;
Expand All @@ -125,8 +126,8 @@ namespace uammd{
SchurBoundaryCondition(int nz, real H):nz(nz), H(H), bcs(nz){}

template<class TopBC, class BottomBC>
std::vector<real> computeBoundaryConditionMatrix(const TopBC &top, const BottomBC &bottom){
std::vector<real> CandD(2*nz+4, 0);
std::vector<T> computeBoundaryConditionMatrix(const TopBC &top, const BottomBC &bottom){
std::vector<T> CandD(2*nz+4, 0);
auto topRow = computeTopRow(top, bottom);
auto bottomRow = computeBottomRow(top, bottom);
std::copy(topRow.begin(), topRow.end()-2, CandD.begin());
Expand All @@ -141,8 +142,8 @@ namespace uammd{
private:

template<class TopBC, class BottomBC>
std::vector<real> computeTopRow(const TopBC &top, const BottomBC &bottom){
std::vector<real> topRow(nz+2, 0);
std::vector<T> computeTopRow(const TopBC &top, const BottomBC &bottom){
std::vector<T> topRow(nz+2, 0);
auto tfi = bcs.topFirstIntegral();
auto tsi = bcs.topSecondIntegral();
auto tfiFactor = top.getFirstIntegralFactor();
Expand All @@ -156,8 +157,8 @@ namespace uammd{
}

template<class TopBC, class BottomBC>
std::vector<real> computeBottomRow(const TopBC &top, const BottomBC &bottom){
std::vector<real> bottomRow(nz+2, 0);
std::vector<T> computeBottomRow(const TopBC &top, const BottomBC &bottom){
std::vector<T> bottomRow(nz+2, 0);
auto bfi = bcs.bottomFirstIntegral();
auto bsi = bcs.bottomSecondIntegral();
auto bfiFactor = bottom.getFirstIntegralFactor();
Expand All @@ -172,10 +173,11 @@ namespace uammd{

};

std::vector<real> computeSecondIntegralMatrix(real k, real H, int nz){
std::vector<real> A(nz*nz, 0);
template<class T>
std::vector<T> computeSecondIntegralMatrix(T k, real H, int nz){
std::vector<T> A(nz*nz, 0);
SecondIntegralMatrix sim(nz);
real kH2 = k*k*H*H;
T kH2 = k*k*H*H;
fori(0, nz){
forj(0,nz){
A[i+nz*j] = (i==j) - kH2*sim.getElement(i, j);
Expand All @@ -184,9 +186,10 @@ namespace uammd{
return std::move(A);
}

std::vector<real> computeInverseSecondIntegralMatrix(real k, real H, int nz){
template<class T>
std::vector<T> computeInverseSecondIntegralMatrix(T k, real H, int nz){
if(k==0){
std::vector<real> invA(nz*nz, 0);
std::vector<T> invA(nz*nz, 0);
fori(0, nz){
invA[i+nz*i] = 1;
}
Expand Down
85 changes: 47 additions & 38 deletions src/misc/BoundaryValueProblem/BVPSolver.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,39 +42,40 @@ namespace uammd{
namespace BVP{

namespace detail{
template <typename U>
class SubsystemSolver{
int nz;
real H;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is leaving this as real a conscious decision?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we believe that since H is associated with a size, it will always be real. Furthermore, in the complex_bvp branch (which is the branch we're trying to merge into the main one with these changes), H is treated as a real number, while all other entities are complex.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At first I though this PR was making some things compatible with either float or double, thus some of my comments -.-
Calling the template parameter U seems to signal that this class can be instantiated with any type (like float, float4, complex...). I guess this is why concepts were introduced in C++20, so you can restrict the type to be either thrust::complex or real and nothing else.
I now remember I did not merge the complex branch because I found no way to state this clearly via code.
The code for the complex and real versions is really similar but not quite so. I considered making everything complex, but many usecases are real and that would entail twice the work.
Maybe putting everything into a hidden namespace and just exposing two instantiated types (real and thrust::complex) for the public API is the best way to go. Then if you want to use another complex type or whatever you know you are on your own.

StorageHandle<real> CinvA_storage;
StorageHandle<real4> CinvABmD_storage;
StorageHandle<U> CinvA_storage;
StorageHandle<U> CinvABmD_storage;
RaulPPelaez marked this conversation as resolved.
Show resolved Hide resolved

public:

SubsystemSolver(int nz, real H):nz(nz), H(H){}

void registerRequiredStorage(StorageRegistration &memoryManager){
CinvA_storage = memoryManager.registerStorageRequirement<real>(2*nz+2);
CinvABmD_storage = memoryManager.registerStorageRequirement<real4>(1);
CinvA_storage = memoryManager.registerStorageRequirement<U>(2*nz+2);
CinvABmD_storage = memoryManager.registerStorageRequirement<U>(4);
}

template<class TopBC, class BottomBC>
void precompute(real k, const TopBC &top, const BottomBC &bottom, StorageRetriever &memoryManager){
void precompute(U k, const TopBC &top, const BottomBC &bottom, StorageRetriever &memoryManager){
auto CinvA_it = memoryManager.retrieveStorage(CinvA_storage);
auto CinvABmD_it = memoryManager.retrieveStorage(CinvABmD_storage);
auto invA = computeInverseSecondIntegralMatrix(k, H, nz);
SchurBoundaryCondition bcs(nz, H);
SchurBoundaryCondition<U> bcs(nz, H);
auto CandD = bcs.computeBoundaryConditionMatrix(top, bottom);
real4 D = make_real4(CandD[2*nz], CandD[2*nz+1], CandD[2*nz+2], CandD[2*nz+3]);
U D[4] = {CandD[2*nz], CandD[2*nz+1], CandD[2*nz+2], CandD[2*nz+3]};
auto CinvA = matmul(CandD, nz, 2, invA, nz, nz);
std::copy(CinvA.begin(), CinvA.end(), CinvA_it);
real4 CinvAB;
real B00 = -k*k*H*H;
real B11 = -k*k*H*H;
CinvAB.x = CinvA[0]*B00;
CinvAB.y = CinvA[1]*B11;
CinvAB.z = CinvA[0+nz]*B00;
CinvAB.w = CinvA[1+nz]*B11;
CinvABmD_it[0] = CinvAB - D;
U CinvAB[4];
U B00 = -k*k*H*H;
U B11 = -k*k*H*H;
CinvAB[0] = CinvA[0]*B00;
CinvAB[1] = CinvA[1]*B11;
CinvAB[2] = CinvA[0+nz]*B00;
CinvAB[3] = CinvA[1+nz]*B11;
fori(0, 4) CinvABmD_it[i] = CinvAB[i] - D[i];
}

template<class T, class FnIterator>
Expand All @@ -83,16 +84,19 @@ namespace uammd{
StorageRetriever &memoryManager){
const auto CinvA = memoryManager.retrieveStorage(CinvA_storage);
const auto CinvAfmab = computeRightHandSide(alpha, beta, fn, CinvA);
const real4 CinvABmD = *(memoryManager.retrieveStorage(CinvABmD_storage));
const auto CinvABmD_it = memoryManager.retrieveStorage(CinvABmD_storage);
U CinvABmD[4] = {CinvABmD_it[0], CinvABmD_it[1],
CinvABmD_it[2], CinvABmD_it[3]};

const auto c0d0 = solveSubsystem(CinvABmD, CinvAfmab);
return c0d0;
}

private:

template<class T>
__device__ thrust::pair<T,T> solveSubsystem(real4 CinvABmD, thrust::pair<T,T> CinvAfmab) const{
auto c0d0 = solve2x2System(CinvABmD, CinvAfmab);
template<class T, class T2>
__device__ thrust::pair<T,T> solveSubsystem(T2 CinvABmD[4], thrust::pair<T,T> CinvAfmab) const{
auto c0d0 = solve2x2System<T,T2>(CinvABmD, CinvAfmab);
return c0d0;
}

Expand All @@ -113,10 +117,11 @@ namespace uammd{

};

template<typename U>
class PentadiagonalSystemSolver{
int nz;
real H;
KBPENTA_mod pentasolve;
KBPENTA_mod<U> pentasolve;
public:

PentadiagonalSystemSolver(int nz, real H):
Expand All @@ -126,12 +131,12 @@ namespace uammd{
pentasolve.registerRequiredStorage(memoryManager);
}

void precompute(real k, StorageRetriever &memoryManager){
real diagonal[nz];
real diagonal_p2[nz]; diagonal_p2[nz-2] = diagonal_p2[nz-1] = 0;
real diagonal_m2[nz]; diagonal_m2[0] = diagonal_m2[1] = 0;
void precompute(U k, StorageRetriever &memoryManager){
U diagonal[nz];
U diagonal_p2[nz]; diagonal_p2[nz-2] = diagonal_p2[nz-1] = 0;
U diagonal_m2[nz]; diagonal_m2[0] = diagonal_m2[1] = 0;
SecondIntegralMatrix sim(nz);
const real kH2 = k*k*H*H;
const U kH2 = k*k*H*H;
for(int i = 0; i<nz; i++){
diagonal[i] = 1.0 - kH2*sim.getElement(i,i);
if(i<nz-2) diagonal_p2[i] = -kH2*sim.getElement(i+2, i);
Expand All @@ -148,23 +153,24 @@ namespace uammd{
};
}

template<typename U>
class BoundaryValueProblemSolver{
detail::PentadiagonalSystemSolver pent;
detail::SubsystemSolver sub;
StorageHandle<real> waveVector;
detail::PentadiagonalSystemSolver<U> pent;
detail::SubsystemSolver<U> sub;
StorageHandle<U> waveVector;
int nz;
real H;
public:
BoundaryValueProblemSolver(int nz, real H): nz(nz), H(H), sub(nz, H), pent(nz, H){}

void registerRequiredStorage(StorageRegistration &mem){
waveVector = mem.registerStorageRequirement<real>(1);
waveVector = mem.registerStorageRequirement<U>(1);
sub.registerRequiredStorage(mem);
pent.registerRequiredStorage(mem);
}

template<class TopBC, class BottomBC>
void precompute(StorageRetriever &mem, real k, const TopBC &top, const BottomBC &bot){
void precompute(StorageRetriever &mem, U k, const TopBC &top, const BottomBC &bot){
auto k_access = mem.retrieveStorage(waveVector);
k_access[0] = k;
pent.precompute(k, mem);
Expand All @@ -177,10 +183,10 @@ namespace uammd{
AnIterator& an,
CnIterator& cn,
StorageRetriever &mem){
const real k = *(mem.retrieveStorage(waveVector));
const auto k = *(mem.retrieveStorage(waveVector));
T c0, d0;
thrust::tie(c0, d0) = sub.solve(fn, alpha, beta, mem);
const real kH2 = k*k*H*H;
const auto kH2 = k*k*H*H;
fn[0] += kH2*c0;
fn[1] += kH2*d0;
pent.solve(fn, an, mem);
Expand All @@ -201,15 +207,17 @@ namespace uammd{

};

template<typename U>
class BatchedBVPHandler;

template<typename U>
struct BatchedBVPGPUSolver{
private:
int numberSystems;
BoundaryValueProblemSolver bvpSolver;
BoundaryValueProblemSolver<U> bvpSolver;
char* gpuMemory;
friend class BatchedBVPHandler;
BatchedBVPGPUSolver(int numberSystems, BoundaryValueProblemSolver bvpSolver, char *raw):
friend class BatchedBVPHandler<U>;
BatchedBVPGPUSolver(int numberSystems, BoundaryValueProblemSolver<U> bvpSolver, char *raw):
numberSystems(numberSystems), bvpSolver(bvpSolver), gpuMemory(raw){}
public:

Expand All @@ -225,9 +233,10 @@ namespace uammd{

};

template <typename U>
class BatchedBVPHandler{
int numberSystems;
BoundaryValueProblemSolver bvp;
BoundaryValueProblemSolver<U> bvp;
thrust::device_vector<char> gpuMemory;
public:

Expand All @@ -240,9 +249,9 @@ namespace uammd{
precompute(klist, top, bot);
}

BatchedBVPGPUSolver getGPUSolver(){
BatchedBVPGPUSolver<U> getGPUSolver(){
auto raw = thrust::raw_pointer_cast(gpuMemory.data());
BatchedBVPGPUSolver d_solver(numberSystems, bvp, raw);
BatchedBVPGPUSolver<U> d_solver(numberSystems, bvp, raw);
return d_solver;
}

Expand Down
9 changes: 5 additions & 4 deletions src/misc/BoundaryValueProblem/KBPENTA.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,21 @@ namespace uammd{
namespace BVP{

//Algorithm adapted from http://dx.doi.org/10.1080/00207160802326507 for a special case of only three diagonals being non zero
template<typename U>
class KBPENTA_mod{
StorageHandle<real> storageHandle;
StorageHandle<U> storageHandle;
int nz;
public:

KBPENTA_mod(int nz): nz(nz){}

void registerRequiredStorage(StorageRegistration &memoryManager){
storageHandle = memoryManager.registerStorageRequirement<real>(3*nz+2);
storageHandle = memoryManager.registerStorageRequirement<U>(3*nz+2);
}

void store(real *diagonal, real *diagonal_p2, real *diagonal_m2, StorageRetriever &memoryManager){
void store(U *diagonal, U *diagonal_p2, U *diagonal_m2, StorageRetriever &memoryManager){
auto storage = memoryManager.retrieveStorage(storageHandle);
std::vector<real> beta(nz+1, 0);
std::vector<U> beta(nz+1, 0);
beta[0] = 0;
beta[1] = diagonal[nz-nz];
beta[2] = diagonal[nz-(nz-1)];
Expand Down
Loading
Loading