Skip to content

Adding Probit link to BART and BCF #164

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 136 additions & 31 deletions R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
#' - `sigma2_leaf_scale` Scale parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here.
#' - `keep_vars` Vector of variable names or column indices denoting variables that should be included in the forest. Default: `NULL`.
#' - `drop_vars` Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: `NULL`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored.
#' - `probit_outcome_model` Whether or not the outcome should be modeled as explicitly binary via a probit link. If `TRUE`, `y` must only contain the values `0` and `1`. Default: `FALSE`.
#'
#' @param variance_forest_params (Optional) A list of variance forest model parameters, each of which has a default value processed internally, so this argument list is optional.
#'
Expand Down Expand Up @@ -125,7 +126,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
min_samples_leaf = 5, max_depth = 10,
sample_sigma2_leaf = TRUE, sigma2_leaf_init = NULL,
sigma2_leaf_shape = 3, sigma2_leaf_scale = NULL,
keep_vars = NULL, drop_vars = NULL
keep_vars = NULL, drop_vars = NULL,
probit_outcome_model = FALSE
)
mean_forest_params_updated <- preprocessParams(
mean_forest_params_default, mean_forest_params
Expand Down Expand Up @@ -173,6 +175,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
b_leaf <- mean_forest_params_updated$sigma2_leaf_scale
keep_vars_mean <- mean_forest_params_updated$keep_vars
drop_vars_mean <- mean_forest_params_updated$drop_vars
probit_outcome_model <- mean_forest_params_updated$probit_outcome_model

# 3. Variance forest parameters
num_trees_variance <- variance_forest_params_updated$num_trees
Expand Down Expand Up @@ -462,50 +465,118 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train

# Determine whether a test set is provided
has_test = !is.null(X_test)

# Preliminary runtime checks for probit link
if (!include_mean_forest) {
probit_outcome_model <- FALSE
}
if (probit_outcome_model) {
if (!(length(unique(y_train)) == 2)) {
stop("You specified a probit outcome model, but supplied an outcome with more than 2 unique values")
}
unique_outcomes <- sort(unique(y_train))
if (!(all(unique_outcomes == c(0,1)))) {
stop("You specified a probit outcome model, but supplied an outcome with 2 unique values other than 0 and 1")
}
if (include_variance_forest) {
stop("We do not support heteroskedasticity with a probit link")
}
if (sample_sigma_global) {
warning("Global error variance will not be sampled with a probit link as it is fixed at 1")
sample_sigma_global <- F
}
}

# Standardize outcome separately for test and train
if (standardize) {
y_bar_train <- mean(y_train)
y_std_train <- sd(y_train)
} else {
y_bar_train <- 0
# Handle standardization, prior calibration, and initialization of forest
# differently for binary and continuous outcomes
if (probit_outcome_model) {
# Compute a probit-scale offset and fix scale to 1
y_bar_train <- qnorm(mean(y_train))
y_std_train <- 1
}
resid_train <- (y_train-y_bar_train)/y_std_train

# Compute initial value of root nodes in mean forest
init_val_mean <- mean(resid_train)

# Calibrate priors for sigma^2 and tau
if (is.null(sigma2_init)) sigma2_init <- 1.0*var(resid_train)
if (is.null(variance_forest_init)) variance_forest_init <- 1.0*var(resid_train)
if (is.null(b_leaf)) b_leaf <- var(resid_train)/(2*num_trees_mean)
if (has_basis) {
if (ncol(leaf_basis_train) > 1) {
if (is.null(sigma_leaf_init)) sigma_leaf_init <- diag(var(resid_train)/(num_trees_mean), ncol(leaf_basis_train))
if (!is.matrix(sigma_leaf_init)) {
current_leaf_scale <- as.matrix(diag(sigma_leaf_init, ncol(leaf_basis_train)))
# Set a pseudo outcome by subtracting mean(y_train) from y_train
resid_train <- y_train - mean(y_train)

# Set initial values of root nodes to 0.0 (in probit scale)
init_val_mean <- 0.0

# Calibrate priors for sigma^2 and tau
# Set sigma2_init to 1, ignoring default provided
sigma2_init <- 1.0
# Skip variance_forest_init, since variance forests are not supported with probit link
b_leaf <- 1/(num_trees_mean)
if (has_basis) {
if (ncol(leaf_basis_train) > 1) {
if (is.null(sigma_leaf_init)) sigma_leaf_init <- diag(2/(num_trees_mean), ncol(leaf_basis_train))
if (!is.matrix(sigma_leaf_init)) {
current_leaf_scale <- as.matrix(diag(sigma_leaf_init, ncol(leaf_basis_train)))
} else {
current_leaf_scale <- sigma_leaf_init
}
} else {
current_leaf_scale <- sigma_leaf_init
if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(2/(num_trees_mean))
if (!is.matrix(sigma_leaf_init)) {
current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1))
} else {
current_leaf_scale <- sigma_leaf_init
}
}
} else {
if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(var(resid_train)/(num_trees_mean))
if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(2/(num_trees_mean))
if (!is.matrix(sigma_leaf_init)) {
current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1))
} else {
current_leaf_scale <- sigma_leaf_init
}
}
current_sigma2 <- sigma2_init
} else {
if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(var(resid_train)/(num_trees_mean))
if (!is.matrix(sigma_leaf_init)) {
current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1))
# Only standardize if user requested
if (standardize) {
y_bar_train <- mean(y_train)
y_std_train <- sd(y_train)
} else {
y_bar_train <- 0
y_std_train <- 1
}

# Compute residual value
resid_train <- (y_train-y_bar_train)/y_std_train

# Compute initial value of root nodes in mean forest
init_val_mean <- mean(resid_train)

# Calibrate priors for sigma^2 and tau
if (is.null(sigma2_init)) sigma2_init <- 1.0*var(resid_train)
if (is.null(variance_forest_init)) variance_forest_init <- 1.0*var(resid_train)
if (is.null(b_leaf)) b_leaf <- var(resid_train)/(2*num_trees_mean)
if (has_basis) {
if (ncol(leaf_basis_train) > 1) {
if (is.null(sigma_leaf_init)) sigma_leaf_init <- diag(var(resid_train)/(num_trees_mean), ncol(leaf_basis_train))
if (!is.matrix(sigma_leaf_init)) {
current_leaf_scale <- as.matrix(diag(sigma_leaf_init, ncol(leaf_basis_train)))
} else {
current_leaf_scale <- sigma_leaf_init
}
} else {
if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(var(resid_train)/(num_trees_mean))
if (!is.matrix(sigma_leaf_init)) {
current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1))
} else {
current_leaf_scale <- sigma_leaf_init
}
}
} else {
current_leaf_scale <- sigma_leaf_init
if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(var(resid_train)/(num_trees_mean))
if (!is.matrix(sigma_leaf_init)) {
current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1))
} else {
current_leaf_scale <- sigma_leaf_init
}
}
current_sigma2 <- sigma2_init
}
current_sigma2 <- sigma2_init


# Determine leaf model type
if (!has_basis) leaf_model_mean_forest <- 0
else if (ncol(leaf_basis_train) == 1) leaf_model_mean_forest <- 1
Expand Down Expand Up @@ -634,7 +705,6 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
# Initialize the leaves of each tree in the variance forest
if (include_variance_forest) {
active_forest_variance$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_variance, leaf_model_variance_forest, variance_forest_init)

}

# Run GFR (warm start) if specified
Expand All @@ -652,6 +722,21 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
}

if (include_mean_forest) {
if (probit_outcome_model) {
# Sample latent probit variable, z | -
forest_pred <- active_forest_mean$predict(forest_dataset_train) + y_bar_train
mu0 <- forest_pred[y_train == 0]
mu1 <- forest_pred[y_train == 1]
u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0))
u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1)
resid_train[y_train==0] <- mu0 + qnorm(u0)
resid_train[y_train==1] <- mu1 + qnorm(u1)

# Update outcome
outcome_train$update_data(resid_train - forest_pred)
}

# Sample mean forest
forest_model_mean$sample_one_iteration(
forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mean,
active_forest = active_forest_mean, rng = rng, forest_model_config = forest_model_config_mean,
Expand Down Expand Up @@ -791,6 +876,20 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
}

if (include_mean_forest) {
if (probit_outcome_model) {
# Sample latent probit variable, z | -
forest_pred <- active_forest_mean$predict(forest_dataset_train) + y_bar_train
mu0 <- forest_pred[y_train == 0]
mu1 <- forest_pred[y_train == 1]
u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0))
u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1)
resid_train[y_train==0] <- mu0 + qnorm(u0)
resid_train[y_train==1] <- mu1 + qnorm(u1)

# Update outcome
outcome_train$update_data(resid_train - forest_pred)
}

forest_model_mean$sample_one_iteration(
forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mean,
active_forest = active_forest_mean, rng = rng, forest_model_config = forest_model_config_mean,
Expand Down Expand Up @@ -915,7 +1014,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
"sample_sigma_global" = sample_sigma_global,
"sample_sigma_leaf" = sample_sigma_leaf,
"include_mean_forest" = include_mean_forest,
"include_variance_forest" = include_variance_forest
"include_variance_forest" = include_variance_forest,
"probit_outcome_model" = probit_outcome_model
)
result <- list(
"model_params" = model_params,
Expand Down Expand Up @@ -1257,6 +1357,7 @@ saveBARTModelToJson <- function(object){
jsonobj$add_scalar("num_chains", object$model_params$num_chains)
jsonobj$add_scalar("keep_every", object$model_params$keep_every)
jsonobj$add_boolean("requires_basis", object$model_params$requires_basis)
jsonobj$add_boolean("probit_outcome_model", object$model_params$probit_outcome_model)
if (object$model_params$sample_sigma_global) {
jsonobj$add_vector("sigma2_global_samples", object$sigma2_global_samples, "parameters")
}
Expand Down Expand Up @@ -1448,6 +1549,8 @@ createBARTModelFromJson <- function(json_object){
model_params[["num_chains"]] <- json_object$get_scalar("num_chains")
model_params[["keep_every"]] <- json_object$get_scalar("keep_every")
model_params[["requires_basis"]] <- json_object$get_boolean("requires_basis")
model_params[["probit_outcome_model"]] <- json_object$get_boolean("probit_outcome_model")

output[["model_params"]] <- model_params

# Unpack sampled parameters
Expand Down Expand Up @@ -1650,6 +1753,7 @@ createBARTModelFromCombinedJson <- function(json_object_list){
model_params[["num_covariates"]] <- json_object_default$get_scalar("num_covariates")
model_params[["num_basis"]] <- json_object_default$get_scalar("num_basis")
model_params[["requires_basis"]] <- json_object_default$get_boolean("requires_basis")
model_params[["probit_outcome_model"]] <- json_object_default$get_boolean("probit_outcome_model")
model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains")
model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every")

Expand Down Expand Up @@ -1805,6 +1909,7 @@ createBARTModelFromCombinedJsonString <- function(json_string_list){
model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains")
model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every")
model_params[["requires_basis"]] <- json_object_default$get_boolean("requires_basis")
model_params[["probit_outcome_model"]] <- json_object_default$get_boolean("probit_outcome_model")

# Combine values that are sample-specific
for (i in 1:length(json_object_list)) {
Expand Down
59 changes: 59 additions & 0 deletions demo/debug/classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve, auc

from stochtree import BARTModel

# RNG
rng = np.random.default_rng()

# Generate covariates
n = 1000
p_X = 10
X = rng.uniform(0, 1, (n, p_X))


# Define the outcome mean function
def outcome_mean(X):
return np.where(
(X[:, 0] >= 0.0) & (X[:, 0] < 0.25),
-7.5 * X[:, 1],
np.where(
(X[:, 0] >= 0.25) & (X[:, 0] < 0.5),
-2.5 * X[:, 1],
np.where((X[:, 0] >= 0.5) & (X[:, 0] < 0.75), 2.5 * X[:, 1], 7.5 * X[:, 1]),
),
)


# Generate outcome
epsilon = rng.normal(0, 1, n)
z = outcome_mean(X) + epsilon
y = np.where(z >= 0, 1, 0)

# Test-train split
sample_inds = np.arange(n)
train_inds, test_inds = train_test_split(sample_inds, test_size=0.5)
X_train = X[train_inds, :]
X_test = X[test_inds, :]
z_train = z[train_inds]
z_test = z[test_inds]
y_train = y[train_inds]
y_test = y[test_inds]

# Fit Probit BART
bart_model = BARTModel()
general_params = {"num_chains": 1}
mean_forest_params = {"probit_outcome_model": True}
bart_model.sample(
X_train=X_train,
y_train=y_train,
X_test=X_test,
num_gfr=10,
num_mcmc=100,
general_params=general_params,
mean_forest_params=mean_forest_params
)
Loading
Loading