From 4621795700f15fad8f10cefcf62bb994b8691032 Mon Sep 17 00:00:00 2001 From: perdaug <2019828p@student.gla.ac.uk> Date: Thu, 7 Sep 2017 14:14:04 +0300 Subject: [PATCH] [MaxVar split, Part 2] Added the visualisation improvements. --- CHANGELOG.rst | 2 + elfi/methods/bo/gpy_regression.py | 10 ++ elfi/methods/parameter_inference.py | 174 +++++++++++----------- elfi/visualization/interactive.py | 62 ++++---- elfi/visualization/visualization.py | 208 +++++++++++++++++++++++---- tests/unit/test_document_examples.py | 2 +- 6 files changed, 309 insertions(+), 149 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index a832c562b..c49cb39d3 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -10,6 +10,8 @@ Changelog - Improved performance when rerunning inference using stored data - Change SMC to use ModelPrior, use to immediately reject invalid proposals - Added the general Gaussian noise example model (fixed covariance) +- Improved the interactive plotting (customised for the MaxVar-based acquisition methods) +- Added a pair-wise plotting to plot_state() (a way to visualise n-dimensional parameters) 0.6.1 (2017-07-21) ------------------ diff --git a/elfi/methods/bo/gpy_regression.py b/elfi/methods/bo/gpy_regression.py index 19b0a5fe8..8999a45f0 100644 --- a/elfi/methods/bo/gpy_regression.py +++ b/elfi/methods/bo/gpy_regression.py @@ -338,6 +338,16 @@ def Y(self): """Return output evidence.""" return self._gp.Y + @property + def noise(self): + """Return the noise.""" + return self._gp.Gaussian_noise.variance[0] + + @property + def instance(self): + """Return the gp instance.""" + return self._gp + def copy(self): """Return a copy of current instance.""" kopy = copy.copy(self) diff --git a/elfi/methods/parameter_inference.py b/elfi/methods/parameter_inference.py index 8c70a233d..a27b8679d 100644 --- a/elfi/methods/parameter_inference.py +++ b/elfi/methods/parameter_inference.py @@ -3,9 +3,9 @@ __all__ = ['Rejection', 'SMC', 'BayesianOptimization', 'BOLFI'] import logging +from collections import OrderedDict from math import ceil -import matplotlib.pyplot as plt import numpy as np import elfi.client @@ -89,7 +89,6 @@ def __init__(self, model = model.model if isinstance(model, NodeReference) else model if not model.parameter_names: raise ValueError('Model {} defines no parameters'.format(model)) - self.model = model.copy() self.output_names = self._check_outputs(output_names) @@ -161,7 +160,7 @@ def extract_result(self): """ raise NotImplementedError - def update(self, batch, batch_index): + def update(self, batch, batch_index, vis=None): """Update the inference state with a new batch. ELFI calls this method when a new batch has been computed and the state of @@ -174,10 +173,8 @@ def update(self, batch, batch_index): dict with `self.outputs` as keys and the corresponding outputs for the batch as values batch_index : int - - Returns - ------- - None + vis : bool, optional + Interactive visualisation of the iterations. """ self.state['n_batches'] += 1 @@ -231,7 +228,7 @@ def plot_state(self, **kwargs): """ raise NotImplementedError - def infer(self, *args, vis=None, **kwargs): + def infer(self, *args, **opts): """Set the objective and start the iterate loop until the inference is finished. See the other arguments from the `set_objective` method. @@ -241,23 +238,16 @@ def infer(self, *args, vis=None, **kwargs): result : Sample """ - vis_opt = vis if isinstance(vis, dict) else {} - - self.set_objective(*args, **kwargs) - + vis = opts.pop('vis', None) + self.set_objective(*args, **opts) while not self.finished: - self.iterate() - if vis: - self.plot_state(interactive=True, **vis_opt) - + self.iterate(vis=vis) self.batches.cancel_pending() - if vis: - self.plot_state(close=True, **vis_opt) return self.extract_result() - def iterate(self): - """Advance the inference by one iteration. + def iterate(self, vis=None): + """Forward the inference one iteration. This is a way to manually progress the inference. One iteration consists of waiting and processing the result of the next batch in succession and possibly @@ -272,6 +262,11 @@ def iterate(self): will never be more batches submitted in parallel than the `max_parallel_batches` setting allows. + Parameters + ---------- + vis : bool, optional + Interactive visualisation of the iterations. + Returns ------- None @@ -286,7 +281,7 @@ def iterate(self): # Handle the next ready batch in succession batch, batch_index = self.batches.wait_next() logger.debug('Received batch %d' % batch_index) - self.update(batch, batch_index) + self.update(batch, batch_index, vis=vis) @property def finished(self): @@ -466,17 +461,21 @@ def set_objective(self, n_samples, threshold=None, quantile=None, n_sim=None): # Reset the inference self.batches.reset() - def update(self, batch, batch_index): + def update(self, batch, batch_index, vis=None): """Update the inference state with a new batch. Parameters ---------- batch : dict - dict with `self.outputs` as keys and the corresponding outputs for the batch - as values + dict with `self.outputs` as keys and the corresponding outputs for the batch as values + vis : bool, optional + Interactive visualisation of the iterations. batch_index : int """ + if vis and self.state['samples'] is not None: + self.plot_state(interactive=True, **vis) + super(Rejection, self).update(batch, batch_index) if self.state['samples'] is None: # Lazy initialization of the outputs dict @@ -584,8 +583,8 @@ def plot_state(self, **options): displays = [] if options.get('interactive'): from IPython import display - displays.append( - display.HTML('Threshold: {}'.format(self.state['threshold']))) + html_display = 'Threshold: {}'.format(self.state['threshold']) + displays.append(display.HTML(html_display)) visin.plot_sample( self.state['samples'], @@ -651,14 +650,15 @@ def extract_result(self): threshold=pop.threshold, **self._extract_result_kwargs()) - def update(self, batch, batch_index): + def update(self, batch, batch_index, vis=None): """Update the inference state with a new batch. Parameters ---------- batch : dict - dict with `self.outputs` as keys and the corresponding outputs for the batch - as values + dict with `self.outputs` as keys and the corresponding outputs for the batch as values + vis : bool, optional + Interactive visualisation of the iterations. batch_index : int """ @@ -833,7 +833,6 @@ def __init__(self, output_names = [target_name] + model.parameter_names super(BayesianOptimization, self).__init__( model, output_names, batch_size=batch_size, **kwargs) - target_model = target_model or \ GPyRegression(self.model.parameter_names, bounds=bounds) @@ -942,7 +941,7 @@ def extract_result(self): return OptimizationResult( x_min=batch_min, outputs=outputs, **self._extract_result_kwargs()) - def update(self, batch, batch_index): + def update(self, batch, batch_index, vis=None): """Update the GP regression model of the target node with a new batch. Parameters @@ -950,6 +949,8 @@ def update(self, batch, batch_index): batch : dict dict with `self.outputs` as keys and the corresponding outputs for the batch as values + vis : bool, optional + Interactive visualisation of the iterations. batch_index : int """ @@ -958,12 +959,21 @@ def update(self, batch, batch_index): params = batch_to_arr2d(batch, self.parameter_names) self._report_batch(batch_index, params, batch[self.target_name]) + # Adding the acquisition plots. + if vis and self.batches.next_index * self.batch_size > self.n_initial_evidence: + opts = {} + opts['point_acq'] = {'x': params, 'd': batch[self.target_name]} + arr_ax = self.plot_state(interactive=True, **opts) optimize = self._should_optimize() self.target_model.update(params, batch[self.target_name], optimize) if optimize: self.state['last_GP_update'] = self.target_model.n_evidence + # Adding the updated gp plots. + if vis and self.batches.next_index * self.batch_size > self.n_initial_evidence: + self.plot_state(interactive=True, arr_ax=arr_ax, **opts) + def prepare_new_batch(self, batch_index): """Prepare values for a new batch. @@ -980,7 +990,6 @@ def prepare_new_batch(self, batch_index): """ t = self._get_acquisition_index(batch_index) - # Check if we still should take initial points from the prior if t < 0: return @@ -1040,60 +1049,40 @@ def _report_batch(self, batch_index, params, distances): str += "{}{} at {}\n".format(fill, distances[i].item(), params[i]) logger.debug(str) - def plot_state(self, **options): - """Plot the GP surface. - - This feature is still experimental and currently supports only 2D cases. - """ - f = plt.gcf() - if len(f.axes) < 2: - f, _ = plt.subplots(1, 2, figsize=(13, 6), sharex='row', sharey='row') - - gp = self.target_model - - # Draw the GP surface - visin.draw_contour( - gp.predict_mean, - gp.bounds, - self.parameter_names, - title='GP target surface', - points=gp.X, - axes=f.axes[0], - **options) - - # Draw the latest acquisitions - if options.get('interactive'): - point = gp.X[-1, :] - if len(gp.X) > 1: - f.axes[1].scatter(*point, color='red') + def plot_state(self, plot_acq_pairwise=False, arr_ax=None, **opts): + """Plot the GP surface and the acquisition space. - displays = [gp._gp] + Notes + ----- + - The plots of the GP surface and the acquisition space work for the + cases when dim < 3; + - The method is experimental. - if options.get('interactive'): - from IPython import display - displays.insert( - 0, - display.HTML('Iteration {}: Acquired {} at {}'.format( - len(gp.Y), gp.Y[-1][0], point))) - - # Update - visin._update_interactive(displays, options) - - def acq(x): - return self.acquisition_method.evaluate(x, len(gp.X)) - - # Draw the acquisition surface - visin.draw_contour( - acq, - gp.bounds, - self.parameter_names, - title='Acquisition surface', - points=None, - axes=f.axes[1], - **options) + Parameters + ---------- + plot_acq_pairwise : bool, optional + The option to plot the pair-wise acquisition point relationships. - if options.get('close'): - plt.close() + """ + if plot_acq_pairwise: + if len(self.parameter_names) == 1: + logger.info('Can not plot the pair-wise comparison for 1 parameter.') + return + # Transform the acquisition points in the acceptable format. + dict_pts_acq = OrderedDict() + for idx_param, name_param in enumerate(self.parameter_names): + dict_pts_acq[name_param] = self.target_model.X[:, idx_param] + vis.plot_pairs(dict_pts_acq, **opts) + else: + if len(self.parameter_names) == 1: + arr_ax = vis.plot_state_1d(self, arr_ax, **opts) + return arr_ax + elif len(self.parameter_names) == 2: + arr_ax = vis.plot_state_2d(self, arr_ax, **opts) + return arr_ax + else: + logger.info('The method is supported for 1- or 2-dimensions.') + return def plot_discrepancy(self, axes=None, **kwargs): """Plot acquired parameters vs. resulting discrepancy. @@ -1133,7 +1122,7 @@ class BOLFI(BayesianOptimization): """ - def fit(self, n_evidence, threshold=None): + def fit(self, n_evidence, threshold=None, **opts): """Fit the surrogate model. Generates a regression model for the discrepancy given the parameters. @@ -1150,9 +1139,8 @@ def fit(self, n_evidence, threshold=None): if n_evidence is None: raise ValueError( - 'You must specify the number of evidence (n_evidence) for the fitting') - - self.infer(n_evidence) + 'You must specify the number of evidence( n_evidence) for the fitting') + self.infer(n_evidence, **opts) return self.extract_posterior(threshold) def extract_posterior(self, threshold=None): @@ -1235,12 +1223,10 @@ def sample(self, else: inds = np.argsort(self.target_model.Y[:, 0]) initials = np.asarray(self.target_model.X[inds]) - self.target_model.is_sampling = True # enables caching for default RBF kernel tasks_ids = [] ii_initial = 0 - # sampling is embarrassingly parallel, so depending on self.client this may parallelize for ii in range(n_chains): seed = get_sub_seed(self.seed, ii) @@ -1270,12 +1256,12 @@ def sample(self, chains = np.asarray(chains) - print( - "{} chains of {} iterations acquired. Effective sample size and Rhat for each " - "parameter:".format(n_chains, n_samples)) + logger.info( + "%d chains of %d iterations acquired. Effective sample size and Rhat for each " + "parameter:" % (n_chains, n_samples)) for ii, node in enumerate(self.parameter_names): - print(node, mcmc.eff_sample_size(chains[:, :, ii]), - mcmc.gelman_rubin(chains[:, :, ii])) + chain = chains[:, :, ii] + logger.info("%s %d %d" % (node, mcmc.eff_sample_size(chain), mcmc.gelman_rubin(chain))) self.target_model.is_sampling = False diff --git a/elfi/visualization/interactive.py b/elfi/visualization/interactive.py index 2b28df105..afd599f72 100644 --- a/elfi/visualization/interactive.py +++ b/elfi/visualization/interactive.py @@ -4,14 +4,17 @@ import matplotlib.pyplot as plt import numpy as np +import scipy.interpolate logger = logging.getLogger(__name__) def plot_sample(samples, nodes=None, n=-1, displays=None, **options): - """Plot a scatterplot of samples. + """Plot a scatter-plot of samples. - Experimental, only dims 1-2 supported. + Notes + ----- + - Experimental, only dims 1-2 supported. Parameters ---------- @@ -23,7 +26,8 @@ def plot_sample(samples, nodes=None, n=-1, displays=None, **options): """ axes = _prepare_axes(options) - + if samples is None: + return nodes = nodes or sorted(samples.keys())[:2] if isinstance(nodes, str): nodes = [nodes] @@ -39,9 +43,8 @@ def plot_sample(samples, nodes=None, n=-1, displays=None, **options): axes.set_ylabel(nodes[1]) axes.scatter(samples[nodes[0]][:n], samples[nodes[1]][:n]) - _update_interactive(displays, options) - - if options.get('close'): + if options.get('interactive'): + update_interactive(displays, options) plt.close() @@ -52,7 +55,8 @@ def get_axes(**options): return plt.gca() -def _update_interactive(displays, options): +def update_interactive(displays, options): + """Update the interactive plot.""" displays = displays or [] if options.get('interactive'): from IPython import display @@ -67,7 +71,6 @@ def _prepare_axes(options): if ion: axes.clear() - if options.get('xlim'): axes.set_xlim(options.get('xlim')) if options.get('ylim'): @@ -76,7 +79,7 @@ def _prepare_axes(options): return axes -def draw_contour(fn, bounds, nodes=None, points=None, title=None, **options): +def draw_contour(fn, bounds, params=None, points=None, title=None, label=None, **options): """Plot a contour of a function. Experimental, only 2D supported. @@ -92,29 +95,34 @@ def draw_contour(fn, bounds, nodes=None, points=None, title=None, **options): title : str, optional """ - ax = get_axes(**options) - + # Preparing the contour plot settings. + if options.get('axes'): + axes = options['axes'] + plt.sca(axes) x, y = np.meshgrid(np.linspace(*bounds[0]), np.linspace(*bounds[1])) z = fn(np.c_[x.reshape(-1), y.reshape(-1)]) - if ax: - plt.sca(ax) - plt.cla() + # Plotting the contour. + CS = plt.contourf(x, y, z.reshape(x.shape), 25) + CB = plt.colorbar(CS, orientation='horizontal', format='%.1e') + CB.set_label(label) + rbf = scipy.interpolate.Rbf(x, y, z, function='linear') + zi = rbf(x, y) + plt.imshow(zi, + vmin=z.min(), + vmax=z.max(), + origin='lower', + extent=[x.min(), x.max(), y.min(), y.max()]) + + # Adding the acquisition points. + if points is not None: + plt.scatter(points[:, 0], points[:, 1], color='k') + # Adding the labels. if title: plt.title(title) - try: - plt.contour(x, y, z.reshape(x.shape)) - except ValueError: - logger.warning('Could not draw a contour plot') - if points is not None: - plt.scatter(points[:-1, 0], points[:-1, 1]) - if options.get('interactive'): - plt.scatter(points[-1, 0], points[-1, 1], color='r') - plt.xlim(bounds[0]) plt.ylim(bounds[1]) - - if nodes: - plt.xlabel(nodes[0]) - plt.ylabel(nodes[1]) + if params: + plt.xlabel(params[0]) + plt.ylabel(params[1]) diff --git a/elfi/visualization/visualization.py b/elfi/visualization/visualization.py index d3736e5a4..1b9123619 100644 --- a/elfi/visualization/visualization.py +++ b/elfi/visualization/visualization.py @@ -5,6 +5,7 @@ import matplotlib.pyplot as plt import numpy as np +import elfi.visualization.interactive as visin from elfi.model.elfi_model import Constant, ElfiModel, NodeReference @@ -99,6 +100,7 @@ def _create_axes(axes, shape, **kwargs): else: fig, axes = plt.subplots(ncols=shape[1], nrows=shape[0], **fig_kwargs) axes = np.atleast_1d(axes) + fig.tight_layout(pad=2.0) return axes, kwargs @@ -156,12 +158,157 @@ def plot_marginals(samples, selector=None, bins=20, axes=None, **kwargs): return axes -def plot_pairs(samples, selector=None, bins=20, axes=None, **kwargs): - """Plot pairwise relationships as a matrix with marginals on the diagonal. +def plot_state_1d(model_bo, arr_ax=None, **options): + """Plot the GP surface and the acquisition function in 1D. - The y-axis of marginal histograms are scaled. + Notes + ----- + The method is experimental. + + Parameters + ---------- + model_bo : elfi.methods.parameter_inference.BOLFI + + """ + gp = model_bo.target_model + pts_eval = np.linspace(*gp.bounds[0]) + + if arr_ax is None: + fig, arr_ax = plt.subplots(nrows=1, + ncols=2, + figsize=(12, 4), + sharex=True) + plt.ticklabel_format(style='sci', axis='y', scilimits=(-3, 4)) + fig.tight_layout(pad=2.0) + + # Plotting the acquisition space and the recent acquisition. + arr_ax[1].set_title('Acquisition surface') + arr_ax[1].set_xlabel(model_bo.parameter_names[0]) + arr_ax[1].set_ylabel('Acquisition score') + score_acq = model_bo.acquisition_method.evaluate(pts_eval) + arr_ax[1].plot(pts_eval, + score_acq, + color='k', + label='acquisition function') + # Plotting the confidence interval and the mean. + mean, var = gp.predict(pts_eval, noiseless=False) + sigma = np.sqrt(var) + z_95 = 1.96 + lb_ci = mean - z_95 * (sigma) + ub_ci = mean + z_95 * (sigma) + arr_ax[0].fill(np.concatenate([pts_eval, pts_eval[::-1]]), + np.concatenate([lb_ci, ub_ci[::-1]]), + alpha=.1, + fc='k', + ec='None', + label='95% confidence interval') + arr_ax[0].plot(pts_eval, mean, color='k', label='mean') + # Plotting the acquisition threshold. + if model_bo.acquisition_method.name in ['max_var', 'rand_max_var', 'exp_int_var']: + thresh_acq = np.repeat(model_bo.acquisition_method.eps, + len(pts_eval)) + arr_ax[0].plot(pts_eval, + thresh_acq, + color='g', + label='acquisition threshold') + # Plotting the acquired points. + arr_ax[0].scatter(gp.X, gp.Y, color='k') + + arr_ax[0].legend(loc='upper right') + arr_ax[0].set_title('GP target surface') + arr_ax[0].set_xlabel(model_bo.parameter_names[0]) + arr_ax[0].set_ylabel('Discrepancy') + + return arr_ax + else: + if options.get('interactive'): + from IPython import display + pt_last = options.pop('point_acq') + arr_ax[0].scatter(pt_last['x'], pt_last['d'], color='r') + ymin, ymax = arr_ax[1].get_ylim() + arr_ax[1].vlines(x=pt_last['x'], ymin=ymin, ymax=ymax, + color='r', linestyle='--', + label='latest acquisition') + + arr_ax[1].legend(loc='upper right') + displays = [] + displays.append(gp.instance) + n_it = int(len(gp.Y) / model_bo.batch_size) + html_disp = 'Iteration {}: Acquired {} at {}' \ + .format(n_it, pt_last['d'], pt_last['x']) + displays.append(display.HTML(html_disp)) + visin.update_interactive(displays, options=options) + + plt.close() + + +def plot_state_2d(model_bo, arr_ax=None, pre=False, post=False, **options): + """Plot the GP surface and the acquisition function in 2D. + + Notes + ----- + The method is experimental. + + Parameters + ---------- + model_bo : elfi.methods.parameter_inference.BOLFI - Parameters + """ + gp = model_bo.target_model + + if arr_ax is None: + # Defining the plotting settings. + _, arr_ax = plt.subplots(nrows=1, + ncols=2, + figsize=(16, 10), + sharex='row', + sharey='row') + + # Plotting the acquisition space and the recent acquisition. + def fn_acq(x): + return model_bo.acquisition_method.evaluate(x, len(gp.X)) + visin.draw_contour(fn_acq, + gp.bounds, + model_bo.parameter_names, + title='Acquisition surface', + axes=arr_ax[1], + label='Acquisition score', + **options) + # Plotting the GP target surface and the acquired points. + visin.draw_contour(gp.predict_mean, + gp.bounds, + model_bo.parameter_names, + title='GP target surface', + points=gp.X, + axes=arr_ax[0], + label='Discrepancy', + **options) + return arr_ax + else: + if options.get('interactive'): + from IPython import display + pt_last = options.pop('point_acq') + arr_ax[0].scatter(pt_last['x'][:, 0], pt_last['x'][:, 1], color='r') + arr_ax[1].scatter(pt_last['x'][:, 0], pt_last['x'][:, 1], color='r') + + displays = [] + displays.append(gp.instance) + n_it = int(len(gp.Y) / model_bo.batch_size) + html_disp = 'Iteration {}: Acquired {} at {}' \ + .format(n_it, pt_last['d'], pt_last['x']) + displays.append(display.HTML(html_disp)) + visin.update_interactive(displays, options=options) + plt.close() + + +def plot_pairs(data, selector=None, bins=20, axes=None, **kwargs): + """Plot pair-wise relationships in a grid with marginals on the diagonal. + + Notes + ----- + Removed: The y-axis of marginal histograms are scaled. + + Parameters ---------- samples : OrderedDict of np.arrays selector : iterable of ints or strings, optional @@ -175,31 +322,38 @@ def plot_pairs(samples, selector=None, bins=20, axes=None, **kwargs): axes : np.array of plt.Axes """ - samples = _limit_params(samples, selector) - shape = (len(samples), len(samples)) + # Pop the target kwargs. edgecolor = kwargs.pop('edgecolor', 'none') - dot_size = kwargs.pop('s', 2) - kwargs['sharex'] = kwargs.get('sharex', 'col') - kwargs['sharey'] = kwargs.get('sharey', 'row') - axes, kwargs = _create_axes(axes, shape, **kwargs) - - for i1, k1 in enumerate(samples): - min_samples = samples[k1].min() - max_samples = samples[k1].max() - for i2, k2 in enumerate(samples): - if i1 == i2: - # create a histogram with scaled y-axis - hist, bin_edges = np.histogram(samples[k1], bins=bins) - bar_width = bin_edges[1] - bin_edges[0] - hist = (hist - hist.min()) * (max_samples - min_samples) / ( - hist.max() - hist.min()) - axes[i1, i2].bar(bin_edges[:-1], hist, bar_width, bottom=min_samples, **kwargs) + dot_size = kwargs.pop('s', 25) + + # Filter the data. + data_selected = _limit_params(data, selector) + + # Initialise the figure. + shape_fig = (len(data_selected), len(data_selected)) + axes, kwargs = _create_axes(axes, shape_fig, **kwargs) + + # Populate the grid of figures. + for idx_row, key_row in enumerate(data_selected): + for idx_col, key_col in enumerate(data_selected): + if idx_row == idx_col: + # Plot the marginals. + axes[idx_row, idx_col].hist(data_selected[key_row], bins=bins, **kwargs) + axes[idx_row, idx_col].set_xlabel(key_row) + # Experimental: Calculate the bin length. + x_min, x_max = axes[idx_row, idx_col].get_xlim() + length_bin = (x_max - x_min) / bins + axes[idx_row, idx_col].set_ylabel( + 'Count (bin length: {0:.2f})'.format(length_bin)) else: - axes[i1, i2].scatter( - samples[k2], samples[k1], s=dot_size, edgecolor=edgecolor, **kwargs) - - axes[i1, 0].set_ylabel(k1) - axes[-1, i1].set_xlabel(k1) + # Plot the pairs. + axes[idx_row, idx_col].scatter(data_selected[key_row], + data_selected[key_col], + edgecolor=edgecolor, + s=dot_size, + **kwargs) + axes[idx_row, idx_col].set_xlabel(key_row) + axes[idx_row, idx_col].set_ylabel(key_col) return axes diff --git a/tests/unit/test_document_examples.py b/tests/unit/test_document_examples.py index f873bb04b..7fabef3af 100644 --- a/tests/unit/test_document_examples.py +++ b/tests/unit/test_document_examples.py @@ -31,7 +31,7 @@ def __init__(self, model, discrepancy_name, threshold, **kwargs): def set_objective(self, n_sim): self.objective['n_sim'] = n_sim - def update(self, batch, batch_index): + def update(self, batch, batch_index, vis=None): super(CustomMethod, self).update(batch, batch_index) # Make a filter mask (logical numpy array) from the distance array