diff --git a/.github/scripts/find_compatible_wheel.py b/.github/scripts/find_compatible_wheel.py index af3a9a34..99343631 100644 --- a/.github/scripts/find_compatible_wheel.py +++ b/.github/scripts/find_compatible_wheel.py @@ -10,23 +10,29 @@ description="Program to find wheels in a directory compatible with the current version of Python" ) -parser.add_argument("package", help="The name of the package that you are searching for a wheel for") +parser.add_argument( + "package", help="The name of the package that you are searching for a wheel for" +) parser.add_argument("dir", help="the directory under which to search for the wheels") -args=parser.parse_args() +args = parser.parse_args() -wheel=None +wheel = None for tag in sys_tags(): print(f"Looking for file matching tag {tag}", file=sys.stderr) - matches=glob.glob(args.package + "*" + str(tag) + "*.whl", root_dir=args.dir) + matches = glob.glob(f"{args.package}*{tag}*.whl", root_dir=args.dir) if len(matches) == 1: - wheel=matches[0] + wheel = matches[0] break elif len(matches) > 1: - print("Found multiple matches for the same tag " + str(tag), matches, file=sys.stderr) + print( + f"Found multiple matches for the same tag `{tag}`", + matches, + file=sys.stderr, + ) -if wheel: +if wheel: print(os.path.join(args.dir, wheel)) else: sys.exit("Did not find compatible wheel") diff --git a/.github/workflows/changelog.yaml b/.github/workflows/changelog.yaml index e4ad53b7..7ed08df9 100644 --- a/.github/workflows/changelog.yaml +++ b/.github/workflows/changelog.yaml @@ -12,7 +12,7 @@ jobs: steps: - uses: actions/checkout@v4 - + - run: | eval "$(/home/linuxbrew/.linuxbrew/bin/brew shellenv)" brew install nbbrd/tap/heylogs diff --git a/.github/workflows/python-build-test.yaml b/.github/workflows/python-build-test.yaml index caecdb8a..0259ba33 100644 --- a/.github/workflows/python-build-test.yaml +++ b/.github/workflows/python-build-test.yaml @@ -20,7 +20,7 @@ jobs: defaults: run: working-directory: pylace - + steps: - uses: actions/checkout@v4 @@ -30,7 +30,7 @@ jobs: python-version: '3.12' cache: 'pip' cache-dependency-path: "pylace/requirements-lint.txt" - + - name: Install Python dependencies run: | pip install --upgrade pip @@ -102,10 +102,11 @@ jobs: - name: Build wheels uses: PyO3/maturin-action@v1 with: + maturin-version: 1.5.1 target: ${{ matrix.target }} args: --release --out dist -i python3.8 -i python3.9 -i python3.10 -i python3.11 -i python3.12 --manifest-path pylace/Cargo.toml manylinux: auto - + - name: Install dev dependencies run: | pip install --upgrade pip @@ -113,7 +114,6 @@ jobs: - name: Install pylace run: | - ls -l ./dist WHEEL_FILE=$(python3 .github/scripts/find_compatible_wheel.py pylace ./dist) echo "Installing $WHEEL_FILE" pip install $WHEEL_FILE @@ -147,9 +147,10 @@ jobs: - name: Build wheels uses: PyO3/maturin-action@v1 with: + maturin-version: 1.5.1 target: ${{ matrix.target }} args: --release --out dist -i python3.8 -i python3.9 -i python3.10 -i python3.11 -i python3.12 --manifest-path pylace/Cargo.toml - + - name: Install dev dependencies run: | pip install --upgrade pip @@ -157,7 +158,6 @@ jobs: - name: Install pylace run: | - ls -l ./dist $WHEEL_FILE = (python3 .github/scripts/find_compatible_wheel.py pylace ./dist) echo "Installing $WHEEL_FILE" pip install $WHEEL_FILE @@ -172,11 +172,15 @@ jobs: path: dist macos: - runs-on: macos-latest needs: ["lint-python", "lint-rust"] strategy: matrix: - target: [x86_64, aarch64] + include: + - os: macos-latest + target: aarch64 + - os: macos-13 + target: x86_64 + runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 @@ -190,24 +194,22 @@ jobs: - name: Build wheels uses: PyO3/maturin-action@v1 with: + maturin-version: 1.5.1 target: ${{ matrix.target }} args: --release --out dist -i python3.8 -i python3.9 -i python3.10 -i python3.11 -i python3.12 --manifest-path pylace/Cargo.toml - name: Install dev dependencies - if: ${{ matrix.target != 'aarch64' }} run: | pip install --upgrade pip pip install -r pylace/requirements-dev.txt - name: Install pylace - if: ${{ matrix.target != 'aarch64' }} run: | WHEEL_FILE=$(python3 .github/scripts/find_compatible_wheel.py pylace ./dist) echo "Installing $WHEEL_FILE" pip install $WHEEL_FILE - name: Run Tests - if: ${{ matrix.target != 'aarch64' }} run: pytest pylace/tests - name: Upload wheels diff --git a/CHANGELOG.md b/CHANGELOG.md index 71e3698c..e35455b9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,33 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [Unreleased] +## [rust-0.8.0] - 2024-04-10 + +### Added + +- Added ability Pitman-Yor prior process + - Users now specify the prior process (and its prior) in the codebook +- `lace_stats::prior_process::Builder` + +### Changed + +- Offloaded a lot of `Assignment` functionality to `PriorProcess` +- `StateAlpha` and `ViewAlpha` transitions are now `StatePriorProcessParams` and `ViewPriorProcessParams` +- Changed `SerializedType` default to `Bincode` +- moved `Assignment` from `lace_cc` to `lace_stats` + +## [python-0.8.0] - 2024-04-10 + +### Added + +- Added ability Pitman-Yor prior process. +- Added `remove_rows` to `Engine`. +- Added `with_index` to `CodeBook`. + +### Changed + +- `StateAlpha` and `ViewAlpha` transitions are now `StatePriorProcessParams` and `ViewPriorProcessParams` +- Updated Pyo3 version to 0.21 ## [python-0.7.1] - 2024-02-27 @@ -278,7 +304,9 @@ Initial release on [PyPi](https://pypi.org/) Initial release on [crates.io](https://crates.io/) -[unreleased]: https://github.com/promised-ai/lace/compare/python-0.7.1...HEAD +[unreleased]: https://github.com/promised-ai/lace/compare/python-0.8.0...HEAD +[rust-0.8.0]: https://github.com/promised-ai/lace/compare/rust-0.7.0...rust-0.8.0 +[python-0.8.0]: https://github.com/promised-ai/lace/compare/python-0.7.1...python-0.8.0 [python-0.7.1]: https://github.com/promised-ai/lace/compare/python-0.7.0...python-0.7.1 [python-0.7.0]: https://github.com/promised-ai/lace/compare/python-0.6.0...python-0.7.0 [rust-0.7.0]: https://github.com/promised-ai/lace/compare/rust-0.6.0...rust-0.7.0 diff --git a/CITATION.cff b/CITATION.cff index 153a78d9..50abe654 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -34,5 +34,5 @@ keywords: - Bayesian - Machine Learning license: BUSL-1.1 -version: 0.7.0 +version: 0.8.0 date-released: '2024-02-07' diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a6e2835e..d3f2da7a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,5 +1,8 @@ # Contributing to Lace +> :warning: Note that we are in the process of a major refector and are unlikely +> to accept any contributions apart from simple bugfixes + ## General Guidelines - Don't use getters and setters if it causes indirection in a performance heavy diff --git a/book/lace_preprocess_mdbook_yaml/Cargo.lock b/book/lace_preprocess_mdbook_yaml/Cargo.lock index 7499276e..3197719b 100644 --- a/book/lace_preprocess_mdbook_yaml/Cargo.lock +++ b/book/lace_preprocess_mdbook_yaml/Cargo.lock @@ -2201,14 +2201,15 @@ checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4" [[package]] name = "rv" -version = "0.16.4" +version = "0.16.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a650cc227f4eb01b043fb7580097d6830c688a8f33cd9eabbced2026d11abf5" +checksum = "c07e0a3b756794c7ea2f05d93760ffb946ff4f94b255d92444d94c19fd71f4ab" dependencies = [ "doc-comment", "lru", "nalgebra", "num", + "num-traits", "peroxide", "rand", "rand_distr", diff --git a/book/lace_preprocess_mdbook_yaml/Cargo.toml b/book/lace_preprocess_mdbook_yaml/Cargo.toml index b202e30a..7a3c0a9e 100644 --- a/book/lace_preprocess_mdbook_yaml/Cargo.toml +++ b/book/lace_preprocess_mdbook_yaml/Cargo.toml @@ -16,8 +16,8 @@ path = "src/main.rs" anyhow = "1.0" clap = "4.2" env_logger = "0.10" -lace_codebook = { path = "../../lace/lace_codebook", version = "0.6.0" } -lace_stats = { path = "../../lace/lace_stats", version = "0.3.0" } +lace_codebook = { path = "../../lace/lace_codebook", version = "0.7.0" } +lace_stats = { path = "../../lace/lace_stats", version = "0.4.0" } log = "0.4" mdbook = "0.4" pulldown-cmark = { version = "0.9", default-features = false } diff --git a/book/lace_preprocess_mdbook_yaml/src/lib.rs b/book/lace_preprocess_mdbook_yaml/src/lib.rs index 4fc562d4..04d1a8c8 100644 --- a/book/lace_preprocess_mdbook_yaml/src/lib.rs +++ b/book/lace_preprocess_mdbook_yaml/src/lib.rs @@ -1,5 +1,3 @@ -use std::collections::HashMap; - use anyhow::anyhow; use log::debug; use mdbook::{ @@ -12,7 +10,15 @@ use regex::Regex; use serde::Deserialize; -type GammaMap = HashMap; +#[derive(Deserialize)] +struct ViewPriorProcess { + pub view_prior_process: lace_codebook::PriorProcess, +} + +#[derive(Deserialize)] +struct StatePriorProcess { + pub state_prior_process: lace_codebook::PriorProcess, +} fn check_deserialize_yaml(input: &str) -> anyhow::Result<()> where @@ -54,17 +60,14 @@ macro_rules! check_deserialize_arm { } } -fn check_deserialize_dyn( - input: &str, - type_name: &str, - format: &str, -) -> anyhow::Result<()> { +fn check_deserialize_dyn(input: &str, type_name: &str, format: &str) -> anyhow::Result<()> { check_deserialize_arm!( input, type_name, format, [ - GammaMap, + ViewPriorProcess, + StatePriorProcess, lace_codebook::ColMetadata, lace_codebook::ColMetadataList ] @@ -80,35 +83,21 @@ impl YamlTester { YamlTester } - fn examine_chapter_content( - &self, - content: &str, - re: &Regex, - ) -> anyhow::Result<()> { + fn examine_chapter_content(&self, content: &str, re: &Regex) -> anyhow::Result<()> { let parser = Parser::new(content); let mut code_block = Some(String::new()); for event in parser { match event { - Event::Start(Tag::CodeBlock(CodeBlockKind::Fenced( - ref code_block_string, - ))) => { + Event::Start(Tag::CodeBlock(CodeBlockKind::Fenced(ref code_block_string))) => { if re.is_match(code_block_string) { - debug!( - "YAML Block Start, identifier string={}", - code_block_string - ); + debug!("YAML Block Start, identifier string={}", code_block_string); code_block = Some(String::new()); } } - Event::End(Tag::CodeBlock(CodeBlockKind::Fenced( - ref code_block_string, - ))) => { + Event::End(Tag::CodeBlock(CodeBlockKind::Fenced(ref code_block_string))) => { if let Some(captures) = re.captures(code_block_string) { - debug!( - "Code Block End, identifier string={}", - code_block_string - ); + debug!("Code Block End, identifier string={}", code_block_string); let serialization_format = captures .get(1) @@ -119,21 +108,13 @@ impl YamlTester { .get(2) .ok_or(anyhow!("No deserialize type found"))? .as_str(); - debug!( - "Target deserialization type is {}", - target_type - ); + debug!("Target deserialization type is {}", target_type); let final_block = code_block.take(); - let final_block = - final_block.ok_or(anyhow!("No YAML text found"))?; + let final_block = final_block.ok_or(anyhow!("No YAML text found"))?; debug!("Code block ended up as\n{}", final_block); - check_deserialize_dyn( - &final_block, - target_type, - serialization_format, - )?; + check_deserialize_dyn(&final_block, target_type, serialization_format)?; } } Event::Text(ref text) => { @@ -154,11 +135,7 @@ impl Preprocessor for YamlTester { "lace-yaml-tester" } - fn run( - &self, - _ctx: &PreprocessorContext, - book: Book, - ) -> anyhow::Result { + fn run(&self, _ctx: &PreprocessorContext, book: Book) -> anyhow::Result { debug!("Starting the run"); let re = Regex::new(r"^(yaml|json).*,deserializeTo=([^,]+)").unwrap(); for book_item in book.iter() { diff --git a/book/src/SUMMARY.md b/book/src/SUMMARY.md index 98ee2027..b2560037 100644 --- a/book/src/SUMMARY.md +++ b/book/src/SUMMARY.md @@ -16,6 +16,7 @@ - [Data Simulation](./pcc/simulate.md) - [In- and out-of-table operations](./pcc/inouttable.md) - [Adding data to a model](./pcc/add-data.md) + - [Prior processes](./pcc/prior-processes.md) - [Preparing your data](./data/basics.md) - [Codebook reference](./codebook-ref.md) - [Appendix](./appendix/appendix.md) diff --git a/book/src/codebook-ref.md b/book/src/codebook-ref.md index d9b33789..eaf2c8ca 100644 --- a/book/src/codebook-ref.md +++ b/book/src/codebook-ref.md @@ -16,30 +16,50 @@ information about String name of the table. For your reference. -### `state_alpha_prior` +### `state_prior_process` -A gamma prior on the Chinese Restaurant Process (CRP) alpha parameter assigning -columns to views. +The prior process used for assigning columns to views. Can either be a Dirichlet process with a Gamma prior on alpha: -Example with a gamma prior +```yaml,deserializeTo=StatePriorProcess +state_prior_process: !dirichlet + alpha_prior: + shape: 1.0 + rate: 1.0 +``` + +or a Pitman-Yor process with a Gamma prior on alpha and a Beta prior on d. -```yaml,deserializeTo=GammaMap -state_alpha_prior: - shape: 1.0 - rate: 1.0 +```yaml,deserializeTo=StatePriorProcess +state_prior_process: !pitman_yor + alpha_prior: + shape: 1.0 + rate: 1.0 + d_prior: + alpha: 0.5 + beta: 0.5 ``` -### `view_alpha_prior` +### `view_prior_process` -A gamma prior on the Chinese Restaurant Process (CRP) alpha parameter assigning -rows within views to categories. +The prior process used for assigning rows to categories. Can either be a Dirichlet process with a Gamma prior on alpha: + +```yaml,deserializeTo=ViewPriorProcess +view_prior_process: !dirichlet + alpha_prior: + shape: 1.0 + rate: 1.0 +``` -Example with a gamma prior +or a Pitman-Yor process with a Gamma prior on alpha and a Beta prior on d. -```yaml,deserializeTo=GammaMap -view_alpha_prior: - shape: 1.0 - rate: 1.0 +```yaml,deserializeTo=ViewPriorProcess +view_prior_process: !pitman_yor + alpha_prior: + shape: 1.0 + rate: 1.0 + d_prior: + alpha: 0.5 + beta: 0.5 ``` ### `col_metadata` diff --git a/book/src/pcc/img/crp.png b/book/src/pcc/img/crp.png new file mode 100644 index 00000000..20419171 Binary files /dev/null and b/book/src/pcc/img/crp.png differ diff --git a/book/src/pcc/img/pyp.png b/book/src/pcc/img/pyp.png new file mode 100644 index 00000000..a0867cfe Binary files /dev/null and b/book/src/pcc/img/pyp.png differ diff --git a/book/src/pcc/prior-processes.md b/book/src/pcc/prior-processes.md new file mode 100644 index 00000000..3a6fb8a7 --- /dev/null +++ b/book/src/pcc/prior-processes.md @@ -0,0 +1,21 @@ +# Prior Processes + +In Lace (and in Bayesian nonparametrics) we put a prior on the number of parameters. This *prior process* formalizes how instances are distributed to an unknown number of categories. Lace gives you two options + +- The one-parameter Dirichlet process, `DP(α)` +- The two-parameter Pitman-Yor process, `PYP(α, d)` + +The Dirichlet process more heavily penalizes new categories with an exponential fall off while the Pitman-Yor process has a power law fall off in the number for categories. When d = 0, Pitman-Yor is equivalent to the Dirichlet process. + +![Dirichlet Process](img/crp.png) + +**Figure**: Category ID (y-axis) by instance number (x-axis) for Dirichlet process draws for various values of alpha. + +Pitman-Yor may fit the data better but (and because) it will create more parameters, which will cause model training to take longer. + +![Pitman-Yor Process](img/pyp.png) + +**Figure**: Category ID (y-axis) by instance number (x-axis) for Pitman-Yor process draws for various values of alpha and d. + + +For those looking for a good introduction to prior process, [this slide deck](https://www.gatsby.ucl.ac.uk/~ywteh/teaching/probmodels/lecture5bnp.pdf) from Yee Whye Teh is a good resource. diff --git a/book/src/workflow/codebook.md b/book/src/workflow/codebook.md index 98205e0b..2788c219 100644 --- a/book/src/workflow/codebook.md +++ b/book/src/workflow/codebook.md @@ -49,7 +49,7 @@ let df = CsvReader::from_path(paths.data) .unwrap(); // Create the default codebook -let codebook = Codebook::from_df(&df, None, None, false).unwrap(); +let codebook = Codebook::from_df(&df, None, None, None, false).unwrap(); ``` diff --git a/book/src/workflow/model.md b/book/src/workflow/model.md index 47370cd1..e13299f5 100644 --- a/book/src/workflow/model.md +++ b/book/src/workflow/model.md @@ -125,7 +125,7 @@ let df = CsvReader::from_path(paths.data) .unwrap(); // Create the default codebook -let codebook = Codebook::from_df(&df, None, None, false).unwrap(); +let codebook = Codebook::from_df(&df, None, None, None, false).unwrap(); // Build an rng let rng = Xoshiro256Plus::from_entropy(); @@ -156,12 +156,12 @@ let run_config = EngineUpdateConfig::new() .n_iters(100) .transitions(vec![ StateTransition::ColumnAssignment(ColAssignAlg::Gibbs), - StateTransition::StateAlpha, + StateTransition::StatePriorProcessParams, StateTransition::RowAssignment(RowAssignAlg::Sams), StateTransition::ComponentParams, StateTransition::RowAssignment(RowAssignAlg::Slice), StateTransition::ComponentParams, - StateTransition::ViewAlphas, + StateTransition::ViewPriorProcessParams, StateTransition::FeaturePriors, ]); @@ -219,7 +219,7 @@ engine.update( save_path="mydata.lace", transitions=[ StateTransition.row_assignment(RowKernel.slice()), - StateTransition.view_alphas(), + StateTransition.view_prior_process_params(), ], ) ``` diff --git a/cli/Cargo.lock b/cli/Cargo.lock index 6873681a..5521fcef 100644 --- a/cli/Cargo.lock +++ b/cli/Cargo.lock @@ -881,7 +881,7 @@ dependencies = [ [[package]] name = "lace" -version = "0.7.0" +version = "0.8.0" dependencies = [ "ctrlc", "dirs", @@ -910,7 +910,7 @@ dependencies = [ [[package]] name = "lace-cli" -version = "0.7.0" +version = "0.8.0" dependencies = [ "approx", "clap", @@ -924,7 +924,7 @@ dependencies = [ [[package]] name = "lace_cc" -version = "0.6.0" +version = "0.7.0" dependencies = [ "enum_dispatch", "itertools", @@ -944,7 +944,7 @@ dependencies = [ [[package]] name = "lace_codebook" -version = "0.6.0" +version = "0.7.0" dependencies = [ "lace_consts", "lace_data", @@ -974,7 +974,7 @@ dependencies = [ [[package]] name = "lace_geweke" -version = "0.3.0" +version = "0.4.0" dependencies = [ "indicatif", "lace_stats", @@ -986,7 +986,7 @@ dependencies = [ [[package]] name = "lace_metadata" -version = "0.6.0" +version = "0.7.0" dependencies = [ "bincode", "hex", @@ -1005,7 +1005,7 @@ dependencies = [ [[package]] name = "lace_stats" -version = "0.3.0" +version = "0.4.0" dependencies = [ "itertools", "lace_consts", diff --git a/cli/Cargo.toml b/cli/Cargo.toml index 38d9689b..c4d5b1bf 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "lace-cli" -version = "0.7.0" +version = "0.8.0" authors = ["Promised AI"] edition = "2021" rust-version = "1.62.0" @@ -17,7 +17,7 @@ name = "lace" path = "src/main.rs" [dependencies] -lace = { path = "../lace", version = "0.7.0", features = ["formats", "ctrlc_handler"]} +lace = { path = "../lace", version = "0.8.0", features = ["formats", "ctrlc_handler"]} clap = { version = "4.3.17", features = ["derive"] } env_logger = "0.10" serde_yaml = "0.9.4" diff --git a/cli/README.md b/cli/README.md index 7fe00563..25f4a47f 100644 --- a/cli/README.md +++ b/cli/README.md @@ -1 +1 @@ -# Lace CLI \ No newline at end of file +# Lace CLI diff --git a/cli/resources/datasets/animals/codebook.yaml b/cli/resources/datasets/animals/codebook.yaml index 9d856418..2c1463b3 100644 --- a/cli/resources/datasets/animals/codebook.yaml +++ b/cli/resources/datasets/animals/codebook.yaml @@ -1,10 +1,12 @@ -table_name: my_table -state_alpha_prior: - shape: 1.0 - rate: 1.0 -view_alpha_prior: - shape: 1.0 - rate: 1.0 +table_name: animals +state_prior_process: !dirichlet + alpha_prior: + shape: 1.0 + rate: 1.0 +view_prior_process: !dirichlet + alpha_prior: + shape: 1.0 + rate: 1.0 col_metadata: - name: black coltype: !Categorical diff --git a/cli/resources/datasets/satellites/codebook.yaml b/cli/resources/datasets/satellites/codebook.yaml index 7e750b04..5d7c2390 100644 --- a/cli/resources/datasets/satellites/codebook.yaml +++ b/cli/resources/datasets/satellites/codebook.yaml @@ -1,10 +1,12 @@ -table_name: my_table -state_alpha_prior: - shape: 1.0 - rate: 1.0 -view_alpha_prior: - shape: 1.0 - rate: 1.0 +table_name: satellites +state_prior_process: !dirichlet + alpha_prior: + shape: 1.0 + rate: 1.0 +view_prior_process: !dirichlet + alpha_prior: + shape: 1.0 + rate: 1.0 col_metadata: - name: Country_of_Operator coltype: !Categorical diff --git a/cli/src/opt.rs b/cli/src/opt.rs index 3b782028..53b4b843 100644 --- a/cli/src/opt.rs +++ b/cli/src/opt.rs @@ -21,8 +21,8 @@ pub enum Transition { ColumnAssignment, ComponentParams, RowAssignment, - StateAlpha, - ViewAlphas, + StatePriorProcessParams, + ViewPriorProcessParams, FeaturePriors, } @@ -33,8 +33,8 @@ impl std::str::FromStr for Transition { match s { "column_assignment" => Ok(Self::ColumnAssignment), "row_assignment" => Ok(Self::RowAssignment), - "state_alpha" => Ok(Self::StateAlpha), - "view_alphas" => Ok(Self::ViewAlphas), + "state_prior_process_params" => Ok(Self::StatePriorProcessParams), + "view_prior_process_params" => Ok(Self::ViewPriorProcessParams), "feature_priors" => Ok(Self::FeaturePriors), "component_params" => Ok(Self::ComponentParams), _ => Err(format!("cannot parse '{s}'")), @@ -142,17 +142,17 @@ impl RunArgs { let transitions = match self.transitions { None => vec![ StateTransition::ColumnAssignment(col_alg), - StateTransition::StateAlpha, + StateTransition::StatePriorProcessParams, StateTransition::RowAssignment(row_alg), - StateTransition::ViewAlphas, + StateTransition::ViewPriorProcessParams, StateTransition::FeaturePriors, ], Some(ref ts) => ts .iter() .map(|t| match t { Transition::FeaturePriors => StateTransition::FeaturePriors, - Transition::StateAlpha => StateTransition::StateAlpha, - Transition::ViewAlphas => StateTransition::ViewAlphas, + Transition::StatePriorProcessParams => StateTransition::StatePriorProcessParams, + Transition::ViewPriorProcessParams => StateTransition::ViewPriorProcessParams, Transition::ComponentParams => StateTransition::ComponentParams, Transition::RowAssignment => StateTransition::RowAssignment(row_alg), Transition::ColumnAssignment => StateTransition::ColumnAssignment(col_alg), @@ -264,9 +264,6 @@ pub struct CodebookArgs { /// Parquet input filename #[clap(long = "parquet", group = "src")] pub parquet_src: Option, - /// CRP alpha prior on columns and rows - #[clap(long, default_value = "1.0, 1.0")] - pub alpha_prior: GammaParameters, /// Codebook out. May be either json or yaml #[clap(name = "CODEBOOK_OUT")] pub output: PathBuf, @@ -343,16 +340,16 @@ mod tests { #[test] fn view_alphas_from_str() { assert_eq!( - Transition::from_str("view_alphas").unwrap(), - Transition::ViewAlphas + Transition::from_str("view_prior_process_params").unwrap(), + Transition::ViewPriorProcessParams ); } #[test] fn state_alpha_from_str() { assert_eq!( - Transition::from_str("state_alpha").unwrap(), - Transition::StateAlpha + Transition::from_str("state_prior_process_params").unwrap(), + Transition::StatePriorProcessParams ); } diff --git a/cli/src/routes.rs b/cli/src/routes.rs index 0bd09137..f557b2bb 100644 --- a/cli/src/routes.rs +++ b/cli/src/routes.rs @@ -3,7 +3,6 @@ use std::time::Duration; use lace::codebook::Codebook; use lace::metadata::{deserialize_file, serialize_obj}; -use lace::stats::rv::dist::Gamma; use lace::update_handler::{CtrlC, ProgressBar, Timeout}; use lace::{Engine, EngineBuilder}; @@ -23,7 +22,6 @@ pub fn summarize_engine(cmd: opt::SummarizeArgs) -> i32 { String::from("State"), String::from("Iters"), String::from("Views"), - String::from("Alpha"), String::from("Score"), ]; @@ -36,9 +34,8 @@ pub fn summarize_engine(cmd: opt::SummarizeArgs) -> i32 { vec![ format!("{id}"), format!("{}", n_iters), - format!("{}", state.asgn.n_cats), - format!("{:.6}", state.asgn.alpha), - format!("{:.6}", state.loglike), + format!("{}", state.asgn().n_cats), + format!("{:.6}", state.score.ln_likelihood + state.score.ln_prior), ] }) .collect(); @@ -202,13 +199,6 @@ pub fn run(cmd: opt::RunArgs) -> i32 { macro_rules! codebook_from { ($path: ident, $fn: ident, $cmd: ident) => {{ - let alpha_prior: Gamma = match $cmd.alpha_prior.try_into() { - Ok(gamma) => gamma, - Err(err) => { - eprint!("Invalid Gamma parameters to CRP prior: {err}"); - return 1; - } - }; if !$path.exists() { eprintln!("Input {:?} not found", $path); return 1; @@ -217,7 +207,8 @@ macro_rules! codebook_from { let codebook = match lace::codebook::formats::$fn( $path, Some($cmd.category_cutoff), - Some(alpha_prior), + None, + None, $cmd.no_hyper, ) { Ok(codebook) => codebook, diff --git a/cli/tests/cli.rs b/cli/tests/cli.rs index 121aece9..c5b21951 100644 --- a/cli/tests/cli.rs +++ b/cli/tests/cli.rs @@ -104,12 +104,12 @@ mod run { " --- table_name: my_data - state_alpha_prior: - !Gamma + state_prior_process: !dirichlet + alpha_prior: shape: 1.0 rate: 1.0 - view_alpha_prior: - !Gamma + view_prior_process: !dirichlet + alpha_prior: shape: 1.0 rate: 1.0 col_metadata: @@ -161,14 +161,17 @@ mod run { " --- table_name: my_data - state_alpha_prior: - !Gamma + state_prior_process: !dirichlet + alpha_prior: shape: 1.0 rate: 1.0 - view_alpha_prior: - !Gamma + view_prior_process: !pitman_yor + alpha_prior: shape: 1.0 rate: 1.0 + d_prior: + alpha: 0.5 + beta: 0.5 col_metadata: - name: z coltype: @@ -403,9 +406,9 @@ mod run { save_config: ~ transitions: - !row_assignment slice - - !view_alphas + - !view_prior_process_params - !column_assignment finite_cpu - - !state_alpha + - !state_prior_process_params - !feature_priors " ); @@ -551,7 +554,7 @@ mod run { .arg("--run-config") .arg(config.path()) .arg("--transitions") - .arg("state_alpha,row_assignment") + .arg("state_prior_process_params,row_assignment") .arg(dirname) .output() .expect("failed to execute process"); @@ -734,7 +737,7 @@ mod run { .arg("-q") .args(["--n-states", "4", "--n-iters", "10", "--flat-columns"]) .arg("--transitions") - .arg("state_alpha,view_alphas,component_params,row_assignment,feature_priors") + .arg("state_prior_process_params,view_prior_process_params,component_params,row_assignment,feature_priors") .arg("--csv") .arg(csv::animals()) .arg(dir.path().to_str().unwrap()) @@ -833,40 +836,6 @@ macro_rules! test_codebook_under_fmt { }); assert!(no_hypers); } - - #[test] - fn with_good_alpha_params() { - let fileout = tempfile::Builder::new().suffix(".yaml").tempfile().unwrap(); - let output = Command::new(LACE_CMD) - .arg("codebook") - .arg($flag) - .arg($crate::$mod::animals()) - .arg(fileout.path().to_str().unwrap()) - .arg("--alpha-prior") - .arg("2.3, 2.1") - .output() - .expect("failed to execute process"); - - assert!(output.status.success()); - } - - #[test] - fn with_bad_alpha_params() { - let fileout = tempfile::Builder::new().suffix(".yaml").tempfile().unwrap(); - let output = Command::new(LACE_CMD) - .arg("codebook") - .arg($flag) - .arg($crate::$mod::animals()) - .arg(fileout.path().to_str().unwrap()) - .arg("--alpha-prior") - .arg("2.3, -0.1") - .output() - .expect("failed to execute process"); - - assert!(!output.status.success()); - let err = String::from_utf8_lossy(output.stderr.as_slice()); - assert!(err.contains("must be greater than zero")); - } } }; } diff --git a/lace/Cargo.lock b/lace/Cargo.lock index e8143bab..b187c373 100644 --- a/lace/Cargo.lock +++ b/lace/Cargo.lock @@ -610,12 +610,12 @@ dependencies = [ [[package]] name = "darling" -version = "0.20.3" +version = "0.20.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0209d94da627ab5605dcccf08bb18afa5009cfbef48d8a8b7d7bdbc79be25c5e" +checksum = "54e36fcd13ed84ffdfda6f5be89b31287cbb80c439841fe69e04841435464391" dependencies = [ - "darling_core 0.20.3", - "darling_macro 0.20.3", + "darling_core 0.20.8", + "darling_macro 0.20.8", ] [[package]] @@ -634,9 +634,9 @@ dependencies = [ [[package]] name = "darling_core" -version = "0.20.3" +version = "0.20.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "177e3443818124b357d8e76f53be906d60937f0d3a90773a664fa63fa253e621" +checksum = "9c2cf1c23a687a1feeb728783b993c4e1ad83d99f351801977dd809b48d0a70f" dependencies = [ "fnv", "ident_case", @@ -659,11 +659,11 @@ dependencies = [ [[package]] name = "darling_macro" -version = "0.20.3" +version = "0.20.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "836a9bbc7ad63342d6d6e7b815ccab164bc77a2d95d84bc3117a8c0d5c98e2d5" +checksum = "a668eda54683121533a393014d8692171709ff57a7d61f187b6e782719f8933f" dependencies = [ - "darling_core 0.20.3", + "darling_core 0.20.8", "quote", "syn 2.0.48", ] @@ -1141,7 +1141,7 @@ dependencies = [ [[package]] name = "lace" -version = "0.7.0" +version = "0.8.0" dependencies = [ "approx", "clap", @@ -1176,9 +1176,10 @@ dependencies = [ [[package]] name = "lace_cc" -version = "0.6.0" +version = "0.7.0" dependencies = [ "approx", + "clap", "criterion", "enum_dispatch", "itertools 0.12.0", @@ -1188,17 +1189,19 @@ dependencies = [ "lace_geweke", "lace_stats", "lace_utils", + "plotly", "rand", "rand_xoshiro", "rayon", "serde", + "serde_json", "special", "thiserror", ] [[package]] name = "lace_codebook" -version = "0.6.0" +version = "0.7.0" dependencies = [ "indoc", "lace_consts", @@ -1232,7 +1235,7 @@ dependencies = [ [[package]] name = "lace_geweke" -version = "0.3.0" +version = "0.4.0" dependencies = [ "indicatif", "lace_stats", @@ -1244,7 +1247,7 @@ dependencies = [ [[package]] name = "lace_metadata" -version = "0.6.0" +version = "0.7.0" dependencies = [ "bincode", "hex", @@ -1264,7 +1267,7 @@ dependencies = [ [[package]] name = "lace_stats" -version = "0.3.0" +version = "0.4.0" dependencies = [ "approx", "criterion", @@ -1622,6 +1625,12 @@ dependencies = [ "serde", ] +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + [[package]] name = "num-integer" version = "0.1.45" @@ -2522,7 +2531,7 @@ version = "2.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "881b6f881b17d13214e5d494c939ebab463d01264ce1811e9d4ac3a882e7695f" dependencies = [ - "darling 0.20.3", + "darling 0.20.8", "proc-macro2", "quote", "syn 2.0.48", @@ -2764,12 +2773,13 @@ dependencies = [ [[package]] name = "time" -version = "0.3.31" +version = "0.3.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f657ba42c3f86e7680e53c8cd3af8abbe56b5491790b46e22e19c0d57463583e" +checksum = "c8248b6521bb14bc45b4067159b9b6ad792e2d6d754d6c41fb50e29fefe38749" dependencies = [ "deranged", "itoa", + "num-conv", "powerfmt", "serde", "time-core", @@ -2784,10 +2794,11 @@ checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" [[package]] name = "time-macros" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26197e33420244aeb70c3e8c78376ca46571bc4e701e4791c2cd9f57dcb3a43f" +checksum = "7ba3a3ef41e6672a2f0f001392bb5dcd3ff0a9992d618ca761a11c3121547774" dependencies = [ + "num-conv", "time-core", ] diff --git a/lace/Cargo.toml b/lace/Cargo.toml index 62cd52e0..c88d3bbf 100644 --- a/lace/Cargo.toml +++ b/lace/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "lace" -version = "0.7.0" +version = "0.8.0" authors = ["Promised AI"] build = "build.rs" edition = "2021" @@ -31,14 +31,14 @@ name = "lace" path = "src/lib.rs" [dependencies] -lace_cc = { path = "lace_cc", version = "0.6.0" } +lace_cc = { path = "lace_cc", version = "0.7.0" } lace_utils = { path = "lace_utils", version = "0.3.0" } -lace_stats = { path = "lace_stats", version = "0.3.0" } -lace_codebook = { path = "lace_codebook", version = "0.6.0", default_features=false} -lace_geweke = { path = "lace_geweke", version = "0.3.0" } +lace_stats = { path = "lace_stats", version = "0.4.0" } +lace_codebook = { path = "lace_codebook", version = "0.7.0", default_features=false} +lace_geweke = { path = "lace_geweke", version = "0.4.0" } lace_consts = { path = "lace_consts", version = "0.2.1" } lace_data = { path = "lace_data", version = "0.3.0" } -lace_metadata = { path = "lace_metadata", version = "0.6.0" } +lace_metadata = { path = "lace_metadata", version = "0.7.0" } dirs = { version="5", optional = true} num = "0.4" rand_xoshiro = { version="0.6", features = ["serde1"] } diff --git a/lace/benches/insert_data.rs b/lace/benches/insert_data.rs index d6509305..0a3e04e3 100644 --- a/lace/benches/insert_data.rs +++ b/lace/benches/insert_data.rs @@ -42,8 +42,8 @@ fn build_engine(nrows: usize, ncols: usize) -> Engine { let codebook = Codebook { table_name: "table".into(), - state_alpha_prior: None, - view_alpha_prior: None, + state_prior_process: None, + view_prior_process: None, col_metadata, comments: None, row_names: (0..nrows) diff --git a/lace/examples/column_geweke.rs b/lace/examples/column_geweke.rs index 878622b7..10d3add4 100644 --- a/lace/examples/column_geweke.rs +++ b/lace/examples/column_geweke.rs @@ -1,5 +1,6 @@ use lace::prelude::*; use lace_geweke::*; +use lace_stats::prior_process::Builder as AssignmentBuilder; use lace_stats::rv::dist::{ Categorical, Gaussian, NormalInvChiSquared, SymmetricDirichlet, }; @@ -15,10 +16,10 @@ fn main() { // The column model uses an assignment as its setting. We'll draw a // 50-length assignment from the prior. let transitions = vec![ - ViewTransition::Alpha, + ViewTransition::PriorProcessParams, ViewTransition::RowAssignment(RowAssignAlg::Slice), ]; - let asgn = AssignmentBuilder::new(10).flat().build().unwrap(); + let asgn = AssignmentBuilder::new(10).flat().build().unwrap().asgn; let settings = ColumnGewekeSettings::new(asgn, transitions); diff --git a/lace/examples/shapes.rs b/lace/examples/shapes.rs index 366c3697..346b0c53 100644 --- a/lace/examples/shapes.rs +++ b/lace/examples/shapes.rs @@ -108,9 +108,14 @@ mod requires_formats { // generate codebook println!("Generating codebook"); - let codebook = - lace_codebook::data::codebook_from_csv(f.path(), None, None, false) - .unwrap(); + let codebook = lace_codebook::data::codebook_from_csv( + f.path(), + None, + None, + None, + false, + ) + .unwrap(); // generate engine println!("Constructing Engine"); @@ -143,7 +148,7 @@ mod requires_formats { engine .states .iter() - .for_each(|state| print!("{} ", state.views[0].asgn.n_cats)); + .for_each(|state| print!("{} ", state.views[0].asgn().n_cats)); println!("\nPlotting"); plot(xs_in, ys_in, xs_sim, ys_sim); diff --git a/lace/lace_cc/Cargo.toml b/lace/lace_cc/Cargo.toml index 08078830..5c092ca1 100644 --- a/lace/lace_cc/Cargo.toml +++ b/lace/lace_cc/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "lace_cc" -version = "0.6.0" +version = "0.7.0" authors = ["Promised AI"] edition = "2021" exclude = ["tests/*", "resources/test/*", "target/*"] @@ -11,11 +11,11 @@ description = "Core of the Lace cross-categorization engine library" [dependencies] lace_utils = { path = "../lace_utils", version = "0.3.0" } -lace_stats = { path = "../lace_stats", version = "0.3.0" } -lace_geweke = { path = "../lace_geweke", version = "0.3.0" } +lace_stats = { path = "../lace_stats", version = "0.4.0" } +lace_geweke = { path = "../lace_geweke", version = "0.4.0" } lace_consts = { path = "../lace_consts", version = "0.2.1" } lace_data = { path = "../lace_data", version = "0.3.0" } -lace_codebook = { path = "../lace_codebook", version = "0.6.0" } +lace_codebook = { path = "../lace_codebook", version = "0.7.0" } rand = {version="0.8", features=["serde1"]} rayon = "1.5" serde = { version = "1", features = ["derive"] } @@ -28,6 +28,9 @@ itertools = "0.12" [dev-dependencies] approx = "0.5.1" criterion = "0.5" +clap = { version = "4.3.17", features = ["derive"] } +plotly = "0.8" +serde_json = "1" [[bench]] name = "state_types" diff --git a/lace/examples/state_geweke.rs b/lace/lace_cc/examples/state_geweke.rs similarity index 75% rename from lace/examples/state_geweke.rs rename to lace/lace_cc/examples/state_geweke.rs index e15bd81c..57d7ced7 100644 --- a/lace/examples/state_geweke.rs +++ b/lace/lace_cc/examples/state_geweke.rs @@ -1,27 +1,25 @@ +use std::path::PathBuf; + use clap::Parser; -use lace::prelude::*; -use lace_cc::state::StateGewekeSettings; -use lace_geweke::GewekeTester; use plotly::common::Mode; use plotly::layout::Layout; use plotly::{Plot, Scatter}; use rand::SeedableRng; use rand_xoshiro::Xoshiro256Plus; +use lace_cc::alg::{ColAssignAlg, RowAssignAlg}; +use lace_cc::feature::FType; +use lace_cc::state::{State, StateGewekeSettings}; +use lace_cc::transition::StateTransition; +use lace_geweke::GewekeTester; +use lace_stats::prior_process::PriorProcessType; + #[derive(Parser, Debug)] #[clap(rename_all = "kebab")] struct Opt { - #[clap( - long, - default_value = "gibbs", - value_parser = ["finite_cpu", "gibbs", "slice", "sams"], - )] + #[clap(long, default_value = "gibbs")] pub row_alg: RowAssignAlg, - #[clap( - long, - default_value = "gibbs", - value_parser = ["finite_cpu", "gibbs", "slice"], - )] + #[clap(long, default_value = "gibbs")] pub col_alg: ColAssignAlg, #[clap(long, default_value = "50")] pub nrows: usize, @@ -39,6 +37,10 @@ struct Opt { pub no_priors: bool, #[clap(long)] pub plot_var: Option, + #[clap(long, short, default_value = "10000")] + pub niters: usize, + #[clap(long)] + dst: Option, } fn main() { @@ -59,7 +61,12 @@ fn main() { // The state's Geweke test settings require the number of rows in the // state (50), and the types of each column. Everything else is filled out // automatically. - let mut settings = StateGewekeSettings::new(opt.nrows, ftypes); + let mut settings = StateGewekeSettings::new( + opt.nrows, + ftypes, + PriorProcessType::Dirichlet, + PriorProcessType::Dirichlet, + ); let mut transitions: Vec = Vec::new(); if !opt.no_col_reassign { @@ -67,7 +74,7 @@ fn main() { } if !opt.no_state_alpha { - transitions.push(StateTransition::StateAlpha); + transitions.push(StateTransition::StatePriorProcessParams); } if !opt.no_row_reassign { @@ -75,7 +82,7 @@ fn main() { } if !opt.no_view_alpha { - transitions.push(StateTransition::ViewAlphas); + transitions.push(StateTransition::ViewPriorProcessParams); } if !opt.no_priors { @@ -88,7 +95,7 @@ fn main() { // Initialize a tester given the settings and run. let mut geweke: GewekeTester = GewekeTester::new(settings); - geweke.run(10_000, Some(1), &mut rng); + geweke.run(opt.niters, Some(5), &mut rng); // Reports the deviation from a perfect correspondence between the // forward and posterior CDFs. The best score is zero, the worst possible @@ -97,7 +104,7 @@ fn main() { res.report(); if let Some(ref key) = opt.plot_var { - use lace::stats::EmpiricalCdf; + use lace_stats::EmpiricalCdf; use lace_utils::minmax; let (min_f, max_f) = minmax(res.forward.get(key).unwrap()); let (min_p, max_p) = minmax(res.posterior.get(key).unwrap()); @@ -125,4 +132,14 @@ fn main() { plot.set_layout(Layout::new()); plot.show(); } + + if let Some(dst) = opt.dst { + let f = std::fs::OpenOptions::new() + .create(true) + .write(true) + .truncate(true) + .open(dst) + .unwrap(); + serde_json::to_writer(f, &res).unwrap(); + } } diff --git a/lace/examples/view_geweke.rs b/lace/lace_cc/examples/view_geweke.rs similarity index 74% rename from lace/examples/view_geweke.rs rename to lace/lace_cc/examples/view_geweke.rs index a8e07aea..dbc0a4e1 100644 --- a/lace/examples/view_geweke.rs +++ b/lace/lace_cc/examples/view_geweke.rs @@ -1,18 +1,17 @@ use clap::Parser; -use lace::prelude::*; -use lace_cc::view::ViewGewekeSettings; -use lace_geweke::GewekeTester; use rand::SeedableRng; use rand_xoshiro::Xoshiro256Plus; +use lace_cc::alg::RowAssignAlg; +use lace_cc::feature::FType; +use lace_cc::transition::ViewTransition; +use lace_cc::view::{View, ViewGewekeSettings}; +use lace_geweke::GewekeTester; + #[derive(Parser, Debug)] #[clap(rename_all = "kebab")] struct Opt { - #[clap( - long, - default_value = "gibbs", - value_parser = ["finite_cpu", "gibbs", "slice", "sams"], - )] + #[clap(long, default_value = "gibbs")] pub alg: RowAssignAlg, #[clap(short, long, default_value = "20")] pub nrows: usize, @@ -22,6 +21,10 @@ struct Opt { pub no_view_alpha: bool, #[clap(long)] pub no_priors: bool, + #[clap(long)] + pub pitman_yor: bool, + #[clap(long, short = 'i', default_value = "10000")] + pub niters: usize, } fn main() { @@ -43,13 +46,21 @@ fn main() { .transitions .push(ViewTransition::RowAssignment(opt.alg)); } + if !opt.no_view_alpha { - settings.transitions.push(ViewTransition::Alpha); + settings + .transitions + .push(ViewTransition::PriorProcessParams); } + if !opt.no_priors { settings.transitions.push(ViewTransition::FeaturePriors); } + if opt.pitman_yor { + settings = settings.with_pitman_yor_process(); + } + settings.transitions.push(ViewTransition::ComponentParams); settings @@ -57,7 +68,7 @@ fn main() { // Initialize a tester given the settings and run. let mut geweke: GewekeTester = GewekeTester::new(settings); - geweke.run(10_000, Some(5), &mut rng); + geweke.run(opt.niters, Some(5), &mut rng); // Reports the deviation from a perfect correspondence between the // forward and posterior CDFs. The best score is zero, the worst possible diff --git a/lace/lace_cc/src/builders.rs b/lace/lace_cc/src/builders.rs new file mode 100644 index 00000000..af374e29 --- /dev/null +++ b/lace/lace_cc/src/builders.rs @@ -0,0 +1,143 @@ +use rand::SeedableRng; +use rand_xoshiro::Xoshiro256Plus; +use thiserror::Error; + +use lace_stats::assignment::{Assignment, AssignmentError}; +use lace_stats::prior_process::Process; + +/// Constructs `Assignment`s +#[derive(Clone, Debug)] +pub struct AssignmentBuilder { + n: usize, + asgn: Option>, + prior_process: Option, + seed: Option, +} + +#[derive(Debug, Error, PartialEq)] +pub enum BuildAssignmentError { + #[error("alpha is zero")] + AlphaIsZero, + #[error("non-finite alpha: {alpha}")] + AlphaNotFinite { alpha: f64 }, + #[error("assignment vector is empty")] + EmptyAssignmentVec, + #[error("there are {n_cats} categories but {n} data")] + NLessThanNCats { n: usize, n_cats: usize }, + #[error("invalid assignment: {0}")] + AssignmentError(#[from] AssignmentError), +} + +impl AssignmentBuilder { + /// Create a builder for `n`-length assignments + /// + /// # Arguments + /// - n: the number of data/entries in the assignment + pub fn new(n: usize) -> Self { + AssignmentBuilder { + n, + asgn: None, + prior_process: None, + seed: None, + } + } + + /// Initialize the builder from an assignment vector + /// + /// # Note: + /// The validity of `asgn` will not be verified until `build` is called. + pub fn from_vec(asgn: Vec) -> Self { + AssignmentBuilder { + n: asgn.len(), + asgn: Some(asgn), + prior_process: None, + seed: None, + } + } + + /// Add a prior on the `Crp` `alpha` parameter + #[must_use] + pub fn with_prior_process(mut self, process: Process) -> Self { + self.prior_process = Some(process); + self + } + + /// Set the RNG seed + #[must_use] + pub fn with_seed(mut self, seed: u64) -> Self { + self.seed = Some(seed); + self + } + + /// Set the RNG seed from another RNG + #[must_use] + pub fn seed_from_rng(mut self, rng: &mut R) -> Self { + self.seed = Some(rng.next_u64()); + self + } + + /// Use a *flat* assignment with one partition + #[must_use] + pub fn flat(mut self) -> Self { + self.asgn = Some(vec![0; self.n]); + self + } + + /// Use an assignment with `n_cats`, evenly populated partitions/categories + pub fn with_n_cats( + mut self, + n_cats: usize, + ) -> Result { + if n_cats > self.n { + Err(BuildAssignmentError::NLessThanNCats { n: self.n, n_cats }) + } else { + let asgn: Vec = (0..self.n).map(|i| i % n_cats).collect(); + self.asgn = Some(asgn); + Ok(self) + } + } + + /// Build the assignment and consume the builder + pub fn build(self) -> Result { + use lace_stats::prior_process::{Dirichlet, PriorProcessT}; + + let mut rng = self.seed.map_or_else( + || Xoshiro256Plus::from_entropy(), + Xoshiro256Plus::seed_from_u64, + ); + + let process = self.prior_process.unwrap_or_else(|| { + Process::Dirichlet(Dirichlet::from_prior( + lace_consts::general_alpha_prior(), + &mut rng, + )) + }); + + let n = self.n; + let asgn = self + .asgn + .unwrap_or_else(|| process.draw_assignment(n, &mut rng).asgn); + + let n_cats: usize = asgn.iter().max().map(|&m| m + 1).unwrap_or(0); + let mut counts: Vec = vec![0; n_cats]; + for z in &asgn { + counts[*z] += 1; + } + + let asgn_out = Assignment { + asgn, + counts, + n_cats, + }; + + if lace_stats::validate_assignment!(asgn_out) { + Ok(asgn_out) + } else { + asgn_out + .validate() + .emit_error() + .map_err(BuildAssignmentError::AssignmentError) + .map(|_| asgn_out) + } + } +} diff --git a/lace/lace_cc/src/feature/column.rs b/lace/lace_cc/src/feature/column.rs index da1a6050..0082fdbe 100644 --- a/lace/lace_cc/src/feature/column.rs +++ b/lace/lace_cc/src/feature/column.rs @@ -4,6 +4,7 @@ use std::vec::Drain; use enum_dispatch::enum_dispatch; use lace_data::{Category, FeatureData}; use lace_data::{Container, SparseContainer}; +use lace_stats::assignment::Assignment; use lace_stats::prior::csd::CsdHyper; use lace_stats::prior::nix::NixHyper; use lace_stats::prior::pg::PgHyper; @@ -20,7 +21,6 @@ use serde::{de::DeserializeOwned, Deserialize, Serialize}; use std::sync::OnceLock; use super::{Component, MissingNotAtRandom}; -use crate::assignment::Assignment; use crate::component::ConjugateComponent; use crate::feature::traits::{Feature, FeatureHelper, TranslateDatum}; use crate::feature::FType; @@ -712,18 +712,22 @@ impl QmcEntropy for ColModel { mod tests { use super::*; - use crate::assignment::AssignmentBuilder; use crate::feature::{Column, Feature}; use lace_data::{FeatureData, SparseContainer}; + use lace_stats::prior_process::{Builder, Dirichlet, Process}; use lace_stats::prior::nix::NixHyper; fn gauss_fixture() -> ColModel { let mut rng = rand::thread_rng(); - let asgn = AssignmentBuilder::new(5) - .with_alpha(1.0) + let asgn = Builder::new(5) + .with_process(Process::Dirichlet(Dirichlet { + alpha: 1.0, + alpha_prior: Gamma::default(), + })) .flat() .build() - .unwrap(); + .unwrap() + .asgn; let data_vec: Vec = vec![0.0, 1.0, 2.0, 3.0, 4.0]; let hyper = NixHyper::default(); let data = SparseContainer::from(data_vec); @@ -736,11 +740,15 @@ mod tests { fn categorical_fixture_u8() -> ColModel { let mut rng = rand::thread_rng(); - let asgn = AssignmentBuilder::new(5) - .with_alpha(1.0) + let asgn = Builder::new(5) + .with_process(Process::Dirichlet(Dirichlet { + alpha: 1.0, + alpha_prior: Gamma::default(), + })) .flat() .build() - .unwrap(); + .unwrap() + .asgn; let data_vec: Vec = vec![0, 1, 2, 0, 1]; let data = SparseContainer::from(data_vec); let hyper = CsdHyper::vague(3); diff --git a/lace/lace_cc/src/feature/geweke.rs b/lace/lace_cc/src/feature/geweke.rs index 403a1cd6..6db11103 100644 --- a/lace/lace_cc/src/feature/geweke.rs +++ b/lace/lace_cc/src/feature/geweke.rs @@ -3,6 +3,7 @@ use std::collections::BTreeMap; use lace_data::{Container, SparseContainer}; use lace_geweke::{GewekeModel, GewekeResampleData, GewekeSummarize}; +use lace_stats::assignment::Assignment; use lace_stats::prior::csd::CsdHyper; use lace_stats::prior::nix::NixHyper; use lace_stats::prior::pg::PgHyper; @@ -14,7 +15,6 @@ use lace_stats::rv::traits::Rv; use lace_utils::{mean, std}; use rand::Rng; -use crate::assignment::Assignment; use crate::feature::{ColModel, Column, FType, Feature}; use crate::transition::ViewTransition; diff --git a/lace/lace_cc/src/feature/mnar.rs b/lace/lace_cc/src/feature/mnar.rs index b0035630..eb32fb00 100644 --- a/lace/lace_cc/src/feature/mnar.rs +++ b/lace/lace_cc/src/feature/mnar.rs @@ -1,6 +1,6 @@ use super::{ColModel, Column, Component, FType, Feature, FeatureHelper}; -use crate::assignment::Assignment; use lace_data::{Datum, FeatureData, SparseContainer}; +use lace_stats::assignment::Assignment; use lace_stats::rv::dist::{Bernoulli, Beta}; use lace_stats::MixtureType; use rand::Rng; @@ -278,10 +278,11 @@ mod test { present, }; let mut rng = rand::thread_rng(); - let asgn = crate::assignment::AssignmentBuilder::new(n) + let asgn = lace_stats::prior_process::Builder::new(n) .seed_from_rng(&mut rng) .build() - .unwrap(); + .unwrap() + .asgn; col.reassign(&asgn, &mut rng); (col, asgn) } diff --git a/lace/lace_cc/src/feature/traits.rs b/lace/lace_cc/src/feature/traits.rs index 44ac2643..5f4ceea5 100644 --- a/lace/lace_cc/src/feature/traits.rs +++ b/lace/lace_cc/src/feature/traits.rs @@ -2,6 +2,7 @@ use enum_dispatch::enum_dispatch; use lace_data::FeatureData; use lace_data::{Datum, SparseContainer}; +use lace_stats::assignment::Assignment; use lace_stats::prior::csd::CsdHyper; use lace_stats::prior::nix::NixHyper; use lace_stats::prior::pg::PgHyper; @@ -13,7 +14,6 @@ use lace_stats::MixtureType; use rand::Rng; use super::Component; -use crate::assignment::Assignment; use crate::feature::{ColModel, Column, FType}; pub trait TranslateDatum @@ -162,8 +162,8 @@ pub(crate) trait FeatureHelper: Feature { #[cfg(test)] mod tests { use super::*; - use crate::assignment::AssignmentBuilder; use approx::*; + use lace_stats::prior_process::Builder as PriorProcessBuilder; use lace_stats::rv::dist::Gaussian; use lace_stats::rv::traits::Rv; @@ -175,7 +175,7 @@ mod tests { let hyper = NixHyper::default(); let prior = NormalInvChiSquared::new_unchecked(0.0, 1.0, 1.0, 1.0); for _ in 0..100 { - let asgn = AssignmentBuilder::new(n_rows).build().unwrap(); + let asgn = PriorProcessBuilder::new(n_rows).build().unwrap().asgn; let xs: Vec = g.sample(n_rows, &mut rng); let data = SparseContainer::from(xs); let mut feature = diff --git a/lace/lace_cc/src/lib.rs b/lace/lace_cc/src/lib.rs index 7133e9a1..6a8ec542 100644 --- a/lace/lace_cc/src/lib.rs +++ b/lace/lace_cc/src/lib.rs @@ -9,12 +9,12 @@ )] pub mod alg; -pub mod assignment; +// pub mod builders; pub mod component; pub mod config; pub mod feature; pub mod massflip; -pub mod misc; +// pub mod misc; pub mod state; pub mod traits; pub mod transition; diff --git a/lace/lace_cc/src/misc.rs b/lace/lace_cc/src/misc.rs deleted file mode 100644 index fb1a85bc..00000000 --- a/lace/lace_cc/src/misc.rs +++ /dev/null @@ -1,135 +0,0 @@ -use lace_stats::rv::dist::Beta; -use lace_stats::rv::misc::pflip; -use lace_stats::rv::traits::Rv; -use rand::Rng; -use serde::{Deserialize, Serialize}; - -const MAX_STICK_BREAKING_ITERS: u16 = 1000; - -#[derive(Debug, Clone, Hash, PartialEq, Eq, Ord, PartialOrd)] -pub struct CrpDraw { - pub asgn: Vec, - pub counts: Vec, - pub n_cats: usize, -} - -/// Draw from Chinese Restaraunt Process -pub fn crp_draw(n: usize, alpha: f64, rng: &mut R) -> CrpDraw { - let mut n_cats = 0; - let mut weights: Vec = vec![]; - let mut asgn: Vec = Vec::with_capacity(n); - - for _ in 0..n { - weights.push(alpha); - let k = pflip(&weights, 1, rng)[0]; - asgn.push(k); - - if k == n_cats { - weights[n_cats] = 1.0; - n_cats += 1; - } else { - weights.truncate(n_cats); - weights[k] += 1.0; - } - } - // convert weights to counts, correcting for possible floating point - // errors - let counts: Vec = - weights.iter().map(|w| (w + 0.5) as usize).collect(); - - CrpDraw { - asgn, - counts, - n_cats, - } -} - -/// The stick breaking algorithm has exceeded the max number of iterations. -#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct TheStickIsDust(u16); - -/// Append new dirchlet weights by stick breaking until the new weight is less -/// than u* -/// -/// **NOTE** This function is only for the slice reassignment kernel. It cuts out all -/// weights that are less that u*, so the sum of the weights will not be 1. -pub fn sb_slice_extend( - mut weights: Vec, - alpha: f64, - u_star: f64, - mut rng: &mut R, -) -> Result, TheStickIsDust> { - let mut b_star = weights.pop().unwrap(); - - // If α is low and we do the dirichlet update w ~ Dir(n_1, ..., n_k, α), - // the final weight will often be zero. In that case, we're done. - if b_star <= 1E-16 { - weights.push(b_star); - return Ok(weights); - } - - let beta = Beta::new(1.0, alpha).unwrap(); - - let mut iters: u16 = 0; - loop { - let vk: f64 = beta.draw(&mut rng); - let bk = vk * b_star; - b_star *= 1.0 - vk; - - if bk >= u_star { - weights.push(bk); - } - - if b_star < u_star { - return Ok(weights); - } - - iters += 1; - if iters > MAX_STICK_BREAKING_ITERS { - return Err(TheStickIsDust(MAX_STICK_BREAKING_ITERS)); - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - const TOL: f64 = 1E-12; - - mod sb_slice { - use super::*; - - #[test] - fn should_return_input_weights_if_alpha_is_zero() { - let mut rng = rand::thread_rng(); - let weights_in: Vec = vec![0.8, 0.2, 0.0]; - let weights_out = - sb_slice_extend(weights_in.clone(), 1.0, 0.2, &mut rng) - .unwrap(); - let good = weights_in - .iter() - .zip(weights_out.iter()) - .all(|(wi, wo)| (wi - wo).abs() < TOL); - assert!(good); - } - - #[test] - fn should_return_error_for_zero_u_star() { - let mut rng = rand::thread_rng(); - let weights_in: Vec = vec![0.8, 0.2]; - let u_star = 0.0; - let res = sb_slice_extend(weights_in, 1.0, u_star, &mut rng); - assert!(res.is_err()); - } - - #[test] - fn smoke() { - let mut rng = rand::thread_rng(); - let weights_in: Vec = vec![0.8, 0.2]; - let u_star = 0.1; - let res = sb_slice_extend(weights_in, 1.0, u_star, &mut rng); - assert!(res.is_ok()); - } - } -} diff --git a/lace/lace_cc/src/state.rs b/lace/lace_cc/src/state.rs index 98dd4d63..baddd2dc 100644 --- a/lace/lace_cc/src/state.rs +++ b/lace/lace_cc/src/state.rs @@ -1,11 +1,17 @@ mod builder; pub use builder::{BuildStateError, Builder}; +use lace_consts::geweke_alpha_prior; use std::convert::TryInto; use std::f64::NEG_INFINITY; use lace_data::{Datum, FeatureData}; -use lace_stats::rv::dist::{Dirichlet, Gamma}; +use lace_stats::assignment::Assignment; +use lace_stats::prior_process::Builder as AssignmentBuilder; +use lace_stats::prior_process::{ + PriorProcess, PriorProcessT, PriorProcessType, Process, +}; +use lace_stats::rv::dist::Dirichlet; use lace_stats::rv::misc::ln_pflip; use lace_stats::rv::traits::*; use lace_stats::MixtureType; @@ -17,7 +23,6 @@ use rayon::prelude::*; use serde::{Deserialize, Serialize}; use crate::alg::{ColAssignAlg, RowAssignAlg}; -use crate::assignment::{Assignment, AssignmentBuilder}; use crate::config::StateUpdateConfig; use crate::feature::Component; use crate::feature::{ColModel, FType, Feature}; @@ -36,32 +41,25 @@ pub struct StateDiagnostics { pub logprior: Vec, } +#[derive(Serialize, Deserialize, Clone, Debug, Default)] +pub struct StateScoreComponents { + pub ln_likelihood: f64, + pub ln_prior: f64, + pub ln_state_prior_process: f64, + pub ln_view_prior_process: f64, +} + /// A cross-categorization state #[derive(Serialize, Deserialize, Clone, Debug)] pub struct State { /// The views of columns pub views: Vec, /// The assignment of columns to views - pub asgn: Assignment, + pub prior_process: PriorProcess, /// The weights of each view in the column mixture pub weights: Vec, - /// The prior on the view CRP alphas - pub view_alpha_prior: Gamma, - /// The log likeihood of the data under the current assignment - #[serde(default)] - pub loglike: f64, - /// The log prior likelihood of component parameters under the prior and of - /// feature prior parameters under the hyperprior #[serde(default)] - pub log_prior: f64, - /// The log prior likelihood of the row assignments under CRP and of the CRP - /// alpha param under the hyperprior - #[serde(default)] - pub log_view_alpha_prior: f64, - /// The log prior likelihood of column assignment under CRP and of the state - /// CRP alpha param under the hyperprior - #[serde(default)] - pub log_state_alpha_prior: f64, + pub score: StateScoreComponents, /// The running diagnostics pub diagnostics: StateDiagnostics, } @@ -70,28 +68,28 @@ unsafe impl Send for State {} unsafe impl Sync for State {} impl State { - pub fn new( - views: Vec, - asgn: Assignment, - view_alpha_prior: Gamma, - ) -> Self { - let weights = asgn.weights(); + pub fn new(views: Vec, prior_process: PriorProcess) -> Self { + let weights = prior_process.weight_vec(false); let mut state = State { views, - asgn, + prior_process, weights, - view_alpha_prior, - loglike: 0.0, - log_prior: 0.0, - log_state_alpha_prior: 0.0, - log_view_alpha_prior: 0.0, + score: StateScoreComponents::default(), diagnostics: StateDiagnostics::default(), }; - state.loglike = state.loglike(); + state.score.ln_likelihood = state.loglike(); state } + pub fn asgn(&self) -> &Assignment { + &self.prior_process.asgn + } + + pub fn asgn_mut(&mut self) -> &mut Assignment { + &mut self.prior_process.asgn + } + /// Create a new `Builder` for generating a new `State`. pub fn builder() -> Builder { Builder::new() @@ -100,46 +98,38 @@ impl State { /// Draw a new `State` from the prior pub fn from_prior( mut ftrs: Vec, - state_alpha_prior: Gamma, - view_alpha_prior: Gamma, + state_process: Process, + view_process: Process, rng: &mut R, ) -> Self { let n_cols = ftrs.len(); let n_rows = ftrs.first().map(|f| f.len()).unwrap_or(0); - let asgn = AssignmentBuilder::new(n_cols) - .with_prior(state_alpha_prior) - .seed_from_rng(rng) - .build() - .unwrap(); - - let mut views: Vec = (0..asgn.n_cats) + let prior_process = + PriorProcess::from_process(state_process, n_cols, rng); + let mut views: Vec = (0..prior_process.asgn.n_cats) .map(|_| { view::Builder::new(n_rows) - .alpha_prior(view_alpha_prior.clone()) + .prior_process(view_process.clone()) .seed_from_rng(rng) .build() }) .collect(); // TODO: Can we parallellize this? - for (&v, ftr) in asgn.asgn.iter().zip(ftrs.drain(..)) { + for (&v, ftr) in prior_process.asgn.iter().zip(ftrs.drain(..)) { views[v].init_feature(ftr, rng); } - let weights = asgn.weights(); + let weights = prior_process.weight_vec(false); let mut state = State { views, - asgn, + prior_process, weights, - view_alpha_prior, - loglike: 0.0, - log_prior: 0.0, - log_state_alpha_prior: 0.0, - log_view_alpha_prior: 0.0, + score: StateScoreComponents::default(), diagnostics: StateDiagnostics::default(), }; - state.loglike = state.loglike(); + state.score.ln_likelihood = state.loglike(); state } @@ -154,14 +144,14 @@ impl State { /// Get a reference to the features at `col_ix` #[inline] pub fn feature(&self, col_ix: usize) -> &ColModel { - let view_ix = self.asgn.asgn[col_ix]; + let view_ix = self.asgn().asgn[col_ix]; &self.views[view_ix].ftrs[&col_ix] } /// Get a mutable reference to the features at `col_ix` #[inline] pub fn feature_mut(&mut self, col_ix: usize) -> &mut ColModel { - let view_ix = self.asgn.asgn[col_ix]; + let view_ix = self.asgn().asgn[col_ix]; self.views[view_ix].ftrs.get_mut(&col_ix).unwrap() } @@ -169,7 +159,7 @@ impl State { #[inline] pub fn feature_as_mixture(&self, col_ix: usize) -> MixtureType { let weights = { - let view_ix = self.asgn.asgn[col_ix]; + let view_ix = self.asgn().asgn[col_ix]; self.views[view_ix].weights.clone() }; self.feature(col_ix).to_mixture(weights) @@ -206,7 +196,7 @@ impl State { /// Get the feature type (`FType`) of the column at `col_ix` #[inline] pub fn ftype(&self, col_ix: usize) -> FType { - let view_ix = self.asgn.asgn[col_ix]; + let view_ix = self.asgn().asgn[col_ix]; self.views[view_ix].ftrs[&col_ix].ftype() } @@ -223,23 +213,24 @@ impl State { StateTransition::RowAssignment(alg) => { self.reassign_rows(*alg, rng); } - StateTransition::StateAlpha => { - self.log_state_alpha_prior = self - .asgn - .update_alpha(lace_consts::MH_PRIOR_ITERS, rng); + StateTransition::StatePriorProcessParams => { + // FIXME: Add to probability? + self.score.ln_state_prior_process = + self.prior_process.update_params(rng); } - StateTransition::ViewAlphas => { - self.log_view_alpha_prior = self.update_view_alphas(rng); + StateTransition::ViewPriorProcessParams => { + self.score.ln_view_prior_process = + self.update_view_prior_process_params(rng); } StateTransition::FeaturePriors => { - self.log_prior = self.update_feature_priors(rng); + self.score.ln_prior = self.update_feature_priors(rng); } StateTransition::ComponentParams => { self.update_component_params(rng); } } } - self.loglike = self.loglike(); + self.score.ln_likelihood = self.loglike(); } fn reassign_rows( @@ -260,8 +251,11 @@ impl State { } #[inline] - fn update_view_alphas(&mut self, rng: &mut R) -> f64 { - self.views.iter_mut().map(|v| v.update_alpha(rng)).sum() + fn update_view_prior_process_params(&mut self, rng: &mut R) -> f64 { + self.views + .iter_mut() + .map(|v| v.update_prior_process_params(rng)) + .sum() } #[inline] @@ -301,10 +295,10 @@ impl State { } pub fn push_diagnostics(&mut self) { - self.diagnostics.loglike.push(self.loglike); - let log_prior = self.log_prior - + self.log_view_alpha_prior - + self.log_state_alpha_prior; + self.diagnostics.loglike.push(self.score.ln_likelihood); + let log_prior = self.score.ln_prior + + self.score.ln_view_prior_process + + self.score.ln_state_prior_process; self.diagnostics.logprior.push(log_prior); } @@ -312,11 +306,11 @@ impl State { pub fn flatten_cols(&mut self, mut rng: &mut R) { let n_cols = self.n_cols(); let new_asgn_vec = vec![0; n_cols]; - let n_cats = self.asgn.n_cats; + let n_cats = self.asgn().n_cats; let ftrs = { let mut ftrs: Vec = Vec::with_capacity(n_cols); - for (i, &v) in self.asgn.asgn.iter().enumerate() { + for (i, &v) in self.prior_process.asgn.asgn.iter().enumerate() { ftrs.push( self.views[v].remove_feature(i).expect("Feature missing"), ); @@ -340,6 +334,7 @@ impl State { } ColAssignAlg::Gibbs => { self.reassign_cols_gibbs(transitions, rng); + // // FIXME: below alg doesn't pass enum tests // self.reassign_cols_gibbs_precomputed(transitions, rng); // NOTE: The oracle functions use the weights to compute probabilities. @@ -347,7 +342,7 @@ impl State { // it does not explicitly update the weights. Non-updated weights means // wrong probabilities. To avoid this, we set the weights by the // partition here. - self.weights = self.asgn.weights(); + self.weights = self.prior_process.weight_vec(false); } ColAssignAlg::Slice => self.reassign_cols_slice(transitions, rng), } @@ -392,10 +387,11 @@ impl State { let p = (k as f64).recip(); ftrs.drain(..).for_each(|mut ftr| { ftr.set_id(self.n_cols()); - self.asgn.push_unassigned(); + self.asgn_mut().push_unassigned(); // insert into random existing view let view_ix = pflip(&vec![p; k], 1, &mut rng)[0]; - self.asgn.reassign(self.n_cols(), view_ix); + let n_cols = self.n_cols(); + self.asgn_mut().reassign(n_cols, view_ix); self.views[view_ix].insert_feature(ftr, &mut rng); }) } @@ -407,43 +403,43 @@ impl State { .for_each(|view| view.assign_unassigned(&mut rng)); } - fn create_tmp_assigns( + fn create_tmp_assign( &self, - m: usize, - counter_start: usize, - draw_alpha: bool, - rng: &mut R, - ) -> (BTreeMap, Vec) { - let mut seeds = Vec::with_capacity(m); - let tmp_asgns = (0..m) - .map(|i| { - let seed: u64 = rng.gen(); - - // assignment for a hypothetical singleton view - let asgn_bldr = AssignmentBuilder::new(self.n_rows()) - .with_prior(self.view_alpha_prior.clone()) - .with_seed(seed); - - // If we do not want to draw a view alpha, take an existing one from the - // first view. This covers the case were we set the view alphas and - // never transitions them, for example if we are doing geweke on a - // subset of transitions. - let tmp_asgn = if draw_alpha { - asgn_bldr - } else { - let alpha = self.views[0].asgn.alpha; - asgn_bldr.with_alpha(alpha) - } - .build() - .unwrap(); + draw_process_params: bool, + seed: u64, + ) -> PriorProcess { + // assignment for a hypothetical singleton view + let mut rng = Xoshiro256Plus::seed_from_u64(seed); + let asgn_bldr = + AssignmentBuilder::new(self.n_rows()).with_seed(rng.gen()); + // If we do not want to draw a view process params, take an + // existing process from the first view. This covers the case + // where we set the view process params and never transitions + // them, for example if we are doing geweke on a subset of + // transitions. + let mut process = self.views[0].prior_process.process.clone(); + if draw_process_params { + process.reset_params(&mut rng); + }; + asgn_bldr.with_process(process).build().unwrap() + } - seeds.push(seed); + fn create_tmp_assigns( + &self, + counter_start: usize, + draw_process_params: bool, + seeds: &[u64], + ) -> BTreeMap { + seeds + .iter() + .enumerate() + .map(|(i, &seed)| { + let tmp_asgn = + self.create_tmp_assign(draw_process_params, seed); (i + counter_start, tmp_asgn) }) - .collect(); - - (tmp_asgns, seeds) + .collect() } /// Insert an unassigned feature into the `State` via the `Gibbs` @@ -452,7 +448,7 @@ impl State { pub fn insert_feature( &mut self, ftr: ColModel, - draw_alpha: bool, + update_process_params: bool, rng: &mut R, ) -> f64 { // Number of singleton features. For assigning to a singleton, we have @@ -460,34 +456,42 @@ impl State { // `m` parameter is the number of samples for the integration. let m: usize = 1; // TODO: Should this be a parameter in ColAssignAlg? let col_ix = ftr.id(); + let n_views = self.n_views(); - // crp alpha divided by the number of MC samples - let a_part = (self.asgn.alpha / m as f64).ln(); + // singleton weight divided by the number of MC samples + let p_singleton = + self.prior_process.process.ln_singleton_weight(n_views) + - (m as f64).ln(); // score for each view. We will push the singleton view probabilities // later - let mut logps = self.asgn.log_dirvec(false); + let mut logps = self + .asgn() + .counts + .iter() + .map(|&n_k| self.prior_process.process.ln_gibbs_weight(n_k)) + .collect::>(); // maintain a vec that holds just the likelihoods let mut ftr_logps: Vec = Vec::with_capacity(logps.len()); // TODO: might be faster with an iterator? for (ix, view) in self.views.iter().enumerate() { - let lp = ftr.asgn_score(&view.asgn); + let lp = ftr.asgn_score(view.asgn()); ftr_logps.push(lp); logps[ix] += lp; } - let n_views = self.n_views(); - // here we create the monte carlo estimate for the singleton view - let mut tmp_asgns = - self.create_tmp_assigns(m, n_views, draw_alpha, rng).0; + let mut tmp_asgns = { + let seeds: Vec = (0..m).map(|_| rng.gen()).collect(); + self.create_tmp_assigns(n_views, update_process_params, &seeds) + }; tmp_asgns.iter().for_each(|(_, tmp_asgn)| { - let singleton_logp = ftr.asgn_score(tmp_asgn); + let singleton_logp = ftr.asgn_score(&tmp_asgn.asgn); ftr_logps.push(singleton_logp); - logps.push(a_part + singleton_logp); + logps.push(p_singleton + singleton_logp); }); debug_assert_eq!(n_views + m, logps.len()); @@ -501,7 +505,7 @@ impl State { // This will error if v_new is not in the index, and that is a good. // thing. let tmp_asgn = tmp_asgns.remove(&v_new).unwrap(); - let new_view = view::Builder::from_assignment(tmp_asgn) + let new_view = view::Builder::from_prior_process(tmp_asgn) .seed_from_rng(rng) .build(); self.views.push(new_view); @@ -511,7 +515,7 @@ impl State { // we max the new view index to n_views let v_new = v_new.min(n_views); - self.asgn.reassign(col_ix, v_new); + self.asgn_mut().reassign(col_ix, v_new); self.views[v_new].insert_feature(ftr, rng); logp_out } @@ -520,11 +524,11 @@ impl State { pub fn reassign_col_gibbs( &mut self, col_ix: usize, - draw_alpha: bool, + update_process_params: bool, rng: &mut R, ) -> f64 { let ftr = self.extract_ftr(col_ix); - self.insert_feature(ftr, draw_alpha, rng) + self.insert_feature(ftr, update_process_params, rng) } /// Reassign all columns using the Gibbs transition. @@ -542,17 +546,18 @@ impl State { if self.n_cols() == 1 { return; } - // The algorithm is not valid if the columns are not scanned in - // random order - let draw_alpha = transitions + + let update_process_params = transitions .iter() - .any(|&t| t == StateTransition::ViewAlphas); + .any(|&t| t == StateTransition::ViewPriorProcessParams); + // The algorithm is not valid if the columns are not scanned in + // random order let mut col_ixs: Vec = (0..self.n_cols()).collect(); col_ixs.shuffle(rng); col_ixs.drain(..).for_each(|col_ix| { - self.reassign_col_gibbs(col_ix, draw_alpha, rng); + self.reassign_col_gibbs(col_ix, update_process_params, rng); }) } @@ -569,9 +574,9 @@ impl State { // Check if we're drawing view alpha. If not, we use the user-specified // alpha value for all temporary, singleton assignments - let draw_alpha = transitions + let draw_process_params = transitions .iter() - .any(|&t| t == StateTransition::ViewAlphas); + .any(|&t| t == StateTransition::ViewPriorProcessParams); // determine the number of columns for which to pre-compute transition // probabilities @@ -592,10 +597,13 @@ impl State { n_cols / batch_size + 1 }; + // FIXME: Only works for Dirichlet Process! // The partial alpha required for the singleton columns. Since we have // `m` singletons to try, we have to divide alpha by m so the singleton // proposal as a whole has the correct mass - let a_part = (self.asgn.alpha / m as f64).ln(); + let n_views = self.n_views(); + let a_part = self.prior_process.process.ln_singleton_weight(n_views) + / (m as f64).ln(); for _ in 0..n_batches { // Number of views at the start of the pre-computation @@ -612,41 +620,43 @@ impl State { .map(|(col_ix, mut t_rng)| { // let mut logps = vec![0_f64; n_views]; - let view_ix = self.asgn.asgn[col_ix]; + let view_ix = self.asgn().asgn[col_ix]; let mut logps: Vec = self .views .iter() .map(|view| { // TODO: we can use Feature::score instead of asgn_score // when the view index is this_view_ix - self.feature(col_ix).asgn_score(&view.asgn) + self.feature(col_ix).asgn_score(view.asgn()) }) .collect(); // Always propose new singletons - let (tmp_asgns, tmp_asgn_seeds) = self.create_tmp_assigns( - m, + let tmp_asgn_seeds: Vec = + (0..m).map(|_| t_rng.gen()).collect(); + + let tmp_asgns = self.create_tmp_assigns( self.n_views(), - draw_alpha, - &mut t_rng, + draw_process_params, + &tmp_asgn_seeds, ); let ftr = self.feature(col_ix); // TODO: might be faster with an iterator? for asgn in tmp_asgns.values() { - logps.push(ftr.asgn_score(asgn) + a_part); + logps.push(ftr.asgn_score(&asgn.asgn) + a_part); } (col_ix, view_ix, logps, tmp_asgn_seeds) }) - .collect::>(); + .collect::, Vec)>>(); for _ in 0..pre_comps.len() { let (col_ix, this_view_ix, mut logps, seeds) = pre_comps.pop().unwrap(); - let is_singleton = self.asgn.counts[this_view_ix] == 1; + let is_singleton = self.asgn().counts[this_view_ix] == 1; let n_views = self.n_views(); logps.iter_mut().take(n_views).enumerate().for_each( @@ -654,7 +664,7 @@ impl State { // add the CRP component to the log likelihood. We must // remove the contribution to the counts of the current // column. - let ct = self.asgn.counts[k] as f64; + let ct = self.asgn().counts[k] as f64; let ln_ct = if k == this_view_ix { // Note that if ct == 1 this is a singleton in which // case the CRP component will be log(0), which @@ -670,20 +680,6 @@ impl State { }, ); - // // New views have appeared since we pre-computed - // let logp_views = logps.len() - seeds.len(); - // if n_views > logp_views { - // let ftr = self.feature(col_ix); - // for view_ix in logp_views..n_views { - // let asgn = &self.views[view_ix].asgn; - // let ln_counts = (self.asgn.counts[view_ix] as f64).ln(); - // let logp = ftr.asgn_score(asgn) + ln_counts; - - // // insert the new logps right before the singleton logps - // logps.insert(view_ix, logp); - // } - // } - let mut v_new = ln_pflip(&logps, 1, false, rng)[0]; if v_new != this_view_ix { @@ -691,22 +687,15 @@ impl State { // Moved to a singleton let seed_ix = v_new - n_views; let seed = seeds[seed_ix]; - let asgn_builder = - AssignmentBuilder::new(self.n_rows()) - .with_prior(self.view_alpha_prior.clone()) - .with_seed(seed); - let tmp_asgn = if draw_alpha { - asgn_builder - } else { - asgn_builder.with_alpha(self.asgn.alpha) - } - .build() - .unwrap(); + let prior_process = + self.create_tmp_assign(draw_process_params, seed); + + let new_view = + view::Builder::from_prior_process(prior_process) + .seed_from_rng(&mut rng) + .build(); - let new_view = view::Builder::from_assignment(tmp_asgn) - .seed_from_rng(&mut rng) - .build(); self.views.push(new_view); v_new = n_views; @@ -715,7 +704,7 @@ impl State { pre_comps.iter_mut().for_each( |(col_ix, _, ref mut logps, _)| { let logp = self.feature(*col_ix).asgn_score( - &self.views.last().unwrap().asgn, + self.views.last().unwrap().asgn(), ); logps.insert(n_views, logp); }, @@ -749,7 +738,7 @@ impl State { // some reason, Engine::insert_data requires the column to be // rebuilt... let ftr = self.extract_ftr(col_ix); - self.asgn.reassign(col_ix, v_new); + self.asgn_mut().reassign(col_ix, v_new); self.views[v_new].insert_feature(ftr, rng); } } @@ -769,16 +758,16 @@ impl State { let draw_alpha = transitions .iter() - .any(|&t| t == StateTransition::ViewAlphas); + .any(|&t| t == StateTransition::ViewPriorProcessParams); self.resample_weights(true, rng); self.append_empty_view(draw_alpha, rng); let log_weights: Vec = self.weights.iter().map(|w| w.ln()).collect(); - let n_cats = self.asgn.n_cats + 1; + let n_cats = self.asgn().n_cats + 1; let mut ftrs: Vec = Vec::with_capacity(n_cols); - for (i, &v) in self.asgn.asgn.iter().enumerate() { + for (i, &v) in self.prior_process.asgn.asgn.iter().enumerate() { ftrs.push( self.views[v].remove_feature(i).expect("Feature missing"), ); @@ -792,7 +781,7 @@ impl State { .iter() .enumerate() .map(|(v, view)| { - ftr.asgn_score(&view.asgn) + log_weights[v] + ftr.asgn_score(view.asgn()) + log_weights[v] }) .collect::>() }) @@ -813,8 +802,6 @@ impl State { transitions: &[StateTransition], rng: &mut R, ) { - use crate::misc::sb_slice_extend; - if self.n_cols() == 1 { return; } @@ -824,13 +811,14 @@ impl State { let n_cols = self.n_cols(); let weights: Vec = { - let dirvec = self.asgn.dirvec(true); + let dirvec = self.prior_process.weight_vec_unnormed(true); + // FIXME: this only works for Dirichlet process! let dir = Dirichlet::new(dirvec).unwrap(); dir.draw(rng) }; let us: Vec = self - .asgn + .asgn() .asgn .iter() .map(|&zi| { @@ -845,14 +833,16 @@ impl State { .fold(1.0, |umin, &ui| if ui < umin { ui } else { umin }); // Variable shadowing - let weights = - sb_slice_extend(weights, self.asgn.alpha, u_star, rng).unwrap(); + let weights = self + .prior_process + .process + .slice_sb_extend(weights, u_star, rng); let n_new_views = weights.len() - self.weights.len(); let n_views = weights.len(); let mut ftrs: Vec = Vec::with_capacity(n_cols); - for (i, &v) in self.asgn.asgn.iter().enumerate() { + for (i, &v) in self.prior_process.asgn.iter().enumerate() { ftrs.push( self.views[v].remove_feature(i).expect("Feature missing"), ); @@ -860,7 +850,7 @@ impl State { let draw_alpha = transitions .iter() - .any(|&t| t == StateTransition::ViewAlphas); + .any(|&t| t == StateTransition::ViewPriorProcessParams); for _ in 0..n_new_views { self.append_empty_view(draw_alpha, rng); } @@ -876,7 +866,7 @@ impl State { .zip(weights.iter()) .map(|(view, w)| { if w >= ui { - ftr.asgn_score(&view.asgn) + ftr.asgn_score(view.asgn()) } else { NEG_INFINITY } @@ -906,7 +896,7 @@ impl State { #[inline] pub fn datum(&self, row_ix: usize, col_ix: usize) -> Datum { - let view_ix = self.asgn.asgn[col_ix]; + let view_ix = self.asgn().asgn[col_ix]; self.views[view_ix].datum(row_ix, col_ix).unwrap() } @@ -915,7 +905,8 @@ impl State { add_empty_component: bool, rng: &mut R, ) { - let dirvec = self.asgn.dirvec(add_empty_component); + // FIXME: this only works for Dirichlet! + let dirvec = self.prior_process.weight_vec(add_empty_component); let dir = Dirichlet::new(dirvec).unwrap(); self.weights = dir.draw(rng) } @@ -938,11 +929,12 @@ impl State { } } - self.asgn + self.asgn_mut() .set_asgn(new_asgn_vec) .expect("new_asgn_vec is invalid"); - for (ftr, &v) in ftrs.drain(..).zip(self.asgn.asgn.iter()) { + for (ftr, &v) in ftrs.drain(..).zip(self.prior_process.asgn.asgn.iter()) + { self.views[v].insert_feature(ftr, rng) } } @@ -950,20 +942,20 @@ impl State { /// Extract a feature from its view, unassign it, and drop the view if it /// is a singleton. fn extract_ftr(&mut self, ix: usize) -> ColModel { - let v = self.asgn.asgn[ix]; - let ct = self.asgn.counts[v]; + let v = self.asgn().asgn[ix]; + let ct = self.asgn().counts[v]; let ftr = self.views[v].remove_feature(ix).unwrap(); if ct == 1 { self.drop_view(v); } - self.asgn.unassign(ix); + self.asgn_mut().unassign(ix); ftr } pub fn component(&self, row_ix: usize, col_ix: usize) -> Component { - let view_ix = self.asgn.asgn[col_ix]; + let view_ix = self.asgn().asgn[col_ix]; let view = &self.views[view_ix]; - let k = view.asgn.asgn[row_ix]; + let k = view.asgn().asgn[row_ix]; view.ftrs[&col_ix].component(k) } @@ -976,23 +968,20 @@ impl State { fn append_empty_view( &mut self, - draw_alpha: bool, // draw the view CRP alpha from the prior + draw_process_params: bool, rng: &mut R, ) { - let asgn_builder = AssignmentBuilder::new(self.n_rows()) - .with_prior(self.view_alpha_prior.clone()); + let asgn_bldr = + AssignmentBuilder::new(self.n_rows()).with_seed(rng.gen()); - let asgn_builder = if draw_alpha { - asgn_builder - } else { - // The alphas should all be the same, so just take one from another view - let alpha = self.views[0].asgn.alpha; - asgn_builder.with_alpha(alpha) + let mut process = self.views[0].prior_process.process.clone(); + if draw_process_params { + process.reset_params(rng); }; - let asgn = asgn_builder.seed_from_rng(rng).build().unwrap(); + let prior_process = asgn_bldr.with_process(process).build().unwrap(); - let view = view::Builder::from_assignment(asgn) + let view = view::Builder::from_prior_process(prior_process) .seed_from_rng(rng) .build(); @@ -1001,7 +990,7 @@ impl State { #[inline] pub fn impute_bounds(&self, col_ix: usize) -> Option<(f64, f64)> { - let view_ix = self.asgn.asgn[col_ix]; + let view_ix = self.asgn().asgn[col_ix]; self.views[view_ix].ftrs[&col_ix].impute_bounds() } @@ -1021,7 +1010,7 @@ impl State { row_ix: usize, col_ix: usize, ) -> Option { - let view_ix = self.asgn.asgn[col_ix]; + let view_ix = self.asgn().asgn[col_ix]; self.views[view_ix].remove_datum(row_ix, col_ix) } @@ -1029,7 +1018,7 @@ impl State { if x.is_missing() { self.remove_datum(row_ix, col_ix); } else { - let view_ix = self.asgn.asgn[col_ix]; + let view_ix = self.asgn().asgn[col_ix]; self.views[view_ix].insert_datum(row_ix, col_ix, x); } } @@ -1047,11 +1036,11 @@ impl State { // Delete a column from the table pub fn del_col(&mut self, ix: usize, rng: &mut R) { - let zi = self.asgn.asgn[ix]; - let is_singleton = self.asgn.counts[zi] == 1; + let zi = self.asgn().asgn[ix]; + let is_singleton = self.asgn().counts[zi] == 1; - self.asgn.unassign(ix); - self.asgn.asgn.remove(ix); + self.asgn_mut().unassign(ix); + self.asgn_mut().asgn.remove(ix); if is_singleton { self.views.remove(zi); @@ -1063,7 +1052,7 @@ impl State { // self.n_cols counts the number of features in views, so this should be // accurate after the remove step above for i in ix..self.n_cols() { - let zi = self.asgn.asgn[i]; + let zi = self.asgn().asgn[i]; let mut ftr = self.views[zi].remove_feature(i + 1).unwrap(); ftr.set_id(i); self.views[zi].ftrs.insert(ftr.id(), ftr); @@ -1106,8 +1095,8 @@ impl State { } pub fn col_weights(&self, col_ix: usize) -> Vec { - let view_ix = self.asgn.asgn[col_ix]; - self.views[view_ix].asgn.weights() + let view_ix = self.asgn().asgn[col_ix]; + self.views[view_ix].prior_process.weight_vec(false) } } @@ -1127,10 +1116,32 @@ pub struct StateGewekeSettings { pub cm_types: Vec, /// Which transitions to do pub transitions: Vec, + /// Which prior process to use for the State assignment + pub state_process_type: PriorProcessType, + /// Which prior process to use for the View assignment + pub view_process_type: PriorProcessType, } impl StateGewekeSettings { - pub fn new(n_rows: usize, cm_types: Vec) -> Self { + pub fn new( + n_rows: usize, + cm_types: Vec, + state_process_type: PriorProcessType, + view_process_type: PriorProcessType, + ) -> Self { + use crate::transition::DEFAULT_STATE_TRANSITIONS; + + StateGewekeSettings { + n_cols: cm_types.len(), + n_rows, + cm_types, + transitions: DEFAULT_STATE_TRANSITIONS.into(), + state_process_type, + view_process_type, + } + } + + pub fn new_dirichlet_process(n_rows: usize, cm_types: Vec) -> Self { use crate::transition::DEFAULT_STATE_TRANSITIONS; StateGewekeSettings { @@ -1138,6 +1149,8 @@ impl StateGewekeSettings { n_rows, cm_types, transitions: DEFAULT_STATE_TRANSITIONS.into(), + state_process_type: PriorProcessType::Dirichlet, + view_process_type: PriorProcessType::Dirichlet, } } @@ -1153,10 +1166,10 @@ impl StateGewekeSettings { .any(|t| matches!(t, StateTransition::RowAssignment(_))) } - pub fn do_alpha_transition(&self) -> bool { + pub fn do_process_params_transition(&self) -> bool { self.transitions .iter() - .any(|t| matches!(t, StateTransition::StateAlpha)) + .any(|t| matches!(t, StateTransition::StatePriorProcessParams)) } } @@ -1179,6 +1192,7 @@ impl GewekeResampleData for State { .iter() .filter_map(|&st| st.try_into().ok()) .collect(), + process_type: s.view_process_type, }; for view in &mut self.views { view.geweke_resample_data(Some(&view_settings), &mut rng); @@ -1243,16 +1257,20 @@ impl GewekeSummarize for State { .iter() .filter_map(|&st| st.try_into().ok()) .collect(), + process_type: settings.view_process_type, }; GewekeStateSummary { n_views: if settings.do_col_asgn_transition() { - Some(self.asgn.n_cats) + Some(self.asgn().n_cats) } else { None }, - alpha: if settings.do_alpha_transition() { - Some(self.asgn.alpha) + alpha: if settings.do_process_params_transition() { + Some(match self.prior_process.process { + Process::Dirichlet(ref inner) => inner.alpha, + Process::PitmanYor(ref inner) => inner.alpha, + }) } else { None }, @@ -1277,6 +1295,8 @@ impl GewekeModel for State { settings: &StateGewekeSettings, mut rng: &mut impl Rng, ) -> Self { + use lace_stats::prior_process::Dirichlet as PDirichlet; + let has_transition = |t: StateTransition, s: &StateGewekeSettings| { s.transitions.iter().any(|&ti| ti == t) }; @@ -1285,11 +1305,11 @@ impl GewekeModel for State { let do_ftr_prior_transition = has_transition(StateTransition::FeaturePriors, settings); - let do_state_alpha_transition = - has_transition(StateTransition::StateAlpha, settings); + let do_state_process_transition = + has_transition(StateTransition::StatePriorProcessParams, settings); - let do_view_alphas_transition = - has_transition(StateTransition::ViewAlphas, settings); + let do_view_process_transition = + has_transition(StateTransition::ViewPriorProcessParams, settings); let do_col_asgn_transition = settings.do_col_asgn_transition(); let do_row_asgn_transition = settings.do_row_asgn_transition(); @@ -1303,64 +1323,75 @@ impl GewekeModel for State { let n_cols = ftrs.len(); - let asgn_bldr = if do_col_asgn_transition { - AssignmentBuilder::new(n_cols) - } else { - AssignmentBuilder::new(n_cols).flat() - } - .seed_from_rng(&mut rng) - .with_geweke_prior(); + let state_prior_process = { + let process = if do_state_process_transition { + Process::Dirichlet(PDirichlet::from_prior( + geweke_alpha_prior(), + &mut rng, + )) + } else { + Process::Dirichlet(PDirichlet { + alpha_prior: geweke_alpha_prior(), + alpha: 1.0, + }) + }; - let asgn = if do_state_alpha_transition { - asgn_bldr.build().unwrap() - } else { - asgn_bldr.with_alpha(1.0).build().unwrap() + if do_col_asgn_transition { + AssignmentBuilder::new(n_cols) + } else { + AssignmentBuilder::new(n_cols).flat() + } + .with_process(process.clone()) + .seed_from_rng(&mut rng) + .build() + .unwrap() }; let view_asgn_bldr = if do_row_asgn_transition { - if do_view_alphas_transition { - AssignmentBuilder::new(settings.n_rows) - } else { - AssignmentBuilder::new(settings.n_rows).with_alpha(1.0) - } - } else if do_view_alphas_transition { - AssignmentBuilder::new(settings.n_rows).flat() - } else { AssignmentBuilder::new(settings.n_rows) - .flat() - .with_alpha(1.0) - } - .with_geweke_prior(); + } else { + AssignmentBuilder::new(settings.n_rows).flat() + }; - let mut views: Vec = (0..asgn.n_cats) + let mut views: Vec = (0..state_prior_process.asgn.n_cats) .map(|_| { + // may need to redraw the process params from the prior many + // times, so Process construction must be a generating function + let process = if do_view_process_transition { + Process::Dirichlet(PDirichlet::from_prior( + geweke_alpha_prior(), + &mut rng, + )) + } else { + Process::Dirichlet(PDirichlet { + alpha_prior: geweke_alpha_prior(), + alpha: 1.0, + }) + }; + let asgn = view_asgn_bldr .clone() .seed_from_rng(&mut rng) + .with_process(process.clone()) .build() .unwrap(); - view::Builder::from_assignment(asgn) + view::Builder::from_prior_process(asgn) .seed_from_rng(&mut rng) .build() }) .collect(); - for (&v, ftr) in asgn.asgn.iter().zip(ftrs.drain(..)) { + for (&v, ftr) in + state_prior_process.asgn.asgn.iter().zip(ftrs.drain(..)) + { views[v].geweke_init_feature(ftr, &mut rng); } - let view_alpha_prior = views[0].asgn.prior.clone(); - - let weights = asgn.weights(); State { views, - asgn, - weights, - view_alpha_prior, - loglike: 0.0, - log_prior: 0.0, - log_state_alpha_prior: 0.0, - log_view_alpha_prior: 0.0, + weights: state_prior_process.weight_vec(false), + prior_process: state_prior_process, + score: StateScoreComponents::default(), diagnostics: StateDiagnostics::default(), } } @@ -1385,7 +1416,6 @@ mod test { use super::*; use crate::state::Builder; - use approx::*; use lace_codebook::ColType; #[test] @@ -1403,7 +1433,7 @@ mod test { .build() .expect("Failed to build state"); - assert_eq!(state.asgn.asgn, vec![0, 0, 1, 1]); + assert_eq!(state.asgn().asgn, vec![0, 0, 1, 1]); let ftr = state.extract_ftr(1); @@ -1411,9 +1441,9 @@ mod test { assert_eq!(state.views[0].ftrs.len(), 1); assert_eq!(state.views[1].ftrs.len(), 2); - assert_eq!(state.asgn.asgn, vec![0, usize::max_value(), 1, 1]); - assert_eq!(state.asgn.counts, vec![1, 2]); - assert_eq!(state.asgn.n_cats, 2); + assert_eq!(state.asgn().asgn, vec![0, usize::max_value(), 1, 1]); + assert_eq!(state.asgn().counts, vec![1, 2]); + assert_eq!(state.asgn().n_cats, 2); assert_eq!(ftr.id(), 1); } @@ -1433,16 +1463,16 @@ mod test { .build() .expect("Failed to build state"); - assert_eq!(state.asgn.asgn, vec![0, 1, 1]); + assert_eq!(state.asgn().asgn, vec![0, 1, 1]); let ftr = state.extract_ftr(0); assert_eq!(state.n_views(), 1); assert_eq!(state.views[0].ftrs.len(), 2); - assert_eq!(state.asgn.asgn, vec![usize::max_value(), 0, 0]); - assert_eq!(state.asgn.counts, vec![2]); - assert_eq!(state.asgn.n_cats, 1); + assert_eq!(state.asgn().asgn, vec![usize::max_value(), 0, 0]); + assert_eq!(state.asgn().counts, vec![2]); + assert_eq!(state.asgn().n_cats, 1); assert_eq!(ftr.id(), 0); } @@ -1521,37 +1551,34 @@ mod test { for _ in 0..n_runs { let state = State::geweke_from_prior(settings, &mut rng); - // 1. Check the assignment prior - assert_relative_eq!(state.asgn.prior.shape(), 3.0, epsilon = 1E-12); - assert_relative_eq!(state.asgn.prior.rate(), 3.0, epsilon = 1E-12); // Column assignment is not flat - if state.asgn.asgn.iter().any(|&zi| zi != 0) { + if state.asgn().asgn.iter().any(|&zi| zi != 0) { cols_always_flat = false; } - if !basically_one(state.asgn.alpha) { + let alpha = match state.prior_process.process { + Process::Dirichlet(ref inner) => inner.alpha, + Process::PitmanYor(ref inner) => inner.alpha, + }; + + if !basically_one(alpha) { state_alpha_1 = false; } + // 2. Check the column priors for view in state.views.iter() { // Check the view assignment priors - assert_relative_eq!( - view.asgn.prior.shape(), - 3.0, - epsilon = 1E-12 - ); - assert_relative_eq!( - view.asgn.prior.rate(), - 3.0, - epsilon = 1E-12 - ); // Check the view assignments aren't flat - if view.asgn.asgn.iter().any(|&zi| zi != 0) { + if view.asgn().asgn.iter().any(|&zi| zi != 0) { rows_always_flat = false; } + let view_alpha = match view.prior_process.process { + Process::Dirichlet(ref inner) => inner.alpha, + Process::PitmanYor(ref inner) => inner.alpha, + }; - if !basically_one(view.asgn.alpha) { + if !basically_one(view_alpha) { view_alphas_1 = false; } } @@ -1567,8 +1594,10 @@ mod test { #[test] fn geweke_from_prior_all_transitions() { - let settings = - StateGewekeSettings::new(50, vec![FType::Continuous; 40]); + let settings = StateGewekeSettings::new_dirichlet_process( + 50, + vec![FType::Continuous; 40], + ); let mut rng = rand::thread_rng(); let result = test_asgn_flatness(&settings, 10, &mut rng); assert!(!result.rows_always_flat); @@ -1585,10 +1614,12 @@ mod test { cm_types: vec![FType::Continuous; 20], transitions: vec![ StateTransition::ColumnAssignment(ColAssignAlg::FiniteCpu), - StateTransition::StateAlpha, - StateTransition::ViewAlphas, + StateTransition::StatePriorProcessParams, + StateTransition::ViewPriorProcessParams, StateTransition::FeaturePriors, ], + state_process_type: PriorProcessType::Dirichlet, + view_process_type: PriorProcessType::Dirichlet, }; let mut rng = rand::thread_rng(); let result = test_asgn_flatness(&settings, 100, &mut rng); @@ -1606,10 +1637,12 @@ mod test { cm_types: vec![FType::Continuous; 20], transitions: vec![ StateTransition::RowAssignment(RowAssignAlg::FiniteCpu), - StateTransition::StateAlpha, - StateTransition::ViewAlphas, + StateTransition::StatePriorProcessParams, + StateTransition::ViewPriorProcessParams, StateTransition::FeaturePriors, ], + state_process_type: PriorProcessType::Dirichlet, + view_process_type: PriorProcessType::Dirichlet, }; let mut rng = rand::thread_rng(); let result = test_asgn_flatness(&settings, 100, &mut rng); @@ -1626,10 +1659,12 @@ mod test { n_rows: 50, cm_types: vec![FType::Continuous; 20], transitions: vec![ - StateTransition::StateAlpha, - StateTransition::ViewAlphas, + StateTransition::StatePriorProcessParams, + StateTransition::ViewPriorProcessParams, StateTransition::FeaturePriors, ], + state_process_type: PriorProcessType::Dirichlet, + view_process_type: PriorProcessType::Dirichlet, }; let mut rng = rand::thread_rng(); let result = test_asgn_flatness(&settings, 100, &mut rng); @@ -1650,6 +1685,8 @@ mod test { StateTransition::RowAssignment(RowAssignAlg::FiniteCpu), StateTransition::FeaturePriors, ], + state_process_type: PriorProcessType::Dirichlet, + view_process_type: PriorProcessType::Dirichlet, }; let mut rng = rand::thread_rng(); let result = test_asgn_flatness(&settings, 100, &mut rng); @@ -1705,6 +1742,6 @@ mod test { assert_eq!(state.n_views(), 1); assert_eq!(state.n_cols(), 20); - assert!(state.asgn.asgn.iter().all(|&z| z == 0)) + assert!(state.asgn().asgn.iter().all(|&z| z == 0)) } } diff --git a/lace/lace_cc/src/state/builder.rs b/lace/lace_cc/src/state/builder.rs index 5be95e05..bfb8c2cb 100644 --- a/lace/lace_cc/src/state/builder.rs +++ b/lace/lace_cc/src/state/builder.rs @@ -3,6 +3,8 @@ use lace_data::SparseContainer; use lace_stats::prior::csd::CsdHyper; use lace_stats::prior::nix::NixHyper; use lace_stats::prior::pg::PgHyper; +use lace_stats::prior_process::Builder as AssignmentBuilder; +use lace_stats::prior_process::Process; use lace_stats::rv::dist::{ Categorical, Gamma, Gaussian, NormalInvChiSquared, Poisson, }; @@ -11,7 +13,6 @@ use rand::SeedableRng; use rand_xoshiro::Xoshiro256Plus; use thiserror::Error; -use crate::assignment::AssignmentBuilder; use crate::feature::{ColModel, Column, Feature}; use crate::state::State; @@ -24,6 +25,7 @@ pub struct Builder { pub col_configs: Option>, pub ftrs: Option>, pub seed: Option, + pub prior_process: Option, } #[derive(Debug, Error, PartialEq)] @@ -106,6 +108,12 @@ impl Builder { self } + #[must_use] + pub fn prior_process(mut self, process: Process) -> Self { + self.prior_process = Some(process); + self + } + /// Build the `State` pub fn build(self) -> Result { let mut rng = match self.seed { @@ -168,13 +176,15 @@ impl Builder { col_asgn.append(&mut vec![view_ix; to_drain]); col_counts.push(to_drain); let ftrs_view = ftrs.drain(0..to_drain).collect(); - let asgn = AssignmentBuilder::new(n_rows) + + let prior_process = AssignmentBuilder::new(n_rows) .with_n_cats(n_cats) .unwrap() .seed_from_rng(&mut rng) .build() .unwrap(); - crate::view::Builder::from_assignment(asgn) + + crate::view::Builder::from_prior_process(prior_process) .features(ftrs_view) .seed_from_rng(&mut rng) .build() @@ -183,12 +193,22 @@ impl Builder { assert_eq!(ftrs.len(), 0); - let asgn = AssignmentBuilder::from_vec(col_asgn) + let process = self.prior_process.unwrap_or_else(|| { + Process::Dirichlet( + lace_stats::prior_process::Dirichlet::from_prior( + lace_consts::state_alpha_prior(), + &mut rng, + ), + ) + }); + + let process = AssignmentBuilder::from_vec(col_asgn) .seed_from_rng(&mut rng) + .with_process(process) .build() .unwrap(); - let alpha_prior: Gamma = lace_consts::state_alpha_prior(); - Ok(State::new(views, asgn, alpha_prior)) + + Ok(State::new(views, process)) } } @@ -313,15 +333,15 @@ mod tests { .expect("Failed to build state") }; - assert_eq!(state_1.asgn.asgn, state_2.asgn.asgn); + assert_eq!(state_1.asgn().asgn, state_2.asgn().asgn); for (view_1, view_2) in state_1.views.iter().zip(state_2.views.iter()) { - assert_eq!(view_1.asgn.asgn, view_2.asgn.asgn); + assert_eq!(view_1.asgn().asgn, view_2.asgn().asgn); } } #[test] - fn n_rows_overriden_by_features() { + fn n_rows_overridden_by_features() { let n_cols = 5; let col_models = { let state = Builder::new() diff --git a/lace/lace_cc/src/transition.rs b/lace/lace_cc/src/transition.rs index bfb5a65e..719dc4b6 100644 --- a/lace/lace_cc/src/transition.rs +++ b/lace/lace_cc/src/transition.rs @@ -7,9 +7,9 @@ use crate::ParseError; pub const DEFAULT_STATE_TRANSITIONS: [StateTransition; 5] = [ StateTransition::ColumnAssignment(ColAssignAlg::Slice), - StateTransition::StateAlpha, + StateTransition::StatePriorProcessParams, StateTransition::RowAssignment(RowAssignAlg::Slice), - StateTransition::ViewAlphas, + StateTransition::ViewPriorProcessParams, StateTransition::FeaturePriors, ]; @@ -19,7 +19,7 @@ pub enum ViewTransition { /// Reassign rows to categories RowAssignment(RowAssignAlg), /// Update the alpha (discount) parameters on the CRP - Alpha, + PriorProcessParams, /// Update the feature (column) prior parameters FeaturePriors, /// Update the parameters in the feature components. This is usually done @@ -39,11 +39,11 @@ pub enum StateTransition { #[serde(rename = "row_assignment")] RowAssignment(RowAssignAlg), /// Update the alpha (discount) parameter on the column-to-views CRP - #[serde(rename = "state_alpha")] - StateAlpha, + #[serde(rename = "state_prior_process_params")] + StatePriorProcessParams, /// Update the alpha (discount) parameters on the row-to-categories CRP - #[serde(rename = "view_alphas")] - ViewAlphas, + #[serde(rename = "view_prior_process_params")] + ViewPriorProcessParams, /// Update the feature (column) prior parameters #[serde(rename = "feature_priors")] FeaturePriors, @@ -60,7 +60,9 @@ impl TryFrom for ViewTransition { fn try_from(st: StateTransition) -> Result { match st { - StateTransition::ViewAlphas => Ok(ViewTransition::Alpha), + StateTransition::ViewPriorProcessParams => { + Ok(ViewTransition::PriorProcessParams) + } StateTransition::RowAssignment(alg) => { Ok(ViewTransition::RowAssignment(alg)) } diff --git a/lace/lace_cc/src/view.rs b/lace/lace_cc/src/view.rs index 6ba9f346..232c9a16 100644 --- a/lace/lace_cc/src/view.rs +++ b/lace/lace_cc/src/view.rs @@ -3,7 +3,12 @@ use std::f64::NEG_INFINITY; use lace_data::{Datum, FeatureData}; use lace_geweke::{GewekeModel, GewekeResampleData, GewekeSummarize}; -use lace_stats::rv::dist::{Dirichlet, Gamma}; +use lace_stats::assignment::Assignment; +use lace_stats::prior_process::Builder as AssignmentBuilder; +use lace_stats::prior_process::{ + PriorProcess, PriorProcessT, PriorProcessType, Process, +}; +use lace_stats::rv::dist::Dirichlet; use lace_stats::rv::misc::ln_pflip; use lace_stats::rv::traits::Rv; use lace_utils::{logaddexp, unused_components, Matrix, Shape}; @@ -13,7 +18,6 @@ use serde::{Deserialize, Serialize}; // use crate::cc::feature::geweke::{gen_geweke_col_models, ColumnGewekeSettings}; use crate::alg::RowAssignAlg; -use crate::assignment::{Assignment, AssignmentBuilder}; use crate::feature::geweke::GewekeColumnSummary; use crate::feature::geweke::{gen_geweke_col_models, ColumnGewekeSettings}; use crate::feature::{ColModel, FType, Feature}; @@ -30,7 +34,7 @@ pub struct View { /// A Map of features indexed by the feature ID pub ftrs: BTreeMap, /// The assignment of rows to categories - pub asgn: Assignment, + pub prior_process: PriorProcess, /// The weights of each category pub weights: Vec, } @@ -38,7 +42,7 @@ pub struct View { /// Builds a `View` pub struct Builder { n_rows: usize, - alpha_prior: Option, + process: Option, asgn: Option, ftrs: Option>, seed: Option, @@ -50,7 +54,7 @@ impl Builder { Builder { n_rows, asgn: None, - alpha_prior: None, + process: None, ftrs: None, seed: None, } @@ -63,7 +67,17 @@ impl Builder { Builder { n_rows: asgn.len(), asgn: Some(asgn), - alpha_prior: None, // is ignored in asgn set + process: None, // is ignored in asgn set + ftrs: None, + seed: None, + } + } + + pub fn from_prior_process(prior_process: PriorProcess) -> Self { + Builder { + n_rows: prior_process.asgn.len(), + asgn: Some(prior_process.asgn), + process: Some(prior_process.process), ftrs: None, seed: None, } @@ -71,13 +85,9 @@ impl Builder { /// Put a custom `Gamma` prior on the CRP alpha #[must_use] - pub fn alpha_prior(mut self, alpha_prior: Gamma) -> Self { - if self.asgn.is_some() { - panic!("Cannot add alpha_prior once Assignment added"); - } else { - self.alpha_prior = Some(alpha_prior); - self - } + pub fn prior_process(mut self, process: Process) -> Self { + self.process = Some(process); + self } /// Add features to the `View` @@ -103,51 +113,58 @@ impl Builder { /// Build the `View` and consume the builder pub fn build(self) -> View { + use lace_consts::general_alpha_prior; + use lace_stats::prior_process::Dirichlet; + let mut rng = match self.seed { Some(seed) => Xoshiro256Plus::seed_from_u64(seed), None => Xoshiro256Plus::from_entropy(), }; + let process = self.process.unwrap_or_else(|| { + Process::Dirichlet(Dirichlet::from_prior( + general_alpha_prior(), + &mut rng, + )) + }); + let asgn = match self.asgn { Some(asgn) => asgn, - None => { - if self.alpha_prior.is_none() { - AssignmentBuilder::new(self.n_rows) - .seed_from_rng(&mut rng) - .build() - .unwrap() - } else { - AssignmentBuilder::new(self.n_rows) - .with_prior(self.alpha_prior.unwrap()) - .seed_from_rng(&mut rng) - .build() - .unwrap() - } - } + None => process.draw_assignment(self.n_rows, &mut rng), }; - let weights = asgn.weights(); + let prior_process = PriorProcess { process, asgn }; + + let weights = prior_process.weight_vec(false); let mut ftr_tree = BTreeMap::new(); if let Some(mut ftrs) = self.ftrs { for mut ftr in ftrs.drain(..) { - ftr.reassign(&asgn, &mut rng); + ftr.reassign(&prior_process.asgn, &mut rng); ftr_tree.insert(ftr.id(), ftr); } } View { ftrs: ftr_tree, - asgn, + prior_process, weights, } } } impl View { + pub fn asgn(&self) -> &Assignment { + &self.prior_process.asgn + } + + pub fn asgn_mut(&mut self) -> &mut Assignment { + &mut self.prior_process.asgn + } + /// The number of rows in the `View` #[inline] pub fn n_rows(&self) -> usize { - self.asgn.asgn.len() + self.asgn().len() } /// The number of columns in the `View` @@ -171,19 +188,13 @@ impl View { /// The number of categories #[inline] pub fn n_cats(&self) -> usize { - self.asgn.n_cats - } - - /// The current value of the CPR alpha parameter - #[inline] - pub fn alpha(&self) -> f64 { - self.asgn.alpha + self.asgn().n_cats } // Extend the columns by a number of cells, increasing the total number of // rows. The added entries will be empty. pub fn extend_cols(&mut self, n_rows: usize) { - (0..n_rows).for_each(|_| self.asgn.push_unassigned()); + (0..n_rows).for_each(|_| self.asgn_mut().push_unassigned()); self.ftrs.values_mut().for_each(|ftr| { (0..n_rows).for_each(|_| ftr.append_datum(Datum::Missing)) }) @@ -195,7 +206,7 @@ impl View { row_ix: usize, col_ix: usize, ) -> Option { - let k = self.asgn.asgn[row_ix]; + let k = self.asgn().asgn[row_ix]; let is_assigned = k != usize::max_value(); if is_assigned { @@ -212,7 +223,7 @@ impl View { return; } - let k = self.asgn.asgn[row_ix]; + let k = self.asgn().asgn[row_ix]; let is_assigned = k != usize::max_value(); let ftr = self.ftrs.get_mut(&col_ix).unwrap(); @@ -267,8 +278,8 @@ impl View { ) { for transition in transitions { match transition { - ViewTransition::Alpha => { - self.update_alpha(&mut rng); + ViewTransition::PriorProcessParams => { + self.update_prior_process_params(&mut rng); } ViewTransition::RowAssignment(alg) => { self.reassign(*alg, &mut rng); @@ -287,7 +298,7 @@ impl View { pub fn default_transitions() -> Vec { vec![ ViewTransition::RowAssignment(RowAssignAlg::FiniteCpu), - ViewTransition::Alpha, + ViewTransition::PriorProcessParams, ViewTransition::FeaturePriors, ] } @@ -365,13 +376,13 @@ impl View { // it does not explicitly update the weights. Non-updated weights means // wrong probabilities. To avoid this, we set the weights by the // partition here. - self.weights = self.asgn.weights(); - debug_assert!(self.asgn.validate().is_valid()); + self.weights = self.prior_process.weight_vec(false); + debug_assert!(self.asgn().validate().is_valid()); } /// Use the finite approximation (on the CPU) to reassign the rows pub fn reassign_rows_finite_cpu(&mut self, mut rng: &mut impl Rng) { - let n_cats = self.asgn.n_cats; + let n_cats = self.n_cats(); let n_rows = self.n_rows(); self.resample_weights(true, &mut rng); @@ -392,17 +403,17 @@ impl View { /// Use the improved slice algorithm to reassign the rows pub fn reassign_rows_slice(&mut self, mut rng: &mut impl Rng) { - use crate::misc::sb_slice_extend; self.resample_weights(false, &mut rng); let weights: Vec = { - let dirvec = self.asgn.dirvec(true); + // FIXME: only works for dirichlet + let dirvec = self.prior_process.weight_vec_unnormed(true); let dir = Dirichlet::new(dirvec).unwrap(); dir.draw(&mut rng) }; let us: Vec = self - .asgn + .asgn() .asgn .iter() .map(|&zi| { @@ -416,9 +427,10 @@ impl View { us.iter() .fold(1.0, |umin, &ui| if ui < umin { ui } else { umin }); - let weights = - sb_slice_extend(weights, self.asgn.alpha, u_star, &mut rng) - .unwrap(); + let weights = self + .prior_process + .process + .slice_sb_extend(weights, u_star, &mut rng); let n_new_cats = weights.len() - self.weights.len(); let n_cats = weights.len(); @@ -461,7 +473,14 @@ impl View { add_empty_component: bool, mut rng: &mut impl Rng, ) { - let dirvec = self.asgn.dirvec(add_empty_component); + let dirvec = + self.prior_process.weight_vec_unnormed(add_empty_component); + + if dirvec.iter().any(|&p| p < 0.0) { + eprintln!("{:?}", dirvec); + eprintln!("{:?}\n", self.prior_process.process); + } + let dir = Dirichlet::new(dirvec).unwrap(); self.weights = dir.draw(&mut rng) } @@ -475,8 +494,8 @@ impl View { let i = ixs[0]; let j = ixs[1]; - let zi = self.asgn.asgn[i]; - let zj = self.asgn.asgn[j]; + let zi = self.asgn().asgn[i]; + let zj = self.asgn().asgn[j]; if zi < zj { (i, j, zi, zj) @@ -491,14 +510,15 @@ impl View { assert!(zi < zj); self.sams_merge(i, j, rng); } - debug_assert!(self.asgn.validate().is_valid()); + debug_assert!(self.asgn().validate().is_valid()); } /// MCMC update on the CPR alpha parameter #[inline] - pub fn update_alpha(&mut self, mut rng: &mut impl Rng) -> f64 { - self.asgn - .update_alpha(lace_consts::MH_PRIOR_ITERS, &mut rng) + pub fn update_prior_process_params(&mut self, rng: &mut impl Rng) -> f64 { + self.prior_process.update_params(rng); + // FIXME: should be the new likelihood + 0.0 } /// Insert a new `Feature` into the `View`, but draw the feature @@ -511,8 +531,8 @@ impl View { "Feature {} already in view", id ); - ftr.init_components(self.asgn.n_cats, &mut rng); - ftr.reassign(&self.asgn, &mut rng); + ftr.init_components(self.asgn().n_cats, &mut rng); + ftr.reassign(self.asgn(), &mut rng); self.ftrs.insert(id, ftr); } @@ -530,7 +550,7 @@ impl View { "Feature {} already in view", id ); - ftr.geweke_init(&self.asgn, rng); + ftr.geweke_init(self.asgn(), rng); self.ftrs.insert(id, ftr); } @@ -547,7 +567,7 @@ impl View { "Feature {} already in view", id ); - ftr.reassign(&self.asgn, &mut rng); + ftr.reassign(self.asgn(), &mut rng); self.ftrs.insert(id, ftr); } @@ -570,7 +590,7 @@ impl View { // assignment to preserve canonical order. (0..n).for_each(|_| { self.remove_row(ix); - self.asgn.asgn.remove(ix); + self.asgn_mut().asgn.remove(ix); }); // remove data from features @@ -596,7 +616,7 @@ impl View { #[inline] pub fn refresh_suffstats(&mut self, mut rng: &mut impl Rng) { for ftr in self.ftrs.values_mut() { - ftr.reassign(&self.asgn, &mut rng); + ftr.reassign(&self.prior_process.asgn, &mut rng); } } @@ -615,7 +635,7 @@ impl View { // problem is that I can't iterate on self.asgn then call // self.reinsert_row inside the for_each closure let mut unassigned_rows: Vec = self - .asgn + .asgn() .iter() .enumerate() .filter_map(|(row_ix, &z)| { @@ -640,10 +660,10 @@ impl View { // Remove the row for the purposes of MCMC without deleting its data. #[inline] fn remove_row(&mut self, row_ix: usize) { - let k = self.asgn.asgn[row_ix]; - let is_singleton = self.asgn.counts[k] == 1; + let k = self.asgn().asgn[row_ix]; + let is_singleton = self.asgn().counts[k] == 1; self.forget_row(row_ix, k); - self.asgn.unassign(row_ix); + self.asgn_mut().unassign(row_ix); if is_singleton { self.drop_component(k); @@ -660,25 +680,31 @@ impl View { #[inline] fn reinsert_row(&mut self, row_ix: usize, mut rng: &mut impl Rng) { - let k_new = if self.asgn.n_cats == 0 { + let k_new = if self.asgn().n_cats == 0 { // If empty, assign to category zero debug_assert!(self.ftrs.values().all(|f| f.k() == 0)); self.append_empty_component(&mut rng); 0 } else { // If not empty, do a Gibbs step - let mut logps: Vec = Vec::with_capacity(self.asgn.n_cats + 1); - self.asgn.counts.iter().enumerate().for_each(|(k, &ct)| { - logps.push( - (ct as f64).ln() + self.predictive_score_at(row_ix, k), - ); + let mut logps: Vec = + Vec::with_capacity(self.asgn().n_cats + 1); + + self.asgn().counts.iter().enumerate().for_each(|(k, &ct)| { + let w = self.prior_process.process.ln_gibbs_weight(ct); + logps.push(w + self.predictive_score_at(row_ix, k)); }); - logps.push(self.asgn.alpha.ln() + self.singleton_score(row_ix)); + logps.push( + self.prior_process + .process + .ln_singleton_weight(self.n_cats()) + + self.singleton_score(row_ix), + ); let k_new = ln_pflip(&logps, 1, false, &mut rng)[0]; - if k_new == self.asgn.n_cats { + if k_new == self.n_cats() { self.append_empty_component(&mut rng); } @@ -686,7 +712,7 @@ impl View { }; self.observe_row(row_ix, k_new); - self.asgn.reassign(row_ix, k_new); + self.asgn_mut().reassign(row_ix, k_new); } #[inline] @@ -724,20 +750,20 @@ impl View { } } - self.asgn + self.asgn_mut() .set_asgn(new_asgn_vec) .expect("new asgn is invalid"); self.resample_weights(false, &mut rng); for ftr in self.ftrs.values_mut() { - ftr.reassign(&self.asgn, &mut rng) + ftr.reassign(&self.prior_process.asgn, &mut rng) } } fn set_asgn(&mut self, asgn: Assignment, rng: &mut R) { - self.asgn = asgn; + self.prior_process.asgn = asgn; self.resample_weights(false, rng); for ftr in self.ftrs.values_mut() { - ftr.reassign(&self.asgn, rng) + ftr.reassign(&self.prior_process.asgn, rng) } } @@ -768,7 +794,7 @@ impl View { if calc_reverse { // Get the indices of the columns assigned to the clusters that // were split - self.asgn + self.asgn() .asgn .iter() .enumerate() @@ -785,7 +811,7 @@ impl View { } else { // Get the indices of the columns assigned to the cluster to split let mut row_ixs: Vec = self - .asgn + .asgn() .asgn .iter() .enumerate() @@ -798,17 +824,16 @@ impl View { } fn sams_merge(&mut self, i: usize, j: usize, rng: &mut R) { - use crate::assignment::lcrp; use std::cmp::Ordering; - let zi = self.asgn.asgn[i]; - let zj = self.asgn.asgn[j]; + let zi = self.asgn().asgn[i]; + let zj = self.asgn().asgn[j]; let (logp_spt, logq_spt, ..) = self.propose_split(i, j, true, rng); let asgn = { let zs = self - .asgn + .asgn() .asgn .iter() .map(|&z| match z.cmp(&zj) { @@ -819,11 +844,11 @@ impl View { .collect(); AssignmentBuilder::from_vec(zs) - .with_prior(self.asgn.prior.clone()) - .with_alpha(self.asgn.alpha) + .with_process(self.prior_process.process.clone()) .seed_from_rng(rng) .build() .unwrap() + .asgn }; self.append_empty_component(rng); @@ -833,8 +858,8 @@ impl View { } }); - let logp_mrg = self.logm(self.n_cats()) - + lcrp(asgn.len(), &asgn.counts, asgn.alpha); + let logp_mrg = + self.logm(self.n_cats()) + self.prior_process.ln_f_partition(&asgn); self.drop_component(self.n_cats()); @@ -844,12 +869,11 @@ impl View { } fn sams_split(&mut self, i: usize, j: usize, rng: &mut R) { - use crate::assignment::lcrp; - - let zi = self.asgn.asgn[i]; + let zi = self.asgn().asgn[i]; - let logp_mrg = self.logm(zi) - + lcrp(self.asgn.len(), &self.asgn.counts, self.asgn.alpha); + // FIXME: only works for CRP + let logp_mrg = + self.logm(zi) + self.prior_process.ln_f_partition(self.asgn()); let (logp_spt, logq_spt, asgn_opt) = self.propose_split(i, j, false, rng); @@ -868,15 +892,13 @@ impl View { calc_reverse: bool, rng: &mut R, ) -> (f64, f64, Option) { - use crate::assignment::lcrp; - - let zi = self.asgn.asgn[i]; - let zj = self.asgn.asgn[j]; + let zi = self.asgn().asgn[i]; + let zj = self.asgn().asgn[j]; self.append_empty_component(rng); self.append_empty_component(rng); - let zi_tmp = self.asgn.n_cats; + let zi_tmp = self.n_cats(); let zj_tmp = zi_tmp + 1; self.force_observe_row(i, zi_tmp); @@ -885,7 +907,7 @@ impl View { let mut tmp_z: Vec = { // mark everything assigned to the split cluster as unassigned (-1) let mut zs: Vec = self - .asgn + .asgn() .iter() .map(|&z| if z == zi { std::usize::MAX } else { z }) .collect(); @@ -909,7 +931,7 @@ impl View { let lognorm = logaddexp(logp_zi, logp_zj); let assign_to_zi = if calc_reverse { - self.asgn.asgn[ix] == zi + self.asgn().asgn[ix] == zi } else { rng.gen::().ln() < logp_zi - lognorm }; @@ -930,7 +952,7 @@ impl View { let mut logp = self.logm(zi_tmp) + self.logm(zj_tmp); let asgn = if calc_reverse { - logp += lcrp(self.asgn.len(), &self.asgn.counts, self.asgn.alpha); + logp += self.prior_process.ln_f_partition(self.asgn()); None } else { tmp_z.iter_mut().for_each(|z| { @@ -941,14 +963,15 @@ impl View { } }); + // FIXME: create (draw) new process outside to carry forward alpha let asgn = AssignmentBuilder::from_vec(tmp_z) - .with_prior(self.asgn.prior.clone()) - .with_alpha(self.asgn.alpha) + .with_process(self.prior_process.process.clone()) .seed_from_rng(rng) .build() - .unwrap(); + .unwrap() + .asgn; - logp += lcrp(asgn.len(), &asgn.counts, asgn.alpha); + logp += self.prior_process.ln_f_partition(&asgn); Some(asgn) }; @@ -1002,6 +1025,8 @@ pub struct ViewGewekeSettings { pub cm_types: Vec, /// Which transitions to run pub transitions: Vec, + /// Which prior process to use + pub process_type: PriorProcessType, } impl ViewGewekeSettings { @@ -1015,44 +1040,87 @@ impl ViewGewekeSettings { // parameter updates explicitly (they marginalize over the component // parameters) and the data resample relies on the component // parameters. + process_type: PriorProcessType::Dirichlet, transitions: vec![ ViewTransition::RowAssignment(RowAssignAlg::Slice), ViewTransition::FeaturePriors, ViewTransition::ComponentParams, - ViewTransition::Alpha, + ViewTransition::PriorProcessParams, ], } } + pub fn with_pitman_yor_process(mut self) -> Self { + self.process_type = PriorProcessType::PitmanYor; + self + } + + pub fn with_dirichlet_process(mut self) -> Self { + self.process_type = PriorProcessType::Dirichlet; + self + } + pub fn do_row_asgn_transition(&self) -> bool { self.transitions .iter() .any(|t| matches!(t, ViewTransition::RowAssignment(_))) } - pub fn do_alpha_transition(&self) -> bool { + pub fn do_process_params_transition(&self) -> bool { self.transitions .iter() - .any(|t| matches!(t, ViewTransition::Alpha)) + .any(|t| matches!(t, ViewTransition::PriorProcessParams)) } } -fn view_geweke_asgn( +fn view_geweke_asgn( n_rows: usize, - do_alpha_transition: bool, + do_process_params_transition: bool, do_row_asgn_transition: bool, -) -> AssignmentBuilder { - let mut bldr = AssignmentBuilder::new(n_rows).with_geweke_prior(); + process_type: PriorProcessType, + rng: &mut R, +) -> (AssignmentBuilder, Process) { + use lace_consts::geweke_alpha_prior; + let process = match process_type { + PriorProcessType::Dirichlet => { + use lace_stats::prior_process::Dirichlet; + let inner = if do_process_params_transition { + Dirichlet::from_prior(geweke_alpha_prior(), rng) + } else { + Dirichlet { + alpha: 1.0, + alpha_prior: geweke_alpha_prior(), + } + }; + Process::Dirichlet(inner) + } + PriorProcessType::PitmanYor => { + use lace_stats::prior_process::PitmanYor; + use lace_stats::rv::dist::Beta; + let inner = if do_process_params_transition { + PitmanYor::from_prior( + geweke_alpha_prior(), + Beta::jeffreys(), + rng, + ) + } else { + PitmanYor { + alpha: 1.0, + d: 0.2, + alpha_prior: geweke_alpha_prior(), + d_prior: Beta::jeffreys(), + } + }; + Process::PitmanYor(inner) + } + }; + let mut bldr = AssignmentBuilder::new(n_rows).with_process(process.clone()); if !do_row_asgn_transition { bldr = bldr.flat(); } - if !do_alpha_transition { - bldr = bldr.with_alpha(1.0); - } - - bldr + (bldr, process) } impl GewekeModel for View { @@ -1065,14 +1133,15 @@ impl GewekeModel for View { .iter() .any(|&t| t == ViewTransition::FeaturePriors); - let asgn = view_geweke_asgn( + // FIXME: Redundant! asgn_builder builds a PriorProcess + let (asgn_builder, process) = view_geweke_asgn( settings.n_rows, - settings.do_alpha_transition(), + settings.do_process_params_transition(), settings.do_row_asgn_transition(), - ) - .seed_from_rng(&mut rng) - .build() - .unwrap(); + settings.process_type, + rng, + ); + let asgn = asgn_builder.seed_from_rng(&mut rng).build().unwrap(); // this function sets up dummy features that we can properly populate with // Feature.geweke_init in the next loop @@ -1087,15 +1156,20 @@ impl GewekeModel for View { .drain(..) .enumerate() .map(|(id, mut ftr)| { - ftr.geweke_init(&asgn, &mut rng); + ftr.geweke_init(&asgn.asgn, &mut rng); (id, ftr) }) .collect(); + let prior_process = PriorProcess { + process, + asgn: asgn.asgn, + }; + View { ftrs, - weights: asgn.weights(), - asgn, + weights: prior_process.weight_vec(false), + prior_process, } } @@ -1116,8 +1190,10 @@ impl GewekeResampleData for View { rng: &mut impl Rng, ) { let s = settings.unwrap(); - let col_settings = - ColumnGewekeSettings::new(self.asgn.clone(), s.transitions.clone()); + let col_settings = ColumnGewekeSettings::new( + self.asgn().clone(), + s.transitions.clone(), + ); for ftr in self.ftrs.values_mut() { ftr.geweke_resample_data(Some(&col_settings), rng); } @@ -1168,7 +1244,7 @@ impl GewekeSummarize for View { fn geweke_summarize(&self, settings: &ViewGewekeSettings) -> Self::Summary { let col_settings = ColumnGewekeSettings::new( - self.asgn.clone(), + self.asgn().clone(), settings.transitions.clone(), ); @@ -1178,8 +1254,11 @@ impl GewekeSummarize for View { } else { None }, - alpha: if settings.do_alpha_transition() { - Some(self.asgn.alpha) + alpha: if settings.do_process_params_transition() { + Some(match self.prior_process.process { + Process::Dirichlet(ref inner) => inner.alpha, + Process::PitmanYor(ref inner) => inner.alpha, + }) } else { None }, @@ -1282,7 +1361,7 @@ mod tests { gen_gauss_view(1000, &mut rng) }; - assert_eq!(view_1.asgn.asgn, view_2.asgn.asgn); + assert_eq!(view_1.asgn().asgn, view_2.asgn().asgn); } #[test] @@ -1294,9 +1373,9 @@ mod tests { view.extend_cols(2); - assert_eq!(view.asgn.asgn.len(), 12); - assert_eq!(view.asgn.asgn[10], usize::max_value()); - assert_eq!(view.asgn.asgn[11], usize::max_value()); + assert_eq!(view.asgn().asgn.len(), 12); + assert_eq!(view.asgn().asgn[10], usize::max_value()); + assert_eq!(view.asgn().asgn[11], usize::max_value()); for ftr in view.ftrs.values() { assert_eq!(ftr.len(), 12); @@ -1314,13 +1393,13 @@ mod tests { let components_start = extract_components(&view); - let view_ix_start = view.asgn.asgn[2]; + let view_ix_start = view.asgn().asgn[2]; let component_start = components_start[3][view_ix_start].clone(); view.insert_datum(2, 3, Datum::Continuous(20.22)); let components_end = extract_components(&view); - let view_ix_end = view.asgn.asgn[2]; + let view_ix_end = view.asgn().asgn[2]; let component_end = components_end[3][view_ix_end].clone(); assert_ne!(components_start, components_end); diff --git a/lace/lace_cc/tests/enum.rs b/lace/lace_cc/tests/enum.rs index 551be9da..e8a4b00f 100644 --- a/lace/lace_cc/tests/enum.rs +++ b/lace/lace_cc/tests/enum.rs @@ -4,10 +4,20 @@ use enum_test::*; mod partition { use super::*; - use lace_cc::misc::crp_draw; + use lace_stats::assignment::Assignment; + use lace_stats::prior_process::{Dirichlet, PriorProcessT}; + use lace_stats::rv::dist::Gamma; use rand::rngs::StdRng; use rand::SeedableRng; + fn crp_draw(n: usize, alpha: f64, rng: &mut R) -> Assignment { + let process = Dirichlet { + alpha_prior: Gamma::default(), // doesn't matter here, + alpha, + }; + process.draw_assignment(n, rng) + } + #[test] fn empty_partition() { let mut rng = StdRng::seed_from_u64(0xABCD); diff --git a/lace/lace_cc/tests/enum_state.rs b/lace/lace_cc/tests/enum_state.rs index 493a934e..1909ef36 100644 --- a/lace/lace_cc/tests/enum_state.rs +++ b/lace/lace_cc/tests/enum_state.rs @@ -15,13 +15,16 @@ use lace_stats::rv::misc::logsumexp; use rand::Rng; use lace_cc::alg::{ColAssignAlg, RowAssignAlg}; -use lace_cc::assignment::lcrp; -use lace_cc::assignment::AssignmentBuilder; use lace_cc::config::StateUpdateConfig; use lace_cc::feature::{ColModel, FType, Feature}; use lace_cc::state::State; use lace_cc::transition::StateTransition; use lace_cc::view::{Builder, View}; +use lace_stats::prior_process::Builder as PriorProcessBuilder; +use lace_stats::prior_process::{ + Dirichlet, PitmanYor, PriorProcessT, PriorProcessType, Process, +}; +use lace_stats::rv::dist::{Beta, Gamma}; use enum_test::{ build_features, normalize_assignment, partition_to_ix, Partition, @@ -84,25 +87,24 @@ fn state_from_partition( mut features: Vec, mut rng: &mut R, ) -> State { - let asgn = AssignmentBuilder::from_vec(partition.col_partition.clone()) - .with_alpha(1.0) - .seed_from_rng(&mut rng) - .build() - .unwrap(); - let mut views: Vec = partition .row_partitions .iter() .map(|zr| { + let process = Process::Dirichlet(Dirichlet { + alpha: 1.0, + alpha_prior: Gamma::default(), + }); + // NOTE: We don't need seed control here because both alpha and the // assignment are set, but I'm setting the seed anyway in case the // assignment builder internals change - let asgn = AssignmentBuilder::from_vec(zr.clone()) - .with_alpha(1.0) + let view_prior_process = PriorProcessBuilder::from_vec(zr.clone()) + .with_process(process.clone()) .seed_from_rng(&mut rng) .build() .unwrap(); - Builder::from_assignment(asgn) + Builder::from_prior_process(view_prior_process) .seed_from_rng(&mut rng) .build() }) @@ -114,43 +116,76 @@ fn state_from_partition( .zip(features.drain(..)) .for_each(|(&zi, ftr)| views[zi].insert_feature(ftr, &mut rng)); - State::new(views, asgn, lace_consts::state_alpha_prior()) + let state_prior_process = + PriorProcessBuilder::from_vec(partition.col_partition.clone()) + .with_process(Process::Dirichlet(Dirichlet { + alpha: 1.0, + alpha_prior: Gamma::default(), + })) + .seed_from_rng(&mut rng) + .build() + .unwrap(); + + State::new(views, state_prior_process) +} + +fn emit_process(proc_type: PriorProcessType) -> Process { + match proc_type { + PriorProcessType::Dirichlet => Process::Dirichlet(Dirichlet { + alpha: 1.0, + alpha_prior: Gamma::default(), + }), + PriorProcessType::PitmanYor => Process::PitmanYor(PitmanYor { + alpha: 1.2, + d: 0.2, + alpha_prior: Gamma::default(), + d_prior: Beta::jeffreys(), + }), + } } /// Generates a random start state from the prior, with default values chosen for the /// feature priors, and all CRP alphas set to 1.0. fn gen_start_state( mut features: Vec, + proc_type: PriorProcessType, mut rng: &mut R, ) -> State { let n_cols = features.len(); let n_rows = features[0].len(); - let asgn = AssignmentBuilder::new(n_cols) - .with_alpha(1.0) + + let process = emit_process(proc_type); + + let state_prior_process = PriorProcessBuilder::new(n_cols) + .with_process(process) .seed_from_rng(&mut rng) .build() .unwrap(); - let mut views: Vec = (0..asgn.n_cats) + let mut views: Vec = (0..state_prior_process.asgn.n_cats) .map(|_| { - let asgn = AssignmentBuilder::new(n_rows) - .with_alpha(1.0) + let view_prior_process = PriorProcessBuilder::new(n_rows) + .with_process(state_prior_process.process.clone()) .seed_from_rng(&mut rng) .build() .unwrap(); - Builder::from_assignment(asgn).build() + Builder::from_prior_process(view_prior_process).build() }) .collect(); - asgn.iter() + state_prior_process + .asgn + .iter() .zip(features.drain(..)) .for_each(|(&zi, ftr)| views[zi].insert_feature(ftr, &mut rng)); - State::new(views, asgn, lace_consts::state_alpha_prior()) + State::new(views, state_prior_process) } fn calc_state_ln_posterior( features: Vec, + state_process: &Process, + view_process: &Process, mut rng: &mut R, ) -> HashMap { let n_cols = features.len(); @@ -162,9 +197,9 @@ fn calc_state_ln_posterior( .iter() .for_each(|part| { let state = state_from_partition(part, features.clone(), &mut rng); - let mut score = lcrp(state.n_cols(), &state.asgn.counts, 1.0); + let mut score = state_process.ln_f_partition(state.asgn()); for view in state.views { - score += lcrp(view.n_rows(), &view.asgn.counts, 1.0); + score += view_process.ln_f_partition(view.asgn()); for ftr in view.ftrs.values() { score += ftr.score(); } @@ -184,13 +219,13 @@ fn calc_state_ln_posterior( /// Extract the index from a State fn extract_state_index(state: &State) -> StateIndex { - let normed = normalize_assignment(state.asgn.asgn.clone()); + let normed = normalize_assignment(state.asgn().asgn.clone()); let col_ix: u64 = partition_to_ix(&normed); let row_ixs: Vec = state .views .iter() .map(|v| { - let zn = normalize_assignment(v.asgn.asgn.clone()); + let zn = normalize_assignment(v.asgn().asgn.clone()); partition_to_ix(&zn) }) .collect(); @@ -216,6 +251,7 @@ pub fn state_enum_test( row_alg: RowAssignAlg, col_alg: ColAssignAlg, ftype: FType, + proc_type: PriorProcessType, mut rng: &mut R, ) -> f64 { let features = build_features(n_rows, n_cols, ftype, &mut rng); @@ -232,31 +268,20 @@ pub fn state_enum_test( let inc: f64 = ((n_runs * n_iters) as f64).recip(); for _ in 0..n_runs { - let mut state = gen_start_state(features.clone(), &mut rng); - - // alphas should start out at 1.0 - assert!((state.asgn.alpha - 1.0).abs() < 1E-16); - assert!(state - .views - .iter() - .all(|v| (v.asgn.alpha - 1.0).abs() < 1E-16)); + let mut state = gen_start_state(features.clone(), proc_type, &mut rng); for _ in 0..n_iters { state.update(update_config.clone(), &mut rng); - // all alphas should be 1.0 - assert!((state.asgn.alpha - 1.0).abs() < 1E-16); - assert!(state - .views - .iter() - .all(|v| (v.asgn.alpha - 1.0).abs() < 1E-16)); - let ix = extract_state_index(&state); *est_posterior.entry(ix).or_insert(0.0) += inc; } } - let posterior = calc_state_ln_posterior(features, &mut rng); + let process = emit_process(proc_type); + + let posterior = + calc_state_ln_posterior(features, &process, &process, &mut rng); assert!(!est_posterior.keys().any(|k| !posterior.contains_key(k))); @@ -293,23 +318,50 @@ mod tests { // TODO: could remove $test name by using mods macro_rules! state_enum_test { ($test_name: ident, $ftype: ident, $row_alg: ident, $col_alg: ident) => { - #[test] - fn $test_name() { - fn test_fn() -> bool { - let mut rng = rand::thread_rng(); - let err = state_enum_test( - 3, - 3, - 1, - 10_000, - RowAssignAlg::$row_alg, - ColAssignAlg::$col_alg, - FType::$ftype, - &mut rng, - ); - err < 0.01 + mod $test_name { + use super::*; + + #[test] + fn dirichlet() { + fn test_fn() -> bool { + let mut rng = rand::thread_rng(); + let err = state_enum_test( + 3, + 3, + 1, + 10_000, + RowAssignAlg::$row_alg, + ColAssignAlg::$col_alg, + FType::$ftype, + PriorProcessType::Dirichlet, + &mut rng, + ); + eprintln!("err: {err}"); + err < 0.01 + } + assert!(flaky_test_passes(N_TRIES, test_fn)); + } + + #[test] + fn pitman_yor() { + fn test_fn() -> bool { + let mut rng = rand::thread_rng(); + let err = state_enum_test( + 3, + 3, + 1, + 10_000, + RowAssignAlg::$row_alg, + ColAssignAlg::$col_alg, + FType::$ftype, + PriorProcessType::PitmanYor, + &mut rng, + ); + eprintln!("err: {err}"); + err < 0.01 + } + assert!(flaky_test_passes(N_TRIES, test_fn)); } - assert!(flaky_test_passes(N_TRIES, test_fn)); } }; ($(($fn_name: ident, $ftype: ident, $row_alg: ident, $col_alg: ident)),+) => { @@ -331,9 +383,11 @@ mod tests { #[test] fn ln_posterior_length() { + let process = emit_process(PriorProcessType::Dirichlet); let mut rng = rand::thread_rng(); let ftrs = build_features(3, 3, FType::Continuous, &mut rng); - let posterior = calc_state_ln_posterior(ftrs, &mut rng); + let posterior = + calc_state_ln_posterior(ftrs, &process, &process, &mut rng); assert_eq!(posterior.len(), 205) } diff --git a/lace/lace_cc/tests/state.rs b/lace/lace_cc/tests/state.rs index c3cbe852..386384f4 100644 --- a/lace/lace_cc/tests/state.rs +++ b/lace/lace_cc/tests/state.rs @@ -4,6 +4,7 @@ use lace_cc::feature::{ColModel, Column}; use lace_cc::state::State; use lace_data::{FeatureData, SparseContainer}; use lace_stats::prior::nix::NixHyper; +use lace_stats::prior_process::{Dirichlet, Process}; use lace_stats::rv::dist::{Gamma, Gaussian, NormalInvChiSquared}; use lace_stats::rv::traits::Rv; use rand::Rng; @@ -19,6 +20,13 @@ fn gen_col(id: usize, n: usize, mut rng: &mut R) -> ColModel { ColModel::Continuous(ftr) } +fn default_process(rng: &mut R) -> Process { + Process::Dirichlet(Dirichlet::from_prior( + Gamma::new(1.0, 1.0).unwrap(), + rng, + )) +} + fn gen_all_gauss_state( n_rows: usize, n_cols: usize, @@ -28,10 +36,11 @@ fn gen_all_gauss_state( for i in 0..n_cols { ftrs.push(gen_col(i, n_rows, &mut rng)); } + State::from_prior( ftrs, - Gamma::new(1.0, 1.0).unwrap(), - Gamma::new(1.0, 1.0).unwrap(), + default_process(rng), + default_process(rng), &mut rng, ) } @@ -186,9 +195,9 @@ fn two_part_runner( n_iters: 50, transitions: vec![ StateTransition::ColumnAssignment(first_algs.1), - StateTransition::StateAlpha, + StateTransition::StatePriorProcessParams, StateTransition::RowAssignment(first_algs.0), - StateTransition::ViewAlphas, + StateTransition::ViewPriorProcessParams, StateTransition::FeaturePriors, ], }; @@ -199,9 +208,9 @@ fn two_part_runner( n_iters: 50, transitions: vec![ StateTransition::ColumnAssignment(second_algs.1), - StateTransition::StateAlpha, + StateTransition::StatePriorProcessParams, StateTransition::RowAssignment(second_algs.0), - StateTransition::ViewAlphas, + StateTransition::ViewPriorProcessParams, StateTransition::FeaturePriors, ], }; diff --git a/lace/lace_cc/tests/view.rs b/lace/lace_cc/tests/view.rs index 9acb0f0c..44c40e6f 100644 --- a/lace/lace_cc/tests/view.rs +++ b/lace/lace_cc/tests/view.rs @@ -48,7 +48,7 @@ fn finite_reassign_direct_call() { let mut view = gen_gauss_view(10, &mut rng); view.reassign_rows_finite_cpu(&mut rng); - assert!(view.asgn.validate().is_valid()); + assert!(view.asgn().validate().is_valid()); } #[test] @@ -57,7 +57,7 @@ fn finite_reassign_from_reassign() { let mut view = gen_gauss_view(10, &mut rng); view.reassign(RowAssignAlg::FiniteCpu, &mut rng); - assert!(view.asgn.validate().is_valid()); + assert!(view.asgn().validate().is_valid()); } #[test] diff --git a/lace/lace_cc/tests/view_enum.rs b/lace/lace_cc/tests/view_enum.rs index 115ae333..ff773c56 100644 --- a/lace/lace_cc/tests/view_enum.rs +++ b/lace/lace_cc/tests/view_enum.rs @@ -5,20 +5,45 @@ mod enum_test; use std::collections::BTreeMap; -use lace_stats::rv::misc::logsumexp; use rand::Rng; use enum_test::{ build_features, normalize_assignment, partition_to_ix, Partition, }; use lace_cc::alg::RowAssignAlg; -use lace_cc::assignment::{lcrp, AssignmentBuilder}; use lace_cc::feature::{ColModel, FType, Feature}; use lace_cc::transition::ViewTransition; use lace_cc::view::{Builder, View}; +use lace_stats::prior_process::Builder as PriorProcessBuilder; +use lace_stats::prior_process::{Dirichlet, PitmanYor, Process}; +use lace_stats::rv::dist::{Beta, Gamma}; +use lace_stats::rv::misc::logsumexp; const N_TRIES: u32 = 5; +#[derive(Clone, Copy, Debug)] +pub enum ProcessType { + Dirichlet, + PitmanYor, +} + +impl From for Process { + fn from(proc: ProcessType) -> Self { + match proc { + ProcessType::Dirichlet => Process::Dirichlet(Dirichlet { + alpha: 1.0, + alpha_prior: Gamma::default(), + }), + ProcessType::PitmanYor => Process::PitmanYor(PitmanYor { + alpha: 1.2, + d: 0.2, + alpha_prior: Gamma::default(), + d_prior: Beta::jeffreys(), + }), + } + } +} + /// Compute the posterior of all assignments of the features under CRP(alpha) /// /// NOTE: The rng is required, for calling AssignmentBuilder.build, but nothing @@ -26,7 +51,7 @@ const N_TRIES: u32 = 5; #[allow(clippy::ptr_arg)] fn calc_partition_ln_posterior( features: &Vec, - alpha: f64, + proc_type: ProcessType, mut rng: &mut R, ) -> BTreeMap { let n = features[0].len(); @@ -38,15 +63,16 @@ fn calc_partition_ln_posterior( // NOTE: We don't need seed control here because both alpha and the // assignment are set, but I'm setting the seed anyway in case the // assignment builder internals change - let asgn = AssignmentBuilder::from_vec(z) - .with_alpha(alpha) + let prior_process = PriorProcessBuilder::from_vec(z) + .with_process(proc_type.into()) .seed_from_rng(&mut rng) .build() .unwrap(); - let ln_pz = lcrp(n, &asgn.counts, alpha); + // let ln_pz = lcrp(n, &prior_process.asgn.counts, alpha); + let ln_pz = prior_process.ln_f_partition(&prior_process.asgn); - let view: View = Builder::from_assignment(asgn) + let view: View = Builder::from_prior_process(prior_process) .features(features.clone()) .seed_from_rng(&mut rng) .build(); @@ -77,10 +103,12 @@ pub fn view_enum_test( n_iters: usize, ftype: FType, row_alg: RowAssignAlg, + proc_type: ProcessType, ) -> f64 { let mut rng = rand::thread_rng(); let features = build_features(n_rows, n_cols, ftype, &mut rng); - let ln_posterior = calc_partition_ln_posterior(&features, 1.0, &mut rng); + let ln_posterior = + calc_partition_ln_posterior(&features, proc_type, &mut rng); let posterior = norm_posterior(&ln_posterior); let transitions: Vec = vec![ @@ -92,13 +120,13 @@ pub fn view_enum_test( let inc: f64 = ((n_runs * n_iters) as f64).recip(); for _ in 0..n_runs { - let asgn = AssignmentBuilder::new(n_rows) - .with_alpha(1.0) + let prior_process = PriorProcessBuilder::new(n_rows) + .with_process(proc_type.into()) .seed_from_rng(&mut rng) .build() .unwrap(); - let mut view = Builder::from_assignment(asgn) + let mut view = Builder::from_prior_process(prior_process) .features(features.clone()) .seed_from_rng(&mut rng) .build(); @@ -106,11 +134,11 @@ pub fn view_enum_test( for _ in 0..n_iters { view.update(10, &transitions, &mut rng); - let normed = normalize_assignment(view.asgn.asgn.clone()); + let normed = normalize_assignment(view.asgn().asgn.clone()); let ix = partition_to_ix(&normed); if !posterior.contains_key(&ix) { - panic!("invalid index!\n{:?}\n{:?}", view.asgn.asgn, normed); + panic!("invalid index!\n{:?}\n{:?}", view.asgn().asgn, normed); } *est_posterior.entry(ix).or_insert(0.0) += inc; @@ -144,7 +172,7 @@ where // TODO: could remove $test name by using mods macro_rules! view_enum_test { - ($ftype: ident, $row_alg: ident) => { + ($ftype: ident, $proctype: ident, $row_alg: ident) => { #[test] fn $row_alg() { fn test_fn() -> bool { @@ -154,32 +182,57 @@ macro_rules! view_enum_test { 1, 5_000, FType::$ftype, - RowAssignAlg::$row_alg + RowAssignAlg::$row_alg, + ProcessType::$proctype, ); + eprintln!("err: {}", err); err < 0.01 } assert!(flaky_test_passes(N_TRIES, test_fn)); } }; - ($ftype: ident, [$($row_alg: ident),+]) => { + ($modname: ident, $ftype: ident, $proctype: ident, [$($row_alg: ident),+]) => { #[allow(non_snake_case)] - mod $ftype { + mod $modname { use super::*; $( - view_enum_test!($ftype, $row_alg); + view_enum_test!($ftype, $proctype, $row_alg); )+ } }; - ($(($ftype: ident, $row_algs: tt)),+) => { + ($(($modname: ident, $ftype: ident, $proctype: ident, $row_algs: tt)),+) => { $( - view_enum_test!($ftype, $row_algs); + view_enum_test!($modname, $ftype, $proctype, $row_algs); )+ }; } view_enum_test!( - (Continuous, [Gibbs, Slice, Sams]), - (Categorical, [Gibbs, Slice, Sams]), - (Count, [Gibbs, Slice, Sams]) + ( + ve_continuous_dp, + Continuous, + Dirichlet, + [Gibbs, Slice, Sams] + ), + ( + ve_continuous_pyp, + Continuous, + PitmanYor, + [Gibbs, Slice, Sams] + ), + ( + ve_categorical_dp, + Categorical, + Dirichlet, + [Gibbs, Slice, Sams] + ), + ( + ve_categorical_pyp, + Categorical, + PitmanYor, + [Gibbs, Slice, Sams] + ), + (ve_count_dp, Count, Dirichlet, [Gibbs, Slice, Sams]), + (ve_count_pyp, Count, PitmanYor, [Gibbs, Slice, Sams]) ); diff --git a/lace/lace_codebook/Cargo.toml b/lace/lace_codebook/Cargo.toml index 824a12ba..a279569a 100644 --- a/lace/lace_codebook/Cargo.toml +++ b/lace/lace_codebook/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "lace_codebook" -version = "0.6.0" +version = "0.7.0" authors = ["Promised.ai"] edition = "2021" license = "BUSL-1.1" @@ -10,7 +10,7 @@ description = "Contains the Lace codebook specification as well as utilities for [dependencies] lace_consts = { path = "../lace_consts", version = "0.2.1" } -lace_stats = { path = "../lace_stats", version = "0.3.0" } +lace_stats = { path = "../lace_stats", version = "0.4.0" } lace_utils = { path = "../lace_utils", version = "0.3.0" } lace_data = { path = "../lace_data", version = "0.3.0" } serde = { version = "1", features = ["derive"] } @@ -19,9 +19,7 @@ thiserror = "1.0.11" polars = { version = "0.36", default_features=false, features=["csv", "dtype-i8", "dtype-i16", "dtype-u8", "dtype-u16"] } [features] -# default = ["formats"] formats = ["polars/json", "polars/ipc", "polars/decompress", "polars/parquet"] -# formats = [] [dev-dependencies] tempfile = "3.3.0" diff --git a/lace/lace_codebook/examples/bench_csv.rs b/lace/lace_codebook/examples/bench_csv.rs index 274f7afd..e62ffb88 100644 --- a/lace/lace_codebook/examples/bench_csv.rs +++ b/lace/lace_codebook/examples/bench_csv.rs @@ -19,7 +19,7 @@ fn main() { // println!("t_old: {}s", t_old.as_secs_f64()); let now = Instant::now(); - let _codebook = csv_new(path, None, None, true); + let _codebook = csv_new(path, None, None, None, true); let t_new = now.elapsed(); println!("t_new: {}s", t_new.as_secs_f64()); diff --git a/lace/lace_codebook/src/codebook.rs b/lace/lace_codebook/src/codebook.rs index 11f2c798..d851fb1c 100644 --- a/lace/lace_codebook/src/codebook.rs +++ b/lace/lace_codebook/src/codebook.rs @@ -7,7 +7,9 @@ use crate::ValueMap; use lace_stats::prior::csd::CsdHyper; use lace_stats::prior::nix::NixHyper; use lace_stats::prior::pg::PgHyper; -use lace_stats::rv::dist::{Gamma, NormalInvChiSquared, SymmetricDirichlet}; +use lace_stats::rv::dist::{ + Beta, Gamma, NormalInvChiSquared, SymmetricDirichlet, +}; use polars::prelude::DataFrame; use serde::{Deserialize, Serialize}; use std::collections::HashMap; @@ -343,6 +345,37 @@ impl TryFrom> for ColMetadataList { } } +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum PriorProcess { + Dirichlet { alpha_prior: Gamma }, + PitmanYor { alpha_prior: Gamma, d_prior: Beta }, +} + +impl std::fmt::Display for PriorProcess { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + PriorProcess::Dirichlet { alpha_prior } => { + write!(f, "DP(α ~ {})", alpha_prior) + } + PriorProcess::PitmanYor { + alpha_prior, + d_prior, + } => { + write!(f, "PYP(α ~ {}, d ~ {})", alpha_prior, d_prior) + } + } + } +} + +impl Default for PriorProcess { + fn default() -> Self { + Self::Dirichlet { + alpha_prior: lace_consts::general_alpha_prior(), + } + } +} + /// Codebook object for storing information about the dataset #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[serde(deny_unknown_fields)] @@ -350,9 +383,9 @@ pub struct Codebook { /// The name of the table pub table_name: String, /// Prior on State CRP alpha parameter - pub state_alpha_prior: Option, + pub state_prior_process: Option, /// Prior on View CRP alpha parameters - pub view_alpha_prior: Option, + pub view_prior_process: Option, /// The metadata for each column indexed by name pub col_metadata: ColMetadataList, /// Optional misc comments @@ -372,8 +405,8 @@ impl Codebook { Codebook { table_name, col_metadata, - view_alpha_prior: None, - state_alpha_prior: None, + view_prior_process: None, + state_prior_process: None, comments: None, row_names: RowNameList::new(), } @@ -385,15 +418,23 @@ impl Codebook { /// - df: the dataframe /// - cat_cutoff: the maximum value an integer column can take on before it /// is considered Count type instead of Categorical - /// - alpha_prior_opt: Optional Gamma prior on the column and row CRP alpha + /// - state_prior_process: The prior process on the column partition + /// - view_prior_process: The prior process on the row partitions /// - no_hypers: if `true` do not do prior parameter inference pub fn from_df( df: &DataFrame, cat_cutoff: Option, - alpha_prior_opt: Option, + state_prior_process: Option, + view_prior_process: Option, no_hypers: bool, ) -> Result { - df_to_codebook(df, cat_cutoff, alpha_prior_opt, no_hypers) + df_to_codebook( + df, + cat_cutoff, + state_prior_process, + view_prior_process, + no_hypers, + ) } pub fn from_yaml>(path: P) -> io::Result { @@ -761,8 +802,8 @@ mod tests { !Categorical k: 2 value_map: !u8 2 - state_alpha_prior: ~ - view_alpha_prior: ~ + state_prior_process: ~ + view_prior_process: ~ comments: ~ row_names: - one @@ -794,8 +835,8 @@ mod tests { coltype: !Categorical k: 2 - state_alpha_prior: ~ - view_alpha_prior: ~ + state_prior_process: ~ + view_prior_process: ~ comments: ~ row_names: - one @@ -810,8 +851,8 @@ mod tests { fn serialize_metadata_list() { let codebook = Codebook { table_name: "my-table".into(), - state_alpha_prior: None, - view_alpha_prior: None, + state_prior_process: None, + view_prior_process: None, comments: None, row_names: RowNameList::new(), col_metadata: ColMetadataList::try_from(vec![ @@ -854,8 +895,8 @@ mod tests { let raw = indoc!( r#" table_name: my-table - state_alpha_prior: null - view_alpha_prior: null + state_prior_process: null + view_prior_process: null col_metadata: - name: one coltype: !Continuous @@ -891,8 +932,8 @@ mod tests { fn serialize_then_deserialize() { let codebook = Codebook { table_name: "my-table".into(), - state_alpha_prior: None, - view_alpha_prior: None, + state_prior_process: None, + view_prior_process: None, comments: None, row_names: RowNameList::new(), col_metadata: ColMetadataList::try_from(vec![ diff --git a/lace/lace_codebook/src/data.rs b/lace/lace_codebook/src/data.rs index e525d543..3c545df1 100644 --- a/lace/lace_codebook/src/data.rs +++ b/lace/lace_codebook/src/data.rs @@ -1,11 +1,12 @@ +use crate::codebook::PriorProcess; use crate::error::{CodebookError, ReadError}; use crate::{ Codebook, ColMetadata, ColMetadataList, ColType, RowNameList, ValueMap, }; + use lace_stats::prior::csd::CsdHyper; use lace_stats::prior::nix::NixHyper; use lace_stats::prior::pg::PgHyper; -use lace_stats::rv::dist::Gamma; use polars::prelude::{CsvReader, DataFrame, DataType, SerReader, Series}; use std::convert::TryFrom; use std::path::Path; @@ -428,7 +429,8 @@ fn rownames_from_index(id_srs: &Series) -> Result { pub fn df_to_codebook( df: &DataFrame, cat_cutoff: Option, - alpha_prior_opt: Option, + state_prior_process: Option, + view_prior_process: Option, no_hypers: bool, ) -> Result { let (col_metadata, row_names) = { @@ -457,13 +459,10 @@ pub fn df_to_codebook( (col_metadata, row_names) }; - let alpha_prior = - alpha_prior_opt.unwrap_or_else(lace_consts::general_alpha_prior); - Ok(Codebook { table_name: "my_table".into(), - state_alpha_prior: Some(alpha_prior.clone()), - view_alpha_prior: Some(alpha_prior), + state_prior_process, + view_prior_process, col_metadata, row_names, comments: None, @@ -481,11 +480,18 @@ pub fn read_csv>(path: P) -> Result { pub fn codebook_from_csv>( path: P, cat_cutoff: Option, - alpha_prior_opt: Option, + state_prior_process: Option, + view_prior_process: Option, no_hypers: bool, ) -> Result { let df = read_csv(path).unwrap(); - df_to_codebook(&df, cat_cutoff, alpha_prior_opt, no_hypers) + df_to_codebook( + &df, + cat_cutoff, + state_prior_process, + view_prior_process, + no_hypers, + ) } #[cfg(test)] @@ -521,7 +527,7 @@ mod test { let file = write_to_tempfile(&data); let codebook = - codebook_from_csv(file.path(), None, None, false).unwrap(); + codebook_from_csv(file.path(), None, None, None, false).unwrap(); assert_eq!(codebook.col_metadata.len(), 5); assert_eq!(codebook.row_names.len(), 5); diff --git a/lace/lace_codebook/src/formats.rs b/lace/lace_codebook/src/formats.rs index 5cb80b85..9f0e8150 100644 --- a/lace/lace_codebook/src/formats.rs +++ b/lace/lace_codebook/src/formats.rs @@ -1,5 +1,4 @@ use crate::ReadError; -use lace_stats::rv::dist::Gamma; use polars::prelude::{ CsvReader, DataFrame, IpcReader, JsonFormat, JsonReader, ParquetReader, SerReader, @@ -54,14 +53,16 @@ macro_rules! codebook_from_fn { pub fn $fn_name>( path: P, cat_cutoff: Option, - alpha_prior_opt: Option, + state_prior_process: Option<$crate::codebook::PriorProcess>, + view_prior_process: Option<$crate::codebook::PriorProcess>, no_hypers: bool, ) -> Result<$crate::codebook::Codebook, $crate::error::CodebookError> { let df = $reader(path).unwrap(); $crate::data::df_to_codebook( &df, cat_cutoff, - alpha_prior_opt, + state_prior_process, + view_prior_process, no_hypers, ) } diff --git a/lace/lace_codebook/src/lib.rs b/lace/lace_codebook/src/lib.rs index c420b815..54d04969 100644 --- a/lace/lace_codebook/src/lib.rs +++ b/lace/lace_codebook/src/lib.rs @@ -17,14 +17,19 @@ //! let codebook_str = indoc!(" //! --- //! table_name: two column dataset -//! state_alpha_prior: -//! !Gamma -//! shape: 1.0 -//! rate: 1.0 -//! view_alpha_prior: -//! !Gamma -//! shape: 1.0 -//! rate: 1.0 +//! state_prior_process: +//! !dirichlet +//! alpha_prior: +//! shape: 1.0 +//! rate: 1.0 +//! view_prior_process: +//! !pitman_yor +//! alpha_prior: +//! shape: 1.0 +//! rate: 1.0 +//! d_prior: +//! alpha: 1.0 +//! beta: 2.0 //! col_metadata: //! - name: col_1 //! notes: first column with all fields filled in diff --git a/lace/lace_geweke/Cargo.toml b/lace/lace_geweke/Cargo.toml index 34920f4e..db782e19 100644 --- a/lace/lace_geweke/Cargo.toml +++ b/lace/lace_geweke/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "lace_geweke" -version = "0.3.0" +version = "0.4.0" authors = ["Promised AI"] edition = "2021" license = "BUSL-1.1" @@ -9,7 +9,7 @@ repository = "https://github.com/promised-ai/lace" description = "Geweke tester for Lace" [dependencies] -lace_stats = { path = "../lace_stats", version = "0.3.0" } +lace_stats = { path = "../lace_stats", version = "0.4.0" } lace_utils = { path = "../lace_utils", version = "0.3.0" } serde = { version = "1", features = ["derive"] } serde_yaml = "0.9.4" diff --git a/lace/lace_metadata/Cargo.toml b/lace/lace_metadata/Cargo.toml index 8662dd51..7d8d2296 100644 --- a/lace/lace_metadata/Cargo.toml +++ b/lace/lace_metadata/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "lace_metadata" -version = "0.6.0" +version = "0.7.0" authors = ["Promised AI"] edition = "2021" license = "BUSL-1.1" @@ -9,10 +9,10 @@ repository = "https://github.com/promised-ai/lace" description = "Archive of the metadata (savefile) formats for Lace. In charge of versioning and conversion." [dependencies] -lace_stats = { path = "../lace_stats", version = "0.3.0" } +lace_stats = { path = "../lace_stats", version = "0.4.0" } lace_data = { path = "../lace_data", version = "0.3.0" } -lace_codebook = { path = "../lace_codebook", version = "0.6.0" } -lace_cc = { path = "../lace_cc", version = "0.6.0" } +lace_codebook = { path = "../lace_codebook", version = "0.7.0" } +lace_cc = { path = "../lace_cc", version = "0.7.0" } serde = { version = "1", features = ["derive"] } serde_yaml = "0.9.4" serde_json = "1" diff --git a/lace/lace_metadata/resources/test/metadata/v0/codebook.yaml b/lace/lace_metadata/resources/test/metadata/v0/codebook.yaml new file mode 100644 index 00000000..95511be3 --- /dev/null +++ b/lace/lace_metadata/resources/test/metadata/v0/codebook.yaml @@ -0,0 +1,316 @@ +table_name: my_data +state_alpha_prior: + shape: 2.0 + rate: 1.0 +view_alpha_prior: + shape: 1.0 + rate: 2.0 +col_metadata: +- name: cat + coltype: !Categorical + k: 3 + hyper: + pr_alpha: + shape: 1.0 + scale: 1.0 + value_map: !string + 0: blue + 1: green + 2: red + prior: null + notes: null + missing_not_at_random: false +- name: cts + coltype: !Continuous + hyper: + pr_m: + mu: -0.12554951392089267 + sigma: 1.0060712163273107 + pr_k: + shape: 1.0 + rate: 1.0 + pr_v: + shape: 5.298317366548036 + scale: 5.298317366548036 + pr_s2: + shape: 5.298317366548036 + scale: 1.0121792923223145 + prior: null + notes: null + missing_not_at_random: false +- name: count + coltype: !Count + hyper: + pr_shape: + shape: 1.2463818344393327 + rate: 1.0 + pr_rate: + shape: 67.03804379304536 + scale: 1.0 + prior: null + notes: null + missing_not_at_random: false +- name: cat_msng + coltype: !Categorical + k: 3 + hyper: + pr_alpha: + shape: 1.0 + scale: 1.0 + value_map: !string + 0: blue + 1: green + 2: red + prior: null + notes: null + missing_not_at_random: false +- name: cts_msng + coltype: !Continuous + hyper: + pr_m: + mu: -0.1339303540874964 + sigma: 1.018971505873358 + pr_k: + shape: 1.0 + rate: 1.0 + pr_v: + shape: 4.941642422609304 + scale: 4.941642422609304 + pr_s2: + shape: 4.941642422609304 + scale: 1.038302929781819 + prior: null + notes: null + missing_not_at_random: true +- name: count_msng + coltype: !Continuous + hyper: + pr_m: + mu: 87.98571428571428 + sigma: 67.47147394209178 + pr_k: + shape: 1.0 + rate: 1.0 + pr_v: + shape: 4.941642422609304 + scale: 4.941642422609304 + pr_s2: + shape: 4.941642422609304 + scale: 4552.399795918369 + prior: null + notes: null + missing_not_at_random: true +- name: count_msgn + coltype: !Count + hyper: + pr_shape: + shape: 1.148060403429 + rate: 1.0 + pr_rate: + shape: 76.72567926097102 + scale: 1.0 + prior: null + notes: null + missing_not_at_random: true +comments: null +row_names: +- '0' +- '1' +- '2' +- '3' +- '4' +- '5' +- '6' +- '7' +- '8' +- '9' +- '10' +- '11' +- '12' +- '13' +- '14' +- '15' +- '16' +- '17' +- '18' +- '19' +- '20' +- '21' +- '22' +- '23' +- '24' +- '25' +- '26' +- '27' +- '28' +- '29' +- '30' +- '31' +- '32' +- '33' +- '34' +- '35' +- '36' +- '37' +- '38' +- '39' +- '40' +- '41' +- '42' +- '43' +- '44' +- '45' +- '46' +- '47' +- '48' +- '49' +- '50' +- '51' +- '52' +- '53' +- '54' +- '55' +- '56' +- '57' +- '58' +- '59' +- '60' +- '61' +- '62' +- '63' +- '64' +- '65' +- '66' +- '67' +- '68' +- '69' +- '70' +- '71' +- '72' +- '73' +- '74' +- '75' +- '76' +- '77' +- '78' +- '79' +- '80' +- '81' +- '82' +- '83' +- '84' +- '85' +- '86' +- '87' +- '88' +- '89' +- '90' +- '91' +- '92' +- '93' +- '94' +- '95' +- '96' +- '97' +- '98' +- '99' +- '100' +- '101' +- '102' +- '103' +- '104' +- '105' +- '106' +- '107' +- '108' +- '109' +- '110' +- '111' +- '112' +- '113' +- '114' +- '115' +- '116' +- '117' +- '118' +- '119' +- '120' +- '121' +- '122' +- '123' +- '124' +- '125' +- '126' +- '127' +- '128' +- '129' +- '130' +- '131' +- '132' +- '133' +- '134' +- '135' +- '136' +- '137' +- '138' +- '139' +- '140' +- '141' +- '142' +- '143' +- '144' +- '145' +- '146' +- '147' +- '148' +- '149' +- '150' +- '151' +- '152' +- '153' +- '154' +- '155' +- '156' +- '157' +- '158' +- '159' +- '160' +- '161' +- '162' +- '163' +- '164' +- '165' +- '166' +- '167' +- '168' +- '169' +- '170' +- '171' +- '172' +- '173' +- '174' +- '175' +- '176' +- '177' +- '178' +- '179' +- '180' +- '181' +- '182' +- '183' +- '184' +- '185' +- '186' +- '187' +- '188' +- '189' +- '190' +- '191' +- '192' +- '193' +- '194' +- '195' +- '196' +- '197' +- '198' +- '199' diff --git a/lace/lace_metadata/resources/test/metadata/v0/data.parquet b/lace/lace_metadata/resources/test/metadata/v0/data.parquet new file mode 100644 index 00000000..1a376798 Binary files /dev/null and b/lace/lace_metadata/resources/test/metadata/v0/data.parquet differ diff --git a/lace/lace_metadata/resources/test/metadata/v0/metadata.lace/0.diagnostics.csv b/lace/lace_metadata/resources/test/metadata/v0/metadata.lace/0.diagnostics.csv new file mode 100644 index 00000000..467c5235 --- /dev/null +++ b/lace/lace_metadata/resources/test/metadata/v0/metadata.lace/0.diagnostics.csv @@ -0,0 +1,2 @@ +loglike,logprior +-10925.41459790704,-248.52943029952922 diff --git a/lace/lace_metadata/resources/test/metadata/v0/metadata.lace/0.state b/lace/lace_metadata/resources/test/metadata/v0/metadata.lace/0.state new file mode 100644 index 00000000..41a9459e Binary files /dev/null and b/lace/lace_metadata/resources/test/metadata/v0/metadata.lace/0.state differ diff --git a/lace/lace_metadata/resources/test/metadata/v0/metadata.lace/1.diagnostics.csv b/lace/lace_metadata/resources/test/metadata/v0/metadata.lace/1.diagnostics.csv new file mode 100644 index 00000000..a3e36b48 --- /dev/null +++ b/lace/lace_metadata/resources/test/metadata/v0/metadata.lace/1.diagnostics.csv @@ -0,0 +1,2 @@ +loglike,logprior +-11192.073888273855,-125.03911560348585 diff --git a/lace/lace_metadata/resources/test/metadata/v0/metadata.lace/1.state b/lace/lace_metadata/resources/test/metadata/v0/metadata.lace/1.state new file mode 100644 index 00000000..259fad55 Binary files /dev/null and b/lace/lace_metadata/resources/test/metadata/v0/metadata.lace/1.state differ diff --git a/lace/lace_metadata/resources/test/metadata/v0/metadata.lace/2.diagnostics.csv b/lace/lace_metadata/resources/test/metadata/v0/metadata.lace/2.diagnostics.csv new file mode 100644 index 00000000..d7776299 --- /dev/null +++ b/lace/lace_metadata/resources/test/metadata/v0/metadata.lace/2.diagnostics.csv @@ -0,0 +1,2 @@ +loglike,logprior +-5369.307571819231,-402.8832169852534 diff --git a/lace/lace_metadata/resources/test/metadata/v0/metadata.lace/2.state b/lace/lace_metadata/resources/test/metadata/v0/metadata.lace/2.state new file mode 100644 index 00000000..48e2599e Binary files /dev/null and b/lace/lace_metadata/resources/test/metadata/v0/metadata.lace/2.state differ diff --git a/lace/lace_metadata/resources/test/metadata/v0/metadata.lace/3.diagnostics.csv b/lace/lace_metadata/resources/test/metadata/v0/metadata.lace/3.diagnostics.csv new file mode 100644 index 00000000..2daa298a --- /dev/null +++ b/lace/lace_metadata/resources/test/metadata/v0/metadata.lace/3.diagnostics.csv @@ -0,0 +1,2 @@ +loglike,logprior +-11498.565830803074,-91.28149757903498 diff --git a/lace/lace_metadata/resources/test/metadata/v0/metadata.lace/3.state b/lace/lace_metadata/resources/test/metadata/v0/metadata.lace/3.state new file mode 100644 index 00000000..ef566538 Binary files /dev/null and b/lace/lace_metadata/resources/test/metadata/v0/metadata.lace/3.state differ diff --git a/lace/lace_metadata/resources/test/metadata/v0/metadata.lace/config.yaml b/lace/lace_metadata/resources/test/metadata/v0/metadata.lace/config.yaml new file mode 100644 index 00000000..f0154028 --- /dev/null +++ b/lace/lace_metadata/resources/test/metadata/v0/metadata.lace/config.yaml @@ -0,0 +1,2 @@ +metadata_version: 0 +serialized_type: bincode diff --git a/lace/lace_metadata/resources/test/metadata/v0/metadata.lace/lace.codebook b/lace/lace_metadata/resources/test/metadata/v0/metadata.lace/lace.codebook new file mode 100644 index 00000000..2b13f8c4 Binary files /dev/null and b/lace/lace_metadata/resources/test/metadata/v0/metadata.lace/lace.codebook differ diff --git a/lace/lace_metadata/resources/test/metadata/v0/metadata.lace/lace.data b/lace/lace_metadata/resources/test/metadata/v0/metadata.lace/lace.data new file mode 100644 index 00000000..056161e0 Binary files /dev/null and b/lace/lace_metadata/resources/test/metadata/v0/metadata.lace/lace.data differ diff --git a/lace/lace_metadata/resources/test/metadata/v0/metadata.lace/rng.yaml b/lace/lace_metadata/resources/test/metadata/v0/metadata.lace/rng.yaml new file mode 100644 index 00000000..c951a030 --- /dev/null +++ b/lace/lace_metadata/resources/test/metadata/v0/metadata.lace/rng.yaml @@ -0,0 +1,5 @@ +s: +- 6739545338414267512 +- 6183947708318209450 +- 3453482548197079121 +- 4896452037606409197 diff --git a/lace/lace_metadata/resources/test/metadata/v1/codebook.yaml b/lace/lace_metadata/resources/test/metadata/v1/codebook.yaml new file mode 100644 index 00000000..8207fb5b --- /dev/null +++ b/lace/lace_metadata/resources/test/metadata/v1/codebook.yaml @@ -0,0 +1,312 @@ +table_name: my_table +state_prior_process: null +view_prior_process: null +col_metadata: +- name: cat + coltype: !Categorical + k: 3 + hyper: + pr_alpha: + shape: 1.0 + scale: 1.0 + value_map: !string + 0: blue + 1: green + 2: red + prior: null + notes: null + missing_not_at_random: false +- name: cts + coltype: !Continuous + hyper: + pr_m: + mu: -0.12554951392089267 + sigma: 1.0060712163273107 + pr_k: + shape: 1.0 + rate: 1.0 + pr_v: + shape: 5.298317366548036 + scale: 5.298317366548036 + pr_s2: + shape: 5.298317366548036 + scale: 1.0121792923223145 + prior: null + notes: null + missing_not_at_random: false +- name: count + coltype: !Count + hyper: + pr_shape: + shape: 1.2463818344393327 + rate: 1.0 + pr_rate: + shape: 67.03804379304536 + scale: 1.0 + prior: null + notes: null + missing_not_at_random: false +- name: cat_msng + coltype: !Categorical + k: 3 + hyper: + pr_alpha: + shape: 1.0 + scale: 1.0 + value_map: !string + 0: blue + 1: green + 2: red + prior: null + notes: null + missing_not_at_random: false +- name: cts_msng + coltype: !Continuous + hyper: + pr_m: + mu: -0.1339303540874964 + sigma: 1.018971505873358 + pr_k: + shape: 1.0 + rate: 1.0 + pr_v: + shape: 4.941642422609304 + scale: 4.941642422609304 + pr_s2: + shape: 4.941642422609304 + scale: 1.038302929781819 + prior: null + notes: null + missing_not_at_random: false +- name: count_msng + coltype: !Continuous + hyper: + pr_m: + mu: 87.98571428571428 + sigma: 67.47147394209178 + pr_k: + shape: 1.0 + rate: 1.0 + pr_v: + shape: 4.941642422609304 + scale: 4.941642422609304 + pr_s2: + shape: 4.941642422609304 + scale: 4552.399795918369 + prior: null + notes: null + missing_not_at_random: false +- name: count_msgn + coltype: !Count + hyper: + pr_shape: + shape: 1.148060403429 + rate: 1.0 + pr_rate: + shape: 76.72567926097102 + scale: 1.0 + prior: null + notes: null + missing_not_at_random: false +comments: null +row_names: +- '0' +- '1' +- '2' +- '3' +- '4' +- '5' +- '6' +- '7' +- '8' +- '9' +- '10' +- '11' +- '12' +- '13' +- '14' +- '15' +- '16' +- '17' +- '18' +- '19' +- '20' +- '21' +- '22' +- '23' +- '24' +- '25' +- '26' +- '27' +- '28' +- '29' +- '30' +- '31' +- '32' +- '33' +- '34' +- '35' +- '36' +- '37' +- '38' +- '39' +- '40' +- '41' +- '42' +- '43' +- '44' +- '45' +- '46' +- '47' +- '48' +- '49' +- '50' +- '51' +- '52' +- '53' +- '54' +- '55' +- '56' +- '57' +- '58' +- '59' +- '60' +- '61' +- '62' +- '63' +- '64' +- '65' +- '66' +- '67' +- '68' +- '69' +- '70' +- '71' +- '72' +- '73' +- '74' +- '75' +- '76' +- '77' +- '78' +- '79' +- '80' +- '81' +- '82' +- '83' +- '84' +- '85' +- '86' +- '87' +- '88' +- '89' +- '90' +- '91' +- '92' +- '93' +- '94' +- '95' +- '96' +- '97' +- '98' +- '99' +- '100' +- '101' +- '102' +- '103' +- '104' +- '105' +- '106' +- '107' +- '108' +- '109' +- '110' +- '111' +- '112' +- '113' +- '114' +- '115' +- '116' +- '117' +- '118' +- '119' +- '120' +- '121' +- '122' +- '123' +- '124' +- '125' +- '126' +- '127' +- '128' +- '129' +- '130' +- '131' +- '132' +- '133' +- '134' +- '135' +- '136' +- '137' +- '138' +- '139' +- '140' +- '141' +- '142' +- '143' +- '144' +- '145' +- '146' +- '147' +- '148' +- '149' +- '150' +- '151' +- '152' +- '153' +- '154' +- '155' +- '156' +- '157' +- '158' +- '159' +- '160' +- '161' +- '162' +- '163' +- '164' +- '165' +- '166' +- '167' +- '168' +- '169' +- '170' +- '171' +- '172' +- '173' +- '174' +- '175' +- '176' +- '177' +- '178' +- '179' +- '180' +- '181' +- '182' +- '183' +- '184' +- '185' +- '186' +- '187' +- '188' +- '189' +- '190' +- '191' +- '192' +- '193' +- '194' +- '195' +- '196' +- '197' +- '198' +- '199' diff --git a/lace/lace_metadata/resources/test/metadata/v1/metadata.lace/0.diagnostics.csv b/lace/lace_metadata/resources/test/metadata/v1/metadata.lace/0.diagnostics.csv new file mode 100644 index 00000000..8af4f6ab --- /dev/null +++ b/lace/lace_metadata/resources/test/metadata/v1/metadata.lace/0.diagnostics.csv @@ -0,0 +1,11 @@ +loglike,logprior +-9322.573577333967,0.8011617909578268 +-8091.773286295056,17.57728505848032 +-7102.691003107509,-20.450860835926722 +-6109.692165910345,6.399352506589128 +-5544.004421332932,-3.0875989795577103 +-5092.780889446918,7.681785880761241 +-4705.777292758443,-12.024459034238777 +-4382.645397384758,-4.177482580888658 +-4037.8726667513747,-11.29922787781795 +-3994.6107845240213,-14.339492405667933 diff --git a/lace/lace_metadata/resources/test/metadata/v1/metadata.lace/0.state b/lace/lace_metadata/resources/test/metadata/v1/metadata.lace/0.state new file mode 100644 index 00000000..01b8b70d Binary files /dev/null and b/lace/lace_metadata/resources/test/metadata/v1/metadata.lace/0.state differ diff --git a/lace/lace_metadata/resources/test/metadata/v1/metadata.lace/1.diagnostics.csv b/lace/lace_metadata/resources/test/metadata/v1/metadata.lace/1.diagnostics.csv new file mode 100644 index 00000000..393387f6 --- /dev/null +++ b/lace/lace_metadata/resources/test/metadata/v1/metadata.lace/1.diagnostics.csv @@ -0,0 +1,11 @@ +loglike,logprior +-9609.92370618558,-13.14814572031007 +-8678.653104315887,-11.71974328539811 +-7778.630766283191,-11.168982762963822 +-4914.483172879051,-24.279491361061424 +-4384.694911219802,-2.399070504063549 +-4358.782617401552,-3.786398860100305 +-4232.529895337348,-8.448527259869834 +-4096.343492636015,8.544125037466497 +-3917.593292727098,10.223971445645432 +-3732.6427709200702,-6.678249358526918 diff --git a/lace/lace_metadata/resources/test/metadata/v1/metadata.lace/1.state b/lace/lace_metadata/resources/test/metadata/v1/metadata.lace/1.state new file mode 100644 index 00000000..70222a46 Binary files /dev/null and b/lace/lace_metadata/resources/test/metadata/v1/metadata.lace/1.state differ diff --git a/lace/lace_metadata/resources/test/metadata/v1/metadata.lace/2.diagnostics.csv b/lace/lace_metadata/resources/test/metadata/v1/metadata.lace/2.diagnostics.csv new file mode 100644 index 00000000..cb605f0a --- /dev/null +++ b/lace/lace_metadata/resources/test/metadata/v1/metadata.lace/2.diagnostics.csv @@ -0,0 +1,11 @@ +loglike,logprior +-7820.733122404915,1.8584824986981152 +-5631.531862829009,-18.035523971478796 +-4829.045061186832,-12.881169652516501 +-4275.600507658869,-13.90060813571345 +-3985.183487089194,-17.950836613327276 +-3726.706477619133,-20.18027368130369 +-3558.9431204799303,-15.093427209793461 +-3524.300832642899,-16.792937576289212 +-3517.030236367397,-18.33508257674344 +-3485.1648002590455,-15.8390617875546 diff --git a/lace/lace_metadata/resources/test/metadata/v1/metadata.lace/2.state b/lace/lace_metadata/resources/test/metadata/v1/metadata.lace/2.state new file mode 100644 index 00000000..4302ede2 Binary files /dev/null and b/lace/lace_metadata/resources/test/metadata/v1/metadata.lace/2.state differ diff --git a/lace/lace_metadata/resources/test/metadata/v1/metadata.lace/3.diagnostics.csv b/lace/lace_metadata/resources/test/metadata/v1/metadata.lace/3.diagnostics.csv new file mode 100644 index 00000000..3be23ffc --- /dev/null +++ b/lace/lace_metadata/resources/test/metadata/v1/metadata.lace/3.diagnostics.csv @@ -0,0 +1,11 @@ +loglike,logprior +-10887.659117268042,4.4677284438580696 +-10748.404837614287,-2.3130437622533906 +-10735.24505083381,-7.558075372957182 +-10702.25510289215,-8.764376465644666 +-10705.612794045579,-3.8903977043777562 +-10666.769325387486,-10.779826270298251 +-10318.86014024385,-1.4776990027086931 +-10124.650167386344,-0.9534088286827629 +-9915.202418761974,-17.308900887721407 +-9824.615186461757,-12.127615173204081 diff --git a/lace/lace_metadata/resources/test/metadata/v1/metadata.lace/3.state b/lace/lace_metadata/resources/test/metadata/v1/metadata.lace/3.state new file mode 100644 index 00000000..4148a53d Binary files /dev/null and b/lace/lace_metadata/resources/test/metadata/v1/metadata.lace/3.state differ diff --git a/lace/lace_metadata/resources/test/metadata/v1/metadata.lace/config.yaml b/lace/lace_metadata/resources/test/metadata/v1/metadata.lace/config.yaml new file mode 100644 index 00000000..bcff6727 --- /dev/null +++ b/lace/lace_metadata/resources/test/metadata/v1/metadata.lace/config.yaml @@ -0,0 +1,2 @@ +metadata_version: 1 +serialized_type: bincode diff --git a/lace/lace_metadata/resources/test/metadata/v1/metadata.lace/lace.codebook b/lace/lace_metadata/resources/test/metadata/v1/metadata.lace/lace.codebook new file mode 100644 index 00000000..502285b3 Binary files /dev/null and b/lace/lace_metadata/resources/test/metadata/v1/metadata.lace/lace.codebook differ diff --git a/lace/lace_metadata/resources/test/metadata/v1/metadata.lace/lace.data b/lace/lace_metadata/resources/test/metadata/v1/metadata.lace/lace.data new file mode 100644 index 00000000..056161e0 Binary files /dev/null and b/lace/lace_metadata/resources/test/metadata/v1/metadata.lace/lace.data differ diff --git a/lace/lace_metadata/resources/test/metadata/v1/metadata.lace/rng.yaml b/lace/lace_metadata/resources/test/metadata/v1/metadata.lace/rng.yaml new file mode 100644 index 00000000..f866d400 --- /dev/null +++ b/lace/lace_metadata/resources/test/metadata/v1/metadata.lace/rng.yaml @@ -0,0 +1,5 @@ +s: +- 6673150548032836391 +- 8660570340892567824 +- 6313710878850318851 +- 13781755498804928295 diff --git a/lace/lace_metadata/src/config.rs b/lace/lace_metadata/src/config.rs index 18254a70..6bba2345 100644 --- a/lace/lace_metadata/src/config.rs +++ b/lace/lace_metadata/src/config.rs @@ -28,7 +28,7 @@ impl FromStr for SerializedType { impl Default for SerializedType { fn default() -> Self { - Self::Yaml + Self::Bincode } } diff --git a/lace/lace_metadata/src/latest.rs b/lace/lace_metadata/src/latest.rs index 2fed4fb2..a00bcb61 100644 --- a/lace/lace_metadata/src/latest.rs +++ b/lace/lace_metadata/src/latest.rs @@ -1,28 +1,18 @@ use std::collections::BTreeMap; -use lace_cc::assignment::Assignment; -use lace_cc::component::ConjugateComponent; -use lace_cc::feature::{ColModel, Column, MissingNotAtRandom}; -use lace_cc::state::{State, StateDiagnostics}; -use lace_cc::traits::{LaceDatum, LaceLikelihood, LacePrior, LaceStat}; +use lace_cc::feature::{ColModel, MissingNotAtRandom}; +use lace_cc::state::{State, StateDiagnostics, StateScoreComponents}; use lace_cc::view::View; -use lace_data::{FeatureData, SparseContainer}; -use lace_stats::prior::csd::CsdHyper; -use lace_stats::prior::nix::NixHyper; -use lace_stats::prior::pg::PgHyper; -use lace_stats::rv::dist::{ - Bernoulli, Beta, Categorical, Gamma, Gaussian, Mixture, - NormalInvChiSquared, Poisson, SymmetricDirichlet, -}; -use lace_stats::MixtureType; +use lace_stats::assignment::Assignment; +use lace_stats::prior_process::{PriorProcess, Process}; use rand_xoshiro::Xoshiro256Plus; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use std::sync::OnceLock; +use serde::{Deserialize, Serialize}; +use crate::versions::v1; use crate::{impl_metadata_version, to_from_newtype, MetadataVersion}; -pub const METADATA_VERSION: i32 = 0; +pub const METADATA_VERSION: i32 = 1; #[derive(Serialize, Deserialize, Debug)] #[serde(deny_unknown_fields)] @@ -35,26 +25,10 @@ pub struct Metadata { pub states: Vec, pub state_ids: Vec, pub codebook: Codebook, - pub data: Option, + pub data: Option, pub rng: Option, } -#[derive(Serialize, Deserialize, Debug)] -#[serde(deny_unknown_fields)] -pub struct DataStore(BTreeMap); - -impl From for DataStore { - fn from(data: lace_data::DataStore) -> Self { - Self(data.0) - } -} - -impl From for lace_data::DataStore { - fn from(data: DataStore) -> Self { - Self(data.0) - } -} - #[derive(Serialize, Deserialize, Debug)] pub struct DatalessStateAndDiagnostics { pub state: DatalessState, @@ -65,16 +39,9 @@ pub struct DatalessStateAndDiagnostics { #[serde(deny_unknown_fields)] pub struct DatalessState { pub views: Vec, - pub asgn: Assignment, + pub prior_process: PriorProcess, pub weights: Vec, - pub view_alpha_prior: Gamma, - pub loglike: f64, - #[serde(default)] - pub log_prior: f64, - #[serde(default)] - pub log_view_alpha_prior: f64, - #[serde(default)] - pub log_state_alpha_prior: f64, + pub score: StateScoreComponents, } /// Marks a state as having no data in its columns @@ -85,13 +52,9 @@ impl From for DatalessStateAndDiagnostics { Self { state: DatalessState { views: state.views.drain(..).map(|view| view.into()).collect(), - asgn: state.asgn, + prior_process: state.prior_process, weights: state.weights, - view_alpha_prior: state.view_alpha_prior, - loglike: state.loglike, - log_prior: state.log_prior, - log_view_alpha_prior: state.log_view_alpha_prior, - log_state_alpha_prior: state.log_state_alpha_prior, + score: state.score, }, diagnostics: state.diagnostics, } @@ -113,21 +76,24 @@ impl From for EmptyState { .map(|id| { let dl_ftr = dl_view.ftrs.remove(&id).unwrap(); let cm: ColModel = match dl_ftr { - DatalessColModel::Continuous(cm) => { - let ecm: EmptyColumn<_, _, _, _> = cm.into(); + v1::DatalessColModel::Continuous(cm) => { + let ecm: v1::EmptyColumn<_, _, _, _> = + cm.into(); ColModel::Continuous(ecm.0) } - DatalessColModel::Categorical(cm) => { - let ecm: EmptyColumn<_, _, _, _> = cm.into(); + v1::DatalessColModel::Categorical(cm) => { + let ecm: v1::EmptyColumn<_, _, _, _> = + cm.into(); ColModel::Categorical(ecm.0) } - DatalessColModel::Count(cm) => { - let ecm: EmptyColumn<_, _, _, _> = cm.into(); + v1::DatalessColModel::Count(cm) => { + let ecm: v1::EmptyColumn<_, _, _, _> = + cm.into(); ColModel::Count(ecm.0) } - DatalessColModel::MissingNotAtRandom(mnar) => { + v1::DatalessColModel::MissingNotAtRandom(mnar) => { let fx: ColModel = (*mnar.fx).into(); - let missing: EmptyColumn<_, _, _, _> = + let missing: v1::EmptyColumn<_, _, _, _> = mnar.missing.into(); ColModel::MissingNotAtRandom( MissingNotAtRandom { @@ -142,7 +108,7 @@ impl From for EmptyState { .collect(); View { - asgn: dl_view.asgn, + prior_process: dl_view.prior_process, weights: dl_view.weights, ftrs, } @@ -151,13 +117,9 @@ impl From for EmptyState { EmptyState(State { views, - asgn: dl_state.state.asgn, + prior_process: dl_state.state.prior_process, weights: dl_state.state.weights, - view_alpha_prior: dl_state.state.view_alpha_prior, - loglike: dl_state.state.loglike, - log_prior: dl_state.state.log_prior, - log_view_alpha_prior: dl_state.state.log_view_alpha_prior, - log_state_alpha_prior: dl_state.state.log_state_alpha_prior, + score: dl_state.state.score, diagnostics: dl_state.diagnostics, }) } @@ -166,8 +128,8 @@ impl From for EmptyState { #[derive(Serialize, Deserialize, Debug)] #[serde(deny_unknown_fields)] pub struct DatalessView { - pub ftrs: BTreeMap, - pub asgn: Assignment, + pub ftrs: BTreeMap, + pub prior_process: PriorProcess, pub weights: Vec, } @@ -180,166 +142,107 @@ impl From for DatalessView { .map(|k| (*k, view.ftrs.remove(k).unwrap().into())) .collect() }, - asgn: view.asgn, + prior_process: view.prior_process, weights: view.weights, } } } -#[derive(Serialize, Deserialize, Debug)] -#[serde(deny_unknown_fields)] -pub enum DatalessColModel { - Continuous(DatalessColumn), - Categorical(DatalessColumn), - Count(DatalessColumn), - MissingNotAtRandom(DatalessMissingNotAtRandom), -} - -impl From for DatalessColModel { - fn from(col_model: ColModel) -> DatalessColModel { - match col_model { - ColModel::Categorical(col) => { - DatalessColModel::Categorical(col.into()) - } - ColModel::Continuous(col) => { - DatalessColModel::Continuous(col.into()) - } - ColModel::Count(col) => DatalessColModel::Count(col.into()), - ColModel::MissingNotAtRandom(mnar) => { - DatalessColModel::MissingNotAtRandom( - DatalessMissingNotAtRandom { - fx: Box::new((*mnar.fx).into()), - missing: mnar.present.into(), - }, - ) - } +impl From for PriorProcess { + fn from(asgn: v1::Assignment) -> Self { + Self { + asgn: Assignment { + asgn: asgn.asgn, + counts: asgn.counts, + n_cats: asgn.n_cats, + }, + process: Process::Dirichlet(lace_stats::prior_process::Dirichlet { + alpha: asgn.alpha, + alpha_prior: asgn.prior, + }), } } } -impl From for ColModel { - fn from(col_model: DatalessColModel) -> Self { - match col_model { - DatalessColModel::Continuous(cm) => { - let empty_col: EmptyColumn<_, _, _, _> = cm.into(); - Self::Continuous(empty_col.0) - } - DatalessColModel::Count(cm) => { - let empty_col: EmptyColumn<_, _, _, _> = cm.into(); - Self::Count(empty_col.0) - } - DatalessColModel::Categorical(cm) => { - let empty_col: EmptyColumn<_, _, _, _> = cm.into(); - Self::Categorical(empty_col.0) - } - _ => unimplemented!(), +impl From for DatalessView { + fn from(view: v1::DatalessView) -> Self { + Self { + ftrs: view.ftrs, + prior_process: view.asgn.into(), + weights: view.weights, } } } -#[derive(Serialize, Deserialize, Debug)] -#[serde(deny_unknown_fields)] -pub struct DatalessMissingNotAtRandom { - fx: Box, - missing: DatalessColumn, -} - -#[derive(Serialize, Deserialize, Debug)] -#[serde(deny_unknown_fields)] -pub struct DatalessColumn -where - X: LaceDatum, - Fx: LaceLikelihood, - Pr: LacePrior, - H: Serialize + DeserializeOwned, - MixtureType: From>, - Fx::Stat: LaceStat, - Pr::LnMCache: Clone + std::fmt::Debug, - Pr::LnPpCache: Send + Sync + Clone + std::fmt::Debug, -{ - pub id: usize, - #[serde(bound(deserialize = "X: serde::de::DeserializeOwned"))] - pub components: Vec>, - #[serde(bound(deserialize = "Pr: serde::de::DeserializeOwned"))] - pub prior: Pr, - #[serde(bound(deserialize = "H: serde::de::DeserializeOwned"))] - pub hyper: H, - #[serde(default)] - pub ignore_hyper: bool, +impl From for DatalessState { + fn from(mut state: v1::DatalessState) -> Self { + Self { + views: state.views.drain(..).map(|view| view.into()).collect(), + prior_process: state.asgn.into(), + weights: state.weights, + score: StateScoreComponents { + ln_likelihood: state.loglike, + ln_prior: state.log_prior, + ln_state_prior_process: state.log_state_alpha_prior, + ln_view_prior_process: state.log_view_alpha_prior, + }, + } + } } -macro_rules! col2dataless { - ($x:ty, $fx:ty, $pr:ty, $h:ty) => { - impl From> - for DatalessColumn<$x, $fx, $pr, $h> - { - fn from(col: Column<$x, $fx, $pr, $h>) -> Self { - DatalessColumn { - id: col.id, - components: col.components, - prior: col.prior, - hyper: col.hyper, - ignore_hyper: col.ignore_hyper, - } - } +impl From for DatalessStateAndDiagnostics { + fn from(state_and_diag: v1::DatalessStateAndDiagnostics) -> Self { + Self { + state: state_and_diag.state.into(), + diagnostics: state_and_diag.diagnostics, } - }; + } } -col2dataless!(f64, Gaussian, NormalInvChiSquared, NixHyper); -col2dataless!(u8, Categorical, SymmetricDirichlet, CsdHyper); -col2dataless!(u32, Poisson, Gamma, PgHyper); -col2dataless!(bool, Bernoulli, Beta, ()); - -struct EmptyColumn(Column) -where - X: LaceDatum, - Fx: LaceLikelihood, - Fx::Stat: LaceStat, - Pr: LacePrior, - H: Serialize + DeserializeOwned, - Pr::LnMCache: Clone + std::fmt::Debug, - Pr::LnPpCache: Send + Sync + Clone + std::fmt::Debug, - MixtureType: From>; +impl From for Codebook { + fn from(codebook: v1::Codebook) -> Self { + Self(lace_codebook::Codebook { + table_name: codebook.table_name, + state_prior_process: codebook.state_alpha_prior.map( + |alpha_prior| lace_codebook::PriorProcess::Dirichlet { + alpha_prior, + }, + ), + view_prior_process: codebook.view_alpha_prior.map(|alpha_prior| { + lace_codebook::PriorProcess::Dirichlet { alpha_prior } + }), + col_metadata: codebook.col_metadata, + comments: codebook.comments, + row_names: codebook.row_names, + }) + } +} -macro_rules! dataless2col { - ($x:ty, $fx:ty, $pr:ty, $h:ty) => { - impl From> - for EmptyColumn<$x, $fx, $pr, $h> - { - fn from( - col_dl: DatalessColumn<$x, $fx, $pr, $h>, - ) -> EmptyColumn<$x, $fx, $pr, $h> { - EmptyColumn(Column { - id: col_dl.id, - components: col_dl.components, - prior: col_dl.prior, - hyper: col_dl.hyper, - data: SparseContainer::default(), - ln_m_cache: OnceLock::new(), - ignore_hyper: col_dl.ignore_hyper, - }) - } +impl From for Metadata { + fn from(mut metadata: v1::Metadata) -> Self { + Self { + states: metadata + .states + .drain(..) + .map(|state| state.into()) + .collect(), + state_ids: metadata.state_ids, + codebook: metadata.codebook.into(), + data: metadata.data, + rng: metadata.rng, } - }; + } } -dataless2col!(f64, Gaussian, NormalInvChiSquared, NixHyper); -dataless2col!(u8, Categorical, SymmetricDirichlet, CsdHyper); -dataless2col!(u32, Poisson, Gamma, PgHyper); -dataless2col!(bool, Bernoulli, Beta, ()); - impl_metadata_version!(Metadata, METADATA_VERSION); impl_metadata_version!(Codebook, METADATA_VERSION); -impl_metadata_version!(DatalessColModel, METADATA_VERSION); impl_metadata_version!(DatalessView, METADATA_VERSION); impl_metadata_version!(DatalessState, METADATA_VERSION); -impl_metadata_version!(DataStore, METADATA_VERSION); // Create the loaders module for latest crate::loaders!( DatalessStateAndDiagnostics, - DataStore, + v1::DataStore, Codebook, rand_xoshiro::Xoshiro256Plus ); diff --git a/lace/lace_metadata/src/lib.rs b/lace/lace_metadata/src/lib.rs index bbe00195..f1b0d0da 100644 --- a/lace/lace_metadata/src/lib.rs +++ b/lace/lace_metadata/src/lib.rs @@ -13,6 +13,7 @@ mod config; mod error; pub mod latest; mod utils; +pub mod versions; pub use utils::{deserialize_file, save_state, serialize_obj}; @@ -58,7 +59,7 @@ macro_rules! to_from_newtype { } /// creates a bunch of helper functions in a `load` module that load the -/// metadata components and create and `Meatadata` object of the appropriate +/// metadata components and create and `Metadata` object of the appropriate /// version. #[macro_export] macro_rules! loaders { @@ -132,7 +133,7 @@ macro_rules! loaders { load(codebook_path, file_config.serialized_type) } - pub(crate) fn load_meatadata>( + pub(crate) fn load_metadata>( path: P, file_config: &$crate::config::FileConfig, ) -> Result { @@ -202,7 +203,9 @@ pub fn load_metadata>( let md_version = file_config.metadata_version; match md_version { - 0 => crate::latest::load::load_meatadata(path, &file_config), + 0 => crate::versions::v1::load::load_metadata(path, &file_config) + .map(|metadata| metadata.into()), + 1 => crate::latest::load::load_metadata(path, &file_config), requested => Err(Error::UnsupportedMetadataVersion { requested, max_supported: crate::latest::METADATA_VERSION, diff --git a/lace/lace_metadata/src/utils.rs b/lace/lace_metadata/src/utils.rs index 18ae56a7..00b1439b 100644 --- a/lace/lace_metadata/src/utils.rs +++ b/lace/lace_metadata/src/utils.rs @@ -9,11 +9,12 @@ use log::info; use rand_xoshiro::Xoshiro256Plus; use serde::{Deserialize, Serialize}; +use crate::latest::Codebook; use crate::latest::DatalessStateAndDiagnostics; -use crate::latest::{Codebook, DataStore}; +use crate::versions::v1::DataStore; use crate::{Error, FileConfig, SerializedType}; -fn extenson_from_path>(path: &P) -> Result<&str, Error> { +fn extension_from_path>(path: &P) -> Result<&str, Error> { path.as_ref() .extension() .and_then(|s| s.to_str()) @@ -28,7 +29,7 @@ fn extenson_from_path>(path: &P) -> Result<&str, Error> { fn serialized_type_from_path>( path: &P, ) -> Result { - let ext = extenson_from_path(path)?; + let ext = extension_from_path(path)?; SerializedType::from_str(ext) } diff --git a/lace/lace_metadata/src/versions/mod.rs b/lace/lace_metadata/src/versions/mod.rs new file mode 100644 index 00000000..a3a6d96c --- /dev/null +++ b/lace/lace_metadata/src/versions/mod.rs @@ -0,0 +1 @@ +pub mod v1; diff --git a/lace/lace_metadata/src/versions/v1.rs b/lace/lace_metadata/src/versions/v1.rs new file mode 100644 index 00000000..2ece9d22 --- /dev/null +++ b/lace/lace_metadata/src/versions/v1.rs @@ -0,0 +1,365 @@ +use std::collections::BTreeMap; + +use lace_cc::component::ConjugateComponent; +use lace_cc::feature::{ColModel, Column}; +use lace_cc::state::StateDiagnostics; +use lace_cc::traits::{LaceDatum, LaceLikelihood, LacePrior, LaceStat}; +use lace_data::{FeatureData, SparseContainer}; +use lace_stats::prior::csd::CsdHyper; +use lace_stats::prior::nix::NixHyper; +use lace_stats::prior::pg::PgHyper; +use lace_stats::rv::dist::{ + Bernoulli, Beta, Categorical, Gamma, Gaussian, Mixture, + NormalInvChiSquared, Poisson, SymmetricDirichlet, +}; +use lace_stats::MixtureType; + +use rand_xoshiro::Xoshiro256Plus; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use std::sync::OnceLock; + +use crate::{impl_metadata_version, MetadataVersion}; + +pub const METADATA_VERSION: i32 = 0; + +// #[derive(Serialize, Deserialize, Debug)] +// #[serde(deny_unknown_fields)] +// pub struct Codebook(pub lace_codebook::Codebook); + +#[derive(Serialize, Deserialize, Debug)] +#[serde(deny_unknown_fields)] +pub struct Assignment { + pub alpha: f64, + pub asgn: Vec, + pub counts: Vec, + pub n_cats: usize, + pub prior: Gamma, +} + +/// Codebook object for storing information about the dataset +#[derive(Serialize, Deserialize, Debug)] +#[serde(deny_unknown_fields)] +pub struct Codebook { + pub table_name: String, + pub state_alpha_prior: Option, + pub view_alpha_prior: Option, + pub col_metadata: lace_codebook::ColMetadataList, + pub comments: Option, + pub row_names: lace_codebook::RowNameList, +} + +// to_from_newtype!(lace_codebook::Codebook, Codebook); + +#[derive(Debug, Serialize, Deserialize)] +pub struct Metadata { + pub states: Vec, + pub state_ids: Vec, + pub codebook: Codebook, + pub data: Option, + pub rng: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +#[serde(deny_unknown_fields)] +pub struct DataStore(BTreeMap); + +impl From for DataStore { + fn from(data: lace_data::DataStore) -> Self { + Self(data.0) + } +} + +impl From for lace_data::DataStore { + fn from(data: DataStore) -> Self { + Self(data.0) + } +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct DatalessStateAndDiagnostics { + pub state: DatalessState, + pub diagnostics: StateDiagnostics, +} + +#[derive(Serialize, Deserialize, Debug)] +#[serde(deny_unknown_fields)] +pub struct DatalessState { + pub views: Vec, + pub asgn: Assignment, + pub weights: Vec, + pub view_alpha_prior: Gamma, + pub loglike: f64, + #[serde(default)] + pub log_prior: f64, + #[serde(default)] + pub log_view_alpha_prior: f64, + #[serde(default)] + pub log_state_alpha_prior: f64, +} + +// /// Marks a state as having no data in its columns +// pub struct EmptyState(pub State); + +// impl From for DatalessStateAndDiagnostics { +// fn from(mut state: lace_cc::state::State) -> Self { +// Self { +// state: DatalessState { +// views: state.views.drain(..).map(|view| view.into()).collect(), +// asgn: state.asgn, +// weights: state.weights, +// view_alpha_prior: state.view_alpha_prior, +// loglike: state.loglike, +// log_prior: state.log_prior, +// log_view_alpha_prior: state.log_view_alpha_prior, +// log_state_alpha_prior: state.log_state_alpha_prior, +// }, +// diagnostics: state.diagnostics, +// } +// } +// } + +// impl From for EmptyState { +// fn from(mut dl_state: DatalessStateAndDiagnostics) -> EmptyState { +// let views = dl_state +// .state +// .views +// .drain(..) +// .map(|mut dl_view| { +// let mut ftr_ids: Vec = +// dl_view.ftrs.keys().copied().collect(); + +// let ftrs: BTreeMap = ftr_ids +// .drain(..) +// .map(|id| { +// let dl_ftr = dl_view.ftrs.remove(&id).unwrap(); +// let cm: ColModel = match dl_ftr { +// DatalessColModel::Continuous(cm) => { +// let ecm: EmptyColumn<_, _, _, _> = cm.into(); +// ColModel::Continuous(ecm.0) +// } +// DatalessColModel::Categorical(cm) => { +// let ecm: EmptyColumn<_, _, _, _> = cm.into(); +// ColModel::Categorical(ecm.0) +// } +// DatalessColModel::Count(cm) => { +// let ecm: EmptyColumn<_, _, _, _> = cm.into(); +// ColModel::Count(ecm.0) +// } +// DatalessColModel::MissingNotAtRandom(mnar) => { +// let fx: ColModel = (*mnar.fx).into(); +// let missing: EmptyColumn<_, _, _, _> = +// mnar.missing.into(); +// ColModel::MissingNotAtRandom( +// MissingNotAtRandom { +// fx: Box::new(fx), +// present: missing.0, +// }, +// ) +// } +// }; +// (id, cm) +// }) +// .collect(); + +// View { +// asgn: dl_view.asgn, +// weights: dl_view.weights, +// ftrs, +// } +// }) +// .collect(); + +// EmptyState(State { +// views, +// asgn: dl_state.state.asgn, +// weights: dl_state.state.weights, +// view_alpha_prior: dl_state.state.view_alpha_prior, +// loglike: dl_state.state.loglike, +// log_prior: dl_state.state.log_prior, +// log_view_alpha_prior: dl_state.state.log_view_alpha_prior, +// log_state_alpha_prior: dl_state.state.log_state_alpha_prior, +// diagnostics: dl_state.diagnostics, +// }) +// } +// } + +#[derive(Serialize, Deserialize, Debug)] +#[serde(deny_unknown_fields)] +pub struct DatalessView { + pub ftrs: BTreeMap, + pub asgn: Assignment, + pub weights: Vec, +} + +// impl From for DatalessView { +// fn from(mut view: View) -> DatalessView { +// DatalessView { +// ftrs: { +// let keys: Vec = view.ftrs.keys().cloned().collect(); +// keys.iter() +// .map(|k| (*k, view.ftrs.remove(k).unwrap().into())) +// .collect() +// }, +// asgn: view.asgn, +// weights: view.weights, +// } +// } +// } + +#[derive(Serialize, Deserialize, Debug)] +#[serde(deny_unknown_fields)] +pub enum DatalessColModel { + Continuous(DatalessColumn), + Categorical(DatalessColumn), + Count(DatalessColumn), + MissingNotAtRandom(DatalessMissingNotAtRandom), +} + +impl From for DatalessColModel { + fn from(col_model: ColModel) -> DatalessColModel { + match col_model { + ColModel::Categorical(col) => { + DatalessColModel::Categorical(col.into()) + } + ColModel::Continuous(col) => { + DatalessColModel::Continuous(col.into()) + } + ColModel::Count(col) => DatalessColModel::Count(col.into()), + ColModel::MissingNotAtRandom(mnar) => { + DatalessColModel::MissingNotAtRandom( + DatalessMissingNotAtRandom { + fx: Box::new((*mnar.fx).into()), + missing: mnar.present.into(), + }, + ) + } + } + } +} + +impl From for ColModel { + fn from(col_model: DatalessColModel) -> Self { + match col_model { + DatalessColModel::Continuous(cm) => { + let empty_col: EmptyColumn<_, _, _, _> = cm.into(); + Self::Continuous(empty_col.0) + } + DatalessColModel::Count(cm) => { + let empty_col: EmptyColumn<_, _, _, _> = cm.into(); + Self::Count(empty_col.0) + } + DatalessColModel::Categorical(cm) => { + let empty_col: EmptyColumn<_, _, _, _> = cm.into(); + Self::Categorical(empty_col.0) + } + _ => unimplemented!(), + } + } +} + +#[derive(Serialize, Deserialize, Debug)] +#[serde(deny_unknown_fields)] +pub struct DatalessMissingNotAtRandom { + pub fx: Box, + pub missing: DatalessColumn, +} + +#[derive(Serialize, Deserialize, Debug)] +#[serde(deny_unknown_fields)] +pub struct DatalessColumn +where + X: LaceDatum, + Fx: LaceLikelihood, + Pr: LacePrior, + H: Serialize + DeserializeOwned, + MixtureType: From>, + Fx::Stat: LaceStat, + Pr::LnMCache: Clone + std::fmt::Debug, + Pr::LnPpCache: Send + Sync + Clone + std::fmt::Debug, +{ + pub id: usize, + #[serde(bound(deserialize = "X: serde::de::DeserializeOwned"))] + pub components: Vec>, + #[serde(bound(deserialize = "Pr: serde::de::DeserializeOwned"))] + pub prior: Pr, + #[serde(bound(deserialize = "H: serde::de::DeserializeOwned"))] + pub hyper: H, + #[serde(default)] + pub ignore_hyper: bool, +} + +macro_rules! col2dataless { + ($x:ty, $fx:ty, $pr:ty, $h:ty) => { + impl From> + for DatalessColumn<$x, $fx, $pr, $h> + { + fn from(col: Column<$x, $fx, $pr, $h>) -> Self { + DatalessColumn { + id: col.id, + components: col.components, + prior: col.prior, + hyper: col.hyper, + ignore_hyper: col.ignore_hyper, + } + } + } + }; +} + +col2dataless!(f64, Gaussian, NormalInvChiSquared, NixHyper); +col2dataless!(u8, Categorical, SymmetricDirichlet, CsdHyper); +col2dataless!(u32, Poisson, Gamma, PgHyper); +col2dataless!(bool, Bernoulli, Beta, ()); + +pub struct EmptyColumn(pub Column) +where + X: LaceDatum, + Fx: LaceLikelihood, + Fx::Stat: LaceStat, + Pr: LacePrior, + H: Serialize + DeserializeOwned, + Pr::LnMCache: Clone + std::fmt::Debug, + Pr::LnPpCache: Send + Sync + Clone + std::fmt::Debug, + MixtureType: From>; + +macro_rules! dataless2col { + ($x:ty, $fx:ty, $pr:ty, $h:ty) => { + impl From> + for EmptyColumn<$x, $fx, $pr, $h> + { + fn from( + col_dl: DatalessColumn<$x, $fx, $pr, $h>, + ) -> EmptyColumn<$x, $fx, $pr, $h> { + EmptyColumn(Column { + id: col_dl.id, + components: col_dl.components, + prior: col_dl.prior, + hyper: col_dl.hyper, + data: SparseContainer::default(), + ln_m_cache: OnceLock::new(), + ignore_hyper: col_dl.ignore_hyper, + }) + } + } + }; +} + +dataless2col!(f64, Gaussian, NormalInvChiSquared, NixHyper); +dataless2col!(u8, Categorical, SymmetricDirichlet, CsdHyper); +dataless2col!(u32, Poisson, Gamma, PgHyper); +dataless2col!(bool, Bernoulli, Beta, ()); + +impl_metadata_version!(Metadata, METADATA_VERSION); +impl_metadata_version!(Codebook, METADATA_VERSION); +impl_metadata_version!(DatalessColModel, METADATA_VERSION); +impl_metadata_version!(DatalessView, METADATA_VERSION); +impl_metadata_version!(DatalessState, METADATA_VERSION); +impl_metadata_version!(DataStore, METADATA_VERSION); + +// Create the loaders module for latest +crate::loaders!( + DatalessStateAndDiagnostics, + DataStore, + Codebook, + rand_xoshiro::Xoshiro256Plus +); diff --git a/lace/lace_metadata/tests/convert.rs b/lace/lace_metadata/tests/convert.rs new file mode 100644 index 00000000..73bb3f61 --- /dev/null +++ b/lace/lace_metadata/tests/convert.rs @@ -0,0 +1,21 @@ +use std::path::PathBuf; + +#[test] +fn read_v0() { + let path = PathBuf::from("resources") + .join("test") + .join("metadata") + .join("v0") + .join("metadata.lace"); + let _metadata = lace_metadata::load_metadata(path); +} + +#[test] +fn read_v1() { + let path = PathBuf::from("resources") + .join("test") + .join("metadata") + .join("v1") + .join("metadata.lace"); + let _metadata = lace_metadata::load_metadata(path); +} diff --git a/lace/lace_stats/Cargo.toml b/lace/lace_stats/Cargo.toml index c402aba6..d5b441e0 100644 --- a/lace/lace_stats/Cargo.toml +++ b/lace/lace_stats/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "lace_stats" -version = "0.3.0" +version = "0.4.0" rust-version = "1.62.0" authors = ["Promised AI"] edition = "2021" diff --git a/lace/lace_cc/src/assignment.rs b/lace/lace_stats/src/assignment.rs similarity index 61% rename from lace/lace_cc/src/assignment.rs rename to lace/lace_stats/src/assignment.rs index d4cfdc1c..01a8e4ab 100644 --- a/lace/lace_cc/src/assignment.rs +++ b/lace/lace_stats/src/assignment.rs @@ -1,15 +1,9 @@ //! Data structures for assignments of items to components (partitions) -use lace_stats::mh::mh_prior; -use lace_stats::rv::dist::Gamma; -use lace_stats::rv::traits::Rv; -use rand::SeedableRng; -use rand_xoshiro::Xoshiro256Plus; use serde::{Deserialize, Serialize}; use thiserror::Error; -use crate::misc::crp_draw; - /// Validates assignments if the `LACE_NOCHECK` is not set to `"1"`. +#[macro_export] macro_rules! validate_assignment { ($asgn:expr) => {{ let validate_asgn: bool = match option_env!("LACE_NOCHECK") { @@ -28,8 +22,6 @@ macro_rules! validate_assignment { #[allow(dead_code)] #[derive(Serialize, Deserialize, PartialEq, Debug, Clone)] pub struct Assignment { - /// The `Crp` discount parameter - pub alpha: f64, /// The assignment vector. `asgn[i]` is the partition index of the /// ith datum. pub asgn: Vec, @@ -37,8 +29,6 @@ pub struct Assignment { pub counts: Vec, /// The number of partitions/categories pub n_cats: usize, - /// The prior on `alpha` - pub prior: Gamma, } /// The possible ways an assignment can go wrong with incorrect bookkeeping @@ -208,166 +198,15 @@ impl AssignmentDiagnostics { } } -/// Constructs `Assignment`s -#[derive(Clone, Debug)] -pub struct AssignmentBuilder { - n: usize, - asgn: Option>, - alpha: Option, - prior: Option, - seed: Option, -} - -#[derive(Debug, Error, PartialEq)] -pub enum BuildAssignmentError { - #[error("alpha is zero")] - AlphaIsZero, - #[error("non-finite alpha: {alpha}")] - AlphaNotFinite { alpha: f64 }, - #[error("assignment vector is empty")] - EmptyAssignmentVec, - #[error("there are {n_cats} categories but {n} data")] - NLessThanNCats { n: usize, n_cats: usize }, - #[error("invalid assignment: {0}")] - AssignmentError(#[from] AssignmentError), -} - -impl AssignmentBuilder { - /// Create a builder for `n`-length assignments - /// - /// # Arguments - /// - n: the number of data/entries in the assignment - pub fn new(n: usize) -> Self { - AssignmentBuilder { - n, - asgn: None, - prior: None, - alpha: None, - seed: None, - } - } - - /// Initialize the builder from an assignment vector - /// - /// # Note: - /// The validity of `asgn` will not be verified until `build` is called. - pub fn from_vec(asgn: Vec) -> Self { - AssignmentBuilder { - n: asgn.len(), - asgn: Some(asgn), - prior: None, - alpha: None, - seed: None, - } - } - - /// Add a prior on the `Crp` `alpha` parameter - #[must_use] - pub fn with_prior(mut self, prior: Gamma) -> Self { - self.prior = Some(prior); - self - } - - /// Use the Geweke `Crp` `alpha` prior - #[must_use] - pub fn with_geweke_prior(mut self) -> Self { - self.prior = Some(lace_consts::geweke_alpha_prior()); - self - } - - /// Set the `Crp` `alpha` to a specific value - #[must_use] - pub fn with_alpha(mut self, alpha: f64) -> Self { - self.alpha = Some(alpha); - self - } - - /// Set the RNG seed - #[must_use] - pub fn with_seed(mut self, seed: u64) -> Self { - self.seed = Some(seed); - self - } - - /// Set the RNG seed from another RNG - #[must_use] - pub fn seed_from_rng(mut self, rng: &mut R) -> Self { - self.seed = Some(rng.next_u64()); - self - } - - /// Use a *flat* assignment with one partition - #[must_use] - pub fn flat(mut self) -> Self { - self.asgn = Some(vec![0; self.n]); - self - } - - /// Use an assignment with `n_cats`, evenly populated partitions/categories - pub fn with_n_cats( - mut self, - n_cats: usize, - ) -> Result { - if n_cats > self.n { - Err(BuildAssignmentError::NLessThanNCats { n: self.n, n_cats }) - } else { - let asgn: Vec = (0..self.n).map(|i| i % n_cats).collect(); - self.asgn = Some(asgn); - Ok(self) - } - } - - /// Build the assignment and consume the builder - pub fn build(self) -> Result { - let prior = self.prior.unwrap_or_else(lace_consts::general_alpha_prior); - - let mut rng_opt = if self.alpha.is_none() || self.asgn.is_none() { - let rng = match self.seed { - Some(seed) => Xoshiro256Plus::seed_from_u64(seed), - None => Xoshiro256Plus::from_entropy(), - }; - Some(rng) - } else { - None - }; - - let alpha = match self.alpha { - Some(alpha) => alpha, - None => prior.draw(&mut rng_opt.as_mut().unwrap()), - }; - - let n = self.n; - let asgn = self.asgn.unwrap_or_else(|| { - crp_draw(n, alpha, &mut rng_opt.as_mut().unwrap()).asgn - }); - - let n_cats: usize = asgn.iter().max().map(|&m| m + 1).unwrap_or(0); - let mut counts: Vec = vec![0; n_cats]; - for z in &asgn { - counts[*z] += 1; - } - - let asgn_out = Assignment { - alpha, - asgn, - counts, - n_cats, - prior, - }; - - if validate_assignment!(asgn_out) { - Ok(asgn_out) - } else { - asgn_out - .validate() - .emit_error() - .map_err(BuildAssignmentError::AssignmentError) - .map(|_| asgn_out) +impl Assignment { + pub fn empty() -> Self { + Self { + asgn: Vec::new(), + counts: Vec::new(), + n_cats: 0, } } -} -impl Assignment { /// Replace the assignment vector pub fn set_asgn( &mut self, @@ -407,52 +246,6 @@ impl Assignment { self.len() == 0 } - /// Returns the Dirichlet posterior - /// - /// # Arguments - /// - /// - append_alpha: if `true` append `alpha` to the end of the vector. This - /// is used primarily for the `FiniteCpu` assignment kernel. - /// - /// # Example - /// - /// ```rust - /// # use lace_cc::assignment::AssignmentBuilder; - /// let assignment = AssignmentBuilder::from_vec(vec![0, 0, 1, 2]) - /// .with_alpha(0.5) - /// .build() - /// .unwrap(); - /// - /// assert_eq!(assignment.asgn, vec![0, 0, 1, 2]); - /// assert_eq!(assignment.counts, vec![2, 1, 1]); - /// assert_eq!(assignment.dirvec(false), vec![2.0, 1.0, 1.0]); - /// assert_eq!(assignment.dirvec(true), vec![2.0, 1.0, 1.0, 0.5]); - /// ``` - pub fn dirvec(&self, append_alpha: bool) -> Vec { - let mut dv: Vec = self.counts.iter().map(|&x| x as f64).collect(); - if append_alpha { - dv.push(self.alpha); - } - dv - } - - /// Returns the log of the Dirichlet posterior - /// - /// # Arguments - /// - /// - append_alpha: if `true` append `alpha` to the end of the vector. This - /// is used primarily for the `FiniteCpu` assignment kernel. - pub fn log_dirvec(&self, append_alpha: bool) -> Vec { - let mut dv: Vec = - self.counts.iter().map(|&x| (x as f64).ln()).collect(); - - if append_alpha { - dv.push(self.alpha.ln()); - } - - dv - } - /// Mark the entry at ix as unassigned. Will remove the entry's contribution /// to `n_cats` and `counts`, and will mark `asgn[ix]` with the unassigned /// designator.. @@ -503,14 +296,15 @@ impl Assignment { /// Append a new, unassigned entry to th end of the assignment /// - /// # Eample + /// # Example /// /// ``` - /// # use lace_cc::assignment::AssignmentBuilder; + /// # use lace_stats::prior_process::Builder; /// - /// let mut assignment = AssignmentBuilder::from_vec(vec![0, 0, 1]) + /// let mut assignment = Builder::from_vec(vec![0, 0, 1]) /// .build() - /// .unwrap(); + /// .unwrap() + /// .asgn; /// /// assert_eq!(assignment.asgn, vec![0, 0, 1]); /// @@ -522,50 +316,6 @@ impl Assignment { self.asgn.push(usize::max_value()) } - /// Returns the proportion of data assigned to each partition/category - /// - /// # Example - /// - /// ```rust - /// # use lace_cc::assignment::AssignmentBuilder; - /// let mut rng = rand::thread_rng(); - /// let assignment = AssignmentBuilder::from_vec(vec![0, 0, 1, 2]) - /// .build() - /// .unwrap(); - /// - /// assert_eq!(assignment.asgn, vec![0, 0, 1, 2]); - /// assert_eq!(assignment.counts, vec![2, 1, 1]); - /// assert_eq!(assignment.weights(), vec![0.5, 0.25, 0.25]); - /// ``` - pub fn weights(&self) -> Vec { - let z: f64 = self.len() as f64; - self.dirvec(false).iter().map(|&w| w / z).collect() - } - - /// The log of the weights - pub fn log_weights(&self) -> Vec { - self.weights().iter().map(|w| w.ln()).collect() - } - - /// Posterior update of `alpha` given the prior and the current assignment - /// vector - pub fn update_alpha( - &mut self, - n_iter: usize, - rng: &mut R, - ) -> f64 { - // TODO: Should we use a different method to draw CRP alpha that can - // extend outside of the bulk of the prior's mass? - let cts = &self.counts; - let n: usize = self.len(); - let loglike = |alpha: &f64| lcrp(n, cts, *alpha); - let prior_ref = &self.prior; - let prior_draw = |rng: &mut R| prior_ref.draw(rng); - let mh_result = mh_prior(self.alpha, loglike, prior_draw, n_iter, rng); - self.alpha = mh_result.x; - mh_result.score_x - } - /// Validates the assignment pub fn validate(&self) -> AssignmentDiagnostics { AssignmentDiagnostics::new(self) @@ -582,20 +332,45 @@ pub fn lcrp(n: usize, cts: &[usize], alpha: f64) -> f64 { gsum + k.mul_add(alpha.ln(), cpnt_2) } +fn ln_py_bracket(x: f64, m: usize, alpha: f64) -> f64 { + if m == 0 { + return 0.0; + } + (1..=m) + .map(|m_i| (m_i as f64 - 1.0).mul_add(alpha, x).ln()) + .sum::() +} + +/// Formula from: +/// Pitman, Jim. "Exchangeable and partially exchangeable random partitions." +/// Probability theory and related fields 102.2 (1995): 145-158. +/// https://www.stat.berkeley.edu/~aldous/206-Exch/Papers/pitman95a.pdf +pub fn lpyp(cts: &[usize], alpha: f64, d: f64) -> f64 { + let k = cts.len(); + let n = cts.iter().copied().sum::(); + let term_a = ln_py_bracket(alpha + d, k - 1, d); + let term_b = ln_py_bracket(alpha + 1.0, n - 1, 1.0); + let term_c = cts + .iter() + .map(|&ct_i| ln_py_bracket(1.0 - d, ct_i - 1, 1.0)) + .sum::(); + term_a - term_b + term_c +} + #[cfg(test)] mod tests { use super::*; + use crate::prior_process::Builder as AssignmentBuilder; + use crate::prior_process::{Dirichlet, Process}; + use crate::rv::dist::Gamma; use approx::*; - use lace_stats::rv::dist::Gamma; #[test] fn zero_count_fails_validation() { let asgn = Assignment { - alpha: 1.0, asgn: vec![0, 0, 0, 0], counts: vec![0, 4], n_cats: 1, - prior: Gamma::new(1.0, 1.0).unwrap(), }; let diagnostic = asgn.validate(); @@ -614,11 +389,9 @@ mod tests { #[test] fn bad_counts_fails_validation() { let asgn = Assignment { - alpha: 1.0, asgn: vec![1, 1, 0, 0], counts: vec![2, 3], n_cats: 2, - prior: Gamma::new(1.0, 1.0).unwrap(), }; let diagnostic = asgn.validate(); @@ -637,11 +410,9 @@ mod tests { #[test] fn low_n_cats_fails_validation() { let asgn = Assignment { - alpha: 1.0, asgn: vec![1, 1, 0, 0], counts: vec![2, 2], n_cats: 1, - prior: Gamma::new(1.0, 1.0).unwrap(), }; let diagnostic = asgn.validate(); @@ -660,11 +431,9 @@ mod tests { #[test] fn high_n_cats_fails_validation() { let asgn = Assignment { - alpha: 1.0, asgn: vec![1, 1, 0, 0], counts: vec![2, 2], n_cats: 3, - prior: Gamma::new(1.0, 1.0).unwrap(), }; let diagnostic = asgn.validate(); @@ -683,11 +452,9 @@ mod tests { #[test] fn no_zero_cat_fails_validation() { let asgn = Assignment { - alpha: 1.0, asgn: vec![1, 1, 2, 2], counts: vec![2, 2], n_cats: 2, - prior: Gamma::new(1.0, 1.0).unwrap(), }; let diagnostic = asgn.validate(); @@ -709,29 +476,34 @@ mod tests { // do the test 100 times because it's random for _ in 0..100 { - let asgn = AssignmentBuilder::new(n).build().unwrap(); + let asgn = AssignmentBuilder::new(n).build().unwrap().asgn; assert!(asgn.validate().is_valid()); } } #[test] - fn from_prior_should_have_valid_alpha_and_proper_length() { + fn from_prior_process_should_have_valid_alpha_and_proper_length() { let n: usize = 50; + let mut rng = rand::thread_rng(); + let process = Process::Dirichlet(Dirichlet::from_prior( + Gamma::new(1.0, 1.0).unwrap(), + &mut rng, + )); let asgn = AssignmentBuilder::new(n) - .with_prior(Gamma::new(1.0, 1.0).unwrap()) + .with_process(process) .build() - .unwrap(); + .unwrap() + .asgn; assert!(!asgn.is_empty()); assert_eq!(asgn.len(), n); assert!(asgn.validate().is_valid()); - assert!(asgn.alpha > 0.0); } #[test] fn flat_partition_validation() { let n: usize = 50; - let asgn = AssignmentBuilder::new(n).flat().build().unwrap(); + let asgn = AssignmentBuilder::new(n).flat().build().unwrap().asgn; assert_eq!(asgn.n_cats, 1); assert_eq!(asgn.counts.len(), 1); @@ -742,7 +514,7 @@ mod tests { #[test] fn from_vec() { let z = vec![0, 1, 2, 0, 1, 0]; - let asgn = AssignmentBuilder::from_vec(z).build().unwrap(); + let asgn = AssignmentBuilder::from_vec(z).build().unwrap().asgn; assert_eq!(asgn.n_cats, 3); assert_eq!(asgn.counts[0], 3); assert_eq!(asgn.counts[1], 2); @@ -755,7 +527,8 @@ mod tests { .with_n_cats(5) .expect("Whoops!") .build() - .unwrap(); + .unwrap() + .asgn; assert!(asgn.validate().is_valid()); assert_eq!(asgn.n_cats, 5); assert_eq!(asgn.counts[0], 20); @@ -771,7 +544,8 @@ mod tests { .with_n_cats(5) .expect("Whoops!") .build() - .unwrap(); + .unwrap() + .asgn; assert!(asgn.validate().is_valid()); assert_eq!(asgn.n_cats, 5); assert_eq!(asgn.counts[0], 21); @@ -781,79 +555,6 @@ mod tests { assert_eq!(asgn.counts[4], 20); } - #[test] - fn dirvec_with_alpha_1() { - let asgn = AssignmentBuilder::from_vec(vec![0, 1, 2, 0, 1, 0]) - .with_alpha(1.0) - .build() - .unwrap(); - let dv = asgn.dirvec(false); - - assert_eq!(dv.len(), 3); - assert_relative_eq!(dv[0], 3.0, epsilon = 10E-10); - assert_relative_eq!(dv[1], 2.0, epsilon = 10E-10); - assert_relative_eq!(dv[2], 1.0, epsilon = 10E-10); - } - - #[test] - fn dirvec_with_alpha_15() { - let asgn = AssignmentBuilder::from_vec(vec![0, 1, 2, 0, 1, 0]) - .with_alpha(1.5) - .build() - .unwrap(); - let dv = asgn.dirvec(true); - - assert_eq!(dv.len(), 4); - assert_relative_eq!(dv[0], 3.0, epsilon = 10E-10); - assert_relative_eq!(dv[1], 2.0, epsilon = 10E-10); - assert_relative_eq!(dv[2], 1.0, epsilon = 10E-10); - assert_relative_eq!(dv[3], 1.5, epsilon = 10E-10); - } - - #[test] - fn log_dirvec_with_alpha_1() { - let asgn = AssignmentBuilder::from_vec(vec![0, 1, 2, 0, 1, 0]) - .with_alpha(1.0) - .build() - .unwrap(); - let ldv = asgn.log_dirvec(false); - - assert_eq!(ldv.len(), 3); - assert_relative_eq!(ldv[0], 3.0_f64.ln(), epsilon = 10E-10); - assert_relative_eq!(ldv[1], 2.0_f64.ln(), epsilon = 10E-10); - assert_relative_eq!(ldv[2], 1.0_f64.ln(), epsilon = 10E-10); - } - - #[test] - fn log_dirvec_with_alpha_15() { - let asgn = AssignmentBuilder::from_vec(vec![0, 1, 2, 0, 1, 0]) - .with_alpha(1.5) - .build() - .unwrap(); - - let ldv = asgn.log_dirvec(true); - - assert_eq!(ldv.len(), 4); - assert_relative_eq!(ldv[0], 3.0_f64.ln(), epsilon = 10E-10); - assert_relative_eq!(ldv[1], 2.0_f64.ln(), epsilon = 10E-10); - assert_relative_eq!(ldv[2], 1.0_f64.ln(), epsilon = 10E-10); - assert_relative_eq!(ldv[3], 1.5_f64.ln(), epsilon = 10E-10); - } - - #[test] - fn weights() { - let asgn = AssignmentBuilder::from_vec(vec![0, 1, 2, 0, 1, 0]) - .with_alpha(1.0) - .build() - .unwrap(); - let weights = asgn.weights(); - - assert_eq!(weights.len(), 3); - assert_relative_eq!(weights[0], 3.0 / 6.0, epsilon = 10E-10); - assert_relative_eq!(weights[1], 2.0 / 6.0, epsilon = 10E-10); - assert_relative_eq!(weights[2], 1.0 / 6.0, epsilon = 10E-10); - } - #[test] fn lcrp_all_ones() { let lcrp_1 = lcrp(4, &[1, 1, 1, 1], 1.0); @@ -866,7 +567,7 @@ mod tests { #[test] fn unassign_non_singleton() { let z: Vec = vec![0, 1, 1, 1, 2, 2]; - let mut asgn = AssignmentBuilder::from_vec(z).build().unwrap(); + let mut asgn = AssignmentBuilder::from_vec(z).build().unwrap().asgn; assert_eq!(asgn.n_cats, 3); assert_eq!(asgn.counts, vec![1, 3, 2]); @@ -881,7 +582,7 @@ mod tests { #[test] fn unassign_singleton_low() { let z: Vec = vec![0, 1, 1, 1, 2, 2]; - let mut asgn = AssignmentBuilder::from_vec(z).build().unwrap(); + let mut asgn = AssignmentBuilder::from_vec(z).build().unwrap().asgn; assert_eq!(asgn.n_cats, 3); assert_eq!(asgn.counts, vec![1, 3, 2]); @@ -896,7 +597,7 @@ mod tests { #[test] fn unassign_singleton_high() { let z: Vec = vec![0, 0, 1, 1, 1, 2]; - let mut asgn = AssignmentBuilder::from_vec(z).build().unwrap(); + let mut asgn = AssignmentBuilder::from_vec(z).build().unwrap().asgn; assert_eq!(asgn.n_cats, 3); assert_eq!(asgn.counts, vec![2, 3, 1]); @@ -911,7 +612,7 @@ mod tests { #[test] fn unassign_singleton_middle() { let z: Vec = vec![0, 0, 1, 2, 2, 2]; - let mut asgn = AssignmentBuilder::from_vec(z).build().unwrap(); + let mut asgn = AssignmentBuilder::from_vec(z).build().unwrap().asgn; assert_eq!(asgn.n_cats, 3); assert_eq!(asgn.counts, vec![2, 1, 3]); @@ -926,7 +627,7 @@ mod tests { #[test] fn reassign_to_existing_cat() { let z: Vec = vec![0, 1, 1, 1, 2, 2]; - let mut asgn = AssignmentBuilder::from_vec(z).build().unwrap(); + let mut asgn = AssignmentBuilder::from_vec(z).build().unwrap().asgn; assert_eq!(asgn.n_cats, 3); assert_eq!(asgn.counts, vec![1, 3, 2]); @@ -947,7 +648,7 @@ mod tests { #[test] fn reassign_to_new_cat() { let z: Vec = vec![0, 1, 1, 1, 2, 2]; - let mut asgn = AssignmentBuilder::from_vec(z).build().unwrap(); + let mut asgn = AssignmentBuilder::from_vec(z).build().unwrap().asgn; assert_eq!(asgn.n_cats, 3); assert_eq!(asgn.counts, vec![1, 3, 2]); @@ -966,41 +667,43 @@ mod tests { } #[test] - fn dirvec_with_unassigned_entry() { - let z: Vec = vec![0, 1, 1, 1, 2, 2]; - let mut asgn = AssignmentBuilder::from_vec(z) - .with_alpha(1.0) + fn manual_seed_control_works() { + let n = 100; + let asgn_1 = AssignmentBuilder::new(n) + .with_seed(17_834_795) .build() - .unwrap(); - - asgn.unassign(5); - - let dv = asgn.dirvec(false); + .unwrap() + .asgn; - assert_eq!(dv.len(), 3); - assert_relative_eq!(dv[0], 1.0, epsilon = 10e-10); - assert_relative_eq!(dv[1], 3.0, epsilon = 10e-10); - assert_relative_eq!(dv[2], 1.0, epsilon = 10e-10); - } + let asgn_2 = AssignmentBuilder::new(n) + .with_seed(17_834_795) + .build() + .unwrap() + .asgn; - #[test] - fn manual_seed_control_works() { - let asgn_1 = AssignmentBuilder::new(25).with_seed(17_834_795).build(); - let asgn_2 = AssignmentBuilder::new(25).with_seed(17_834_795).build(); - let asgn_3 = AssignmentBuilder::new(25).build(); + let asgn_3 = AssignmentBuilder::new(n).build().unwrap().asgn; assert_eq!(asgn_1, asgn_2); assert_ne!(asgn_1, asgn_3); } #[test] fn from_rng_seed_control_works() { - let mut rng_1 = Xoshiro256Plus::seed_from_u64(17_834_795); - let mut rng_2 = Xoshiro256Plus::seed_from_u64(17_834_795); - let asgn_1 = - AssignmentBuilder::new(25).seed_from_rng(&mut rng_1).build(); - let asgn_2 = - AssignmentBuilder::new(25).seed_from_rng(&mut rng_2).build(); - let asgn_3 = AssignmentBuilder::new(25).build(); + use rand::rngs::SmallRng; + use rand::SeedableRng; + + let mut rng_1 = SmallRng::seed_from_u64(17_834_795); + let mut rng_2 = SmallRng::seed_from_u64(17_834_795); + let asgn_1 = AssignmentBuilder::new(50) + .seed_from_rng(&mut rng_1) + .build() + .unwrap() + .asgn; + let asgn_2 = AssignmentBuilder::new(50) + .seed_from_rng(&mut rng_2) + .build() + .unwrap() + .asgn; + let asgn_3 = AssignmentBuilder::new(50).build().unwrap().asgn; assert_eq!(asgn_1, asgn_2); assert_ne!(asgn_1, asgn_3); } diff --git a/lace/lace_stats/src/lib.rs b/lace/lace_stats/src/lib.rs index 5d101ce9..74cf0bd2 100644 --- a/lace/lace_stats/src/lib.rs +++ b/lace/lace_stats/src/lib.rs @@ -8,6 +8,8 @@ clippy::option_option, clippy::implicit_clone )] + +pub mod assignment; mod cdf; mod chi_square; pub mod dist; @@ -19,6 +21,7 @@ pub mod mh; mod mixture_type; mod perm; pub mod prior; +pub mod prior_process; pub mod seq; mod simplex; pub mod uncertainty; diff --git a/lace/lace_stats/src/prior_process.rs b/lace/lace_stats/src/prior_process.rs new file mode 100644 index 00000000..15fb1f7c --- /dev/null +++ b/lace/lace_stats/src/prior_process.rs @@ -0,0 +1,754 @@ +use lace_consts::rv::{misc::pflip, traits::Rv}; +use rand::Rng; +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +use crate::assignment::{Assignment, AssignmentError}; +use crate::rv::dist::{Beta, Gamma}; + +const MAX_STICK_BREAKING_ITERS: u16 = 10_000; + +pub trait PriorProcessT { + fn ln_gibbs_weight(&self, n_k: usize) -> f64; + + fn ln_singleton_weight(&self, n_cats: usize) -> f64; + + fn weight_vec( + &self, + asgn: &Assignment, + normed: bool, + append_new: bool, + ) -> Vec; + + fn slice_sb_extend( + &self, + weights: Vec, + u_star: f64, + rng: &mut R, + ) -> Vec; + + fn draw_assignment(&self, n: usize, rng: &mut R) -> Assignment; + + fn update_params(&mut self, asgn: &Assignment, rng: &mut R) -> f64; + + fn reset_params(&mut self, rng: &mut R); + + fn ln_f_partition(&self, asgn: &Assignment) -> f64; +} + +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +#[serde(rename = "snake_case")] +pub struct Dirichlet { + pub alpha: f64, + pub alpha_prior: Gamma, +} + +impl Dirichlet { + pub fn from_prior(alpha_prior: Gamma, rng: &mut R) -> Self { + Self { + alpha: alpha_prior.draw(rng), + alpha_prior, + } + } +} + +impl std::fmt::Display for Dirichlet { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Dirichlet(α({}) ~ {})", self.alpha, self.alpha_prior) + } +} + +impl PriorProcessT for Dirichlet { + fn ln_gibbs_weight(&self, n_k: usize) -> f64 { + (n_k as f64).ln() + } + + fn ln_singleton_weight(&self, _n_cats: usize) -> f64 { + self.alpha.ln() + } + + fn weight_vec( + &self, + asgn: &Assignment, + normed: bool, + append_new: bool, + ) -> Vec { + let mut weights: Vec = + asgn.counts.iter().map(|&ct| ct as f64).collect(); + + let z = if append_new { + weights.push(self.alpha); + asgn.len() as f64 + self.alpha + } else { + asgn.len() as f64 + }; + + if normed { + weights.iter_mut().for_each(|ct| *ct /= z); + } + + weights + } + + fn slice_sb_extend( + &self, + weights: Vec, + u_star: f64, + rng: &mut R, + ) -> Vec { + sb_slice_extend(weights, self.alpha, 0.0, u_star, rng).unwrap() + } + + fn draw_assignment(&self, n: usize, rng: &mut R) -> Assignment { + if n == 0 { + return Assignment::empty(); + } + let mut counts = vec![1]; + let mut ps = vec![1.0, self.alpha]; + let mut zs = vec![0; n]; + + for z in zs.iter_mut().take(n).skip(1) { + let zi = pflip(&ps, 1, rng)[0]; + *z = zi; + if zi < counts.len() { + ps[zi] += 1.0; + counts[zi] += 1; + } else { + ps[zi] = 1.0; + ps.push(self.alpha); + counts.push(1); + }; + } + + Assignment { + asgn: zs, + n_cats: counts.len(), + counts, + } + } + + fn update_params(&mut self, asgn: &Assignment, rng: &mut R) -> f64 { + // TODO: Should we use a different method to draw CRP alpha that can + // extend outside of the bulk of the prior's mass? + let cts = &asgn.counts; + let n: usize = asgn.len(); + let loglike = |alpha: &f64| crate::assignment::lcrp(n, cts, *alpha); + let prior_ref = &self.alpha_prior; + let prior_draw = |rng: &mut R| prior_ref.draw(rng); + let mh_result = + crate::mh::mh_prior(self.alpha, loglike, prior_draw, 100, rng); + self.alpha = mh_result.x; + mh_result.score_x + } + + fn reset_params(&mut self, rng: &mut R) { + self.alpha = self.alpha_prior.draw(rng); + } + + fn ln_f_partition(&self, asgn: &Assignment) -> f64 { + crate::assignment::lcrp(asgn.len(), &asgn.counts, self.alpha) + } +} + +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +#[serde(rename = "snake_case")] +pub struct PitmanYor { + pub alpha: f64, + pub d: f64, + pub alpha_prior: Gamma, + pub d_prior: Beta, +} + +impl std::fmt::Display for PitmanYor { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Pitman-Yor(α({}) ~ {}, d({}) ~ {})", + self.alpha, self.alpha_prior, self.d, self.d_prior + ) + } +} + +impl PitmanYor { + pub fn from_prior( + alpha_prior: Gamma, + d_prior: Beta, + rng: &mut R, + ) -> Self { + Self { + alpha: alpha_prior.draw(rng), + d: d_prior.draw(rng), + alpha_prior, + d_prior, + } + } +} + +impl PriorProcessT for PitmanYor { + fn ln_gibbs_weight(&self, n_k: usize) -> f64 { + (n_k as f64 - self.d).ln() + } + + fn ln_singleton_weight(&self, n_cats: usize) -> f64 { + self.d.mul_add(n_cats as f64, self.alpha).ln() + } + + fn weight_vec( + &self, + asgn: &Assignment, + normed: bool, + append_new: bool, + ) -> Vec { + let mut weights: Vec = + asgn.counts.iter().map(|&ct| ct as f64 - self.d).collect(); + + let z = if append_new { + weights.push(self.d.mul_add(asgn.n_cats as f64, self.alpha)); + asgn.len() as f64 + self.alpha + } else { + asgn.len() as f64 + }; + + if normed { + weights.iter_mut().for_each(|ct| *ct /= z); + } + + weights + } + + fn slice_sb_extend( + &self, + weights: Vec, + u_star: f64, + rng: &mut R, + ) -> Vec { + sb_slice_extend(weights, self.alpha, self.d, u_star, rng).unwrap() + } + + fn draw_assignment(&self, n: usize, rng: &mut R) -> Assignment { + if n == 0 { + return Assignment::empty(); + } + + let mut counts = vec![1]; + let mut ps = vec![1.0 - self.d, self.alpha + self.d]; + let mut zs = vec![0; n]; + + for z in zs.iter_mut().take(n).skip(1) { + let zi = pflip(&ps, 1, rng)[0]; + *z = zi; + if zi < counts.len() { + ps[zi] += 1.0; + counts[zi] += 1; + } else { + ps[zi] = 1.0 - self.d; + counts.push(1); + ps.push(self.d.mul_add(counts.len() as f64, self.alpha)); + }; + } + + Assignment { + asgn: zs, + n_cats: counts.len(), + counts, + } + } + + fn update_params(&mut self, asgn: &Assignment, rng: &mut R) -> f64 { + let cts = &asgn.counts; + // TODO: this is not the best way to do this. + let ln_f_alpha = { + let loglike = + |alpha: &f64| crate::assignment::lpyp(cts, *alpha, self.d); + let prior_ref = &self.alpha_prior; + let prior_draw = |rng: &mut R| prior_ref.draw(rng); + let mh_result = + crate::mh::mh_prior(self.alpha, loglike, prior_draw, 100, rng); + self.alpha = mh_result.x; + mh_result.score_x + }; + + let ln_f_d = { + let loglike = + |d: &f64| crate::assignment::lpyp(cts, self.alpha, *d); + let prior_ref = &self.d_prior; + let prior_draw = |rng: &mut R| prior_ref.draw(rng); + let mh_result = + crate::mh::mh_prior(self.d, loglike, prior_draw, 100, rng); + self.d = mh_result.x; + mh_result.score_x + }; + + ln_f_alpha + ln_f_d + } + + fn reset_params(&mut self, rng: &mut R) { + self.alpha = self.alpha_prior.draw(rng); + self.d = self.d_prior.draw(rng); + } + + fn ln_f_partition(&self, asgn: &Assignment) -> f64 { + crate::assignment::lpyp(&asgn.counts, self.alpha, self.d) + } +} + +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum Process { + Dirichlet(Dirichlet), + PitmanYor(PitmanYor), +} + +impl PriorProcessT for Process { + fn ln_gibbs_weight(&self, n_k: usize) -> f64 { + match self { + Self::Dirichlet(proc) => proc.ln_gibbs_weight(n_k), + Self::PitmanYor(proc) => proc.ln_gibbs_weight(n_k), + } + } + + fn ln_singleton_weight(&self, n_cats: usize) -> f64 { + match self { + Self::Dirichlet(proc) => proc.ln_singleton_weight(n_cats), + Self::PitmanYor(proc) => proc.ln_singleton_weight(n_cats), + } + } + + fn weight_vec( + &self, + asgn: &Assignment, + normed: bool, + append_new: bool, + ) -> Vec { + match self { + Self::Dirichlet(proc) => proc.weight_vec(asgn, normed, append_new), + Self::PitmanYor(proc) => proc.weight_vec(asgn, normed, append_new), + } + } + + fn slice_sb_extend( + &self, + weights: Vec, + u_star: f64, + rng: &mut R, + ) -> Vec { + match self { + Self::Dirichlet(proc) => proc.slice_sb_extend(weights, u_star, rng), + Self::PitmanYor(proc) => proc.slice_sb_extend(weights, u_star, rng), + } + } + + fn draw_assignment(&self, n: usize, rng: &mut R) -> Assignment { + match self { + Self::Dirichlet(proc) => proc.draw_assignment(n, rng), + Self::PitmanYor(proc) => proc.draw_assignment(n, rng), + } + } + + fn update_params(&mut self, asgn: &Assignment, rng: &mut R) -> f64 { + match self { + Self::Dirichlet(proc) => proc.update_params(asgn, rng), + Self::PitmanYor(proc) => proc.update_params(asgn, rng), + } + } + + fn reset_params(&mut self, rng: &mut R) { + match self { + Self::Dirichlet(proc) => proc.reset_params(rng), + Self::PitmanYor(proc) => proc.reset_params(rng), + } + } + + fn ln_f_partition(&self, asgn: &Assignment) -> f64 { + match self { + Self::Dirichlet(proc) => proc.ln_f_partition(asgn), + Self::PitmanYor(proc) => proc.ln_f_partition(asgn), + } + } +} + +impl Default for Process { + fn default() -> Self { + Self::Dirichlet(Dirichlet { + alpha: 1.0, + alpha_prior: lace_consts::general_alpha_prior(), + }) + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct PriorProcess { + pub process: Process, + pub asgn: Assignment, +} + +impl PriorProcess { + pub fn from_process( + process: Process, + n: usize, + rng: &mut R, + ) -> Self { + let asgn = process.draw_assignment(n, rng); + Self { process, asgn } + } + + pub fn weight_vec(&self, append_new: bool) -> Vec { + self.process.weight_vec(&self.asgn, true, append_new) + } + + pub fn weight_vec_unnormed(&self, append_new: bool) -> Vec { + self.process.weight_vec(&self.asgn, false, append_new) + } + + pub fn update_params(&mut self, rng: &mut R) -> f64 { + self.process.update_params(&self.asgn, rng) + } + + pub fn ln_f_partition(&self, asgn: &Assignment) -> f64 { + self.process.ln_f_partition(asgn) + } +} + +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] +pub enum PriorProcessType { + Dirichlet, + PitmanYor, +} + +/// The stick breaking algorithm has exceeded the max number of iterations. +#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct TheStickIsDust(u16); + +// Append new dirchlet weights by stick breaking until the new weight is less +// than u* +// +// **NOTE** This function is only for the slice reassignment kernel. It cuts out +// all weights that are less that u*, so the sum of the weights will not be 1. +fn sb_slice_extend( + mut weights: Vec, + alpha: f64, + d: f64, + u_star: f64, + mut rng: &mut R, +) -> Result, TheStickIsDust> { + let mut b_star = weights.pop().unwrap(); + + // If α is low and we do the dirichlet update w ~ Dir(n_1, ..., n_k, α), + // the final weight will often be zero. In that case, we're done. + if b_star <= 1E-16 { + weights.push(b_star); + return Ok(weights); + } + + let mut beta = Beta::new(1.0 + d, alpha).unwrap(); + + let mut iters: u16 = 0; + loop { + if d > 0.0 { + let n_cats = weights.len() as f64; + beta.set_beta(d.mul_add(n_cats, alpha)).unwrap(); + } + + let vk: f64 = beta.draw(&mut rng); + let bk = vk * b_star; + b_star *= 1.0 - vk; + + if bk >= u_star { + weights.push(bk); + } + + if b_star < u_star { + return Ok(weights); + } + + iters += 1; + if iters > MAX_STICK_BREAKING_ITERS { + // return Err(TheStickIsDust(MAX_STICK_BREAKING_ITERS)); + eprintln!( + "The stick is dust, n_cats: {}, u*: {}", + weights.len(), + u_star + ); + return Ok(weights); + } + } +} + +/// Constructs a PriorProcess +#[derive(Clone, Debug)] +pub struct Builder { + n: usize, + asgn: Option>, + process: Option, + seed: Option, +} + +#[derive(Debug, Error, PartialEq)] +pub enum BuildPriorProcessError { + #[error("assignment vector is empty")] + EmptyAssignmentVec, + #[error("there are {n_cats} categories but {n} data")] + NLessThanNCats { n: usize, n_cats: usize }, + #[error("invalid assignment: {0}")] + AssignmentError(#[from] AssignmentError), +} + +impl Builder { + /// Create a builder for `n`-length assignments + /// + /// # Arguments + /// - n: the number of data/entries in the assignment + pub fn new(n: usize) -> Self { + Self { + n, + asgn: None, + process: None, + seed: None, + } + } + + /// Initialize the builder from an assignment vector + /// + /// # Note: + /// The validity of `asgn` will not be verified until `build` is called. + pub fn from_vec(asgn: Vec) -> Self { + Self { + n: asgn.len(), + asgn: Some(asgn), + process: None, + seed: None, + } + } + + /// Select the process type + #[must_use] + pub fn with_process(mut self, process: Process) -> Self { + self.process = Some(process); + self + } + + /// Set the RNG seed + #[must_use] + pub fn with_seed(mut self, seed: u64) -> Self { + self.seed = Some(seed); + self + } + + /// Set the RNG seed from another RNG + #[must_use] + pub fn seed_from_rng(mut self, rng: &mut R) -> Self { + self.seed = Some(rng.next_u64()); + self + } + + /// Use a *flat* assignment with one partition + #[must_use] + pub fn flat(mut self) -> Self { + self.asgn = Some(vec![0; self.n]); + self + } + + /// Use an assignment with `n_cats`, evenly populated partitions/categories + pub fn with_n_cats( + mut self, + n_cats: usize, + ) -> Result { + if n_cats > self.n { + Err(BuildPriorProcessError::NLessThanNCats { n: self.n, n_cats }) + } else { + let asgn: Vec = (0..self.n).map(|i| i % n_cats).collect(); + self.asgn = Some(asgn); + Ok(self) + } + } + + /// Build the assignment and consume the builder + pub fn build(self) -> Result { + use rand::rngs::StdRng; + use rand::SeedableRng; + + let mut rng = self + .seed + .map_or_else(StdRng::from_entropy, StdRng::seed_from_u64); + + let process = self.process.unwrap_or_else(|| { + Process::Dirichlet(Dirichlet::from_prior( + lace_consts::general_alpha_prior(), + &mut rng, + )) + }); + + let n = self.n; + let asgn = self + .asgn + .unwrap_or_else(|| process.draw_assignment(n, &mut rng).asgn); + + let n_cats: usize = asgn.iter().max().map(|&m| m + 1).unwrap_or(0); + let mut counts: Vec = vec![0; n_cats]; + for z in &asgn { + counts[*z] += 1; + } + + let asgn = Assignment { + asgn, + counts, + n_cats, + }; + + if crate::validate_assignment!(asgn) { + Ok(PriorProcess { process, asgn }) + } else { + asgn.validate() + .emit_error() + .map_err(BuildPriorProcessError::AssignmentError)?; + Ok(PriorProcess { process, asgn }) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use approx::*; + + const TOL: f64 = 1E-12; + + mod sb_slice { + use super::*; + + #[test] + fn should_return_input_weights_if_alpha_is_zero() { + let mut rng = rand::thread_rng(); + let weights_in: Vec = vec![0.8, 0.2, 0.0]; + let weights_out = + sb_slice_extend(weights_in.clone(), 1.0, 0.0, 0.2, &mut rng) + .unwrap(); + let good = weights_in + .iter() + .zip(weights_out.iter()) + .all(|(wi, wo)| (wi - wo).abs() < TOL); + assert!(good); + } + + #[test] + fn smoke() { + let mut rng = rand::thread_rng(); + let weights_in: Vec = vec![0.8, 0.2]; + let u_star = 0.1; + let res = sb_slice_extend(weights_in, 1.0, 0.0, u_star, &mut rng); + assert!(res.is_ok()); + } + } + + mod build { + use super::*; + + fn dir_process(alpha: f64) -> Process { + let inner = Dirichlet { + alpha, + alpha_prior: Gamma::default(), + }; + Process::Dirichlet(inner) + } + + #[test] + fn dirvec_with_alpha_1() { + let proc = Builder::from_vec(vec![0, 1, 2, 0, 1, 0]) + .with_process(dir_process(1.0)) + .build() + .unwrap(); + let dv = proc.weight_vec_unnormed(false); + + assert_eq!(dv.len(), 3); + assert_relative_eq!(dv[0], 3.0, epsilon = 10E-10); + assert_relative_eq!(dv[1], 2.0, epsilon = 10E-10); + assert_relative_eq!(dv[2], 1.0, epsilon = 10E-10); + } + + #[test] + fn dirvec_with_alpha_15() { + let proc = Builder::from_vec(vec![0, 1, 2, 0, 1, 0]) + .with_process(dir_process(1.5)) + .build() + .unwrap(); + let dv = proc.weight_vec_unnormed(true); + + assert_eq!(dv.len(), 4); + assert_relative_eq!(dv[0], 3.0, epsilon = 10E-10); + assert_relative_eq!(dv[1], 2.0, epsilon = 10E-10); + assert_relative_eq!(dv[2], 1.0, epsilon = 10E-10); + assert_relative_eq!(dv[3], 1.5, epsilon = 10E-10); + } + + #[test] + fn log_dirvec_with_alpha_1() { + let proc = Builder::from_vec(vec![0, 1, 2, 0, 1, 0]) + .with_process(dir_process(1.0)) + .build() + .unwrap(); + + let ldv = (0..3) + .map(|k| proc.process.ln_gibbs_weight(proc.asgn.counts[k])) + .collect::>(); + + assert_eq!(ldv.len(), 3); + assert_relative_eq!(ldv[0], 3.0_f64.ln(), epsilon = 10E-10); + assert_relative_eq!(ldv[1], 2.0_f64.ln(), epsilon = 10E-10); + assert_relative_eq!(ldv[2], 1.0_f64.ln(), epsilon = 10E-10); + } + + #[test] + fn log_dirvec_with_alpha_15() { + let proc = Builder::from_vec(vec![0, 1, 2, 0, 1, 0]) + .with_process(dir_process(1.5)) + .build() + .unwrap(); + + let ldv = (0..3) + .map(|k| proc.process.ln_gibbs_weight(proc.asgn.counts[k])) + .chain(std::iter::once_with(|| { + proc.process.ln_singleton_weight(3) + })) + .collect::>(); + + assert_eq!(ldv.len(), 4); + assert_relative_eq!(ldv[0], 3.0_f64.ln(), epsilon = 10E-10); + assert_relative_eq!(ldv[1], 2.0_f64.ln(), epsilon = 10E-10); + assert_relative_eq!(ldv[2], 1.0_f64.ln(), epsilon = 10E-10); + assert_relative_eq!(ldv[3], 1.5_f64.ln(), epsilon = 10E-10); + } + + #[test] + fn weights() { + let proc = Builder::from_vec(vec![0, 1, 2, 0, 1, 0]) + .with_process(dir_process(1.0)) + .build() + .unwrap(); + + let weights = proc.weight_vec(false); + + assert_eq!(weights.len(), 3); + assert_relative_eq!(weights[0], 3.0 / 6.0, epsilon = 10E-10); + assert_relative_eq!(weights[1], 2.0 / 6.0, epsilon = 10E-10); + assert_relative_eq!(weights[2], 1.0 / 6.0, epsilon = 10E-10); + } + + #[test] + fn dirvec_with_unassigned_entry() { + let z: Vec = vec![0, 1, 1, 1, 2, 2]; + let mut proc = Builder::from_vec(z) + .with_process(dir_process(1.0)) + .build() + .unwrap(); + + proc.asgn.unassign(5); + + let dv = proc.weight_vec_unnormed(false); + + assert_eq!(dv.len(), 3); + assert_relative_eq!(dv[0], 1.0, epsilon = 10e-10); + assert_relative_eq!(dv[1], 3.0, epsilon = 10e-10); + assert_relative_eq!(dv[2], 1.0, epsilon = 10e-10); + } + } +} diff --git a/lace/resources/datasets/animals/codebook.yaml b/lace/resources/datasets/animals/codebook.yaml index 9d856418..2c1463b3 100644 --- a/lace/resources/datasets/animals/codebook.yaml +++ b/lace/resources/datasets/animals/codebook.yaml @@ -1,10 +1,12 @@ -table_name: my_table -state_alpha_prior: - shape: 1.0 - rate: 1.0 -view_alpha_prior: - shape: 1.0 - rate: 1.0 +table_name: animals +state_prior_process: !dirichlet + alpha_prior: + shape: 1.0 + rate: 1.0 +view_prior_process: !dirichlet + alpha_prior: + shape: 1.0 + rate: 1.0 col_metadata: - name: black coltype: !Categorical diff --git a/lace/resources/datasets/satellites/codebook.yaml b/lace/resources/datasets/satellites/codebook.yaml index 7e750b04..5d7c2390 100644 --- a/lace/resources/datasets/satellites/codebook.yaml +++ b/lace/resources/datasets/satellites/codebook.yaml @@ -1,10 +1,12 @@ -table_name: my_table -state_alpha_prior: - shape: 1.0 - rate: 1.0 -view_alpha_prior: - shape: 1.0 - rate: 1.0 +table_name: satellites +state_prior_process: !dirichlet + alpha_prior: + shape: 1.0 + rate: 1.0 +view_prior_process: !dirichlet + alpha_prior: + shape: 1.0 + rate: 1.0 col_metadata: - name: Country_of_Operator coltype: !Categorical diff --git a/lace/resources/test/entropy/entropy-state-1.yaml b/lace/resources/test/entropy/entropy-state-1.yaml index 86f1aa06..fef3b3ea 100644 --- a/lace/resources/test/entropy/entropy-state-1.yaml +++ b/lace/resources/test/entropy/entropy-state-1.yaml @@ -1,16 +1,14 @@ --- -loglike: 0.0 +score: + ln_likelihood: 0.0 + ln_prior: 0.0 + ln_state_prior_process: 0.0 + ln_view_prior_process: 0.0 diagnostics: loglike: - 0.0 - nviews: - - 2 - state_alpha: - - 1.0 -view_alpha_prior: - !Gamma - shape: 1.0 - rate: 1.0 + logprior: + - 0.0 views: - ftrs: 0: @@ -90,19 +88,20 @@ views: scale: 1.0 data: n: 0 - asgn: - alpha: 1 + prior_process: asgn: - - 0 - - 0 - - 1 - - 1 - counts: - - 2 - - 2 - n_cats: 2 - prior: - !Gamma + asgn: + - 0 + - 0 + - 1 + - 1 + counts: + - 2 + - 2 + n_cats: 2 + process: !dirichlet + alpha: 1 + alpha_prior: shape: 1.0 rate: 1.0 weights: @@ -179,38 +178,41 @@ views: scale: 1.0 data: n: 0 - asgn: - alpha: 1 + prior_process: asgn: - - 0 - - 1 - - 1 - - 1 - counts: - - 1 - - 3 - n_cats: 2 - prior: - !Gamma + asgn: + - 0 + - 1 + - 1 + - 1 + counts: + - 1 + - 3 + n_cats: 2 + process: !dirichlet + alpha: 1 + alpha_prior: shape: 1.0 rate: 1.0 weights: - 0.25 - 0.75 alpha: 1 -asgn: - alpha: 1 +prior_process: asgn: - - 0 - - 1 - - 0 - - 1 - counts: - - 2 - - 2 - n_cats: 2 - prior: - !Gamma + alpha: 1 + asgn: + - 0 + - 1 + - 0 + - 1 + counts: + - 2 + - 2 + n_cats: 2 + process: !dirichlet + alpha: 1 + alpha_prior: shape: 1.0 rate: 1.0 weights: diff --git a/lace/resources/test/entropy/entropy-state-2.yaml b/lace/resources/test/entropy/entropy-state-2.yaml index 657de681..4d0cad9f 100644 --- a/lace/resources/test/entropy/entropy-state-2.yaml +++ b/lace/resources/test/entropy/entropy-state-2.yaml @@ -1,15 +1,14 @@ --- +score: + ln_likelihood: 0.0 + ln_prior: 0.0 + ln_state_prior_process: 0.0 + ln_view_prior_process: 0.0 diagnostics: loglike: - 0.0 - nviews: - - 2 - state_alpha: - - 1.0 -view_alpha_prior: - !Gamma - shape: 1.0 - rate: 1.0 + logprior: + - 0.0 views: - ftrs: 0: @@ -88,19 +87,20 @@ views: scale: 1.0 data: n: 0 - asgn: - alpha: 1 + prior_process: asgn: - - 0 - - 0 - - 1 - - 1 - counts: - - 2 - - 2 - n_cats: 2 - prior: - !Gamma + asgn: + - 0 + - 0 + - 1 + - 1 + counts: + - 2 + - 2 + n_cats: 2 + process: !dirichlet + alpha: 1 + alpha_prior: shape: 1.0 rate: 1.0 weights: @@ -178,38 +178,40 @@ views: scale: 1.0 data: n: 0 - asgn: - alpha: 1 + prior_process: asgn: - - 0 - - 1 - - 1 - - 1 - counts: - - 1 - - 3 - n_cats: 2 - prior: - !Gamma + asgn: + - 0 + - 1 + - 1 + - 1 + counts: + - 1 + - 3 + n_cats: 2 + process: !dirichlet + alpha: 1 + alpha_prior: shape: 1.0 rate: 1.0 weights: - 0.25 - 0.75 alpha: 1 -asgn: - alpha: 1 +prior_process: asgn: - - 0 - - 0 - - 1 - - 1 - counts: - - 2 - - 2 - n_cats: 2 - prior: - !Gamma + asgn: + - 0 + - 0 + - 1 + - 1 + counts: + - 2 + - 2 + n_cats: 2 + process: !dirichlet + alpha: 1 + alpha_prior: shape: 1.0 rate: 1.0 weights: diff --git a/lace/resources/test/single-categorical.yaml b/lace/resources/test/single-categorical.yaml index 6369eaf8..beb105d8 100644 --- a/lace/resources/test/single-categorical.yaml +++ b/lace/resources/test/single-categorical.yaml @@ -1,16 +1,14 @@ --- -loglike: 0.0 +score: + ln_likelihood: 0.0 + ln_prior: 0.0 + ln_state_prior_process: 0.0 + ln_view_prior_process: 0.0 diagnostics: loglike: - 0.0 - nviews: - - 2 - state_alpha: - - 1.0 -view_alpha_prior: - !Gamma - shape: 1.0 - rate: 1.0 + logprior: + - 0.0 views: - ftrs: 0: @@ -52,35 +50,38 @@ views: scale: 1.0 data: n: 0 - asgn: - alpha: 1 + prior_process: asgn: - - 0 - - 0 - - 1 - - 1 - counts: - - 2 - - 2 - n_cats: 2 - prior: - !Gamma + asgn: + - 0 + - 0 + - 1 + - 1 + counts: + - 2 + - 2 + n_cats: 2 + process: !dirichlet + alpha: 1.0 + alpha_prior: shape: 1.0 rate: 1.0 weights: - 0.5 - 0.5 alpha: 1 -asgn: - alpha: 1 +prior_process: asgn: - - 0 - counts: - - 1 - n_cats: 1 - prior: - !Gamma - shape: 1.0 - rate: 1.0 + asgn: + - 0 + counts: + - 1 + n_cats: 1 + process: + !dirichlet + alpha: 1.0 + alpha_prior: + shape: 1.0 + rate: 1.0 weights: - 1.0 diff --git a/lace/resources/test/single-continuous.yaml b/lace/resources/test/single-continuous.yaml index c4e780ee..8e47ce54 100644 --- a/lace/resources/test/single-continuous.yaml +++ b/lace/resources/test/single-continuous.yaml @@ -1,16 +1,14 @@ --- -loglike: 0.0 +score: + ln_likelihood: 0.0 + ln_prior: 0.0 + ln_state_prior_process: 0.0 + ln_view_prior_process: 0.0 diagnostics: loglike: - 0.0 - nviews: - - 2 - state_alpha: - - 1.0 -view_alpha_prior: - !Gamma - shape: 1.0 - rate: 1.0 + logprior: + - 0.0 views: - ftrs: 0: @@ -51,35 +49,39 @@ views: scale: 1.0 data: n: 0 - asgn: - alpha: 1 + prior_process: + process: + !dirichlet + alpha: 1.0 + alpha_prior: + shape: 1.0 + rate: 1.0 asgn: - - 0 - - 0 - - 1 - - 1 - counts: - - 2 - - 2 - n_cats: 2 - prior: - !Gamma - shape: 1.0 - rate: 1.0 + asgn: + - 0 + - 0 + - 1 + - 1 + counts: + - 2 + - 2 + n_cats: 2 weights: - 0.5 - 0.5 alpha: 1 -asgn: - alpha: 1 +prior_process: asgn: - - 0 - counts: - - 1 - n_cats: 1 - prior: - !Gamma - shape: 1.0 - rate: 1.0 + asgn: + - 0 + counts: + - 1 + n_cats: 1 + process: + !dirichlet + alpha: 1.0 + alpha_prior: + shape: 1.0 + rate: 1.0 weights: - 1.0 diff --git a/lace/resources/test/single-count.yaml b/lace/resources/test/single-count.yaml index d3d0f622..3a77b274 100644 --- a/lace/resources/test/single-count.yaml +++ b/lace/resources/test/single-count.yaml @@ -1,16 +1,14 @@ --- -loglike: 0.0 +score: + ln_likelihood: 0.0 + ln_prior: 0.0 + ln_state_prior_process: 0.0 + ln_view_prior_process: 0.0 diagnostics: - loglike: - - 0.0 - nviews: - - 2 - state_alpha: - - 1.0 -view_alpha_prior: - !Gamma - shape: 1.0 - rate: 1.0 + loglike: + - 0.0 + logprior: + - 0.0 views: - ftrs: 0: @@ -41,34 +39,35 @@ views: scale: 2.0 data: n: 0 - asgn: - alpha: 1 + prior_process: asgn: - - 0 - - 0 - - 1 - - 1 - counts: - - 2 - - 2 - n_cats: 2 - prior: - !Gamma + asgn: + - 0 + - 0 + - 1 + - 1 + counts: + - 2 + - 2 + n_cats: 2 + process: !dirichlet + alpha: 1.0 + alpha_prior: shape: 1.0 rate: 1.0 weights: - 0.5 - 0.5 - alpha: 1 -asgn: - alpha: 1 +prior_process: asgn: - - 0 - counts: - - 1 - n_cats: 1 - prior: - !Gamma + asgn: + - 0 + counts: + - 1 + n_cats: 1 + process: !dirichlet + alpha: 1 + alpha_prior: shape: 1.0 rate: 1.0 weights: diff --git a/lace/resources/test/small/small-state-1.yaml b/lace/resources/test/small/small-state-1.yaml index 6cc3d5ba..ed16459f 100644 --- a/lace/resources/test/small/small-state-1.yaml +++ b/lace/resources/test/small/small-state-1.yaml @@ -1,16 +1,14 @@ --- -loglike: 0.0 +score: + ln_likelihood: 0.0 + ln_prior: 0.0 + ln_state_prior_process: 0.0 + ln_view_prior_process: 0.0 diagnostics: loglike: - 0.0 - nviews: - - 2 - state_alpha: - - 1.0 -view_alpha_prior: - !Gamma - shape: 1.0 - rate: 1.0 + logprior: + - 0.0 views: - ftrs: 0: @@ -89,19 +87,20 @@ views: scale: 1.0 data: n: 0 - asgn: - alpha: 1 + prior_process: asgn: - - 0 - - 0 - - 1 - - 1 - counts: - - 2 - - 2 - n_cats: 2 - prior: - !Gamma + asgn: + - 0 + - 0 + - 1 + - 1 + counts: + - 2 + - 2 + n_cats: 2 + process: !dirichlet + alpha: 1 + alpha_prior: shape: 1.0 rate: 1.0 weights: @@ -147,37 +146,38 @@ views: scale: 1.0 data: n: 0 - asgn: - alpha: 1 + prior_process: asgn: - - 0 - - 1 - - 1 - - 1 - counts: - - 1 - - 3 - n_cats: 2 - prior: - !Gamma + asgn: + - 0 + - 1 + - 1 + - 1 + counts: + - 1 + - 3 + n_cats: 2 + process: !dirichlet + alpha: 1.0 + alpha_prior: shape: 1.0 rate: 1.0 weights: - 0.25 - 0.75 - alpha: 1 -asgn: - alpha: 1 +prior_process: asgn: - - 0 - - 1 - - 0 - counts: - - 2 - - 1 - n_cats: 2 - prior: - !Gamma + asgn: + - 0 + - 1 + - 0 + counts: + - 2 + - 1 + n_cats: 2 + process: !dirichlet + alpha: 1 + alpha_prior: shape: 1.0 rate: 1.0 weights: diff --git a/lace/resources/test/small/small-state-2.yaml b/lace/resources/test/small/small-state-2.yaml index d349c322..bbb3ac3b 100644 --- a/lace/resources/test/small/small-state-2.yaml +++ b/lace/resources/test/small/small-state-2.yaml @@ -1,16 +1,14 @@ --- -loglike: 0.0 +score: + ln_likelihood: 0.0 + ln_prior: 0.0 + ln_state_prior_process: 0.0 + ln_view_prior_process: 0.0 diagnostics: - loglike: - - 0.0 - nviews: - - 2 - state_alpha: - - 1.0 -view_alpha_prior: - !Gamma - shape: 1.0 - rate: 1.0 + loglike: + - 0.0 + logprior: + - 0.0 views: - ftrs: 0: @@ -58,27 +56,27 @@ views: scale: 1.0 data: n: 0 - asgn: - alpha: 1 + prior_process: asgn: - - 0 - - 1 - - 1 - - 2 - counts: - - 1 - - 2 - - 1 - n_cats: 3 - prior: - !Gamma + asgn: + - 0 + - 1 + - 1 + - 2 + counts: + - 1 + - 2 + - 1 + n_cats: 3 + process: !dirichlet + alpha: 1 + alpha_prior: shape: 1.0 rate: 1.0 weights: - 0.25 - 0.5 - 0.25 - alpha: 1 - ftrs: 1: !Continuous @@ -156,37 +154,38 @@ views: scale: 1.0 data: n: 0 - asgn: - alpha: 1 + prior_process: asgn: - - 0 - - 0 - - 1 - - 1 - counts: - - 2 - - 2 - n_cats: 2 - prior: - !Gamma + asgn: + - 0 + - 0 + - 1 + - 1 + counts: + - 2 + - 2 + n_cats: 2 + process: !dirichlet + alpha: 1 + alpha_prior: shape: 1.0 rate: 1.0 weights: - 0.5 - 0.5 - alpha: 1 -asgn: - alpha: 1 +prior_process: asgn: - - 0 - - 1 - - 1 - counts: - - 1 - - 2 - n_cats: 2 - prior: - !Gamma + asgn: + - 0 + - 1 + - 1 + counts: + - 1 + - 2 + n_cats: 2 + process: !dirichlet + alpha: 1 + alpha_prior: shape: 1.0 rate: 1.0 weights: diff --git a/lace/resources/test/small/small-state-3.yaml b/lace/resources/test/small/small-state-3.yaml index 78e61b22..290c9f9c 100644 --- a/lace/resources/test/small/small-state-3.yaml +++ b/lace/resources/test/small/small-state-3.yaml @@ -1,16 +1,14 @@ --- -loglike: 0.0 +score: + ln_likelihood: 0.0 + ln_prior: 0.0 + ln_state_prior_process: 0.0 + ln_view_prior_process: 0.0 diagnostics: - loglike: - - 0.0 - nviews: - - 2 - state_alpha: - - 1.0 -view_alpha_prior: - !Gamma - shape: 1.0 - rate: 1.0 + loglike: + - 0.0 + logprior: + - 0.0 views: - ftrs: 0: @@ -127,38 +125,39 @@ views: scale: 1.0 data: n: 0 - asgn: - alpha: 1 + prior_process: asgn: - - 0 - - 1 - - 1 - - 1 - counts: - - 1 - - 3 - n_cats: 2 - prior: - !Gamma + alpha: 1 + asgn: + - 0 + - 1 + - 1 + - 1 + counts: + - 1 + - 3 + n_cats: 2 + process: !dirichlet + alpha: 1 + alpha_prior: shape: 1.0 rate: 1.0 weights: - 0.25 - 0.75 - alpha: 1 -asgn: - alpha: 1 +prior_process: asgn: - - 0 - - 0 - - 0 - counts: - - 3 - n_cats: 1 - prior: - !Gamma + asgn: + - 0 + - 0 + - 0 + counts: + - 3 + n_cats: 1 + process: !dirichlet + alpha: 1 + alpha_prior: shape: 1.0 rate: 1.0 weights: - 1.0 -alpha: 1 diff --git a/lace/resources/test/spread-out-continuous-modes.yaml b/lace/resources/test/spread-out-continuous-modes.yaml index 3ca34fec..f7bd75f6 100644 --- a/lace/resources/test/spread-out-continuous-modes.yaml +++ b/lace/resources/test/spread-out-continuous-modes.yaml @@ -1,16 +1,14 @@ --- -loglike: 0.0 +score: + ln_likelihood: 0.0 + ln_prior: 0.0 + ln_state_prior_process: 0.0 + ln_view_prior_process: 0.0 diagnostics: - loglike: - - 0.0 - nviews: - - 1 - state_alpha: - - 1.0 -view_alpha_prior: - !Gamma - shape: 1.0 - rate: 1.0 + loglike: + - 0.0 + logprior: + - 0.0 views: - ftrs: 0: @@ -58,20 +56,21 @@ views: scale: 1.0 data: n: 0 - asgn: - alpha: 1 + prior_process: asgn: - - 0 - - 0 - - 1 - - 2 - counts: - - 2 - - 1 - - 1 - n_cats: 3 - prior: - !Gamma + asgn: + - 0 + - 0 + - 1 + - 2 + counts: + - 2 + - 1 + - 1 + n_cats: 3 + process: !dirichlet + alpha: 1 + alpha_prior: shape: 1.0 rate: 1.0 weights: @@ -79,15 +78,16 @@ views: - 0.2 - 0.3 alpha: 1 -asgn: - alpha: 1 +prior_process: asgn: - - 0 - counts: - - 1 - n_cats: 1 - prior: - !Gamma + asgn: + - 0 + counts: + - 1 + n_cats: 1 + process: !dirichlet + alpha: 1 + alpha_prior: shape: 1.0 rate: 1.0 weights: diff --git a/lace/src/bencher.rs b/lace/src/bencher.rs index 5b6b956b..cd63e8d1 100644 --- a/lace/src/bencher.rs +++ b/lace/src/bencher.rs @@ -62,6 +62,26 @@ pub enum GenerateStateError { BuildState(#[from] BuildStateError), } +fn emit_prior_process( + prior_process: crate::codebook::PriorProcess, + rng: &mut R, +) -> lace_stats::prior_process::Process { + use lace_stats::prior_process::{Dirichlet, PitmanYor, Process}; + match prior_process { + crate::codebook::PriorProcess::Dirichlet { alpha_prior } => { + let inner = Dirichlet::from_prior(alpha_prior, rng); + Process::Dirichlet(inner) + } + crate::codebook::PriorProcess::PitmanYor { + alpha_prior, + d_prior, + } => { + let inner = PitmanYor::from_prior(alpha_prior, d_prior, rng); + Process::PitmanYor(inner) + } + } +} + impl BencherSetup { fn gen_state( &mut self, @@ -74,15 +94,22 @@ impl BencherSetup { } => crate::codebook::data::read_csv(path) .map_err(GenerateStateError::Read) .and_then(|df| { - let state_alpha_prior = codebook - .state_alpha_prior - .clone() - .unwrap_or_else(lace_consts::state_alpha_prior); - - let view_alpha_prior = codebook - .view_alpha_prior - .clone() - .unwrap_or_else(lace_consts::view_alpha_prior); + let state_prior_process = { + let prior_process = codebook + .state_prior_process + .clone() + .unwrap_or_default(); + emit_prior_process(prior_process, rng) + }; + + let view_prior_process = { + let prior_process = codebook + .view_prior_process + .clone() + .unwrap_or_default(); + emit_prior_process(prior_process, rng) + }; + let mut codebook_tmp = Box::::default(); // swap codebook into something we can take ownership of @@ -94,8 +121,8 @@ impl BencherSetup { State::from_prior( features, - state_alpha_prior, - view_alpha_prior, + state_prior_process, + view_prior_process, &mut rng, ) }) diff --git a/lace/src/data/data_source.rs b/lace/src/data/data_source.rs index c85538f2..15afa3e4 100644 --- a/lace/src/data/data_source.rs +++ b/lace/src/data/data_source.rs @@ -148,19 +148,19 @@ impl DataSource { use crate::codebook::{data, formats}; let codebook = match &self { DataSource::Ipc(path) => { - formats::codebook_from_ipc(path, None, None, false) + formats::codebook_from_ipc(path, None, None, None, false) } DataSource::Csv(path) => { - formats::codebook_from_csv(path, None, None, false) + formats::codebook_from_csv(path, None, None, None, false) } DataSource::Json(path) => { - formats::codebook_from_json(path, None, None, false) + formats::codebook_from_json(path, None, None, None, false) } DataSource::Parquet(path) => { - formats::codebook_from_parquet(path, None, None, false) + formats::codebook_from_parquet(path, None, None, None, false) } DataSource::Polars(df) => { - data::df_to_codebook(df, None, None, false) + data::df_to_codebook(df, None, None, None, false) } DataSource::Empty => Ok(Codebook::default()), }?; @@ -175,7 +175,7 @@ impl DataSource { use crate::codebook::data; let codebook = match &self { DataSource::Polars(df) => { - data::df_to_codebook(df, None, None, false) + data::df_to_codebook(df, None, None, None, false) } DataSource::Empty => Ok(Codebook::default()), }?; diff --git a/lace/src/interface/engine/builder.rs b/lace/src/interface/engine/builder.rs index f8a8141e..3efa80c4 100644 --- a/lace/src/interface/engine/builder.rs +++ b/lace/src/interface/engine/builder.rs @@ -193,11 +193,11 @@ mod tests { for (state_1, state_2) in engine_1.states.iter().zip(engine_2.states.iter()) { - assert_eq!(&state_1.asgn, &state_2.asgn); + assert_eq!(state_1.asgn(), state_2.asgn()); for (view_1, view_2) in state_1.views.iter().zip(state_2.views.iter()) { - assert_eq!(&view_1.asgn, &view_2.asgn); + assert_eq!(view_1.asgn(), view_2.asgn()); } } @@ -208,11 +208,11 @@ mod tests { for (state_1, state_2) in engine_1.states.iter().zip(engine_2.states.iter()) { - assert_eq!(&state_1.asgn, &state_2.asgn); + assert_eq!(state_1.asgn(), state_2.asgn()); for (view_1, view_2) in state_1.views.iter().zip(state_2.views.iter()) { - assert_eq!(&view_1.asgn, &view_2.asgn); + assert_eq!(view_1.asgn(), view_2.asgn()); } } } diff --git a/lace/src/interface/engine/mod.rs b/lace/src/interface/engine/mod.rs index 015620bd..890fa199 100644 --- a/lace/src/interface/engine/mod.rs +++ b/lace/src/interface/engine/mod.rs @@ -75,7 +75,7 @@ impl HasData for Engine { #[inline] fn summarize_feature(&self, ix: usize) -> SummaryStatistics { let state = &self.states[0]; - let view_ix = state.asgn.asgn[ix]; + let view_ix = state.asgn().asgn[ix]; // XXX: Cloning the data could be very slow state.views[view_ix].ftrs[&ix].clone_data().summarize() } @@ -124,6 +124,26 @@ fn col_models_from_data_src( crate::data::df_to_col_models(codebook, df, rng) } +fn emit_prior_process( + prior_process: crate::codebook::PriorProcess, + rng: &mut R, +) -> lace_stats::prior_process::Process { + use lace_stats::prior_process::{Dirichlet, PitmanYor, Process}; + match prior_process { + crate::codebook::PriorProcess::Dirichlet { alpha_prior } => { + let inner = Dirichlet::from_prior(alpha_prior, rng); + Process::Dirichlet(inner) + } + crate::codebook::PriorProcess::PitmanYor { + alpha_prior, + d_prior, + } => { + let inner = PitmanYor::from_prior(alpha_prior, d_prior, rng); + Process::PitmanYor(inner) + } + } +} + /// Maintains and samples states impl Engine { /// Create a new engine @@ -154,23 +174,23 @@ impl Engine { col_models_from_data_src(codebook, data_source, &mut rng) .map_err(NewEngineError::DataParseError)?; - let state_alpha_prior = codebook - .state_alpha_prior - .clone() - .unwrap_or_else(lace_consts::state_alpha_prior); + let state_prior_process = emit_prior_process( + codebook.state_prior_process.clone().unwrap_or_default(), + &mut rng, + ); - let view_alpha_prior = codebook - .view_alpha_prior - .clone() - .unwrap_or_else(lace_consts::view_alpha_prior); + let view_prior_process = emit_prior_process( + codebook.view_prior_process.clone().unwrap_or_default(), + &mut rng, + ); let states: Vec = (0..n_states) .map(|_| { let features = col_models.clone(); State::from_prior( features, - state_alpha_prior.clone(), - view_alpha_prior.clone(), + state_prior_process.clone(), + view_prior_process.clone(), &mut rng, ) }) @@ -837,7 +857,7 @@ impl Engine { Ok(()) } - /// Run the Gibbs reassignment kernel on a specific column and row withing + /// Run the Gibbs reassignment kernel on a specific column and row within /// a view. Used when the user would like to focus more updating on /// specific regions of the table. /// @@ -867,7 +887,7 @@ impl Engine { .for_each(|(state, mut trng)| { state.reassign_col_gibbs(col_ix, true, &mut trng); let view = { - let view_ix = state.asgn.asgn[col_ix]; + let view_ix = state.asgn().asgn[col_ix]; &mut state.views[view_ix] }; @@ -875,8 +895,8 @@ impl Engine { // Make sure the view weights are correct so oracle functions // reflect the update correctly. - view.weights = view.asgn.weights(); - debug_assert!(view.asgn.validate().is_valid()); + view.weights = view.prior_process.weight_vec(false); + debug_assert!(view.asgn().validate().is_valid()); }); } diff --git a/lace/src/interface/engine/update_handler.rs b/lace/src/interface/engine/update_handler.rs index 7f709d6e..6d34fab6 100644 --- a/lace/src/interface/engine/update_handler.rs +++ b/lace/src/interface/engine/update_handler.rs @@ -418,7 +418,10 @@ impl UpdateHandler for ProgressBar { sender .lock() .unwrap() - .send((state_id, state.log_prior + state.loglike)) + .send(( + state_id, + state.score.ln_prior + state.score.ln_likelihood, + )) .unwrap(); } } diff --git a/lace/src/interface/oracle/mod.rs b/lace/src/interface/oracle/mod.rs index 042e66ef..82daec43 100644 --- a/lace/src/interface/oracle/mod.rs +++ b/lace/src/interface/oracle/mod.rs @@ -220,8 +220,8 @@ mod tests { fn dummy_codebook_from_state(state: &State) -> Codebook { Codebook { table_name: "my_table".into(), - state_alpha_prior: None, - view_alpha_prior: None, + state_prior_process: None, + view_prior_process: None, col_metadata: (0..state.n_cols()) .map(|ix| { let ftr = state.feature(ix); @@ -647,7 +647,7 @@ mod tests { // draw from appropriate component from that view let mut xs: Vec = Vec::with_capacity(col_ixs.len()); col_ixs.iter().for_each(|col_ix| { - let view_ix = state.asgn.asgn[*col_ix]; + let view_ix = state.asgn().asgn[*col_ix]; let k = cpnt_ixs[&view_ix]; let x = state.views[view_ix].ftrs[col_ix].draw(k, &mut rng); diff --git a/lace/src/interface/oracle/traits.rs b/lace/src/interface/oracle/traits.rs index 11e5772b..6650bff8 100644 --- a/lace/src/interface/oracle/traits.rs +++ b/lace/src/interface/oracle/traits.rs @@ -110,7 +110,7 @@ pub trait OracleT: CanOracle { fn ftype(&self, col_ix: Ix) -> Result { let col_ix = col_ix.col_ix(self.codebook())?; let state = &self.states()[0]; - let view_ix = state.asgn.asgn[col_ix]; + let view_ix = state.asgn().asgn[col_ix]; Ok(state.views[view_ix].ftrs[&col_ix].ftype()) } @@ -199,7 +199,7 @@ pub trait OracleT: CanOracle { Ok(1.0) } else { let depprob = self.states().iter().fold(0.0, |acc, state| { - if state.asgn.asgn[col_a] == state.asgn.asgn[col_b] { + if state.asgn().asgn[col_a] == state.asgn().asgn[col_b] { acc + 1.0 } else { acc @@ -351,7 +351,7 @@ pub trait OracleT: CanOracle { let rowsim = self.states().iter().fold(0.0, |acc, state| { let view_ixs: Vec = match wrt.as_ref() { Some(col_ixs) => { - let asgn = &state.asgn.asgn; + let asgn = &state.asgn().asgn; let viewset: BTreeSet = col_ixs.iter().map(|&col_ix| asgn[col_ix]).collect(); viewset.iter().copied().collect() @@ -375,7 +375,7 @@ pub trait OracleT: CanOracle { acc + view_ixs.iter().enumerate().fold( 0.0, |sim, (ix, &view_ix)| { - let asgn = &state.views[view_ix].asgn.asgn; + let asgn = &state.views[view_ix].asgn().asgn; if asgn[row_a] == asgn[row_b] { sim + col_counts.as_ref().map_or(1.0, |cts| cts[ix]) } else { @@ -494,7 +494,7 @@ pub trait OracleT: CanOracle { let compliment = self.states().iter().fold(0.0, |acc, state| { let view_ixs: Vec = match wrt.as_ref() { Some(col_ixs) => { - let asgn = &state.asgn.asgn; + let asgn = &state.asgn().asgn; let viewset: BTreeSet = col_ixs.iter().map(|&col_ix| asgn[col_ix]).collect(); viewset.iter().copied().collect() @@ -503,7 +503,7 @@ pub trait OracleT: CanOracle { }; acc + view_ixs.iter().fold(0.0, |novelty, &view_ix| { - let asgn = &state.views[view_ix].asgn; + let asgn = &state.views[view_ix].asgn(); let z = asgn.asgn[row_ix]; novelty + (asgn.counts[z] as f64) / nf }) / (view_ixs.len() as f64) @@ -1636,8 +1636,8 @@ pub trait OracleT: CanOracle { let state = &self.states()[state_ix]; // Draw from the propoer component in the feature - let view_ix = state.asgn.asgn[col_ix]; - let cpnt_ix = state.views[view_ix].asgn.asgn[row_ix]; + let view_ix = state.asgn().asgn[col_ix]; + let cpnt_ix = state.views[view_ix].asgn().asgn[row_ix]; let ftr = state.feature(col_ix); let x = ftr.draw(cpnt_ix, &mut rng); utils::post_process_datum(x, col_ix, self.codebook()) @@ -1817,13 +1817,11 @@ pub trait OracleT: CanOracle { /// # use lace::OracleT; /// # use lace_data::{Datum, Category}; /// # let oracle = Example::Satellites.oracle().unwrap(); - /// let (imp, _) = oracle.impute( + /// let (imp, unc): (Datum, Option) = oracle.impute( /// "X-Sat", /// "longitude_radians_of_geo", /// true, /// ).unwrap(); - /// - /// assert!((imp.to_f64_opt().unwrap() - 0.18514237733859296).abs() < 1e-10); /// ``` fn impute( &self, @@ -1956,7 +1954,7 @@ pub trait OracleT: CanOracle { /// ``` /// /// Note that the uncertainty when the prediction is missing is the - /// uncertainty only off the missing prediction. For example, the + /// uncertainty only of the missing prediction. For example, the /// `longitude_radians_of_geo` value is only present for geosynchronous /// satellites, which have an orbital period of around 1440 minutes. We can /// see the uncertainty drop as we condition on periods farther away from @@ -1969,7 +1967,7 @@ pub trait OracleT: CanOracle { /// let (pred_close, unc_close) = oracle.predict( /// "longitude_radians_of_geo", /// &Given::Conditions(vec![ - /// ("Period_minutes", Datum::Continuous(1200.0)) + /// ("Period_minutes", Datum::Continuous(1400.0)) /// ]), /// true, /// None, @@ -1987,7 +1985,6 @@ pub trait OracleT: CanOracle { /// ).unwrap(); /// /// assert_eq!(pred_far, Datum::Missing); - /// dbg!(&unc_far, &unc_close); /// assert!(unc_far.unwrap() < unc_close.unwrap()); /// ``` fn predict( @@ -2105,7 +2102,7 @@ pub trait OracleT: CanOracle { let mut mixture_types: Vec = states .iter() .map(|state| { - let view_ix = state.asgn.asgn[col_ix]; + let view_ix = state.asgn().asgn[col_ix]; let weights = &utils::given_weights(&[state], &[col_ix], &given)[0]; @@ -2323,8 +2320,8 @@ pub trait OracleT: CanOracle { .iter() .map(|&ix| { let state = &self.states()[ix]; - let view_ix = state.asgn.asgn[col_ix]; - let k = state.views[view_ix].asgn.asgn[row_ix]; + let view_ix = state.asgn().asgn[col_ix]; + let k = state.views[view_ix].asgn().asgn[row_ix]; state.views[view_ix].ftrs[&col_ix].cpnt_logp(&x, k) }) .collect(); diff --git a/lace/src/interface/oracle/utils.rs b/lace/src/interface/oracle/utils.rs index d1ef1609..0a39c528 100644 --- a/lace/src/interface/oracle/utils.rs +++ b/lace/src/interface/oracle/utils.rs @@ -193,7 +193,7 @@ impl<'s, R: rand::Rng> Iterator for Simulator<'s, R> { .col_ixs .iter() .map(|col_ix| { - let view_ix = state.asgn.asgn[*col_ix]; + let view_ix = state.asgn().asgn[*col_ix]; let k = cpnt_ixs[&view_ix]; state.views[view_ix].ftrs[col_ix].draw(k, &mut rng) }) @@ -423,7 +423,7 @@ pub fn single_state_weights( let mut view_weights: BTreeMap> = BTreeMap::new(); col_ixs .iter() - .map(|&ix| state.asgn.asgn[ix]) + .map(|&ix| state.asgn().asgn[ix]) .for_each(|view_ix| { view_weights .entry(view_ix) @@ -442,7 +442,7 @@ pub fn single_state_exp_weights( let mut view_weights: BTreeMap> = BTreeMap::new(); col_ixs .iter() - .map(|&ix| state.asgn.asgn[ix]) + .map(|&ix| state.asgn().asgn[ix]) .for_each(|view_ix| { view_weights.entry(view_ix).or_insert_with(|| { single_view_exp_weights(state, view_ix, given) @@ -464,7 +464,8 @@ fn single_view_weights( match given { Given::Conditions(ref conditions) => { for &(col_ix, ref datum) in conditions { - let in_target_view = state.asgn.asgn[col_ix] == target_view_ix; + let in_target_view = + state.asgn().asgn[col_ix] == target_view_ix; if in_target_view { view.ftrs[&col_ix].accum_weights( datum, @@ -493,7 +494,7 @@ fn single_view_exp_weights( match given { Given::Conditions(ref conditions) => { conditions.iter().for_each(|(ix, datum)| { - let in_target_view = state.asgn.asgn[*ix] == target_view_ix; + let in_target_view = state.asgn().asgn[*ix] == target_view_ix; if in_target_view { view.ftrs[ix].accum_exp_weights(datum, &mut weights); } @@ -588,7 +589,7 @@ fn single_val_logp( col_ixs .iter() .zip(val) - .map(|(col_ix, datum)| (col_ix, state.asgn.asgn[*col_ix], datum)) + .map(|(col_ix, datum)| (col_ix, state.asgn().asgn[*col_ix], datum)) .for_each(|(col_ix, view_ix, datum)| { state.views[view_ix].ftrs[col_ix].accum_weights( datum, @@ -852,7 +853,7 @@ macro_rules! dep_ind_col_mixtures { _ => panic!("Unexpected MixtureType"), }; - if state.asgn.asgn[$col_a] == state.asgn.asgn[$col_b] { + if state.asgn().asgn[$col_a] == state.asgn().asgn[$col_b] { weight += 1.0; mms_dep.push(mm); } else { @@ -1339,7 +1340,7 @@ pub fn continuous_predict( let mixtures = states .iter() .map(|state| { - let view_ix = state.asgn.asgn[col_ix]; + let view_ix = state.asgn().asgn[col_ix]; // NOTE: There is a slight speedup from using given_exp_weights, // but at the cost of panics when there is a large number of // conditions in the given: underflow causes all the weights to @@ -1506,7 +1507,7 @@ macro_rules! predunc_arm { let mix_models: Vec> = $states .iter() .map(|state| { - let view_ix = state.asgn.asgn[$col_ix]; + let view_ix = state.asgn().asgn[$col_ix]; let weights = single_view_weights(&state, view_ix, $given_opt); let mut mixture: Mixture<$cpnt_type> = @@ -1534,7 +1535,7 @@ pub fn predict_uncertainty( states_ixs_opt: Option<&[usize]>, ) -> f64 { let ftype = { - let view_ix = states[0].asgn.asgn[col_ix]; + let view_ix = states[0].asgn().asgn[col_ix]; states[0].views[view_ix].ftrs[&col_ix].ftype() }; let states = select_states(states, states_ixs_opt); @@ -1562,7 +1563,7 @@ pub(crate) fn mnar_uncertainty( .. }) => { // get the index of the view to which this column is assigned - let view_ix = state.asgn.asgn[col_ix]; + let view_ix = state.asgn().asgn[col_ix]; // Get the weights from the view using the given let weights = { let mut weights = @@ -1623,9 +1624,9 @@ macro_rules! impunc_arm { let n_states = $states.len(); let mixtures = (0..n_states) .map(|state_ix| { - let view_ix = $states[state_ix].asgn.asgn[$col_ix]; + let view_ix = $states[state_ix].asgn().asgn[$col_ix]; let view = &$states[state_ix].views[view_ix]; - let k = view.asgn.asgn[$row_ix]; + let k = view.asgn().asgn[$row_ix]; match &view.ftrs[&$col_ix] { ColModel::$variant(ref ftr) => ftr.components[k].fx.clone(), ColModel::MissingNotAtRandom( diff --git a/lace/src/prelude.rs b/lace/src/prelude.rs index e6ea31b2..e01bf563 100644 --- a/lace/src/prelude.rs +++ b/lace/src/prelude.rs @@ -12,7 +12,6 @@ pub use crate::data::DataSource; pub use lace_cc::{ alg::{ColAssignAlg, RowAssignAlg}, - assignment::AssignmentBuilder, config::StateUpdateConfig, feature::{Column, FType}, state::State, @@ -23,6 +22,7 @@ pub use lace_codebook::{ Codebook, CodebookError, ColMetadata, ColMetadataList, ColType, }; pub use lace_metadata::SerializedType; +pub use lace_stats::assignment::Assignment; pub use lace_stats::prior::{csd::CsdHyper, nix::NixHyper, pg::PgHyper}; pub use lace_stats::rv; pub use lace_utils as utils; diff --git a/lace/tests/engine.rs b/lace/tests/engine.rs index d287a881..77bf97cc 100644 --- a/lace/tests/engine.rs +++ b/lace/tests/engine.rs @@ -55,7 +55,7 @@ fn loaded_engine_should_have_same_rng_state() { engine_2.run(5).unwrap(); for (s1, s2) in engine_1.states.iter().zip(engine_2.states.iter()) { - assert_eq!(s1.asgn.asgn, s2.asgn.asgn); + assert_eq!(s1.asgn().asgn, s2.asgn().asgn); } } @@ -152,7 +152,7 @@ fn run_engine_after_flatten_cols_smoke_test() { engine.run(1).unwrap(); } -mod contructor { +mod constructor { use super::*; use lace::error::{DataParseError, NewEngineError}; use lace_codebook::{ColMetadata, ColType}; @@ -229,7 +229,7 @@ fn engine_build_without_flat_col_is_not_flat() { let path = animals_data_path(); let df = lace_codebook::data::read_csv(path).unwrap(); let engine = EngineBuilder::new(DataSource::Polars(df)) - .with_nstates(8) + .with_nstates(32) .build() .unwrap(); assert!(engine.states.iter().any(|state| state.n_views() > 1)); @@ -1336,9 +1336,9 @@ mod insert_data { n_iters: 10, transitions: vec![ StateTransition::ColumnAssignment(ColAssignAlg::Gibbs), - StateTransition::StateAlpha, + StateTransition::StatePriorProcessParams, StateTransition::RowAssignment(RowAssignAlg::Gibbs), - StateTransition::ViewAlphas, + StateTransition::ViewPriorProcessParams, StateTransition::FeaturePriors, ], ..Default::default() @@ -1432,9 +1432,9 @@ mod insert_data { n_iters: 10, transitions: vec![ StateTransition::ColumnAssignment(ColAssignAlg::Gibbs), - StateTransition::StateAlpha, + StateTransition::StatePriorProcessParams, StateTransition::RowAssignment(RowAssignAlg::Gibbs), - StateTransition::ViewAlphas, + StateTransition::ViewPriorProcessParams, StateTransition::FeaturePriors, ], ..Default::default() @@ -1754,8 +1754,8 @@ mod insert_data { EngineUpdateConfig { n_iters: 2, transitions: vec![ - StateTransition::StateAlpha, - StateTransition::ViewAlphas, + StateTransition::StatePriorProcessParams, + StateTransition::ViewPriorProcessParams, StateTransition::ComponentParams, StateTransition::FeaturePriors, StateTransition::RowAssignment( @@ -2147,9 +2147,9 @@ mod insert_data { n_iters: 2, transitions: vec![ StateTransition::ColumnAssignment($col_kernel), - StateTransition::StateAlpha, + StateTransition::StatePriorProcessParams, StateTransition::RowAssignment($row_kernel), - StateTransition::ViewAlphas, + StateTransition::ViewPriorProcessParams, StateTransition::FeaturePriors, ], ..Default::default() @@ -2630,19 +2630,19 @@ mod prior_in_codebook { use lace_cc::feature::ColModel; use lace_codebook::{Codebook, ColMetadata, ColMetadataList, ColType}; use lace_stats::prior::nix::NixHyper; - use lace_stats::rv::dist::{Gamma, NormalInvChiSquared}; + use lace_stats::rv::dist::NormalInvChiSquared; use lace_stats::rv::traits::Rv; use std::convert::TryInto; use std::io::Write; - // Generate a two-column codebook ('x' and 'y'). The x column will alyways + // Generate a two-column codebook ('x' and 'y'). The x column will always // have a hyper for the x column, but will have a prior defined if set_prior // is true. The y column will have neither a prior or hyper defined. fn gen_codebook(n_rows: usize, set_prior: bool) -> Codebook { Codebook { table_name: String::from("table"), - state_alpha_prior: Some(Gamma::default()), - view_alpha_prior: Some(Gamma::default()), + state_prior_process: None, + view_prior_process: None, col_metadata: { let mut col_metadata = ColMetadataList::new(vec![]).unwrap(); col_metadata @@ -2691,12 +2691,12 @@ mod prior_in_codebook { " --- table_name: table - state_alpha_prior: - !Gamma + state_prior_process: !dirichlet + alpha_prior: shape: 1.0 rate: 1.0 - view_alpha_prior: - !Gamma + view_prior_process: !dirichlet + alpha_prior: shape: 1.0 rate: 1.0 col_metadata: diff --git a/lace/tests/feature.rs b/lace/tests/feature.rs index cf62ddc4..fa38a6da 100644 --- a/lace/tests/feature.rs +++ b/lace/tests/feature.rs @@ -3,14 +3,15 @@ extern crate approx; use std::f64::consts::LN_2; -use lace_cc::assignment::{Assignment, AssignmentBuilder}; use lace_cc::component::ConjugateComponent; use lace_cc::feature::{Column, Feature}; use lace_data::SparseContainer; +use lace_stats::assignment::Assignment; use lace_stats::prior::csd::CsdHyper; use lace_stats::prior::nix::NixHyper; +use lace_stats::prior_process::Builder as AssignmentBuilder; use lace_stats::rv::dist::{ - Categorical, Gamma, Gaussian, NormalInvChiSquared, SymmetricDirichlet, + Categorical, Gaussian, NormalInvChiSquared, SymmetricDirichlet, }; use lace_stats::rv::traits::Rv; use rand::Rng; @@ -67,7 +68,7 @@ fn three_component_column() -> GaussCol { #[test] fn feature_with_flat_assign_should_have_one_component() { let mut rng = rand::thread_rng(); - let asgn = AssignmentBuilder::new(5).flat().build().unwrap(); + let asgn = AssignmentBuilder::new(5).flat().build().unwrap().asgn; let col = gauss_fixture(&mut rng, &asgn); @@ -78,7 +79,7 @@ fn feature_with_flat_assign_should_have_one_component() { fn feature_with_random_assign_should_have_k_component() { let mut rng = rand::thread_rng(); for _ in 0..50 { - let asgn = AssignmentBuilder::new(5).build().unwrap(); + let asgn = AssignmentBuilder::new(5).build().unwrap().asgn; let col = gauss_fixture(&mut rng, &asgn); assert_eq!(col.components.len(), asgn.n_cats); @@ -90,7 +91,7 @@ fn feature_with_random_assign_should_have_k_component() { #[test] fn append_empty_component_appends_one() { let mut rng = rand::thread_rng(); - let asgn = AssignmentBuilder::new(5).flat().build().unwrap(); + let asgn = AssignmentBuilder::new(5).flat().build().unwrap().asgn; let mut col = gauss_fixture(&mut rng, &asgn); assert_eq!(col.components.len(), 1); @@ -103,13 +104,11 @@ fn append_empty_component_appends_one() { #[test] fn reassign_to_more_components() { let mut rng = rand::thread_rng(); - let asgn_a = AssignmentBuilder::new(5).flat().build().unwrap(); + let asgn_a = AssignmentBuilder::new(5).flat().build().unwrap().asgn; let asgn_b = Assignment { - alpha: 1.0, asgn: vec![0, 0, 0, 1, 1], counts: vec![3, 2], n_cats: 2, - prior: Gamma::new(1.0, 1.0).unwrap(), }; let mut col = gauss_fixture(&mut rng, &asgn_a); @@ -237,13 +236,11 @@ fn gauss_accum_scores_2_cats_no_missing() { #[test] fn asgn_score_under_asgn_gaussian_magnitude() { let mut rng = rand::thread_rng(); - let asgn_a = AssignmentBuilder::new(5).flat().build().unwrap(); + let asgn_a = AssignmentBuilder::new(5).flat().build().unwrap().asgn; let asgn_b = Assignment { - alpha: 1.0, asgn: vec![0, 0, 0, 1, 1], counts: vec![3, 2], n_cats: 2, - prior: Gamma::new(1.0, 1.0).unwrap(), }; let col = gauss_fixture(&mut rng, &asgn_a); @@ -339,13 +336,11 @@ fn cat_u8_accum_scores_2_cats_no_missing() { #[test] fn asgn_score_under_asgn_cat_u8_magnitude() { let mut rng = rand::thread_rng(); - let asgn_a = AssignmentBuilder::new(5).flat().build().unwrap(); + let asgn_a = AssignmentBuilder::new(5).flat().build().unwrap().asgn; let asgn_b = Assignment { - alpha: 1.0, asgn: vec![0, 1, 1, 0, 1], counts: vec![2, 3], n_cats: 2, - prior: Gamma::new(1.0, 1.0).unwrap(), }; let col = categorical_fixture_u8(&mut rng, &asgn_a); @@ -364,9 +359,9 @@ fn asgn_score_under_asgn_cat_u8_magnitude() { // Gaussian // -------- #[test] -fn update_componet_params_should_draw_different_values_for_gaussian() { +fn update_component_params_should_draw_different_values_for_gaussian() { let mut rng = rand::thread_rng(); - let asgn = AssignmentBuilder::new(5).flat().build().unwrap(); + let asgn = AssignmentBuilder::new(5).flat().build().unwrap().asgn; let mut col = gauss_fixture(&mut rng, &asgn); let cpnt_a = col.components[0].clone(); @@ -390,7 +385,7 @@ fn asgn_score_should_be_the_same_as_score_given_current_asgn() { let mut col = Column::new(0, data, prior, hyper.clone()); - let asgn = AssignmentBuilder::new(n).flat().build().unwrap(); + let asgn = AssignmentBuilder::new(n).flat().build().unwrap().asgn; let asgn_score = col.asgn_score(&asgn); col.reassign(&asgn, &mut rng); diff --git a/lace/tests/oracle.rs b/lace/tests/oracle.rs index 1a1954f1..b024057b 100644 --- a/lace/tests/oracle.rs +++ b/lace/tests/oracle.rs @@ -35,12 +35,13 @@ fn gen_all_gauss_state( for i in 0..n_cols { ftrs.push(gen_col(i, n_rows, &mut rng)); } - State::from_prior( - ftrs, - Gamma::new(1.0, 1.0).unwrap(), - Gamma::new(1.0, 1.0).unwrap(), - &mut rng, - ) + let process = lace_stats::prior_process::Process::Dirichlet( + lace_stats::prior_process::Dirichlet::from_prior( + Gamma::default(), + &mut rng, + ), + ); + State::from_prior(ftrs, process.clone(), process, &mut rng) } fn load_states>(filenames: Vec

) -> Vec { @@ -63,8 +64,8 @@ fn dummy_codebook_from_state(state: &State) -> Codebook { Codebook { table_name: "my_table".into(), - state_alpha_prior: None, - view_alpha_prior: None, + state_prior_process: None, + view_prior_process: None, col_metadata: (0..state.n_cols()) .map(|ix| { let ftr = state.feature(ix); diff --git a/lace/tests/streaming_insert.rs b/lace/tests/streaming_insert.rs index 458bd86a..85e61299 100644 --- a/lace/tests/streaming_insert.rs +++ b/lace/tests/streaming_insert.rs @@ -6,7 +6,6 @@ use lace_codebook::{Codebook, ColMetadata, ColType}; use lace_data::Datum; use lace_stats::prior::nix::NixHyper; -use lace_stats::rv::dist::Gamma; use rand::{Rng, SeedableRng}; use rand_xoshiro::Xoshiro256Plus; @@ -63,8 +62,8 @@ fn gen_engine() -> Engine { let codebook = Codebook { table_name: "table".into(), - state_alpha_prior: Some(Gamma::default()), - view_alpha_prior: Some(Gamma::default()), + state_prior_process: None, + view_prior_process: None, col_metadata: (0..14) .map(|i| ColMetadata { name: format!("{}", i), diff --git a/lace/tests/workflow.rs b/lace/tests/workflow.rs index af7d5e2d..6764d633 100644 --- a/lace/tests/workflow.rs +++ b/lace/tests/workflow.rs @@ -35,7 +35,8 @@ fn default_csv_workflow() { let file = datafile(); // default codebook - let codebook = codebook_from_csv(file.path(), None, None, false).unwrap(); + let codebook = + codebook_from_csv(file.path(), None, None, None, false).unwrap(); let rng = rand_xoshiro::Xoshiro256Plus::from_entropy(); let mut engine = Engine::new( 4, @@ -56,7 +57,7 @@ fn satellites_csv_workflow() { // default codebook let codebook = - codebook_from_csv(path.as_path(), None, None, false).unwrap(); + codebook_from_csv(path.as_path(), None, None, None, false).unwrap(); let mut engine: Engine = EngineBuilder::new(DataSource::Csv(path)) .codebook(codebook) diff --git a/pylace/Cargo.lock b/pylace/Cargo.lock index 6cc3342e..1b6d316c 100644 --- a/pylace/Cargo.lock +++ b/pylace/Cargo.lock @@ -4,9 +4,9 @@ version = 3 [[package]] name = "ahash" -version = "0.8.7" +version = "0.8.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77c3a9648d43b9cd48db467b3f87fdd6e146bcc88ab0180006cef2179fe11d01" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", "getrandom", @@ -17,18 +17,18 @@ dependencies = [ [[package]] name = "aho-corasick" -version = "1.1.2" +version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2969dcb958b36655471fc61f7e416fa76033bdd4bfed0678d8fee1e2d07a1f0" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" dependencies = [ "memchr", ] [[package]] name = "allocator-api2" -version = "0.2.16" +version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5" +checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" [[package]] name = "android-tzdata" @@ -56,9 +56,9 @@ dependencies = [ [[package]] name = "argminmax" -version = "0.6.1" +version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "202108b46429b765ef483f8a24d5c46f48c14acfdacc086dd4ab6dddf6bcdbd2" +checksum = "52424b59d69d69d5056d508b260553afd91c57e21849579cd1f50ee8b8b88eaa" dependencies = [ "num-traits", ] @@ -96,9 +96,9 @@ checksum = "9ae037714f313c1353189ead58ef9eec30a8e8dc101b2622d461418fd59e28a9" [[package]] name = "autocfg" -version = "1.1.0" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80" [[package]] name = "bincode" @@ -111,56 +111,51 @@ dependencies = [ [[package]] name = "bitflags" -version = "1.3.2" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" - -[[package]] -name = "bitflags" -version = "2.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed570934406eb16438a4e976b1b4500774099c13b8cb96eec99f620f05090ddf" +checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" [[package]] name = "bumpalo" -version = "3.14.0" +version = "3.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec" +checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "bytemuck" -version = "1.14.0" +version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "374d28ec25809ee0e23827c2ab573d729e293f281dfe393500e7ad618baa61c6" +checksum = "5d6d68c57235a3a081186990eca2867354726650f42f7516ca50c28d6281fd15" dependencies = [ "bytemuck_derive", ] [[package]] name = "bytemuck_derive" -version = "1.5.0" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "965ab7eb5f8f97d2a083c799f3a1b994fc397b2fe2da5d1da1626ce15a39f2b1" +checksum = "4da9a32f3fed317401fa3c862968128267c3106685286e15d5aaa3d7389c2f60" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.60", ] [[package]] name = "bytes" -version = "1.5.0" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" +checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" [[package]] name = "cc" -version = "1.0.83" +version = "1.0.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" +checksum = "065a29261d53ba54260972629f9ca6bffa69bac13cd1fed61420f7fa68b9f8bd" dependencies = [ "jobserver", "libc", + "once_cell", ] [[package]] @@ -171,25 +166,25 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.31" +version = "0.4.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f2c685bad3eb3d45a01354cedb7d5faa66194d1d58ba6e267a8de788f79db38" +checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" dependencies = [ "android-tzdata", "iana-time-zone", "num-traits", - "windows-targets 0.48.5", + "windows-targets 0.52.5", ] [[package]] name = "comfy-table" -version = "7.1.0" +version = "7.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c64043d6c7b7a4c58e39e7efccfdea7b93d885a795d0c054a69dbbf4dd52686" +checksum = "b34115915337defe99b2aff5c2ce6771e5fbc4079f4b506301f5cf394c8452f7" dependencies = [ "crossterm", "strum", - "strum_macros", + "strum_macros 0.26.2", "unicode-width", ] @@ -214,9 +209,9 @@ checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" [[package]] name = "crossbeam-channel" -version = "0.5.11" +version = "0.5.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "176dc175b78f56c0f321911d9c8eb2b77a78a4860b9c19db83835fea1a46649b" +checksum = "ab3db02a9c5b5121e1e42fbdb1aeb65f5e02624cc58c43f2884c6ccac0b82f95" dependencies = [ "crossbeam-utils", ] @@ -261,7 +256,7 @@ version = "0.27.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f476fe445d41c9e991fd07515a6f463074b782242ccf4a5b7b1d1012e70824df" dependencies = [ - "bitflags 2.4.2", + "bitflags", "crossterm_winapi", "libc", "parking_lot", @@ -306,15 +301,15 @@ checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" [[package]] name = "dyn-clone" -version = "1.0.16" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "545b22097d44f8a9581187cdf93de7a71e4722bf51200cfaba810865b49a495d" +checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125" [[package]] name = "either" -version = "1.9.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" +checksum = "a47c1c47d2f5964e29c61246e81db715514cd532db6b5116a25ea3c03d6780a2" [[package]] name = "encode_unicode" @@ -324,14 +319,14 @@ checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" [[package]] name = "enum_dispatch" -version = "0.3.12" +version = "0.3.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f33313078bb8d4d05a2733a94ac4c2d8a0df9a2b84424ebf4f33bfc224a890e" +checksum = "aa18ce2bc66555b3218614519ac839ddb759a7d6720732f979ef8d13be147ecd" dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.60", ] [[package]] @@ -360,9 +355,9 @@ checksum = "ee1b05cbd864bcaecbd3455d6d967862d446e4ebfc3c2e5e5b9841e53cba6673" [[package]] name = "getrandom" -version = "0.2.12" +version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" +checksum = "94b22e06ecb0110981051723910cbf0b5f5e09a2062dd7663334ee79a9d1286c" dependencies = [ "cfg-if", "js-sys", @@ -388,9 +383,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.14.3" +version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" dependencies = [ "ahash", "allocator-api2", @@ -405,9 +400,9 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" [[package]] name = "hermit-abi" -version = "0.3.3" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" [[package]] name = "hex" @@ -426,9 +421,9 @@ dependencies = [ [[package]] name = "iana-time-zone" -version = "0.1.59" +version = "0.1.60" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6a67363e2aa4443928ce15e57ebae94fd8949958fd1223c4cfc0cd473ad7539" +checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" dependencies = [ "android_system_properties", "core-foundation-sys", @@ -449,20 +444,20 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.1.0" +version = "2.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d530e1a18b1cb4c484e6e34556a0d948706958449fca0cab753d649f2bce3d1f" +checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" dependencies = [ "equivalent", - "hashbrown 0.14.3", + "hashbrown 0.14.5", "serde", ] [[package]] name = "indicatif" -version = "0.17.7" +version = "0.17.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb28741c9db9a713d93deb3bb9515c20788cef5815265bee4980e87bde7e0f25" +checksum = "763a5a8f45087d6bcea4222e7b72c291a054edf80e4ef6efd2a4979878c7bea3" dependencies = [ "console", "instant", @@ -473,9 +468,9 @@ dependencies = [ [[package]] name = "indoc" -version = "2.0.4" +version = "2.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e186cfbae8084e513daff4240b4797e342f988cecda4fb6c939150f96315fd8" +checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" [[package]] name = "instant" @@ -488,40 +483,40 @@ dependencies = [ [[package]] name = "itertools" -version = "0.12.0" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25db6b064527c5d482d0423354fcd07a89a2dfe07b67892e62411946db7f07b0" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" dependencies = [ "either", ] [[package]] name = "itoa" -version = "1.0.10" +version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" +checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" [[package]] name = "jobserver" -version = "0.1.27" +version = "0.1.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c37f63953c4c63420ed5fd3d6d398c719489b9f872b9fa683262f8edd363c7d" +checksum = "d2b099aaa34a9751c5bf0878add70444e1ed2dd73f347be99003d4577277de6e" dependencies = [ "libc", ] [[package]] name = "js-sys" -version = "0.3.67" +version = "0.3.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a1d36f1235bc969acba30b7f5990b864423a6068a10f7c90ae8f0112e3a59d1" +checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" dependencies = [ "wasm-bindgen", ] [[package]] name = "lace" -version = "0.7.0" +version = "0.8.0" dependencies = [ "dirs", "indexmap", @@ -549,7 +544,7 @@ dependencies = [ [[package]] name = "lace_cc" -version = "0.6.0" +version = "0.7.0" dependencies = [ "enum_dispatch", "itertools", @@ -569,7 +564,7 @@ dependencies = [ [[package]] name = "lace_codebook" -version = "0.6.0" +version = "0.7.0" dependencies = [ "lace_consts", "lace_data", @@ -599,7 +594,7 @@ dependencies = [ [[package]] name = "lace_geweke" -version = "0.3.0" +version = "0.4.0" dependencies = [ "indicatif", "lace_stats", @@ -611,7 +606,7 @@ dependencies = [ [[package]] name = "lace_metadata" -version = "0.6.0" +version = "0.7.0" dependencies = [ "bincode", "hex", @@ -630,7 +625,7 @@ dependencies = [ [[package]] name = "lace_stats" -version = "0.3.0" +version = "0.4.0" dependencies = [ "itertools", "lace_consts", @@ -658,9 +653,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.152" +version = "0.2.154" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13e3bf6590cbc649f4d1a3eefc9d5d6eb746f5200ffb04e5e142700b8faa56e7" +checksum = "ae743338b92ff9146ce83992f766a31066a91a8c84a45e0e9f21e7cf6de6d346" [[package]] name = "libm" @@ -670,20 +665,19 @@ checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" [[package]] name = "libredox" -version = "0.0.1" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85c833ca1e66078851dba29046874e38f08b2c883700aa29a03ddd3b23814ee8" +checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" dependencies = [ - "bitflags 2.4.2", + "bitflags", "libc", - "redox_syscall", ] [[package]] name = "lock_api" -version = "0.4.11" +version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c168f8615b12bc01f9c17e2eb0cc07dcae1940121185446edc3744920e8ef45" +checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" dependencies = [ "autocfg", "scopeguard", @@ -691,9 +685,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.20" +version = "0.4.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" +checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" [[package]] name = "lru" @@ -745,9 +739,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.7.1" +version = "2.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" +checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" [[package]] name = "memmap2" @@ -760,18 +754,18 @@ dependencies = [ [[package]] name = "memoffset" -version = "0.9.0" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" dependencies = [ "autocfg", ] [[package]] name = "multiversion" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2c7b9d7fe61760ce5ea19532ead98541f6b4c495d87247aff9826445cf6872a" +checksum = "c4851161a11d3ad0bf9402d90ffc3967bf231768bfd7aeb61755ad06dbf1a142" dependencies = [ "multiversion-macros", "target-features", @@ -779,9 +773,9 @@ dependencies = [ [[package]] name = "multiversion-macros" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26a83d8500ed06d68877e9de1dde76c1dbb83885dcdbda4ef44ccbc3fbda2ac8" +checksum = "79a74ddee9e0c27d2578323c13905793e91622148f138ba29738f9dddb835e90" dependencies = [ "proc-macro2", "quote", @@ -791,9 +785,9 @@ dependencies = [ [[package]] name = "nalgebra" -version = "0.32.3" +version = "0.32.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "307ed9b18cc2423f29e83f84fd23a8e73628727990181f18641a8b5dc2ab1caa" +checksum = "3ea4908d4f23254adda3daa60ffef0f1ac7b8c3e9a864cf3cc154b251908a2ef" dependencies = [ "approx", "matrixmultiply", @@ -837,9 +831,9 @@ dependencies = [ [[package]] name = "num" -version = "0.4.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b05180d69e3da0e530ba2a1dae5110317e49e3b7f3d41be227dc5f92e49ee7af" +checksum = "3135b08af27d103b0a51f2ae0f8632117b7b185ccf931445affa8df530576a41" dependencies = [ "num-bigint", "num-complex", @@ -862,9 +856,9 @@ dependencies = [ [[package]] name = "num-complex" -version = "0.4.4" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214" +checksum = "23c6602fda94a57c990fe0df199a035d83576b496aa29f4e634a8ac6004e68a6" dependencies = [ "num-traits", "serde", @@ -872,19 +866,18 @@ dependencies = [ [[package]] name = "num-integer" -version = "0.1.45" +version = "0.1.46" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" dependencies = [ - "autocfg", "num-traits", ] [[package]] name = "num-iter" -version = "0.1.43" +version = "0.1.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d03e6c028c5dc5cac6e2dec0efda81fc887605bb3d884578bb6d6bf7514e252" +checksum = "d869c01cc0c455284163fd0092f1f93835385ccab5a98a0dcc497b2f8bf055a9" dependencies = [ "autocfg", "num-integer", @@ -905,9 +898,9 @@ dependencies = [ [[package]] name = "num-traits" -version = "0.2.17" +version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" +checksum = "da0df0e5185db44f69b44f26786fe401b6c293d1907744beaa7fa62b2e5a517a" dependencies = [ "autocfg", "libm", @@ -949,9 +942,9 @@ checksum = "efa535d5117d3661134dbf1719b6f0ffe06f2375843b13935db186cd094105eb" [[package]] name = "parking_lot" -version = "0.12.1" +version = "0.12.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" +checksum = "7e4af0ca4f6caed20e900d564c242b8e5d4903fdacf31d3daf527b66fe6f42fb" dependencies = [ "lock_api", "parking_lot_core", @@ -959,15 +952,15 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.9" +version = "0.9.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e" +checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" dependencies = [ "cfg-if", "libc", "redox_syscall", "smallvec", - "windows-targets 0.48.5", + "windows-targets 0.52.5", ] [[package]] @@ -1008,9 +1001,9 @@ dependencies = [ [[package]] name = "pkg-config" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69d3587f8a9e599cc7ec2c00e331f71c4e69a5f9a4b8a6efd5b07466b9736f9a" +checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" [[package]] name = "planus" @@ -1054,7 +1047,7 @@ dependencies = [ "fast-float", "foreign_vec", "getrandom", - "hashbrown 0.14.3", + "hashbrown 0.14.5", "itoa", "lz4", "multiversion", @@ -1089,12 +1082,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0f5efe734b6cbe5f97ea769be8360df5324fade396f1f3f5ad7fe9360ca4a23" dependencies = [ "ahash", - "bitflags 2.4.2", + "bitflags", "bytemuck", "chrono", "comfy-table", "either", - "hashbrown 0.14.3", + "hashbrown 0.14.5", "indexmap", "num-traits", "once_cell", @@ -1162,7 +1155,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d7105b40905bb38e8fc4a7fd736594b7491baa12fad3ac492969ca221a1b5d5" dependencies = [ "ahash", - "bitflags 2.4.2", + "bitflags", "glob", "once_cell", "polars-arrow", @@ -1188,7 +1181,7 @@ dependencies = [ "argminmax", "bytemuck", "either", - "hashbrown 0.14.3", + "hashbrown 0.14.5", "indexmap", "memchr", "num-traits", @@ -1212,7 +1205,7 @@ dependencies = [ "crossbeam-channel", "crossbeam-queue", "enum_dispatch", - "hashbrown 0.14.3", + "hashbrown 0.14.5", "num-traits", "polars-arrow", "polars-compute", @@ -1246,7 +1239,7 @@ dependencies = [ "rayon", "regex", "smartstring", - "strum_macros", + "strum_macros 0.25.3", "version_check", ] @@ -1305,7 +1298,7 @@ checksum = "b174ca4a77ad47d7b91a0460aaae65bbf874c8bfbaaa5308675dadef3976bbda" dependencies = [ "ahash", "bytemuck", - "hashbrown 0.14.3", + "hashbrown 0.14.5", "indexmap", "num-traits", "once_cell", @@ -1330,22 +1323,22 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.76" +version = "1.0.81" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95fc56cda0b5c3325f5fbbd7ff9fda9e02bb00bb3dac51252d2f1bfa1cb8cc8c" +checksum = "3d1597b0c024618f09a9c3b8655b7e430397a36d23fdafec26d6965e9eec3eba" dependencies = [ "unicode-ident", ] [[package]] name = "puruspe" -version = "0.2.0" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe7765e19fb2ba6fd4373b8d90399f5321683ea7c11b598c6bbaa3a72e9c83b8" +checksum = "06a1eed715f625eaa95fba5e049dcf7bc06fa396d6d2e55015b3764e234dfd3f" [[package]] name = "pylace" -version = "0.7.1" +version = "0.8.0" dependencies = [ "bincode", "lace", @@ -1362,15 +1355,16 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.20.2" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a89dc7a5850d0e983be1ec2a463a171d20990487c3cfcd68b5363f1ee3d6fe0" +checksum = "a5e00b96a521718e08e03b1a622f01c8a8deb50719335de3f60b3b3950f069d8" dependencies = [ "cfg-if", "indoc", "libc", "memoffset", "parking_lot", + "portable-atomic", "pyo3-build-config", "pyo3-ffi", "pyo3-macros", @@ -1379,9 +1373,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.20.2" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07426f0d8fe5a601f26293f300afd1a7b1ed5e78b2a705870c5f30893c5163be" +checksum = "7883df5835fafdad87c0d888b266c8ec0f4c9ca48a5bed6bbb592e8dedee1b50" dependencies = [ "once_cell", "target-lexicon", @@ -1389,9 +1383,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.20.2" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbb7dec17e17766b46bca4f1a4215a85006b4c2ecde122076c562dd058da6cf1" +checksum = "01be5843dc60b916ab4dad1dca6d20b9b4e6ddc8e15f50c47fe6d85f1fb97403" dependencies = [ "libc", "pyo3-build-config", @@ -1399,33 +1393,34 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.20.2" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05f738b4e40d50b5711957f142878cfa0f28e054aa0ebdfc3fd137a843f74ed3" +checksum = "77b34069fc0682e11b31dbd10321cbf94808394c56fd996796ce45217dfac53c" dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.48", + "syn 2.0.60", ] [[package]] name = "pyo3-macros-backend" -version = "0.20.2" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fc910d4851847827daf9d6cdd4a823fbdaab5b8818325c5e97a86da79e8881f" +checksum = "08260721f32db5e1a5beae69a55553f56b99bd0e1c3e6e0a5e8851a9d0f5a85c" dependencies = [ "heck", "proc-macro2", + "pyo3-build-config", "quote", - "syn 2.0.48", + "syn 2.0.60", ] [[package]] name = "quote" -version = "1.0.35" +version = "1.0.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" +checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" dependencies = [ "proc-macro2", ] @@ -1490,9 +1485,9 @@ checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" [[package]] name = "rayon" -version = "1.8.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c27db03db7734835b3f53954b534c91069375ce6ccaa2e065441e07d9b6cdb1" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" dependencies = [ "either", "rayon-core", @@ -1500,9 +1495,9 @@ dependencies = [ [[package]] name = "rayon-core" -version = "1.12.0" +version = "1.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ce3fb6ad83f861aac485e76e1985cd109d9a3713802152be56c3b1f0e0658ed" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" dependencies = [ "crossbeam-deque", "crossbeam-utils", @@ -1510,18 +1505,18 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.4.1" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" +checksum = "469052894dcb553421e483e4209ee581a45100d31b4018de03e5a7ad86374a7e" dependencies = [ - "bitflags 1.3.2", + "bitflags", ] [[package]] name = "redox_users" -version = "0.4.4" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a18479200779601e498ada4e8c1e1f50e3ee19deb0259c25825a98b5603b2cb4" +checksum = "bd283d9651eeda4b2a83a43c1c91b266c40fd76ecd39a50a8c630ae69dc72891" dependencies = [ "getrandom", "libredox", @@ -1530,9 +1525,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.2" +version = "1.10.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "380b951a9c5e80ddfd6136919eef32310721aa4aacd4889a8d39124b026ab343" +checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" dependencies = [ "aho-corasick", "memchr", @@ -1542,9 +1537,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.3" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f804c7828047e88b2d32e2d7fe5a105da8ee3264f01902f796c8e067dc2483f" +checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" dependencies = [ "aho-corasick", "memchr", @@ -1553,26 +1548,27 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" +checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" [[package]] name = "rustversion" -version = "1.0.14" +version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4" +checksum = "80af6f9131f277a45a3fba6ce8e2258037bb0477a67e610d3c1fe046ab31de47" [[package]] name = "rv" -version = "0.16.3" +version = "0.16.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35f602941aca67593b30eea71a0b372e50e3ad63e7aa6b98b2ea18ff74ba9cf8" +checksum = "c07e0a3b756794c7ea2f05d93760ffb946ff4f94b255d92444d94c19fd71f4ab" dependencies = [ "doc-comment", "lru", "nalgebra", "num", + "num-traits", "peroxide", "rand", "rand_distr", @@ -1582,9 +1578,9 @@ dependencies = [ [[package]] name = "ryu" -version = "1.0.16" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c" +checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1" [[package]] name = "safe_arch" @@ -1603,29 +1599,29 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "serde" -version = "1.0.196" +version = "1.0.199" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "870026e60fa08c69f064aa766c10f10b1d62db9ccd4d0abb206472bee0ce3b32" +checksum = "0c9f6e76df036c77cd94996771fb40db98187f096dd0b9af39c6c6e452ba966a" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.196" +version = "1.0.199" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33c85360c95e7d137454dc81d9a4ed2b8efd8fbe19cee57357b32b9771fccb67" +checksum = "11bd257a6541e141e42ca6d24ae26f7714887b47e89aa739099104c7e4d3b7fc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.60", ] [[package]] name = "serde_json" -version = "1.0.111" +version = "1.0.116" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "176e46fa42316f18edd598015a5166857fc835ec732f5215eac6b7bdbf0a84f4" +checksum = "3e17db7126d17feb94eb3fad46bf1a96b034e8aacbc2e775fe81505f8b0b2813" dependencies = [ "itoa", "ryu", @@ -1634,9 +1630,9 @@ dependencies = [ [[package]] name = "serde_yaml" -version = "0.9.30" +version = "0.9.34+deprecated" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1bf28c79a99f70ee1f1d83d10c875d2e70618417fda01ad1785e027579d9d38" +checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" dependencies = [ "indexmap", "itoa", @@ -1666,9 +1662,9 @@ checksum = "f27f6278552951f1f2b8cf9da965d10969b2efdea95a6ec47987ab46edfe263a" [[package]] name = "smallvec" -version = "1.12.0" +version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2593d31f82ead8df961d8bd23a64c2ccf2eb5dd34b0a34bfb4dd54011c72009e" +checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" [[package]] name = "smartstring" @@ -1719,9 +1715,9 @@ checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82" [[package]] name = "strum" -version = "0.25.0" +version = "0.26.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" +checksum = "5d8cec3501a5194c432b2b7976db6b7d10ec95c253208b45f83f7136aa985e29" [[package]] name = "strum_macros" @@ -1733,7 +1729,20 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.48", + "syn 2.0.60", +] + +[[package]] +name = "strum_macros" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6cf59daf282c0a494ba14fd21610a0325f9f90ec9d1231dea26bcb1d696c946" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.60", ] [[package]] @@ -1749,9 +1758,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.48" +version = "2.0.60" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f" +checksum = "909518bc7b1c9b779f1bbf07f2929d35af9f0f37e47c6e9ef7f9dddc1e1821f3" dependencies = [ "proc-macro2", "quote", @@ -1760,9 +1769,9 @@ dependencies = [ [[package]] name = "sysinfo" -version = "0.30.5" +version = "0.30.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fb4f3438c8f6389c864e61221cbc97e9bca98b4daf39a5beb7bea660f528bb2" +checksum = "87341a165d73787554941cd5ef55ad728011566fe714e987d1b976c15dbc3a83" dependencies = [ "cfg-if", "core-foundation-sys", @@ -1774,34 +1783,34 @@ dependencies = [ [[package]] name = "target-features" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfb5fa503293557c5158bd215fdc225695e567a77e453f5d4452a50a193969bd" +checksum = "c1bbb9f3c5c463a01705937a24fdabc5047929ac764b2d5b9cf681c1f5041ed5" [[package]] name = "target-lexicon" -version = "0.12.13" +version = "0.12.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69758bda2e78f098e4ccb393021a0963bb3442eac05f135c30f61b7370bbafae" +checksum = "e1fc403891a21bcfb7c37834ba66a547a8f402146eba7265b5a6d88059c9ff2f" [[package]] name = "thiserror" -version = "1.0.56" +version = "1.0.59" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d54378c645627613241d077a3a79db965db602882668f9136ac42af9ecb730ad" +checksum = "f0126ad08bff79f29fc3ae6a55cc72352056dfff61e3ff8bb7129476d44b23aa" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.56" +version = "1.0.59" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa0faa943b50f3db30a20aa7e265dbc66076993efed8463e8de414e5d06d3471" +checksum = "d1cd413b5d558b4c5bf3680e324a6fa5014e7b7c067a51e69dbdf47eb7148b66" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.60", ] [[package]] @@ -1827,9 +1836,9 @@ checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" [[package]] name = "unicode-width" -version = "0.1.11" +version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85" +checksum = "68f5e5f3158ecfd4b8ff6fe086db7c8467a2dfdac97fe420f2b7c4aa97af66d6" [[package]] name = "unindent" @@ -1839,9 +1848,9 @@ checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" [[package]] name = "unsafe-libyaml" -version = "0.2.10" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab4c90930b95a82d00dc9e9ac071b4991924390d46cbd0dfe566148667605e4b" +checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861" [[package]] name = "version_check" @@ -1857,9 +1866,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.90" +version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1223296a201415c7fad14792dbefaace9bd52b62d33453ade1c5b5f07555406" +checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -1867,24 +1876,24 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.90" +version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcdc935b63408d58a32f8cc9738a0bffd8f05cc7c002086c6ef20b7312ad9dcd" +checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" dependencies = [ "bumpalo", "log", "once_cell", "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.60", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-macro" -version = "0.2.90" +version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e4c238561b2d428924c49815533a8b9121c664599558a5d9ec51f8a1740a999" +checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -1892,28 +1901,28 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.90" +version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bae1abb6806dc1ad9e560ed242107c0f6c84335f1749dd4e8ddb012ebd5e25a7" +checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.60", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.90" +version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d91413b1c31d7539ba5ef2451af3f0b833a005eb27a631cec32bc0635a8602b" +checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" [[package]] name = "wide" -version = "0.7.13" +version = "0.7.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c68938b57b33da363195412cfc5fc37c9ed49aa9cfe2156fde64b8d2c9498242" +checksum = "0f0e39d2c603fdc0504b12b458cf1f34e0b937ed2f4f2dc20796e3e86f34e11f" dependencies = [ "bytemuck", "safe_arch", @@ -1948,7 +1957,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be" dependencies = [ "windows-core", - "windows-targets 0.52.0", + "windows-targets 0.52.5", ] [[package]] @@ -1957,7 +1966,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" dependencies = [ - "windows-targets 0.52.0", + "windows-targets 0.52.5", ] [[package]] @@ -1975,7 +1984,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.52.0", + "windows-targets 0.52.5", ] [[package]] @@ -1995,17 +2004,18 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.0" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd" +checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" dependencies = [ - "windows_aarch64_gnullvm 0.52.0", - "windows_aarch64_msvc 0.52.0", - "windows_i686_gnu 0.52.0", - "windows_i686_msvc 0.52.0", - "windows_x86_64_gnu 0.52.0", - "windows_x86_64_gnullvm 0.52.0", - "windows_x86_64_msvc 0.52.0", + "windows_aarch64_gnullvm 0.52.5", + "windows_aarch64_msvc 0.52.5", + "windows_i686_gnu 0.52.5", + "windows_i686_gnullvm", + "windows_i686_msvc 0.52.5", + "windows_x86_64_gnu 0.52.5", + "windows_x86_64_gnullvm 0.52.5", + "windows_x86_64_msvc 0.52.5", ] [[package]] @@ -2016,9 +2026,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.0" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" +checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" [[package]] name = "windows_aarch64_msvc" @@ -2028,9 +2038,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" -version = "0.52.0" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" +checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" [[package]] name = "windows_i686_gnu" @@ -2040,9 +2050,15 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" -version = "0.52.0" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" +checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" [[package]] name = "windows_i686_msvc" @@ -2052,9 +2068,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" -version = "0.52.0" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" +checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" [[package]] name = "windows_x86_64_gnu" @@ -2064,9 +2080,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" -version = "0.52.0" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" +checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" [[package]] name = "windows_x86_64_gnullvm" @@ -2076,9 +2092,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.0" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" +checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" [[package]] name = "windows_x86_64_msvc" @@ -2088,15 +2104,15 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" -version = "0.52.0" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" +checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" [[package]] name = "xxhash-rust" -version = "0.8.8" +version = "0.8.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53be06678ed9e83edb1745eb72efc0bbcd7b5c3c35711a860906aed827a13d61" +checksum = "927da81e25be1e1a2901d59b81b37dd2efd1fc9c9345a55007f09bf5a2d3ee03" [[package]] name = "zerocopy" @@ -2115,32 +2131,32 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.60", ] [[package]] name = "zstd" -version = "0.13.0" +version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bffb3309596d527cfcba7dfc6ed6052f1d39dfbd7c867aa2e865e4a449c10110" +checksum = "2d789b1514203a1120ad2429eae43a7bd32b90976a7bb8a05f7ec02fa88cc23a" dependencies = [ "zstd-safe", ] [[package]] name = "zstd-safe" -version = "7.0.0" +version = "7.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43747c7422e2924c11144d5229878b98180ef8b06cca4ab5af37afc8a8d8ea3e" +checksum = "1cd99b45c6bc03a018c8b8a86025678c87e55526064e38f9df301989dce7ec0a" dependencies = [ "zstd-sys", ] [[package]] name = "zstd-sys" -version = "2.0.9+zstd.1.5.5" +version = "2.0.10+zstd.1.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e16efa8a874a0481a574084d34cc26fdb3b99627480f785888deb6386506656" +checksum = "c253a4914af5bafc8fa8c86ee400827e83cf6ec01195ec1f1ed8441bf00d65aa" dependencies = [ "cc", "pkg-config", diff --git a/pylace/Cargo.toml b/pylace/Cargo.toml index 45a86ba7..5876af55 100644 --- a/pylace/Cargo.toml +++ b/pylace/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pylace" -version = "0.7.1" +version = "0.8.0" edition = "2021" license = "BUSL-1.1" @@ -9,11 +9,11 @@ name = "lace" crate-type = ["cdylib"] [dependencies] -lace = { path = "../lace", version="0.7.0" } +lace = { path = "../lace", version="0.8.0" } lace_utils = { path = "../lace/lace_utils", version="0.3.0" } rand = "0.8.5" rand_xoshiro = "0.6.0" -pyo3 = { version = "0.20", features = ["extension-module"] } +pyo3 = { version = "0.21", features = ["extension-module"] } serde_json = "1.0.91" serde_yaml = "0.9.17" polars = "0.36" diff --git a/pylace/lace/__init__.py b/pylace/lace/__init__.py index 71b74339..3c82c642 100644 --- a/pylace/lace/__init__.py +++ b/pylace/lace/__init__.py @@ -14,6 +14,7 @@ ContinuousPrior, CountHyper, CountPrior, + PriorProcess, RowKernel, StateTransition, ValueMap, @@ -36,6 +37,7 @@ "CountHyper", "CountPrior", "ValueMap", + "PriorProcess", ] __version__ = metadata.version("pylace") diff --git a/pylace/lace/analysis.py b/pylace/lace/analysis.py index a07d6858..68ee352c 100644 --- a/pylace/lace/analysis.py +++ b/pylace/lace/analysis.py @@ -268,15 +268,15 @@ def held_out_neglogp( │ --- ┆ --- ┆ --- │ │ list[str] ┆ f64 ┆ i64 │ ╞═════════════════════════╪═════════════════════╪═══════════╡ - │ null ┆ 7.808063 ┆ 0 │ - │ ["Apogee_km"] ┆ 5.082683 ┆ 1 │ - │ ["Eccentricity"] ┆ 2.931816 ┆ 2 │ - │ ["Launch_Vehicle"] ┆ 2.931816 ┆ 3 │ + │ null ┆ 7.115493 ┆ 0 │ + │ ["Apogee_km"] ┆ 4.484848 ┆ 1 │ + │ ["Eccentricity"] ┆ 3.022424 ┆ 2 │ + │ ["Date_of_Launch"] ┆ 3.022424 ┆ 3 │ │ … ┆ … ┆ … │ - │ ["Power_watts"] ┆ 2.932103 ┆ 15 │ - │ ["Inclination_radians"] ┆ 2.933732 ┆ 16 │ - │ ["Users"] ┆ 2.940667 ┆ 17 │ - │ ["Perigee_km"] ┆ 3.956759 ┆ 18 │ + │ ["Launch_Site"] ┆ 3.022426 ┆ 15 │ + │ ["Power_watts"] ┆ 3.022582 ┆ 16 │ + │ ["Inclination_radians"] ┆ 3.024748 ┆ 17 │ + │ ["Perigee_km"] ┆ 4.025416 ┆ 18 │ └─────────────────────────┴─────────────────────┴───────────┘ If we don't want to use the greedy search, we can enumerate, but we need to @@ -297,15 +297,15 @@ def held_out_neglogp( │ --- ┆ --- ┆ --- │ │ list[str] ┆ f64 ┆ i64 │ ╞═══════════════════════════════════╪═════════════════════╪═══════════╡ - │ null ┆ 7.853468 ┆ 0 │ - │ ["Apogee_km"] ┆ 5.106627 ┆ 1 │ - │ ["Apogee_km", "Eccentricity"] ┆ 2.951662 ┆ 2 │ - │ ["Apogee_km", "Country_of_Operat… ┆ 2.951254 ┆ 3 │ - │ ["Apogee_km", "Country_of_Operat… ┆ 2.952801 ┆ 4 │ - │ ["Apogee_km", "Country_of_Contra… ┆ 2.956224 ┆ 5 │ - │ ["Apogee_km", "Country_of_Contra… ┆ 2.96479 ┆ 6 │ - │ ["Apogee_km", "Country_of_Contra… ┆ 2.992173 ┆ 7 │ - │ ["Apogee_km", "Class_of_Orbit", … ┆ 3.956759 ┆ 8 │ + │ null ┆ 7.187543 ┆ 0 │ + │ ["Apogee_km"] ┆ 4.502691 ┆ 1 │ + │ ["Apogee_km", "Eccentricity"] ┆ 3.033792 ┆ 2 │ + │ ["Apogee_km", "Country_of_Operat… ┆ 3.033296 ┆ 3 │ + │ ["Apogee_km", "Country_of_Contra… ┆ 3.035064 ┆ 4 │ + │ ["Apogee_km", "Country_of_Contra… ┆ 3.037117 ┆ 5 │ + │ ["Apogee_km", "Country_of_Contra… ┆ 3.046293 ┆ 6 │ + │ ["Apogee_km", "Country_of_Contra… ┆ 3.076149 ┆ 7 │ + │ ["Apogee_km", "Class_of_Orbit", … ┆ 4.025416 ┆ 8 │ └───────────────────────────────────┴─────────────────────┴───────────┘ """ @@ -382,14 +382,14 @@ def held_out_inconsistency( │ --- ┆ --- ┆ --- │ │ list[str] ┆ f64 ┆ i64 │ ╞═════════════════════════╪═══════════════════════════╪═══════════╡ - │ null ┆ 1.973348 ┆ 0 │ - │ ["Apogee_km"] ┆ 1.284557 ┆ 1 │ - │ ["Eccentricity"] ┆ 0.740964 ┆ 2 │ - │ ["Launch_Vehicle"] ┆ 0.740964 ┆ 3 │ + │ null ┆ 1.767642 ┆ 0 │ + │ ["Apogee_km"] ┆ 1.114133 ┆ 1 │ + │ ["Eccentricity"] ┆ 0.750835 ┆ 2 │ + │ ["Date_of_Launch"] ┆ 0.750835 ┆ 3 │ │ … ┆ … ┆ … │ - │ ["Power_watts"] ┆ 0.741036 ┆ 15 │ - │ ["Inclination_radians"] ┆ 0.741448 ┆ 16 │ - │ ["Users"] ┆ 0.743201 ┆ 17 │ + │ ["Launch_Site"] ┆ 0.750836 ┆ 15 │ + │ ["Power_watts"] ┆ 0.750874 ┆ 16 │ + │ ["Inclination_radians"] ┆ 0.751413 ┆ 17 │ │ ["Perigee_km"] ┆ 1.0 ┆ 18 │ └─────────────────────────┴───────────────────────────┴───────────┘ @@ -411,14 +411,14 @@ def held_out_inconsistency( │ --- ┆ --- ┆ --- │ │ list[str] ┆ f64 ┆ i64 │ ╞═══════════════════════════════════╪═══════════════════════════╪═══════════╡ - │ null ┆ 1.984823 ┆ 0 │ - │ ["Apogee_km"] ┆ 1.290609 ┆ 1 │ - │ ["Apogee_km", "Eccentricity"] ┆ 0.74598 ┆ 2 │ - │ ["Apogee_km", "Country_of_Operat… ┆ 0.745877 ┆ 3 │ - │ ["Apogee_km", "Country_of_Operat… ┆ 0.746268 ┆ 4 │ - │ ["Apogee_km", "Country_of_Contra… ┆ 0.747133 ┆ 5 │ - │ ["Apogee_km", "Country_of_Contra… ┆ 0.749297 ┆ 6 │ - │ ["Apogee_km", "Country_of_Contra… ┆ 0.756218 ┆ 7 │ + │ null ┆ 1.785541 ┆ 0 │ + │ ["Apogee_km"] ┆ 1.118565 ┆ 1 │ + │ ["Apogee_km", "Eccentricity"] ┆ 0.753659 ┆ 2 │ + │ ["Apogee_km", "Country_of_Operat… ┆ 0.753536 ┆ 3 │ + │ ["Apogee_km", "Country_of_Contra… ┆ 0.753975 ┆ 4 │ + │ ["Apogee_km", "Country_of_Contra… ┆ 0.754485 ┆ 5 │ + │ ["Apogee_km", "Country_of_Contra… ┆ 0.756765 ┆ 6 │ + │ ["Apogee_km", "Country_of_Contra… ┆ 0.764182 ┆ 7 │ │ ["Apogee_km", "Class_of_Orbit", … ┆ 1.0 ┆ 8 │ └───────────────────────────────────┴───────────────────────────┴───────────┘ @@ -488,21 +488,21 @@ def held_out_uncertainty( ... quiet=True, ... ) # doctest: +NORMALIZE_WHITESPACE shape: (19, 3) - ┌──────────────────────────────────┬─────────────────────────┬───────────┐ - │ feature_rmed ┆ HoldOutFunc.Uncertainty ┆ keys_rmed │ - │ --- ┆ --- ┆ --- │ - │ list[str] ┆ f64 ┆ i64 │ - ╞══════════════════════════════════╪═════════════════════════╪═══════════╡ - │ null ┆ 0.43212 ┆ 0 │ - │ ["Perigee_km"] ┆ 0.43212 ┆ 1 │ - │ ["Class_of_Orbit"] ┆ 0.43212 ┆ 2 │ - │ ["Source_Used_for_Orbital_Data"] ┆ 0.431921 ┆ 3 │ - │ … ┆ … ┆ … │ - │ ["Country_of_Operator"] ┆ 0.054156 ┆ 15 │ - │ ["Country_of_Contractor"] ┆ 0.06069 ┆ 16 │ - │ ["Dry_Mass_kg"] ┆ 0.139502 ┆ 17 │ - │ ["Inclination_radians"] ┆ 0.089026 ┆ 18 │ - └──────────────────────────────────┴─────────────────────────┴───────────┘ + ┌───────────────────────────┬─────────────────────────┬───────────┐ + │ feature_rmed ┆ HoldOutFunc.Uncertainty ┆ keys_rmed │ + │ --- ┆ --- ┆ --- │ + │ list[str] ┆ f64 ┆ i64 │ + ╞═══════════════════════════╪═════════════════════════╪═══════════╡ + │ null ┆ 0.505795 ┆ 0 │ + │ ["Purpose"] ┆ 0.505794 ┆ 1 │ + │ ["Launch_Mass_kg"] ┆ 0.499515 ┆ 2 │ + │ ["Country_of_Contractor"] ┆ 0.497596 ┆ 3 │ + │ … ┆ … ┆ … │ + │ ["Expected_Lifetime"] ┆ 0.252419 ┆ 15 │ + │ ["Launch_Vehicle"] ┆ 0.225609 ┆ 16 │ + │ ["Users"] ┆ 0.19823 ┆ 17 │ + │ ["Country_of_Operator"] ┆ 0.185145 ┆ 18 │ + └───────────────────────────┴─────────────────────────┴───────────┘ If we don't want to use the greedy search, we can enumerate, but we need to be mindful that the number of conditions we must enumerate over is 2^n @@ -522,15 +522,15 @@ def held_out_uncertainty( │ --- ┆ --- ┆ --- │ │ list[str] ┆ f64 ┆ i64 │ ╞═══════════════════════════════════╪═════════════════════════╪═══════════╡ - │ null ┆ 0.445501 ┆ 0 │ - │ ["Expected_Lifetime"] ┆ 0.437647 ┆ 1 │ - │ ["Apogee_km", "Eccentricity"] ┆ 0.05561 ┆ 2 │ - │ ["Apogee_km", "Country_of_Operat… ┆ 0.055283 ┆ 3 │ - │ ["Apogee_km", "Country_of_Operat… ┆ 0.056185 ┆ 4 │ - │ ["Apogee_km", "Country_of_Operat… ┆ 0.057624 ┆ 5 │ - │ ["Apogee_km", "Country_of_Contra… ┆ 0.0595 ┆ 6 │ - │ ["Apogee_km", "Country_of_Contra… ┆ 0.077359 ┆ 7 │ - │ ["Apogee_km", "Class_of_Orbit", … ┆ 0.089026 ┆ 8 │ + │ null ┆ 0.515391 ┆ 0 │ + │ ["Class_of_Orbit"] ┆ 0.484085 ┆ 1 │ + │ ["Apogee_km", "Eccentricity"] ┆ 0.260645 ┆ 2 │ + │ ["Apogee_km", "Class_of_Orbit", … ┆ 0.251961 ┆ 3 │ + │ ["Apogee_km", "Class_of_Orbit", … ┆ 0.247123 ┆ 4 │ + │ ["Apogee_km", "Class_of_Orbit", … ┆ 0.220715 ┆ 5 │ + │ ["Apogee_km", "Class_of_Orbit", … ┆ 0.211055 ┆ 6 │ + │ ["Apogee_km", "Class_of_Orbit", … ┆ 0.1979 ┆ 7 │ + │ ["Apogee_km", "Class_of_Orbit", … ┆ 0.185145 ┆ 8 │ └───────────────────────────────────┴─────────────────────────┴───────────┘ """ @@ -656,7 +656,7 @@ def attributable_inconsistency( ... quiet=True, ... ) # doctest: +NORMALIZE_WHITESPACE >>> frac - 0.2930260843667006 + 0.2702093046733929 """ @@ -730,7 +730,7 @@ def attributable_neglogp( ... quiet=True, ... ) # doctest: +NORMALIZE_WHITESPACE >>> frac - 0.29302608436670047 + 0.2702093046733929 """ @@ -801,7 +801,7 @@ def attributable_uncertainty( ... quiet=True, ... ) # doctest: +NORMALIZE_WHITESPACE >>> frac - 0.1814171785207335 + 0.17905287760659047 """ @@ -950,15 +950,15 @@ def explain_prediction( │ --- ┆ --- │ │ str ┆ f64 │ ╞══════════════════════════════╪═════════════╡ - │ Country_of_Operator ┆ 2.4617e-16 │ - │ Users ┆ -2.1412e-15 │ - │ Purpose ┆ -8.0193e-15 │ - │ Class_of_Orbit ┆ -2.2727e-15 │ + │ Country_of_Operator ┆ 3.9980e-15 │ + │ Users ┆ -3.4701e-13 │ + │ Purpose ┆ -5.3209e-15 │ + │ Class_of_Orbit ┆ -1.8481e-15 │ │ … ┆ … │ - │ Launch_Site ┆ -5.8214e-16 │ - │ Launch_Vehicle ┆ -9.6101e-16 │ - │ Source_Used_for_Orbital_Data ┆ -9.1997e-15 │ - │ Inclination_radians ┆ -1.5407e-15 │ + │ Launch_Site ┆ -4.2856e-13 │ + │ Launch_Vehicle ┆ -8.2878e-14 │ + │ Source_Used_for_Orbital_Data ┆ 1.7684e-14 │ + │ Inclination_radians ┆ -2.6242e-13 │ └──────────────────────────────┴─────────────┘ Get the importances using the 'ablative-dist' method, which measures how @@ -977,15 +977,15 @@ def explain_prediction( │ --- ┆ --- │ │ str ┆ f64 │ ╞══════════════════════════════╪═══════════╡ - │ Country_of_Operator ┆ -0.000109 │ - │ Users ┆ 0.081289 │ - │ Purpose ┆ 0.18938 │ - │ Class_of_Orbit ┆ 0.000119 │ + │ Country_of_Operator ┆ -0.012699 │ + │ Users ┆ 0.003983 │ + │ Purpose ┆ -0.042624 │ + │ Class_of_Orbit ┆ -0.00122 │ │ … ┆ … │ - │ Launch_Site ┆ 0.003411 │ - │ Launch_Vehicle ┆ -0.018817 │ - │ Source_Used_for_Orbital_Data ┆ 0.001454 │ - │ Inclination_radians ┆ 0.057333 │ + │ Launch_Site ┆ -0.011698 │ + │ Launch_Vehicle ┆ -0.09602 │ + │ Source_Used_for_Orbital_Data ┆ -0.027222 │ + │ Inclination_radians ┆ 0.012758 │ └──────────────────────────────┴───────────┘ """ diff --git a/pylace/lace/codebook.py b/pylace/lace/codebook.py index 5c9d8f43..f6583512 100644 --- a/pylace/lace/codebook.py +++ b/pylace/lace/codebook.py @@ -280,14 +280,14 @@ def rename(self, name: str): >>> codebook = Animals().codebook >>> codebook # doctest: +NORMALIZE_WHITESPACE Codebook 'my_table' - state_alpha_prior: G(α: 1, β: 1) - view_alpha_prior: G(α: 1, β: 1) + state_prior_process: DP(α ~ G(α: 1, β: 1)) + view_prior_process: DP(α ~ G(α: 1, β: 1)) columns: 85 rows: 50 >>> codebook.rename("Dennis") Codebook 'Dennis' - state_alpha_prior: G(α: 1, β: 1) - view_alpha_prior: G(α: 1, β: 1) + state_prior_process: DP(α ~ G(α: 1, β: 1)) + view_prior_process: DP(α ~ G(α: 1, β: 1)) columns: 85 rows: 50 @@ -296,74 +296,68 @@ def rename(self, name: str): codebook.codebook.rename(name) return codebook - def set_state_alpha_prior(self, shape: float = 1.0, rate: float = 1.0): + def set_state_prior_process(self, prior_process: _lc.PriorProcess): """ - Return a copy of the codebook with a new state CRP alpha prior. + Return a copy of the codebook with a new state PriorProcess. Parameters ---------- - shape: float, optional - The shape of the Gamma distribution prior on alpha: a positive - floating point value in (0, Inf). Default is 1. - rate: float, optional - The rate of the Gamma distribution prior on alpha: a positive - floating point value in (0, Inf). Default is 1. + prior_process: core.PriorProcess Examples -------- >>> from lace.examples import Animals + >>> from lace import PriorProcess >>> codebook = Animals().codebook >>> codebook # doctest: +NORMALIZE_WHITESPACE Codebook 'my_table' - state_alpha_prior: G(α: 1, β: 1) - view_alpha_prior: G(α: 1, β: 1) + state_prior_process: DP(α ~ G(α: 1, β: 1)) + view_prior_process: DP(α ~ G(α: 1, β: 1)) columns: 85 rows: 50 - >>> codebook.set_state_alpha_prior(2.0, 3.1) + >>> process = PriorProcess.pitman_yor(1.0, 2.0, 0.5, 0.5) + >>> codebook.set_state_prior_process(process) Codebook 'my_table' - state_alpha_prior: G(α: 2, β: 3.1) - view_alpha_prior: G(α: 1, β: 1) + state_prior_process: PYP(α ~ G(α: 1, β: 2), d ~ Beta(α: 0.5, β: 0.5)) + view_prior_process: DP(α ~ G(α: 1, β: 1)) columns: 85 rows: 50 """ codebook = copy.copy(self) - codebook.codebook.set_state_alpha_prior(shape, rate) + codebook.codebook.set_state_prior_process(prior_process) return codebook - def set_view_alpha_prior(self, shape: float = 1.0, rate: float = 1.0): + def set_view_prior_process(self, prior_process: _lc.PriorProcess): """ - Return a copy of the codebook with a new view CRP alpha prior. + Return a copy of the codebook with a new view PriorProcess. Parameters ---------- - shape: float, optional - The shape of the Gamma distribution prior on alpha: a positive - floating point value in (0, Inf). Default is 1. - rate: float, optional - The rate of the Gamma distribution prior on alpha: a positive - floating point value in (0, Inf). Default is 1. + prior_process: core.PriorProcess Examples -------- >>> from lace.examples import Animals + >>> from lace import PriorProcess >>> codebook = Animals().codebook >>> codebook # doctest: +NORMALIZE_WHITESPACE Codebook 'my_table' - state_alpha_prior: G(α: 1, β: 1) - view_alpha_prior: G(α: 1, β: 1) + state_prior_process: DP(α ~ G(α: 1, β: 1)) + view_prior_process: DP(α ~ G(α: 1, β: 1)) columns: 85 rows: 50 - >>> codebook.set_view_alpha_prior(2.0, 3.1) + >>> process = PriorProcess.pitman_yor(1.0, 2.0, 0.5, 0.5) + >>> codebook.set_view_prior_process(process) Codebook 'my_table' - state_alpha_prior: G(α: 1, β: 1) - view_alpha_prior: G(α: 2, β: 3.1) + state_prior_process: DP(α ~ G(α: 1, β: 1)) + view_prior_process: PYP(α ~ G(α: 1, β: 2), d ~ Beta(α: 0.5, β: 0.5)) columns: 85 rows: 50 """ codebook = copy.copy(self) - codebook.codebook.set_view_alpha_prior(shape, rate) + codebook.codebook.set_view_prior_process(prior_process) return codebook def append_column_metadata(self, col_metadata: List[_lc.ColumnMetadata]): diff --git a/pylace/lace/engine.py b/pylace/lace/engine.py index 80922705..aed646c3 100644 --- a/pylace/lace/engine.py +++ b/pylace/lace/engine.py @@ -2,7 +2,7 @@ import itertools as it from os import PathLike -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Union, Set import pandas as pd import plotly.express as px @@ -208,7 +208,7 @@ def seed(self, rng_seed: int): │ --- ┆ --- ┆ --- ┆ --- │ │ u32 ┆ u32 ┆ u32 ┆ u32 │ ╞═══════╪══════╪══════════╪═══════╡ - │ 34 ┆ 49 ┆ 20 ┆ 49 │ + │ 34 ┆ 48 ┆ 26 ┆ 35 │ └───────┴──────┴──────────┴───────┘ If we set the seed, we get the same data. @@ -410,7 +410,7 @@ def flatten_columns(self): >>> engine.column_assignment(0) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0] >>> engine.column_assignment(1) - [0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 0, 0, 2, 2, 2, 2, 0] + [1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] >>> engine.flatten_columns() >>> engine.column_assignment(0) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] @@ -470,7 +470,7 @@ def row_assignments(self, state_ix: int): >>> from lace.examples import Animals >>> animals = Animals() >>> animals.row_assignments(0)[1][11] - 3 + 1 """ return self.engine.row_assignments(state_ix) @@ -503,18 +503,18 @@ def feature_params( >>> gauss_params = sats.feature_params("Period_minutes") >>> g = gauss_params[1][0] >>> g - Gaussian(mu=97.38792034245135, sigma=8.864698646528195) + Gaussian(mu=2216.995855497483, sigma=2809.7999447423026) >>> g.mu - 97.38792034245135 + 2216.995855497483 Get categorical weights from the Satellites dataset >>> cat_params = sats.feature_params("Class_of_Orbit", state_ixs=[2]) >>> c = cat_params[2][0] >>> c - Categorical_4(weights=[0.7196355242414928, ..., 0.12915471912497747]) + Categorical_4(weights=[0.23464953242044007, ..., 0.04544555912284563]) >>> c.weights # doctest: +ELLIPSIS - [0.7196355242414928, ..., 0.12915471912497747] + [0.23464953242044007, ..., 0.04544555912284563] You can also select columns by integer index @@ -522,7 +522,7 @@ def feature_params( 'Class_of_Orbit' >>> params = sats.feature_params(3) >>> params[0][1] - Categorical_4(weights=[0.0016550494108113051, ..., 0.000028906241993218738]) + Categorical_4(weights=[0.0010264756471345055, ..., 0.9963828657821785]) """ if state_ixs is None: @@ -571,17 +571,16 @@ def diagnostics(self, name: str = "score"): │ --- ┆ --- ┆ --- ┆ --- │ │ f64 ┆ f64 ┆ f64 ┆ f64 │ ╞══════════════╪══════════════╪══════════════╪══════════════╡ - │ -2882.424453 ┆ -2809.0876 ┆ -2638.714156 ┆ -2604.137622 │ - │ -2695.299327 ┆ -2666.497867 ┆ -2608.185358 ┆ -2576.545684 │ - │ -2642.539971 ┆ -2532.638368 ┆ -2576.463401 ┆ -2568.516617 │ - │ -2488.369418 ┆ -2513.134161 ┆ -2549.299382 ┆ -2554.131179 │ + │ -2533.503142 ┆ -2531.11451 ┆ -2488.379725 ┆ -2527.653495 │ + │ -2510.144546 ┆ -2519.318755 ┆ -2449.46579 ┆ -2529.866394 │ + │ -2494.957427 ┆ -2527.118066 ┆ -2417.423267 ┆ -2518.054613 │ + │ -2517.055318 ┆ -2534.235993 ┆ -2413.22879 ┆ -2523.029661 │ │ … ┆ … ┆ … ┆ … │ - │ -1972.005746 ┆ -2122.788121 ┆ -1965.921104 ┆ -1969.328651 │ - │ -1966.516529 ┆ -2117.398333 ┆ -1993.351756 ┆ -1986.589833 │ - │ -1969.400394 ┆ -2147.941128 ┆ -1968.697139 ┆ -1988.805311 │ - │ -1920.217666 ┆ -2081.368421 ┆ -1909.655836 ┆ -1920.432849 │ + │ -1763.593686 ┆ -1601.3273 ┆ -1873.277623 ┆ -1767.766707 │ + │ -1724.87438 ┆ -1648.269934 ┆ -1906.093392 ┆ -1809.921707 │ + │ -1776.739292 ┆ -1670.216919 ┆ -1898.314835 ┆ -1756.702674 │ + │ -1733.91896 ┆ -1665.882412 ┆ -1900.749398 ┆ -1750.687124 │ └──────────────┴──────────────┴──────────────┴──────────────┘ - """ diag = self.engine.diagnostics() @@ -627,11 +626,11 @@ def edit_cell(self, row: Union[str, int], col: Union[str, int], value): │ --- ┆ --- ┆ --- │ │ str ┆ u8 ┆ f64 │ ╞════════════╪════════╪═══════════╡ - │ pig ┆ 1 ┆ 1.565845 │ - │ rhinoceros ┆ 1 ┆ 1.094639 │ - │ buffalo ┆ 1 ┆ 1.094639 │ - │ chihuahua ┆ 1 ┆ 0.802085 │ - │ chimpanzee ┆ 1 ┆ 0.723817 │ + │ pig ┆ 1 ┆ 1.574539 │ + │ buffalo ┆ 1 ┆ 1.240631 │ + │ rhinoceros ┆ 1 ┆ 1.076105 │ + │ collie ┆ 0 ┆ 0.72471 │ + │ chimpanzee ┆ 1 ┆ 0.697159 │ └────────────┴────────┴───────────┘ >>> # change pig to not fierce >>> animals.edit_cell('pig', 'fierce', 0) @@ -644,11 +643,11 @@ def edit_cell(self, row: Union[str, int], col: Union[str, int], value): │ --- ┆ --- ┆ --- │ │ str ┆ u8 ┆ f64 │ ╞════════════╪════════╪═══════════╡ - │ rhinoceros ┆ 1 ┆ 1.094639 │ - │ buffalo ┆ 1 ┆ 1.094639 │ - │ chihuahua ┆ 1 ┆ 0.802085 │ - │ chimpanzee ┆ 1 ┆ 0.723817 │ - │ dalmatian ┆ 0 ┆ 0.594919 │ + │ buffalo ┆ 1 ┆ 1.240631 │ + │ rhinoceros ┆ 1 ┆ 1.076105 │ + │ collie ┆ 0 ┆ 0.72471 │ + │ chimpanzee ┆ 1 ┆ 0.697159 │ + │ chihuahua ┆ 1 ┆ 0.614058 │ └────────────┴────────┴───────────┘ Set a value to missing @@ -662,7 +661,7 @@ def edit_cell(self, row: Union[str, int], col: Union[str, int], value): │ --- ┆ --- ┆ --- │ │ str ┆ u8 ┆ f64 │ ╞═══════╪════════╪═════════════╡ - │ pig ┆ 0 ┆ 0.07593 │ + │ pig ┆ 0 ┆ 0.094179 │ └───────┴────────┴─────────────┘ """ @@ -1030,7 +1029,7 @@ def update( ... timeout=30, ... transitions=[ ... StateTransition.row_assignment(RowKernel.slice()), - ... StateTransition.view_alphas(), + ... StateTransition.view_prior_process_params(), ... ], ... ) @@ -1088,14 +1087,14 @@ def entropy(self, cols, n_mc_samples: int = 1000): >>> from lace.examples import Animals >>> animals = Animals() >>> animals.entropy(["slow"]) - 0.6755931727528786 + 0.6812321322736966 >>> animals.entropy(["water"]) - 0.49836129824622094 + 0.46626932307630625 Joint entropy >>> animals.entropy(["swims", "fast"]) - 0.9552642751735604 + 0.9367950081783651 We can use entropies to compute mutual information, I(X, Y) = H(X) + H(Y) - H(X, Y). @@ -1108,7 +1107,7 @@ def entropy(self, cols, n_mc_samples: int = 1000): >>> h_fast = animals.entropy(["fast"]) >>> h_swims_and_fast = animals.entropy(["swims", "fast"]) >>> h_swims + h_fast - h_swims_and_fast - 3.510013543328583e-05 + 7.03684751313105e-06 But swimming and having flippers are mutually predictive, so we should see more mutual information. @@ -1116,7 +1115,7 @@ def entropy(self, cols, n_mc_samples: int = 1000): >>> h_flippers = animals.entropy(["flippers"]) >>> h_swims_and_flippers = animals.entropy(["swims", "flippers"]) >>> h_swims + h_flippers - h_swims_and_flippers - 0.19361180218629537 + 0.18686797893023643 """ return self.engine.entropy(cols, n_mc_samples) @@ -1178,9 +1177,9 @@ def logp( shape: (3,) Series: 'logp' [f64] [ - 0.523575 - 0.06601 - 0.380453 + 0.515602 + 0.06607 + 0.38637 ] Conditioning using ``given`` @@ -1192,9 +1191,9 @@ def logp( shape: (3,) Series: 'logp' [f64] [ - 0.000349 - 0.000756 - 0.998017 + 0.000975 + 0.018733 + 0.972718 ] Ask about the likelihood of values belonging to multiple features @@ -1209,9 +1208,9 @@ def logp( shape: (3,) Series: 'logp' [f64] [ - 0.000306 - 0.000008 - 0.016546 + 0.000353 + 0.000006 + 0.015253 ] An example of the scaled variant: @@ -1223,9 +1222,9 @@ def logp( shape: (3,) Series: 'logp_scaled' [f64] [ - 0.137554 - 0.167357 - 0.577699 + 0.260898 + 0.133143 + 0.592816 ] For columns which we explicitly model missing-not-at-random data, we can @@ -1234,13 +1233,13 @@ def logp( >>> from math import exp >>> no_long_geo = pl.Series("longitude_radians_of_geo", [None]) >>> exp(engine.logp(no_long_geo)) - 0.631030460838865 + 0.626977387513902 The probability of a value missing (not-at-random) changes depending on the conditions. >>> exp(engine.logp(no_long_geo, given={"Class_of_Orbit": "GEO"})) - 0.048855132811982976 + 0.07779133514786091 And we can condition on missingness @@ -1251,9 +1250,9 @@ def logp( shape: (3,) Series: 'logp' [f64] [ - 0.827158 - 0.099435 - 0.029606 + 0.818785 + 0.090779 + 0.04799 ] Plot the marginal distribution of `Period_minutes` for each state @@ -1325,15 +1324,15 @@ def inconsistency(self, values, given=None): │ --- ┆ --- │ │ str ┆ f64 │ ╞════════════════╪═══════════════╡ - │ collie ┆ 0.816517 │ - │ beaver ┆ 0.809991 │ - │ rabbit ┆ 0.785911 │ - │ polar+bear ┆ 0.783775 │ + │ beaver ┆ 0.830524 │ + │ collie ┆ 0.826842 │ + │ rabbit ┆ 0.80862 │ + │ skunk ┆ 0.801401 │ │ … ┆ … │ - │ killer+whale ┆ 0.513013 │ - │ blue+whale ┆ 0.503965 │ - │ dolphin ┆ 0.480259 │ - │ humpback+whale ┆ 0.434979 │ + │ walrus ┆ 0.535375 │ + │ blue+whale ┆ 0.48628 │ + │ killer+whale ┆ 0.466145 │ + │ humpback+whale ┆ 0.433302 │ └────────────────┴───────────────┘ Find satellites with inconsistent orbital periods @@ -1371,15 +1370,15 @@ def inconsistency(self, values, given=None): │ --- ┆ --- ┆ --- │ │ str ┆ f64 ┆ f64 │ ╞═══════════════════════════════════╪═══════════════╪════════════════╡ - │ Intelsat 903 ┆ 1.973348 ┆ 1436.16 │ - │ TianLian 2 (TL-1-02, CTDRS) ┆ 1.3645 ┆ 1436.1 │ - │ QZS-1 (Quazi-Zenith Satellite Sy… ┆ 1.364247 ┆ 1436.0 │ - │ Compass G-8 (Beidou IGSO-3) ┆ 1.364093 ┆ 1435.93 │ + │ Intelsat 903 ┆ 1.767642 ┆ 1436.16 │ + │ Mercury 2 (Advanced Vortex 2, US… ┆ 1.649006 ┆ 1436.12 │ + │ INSAT 4CR (Indian National Satel… ┆ 1.648992 ┆ 1436.11 │ + │ QZS-1 (Quazi-Zenith Satellite Sy… ┆ 1.64879 ┆ 1436.0 │ │ … ┆ … ┆ … │ - │ Navstar GPS II-24 (Navstar SVN 3… ┆ 0.646141 ┆ 716.69 │ - │ Navstar GPS IIR-10 (Navstar SVN … ┆ 0.646027 ┆ 716.47 │ - │ Navstar GPS IIR-M-6 (Navstar SVN… ┆ 0.645991 ┆ 716.4 │ - │ BSAT-3B ┆ 0.625282 ┆ 1365.61 │ + │ Glonass 723 (Glonass 37-3, Cosmo… ┆ 0.646552 ┆ 680.75 │ + │ Glonass 721 (Glonass 37-1, Cosmo… ┆ 0.646474 ┆ 680.91 │ + │ Glonass 730 (Glonass 41-1, Cosmo… ┆ 0.646183 ┆ 681.53 │ + │ Wind (International Solar-Terres… ┆ 0.526911 ┆ 19700.45 │ └───────────────────────────────────┴───────────────┴────────────────┘ It looks like Intelsat 903 is the most inconsistent by a good amount. @@ -1476,14 +1475,13 @@ def surprisal( │ --- ┆ --- ┆ --- │ │ str ┆ f64 ┆ f64 │ ╞═══════════════════════════════════╪═══════════════════╪═══════════╡ - │ International Space Station (ISS… ┆ 30.0 ┆ 7.02499 │ - │ Landsat 7 ┆ 15.0 ┆ 4.869031 │ - │ Milstar DFS-5 (USA 164, Milstar … ┆ 0.0 ┆ 4.74869 │ - │ Optus B3 ┆ 0.5 ┆ 4.653549 │ - │ SDS III-3 (Satellite Data System… ┆ 0.5 ┆ 4.558333 │ + │ International Space Station (ISS… ┆ 30.0 ┆ 11.423102 │ + │ Milstar DFS-5 (USA 164, Milstar … ┆ 0.0 ┆ 6.661427 │ + │ DSP 21 (USA 159) (Defense Suppor… ┆ 0.5 ┆ 6.366436 │ + │ DSP 22 (USA 176) (Defense Suppor… ┆ 0.5 ┆ 6.366436 │ + │ Intelsat 701 ┆ 0.5 ┆ 6.366436 │ └───────────────────────────────────┴───────────────────┴───────────┘ - Compute the surprisal for specific cells >>> engine.surprisal( @@ -1495,8 +1493,8 @@ def surprisal( │ --- ┆ --- ┆ --- │ │ str ┆ f64 ┆ f64 │ ╞══════════════╪═══════════════════╪═══════════╡ - │ Landsat 7 ┆ 15.0 ┆ 4.869031 │ - │ Intelsat 701 ┆ 0.5 ┆ 4.533067 │ + │ Landsat 7 ┆ 15.0 ┆ 4.588265 │ + │ Intelsat 701 ┆ 0.5 ┆ 6.366436 │ └──────────────┴───────────────────┴───────────┘ Compute the surprisal of specific values in specific cells @@ -1512,8 +1510,8 @@ def surprisal( │ --- ┆ --- ┆ --- │ │ str ┆ f64 ┆ f64 │ ╞══════════════╪═══════════════════╪═══════════╡ - │ Landsat 7 ┆ 10.0 ┆ 3.037384 │ - │ Intelsat 701 ┆ 10.0 ┆ 2.559729 │ + │ Landsat 7 ┆ 10.0 ┆ 2.984587 │ + │ Intelsat 701 ┆ 10.0 ┆ 2.52041 │ └──────────────┴───────────────────┴───────────┘ Compute the surprisal of multiple values in a single cell @@ -1526,10 +1524,10 @@ def surprisal( shape: (4,) Series: 'surprisal' [f64] [ - 3.126282 - 2.938583 - 2.24969 - 3.037384 + 3.225658 + 3.036696 + 2.273096 + 2.984587 ] Surprisal will be different under different_states @@ -1546,8 +1544,8 @@ def surprisal( │ --- ┆ --- ┆ --- │ │ str ┆ f64 ┆ f64 │ ╞══════════════╪═══════════════════╪═══════════╡ - │ Landsat 7 ┆ 10.0 ┆ 2.743636 │ - │ Intelsat 701 ┆ 10.0 ┆ 2.587096 │ + │ Landsat 7 ┆ 10.0 ┆ 3.431414 │ + │ Intelsat 701 ┆ 10.0 ┆ 2.609992 │ └──────────────┴───────────────────┴───────────┘ """ @@ -1596,11 +1594,11 @@ def simulate( │ --- ┆ --- │ │ str ┆ f64 │ ╞════════════════╪════════════════╡ - │ MEO ┆ 2807.568333 │ - │ GEO ┆ 1421.333515 │ - │ LEO ┆ 92.435621 │ - │ GEO ┆ 1435.7067 │ - │ LEO ┆ 84.896787 │ + │ LEO ┆ 140.214617 │ + │ MEO ┆ 707.76105 │ + │ MEO ┆ 649.888366 │ + │ LEO ┆ 109.460389 │ + │ GEO ┆ 1309.460359 │ └────────────────┴────────────────┘ Simulate a pair of columns conditioned on another @@ -1616,11 +1614,11 @@ def simulate( │ --- ┆ --- │ │ str ┆ f64 │ ╞════════════════╪════════════════╡ - │ GEO ┆ 1439.041087 │ - │ GEO ┆ 1426.020318 │ - │ GEO ┆ 1430.553113 │ - │ GEO ┆ 1451.192889 │ - │ GEO ┆ 1431.855712 │ + │ LEO ┆ 97.079974 │ + │ GEO ┆ -45.703234 │ + │ LEO ┆ 114.135217 │ + │ LEO ┆ 103.676199 │ + │ GEO ┆ 1434.897091 │ └────────────────┴────────────────┘ Simulate missing values for columns that are missing not-at-random @@ -1632,11 +1630,11 @@ def simulate( │ --- │ │ f64 │ ╞══════════════════════════╡ + │ -2.719645 │ + │ -0.154891 │ │ null │ │ null │ - │ null │ - │ null │ - │ null │ + │ 0.712423 │ └──────────────────────────┘ >>> engine.simulate( ... ["longitude_radians_of_geo"], @@ -1649,11 +1647,11 @@ def simulate( │ --- │ │ f64 │ ╞══════════════════════════╡ - │ 0.396442 │ - │ 0.794023 │ - │ 0.643669 │ - │ -0.005531 │ - │ 1.827976 │ + │ 0.850506 │ + │ 0.666353 │ + │ 0.682146 │ + │ 0.221179 │ + │ 2.621126 │ └──────────────────────────┘ If we simulate using ``given`` conditions, we can include the @@ -1671,11 +1669,11 @@ def simulate( │ --- ┆ --- ┆ --- │ │ f64 ┆ str ┆ str │ ╞════════════════╪════════════════╪════════════════╡ - │ 1436.038447 ┆ Communications ┆ GEO │ - │ 1447.908161 ┆ Communications ┆ GEO │ - │ 1452.635331 ┆ Communications ┆ GEO │ - │ 1443.983013 ┆ Communications ┆ GEO │ - │ 1437.544045 ┆ Communications ┆ GEO │ + │ 1426.679095 ┆ Communications ┆ GEO │ + │ 54.08657 ┆ Communications ┆ GEO │ + │ 1433.563215 ┆ Communications ┆ GEO │ + │ 1436.388876 ┆ Communications ┆ GEO │ + │ 1434.298969 ┆ Communications ┆ GEO │ └────────────────┴────────────────┴────────────────┘ """ @@ -1719,13 +1717,12 @@ def draw(self, row: Union[int, str], col: Union[int, str], n: int = 1): shape: (5,) Series: 'Period_minutes' [f64] [ - 110.076567 - 108.096406 - 102.34334 - 90.175641 - 94.512276 + 125.0209 + 173.739372 + 103.887763 + 115.319662 + 98.08124 ] - """ srs = self.engine.draw(row, col, n) return utils.return_srs(srs) @@ -1771,18 +1768,18 @@ def predict( >>> from lace.examples import Animals >>> animals = Animals() >>> animals.predict("swims") - (0, 0.04384630488890182) + (0, 0.03782005724890601) Predict whether an animal swims given that it has flippers >>> animals.predict("swims", given={"flippers": 1}) - (1, 0.09588592928237495) + (1, 0.08920133574559677) Let's confuse lace and see what happens to its uncertainty. Let's predict whether an non-water animal with flippers swims >>> animals.predict("swims", given={"flippers": 1, "water": 0}) - (0, 0.36077426258767503) + (0, 0.23777388425463844) If you want to save time and you do not care about quantifying your epistemic uncertainty, you don't have to compute uncertainty. @@ -1828,19 +1825,19 @@ def variability( >>> from lace.examples import Satellites >>> sats = Satellites() >>> sats.variability("Period_minutes") - 691324.3941953736 + 709857.0508301815 Compute the variance of Period_minutes for geosynchronous satellite >>> sats.variability("Period_minutes", given={"Class_of_Orbit": "GEO"}) - 136818.61181890886 + 148682.45531411088 Compute the entropy of Class_of_orbit >>> sats.variability("Class_of_Orbit") - 0.9362550555890782 + 0.9571321355529944 >>> sats.variability("Class_of_Orbit", given={"Period_minutes": 1440.0}) - 0.01569677151657056 + 0.1455965989424529 """ return self.engine.variability(target, given, state_ixs) @@ -1909,15 +1906,15 @@ def impute( │ --- ┆ --- ┆ --- │ │ str ┆ str ┆ f64 │ ╞═══════════════════════════════════╪═════════════════╪═════════════╡ - │ AAUSat-3 ┆ Sun-Synchronous ┆ 0.186415 │ - │ ABS-1 (LMI-1, Lockheed Martin-In… ┆ Sun-Synchronous ┆ 0.360331 │ - │ ABS-1A (Koreasat 2, Mugunghwa 2,… ┆ Sun-Synchronous ┆ 0.425853 │ - │ ABS-2i (MBSat, Mobile Broadcasti… ┆ Sun-Synchronous ┆ 0.360331 │ + │ AAUSat-3 ┆ Sun-Synchronous ┆ 0.190897 │ + │ ABS-1 (LMI-1, Lockheed Martin-In… ┆ Sun-Synchronous ┆ 0.422782 │ + │ ABS-1A (Koreasat 2, Mugunghwa 2,… ┆ Sun-Synchronous ┆ 0.422782 │ + │ ABS-2i (MBSat, Mobile Broadcasti… ┆ Sun-Synchronous ┆ 0.422782 │ │ … ┆ … ┆ … │ - │ Zhongxing 20A ┆ Sun-Synchronous ┆ 0.360331 │ - │ Zhongxing 22A (Chinastar 22A) ┆ Sun-Synchronous ┆ 0.404823 │ - │ Zhongxing 2A (Chinasat 2A) ┆ Sun-Synchronous ┆ 0.360331 │ - │ Zhongxing 9 (Chinasat 9, Chinast… ┆ Sun-Synchronous ┆ 0.360331 │ + │ Zhongxing 20A ┆ Sun-Synchronous ┆ 0.422782 │ + │ Zhongxing 22A (Chinastar 22A) ┆ Sun-Synchronous ┆ 0.422782 │ + │ Zhongxing 2A (Chinasat 2A) ┆ Sun-Synchronous ┆ 0.422782 │ + │ Zhongxing 9 (Chinasat 9, Chinast… ┆ Sun-Synchronous ┆ 0.422782 │ └───────────────────────────────────┴─────────────────┴─────────────┘ Impute a defined set of rows @@ -1929,8 +1926,8 @@ def impute( │ --- ┆ --- ┆ --- │ │ str ┆ str ┆ f64 │ ╞═══════════════╪════════════════════════╪═════════════╡ - │ AAUSat-3 ┆ Technology Development ┆ 0.238355 │ - │ Zhongxing 20A ┆ Communications ┆ 0.129248 │ + │ AAUSat-3 ┆ Technology Development ┆ 0.236857 │ + │ Zhongxing 20A ┆ Communications ┆ 0.142772 │ └───────────────┴────────────────────────┴─────────────┘ Uncertainty is optional @@ -2088,12 +2085,12 @@ def mi( >>> from lace.examples import Animals >>> engine = Animals() >>> engine.mi([("swims", "flippers")]) - 0.27197816458827445 + 0.2785114781561444 You can select different normalizations of mutual information >>> engine.mi([("swims", "flippers")], mi_type="unnormed") - 0.19361180218629537 + 0.18686797893023643 Multiple pairs as inputs gets you a polars ``Series`` @@ -2106,8 +2103,8 @@ def mi( shape: (2,) Series: 'mi' [f64] [ - 0.271978 - 0.005378 + 0.278511 + 0.012031 ] """ @@ -2163,25 +2160,25 @@ def rowsim( >>> from lace.examples import Animals >>> animals = Animals() >>> animals.rowsim([("beaver", "polar+bear")]) - 0.6059523809523808 + 0.5305059523809523 What about if we weight similarity by columns and not the standard views? >>> animals.rowsim([("beaver", "polar+bear")], col_weighted=True) - 0.5698529411764706 + 0.5095588235294117 Not much change. How similar are they with respect to how we model their swimming? >>> animals.rowsim([("beaver", "polar+bear")], wrt=["swims"]) - 0.875 + 1.0 Very similar. But will all animals that swim be highly similar with respect to their swimming? >>> animals.rowsim([("otter", "polar+bear")], wrt=["swims"]) - 0.375 + 0.3125 Lace predicts an otter's swimming for different reasons than a polar bear's. @@ -2199,8 +2196,8 @@ def rowsim( shape: (2,) Series: 'rowsim' [f64] [ - 0.629315 - 0.772545 + 0.712798 + 0.841518 ] """ @@ -2243,13 +2240,13 @@ def pairwise_fn(self, fn_name, indices: Optional[list] = None, **kwargs): │ str ┆ str ┆ f64 │ ╞═══════╪═══════╪══════════╡ │ wolf ┆ wolf ┆ 1.0 │ - │ wolf ┆ rat ┆ 0.71689 │ - │ wolf ┆ otter ┆ 0.492262 │ - │ rat ┆ wolf ┆ 0.71689 │ + │ wolf ┆ rat ┆ 0.801339 │ + │ wolf ┆ otter ┆ 0.422619 │ + │ rat ┆ wolf ┆ 0.801339 │ │ rat ┆ rat ┆ 1.0 │ - │ rat ┆ otter ┆ 0.613095 │ - │ otter ┆ wolf ┆ 0.492262 │ - │ otter ┆ rat ┆ 0.613095 │ + │ rat ┆ otter ┆ 0.572173 │ + │ otter ┆ wolf ┆ 0.422619 │ + │ otter ┆ rat ┆ 0.572173 │ │ otter ┆ otter ┆ 1.0 │ └───────┴───────┴──────────┘ @@ -2267,13 +2264,13 @@ def pairwise_fn(self, fn_name, indices: Optional[list] = None, **kwargs): │ str ┆ str ┆ f64 │ ╞═══════╪═══════╪══════════╡ │ wolf ┆ wolf ┆ 1.0 │ - │ wolf ┆ rat ┆ 0.642647 │ - │ wolf ┆ otter ┆ 0.302206 │ - │ rat ┆ wolf ┆ 0.642647 │ + │ wolf ┆ rat ┆ 0.804412 │ + │ wolf ┆ otter ┆ 0.323529 │ + │ rat ┆ wolf ┆ 0.804412 │ │ rat ┆ rat ┆ 1.0 │ - │ rat ┆ otter ┆ 0.491176 │ - │ otter ┆ wolf ┆ 0.302206 │ - │ otter ┆ rat ┆ 0.491176 │ + │ rat ┆ otter ┆ 0.469853 │ + │ otter ┆ wolf ┆ 0.323529 │ + │ otter ┆ rat ┆ 0.469853 │ │ otter ┆ otter ┆ 1.0 │ └───────┴───────┴──────────┘ @@ -2288,13 +2285,13 @@ def pairwise_fn(self, fn_name, indices: Optional[list] = None, **kwargs): │ str ┆ str ┆ f64 │ ╞══════════╪══════════════╪══════════╡ │ antelope ┆ antelope ┆ 1.0 │ - │ antelope ┆ grizzly+bear ┆ 0.464137 │ - │ antelope ┆ killer+whale ┆ 0.479613 │ - │ antelope ┆ beaver ┆ 0.438467 │ + │ antelope ┆ grizzly+bear ┆ 0.457589 │ + │ antelope ┆ killer+whale ┆ 0.469494 │ + │ antelope ┆ beaver ┆ 0.332589 │ │ … ┆ … ┆ … │ - │ dolphin ┆ walrus ┆ 0.724702 │ - │ dolphin ┆ raccoon ┆ 0.340923 │ - │ dolphin ┆ cow ┆ 0.482887 │ + │ dolphin ┆ walrus ┆ 0.799851 │ + │ dolphin ┆ raccoon ┆ 0.236607 │ + │ dolphin ┆ cow ┆ 0.441964 │ │ dolphin ┆ dolphin ┆ 1.0 │ └──────────┴──────────────┴──────────┘ @@ -2390,6 +2387,41 @@ def clustermap( else: return ClusterMap(df, linkage) + def remove_rows( + self, + indices: Union[pd.Series, List[str], pd.Series, Set[str]], + ) -> pl.DataFrame: + """ + Remove rows from the table. + + Parameters + ---------- + indices: Union[pd.Series, List[str], pd.Series, Set[str]] + Rows to remove from the Engine, specified by index or id name. + + Example + ------- + Remove crab and squid from the animals example engine. + + >>> from lace.examples import Animals + >>> engine = Animals() + >>> n_rows = engine.n_rows + >>> removed = engine.remove_rows(["cow", "wolf"]) + >>> n_rows == engine.n_rows + 1 + True + >>> removed["index"] # doctest: +NORMALIZE_WHITESPACE + ┌────────┐ + │ index │ + │ --- │ + │ str │ + ╞════════╡ + │ cow │ + │ wolf │ + └────────┘ + + """ + return self.engine.remove_rows(indices) + class _TqdmUpdateHandler: def __init__(self): diff --git a/pylace/lace/resources/datasets/animals/codebook.yaml b/pylace/lace/resources/datasets/animals/codebook.yaml index 9d856418..120c53da 100644 --- a/pylace/lace/resources/datasets/animals/codebook.yaml +++ b/pylace/lace/resources/datasets/animals/codebook.yaml @@ -1,10 +1,12 @@ table_name: my_table -state_alpha_prior: - shape: 1.0 - rate: 1.0 -view_alpha_prior: - shape: 1.0 - rate: 1.0 +state_prior_process: !dirichlet + alpha_prior: + shape: 1.0 + rate: 1.0 +view_prior_process: !dirichlet + alpha_prior: + shape: 1.0 + rate: 1.0 col_metadata: - name: black coltype: !Categorical diff --git a/pylace/lace/resources/datasets/satellites/codebook.yaml b/pylace/lace/resources/datasets/satellites/codebook.yaml index 766042da..1349907f 100644 --- a/pylace/lace/resources/datasets/satellites/codebook.yaml +++ b/pylace/lace/resources/datasets/satellites/codebook.yaml @@ -1,10 +1,12 @@ table_name: my_data -state_alpha_prior: !Gamma - shape: 1.0 - rate: 1.0 -view_alpha_prior: - shape: 1.0 - rate: 1.0 +state_prior_process: !dirichlet + alpha_prior: + shape: 1.0 + rate: 1.0 +view_prior_process: !dirichlet + alpha_prior: + shape: 1.0 + rate: 1.0 col_metadata: - name: Country_of_Operator coltype: !Categorical diff --git a/pylace/lace/utils.py b/pylace/lace/utils.py index b8b0cf0b..4d8ec10d 100644 --- a/pylace/lace/utils.py +++ b/pylace/lace/utils.py @@ -128,36 +128,36 @@ def infer_column_metadata( _COMMON_TRANSITIONS = { "sams": [ - StateTransition.view_alphas(), + StateTransition.view_prior_process_params(), StateTransition.row_assignment(RowKernel.sams()), - StateTransition.view_alphas(), + StateTransition.view_prior_process_params(), StateTransition.row_assignment(RowKernel.sams()), - StateTransition.view_alphas(), + StateTransition.view_prior_process_params(), StateTransition.row_assignment(RowKernel.slice()), StateTransition.component_parameters(), StateTransition.column_assignment(ColumnKernel.gibbs()), - StateTransition.state_alpha(), + StateTransition.state_prior_process_params(), StateTransition.feature_priors(), ], "flat": [ - StateTransition.view_alphas(), + StateTransition.view_prior_process_params(), StateTransition.row_assignment(RowKernel.sams()), - StateTransition.view_alphas(), + StateTransition.view_prior_process_params(), StateTransition.row_assignment(RowKernel.sams()), - StateTransition.view_alphas(), + StateTransition.view_prior_process_params(), StateTransition.component_parameters(), StateTransition.row_assignment(RowKernel.slice()), StateTransition.component_parameters(), - StateTransition.view_alphas(), + StateTransition.view_prior_process_params(), StateTransition.feature_priors(), ], "fast": [ - StateTransition.view_alphas(), + StateTransition.view_prior_process_params(), StateTransition.row_assignment(RowKernel.slice()), StateTransition.component_parameters(), StateTransition.feature_priors(), StateTransition.column_assignment(ColumnKernel.slice()), - StateTransition.state_alpha(), + StateTransition.state_prior_process_params(), ], } diff --git a/pylace/pyproject.toml b/pylace/pyproject.toml index b67f36e2..fc161712 100644 --- a/pylace/pyproject.toml +++ b/pylace/pyproject.toml @@ -1,10 +1,10 @@ [build-system] -requires = ["maturin>=0.13,<0.14"] +requires = ["maturin>=1.0,<2"] build-backend = "maturin" [project] name = "pylace" -version = "0.7.1" +version = "0.8.0" description = "A probabalistic programming ML tool for science" requires-python = ">=3.8" classifiers = [ @@ -131,3 +131,7 @@ strict = true [tool.black] line-length = 80 + + +[tool.maturin] +module-name = "lace.core" diff --git a/pylace/requirements-dev.txt b/pylace/requirements-dev.txt index e5f4356a..aacb02b0 100644 --- a/pylace/requirements-dev.txt +++ b/pylace/requirements-dev.txt @@ -3,8 +3,8 @@ # Dependencies # Tooling -hypothesis==6.65.2 -maturin==0.14.10 -pytest==7.2.0 -pytest-cov==4.0.0 -pytest-xdist==3.1.0 +hypothesis==6.100.2 +maturin==1.5.1 +pytest==8.2.0 +pytest-cov==5.0.0 +pytest-xdist==3.6.1 diff --git a/pylace/src/df.rs b/pylace/src/df.rs index b7d05c00..d7200ac6 100644 --- a/pylace/src/df.rs +++ b/pylace/src/df.rs @@ -4,10 +4,10 @@ use polars::series::Series; use polars_arrow::ffi; use pyo3::exceptions::{PyException, PyIOError, PyValueError}; use pyo3::ffi::Py_uintptr_t; -use pyo3::types::PyModule; +use pyo3::types::{PyAnyMethods, PyModule}; use pyo3::{ - create_exception, FromPyObject, IntoPy, PyAny, PyErr, PyObject, PyResult, - Python, ToPyObject, + create_exception, Bound, FromPyObject, IntoPy, PyAny, PyErr, PyObject, + PyResult, Python, ToPyObject, }; #[derive(Debug)] @@ -146,7 +146,7 @@ impl<'a> FromPyObject<'a> for PyDataFrame { pub(crate) fn to_py_array( array: ArrayRef, py: Python, - pyarrow: &PyModule, + pyarrow: &Bound, ) -> PyResult { let schema = Box::new(ffi::export_field_to_c(&ArrowField::new( "", @@ -173,10 +173,11 @@ impl IntoPy for PySeries { let s = self.0.rechunk(); let name = s.name(); let arr = s.to_arrow(0); - let pyarrow = py.import("pyarrow").expect("pyarrow not installed"); - let polars = py.import("polars").expect("polars not installed"); + let pyarrow = + py.import_bound("pyarrow").expect("pyarrow not installed"); + let polars = py.import_bound("polars").expect("polars not installed"); - let arg = to_py_array(arr, py, pyarrow).unwrap(); + let arg = to_py_array(arr, py, &pyarrow).unwrap(); let s = polars.call_method1("from_arrow", (arg,)).unwrap(); let s = s.call_method1("rename", (name,)).unwrap(); s.to_object(py) @@ -194,7 +195,7 @@ impl IntoPy for PyDataFrame { .map(|s| PySeries(s.clone()).into_py(py)) .collect::>(); - let polars = py.import("polars").expect("polars not installed"); + let polars = py.import_bound("polars").expect("polars not installed"); let df_object = polars.call_method1("DataFrame", (pyseries,)).unwrap(); df_object.into_py(py) } diff --git a/pylace/src/lib.rs b/pylace/src/lib.rs index 3b010c6e..c0c6624c 100644 --- a/pylace/src/lib.rs +++ b/pylace/src/lib.rs @@ -13,7 +13,7 @@ use df::{DataFrameLike, PyDataFrame, PySeries}; use lace::data::DataSource; use lace::metadata::SerializedType; use lace::prelude::ColMetadataList; -use lace::{EngineUpdateConfig, FType, HasStates, OracleT}; +use lace::{Datum, EngineUpdateConfig, FType, HasStates, OracleT, TableIndex}; use polars::prelude::{DataFrame, NamedFrom, Series}; use pyo3::create_exception; use pyo3::exceptions::{PyIndexError, PyRuntimeError, PyValueError}; @@ -114,7 +114,7 @@ impl CoreEngine { /// Load a Engine from metadata #[classmethod] - fn load(_cls: &PyType, path: PathBuf) -> PyResult { + fn load(_cls: &Bound, path: PathBuf) -> PyResult { let (engine, rng) = { let mut engine = lace::Engine::load(path) .map_err(|e| EngineLoadError::new_err(e.to_string()))?; @@ -143,7 +143,7 @@ impl CoreEngine { } /// Return a copy of the engine - fn __deepcopy__(&self, _memo: &PyDict) -> Self { + fn __deepcopy__(&self, _memo: &Bound) -> Self { self.clone() } @@ -238,7 +238,7 @@ impl CoreEngine { .collect() } - fn ftype(&self, col: &PyAny) -> PyResult { + fn ftype(&self, col: &Bound) -> PyResult { let col_ix = utils::value_to_name(col, &self.col_indexer)?; self.engine .ftype(col_ix) @@ -259,7 +259,7 @@ impl CoreEngine { ); Err(PyErr::new::(msg)) } else { - Ok(self.engine.states[state_ix].asgn.asgn.clone()) + Ok(self.engine.states[state_ix].asgn().asgn.clone()) } } @@ -275,7 +275,7 @@ impl CoreEngine { let asgns = self.engine.states[state_ix] .views .iter() - .map(|view| view.asgn.asgn.clone()) + .map(|view| view.asgn().asgn.clone()) .collect(); Ok(asgns) } @@ -284,9 +284,9 @@ impl CoreEngine { fn feature_params<'p>( &self, py: Python<'p>, - col: &PyAny, + col: &Bound, state_ix: usize, - ) -> PyResult<&'p PyAny> { + ) -> PyResult> { use component::ComponentParams; let col_ix = utils::value_to_index(col, &self.col_indexer)?; @@ -302,16 +302,16 @@ impl CoreEngine { let mixture = self.engine.states[state_ix].feature_as_mixture(col_ix); match ComponentParams::from(mixture) { ComponentParams::Bernoulli(params) => { - Ok(params.into_py(py).into_ref(py)) + Ok(params.into_py(py).into_bound(py)) } ComponentParams::Categorical(params) => { - Ok(params.into_py(py).into_ref(py)) + Ok(params.into_py(py).into_bound(py)) } ComponentParams::Gaussian(params) => { - Ok(params.into_py(py).into_ref(py)) + Ok(params.into_py(py).into_bound(py)) } ComponentParams::Poisson(params) => { - Ok(params.into_py(py).into_ref(py)) + Ok(params.into_py(py).into_bound(py)) } } } @@ -361,7 +361,7 @@ impl CoreEngine { /// array([0.125, 0. ]) /// >>> engine.depprob([('swims', 'flippers'), ('swims', 'fast')]) /// array([0.875, 0.25 ]) - fn depprob(&self, col_pairs: &PyList) -> PyResult { + fn depprob(&self, col_pairs: &Bound) -> PyResult { let pairs = list_to_pairs(col_pairs, &self.col_indexer)?; self.engine .depprob_pw(&pairs) @@ -373,11 +373,11 @@ impl CoreEngine { #[pyo3(signature=(col_pairs, n_mc_samples=1000, mi_type="iqr"))] fn mi( &self, - col_pairs: &PyList, + col_pairs: &Bound, n_mc_samples: usize, mi_type: &str, ) -> PyResult { - let pairs = list_to_pairs(col_pairs, &self.col_indexer)?; + let pairs = list_to_pairs(&col_pairs, &self.col_indexer)?; let mi_type = utils::str_to_mitype(mi_type)?; self.engine .mi_pw(&pairs, n_mc_samples, mi_type) @@ -385,7 +385,7 @@ impl CoreEngine { .map(|xs| PySeries(Series::new("mi", xs))) } - /// Row similarlity + /// Row similarity /// /// Parameters /// ---------- @@ -440,8 +440,8 @@ impl CoreEngine { #[pyo3(signature=(row_pairs, wrt=None, col_weighted=false))] fn rowsim( &self, - row_pairs: &PyList, - wrt: Option<&PyAny>, + row_pairs: &Bound, + wrt: Option<&Bound>, col_weighted: bool, ) -> PyResult { let variant = if col_weighted { @@ -465,8 +465,8 @@ impl CoreEngine { fn pairwise_fn( &self, fn_name: &str, - pairs: &PyList, - fn_kwargs: Option<&PyDict>, + pairs: &Bound, + fn_kwargs: Option<&Bound>, ) -> PyResult { match fn_name { "depprob" => self.depprob(pairs).map(|xs| (xs, &self.col_indexer)), @@ -481,9 +481,9 @@ impl CoreEngine { "rowsim" => { let args = fn_kwargs.map_or_else( || Ok(utils::RowsimArgs::default()), - utils::rowsim_args_from_dict, + |dict| utils::rowsim_args_from_dict(dict), )?; - self.rowsim(pairs, args.wrt, args.col_weighted) + self.rowsim(pairs, args.wrt.as_ref(), args.col_weighted) .map(|xs| (xs, &self.row_indexer)) } _ => Err(PyErr::new::(format!( @@ -544,8 +544,8 @@ impl CoreEngine { #[pyo3(signature = (cols, given=None, n=1))] fn simulate( &mut self, - cols: &PyAny, - given: Option<&PyDict>, + cols: &Bound, + given: Option<&Bound>, n: usize, ) -> PyResult { let col_ixs = pyany_to_indices(cols, &self.col_indexer)?; @@ -583,8 +583,8 @@ impl CoreEngine { #[pyo3(signature = (row, col, n=1))] fn draw( &mut self, - row: &PyAny, - col: &PyAny, + row: &Bound, + col: &Bound, n: usize, ) -> PyResult { let row_ix = utils::value_to_index(row, &self.row_indexer)?; @@ -636,8 +636,8 @@ impl CoreEngine { /// ``` fn logp( &self, - values: &PyAny, - given: Option<&PyDict>, + values: &Bound, + given: Option<&Bound>, state_ixs: Option>, ) -> PyResult { let df_vals = @@ -664,8 +664,8 @@ impl CoreEngine { fn logp_scaled( &self, - values: &PyAny, - given: Option<&PyDict>, + values: &Bound, + given: Option<&Bound>, state_ixs: Option>, ) -> PyResult { let df_vals = @@ -693,9 +693,9 @@ impl CoreEngine { #[pyo3(signature=(col, rows=None, values=None, state_ixs=None))] fn surprisal( &self, - col: &PyAny, - rows: Option<&PyAny>, - values: Option<&PyAny>, + col: &Bound, + rows: Option<&Bound>, + values: Option<&Bound>, state_ixs: Option>, ) -> PyResult { let col_ix = utils::value_to_index(col, &self.col_indexer)?; @@ -817,7 +817,11 @@ impl CoreEngine { } #[pyo3(signature=(row, wrt=None))] - fn novelty(&self, row: &PyAny, wrt: Option<&PyAny>) -> PyResult { + fn novelty( + &self, + row: &Bound, + wrt: Option<&Bound>, + ) -> PyResult { let row_ix = utils::value_to_index(row, &self.row_indexer)?; let wrt = wrt .map(|cols| utils::pyany_to_indices(cols, &self.col_indexer)) @@ -828,7 +832,11 @@ impl CoreEngine { } #[pyo3(signature=(cols, n_mc_samples=1000))] - fn entropy(&self, cols: &PyAny, n_mc_samples: usize) -> PyResult { + fn entropy( + &self, + cols: &Bound, + n_mc_samples: usize, + ) -> PyResult { let col_ixs = utils::pyany_to_indices(cols, &self.col_indexer)?; self.engine .entropy(&col_ixs, n_mc_samples) @@ -855,8 +863,8 @@ impl CoreEngine { #[pyo3(signature=(col, rows=None, with_uncertainty=true))] fn impute( &mut self, - col: &PyAny, - rows: Option<&PyAny>, + col: &Bound, + rows: Option<&Bound>, with_uncertainty: bool, ) -> PyResult { use lace::cc::feature::Feature; @@ -938,8 +946,8 @@ impl CoreEngine { #[pyo3(signature=(target, given=None, state_ixs=None, with_uncertainty=true))] fn predict( &self, - target: &PyAny, - given: Option<&PyDict>, + target: &Bound, + given: Option<&Bound>, state_ixs: Option>, with_uncertainty: bool, ) -> PyResult> { @@ -972,8 +980,8 @@ impl CoreEngine { #[pyo3(signature=(target, given=None, state_ixs=None))] fn variability( &self, - target: &PyAny, - given: Option<&PyDict>, + target: &Bound, + given: Option<&Bound>, state_ixs: Option>, ) -> PyResult { let col_ix = value_to_index(target, &self.col_indexer)?; @@ -1123,7 +1131,7 @@ impl CoreEngine { /// ... ) /// >>> /// >>> engine.append_rows(row) - fn append_rows(&mut self, rows: &PyAny) -> PyResult<()> { + fn append_rows(&mut self, rows: &Bound) -> PyResult<()> { let df_vals = pandas_to_insert_values( rows, &self.col_indexer, @@ -1179,10 +1187,71 @@ impl CoreEngine { Ok(()) } + /// Remove Rows at the given indices. + /// + /// Example + /// ------- + /// + /// >>> import lace + /// >>> engine = lace.Engine('animals.rp') + /// >>> n_rows = engine.shape[0] + /// >>> removed = engine.remove_rows(["wolf", "ox"]) + /// >>> removed["index"].to_list() + /// ["wolf", "ox"] + /// >>> n_rows - 2 == engine.shape[0] + /// True + fn remove_rows(&mut self, rows: &Bound) -> PyResult { + let remove: Vec = rows.extract()?; + + let row_idxs: Vec = remove + .iter() + .map(|row_name| { + self.engine.codebook.row_index(row_name).ok_or_else(|| { + PyIndexError::new_err(format!( + "{row_name} is not a valid row index" + )) + }) + }) + .collect::>>()?; + + let mut df = polars::frame::DataFrame::empty(); + let index = polars::series::Series::new("index", remove); + df.with_column(index).map_err(to_pyerr)?; + + for col_ix in 0..self.engine.n_cols() { + let values = row_idxs + .iter() + .map(|&row_ix| { + self.engine.datum(row_ix, col_ix).map_err(to_pyerr) + }) + .collect::>>()?; + + let ftype = self.engine.ftype(col_ix).map_err(to_pyerr)?; + let srs = utils::vec_to_srs( + values, + col_ix, + ftype, + &self.engine.codebook, + )?; + df.with_column(srs.0).map_err(to_pyerr)?; + } + + self.engine + .remove_data( + row_idxs + .into_iter() + .map(|idx| TableIndex::Row(idx)) + .collect::>>(), + ) + .map_err(to_pyerr)?; + + Ok(PyDataFrame(df)) + } + /// Append new columns to the Engine fn append_columns( &mut self, - cols: &PyAny, + cols: &Bound, mut metadata: Vec, ) -> PyResult<()> { let suppl_types = Some( @@ -1255,7 +1324,7 @@ impl CoreEngine { } /// Delete a given column from the ``Engine`` - fn del_column(&mut self, col: &PyAny) -> PyResult<()> { + fn del_column(&mut self, col: &Bound) -> PyResult<()> { let col_ix = utils::value_to_index(col, &self.col_indexer)?; self.col_indexer.drop_by_ix(col_ix)?; self.engine.del_column(col_ix).map_err(to_pyerr) @@ -1273,9 +1342,9 @@ impl CoreEngine { /// The new value at the cell fn edit_cell( &mut self, - row: &PyAny, - col: &PyAny, - value: &PyAny, + row: &Bound, + col: &Bound, + value: &Bound, ) -> PyResult<()> { let row_ix = utils::value_to_index(row, &self.row_indexer)?; let col_ix = utils::value_to_index(col, &self.col_indexer)?; @@ -1299,7 +1368,7 @@ impl CoreEngine { fn categorical_support( &self, - col: &PyAny, + col: &Bound, ) -> PyResult>> { use lace::codebook::ValueMap as Vm; let col_ix = utils::value_to_index(col, &self.col_indexer)?; @@ -1336,7 +1405,7 @@ impl CoreEngine { } pub fn __getstate__(&self, py: Python) -> PyResult { - Ok(PyBytes::new( + Ok(PyBytes::new_bound( py, &bincode::serialize(&self).map_err(|e| { PyValueError::new_err(format!( @@ -1373,7 +1442,8 @@ create_exception!(lace, EngineUpdateError, pyo3::exceptions::PyException); /// A Python module implemented in Rust. #[pymodule] -fn core(py: Python, m: &PyModule) -> PyResult<()> { +#[pyo3(name = "core")] +fn core(py: Python, m: &Bound) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -1388,10 +1458,14 @@ fn core(py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_function(wrap_pyfunction!(infer_srs_metadata, m)?)?; m.add_function(wrap_pyfunction!(metadata::codebook_from_df, m)?)?; - m.add("EngineLoadError", py.get_type::())?; - m.add("EngineUpdateError", py.get_type::())?; + m.add("EngineLoadError", py.get_type_bound::())?; + m.add( + "EngineUpdateError", + py.get_type_bound::(), + )?; Ok(()) } diff --git a/pylace/src/metadata.rs b/pylace/src/metadata.rs index 9277930c..27aa53dd 100644 --- a/pylace/src/metadata.rs +++ b/pylace/src/metadata.rs @@ -5,7 +5,7 @@ use lace::stats::prior::csd::CsdHyper; use lace::stats::prior::nix::NixHyper; use lace::stats::prior::pg::PgHyper; use lace::stats::rv::dist::{ - Gamma, Gaussian, InvGamma, NormalInvChiSquared, SymmetricDirichlet, + Beta, Gamma, Gaussian, InvGamma, NormalInvChiSquared, SymmetricDirichlet, }; use polars::prelude::DataFrame; use pyo3::exceptions::{PyIOError, PyIndexError}; @@ -278,7 +278,7 @@ impl ValueMap { /// Create a map of ``k`` unsigned integers #[classmethod] #[pyo3(signature = (k))] - pub fn int(_cls: &PyType, k: usize) -> Self { + pub fn int(_cls: &Bound, k: usize) -> Self { Self(lace::codebook::ValueMap::U8(k)) } @@ -295,7 +295,7 @@ impl ValueMap { /// The strings are not unique #[classmethod] #[pyo3(signature = (values))] - pub fn string(_cls: &PyType, values: Vec) -> PyResult { + pub fn string(_cls: &Bound, values: Vec) -> PyResult { lace::codebook::ValueMap::try_from(values) .map_err(PyValueError::new_err) .map(Self) @@ -303,7 +303,7 @@ impl ValueMap { /// Create a map from boolean #[classmethod] - pub fn bool(_cls: &PyType) -> Self { + pub fn bool(_cls: &Bound) -> Self { Self(lace::codebook::ValueMap::Bool) } @@ -349,7 +349,7 @@ impl ColumnMetadata { #[classmethod] #[pyo3(signature = (name, prior=None, hyper=None))] pub fn continuous( - _cls: &PyType, + _cls: &Bound, name: String, prior: Option, hyper: Option, @@ -386,7 +386,7 @@ impl ColumnMetadata { #[classmethod] #[pyo3(signature = (name, k, value_map=None, prior=None, hyper=None))] pub fn categorical( - _cls: &PyType, + _cls: &Bound, name: String, k: usize, value_map: Option, @@ -423,7 +423,7 @@ impl ColumnMetadata { #[classmethod] #[pyo3(signature = (name, prior=None, hyper=None))] pub fn count( - _cls: &PyType, + _cls: &Bound, name: String, prior: Option, hyper: Option, @@ -485,7 +485,8 @@ enum CodebookMethod { Path(PathBuf), Inferred { cat_cutoff: Option, - alpha_prior_opt: Option, + state_prior_process: Option, + view_prior_process: Option, no_hypers: bool, }, Codebook(Codebook), @@ -495,12 +496,47 @@ impl Default for CodebookMethod { fn default() -> Self { Self::Inferred { cat_cutoff: None, - alpha_prior_opt: None, + state_prior_process: None, + view_prior_process: None, no_hypers: false, } } } +#[pyclass] +#[derive(Clone, Debug)] +pub struct PriorProcess(lace::codebook::PriorProcess); + +#[pymethods] +impl PriorProcess { + #[classmethod] + #[pyo3(signature=(alpha_shape=1.0, alpha_rate=1.0, d_a=0.5, d_b=0.5))] + pub fn pitman_yor( + _cls: &Bound, + alpha_shape: f64, + alpha_rate: f64, + d_a: f64, + d_b: f64, + ) -> Self { + PriorProcess(lace::codebook::PriorProcess::PitmanYor { + alpha_prior: Gamma::new(alpha_shape, alpha_rate).unwrap(), + d_prior: Beta::new(d_a, d_b).unwrap(), + }) + } + + #[classmethod] + #[pyo3(signature=(alpha_shape=1.0, alpha_rate=1.0))] + pub fn dirichlet( + _cls: &Bound, + alpha_shape: f64, + alpha_rate: f64, + ) -> Self { + PriorProcess(lace::codebook::PriorProcess::Dirichlet { + alpha_prior: Gamma::new(alpha_shape, alpha_rate).unwrap(), + }) + } +} + #[pyclass] #[derive(Clone, Debug)] pub struct Codebook(pub(crate) lace::codebook::Codebook); @@ -518,25 +554,19 @@ impl Codebook { self.0.table_name = table_name; } - #[pyo3(signature=(shape=1.0, rate=1.0))] - pub fn set_state_alpha_prior( + pub fn set_state_prior_process( &mut self, - shape: f64, - rate: f64, + process: PriorProcess, ) -> PyResult<()> { - let gamma = Gamma::new(shape, rate).map_err(to_pyerr)?; - self.0.state_alpha_prior = Some(gamma); + self.0.state_prior_process = Some(process.0); Ok(()) } - #[pyo3(signature=(shape=1.0, rate=1.0))] - pub fn set_view_alpha_prior( + pub fn set_view_prior_process( &mut self, - shape: f64, - rate: f64, + process: PriorProcess, ) -> PyResult<()> { - let gamma = Gamma::new(shape, rate).map_err(to_pyerr)?; - self.0.view_alpha_prior = Some(gamma); + self.0.view_prior_process = Some(process.0); Ok(()) } @@ -583,19 +613,19 @@ impl Codebook { pub fn __repr__(&self) -> String { format!( "Codebook '{}'\ - \n state_alpha_prior: {}\ - \n view_alpha_prior: {}\ + \n state_prior_process: {}\ + \n view_prior_process: {}\ \n columns: {}\ \n rows: {}", self.0.table_name, self.0 - .state_alpha_prior + .state_prior_process .clone() - .map_or_else(|| String::from("None"), |g| format!("{g}")), + .map_or_else(|| String::from("None"), |p| format!("{}", p)), self.0 - .view_alpha_prior + .view_prior_process .clone() - .map_or_else(|| String::from("None"), |g| format!("{g}")), + .map_or_else(|| String::from("None"), |p| format!("{}", p)), self.0.col_metadata.len(), self.0.row_names.len() ) @@ -642,6 +672,15 @@ impl Codebook { Ok(()) } } + + /// Create a new codebook with the same columns but row indices from another dataframe. + fn with_index(&self, new_index: Vec) -> PyResult { + let row_names = RowNameList::try_from(new_index).map_err(to_pyerr)?; + Ok(Self(lace::codebook::Codebook { + row_names, + ..self.0.clone() + })) + } } #[pyfunction] @@ -654,7 +693,8 @@ pub fn codebook_from_df( CodebookBuilder { method: CodebookMethod::Inferred { cat_cutoff, - alpha_prior_opt: None, + state_prior_process: None, + view_prior_process: None, no_hypers, }, } @@ -672,36 +712,33 @@ pub struct CodebookBuilder { impl CodebookBuilder { #[classmethod] /// Load a Codebook from a path. - fn load(_cls: &PyType, path: PathBuf) -> Self { + fn load(_cls: &Bound, path: PathBuf) -> Self { Self { method: CodebookMethod::Path(path), } } #[classmethod] - #[pyo3(signature = (cat_cutoff=None, alpha_prior_shape_rate=None, use_hypers=true))] + #[pyo3(signature = (cat_cutoff=None, state_prior_process=None, view_prior_process=None, use_hypers=true))] fn infer( - _cls: &PyType, + _cls: &Bound, cat_cutoff: Option, - alpha_prior_shape_rate: Option<(f64, f64)>, + state_prior_process: Option, + view_prior_process: Option, use_hypers: bool, ) -> PyResult { - let alpha_prior_opt = alpha_prior_shape_rate - .map(|(shape, rate)| Gamma::new(shape, rate)) - .transpose() - .map_err(|e| PyValueError::new_err(e.to_string()))?; - Ok(Self { method: CodebookMethod::Inferred { cat_cutoff, - alpha_prior_opt, + state_prior_process, + view_prior_process, no_hypers: !use_hypers, }, }) } #[classmethod] - fn codebook(_cls: &PyType, codebook: Codebook) -> Self { + fn codebook(_cls: &Bound, codebook: Codebook) -> Self { Self { method: CodebookMethod::Codebook(codebook), } @@ -730,14 +767,21 @@ impl CodebookBuilder { } CodebookMethod::Inferred { cat_cutoff, - alpha_prior_opt, + state_prior_process, + view_prior_process, + no_hypers, + } => df_to_codebook( + df, + cat_cutoff, + state_prior_process.map(|p| p.0), + view_prior_process.map(|p| p.0), no_hypers, - } => df_to_codebook(df, cat_cutoff, alpha_prior_opt, no_hypers) - .map_err(|e| { - PyValueError::new_err(format!( - "Failed to infer the Codebook. Error: {e}" - )) - }), + ) + .map_err(|e| { + PyValueError::new_err(format!( + "Failed to infer the Codebook. Error: {e}" + )) + }), CodebookMethod::Codebook(codebook) => Ok(codebook.0), } } @@ -745,7 +789,7 @@ impl CodebookBuilder { fn __repr__(&self) -> String { match &self.method { CodebookMethod::Path(path) => format!("", path.display()), - CodebookMethod::Inferred { cat_cutoff, alpha_prior_opt, no_hypers } => format!("CodebookBuilder Inferred(cat_cutoff={cat_cutoff:?}, alpha_prior_opt={alpha_prior_opt:?}, use_hypers={})", !no_hypers), + CodebookMethod::Inferred { cat_cutoff, state_prior_process, view_prior_process, no_hypers } => format!("CodebookBuilder Inferred(cat_cutoff={cat_cutoff:?}, state_prior_process={state_prior_process:?}, view_prior_process={view_prior_process:?}, use_hypers={})", !no_hypers), CodebookMethod::Codebook(_) => String::from("Codebook (fully specified)"), } } diff --git a/pylace/src/transition.rs b/pylace/src/transition.rs index 7978db47..f45524fc 100644 --- a/pylace/src/transition.rs +++ b/pylace/src/transition.rs @@ -56,7 +56,7 @@ impl RowKernel { } } -/// A particular state transition withing the Markov chain +/// A particular state transition within the Markov chain #[pyclass] #[derive(Clone, Copy)] pub(crate) struct StateTransition(lace::cc::transition::StateTransition); @@ -79,18 +79,18 @@ impl StateTransition { )) } - /// The state alpha (controls the assignment of columns to views) - /// transition. + /// The state prior process parameters (controls the assignment of + /// columns to views) transition. #[staticmethod] - fn state_alpha() -> Self { - Self(lace::cc::transition::StateTransition::StateAlpha) + fn state_prior_process_params() -> Self { + Self(lace::cc::transition::StateTransition::StatePriorProcessParams) } - /// The view alpha (controls the assignment of rows to categories within - /// each view) transition. + /// The view prior process parameters (controls the assignment of rows to + /// categories within each view) transition. #[staticmethod] - fn view_alphas() -> Self { - Self(lace::cc::transition::StateTransition::ViewAlphas) + fn view_prior_process_params() -> Self { + Self(lace::cc::transition::StateTransition::ViewPriorProcessParams) } /// Re-sample the feature prior parameters @@ -149,8 +149,8 @@ impl std::fmt::Display for StateTransition { } St::FeaturePriors => write!(f, "FeaturePriors"), St::ComponentParams => write!(f, "ComponentParams"), - St::StateAlpha => write!(f, "StateAlpha"), - St::ViewAlphas => write!(f, "ViewAlphas"), + St::StatePriorProcessParams => write!(f, "StatePriorProcessParams"), + St::ViewPriorProcessParams => write!(f, "ViewPriorProcessParams"), } } } diff --git a/pylace/src/update_handler.rs b/pylace/src/update_handler.rs index 023e18e3..643834ea 100644 --- a/pylace/src/update_handler.rs +++ b/pylace/src/update_handler.rs @@ -4,7 +4,7 @@ use std::sync::{Arc, Mutex}; use lace::cc::state::State; use lace::update_handler::UpdateHandler; use lace::EngineUpdateConfig; -use pyo3::{pyclass, IntoPy, Py, PyAny}; +use pyo3::{prelude::PyDictMethods, pyclass, IntoPy, Py, PyAny}; /// Python version of `EngineUpdateConfig`. #[derive(Clone, Debug)] @@ -36,7 +36,7 @@ impl PyUpdateHandler { macro_rules! pydict { ($py: expr, $($key:tt : $val:expr),* $(,)?) => {{ - let map = pyo3::types::PyDict::new($py); + let map = pyo3::types::PyDict::new_bound($py); $( let _ = map.set_item($key, $val.into_py($py)) .expect("Should be able to set item in PyDict"); @@ -59,7 +59,7 @@ macro_rules! call_pyhandler_noret { ); handler - .call_method(py, $func_name, (), kwargs.into()) + .call_method_bound(py, $func_name, (), Some(&kwargs)) .expect("Expected python call_method to return successfully"); }) }}; @@ -79,7 +79,7 @@ macro_rules! call_pyhandler_ret { ); handler - .call_method(py, $func_name, (), kwargs.into()) + .call_method_bound(py, $func_name, (), Some(&kwargs)) .expect("Expected python call_method to return successfully") .extract(py) .expect("Failed to extract expected type") diff --git a/pylace/src/utils.rs b/pylace/src/utils.rs index 033dcd94..a7848640 100644 --- a/pylace/src/utils.rs +++ b/pylace/src/utils.rs @@ -184,7 +184,7 @@ pub(crate) struct MiArgs { #[derive(Default)] pub(crate) struct RowsimArgs<'a> { - pub(crate) wrt: Option<&'a PyAny>, + pub(crate) wrt: Option>, pub(crate) col_weighted: bool, } @@ -205,7 +205,7 @@ pub(crate) fn coltype_to_ftype(col_type: &ColType) -> FType { } } -pub(crate) fn mi_args_from_dict(dict: &PyDict) -> PyResult { +pub(crate) fn mi_args_from_dict(dict: &Bound) -> PyResult { let n_mc_samples: Option = dict .get_item("n_mc_samples")? .map(|any| any.extract::()) @@ -222,13 +222,15 @@ pub(crate) fn mi_args_from_dict(dict: &PyDict) -> PyResult { }) } -pub(crate) fn rowsim_args_from_dict(dict: &PyDict) -> PyResult { +pub(crate) fn rowsim_args_from_dict<'a>( + dict: &'a Bound<'a, PyDict>, +) -> PyResult> { let col_weighted: Option = dict .get_item("col_weighted")? .map(|any| any.extract::()) .transpose()?; - let wrt: Option<&PyAny> = dict.get_item("wrt")?; + let wrt: Option> = dict.get_item("wrt")?; Ok(RowsimArgs { wrt, @@ -427,7 +429,7 @@ impl Indexer { } pub(crate) fn pairs_list_iter<'a>( - pairs: &'a PyList, + pairs: &'a Bound<'a, PyList>, indexer: &'a Indexer, ) -> impl Iterator> + 'a { pairs.iter().map(|item| { @@ -438,28 +440,36 @@ pub(crate) fn pairs_list_iter<'a>( "A pair consists of two items", )) } else { - value_to_index(&ixs[0], indexer).and_then(|a| { - value_to_index(&ixs[1], indexer).map(|b| (a, b)) - }) + ixs.get_item(0) + .and_then(|a| value_to_index(&a, indexer)) + .and_then(|a| { + ixs.get_item(1) + .and_then(|b| value_to_index(&b, indexer)) + .map(|b| (a, b)) + }) } }) .unwrap_or_else(|_| { - let ixs: &PyTuple = item.downcast()?; + let ixs: &Bound = item.downcast()?; if ixs.len() != 2 { Err(PyErr::new::( "A pair consists of two items", )) } else { - value_to_index(&ixs[0], indexer).and_then(|a| { - value_to_index(&ixs[1], indexer).map(|b| (a, b)) - }) + ixs.get_item(0) + .and_then(|a| value_to_index(&a, indexer)) + .and_then(|a| { + ixs.get_item(1) + .and_then(|b| value_to_index(&b, indexer)) + .map(|b| (a, b)) + }) } }) }) } pub(crate) fn list_to_pairs<'a>( - pairs: &'a PyList, + pairs: &'a Bound, indexer: &'a Indexer, ) -> PyResult> { pairs_list_iter(pairs, indexer).collect() @@ -501,11 +511,12 @@ pub(crate) fn datum_to_value(datum: Datum) -> PyResult> { }) } -fn pyany_to_category(val: &PyAny) -> PyResult { +fn pyany_to_category(val: &Bound) -> PyResult { use lace::Category; - let name = val.get_type().name()?; + let ty = val.get_type(); + let name = ty.name()?; - match name { + match name.as_ref() { "int" => { let x = val.downcast::()?.extract::()?; Ok(Category::U8(x)) @@ -518,7 +529,7 @@ fn pyany_to_category(val: &PyAny) -> PyResult { let x = val.downcast::()?.extract::()?; Ok(Category::String(x)) } - "int64" | "int32" | "int16" | "int8" => { + "numpy.int64" | "numpy.int32" | "numpy.int16" | "numpy.int8" => { let x = val.call_method("__int__", (), None)?.extract::()?; Ok(Category::U8(x)) } @@ -528,7 +539,10 @@ fn pyany_to_category(val: &PyAny) -> PyResult { } } -pub(crate) fn value_to_datum(val: &PyAny, ftype: FType) -> PyResult { +pub(crate) fn value_to_datum( + val: &Bound, + ftype: FType, +) -> PyResult { if val.is_none() { return Ok(Datum::Missing); } @@ -558,7 +572,7 @@ pub(crate) fn value_to_datum(val: &PyAny, ftype: FType) -> PyResult { } pub(crate) fn value_to_name( - val: &PyAny, + val: &Bound, indexer: &Indexer, ) -> PyResult { val.extract::().or_else(|_| { @@ -572,7 +586,7 @@ pub(crate) fn value_to_name( } pub(crate) fn value_to_index( - val: &PyAny, + val: &Bound, indexer: &Indexer, ) -> PyResult { val.extract::().or_else(|_| { @@ -588,16 +602,16 @@ pub(crate) fn value_to_index( } pub(crate) fn pyany_to_indices( - cols: &PyAny, + cols: &Bound, indexer: &Indexer, ) -> PyResult> { cols.iter()? - .map(|res| res.and_then(|val| value_to_index(val, indexer))) + .map(|res| res.and_then(|val| value_to_index(&val, indexer))) .collect() } pub(crate) fn dict_to_given( - dict_opt: Option<&PyDict>, + dict_opt: Option<&Bound>, engine: &lace::Engine, indexer: &Indexer, ) -> PyResult> { @@ -608,9 +622,9 @@ pub(crate) fn dict_to_given( let conditions = dict .iter() .map(|(key, value)| { - value_to_index(key, indexer).and_then(|ix| { + value_to_index(&key, indexer).and_then(|ix| { value_to_datum( - value, + &value, engine.ftype(ix).expect( "Index from indexer ought to be valid.", ), @@ -625,7 +639,7 @@ pub(crate) fn dict_to_given( } } -pub(crate) fn srs_to_strings(srs: &PyAny) -> PyResult> { +pub(crate) fn srs_to_strings(srs: &Bound) -> PyResult> { let list: &PyList = srs.call_method0("to_list")?.extract()?; list.iter() @@ -661,23 +675,23 @@ pub(crate) fn parts_to_insert_values( } pub(crate) fn pyany_to_data( - data: &PyAny, + data: &Bound, ftype: FType, ) -> PyResult> { data.iter()? - .map(|res| res.and_then(|val| value_to_datum(val, ftype))) + .map(|res| res.and_then(|val| value_to_datum(&val, ftype))) .collect() } fn process_row_dict( - row_dict: &PyDict, + row_dict: &Bound, _col_indexer: &Indexer, engine: &lace::Engine, suppl_types: Option<&HashMap>, ) -> PyResult> { let mut row_data: Vec = Vec::with_capacity(row_dict.len()); for (name_any, value_any) in row_dict { - let col_name: &PyString = name_any.downcast()?; + let col_name: &Bound = name_any.downcast()?; let col_name = col_name.to_str()?; let ftype = engine .codebook @@ -692,21 +706,21 @@ fn process_row_dict( ))) })?; - row_data.push(value_to_datum(value_any, ftype)?); + row_data.push(value_to_datum(&value_any, ftype)?); } Ok(row_data) } // Works on list of dicts fn values_to_data( - data: &PyList, + data: &Bound, col_indexer: &Indexer, engine: &lace::Engine, suppl_types: Option<&HashMap>, ) -> PyResult>> { data.iter() .map(|row_any| { - let row_dict: &PyDict = row_any.downcast()?; + let row_dict: &Bound = row_any.downcast()?; process_row_dict(row_dict, col_indexer, engine, suppl_types) }) .collect() @@ -723,7 +737,7 @@ pub(crate) struct DataFrameComponents { // FIXME: pass the 'py' in so that we can handle errors better. The // `Python::with_gil` thing makes using `?` a pain. fn df_to_values( - df: &PyAny, + df: &Bound, col_indexer: &Indexer, engine: &lace::Engine, suppl_types: Option<&HashMap>, @@ -734,12 +748,12 @@ fn df_to_values( if columns.get_type().name()?.contains("Index") { // Is a Pandas dataframe let index = df.getattr("index")?; - let row_names = srs_to_strings(index).ok(); + let row_names = srs_to_strings(&index).ok(); let cols = columns.call_method0("tolist")?.to_object(py); - let kwargs = PyDict::new(py); + let kwargs = PyDict::new_bound(py); kwargs.set_item("orient", "records")?; - let data = df.call_method("to_dict", (), Some(kwargs))?; + let data = df.call_method("to_dict", (), Some(&kwargs))?; (cols, data, row_names) } else { // Is a Polars dataframe @@ -749,11 +763,11 @@ fn df_to_values( // Find all the index columns let mut index_col_names = list .iter() - .map(|s| s.extract::<&str>()) + .map(|s| s.extract::()) .map(|s| { s.map(|s| { - if is_index_col(s) { - Some(String::from(s)) + if is_index_col(&s) { + Some(s) } else { None } @@ -784,7 +798,7 @@ fn df_to_values( // Get the indices from the index if it exists let row_names = df.get_item(index_name) - .and_then(srs_to_strings) + .and_then(|srs: pyo3::Bound<'_, pyo3::PyAny>| srs_to_strings(&srs)) .map_err(|err| { PyValueError::new_err(format!( "Indices in index '{index_name}' are not strings: {err}")) @@ -794,7 +808,7 @@ fn df_to_values( (df, Some(row_names)) } else { - (df, None) + (df.clone(), None) }; let data = df.call_method0("to_dicts")?; @@ -802,19 +816,19 @@ fn df_to_values( } }; - let data: &PyList = data.extract()?; - let columns: &PyList = columns.extract(py)?; + let data: Bound = data.extract()?; + let columns: Bound = columns.extract(py)?; // will return nothing if there are unknown column names let col_ixs = columns .iter() - .map(|col_name| value_to_index(col_name, col_indexer)) + .map(|col_name| value_to_index(&col_name, col_indexer)) .collect::, _>>() .ok(); let col_names = columns .iter() .map(|name| name.extract()) .collect::, _>>()?; - let values = values_to_data(data, col_indexer, engine, suppl_types)?; + let values = values_to_data(&data, col_indexer, engine, suppl_types)?; Ok(DataFrameComponents { col_ixs, @@ -826,35 +840,36 @@ fn df_to_values( } fn srs_to_column_values( - srs: &PyAny, + srs: &Bound, indexer: &Indexer, engine: &lace::Engine, suppl_types: Option<&HashMap>, ) -> PyResult { let data = srs.call_method0("to_frame")?; - df_to_values(data, indexer, engine, suppl_types) + df_to_values(&data, indexer, engine, suppl_types) } fn srs_to_row_values( - srs: &PyAny, + srs: &Bound, indexer: &Indexer, engine: &lace::Engine, suppl_types: Option<&HashMap>, ) -> PyResult { let data = srs.call_method0("to_frame")?.call_method0("transpose")?; - df_to_values(data, indexer, engine, suppl_types) + df_to_values(&data, indexer, engine, suppl_types) } pub(crate) fn pandas_to_logp_values( - xs: &PyAny, + xs: &Bound, indexer: &Indexer, engine: &lace::Engine, ) -> PyResult { - let type_name = xs.get_type().name()?; + let ty = xs.get_type(); + let type_name = ty.name()?; - match type_name { + match type_name.as_ref() { "DataFrame" => df_to_values(xs, indexer, engine, None), "Series" => srs_to_column_values(xs, indexer, engine, None), t => Err(PyErr::new::(format!( @@ -864,14 +879,15 @@ pub(crate) fn pandas_to_logp_values( } pub(crate) fn pandas_to_insert_values( - xs: &PyAny, + xs: &Bound, col_indexer: &Indexer, engine: &lace::Engine, suppl_types: Option<&HashMap>, ) -> PyResult { - let type_name = xs.get_type().name()?; + let ty = xs.get_type(); + let type_name = ty.name()?; - match type_name { + match type_name.as_ref() { "DataFrame" => df_to_values(xs, col_indexer, engine, suppl_types), "Series" => srs_to_row_values(xs, col_indexer, engine, suppl_types), t => Err(PyErr::new::(format!( diff --git a/pylace/tests/example_test.py b/pylace/tests/example_test.py index 8a87c765..a11765d9 100644 --- a/pylace/tests/example_test.py +++ b/pylace/tests/example_test.py @@ -13,4 +13,4 @@ def test_animals(): assert engine.shape == (50, 85) swim_cat, swim_unc = engine.predict("swims") assert swim_cat == 0 - assert_almost_equal(swim_unc, 0.04384630, 6) + assert_almost_equal(swim_unc, 0.03782005724890601, 6) diff --git a/pylace/tests/test_codebook.py b/pylace/tests/test_codebook.py index c8dd9a09..f354f17b 100644 --- a/pylace/tests/test_codebook.py +++ b/pylace/tests/test_codebook.py @@ -3,6 +3,7 @@ import polars as pl import lace +from lace.examples import Animals def test_engine_from_polars_with_codebook_smoke(): @@ -65,3 +66,10 @@ def test_engine_with_boolean_string_columns(): assert str(engine.codebook.column_metadata["b"].value_map) == str( lace.ValueMap.bool() ) + + +def test_with_index(): + codebook = Animals().engine.codebook + new_codebook = codebook.with_index(["a", "b", "c"]) + + assert new_codebook.shape[0] == 3 diff --git a/pylace/tests/test_docs.py b/pylace/tests/test_docs.py index 2a5b4d83..e7e153c6 100644 --- a/pylace/tests/test_docs.py +++ b/pylace/tests/test_docs.py @@ -32,6 +32,7 @@ def runtest(mod): if __name__ == "__main__": runtest("lace.engine") runtest("lace.analysis") + runtest("lace.codebook") if not NOPLOT: runtest("lace.plot") diff --git a/pylace/tests/test_engine.py b/pylace/tests/test_engine.py index 55d17119..b55604eb 100644 --- a/pylace/tests/test_engine.py +++ b/pylace/tests/test_engine.py @@ -11,3 +11,13 @@ def test_deep_copy(): assert a.columns == b.columns assert a.shape == b.shape + + +def test_remove_rows(): + from lace.examples import Animals + + engine = Animals() + n_rows = engine.n_rows + removed = engine.remove_rows(["cow", "wolf"]) + assert n_rows == engine.n_rows + 2 + assert removed["index"].len() == 2