From f2a01ebfc26b3e08aa2c0f7cd94acf62be12296d Mon Sep 17 00:00:00 2001 From: The swirl_dynamics Authors Date: Fri, 23 Feb 2024 15:57:28 -0800 Subject: [PATCH] Code update PiperOrigin-RevId: 609860237 --- .../projects/evolve_smoothly/README.md | 2 +- .../weno_nn/colabs/init_network.ipynb | 163 ++++++++++++++++++ 2 files changed, 164 insertions(+), 1 deletion(-) create mode 100644 swirl_dynamics/projects/weno_nn/colabs/init_network.ipynb diff --git a/swirl_dynamics/projects/evolve_smoothly/README.md b/swirl_dynamics/projects/evolve_smoothly/README.md index 959e846..9ab4065 100644 --- a/swirl_dynamics/projects/evolve_smoothly/README.md +++ b/swirl_dynamics/projects/evolve_smoothly/README.md @@ -2,7 +2,7 @@ ## Abstract -We present a data-driven, space-time continuous framework to learn surrogate models for complex physical systems described by advection-dominated partial differential equations. Those systems have slow-decaying Kolmogorov $n$-width that hinders standard methods, including reduced order modeling, from producing high-fidelity simulations at low cost. In this work, we construct hypernetwork-based latent dynamical models directly on the parameter space of a compact representation network. We leverage the expressive power of the network and a specially designed consistency-inducing regularization to obtain latent trajectories that are both low-dimensional and smooth. These properties render our surrogate models highly efficient at inference time. We show the efficacy of our framework by learning models that generate accurate multi-step rollout predictions at much faster inference speed compared to competitors, for several challenging examples. +We present a data-driven, space-time continuous framework to learn surrogate models for complex physical systems described by advection-dominated partial differential equations. Those systems have slow-decaying Kolmogorov $$n$$-width that hinders standard methods, including reduced order modeling, from producing high-fidelity simulations at low cost. In this work, we construct hypernetwork-based latent dynamical models directly on the parameter space of a compact representation network. We leverage the expressive power of the network and a specially designed consistency-inducing regularization to obtain latent trajectories that are both low-dimensional and smooth. These properties render our surrogate models highly efficient at inference time. We show the efficacy of our framework by learning models that generate accurate multi-step rollout predictions at much faster inference speed compared to competitors, for several challenging examples. ## Datasets diff --git a/swirl_dynamics/projects/weno_nn/colabs/init_network.ipynb b/swirl_dynamics/projects/weno_nn/colabs/init_network.ipynb new file mode 100644 index 0000000..c41da74 --- /dev/null +++ b/swirl_dynamics/projects/weno_nn/colabs/init_network.ipynb @@ -0,0 +1,163 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OYIaUeqU_GoF" + }, + "outputs": [], + "source": [ + "!pip install git+https://github.com/google-research/swirl-dynamics.git@main" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "t-c9iNBAy75u" + }, + "outputs": [], + "source": [ + "#@title Imports\n", + "from typing import Any, Literal, Optional\n", + "import flax.linen as nn\n", + "import functools\n", + "import jax\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "jax.config.update(\"jax_enable_x64\", True)\n", + "from swirl_dynamics.data import hdf5_utils\n", + "from swirl_dynamics.lib.networks import rational_networks\n", + "from swirl_dynamics.projects.weno_nn import weno\n", + "from swirl_dynamics.projects.weno_nn import weno_nn\n", + "from swirl_dynamics.projects.weno_nn import utils" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "kF-DZ55uzdp6" + }, + "outputs": [], + "source": [ + "flax_model_main_folder = '../model_weights/'\n", + "xid=94741459\n", + "model_num=113\n", + "filename = flax_model_main_folder+f'/xid_{xid}_model_{model_num}.hdf5'\n", + "network_vars = hdf5_utils.read_all_arrays_as_dict(filename)\n", + "mlp_model = weno_nn.OmegaNN(\n", + " features=tuple(network_vars['config']['features'].astype(int)),\n", + " features_fun=utils.get_feature_func(network_vars['config']['features_fun'].decode()),\n", + " act_fun=utils.get_act_func(network_vars['config']['act_fun'].decode()),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "dLBi5s2GbwOz" + }, + "outputs": [], + "source": [ + "x=np.linspace(0.0, 1.0, 101)\n", + "u=np.sin(np.pi*x)\n", + "# Stack neighbor information for [u_{i-1}, u_{i}, u_{i+1}].\n", + "u_nb=np.stack([u[:-2], u[1:-1], u[2:]], axis=1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "tC_fW3IicqHV" + }, + "outputs": [], + "source": [ + "# Individual functions are written over scalar inputs of\n", + "# [u_{i-1}, u_{i}, u_{i+1}]. Hence we vectorize over the axis for u_nb above.\n", + "# Function to perform interpolation:\n", + "weno_interp_func_vmap = jax.vmap(weno.weno_interpolation, in_axes=(0, None))\n", + "model_apply_func = functools.partial(mlp_model.apply, test=True)\n", + "# Function to calculate WENO-weights:\n", + "weno_nn_wt_func_vmap = jax.vmap(model_apply_func, in_axes=(None,0))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "lK3lUqPmd1hp" + }, + "outputs": [], + "source": [ + "# Estimate WENO-weights on the negative side:\n", + "wt_neg = weno_nn_wt_func_vmap({\"params\": network_vars[\"params\"]}, u_nb)\n", + "# Perform WENO interpolation on both positive and negative sides.\n", + "u_interp = weno_interp_func_vmap(\n", + " u_nb,\n", + " lambda x, params: model_apply_func({\"params\": network_vars[\"params\"]}, x),\n", + ")\n", + "# Unstack the positive and negative side interpolations.\n", + "u_interp_pos = u_interp[:, 0]\n", + "u_interp_neg = u_interp[:, 1]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2U3qVwNSeGE0" + }, + "outputs": [], + "source": [ + "plt.figure()\n", + "plt.plot(x[1:-1], wt_neg[:,0]); plt.ylim([0.0,1.0]);\n", + "plt.plot(x[1:-1], wt_neg[:,1]); plt.ylim([0.0,1.0]);\n", + "plt.xlabel('X'); plt.ylabel('WENO Weights');" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "cXdHHtYVe5Jy" + }, + "outputs": [], + "source": [ + "x_half = x+(x[1]-x[0])*0.5\n", + "plt.figure()\n", + "plt.plot(x_half[1:-1], u_interp_pos, '-.b', label='Pos');\n", + "plt.plot(x_half[1:-1], u_interp_neg, '-.g', label='Neg');\n", + "plt.plot(x, u, '--r', label='Cell');\n", + "plt.xlabel('X'); plt.ylabel('WENO Weight'); plt.legend();" + ] + } + ], + "metadata": { + "colab": { + "last_runtime": { + "build_target": "//research/simulation/tools:notebook", + "kind": "shared" + }, + "provenance": [ + { + "file_id": "12h2QSHNj9FUYRqR6axIbK9dFZoQ29-IH", + "timestamp": 1707952544896 + } + ] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}