diff --git a/CHANGELOG.md b/CHANGELOG.md index 1416ba2d..1648f91a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Moved `Example` code into the `examples` feature flag (on by default) - Replaced instances of `once_cell::sync::OnceCell` with `syd::sync::OnceLock` - Renamed all files/methods with the name `feather` to `arrow` +- Renamed `Builder` to `EngineBuilder` + +### Fixed + +- Fixed typo `UpdateHandler::finialize` is now `UpdateHandler::finalize` ## [python-0.4.1] - 2023-10-19 diff --git a/README.md b/README.md index 97fff6c9..4d56352d 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,6 @@ User guide | Rust API | Python API | - CLI
Installation: diff --git a/book/src/workflow/workflow.md b/book/src/workflow/workflow.md index 8dd4c7d5..224f4a1a 100644 --- a/book/src/workflow/workflow.md +++ b/book/src/workflow/workflow.md @@ -9,32 +9,55 @@ The typical workflow consists of two or three steps: Step 1 is optional in many cases as Lace usually does a good job of inferring the types of your data. The condensed workflow looks like this. - -Create an optional codebook using the CLI. - -```console -$ lace codebook --csv data.csv codebook.yaml -``` - -Run a model. - -```console -$ lace run --csv data.csv --codebook codebook.yaml -n 5000 metadata.lace -``` - -Open the model in lace -
```python +import pandas as pd import lace -engine = lace.Engine.load('metadata.lace') +df = pd.read_csv("mydata.csv", index_col=0) + +# 1. Create a codebook (optional) +codebook = lace.Codebook.from_df(df) + +# 2. Initialize a new Engine from the prior. If no codebook is provided, a +# default will be generated +engine = lace.Engine.from_df(df, codebook=codebook) + +# 3. Run inference +engine.run(5000) ``` ```rust,noplayground -use lace::Engine; - -let engine = Engine::load("metadata.lace")?; +use polars::prelude::{SerReader, CsvReader}; +use lace::prelude::*; + +let df = CsvReader::from_path("mydata.csv") + .unwrap() + .has_header(true) + .finish() + .unwrap(); + +// 1. Create a codebook (optional) +let codebook = Codebook::from_df(&df, None, None, False).unwrap(); + +// 2. Build an engine +let mut engine = EngineBuilder::new(DataSource::Polars(df)) + .with_codebook(codebook) + .build() + .unwrap(); + +// 3. Run inference +// Use `run` to fit with the default transition set and update handlers; use +// `update` for more control. +engine.run(5_000); ``` +
+ +You can also use the CLI to create codebooks and run inference. Creating a default YAML codebook with the CLI, and then manually editing is good way to fine tune models. + +```console +$ lace codebook --csv mydata.csv codebook.yaml +$ lace run --csv data.csv --codebook codebook.yaml -n 5000 metadata.lace +``` diff --git a/cli/Cargo.lock b/cli/Cargo.lock index 95ef3b73..605cf155 100644 --- a/cli/Cargo.lock +++ b/cli/Cargo.lock @@ -920,7 +920,7 @@ dependencies = [ [[package]] name = "lace-cli" -version = "0.4.2" +version = "0.5.0" dependencies = [ "approx", "clap", diff --git a/cli/src/routes.rs b/cli/src/routes.rs index 4e9809b8..0bd09137 100644 --- a/cli/src/routes.rs +++ b/cli/src/routes.rs @@ -5,7 +5,7 @@ use lace::codebook::Codebook; use lace::metadata::{deserialize_file, serialize_obj}; use lace::stats::rv::dist::Gamma; use lace::update_handler::{CtrlC, ProgressBar, Timeout}; -use lace::{Builder, Engine}; +use lace::{Engine, EngineBuilder}; use crate::opt; @@ -84,7 +84,7 @@ fn new_engine(cmd: opt::RunArgs) -> i32 { return 1; }; - let mut builder = Builder::new(data_source) + let mut builder = EngineBuilder::new(data_source) .with_nstates(cmd.nstates) .id_offset(cmd.id_offset); diff --git a/lace/examples/count_model.rs b/lace/examples/count_model.rs index 0a7c85e1..18e9b9fe 100644 --- a/lace/examples/count_model.rs +++ b/lace/examples/count_model.rs @@ -24,7 +24,7 @@ fn main() { writeln!(file, "{},{}", ix, x).unwrap(); }); - Builder::new(DataSource::Csv(file.path().into())) + EngineBuilder::new(DataSource::Csv(file.path().into())) .with_nstates(2) .seed_from_u64(1337) .build() diff --git a/lace/src/interface/engine/builder.rs b/lace/src/interface/engine/builder.rs index 42ade3e6..f8a8141e 100644 --- a/lace/src/interface/engine/builder.rs +++ b/lace/src/interface/engine/builder.rs @@ -11,7 +11,7 @@ const DEFAULT_NSTATES: usize = 8; const DEFAULT_ID_OFFSET: usize = 0; /// Builds `Engine`s -pub struct Builder { +pub struct EngineBuilder { n_states: Option, codebook: Option, data_source: DataSource, @@ -28,7 +28,7 @@ pub enum BuildEngineError { DefaultCodebookError(#[from] DefaultCodebookError), } -impl Builder { +impl EngineBuilder { #[must_use] pub fn new(data_source: DataSource) -> Self { Self { @@ -41,7 +41,7 @@ impl Builder { } } - /// Eith a certain number of states + /// With a certain number of states #[must_use] pub fn with_nstates(mut self, n_states: usize) -> Self { self.n_states = Some(n_states); @@ -132,7 +132,7 @@ mod tests { #[test] fn default_build_settings() { - let engine = Builder::new(animals_csv()).build().unwrap(); + let engine = EngineBuilder::new(animals_csv()).build().unwrap(); let state_ids: BTreeSet = engine.state_ids.iter().copied().collect(); let target_ids: BTreeSet = btreeset! {0, 1, 2, 3, 4, 5, 6, 7}; @@ -142,7 +142,10 @@ mod tests { #[test] fn with_id_offet_3() { - let engine = Builder::new(animals_csv()).id_offset(3).build().unwrap(); + let engine = EngineBuilder::new(animals_csv()) + .id_offset(3) + .build() + .unwrap(); let state_ids: BTreeSet = engine.state_ids.iter().copied().collect(); let target_ids: BTreeSet = btreeset! {3, 4, 5, 6, 7, 8, 9, 10}; @@ -152,8 +155,10 @@ mod tests { #[test] fn with_nstates_3() { - let engine = - Builder::new(animals_csv()).with_nstates(3).build().unwrap(); + let engine = EngineBuilder::new(animals_csv()) + .with_nstates(3) + .build() + .unwrap(); let state_ids: BTreeSet = engine.state_ids.iter().copied().collect(); let target_ids: BTreeSet = btreeset! {0, 1, 2}; @@ -163,7 +168,7 @@ mod tests { #[test] fn with_nstates_0_causes_error() { - let result = Builder::new(animals_csv()).with_nstates(0).build(); + let result = EngineBuilder::new(animals_csv()).with_nstates(0).build(); assert!(result.is_err()); } @@ -172,13 +177,13 @@ mod tests { fn seeding_engine_works() { let seed: u64 = 8_675_309; let nstates = 4; - let mut engine_1 = Builder::new(animals_csv()) + let mut engine_1 = EngineBuilder::new(animals_csv()) .with_nstates(nstates) .seed_from_u64(seed) .build() .unwrap(); - let mut engine_2 = Builder::new(animals_csv()) + let mut engine_2 = EngineBuilder::new(animals_csv()) .with_nstates(nstates) .seed_from_u64(seed) .build() diff --git a/lace/src/interface/engine/mod.rs b/lace/src/interface/engine/mod.rs index e5a47fc7..015620bd 100644 --- a/lace/src/interface/engine/mod.rs +++ b/lace/src/interface/engine/mod.rs @@ -3,7 +3,7 @@ mod data; pub mod error; pub mod update_handler; -pub use builder::{BuildEngineError, Builder}; +pub use builder::{BuildEngineError, EngineBuilder}; pub use data::{ AppendStrategy, InsertDataActions, InsertMode, OverwriteMode, Row, SupportExtension, Value, WriteMode, @@ -1030,7 +1030,7 @@ impl Engine { .collect::, _>>()?; } std::mem::drop(update_handlers); - update_handler.finialize(); + update_handler.finalize(); Ok(()) } @@ -1119,12 +1119,12 @@ mod tests { false } - fn finialize(&mut self) { + fn finalize(&mut self) { self.0.write().unwrap().insert("finalize".to_string()); } } - let mut engine = Builder::new(animals_csv()).build().unwrap(); + let mut engine = EngineBuilder::new(animals_csv()).build().unwrap(); let called_methods = Arc::new(RwLock::new(HashSet::new())); let update_handler = TestingHandler(called_methods.clone()); @@ -1158,7 +1158,7 @@ mod tests { // It does not test that the StateTimeout successfully ends states that have gone over the duration #[test] fn state_timeout_update_handler() { - let mut engine = Builder::new(animals_csv()).build().unwrap(); + let mut engine = EngineBuilder::new(animals_csv()).build().unwrap(); let config = EngineUpdateConfig::new().default_transitions().n_iters(1); diff --git a/lace/src/interface/engine/update_handler.rs b/lace/src/interface/engine/update_handler.rs index 6c6be7e5..7f709d6e 100644 --- a/lace/src/interface/engine/update_handler.rs +++ b/lace/src/interface/engine/update_handler.rs @@ -45,7 +45,7 @@ use crate::EngineUpdateConfig; /// self.timings.lock().unwrap().push(Instant::now()); /// } /// -/// fn finialize(&mut self) { +/// fn finalize(&mut self) { /// let timings = self.timings.lock().unwrap(); /// let mean_time_between_updates = /// timings.iter().zip(timings.iter().skip(1)) @@ -106,7 +106,7 @@ pub trait UpdateHandler: Clone + Send + Sync { /// /// This method is called when all updating is complete. /// Uses for this method include cleanup, report generation, etc. - fn finialize(&mut self) {} + fn finalize(&mut self) {} } macro_rules! impl_tuple { @@ -151,9 +151,9 @@ macro_rules! impl_tuple { )||+ } - fn finialize(&mut self) { + fn finalize(&mut self) { $( - self.$idx.finialize(); + self.$idx.finalize(); )+ } @@ -204,8 +204,8 @@ where false } - fn finialize(&mut self) { - self.iter_mut().for_each(|handler| handler.finialize()); + fn finalize(&mut self) { + self.iter_mut().for_each(|handler| handler.finalize()); } } @@ -285,7 +285,7 @@ impl UpdateHandler for Timeout { } } - fn finialize(&mut self) {} + fn finalize(&mut self) {} } /// Limit the time each state can run for during an `Engine::update`. @@ -427,7 +427,7 @@ impl UpdateHandler for ProgressBar { false } - fn finialize(&mut self) { + fn finalize(&mut self) { if let Self::Initialized { sender, handle } = std::mem::take(self) { std::mem::drop(sender); diff --git a/lace/src/interface/mod.rs b/lace/src/interface/mod.rs index 49f905e3..992c7f3f 100644 --- a/lace/src/interface/mod.rs +++ b/lace/src/interface/mod.rs @@ -5,7 +5,7 @@ mod metadata; mod oracle; pub use engine::{ - update_handler, AppendStrategy, BuildEngineError, Builder, Engine, + update_handler, AppendStrategy, BuildEngineError, Engine, EngineBuilder, InsertDataActions, InsertMode, OverwriteMode, Row, SupportExtension, Value, WriteMode, }; diff --git a/lace/src/lib.rs b/lace/src/lib.rs index c4375580..d8e817b2 100644 --- a/lace/src/lib.rs +++ b/lace/src/lib.rs @@ -188,10 +188,10 @@ pub use index::*; pub use config::EngineUpdateConfig; pub use interface::{ - update_handler, utils, AppendStrategy, BuildEngineError, Builder, - ConditionalEntropyType, DatalessOracle, Engine, Given, HasData, HasStates, - ImputeUncertaintyType, InsertDataActions, InsertMode, Metadata, - MiComponents, MiType, Oracle, OracleT, OverwriteMode, + update_handler, utils, AppendStrategy, BuildEngineError, + ConditionalEntropyType, DatalessOracle, Engine, EngineBuilder, Given, + HasData, HasStates, ImputeUncertaintyType, InsertDataActions, InsertMode, + Metadata, MiComponents, MiType, Oracle, OracleT, OverwriteMode, PredictUncertaintyType, Row, RowSimilarityVariant, SupportExtension, Value, WriteMode, }; diff --git a/lace/src/prelude.rs b/lace/src/prelude.rs index 324f0578..e7389119 100644 --- a/lace/src/prelude.rs +++ b/lace/src/prelude.rs @@ -1,10 +1,10 @@ //! Common import for general use. pub use crate::{ - update_handler, AppendStrategy, Builder, Datum, Engine, EngineUpdateConfig, - Given, ImputeUncertaintyType, InsertMode, MiType, OracleT, OverwriteMode, - PredictUncertaintyType, Row, RowSimilarityVariant, SupportExtension, Value, - WriteMode, + update_handler, AppendStrategy, Datum, Engine, EngineBuilder, + EngineUpdateConfig, Given, ImputeUncertaintyType, InsertMode, MiType, + OracleT, OverwriteMode, PredictUncertaintyType, Row, RowSimilarityVariant, + SupportExtension, Value, WriteMode, }; pub use crate::data::DataSource; diff --git a/lace/tests/engine.rs b/lace/tests/engine.rs index 62148f9b..9a91bf65 100644 --- a/lace/tests/engine.rs +++ b/lace/tests/engine.rs @@ -7,7 +7,7 @@ use lace::config::EngineUpdateConfig; use lace::data::DataSource; use lace::examples::Example; use lace::{ - AppendStrategy, Builder, Engine, HasStates, InsertDataActions, + AppendStrategy, Engine, EngineBuilder, HasStates, InsertDataActions, SupportExtension, }; use lace_codebook::{Codebook, ValueMap}; @@ -33,7 +33,7 @@ fn animals_codebook_path() -> PathBuf { // tempfiles. #[cfg(feature = "formats")] fn engine_from_csv>(path: P) -> Engine { - Builder::new(DataSource::Csv(path.into())) + EngineBuilder::new(DataSource::Csv(path.into())) .with_nstates(2) .build() .unwrap() @@ -228,7 +228,7 @@ fn cell_gibbs_smoke() { fn engine_build_without_flat_col_is_not_flat() { let path = animals_data_path(); let df = lace_codebook::data::read_csv(path).unwrap(); - let engine = Builder::new(DataSource::Polars(df)) + let engine = EngineBuilder::new(DataSource::Polars(df)) .with_nstates(8) .build() .unwrap(); @@ -1344,7 +1344,7 @@ mod insert_data { ..Default::default() }; - let mut engine = Builder::new(DataSource::Empty).build().unwrap(); + let mut engine = EngineBuilder::new(DataSource::Empty).build().unwrap(); assert_eq!(engine.n_rows(), 0); assert_eq!(engine.n_cols(), 0); @@ -1440,7 +1440,7 @@ mod insert_data { ..Default::default() }; - let mut engine = Builder::new(DataSource::Empty).build().unwrap(); + let mut engine = EngineBuilder::new(DataSource::Empty).build().unwrap(); assert_eq!(engine.n_rows(), 0); assert_eq!(engine.n_cols(), 0); @@ -1894,7 +1894,7 @@ mod insert_data { #[test] fn $fn_name() { let mut engine = - Builder::new(DataSource::Empty).build().unwrap(); + EngineBuilder::new(DataSource::Empty).build().unwrap(); let new_metadata = ColMetadataList::new(vec![ continuous_md("one".to_string()), continuous_md("two".to_string()), @@ -1972,7 +1972,7 @@ mod insert_data { #[test] fn $fn_name() { let mut engine = - Builder::new(DataSource::Empty).build().unwrap(); + EngineBuilder::new(DataSource::Empty).build().unwrap(); let new_metadata = ColMetadataList::new(vec![ continuous_md("one".to_string()), diff --git a/lace/tests/workflow.rs b/lace/tests/workflow.rs index 5e02df78..af7d5e2d 100644 --- a/lace/tests/workflow.rs +++ b/lace/tests/workflow.rs @@ -1,8 +1,8 @@ use lace::config::EngineUpdateConfig; use lace::data::DataSource; use lace::update_handler::Timeout; -use lace::Builder; use lace::Engine; +use lace::EngineBuilder; use lace_codebook::data::codebook_from_csv; use rand::SeedableRng; use std::io::Write; @@ -58,7 +58,7 @@ fn satellites_csv_workflow() { let codebook = codebook_from_csv(path.as_path(), None, None, false).unwrap(); - let mut engine: Engine = Builder::new(DataSource::Csv(path)) + let mut engine: Engine = EngineBuilder::new(DataSource::Csv(path)) .codebook(codebook) .with_nstates(4) .seed_from_u64(1776)