From 39f23313e61446fa00592b594df4a240a75429ac Mon Sep 17 00:00:00 2001 From: mamu Date: Tue, 2 Jul 2024 16:02:07 +0900 Subject: [PATCH] Fix pre-trained model URL in PFNs4BO sampler --- package/samplers/pfns4bo/README.md | 4 ++-- package/samplers/pfns4bo/sampler.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/package/samplers/pfns4bo/README.md b/package/samplers/pfns4bo/README.md index e8dd6ebc..91e78131 100644 --- a/package/samplers/pfns4bo/README.md +++ b/package/samplers/pfns4bo/README.md @@ -51,12 +51,12 @@ The benchmark consists of 70 problems. The 8 problems are from HPO tabular bench ## Others -The default prior argument is ``"hebo"``. This trains the PFNs model in the init of the sampler. If you want to use a pre-trained model, you can download the model checkpoint from the following link: https://github.com/automl/PFNs/blob/main/models_diff/prior_diff_real_checkpoint_n_0_epoch_42.cpkt and load it using the following code: +The default prior argument is ``"hebo"``. This trains the PFNs model in the init of the sampler. If you want to use a pre-trained model, you can download the model checkpoint from the following link: https://github.com/automl/PFNs4BO/tree/main/pfns4bo/final_models and load it using the following code: ```python import torch -model = torch.load("PATH/TO/prior_diff_real_checkpoint_n_0_epoch_42.cpkt") +model = torch.load("PATH/TO/MODEL.pt") sampler = PFNs4BOSampler(prior=model) ``` diff --git a/package/samplers/pfns4bo/sampler.py b/package/samplers/pfns4bo/sampler.py index 2b03f8f9..5e11d62b 100644 --- a/package/samplers/pfns4bo/sampler.py +++ b/package/samplers/pfns4bo/sampler.py @@ -148,13 +148,13 @@ class PFNs4BOSampler(BaseSampler): The default prior argument is ``"hebo"``. This trains the PFNs model in the init of the sampler. If you want to use a pre-trained model, you can download the model checkpoint from the following link: - https://github.com/automl/PFNs/blob/main/models_diff/prior_diff_real_checkpoint_n_0_epoch_42.cpkt + https://github.com/automl/PFNs4BO/tree/main/pfns4bo/final_models and load it using the following code: .. code-block:: python import torch - model = torch.load("PATH/TO/prior_diff_real_checkpoint_n_0_epoch_42.cpkt") + model = torch.load("PATH/TO/MODEL.pt") sampler = PFNs4BOSampler(prior=model) .. note::