Skip to content

Commit 1d0146e

Browse files
authored
Merge pull request #60 from ihmeuw-msca/bugfix/ensemble-infer-shape
Add `_infer_shape` function to the ensemble class
2 parents d37c5cf + 8f82962 commit 1d0146e

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

src/mrtool/core/model.py

+3
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,7 @@ def __init__(
523523

524524
self.weights = np.ones(self.num_sub_models) / self.num_sub_models
525525

526+
def _infer_shape(self) -> None:
526527
# inherent the dimension variable
527528
self.num_x_vars = self.sub_models[0].num_x_vars
528529
self.num_z_vars = self.sub_models[0].num_z_vars
@@ -542,6 +543,8 @@ def fit_model(
542543
for sub_model in self.sub_models:
543544
sub_model.fit_model(**fit_options)
544545

546+
self._infer_shape()
547+
545548
self.score_model(
546549
scores_weights=scores_weights, slopes=slopes, quantiles=quantiles
547550
)

0 commit comments

Comments
 (0)