diff --git a/.gitignore b/.gitignore index f6d617b8..053a7b90 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,6 @@ lightning_logs data -__pycache__ +**__pycache__** .idea results tb_profile @@ -14,7 +14,6 @@ models_wSPDE.py merge_ose_osse* main_wSPDE.slurm checkpoints -__pycache__ tmp outputs .hydra @@ -28,5 +27,8 @@ multirun.yaml icassp_code archive_dash icassp_code_bis -notebooks +notebooks/* tags + +# exceptions +!notebooks/visualize_exp.ipynb diff --git a/notebooks/visualize_exp.ipynb b/notebooks/visualize_exp.ipynb new file mode 100644 index 00000000..42c2edf9 --- /dev/null +++ b/notebooks/visualize_exp.ipynb @@ -0,0 +1,237 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Notebook for loading hydra config and check dataloaders\n", + "\n", + "### Uses hydra.compose API" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4ba282cd-b730-4056-80e8-9a289004d87a", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "44fe8011-a5ce-4903-8cd0-152e5f894234", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "from pprint import pprint\n", + "import numpy as np\n", + "import torch\n", + "import xarray as xr\n", + "import pytorch_lightning as pl\n", + "from pytorch_lightning import seed_everything\n", + "import matplotlib.pyplot as plt\n", + "\n", + "import hydra\n", + "from hydra import compose, initialize\n", + "from hydra.utils import instantiate, get_class\n", + "from omegaconf import OmegaConf\n", + "\n", + "sys.path.append('..')\n", + "from main import FourDVarNetRunner\n", + "from hydra_main import FourDVarNetHydraRunner" + ] + }, + { + "cell_type": "markdown", + "id": "8d3c282e-6a31-4a3f-b404-30af01af2cca", + "metadata": {}, + "source": [ + "## Choose xp" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12adfd32-97ce-4b1a-ac03-a26924034116", + "metadata": {}, + "outputs": [], + "source": [ + "config_path = \"../hydra_config\"\n", + "\n", + "pprint(os.listdir(os.path.join(config_path, \"xp\")))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f69cdf1e-d3e5-46d2-acd0-549e53794097", + "metadata": {}, + "outputs": [], + "source": [ + "xp = \"sla_glorys\"\n", + "entrypoint = \"train\"\n", + "training = \"glorys\"\n", + "file_paths = \"hal\"" + ] + }, + { + "cell_type": "markdown", + "id": "8f18c59c-7dff-4d79-b21b-ee41a37140b1", + "metadata": {}, + "source": [ + "## Load experiment config" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "34fd56a3-4e54-4f5f-82c1-d67baf9324a3", + "metadata": {}, + "outputs": [], + "source": [ + "with initialize(config_path=config_path):\n", + " cfg = compose(\n", + " config_name=\"main\",\n", + " overrides=[f\"xp={xp}\", f\"entrypoint={entrypoint}\", f\"training={training}\", f\"file_paths={file_paths}\"])\n", + " print(OmegaConf.to_yaml(cfg))" + ] + }, + { + "cell_type": "markdown", + "id": "26021f76-1698-4c66-ab9b-e1ac57c12391", + "metadata": {}, + "source": [ + "## Reproduce hydra_main.py" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f83f0be0-6c92-4106-96d9-2ef6341033e8", + "metadata": {}, + "outputs": [], + "source": [ + "seed_everything(seed=cfg.get('seed', None))\n", + "\n", + "dm = instantiate(cfg.datamodule)\n", + "dm.setup()\n", + "\n", + "lit_mod_cls = get_class(cfg.lit_mod_cls)\n", + "\n", + "runner = FourDVarNetHydraRunner(cfg.params, dm, lit_mod_cls)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4fd2667d-21d5-4ac0-8966-8a12d1a62ba1", + "metadata": {}, + "outputs": [], + "source": [ + "train_dl = dm.train_dataloader()\n", + "val_dl = dm.val_dataloader()\n", + "test_dl = dm.test_dataloader()\n", + "\n", + "print(len(train_dl), len(val_dl), len(test_dl))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c2342d4f-c035-4f25-b3b1-a569be8a835f", + "metadata": {}, + "outputs": [], + "source": [ + "for batch in train_dl:\n", + " \n", + " targets_OI, inputs_Mask, inputs_obs, targets_GT = batch\n", + " break \n", + " \n", + "targets_OI, inputs_Mask, inputs_obs, targets_GT = (\n", + " targets_OI.cpu().numpy(), \n", + " inputs_Mask.cpu().numpy(),\n", + " inputs_obs.cpu().numpy(),\n", + " targets_GT.cpu().numpy()\n", + ")\n", + "\n", + "print('mean obs : ', inputs_obs[inputs_obs != 0].mean())\n", + "print('std obs : ', inputs_obs[inputs_obs != 0].std())\n", + "print('min obs : ', inputs_obs[inputs_obs != 0].min())\n", + "print('max obs : ', inputs_obs[inputs_obs != 0].max())\n", + "\n", + "print('NaNs obs : ', np.isnan(inputs_obs).sum()) \n", + "print('---')\n", + "print('mean oi : ', targets_OI[targets_OI != 0].mean())\n", + "print('std oi : ', targets_OI[targets_OI != 0].std())\n", + "print('min oi : ', targets_OI[targets_OI != 0].min())\n", + "print('max oi : ', targets_OI[targets_OI != 0].max())\n", + "print('NaNs oi : ', np.isnan(targets_OI).sum()) " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "69f157c3-fbfb-4b72-9e52-b23e69d432c9", + "metadata": {}, + "outputs": [], + "source": [ + "n_times = int(inputs_obs.shape[1])\n", + "\n", + "fig, ax = plt.subplots(4, n_times, figsize=(16,16))\n", + "\n", + "for i in range(n_times):\n", + " ax[0,i].imshow(inputs_obs[0,i])\n", + " ax[0,i].set_title(f\"Input obs time {i}\")\n", + "\n", + " ax[1,i].imshow(inputs_Mask[0,i])\n", + " ax[1,i].set_title(f\"Input mask time {i}\")\n", + "\n", + " ax[2,i].imshow(targets_OI[0,i], vmin=-2, vmax=2)\n", + " ax[2,i].set_title(f\"Target OI time {i}\")\n", + "\n", + " ax[3,i].imshow(targets_GT[0,i], vmin=-2, vmax=2)\n", + " ax[3,i].set_title(f\"Target GT time {i}\")\n", + "\n", + "plt.subplots_adjust()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dbf67279-930b-4aff-a751-51a74e42476c", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "interpreter": { + "hash": "6f737a54a835b1e39b48fca7537be7d279a88101ad848d0634d9b32bee2f5e72" + }, + "kernelspec": { + "display_name": "Python 3.8.10 64-bit ('datalab': pyenv)", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}