From 8a0034b834aa86a18d49be311f0493c73127578e Mon Sep 17 00:00:00 2001 From: Pedro Eduardo Mercado Lopez Date: Thu, 17 Aug 2023 18:27:39 +0200 Subject: [PATCH 01/15] add truncated_normal_distribution --- .../torch/distributions/truncated_normal.py | 212 ++++++++++++++++++ 1 file changed, 212 insertions(+) create mode 100644 src/gluonts/torch/distributions/truncated_normal.py diff --git a/src/gluonts/torch/distributions/truncated_normal.py b/src/gluonts/torch/distributions/truncated_normal.py new file mode 100644 index 0000000000..53eb65c9f6 --- /dev/null +++ b/src/gluonts/torch/distributions/truncated_normal.py @@ -0,0 +1,212 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +# mainly based from https://github.com/pytorch/rl/blob/main/torchrl/modules/distributions/continuous.py#L166 + +from numbers import Number +from typing import Dict, Optional, Tuple, Union + +import torch +from torch import distributions as D +from torch.distributions import constraints + +import torch.nn.functional as F +from .distribution_output import DistributionOutput +from gluonts.core.component import validated +from torch.distributions import Distribution +from .utils.truncated_normal import TruncatedNormal as _TruncatedNormal + +# speeds up distribution construction +D.Distribution.set_default_validate_args(False) + + +class TruncatedNormal(D.Independent): + # class TruncatedNormal(Distribution): + """Implements a Truncated Normal distribution with location scaling. + + Location scaling prevents the location to be "too far" from 0, which ultimately + leads to numerically unstable samples and poor gradient computation (e.g. gradient explosion). + In practice, the location is computed according to + + .. math:: + loc = tanh(loc / upscale) * upscale. + + This behaviour can be disabled by switching off the tanh_loc parameter (see below). + + + Args: + loc (torch.Tensor): normal distribution location parameter + scale (torch.Tensor): normal distribution sigma parameter (squared root of variance) + upscale (torch.Tensor or number, optional): 'a' scaling factor in the formula: + + .. math:: + loc = tanh(loc / upscale) * upscale. + + Default is 5.0 + + min (torch.Tensor or number, optional): minimum value of the distribution. Default = -1.0; + max (torch.Tensor or number, optional): maximum value of the distribution. Default = 1.0; + tanh_loc (bool, optional): if ``True``, the above formula is used for + the location scaling, otherwise the raw value is kept. + Default is ``False``; + """ + + num_params: int = 2 + + arg_constraints = { + "loc": constraints.real, + "scale": constraints.greater_than(1e-6), + } + + def __init__( + self, + loc: torch.Tensor, + scale: torch.Tensor, + upscale: Union[torch.Tensor, float] = 5.0, + min: Union[torch.Tensor, float] = -1.0, + max: Union[torch.Tensor, float] = 1.0, + tanh_loc: bool = False, + ): + err_msg = ( + "TanhNormal max values must be strictly greater than min values" + ) + if isinstance(max, torch.Tensor) or isinstance(min, torch.Tensor): + if not (max > min).all(): + raise RuntimeError(err_msg) + elif isinstance(max, Number) and isinstance(min, Number): + if not max > min: + raise RuntimeError(err_msg) + else: + if not all(max > min): + raise RuntimeError(err_msg) + + if isinstance(max, torch.Tensor): + self.non_trivial_max = (max != 1.0).any() + else: + self.non_trivial_max = max != 1.0 + + if isinstance(min, torch.Tensor): + self.non_trivial_min = (min != -1.0).any() + else: + self.non_trivial_min = min != -1.0 + self.tanh_loc = tanh_loc + + self.device = loc.device + self.upscale = ( + upscale + if not isinstance(upscale, torch.Tensor) + else upscale.to(self.device) + ) + + if isinstance(max, torch.Tensor): + max = max.to(self.device) + else: + max = torch.tensor(max, device=self.device) + if isinstance(min, torch.Tensor): + min = min.to(self.device) + else: + min = torch.tensor(min, device=self.device) + self.min = min + self.max = max + self.update(loc, scale) + + def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None: + if self.tanh_loc: + loc = (loc / self.upscale).tanh() * self.upscale + if self.non_trivial_max or self.non_trivial_min: + loc = loc + (self.max - self.min) / 2 + self.min + self.loc = loc + self.scale = scale + + base_dist = _TruncatedNormal( + loc, scale, self.min.expand_as(loc), self.max.expand_as(scale) + ) + super().__init__(base_dist, 1, validate_args=False) + + @property + def mode(self): + m = self.base_dist.loc + a = self.base_dist._non_std_a + self.base_dist._dtype_min_gt_0 + b = self.base_dist._non_std_b - self.base_dist._dtype_min_gt_0 + m = torch.min(torch.stack([m, b], -1), dim=-1)[0] + return torch.max(torch.stack([m, a], -1), dim=-1)[0] + + def log_prob(self, value, **kwargs): + a = self.base_dist._non_std_a + self.base_dist._dtype_min_gt_0 + a = a.expand_as(value) + b = self.base_dist._non_std_b - self.base_dist._dtype_min_gt_0 + b = b.expand_as(value) + value = torch.min(torch.stack([value, b], -1), dim=-1)[0] + value = torch.max(torch.stack([value, a], -1), dim=-1)[0] + return self.base_dist.log_prob( + value + ) # original: return super().log_prob(value, **kwargs) + + +class TruncatedNormalOutput(DistributionOutput): + distr_cls: type = TruncatedNormal + + @validated() + def __init__( + self, + min: float, + max: float, + upscale: float = 5.0, + tanh_loc: bool = False, + ) -> None: + super().__init__(self) + + self.min = min + self.max = max + self.upscale = upscale + self.tanh_loc = tanh_loc + self.args_dim: Dict[str, int] = { + "loc": 1, + "scale": 1, + } + + # @classmethod + def domain_map( + self, + loc: torch.Tensor, + scale: torch.Tensor, + ): + scale = F.softplus(scale) + + return ( + loc.squeeze(-1), + scale.squeeze(-1), + ) + + # Overwrites the parent class method: We pass constant float and + # boolean parameters across tensors + def distribution( + self, + distr_args, + loc: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, + ) -> Distribution: + (loc, scale) = distr_args + + return TruncatedNormal( + loc=loc, + scale=scale, + upscale=self.upscale, + min=self.min, + max=self.max, + tanh_loc=self.tanh_loc, + ) + + @property + def event_shape(self) -> Tuple: + return () From 6066c67216fe59de2e838609e1f30ab7c4cab710 Mon Sep 17 00:00:00 2001 From: Pedro Eduardo Mercado Lopez Date: Thu, 17 Aug 2023 18:28:34 +0200 Subject: [PATCH 02/15] add utils --- .../distributions/utils/truncated_normal.py | 227 ++++++++++++++++++ 1 file changed, 227 insertions(+) create mode 100644 src/gluonts/torch/distributions/utils/truncated_normal.py diff --git a/src/gluonts/torch/distributions/utils/truncated_normal.py b/src/gluonts/torch/distributions/utils/truncated_normal.py new file mode 100644 index 0000000000..080daf3d66 --- /dev/null +++ b/src/gluonts/torch/distributions/utils/truncated_normal.py @@ -0,0 +1,227 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +# from https://github.com/pytorch/rl/blob/main/torchrl/modules/distributions/truncated_normal.py + +# from https://github.com/toshas/torch_truncnorm + +import math +from numbers import Number + +import torch +from torch.distributions import constraints, Distribution +from torch.distributions.utils import broadcast_all + +CONST_SQRT_2 = math.sqrt(2) +CONST_INV_SQRT_2PI = 1 / math.sqrt(2 * math.pi) +CONST_INV_SQRT_2 = 1 / math.sqrt(2) +CONST_LOG_INV_SQRT_2PI = math.log(CONST_INV_SQRT_2PI) +CONST_LOG_SQRT_2PI_E = 0.5 * math.log(2 * math.pi * math.e) + + +class TruncatedStandardNormal(Distribution): + """Truncated Standard Normal distribution. + + Source: https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + """ + + arg_constraints = { + "a": constraints.real, + "b": constraints.real, + } + has_rsample = True + eps = 1e-6 + + def __init__(self, a, b, validate_args=None): + self.a, self.b = broadcast_all(a, b) + if isinstance(a, Number) and isinstance(b, Number): + batch_shape = torch.Size() + else: + batch_shape = self.a.size() + super(TruncatedStandardNormal, self).__init__( + batch_shape, validate_args=validate_args + ) + if self.a.dtype != self.b.dtype: + raise ValueError("Truncation bounds types are different") + if any( + (self.a >= self.b) + .view( + -1, + ) + .tolist() + ): + raise ValueError("Incorrect truncation range") + eps = self.eps + self._dtype_min_gt_0 = eps + self._dtype_max_lt_1 = 1 - eps + self._little_phi_a = self._little_phi(self.a) + self._little_phi_b = self._little_phi(self.b) + self._big_phi_a = self._big_phi(self.a) + self._big_phi_b = self._big_phi(self.b) + self._Z = (self._big_phi_b - self._big_phi_a).clamp(eps, 1 - eps) + self._log_Z = self._Z.log() + little_phi_coeff_a = torch.nan_to_num(self.a, nan=math.nan) + little_phi_coeff_b = torch.nan_to_num(self.b, nan=math.nan) + self._lpbb_m_lpaa_d_Z = ( + self._little_phi_b * little_phi_coeff_b + - self._little_phi_a * little_phi_coeff_a + ) / self._Z + self._mean = -(self._little_phi_b - self._little_phi_a) / self._Z + self._variance = ( + 1 + - self._lpbb_m_lpaa_d_Z + - ((self._little_phi_b - self._little_phi_a) / self._Z) ** 2 + ) + self._entropy = ( + CONST_LOG_SQRT_2PI_E + self._log_Z - 0.5 * self._lpbb_m_lpaa_d_Z + ) + + @constraints.dependent_property + def support(self): + return constraints.interval(self.a, self.b) + + @property + def mean(self): + return self._mean + + @property + def variance(self): + return self._variance + + @property + def entropy(self): + return self._entropy + + @property + def auc(self): + return self._Z + + @staticmethod + def _little_phi(x): + return (-(x**2) * 0.5).exp() * CONST_INV_SQRT_2PI + + def _big_phi(self, x): + phi = 0.5 * (1 + (x * CONST_INV_SQRT_2).erf()) + return phi.clamp(self.eps, 1 - self.eps) + + @staticmethod + def _inv_big_phi(x): + return CONST_SQRT_2 * (2 * x - 1).erfinv() + + def cdf(self, value): + if self._validate_args: + self._validate_sample(value) + return ((self._big_phi(value) - self._big_phi_a) / self._Z).clamp(0, 1) + + def icdf(self, value): + y = self._big_phi_a + value * self._Z + y = y.clamp(self.eps, 1 - self.eps) + return self._inv_big_phi(y) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + return CONST_LOG_INV_SQRT_2PI - self._log_Z - (value**2) * 0.5 + + def rsample(self, sample_shape=None): + if sample_shape is None: + sample_shape = torch.Size([]) + shape = self._extended_shape(sample_shape) + p = torch.empty(shape, device=self.a.device).uniform_( + self._dtype_min_gt_0, self._dtype_max_lt_1 + ) + return self.icdf(p) + + +class TruncatedNormal(TruncatedStandardNormal): + """Truncated Normal distribution. + + https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + """ + + has_rsample = True + + def __init__(self, loc, scale, a, b, validate_args=None): + scale = scale.clamp_min(self.eps) + self.loc, self.scale, a, b = broadcast_all(loc, scale, a, b) + self._non_std_a = a + self._non_std_b = b + a = (a - self.loc) / self.scale + b = (b - self.loc) / self.scale + super(TruncatedNormal, self).__init__( + a, b, validate_args=validate_args + ) + self._log_scale = self.scale.log() + self._mean = self._mean * self.scale + self.loc + self._variance = self._variance * self.scale**2 + self._entropy += self._log_scale + + def _to_std_rv(self, value): + return (value - self.loc) / self.scale + + def _from_std_rv(self, value): + return value * self.scale + self.loc + + def cdf(self, value): + return super(TruncatedNormal, self).cdf(self._to_std_rv(value)) + + def icdf(self, value): + sample = self._from_std_rv(super().icdf(value)) + + # clamp data but keep gradients + sample_clip = torch.stack( + [sample.detach(), self._non_std_a.detach().expand_as(sample)], 0 + ).max(0)[0] + sample_clip = torch.stack( + [sample_clip, self._non_std_b.detach().expand_as(sample)], 0 + ).min(0)[0] + sample.data.copy_(sample_clip) + return sample + + def log_prob(self, value): + value = self._to_std_rv(value) + return super(TruncatedNormal, self).log_prob(value) - self._log_scale + + +# class TruncatedNormalOutput(DistributionOutput): +# distr_cls: type = TruncatedNormal +# +# @validated() +# def __init__(self, a: float, b: float) -> None: +# super().__init__(self) +# +# self.a = a +# self.b = b +# self.args_dim: Dict[str, int] = {"loc": 1, "scale": 1, "a": 1, "b": 1} +# +# # @classmethod +# def domain_map( +# self, +# loc: torch.Tensor, +# scale: torch.Tensor, +# a: torch.Tensor, +# b: torch.Tensor, +# ): +# scale = F.softplus(scale) +# a = self.a * torch.ones_like(a) +# b = self.b * torch.ones_like(b) +# return ( +# loc.squeeze(-1), +# scale.squeeze(-1), +# a.squeeze(axis=-1), +# b.squeeze(axis=-1), +# ) +# +# @property +# def event_shape(self) -> Tuple: +# return () From 4c0fcca8b4f167ab84b4ac678b5113b7b70bebe4 Mon Sep 17 00:00:00 2001 From: Pedro Eduardo Mercado Lopez Date: Thu, 17 Aug 2023 18:30:16 +0200 Subject: [PATCH 03/15] add tests --- .../distribution/test_truncated_normal.py | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 test/torch/distribution/test_truncated_normal.py diff --git a/test/torch/distribution/test_truncated_normal.py b/test/torch/distribution/test_truncated_normal.py new file mode 100644 index 0000000000..c1582af0de --- /dev/null +++ b/test/torch/distribution/test_truncated_normal.py @@ -0,0 +1,52 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import pytest +import torch +from gluonts.torch.distributions import TruncatedNormal + +# Mostly taken from https://github.com/pytorch/rl/blob/main/test/test_distributions.py#L127 + + +@pytest.mark.parametrize( + "min", [-torch.ones(3), -1, 3 * torch.tensor([-1.0, -2.0, -0.5]), -0.1] +) +@pytest.mark.parametrize( + "max", [torch.ones(3), 1, 3 * torch.tensor([1.0, 2.0, 0.5]), 0.1] +) +@pytest.mark.parametrize( + "vecs", + [ + (torch.tensor([0.1, 10.0, 5.0]), torch.tensor([0.1, 10.0, 5.0])), + (torch.zeros(7, 3), torch.ones(7, 3)), + ], +) +@pytest.mark.parametrize( + "upscale", [torch.ones(3), 1, 3 * torch.tensor([1.0, 2.0, 0.5]), 3] +) +@pytest.mark.parametrize("shape", [torch.Size([]), torch.Size([3, 4])]) +def test_truncnormal(min, max, vecs, upscale, shape): + torch.manual_seed(0) + d = TruncatedNormal( + *vecs, + upscale=upscale, + min=min, + max=max, + ) + for _ in range(100): + a = d.rsample(shape) + assert a.shape[: len(shape)] == shape + assert (a >= d.min).all() + assert (a <= d.max).all() + lp = d.log_prob(a) + assert torch.isfinite(lp).all() From 360ccf6e9e76221d08b8ddb687065ecf595e4b3d Mon Sep 17 00:00:00 2001 From: Pedro Eduardo Mercado Lopez Date: Thu, 17 Aug 2023 18:30:30 +0200 Subject: [PATCH 04/15] update __init__ --- src/gluonts/torch/distributions/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/gluonts/torch/distributions/__init__.py b/src/gluonts/torch/distributions/__init__.py index 18a27245c6..e279e023df 100644 --- a/src/gluonts/torch/distributions/__init__.py +++ b/src/gluonts/torch/distributions/__init__.py @@ -36,6 +36,7 @@ SplicedBinnedParetoOutput, ) from .studentT import StudentTOutput +from .truncated_normal import TruncatedNormal, TruncatedNormalOutput __all__ = [ "AffineTransformed", @@ -62,4 +63,6 @@ "SplicedBinnedPareto", "SplicedBinnedParetoOutput", "StudentTOutput", + "TruncatedNormal", + "TruncatedNormalOutput", ] From 9cea7699087ddd65b3ac919e10640e2286efe536 Mon Sep 17 00:00:00 2001 From: Pedro Eduardo Mercado Lopez Date: Thu, 17 Aug 2023 18:31:18 +0200 Subject: [PATCH 05/15] clean up --- src/gluonts/torch/distributions/truncated_normal.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/gluonts/torch/distributions/truncated_normal.py b/src/gluonts/torch/distributions/truncated_normal.py index 53eb65c9f6..c07e66f610 100644 --- a/src/gluonts/torch/distributions/truncated_normal.py +++ b/src/gluonts/torch/distributions/truncated_normal.py @@ -31,7 +31,6 @@ class TruncatedNormal(D.Independent): - # class TruncatedNormal(Distribution): """Implements a Truncated Normal distribution with location scaling. Location scaling prevents the location to be "too far" from 0, which ultimately From 03c6ceac8945f457adff5c5ed43479e2857ec005 Mon Sep 17 00:00:00 2001 From: Pedro Eduardo Mercado Lopez Date: Thu, 17 Aug 2023 18:32:00 +0200 Subject: [PATCH 06/15] clean up --- .../distributions/utils/truncated_normal.py | 34 ------------------- 1 file changed, 34 deletions(-) diff --git a/src/gluonts/torch/distributions/utils/truncated_normal.py b/src/gluonts/torch/distributions/utils/truncated_normal.py index 080daf3d66..55edcbf10f 100644 --- a/src/gluonts/torch/distributions/utils/truncated_normal.py +++ b/src/gluonts/torch/distributions/utils/truncated_normal.py @@ -191,37 +191,3 @@ def icdf(self, value): def log_prob(self, value): value = self._to_std_rv(value) return super(TruncatedNormal, self).log_prob(value) - self._log_scale - - -# class TruncatedNormalOutput(DistributionOutput): -# distr_cls: type = TruncatedNormal -# -# @validated() -# def __init__(self, a: float, b: float) -> None: -# super().__init__(self) -# -# self.a = a -# self.b = b -# self.args_dim: Dict[str, int] = {"loc": 1, "scale": 1, "a": 1, "b": 1} -# -# # @classmethod -# def domain_map( -# self, -# loc: torch.Tensor, -# scale: torch.Tensor, -# a: torch.Tensor, -# b: torch.Tensor, -# ): -# scale = F.softplus(scale) -# a = self.a * torch.ones_like(a) -# b = self.b * torch.ones_like(b) -# return ( -# loc.squeeze(-1), -# scale.squeeze(-1), -# a.squeeze(axis=-1), -# b.squeeze(axis=-1), -# ) -# -# @property -# def event_shape(self) -> Tuple: -# return () From 6c249850d5661ca39fa3e950b2819356c8c22e1e Mon Sep 17 00:00:00 2001 From: Pedro Eduardo Mercado Lopez Date: Fri, 18 Aug 2023 15:04:41 +0200 Subject: [PATCH 07/15] simply classes --- .../distributions/utils/truncated_normal.py | 97 ++++++++----------- 1 file changed, 40 insertions(+), 57 deletions(-) diff --git a/src/gluonts/torch/distributions/utils/truncated_normal.py b/src/gluonts/torch/distributions/utils/truncated_normal.py index 55edcbf10f..158ddc1763 100644 --- a/src/gluonts/torch/distributions/utils/truncated_normal.py +++ b/src/gluonts/torch/distributions/utils/truncated_normal.py @@ -11,9 +11,9 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -# from https://github.com/pytorch/rl/blob/main/torchrl/modules/distributions/truncated_normal.py - -# from https://github.com/toshas/torch_truncnorm +# The implementation is strongly inspired from: +# - https://github.com/pytorch/rl/blob/main/torchrl/modules/distributions/truncated_normal.py +# - https://github.com/toshas/torch_truncnorm import math from numbers import Number @@ -27,12 +27,13 @@ CONST_INV_SQRT_2 = 1 / math.sqrt(2) CONST_LOG_INV_SQRT_2PI = math.log(CONST_INV_SQRT_2PI) CONST_LOG_SQRT_2PI_E = 0.5 * math.log(2 * math.pi * math.e) +torch.manual_seed(0) -class TruncatedStandardNormal(Distribution): - """Truncated Standard Normal distribution. +class TruncatedNormal(Distribution): + """Truncated Normal distribution. - Source: https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf """ arg_constraints = { @@ -42,15 +43,20 @@ class TruncatedStandardNormal(Distribution): has_rsample = True eps = 1e-6 - def __init__(self, a, b, validate_args=None): - self.a, self.b = broadcast_all(a, b) + def __init__(self, loc, scale, a, b): + scale = scale.clamp_min(self.eps) + self.loc, self.scale, a, b = broadcast_all(loc, scale, a, b) + self._non_std_a = a + self._non_std_b = b + self.a = (a - self.loc) / self.scale + self.b = (b - self.loc) / self.scale + if isinstance(a, Number) and isinstance(b, Number): batch_shape = torch.Size() else: batch_shape = self.a.size() - super(TruncatedStandardNormal, self).__init__( - batch_shape, validate_args=validate_args - ) + + super(TruncatedNormal, self).__init__(batch_shape) if self.a.dtype != self.b.dtype: raise ValueError("Truncation bounds types are different") if any( @@ -61,6 +67,7 @@ def __init__(self, a, b, validate_args=None): .tolist() ): raise ValueError("Incorrect truncation range") + eps = self.eps self._dtype_min_gt_0 = eps self._dtype_max_lt_1 = 1 - eps @@ -85,6 +92,10 @@ def __init__(self, a, b, validate_args=None): self._entropy = ( CONST_LOG_SQRT_2PI_E + self._log_Z - 0.5 * self._lpbb_m_lpaa_d_Z ) + self._log_scale = self.scale.log() + self._mean_non_std = self._mean * self.scale + self.loc + self._variance_non_std = self._variance * self.scale**2 + self._entropy_non_std = self._entropy + self._log_scale @constraints.dependent_property def support(self): @@ -92,19 +103,15 @@ def support(self): @property def mean(self): - return self._mean + return self._mean_non_std @property def variance(self): - return self._variance + return self._variance_non_std @property def entropy(self): - return self._entropy - - @property - def auc(self): - return self._Z + return self._entropy_non_std @staticmethod def _little_phi(x): @@ -118,54 +125,21 @@ def _big_phi(self, x): def _inv_big_phi(x): return CONST_SQRT_2 * (2 * x - 1).erfinv() - def cdf(self, value): + def cdf_truncated_standard_normal(self, value): if self._validate_args: self._validate_sample(value) return ((self._big_phi(value) - self._big_phi_a) / self._Z).clamp(0, 1) - def icdf(self, value): + def icdf_truncated_standard_normal(self, value): y = self._big_phi_a + value * self._Z y = y.clamp(self.eps, 1 - self.eps) return self._inv_big_phi(y) - def log_prob(self, value): + def log_prob_truncated_standard_normal(self, value): if self._validate_args: self._validate_sample(value) return CONST_LOG_INV_SQRT_2PI - self._log_Z - (value**2) * 0.5 - def rsample(self, sample_shape=None): - if sample_shape is None: - sample_shape = torch.Size([]) - shape = self._extended_shape(sample_shape) - p = torch.empty(shape, device=self.a.device).uniform_( - self._dtype_min_gt_0, self._dtype_max_lt_1 - ) - return self.icdf(p) - - -class TruncatedNormal(TruncatedStandardNormal): - """Truncated Normal distribution. - - https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf - """ - - has_rsample = True - - def __init__(self, loc, scale, a, b, validate_args=None): - scale = scale.clamp_min(self.eps) - self.loc, self.scale, a, b = broadcast_all(loc, scale, a, b) - self._non_std_a = a - self._non_std_b = b - a = (a - self.loc) / self.scale - b = (b - self.loc) / self.scale - super(TruncatedNormal, self).__init__( - a, b, validate_args=validate_args - ) - self._log_scale = self.scale.log() - self._mean = self._mean * self.scale + self.loc - self._variance = self._variance * self.scale**2 - self._entropy += self._log_scale - def _to_std_rv(self, value): return (value - self.loc) / self.scale @@ -173,10 +147,10 @@ def _from_std_rv(self, value): return value * self.scale + self.loc def cdf(self, value): - return super(TruncatedNormal, self).cdf(self._to_std_rv(value)) + return self.cdf_truncated_standard_normal(self._to_std_rv(value)) def icdf(self, value): - sample = self._from_std_rv(super().icdf(value)) + sample = self._from_std_rv(self.icdf_truncated_standard_normal(value)) # clamp data but keep gradients sample_clip = torch.stack( @@ -190,4 +164,13 @@ def icdf(self, value): def log_prob(self, value): value = self._to_std_rv(value) - return super(TruncatedNormal, self).log_prob(value) - self._log_scale + return self.log_prob_truncated_standard_normal(value) - self._log_scale + + def rsample(self, sample_shape=None): + if sample_shape is None: + sample_shape = torch.Size([]) + shape = self._extended_shape(sample_shape) + p = torch.empty(shape, device=self.a.device).uniform_( + self._dtype_min_gt_0, self._dtype_max_lt_1 + ) + return self.icdf(p) From 959d5034df4f36f15fdb080e9fda2ae7f6c981f4 Mon Sep 17 00:00:00 2001 From: Pedro Eduardo Mercado Lopez Date: Fri, 18 Aug 2023 17:42:06 +0200 Subject: [PATCH 08/15] simplify classes --- .../torch/distributions/truncated_normal.py | 239 ++++++++++++------ .../distributions/utils/truncated_normal.py | 176 ------------- 2 files changed, 159 insertions(+), 256 deletions(-) delete mode 100644 src/gluonts/torch/distributions/utils/truncated_normal.py diff --git a/src/gluonts/torch/distributions/truncated_normal.py b/src/gluonts/torch/distributions/truncated_normal.py index c07e66f610..2f50567143 100644 --- a/src/gluonts/torch/distributions/truncated_normal.py +++ b/src/gluonts/torch/distributions/truncated_normal.py @@ -11,26 +11,34 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -# mainly based from https://github.com/pytorch/rl/blob/main/torchrl/modules/distributions/continuous.py#L166 +# The implementation is strongly inspired from: +# - https://github.com/pytorch/rl/blob/main/torchrl/modules/distributions/truncated_normal.py +# - https://github.com/toshas/torch_truncnorm +import math from numbers import Number from typing import Dict, Optional, Tuple, Union import torch -from torch import distributions as D from torch.distributions import constraints +from torch.distributions.utils import broadcast_all import torch.nn.functional as F from .distribution_output import DistributionOutput from gluonts.core.component import validated from torch.distributions import Distribution -from .utils.truncated_normal import TruncatedNormal as _TruncatedNormal -# speeds up distribution construction -D.Distribution.set_default_validate_args(False) +CONST_SQRT_2 = math.sqrt(2) +CONST_INV_SQRT_2PI = 1 / math.sqrt(2 * math.pi) +CONST_INV_SQRT_2 = 1 / math.sqrt(2) +CONST_LOG_INV_SQRT_2PI = math.log(CONST_INV_SQRT_2PI) +CONST_LOG_SQRT_2PI_E = 0.5 * math.log(2 * math.pi * math.e) +torch.manual_seed(0) -class TruncatedNormal(D.Independent): +class TruncatedNormal(Distribution): + """Truncated Normal distribution.""" + """Implements a Truncated Normal distribution with location scaling. Location scaling prevents the location to be "too far" from 0, which ultimately @@ -58,98 +66,166 @@ class TruncatedNormal(D.Independent): tanh_loc (bool, optional): if ``True``, the above formula is used for the location scaling, otherwise the raw value is kept. Default is ``False``; + + References: + - https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + + This implementation is strongly based on: + - https://github.com/pytorch/rl/blob/main/torchrl/modules/distributions/truncated_normal.py + - https://github.com/toshas/torch_truncnorm """ - num_params: int = 2 - arg_constraints = { - "loc": constraints.real, - "scale": constraints.greater_than(1e-6), + "a": constraints.real, + "b": constraints.real, } + has_rsample = True + eps = 1e-6 def __init__( self, loc: torch.Tensor, scale: torch.Tensor, + min: Union[torch.Tensor, float], + max: Union[torch.Tensor, float], upscale: Union[torch.Tensor, float] = 5.0, - min: Union[torch.Tensor, float] = -1.0, - max: Union[torch.Tensor, float] = 1.0, tanh_loc: bool = False, ): - err_msg = ( - "TanhNormal max values must be strictly greater than min values" - ) - if isinstance(max, torch.Tensor) or isinstance(min, torch.Tensor): - if not (max > min).all(): - raise RuntimeError(err_msg) - elif isinstance(max, Number) and isinstance(min, Number): - if not max > min: - raise RuntimeError(err_msg) - else: - if not all(max > min): - raise RuntimeError(err_msg) - - if isinstance(max, torch.Tensor): - self.non_trivial_max = (max != 1.0).any() - else: - self.non_trivial_max = max != 1.0 - if isinstance(min, torch.Tensor): - self.non_trivial_min = (min != -1.0).any() - else: - self.non_trivial_min = min != -1.0 - self.tanh_loc = tanh_loc + scale = scale.clamp_min(self.eps) + if tanh_loc: + loc = (loc / upscale).tanh() * upscale + loc = loc + (max - min) / 2 + min - self.device = loc.device - self.upscale = ( - upscale - if not isinstance(upscale, torch.Tensor) - else upscale.to(self.device) + self.min = min + self.max = max + self.loc, self.scale, a, b = broadcast_all( + loc, scale, self.min, self.max ) + self._non_std_a = a + self._non_std_b = b + self.a = (a - self.loc) / self.scale + self.b = (b - self.loc) / self.scale - if isinstance(max, torch.Tensor): - max = max.to(self.device) - else: - max = torch.tensor(max, device=self.device) - if isinstance(min, torch.Tensor): - min = min.to(self.device) + if isinstance(a, Number) and isinstance(b, Number): + batch_shape = torch.Size() else: - min = torch.tensor(min, device=self.device) - self.min = min - self.max = max - self.update(loc, scale) - - def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None: - if self.tanh_loc: - loc = (loc / self.upscale).tanh() * self.upscale - if self.non_trivial_max or self.non_trivial_min: - loc = loc + (self.max - self.min) / 2 + self.min - self.loc = loc - self.scale = scale - - base_dist = _TruncatedNormal( - loc, scale, self.min.expand_as(loc), self.max.expand_as(scale) + batch_shape = self.a.size() + + super(TruncatedNormal, self).__init__(batch_shape) + if self.a.dtype != self.b.dtype: + raise ValueError("Truncation bounds types are different") + if any( + (self.a >= self.b) + .view( + -1, + ) + .tolist() + ): + raise ValueError("Incorrect truncation range") + + eps = self.eps + self._dtype_min_gt_0 = eps + self._dtype_max_lt_1 = 1 - eps + self._little_phi_a = self._little_phi(self.a) + self._little_phi_b = self._little_phi(self.b) + self._big_phi_a = self._big_phi(self.a) + self._big_phi_b = self._big_phi(self.b) + self._Z = (self._big_phi_b - self._big_phi_a).clamp(eps, 1 - eps) + self._log_Z = self._Z.log() + little_phi_coeff_a = torch.nan_to_num(self.a, nan=math.nan) + little_phi_coeff_b = torch.nan_to_num(self.b, nan=math.nan) + self._lpbb_m_lpaa_d_Z = ( + self._little_phi_b * little_phi_coeff_b + - self._little_phi_a * little_phi_coeff_a + ) / self._Z + self._mean = -(self._little_phi_b - self._little_phi_a) / self._Z + self._variance = ( + 1 + - self._lpbb_m_lpaa_d_Z + - ((self._little_phi_b - self._little_phi_a) / self._Z) ** 2 + ) + self._entropy = ( + CONST_LOG_SQRT_2PI_E + self._log_Z - 0.5 * self._lpbb_m_lpaa_d_Z ) - super().__init__(base_dist, 1, validate_args=False) + self._log_scale = self.scale.log() + self._mean_non_std = self._mean * self.scale + self.loc + self._variance_non_std = self._variance * self.scale**2 + self._entropy_non_std = self._entropy + self._log_scale + + @constraints.dependent_property + def support(self): + return constraints.interval(self.a, self.b) @property - def mode(self): - m = self.base_dist.loc - a = self.base_dist._non_std_a + self.base_dist._dtype_min_gt_0 - b = self.base_dist._non_std_b - self.base_dist._dtype_min_gt_0 - m = torch.min(torch.stack([m, b], -1), dim=-1)[0] - return torch.max(torch.stack([m, a], -1), dim=-1)[0] - - def log_prob(self, value, **kwargs): - a = self.base_dist._non_std_a + self.base_dist._dtype_min_gt_0 - a = a.expand_as(value) - b = self.base_dist._non_std_b - self.base_dist._dtype_min_gt_0 - b = b.expand_as(value) - value = torch.min(torch.stack([value, b], -1), dim=-1)[0] - value = torch.max(torch.stack([value, a], -1), dim=-1)[0] - return self.base_dist.log_prob( - value - ) # original: return super().log_prob(value, **kwargs) + def mean(self): + return self._mean_non_std + + @property + def variance(self): + return self._variance_non_std + + @property + def entropy(self): + return self._entropy_non_std + + @staticmethod + def _little_phi(x): + return (-(x**2) * 0.5).exp() * CONST_INV_SQRT_2PI + + def _big_phi(self, x): + phi = 0.5 * (1 + (x * CONST_INV_SQRT_2).erf()) + return phi.clamp(self.eps, 1 - self.eps) + + @staticmethod + def _inv_big_phi(x): + return CONST_SQRT_2 * (2 * x - 1).erfinv() + + def cdf_truncated_standard_normal(self, value): + return ((self._big_phi(value) - self._big_phi_a) / self._Z).clamp(0, 1) + + def icdf_truncated_standard_normal(self, value): + y = self._big_phi_a + value * self._Z + y = y.clamp(self.eps, 1 - self.eps) + return self._inv_big_phi(y) + + def log_prob_truncated_standard_normal(self, value): + return CONST_LOG_INV_SQRT_2PI - self._log_Z - (value**2) * 0.5 + + def _to_std_rv(self, value): + return (value - self.loc) / self.scale + + def _from_std_rv(self, value): + return value * self.scale + self.loc + + def cdf(self, value): + return self.cdf_truncated_standard_normal(self._to_std_rv(value)) + + def icdf(self, value): + sample = self._from_std_rv(self.icdf_truncated_standard_normal(value)) + + # clamp data but keep gradients + sample_clip = torch.stack( + [sample.detach(), self._non_std_a.detach().expand_as(sample)], 0 + ).max(0)[0] + sample_clip = torch.stack( + [sample_clip, self._non_std_b.detach().expand_as(sample)], 0 + ).min(0)[0] + sample.data.copy_(sample_clip) + return sample + + def log_prob(self, value): + value = self._to_std_rv(value) + return self.log_prob_truncated_standard_normal(value) - self._log_scale + + def rsample(self, sample_shape=None): + if sample_shape is None: + sample_shape = torch.Size([]) + shape = self._extended_shape(sample_shape) + p = torch.empty(shape, device=self.a.device).uniform_( + self._dtype_min_gt_0, self._dtype_max_lt_1 + ) + return self.icdf(p) class TruncatedNormalOutput(DistributionOutput): @@ -163,6 +239,9 @@ def __init__( upscale: float = 5.0, tanh_loc: bool = False, ) -> None: + + assert min < max, "max must be strictly greater than min" + super().__init__(self) self.min = min @@ -174,9 +253,9 @@ def __init__( "scale": 1, } - # @classmethod + @classmethod def domain_map( - self, + cls, loc: torch.Tensor, scale: torch.Tensor, ): diff --git a/src/gluonts/torch/distributions/utils/truncated_normal.py b/src/gluonts/torch/distributions/utils/truncated_normal.py deleted file mode 100644 index 158ddc1763..0000000000 --- a/src/gluonts/torch/distributions/utils/truncated_normal.py +++ /dev/null @@ -1,176 +0,0 @@ -# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# A copy of the License is located at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the "license" file accompanying this file. This file is distributed -# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# express or implied. See the License for the specific language governing -# permissions and limitations under the License. - -# The implementation is strongly inspired from: -# - https://github.com/pytorch/rl/blob/main/torchrl/modules/distributions/truncated_normal.py -# - https://github.com/toshas/torch_truncnorm - -import math -from numbers import Number - -import torch -from torch.distributions import constraints, Distribution -from torch.distributions.utils import broadcast_all - -CONST_SQRT_2 = math.sqrt(2) -CONST_INV_SQRT_2PI = 1 / math.sqrt(2 * math.pi) -CONST_INV_SQRT_2 = 1 / math.sqrt(2) -CONST_LOG_INV_SQRT_2PI = math.log(CONST_INV_SQRT_2PI) -CONST_LOG_SQRT_2PI_E = 0.5 * math.log(2 * math.pi * math.e) -torch.manual_seed(0) - - -class TruncatedNormal(Distribution): - """Truncated Normal distribution. - - https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf - """ - - arg_constraints = { - "a": constraints.real, - "b": constraints.real, - } - has_rsample = True - eps = 1e-6 - - def __init__(self, loc, scale, a, b): - scale = scale.clamp_min(self.eps) - self.loc, self.scale, a, b = broadcast_all(loc, scale, a, b) - self._non_std_a = a - self._non_std_b = b - self.a = (a - self.loc) / self.scale - self.b = (b - self.loc) / self.scale - - if isinstance(a, Number) and isinstance(b, Number): - batch_shape = torch.Size() - else: - batch_shape = self.a.size() - - super(TruncatedNormal, self).__init__(batch_shape) - if self.a.dtype != self.b.dtype: - raise ValueError("Truncation bounds types are different") - if any( - (self.a >= self.b) - .view( - -1, - ) - .tolist() - ): - raise ValueError("Incorrect truncation range") - - eps = self.eps - self._dtype_min_gt_0 = eps - self._dtype_max_lt_1 = 1 - eps - self._little_phi_a = self._little_phi(self.a) - self._little_phi_b = self._little_phi(self.b) - self._big_phi_a = self._big_phi(self.a) - self._big_phi_b = self._big_phi(self.b) - self._Z = (self._big_phi_b - self._big_phi_a).clamp(eps, 1 - eps) - self._log_Z = self._Z.log() - little_phi_coeff_a = torch.nan_to_num(self.a, nan=math.nan) - little_phi_coeff_b = torch.nan_to_num(self.b, nan=math.nan) - self._lpbb_m_lpaa_d_Z = ( - self._little_phi_b * little_phi_coeff_b - - self._little_phi_a * little_phi_coeff_a - ) / self._Z - self._mean = -(self._little_phi_b - self._little_phi_a) / self._Z - self._variance = ( - 1 - - self._lpbb_m_lpaa_d_Z - - ((self._little_phi_b - self._little_phi_a) / self._Z) ** 2 - ) - self._entropy = ( - CONST_LOG_SQRT_2PI_E + self._log_Z - 0.5 * self._lpbb_m_lpaa_d_Z - ) - self._log_scale = self.scale.log() - self._mean_non_std = self._mean * self.scale + self.loc - self._variance_non_std = self._variance * self.scale**2 - self._entropy_non_std = self._entropy + self._log_scale - - @constraints.dependent_property - def support(self): - return constraints.interval(self.a, self.b) - - @property - def mean(self): - return self._mean_non_std - - @property - def variance(self): - return self._variance_non_std - - @property - def entropy(self): - return self._entropy_non_std - - @staticmethod - def _little_phi(x): - return (-(x**2) * 0.5).exp() * CONST_INV_SQRT_2PI - - def _big_phi(self, x): - phi = 0.5 * (1 + (x * CONST_INV_SQRT_2).erf()) - return phi.clamp(self.eps, 1 - self.eps) - - @staticmethod - def _inv_big_phi(x): - return CONST_SQRT_2 * (2 * x - 1).erfinv() - - def cdf_truncated_standard_normal(self, value): - if self._validate_args: - self._validate_sample(value) - return ((self._big_phi(value) - self._big_phi_a) / self._Z).clamp(0, 1) - - def icdf_truncated_standard_normal(self, value): - y = self._big_phi_a + value * self._Z - y = y.clamp(self.eps, 1 - self.eps) - return self._inv_big_phi(y) - - def log_prob_truncated_standard_normal(self, value): - if self._validate_args: - self._validate_sample(value) - return CONST_LOG_INV_SQRT_2PI - self._log_Z - (value**2) * 0.5 - - def _to_std_rv(self, value): - return (value - self.loc) / self.scale - - def _from_std_rv(self, value): - return value * self.scale + self.loc - - def cdf(self, value): - return self.cdf_truncated_standard_normal(self._to_std_rv(value)) - - def icdf(self, value): - sample = self._from_std_rv(self.icdf_truncated_standard_normal(value)) - - # clamp data but keep gradients - sample_clip = torch.stack( - [sample.detach(), self._non_std_a.detach().expand_as(sample)], 0 - ).max(0)[0] - sample_clip = torch.stack( - [sample_clip, self._non_std_b.detach().expand_as(sample)], 0 - ).min(0)[0] - sample.data.copy_(sample_clip) - return sample - - def log_prob(self, value): - value = self._to_std_rv(value) - return self.log_prob_truncated_standard_normal(value) - self._log_scale - - def rsample(self, sample_shape=None): - if sample_shape is None: - sample_shape = torch.Size([]) - shape = self._extended_shape(sample_shape) - p = torch.empty(shape, device=self.a.device).uniform_( - self._dtype_min_gt_0, self._dtype_max_lt_1 - ) - return self.icdf(p) From b93586eb3884df31181cca6ec5ea260dbf4d89bb Mon Sep 17 00:00:00 2001 From: Pedro Eduardo Mercado Lopez Date: Fri, 18 Aug 2023 17:43:03 +0200 Subject: [PATCH 09/15] making black happy --- src/gluonts/torch/distributions/truncated_normal.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/gluonts/torch/distributions/truncated_normal.py b/src/gluonts/torch/distributions/truncated_normal.py index 2f50567143..3bdc3ad215 100644 --- a/src/gluonts/torch/distributions/truncated_normal.py +++ b/src/gluonts/torch/distributions/truncated_normal.py @@ -33,7 +33,6 @@ CONST_INV_SQRT_2 = 1 / math.sqrt(2) CONST_LOG_INV_SQRT_2PI = math.log(CONST_INV_SQRT_2PI) CONST_LOG_SQRT_2PI_E = 0.5 * math.log(2 * math.pi * math.e) -torch.manual_seed(0) class TruncatedNormal(Distribution): @@ -91,7 +90,6 @@ def __init__( upscale: Union[torch.Tensor, float] = 5.0, tanh_loc: bool = False, ): - scale = scale.clamp_min(self.eps) if tanh_loc: loc = (loc / upscale).tanh() * upscale @@ -239,7 +237,6 @@ def __init__( upscale: float = 5.0, tanh_loc: bool = False, ) -> None: - assert min < max, "max must be strictly greater than min" super().__init__(self) From ae9c46bbf7ad73337b0a4b23ac01dfa57f70b15c Mon Sep 17 00:00:00 2001 From: Pedro Eduardo Mercado Lopez Date: Fri, 18 Aug 2023 18:12:59 +0200 Subject: [PATCH 10/15] update constraints --- src/gluonts/torch/distributions/truncated_normal.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/gluonts/torch/distributions/truncated_normal.py b/src/gluonts/torch/distributions/truncated_normal.py index 3bdc3ad215..87b31fef78 100644 --- a/src/gluonts/torch/distributions/truncated_normal.py +++ b/src/gluonts/torch/distributions/truncated_normal.py @@ -75,9 +75,10 @@ class TruncatedNormal(Distribution): """ arg_constraints = { - "a": constraints.real, - "b": constraints.real, + "loc": constraints.real, + "scale": constraints.greater_than(1e-6), } + has_rsample = True eps = 1e-6 @@ -97,6 +98,7 @@ def __init__( self.min = min self.max = max + self.upscale = upscale self.loc, self.scale, a, b = broadcast_all( loc, scale, self.min, self.max ) From e28724210840996f53518e914ac41ea36af5fee3 Mon Sep 17 00:00:00 2001 From: Pedro Eduardo Mercado Lopez Date: Mon, 21 Aug 2023 10:45:13 +0200 Subject: [PATCH 11/15] add clapping to computation of log_prob --- src/gluonts/torch/distributions/truncated_normal.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/gluonts/torch/distributions/truncated_normal.py b/src/gluonts/torch/distributions/truncated_normal.py index 87b31fef78..fc09d128c0 100644 --- a/src/gluonts/torch/distributions/truncated_normal.py +++ b/src/gluonts/torch/distributions/truncated_normal.py @@ -215,6 +215,12 @@ def icdf(self, value): return sample def log_prob(self, value): + a = self._non_std_a + self._dtype_min_gt_0 + a = a.expand_as(value) + b = self._non_std_b - self._dtype_min_gt_0 + b = b.expand_as(value) + value = torch.min(torch.stack([value, b], -1), dim=-1)[0] + value = torch.max(torch.stack([value, a], -1), dim=-1)[0] value = self._to_std_rv(value) return self.log_prob_truncated_standard_normal(value) - self._log_scale From 0e654312836dd32ea557d547e6e8c7d6128de002 Mon Sep 17 00:00:00 2001 From: Pedro Eduardo Mercado Lopez Date: Tue, 22 Aug 2023 10:51:26 +0200 Subject: [PATCH 12/15] add default values to lower/upper bounds --- src/gluonts/torch/distributions/truncated_normal.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/gluonts/torch/distributions/truncated_normal.py b/src/gluonts/torch/distributions/truncated_normal.py index fc09d128c0..44d8b58b9d 100644 --- a/src/gluonts/torch/distributions/truncated_normal.py +++ b/src/gluonts/torch/distributions/truncated_normal.py @@ -86,8 +86,8 @@ def __init__( self, loc: torch.Tensor, scale: torch.Tensor, - min: Union[torch.Tensor, float], - max: Union[torch.Tensor, float], + min: Union[torch.Tensor, float] = -1.0, + max: Union[torch.Tensor, float] = 1.0, upscale: Union[torch.Tensor, float] = 5.0, tanh_loc: bool = False, ): @@ -240,8 +240,8 @@ class TruncatedNormalOutput(DistributionOutput): @validated() def __init__( self, - min: float, - max: float, + min: float = -1.0, + max: float = 1.0, upscale: float = 5.0, tanh_loc: bool = False, ) -> None: From c10f145db2bd62065931a760ec03ff043aa64f8c Mon Sep 17 00:00:00 2001 From: Pedro Eduardo Mercado Lopez Date: Tue, 22 Aug 2023 14:14:04 +0200 Subject: [PATCH 13/15] fix doc-string --- .../torch/distributions/truncated_normal.py | 46 +++++++++---------- 1 file changed, 21 insertions(+), 25 deletions(-) diff --git a/src/gluonts/torch/distributions/truncated_normal.py b/src/gluonts/torch/distributions/truncated_normal.py index 44d8b58b9d..5509e4336d 100644 --- a/src/gluonts/torch/distributions/truncated_normal.py +++ b/src/gluonts/torch/distributions/truncated_normal.py @@ -11,10 +11,6 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -# The implementation is strongly inspired from: -# - https://github.com/pytorch/rl/blob/main/torchrl/modules/distributions/truncated_normal.py -# - https://github.com/toshas/torch_truncnorm - import math from numbers import Number from typing import Dict, Optional, Tuple, Union @@ -36,8 +32,6 @@ class TruncatedNormal(Distribution): - """Truncated Normal distribution.""" - """Implements a Truncated Normal distribution with location scaling. Location scaling prevents the location to be "too far" from 0, which ultimately @@ -49,26 +43,28 @@ class TruncatedNormal(Distribution): This behaviour can be disabled by switching off the tanh_loc parameter (see below). - - Args: - loc (torch.Tensor): normal distribution location parameter - scale (torch.Tensor): normal distribution sigma parameter (squared root of variance) - upscale (torch.Tensor or number, optional): 'a' scaling factor in the formula: - - .. math:: - loc = tanh(loc / upscale) * upscale. - - Default is 5.0 - - min (torch.Tensor or number, optional): minimum value of the distribution. Default = -1.0; - max (torch.Tensor or number, optional): maximum value of the distribution. Default = 1.0; - tanh_loc (bool, optional): if ``True``, the above formula is used for - the location scaling, otherwise the raw value is kept. - Default is ``False``; - - References: + Parameters + ---------- + loc (torch.Tensor): + normal distribution location parameter + scale (torch.Tensor): + normal distribution sigma parameter (squared root of variance) + min (torch.Tensor or number, optional): + minimum value of the distribution. Default = -1.0 + max (torch.Tensor or number, optional): + maximum value of the distribution. Default = 1.0 + upscale (torch.Tensor or number, optional): + scaling factor. Default = 5.0 + tanh_loc (bool, optional): if ``True``, the above formula is used for + the location scaling, otherwise the raw value is kept. + Default is ``False`` + + References + ---------- - https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf - + + Notes + ----- This implementation is strongly based on: - https://github.com/pytorch/rl/blob/main/torchrl/modules/distributions/truncated_normal.py - https://github.com/toshas/torch_truncnorm From c072d929eb6ec0ccad2f815a21b8f2ec83a1d8e5 Mon Sep 17 00:00:00 2001 From: Pedro Eduardo Mercado Lopez Date: Tue, 22 Aug 2023 16:21:47 +0200 Subject: [PATCH 14/15] remove types from docstring --- src/gluonts/torch/distributions/truncated_normal.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/gluonts/torch/distributions/truncated_normal.py b/src/gluonts/torch/distributions/truncated_normal.py index 5509e4336d..3a65d1ef48 100644 --- a/src/gluonts/torch/distributions/truncated_normal.py +++ b/src/gluonts/torch/distributions/truncated_normal.py @@ -45,17 +45,17 @@ class TruncatedNormal(Distribution): Parameters ---------- - loc (torch.Tensor): + loc: normal distribution location parameter - scale (torch.Tensor): + scale: normal distribution sigma parameter (squared root of variance) - min (torch.Tensor or number, optional): + min: minimum value of the distribution. Default = -1.0 - max (torch.Tensor or number, optional): + max: maximum value of the distribution. Default = 1.0 - upscale (torch.Tensor or number, optional): + upscale: scaling factor. Default = 5.0 - tanh_loc (bool, optional): if ``True``, the above formula is used for + tanh_loc: if ``True``, the above formula is used for the location scaling, otherwise the raw value is kept. Default is ``False`` From c6655987ed699b30931a12c083063627cdead179 Mon Sep 17 00:00:00 2001 From: Pedro Mercado <34275963+melopeo@users.noreply.github.com> Date: Tue, 22 Aug 2023 16:28:36 +0200 Subject: [PATCH 15/15] Fix docstring. Co-authored-by: Lorenzo Stella --- src/gluonts/torch/distributions/truncated_normal.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/gluonts/torch/distributions/truncated_normal.py b/src/gluonts/torch/distributions/truncated_normal.py index 3a65d1ef48..a3ff655b1f 100644 --- a/src/gluonts/torch/distributions/truncated_normal.py +++ b/src/gluonts/torch/distributions/truncated_normal.py @@ -55,7 +55,8 @@ class TruncatedNormal(Distribution): maximum value of the distribution. Default = 1.0 upscale: scaling factor. Default = 5.0 - tanh_loc: if ``True``, the above formula is used for + tanh_loc: + if ``True``, the above formula is used for the location scaling, otherwise the raw value is kept. Default is ``False``