Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add categorical covariate model #57

Merged
merged 12 commits into from
Oct 7, 2024
12 changes: 11 additions & 1 deletion src/mrtool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,25 @@
"""

from .core import utils
from .core.cov_model import CovModel, LinearCovModel, LogCovModel
from .core.cov_model import (
CatCovModel,
CovModel,
LinearCatCovModel,
LinearCovModel,
LogCatCovModel,
LogCovModel,
)
from .core.data import MRData
from .core.model import MRBRT, MRBeRT
from .cov_selection.covfinder import CovFinder

__all__ = [
"MRData",
"CatCovModel",
"CovModel",
"LinearCatCovModel",
"LinearCovModel",
"LogCatCovModel",
"LogCovModel",
"MRBRT",
"MRBeRT",
Expand Down
241 changes: 237 additions & 4 deletions src/mrtool/core/cov_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
Covariates model for `mrtool`.
"""

import itertools
import warnings
from typing import Callable

import numpy as np
import pandas as pd
import xspline
from numpy.typing import NDArray

Expand Down Expand Up @@ -451,7 +456,7 @@ def create_spline(

Returns
-------
xspline.XSpline
XSpline
The spline object.

"""
Expand Down Expand Up @@ -535,7 +540,7 @@ def create_design_mat(self, data) -> tuple[NDArray, NDArray]:

Returns
-------
tuple[numpy.ndarray, numpy.ndarray]
tuple[NDArray, NDArray]
Return the design matrix for linear cov or spline.

"""
Expand Down Expand Up @@ -832,7 +837,7 @@ def create_z_mat(self, data):

Returns
-------
numpy.ndarray
NDArray
Design matrix for random effects.

"""
Expand Down Expand Up @@ -884,7 +889,7 @@ def create_z_mat(self, data):

Returns
-------
numpy.ndarray
NDArray
Design matrix for random effects.

"""
Expand Down Expand Up @@ -929,3 +934,231 @@ def num_constraints(self):
@property
def num_z_vars(self):
return int(self.use_re)


class CatCovModel(CovModel):
"""Categorical covariate model."""

def __init__(
self,
alt_cov,
name=None,
ref_cov=None,
ref_cat=None,
use_re=False,
use_re_intercept=True,
prior_order=None,
prior_beta_gaussian=None,
prior_beta_uniform=None,
prior_beta_laplace=None,
prior_gamma_gaussian=None,
prior_gamma_uniform=None,
prior_gamma_laplace=None,
) -> None:
self.ref_cat = ref_cat
self.use_re_intercept = use_re_intercept
if prior_order is not None:
prior_order_raw, prior_order = prior_order, []
for prior in prior_order_raw:
prior_order.extend(list(zip(prior, prior[1:])))
prior_order = list(set(prior_order))
prior_order.sort()
self.prior_order = prior_order
super().__init__(
alt_cov=alt_cov,
name=name,
ref_cov=ref_cov,
use_re=use_re,
prior_beta_gaussian=prior_beta_gaussian,
prior_beta_uniform=prior_beta_uniform,
prior_beta_laplace=prior_beta_laplace,
prior_gamma_gaussian=prior_gamma_gaussian,
prior_gamma_uniform=prior_gamma_uniform,
prior_gamma_laplace=prior_gamma_laplace,
)

if len(self.alt_cov) != 1:
raise ValueError("alt_cov should be a single column.")
if len(self.ref_cov) > 1:
raise ValueError("ref_cov should be nothing or a single column.")
if len(self.ref_cov) == 1 and self.ref_cat is None:
warnings.warn(
"ref_cat is not provided for a comparison covmodel, it will be "
"inferenced as the most common categories when attaching data."
)
if len(self.ref_cov) == 0 and self.ref_cat is not None:
raise ValueError(
"Cannot set ref_cat when this is not a comparison model."
)

self.cats: pd.Series

def attach_data(self, data: MRData) -> None:
"""Attach data and parse the categories. Number of variables will be
determined here and priors will be processed and if ref_cov is not set
before, and this is a comparison model, ref_cov will be inferred as the
most common category.

"""
alt_cov = data.get_covs(self.alt_cov)
ref_cov = data.get_covs(self.ref_cov)
unique_cats, counts = np.unique(
np.hstack([alt_cov, ref_cov]), return_counts=True
)
self.cats = pd.Series(unique_cats, name="cats")
self._process_priors()

if len(self.ref_cov) == 1:
if self.ref_cat is None:
self.ref_cat = unique_cats[counts.argmax()]
if self.ref_cat not in unique_cats:
raise ValueError(
f"ref_cat {self.ref_cat} is not in the categories."
)

if self.ref_cat is not None:
ref_index = dict(zip(self.cats, self.cats.index))[self.ref_cat]
ref_beta_uprior = self.prior_beta_uniform[:, ref_index]
if not (
np.isinf(ref_beta_uprior).all()
or np.allclose(ref_beta_uprior, 0.0)
):
warnings.warn(
f"Reset ref_cat beta uniform prior from {ref_beta_uprior} to (0, 0)"
)
self.prior_beta_uniform[:, ref_index] = 0.0
if self.use_re and (not self.use_re_intercept):
ref_gamma_uprior = self.prior_gamma_uniform[:, ref_index]
if not (
np.isinf(ref_gamma_uprior[1]).all()
or np.allclose(ref_gamma_uprior, 0.0)
):
warnings.warn(
f"Reset ref_cat gamma uniform prior from {ref_gamma_uprior} to (0, 0)"
)
self.prior_gamma_uniform[:, ref_index] = 0.0

if self.prior_order is not None:
for cat in set(
list(itertools.chain.from_iterable(self.prior_order))
):
if cat not in unique_cats:
raise ValueError(
f"Order prior category {cat} is not in the categories."
)

def has_data(self) -> bool:
"""Return if the data has been attached and categories has been parsed."""
return hasattr(self, "cats")

def encode(self, x: NDArray) -> NDArray:
"""Encode the provided categories into dummy variables."""
col = pd.merge(
pd.Series(x, name="cats"), self.cats.reset_index(), how="left"
)["index"]
if np.isnan(col).any():
raise ValueError("Categories not found")
mat = np.zeros((len(x), self.num_x_vars))
mat[range(len(x)), col] = 1.0
return mat

def create_design_mat(self, data: MRData) -> tuple[NDArray, NDArray]:
"""Create design matrix for alternative and reference categories."""
alt_cov = data.get_covs(self.alt_cov).ravel()
ref_cov = data.get_covs(self.ref_cov).ravel()

alt_mat = self.encode(alt_cov)
if ref_cov.size == 0:
ref_mat = np.empty((len(alt_cov), 0))
else:
ref_mat = self.encode(ref_cov)
return alt_mat, ref_mat

def create_constraint_mat(self) -> tuple[NDArray, NDArray]:
c_mat, c_val = super().create_constraint_mat()
if not self.prior_order:
return c_mat, c_val

c_val = np.hstack(
[
c_val,
np.repeat(
np.array([[-np.inf], [0.0]]), len(self.prior_order), axis=1
),
]
)

mats = []
for alt_cat, ref_cat in self.prior_order:
alt_mat = self.encode([alt_cat])
ref_mat = self.encode([ref_cat])
mats.append(alt_mat - ref_mat)
c_mat = np.vstack([c_mat] + mats)
return c_mat, c_val

@property
def num_x_vars(self) -> int:
"""Number of the fixed effects. Returns 0 if data is not attached
otherwise it will return the number of categories.

"""
if not hasattr(self, "cats"):
return 0
return len(self.cats)

@property
def num_z_vars(self) -> int:
"""Number of the random effects. When use_re_intercept is set to True,
it will use a single intercept random effect. Otherwise, it will use
each category will have its own random effect.

"""
if not self.use_re:
return 0
if self.use_re_intercept:
return 1
return self.num_x_vars

@property
def num_constraints(self) -> int:
num = super().num_constraints
if self.prior_order:
num += len(self.prior_order)
return num

def create_z_mat(self, data: MRData) -> NDArray:
if not self.use_re:
return np.empty((data.num_obs, 0))

if self.use_re_intercept:
alt_mat = np.ones((data.num_obs, 1))
ref_mat = np.empty((data.num_obs, 0))
else:
alt_mat, ref_mat = self.create_design_mat(data)

z_mat = alt_mat if ref_mat.size == 0 else alt_mat - ref_mat
return z_mat


class LinearCatCovModel(CatCovModel):
def create_x_fun(self, data: MRData) -> Callable:
alt_mat, ref_mat = self.create_design_mat(data)
return utils.mat_to_fun(alt_mat, ref_mat=ref_mat)


class LogCatCovModel(CatCovModel):
def attach_data(self, data: MRData) -> None:
super().attach_data(data)

# add positive constraints to each category
# Currently we hard-code the offset value
offset = 1e-6
shift = 0.0 if self.ref_cat is None else 1.0
lb = -shift + offset

self.prior_beta_uniform = np.maximum(lb, self.prior_beta_uniform)

def create_x_fun(self, data: MRData) -> Callable:
alt_mat, ref_mat = self.create_design_mat(data)
add_one = self.ref_cat is not None
return utils.mat_to_log_fun(alt_mat, ref_mat=ref_mat, add_one=add_one)
16 changes: 9 additions & 7 deletions src/mrtool/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _check_attr_type(self):
assert isinstance(self.covs, dict)
for cov in self.covs.values():
assert isinstance(cov, np.ndarray)
assert is_numeric_array(cov)
# assert is_numeric_array(cov)

def _get_cov_scales(self):
"""Compute the covariate scale."""
Expand All @@ -103,6 +103,7 @@ def _get_cov_scales(self):
self.cov_scales = {
cov_name: np.max(np.abs(cov))
for cov_name, cov in self.covs.items()
if is_numeric_array(cov)
}
zero_covs = [
cov_name
Expand Down Expand Up @@ -159,12 +160,13 @@ def _remove_nan_in_covs(self):
if not self.is_empty():
index = np.full(self.num_obs, False)
for cov_name, cov in self.covs.items():
cov_index = np.isnan(cov)
if cov_index.any():
warnings.warn(
f"There are {cov_index.sum()} nans in covaraite {cov_name}."
)
index = index | cov_index
if is_numeric_array(cov):
cov_index = np.isnan(cov)
if cov_index.any():
warnings.warn(
f"There are {cov_index.sum()} nans in covaraite {cov_name}."
)
index = index | cov_index
self._remove_data(index)

def _remove_data(self, index: NDArray):
Expand Down
18 changes: 10 additions & 8 deletions src/mrtool/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@ def __init__(
self.cov_names.extend(cov_model.covs)
self.num_covs = len(self.cov_names)

# place holder for the limetr objective
self.lt: LimeTr
self.beta_soln: NDArray
self.gamma_soln: NDArray
self.u_soln: NDArray
self.w_soln: NDArray
self.re_soln: NDArray

def _infer_shape(self) -> None:
# add random effects
if not any([cov_model.use_re for cov_model in self.cov_models]):
self.cov_models[0].use_re = True
Expand Down Expand Up @@ -83,14 +92,6 @@ def __init__(
[cov_model.num_regularizations for cov_model in self.cov_models]
)

# place holder for the limetr objective
self.lt: LimeTr
self.beta_soln: NDArray
self.gamma_soln: NDArray
self.u_soln: NDArray
self.w_soln: NDArray
self.re_soln: NDArray

def attach_data(self, data=None):
"""Attach data to cov_model."""
data = self.data if data is None else data
Expand Down Expand Up @@ -239,6 +240,7 @@ def fit_model(self, **fit_options):
"""
if not all([cov_model.has_data() for cov_model in self.cov_models]):
self.attach_data()
self._infer_shape()

# dimensions
n = self.data.study_sizes
Expand Down
2 changes: 2 additions & 0 deletions src/mrtool/cov_selection/covfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ def create_model(
model = MRBRT(
self.data, cov_models=cov_models, inlier_pct=self.inlier_pct
)
model.attach_data()
model._infer_shape()
return model

def fit_gaussian_model(self, covs: list[str]) -> MRBRT:
Expand Down
Loading
Loading