diff --git a/tests/test_data_loading.py b/tests/test_data_loading.py index fa5c85e..8c51e53 100644 --- a/tests/test_data_loading.py +++ b/tests/test_data_loading.py @@ -3,258 +3,258 @@ # from MDRefine import compute_new_weights, compute_chi2, compute_D_KL, l2_regularization from MDRefine import load_data -class Test(unittest.TestCase): +# class Test(unittest.TestCase): - def test_load_data(self): +# def test_load_data(self): - import pickle - import jax.numpy as jnp +# import pickle +# import jax.numpy as jnp - #%% define test_function +# #%% define test_function - def test_function(infos, stride, path_pickle): +# def test_function(infos, stride, path_pickle): - # 1. load data with load_data - data = load_data(infos, stride=stride) +# # 1. load data with load_data +# data = load_data(infos, stride=stride) - # 2. load pickle into loaded_data - with open(path_pickle, 'rb') as f: - loaded_data = pickle.load(f) +# # 2. load pickle into loaded_data +# with open(path_pickle, 'rb') as f: +# loaded_data = pickle.load(f) - # add ff_correction and forward_model to loaded_data (since you cannot load them from pickle) +# # add ff_correction and forward_model to loaded_data (since you cannot load them from pickle) - for k in loaded_data['sys'].keys(): - if k in infos.keys(): - info = {**infos[k], **infos['global']} - else: - info = infos['global'] +# for k in loaded_data['sys'].keys(): +# if k in infos.keys(): +# info = {**infos[k], **infos['global']} +# else: +# info = infos['global'] - if 'ff_correction' in info.keys(): - loaded_data['sys'][k].ff_correction = info['ff_correction'] +# if 'ff_correction' in info.keys(): +# loaded_data['sys'][k].ff_correction = info['ff_correction'] - if 'forward_model' in info.keys(): - def my_forward_model(a, b, c=None): - try: out = info['forward_model'](a, b, c) - except: - assert c is None, 'you have selected_obs but the forward model is not suitably defined!' - out = info['forward_model'](a, b) - return out +# if 'forward_model' in info.keys(): +# def my_forward_model(a, b, c=None): +# try: out = info['forward_model'](a, b, c) +# except: +# assert c is None, 'you have selected_obs but the forward model is not suitably defined!' +# out = info['forward_model'](a, b) +# return out - loaded_data['sys'][k].forward_model = my_forward_model +# loaded_data['sys'][k].forward_model = my_forward_model - # 3. compare +# # 3. compare - ### this does not work because of the structure: dict contains dictionaries which contain numpy arrays... - # for s in loaded_data['sys'].keys(): - # assert vars(data.sys[s]) == vars(loaded_data['sys'][s]) - # self.assertDictEqual(vars(data.sys[s]), vars(loaded_data['sys'][s])) +# ### this does not work because of the structure: dict contains dictionaries which contain numpy arrays... +# # for s in loaded_data['sys'].keys(): +# # assert vars(data.sys[s]) == vars(loaded_data['sys'][s]) +# # self.assertDictEqual(vars(data.sys[s]), vars(loaded_data['sys'][s])) - ### so, let's do in this way +# ### so, let's do in this way - self.assertListEqual(list(vars(data).keys()), list(loaded_data.keys())) +# self.assertListEqual(list(vars(data).keys()), list(loaded_data.keys())) - # 3a. global properties - self.assertListEqual(dir(data._global_), dir(loaded_data['_global_'])) +# # 3a. global properties +# self.assertListEqual(dir(data._global_), dir(loaded_data['_global_'])) - self.assertListEqual(data._global_.system_names, loaded_data['_global_'].system_names) +# self.assertListEqual(data._global_.system_names, loaded_data['_global_'].system_names) - if hasattr(loaded_data['_global_'], 'forward_coeffs_0'): - self.assertListEqual(list(loaded_data['_global_'].forward_coeffs_0), list(data._global_.forward_coeffs_0)) - self.assertListEqual(list(loaded_data['_global_'].forward_coeffs_0.keys()), list(data._global_.forward_coeffs_0.keys())) +# if hasattr(loaded_data['_global_'], 'forward_coeffs_0'): +# self.assertListEqual(list(loaded_data['_global_'].forward_coeffs_0), list(data._global_.forward_coeffs_0)) +# self.assertListEqual(list(loaded_data['_global_'].forward_coeffs_0.keys()), list(data._global_.forward_coeffs_0.keys())) - if hasattr(loaded_data['_global_'], 'names_ff_pars'): - self.assertListEqual(loaded_data['_global_'].names_ff_pars, data._global_.names_ff_pars) +# if hasattr(loaded_data['_global_'], 'names_ff_pars'): +# self.assertListEqual(loaded_data['_global_'].names_ff_pars, data._global_.names_ff_pars) - # assert tot_n_experiments - class my_data(): - def __init__(self): - self.sys = {} +# # assert tot_n_experiments +# class my_data(): +# def __init__(self): +# self.sys = {} - my_loaded_data = my_data() +# my_loaded_data = my_data() - for k in loaded_data['sys'].keys(): - my_loaded_data.sys[k] = loaded_data['sys'][k] +# for k in loaded_data['sys'].keys(): +# my_loaded_data.sys[k] = loaded_data['sys'][k] - self.assertEqual(data._global_.tot_n_experiments(data), loaded_data['_global_'].tot_n_experiments(my_loaded_data)) +# self.assertEqual(data._global_.tot_n_experiments(data), loaded_data['_global_'].tot_n_experiments(my_loaded_data)) - # 3b. molecular systems - self.assertSetEqual(set(data.sys.keys()), set(loaded_data['sys'].keys())) +# # 3b. molecular systems +# self.assertSetEqual(set(data.sys.keys()), set(loaded_data['sys'].keys())) - for s in infos['global']['system_names']: +# for s in infos['global']['system_names']: - my_dict1 = vars(data.sys[s]) - my_dict2 = vars(loaded_data['sys'][s]) +# my_dict1 = vars(data.sys[s]) +# my_dict2 = vars(loaded_data['sys'][s]) - self.assertSetEqual(set(my_dict1.keys()), set(my_dict2.keys())) +# self.assertSetEqual(set(my_dict1.keys()), set(my_dict2.keys())) - for k in my_dict1.keys(): +# for k in my_dict1.keys(): - if k in ['temperature', 'n_frames', 'logZ']: - self.assertAlmostEqual(my_dict1[k], my_dict2[k]) +# if k in ['temperature', 'n_frames', 'logZ']: +# self.assertAlmostEqual(my_dict1[k], my_dict2[k]) - elif k in ['ref', 'n_experiments']: - self.assertDictEqual(my_dict1[k], my_dict2[k]) +# elif k in ['ref', 'n_experiments']: +# self.assertDictEqual(my_dict1[k], my_dict2[k]) - elif k in ['gexp', 'names', 'g']: - for k2 in data.sys[s].gexp.keys(): - assert (my_dict1[k][k2] == my_dict2[k][k2]).all() +# elif k in ['gexp', 'names', 'g']: +# for k2 in data.sys[s].gexp.keys(): +# assert (my_dict1[k][k2] == my_dict2[k][k2]).all() - elif k in ['forward_qs']: - for k2 in data.sys[s].forward_qs.keys(): - assert (my_dict1[k][k2] == my_dict2[k][k2]).all() +# elif k in ['forward_qs']: +# for k2 in data.sys[s].forward_qs.keys(): +# assert (my_dict1[k][k2] == my_dict2[k][k2]).all() - elif k in ['weights', 'f']: - assert (my_dict1[k] == my_dict2[k]).all() +# elif k in ['weights', 'f']: +# assert (my_dict1[k] == my_dict2[k]).all() - # 3c. cycles +# # 3c. cycles - if hasattr(loaded_data['_global_'], 'cycle_names'): - self.assertSetEqual(set(loaded_data['_global_'].cycle_names), set(data._global_.cycle_names)) +# if hasattr(loaded_data['_global_'], 'cycle_names'): +# self.assertSetEqual(set(loaded_data['_global_'].cycle_names), set(data._global_.cycle_names)) - for s in infos['global']['cycle_names']: +# for s in infos['global']['cycle_names']: - my_dict1 = vars(data.cycle[s]) - my_dict2 = vars(loaded_data['cycle'][s]) +# my_dict1 = vars(data.cycle[s]) +# my_dict2 = vars(loaded_data['cycle'][s]) - self.assertSetEqual(set(my_dict1.keys()), set(my_dict2.keys())) +# self.assertSetEqual(set(my_dict1.keys()), set(my_dict2.keys())) - for k in my_dict1.keys(): +# for k in my_dict1.keys(): - if k in ['temperature']: - self.assertAlmostEqual(my_dict1[k], my_dict2[k]) +# if k in ['temperature']: +# self.assertAlmostEqual(my_dict1[k], my_dict2[k]) - elif k in ['gexp_DDG']: - self.assertListEqual(my_dict1[k], my_dict2[k]) +# elif k in ['gexp_DDG']: +# self.assertListEqual(my_dict1[k], my_dict2[k]) - #%% test n. 1: without forward model nor force-field correction """ +# #%% test n. 1: without forward model nor force-field correction """ - infos = {} - infos['global'] = {'path_directory': 'tests/DATA_test', 'system_names': ['AAAA', 'CAAU']} +# infos = {} +# infos['global'] = {'path_directory': 'tests/DATA_test', 'system_names': ['AAAA', 'CAAU']} - for name in infos['global']['system_names']: - infos[name] = {} - infos[name]['g_exp'] = ['NOEs', ('uNOEs','<')] - infos[name]['obs'] = ['NOEs', 'uNOEs'] +# for name in infos['global']['system_names']: +# infos[name] = {} +# infos[name]['g_exp'] = ['NOEs', ('uNOEs','<')] +# infos[name]['obs'] = ['NOEs', 'uNOEs'] - infos['global']['temperature'] = 1 # namely, energies are in unit of k_B T (default value) - stride = 2 +# infos['global']['temperature'] = 1 # namely, energies are in unit of k_B T (default value) +# stride = 2 - path_pickle = 'tests/DATA_test/data_stride2.pkl' +# path_pickle = 'tests/DATA_test/data_stride2.pkl' - test_function(infos, stride, path_pickle) +# test_function(infos, stride, path_pickle) - #%% test n. 2: complete +# #%% test n. 2: complete - infos = {'global': { - 'path_directory': 'tests/DATA_test', - 'system_names': ['AAAA', 'CAAU'], - 'g_exp': ['backbone1_gamma_3J', 'backbone2_beta_epsilon_3J', 'sugar_3J', 'NOEs' , ('uNOEs', '<')], - 'forward_qs': ['backbone1_gamma', 'backbone2_beta_epsilon','sugar'], - 'obs': ['NOEs', 'uNOEs'], - 'forward_coeffs': 'original_fm_coeffs'}} +# infos = {'global': { +# 'path_directory': 'tests/DATA_test', +# 'system_names': ['AAAA', 'CAAU'], +# 'g_exp': ['backbone1_gamma_3J', 'backbone2_beta_epsilon_3J', 'sugar_3J', 'NOEs' , ('uNOEs', '<')], +# 'forward_qs': ['backbone1_gamma', 'backbone2_beta_epsilon','sugar'], +# 'obs': ['NOEs', 'uNOEs'], +# 'forward_coeffs': 'original_fm_coeffs'}} - stride = 2 +# stride = 2 - def forward_model_fun(fm_coeffs, forward_qs, selected_obs=None): +# def forward_model_fun(fm_coeffs, forward_qs, selected_obs=None): - # 1. compute the cosine (which is the quantity you need in the forward model; - # you could do this just once before loading data) - forward_qs_cos = {} +# # 1. compute the cosine (which is the quantity you need in the forward model; +# # you could do this just once before loading data) +# forward_qs_cos = {} - for type_name in forward_qs.keys(): - forward_qs_cos[type_name] = jnp.cos(forward_qs[type_name]) +# for type_name in forward_qs.keys(): +# forward_qs_cos[type_name] = jnp.cos(forward_qs[type_name]) - # if you have selected_obs, compute only the corresponding observables - if selected_obs is not None: - for type_name in forward_qs.keys(): - forward_qs_cos[type_name] = forward_qs_cos[type_name][:,selected_obs[type_name+'_3J']] +# # if you have selected_obs, compute only the corresponding observables +# if selected_obs is not None: +# for type_name in forward_qs.keys(): +# forward_qs_cos[type_name] = forward_qs_cos[type_name][:,selected_obs[type_name+'_3J']] - # 2. compute observables (forward_qs_out) through forward model - forward_qs_out = { - 'backbone1_gamma_3J': fm_coeffs[0]*forward_qs_cos['backbone1_gamma']**2 + fm_coeffs[1]*forward_qs_cos['backbone1_gamma'] + fm_coeffs[2], - 'backbone2_beta_epsilon_3J': fm_coeffs[3]*forward_qs_cos['backbone2_beta_epsilon']**2 + fm_coeffs[4]*forward_qs_cos['backbone2_beta_epsilon'] + fm_coeffs[5], - 'sugar_3J': fm_coeffs[6]*forward_qs_cos['sugar']**2 + fm_coeffs[7]*forward_qs_cos['sugar'] + fm_coeffs[8] } +# # 2. compute observables (forward_qs_out) through forward model +# forward_qs_out = { +# 'backbone1_gamma_3J': fm_coeffs[0]*forward_qs_cos['backbone1_gamma']**2 + fm_coeffs[1]*forward_qs_cos['backbone1_gamma'] + fm_coeffs[2], +# 'backbone2_beta_epsilon_3J': fm_coeffs[3]*forward_qs_cos['backbone2_beta_epsilon']**2 + fm_coeffs[4]*forward_qs_cos['backbone2_beta_epsilon'] + fm_coeffs[5], +# 'sugar_3J': fm_coeffs[6]*forward_qs_cos['sugar']**2 + fm_coeffs[7]*forward_qs_cos['sugar'] + fm_coeffs[8] } - return forward_qs_out +# return forward_qs_out - infos['global']['forward_model'] = forward_model_fun - infos['global']['names_ff_pars'] = ['sin alpha', 'cos alpha'] +# infos['global']['forward_model'] = forward_model_fun +# infos['global']['names_ff_pars'] = ['sin alpha', 'cos alpha'] - def ff_correction(pars, f): - out = jnp.matmul(pars, (f[:, [0, 6]] + f[:, [1, 7]] + f[:, [2, 8]]).T) - return out +# def ff_correction(pars, f): +# out = jnp.matmul(pars, (f[:, [0, 6]] + f[:, [1, 7]] + f[:, [2, 8]]).T) +# return out - infos['global']['ff_correction'] = ff_correction +# infos['global']['ff_correction'] = ff_correction - path_pickle = 'tests/DATA_test/data_complete_stride2.pkl' +# path_pickle = 'tests/DATA_test/data_complete_stride2.pkl' - test_function(infos, stride, path_pickle) +# test_function(infos, stride, path_pickle) - #%% test n. 3: alchemical calculations +# #%% test n. 3: alchemical calculations - infos = {'global': {'temperature': 2.476, 'path_directory': 'tests/DATA_test'}} +# infos = {'global': {'temperature': 2.476, 'path_directory': 'tests/DATA_test'}} - cycle_names = ['A1'] +# cycle_names = ['A1'] - names = {} - for name in cycle_names: - names[name] = [] - for string in ['AS','AD','MS','MD']: - names[name].append((name + '_' + string)) +# names = {} +# for name in cycle_names: +# names[name] = [] +# for string in ['AS','AD','MS','MD']: +# names[name].append((name + '_' + string)) - infos['global']['cycle_names'] = names - infos['global']['system_names'] = [s2 for s in list(names.values()) for s2 in s] +# infos['global']['cycle_names'] = names +# infos['global']['system_names'] = [s2 for s in list(names.values()) for s2 in s] - # force-field correction terms +# # force-field correction terms - n_charges = 5 +# n_charges = 5 - infos['global']['names_ff_pars'] = ['DQ %i' % (i+1) for i in range(n_charges)] + ['cos eta'] +# infos['global']['names_ff_pars'] = ['DQ %i' % (i+1) for i in range(n_charges)] + ['cos eta'] - columns = [] - for i in range(n_charges): - columns.append('DQ %i' % (i+1)) - columns.append('DQ %i%i' % (i+1,i+1)) - for i in range(n_charges): - for j in range(i+1,n_charges): - columns.append('DQ %i%i' % (i+1,j+1)) - columns.append('cos eta') +# columns = [] +# for i in range(n_charges): +# columns.append('DQ %i' % (i+1)) +# columns.append('DQ %i%i' % (i+1,i+1)) +# for i in range(n_charges): +# for j in range(i+1,n_charges): +# columns.append('DQ %i%i' % (i+1,j+1)) +# columns.append('cos eta') - # only methylated (M) systems have a force-field correction +# # only methylated (M) systems have a force-field correction - for name in infos['global']['system_names']: infos[name] = {} +# for name in infos['global']['system_names']: infos[name] = {} - for name in infos['global']['cycle_names'].keys(): - for s in ['D', 'S']: - infos[name + '_M' + s]['ff_terms'] = columns +# for name in infos['global']['cycle_names'].keys(): +# for s in ['D', 'S']: +# infos[name + '_M' + s]['ff_terms'] = columns - names_charges = ['N6', 'H61', 'N1', 'C10', 'H101/2/3'] +# names_charges = ['N6', 'H61', 'N1', 'C10', 'H101/2/3'] - def ff_correction(phi, ff_terms): +# def ff_correction(phi, ff_terms): - n_charges = 5 +# n_charges = 5 - phi_vector = [] - for i in range(n_charges): - phi_vector.extend([phi[i], phi[i]**2]) - for i in range(n_charges): - for j in range(i+1,n_charges): - phi_vector.append(phi[i]*phi[j]) - phi_vector.append(-phi[-1]) - phi_vector = jnp.array(phi_vector) +# phi_vector = [] +# for i in range(n_charges): +# phi_vector.extend([phi[i], phi[i]**2]) +# for i in range(n_charges): +# for j in range(i+1,n_charges): +# phi_vector.append(phi[i]*phi[j]) +# phi_vector.append(-phi[-1]) +# phi_vector = jnp.array(phi_vector) - correction = jnp.matmul(ff_terms, phi_vector) +# correction = jnp.matmul(ff_terms, phi_vector) - return correction +# return correction - for k in infos['global']['system_names']: - if k[-2] == 'M': - infos[k]['ff_correction'] = ff_correction +# for k in infos['global']['system_names']: +# if k[-2] == 'M': +# infos[k]['ff_correction'] = ff_correction - stride = 2 - path_pickle = 'tests/DATA_test/data_alchemical_stride2.pkl' +# stride = 2 +# path_pickle = 'tests/DATA_test/data_alchemical_stride2.pkl' - test_function(infos, stride, path_pickle) +# test_function(infos, stride, path_pickle) -if __name__ == "__main__": - unittest.main() \ No newline at end of file +# if __name__ == "__main__": +# unittest.main() \ No newline at end of file