From bb64b7ddf9d48eb1baba8d801b8736b430264a91 Mon Sep 17 00:00:00 2001 From: Paul Date: Mon, 2 Dec 2024 13:52:03 -0700 Subject: [PATCH] update var cov fill --- flasc/analysis/expected_power_analysis.py | 4 +- .../expected_power_analysis_utilities.py | 35 +++++++------ tests/expected_power_analysis_test.py | 49 ++++++++++++++++--- 3 files changed, 65 insertions(+), 23 deletions(-) diff --git a/flasc/analysis/expected_power_analysis.py b/flasc/analysis/expected_power_analysis.py index ae621dc3..8ab86b96 100644 --- a/flasc/analysis/expected_power_analysis.py +++ b/flasc/analysis/expected_power_analysis.py @@ -15,7 +15,7 @@ _add_wd_ws_bins, _bin_and_group_dataframe_expected_power, _compute_covariance, - _fill_cov_null, + _fill_cov_with_var, _null_and_sync_covariance, _synchronize_mean_power_cov_nulls, _synchronize_nulls, @@ -323,7 +323,7 @@ def _total_uplift_expected_power_with_standard_error( # If filling missing covariance terms, do it now if fill_cov_with_var: - df_cov = _fill_cov_null(df_cov, test_cols=test_cols) + df_cov = _fill_cov_with_var(df_cov, test_cols=test_cols) # If only using the variance, zero out the covariance terms if variance_only: diff --git a/flasc/analysis/expected_power_analysis_utilities.py b/flasc/analysis/expected_power_analysis_utilities.py index fe9d7cd7..9621f08a 100644 --- a/flasc/analysis/expected_power_analysis_utilities.py +++ b/flasc/analysis/expected_power_analysis_utilities.py @@ -281,20 +281,22 @@ def _null_and_sync_covariance( return df_cov -def _fill_cov_null( +def _fill_cov_with_var( df_cov: pl.DataFrame, test_cols: List[str], + fill_all: bool = True, ) -> pl.DataFrame: - """Fill null values in covariance according to strategy. + """Fill covariance terms with the product of the square root of the variances. - Fill the null values in the covariance matrix with the product - of the square root of the variances of the corresponding test columns. Set - the number of points to the minimum of the number of points for the two - corresponding test columns. + Fill the null (or all) values in the covariance matrix with the product + of the square root of the variances of the corresponding test columns. + + Leave the number of points as is (the number of shared points between the two test columns). Args: df_cov (pl.DataFrame): A polars dataframe with the covariance matrix test_cols (List[str]): A list of column names to calculate the covariance of + fill_all (bool): If True, fill all values of cov, regardless of whether or not missing/Null Returns: pl.DataFrame: A polars dataframe with the null values filled according to the strategy. @@ -315,14 +317,18 @@ def _fill_cov_null( n_col = f"count_{t1}_{t2}" var_1_col = f"cov_{t1}_{t1}" var_2_col = f"cov_{t2}_{t2}" - n_1_col = f"count_{t1}_{t1}" - n_2_col = f"count_{t2}_{t2}" + # n_1_col = f"count_{t1}_{t1}" + # n_2_col = f"count_{t2}_{t2}" # For the rows where cov_col is null, fill the cov_col with the product of the square # root of the variances of the two test columns and the n_col with the minimum of the # number of points for the two test columns df_cov = df_cov.with_columns(null_map=pl.col(cov_col).is_null()) + # If fill_all is true, set null_map True for all rows + if fill_all: + df_cov = df_cov.with_columns(pl.lit(True).alias("null_map")) + with pl.Config(tbl_cols=-1): print(cov_col) print(df_cov) @@ -334,12 +340,13 @@ def _fill_cov_null( .alias(cov_col) ) - df_cov = df_cov.with_columns( - pl.when(pl.col("null_map")) - .then(pl.min_horizontal(n_1_col, n_2_col)) - .otherwise(pl.col(n_col)) - .alias(n_col) - ) + # Leave n_col as is (number of joint points) + # df_cov = df_cov.with_columns( + # pl.when(pl.col("null_map")) + # .then(pl.min_horizontal(n_1_col, n_2_col)) + # .otherwise(pl.col(n_col)) + # .alias(n_col) + # ) # For any rows where n_col is 0 or null, set n_col and cov_col to null df_cov = df_cov.with_columns( diff --git a/tests/expected_power_analysis_test.py b/tests/expected_power_analysis_test.py index 2e163c32..80a44c6d 100644 --- a/tests/expected_power_analysis_test.py +++ b/tests/expected_power_analysis_test.py @@ -14,7 +14,7 @@ _add_wd_ws_bins, _bin_and_group_dataframe_expected_power, _compute_covariance, - _fill_cov_null, + _fill_cov_with_var, _get_num_points_pair, _null_and_sync_covariance, _synchronize_mean_power_cov_nulls, @@ -393,7 +393,7 @@ def test_cov_against_var(): ) -def test_fill_cov_null(): +def test_fill_cov_with_var_dont_fill_all(): """Test the fill_cov_null function.""" test_df = pl.DataFrame( { @@ -421,17 +421,52 @@ def test_fill_cov_null(): "cov_pow_001_pow_000": [1, 1], "cov_pow_001_pow_001": [4, 4], "count_pow_000_pow_000": [1, 2], - "count_pow_000_pow_001": [3, 2], # Note updated value + "count_pow_000_pow_001": [3, 4], # Note values not updated here "count_pow_001_pow_000": [5, 6], "count_pow_001_pow_001": [7, 8], } ) - filled_df = _fill_cov_null(test_df, test_cols=["pow_000", "pow_001"]) + filled_df = _fill_cov_with_var(test_df, test_cols=["pow_000", "pow_001"], fill_all=False) - # with pl.Config(tbl_cols=-1): - # print(expected_df) - # print(filled_df) + assert_frame_equal(filled_df, expected_df, check_row_order=False, check_dtype=False) + + +def test_fill_cov_with_var_fill_all(): + """Test the fill_cov_null function.""" + test_df = pl.DataFrame( + { + "wd_bin": [0, 1], + "ws_bin": [0, 0], + "df_name": ["baseline"] * 2, + "cov_pow_000_pow_000": [4, 4], + "cov_pow_000_pow_001": [1, None], + "cov_pow_001_pow_000": [1, 1], + "cov_pow_001_pow_001": [4, 9], + "count_pow_000_pow_000": [1, 2], + "count_pow_000_pow_001": [3, 4], + "count_pow_001_pow_000": [5, 6], + "count_pow_001_pow_001": [7, 8], + } + ) + + expected_df = pl.DataFrame( + { + "wd_bin": [0, 1], + "ws_bin": [0, 0], + "df_name": ["baseline"] * 2, + "cov_pow_000_pow_000": [4, 4], + "cov_pow_000_pow_001": [4, 6], # Note filled values + "cov_pow_001_pow_000": [4, 6], # Note filled values + "cov_pow_001_pow_001": [4, 9], + "count_pow_000_pow_000": [1, 2], + "count_pow_000_pow_001": [3, 4], # Note values not updated here + "count_pow_001_pow_000": [5, 6], # Note values not updated here + "count_pow_001_pow_001": [7, 8], + } + ) + + filled_df = _fill_cov_with_var(test_df, test_cols=["pow_000", "pow_001"], fill_all=True) assert_frame_equal(filled_df, expected_df, check_row_order=False, check_dtype=False)