Skip to content

Commit

Permalink
Merge pull request #343 from neurodsp-tools/mlplts
Browse files Browse the repository at this point in the history
[ENH] - Plotting updates
  • Loading branch information
TomDonoghue authored Feb 4, 2025
2 parents dad9dae + 5470350 commit 1e7959a
Show file tree
Hide file tree
Showing 11 changed files with 305 additions and 52 deletions.
10 changes: 10 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,7 @@ Spectral
:toctree: generated/

plot_power_spectra
plot_spectra_3d
plot_scv
plot_scv_rs_lines
plot_scv_rs_matrix
Expand Down Expand Up @@ -465,6 +466,15 @@ Time Frequency

plot_timefrequency

Aperiodic
~~~~~~~~~

.. currentmodule:: neurodsp.plts
.. autosummary::
:toctree: generated/

plot_autocorr

Combined
~~~~~~~~

Expand Down
3 changes: 2 additions & 1 deletion neurodsp/plts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
plot_multi_time_series)
from .filt import plot_filter_properties, plot_frequency_response, plot_impulse_response
from .rhythm import plot_swm_pattern, plot_lagged_coherence
from .spectral import (plot_power_spectra, plot_spectral_hist,
from .spectral import (plot_power_spectra, plot_spectral_hist, plot_spectra_3d,
plot_scv, plot_scv_rs_lines, plot_scv_rs_matrix)
from .timefrequency import plot_timefrequency
from .aperiodic import plot_autocorr
from .combined import plot_timeseries_and_spectra
35 changes: 35 additions & 0 deletions neurodsp/plts/aperiodic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Plotting functions for neurodsp.aperiodic."""

from neurodsp.plts.style import style_plot
from neurodsp.plts.utils import check_ax, savefig, prepare_multi_plot

####################################################################################################
####################################################################################################

@savefig
@style_plot
def plot_autocorr(timepoints, autocorrs, labels=None, colors=None, ax=None, **kwargs):
"""Plot autocorrelation results.
Parameters
----------
timepoints : 1d array
Time points, in samples, at which autocorrelations are computed.
autocorrs : array
Autocorrelation values, across time lags.
labels : str or list of str, optional
Labels for each time series.
colors : str or list of str
Colors to use to plot lines.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
**kwargs
Keyword arguments for customizing the plot.
"""

ax = check_ax(ax, figsize=kwargs.pop('figsize', (6, 5)))

for time, ac, label, color in zip(*prepare_multi_plot(timepoints, autocorrs, labels, colors)):
ax.plot(time, ac, label=label, color=color)

ax.set(xlabel='Lag (Samples)', ylabel='Autocorrelation')
88 changes: 73 additions & 15 deletions neurodsp/plts/spectral.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
"""Plotting functions for neurodsp.spectral."""

from itertools import repeat, cycle

import numpy as np
import matplotlib.pyplot as plt

from neurodsp.plts.style import style_plot
from neurodsp.plts.utils import check_ax, savefig
from neurodsp.plts.utils import check_ax, check_ax_3d, savefig, prepare_multi_plot

###################################################################################################
###################################################################################################
Expand Down Expand Up @@ -47,18 +45,8 @@ def plot_power_spectra(freqs, powers, labels=None, colors=None, ax=None, **kwarg

ax = check_ax(ax, figsize=kwargs.pop('figsize', (6, 6)))

freqs = repeat(freqs) if isinstance(freqs, np.ndarray) and freqs.ndim == 1 else freqs
powers = [powers] if isinstance(powers, np.ndarray) and powers.ndim == 1 else powers

if labels is not None:
labels = [labels] if not isinstance(labels, list) else labels
else:
labels = repeat(labels)

colors = repeat(colors) if not isinstance(colors, list) else cycle(colors)

for freq, power, color, label in zip(freqs, powers, colors, labels):
ax.loglog(freq, power, color=color, label=label)
for freq, power, label, color in zip(*prepare_multi_plot(freqs, powers, labels, colors)):
ax.loglog(freq, power, label=label, color=color)

ax.set_xlabel('Frequency (Hz)')
ax.set_ylabel('Power ($V^2/Hz$)')
Expand Down Expand Up @@ -235,3 +223,73 @@ def plot_spectral_hist(freqs, power_bins, spectral_hist, spectrum_freqs=None,
if spectrum is not None:
plt_inds = np.logical_and(spectrum_freqs >= freqs[0], spectrum_freqs <= freqs[-1])
ax.plot(spectrum_freqs[plt_inds], np.log10(spectrum[plt_inds]), color='w', alpha=0.8)


@savefig
@style_plot
def plot_spectra_3d(freqs, powers, log_freqs=False, log_powers=True, colors=None,
orientation=(20, -50), zoom=1.0, ax=None, **kwargs):
"""Plot a series of power spectra in a 3D plot.
Parameters
----------
freqs : 1d or 2d array or list of 1d array
Frequency vector.
powers : 2d array or list of 1d array
Power values.
log_freqs : bool, optional, default: False
Whether to plot the frequency values in log10 space.
log_powers : bool, optional, default: True
Whether to plot the power values in log10 space.
colors : str or list of str, optional
Colors to use to plot lines.
orientation : tuple of int, optional, default: (20, -50)
Orientation to set the 3D plot. See `Axes3D.view_init` for more information.
zoom : float, optional, default: 1.0
Zoom scaling for the figure axis. See `Axes3D.set_box_aspect` for more information.
ax : matplotlib.Axes, optional
Figure axes upon which to plot. Must be a 3D axis.
**kwargs
Keyword arguments for customizing the plot.
Examples
--------
Plot power spectra in 3D:
>>> from neurodsp.sim import sim_combined
>>> from neurodsp.spectral import compute_spectrum
>>> sig1 = sim_combined(n_seconds=10, fs=500,
... components={'sim_powerlaw': {'exponent' : -1},
... 'sim_bursty_oscillation' : {'freq': 10}})
>>> sig2 = sim_combined(n_seconds=10, fs=500,
... components={'sim_powerlaw': {'exponent' : -1.5},
... 'sim_bursty_oscillation' : {'freq': 10}})
>>> freqs1, powers1 = compute_spectrum(sig1, fs=500)
>>> freqs2, powers2 = compute_spectrum(sig2, fs=500)
>>> plot_spectra_3d([freqs1, freqs2], [powers1, powers2])
"""

ax = check_ax_3d(ax)

n_spectra = len(powers)

for ind, (freq, power, _, color) in \
enumerate(zip(*prepare_multi_plot(freqs, powers, None, colors))):
ax.plot(xs=np.log10(freq) if log_freqs else freq,
ys=[ind] * len(freq),
zs=np.log10(power) if log_powers else power,
color=color,
**kwargs)

ax.set(
xlabel='Frequency (Hz)',
ylabel='Channels',
zlabel='Power',
ylim=[0, n_spectra - 1],
)

yticks = list(range(n_spectra))
ax.set_yticks(yticks, yticks)

ax.view_init(*orientation)
ax.set_box_aspect(None, zoom=zoom)
4 changes: 3 additions & 1 deletion neurodsp/plts/style.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,12 @@ def apply_custom_style(ax, **kwargs):
if ax.get_title():
ax.title.set_size(kwargs.pop('title_fontsize', TITLE_FONTSIZE))

# Settings for the axis labels
# Settings for the axis labels, including checking & setting for 3D axis
label_size = kwargs.pop('label_size', LABEL_SIZE)
ax.xaxis.label.set_size(label_size)
ax.yaxis.label.set_size(label_size)
if hasattr(ax, 'zaxis'):
ax.zaxis.label.set_size(label_size)

# Settings for the axis ticks
ax.tick_params(axis='both', which='major',
Expand Down
33 changes: 13 additions & 20 deletions neurodsp/plts/time_series.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""Plots for time series."""

from itertools import repeat, cycle
from itertools import repeat

import numpy as np
import matplotlib.pyplot as plt

from neurodsp.plts.style import style_plot
from neurodsp.plts.utils import check_ax, savefig
from neurodsp.plts.utils import check_ax, savefig, prepare_multi_plot
from neurodsp.utils.data import create_samples
from neurodsp.utils.checks import check_param_options

Expand Down Expand Up @@ -49,18 +49,12 @@ def plot_time_series(times, sigs, labels=None, colors=None, ax=None, **kwargs):

ax = check_ax(ax, kwargs.pop('figsize', (15, 3)))

sigs = [sigs] if (isinstance(sigs, np.ndarray) and sigs.ndim == 1) else sigs
times, xlabel = _check_times(times, sigs)

if labels is not None:
labels = [labels] if not isinstance(labels, list) else labels
else:
labels = repeat(labels)
times, sigs, colors, labels = prepare_multi_plot(times, sigs, colors, labels)

# If not provided, default colors for up to two signals to be black & red
if not colors and len(sigs) <= 2:
if isinstance(colors, repeat) and next(colors) is None and len(sigs) <= 2:
colors = ['k', 'r']
colors = repeat(colors) if not isinstance(colors, list) else cycle(colors)

for time, sig, color, label in zip(times, sigs, colors, labels):
ax.plot(time, sig, color=color, label=label)
Expand Down Expand Up @@ -174,32 +168,31 @@ def plot_multi_time_series(times, sigs, colors=None, ax=None, **plt_kwargs):
Keyword arguments for customizing the plot.
"""

colors = 'black' if not colors else colors
colors = repeat(colors) if isinstance(colors, str) else iter(colors)

ax = check_ax(ax, figsize=plt_kwargs.pop('figsize', (15, 5)))

sigs = [sigs] if (isinstance(sigs, np.ndarray) and sigs.ndim == 1) else sigs
colors = 'black' if not colors else colors

times, xlabel = _check_times(times, sigs)
times, sigs, _, colors = prepare_multi_plot(times, sigs, None, colors)

step = 0.8 * np.ptp(sigs[0])

for ind, (time, sig) in enumerate(zip(times, sigs)):
ax.plot(time, sig+step*ind, color=next(colors), **plt_kwargs)

ax.set(yticks=[])
ax.set_xlabel(xlabel)
ax.set_ylabel('Channels')
ax.set(xlabel=xlabel, ylabel='Channels', yticks=[])


def _check_times(times, sigs):
"""Helper function to check a times definition passed into a time series plot function."""

xlabel = 'Time (s)'
if times is None:
times = create_samples(len(sigs[0]))
if isinstance(sigs, list) or (isinstance(sigs, np.ndarray) and sigs.ndim == 2):
n_samples = len(sigs[0])
else:
n_samples = len(sigs)
times = create_samples(n_samples)
xlabel = 'Samples'

times = repeat(times) if (isinstance(times, np.ndarray) and times.ndim == 1) else times

return times, xlabel
75 changes: 75 additions & 0 deletions neurodsp/plts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from copy import deepcopy
from functools import wraps
from os.path import join as pjoin
from itertools import repeat, cycle

import numpy as np
import matplotlib.pyplot as plt

from neurodsp.plts.settings import SUPTITLE_FONTSIZE
Expand Down Expand Up @@ -60,6 +62,36 @@ def check_ax(ax, figsize=None):
return ax


def check_ax_3d(ax, figsize=None):
"""Check whether a 3D figure axes object is defined, define if not.
Parameters
----------
ax : matplotlib.Axes or None
Axes object to check if is defined. Must be 3D.
Returns
-------
ax : matplotlib.Axes
Figure axes object to use.
Raises
------
ValueError
If the ax input is a defined axis, but is not 3D.
"""

if ax and '3d' not in ax.name:
raise ValueError('Provided axis is not 3D.')

if not ax:

fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(projection='3d')

return ax


def savefig(func):
"""Decorator function to save out figures."""

Expand Down Expand Up @@ -155,3 +187,46 @@ def make_axes(n_rows, n_cols, figsize=None, row_size=4, col_size=3.6,
**title_kwargs)

return axes


def prepare_multi_plot(xs, ys, labels=None, colors=None):
"""Prepare inputs for plotting one or more elements in a loop.
Parameters
----------
xs, ys : 1d or 2d array
Plot data.
labels : str or list
Label(s) for the plot input(s).
colors : str or iterable
Color(s) to plot input(s).
Returns
-------
xs, ys : iterable
Plot data.
labels : iterable
Label(s) for the plot input(s).
colors : iterable
Color(s) to plot input(s).
Notes
-----
This function takes inputs that can reflect one or more plot elements, and
prepares the inputs to be iterable for plotting in a loop.
"""

xs = repeat(xs) if isinstance(xs, np.ndarray) and xs.ndim == 1 else xs
ys = [ys] if isinstance(ys, np.ndarray) and ys.ndim == 1 else ys

# Collect definition of collection items considered iterables to check against
iterables = (list, tuple, np.ndarray)

if labels is not None:
labels = [labels] if not isinstance(labels, iterables) else labels
else:
labels = repeat(labels)

colors = repeat(colors) if not isinstance(colors, iterables) else cycle(colors)

return xs, ys, labels, colors
25 changes: 25 additions & 0 deletions neurodsp/tests/plts/test_aperiodic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""Tests for neurodsp.plts.aperiodic."""

from neurodsp.aperiodic.autocorr import compute_autocorr

from neurodsp.tests.settings import TEST_PLOTS_PATH, FS
from neurodsp.tests.tutils import plot_test

from neurodsp.plts.aperiodic import *

###################################################################################################
###################################################################################################

def tests_plot_autocorr(tsig, tsig_comb):

times1, acs1 = compute_autocorr(tsig, max_lag=150)
times2, acs2 = compute_autocorr(tsig_comb, max_lag=150)

plot_autocorr(times1, acs1,
save_fig=True, file_path=TEST_PLOTS_PATH,
file_name='test_plot_autocorr-1.png')

plot_autocorr([times1, times2], [acs1, acs2],
labels=['first', 'second'], colors=['k', 'r'],
save_fig=True, file_path=TEST_PLOTS_PATH,
file_name='test_plot_autocorr-2.png')
Loading

0 comments on commit 1e7959a

Please sign in to comment.