Skip to content

Commit

Permalink
Merge pull request #176 from promised-ai/feature/pylace-update-handler
Browse files Browse the repository at this point in the history
feat(python): Added tqdm handler for `Engine.update` progress bar.
  • Loading branch information
schmidmt authored Feb 1, 2024
2 parents 965d522 + 60e9a7d commit 9135670
Show file tree
Hide file tree
Showing 8 changed files with 193 additions and 12 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- `DataParseError::CodebookAndDataRowsMismatch` variant for when the number of rows in the codebook and the number of rows in the data do not match.
- `DataParseError::DataFrameMissingColumn` variant for when a column is in the codebook but not in the initial dataframe.
- Python's `Engine.update` uses `tqdm.auto` for progress bar reporting.

### Fixed
- Initializing an engine with a codebook that has a different number of rows than the data will result in an error instead of printing a bunch on nonsense.
Expand Down
2 changes: 1 addition & 1 deletion pylace/docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ commonmark==0.9.1
hypothesis==6.65.2
maturin==0.14.10
mock==1.0.1
pillow==10.1.0
pillow==10.2.0
pytest-cov==4.0.0
pytest-xdist==3.1.0
pytest==7.2.0
Expand Down
1 change: 0 additions & 1 deletion pylace/lace/analysis.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Tools for analysis of probabilistic cross-categorization results in Lace."""


import enum
import itertools as it
from copy import deepcopy
Expand Down
32 changes: 31 additions & 1 deletion pylace/lace/engine.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""The main interface to Lace models."""

import itertools as it
from os import PathLike
from typing import TYPE_CHECKING, Dict, List, Optional, Union

import pandas as pd
import plotly.express as px
import polars as pl
from tqdm.auto import tqdm

from lace import core, utils
from lace.codebook import Codebook
Expand Down Expand Up @@ -1010,13 +1012,15 @@ def update(
if isinstance(transitions, str):
transitions = utils._get_common_transitions(transitions)

update_handler = None if quiet else TqdmUpdateHandler()

return self.engine.update(
n_iters,
timeout=timeout,
checkpoint=checkpoint,
transitions=transitions,
save_path=save_path,
quiet=quiet,
update_handler=update_handler,
)

def entropy(self, cols, n_mc_samples: int = 1000):
Expand Down Expand Up @@ -2336,3 +2340,29 @@ def clustermap(
return ClusterMap(df, linkage, fig)
else:
return ClusterMap(df, linkage)


class TqdmUpdateHandler:
def __init__(self):
self._t = tqdm()

def global_init(self, config):
self._t.reset(config.n_iters * config.n_states)

def new_state_init(self, state_id):
pass

def state_updated(self, state_id):
self._t.update(1)

def state_complete(self, state_id):
pass

def stop_engine(self):
return False

def stop_state(self):
return False

def finalize(self):
self._t.close()
28 changes: 19 additions & 9 deletions pylace/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ mod component;
mod df;
mod metadata;
mod transition;
mod update_handler;
mod utils;

use std::cmp::Ordering;
Expand All @@ -22,6 +23,7 @@ use rand_xoshiro::Xoshiro256Plus;

use metadata::{Codebook, CodebookBuilder};

use crate::update_handler::PyUpdateHandler;
use crate::utils::*;

#[derive(Clone)]
Expand Down Expand Up @@ -1027,19 +1029,20 @@ impl CoreEngine {
checkpoint=None,
transitions=None,
save_path=None,
quiet=false,
update_handler=None,
)
)]
fn update(
&mut self,
py: Python<'_>,
n_iters: usize,
timeout: Option<u64>,
checkpoint: Option<usize>,
transitions: Option<Vec<transition::StateTransition>>,
save_path: Option<PathBuf>,
quiet: bool,
update_handler: Option<PyObject>,
) {
use lace::update_handler::{ProgressBar, Timeout};
use lace::update_handler::Timeout;
use std::time::Duration;

let config = match transitions {
Expand Down Expand Up @@ -1068,12 +1071,18 @@ impl CoreEngine {
Timeout::new(Duration::from_secs(secs))
};

if quiet {
self.engine.update(config, timeout).unwrap();
} else {
let pbar = ProgressBar::new();
self.engine.update(config, (timeout, pbar)).unwrap();
}
py.allow_threads(|| {
if let Some(update_handler) = update_handler {
self.engine
.update(
config,
(timeout, PyUpdateHandler::new(update_handler)),
)
.unwrap();
} else {
self.engine.update(config, timeout).unwrap();
}
});
}

/// Append new rows to the table.
Expand Down Expand Up @@ -1316,6 +1325,7 @@ fn core(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<metadata::CategoricalPrior>()?;
m.add_class::<metadata::CountHyper>()?;
m.add_class::<metadata::CountPrior>()?;
m.add_class::<update_handler::PyEngineUpdateConfig>()?;
m.add_function(wrap_pyfunction!(infer_srs_metadata, m)?)?;
m.add_function(wrap_pyfunction!(metadata::codebook_from_df, m)?)?;
Ok(())
Expand Down
139 changes: 139 additions & 0 deletions pylace/src/update_handler.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
use std::io::Write;
/// Update Handler and associated tooling for `CoreEngine.update` in Python.
use std::sync::{Arc, Mutex};

use lace::cc::state::State;
use lace::update_handler::UpdateHandler;
use lace::EngineUpdateConfig;
use pyo3::{pyclass, IntoPy, Py, PyAny};

/// Python version of `EngineUpdateConfig`.
#[derive(Clone, Debug)]
#[pyclass(frozen, get_all)]
pub struct PyEngineUpdateConfig {
/// Maximum number of iterations to run.
pub n_iters: usize,
/// Number of iterations after which each state should be saved
pub checkpoint: Option<usize>,
/// Number of states
pub n_states: usize,
}

/// An `UpdateHandler` which wraps a Python Object.
#[derive(Debug, Clone)]
pub struct PyUpdateHandler {
handler: Arc<Mutex<Py<PyAny>>>,
}

impl PyUpdateHandler {
/// Create a new `PyUpdateHandler` from a Python Object
pub fn new(handler: Py<PyAny>) -> Self {
Self {
handler: Arc::new(Mutex::new(handler)),
}
}
}

macro_rules! pydict {
($py: expr, $($key:tt : $val:expr),* $(,)?) => {{

let map = pyo3::types::PyDict::new($py);
$(
let _ = map.set_item($key, $val.into_py($py))
.expect("Should be able to set item in PyDict");
)*
map
}};
}

macro_rules! call_pyhandler_noret {
($self: ident, $func_name: tt, $($key: tt : $val: expr),* $(,)?) => {{
let handler = $self
.handler
.lock()
.expect("Should be able to get a lock for the PyUpdateHandler");

::pyo3::Python::with_gil(|py| {
let kwargs = pydict!(
py,
$($key: $val),*
);

handler
.call_method(py, $func_name, (), kwargs.into())
.expect("Expected python call_method to return successfully");
})
}};
}

macro_rules! call_pyhandler_ret {
($self: ident, $func_name: tt, $($key: tt : $val: expr),* $(,)?) => {{
let handler = $self
.handler
.lock()
.expect("Should be able to get a lock for the PyUpdateHandler");

::pyo3::Python::with_gil(|py| {
let kwargs = pydict!(
py,
$($key: $val),*
);

handler
.call_method(py, $func_name, (), kwargs.into())
.expect("Expected python call_method to return successfully")
.extract(py)
.expect("Failed to extract expected type")
})
}};
}

impl UpdateHandler for PyUpdateHandler {
fn global_init(&mut self, config: &EngineUpdateConfig, states: &[State]) {
call_pyhandler_noret!(
self,
"global_init",
"config": PyEngineUpdateConfig {
n_iters: config.n_iters,
checkpoint: config.checkpoint,
n_states: states.len(),
}
);
}

fn new_state_init(&mut self, state_id: usize, _state: &State) {
call_pyhandler_noret!(
self,
"new_state_init",
"state_id": state_id,
);
}

fn state_updated(&mut self, state_id: usize, _state: &State) {
call_pyhandler_noret!(
self,
"state_updated",
"state_id": state_id,
);
}

fn state_complete(&mut self, state_id: usize, _state: &State) {
call_pyhandler_noret!(
self,
"state_complete",
"state_id": state_id,
);
}

fn stop_engine(&self) -> bool {
call_pyhandler_ret!(self, "stop_engine",)
}

fn stop_state(&self, _state_id: usize) -> bool {
call_pyhandler_ret!(self, "stop_state",)
}

fn finalize(&mut self) {
call_pyhandler_noret!(self, "finalize",)
}
}
1 change: 1 addition & 0 deletions pylace/tests/example_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Check the basics of the example datasets."""

from numpy.testing import assert_almost_equal

from lace import examples
Expand Down
1 change: 1 addition & 0 deletions pylace/tests/test_self_referencing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests whether engine functions work with various engine outputs."""

import random

import polars as pl
Expand Down

0 comments on commit 9135670

Please sign in to comment.