Skip to content

Commit 295ea52

Browse files
committed
add ref_cov parsing
1 parent 8a718b8 commit 295ea52

File tree

2 files changed

+69
-13
lines changed

2 files changed

+69
-13
lines changed

src/mrtool/core/cov_model.py

+30-7
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
Covariates model for `mrtool`.
77
"""
88

9+
import warnings
10+
911
import numpy as np
1012
import pandas as pd
1113
import xspline
@@ -972,31 +974,52 @@ def __init__(
972974
raise ValueError("alt_cov should be a single column.")
973975
if len(self.ref_cov) > 1:
974976
raise ValueError("ref_cov should be nothing or a single column.")
977+
if len(self.ref_cov) == 1 and self.ref_cat is None:
978+
warnings.warn(
979+
"ref_cat is not provided for a comparison covmodel, it will be "
980+
"inferenced as the most common categories when attaching data."
981+
)
982+
if len(self.ref_cov) == 0 and self.ref_cat is not None:
983+
raise ValueError(
984+
"Cannot set ref_cat when this is not a comparison model."
985+
)
975986

976987
self.cats: pd.Series
977988

978989
def attach_data(self, data: MRData) -> None:
979990
"""Attach data and parse the categories. Number of variables will be
980-
determined here and priors will be processed here as well.
991+
determined here and priors will be processed and if ref_cov is not set
992+
before, and this is a comparison model, ref_cov will be inferred as the
993+
most common category.
981994
982995
"""
983996
alt_cov = data.get_covs(self.alt_cov)
984997
ref_cov = data.get_covs(self.ref_cov)
985-
self.cats = pd.Series(
986-
np.unique(np.hstack([alt_cov, ref_cov])),
987-
name="cats",
998+
unique_cats, counts = np.unique(
999+
np.hstack([alt_cov, ref_cov]), return_counts=True
9881000
)
1001+
self.cats = pd.Series(unique_cats, name="cats")
9891002
self._process_priors()
9901003

1004+
if len(self.ref_cov) == 1:
1005+
if self.ref_cat is None:
1006+
self.ref_cat = unique_cats[counts.argmax()]
1007+
if self.ref_cat not in unique_cats:
1008+
raise ValueError(
1009+
f"ref_cat {self.ref_cat} is not in the categories."
1010+
)
1011+
9911012
def has_data(self) -> bool:
9921013
"""Return if the data has been attached and categories has been parsed."""
9931014
return hasattr(self, "cats")
9941015

9951016
def encode(self, x: NDArray) -> NDArray:
9961017
"""Encode the provided categories into dummy variables."""
997-
col = pd.merge(pd.Series(x, name="cats"), self.cats.reset_index())[
998-
"index"
999-
]
1018+
col = pd.merge(
1019+
pd.Series(x, name="cats"), self.cats.reset_index(), how="left"
1020+
)["index"]
1021+
if np.isnan(col).any():
1022+
raise ValueError("Categories not found")
10001023
mat = np.zeros((len(x), self.num_x_vars))
10011024
mat[range(len(x)), col] = 1.0
10021025
return mat

tests/test_cat_covmodel.py

+39-6
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def data():
2929

3030

3131
def test_init():
32-
covmodel = CatCovModel(alt_cov="alt_cat", ref_cov="ref_cat")
32+
covmodel = CatCovModel(alt_cov="alt_cat", ref_cov="ref_cat", ref_cat="A")
3333
assert covmodel.alt_cov == ["alt_cat"]
3434
assert covmodel.ref_cov == ["ref_cat"]
3535

@@ -41,26 +41,51 @@ def test_init():
4141
CatCovModel(alt_cov=["a", "b"])
4242

4343
with pytest.raises(ValueError):
44-
CatCovModel(alt_cov="a", ref_cov=["a", "b"])
44+
CatCovModel(alt_cov="a", ref_cov=["a", "b"], ref_cat="A")
4545

4646

4747
def test_attach_data(data):
48-
covmodel = CatCovModel(alt_cov="alt_cat", ref_cov="ref_cat")
48+
covmodel = CatCovModel(alt_cov="alt_cat", ref_cov="ref_cat", ref_cat="A")
4949
assert not hasattr(covmodel, "cats")
5050
covmodel.attach_data(data)
5151
assert covmodel.cats.to_list() == ["A", "B", "C", "D"]
5252

5353

54+
def test_ref_cov(data):
55+
with pytest.raises(ValueError):
56+
covmodel = CatCovModel(
57+
alt_cov="alt_cat", ref_cov="ref_cat", ref_cat="E"
58+
)
59+
covmodel.attach_data(data)
60+
61+
with pytest.raises(ValueError):
62+
covmodel = CatCovModel(alt_cov="alt_cat", ref_cat="A")
63+
64+
covmodel = CatCovModel(alt_cov="alt_cat")
65+
covmodel.attach_data(data)
66+
assert covmodel.ref_cat is None
67+
68+
with pytest.warns():
69+
covmodel = CatCovModel(alt_cov="alt_cat", ref_cov="ref_cat")
70+
assert covmodel.ref_cat is None
71+
covmodel.attach_data(data)
72+
assert covmodel.ref_cat == "A"
73+
74+
covmodel = CatCovModel(alt_cov="alt_cat", ref_cov="ref_cat", ref_cat="B")
75+
covmodel.attach_data(data)
76+
assert covmodel.ref_cat == "B"
77+
78+
5479
def test_has_data(data):
55-
covmodel = CatCovModel(alt_cov="alt_cat", ref_cov="ref_cat")
80+
covmodel = CatCovModel(alt_cov="alt_cat", ref_cov="ref_cat", ref_cat="A")
5681
assert not covmodel.has_data()
5782

5883
covmodel.attach_data(data)
5984
assert covmodel.has_data()
6085

6186

6287
def test_encode(data):
63-
covmodel = CatCovModel(alt_cov="alt_cat", ref_cov="ref_cat")
88+
covmodel = CatCovModel(alt_cov="alt_cat", ref_cov="ref_cat", ref_cat="A")
6489
covmodel.attach_data(data)
6590

6691
mat = covmodel.encode(["A", "B", "C", "C"])
@@ -77,8 +102,16 @@ def test_encode(data):
77102
assert np.allclose(mat, true_mat)
78103

79104

105+
def test_encode_fail(data):
106+
covmodel = CatCovModel(alt_cov="alt_cat", ref_cov="ref_cat", ref_cat="A")
107+
covmodel.attach_data(data)
108+
109+
with pytest.raises(ValueError):
110+
covmodel.encode(["A", "B", "C", "E"])
111+
112+
80113
def test_create_design_mat(data):
81-
covmodel = CatCovModel(alt_cov="alt_cat", ref_cov="ref_cat")
114+
covmodel = CatCovModel(alt_cov="alt_cat", ref_cov="ref_cat", ref_cat="A")
82115
covmodel.attach_data(data)
83116

84117
alt_mat, ref_mat = covmodel.create_design_mat(data)

0 commit comments

Comments
 (0)