Skip to content

Commit

Permalink
minor edits
Browse files Browse the repository at this point in the history
  • Loading branch information
OnnoKampman committed Apr 22, 2024
1 parent 78fd31d commit 4b251e7
Show file tree
Hide file tree
Showing 11 changed files with 38 additions and 1,839 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
f'{covs_type:s}_covariance.csv'
)
if not os.path.exists(data_filepath):
logging.warning(f"File '{data_filepath:s}' not found.")
logging.warning(f"Data file '{data_filepath:s}' not found.")
if covs_type == 'boxcar':
data_filepath = os.path.join(
cfg['data-dir'], noise_type, f'trial_{i_trial:03d}',
Expand All @@ -48,6 +48,8 @@
if not os.path.exists(data_filepath):
logging.warning(f"File '{data_filepath:s}' not found.")
continue
else:
continue
x, y = load_data(
data_filepath,
verbose=False,
Expand Down
64 changes: 35 additions & 29 deletions helpers/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,11 @@ def get_tvfc_estimates(

repetition_time = None

wp_joint_model_savedir = os.path.join(
wp_model_savedir = os.path.join(
config_dict['experiments-basedir'], noise_type, data_split,
f'trial_{i_trial:03d}', model_name
)
wp_joint_model_filename = f'{covs_type:s}.json'
wp_model_filename = f'{covs_type:s}.json'

tvfc_estimates_savedir = os.path.join(
config_dict['experiments-basedir'], noise_type, f'trial_{i_trial:03d}',
Expand All @@ -256,10 +256,10 @@ def get_tvfc_estimates(
)

# Fix renaming issue.
if not os.path.exists(os.path.join(wp_joint_model_savedir, wp_joint_model_filename)):
logging.warning(f"Model file {os.path.join(wp_joint_model_savedir, wp_joint_model_filename):s} not found.")
if not os.path.exists(os.path.join(wp_model_savedir, wp_model_filename)):
logging.warning(f"WP model file {os.path.join(wp_model_savedir, wp_model_filename):s} not found.")
if covs_type == 'boxcar':
wp_joint_model_filename = 'checkerboard.json'
wp_model_filename = 'checkerboard.json'
if not os.path.exists(tvfc_estimates_filepath):
if covs_type == 'boxcar':
tvfc_estimates_filepath = os.path.join(
Expand All @@ -270,11 +270,11 @@ def get_tvfc_estimates(
assert experiment_dimensionality is not None
assert scan_id is not None
assert subject is not None
wp_joint_model_savedir = os.path.join(
wp_model_savedir = os.path.join(
config_dict['experiments-basedir'], 'saved_models', f'scan_{scan_id:d}',
data_split, experiment_dimensionality, model_name
)
wp_joint_model_filename = f'{subject:d}.json'
wp_model_filename = f'{subject:d}.json'
repetition_time = config_dict['repetition-time']
# tvfc_estimates_filepath = os.path.join(
# config_dict['experiments-basedir'], 'TVFC_estimates', f'scan_{scan_id:d}',
Expand All @@ -283,27 +283,27 @@ def get_tvfc_estimates(
case 'rockland':
pp_pipeline = 'custom_fsl_pipeline'
assert subject is not None
wp_joint_model_savedir = os.path.join(
wp_model_savedir = os.path.join(
config_dict['experiments-basedir'], pp_pipeline, 'saved_models', data_split, model_name
)
wp_joint_model_filename = f"{subject.removesuffix('.csv'):s}.json"
wp_model_filename = f"{subject.removesuffix('.csv'):s}.json"
repetition_time = config_dict['repetition-time']
case _:
logging.error(f"Dataset '{data_set_name:s}' not recognized.")
return

match model_name:
case 'VWP' | 'VWP_joint':
if os.path.exists(os.path.join(wp_joint_model_savedir, wp_joint_model_filename)):
if os.path.exists(os.path.join(wp_model_savedir, wp_model_filename)):
k = gpflow.kernels.Matern52()
m = VariationalWishartProcess(
x_train, y_train,
nu=n_time_series,
kernel=k
)
m.load_from_params_dict(
savedir=wp_joint_model_savedir,
model_name=wp_joint_model_filename,
savedir=wp_model_savedir,
model_name=wp_model_filename,
)
if metric == 'correlation':
all_covs_means, _ = m.predict_corr(x_train) # Tensor of shape (N, D, D), _
Expand All @@ -312,10 +312,10 @@ def get_tvfc_estimates(
estimated_tvfc = all_covs_means.numpy() # (N, D, D)
del m
else:
logging.warning(f"VWP model not found in '{wp_joint_model_savedir:s}'.")
logging.warning(f"VWP model not found in '{wp_model_savedir:s}'.")
return
case 'SVWP' | 'SVWP_joint':
if os.path.exists(os.path.join(wp_joint_model_savedir, wp_joint_model_filename)):
if os.path.exists(os.path.join(wp_model_savedir, wp_model_filename)):
k = gpflow.kernels.Matern52()
m = SparseVariationalWishartProcess(
D=n_time_series,
Expand All @@ -325,8 +325,8 @@ def get_tvfc_estimates(
verbose=False
)
m.load_from_params_dict(
savedir=wp_joint_model_savedir,
model_name=wp_joint_model_filename,
savedir=wp_model_savedir,
model_name=wp_model_filename,
)
if metric == 'correlation':
all_covs_means, _ = m.predict_corr(x_train) # Tensor of shape (N, D, D), _
Expand All @@ -335,7 +335,7 @@ def get_tvfc_estimates(
estimated_tvfc = all_covs_means.numpy() # (N, D, D)
del m
else:
logging.warning(f"SVWP model not found in '{wp_joint_model_savedir:s}'.")
logging.warning(f"SVWP model not found in '{wp_model_savedir:s}'.")
return
case 'DCC' | 'DCC_joint' | 'DCC_bivariate_loop' | 'GO' | 'GO_joint' | 'GO_bivariate_loop':
if os.path.exists(tvfc_estimates_filepath):
Expand Down Expand Up @@ -405,20 +405,26 @@ def get_test_location_estimated_covariance_structure(
subject: str,
data_split: str,
y_train_locations: np.array = None,
scan_id: int = None, experiment_dimensionality: str = None,
scan_id: int = None,
experiment_dimensionality: str = None,
noise_type: str = None,
i_trial: int = None,
covs_type: str = None,
connectivity_metric: str = 'covariance',
):
) -> np.array:
"""
Estimates test location covariance structure.
For all methods except the Wishart process models, we load the pre-saved train locations predictions.
Parameters
----------
experiment_dimensionality: str
'multivariate' or 'bivariate'.
"""
data_set_name = config_dict['data-set-name']
match data_set_name:
case 'd2' | 'd3d' | 'd3s' | 'd4s' | 'd6s' | 'd9s' | 'd15s': # TODO: generalize this
model_savedir = os.path.join(
wp_model_savedir = os.path.join(
config_dict['experiments-basedir'], noise_type, data_split,
f'trial_{i_trial:03d}', model_name
)
Expand All @@ -433,8 +439,8 @@ def get_test_location_estimated_covariance_structure(
)

# Fix renaming issue.
if not os.path.exists(os.path.join(model_savedir, wp_model_filename)):
logging.warning(f"Model file {os.path.join(model_savedir, wp_model_filename):s} not found.")
if not os.path.exists(os.path.join(wp_model_savedir, wp_model_filename)):
logging.warning(f"WP model file {os.path.join(wp_model_savedir, wp_model_filename):s} not found.")
if covs_type == 'boxcar':
wp_model_filename = 'checkerboard.json'
if not os.path.exists(tvfc_estimates_filepath):
Expand All @@ -446,7 +452,7 @@ def get_test_location_estimated_covariance_structure(
case 'HCP_PTN1200_recon2':
assert scan_id is not None
assert experiment_dimensionality is not None
model_savedir = os.path.join(
wp_model_savedir = os.path.join(
config_dict['experiments-basedir'], 'saved_models', f'scan_{scan_id:d}',
data_split, experiment_dimensionality, model_name
)
Expand All @@ -458,7 +464,7 @@ def get_test_location_estimated_covariance_structure(
)
case 'rockland':
pp_pipeline = 'custom_fsl_pipeline'
model_savedir = os.path.join(
wp_model_savedir = os.path.join(
config_dict['experiments-basedir'], pp_pipeline, 'saved_models', data_split, model_name
)
wp_model_filename = f"{subject.removesuffix('.csv'):s}.json"
Expand All @@ -472,25 +478,25 @@ def get_test_location_estimated_covariance_structure(

match model_name:
case 'VWP' | 'VWP_joint': # we do not have to interpolate linearly here
wp_model_filepath = os.path.join(model_savedir, wp_model_filename)
wp_model_filepath = os.path.join(wp_model_savedir, wp_model_filename)
if not os.path.exists(wp_model_filepath):
raise FileNotFoundError(f"Could not load model '{wp_model_filepath:s}'.")
raise FileNotFoundError(f"Could not load WP model '{wp_model_filepath:s}'.")
k = gpflow.kernels.Matern52()
m = VariationalWishartProcess(
x_observed=x_train_locations,
y_observed=y_train_locations,
nu=n_time_series,
kernel=k
)
m.load_from_params_dict(model_savedir, wp_model_filename)
m.load_from_params_dict(wp_model_savedir, wp_model_filename)
if connectivity_metric == 'correlation':
all_covs_means, _ = m.predict_corr(x_test_locations) # Tensor of shape (N, D, D), _
else:
all_covs_means, _ = m.predict_cov(x_test_locations) # Tensor of shape (N, D, D), _
test_locations_predicted_covariance_structure = all_covs_means.numpy() # (N, D, D)
del m
case 'SVWP' | 'SVWP_joint': # we do not have to interpolate linearly here
wp_model_filepath = os.path.join(model_savedir, wp_model_filename)
wp_model_filepath = os.path.join(wp_model_savedir, wp_model_filename)
if not os.path.exists(wp_model_filepath):
raise FileNotFoundError(f"Could not load model '{wp_model_filepath:s}'.")
k = gpflow.kernels.Matern52()
Expand All @@ -502,7 +508,7 @@ def get_test_location_estimated_covariance_structure(
verbose=False
)
m.load_from_params_dict(
savedir=model_savedir,
savedir=wp_model_savedir,
model_name=wp_model_filename,
)
test_locations_predicted_covariance_structure, _ = m.predict_cov(x_new=x_test_locations) # Tensor of shape (N_test, D, D)
Expand Down
Loading

0 comments on commit 4b251e7

Please sign in to comment.