Skip to content

Commit

Permalink
Bumped Runner call to updated version
Browse files Browse the repository at this point in the history
  • Loading branch information
abhinaukumar committed Oct 31, 2023
1 parent 070a450 commit b0cd08c
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions examples/gridsearch_crossval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
from sklearn.preprocessing import MinMaxScaler

from qualitylib.tools import import_python_file, read_dataset
from qualitylib.feature_extractor import get_fex
from qualitylib.runner import Runner
from qualitylib.cross_validate import random_cross_validation

from feature_extractors.ssim_fex import SsimFeatureExtractor # Import feature extractor(s) to make visible to get_fex

np.random.seed(0)


Expand Down Expand Up @@ -53,7 +56,7 @@ def print_agg_stats(stats: Dict[str, Any]) -> None:
print('Stat,Median,LoCI,HiCI,Std') # Not using spaces makes parsing text output as csv easier
for stat_key in stats[max_param_key][0]:
key_stats = np.array([stat[stat_key] for stat in stats[max_param_key]])
print(f'{stat_key},{np.median(key_stats)},{np.percentile(key_stats, lo_ci)},{np.percentile(key_stats, hi_ci)},{np.std(key_stats)}')
print(f'{stat_key},{np.median(key_stats):.4f},{np.percentile(key_stats, lo_ci):.4f},{np.percentile(key_stats, hi_ci):.4f},{np.std(key_stats):.4f}')


def dict_to_str(d: Dict[Any, Any]) -> str:
Expand All @@ -68,7 +71,6 @@ def get_parser() -> argparse.ArgumentParser:
parser.add_argument('--dataset', help='Path to dataset file for which to extract features', type=str)
parser.add_argument('--fex_name', help='Name of feature extractor', type=str)
parser.add_argument('--fex_version', help='Version of feature extractor', type=str, default=None)
parser.add_argument('--feat_file', help='Path to csv containing features', type=str, required=True)
parser.add_argument('--feat_names_file', help='Path to file containing feature sets', type=str, default=None)
parser.add_argument('--regressor', help='Regressor to use', type=str, default='RandomForest')
parser.add_argument('--splits', help='Number of parallel processes', type=int, default=100)
Expand All @@ -85,7 +87,8 @@ def main() -> None:

dataset = import_python_file(args.dataset)
assets = read_dataset(dataset, shuffle=True)
runner = Runner(args.fex_name, args.fex_version, processes=args.processes, use_cache=True) # Reads from stored results if available, else stores results.
FexClass = get_fex(args.fex_name, args.fex_version)
runner = Runner(FexClass, processes=args.processes, use_cache=True) # Reads from stored results if available, else stores results.

if args.feat_names_file is not None:
mod = import_python_file(args.feat_names_file)
Expand All @@ -108,15 +111,15 @@ def main() -> None:

res_dict = {}
for key in feat_names_dict:
results = runner(assets, return_results=True, feat_names=np.array(feat_names_dict[key])) # Extract features and return only specified features for cross-validation.
results = runner(assets, return_results=True, feat_names=np.array(feat_names_dict[key]) if feat_names_dict[key] is not None else None) # Extract features and return only specified features for cross-validation.
start_time = time.time()
temp_res_dict = {}
for model_param_dict in model_params:
agg_stats = random_cross_validation(partial(ModelClass, **model_param_dict), results, splits=args.splits, test_fraction=0.2, processes=args.processes)
temp_res_dict[dict_to_str(model_param_dict)] = agg_stats['stats'] # Metrics computed from each split.
print(f'Tested params: {model_param_dict}. Time elapsed {((time.time() - start_time)/60):.2f} minutes.')
print(f'Results - {key}')
print_agg_stats(res_dict)
print_agg_stats(temp_res_dict)
res_dict[key] = temp_res_dict

with open(args.out_file, 'wb') as out_file:
Expand Down

0 comments on commit b0cd08c

Please sign in to comment.