Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an inflation factor to correct for multiple contrasts in Stouffer's combination test #117

Merged
merged 13 commits into from
Apr 9, 2024
76 changes: 68 additions & 8 deletions pymare/estimators/combination.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,77 @@ class StoufferCombinationTest(CombinationTest):
"""

# Maps Dataset attributes onto fit() args; see BaseEstimator for details.
_dataset_attr_map = {"z": "y", "w": "v"}

def fit(self, z, w=None):
"""Fit the estimator to z-values, optionally with weights."""
return super().fit(z, w=w)

def p_value(self, z, w=None):
_dataset_attr_map = {"z": "y", "w": "n", "g": "v"}

def _inflation_term(self, z, w, g):
"""Calculate the variance inflation term for each group.

This term is used to adjust the variance of the combined z-score when
multiple sample come from the same study.

Parameters
----------
z : :obj:`numpy.ndarray` of shape (n, d)
Array of z-values.
w : :obj:`numpy.ndarray` of shape (n, d)
Array of weights.
g : :obj:`numpy.ndarray` of shape (n, d)
Array of group labels.

Returns
-------
sigma : float
The variance inflation term.
"""
# Only center if the samples are not all the same, to prevent division by zero
# when calculating the correlation matrix.
# This centering is problematic for N=2
all_samples_same = np.all(np.equal(z, z[0]), axis=0).all()
z = z if all_samples_same else z - z.mean(0)

# Use the value from one feature, as all features have the same groups and weights
groups = g[:, 0]
weights = w[:, 0]

# Loop over groups
unique_groups = np.unique(groups)

sigma = 0
for group in unique_groups:
group_indices = np.where(groups == group)[0]
group_z = z[group_indices]

# For groups with only one sample the contribution to the summand is 0
n_samples = len(group_indices)
if n_samples < 2:
continue

# Calculate the within group correlation matrix and sum the non-diagonal elements
corr = np.corrcoef(group_z, rowvar=True)
upper_indices = np.triu_indices(n_samples, k=1)
non_diag_corr = corr[upper_indices]
w_i, w_j = weights[upper_indices[0]], weights[upper_indices[1]]

sigma += (2 * w_i * w_j * non_diag_corr).sum()

return sigma

def fit(self, z, w=None, g=None):
"""Fit the estimator to z-values, optionally with weights and groups."""
return super().fit(z, w=w, g=g)

def p_value(self, z, w=None, g=None):
"""Calculate p-values."""
if w is None:
w = np.ones_like(z)
cz = (z * w).sum(0) / np.sqrt((w**2).sum(0))

# Calculate the variance inflation term, sum of non-diagonal elements of sigma.
sigma = self._inflation_term(z, w, g) if g is not None else 0

# The sum of diagonal elements of sigma is given by (w**2).sum(0).
variance = (w**2).sum(0) + sigma

cz = (z * w).sum(0) / np.sqrt(variance)
return ss.norm.sf(cz)


Expand Down
37 changes: 37 additions & 0 deletions pymare/tests/test_combination_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,40 @@ def test_combination_test_from_dataset(Cls, data, mode, expected):
results = est.summary()
z = ss.norm.isf(results.p)
assert np.allclose(z, expected, atol=1e-5)


def test_stouffer_adjusted():
"""Test StoufferCombinationTest with weights and groups."""
# Test with weights and groups
data = np.array(
[
[2.1, 0.7, -0.2, 4.1, 3.8],
[1.1, 0.2, 0.4, 1.3, 1.5],
[-0.6, -1.6, -2.3, -0.8, -4.0],
[2.5, 1.7, 2.1, 2.3, 2.5],
[3.1, 2.7, 3.1, 3.3, 3.5],
[3.6, 3.2, 3.6, 3.8, 4.0],
]
)
weights = np.tile(np.array([4, 3, 4, 10, 15, 10]), (data.shape[1], 1)).T
groups = np.tile(np.array([0, 0, 1, 2, 2, 2]), (data.shape[1], 1)).T

results = StoufferCombinationTest("directed").fit(z=data, w=weights, g=groups).params_
z = ss.norm.isf(results["p"])

z_expected = np.array([5.00088912, 3.70356943, 4.05465924, 5.4633001, 5.18927878])
assert np.allclose(z, z_expected, atol=1e-5)

# Test with weights and no groups. Limiting cases.
# Limiting case 1: all correlations are one.
n_maps_l1 = 5
common_sample = np.array([2.1, 0.7, -0.2])
data_l1 = np.tile(common_sample, (n_maps_l1, 1))
groups_l1 = np.tile(np.array([0, 0, 0, 0, 0]), (data_l1.shape[1], 1)).T

results_l1 = StoufferCombinationTest("directed").fit(z=data_l1, g=groups_l1).params_
z_l1 = ss.norm.isf(results_l1["p"])

sigma_l1 = n_maps_l1 * (n_maps_l1 - 1) # Expected inflation term
z_expected_l1 = n_maps_l1 * common_sample / np.sqrt(n_maps_l1 + sigma_l1)
assert np.allclose(z_l1, z_expected_l1, atol=1e-5)
Loading