diff --git a/examples/example-5.5-oc-jax.ipynb b/examples/example-5.5-oc-jax.ipynb new file mode 100644 index 00000000..f0ac5017 --- /dev/null +++ b/examples/example-5.5-oc-jax.ipynb @@ -0,0 +1,447 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "import numpy as np\n", + "from neurolib.models.jax.wc import WCModel\n", + "from neurolib.models.jax.wc.timeIntegration import timeIntegration_args, timeIntegration_elementwise\n", + "from neurolib.optimize.autodiff.wc_optimizer import args_names\n", + "\n", + "import logging" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "model = WCModel()\n", + "\n", + "model.params.duration = 203\n", + "model.params.sigma_ou = 0" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def update_control_with_limit(N, dim_in, T, control, step, gradient, u_max):\n", + " return control + step * gradient" + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "metadata": {}, + "outputs": [], + "source": [ + "from neurolib.control.optimal_control.oc import getdefaultweights\n", + "\n", + "class OcWc:\n", + " def __init__(\n", + " self,\n", + " model,\n", + " target,\n", + " opt_params=['exc_ext']\n", + " ):\n", + " self.model = model\n", + " self.target = target\n", + " self.opt_params = opt_params\n", + " self.weights = getdefaultweights()\n", + " self.M = 1\n", + " \n", + " args_values = timeIntegration_args(self.model.params)\n", + " self.args = dict(zip(args_names, args_values))\n", + "\n", + " self.loss = self.get_loss()\n", + " self.compute_gradient = jax.jit(jax.grad(self.loss))\n", + " self.T = self.args['exc_ext'].shape[1]\n", + " self.control = jnp.zeros_like(self.args['exc_ext'], dtype=float)#TODO: depend on opt_params\n", + "\n", + " self.step = 10.0 # Initial step size in first optimization iteration.\n", + " self.count_noisy_step = 10\n", + " self.count_step = 30\n", + "\n", + " self.factor_down = 0.5 # Factor for adaptive step size reduction.\n", + " self.factor_up = 2.0 # Factor for adaptive step size increment.\n", + " \n", + " self.cost_history = []\n", + " self.step_sizes_history = []\n", + " self.step_sizes_loops_history = []\n", + "\n", + " self.dim_vars = len(self.model.state_vars)\n", + " self.dim_in = 1\n", + " self.dim_out = len(self.model.output_vars)\n", + " self.maximum_control_strength = 0\n", + "\n", + " self.print_array = []\n", + " self.zero_step_encountered = False # deterministic gradient descent cannot further improve\n", + "\n", + "\n", + " def simulate(self, control):\n", + " args_local = self.args.copy()\n", + " args_local.update(dict(zip(self.opt_params, [control])))\n", + " return timeIntegration_elementwise(**args_local)\n", + " \n", + " def get_loss(self):\n", + " @jax.jit\n", + " def loss(control):\n", + " t, exc, inh, exc_ou, inh_ou = self.simulate(control)\n", + " return self.compute_total_cost(control, exc)\n", + " return loss\n", + " \n", + " def accuracy_cost(self, exc):\n", + " return self.weights[\"w_p\"] * 0.5 * self.model.params.dt * jnp.sum((exc - self.target)**2)\n", + " \n", + " def control_strength_cost(self, control):\n", + " return self.weights[\"w_2\"] * 0.5 * self.model.params.dt * jnp.sum(control**2)\n", + "\n", + " def compute_total_cost(self, control, exc):\n", + " \"\"\"Compute the total cost as weighted sum precision of all contributing cost terms.\n", + " :rtype: float\n", + " \"\"\"\n", + " accuracy_cost = self.accuracy_cost(exc)\n", + " control_strength_cost = self.control_strength_cost(control)\n", + " return accuracy_cost + control_strength_cost\n", + " \n", + " def optimize_deterministic(self, n_max_iterations):\n", + " \"\"\"Compute the optimal control signal for noise averaging method 0 (deterministic, M=1).\n", + "\n", + " :param n_max_iterations: maximum number of iterations of gradient descent\n", + " :type n_max_iterations: int\n", + " \"\"\"\n", + "\n", + " # (I) forward simulation\n", + " t, exc, inh, exc_ou, inh_ou = self.simulate(self.control) # yields x(t)\n", + "\n", + " cost = self.compute_total_cost(self.control, exc)\n", + " print(f\"Cost in iteration 0: %s\" % (cost))\n", + " if len(self.cost_history) == 0: # add only if control model has not yet been optimized\n", + " self.cost_history.append(cost)\n", + "\n", + " for i in range(1, n_max_iterations + 1):\n", + " self.gradient = self.compute_gradient(self.control)\n", + "\n", + " self.step_size(-self.gradient)\n", + " t, exc, inh, exc_ou, inh_ou = self.simulate(self.control)\n", + "\n", + " cost = self.compute_total_cost(self.control, exc)\n", + " if i in self.print_array:\n", + " print(f\"Cost in iteration %s: %s\" % (i, cost))\n", + " self.cost_history.append(cost)\n", + "\n", + " if self.zero_step_encountered:\n", + " print(f\"Converged in iteration %s with cost %s\" % (i, cost))\n", + " break\n", + "\n", + " print(f\"Final cost : %s\" % (cost))\n", + "\n", + " def step_size(self, cost_gradient):\n", + " \"\"\"Adaptively choose a step size for control update.\n", + "\n", + " :param cost_gradient: N x V x T gradient of the total cost wrt. control.\n", + " :type cost_gradient: np.ndarray\n", + "\n", + " :return: Step size that got multiplied with the 'cost_gradient'.\n", + " :rtype: float\n", + " \"\"\"\n", + " if self.M > 1:\n", + " noisy = True\n", + " else:\n", + " noisy = False\n", + "\n", + " t, exc, inh, exc_ou, inh_ou = self.simulate(self.control)\n", + " if noisy:\n", + " cost0 = self.compute_cost_noisy(self.M)\n", + " else:\n", + " cost0 = (\n", + " self.compute_total_cost(self.control, exc)\n", + " ) # Current cost without updating the control according to the \"cost_gradient\".\n", + "\n", + " step = self.step # Load step size of last optimization-iteration as initial guess.\n", + "\n", + " control0 = self.control # Memorize unchanged control throughout step-size computation.\n", + "\n", + " while True: # Reduce the step size, if numerical instability occurs in the forward-simulation.\n", + " # inplace updating of models control bc. forward-sim relies on models parameters\n", + " self.control = update_control_with_limit(\n", + " self.model.params.N, self.dim_in, self.T, control0, step, cost_gradient, self.maximum_control_strength\n", + " )\n", + " ##self.update_input()\n", + "\n", + " # Input signal might be too high and produce diverging values in simulation.\n", + " t, exc, inh, exc_ou, inh_ou = self.simulate(self.control)\n", + "\n", + " #TODO\n", + " \"\"\"\n", + " if np.isnan(self.get_xs()).any(): # Detect numerical instability due to too large control update.\n", + " step *= self.factor_down**2 # Double the step for faster search of stable region.\n", + " self.step = step\n", + " print(f\"Diverging model output, decrease step size to {step}.\")\n", + " else:\n", + " break\n", + " \"\"\"\n", + " break\n", + " \n", + " if noisy:\n", + " cost = self.compute_cost_noisy(self.M)\n", + " else:\n", + " cost = (\n", + " self.compute_total_cost(self.control, exc)\n", + " ) # Cost after applying control update according to gradient with first valid\n", + " # step size (numerically stable).\n", + " # print(cost, cost0)\n", + " if (\n", + " cost > cost0\n", + " ): # If the cost choosing the first (stable) step size is no improvement, reduce step size by bisection.\n", + " step, counter = self.decrease_step(cost, cost0, step, control0, self.factor_down, cost_gradient)\n", + "\n", + " elif (\n", + " cost < cost0\n", + " ): # If the cost is improved with the first (stable) step size, search for larger steps with even better\n", + " # reduction of cost.\n", + "\n", + " step, counter = self.increase_step(cost, cost0, step, control0, self.factor_up, cost_gradient)\n", + "\n", + " else: # Remark: might be included as part of adaptive search for further improvement.\n", + " step = 0.0 # For later analysis only.\n", + " counter = 0\n", + " self.zero_step_encountered = True\n", + "\n", + " self.step = step # Memorize the last step size for the next optimization step with next gradient.\n", + "\n", + " self.step_sizes_loops_history.append(counter)\n", + " self.step_sizes_history.append(step)\n", + "\n", + " return step\n", + "\n", + " def decrease_step(self, cost, cost0, step, control0, factor_down, cost_gradient):\n", + " \"\"\"Find a step size which leads to improved cost given the gradient. The step size is iteratively decreased.\n", + " The control-inputs are updated in place according to the found step size via the\n", + " \"####self.update_input()\" call.\n", + "\n", + " :param cost: Cost after applying control update according to gradient with first valid step size (numerically\n", + " stable).\n", + " :type cost: float\n", + " :param cost0: Cost without updating the control.\n", + " :type cost0: float\n", + " :param step: Step size initial to the iterative decreasing.\n", + " :type step: float\n", + " :param control0: The unchanged control signal.\n", + " :type control0: np.ndarray N x V x T\n", + " :param factor_down: Factor the step size is scaled with in each iteration until cost is improved.\n", + " :type factor_down: float\n", + " :param cost_gradient: Gradient of the total cost wrt. the control signal.\n", + " :type cost_gradient: np.ndarray of shape N x V x T\n", + "\n", + " :return: The selected step size and the count-variable how often step-adjustment-loop was executed.\n", + " :rtype: tuple[float, int]\n", + " \"\"\"\n", + " if self.M > 1:\n", + " noisy = True\n", + " else:\n", + " noisy = False\n", + "\n", + " counter = 0\n", + "\n", + " while cost > cost0: # Decrease the step size until first step size is found where cost is improved.\n", + " step *= factor_down # Decrease step size.\n", + " counter += 1\n", + " # print(step, cost, cost0)\n", + "\n", + " # Inplace updating of models control bc. forward-sim relies on models parameters.\n", + " self.control = update_control_with_limit(\n", + " self.model.params.N, self.dim_in, self.T, control0, step, cost_gradient, self.maximum_control_strength\n", + " )\n", + " #self.update_input()\n", + "\n", + " # Simulate with control updated according to new step and evaluate cost.\n", + " t, exc, inh, exc_ou, inh_ou = self.simulate(self.control)\n", + "\n", + " if noisy:\n", + " cost = self.compute_cost_noisy(self.M)\n", + " else:\n", + " cost = self.compute_total_cost(self.control, exc)\n", + "\n", + " if counter == self.count_step: # Exit if the maximum search depth is reached without improvement of\n", + " # cost.\n", + " step = 0.0 # For later analysis only.\n", + " self.control = update_control_with_limit(\n", + " self.model.params.N, self.dim_in, self.T, control0, 0.0, jnp.zeros_like(control0, dtype=float), self.maximum_control_strength\n", + " )\n", + " #self.update_input()\n", + "\n", + " self.zero_step_encountered = True\n", + " break\n", + "\n", + " return step, counter\n", + "\n", + " def increase_step(self, cost, cost0, step, control0, factor_up, cost_gradient):\n", + " \"\"\"Find the largest step size which leads to the biggest improvement of cost given the gradient. The step size is\n", + " iteratively increased. The control-inputs are updated in place according to the found step size via the\n", + " \"self.update_input()\" call.\n", + "\n", + " :param cost: Cost after applying control update according to gradient with first valid step size (numerically\n", + " stable).\n", + " :type cost: float\n", + " :param cost0: Cost without updating the control.\n", + " :type cost0: float\n", + " :param step: Step size initial to the iterative decreasing.\n", + " :type step: float\n", + " :param control0: The unchanged control signal.\n", + " :type control0: np.ndarray N x V x T\n", + " :param factor_up: Factor the step size is scaled with in each iteration while the cost keeps improving.\n", + " :type factor_up: float\n", + " :param cost_gradient: Gradient of the total cost wrt. the control signal.\n", + " :type cost_gradient: np.ndarray of shape N x V x T\n", + "\n", + " :return: The selected step size and the count-variable how often step-adjustment-loop was executed.\n", + " :rtype: tuple[float, int]\n", + " \"\"\"\n", + " if self.M > 1:\n", + " noisy = True\n", + " else:\n", + " noisy = False\n", + "\n", + " cost_prev = cost0\n", + " counter = 0\n", + "\n", + " while cost < cost_prev: # Increase the step size as long as the cost is improving.\n", + " step *= factor_up\n", + " counter += 1\n", + "\n", + " # Inplace updating of models control bc. forward-sim relies on models parameters\n", + " self.control = update_control_with_limit(\n", + " self.model.params.N, self.dim_in, self.T, control0, step, cost_gradient, self.maximum_control_strength\n", + " )\n", + " #self.update_input()\n", + "\n", + " t, exc, inh, exc_ou, inh_ou = self.simulate(self.control)\n", + " #TODO\n", + " \"\"\"\n", + " if np.isnan(self.get_xs()).any(): # Go back to last step (that was numerically stable and improved cost)\n", + " # and exit.\n", + " logging.info(\"Increasing step encountered NAN.\")\n", + " step /= factor_up # Undo the last step update by inverse operation.\n", + " self.control = update_control_with_limit(\n", + " self.model.params.N, self.dim_in, self.T, control0, step, cost_gradient, self.maximum_control_strength\n", + " )\n", + " #self.update_input()\n", + " break\n", + "\n", + " else:\n", + " \"\"\"\n", + " if noisy:\n", + " cost = self.compute_cost_noisy(self.M)\n", + " else:\n", + " cost = self.compute_total_cost(self.control, exc)\n", + "\n", + " if cost > cost_prev: # If the cost increases: go back to last step (that resulted in best cost until\n", + " # then) and exit.\n", + " step /= factor_up # Undo the last step update by inverse operation.\n", + " self.control = update_control_with_limit(\n", + " self.model.params.N, self.dim_in, self.T, control0, step, cost_gradient, self.maximum_control_strength\n", + " )\n", + " self.update_input()\n", + " break\n", + "\n", + " else:\n", + " cost_prev = cost # Memorize cost with this step size for comparison in next step-update.\n", + "\n", + " if counter == self.count_step:\n", + " # Terminate step size search at count limit, exit with the best performing step size.\n", + " break\n", + "\n", + " return step, counter\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "metadata": {}, + "outputs": [], + "source": [ + "args_values = timeIntegration_args(model.params)\n", + "\n", + "args = dict(zip(args_names, args_values))" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "metadata": {}, + "outputs": [], + "source": [ + "ones_target = jnp.ones_like(args['exc_ext'], dtype=float)" + ] + }, + { + "cell_type": "code", + "execution_count": 72, + "metadata": {}, + "outputs": [], + "source": [ + "oc_wc = OcWc(model, ones_target)" + ] + }, + { + "cell_type": "code", + "execution_count": 73, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Cost in iteration 0: 99.194\n", + "Final cost : 25.754984\n" + ] + } + ], + "source": [ + "oc_wc.optimize_deterministic(10)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "args_local = args.copy()\n", + "args_local.update(dict(zip(['exc_ext'], [oc_wc.control])))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "neurolib_jax", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}