diff --git a/mnist_from_scratch.ipynb b/mnist_from_scratch.ipynb index a07fe17..df182fa 100644 --- a/mnist_from_scratch.ipynb +++ b/mnist_from_scratch.ipynb @@ -325,7 +325,7 @@ " x_relu = np.maximum(x_l1, 0)\n", " x_l2 = x_relu.dot(l2)\n", " x_lsm = x_l2 - logsumexp(x_l2).reshape((-1, 1))\n", - " x_loss = (-out * x_lsm).mean(axis=1)\n", + " x_loss = -x_lsm[np.arange(x_lsm.shape[0]), y]\n", "\n", " # training in numpy (super hard!)\n", " # backward pass\n", @@ -545,7 +545,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.6" + "version": "3.8.5" } }, "nbformat": 4,