Skip to content

Commit

Permalink
Changed forcing to an nn.Module template class
Browse files Browse the repository at this point in the history
  • Loading branch information
scaomath committed Apr 29, 2024
1 parent 20608ac commit 8dbb1d5
Show file tree
Hide file tree
Showing 5 changed files with 268 additions and 64 deletions.
38 changes: 21 additions & 17 deletions example_Kolmogrov2d_rk4_cn_forced_turbulence.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
numpy==1.24.4
torch>=2.0.1
torch>=2.2.0
xarray>=2023.1.0
tqdm>=4.62.0
51 changes: 36 additions & 15 deletions torch_cfd/equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# Modifications copyright (C) 2024 S.Cao
# ported Google's Jax-CFD functional template to PyTorch's tensor ops

from typing import Callable, Dict, Optional
from typing import Tuple, Callable, Dict, Optional

import torch
import torch.nn as nn
Expand Down Expand Up @@ -148,7 +148,7 @@ def crank_nicolson_rk4(
)


class NavierStokes2D(nn.Module):
class NavierStokes2DSpectral(nn.Module):
"""Breaks the Navier-Stokes equation into implicit and explicit parts.
Implicit parts are the linear terms and explicit parts are the non-linear
Expand Down Expand Up @@ -201,14 +201,24 @@ def brick_wall_filter_2d(grid: Grid):
return filter_

@staticmethod
def spectral_curl_2d(mesh, velocity_hat):
"""Computes the 2D curl in the Fourier basis."""
kx, ky = mesh
uhat, vhat = velocity_hat
def spectral_curl_2d(vhat, rfft_mesh):
r"""
Computes the 2D curl in the Fourier basis.
det [d_x d_y \\ u v]
"""
uhat, vhat = vhat
kx, ky = rfft_mesh
return 2j * torch.pi * (vhat * kx - uhat * ky)

@staticmethod
def vorticity_to_velocity(grid: Grid, w_hat: Array):
def spectral_grad_2d(vhat, rfft_mesh):
kx, ky = rfft_mesh
return 2j * torch.pi * kx * vhat, 2j * torch.pi * ky * vhat

@staticmethod
def vorticity_to_velocity(
grid: Grid, w_hat: Array, rfft_mesh: Optional[Tuple[Array, Array]] = None
):
"""Constructs a function for converting vorticity to velocity, both in Fourier domain.
Solves for the stream function and then uses the stream function to compute
Expand All @@ -229,19 +239,24 @@ def vorticity_to_velocity(grid: Grid, w_hat: Array):
Pages 509-520, ISSN 0045-7930,
https://doi.org/10.1016/j.compfluid.2003.06.003.
"""
device = w_hat.device
kx, ky = grid.rfft_mesh()
kx, ky = kx.to(device), ky.to(device)
kx, ky = rfft_mesh if rfft_mesh is not None else grid.rfft_mesh()
two_pi_i = 2 * torch.pi * 1j
laplace = two_pi_i**2 * (abs(kx) ** 2 + abs(ky) ** 2)
laplace[0, 0] = 1
psi_hat = -1 / laplace * w_hat
vxhat = two_pi_i * ky * psi_hat
vyhat = -two_pi_i * kx * psi_hat
return vxhat, vyhat

def residual(self,
vort_hat: Array,
vort_t_hat: Array,
):
residual = vort_t_hat - self.explicit_terms(vort_hat) - self.viscosity * self.implicit_terms(vort_hat)
return residual

def _explicit_terms(self, vort_hat):
vxhat, vyhat = self.vorticity_to_velocity(self.grid, vort_hat)
vxhat, vyhat = self.vorticity_to_velocity(self.grid, vort_hat, (self.kx, self.ky))
vx, vy = fft.irfft2(vxhat), fft.irfft2(vyhat)

grad_x_hat = 2j * torch.pi * self.kx * vort_hat
Expand All @@ -251,15 +266,15 @@ def _explicit_terms(self, vort_hat):
advection = -(grad_x * vx + grad_y * vy)
advection_hat = fft.rfft2(advection)

if self.smooth is not None:
if self.smooth:
advection_hat *= self.filter

terms = advection_hat

if self.forcing_fn is not None:
fx, fy = self.forcing_fn(self.grid, (vx, vy))
fx_hat, fy_hat = fft.rfft2(fx.data), fft.rfft2(fy.data)
terms += self.spectral_curl_2d((self.kx, self.ky), (fx_hat, fy_hat))
terms += self.spectral_curl_2d((fx_hat, fy_hat), (self.kx, self.ky))

return terms

Expand All @@ -276,7 +291,7 @@ def get_trajectory(
self,
w0: Array,
dt: float,
time_steps: int,
T: float,
record_every_steps=1,
pbar=False,
pbar_desc="",
Expand All @@ -288,7 +303,9 @@ def get_trajectory(
w_all = []
v_all = []
dwdt_all = []
res_all = []
w = w0
time_steps = int(T / dt)
update_iters = time_steps // TQDM_ITERS
with tqdm(total=time_steps) as pbar:
for t in range(time_steps):
Expand All @@ -304,14 +321,18 @@ def get_trajectory(
w_ = w.detach().clone()
dwdt_ = dwdt.detach().clone()
v = self.vorticity_to_velocity(self.grid, w_)
res = self.residual(w_, dwdt_)

v = torch.stack(v, dim=0)
w_all.append(w_)
v_all.append(v)
dwdt_all.append(dwdt_)
res_all.append(res)

result = {
var_name: torch.stack(var, dim=0)
for var_name, var in zip(
["vorticity", "velocity", "vort_t"], [w_all, v_all, dwdt_all]
["vorticity", "velocity", "vort_t", "residual"], [w_all, v_all, dwdt_all, res_all]
)
}
return result
Expand Down
229 changes: 203 additions & 26 deletions torch_cfd/forcings.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,36 +15,213 @@
# Modifications copyright (C) 2024 S.Cao
# ported Google's Jax-CFD functional template to PyTorch's tensor ops

from typing import Tuple, Optional
from typing import Optional, Tuple

import torch
import torch.nn as nn

from . import grids

Array = torch.Tensor
Grid = grids.Grid
GridArray = grids.GridArray

def kolmogorov_forcing(
grid: Grid,
v: Tuple[Array, Array],
scale: float = 1,
k: int = 2,
swap_xy: bool = False,
offsets: Optional[Tuple[Tuple[float, ...], ...]] = None,
device: Optional[torch.device] = None,
) -> Array:
"""Returns the Kolmogorov forcing function for turbulence in 2D."""
if offsets is None:
offsets = grid.cell_faces
if grid.device is None and device is not None:
grid.device = device
if swap_xy:
x = grid.mesh(offsets[1])[0]
v = GridArray(scale * torch.sin(k * x), offsets[1], grid)
u = GridArray(torch.zeros_like(v.data), (1, 1 / 2), grid)
f = (u, v)
else:
y = grid.mesh(offsets[0])[1]
u = GridArray(scale * torch.sin(k * y), offsets[0], grid)
v = GridArray(torch.zeros_like(u.data), (1 / 2, 1), grid)
f = (u, v)
return f

class ForcingFn(nn.Module):
"""
A meta class for forcing functions
"""

def __init__(
self,
grid: Grid,
scale: float = 1,
k: int = 1,
diam: float = 1.0,
swap_xy: bool = False,
offsets: Optional[Tuple[Tuple[float, ...], ...]] = None,
device: Optional[torch.device] = None,
**kwargs,
):
super().__init__()
self.grid = grid
self.scale = scale
self.k = k
self.diam = diam
self.swap_xy = swap_xy
self.offsets = grid.cell_faces if offsets is None else offsets
self.device = grid.device if device is None else device


class KolmogorovForcing(ForcingFn):
"""
The Kolmogorov forcing function used in
Sets up the flow that is used in Kochkov et al. [1].
which is based on Boffetta et al. [2].
Note in the port: this forcing belongs a larger class
of isotropic turbulence. See [3].
References:
[1] Machine learning-accelerated computational fluid dynamics. Dmitrii
Kochkov, Jamie A. Smith, Ayya Alieva, Qing Wang, Michael P. Brenner, Stephan
Hoyer Proceedings of the National Academy of Sciences May 2021, 118 (21)
e2101784118; DOI: 10.1073/pnas.2101784118.
https://doi.org/10.1073/pnas.2101784118
[2] Boffetta, Guido, and Robert E. Ecke. "Two-dimensional turbulence."
Annual review of fluid mechanics 44 (2012): 427-451.
https://doi.org/10.1146/annurev-fluid-120710-101240
[3] McWilliams, J. C. (1984). "The emergence of isolated coherent vortices
in turbulent flow". Journal of Fluid Mechanics, 146, 21-43.
"""

def __init__(
self,
diam=2 * torch.pi,
offsets=((0, 0), (0, 0)),
*args,
**kwargs,
):
super().__init__(
*args,
diam=diam,
offsets=offsets,
**kwargs,
)

def forward(
self,
grid: Optional[Grid],
velocity: Optional[Tuple[Array, Array]] = None,
) -> Tuple[Array, Array]:
offsets = self.offsets
grid = self.grid if grid is None else grid
domain_factor = 2 * torch.pi / self.diam

if self.swap_xy:
x = grid.mesh(offsets[1])[0]
v = GridArray(
self.scale * torch.sin(self.k * domain_factor * x), offsets[1], grid
)
u = GridArray(torch.zeros_like(v.data), (1, 1 / 2), grid)
f = (u, v)
else:
y = grid.mesh(offsets[0])[1]
u = GridArray(
self.scale * torch.sin(self.k * domain_factor * y), offsets[0], grid
)
v = GridArray(torch.zeros_like(u.data), (1 / 2, 1), grid)
f = (u, v)
return f

def potential_template(potential_func):
def wrapper(cls, x: Array, y: Array, s: float, k: float) -> Array:
return potential_func(x, y, s, k)
return wrapper


class SimpleSolenoidalForcing(ForcingFn):
"""
A simple solenoidal (rotating, divergence free) forcing function template.
The template forcing is F = (-psi, psi) such that
Args:
grid: grid on which to simulate the flow
scale: a in the equation above, amplitude of the forcing
k: k in the equation above, wavenumber of the forcing
"""

def __init__(
self,
scale=1,
diam=1.0,
k=1.0,
offsets=((0, 0), (0, 0)),
*args,
**kwargs,
):
super().__init__(
*args,
scale=scale,
diam=diam,
k=k,
offsets=offsets,
**kwargs,
)


@potential_template
def potential(*args, **kwargs) -> Array:
raise NotImplementedError

def forward(
self,
grid: Optional[Grid],
velocity: Optional[Tuple[Array, Array]] = None,
) -> Tuple[Array, Array]:
offsets = self.offsets
grid = self.grid if grid is None else grid
domain_factor = 2 * torch.pi / self.diam
k = self.k * domain_factor
scale = 0.5 * self.scale / (2 * torch.pi) / self.k

if self.swap_xy:
x = grid.mesh(offsets[1])[0]
y = grid.mesh(offsets[0])[1]
rot = self.potential(x, y, scale, k)
v = GridArray(rot, offsets[1], grid)
u = GridArray(-rot, (1, 1 / 2), grid)
f = (u, v)
else:
x = grid.mesh(offsets[0])[0]
y = grid.mesh(offsets[1])[1]
rot = self.potential(x, y, scale, k)
u = GridArray(rot, offsets[0], grid)
v = GridArray(-rot, (1 / 2, 1), grid)
f = (u, v)
return f


class SinCosForcing(SimpleSolenoidalForcing):
"""
The solenoidal (divergence free) forcing function used in [4].
Note: in the vorticity-streamfunction formulation, the forcing
is actually the curl of the velocity field, which
is a*(sin(2*pi*k*(x+y)) + cos(2*pi*k*(x+y)))
a=0.1, k=1 in [4]
References:
[4] Li, Zongyi, et al. "Fourier Neural Operator for
Parametric Partial Differential Equations."
ICLR. 2020.
Args:
grid: grid on which to simulate the flow
scale: a in the equation above, amplitude of the forcing
k: k in the equation above, wavenumber of the forcing
"""

def __init__(
self,
scale=0.1,
diam=1.0,
k=1.0,
offsets=((0, 0), (0, 0)),
*args,
**kwargs,
):
super().__init__(
*args,
scale=scale,
diam=diam,
k=k,
offsets=offsets,
**kwargs,
)

@potential_template
def potential(x: Array, y: Array, s: float, k: float) -> Array:
return s * (torch.sin(k * (x + y)) - torch.cos(k * (x + y)))
Loading

0 comments on commit 8dbb1d5

Please sign in to comment.