From d8a3c21594f77182655d241a8632bbe772f67e3e Mon Sep 17 00:00:00 2001 From: clementchadebec <47564971+clementchadebec@users.noreply.github.com> Date: Mon, 4 Jul 2022 19:37:25 +0200 Subject: [PATCH] Integration with HuggingFace Hub (#28) * vaegan * remove test_vae * add auto config * start hf hub integration * model reloading refacto * add hf hub to AutoModel * add nf to AutoConfig * remove typo * add wandb tutorial * add hf hub tutorial * update gitignore * Update README * Update README * Update README * Update README * remove hf hub and wandb from coverage compute * update demo * update tests with AutoModel * remove wandb and hf hub from coverage compute * update README * black and isort formatting * update demo * add install command to demo * add finish to wandb callback * add test AutoModel for NF * add install comands * typo correction * black * fix typo * fix logging in load_from_hf_hub * add ModelOutput to AutoModel * switch from dill to cloudpickle * replace pickle by pickle5 * add pickle security to load_from_hf_hub * fix pickling outside of __main__ scope * black & isort * add env config * update test with env saving * update setup * isort & black * add check to hf_load * fix typo * prepare release 0.0.2 * fix typo --- .coveragerc | 3 + .gitignore | 4 +- README.md | 85 +++- .../notebooks/hf_hub_models_sharing.ipynb | 300 ++++++++++++ .../making_your_own_autoencoder.ipynb | 463 ++---------------- .../wandb_experiment_monitoring.ipynb | 191 ++++++++ requirements.txt | 3 +- setup.py | 16 +- .../adversarial_ae/adversarial_ae_model.py | 132 ++++- src/pythae/models/ae/ae_model.py | 15 - src/pythae/models/auto_model/__init__.py | 1 + src/pythae/models/auto_model/auto_config.py | 169 +++++++ src/pythae/models/auto_model/auto_model.py | 334 ++++++++++++- src/pythae/models/base/base_config.py | 5 + src/pythae/models/base/base_model.py | 232 ++++++++- src/pythae/models/base/base_utils.py | 15 +- .../models/beta_tc_vae/beta_tc_vae_model.py | 15 - src/pythae/models/beta_vae/beta_vae_model.py | 15 - .../disentangled_beta_vae_model.py | 15 - .../models/factor_vae/factor_vae_model.py | 54 -- src/pythae/models/hvae/hvae_model.py | 15 - src/pythae/models/info_vae/info_vae_model.py | 15 - src/pythae/models/iwae/iwae_model.py | 15 - .../models/msssim_vae/msssim_vae_model.py | 15 - .../normalizing_flows/base/base_nf_model.py | 32 +- .../models/normalizing_flows/iaf/iaf_model.py | 15 - .../normalizing_flows/made/made_model.py | 15 - .../models/normalizing_flows/maf/maf_model.py | 15 - .../pixelcnn/pixelcnn_model.py | 15 - .../planar_flow/planar_flow_model.py | 15 - .../radial_flow/radial_flow_model.py | 15 - src/pythae/models/rae_gp/rae_gp_model.py | 15 - src/pythae/models/rae_l2/rae_l2_model.py | 15 - src/pythae/models/rhvae/rhvae_model.py | 136 ++++- src/pythae/models/svae/svae_model.py | 15 - src/pythae/models/vae/vae_model.py | 15 - src/pythae/models/vae_gan/vae_gan_model.py | 132 ++++- src/pythae/models/vae_iaf/vae_iaf_model.py | 15 - .../models/vae_lin_nf/vae_lin_nf_model.py | 15 - src/pythae/models/vamp/vamp_model.py | 16 - src/pythae/models/vq_vae/vq_vae_model.py | 15 - src/pythae/models/wae_mmd/wae_mmd_model.py | 15 - src/pythae/trainers/training_callbacks.py | 7 +- tests/data/custom_architectures.py | 12 +- tests/test_AE.py | 14 +- tests/test_Adversarial_AE.py | 9 +- tests/test_BetaTCVAE.py | 14 +- tests/test_BetaVAE.py | 14 +- tests/test_DisentangledBetaVAE.py | 14 +- tests/test_FactorVAE.py | 9 +- tests/test_HVAE.py | 14 +- tests/test_IAF.py | 17 +- tests/test_IWAE.py | 14 +- tests/test_MADE.py | 376 +------------- tests/test_MAF.py | 17 +- tests/test_MSSSIMVAE.py | 14 +- tests/test_PixelCNN.py | 17 +- tests/test_RHVAE.py | 9 +- tests/test_SVAE.py | 14 +- tests/test_VAE.py | 14 +- tests/test_VAEGAN.py | 9 +- tests/test_VAE_IAF.py | 14 +- tests/test_VAE_LinFlow.py | 14 +- tests/test_VAMP.py | 14 +- tests/test_VQVAE.py | 14 +- tests/test_WAE_MMD.py | 14 +- tests/test_baseAE.py | 4 +- tests/test_info_vae_mmd.py | 14 +- tests/test_planar_flow.py | 17 +- tests/test_radial_flow.py | 17 +- tests/test_rae_gp.py | 14 +- tests/test_rae_l2.py | 14 +- 72 files changed, 1970 insertions(+), 1441 deletions(-) create mode 100644 .coveragerc create mode 100644 examples/notebooks/hf_hub_models_sharing.ipynb create mode 100644 examples/notebooks/wandb_experiment_monitoring.ipynb create mode 100644 src/pythae/models/auto_model/auto_config.py diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 00000000..641c99ce --- /dev/null +++ b/.coveragerc @@ -0,0 +1,3 @@ +[report] +exclude_lines = + # pragma: no cover \ No newline at end of file diff --git a/.gitignore b/.gitignore index ef433bd2..632838d1 100644 --- a/.gitignore +++ b/.gitignore @@ -33,7 +33,7 @@ examples/notebooks/dummy_output_dir/ examples/notebooks/models_training/dummy_output_dir/ logs/ downloads_data/ -*wandb* +*wandb *.slurm configs examples/scripts/generation_jz.py @@ -48,6 +48,8 @@ results.ipynb examples/scripts/artifacts/* examples/scripts/plots/* examples/notebooks/my_model_with_custom_archi/ +examples/notebooks/my_model/ +examples/net.py diff --git a/README.md b/README.md index cc5596b3..6675306f 100644 --- a/README.md +++ b/README.md @@ -31,10 +31,10 @@ # pythae -This library implements some of the most common (Variational) Autoencoder models. In particular it +This library implements some of the most common (Variational) Autoencoder models under a unified implementation. In particular, it provides the possibility to perform benchmark experiments and comparisons by training the models with the same autoencoding neural network architecture. The feature *make your own autoencoder* -allows you to train any of these models with your own data and own Encoder and Decoder neural networks. +allows you to train any of these models with your own data and own Encoder and Decoder neural networks. It integrates an experiment monitoring tool [wandb](https://wandb.ai/) ๐Ÿงช and allows model sharing and loading from the [HuggingFace Hub](https://huggingface.co/models) ๐Ÿค— in a few lines of code. # Installation @@ -244,7 +244,7 @@ The samplers can be used with any model as long as it is suited. For instance, a ``` -## Define you own Autoencoder architecture +## Define you own Autoencoder architecture Pythae provides you the possibility to define your own neural networks within the VAE models. For instance, say you want to train a Wassertstein AE with a specific encoder and decoder, you can do the following: @@ -315,18 +315,93 @@ You can also find predefined neural network architectures for the most common da ``` Replace *mnist* by cifar or celeba to access to other neural nets. -## Getting your hands on the code +## Sharing your models with the HuggingFace Hub ๐Ÿค— +Pythae also allows you to share your models on the [HuggingFace Hub](https://huggingface.co/models). To do so you need: +- a valid HuggingFace account +- the package `huggingface_hub` installed in your virtual env. If not you can install it with +``` +$ python -m pip install huggingface_hub +``` +- to be logged in to your HuggingFace account using +``` +$ huggingface-cli login +``` + +### Uploading a model to the Hub +Any pythae model can be easily uploaded using the method `push_to_hf_hub` +```python +>>> my_vae_model.push_to_hf_hub(hf_hub_path="your_hf_username/your_hf_hub_repo") +``` +**Note:** If `your_hf_hub_repo` already exists and is not empty, files will be overridden. In case, +the repo `your_hf_hub_repo` does not exist, a folder having the same name will be created. + +### Downloading models from the Hub +Equivalently, you can download or reload any Pythae's model directly from the Hub using the method `load_from_hf_hub` +```python +>>> from pythae.models import AutoModel +>>> my_downloaded_vae = AutoModel.load_from_hf_hub(hf_hub_path="path_to_hf_repo") +``` + +## Monitoring your experiments with **Wandb** ๐Ÿงช +Pythae also integrates the experiement tracking tool [wandb](https://wandb.ai/) allowing users to store their configs, monitor their trainings and compare runs through a graphic interface. To be able use this feature you will need: +- a valid wand account +- the package `wandb` installed in your virtual env. If not you can install it with +``` +$ pip install wandb +``` +- to be logged in to your wandb account using +``` +$ wandb login +``` + +### Creating a `WandbCallback` +Launching an experiment monitoring with `wandb` in pythae is pretty simple. The only thing a user needs to do is create a `WandbCallback` instance... + +```python +>>> # Create you callback +>>> from pythae.trainers.training_callbacks import WandbCallback +>>> callbacks = [] # the TrainingPipeline expects a list of callbacks +>>> wandb_cb = WandbCallback() # Build the callback +>>> # SetUp the callback +>>> wandb_cb.setup( +... training_config=your_training_config, # training config +... model_config=your_model_config, # model config +... project_name="your_wandb_project", # specify your wandb project +... entity_name="your_wandb_entity", # specify your wandb entity +... ) +>>> callbacks.append(wandb_cb) # Add it to the callbacks list +``` +...and then pass it to the `TrainingPipeline`. +```python +>>> pipeline = TrainingPipeline( +... training_config=config, +... model=model +... ) +>>> pipeline( +... train_data=train_dataset, +... eval_data=eval_dataset, +... callbacks=callbacks # pass the callbacks to the TrainingPipeline and you are done! +... ) +>>> # You can log to https://wandb.ai/your_wandb_entity/your_wandb_project to monitor your training +``` +See a detailes tutorial + +## Getting your hands on the code To help you to understand the way pythae works and how you can train your models with this library we also provide tutorials: - [making_your_own_autoencoder.ipynb](https://github.com/clementchadebec/benchmark_VAE/tree/main/examples/notebooks) shows you how to pass your own networks to the models implemented in pythae [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/clementchadebec/benchmark_VAE/blob/main/examples/notebooks/making_your_own_autoencoder.ipynb) +- [hf_hub_models_sharing.ipynb](https://github.com/clementchadebec/benchmark_VAE/tree/main/examples/notebooks) shows you how to upload and download models for the HuggingFace Hub [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/clementchadebec/benchmark_VAE/blob/main/examples/notebooks/hf_hub_models_sharing.ipynb) + +- [wandb_experiment_monitoring.ipynb](https://github.com/clementchadebec/benchmark_VAE/tree/main/examples/notebooks) shows you how to monitor you experiments using `wandb` [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/clementchadebec/benchmark_VAE/blob/main/examples/notebooks/wandb_experiment_monitoring.ipynb) + - [models_training](https://github.com/clementchadebec/benchmark_VAE/tree/main/examples/notebooks/models_training) folder provides notebooks showing how to train each implemented model and how to sample from it using `pythae.samplers`. - [scripts](https://github.com/clementchadebec/benchmark_VAE/tree/main/examples/scripts) folder provides in particular an example of a training script to train the models on benchmark data sets (mnist, cifar10, celeba ...) -## Dealing with issues +## Dealing with issues ๐Ÿ› ๏ธ If you are experiencing any issues while running the code or request new features/models to be implemented please [open an issue on github](https://github.com/clementchadebec/benchmark_VAE/issues). diff --git a/examples/notebooks/hf_hub_models_sharing.ipynb b/examples/notebooks/hf_hub_models_sharing.ipynb new file mode 100644 index 00000000..aa2f1dac --- /dev/null +++ b/examples/notebooks/hf_hub_models_sharing.ipynb @@ -0,0 +1,300 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Tutorial - Hugging Face Hub model sharing ๐Ÿค—\n", + "\n", + "In this notebook, we will see how to share your models with the community using the integrated Hugging Face Hub" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Install the library\n", + "%pip install pythae" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train your Pythae model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torchvision.datasets as datasets\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mnist_trainset = datasets.MNIST(root='../data', train=True, download=True, transform=None)\n", + "\n", + "train_dataset = mnist_trainset.data[:-10000].reshape(-1, 1, 28, 28) / 255.\n", + "eval_dataset = mnist_trainset.data[-10000:].reshape(-1, 1, 28, 28) / 255." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "from pythae.models import BetaVAE, BetaVAEConfig\n", + "from pythae.trainers import BaseTrainerConfig\n", + "from pythae.pipelines.training import TrainingPipeline\n", + "from pythae.models.nn.benchmarks.mnist import Encoder_ResNet_VAE_MNIST, Decoder_ResNet_AE_MNIST" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "config = BaseTrainerConfig(\n", + " output_dir='my_model',\n", + " learning_rate=1e-4,\n", + " batch_size=100,\n", + " num_epochs=1, # Change this to train the model a bit more\n", + ")\n", + "\n", + "\n", + "model_config = BetaVAEConfig(\n", + " input_dim=(1, 28, 28),\n", + " latent_dim=16,\n", + " beta=2.\n", + "\n", + ")\n", + "\n", + "model = BetaVAE(\n", + " model_config=model_config,\n", + " encoder=Encoder_ResNet_VAE_MNIST(model_config), \n", + " decoder=Decoder_ResNet_AE_MNIST(model_config) \n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pipeline = TrainingPipeline(\n", + " training_config=config,\n", + " model=model\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pipeline(\n", + " train_data=train_dataset,\n", + " eval_data=eval_dataset\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Reload your trained model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from pythae.models import AutoModel" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "last_training = sorted(os.listdir('my_model'))[-1]\n", + "trained_model = AutoModel.load_from_folder(os.path.join('my_model', last_training, 'final_model'))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Now let's share your model to the community through the Hugging Face hub! ๐Ÿค—" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To be able to access this feature you will need:\n", + "- a valid *username* from your Hugging Face account.\n", + "- the `huggingface_hub` package installed in your virtual env. You can install it by running (`$ python -m pip install huggingface_hub`)\n", + "- to be logged in to your hugginface account by running (`$ huggingface-cli login`)\n", + "\n", + "**note**: If the repo you specified is not empty, its content will be overidden. If the repo does not exist it will be created automatically under the name that was specified." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Before pushing or loading a model from the Hub you may need to run the following\n", + "# !python -m pip install huggingface_hub\n", + "# !huggingface-cli login" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Save the model to the Hub by specifying your username and the name of the repo in which you want to save your model\n", + "trained_model.push_to_hf_hub(\"your_hf_username/my_beta_vae\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trained_model_from_hf = AutoModel.load_from_hf_hub(\"your_hf_username/my_beta_vae\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# check that everything went well\n", + "assert all(\n", + " [\n", + " torch.equal(trained_model.state_dict()[key], trained_model_from_hf.state_dict()[key])\n", + " for key in model.state_dict().keys()\n", + " ]\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Use your model to do whatever you want" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pythae.samplers import NormalSampler" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# create normal sampler\n", + "normal_samper = NormalSampler(\n", + " model=trained_model_from_hf\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# sample\n", + "gen_data = normal_samper.sample(\n", + " num_samples=25\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show results with normal sampler\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(gen_data[i*5 +j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "3efa06c4da850a09a4898b773c7e91b0da3286dbbffa369a8099a14a8fa43098" + }, + "kernelspec": { + "display_name": "Python 3.8.11 64-bit ('pythae_dev': conda)", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/notebooks/making_your_own_autoencoder.ipynb b/examples/notebooks/making_your_own_autoencoder.ipynb index f7271acb..6e8d2231 100644 --- a/examples/notebooks/making_your_own_autoencoder.ipynb +++ b/examples/notebooks/making_your_own_autoencoder.ipynb @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -21,18 +21,9 @@ }, { "cell_type": "code", - "execution_count": 35, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "import torch\n", "import torchvision.datasets as datasets\n", @@ -54,33 +45,20 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mnist_trainset = datasets.MNIST(root='../data', train=True, download=True, transform=None)\n", - "n_samples = 10000\n", + "n_samples = 100\n", "dataset = mnist_trainset.data.reshape(-1, 1, 28, 28)[:n_samples] / 255." ] }, { "cell_type": "code", - "execution_count": 49, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAswAAACOCAYAAAAl3l5UAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAq2klEQVR4nO3dZ4BT1dbG8T9iRRSxd8UCqNgr6FUUxIIdBRtYsAv2LooKKsWOqCgqXsB2FVRUxILYe7sW7L2AWLBj5f1w3+fsk0xymJLkJJnn92VCcmZmz+bkJFl77bWazJ49GzMzMzMzy22utAdgZmZmZlbO/IbZzMzMzCyB3zCbmZmZmSXwG2YzMzMzswR+w2xmZmZmlsBvmM3MzMzMEsyd9GCTJk0aZc252bNnN5nTMZ6bZJ6f/Dw3+Xlu8vPc5Oe5Seb5yc9zk5/nJpMjzGZmZmZmCfyG2czMzMwsgd8wm5mZmZkl8BtmMzMzM7MEfsNsZmZmZpbAb5jNzMzMzBL4DbOZmZmZWYLEOsxW3jbccEMA+vTpE93Xq1cvAP79738DMGzYMABefvnlEo/OzKz6XH755QAcc8wxALzxxhsA7LTTTgB88skn6QzMzGp45JFHAGjS5H+llbfZZpt6/yxHmM3MzMzMEpRNhLlp06YAtGjRIufj8Shqs2bNAGjTpg0ARx99NAAXXXQRAPvss0907KxZswAYNGgQAOeee24hh52K9dZbD4CHHnoIgIUXXjh6bPbs/zXm6dmzJwC77LILAIsttlgJR1h5OnXqBMDYsWOj+7baaisA3nnnnVTGlJZ+/foB4bky11z/+1zdsWPH6JjHHnus5OOy9C200EIANG/ePLqva9euACyxxBIAXHLJJQD8/vvvJR5d8ay88srR7f333x+Af/75B4A11lgDgLZt2wKNM8LcunXr6PY888wDwJZbbgnAVVddBYT5qo27774bgL333ju6748//mjwONOmuenQoQMAF1xwAQCbb755amOqRpdeeml0W3OtVfeGcITZzMzMzCxBSSLMK664IgDzzjsvEN7xb7HFFtExiyyyCADdunWr9c/9/PPPAbjiiisA2H333QH46aefomNee+01oDoiYptssgkAd955JxCi8YoqQ/jb9WlckeXNNtsMyMxlTvsTuyIQGuP48eNTG8vGG28MwAsvvJDaGNJ24IEHAnDqqacCNSNC8fPMGgdFVnVOtG/fHoB27drl/Z5lllkGCDm+1WDGjBnR7ccffxwIq3eN0VprrQWEa8Zee+0VPaYVqWWXXRYI15G6XD80t9dcc01033HHHQfAjz/+WL9BlwG9Zj/66KMATJs2DYCll146499WP8okOOKII6L7/vzzTyDkMjeEI8xmZmZmZgmKFmFWni3A5MmTgfz5yXWlT6zKtfz555+BkH/61VdfRcd+//33QOXloSpPG2CDDTYAYMyYMUCI4OTy3nvvATBkyBAAbr31VgCeeuopIMwZwIUXXljAEdedcmJXX311IJ0Is6IhrVq1AmCllVaKHtOu2sZCf/v888+f8khKY9NNNwVCTqpy1hU9izvppJMA+PLLL4GwOqbn5HPPPVfcwZaAcnAhRPP2228/ABZYYAEgPCc+++yz6FitaimXt3v37kDIXX377beLOOrS+OWXX6LbjTFHOZteO3bcccei/h5VfQK4/vrrgfBaVg0UWXaEuTC0kq5ccYAnn3wSgNtvv73BP98RZjMzMzOzBH7DbGZmZmaWoGgpGZ9++ml0+9tvvwXqlpKhJc6ZM2cCsPXWW0ePabPa6NGjGzrMsjVixIjodrxM3pwofUNln7TZUekP66yzToFG2HBabnvmmWdSG4PSWw499FAgLLFDdSwlz0nnzp2j23379s14TH+/GjJMnz69dAMrkh49ekS31YBi8cUXB0K6wZQpU4BQJg1g6NChGT9Hx+qYePmrSqHr8eDBg4HMuVH5uGxK+dpuu+2i+7T8qfNF86mv1UCb0gHWXXfd9AZSJlTSNFdKxtdffw2EFAqlveUqK6cCAEqHamwaW9pfEhUBOPPMM4Hwvue7776b4/fqWG1G/uCDD6LHlE5XCI4wm5mZmZklKFqEOf6p4OSTTwZCpOqVV14BQjm4uFdffRWAbbfdFgibLeIbcY499tjCD7hMqN21mgFAzU+hihpPmDABCA1bIGxK0hxr06PaQZbTJ1pFHtI0cuTIjH8rglbttGntxhtvjO7LXgFSVLWSNznNPff/LnEbbbQRANddd130mDbWqkzYgAEDgLBJZL755ouO1YaRLl26ZPz8F198sRjDLgmV4TzkkEPmeKwiNrouxzf9rbbaakUYXXmJb8JWmdRsKk0ZX5mq5OdOkquvvhqAu+66q8ZjKuNVmw1sarql9uIqRSfxn1/Jz7V8VGqvsWy0TnLttdcCoQjAmmuuCYTrcZIzzjgDCCVqtWIMobRwIaT/jsXMzMzMrIyVpHGJPiWqvJzKEMVzwXr37g2EaGm8jA/Am2++Gd0+7LDDijbWtNSm3fXEiROBkK+jvK94qThFTFVoX5+ulD8Wj1wr3znezKQUlEe91FJLlfT35pIdVdX8V7sDDjgAqBnRgZDDW4hWomlTybjslQQI/9fK3c1uiBDP6c2OLKtp0k033VS4wZZYvNlEto8//hgIjXzUuCQeWRaVk6tmWrkDGDVqFADnnHNOxjH6t/bdAFx55ZVFHlk6/vrrLyD3+VAXyoVv2bJlzsf1PIPqarWeTStgzz77bMojSc+vv/4K1C3qrvdNKomq9znFitg7wmxmZmZmlqAkEWbJjuD88MMPNY5R7sltt90G5N5ZW01at24NhDxvRTy/+eab6Bg1YlE0S41a7rvvvoyvtaEGBAAnnngiEJoTlIp2VsfHUmqKbqthiXzxxRdpDKdkVLng4IMPBjKfX4qMDRw4sOTjKjTlIyu3TVELNdOAsDKTr9WudmvnorbP8ZbJlUbXWq3YPfjgg9Fj77//PhAqHiQph5WiUtK5lR1httqJV5TROZjvteDss88uyZhKRZF5vffR6/2qq66a2pjSpOcSwNprrw3A1KlTgeTc4wUXXBAIK1/aY6AI/R133FH4weIIs5mZmZlZIr9hNjMzMzNLUNKUjGzxJS2VU9NGNjVUiC8TVot4uSptclSagjZEqqkHhHI6hU5hyFceqdjatGmT8e/4hs5S0bxrOfndd98FwvxXm5VXXhmAO++8M+8xw4YNA+DRRx8txZAKLr58q1QMNTmaNGkSEJbwAH777beM79dGEW3wiz8/VI5R6Sp33313QceeBm1ka2hqQfv27QswmsqT1JDDAqX8nXbaaUBmGUI1vcmm8rIqUVctlPb2xBNPAKHUbmOzwgorAJnl35Su0qdPHyA53e2SSy4BwsZlXcs233zzwg82xhFmMzMzM7MEqUaY46Xj9ElDJc7UYEDRrnjR8uHDhwNhI0+lWX/99aPb2a1Fd911VyA0J2kMVLqq0FSab/vtt4/uU5mx7DJh2nwQLwlVTTQH2a3RH3nkkei2WkVXGrUtPuqoo6L7dG1QZHm33XbL+/2KeI0dOxYIq11x2kQyZMiQBo+3kmhzozbZ5KLNOvL0008D6ba8LwVFliv1daghtGLVs2dPIKwI56ImSUnzpI23ikLff//9QM1VIKtsal09fvx4IGxCh7DCme+9T7zF9YEHHpjx2Pnnn1/IYeblCLOZmZmZWYJUI8xxar2qTw5q2atPsPoKIdqhxgoqu1YplH8DITdSn6qKFVku53y7RRdddI7HqMmN5ise0Vh++eUBmHfeeYGQM6e/OR6leO6554BQBF+tk1966aX6/wFlTJHVQYMGZdyvdqNqYAK5yzxWAv2/x6MVogjpkksuCcBBBx0UPbbLLrsAIerRvHlzIETC4hGxMWPGADUbKlUDlWRSK1qA/v37AzVXwJKuI8oj1Bz//fffhR+spUrPlXvuuQco3D4Y5fSqPXJjo5bO1USvrRBWdq+//nog93VEeyFOP/10ILxP0vuDeKMlvQ/Qe8ARI0YU/g/IwRFmMzMzM7MEZRNhFuW2vPfee0D4lNGpU6fomAsuuAAI7RCVv1LuTSe0I1btHCFEsfSJvVhy5dtpJ3KpKeKrsVxzzTVAqGyQi3Jv9clSO2ohtNR86623ALjhhhuAkPcej9pPnz4dCC1XVXnk7bffrvffU26UXwj5q2J8+OGHQJiPSqZKGPFd1UsssQQAH330EZCcP6nIqPIol1lmGSCzedCECRMKOOJ0qTKB9lLoHNHfDeE5qrlRPrJy4RWVjlNEaY899gBCTrz+f6x66Dqsr0lqs7qp18YddtgBgIkTJzZ0iBVFq13VJN6gZuTIkUC4DutcUIMkCO3B9VX7uZZbbjkg8/qka70acJWKI8xmZmZmZgnKLsIsb7zxBgDdu3cHYOedd44eU37z4YcfDsDqq68OwLbbblvKIdaZopnKuYTQelatwAtFtZ6za6xOnjw5uq1coVJTNYNPPvkEgA4dOszxez799FMA7rrrLiC0z4TQDrM21AZYEUhFWqtJvNZwvqhOdk5zJVNlk3gljHvvvRcI+W/aIxGvnzxq1CgAvvvuOwBuvfVWIEQy9O9qEL/mKEo8bty4jGPOPffc6LauE0899RQQ5lH3K5c1Ts+pCy+8EKj5nIWwd6Aa5IucbrnlltHtK6+8sqRjKja9Lnfs2BEIuamqRgMwa9asxJ/Ru3fv6Hbfvn0LPMLKoOpf1ViHuUePHkB4nwahnrau1fvuuy8A33//fXTMxRdfDIReHIo0axUjvkqo/SqfffYZEM5HXeeLxRFmMzMzM7MEfsNsZmZmZpagbFMyRCH80aNHR/cpgVybTLQEprD8lClTSja+htISZaFK4ykVo1+/fgCcfPLJQNjkpmUPgJ9//rkgv7O+Bg8eXPLfGd88CsmtoiuNNpNmN2WJU0rCO++8U4ohlZRKBkJID6gNXT+0FKgl9mpI19EGv3i6ha4Jog1WahwA4bqreVQjCTUpiW/kUzMXpWlos44awTz88MPRsXrOx5diIb0NyA2Rr3GJNj1CKNWnDcnVQul09WkYEU8TbKwpGUpXEj1PVcgAwhxXGqXKxv/GgQMHAplpGtl0LqhEnMrM5aI0DaW2FDsVQxxhNjMzMzNLULYRZpUR23PPPQHYeOONo8fiBbEhfHp//PHHSzS6wilEObl4mTpFj5R4r4hit27dGvx7qpHKGFaDBx98EICWLVvWeEwbI7NbilrYjJsdMazkTX9NmzYFQsv3eFtZNV9RG2L9nfG28Npwo01rKkGncp9HHnlkdKyiPGpFr028aiAUL5n10EMPZYxTm3ZatWpV1z8xdSqHqYhaLtpkfNxxx5ViSBVhu+22S3sIqYuXRYUQMdUKcSXTe474pmI9z5NoI1/2huJ99tkHCBtO47RyXiqOMJuZmZmZJSibCHObNm0A6NOnDxDywJZeeum836PWq8r/Lce2z3G5ir2rFNaxxx5b5593/PHHA3DWWWdF97Vo0QII+YO9evWq11it8qi9aq7nwVVXXQWkn7dejuIlsaqFIpuKLKu5D4SIqFYkNttsMyCzbbgaSCj6ft555wEhBzFXxEiNXx544IGMr4oQQSgnJbqGVaJqanaUi/Jq43siVFZQjW3qQueXGto0ZorC6hxq27YtkLkSofKrlaYu/796vwKh9bVWqpSXfPvttxdwdA3jCLOZmZmZWYJUIsyKGscjD4osx9v65qOWx9qhW+y20oWi3Mj4rmrNxRVXXAGEts7ffvttdIwiQD179gRg3XXXBWD55ZcHMnejKlqmiKLlpih/69atgbo1Pyk3ivqpkUIuTz/9dKmGU3GqMafy7LPPzvi3cpoh7HNQtYLVVlst78/RMWpGolW9urjlllty3q50qiqi3f2rrrpqjWO0cqhjS7WbvyG22GILAM4880wgsyGYcs1rk5OqZjc77rgjAJdccgmQu626ItZzanpSbbTKo/bPJ5xwQprDKbl4FF37ItTMbZtttkllTEkcYTYzMzMzS1CSCPNSSy0FhJqU2nmtvJ0kqq06dOjQ6D7l/5R7znJtKPKjT1qqZqF8QAitv7Mpaqhd6lAzsmS5KcqfFJUtd6qO0rlzZyA8H+I1cocPHw7A9OnTSzu4CrLKKqukPYSCmzZtGhDqKMd332uFSlRjOV5lSO2sP/74Y6B+keXG4s033wRyn0eV+Bql1+dc7c9POeUUAH766ac5/hxFpjfYYAOgZr1qCD0Trr76aiDztawx0dzEr93VTPWmDznkkOg+zcG1114LlL4CRm1U7rsFMzMzM7MS8BtmMzMzM7MEBU/JUKK/2htCWDquzdKn0gzUwlmb2OpTxqbcPPPMMwC88MIL0X3xhiwQNgEqjSVOGwHVaKA+pegsk9pvjho1Kt2B1MMiiywC1Cy9+MUXX0S34w0rLLcnnngCCOk5lbiMnk3tvlW2UsviEDbVaIOx2lQ3luXgQtMS8s4775zySIov3rCmrnTeTZgwIbpPr2GNbbNfNpVSU1t5qK6mWtnUwCjeCnzMmDEA9O/fP5Ux1YYjzGZmZmZmCRocYd50002BUKpok002AUKZlCTxYvoqq3bBBRcAoX1rNVESu5qyQGgi0K9fv7zfp0Lg2hjx/vvvF2uIjUa8eYw1bmq5qrbPWgmLlwmbMWNG6QfWANqUNXr06IyvVnhvvfUWAFOnTgVgjTXWSHM4DXbggQcCoVzeAQccUKfvV+k8vb5rBUeR+Fwtjhur7t27A/D7778D4RyqdiqFOmDAgOg+FXMoZ44wm5mZmZklaJKr1Ev0YJMm+R/8f4MGDQJChDkXfQK/9957Afjrr7+AkKcMMHPmzDmPtkRmz549x/BjbeamGtVmbqB850fRE+VvXnfddUCI9DdUKc8d5S7fdtttQGg48NFHH0XHJDWlKLVyf17p3Bg5ciQAjz32WPSYom26lhVauc9Nmjw3+RXzeqxShHpeAAwcOBCAli1bAqH8oHJSIUQKVdowTeV+7mg/klYldtlll+ixTz75pKi/u9znJk355sYRZjMzMzOzBA2OMFcjf/LKr9IjzMXmcye/cp8b7VS//fbbgdAQBmDcuHEAHHTQQUDh91iU+9ykyXOTn6/HyXzu5Oe5yc8RZjMzMzOzevAbZjMzMzOzBE7JyMFLFfl5CTCZz538KmVulJpx/vnnR/epYcM666wDFH7zX6XMTRo8N/n5epzM505+npv8nJJhZmZmZlYPjjDn4E9e+TmikcznTn6em/w8N/l5bvLz9TiZz538PDf5OcJsZmZmZlYPiRFmMzMzM7PGzhFmMzMzM7MEfsNsZmZmZpbAb5jNzMzMzBL4DbOZmZmZWQK/YTYzMzMzSzB30oOuwZef5yaZ5yc/z01+npv8PDf5eW6SeX7y89zk57nJ5AizmZmZmVkCv2E2MzMzM0uQmJJhZtWtdevW0e0HHngAgKZNmwKw0korpTImMzOzcuMIs5mZmZlZAr9hNjMzMzNL4JQMs0Zo2LBhAPTo0SO6b9FFFwXg3nvvTWVMZmaN2SqrrBLdvvDCCwHYfffdAVhnnXUAePvtt0s/MAMcYTYzMzMzS5RqhHnNNdeMbu+0004AHHbYYQC88MILALzyyis1vu+yyy4D4I8//ijyCM2qw1JLLQXAuHHjANhss80AmD07lNl84403AOjdu3eJR2dm1nh16NABCBuvAWbMmAHA8OHDAZg+fXrpB2YZHGE2MzMzM0vQJB5hqvFgkbq8HH744QBcdNFF0X3Nmzev9fdvs802ADz66KOFHdj/cwec/ArRWUr/18qfnTVrFgAbbrghAAsttFB07H777QfAlClTAPjiiy/m+LunTZsGwN133w3Aiy++WJshF0S5nTsqG6fn2o477qgxAHDaaadFx2qeqvl5pb8b4JZbbgHCnGjF6/PPPy/mEHIqh7kpV+UwNz179oxud+nSBYD11lsPgDZt2mQc++yzz0a3d955ZwB++OGHooyr2jr9LbjggkC43i+77LLRY5tvvjkAH3/8ca1/XjmcO0m6du0KwB133AHANddcEz125plnAvDrr78W5XeX+9ykyZ3+zMzMzMzqIZUIs3bjT506NbpvySWXrPX3z5w5EwgRygcffLBwg8OfvJIUIqIxZMgQAE466aQCjSq3f/75B4C33noruk9RRX2tS7SiNsrt3FGu8pNPPpk9BgD233//6D7NSbGUw9w0a9Ysuv3OO+8AsNxyywFh/8TIkSOLOYScymFuylUac7P44osD4VxQpBjC68/TTz+d8T0dO3YEQpQUQkWD+H6dQqqkCLOixUsssUSNx77//nsAtt56awBuvPFGIDxHATbZZBMAfvrpp1r/znJ9Xq222moAvPbaawA88cQTQFjtgvD6VSzlOjflwBFmMzMzM7N6SKVKxnfffQdA//79o/suvvhiIESAPv30UwBWXHHFGt+/yCKLALD99tsDhY8wVyO1OV5ggQWi+/bZZx8AjjzyyIxj77vvvuj2QQcdVPCx7LHHHomPf/vtt9Ht//73v3P8eYpCKJdQ58f6668PQLt27aJjzz///IyfW+gIczmIt7u++eabgczcXQj/B8rzbizi+YDvvfceECLMuSJflunEE08EYN55543uW2ONNYCw30AUXV1rrbVKNLrCUbWClVdeGQirYgBDhw4FwuuYtG3bFoDnn38+uk/PxbPPPhuA8847rzgDTln8GnvMMccA4TVHNBe5XtMHDRoEhEi8rlfxPSvxc65SzT///EBYuXj99dcB6N69O1D8qHIlUAaCMgjOOOMMIDOfXfr16weEmtXF5gizmZmZmVkCv2E2MzMzM0uQyqa/XF599VUA1l13XSA0UYgv9WRbddVVAfjwww8LOpZqSIbv3LkzEJbelX7RokWL6Jh8//fvvvtudFvLrbHvafAmE/2/aYku/vsgc9n8q6++qs2vy6CydFruyrUEeN111wGhxGGhlMO5M2DAgOj26aefDsDEiRMBOOKII4DalecrtHKYm7hu3boB8J///AeAMWPGANCrV69SDSFSbnOz1VZbAeH6q3+rTW92ik8uWl5+//33o/vqs/mtlHOz7bbbAiEl4/bbbwfC9bM24mkXWjL+5JNPAGjVqlUhhhkpl01/SsMAuPTSS3Me8/vvvwPh+QahRGz2crvOr/hzUc/Puii355XSefr06QPA6quvDriUJYQN6jp/tMkz6T2qjB49GihcCqk3/ZmZmZmZ1UOqrbHjBg4cCIRi3SoKn6QaNgEUQrwM1tprrw3AxhtvnPPYeEmesWPHAqENucqKqZFIsXzwwQcZXwtNbdZzRZYV5VCEuZqozFX8uaNNjccffzyQTmS5XMU3Z0HYeHPqqacC9VvdqATLLLMMEJ7vq6yySo1jtBKlEmmK+L300ksAbLDBBnP8PXPNNVfGz6gEc8/9v5dERcVvvfXWOv8MNaGAEGHWZq+FF14YgB9//LFB4ywX55xzDgAnn3xyjcduuukmILR4VvMk/RvCtWrSpElAKOenY+JzWanmm2++6LbKeKoxSxqR5XKi/28Ir8la1dY5cNdddwGZG9S18rDXXnsBITqt94R//PFHUcbrCLOZmZmZWYKyiTDrk6QaLKhUnCKmuSgqveeeexZ5dOVlscUWA0IplYMPPjh6TKWOFAlSuR7lhP/222/RsSrdV+n0qfKKK64AknNQ27dvD4Sc+Wqw6667ArDpppsCmTlfyhcs9qpBJVP0VOfRLrvsAsCIESNSG1OhaU8DhEjOCiusUOvvV+7xN998A2RGhpR/qmYTyy+/fMb3xhsHlTu1hVdJyvq0JdYqVtxSSy0FwL777gtktkCuZFo9iJcrVb62VouzV2rUtANCyTCVdPzll1+AELmuhuvWKaecEt1u3rw5EOamsYtHjRVZ1nu/eBOXbCoJquuarjn6GWoIU2iOMJuZmZmZJSibCLOK3qtKRlJ1DMlu99tYnHXWWQD07t0bgGHDhkWP6ZPrzz//XPqBlZBaqAL07NkTgAMPPDDjmD///BPI3MGtZgrVQA1a/vWvf+U9Ri1na5Mrd+yxxwI1I4/FbmGetuxd2NW4NyIe5coXWY5HRpXH/eyzzwKZLYohs7mQzpvsyLLy5/X8rASFiGjGqza9+eabQGjeoqoI1UIrw2oiBmE1QqubRx11FBDy4i+55JLo2K5duwJhZVSNpa6++upiDrukunTpEt1+6qmnAHj55ZfTGk5Zia94S32aaWlPgFbAisURZjMzMzOzBH7DbGZmZmaWIJWUjLZt2wIwfvz46D5tBFBZn9q45557CjuwMtKsWbPotpZHtbR53HHHAWGDikryQHVskkiiYubaGADQtGnTnMdqqT2+ufHvv/8u4uhKS3/LhhtuCIQyXmoYAfD444/n/F6VmYvr27cvACuttFLG/SeeeCKQueTu8nSVQcvBKruUi54f8dQJLR3XRnYqhmhptdjLpOVGqWAAf/31V4ojKT5tnlbqDoSUDDUlUTMYNaTIVe7z3HPPBTLTCyvdFltsAWQ+95KKGAB07Ngxuq2yakrrqUbxBki6rTRClWJUo7N4yqVe86ZNmwaExkLFfl1yhNnMzMzMLEEqEWaV/oi3Ca1LZFkUJVNkrJqo4D2ECLPatCq6Wu3R5FzUXCJfVDlOm7fuu+++6L4XX3wRgAkTJgBhlUNl9yqJ2hVr058iy/GIenZ0T40C9D0qoRan0k7aKNimTRsgs4nA3nvvDYQSUlaetDoQX7ESNbpRdK82UeWWLVsCmZu8ttxyy5w/9/7776/HiCtfvFGFomQSbxxVDbRRNFcjFpUbvPPOO4EQQYxvsr3++uuB0JyimqhJydSpU6P7Pvroo4xjFDW9+OKLgfD8gjC32nQ9fPjwoo01LdoMC+G8OOGEE4Bw7VI0OU6vP6VubOMIs5mZmZlZglQizIrqxUsdDR48GKj5iTyJWrxWo9NPPz26rU9epWpdXc7GjRsHhFUKCG3A480U8tloo40yvvbv3x+Ayy67DIAhQ4ZEx3799dcNH3CBLbTQQtHt+AoNwJdffgnA6NGjo/vU4rd169ZAaGGrZifxCLRWLhTtUBmoyZMnZ/y72uSKfFWLa6+9Fsh8bvzwww9AaKKhPMDaOOKIIwAYMGBAjceUa6lVoLr83Gqy8sorR7e1OiMPPPBA3u/T/5FKq6rJkpoPQc3yfuWiLitN8ZUHtcv+7LPPCj6mtKmhmJ5nEKLGWv3U68/hhx8OZO5HUuMONQT64IMPgORzqNLEy1PqtU2vzdnX5XgTobSaITnCbGZmZmaWINXGJWplDKHVoZoxiHKbr7zyyui+hRdeuPiDS9nzzz8f3dYnLs2Bin0/9NBDpR9YypQfqYL3EHZdK0KjNrR77LEHkNk6PL4rF0JlCeVNxfOlOnXqBGRWnUibdl5D2HUuanl83nnnRfdpLhTJUdRCuZTKi4eQK6fmCmrfq2MfeeSR6Nhqyl2uxsiyKH9UX+tr5513BuDss8+u8ZgqQeh8aWyRZeUsq1pIhw4d8h6rOXrppZcA2GCDDaLHFl10USA0ltHzLt5KOrs5U9q0lyTePCn7GivaS6JzqVopL1fvXXJVStH/u6LFuXJxb7vtNiBc87XqXE0R5ngOs6qJ6Hmkv1+0ugyOMJuZmZmZlaUmSdGVJk2apB560afVc845J7pPUQ7l9CgSWKio1+zZs3N/RM4cV0HmZtNNNwXglVdeAeCPP/4AQrQBQmtntcRW22t9bynbPddmbqA8zh1R23UIFVVUzznJaaedBmTmNc9Jsc8dVUyB0EZWclWaUeUDnSui58xjjz0W3adP+Nkt55Xf3dAW2aV8XtWGInnZ1w21XY/PTbGV29xkU83vXK8Xan2sfOlCK+XcLLDAAgAsueSSQIgExmvpqr6waN9NPFqWj+YxV6v6UaNGASESq/0FajGeS9rXY+VXazUvif6uXJV5iiWN55WurVoBVl1qCK/VytdVLnM8lzebvv/1118HalchqjbK9ZrTrl07AF577TUgXHPi8/juu+8WdQz55sYRZjMzMzOzBH7DbGZmZmaWINVNf7WhJYtcm03UgrRS2h2rDN69994b3acNa2rCMmbMGAC+++676Bht9lNKRvPmzYHMtA3Lb+zYsdFtbSR4+OGHgZpNF+Lim23KRXxTrNKV1IJY1JwEQokrHati8Eo3ULk5gJtvvjnnsUrJaCyU6mVwwQUXALnbrkspU1cKSekX8XQ/bUhr27btHL9fzTq0OU+bu3KlRo0cORIIm/5efvnleo46XWpGctBBBwHQrVs3IDNVR3+bltR1rNJcGptc7Zrr0sAmV/pONVP78KRrTlocYTYzMzMzS1D2EeaBAwfmfUxtNSvlE5g+ecfL4mkTlyLLuRx77LEZ/1Z0tBLbOadNUSCVdUqKMBd7Y0FDKaqTtHFXn851zDrrrAOE9tnxRkFq26oSUWpwYY2PVvbWX399oOZ5FL8mqSRopVE75m233Ta6T40ltEFNz4n4Ko6O0WY8vf5oQ1d81ebDDz8EQtlKbdiuVNrQFi9dCdCvX7/otlZEd9ttNyBEmNMqBVZqWqHLV16vrrbaaiug+tqq56OyubrmTJkyBQgFEdLkCLOZmZmZWYKCR5gXW2wxILRzhNDSWV9rQ/m+hx12WN5j4oWsK4EatcQ/jeu+eBMXyIzaqJGEyl+pgLly6KqZzoNDDz0UCFGceMONulBJHrWfzRYvMv/ss8/W63cUUzzSld3mWqWv4jnM8VbaAL169QJC9CPeGlu5nLly7hoTNaJobJo1axbd3n///YHM6CuEa3h8X0A55RjWRZcuXYAQRYZQHu3VV1+d4/crV3nw4MEALLfccgB8/fXX0TFqE17JkeWOHTtGt7Nfp1QiTqueAEsvvTRQc99RUnm8alKblb/amGeeeYDQjn706NENG1iZ076B3r17AzBjxgwArr76aqA8zh9HmM3MzMzMEhQ8wqxPoPH2l8rp+vLLL4EQwXr//fejY9SSWMeecsopQO422BdffHHGz6sUF154IRCqe0DIEezcuXPGsS1btoxuK59OjSPi81aNFKGA0AZUO2fj81Jbag8NIZcwu/mATJ06Nbqd3cCjHMTPnV9//RUIkUE1KalNZCNXa+yJEycWbJyVTO3Dhw0blvJISkOrEGqtDrDnnntmHKMqPspPrdSocpyeJzNnzozuq82+EOX9q2lH165dgZDbvPfee0fHVmo1jLj4KkOLFi2AUBlFFZ8UDQXYaaedMo7VapYihtVOudpfffUVEFZrIERL84nPo45VpaMDDjigkMMsCzpHACZNmgSElRrt78rVNjwtjjCbmZmZmSUoeIRZUZlWrVpF97Vv3x4Iux2VixLfNaud+dk5l4oCxNs/9+/fH4BZs2YVcOSlc9FFF6U9hLIWr/uryLLovHrnnXei+7SrVlRfVasUiipDzfNL0Q9FXNWGvFypugfAPvvsA4S/L55rmO2mm24CQntVtWKv1Bq6hTB9+nQA3nzzTaB2rY2rkSI62VFlCDWps3NXq4Gq4MRz/tXeW3txVEtY1S4g7B1o06YNAM899xwARx55JFC7/OdKEl9NyM7PVURUFTEALr/8cgC+//57INSgnlN0tVoosqwa5loRj9MegFVWWQUIe2rOOOOM6Bi9v1GufXy/SbUYMmRIdFvXIe2TyDVvaXOE2czMzMwsgd8wm5mZmZklaJK0QahJkyb1rosSD6drk9pVV11V55+jFtFaIiuF2bNnz7HieEPmppLVZm6gYfOjEnIAI0aMyHmMUgqgZoMNbSTQhsokKve0++67A/DII4/UbbBZfO7kV65z88ILLwBh47E2MqlkVimkMTcq46QW6GowASFdYYcddgBCScs0FHtuBgwYEN3Wxmq15c3lnnvuAULjLG1MTkMprsfxa/AhhxwChI1Y2lCtlMo4pWlMmDChvr+6wcrhmnP00UdHt4cOHQrULF2plMB46pOathWrYUeac6MiB/EyqUr90SbJ+GOllm9uHGE2MzMzM0tQtAhznD5NZW+oikcAtYFJFDVU+a9Slucph0+l5aoUEQ2V0YGwcSJeqqkh1JhEGwvvvPNOIGzcaSifO/mV69yonJoK5mtzcr7Sg8WQxtxo41GPHj1qPNa3b1+gPDZqlet5Uw5KcT0+7rjjotvZG7G0aVorwQDDhw8HYNCgQUDNTdml5HMnvzTmRq/t2ryuEo0QIsvjx48v5K+sF0eYzczMzMzqoeBl5XJRQXfl7+Sy7777lmIoVgHiLTCVV6m8QUX9lGMJNXNN4yUIASZPnlzjsWor/WT1d/755wPQrl07oP5t1yuFyudlN4VSSTXIfM5Y46aSlADzzjsvAGeddRYAL774IhCuzwCXXnppCUdnlUClXrVfQvuMtMIL5RFZnhNHmM3MzMzMEpQkh7nSOO8pv1LkzFUynzv5eW7yK+XcDB48GAjRHlXAUEtwyGwMlDafN/n5epzM505+pZwbNfW58sorAXj66aeBUC0DQiZCOXAOs5mZmZlZPfgNs5mZmZlZAqdk5OBlnPy8BJjM505+npv8Sjk3nTp1AmDSpEkAdOvWDUi3UUASnzf5+XqczOdOfsWem0022SS6rc19N9xwAxBKeX7++ef1/fFF5ZQMMzMzM7N6cIQ5B38qzc8RjWQ+d/Lz3OTnucnPc5Ofr8fJfO7k57nJzxFmMzMzM7N6SIwwm5mZmZk1do4wm5mZmZkl8BtmMzMzM7MEfsNsZmZmZpbAb5jNzMzMzBL4DbOZmZmZWQK/YTYzMzMzS/B/h5JwsOgltrAAAAAASUVORK5CYII=", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "fig, axes = plt.subplots(2, 10, figsize=(10, 2))\n", "for i in range(2):\n", @@ -107,7 +85,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -124,7 +102,7 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -201,12 +179,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Define a model configuration (in which the latent will be stated). Here, we use the RHVAE model." + "### Define a model configuration (in which the latent will be stated). Here, we use the VAE model." ] }, { "cell_type": "code", - "execution_count": 52, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -227,7 +205,7 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -239,12 +217,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Last but not least. Build you RHVAE model by passing the ``encoder`` and ``decoder`` arguments" + "### Last but not least. Build you VAE model by passing the ``encoder`` and ``decoder`` arguments" ] }, { "cell_type": "code", - "execution_count": 54, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -266,52 +244,9 @@ }, { "cell_type": "code", - "execution_count": 55, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "VAE(\n", - " (decoder): Decoder_Conv_AE_MNIST(\n", - " (fc): Linear(in_features=16, out_features=16384, bias=True)\n", - " (deconv_layers): Sequential(\n", - " (0): ConvTranspose2d(1024, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", - " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): ReLU()\n", - " (3): ConvTranspose2d(512, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))\n", - " (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (5): ReLU()\n", - " (6): ConvTranspose2d(256, 1, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))\n", - " (7): Sigmoid()\n", - " )\n", - " )\n", - " (encoder): Encoder_Conv_VAE_MNIST(\n", - " (conv_layers): Sequential(\n", - " (0): Conv2d(1, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): ReLU()\n", - " (3): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (5): ReLU()\n", - " (6): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (7): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (8): ReLU()\n", - " (9): Conv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (10): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (11): ReLU()\n", - " )\n", - " (embedding): Linear(in_features=1024, out_features=16, bias=True)\n", - " (log_var): Linear(in_features=1024, out_features=16, bias=True)\n", - " )\n", - ")" - ] - }, - "execution_count": 55, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "model" ] @@ -332,7 +267,7 @@ }, { "cell_type": "code", - "execution_count": 56, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -349,7 +284,7 @@ }, { "cell_type": "code", - "execution_count": 57, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -363,7 +298,7 @@ }, { "cell_type": "code", - "execution_count": 58, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -381,264 +316,10 @@ }, { "cell_type": "code", - "execution_count": 59, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Preprocessing train data...\n", - "Using Base Trainer\n", - "\n", - "! No eval dataset provided ! -> keeping best model on train.\n", - "\n", - "Model passed sanity check !\n", - "\n", - "Created my_model_with_custom_archi/VAE_training_2022-06-14_09-45-29. \n", - "Training config, checkpoints and final model will be saved here.\n", - "\n", - "Successfully launched training !\n", - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "ea8f78b4359448bb810efa15fe454214", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Training of epoch 1/10: 0%| | 0/50 [00:00" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", @@ -804,7 +423,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.12" + "version": "3.8.13" }, "orig_nbformat": 4 }, diff --git a/examples/notebooks/wandb_experiment_monitoring.ipynb b/examples/notebooks/wandb_experiment_monitoring.ipynb new file mode 100644 index 00000000..57f33521 --- /dev/null +++ b/examples/notebooks/wandb_experiment_monitoring.ipynb @@ -0,0 +1,191 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Tutorial - Wandb experiements monitoring\n", + "\n", + "In this notebook, we will see how to smonitor your experiments using the integrated **wandb** callbacks." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Install the library\n", + "%pip install pythae" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train your Pythae model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torchvision.datasets as datasets\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mnist_trainset = datasets.MNIST(root='../data', train=True, download=True, transform=None)\n", + "\n", + "train_dataset = mnist_trainset.data[:-10000].reshape(-1, 1, 28, 28) / 255.\n", + "eval_dataset = mnist_trainset.data[-10000:].reshape(-1, 1, 28, 28) / 255." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pythae.models import BetaVAE, BetaVAEConfig\n", + "from pythae.trainers import BaseTrainerConfig\n", + "from pythae.pipelines.training import TrainingPipeline\n", + "from pythae.models.nn.benchmarks.mnist import Encoder_ResNet_VAE_MNIST, Decoder_ResNet_AE_MNIST" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "training_config = BaseTrainerConfig(\n", + " output_dir='my_model',\n", + " learning_rate=1e-4,\n", + " batch_size=100,\n", + " num_epochs=10, # Change this to train the model a bit more\n", + ")\n", + "\n", + "\n", + "model_config = BetaVAEConfig(\n", + " input_dim=(1, 28, 28),\n", + " latent_dim=16,\n", + " beta=2.\n", + "\n", + ")\n", + "\n", + "model = BetaVAE(\n", + " model_config=model_config,\n", + " encoder=Encoder_ResNet_VAE_MNIST(model_config), \n", + " decoder=Decoder_ResNet_AE_MNIST(model_config) \n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Before lauching the pipeline, you will need to build your `WandbCallback`\n", + "\n", + "To be able to access this feature you will need:\n", + "- a valid wandb acccount\n", + "- the `wandb` package installed in your virtual env. You can install it by running (`pip install wandb`)\n", + "- to be logged in by running (`$ wandb login`)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Before being allowed to monitor your experiments you may need to run the following\n", + "# !pip install wandb\n", + "# !wandb login" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create you callback\n", + "from pythae.trainers.training_callbacks import WandbCallback\n", + "\n", + "callbacks = [] # the TrainingPipeline expects a list of callbacks\n", + "\n", + "wandb_cb = WandbCallback() # Build the callback \n", + "\n", + "# SetUp the callback \n", + "wandb_cb.setup(\n", + " training_config=training_config, # training config\n", + " model_config=model_config, # model config\n", + " project_name=\"test\", # specify your wandb project\n", + " entity_name=\"benchmark_team\", # specify your wandb entity\n", + ")\n", + "\n", + "callbacks.append(wandb_cb) # Add it to the callbacks list" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pipeline = TrainingPipeline(\n", + " training_config=training_config,\n", + " model=model\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pipeline(\n", + " train_data=train_dataset,\n", + " eval_data=eval_dataset,\n", + " callbacks=callbacks # pass the callbacks to the TrainingPipeline and you are done!\n", + ")\n", + "# You can log to https://wandb.ai/your_wandb_entity/your_wandb_project to monitor your training" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "3efa06c4da850a09a4898b773c7e91b0da3286dbbffa369a8099a14a8fa43098" + }, + "kernelspec": { + "display_name": "Python 3.8.11 64-bit ('pythae_dev': conda)", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/requirements.txt b/requirements.txt index 3e352b94..fde1f54c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -dill>=0.3.3 +cloudpickle>=2.1.0 imageio numpy>=1.19 pydantic>=1.8.2 @@ -8,3 +8,4 @@ torch>=1.10.1 tqdm typing_extensions dataclasses>=0.6 +pickle5 \ No newline at end of file diff --git a/setup.py b/setup.py index 520d8116..3b4dd82b 100644 --- a/setup.py +++ b/setup.py @@ -1,20 +1,11 @@ -import pathlib - -import pkg_resources from setuptools import find_packages, setup -with pathlib.Path("requirements.txt").open() as requirements_txt: - install_requires = [ - str(requirement) - for requirement in pkg_resources.parse_requirements(requirements_txt) - ] - with open("README.md", "r", encoding="utf-8") as fh: long_description = fh.read() setup( name="pythae", - version="0.0.1", + version="0.0.2", author="Clement Chadebec (HekA team INRIA)", author_email="clement.chadebec@inria.fr", description="Unifying Generative Autoencoders in Python", @@ -34,7 +25,7 @@ package_dir={"": "src"}, packages=find_packages(where="src"), install_requires=[ - "dill>=0.3.3", + "cloudpickle>=2.1.0", "imageio", "numpy>=1.19", "pydantic>=1.8.2", @@ -43,7 +34,8 @@ "torch>=1.10.1", "tqdm", "typing_extensions", - "dataclasses>=0.6" + "dataclasses>=0.6", + "pickle5" ], python_requires=">=3.6", ) diff --git a/src/pythae/models/adversarial_ae/adversarial_ae_model.py b/src/pythae/models/adversarial_ae/adversarial_ae_model.py index f6961565..1cdfbc3e 100644 --- a/src/pythae/models/adversarial_ae/adversarial_ae_model.py +++ b/src/pythae/models/adversarial_ae/adversarial_ae_model.py @@ -1,19 +1,27 @@ +import inspect +import logging import os +import warnings from copy import deepcopy from typing import Optional -import dill +import cloudpickle import torch import torch.nn.functional as F from ...customexception import BadInheritanceError from ...data.datasets import BaseDataset -from ..base.base_utils import CPU_Unpickler, ModelOutput +from ..base.base_utils import CPU_Unpickler, ModelOutput, hf_hub_is_available from ..nn import BaseDecoder, BaseDiscriminator, BaseEncoder from ..nn.default_architectures import Discriminator_MLP from ..vae import VAE from .adversarial_ae_config import Adversarial_AE_Config +logger = logging.getLogger(__name__) +console = logging.StreamHandler() +logger.addHandler(console) +logger.setLevel(logging.INFO) + class Adversarial_AE(VAE): """Adversarial Autoencoder model. @@ -208,29 +216,18 @@ def save(self, dir_path: str): if not self.model_config.uses_default_discriminator: with open(os.path.join(model_path, "discriminator.pkl"), "wb") as fp: - dill.dump(self.discriminator, fp) + cloudpickle.register_pickle_by_value( + inspect.getmodule(self.discriminator) + ) + cloudpickle.dump(self.discriminator, fp) torch.save(model_dict, os.path.join(model_path, "model.pt")) - @classmethod - def _load_model_config_from_folder(cls, dir_path): - file_list = os.listdir(dir_path) - - if "model_config.json" not in file_list: - raise FileNotFoundError( - f"Missing model config file ('model_config.json') in" - f"{dir_path}... Cannot perform model building." - ) - - path_to_model_config = os.path.join(dir_path, "model_config.json") - model_config = Adversarial_AE_Config.from_json_file(path_to_model_config) - - return model_config - @classmethod def _load_custom_discriminator_from_folder(cls, dir_path): file_list = os.listdir(dir_path) + cls._check_python_version_from_folder(dir_path=dir_path) if "discriminator.pkl" not in file_list: raise FileNotFoundError( @@ -291,3 +288,102 @@ def load_from_folder(cls, dir_path): model.load_state_dict(model_weights) return model + + @classmethod + def load_from_hf_hub( + cls, hf_hub_path: str, allow_pickle: bool = False + ): # pragma: no cover + """Class method to be used to load a pretrained model from the Hugging Face hub + + Args: + hf_hub_path (str): The path where the model should have been be saved on the + hugginface hub. + + .. note:: + This function requires the folder to contain: + + - | a ``model_config.json`` and a ``model.pt`` if no custom architectures were provided + + **or** + + - | a ``model_config.json``, a ``model.pt`` and a ``encoder.pkl`` (resp. + ``decoder.pkl`` and ``discriminator``) if a custom encoder (resp. decoder and/or + discriminator) was provided + """ + + if not hf_hub_is_available(): + raise ModuleNotFoundError( + "`huggingface_hub` package must be installed to load models from the HF hub. " + "Run `python -m pip install huggingface_hub` and log in to your account with " + "`huggingface-cli login`." + ) + + else: + from huggingface_hub import hf_hub_download + + logger.info(f"Downloading {cls.__name__} files for rebuilding...") + + config_path = hf_hub_download(repo_id=hf_hub_path, filename="model_config.json") + dir_path = os.path.dirname(config_path) + + _ = hf_hub_download(repo_id=hf_hub_path, filename="model.pt") + + model_config = cls._load_model_config_from_folder(dir_path) + + if ( + cls.__name__ + "Config" != model_config.name + and cls.__name__ + "_Config" != model_config.name + ): + warnings.warn( + f"You are trying to load a " + f"`{ cls.__name__}` while a " + f"`{model_config.name}` is given." + ) + + model_weights = cls._load_model_weights_from_folder(dir_path) + + if ( + not model_config.uses_default_encoder + or not model_config.uses_default_decoder + or not model_config.uses_default_discriminator + ) and not allow_pickle: + warnings.warn( + "You are about to download pickled files from the HF hub that may have " + "been created by a third party and so could potentially harm your computer. If you " + "are sure that you want to download them set `allow_pickle=true`." + ) + + else: + + if not model_config.uses_default_encoder: + _ = hf_hub_download(repo_id=hf_hub_path, filename="encoder.pkl") + encoder = cls._load_custom_encoder_from_folder(dir_path) + + else: + encoder = None + + if not model_config.uses_default_decoder: + _ = hf_hub_download(repo_id=hf_hub_path, filename="decoder.pkl") + decoder = cls._load_custom_decoder_from_folder(dir_path) + + else: + decoder = None + + if not model_config.uses_default_discriminator: + _ = hf_hub_download(repo_id=hf_hub_path, filename="discriminator.pkl") + discriminator = cls._load_custom_discriminator_from_folder(dir_path) + + else: + discriminator = None + + logger.info(f"Successfully downloaded {cls.__name__} model!") + + model = cls( + model_config, + encoder=encoder, + decoder=decoder, + discriminator=discriminator, + ) + model.load_state_dict(model_weights) + + return model diff --git a/src/pythae/models/ae/ae_model.py b/src/pythae/models/ae/ae_model.py index 19ae6d80..6be833c6 100644 --- a/src/pythae/models/ae/ae_model.py +++ b/src/pythae/models/ae/ae_model.py @@ -88,18 +88,3 @@ def loss_function(self, recon_x, x): recon_x.reshape(x.shape[0], -1), x.reshape(x.shape[0], -1), reduction="none" ).sum(dim=-1) return MSE.mean(dim=0) - - @classmethod - def _load_model_config_from_folder(cls, dir_path): - file_list = os.listdir(dir_path) - - if "model_config.json" not in file_list: - raise FileNotFoundError( - f"Missing model config file ('model_config.json') in" - f"{dir_path}... Cannot perform model building." - ) - - path_to_model_config = os.path.join(dir_path, "model_config.json") - model_config = AEConfig.from_json_file(path_to_model_config) - - return model_config diff --git a/src/pythae/models/auto_model/__init__.py b/src/pythae/models/auto_model/__init__.py index df4d0635..1db22166 100644 --- a/src/pythae/models/auto_model/__init__.py +++ b/src/pythae/models/auto_model/__init__.py @@ -7,4 +7,5 @@ >>> model = AutoModel.load_from_folder(dir_path='path/to/my_model') """ +from .auto_config import AutoConfig from .auto_model import AutoModel diff --git a/src/pythae/models/auto_model/auto_config.py b/src/pythae/models/auto_model/auto_config.py new file mode 100644 index 00000000..00ba1592 --- /dev/null +++ b/src/pythae/models/auto_model/auto_config.py @@ -0,0 +1,169 @@ +from pydantic.dataclasses import dataclass + +from pythae.config import BaseConfig + + +@dataclass +class AutoConfig(BaseConfig): + @classmethod + def from_json_file(cls, json_path): + """Creates a :class:`~pythae.config.BaseAEConfig` instance from a JSON config file. It + builds automatically the correct config for any `pythae.models`. + + Args: + json_path (str): The path to the json file containing all the parameters + + Returns: + :class:`BaseAEConfig`: The created instance + """ + + config_dict = cls._dict_from_json(json_path) + config_name = config_dict.pop("name") + + if config_name == "BaseAEConfig": + from ..base import BaseAEConfig + + model_config = BaseAEConfig.from_json_file(json_path) + + elif config_name == "Adversarial_AE_Config": + from ..adversarial_ae import Adversarial_AE_Config + + model_config = Adversarial_AE_Config.from_json_file(json_path) + + elif config_name == "AEConfig": + from ..ae import AEConfig + + model_config = AEConfig.from_json_file(json_path) + + elif config_name == "BetaTCVAEConfig": + from ..beta_tc_vae import BetaTCVAEConfig + + model_config = BetaTCVAEConfig.from_json_file(json_path) + + elif config_name == "BetaVAEConfig": + from ..beta_vae import BetaVAEConfig + + model_config = BetaVAEConfig.from_json_file(json_path) + + elif config_name == "DisentangledBetaVAEConfig": + from ..disentangled_beta_vae import DisentangledBetaVAEConfig + + model_config = DisentangledBetaVAEConfig.from_json_file(json_path) + + elif config_name == "FactorVAEConfig": + from ..factor_vae import FactorVAEConfig + + model_config = FactorVAEConfig.from_json_file(json_path) + + elif config_name == "HVAEConfig": + from ..hvae import HVAEConfig + + model_config = HVAEConfig.from_json_file(json_path) + + elif config_name == "INFOVAE_MMD_Config": + from ..info_vae import INFOVAE_MMD_Config + + model_config = INFOVAE_MMD_Config.from_json_file(json_path) + + elif config_name == "IWAEConfig": + from ..iwae import IWAEConfig + + model_config = IWAEConfig.from_json_file(json_path) + + elif config_name == "MSSSIM_VAEConfig": + from ..msssim_vae import MSSSIM_VAEConfig + + model_config = MSSSIM_VAEConfig.from_json_file(json_path) + + elif config_name == "RAE_GP_Config": + from ..rae_gp import RAE_GP_Config + + model_config = RAE_GP_Config.from_json_file(json_path) + + elif config_name == "RAE_L2_Config": + from ..rae_l2 import RAE_L2_Config + + model_config = RAE_L2_Config.from_json_file(json_path) + + elif config_name == "RHVAEConfig": + from ..rhvae import RHVAEConfig + + model_config = RHVAEConfig.from_json_file(json_path) + + elif config_name == "SVAEConfig": + from ..svae import SVAEConfig + + model_config = SVAEConfig.from_json_file(json_path) + + elif config_name == "VAEConfig": + from ..vae import VAEConfig + + model_config = VAEConfig.from_json_file(json_path) + + elif config_name == "VAEGANConfig": + from ..vae_gan import VAEGANConfig + + model_config = VAEGANConfig.from_json_file(json_path) + + elif config_name == "VAE_IAF_Config": + from ..vae_iaf import VAE_IAF_Config + + model_config = VAE_IAF_Config.from_json_file(json_path) + + elif config_name == "VAE_LinNF_Config": + from ..vae_lin_nf import VAE_LinNF_Config + + model_config = VAE_LinNF_Config.from_json_file(json_path) + + elif config_name == "VAMPConfig": + from ..vamp import VAMPConfig + + model_config = VAMPConfig.from_json_file(json_path) + + elif config_name == "VQVAEConfig": + from ..vq_vae import VQVAEConfig + + model_config = VQVAEConfig.from_json_file(json_path) + + elif config_name == "WAE_MMD_Config": + from ..wae_mmd import WAE_MMD_Config + + model_config = WAE_MMD_Config.from_json_file(json_path) + + elif config_name == "MAFConfig": + from ..normalizing_flows import MAFConfig + + model_config = MAFConfig.from_json_file(json_path) + + elif config_name == "IAFConfig": + from ..normalizing_flows import IAFConfig + + model_config = IAFConfig.from_json_file(json_path) + + elif config_name == "PlanarFlowConfig": + from ..normalizing_flows import PlanarFlowConfig + + model_config = PlanarFlowConfig.from_json_file(json_path) + + elif config_name == "RadialFlowConfig": + from ..normalizing_flows import RadialFlowConfig + + model_config = RadialFlowConfig.from_json_file(json_path) + + elif config_name == "MADEConfig": + from ..normalizing_flows import MADEConfig + + model_config = MADEConfig.from_json_file(json_path) + + elif config_name == "PixelCNNConfig": + from ..normalizing_flows import PixelCNNConfig + + model_config = PixelCNNConfig.from_json_file(json_path) + + else: + raise NameError( + "Cannot reload automatically the model configuration... " + f"The model name in the `model_config.json may be corrupted. Got `{config_name}`" + ) + + return model_config diff --git a/src/pythae/models/auto_model/auto_model.py b/src/pythae/models/auto_model/auto_model.py index 92f78c84..93024af1 100644 --- a/src/pythae/models/auto_model/auto_model.py +++ b/src/pythae/models/auto_model/auto_model.py @@ -1,30 +1,15 @@ import json +import logging import os import torch.nn as nn -from ..adversarial_ae import Adversarial_AE -from ..ae import AE -from ..beta_tc_vae import BetaTCVAE -from ..beta_vae import BetaVAE -from ..disentangled_beta_vae import DisentangledBetaVAE -from ..factor_vae import FactorVAE -from ..hvae import HVAE -from ..info_vae import INFOVAE_MMD -from ..iwae import IWAE -from ..msssim_vae import MSSSIM_VAE -from ..normalizing_flows import IAF, MAF -from ..rae_gp import RAE_GP -from ..rae_l2 import RAE_L2 -from ..rhvae import RHVAE -from ..svae import SVAE -from ..vae import VAE -from ..vae_gan import VAEGAN -from ..vae_iaf import VAE_IAF -from ..vae_lin_nf import VAE_LinNF -from ..vamp import VAMP -from ..vq_vae import VQVAE -from ..wae_mmd import WAE_MMD +from ..base.base_utils import hf_hub_is_available + +logger = logging.getLogger(__name__) +console = logging.StreamHandler() +logger.addHandler(console) +logger.setLevel(logging.INFO) class AutoModel(nn.Module): @@ -32,7 +17,7 @@ def __init__(self) -> None: super().__init__() @classmethod - def load_from_folder(cls, dir_path): + def load_from_folder(cls, dir_path: str): """Class method to be used to load the model from a specific folder Args: @@ -53,74 +38,377 @@ def load_from_folder(cls, dir_path): model_name = json.load(f)["name"] if model_name == "Adversarial_AE_Config": + from ..adversarial_ae import Adversarial_AE + model = Adversarial_AE.load_from_folder(dir_path=dir_path) elif model_name == "AEConfig": + from ..ae import AE + model = AE.load_from_folder(dir_path=dir_path) elif model_name == "BetaTCVAEConfig": + from ..beta_tc_vae import BetaTCVAE + model = BetaTCVAE.load_from_folder(dir_path=dir_path) elif model_name == "BetaVAEConfig": + from ..beta_vae import BetaVAE + model = BetaVAE.load_from_folder(dir_path=dir_path) elif model_name == "DisentangledBetaVAEConfig": + from ..disentangled_beta_vae import DisentangledBetaVAE + model = DisentangledBetaVAE.load_from_folder(dir_path=dir_path) elif model_name == "FactorVAEConfig": + from ..factor_vae import FactorVAE + model = FactorVAE.load_from_folder(dir_path=dir_path) elif model_name == "HVAEConfig": + from ..hvae import HVAE + model = HVAE.load_from_folder(dir_path=dir_path) elif model_name == "INFOVAE_MMD_Config": + from ..info_vae import INFOVAE_MMD + model = INFOVAE_MMD.load_from_folder(dir_path=dir_path) elif model_name == "IWAEConfig": + from ..iwae import IWAE + model = IWAE.load_from_folder(dir_path=dir_path) elif model_name == "MSSSIM_VAEConfig": + from ..msssim_vae import MSSSIM_VAE + model = MSSSIM_VAE.load_from_folder(dir_path=dir_path) elif model_name == "RAE_GP_Config": + from ..rae_gp import RAE_GP + model = RAE_GP.load_from_folder(dir_path=dir_path) elif model_name == "RAE_L2_Config": + from ..rae_l2 import RAE_L2 + model = RAE_L2.load_from_folder(dir_path=dir_path) elif model_name == "RHVAEConfig": + from ..rhvae import RHVAE + model = RHVAE.load_from_folder(dir_path=dir_path) elif model_name == "SVAEConfig": + from ..svae import SVAE + model = SVAE.load_from_folder(dir_path=dir_path) elif model_name == "VAEConfig": + from ..vae import VAE + model = VAE.load_from_folder(dir_path=dir_path) elif model_name == "VAEGANConfig": + from ..vae_gan import VAEGAN + model = VAEGAN.load_from_folder(dir_path=dir_path) elif model_name == "VAE_IAF_Config": + from ..vae_iaf import VAE_IAF + model = VAE_IAF.load_from_folder(dir_path=dir_path) elif model_name == "VAE_LinNF_Config": + from ..vae_lin_nf import VAE_LinNF + model = VAE_LinNF.load_from_folder(dir_path=dir_path) elif model_name == "VAMPConfig": + from ..vamp import VAMP + model = VAMP.load_from_folder(dir_path=dir_path) elif model_name == "VQVAEConfig": + from ..vq_vae import VQVAE + model = VQVAE.load_from_folder(dir_path=dir_path) elif model_name == "WAE_MMD_Config": + from ..wae_mmd import WAE_MMD + model = WAE_MMD.load_from_folder(dir_path=dir_path) elif model_name == "MAFConfig": + from ..normalizing_flows import MAF + model = MAF.load_from_folder(dir_path=dir_path) elif model_name == "IAFConfig": + from ..normalizing_flows import IAF + model = IAF.load_from_folder(dir_path=dir_path) + elif model_name == "PlanarFlowConfig": + from ..normalizing_flows import PlanarFlow + + model = PlanarFlow.load_from_folder(dir_path=dir_path) + + elif model_name == "RadialFlowConfig": + from ..normalizing_flows import RadialFlow + + model = RadialFlow.load_from_folder(dir_path=dir_path) + + elif model_name == "MADEConfig": + from ..normalizing_flows import MADE + + model = MADE.load_from_folder(dir_path=dir_path) + + elif model_name == "PixelCNNConfig": + from ..normalizing_flows import PixelCNN + + model = PixelCNN.load_from_folder(dir_path=dir_path) + + else: + raise NameError( + "Cannot reload automatically the model... " + f"The model name in the `model_config.json may be corrupted. Got {model_name}" + ) + + return model + + @classmethod + def load_from_hf_hub( + cls, hf_hub_path: str, allow_pickle: bool = False + ): # pragma: no cover + """Class method to be used to load a automaticaly a pretrained model from the Hugging Face + hub + + Args: + hf_hub_path (str): The path where the model should have been be saved on the + hugginface hub. + + .. note:: + This function requires the folder to contain: + + - | a ``model_config.json`` and a ``model.pt`` if no custom architectures were provided + + **or** + + - | a ``model_config.json``, a ``model.pt`` and a ``encoder.pkl`` (resp. + ``decoder.pkl``) if a custom encoder (resp. decoder) was provided + """ + + if not hf_hub_is_available(): + raise ModuleNotFoundError( + "`huggingface_hub` package must be installed to load models from the HF hub. " + "Run `python -m pip install huggingface_hub` and log in to your account with " + "`huggingface-cli login`." + ) + + else: + from huggingface_hub import hf_hub_download + + logger.info(f"Downloading config file ...") + + config_path = hf_hub_download(repo_id=hf_hub_path, filename="model_config.json") + dir_path = os.path.dirname(config_path) + + with open(os.path.join(dir_path, "model_config.json")) as f: + model_name = json.load(f)["name"] + + if model_name == "Adversarial_AE_Config": + from ..adversarial_ae import Adversarial_AE + + model = Adversarial_AE.load_from_hf_hub( + hf_hub_path=hf_hub_path, allow_pickle=allow_pickle + ) + + elif model_name == "AEConfig": + from ..ae import AE + + model = AE.load_from_hf_hub( + hf_hub_path=hf_hub_path, allow_pickle=allow_pickle + ) + + elif model_name == "BetaTCVAEConfig": + from ..beta_tc_vae import BetaTCVAE + + model = BetaTCVAE.load_from_hf_hub( + hf_hub_path=hf_hub_path, allow_pickle=allow_pickle + ) + + elif model_name == "BetaVAEConfig": + from ..beta_vae import BetaVAE + + model = BetaVAE.load_from_hf_hub( + hf_hub_path=hf_hub_path, allow_pickle=allow_pickle + ) + + elif model_name == "DisentangledBetaVAEConfig": + from ..disentangled_beta_vae import DisentangledBetaVAE + + model = DisentangledBetaVAE.load_from_hf_hub( + hf_hub_path=hf_hub_path, allow_pickle=allow_pickle + ) + + elif model_name == "FactorVAEConfig": + from ..factor_vae import FactorVAE + + model = FactorVAE.load_from_hf_hub( + hf_hub_path=hf_hub_path, allow_pickle=allow_pickle + ) + + elif model_name == "HVAEConfig": + from ..hvae import HVAE + + model = HVAE.load_from_hf_hub( + hf_hub_path=hf_hub_path, allow_pickle=allow_pickle + ) + + elif model_name == "INFOVAE_MMD_Config": + from ..info_vae import INFOVAE_MMD + + model = INFOVAE_MMD.load_from_hf_hub( + hf_hub_path=hf_hub_path, allow_pickle=allow_pickle + ) + + elif model_name == "IWAEConfig": + from ..iwae import IWAE + + model = IWAE.load_from_hf_hub( + hf_hub_path=hf_hub_path, allow_pickle=allow_pickle + ) + + elif model_name == "MSSSIM_VAEConfig": + from ..msssim_vae import MSSSIM_VAE + + model = MSSSIM_VAE.load_from_hf_hub( + hf_hub_path=hf_hub_path, allow_pickle=allow_pickle + ) + + elif model_name == "RAE_GP_Config": + from ..rae_gp import RAE_GP + + model = RAE_GP.load_from_hf_hub( + hf_hub_path=hf_hub_path, allow_pickle=allow_pickle + ) + + elif model_name == "RAE_L2_Config": + from ..rae_l2 import RAE_L2 + + model = RAE_L2.load_from_hf_hub( + hf_hub_path=hf_hub_path, allow_pickle=allow_pickle + ) + + elif model_name == "RHVAEConfig": + from ..rhvae import RHVAE + + model = RHVAE.load_from_hf_hub( + hf_hub_path=hf_hub_path, allow_pickle=allow_pickle + ) + + elif model_name == "SVAEConfig": + from ..svae import SVAE + + model = SVAE.load_from_hf_hub( + hf_hub_path=hf_hub_path, allow_pickle=allow_pickle + ) + + elif model_name == "VAEConfig": + from ..vae import VAE + + model = VAE.load_from_hf_hub( + hf_hub_path=hf_hub_path, allow_pickle=allow_pickle + ) + + elif model_name == "VAEGANConfig": + from ..vae_gan import VAEGAN + + model = VAEGAN.load_from_hf_hub( + hf_hub_path=hf_hub_path, allow_pickle=allow_pickle + ) + + elif model_name == "VAE_IAF_Config": + from ..vae_iaf import VAE_IAF + + model = VAE_IAF.load_from_hf_hub( + hf_hub_path=hf_hub_path, allow_pickle=allow_pickle + ) + + elif model_name == "VAE_LinNF_Config": + from ..vae_lin_nf import VAE_LinNF + + model = VAE_LinNF.load_from_hf_hub( + hf_hub_path=hf_hub_path, allow_pickle=allow_pickle + ) + + elif model_name == "VAMPConfig": + from ..vamp import VAMP + + model = VAMP.load_from_hf_hub( + hf_hub_path=hf_hub_path, allow_pickle=allow_pickle + ) + + elif model_name == "VQVAEConfig": + from ..vq_vae import VQVAE + + model = VQVAE.load_from_hf_hub( + hf_hub_path=hf_hub_path, allow_pickle=allow_pickle + ) + + elif model_name == "WAE_MMD_Config": + from ..wae_mmd import WAE_MMD + + model = WAE_MMD.load_from_hf_hub( + hf_hub_path=hf_hub_path, allow_pickle=allow_pickle + ) + + elif model_name == "MAFConfig": + from ..normalizing_flows import MAF + + model = MAF.load_from_hf_hub( + hf_hub_path=hf_hub_path, allow_pickle=allow_pickle + ) + + elif model_name == "IAFConfig": + from ..normalizing_flows import IAF + + model = IAF.load_from_hf_hub( + hf_hub_path=hf_hub_path, allow_pickle=allow_pickle + ) + + elif model_name == "PlanarFlowConfig": + from ..normalizing_flows import PlanarFlow + + model = PlanarFlow.load_from_hf_hub( + hf_hub_path=hf_hub_path, allow_pickle=allow_pickle + ) + + elif model_name == "RadialFlowConfig": + from ..normalizing_flows import RadialFlow + + model = RadialFlow.load_from_hf_hub( + hf_hub_path=hf_hub_path, allow_pickle=allow_pickle + ) + + elif model_name == "MADEConfig": + from ..normalizing_flows import MADE + + model = MADE.load_from_hf_hub( + hf_hub_path=hf_hub_path, allow_pickle=allow_pickle + ) + + elif model_name == "PixelCNNConfig": + from ..normalizing_flows import PixelCNN + + model = PixelCNN.load_from_hf_hub( + hf_hub_path=hf_hub_path, allow_pickle=allow_pickle + ) + else: raise NameError( "Cannot reload automatically the model... " diff --git a/src/pythae/models/base/base_config.py b/src/pythae/models/base/base_config.py index 9e1373d4..f5d4edbd 100644 --- a/src/pythae/models/base/base_config.py +++ b/src/pythae/models/base/base_config.py @@ -19,3 +19,8 @@ class BaseAEConfig(BaseConfig): latent_dim: int = 10 uses_default_encoder: bool = True uses_default_decoder: bool = True + + +@dataclass +class EnvironmentConfig(BaseConfig): + python_version: str = "3.8" diff --git a/src/pythae/models/base/base_model.py b/src/pythae/models/base/base_model.py index 92d7b6a2..6fbe1346 100644 --- a/src/pythae/models/base/base_model.py +++ b/src/pythae/models/base/base_model.py @@ -1,17 +1,30 @@ +import inspect +import logging import os +import shutil +import sys +import tempfile +import warnings from copy import deepcopy +from http.cookiejar import LoadError from typing import Optional -import dill +import cloudpickle import torch import torch.nn as nn from ...customexception import BadInheritanceError from ...data.datasets import BaseDataset +from ..auto_model import AutoConfig from ..nn import BaseDecoder, BaseEncoder from ..nn.default_architectures import Decoder_AE_MLP -from .base_config import BaseAEConfig -from .base_utils import CPU_Unpickler, ModelOutput +from .base_config import BaseAEConfig, EnvironmentConfig +from .base_utils import CPU_Unpickler, ModelOutput, hf_hub_is_available + +logger = logging.getLogger(__name__) +console = logging.StreamHandler() +logger.addHandler(console) +logger.setLevel(logging.INFO) class BaseAE(nn.Module): @@ -95,7 +108,7 @@ def update(self): By default, it does nothing. """ - def save(self, dir_path): + def save(self, dir_path: str): """Method to save the model at a specific location. It saves, the model weights as a ``models.pt`` file along with the model config as a ``model_config.json`` file. If the model to save used custom encoder (resp. decoder) provided by the user, these are also @@ -106,29 +119,110 @@ def save(self, dir_path): path does not exist a folder will be created at the provided location. """ - model_path = dir_path - + env_spec = EnvironmentConfig( + python_version=f"{sys.version_info[0]}.{sys.version_info[1]}" + ) model_dict = {"model_state_dict": deepcopy(self.state_dict())} - if not os.path.exists(model_path): + if not os.path.exists(dir_path): try: - os.makedirs(model_path) + os.makedirs(dir_path) except FileNotFoundError as e: raise e - self.model_config.save_json(model_path, "model_config") + env_spec.save_json(dir_path, "environment") + self.model_config.save_json(dir_path, "model_config") # only save .pkl if custom architecture provided if not self.model_config.uses_default_encoder: - with open(os.path.join(model_path, "encoder.pkl"), "wb") as fp: - dill.dump(self.encoder, fp) + with open(os.path.join(dir_path, "encoder.pkl"), "wb") as fp: + cloudpickle.register_pickle_by_value(inspect.getmodule(self.encoder)) + cloudpickle.dump(self.encoder, fp) if not self.model_config.uses_default_decoder: - with open(os.path.join(model_path, "decoder.pkl"), "wb") as fp: - dill.dump(self.decoder, fp) + with open(os.path.join(dir_path, "decoder.pkl"), "wb") as fp: + cloudpickle.register_pickle_by_value(inspect.getmodule(self.decoder)) + cloudpickle.dump(self.decoder, fp) + + torch.save(model_dict, os.path.join(dir_path, "model.pt")) + + def push_to_hf_hub(self, hf_hub_path: str): # pragma: no cover + """Method allowing to save your model directly on the huggung face hub. + You will need to have the `huggingface_hub` package installed and a valid Hugging Face + account. You can install the package using + + .. code-block:: bash + + python -m pip install huggingface_hub + + end then login using + + .. code-block:: bash + + huggingface-cli login + + Args: + hf_hub_path (str): path to your repo on the Hugging Face hub. + """ + if not hf_hub_is_available(): + raise ModuleNotFoundError( + "`huggingface_hub` package must be installed to push your model to the HF hub. " + "Run `python -m pip install huggingface_hub` and log in to your account with " + "`huggingface-cli login`." + ) + + else: + from huggingface_hub import CommitOperationAdd, HfApi + + logger.info( + f"Uploading {self.model_name} model to {hf_hub_path} repo in HF hub..." + ) + + tempdir = tempfile.mkdtemp() + + self.save(tempdir) + + model_files = os.listdir(tempdir) + + api = HfApi() + hf_operations = [] - torch.save(model_dict, os.path.join(model_path, "model.pt")) + for file in model_files: + hf_operations.append( + CommitOperationAdd( + path_in_repo=file, + path_or_fileobj=f"{str(os.path.join(tempdir, file))}", + ) + ) + + try: + api.create_commit( + commit_message=f"Uploading {self.model_name} in {hf_hub_path}", + repo_id=hf_hub_path, + operations=hf_operations, + ) + logger.info( + f"Successfully uploaded {self.model_name} to {hf_hub_path} repo in HF hub!" + ) + + except: + from huggingface_hub import create_repo + + repo_name = os.path.basename(os.path.normpath(hf_hub_path)) + logger.info( + f"Creating {repo_name} in the HF hub since it does not exist..." + ) + create_repo(repo_id=repo_name) + logger.info(f"Successfully created {repo_name} in the HF hub!") + + api.create_commit( + commit_message=f"Uploading {self.model_name} in {hf_hub_path}", + repo_id=hf_hub_path, + operations=hf_operations, + ) + + shutil.rmtree(tempdir) @classmethod def _load_model_config_from_folder(cls, dir_path): @@ -141,7 +235,7 @@ def _load_model_config_from_folder(cls, dir_path): ) path_to_model_config = os.path.join(dir_path, "model_config.json") - model_config = BaseAEConfig.from_json_file(path_to_model_config) + model_config = AutoConfig.from_json_file(path_to_model_config) return model_config @@ -179,6 +273,8 @@ def _load_model_weights_from_folder(cls, dir_path): def _load_custom_encoder_from_folder(cls, dir_path): file_list = os.listdir(dir_path) + cls._check_python_version_from_folder(dir_path=dir_path) + if "encoder.pkl" not in file_list: raise FileNotFoundError( f"Missing encoder pkl file ('encoder.pkl') in" @@ -196,6 +292,7 @@ def _load_custom_encoder_from_folder(cls, dir_path): def _load_custom_decoder_from_folder(cls, dir_path): file_list = os.listdir(dir_path) + cls._check_python_version_from_folder(dir_path=dir_path) if "decoder.pkl" not in file_list: raise FileNotFoundError( @@ -248,6 +345,90 @@ def load_from_folder(cls, dir_path): return model + @classmethod + def load_from_hf_hub(cls, hf_hub_path: str, allow_pickle=False): # pragma: no cover + """Class method to be used to load a pretrained model from the Hugging Face hub + + Args: + hf_hub_path (str): The path where the model should have been be saved on the + hugginface hub. + + .. note:: + This function requires the folder to contain: + + - | a ``model_config.json`` and a ``model.pt`` if no custom architectures were provided + + **or** + + - | a ``model_config.json``, a ``model.pt`` and a ``encoder.pkl`` (resp. + ``decoder.pkl``) if a custom encoder (resp. decoder) was provided + """ + + if not hf_hub_is_available(): + raise ModuleNotFoundError( + "`huggingface_hub` package must be installed to load models from the HF hub. " + "Run `python -m pip install huggingface_hub` and log in to your account with " + "`huggingface-cli login`." + ) + + else: + from huggingface_hub import hf_hub_download + + logger.info(f"Downloading {cls.__name__} files for rebuilding...") + + _ = hf_hub_download(repo_id=hf_hub_path, filename="environment.json") + config_path = hf_hub_download(repo_id=hf_hub_path, filename="model_config.json") + dir_path = os.path.dirname(config_path) + + _ = hf_hub_download(repo_id=hf_hub_path, filename="model.pt") + + model_config = cls._load_model_config_from_folder(dir_path) + + if ( + cls.__name__ + "Config" != model_config.name + and cls.__name__ + "_Config" != model_config.name + ): + warnings.warn( + f"You are trying to load a " + f"`{ cls.__name__}` while a " + f"`{model_config.name}` is given." + ) + + model_weights = cls._load_model_weights_from_folder(dir_path) + + if ( + not model_config.uses_default_encoder + or not model_config.uses_default_decoder + ) and not allow_pickle: + warnings.warn( + "You are about to download pickled files from the HF hub that may have " + "been created by a third party and so could potentially harm your computer. If you " + "are sure that you want to download them set `allow_pickle=true`." + ) + + else: + + if not model_config.uses_default_encoder: + _ = hf_hub_download(repo_id=hf_hub_path, filename="encoder.pkl") + encoder = cls._load_custom_encoder_from_folder(dir_path) + + else: + encoder = None + + if not model_config.uses_default_decoder: + _ = hf_hub_download(repo_id=hf_hub_path, filename="decoder.pkl") + decoder = cls._load_custom_decoder_from_folder(dir_path) + + else: + decoder = None + + logger.info(f"Successfully downloaded {cls.__name__} model!") + + model = cls(model_config, encoder=encoder, decoder=decoder) + model.load_state_dict(model_weights) + + return model + def set_encoder(self, encoder: BaseEncoder) -> None: """Set the encoder of the model""" if not issubclass(type(encoder), BaseEncoder): @@ -269,3 +450,24 @@ def set_decoder(self, decoder: BaseDecoder) -> None: ) ) self.decoder = decoder + + @classmethod + def _check_python_version_from_folder(cls, dir_path: str): + if "environment.json" in os.listdir(dir_path): + env_spec = EnvironmentConfig.from_json_file( + os.path.join(dir_path, "environment.json") + ) + python_version = env_spec.python_version + python_version_minor = python_version.split(".")[1] + + if python_version_minor == "7" and sys.version_info[1] > 7: + raise LoadError( + "Trying to reload a model saved with python3.7 with python3.8+. " + "Please create a virtual env with python 3.7 to reload this model." + ) + + elif int(python_version_minor) >= 8 and sys.version_info[1] == 7: + raise LoadError( + "Trying to reload a model saved with python3.8+ with python3.7. " + "Please create a virtual env with python 3.8+ to reload this model." + ) diff --git a/src/pythae/models/base/base_utils.py b/src/pythae/models/base/base_utils.py index 6b13d5f5..a53165b1 100644 --- a/src/pythae/models/base/base_utils.py +++ b/src/pythae/models/base/base_utils.py @@ -1,10 +1,21 @@ +import importlib import io +import logging from collections import OrderedDict from typing import Any, Tuple -import dill +import pickle5 as pickle import torch +logger = logging.getLogger(__name__) +console = logging.StreamHandler() +logger.addHandler(console) +logger.setLevel(logging.INFO) + + +def hf_hub_is_available(): + return importlib.util.find_spec("huggingface_hub") is not None + class ModelOutput(OrderedDict): """Base ModelOutput class fixing the output type from the models. This class is inspired from @@ -32,7 +43,7 @@ def to_tuple(self) -> Tuple[Any]: return tuple(self[k] for k in self.keys()) -class CPU_Unpickler(dill.Unpickler): +class CPU_Unpickler(pickle.Unpickler): def find_class(self, module, name): if module == "torch.storage" and name == "_load_from_bytes": return lambda b: torch.load(io.BytesIO(b), map_location="cpu") diff --git a/src/pythae/models/beta_tc_vae/beta_tc_vae_model.py b/src/pythae/models/beta_tc_vae/beta_tc_vae_model.py index a171416f..99576c28 100644 --- a/src/pythae/models/beta_tc_vae/beta_tc_vae_model.py +++ b/src/pythae/models/beta_tc_vae/beta_tc_vae_model.py @@ -195,18 +195,3 @@ def _log_importance_weight_matrix(self, batch_size, dataset_size): W.view(-1)[1 :: M + 1] = strat_weight W[M - 1, 0] = strat_weight return W.log() - - @classmethod - def _load_model_config_from_folder(cls, dir_path): - file_list = os.listdir(dir_path) - - if "model_config.json" not in file_list: - raise FileNotFoundError( - f"Missing model config file ('model_config.json') in" - f"{dir_path}... Cannot perform model building." - ) - - path_to_model_config = os.path.join(dir_path, "model_config.json") - model_config = BetaTCVAEConfig.from_json_file(path_to_model_config) - - return model_config diff --git a/src/pythae/models/beta_vae/beta_vae_model.py b/src/pythae/models/beta_vae/beta_vae_model.py index f4d88993..d05f3bd4 100644 --- a/src/pythae/models/beta_vae/beta_vae_model.py +++ b/src/pythae/models/beta_vae/beta_vae_model.py @@ -111,18 +111,3 @@ def _sample_gauss(self, mu, std): # Sample N(0, I) eps = torch.randn_like(std) return mu + eps * std, eps - - @classmethod - def _load_model_config_from_folder(cls, dir_path): - file_list = os.listdir(dir_path) - - if "model_config.json" not in file_list: - raise FileNotFoundError( - f"Missing model config file ('model_config.json') in" - f"{dir_path}... Cannot perform model building." - ) - - path_to_model_config = os.path.join(dir_path, "model_config.json") - model_config = BetaVAEConfig.from_json_file(path_to_model_config) - - return model_config diff --git a/src/pythae/models/disentangled_beta_vae/disentangled_beta_vae_model.py b/src/pythae/models/disentangled_beta_vae/disentangled_beta_vae_model.py index cb15a486..87bb2f00 100644 --- a/src/pythae/models/disentangled_beta_vae/disentangled_beta_vae_model.py +++ b/src/pythae/models/disentangled_beta_vae/disentangled_beta_vae_model.py @@ -121,18 +121,3 @@ def _sample_gauss(self, mu, std): # Sample N(0, I) eps = torch.randn_like(std) return mu + eps * std, eps - - @classmethod - def _load_model_config_from_folder(cls, dir_path): - file_list = os.listdir(dir_path) - - if "model_config.json" not in file_list: - raise FileNotFoundError( - f"Missing model config file ('model_config.json') in" - f"{dir_path}... Cannot perform model building." - ) - - path_to_model_config = os.path.join(dir_path, "model_config.json") - model_config = DisentangledBetaVAEConfig.from_json_file(path_to_model_config) - - return model_config diff --git a/src/pythae/models/factor_vae/factor_vae_model.py b/src/pythae/models/factor_vae/factor_vae_model.py index 51dd5b3f..159df2ca 100644 --- a/src/pythae/models/factor_vae/factor_vae_model.py +++ b/src/pythae/models/factor_vae/factor_vae_model.py @@ -184,57 +184,3 @@ def _permute_dims(self, z): permuted[:, i] = z[perms, i] return permuted - - @classmethod - def _load_model_config_from_folder(cls, dir_path): - file_list = os.listdir(dir_path) - - if "model_config.json" not in file_list: - raise FileNotFoundError( - f"Missing model config file ('model_config.json') in" - f"{dir_path}... Cannot perform model building." - ) - - path_to_model_config = os.path.join(dir_path, "model_config.json") - model_config = FactorVAEConfig.from_json_file(path_to_model_config) - - return model_config - - @classmethod - def load_from_folder(cls, dir_path): - """Class method to be used to load the model from a specific folder - - Args: - dir_path (str): The path where the model should have been be saved. - - .. note:: - This function requires the folder to contain: - - - | a ``model_config.json`` and a ``model.pt`` if no custom architectures were provided - - **or** - - - | a ``model_config.json``, a ``model.pt`` and a ``encoder.pkl`` (resp. - ``decoder.pkl``) if a custom encoder (resp. decoder) was provided - - """ - - model_config = cls._load_model_config_from_folder(dir_path) - model_weights = cls._load_model_weights_from_folder(dir_path) - - if not model_config.uses_default_encoder: - encoder = cls._load_custom_encoder_from_folder(dir_path) - - else: - encoder = None - - if not model_config.uses_default_decoder: - decoder = cls._load_custom_decoder_from_folder(dir_path) - - else: - decoder = None - - model = cls(model_config, encoder=encoder, decoder=decoder) - model.load_state_dict(model_weights) - - return model diff --git a/src/pythae/models/hvae/hvae_model.py b/src/pythae/models/hvae/hvae_model.py index 9ce95a52..db932874 100644 --- a/src/pythae/models/hvae/hvae_model.py +++ b/src/pythae/models/hvae/hvae_model.py @@ -323,18 +323,3 @@ def get_nll(self, data, n_samples=1, batch_size=100): log_p.append((torch.logsumexp(log_p_x, 0) - np.log(len(log_p_x))).item()) return np.mean(log_p) - - @classmethod - def _load_model_config_from_folder(cls, dir_path): - file_list = os.listdir(dir_path) - - if "model_config.json" not in file_list: - raise FileNotFoundError( - f"Missing model config file ('model_config.json') in" - f"{dir_path}... Cannot perform model building." - ) - - path_to_model_config = os.path.join(dir_path, "model_config.json") - model_config = HVAEConfig.from_json_file(path_to_model_config) - - return model_config diff --git a/src/pythae/models/info_vae/info_vae_model.py b/src/pythae/models/info_vae/info_vae_model.py index 874c7a61..95599b05 100644 --- a/src/pythae/models/info_vae/info_vae_model.py +++ b/src/pythae/models/info_vae/info_vae_model.py @@ -164,18 +164,3 @@ def rbf_kernel(self, z1, z2): k = torch.exp(-torch.norm(z1.unsqueeze(1) - z2.unsqueeze(0), dim=-1) ** 2 / C) return k - - @classmethod - def _load_model_config_from_folder(cls, dir_path): - file_list = os.listdir(dir_path) - - if "model_config.json" not in file_list: - raise FileNotFoundError( - f"Missing model config file ('model_config.json') in" - f"{dir_path}... Cannot perform model building." - ) - - path_to_model_config = os.path.join(dir_path, "model_config.json") - model_config = INFOVAE_MMD_Config.from_json_file(path_to_model_config) - - return model_config diff --git a/src/pythae/models/iwae/iwae_model.py b/src/pythae/models/iwae/iwae_model.py index 28271142..be4510f6 100644 --- a/src/pythae/models/iwae/iwae_model.py +++ b/src/pythae/models/iwae/iwae_model.py @@ -222,18 +222,3 @@ def get_nll(self, data, n_samples=1, batch_size=100): log_p.append((torch.logsumexp(log_p_x, 0) - np.log(len(log_p_x))).item()) return np.mean(log_p) - - @classmethod - def _load_model_config_from_folder(cls, dir_path): - file_list = os.listdir(dir_path) - - if "model_config.json" not in file_list: - raise FileNotFoundError( - f"Missing model config file ('model_config.json') in" - f"{dir_path}... Cannot perform model building." - ) - - path_to_model_config = os.path.join(dir_path, "model_config.json") - model_config = IWAEConfig.from_json_file(path_to_model_config) - - return model_config diff --git a/src/pythae/models/msssim_vae/msssim_vae_model.py b/src/pythae/models/msssim_vae/msssim_vae_model.py index e628802a..686fcb8a 100644 --- a/src/pythae/models/msssim_vae/msssim_vae_model.py +++ b/src/pythae/models/msssim_vae/msssim_vae_model.py @@ -98,18 +98,3 @@ def _sample_gauss(self, mu, std): # Sample N(0, I) eps = torch.randn_like(std) return mu + eps * std, eps - - @classmethod - def _load_model_config_from_folder(cls, dir_path): - file_list = os.listdir(dir_path) - - if "model_config.json" not in file_list: - raise FileNotFoundError( - f"Missing model config file ('model_config.json') in" - f"{dir_path}... Cannot perform model building." - ) - - path_to_model_config = os.path.join(dir_path, "model_config.json") - model_config = MSSSIM_VAEConfig.from_json_file(path_to_model_config) - - return model_config diff --git a/src/pythae/models/normalizing_flows/base/base_nf_model.py b/src/pythae/models/normalizing_flows/base/base_nf_model.py index 6dea2472..6f64063c 100644 --- a/src/pythae/models/normalizing_flows/base/base_nf_model.py +++ b/src/pythae/models/normalizing_flows/base/base_nf_model.py @@ -1,4 +1,5 @@ import os +import sys from copy import deepcopy import numpy as np @@ -6,6 +7,8 @@ import torch.nn as nn from ....data.datasets import BaseDataset +from ...auto_model import AutoConfig +from ...base.base_config import EnvironmentConfig from ...base.base_utils import ModelOutput from .base_nf_config import BaseNFConfig @@ -76,20 +79,37 @@ def save(self, dir_path): path does not exist a folder will be created at the provided location. """ - model_path = dir_path - + env_spec = EnvironmentConfig( + python_version=f"{sys.version_info[0]}.{sys.version_info[1]}" + ) model_dict = {"model_state_dict": deepcopy(self.state_dict())} - if not os.path.exists(model_path): + if not os.path.exists(dir_path): try: - os.makedirs(model_path) + os.makedirs(dir_path) except (FileNotFoundError, TypeError) as e: raise e - self.model_config.save_json(model_path, "model_config") + env_spec.save_json(dir_path, "environment") + self.model_config.save_json(dir_path, "model_config") + + torch.save(model_dict, os.path.join(dir_path, "model.pt")) + + @classmethod + def _load_model_config_from_folder(cls, dir_path): + file_list = os.listdir(dir_path) + + if "model_config.json" not in file_list: + raise FileNotFoundError( + f"Missing model config file ('model_config.json') in" + f"{dir_path}... Cannot perform model building." + ) + + path_to_model_config = os.path.join(dir_path, "model_config.json") + model_config = AutoConfig.from_json_file(path_to_model_config) - torch.save(model_dict, os.path.join(model_path, "model.pt")) + return model_config @classmethod def _load_model_weights_from_folder(cls, dir_path): diff --git a/src/pythae/models/normalizing_flows/iaf/iaf_model.py b/src/pythae/models/normalizing_flows/iaf/iaf_model.py index 43bd6b98..6366a652 100644 --- a/src/pythae/models/normalizing_flows/iaf/iaf_model.py +++ b/src/pythae/models/normalizing_flows/iaf/iaf_model.py @@ -105,18 +105,3 @@ def inverse(self, y: torch.Tensor, **kwargs) -> ModelOutput: sum_log_abs_det_jac += layer_out.log_abs_det_jac return ModelOutput(out=y, log_abs_det_jac=sum_log_abs_det_jac) - - @classmethod - def _load_model_config_from_folder(cls, dir_path): - file_list = os.listdir(dir_path) - - if "model_config.json" not in file_list: - raise FileNotFoundError( - f"Missing model config file ('model_config.json') in" - f"{dir_path}... Cannot perform model building." - ) - - path_to_model_config = os.path.join(dir_path, "model_config.json") - model_config = IAFConfig.from_json_file(path_to_model_config) - - return model_config diff --git a/src/pythae/models/normalizing_flows/made/made_model.py b/src/pythae/models/normalizing_flows/made/made_model.py index 3f395e9a..191e703d 100644 --- a/src/pythae/models/normalizing_flows/made/made_model.py +++ b/src/pythae/models/normalizing_flows/made/made_model.py @@ -110,18 +110,3 @@ def forward(self, x: torch.tensor, **kwargs) -> ModelOutput: log_var = net_output[:, self.input_dim :] return ModelOutput(mu=mu, log_var=log_var) - - @classmethod - def _load_model_config_from_folder(cls, dir_path): - file_list = os.listdir(dir_path) - - if "model_config.json" not in file_list: - raise FileNotFoundError( - f"Missing model config file ('model_config.json') in" - f"{dir_path}... Cannot perform model building." - ) - - path_to_model_config = os.path.join(dir_path, "model_config.json") - model_config = MADEConfig.from_json_file(path_to_model_config) - - return model_config diff --git a/src/pythae/models/normalizing_flows/maf/maf_model.py b/src/pythae/models/normalizing_flows/maf/maf_model.py index 766d353d..1b2f11f2 100644 --- a/src/pythae/models/normalizing_flows/maf/maf_model.py +++ b/src/pythae/models/normalizing_flows/maf/maf_model.py @@ -105,18 +105,3 @@ def inverse(self, y: torch.Tensor, **kwargs) -> ModelOutput: sum_log_abs_det_jac += layer_out.log_abs_det_jac return ModelOutput(out=y, log_abs_det_jac=sum_log_abs_det_jac) - - @classmethod - def _load_model_config_from_folder(cls, dir_path): - file_list = os.listdir(dir_path) - - if "model_config.json" not in file_list: - raise FileNotFoundError( - f"Missing model config file ('model_config.json') in" - f"{dir_path}... Cannot perform model building." - ) - - path_to_model_config = os.path.join(dir_path, "model_config.json") - model_config = MAFConfig.from_json_file(path_to_model_config) - - return model_config diff --git a/src/pythae/models/normalizing_flows/pixelcnn/pixelcnn_model.py b/src/pythae/models/normalizing_flows/pixelcnn/pixelcnn_model.py index ee7bb49f..849f0cc1 100644 --- a/src/pythae/models/normalizing_flows/pixelcnn/pixelcnn_model.py +++ b/src/pythae/models/normalizing_flows/pixelcnn/pixelcnn_model.py @@ -91,18 +91,3 @@ def forward(self, inputs: BaseDataset, **kwargs) -> ModelOutput: loss = F.cross_entropy(out, x.long()) return ModelOutput(out=out, loss=loss) - - @classmethod - def _load_model_config_from_folder(cls, dir_path): - file_list = os.listdir(dir_path) - - if "model_config.json" not in file_list: - raise FileNotFoundError( - f"Missing model config file ('model_config.json') in" - f"{dir_path}... Cannot perform model building." - ) - - path_to_model_config = os.path.join(dir_path, "model_config.json") - model_config = PixelCNNConfig.from_json_file(path_to_model_config) - - return model_config diff --git a/src/pythae/models/normalizing_flows/planar_flow/planar_flow_model.py b/src/pythae/models/normalizing_flows/planar_flow/planar_flow_model.py index 35e11d36..4ef15611 100644 --- a/src/pythae/models/normalizing_flows/planar_flow/planar_flow_model.py +++ b/src/pythae/models/normalizing_flows/planar_flow/planar_flow_model.py @@ -58,18 +58,3 @@ def forward(self, x: torch.Tensor, **kwargs) -> ModelOutput: output = ModelOutput(out=f, log_abs_det_jac=log_det) return output - - @classmethod - def _load_model_config_from_folder(cls, dir_path): - file_list = os.listdir(dir_path) - - if "model_config.json" not in file_list: - raise FileNotFoundError( - f"Missing model config file ('model_config.json') in" - f"{dir_path}... Cannot perform model building." - ) - - path_to_model_config = os.path.join(dir_path, "model_config.json") - model_config = PlanarFlowConfig.from_json_file(path_to_model_config) - - return model_config diff --git a/src/pythae/models/normalizing_flows/radial_flow/radial_flow_model.py b/src/pythae/models/normalizing_flows/radial_flow/radial_flow_model.py index a57fa2f2..4f2b9e5a 100644 --- a/src/pythae/models/normalizing_flows/radial_flow/radial_flow_model.py +++ b/src/pythae/models/normalizing_flows/radial_flow/radial_flow_model.py @@ -52,18 +52,3 @@ def forward(self, x: torch.Tensor, **kwargs) -> ModelOutput: output = ModelOutput(out=f, log_abs_det_jac=log_det.squeeze()) return output - - @classmethod - def _load_model_config_from_folder(cls, dir_path): - file_list = os.listdir(dir_path) - - if "model_config.json" not in file_list: - raise FileNotFoundError( - f"Missing model config file ('model_config.json') in" - f"{dir_path}... Cannot perform model building." - ) - - path_to_model_config = os.path.join(dir_path, "model_config.json") - model_config = RadialFlowConfig.from_json_file(path_to_model_config) - - return model_config diff --git a/src/pythae/models/rae_gp/rae_gp_model.py b/src/pythae/models/rae_gp/rae_gp_model.py index c80d9f75..9cec7d28 100644 --- a/src/pythae/models/rae_gp/rae_gp_model.py +++ b/src/pythae/models/rae_gp/rae_gp_model.py @@ -105,18 +105,3 @@ def _compute_gp(self, recon_x, x): )[0].reshape(recon_x.shape[0], -1) return grads.norm(dim=-1) ** 2 - - @classmethod - def _load_model_config_from_folder(cls, dir_path): - file_list = os.listdir(dir_path) - - if "model_config.json" not in file_list: - raise FileNotFoundError( - f"Missing model config file ('model_config.json') in" - f"{dir_path}... Cannot perform model building." - ) - - path_to_model_config = os.path.join(dir_path, "model_config.json") - model_config = RAE_GP_Config.from_json_file(path_to_model_config) - - return model_config diff --git a/src/pythae/models/rae_l2/rae_l2_model.py b/src/pythae/models/rae_l2/rae_l2_model.py index 44cf2930..8dab4518 100644 --- a/src/pythae/models/rae_l2/rae_l2_model.py +++ b/src/pythae/models/rae_l2/rae_l2_model.py @@ -86,18 +86,3 @@ def loss_function(self, recon_x, x, z): (recon_loss).mean(dim=0), (embedding_loss).mean(dim=0), ) - - @classmethod - def _load_model_config_from_folder(cls, dir_path): - file_list = os.listdir(dir_path) - - if "model_config.json" not in file_list: - raise FileNotFoundError( - f"Missing model config file ('model_config.json') in" - f"{dir_path}... Cannot perform model building." - ) - - path_to_model_config = os.path.join(dir_path, "model_config.json") - model_config = RAE_L2_Config.from_json_file(path_to_model_config) - - return model_config diff --git a/src/pythae/models/rhvae/rhvae_model.py b/src/pythae/models/rhvae/rhvae_model.py index 2bc3a667..9b3cbfad 100644 --- a/src/pythae/models/rhvae/rhvae_model.py +++ b/src/pythae/models/rhvae/rhvae_model.py @@ -1,9 +1,12 @@ +import inspect +import logging import os +import warnings from collections import deque from copy import deepcopy from typing import Optional -import dill +import cloudpickle import numpy as np import torch import torch.nn as nn @@ -12,13 +15,18 @@ from ...customexception import BadInheritanceError from ...data.datasets import BaseDataset -from ..base.base_utils import CPU_Unpickler, ModelOutput +from ..base.base_utils import CPU_Unpickler, ModelOutput, hf_hub_is_available from ..nn import BaseDecoder, BaseEncoder, BaseMetric from ..nn.default_architectures import Metric_MLP from ..vae import VAE from .rhvae_config import RHVAEConfig from .rhvae_utils import create_inverse_metric, create_metric +logger = logging.getLogger(__name__) +console = logging.StreamHandler() +logger.addHandler(console) +logger.setLevel(logging.INFO) + class RHVAE(VAE): r""" @@ -599,29 +607,16 @@ def save(self, dir_path: str): if not self.model_config.uses_default_metric: with open(os.path.join(model_path, "metric.pkl"), "wb") as fp: - dill.dump(self.metric, fp) + cloudpickle.register_pickle_by_value(inspect.getmodule(self.metric)) + cloudpickle.dump(self.metric, fp) torch.save(model_dict, os.path.join(model_path, "model.pt")) - @classmethod - def _load_model_config_from_folder(cls, dir_path): - file_list = os.listdir(dir_path) - - if "model_config.json" not in file_list: - raise FileNotFoundError( - f"Missing model config file ('model_config.json') in" - f"{dir_path}... Cannot perform model building." - ) - - path_to_model_config = os.path.join(dir_path, "model_config.json") - model_config = RHVAEConfig.from_json_file(path_to_model_config) - - return model_config - @classmethod def _load_custom_metric_from_folder(cls, dir_path): file_list = os.listdir(dir_path) + cls._check_python_version_from_folder(dir_path=dir_path) if "metric.pkl" not in file_list: raise FileNotFoundError( @@ -717,3 +712,108 @@ def load_from_folder(cls, dir_path): model.load_state_dict(model_weights) return model + + @classmethod + def load_from_hf_hub( + cls, hf_hub_path: str, allow_pickle: bool = False + ): # pragma: no cover + """Class method to be used to load a pretrained model from the Hugging Face hub + + Args: + hf_hub_path (str): The path where the model should have been be saved on the + hugginface hub. + + .. note:: + This function requires the folder to contain: + + - | a ``model_config.json`` and a ``model.pt`` if no custom architectures were provided + + **or** + + - | a ``model_config.json``, a ``model.pt`` and a ``encoder.pkl`` (resp. + ``decoder.pkl`` and ``metric.pkl``) if a custom encoder (resp. decoder and/or + metric) was provided + """ + + if not hf_hub_is_available(): + raise ModuleNotFoundError( + "`huggingface_hub` package must be installed to load models from the HF hub. " + "Run `python -m pip install huggingface_hub` and log in to your account with " + "`huggingface-cli login`." + ) + + else: + from huggingface_hub import hf_hub_download + + logger.info(f"Downloading {cls.__name__} files for rebuilding...") + + config_path = hf_hub_download(repo_id=hf_hub_path, filename="model_config.json") + dir_path = os.path.dirname(config_path) + + _ = hf_hub_download(repo_id=hf_hub_path, filename="model.pt") + + model_config = cls._load_model_config_from_folder(dir_path) + + if ( + cls.__name__ + "Config" != model_config.name + and cls.__name__ + "_Config" != model_config.name + ): + warnings.warn( + f"You are trying to load a " + f"`{ cls.__name__}` while a " + f"`{model_config.name}` is given." + ) + + model_weights = cls._load_model_weights_from_folder(dir_path) + + if ( + not model_config.uses_default_encoder + or not model_config.uses_default_decoder + or not model_config.uses_default_metric + ) and not allow_pickle: + warnings.warn( + "You are about to download pickled files from the HF hub that may have " + "been created by a third party and so could potentially harm your computer. If you " + "are sure that you want to download them set `allow_pickle=true`." + ) + + else: + + if not model_config.uses_default_encoder: + _ = hf_hub_download(repo_id=hf_hub_path, filename="encoder.pkl") + encoder = cls._load_custom_encoder_from_folder(dir_path) + + else: + encoder = None + + if not model_config.uses_default_decoder: + _ = hf_hub_download(repo_id=hf_hub_path, filename="decoder.pkl") + decoder = cls._load_custom_decoder_from_folder(dir_path) + + else: + decoder = None + + if not model_config.uses_default_metric: + _ = hf_hub_download(repo_id=hf_hub_path, filename="metric.pkl") + metric = cls._load_custom_metric_from_folder(dir_path) + + else: + metric = None + + logger.info(f"Successfully downloaded {cls.__name__} model!") + + model = cls(model_config, encoder=encoder, decoder=decoder, metric=metric) + + metric_M, metric_centroids = cls._load_metric_matrices_and_centroids( + dir_path + ) + + model.M_tens = metric_M + model.centroids_tens = metric_centroids + + model.G = create_metric(model) + model.G_inv = create_inverse_metric(model) + + model.load_state_dict(model_weights) + + return model diff --git a/src/pythae/models/svae/svae_model.py b/src/pythae/models/svae/svae_model.py index fff4f909..a045ea20 100644 --- a/src/pythae/models/svae/svae_model.py +++ b/src/pythae/models/svae/svae_model.py @@ -300,18 +300,3 @@ def get_nll(self, data, n_samples=1, batch_size=100): print(f"Current nll at {i}: {np.mean(log_p)}") return np.mean(log_p) - - @classmethod - def _load_model_config_from_folder(cls, dir_path): - file_list = os.listdir(dir_path) - - if "model_config.json" not in file_list: - raise FileNotFoundError( - f"Missing model config file ('model_config.json') in" - f"{dir_path}... Cannot perform model building." - ) - - path_to_model_config = os.path.join(dir_path, "model_config.json") - model_config = SVAEConfig.from_json_file(path_to_model_config) - - return model_config diff --git a/src/pythae/models/vae/vae_model.py b/src/pythae/models/vae/vae_model.py index f1b19818..e89874ab 100644 --- a/src/pythae/models/vae/vae_model.py +++ b/src/pythae/models/vae/vae_model.py @@ -195,18 +195,3 @@ def get_nll(self, data, n_samples=1, batch_size=100): log_p.append((torch.logsumexp(log_p_x, 0) - np.log(len(log_p_x))).item()) return np.mean(log_p) - - @classmethod - def _load_model_config_from_folder(cls, dir_path): - file_list = os.listdir(dir_path) - - if "model_config.json" not in file_list: - raise FileNotFoundError( - f"Missing model config file ('model_config.json') in" - f"{dir_path}... Cannot perform model building." - ) - - path_to_model_config = os.path.join(dir_path, "model_config.json") - model_config = VAEConfig.from_json_file(path_to_model_config) - - return model_config diff --git a/src/pythae/models/vae_gan/vae_gan_model.py b/src/pythae/models/vae_gan/vae_gan_model.py index c52fd45b..23e46cdc 100644 --- a/src/pythae/models/vae_gan/vae_gan_model.py +++ b/src/pythae/models/vae_gan/vae_gan_model.py @@ -1,19 +1,27 @@ +import inspect +import logging import os +import warnings from copy import deepcopy from typing import Optional -import dill +import cloudpickle import torch import torch.nn.functional as F from ...customexception import BadInheritanceError from ...data.datasets import BaseDataset -from ..base.base_utils import CPU_Unpickler, ModelOutput +from ..base.base_utils import CPU_Unpickler, ModelOutput, hf_hub_is_available from ..nn import BaseDecoder, BaseDiscriminator, BaseEncoder from ..nn.default_architectures import Discriminator_MLP from ..vae import VAE from .vae_gan_config import VAEGANConfig +logger = logging.getLogger(__name__) +console = logging.StreamHandler() +logger.addHandler(console) +logger.setLevel(logging.INFO) + class VAEGAN(VAE): """Variational Autoencoder using Adversarial reconstruction loss model. @@ -276,29 +284,18 @@ def save(self, dir_path: str): if not self.model_config.uses_default_discriminator: with open(os.path.join(model_path, "discriminator.pkl"), "wb") as fp: - dill.dump(self.discriminator, fp) + cloudpickle.register_pickle_by_value( + inspect.getmodule(self.discriminator) + ) + cloudpickle.dump(self.discriminator, fp) torch.save(model_dict, os.path.join(model_path, "model.pt")) - @classmethod - def _load_model_config_from_folder(cls, dir_path): - file_list = os.listdir(dir_path) - - if "model_config.json" not in file_list: - raise FileNotFoundError( - f"Missing model config file ('model_config.json') in" - f"{dir_path}... Cannot perform model building." - ) - - path_to_model_config = os.path.join(dir_path, "model_config.json") - model_config = VAEGANConfig.from_json_file(path_to_model_config) - - return model_config - @classmethod def _load_custom_discriminator_from_folder(cls, dir_path): file_list = os.listdir(dir_path) + cls._check_python_version_from_folder(dir_path=dir_path) if "discriminator.pkl" not in file_list: raise FileNotFoundError( @@ -359,3 +356,102 @@ def load_from_folder(cls, dir_path): model.load_state_dict(model_weights) return model + + @classmethod + def load_from_hf_hub( + cls, hf_hub_path: str, allow_pickle: bool = False + ): # pragma: no cover + """Class method to be used to load a pretrained model from the Hugging Face hub + + Args: + hf_hub_path (str): The path where the model should have been be saved on the + hugginface hub. + + .. note:: + This function requires the folder to contain: + + - | a ``model_config.json`` and a ``model.pt`` if no custom architectures were provided + + **or** + + - | a ``model_config.json``, a ``model.pt`` and a ``encoder.pkl`` (resp. + ``decoder.pkl`` and ``discriminator``) if a custom encoder (resp. decoder and/or + discriminator) was provided + """ + + if not hf_hub_is_available(): + raise ModuleNotFoundError( + "`huggingface_hub` package must be installed to load models from the HF hub. " + "Run `python -m pip install huggingface_hub` and log in to your account with " + "`huggingface-cli login`." + ) + + else: + from huggingface_hub import hf_hub_download + + logger.info(f"Downloading {cls.__name__} files for rebuilding...") + + config_path = hf_hub_download(repo_id=hf_hub_path, filename="model_config.json") + dir_path = os.path.dirname(config_path) + + _ = hf_hub_download(repo_id=hf_hub_path, filename="model.pt") + + model_config = cls._load_model_config_from_folder(dir_path) + + if ( + cls.__name__ + "Config" != model_config.name + and cls.__name__ + "_Config" != model_config.name + ): + warnings.warn( + f"You are trying to load a " + f"`{ cls.__name__}` while a " + f"`{model_config.name}` is given." + ) + + model_weights = cls._load_model_weights_from_folder(dir_path) + + if ( + not model_config.uses_default_encoder + or not model_config.uses_default_decoder + or not model_config.uses_default_discriminator + ) and not allow_pickle: + warnings.warn( + "You are about to download pickled files from the HF hub that may have " + "been created by a third party and so could potentially harm your computer. If you " + "are sure that you want to download them set `allow_pickle=true`." + ) + + else: + + if not model_config.uses_default_encoder: + _ = hf_hub_download(repo_id=hf_hub_path, filename="encoder.pkl") + encoder = cls._load_custom_encoder_from_folder(dir_path) + + else: + encoder = None + + if not model_config.uses_default_decoder: + _ = hf_hub_download(repo_id=hf_hub_path, filename="decoder.pkl") + decoder = cls._load_custom_decoder_from_folder(dir_path) + + else: + decoder = None + + if not model_config.uses_default_discriminator: + _ = hf_hub_download(repo_id=hf_hub_path, filename="discriminator.pkl") + discriminator = cls._load_custom_discriminator_from_folder(dir_path) + + else: + discriminator = None + + logger.info(f"Successfully downloaded {cls.__name__} model!") + + model = cls( + model_config, + encoder=encoder, + decoder=decoder, + discriminator=discriminator, + ) + model.load_state_dict(model_weights) + + return model diff --git a/src/pythae/models/vae_iaf/vae_iaf_model.py b/src/pythae/models/vae_iaf/vae_iaf_model.py index 1ba79449..37c646c8 100644 --- a/src/pythae/models/vae_iaf/vae_iaf_model.py +++ b/src/pythae/models/vae_iaf/vae_iaf_model.py @@ -221,18 +221,3 @@ def get_nll(self, data, n_samples=1, batch_size=100): log_p.append((torch.logsumexp(log_p_x, 0) - np.log(len(log_p_x))).item()) return np.mean(log_p) - - @classmethod - def _load_model_config_from_folder(cls, dir_path): - file_list = os.listdir(dir_path) - - if "model_config.json" not in file_list: - raise FileNotFoundError( - f"Missing model config file ('model_config.json') in" - f"{dir_path}... Cannot perform model building." - ) - - path_to_model_config = os.path.join(dir_path, "model_config.json") - model_config = VAE_IAF_Config.from_json_file(path_to_model_config) - - return model_config diff --git a/src/pythae/models/vae_lin_nf/vae_lin_nf_model.py b/src/pythae/models/vae_lin_nf/vae_lin_nf_model.py index 726a060c..e4cd5164 100644 --- a/src/pythae/models/vae_lin_nf/vae_lin_nf_model.py +++ b/src/pythae/models/vae_lin_nf/vae_lin_nf_model.py @@ -235,18 +235,3 @@ def get_nll(self, data, n_samples=1, batch_size=100): log_p.append((torch.logsumexp(log_p_x, 0) - np.log(len(log_p_x))).item()) return np.mean(log_p) - - @classmethod - def _load_model_config_from_folder(cls, dir_path): - file_list = os.listdir(dir_path) - - if "model_config.json" not in file_list: - raise FileNotFoundError( - f"Missing model config file ('model_config.json') in" - f"{dir_path}... Cannot perform model building." - ) - - path_to_model_config = os.path.join(dir_path, "model_config.json") - model_config = VAE_LinNF_Config.from_json_file(path_to_model_config) - - return model_config diff --git a/src/pythae/models/vamp/vamp_model.py b/src/pythae/models/vamp/vamp_model.py index 8503b60d..5757c2b9 100644 --- a/src/pythae/models/vamp/vamp_model.py +++ b/src/pythae/models/vamp/vamp_model.py @@ -1,4 +1,3 @@ -import os from typing import Optional import numpy as np @@ -249,18 +248,3 @@ def get_nll(self, data, n_samples=1, batch_size=100): if i % 1000 == 0: print(f"Current nll at {i}: {np.mean(log_p)}") return np.mean(log_p) - - @classmethod - def _load_model_config_from_folder(cls, dir_path): - file_list = os.listdir(dir_path) - - if "model_config.json" not in file_list: - raise FileNotFoundError( - f"Missing model config file ('model_config.json') in" - f"{dir_path}... Cannot perform model building." - ) - - path_to_model_config = os.path.join(dir_path, "model_config.json") - model_config = VAMPConfig.from_json_file(path_to_model_config) - - return model_config diff --git a/src/pythae/models/vq_vae/vq_vae_model.py b/src/pythae/models/vq_vae/vq_vae_model.py index b7e51108..b9fd16c8 100644 --- a/src/pythae/models/vq_vae/vq_vae_model.py +++ b/src/pythae/models/vq_vae/vq_vae_model.py @@ -139,18 +139,3 @@ def _sample_gauss(self, mu, std): # Sample N(0, I) eps = torch.randn_like(std) return mu + eps * std, eps - - @classmethod - def _load_model_config_from_folder(cls, dir_path): - file_list = os.listdir(dir_path) - - if "model_config.json" not in file_list: - raise FileNotFoundError( - f"Missing model config file ('model_config.json') in" - f"{dir_path}... Cannot perform model building." - ) - - path_to_model_config = os.path.join(dir_path, "model_config.json") - model_config = VQVAEConfig.from_json_file(path_to_model_config) - - return model_config diff --git a/src/pythae/models/wae_mmd/wae_mmd_model.py b/src/pythae/models/wae_mmd/wae_mmd_model.py index e92df5a8..6184ade9 100644 --- a/src/pythae/models/wae_mmd/wae_mmd_model.py +++ b/src/pythae/models/wae_mmd/wae_mmd_model.py @@ -126,18 +126,3 @@ def rbf_kernel(self, z1, z2): k = torch.exp(-torch.norm(z1.unsqueeze(1) - z2.unsqueeze(0), dim=-1) ** 2 / C) return k - - @classmethod - def _load_model_config_from_folder(cls, dir_path): - file_list = os.listdir(dir_path) - - if "model_config.json" not in file_list: - raise FileNotFoundError( - f"Missing model config file ('model_config.json') in" - f"{dir_path}... Cannot perform model building." - ) - - path_to_model_config = os.path.join(dir_path, "model_config.json") - model_config = WAE_MMD_Config.from_json_file(path_to_model_config) - - return model_config diff --git a/src/pythae/trainers/training_callbacks.py b/src/pythae/trainers/training_callbacks.py index b97cc2bb..b9c7d99c 100644 --- a/src/pythae/trainers/training_callbacks.py +++ b/src/pythae/trainers/training_callbacks.py @@ -252,7 +252,7 @@ def on_epoch_end(self, training_config, **kwags): self.eval_progress_bar.close() -class WandbCallback(TrainingCallback): +class WandbCallback(TrainingCallback): # pragma: no cover def __init__(self): if not wandb_is_available(): raise ModuleNotFoundError( @@ -273,7 +273,7 @@ def setup(self, training_config, **kwargs): training_config_dict = training_config.to_dict() - self._wandb.init(project=project_name, entity=entity_name) + self.run = self._wandb.init(project=project_name, entity=entity_name) if model_config is not None: model_config_dict = model_config.to_dict() @@ -350,3 +350,6 @@ def on_prediction_step(self, training_config, **kwargs): val_table = self._wandb.Table(data=data_to_log, columns=column_names) self._wandb.log({"my_val_table": val_table}) + + def on_train_end(self, training_config: BaseTrainerConfig, **kwargs): + self.run.finish() diff --git a/tests/data/custom_architectures.py b/tests/data/custom_architectures.py index 44b05b4c..0b086b1e 100644 --- a/tests/data/custom_architectures.py +++ b/tests/data/custom_architectures.py @@ -6,6 +6,14 @@ from typing import List from pythae.models.nn import * from pythae.models.base.base_utils import ModelOutput +import torch.nn as nn + +class Layer(nn.Module): + def __init__(self) -> None: + nn.Module.__init__(self) + + def forward(self, x): + return x class Encoder_AE_Conv(BaseEncoder): @@ -357,7 +365,7 @@ def __init__(self, args: dict): self.input_dim = args.input_dim self.latent_dim = args.latent_dim - self.layers = nn.Sequential(nn.Linear(np.prod(args.input_dim), 10), nn.ReLU()) + self.layers = nn.Sequential(nn.Linear(np.prod(args.input_dim), 10), nn.ReLU(), Layer()) self.mu = nn.Linear(10, self.latent_dim) def forward(self, x): @@ -382,7 +390,7 @@ def __init__(self, args: dict): self.input_dim = args.input_dim self.latent_dim = args.latent_dim - self.layers = nn.Sequential(nn.Linear(np.prod(args.input_dim), 10), nn.ReLU()) + self.layers = nn.Sequential(nn.Linear(np.prod(args.input_dim), 10), nn.ReLU(), Layer()) self.mu = nn.Linear(10, self.latent_dim) self.std = nn.Linear(10, self.latent_dim) diff --git a/tests/test_AE.py b/tests/test_AE.py index e0af599d..5d79c5c3 100644 --- a/tests/test_AE.py +++ b/tests/test_AE.py @@ -118,7 +118,7 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt"]) + assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -145,7 +145,7 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl"] + ["model_config.json", "model.pt", "encoder.pkl", "environment.json"] ) # reload model @@ -173,7 +173,7 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "decoder.pkl"] + ["model_config.json", "model.pt", "decoder.pkl", "environment.json"] ) # reload model @@ -203,7 +203,13 @@ def test_full_custom_model_saving( model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl", "decoder.pkl"] + [ + "model_config.json", + "model.pt", + "encoder.pkl", + "decoder.pkl", + "environment.json" + ] ) # reload model diff --git a/tests/test_Adversarial_AE.py b/tests/test_Adversarial_AE.py index ab4f212e..0acbc8fa 100644 --- a/tests/test_Adversarial_AE.py +++ b/tests/test_Adversarial_AE.py @@ -178,7 +178,7 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt"]) + assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -205,7 +205,7 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl"] + ["model_config.json", "model.pt", "encoder.pkl", "environment.json"] ) # reload model @@ -233,7 +233,7 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "decoder.pkl"] + ["model_config.json", "model.pt", "decoder.pkl", "environment.json"] ) # reload model @@ -263,7 +263,7 @@ def test_custom_discriminator_model_saving( model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "discriminator.pkl"] + ["model_config.json", "model.pt", "discriminator.pkl", "environment.json"] ) # reload model @@ -309,6 +309,7 @@ def test_full_custom_model_saving( "encoder.pkl", "decoder.pkl", "discriminator.pkl", + "environment.json" ] ) diff --git a/tests/test_BetaTCVAE.py b/tests/test_BetaTCVAE.py index 0a795c24..2e3806f9 100644 --- a/tests/test_BetaTCVAE.py +++ b/tests/test_BetaTCVAE.py @@ -126,7 +126,7 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt"]) + assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -153,7 +153,7 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl"] + ["model_config.json", "model.pt", "encoder.pkl", "environment.json"] ) # reload model @@ -181,7 +181,7 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "decoder.pkl"] + ["model_config.json", "model.pt", "decoder.pkl", "environment.json"] ) # reload model @@ -211,7 +211,13 @@ def test_full_custom_model_saving( model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl", "decoder.pkl"] + [ + "model_config.json", + "model.pt", + "encoder.pkl", + "decoder.pkl", + "environment.json" + ] ) # reload model diff --git a/tests/test_BetaVAE.py b/tests/test_BetaVAE.py index cfc6a939..dca0be15 100644 --- a/tests/test_BetaVAE.py +++ b/tests/test_BetaVAE.py @@ -118,7 +118,7 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt"]) + assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -145,7 +145,7 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl"] + ["model_config.json", "model.pt", "encoder.pkl", "environment.json"] ) # reload model @@ -173,7 +173,7 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "decoder.pkl"] + ["model_config.json", "model.pt", "decoder.pkl", "environment.json"] ) # reload model @@ -203,7 +203,13 @@ def test_full_custom_model_saving( model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl", "decoder.pkl"] + [ + "model_config.json", + "model.pt", + "encoder.pkl", + "decoder.pkl", + "environment.json" + ] ) # reload model diff --git a/tests/test_DisentangledBetaVAE.py b/tests/test_DisentangledBetaVAE.py index 5192e565..606a9ac3 100644 --- a/tests/test_DisentangledBetaVAE.py +++ b/tests/test_DisentangledBetaVAE.py @@ -133,7 +133,7 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt"]) + assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -160,7 +160,7 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl"] + ["model_config.json", "model.pt", "encoder.pkl", "environment.json"] ) # reload model @@ -188,7 +188,7 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "decoder.pkl"] + ["model_config.json", "model.pt", "decoder.pkl", "environment.json"] ) # reload model @@ -220,7 +220,13 @@ def test_full_custom_model_saving( model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl", "decoder.pkl"] + [ + "model_config.json", + "model.pt", + "encoder.pkl", + "decoder.pkl", + "environment.json" + ] ) # reload model diff --git a/tests/test_FactorVAE.py b/tests/test_FactorVAE.py index d50e144f..94e02980 100644 --- a/tests/test_FactorVAE.py +++ b/tests/test_FactorVAE.py @@ -127,7 +127,7 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt"]) + assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -154,7 +154,7 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl"] + ["model_config.json", "model.pt", "encoder.pkl", "environment.json"] ) # reload model @@ -182,7 +182,7 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "decoder.pkl"] + ["model_config.json", "model.pt", "decoder.pkl", "environment.json"] ) # reload model @@ -225,7 +225,8 @@ def test_full_custom_model_saving( "model_config.json", "model.pt", "encoder.pkl", - "decoder.pkl" + "decoder.pkl", + "environment.json" ] ) diff --git a/tests/test_HVAE.py b/tests/test_HVAE.py index 73544153..bb7d7282 100644 --- a/tests/test_HVAE.py +++ b/tests/test_HVAE.py @@ -124,7 +124,7 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt"]) + assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -151,7 +151,7 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl"] + ["model_config.json", "model.pt", "encoder.pkl", "environment.json"] ) # reload model @@ -179,7 +179,7 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "decoder.pkl"] + ["model_config.json", "model.pt", "decoder.pkl", "environment.json"] ) # reload model @@ -209,7 +209,13 @@ def test_full_custom_model_saving( model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl", "decoder.pkl"] + [ + "model_config.json", + "model.pt", + "encoder.pkl", + "decoder.pkl", + "environment.json" + ] ) # reload model diff --git a/tests/test_IAF.py b/tests/test_IAF.py index efeb1f18..0116f458 100644 --- a/tests/test_IAF.py +++ b/tests/test_IAF.py @@ -11,6 +11,7 @@ from pythae.models.normalizing_flows import IAF, IAFConfig from pythae.models.normalizing_flows import NFModel from pythae.data.datasets import BaseDataset +from pythae.models import AutoModel from pythae.trainers import BaseTrainer, BaseTrainerConfig @@ -79,10 +80,10 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt"]) + assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) # reload model - model_rec = IAF.load_from_folder(dir_path) + model_rec = AutoModel.load_from_folder(dir_path) # check configs are the same assert model_rec.model_config.__dict__ == model.model_config.__dict__ @@ -112,18 +113,18 @@ def test_raises_missing_files(self, tmpdir, model_configs): # check raises model.pt is missing with pytest.raises(FileNotFoundError): - model_rec = IAF.load_from_folder(dir_path) + model_rec = AutoModel.load_from_folder(dir_path) torch.save({"wrong_key": 0.0}, os.path.join(dir_path, "model.pt")) # check raises wrong key in model.pt with pytest.raises(KeyError): - model_rec = IAF.load_from_folder(dir_path) + model_rec = AutoModel.load_from_folder(dir_path) os.remove(os.path.join(dir_path, "model_config.json")) # check raises model_config.json is missing with pytest.raises(FileNotFoundError): - model_rec = IAF.load_from_folder(dir_path) + model_rec = AutoModel.load_from_folder(dir_path) class Test_Model_forward: @@ -350,7 +351,7 @@ def test_checkpoint_saving( ) # check reload full model - model_rec = IAF.load_from_folder(os.path.join(checkpoint_dir)) + model_rec = AutoModel.load_from_folder(os.path.join(checkpoint_dir)) assert all( [ @@ -466,7 +467,7 @@ def test_final_model_saving( ) # check reload full model - model_rec = IAF.load_from_folder(os.path.join(final_dir)) + model_rec = AutoModel.load_from_folder(os.path.join(final_dir)) assert all( [ @@ -513,7 +514,7 @@ def test_iaf_training_pipeline( ) # check reload full model - model_rec = IAF.load_from_folder(os.path.join(final_dir)) + model_rec = AutoModel.load_from_folder(os.path.join(final_dir)) assert all( [ diff --git a/tests/test_IWAE.py b/tests/test_IWAE.py index e852590b..2cf9654f 100644 --- a/tests/test_IWAE.py +++ b/tests/test_IWAE.py @@ -121,7 +121,7 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt"]) + assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -148,7 +148,7 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl"] + ["model_config.json", "model.pt", "encoder.pkl", "environment.json"] ) # reload model @@ -176,7 +176,7 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "decoder.pkl"] + ["model_config.json", "model.pt", "decoder.pkl", "environment.json"] ) # reload model @@ -206,7 +206,13 @@ def test_full_custom_model_saving( model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl", "decoder.pkl"] + [ + "model_config.json", + "model.pt", + "encoder.pkl", + "decoder.pkl", + "environment.json" + ] ) # reload model diff --git a/tests/test_MADE.py b/tests/test_MADE.py index e5cc5250..e5634bbb 100644 --- a/tests/test_MADE.py +++ b/tests/test_MADE.py @@ -11,6 +11,7 @@ from pythae.models.base.base_utils import ModelOutput from pythae.models.normalizing_flows import MADE, MADEConfig from pythae.models.normalizing_flows import NFModel +from pythae.models import AutoModel from pythae.trainers import BaseTrainer, BaseTrainerConfig from pythae.pipelines import TrainingPipeline @@ -72,10 +73,10 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt"]) + assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) # reload model - model_rec = MADE.load_from_folder(dir_path) + model_rec = AutoModel.load_from_folder(dir_path) # check configs are the same assert model_rec.model_config.__dict__ == model.model_config.__dict__ @@ -104,18 +105,18 @@ def test_raises_missing_files(self, tmpdir, model_configs): # check raises model.pt is missing with pytest.raises(FileNotFoundError): - model_rec = MADE.load_from_folder(dir_path) + model_rec = AutoModel.load_from_folder(dir_path) torch.save({"wrong_key": 0.0}, os.path.join(dir_path, "model.pt")) # check raises wrong key in model.pt with pytest.raises(KeyError): - model_rec = MADE.load_from_folder(dir_path) + model_rec = AutoModel.load_from_folder(dir_path) os.remove(os.path.join(dir_path, "model_config.json")) # check raises model_config.json is missing with pytest.raises(FileNotFoundError): - model_rec = MADE.load_from_folder(dir_path) + model_rec = AutoModel.load_from_folder(dir_path) class Test_Model_forward: @@ -145,368 +146,3 @@ def test_model_train_output(self, made, demo_data): assert out.log_var.shape[0] == demo_data["data"].shape[0] assert out.mu.shape[1:] == np.prod(made.model_config.output_dim) assert out.log_var.shape[1:] == np.prod(made.model_config.output_dim) - - -# @pytest.mark.slow -# class Test_MADE_Training: -# @pytest.fixture -# def train_dataset(self): -# return torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample")) -# -# @pytest.fixture( -# params=[BaseTrainerConfig(num_epochs=3, steps_saving=2, learning_rate=1e-5)] -# ) -# def training_configs(self, tmpdir, request): -# tmpdir.mkdir("dummy_folder") -# dir_path = os.path.join(tmpdir, "dummy_folder") -# request.param.output_dir = dir_path -# return request.param -# -# @pytest.fixture( -# params=[ -# MADEConfig(input_dim=(784,), output_dim=(784,), degrees_ordering='random'), -# ] -# ) -# def model_configs(self, request): -# return request.param -# -# @pytest.fixture -# def made(self, model_configs): -# model = MADE(model_configs) -# return model -# -# @pytest.fixture() -# def prior(self, model_configs, request): -# -# device = 'cuda' if torch.cuda.is_available() else 'cpu' -# -# return torch.distributions.MultivariateNormal( -# torch.zeros(np.prod(model_configs.input_dim)).to(device), -# torch.eye(np.prod(model_configs.input_dim)).to(device) -# ) -# -# @pytest.fixture(params=[Adam]) -# def optimizers(self, request, made, training_configs): -# if request.param is not None: -# optimizer = request.param( -# made.parameters(), lr=training_configs.learning_rate -# ) -# -# else: -# optimizer = None -# -# return optimizer -# -# def test_made_train_step(self, made, prior, train_dataset, training_configs, optimizers): -# -# nf_model = NFModel(prior=prior, flow=made) -# -# trainer = BaseTrainer( -# model=nf_model, -# train_dataset=train_dataset, -# training_config=training_configs, -# optimizer=optimizers, -# ) -# -# start_model_state_dict = deepcopy(trainer.model.state_dict()) -# -# step_1_loss = trainer.train_step(epoch=1) -# -# step_1_model_state_dict = deepcopy(trainer.model.state_dict()) -# -# # check that weights were updated -# assert not all( -# [ -# torch.equal(start_model_state_dict[key], step_1_model_state_dict[key]) -# for key in start_model_state_dict.keys() -# ] -# ) -# -# def test_made_eval_step(self, made, prior, train_dataset, training_configs, optimizers): -# -# nf_model = NFModel(prior=prior, flow=made) -# -# trainer = BaseTrainer( -# model=nf_model, -# train_dataset=train_dataset, -# eval_dataset=train_dataset, -# training_config=training_configs, -# optimizer=optimizers, -# ) -# -# start_model_state_dict = deepcopy(trainer.model.state_dict()) -# -# step_1_loss = trainer.eval_step(epoch=1) -# -# step_1_model_state_dict = deepcopy(trainer.model.state_dict()) -# -# # check that weights were updated -# assert all( -# [ -# torch.equal(start_model_state_dict[key], step_1_model_state_dict[key]) -# for key in start_model_state_dict.keys() -# ] -# ) -# -# def test_made_main_train_loop( -# self, made, prior, train_dataset, training_configs, optimizers): -# -# nf_model = NFModel(prior=prior, flow=made) -# -# trainer = BaseTrainer( -# model=nf_model, -# train_dataset=train_dataset, -# training_config=training_configs, -# optimizer=optimizers, -# ) -# -# start_model_state_dict = deepcopy(trainer.model.state_dict()) -# -# trainer.train() -# -# step_1_model_state_dict = deepcopy(trainer.model.state_dict()) -# -# # check that weights were updated -# assert not all( -# [ -# torch.equal(start_model_state_dict[key], step_1_model_state_dict[key]) -# for key in start_model_state_dict.keys() -# ] -# ) -# -# def test_checkpoint_saving( -# self, tmpdir, made, prior, train_dataset, training_configs, optimizers -# ): -# -# dir_path = training_configs.output_dir -# -# nf_model = NFModel(prior=prior, flow=made) -# -# trainer = BaseTrainer( -# model=nf_model, -# train_dataset=train_dataset, -# training_config=training_configs, -# optimizer=optimizers, -# ) -# -# # Make a training step -# step_1_loss = trainer.train_step(epoch=1) -# -# model = deepcopy(trainer.model.flow) -# optimizer = deepcopy(trainer.optimizer) -# -# trainer.save_checkpoint(dir_path=dir_path, epoch=0, model=model) -# -# checkpoint_dir = os.path.join(dir_path, "checkpoint_epoch_0") -# -# assert os.path.isdir(checkpoint_dir) -# -# files_list = os.listdir(checkpoint_dir) -# -# assert set(["model.pt", "optimizer.pt", "training_config.json"]).issubset( -# set(files_list) -# ) -# -# -# model_rec_state_dict = torch.load(os.path.join(checkpoint_dir, "model.pt"))[ -# "model_state_dict" -# ] -# -# assert all( -# [ -# torch.equal( -# model_rec_state_dict[key].cpu(), model.state_dict()[key].cpu() -# ) -# for key in model.state_dict().keys() -# ] -# ) -# -# # check reload full model -# model_rec = MADE.load_from_folder(os.path.join(checkpoint_dir)) -# -# assert all( -# [ -# torch.equal( -# model_rec.state_dict()[key].cpu(), model.state_dict()[key].cpu() -# ) -# for key in model.state_dict().keys() -# ] -# ) -# -# optim_rec_state_dict = torch.load(os.path.join(checkpoint_dir, "optimizer.pt")) -# -# assert all( -# [ -# dict_rec == dict_optimizer -# for (dict_rec, dict_optimizer) in zip( -# optim_rec_state_dict["param_groups"], -# optimizer.state_dict()["param_groups"], -# ) -# ] -# ) -# -# assert all( -# [ -# dict_rec == dict_optimizer -# for (dict_rec, dict_optimizer) in zip( -# optim_rec_state_dict["state"], optimizer.state_dict()["state"] -# ) -# ] -# ) -# -# def test_checkpoint_saving_during_training( -# self, tmpdir, made, prior, train_dataset, training_configs, optimizers -# ): -# # -# target_saving_epoch = training_configs.steps_saving -# -# dir_path = training_configs.output_dir -# -# nf_model = NFModel(prior=prior, flow=made) -# -# trainer = BaseTrainer( -# model=nf_model, -# train_dataset=train_dataset, -# training_config=training_configs, -# optimizer=optimizers, -# ) -# -# model = deepcopy(trainer.model.flow) -# -# trainer.train() -# -# training_dir = os.path.join( -# dir_path, f"MADE_training_{trainer._training_signature}" -# ) -# assert os.path.isdir(training_dir) -# -# checkpoint_dir = os.path.join( -# training_dir, f"checkpoint_epoch_{target_saving_epoch}" -# ) -# -# assert os.path.isdir(checkpoint_dir) -# -# files_list = os.listdir(checkpoint_dir) -# -# # check files -# assert set(["model.pt", "optimizer.pt", "training_config.json"]).issubset( -# set(files_list) -# ) -# -# model_rec_state_dict = torch.load(os.path.join(checkpoint_dir, "model.pt"))[ -# "model_state_dict" -# ] -# -# assert not all( -# [ -# torch.equal(model_rec_state_dict[key], model.state_dict()[key]) -# for key in model.state_dict().keys() -# ] -# ) -# -# def test_final_model_saving( -# self, tmpdir, made, prior, train_dataset, training_configs, optimizers -# ): -# -# dir_path = training_configs.output_dir -# -# nf_model = NFModel(prior=prior, flow=made) -# -# trainer = BaseTrainer( -# model=nf_model, -# train_dataset=train_dataset, -# training_config=training_configs, -# optimizer=optimizers, -# ) -# -# trainer.train() -# -# model = deepcopy(trainer._best_model.flow) -# -# training_dir = os.path.join( -# dir_path, f"MADE_training_{trainer._training_signature}" -# ) -# assert os.path.isdir(training_dir) -# -# final_dir = os.path.join(training_dir, f"final_model") -# assert os.path.isdir(final_dir) -# -# files_list = os.listdir(final_dir) -# -# assert set(["model.pt", "model_config.json", "training_config.json"]).issubset( -# set(files_list) -# ) -# -# -# # check reload full model -# model_rec = MADE.load_from_folder(os.path.join(final_dir)) -# -# assert all( -# [ -# torch.equal( -# model_rec.state_dict()[key].cpu(), model.state_dict()[key].cpu() -# ) -# for key in model.state_dict().keys() -# ] -# ) -# -## def test_made_training_pipeline(self, tmpdir, made, train_dataset, training_configs): -## -# dir_path = training_configs.output_dir -# -# # build pipeline -# pipeline = TrainingPipeline(model=made, training_config=training_configs) -# -# assert pipeline.training_config.__dict__ == training_configs.__dict__ -# -# # Launch Pipeline -# pipeline( -# train_data=train_dataset.data, # gives tensor to pipeline -# eval_data=train_dataset.data, # gives tensor to pipeline -# ) -# -# model = deepcopy(pipeline.trainer._best_model) -# -# training_dir = os.path.join( -# dir_path, f"MADE_training_{pipeline.trainer._training_signature}" -# ) -# assert os.path.isdir(training_dir) -# -# final_dir = os.path.join(training_dir, f"final_model") -# assert os.path.isdir(final_dir) -# -# files_list = os.listdir(final_dir) -# -# assert set(["model.pt", "model_config.json", "training_config.json"]).issubset( -# set(files_list) -# ) -# -# # check pickled custom decoder -# if not made.model_config.uses_default_decoder: -# assert "decoder.pkl" in files_list -# -# else: -# assert not "decoder.pkl" in files_list -# -# # check pickled custom encoder -# if not made.model_config.uses_default_encoder: -# assert "encoder.pkl" in files_list -# -# else: -# assert not "encoder.pkl" in files_list -# -# # check reload full model -# model_rec = MADE.load_from_folder(os.path.join(final_dir)) -# -# assert all( -# [ -# torch.equal( -# model_rec.state_dict()[key].cpu(), model.state_dict()[key].cpu() -# ) -# for key in model.state_dict().keys() -# ] -# ) -# -# assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) -# assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) -# diff --git a/tests/test_MAF.py b/tests/test_MAF.py index 51d29f80..720dd671 100644 --- a/tests/test_MAF.py +++ b/tests/test_MAF.py @@ -10,6 +10,7 @@ from pythae.models.base.base_utils import ModelOutput from pythae.models.normalizing_flows import MAF, MAFConfig from pythae.models.normalizing_flows import NFModel +from pythae.models import AutoModel from pythae.trainers import BaseTrainer, BaseTrainerConfig @@ -71,10 +72,10 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt"]) + assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) # reload model - model_rec = MAF.load_from_folder(dir_path) + model_rec = AutoModel.load_from_folder(dir_path) # check configs are the same assert model_rec.model_config.__dict__ == model.model_config.__dict__ @@ -104,18 +105,18 @@ def test_raises_missing_files(self, tmpdir, model_configs): # check raises model.pt is missing with pytest.raises(FileNotFoundError): - model_rec = MAF.load_from_folder(dir_path) + model_rec = AutoModel.load_from_folder(dir_path) torch.save({"wrong_key": 0.0}, os.path.join(dir_path, "model.pt")) # check raises wrong key in model.pt with pytest.raises(KeyError): - model_rec = MAF.load_from_folder(dir_path) + model_rec = AutoModel.load_from_folder(dir_path) os.remove(os.path.join(dir_path, "model_config.json")) # check raises model_config.json is missing with pytest.raises(FileNotFoundError): - model_rec = MAF.load_from_folder(dir_path) + model_rec = AutoModel.load_from_folder(dir_path) class Test_Model_forward: @@ -342,7 +343,7 @@ def test_checkpoint_saving( ) # check reload full model - model_rec = MAF.load_from_folder(os.path.join(checkpoint_dir)) + model_rec = AutoModel.load_from_folder(os.path.join(checkpoint_dir)) assert all( [ @@ -458,7 +459,7 @@ def test_final_model_saving( ) # check reload full model - model_rec = MAF.load_from_folder(os.path.join(final_dir)) + model_rec = AutoModel.load_from_folder(os.path.join(final_dir)) assert all( [ @@ -505,7 +506,7 @@ def test_maf_training_pipeline( ) # check reload full model - model_rec = MAF.load_from_folder(os.path.join(final_dir)) + model_rec = AutoModel.load_from_folder(os.path.join(final_dir)) assert all( [ diff --git a/tests/test_MSSSIMVAE.py b/tests/test_MSSSIMVAE.py index c8768e03..8e8534ac 100644 --- a/tests/test_MSSSIMVAE.py +++ b/tests/test_MSSSIMVAE.py @@ -122,7 +122,7 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt"]) + assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -149,7 +149,7 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl"] + ["model_config.json", "model.pt", "encoder.pkl", "environment.json"] ) # reload model @@ -177,7 +177,7 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "decoder.pkl"] + ["model_config.json", "model.pt", "decoder.pkl", "environment.json"] ) # reload model @@ -209,7 +209,13 @@ def test_full_custom_model_saving( model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl", "decoder.pkl"] + [ + "model_config.json", + "model.pt", + "encoder.pkl", + "decoder.pkl", + "environment.json" + ] ) # reload model diff --git a/tests/test_PixelCNN.py b/tests/test_PixelCNN.py index dcdad34f..23ac9e45 100644 --- a/tests/test_PixelCNN.py +++ b/tests/test_PixelCNN.py @@ -8,6 +8,7 @@ from pythae.models.base.base_utils import ModelOutput from pythae.models.normalizing_flows import PixelCNN, PixelCNNConfig +from pythae.models import AutoModel from pythae.trainers import BaseTrainer, BaseTrainerConfig from pythae.pipelines import TrainingPipeline @@ -71,10 +72,10 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt"]) + assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) # reload model - model_rec = PixelCNN.load_from_folder(dir_path) + model_rec = AutoModel.load_from_folder(dir_path) # check configs are the same assert model_rec.model_config.__dict__ == model.model_config.__dict__ @@ -104,18 +105,18 @@ def test_raises_missing_files(self, tmpdir, model_configs): # check raises model.pt is missing with pytest.raises(FileNotFoundError): - model_rec = PixelCNN.load_from_folder(dir_path) + model_rec = AutoModel.load_from_folder(dir_path) torch.save({"wrong_key": 0.0}, os.path.join(dir_path, "model.pt")) # check raises wrong key in model.pt with pytest.raises(KeyError): - model_rec = PixelCNN.load_from_folder(dir_path) + model_rec = AutoModel.load_from_folder(dir_path) os.remove(os.path.join(dir_path, "model_config.json")) # check raises model_config.json is missing with pytest.raises(FileNotFoundError): - model_rec = PixelCNN.load_from_folder(dir_path) + model_rec = AutoModel.load_from_folder(dir_path) class Test_Model_forward: @@ -311,7 +312,7 @@ def test_checkpoint_saving( ) # check reload full model - model_rec = PixelCNN.load_from_folder(os.path.join(checkpoint_dir)) + model_rec = AutoModel.load_from_folder(os.path.join(checkpoint_dir)) assert all( [ @@ -423,7 +424,7 @@ def test_final_model_saving( ) # check reload full model - model_rec = PixelCNN.load_from_folder(os.path.join(final_dir)) + model_rec = AutoModel.load_from_folder(os.path.join(final_dir)) assert all( [ @@ -468,7 +469,7 @@ def test_pixelcnn_training_pipeline( ) # check reload full model - model_rec = PixelCNN.load_from_folder(os.path.join(final_dir)) + model_rec = AutoModel.load_from_folder(os.path.join(final_dir)) assert all( [ diff --git a/tests/test_RHVAE.py b/tests/test_RHVAE.py index 4b45d960..3acb0521 100644 --- a/tests/test_RHVAE.py +++ b/tests/test_RHVAE.py @@ -141,7 +141,7 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt"]) + assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -174,7 +174,7 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl"] + ["model_config.json", "model.pt", "encoder.pkl", "environment.json"] ) # reload model @@ -208,7 +208,7 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "decoder.pkl"] + ["model_config.json", "model.pt", "decoder.pkl", "environment.json"] ) # reload model @@ -242,7 +242,7 @@ def test_custom_metric_model_saving(self, tmpdir, model_configs, custom_metric): model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "metric.pkl"] + ["model_config.json", "model.pt", "metric.pkl", "environment.json"] ) # reload model @@ -289,6 +289,7 @@ def test_full_custom_model_saving( "encoder.pkl", "decoder.pkl", "metric.pkl", + "environment.json" ] ) diff --git a/tests/test_SVAE.py b/tests/test_SVAE.py index a92c503a..2aeee030 100644 --- a/tests/test_SVAE.py +++ b/tests/test_SVAE.py @@ -116,7 +116,7 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt"]) + assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -143,7 +143,7 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl"] + ["model_config.json", "model.pt", "encoder.pkl", "environment.json"] ) # reload model @@ -171,7 +171,7 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "decoder.pkl"] + ["model_config.json", "model.pt", "decoder.pkl", "environment.json"] ) # reload model @@ -201,7 +201,13 @@ def test_full_custom_model_saving( model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl", "decoder.pkl"] + [ + "model_config.json", + "model.pt", + "encoder.pkl", + "decoder.pkl", + "environment.json" + ] ) # reload model diff --git a/tests/test_VAE.py b/tests/test_VAE.py index 61a0affd..904dc8d2 100644 --- a/tests/test_VAE.py +++ b/tests/test_VAE.py @@ -116,7 +116,7 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt"]) + assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -143,7 +143,7 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl"] + ["model_config.json", "model.pt", "encoder.pkl", "environment.json"] ) # reload model @@ -171,7 +171,7 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "decoder.pkl"] + ["model_config.json", "model.pt", "decoder.pkl", "environment.json"] ) # reload model @@ -201,7 +201,13 @@ def test_full_custom_model_saving( model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl", "decoder.pkl"] + [ + "model_config.json", + "model.pt", + "encoder.pkl", + "decoder.pkl", + "environment.json" + ] ) # reload model diff --git a/tests/test_VAEGAN.py b/tests/test_VAEGAN.py index df0100d9..8f6c6634 100644 --- a/tests/test_VAEGAN.py +++ b/tests/test_VAEGAN.py @@ -171,7 +171,7 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt"]) + assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -198,7 +198,7 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl"] + ["model_config.json", "model.pt", "encoder.pkl", "environment.json"] ) # reload model @@ -226,7 +226,7 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "decoder.pkl"] + ["model_config.json", "model.pt", "decoder.pkl", "environment.json"] ) # reload model @@ -256,7 +256,7 @@ def test_custom_discriminator_model_saving( model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "discriminator.pkl"] + ["model_config.json", "model.pt", "discriminator.pkl", "environment.json"] ) # reload model @@ -302,6 +302,7 @@ def test_full_custom_model_saving( "encoder.pkl", "decoder.pkl", "discriminator.pkl", + "environment.json" ] ) diff --git a/tests/test_VAE_IAF.py b/tests/test_VAE_IAF.py index 0188af71..e7a673a7 100644 --- a/tests/test_VAE_IAF.py +++ b/tests/test_VAE_IAF.py @@ -122,7 +122,7 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt"]) + assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -149,7 +149,7 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl"] + ["model_config.json", "model.pt", "encoder.pkl", "environment.json"] ) # reload model @@ -177,7 +177,7 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "decoder.pkl"] + ["model_config.json", "model.pt", "decoder.pkl", "environment.json"] ) # reload model @@ -207,7 +207,13 @@ def test_full_custom_model_saving( model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl", "decoder.pkl"] + [ + "model_config.json", + "model.pt", + "encoder.pkl", + "decoder.pkl", + "environment.json" + ] ) # reload model diff --git a/tests/test_VAE_LinFlow.py b/tests/test_VAE_LinFlow.py index afd81bc1..48abb79d 100644 --- a/tests/test_VAE_LinFlow.py +++ b/tests/test_VAE_LinFlow.py @@ -136,7 +136,7 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt"]) + assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -163,7 +163,7 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl"] + ["model_config.json", "model.pt", "encoder.pkl", "environment.json"] ) # reload model @@ -191,7 +191,7 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "decoder.pkl"] + ["model_config.json", "model.pt", "decoder.pkl", "environment.json"] ) # reload model @@ -221,7 +221,13 @@ def test_full_custom_model_saving( model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl", "decoder.pkl"] + [ + "model_config.json", + "model.pt", + "encoder.pkl", + "decoder.pkl", + "environment.json" + ] ) # reload model diff --git a/tests/test_VAMP.py b/tests/test_VAMP.py index 7e7e609c..224906d4 100644 --- a/tests/test_VAMP.py +++ b/tests/test_VAMP.py @@ -120,7 +120,7 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt"]) + assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -147,7 +147,7 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl"] + ["model_config.json", "model.pt", "encoder.pkl", "environment.json"] ) # reload model @@ -175,7 +175,7 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "decoder.pkl"] + ["model_config.json", "model.pt", "decoder.pkl", "environment.json"] ) # reload model @@ -205,7 +205,13 @@ def test_full_custom_model_saving( model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl", "decoder.pkl"] + [ + "model_config.json", + "model.pt", + "encoder.pkl", + "decoder.pkl", + "environment.json" + ] ) # reload model diff --git a/tests/test_VQVAE.py b/tests/test_VQVAE.py index 06a249fe..c5f8ebd7 100644 --- a/tests/test_VQVAE.py +++ b/tests/test_VQVAE.py @@ -135,7 +135,7 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt"]) + assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -162,7 +162,7 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl"] + ["model_config.json", "model.pt", "encoder.pkl", "environment.json"] ) # reload model @@ -190,7 +190,7 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "decoder.pkl"] + ["model_config.json", "model.pt", "decoder.pkl", "environment.json"] ) # reload model @@ -220,7 +220,13 @@ def test_full_custom_model_saving( model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl", "decoder.pkl"] + [ + "model_config.json", + "model.pt", + "encoder.pkl", + "decoder.pkl", + "environment.json" + ] ) # reload model diff --git a/tests/test_WAE_MMD.py b/tests/test_WAE_MMD.py index 30a9f446..e561a196 100644 --- a/tests/test_WAE_MMD.py +++ b/tests/test_WAE_MMD.py @@ -118,7 +118,7 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt"]) + assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -145,7 +145,7 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl"] + ["model_config.json", "model.pt", "encoder.pkl", "environment.json"] ) # reload model @@ -173,7 +173,7 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "decoder.pkl"] + ["model_config.json", "model.pt", "decoder.pkl", "environment.json"] ) # reload model @@ -203,7 +203,13 @@ def test_full_custom_model_saving( model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl", "decoder.pkl"] + [ + "model_config.json", + "model.pt", + "encoder.pkl", + "decoder.pkl", + "environment.json" + ] ) # reload model diff --git a/tests/test_baseAE.py b/tests/test_baseAE.py index 89c2beef..c02c90c1 100644 --- a/tests/test_baseAE.py +++ b/tests/test_baseAE.py @@ -85,7 +85,7 @@ def test_default_model_saving(self, tmpdir, model_config_with_input_dim): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt"]) + assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) # reload model model_rec = BaseAE.load_from_folder(dir_path) @@ -112,7 +112,7 @@ def test_custom_decoder_model_saving( model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "decoder.pkl"] + ["model_config.json", "model.pt", "decoder.pkl", "environment.json"] ) # reload model diff --git a/tests/test_info_vae_mmd.py b/tests/test_info_vae_mmd.py index 09156cf9..2ebb5ee9 100644 --- a/tests/test_info_vae_mmd.py +++ b/tests/test_info_vae_mmd.py @@ -125,7 +125,7 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt"]) + assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) # reload model model_rec = INFOVAE_MMD.load_from_folder(dir_path) @@ -152,7 +152,7 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl"] + ["model_config.json", "model.pt", "encoder.pkl", "environment.json"] ) # reload model @@ -180,7 +180,7 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "decoder.pkl"] + ["model_config.json", "model.pt", "decoder.pkl", "environment.json"] ) # reload model @@ -212,7 +212,13 @@ def test_full_custom_model_saving( model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl", "decoder.pkl"] + [ + "model_config.json", + "model.pt", + "encoder.pkl", + "decoder.pkl", + "environment.json" + ] ) # reload model diff --git a/tests/test_planar_flow.py b/tests/test_planar_flow.py index 70d675c0..f4f726d7 100644 --- a/tests/test_planar_flow.py +++ b/tests/test_planar_flow.py @@ -11,6 +11,7 @@ from pythae.models.normalizing_flows import PlanarFlow, PlanarFlowConfig from pythae.models.normalizing_flows import NFModel from pythae.data.datasets import BaseDataset +from pythae.models import AutoModel from pythae.trainers import BaseTrainer, BaseTrainerConfig @@ -78,10 +79,10 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt"]) + assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) # reload model - model_rec = PlanarFlow.load_from_folder(dir_path) + model_rec = AutoModel.load_from_folder(dir_path) # check configs are the same assert model_rec.model_config.__dict__ == model.model_config.__dict__ @@ -111,18 +112,18 @@ def test_raises_missing_files(self, tmpdir, model_configs): # check raises model.pt is missing with pytest.raises(FileNotFoundError): - model_rec = PlanarFlow.load_from_folder(dir_path) + model_rec = AutoModel.load_from_folder(dir_path) torch.save({"wrong_key": 0.0}, os.path.join(dir_path, "model.pt")) # check raises wrong key in model.pt with pytest.raises(KeyError): - model_rec = PlanarFlow.load_from_folder(dir_path) + model_rec = AutoModel.load_from_folder(dir_path) os.remove(os.path.join(dir_path, "model_config.json")) # check raises model_config.json is missing with pytest.raises(FileNotFoundError): - model_rec = PlanarFlow.load_from_folder(dir_path) + model_rec = AutoModel.load_from_folder(dir_path) class Test_Model_forward: @@ -335,7 +336,7 @@ def test_checkpoint_saving( ) # check reload full model - model_rec = PlanarFlow.load_from_folder(os.path.join(checkpoint_dir)) + model_rec = AutoModel.load_from_folder(os.path.join(checkpoint_dir)) assert all( [ @@ -451,7 +452,7 @@ def test_final_model_saving( ) # check reload full model - model_rec = PlanarFlow.load_from_folder(os.path.join(final_dir)) + model_rec = AutoModel.load_from_folder(os.path.join(final_dir)) assert all( [ @@ -498,7 +499,7 @@ def test_planar_flow_training_pipeline( ) # check reload full model - model_rec = PlanarFlow.load_from_folder(os.path.join(final_dir)) + model_rec = AutoModel.load_from_folder(os.path.join(final_dir)) assert all( [ diff --git a/tests/test_radial_flow.py b/tests/test_radial_flow.py index a93759f5..23c420ba 100644 --- a/tests/test_radial_flow.py +++ b/tests/test_radial_flow.py @@ -11,6 +11,7 @@ from pythae.models.normalizing_flows import RadialFlow, RadialFlowConfig from pythae.models.normalizing_flows import NFModel from pythae.data.datasets import BaseDataset +from pythae.models import AutoModel from pythae.trainers import BaseTrainer, BaseTrainerConfig @@ -72,10 +73,10 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt"]) + assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) # reload model - model_rec = RadialFlow.load_from_folder(dir_path) + model_rec = AutoModel.load_from_folder(dir_path) # check configs are the same assert model_rec.model_config.__dict__ == model.model_config.__dict__ @@ -105,18 +106,18 @@ def test_raises_missing_files(self, tmpdir, model_configs): # check raises model.pt is missing with pytest.raises(FileNotFoundError): - model_rec = RadialFlow.load_from_folder(dir_path) + model_rec = AutoModel.load_from_folder(dir_path) torch.save({"wrong_key": 0.0}, os.path.join(dir_path, "model.pt")) # check raises wrong key in model.pt with pytest.raises(KeyError): - model_rec = RadialFlow.load_from_folder(dir_path) + model_rec = AutoModel.load_from_folder(dir_path) os.remove(os.path.join(dir_path, "model_config.json")) # check raises model_config.json is missing with pytest.raises(FileNotFoundError): - model_rec = RadialFlow.load_from_folder(dir_path) + model_rec = AutoModel.load_from_folder(dir_path) class Test_Model_forward: @@ -329,7 +330,7 @@ def test_checkpoint_saving( ) # check reload full model - model_rec = RadialFlow.load_from_folder(os.path.join(checkpoint_dir)) + model_rec = AutoModel.load_from_folder(os.path.join(checkpoint_dir)) assert all( [ @@ -445,7 +446,7 @@ def test_final_model_saving( ) # check reload full model - model_rec = RadialFlow.load_from_folder(os.path.join(final_dir)) + model_rec = AutoModel.load_from_folder(os.path.join(final_dir)) assert all( [ @@ -492,7 +493,7 @@ def test_radial_flow_training_pipeline( ) # check reload full model - model_rec = RadialFlow.load_from_folder(os.path.join(final_dir)) + model_rec = AutoModel.load_from_folder(os.path.join(final_dir)) assert all( [ diff --git a/tests/test_rae_gp.py b/tests/test_rae_gp.py index a8b41421..57b831b9 100644 --- a/tests/test_rae_gp.py +++ b/tests/test_rae_gp.py @@ -118,7 +118,7 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt"]) + assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -145,7 +145,7 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl"] + ["model_config.json", "model.pt", "encoder.pkl", "environment.json"] ) # reload model @@ -173,7 +173,7 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "decoder.pkl"] + ["model_config.json", "model.pt", "decoder.pkl", "environment.json"] ) # reload model @@ -203,7 +203,13 @@ def test_full_custom_model_saving( model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl", "decoder.pkl"] + [ + "model_config.json", + "model.pt", + "encoder.pkl", + "decoder.pkl", + "environment.json" + ] ) # reload model diff --git a/tests/test_rae_l2.py b/tests/test_rae_l2.py index 3e4b5dbb..4b148237 100644 --- a/tests/test_rae_l2.py +++ b/tests/test_rae_l2.py @@ -122,7 +122,7 @@ def test_default_model_saving(self, tmpdir, model_configs): model.save(dir_path=dir_path) - assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt"]) + assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) # reload model model_rec = AutoModel.load_from_folder(dir_path) @@ -149,7 +149,7 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl"] + ["model_config.json", "model.pt", "encoder.pkl", "environment.json"] ) # reload model @@ -177,7 +177,7 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "decoder.pkl"] + ["model_config.json", "model.pt", "decoder.pkl", "environment.json"] ) # reload model @@ -207,7 +207,13 @@ def test_full_custom_model_saving( model.save(dir_path=dir_path) assert set(os.listdir(dir_path)) == set( - ["model_config.json", "model.pt", "encoder.pkl", "decoder.pkl"] + [ + "model_config.json", + "model.pt", + "encoder.pkl", + "decoder.pkl", + "environment.json" + ] ) # reload model