From 7235df90a17fcf7f7aa18bb96df3d3c91a7bd53a Mon Sep 17 00:00:00 2001 From: lezcano Date: Fri, 22 Oct 2021 09:26:30 +0100 Subject: [PATCH 01/10] Updated readme and setup --- README.rst | 5 +++-- setup.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/README.rst b/README.rst index cec0326c..bb70f701 100644 --- a/README.rst +++ b/README.rst @@ -149,7 +149,7 @@ If one wants to use a parametrized tensor in different places in their model, or Of course, this ``with`` statement may be used simply inside the forward function where the parametrized layer is used several times. -These ideas fall in the context of parametrized optimization, where one wraps a tensor ``X`` with a function ``f``, and rather than using ``X``, uses ``f(X)``. Particular examples of this idea are pruning, weight normalization, and spectral normalization among others. This repository implements a framework to approach this kind of problems. The framework is currently `PR #33344`_ in PyTorch. All the functionality of this PR is located in `geotorch/parametrize.py`_. +These ideas fall in the context of parametrized optimization, where one wraps a tensor ``X`` with a function ``f``, and rather than using ``X``, uses ``f(X)``. Particular examples of this idea are pruning, weight normalization, and spectral normalization among others. This repository implements a framework to approach this kind of problems. This framework was accepted to core PyTorch 1.8. It can be found under `torch.nn.utils.parametrize`_ and `torch.nn.utils.parametrizations`_. As every space in GeoTorch is, at its core, a map from a flat space into a manifold, the tools implemented here also serve as a building block in normalizing flows. Using a factorized space such as |low|_ it is direct to compute the determinant of the transformation it defines, as we have direct access to the singular values of the layer. @@ -219,7 +219,8 @@ Please cite the following work if you found GeoTorch useful. This paper exposes :alt: License .. _here: https://github.com/Lezcano/geotorch/blob/master/examples/copying_problem.py#L16 -.. _PR #33344: https://github.com/pytorch/pytorch/pull/33344 +.. _torch.nn.utils.parametrize: https://pytorch.org/docs/stable/generated/torch.nn.utils.parametrize.register_parametrization.html +.. _torch.nn.utils.parametrizations: https://pytorch.org/docs/stable/generated/torch.nn.utils.parametrizations.orthogonal.html .. _geotorch/parametrize.py: https://github.com/Lezcano/geotorch/blob/master/geotorch/parametrize.py .. _examples/sequential_mnist.py: https://github.com/Lezcano/geotorch/blob/master/examples/sequential_mnist.py .. _examples/copying_problem.py: https://github.com/Lezcano/geotorch/blob/master/examples/copying_problem.py diff --git a/setup.py b/setup.py index 87cc1181..26165245 100644 --- a/setup.py +++ b/setup.py @@ -50,6 +50,6 @@ keywords=["Constrained Optimization", "Optimization on Manifolds", "Pytorch"], packages=find_packages(), python_requires=">=3.5", - install_requires=["torch>=1.5"], + install_requires=["torch>=1.8"], extras_require={"dev": DEV_REQUIRES, "test": TEST_REQUIRES}, ) From acdf7670ac17830f2163c2ef22a2798b06ed67a3 Mon Sep 17 00:00:00 2001 From: lezcano Date: Sun, 14 Nov 2021 19:35:24 +0000 Subject: [PATCH 02/10] Add support to parametrizations with multiple inputs --- .github/workflows/build.yml | 2 +- README.rst | 2 +- geotorch/almostorthogonal.py | 27 +- geotorch/constraints.py | 13 +- geotorch/fixedrank.py | 20 +- geotorch/lowrank.py | 117 +---- geotorch/parametrize.py | 881 +++++++++++++++++++++++------------ geotorch/pssdfixedrank.py | 28 +- geotorch/so.py | 24 +- geotorch/sphere.py | 6 +- geotorch/stiefel.py | 8 +- geotorch/symmetric.py | 102 +--- geotorch/utils.py | 7 +- test/test_integration.py | 77 +-- test/test_orthogonal.py | 5 +- 15 files changed, 684 insertions(+), 635 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 8023f4e7..7abf216d 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -33,7 +33,7 @@ jobs: if: ${{ matrix.os == 'windows-latest' }} run: | python -m pip install --upgrade pip - pip install torch===1.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html + pip install torch===1.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html pip install .[dev] - name: Lint with flake8 diff --git a/README.rst b/README.rst index bb70f701..67d9694c 100644 --- a/README.rst +++ b/README.rst @@ -135,7 +135,7 @@ You may try GeoTorch installing it with pip install git+https://github.com/Lezcano/geotorch/ -GeoTorch is tested in Linux, Mac, and Windows environments for Python >= 3.6. +GeoTorch is tested in Linux, Mac, and Windows environments for Python >= 3.6 and supports PyTorch >= 1.9 Sharing Weights, Parametrizations, and Normalizing Flows -------------------------------------------------------- diff --git a/geotorch/almostorthogonal.py b/geotorch/almostorthogonal.py index 1391ef5e..80d978a8 100644 --- a/geotorch/almostorthogonal.py +++ b/geotorch/almostorthogonal.py @@ -107,7 +107,7 @@ def in_manifold_singular_values(self, S, eps=1e-5): and ((S - 1.0).abs() <= lam).all().item() ) - def sample(self, distribution="uniform", init_=None, factorized=True): + def sample(self, distribution="uniform", init_=None): r""" Returns a randomly sampled orthogonal matrix according to the specified ``distribution``. The options are: @@ -142,30 +142,9 @@ def sample(self, distribution="uniform", init_=None, factorized=True): to some distribution. See `torch.init `_. Default: :math:`\operatorname{Uniform}(-\pi, \pi)` - factorized (bool): Optional. Return an SVD decomposition of the - sampled matrix as a tuple :math:`(U, \Sigma, V)`. - Using ``factorized=True`` is more efficient when the result is - used to initialize a parametrized tensor. - Default: ``True`` """ - with torch.no_grad(): - device = self[0].base.device - dtype = self[0].base.dtype - # Sample U and set S = 1, V = Id - U = self[0].sample(distribution=distribution, init_=init_) - S = torch.ones( - *(self.tensorial_size + (self.n,)), device=device, dtype=dtype - ) - V = torch.eye(self.n, device=device, dtype=dtype) - if len(self.tensorial_size) > 0: - V = V.repeat(*(self.tensorial_size + (1, 1))) - - if factorized: - return U, S, V - else: - Vt = V.transpose(-2, -1) - # Multiply the three of them, S as a diagonal matrix - return U @ (S.unsqueeze(-1).expand_as(Vt) * Vt) + # Sample an orthogonal matrix as U and return it + return self[0].sample(distribution=distribution, init_=init_) def extra_repr(self): return _extra_repr( diff --git a/geotorch/constraints.py b/geotorch/constraints.py index eb3fd78b..fee0939c 100644 --- a/geotorch/constraints.py +++ b/geotorch/constraints.py @@ -19,15 +19,16 @@ def _register_manifold(module, tensor_name, cls, *args): tensor = getattr(module, tensor_name) M = cls(tensor.size(), *args).to(device=tensor.device, dtype=tensor.dtype) - P.register_parametrization(module, tensor_name, M) # Initialize without checking in manifold X = M.sample() - param_list = module.parametrizations[tensor_name] - with torch.no_grad(): - for m in reversed(param_list): - X = m.right_inverse(X, check_in_manifold=False) - param_list.original.copy_(X) + if not P.is_parametrized(module, tensor_name): + with torch.no_grad(): + tensor.copy_(X) + else: + setattr(module, tensor_name, X) + + P.register_parametrization(module, tensor_name, M) return module diff --git a/geotorch/fixedrank.py b/geotorch/fixedrank.py index af81f82f..2b52f62f 100644 --- a/geotorch/fixedrank.py +++ b/geotorch/fixedrank.py @@ -89,7 +89,7 @@ def in_manifold_singular_values(self, S, eps=1e-5): infty_norm = D.abs().max(dim=-1).values return (infty_norm > eps).all().item() - def sample(self, init_=torch.nn.init.xavier_normal_, factorized=True, eps=5e-6): + def sample(self, init_=torch.nn.init.xavier_normal_, eps=5e-6): r""" Returns a randomly sampled matrix on the manifold by sampling a matrix according to ``init_`` and projecting it onto the manifold. @@ -110,11 +110,6 @@ def sample(self, init_=torch.nn.init.xavier_normal_, factorized=True, eps=5e-6): in place according to some distribution. See `torch.init `_. Default: ``torch.nn.init.xavier_normal_`` - factorized (bool): Optional. Return an SVD decomposition of the - sampled matrix as a tuple :math:`(U, \Sigma, V)`. - Using ``factorized=True`` is more efficient when the result is - used to initialize a parametrized tensor. - Default: ``True`` eps (float): Optional. Minimum singular value of the sampled matrix. Default: ``5e-6`` """ @@ -122,12 +117,7 @@ def sample(self, init_=torch.nn.init.xavier_normal_, factorized=True, eps=5e-6): with torch.no_grad(): # S >= 0, as given by torch.linalg.eigvalsh() S[S < eps] = eps - if factorized: - return U, S, V - else: - Vt = V.transpose(-2, -1) - # Multiply the three of them, S as a diagonal matrix - X = U @ (S.unsqueeze(-1).expand_as(Vt) * Vt) - if self.transposed: - X = X.transpose(-2, -1) - return X + X = (U * S) @ V.transpose(-2, -1) + if self.transposed: + X = X.transpose(-2, -1) + return X diff --git a/geotorch/lowrank.py b/geotorch/lowrank.py index 1ac51619..b5459444 100644 --- a/geotorch/lowrank.py +++ b/geotorch/lowrank.py @@ -1,13 +1,4 @@ import torch -from functools import partial - -try: - from torch.linalg import svd - - svd = partial(svd, full_matrices=False) -except ImportError: - from torch import svd - from .product import ProductManifold from .stiefel import Stiefel @@ -66,9 +57,7 @@ def frame(self, X): return U, S, V def submersion(self, U, S, V): - Vt = V.transpose(-2, -1) - # Multiply the three of them, S as a diagonal matrix - return U @ (S.unsqueeze(-1).expand_as(Vt) * Vt) + return (U * S) @ V.transpose(-2, -1) @transpose def forward(self, X): @@ -80,24 +69,19 @@ def frame_inv(self, X1, X2, X3): with torch.no_grad(): # X1 is lower-triangular # X2 is a vector - # X3 is upper-triangular + # X3 is lower-triangular size = self.tensorial_size + (self.n, self.k) - ret = torch.zeros(*size, dtype=X1.dtype, device=X1.device) + ret = torch.zeros(size, dtype=X1.dtype, device=X1.device) ret[..., : self.rank] += X1 ret[..., : self.rank, : self.rank] += torch.diag_embed(X2) - ret[..., : self.rank, :] += X3.transpose(-2, -1) + ret.transpose(-2, -1)[..., : self.rank] += X3 return ret def submersion_inv(self, X, check_in_manifold=True): - if isinstance(X, torch.Tensor): - U, S, V = svd(X) - if check_in_manifold and not self.in_manifold_singular_values(S): - raise InManifoldError(X, self) - else: - # We assume that we got he U S V factorized in a tuple / list - U, S, V = X - if check_in_manifold and not self.in_manifold_tuple(U, S, V): - raise InManifoldError(X, self) + U, S, Vt = torch.linalg.svd(X, full_matrices=False) + V = Vt.transpose(-2, -1) + if check_in_manifold and not self.in_manifold_singular_values(S): + raise InManifoldError(X, self) return U[..., : self.rank], S[..., : self.rank], V[..., : self.rank] @transpose @@ -126,74 +110,24 @@ def in_manifold_singular_values(self, S, eps=1e-5): infty_norm_err = D.abs().max(dim=-1).values return (infty_norm_err < eps).all() - def in_manifold_tuple(self, U, S, V, eps=1e-5): - return ( - self.in_manifold_singular_values(S, eps) - and self[0].in_manifold(U) - and self[1].in_manifold(S) - and self[2].in_manifold(V) - ) - def in_manifold(self, X, eps=1e-5): r""" - Checks that a matrix is in the manifold. The matrix may be given - factorized in a `3`-tuple :math:`(U, \Sigma, V)` of a matrix, vector, - and matrix representing an SVD of the matrix. - - - For tensors with more than 2 dimensions the first dimensions are - treated as batch dimensions. + Checks that a given matrix is in the manifold. Args: - X (torch.Tensor or tuple): The matrix to be checked or a tuple containing - :math:`(U, \Sigma, V)` as returned by ``torch.linalg.svd`` or - ``self.sample(factorized=True)``. + X (torch.Tensor or tuple): The input matrix or matrices of shape ``(*, n, k)``. eps (float): Optional. Threshold at which the singular values are considered to be zero Default: ``1e-5`` """ - if isinstance(X, tuple): - if len(X) == 3: - return self.in_manifold_tuple(X[0], X[1], X[2]) - else: - return False - else: - if X.size(-1) > X.size(-2): - X = X.transpose(-2, -1) - if X.size() != self.tensorial_size + (self.n, self.k): - return False - try: - S = torch.linalg.svdvals(X) - except AttributeError: - S = svd(X).S - return self.in_manifold_singular_values(S, eps) - - def project(self, X, factorized=True): - r""" - Project a matrix onto the manifold. - - If ``factorized==True``, it returns a tuple containing the SVD decomposition of - the matrix. - - Args: - X (torch.Tensor): Matrix to be projected onto the manifold - factorized (bool): Optional. Return an SVD decomposition of the - sampled matrix as a tuple :math:`(U, \Sigma, V)`. - Using ``factorized=True`` is more efficient when the result is - used to initialize a parametrized tensor. - Default: ``True`` - """ - U, S, V = svd(X) - U, S, V = U[..., : self.rank], S[..., : self.rank], V[..., : self.rank] - if factorized: - return U, S, V - else: - X = self.submersion(U, S, V) - if self.transposed: - X = X.transpose(-2, -1) - return X - - def sample(self, init_=torch.nn.init.xavier_normal_, factorized=True): + if X.size(-1) > X.size(-2): + X = X.transpose(-2, -1) + if X.size() != self.tensorial_size + (self.n, self.k): + return False + S = torch.linalg.svdvals(X) + return self.in_manifold_singular_values(S, eps) + + def sample(self, init_=torch.nn.init.xavier_normal_, factorized=False): r""" Returns a randomly sampled matrix on the manifold by sampling a matrix according to ``init_`` and projecting it onto the manifold. @@ -211,11 +145,6 @@ def sample(self, init_=torch.nn.init.xavier_normal_, factorized=True): in place according to some distribution. See `torch.init `_. Default: ``torch.nn.init.xavier_normal_`` - factorized (bool): Optional. Return an SVD decomposition of the - sampled matrix as a tuple :math:`(U, \Sigma, V)`. - Using ``factorized=True`` is more efficient when the result is - used to initialize a parametrized tensor. - Default: ``True`` """ with torch.no_grad(): device = self[0].base.device @@ -224,14 +153,12 @@ def sample(self, init_=torch.nn.init.xavier_normal_, factorized=True): *(self.tensorial_size + (self.n, self.k)), device=device, dtype=dtype ) init_(X) - U, S, V = svd(X) - U, S, V = U[..., : self.rank], S[..., : self.rank], V[..., : self.rank] + U, S, Vt = torch.linalg.svd(X, full_matrices=False) + U, S, Vt = U[..., : self.rank], S[..., : self.rank], Vt[..., : self.rank, :] if factorized: - return U, S, V + return U, S, Vt.transpose(-2, -1) else: - Vt = V.transpose(-2, -1) - # Multiply the three of them, S as a diagonal matrix - X = U @ (S.unsqueeze(-1).expand_as(Vt) * Vt) + X = (U * S) @ Vt if self.transposed: X = X.transpose(-2, -1) return X diff --git a/geotorch/parametrize.py b/geotorch/parametrize.py index f8c63ef5..2bd1a877 100644 --- a/geotorch/parametrize.py +++ b/geotorch/parametrize.py @@ -1,110 +1,293 @@ import torch -from torch.nn.modules.container import ModuleList, ModuleDict, Module -from torch.nn.parameter import Parameter -from torch import Tensor -from typing import Union, Optional, Iterable, Dict, Tuple -from contextlib import contextmanager +if int(torch.__version__.split(".")[1]) >= 10: + from torch.nn.utils.parametrize import * +else: + from torch.nn.modules.container import ModuleList, ModuleDict, Module + from torch.nn.parameter import Parameter + from torch import Tensor -_cache_enabled = 0 -_cache: Dict[Tuple[int, str], Optional[Tensor]] = {} + import collections + from contextlib import contextmanager + from typing import Union, Optional, Dict, Tuple, Sequence + _cache_enabled = 0 + _cache: Dict[Tuple[int, str], Optional[Tensor]] = {} -@contextmanager -def cached(): - r"""Context manager that enables the caching system within parametrizations - registered with :func:`register_parametrization`. + @contextmanager + def cached(): + r"""Context manager that enables the caching system within parametrizations + registered with :func:`register_parametrization`. - The value of the parametrized objects is computed and cached the first time - they are required when this context manager is active. The cached values are - discarded when leaving the context manager. + The value of the parametrized objects is computed and cached the first time + they are required when this context manager is active. The cached values are + discarded when leaving the context manager. - This is useful when using a parametrized parameter more than once in the forward pass. - An example of this is when parametrizing the recurrent kernel of an RNN or when - sharing weights. + This is useful when using a parametrized parameter more than once in the forward pass. + An example of this is when parametrizing the recurrent kernel of an RNN or when + sharing weights. - The simplest way to activate the cache is by wrapping the forward pass of the neural network + The simplest way to activate the cache is by wrapping the forward pass of the neural network - .. code-block:: python + .. code-block:: python - import torch.nn.utils.parametrize as P - ... - with P.cached(): - output = model(inputs) + import torch.nn.utils.parametrize as P + ... + with P.cached(): + output = model(inputs) - in training and evaluation. One may also wrap the parts of the modules that use - several times the parametrized tensors. For example, the loop of an RNN with a - parametrized recurrent kernel: + in training and evaluation. One may also wrap the parts of the modules that use + several times the parametrized tensors. For example, the loop of an RNN with a + parametrized recurrent kernel: - .. code-block:: python - - with P.cached(): - for x in xs: - out_rnn = self.rnn_cell(x, out_rnn) - """ - global _cache - global _cache_enabled - _cache_enabled += 1 - try: - yield - finally: - _cache_enabled -= 1 - if not _cache_enabled: - _cache = {} - - -class ParametrizationList(ModuleList): - r"""A sequential container that holds and manages the ``original`` parameter or buffer of - a parametrized :class:`torch.nn.Module`. It is the type of - ``module.parametrizations[tensor_name]`` when ``module[tensor_name]`` has been parametrized - with :func:`register_parametrization`. - - .. note :: - This class is used internally by :func:`register_parametrization`. It is documented - here for completeness. It should not be instantiated by the user. + def set_original_(self, value: Tensor) -> None: + r"""This method is called when assigning to a parametrized tensor. - Args: - modules (iterable): an iterable of modules representing the parametrizations - original (Parameter or Tensor): parameter or buffer that is parametrized - """ - original: Tensor - - def __init__( - self, modules: Iterable[Module], original: Union[Tensor, Parameter] - ) -> None: - super().__init__(modules) - if isinstance(original, Parameter): - self.register_parameter("original", original) + with P.cached(): + for x in xs: + out_rnn = self.rnn_cell(x, out_rnn) + """ + global _cache + global _cache_enabled + _cache_enabled += 1 + try: + yield + finally: + _cache_enabled -= 1 + if not _cache_enabled: + _cache = {} + + def _register_parameter_or_buffer(module, name, X): + if isinstance(X, Parameter): + module.register_parameter(name, X) else: - self.register_buffer("original", original) + module.register_buffer(name, X) - def set_original_(self, value: Tensor) -> None: - r"""This method is called when assigning to a parametrized tensor. + class ParametrizationList(ModuleList): + r"""A sequential container that holds and manages the ``original`` or ``original0``, ``original1``, ... + parameters or buffers of a parametrized :class:`torch.nn.Module`. - It calls the methods ``right_inverse`` (see :func:`register_parametrization`) - of the parametrizations in the inverse order that they have been registered. - Then, it assigns the result to ``self.original``. + It is the type of ``module.parametrizations[tensor_name]`` when ``module[tensor_name]`` + has been parametrized with :func:`register_parametrization`. - Args: - value (Tensor): Value to which initialize the module + If the first registered parmetrization has a ``right_inverse`` that returns one tensor or + does not have a ``right_inverse`` (in which case we assume that ``right_inverse`` is the identity), + it will hold the tensor under the name ``original``. + If it has a ``right_inverse`` that returns more than one tensor, these will be registered as + ``original0``, ``original1``, ... - Raises: - RuntimeError: if any of the parametrizations do not implement a ``right_inverse`` method + .. warning:: + This class is used internally by :func:`register_parametrization`. It is documented + here for completeness. It shall not be instantiated by the user. + + Args: + modules (sequence): sequence of modules representing the parametrizations + original (Parameter or Tensor): parameter or buffer that is parametrized + unsafe (bool): a boolean flag that denotes whether the parametrization + may change the dtype and shape of the tensor. Default: `False` + Warning: the parametrization is not checked for consistency upon registration. + Enable this flag at your own risk. """ - with torch.no_grad(): - # See https://github.com/pytorch/pytorch/issues/53103 - for module in reversed(self): # type: ignore - if hasattr(module, "right_inverse"): - value = module.right_inverse(value) - else: - raise RuntimeError( - "The parametrization '{}' does not implement a 'right_inverse' method. " - "Assigning to a parametrized tensor is only possible when all the parametrizations " - "implement a 'right_inverse' method.".format( - module.__class__.__name__ + original: Tensor + unsafe: bool + + def __init__( + self, + modules: Sequence[Module], + original: Union[Tensor, Parameter], + unsafe: bool = False, + ) -> None: + # We require this because we need to treat differently the first parametrization + # This should never throw, unless this class is used from the outside + if len(modules) == 0: + raise ValueError("ParametrizationList requires one or more modules.") + + super().__init__(modules) + self.unsafe = unsafe + + # In plain words: + # module.weight must keep its dtype and shape. + # Furthermore, if there is no right_inverse or the right_inverse returns a tensor, + # this should be of the same dtype as the original tensor + # + # We check that the following invariants hold: + # X = module.weight + # Y = param.right_inverse(X) + # assert isinstance(Y, Tensor) or + # (isinstance(Y, collections.abc.Sequence) and all(isinstance(t, Tensor) for t in Y)) + # Z = param(Y) if isisntance(Y, Tensor) else param(*Y) + # # Consistency checks + # assert X.dtype == Z.dtype and X.shape == Z.shape + # # If it has one input, this allows to be able to use set_ to be able to + # # move data to/from the original tensor without changing its id (which is what the + # # optimiser uses to track parameters) + # if isinstance(Y, Tensor) + # assert X.dtype == Y.dtype + # Below we use original = X, new = Y + + original_shape = original.shape + original_dtype = original.dtype + + # Compute new + with torch.no_grad(): + new = original + for module in reversed(self): # type: ignore[call-overload] + if hasattr(module, "right_inverse"): + try: + new = module.right_inverse(new) + except NotImplementedError: + pass + # else, or if it throws, we assume that right_inverse is the identity + + if not isinstance(new, Tensor) and not isinstance( + new, collections.abc.Sequence + ): + raise ValueError( + "'right_inverse' must return a Tensor or a Sequence of tensors (list, tuple...). " + f"Got {type(new).__name__}" + ) + + # Set the number of original tensors + self.is_tensor = isinstance(new, Tensor) + self.ntensors = 1 if self.is_tensor else len(new) + + # Register the tensor(s) + if self.is_tensor: + if original.dtype != new.dtype: + raise ValueError( + "When `right_inverse` outputs one tensor, it may not change the dtype.\n" + f"original.dtype: {original.dtype}\n" + f"right_inverse(original).dtype: {new.dtype}" + ) + # Set the original to original so that the user does not need to re-register the parameter + # manually in the optimiser + with torch.no_grad(): + original.set_(new) # type: ignore[call-overload] + _register_parameter_or_buffer(self, "original", original) + else: + for i, originali in enumerate(new): + if not isinstance(originali, Tensor): + raise ValueError( + "'right_inverse' must return a Tensor or a Sequence of tensors " + "(list, tuple...). " + f"Got element {i} of the sequence with type {type(originali).__name__}." ) + + # If the original tensor was a Parameter that required grad, we expect the user to + # add the new parameters to the optimizer after registering the parametrization + # (this is documented) + if isinstance(original, Parameter): + originali = Parameter(originali) + originali.requires_grad_(original.requires_grad) + _register_parameter_or_buffer(self, f"original{i}", originali) + + if not self.unsafe: + # Consistency checks: + # Since f : A -> B, right_inverse : B -> A, Z and original should live in B + # Z = forward(right_inverse(original)) + Z = self() + if not isinstance(Z, Tensor): + raise ValueError( + f"A parametrization must return a tensor. Got {type(Z).__name__}." + ) + if Z.dtype != original_dtype: + raise ValueError( + "Registering a parametrization may not change the dtype of the tensor, unless `unsafe` flag is enabled.\n" + f"unparametrized dtype: {original_dtype}\n" + f"parametrized dtype: {Z.dtype}" + ) + if Z.shape != original_shape: + raise ValueError( + "Registering a parametrization may not change the shape of the tensor, unless `unsafe` flag is enabled.\n" + f"unparametrized shape: {original_shape}\n" + f"parametrized shape: {Z.shape}" ) - self.original.copy_(value) + + def right_inverse(self, value: Tensor) -> None: + r"""Calls the methods ``right_inverse`` (see :func:`register_parametrization`) + of the parametrizations in the inverse order they were registered in. + Then, it stores the result in ``self.original`` if ``right_inverse`` outputs one tensor + or in ``self.original0``, ``self.original1``, ... if it outputs several. + + Args: + value (Tensor): Value to which initialize the module + """ + # All the exceptions in this function should almost never throw. + # They could throw if, for example, right_inverse function returns a different + # dtype when given a different input, which should most likely be caused by a + # bug in the user's code + + with torch.no_grad(): + # See https://github.com/pytorch/pytorch/issues/53103 + for module in reversed(self): # type: ignore[call-overload] + if hasattr(module, "right_inverse"): + value = module.right_inverse(value) + else: + raise RuntimeError( + f"parametrization {type(module).__name__} does not implement " + "right_inverse." + ) + if self.is_tensor: + # These exceptions should only throw when a right_inverse function does not + # return the same dtype for every input, which should most likely be caused by a bug + if not isinstance(value, Tensor): + raise ValueError( + f"`right_inverse` should return a tensor. Got {type(value).__name__}" + ) + if value.dtype != self.original.dtype: + raise ValueError( + f"The tensor returned by `right_inverse` has dtype {value.dtype} " + f"while `original` has dtype {self.original.dtype}" + ) + # We know that the result is going to have the same dtype + self.original.set_(value) # type: ignore[call-overload] + else: + if not isinstance(value, collections.abc.Sequence): + raise ValueError( + "'right_inverse' must return a sequence of tensors. " + f"Got {type(value).__name__}." + ) + if len(value) != self.ntensors: + raise ValueError( + "'right_inverse' must return a sequence of tensors of length " + f"{self.ntensors}. Got a sequence of lenght {len(value)}." + ) + for i, tensor in enumerate(value): + original_i = getattr(self, f"original{i}") + if not isinstance(tensor, Tensor): + raise ValueError( + f"`right_inverse` must return a sequence of tensors. " + f"Got element {i} of type {type(tensor).__name__}" + ) + if original_i.dtype != tensor.dtype: + raise ValueError( + f"Tensor {i} returned by `right_inverse` has dtype {tensor.dtype} " + f"while `original{i}` has dtype {original_i.dtype}" + ) + original_i.set_(tensor) + + def forward(self) -> Tensor: + # Unpack the originals for the first parametrization + if self.is_tensor: + x = self[0](self.original) + else: + originals = ( + getattr(self, f"original{i}") for i in range(self.ntensors) + ) + x = self[0](*originals) + # It's not possible to call self[1:] here, so we have to be a bit more cryptic + # Also we want to skip all non-integer keys + curr_idx = 1 + while hasattr(self, str(curr_idx)): + x = self[curr_idx](x) + curr_idx += 1 + return x + + def _inject_new_class(module: Module) -> None: + r"""Sets up a module to be parametrized. + + This works by substituting the class of the module by a class + that extends it to be able to inject a property def forward(self) -> Tensor: x = self.original @@ -118,227 +301,283 @@ def forward(self) -> Tensor: ) return x - -def _inject_new_class(module: Module) -> None: - r"""Sets up the parametrization mechanism used by parametrizations. - - This works by substituting the class of the module by a class - that extends it to be able to inject a property - - Args: - module (nn.Module): module into which to inject the property - """ - cls = module.__class__ - - def getstate(self): - raise RuntimeError( - "Serialization of parametrized modules is only " - "supported through state_dict(). See:\n" - "https://pytorch.org/tutorials/beginner/saving_loading_models.html" - "#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training" + param_cls = type( + f"Parametrized{cls.__name__}", + (cls,), + { + "__getstate__": getstate, + }, ) - param_cls = type( - "Parametrized{}".format(cls.__name__), - (cls,), - { - "__getstate__": getstate, - }, - ) + module.__class__ = param_cls - module.__class__ = param_cls + def _inject_property(module: Module, tensor_name: str) -> None: + r"""Injects a property into module[tensor_name]. + It assumes that the class in the module has already been modified from its + original one using _inject_new_class and that the tensor under :attr:`tensor_name` + has already been moved out -def _inject_property(module: Module, tensor_name: str) -> None: - r"""Injects a property into module[tensor_name]. - - It assumes that the class in the module has already been modified from its - original one using _inject_new_class and that the tensor under :attr:`tensor_name` - has already been moved out - - Args: - module (nn.Module): module into which to inject the property - tensor_name (str): name of the name of the property to create - """ - # We check the precondition. - # This should never fire if register_parametrization is correctly implemented - assert not hasattr(module, tensor_name) - - def get_parametrized(self) -> Tensor: - global _cache + Args: + module (nn.Module): module into which to inject the property + tensor_name (str): name of the name of the property to create + """ + # We check the precondition. + # This should never fire if register_parametrization is correctly implemented + assert not hasattr(module, tensor_name) - parametrization = self.parametrizations[tensor_name] - if _cache_enabled: + @torch.jit.unused + def get_cached_parametrization(parametrization) -> Tensor: + global _cache key = (id(module), tensor_name) tensor = _cache.get(key) if tensor is None: tensor = parametrization() _cache[key] = tensor return tensor - else: - # If caching is not active, this function just evaluates the parametrization - return parametrization() - def set_original(self, value: Tensor) -> None: - self.parametrizations[tensor_name].set_original_(value) + def get_parametrized(self) -> Tensor: + parametrization = self.parametrizations[tensor_name] + if _cache_enabled: + if torch.jit.is_scripting(): + # Scripting + raise RuntimeError( + "Caching is not implemented for scripting. " + "Either disable caching or avoid scripting." + ) + elif torch._C._get_tracing_state() is not None: + # Tracing + raise RuntimeError( + "Cannot trace a model while caching parametrizations." + ) + else: + return get_cached_parametrization(parametrization) + else: + # If caching is not active, this function just evaluates the parametrization + return parametrization() + + def set_original(self, value: Tensor) -> None: + self.parametrizations[tensor_name].right_inverse(value) + + setattr(module.__class__, tensor_name, property(get_parametrized, set_original)) + + def register_parametrization( + module: Module, + tensor_name: str, + parametrization: Module, + *, + unsafe: bool = False, + ) -> Module: + r"""Adds a parametrization to a tensor in a module. + + Assume that ``tensor_name="weight"`` for simplicity. When accessing ``module.weight``, + the module will return the parametrized version ``parametrization(module.weight)``. + If the original tensor requires a gradient, the backward pass will differentiate + through :attr:`parametrization`, and the optimizer will update the tensor accordingly. - setattr(module.__class__, tensor_name, property(get_parametrized, set_original)) +def _inject_new_class(module: Module) -> None: + r"""Sets up the parametrization mechanism used by parametrizations. + The list of parametrizations on the tensor ``weight`` will be accessible under + ``module.parametrizations.weight``. -def register_parametrization( - module: Module, tensor_name: str, parametrization: Module -) -> Module: - r"""Adds a parametrization to a tensor in a module. + Args: + module (nn.Module): module into which to inject the property + """ + cls = module.__class__ - Assume that ``tensor_name="weight"`` for simplicity. When accessing ``module.weight``, - the module will return the parametrized version ``parametrization(module.weight)``. - If the original tensor requires a gradient, the backward pass will differentiate - through the :attr:`parametrization`, and the optimizer will update the tensor accordingly. + Parametrizations may be concatenated by registering several parametrizations + on the same attribute. - The first time that a module registers a parametrization, this function will add an attribute - ``parametrizations`` to the module of type :class:`~ParametrizationList`. + The training mode of a registered parametrization is updated on registration + to match the training mode of the host module - The list of parametrizations on a tensor will be accessible under - ``module.parametrizations.weight``. + Parametrized parameters and buffers have an inbuilt caching system that can be activated + using the context manager :func:`cached`. - The original tensor will be accessible under - ``module.parametrizations.weight.original``. + param_cls = type( + "Parametrized{}".format(cls.__name__), + (cls,), + { + "__getstate__": getstate, + }, + ) - Parametrizations may be concatenated by registering several parametrizations - on the same attribute. + module.__class__ = param_cls - Parametrized parameters and buffers have an inbuilt caching system that can be activated - using the context manager :func:`cached`. + def right_inverse(self, X: Tensor) -> Union[Tensor, Sequence[Tensor]] - A :attr:`parametrization` may optionally implement a method with signature + This method is called on the unparametrized tensor when the first parametrization + is registered to compute the initial value of the original tensor. + If this method is not implemented, the original tensor will be just the unparametrized tensor. - .. code-block:: python + If all the parametrizations registered on a tensor implement `right_inverse` it is possible + to initialize a parametrized tensor by assigning to it, as shown in the example below. - def right_inverse(self, X: Tensor) -> Tensor + It is possible for the first parametrization to depend on several inputs. + This may be implemented returning a tuple of tensors from ``right_inverse`` + (see the example implementation of a ``RankOne`` parametrization below). - If :attr:`parametrization` implements this method, it will be possible to assign - to the parametrized tensor. This may be used to initialize the tensor, as shown in the example. + In this case, the unconstrained tensors are also located under ``module.parametrizations.weight`` + with names ``original0``, ``original1``,... - In most situations, ``right_inverse`` will be a function such that - ``forward(right_inverse(X)) == X`` (see - `right inverse `_). - Sometimes, when the parametrization is not surjective, it may be reasonable - to relax this, as shown in the example below. + .. note:: - Args: - module (nn.Module): module on which to register the parametrization - tensor_name (str): name of the parameter or buffer on which to register - the parametrization - parametrization (nn.Module): the parametrization to register - - Returns: - Module: module - - Raises: - ValueError: if the module does not have a parameter or a buffer named :attr:`tensor_name` - - Examples: - >>> import torch - >>> import torch.nn.utils.parametrize as P - >>> - >>> class Symmetric(torch.nn.Module): - >>> def forward(self, X): - >>> return X.triu() + X.triu(1).T # Return a symmetric matrix - >>> - >>> def right_inverse(self, A): - >>> return A.triu() - >>> - >>> m = torch.nn.Linear(5, 5) - >>> P.register_parametrization(m, "weight", Symmetric()) - >>> print(torch.allclose(m.weight, m.weight.T)) # m.weight is now symmetric - True - >>> A = torch.rand(5, 5) - >>> A = A + A.T # A is now symmetric - >>> m.weight = A # Initialize the weight to be the symmetric matrix A - >>> print(torch.allclose(m.weight, A)) - True - """ - if is_parametrized(module, tensor_name): - # Just add the new parametrization to the parametrization list - module.parametrizations[tensor_name].append(parametrization) # type: ignore - elif tensor_name in module._buffers or tensor_name in module._parameters: - # Set the parametrization mechanism - # Fetch the original buffer or parameter - original = getattr(module, tensor_name) - # Delete the previous parameter or buffer - delattr(module, tensor_name) - # If this is the first parametrization registered on the module, - # we prepare the module to inject the property - if not is_parametrized(module): - # Change the class - _inject_new_class(module) - # Inject the a ``ModuleDict`` into the instance under module.parametrizations - module.parametrizations = ModuleDict() - # Add a property into the class - _inject_property(module, tensor_name) - # Add a ParametrizationList - module.parametrizations[tensor_name] = ParametrizationList( # type: ignore - [parametrization], original - ) - else: - raise ValueError( - "Module '{}' does not have a parameter, a buffer, or a " - "parametrized element with name '{}'".format(module, tensor_name) - ) - return module + If unsafe=False (default) both the forward and right_inverse methods will be called + once to perform a number of consistency checks. + If unsafe=True, then right_inverse will be called if the tensor is not parametrized, + and nothing will be called otherwise. + .. note:: -def is_parametrized(module: Module, tensor_name: Optional[str] = None) -> bool: - r"""Returns ``True`` if module has an active parametrization. + In most situations, ``right_inverse`` will be a function such that + ``forward(right_inverse(X)) == X`` (see + `right inverse `_). + Sometimes, when the parametrization is not surjective, it may be reasonable + to relax this. - If the argument :attr:`tensor_name` is specified, returns ``True`` if - ``module[tensor_name]`` is parametrized. + .. warning:: - Args: - module (nn.Module): module to query - name (str, optional): attribute in the module to query - Default: ``None`` - """ - parametrizations = getattr(module, "parametrizations", None) - if parametrizations is None or not isinstance(parametrizations, ModuleDict): - return False - if tensor_name is None: - # Check that there is at least one parametrized buffer or Parameter - return len(parametrizations) > 0 - else: - return tensor_name in parametrizations - - -def remove_parametrizations( - module: Module, tensor_name: str, leave_parametrized: bool = True -) -> Module: - r"""Removes the parametrizations on a tensor in a module. - - - If ``leave_parametrized=True``, ``module[tensor_name]`` will be set to - its current output. In this case, the parametrization shall not change the ``dtype`` - of the tensor. - - If ``leave_parametrized=False``, ``module[tensor_name]`` will be set to - the unparametrised tensor in ``module.parametrizations[tensor_name].original``. + If a parametrization depends on several inputs, :func:`~register_parametrization` + will register a number of new parameters. If such parametrization is registered + after the optimizer is created, these new parameters will need to be added manually + to the optimizer. See :meth:`torch.Optimizer.add_param_group`. - Args: - module (nn.Module): module from which remove the parametrization - tensor_name (str): name of the parametrization to be removed - leave_parametrized (bool, optional): leave the attribute :attr:`tensor_name` parametrized. - Default: ``True`` - - Returns: - Module: module - - Raises: - ValueError: if ``module[tensor_name]`` is not parametrized - ValueError: if ``leave_parametrized=True`` and the parametrization changes the size or dtype - of the tensor - """ + Args: + module (nn.Module): module on which to register the parametrization + tensor_name (str): name of the parameter or buffer on which to register + the parametrization + parametrization (nn.Module): the parametrization to register + Keyword args: + unsafe (bool): a boolean flag that denotes whether the parametrization + may change the dtype and shape of the tensor. Default: `False` + Warning: the parametrization is not checked for consistency upon registration. + Enable this flag at your own risk. + + Raises: + ValueError: if the module does not have a parameter or a buffer named :attr:`tensor_name` + + Examples: + >>> import torch + >>> import torch.nn as nn + >>> import torch.nn.utils.parametrize as P + >>> + >>> class Symmetric(nn.Module): + >>> def forward(self, X): + >>> return X.triu() + X.triu(1).T # Return a symmetric matrix + >>> + >>> def right_inverse(self, A): + >>> return A.triu() + >>> + >>> m = nn.Linear(5, 5) + >>> P.register_parametrization(m, "weight", Symmetric()) + >>> print(torch.allclose(m.weight, m.weight.T)) # m.weight is now symmetric + True + >>> A = torch.rand(5, 5) + >>> A = A + A.T # A is now symmetric + >>> m.weight = A # Initialize the weight to be the symmetric matrix A + >>> print(torch.allclose(m.weight, A)) + True + + >>> class RankOne(nn.Module): + >>> def forward(self, x, y): + >>> # Form a rank 1 matrix multiplying two vectors + >>> return x.unsqueeze(-1) @ y.unsqueeze(-2) + >>> + >>> def right_inverse(self, Z): + >>> # Project Z onto the rank 1 matrices + >>> U, S, Vh = torch.linalg.svd(Z, full_matrices=False) + >>> # Return rescaled singular vectors + >>> s0_sqrt = S[0].sqrt().unsqueeze(-1) + >>> return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt + >>> + >>> linear_rank_one = P.register_parametrization(nn.Linear(4, 4), "weight", RankOne()) + >>> print(torch.linalg.matrix_rank(linear_rank_one.weight).item()) + 1 - if not is_parametrized(module, tensor_name): - raise ValueError( - "Module {} does not have a parametrization on {}".format( - module, tensor_name + """ + parametrization.train(module.training) + if is_parametrized(module, tensor_name): + # Correctness checks. + # If A is the space of tensors with shape and dtype equal to module.weight + # we check that parametrization.forward and parametrization.right_inverse are + # functions from A to A + if not unsafe: + Y = getattr(module, tensor_name) + X = parametrization(Y) + if not isinstance(X, Tensor): + raise ValueError( + f"A parametrization must return a tensor. Got {type(X).__name__}." + ) + if X.dtype != Y.dtype: + raise ValueError( + "Registering a parametrization may not change the dtype of the tensor, unless the `unsafe` flag is enabled.\n" + f"module.{tensor_name}.dtype: {Y.dtype}\n" + f"parametrization(module.{tensor_name}).dtype: {X.dtype}" + ) + if X.shape != Y.shape: + raise ValueError( + "Registering a parametrization may not change the shape of the tensor, unless the `unsafe` flag is enabled.\n" + f"module.{tensor_name}.shape: {Y.shape}\n" + f"parametrization(module.{tensor_name}).shape: {X.shape}" + ) + if hasattr(parametrization, "right_inverse"): + try: + Z = parametrization.right_inverse(X) # type: ignore[operator] + except NotImplementedError: + pass + else: + if not isinstance(Z, Tensor): + raise ValueError( + f"parametrization.right_inverse must return a tensor. Got: {type(Z).__name__}" + ) + if Z.dtype != Y.dtype: + raise ValueError( + "The tensor returned by parametrization.right_inverse must have the same dtype " + f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n" + f"module.{tensor_name}.dtype: {Y.dtype}\n" + f"returned dtype: {Z.dtype}" + ) + if Z.shape != Y.shape: + raise ValueError( + "The tensor returned by parametrization.right_inverse must have the same shape " + f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n" + f"module.{tensor_name}.shape: {Y.shape}\n" + f"returned shape: {Z.shape}" + ) + # else right_inverse is assumed to be the identity + + # add the new parametrization to the parametrization list + assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy + module.parametrizations[tensor_name].append(parametrization) + # If unsafe was True in previous parametrization, keep it enabled + module.parametrizations[tensor_name].unsafe |= unsafe # type: ignore[index, union-attr] + elif tensor_name in module._buffers or tensor_name in module._parameters: + # Set the parametrization mechanism + # Fetch the original buffer or parameter + original = getattr(module, tensor_name) + # We create this early to check for possible errors + parametrizations = ParametrizationList( + [parametrization], original, unsafe=unsafe + ) + # Delete the previous parameter or buffer + delattr(module, tensor_name) + # If this is the first parametrization registered on the module, + # we prepare the module to inject the property + if not is_parametrized(module): + # Change the class + _inject_new_class(module) + # Inject a ``ModuleDict`` into the instance under module.parametrizations + module.parametrizations = ModuleDict() + # Add a property into the class + _inject_property(module, tensor_name) + # Add a ParametrizationList + assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy + module.parametrizations[tensor_name] = parametrizations + else: + raise ValueError( + f"Module '{module}' does not have a parameter, a buffer, or a " + f"parametrized element with name '{tensor_name}'" ) ) @@ -353,27 +592,79 @@ def remove_parametrizations( with torch.no_grad(): original.set_(t) else: + return tensor_name in parametrizations + + def remove_parametrizations( + module: Module, tensor_name: str, leave_parametrized: bool = True + ) -> Module: + r"""Removes the parametrizations on a tensor in a module. + + - If ``leave_parametrized=True``, ``module[tensor_name]`` will be set to + its current output. In this case, the parametrization shall not change the ``dtype`` + of the tensor. + - If ``leave_parametrized=False``, ``module[tensor_name]`` will be set to + the unparametrised tensor in ``module.parametrizations[tensor_name].original``. + This is only possible when the parametrization depends on just one tensor. + + Args: + module (nn.Module): module from which remove the parametrization + tensor_name (str): name of the parametrization to be removed + leave_parametrized (bool, optional): leave the attribute :attr:`tensor_name` parametrized. + Default: ``True`` + + Returns: + Module: module + + Raises: + ValueError: if ``module[tensor_name]`` is not parametrized + ValueError: if ``leave_parametrized=False`` and the parametrization depends on several tensors + """ + + if not is_parametrized(module, tensor_name): raise ValueError( - "The parametrization changes the dtype of the tensor from {} to {}. " - "It is not supported to leave the tensor parametrized (`leave_parametrized=True`) " - "in this case.".format(original.dtype, t.dtype) + f"Module {module} does not have a parametrization on {tensor_name}" ) - # Delete the property that manages the parametrization - delattr(module.__class__, tensor_name) - # Delete the ParametrizationList - del module.parametrizations[tensor_name] # type: ignore - - # Restore the parameter / buffer into the main class - if isinstance(original, Parameter): - module.register_parameter(tensor_name, original) - else: - module.register_buffer(tensor_name, original) - - # Roll back the parametrized class if no other buffer or parameter - # is currently parametrized in this class - if not is_parametrized(module): - delattr(module, "parametrizations") - # Restore class - orig_cls = module.__class__.__bases__[0] - module.__class__ = orig_cls - return module + + # Fetch the original tensor + assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy + parametrizations = module.parametrizations[tensor_name] + if parametrizations.is_tensor: + original = parametrizations.original + if leave_parametrized: + with torch.no_grad(): + t = getattr(module, tensor_name) + # We know they have the same dtype because we have checked this when registering the + # parametrizations. As such, we can use set_ + # We do this so that the parameter does not to change the id() + # This way the user does not need to update the optimizer + with torch.no_grad(): + original.set_(t) + else: + if leave_parametrized: + # We cannot use no_grad because we need to know whether one or more + # original tensors required grad + t = getattr(module, tensor_name) + # We'll have to trust the user to add it to the optimizer + original = Parameter(t) if t.requires_grad else t + else: + raise ValueError( + "Cannot leave unparametrized (`leave_parametrized=False`) a tensor " + "that is parametrized in terms of a sequence of tensors." + ) + + # Delete the property that manages the parametrization + delattr(module.__class__, tensor_name) + # Delete the ParametrizationList + del module.parametrizations[tensor_name] + + # Restore the parameter / buffer into the main class + _register_parameter_or_buffer(module, tensor_name, original) + + # Roll back the parametrized class if no other buffer or parameter + # is currently parametrized in this class + if not is_parametrized(module): + delattr(module, "parametrizations") + # Restore class + orig_cls = module.__class__.__bases__[0] + module.__class__ = orig_cls + return module diff --git a/geotorch/pssdfixedrank.py b/geotorch/pssdfixedrank.py index ef6ed7a6..c6edc094 100644 --- a/geotorch/pssdfixedrank.py +++ b/geotorch/pssdfixedrank.py @@ -51,23 +51,20 @@ def parse_f(f): def in_manifold_eigen(self, L, eps=1e-6): r""" - Checks that an ordered vector of eigenvalues values is in the manifold. - - For tensors with more than 1 dimension the first dimensions are - treated as batch dimensions. + Checks that an ascending ordered vector of eigenvalues is in the manifold. Args: - L (torch.Tensor): Vector of eigenvalues + L (torch.Tensor): Vector of eigenvalues of shape `(*, rank)` eps (float): Optional. Threshold at which the eigenvalues are considered to be zero Default: ``1e-6`` """ return ( super().in_manifold_eigen(L, eps) - and (L[..., : self.rank] >= eps).all().item() + and (L[..., -self.rank :] >= eps).all().item() ) - def sample(self, init_=torch.nn.init.xavier_normal_, factorized=True, eps=5e-6): + def sample(self, init_=torch.nn.init.xavier_normal_, eps=5e-6): r""" Returns a randomly sampled matrix on the manifold as @@ -92,22 +89,11 @@ def sample(self, init_=torch.nn.init.xavier_normal_, factorized=True, eps=5e-6): to some distribution. See `torch.init `_. Default: ``torch.nn.init.xavier_normal_`` - factorized (bool): Optional. Return the tuple :math:`(\Lambda, Q)` with an - eigenvalue decomposition of the sampled matrix. This can also be used - to initialize the layer. - Default: ``True`` eps (float): Optional. Minimum eigenvalue of the sampled matrix. Default: ``5e-6`` """ L, Q = super().sample(factorized=True, init_=init_) with torch.no_grad(): - # S >= 0, as given by torch.linalg.eigvalsh() - small = L < eps - L[small] = eps - if factorized: - return L, Q - else: - # Project onto the manifold - Qt = Q.transpose(-2, -1) - # Multiply the three of them as Q\LambdaQ^T - return Q @ (L.unsqueeze(-1).expand_as(Qt) * Qt) + # L >= 0, as given by torch.linalg.eigvalsh() + L[L < eps] = eps + return (Q * L) @ Q.transpose(-2, -1) diff --git a/geotorch/so.py b/geotorch/so.py index 7262e09c..20d6c6de 100644 --- a/geotorch/so.py +++ b/geotorch/so.py @@ -3,27 +3,15 @@ from torch import nn try: - from torch.linalg import qr + from torch.linalg import matrix_exp as expm except ImportError: - from torch import qr + from torch import matrix_exp as expm from .utils import _extra_repr from .skew import Skew - -try: - from torch import matrix_exp as expm -except ImportError: - from .linalg.expm import expm from .exceptions import NonSquareError, VectorError, InManifoldError -def solve(A, B): - try: - return torch.linalg.solve(A, B) - except AttributeError: - return torch.solve(B, A).solution - - def _has_orthonormal_columns(X, eps): k = X.size(-1) Id = torch.eye(k, dtype=X.dtype, device=X.device) @@ -34,7 +22,7 @@ def cayley_map(X): # compute (I+X/2)(I-X/2)^{-1} n = X.size(-1) Id = torch.eye(n, dtype=X.dtype, device=X.device) - return solve(Id.add(X, alpha=-0.5), Id.add(X, alpha=-0.5)) + return torch.linalg.solve(Id.add(X, alpha=-0.5), Id.add(X, alpha=0.5)) class SO(nn.Module): @@ -194,11 +182,11 @@ def uniform_init_(tensor): x = torch.empty_like(tensor).normal_(0, 1) if transpose: x.transpose_(-2, -1) - q, r = qr(x) + q, r = torch.linalg.qr(x) # Make uniform (diag r >= 0) d = r.diagonal(dim1=-2, dim2=-1).sign() - q *= d.unsqueeze(-2).expand_as(q) + q *= d if transpose: q.transpose_(-2, -1) @@ -207,7 +195,7 @@ def uniform_init_(tensor): if n == k: mask = (torch.det(q) > 0.0).float() mask[mask == 0.0] = -1.0 - mask = mask.unsqueeze(-1).unsqueeze(-1).expand_as(q) + mask = mask.unsqueeze(-1).unsqueeze(-1) q[..., 0] *= mask[..., 0] tensor.copy_(q) return tensor diff --git a/geotorch/sphere.py b/geotorch/sphere.py index 0ec48eb5..acb5ffd6 100644 --- a/geotorch/sphere.py +++ b/geotorch/sphere.py @@ -89,8 +89,7 @@ def in_manifold(self, x, eps=1e-5): Args: X (torch.Tensor): The vector to be checked. eps (float): Optional. Threshold at which the norm is considered - to be equal to `1`. - Default: ``1e-5`` + to be equal to ``1``. Default: ``1e-5`` """ return _in_sphere(x, self.radius, eps) @@ -162,8 +161,7 @@ def in_manifold(self, x, eps=1e-5): Args: X (torch.Tensor): The vector to be checked. eps (float): Optional. Threshold at which the norm is considered - to be equal to `1`. - Default: ``1e-5`` + to be equal to ``1``. Default: ``1e-5`` """ return _in_sphere(x, self.radius, eps) diff --git a/geotorch/stiefel.py b/geotorch/stiefel.py index a13cf367..3b8cb302 100644 --- a/geotorch/stiefel.py +++ b/geotorch/stiefel.py @@ -1,11 +1,5 @@ import torch -try: - from torch.linalg import qr -except ImportError: - from torch import qr - - from .utils import transpose, _extra_repr from .so import SO, _has_orthonormal_columns @@ -66,7 +60,7 @@ def right_inverse(self, X, check_in_manifold=True): for _ in range(2): N = N - X @ (X.transpose(-2, -1) @ N) # And make it an orthonormal base of the image - N = qr(N).Q + N = torch.linalg.qr(N).Q X = torch.cat([X, N], dim=-1) return super().right_inverse(X, check_in_manifold=False)[..., : self.k] diff --git a/geotorch/symmetric.py b/geotorch/symmetric.py index da07ec12..d7a4bc01 100644 --- a/geotorch/symmetric.py +++ b/geotorch/symmetric.py @@ -1,18 +1,5 @@ import torch from torch import nn -from functools import partial - -try: - from torch.linalg import eigh - from torch.linalg import eigvalsh -except ImportError: - from torch import symeig - - eigh = partial(symeig, eigenvectors=True) - - def eigvalsh(X): - return symeig(X, eigenvectors=False).eigenvalues - from .product import ProductManifold from .stiefel import Stiefel @@ -27,14 +14,6 @@ def eigvalsh(X): from .utils import _extra_repr -def _decreasing_eigh(X, eigenvectors): - if eigenvectors: - L, Q = eigh(X) - return L.flip(-1), Q.flip(-1) - else: - return eigvalsh(X).flip(-1) - - class Symmetric(nn.Module): def __init__(self, lower=True): r""" @@ -141,9 +120,7 @@ def frame(self, X): def submersion(self, Q, L): L = self.f(L) - Qt = Q.transpose(-2, -1) - # Multiply the three of them as Q\LambdaQ^T - return Q @ (L.unsqueeze(-1).expand_as(Qt) * Qt) + return (Q * L) @ Q.transpose(-2, -1) def forward(self, X): X = self.frame(X) @@ -159,22 +136,15 @@ def frame_inv(self, X1, X2): return ret def submersion_inv(self, X, check_in_manifold=True): - if isinstance(X, torch.Tensor): - with torch.no_grad(): - L, Q = _decreasing_eigh(X, eigenvectors=True) - if check_in_manifold and not self.in_manifold_eigen(L): - raise InManifoldError(X, self) - else: - # We assume that we got the L, Q factorized in a tuple / list - L, Q = X - if check_in_manifold and not self.in_manifold_tuple(L, Q): - raise InManifoldError(X, self) + with torch.no_grad(): + L, Q = torch.linalg.eigh(X) + if check_in_manifold and not self.in_manifold_eigen(L): + raise InManifoldError(X, self) if self.inv is None: raise InverseError(self) with torch.no_grad(): - Q = Q[..., : self.rank] - L = L[..., : self.rank] - # Multiply the three of them as Q\LambdaQ^T + Q = Q[..., -self.rank :] + L = L[..., -self.rank :] L = self.inv(L) return L, Q @@ -185,13 +155,10 @@ def right_inverse(self, X, check_in_manifold=True): def in_manifold_eigen(self, L, eps=1e-6): r""" - Checks that an ordered vector of eigenvalues values is in the manifold. - - For tensors with more than 1 dimension the first dimensions are - treated as batch dimensions. + Checks that an ascending ordered vector of eigenvalues is in the manifold. Args: - L (torch.Tensor): Vector of eigenvalues + L (torch.Tensor): Vector of eigenvalues of shape `(*, rank)` eps (float): Optional. Threshold at which the eigenvalues are considered to be zero Default: ``1e-6`` @@ -200,49 +167,28 @@ def in_manifold_eigen(self, L, eps=1e-6): return False if L.size(-1) > self.rank: # We compute the \infty-norm of the remaining dimension - D = L[..., self.rank :] + D = L[..., : -self.rank] infty_norm_err = D.abs().max(dim=-1).values if (infty_norm_err > 5.0 * eps).any(): return False - return (L[..., : self.rank] >= -eps).all().item() - - def in_manifold_tuple(self, L, Q, eps=1e-6): - return ( - self.in_manifold_eigen(L, eps) - and self[0].in_manifold(Q) - and self[1].in_manifold(L) - ) + return (L[..., -self.rank :] >= -eps).all().item() def in_manifold(self, X, eps=1e-6): r""" - Checks that a matrix is in the manifold. The matrix may be given factorized - as a pair :math:`(\Lambda, Q)` with :math:`\Lambda` a vector of eigenvalues - and :math:`Q` a matrix of eigenvectors. - - For tensors with more than 2 dimensions the first dimensions are - treated as batch dimensions. + Checks that a matrix is in the manifold. Args: - X (torch.Tensor or tuple): The matrix to be checked or a tuple - ``(eigenvectors, eigenvalues)`` as returned by ``torch.linalg.eigh`` - or ``self.sample(factorized=True)``. + X (torch.Tensor): The matrix or batch of matrices of shape ``(*, n, n)`` to check. eps (float): Optional. Threshold at which the singular values are - considered to be zero - Default: ``1e-6`` + considered to be zero. Default: ``1e-6`` """ - if isinstance(X, tuple): - if len(X) == 2: - return self.in_manifold_tuple(X[0], X[1]) - else: - return False size = self.tensorial_size + (self.n, self.n) if X.size() != size or not Symmetric.in_manifold(X, eps): return False - - L = _decreasing_eigh(X, eigenvectors=False) + L = torch.linalg.eigvalsh(X) return self.in_manifold_eigen(L, eps) - def sample(self, init_=torch.nn.init.xavier_normal_, factorized=True): + def sample(self, init_=torch.nn.init.xavier_normal_, factorized=False): r""" Returns a randomly sampled matrix on the manifold as @@ -267,11 +213,6 @@ def sample(self, init_=torch.nn.init.xavier_normal_, factorized=True): to some distribution. See `torch.init `_. Default: ``torch.nn.init.xavier_normal_`` - factorized (bool): Optional. Return an eigenvalue decomposition of the - sampled matrix as a tuple :math:`(\Lambda, Q)`. - Using ``factorized=True`` is more efficient when the result is - used to initialize a parametrized tensor. - Default: ``True`` """ with torch.no_grad(): device = self[0].base.device @@ -281,16 +222,13 @@ def sample(self, init_=torch.nn.init.xavier_normal_, factorized=True): ) init_(X) X = X @ X.transpose(-2, -1) - L, Q = _decreasing_eigh(X, eigenvectors=True) - L = L[..., : self.rank] - Q = Q[..., : self.rank] + L, Q = torch.linalg.eigh(X) + L = L[..., -self.rank :] + Q = Q[..., -self.rank :] if factorized: return L, Q else: - # Project onto the manifold - Qt = Q.transpose(-2, -1) - # Multiply the three of them as Q\LambdaQ^T - return Q @ (L.unsqueeze(-1).expand_as(Qt) * Qt) + return (Q * L) @ Q.transpose(-2, -1) def extra_repr(self): return _extra_repr( diff --git a/geotorch/utils.py b/geotorch/utils.py index 17d2cbef..91f6c5fb 100644 --- a/geotorch/utils.py +++ b/geotorch/utils.py @@ -8,11 +8,8 @@ def update_base(layer, tensor_name): def transpose(fun): def new_fun(self, X, *args, **kwargs): - # It might happen that we get at tuple inside ``right_inverse`` - # In that case we do nothing - if isinstance(X, torch.Tensor): - if self.transposed: - X = X.transpose(-2, -1) + if self.transposed: + X = X.transpose(-2, -1) X = fun(self, X, *args, **kwargs) if self.transposed: X = X.transpose(-2, -1) diff --git a/test/test_integration.py b/test/test_integration.py index 9ce6f27c..4a1bccfa 100644 --- a/test/test_integration.py +++ b/test/test_integration.py @@ -94,7 +94,7 @@ def test_orthogonal(self): self._test_manifolds( [Stiefel, Grassmannian, geotorch.orthogonal, geotorch.grassmannian], dicts_product(distribution=["uniform", "torus"]), - [{}], + dicts_product(triv=["expm", "cayley"]), self.devices(), self.sizes(square=False), ) @@ -102,7 +102,7 @@ def test_orthogonal(self): def test_rank(self): self._test_manifolds( [LowRank, FixedRank, geotorch.low_rank, geotorch.fixed_rank], - dicts_product(factorized=[True, False]), + [{}], dicts_product(rank=self.ranks()), self.devices(), self.sizes(square=False), @@ -118,7 +118,7 @@ def test_psd_and_glp(self): geotorch.positive_semidefinite, geotorch.invertible, ], - dicts_product(factorized=[True, False]), + [{}], [{}], self.devices(), self.sizes(square=True), @@ -132,7 +132,7 @@ def test_pssd_rank(self): geotorch.positive_semidefinite_low_rank, geotorch.positive_semidefinite_fixed_rank, ], - dicts_product(factorized=[True, False]), + [{}], dicts_product(rank=self.ranks()), self.devices(), self.sizes(square=True), @@ -141,7 +141,7 @@ def test_pssd_rank(self): def test_almost_orthogonal(self): self._test_manifolds( [AlmostOrthogonal, geotorch.almost_orthogonal], - dicts_product(factorized=[True, False], distribution=["uniform", "torus"]), + dicts_product(distribution=["uniform", "torus"]), dicts_product(lam=self.lambdas(), f=list(AlmostOrthogonal.fs.keys())), self.devices(), self.sizes(square=True), @@ -171,63 +171,29 @@ def _test_manifolds( ) def _test_manifold(self, M, args_sample, args_constr, device, size, initialize): - # Test Linear - layer = nn.Linear(*size) - input_ = torch.rand(3, size[0]).to(device) - old_size = layer.weight.size() - # Somewhat dirty but will do - if isinstance(M, types.FunctionType): - M(layer, "weight", **args_constr) - else: - P.register_parametrization( - layer, "weight", M(size=layer.weight.size(), **args_constr) - ) - layer = layer.to(device) - # Check that it does not change the size of the layer - self.assertEqual(old_size, layer.weight.size(), msg=f"{layer}") - self._test_training(layer, args_sample, input_, initialize) - - # Just for the smaller ones, for the large ones this is just too expensive + inputs = [torch.rand(3, size[0], device=device)] + layers = [nn.Linear(*size, device=device)] + # Just test on convolution for small layers, otherwise it takes too long if min(size) < 100: - # Test Convolutionar (tensorial) - layer = nn.Conv2d(5, 4, size) - input_ = torch.rand(6, 5, size[0] + 7, size[1] + 3).to(device) + inputs.append(torch.rand(6, 5, size[0] + 7, size[1] + 3, device=device)) + layers.append(nn.Conv2d(5, 4, size, device=device)) + + for input_, layer in zip(inputs, layers): old_size = layer.weight.size() # Somewhat dirty but will do if isinstance(M, types.FunctionType): M(layer, "weight", **args_constr) else: - P.register_parametrization( - layer, "weight", M(size=layer.weight.size(), **args_constr) - ) - layer = layer.to(device) + # initialize the weight first (annoying) + M_ = M(size=layer.weight.size(), **args_constr).to(device) + X = M_.sample(**args_sample) + with torch.no_grad(): + layer.weight.copy_(X) + P.register_parametrization(layer, "weight", M_) # Check that it does not change the size of the layer self.assertEqual(old_size, layer.weight.size(), msg=f"{layer}") self._test_training(layer, args_sample, input_, initialize) - def matrix_from_factor_svd(self, U, S, V): - Vt = V.transpose(-2, -1) - # Multiply the three of them, S as a diagonal matrix - return U @ (S.unsqueeze(-1).expand_as(Vt) * Vt) - - def matrix_from_factor_eigen(self, L, Q): - Qt = Q.transpose(-2, -1) - # Multiply the three of them as Q\LambdaQ^T - return Q @ (L.unsqueeze(-1).expand_as(Qt) * Qt) - - def matrix_from_factor(self, X, M): - transpose = hasattr(M, "transposed") and M.transposed - if not isinstance(X, tuple): - return X - elif len(X) == 2: - X = self.matrix_from_factor_eigen(X[0], X[1]) - else: - X = self.matrix_from_factor_svd(X[0], X[1], X[2]) - if transpose: - return X.transpose(-2, -1) - else: - return X - def _test_training(self, layer, args_sample, input_, initialize): msg = f"{layer}\n{args_sample}" M = layer.parametrizations.weight[0] @@ -238,15 +204,12 @@ def _test_training(self, layer, args_sample, input_, initialize): layer.weight = X with P.cached(): # Compute the product if it is factorized - X_matrix = self.matrix_from_factor(X, M).to(layer.weight.device) # The sampled matrix should not have a gradient - self.assertFalse(X_matrix.requires_grad) + self.assertFalse(X.requires_grad) # Size does not change self.assertEqual(initial_size, layer.weight.size(), msg=msg) # Tha initialisation initialisation is equal to what we passed - self.assertTrue( - torch.allclose(layer.weight, X_matrix, atol=1e-5), msg=msg - ) + self.assertTrue(torch.allclose(layer.weight, X, atol=1e-5), msg=msg) # Take a couple SGD steps optim = torch.optim.SGD(layer.parameters(), lr=1e-3) diff --git a/test/test_orthogonal.py b/test/test_orthogonal.py index 8bb5eedf..c9e3d18a 100644 --- a/test/test_orthogonal.py +++ b/test/test_orthogonal.py @@ -22,10 +22,7 @@ def _test_constructor(self, cls): SO(size=(3, 3), triv="wrong") # Try a custom trivialization (it should break in the forward) - try: - cls(size=(3, 3), triv=lambda: 3) - except ValueError: - self.fail("{} raised ValueError unexpectedly!".format(cls)) + cls(size=(3, 3), triv=lambda: 3) # Try to instantiate it in a vector rather than a matrix with self.assertRaises(VectorError): From b33ea989b0ecc5b75ea1b0f9e8f0f79c1b94df28 Mon Sep 17 00:00:00 2001 From: lezcano Date: Sun, 14 Nov 2021 19:42:55 +0000 Subject: [PATCH 03/10] Fix merge --- geotorch/parametrize.py | 71 ++++++++++++++++++++--------------------- 1 file changed, 35 insertions(+), 36 deletions(-) diff --git a/geotorch/parametrize.py b/geotorch/parametrize.py index 2bd1a877..a9e26726 100644 --- a/geotorch/parametrize.py +++ b/geotorch/parametrize.py @@ -40,8 +40,7 @@ def cached(): several times the parametrized tensors. For example, the loop of an RNN with a parametrized recurrent kernel: - def set_original_(self, value: Tensor) -> None: - r"""This method is called when assigning to a parametrized tensor. + .. code-block:: python with P.cached(): for x in xs: @@ -289,17 +288,18 @@ def _inject_new_class(module: Module) -> None: This works by substituting the class of the module by a class that extends it to be able to inject a property - def forward(self) -> Tensor: - x = self.original - for module in self: - x = module(x) - if x.size() != self.original.size(): + Args: + module (nn.Module): module into which to inject the property + """ + cls = module.__class__ + + def getstate(self): raise RuntimeError( - "The parametrization may not change the size of the parametrized tensor. " - "Size of original tensor: {} " - "Size of parametrized tensor: {}".format(self.original.size(), x.size()) + "Serialization of parametrized modules is only " + "supported through state_dict(). See:\n" + "https://pytorch.org/tutorials/beginner/saving_loading_models.html" + "#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training" ) - return x param_cls = type( f"Parametrized{cls.__name__}", @@ -375,16 +375,14 @@ def register_parametrization( If the original tensor requires a gradient, the backward pass will differentiate through :attr:`parametrization`, and the optimizer will update the tensor accordingly. -def _inject_new_class(module: Module) -> None: - r"""Sets up the parametrization mechanism used by parametrizations. + The first time that a module registers a parametrization, this function will add an attribute + ``parametrizations`` to the module of type :class:`~ParametrizationList`. The list of parametrizations on the tensor ``weight`` will be accessible under ``module.parametrizations.weight``. - Args: - module (nn.Module): module into which to inject the property - """ - cls = module.__class__ + The original tensor will be accessible under + ``module.parametrizations.weight.original``. Parametrizations may be concatenated by registering several parametrizations on the same attribute. @@ -395,15 +393,9 @@ def _inject_new_class(module: Module) -> None: Parametrized parameters and buffers have an inbuilt caching system that can be activated using the context manager :func:`cached`. - param_cls = type( - "Parametrized{}".format(cls.__name__), - (cls,), - { - "__getstate__": getstate, - }, - ) + A :attr:`parametrization` may optionally implement a method with signature - module.__class__ = param_cls + .. code-block:: python def right_inverse(self, X: Tensor) -> Union[Tensor, Sequence[Tensor]] @@ -579,18 +571,25 @@ def right_inverse(self, X: Tensor) -> Union[Tensor, Sequence[Tensor]] f"Module '{module}' does not have a parameter, a buffer, or a " f"parametrized element with name '{tensor_name}'" ) - ) + return module - # Fetch the original tensor - original = module.parametrizations[tensor_name].original # type: ignore - if leave_parametrized: - t = getattr(module, tensor_name) - # If they have the same dtype, we reuse the original tensor. - # We do this so that the parameter does not to change the id() - # This way the user does not need to update the optimizer - if t.dtype == original.dtype: - with torch.no_grad(): - original.set_(t) + def is_parametrized(module: Module, tensor_name: Optional[str] = None) -> bool: + r"""Returns ``True`` if module has an active parametrization. + + If the argument :attr:`tensor_name` is specified, returns ``True`` if + ``module[tensor_name]`` is parametrized. + + Args: + module (nn.Module): module to query + name (str, optional): attribute in the module to query + Default: ``None`` + """ + parametrizations = getattr(module, "parametrizations", None) + if parametrizations is None or not isinstance(parametrizations, ModuleDict): + return False + if tensor_name is None: + # Check that there is at least one parametrized buffer or Parameter + return len(parametrizations) > 0 else: return tensor_name in parametrizations From 98a5197b77f2586c59bd490c40d3ee80b49f0017 Mon Sep 17 00:00:00 2001 From: lezcano Date: Sun, 14 Nov 2021 20:37:22 +0000 Subject: [PATCH 04/10] Few fixes --- geotorch/fixedrank.py | 2 +- geotorch/lowrank.py | 4 ++-- geotorch/so.py | 2 +- geotorch/symmetric.py | 4 ++-- setup.py | 2 +- test/test_integration.py | 28 +++++++++++++--------------- 6 files changed, 20 insertions(+), 22 deletions(-) diff --git a/geotorch/fixedrank.py b/geotorch/fixedrank.py index 2b52f62f..51fc3e17 100644 --- a/geotorch/fixedrank.py +++ b/geotorch/fixedrank.py @@ -117,7 +117,7 @@ def sample(self, init_=torch.nn.init.xavier_normal_, eps=5e-6): with torch.no_grad(): # S >= 0, as given by torch.linalg.eigvalsh() S[S < eps] = eps - X = (U * S) @ V.transpose(-2, -1) + X = (U * S.unsqueeze(-2)) @ V.transpose(-2, -1) if self.transposed: X = X.transpose(-2, -1) return X diff --git a/geotorch/lowrank.py b/geotorch/lowrank.py index b5459444..fde12af6 100644 --- a/geotorch/lowrank.py +++ b/geotorch/lowrank.py @@ -57,7 +57,7 @@ def frame(self, X): return U, S, V def submersion(self, U, S, V): - return (U * S) @ V.transpose(-2, -1) + return (U * S.unsqueeze(-2)) @ V.transpose(-2, -1) @transpose def forward(self, X): @@ -158,7 +158,7 @@ def sample(self, init_=torch.nn.init.xavier_normal_, factorized=False): if factorized: return U, S, Vt.transpose(-2, -1) else: - X = (U * S) @ Vt + X = (U * S.unsqueeze(-2)) @ Vt if self.transposed: X = X.transpose(-2, -1) return X diff --git a/geotorch/so.py b/geotorch/so.py index 20d6c6de..3dc6e5e3 100644 --- a/geotorch/so.py +++ b/geotorch/so.py @@ -186,7 +186,7 @@ def uniform_init_(tensor): # Make uniform (diag r >= 0) d = r.diagonal(dim1=-2, dim2=-1).sign() - q *= d + q *= d.unsqueeze(-2) if transpose: q.transpose_(-2, -1) diff --git a/geotorch/symmetric.py b/geotorch/symmetric.py index d7a4bc01..b873ee93 100644 --- a/geotorch/symmetric.py +++ b/geotorch/symmetric.py @@ -120,7 +120,7 @@ def frame(self, X): def submersion(self, Q, L): L = self.f(L) - return (Q * L) @ Q.transpose(-2, -1) + return (Q * L.unsqueeze(-2)) @ Q.transpose(-2, -1) def forward(self, X): X = self.frame(X) @@ -228,7 +228,7 @@ def sample(self, init_=torch.nn.init.xavier_normal_, factorized=False): if factorized: return L, Q else: - return (Q * L) @ Q.transpose(-2, -1) + return (Q * L.unsqueeze(-2)) @ Q.transpose(-2, -1) def extra_repr(self): return _extra_repr( diff --git a/setup.py b/setup.py index 26165245..276df470 100644 --- a/setup.py +++ b/setup.py @@ -50,6 +50,6 @@ keywords=["Constrained Optimization", "Optimization on Manifolds", "Pytorch"], packages=find_packages(), python_requires=">=3.5", - install_requires=["torch>=1.8"], + install_requires=["torch>=1.9"], extras_require={"dev": DEV_REQUIRES, "test": TEST_REQUIRES}, ) diff --git a/test/test_integration.py b/test/test_integration.py index 4a1bccfa..6af0291a 100644 --- a/test/test_integration.py +++ b/test/test_integration.py @@ -36,21 +36,19 @@ def dicts_product(**kwargs): class TestIntegration(TestCase): def sizes(self, square): - sizes = [] - if not torch.cuda.is_available(): - sizes = [(i, i) for i in range(1, 11)] - if not square: - sizes.extend( - [ - (i, j) - for i, j in itertools.product(range(1, 5), range(1, 5)) - if i != j - ] - ) - sizes.extend( - [(1, 7), (2, 7), (1, 8), (2, 8), (7, 1), (7, 2), (8, 1), (8, 2)] - ) - else: + sizes = [(i, i) for i in range(1, 11)] + if not square: + sizes.extend( + [ + (i, j) + for i, j in itertools.product(range(1, 5), range(1, 5)) + if i != j + ] + ) + sizes.extend( + [(1, 7), (2, 7), (1, 8), (2, 8), (7, 1), (7, 2), (8, 1), (8, 2)] + ) + if torch.cuda.is_available(): sizes.extend([(256, 256), (512, 512)]) if not square: sizes.extend([(256, 128), (128, 512), (1024, 512)]) From 650c48c395c401c576c7501ee48e103b10613c20 Mon Sep 17 00:00:00 2001 From: lezcano Date: Sun, 14 Nov 2021 20:52:36 +0000 Subject: [PATCH 05/10] Fix --- geotorch/pssdfixedrank.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/geotorch/pssdfixedrank.py b/geotorch/pssdfixedrank.py index c6edc094..e6dac445 100644 --- a/geotorch/pssdfixedrank.py +++ b/geotorch/pssdfixedrank.py @@ -96,4 +96,4 @@ def sample(self, init_=torch.nn.init.xavier_normal_, eps=5e-6): with torch.no_grad(): # L >= 0, as given by torch.linalg.eigvalsh() L[L < eps] = eps - return (Q * L) @ Q.transpose(-2, -1) + return (Q * L.unsqueeze(-2)) @ Q.transpose(-2, -1) From 045af1735beed43a61feee5b29993454f1ba632a Mon Sep 17 00:00:00 2001 From: lezcano Date: Sun, 14 Nov 2021 20:57:33 +0000 Subject: [PATCH 06/10] Remove expm --- geotorch/linalg/__init__.py | 0 geotorch/linalg/expm.py | 328 ------------------------------------ 2 files changed, 328 deletions(-) delete mode 100644 geotorch/linalg/__init__.py delete mode 100644 geotorch/linalg/expm.py diff --git a/geotorch/linalg/__init__.py b/geotorch/linalg/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/geotorch/linalg/expm.py b/geotorch/linalg/expm.py deleted file mode 100644 index f384aaa9..00000000 --- a/geotorch/linalg/expm.py +++ /dev/null @@ -1,328 +0,0 @@ -import torch -import math - -degs = [1, 2, 4, 8, 12, 18] - -thetas_dict = { - "single": [ - 1.192092800768788e-07, # m_vals = 1 - 5.978858893805233e-04, # m_vals = 2 - # 1.123386473528671e-02, - 5.116619363445086e-02, # m_vals = 4 - # 1.308487164599470e-01, - # 2.495289322846698e-01, - # 4.014582423510481e-01, - 5.800524627688768e-01, # m_vals = 8 - # 7.795113374358031e-01, - # 9.951840790004457e-01, - # 1.223479542424143e+00, - 1.461661507209034e00, # m_vals = 12 - # 1.707648529608701e+00, - # 1.959850585959898e+00, - # 2.217044394974720e+00, - # 2.478280877521971e+00, - # 2.742817112698780e+00, - 3.010066362817634e00, - ], # m_vals = 18 - "double": [ - 2.220446049250313e-16, # m_vals = 1 - 2.580956802971767e-08, # m_vals = 2 - # 1.386347866119121e-05, - 3.397168839976962e-04, # m_vals = 4 - # 2.400876357887274e-03, - # 9.065656407595102e-03, - # 2.384455532500274e-02, - 4.991228871115323e-02, # m_vals = 8 - # 8.957760203223343e-02, - # 1.441829761614378e-01, - # 2.142358068451711e-01, - 2.996158913811580e-01, # m_vals = 12 - # 3.997775336316795e-01, - # 5.139146936124294e-01, - # 6.410835233041199e-01, - # 7.802874256626574e-01, - # 9.305328460786568e-01, - 1.090863719290036e00, - ], # m_vals = 18 -} - -coefs = { - 12: [ - [ - -1.86023205146205530824e-02, - -5.00702322573317714499e-03, - -5.73420122960522249400e-01, - -1.33399693943892061476e-01, - ], - [ - 4.6, - 9.92875103538486847299e-01, - -1.32445561052799642976e-01, - 1.72990000000000000000e-03, - ], - [ - 2.11693118299809440730e-01, - 1.58224384715726723583e-01, - 1.65635169436727403003e-01, - 1.07862779315792429308e-02, - ], - [ - 0.0, - -1.31810610138301836924e-01, - -2.02785554058925905629e-02, - -6.75951846863086323186e-03, - ], - ], - 18: [ - [ - 0.0, - -1.00365581030144618291e-01, - -8.02924648241156932449e-03, - -8.92138498045729985177e-04, - 0.0, - ], - [ - 0.0, - 3.97849749499645077844e-01, - 1.36783778460411720168e00, - 4.98289622525382669416e-01, - -6.37898194594723280150e-04, - ], - [ - -1.09676396052962061844e01, - 1.68015813878906206114e00, - 5.71779846478865511061e-02, - -6.98210122488052056106e-03, - 3.34975017086070470649e-05, - ], - [ - -9.04316832390810593223e-02, - -6.76404519071381882256e-02, - 6.75961301770459654925e-02, - 2.95552570429315521194e-02, - -1.39180257516060693404e-05, - ], - [ - 0.0, - 0.0, - -9.23364619367118555360e-02, - -1.69364939002081722752e-02, - -1.40086798182036094347e-05, - ], - ], -} - - -def matrix_power_two_batch(A, k): - orig_size = A.size() - A, k = A.flatten(0, -3), k.flatten() - ksorted, idx = torch.sort(k) - # Abusing bincount... - count = torch.bincount(ksorted) - nonzero = torch.nonzero(count, as_tuple=False) - A = torch.matrix_power(A, 2 ** ksorted[0]) - last = ksorted[0] - processed = count[nonzero[0]] - for exp in nonzero[1:]: - new, last = exp - last, exp - A[idx[processed:]] = torch.matrix_power(A[idx[processed:]], 2 ** new.item()) - processed += count[exp] - return A.reshape(orig_size) - - -def expm_taylor(A): - if A.ndimension() < 2 or A.size(-2) != A.size(-1): - raise ValueError("Expected a square matrix or a batch of square matrices") - - if A.ndimension() == 2: - # Just one matrix - - # Trivial case - if A.size() == (1, 1): - return torch.exp(A) - - if A.element_size() > 4: - thetas = thetas_dict["double"] - else: - thetas = thetas_dict["single"] - - normA = torch.max(torch.sum(torch.abs(A), axis=0)).item() - - # No scale-square needed - # This could be done marginally faster if iterated in reverse - for deg, theta in zip(degs, thetas): - if normA <= theta: - return taylor_approx(A, deg) - - # Scale square - s = int(math.ceil(math.log2(normA) - math.log2(thetas[-1]))) - A = A * (2 ** -s) - X = taylor_approx(A, degs[-1]) - return torch.matrix_power(X, 2 ** s) - else: - # Batching - - # Trivial case - if A.size()[-2:] == (1, 1): - return torch.exp(A) - - if A.element_size() > 4: - thetas = thetas_dict["double"] - else: - thetas = thetas_dict["single"] - - normA = torch.max(torch.sum(torch.abs(A), axis=-2), axis=-1).values - - # Handle trivial case - if (normA == 0.0).all(): - Id = torch.eye(A.size(-2), A.size(-1), dtype=A.dtype, device=A.device) - return Id.expand_as(A) - - # Handle small normA - more = normA > thetas[-1] - s = normA.new_zeros(normA.size(), dtype=torch.long) - s[more] = torch.ceil(torch.log2(normA[more]) - math.log2(thetas[-1])).long() - - # A = A * 2**(-s) - A = torch.pow(0.5, s.float()).unsqueeze_(-1).unsqueeze_(-1).expand_as(A) * A - X = taylor_approx(A, degs[-1]) - return matrix_power_two_batch(X, s) - - -def taylor1(Id, A): - return Id + A - - -def taylor2(Id, A, A2): - return Id + A + 0.5 * A2 - - -def taylor4(Id, A, A2): - return Id + A + A2 @ (0.5 * Id + A / 6.0 + A2 / 24.0) - - -def taylor8(Id, A, A2): - # Minor: Precompute - SQRT = math.sqrt(177.0) - x3 = 2.0 / 3.0 - a1 = (1.0 + SQRT) * x3 - x1 = a1 / 88.0 - x2 = a1 / 352.0 - c0 = (-271.0 + 29.0 * SQRT) / (315.0 * x3) - c1 = (11.0 * (-1.0 + SQRT)) / (1260.0 * x3) - c2 = (11.0 * (-9.0 + SQRT)) / (5040.0 * x3) - c4 = (89.0 - SQRT) / (5040.0 * x3 * x3) - y2 = ((857.0 - 58.0 * SQRT)) / 630.0 - # Matrix products - A4 = A2 @ (x1 * A + x2 * A2) - A8 = (x3 * A2 + A4) @ (c0 * Id + c1 * A + c2 * A2 + c4 * A4) - return Id + A + y2 * A2 + A8 - - -def taylor12(Id, A, A2, A3): - b = torch.tensor(coefs[12], dtype=A.dtype, device=A.device) - # We implement the following allowing for batches - # q31 = a01*Id+a11*A+a21*A2+a31*A3 - # q32 = a02*Id+a12*A+a22*A2+a32*A3 - # q33 = a03*Id+a13*A+a23*A2+a33*A3 - # q34 = a04*Id+a14*A+a24*A2+a34*A3 - # Matrix products - # q61 = q33 + q34 @ q34 - # return (q31 + (q32 + q61) @ q61) - - q = torch.stack([Id, A, A2, A3], dim=-3).unsqueeze_(-4) - len_batch = A.ndimension() - 2 - # Expand first dimension to perform pointwise multiplication - q_size = [-1 for _ in range(len_batch)] + [4, -1, -1, -1] - q = q.expand(*q_size) - b = b.unsqueeze_(-1).unsqueeze_(-1).expand_as(q) - q = (b * q).sum(dim=-3) - if A.ndimension() > 2: - # Indexing the third to last dimension, because otherwise we - # would have to prepend as many 1's as the batch shape for the - # previous expand_as to work - qaux = q[..., 2, :, :] + q[..., 3, :, :] @ q[..., 3, :, :] - return q[..., 0, :, :] + (q[..., 1, :, :] + qaux) @ qaux - else: - qaux = q[2] + q[3] @ q[3] - return q[0] + (q[1] + qaux) @ qaux - - -def taylor18(Id, A, A2, A3, A6): - b = torch.tensor(coefs[18], dtype=A.dtype, device=A.device) - # We implement the following allowing for batches - # q31 = a01*Id + a11*A + a21*A2 + a31*A3 - # q61 = b01*Id + b11*A + b21*A2 + b31*A3 + b61*A6 - # q62 = b02*Id + b12*A + b22*A2 + b32*A3 + b62*A6 - # q63 = b03*Id + b13*A + b23*A2 + b33*A3 + b63*A6 - # q64 = b04*Id + b14*A + b24*A2 + b34*A3 + b64*A6 - # q91 = q31 @ q64 + q63 - # return q61 + (q62 + q91) @ q91 - q = torch.stack([Id, A, A2, A3, A6], dim=-3).unsqueeze_(-4) - len_batch = A.ndimension() - 2 - q_size = [-1 for _ in range(len_batch)] + [5, -1, -1, -1] - q = q.expand(*q_size) - b = b.unsqueeze_(-1).unsqueeze_(-1).expand_as(q) - q = (b * q).sum(dim=-3) - if A.ndimension() > 2: - # Indexing the third to last dimension, because otherwise we - # would have to prepend as many 1's as the batch shape for the - # previous expand_as to work - qaux = q[..., 0, :, :] @ q[..., 4, :, :] + q[..., 3, :, :] - return q[..., 1, :, :] + (q[..., 2, :, :] + qaux) @ qaux - else: - qaux = q[0] @ q[4] + q[3] - return q[1] + (q[2] + qaux) @ qaux - - -def taylor_approx(A, deg): - Id = torch.eye(A.size(-2), A.size(-1), dtype=A.dtype, device=A.device) - if A.ndimension() > 2: - Id = Id.expand_as(A) - - As = [Id, A] - if deg >= 2: - # A2 - As.append(A @ A) - if deg >= 12: - # A3 - As.append(A @ As[2]) - if deg == 18: - # A6 - As.append(As[3] @ As[3]) - - # Switch-case - return {1: taylor1, 2: taylor2, 4: taylor4, 8: taylor8, 12: taylor12, 18: taylor18}[ - deg - ](*As) - - -# Coverage does not catch these two being used as they are executed from the C++ backend -def differential(A, E, f): # pragma: no cover - n = A.size(-1) - size_M = list(A.size()[:-2]) + [2 * n, 2 * n] - M = A.new_zeros(size_M) - M[..., :n, :n] = A - M[..., n:, n:] = A - M[..., :n, n:] = E - return f(M)[..., :n, n:] - - -class expm_taylor_class(torch.autograd.Function): - @staticmethod - def forward(ctx, A): - ctx.save_for_backward(A) - return expm_taylor(A) - - @staticmethod - def backward(ctx, G): # pragma: no cover - (A,) = ctx.saved_tensors - # Handle typical case separately as (dexp)_0 = Id - if (A == 0).all(): - return G - else: - return differential(A.transpose(-2, -1), G, expm_taylor) - - -def expm(X): - return expm_taylor_class.apply(X) From 627116182dc424d35fb0c4287aad6cdc3076cfcf Mon Sep 17 00:00:00 2001 From: lezcano Date: Sun, 14 Nov 2021 21:01:06 +0000 Subject: [PATCH 07/10] Remove test_linalg.py --- test/test_linalg.py | 138 -------------------------------------------- 1 file changed, 138 deletions(-) delete mode 100644 test/test_linalg.py diff --git a/test/test_linalg.py b/test/test_linalg.py deleted file mode 100644 index 144a801c..00000000 --- a/test/test_linalg.py +++ /dev/null @@ -1,138 +0,0 @@ -import math -from unittest import TestCase -import torch - -from geotorch.linalg.expm import expm, taylor_approx - - -class TestLinalg(TestCase): - def taylor(self, X, deg): - assert X.size(-1) == X.size(-2) - n = X.size(-1) - Id = torch.eye(n, n, dtype=X.dtype, device=X.device) - acc = Id - last = Id - for i in range(1, deg + 1): - last = last @ X / float(i) - acc = acc + last - return acc - - def scale_square(self, X): - """ - Scale-squaring trick - """ - norm = X.norm() - if norm < 0.5: - return self.taylor(X, 12) - - k = int(math.ceil(math.log2(float(norm)))) + 2 - X = X * (2 ** -k) - E = self.taylor(X, 18) - for _ in range(k): - E = torch.mm(E, E) - return E - - def assertIsCloseSquare(self, X, Y, places=4): - self.assertEqual(X.ndim, 2) - self.assertEqual(X.size(0), X.size(1)) - self.assertAlmostEqual(torch.dist(X, Y).item(), 0.0, places=places) - - def compare_f(self, f_batching, f_simple, allows_batches, dtype, gradients=False): - # Test expm without batching - for _ in range(8): - A = torch.rand(10, 10, dtype=dtype) - if gradients: - G = torch.rand(10, 10, dtype=dtype) - B1 = f_batching(A, G) - B2 = f_simple(A, G) - self.assertIsCloseSquare(B1, B2, places=2) - else: - B1 = f_batching(A) - B2 = f_simple(A) - self.assertIsCloseSquare(B1, B2, places=3) - - # Test batching - for _ in range(3): - len_shape = torch.randint(1, 4, (1,)) - shape_batch = torch.randint(1, 5, size=(len_shape,)) - shape = list(shape_batch) + [8, 8] - A = torch.rand(*shape, dtype=dtype) - if gradients: - G = torch.rand(*shape, dtype=dtype) - B1 = f_batching(A, G) - else: - B1 = f_batching(A) - if allows_batches: - B2 = f_simple(A) - self.assertEqual(B1.size(), A.size()) - self.assertEqual(B2.size(), A.size()) - - # sample a few coordinates and evaluate the equality - # of those elements in the batch - for _ in range(3): - coords = [ - torch.randint(low=0, high=s, size=(1,)).item() - for s in shape_batch - ] - coords = coords + [...] - self.assertIsCloseSquare(B1[coords], B2[coords], places=3) - else: - # sample a few coordinates and evaluate the equality - # of those elements in the batch - for _ in range(3): - coords = [ - torch.randint(low=0, high=s, size=(1,)).item() - for s in shape_batch - ] - coords = coords + [...] - if gradients: - self.assertIsCloseSquare( - B1[coords], f_simple(A[coords], G[coords]), places=2 - ) - else: - self.assertIsCloseSquare( - B1[coords], f_simple(A[coords]), places=3 - ) - - def test_expm(self): - with torch.random.fork_rng(devices=range(torch.cuda.device_count())): - torch.random.manual_seed(8888) - # Test different Taylor approximations - degs = [1, 2, 4, 8, 12, 18] - for deg in degs: - for dtype in [torch.float, torch.double]: - self.compare_f( - lambda X: taylor_approx(X, deg), - lambda X: self.taylor(X, deg), - allows_batches=True, - dtype=dtype, - ) - - # Test the main function - for dtype in [torch.float, torch.double]: - self.compare_f( - expm, self.scale_square, allows_batches=False, dtype=dtype - ) - - # Test the gradients - def diff(f): - def wrap(A, G): - A.requires_grad_() - return torch.autograd.grad([f(A)], [A], [G])[0] - - return wrap - - for dtype in [torch.float, torch.double]: - self.compare_f( - diff(expm), - diff(self.scale_square), - allows_batches=False, - dtype=dtype, - gradients=True, - ) - - def test_errors(self): - with self.assertRaises(ValueError): - expm(torch.empty(3, 4)) - with self.assertRaises(ValueError): - expm(torch.empty(1, 4)) From 6ec2f94802d90902674e69d6e9fe44effbf41004 Mon Sep 17 00:00:00 2001 From: lezcano Date: Sun, 14 Nov 2021 21:06:28 +0000 Subject: [PATCH 08/10] Unsafe registrations and update README --- README.rst | 2 +- geotorch/constraints.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index 67d9694c..80e972ad 100644 --- a/README.rst +++ b/README.rst @@ -149,7 +149,7 @@ If one wants to use a parametrized tensor in different places in their model, or Of course, this ``with`` statement may be used simply inside the forward function where the parametrized layer is used several times. -These ideas fall in the context of parametrized optimization, where one wraps a tensor ``X`` with a function ``f``, and rather than using ``X``, uses ``f(X)``. Particular examples of this idea are pruning, weight normalization, and spectral normalization among others. This repository implements a framework to approach this kind of problems. This framework was accepted to core PyTorch 1.8. It can be found under `torch.nn.utils.parametrize`_ and `torch.nn.utils.parametrizations`_. +These ideas fall in the context of parametrized optimization, where one wraps a tensor ``X`` with a function ``f``, and rather than using ``X``, uses ``f(X)``. Particular examples of this idea are pruning, weight normalization, and spectral normalization among others. This repository implements a framework to approach this kind of problems. This framework was accepted to core PyTorch 1.8. It can be found under `torch.nn.utils.parametrize`_ and `torch.nn.utils.parametrizations`_. When using PyTorch 1.10 or higher, these functions are used, and the user can interact with the parametrizations in GeoTorch using the functions in PyTorch. As every space in GeoTorch is, at its core, a map from a flat space into a manifold, the tools implemented here also serve as a building block in normalizing flows. Using a factorized space such as |low|_ it is direct to compute the determinant of the transformation it defines, as we have direct access to the singular values of the layer. diff --git a/geotorch/constraints.py b/geotorch/constraints.py index fee0939c..16efe44c 100644 --- a/geotorch/constraints.py +++ b/geotorch/constraints.py @@ -28,7 +28,7 @@ def _register_manifold(module, tensor_name, cls, *args): else: setattr(module, tensor_name, X) - P.register_parametrization(module, tensor_name, M) + P.register_parametrization(module, tensor_name, M, unsafe=True) return module From 8eb7089c67938aa65b3133135115da49f752a9ae Mon Sep 17 00:00:00 2001 From: lezcano Date: Sun, 14 Nov 2021 21:43:28 +0000 Subject: [PATCH 09/10] minor --- geotorch/so.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/geotorch/so.py b/geotorch/so.py index 3dc6e5e3..37bfc47a 100644 --- a/geotorch/so.py +++ b/geotorch/so.py @@ -193,10 +193,10 @@ def uniform_init_(tensor): # Make them have positive determinant by multiplying the # first column by -1 (does not change the measure) if n == k: - mask = (torch.det(q) > 0.0).float() - mask[mask == 0.0] = -1.0 - mask = mask.unsqueeze(-1).unsqueeze(-1) - q[..., 0] *= mask[..., 0] + mask = (torch.det(q) >= 0.0).float() + mask[mask == 0.] = -1. + mask = mask.unsqueeze(-1) + q[..., 0] *= mask tensor.copy_(q) return tensor From 754144e58cd1dae6013e12af620118b13bf2fe50 Mon Sep 17 00:00:00 2001 From: lezcano Date: Sun, 14 Nov 2021 21:45:53 +0000 Subject: [PATCH 10/10] black --- geotorch/so.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/geotorch/so.py b/geotorch/so.py index 37bfc47a..93b14ba7 100644 --- a/geotorch/so.py +++ b/geotorch/so.py @@ -194,7 +194,7 @@ def uniform_init_(tensor): # first column by -1 (does not change the measure) if n == k: mask = (torch.det(q) >= 0.0).float() - mask[mask == 0.] = -1. + mask[mask == 0.0] = -1.0 mask = mask.unsqueeze(-1) q[..., 0] *= mask tensor.copy_(q)