diff --git a/examples/gridsearch_crossval.py b/examples/gridsearch_crossval.py index 48634c0..d7b4793 100644 --- a/examples/gridsearch_crossval.py +++ b/examples/gridsearch_crossval.py @@ -13,9 +13,12 @@ from sklearn.preprocessing import MinMaxScaler from qualitylib.tools import import_python_file, read_dataset +from qualitylib.feature_extractor import get_fex from qualitylib.runner import Runner from qualitylib.cross_validate import random_cross_validation +from feature_extractors.ssim_fex import SsimFeatureExtractor # Import feature extractor(s) to make visible to get_fex + np.random.seed(0) @@ -53,7 +56,7 @@ def print_agg_stats(stats: Dict[str, Any]) -> None: print('Stat,Median,LoCI,HiCI,Std') # Not using spaces makes parsing text output as csv easier for stat_key in stats[max_param_key][0]: key_stats = np.array([stat[stat_key] for stat in stats[max_param_key]]) - print(f'{stat_key},{np.median(key_stats)},{np.percentile(key_stats, lo_ci)},{np.percentile(key_stats, hi_ci)},{np.std(key_stats)}') + print(f'{stat_key},{np.median(key_stats):.4f},{np.percentile(key_stats, lo_ci):.4f},{np.percentile(key_stats, hi_ci):.4f},{np.std(key_stats):.4f}') def dict_to_str(d: Dict[Any, Any]) -> str: @@ -68,7 +71,6 @@ def get_parser() -> argparse.ArgumentParser: parser.add_argument('--dataset', help='Path to dataset file for which to extract features', type=str) parser.add_argument('--fex_name', help='Name of feature extractor', type=str) parser.add_argument('--fex_version', help='Version of feature extractor', type=str, default=None) - parser.add_argument('--feat_file', help='Path to csv containing features', type=str, required=True) parser.add_argument('--feat_names_file', help='Path to file containing feature sets', type=str, default=None) parser.add_argument('--regressor', help='Regressor to use', type=str, default='RandomForest') parser.add_argument('--splits', help='Number of parallel processes', type=int, default=100) @@ -85,7 +87,8 @@ def main() -> None: dataset = import_python_file(args.dataset) assets = read_dataset(dataset, shuffle=True) - runner = Runner(args.fex_name, args.fex_version, processes=args.processes, use_cache=True) # Reads from stored results if available, else stores results. + FexClass = get_fex(args.fex_name, args.fex_version) + runner = Runner(FexClass, processes=args.processes, use_cache=True) # Reads from stored results if available, else stores results. if args.feat_names_file is not None: mod = import_python_file(args.feat_names_file) @@ -108,7 +111,7 @@ def main() -> None: res_dict = {} for key in feat_names_dict: - results = runner(assets, return_results=True, feat_names=np.array(feat_names_dict[key])) # Extract features and return only specified features for cross-validation. + results = runner(assets, return_results=True, feat_names=np.array(feat_names_dict[key]) if feat_names_dict[key] is not None else None) # Extract features and return only specified features for cross-validation. start_time = time.time() temp_res_dict = {} for model_param_dict in model_params: @@ -116,7 +119,7 @@ def main() -> None: temp_res_dict[dict_to_str(model_param_dict)] = agg_stats['stats'] # Metrics computed from each split. print(f'Tested params: {model_param_dict}. Time elapsed {((time.time() - start_time)/60):.2f} minutes.') print(f'Results - {key}') - print_agg_stats(res_dict) + print_agg_stats(temp_res_dict) res_dict[key] = temp_res_dict with open(args.out_file, 'wb') as out_file: