diff --git a/dsprites_reloading_example.ipynb b/dsprites_reloading_example.ipynb index a015e9a..36ccd20 100644 --- a/dsprites_reloading_example.ipynb +++ b/dsprites_reloading_example.ipynb @@ -139,7 +139,7 @@ ], "source": [ "# Load dataset\n", - "dataset_zip = np.load('dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz')\n", + "dataset_zip = np.load('dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz', allow_pickle=True, encoding='latin1')\n", "\n", "print('Keys in the dataset:', dataset_zip.keys())\n", "imgs = dataset_zip['imgs']\n",