Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add truncated normal distribution for torch distributions #2970

3 changes: 3 additions & 0 deletions src/gluonts/torch/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
SplicedBinnedParetoOutput,
)
from .studentT import StudentTOutput
from .truncated_normal import TruncatedNormal, TruncatedNormalOutput

__all__ = [
"AffineTransformed",
Expand All @@ -62,4 +63,6 @@
"SplicedBinnedPareto",
"SplicedBinnedParetoOutput",
"StudentTOutput",
"TruncatedNormal",
"TruncatedNormalOutput",
]
292 changes: 292 additions & 0 deletions src/gluonts/torch/distributions/truncated_normal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import math
from numbers import Number
from typing import Dict, Optional, Tuple, Union

import torch
from torch.distributions import constraints
from torch.distributions.utils import broadcast_all

import torch.nn.functional as F
from .distribution_output import DistributionOutput
from gluonts.core.component import validated
from torch.distributions import Distribution

CONST_SQRT_2 = math.sqrt(2)
CONST_INV_SQRT_2PI = 1 / math.sqrt(2 * math.pi)
CONST_INV_SQRT_2 = 1 / math.sqrt(2)
CONST_LOG_INV_SQRT_2PI = math.log(CONST_INV_SQRT_2PI)
CONST_LOG_SQRT_2PI_E = 0.5 * math.log(2 * math.pi * math.e)


class TruncatedNormal(Distribution):
"""Implements a Truncated Normal distribution with location scaling.

Location scaling prevents the location to be "too far" from 0, which ultimately
leads to numerically unstable samples and poor gradient computation (e.g. gradient explosion).
In practice, the location is computed according to

.. math::
loc = tanh(loc / upscale) * upscale.

This behaviour can be disabled by switching off the tanh_loc parameter (see below).

Parameters
----------
loc:
normal distribution location parameter
scale:
normal distribution sigma parameter (squared root of variance)
min:
minimum value of the distribution. Default = -1.0
max:
maximum value of the distribution. Default = 1.0
upscale:
scaling factor. Default = 5.0
tanh_loc:
if ``True``, the above formula is used for
the location scaling, otherwise the raw value is kept.
Default is ``False``

References
----------
- https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf

Notes
-----
This implementation is strongly based on:
- https://github.com/pytorch/rl/blob/main/torchrl/modules/distributions/truncated_normal.py
- https://github.com/toshas/torch_truncnorm
"""

arg_constraints = {
"loc": constraints.real,
"scale": constraints.greater_than(1e-6),
}

has_rsample = True
eps = 1e-6

def __init__(
self,
loc: torch.Tensor,
scale: torch.Tensor,
min: Union[torch.Tensor, float] = -1.0,
max: Union[torch.Tensor, float] = 1.0,
upscale: Union[torch.Tensor, float] = 5.0,
tanh_loc: bool = False,
):
scale = scale.clamp_min(self.eps)
if tanh_loc:
loc = (loc / upscale).tanh() * upscale
loc = loc + (max - min) / 2 + min

self.min = min
self.max = max
self.upscale = upscale
self.loc, self.scale, a, b = broadcast_all(
loc, scale, self.min, self.max
)
self._non_std_a = a
self._non_std_b = b
self.a = (a - self.loc) / self.scale
self.b = (b - self.loc) / self.scale

if isinstance(a, Number) and isinstance(b, Number):
batch_shape = torch.Size()
else:
batch_shape = self.a.size()

super(TruncatedNormal, self).__init__(batch_shape)
if self.a.dtype != self.b.dtype:
raise ValueError("Truncation bounds types are different")
if any(
(self.a >= self.b)
.view(
-1,
)
.tolist()
):
raise ValueError("Incorrect truncation range")

eps = self.eps
self._dtype_min_gt_0 = eps
self._dtype_max_lt_1 = 1 - eps
self._little_phi_a = self._little_phi(self.a)
self._little_phi_b = self._little_phi(self.b)
self._big_phi_a = self._big_phi(self.a)
self._big_phi_b = self._big_phi(self.b)
self._Z = (self._big_phi_b - self._big_phi_a).clamp(eps, 1 - eps)
self._log_Z = self._Z.log()
little_phi_coeff_a = torch.nan_to_num(self.a, nan=math.nan)
little_phi_coeff_b = torch.nan_to_num(self.b, nan=math.nan)
self._lpbb_m_lpaa_d_Z = (
self._little_phi_b * little_phi_coeff_b
- self._little_phi_a * little_phi_coeff_a
) / self._Z
self._mean = -(self._little_phi_b - self._little_phi_a) / self._Z
self._variance = (
1
- self._lpbb_m_lpaa_d_Z
- ((self._little_phi_b - self._little_phi_a) / self._Z) ** 2
)
self._entropy = (
CONST_LOG_SQRT_2PI_E + self._log_Z - 0.5 * self._lpbb_m_lpaa_d_Z
)
self._log_scale = self.scale.log()
self._mean_non_std = self._mean * self.scale + self.loc
self._variance_non_std = self._variance * self.scale**2
self._entropy_non_std = self._entropy + self._log_scale

@constraints.dependent_property
def support(self):
return constraints.interval(self.a, self.b)

@property
def mean(self):
return self._mean_non_std

@property
def variance(self):
return self._variance_non_std

@property
def entropy(self):
return self._entropy_non_std

@staticmethod
def _little_phi(x):
return (-(x**2) * 0.5).exp() * CONST_INV_SQRT_2PI

def _big_phi(self, x):
phi = 0.5 * (1 + (x * CONST_INV_SQRT_2).erf())
return phi.clamp(self.eps, 1 - self.eps)

@staticmethod
def _inv_big_phi(x):
return CONST_SQRT_2 * (2 * x - 1).erfinv()

def cdf_truncated_standard_normal(self, value):
return ((self._big_phi(value) - self._big_phi_a) / self._Z).clamp(0, 1)

def icdf_truncated_standard_normal(self, value):
y = self._big_phi_a + value * self._Z
y = y.clamp(self.eps, 1 - self.eps)
return self._inv_big_phi(y)

def log_prob_truncated_standard_normal(self, value):
return CONST_LOG_INV_SQRT_2PI - self._log_Z - (value**2) * 0.5

def _to_std_rv(self, value):
return (value - self.loc) / self.scale

def _from_std_rv(self, value):
return value * self.scale + self.loc

def cdf(self, value):
return self.cdf_truncated_standard_normal(self._to_std_rv(value))

def icdf(self, value):
sample = self._from_std_rv(self.icdf_truncated_standard_normal(value))

# clamp data but keep gradients
sample_clip = torch.stack(
[sample.detach(), self._non_std_a.detach().expand_as(sample)], 0
).max(0)[0]
sample_clip = torch.stack(
[sample_clip, self._non_std_b.detach().expand_as(sample)], 0
).min(0)[0]
sample.data.copy_(sample_clip)
return sample

def log_prob(self, value):
a = self._non_std_a + self._dtype_min_gt_0
a = a.expand_as(value)
b = self._non_std_b - self._dtype_min_gt_0
b = b.expand_as(value)
value = torch.min(torch.stack([value, b], -1), dim=-1)[0]
value = torch.max(torch.stack([value, a], -1), dim=-1)[0]
value = self._to_std_rv(value)
return self.log_prob_truncated_standard_normal(value) - self._log_scale

def rsample(self, sample_shape=None):
if sample_shape is None:
sample_shape = torch.Size([])
shape = self._extended_shape(sample_shape)
p = torch.empty(shape, device=self.a.device).uniform_(
self._dtype_min_gt_0, self._dtype_max_lt_1
)
return self.icdf(p)


class TruncatedNormalOutput(DistributionOutput):
distr_cls: type = TruncatedNormal

@validated()
def __init__(
self,
min: float = -1.0,
max: float = 1.0,
upscale: float = 5.0,
tanh_loc: bool = False,
) -> None:
assert min < max, "max must be strictly greater than min"

super().__init__(self)

self.min = min
self.max = max
self.upscale = upscale
self.tanh_loc = tanh_loc
self.args_dim: Dict[str, int] = {
"loc": 1,
"scale": 1,
}

@classmethod
def domain_map(
cls,
loc: torch.Tensor,
scale: torch.Tensor,
):
scale = F.softplus(scale)

return (
loc.squeeze(-1),
scale.squeeze(-1),
)

# Overwrites the parent class method: We pass constant float and
# boolean parameters across tensors
def distribution(
self,
distr_args,
loc: Optional[torch.Tensor] = None,
scale: Optional[torch.Tensor] = None,
) -> Distribution:
(loc, scale) = distr_args

return TruncatedNormal(
loc=loc,
scale=scale,
upscale=self.upscale,
min=self.min,
max=self.max,
tanh_loc=self.tanh_loc,
)

@property
def event_shape(self) -> Tuple:
return ()
52 changes: 52 additions & 0 deletions test/torch/distribution/test_truncated_normal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import pytest
import torch
from gluonts.torch.distributions import TruncatedNormal

# Mostly taken from https://github.com/pytorch/rl/blob/main/test/test_distributions.py#L127


@pytest.mark.parametrize(
"min", [-torch.ones(3), -1, 3 * torch.tensor([-1.0, -2.0, -0.5]), -0.1]
)
@pytest.mark.parametrize(
"max", [torch.ones(3), 1, 3 * torch.tensor([1.0, 2.0, 0.5]), 0.1]
)
@pytest.mark.parametrize(
"vecs",
[
(torch.tensor([0.1, 10.0, 5.0]), torch.tensor([0.1, 10.0, 5.0])),
(torch.zeros(7, 3), torch.ones(7, 3)),
],
)
@pytest.mark.parametrize(
"upscale", [torch.ones(3), 1, 3 * torch.tensor([1.0, 2.0, 0.5]), 3]
)
@pytest.mark.parametrize("shape", [torch.Size([]), torch.Size([3, 4])])
def test_truncnormal(min, max, vecs, upscale, shape):
torch.manual_seed(0)
d = TruncatedNormal(
*vecs,
upscale=upscale,
min=min,
max=max,
)
for _ in range(100):
a = d.rsample(shape)
assert a.shape[: len(shape)] == shape
assert (a >= d.min).all()
assert (a <= d.max).all()
lp = d.log_prob(a)
assert torch.isfinite(lp).all()
Loading