diff --git a/src/gluonts/torch/distributions/__init__.py b/src/gluonts/torch/distributions/__init__.py index 18a27245c6..e279e023df 100644 --- a/src/gluonts/torch/distributions/__init__.py +++ b/src/gluonts/torch/distributions/__init__.py @@ -36,6 +36,7 @@ SplicedBinnedParetoOutput, ) from .studentT import StudentTOutput +from .truncated_normal import TruncatedNormal, TruncatedNormalOutput __all__ = [ "AffineTransformed", @@ -62,4 +63,6 @@ "SplicedBinnedPareto", "SplicedBinnedParetoOutput", "StudentTOutput", + "TruncatedNormal", + "TruncatedNormalOutput", ] diff --git a/src/gluonts/torch/distributions/truncated_normal.py b/src/gluonts/torch/distributions/truncated_normal.py new file mode 100644 index 0000000000..a3ff655b1f --- /dev/null +++ b/src/gluonts/torch/distributions/truncated_normal.py @@ -0,0 +1,292 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import math +from numbers import Number +from typing import Dict, Optional, Tuple, Union + +import torch +from torch.distributions import constraints +from torch.distributions.utils import broadcast_all + +import torch.nn.functional as F +from .distribution_output import DistributionOutput +from gluonts.core.component import validated +from torch.distributions import Distribution + +CONST_SQRT_2 = math.sqrt(2) +CONST_INV_SQRT_2PI = 1 / math.sqrt(2 * math.pi) +CONST_INV_SQRT_2 = 1 / math.sqrt(2) +CONST_LOG_INV_SQRT_2PI = math.log(CONST_INV_SQRT_2PI) +CONST_LOG_SQRT_2PI_E = 0.5 * math.log(2 * math.pi * math.e) + + +class TruncatedNormal(Distribution): + """Implements a Truncated Normal distribution with location scaling. + + Location scaling prevents the location to be "too far" from 0, which ultimately + leads to numerically unstable samples and poor gradient computation (e.g. gradient explosion). + In practice, the location is computed according to + + .. math:: + loc = tanh(loc / upscale) * upscale. + + This behaviour can be disabled by switching off the tanh_loc parameter (see below). + + Parameters + ---------- + loc: + normal distribution location parameter + scale: + normal distribution sigma parameter (squared root of variance) + min: + minimum value of the distribution. Default = -1.0 + max: + maximum value of the distribution. Default = 1.0 + upscale: + scaling factor. Default = 5.0 + tanh_loc: + if ``True``, the above formula is used for + the location scaling, otherwise the raw value is kept. + Default is ``False`` + + References + ---------- + - https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + + Notes + ----- + This implementation is strongly based on: + - https://github.com/pytorch/rl/blob/main/torchrl/modules/distributions/truncated_normal.py + - https://github.com/toshas/torch_truncnorm + """ + + arg_constraints = { + "loc": constraints.real, + "scale": constraints.greater_than(1e-6), + } + + has_rsample = True + eps = 1e-6 + + def __init__( + self, + loc: torch.Tensor, + scale: torch.Tensor, + min: Union[torch.Tensor, float] = -1.0, + max: Union[torch.Tensor, float] = 1.0, + upscale: Union[torch.Tensor, float] = 5.0, + tanh_loc: bool = False, + ): + scale = scale.clamp_min(self.eps) + if tanh_loc: + loc = (loc / upscale).tanh() * upscale + loc = loc + (max - min) / 2 + min + + self.min = min + self.max = max + self.upscale = upscale + self.loc, self.scale, a, b = broadcast_all( + loc, scale, self.min, self.max + ) + self._non_std_a = a + self._non_std_b = b + self.a = (a - self.loc) / self.scale + self.b = (b - self.loc) / self.scale + + if isinstance(a, Number) and isinstance(b, Number): + batch_shape = torch.Size() + else: + batch_shape = self.a.size() + + super(TruncatedNormal, self).__init__(batch_shape) + if self.a.dtype != self.b.dtype: + raise ValueError("Truncation bounds types are different") + if any( + (self.a >= self.b) + .view( + -1, + ) + .tolist() + ): + raise ValueError("Incorrect truncation range") + + eps = self.eps + self._dtype_min_gt_0 = eps + self._dtype_max_lt_1 = 1 - eps + self._little_phi_a = self._little_phi(self.a) + self._little_phi_b = self._little_phi(self.b) + self._big_phi_a = self._big_phi(self.a) + self._big_phi_b = self._big_phi(self.b) + self._Z = (self._big_phi_b - self._big_phi_a).clamp(eps, 1 - eps) + self._log_Z = self._Z.log() + little_phi_coeff_a = torch.nan_to_num(self.a, nan=math.nan) + little_phi_coeff_b = torch.nan_to_num(self.b, nan=math.nan) + self._lpbb_m_lpaa_d_Z = ( + self._little_phi_b * little_phi_coeff_b + - self._little_phi_a * little_phi_coeff_a + ) / self._Z + self._mean = -(self._little_phi_b - self._little_phi_a) / self._Z + self._variance = ( + 1 + - self._lpbb_m_lpaa_d_Z + - ((self._little_phi_b - self._little_phi_a) / self._Z) ** 2 + ) + self._entropy = ( + CONST_LOG_SQRT_2PI_E + self._log_Z - 0.5 * self._lpbb_m_lpaa_d_Z + ) + self._log_scale = self.scale.log() + self._mean_non_std = self._mean * self.scale + self.loc + self._variance_non_std = self._variance * self.scale**2 + self._entropy_non_std = self._entropy + self._log_scale + + @constraints.dependent_property + def support(self): + return constraints.interval(self.a, self.b) + + @property + def mean(self): + return self._mean_non_std + + @property + def variance(self): + return self._variance_non_std + + @property + def entropy(self): + return self._entropy_non_std + + @staticmethod + def _little_phi(x): + return (-(x**2) * 0.5).exp() * CONST_INV_SQRT_2PI + + def _big_phi(self, x): + phi = 0.5 * (1 + (x * CONST_INV_SQRT_2).erf()) + return phi.clamp(self.eps, 1 - self.eps) + + @staticmethod + def _inv_big_phi(x): + return CONST_SQRT_2 * (2 * x - 1).erfinv() + + def cdf_truncated_standard_normal(self, value): + return ((self._big_phi(value) - self._big_phi_a) / self._Z).clamp(0, 1) + + def icdf_truncated_standard_normal(self, value): + y = self._big_phi_a + value * self._Z + y = y.clamp(self.eps, 1 - self.eps) + return self._inv_big_phi(y) + + def log_prob_truncated_standard_normal(self, value): + return CONST_LOG_INV_SQRT_2PI - self._log_Z - (value**2) * 0.5 + + def _to_std_rv(self, value): + return (value - self.loc) / self.scale + + def _from_std_rv(self, value): + return value * self.scale + self.loc + + def cdf(self, value): + return self.cdf_truncated_standard_normal(self._to_std_rv(value)) + + def icdf(self, value): + sample = self._from_std_rv(self.icdf_truncated_standard_normal(value)) + + # clamp data but keep gradients + sample_clip = torch.stack( + [sample.detach(), self._non_std_a.detach().expand_as(sample)], 0 + ).max(0)[0] + sample_clip = torch.stack( + [sample_clip, self._non_std_b.detach().expand_as(sample)], 0 + ).min(0)[0] + sample.data.copy_(sample_clip) + return sample + + def log_prob(self, value): + a = self._non_std_a + self._dtype_min_gt_0 + a = a.expand_as(value) + b = self._non_std_b - self._dtype_min_gt_0 + b = b.expand_as(value) + value = torch.min(torch.stack([value, b], -1), dim=-1)[0] + value = torch.max(torch.stack([value, a], -1), dim=-1)[0] + value = self._to_std_rv(value) + return self.log_prob_truncated_standard_normal(value) - self._log_scale + + def rsample(self, sample_shape=None): + if sample_shape is None: + sample_shape = torch.Size([]) + shape = self._extended_shape(sample_shape) + p = torch.empty(shape, device=self.a.device).uniform_( + self._dtype_min_gt_0, self._dtype_max_lt_1 + ) + return self.icdf(p) + + +class TruncatedNormalOutput(DistributionOutput): + distr_cls: type = TruncatedNormal + + @validated() + def __init__( + self, + min: float = -1.0, + max: float = 1.0, + upscale: float = 5.0, + tanh_loc: bool = False, + ) -> None: + assert min < max, "max must be strictly greater than min" + + super().__init__(self) + + self.min = min + self.max = max + self.upscale = upscale + self.tanh_loc = tanh_loc + self.args_dim: Dict[str, int] = { + "loc": 1, + "scale": 1, + } + + @classmethod + def domain_map( + cls, + loc: torch.Tensor, + scale: torch.Tensor, + ): + scale = F.softplus(scale) + + return ( + loc.squeeze(-1), + scale.squeeze(-1), + ) + + # Overwrites the parent class method: We pass constant float and + # boolean parameters across tensors + def distribution( + self, + distr_args, + loc: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, + ) -> Distribution: + (loc, scale) = distr_args + + return TruncatedNormal( + loc=loc, + scale=scale, + upscale=self.upscale, + min=self.min, + max=self.max, + tanh_loc=self.tanh_loc, + ) + + @property + def event_shape(self) -> Tuple: + return () diff --git a/test/torch/distribution/test_truncated_normal.py b/test/torch/distribution/test_truncated_normal.py new file mode 100644 index 0000000000..c1582af0de --- /dev/null +++ b/test/torch/distribution/test_truncated_normal.py @@ -0,0 +1,52 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import pytest +import torch +from gluonts.torch.distributions import TruncatedNormal + +# Mostly taken from https://github.com/pytorch/rl/blob/main/test/test_distributions.py#L127 + + +@pytest.mark.parametrize( + "min", [-torch.ones(3), -1, 3 * torch.tensor([-1.0, -2.0, -0.5]), -0.1] +) +@pytest.mark.parametrize( + "max", [torch.ones(3), 1, 3 * torch.tensor([1.0, 2.0, 0.5]), 0.1] +) +@pytest.mark.parametrize( + "vecs", + [ + (torch.tensor([0.1, 10.0, 5.0]), torch.tensor([0.1, 10.0, 5.0])), + (torch.zeros(7, 3), torch.ones(7, 3)), + ], +) +@pytest.mark.parametrize( + "upscale", [torch.ones(3), 1, 3 * torch.tensor([1.0, 2.0, 0.5]), 3] +) +@pytest.mark.parametrize("shape", [torch.Size([]), torch.Size([3, 4])]) +def test_truncnormal(min, max, vecs, upscale, shape): + torch.manual_seed(0) + d = TruncatedNormal( + *vecs, + upscale=upscale, + min=min, + max=max, + ) + for _ in range(100): + a = d.rsample(shape) + assert a.shape[: len(shape)] == shape + assert (a >= d.min).all() + assert (a <= d.max).all() + lp = d.log_prob(a) + assert torch.isfinite(lp).all()