Skip to content

Commit

Permalink
[common] axes as args argument
Browse files Browse the repository at this point in the history
  • Loading branch information
janscience committed Nov 2, 2024
1 parent b307a8f commit c2580ab
Showing 1 changed file with 43 additions and 31 deletions.
74 changes: 43 additions & 31 deletions src/plottools/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import matplotlib.ticker as ticker


def common_xlabels(fig, axes=None):
def common_xlabels(fig, *axes):
""" Reduce common xlabels.
Remove all xlabels except for one that is centered at the bottommost axes.
Expand All @@ -36,14 +36,16 @@ def common_xlabels(fig, axes=None):
----------
fig: matplotlib figure
The figure containing the axes.
axes: None or sequence of matplotlib axes
axes: Sequence of matplotlib axes
Axes whose xlabels should be merged.
If None take all axes of the figure.
If not specified, take all axes of the figure.
"""
if axes is None:
if len(axes) == 0:
axes = fig.get_axes()
if isinstance(axes, np.ndarray):
axes = axes.ravel()
if len(axes) == 0:
return
if len(axes) == 1 and isinstance(axes[0], np.ndarray):
axes = axes[0].ravel()
coords = np.array([ax.get_position().get_points().ravel() for ax in axes])
minx = np.min(coords[:,0])
maxx = np.max(coords[:,2])
Expand All @@ -68,7 +70,7 @@ def common_xlabels(fig, axes=None):
done = True


def common_ylabels(fig, axes=None):
def common_ylabels(fig, *axes):
""" Reduce common ylabels.
Remove all ylabels except for one that is centered at the leftmost axes.
Expand All @@ -79,12 +81,14 @@ def common_ylabels(fig, axes=None):
The figure containing the axes.
axes: None or sequence of matplotlib axes
Axes whose ylabels should be merged.
If None take all axes of the figure.
If not specified, take all axes of the figure.
"""
if axes is None:
if len(axes) == 0:
axes = fig.get_axes()
if isinstance(axes, np.ndarray):
axes = axes.ravel()
if len(axes) == 0:
return
if len(axes) == 1 and isinstance(axes[0], np.ndarray):
axes = axes[0].ravel()
coords = np.array([ax.get_position().get_points().ravel() for ax in axes])
# center common ylabel:
minx = np.min(coords[:,0])
Expand All @@ -108,7 +112,7 @@ def common_ylabels(fig, axes=None):
done = True


def common_xticks(fig, axes=None):
def common_xticks(fig, *axes):
""" Reduce common xtick labels and xlabels.
Keep xtick labels only at the lowest axes and center the common xlabel.
Expand All @@ -119,12 +123,14 @@ def common_xticks(fig, axes=None):
The figure containing the axes.
axes: None or sequence of matplotlib axes
Axes whose xticks should be combined.
If None take all axes of the figure.
If not specified, take all axes of the figure.
"""
if axes is None:
if len(axes) == 0:
axes = fig.get_axes()
if isinstance(axes, np.ndarray):
axes = axes.ravel()
if len(axes) == 0:
return
if len(axes) == 1 and isinstance(axes[0], np.ndarray):
axes = axes[0].ravel()
coords = np.array([ax.get_position().get_points().ravel() for ax in axes])
minx = np.min(coords[:,0])
maxx = np.max(coords[:,2])
Expand All @@ -150,7 +156,7 @@ def common_xticks(fig, axes=None):
done = True


def common_yticks(fig, axes=None):
def common_yticks(fig, *axes):
""" Reduce common ytick labels and ylabels.
Keep ytick labels only at the leftmost axes and center the common ylabel.
Expand All @@ -161,12 +167,14 @@ def common_yticks(fig, axes=None):
The figure containing the axes.
axes: None or sequence of matplotlib axes
Axes whose yticks should be combined.
If None take all axes of the figure.
If not specified, take all axes of the figure.
"""
if axes is None:
if len(axes) == 0:
axes = fig.get_axes()
if isinstance(axes, np.ndarray):
axes = axes.ravel()
if len(axes) == 0:
return
if len(axes) == 1 and isinstance(axes[0], np.ndarray):
axes = axes[0].ravel()
coords = np.array([ax.get_position().get_points().ravel() for ax in axes])
minx = np.min(coords[:,0])
maxx = np.max(coords[:,2])
Expand All @@ -190,7 +198,7 @@ def common_yticks(fig, axes=None):
done = True


def common_xspines(fig, axes=None):
def common_xspines(fig, *axes):
""" Reduce common x-spines, xtick labels, and xlabels.
Keep spine and xtick labels only at the lowest axes and center the common xlabel.
Expand All @@ -201,12 +209,14 @@ def common_xspines(fig, axes=None):
The figure containing the axes.
axes: None or sequence of matplotlib axes
Axes whose xticks should be combined.
If None take all axes of the figure.
If not specified, take all axes of the figure.
"""
if axes is None:
if len(axes) == 0:
axes = fig.get_axes()
if isinstance(axes, np.ndarray):
axes = axes.ravel()
if len(axes) == 0:
return
if len(axes) == 1 and isinstance(axes[0], np.ndarray):
axes = axes[0].ravel()
coords = np.array([ax.get_position().get_points().ravel() for ax in axes])
minx = np.min(coords[:,0])
maxx = np.max(coords[:,2])
Expand All @@ -233,7 +243,7 @@ def common_xspines(fig, axes=None):
done = True


def common_yspines(fig, axes=None):
def common_yspines(fig, *axes):
""" Reduce common y-spines, ytick labels, and ylabels.
Keep spine and ytick labels only at the lowest axes and center the common ylabel.
Expand All @@ -244,12 +254,14 @@ def common_yspines(fig, axes=None):
The figure containing the axes.
axes: None or sequence of matplotlib axes
Axes whose yticks should be combined.
If None take all axes of the figure.
If not specified, take all axes of the figure.
"""
if axes is None:
if len(axes) == 0:
axes = fig.get_axes()
if isinstance(axes, np.ndarray):
axes = axes.ravel()
if len(axes) == 0:
return
if len(axes) == 1 and isinstance(axes[0], np.ndarray):
axes = axes[0].ravel()
coords = np.array([ax.get_position().get_points().ravel() for ax in axes])
minx = np.min(coords[:,0])
maxx = np.max(coords[:,2])
Expand Down

0 comments on commit c2580ab

Please sign in to comment.