diff --git a/pylace/lace/plot.py b/pylace/lace/plot.py index b4cf5fd3..0bc431e8 100644 --- a/pylace/lace/plot.py +++ b/pylace/lace/plot.py @@ -1,5 +1,6 @@ """Plotting utilities.""" +from math import ceil from typing import Any, Dict, Optional, Union import matplotlib.pyplot as plt @@ -292,8 +293,8 @@ def state( *, cmap: Optional[str] = None, missing_color=None, - cat_gap: int = 1, - view_gap: int = 1, + cat_gap: Union[float, int] = 0.1, + view_gap: Union[float, int] = 0.2, show_index: bool = True, show_columns: bool = True, min_height: int = 0, @@ -320,10 +321,12 @@ def state( missing_color: optional, default: red The RGBA array representation ([float, float, float, float]) of the color to use to represent missing data - cat_gap: int, optional, default: 1 - The vertical spacing (in cells) between categories - view_gap: int, optional, default: 1 - The horizontal spacing (in cell) between views + cat_gap: int or float, default: 0.1 + The vertical spacing (in cells if int or fraction of table height if + float) between categories + view_gap: int or float, default: 0.2 + The horizontal spacing (in cells if int or fraction of table width if + float) between views show_index: bool, default: True If True (default), will show row names next to rows in each view show_columns: bool, default: True @@ -349,17 +352,10 @@ def state( >>> engine = Animals() >>> fig = plt.figure(tight_layout=True, facecolor="#00000000") >>> ax = plt.gca() - >>> plot.state( - ... engine, - ... 7, - ... view_gap=13, - ... cat_gap=3, - ... ax=ax, - ... ) + >>> plot.state(engine, 7, ax=ax) >>> _ = plt.axis("off") >>> plt.show() - Render a satellites State, which has continuous, categorial and missing data @@ -388,6 +384,17 @@ def state( cmap = "gray_r" n_rows, n_cols = engine.shape + + if isinstance(cat_gap, float): + assert cat_gap >= 0.0 + assert cat_gap <= 1.0 + cat_gap = ceil(n_rows * cat_gap) + + if isinstance(view_gap, float): + assert view_gap >= 0.0 + assert view_gap <= 1.0 + view_gap = ceil(n_cols * view_gap) + col_asgn = engine.column_assignment(state_ix) row_asgns = engine.row_assignments(state_ix) @@ -470,7 +477,7 @@ def state( ax.text( col_start + view_counts[view_ix] / 2.0 - 0.5, - row_start + cat_counts[cat_ix] + cat_gap * 0.15, + row_start + cat_counts[cat_ix] + cat_gap * 0.25, f"$C_{{{cat_ix}}}$", ha="center", va="center",