From e0cd2a8dde9d90712c70e2209cc0ecb085600660 Mon Sep 17 00:00:00 2001 From: Felix Koehler Date: Thu, 16 May 2024 13:33:16 +0200 Subject: [PATCH] Rewrite derivative routines to use wrapped FFT calls --- exponax/_spectral.py | 214 +++++++++++++++++++++---------------------- 1 file changed, 107 insertions(+), 107 deletions(-) diff --git a/exponax/_spectral.py b/exponax/_spectral.py index 02531d0..2028bcb 100644 --- a/exponax/_spectral.py +++ b/exponax/_spectral.py @@ -71,113 +71,6 @@ def build_scaled_wavenumbers( return scale * wavenumbers -def derivative( - field: Float[Array, "C ... N"], - domain_extent: float, - *, - order: int = 1, - indexing: str = "ij", -) -> Union[Float[Array, "C D ... (N//2)+1"], Float[Array, "D ... (N//2)+1"]]: - """ - Perform the spectral derivative of a field. In higher dimensions, this - defaults to the gradient (the collection of all partial derivatives). In 1d, - the resulting channel dimension holds the derivative. If the function is - called with an d-dimensional field which has 1 channel, the result will be a - d-dimensional field with d channels (one per partial derivative). If the - field originally had C channels, the result will be a matrix field with C - rows and d columns. - - Note that applying this operator twice will produce issues at the Nyquist if - the number of degrees of freedom N is even. For this, consider also using - the order option. - - **Arguments:** - - `field`: The field to differentiate, shape `(C, ..., N,)`. `C` can be - `1` for a scalar field or `D` for a vector field. - - `L`: The domain extent. - - `order`: The order of the derivative. Default is `1`. - - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. - Either `"ij"` or `"xy"`. Default is `"ij"`. - - **Returns:** - - `field_der`: The derivative of the field, shape `(C, D, ..., - (N//2)+1)` or `(D, ..., (N//2)+1)`. - """ - channel_shape = field.shape[0] - spatial_shape = field.shape[1:] - D = len(spatial_shape) - N = spatial_shape[0] - derivative_operator = build_derivative_operator( - D, domain_extent, N, indexing=indexing - ) - # # I decided to not use this fix - - # # Required for even N, no effect for odd N - # derivative_operator_fixed = ( - # derivative_operator * nyquist_filter_mask(D, N) - # ) - derivative_operator_fixed = derivative_operator**order - - field_hat = jnp.fft.rfftn(field, axes=space_indices(D)) - if channel_shape == 1: - # Do not introduce another channel axis - field_der_hat = derivative_operator_fixed * field_hat - else: - # Create a "derivative axis" right after the channel axis - field_der_hat = field_hat[:, None] * derivative_operator_fixed[None, ...] - - field_der = jnp.fft.irfftn(field_der_hat, s=spatial_shape, axes=space_indices(D)) - - return field_der - - -def make_incompressible( - field: Float[Array, "D ... N"], - *, - indexing: str = "ij", -): - channel_shape = field.shape[0] - spatial_shape = field.shape[1:] - num_spatial_dims = len(spatial_shape) - if channel_shape != num_spatial_dims: - raise ValueError( - f"Expected the number of channels to be {num_spatial_dims}, got {channel_shape}." - ) - num_points = spatial_shape[0] - - derivative_operator = build_derivative_operator( - num_spatial_dims, 1.0, num_points, indexing=indexing - ) # domain_extent does not matter because it will cancel out - - incompressible_field_hat = jnp.fft.rfftn( - field, axes=space_indices(num_spatial_dims) - ) - - divergence = jnp.sum( - derivative_operator * incompressible_field_hat, axis=0, keepdims=True - ) - - laplace_operator = build_laplace_operator(derivative_operator) - - inv_laplace_operator = jnp.where( - laplace_operator == 0, - 1.0, - 1.0 / laplace_operator, - ) - - pseudo_pressure = -inv_laplace_operator * divergence - - pseudo_pressure_garadient = derivative_operator * pseudo_pressure - - incompressible_field_hat = incompressible_field_hat - pseudo_pressure_garadient - - incompressible_field = jnp.fft.irfftn( - incompressible_field_hat, s=spatial_shape, axes=space_indices(num_spatial_dims) - ) - - return incompressible_field - - def build_derivative_operator( num_spatial_dims: int, domain_extent: float, @@ -535,3 +428,110 @@ def ifft( s=spatial_shape(num_spatial_dims, num_points), axes=space_indices(num_spatial_dims), ) + + +def derivative( + field: Float[Array, "C ... N"], + domain_extent: float, + *, + order: int = 1, + indexing: str = "ij", +) -> Union[Float[Array, "C D ... (N//2)+1"], Float[Array, "D ... (N//2)+1"]]: + """ + Perform the spectral derivative of a field. In higher dimensions, this + defaults to the gradient (the collection of all partial derivatives). In 1d, + the resulting channel dimension holds the derivative. If the function is + called with an d-dimensional field which has 1 channel, the result will be a + d-dimensional field with d channels (one per partial derivative). If the + field originally had C channels, the result will be a matrix field with C + rows and d columns. + + Note that applying this operator twice will produce issues at the Nyquist if + the number of degrees of freedom N is even. For this, consider also using + the order option. + + **Arguments:** + - `field`: The field to differentiate, shape `(C, ..., N,)`. `C` can be + `1` for a scalar field or `D` for a vector field. + - `L`: The domain extent. + - `order`: The order of the derivative. Default is `1`. + - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. + Either `"ij"` or `"xy"`. Default is `"ij"`. + + **Returns:** + - `field_der`: The derivative of the field, shape `(C, D, ..., + (N//2)+1)` or `(D, ..., (N//2)+1)`. + """ + channel_shape = field.shape[0] + spatial_shape = field.shape[1:] + D = len(spatial_shape) + N = spatial_shape[0] + derivative_operator = build_derivative_operator( + D, domain_extent, N, indexing=indexing + ) + # # I decided to not use this fix + + # # Required for even N, no effect for odd N + # derivative_operator_fixed = ( + # derivative_operator * nyquist_filter_mask(D, N) + # ) + derivative_operator_fixed = derivative_operator**order + + field_hat = fft(field, num_spatial_dims=D) + if channel_shape == 1: + # Do not introduce another channel axis + field_der_hat = derivative_operator_fixed * field_hat + else: + # Create a "derivative axis" right after the channel axis + field_der_hat = field_hat[:, None] * derivative_operator_fixed[None, ...] + + field_der = ifft(field_der_hat, num_spatial_dims=D, num_points=N) + + return field_der + + +def make_incompressible( + field: Float[Array, "D ... N"], + *, + indexing: str = "ij", +): + channel_shape = field.shape[0] + spatial_shape = field.shape[1:] + num_spatial_dims = len(spatial_shape) + if channel_shape != num_spatial_dims: + raise ValueError( + f"Expected the number of channels to be {num_spatial_dims}, got {channel_shape}." + ) + num_points = spatial_shape[0] + + derivative_operator = build_derivative_operator( + num_spatial_dims, 1.0, num_points, indexing=indexing + ) # domain_extent does not matter because it will cancel out + + incompressible_field_hat = fft(field, num_spatial_dims=num_spatial_dims) + + divergence = jnp.sum( + derivative_operator * incompressible_field_hat, axis=0, keepdims=True + ) + + laplace_operator = build_laplace_operator(derivative_operator) + + inv_laplace_operator = jnp.where( + laplace_operator == 0, + 1.0, + 1.0 / laplace_operator, + ) + + pseudo_pressure = -inv_laplace_operator * divergence + + pseudo_pressure_garadient = derivative_operator * pseudo_pressure + + incompressible_field_hat = incompressible_field_hat - pseudo_pressure_garadient + + incompressible_field = ifft( + incompressible_field_hat, + num_spatial_dims=num_spatial_dims, + num_points=num_points, + ) + + return incompressible_field