From a95b334249c063530887c1e4153092c61643de6c Mon Sep 17 00:00:00 2001 From: Deepak CH Date: Fri, 2 Aug 2024 07:54:38 +0000 Subject: [PATCH] Fix Minibatch alignment in Bayesian Neural Network example --- .../bayesian_neural_network_advi.ipynb | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/examples/variational_inference/bayesian_neural_network_advi.ipynb b/examples/variational_inference/bayesian_neural_network_advi.ipynb index 731de639a..bc46b320a 100644 --- a/examples/variational_inference/bayesian_neural_network_advi.ipynb +++ b/examples/variational_inference/bayesian_neural_network_advi.ipynb @@ -190,7 +190,7 @@ }, "outputs": [], "source": [ - "def construct_nn(ann_input, ann_output):\n", + "def construct_nn():\n", " n_hidden = 5\n", "\n", " # Initialize random weights between each layer\n", @@ -204,9 +204,14 @@ " \"train_cols\": np.arange(X_train.shape[1]),\n", " \"obs_id\": np.arange(X_train.shape[0]),\n", " }\n", + " \n", " with pm.Model(coords=coords) as neural_network:\n", - " ann_input = pm.Data(\"ann_input\", X_train, dims=(\"obs_id\", \"train_cols\"))\n", - " ann_output = pm.Data(\"ann_output\", Y_train, dims=\"obs_id\")\n", + " # Define minibatch variables\n", + " minibatch_x, minibatch_y = pm.Minibatch(X_train, Y_train, batch_size=50)\n", + " \n", + " # Define data variables using minibatches\n", + " ann_input = pm.Data(\"ann_input\", minibatch_x, mutable=True, dims=(\"obs_id\", \"train_cols\"))\n", + " ann_output = pm.Data(\"ann_output\", minibatch_y, mutable=True, dims=\"obs_id\")\n", "\n", " # Weights from input to hidden layer\n", " weights_in_1 = pm.Normal(\n", @@ -231,13 +236,13 @@ " \"out\",\n", " act_out,\n", " observed=ann_output,\n", - " total_size=Y_train.shape[0], # IMPORTANT for minibatches\n", + " total_size=X_train.shape[0], # IMPORTANT for minibatches\n", " dims=\"obs_id\",\n", " )\n", " return neural_network\n", "\n", - "\n", - "neural_network = construct_nn(X_train, Y_train)" + "# Create the neural network model\n", + "neural_network = construct_nn()\n" ] }, {