Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/anyval to datum #197

Open
wants to merge 2 commits into
base: chore/reorganize-2024-04-03
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions lace/lace_codebook/src/codebook.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use lace_stats::rv::dist::{
};
use polars::prelude::DataFrame;
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use std::collections::HashMap;
use std::convert::TryFrom;
use std::fs::File;
Expand Down Expand Up @@ -463,9 +464,12 @@ impl Codebook {
output
}

pub fn col_metadata(&self, col: String) -> Option<&ColMetadata> {
pub fn col_metadata<T>(&self, col: T) -> Option<&ColMetadata>
where
T: Borrow<str>,
{
// self.col_metadata.get(&col)
self.col_metadata.iter().find(|md| md.name == col)
self.col_metadata.iter().find(|md| md.name == *col.borrow())
}

/// Get the number of columns
Expand Down
272 changes: 272 additions & 0 deletions lace/lace_codebook/src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,288 @@ use crate::{
};

use lace_consts::rv::prelude::UnitPowerLaw;
use lace_data::{Category, Datum};
use lace_stats::prior::csd::CsdHyper;
use lace_stats::prior::nix::NixHyper;
use lace_stats::prior::pg::PgHyper;
use lace_stats::prior::sbd::SbdHyper;
use polars::datatypes::AnyValue;
use polars::prelude::{CsvReader, DataFrame, DataType, SerReader, Series};
use std::convert::TryFrom;
use std::path::Path;
use thiserror::Error;

pub const DEFAULT_CAT_CUTOFF: u8 = 20;

/// An Error from converting a Polar's AnyValue to a Datum
#[derive(Debug, Error)]
pub enum ConversionError {
#[error("The given value `{0}` is not an existing category.")]
ValueNotACategory(String),
#[error("The category is not indexed by a string.")]
CategoryNotIndexedByString,
#[error("The type `{0}` is not supported by this conversion.")]
UnsupportedType(String),
#[error("The category count is out of the existing bounds.")]
CountOutOfBounds,
#[error("The index is out of the existing bounds.")]
IndexOutOfBounds,
}

pub trait AnyValueDatumExt: Sized {
/// Convert a AnyValue to a Datum with a specific coltype.
///
/// * `coltype` - Column type to convert datum into.
/// * `drop_out_of_category` - Set output to `Datum::Missing` if given value is not contained
/// in the existing category.
fn to_datum(
self,
coltype: &ColType,
drop_out_of_category: bool,
) -> Result<Datum, ConversionError>;
}

macro_rules! int_to_category {
($x: expr, $value_map: expr, $out_to_missing: expr) => {{
let idx: u8 = $x.try_into().map_err(|_| {
ConversionError::ValueNotACategory(format!("{}", $x))
})?;

if let ValueMap::U8(size) = $value_map {
if (idx as usize) < *size {
Ok(Datum::Categorical(Category::U8(idx)))
} else if $out_to_missing {
Ok(Datum::Missing)
} else {
Err(ConversionError::ValueNotACategory(idx.to_string()))
}
} else {
Err(ConversionError::ValueNotACategory(idx.to_string()))
}
}};
}

impl<'a> AnyValueDatumExt for AnyValue<'a> {
fn to_datum(
self,
coltype: &ColType,
drop_out_of_category: bool,
) -> Result<Datum, ConversionError> {
match (self, coltype) {
(AnyValue::Null, _) => Ok(Datum::Missing),
(AnyValue::String(s), ColType::Categorical { value_map, .. }) => {
if let ValueMap::String(cat_map) = value_map {
if let Some(_cat_idx) = cat_map.ix(s) {
Ok(Datum::Categorical(lace_data::Category::String(
s.to_string(),
)))
} else {
if drop_out_of_category {
Ok(Datum::Missing)
} else {
Err(ConversionError::ValueNotACategory(
s.to_string(),
))
}
}
} else {
Err(ConversionError::CategoryNotIndexedByString)
}
}

(
AnyValue::StringOwned(s),
ColType::Categorical { value_map, .. },
) => {
if let ValueMap::String(cat_map) = value_map {
if let Some(_cat_idx) = cat_map.ix(&s.to_string()) {
Ok(Datum::Categorical(lace_data::Category::String(
s.to_string(),
)))
} else {
if drop_out_of_category {
Ok(Datum::Missing)
} else {
Err(ConversionError::ValueNotACategory(
s.to_string(),
))
}
}
} else {
Err(ConversionError::CategoryNotIndexedByString)
}
}

(AnyValue::Boolean(b), ColType::Categorical { value_map, .. }) => {
if let ValueMap::Bool = value_map {
Ok(Datum::Binary(b))
} else {
Err(ConversionError::ValueNotACategory(b.to_string()))
}
}
(AnyValue::UInt8(x), ColType::Categorical { value_map, .. }) => {
int_to_category!(x, value_map, drop_out_of_category)
}
(AnyValue::UInt8(x), ColType::Count { .. }) => Ok(Datum::Count(
x.try_into()
.map_err(|_| ConversionError::CountOutOfBounds)?,
)),
(AnyValue::UInt8(x), ColType::StickBreakingDiscrete { .. }) => {
Ok(Datum::Index(x.into()))
}
(AnyValue::UInt16(x), ColType::Categorical { value_map, .. }) => {
int_to_category!(x, value_map, drop_out_of_category)
}
(AnyValue::UInt16(x), ColType::Count { .. }) => Ok(Datum::Count(
x.try_into()
.map_err(|_| ConversionError::CountOutOfBounds)?,
)),
(AnyValue::UInt16(x), ColType::StickBreakingDiscrete { .. }) => {
Ok(Datum::Index(x.into()))
}
(AnyValue::UInt32(x), ColType::Categorical { value_map, .. }) => {
int_to_category!(x, value_map, drop_out_of_category)
}
(AnyValue::UInt32(x), ColType::Count { .. }) => Ok(Datum::Count(
x.try_into()
.map_err(|_| ConversionError::CountOutOfBounds)?,
)),
(AnyValue::UInt32(x), ColType::StickBreakingDiscrete { .. }) => {
Ok(Datum::Index(
x.try_into()
.map_err(|_| ConversionError::IndexOutOfBounds)?,
))
}
(AnyValue::UInt64(x), ColType::Categorical { value_map, .. }) => {
int_to_category!(x, value_map, drop_out_of_category)
}
(AnyValue::UInt64(x), ColType::Count { .. }) => Ok(Datum::Count(
x.try_into()
.map_err(|_| ConversionError::CountOutOfBounds)?,
)),
(AnyValue::UInt64(x), ColType::StickBreakingDiscrete { .. }) => {
Ok(Datum::Index(
x.try_into()
.map_err(|_| ConversionError::IndexOutOfBounds)?,
))
}

(AnyValue::Int8(x), ColType::Categorical { value_map, .. }) => {
int_to_category!(x, value_map, drop_out_of_category)
}
(AnyValue::Int8(x), ColType::Count { .. }) => Ok(Datum::Count(
x.try_into()
.map_err(|_| ConversionError::CountOutOfBounds)?,
)),
(AnyValue::Int8(x), ColType::StickBreakingDiscrete { .. }) => {
Ok(Datum::Index(
x.try_into()
.map_err(|_| ConversionError::IndexOutOfBounds)?,
))
}
(AnyValue::Int16(x), ColType::Categorical { value_map, .. }) => {
int_to_category!(x, value_map, drop_out_of_category)
}
(AnyValue::Int16(x), ColType::Count { .. }) => Ok(Datum::Count(
x.try_into()
.map_err(|_| ConversionError::CountOutOfBounds)?,
)),
(AnyValue::Int16(x), ColType::StickBreakingDiscrete { .. }) => {
Ok(Datum::Index(
x.try_into()
.map_err(|_| ConversionError::IndexOutOfBounds)?,
))
}
(AnyValue::Int32(x), ColType::Categorical { value_map, .. }) => {
int_to_category!(x, value_map, drop_out_of_category)
}
(AnyValue::Int32(x), ColType::Count { .. }) => Ok(Datum::Count(
x.try_into()
.map_err(|_| ConversionError::CountOutOfBounds)?,
)),
(AnyValue::Int32(x), ColType::StickBreakingDiscrete { .. }) => {
Ok(Datum::Index(
x.try_into()
.map_err(|_| ConversionError::IndexOutOfBounds)?,
))
}
(AnyValue::Int64(x), ColType::Categorical { value_map, .. }) => {
int_to_category!(x, value_map, drop_out_of_category)
}
(AnyValue::Int64(x), ColType::Count { .. }) => Ok(Datum::Count(
x.try_into()
.map_err(|_| ConversionError::CountOutOfBounds)?,
)),
(AnyValue::Int64(x), ColType::StickBreakingDiscrete { .. }) => {
Ok(Datum::Index(
x.try_into()
.map_err(|_| ConversionError::IndexOutOfBounds)?,
))
}

(AnyValue::Float32(x), ColType::Continuous { .. }) => {
Ok(Datum::Continuous(x.into()))
}
(AnyValue::Float64(x), ColType::Continuous { .. }) => {
Ok(Datum::Continuous(x))
}

(av, _) => Err(ConversionError::UnsupportedType(av.to_string())),
}
}
}

/// Series to collection of `Datum` conversion helper.
pub trait SeriesDatumExt {
/// Convert a `polars::Series` to a `Vec<Datum>` with a specific coltype.
///
/// * `coltype` - Column type to convert datum into.
/// * `drop_out_of_category` - Set output to `Datum::Missing` if given value is not contained
/// in the existing category.
fn to_datum_vec(
self,
col_type: &ColType,
drop_out_of_category: bool,
) -> Result<Vec<Datum>, ConversionError>;

/// Convert a `polars::Series` to an iterator of `Datum`s.
///
/// * `coltype` - Column type to convert datum into.
/// * `drop_out_of_category` - Set output to `Datum::Missing` if given value is not contained
/// in the existing category.
fn as_datum_iter(
&self,
col_type: &ColType,
drop_out_of_category: bool,
) -> impl Iterator<Item = Result<Datum, ConversionError>>;
}

impl SeriesDatumExt for Series {
fn to_datum_vec(
self,
col_type: &ColType,
drop_out_of_category: bool,
) -> Result<Vec<Datum>, ConversionError> {
// XXX: Rechunk is only required because of a polar's design oddity, remove this if polars
// fixes it.
let arr = self.rechunk();
arr.iter()
.map(|x: AnyValue| x.to_datum(col_type, drop_out_of_category))
.collect::<Result<Vec<Datum>, _>>()
}

fn as_datum_iter(
&self,
col_type: &ColType,
drop_out_of_category: bool,
) -> impl Iterator<Item = Result<Datum, ConversionError>> {
// XXX: Rechunk is only required because of a polar's design oddity, remove this if polars
// fixes it.
self.iter()
.map(move |x: AnyValue| x.to_datum(col_type, drop_out_of_category))
}
}

#[macro_export]
macro_rules! series_to_opt_vec {
($srs: ident, $X: ty) => {{
Expand Down
17 changes: 13 additions & 4 deletions lace/lace_codebook/src/value_map.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use lace_data::Category;
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use std::hash::Hash;
use thiserror::Error;
Expand Down Expand Up @@ -30,16 +31,24 @@ where
self.to_cat.is_empty()
}

pub fn ix(&self, cat: &T) -> Option<usize> {
self.to_ix.get(cat).cloned()
pub fn ix<Q>(&self, cat: &Q) -> Option<usize>
where
T: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
self.to_ix.get(cat.borrow()).copied()
}

pub fn category(&self, ix: usize) -> T {
self.to_cat[ix].clone()
}

pub fn contains_cat(&self, cat: &T) -> bool {
self.to_ix.contains_key(cat)
pub fn contains_cat<Q>(&self, cat: &Q) -> bool
where
T: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
self.to_ix.contains_key(cat.borrow())
}

pub(crate) fn add(&mut self, value: T) {
Expand Down
Loading