Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transpose interpolation -> adjoint interpolation #3965

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions firedrake/__future__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ class CrossMeshInterpolator(Interpolator, CrossMeshInterpolatorOld):
def interpolate(expr, V, *args, **kwargs):
default_missing_val = kwargs.pop("default_missing_val", None)
if isinstance(V, Cofunction):
transpose = bool(extract_arguments(expr))
adjoint = bool(extract_arguments(expr))
return Interpolator(
expr, V.function_space().dual(), *args, **kwargs
).interpolate(V, transpose=transpose, default_missing_val=default_missing_val)
).interpolate(V, adjoint=adjoint, default_missing_val=default_missing_val)
return Interpolator(expr, V, *args, **kwargs).interpolate(default_missing_val=default_missing_val)
2 changes: 1 addition & 1 deletion firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
# Assembling the action of the Jacobian adjoint.
if is_adjoint:
output = tensor or firedrake.Cofunction(arg_expression[0].function_space().dual())
return interpolator._interpolate(v, output=output, transpose=True, default_missing_val=default_missing_val)
return interpolator._interpolate(v, output=output, adjoint=True, default_missing_val=default_missing_val)
# Assembling the Jacobian action.
if interpolator.nargs:
return interpolator._interpolate(expression, output=tensor, default_missing_val=default_missing_val)
Expand Down
108 changes: 63 additions & 45 deletions firedrake/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import tempfile
import abc
import warnings

import FIAT
import ufl
Expand All @@ -23,7 +24,7 @@

import firedrake
from firedrake import tsfc_interface, utils, functionspaceimpl
from firedrake.ufl_expr import Argument, action, adjoint
from firedrake.ufl_expr import Argument, action, adjoint as expr_adjoint
from firedrake.mesh import MissingPointsBehaviour, VertexOnlyMeshMissingPointsError
from firedrake.petsc import PETSc
from firedrake.halo import _get_mtype as get_dat_mpi_type
Expand Down Expand Up @@ -75,7 +76,7 @@ def __init__(self, expr, v,
(a) unchanged if some ``output`` is given to the :meth:`interpolate` method
or (b) set to zero.
Can be overwritten with the ``default_missing_val`` kwarg of :meth:`interpolate`.
This does not affect transpose interpolation. Ignored if interpolating within
This does not affect adjoint interpolation. Ignored if interpolating within
the same mesh or onto a :func:`.VertexOnlyMesh` (the behaviour of a
:func:`.VertexOnlyMesh` in this scenario is, at present, set when it is created).
default_missing_val : float
Expand Down Expand Up @@ -139,7 +140,7 @@ def _ufl_expr_reconstruct_(self, expr, v=None, **interp_data):
# - v is a function in V (NOT an expression).
# - w is a function in W.
# - Maths: v = Bw
# - v_star = B.interpolate(w_star, transpose = True)
# - v_star = B.interpolate(w_star, adjoint=True)
# - w_star is a cofunction in W^* (such as an assembled 1-form).
# - v_star is a cofunction in V^*.
# - Maths: v^* = B^* w^*
Expand Down Expand Up @@ -175,7 +176,7 @@ def interpolate(
``True`` the corresponding values are either (a) unchanged if
some ``output`` is given to the :meth:`interpolate` method or (b) set
to zero. In either case, if ``default_missing_val`` is specified, that
value is used. This does not affect transpose interpolation. Ignored if
value is used. This does not affect adjoint interpolation. Ignored if
interpolating within the same mesh or onto a :func:`.VertexOnlyMesh`
(the behaviour of a :func:`.VertexOnlyMesh` in this scenario is, at
present, set when it is created).
Expand Down Expand Up @@ -242,7 +243,7 @@ class Interpolator(abc.ABC):
``True`` the corresponding values are either (a) unchanged if
some ``output`` is given to the :meth:`interpolate` method or (b) set
to zero. Can be overwritten with the ``default_missing_val`` kwarg
of :meth:`interpolate`. This does not affect transpose interpolation.
of :meth:`interpolate`. This does not affect adjoint interpolation.
Ignored if interpolating within the same mesh or onto a
:func:`.VertexOnlyMesh` (the behaviour of a :func:`.VertexOnlyMesh` in
this scenario is, at present, set when it is created).
Expand Down Expand Up @@ -298,24 +299,26 @@ def __init__(
part=v.part())})
self.expr_renumbered = expr

def _interpolate_future(self, *function, transpose=False, default_missing_val=None):
def _interpolate_future(self, *function, transpose=None, adjoint=False, default_missing_val=None):
"""Define the :class:`Interpolate` object corresponding to the interpolation operation of interest.

Parameters
----------
*function: firedrake.function.Function or firedrake.cofunction.Cofunction
If the expression being interpolated contains an argument,
then the function value to interpolate.
transpose: bool
Set to true to apply the transpose (adjoint) of the
interpolation operator.
transpose : bool
Deprecated, use adjoint instead.
adjoint: bool
Set to true to apply the adjoint of the interpolation
operator.
default_missing_val: bool
For interpolation across meshes: the
optional value to assign to DoFs in the target mesh that are
outside the source mesh. If this is not set then the values are
either (a) unchanged if some ``output`` is specified to the
:meth:`interpolate` method or (b) set to zero. This does not affect
transpose interpolation. Ignored if interpolating within the same
adjoint interpolation. Ignored if interpolating within the same
mesh or onto a :func:`.VertexOnlyMesh`.

Returns
Expand All @@ -339,7 +342,10 @@ def _interpolate_future(self, *function, transpose=False, default_missing_val=No
allow_missing_dofs=self._allow_missing_dofs,
default_missing_val=default_missing_val)
if transpose:
interp = adjoint(interp)
warnings.warn("'transpose' argument is deprecated, use 'adjoint' instead", FutureWarning)
adjoint = transpose or adjoint
if adjoint:
interp = expr_adjoint(interp)

if function:
f, = function
Expand All @@ -349,7 +355,7 @@ def _interpolate_future(self, *function, transpose=False, default_missing_val=No
return interp

@PETSc.Log.EventDecorator()
def interpolate(self, *function, output=None, transpose=False, default_missing_val=None,
def interpolate(self, *function, output=None, transpose=None, adjoint=False, default_missing_val=None,
ad_block_tag=None):
"""Compute the interpolation by assembling the appropriate :class:`Interpolate` object.

Expand All @@ -360,16 +366,18 @@ def interpolate(self, *function, output=None, transpose=False, default_missing_v
then the function value to interpolate.
output: firedrake.function.Function or firedrake.cofunction.Cofunction
A function to contain the output.
transpose: bool
Set to true to apply the transpose (adjoint) of the
interpolation operator.
transpose : bool
Deprecated, use adjoint instead.
adjoint: bool
Set to true to apply the adjoint of the interpolation
operator.
default_missing_val: bool
For interpolation across meshes: the
optional value to assign to DoFs in the target mesh that are
outside the source mesh. If this is not set then the values are
either (a) unchanged if some ``output`` is specified to the
:meth:`interpolate` method or (b) set to zero. This does not affect
transpose interpolation. Ignored if interpolating within the same
adjoint interpolation. Ignored if interpolating within the same
mesh or onto a :func:`.VertexOnlyMesh`.
ad_block_tag: str
An optional string for tagging the resulting assemble block on the Pyadjoint tape.
Expand All @@ -379,7 +387,6 @@ def interpolate(self, *function, output=None, transpose=False, default_missing_v
firedrake.function.Function or firedrake.cofunction.Cofunction
The resulting interpolated function.
"""
import warnings
from firedrake.assemble import assemble

warnings.warn("""The use of `interpolate` to perform the numerical interpolation is deprecated.
Expand All @@ -402,9 +409,12 @@ def interpolate(self, *function, output=None, transpose=False, default_missing_v
Alternatively, you can also perform other symbolic operations on the interpolation operator, such as taking
the derivative, and then assemble the resulting form.
""", FutureWarning)
if transpose:
warnings.warn("'transpose' argument is deprecated, use 'adjoint' instead", FutureWarning)
adjoint = transpose or adjoint

# Get the Interpolate object
interp = self._interpolate_future(*function, transpose=transpose,
interp = self._interpolate_future(*function, adjoint=adjoint,
default_missing_val=default_missing_val)

if isinstance(self.V, firedrake.Function) and not output:
Expand Down Expand Up @@ -643,7 +653,8 @@ def _interpolate(
self,
*function,
output=None,
transpose=False,
transpose=None,
adjoint=False,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using an adjoint argument unfortunately hides the adjoint function. I've worked around this with expr_adjoint but that doesn't seem ideal.

default_missing_val=None,
**kwargs,
):
Expand All @@ -653,9 +664,12 @@ def _interpolate(
"""
from firedrake.assemble import assemble

if transpose and not self.nargs:
if transpose is not None:
warnings.warn("'transpose' argument is deprecated, use 'adjoint' instead", FutureWarning)
adjoint = transpose or adjoint
if adjoint and not self.nargs:
raise ValueError(
"Can currently only apply transpose interpolation with arguments."
"Can currently only apply adjoint interpolation with arguments."
)
if self.nargs != len(function):
raise ValueError(
Expand All @@ -672,7 +686,7 @@ def _interpolate(
else:
f_src = self.expr

if transpose:
if adjoint:
try:
V_dest = self.expr.function_space().dual()
except AttributeError:
Expand All @@ -684,7 +698,7 @@ def _interpolate(
V_dest = coeffs[0].function_space().dual()
else:
raise ValueError(
"Can't transpose interpolate an expression with no coefficients or arguments."
"Can't adjoint interpolate an expression with no coefficients or arguments."
)
else:
if isinstance(self.V, (firedrake.Function, firedrake.Cofunction)):
Expand All @@ -710,14 +724,14 @@ def _interpolate(
# so the sub_interpolators are already prepared to interpolate
# without needing to be given a Function
assert not self.nargs
interp = sub_interpolator._interpolate_future(transpose=transpose, **kwargs)
interp = sub_interpolator._interpolate_future(adjoint=adjoint, **kwargs)
assemble(interp, tensor=output_sub_func)
else:
interp = sub_interpolator._interpolate_future(transpose=transpose, **kwargs)
interp = sub_interpolator._interpolate_future(adjoint=adjoint, **kwargs)
assemble(action(interp, f_src_sub_func), tensor=output_sub_func)
return output

if not transpose:
if not adjoint:
if f_src is self.expr:
# f_src is already contained in self.point_eval_interpolate
assert not self.nargs
Expand Down Expand Up @@ -764,7 +778,7 @@ def _interpolate(
] = f_src_at_dest_node_coords_dest_mesh_decomp.dat.data_ro[:]

else:
# adjoint/transpose interpolation
# adjoint interpolation

# f_src is a cofunction on V_dest.dual as originally specified when
# creating the interpolator. Our first adjoint operation is to
Expand All @@ -780,25 +794,25 @@ def _interpolate(
:
] = f_src.dat.data_ro[:]

# The rest of the transpose interpolation is merely the composition
# of the transpose interpolators in the reverse direction. NOTE: I
# The rest of the adjoint interpolation is merely the composition
# of the adjoint interpolators in the reverse direction. NOTE: I
# don't have to worry about skipping over missing points here
# because I'm going from the input ordering VOM to the original VOM
# and all points from the input ordering VOM are in the original.
interp = action(adjoint(self.to_input_ordering_interpolate), f_src_at_dest_node_coords_dest_mesh_decomp)
interp = action(expr_adjoint(self.to_input_ordering_interpolate), f_src_at_dest_node_coords_dest_mesh_decomp)
f_src_at_src_node_coords = assemble(interp)
# NOTE: if I wanted the default missing value to be applied to
# transpose interpolation I would have to do it here. However,
# adjoint interpolation I would have to do it here. However,
# this would require me to implement default missing values for
# transpose interpolation from a point evaluation interpolator
# adjoint interpolation from a point evaluation interpolator
# which I haven't done. I wonder if it is necessary - perhaps the
# adjoint operator always sets all the values of the resulting
# cofunction? My initial attempt to insert setting the dat values
# prior to performing the multTranspose operation in
# prior to performing the multHermitian operation in
# SameMeshInterpolator.interpolate did not effect the result. For
# now, I say in the docstring that it only applies to forward
# interpolation.
interp = action(adjoint(self.point_eval_interpolate), f_src_at_src_node_coords)
interp = action(expr_adjoint(self.point_eval_interpolate), f_src_at_src_node_coords)
assemble(interp, tensor=output)

return output
Expand All @@ -823,13 +837,17 @@ def __init__(self, expr, V, subset=None, freeze_expr=False, access=op2.WRITE, bc
self.nargs = len(arguments)

@PETSc.Log.EventDecorator()
def _interpolate(self, *function, output=None, transpose=False, **kwargs):
def _interpolate(self, *function, output=None, transpose=None, adjoint=False, **kwargs):
"""Compute the interpolation.

For arguments, see :class:`.Interpolator`.
"""
if transpose and not self.nargs:
raise ValueError("Can currently only apply transpose interpolation with arguments.")

if transpose is not None:
warnings.warn("'transpose' argument is deprecated, use 'adjoint' instead", FutureWarning)
adjoint = transpose or adjoint
if adjoint and not self.nargs:
raise ValueError("Can currently only apply adjoint interpolation with arguments.")
if self.nargs != len(function):
raise ValueError("Passed %d Functions to interpolate, expected %d"
% (len(function), self.nargs))
Expand All @@ -851,8 +869,8 @@ def _interpolate(self, *function, output=None, transpose=False, **kwargs):
function, = function
if not hasattr(function, "dat"):
raise ValueError("The expression had arguments: we therefore need to be given a Function (not an expression) to interpolate!")
if transpose:
mul = assembled_interpolator.handle.multTranspose
if adjoint:
mul = assembled_interpolator.handle.multHermitian
V = self.arguments[0].function_space()
else:
mul = assembled_interpolator.handle.mult
Expand Down Expand Up @@ -1520,7 +1538,7 @@ class VomOntoVomDummyMat(object):
forward_reduce : bool
If ``True``, the action of the operator (accessed via the `mult`
method) is to perform a SF reduce from the source vec to the target
vec, whilst the adjoint action (accessed via the `multTranspose`
vec, whilst the adjoint action (accessed via the `multHermitian`
method) is to perform a SF broadcast from the source vec to the target
vec. If ``False``, the opposite is true.
V : `.FunctionSpace`
Expand Down Expand Up @@ -1630,23 +1648,23 @@ def mult(self, source_vec, target_vec):
else:
self.broadcast(coeff_vec, target_vec)

def multTranspose(self, source_vec, target_vec):
# can only do transpose if our expression exclusively contains a
def multHermitian(self, source_vec, target_vec):
# can only do adjoint if our expression exclusively contains a
# single argument, making the application of the adjoint operator
# straightforward (haven't worked out how to do this otherwise!)
if not len(self.arguments) == 1:
raise NotImplementedError(
"Can only apply transpose to expressions with one argument!"
"Can only apply adjoint to expressions with one argument!"
)
if self.arguments[0] is not self.expr:
raise NotImplementedError(
"Can only apply transpose to expressions consisting of a single argument at the moment."
"Can only apply adjoint to expressions consisting of a single argument at the moment."
)
if self.forward_reduce:
self.broadcast(source_vec, target_vec)
else:
# We need to ensure the target vec is zeroed for SF Reduce to
# represent multTranspose in case the interpolation matrix is not
# represent multHermitian in case the interpolation matrix is not
# square (in which case it will have columns which are zero). This
# happens when we interpolate from an input-ordering vertex-only
# mesh to an immersed vertex-only mesh where the input ordering
Expand Down
15 changes: 15 additions & 0 deletions tests/firedrake/regression/test_interp_dual.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
from firedrake import *
from firedrake.__future__ import *
from firedrake.utils import complex_mode
import ufl


Expand Down Expand Up @@ -100,6 +101,20 @@ def test_assemble_interp_adjoint_model(V1, V2):
assert np.allclose(res.dat.data, Ivfstar.dat.data)


def test_assemble_interp_adjoint_complex(mesh, V1, f1):
if complex_mode:
f1 = Constant(3 - 5.j) * f1

a = assemble(conj(TestFunction(V1)) * dx)
b = assemble(action(adjoint(Interpolate(f1 * TestFunction(V1), V1)), a))

x, y = SpatialCoordinate(mesh)
f2 = Function(V1, name="f2").interpolate(
exp(x) * y)

assert np.allclose(assemble(b(f2)), assemble(Function(V1).interpolate(conj(f1 * f2)) * dx))


def test_assemble_interp_rank0(V1, V2, f1):
# -- Interpolate(f1, u2) (rank 0) -- #
v2 = TestFunction(V2)
Expand Down
Loading
Loading