diff --git a/stride/__init__.py b/stride/__init__.py index d10ac7e..6b3123e 100644 --- a/stride/__init__.py +++ b/stride/__init__.py @@ -228,9 +228,6 @@ async def adjoint(problem, pde, loss, optimisation_loop, optimiser, *args, **kwa filter_traces = kwargs.pop('filter_traces', True) filter_wavelets = kwargs.pop('filter_wavelets', filter_traces) - # TODO - test_adjoint = kwargs.pop('test_adjoint', False) - fw3d_mode = kwargs.get('fw3d_mode', False) filter_wavelets_relaxation = kwargs.pop('filter_wavelets_relaxation', 0.75 if not fw3d_mode else 0.725) @@ -358,41 +355,22 @@ async def loop(worker, shot_id): modelled = traces.outputs[0] observed = traces.outputs[1] - if not test_adjoint: - # calculate loss - fun = await loss(modelled, observed, - keep_residual=keep_residual, - iteration=iteration, problem=sub_problem, - runtime=worker, **_kwargs).result() - - iteration.add_loss(fun) - logger.perf('Functional value for shot %d: %s' % (shot_id, fun)) - - # run adjoint - await fun.adjoint(**_kwargs) - iteration.add_completed(sub_problem.shot) - - logger.perf('Retrieved gradient for shot %d (%d out of %d)' - % (sub_problem.shot_id, - iteration.num_completed, num_shots)) + # calculate loss + fun = loss(modelled, observed, + keep_residual=keep_residual, + iteration=iteration, problem=sub_problem, + runtime=worker, **_kwargs) - else: - # calculate loss - fun = loss(modelled, observed, - keep_residual=keep_residual, - iteration=iteration, problem=sub_problem, - runtime=worker, **_kwargs) - - # run adjoint - fun_value = await fun.remote.adjoint(**_kwargs).result() + # run adjoint + fun_value = await fun.remote.adjoint(**_kwargs).result() - iteration.add_loss(fun_value) - logger.perf('Functional value for shot %d: %s' % (shot_id, fun_value)) + iteration.add_loss(fun_value) + logger.perf('Functional value for shot %d: %s' % (shot_id, fun_value)) - iteration.add_completed(sub_problem.shot) - logger.perf('Retrieved gradient for shot %d (%d out of %d)' - % (sub_problem.shot_id, - iteration.num_completed, num_shots)) + iteration.add_completed(sub_problem.shot) + logger.perf('Retrieved gradient for shot %d (%d out of %d)' + % (sub_problem.shot_id, + iteration.num_completed, num_shots)) await loop