diff --git a/docs/source/_images/time_series/raincloud_1d.png b/docs/source/_images/time_series/raincloud_1d.png new file mode 100644 index 0000000..6dc646c Binary files /dev/null and b/docs/source/_images/time_series/raincloud_1d.png differ diff --git a/statista/time_series.py b/statista/time_series.py index ce60f18..3b8f67d 100644 --- a/statista/time_series.py +++ b/statista/time_series.py @@ -1,5 +1,6 @@ from typing import Union, List, Tuple, Literal from pandas import DataFrame +from matplotlib.collections import PolyCollection import numpy as np import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -194,14 +195,14 @@ def box_plot( >>> ts = TimeSeries(np.random.randn(100)) >>> fig, ax = ts.box_plot() - .. image:: /_images/box_plot_1d.png + .. image:: /_images/time_series/box_plot_1d.png :align: center - Plot the box plot for a multiple time series: >>> data_2d = np.random.randn(100, 4) >>> ts_2d = TimeSeries(data_2d, columns=['A', 'B', 'C', 'D']) - >>> fig, ax = ts_2d.box_plot(grid=True, mean=True) + >>> fig, ax = ts_2d.box_plot(mean=True, grid=True) .. image:: /_images/times_series/box_plot_2d.png :align: center @@ -354,6 +355,7 @@ def violin( .. image:: /_images/times_series/violin_labels_titles.png :align: center + - You can display the means, medians, and extrema using the respective parameters: >>> fig, ax = ts_2d.violin(mean=True, median=True, extrema=True) @@ -397,3 +399,172 @@ def violin( ax.grid(kwargs.get("grid"), axis="both", linestyle="-.", linewidth=0.3) plt.show() return fig, ax + + def raincloud( + self, + overlay: bool = True, + violin_width: float = 0.4, + scatter_offset: float = 0.15, + boxplot_width: float = 0.1, + order: List[str] = None, + **kwargs, + ): + """RainCloud plot. + + Parameters + ---------- + overlay: bool, optional, default is True. + Whether to overlay the plots or display them side-by-side. + violin_width: float, optional, default is 0.4. + Width of the violins. + scatter_offset: float, optional, default is 0.15. + Offset for the scatter plot. + boxplot_width: float, optional, default is + Width of the box plot. + order: list, optional, default is None. + Order of the plots. Default is ['violin', 'scatter', 'box']. + **kwargs: dict, optional + fig: matplotlib.figure.Figure, optional + Existing figure to plot on. If None, a new figure is created. + ax: matplotlib.axes.Axes, optional + Existing axes to plot on. If None, a new axes is created. + grid: bool, optional + Whether to show grid lines. Default is True. + color: dict, optional, default is None. + Colors to use for the plot elements. Default is None. + >>> color = {"boxes", "#27408B"} + title: str, optional + Title of the plot. Default is 'Box Plot'. + xlabel: str, optional + Label for the x-axis. Default is 'Index'. + ylabel: str, optional + Label for the y-axis. Default is 'Value'. + + Returns + ------- + fig: matplotlib.figure.Figure + The figure object containing the plot. + ax: matplotlib.axes.Axes + The axes object containing the plot. + + Examples + -------- + - Plot the raincloud plot for a 1D time series, and use the `overlay` parameter to overlay the plots: + + >>> ts = TimeSeries(np.random.randn(100)) + >>> fig, ax = ts.raincloud() + + .. image:: /_images/time_series/raincloud_1d.png + :align: center + + >>> fig, ax = ts.raincloud(overlay=False) + + .. image:: /_images/time_series/raincloud_1d.png + :align: center + + - Plot the box plot for a multiple time series: + + >>> data_2d = np.random.randn(100, 4) + >>> ts_2d = TimeSeries(data_2d, columns=['A', 'B', 'C', 'D']) + >>> fig, ax = ts_2d.box_plot(grid=True, mean=True) + + .. image:: /_images/times_series/box_plot_2d.png + :align: center + + >>> fig, ax = ts_2d.box_plot(grid=True, mean=True, color={"boxes": "#DC143C"}) + + .. image:: /_images/times_series/box_plot_color.png + :align: center + + >>> fig, ax = ts_2d.box_plot(xlabel='Custom X', ylabel='Custom Y', title='Custom Box Plot') + + .. image:: /_images/times_series/box_plot_axes-label.png + :align: center + + >>> fig, ax = ts_2d.box_plot(notch=True) + + .. image:: /_images/times_series/box_plot_notch.png + :align: center + """ + fig, ax = self._get_ax_fig(fig=kwargs.get("fig"), ax=kwargs.get("ax")) + if order is None: + order = ["violin", "scatter", "box"] + + n_groups = len(self.columns) + positions = np.arange(1, n_groups + 1) + + # Dictionary to map plot types to the functions + plot_funcs = { + "violin": lambda pos, d: ax.violinplot( + [d], + positions=[pos], + showmeans=False, + showmedians=False, + showextrema=False, + widths=violin_width, + ), + "scatter": lambda pos, d: ax.scatter( + np.random.normal(pos, 0.04, size=len(d)), + d, + alpha=0.6, + color="black", + s=10, + edgecolor="white", + linewidth=0.5, + ), + "box": lambda pos, d: ax.boxplot( + [d], + positions=[pos], + widths=boxplot_width, + vert=True, + patch_artist=True, + boxprops=dict(facecolor="lightblue", color="blue"), + medianprops=dict(color="red"), + ), + } + + # Plot elements according to the specified order and selected plots + # for i, d in enumerate(data): + for i in range(len(self.columns)): + if self.ndim == 1: + d = self.values + else: + d = self.values[:, i] + base_pos = positions[i] + if overlay: + for plot_type in order: + plot_funcs[plot_type](base_pos, d) + else: + for j, plot_type in enumerate(order): + offset = (j - 1) * scatter_offset + plot_funcs[plot_type](base_pos + offset, d) + + # Customize the appearance of violins if they are included + if "violin" in order: + for ( + pc + ) in ( + ax.collections + ): # all polygons created by violinplot are in ax.collections + if isinstance(pc, PolyCollection): + pc.set_facecolor("skyblue") + pc.set_edgecolor("blue") + pc.set_alpha(0.3) + pc.set_linewidth(1) + pc.set_linestyle("-") + + # Set x-tick labels + ax.set_xticks(positions) + ax.set_xticklabels(self.columns) + ax.set_title(kwargs.get("title")) + ax.set_xlabel(kwargs.get("xlabel")) + ax.set_ylabel(kwargs.get("ylabel")) + + ax.grid(kwargs.get("grid"), axis="both", linestyle="-.", linewidth=0.3) + + # Add grid lines for better readability + # ax.yaxis.grid(True) + + # Display the plot + plt.show() + return fig, ax diff --git a/tests/test_time_series.py b/tests/test_time_series.py index 3ed3d4f..86cefc9 100644 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -91,3 +91,23 @@ def test_violin(self, ts: TimeSeries, request): assert ax2 is ax, "If ax is provided, plot_box should use it." if ts.shape[1] > 1: assert len(ax.get_xticklabels()) == 3 + + +class TestRainCloud: + + @pytest.mark.parametrize("ts", ["ts_1d", "ts_2d"]) + def test_raincloud(self, ts: TimeSeries, request): + """Test the plot_box method.""" + ts = request.getfixturevalue(ts) + fig, ax = ts.raincloud() + assert isinstance( + fig, plt.Figure + ), "plot_box should return a matplotlib Figure." + assert isinstance(ax, plt.Axes), "plot_box should return a matplotlib Axes." + + fig, ax = plt.subplots() + fig2, ax2 = ts.raincloud(fig=fig, ax=ax) + assert fig2 is fig, "If fig is provided, plot_box should use it." + assert ax2 is ax, "If ax is provided, plot_box should use it." + if ts.shape[1] > 1: + assert len(ax.get_xticklabels()) == 3