diff --git a/metalearners/drlearner.py b/metalearners/drlearner.py index 944b33a..1d23908 100644 --- a/metalearners/drlearner.py +++ b/metalearners/drlearner.py @@ -268,6 +268,7 @@ def _pseudo_outcome( is_oos: bool, oos_method: OosMethod = OVERALL, epsilon: float = _EPSILON, + adaptive_clipping: bool = False, ) -> np.ndarray: """Compute the DR-Learner pseudo outcome.""" validate_valid_treatment_variant_not_control(treatment_variant, self.n_variants) @@ -317,4 +318,12 @@ def _pseudo_outcome( - y0_estimate ) + if adaptive_clipping: + t_pseudo_outcome = y1_estimate - y0_estimate + pseudo_outcome = np.where( + propensity_estimates.min(axis=1) < epsilon, + t_pseudo_outcome, + pseudo_outcome, + ) + return pseudo_outcome