Skip to content

Commit

Permalink
Calculate correlation in good voxels
Browse files Browse the repository at this point in the history
  • Loading branch information
JulioAPeraza committed Jul 12, 2024
1 parent ea16fe3 commit 539aa6f
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 35 deletions.
50 changes: 30 additions & 20 deletions nimare/meta/ibma.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,22 +130,21 @@ def _preprocess_input(self, dataset):
# in the stimator if self.aggressive_mask is True.
self.inputs_[name] = temp_arr

# Regardless of the masking strategy, we need to determine the good voxels here
# to mask the bad ones later when calculating the correlation matrix.
nonzero_voxels_bool = np.all(temp_arr != 0, axis=0)
nonnan_voxels_bool = np.all(~np.isnan(temp_arr), axis=0)
good_voxels_bool = np.logical_and(nonzero_voxels_bool, nonnan_voxels_bool)

if "aggressive_mask" not in self.inputs_.keys():
self.inputs_["aggressive_mask"] = good_voxels_bool
if self.aggressive_mask:
# Determine the good voxels here
nonzero_voxels_bool = np.all(temp_arr != 0, axis=0)
nonnan_voxels_bool = np.all(~np.isnan(temp_arr), axis=0)
good_voxels_bool = np.logical_and(nonzero_voxels_bool, nonnan_voxels_bool)

if "aggressive_mask" not in self.inputs_.keys():
self.inputs_["aggressive_mask"] = good_voxels_bool
else:
# Remove any voxels that are bad in any image-based inputs
self.inputs_["aggressive_mask"] = np.logical_or(
self.inputs_["aggressive_mask"],
good_voxels_bool,
)
else:
# Remove any voxels that are bad in any image-based inputs
self.inputs_["aggressive_mask"] = np.logical_or(
self.inputs_["aggressive_mask"],
good_voxels_bool,
)

if not self.aggressive_mask:
data_bags = zip(*_apply_liberal_mask(temp_arr))

keys = ["values", "voxel_mask", "study_mask"]
Expand Down Expand Up @@ -398,13 +397,24 @@ def _preprocess_input(self, dataset):
self.inputs_["contrast_names"] = np.array([label_to_int[label] for label in labels])
self.inputs_["num_contrasts"] = np.array([label_counts[label] for label in labels])

if self.inputs_["contrast_names"].size != np.unique(self.inputs_["contrast_names"]).size:
n_studies = len(self.inputs_["id"])
if n_studies != np.unique(self.inputs_["contrast_names"]).size:
# If all studies are not unique, we will need to correct for multiple contrasts
# Calculate correlation matrix on valid voxels
self.inputs_["corr_matrix"] = np.corrcoef(
self.inputs_["z_maps"][:, self.inputs_["aggressive_mask"]],
rowvar=True,
)
if self.aggressive_mask:
voxel_mask = self.inputs_["aggressive_mask"]
self.inputs_["corr_matrix"] = np.corrcoef(
self.inputs_["z_maps"][:, voxel_mask],
rowvar=True,
)
else:
self.inputs_["corr_matrix"] = np.zeros((n_studies, n_studies), dtype=float)
for bag in self.inputs_["data_bags"]["z_maps"]:
study_bag = bag["study_mask"]
self.inputs_["corr_matrix"][np.ix_(study_bag, study_bag)] = np.corrcoef(
bag["values"],
rowvar=True,
)

def _generate_description(self):
description = (
Expand Down
37 changes: 22 additions & 15 deletions nimare/reports/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,22 +517,29 @@ def __init__(

# Compute similarity matrix
if self.results.estimator.inputs_["corr_matrix"] is None:
aggressive_mask = self.results.estimator.inputs_["aggressive_mask"]
corr = np.corrcoef(
maps_arr[:, aggressive_mask],
rowvar=True,
)
similarity_table = pd.DataFrame(
index=ids_,
columns=ids_,
data=corr,
)
n_studies = len(ids_)
if self.results.estimator.aggressive_mask:
voxel_mask = self.results.estimator.inputs_["aggressive_mask"]
corr = np.corrcoef(
self.results.estimator.inputs_["z_maps"][:, voxel_mask],
rowvar=True,
)
else:
corr = np.zeros((n_studies, n_studies), dtype=float)
for bag in self.results.estimator.inputs_["data_bags"]["z_maps"]:
study_bag = bag["study_mask"]
corr[np.ix_(study_bag, study_bag)] = np.corrcoef(
bag["values"],
rowvar=True,
)
else:
similarity_table = pd.DataFrame(
index=ids_,
columns=ids_,
data=self.inputs_["corr_matrix"],
)
corr = self.inputs_["corr_matrix"]

similarity_table = pd.DataFrame(
index=ids_,
columns=ids_,
data=corr,
)

plot_heatmap(
similarity_table,
Expand Down

0 comments on commit 539aa6f

Please sign in to comment.