Skip to content

Commit

Permalink
Merge pull request #258 from stan-dev/factor
Browse files Browse the repository at this point in the history
Factor rvars
  • Loading branch information
paul-buerkner authored Dec 20, 2022
2 parents 336e31f + 4d4e4a3 commit ef688e3
Show file tree
Hide file tree
Showing 60 changed files with 3,398 additions and 259 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/rcmdcheck.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ jobs:
- {os: macOS-latest, r: 'devel', suggested_check: TRUE}
- {os: macOS-latest, r: 'release', suggested_check: TRUE}
- {os: windows-latest, r: 'release', suggested_check: TRUE}
- {os: windows-latest, r: '3.6', suggested_check: FALSE}
- {os: ubuntu-18.04, r: 'devel', rspm: "https://packagemanager.rstudio.com/cran/__linux__/bionic/latest", http-user-agent: "R/4.0.0 (ubuntu-18.04) R (4.0.0 x86_64-pc-linux-gnu x86_64 linux-gnu) on GitHub Actions" , suggested_check: TRUE}
- {os: ubuntu-18.04, r: 'release', rspm: "https://packagemanager.rstudio.com/cran/__linux__/bionic/latest", suggested_check: TRUE}
- {os: ubuntu-18.04, r: 'oldrel', rspm: "https://packagemanager.rstudio.com/cran/__linux__/bionic/latest", suggested_check: FALSE}
- {os: ubuntu-18.04, r: '3.6', rspm: "https://packagemanager.rstudio.com/cran/__linux__/bionic/latest", suggested_check: FALSE}

env:
R_REMOTES_NO_ERRORS_FROM_WARNINGS: true
Expand Down
6 changes: 3 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ Imports:
methods,
abind,
checkmate,
rlang (>= 0.4.7),
rlang (>= 1.0.6),
stats,
tibble (>= 3.0.0),
vctrs,
vctrs (>= 0.5.0),
tensorA,
pillar,
distributional,
Expand All @@ -53,5 +53,5 @@ LazyData: false
URL: https://mc-stan.org/posterior/, https://discourse.mc-stan.org/
BugReports: https://github.com/stan-dev/posterior/issues
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.1
RoxygenNote: 7.2.2
VignetteBuilder: knitr
84 changes: 83 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ S3method("[[",rvar)
S3method("[[<-",rvar)
S3method("dim<-",rvar)
S3method("dimnames<-",rvar)
S3method("levels<-",rvar)
S3method("names<-",rvar)
S3method("variables<-",draws_array)
S3method("variables<-",draws_df)
Expand All @@ -23,7 +24,10 @@ S3method(.subset_draws,draws_list)
S3method(.subset_draws,draws_matrix)
S3method(.subset_draws,draws_rvars)
S3method(Math,rvar)
S3method(Math,rvar_factor)
S3method(Ops,rvar)
S3method(Ops,rvar_factor)
S3method(Ops,rvar_ordered)
S3method(Pr,default)
S3method(Pr,logical)
S3method(Pr,rvar)
Expand All @@ -32,6 +36,7 @@ S3method(all,equal.rvar)
S3method(all,rvar)
S3method(any,rvar)
S3method(anyDuplicated,rvar)
S3method(anyDuplicated,rvar_factor)
S3method(anyNA,rvar)
S3method(aperm,rvar)
S3method(as.data.frame,rvar)
Expand Down Expand Up @@ -91,9 +96,13 @@ S3method(bind_draws,draws_list)
S3method(bind_draws,draws_matrix)
S3method(bind_draws,draws_rvars)
S3method(bind_draws,list)
S3method(broadcast_and_bind_rvars,rvar)
S3method(broadcast_and_bind_rvars,rvar_factor)
S3method(c,rvar)
S3method(cbind,rvar)
S3method(cdf,rvar)
S3method(cdf,rvar_factor)
S3method(cdf,rvar_ordered)
S3method(chain_ids,"NULL")
S3method(chain_ids,draws_array)
S3method(chain_ids,draws_df)
Expand All @@ -102,8 +111,11 @@ S3method(chain_ids,draws_matrix)
S3method(chain_ids,draws_rvars)
S3method(chol,rvar)
S3method(density,rvar)
S3method(density,rvar_factor)
S3method(dim,rvar)
S3method(dimnames,rvar)
S3method(dissent,default)
S3method(dissent,rvar)
S3method(draw_ids,"NULL")
S3method(draw_ids,draws_array)
S3method(draw_ids,draws_df)
Expand All @@ -112,6 +124,9 @@ S3method(draw_ids,draws_matrix)
S3method(draw_ids,draws_rvars)
S3method(draw_ids,rvar)
S3method(duplicated,rvar)
S3method(duplicated,rvar_factor)
S3method(entropy,default)
S3method(entropy,rvar)
S3method(ess_basic,default)
S3method(ess_basic,rvar)
S3method(ess_bulk,default)
Expand All @@ -132,6 +147,9 @@ S3method(extract_variable_matrix,draws)
S3method(extract_variable_matrix,draws_rvars)
S3method(format,rvar)
S3method(format_glimpse,rvar)
S3method(get_rvar_class,default)
S3method(get_rvar_class,factor)
S3method(get_rvar_class,ordered)
S3method(is.array,rvar)
S3method(is.finite,rvar)
S3method(is.infinite,rvar)
Expand All @@ -149,6 +167,9 @@ S3method(length,rvar)
S3method(levels,rvar)
S3method(mad,default)
S3method(mad,rvar)
S3method(mad,rvar_ordered)
S3method(match,default)
S3method(match,rvar)
S3method(max,rvar)
S3method(mcse_mean,default)
S3method(mcse_mean,rvar)
Expand All @@ -165,6 +186,8 @@ S3method(merge_chains,draws_matrix)
S3method(merge_chains,draws_rvars)
S3method(merge_chains,rvar)
S3method(min,rvar)
S3method(modal_category,default)
S3method(modal_category,rvar)
S3method(mutate_variables,draws_array)
S3method(mutate_variables,draws_df)
S3method(mutate_variables,draws_list)
Expand Down Expand Up @@ -209,6 +232,8 @@ S3method(print,draws_rvars)
S3method(print,rvar)
S3method(prod,rvar)
S3method(quantile,rvar)
S3method(quantile,rvar_factor)
S3method(quantile,rvar_ordered)
S3method(quantile2,default)
S3method(quantile2,rvar)
S3method(range,rvar)
Expand Down Expand Up @@ -262,6 +287,7 @@ S3method(t,rvar)
S3method(thin_draws,draws)
S3method(thin_draws,rvar)
S3method(unique,rvar)
S3method(unique,rvar_factor)
S3method(var,default)
S3method(var,rvar)
S3method(variables,"NULL")
Expand All @@ -274,33 +300,79 @@ S3method(variance,draws_array)
S3method(variance,draws_matrix)
S3method(variance,rvar)
S3method(vec_cast,character.rvar)
S3method(vec_cast,character.rvar_factor)
S3method(vec_cast,character.rvar_ordered)
S3method(vec_cast,distribution.rvar)
S3method(vec_cast,rvar.character)
S3method(vec_cast,rvar.distribution)
S3method(vec_cast,rvar.double)
S3method(vec_cast,rvar.factor)
S3method(vec_cast,rvar.integer)
S3method(vec_cast,rvar.logical)
S3method(vec_cast,rvar.ordered)
S3method(vec_cast,rvar.rvar)
S3method(vec_cast,rvar.rvar_factor)
S3method(vec_cast,rvar.rvar_ordered)
S3method(vec_cast,rvar_factor.character)
S3method(vec_cast,rvar_factor.double)
S3method(vec_cast,rvar_factor.factor)
S3method(vec_cast,rvar_factor.integer)
S3method(vec_cast,rvar_factor.logical)
S3method(vec_cast,rvar_factor.ordered)
S3method(vec_cast,rvar_factor.rvar)
S3method(vec_cast,rvar_factor.rvar_factor)
S3method(vec_cast,rvar_factor.rvar_ordered)
S3method(vec_cast,rvar_ordered.character)
S3method(vec_cast,rvar_ordered.double)
S3method(vec_cast,rvar_ordered.factor)
S3method(vec_cast,rvar_ordered.integer)
S3method(vec_cast,rvar_ordered.logical)
S3method(vec_cast,rvar_ordered.ordered)
S3method(vec_cast,rvar_ordered.rvar)
S3method(vec_cast,rvar_ordered.rvar_factor)
S3method(vec_cast,rvar_ordered.rvar_ordered)
S3method(vec_proxy,rvar)
S3method(vec_ptype,rvar)
S3method(vec_ptype,rvar_factor)
S3method(vec_ptype,rvar_ordered)
S3method(vec_ptype2,character.rvar_factor)
S3method(vec_ptype2,character.rvar_ordered)
S3method(vec_ptype2,distribution.rvar)
S3method(vec_ptype2,double.rvar)
S3method(vec_ptype2,factor.rvar_factor)
S3method(vec_ptype2,factor.rvar_ordered)
S3method(vec_ptype2,integer.rvar)
S3method(vec_ptype2,logical.rvar)
S3method(vec_ptype2,ordered.rvar_factor)
S3method(vec_ptype2,ordered.rvar_ordered)
S3method(vec_ptype2,rvar.distribution)
S3method(vec_ptype2,rvar.double)
S3method(vec_ptype2,rvar.integer)
S3method(vec_ptype2,rvar.logical)
S3method(vec_ptype2,rvar.rvar)
S3method(vec_ptype2,rvar_factor.character)
S3method(vec_ptype2,rvar_factor.factor)
S3method(vec_ptype2,rvar_factor.ordered)
S3method(vec_ptype2,rvar_factor.rvar_factor)
S3method(vec_ptype2,rvar_ordered.character)
S3method(vec_ptype2,rvar_ordered.factor)
S3method(vec_ptype2,rvar_ordered.ordered)
S3method(vec_ptype2,rvar_ordered.rvar_ordered)
S3method(vec_ptype_abbr,rvar)
S3method(vec_ptype_abbr,rvar_factor)
S3method(vec_ptype_abbr,rvar_ordered)
S3method(vec_ptype_full,rvar)
S3method(vec_restore,rvar)
S3method(vec_restore,rvar_factor)
S3method(vec_restore,rvar_ordered)
S3method(weight_draws,draws_array)
S3method(weight_draws,draws_df)
S3method(weight_draws,draws_list)
S3method(weight_draws,draws_matrix)
S3method(weight_draws,draws_rvars)
S3method(weights,draws)
export("%**%")
export("%in%")
export("draws_of<-")
export("variables<-")
export(E)
Expand All @@ -312,6 +384,9 @@ export(as_draws_list)
export(as_draws_matrix)
export(as_draws_rvars)
export(as_rvar)
export(as_rvar_factor)
export(as_rvar_numeric)
export(as_rvar_ordered)
export(autocorrelation)
export(autocovariance)
export(bind_draws)
Expand All @@ -321,6 +396,7 @@ export(default_convergence_measures)
export(default_mcse_measures)
export(default_summary_measures)
export(diag)
export(dissent)
export(draw_ids)
export(draws_array)
export(draws_df)
Expand All @@ -329,6 +405,7 @@ export(draws_matrix)
export(draws_of)
export(draws_rvars)
export(drop)
export(entropy)
export(ess_basic)
export(ess_bulk)
export(ess_mean)
Expand All @@ -347,13 +424,17 @@ export(is_draws_list)
export(is_draws_matrix)
export(is_draws_rvars)
export(is_rvar)
export(is_rvar_factor)
export(is_rvar_ordered)
export(iteration_ids)
export(mad)
export(match)
export(mcse_mean)
export(mcse_median)
export(mcse_quantile)
export(mcse_sd)
export(merge_chains)
export(modal_category)
export(mutate_variables)
export(nchains)
export(ndraws)
Expand All @@ -375,6 +456,7 @@ export(rvar)
export(rvar_all)
export(rvar_any)
export(rvar_apply)
export(rvar_factor)
export(rvar_is_finite)
export(rvar_is_infinite)
export(rvar_is_na)
Expand All @@ -384,6 +466,7 @@ export(rvar_max)
export(rvar_mean)
export(rvar_median)
export(rvar_min)
export(rvar_ordered)
export(rvar_prod)
export(rvar_quantile)
export(rvar_range)
Expand Down Expand Up @@ -438,7 +521,6 @@ importFrom(utils,lsf.str)
importFrom(utils,str)
importFrom(vctrs,new_vctr)
importFrom(vctrs,vec_cast)
importFrom(vctrs,vec_chop)
importFrom(vctrs,vec_proxy)
importFrom(vctrs,vec_ptype)
importFrom(vctrs,vec_ptype2)
Expand Down
18 changes: 18 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,26 @@

### Enhancements

* Added new `rvar_factor()` and `rvar_ordered()` subtypes of `rvar()` that work
analogously to `factor()` and `ordered()` (#149). See the new section on
`rvar_factor`s in `vignette("rvar")`.
* The `draws_df()`, `draws_list()`, and `draws_rvars()` formats now support
discrete variables stored as `factors` / `ordered`s (or `rvar_factor`s /
`rvar_ordered`s). If converted to formats that do not support discrete
variables with named levels (`draws_matrix()` and `draws_array()`),
factor-like variables are converted to `numeric`s.
* Made `match()` and `%in%` generic and added support for `rvar`s to both
functions.
* Added `modal_category()`, `entropy()`, and `dissent()` functions for
summarizing discrete draws.
* Allow lists of draws objects to be passed as the first argument to
`bind_draws()` (#253).
* `print.rvar()` and `format.rvar()` now default to a smaller number of
significant digits in more cases, including when printing in data frames.
This is controlled by the new `"posterior.digits"` option (see
`help("posterior-package")`).
* Implemented faster `vec_proxy.rvar()` and `vec_restore.rvar()`, improving
performance of `rvar`s in `tibble`s (and elsewhere `vctrs` is used).


# posterior 1.3.1
Expand Down
26 changes: 26 additions & 0 deletions R/as_draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,32 @@ check_draws_object <- function(x) {
x
}

#' check all variables in an object are numeric, converting non-numeric
#' to numeric with a warning. Used when converting to draws_array or
#' draws_matrix formats (which don't support non-numeric variables).
#' @param x A draws_df, draws_list, or draws_rvars object
#' @param is_non_numeric function that checks if a variable is non-numeric
#' @param convert convert non-numeric variables to numeric?
#' @noRd
check_variables_are_numeric <- function(
x, to = "draws_array",
is_non_numeric = function(x_i) !is.numeric(x_i) && !is.logical(x_i),
convert = TRUE
) {

non_numeric_cols <- vapply(x, is_non_numeric, logical(1))
if (any(non_numeric_cols)) {
warning_no_call(
to,
" does not support non-numeric variables (e.g., factors). Converting non-numeric variables to numeric."
)
}
if (convert) {
x[, non_numeric_cols] <- lapply(unclass(x)[non_numeric_cols], as.numeric)
}
x
}

# define default variable names
# use the 'unique' naming strategy of tibble
# @param nvariables number of variables
Expand Down
7 changes: 7 additions & 0 deletions R/as_draws_array.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ as_draws_array.draws_df <- function(x, ...) {
}
iterations <- iteration_ids(x)
chains <- chain_ids(x)
x <- check_variables_are_numeric(x, to = "draws_array")
out <- vector("list", length(chains))
for (i in seq_along(out)) {
if (length(chains) == 1) {
Expand Down Expand Up @@ -92,6 +93,12 @@ as_draws_array.draws_rvars <- function(x, ...) {
return(empty_draws_array(variables(x)))
}

x <- check_variables_are_numeric(
x, to = "draws_array", is_non_numeric = is_rvar_factor, convert = FALSE
)

# cbind discards class information when applied to vectors, which converts
# the underlying factors to numeric
draws <- do.call(cbind, lapply(seq_along(x), function(i) {
# flatten each rvar so it only has two dimensions: draws and variables
# this also collapses indices into variable names in the format "var[i,j,k,...]"
Expand Down
Loading

0 comments on commit ef688e3

Please sign in to comment.