forked from tslund/spectralDNS
-
Notifications
You must be signed in to change notification settings - Fork 1
/
spectralDNS3D_short.py
113 lines (104 loc) · 3.33 KB
/
spectralDNS3D_short.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
__author__ = "Mikael Mortensen <[email protected]>"
__date__ = "2015-01-02"
__copyright__ = "Copyright (C) 2014-2018 " + __author__
__license__ = "GNU Lesser GPL version 3 or any later version"
#pylint: disable=reimported,redefined-outer-name
from time import time
from numpy import *
from numpy.fft import fftfreq, fft, ifft, irfft2, rfft2
from mpi4py import MPI
try:
from pyfftw.interfaces.numpy_fft import fft, ifft, irfft2, rfft2
import pyfftw
pyfftw.interfaces.cache.enable()
except ImportError:
pass
nu = 0.000625
T = 0.1
dt = 0.01
N = 2**6
comm = MPI.COMM_WORLD
num_processes = comm.Get_size()
rank = comm.Get_rank()
Np = N // num_processes
X = mgrid[rank*Np:(rank+1)*Np, :N, :N].astype(float)*2*pi/N
U = empty((3, Np, N, N))
U_hat = empty((3, N, Np, N//2+1), dtype=complex)
P = empty((Np, N, N))
P_hat = empty((N, Np, N//2+1), dtype=complex)
U_hat0 = empty((3, N, Np, N//2+1), dtype=complex)
U_hat1 = empty((3, N, Np, N//2+1), dtype=complex)
dU = empty((3, N, Np, N//2+1), dtype=complex)
Uc_hat = empty((N, Np, N//2+1), dtype=complex)
Uc_hatT = empty((Np, N, N//2+1), dtype=complex)
curl = empty((3, Np, N, N))
kx = fftfreq(N, 1./N)
kz = kx[:(N//2+1)].copy()
kz[-1] *= -1
K = array(meshgrid(kx, kx[rank*Np:(rank+1)*Np], kz, indexing='ij'), dtype=int)
K2 = sum(K*K, 0, dtype=int)
K_over_K2 = K.astype(float) / where(K2 == 0, 1, K2).astype(float)
kmax_dealias = 2./3.*(N//2+1)
dealias = array((abs(K[0]) < kmax_dealias)*(abs(K[1]) < kmax_dealias)*
(abs(K[2]) < kmax_dealias), dtype=bool)
a = [1./6., 1./3., 1./3., 1./6.]
b = [0.5, 0.5, 1.]
def fftn_mpi(u, fu):
Uc_hatT[:] = rfft2(u, axes=(1, 2))
fu[:] = rollaxis(Uc_hatT.reshape(Np, num_processes, Np, N//2+1), 1).reshape(fu.shape)
comm.Alltoall(MPI.IN_PLACE, [fu, MPI.DOUBLE_COMPLEX])
fu[:] = fft(fu, axis=0)
return fu
def ifftn_mpi(fu, u):
Uc_hat[:] = ifft(fu, axis=0)
comm.Alltoall(MPI.IN_PLACE, [Uc_hat, MPI.DOUBLE_COMPLEX])
Uc_hatT[:] = rollaxis(Uc_hat.reshape((num_processes, Np, Np, N//2+1)), 1).reshape(Uc_hatT.shape)
u[:] = irfft2(Uc_hatT, axes=(1, 2))
return u
def Cross(a, b, c):
c[0] = fftn_mpi(a[1]*b[2]-a[2]*b[1], c[0])
c[1] = fftn_mpi(a[2]*b[0]-a[0]*b[2], c[1])
c[2] = fftn_mpi(a[0]*b[1]-a[1]*b[0], c[2])
return c
#@profile
def Curl(a, c):
c[2] = ifftn_mpi(1j*(K[0]*a[1]-K[1]*a[0]), c[2])
c[1] = ifftn_mpi(1j*(K[2]*a[0]-K[0]*a[2]), c[1])
c[0] = ifftn_mpi(1j*(K[1]*a[2]-K[2]*a[1]), c[0])
return c
#@profile
def ComputeRHS(dU, rk):
if rk > 0:
for i in range(3):
U[i] = ifftn_mpi(U_hat[i], U[i])
curl[:] = Curl(U_hat, curl)
dU = Cross(U, curl, dU)
dU *= dealias
P_hat[:] = sum(dU*K_over_K2, 0, out=P_hat)
dU -= P_hat*K
dU -= nu*K2*U_hat
return dU
U[0] = sin(X[0])*cos(X[1])*cos(X[2])
U[1] = -cos(X[0])*sin(X[1])*cos(X[2])
U[2] = 0
for i in range(3):
U_hat[i] = fftn_mpi(U[i], U_hat[i])
t = 0.0
tstep = 0
t0 = time()
while t < T-1e-8:
t += dt
tstep += 1
U_hat1[:] = U_hat0[:] = U_hat
for rk in range(4):
dU = ComputeRHS(dU, rk)
if rk < 3:
U_hat[:] = U_hat0 + b[rk]*dt*dU
U_hat1[:] += a[rk]*dt*dU
U_hat[:] = U_hat1[:]
for i in range(3):
U[i] = ifftn_mpi(U_hat[i], U[i])
k = comm.reduce(0.5*sum(U*U)*(1./N)**3)
if rank == 0:
print("Time = {}".format(time()-t0))
assert round(k - 0.124953117517, 7) == 0