diff --git a/R/bart.R b/R/bart.R index 96815850..340e53ae 100644 --- a/R/bart.R +++ b/R/bart.R @@ -58,6 +58,7 @@ #' - `sigma2_leaf_scale` Scale parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here. #' - `keep_vars` Vector of variable names or column indices denoting variables that should be included in the forest. Default: `NULL`. #' - `drop_vars` Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: `NULL`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored. +#' - `probit_outcome_model` Whether or not the outcome should be modeled as explicitly binary via a probit link. If `TRUE`, `y` must only contain the values `0` and `1`. Default: `FALSE`. #' #' @param variance_forest_params (Optional) A list of variance forest model parameters, each of which has a default value processed internally, so this argument list is optional. #' @@ -125,7 +126,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train min_samples_leaf = 5, max_depth = 10, sample_sigma2_leaf = TRUE, sigma2_leaf_init = NULL, sigma2_leaf_shape = 3, sigma2_leaf_scale = NULL, - keep_vars = NULL, drop_vars = NULL + keep_vars = NULL, drop_vars = NULL, + probit_outcome_model = FALSE ) mean_forest_params_updated <- preprocessParams( mean_forest_params_default, mean_forest_params @@ -173,6 +175,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train b_leaf <- mean_forest_params_updated$sigma2_leaf_scale keep_vars_mean <- mean_forest_params_updated$keep_vars drop_vars_mean <- mean_forest_params_updated$drop_vars + probit_outcome_model <- mean_forest_params_updated$probit_outcome_model # 3. Variance forest parameters num_trees_variance <- variance_forest_params_updated$num_trees @@ -462,50 +465,118 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train # Determine whether a test set is provided has_test = !is.null(X_test) + + # Preliminary runtime checks for probit link + if (!include_mean_forest) { + probit_outcome_model <- FALSE + } + if (probit_outcome_model) { + if (!(length(unique(y_train)) == 2)) { + stop("You specified a probit outcome model, but supplied an outcome with more than 2 unique values") + } + unique_outcomes <- sort(unique(y_train)) + if (!(all(unique_outcomes == c(0,1)))) { + stop("You specified a probit outcome model, but supplied an outcome with 2 unique values other than 0 and 1") + } + if (include_variance_forest) { + stop("We do not support heteroskedasticity with a probit link") + } + if (sample_sigma_global) { + warning("Global error variance will not be sampled with a probit link as it is fixed at 1") + sample_sigma_global <- F + } + } - # Standardize outcome separately for test and train - if (standardize) { - y_bar_train <- mean(y_train) - y_std_train <- sd(y_train) - } else { - y_bar_train <- 0 + # Handle standardization, prior calibration, and initialization of forest + # differently for binary and continuous outcomes + if (probit_outcome_model) { + # Compute a probit-scale offset and fix scale to 1 + y_bar_train <- qnorm(mean(y_train)) y_std_train <- 1 - } - resid_train <- (y_train-y_bar_train)/y_std_train - - # Compute initial value of root nodes in mean forest - init_val_mean <- mean(resid_train) - # Calibrate priors for sigma^2 and tau - if (is.null(sigma2_init)) sigma2_init <- 1.0*var(resid_train) - if (is.null(variance_forest_init)) variance_forest_init <- 1.0*var(resid_train) - if (is.null(b_leaf)) b_leaf <- var(resid_train)/(2*num_trees_mean) - if (has_basis) { - if (ncol(leaf_basis_train) > 1) { - if (is.null(sigma_leaf_init)) sigma_leaf_init <- diag(var(resid_train)/(num_trees_mean), ncol(leaf_basis_train)) - if (!is.matrix(sigma_leaf_init)) { - current_leaf_scale <- as.matrix(diag(sigma_leaf_init, ncol(leaf_basis_train))) + # Set a pseudo outcome by subtracting mean(y_train) from y_train + resid_train <- y_train - mean(y_train) + + # Set initial values of root nodes to 0.0 (in probit scale) + init_val_mean <- 0.0 + + # Calibrate priors for sigma^2 and tau + # Set sigma2_init to 1, ignoring default provided + sigma2_init <- 1.0 + # Skip variance_forest_init, since variance forests are not supported with probit link + b_leaf <- 1/(num_trees_mean) + if (has_basis) { + if (ncol(leaf_basis_train) > 1) { + if (is.null(sigma_leaf_init)) sigma_leaf_init <- diag(2/(num_trees_mean), ncol(leaf_basis_train)) + if (!is.matrix(sigma_leaf_init)) { + current_leaf_scale <- as.matrix(diag(sigma_leaf_init, ncol(leaf_basis_train))) + } else { + current_leaf_scale <- sigma_leaf_init + } } else { - current_leaf_scale <- sigma_leaf_init + if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(2/(num_trees_mean)) + if (!is.matrix(sigma_leaf_init)) { + current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1)) + } else { + current_leaf_scale <- sigma_leaf_init + } } } else { - if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(var(resid_train)/(num_trees_mean)) + if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(2/(num_trees_mean)) if (!is.matrix(sigma_leaf_init)) { current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1)) } else { current_leaf_scale <- sigma_leaf_init } } + current_sigma2 <- sigma2_init } else { - if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(var(resid_train)/(num_trees_mean)) - if (!is.matrix(sigma_leaf_init)) { - current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1)) + # Only standardize if user requested + if (standardize) { + y_bar_train <- mean(y_train) + y_std_train <- sd(y_train) + } else { + y_bar_train <- 0 + y_std_train <- 1 + } + + # Compute residual value + resid_train <- (y_train-y_bar_train)/y_std_train + + # Compute initial value of root nodes in mean forest + init_val_mean <- mean(resid_train) + + # Calibrate priors for sigma^2 and tau + if (is.null(sigma2_init)) sigma2_init <- 1.0*var(resid_train) + if (is.null(variance_forest_init)) variance_forest_init <- 1.0*var(resid_train) + if (is.null(b_leaf)) b_leaf <- var(resid_train)/(2*num_trees_mean) + if (has_basis) { + if (ncol(leaf_basis_train) > 1) { + if (is.null(sigma_leaf_init)) sigma_leaf_init <- diag(var(resid_train)/(num_trees_mean), ncol(leaf_basis_train)) + if (!is.matrix(sigma_leaf_init)) { + current_leaf_scale <- as.matrix(diag(sigma_leaf_init, ncol(leaf_basis_train))) + } else { + current_leaf_scale <- sigma_leaf_init + } + } else { + if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(var(resid_train)/(num_trees_mean)) + if (!is.matrix(sigma_leaf_init)) { + current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1)) + } else { + current_leaf_scale <- sigma_leaf_init + } + } } else { - current_leaf_scale <- sigma_leaf_init + if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(var(resid_train)/(num_trees_mean)) + if (!is.matrix(sigma_leaf_init)) { + current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1)) + } else { + current_leaf_scale <- sigma_leaf_init + } } + current_sigma2 <- sigma2_init } - current_sigma2 <- sigma2_init - + # Determine leaf model type if (!has_basis) leaf_model_mean_forest <- 0 else if (ncol(leaf_basis_train) == 1) leaf_model_mean_forest <- 1 @@ -634,7 +705,6 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train # Initialize the leaves of each tree in the variance forest if (include_variance_forest) { active_forest_variance$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_variance, leaf_model_variance_forest, variance_forest_init) - } # Run GFR (warm start) if specified @@ -652,6 +722,21 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train } if (include_mean_forest) { + if (probit_outcome_model) { + # Sample latent probit variable, z | - + forest_pred <- active_forest_mean$predict(forest_dataset_train) + y_bar_train + mu0 <- forest_pred[y_train == 0] + mu1 <- forest_pred[y_train == 1] + u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) + u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) + resid_train[y_train==0] <- mu0 + qnorm(u0) + resid_train[y_train==1] <- mu1 + qnorm(u1) + + # Update outcome + outcome_train$update_data(resid_train - forest_pred) + } + + # Sample mean forest forest_model_mean$sample_one_iteration( forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mean, active_forest = active_forest_mean, rng = rng, forest_model_config = forest_model_config_mean, @@ -791,6 +876,20 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train } if (include_mean_forest) { + if (probit_outcome_model) { + # Sample latent probit variable, z | - + forest_pred <- active_forest_mean$predict(forest_dataset_train) + y_bar_train + mu0 <- forest_pred[y_train == 0] + mu1 <- forest_pred[y_train == 1] + u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) + u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) + resid_train[y_train==0] <- mu0 + qnorm(u0) + resid_train[y_train==1] <- mu1 + qnorm(u1) + + # Update outcome + outcome_train$update_data(resid_train - forest_pred) + } + forest_model_mean$sample_one_iteration( forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mean, active_forest = active_forest_mean, rng = rng, forest_model_config = forest_model_config_mean, @@ -915,7 +1014,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train "sample_sigma_global" = sample_sigma_global, "sample_sigma_leaf" = sample_sigma_leaf, "include_mean_forest" = include_mean_forest, - "include_variance_forest" = include_variance_forest + "include_variance_forest" = include_variance_forest, + "probit_outcome_model" = probit_outcome_model ) result <- list( "model_params" = model_params, @@ -1257,6 +1357,7 @@ saveBARTModelToJson <- function(object){ jsonobj$add_scalar("num_chains", object$model_params$num_chains) jsonobj$add_scalar("keep_every", object$model_params$keep_every) jsonobj$add_boolean("requires_basis", object$model_params$requires_basis) + jsonobj$add_boolean("probit_outcome_model", object$model_params$probit_outcome_model) if (object$model_params$sample_sigma_global) { jsonobj$add_vector("sigma2_global_samples", object$sigma2_global_samples, "parameters") } @@ -1448,6 +1549,8 @@ createBARTModelFromJson <- function(json_object){ model_params[["num_chains"]] <- json_object$get_scalar("num_chains") model_params[["keep_every"]] <- json_object$get_scalar("keep_every") model_params[["requires_basis"]] <- json_object$get_boolean("requires_basis") + model_params[["probit_outcome_model"]] <- json_object$get_boolean("probit_outcome_model") + output[["model_params"]] <- model_params # Unpack sampled parameters @@ -1650,6 +1753,7 @@ createBARTModelFromCombinedJson <- function(json_object_list){ model_params[["num_covariates"]] <- json_object_default$get_scalar("num_covariates") model_params[["num_basis"]] <- json_object_default$get_scalar("num_basis") model_params[["requires_basis"]] <- json_object_default$get_boolean("requires_basis") + model_params[["probit_outcome_model"]] <- json_object_default$get_boolean("probit_outcome_model") model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") @@ -1805,6 +1909,7 @@ createBARTModelFromCombinedJsonString <- function(json_string_list){ model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") model_params[["requires_basis"]] <- json_object_default$get_boolean("requires_basis") + model_params[["probit_outcome_model"]] <- json_object_default$get_boolean("probit_outcome_model") # Combine values that are sample-specific for (i in 1:length(json_object_list)) { diff --git a/demo/debug/classification.py b/demo/debug/classification.py new file mode 100644 index 00000000..4d303289 --- /dev/null +++ b/demo/debug/classification.py @@ -0,0 +1,59 @@ +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from sklearn.model_selection import train_test_split +from sklearn.metrics import roc_curve, auc + +from stochtree import BARTModel + +# RNG +rng = np.random.default_rng() + +# Generate covariates +n = 1000 +p_X = 10 +X = rng.uniform(0, 1, (n, p_X)) + + +# Define the outcome mean function +def outcome_mean(X): + return np.where( + (X[:, 0] >= 0.0) & (X[:, 0] < 0.25), + -7.5 * X[:, 1], + np.where( + (X[:, 0] >= 0.25) & (X[:, 0] < 0.5), + -2.5 * X[:, 1], + np.where((X[:, 0] >= 0.5) & (X[:, 0] < 0.75), 2.5 * X[:, 1], 7.5 * X[:, 1]), + ), + ) + + +# Generate outcome +epsilon = rng.normal(0, 1, n) +z = outcome_mean(X) + epsilon +y = np.where(z >= 0, 1, 0) + +# Test-train split +sample_inds = np.arange(n) +train_inds, test_inds = train_test_split(sample_inds, test_size=0.5) +X_train = X[train_inds, :] +X_test = X[test_inds, :] +z_train = z[train_inds] +z_test = z[test_inds] +y_train = y[train_inds] +y_test = y[test_inds] + +# Fit Probit BART +bart_model = BARTModel() +general_params = {"num_chains": 1} +mean_forest_params = {"probit_outcome_model": True} +bart_model.sample( + X_train=X_train, + y_train=y_train, + X_test=X_test, + num_gfr=10, + num_mcmc=100, + general_params=general_params, + mean_forest_params=mean_forest_params +) diff --git a/demo/notebooks/supervised_learning_classification.ipynb b/demo/notebooks/supervised_learning_classification.ipynb new file mode 100644 index 00000000..e88b1b7b --- /dev/null +++ b/demo/notebooks/supervised_learning_classification.ipynb @@ -0,0 +1,221 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Supervised Learning (Classification)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load necessary libraries" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pandas as pd\n", + "import seaborn as sns\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.metrics import roc_curve, auc\n", + "\n", + "from stochtree import BARTModel" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Generate sample data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# RNG\n", + "rng = np.random.default_rng()\n", + "\n", + "# Generate covariates\n", + "n = 1000\n", + "p_X = 10\n", + "X = rng.uniform(0, 1, (n, p_X))\n", + "\n", + "\n", + "# Define the outcome mean function\n", + "def outcome_mean(X):\n", + " return np.where(\n", + " (X[:, 0] >= 0.0) & (X[:, 0] < 0.25),\n", + " -7.5 * X[:, 1],\n", + " np.where(\n", + " (X[:, 0] >= 0.25) & (X[:, 0] < 0.5),\n", + " -2.5 * X[:, 1],\n", + " np.where((X[:, 0] >= 0.5) & (X[:, 0] < 0.75), 2.5 * X[:, 1], 7.5 * X[:, 1]),\n", + " ),\n", + " )\n", + "\n", + "\n", + "# Generate outcome\n", + "epsilon = rng.normal(0, 1, n)\n", + "z = outcome_mean(X) + epsilon\n", + "y = np.where(z >= 0, 1, 0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Test-train split" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sample_inds = np.arange(n)\n", + "train_inds, test_inds = train_test_split(sample_inds, test_size=0.5)\n", + "X_train = X[train_inds, :]\n", + "X_test = X[test_inds, :]\n", + "z_train = z[train_inds]\n", + "z_test = z[test_inds]\n", + "y_train = y[train_inds]\n", + "y_test = y[test_inds]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Run BART" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "num_gfr = 10\n", + "num_mcmc = 100\n", + "bart_model = BARTModel()\n", + "general_params = {\"num_chains\": 1}\n", + "mean_forest_params = {\"probit_outcome_model\": True}\n", + "bart_model.sample(\n", + " X_train=X_train,\n", + " y_train=y_train,\n", + " X_test=X_test,\n", + " num_gfr=num_gfr,\n", + " num_mcmc=num_mcmc,\n", + " general_params=general_params,\n", + " mean_forest_params=mean_forest_params\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Since we've simulated this data, we can compare the true latent continuous outcome variable to the probit-scale predictions for a test set." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.scatter(x=np.mean(bart_model.y_hat_test,axis=1), y=z_test)\n", + "plt.axline((0, 0), slope=1, color=\"black\", linestyle=(0, (3, 3)))\n", + "plt.xlabel(\"Predicted\")\n", + "plt.ylabel(\"Actual\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "On non-simulated datasets, the first thing we would evaluate is the prediction accuracy." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "preds_test = np.mean(bart_model.y_hat_test,axis=1) > 0\n", + "print(f\"Test set accuracy: {np.mean(y_test == preds_test):.3f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also compute the [ROC curve](https://en.wikipedia.org/wiki/Receiver_operating_characteristic) for every posterior sample, as well as the ROC of the posterior mean." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "num_gfr = 10\n", + "num_mcmc = 100\n", + "fpr_list = list()\n", + "tpr_list = list()\n", + "threshold_list = list()\n", + "for i in range(num_mcmc):\n", + " fpr, tpr, thresholds = roc_curve(y_test, bart_model.y_hat_test[:,i], pos_label=1)\n", + " fpr_list.append(fpr)\n", + " tpr_list.append(tpr)\n", + " threshold_list.append(thresholds)\n", + "probit_preds_test_mean = np.mean(bart_model.y_hat_test,axis=1)\n", + "fpr_mean, tpr_mean, thresholds_mean = roc_curve(y_test, probit_preds_test_mean, pos_label=1)\n", + "for i in range(num_mcmc):\n", + " plt.plot(fpr_list[i], tpr_list[i], color = 'blue', linestyle='solid', linewidth = 0.9)\n", + "plt.plot(fpr_mean, tpr_mean, color = 'black', linestyle='dashed', linewidth = 1.75)\n", + "plt.axline((0, 0), slope=1, color=\"red\", linestyle='dashed', linewidth=1.5)\n", + "plt.xlabel(\"False Positive Rate\")\n", + "plt.ylabel(\"True Positive Rate\")\n", + "plt.xlim(0, 1)\n", + "plt.ylim(0, 1)\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/stochtree/bart.py b/stochtree/bart.py index dff3362f..fa012319 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -9,6 +9,7 @@ import numpy as np import pandas as pd +from scipy.stats import norm from .config import ForestModelConfig, GlobalModelConfig from .data import Dataset, Residual @@ -145,6 +146,7 @@ def sample( * `sigma2_leaf_scale` (`float`): Scale parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here. * `keep_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be included in the mean forest. Defaults to `None`. * `drop_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be excluded from the mean forest. Defaults to `None`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored. + * `probit_outcome_model` (`bool`): Whether or not the outcome should be modeled as explicitly binary via a probit link. If `True`, `y` must only contain the values `0` and `1`. Default: `False`. variance_forest_params : dict, optional Dictionary of variance forest model parameters, each of which has a default value processed internally, so this argument is optional. @@ -180,7 +182,7 @@ def sample( "sigma2_global_shape": 0, "sigma2_global_scale": 0, "variable_weights": None, - "random_seed": -1, + "random_seed": None, "keep_burnin": False, "keep_gfr": False, "keep_every": 1, @@ -203,6 +205,7 @@ def sample( "sigma2_leaf_scale": None, "keep_vars": None, "drop_vars": None, + "probit_outcome_model": False, } mean_forest_params_updated = _preprocess_params( mean_forest_params_default, mean_forest_params @@ -253,6 +256,7 @@ def sample( b_leaf = mean_forest_params_updated["sigma2_leaf_scale"] keep_vars_mean = mean_forest_params_updated["keep_vars"] drop_vars_mean = mean_forest_params_updated["drop_vars"] + self.probit_outcome_model = mean_forest_params_updated["probit_outcome_model"] # 3. Variance forest parameters num_trees_variance = variance_forest_params_updated["num_trees"] @@ -710,25 +714,42 @@ def sample( [variable_subset_variance.count(i) == 0 for i in original_var_indices] ] = 0 - # Scale outcome - if self.standardize: - self.y_bar = np.squeeze(np.mean(y_train)) - self.y_std = np.squeeze(np.std(y_train)) - else: - self.y_bar = 0 - self.y_std = 1 - resid_train = (y_train - self.y_bar) / self.y_std - - # Calibrate priors for global sigma^2 and sigma_leaf (don't use regression initializer for warm-start or XBART) - if not sigma2_init: - sigma2_init = 1.0 * np.var(resid_train) - if not variance_forest_leaf_init: - variance_forest_leaf_init = 0.6 * np.var(resid_train) - current_sigma2 = sigma2_init - self.sigma2_init = sigma2_init - if self.include_mean_forest: + # Preliminary runtime checks for probit link + if not self.include_mean_forest: + self.probit_outcome_model = False + if self.probit_outcome_model: + if np.unique(y_train).size != 2: + raise ValueError("You specified a probit outcome model, but supplied an outcome with more than 2 unique values") + unique_outcomes = np.squeeze(np.unique(y_train)) + if not np.array_equal(unique_outcomes, [0,1]): + raise ValueError("You specified a probit outcome model, but supplied an outcome with 2 unique values other than 0 and 1") + if self.include_variance_forest: + raise ValueError("We do not support heteroskedasticity with a probit link") + if sample_sigma_global: + warnings.warn("Global error variance will not be sampled with a probit link as it is fixed at 1") + sample_sigma_global = False + + # Handle standardization, prior calibration, and initialization of forest + # differently for binary and continuous outcomes + if self.probit_outcome_model: + # Compute a probit-scale offset and fix scale to 1 + self.y_bar = norm.ppf(np.squeeze(np.mean(y_train))) + self.y_std = 1.0 + + # Set a pseudo outcome by subtracting mean(y_train) from y_train + resid_train = y_train - np.squeeze(np.mean(y_train)) + + # Set initial values of root nodes to 0.0 (in probit scale) + init_val_mean = 0.0 + + # Calibrate priors for sigma^2 and tau + # Set sigma2_init to 1, ignoring default provided + sigma2_init = 1.0 + current_sigma2 = sigma2_init + self.sigma2_init = sigma2_init + # Skip variance_forest_init, since variance forests are not supported with probit link b_leaf = ( - np.squeeze(np.var(resid_train)) / num_trees_mean + 1.0 / num_trees_mean if b_leaf is None else b_leaf ) @@ -737,7 +758,7 @@ def sample( current_leaf_scale = np.zeros((self.num_basis, self.num_basis), dtype=float) np.fill_diagonal( current_leaf_scale, - np.squeeze(np.var(resid_train)) / num_trees_mean, + 2.0 / num_trees_mean, ) elif isinstance(sigma_leaf, float): current_leaf_scale = np.zeros((self.num_basis, self.num_basis), dtype=float) @@ -763,7 +784,7 @@ def sample( else: if sigma_leaf is None: current_leaf_scale = np.array( - [[np.squeeze(np.var(resid_train)) / num_trees_mean]] + [[2.0 / num_trees_mean]] ) elif isinstance(sigma_leaf, float): current_leaf_scale = np.array([[sigma_leaf]]) @@ -786,17 +807,98 @@ def sample( "sigma_leaf must be either a scalar or a 2d numpy array" ) else: - current_leaf_scale = np.array([[1.0]]) - if self.include_variance_forest: - if not a_forest: - a_forest = num_trees_variance / a_0**2 + 0.5 - if not b_forest: - b_forest = num_trees_variance / a_0**2 - else: - if not a_forest: - a_forest = 1.0 - if not b_forest: - b_forest = 1.0 + # Standardize if requested + if self.standardize: + self.y_bar = np.squeeze(np.mean(y_train)) + self.y_std = np.squeeze(np.std(y_train)) + else: + self.y_bar = 0 + self.y_std = 1 + + # Compute residual value + resid_train = (y_train - self.y_bar) / self.y_std + + # Compute initial value of root nodes in mean forest + init_val_mean = np.squeeze(np.mean(resid_train)) + + # Calibrate priors for global sigma^2 and sigma_leaf + if not sigma2_init: + sigma2_init = 1.0 * np.var(resid_train) + if not variance_forest_leaf_init: + variance_forest_leaf_init = 0.6 * np.var(resid_train) + current_sigma2 = sigma2_init + self.sigma2_init = sigma2_init + if self.include_mean_forest: + b_leaf = ( + np.squeeze(np.var(resid_train)) / num_trees_mean + if b_leaf is None + else b_leaf + ) + if self.has_basis: + if sigma_leaf is None: + current_leaf_scale = np.zeros((self.num_basis, self.num_basis), dtype=float) + np.fill_diagonal( + current_leaf_scale, + np.squeeze(np.var(resid_train)) / num_trees_mean, + ) + elif isinstance(sigma_leaf, float): + current_leaf_scale = np.zeros((self.num_basis, self.num_basis), dtype=float) + np.fill_diagonal(current_leaf_scale, sigma_leaf) + elif isinstance(sigma_leaf, np.ndarray): + if sigma_leaf.ndim != 2: + raise ValueError( + "sigma_leaf must be a 2d symmetric numpy array if provided in matrix form" + ) + if sigma_leaf.shape[0] != sigma_leaf.shape[1]: + raise ValueError( + "sigma_leaf must be a 2d symmetric numpy array if provided in matrix form" + ) + if sigma_leaf.shape[0] != self.num_basis: + raise ValueError( + "sigma_leaf must be a 2d symmetric numpy array with its dimensionality matching the basis dimension" + ) + current_leaf_scale = sigma_leaf + else: + raise ValueError( + "sigma_leaf must be either a scalar or a 2d symmetric numpy array" + ) + else: + if sigma_leaf is None: + current_leaf_scale = np.array( + [[np.squeeze(np.var(resid_train)) / num_trees_mean]] + ) + elif isinstance(sigma_leaf, float): + current_leaf_scale = np.array([[sigma_leaf]]) + elif isinstance(sigma_leaf, np.ndarray): + if sigma_leaf.ndim != 2: + raise ValueError( + "sigma_leaf must be a 2d symmetric numpy array if provided in matrix form" + ) + if sigma_leaf.shape[0] != sigma_leaf.shape[1]: + raise ValueError( + "sigma_leaf must be a 2d symmetric numpy array if provided in matrix form" + ) + if sigma_leaf.shape[0] != 1: + raise ValueError( + "sigma_leaf must be a 1x1 numpy array for this leaf model" + ) + current_leaf_scale = sigma_leaf + else: + raise ValueError( + "sigma_leaf must be either a scalar or a 2d numpy array" + ) + else: + current_leaf_scale = np.array([[1.0]]) + if self.include_variance_forest: + if not a_forest: + a_forest = num_trees_variance / a_0**2 + 0.5 + if not b_forest: + b_forest = num_trees_variance / a_0**2 + else: + if not a_forest: + a_forest = 1.0 + if not b_forest: + b_forest = 1.0 # Runtime checks on RFX group ids self.has_rfx = False @@ -894,11 +996,13 @@ def sample( # Residual residual_train = Residual(resid_train) - # C++ random number generator + # C++ and Numpy random number generator if random_seed is None: cpp_rng = RNG(-1) + self.rng = np.random.default_rng() else: cpp_rng = RNG(random_seed) + self.rng = np.random.default_rng(random_seed) # Set variance leaf model type (currently only one option) leaf_model_variance_forest = 3 @@ -1018,8 +1122,32 @@ def sample( keep_sample = True if keep_sample: sample_counter += 1 - # Sample the mean forest if self.include_mean_forest: + if self.probit_outcome_model: + # Sample latent probit variable z | - + forest_pred = active_forest_mean.predict(forest_dataset_train) + mu0 = forest_pred[y_train[:,0] == 0] + mu1 = forest_pred[y_train[:,0] == 1] + n0 = np.sum(y_train[:,0] == 0) + n1 = np.sum(y_train[:,0] == 1) + u0 = self.rng.uniform( + low=0.0, + high=norm.cdf(0 - mu0), + size=n0, + ) + u1 = self.rng.uniform( + low=norm.cdf(0 - mu1), + high=1.0, + size=n1, + ) + resid_train[y_train[:,0] == 0,0] = mu0 + norm.ppf(u0) + resid_train[y_train[:,0] == 1,0] = mu1 + norm.ppf(u1) + + # Update outcome + new_outcome = np.squeeze(resid_train) - forest_pred + residual_train.update_data(new_outcome) + + # Sample the mean forest forest_sampler_mean.sample_one_iteration( self.forest_container_mean, active_forest_mean, @@ -1183,8 +1311,33 @@ def sample( keep_sample = False if keep_sample: sample_counter += 1 - # Sample the mean forest + if self.include_mean_forest: + if self.probit_outcome_model: + # Sample latent probit variable z | - + forest_pred = active_forest_mean.predict(forest_dataset_train) + mu0 = forest_pred[y_train[:,0] == 0] + mu1 = forest_pred[y_train[:,0] == 1] + n0 = np.sum(y_train[:,0] == 0) + n1 = np.sum(y_train[:,0] == 1) + u0 = self.rng.uniform( + low=0.0, + high=norm.cdf(0 - mu0), + size=n0, + ) + u1 = self.rng.uniform( + low=norm.cdf(0 - mu1), + high=1.0, + size=n1, + ) + resid_train[y_train[:,0] == 0,0] = mu0 + norm.ppf(u0) + resid_train[y_train[:,0] == 1,0] = mu1 + norm.ppf(u1) + + # Update outcome + new_outcome = np.squeeze(resid_train) - forest_pred + residual_train.update_data(new_outcome) + + # Sample the mean forest forest_sampler_mean.sample_one_iteration( self.forest_container_mean, active_forest_mean, @@ -1686,6 +1839,7 @@ def to_json(self) -> str: bart_json.add_integer("num_samples", self.num_samples) bart_json.add_integer("num_basis", self.num_basis) bart_json.add_boolean("requires_basis", self.has_basis) + bart_json.add_boolean("probit_outcome_model", self.probit_outcome_model) # Add parameter samples if self.sample_sigma_global: @@ -1757,6 +1911,7 @@ def from_json(self, json_string: str) -> None: self.num_samples = bart_json.get_integer("num_samples") self.num_basis = bart_json.get_integer("num_basis") self.has_basis = bart_json.get_boolean("requires_basis") + self.probit_outcome_model = bart_json.get_boolean("probit_outcome_model") # Unpack parameter samples if self.sample_sigma_global: @@ -1865,6 +2020,7 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: self.num_samples = json_object_default.get_integer("num_samples") self.num_basis = json_object_default.get_integer("num_basis") self.has_basis = json_object_default.get_boolean("requires_basis") + self.probit_outcome_model = json_object_default.get_boolean("probit_outcome_model") # Unpack parameter samples if self.sample_sigma_global: diff --git a/stochtree/bcf.py b/stochtree/bcf.py index 49699f70..4ad3a841 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -72,7 +72,6 @@ class BCFModel: def __init__(self) -> None: # Internal flag for whether the sample() method has been run self.sampled = False - self.rng = np.random.default_rng() def sample( self, @@ -216,7 +215,7 @@ def sample( "adaptive_coding": True, "control_coding_init": -0.5, "treated_coding_init": 0.5, - "random_seed": -1, + "random_seed": None, "keep_burnin": False, "keep_gfr": False, "keep_every": 1, @@ -1384,11 +1383,13 @@ def sample( # Residual residual_train = Residual(resid_train) - # C++ random number generator + # C++ and numpy random number generator if random_seed is None: cpp_rng = RNG(-1) + self.rng = np.random.default_rng() else: cpp_rng = RNG(random_seed) + self.rng = np.random.default_rng(random_seed) # Sampling data structures global_model_config = GlobalModelConfig(global_error_variance=current_sigma2) diff --git a/stochtree/data.py b/stochtree/data.py index a29e80f5..8cbe76e0 100644 --- a/stochtree/data.py +++ b/stochtree/data.py @@ -175,4 +175,4 @@ def update_data(self, new_vector: np.array) -> None: Univariate numpy array of new residual values. """ n = new_vector.size - self.residual_cpp.UpdateData(new_vector, n) + self.residual_cpp.ReplaceData(new_vector, n) diff --git a/vignettes/BayesianSupervisedLearning.Rmd b/vignettes/BayesianSupervisedLearning.Rmd index 2b9337c3..6bfbf85e 100644 --- a/vignettes/BayesianSupervisedLearning.Rmd +++ b/vignettes/BayesianSupervisedLearning.Rmd @@ -327,4 +327,188 @@ plot(rowMeans(bart_model_root$y_hat_test), y_test, abline(0,1,col="red",lty=2,lwd=2.5) ``` +# Demo 4: Partitioned Linear Model with Probit Outcome Model + +## Simulation + +Here, we generate data from a simple partitioned linear model. + +```{r} +# Generate the data +n <- 1000 +p_x <- 100 +X <- matrix(runif(n*p_x), ncol = p_x) +f_X <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*X[,2]) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*X[,2]) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*X[,2]) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*X[,2]) +) +z <- f_X + rnorm(n, 0, 1) +y <- (z>0) * 1.0 + +# Split data into test and train sets +test_set_pct <- 0.5 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- as.data.frame(X[test_inds,]) +X_train <- as.data.frame(X[train_inds,]) +z_test <- z[test_inds] +z_train <- z[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] +``` + +## Sampling and Analysis + +### Warmstart + +We first sample from an ensemble model of $y \mid X$ using "warm-start" +initialization samples (@he2023stochastic). This is the default in +`stochtree`. + +```{r} +num_gfr <- 10 +num_burnin <- 0 +num_mcmc <- 100 +num_samples <- num_gfr + num_burnin + num_mcmc +general_params <- list(sample_sigma2_global = F) +mean_forest_params <- list(sample_sigma2_leaf = F, num_trees = 100, + probit_outcome_model = T) +bart_model_warmstart <- stochtree::bart( + X_train = X_train, y_train = y_train, X_test = X_test, + num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, + general_params = general_params, mean_forest_params = mean_forest_params +) +``` + +Since we've simulated this data, we can compare the true latent continuous outcome variable to the probit-scale predictions for a test set. + +```{r} +plot(rowMeans(bart_model_warmstart$y_hat_test), z_test, + pch=16, cex=0.75, xlab = "pred", ylab = "actual") +abline(0,1,col="red",lty=2,lwd=2.5) +``` + +On non-simulated datasets, the first thing we would evaluate is the prediction accuracy. + +```{r} +preds_test <- rowMeans(bart_model_warmstart$y_hat_test) > 0 +mean(preds_test == y_test) +``` + +We can also compute the [ROC curve](https://en.wikipedia.org/wiki/Receiver_operating_characteristic) for every posterior sample, as well as the ROC of the posterior mean. + +```{r} +num_thresholds <- 1000 +thresholds <- seq(0.001,0.999,length.out=num_thresholds) +tpr_mean <- rep(NA, num_thresholds) +fpr_mean <- rep(NA, num_thresholds) +tpr_samples <- matrix(NA, num_thresholds, num_mcmc) +fpr_samples <- matrix(NA, num_thresholds, num_mcmc) +yhat_samples <- bart_model_warmstart$y_hat_test +yhat_mean <- rowMeans(yhat_samples) +for (i in 1:num_thresholds) { + is_above_threshold_samples <- yhat_samples > qnorm(thresholds[i]) + is_above_threshold_mean <- yhat_mean > qnorm(thresholds[i]) + n_positive <- sum(y_test) + n_negative <- sum(y_test==0) + y_above_threshold_mean <- y_test[is_above_threshold_mean] + tpr_mean[i] <- sum(y_above_threshold_mean)/n_positive + fpr_mean[i] <- sum(y_above_threshold_mean==0)/n_negative + for (j in 1:num_mcmc) { + y_above_threshold <- y_test[is_above_threshold_samples[,j]] + tpr_samples[i,j] <- sum(y_above_threshold)/n_positive + fpr_samples[i,j] <- sum(y_above_threshold==0)/n_negative + } +} + +for (i in 1:num_mcmc) { + if (i == 1) { + plot(fpr_samples[,i], tpr_samples[,i], type = "line", col = "blue", lwd = 1, lty = 1, + xlab = "False positive rate", ylab = "True positive rate") + } else { + lines(fpr_samples[,i], tpr_samples[,i], col = "blue", lwd = 1, lty = 1) + } +} +lines(fpr_mean, tpr_mean, col = "black", lwd = 3, lty = 3) +``` + +Note that the nonlinearity of the ROC function means that the ROC curve of the posterior mean lies above most of the individual posterior sample ROC curves (which would not be the case if we had simply taken the mean of the ROC curves). + +### BART MCMC without Warmstart + +Next, we sample from this ensemble model without any warm-start initialization. + +```{r} +num_gfr <- 0 +num_burnin <- 100 +num_mcmc <- 100 +num_samples <- num_gfr + num_burnin + num_mcmc +general_params <- list(sample_sigma2_global = F) +mean_forest_params <- list(sample_sigma2_leaf = F, num_trees = 100, + probit_outcome_model = T) +bart_model_root <- stochtree::bart( + X_train = X_train, y_train = y_train, X_test = X_test, + num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, + general_params = general_params, mean_forest_params = mean_forest_params +) +``` + +Since we've simulated this data, we can compare the true latent continuous outcome variable to the probit-scale predictions for a test set. + +```{r} +plot(rowMeans(bart_model_root$y_hat_test), z_test, + pch=16, cex=0.75, xlab = "pred", ylab = "actual") +abline(0,1,col="red",lty=2,lwd=2.5) +``` + +On non-simulated datasets, the first thing we would evaluate is the prediction accuracy. + +```{r} +preds_test <- rowMeans(bart_model_root$y_hat_test) > 0 +mean(preds_test == y_test) +``` + +We can also compute the [ROC curve](https://en.wikipedia.org/wiki/Receiver_operating_characteristic) for every posterior sample, as well as the ROC of the posterior mean. + +```{r} +num_thresholds <- 1000 +thresholds <- seq(0.001,0.999,length.out=num_thresholds) +tpr_mean <- rep(NA, num_thresholds) +fpr_mean <- rep(NA, num_thresholds) +tpr_samples <- matrix(NA, num_thresholds, num_mcmc) +fpr_samples <- matrix(NA, num_thresholds, num_mcmc) +yhat_samples <- bart_model_root$y_hat_test +yhat_mean <- rowMeans(yhat_samples) +for (i in 1:num_thresholds) { + is_above_threshold_samples <- yhat_samples > qnorm(thresholds[i]) + is_above_threshold_mean <- yhat_mean > qnorm(thresholds[i]) + n_positive <- sum(y_test) + n_negative <- sum(y_test==0) + y_above_threshold_mean <- y_test[is_above_threshold_mean] + tpr_mean[i] <- sum(y_above_threshold_mean)/n_positive + fpr_mean[i] <- sum(y_above_threshold_mean==0)/n_negative + for (j in 1:num_mcmc) { + y_above_threshold <- y_test[is_above_threshold_samples[,j]] + tpr_samples[i,j] <- sum(y_above_threshold)/n_positive + fpr_samples[i,j] <- sum(y_above_threshold==0)/n_negative + } +} + +for (i in 1:num_mcmc) { + if (i == 1) { + plot(fpr_samples[,i], tpr_samples[,i], type = "line", col = "blue", lwd = 1, lty = 1, + xlab = "False positive rate", ylab = "True positive rate") + } else { + lines(fpr_samples[,i], tpr_samples[,i], col = "blue", lwd = 1, lty = 1) + } +} +lines(fpr_mean, tpr_mean, col = "black", lwd = 3, lty = 3) +``` + +Note that the nonlinearity of the ROC function means that the ROC curve of the posterior mean lies above most of the individual posterior sample ROC curves (which would not be the case if we had simply taken the mean of the ROC curves). + # References