From d3eed32867e4b75c1d0e35669459dc39e33c1b89 Mon Sep 17 00:00:00 2001 From: takuseno Date: Thu, 19 Dec 2019 16:40:14 +0900 Subject: [PATCH] Implement nnabla.experimental.distributions --- python/setup.py | 1 + .../experimental/distributions/__init__.py | 3 + .../distributions/distribution.py | 94 ++++++++++++ .../distributions/multivariate_normal.py | 138 ++++++++++++++++++ .../experimental/distributions/normal.py | 131 +++++++++++++++++ .../experimental/distributions/uniform.py | 135 +++++++++++++++++ .../distributions/distribution_test_util.py | 107 ++++++++++++++ .../distributions/test_multivariate_normal.py | 44 ++++++ .../experimental/distributions/test_normal.py | 23 +++ .../distributions/test_uniform.py | 23 +++ 10 files changed, 699 insertions(+) create mode 100644 python/src/nnabla/experimental/distributions/__init__.py create mode 100644 python/src/nnabla/experimental/distributions/distribution.py create mode 100644 python/src/nnabla/experimental/distributions/multivariate_normal.py create mode 100644 python/src/nnabla/experimental/distributions/normal.py create mode 100644 python/src/nnabla/experimental/distributions/uniform.py create mode 100644 python/test/experimental/distributions/distribution_test_util.py create mode 100644 python/test/experimental/distributions/test_multivariate_normal.py create mode 100644 python/test/experimental/distributions/test_normal.py create mode 100644 python/test/experimental/distributions/test_uniform.py diff --git a/python/setup.py b/python/setup.py index 4ca8d77c0..4f43e1c85 100644 --- a/python/setup.py +++ b/python/setup.py @@ -245,6 +245,7 @@ def extopts(library_name, library_dir): 'nnabla.experimental.graph_converters', 'nnabla.experimental.parametric_function_class', 'nnabla.experimental.trainers', + 'nnabla.experimental.distributions', 'nnabla.models', 'nnabla.models.imagenet', 'nnabla.models.object_detection', diff --git a/python/src/nnabla/experimental/distributions/__init__.py b/python/src/nnabla/experimental/distributions/__init__.py new file mode 100644 index 000000000..d5145cc11 --- /dev/null +++ b/python/src/nnabla/experimental/distributions/__init__.py @@ -0,0 +1,3 @@ +from nnabla.experimental.distributions.uniform import Uniform +from nnabla.experimental.distributions.normal import Normal +from nnabla.experimental.distributions.multivariate_normal import MultivariateNormal diff --git a/python/src/nnabla/experimental/distributions/distribution.py b/python/src/nnabla/experimental/distributions/distribution.py new file mode 100644 index 000000000..311b03757 --- /dev/null +++ b/python/src/nnabla/experimental/distributions/distribution.py @@ -0,0 +1,94 @@ +# Copyright (c) 2017 Sony Corporation. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import nnabla.functions as F + + +class Distribution(object): + """Distribution base class for distribution classes. + """ + + def entropy(self): + """Get entropy of distribution. + + Returns: + :class:`~nnabla.Variable`: N-D array. + + """ + raise NotImplementedError + + def mean(self): + """Get mean of distribution. + + Returns: + :class:`~nnabla.Variable`: N-D array. + + """ + raise NotImplementedError + + def stddev(self): + """Get standard deviation of distribution. + + Returns: + :class:`~nnabla.Variable`: N-D array. + + """ + raise NotImplementedError + + def variance(self): + """Get variance of distribution. + + Returns: + :class:`~nnabla.Variable`: N-D array. + + """ + raise NotImplementedError + + def prob(self, x): + """Get probability of sampled `x` from distribution. + + Args: + x (~nnabla.Variable): N-D array. + + Returns: + :class:`~nnabla.Variable`: N-D array. + + """ + raise NotImplementedError + + def sample(self, shape): + """Sample points from distribution. + + Args: + shape (:obj:`tuple`): Shape of sampled points. + + Returns: + :class:`~nnabla.Variable`: N-D array. + + """ + raise NotImplementedError + + def sample_n(self, n): + """Sample points from distribution :math:`n` times. + + Args: + n (int): The number of sampling points. + + Returns: + :class:`~nnabla.Variable`: N-D array. + + """ + samples = [self.sample() for _ in range(n)] + return F.stack(*samples, axis=1) diff --git a/python/src/nnabla/experimental/distributions/multivariate_normal.py b/python/src/nnabla/experimental/distributions/multivariate_normal.py new file mode 100644 index 000000000..6b8fbc153 --- /dev/null +++ b/python/src/nnabla/experimental/distributions/multivariate_normal.py @@ -0,0 +1,138 @@ +# Copyright (c) 2017 Sony Corporation. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import nnabla.functions as F + +from .distribution import Distribution + + +class MultivariateNormal(Distribution): + """Multivariate normal distribution. + + Multivariate normal distribution defined as follows: + + .. math:: + + p(x | \mu, \Sigma) = \frac{1}{\sqrt{(2 \pi)^k \det(\Sigma)}} + \exp(-\frac{1}{2}(x - \mu)^T \Sigma^(-1) (x - \mu)) + + where :math:`k` is a rank of `\Sigma`. + + Args: + loc (~nnabla.Variable or numpy.ndarray): N-D array of :math:`\mu` in + definition. + scale (~nnabla.Variable or numpy.ndarray): N-D array of diagonal + entries of :math:`L` such that covariance matrix + :math:`\Sigma = L L^T`. + + """ + + def __init__(self, loc, scale): + assert loc.shape == scale.shape,\ + 'For now, loc and scale must have same shape.' + if isinstance(loc, np.ndarray): + loc = nn.Variable.from_numpy_array(loc) + loc.persistent = True + if isinstance(scale, np.ndarray): + scale = nn.Variable.from_numpy_array(scale) + scale.persistent = True + self.loc = loc + self.scale = scale + + def mean(self): + """Get mean of multivariate normal distribution. + + Returns: + :class:`~nnabla.Variable`: N-D array identical to :math:`\mu`. + + """ + # to avoid no parent error + return F.identity(self.loc) + + def variance(self): + """Get covariance matrix of multivariate normal distribution. + + .. math:: + + \Sigma = L L^T + + Returns: + :class:`~nnabla.Variable`: N-D array. + + """ + diag = self._diag_scale() + return F.batch_matmul(diag, diag, False, True) + + def prob(self, x): + """Get probability of `x` in multivariate normal distribution. + + .. math:: + + p(x | \mu, \Sigma) = \frac{1}{\sqrt{(2 \pi)^k \det(\Sigma)}} + \exp(-\frac{1}{2}(x - \mu)^T \Sigma^(-1) (x - \mu)) + + Args: + x (~nn.Variable): N-D array. + + Returns: + :class:`~nnabla.Variable`: N-D array. + + """ + k = self.loc.shape[1] + z = 1.0 / ((2 * np.pi) ** k * F.batch_det(self._diag_scale())) ** 0.5 + + diff = F.reshape(x - self.mean(), self.loc.shape + (1,), False) + inv = F.batch_inv(self._diag_scale()) + y = F.batch_matmul(diff, inv, True, False) + norm = F.reshape(F.batch_matmul(y, diff, False, False), (-1,), False) + return z * F.exp(-0.5 * norm) + + def entropy(self): + """Get entropy of multivariate normal distribution. + + .. math:: + + S = \frac{1}{2} \ln \det(2 \pi e \Sigma) + + Returns: + :class:`~nnabla.Variable`: N-D array. + + """ + det = F.batch_det(2.0 * np.pi * np.e * self._diag_scale()) + return 0.5 * F.log(det) + + def _diag_scale(self): + return F.matrix_diag(self.scale) + + def sample(self, shape=None): + """Sample points from multivariate normal distribution. + + .. math:: + + x \sim N(\mu, \Sigma) + + Args: + shape (:obj:`tuple`): Shape of sampled points. If this is omitted, + the returned shape is identical to :math:`\mu`. + + Returns: + :class:`~nnabla.Variable`: N-D array. + + """ + if shape is None: + shape = self.loc.shape + eps = F.randn(mu=0.0, sigma=1.0, shape=shape) + return self.mean() + self.scale * eps diff --git a/python/src/nnabla/experimental/distributions/normal.py b/python/src/nnabla/experimental/distributions/normal.py new file mode 100644 index 000000000..63661972c --- /dev/null +++ b/python/src/nnabla/experimental/distributions/normal.py @@ -0,0 +1,131 @@ +# Copyright (c) 2017 Sony Corporation. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import nnabla.functions as F + +from .distribution import Distribution + + +class Normal(Distribution): + """Normal distribution. + + Normal distribution defined as follows: + + .. math:: + + p(x | \mu, \sigma^2) = \frac{1}{\sqrt{2 \pi \sigma^2}} + \exp(-\frac{(x - \mu)^2}{2\sigma^2}) + + Args: + loc (~nnabla.Variable or numpy.ndarray): N-D array of :math:`\mu` in + definition. + scale (~nnabla.Variable or numpy.ndarray): N-D array of :math:`\sigma` + in definition. + + """ + + def __init__(self, loc, scale): + assert loc.shape == scale.shape,\ + 'For now, loc and scale must have same shape.' + if isinstance(loc, np.ndarray): + loc = nn.Variable.from_numpy_array(loc) + loc.persistent = True + if isinstance(scale, np.ndarray): + scale = nn.Variable.from_numpy_array(scale) + scale.persistent = True + self.loc = loc + self.scale = scale + + def mean(self): + """Get mean of normal distribution. + + Returns: + :class:`~nnabla.Variable`: N-D array identical to :math:`\mu`. + + """ + # to avoid no parent error + return F.identity(self.loc) + + def stddev(self): + """Get standard deviation of normal distribution. + + Returns: + :class:`~nnabla.Variable`: N-D array identical to :math:`\sigma`. + + """ + # to avoid no parent error + return F.identity(self.scale) + + def variance(self): + """Get variance of normal distribution. + + Returns: + :class:`~nnabla.Variable`: N-D array defined as :math:`\sigma^2`. + + """ + return self.stddev() ** 2 + + def prob(self, x): + """Get probability of :math:`x` in normal distribution. + + .. math:: + + p(x | \mu, \sigma^2) = \frac{1}{\sqrt{2 \pi \sigma^2}} + \exp(-\frac{(x - \mu)^2}{2\sigma^2}) + + Args: + x (~nnabla.Variable): N-D array. + + Returns: + :class:`~nnabla.Variable`: N-D array. + + """ + z = 1.0 / (2 * np.pi * self.variance()) ** 0.5 + return z * F.exp(-0.5 * ((x - self.mean()) ** 2) / self.variance()) + + def entropy(self): + """Get entropy of normal distribution. + + .. math:: + + S = \frac{1}{2}\log(2 \pi e \sigma^2) + + Returns: + :class:`~nnabla.Variable`: N-D array. + + """ + return F.log(self.stddev()) + 0.5 * np.log(2.0 * np.pi * np.e) + + def sample(self, shape=None): + """Sample points from normal distribution. + + .. math:: + + x \sim N(\mu, \sigma^2) + + Args: + shape (:obj:`tuple`): Shape of sampled points. If this is omitted, + the returned shape is identical to + :math:`\mu` and :math:`\sigma`. + + Returns: + :class:`~nnabla.Variable`: N-D array. + + """ + if shape is None: + shape = self.loc.shape + eps = F.randn(mu=0.0, sigma=1.0, shape=shape) + return self.mean() + self.stddev() * eps diff --git a/python/src/nnabla/experimental/distributions/uniform.py b/python/src/nnabla/experimental/distributions/uniform.py new file mode 100644 index 000000000..2dc3d121e --- /dev/null +++ b/python/src/nnabla/experimental/distributions/uniform.py @@ -0,0 +1,135 @@ +# Copyright (c) 2017 Sony Corporation. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import nnabla.functions as F + +from .distribution import Distribution + + +class Uniform(Distribution): + """Uniform distribution. + + Uniform distribution defined as :math:`x \sim U(low, high)`. + Values are uniformly sampled between :math:`low` and :math:`high`. + + Args: + low (~nnabla.Variable or numpy.ndarray): N-D array of :math:`low` in + definition. + high (~nnabla.Variable or numpy.ndarray): N-D arraya of :math:`high` in + definition. + + """ + + def __init__(self, low, high): + assert low.shape == high.shape,\ + 'For now, low and high must have same shape.' + if isinstance(low, np.ndarray): + low = nn.Variable.from_numpy_array(low) + low.persistent = True + if isinstance(high, np.ndarray): + high = nn.Variable.from_numpy_array(high) + high.persistent = True + self.low = low + self.high = high + + def mean(self): + """Get mean of uniform distribution. + + .. math:: + + \mu = low + \frac{high - low}{2} + + Returns: + :class:`~nnabla.Variable`: N-D array. + + """ + return self.low + (self.high - self.low) / 2.0 + + def stddev(self): + """Get standard deviation of uniform distribution. + + .. math:: + + \sigma = \frac{high - low}{\sqrt{12}} + + Returns: + :class:`~nnabla.Variable`: N-D array. + + """ + return (self.high - self.low) / np.sqrt(12.0) + + def variance(self): + """Get variance of uniform distribution. + + Returns: + :class:`~nnabla.Variable`: N-D array defined as :math:`\sigma^2`. + + """ + return self.stddev() ** 2 + + def prob(self, x): + """Get probability of :math:`x` in uniform distribution. + + .. math:: + + p(x | low, high) = \begin{cases} + \frac{1}{high - low} & (x \geq low \text{and} x \leq high) \\ + 0 & (otherwise) + \end{cases} + + Args: + x (~nnabla.Variable): N-D array. + + Returns: + :class:`~nnabla.Variable`: N-D array. + + """ + return 1.0 / (self.high - self.low) * F.less(self.low, x) \ + * F.greater(self.high, x) + + def entropy(self): + """Get entropy of uniform distribution. + + .. math:: + + S = \ln(high - low) + + Returns: + :class:`~nnabla.Variable`: N-D array. + + """ + return F.log(self.high - self.low) + + def sample(self, shape=None): + """Sample points from uniform distribution. + + .. math:: + + x \sim U(low, high) + + Args: + shape (:obj:`tuple`): Shape of sampled points. If this is omitted, + the returned shape is identical to + :math:`high` and :math:`low`. + + Returns: + :class:`~nnabla.Variable`: N-D array. + + """ + if shape is None: + shape = self.high.shape + eps = F.rand(low=0.0, high=1.0, shape=shape) + return self.low + (self.high - self.low) * eps diff --git a/python/test/experimental/distributions/distribution_test_util.py b/python/test/experimental/distributions/distribution_test_util.py new file mode 100644 index 000000000..3062fe4e3 --- /dev/null +++ b/python/test/experimental/distributions/distribution_test_util.py @@ -0,0 +1,107 @@ +import nnabla as nn +import nnabla.functions as F +import numpy as np +import scipy + +from nnabla.testing import assert_allclose + + +CHECK_COLUMNS = ['mean', 'stddev', 'variance', 'entropy'] +SCIPY_FUNCS = ['mean', 'std', 'var', 'entropy'] + + +def distribution_test_util(dist_fn, + scipy_fn, + param_fn, + skip_columns=[], + sample_shape=(10000, 10), + ref_sample_fn=None, + ref_columns={}, + ref_prob_fn=None): + # check mean and standard deviation of sampled values + _check_sample(dist_fn, scipy_fn, param_fn, ref_sample_fn) + + # check each parameters + _check_columns(dist_fn, scipy_fn, param_fn, skip_columns, sample_shape, + ref_columns) + + # check probability density function + _check_prob(dist_fn, scipy_fn, param_fn, sample_shape, ref_sample_fn, + ref_prob_fn) + + +def _check_sample(dist_fn, scipy_fn, param_fn, ref_sample_fn): + params = param_fn(shape=(10000, 10)) + + dist = dist_fn(*params) + + # nnabla sample + sample = dist.sample() + sample.forward(clear_buffer=True) + + if ref_sample_fn is None: + # scipy sample + scipy_dist = scipy_fn(*params) + ref_sample = scipy_dist.rvs(size=(10000, 10)) + else: + ref_sample = ref_sample_fn(*params, shape=(10000, 10)) + + assert_allclose(sample.d.mean(), ref_sample.mean(), atol=3e-2, rtol=3e-2) + assert_allclose(sample.d.std(), ref_sample.std(), atol=3e-2, rtol=3e-2) + + # nnabla sample_n + sample_n = dist.sample_n(2) + sample_n.forward(clear_buffer=True) + assert sample_n.d.shape == (10000, 2, 10) + + +def _check_prob(dist_fn, + scipy_fn, + param_fn, + sample_shape, + ref_sample_fn, + ref_prob_fn): + params = param_fn(shape=sample_shape) + dist = dist_fn(*params) + + if ref_sample_fn is None: + scipy_dist = scipy_fn(*params) + sample = scipy_dist.rvs(size=sample_shape) + else: + sample = ref_sample_fn(*params, shape=sample_shape) + + prob = dist.prob(nn.Variable.from_numpy_array(sample)) + prob.forward(clear_buffer=True) + + if ref_prob_fn is None: + scipy_dist = scipy_fn(*params) + ref_prob = scipy_dist.pdf(sample) + else: + ref_prob = ref_prob_fn(*params, sample=sample, shape=sample_shape) + + assert_allclose(prob.d, ref_prob, atol=3e-2, rtol=3e-2) + + +def _check_columns(dist_fn, + scipy_fn, + param_fn, + skip_columns, + sample_shape, + ref_columns): + param = param_fn(shape=sample_shape) + dist = dist_fn(*param) + scipy_dist = scipy_fn(*param) + + for i, column in enumerate(CHECK_COLUMNS): + if column in skip_columns: + continue + + v = getattr(dist, column)() + v.forward() + + if column in ref_columns: + ref_v = ref_columns[column](*param, shape=sample_shape) + else: + ref_v = getattr(scipy_dist, SCIPY_FUNCS[i])() + + assert_allclose(v.d, ref_v, atol=3e-2, rtol=3e-2) diff --git a/python/test/experimental/distributions/test_multivariate_normal.py b/python/test/experimental/distributions/test_multivariate_normal.py new file mode 100644 index 000000000..a3c9f7623 --- /dev/null +++ b/python/test/experimental/distributions/test_multivariate_normal.py @@ -0,0 +1,44 @@ +import numpy as np +import nnabla as nn + +from scipy import stats +from nnabla.experimental.distributions import MultivariateNormal +from distribution_test_util import distribution_test_util + + +def test_multivariate_normal(): + def param_fn(shape): + loc = np.random.random(shape) + scale = np.random.random(shape) + 1e-5 + return loc, scale + + def dist_fn(loc, scale): + loc = nn.Variable.from_numpy_array(loc) + scale = nn.Variable.from_numpy_array(scale) + return MultivariateNormal(loc=loc, scale=scale) + + def scipy_fn(loc, scale): + return stats.multivariate_normal(np.reshape(loc, (-1,)), + np.reshape(scale, (-1,))) + + def ref_sample_fn(loc, scale, shape): + return np.random.normal(loc, scale, size=shape) + + def ref_entropy(loc, scale, shape): + entropy = np.zeros(shape[0]) + for i, (l, s) in enumerate(zip(loc, scale)): + entropy[i] = stats.multivariate_normal.entropy(l, s) + return entropy + + def ref_prob_fn(loc, scale, sample, shape): + probs = np.zeros(shape[0]) + for i, (l, s) in enumerate(zip(loc, scale)): + probs[i] = stats.multivariate_normal.pdf(sample[i], l, s) + return probs + + # due to memory error at scipy with large dimension, use small sample size + distribution_test_util(dist_fn, scipy_fn, param_fn, + skip_columns=['mean', 'stddev', 'variance'], + sample_shape=(100, 10), ref_sample_fn=ref_sample_fn, + ref_columns={'entropy': ref_entropy}, + ref_prob_fn=ref_prob_fn) diff --git a/python/test/experimental/distributions/test_normal.py b/python/test/experimental/distributions/test_normal.py new file mode 100644 index 000000000..0513810a5 --- /dev/null +++ b/python/test/experimental/distributions/test_normal.py @@ -0,0 +1,23 @@ +import nnabla as nn +import numpy as np + +from scipy import stats +from nnabla.experimental.distributions import Normal +from distribution_test_util import distribution_test_util + + +def test_normal(): + def param_fn(shape): + loc = np.random.random(shape) + scale = np.random.random(shape) + 1e-5 + return loc, scale + + def dist_fn(loc, scale): + loc = nn.Variable.from_numpy_array(loc) + scale = nn.Variable.from_numpy_array(scale) + return Normal(loc=loc, scale=scale) + + def scipy_fn(loc, scale): + return stats.norm(loc, scale) + + distribution_test_util(dist_fn, scipy_fn, param_fn) diff --git a/python/test/experimental/distributions/test_uniform.py b/python/test/experimental/distributions/test_uniform.py new file mode 100644 index 000000000..832209386 --- /dev/null +++ b/python/test/experimental/distributions/test_uniform.py @@ -0,0 +1,23 @@ +import nnabla as nn +import numpy as np + +from scipy import stats +from nnabla.experimental.distributions import Uniform +from distribution_test_util import distribution_test_util + + +def test_uniform(): + def param_fn(shape): + low = np.random.random(shape) + high = low + np.random.random(shape) + return low, high + + def dist_fn(low, high): + low = nn.Variable.from_numpy_array(low) + high = nn.Variable.from_numpy_array(high) + return Uniform(low=low, high=high) + + def scipy_fn(low, high): + return stats.uniform(low, high - low) + + distribution_test_util(dist_fn, scipy_fn, param_fn)