Skip to content

Commit

Permalink
update var cov fill
Browse files Browse the repository at this point in the history
  • Loading branch information
paulf81 committed Dec 2, 2024
1 parent 8aaaf54 commit bb64b7d
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 23 deletions.
4 changes: 2 additions & 2 deletions flasc/analysis/expected_power_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
_add_wd_ws_bins,
_bin_and_group_dataframe_expected_power,
_compute_covariance,
_fill_cov_null,
_fill_cov_with_var,
_null_and_sync_covariance,
_synchronize_mean_power_cov_nulls,
_synchronize_nulls,
Expand Down Expand Up @@ -323,7 +323,7 @@ def _total_uplift_expected_power_with_standard_error(

# If filling missing covariance terms, do it now
if fill_cov_with_var:
df_cov = _fill_cov_null(df_cov, test_cols=test_cols)
df_cov = _fill_cov_with_var(df_cov, test_cols=test_cols)

# If only using the variance, zero out the covariance terms
if variance_only:
Expand Down
35 changes: 21 additions & 14 deletions flasc/analysis/expected_power_analysis_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,20 +281,22 @@ def _null_and_sync_covariance(
return df_cov


def _fill_cov_null(
def _fill_cov_with_var(
df_cov: pl.DataFrame,
test_cols: List[str],
fill_all: bool = True,
) -> pl.DataFrame:
"""Fill null values in covariance according to strategy.
"""Fill covariance terms with the product of the square root of the variances.
Fill the null values in the covariance matrix with the product
of the square root of the variances of the corresponding test columns. Set
the number of points to the minimum of the number of points for the two
corresponding test columns.
Fill the null (or all) values in the covariance matrix with the product
of the square root of the variances of the corresponding test columns.
Leave the number of points as is (the number of shared points between the two test columns).
Args:
df_cov (pl.DataFrame): A polars dataframe with the covariance matrix
test_cols (List[str]): A list of column names to calculate the covariance of
fill_all (bool): If True, fill all values of cov, regardless of whether or not missing/Null
Returns:
pl.DataFrame: A polars dataframe with the null values filled according to the strategy.
Expand All @@ -315,14 +317,18 @@ def _fill_cov_null(
n_col = f"count_{t1}_{t2}"
var_1_col = f"cov_{t1}_{t1}"
var_2_col = f"cov_{t2}_{t2}"
n_1_col = f"count_{t1}_{t1}"
n_2_col = f"count_{t2}_{t2}"
# n_1_col = f"count_{t1}_{t1}"
# n_2_col = f"count_{t2}_{t2}"

# For the rows where cov_col is null, fill the cov_col with the product of the square
# root of the variances of the two test columns and the n_col with the minimum of the
# number of points for the two test columns
df_cov = df_cov.with_columns(null_map=pl.col(cov_col).is_null())

# If fill_all is true, set null_map True for all rows
if fill_all:
df_cov = df_cov.with_columns(pl.lit(True).alias("null_map"))

with pl.Config(tbl_cols=-1):
print(cov_col)
print(df_cov)
Expand All @@ -334,12 +340,13 @@ def _fill_cov_null(
.alias(cov_col)
)

df_cov = df_cov.with_columns(
pl.when(pl.col("null_map"))
.then(pl.min_horizontal(n_1_col, n_2_col))
.otherwise(pl.col(n_col))
.alias(n_col)
)
# Leave n_col as is (number of joint points)
# df_cov = df_cov.with_columns(
# pl.when(pl.col("null_map"))
# .then(pl.min_horizontal(n_1_col, n_2_col))
# .otherwise(pl.col(n_col))
# .alias(n_col)
# )

# For any rows where n_col is 0 or null, set n_col and cov_col to null
df_cov = df_cov.with_columns(
Expand Down
49 changes: 42 additions & 7 deletions tests/expected_power_analysis_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
_add_wd_ws_bins,
_bin_and_group_dataframe_expected_power,
_compute_covariance,
_fill_cov_null,
_fill_cov_with_var,
_get_num_points_pair,
_null_and_sync_covariance,
_synchronize_mean_power_cov_nulls,
Expand Down Expand Up @@ -393,7 +393,7 @@ def test_cov_against_var():
)


def test_fill_cov_null():
def test_fill_cov_with_var_dont_fill_all():
"""Test the fill_cov_null function."""
test_df = pl.DataFrame(
{
Expand Down Expand Up @@ -421,17 +421,52 @@ def test_fill_cov_null():
"cov_pow_001_pow_000": [1, 1],
"cov_pow_001_pow_001": [4, 4],
"count_pow_000_pow_000": [1, 2],
"count_pow_000_pow_001": [3, 2], # Note updated value
"count_pow_000_pow_001": [3, 4], # Note values not updated here
"count_pow_001_pow_000": [5, 6],
"count_pow_001_pow_001": [7, 8],
}
)

filled_df = _fill_cov_null(test_df, test_cols=["pow_000", "pow_001"])
filled_df = _fill_cov_with_var(test_df, test_cols=["pow_000", "pow_001"], fill_all=False)

# with pl.Config(tbl_cols=-1):
# print(expected_df)
# print(filled_df)
assert_frame_equal(filled_df, expected_df, check_row_order=False, check_dtype=False)


def test_fill_cov_with_var_fill_all():
"""Test the fill_cov_null function."""
test_df = pl.DataFrame(
{
"wd_bin": [0, 1],
"ws_bin": [0, 0],
"df_name": ["baseline"] * 2,
"cov_pow_000_pow_000": [4, 4],
"cov_pow_000_pow_001": [1, None],
"cov_pow_001_pow_000": [1, 1],
"cov_pow_001_pow_001": [4, 9],
"count_pow_000_pow_000": [1, 2],
"count_pow_000_pow_001": [3, 4],
"count_pow_001_pow_000": [5, 6],
"count_pow_001_pow_001": [7, 8],
}
)

expected_df = pl.DataFrame(
{
"wd_bin": [0, 1],
"ws_bin": [0, 0],
"df_name": ["baseline"] * 2,
"cov_pow_000_pow_000": [4, 4],
"cov_pow_000_pow_001": [4, 6], # Note filled values
"cov_pow_001_pow_000": [4, 6], # Note filled values
"cov_pow_001_pow_001": [4, 9],
"count_pow_000_pow_000": [1, 2],
"count_pow_000_pow_001": [3, 4], # Note values not updated here
"count_pow_001_pow_000": [5, 6], # Note values not updated here
"count_pow_001_pow_001": [7, 8],
}
)

filled_df = _fill_cov_with_var(test_df, test_cols=["pow_000", "pow_001"], fill_all=True)

assert_frame_equal(filled_df, expected_df, check_row_order=False, check_dtype=False)

Expand Down

0 comments on commit bb64b7d

Please sign in to comment.