Skip to content

Commit

Permalink
Make test default run method
Browse files Browse the repository at this point in the history
  • Loading branch information
ccuetom committed Nov 14, 2024
1 parent 353d099 commit c58a7ba
Showing 1 changed file with 13 additions and 35 deletions.
48 changes: 13 additions & 35 deletions stride/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,6 @@ async def adjoint(problem, pde, loss, optimisation_loop, optimiser, *args, **kwa
filter_traces = kwargs.pop('filter_traces', True)
filter_wavelets = kwargs.pop('filter_wavelets', filter_traces)

# TODO
test_adjoint = kwargs.pop('test_adjoint', False)

fw3d_mode = kwargs.get('fw3d_mode', False)
filter_wavelets_relaxation = kwargs.pop('filter_wavelets_relaxation',
0.75 if not fw3d_mode else 0.725)
Expand Down Expand Up @@ -358,41 +355,22 @@ async def loop(worker, shot_id):
modelled = traces.outputs[0]
observed = traces.outputs[1]

if not test_adjoint:
# calculate loss
fun = await loss(modelled, observed,
keep_residual=keep_residual,
iteration=iteration, problem=sub_problem,
runtime=worker, **_kwargs).result()

iteration.add_loss(fun)
logger.perf('Functional value for shot %d: %s' % (shot_id, fun))

# run adjoint
await fun.adjoint(**_kwargs)
iteration.add_completed(sub_problem.shot)

logger.perf('Retrieved gradient for shot %d (%d out of %d)'
% (sub_problem.shot_id,
iteration.num_completed, num_shots))
# calculate loss
fun = loss(modelled, observed,
keep_residual=keep_residual,
iteration=iteration, problem=sub_problem,
runtime=worker, **_kwargs)

else:
# calculate loss
fun = loss(modelled, observed,
keep_residual=keep_residual,
iteration=iteration, problem=sub_problem,
runtime=worker, **_kwargs)

# run adjoint
fun_value = await fun.remote.adjoint(**_kwargs).result()
# run adjoint
fun_value = await fun.remote.adjoint(**_kwargs).result()

iteration.add_loss(fun_value)
logger.perf('Functional value for shot %d: %s' % (shot_id, fun_value))
iteration.add_loss(fun_value)
logger.perf('Functional value for shot %d: %s' % (shot_id, fun_value))

iteration.add_completed(sub_problem.shot)
logger.perf('Retrieved gradient for shot %d (%d out of %d)'
% (sub_problem.shot_id,
iteration.num_completed, num_shots))
iteration.add_completed(sub_problem.shot)
logger.perf('Retrieved gradient for shot %d (%d out of %d)'
% (sub_problem.shot_id,
iteration.num_completed, num_shots))

await loop

Expand Down

0 comments on commit c58a7ba

Please sign in to comment.