From 087543abbbf6cf3723ac113063d354a906214e4f Mon Sep 17 00:00:00 2001 From: Baxter Eaves Date: Thu, 30 Nov 2023 17:02:57 -0600 Subject: [PATCH 1/7] can plot all-numeric state --- pylace/lace/plot.py | 96 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 94 insertions(+), 2 deletions(-) diff --git a/pylace/lace/plot.py b/pylace/lace/plot.py index b4a7466b..6ffebc3e 100644 --- a/pylace/lace/plot.py +++ b/pylace/lace/plot.py @@ -7,6 +7,7 @@ import plotly.express as px import plotly.graph_objects as go import polars as pl +import matplotlib.pyplot as plt from lace import Engine @@ -209,7 +210,98 @@ def prediction_uncertainty( return fig +def _get_xs(engine, cat_rows, view_cols, compute_ps=False): + xs = engine[cat_rows, view_cols] + if isinstance(xs, (int, float)): + xs = np.array(xs) + if compute_ps: + ps = [0.0] + else: + xs = xs[:, 1:] + if compute_ps: + ps = engine.logp(xs) + xs = xs.to_numpy() + + if compute_ps: + return xs, ps + else: + return xs + +def state( + engine: Engine, + state_ix: int, + missing_color = None, + cat_gap:int = 1, + view_gap:int = 1, + ax = None, +): + if ax is None: + ax = plt.gca() + + n_rows, n_cols = engine.shape + col_asgn = engine.column_assignment(state_ix) + row_asgns = engine.row_assignments(state_ix) + + n_views = len(row_asgns) + max_cats = max(max(asgn) + 1 for asgn in row_asgns) + + dim_col = n_cols + (n_views - 1) * view_gap + dim_row = n_rows + (max_cats - 1) * cat_gap + + zs = np.zeros((dim_row, dim_col)) + 0.1 + + row_names = engine.index + col_names = engine.columns + + col_start = 0 + for view_ix, row_asgn in enumerate(row_asgns): + row_start = 0 + n_cats = max(row_asgn) + 1 + view_len = sum(z == view_ix for z in col_asgn) + view_cols = [i for i, z in enumerate(col_asgn) if z == view_ix] + + # sort columns within each view + ps = [] + for col in view_cols: + ps.append(engine.logp(engine[col][:, 1:]).sum()) + ixs = np.argsort(ps)[::-1] + view_cols = [view_cols[ix] for ix in ixs] + + for i, col in enumerate(view_cols): + ax.text(col_start + i, -1, col_names[col], ha='center', va='bottom', rotation=90) + + for cat_ix in range(n_cats): + cat_len = sum(z == cat_ix for z in row_asgn) + cat_rows = [i for i, z in enumerate(row_asgn) if z == cat_ix] + + xs, ps = _get_xs(engine, cat_rows, view_cols, compute_ps=True) + ixs = np.argsort(ps)[::-1] + cat_rows = [cat_rows[ix] for ix in ixs] + xs = _get_xs(engine, cat_rows, view_cols) + zs[row_start:row_start+cat_len, col_start:col_start + view_len] = xs + + # label rows + for i, row in enumerate(cat_rows): + ax.text(col_start-1, i + row_start, row_names[row], ha='right', va='center') + + row_start += cat_len + cat_gap + + col_start += view_len + view_gap + + + ax.matshow(zs, cmap='gray_r') + + + if __name__ == "__main__": - import doctest + from lace.examples import Animals + + eng = Animals() + plt.figure(tight_layout=True, facecolor='#e8e8e8') + ax = plt.gca() + state(eng, 1, view_gap=15, ax=ax) + plt.axis('off') + plt.show() + # import doctest - doctest.testmod() + # doctest.testmod() From 8b35406639c868d1cb0e50cd5312130e765f49fe Mon Sep 17 00:00:00 2001 From: Baxter Eaves Date: Fri, 1 Dec 2023 08:47:56 -0600 Subject: [PATCH 2/7] Filling out state plot (partial progress) --- pylace/lace/plot.py | 146 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 123 insertions(+), 23 deletions(-) diff --git a/pylace/lace/plot.py b/pylace/lace/plot.py index 6ffebc3e..9c04c724 100644 --- a/pylace/lace/plot.py +++ b/pylace/lace/plot.py @@ -2,12 +2,12 @@ from typing import Dict, Optional, Union +import matplotlib.pyplot as plt import numpy as np import pandas as pd import plotly.express as px import plotly.graph_objects as go import polars as pl -import matplotlib.pyplot as plt from lace import Engine @@ -227,17 +227,65 @@ def _get_xs(engine, cat_rows, view_cols, compute_ps=False): else: return xs + +def _get_colors(engine: Engine, *, cmap: str = "gray_r", missing_color=None): + codebook = engine.codebook + + if missing_color is None: + missing_color = (1.0, 0.0, 0.2, 1.0) + + colors = {} + for col in engine.columns: + ftype = engine.ftype(col) + xs = engine[col][col] + if ftype == "Categorical": + valmap = codebook.value_map(col) + mapper = {v: k for k, v in enumerate(valmap.values())} + k = len(mapper) - 1 + + def _inner_norm(val): + return mapper[val] / k + + else: + xmin = xs.min() + xmax = xs.max() + + def _inner_norm(val): + return (val - xmin) / (xmax - xmin) + + colormap = plt.cm.get_cmap(cmap) + + def _norm(val): + if pd.isnull(val): + return missing_color + else: + return np.array(colormap(_inner_norm(val))) + + colors[col] = [_norm(x) for x in xs] + + df = pd.DataFrame(colors, index=engine.index) + return df[engine.columns] + + def state( engine: Engine, state_ix: int, - missing_color = None, - cat_gap:int = 1, - view_gap:int = 1, - ax = None, + *, + cmap: Optional[str] = None, + missing_color=None, + cat_gap: int = 1, + view_gap: int = 1, + show_index: bool = True, + show_columns: bool = True, + aspect=None, + ax=None, ): if ax is None: ax = plt.gca() + if cmap is None: + cmap = "gray_r" + n_rows, n_cols = engine.shape col_asgn = engine.column_assignment(state_ix) row_asgns = engine.row_assignments(state_ix) @@ -248,15 +296,21 @@ def state( dim_col = n_cols + (n_views - 1) * view_gap dim_row = n_rows + (max_cats - 1) * cat_gap - zs = np.zeros((dim_row, dim_col)) + 0.1 + zs = np.zeros((dim_row, dim_col, 4)) row_names = engine.index - col_names = engine.columns + col_names = engine.columns + + view_counts = np.bincount(col_asgn) + view_ixs = np.argsort(view_counts)[::-1] + + colors = _get_colors(engine, cmap=cmap, missing_color=missing_color) col_start = 0 - for view_ix, row_asgn in enumerate(row_asgns): + for view_ix in view_ixs: + row_asgn = row_asgns[view_ix] row_start = 0 - n_cats = max(row_asgn) + 1 + max(row_asgn) + 1 view_len = sum(z == view_ix for z in col_asgn) view_cols = [i for i, z in enumerate(col_asgn) if z == view_ix] @@ -267,40 +321,86 @@ def state( ixs = np.argsort(ps)[::-1] view_cols = [view_cols[ix] for ix in ixs] - for i, col in enumerate(view_cols): - ax.text(col_start + i, -1, col_names[col], ha='center', va='bottom', rotation=90) - - for cat_ix in range(n_cats): + if show_columns: + for i, col in enumerate(view_cols): + ax.text( + col_start + i, + -1, + col_names[col], + ha="center", + va="bottom", + rotation=90, + ) + + cat_counts = np.bincount(row_asgn) + cat_ixs = np.argsort(cat_counts)[::-1] + + for cat_ix in cat_ixs: cat_len = sum(z == cat_ix for z in row_asgn) cat_rows = [i for i, z in enumerate(row_asgn) if z == cat_ix] xs, ps = _get_xs(engine, cat_rows, view_cols, compute_ps=True) ixs = np.argsort(ps)[::-1] cat_rows = [cat_rows[ix] for ix in ixs] - xs = _get_xs(engine, cat_rows, view_cols) - zs[row_start:row_start+cat_len, col_start:col_start + view_len] = xs + + cs = colors.iloc[cat_rows, view_cols].values.tolist() + cs = np.asarray(cs, dtype=float) + zs[ + row_start : row_start + cat_len, + col_start : col_start + view_len, + ] = cs # label rows - for i, row in enumerate(cat_rows): - ax.text(col_start-1, i + row_start, row_names[row], ha='right', va='center') + if show_index: + for i, row in enumerate(cat_rows): + ax.text( + col_start - 1, + i + row_start, + row_names[row], + ha="right", + va="center", + ) + + ax.text( + col_start + view_counts[view_ix] / 2, + row_start + cat_counts[cat_ix], + f"$C_{{{cat_ix}}}$", + ha="center", + va="top", + ) row_start += cat_len + cat_gap + ax.text( + col_start + view_counts[view_ix] / 2, + dim_row + cat_gap, + f"$V_{{{view_ix}}}$", + ha="center", + va="top", + ) col_start += view_len + view_gap - - ax.matshow(zs, cmap='gray_r') - + ax.matshow(zs, cmap="gray_r", aspect=aspect) if __name__ == "__main__": from lace.examples import Animals eng = Animals() - plt.figure(tight_layout=True, facecolor='#e8e8e8') + # eng = Satellites() + plt.figure(tight_layout=True, facecolor="Gainsboro") ax = plt.gca() - state(eng, 1, view_gap=15, ax=ax) - plt.axis('off') + state( + eng, + 1, + view_gap=15, + cat_gap=2, + ax=ax, + show_index=True, + show_columns=True, + cmap="cubehelix", + ) + plt.axis("off") plt.show() # import doctest From b1c32ccb1c3c3023cf85847e566d438fece7dc46 Mon Sep 17 00:00:00 2001 From: Baxter Eaves Date: Mon, 4 Dec 2023 09:38:29 -0600 Subject: [PATCH 3/7] Fix random order in pylace column indexing the order of rows when single indexing, e.g., `engine[column]` would be differente each time you initialized the engine, and would not correspond to the order of rows in the row index. Fixed it. --- pylace/src/utils.rs | 15 ++++++++++----- pylace/tests/test_indexing.py | 13 +++++++++++++ 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/pylace/src/utils.rs b/pylace/src/utils.rs index 55e063c6..fd81916f 100644 --- a/pylace/src/utils.rs +++ b/pylace/src/utils.rs @@ -152,14 +152,19 @@ impl<'s> TableIndex<'s> { ) -> PyResult<(Vec<(usize, String)>, Vec<(usize, String)>)> { match self { Self::Single(ixs) => { - let row_ixs = codebook - .row_names - .iter() - .map(|(a, &b)| (b, a.clone())) + // let row_ixs = codebook + // .row_names + // .iter() + // .map(|(a, &b)| (b, a.clone())) + // .collect(); + let row_ixs: PyResult> = (0..codebook + .n_rows()) + .map(|ix| IntOrString::Int(ix as isize)) + .map(|ix| ix.row_ix(codebook)) .collect(); let col_ixs = ixs.col_ixs(codebook)?; - Ok((row_ixs, col_ixs)) + Ok((row_ixs?, col_ixs)) } Self::Tuple(row_ixs, col_ixs) => { col_ixs.col_ixs(codebook).and_then(|cixs| { diff --git a/pylace/tests/test_indexing.py b/pylace/tests/test_indexing.py index ca31084d..9d5f3bbd 100644 --- a/pylace/tests/test_indexing.py +++ b/pylace/tests/test_indexing.py @@ -22,6 +22,19 @@ def _random_index(n, strs): return random.choice(strs) +@pytest.mark.parametrize("target", ["black", "swims", 12, 45]) +def test_single_index_consistency(target): + # for some reason, the order of the row indeices iterator was different + # each time we read in the metadata -- probably because of some random + # state initialization in a hashmap. Not sure why that would happen, but + # we fixed that particular issue. + a1 = Animals() + a2 = Animals() + xs = a1[target][:, 1] + ys = a2[target][:, 1] + assert all(x == y for x, y in zip(xs, ys)) + + def tests_index_positive_int_tuple(animals): assert animals[0, 0] == 0 assert animals[2, 3] == 0 From 128100bdc10579e7651f8648a8135030f7d53811 Mon Sep 17 00:00:00 2001 From: Baxter Eaves Date: Mon, 4 Dec 2023 10:37:38 -0600 Subject: [PATCH 4/7] Add plot.state fn to plot PCC states in pylace --- pylace/lace/plot.py | 224 +++++++++++++++++++++++++++++++++----------- 1 file changed, 167 insertions(+), 57 deletions(-) diff --git a/pylace/lace/plot.py b/pylace/lace/plot.py index 9c04c724..525f94ae 100644 --- a/pylace/lace/plot.py +++ b/pylace/lace/plot.py @@ -1,4 +1,4 @@ -"""Plottling utilities.""" +"""Plotting utilities""" from typing import Dict, Optional, Union @@ -212,14 +212,36 @@ def prediction_uncertainty( def _get_xs(engine, cat_rows, view_cols, compute_ps=False): xs = engine[cat_rows, view_cols] + + if xs is None: + if compute_ps: + return np.array([None]), [0.0] + else: + return np.array([None]) + if isinstance(xs, (int, float)): xs = np.array(xs) if compute_ps: ps = [0.0] else: xs = xs[:, 1:] + if compute_ps: - ps = engine.logp(xs) + n_rows, n_cols = xs.shape + ps = [] + for row_ix in range(n_rows): + row = xs[row_ix, :] + null_count = row.null_count() + to_drop = [c for c in row.columns if null_count[0, c] > 0] + k = n_cols - len(to_drop) + if k == 0: + ps.append(float("-inf")) + continue + + row = row.drop(to_drop) + p = engine.logp(row) / k + ps.append(p) + xs = xs.to_numpy() if compute_ps: @@ -228,43 +250,52 @@ def _get_xs(engine, cat_rows, view_cols, compute_ps=False): return xs +def _makenorm(cmap, missing_color, *, mapper=None, xlim=None): + if mapper is not None: + + def _fn(val): + k = len(mapper) - 1 + return mapper[val] / k + + else: + xmin, xmax = xlim + + def _fn(val): + return (val - xmin) / (xmax - xmin) + + def _norm(val): + colormap = plt.cm._colormaps[cmap] + if pd.isnull(val): + return missing_color + else: + return np.array(colormap(_fn(val))) + + return _norm + + def _get_colors(engine: Engine, *, cmap: str = "gray_r", missing_color=None): codebook = engine.codebook if missing_color is None: - missing_color = (1.0, 0.0, 0.2, 1.0) + missing_color = np.array([1.0, 0.0, 0.2, 1.0]) - colors = {} - for col in engine.columns: + n_rows, n_cols = engine.shape + colors = np.zeros((n_rows, n_cols, 4)) + + for i, col in enumerate(engine.columns): ftype = engine.ftype(col) xs = engine[col][col] if ftype == "Categorical": valmap = codebook.value_map(col) mapper = {v: k for k, v in enumerate(valmap.values())} - k = len(mapper) - 1 - - def _inner_norm(val): - return mapper[val] / k + _norm = _makenorm(cmap, missing_color, mapper=mapper) else: - xmin = xs.min() - xmax = xs.max() - - def _inner_norm(val): - return (val - xmin) / (xmax - xmin) - - colormap = plt.cm.get_cmap(cmap) + _norm = _makenorm(cmap, missing_color, xlim=(xs.min(), xs.max())) - def _norm(val): - if pd.isnull(val): - return missing_color - else: - return np.array(colormap(_inner_norm(val))) + colors[:, i] = np.array([_norm(x) for x in xs]) - colors[col] = [_norm(x) for x in xs] - - df = pd.DataFrame(colors, index=engine.index) - return df[engine.columns] + return colors def state( @@ -277,9 +308,92 @@ def state( view_gap: int = 1, show_index: bool = True, show_columns: bool = True, + min_height: int = 0, + min_width: int = 0, aspect=None, ax=None, ): + """ + Plot a Lace state. + + View are sorted from largest (most columns) to smallest. Within views, + columns are sorted from highest (left) to lowest total likelihood. + Categories are sorted from largest (most rows) to smallest. Within + categories, rows are sorted from highest (top) to lowest log likelihood. + + Parameters + ---------- + engine: Engine + The engine containing the states to plot + state_ix: int + The index of the state to plot + cmap: str, optional, default: gray_r + The color map to use for present data + missing_color: optional, default: red + The RGBA array representation ([float, float, float, float]) of the + color to use to represent missing data + cat_gap: int, optional, default: 1 + The vertical spacing (in cells) between categories + view_gap: int, optional, default: 1 + The horizontal spacing (in cell) between views + show_index: bool, default: True + If True (default), will show row names next to rows in each view + show_columns: bool, default: True + If True (default), will show columns names above each column + min_height: int, default: 0 + The minimum height in cells of the state render. Padding will be added + to the lower part of the image. + min_width: int (default: 0) + The minimum width in cells of the state render. Padding will be added + to the right of the image. + aspect: {'equal', 'auto'} or float or None, default: None + matplotlib imshow aspect + ax: matplotlib.Axis, optional + The axis on which to plot + + Examples + -------- + + Render an animals state + + >>> import matplotlib.pyplot as plt + >>> from lace.examples import Animals + >>> from lace import plot + >>> engine = Animals() + >>> fig = plt.figure(tight_layout=True, facecolor="#00000000") + >>> ax = plt.gca() + >>> plot.state( + ... engine, + ... 7, + ... view_gap=13, + ... cat_gap=3, + ... ax=ax, + ... ) + >>> _ = plt.axis("off") + >>> plt.show() + + + Render a satellites State, which has continuous, categorial and + missing data + + >>> from lace.examples import Satellites + >>> engine = Satellites() + >>> fig = plt.figure(tight_layout=True, facecolor="#00000000") + >>> ax = plt.gca() + >>> plot.state( + ... engine, + ... 1, + ... view_gap=2, + ... cat_gap=100, + ... show_index=False, + ... show_columns=False, + ... ax=ax, + ... cmap="YlGnBu", + ... aspect="auto" + ... ) + >>> _ = plt.axis("off") + >>> plt.show() + """ if ax is None: ax = plt.gca() @@ -310,14 +424,16 @@ def state( for view_ix in view_ixs: row_asgn = row_asgns[view_ix] row_start = 0 - max(row_asgn) + 1 - view_len = sum(z == view_ix for z in col_asgn) view_cols = [i for i, z in enumerate(col_asgn) if z == view_ix] + view_len = len(view_cols) # sort columns within each view ps = [] for col in view_cols: - ps.append(engine.logp(engine[col][:, 1:]).sum()) + xs = engine[col][:, 1:].drop_nulls() + if xs.shape[0] == 0: + ps.append(float("-inf")) + ps.append(engine.logp(xs).sum() / xs.shape[0]) ixs = np.argsort(ps)[::-1] view_cols = [view_cols[ix] for ix in ixs] @@ -332,19 +448,23 @@ def state( rotation=90, ) + max(row_asgn) + 1 cat_counts = np.bincount(row_asgn) cat_ixs = np.argsort(cat_counts)[::-1] for cat_ix in cat_ixs: - cat_len = sum(z == cat_ix for z in row_asgn) cat_rows = [i for i, z in enumerate(row_asgn) if z == cat_ix] + cat_len = len(cat_rows) xs, ps = _get_xs(engine, cat_rows, view_cols, compute_ps=True) ixs = np.argsort(ps)[::-1] cat_rows = [cat_rows[ix] for ix in ixs] - cs = colors.iloc[cat_rows, view_cols].values.tolist() - cs = np.asarray(cs, dtype=float) + cs = np.zeros((len(cat_rows), len(view_cols), 4)) + for iix, i in enumerate(cat_rows): + for jix, j in enumerate(view_cols): + cs[iix, jix] = colors[i, j] + # cs = colors[cat_rows, view_cols, :] zs[ row_start : row_start + cat_len, col_start : col_start + view_len, @@ -362,46 +482,36 @@ def state( ) ax.text( - col_start + view_counts[view_ix] / 2, - row_start + cat_counts[cat_ix], + col_start + view_counts[view_ix] / 2.0 - 0.5, + row_start + cat_counts[cat_ix] + cat_gap * 0.15, f"$C_{{{cat_ix}}}$", ha="center", - va="top", + va="center", ) row_start += cat_len + cat_gap ax.text( - col_start + view_counts[view_ix] / 2, + col_start + view_counts[view_ix] / 2.0 - 0.5, dim_row + cat_gap, f"$V_{{{view_ix}}}$", - ha="center", + ha="left", va="top", ) col_start += view_len + view_gap - ax.matshow(zs, cmap="gray_r", aspect=aspect) + if min_height > zs.shape[0]: + margin = min_height - zs.shape[0] + zs = np.vstack((zs, np.zeros((margin, zs.shape[1], 4)))) + + if min_width > zs.shape[1]: + margin = min_width - zs.shape[1] + zs = np.hstack((zs, np.zeros((zs.shape[0], margin, 4)))) + + ax.matshow(zs, aspect=aspect) if __name__ == "__main__": - from lace.examples import Animals - - eng = Animals() - # eng = Satellites() - plt.figure(tight_layout=True, facecolor="Gainsboro") - ax = plt.gca() - state( - eng, - 1, - view_gap=15, - cat_gap=2, - ax=ax, - show_index=True, - show_columns=True, - cmap="cubehelix", - ) - plt.axis("off") - plt.show() - # import doctest + import doctest - # doctest.testmod() + doctest.testmod() From 2cad074b533522bef7469761cc554b4437f99c9c Mon Sep 17 00:00:00 2001 From: Baxter Eaves Date: Mon, 4 Dec 2023 10:40:16 -0600 Subject: [PATCH 5/7] Updated changelog --- CHANGELOG.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0767fdcb..bae04729 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- Added `plot.state` function in pylace to render PCC states + +### Fixed + +- Fixed issue that would cause random row order when indexing pylace Engines by a single (column) index, e.g., engine['column'] would return the columns in a different order every time the engine was loaded + ## [python-0.5.0] - 2023-11-20 ### Added From feb9c8feff5895d0b86216f82444522de5c64f99 Mon Sep 17 00:00:00 2001 From: Baxter Eaves Date: Mon, 4 Dec 2023 10:58:27 -0600 Subject: [PATCH 6/7] Fix root issue of indexing bug --- lace/lace_codebook/src/codebook.rs | 4 ++-- pylace/src/utils.rs | 17 ++++++----------- 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/lace/lace_codebook/src/codebook.rs b/lace/lace_codebook/src/codebook.rs index 129ae166..11f2c798 100644 --- a/lace/lace_codebook/src/codebook.rs +++ b/lace/lace_codebook/src/codebook.rs @@ -137,8 +137,8 @@ impl RowNameList { .map(|row_name| self.row_names.push(row_name)) } - pub fn iter(&self) -> std::collections::hash_map::Iter { - self.index_lookup.iter() + pub fn iter(&self) -> std::iter::Enumerate> { + self.row_names.iter().enumerate() } pub fn remove(&mut self, row_name: &str) -> bool { diff --git a/pylace/src/utils.rs b/pylace/src/utils.rs index fd81916f..aa17ddaf 100644 --- a/pylace/src/utils.rs +++ b/pylace/src/utils.rs @@ -152,19 +152,14 @@ impl<'s> TableIndex<'s> { ) -> PyResult<(Vec<(usize, String)>, Vec<(usize, String)>)> { match self { Self::Single(ixs) => { - // let row_ixs = codebook - // .row_names - // .iter() - // .map(|(a, &b)| (b, a.clone())) - // .collect(); - let row_ixs: PyResult> = (0..codebook - .n_rows()) - .map(|ix| IntOrString::Int(ix as isize)) - .map(|ix| ix.row_ix(codebook)) + let row_ixs = codebook + .row_names + .iter() + .map(|(a, b)| (a, b.clone())) .collect(); let col_ixs = ixs.col_ixs(codebook)?; - Ok((row_ixs?, col_ixs)) + Ok((row_ixs, col_ixs)) } Self::Tuple(row_ixs, col_ixs) => { col_ixs.col_ixs(codebook).and_then(|cixs| { @@ -411,7 +406,7 @@ impl Indexer { pub(crate) fn rows(codebook: &Codebook) -> Self { let mut to_ix: HashMap = HashMap::new(); let mut to_name: HashMap = HashMap::new(); - codebook.row_names.iter().for_each(|(name, &ix)| { + codebook.row_names.iter().for_each(|(ix, name)| { to_ix.insert(name.clone(), ix); to_name.insert(ix, name.clone()); }); From 08a0522a559e561ad78f1e853765e671fa5e9dbe Mon Sep 17 00:00:00 2001 From: Baxter Eaves Date: Mon, 4 Dec 2023 11:04:15 -0600 Subject: [PATCH 7/7] lints --- pylace/lace/plot.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pylace/lace/plot.py b/pylace/lace/plot.py index 525f94ae..4049f59c 100644 --- a/pylace/lace/plot.py +++ b/pylace/lace/plot.py @@ -1,4 +1,4 @@ -"""Plotting utilities""" +"""Plotting utilities.""" from typing import Dict, Optional, Union @@ -353,7 +353,6 @@ def state( Examples -------- - Render an animals state >>> import matplotlib.pyplot as plt