diff --git a/docs/examples/example_estimating_ates.ipynb b/docs/examples/example_estimating_ates.ipynb index 72011e5..0eda732 100644 --- a/docs/examples/example_estimating_ates.ipynb +++ b/docs/examples/example_estimating_ates.ipynb @@ -310,7 +310,7 @@ " np.c_[\n", " naive_est,\n", " linreg_est,\n", - " metalearners_est.flatten(),\n", + " np.hstack(metalearners_est),\n", " doubleml_est,\n", " econml_est,\n", "], index = ['est', 'se'],\n", diff --git a/metalearners/drlearner.py b/metalearners/drlearner.py index 1bfc2ab..7bff6fb 100644 --- a/metalearners/drlearner.py +++ b/metalearners/drlearner.py @@ -344,11 +344,11 @@ def average_treatment_effect( y: Vector, w: Vector, is_oos: bool, - ) -> np.ndarray: + ) -> tuple[np.ndarray, np.ndarray]: """Compute Average Treatment Effect (ATE) for each treatment variant using the Augmented IPW estimator (Robins et al 1994). Does not require fitting a second- - stage treatment model: it uses the pseudo-outcome alone and computes the average - and SE. Can be used following the + stage treatment model: it uses the pseudo-outcome alone and computes the point + estimate and standard error. Can be used following the :meth:`~metalearners.drlearner.DRLearner.fit_all_nuisance` method. Args: @@ -358,7 +358,8 @@ def average_treatment_effect( is_oos (bool): indicator whether data is out of sample Returns: - np.ndarray: Treatment effect and standard error for each treatment variant. + np.ndarray: Treatment effect for each treatment variant. + np.ndarray: Standard error for each treatment variant. """ if not self._nuisance_models_fit: raise ValueError( @@ -375,7 +376,7 @@ def average_treatment_effect( ) treatment_effect = gamma_matrix.mean(axis=0) standard_error = gamma_matrix.std(axis=0) / np.sqrt(len(X)) - return np.c_[treatment_effect, standard_error] + return treatment_effect, standard_error def _pseudo_outcome( self, diff --git a/tests/test_drlearner.py b/tests/test_drlearner.py index 14298e2..9cfeb7d 100644 --- a/tests/test_drlearner.py +++ b/tests/test_drlearner.py @@ -135,7 +135,7 @@ def test_drlearner_onnx( np.testing.assert_allclose(ml.predict(X, True, "overall"), pred_onnx, atol=5e-4) -def test_treatment_effect( +def test_average_treatment_effect( numerical_experiment_dataset_continuous_outcome_binary_treatment_linear_te, ): X, _, W, Y, _, tau = ( @@ -150,5 +150,5 @@ def test_treatment_effect( n_folds=2, ) ml.fit_all_nuisance(X, Y, W) - est = ml.average_treatment_effect(X, Y, W, is_oos=False) - np.testing.assert_almost_equal(est[:, 0], tau.mean(), decimal=1) + ate_estimate, _ = ml.average_treatment_effect(X, Y, W, is_oos=False) + np.testing.assert_almost_equal(ate_estimate, tau.mean(), decimal=1)