diff --git a/libs/models/unsupervised_kernel_regression.py b/libs/models/unsupervised_kernel_regression.py index 2df2689..8a13333 100644 --- a/libs/models/unsupervised_kernel_regression.py +++ b/libs/models/unsupervised_kernel_regression.py @@ -176,7 +176,7 @@ def inverse_transform(self, Znew): def visualize(self, n_grid_points=30, cmap=None, label_data=None, label_feature=None, title_latent_space=None, title_feature_bars=None, is_show_all_label_data=False, - fig=None, fig_size=None, ax_latent_space=None, ax_feature_bars=None): + interpolation=None, fig=None, fig_size=None, ax_latent_space=None, ax_feature_bars=None): """Visualize fit model interactively. The dataset can be visualized in an exploratory way using the latent variables and the mapping estimated by UKR. When an arbitrary coordinate on the latent space is specified, the corresponding feature is displayed as a bar. @@ -199,6 +199,8 @@ def visualize(self, n_grid_points=30, cmap=None, label_data=None, label_feature= :param is_show_all_label_data: bool, optional, default = False When True the labels of the data is always shown. When False the label is only shown when the cursor overlaps the corresponding latent variable. + :param interpolation: str, optional, default = None + Interpolation method by imshow. :param fig: matplotlib.figure.Figure, default = True The figure to visualize. It is assigned only when you want to specify a figure to visualize. @@ -218,7 +220,7 @@ def visualize(self, n_grid_points=30, cmap=None, label_data=None, label_feature= self._initialize_to_visualize(n_grid_points, cmap, label_data, label_feature, title_latent_space, title_feature_bars, is_show_all_label_data, - fig, fig_size, ax_latent_space, ax_feature_bars) + interpolation, fig, fig_size, ax_latent_space, ax_feature_bars) self._draw_latent_space() self._draw_feature_bars() @@ -259,7 +261,7 @@ def __mouse_over_fig(self, event): def _initialize_to_visualize(self, n_grid_points, cmap, label_data, label_feature, title_latent_space, title_feature_bars, is_show_all_label_data, - fig, fig_size, ax_latent_space, ax_feature_bars): + interpolation, fig, fig_size, ax_latent_space, ax_feature_bars): # invalid check if self.n_components != 2: raise ValueError('Now support only n_components = 2') @@ -328,6 +330,7 @@ def _initialize_to_visualize(self, n_grid_points, cmap, label_data, label_featur self.ax_feature_bars = ax_feature_bars self.cmap = cmap + self.interpolation = interpolation self.click_point_latent_space = None # index of the clicked representative point self.clicked_mapping = self.X.mean(axis=0) self.is_initial_view = True @@ -388,10 +391,30 @@ def _draw_latent_space(self): # To draw by pcolormesh and contour, reshape arrays like grid grid_values_to_draw_3d = self.__unflatten_grid_array(self.grid_values_to_draw) grid_points_3d = self.__unflatten_grid_array(self.grid_points) - self.ax_latent_space.pcolormesh(grid_points_3d[:, :, 0], - grid_points_3d[:, :, 1], - grid_values_to_draw_3d, - cmap=self.cmap) + # set coordinate of axis + any_index = 0 + if grid_points_3d[any_index, 0, 0] < grid_points_3d[any_index, -1, 0]: + coordinate_ax_left = grid_points_3d[any_index, 0, 0] + coordinate_ax_right = grid_points_3d[any_index, -1, 0] + else: + coordinate_ax_left = grid_points_3d[any_index, -1, 0] + coordinate_ax_right = grid_points_3d[any_index, 0, 0] + grid_values_to_draw_3d = np.flip(grid_values_to_draw_3d, axis=1).copy() + + if grid_points_3d[-1, any_index, 1] < grid_points_3d[0, any_index, 1]: + coordinate_ax_bottom = grid_points_3d[-1, any_index, 1] + coordinate_ax_top = grid_points_3d[0, any_index, 1] + else: + coordinate_ax_bottom = grid_points_3d[0, any_index, 1] + coordinate_ax_top = grid_points_3d[-1, any_index, 1] + grid_values_to_draw_3d = np.flip(grid_values_to_draw_3d, axis=0).copy() + self.ax_latent_space.imshow(grid_values_to_draw_3d, + extent=[coordinate_ax_left, + coordinate_ax_right, + coordinate_ax_bottom, + coordinate_ax_top], + interpolation=self.interpolation, + cmap=self.cmap) ctr = self.ax_latent_space.contour(grid_points_3d[:, :, 0], grid_points_3d[:, :, 1], grid_values_to_draw_3d, 6, colors='k')