Skip to content

Commit

Permalink
Improve cat/view spacing in pylace plot.state
Browse files Browse the repository at this point in the history
  • Loading branch information
Baxter Eaves committed Jan 12, 2024
1 parent d3a45a2 commit ad5843b
Showing 1 changed file with 22 additions and 15 deletions.
37 changes: 22 additions & 15 deletions pylace/lace/plot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Plotting utilities."""

from math import ceil
from typing import Any, Dict, Optional, Union

import matplotlib.pyplot as plt
Expand Down Expand Up @@ -292,8 +293,8 @@ def state(
*,
cmap: Optional[str] = None,
missing_color=None,
cat_gap: int = 1,
view_gap: int = 1,
cat_gap: Union[float, int] = 0.1,
view_gap: Union[float, int] = 0.2,
show_index: bool = True,
show_columns: bool = True,
min_height: int = 0,
Expand All @@ -320,10 +321,12 @@ def state(
missing_color: optional, default: red
The RGBA array representation ([float, float, float, float]) of the
color to use to represent missing data
cat_gap: int, optional, default: 1
The vertical spacing (in cells) between categories
view_gap: int, optional, default: 1
The horizontal spacing (in cell) between views
cat_gap: int or float, default: 0.1
The vertical spacing (in cells if int or fraction of table height if
float) between categories
view_gap: int or float, default: 0.2
The horizontal spacing (in cells if int or fraction of table width if
float) between views
show_index: bool, default: True
If True (default), will show row names next to rows in each view
show_columns: bool, default: True
Expand All @@ -349,17 +352,10 @@ def state(
>>> engine = Animals()
>>> fig = plt.figure(tight_layout=True, facecolor="#00000000")
>>> ax = plt.gca()
>>> plot.state(
... engine,
... 7,
... view_gap=13,
... cat_gap=3,
... ax=ax,
... )
>>> plot.state(engine, 7, ax=ax)
>>> _ = plt.axis("off")
>>> plt.show()
Render a satellites State, which has continuous, categorial and
missing data
Expand Down Expand Up @@ -388,6 +384,17 @@ def state(
cmap = "gray_r"

n_rows, n_cols = engine.shape

if isinstance(cat_gap, float):
assert cat_gap >= 0.0
assert cat_gap <= 1.0
cat_gap = ceil(n_rows * cat_gap)

if isinstance(view_gap, float):
assert view_gap >= 0.0
assert view_gap <= 1.0
view_gap = ceil(n_cols * view_gap)

col_asgn = engine.column_assignment(state_ix)
row_asgns = engine.row_assignments(state_ix)

Expand Down Expand Up @@ -470,7 +477,7 @@ def state(

ax.text(
col_start + view_counts[view_ix] / 2.0 - 0.5,
row_start + cat_counts[cat_ix] + cat_gap * 0.15,
row_start + cat_counts[cat_ix] + cat_gap * 0.25,
f"$C_{{{cat_ix}}}$",
ha="center",
va="center",
Expand Down

0 comments on commit ad5843b

Please sign in to comment.