Skip to content

Commit

Permalink
Return tuple of arrays instead of concatenated arrays.
Browse files Browse the repository at this point in the history
  • Loading branch information
kklein committed Aug 7, 2024
1 parent fc767a2 commit 3c76626
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 9 deletions.
2 changes: 1 addition & 1 deletion docs/examples/example_estimating_ates.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@
" np.c_[\n",
" naive_est,\n",
" linreg_est,\n",
" metalearners_est.flatten(),\n",
" np.hstack(metalearners_est),\n",
" doubleml_est,\n",
" econml_est,\n",
"], index = ['est', 'se'],\n",
Expand Down
11 changes: 6 additions & 5 deletions metalearners/drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,11 +344,11 @@ def average_treatment_effect(
y: Vector,
w: Vector,
is_oos: bool,
) -> np.ndarray:
) -> tuple[np.ndarray, np.ndarray]:
"""Compute Average Treatment Effect (ATE) for each treatment variant using the
Augmented IPW estimator (Robins et al 1994). Does not require fitting a second-
stage treatment model: it uses the pseudo-outcome alone and computes the average
and SE. Can be used following the
stage treatment model: it uses the pseudo-outcome alone and computes the point
estimate and standard error. Can be used following the
:meth:`~metalearners.drlearner.DRLearner.fit_all_nuisance` method.
Args:
Expand All @@ -358,7 +358,8 @@ def average_treatment_effect(
is_oos (bool): indicator whether data is out of sample
Returns:
np.ndarray: Treatment effect and standard error for each treatment variant.
np.ndarray: Treatment effect for each treatment variant.
np.ndarray: Standard error for each treatment variant.
"""
if not self._nuisance_models_fit:
raise ValueError(
Expand All @@ -375,7 +376,7 @@ def average_treatment_effect(
)
treatment_effect = gamma_matrix.mean(axis=0)
standard_error = gamma_matrix.std(axis=0) / np.sqrt(len(X))
return np.c_[treatment_effect, standard_error]
return treatment_effect, standard_error

def _pseudo_outcome(
self,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def test_drlearner_onnx(
np.testing.assert_allclose(ml.predict(X, True, "overall"), pred_onnx, atol=5e-4)


def test_treatment_effect(
def test_average_treatment_effect(
numerical_experiment_dataset_continuous_outcome_binary_treatment_linear_te,
):
X, _, W, Y, _, tau = (
Expand All @@ -150,5 +150,5 @@ def test_treatment_effect(
n_folds=2,
)
ml.fit_all_nuisance(X, Y, W)
est = ml.average_treatment_effect(X, Y, W, is_oos=False)
np.testing.assert_almost_equal(est[:, 0], tau.mean(), decimal=1)
ate_estimate, _ = ml.average_treatment_effect(X, Y, W, is_oos=False)
np.testing.assert_almost_equal(ate_estimate, tau.mean(), decimal=1)

0 comments on commit 3c76626

Please sign in to comment.