Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix/boolean as category #190

Merged
merged 12 commits into from
Mar 1, 2024
33 changes: 29 additions & 4 deletions .github/workflows/rust-build-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,28 +59,47 @@ jobs:
with:
workspaces: . -> lace/target

- name: Run rustfmt
- name: Run rustfmt (Lace)
working-directory: lace
run: cargo fmt --all -- --check

- name: Run clippy
- name: Run rustfmt (CLI)
working-directory: cli
run: cargo fmt --all -- --check

- name: Run clippy (Lace)
working-directory: lace
env:
RUSTFLAGS: -C debuginfo=0
run: |
cargo clippy --all-features

- name: Run clippy (CLI)
working-directory: cli
env:
RUSTFLAGS: -C debuginfo=0
run: |
cargo clippy --all-features

- name: Install audit
run: cargo install cargo-audit

- name: Run audit
- name: Run audit (Lace)
working-directory: lace
# Note: Both `polars` and `arrow2` trigger this security violation
# due to their reliance on `chrono`, which is the ultimate source of the violation
# as of 2/21/23, no version of `chrono` has been published that fixes the issue
# and thus neither `polars` or `arrow2` can pass `audit` checks
run: cargo audit -f Cargo.lock --ignore RUSTSEC-2020-0071

- name: Run audit (CLI)
working-directory: cli
# Note: Both `polars` and `arrow2` trigger this security violation
# due to their reliance on `chrono`, which is the ultimate source of the violation
# as of 2/21/23, no version of `chrono` has been published that fixes the issue
# and thus neither `polars` or `arrow2` can pass `audit` checks
run: cargo audit -f Cargo.lock --ignore RUSTSEC-2020-0071

test:
runs-on: ${{ matrix.os }}
needs: ["lint", "features"]
Expand Down Expand Up @@ -109,7 +128,13 @@ jobs:
RUSTFLAGS: -C debuginfo=0
run: cargo run -- regen-examples

- name: Run tests
- name: Run Lace tests
env:
RUSTFLAGS: -C debuginfo=0
run: cargo test --all-features

- name: Run CLI tests
working-directory: cli
env:
RUSTFLAGS: -C debuginfo=0
run: cargo test --all-features
Expand Down
19 changes: 10 additions & 9 deletions cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion cli/src/opt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@ impl std::str::FromStr for Transition {
// into an enum-derived ArgGroup.
#[derive(Parser, Debug)]
pub struct RunArgs {
/// Directory to save Lace data in. If it does not exist, `run` will create it
#[clap(name = "LACEFILE_OUT")]
pub output: PathBuf,
/// Optinal path to codebook
/// Optional path to codebook
#[clap(long = "codebook", short = 'c')]
pub codebook: Option<PathBuf>,
/// Path to .csv data source. May be compressed.
Expand Down
37 changes: 37 additions & 0 deletions cli/tests/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1008,4 +1008,41 @@ mod codebook {

Ok(())
}

#[test]
fn bool_data() -> Result<(), Box<dyn std::error::Error>> {
let output_dir = tempfile::Builder::new().tempdir().unwrap();
let mut data_file = tempfile::NamedTempFile::new().unwrap();

{
// Write CSV with 5 "True" and 5 "False values"
let f = data_file.as_file_mut();
writeln!(f, "ID,data")?;
for id in 1..5 {
writeln!(f, "{},{}", id, "True")?;
}
for id in 6..10 {
writeln!(f, "{},{}", id, "False")?;
}
}

// Default categorical cutoff should be 20
let output_default = Command::new(LACE_CMD)
.arg("run")
.arg("--csv")
.arg(data_file.path())
.arg("-n")
.arg("100")
.arg(output_dir.path())
.output()
.expect("Failed to execute process");

assert!(
output_default.status.success(),
"Process exited with error :\n{}",
String::from_utf8(output_default.stderr)?
);

Ok(())
}
}
4 changes: 2 additions & 2 deletions lace/lace_cc/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ impl State {
rng: &mut R,
) -> Self {
let n_cols = ftrs.len();
let n_rows = ftrs.get(0).map(|f| f.len()).unwrap_or(0);
let n_rows = ftrs.first().map(|f| f.len()).unwrap_or(0);
let asgn = AssignmentBuilder::new(n_cols)
.with_prior(state_alpha_prior)
.seed_from_rng(rng)
Expand Down Expand Up @@ -178,7 +178,7 @@ impl State {
/// Get the number of rows
#[inline]
pub fn n_rows(&self) -> usize {
self.views.get(0).map(|v| v.n_rows()).unwrap_or(0)
self.views.first().map(|v| v.n_rows()).unwrap_or(0)
}

/// Get the number of columns
Expand Down
34 changes: 28 additions & 6 deletions lace/lace_codebook/src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,19 +288,41 @@ fn uint_categorical_coltype(
k: usize,
no_hypers: bool,
) -> Result<ColType, CodebookError> {
let (hyper, prior) = hyper_and_prior_for_categorical(no_hypers, k);

Ok(ColType::Categorical {
k,
hyper,
prior,
value_map: ValueMap::U8(k),
})
}

fn hyper_and_prior_for_categorical(
no_hypers: bool,
k: usize,
) -> (
Option<CsdHyper>,
Option<lace_stats::rv::prelude::SymmetricDirichlet>,
) {
use lace_stats::rv::dist::SymmetricDirichlet;

let (hyper, prior) = if no_hypers {
(None, Some(SymmetricDirichlet::jeffreys(k).unwrap()))
} else {
(Some(CsdHyper::new(1.0, 1.0)), None)
};
(hyper, prior)
}

fn bool_categorical_coltype(no_hypers: bool) -> Result<ColType, CodebookError> {
let (hyper, prior) = hyper_and_prior_for_categorical(no_hypers, 2);

Ok(ColType::Categorical {
k,
k: 2,
hyper,
prior,
value_map: ValueMap::U8(k),
value_map: ValueMap::Bool,
})
}

Expand Down Expand Up @@ -359,7 +381,7 @@ pub fn series_to_colmd(
let name = String::from(srs.name());
let dtype = srs.dtype();
let coltype = match dtype {
DataType::Boolean => uint_categorical_coltype(2, no_hypers),
DataType::Boolean => bool_categorical_coltype(no_hypers),
DataType::UInt8 => uint_coltype(srs, cat_cutoff, no_hypers),
DataType::UInt16 => uint_coltype(srs, cat_cutoff, no_hypers),
DataType::UInt32 => uint_coltype(srs, cat_cutoff, no_hypers),
Expand Down Expand Up @@ -708,19 +730,19 @@ mod test {
count_or_continuous!(count_or_cts_i64_neg_small, -1_i64, 10, false);

#[test]
fn bool_data_is_categorical() {
fn bool_data_is_bool() {
let srs = Series::new(
"A",
(0..100).map(|x| x % 2 == 1).collect::<Vec<bool>>(),
);
let colmd = series_to_colmd(&srs, None, true).unwrap();
match colmd.coltype {
ColType::Categorical {
value_map: ValueMap::U8(2),
value_map: ValueMap::Bool,
..
} => (),
ColType::Categorical { value_map, .. } => {
panic!("value map should be U8(2), was: {:?}", value_map)
panic!("value map should be Bool, was: {:?}", value_map)
}
_ => panic!("wrong coltype"),
}
Expand Down
7 changes: 7 additions & 0 deletions lace/src/data/init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,13 @@ fn categorical_col_model<R: rand::Rng>(
(ValueMap::U8(_), dt) if is_categorical_int_dtype(dt) => {
crate::codebook::data::series_to_opt_vec!(srs, u8)
}
(ValueMap::Bool, DataType::Boolean) => srs
.bool()?
.into_iter()
.map(|maybe_bool| {
maybe_bool.map(|b| ValueMap::Bool.ix(&b.into()).unwrap() as u8)
})
.collect(),
_ => {
return Err(CodebookError::UnsupportedDataType {
col_name: srs.name().to_owned(),
Expand Down
2 changes: 1 addition & 1 deletion lace/src/interface/oracle/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ impl Oracle {
/// Convert an `Engine` into an `Oracle`
pub fn from_engine(engine: Engine) -> Self {
let data = {
let data_map = engine.states.get(0).unwrap().clone_data();
let data_map = engine.states.first().unwrap().clone_data();
DataStore::new(data_map)
};

Expand Down
2 changes: 1 addition & 1 deletion pylace/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 26 additions & 0 deletions pylace/tests/test_codebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,29 @@ def test_engine_from_pandas_with_codebook_smoke():
engine = lace.Engine.from_df(df, codebook=codebook, n_states=3)
assert engine.shape == (14, 2)
assert engine.columns == ["x", "y"]


def test_engine_with_boolean_string_columns():
n = 14
df = pl.DataFrame(
{
"ID": list(range(n)),
"x": np.random.randn(n),
"b": np.random.choice([True, False], size=n),
}
)

assert df.dtypes[df.get_column_index("b")] == pl.Boolean

codebook = lace.Codebook.from_df("test", df)
assert codebook.shape == (n, 2)
assert str(codebook.column_metadata["b"].value_map) == str(
lace.ValueMap.bool()
)

engine = lace.Engine.from_df(df, codebook=codebook, n_states=3)
assert engine.shape == (n, 2)
assert engine.columns == ["x", "b"]
assert str(engine.codebook.column_metadata["b"].value_map) == str(
lace.ValueMap.bool()
)
Loading