diff --git a/imagen_pytorch/imagen_pytorch.py b/imagen_pytorch/imagen_pytorch.py index d2a5e5a..3ef637c 100644 --- a/imagen_pytorch/imagen_pytorch.py +++ b/imagen_pytorch/imagen_pytorch.py @@ -207,8 +207,8 @@ def __init__(self, *, noise_schedule, timesteps = 1000): def get_times(self, batch_size, noise_level, *, device): return torch.full((batch_size,), noise_level, device = device, dtype = torch.float32) - def sample_random_times(self, batch_size, max_thres = 0.999, *, device): - return torch.zeros((batch_size,), device = device).float().uniform_(0, max_thres) + def sample_random_times(self, batch_size, *, device): + return torch.zeros((batch_size,), device = device).float().uniform_(0, 1) def get_condition(self, times): return maybe(self.log_snr)(times) diff --git a/imagen_pytorch/version.py b/imagen_pytorch/version.py index 1c19d78..2ed8108 100644 --- a/imagen_pytorch/version.py +++ b/imagen_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.15.0' +__version__ = '1.15.1' diff --git a/setup.py b/setup.py index f6b46e4..c6a28eb 100644 --- a/setup.py +++ b/setup.py @@ -28,6 +28,7 @@ install_requires=[ 'accelerate', 'click', + 'datasets', 'einops>=0.4', 'einops-exts', 'ema-pytorch>=0.0.3',