Skip to content

Commit

Permalink
Fix shape issue.
Browse files Browse the repository at this point in the history
  • Loading branch information
kklein committed Aug 9, 2024
1 parent b87fefa commit 486d4f8
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions metalearners/rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,10 +565,16 @@ def predict_conditional_average_outcomes(
# TODO: Consider whether the readability vs efficiency trade-off should be dealth with differently here.
# One could use a matrix/tensor operation.
for treatment_variant in range(1, self.n_variants):
for outcome_channel in range(0, cate_estimates.shape[2]):
control_outcomes[:, outcome_channel] -= (
if (n_outputs := cate_estimates.shape[2]) > 1:
for outcome_channel in range(0, n_outputs):
control_outcomes[:, outcome_channel] -= (
propensity_estimates[:, treatment_variant]
* cate_estimates[:, treatment_variant - 1, outcome_channel]
)
else:
control_outcomes -= (
propensity_estimates[:, treatment_variant]
* cate_estimates[:, treatment_variant - 1, outcome_channel]
* cate_estimates[:, treatment_variant - 1, 0]
)

conditional_average_outcomes_list.append(control_outcomes)
Expand Down

0 comments on commit 486d4f8

Please sign in to comment.