diff --git a/mosaic/cli/clusters/pbs.py b/mosaic/cli/clusters/pbs.py index 5c43a357..dfbd8fa9 100644 --- a/mosaic/cli/clusters/pbs.py +++ b/mosaic/cli/clusters/pbs.py @@ -78,6 +78,7 @@ def submission_script(name, num_nodes, num_workers, num_threads, node_memory): # use $(ppn) to use one worker per node and as many threads pr worker as cores in the node export OMP_NUM_THREADS=$num_threads_per_worker export OMP_PLACES=cores +export OMP_PROC_BIND=true # set any environment variables # for example: diff --git a/mosaic/cli/clusters/sge.py b/mosaic/cli/clusters/sge.py index 1b481ab2..2ba812c8 100644 --- a/mosaic/cli/clusters/sge.py +++ b/mosaic/cli/clusters/sge.py @@ -98,6 +98,7 @@ def submission_script(name, num_nodes, num_workers, num_threads, node_memory): # use $(ppn) to use one worker per node and as many threads pr worker as cores in the node export OMP_NUM_THREADS=$num_workers_per_node \\* $num_threads_per_worker export OMP_PLACES=cores +export OMP_PROC_BIND=true # set any environment variables # for example: diff --git a/mosaic/cli/clusters/slurm.py b/mosaic/cli/clusters/slurm.py index 9761040e..82b36ca6 100644 --- a/mosaic/cli/clusters/slurm.py +++ b/mosaic/cli/clusters/slurm.py @@ -64,6 +64,7 @@ def submission_script(name, num_nodes, num_workers, num_threads, node_memory): #SBATCH --account= #SBATCH --partition= #SBATCH --qos= +#SBATCH --exclusive name={name} num_nodes={num_nodes} @@ -81,6 +82,7 @@ def submission_script(name, num_nodes, num_workers, num_threads, node_memory): # use $(ppn) to use one worker per node and as many threads pr worker as cores in the node export OMP_NUM_THREADS=$num_threads_per_worker export OMP_PLACES=cores +export OMP_PROC_BIND=true # set any environment variables # for example: diff --git a/mosaic/runtime/monitor.py b/mosaic/runtime/monitor.py index 3bd1b31a..8a28780b 100644 --- a/mosaic/runtime/monitor.py +++ b/mosaic/runtime/monitor.py @@ -215,6 +215,8 @@ async def init_cluster(self, **kwargs): cmd = (f'srun {ssh_flags} --nodes=1 --ntasks=1 --tasks-per-node={num_cpus} ' f'--cpu-bind=mask_cpu:{cpu_mask} ' f'--oversubscribe ' + f'--distribution=block:block ' + f'--hint=nomultithread ' f'--nodelist={node_address} ' f'{remote_cmd}') diff --git a/mosaic/runtime/node.py b/mosaic/runtime/node.py index 2afe4485..fce8dc60 100644 --- a/mosaic/runtime/node.py +++ b/mosaic/runtime/node.py @@ -82,6 +82,7 @@ async def init_workers(self, **kwargs): num_threads = num_threads or num_cpus // num_workers self._num_threads = num_threads + worker_cpus = {} if self.mode == 'cluster': if num_workers*num_threads > num_cpus: raise ValueError('Requested number of CPUs per node (%d - num_workers*num_threads) ' @@ -97,15 +98,40 @@ async def init_workers(self, **kwargs): if numa_available: available_cpus = numa.info.numa_hardware_info()['node_cpu_info'] else: - available_cpus = {0: list(range(num_cpus))} + available_cpus = {worker_index: list(range(num_threads*worker_index, + num_threads*(worker_index+1))) + for worker_index in range(self._num_workers)} # Eliminate cores corresponding to hyperthreading for node_index, node_cpus in available_cpus.items(): node_cpus = [each for each in node_cpus if each < num_cpus] available_cpus[node_index] = node_cpus - available_cpus = sum(list(available_cpus.values()), []) - available_cpus.remove(num_cpus-1) + node_ids = list(available_cpus.keys()) + num_nodes = len(available_cpus) + num_cpus_per_node = min([len(cpus) for cpus in available_cpus.values()]) + + # Distribute cores across workers + if num_nodes >= self._num_workers: + nodes_per_worker = num_nodes // self._num_workers + for worker_index in range(self._num_workers): + node_s = worker_index*nodes_per_worker + node_e = min((worker_index+1)*nodes_per_worker, num_nodes) + worker_cpus[worker_index] = sum([available_cpus[node_index] + for node_index in node_ids[node_s:node_e]], []) + + else: + workers_per_node = self._num_workers // num_nodes + cpus_per_worker = num_cpus_per_node // workers_per_node + for node_index, node_cpus in available_cpus.items(): + worker_s = node_index*workers_per_node + worker_e = min((node_index+1)*workers_per_node, self._num_workers) + worker_chunk = {} + for worker_index in range(worker_s, worker_e): + cpu_s = worker_index*cpus_per_worker + cpu_e = min((worker_index+1)*cpus_per_worker, len(node_cpus)) + worker_chunk[worker_index] = node_cpus[cpu_s:cpu_e] + worker_cpus.update(worker_chunk) for worker_index in range(self._num_workers): indices = self.indices + (worker_index,) @@ -118,17 +144,10 @@ def start_worker(*args, **extra_kwargs): mosaic.init('worker', *args, **kwargs, wait=True) - worker_cpus = None - if self.mode == 'cluster': - start_cpu = worker_index * num_threads - end_cpu = min((worker_index + 1) * num_threads, len(available_cpus)) - - worker_cpus = available_cpus[start_cpu:end_cpu] - worker_proxy = RuntimeProxy(name='worker', indices=indices) worker_subprocess = subprocess(start_worker)(name=worker_proxy.uid, daemon=False, - cpu_affinity=worker_cpus) + cpu_affinity=worker_cpus.get(worker_index, None)) worker_subprocess.start_process() worker_proxy.subprocess = worker_subprocess diff --git a/mosaic/types/struct.py b/mosaic/types/struct.py index 280d51c2..bb2cac14 100644 --- a/mosaic/types/struct.py +++ b/mosaic/types/struct.py @@ -93,6 +93,9 @@ def __getattr__(self, item): def __getitem__(self, item): return self._get(item) + def __delitem__(self, item): + self.__dict__['_content'].__delitem__(item) + def get(self, item, default=None): """ Returns an item from the Struct or a default value if it is not found. diff --git a/stride/__init__.py b/stride/__init__.py index 6b303a20..46724c00 100644 --- a/stride/__init__.py +++ b/stride/__init__.py @@ -301,7 +301,7 @@ async def loop(worker, shot_id): if dump: optimiser.variable.dump(path=problem.output_folder, project_name=problem.name, - version=iteration.abs_id) + version=iteration.abs_id+1) logger.perf('Done iteration %d (out of %d), ' 'block %d (out of %d) - Total loss %e' % diff --git a/stride/optimisation/pipelines/default_pipelines.py b/stride/optimisation/pipelines/default_pipelines.py index 34bc6d01..192953f4 100644 --- a/stride/optimisation/pipelines/default_pipelines.py +++ b/stride/optimisation/pipelines/default_pipelines.py @@ -24,6 +24,9 @@ class ProcessWavelets(Pipeline): def __init__(self, steps=None, no_grad=False, **kwargs): steps = steps or [] + if kwargs.pop('check_traces', True): + steps.append('check_traces') + super().__init__(steps, no_grad=no_grad, **kwargs) @@ -43,6 +46,9 @@ class ProcessTraces(Pipeline): def __init__(self, steps=None, no_grad=False, **kwargs): steps = steps or [] + if kwargs.pop('check_traces', True): + steps.append('check_traces') + if kwargs.pop('filter_offsets', False): steps.append(('filter_offsets', False)) # do not raise if not present diff --git a/stride/optimisation/pipelines/steps/__init__.py b/stride/optimisation/pipelines/steps/__init__.py index 75bb70a8..7e97d2ae 100644 --- a/stride/optimisation/pipelines/steps/__init__.py +++ b/stride/optimisation/pipelines/steps/__init__.py @@ -8,6 +8,7 @@ from .mask import Mask from .mute_traces import MuteTraces from .clip import Clip +from .check_traces import CheckTraces steps_registry = { @@ -17,7 +18,8 @@ 'norm_per_trace': NormPerTrace, 'norm_field': NormField, 'smooth_field': SmoothField, - 'mask' : Mask, + 'mask': Mask, 'mute_traces': MuteTraces, 'clip': Clip, + 'check_traces': CheckTraces, } diff --git a/stride/optimisation/pipelines/steps/check_traces.py b/stride/optimisation/pipelines/steps/check_traces.py new file mode 100644 index 00000000..b1f5d859 --- /dev/null +++ b/stride/optimisation/pipelines/steps/check_traces.py @@ -0,0 +1,82 @@ +import numpy as np + +import mosaic + +from ....core import Operator + + +class CheckTraces(Operator): + """ + Check a set of time traces for NaNs, Inf, etc. + + Parameters + ---------- + raise_incorrect : bool, optional + Whether to raise an exception if there are incorrect traces. + Defaults to True. + filter_incorrect : bool, optional + Whether to filter out traces that are incorrect. Defaults to False. + + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.raise_incorrect = kwargs.pop('raise_incorrect', True) + self.filter_incorrect = kwargs.pop('filter_incorrect', False) + + self._num_traces = None + + def forward(self, *traces, **kwargs): + self._num_traces = len(traces) + + filtered = [] + for each in traces: + filtered.append(self._apply(each, **kwargs)) + + if len(traces) > 1: + return tuple(filtered) + + else: + return filtered[0] + + def adjoint(self, *d_traces, **kwargs): + d_traces = d_traces[:self._num_traces] + + self._num_traces = None + + if len(d_traces) > 1: + return d_traces + + else: + return d_traces[0] + + def _apply(self, traces, **kwargs): + raise_incorrect = kwargs.pop('raise_incorrect', self.raise_incorrect) + filter_incorrect = kwargs.pop('filter_incorrect', self.filter_incorrect) + + out_traces = traces.alike(name='checked_%s' % traces.name) + filtered = traces.extended_data.copy() + + is_nan = np.any(np.isnan(filtered), axis=-1) + is_inf = np.any(np.isinf(filtered), axis=-1) + + if np.any(is_nan) or np.any(is_inf): + msg = 'Nan or inf detected in %s' % traces.name + + problem = kwargs.pop('problem', None) + shot_id = problem.shot.id if problem is not None else kwargs.pop('shot_id', None) + if shot_id is not None: + msg = '(ShotID %d) ' % shot_id + msg + + if raise_incorrect: + raise RuntimeError(msg) + else: + mosaic.logger().warn(msg) + + if filter_incorrect: + filtered[is_nan | is_inf, :] = 0 + + out_traces.extended_data[:] = filtered + + return out_traces diff --git a/stride/physics/common/devito.py b/stride/physics/common/devito.py index a637353c..f9d02a69 100644 --- a/stride/physics/common/devito.py +++ b/stride/physics/common/devito.py @@ -134,12 +134,13 @@ def _cached(func): @functools.wraps(func) def cached_wrapper(self, *args, **kwargs): name = args[0] - cached = kwargs.pop('cached', True) + cached = kwargs.get('cached', True) + replace_cached = kwargs.get('replace_cached', True) if cached is True: fun = self.vars.get(name, None) if fun is not None: - _args, _kwargs = self._args[name] + _args, _kwargs = self.cached_args[name] same_args = True for arg, _arg in zip(args, _args): @@ -157,10 +158,18 @@ def cached_wrapper(self, *args, **kwargs): if same_args: return fun + elif replace_cached: + self.vars.pop(name, None) + self.cached_args.pop(name, None) + + if name not in self.cached_funcs: + self.cached_funcs[name] = cached_wrapper + fun = func(self, *args, **kwargs) - self.vars[name] = fun - self._args[name] = (args, kwargs) + if replace_cached: + self.vars[name] = fun + self.cached_args[name] = (args, kwargs) return fun @@ -196,7 +205,8 @@ def __init__(self, space_order, time_order, time_dim=None, **kwargs): super().__init__(**kwargs) self.vars = Struct() - self._args = Struct() + self.cached_args = Struct() + self.cached_funcs = Struct() self.space_order = space_order self.time_order = time_order @@ -286,7 +296,7 @@ def function(self, name, space_order=None, **kwargs): Generated function. """ - space_order = space_order or self.space_order + space_order = self.space_order if space_order is None else space_order fun = devito.Function(name=name, grid=kwargs.pop('grid', self.devito_grid), @@ -318,8 +328,8 @@ def time_function(self, name, space_order=None, time_order=None, **kwargs): Generated function. """ - space_order = space_order or self.space_order - time_order = time_order or self.time_order + space_order = self.space_order if space_order is None else space_order + time_order = self.time_order if time_order is None else time_order fun = devito.TimeFunction(name=name, grid=kwargs.pop('grid', self.devito_grid), @@ -350,7 +360,7 @@ def vector_function(self, name, space_order=None, **kwargs): Generated function. """ - space_order = space_order or self.space_order + space_order = self.space_order if space_order is None else space_order fun = devito.VectorFunction(name=name, grid=kwargs.pop('grid', self.devito_grid), @@ -382,8 +392,8 @@ def vector_time_function(self, name, space_order=None, time_order=None, **kwargs Generated function. """ - space_order = space_order or self.space_order - time_order = time_order or self.time_order + space_order = self.space_order if space_order is None else space_order + time_order = self.time_order if time_order is None else time_order fun = devito.VectorTimeFunction(name=name, grid=kwargs.pop('grid', self.devito_grid), @@ -416,8 +426,8 @@ def tensor_time_function(self, name, space_order=None, time_order=None, **kwargs Generated function. """ - space_order = space_order or self.space_order - time_order = time_order or self.time_order + space_order = self.space_order if space_order is None else space_order + time_order = self.time_order if time_order is None else time_order fun = devito.TensorTimeFunction(name=name, grid=kwargs.pop('grid', self.devito_grid), @@ -534,8 +544,8 @@ def sparse_time_function(self, name, num=1, space_order=None, time_order=None, Generated function. """ - space_order = space_order or self.space_order - time_order = time_order or self.time_order + space_order = self.space_order if space_order is None else space_order + time_order = self.time_order if time_order is None else time_order # Define variables p_dim = kwargs.pop('p_dim', devito.Dimension(name='p_%s' % name)) @@ -597,7 +607,7 @@ def sparse_function(self, name, num=1, space_order=None, Generated function. """ - space_order = space_order or self.space_order + space_order = self.space_order if space_order is None else space_order # Define variables p_dim = kwargs.pop('p_dim', devito.Dimension(name='p_%s' % name)) @@ -628,6 +638,26 @@ def sparse_function(self, name, num=1, space_order=None, return fun + def func(self, name, cached=False): + """ + Re-instantiate devito function, if ``name`` is cached. + + Parameters + ---------- + name : str + Name of the function. + cached : bool, optional + Whether to cache the result of the func call, defaults to ``False``. + + Returns + ------- + + """ + func = self.cached_funcs[name] + args, kwargs = self.cached_args[name] + return func(self, *args, **kwargs, + cached=cached, replace_cached=cached) + def deallocate(self, name, collect=False): """ Remove internal references to data buffers, if ``name`` is cached. @@ -655,6 +685,29 @@ def deallocate(self, name, collect=False): if collect: gc.collect() + def delete(self, name, collect=False): + """ + Remove internal references to devito function, if ``name`` is cached. + + Parameters + ---------- + name : str + Name of the function. + collect : bool, optional + Whether to garbage collect after deallocate, defaults to ``False``. + + Returns + ------- + + """ + if name in self.vars: + del self.vars[name] + del self.cached_funcs[name] + del self.cached_args[name] + + if collect: + devito.clear_cache(force=True) + def with_halo(self, data, value=None, time_dependent=False, is_vector=False): """ Pad ndarray with appropriate halo given the grid space order. @@ -839,6 +892,7 @@ def run(self, **kwargs): if arg.name in self.grid.vars: default_kwargs[arg.name] = self.grid.vars[arg.name] + autotune = kwargs.pop('autotune', None) default_kwargs.update(kwargs) if self.grid.time_dim: @@ -862,6 +916,14 @@ def run(self, **kwargs): runtime_kwargs[key] = value with devito.switchconfig(**self.devito_context, **runtime_context): + if autotune is None: + try: + tuned = self.devito_operator._state['autotuning'][-1]['tuned'] + runtime_kwargs.update(tuned) + runtime_kwargs['autotune'] = 'off' + except KeyError: + pass + self.devito_operator.apply(**runtime_kwargs) diff --git a/stride/physics/iso_acoustic/devito.py b/stride/physics/iso_acoustic/devito.py index feeb8d63..161d07f5 100644 --- a/stride/physics/iso_acoustic/devito.py +++ b/stride/physics/iso_acoustic/devito.py @@ -268,10 +268,12 @@ async def before_forward(self, wavelets, vp, rho=None, alpha=None, **kwargs): save_wavefield |= alpha.needs_grad platform = kwargs.get('platform', 'cpu') + is_nvidia = platform is not None and 'nvidia' in platform + diff_source = kwargs.pop('diff_source', False) save_compression = kwargs.get('save_compression', 'bitcomp' if self.space.dim > 2 else None) - save_compression = save_compression if platform and 'nvidia' in platform and devito.pro_available else None + save_compression = save_compression if is_nvidia and devito.pro_available else None # If there's no previous operator, generate one if self.state_operator.devito_operator is None: @@ -304,13 +306,20 @@ async def before_forward(self, wavelets, vp, rho=None, alpha=None, **kwargs): # Define the saving of the wavefield if save_wavefield is True: - layers = devito.HostDevice if platform and 'nvidia' in platform else devito.NoLayers + space_order = None if self._needs_grad(rho, alpha) else 0 + layers = devito.HostDevice if is_nvidia else devito.NoLayers p_saved = self.dev_grid.undersampled_time_function('p_saved', bounds=kwargs.pop('save_bounds', None), factor=self.undersampling_factor, + space_order=space_order, layers=layers, compression=save_compression) + if not is_nvidia: + self.logger.perf('(ShotID %d) Expected wavefield size %.4f GB' % + (problem.shot_id, + np.prod(p_saved.shape_allocated)*p_saved.dtype().itemsize/1024**3)) + if self._needs_grad(wavelets, rho, alpha): p_saved_expr = p else: @@ -335,15 +344,6 @@ async def before_forward(self, wavelets, vp, rho=None, alpha=None, **kwargs): self.state_operator.compile() else: - # If the wavefield is lazily streamed, re-create every time - if platform and 'nvidia' in platform and devito.pro_available: - self.dev_grid.undersampled_time_function('p_saved', - bounds=kwargs.pop('save_bounds', None), - factor=self.undersampling_factor, - layers=devito.HostDevice, - compression=save_compression, - cached=False) - # If the source/receiver size has changed, then create new functions for them if num_sources != self.dev_grid.vars.src.npoint: self.dev_grid.sparse_time_function('src', num=num_sources, cached=False) @@ -418,11 +418,13 @@ async def run_forward(self, wavelets, vp, rho=None, alpha=None, **kwargs): rec=self.dev_grid.vars.rec, ) - platform = kwargs.get('platform', 'cpu') devito_args = kwargs.get('devito_args', {}) if 'p_saved' in self.dev_grid.vars: - functions['p_saved'] = self.dev_grid.vars.p_saved + if self._wavefield is None: + self._wavefield = self.dev_grid.func('p_saved') + + functions['p_saved'] = self._wavefield if 'nbits_compression' in kwargs or 'nbits' in devito_args: devito_args['nbits'] = kwargs.get('nbits_compression', @@ -467,8 +469,6 @@ async def after_forward(self, wavelets, vp, rho=None, alpha=None, **kwargs): save_wavefield |= alpha.needs_grad if save_wavefield: - self._wavefield = self.dev_grid.vars.p_saved - if os.environ.get('STRIDE_DUMP_WAVEFIELD', None) == 'yes': self.wavefield.dump(path=problem.output_folder, project_name=problem.name) @@ -507,7 +507,6 @@ def _rm_tmpdir(): shutil.rmtree(self._cache_folder, ignore_errors=True) raise - del self._wavefield self._wavefield = None self.dev_grid.deallocate('p_saved') @@ -619,22 +618,23 @@ async def before_adjoint(self, adjoint_source, wavelets, vp, rho=None, alpha=Non slices = [slice(extra, -extra) for extra in self.space.extra] slices = (slice(0, None),) + tuple(slices) + wavefield = self.dev_grid.func('p_saved') + if cache_location is None: inner_wavefield = np.asarray([np.frombuffer(decompress(*each), dtype=np.float32) for each in self._wavefield]) inner_wavefield = inner_wavefield.reshape((inner_wavefield.shape[0],) + self.space.shape) - self.dev_grid.vars.p_saved.data[slices] = inner_wavefield - - del self._wavefield - self._wavefield = None + wavefield.data[slices] = inner_wavefield else: filename = os.path.join(self._cache_folder, '%s-%s-%05d.npy' % (problem.name, 'P', shot.id)) - self.dev_grid.vars.p_saved.data[slices] = np.load(filename) + wavefield.data[slices] = np.load(filename) os.remove(filename) + self._wavefield = wavefield + # Set medium parameters vp_with_halo = self.dev_grid.with_halo(vp.extended_data) self.dev_grid.vars.vp.data_with_halo[:] = vp_with_halo @@ -686,7 +686,7 @@ async def run_adjoint(self, adjoint_source, wavelets, vp, rho=None, alpha=None, functions = dict( vp=self.dev_grid.vars.vp, rec=self.dev_grid.vars.rec, - p_saved=self.dev_grid.vars.p_saved, + p_saved=self._wavefield, ) devito_args = kwargs.get('devito_args', {}) @@ -723,9 +723,13 @@ async def after_adjoint(self, adjoint_source, wavelets, vp, rho=None, alpha=None Tuple with the gradients of the variables that need them """ - self._wavefield = None - + platform = kwargs.get('platform', 'cpu') deallocate = kwargs.get('deallocate', False) + + if platform and 'nvidia' in platform or deallocate: + self._wavefield = None + devito.clear_cache(force=True) + if deallocate: self.boundary.deallocate() self.dev_grid.deallocate('p_a') @@ -768,17 +772,17 @@ async def prepare_grad_vp(self, vp, **kwargs): bounds=kwargs.pop('save_bounds', None), deriv_order=2, fd_order=2) - p_dt2_fun = self.dev_grid.function('p_dt2') + p_dt2_fun = self.dev_grid.function('p_dt2', space_order=0) p_dt2_update = (devito.Eq(p_dt2_fun, p_dt2, subdomain=interior),) else: p_dt2 = p p_dt2_fun = p_dt2 p_dt2_update = () - grad = self.dev_grid.function('grad_vp') + grad = self.dev_grid.function('grad_vp', space_order=0) grad_update = devito.Inc(grad, p_dt2_fun * p_a, subdomain=interior) - prec = self.dev_grid.function('prec_vp') + prec = self.dev_grid.function('prec_vp', space_order=0) prec_update = devito.Inc(prec, p_dt2_fun * p_dt2_fun, subdomain=interior) return p_dt2_update + (grad_update, prec_update) @@ -861,14 +865,16 @@ async def prepare_grad_rho(self, rho, **kwargs): grad_term = - devito.grad(buoy, shift=-0.5).dot(devito.grad(p, shift=-0.5)) \ - buoy * p.laplace + grad_rho_fun = self.dev_grid.function('grad_rho_fun', space_order=0) + grad_term_update = (devito.Eq(grad_rho_fun, grad_term, subdomain=interior),) - grad = self.dev_grid.function('grad_rho') - grad_update = devito.Inc(grad, grad_term * p_a, subdomain=interior) + grad = self.dev_grid.function('grad_rho', space_order=0) + grad_update = devito.Inc(grad, grad_rho_fun * p_a, subdomain=interior) - prec = self.dev_grid.function('prec_rho') - prec_update = devito.Inc(prec, grad_term * grad_term, subdomain=interior) + prec = self.dev_grid.function('prec_rho', space_order=0) + prec_update = devito.Inc(prec, grad_rho_fun * grad_rho_fun, subdomain=interior) - return grad_update, prec_update + return grad_term_update + (grad_update, prec_update) async def init_grad_rho(self, rho, **kwargs): """ @@ -1134,7 +1140,7 @@ def _stencil(self, field, wavelets, vp, rho=None, alpha=None, direction='forward if self.drp: extra_functions = () if rho_fun is not None: - extra_functions = (rho_fun,) + extra_functions = (rho_fun, buoy_fun,) subs = self._symbolic_coefficients(field, laplacian, vp_fun, *extra_functions) @@ -1215,7 +1221,9 @@ def _stencil(self, field, wavelets, vp, rho=None, alpha=None, direction='forward return sub_befores + eq_before + stencils + eq_after + sub_afters def _medium_functions(self, vp, rho=None, alpha=None, **kwargs): - _kwargs = dict(coefficients='symbolic' if self.drp else 'standard') + _kwargs = { + 'coefficients': 'symbolic' if self.drp else 'standard', + } vp_fun = self.dev_grid.function('vp', **_kwargs) vp2_fun = vp_fun**2