Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement snapshotting for the acoustic wave equation #2474

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
89 changes: 69 additions & 20 deletions examples/seismic/acoustic/operators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from devito import Eq, Operator, Function, TimeFunction, Inc, solve, sign
from devito import Eq, Operator, Function, TimeFunction, Inc, solve, sign, ConditionalDimension
from devito.symbolics import retrieve_functions, INT, retrieve_derivatives


Expand Down Expand Up @@ -108,7 +108,7 @@ def iso_stencil(field, model, kernel, **kwargs):


def ForwardOperator(model, geometry, space_order=4,
save=False, kernel='OT2', **kwargs):
save=False, kernel='OT2', factor=None, **kwargs):
"""
Construct a forward modelling operator in an acoustic medium.

Expand All @@ -126,6 +126,8 @@ def ForwardOperator(model, geometry, space_order=4,
Defaults to False.
kernel : str, optional
Type of discretization, 'OT2' or 'OT4'.
factor : int, optional
Downsampling factor to save snapshots of the wavefield.
"""
m = model.m

Expand All @@ -144,10 +146,28 @@ def ForwardOperator(model, geometry, space_order=4,

# Create interpolation expression for receivers
rec_term = rec.interpolate(expr=u)

# Build operator equations
equations = eqn + src_term + rec_term

if factor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be wrapped into a utility function as it's duplicated below

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've created a function to construct usnaps.

# Implement snapshotting
nsnaps = (geometry.nt + factor - 1) // factor
time_subsampled = ConditionalDimension(
't_sub', parent=model.grid.time_dim, factor=factor)
usnaps = TimeFunction(name='usnaps', grid=model.grid,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You still have u with full time saved line 135 you can't have both

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

time_order=2, space_order=space_order,
save=nsnaps, time_dim=time_subsampled)
# Add equation to save snapshots
snapshot_eq = Eq(usnaps, u)
equations += [snapshot_eq]
else:
usnaps = None
# Substitute spacing terms to reduce flops
return Operator(eqn + src_term + rec_term, subs=model.spacing_map,
name='Forward', **kwargs)
op = Operator(equations, subs=model.spacing_map, name='Forward', **kwargs)
if usnaps is not None:
return op, usnaps
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No the operator build cannot return objects like that. This is an abstract operator with placeholders that might not be correct for runtime.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. The operator build only returns op now.

else:
return op


def AdjointOperator(model, geometry, space_order=4,
Expand Down Expand Up @@ -189,8 +209,8 @@ def AdjointOperator(model, geometry, space_order=4,


def GradientOperator(model, geometry, space_order=4, save=True,
kernel='OT2', **kwargs):
"""
kernel='OT2', factor=None, **kwargs):
"""
Construct a gradient operator in an acoustic media.

Parameters
Expand All @@ -206,30 +226,59 @@ def GradientOperator(model, geometry, space_order=4, save=True,
Option to store the entire (unrolled) wavefield.
kernel : str, optional
Type of discretization, centered or shifted.
factor : int, optional
Downsampling factor to save snapshots of the wavefield.
"""
m = model.m

# Gradient symbol and wavefield symbols
# Gradient symbol
grad = Function(name='grad', grid=model.grid)
u = TimeFunction(name='u', grid=model.grid, save=geometry.nt if save
else None, time_order=2, space_order=space_order)
v = TimeFunction(name='v', grid=model.grid, save=None,
time_order=2, space_order=space_order)
rec = geometry.rec

# Create the adjoint wavefield
v = TimeFunction(name='v', grid=model.grid, time_order=2, space_order=space_order)

s = model.grid.stepping_dim.spacing
eqn = iso_stencil(v, model, kernel, forward=False)

if kernel == 'OT2':
gradient_update = Inc(grad, - u * v.dt2)
elif kernel == 'OT4':
gradient_update = Inc(grad, - u * v.dt2 - s**2 / 12.0 * u.biharmonic(m**(-2)) * v)
# Add expression for receiver injection
rec = geometry.rec
receivers = rec.inject(field=v.backward, expr=rec * s**2 / m)

time = model.grid.time_dim

if factor is not None:
# Condition to apply gradient update only at snapshot times
condition = Eq(time % factor, 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No you don't need that usnap already contains the conditon

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

# Create the ConditionalDimension for subsampling
time_subsampled = ConditionalDimension('t_sub', parent=time, factor=factor)
# Define usnaps with time_subsampled as its time dimension
nsnaps = (geometry.nt + factor - 1) // factor
usnaps = TimeFunction(name='usnaps', grid=model.grid,
time_order=2, space_order=space_order,
save=nsnaps, time_dim=time_subsampled)
# Gradient update without indexing usnaps
if kernel == 'OT2':
gradient_update = Inc(grad, - usnaps * v.dt2, implicit_dims=[time_subsampled],
condition=condition)
elif kernel == 'OT4':
gradient_update = Inc(grad, - usnaps * v.dt2
- s**2 / 12.0 * usnaps.biharmonic(m**(-2)) * v,
implicit_dims=[time_subsampled],
condition=condition)
else:
u = TimeFunction(name='u', grid=model.grid,
save=geometry.nt if save else None,
time_order=2, space_order=space_order)
if kernel == 'OT2':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary duplicate, u contains the information you should not need separate cases for gradient_update

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. No cases are used.

gradient_update = Inc(grad, - u * v.dt2)
elif kernel == 'OT4':
gradient_update = Inc(grad, - u * v.dt2
- s**2 / 12.0 * u.biharmonic(m**(-2)) * v)

# Substitute spacing terms to reduce flops
return Operator(eqn + receivers + [gradient_update], subs=model.spacing_map,
name='Gradient', **kwargs)
op = Operator(eqn + receivers + [gradient_update], subs=model.spacing_map,
name='Gradient', **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

re-indent

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

return op


def BornOperator(model, geometry, space_order=4,
Expand Down Expand Up @@ -274,4 +323,4 @@ def BornOperator(model, geometry, space_order=4,

# Substitute spacing terms to reduce flops
return Operator(eqn1 + source + eqn2 + receivers, subs=model.spacing_map,
name='Born', **kwargs)
name='Born', **kwargs)
66 changes: 51 additions & 15 deletions examples/seismic/acoustic/wavesolver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from devito import Function, TimeFunction, DevitoCheckpoint, CheckpointOperator, Revolver
from devito.tools import memoized_meth
from examples.seismic.acoustic.operators import (
from devitofwi.devito.acoustic.operators import (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leftover?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. I did not catch it.

ForwardOperator, AdjointOperator, GradientOperator, BornOperator
)

Expand All @@ -23,6 +23,7 @@ class AcousticWaveSolver:
space_order: int, optional
Order of the spatial stencil discretisation. Defaults to 4.
"""

def __init__(self, model, geometry, kernel='OT2', space_order=4, **kwargs):
self.model = model
self.model._initialize_bcs(bcs="damp")
Expand All @@ -44,11 +45,11 @@ def dt(self):
return self.model.critical_dt

@memoized_meth
def op_fwd(self, save=None):
def op_fwd(self, save=None, factor=None):
"""Cached operator for forward runs with buffered wavefield"""
return ForwardOperator(self.model, save=save, geometry=self.geometry,
kernel=self.kernel, space_order=self.space_order,
**self._kwargs)
factor=factor, **self._kwargs)

@memoized_meth
def op_adj(self):
Expand All @@ -58,11 +59,11 @@ def op_adj(self):
**self._kwargs)

@memoized_meth
def op_grad(self, save=True):
def op_grad(self, save=True, factor=None):
"""Cached operator for gradient runs"""
return GradientOperator(self.model, save=save, geometry=self.geometry,
kernel=self.kernel, space_order=self.space_order,
**self._kwargs)
factor=factor, **self._kwargs)

@memoized_meth
def op_born(self):
Expand All @@ -71,7 +72,7 @@ def op_born(self):
kernel=self.kernel, space_order=self.space_order,
**self._kwargs)

def forward(self, src=None, rec=None, u=None, model=None, save=None, **kwargs):
def forward(self, src=None, rec=None, u=None, model=None, save=None, factor=None, **kwargs):
"""
Forward modelling function that creates the necessary
data objects for running a forward modelling operator.
Expand All @@ -90,6 +91,8 @@ def forward(self, src=None, rec=None, u=None, model=None, save=None, **kwargs):
The time-constant velocity.
save : bool, optional
Whether or not to save the entire (unrolled) wavefield.
factor : int, optional
Downsampling factor to save snapshots of the wavefield.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indent

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed


Returns
-------
Expand All @@ -108,12 +111,24 @@ def forward(self, src=None, rec=None, u=None, model=None, save=None, **kwargs):
model = model or self.model
# Pick vp from model unless explicitly provided
kwargs.update(model.physical_params(**kwargs))
# Get the operator
op_fwd = self.op_fwd(save=save, factor=factor)
# Prepare parameters for operator apply
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't know what this is for.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

op_args = {'src': src, 'rec': rec, 'u': u, 'dt': kwargs.pop('dt', self.dt)}
op_args.update(kwargs)

# Execute operator and return wavefield and receiver data
summary = self.op_fwd(save).apply(src=src, rec=rec, u=u,
dt=kwargs.pop('dt', self.dt), **kwargs)

return rec, u, summary
if factor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, usnap needs to be create here like u then passed as argument

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. usnaps is created now.

# Operator returned is op, usnaps
op, usnaps = op_fwd
op_args['usnaps'] = usnaps
summary = op.apply(**op_args)

else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't need if else only kwargs

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

op = op_fwd
usnaps = None
summary = op.apply(**op_args)
return rec, u, usnaps, summary

def adjoint(self, rec, srca=None, v=None, model=None, **kwargs):
"""
Expand Down Expand Up @@ -155,8 +170,8 @@ def adjoint(self, rec, srca=None, v=None, model=None, **kwargs):
dt=kwargs.pop('dt', self.dt), **kwargs)
return srca, v, summary

def jacobian_adjoint(self, rec, u, src=None, v=None, grad=None, model=None,
checkpointing=False, **kwargs):
def jacobian_adjoint(self, rec, u=None, usnaps=None, src=None, v=None, grad=None, model=None,
factor=None, checkpointing=False, **kwargs):
"""
Gradient modelling function for computing the adjoint of the
Linearized Born modelling function, ie. the action of the
Expand All @@ -168,6 +183,8 @@ def jacobian_adjoint(self, rec, u, src=None, v=None, grad=None, model=None,
Receiver data.
u : TimeFunction
Full wavefield `u` (created with save=True).
usnaps : TimeFunction
Snapshots of the wavefield `u`.
v : TimeFunction, optional
Stores the computed wavefield.
grad : Function, optional
Expand All @@ -176,12 +193,22 @@ def jacobian_adjoint(self, rec, u, src=None, v=None, grad=None, model=None,
Object containing the physical parameters.
vp : Function or float, optional
The time-constant velocity.
checkpointing : boolean, optional
Flag to enable checkpointing (default False).
Cannot be used with snapshotting.
factor : int, optional
Downsampling factor for the saved snapshots of the wavefield `u`.
Cannot be used with checkpointing.

Returns
-------
Gradient field and performance summary.
"""
dt = kwargs.pop('dt', self.dt)
# Check that snapshotting and checkpointing are not used together
if factor is not None and checkpointing:
raise ValueError("Cannot use snapshotting (factor) and checkpointing simultaneously.")

# Gradient symbol
grad = grad or Function(name='grad', grid=self.model.grid)

Expand Down Expand Up @@ -209,8 +236,17 @@ def jacobian_adjoint(self, rec, u, src=None, v=None, grad=None, model=None,
wrp.apply_forward()
summary = wrp.apply_reverse()
else:
summary = self.op_grad().apply(rec=rec, grad=grad, v=v, u=u, dt=dt,
**kwargs)
if factor is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, not needed, input u should contain all metada needed

Copy link
Author

@malfarhan7 malfarhan7 Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

# Get the gradient operator
op = self.op_grad(save=False, factor=factor)
op_args = {'rec': rec, 'grad': grad, 'v': v, 'dt': dt, 'usnaps': usnaps}
else:
op = self.op_grad(save=True, factor=None)
op_args = {'rec': rec, 'grad': grad, 'v': v, 'dt': dt, 'u': u}

op_args.update(kwargs)
summary = op.apply(**op_args)

return grad, summary

def jacobian(self, dmin, src=None, rec=None, u=None, U=None, model=None, **kwargs):
Expand Down Expand Up @@ -255,4 +291,4 @@ def jacobian(self, dmin, src=None, rec=None, u=None, U=None, model=None, **kwarg

# Backward compatibility
born = jacobian
gradient = jacobian_adjoint
gradient = jacobian_adjoint
2 changes: 2 additions & 0 deletions examples/seismic/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,4 +264,6 @@ def __call__(self, parser, args, values, option_string=None):
choices=['float32', 'float64'])
parser.add_argument("-interp", dest="interp", default="linear",
choices=['linear', 'sinc'])
parser.add_argument("--factor", type=int, default=None,
help="Downsampling factor to use snapshotting, default is None")
return parser