diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py b/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py index 9283fc3a..8a254ad9 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py @@ -28,7 +28,7 @@ from codetiming import Timer import logging # Keep the import below for registering all model definitions -from .models import det_cunet # cunet, cncsnpp +from .models import det_cunet, cunet, cncsnpp from . import losses from .models.location_params import LocationParams from . import sampling