-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added plot and change_plotting_library methods directly to the Model …
…class (wrapping lower-level model ones); changing to plotly works but throws many warnings. Added Laplace Approximation of Leave-One-Out Cross-Validation error (LA_LOO) utility function for possible future work in balancing it with complexity scores.
- Loading branch information
Showing
6 changed files
with
98 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,37 +1,77 @@ | ||
import pickle | ||
import pprint | ||
import numpy as np | ||
import scipy.io as sio | ||
from matplotlib import pyplot as plt | ||
from datetime import datetime | ||
|
||
from GPy_ABCD.Models.modelSearch import * | ||
from GPy_ABCD import config as global_flags | ||
from testConsistency import save_one_run | ||
|
||
|
||
# np.seterr(all='raise') # Raise exceptions instead of RuntimeWarnings. The exceptions can then be caught by the debugger | ||
datasets = ['01-airline', '02-solar', '03-mauna', '04-wheat', '05-temperature', '06-internet', '07-call-centre', '08-radio', '09-gas-production', '10-sulphuric', '11-unemployment', '12-births', '13-wages'] | ||
# Only 1, 2, 10 and 11 have published analyses, and their identified formulae are (deciphered from component descriptions): | ||
# 1: LIN + PER * LIN + SE + WN * LIN | ||
# Default Rules: (PER + C + LIN * (PER + C)) * (PER + PER + WN) * (C + LIN) | ||
# 2: C + CW_1643_1716(PER + SE + RQ + WN * LIN + WN * LIN, C + WN) | ||
# Default Rules: C + (PER + C) * (PER + PER + PER + WN) | ||
# 10: PER + SE + CP_64(PER + WN, CW_69_77(SE, SE) + CP_90(C + WN, WN)) | ||
# Default Rules: (PER + C) * (PER + LIN + PER * LIN) | ||
# 11: SE + PER + SE + SE + WN | ||
# Default Rules: (PER + PER + PER + C) * (C + LIN) | ||
|
||
|
||
if __name__ == '__main__': | ||
retrieve_instead = False | ||
|
||
datasets_to_test = [1, 2]#, 10, 11] | ||
|
||
def run_for_dataset_number(dataset_id): | ||
dataset_name = datasets[dataset_id - 1] | ||
data = sio.loadmat(f'./Data/{dataset_name}.mat') | ||
# print(data.keys()) | ||
|
||
X = data['X'] | ||
Y = data['y'] | ||
|
||
args_to_save = {'start_kernels': start_kernels['Default'], 'p_rules': production_rules['Default'], 'utility_function': BIC, | ||
'rounds': 5, 'beam': 2, 'restarts': 10, 'model_list_fitter': fit_mods_parallel_processes, 'optimiser': GPy_optimisers[0], 'verbose': True} | ||
best_mods, all_mods, all_exprs, expanded, not_expanded = explore_model_space(X, Y, **args_to_save) | ||
|
||
# np.seterr(all='raise') # Raise exceptions instead of RuntimeWarnings. The exceptions can then be caught by the debugger | ||
datasets = ['01-airline', '02-solar', '03-mauna', '04-wheat', '05-temperature', '06-internet', '07-call-centre', '08-radio', '09-gas-production', '10-sulphuric', '11-unemployment', '12-births', '13-wages'] | ||
dataset_name = datasets[1-1] | ||
# for mod_depth in all_mods: print(', '.join([str(mod.kernel_expression) for mod in mod_depth]) + f'\n{len(mod_depth)}') | ||
# | ||
# print() | ||
# | ||
# from matplotlib import pyplot as plt | ||
# for bm in best_mods[:3]: model_printout(bm) | ||
# plt.show() | ||
|
||
data = sio.loadmat(f'./Data/{dataset_name}.mat') | ||
# print(data.keys()) | ||
with open(f'./Pickles/{dataset_name}_{datetime.now().strftime("%d-%m-%Y_%H-%M-%S")}', 'wb') as f: | ||
pickle.dump({'dataset_name': dataset_name, 'best_mods': best_mods[:10], | ||
'str_of_args': pprint.pformat(args_to_save, width = 40, compact = True), | ||
'global_flags': { | ||
'__INCLUDE_SE_KERNEL': global_flags.__INCLUDE_SE_KERNEL, | ||
'__USE_LIN_KERNEL_HORIZONTAL_OFFSET': global_flags.__USE_LIN_KERNEL_HORIZONTAL_OFFSET, | ||
'__USE_NON_PURELY_PERIODIC_PER_KERNEL': global_flags.__USE_NON_PURELY_PERIODIC_PER_KERNEL, | ||
'__FIX_SIGMOIDAL_KERNELS_SLOPE': global_flags.__FIX_SIGMOIDAL_KERNELS_SLOPE, | ||
'__USE_INDEPENDENT_SIDES_CHANGEWINDOW_KERNEL': global_flags.__USE_INDEPENDENT_SIDES_CHANGEWINDOW_KERNEL | ||
} }, f) | ||
|
||
X = data['X'] | ||
Y = data['y'] | ||
# save_one_run(dataset_name, 'UNKNOWN', best_mods, all_mods, all_exprs) | ||
|
||
sorted_models, tested_models, tested_k_exprs, expanded, not_expanded = explore_model_space(X, Y, start_kernels = standard_start_kernels, p_rules = production_rules_all, | ||
restarts = 3, utility_function = 'BIC', rounds = 3, buffer = 2, dynamic_buffer = True, verbose = True, parallel = True) | ||
|
||
for mod_depth in tested_models: print(', '.join([str(mod.kernel_expression) for mod in mod_depth]) + f'\n{len(mod_depth)}') | ||
## ACTUAL EXECUTION ## | ||
|
||
from matplotlib import pyplot as plt | ||
for bm in sorted_models[:3]: | ||
print(bm.kernel_expression) | ||
print(bm.model.kern) | ||
print(bm.model.log_likelihood()) | ||
print(bm.cached_utility_function) | ||
bm.model.plot() | ||
print(bm.interpret()) | ||
if not retrieve_instead: | ||
for id in datasets_to_test: run_for_dataset_number(id) | ||
else: | ||
file_names = ['01-airline_19-12-2020_19-22-21', '02-solar_19-12-2020_22-27-53', '10-sulphuric_20-12-2020_04-01-08', '11-unemployment_20-12-2020_12-04-19'] | ||
with open(f'./Pickles/{file_names[0]}', 'rb') as f: IMPORTED = pickle.load(f) | ||
# print(IMPORTED) | ||
|
||
plt.show() | ||
for bm in IMPORTED['best_mods'][:3]: model_printout(bm) | ||
plt.show() | ||
|
||
save_one_run(dataset_name, 'UNKNOWN', sorted_models, tested_models, tested_k_exprs) | ||
|