Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement snapshotting for the acoustic wave equation #2474

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Binary file added devito/data/Marm.bin
Binary file not shown.
96 changes: 87 additions & 9 deletions examples/seismic/acoustic/operators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from devito import Eq, Operator, Function, TimeFunction, Inc, solve, sign
from devito import (
Eq, Operator, Function, TimeFunction, Inc, solve, sign, ConditionalDimension
)
from devito.symbolics import retrieve_functions, INT, retrieve_derivatives


Expand Down Expand Up @@ -107,8 +109,60 @@ def iso_stencil(field, model, kernel, **kwargs):
return eqns


def create_snapshot_time_function(model, name, geometry, space_order, factor=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would lift that into seismic/utils.py or some new file so it can be easily reused from other pdes. It can probably be done directly from the wavefield too i,e

def create_snapshot_time_function(u, nsnap):

    name = f"{u.name}_save"
    grid = u.grid
    
    ....
    

Makes it also decoupled from those geometry/model. THis would also allow to directly return the equation Eq(u,usave) as well

"""
Create a TimeFunction to store snapshots of the wavefield during simulation.

Parameters
----------
model : Model
Object containing the physical parameters.
name : str
The name of the snapshot TimeFunction.
geometry : AcquisitionGeometry
Geometry object that contains the source (SparseTimeFunction) and
receivers (SparseTimeFunction) and their position.
space_order : int
Space discretization order.
factor : int, optional
Downsampling factor for snapshot storage. If provided, snapshots are saved
every `factor` time steps. Defaults to None, which disables snapshot saving.

Returns
-------
TimeFunction configured to store snapshots of the wavefield.
The snapshots are saved based on the provided downsampling factor.

Notes
-----
- If `factor` is provided, snapshots will be saved every `factor` time steps.
The number of snapshots (`nsnaps`) is calculated as:
`(geometry.nt + factor - 1) // factor`.
- If `factor` is None, the snapshot storage is disabled (`save=None`).
- The `time_dim` of the TimeFunction is subsampled using a ConditionalDimension.
"""
if factor is not None:
nsnaps = (geometry.nt + factor - 1) // factor
else:
nsnaps = None

time_subsampled = ConditionalDimension(
't_sub', parent=model.grid.time_dim, factor=factor
)

usnaps = TimeFunction(
name=name,
grid=model.grid,
time_order=2,
space_order=space_order,
save=nsnaps,
time_dim=time_subsampled
)
return usnaps


def ForwardOperator(model, geometry, space_order=4,
save=False, kernel='OT2', **kwargs):
save=False, kernel='OT2', factor=None, **kwargs):
"""
Construct a forward modelling operator in an acoustic medium.

Expand All @@ -126,13 +180,17 @@ def ForwardOperator(model, geometry, space_order=4,
Defaults to False.
kernel : str, optional
Type of discretization, 'OT2' or 'OT4'.
factor : int, optional
Downsampling factor to save snapshots of the wavefield.
"""
m = model.m

# Create symbols for forward wavefield, source and receivers
save_value = geometry.nt if save and factor is None else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make the code more homogeneous and easier to read, what we could do, imho, is the following:

  • drop save from this TimeFunction (which is just legacy behaviour...)
  • systematically create usnaps and the corresponding equation
  • with an inf (aka sys.maxint ?) factor, no snapshots will be saved at runtime
  • Tweak here adding a line if self.size == 0: return, so that Devito avoids allocating memory entirely if pointless

In my opinion, this will dramatically clean up the code here.

At the moment, the proliferation of if factor logic is still affecting maintainability too much

@mloubout is on vacation ATM, but would be useful to hear from his thoughts about this matter

u = TimeFunction(name='u', grid=model.grid,
save=geometry.nt if save else None,
save=save_value,
time_order=2, space_order=space_order)

src = geometry.src
rec = geometry.rec

Expand All @@ -145,9 +203,18 @@ def ForwardOperator(model, geometry, space_order=4,
# Create interpolation expression for receivers
rec_term = rec.interpolate(expr=u)

# Build operator equations
equations = eqn + src_term + rec_term

usnaps = create_snapshot_time_function(model, 'usnaps', geometry, space_order, factor)

if factor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be wrapped into a utility function as it's duplicated below

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've created a function to construct usnaps.

# Add equation to save snapshots
snapshot_eq = Eq(usnaps, u)
equations += [snapshot_eq]

# Substitute spacing terms to reduce flops
return Operator(eqn + src_term + rec_term, subs=model.spacing_map,
name='Forward', **kwargs)
return Operator(equations, subs=model.spacing_map, name='Forward', **kwargs)


def AdjointOperator(model, geometry, space_order=4,
Expand Down Expand Up @@ -189,9 +256,9 @@ def AdjointOperator(model, geometry, space_order=4,


def GradientOperator(model, geometry, space_order=4, save=True,
kernel='OT2', **kwargs):
kernel='OT2', factor=None, **kwargs):
"""
Construct a gradient operator in an acoustic media.
Construct a gradient operator in an acoustic medium.

Parameters
----------
Expand All @@ -206,13 +273,23 @@ def GradientOperator(model, geometry, space_order=4, save=True,
Option to store the entire (unrolled) wavefield.
kernel : str, optional
Type of discretization, centered or shifted.
factor : int, optional
Downsampling factor to save snapshots of the wavefield.
"""
m = model.m

# Gradient symbol and wavefield symbols
grad = Function(name='grad', grid=model.grid)
u = TimeFunction(name='u', grid=model.grid, save=geometry.nt if save
else None, time_order=2, space_order=space_order)

if factor:
# Apply the imaging condition at the snapshots of the full wavefield
u = create_snapshot_time_function(model, 'u', geometry, space_order, factor)
else:
# Apply the imaging condition at every time step of the full wavefield
u = TimeFunction(name='u', grid=model.grid,
save=geometry.nt if save else None,
time_order=2, space_order=space_order)

v = TimeFunction(name='v', grid=model.grid, save=None,
time_order=2, space_order=space_order)
rec = geometry.rec
Expand All @@ -224,6 +301,7 @@ def GradientOperator(model, geometry, space_order=4, save=True,
gradient_update = Inc(grad, - u * v.dt2)
elif kernel == 'OT4':
gradient_update = Inc(grad, - u * v.dt2 - s**2 / 12.0 * u.biharmonic(m**(-2)) * v)

# Add expression for receiver injection
receivers = rec.inject(field=v.backward, expr=rec * s**2 / m)

Expand Down
69 changes: 55 additions & 14 deletions examples/seismic/acoustic/wavesolver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from devito import Function, TimeFunction, DevitoCheckpoint, CheckpointOperator, Revolver
from devito.tools import memoized_meth
from examples.seismic.acoustic.operators import (
ForwardOperator, AdjointOperator, GradientOperator, BornOperator
ForwardOperator, AdjointOperator, GradientOperator, BornOperator,
create_snapshot_time_function
)


Expand All @@ -23,6 +24,7 @@ class AcousticWaveSolver:
space_order: int, optional
Order of the spatial stencil discretisation. Defaults to 4.
"""

def __init__(self, model, geometry, kernel='OT2', space_order=4, **kwargs):
self.model = model
self.model._initialize_bcs(bcs="damp")
Expand All @@ -44,11 +46,11 @@ def dt(self):
return self.model.critical_dt

@memoized_meth
def op_fwd(self, save=None):
def op_fwd(self, save=None, factor=None):
"""Cached operator for forward runs with buffered wavefield"""
return ForwardOperator(self.model, save=save, geometry=self.geometry,
kernel=self.kernel, space_order=self.space_order,
**self._kwargs)
factor=factor, **self._kwargs)

@memoized_meth
def op_adj(self):
Expand All @@ -58,11 +60,11 @@ def op_adj(self):
**self._kwargs)

@memoized_meth
def op_grad(self, save=True):
def op_grad(self, save=True, factor=None):
"""Cached operator for gradient runs"""
return GradientOperator(self.model, save=save, geometry=self.geometry,
kernel=self.kernel, space_order=self.space_order,
**self._kwargs)
factor=factor, **self._kwargs)

@memoized_meth
def op_born(self):
Expand All @@ -71,7 +73,11 @@ def op_born(self):
kernel=self.kernel, space_order=self.space_order,
**self._kwargs)

def forward(self, src=None, rec=None, u=None, model=None, save=None, **kwargs):
def forward(
self, src=None, rec=None, u=None, usnaps=None,
model=None, save=None, factor=None,
**kwargs
):
"""
Forward modelling function that creates the necessary
data objects for running a forward modelling operator.
Expand All @@ -90,6 +96,8 @@ def forward(self, src=None, rec=None, u=None, model=None, save=None, **kwargs):
The time-constant velocity.
save : bool, optional
Whether or not to save the entire (unrolled) wavefield.
factor : int, optional
Downsampling factor to save snapshots of the wavefield.

Returns
-------
Expand All @@ -101,19 +109,35 @@ def forward(self, src=None, rec=None, u=None, model=None, save=None, **kwargs):
rec = rec or self.geometry.rec

# Create the forward wavefield if not provided
save_value = self.geometry.nt if save and factor is None else None
u = u or TimeFunction(name='u', grid=self.model.grid,
save=self.geometry.nt if save else None,
save=save_value,
time_order=2, space_order=self.space_order)

# Create snapshots of the forward wavefield
usnaps = usnaps or create_snapshot_time_function(
self.model, 'usnaps', self.geometry, self.space_order, factor
)

model = model or self.model
# Pick vp from model unless explicitly provided
kwargs.update(model.physical_params(**kwargs))

# Execute operator and return wavefield and receiver data
summary = self.op_fwd(save).apply(src=src, rec=rec, u=u,
dt=kwargs.pop('dt', self.dt), **kwargs)

return rec, u, summary
if factor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, usnap needs to be create here like u then passed as argument

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. usnaps is created now.

# Return snapshots of the forward wavefield
summary = self.op_fwd(save, factor).apply(
src=src, rec=rec, u=u, usnaps=usnaps,
dt=kwargs.pop('dt', self.dt), **kwargs
)
return rec, usnaps, summary
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't need if else only kwargs

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

# Return the full forward wavefield
summary = self.op_fwd(save, factor).apply(
src=src, rec=rec, u=u,
dt=kwargs.pop('dt', self.dt), **kwargs
)
return rec, u, summary

def adjoint(self, rec, srca=None, v=None, model=None, **kwargs):
"""
Expand Down Expand Up @@ -156,7 +180,7 @@ def adjoint(self, rec, srca=None, v=None, model=None, **kwargs):
return srca, v, summary

def jacobian_adjoint(self, rec, u, src=None, v=None, grad=None, model=None,
checkpointing=False, **kwargs):
factor=None, checkpointing=False, **kwargs):
"""
Gradient modelling function for computing the adjoint of the
Linearized Born modelling function, ie. the action of the
Expand All @@ -168,6 +192,8 @@ def jacobian_adjoint(self, rec, u, src=None, v=None, grad=None, model=None,
Receiver data.
u : TimeFunction
Full wavefield `u` (created with save=True).
usnaps : TimeFunction
Snapshots of the wavefield `u`.
v : TimeFunction, optional
Stores the computed wavefield.
grad : Function, optional
Expand All @@ -176,12 +202,25 @@ def jacobian_adjoint(self, rec, u, src=None, v=None, grad=None, model=None,
Object containing the physical parameters.
vp : Function or float, optional
The time-constant velocity.
checkpointing : boolean, optional
Flag to enable checkpointing (default False).
Cannot be used with snapshotting.
factor : int, optional
Downsampling factor for the saved snapshots of the wavefield `u`.
Cannot be used with checkpointing.

Returns
-------
Gradient field and performance summary.
"""
dt = kwargs.pop('dt', self.dt)
# Check that snapshotting and checkpointing are not used together
if factor is not None and checkpointing:
raise ValueError(
"Cannot use snapshotting (factor) and "
"checkpointing simultaneously."
)

# Gradient symbol
grad = grad or Function(name='grad', grid=self.model.grid)

Expand Down Expand Up @@ -209,8 +248,10 @@ def jacobian_adjoint(self, rec, u, src=None, v=None, grad=None, model=None,
wrp.apply_forward()
summary = wrp.apply_reverse()
else:
summary = self.op_grad().apply(rec=rec, grad=grad, v=v, u=u, dt=dt,
**kwargs)
summary = self.op_grad(factor=factor).apply(
rec=rec, grad=grad, v=v, u=u, dt=dt, **kwargs
)

return grad, summary

def jacobian(self, dmin, src=None, rec=None, u=None, U=None, model=None, **kwargs):
Expand Down
641 changes: 641 additions & 0 deletions examples/seismic/tutorials/17_fwi_gradient_snapshotting.ipynb

Large diffs are not rendered by default.

Loading