diff --git a/.gitignore b/.gitignore index 42a1f8b..9b83832 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +lightning_logs/ examples/lightning_logs/ tags diff --git a/examples/doublependulum_emlp.ipynb b/examples/doublependulum_emlp.ipynb new file mode 100644 index 0000000..00442f1 --- /dev/null +++ b/examples/doublependulum_emlp.ipynb @@ -0,0 +1,283 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "f4818f33-b331-419c-8dd5-3dfa60e1da39", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import sys\n", + "\n", + "sys.path.append(\"../\")\n", + "\n", + "import torch\n", + "import torch.utils as utils\n", + "\n", + "import pytorch_lightning as pl\n", + "\n", + "from torchemlp.groups import SO, O, S, Z\n", + "from torchemlp.nn.equivariant import EMLP\n", + "from torchemlp.nn.runners import (\n", + " FuncMSERegressionLightning,\n", + " DynamicsL2RegressionLightning,\n", + ")\n", + "from torchemlp.nn.utils import Standardize\n", + "from torchemlp.datasets import DoublePendulum\n", + "\n", + "# torch.set_default_dtype(torch.float32)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "08d616c9-bb21-4c1c-b835-c2b5533fc8d8", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "TRAINING_SET_SIZE = 5_000\n", + "VALIDATION_SET_SIZE = 1_000\n", + "TEST_SET_SIZE = 1_000\n", + "\n", + "DT = 0.1\n", + "T = 30.0\n", + "\n", + "BATCH_SIZE = 500\n", + "\n", + "N_EPOCHS = 5 # min(int(900_000 / TRAINING_SET_SIZE), 1000)\n", + "# N_EPOCHS = 100 # min(int(900_000 / TRAINING_SET_SIZE), 1000)\n", + "\n", + "N_CHANNELS = 384\n", + "N_LAYERS = 3\n", + "\n", + "DL_WORKERS = 0" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "b7627c87-9a7b-40e5-92ae-f8885c05ef29", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'Input type: 4V⁰+4V, output type: 8V⁰'" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset = DoublePendulum(\n", + " TRAINING_SET_SIZE + VALIDATION_SET_SIZE + TEST_SET_SIZE,\n", + " DT,\n", + " T,\n", + ")\n", + "f\"Input type: {dataset.repin(dataset.G)}, output type: {dataset.repout(dataset.G)}\"" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "055ccfb1-c8d6-4d1c-a319-06844a43b269", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([7000, 12])" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset.X[0].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "56f15ae9-6cca-4b34-9f93-ca879d9992c2", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "split_data = utils.data.random_split(\n", + " dataset, [TRAINING_SET_SIZE, VALIDATION_SET_SIZE, TEST_SET_SIZE]\n", + ")\n", + "\n", + "train_loader = utils.data.DataLoader(\n", + " split_data[0], batch_size=BATCH_SIZE, num_workers=DL_WORKERS, shuffle=True\n", + ")\n", + "val_loader = utils.data.DataLoader(\n", + " split_data[1], batch_size=BATCH_SIZE, num_workers=DL_WORKERS\n", + ")\n", + "test_loader = utils.data.DataLoader(\n", + " split_data[2], batch_size=BATCH_SIZE, num_workers=DL_WORKERS\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "e81cd4b4-3822-466a-9d1b-7879979e2b3b", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "model = Standardize(\n", + " EMLP(dataset.repin, dataset.repout, dataset.G, N_CHANNELS, N_LAYERS), dataset.stats\n", + ").cuda()\n", + "plmodel = DynamicsL2RegressionLightning(model)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "0dbc7048-557d-4af0-87d8-a37d671551d1", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "/home/ryan/.conda/envs/l2e10/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:67: UserWarning: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n", + " warning_cache.warn(\n", + "You are using a CUDA device ('NVIDIA GeForce RTX 3080 Ti') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "\n", + " | Name | Type | Params\n", + "--------------------------------------\n", + "0 | model | Standardize | 570 K \n", + "--------------------------------------\n", + "570 K Trainable params\n", + "0 Non-trainable params\n", + "570 K Total params\n", + "2.280 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a77f3c0512504242981bc2212ccef9ac", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Sanity Checking: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ryan/.conda/envs/l2e10/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 20 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", + " rank_zero_warn(\n" + ] + }, + { + "ename": "RuntimeError", + "evalue": "mat1 and mat2 shapes cannot be multiplied (500x12 and 16x432)", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[7], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m trainer \u001b[38;5;241m=\u001b[39m pl\u001b[38;5;241m.\u001b[39mTrainer(\n\u001b[1;32m 2\u001b[0m limit_train_batches\u001b[38;5;241m=\u001b[39mBATCH_SIZE, max_epochs\u001b[38;5;241m=\u001b[39mN_EPOCHS, accelerator\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgpu\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 3\u001b[0m )\n\u001b[0;32m----> 4\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mplmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_loader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_loader\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5\u001b[0m trainer\u001b[38;5;241m.\u001b[39mtest(plmodel, test_loader)\n", + "File \u001b[0;32m~/.conda/envs/l2e10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:608\u001b[0m, in \u001b[0;36mTrainer.fit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 606\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_maybe_unwrap_optimized(model)\n\u001b[1;32m 607\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m_lightning_module \u001b[38;5;241m=\u001b[39m model\n\u001b[0;32m--> 608\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_and_handle_interrupt\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 609\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit_impl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\n\u001b[1;32m 610\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.conda/envs/l2e10/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:38\u001b[0m, in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(trainer, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher\u001b[38;5;241m.\u001b[39mlaunch(trainer_fn, \u001b[38;5;241m*\u001b[39margs, trainer\u001b[38;5;241m=\u001b[39mtrainer, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 38\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtrainer_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 40\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m _TunerExitException:\n\u001b[1;32m 41\u001b[0m trainer\u001b[38;5;241m.\u001b[39m_call_teardown_hook()\n", + "File \u001b[0;32m~/.conda/envs/l2e10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:650\u001b[0m, in \u001b[0;36mTrainer._fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 643\u001b[0m ckpt_path \u001b[38;5;241m=\u001b[39m ckpt_path \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mresume_from_checkpoint\n\u001b[1;32m 644\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_ckpt_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39m_set_ckpt_path(\n\u001b[1;32m 645\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn,\n\u001b[1;32m 646\u001b[0m ckpt_path, \u001b[38;5;66;03m# type: ignore[arg-type]\u001b[39;00m\n\u001b[1;32m 647\u001b[0m model_provided\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 648\u001b[0m model_connected\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 649\u001b[0m )\n\u001b[0;32m--> 650\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mckpt_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 652\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstopped\n\u001b[1;32m 653\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n", + "File \u001b[0;32m~/.conda/envs/l2e10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1112\u001b[0m, in \u001b[0;36mTrainer._run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 1108\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39mrestore_training_state()\n\u001b[1;32m 1110\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39mresume_end()\n\u001b[0;32m-> 1112\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_stage\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1114\u001b[0m log\u001b[38;5;241m.\u001b[39mdetail(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: trainer tearing down\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1115\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_teardown()\n", + "File \u001b[0;32m~/.conda/envs/l2e10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1191\u001b[0m, in \u001b[0;36mTrainer._run_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1189\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpredicting:\n\u001b[1;32m 1190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_run_predict()\n\u001b[0;32m-> 1191\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_train\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.conda/envs/l2e10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1204\u001b[0m, in \u001b[0;36mTrainer._run_train\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1201\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pre_training_routine()\n\u001b[1;32m 1203\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m isolate_rng():\n\u001b[0;32m-> 1204\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_sanity_check\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1206\u001b[0m \u001b[38;5;66;03m# enable train mode\u001b[39;00m\n\u001b[1;32m 1207\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[0;32m~/.conda/envs/l2e10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1276\u001b[0m, in \u001b[0;36mTrainer._run_sanity_check\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1274\u001b[0m \u001b[38;5;66;03m# run eval step\u001b[39;00m\n\u001b[1;32m 1275\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m-> 1276\u001b[0m \u001b[43mval_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1278\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_callback_hooks(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mon_sanity_check_end\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1280\u001b[0m \u001b[38;5;66;03m# reset logger connector\u001b[39;00m\n", + "File \u001b[0;32m~/.conda/envs/l2e10/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:199\u001b[0m, in \u001b[0;36mLoop.run\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 198\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 199\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 201\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n", + "File \u001b[0;32m~/.conda/envs/l2e10/lib/python3.10/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py:152\u001b[0m, in \u001b[0;36mEvaluationLoop.advance\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_dataloaders \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 151\u001b[0m kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdataloader_idx\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m dataloader_idx\n\u001b[0;32m--> 152\u001b[0m dl_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mepoch_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_fetcher\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdl_max_batches\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 154\u001b[0m \u001b[38;5;66;03m# store batch level output per dataloader\u001b[39;00m\n\u001b[1;32m 155\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_outputs\u001b[38;5;241m.\u001b[39mappend(dl_outputs)\n", + "File \u001b[0;32m~/.conda/envs/l2e10/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:199\u001b[0m, in \u001b[0;36mLoop.run\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 198\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 199\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 201\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n", + "File \u001b[0;32m~/.conda/envs/l2e10/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:137\u001b[0m, in \u001b[0;36mEvaluationEpochLoop.advance\u001b[0;34m(self, data_fetcher, dl_max_batches, kwargs)\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_progress\u001b[38;5;241m.\u001b[39mincrement_started()\n\u001b[1;32m 136\u001b[0m \u001b[38;5;66;03m# lightning module methods\u001b[39;00m\n\u001b[0;32m--> 137\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_evaluation_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 138\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_evaluation_step_end(output)\n\u001b[1;32m 140\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_progress\u001b[38;5;241m.\u001b[39mincrement_processed()\n", + "File \u001b[0;32m~/.conda/envs/l2e10/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:234\u001b[0m, in \u001b[0;36mEvaluationEpochLoop._evaluation_step\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 223\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"The evaluation step (validation_step or test_step depending on the trainer's state).\u001b[39;00m\n\u001b[1;32m 224\u001b[0m \n\u001b[1;32m 225\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 231\u001b[0m \u001b[38;5;124;03m the outputs of the step\u001b[39;00m\n\u001b[1;32m 232\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 233\u001b[0m hook_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest_step\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mtesting \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvalidation_step\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m--> 234\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_strategy_hook\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhook_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 236\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m output\n", + "File \u001b[0;32m~/.conda/envs/l2e10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1494\u001b[0m, in \u001b[0;36mTrainer._call_strategy_hook\u001b[0;34m(self, hook_name, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1491\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m 1493\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[Strategy]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m-> 1494\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n", + "File \u001b[0;32m~/.conda/envs/l2e10/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py:390\u001b[0m, in \u001b[0;36mStrategy.validation_step\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 388\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprecision_plugin\u001b[38;5;241m.\u001b[39mval_step_context():\n\u001b[1;32m 389\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel, ValidationStep)\n\u001b[0;32m--> 390\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalidation_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/mnt/SATA4TB/tse/ryan/torchemlp/examples/../torchemlp/nn/runners.py:38\u001b[0m, in \u001b[0;36mRegressionLightning.validation_step\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mvalidation_step\u001b[39m(\u001b[38;5;28mself\u001b[39m, batch: \u001b[38;5;28mtuple\u001b[39m[torch\u001b[38;5;241m.\u001b[39mTensor, \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m], batch_idx: \u001b[38;5;28mint\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m---> 38\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch2loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 39\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlog(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mval_loss\u001b[39m\u001b[38;5;124m\"\u001b[39m, loss)\n", + "File \u001b[0;32m/mnt/SATA4TB/tse/ryan/torchemlp/examples/../torchemlp/nn/runners.py:75\u001b[0m, in \u001b[0;36mDynamicsL2RegressionLightning.batch2loss\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 72\u001b[0m (z0, ts), zs \u001b[38;5;241m=\u001b[39m batch\n\u001b[1;32m 74\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mautonomous:\n\u001b[0;32m---> 75\u001b[0m zs_pred \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43modeint_fn\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 76\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mlambda\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43m_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my0\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43my0\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 77\u001b[0m \u001b[43m \u001b[49m\u001b[43mz0\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 78\u001b[0m \u001b[43m \u001b[49m\u001b[43mts\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 79\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;66;43;03m# options={\"dtype\": torch.float32},\u001b[39;49;00m\n\u001b[1;32m 80\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 81\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 82\u001b[0m zs_pred \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39modeint_fn(\n\u001b[1;32m 83\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel, z0, ts[\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m],\n\u001b[1;32m 84\u001b[0m \u001b[38;5;66;03m#options={\"dtype\": torch.float32}\u001b[39;00m\n\u001b[1;32m 85\u001b[0m )\n", + "File \u001b[0;32m~/.conda/envs/l2e10/lib/python3.10/site-packages/torchdiffeq/_impl/odeint.py:77\u001b[0m, in \u001b[0;36modeint\u001b[0;34m(func, y0, t, rtol, atol, method, options, event_fn)\u001b[0m\n\u001b[1;32m 74\u001b[0m solver \u001b[38;5;241m=\u001b[39m SOLVERS[method](func\u001b[38;5;241m=\u001b[39mfunc, y0\u001b[38;5;241m=\u001b[39my0, rtol\u001b[38;5;241m=\u001b[39mrtol, atol\u001b[38;5;241m=\u001b[39matol, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39moptions)\n\u001b[1;32m 76\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m event_fn \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m---> 77\u001b[0m solution \u001b[38;5;241m=\u001b[39m \u001b[43msolver\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mintegrate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mt\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 78\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 79\u001b[0m event_t, solution \u001b[38;5;241m=\u001b[39m solver\u001b[38;5;241m.\u001b[39mintegrate_until_event(t[\u001b[38;5;241m0\u001b[39m], event_fn)\n", + "File \u001b[0;32m~/.conda/envs/l2e10/lib/python3.10/site-packages/torchdiffeq/_impl/solvers.py:28\u001b[0m, in \u001b[0;36mAdaptiveStepsizeODESolver.integrate\u001b[0;34m(self, t)\u001b[0m\n\u001b[1;32m 26\u001b[0m solution[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39my0\n\u001b[1;32m 27\u001b[0m t \u001b[38;5;241m=\u001b[39m t\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdtype)\n\u001b[0;32m---> 28\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_before_integrate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mt\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 29\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;28mlen\u001b[39m(t)):\n\u001b[1;32m 30\u001b[0m solution[i] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_advance(t[i])\n", + "File \u001b[0;32m~/.conda/envs/l2e10/lib/python3.10/site-packages/torchdiffeq/_impl/rk_common.py:161\u001b[0m, in \u001b[0;36mRKAdaptiveStepsizeODESolver._before_integrate\u001b[0;34m(self, t)\u001b[0m\n\u001b[1;32m 159\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_before_integrate\u001b[39m(\u001b[38;5;28mself\u001b[39m, t):\n\u001b[1;32m 160\u001b[0m t0 \u001b[38;5;241m=\u001b[39m t[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m--> 161\u001b[0m f0 \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mt\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43my0\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 162\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfirst_step \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 163\u001b[0m first_step \u001b[38;5;241m=\u001b[39m _select_initial_step(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfunc, t[\u001b[38;5;241m0\u001b[39m], \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39my0, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39morder \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrtol, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39matol,\n\u001b[1;32m 164\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm, f0\u001b[38;5;241m=\u001b[39mf0)\n", + "File \u001b[0;32m~/.conda/envs/l2e10/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", + "File \u001b[0;32m~/.conda/envs/l2e10/lib/python3.10/site-packages/torchdiffeq/_impl/misc.py:189\u001b[0m, in \u001b[0;36m_PerturbFunc.forward\u001b[0;34m(self, t, y, perturb)\u001b[0m\n\u001b[1;32m 186\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 187\u001b[0m \u001b[38;5;66;03m# Do nothing.\u001b[39;00m\n\u001b[1;32m 188\u001b[0m \u001b[38;5;28;01mpass\u001b[39;00m\n\u001b[0;32m--> 189\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbase_func\u001b[49m\u001b[43m(\u001b[49m\u001b[43mt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/mnt/SATA4TB/tse/ryan/torchemlp/examples/../torchemlp/nn/runners.py:76\u001b[0m, in \u001b[0;36mDynamicsL2RegressionLightning.batch2loss..\u001b[0;34m(_, y0)\u001b[0m\n\u001b[1;32m 72\u001b[0m (z0, ts), zs \u001b[38;5;241m=\u001b[39m batch\n\u001b[1;32m 74\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mautonomous:\n\u001b[1;32m 75\u001b[0m zs_pred \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39modeint_fn(\n\u001b[0;32m---> 76\u001b[0m \u001b[38;5;28;01mlambda\u001b[39;00m _, y0: \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43my0\u001b[49m\u001b[43m)\u001b[49m,\n\u001b[1;32m 77\u001b[0m z0,\n\u001b[1;32m 78\u001b[0m ts[\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m],\n\u001b[1;32m 79\u001b[0m \u001b[38;5;66;03m# options={\"dtype\": torch.float32},\u001b[39;00m\n\u001b[1;32m 80\u001b[0m )\n\u001b[1;32m 81\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 82\u001b[0m zs_pred \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39modeint_fn(\n\u001b[1;32m 83\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel, z0, ts[\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m],\n\u001b[1;32m 84\u001b[0m \u001b[38;5;66;03m#options={\"dtype\": torch.float32}\u001b[39;00m\n\u001b[1;32m 85\u001b[0m )\n", + "File \u001b[0;32m~/.conda/envs/l2e10/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", + "File \u001b[0;32m/mnt/SATA4TB/tse/ryan/torchemlp/examples/../torchemlp/nn/utils.py:72\u001b[0m, in \u001b[0;36mStandardize.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 70\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x: torch\u001b[38;5;241m.\u001b[39mTensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m torch\u001b[38;5;241m.\u001b[39mTensor:\n\u001b[1;32m 71\u001b[0m normalized_in \u001b[38;5;241m=\u001b[39m (x \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmuin) \u001b[38;5;241m/\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstdin\n\u001b[0;32m---> 72\u001b[0m y \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnormalized_in\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 73\u001b[0m unnormalized_out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstdout \u001b[38;5;241m*\u001b[39m y \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmuout\n\u001b[1;32m 74\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m unnormalized_out\n", + "File \u001b[0;32m~/.conda/envs/l2e10/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", + "File \u001b[0;32m/mnt/SATA4TB/tse/ryan/torchemlp/examples/../torchemlp/nn/equivariant.py:209\u001b[0m, in \u001b[0;36mEMLP.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 208\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x: torch\u001b[38;5;241m.\u001b[39mTensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m torch\u001b[38;5;241m.\u001b[39mTensor:\n\u001b[0;32m--> 209\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnetwork\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.conda/envs/l2e10/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", + "File \u001b[0;32m~/.conda/envs/l2e10/lib/python3.10/site-packages/torch/nn/modules/container.py:217\u001b[0m, in \u001b[0;36mSequential.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 215\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m):\n\u001b[1;32m 216\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[0;32m--> 217\u001b[0m \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 218\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28minput\u001b[39m\n", + "File \u001b[0;32m~/.conda/envs/l2e10/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", + "File \u001b[0;32m/mnt/SATA4TB/tse/ryan/torchemlp/examples/../torchemlp/nn/equivariant.py:161\u001b[0m, in \u001b[0;36mEMLPBlock.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x: torch\u001b[38;5;241m.\u001b[39mTensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m torch\u001b[38;5;241m.\u001b[39mTensor:\n\u001b[0;32m--> 161\u001b[0m lin \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 162\u001b[0m preact \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbilinear(lin) \u001b[38;5;241m+\u001b[39m lin\n\u001b[1;32m 163\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnonlinear(preact)\n", + "File \u001b[0;32m~/.conda/envs/l2e10/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", + "File \u001b[0;32m/mnt/SATA4TB/tse/ryan/torchemlp/examples/../torchemlp/nn/equivariant.py:42\u001b[0m, in \u001b[0;36mEquivariantLinear.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 40\u001b[0m W \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mPw \u001b[38;5;241m@\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mw\n\u001b[1;32m 41\u001b[0m b \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mPb \u001b[38;5;241m@\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mb\n\u001b[0;32m---> 42\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mx\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m@\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mW\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mT\u001b[49m \u001b[38;5;241m+\u001b[39m b\n", + "\u001b[0;31mRuntimeError\u001b[0m: mat1 and mat2 shapes cannot be multiplied (500x12 and 16x432)" + ] + } + ], + "source": [ + "trainer = pl.Trainer(\n", + " limit_train_batches=BATCH_SIZE, max_epochs=N_EPOCHS, accelerator=\"gpu\"\n", + ")\n", + "trainer.fit(plmodel, train_loader, val_loader)\n", + "trainer.test(plmodel, test_loader)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8bb2dbd8-6b85-4b44-9654-0530687f7822", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/doublependulum_emlp.py b/examples/doublependulum_emlp.py new file mode 100644 index 0000000..980f5b0 --- /dev/null +++ b/examples/doublependulum_emlp.py @@ -0,0 +1,79 @@ +import sys + +sys.path.append("../") + +import torch +import torch.utils as utils + +import pytorch_lightning as pl + +from torchemlp.utils import DEFAULT_DEVICE, DEFAULT_DEVICE_STR +from torchemlp.reps import Scalar +from torchemlp.nn.utils import AutonomousWrapper +from torchemlp.nn.contdepth import Hamiltonian +from torchemlp.nn.equivariant import EMLP +from torchemlp.nn.runners import DynamicsL2RegressionLightning +from torchemlp.nn.utils import Standardize +from torchemlp.datasets import DoublePendulum + +torch.set_default_dtype(torch.float32) + +TRAINING_SET_SIZE = 5_000 +VALIDATION_SET_SIZE = 1_000 +TEST_SET_SIZE = 1_000 + +DT = 0.1 +T = 30.0 + +BATCH_SIZE = 500 + +N_EPOCHS = 5 # min(int(900_000 / TRAINING_SET_SIZE), 1000) +# N_EPOCHS = 100 # min(int(900_000 / TRAINING_SET_SIZE), 1000) + +# N_CHANNELS = 384 +N_CHANNELS = 3 +N_LAYERS = 3 + +DL_WORKERS = 0 + +dataset = DoublePendulum( + TRAINING_SET_SIZE + VALIDATION_SET_SIZE + TEST_SET_SIZE, + DT, + T, +) + +print(f"Loaded dataset.") +print(f"Repin: {dataset.repin}") +print(f"Repout: {dataset.repout}") + +split_data = utils.data.random_split( + dataset, [TRAINING_SET_SIZE, VALIDATION_SET_SIZE, TEST_SET_SIZE] +) + +train_loader = utils.data.DataLoader( + split_data[0], batch_size=BATCH_SIZE, num_workers=DL_WORKERS, shuffle=True +) +val_loader = utils.data.DataLoader( + split_data[1], batch_size=BATCH_SIZE, num_workers=DL_WORKERS +) +test_loader = utils.data.DataLoader( + split_data[2], batch_size=BATCH_SIZE, num_workers=DL_WORKERS +) + +model = Hamiltonian( + AutonomousWrapper( + Standardize( + EMLP(dataset.repin, Scalar, dataset.G, N_CHANNELS, N_LAYERS), dataset.stats + ) + ) +).to(DEFAULT_DEVICE) + +plmodel = DynamicsL2RegressionLightning(model) + +trainer = pl.Trainer( + limit_train_batches=BATCH_SIZE, + max_epochs=N_EPOCHS, + accelerator=DEFAULT_DEVICE_STR, +) +trainer.fit(plmodel, train_loader, val_loader) +trainer.test(plmodel, test_loader) diff --git a/examples/doublependulum_mlp.py b/examples/doublependulum_mlp.py new file mode 100644 index 0000000..014375d --- /dev/null +++ b/examples/doublependulum_mlp.py @@ -0,0 +1,80 @@ +import sys + +sys.path.append("../") + +import torch +import torch.utils as utils +from torch.profiler import profile, record_function, ProfilerActivity + +import pytorch_lightning as pl + +from torchemlp.utils import DEFAULT_DEVICE, DEFAULT_DEVICE_STR +from torchemlp.nn.utils import MLP, AutonomousWrapper +from torchemlp.nn.contdepth import Hamiltonian +from torchemlp.nn.runners import DynamicsL2RegressionLightning +from torchemlp.nn.utils import Standardize +from torchemlp.datasets import DoublePendulum + +# torch.set_default_dtype(torch.float32) + +TRAINING_SET_SIZE = 5_000 +VALIDATION_SET_SIZE = 1_000 +TEST_SET_SIZE = 1_000 + +DT = 0.1 +T = 30.0 + +BATCH_SIZE = 500 + +N_EPOCHS = 1 # min(int(900_000 / TRAINING_SET_SIZE), 1000) +# N_EPOCHS = 100 # min(int(900_000 / TRAINING_SET_SIZE), 1000) + +N_CHANNELS = 384 +N_LAYERS = 3 + +DL_WORKERS = 0 + +dataset = DoublePendulum( + TRAINING_SET_SIZE + VALIDATION_SET_SIZE + TEST_SET_SIZE, + DT, + T, +) + +print(f"Loaded dataset.") +print(f"Repin: {dataset.repin}") +print(f"Repout: {dataset.repout}") + +split_data = utils.data.random_split( + dataset, [TRAINING_SET_SIZE, VALIDATION_SET_SIZE, TEST_SET_SIZE] +) + +train_loader = utils.data.DataLoader( + split_data[0], batch_size=BATCH_SIZE, num_workers=DL_WORKERS, shuffle=True +) +val_loader = utils.data.DataLoader( + split_data[1], batch_size=BATCH_SIZE, num_workers=DL_WORKERS +) +test_loader = utils.data.DataLoader( + split_data[2], batch_size=BATCH_SIZE, num_workers=DL_WORKERS +) + +model = Hamiltonian( + AutonomousWrapper(Standardize(MLP(12, 1, 2, 2), dataset.stats)) +).to(DEFAULT_DEVICE) + +plmodel = DynamicsL2RegressionLightning(model) + +trainer = pl.Trainer( + limit_train_batches=BATCH_SIZE, + max_epochs=N_EPOCHS, + accelerator=DEFAULT_DEVICE_STR, +) + +with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True, use_cuda=True) as prof: + with record_function("model_train"): + trainer.fit(plmodel, train_loader, val_loader) + +print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10)) +prof.export_chrome_trace("trace.json") + +trainer.test(plmodel, test_loader) diff --git a/examples/experiment.py b/examples/experiment.py new file mode 100644 index 0000000..6f7cc3b --- /dev/null +++ b/examples/experiment.py @@ -0,0 +1,14 @@ +import sys + +sys.path.append("../") + +import torch +import torch.utils as utils + +import pytorch_lightning as pl + +from torchemlp.groups import SO, O, S, Z +from torchemlp.nn.equivariant import EMLP +from torchemlp.nn.runners import RegressionLightning +from torchemlp.nn.utils import Standardize +from torchemlp.datasets import O5Synthetic diff --git a/examples/o5_emlp.ipynb b/examples/o5_emlp.ipynb index 40fcf01..8dcfb70 100644 --- a/examples/o5_emlp.ipynb +++ b/examples/o5_emlp.ipynb @@ -20,7 +20,8 @@ "\n", "from torchemlp.groups import SO, O, S, Z\n", "from torchemlp.nn.equivariant import EMLP\n", - "from torchemlp.nn.utils import RegressionLightning, Standardize\n", + "from torchemlp.nn.runners import FuncMSERegressionLightning\n", + "from torchemlp.nn.utils import Standardize\n", "from torchemlp.datasets import O5Synthetic" ] }, @@ -39,7 +40,8 @@ "\n", "BATCH_SIZE = 500\n", "\n", - "N_EPOCHS = 100 # min(int(900_000 / TRAINING_SET_SIZE), 1000)\n", + "N_EPOCHS = 5 # min(int(900_000 / TRAINING_SET_SIZE), 1000)\n", + "# N_EPOCHS = 100 # min(int(900_000 / TRAINING_SET_SIZE), 1000)\n", "\n", "N_CHANNELS = 384\n", "N_LAYERS = 3\n", @@ -110,7 +112,7 @@ "model = Standardize(\n", " EMLP(dataset.repin, dataset.repout, dataset.G, N_CHANNELS, N_LAYERS), dataset.stats\n", ").cuda()\n", - "plmodel = RegressionLightning(model)" + "plmodel = FuncMSERegressionLightning(model)" ] }, { @@ -129,8 +131,6 @@ "TPU available: False, using: 0 TPU cores\n", "IPU available: False, using: 0 IPUs\n", "HPU available: False, using: 0 HPUs\n", - "/home/eli/anaconda3/envs/l2e10/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:67: UserWarning: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n", - " warning_cache.warn(\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", "\n", " | Name | Type | Params\n", @@ -146,7 +146,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "73ee9907d4b64e29af5283759b5baafd", + "model_id": "", "version_major": 2, "version_minor": 0 }, @@ -161,46 +161,157 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/eli/anaconda3/envs/l2e10/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", + "/home/rytse/mambaforge/envs/l2e10/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", + " rank_zero_warn(\n", + "/home/rytse/Documents/research/torchemlp/examples/../torchemlp/reps/reps_base.py:1381: UserWarning: operator() sees varying value in profiling, ignoring and this should be handled by GUARD logic (Triggered internally at ../third_party/nvfuser/csrc/parser.cpp:3777.)\n", + " return lazy_proj_jit(\n", + "/home/rytse/Documents/research/torchemlp/examples/../torchemlp/reps/reps_base.py:1381: UserWarning: operator() profile_node %119 : int[] = prim::profile_ivalue[profile_failed=\"varying profile values\"](%117)\n", + " does not have profile information (Triggered internally at ../third_party/nvfuser/csrc/graph_fuser.cpp:104.)\n", + " return lazy_proj_jit(\n", + "/home/rytse/mambaforge/envs/l2e10/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", + " rank_zero_warn(\n", + "/home/rytse/mambaforge/envs/l2e10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1609: PossibleUserWarning: The number of training batches (10) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n", " rank_zero_warn(\n" ] }, { - "ename": "RuntimeError", - "evalue": "The following operation failed in the TorchScript interpreter.\nTraceback of TorchScript (most recent call last):\n File \"/home/eli/Documents/2023/Ryan/torchemlp/examples/../torchemlp/reps/reps_base.py\", line 1398, in helper\n @torch.jit.script\n def helper(to_container: torch.Tensor, params: torch.Tensor, i: int, i_end: int, W_mult: int, x: torch.Tensor, bids: torch.Tensor, rep_size: int, n: int, bs: int, local_w_size: int):\n bilinear_params = params[i:i_end].reshape(W_mult, n)\n ~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE\n to_container = (bilinear_params @ x[:, bids].T.reshape(n, rep_size * bs)).reshape(local_w_size,bs).T\n return to_container\nRuntimeError: shape '[8200, 5]' is invalid for input of size 0\n", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[6], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m trainer \u001b[38;5;241m=\u001b[39m pl\u001b[38;5;241m.\u001b[39mTrainer(\n\u001b[1;32m 2\u001b[0m limit_train_batches\u001b[38;5;241m=\u001b[39mBATCH_SIZE, max_epochs\u001b[38;5;241m=\u001b[39mN_EPOCHS, accelerator\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgpu\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 3\u001b[0m )\n\u001b[0;32m----> 4\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mplmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_loader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_loader\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5\u001b[0m trainer\u001b[38;5;241m.\u001b[39mtest(plmodel, test_loader)\n", - "File \u001b[0;32m~/anaconda3/envs/l2e10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:608\u001b[0m, in \u001b[0;36mTrainer.fit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 606\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_maybe_unwrap_optimized(model)\n\u001b[1;32m 607\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m_lightning_module \u001b[38;5;241m=\u001b[39m model\n\u001b[0;32m--> 608\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_and_handle_interrupt\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 609\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit_impl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\n\u001b[1;32m 610\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/anaconda3/envs/l2e10/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:38\u001b[0m, in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(trainer, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher\u001b[38;5;241m.\u001b[39mlaunch(trainer_fn, \u001b[38;5;241m*\u001b[39margs, trainer\u001b[38;5;241m=\u001b[39mtrainer, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 38\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtrainer_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 40\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m _TunerExitException:\n\u001b[1;32m 41\u001b[0m trainer\u001b[38;5;241m.\u001b[39m_call_teardown_hook()\n", - "File \u001b[0;32m~/anaconda3/envs/l2e10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:650\u001b[0m, in \u001b[0;36mTrainer._fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 643\u001b[0m ckpt_path \u001b[38;5;241m=\u001b[39m ckpt_path \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mresume_from_checkpoint\n\u001b[1;32m 644\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_ckpt_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39m_set_ckpt_path(\n\u001b[1;32m 645\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn,\n\u001b[1;32m 646\u001b[0m ckpt_path, \u001b[38;5;66;03m# type: ignore[arg-type]\u001b[39;00m\n\u001b[1;32m 647\u001b[0m model_provided\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 648\u001b[0m model_connected\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 649\u001b[0m )\n\u001b[0;32m--> 650\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mckpt_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 652\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstopped\n\u001b[1;32m 653\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n", - "File \u001b[0;32m~/anaconda3/envs/l2e10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1112\u001b[0m, in \u001b[0;36mTrainer._run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 1108\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39mrestore_training_state()\n\u001b[1;32m 1110\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39mresume_end()\n\u001b[0;32m-> 1112\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_stage\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1114\u001b[0m log\u001b[38;5;241m.\u001b[39mdetail(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: trainer tearing down\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1115\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_teardown()\n", - "File \u001b[0;32m~/anaconda3/envs/l2e10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1191\u001b[0m, in \u001b[0;36mTrainer._run_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1189\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpredicting:\n\u001b[1;32m 1190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_run_predict()\n\u001b[0;32m-> 1191\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_train\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/anaconda3/envs/l2e10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1204\u001b[0m, in \u001b[0;36mTrainer._run_train\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1201\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pre_training_routine()\n\u001b[1;32m 1203\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m isolate_rng():\n\u001b[0;32m-> 1204\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_sanity_check\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1206\u001b[0m \u001b[38;5;66;03m# enable train mode\u001b[39;00m\n\u001b[1;32m 1207\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n", - "File \u001b[0;32m~/anaconda3/envs/l2e10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1276\u001b[0m, in \u001b[0;36mTrainer._run_sanity_check\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1274\u001b[0m \u001b[38;5;66;03m# run eval step\u001b[39;00m\n\u001b[1;32m 1275\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m-> 1276\u001b[0m \u001b[43mval_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1278\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_callback_hooks(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mon_sanity_check_end\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1280\u001b[0m \u001b[38;5;66;03m# reset logger connector\u001b[39;00m\n", - "File \u001b[0;32m~/anaconda3/envs/l2e10/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:199\u001b[0m, in \u001b[0;36mLoop.run\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 198\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 199\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 201\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n", - "File \u001b[0;32m~/anaconda3/envs/l2e10/lib/python3.10/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py:152\u001b[0m, in \u001b[0;36mEvaluationLoop.advance\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_dataloaders \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 151\u001b[0m kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdataloader_idx\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m dataloader_idx\n\u001b[0;32m--> 152\u001b[0m dl_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mepoch_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_fetcher\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdl_max_batches\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 154\u001b[0m \u001b[38;5;66;03m# store batch level output per dataloader\u001b[39;00m\n\u001b[1;32m 155\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_outputs\u001b[38;5;241m.\u001b[39mappend(dl_outputs)\n", - "File \u001b[0;32m~/anaconda3/envs/l2e10/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:199\u001b[0m, in \u001b[0;36mLoop.run\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 198\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 199\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 201\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n", - "File \u001b[0;32m~/anaconda3/envs/l2e10/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:137\u001b[0m, in \u001b[0;36mEvaluationEpochLoop.advance\u001b[0;34m(self, data_fetcher, dl_max_batches, kwargs)\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_progress\u001b[38;5;241m.\u001b[39mincrement_started()\n\u001b[1;32m 136\u001b[0m \u001b[38;5;66;03m# lightning module methods\u001b[39;00m\n\u001b[0;32m--> 137\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_evaluation_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 138\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_evaluation_step_end(output)\n\u001b[1;32m 140\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_progress\u001b[38;5;241m.\u001b[39mincrement_processed()\n", - "File \u001b[0;32m~/anaconda3/envs/l2e10/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:234\u001b[0m, in \u001b[0;36mEvaluationEpochLoop._evaluation_step\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 223\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"The evaluation step (validation_step or test_step depending on the trainer's state).\u001b[39;00m\n\u001b[1;32m 224\u001b[0m \n\u001b[1;32m 225\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 231\u001b[0m \u001b[38;5;124;03m the outputs of the step\u001b[39;00m\n\u001b[1;32m 232\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 233\u001b[0m hook_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest_step\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mtesting \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvalidation_step\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m--> 234\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_strategy_hook\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhook_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 236\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m output\n", - "File \u001b[0;32m~/anaconda3/envs/l2e10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1494\u001b[0m, in \u001b[0;36mTrainer._call_strategy_hook\u001b[0;34m(self, hook_name, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1491\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m 1493\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[Strategy]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m-> 1494\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n", - "File \u001b[0;32m~/anaconda3/envs/l2e10/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py:390\u001b[0m, in \u001b[0;36mStrategy.validation_step\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 388\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprecision_plugin\u001b[38;5;241m.\u001b[39mval_step_context():\n\u001b[1;32m 389\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel, ValidationStep)\n\u001b[0;32m--> 390\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalidation_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/Documents/2023/Ryan/torchemlp/examples/../torchemlp/nn/utils.py:104\u001b[0m, in \u001b[0;36mRegressionLightning.validation_step\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m 102\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mvalidation_step\u001b[39m(\u001b[38;5;28mself\u001b[39m, batch, batch_idx):\n\u001b[1;32m 103\u001b[0m x, y \u001b[38;5;241m=\u001b[39m batch\n\u001b[0;32m--> 104\u001b[0m y_pred \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 105\u001b[0m loss \u001b[38;5;241m=\u001b[39m F\u001b[38;5;241m.\u001b[39mmse_loss(y_pred, y)\n\u001b[1;32m 106\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlog(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mval_loss\u001b[39m\u001b[38;5;124m\"\u001b[39m, loss)\n", - "File \u001b[0;32m~/anaconda3/envs/l2e10/lib/python3.10/site-packages/torch/nn/modules/module.py:1533\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1528\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;66;03m# this function, and just call forward. It's slow for dynamo to guard on the state\u001b[39;00m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;66;03m# of all these hook dicts individually, so instead it can guard on 2 bools and we just\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;66;03m# have to promise to keep them up to date when hooks are added or removed via official means.\u001b[39;00m\n\u001b[1;32m 1532\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_has_hooks \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m _has_global_hooks:\n\u001b[0;32m-> 1533\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1534\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1535\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", - "File \u001b[0;32m~/Documents/2023/Ryan/torchemlp/examples/../torchemlp/nn/utils.py:76\u001b[0m, in \u001b[0;36mStandardize.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 74\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x: torch\u001b[38;5;241m.\u001b[39mTensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m torch\u001b[38;5;241m.\u001b[39mTensor:\n\u001b[1;32m 75\u001b[0m normalized_in \u001b[38;5;241m=\u001b[39m (x \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmuin) \u001b[38;5;241m/\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstdin\n\u001b[0;32m---> 76\u001b[0m y \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnormalized_in\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 77\u001b[0m unnormalized_out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstdout \u001b[38;5;241m*\u001b[39m y \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmuout\n\u001b[1;32m 78\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m unnormalized_out\n", - "File \u001b[0;32m~/anaconda3/envs/l2e10/lib/python3.10/site-packages/torch/nn/modules/module.py:1533\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1528\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;66;03m# this function, and just call forward. It's slow for dynamo to guard on the state\u001b[39;00m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;66;03m# of all these hook dicts individually, so instead it can guard on 2 bools and we just\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;66;03m# have to promise to keep them up to date when hooks are added or removed via official means.\u001b[39;00m\n\u001b[1;32m 1532\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_has_hooks \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m _has_global_hooks:\n\u001b[0;32m-> 1533\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1534\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1535\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", - "File \u001b[0;32m~/Documents/2023/Ryan/torchemlp/examples/../torchemlp/nn/equivariant.py:204\u001b[0m, in \u001b[0;36mEMLP.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x: torch\u001b[38;5;241m.\u001b[39mTensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m torch\u001b[38;5;241m.\u001b[39mTensor:\n\u001b[0;32m--> 204\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnetwork\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/anaconda3/envs/l2e10/lib/python3.10/site-packages/torch/nn/modules/module.py:1533\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1528\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;66;03m# this function, and just call forward. It's slow for dynamo to guard on the state\u001b[39;00m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;66;03m# of all these hook dicts individually, so instead it can guard on 2 bools and we just\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;66;03m# have to promise to keep them up to date when hooks are added or removed via official means.\u001b[39;00m\n\u001b[1;32m 1532\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_has_hooks \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m _has_global_hooks:\n\u001b[0;32m-> 1533\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1534\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1535\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", - "File \u001b[0;32m~/anaconda3/envs/l2e10/lib/python3.10/site-packages/torch/nn/modules/container.py:217\u001b[0m, in \u001b[0;36mSequential.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 215\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m):\n\u001b[1;32m 216\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[0;32m--> 217\u001b[0m \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 218\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28minput\u001b[39m\n", - "File \u001b[0;32m~/anaconda3/envs/l2e10/lib/python3.10/site-packages/torch/nn/modules/module.py:1533\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1528\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;66;03m# this function, and just call forward. It's slow for dynamo to guard on the state\u001b[39;00m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;66;03m# of all these hook dicts individually, so instead it can guard on 2 bools and we just\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;66;03m# have to promise to keep them up to date when hooks are added or removed via official means.\u001b[39;00m\n\u001b[1;32m 1532\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_has_hooks \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m _has_global_hooks:\n\u001b[0;32m-> 1533\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1534\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1535\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", - "File \u001b[0;32m~/Documents/2023/Ryan/torchemlp/examples/../torchemlp/nn/equivariant.py:157\u001b[0m, in \u001b[0;36mEMLPBlock.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 155\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x: torch\u001b[38;5;241m.\u001b[39mTensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m torch\u001b[38;5;241m.\u001b[39mTensor:\n\u001b[1;32m 156\u001b[0m lin \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlinear(x)\n\u001b[0;32m--> 157\u001b[0m preact \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbilinear\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlin\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;241m+\u001b[39m lin\n\u001b[1;32m 158\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnonlinear(preact)\n", - "File \u001b[0;32m~/anaconda3/envs/l2e10/lib/python3.10/site-packages/torch/nn/modules/module.py:1533\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1528\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;66;03m# this function, and just call forward. It's slow for dynamo to guard on the state\u001b[39;00m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;66;03m# of all these hook dicts individually, so instead it can guard on 2 bools and we just\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;66;03m# have to promise to keep them up to date when hooks are added or removed via official means.\u001b[39;00m\n\u001b[1;32m 1532\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_has_hooks \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m _has_global_hooks:\n\u001b[0;32m-> 1533\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1534\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1535\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", - "File \u001b[0;32m~/Documents/2023/Ryan/torchemlp/examples/../torchemlp/nn/equivariant.py:63\u001b[0m, in \u001b[0;36mEquivariantBiLinear.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x: torch\u001b[38;5;241m.\u001b[39mTensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m torch\u001b[38;5;241m.\u001b[39mTensor:\n\u001b[0;32m---> 63\u001b[0m W_proj \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mW_proj\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mW\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 64\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;241m0.1\u001b[39m \u001b[38;5;241m*\u001b[39m (W_proj \u001b[38;5;241m@\u001b[39m x[\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m])[\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m, \u001b[38;5;241m0\u001b[39m]\n", - "File \u001b[0;32m~/Documents/2023/Ryan/torchemlp/examples/../torchemlp/reps/reps_base.py:1385\u001b[0m, in \u001b[0;36mbilinear_weights..lazy_proj_fast\u001b[0;34m(params, x)\u001b[0m\n\u001b[1;32m 1381\u001b[0m bids \u001b[38;5;241m=\u001b[39m reduced_indices_dict[rep]\n\u001b[1;32m 1384\u001b[0m i \u001b[38;5;241m=\u001b[39m i_end\n\u001b[0;32m-> 1385\u001b[0m \u001b[43mhelper\u001b[49m\u001b[43m(\u001b[49m\u001b[43mW\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstart_index\u001b[49m\u001b[43m:\u001b[49m\u001b[43mstart_index\u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43mlocal_w_size\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mi\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mi_end\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mW_mult\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbids\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrep\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msize\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlocal_w_size\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1386\u001b[0m \u001b[38;5;66;03m#W[:, start_index:start_index+local_w_size] = (bilinear_params @ x[..., bids].T.reshape(n, rep.size * bs)).reshape(local_w_size,bs).T\u001b[39;00m\n\u001b[1;32m 1387\u001b[0m \u001b[38;5;66;03m#torch.matmul(bilinear_params, )\u001b[39;00m\n\u001b[1;32m 1388\u001b[0m \u001b[38;5;66;03m#W[:, start_index:start_index+local_w_size] = bilinear_elems\u001b[39;00m\n\u001b[1;32m 1389\u001b[0m \u001b[38;5;66;03m#bilinear_elems.reshape(local_w_size,bs).T \u001b[39;00m\n\u001b[1;32m 1390\u001b[0m start_index \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m local_w_size\n", - "\u001b[0;31mRuntimeError\u001b[0m: The following operation failed in the TorchScript interpreter.\nTraceback of TorchScript (most recent call last):\n File \"/home/eli/Documents/2023/Ryan/torchemlp/examples/../torchemlp/reps/reps_base.py\", line 1398, in helper\n @torch.jit.script\n def helper(to_container: torch.Tensor, params: torch.Tensor, i: int, i_end: int, W_mult: int, x: torch.Tensor, bids: torch.Tensor, rep_size: int, n: int, bs: int, local_w_size: int):\n bilinear_params = params[i:i_end].reshape(W_mult, n)\n ~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE\n to_container = (bilinear_params @ x[:, bids].T.reshape(n, rep_size * bs)).reshape(local_w_size,bs).T\n return to_container\nRuntimeError: shape '[8200, 5]' is invalid for input of size 0\n" + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "dab305829cb4490886a0dac0e975dd3c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`Trainer.fit` stopped: `max_epochs=5` reached.\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "/home/rytse/mambaforge/envs/l2e10/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", + " rank_zero_warn(\n" ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9a6545ba2c464333aa7ecb5e130c1f7d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Testing: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃        Test metric               DataLoader 0        ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+       "│         test_loss             35.857906341552734     │\n",
+       "└───────────────────────────┴───────────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 35.857906341552734 \u001b[0m\u001b[35m \u001b[0m│\n", + "└───────────────────────────┴───────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "[{'test_loss': 35.857906341552734}]" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ @@ -236,7 +347,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.10.11" } }, "nbformat": 4, diff --git a/torchemlp/datasets.py b/torchemlp/datasets.py index 332d312..617ce13 100644 --- a/torchemlp/datasets.py +++ b/torchemlp/datasets.py @@ -1,3 +1,4 @@ +from typing import Callable from abc import ABC, abstractmethod import torch @@ -5,9 +6,12 @@ import functorch -from torchemlp.groups import Group, O, SO +from torchemlp.groups import Z, Group, O, SO, O2eR3 from torchemlp.reps import Rep, Vector, Scalar, T -from torchemlp.utils import DEFAULT_DEVICE +from torchemlp.utils import DEFAULT_DEVICE, unpack_hsys +from torchemlp.nn.contdepth import hamiltonian_dynamics + +from torchdiffeq import odeint class GroupAugmentation(nn.Module): @@ -46,8 +50,8 @@ def __init__( G: Group, repin: Rep, repout: Rep, - X: torch.Tensor, - Y: torch.Tensor, + X: torch.Tensor | tuple[torch.Tensor, ...], + Y: torch.Tensor | tuple[torch.Tensor, ...], stats: list[float], ): self.G = G @@ -58,11 +62,22 @@ def __init__( self.Y = Y self.stats = stats - def __getitem__(self, i): + def __getitem__( + self, i + ) -> ( + tuple[torch.Tensor, torch.Tensor] + | tuple[tuple[torch.Tensor, ...], torch.Tensor] + | tuple[torch.Tensor, tuple[torch.Tensor, ...]] + | tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor]] + ): return (self.X[i], self.Y[i]) - def __len__(self): - return self.X.shape[0] + def __len__(self) -> int: + match self.X: + case torch.Tensor(): + return self.X.shape[0] + case tuple(): + return self.X[0].shape[0] def default_aug(self, model): return GroupAugmentation(model, self.repin, self.repout, self.G) @@ -172,3 +187,102 @@ def __init__(self, N=1024, device: torch.device = DEFAULT_DEVICE): stats = [X.mean(0), X.std(dim=0), Y.mean(dim=0), Y.std(dim=0)] super().__init__(d, G, repin, repout, X, Y, stats) + + +class DynamicsDataset(Dataset): + def __init__( + self, + dynamics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + n_traj: int, + dt: float, + T: float, + dim: int, + G: Group, + repin: Rep, + repout: Rep, + device: torch.device = DEFAULT_DEVICE, + ): + self.dynamics = dynamics + self.device = device + + self.zs, self.ts = self.gen_trajs(n_traj, dt, T) + + X = (self.zs[:, 0, ...], self.ts) + Y = self.zs + + stats = [0.0, 1.0, 0.0, 1.0] + + super().__init__(dim, G, repin, repout, X, Y, stats) + + @abstractmethod + def sample_ics(self, n_traj: int): + raise NotImplementedError + + def gen_trajs( + self, n_traj: int, dt: float, T: float + ) -> tuple[torch.Tensor, torch.Tensor]: + z0 = self.sample_ics(n_traj) + ts = torch.arange(0.0, T, dt, device=self.device) + sol = torch.swapaxes( + odeint( + self.dynamics, + z0, + ts, + options={"dtype": torch.float32}, + ), + 0, + 1, + ).detach() + return sol, ts + + def __getitem__(self, i: int) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]: + return (self.zs[i, 0, ...], self.ts), self.zs[i, ...] + + +class DoublePendulum(DynamicsDataset): + def __init__( + self, + n_traj: int, + dt: float, + dur: float, + device: torch.device = DEFAULT_DEVICE, + ): + dynamics = lambda t, z: hamiltonian_dynamics(self.H, z, t) + G = O2eR3() + # base_rep = Vector(G) + repin = 4 * T(1) + repout = 4 * T(1) + # repout= T(0) + # repin = 4 * (base_rep**1 + base_rep.T**0) + # repout = 4 * (base_rep**0 + base_rep.T**0) + dim = 12 + super().__init__(dynamics, n_traj, dt, dur, dim, G, repin, repout, device) + + def H(self, _: torch.Tensor, z: torch.Tensor) -> torch.Tensor: + g, m1, m2, k1, k2, l1, l2 = 1, 1, 1, 1, 1, 1, 1 + + q, p = unpack_hsys(z) + q1, q2 = unpack_hsys(q) + p1, p2 = unpack_hsys(p) + + ke = 1.0 / (2.0 * m1) * torch.square(torch.norm(p1, dim=-1)) + 1.0 / ( + 2.0 * m2 + ) * torch.square(torch.norm(p2, dim=-1)) + pe = ( + k1 / 2.0 * torch.square(torch.norm(q1, dim=-1) - l1) + + k2 * torch.square(torch.norm(q1 - q2, dim=-1) - l2) + + m1 * g * q1[..., 2] + + m2 * g * q2[..., 2] + ) + + return ke + pe + + def sample_ics(self, bs: int) -> torch.Tensor: + x1 = torch.tensor([0.0, 0.0, -1.5], device=self.device) + 0.2 * torch.randn( + bs, 3, device=self.device + ) + x2 = torch.tensor([0.0, 0.0, -3.0], device=self.device) + 0.2 * torch.randn( + bs, 3, device=self.device + ) + p = 0.4 * torch.randn(bs, 6, device=self.device) + return torch.cat([x1, x2, p], dim=1) diff --git a/torchemlp/nn/contdepth.py b/torchemlp/nn/contdepth.py new file mode 100644 index 0000000..9cc11fb --- /dev/null +++ b/torchemlp/nn/contdepth.py @@ -0,0 +1,50 @@ +from typing import Callable + +import torch +import torch.nn as nn +from torch.autograd import grad +import torch.autograd.functional as F + +from torchemlp.nn.equivariant import EMLP + + +class Hamiltonian(nn.Module): + def __init__(self, H: nn.Module): + super().__init__() + self.H = H + + def forward(self, t: torch.Tensor, z: torch.Tensor): + return hamiltonian_dynamics(self.H, z, t) + + +def hamiltonian_dynamics( + H: nn.Module, + z: torch.Tensor, + t: torch.Tensor, +) -> torch.Tensor: + """ + Given a nn.Module representing the Hamiltonian of a system, compute the + dynamics function z' = J∇H(z) + """ + d = z.shape[-1] // 2 + + z.requires_grad_(True) + t.requires_grad_(True) + + J = torch.zeros((2 * d, 2 * d)).to(z) + J[:d, d:] = torch.eye(d) + J[d:, :d] = -torch.eye(d) + J = J.requires_grad_(True) + + with torch.enable_grad(): + H_val = H(t, z) + + dHdz = grad( + H_val, + z, + grad_outputs=torch.ones_like(H_val), + create_graph=True, + retain_graph=True, + )[0] + + return torch.matmul(dHdz, J.t()) diff --git a/torchemlp/nn/equivariant.py b/torchemlp/nn/equivariant.py index 4b95f6a..a49881e 100644 --- a/torchemlp/nn/equivariant.py +++ b/torchemlp/nn/equivariant.py @@ -324,5 +324,5 @@ def __init__(self, rep_in: Rep, rep_out: Rep, group: Group, ch=384, num_layers=3 reps = [self.rep_in] + middle_layers self.network = torch.nn.Sequential( *[EMLPBlock(rin, rout) for rin, rout in zip(reps, reps[1:])], - torch.nn.Linear(reps[-1], self.rep_out) + torch.nn.Linear(reps[-1], self.rep_out), ) diff --git a/torchemlp/nn/runners.py b/torchemlp/nn/runners.py new file mode 100644 index 0000000..2d8ec71 --- /dev/null +++ b/torchemlp/nn/runners.py @@ -0,0 +1,86 @@ +from typing import Callable +from abc import ABC, abstractmethod + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F + +from torchdiffeq import odeint_adjoint as odeint + +from torchemlp.nn.contdepth import hamiltonian_dynamics + +import pytorch_lightning as pl + + +class RegressionLightning(ABC, pl.LightningModule): + def __init__(self, model: nn.Module, lr: float = 3e-3, weight_decay: float = 0.0): + super().__init__() + self.model = model + self.lr = lr + self.weight_decay = weight_decay + + @abstractmethod + def batch2loss(self, batch: tuple[torch.Tensor, ...]) -> torch.Tensor: + pass + + def training_step( + self, batch: tuple[torch.Tensor, ...], batch_idx: int + ) -> torch.Tensor: + loss = self.batch2loss(batch) + self.log("train_loss", loss) + + return loss + + def test_step(self, batch: tuple[torch.Tensor, ...], batch_idx: int) -> None: + loss = self.batch2loss(batch) + self.log("test_loss", loss.item()) + + def validation_step(self, batch: tuple[torch.Tensor, ...], batch_idx: int) -> None: + loss = self.batch2loss(batch) + self.log("val_loss", loss) + + def configure_optimizers(self): + optimizer = optim.Adam( + self.parameters(), lr=self.lr, weight_decay=self.weight_decay + ) + return optimizer + + +class FuncMSERegressionLightning(RegressionLightning): + def batch2loss(self, batch: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + x, y = batch + y_pred = self.model(x) + return F.mse_loss(y_pred, y) + + +class DynamicsL2RegressionLightning(RegressionLightning): + def __init__( + self, + model: nn.Module, + odeint_fn: Callable = odeint, + lr: float = 3e-3, + weight_decay: float = 0.0, + ): + super().__init__(model, lr, weight_decay) + self.odeint_fn = odeint_fn + + def batch2loss( + self, batch: tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor] + ) -> torch.Tensor: + (z0, ts), zs = batch + + # Force them to have grad so that we can evaluate dynamics + z0 = z0.requires_grad_(True) + zs = zs.requires_grad_(True) + ts = ts.requires_grad_(True) + + zs_pred = self.odeint_fn( + self.model, + z0, + ts[0, ...], + options={"dtype": torch.float32}, + ) + zs_pred = torch.swapaxes(zs_pred, 0, 1) + + return F.mse_loss(zs_pred, zs) diff --git a/torchemlp/nn/utils.py b/torchemlp/nn/utils.py index 847e948..4e12e7f 100644 --- a/torchemlp/nn/utils.py +++ b/torchemlp/nn/utils.py @@ -2,10 +2,6 @@ import torch import torch.nn as nn -import torch.optim as optim -import torch.nn.functional as F - -import pytorch_lightning as pl class MLP(nn.Module): @@ -36,7 +32,7 @@ def __init__( layers.append(act()) for _ in range(depth): layers.append(nn.Linear(width, width)) - layers.append(act) + layers.append(act()) layers.append(nn.Linear(width, dim_out)) self.network = nn.Sequential(*layers) @@ -78,35 +74,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return unnormalized_out -class RegressionLightning(pl.LightningModule): - def __init__(self, model, lr=3e-3, weight_decay=0.0): - super().__init__() +class AutonomousWrapper(nn.Module): + """ + Convenience module for converting a module that is only a function of state + to a new module that has input args for state and time. + """ + + def __init__(self, model: nn.Module): + super(AutonomousWrapper, self).__init__() self.model = model - self.lr = lr - self.weight_decay = weight_decay - - def training_step(self, batch, batch_idx): - x, y = batch - y_pred = self.model(x) - loss = F.mse_loss(y_pred, y) - self.log("train_loss", loss) - - return loss - - def test_step(self, batch, batch_idx): - x, y = batch - y_pred = self.model(x) - loss = F.mse_loss(y_pred, y) - self.log("test_loss", loss.item()) - - def validation_step(self, batch, batch_idx): - x, y = batch - y_pred = self.model(x) - loss = F.mse_loss(y_pred, y) - self.log("val_loss", loss) - - def configure_optimizers(self): - optimizer = optim.Adam( - self.parameters(), lr=self.lr, weight_decay=self.weight_decay - ) - return optimizer + + def forward(self, _: torch.Tensor, z: torch.Tensor) -> torch.Tensor: + return self.model(z) diff --git a/torchemlp/reps/reps_base.py b/torchemlp/reps/reps_base.py index 2fc9d26..efe4e5c 100644 --- a/torchemlp/reps/reps_base.py +++ b/torchemlp/reps/reps_base.py @@ -807,7 +807,9 @@ def rep_permutation( """ size_cumsums = [] for repsizes in repsizes_all: - padded_sizes = torch.tensor([0] + [size for size in repsizes], device=device) + padded_sizes = torch.tensor( + [0] + [size for size in repsizes], dtype=torch.int32, device=device + ) size_cumsums.append(torch.cumsum(padded_sizes, 0)) permutation = torch.zeros( [cumsum[-1] for cumsum in size_cumsums], dtype=torch.int, device=device @@ -1404,7 +1406,6 @@ def lazy_proj_jit( REP_SIZES: list[int], W_MULT: list[int], ): - W = torch.zeros((bs, W_alloc_size), device=device) for i, ((i_start, i_end), (w_start, w_end), n, rep_size, w_mult) in enumerate( zip(I_REGIONS, W_REGIONS, N, REP_SIZES, W_MULT) diff --git a/torchemlp/utils.py b/torchemlp/utils.py index a3d5de8..573e614 100644 --- a/torchemlp/utils.py +++ b/torchemlp/utils.py @@ -1,9 +1,11 @@ import matplotlib.pyplot as plt import torch +import torch.backends -DEFAULT_DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +DEFAULT_DEVICE_STR = "cuda" if torch.cuda.is_available() else "cpu" +DEFAULT_DEVICE = torch.device(DEFAULT_DEVICE_STR) def merge_torch_types(dtype1, dtype2, device: torch.device = DEFAULT_DEVICE): @@ -70,3 +72,14 @@ def binom(n, k): n -= 1 return int(b) + + +def unpack_hsys(z: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + D = z.shape[-1] + assert D % 2 == 0 + d = D // 2 + + q = z[..., :d] + p = z[..., d:] + + return q, p