Skip to content

Commit

Permalink
Add: support for predict_proba for estimators that support it
Browse files Browse the repository at this point in the history
  • Loading branch information
arnauddhaene committed May 30, 2023
1 parent b0936ac commit 46dbbda
Showing 1 changed file with 24 additions and 6 deletions.
30 changes: 24 additions & 6 deletions econml/metalearners/_metalearners.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def fit(self, Y, T, *, X, inference=None):
self.models[ind].fit(X[T == ind], Y[T == ind])

def const_marginal_effect(self, X):
"""Calculate the constant marignal treatment effect on a vector of features for each sample.
"""Calculate the constant marginal treatment effect on a vector of features for each sample.
Parameters
----------
Expand All @@ -127,7 +127,11 @@ def const_marginal_effect(self, X):
X = check_array(X)
taus = []
for ind in range(self._d_t[0]):
taus.append(self.models[ind + 1].predict(X) - self.models[0].predict(X))
if hasattr(self.models[ind + 1], 'predict_proba'):
taus.append(self.models[ind + 1].predict_proba(X)[:, 1] - self.models[0].predict_proba(X)[:, 1])
else:
taus.append(self.models[ind + 1].predict(X) - self.models[0].predict(X))

taus = np.column_stack(taus).reshape((-1,) + self._d_t + self._d_y) # shape as of m*d_t*d_y
if self._d_y:
taus = transpose(taus, (0, 2, 1)) # shape as of m*d_y*d_t
Expand Down Expand Up @@ -242,7 +246,12 @@ def const_marginal_effect(self, X=None):
X = check_array(X)
Xs, Ts = broadcast_unit_treatments(X, self._d_t[0] + 1)
feat_arr = np.concatenate((Xs, Ts), axis=1)
prediction = self.overall_model.predict(feat_arr).reshape((-1, self._d_t[0] + 1,) + self._d_y)

if hasattr(self.overall_model, 'predict_proba'):
prediction = self.overall_model.predict_proba(feat_arr)[:, 1].reshape((-1, self._d_t[0] + 1,) + self._d_y)
else:
prediction = self.overall_model.predict(feat_arr).reshape((-1, self._d_t[0] + 1,) + self._d_y)

if self._d_y:
prediction = transpose(prediction, (0, 2, 1))
taus = (prediction - np.repeat(prediction[:, :, 0], self._d_t[0] + 1).reshape(prediction.shape))[:, :, 1:]
Expand Down Expand Up @@ -393,8 +402,14 @@ def const_marginal_effect(self, X):
taus = []
for ind in range(self._d_t[0]):
propensity_scores = self.propensity_models[ind].predict_proba(X)[:, 1:]
tau_hat = propensity_scores * self.cate_controls_models[ind].predict(X).reshape(m, -1) \
+ (1 - propensity_scores) * self.cate_treated_models[ind].predict(X).reshape(m, -1)

if hasattr(self.cate_controls_models[ind], 'predict_proba'):
tau_hat = propensity_scores * self.cate_controls_models[ind].predict_proba(X)[:, 1].reshape(m, -1) \
+ (1 - propensity_scores) * self.cate_treated_models[ind].predict_proba(X)[:, 1].reshape(m, -1)
else:
tau_hat = propensity_scores * self.cate_controls_models[ind].predict(X).reshape(m, -1) \
+ (1 - propensity_scores) * self.cate_treated_models[ind].predict(X).reshape(m, -1)

taus.append(tau_hat)
taus = np.column_stack(taus).reshape((-1,) + self._d_t + self._d_y) # shape as of m*d_t*d_y
if self._d_y:
Expand Down Expand Up @@ -549,7 +564,10 @@ def const_marginal_effect(self, X):
X = check_array(X)
taus = []
for model in self.final_models:
taus.append(model.predict(X))
if hasattr(model, 'predict_proba'):
taus.append(model.predict_proba(X)[:, 1])
else:
taus.append(model.predict(X))
taus = np.column_stack(taus).reshape((-1,) + self._d_t + self._d_y) # shape as of m*d_t*d_y
if self._d_y:
taus = transpose(taus, (0, 2, 1)) # shape as of m*d_y*d_t
Expand Down

0 comments on commit 46dbbda

Please sign in to comment.