Skip to content

Commit

Permalink
Merge pull request #87 from HideakiImamura/fix-pre-trained-model-url
Browse files Browse the repository at this point in the history
Fix pre-trained model URL in PFNs4BO sampler
  • Loading branch information
y0z authored Jul 2, 2024
2 parents 3d55b07 + 39f2331 commit ee52694
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions package/samplers/pfns4bo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ The benchmark consists of 70 problems. The 8 problems are from HPO tabular bench

## Others

The default prior argument is ``"hebo"``. This trains the PFNs model in the init of the sampler. If you want to use a pre-trained model, you can download the model checkpoint from the following link: https://github.com/automl/PFNs/blob/main/models_diff/prior_diff_real_checkpoint_n_0_epoch_42.cpkt and load it using the following code:
The default prior argument is ``"hebo"``. This trains the PFNs model in the init of the sampler. If you want to use a pre-trained model, you can download the model checkpoint from the following link: https://github.com/automl/PFNs4BO/tree/main/pfns4bo/final_models and load it using the following code:

```python
import torch

model = torch.load("PATH/TO/prior_diff_real_checkpoint_n_0_epoch_42.cpkt")
model = torch.load("PATH/TO/MODEL.pt")
sampler = PFNs4BOSampler(prior=model)
```

Expand Down
4 changes: 2 additions & 2 deletions package/samplers/pfns4bo/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,13 @@ class PFNs4BOSampler(BaseSampler):
The default prior argument is ``"hebo"``. This trains the PFNs model in the
init of the sampler. If you want to use a pre-trained model, you can download
the model checkpoint from the following link:
https://github.com/automl/PFNs/blob/main/models_diff/prior_diff_real_checkpoint_n_0_epoch_42.cpkt
https://github.com/automl/PFNs4BO/tree/main/pfns4bo/final_models
and load it using the following code:
.. code-block:: python
import torch
model = torch.load("PATH/TO/prior_diff_real_checkpoint_n_0_epoch_42.cpkt")
model = torch.load("PATH/TO/MODEL.pt")
sampler = PFNs4BOSampler(prior=model)
.. note::
Expand Down

0 comments on commit ee52694

Please sign in to comment.