Skip to content

Commit ebc0d75

Browse files
committed
add constraints
1 parent 05da0d1 commit ebc0d75

File tree

2 files changed

+102
-7
lines changed

2 files changed

+102
-7
lines changed

src/mrtool/core/cov_model.py

+28-6
Original file line numberDiff line numberDiff line change
@@ -1026,7 +1026,7 @@ def attach_data(self, data: MRData) -> None:
10261026
f"Reset ref_cat beta uniform prior from {ref_beta_uprior} to (0, 0)"
10271027
)
10281028
self.prior_beta_uniform[:, ref_index] = 0.0
1029-
if not self.use_re_intercept:
1029+
if self.use_re and (not self.use_re_intercept):
10301030
ref_gamma_uprior = self.prior_gamma_uniform[:, ref_index]
10311031
if not (
10321032
np.isinf(ref_gamma_uprior[1]).all()
@@ -1068,14 +1068,32 @@ def create_design_mat(self, data: MRData) -> tuple[NDArray, NDArray]:
10681068

10691069
alt_mat = self.encode(alt_cov)
10701070
if ref_cov.size == 0:
1071-
ref_mat = np.zeros((len(alt_cov), self.num_x_vars))
1071+
ref_mat = np.empty((len(alt_cov), 0))
10721072
else:
10731073
ref_mat = self.encode(ref_cov)
10741074
return alt_mat, ref_mat
10751075

10761076
def create_constraint_mat(self) -> tuple[NDArray, NDArray]:
1077-
"""TODO: Create constraint matrix from order priors."""
1078-
return np.empty((0, self.num_x_vars)), np.empty((2, 0))
1077+
c_mat, c_val = super().create_constraint_mat()
1078+
if not self.prior_order:
1079+
return c_mat, c_val
1080+
1081+
c_val = np.hstack(
1082+
[
1083+
c_val,
1084+
np.repeat(
1085+
np.array([[-np.inf], [0.0]]), len(self.prior_order), axis=1
1086+
),
1087+
]
1088+
)
1089+
1090+
mats = []
1091+
for alt_cat, ref_cat in self.prior_order:
1092+
alt_mat = self.encode([alt_cat])
1093+
ref_mat = self.encode([ref_cat])
1094+
mats.append(alt_mat - ref_mat)
1095+
c_mat = np.vstack([c_mat] + mats)
1096+
return c_mat, c_val
10791097

10801098
@property
10811099
def num_x_vars(self) -> int:
@@ -1094,11 +1112,15 @@ def num_z_vars(self) -> int:
10941112
each category will have its own random effect.
10951113
10961114
"""
1115+
if not self.use_re:
1116+
return 0
10971117
if self.use_re_intercept:
10981118
return 1
10991119
return self.num_x_vars
11001120

11011121
@property
11021122
def num_constraints(self) -> int:
1103-
"""TODO: Overwrite the number of constraints."""
1104-
return 0
1123+
num = super().num_constraints
1124+
if self.prior_order:
1125+
num += len(self.prior_order)
1126+
return num

tests/test_cat_covmodel.py

+74-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ def test_ref_cov(data):
6666
assert covmodel.ref_cat is None
6767
assert np.isinf(covmodel.prior_beta_uniform).all()
6868

69-
covmodel = CatCovModel(alt_cov="alt_cat", use_re_intercept=False)
69+
covmodel = CatCovModel(
70+
alt_cov="alt_cat", use_re=True, use_re_intercept=False
71+
)
7072
covmodel.attach_data(data)
7173
assert covmodel.ref_cat is None
7274
assert np.isinf(covmodel.prior_beta_uniform).all()
@@ -83,6 +85,7 @@ def test_ref_cov(data):
8385
alt_cov="alt_cat",
8486
ref_cov="ref_cat",
8587
ref_cat="B",
88+
use_re=True,
8689
use_re_intercept=False,
8790
)
8891
covmodel.attach_data(data)
@@ -188,3 +191,73 @@ def test_order_prior(data):
188191
prior_order=[["A", "B"], ["B", "C", "E"]],
189192
)
190193
covmodel.attach_data(data)
194+
195+
196+
def test_num_x_vars(data):
197+
covmodel = CatCovModel(alt_cov="alt_cat", ref_cov="ref_cat", ref_cat="A")
198+
assert covmodel.num_x_vars == 0
199+
covmodel.attach_data(data)
200+
assert covmodel.num_x_vars == 4
201+
202+
203+
def test_num_z_vars(data):
204+
covmodel = CatCovModel(alt_cov="alt_cat", ref_cov="ref_cat", ref_cat="A")
205+
assert covmodel.num_z_vars == 0
206+
207+
covmodel = CatCovModel(
208+
alt_cov="alt_cat", ref_cov="ref_cat", ref_cat="A", use_re=True
209+
)
210+
assert covmodel.num_z_vars == 1
211+
212+
covmodel = CatCovModel(
213+
alt_cov="alt_cat",
214+
ref_cov="ref_cat",
215+
ref_cat="A",
216+
use_re=True,
217+
use_re_intercept=False,
218+
)
219+
assert covmodel.num_z_vars == 0
220+
covmodel.attach_data(data)
221+
assert covmodel.num_z_vars == 4
222+
223+
224+
def test_num_constraints(data):
225+
covmodel = CatCovModel(alt_cov="alt_cat", ref_cov="ref_cat", ref_cat="A")
226+
assert covmodel.num_constraints == 0
227+
covmodel = CatCovModel(
228+
alt_cov="alt_cat",
229+
ref_cov="ref_cat",
230+
ref_cat="A",
231+
prior_order=[["A", "B", "C"]],
232+
)
233+
assert covmodel.num_constraints == 2
234+
235+
236+
def test_create_constraint_mat(data):
237+
covmodel = CatCovModel(
238+
alt_cov="alt_cat",
239+
ref_cov="ref_cat",
240+
ref_cat="A",
241+
prior_order=[["A", "B", "C"]],
242+
)
243+
covmodel.attach_data(data)
244+
c_mat, c_val = covmodel.create_constraint_mat()
245+
assert np.allclose(
246+
c_mat,
247+
np.array(
248+
[
249+
[1.0, -1.0, 0.0, 0.0],
250+
[0.0, 1.0, -1.0, 0.0],
251+
]
252+
),
253+
)
254+
255+
assert np.allclose(
256+
c_val,
257+
np.array(
258+
[
259+
[-np.inf, -np.inf],
260+
[0.0, 0.0],
261+
]
262+
),
263+
)

0 commit comments

Comments
 (0)