From 90af94db1a332cce948882261e481c895ee8d03a Mon Sep 17 00:00:00 2001 From: Onno Kampman Date: Sun, 21 Apr 2024 22:56:32 +0800 Subject: [PATCH] improve d2 sim estimates plot --- .../fmri/sim/plotters/plot_TVFC_estimates.py | 64 +++++++++++++------ 1 file changed, 44 insertions(+), 20 deletions(-) diff --git a/benchmarks/fmri/sim/plotters/plot_TVFC_estimates.py b/benchmarks/fmri/sim/plotters/plot_TVFC_estimates.py index a1389e65..f6d048c1 100644 --- a/benchmarks/fmri/sim/plotters/plot_TVFC_estimates.py +++ b/benchmarks/fmri/sim/plotters/plot_TVFC_estimates.py @@ -15,11 +15,15 @@ from helpers.synthetic_covariance_structures import get_ground_truth_covariance_structure, get_ylim -def _plot_d2_all_covariance_structures( - config_dict: dict, signal_to_noise_ratio: float, - connectivity_metric: str, time_series_noise_type: str, data_split: str, - i_trial: int, - figures_savedir: str = None +def plot_d2_all_covariance_structures( + config_dict: dict, + signal_to_noise_ratio: float, + connectivity_metric: str, + time_series_noise_type: str, + data_split: str, + i_trial: int, + figsize: tuple[float] = (4, 7), + figures_savedir: str = None, ) -> None: """ Plots bivariate correlation edge for all synthetic covariance structures considered. @@ -28,21 +32,24 @@ def _plot_d2_all_covariance_structures( ---------- :param config_dict: :param signal_to_noise_ratio: - :param figure_filename: note that .eps files do not render transparency plots. + :param figure_filename: + Note that .eps files do not render transparency plots. :param connectivity_metric: :param time_series_noise_type: :param data_split: :param i_trial: + :param figsize: :param figures_savedir: """ - sns.set(style="whitegrid", font_scale=1.5) + sns.set(style="whitegrid") n_covs_types = len(config_dict['plot-covs-types']) fig, ax = plt.subplots( - nrows=n_covs_types, ncols=1, + nrows=n_covs_types, + ncols=1, sharex=True, - figsize=(8, 14) + figsize=figsize, ) for i_covs_type, covs_type in enumerate(config_dict['plot-covs-types']): @@ -52,12 +59,15 @@ def _plot_d2_all_covariance_structures( config_dict['data-dir'], time_series_noise_type, f'trial_{i_trial:03d}', f'{covs_type:s}_covariance.csv' ) - x, y = load_data(data_file, verbose=False) # (N, 1), (N, D) + x, y = load_data( + data_file, + verbose=False, + ) # (N, 1), (N, D) ground_truth_covariance_structure = get_ground_truth_covariance_structure( covs_type=covs_type, n_samples=len(x), signal_to_noise_ratio=signal_to_noise_ratio, - data_set_name=config_dict['data-set-name'] + data_set_name=config_dict['data-set-name'], ) # Plot ground truth. @@ -68,7 +78,9 @@ def _plot_d2_all_covariance_structures( label='Ground\nTruth' ) for i_model_name, model_name in enumerate(config_dict['plot-models']): + plot_color = get_palette(config_dict['plot-models'])[i_model_name] + plot_method_tvfc_estimates( config_dict=config_dict, model_name=model_name, @@ -82,7 +94,7 @@ def _plot_d2_all_covariance_structures( i_time_series=0, j_time_series=1, plot_color=plot_color, - ax=ax[i_covs_type] + ax=ax[i_covs_type], ) ax[i_covs_type].set_xlim(config_dict['plot-data-xlim']) @@ -91,8 +103,10 @@ def _plot_d2_all_covariance_structures( if i_covs_type == 0: ax[i_covs_type].legend( - bbox_to_anchor=(1.01, 1.0), frameon=True, - title='TVFC\nestimator', alignment='left' + bbox_to_anchor=(1.01, 1.0), + frameon=True, + title='TVFC\nestimator', + alignment='left', ) # plt.legend(frameon=True, title='cohort') @@ -114,10 +128,20 @@ def _plot_d2_all_covariance_structures( def _plot_d2_tvfc_estimates_single_covariance_structure( - config_dict: dict, - x_train_locations: np.array, y_train_locations: np.array, ground_truth_covariance_structure: np.array, - figure_filename: str, connectivity_metric: str, time_series_noise_type, data_split: str, i_trial: int, covs_type, - markersize=3.6, bbox_to_anchor=(1.19, 1.0), legend_fontsize=12, figure_savedir: str = None + config_dict: dict, + x_train_locations: np.array, + y_train_locations: np.array, + ground_truth_covariance_structure: np.array, + figure_filename: str, + connectivity_metric: str, + time_series_noise_type, + data_split: str, + i_trial: int, + covs_type, + markersize=3.6, + bbox_to_anchor=(1.19, 1.0), + legend_fontsize=12, + figure_savedir: str = None, ) -> None: """ Plots bivariate pair of time series and the predicted covariance structure in one figure. @@ -336,7 +360,7 @@ def _plot_d4_tvfc_estimates( data_set_name = sys.argv[1] # 'd2', 'd3d', or 'd{%d}s' data_split = sys.argv[2] # 'all', or 'LEOO' - experiment_data = sys.argv[3] # e.g. 'N0200_T0100' + experiment_data = sys.argv[3] # 'Nxxxx_Txxxx' metric = sys.argv[4] # 'correlation', or 'covariance' cfg = get_config_dict( @@ -353,7 +377,7 @@ def _plot_d4_tvfc_estimates( # Plot figure for all (bivariate) synthetic covariance structures jointly. if data_set_name == 'd2': # TODO: also plot this for any sparse case - _plot_d2_all_covariance_structures( + plot_d2_all_covariance_structures( config_dict=cfg, signal_to_noise_ratio=SNR, connectivity_metric=metric,