Skip to content

Commit baed7c1

Browse files
committed
add majority part of the cat covmodel
1 parent 0942ac9 commit baed7c1

File tree

3 files changed

+229
-11
lines changed

3 files changed

+229
-11
lines changed

src/mrtool/core/cov_model.py

+112-4
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"""
88

99
import numpy as np
10+
import pandas as pd
1011
import xspline
1112
from numpy.typing import NDArray
1213

@@ -451,7 +452,7 @@ def create_spline(
451452
452453
Returns
453454
-------
454-
xspline.XSpline
455+
XSpline
455456
The spline object.
456457
457458
"""
@@ -535,7 +536,7 @@ def create_design_mat(self, data) -> tuple[NDArray, NDArray]:
535536
536537
Returns
537538
-------
538-
tuple[numpy.ndarray, numpy.ndarray]
539+
tuple[NDArray, NDArray]
539540
Return the design matrix for linear cov or spline.
540541
541542
"""
@@ -832,7 +833,7 @@ def create_z_mat(self, data):
832833
833834
Returns
834835
-------
835-
numpy.ndarray
836+
NDArray
836837
Design matrix for random effects.
837838
838839
"""
@@ -884,7 +885,7 @@ def create_z_mat(self, data):
884885
885886
Returns
886887
-------
887-
numpy.ndarray
888+
NDArray
888889
Design matrix for random effects.
889890
890891
"""
@@ -929,3 +930,110 @@ def num_constraints(self):
929930
@property
930931
def num_z_vars(self):
931932
return int(self.use_re)
933+
934+
935+
class CatCovModel(CovModel):
936+
"""Categorical covariate model.
937+
938+
TODO: Add order prior.
939+
"""
940+
941+
def __init__(
942+
self,
943+
alt_cov,
944+
name=None,
945+
ref_cov=None,
946+
ref_cat=None,
947+
use_re=False,
948+
prior_beta_gaussian=None,
949+
prior_beta_uniform=None,
950+
prior_beta_laplace=None,
951+
prior_gamma_gaussian=None,
952+
prior_gamma_uniform=None,
953+
prior_gamma_laplace=None,
954+
) -> None:
955+
super().__init__(
956+
alt_cov=alt_cov,
957+
name=name,
958+
ref_cov=ref_cov,
959+
use_re=use_re,
960+
prior_beta_gaussian=prior_beta_gaussian,
961+
prior_beta_uniform=prior_beta_uniform,
962+
prior_beta_laplace=prior_beta_laplace,
963+
prior_gamma_gaussian=prior_gamma_gaussian,
964+
prior_gamma_uniform=prior_gamma_uniform,
965+
prior_gamma_laplace=prior_gamma_laplace,
966+
)
967+
self.ref_cat = ref_cat
968+
if len(self.alt_cov) != 1:
969+
raise ValueError("alt_cov should be a single column.")
970+
if len(self.ref_cov) > 1:
971+
raise ValueError("ref_cov should be nothing or a single column.")
972+
973+
self.cats: pd.Series
974+
975+
def attach_data(self, data: MRData) -> None:
976+
"""Attach data and parse the categories. Number of variables will be
977+
determined here and priors will be processed here as well.
978+
979+
"""
980+
alt_cov = data.get_covs(self.alt_cov)
981+
ref_cov = data.get_covs(self.ref_cov)
982+
self.cats = pd.Series(
983+
np.unique(np.hstack([alt_cov, ref_cov])),
984+
name="cats",
985+
)
986+
self._process_priors()
987+
988+
def has_data(self) -> bool:
989+
"""Return if the data has been attached and categories has been parsed."""
990+
return hasattr(self, "cats")
991+
992+
def encode(self, x: NDArray) -> NDArray:
993+
"""Encode the provided categories into dummy variables."""
994+
col = pd.merge(pd.Series(x, name="cats"), self.cats.reset_index())[
995+
"index"
996+
]
997+
mat = np.zeros((len(x), self.num_x_vars))
998+
mat[range(len(x)), col] = 1.0
999+
return mat
1000+
1001+
def create_design_mat(self, data: MRData) -> tuple[NDArray, NDArray]:
1002+
"""Create design matrix for alternative and reference categories."""
1003+
alt_cov = data.get_covs(self.alt_cov).ravel()
1004+
ref_cov = data.get_covs(self.ref_cov).ravel()
1005+
1006+
alt_mat = self.encode(alt_cov)
1007+
if ref_cov.size == 0:
1008+
ref_mat = np.zeros((len(alt_cov), self.num_x_vars))
1009+
else:
1010+
ref_mat = self.encode(ref_cov)
1011+
return alt_mat, ref_mat
1012+
1013+
def create_constraint_mat(self) -> tuple[NDArray, NDArray]:
1014+
"""TODO: Create constraint matrix from order priors."""
1015+
return np.empty((0, self.num_x_vars)), np.empty((2, 0))
1016+
1017+
@property
1018+
def num_x_vars(self) -> int:
1019+
"""Number of the fixed effects. Returns 0 if data is not attached
1020+
otherwise it will return the number of categories.
1021+
1022+
"""
1023+
if not hasattr(self, "cats"):
1024+
return 0
1025+
return len(self.cats)
1026+
1027+
@property
1028+
def num_z_vars(self) -> int:
1029+
"""Number of the random effects. Currently it is the same with the
1030+
number of the fixed effects, but this is to be discussed.
1031+
TODO: Overwrite the number of random effects.
1032+
1033+
"""
1034+
return self.num_x_vars
1035+
1036+
@property
1037+
def num_constraints(self) -> int:
1038+
"""TODO: Overwrite the number of constraints."""
1039+
return 0

src/mrtool/core/data.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def _check_attr_type(self):
9191
assert isinstance(self.covs, dict)
9292
for cov in self.covs.values():
9393
assert isinstance(cov, np.ndarray)
94-
assert is_numeric_array(cov)
94+
# assert is_numeric_array(cov)
9595

9696
def _get_cov_scales(self):
9797
"""Compute the covariate scale."""
@@ -103,6 +103,7 @@ def _get_cov_scales(self):
103103
self.cov_scales = {
104104
cov_name: np.max(np.abs(cov))
105105
for cov_name, cov in self.covs.items()
106+
if is_numeric_array(cov)
106107
}
107108
zero_covs = [
108109
cov_name
@@ -159,12 +160,13 @@ def _remove_nan_in_covs(self):
159160
if not self.is_empty():
160161
index = np.full(self.num_obs, False)
161162
for cov_name, cov in self.covs.items():
162-
cov_index = np.isnan(cov)
163-
if cov_index.any():
164-
warnings.warn(
165-
f"There are {cov_index.sum()} nans in covaraite {cov_name}."
166-
)
167-
index = index | cov_index
163+
if is_numeric_array(cov):
164+
cov_index = np.isnan(cov)
165+
if cov_index.any():
166+
warnings.warn(
167+
f"There are {cov_index.sum()} nans in covaraite {cov_name}."
168+
)
169+
index = index | cov_index
168170
self._remove_data(index)
169171

170172
def _remove_data(self, index: NDArray):

tests/test_cat_covmodel.py

+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import numpy as np
2+
import pandas as pd
3+
import pytest
4+
5+
from mrtool.core.cov_model import CatCovModel
6+
from mrtool.core.data import MRData
7+
8+
9+
@pytest.fixture
10+
def data():
11+
df = pd.DataFrame(
12+
dict(
13+
obs=[0, 1, 0, 1],
14+
obs_se=[0.1, 0.1, 0.1, 0.1],
15+
alt_cat=["A", "A", "B", "C"],
16+
ref_cat=["A", "B", "B", "D"],
17+
study_id=[1, 1, 2, 2],
18+
)
19+
)
20+
data = MRData()
21+
data.load_df(
22+
df,
23+
col_obs="obs",
24+
col_obs_se="obs_se",
25+
col_covs=["alt_cat", "ref_cat"],
26+
col_study_id="study_id",
27+
)
28+
return data
29+
30+
31+
def test_init():
32+
covmodel = CatCovModel(alt_cov="alt_cat", ref_cov="ref_cat")
33+
assert covmodel.alt_cov == ["alt_cat"]
34+
assert covmodel.ref_cov == ["ref_cat"]
35+
36+
covmodel = CatCovModel(alt_cov="alt_cat")
37+
assert covmodel.alt_cov == ["alt_cat"]
38+
assert covmodel.ref_cov == []
39+
40+
with pytest.raises(ValueError):
41+
CatCovModel(alt_cov=["a", "b"])
42+
43+
with pytest.raises(ValueError):
44+
CatCovModel(alt_cov="a", ref_cov=["a", "b"])
45+
46+
47+
def test_attach_data(data):
48+
covmodel = CatCovModel(alt_cov="alt_cat", ref_cov="ref_cat")
49+
assert not hasattr(covmodel, "cats")
50+
covmodel.attach_data(data)
51+
assert covmodel.cats.to_list() == ["A", "B", "C", "D"]
52+
53+
54+
def test_has_data(data):
55+
covmodel = CatCovModel(alt_cov="alt_cat", ref_cov="ref_cat")
56+
assert not covmodel.has_data()
57+
58+
covmodel.attach_data(data)
59+
assert covmodel.has_data()
60+
61+
62+
def test_encode(data):
63+
covmodel = CatCovModel(alt_cov="alt_cat", ref_cov="ref_cat")
64+
covmodel.attach_data(data)
65+
66+
mat = covmodel.encode(["A", "B", "C", "C"])
67+
true_mat = np.array(
68+
[
69+
[
70+
[1.0, 0.0, 0.0, 0.0],
71+
[0.0, 1.0, 0.0, 0.0],
72+
[0.0, 0.0, 1.0, 0.0],
73+
[0.0, 0.0, 1.0, 0.0],
74+
]
75+
]
76+
)
77+
assert np.allclose(mat, true_mat)
78+
79+
80+
def test_create_design_mat(data):
81+
covmodel = CatCovModel(alt_cov="alt_cat", ref_cov="ref_cat")
82+
covmodel.attach_data(data)
83+
84+
alt_mat, ref_mat = covmodel.create_design_mat(data)
85+
86+
assert np.allclose(
87+
alt_mat,
88+
np.array(
89+
[
90+
[1.0, 0.0, 0.0, 0.0],
91+
[1.0, 0.0, 0.0, 0.0],
92+
[0.0, 1.0, 0.0, 0.0],
93+
[0.0, 0.0, 1.0, 0.0],
94+
]
95+
),
96+
)
97+
98+
assert np.allclose(
99+
ref_mat,
100+
np.array(
101+
[
102+
[1.0, 0.0, 0.0, 0.0],
103+
[0.0, 1.0, 0.0, 0.0],
104+
[0.0, 1.0, 0.0, 0.0],
105+
[0.0, 0.0, 0.0, 1.0],
106+
]
107+
),
108+
)

0 commit comments

Comments
 (0)