Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Added tqdm handler for Engine.update progress bar. #176

Merged
merged 2 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- `DataParseError::CodebookAndDataRowsMismatch` variant for when the number of rows in the codebook and the number of rows in the data do not match.
- `DataParseError::DataFrameMissingColumn` variant for when a column is in the codebook but not in the initial dataframe.
- Python's `Engine.update` uses `tqdm.auto` for progress bar reporting.

### Fixed
- Initializing an engine with a codebook that has a different number of rows than the data will result in an error instead of printing a bunch on nonsense.
Expand Down
2 changes: 1 addition & 1 deletion pylace/docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ commonmark==0.9.1
hypothesis==6.65.2
maturin==0.14.10
mock==1.0.1
pillow==10.1.0
pillow==10.2.0
pytest-cov==4.0.0
pytest-xdist==3.1.0
pytest==7.2.0
Expand Down
1 change: 0 additions & 1 deletion pylace/lace/analysis.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Tools for analysis of probabilistic cross-categorization results in Lace."""


import enum
import itertools as it
from copy import deepcopy
Expand Down
32 changes: 31 additions & 1 deletion pylace/lace/engine.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""The main interface to Lace models."""

import itertools as it
from os import PathLike
from typing import TYPE_CHECKING, Dict, List, Optional, Union

import pandas as pd
import plotly.express as px
import polars as pl
from tqdm.auto import tqdm

from lace import core, utils
from lace.codebook import Codebook
Expand Down Expand Up @@ -1010,13 +1012,15 @@ def update(
if isinstance(transitions, str):
transitions = utils._get_common_transitions(transitions)

update_handler = None if quiet else TqdmUpdateHandler()

return self.engine.update(
n_iters,
timeout=timeout,
checkpoint=checkpoint,
transitions=transitions,
save_path=save_path,
quiet=quiet,
update_handler=update_handler,
)

def entropy(self, cols, n_mc_samples: int = 1000):
Expand Down Expand Up @@ -2336,3 +2340,29 @@ def clustermap(
return ClusterMap(df, linkage, fig)
else:
return ClusterMap(df, linkage)


class TqdmUpdateHandler:
def __init__(self):
self._t = tqdm()

def global_init(self, config):
self._t.reset(config.n_iters * config.n_states)

def new_state_init(self, state_id):
pass

def state_updated(self, state_id):
self._t.update(1)

def state_complete(self, state_id):
pass

def stop_engine(self):
return False

def stop_state(self):
return False

def finalize(self):
self._t.close()
28 changes: 19 additions & 9 deletions pylace/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ mod component;
mod df;
mod metadata;
mod transition;
mod update_handler;
mod utils;

use std::cmp::Ordering;
Expand All @@ -22,6 +23,7 @@ use rand_xoshiro::Xoshiro256Plus;

use metadata::{Codebook, CodebookBuilder};

use crate::update_handler::PyUpdateHandler;
use crate::utils::*;

#[derive(Clone)]
Expand Down Expand Up @@ -1027,19 +1029,20 @@ impl CoreEngine {
checkpoint=None,
transitions=None,
save_path=None,
quiet=false,
update_handler=None,
)
)]
fn update(
&mut self,
py: Python<'_>,
Swandog marked this conversation as resolved.
Show resolved Hide resolved
n_iters: usize,
timeout: Option<u64>,
checkpoint: Option<usize>,
transitions: Option<Vec<transition::StateTransition>>,
save_path: Option<PathBuf>,
quiet: bool,
update_handler: Option<PyObject>,
) {
use lace::update_handler::{ProgressBar, Timeout};
use lace::update_handler::Timeout;
use std::time::Duration;

let config = match transitions {
Expand Down Expand Up @@ -1068,12 +1071,18 @@ impl CoreEngine {
Timeout::new(Duration::from_secs(secs))
};

if quiet {
self.engine.update(config, timeout).unwrap();
} else {
let pbar = ProgressBar::new();
self.engine.update(config, (timeout, pbar)).unwrap();
}
py.allow_threads(|| {
if let Some(update_handler) = update_handler {
self.engine
.update(
config,
(timeout, PyUpdateHandler::new(update_handler)),
)
.unwrap();
} else {
self.engine.update(config, timeout).unwrap();
}
});
}

/// Append new rows to the table.
Expand Down Expand Up @@ -1316,6 +1325,7 @@ fn core(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<metadata::CategoricalPrior>()?;
m.add_class::<metadata::CountHyper>()?;
m.add_class::<metadata::CountPrior>()?;
m.add_class::<update_handler::PyEngineUpdateConfig>()?;
m.add_function(wrap_pyfunction!(infer_srs_metadata, m)?)?;
m.add_function(wrap_pyfunction!(metadata::codebook_from_df, m)?)?;
Ok(())
Expand Down
139 changes: 139 additions & 0 deletions pylace/src/update_handler.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
use std::io::Write;
/// Update Handler and associated tooling for `CoreEngine.update` in Python.
use std::sync::{Arc, Mutex};

use lace::cc::state::State;
use lace::update_handler::UpdateHandler;
use lace::EngineUpdateConfig;
use pyo3::{pyclass, IntoPy, Py, PyAny};

/// Python version of `EngineUpdateConfig`.
#[derive(Clone, Debug)]
#[pyclass(frozen, get_all)]
pub struct PyEngineUpdateConfig {
/// Maximum number of iterations to run.
pub n_iters: usize,
/// Number of iterations after which each state should be saved
pub checkpoint: Option<usize>,
/// Number of states
pub n_states: usize,
}

/// An `UpdateHandler` which wraps a Python Object.
#[derive(Debug, Clone)]
pub struct PyUpdateHandler {
handler: Arc<Mutex<Py<PyAny>>>,
}

impl PyUpdateHandler {
/// Create a new `PyUpdateHandler` from a Python Object
pub fn new(handler: Py<PyAny>) -> Self {
Self {
handler: Arc::new(Mutex::new(handler)),
}
}
}

macro_rules! pydict {
($py: expr, $($key:tt : $val:expr),* $(,)?) => {{

let map = pyo3::types::PyDict::new($py);
$(
let _ = map.set_item($key, $val.into_py($py))
.expect("Should be able to set item in PyDict");
)*
map
}};
}

macro_rules! call_pyhandler_noret {
($self: ident, $func_name: tt, $($key: tt : $val: expr),* $(,)?) => {{
let handler = $self
.handler
.lock()
.expect("Should be able to get a lock for the PyUpdateHandler");
Swandog marked this conversation as resolved.
Show resolved Hide resolved

::pyo3::Python::with_gil(|py| {
let kwargs = pydict!(
py,
$($key: $val),*
);

handler
.call_method(py, $func_name, (), kwargs.into())
.expect("Expected python call_method to return successfully");
})
}};
}

macro_rules! call_pyhandler_ret {
($self: ident, $func_name: tt, $($key: tt : $val: expr),* $(,)?) => {{
let handler = $self
.handler
.lock()
.expect("Should be able to get a lock for the PyUpdateHandler");

::pyo3::Python::with_gil(|py| {
let kwargs = pydict!(
py,
$($key: $val),*
);

handler
.call_method(py, $func_name, (), kwargs.into())
.expect("Expected python call_method to return successfully")
.extract(py)
.expect("Failed to extract expected type")
})
}};
}

impl UpdateHandler for PyUpdateHandler {
fn global_init(&mut self, config: &EngineUpdateConfig, states: &[State]) {
call_pyhandler_noret!(
self,
"global_init",
"config": PyEngineUpdateConfig {
n_iters: config.n_iters,
checkpoint: config.checkpoint,
n_states: states.len(),
}
);
}

fn new_state_init(&mut self, state_id: usize, _state: &State) {
call_pyhandler_noret!(
self,
"new_state_init",
"state_id": state_id,
);
}

fn state_updated(&mut self, state_id: usize, _state: &State) {
call_pyhandler_noret!(
self,
"state_updated",
"state_id": state_id,
);
}

fn state_complete(&mut self, state_id: usize, _state: &State) {
call_pyhandler_noret!(
self,
"state_complete",
"state_id": state_id,
);
}

fn stop_engine(&self) -> bool {
call_pyhandler_ret!(self, "stop_engine",)
}

fn stop_state(&self, _state_id: usize) -> bool {
call_pyhandler_ret!(self, "stop_state",)
}

fn finalize(&mut self) {
call_pyhandler_noret!(self, "finalize",)
}
}
1 change: 1 addition & 0 deletions pylace/tests/example_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Check the basics of the example datasets."""

from numpy.testing import assert_almost_equal

from lace import examples
Expand Down
1 change: 1 addition & 0 deletions pylace/tests/test_self_referencing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests whether engine functions work with various engine outputs."""

import random

import polars as pl
Expand Down
Loading