diff --git a/simple_einet/data.py b/simple_einet/data.py index 8cfc60c..de2ec01 100644 --- a/simple_einet/data.py +++ b/simple_einet/data.py @@ -115,6 +115,7 @@ def get_data_shape(dataset_name: str) -> Shape: "flowers": (3, 32, 32), "tiny-imagenet": (3, 32, 32), "lfw": (3, 32, 32), + "digits": (1, 8, 8), }[dataset_name] )