From 31d5da3f567ef8a9aeafd45cf98549aed81ac7bb Mon Sep 17 00:00:00 2001 From: Felix Benning Date: Thu, 16 May 2024 20:09:39 +0200 Subject: [PATCH] multiple runs of covariance fit --- benchmarking/classification/__main__.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/benchmarking/classification/__main__.py b/benchmarking/classification/__main__.py index e094481..a329b45 100644 --- a/benchmarking/classification/__main__.py +++ b/benchmarking/classification/__main__.py @@ -114,6 +114,18 @@ def main(problem_name, opt): cache=f"""cache/{problem["dataset"].__name__}/{problem["model"].__name__}/covariance_cache.csv""", ) + if opt == "cov": + for run in range(20): + sq_exp_cov_model = covariance.SquaredExponential() + sq_exp_cov_model.auto_fit( + model_factory=problem["model"], + loss=problem["loss"], + data=data.data_train, + tol=problem['tol'], + cache=f"""cache/{problem["dataset"].__name__}/{problem["model"].__name__}_run={run}/covariance_cache.csv""", + ) + + if opt == "RFD": for seed in range(20): problem["seed"] = seed