Skip to content

Commit

Permalink
add raincloud method
Browse files Browse the repository at this point in the history
  • Loading branch information
MAfarrag committed Aug 26, 2024
1 parent 28a4ed5 commit 152f3f1
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 2 deletions.
Binary file added docs/source/_images/time_series/raincloud_1d.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
175 changes: 173 additions & 2 deletions statista/time_series.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Union, List, Tuple, Literal
from pandas import DataFrame
from matplotlib.collections import PolyCollection
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
Expand Down Expand Up @@ -194,14 +195,14 @@ def box_plot(
>>> ts = TimeSeries(np.random.randn(100))
>>> fig, ax = ts.box_plot()
.. image:: /_images/box_plot_1d.png
.. image:: /_images/time_series/box_plot_1d.png
:align: center
- Plot the box plot for a multiple time series:
>>> data_2d = np.random.randn(100, 4)
>>> ts_2d = TimeSeries(data_2d, columns=['A', 'B', 'C', 'D'])
>>> fig, ax = ts_2d.box_plot(grid=True, mean=True)
>>> fig, ax = ts_2d.box_plot(mean=True, grid=True)
.. image:: /_images/times_series/box_plot_2d.png
:align: center
Expand Down Expand Up @@ -354,6 +355,7 @@ def violin(
.. image:: /_images/times_series/violin_labels_titles.png
:align: center
- You can display the means, medians, and extrema using the respective parameters:
>>> fig, ax = ts_2d.violin(mean=True, median=True, extrema=True)
Expand Down Expand Up @@ -397,3 +399,172 @@ def violin(
ax.grid(kwargs.get("grid"), axis="both", linestyle="-.", linewidth=0.3)
plt.show()
return fig, ax

def raincloud(
self,
overlay: bool = True,
violin_width: float = 0.4,
scatter_offset: float = 0.15,
boxplot_width: float = 0.1,
order: List[str] = None,
**kwargs,
):
"""RainCloud plot.
Parameters
----------
overlay: bool, optional, default is True.
Whether to overlay the plots or display them side-by-side.
violin_width: float, optional, default is 0.4.
Width of the violins.
scatter_offset: float, optional, default is 0.15.
Offset for the scatter plot.
boxplot_width: float, optional, default is
Width of the box plot.
order: list, optional, default is None.
Order of the plots. Default is ['violin', 'scatter', 'box'].
**kwargs: dict, optional
fig: matplotlib.figure.Figure, optional
Existing figure to plot on. If None, a new figure is created.
ax: matplotlib.axes.Axes, optional
Existing axes to plot on. If None, a new axes is created.
grid: bool, optional
Whether to show grid lines. Default is True.
color: dict, optional, default is None.
Colors to use for the plot elements. Default is None.
>>> color = {"boxes", "#27408B"}
title: str, optional
Title of the plot. Default is 'Box Plot'.
xlabel: str, optional
Label for the x-axis. Default is 'Index'.
ylabel: str, optional
Label for the y-axis. Default is 'Value'.
Returns
-------
fig: matplotlib.figure.Figure
The figure object containing the plot.
ax: matplotlib.axes.Axes
The axes object containing the plot.
Examples
--------
- Plot the raincloud plot for a 1D time series, and use the `overlay` parameter to overlay the plots:
>>> ts = TimeSeries(np.random.randn(100))
>>> fig, ax = ts.raincloud()
.. image:: /_images/time_series/raincloud_1d.png
:align: center
>>> fig, ax = ts.raincloud(overlay=False)
.. image:: /_images/time_series/raincloud_1d.png
:align: center
- Plot the box plot for a multiple time series:
>>> data_2d = np.random.randn(100, 4)
>>> ts_2d = TimeSeries(data_2d, columns=['A', 'B', 'C', 'D'])
>>> fig, ax = ts_2d.box_plot(grid=True, mean=True)
.. image:: /_images/times_series/box_plot_2d.png
:align: center
>>> fig, ax = ts_2d.box_plot(grid=True, mean=True, color={"boxes": "#DC143C"})
.. image:: /_images/times_series/box_plot_color.png
:align: center
>>> fig, ax = ts_2d.box_plot(xlabel='Custom X', ylabel='Custom Y', title='Custom Box Plot')
.. image:: /_images/times_series/box_plot_axes-label.png
:align: center
>>> fig, ax = ts_2d.box_plot(notch=True)
.. image:: /_images/times_series/box_plot_notch.png
:align: center
"""
fig, ax = self._get_ax_fig(fig=kwargs.get("fig"), ax=kwargs.get("ax"))
if order is None:
order = ["violin", "scatter", "box"]

n_groups = len(self.columns)
positions = np.arange(1, n_groups + 1)

# Dictionary to map plot types to the functions
plot_funcs = {
"violin": lambda pos, d: ax.violinplot(
[d],
positions=[pos],
showmeans=False,
showmedians=False,
showextrema=False,
widths=violin_width,
),
"scatter": lambda pos, d: ax.scatter(
np.random.normal(pos, 0.04, size=len(d)),
d,
alpha=0.6,
color="black",
s=10,
edgecolor="white",
linewidth=0.5,
),
"box": lambda pos, d: ax.boxplot(
[d],
positions=[pos],
widths=boxplot_width,
vert=True,
patch_artist=True,
boxprops=dict(facecolor="lightblue", color="blue"),
medianprops=dict(color="red"),
),
}

# Plot elements according to the specified order and selected plots
# for i, d in enumerate(data):
for i in range(len(self.columns)):
if self.ndim == 1:
d = self.values
else:
d = self.values[:, i]
base_pos = positions[i]
if overlay:
for plot_type in order:
plot_funcs[plot_type](base_pos, d)
else:
for j, plot_type in enumerate(order):
offset = (j - 1) * scatter_offset
plot_funcs[plot_type](base_pos + offset, d)

# Customize the appearance of violins if they are included
if "violin" in order:
for (
pc
) in (
ax.collections
): # all polygons created by violinplot are in ax.collections
if isinstance(pc, PolyCollection):
pc.set_facecolor("skyblue")
pc.set_edgecolor("blue")
pc.set_alpha(0.3)
pc.set_linewidth(1)
pc.set_linestyle("-")

# Set x-tick labels
ax.set_xticks(positions)
ax.set_xticklabels(self.columns)
ax.set_title(kwargs.get("title"))
ax.set_xlabel(kwargs.get("xlabel"))
ax.set_ylabel(kwargs.get("ylabel"))

ax.grid(kwargs.get("grid"), axis="both", linestyle="-.", linewidth=0.3)

# Add grid lines for better readability
# ax.yaxis.grid(True)

# Display the plot
plt.show()
return fig, ax
20 changes: 20 additions & 0 deletions tests/test_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,23 @@ def test_violin(self, ts: TimeSeries, request):
assert ax2 is ax, "If ax is provided, plot_box should use it."
if ts.shape[1] > 1:
assert len(ax.get_xticklabels()) == 3


class TestRainCloud:

@pytest.mark.parametrize("ts", ["ts_1d", "ts_2d"])
def test_raincloud(self, ts: TimeSeries, request):
"""Test the plot_box method."""
ts = request.getfixturevalue(ts)
fig, ax = ts.raincloud()
assert isinstance(
fig, plt.Figure
), "plot_box should return a matplotlib Figure."
assert isinstance(ax, plt.Axes), "plot_box should return a matplotlib Axes."

fig, ax = plt.subplots()
fig2, ax2 = ts.raincloud(fig=fig, ax=ax)
assert fig2 is fig, "If fig is provided, plot_box should use it."
assert ax2 is ax, "If ax is provided, plot_box should use it."
if ts.shape[1] > 1:
assert len(ax.get_xticklabels()) == 3

0 comments on commit 152f3f1

Please sign in to comment.