From 83fd246101674d68507aa9cf315982e3f24b12b8 Mon Sep 17 00:00:00 2001 From: Mike Schmidt Date: Thu, 23 May 2024 15:29:14 -0600 Subject: [PATCH 1/2] feat(rust): Made codebook and valuemap accessors generic borrows. --- lace/lace_codebook/src/codebook.rs | 8 ++++++-- lace/lace_codebook/src/value_map.rs | 17 +++++++++++++---- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/lace/lace_codebook/src/codebook.rs b/lace/lace_codebook/src/codebook.rs index a438ff46..c74d9db2 100644 --- a/lace/lace_codebook/src/codebook.rs +++ b/lace/lace_codebook/src/codebook.rs @@ -14,6 +14,7 @@ use lace_stats::rv::dist::{ }; use polars::prelude::DataFrame; use serde::{Deserialize, Serialize}; +use std::borrow::Borrow; use std::collections::HashMap; use std::convert::TryFrom; use std::fs::File; @@ -463,9 +464,12 @@ impl Codebook { output } - pub fn col_metadata(&self, col: String) -> Option<&ColMetadata> { + pub fn col_metadata(&self, col: T) -> Option<&ColMetadata> + where + T: Borrow, + { // self.col_metadata.get(&col) - self.col_metadata.iter().find(|md| md.name == col) + self.col_metadata.iter().find(|md| md.name == *col.borrow()) } /// Get the number of columns diff --git a/lace/lace_codebook/src/value_map.rs b/lace/lace_codebook/src/value_map.rs index a79e28f1..c82d70ca 100644 --- a/lace/lace_codebook/src/value_map.rs +++ b/lace/lace_codebook/src/value_map.rs @@ -1,5 +1,6 @@ use lace_data::Category; use serde::{Deserialize, Serialize}; +use std::borrow::Borrow; use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; use std::hash::Hash; use thiserror::Error; @@ -30,16 +31,24 @@ where self.to_cat.is_empty() } - pub fn ix(&self, cat: &T) -> Option { - self.to_ix.get(cat).cloned() + pub fn ix(&self, cat: &Q) -> Option + where + T: Borrow, + Q: Hash + Eq + ?Sized, + { + self.to_ix.get(cat.borrow()).copied() } pub fn category(&self, ix: usize) -> T { self.to_cat[ix].clone() } - pub fn contains_cat(&self, cat: &T) -> bool { - self.to_ix.contains_key(cat) + pub fn contains_cat(&self, cat: &Q) -> bool + where + T: Borrow, + Q: Hash + Eq + ?Sized, + { + self.to_ix.contains_key(cat.borrow()) } pub(crate) fn add(&mut self, value: T) { From 62178a0d67c02695bf946bcf882cc8212bfce1f6 Mon Sep 17 00:00:00 2001 From: Mike Schmidt Date: Thu, 23 May 2024 15:33:04 -0600 Subject: [PATCH 2/2] feat(rust): Added conversion traits for AnyValues and Series to Datums --- lace/lace_codebook/src/data.rs | 272 +++++++++++++++++++++++++++++++++ 1 file changed, 272 insertions(+) diff --git a/lace/lace_codebook/src/data.rs b/lace/lace_codebook/src/data.rs index f33c8cb0..2afee550 100644 --- a/lace/lace_codebook/src/data.rs +++ b/lace/lace_codebook/src/data.rs @@ -5,16 +5,288 @@ use crate::{ }; use lace_consts::rv::prelude::UnitPowerLaw; +use lace_data::{Category, Datum}; use lace_stats::prior::csd::CsdHyper; use lace_stats::prior::nix::NixHyper; use lace_stats::prior::pg::PgHyper; use lace_stats::prior::sbd::SbdHyper; +use polars::datatypes::AnyValue; use polars::prelude::{CsvReader, DataFrame, DataType, SerReader, Series}; use std::convert::TryFrom; use std::path::Path; +use thiserror::Error; pub const DEFAULT_CAT_CUTOFF: u8 = 20; +/// An Error from converting a Polar's AnyValue to a Datum +#[derive(Debug, Error)] +pub enum ConversionError { + #[error("The given value `{0}` is not an existing category.")] + ValueNotACategory(String), + #[error("The category is not indexed by a string.")] + CategoryNotIndexedByString, + #[error("The type `{0}` is not supported by this conversion.")] + UnsupportedType(String), + #[error("The category count is out of the existing bounds.")] + CountOutOfBounds, + #[error("The index is out of the existing bounds.")] + IndexOutOfBounds, +} + +pub trait AnyValueDatumExt: Sized { + /// Convert a AnyValue to a Datum with a specific coltype. + /// + /// * `coltype` - Column type to convert datum into. + /// * `drop_out_of_category` - Set output to `Datum::Missing` if given value is not contained + /// in the existing category. + fn to_datum( + self, + coltype: &ColType, + drop_out_of_category: bool, + ) -> Result; +} + +macro_rules! int_to_category { + ($x: expr, $value_map: expr, $out_to_missing: expr) => {{ + let idx: u8 = $x.try_into().map_err(|_| { + ConversionError::ValueNotACategory(format!("{}", $x)) + })?; + + if let ValueMap::U8(size) = $value_map { + if (idx as usize) < *size { + Ok(Datum::Categorical(Category::U8(idx))) + } else if $out_to_missing { + Ok(Datum::Missing) + } else { + Err(ConversionError::ValueNotACategory(idx.to_string())) + } + } else { + Err(ConversionError::ValueNotACategory(idx.to_string())) + } + }}; +} + +impl<'a> AnyValueDatumExt for AnyValue<'a> { + fn to_datum( + self, + coltype: &ColType, + drop_out_of_category: bool, + ) -> Result { + match (self, coltype) { + (AnyValue::Null, _) => Ok(Datum::Missing), + (AnyValue::String(s), ColType::Categorical { value_map, .. }) => { + if let ValueMap::String(cat_map) = value_map { + if let Some(_cat_idx) = cat_map.ix(s) { + Ok(Datum::Categorical(lace_data::Category::String( + s.to_string(), + ))) + } else { + if drop_out_of_category { + Ok(Datum::Missing) + } else { + Err(ConversionError::ValueNotACategory( + s.to_string(), + )) + } + } + } else { + Err(ConversionError::CategoryNotIndexedByString) + } + } + + ( + AnyValue::StringOwned(s), + ColType::Categorical { value_map, .. }, + ) => { + if let ValueMap::String(cat_map) = value_map { + if let Some(_cat_idx) = cat_map.ix(&s.to_string()) { + Ok(Datum::Categorical(lace_data::Category::String( + s.to_string(), + ))) + } else { + if drop_out_of_category { + Ok(Datum::Missing) + } else { + Err(ConversionError::ValueNotACategory( + s.to_string(), + )) + } + } + } else { + Err(ConversionError::CategoryNotIndexedByString) + } + } + + (AnyValue::Boolean(b), ColType::Categorical { value_map, .. }) => { + if let ValueMap::Bool = value_map { + Ok(Datum::Binary(b)) + } else { + Err(ConversionError::ValueNotACategory(b.to_string())) + } + } + (AnyValue::UInt8(x), ColType::Categorical { value_map, .. }) => { + int_to_category!(x, value_map, drop_out_of_category) + } + (AnyValue::UInt8(x), ColType::Count { .. }) => Ok(Datum::Count( + x.try_into() + .map_err(|_| ConversionError::CountOutOfBounds)?, + )), + (AnyValue::UInt8(x), ColType::StickBreakingDiscrete { .. }) => { + Ok(Datum::Index(x.into())) + } + (AnyValue::UInt16(x), ColType::Categorical { value_map, .. }) => { + int_to_category!(x, value_map, drop_out_of_category) + } + (AnyValue::UInt16(x), ColType::Count { .. }) => Ok(Datum::Count( + x.try_into() + .map_err(|_| ConversionError::CountOutOfBounds)?, + )), + (AnyValue::UInt16(x), ColType::StickBreakingDiscrete { .. }) => { + Ok(Datum::Index(x.into())) + } + (AnyValue::UInt32(x), ColType::Categorical { value_map, .. }) => { + int_to_category!(x, value_map, drop_out_of_category) + } + (AnyValue::UInt32(x), ColType::Count { .. }) => Ok(Datum::Count( + x.try_into() + .map_err(|_| ConversionError::CountOutOfBounds)?, + )), + (AnyValue::UInt32(x), ColType::StickBreakingDiscrete { .. }) => { + Ok(Datum::Index( + x.try_into() + .map_err(|_| ConversionError::IndexOutOfBounds)?, + )) + } + (AnyValue::UInt64(x), ColType::Categorical { value_map, .. }) => { + int_to_category!(x, value_map, drop_out_of_category) + } + (AnyValue::UInt64(x), ColType::Count { .. }) => Ok(Datum::Count( + x.try_into() + .map_err(|_| ConversionError::CountOutOfBounds)?, + )), + (AnyValue::UInt64(x), ColType::StickBreakingDiscrete { .. }) => { + Ok(Datum::Index( + x.try_into() + .map_err(|_| ConversionError::IndexOutOfBounds)?, + )) + } + + (AnyValue::Int8(x), ColType::Categorical { value_map, .. }) => { + int_to_category!(x, value_map, drop_out_of_category) + } + (AnyValue::Int8(x), ColType::Count { .. }) => Ok(Datum::Count( + x.try_into() + .map_err(|_| ConversionError::CountOutOfBounds)?, + )), + (AnyValue::Int8(x), ColType::StickBreakingDiscrete { .. }) => { + Ok(Datum::Index( + x.try_into() + .map_err(|_| ConversionError::IndexOutOfBounds)?, + )) + } + (AnyValue::Int16(x), ColType::Categorical { value_map, .. }) => { + int_to_category!(x, value_map, drop_out_of_category) + } + (AnyValue::Int16(x), ColType::Count { .. }) => Ok(Datum::Count( + x.try_into() + .map_err(|_| ConversionError::CountOutOfBounds)?, + )), + (AnyValue::Int16(x), ColType::StickBreakingDiscrete { .. }) => { + Ok(Datum::Index( + x.try_into() + .map_err(|_| ConversionError::IndexOutOfBounds)?, + )) + } + (AnyValue::Int32(x), ColType::Categorical { value_map, .. }) => { + int_to_category!(x, value_map, drop_out_of_category) + } + (AnyValue::Int32(x), ColType::Count { .. }) => Ok(Datum::Count( + x.try_into() + .map_err(|_| ConversionError::CountOutOfBounds)?, + )), + (AnyValue::Int32(x), ColType::StickBreakingDiscrete { .. }) => { + Ok(Datum::Index( + x.try_into() + .map_err(|_| ConversionError::IndexOutOfBounds)?, + )) + } + (AnyValue::Int64(x), ColType::Categorical { value_map, .. }) => { + int_to_category!(x, value_map, drop_out_of_category) + } + (AnyValue::Int64(x), ColType::Count { .. }) => Ok(Datum::Count( + x.try_into() + .map_err(|_| ConversionError::CountOutOfBounds)?, + )), + (AnyValue::Int64(x), ColType::StickBreakingDiscrete { .. }) => { + Ok(Datum::Index( + x.try_into() + .map_err(|_| ConversionError::IndexOutOfBounds)?, + )) + } + + (AnyValue::Float32(x), ColType::Continuous { .. }) => { + Ok(Datum::Continuous(x.into())) + } + (AnyValue::Float64(x), ColType::Continuous { .. }) => { + Ok(Datum::Continuous(x)) + } + + (av, _) => Err(ConversionError::UnsupportedType(av.to_string())), + } + } +} + +/// Series to collection of `Datum` conversion helper. +pub trait SeriesDatumExt { + /// Convert a `polars::Series` to a `Vec` with a specific coltype. + /// + /// * `coltype` - Column type to convert datum into. + /// * `drop_out_of_category` - Set output to `Datum::Missing` if given value is not contained + /// in the existing category. + fn to_datum_vec( + self, + col_type: &ColType, + drop_out_of_category: bool, + ) -> Result, ConversionError>; + + /// Convert a `polars::Series` to an iterator of `Datum`s. + /// + /// * `coltype` - Column type to convert datum into. + /// * `drop_out_of_category` - Set output to `Datum::Missing` if given value is not contained + /// in the existing category. + fn as_datum_iter( + &self, + col_type: &ColType, + drop_out_of_category: bool, + ) -> impl Iterator>; +} + +impl SeriesDatumExt for Series { + fn to_datum_vec( + self, + col_type: &ColType, + drop_out_of_category: bool, + ) -> Result, ConversionError> { + // XXX: Rechunk is only required because of a polar's design oddity, remove this if polars + // fixes it. + let arr = self.rechunk(); + arr.iter() + .map(|x: AnyValue| x.to_datum(col_type, drop_out_of_category)) + .collect::, _>>() + } + + fn as_datum_iter( + &self, + col_type: &ColType, + drop_out_of_category: bool, + ) -> impl Iterator> { + // XXX: Rechunk is only required because of a polar's design oddity, remove this if polars + // fixes it. + self.iter() + .map(move |x: AnyValue| x.to_datum(col_type, drop_out_of_category)) + } +} + #[macro_export] macro_rules! series_to_opt_vec { ($srs: ident, $X: ty) => {{