diff --git a/docs/changelog.rst b/docs/changelog.rst index 43523d0..8c64f9b 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -18,6 +18,9 @@ New Features - ``plot_free_energy`` now accepts a ``return_data`` flag that will return the data used for the free energy plot(#78). +- Added a new function ``plot_trace2d`` that plots the time evolution of a 2D numpy array + using a colorbar to map the time (#108). + Improvements ~~~~~~~~~~~~ diff --git a/examples/plot_trace2d.py b/examples/plot_trace2d.py new file mode 100644 index 0000000..31b3234 --- /dev/null +++ b/examples/plot_trace2d.py @@ -0,0 +1,50 @@ +""" +Two dimensional trace plot +=============== +""" +from msmbuilder.example_datasets import FsPeptide +from msmbuilder.featurizer import DihedralFeaturizer +from msmbuilder.decomposition import tICA +from msmbuilder.cluster import MiniBatchKMeans +from msmbuilder.msm import MarkovStateModel +from matplotlib import pyplot as pp +import numpy as np + +import msmexplorer as msme + +rs = np.random.RandomState(42) + +# Load Fs Peptide Data +trajs = FsPeptide().get().trajectories + +# Extract Backbone Dihedrals +featurizer = DihedralFeaturizer(types=['phi', 'psi']) +diheds = featurizer.fit_transform(trajs) + +# Perform Dimensionality Reduction +tica_model = tICA(lag_time=2, n_components=2) +tica_trajs = tica_model.fit_transform(diheds) + +# Plot free 2D free energy (optional) +txx = np.concatenate(tica_trajs, axis=0) +ax = msme.plot_free_energy( + txx, obs=(0, 1), n_samples=100000, + random_state=rs, + shade=True, + clabel=True, + clabel_kwargs={'fmt': '%.1f'}, + cbar=True, + cbar_kwargs={'format': '%.1f', 'label': 'Free energy (kcal/mol)'} +) +# Now plot the first trajectory on top of it to inspect it's movement +msme.plot_trace2d( + data=tica_trajs[0], ts=0.2, ax=ax, + scatter_kwargs={'s': 2}, + cbar_kwargs={'format': '%d', 'label': 'Time (ns)', + 'orientation': 'horizontal'}, + xlabel='tIC 1', ylabel='tIC 2' +) +# Finally, let's plot every trajectory to see the individual sampled regions +f, ax = pp.subplots() +msme.plot_trace2d(tica_trajs, ax=ax, xlabel='tIC 1', ylabel='tIC 2') +pp.show() diff --git a/msmexplorer/plots/misc.py b/msmexplorer/plots/misc.py index 4c81625..56a2dca 100644 --- a/msmexplorer/plots/misc.py +++ b/msmexplorer/plots/misc.py @@ -11,7 +11,7 @@ from ..utils import msme_colors from .. import palettes -__all__ = ['plot_chord', 'plot_stackdist', 'plot_trace'] +__all__ = ['plot_chord', 'plot_stackdist', 'plot_trace', 'plot_trace2d'] def plot_chord(data, ax=None, cmap=None, labels=None, labelsize=12, norm=True, @@ -303,3 +303,80 @@ def plot_trace(data, label=None, window=1, ax=None, side_ax=None, side_ax.set_title('') return ax, side_ax + + +@msme_colors +def plot_trace2d(data, obs=(0, 1), ts=1.0, cbar=True, ax=None, xlabel=None, + ylabel=None, labelsize=14, + cbar_kwargs=None, scatter_kwargs=None, plot_kwargs=None): + """ + Plot a 2D trace of time-series data. + + Parameters + ---------- + data : array-like (nsamples, 2) or list thereof + The samples. This should be a single 2-D time-series array or a list of 2-D + time-series arrays. + If it is a single 2D np.array, the elements will be scatter plotted and + color mapped to their values. + If it is a list of 2D np.arrays, each will be plotted with a single color on + the same axis. + obs: tuple, optional (default: (0,1)) + Observables to plot. + ts: float, optional (default: 1.0) + Step in units of time between each data point in data + cbar: bool, optional (default: True) + Adds a colorbar that maps the evolution of points in data + ax : matplotlib axis, optional + main matplotlib figure axis for trace. + xlabel : str, optional + x-axis label + ylabel : str, optional + y-axis label + labelsize : int, optional (default: 14) + Font side for axes labels. + cbar_kwargs: dict, optional + Arguments to pass to matplotlib cbar + scatter_kwargs: dict, optional + Arguments to pass to matplotlib scatter + plot_kwargs: dict, optional + Arguments to pass to matplotlib plot + Returns + ------- + ax : matplotlib axis + main matplotlib figure axis for 2D trace. + """ + + if ax is None: + ax = pp.gca() + if scatter_kwargs is None: + scatter_kwargs = {} + if plot_kwargs is None: + plot_kwargs = {} + + if not isinstance(obs, tuple): + raise ValueError('obs must be a tuple') + + if isinstance(data, list): + # Plot each item in the list with a single color and join with lines + for item in data: + prune = item[:, obs] + ax.plot(prune[:, 0], prune[:, 1], **plot_kwargs) + else: + # A single array of data is passed, so we scatter plot + prune = data[:, obs] + c = ax.scatter(prune[:, 0], prune[:, 1], + c=np.linspace(0, data.shape[0] * ts, data.shape[0]), + **scatter_kwargs) + if cbar: + # Map the time evolution between the data points to a colorbar + if cbar_kwargs is None: + cbar_kwargs = {} + pp.colorbar(c, **cbar_kwargs) + + if xlabel: + ax.set_xlabel(xlabel, size=labelsize) + if ylabel: + ax.set_ylabel(ylabel, size=labelsize) + + return ax diff --git a/msmexplorer/tests/test_misc_plot.py b/msmexplorer/tests/test_misc_plot.py index b7a3658..a781eaa 100644 --- a/msmexplorer/tests/test_misc_plot.py +++ b/msmexplorer/tests/test_misc_plot.py @@ -2,12 +2,13 @@ from matplotlib.axes import SubplotBase from seaborn.apionly import FacetGrid -from ..plots import plot_chord, plot_stackdist, plot_trace +from ..plots import plot_chord, plot_stackdist, plot_trace, plot_trace2d from . import PlotTestCase rs = np.random.RandomState(42) data = rs.rand(12, 12) ts = rs.rand(100000, 1) +ts2 = rs.rand(100000, 2) class TestChordPlot(PlotTestCase): @@ -38,3 +39,10 @@ def test_plot_trace(self): assert isinstance(ax, SubplotBase) assert isinstance(side_ax, SubplotBase) + + def test_plot_trace2d(self): + ax1 = plot_trace2d(ts2) + ax2 = plot_trace2d([ts2, ts2]) + + assert isinstance(ax1, SubplotBase) + assert isinstance(ax2, SubplotBase)