diff --git a/CHANGELOG.md b/CHANGELOG.md index 0ab854d6..3351ac1b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,10 +7,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- Added `plot.state` function in pylace to render PCC states + ### Changed - Updated all packages to have the correct SPDX for the Business Source License +### Fixed + +- Fixed issue that would cause random row order when indexing pylace Engines by a single (column) index, e.g., engine['column'] would return the columns in a different order every time the engine was loaded + ## [python-0.5.0] - 2023-11-20 ### Added diff --git a/lace/lace_codebook/src/codebook.rs b/lace/lace_codebook/src/codebook.rs index 129ae166..11f2c798 100644 --- a/lace/lace_codebook/src/codebook.rs +++ b/lace/lace_codebook/src/codebook.rs @@ -137,8 +137,8 @@ impl RowNameList { .map(|row_name| self.row_names.push(row_name)) } - pub fn iter(&self) -> std::collections::hash_map::Iter { - self.index_lookup.iter() + pub fn iter(&self) -> std::iter::Enumerate> { + self.row_names.iter().enumerate() } pub fn remove(&mut self, row_name: &str) -> bool { diff --git a/pylace/lace/plot.py b/pylace/lace/plot.py index b4a7466b..4049f59c 100644 --- a/pylace/lace/plot.py +++ b/pylace/lace/plot.py @@ -1,7 +1,8 @@ -"""Plottling utilities.""" +"""Plotting utilities.""" from typing import Dict, Optional, Union +import matplotlib.pyplot as plt import numpy as np import pandas as pd import plotly.express as px @@ -209,6 +210,306 @@ def prediction_uncertainty( return fig +def _get_xs(engine, cat_rows, view_cols, compute_ps=False): + xs = engine[cat_rows, view_cols] + + if xs is None: + if compute_ps: + return np.array([None]), [0.0] + else: + return np.array([None]) + + if isinstance(xs, (int, float)): + xs = np.array(xs) + if compute_ps: + ps = [0.0] + else: + xs = xs[:, 1:] + + if compute_ps: + n_rows, n_cols = xs.shape + ps = [] + for row_ix in range(n_rows): + row = xs[row_ix, :] + null_count = row.null_count() + to_drop = [c for c in row.columns if null_count[0, c] > 0] + k = n_cols - len(to_drop) + if k == 0: + ps.append(float("-inf")) + continue + + row = row.drop(to_drop) + p = engine.logp(row) / k + ps.append(p) + + xs = xs.to_numpy() + + if compute_ps: + return xs, ps + else: + return xs + + +def _makenorm(cmap, missing_color, *, mapper=None, xlim=None): + if mapper is not None: + + def _fn(val): + k = len(mapper) - 1 + return mapper[val] / k + + else: + xmin, xmax = xlim + + def _fn(val): + return (val - xmin) / (xmax - xmin) + + def _norm(val): + colormap = plt.cm._colormaps[cmap] + if pd.isnull(val): + return missing_color + else: + return np.array(colormap(_fn(val))) + + return _norm + + +def _get_colors(engine: Engine, *, cmap: str = "gray_r", missing_color=None): + codebook = engine.codebook + + if missing_color is None: + missing_color = np.array([1.0, 0.0, 0.2, 1.0]) + + n_rows, n_cols = engine.shape + colors = np.zeros((n_rows, n_cols, 4)) + + for i, col in enumerate(engine.columns): + ftype = engine.ftype(col) + xs = engine[col][col] + if ftype == "Categorical": + valmap = codebook.value_map(col) + mapper = {v: k for k, v in enumerate(valmap.values())} + + _norm = _makenorm(cmap, missing_color, mapper=mapper) + else: + _norm = _makenorm(cmap, missing_color, xlim=(xs.min(), xs.max())) + + colors[:, i] = np.array([_norm(x) for x in xs]) + + return colors + + +def state( + engine: Engine, + state_ix: int, + *, + cmap: Optional[str] = None, + missing_color=None, + cat_gap: int = 1, + view_gap: int = 1, + show_index: bool = True, + show_columns: bool = True, + min_height: int = 0, + min_width: int = 0, + aspect=None, + ax=None, +): + """ + Plot a Lace state. + + View are sorted from largest (most columns) to smallest. Within views, + columns are sorted from highest (left) to lowest total likelihood. + Categories are sorted from largest (most rows) to smallest. Within + categories, rows are sorted from highest (top) to lowest log likelihood. + + Parameters + ---------- + engine: Engine + The engine containing the states to plot + state_ix: int + The index of the state to plot + cmap: str, optional, default: gray_r + The color map to use for present data + missing_color: optional, default: red + The RGBA array representation ([float, float, float, float]) of the + color to use to represent missing data + cat_gap: int, optional, default: 1 + The vertical spacing (in cells) between categories + view_gap: int, optional, default: 1 + The horizontal spacing (in cell) between views + show_index: bool, default: True + If True (default), will show row names next to rows in each view + show_columns: bool, default: True + If True (default), will show columns names above each column + min_height: int, default: 0 + The minimum height in cells of the state render. Padding will be added + to the lower part of the image. + min_width: int (default: 0) + The minimum width in cells of the state render. Padding will be added + to the right of the image. + aspect: {'equal', 'auto'} or float or None, default: None + matplotlib imshow aspect + ax: matplotlib.Axis, optional + The axis on which to plot + + Examples + -------- + Render an animals state + + >>> import matplotlib.pyplot as plt + >>> from lace.examples import Animals + >>> from lace import plot + >>> engine = Animals() + >>> fig = plt.figure(tight_layout=True, facecolor="#00000000") + >>> ax = plt.gca() + >>> plot.state( + ... engine, + ... 7, + ... view_gap=13, + ... cat_gap=3, + ... ax=ax, + ... ) + >>> _ = plt.axis("off") + >>> plt.show() + + + Render a satellites State, which has continuous, categorial and + missing data + + >>> from lace.examples import Satellites + >>> engine = Satellites() + >>> fig = plt.figure(tight_layout=True, facecolor="#00000000") + >>> ax = plt.gca() + >>> plot.state( + ... engine, + ... 1, + ... view_gap=2, + ... cat_gap=100, + ... show_index=False, + ... show_columns=False, + ... ax=ax, + ... cmap="YlGnBu", + ... aspect="auto" + ... ) + >>> _ = plt.axis("off") + >>> plt.show() + """ + if ax is None: + ax = plt.gca() + + if cmap is None: + cmap = "gray_r" + + n_rows, n_cols = engine.shape + col_asgn = engine.column_assignment(state_ix) + row_asgns = engine.row_assignments(state_ix) + + n_views = len(row_asgns) + max_cats = max(max(asgn) + 1 for asgn in row_asgns) + + dim_col = n_cols + (n_views - 1) * view_gap + dim_row = n_rows + (max_cats - 1) * cat_gap + + zs = np.zeros((dim_row, dim_col, 4)) + + row_names = engine.index + col_names = engine.columns + + view_counts = np.bincount(col_asgn) + view_ixs = np.argsort(view_counts)[::-1] + + colors = _get_colors(engine, cmap=cmap, missing_color=missing_color) + + col_start = 0 + for view_ix in view_ixs: + row_asgn = row_asgns[view_ix] + row_start = 0 + view_cols = [i for i, z in enumerate(col_asgn) if z == view_ix] + view_len = len(view_cols) + + # sort columns within each view + ps = [] + for col in view_cols: + xs = engine[col][:, 1:].drop_nulls() + if xs.shape[0] == 0: + ps.append(float("-inf")) + ps.append(engine.logp(xs).sum() / xs.shape[0]) + ixs = np.argsort(ps)[::-1] + view_cols = [view_cols[ix] for ix in ixs] + + if show_columns: + for i, col in enumerate(view_cols): + ax.text( + col_start + i, + -1, + col_names[col], + ha="center", + va="bottom", + rotation=90, + ) + + max(row_asgn) + 1 + cat_counts = np.bincount(row_asgn) + cat_ixs = np.argsort(cat_counts)[::-1] + + for cat_ix in cat_ixs: + cat_rows = [i for i, z in enumerate(row_asgn) if z == cat_ix] + cat_len = len(cat_rows) + + xs, ps = _get_xs(engine, cat_rows, view_cols, compute_ps=True) + ixs = np.argsort(ps)[::-1] + cat_rows = [cat_rows[ix] for ix in ixs] + + cs = np.zeros((len(cat_rows), len(view_cols), 4)) + for iix, i in enumerate(cat_rows): + for jix, j in enumerate(view_cols): + cs[iix, jix] = colors[i, j] + # cs = colors[cat_rows, view_cols, :] + zs[ + row_start : row_start + cat_len, + col_start : col_start + view_len, + ] = cs + + # label rows + if show_index: + for i, row in enumerate(cat_rows): + ax.text( + col_start - 1, + i + row_start, + row_names[row], + ha="right", + va="center", + ) + + ax.text( + col_start + view_counts[view_ix] / 2.0 - 0.5, + row_start + cat_counts[cat_ix] + cat_gap * 0.15, + f"$C_{{{cat_ix}}}$", + ha="center", + va="center", + ) + + row_start += cat_len + cat_gap + + ax.text( + col_start + view_counts[view_ix] / 2.0 - 0.5, + dim_row + cat_gap, + f"$V_{{{view_ix}}}$", + ha="left", + va="top", + ) + col_start += view_len + view_gap + + if min_height > zs.shape[0]: + margin = min_height - zs.shape[0] + zs = np.vstack((zs, np.zeros((margin, zs.shape[1], 4)))) + + if min_width > zs.shape[1]: + margin = min_width - zs.shape[1] + zs = np.hstack((zs, np.zeros((zs.shape[0], margin, 4)))) + + ax.matshow(zs, aspect=aspect) + + if __name__ == "__main__": import doctest diff --git a/pylace/src/utils.rs b/pylace/src/utils.rs index 55e063c6..aa17ddaf 100644 --- a/pylace/src/utils.rs +++ b/pylace/src/utils.rs @@ -155,7 +155,7 @@ impl<'s> TableIndex<'s> { let row_ixs = codebook .row_names .iter() - .map(|(a, &b)| (b, a.clone())) + .map(|(a, b)| (a, b.clone())) .collect(); let col_ixs = ixs.col_ixs(codebook)?; @@ -406,7 +406,7 @@ impl Indexer { pub(crate) fn rows(codebook: &Codebook) -> Self { let mut to_ix: HashMap = HashMap::new(); let mut to_name: HashMap = HashMap::new(); - codebook.row_names.iter().for_each(|(name, &ix)| { + codebook.row_names.iter().for_each(|(ix, name)| { to_ix.insert(name.clone(), ix); to_name.insert(ix, name.clone()); }); diff --git a/pylace/tests/test_indexing.py b/pylace/tests/test_indexing.py index ca31084d..9d5f3bbd 100644 --- a/pylace/tests/test_indexing.py +++ b/pylace/tests/test_indexing.py @@ -22,6 +22,19 @@ def _random_index(n, strs): return random.choice(strs) +@pytest.mark.parametrize("target", ["black", "swims", 12, 45]) +def test_single_index_consistency(target): + # for some reason, the order of the row indeices iterator was different + # each time we read in the metadata -- probably because of some random + # state initialization in a hashmap. Not sure why that would happen, but + # we fixed that particular issue. + a1 = Animals() + a2 = Animals() + xs = a1[target][:, 1] + ys = a2[target][:, 1] + assert all(x == y for x, y in zip(xs, ys)) + + def tests_index_positive_int_tuple(animals): assert animals[0, 0] == 0 assert animals[2, 3] == 0