diff --git a/plotnine/_mpl/layout_manager/_engine.py b/plotnine/_mpl/layout_manager/_engine.py index 08080576a..f3c20fc9f 100644 --- a/plotnine/_mpl/layout_manager/_engine.py +++ b/plotnine/_mpl/layout_manager/_engine.py @@ -5,7 +5,7 @@ from matplotlib.layout_engine import LayoutEngine -from ._layout_pack import LayoutPack +from ._spaces import LayoutSpaces if TYPE_CHECKING: from matplotlib.figure import Figure @@ -33,12 +33,10 @@ def __init__(self, plot: ggplot): def execute(self, fig: Figure): from contextlib import nullcontext - from ._tight_layout import adjust_figure_artists, compute_layout + renderer = fig._get_renderer() # pyright: ignore[reportAttributeAccessIssue] - pack = LayoutPack(self.plot) - - with getattr(pack.renderer, "_draw_disabled", nullcontext)(): - spaces = compute_layout(pack) + with getattr(renderer, "_draw_disabled", nullcontext)(): + spaces = LayoutSpaces(self.plot) fig.subplots_adjust(**asdict(spaces.gsparams)) - adjust_figure_artists(pack, spaces) + spaces.items._adjust_positions(spaces) diff --git a/plotnine/_mpl/layout_manager/_layout_pack.py b/plotnine/_mpl/layout_manager/_layout_items.py similarity index 68% rename from plotnine/_mpl/layout_manager/_layout_pack.py rename to plotnine/_mpl/layout_manager/_layout_items.py index e2dfa2c06..23ef4a18e 100644 --- a/plotnine/_mpl/layout_manager/_layout_pack.py +++ b/plotnine/_mpl/layout_manager/_layout_items.py @@ -10,6 +10,7 @@ from ..utils import ( bbox_in_figure_space, + get_transPanels, tight_bbox_in_figure_space, ) @@ -25,14 +26,16 @@ from matplotlib.artist import Artist from matplotlib.axes import Axes from matplotlib.axis import Tick - from matplotlib.figure import Figure - from matplotlib.transforms import Bbox + from matplotlib.transforms import Bbox, Transform from plotnine import ggplot + from plotnine._mpl.offsetbox import FlexibleAnchoredOffsetbox from plotnine._mpl.text import StripText from plotnine.iapi import legend_artists from plotnine.typing import StripPosition + from ._spaces import LayoutSpaces + AxesLocation: TypeAlias = Literal[ "all", "first_row", "last_row", "first_col", "last_col" ] @@ -40,20 +43,29 @@ @dataclass class Calc: - fig: Figure - renderer: RendererBase + """ + Calculate space taken up by an artist + """ + + # fig: Figure + # renderer: RendererBase + plot: ggplot + + def __post_init__(self): + self.figure = self.plot.figure + self.renderer = cast(RendererBase, self.plot.figure._get_renderer()) # pyright: ignore def bbox(self, artist: Artist) -> Bbox: """ Bounding box of artist in figure coordinates """ - return bbox_in_figure_space(artist, self.fig, self.renderer) + return bbox_in_figure_space(artist, self.figure, self.renderer) def tight_bbox(self, artist: Artist) -> Bbox: """ Bounding box of artist and its children in figure coordinates """ - return tight_bbox_in_figure_space(artist, self.fig, self.renderer) + return tight_bbox_in_figure_space(artist, self.figure, self.renderer) def width(self, artist: Artist) -> float: """ @@ -138,7 +150,7 @@ def max_width(self, artists: Sequence[Artist]) -> float: Return the maximum width of list of artists """ widths = [ - bbox_in_figure_space(a, self.fig, self.renderer).width + bbox_in_figure_space(a, self.figure, self.renderer).width for a in artists ] return max(widths) if len(widths) else 0 @@ -148,14 +160,14 @@ def max_height(self, artists: Sequence[Artist]) -> float: Return the maximum height of list of artists """ heights = [ - bbox_in_figure_space(a, self.fig, self.renderer).height + bbox_in_figure_space(a, self.figure, self.renderer).height for a in artists ] return max(heights) if len(heights) else 0 @dataclass -class LayoutPack: +class LayoutItems: """ Objects required to compute the layout """ @@ -167,20 +179,15 @@ def get(name: str) -> Any: """ Return themeable target or None """ - if self.theme.T.is_blank(name): + if self._is_blank(name): return None else: - t = getattr(self.theme.targets, name) + t = getattr(self.plot.theme.targets, name) if isinstance(t, Text) and t.get_text() == "": return None return t - self.axs = self.plot.axs - self.theme = self.plot.theme - self.figure = self.plot.figure - self.facet = self.plot.facet - self.renderer = cast(RendererBase, self.plot.figure._get_renderer()) # pyright: ignore - self.calc = Calc(self.figure, self.renderer) + self.calc = Calc(self.plot) self.axis_title_x: Text | None = get("axis_title_x") self.axis_title_y: Text | None = get("axis_title_y") @@ -195,20 +202,22 @@ def get(name: str) -> Any: self.strip_text_y: list[StripText] | None = get("strip_text_y") def _is_blank(self, name: str) -> bool: - return self.theme.T.is_blank(name) + return self.plot.theme.T.is_blank(name) def _filter_axes(self, location: AxesLocation = "all") -> list[Axes]: """ Return subset of axes """ + axs = self.plot.axs + if location == "all": - return self.axs + return axs # e.g. is_first_row, is_last_row, .. pred_method = f"is_{location}" return [ ax - for spec, ax in zip(get_subplotspec_list(self.axs), self.axs) + for spec, ax in zip(get_subplotspec_list(axs), axs) if getattr(spec, pred_method)() ] @@ -280,7 +289,7 @@ def axis_ticks_pad_x(self, ax: Axes) -> Iterator[float]: # the axis_text. major, minor = [], [] if not self._is_blank("axis_text_x"): - h = self.figure.get_figheight() * 72 + h = self.plot.figure.get_figheight() * 72 major = [ (t.get_pad() or 0) / h for t in ax.xaxis.get_major_ticks() ] @@ -297,7 +306,7 @@ def axis_ticks_pad_y(self, ax: Axes) -> Iterator[float]: # the axis_text. major, minor = [], [] if not self._is_blank("axis_text_y"): - w = self.figure.get_figwidth() * 72 + w = self.plot.figure.get_figwidth() * 72 major = [ (t.get_pad() or 0) / w for t in ax.yaxis.get_major_ticks() ] @@ -436,9 +445,165 @@ def axis_text_x_right_protrusion(self, location: AxesLocation) -> float: return max(extras) if len(extras) else 0 + def _adjust_positions(self, spaces: LayoutSpaces): + """ + Set the x,y position of the artists around the panels + """ + theme = self.plot.theme + + if self.plot_title: + ha = theme.getp(("plot_title", "ha")) + self.plot_title.set_y(spaces.t.edge("plot_title")) + horizontally_align_text_with_panels(self.plot_title, ha, spaces) + + if self.plot_subtitle: + ha = theme.getp(("plot_subtitle", "ha")) + self.plot_subtitle.set_y(spaces.t.edge("plot_subtitle")) + horizontally_align_text_with_panels(self.plot_subtitle, ha, spaces) + + if self.plot_caption: + ha = theme.getp(("plot_caption", "ha"), "right") + self.plot_caption.set_y(spaces.b.edge("plot_caption")) + horizontally_align_text_with_panels(self.plot_caption, ha, spaces) + + if self.axis_title_x: + ha = theme.getp(("axis_title_x", "ha"), "center") + self.axis_title_x.set_y(spaces.b.edge("axis_title_x")) + horizontally_align_text_with_panels(self.axis_title_x, ha, spaces) + + if self.axis_title_y: + va = theme.getp(("axis_title_y", "va"), "center") + self.axis_title_y.set_x(spaces.l.edge("axis_title_y")) + vertically_align_text_with_panels(self.axis_title_y, va, spaces) + + if self.legends: + set_legends_position(self.legends, spaces) + def _text_is_visible(text: Text) -> bool: """ Return True if text is visible and is not empty """ return text.get_visible() and text._text # type: ignore + + +def horizontally_align_text_with_panels( + text: Text, ha: str | float, spaces: LayoutSpaces +): + """ + Horizontal justification + + Reinterpret horizontal alignment to be justification about the panels. + """ + if isinstance(ha, str): + lookup = { + "left": 0.0, + "center": 0.5, + "right": 1.0, + } + f = lookup[ha] + else: + f = ha + + params = spaces.gsparams + width = spaces.items.calc.width(text) + x = params.left * (1 - f) + (params.right - width) * f + text.set_x(x) + text.set_horizontalalignment("left") + + +def vertically_align_text_with_panels( + text: Text, va: str | float, spaces: LayoutSpaces +): + """ + Vertical justification + + Reinterpret vertical alignment to be justification about the panels. + """ + if isinstance(va, str): + lookup = { + "top": 1.0, + "center": 0.5, + "baseline": 0.5, + "center_baseline": 0.5, + "bottom": 0.0, + } + f = lookup[va] + else: + f = va + + params = spaces.gsparams + height = spaces.items.calc.height(text) + y = params.bottom * (1 - f) + (params.top - height) * f + text.set_y(y) + text.set_verticalalignment("bottom") + + +def set_legends_position(legends: legend_artists, spaces: LayoutSpaces): + """ + Place legend on the figure and justify is a required + """ + figure = spaces.plot.figure + params = figure.subplotpars + + def set_position( + aob: FlexibleAnchoredOffsetbox, + anchor_point: tuple[float, float], + xy_loc: tuple[float, float], + transform: Transform = figure.transFigure, + ): + """ + Place box (by the anchor point) at given xy location + + Parameters + ---------- + aob : + Offsetbox to place + anchor_point : + Point on the Offsefbox. + xy_loc : + Point where to place the offsetbox. + transform : + Transformation + """ + aob.xy_loc = xy_loc + aob.set_bbox_to_anchor(anchor_point, transform) # type: ignore + + def func(a, b, length, f): + return a * (1 - f) + (b - length) * f + + if legends.right: + j = legends.right.justification + y = ( + params.bottom * (1 - j) + + (params.top - spaces.r._legend_height) * j + ) + x = spaces.r.edge("legend") + set_position(legends.right.box, (x, y), (1, 0)) + + if legends.left: + j = legends.left.justification + y = ( + params.bottom * (1 - j) + + (params.top - spaces.l._legend_height) * j + ) + x = spaces.l.edge("legend") + set_position(legends.left.box, (x, y), (0, 0)) + + if legends.top: + j = legends.top.justification + x = params.left * (1 - j) + (params.right - spaces.t._legend_width) * j + y = spaces.t.edge("legend") + set_position(legends.top.box, (x, y), (0, 1)) + + if legends.bottom: + j = legends.bottom.justification + x = params.left * (1 - j) + (params.right - spaces.b._legend_width) * j + y = spaces.b.edge("legend") + set_position(legends.bottom.box, (x, y), (0, 0)) + + # Inside legends are placed using the panels coordinate system + if legends.inside: + transPanels = get_transPanels(figure) + for l in legends.inside: + set_position(l.box, l.position, l.justification, transPanels) diff --git a/plotnine/_mpl/layout_manager/_spaces.py b/plotnine/_mpl/layout_manager/_spaces.py index 3bbdc8460..bed7a7d65 100644 --- a/plotnine/_mpl/layout_manager/_spaces.py +++ b/plotnine/_mpl/layout_manager/_spaces.py @@ -18,11 +18,13 @@ from plotnine.facets import facet_grid, facet_null, facet_wrap +from ._layout_items import LayoutItems + if TYPE_CHECKING: from dataclasses import Field from typing import Generator - from ._layout_pack import LayoutPack + from plotnine import ggplot # Note @@ -56,7 +58,7 @@ class _side_spaces(ABC): side classes (e.g. legend). """ - pack: LayoutPack + items: LayoutItems def __post_init__(self): self._calculate() @@ -131,39 +133,39 @@ class left_spaces(_side_spaces): axis_ticks_y: float = 0 def _calculate(self): - theme = self.pack.theme - calc = self.pack.calc - pack = self.pack + theme = self.items.plot.theme + calc = self.items.calc + items = self.items self.plot_margin = theme.getp("plot_margin_left") - if pack.legends and pack.legends.left: + if items.legends and items.legends.left: self.legend = self._legend_width self.legend_box_spacing = theme.getp("legend_box_spacing") - if pack.axis_title_y: + if items.axis_title_y: self.axis_title_y_margin_right = theme.getp( ("axis_title_y", "margin") ).get_as("r", "fig") - self.axis_title_y = calc.width(pack.axis_title_y) + self.axis_title_y = calc.width(items.axis_title_y) # Account for the space consumed by the axis - self.axis_text_y = pack.axis_text_y_max_width("first_col") - self.axis_ticks_y = pack.axis_ticks_y_max_width("first_col") + self.axis_text_y = items.axis_text_y_max_width("first_col") + self.axis_ticks_y = items.axis_ticks_y_max_width("first_col") # Adjust plot_margin to make room for ylabels that protude well # beyond the axes # NOTE: This adjustment breaks down when the protrusion is large - protrusion = pack.axis_text_x_left_protrusion("all") + protrusion = items.axis_text_x_left_protrusion("all") adjustment = protrusion - (self.total - self.plot_margin) if adjustment > 0: self.plot_margin += adjustment @cached_property def _legend_size(self) -> tuple[float, float]: - if not (self.pack.legends and self.pack.legends.left): + if not (self.items.legends and self.items.legends.left): return (0, 0) - return self.pack.calc.size(self.pack.legends.left.box) + return self.items.calc.size(self.items.legends.left.box) def edge(self, item: str) -> float: """ @@ -186,30 +188,30 @@ class right_spaces(_side_spaces): strip_text_y_width_right: float = 0 def _calculate(self): - pack = self.pack - theme = self.pack.theme + items = self.items + theme = self.items.plot.theme self.plot_margin = theme.getp("plot_margin_right") - if pack.legends and pack.legends.right: + if items.legends and items.legends.right: self.legend = self._legend_width self.legend_box_spacing = theme.getp("legend_box_spacing") - self.strip_text_y_width_right = pack.strip_text_y_width("right") + self.strip_text_y_width_right = items.strip_text_y_width("right") # Adjust plot_margin to make room for ylabels that protude well # beyond the axes # NOTE: This adjustment breaks down when the protrusion is large - protrusion = pack.axis_text_x_right_protrusion("all") + protrusion = items.axis_text_x_right_protrusion("all") adjustment = protrusion - (self.total - self.plot_margin) if adjustment > 0: self.plot_margin += adjustment @cached_property def _legend_size(self) -> tuple[float, float]: - if not (self.pack.legends and self.pack.legends.right): + if not (self.items.legends and self.items.legends.right): return (0, 0) - return self.pack.calc.size(self.pack.legends.right.box) + return self.items.calc.size(self.items.legends.right.box) def edge(self, item: str) -> float: """ @@ -236,45 +238,45 @@ class top_spaces(_side_spaces): strip_text_x_height_top: float = 0 def _calculate(self): - pack = self.pack - theme = self.pack.theme - calc = self.pack.calc + items = self.items + theme = self.items.plot.theme + calc = self.items.calc W, H = theme.getp("figure_size") F = W / H self.plot_margin = theme.getp("plot_margin_top") * F - if pack.plot_title: - self.plot_title = calc.height(pack.plot_title) + if items.plot_title: + self.plot_title = calc.height(items.plot_title) self.plot_title_margin_bottom = ( theme.getp(("plot_title", "margin")).get_as("b", "fig") * F ) - if pack.plot_subtitle: - self.plot_subtitle = calc.height(pack.plot_subtitle) + if items.plot_subtitle: + self.plot_subtitle = calc.height(items.plot_subtitle) self.plot_subtitle_margin_bottom = ( theme.getp(("plot_subtitle", "margin")).get_as("b", "fig") * F ) - if pack.legends and pack.legends.top: + if items.legends and items.legends.top: self.legend = self._legend_height self.legend_box_spacing = theme.getp("legend_box_spacing") * F - self.strip_text_x_height_top = pack.strip_text_x_height("top") + self.strip_text_x_height_top = items.strip_text_x_height("top") # Adjust plot_margin to make room for ylabels that protude well # beyond the axes # NOTE: This adjustment breaks down when the protrusion is large - protrusion = pack.axis_text_y_top_protrusion("all") + protrusion = items.axis_text_y_top_protrusion("all") adjustment = protrusion - (self.total - self.plot_margin) if adjustment > 0: self.plot_margin += adjustment @cached_property def _legend_size(self) -> tuple[float, float]: - if not (self.pack.legends and self.pack.legends.top): + if not (self.items.legends and self.items.legends.top): return (0, 0) - return self.pack.calc.size(self.pack.legends.top.box) + return self.items.calc.size(self.items.legends.top.box) def edge(self, item: str) -> float: """ @@ -302,48 +304,48 @@ class bottom_spaces(_side_spaces): axis_ticks_x: float = 0 def _calculate(self): - pack = self.pack - theme = self.pack.theme - calc = self.pack.calc + items = self.items + theme = self.items.plot.theme + calc = self.items.calc W, H = theme.getp("figure_size") F = W / H self.plot_margin = theme.getp("plot_margin_bottom") * F - if pack.plot_caption: - self.plot_caption = calc.height(pack.plot_caption) + if items.plot_caption: + self.plot_caption = calc.height(items.plot_caption) self.plot_caption_margin_top = ( theme.getp(("plot_caption", "margin")).get_as("t", "fig") * F ) - if pack.legends and pack.legends.bottom: + if items.legends and items.legends.bottom: self.legend = self._legend_height self.legend_box_spacing = theme.getp("legend_box_spacing") * F - if pack.axis_title_x: - self.axis_title_x = calc.height(pack.axis_title_x) + if items.axis_title_x: + self.axis_title_x = calc.height(items.axis_title_x) self.axis_title_x_margin_top = ( theme.getp(("axis_title_x", "margin")).get_as("t", "fig") * F ) # Account for the space consumed by the axis - self.axis_ticks_x = pack.axis_ticks_x_max_height("last_row") - self.axis_text_x = pack.axis_text_x_max_height("last_row") + self.axis_ticks_x = items.axis_ticks_x_max_height("last_row") + self.axis_text_x = items.axis_text_x_max_height("last_row") # Adjust plot_margin to make room for ylabels that protude well # beyond the axes # NOTE: This adjustment breaks down when the protrusion is large - protrusion = pack.axis_text_y_bottom_protrusion("all") + protrusion = items.axis_text_y_bottom_protrusion("all") adjustment = protrusion - (self.total - self.plot_margin) if adjustment > 0: self.plot_margin += adjustment @cached_property def _legend_size(self) -> tuple[float, float]: - if not (self.pack.legends and self.pack.legends.bottom): + if not (self.items.legends and self.items.legends.bottom): return (0, 0) - return self.pack.calc.size(self.pack.legends.bottom.box) + return self.items.calc.size(self.items.legends.bottom.box) def edge(self, item: str) -> float: """ @@ -355,10 +357,21 @@ def edge(self, item: str) -> float: @dataclass class LayoutSpaces: """ - Space created by the layout management + Compute the all the spaces required in the layout + + These are: + + 1. The space of each artist between the panel and the edge of the + figure. + 2. The space in-between the panels + + From these values, we put together the grid-spec parameters required + by matplotblib to position the axes. We also use the values to adjust + the coordinates of all the artists that occupy these spaces, placing + them in their final positions. """ - pack: LayoutPack + plot: ggplot l: left_spaces = field(init=False) """All subspaces to the left of the panels""" @@ -394,14 +407,15 @@ class LayoutSpaces: """Grid spacing btn panels w.r.t figure""" def __post_init__(self): - self.W, self.H = self.pack.theme.getp("figure_size") + self.items = LayoutItems(self.plot) + self.W, self.H = self.plot.theme.getp("figure_size") # Calculate the spacing along the edges of the panel area # (spacing required by plotnine) - self.l = left_spaces(self.pack) - self.r = right_spaces(self.pack) - self.t = top_spaces(self.pack) - self.b = bottom_spaces(self.pack) + self.l = left_spaces(self.items) + self.r = right_spaces(self.items) + self.t = top_spaces(self.items) + self.b = bottom_spaces(self.items) # Calculate the gridspec params # (spacing required by mpl) @@ -411,7 +425,7 @@ def __post_init__(self): # It is simpler to adjust for the aspect ratio than to calculate # the final parameters that are true to the aspect ratio in # one-short - if (ratio := self.pack.facet._aspect_ratio()) is not None: + if (ratio := self.plot.facet._aspect_ratio()) is not None: current_ratio = self.aspect_ratio if ratio > current_ratio: # Increase aspect ratio, taller panels @@ -470,14 +484,14 @@ def _calculate_panel_spacing(self) -> GridSpecParams: This ensures that the same fraction gives equals space in both directions. """ - if isinstance(self.pack.facet, facet_wrap): + if isinstance(self.plot.facet, facet_wrap): wspace, hspace = self._calculate_panel_spacing_facet_wrap() - elif isinstance(self.pack.facet, facet_grid): + elif isinstance(self.plot.facet, facet_grid): wspace, hspace = self._calculate_panel_spacing_facet_grid() - elif isinstance(self.pack.facet, facet_null): + elif isinstance(self.plot.facet, facet_null): wspace, hspace = self._calculate_panel_spacing_facet_null() else: - raise TypeError(f"Unknown type of facet: {type(self.pack.facet)}") + raise TypeError(f"Unknown type of facet: {type(self.plot.facet)}") return GridSpecParams( self.left, self.right, self.top, self.bottom, wspace, hspace @@ -487,18 +501,16 @@ def _calculate_panel_spacing_facet_grid(self) -> tuple[float, float]: """ Calculate spacing parts for facet_grid """ - theme = self.pack.theme + theme = self.plot.theme - ncol = self.pack.facet.ncol - nrow = self.pack.facet.nrow - - W, H = theme.getp("figure_size") + ncol = self.plot.facet.ncol + nrow = self.plot.facet.nrow # Both spacings are specified as fractions of the figure width # Multiply the vertical by (W/H) so that the gullies along both # directions are equally spaced. self.sw = theme.getp("panel_spacing_x") - self.sh = theme.getp("panel_spacing_y") * W / H + self.sh = theme.getp("panel_spacing_y") * self.W / self.H # width and height of axes as fraction of figure width & height self.w = ((self.right - self.left) - self.sw * (ncol - 1)) / ncol @@ -513,16 +525,15 @@ def _calculate_panel_spacing_facet_wrap(self) -> tuple[float, float]: """ Calculate spacing parts for facet_wrap """ - facet = self.pack.facet - theme = self.pack.theme + facet = self.plot.facet + theme = self.plot.theme ncol = facet.ncol nrow = facet.nrow - W, H = theme.getp("figure_size") # Both spacings are specified as fractions of the figure width self.sw = theme.getp("panel_spacing_x") - self.sh = theme.getp("panel_spacing_y") * W / H + self.sh = theme.getp("panel_spacing_y") * self.W / self.H # A fraction of the strip height # Effectively slides the strip @@ -539,13 +550,13 @@ def _calculate_panel_spacing_facet_wrap(self) -> tuple[float, float]: self.sh += self.t.strip_text_x_height_top * (1 + strip_align_x) if facet.free["x"]: - self.sh += self.pack.axis_text_x_max_height( + self.sh += self.items.axis_text_x_max_height( "all" - ) + self.pack.axis_ticks_x_max_height("all") + ) + self.items.axis_ticks_x_max_height("all") if facet.free["y"]: - self.sw += self.pack.axis_text_y_max_width( + self.sw += self.items.axis_text_y_max_width( "all" - ) + self.pack.axis_ticks_y_max_width("all") + ) + self.items.axis_ticks_y_max_width("all") # width and height of axes as fraction of figure width & height self.w = ((self.right - self.left) - self.sw * (ncol - 1)) / ncol @@ -574,7 +585,7 @@ def _reduce_height(self, ratio: float): h1 = ratio * self.w * (self.W / self.H) # Half of the total vertical reduction w.r.t figure height - dh = (self.h - h1) * self.pack.facet.nrow / 2 + dh = (self.h - h1) * self.plot.facet.nrow / 2 # Reduce plot area height self.gsparams.top -= dh @@ -592,7 +603,7 @@ def _reduce_width(self, ratio: float): w1 = (self.h * self.H) / (ratio * self.W) # Half of the total horizontal reduction w.r.t figure width - dw = (self.w - w1) * self.pack.facet.ncol / 2 + dw = (self.w - w1) * self.plot.facet.ncol / 2 # Reduce width self.gsparams.left += dw diff --git a/plotnine/_mpl/layout_manager/_tight_layout.py b/plotnine/_mpl/layout_manager/_tight_layout.py deleted file mode 100644 index ed730fff3..000000000 --- a/plotnine/_mpl/layout_manager/_tight_layout.py +++ /dev/null @@ -1,196 +0,0 @@ -""" -Routines to adjust subplot params so that subplots are -nicely fit in the figure. In doing so, only axis labels, tick labels, axes -titles and offsetboxes that are anchored to axes are currently considered. - -Internally, this module assumes that the margins (left margin, etc.) which are -differences between `Axes.get_tightbbox` and `Axes.bbox` are independent of -Axes position. This may fail if `Axes.adjustable` is `datalim` as well as -such cases as when left or right margin are affected by xlabel. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -from ..utils import get_transPanels -from ._plot_side_space import GridSpecParams, LayoutSpaces - -if TYPE_CHECKING: - from matplotlib.figure import Figure - from matplotlib.text import Text - from matplotlib.transforms import Transform - - from plotnine._mpl.offsetbox import FlexibleAnchoredOffsetbox - from plotnine.iapi import legend_artists - - from ._layout_pack import LayoutPack - - -def compute_layout(pack: LayoutPack) -> LayoutSpaces: - """ - Compute tight layout parameters - """ - return LayoutSpaces(pack) - - -def adjust_figure_artists(pack: LayoutPack, spaces: LayoutSpaces): - """ - Set the x,y position of the artists around the panels - """ - theme = pack.theme - params = spaces.gsparams - - if pack.plot_title: - ha = theme.getp(("plot_title", "ha")) - pack.plot_title.set_y(spaces.t.edge("plot_title")) - horizontally_align_text_with_panels(pack.plot_title, params, ha, pack) - - if pack.plot_subtitle: - ha = theme.getp(("plot_subtitle", "ha")) - pack.plot_subtitle.set_y(spaces.t.edge("plot_subtitle")) - horizontally_align_text_with_panels( - pack.plot_subtitle, params, ha, pack - ) - - if pack.plot_caption: - ha = theme.getp(("plot_caption", "ha"), "right") - pack.plot_caption.set_y(spaces.b.edge("plot_caption")) - horizontally_align_text_with_panels( - pack.plot_caption, params, ha, pack - ) - - if pack.axis_title_x: - ha = theme.getp(("axis_title_x", "ha"), "center") - pack.axis_title_x.set_y(spaces.b.edge("axis_title_x")) - horizontally_align_text_with_panels( - pack.axis_title_x, params, ha, pack - ) - - if pack.axis_title_y: - va = theme.getp(("axis_title_y", "va"), "center") - pack.axis_title_y.set_x(spaces.l.edge("axis_title_y")) - vertically_align_text_with_panels(pack.axis_title_y, params, va, pack) - - if pack.legends: - set_legends_position(pack.legends, spaces, pack.figure) - - -def horizontally_align_text_with_panels( - text: Text, params: GridSpecParams, ha: str | float, pack: LayoutPack -): - """ - Horizontal justification - - Reinterpret horizontal alignment to be justification about the panels. - """ - if isinstance(ha, str): - lookup = { - "left": 0.0, - "center": 0.5, - "right": 1.0, - } - f = lookup[ha] - else: - f = ha - - width = pack.calc.width(text) - x = params.left * (1 - f) + (params.right - width) * f - text.set_x(x) - text.set_horizontalalignment("left") - - -def vertically_align_text_with_panels( - text: Text, params: GridSpecParams, va: str | float, pack: LayoutPack -): - """ - Vertical justification - - Reinterpret vertical alignment to be justification about the panels. - """ - if isinstance(va, str): - lookup = { - "top": 1.0, - "center": 0.5, - "baseline": 0.5, - "center_baseline": 0.5, - "bottom": 0.0, - } - f = lookup[va] - else: - f = va - - height = pack.calc.height(text) - y = params.bottom * (1 - f) + (params.top - height) * f - text.set_y(y) - text.set_verticalalignment("bottom") - - -def set_legends_position( - legends: legend_artists, - spaces: LayoutSpaces, - fig: Figure, -): - """ - Place legend on the figure and justify is a required - """ - - def set_position( - aob: FlexibleAnchoredOffsetbox, - anchor_point: tuple[float, float], - xy_loc: tuple[float, float], - transform: Transform = fig.transFigure, - ): - """ - Place box (by the anchor point) at given xy location - - Parameters - ---------- - aob : - Offsetbox to place - anchor_point : - Point on the Offsefbox. - xy_loc : - Point where to place the offsetbox. - transform : - Transformation - """ - aob.xy_loc = xy_loc - aob.set_bbox_to_anchor(anchor_point, transform) # type: ignore - - params = fig.subplotpars - if legends.right: - j = legends.right.justification - y = ( - params.bottom * (1 - j) - + (params.top - spaces.r._legend_height) * j - ) - x = spaces.r.edge("legend") - set_position(legends.right.box, (x, y), (1, 0)) - - if legends.left: - j = legends.left.justification - y = ( - params.bottom * (1 - j) - + (params.top - spaces.l._legend_height) * j - ) - x = spaces.l.edge("legend") - set_position(legends.left.box, (x, y), (0, 0)) - - if legends.top: - j = legends.top.justification - x = params.left * (1 - j) + (params.right - spaces.t._legend_width) * j - y = spaces.t.edge("legend") - set_position(legends.top.box, (x, y), (0, 1)) - - if legends.bottom: - j = legends.bottom.justification - x = params.left * (1 - j) + (params.right - spaces.b._legend_width) * j - y = spaces.b.edge("legend") - set_position(legends.bottom.box, (x, y), (0, 0)) - - # Inside legends are placed using the panels coordinate system - if legends.inside: - transPanels = get_transPanels(fig) - for l in legends.inside: - set_position(l.box, l.position, l.justification, transPanels)