diff --git a/CHANGELOG.md b/CHANGELOG.md
index 1416ba2d..1648f91a 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -20,6 +20,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Moved `Example` code into the `examples` feature flag (on by default)
- Replaced instances of `once_cell::sync::OnceCell` with `syd::sync::OnceLock`
- Renamed all files/methods with the name `feather` to `arrow`
+- Renamed `Builder` to `EngineBuilder`
+
+### Fixed
+
+- Fixed typo `UpdateHandler::finialize` is now `UpdateHandler::finalize`
## [python-0.4.1] - 2023-10-19
diff --git a/README.md b/README.md
index 97fff6c9..4d56352d 100644
--- a/README.md
+++ b/README.md
@@ -28,7 +28,6 @@
User guide |
Rust API |
Python API |
- CLI
Installation:
diff --git a/book/src/workflow/workflow.md b/book/src/workflow/workflow.md
index 8dd4c7d5..224f4a1a 100644
--- a/book/src/workflow/workflow.md
+++ b/book/src/workflow/workflow.md
@@ -9,32 +9,55 @@ The typical workflow consists of two or three steps:
Step 1 is optional in many cases as Lace usually does a good job of inferring
the types of your data. The condensed workflow looks like this.
-
-Create an optional codebook using the CLI.
-
-```console
-$ lace codebook --csv data.csv codebook.yaml
-```
-
-Run a model.
-
-```console
-$ lace run --csv data.csv --codebook codebook.yaml -n 5000 metadata.lace
-```
-
-Open the model in lace
-
```python
+import pandas as pd
import lace
-engine = lace.Engine.load('metadata.lace')
+df = pd.read_csv("mydata.csv", index_col=0)
+
+# 1. Create a codebook (optional)
+codebook = lace.Codebook.from_df(df)
+
+# 2. Initialize a new Engine from the prior. If no codebook is provided, a
+# default will be generated
+engine = lace.Engine.from_df(df, codebook=codebook)
+
+# 3. Run inference
+engine.run(5000)
```
```rust,noplayground
-use lace::Engine;
-
-let engine = Engine::load("metadata.lace")?;
+use polars::prelude::{SerReader, CsvReader};
+use lace::prelude::*;
+
+let df = CsvReader::from_path("mydata.csv")
+ .unwrap()
+ .has_header(true)
+ .finish()
+ .unwrap();
+
+// 1. Create a codebook (optional)
+let codebook = Codebook::from_df(&df, None, None, False).unwrap();
+
+// 2. Build an engine
+let mut engine = EngineBuilder::new(DataSource::Polars(df))
+ .with_codebook(codebook)
+ .build()
+ .unwrap();
+
+// 3. Run inference
+// Use `run` to fit with the default transition set and update handlers; use
+// `update` for more control.
+engine.run(5_000);
```
+
+
+You can also use the CLI to create codebooks and run inference. Creating a default YAML codebook with the CLI, and then manually editing is good way to fine tune models.
+
+```console
+$ lace codebook --csv mydata.csv codebook.yaml
+$ lace run --csv data.csv --codebook codebook.yaml -n 5000 metadata.lace
+```
diff --git a/cli/Cargo.lock b/cli/Cargo.lock
index 95ef3b73..605cf155 100644
--- a/cli/Cargo.lock
+++ b/cli/Cargo.lock
@@ -920,7 +920,7 @@ dependencies = [
[[package]]
name = "lace-cli"
-version = "0.4.2"
+version = "0.5.0"
dependencies = [
"approx",
"clap",
diff --git a/cli/src/routes.rs b/cli/src/routes.rs
index 4e9809b8..0bd09137 100644
--- a/cli/src/routes.rs
+++ b/cli/src/routes.rs
@@ -5,7 +5,7 @@ use lace::codebook::Codebook;
use lace::metadata::{deserialize_file, serialize_obj};
use lace::stats::rv::dist::Gamma;
use lace::update_handler::{CtrlC, ProgressBar, Timeout};
-use lace::{Builder, Engine};
+use lace::{Engine, EngineBuilder};
use crate::opt;
@@ -84,7 +84,7 @@ fn new_engine(cmd: opt::RunArgs) -> i32 {
return 1;
};
- let mut builder = Builder::new(data_source)
+ let mut builder = EngineBuilder::new(data_source)
.with_nstates(cmd.nstates)
.id_offset(cmd.id_offset);
diff --git a/lace/examples/count_model.rs b/lace/examples/count_model.rs
index 0a7c85e1..18e9b9fe 100644
--- a/lace/examples/count_model.rs
+++ b/lace/examples/count_model.rs
@@ -24,7 +24,7 @@ fn main() {
writeln!(file, "{},{}", ix, x).unwrap();
});
- Builder::new(DataSource::Csv(file.path().into()))
+ EngineBuilder::new(DataSource::Csv(file.path().into()))
.with_nstates(2)
.seed_from_u64(1337)
.build()
diff --git a/lace/src/interface/engine/builder.rs b/lace/src/interface/engine/builder.rs
index 42ade3e6..f8a8141e 100644
--- a/lace/src/interface/engine/builder.rs
+++ b/lace/src/interface/engine/builder.rs
@@ -11,7 +11,7 @@ const DEFAULT_NSTATES: usize = 8;
const DEFAULT_ID_OFFSET: usize = 0;
/// Builds `Engine`s
-pub struct Builder {
+pub struct EngineBuilder {
n_states: Option
,
codebook: Option,
data_source: DataSource,
@@ -28,7 +28,7 @@ pub enum BuildEngineError {
DefaultCodebookError(#[from] DefaultCodebookError),
}
-impl Builder {
+impl EngineBuilder {
#[must_use]
pub fn new(data_source: DataSource) -> Self {
Self {
@@ -41,7 +41,7 @@ impl Builder {
}
}
- /// Eith a certain number of states
+ /// With a certain number of states
#[must_use]
pub fn with_nstates(mut self, n_states: usize) -> Self {
self.n_states = Some(n_states);
@@ -132,7 +132,7 @@ mod tests {
#[test]
fn default_build_settings() {
- let engine = Builder::new(animals_csv()).build().unwrap();
+ let engine = EngineBuilder::new(animals_csv()).build().unwrap();
let state_ids: BTreeSet =
engine.state_ids.iter().copied().collect();
let target_ids: BTreeSet = btreeset! {0, 1, 2, 3, 4, 5, 6, 7};
@@ -142,7 +142,10 @@ mod tests {
#[test]
fn with_id_offet_3() {
- let engine = Builder::new(animals_csv()).id_offset(3).build().unwrap();
+ let engine = EngineBuilder::new(animals_csv())
+ .id_offset(3)
+ .build()
+ .unwrap();
let state_ids: BTreeSet =
engine.state_ids.iter().copied().collect();
let target_ids: BTreeSet = btreeset! {3, 4, 5, 6, 7, 8, 9, 10};
@@ -152,8 +155,10 @@ mod tests {
#[test]
fn with_nstates_3() {
- let engine =
- Builder::new(animals_csv()).with_nstates(3).build().unwrap();
+ let engine = EngineBuilder::new(animals_csv())
+ .with_nstates(3)
+ .build()
+ .unwrap();
let state_ids: BTreeSet =
engine.state_ids.iter().copied().collect();
let target_ids: BTreeSet = btreeset! {0, 1, 2};
@@ -163,7 +168,7 @@ mod tests {
#[test]
fn with_nstates_0_causes_error() {
- let result = Builder::new(animals_csv()).with_nstates(0).build();
+ let result = EngineBuilder::new(animals_csv()).with_nstates(0).build();
assert!(result.is_err());
}
@@ -172,13 +177,13 @@ mod tests {
fn seeding_engine_works() {
let seed: u64 = 8_675_309;
let nstates = 4;
- let mut engine_1 = Builder::new(animals_csv())
+ let mut engine_1 = EngineBuilder::new(animals_csv())
.with_nstates(nstates)
.seed_from_u64(seed)
.build()
.unwrap();
- let mut engine_2 = Builder::new(animals_csv())
+ let mut engine_2 = EngineBuilder::new(animals_csv())
.with_nstates(nstates)
.seed_from_u64(seed)
.build()
diff --git a/lace/src/interface/engine/mod.rs b/lace/src/interface/engine/mod.rs
index e5a47fc7..015620bd 100644
--- a/lace/src/interface/engine/mod.rs
+++ b/lace/src/interface/engine/mod.rs
@@ -3,7 +3,7 @@ mod data;
pub mod error;
pub mod update_handler;
-pub use builder::{BuildEngineError, Builder};
+pub use builder::{BuildEngineError, EngineBuilder};
pub use data::{
AppendStrategy, InsertDataActions, InsertMode, OverwriteMode, Row,
SupportExtension, Value, WriteMode,
@@ -1030,7 +1030,7 @@ impl Engine {
.collect::, _>>()?;
}
std::mem::drop(update_handlers);
- update_handler.finialize();
+ update_handler.finalize();
Ok(())
}
@@ -1119,12 +1119,12 @@ mod tests {
false
}
- fn finialize(&mut self) {
+ fn finalize(&mut self) {
self.0.write().unwrap().insert("finalize".to_string());
}
}
- let mut engine = Builder::new(animals_csv()).build().unwrap();
+ let mut engine = EngineBuilder::new(animals_csv()).build().unwrap();
let called_methods = Arc::new(RwLock::new(HashSet::new()));
let update_handler = TestingHandler(called_methods.clone());
@@ -1158,7 +1158,7 @@ mod tests {
// It does not test that the StateTimeout successfully ends states that have gone over the duration
#[test]
fn state_timeout_update_handler() {
- let mut engine = Builder::new(animals_csv()).build().unwrap();
+ let mut engine = EngineBuilder::new(animals_csv()).build().unwrap();
let config = EngineUpdateConfig::new().default_transitions().n_iters(1);
diff --git a/lace/src/interface/engine/update_handler.rs b/lace/src/interface/engine/update_handler.rs
index 6c6be7e5..7f709d6e 100644
--- a/lace/src/interface/engine/update_handler.rs
+++ b/lace/src/interface/engine/update_handler.rs
@@ -45,7 +45,7 @@ use crate::EngineUpdateConfig;
/// self.timings.lock().unwrap().push(Instant::now());
/// }
///
-/// fn finialize(&mut self) {
+/// fn finalize(&mut self) {
/// let timings = self.timings.lock().unwrap();
/// let mean_time_between_updates =
/// timings.iter().zip(timings.iter().skip(1))
@@ -106,7 +106,7 @@ pub trait UpdateHandler: Clone + Send + Sync {
///
/// This method is called when all updating is complete.
/// Uses for this method include cleanup, report generation, etc.
- fn finialize(&mut self) {}
+ fn finalize(&mut self) {}
}
macro_rules! impl_tuple {
@@ -151,9 +151,9 @@ macro_rules! impl_tuple {
)||+
}
- fn finialize(&mut self) {
+ fn finalize(&mut self) {
$(
- self.$idx.finialize();
+ self.$idx.finalize();
)+
}
@@ -204,8 +204,8 @@ where
false
}
- fn finialize(&mut self) {
- self.iter_mut().for_each(|handler| handler.finialize());
+ fn finalize(&mut self) {
+ self.iter_mut().for_each(|handler| handler.finalize());
}
}
@@ -285,7 +285,7 @@ impl UpdateHandler for Timeout {
}
}
- fn finialize(&mut self) {}
+ fn finalize(&mut self) {}
}
/// Limit the time each state can run for during an `Engine::update`.
@@ -427,7 +427,7 @@ impl UpdateHandler for ProgressBar {
false
}
- fn finialize(&mut self) {
+ fn finalize(&mut self) {
if let Self::Initialized { sender, handle } = std::mem::take(self) {
std::mem::drop(sender);
diff --git a/lace/src/interface/mod.rs b/lace/src/interface/mod.rs
index 49f905e3..992c7f3f 100644
--- a/lace/src/interface/mod.rs
+++ b/lace/src/interface/mod.rs
@@ -5,7 +5,7 @@ mod metadata;
mod oracle;
pub use engine::{
- update_handler, AppendStrategy, BuildEngineError, Builder, Engine,
+ update_handler, AppendStrategy, BuildEngineError, Engine, EngineBuilder,
InsertDataActions, InsertMode, OverwriteMode, Row, SupportExtension, Value,
WriteMode,
};
diff --git a/lace/src/lib.rs b/lace/src/lib.rs
index c4375580..d8e817b2 100644
--- a/lace/src/lib.rs
+++ b/lace/src/lib.rs
@@ -188,10 +188,10 @@ pub use index::*;
pub use config::EngineUpdateConfig;
pub use interface::{
- update_handler, utils, AppendStrategy, BuildEngineError, Builder,
- ConditionalEntropyType, DatalessOracle, Engine, Given, HasData, HasStates,
- ImputeUncertaintyType, InsertDataActions, InsertMode, Metadata,
- MiComponents, MiType, Oracle, OracleT, OverwriteMode,
+ update_handler, utils, AppendStrategy, BuildEngineError,
+ ConditionalEntropyType, DatalessOracle, Engine, EngineBuilder, Given,
+ HasData, HasStates, ImputeUncertaintyType, InsertDataActions, InsertMode,
+ Metadata, MiComponents, MiType, Oracle, OracleT, OverwriteMode,
PredictUncertaintyType, Row, RowSimilarityVariant, SupportExtension, Value,
WriteMode,
};
diff --git a/lace/src/prelude.rs b/lace/src/prelude.rs
index 324f0578..e7389119 100644
--- a/lace/src/prelude.rs
+++ b/lace/src/prelude.rs
@@ -1,10 +1,10 @@
//! Common import for general use.
pub use crate::{
- update_handler, AppendStrategy, Builder, Datum, Engine, EngineUpdateConfig,
- Given, ImputeUncertaintyType, InsertMode, MiType, OracleT, OverwriteMode,
- PredictUncertaintyType, Row, RowSimilarityVariant, SupportExtension, Value,
- WriteMode,
+ update_handler, AppendStrategy, Datum, Engine, EngineBuilder,
+ EngineUpdateConfig, Given, ImputeUncertaintyType, InsertMode, MiType,
+ OracleT, OverwriteMode, PredictUncertaintyType, Row, RowSimilarityVariant,
+ SupportExtension, Value, WriteMode,
};
pub use crate::data::DataSource;
diff --git a/lace/tests/engine.rs b/lace/tests/engine.rs
index 62148f9b..9a91bf65 100644
--- a/lace/tests/engine.rs
+++ b/lace/tests/engine.rs
@@ -7,7 +7,7 @@ use lace::config::EngineUpdateConfig;
use lace::data::DataSource;
use lace::examples::Example;
use lace::{
- AppendStrategy, Builder, Engine, HasStates, InsertDataActions,
+ AppendStrategy, Engine, EngineBuilder, HasStates, InsertDataActions,
SupportExtension,
};
use lace_codebook::{Codebook, ValueMap};
@@ -33,7 +33,7 @@ fn animals_codebook_path() -> PathBuf {
// tempfiles.
#[cfg(feature = "formats")]
fn engine_from_csv>(path: P) -> Engine {
- Builder::new(DataSource::Csv(path.into()))
+ EngineBuilder::new(DataSource::Csv(path.into()))
.with_nstates(2)
.build()
.unwrap()
@@ -228,7 +228,7 @@ fn cell_gibbs_smoke() {
fn engine_build_without_flat_col_is_not_flat() {
let path = animals_data_path();
let df = lace_codebook::data::read_csv(path).unwrap();
- let engine = Builder::new(DataSource::Polars(df))
+ let engine = EngineBuilder::new(DataSource::Polars(df))
.with_nstates(8)
.build()
.unwrap();
@@ -1344,7 +1344,7 @@ mod insert_data {
..Default::default()
};
- let mut engine = Builder::new(DataSource::Empty).build().unwrap();
+ let mut engine = EngineBuilder::new(DataSource::Empty).build().unwrap();
assert_eq!(engine.n_rows(), 0);
assert_eq!(engine.n_cols(), 0);
@@ -1440,7 +1440,7 @@ mod insert_data {
..Default::default()
};
- let mut engine = Builder::new(DataSource::Empty).build().unwrap();
+ let mut engine = EngineBuilder::new(DataSource::Empty).build().unwrap();
assert_eq!(engine.n_rows(), 0);
assert_eq!(engine.n_cols(), 0);
@@ -1894,7 +1894,7 @@ mod insert_data {
#[test]
fn $fn_name() {
let mut engine =
- Builder::new(DataSource::Empty).build().unwrap();
+ EngineBuilder::new(DataSource::Empty).build().unwrap();
let new_metadata = ColMetadataList::new(vec![
continuous_md("one".to_string()),
continuous_md("two".to_string()),
@@ -1972,7 +1972,7 @@ mod insert_data {
#[test]
fn $fn_name() {
let mut engine =
- Builder::new(DataSource::Empty).build().unwrap();
+ EngineBuilder::new(DataSource::Empty).build().unwrap();
let new_metadata = ColMetadataList::new(vec![
continuous_md("one".to_string()),
diff --git a/lace/tests/workflow.rs b/lace/tests/workflow.rs
index 5e02df78..af7d5e2d 100644
--- a/lace/tests/workflow.rs
+++ b/lace/tests/workflow.rs
@@ -1,8 +1,8 @@
use lace::config::EngineUpdateConfig;
use lace::data::DataSource;
use lace::update_handler::Timeout;
-use lace::Builder;
use lace::Engine;
+use lace::EngineBuilder;
use lace_codebook::data::codebook_from_csv;
use rand::SeedableRng;
use std::io::Write;
@@ -58,7 +58,7 @@ fn satellites_csv_workflow() {
let codebook =
codebook_from_csv(path.as_path(), None, None, false).unwrap();
- let mut engine: Engine = Builder::new(DataSource::Csv(path))
+ let mut engine: Engine = EngineBuilder::new(DataSource::Csv(path))
.codebook(codebook)
.with_nstates(4)
.seed_from_u64(1776)