From 7a11445b3b1f0c607221ac5b56aacac74e0a35da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francesc=20Mart=C3=AD=20Escofet?= <154450563+FrancescMartiEscofetQC@users.noreply.github.com> Date: Fri, 14 Jun 2024 12:58:48 +0200 Subject: [PATCH] Switch `strict` meaning in `validate_number_positive` --- metalearners/_utils.py | 19 +++++++++++++------ metalearners/cross_fit_estimator.py | 2 +- tests/test_cross_fit_estimator.py | 4 +++- 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/metalearners/_utils.py b/metalearners/_utils.py index 00eecfb..0aca691 100644 --- a/metalearners/_utils.py +++ b/metalearners/_utils.py @@ -1,7 +1,6 @@ # # Copyright (c) QuantCo 2024-2024 # # SPDX-License-Identifier: BSD-3-Clause -import operator from collections.abc import Callable from inspect import signature from operator import le, lt @@ -66,14 +65,22 @@ def validate_all_vectors_same_index(*args: Vector) -> None: def validate_number_positive( - value: int | float, name: str, strict: bool = False + value: int | float, name: str, strict: bool = True ) -> None: + """Validates that a number is positive. + + If ``strict = True`` then it validates that the number is strictly positive. + """ if strict: - comparison = operator.lt + if value <= 0: + raise ValueError( + f"{name} was expected to be strictly positive but was {value}." + ) else: - comparison = operator.le - if comparison(value, 0): - raise ValueError(f"{name} was expected to be positive but was {value}.") + if value < 0: + raise ValueError( + f"{name} was expected to be positive or zero but was {value}." + ) def check_propensity_score( diff --git a/metalearners/cross_fit_estimator.py b/metalearners/cross_fit_estimator.py index e26d898..9765aa7 100644 --- a/metalearners/cross_fit_estimator.py +++ b/metalearners/cross_fit_estimator.py @@ -56,7 +56,7 @@ def _validate_data_match_prior_split( ) -> None: """Validate whether the previous test_indices and the passed data are based on the same number of observations.""" - validate_number_positive(n_observations, "n_observations", strict=False) + validate_number_positive(n_observations, "n_observations", strict=True) if test_indices is None: return expected_n_observations = sum(len(x) for x in test_indices) diff --git a/tests/test_cross_fit_estimator.py b/tests/test_cross_fit_estimator.py index 8e34b00..bb102c5 100644 --- a/tests/test_cross_fit_estimator.py +++ b/tests/test_cross_fit_estimator.py @@ -223,7 +223,9 @@ def test_crossfitestimator_n_folds_1(rng, sample_size): ) def test_validate_data_match(n_observations, test_indices, success): if n_observations < 1: - with pytest.raises(ValueError, match="was expected to be positive"): + with pytest.raises( + ValueError, match=r"was expected to be (strictly )?positive" + ): _validate_data_match_prior_split(n_observations, test_indices) return if success: