Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] - Plotting updates #343

Merged
merged 18 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,7 @@ Spectral
:toctree: generated/

plot_power_spectra
plot_spectra_3d
plot_scv
plot_scv_rs_lines
plot_scv_rs_matrix
Expand Down Expand Up @@ -465,6 +466,15 @@ Time Frequency

plot_timefrequency

Aperiodic
~~~~~~~~~

.. currentmodule:: neurodsp.plts
.. autosummary::
:toctree: generated/

plot_autocorr

Combined
~~~~~~~~

Expand Down
3 changes: 2 additions & 1 deletion neurodsp/plts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
plot_multi_time_series)
from .filt import plot_filter_properties, plot_frequency_response, plot_impulse_response
from .rhythm import plot_swm_pattern, plot_lagged_coherence
from .spectral import (plot_power_spectra, plot_spectral_hist,
from .spectral import (plot_power_spectra, plot_spectral_hist, plot_spectra_3d,
plot_scv, plot_scv_rs_lines, plot_scv_rs_matrix)
from .timefrequency import plot_timefrequency
from .aperiodic import plot_autocorr
from .combined import plot_timeseries_and_spectra
35 changes: 35 additions & 0 deletions neurodsp/plts/aperiodic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Plotting functions for neurodsp.aperiodic."""

from neurodsp.plts.style import style_plot
from neurodsp.plts.utils import check_ax, savefig, prepare_multi_plot

####################################################################################################
####################################################################################################

@savefig
@style_plot
def plot_autocorr(timepoints, autocorrs, labels=None, colors=None, ax=None, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only thought here is whether to include it in aperiodic.py or somewhere else, since acf isn't restricted to aperiodic signals.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeh.... agreed. I wasn't sure where to put it, but since we already have neurodsp/aperiodic/autocorr as the home of the function to compute autocorrelation, this seemed like the most consistent spot for the plot function. I don't know what a better name / place for these things is - so unless we want to re-org, move both I think this works best for now?

"""Plot autocorrelation results.

Parameters
----------
timepoints : 1d array
Time points, in samples, at which autocorrelations are computed.
autocorrs : array
Autocorrelation values, across time lags.
labels : str or list of str, optional
Labels for each time series.
colors : str or list of str
Colors to use to plot lines.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
**kwargs
Keyword arguments for customizing the plot.
"""

ax = check_ax(ax, figsize=kwargs.pop('figsize', (6, 5)))

for time, ac, label, color in zip(*prepare_multi_plot(timepoints, autocorrs, labels, colors)):
ax.plot(time, ac, label=label, color=color)

ax.set(xlabel='Lag (Samples)', ylabel='Autocorrelation')
88 changes: 73 additions & 15 deletions neurodsp/plts/spectral.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
"""Plotting functions for neurodsp.spectral."""

from itertools import repeat, cycle

import numpy as np
import matplotlib.pyplot as plt

from neurodsp.plts.style import style_plot
from neurodsp.plts.utils import check_ax, savefig
from neurodsp.plts.utils import check_ax, check_ax_3d, savefig, prepare_multi_plot

###################################################################################################
###################################################################################################
Expand Down Expand Up @@ -47,18 +45,8 @@ def plot_power_spectra(freqs, powers, labels=None, colors=None, ax=None, **kwarg

ax = check_ax(ax, figsize=kwargs.pop('figsize', (6, 6)))

freqs = repeat(freqs) if isinstance(freqs, np.ndarray) and freqs.ndim == 1 else freqs
powers = [powers] if isinstance(powers, np.ndarray) and powers.ndim == 1 else powers

if labels is not None:
labels = [labels] if not isinstance(labels, list) else labels
else:
labels = repeat(labels)

colors = repeat(colors) if not isinstance(colors, list) else cycle(colors)

for freq, power, color, label in zip(freqs, powers, colors, labels):
ax.loglog(freq, power, color=color, label=label)
for freq, power, label, color in zip(*prepare_multi_plot(freqs, powers, labels, colors)):
ax.loglog(freq, power, label=label, color=color)

ax.set_xlabel('Frequency (Hz)')
ax.set_ylabel('Power ($V^2/Hz$)')
Expand Down Expand Up @@ -235,3 +223,73 @@ def plot_spectral_hist(freqs, power_bins, spectral_hist, spectrum_freqs=None,
if spectrum is not None:
plt_inds = np.logical_and(spectrum_freqs >= freqs[0], spectrum_freqs <= freqs[-1])
ax.plot(spectrum_freqs[plt_inds], np.log10(spectrum[plt_inds]), color='w', alpha=0.8)


@savefig
@style_plot
def plot_spectra_3d(freqs, powers, log_freqs=False, log_powers=True, colors=None,
orientation=(20, -50), zoom=1.0, ax=None, **kwargs):
"""Plot a series of power spectra in a 3D plot.

Parameters
----------
freqs : 1d or 2d array or list of 1d array
Frequency vector.
powers : 2d array or list of 1d array
Power values.
log_freqs : bool, optional, default: False
Whether to plot the frequency values in log10 space.
log_powers : bool, optional, default: True
Whether to plot the power values in log10 space.
colors : str or list of str, optional
Colors to use to plot lines.
orientation : tuple of int, optional, default: (20, -50)
Orientation to set the 3D plot. See `Axes3D.view_init` for more information.
zoom : float, optional, default: 1.0
Zoom scaling for the figure axis. See `Axes3D.set_box_aspect` for more information.
ax : matplotlib.Axes, optional
Figure axes upon which to plot. Must be a 3D axis.
**kwargs
Keyword arguments for customizing the plot.

Examples
--------
Plot power spectra in 3D:

>>> from neurodsp.sim import sim_combined
>>> from neurodsp.spectral import compute_spectrum
>>> sig1 = sim_combined(n_seconds=10, fs=500,
... components={'sim_powerlaw': {'exponent' : -1},
... 'sim_bursty_oscillation' : {'freq': 10}})
>>> sig2 = sim_combined(n_seconds=10, fs=500,
... components={'sim_powerlaw': {'exponent' : -1.5},
... 'sim_bursty_oscillation' : {'freq': 10}})
>>> freqs1, powers1 = compute_spectrum(sig1, fs=500)
>>> freqs2, powers2 = compute_spectrum(sig2, fs=500)
>>> plot_spectra_3d([freqs1, freqs2], [powers1, powers2])
"""

ax = check_ax_3d(ax)

n_spectra = len(powers)

for ind, (freq, power, _, color) in \
enumerate(zip(*prepare_multi_plot(freqs, powers, None, colors))):
ax.plot(xs=np.log10(freq) if log_freqs else freq,
ys=[ind] * len(freq),
zs=np.log10(power) if log_powers else power,
color=color,
**kwargs)

ax.set(
xlabel='Frequency (Hz)',
ylabel='Channels',
zlabel='Power',
ylim=[0, n_spectra - 1],
)

yticks = list(range(n_spectra))
ax.set_yticks(yticks, yticks)

ax.view_init(*orientation)
ax.set_box_aspect(None, zoom=zoom)
4 changes: 3 additions & 1 deletion neurodsp/plts/style.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,12 @@ def apply_custom_style(ax, **kwargs):
if ax.get_title():
ax.title.set_size(kwargs.pop('title_fontsize', TITLE_FONTSIZE))

# Settings for the axis labels
# Settings for the axis labels, including checking & setting for 3D axis
label_size = kwargs.pop('label_size', LABEL_SIZE)
ax.xaxis.label.set_size(label_size)
ax.yaxis.label.set_size(label_size)
if hasattr(ax, 'zaxis'):
ax.zaxis.label.set_size(label_size)

# Settings for the axis ticks
ax.tick_params(axis='both', which='major',
Expand Down
33 changes: 13 additions & 20 deletions neurodsp/plts/time_series.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""Plots for time series."""

from itertools import repeat, cycle
from itertools import repeat

import numpy as np
import matplotlib.pyplot as plt

from neurodsp.plts.style import style_plot
from neurodsp.plts.utils import check_ax, savefig
from neurodsp.plts.utils import check_ax, savefig, prepare_multi_plot
from neurodsp.utils.data import create_samples
from neurodsp.utils.checks import check_param_options

Expand Down Expand Up @@ -49,18 +49,12 @@

ax = check_ax(ax, kwargs.pop('figsize', (15, 3)))

sigs = [sigs] if (isinstance(sigs, np.ndarray) and sigs.ndim == 1) else sigs
times, xlabel = _check_times(times, sigs)

if labels is not None:
labels = [labels] if not isinstance(labels, list) else labels
else:
labels = repeat(labels)
times, sigs, colors, labels = prepare_multi_plot(times, sigs, colors, labels)

# If not provided, default colors for up to two signals to be black & red
if not colors and len(sigs) <= 2:
if isinstance(colors, repeat) and next(colors) is None and len(sigs) <= 2:
colors = ['k', 'r']
colors = repeat(colors) if not isinstance(colors, list) else cycle(colors)

for time, sig, color, label in zip(times, sigs, colors, labels):
ax.plot(time, sig, color=color, label=label)
Expand Down Expand Up @@ -174,32 +168,31 @@
Keyword arguments for customizing the plot.
"""

colors = 'black' if not colors else colors
colors = repeat(colors) if isinstance(colors, str) else iter(colors)

ax = check_ax(ax, figsize=plt_kwargs.pop('figsize', (15, 5)))

sigs = [sigs] if (isinstance(sigs, np.ndarray) and sigs.ndim == 1) else sigs
colors = 'black' if not colors else colors

times, xlabel = _check_times(times, sigs)
times, sigs, _, colors = prepare_multi_plot(times, sigs, None, colors)

step = 0.8 * np.ptp(sigs[0])

for ind, (time, sig) in enumerate(zip(times, sigs)):
ax.plot(time, sig+step*ind, color=next(colors), **plt_kwargs)

ax.set(yticks=[])
ax.set_xlabel(xlabel)
ax.set_ylabel('Channels')
ax.set(xlabel=xlabel, ylabel='Channels', yticks=[])


def _check_times(times, sigs):
"""Helper function to check a times definition passed into a time series plot function."""

xlabel = 'Time (s)'
if times is None:
times = create_samples(len(sigs[0]))
if isinstance(sigs, list) or (isinstance(sigs, np.ndarray) and sigs.ndim == 2):
n_samples = len(sigs[0])
else:
n_samples = len(sigs)

Check warning on line 194 in neurodsp/plts/time_series.py

View check run for this annotation

Codecov / codecov/patch

neurodsp/plts/time_series.py#L194

Added line #L194 was not covered by tests
times = create_samples(n_samples)
xlabel = 'Samples'

times = repeat(times) if (isinstance(times, np.ndarray) and times.ndim == 1) else times

return times, xlabel
75 changes: 75 additions & 0 deletions neurodsp/plts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from copy import deepcopy
from functools import wraps
from os.path import join as pjoin
from itertools import repeat, cycle

import numpy as np
import matplotlib.pyplot as plt

from neurodsp.plts.settings import SUPTITLE_FONTSIZE
Expand Down Expand Up @@ -60,6 +62,36 @@ def check_ax(ax, figsize=None):
return ax


def check_ax_3d(ax, figsize=None):
"""Check whether a 3D figure axes object is defined, define if not.

Parameters
----------
ax : matplotlib.Axes or None
Axes object to check if is defined. Must be 3D.

Returns
-------
ax : matplotlib.Axes
Figure axes object to use.

Raises
------
ValueError
If the ax input is a defined axis, but is not 3D.
"""

if ax and '3d' not in ax.name:
raise ValueError('Provided axis is not 3D.')

if not ax:

fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(projection='3d')

return ax


def savefig(func):
"""Decorator function to save out figures."""

Expand Down Expand Up @@ -155,3 +187,46 @@ def make_axes(n_rows, n_cols, figsize=None, row_size=4, col_size=3.6,
**title_kwargs)

return axes


def prepare_multi_plot(xs, ys, labels=None, colors=None):
"""Prepare inputs for plotting one or more elements in a loop.

Parameters
----------
xs, ys : 1d or 2d array
Plot data.
labels : str or list
Label(s) for the plot input(s).
colors : str or iterable
Color(s) to plot input(s).

Returns
-------
xs, ys : iterable
Plot data.
labels : iterable
Label(s) for the plot input(s).
colors : iterable
Color(s) to plot input(s).

Notes
-----
This function takes inputs that can reflect one or more plot elements, and
prepares the inputs to be iterable for plotting in a loop.
"""

xs = repeat(xs) if isinstance(xs, np.ndarray) and xs.ndim == 1 else xs
ys = [ys] if isinstance(ys, np.ndarray) and ys.ndim == 1 else ys

# Collect definition of collection items considered iterables to check against
iterables = (list, tuple, np.ndarray)

if labels is not None:
labels = [labels] if not isinstance(labels, iterables) else labels
else:
labels = repeat(labels)

colors = repeat(colors) if not isinstance(colors, iterables) else cycle(colors)

return xs, ys, labels, colors
25 changes: 25 additions & 0 deletions neurodsp/tests/plts/test_aperiodic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""Tests for neurodsp.plts.aperiodic."""

from neurodsp.aperiodic.autocorr import compute_autocorr

from neurodsp.tests.settings import TEST_PLOTS_PATH, FS
from neurodsp.tests.tutils import plot_test

from neurodsp.plts.aperiodic import *

###################################################################################################
###################################################################################################

def tests_plot_autocorr(tsig, tsig_comb):

times1, acs1 = compute_autocorr(tsig, max_lag=150)
times2, acs2 = compute_autocorr(tsig_comb, max_lag=150)

plot_autocorr(times1, acs1,
save_fig=True, file_path=TEST_PLOTS_PATH,
file_name='test_plot_autocorr-1.png')

plot_autocorr([times1, times2], [acs1, acs2],
labels=['first', 'second'], colors=['k', 'r'],
save_fig=True, file_path=TEST_PLOTS_PATH,
file_name='test_plot_autocorr-2.png')
Loading