diff --git a/openfasoc/MLoptimization/eval.py b/openfasoc/MLoptimization/eval.py index 1d0b794dc..1a2164837 100644 --- a/openfasoc/MLoptimization/eval.py +++ b/openfasoc/MLoptimization/eval.py @@ -16,7 +16,7 @@ def unlookup(norm_spec, goal_spec): spec = -1*np.multiply((norm_spec+1), goal_spec)/(norm_spec-1) return spec -def evaluate_model(): +def evaluate_model(checkpoint_dir: str = "./last_checkpoint"): specs = yaml.safe_load(Path('newnew_eval_3.yaml').read_text()) #training set up @@ -47,7 +47,7 @@ def evaluate_model(): args = parser.parse_args() env = Envir(env_config=env_config) - agent = PPO.from_checkpoint("./last_checkpoint") + agent = PPO.from_checkpoint(checkpoint_dir) norm_spec_ref = env.global_g spec_num = len(env.specs)