-
Notifications
You must be signed in to change notification settings - Fork 230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement snapshotting for the acoustic wave equation #2474
base: master
Are you sure you want to change the base?
Changes from all commits
44694f5
02be5b3
d909155
c0791d9
71a99cc
46bbf56
0353a78
810ddd3
6c1bc75
7f12068
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,6 @@ | ||
from devito import Eq, Operator, Function, TimeFunction, Inc, solve, sign | ||
from devito import ( | ||
Eq, Operator, Function, TimeFunction, Inc, solve, sign, ConditionalDimension | ||
) | ||
from devito.symbolics import retrieve_functions, INT, retrieve_derivatives | ||
|
||
|
||
|
@@ -107,8 +109,60 @@ def iso_stencil(field, model, kernel, **kwargs): | |
return eqns | ||
|
||
|
||
def create_snapshot_time_function(model, name, geometry, space_order, factor=None): | ||
""" | ||
Create a TimeFunction to store snapshots of the wavefield during simulation. | ||
|
||
Parameters | ||
---------- | ||
model : Model | ||
Object containing the physical parameters. | ||
name : str | ||
The name of the snapshot TimeFunction. | ||
geometry : AcquisitionGeometry | ||
Geometry object that contains the source (SparseTimeFunction) and | ||
receivers (SparseTimeFunction) and their position. | ||
space_order : int | ||
Space discretization order. | ||
factor : int, optional | ||
Downsampling factor for snapshot storage. If provided, snapshots are saved | ||
every `factor` time steps. Defaults to None, which disables snapshot saving. | ||
|
||
Returns | ||
------- | ||
TimeFunction configured to store snapshots of the wavefield. | ||
The snapshots are saved based on the provided downsampling factor. | ||
|
||
Notes | ||
----- | ||
- If `factor` is provided, snapshots will be saved every `factor` time steps. | ||
The number of snapshots (`nsnaps`) is calculated as: | ||
`(geometry.nt + factor - 1) // factor`. | ||
- If `factor` is None, the snapshot storage is disabled (`save=None`). | ||
- The `time_dim` of the TimeFunction is subsampled using a ConditionalDimension. | ||
""" | ||
if factor is not None: | ||
nsnaps = (geometry.nt + factor - 1) // factor | ||
else: | ||
nsnaps = None | ||
|
||
time_subsampled = ConditionalDimension( | ||
't_sub', parent=model.grid.time_dim, factor=factor | ||
) | ||
|
||
usnaps = TimeFunction( | ||
name=name, | ||
grid=model.grid, | ||
time_order=2, | ||
space_order=space_order, | ||
save=nsnaps, | ||
time_dim=time_subsampled | ||
) | ||
return usnaps | ||
|
||
|
||
def ForwardOperator(model, geometry, space_order=4, | ||
save=False, kernel='OT2', **kwargs): | ||
save=False, kernel='OT2', factor=None, **kwargs): | ||
""" | ||
Construct a forward modelling operator in an acoustic medium. | ||
|
||
|
@@ -126,13 +180,17 @@ def ForwardOperator(model, geometry, space_order=4, | |
Defaults to False. | ||
kernel : str, optional | ||
Type of discretization, 'OT2' or 'OT4'. | ||
factor : int, optional | ||
Downsampling factor to save snapshots of the wavefield. | ||
""" | ||
m = model.m | ||
|
||
# Create symbols for forward wavefield, source and receivers | ||
save_value = geometry.nt if save and factor is None else None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To make the code more homogeneous and easier to read, what we could do, imho, is the following:
In my opinion, this will dramatically clean up the code here. At the moment, the proliferation of @mloubout is on vacation ATM, but would be useful to hear from his thoughts about this matter |
||
u = TimeFunction(name='u', grid=model.grid, | ||
save=geometry.nt if save else None, | ||
save=save_value, | ||
time_order=2, space_order=space_order) | ||
|
||
src = geometry.src | ||
rec = geometry.rec | ||
|
||
|
@@ -145,9 +203,18 @@ def ForwardOperator(model, geometry, space_order=4, | |
# Create interpolation expression for receivers | ||
rec_term = rec.interpolate(expr=u) | ||
|
||
# Build operator equations | ||
equations = eqn + src_term + rec_term | ||
|
||
usnaps = create_snapshot_time_function(model, 'usnaps', geometry, space_order, factor) | ||
|
||
if factor: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This needs to be wrapped into a utility function as it's duplicated below There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've created a function to construct usnaps. |
||
# Add equation to save snapshots | ||
snapshot_eq = Eq(usnaps, u) | ||
equations += [snapshot_eq] | ||
|
||
# Substitute spacing terms to reduce flops | ||
return Operator(eqn + src_term + rec_term, subs=model.spacing_map, | ||
name='Forward', **kwargs) | ||
return Operator(equations, subs=model.spacing_map, name='Forward', **kwargs) | ||
|
||
|
||
def AdjointOperator(model, geometry, space_order=4, | ||
|
@@ -189,9 +256,9 @@ def AdjointOperator(model, geometry, space_order=4, | |
|
||
|
||
def GradientOperator(model, geometry, space_order=4, save=True, | ||
kernel='OT2', **kwargs): | ||
kernel='OT2', factor=None, **kwargs): | ||
""" | ||
Construct a gradient operator in an acoustic media. | ||
Construct a gradient operator in an acoustic medium. | ||
|
||
Parameters | ||
---------- | ||
|
@@ -206,13 +273,23 @@ def GradientOperator(model, geometry, space_order=4, save=True, | |
Option to store the entire (unrolled) wavefield. | ||
kernel : str, optional | ||
Type of discretization, centered or shifted. | ||
factor : int, optional | ||
Downsampling factor to save snapshots of the wavefield. | ||
""" | ||
m = model.m | ||
|
||
# Gradient symbol and wavefield symbols | ||
grad = Function(name='grad', grid=model.grid) | ||
u = TimeFunction(name='u', grid=model.grid, save=geometry.nt if save | ||
else None, time_order=2, space_order=space_order) | ||
|
||
if factor: | ||
# Apply the imaging condition at the snapshots of the full wavefield | ||
u = create_snapshot_time_function(model, 'u', geometry, space_order, factor) | ||
else: | ||
# Apply the imaging condition at every time step of the full wavefield | ||
u = TimeFunction(name='u', grid=model.grid, | ||
save=geometry.nt if save else None, | ||
time_order=2, space_order=space_order) | ||
|
||
v = TimeFunction(name='v', grid=model.grid, save=None, | ||
time_order=2, space_order=space_order) | ||
rec = geometry.rec | ||
|
@@ -224,6 +301,7 @@ def GradientOperator(model, geometry, space_order=4, save=True, | |
gradient_update = Inc(grad, - u * v.dt2) | ||
elif kernel == 'OT4': | ||
gradient_update = Inc(grad, - u * v.dt2 - s**2 / 12.0 * u.biharmonic(m**(-2)) * v) | ||
|
||
# Add expression for receiver injection | ||
receivers = rec.inject(field=v.backward, expr=rec * s**2 / m) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,8 @@ | ||
from devito import Function, TimeFunction, DevitoCheckpoint, CheckpointOperator, Revolver | ||
from devito.tools import memoized_meth | ||
from examples.seismic.acoustic.operators import ( | ||
ForwardOperator, AdjointOperator, GradientOperator, BornOperator | ||
ForwardOperator, AdjointOperator, GradientOperator, BornOperator, | ||
create_snapshot_time_function | ||
) | ||
|
||
|
||
|
@@ -23,6 +24,7 @@ class AcousticWaveSolver: | |
space_order: int, optional | ||
Order of the spatial stencil discretisation. Defaults to 4. | ||
""" | ||
|
||
def __init__(self, model, geometry, kernel='OT2', space_order=4, **kwargs): | ||
self.model = model | ||
self.model._initialize_bcs(bcs="damp") | ||
|
@@ -44,11 +46,11 @@ def dt(self): | |
return self.model.critical_dt | ||
|
||
@memoized_meth | ||
def op_fwd(self, save=None): | ||
def op_fwd(self, save=None, factor=None): | ||
"""Cached operator for forward runs with buffered wavefield""" | ||
return ForwardOperator(self.model, save=save, geometry=self.geometry, | ||
kernel=self.kernel, space_order=self.space_order, | ||
**self._kwargs) | ||
factor=factor, **self._kwargs) | ||
|
||
@memoized_meth | ||
def op_adj(self): | ||
|
@@ -58,11 +60,11 @@ def op_adj(self): | |
**self._kwargs) | ||
|
||
@memoized_meth | ||
def op_grad(self, save=True): | ||
def op_grad(self, save=True, factor=None): | ||
"""Cached operator for gradient runs""" | ||
return GradientOperator(self.model, save=save, geometry=self.geometry, | ||
kernel=self.kernel, space_order=self.space_order, | ||
**self._kwargs) | ||
factor=factor, **self._kwargs) | ||
|
||
@memoized_meth | ||
def op_born(self): | ||
|
@@ -71,7 +73,11 @@ def op_born(self): | |
kernel=self.kernel, space_order=self.space_order, | ||
**self._kwargs) | ||
|
||
def forward(self, src=None, rec=None, u=None, model=None, save=None, **kwargs): | ||
def forward( | ||
self, src=None, rec=None, u=None, usnaps=None, | ||
model=None, save=None, factor=None, | ||
**kwargs | ||
): | ||
""" | ||
Forward modelling function that creates the necessary | ||
data objects for running a forward modelling operator. | ||
|
@@ -90,6 +96,8 @@ def forward(self, src=None, rec=None, u=None, model=None, save=None, **kwargs): | |
The time-constant velocity. | ||
save : bool, optional | ||
Whether or not to save the entire (unrolled) wavefield. | ||
factor : int, optional | ||
Downsampling factor to save snapshots of the wavefield. | ||
|
||
Returns | ||
------- | ||
|
@@ -101,19 +109,35 @@ def forward(self, src=None, rec=None, u=None, model=None, save=None, **kwargs): | |
rec = rec or self.geometry.rec | ||
|
||
# Create the forward wavefield if not provided | ||
save_value = self.geometry.nt if save and factor is None else None | ||
u = u or TimeFunction(name='u', grid=self.model.grid, | ||
save=self.geometry.nt if save else None, | ||
save=save_value, | ||
time_order=2, space_order=self.space_order) | ||
|
||
# Create snapshots of the forward wavefield | ||
usnaps = usnaps or create_snapshot_time_function( | ||
self.model, 'usnaps', self.geometry, self.space_order, factor | ||
) | ||
|
||
model = model or self.model | ||
# Pick vp from model unless explicitly provided | ||
kwargs.update(model.physical_params(**kwargs)) | ||
|
||
# Execute operator and return wavefield and receiver data | ||
summary = self.op_fwd(save).apply(src=src, rec=rec, u=u, | ||
dt=kwargs.pop('dt', self.dt), **kwargs) | ||
|
||
return rec, u, summary | ||
if factor: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, usnap needs to be create here like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. usnaps is created now. |
||
# Return snapshots of the forward wavefield | ||
summary = self.op_fwd(save, factor).apply( | ||
src=src, rec=rec, u=u, usnaps=usnaps, | ||
dt=kwargs.pop('dt', self.dt), **kwargs | ||
) | ||
return rec, usnaps, summary | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldn't need if else only kwargs There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed |
||
# Return the full forward wavefield | ||
summary = self.op_fwd(save, factor).apply( | ||
src=src, rec=rec, u=u, | ||
dt=kwargs.pop('dt', self.dt), **kwargs | ||
) | ||
return rec, u, summary | ||
|
||
def adjoint(self, rec, srca=None, v=None, model=None, **kwargs): | ||
""" | ||
|
@@ -156,7 +180,7 @@ def adjoint(self, rec, srca=None, v=None, model=None, **kwargs): | |
return srca, v, summary | ||
|
||
def jacobian_adjoint(self, rec, u, src=None, v=None, grad=None, model=None, | ||
checkpointing=False, **kwargs): | ||
factor=None, checkpointing=False, **kwargs): | ||
""" | ||
Gradient modelling function for computing the adjoint of the | ||
Linearized Born modelling function, ie. the action of the | ||
|
@@ -168,6 +192,8 @@ def jacobian_adjoint(self, rec, u, src=None, v=None, grad=None, model=None, | |
Receiver data. | ||
u : TimeFunction | ||
Full wavefield `u` (created with save=True). | ||
usnaps : TimeFunction | ||
Snapshots of the wavefield `u`. | ||
v : TimeFunction, optional | ||
Stores the computed wavefield. | ||
grad : Function, optional | ||
|
@@ -176,12 +202,25 @@ def jacobian_adjoint(self, rec, u, src=None, v=None, grad=None, model=None, | |
Object containing the physical parameters. | ||
vp : Function or float, optional | ||
The time-constant velocity. | ||
checkpointing : boolean, optional | ||
Flag to enable checkpointing (default False). | ||
Cannot be used with snapshotting. | ||
factor : int, optional | ||
Downsampling factor for the saved snapshots of the wavefield `u`. | ||
Cannot be used with checkpointing. | ||
|
||
Returns | ||
------- | ||
Gradient field and performance summary. | ||
""" | ||
dt = kwargs.pop('dt', self.dt) | ||
# Check that snapshotting and checkpointing are not used together | ||
if factor is not None and checkpointing: | ||
raise ValueError( | ||
"Cannot use snapshotting (factor) and " | ||
"checkpointing simultaneously." | ||
) | ||
|
||
# Gradient symbol | ||
grad = grad or Function(name='grad', grid=self.model.grid) | ||
|
||
|
@@ -209,8 +248,10 @@ def jacobian_adjoint(self, rec, u, src=None, v=None, grad=None, model=None, | |
wrp.apply_forward() | ||
summary = wrp.apply_reverse() | ||
else: | ||
summary = self.op_grad().apply(rec=rec, grad=grad, v=v, u=u, dt=dt, | ||
**kwargs) | ||
summary = self.op_grad(factor=factor).apply( | ||
rec=rec, grad=grad, v=v, u=u, dt=dt, **kwargs | ||
) | ||
|
||
return grad, summary | ||
|
||
def jacobian(self, dmin, src=None, rec=None, u=None, U=None, model=None, **kwargs): | ||
|
Large diffs are not rendered by default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would lift that into
seismic/utils.py
or some new file so it can be easily reused from other pdes. It can probably be done directly from the wavefield too i,eMakes it also decoupled from those geometry/model. THis would also allow to directly return the equation
Eq(u,usave)
as well