Skip to content

Commit

Permalink
added batch dimension in solver class
Browse files Browse the repository at this point in the history
  • Loading branch information
scaomath committed May 6, 2024
1 parent c5f842b commit 2f5dea8
Show file tree
Hide file tree
Showing 4 changed files with 378 additions and 199 deletions.
274 changes: 191 additions & 83 deletions example_Kolmogrov2d_rk4_cn_forced_turbulence.ipynb

Large diffs are not rendered by default.

61 changes: 40 additions & 21 deletions torch_cfd/equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
# Modifications copyright (C) 2024 S.Cao
# ported Google's Jax-CFD functional template to PyTorch's tensor ops

from typing import Tuple, Callable, Dict, Optional
from typing import Callable, Dict, Optional, Tuple

import torch
import torch.nn as nn
import torch.fft as fft
from . import grids
import torch.nn as nn
from tqdm.auto import tqdm

from . import grids

TQDM_ITERS = 500

Array = torch.Tensor
Expand Down Expand Up @@ -55,7 +56,7 @@ def stable_time_step(
dt_diffusion = dx

if not implicit_diffusion:
dt_diffusion = dx ** 2 / (viscosity * 2 ** (ndim))
dt_diffusion = dx**2 / (viscosity * 2 ** (ndim))
dt_advection = max_courant_number * dx / max_velocity
dt = dt_advection if dt is None else dt
return min(dt_diffusion, dt_advection, dt)
Expand Down Expand Up @@ -264,7 +265,7 @@ def crank_nicolson_rk4(
)


class NavierStokes2DSpectral(nn.Module):
class NavierStokes2DSpectral(ImplicitExplicitODE):
"""Breaks the Navier-Stokes equation into implicit and explicit parts.
Implicit parts are the linear terms and explicit parts are the non-linear
Expand Down Expand Up @@ -357,12 +358,13 @@ def vorticity_to_velocity(
"""
kx, ky = rfft_mesh if rfft_mesh is not None else grid.rfft_mesh()
two_pi_i = 2 * torch.pi * 1j
assert kx.shape[-2:] == w_hat.shape[-2:]
laplace = two_pi_i**2 * (abs(kx) ** 2 + abs(ky) ** 2)
laplace[0, 0] = 1
laplace[..., 0, 0] = 1
psi_hat = -1 / laplace * w_hat
vxhat = two_pi_i * ky * psi_hat
vyhat = -two_pi_i * kx * psi_hat
return vxhat, vyhat
u_hat = two_pi_i * ky * psi_hat
v_hat = -two_pi_i * kx * psi_hat
return u_hat, v_hat

def residual(
self,
Expand Down Expand Up @@ -426,8 +428,15 @@ def get_trajectory(
):
"""
vorticity stacked in the time dimension
all inputs and outputs are in the frequency domain
input: w0 (*, n, n)
output:
vorticity (*, n_t, kx, ky)
velocity: tuple of (*, n_t, kx, ky)
"""
w_all = []
u_all = []
v_all = []
dwdt_all = []
res_all = []
Expand All @@ -445,30 +454,40 @@ def get_trajectory(
pbar.update(update_iters)

if t % record_every_steps == 0:
w_ = w.detach().clone()
dwdt_ = dwdt.detach().clone()
v = self.vorticity_to_velocity(self.grid, w_)
res = self.residual(w_, dwdt_)
u, v = self.vorticity_to_velocity(self.grid, w)
res = self.residual(w, dwdt)

w_, dwdt_, u, v, res = [
var.detach().cpu().clone() for var in [w, dwdt, u, v, res]
]

v = torch.stack(v, dim=0)
w_all.append(w_)
u_all.append(u)
v_all.append(v)
dwdt_all.append(dwdt_)
res_all.append(res)

result = {
var_name: torch.stack(var, dim=0)
var_name: torch.stack(var, dim=-3)
for var_name, var in zip(
["vorticity", "velocity", "vort_t", "residual"],
[w_all, v_all, dwdt_all, res_all],
["vorticity", "u", "v", "vort_t", "residual"],
[w_all, u_all, v_all, dwdt_all, res_all],
)
}
return result

def step(self, *args, **kwargs):
return self.forward(*args, **kwargs)

def forward(self, vort_hat, dt):
vort_hat_new = self.solver(vort_hat, dt, self)
dvortdt_hat = 1 / dt * (vort_hat_new - vort_hat)
return vort_hat_new, dvortdt_hat
def forward(self, vort_hat, dt, steps=1):
"""
vort_hat: (B, kx, ky) or (n_t, kx, ky) or (kx, ky)
- if rfft2 is used then the shape is (*, kx, ky//2+1)
- if (n_t, kx, ky), then the time step marches in the time
dimension in parallel.
"""
vort_old = vort_hat
for _ in range(steps):
vort_hat = self.solver(vort_hat, dt, self)
dvortdt_hat = 1 / (steps * dt) * (vort_hat - vort_old)
return vort_hat, dvortdt_hat
111 changes: 16 additions & 95 deletions torch_cfd/initial_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
# ported Google's Jax-CFD functional template to PyTorch's tensor ops

"""Prepare initial conditions for simulations."""
from typing import Callable, Optional, Sequence
import math
from typing import Callable, Optional, Sequence

import torch
import torch.fft as fft
from . import grids
from . import finite_differences as fd
from . import fast_diagonalization as solver

from . import grids, pressure

Array = torch.Tensor
GridArray = grids.GridArray
Expand All @@ -45,6 +45,7 @@ def wrap_velocities(
for u, offset, bc in zip(v, grid.cell_faces, bcs)
)


def wrap_vorticity(
w: Array,
grid: grids.Grid,
Expand All @@ -57,7 +58,11 @@ def wrap_vorticity(


def _log_normal_density(k, mode: float, variance=0.25):
"""Unscaled PDF for a log normal given `mode` and log variance 1."""
"""
Unscaled PDF for a log normal given `mode` and log variance 1.
"""
mean = math.log(mode) + variance
logk = torch.log(k)
return torch.exp(-((mean - logk) ** 2) / 2 / variance - logk)
Expand All @@ -74,6 +79,7 @@ def McWilliams_density(k, mode: float, tau: float = 1.0):
"""
return (k * (tau**2 + (k / mode) ** 4)) ** (-1)


def _angular_frequency_magnitude(grid: grids.Grid) -> Array:
frequencies = [
2 * torch.pi * fft.fftfreq(size, step)
Expand All @@ -95,103 +101,19 @@ def spectral_filter(
# real, because our spectral density only depends on norm(k).
return fft.ifftn(fft.fftn(v) * filters).real


def streamfunc_normalize(k, psi):
# only half the spectrum for real ffts, needs spectral normalisation
nx, ny = psi.shape
psih = fft.fft2(psi)
uh = k * psih
kinetic_energy = (2 * uh.abs() ** 2 / (nx * ny) ** 2).sum()
uh_mag = k * psih
kinetic_energy = (2 * uh_mag.abs() ** 2 / (nx * ny) ** 2).sum()
return psi / kinetic_energy.sqrt()

def _rhs_transform(
u: GridArray,
bc: BoundaryConditions,
) -> Array:
"""Transform the RHS of pressure projection equation for stability.
In case of poisson equation, the kernel is subtracted from RHS for stability.
Args:
u: a GridArray that solves ∇²x = u.
bc: specifies boundary of x.
Returns:
u' s.t. u = u' + kernel of the laplacian.
"""
u_data = u.data
for axis in range(u.grid.ndim):
if (
bc.types[axis][0] == grids.BCType.NEUMANN
and bc.types[axis][1] == grids.BCType.NEUMANN
):
# if all sides are neumann, poisson solution has a kernel of constant
# functions. We substact the mean to ensure consistency.
u_data = u_data - torch.mean(u_data)
return u_data


def solve_fast_diag(
v: GridVariableVector,
q0: Optional[GridVariable] = None,
pressure_bc: Optional[grids.ConstantBoundaryConditions] = None,
implementation: Optional[str] = None,
) -> GridArray:
"""Solve for pressure using the fast diagonalization approach."""
del q0 # unused
if pressure_bc is None:
pressure_bc = grids.get_pressure_bc_from_velocity(v)
if grids.has_all_periodic_boundary_conditions(*v):
circulant = True
else:
circulant = False
# only matmul implementation supports non-circulant matrices
implementation = "matmul"
grid = grids.consistent_grid(*v)
rhs = fd.divergence(v)
laplacians = list(map(fd.laplacian_matrix, grid.shape, grid.step))
laplacians = [lap.to(grid.device) for lap in laplacians]
rhs_transformed = _rhs_transform(rhs, pressure_bc)
pinv = solver.pseudoinverse(
rhs_transformed,
laplacians,
rhs_transformed.dtype,
hermitian=True,
circulant=circulant,
implementation=implementation,
)
# return applied(pinv)(rhs_transformed)
return GridArray(pinv, rhs.offset, rhs.grid)


def projection(
v: GridVariableVector,
solve: Callable = solve_fast_diag,
) -> GridVariableVector:
"""
Apply pressure projection (a discrete Helmholtz decomposition)
to make a velocity field divergence free.
Note by S.Cao: this was originally implemented by the jax-cfd team
but using FDM results having a non-negligible error in fp32.
One resolution is to use fp64 then cast back to fp32.
"""
grid = grids.consistent_grid(*v)
pressure_bc = grids.get_pressure_bc_from_velocity(v)

q0 = GridArray(torch.zeros(grid.shape), grid.cell_center, grid)
q0 = pressure_bc.impose_bc(q0)

q = solve(v, q0, pressure_bc)
q = pressure_bc.impose_bc(q)
q_grad = fd.forward_difference(q)
v_projected = tuple(u.bc.impose_bc(u.array - q_g) for u, q_g in zip(v, q_grad))
return v_projected


def project_and_normalize(
v: GridVariableVector, maximum_velocity: float = 1
) -> GridVariableVector:
v = projection(v)
v = pressure.projection(v)
vmax = torch.linalg.norm(torch.stack([u.data for u in v]), dim=0).max()
v = tuple(GridVariable(maximum_velocity * u.array / vmax, u.bc) for u in v)
return v
Expand Down Expand Up @@ -256,7 +178,6 @@ def vorticity_field(
Args:
rng_key: key for seeding the random initial vorticity field.
grid: the grid on which the vorticity field is defined.
maximum_velocity: the maximum speed in the velocity field.
peak_wavenumber: the velocity field will be filtered so that the largest
magnitudes are associated with this wavenumber.
Expand All @@ -277,4 +198,4 @@ def spectral_density(k):
boundary_condition = grids.periodic_boundary_conditions(grid.ndim)
vorticity = wrap_vorticity(vorticity, grid, boundary_condition)

return vorticity
return vorticity
Loading

0 comments on commit 2f5dea8

Please sign in to comment.