Skip to content

Commit

Permalink
improve d2 sim estimates plot
Browse files Browse the repository at this point in the history
  • Loading branch information
OnnoKampman committed Apr 21, 2024
1 parent a242ed4 commit 90af94d
Showing 1 changed file with 44 additions and 20 deletions.
64 changes: 44 additions & 20 deletions benchmarks/fmri/sim/plotters/plot_TVFC_estimates.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@
from helpers.synthetic_covariance_structures import get_ground_truth_covariance_structure, get_ylim


def _plot_d2_all_covariance_structures(
config_dict: dict, signal_to_noise_ratio: float,
connectivity_metric: str, time_series_noise_type: str, data_split: str,
i_trial: int,
figures_savedir: str = None
def plot_d2_all_covariance_structures(
config_dict: dict,
signal_to_noise_ratio: float,
connectivity_metric: str,
time_series_noise_type: str,
data_split: str,
i_trial: int,
figsize: tuple[float] = (4, 7),
figures_savedir: str = None,
) -> None:
"""
Plots bivariate correlation edge for all synthetic covariance structures considered.
Expand All @@ -28,21 +32,24 @@ def _plot_d2_all_covariance_structures(
----------
:param config_dict:
:param signal_to_noise_ratio:
:param figure_filename: note that .eps files do not render transparency plots.
:param figure_filename:
Note that .eps files do not render transparency plots.
:param connectivity_metric:
:param time_series_noise_type:
:param data_split:
:param i_trial:
:param figsize:
:param figures_savedir:
"""
sns.set(style="whitegrid", font_scale=1.5)
sns.set(style="whitegrid")

n_covs_types = len(config_dict['plot-covs-types'])

fig, ax = plt.subplots(
nrows=n_covs_types, ncols=1,
nrows=n_covs_types,
ncols=1,
sharex=True,
figsize=(8, 14)
figsize=figsize,
)
for i_covs_type, covs_type in enumerate(config_dict['plot-covs-types']):

Expand All @@ -52,12 +59,15 @@ def _plot_d2_all_covariance_structures(
config_dict['data-dir'], time_series_noise_type,
f'trial_{i_trial:03d}', f'{covs_type:s}_covariance.csv'
)
x, y = load_data(data_file, verbose=False) # (N, 1), (N, D)
x, y = load_data(
data_file,
verbose=False,
) # (N, 1), (N, D)
ground_truth_covariance_structure = get_ground_truth_covariance_structure(
covs_type=covs_type,
n_samples=len(x),
signal_to_noise_ratio=signal_to_noise_ratio,
data_set_name=config_dict['data-set-name']
data_set_name=config_dict['data-set-name'],
)

# Plot ground truth.
Expand All @@ -68,7 +78,9 @@ def _plot_d2_all_covariance_structures(
label='Ground\nTruth'
)
for i_model_name, model_name in enumerate(config_dict['plot-models']):

plot_color = get_palette(config_dict['plot-models'])[i_model_name]

plot_method_tvfc_estimates(
config_dict=config_dict,
model_name=model_name,
Expand All @@ -82,7 +94,7 @@ def _plot_d2_all_covariance_structures(
i_time_series=0,
j_time_series=1,
plot_color=plot_color,
ax=ax[i_covs_type]
ax=ax[i_covs_type],
)

ax[i_covs_type].set_xlim(config_dict['plot-data-xlim'])
Expand All @@ -91,8 +103,10 @@ def _plot_d2_all_covariance_structures(

if i_covs_type == 0:
ax[i_covs_type].legend(
bbox_to_anchor=(1.01, 1.0), frameon=True,
title='TVFC\nestimator', alignment='left'
bbox_to_anchor=(1.01, 1.0),
frameon=True,
title='TVFC\nestimator',
alignment='left',
)

# plt.legend(frameon=True, title='cohort')
Expand All @@ -114,10 +128,20 @@ def _plot_d2_all_covariance_structures(


def _plot_d2_tvfc_estimates_single_covariance_structure(
config_dict: dict,
x_train_locations: np.array, y_train_locations: np.array, ground_truth_covariance_structure: np.array,
figure_filename: str, connectivity_metric: str, time_series_noise_type, data_split: str, i_trial: int, covs_type,
markersize=3.6, bbox_to_anchor=(1.19, 1.0), legend_fontsize=12, figure_savedir: str = None
config_dict: dict,
x_train_locations: np.array,
y_train_locations: np.array,
ground_truth_covariance_structure: np.array,
figure_filename: str,
connectivity_metric: str,
time_series_noise_type,
data_split: str,
i_trial: int,
covs_type,
markersize=3.6,
bbox_to_anchor=(1.19, 1.0),
legend_fontsize=12,
figure_savedir: str = None,
) -> None:
"""
Plots bivariate pair of time series and the predicted covariance structure in one figure.
Expand Down Expand Up @@ -336,7 +360,7 @@ def _plot_d4_tvfc_estimates(

data_set_name = sys.argv[1] # 'd2', 'd3d', or 'd{%d}s'
data_split = sys.argv[2] # 'all', or 'LEOO'
experiment_data = sys.argv[3] # e.g. 'N0200_T0100'
experiment_data = sys.argv[3] # 'Nxxxx_Txxxx'
metric = sys.argv[4] # 'correlation', or 'covariance'

cfg = get_config_dict(
Expand All @@ -353,7 +377,7 @@ def _plot_d4_tvfc_estimates(

# Plot figure for all (bivariate) synthetic covariance structures jointly.
if data_set_name == 'd2': # TODO: also plot this for any sparse case
_plot_d2_all_covariance_structures(
plot_d2_all_covariance_structures(
config_dict=cfg,
signal_to_noise_ratio=SNR,
connectivity_metric=metric,
Expand Down

0 comments on commit 90af94d

Please sign in to comment.