Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add trace2d plot #108

Merged
merged 15 commits into from
Sep 29, 2017
3 changes: 3 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ New Features
- ``plot_free_energy`` now accepts a ``return_data`` flag that will return
the data used for the free energy plot(#78).

- Added a new function ``plot_trace2d`` that plots the time evolution of a 2D numpy array
using a colorbar to map the time (#108).

Improvements
~~~~~~~~~~~~

Expand Down
50 changes: 50 additions & 0 deletions examples/plot_trace2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""
Two dimensional trace plot
===============
"""
from msmbuilder.example_datasets import FsPeptide
from msmbuilder.featurizer import DihedralFeaturizer
from msmbuilder.decomposition import tICA
from msmbuilder.cluster import MiniBatchKMeans
from msmbuilder.msm import MarkovStateModel
from matplotlib import pyplot as pp
import numpy as np

import msmexplorer as msme

rs = np.random.RandomState(42)

# Load Fs Peptide Data
trajs = FsPeptide().get().trajectories

# Extract Backbone Dihedrals
featurizer = DihedralFeaturizer(types=['phi', 'psi'])
diheds = featurizer.fit_transform(trajs)

# Perform Dimensionality Reduction
tica_model = tICA(lag_time=2, n_components=2)
tica_trajs = tica_model.fit_transform(diheds)

# Plot free 2D free energy (optional)
txx = np.concatenate(tica_trajs, axis=0)
ax = msme.plot_free_energy(
txx, obs=(0, 1), n_samples=100000,
random_state=rs,
shade=True,
clabel=True,
clabel_kwargs={'fmt': '%.1f'},
cbar=True,
cbar_kwargs={'format': '%.1f', 'label': 'Free energy (kcal/mol)'}
)
# Now plot the first trajectory on top of it to inspect it's movement
msme.plot_trace2d(
data=tica_trajs[0], ts=0.2, ax=ax,
scatter_kwargs={'s': 2},
cbar_kwargs={'format': '%d', 'label': 'Time (ns)',
'orientation': 'horizontal'},
xlabel='tIC 1', ylabel='tIC 2'
)
# Finally, let's plot every trajectory to see the individual sampled regions
f, ax = pp.subplots()
msme.plot_trace2d(tica_trajs, ax=ax, xlabel='tIC 1', ylabel='tIC 2')
pp.show()
79 changes: 78 additions & 1 deletion msmexplorer/plots/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ..utils import msme_colors
from .. import palettes

__all__ = ['plot_chord', 'plot_stackdist', 'plot_trace']
__all__ = ['plot_chord', 'plot_stackdist', 'plot_trace', 'plot_trace2d']


def plot_chord(data, ax=None, cmap=None, labels=None, labelsize=12, norm=True,
Expand Down Expand Up @@ -303,3 +303,80 @@ def plot_trace(data, label=None, window=1, ax=None, side_ax=None,
side_ax.set_title('')

return ax, side_ax


@msme_colors
def plot_trace2d(data, obs=(0, 1), ts=1.0, cbar=True, ax=None, xlabel=None,
ylabel=None, labelsize=14,
cbar_kwargs=None, scatter_kwargs=None, plot_kwargs=None):
"""
Plot a 2D trace of time-series data.

Parameters
----------
data : array-like (nsamples, 2) or list thereof
The samples. This should be a single 2-D time-series array or a list of 2-D
time-series arrays.
If it is a single 2D np.array, the elements will be scatter plotted and
color mapped to their values.
If it is a list of 2D np.arrays, each will be plotted with a single color on
the same axis.
obs: tuple, optional (default: (0,1))
Observables to plot.
ts: float, optional (default: 1.0)
Step in units of time between each data point in data
cbar: bool, optional (default: True)
Adds a colorbar that maps the evolution of points in data
ax : matplotlib axis, optional
main matplotlib figure axis for trace.
xlabel : str, optional
x-axis label
ylabel : str, optional
y-axis label
labelsize : int, optional (default: 14)
Font side for axes labels.
cbar_kwargs: dict, optional
Arguments to pass to matplotlib cbar
scatter_kwargs: dict, optional
Arguments to pass to matplotlib scatter
plot_kwargs: dict, optional
Arguments to pass to matplotlib plot
Returns
-------
ax : matplotlib axis
main matplotlib figure axis for 2D trace.
"""

if ax is None:
ax = pp.gca()
if scatter_kwargs is None:
scatter_kwargs = {}
if plot_kwargs is None:
plot_kwargs = {}

if not isinstance(obs, tuple):
raise ValueError('obs must be a tuple')

if isinstance(data, list):
# Plot each item in the list with a single color and join with lines
for item in data:
prune = item[:, obs]
ax.plot(prune[:, 0], prune[:, 1], **plot_kwargs)
else:
# A single array of data is passed, so we scatter plot
prune = data[:, obs]
c = ax.scatter(prune[:, 0], prune[:, 1],
c=np.linspace(0, data.shape[0] * ts, data.shape[0]),
**scatter_kwargs)
if cbar:
# Map the time evolution between the data points to a colorbar
if cbar_kwargs is None:
cbar_kwargs = {}
pp.colorbar(c, **cbar_kwargs)

if xlabel:
ax.set_xlabel(xlabel, size=labelsize)
if ylabel:
ax.set_ylabel(ylabel, size=labelsize)

return ax
10 changes: 9 additions & 1 deletion msmexplorer/tests/test_misc_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from matplotlib.axes import SubplotBase
from seaborn.apionly import FacetGrid

from ..plots import plot_chord, plot_stackdist, plot_trace
from ..plots import plot_chord, plot_stackdist, plot_trace, plot_trace2d
from . import PlotTestCase

rs = np.random.RandomState(42)
data = rs.rand(12, 12)
ts = rs.rand(100000, 1)
ts2 = rs.rand(100000, 2)


class TestChordPlot(PlotTestCase):
Expand Down Expand Up @@ -38,3 +39,10 @@ def test_plot_trace(self):

assert isinstance(ax, SubplotBase)
assert isinstance(side_ax, SubplotBase)

def test_plot_trace2d(self):
ax1 = plot_trace2d(ts2)
ax2 = plot_trace2d([ts2, ts2])

assert isinstance(ax1, SubplotBase)
assert isinstance(ax2, SubplotBase)