-
Notifications
You must be signed in to change notification settings - Fork 230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement snapshotting for the acoustic wave equation #2474
base: master
Are you sure you want to change the base?
Changes from 7 commits
44694f5
02be5b3
d909155
c0791d9
71a99cc
46bbf56
0353a78
810ddd3
6c1bc75
7f12068
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from devito import Eq, Operator, Function, TimeFunction, Inc, solve, sign | ||
from devito import Eq, Operator, Function, TimeFunction, Inc, solve, sign, ConditionalDimension | ||
from devito.symbolics import retrieve_functions, INT, retrieve_derivatives | ||
|
||
|
||
|
@@ -106,9 +106,17 @@ def iso_stencil(field, model, kernel, **kwargs): | |
eqns.append(freesurface(model, Eq(unext, eq_time))) | ||
return eqns | ||
|
||
def create_snapshot_time_function(model, name, geometry, space_order, factor, save=True): | ||
nsnaps = (geometry.nt + factor - 1) // factor | ||
time_subsampled = ConditionalDimension('t_sub', | ||
parent=model.grid.time_dim, factor=factor) | ||
u_ = TimeFunction(name=name, grid=model.grid, | ||
time_order=2, space_order=space_order, | ||
save=nsnaps if save else None, time_dim=time_subsampled) | ||
return u_ | ||
|
||
def ForwardOperator(model, geometry, space_order=4, | ||
save=False, kernel='OT2', **kwargs): | ||
save=False, kernel='OT2', factor=None, **kwargs): | ||
""" | ||
Construct a forward modelling operator in an acoustic medium. | ||
|
||
|
@@ -126,13 +134,16 @@ def ForwardOperator(model, geometry, space_order=4, | |
Defaults to False. | ||
kernel : str, optional | ||
Type of discretization, 'OT2' or 'OT4'. | ||
factor : int, optional | ||
Downsampling factor to save snapshots of the wavefield. | ||
""" | ||
m = model.m | ||
|
||
# Create symbols for forward wavefield, source and receivers | ||
u = TimeFunction(name='u', grid=model.grid, | ||
save=geometry.nt if save else None, | ||
save=geometry.nt if save and factor is None else None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not a big fan of this composite conditional involving both There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved the conditional statement outside of the |
||
time_order=2, space_order=space_order) | ||
|
||
src = geometry.src | ||
rec = geometry.rec | ||
|
||
|
@@ -145,9 +156,17 @@ def ForwardOperator(model, geometry, space_order=4, | |
# Create interpolation expression for receivers | ||
rec_term = rec.interpolate(expr=u) | ||
|
||
# Build operator equations | ||
equations = eqn + src_term + rec_term | ||
|
||
if factor: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This needs to be wrapped into a utility function as it's duplicated below There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've created a function to construct usnaps. |
||
# Add equation to save snapshots | ||
usnaps = create_snapshot_time_function(model, 'usnaps', geometry, space_order, factor) | ||
snapshot_eq = Eq(usnaps, u) | ||
equations += [snapshot_eq] | ||
|
||
# Substitute spacing terms to reduce flops | ||
return Operator(eqn + src_term + rec_term, subs=model.spacing_map, | ||
name='Forward', **kwargs) | ||
return Operator(equations, subs=model.spacing_map, name='Forward', **kwargs) | ||
|
||
|
||
def AdjointOperator(model, geometry, space_order=4, | ||
|
@@ -189,9 +208,9 @@ def AdjointOperator(model, geometry, space_order=4, | |
|
||
|
||
def GradientOperator(model, geometry, space_order=4, save=True, | ||
kernel='OT2', **kwargs): | ||
kernel='OT2', factor=None, **kwargs): | ||
""" | ||
Construct a gradient operator in an acoustic media. | ||
Construct a gradient operator in an acoustic medium. | ||
|
||
Parameters | ||
---------- | ||
|
@@ -206,30 +225,38 @@ def GradientOperator(model, geometry, space_order=4, save=True, | |
Option to store the entire (unrolled) wavefield. | ||
kernel : str, optional | ||
Type of discretization, centered or shifted. | ||
factor : int, optional | ||
Downsampling factor to save snapshots of the wavefield. | ||
""" | ||
m = model.m | ||
|
||
# Gradient symbol and wavefield symbols | ||
grad = Function(name='grad', grid=model.grid) | ||
u = TimeFunction(name='u', grid=model.grid, save=geometry.nt if save | ||
else None, time_order=2, space_order=space_order) | ||
v = TimeFunction(name='v', grid=model.grid, save=None, | ||
if factor: # Apply the imaging condition at the snapshots of the full wavefield | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Leave a blank line between the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move the comment inside the body of the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed |
||
u = create_snapshot_time_function(model, 'u', geometry, space_order, factor) | ||
else:# Apply the imaging condition at every time step of the full wavefield | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move the comment inside the body of the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed |
||
u = TimeFunction(name='u', grid=model.grid, | ||
save=geometry.nt if save else None, | ||
time_order=2, space_order=space_order) | ||
v = TimeFunction(name='v', grid=model.grid, save=None, | ||
time_order=2, space_order=space_order) | ||
rec = geometry.rec | ||
|
||
s = model.grid.stepping_dim.spacing | ||
eqn = iso_stencil(v, model, kernel, forward=False) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. revert change, pep8 violation There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed |
||
if kernel == 'OT2': | ||
gradient_update = Inc(grad, - u * v.dt2) | ||
elif kernel == 'OT4': | ||
gradient_update = Inc(grad, - u * v.dt2 - s**2 / 12.0 * u.biharmonic(m**(-2)) * v) | ||
# Add expression for receiver injection | ||
gradient_update = Inc(grad, - u * v.dt2 | ||
- s**2 / 12.0 * u.biharmonic(m**(-2)) * v) | ||
|
||
# Add expression for receiver injection | ||
receivers = rec.inject(field=v.backward, expr=rec * s**2 / m) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as above There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed |
||
# Substitute spacing terms to reduce flops | ||
return Operator(eqn + receivers + [gradient_update], subs=model.spacing_map, | ||
name='Gradient', **kwargs) | ||
name='Gradient', **kwargs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. re-indent There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed |
||
|
||
|
||
def BornOperator(model, geometry, space_order=4, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
from devito import Function, TimeFunction, DevitoCheckpoint, CheckpointOperator, Revolver | ||
from devito.tools import memoized_meth | ||
from examples.seismic.acoustic.operators import ( | ||
ForwardOperator, AdjointOperator, GradientOperator, BornOperator | ||
ForwardOperator, AdjointOperator, GradientOperator, BornOperator, create_snapshot_time_function | ||
) | ||
|
||
|
||
|
@@ -23,6 +23,7 @@ class AcousticWaveSolver: | |
space_order: int, optional | ||
Order of the spatial stencil discretisation. Defaults to 4. | ||
""" | ||
|
||
def __init__(self, model, geometry, kernel='OT2', space_order=4, **kwargs): | ||
self.model = model | ||
self.model._initialize_bcs(bcs="damp") | ||
|
@@ -44,11 +45,11 @@ def dt(self): | |
return self.model.critical_dt | ||
|
||
@memoized_meth | ||
def op_fwd(self, save=None): | ||
def op_fwd(self, save=None, factor=None): | ||
"""Cached operator for forward runs with buffered wavefield""" | ||
return ForwardOperator(self.model, save=save, geometry=self.geometry, | ||
kernel=self.kernel, space_order=self.space_order, | ||
**self._kwargs) | ||
factor=factor, **self._kwargs) | ||
|
||
@memoized_meth | ||
def op_adj(self): | ||
|
@@ -58,11 +59,11 @@ def op_adj(self): | |
**self._kwargs) | ||
|
||
@memoized_meth | ||
def op_grad(self, save=True): | ||
def op_grad(self, save=True, factor=None): | ||
"""Cached operator for gradient runs""" | ||
return GradientOperator(self.model, save=save, geometry=self.geometry, | ||
kernel=self.kernel, space_order=self.space_order, | ||
**self._kwargs) | ||
factor=factor, **self._kwargs) | ||
|
||
@memoized_meth | ||
def op_born(self): | ||
|
@@ -71,7 +72,7 @@ def op_born(self): | |
kernel=self.kernel, space_order=self.space_order, | ||
**self._kwargs) | ||
|
||
def forward(self, src=None, rec=None, u=None, model=None, save=None, **kwargs): | ||
def forward(self, src=None, rec=None, u=None, model=None, save=None, factor=None, **kwargs): | ||
""" | ||
Forward modelling function that creates the necessary | ||
data objects for running a forward modelling operator. | ||
|
@@ -90,30 +91,38 @@ def forward(self, src=None, rec=None, u=None, model=None, save=None, **kwargs): | |
The time-constant velocity. | ||
save : bool, optional | ||
Whether or not to save the entire (unrolled) wavefield. | ||
factor : int, optional | ||
Downsampling factor to save snapshots of the wavefield. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indent There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed |
||
|
||
Returns | ||
------- | ||
Receiver, wavefield and performance summary | ||
""" | ||
# Source term is read-only, so re-use the default | ||
src = src or self.geometry.src | ||
# Create a new receiver object to store the result | ||
# Create a new receiver object to store the result | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. drop trailing space There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed |
||
rec = rec or self.geometry.rec | ||
|
||
# Create the forward wavefield if not provided | ||
u = u or TimeFunction(name='u', grid=self.model.grid, | ||
save=self.geometry.nt if save else None, | ||
time_order=2, space_order=self.space_order) | ||
|
||
save=self.geometry.nt if save and factor is None else None, | ||
time_order=2, space_order=self.space_order) | ||
if factor: # Create snapshots of the forward wavefield | ||
usnaps = create_snapshot_time_function(self.model, 'usnaps', self.geometry, self.space_order, factor) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. variable definitions should not be conditional-dependent There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now |
||
|
||
model = model or self.model | ||
# Pick vp from model unless explicitly provided | ||
kwargs.update(model.physical_params(**kwargs)) | ||
|
||
# Execute operator and return wavefield and receiver data | ||
summary = self.op_fwd(save).apply(src=src, rec=rec, u=u, | ||
dt=kwargs.pop('dt', self.dt), **kwargs) | ||
|
||
return rec, u, summary | ||
if factor: # Return snapshots of the forward wavefield | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since factor is passed down to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a better way to do this, as the code did not run correctly without the condition? I made the return statement conditional so as not to break people's code, so I kept the number of returned objects at three. |
||
summary = self.op_fwd(save, factor).apply(src=src, rec=rec, u=u, usnaps=usnaps, | ||
dt=kwargs.pop('dt', self.dt), **kwargs) | ||
return rec, usnaps, summary | ||
else:# Return the full forward wavefield | ||
summary = self.op_fwd(save, factor).apply(src=src, rec=rec, u=u, | ||
dt=kwargs.pop('dt', self.dt), **kwargs) | ||
return rec, u, summary | ||
|
||
def adjoint(self, rec, srca=None, v=None, model=None, **kwargs): | ||
""" | ||
|
@@ -145,7 +154,7 @@ def adjoint(self, rec, srca=None, v=None, model=None, **kwargs): | |
# Create the adjoint wavefield if not provided | ||
v = v or TimeFunction(name='v', grid=self.model.grid, | ||
time_order=2, space_order=self.space_order) | ||
|
||
model = model or self.model | ||
# Pick vp from model unless explicitly provided | ||
kwargs.update(model.physical_params(**kwargs)) | ||
|
@@ -156,7 +165,7 @@ def adjoint(self, rec, srca=None, v=None, model=None, **kwargs): | |
return srca, v, summary | ||
|
||
def jacobian_adjoint(self, rec, u, src=None, v=None, grad=None, model=None, | ||
checkpointing=False, **kwargs): | ||
factor=None, checkpointing=False, **kwargs): | ||
""" | ||
Gradient modelling function for computing the adjoint of the | ||
Linearized Born modelling function, ie. the action of the | ||
|
@@ -168,6 +177,8 @@ def jacobian_adjoint(self, rec, u, src=None, v=None, grad=None, model=None, | |
Receiver data. | ||
u : TimeFunction | ||
Full wavefield `u` (created with save=True). | ||
usnaps : TimeFunction | ||
Snapshots of the wavefield `u`. | ||
v : TimeFunction, optional | ||
Stores the computed wavefield. | ||
grad : Function, optional | ||
|
@@ -176,12 +187,22 @@ def jacobian_adjoint(self, rec, u, src=None, v=None, grad=None, model=None, | |
Object containing the physical parameters. | ||
vp : Function or float, optional | ||
The time-constant velocity. | ||
checkpointing : boolean, optional | ||
Flag to enable checkpointing (default False). | ||
Cannot be used with snapshotting. | ||
factor : int, optional | ||
Downsampling factor for the saved snapshots of the wavefield `u`. | ||
Cannot be used with checkpointing. | ||
|
||
Returns | ||
------- | ||
Gradient field and performance summary. | ||
""" | ||
dt = kwargs.pop('dt', self.dt) | ||
# Check that snapshotting and checkpointing are not used together | ||
if factor is not None and checkpointing: | ||
raise ValueError("Cannot use snapshotting (factor) and checkpointing simultaneously.") | ||
|
||
# Gradient symbol | ||
grad = grad or Function(name='grad', grid=self.model.grid) | ||
|
||
|
@@ -209,8 +230,9 @@ def jacobian_adjoint(self, rec, u, src=None, v=None, grad=None, model=None, | |
wrp.apply_forward() | ||
summary = wrp.apply_reverse() | ||
else: | ||
summary = self.op_grad().apply(rec=rec, grad=grad, v=v, u=u, dt=dt, | ||
summary = self.op_grad(factor=factor).apply(rec=rec, grad=grad, v=v, u=u, dt=dt, | ||
**kwargs) | ||
|
||
return grad, summary | ||
|
||
def jacobian(self, dmin, src=None, rec=None, u=None, U=None, model=None, **kwargs): | ||
|
Large diffs are not rendered by default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"usnaps" for homogeneity
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed