diff --git a/scgv/views/clone.py b/scgv/views/clone.py index f43900a..5f021f1 100644 --- a/scgv/views/clone.py +++ b/scgv/views/clone.py @@ -3,26 +3,48 @@ @author: lubo ''' +import numpy as np + +import matplotlib.pyplot as plt +from matplotlib.colors import ListedColormap from scgv.views.base import ViewerBase -from scgv.utils.color_map import ColorMap class CloneViewer(ViewerBase): def __init__(self, model): super(CloneViewer, self).__init__(model) - self.cmap = ColorMap.make_qualitative06_with_white() + self._select_colormap() + + def _select_colormap_size(self, size): + assert size <= 20 + if size > 12: + cmap = plt.get_cmap('tab20') + elif size > 7: + cmap = plt.get_cmap('Paired') + else: + cmap = plt.get_cmap('tab10') + colors = ['#FFFFFF'] + colors.extend(cmap.colors[:size]) + return ListedColormap(colors) + + def _select_colormap(self): + self.vmax = np.max(self.model.subclone) + size = int(self.vmax) + self.cmap = self._select_colormap_size(size) def draw_clone(self, ax): + print(self.model.clone) + if self.model.clone is not None: ax.imshow( [self.model.clone], aspect='auto', interpolation='nearest', - cmap=self.cmap.colors, + cmap=self.cmap, # norm=self.cmap.norm, vmin=0, - vmax=6, + vmax=self.vmax, extent=self.model.bar_extent) ax.set_xticks([]) @@ -31,15 +53,15 @@ def draw_clone(self, ax): ax.set_yticklabels(["Clone"]) def draw_subclone(self, ax): + print(self.model.subclone) if self.model.subclone is not None: ax.imshow( [self.model.subclone], aspect='auto', interpolation='nearest', - cmap=self.cmap.colors, - # norm=self.cmap.norm, + cmap=self.cmap, vmin=0, - vmax=6, + vmax=self.vmax, extent=self.model.bar_extent) ax.set_xticks([]) diff --git a/scgv/views/track.py b/scgv/views/track.py index ad6e226..e51029a 100644 --- a/scgv/views/track.py +++ b/scgv/views/track.py @@ -10,7 +10,10 @@ class TrackViewer(ViewerBase): @classmethod def select_colormap(self, track_mapping): size = len(track_mapping) + return TrackViewer.select_colormap_size(size) + @classmethod + def select_colormap_size(self, size): assert size <= 20 if size > 12: cmap = plt.get_cmap('tab20')