Skip to content

Commit

Permalink
Merge pull request #179 from promised-ai/feature/flat-columns-in-pyla…
Browse files Browse the repository at this point in the history
…ce-ctor

Feature/flat columns in pylace ctor
  • Loading branch information
BaxterEaves authored Feb 5, 2024
2 parents 67148b3 + d75b7a7 commit 06fbf54
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `DataParseError::CodebookAndDataRowsMismatch` variant for when the number of rows in the codebook and the number of rows in the data do not match.
- `DataParseError::DataFrameMissingColumn` variant for when a column is in the codebook but not in the initial dataframe.
- Python's `Engine.update` uses `tqdm.auto` for progress bar reporting.
- Added `flat_columns` option to pylace `Engine` constructor to enable creating engines with one view

### Changed
- Added parallelism to `Slice` row reassignment kernel. Run time is ~6x faster.
Expand Down
27 changes: 21 additions & 6 deletions pylace/lace/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def from_df(
n_states: int = 8,
id_offset: int = 0,
rng_seed: Optional[int] = None,
flat_columns: bool = False,
) -> "Engine":
"""
Create a new ``Engine`` from a DataFrame.
Expand All @@ -85,7 +86,12 @@ def from_df(
files may be merged by copying without name collisions.
rng_seed: int, optional
Random number generator seed.
flat_columns: bool
Initialize all states with one view. Use when you do not want to
do inference over the assignment of columns to views. Note that to
keep the states flat you will have to either use the `flat`
transition set or manually create a transition set that does not
update the column assignments when updating.
Examples
--------
Expand All @@ -106,10 +112,18 @@ def from_df(
... "ID": [1, 2, 3, 4],
... "list_b": [2.0, 4.0, 6.0, 8.0],
... })
>>> engine = Engine.from_df(df, CodebookBuilder.infer(
>>> engine = Engine.from_df(df, codebook=CodebookBuilder.infer(
... cat_cutoff=2,
... ))
Create an engine with flat column structure (one view)
>>> from lace.examples import Animals
>>> df = Animals().df
>>> n_states = 8
>>> engine = Engine.from_df(df, n_states=n_states, flat_columns=True)
>>> [max(engine.column_assignment(i)) for i in range(n_states)]
[0, 0, 0, 0, 0, 0, 0, 0]
"""
if isinstance(df, pd.DataFrame):
df.index.rename("ID", inplace=True)
Expand All @@ -128,6 +142,7 @@ def from_df(
n_states,
id_offset,
rng_seed,
flat_columns,
)
)

Expand Down Expand Up @@ -1012,7 +1027,7 @@ def update(
if isinstance(transitions, str):
transitions = utils._get_common_transitions(transitions)

update_handler = None if quiet else TqdmUpdateHandler()
update_handler = None if quiet else _TqdmUpdateHandler()

return self.engine.update(
n_iters,
Expand All @@ -1027,8 +1042,8 @@ def entropy(self, cols, n_mc_samples: int = 1000):
"""
Estimate the entropy or joint entropy of one or more features.
Prameters
---------
Parameters
----------
col: column indices
The columns for which to compute entropy
n_mc_samples: int
Expand Down Expand Up @@ -2342,7 +2357,7 @@ def clustermap(
return ClusterMap(df, linkage)


class TqdmUpdateHandler:
class _TqdmUpdateHandler:
def __init__(self):
self._t = tqdm()

Expand Down
8 changes: 7 additions & 1 deletion pylace/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ impl CoreEngine {
n_states=16,
id_offset=0,
rng_seed=None,
flat_columns=false,
)
)]
fn new(
Expand All @@ -74,6 +75,7 @@ impl CoreEngine {
n_states: usize,
id_offset: usize,
rng_seed: Option<u64>,
flat_columns: bool,
) -> PyResult<CoreEngine> {
let dataframe = dataframe.0;
let codebook =
Expand All @@ -86,7 +88,7 @@ impl CoreEngine {
Xoshiro256Plus::from_entropy()
};

let engine = lace::Engine::new(
let mut engine = lace::Engine::new(
n_states,
codebook,
data_source,
Expand All @@ -96,6 +98,10 @@ impl CoreEngine {
.map_err(|err| err.to_string())
.map_err(PyErr::new::<PyValueError, _>)?;

if flat_columns {
engine.flatten_cols();
}

Ok(Self {
col_indexer: Indexer::columns(&engine.codebook),
row_indexer: Indexer::rows(&engine.codebook),
Expand Down

0 comments on commit 06fbf54

Please sign in to comment.