Skip to content

Commit

Permalink
#25 Add support for parametrizations from PyTorch 1.10 core
Browse files Browse the repository at this point in the history
  • Loading branch information
lezcano authored Nov 14, 2021
2 parents ca4eec1 + 754144e commit 280f8dd
Show file tree
Hide file tree
Showing 19 changed files with 753 additions and 1,172 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
if: ${{ matrix.os == 'windows-latest' }}
run: |
python -m pip install --upgrade pip
pip install torch===1.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install torch===1.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install .[dev]
- name: Lint with flake8
Expand Down
7 changes: 4 additions & 3 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ You may try GeoTorch installing it with
pip install git+https://github.com/Lezcano/geotorch/
GeoTorch is tested in Linux, Mac, and Windows environments for Python >= 3.6.
GeoTorch is tested in Linux, Mac, and Windows environments for Python >= 3.6 and supports PyTorch >= 1.9

Sharing Weights, Parametrizations, and Normalizing Flows
--------------------------------------------------------
Expand All @@ -149,7 +149,7 @@ If one wants to use a parametrized tensor in different places in their model, or
Of course, this ``with`` statement may be used simply inside the forward function where the parametrized layer is used several times.

These ideas fall in the context of parametrized optimization, where one wraps a tensor ``X`` with a function ``f``, and rather than using ``X``, uses ``f(X)``. Particular examples of this idea are pruning, weight normalization, and spectral normalization among others. This repository implements a framework to approach this kind of problems. The framework is currently `PR #33344`_ in PyTorch. All the functionality of this PR is located in `geotorch/parametrize.py`_.
These ideas fall in the context of parametrized optimization, where one wraps a tensor ``X`` with a function ``f``, and rather than using ``X``, uses ``f(X)``. Particular examples of this idea are pruning, weight normalization, and spectral normalization among others. This repository implements a framework to approach this kind of problems. This framework was accepted to core PyTorch 1.8. It can be found under `torch.nn.utils.parametrize`_ and `torch.nn.utils.parametrizations`_. When using PyTorch 1.10 or higher, these functions are used, and the user can interact with the parametrizations in GeoTorch using the functions in PyTorch.

As every space in GeoTorch is, at its core, a map from a flat space into a manifold, the tools implemented here also serve as a building block in normalizing flows. Using a factorized space such as |low|_ it is direct to compute the determinant of the transformation it defines, as we have direct access to the singular values of the layer.

Expand Down Expand Up @@ -219,7 +219,8 @@ Please cite the following work if you found GeoTorch useful. This paper exposes
:alt: License

.. _here: https://github.com/Lezcano/geotorch/blob/master/examples/copying_problem.py#L16
.. _PR #33344: https://github.com/pytorch/pytorch/pull/33344
.. _torch.nn.utils.parametrize: https://pytorch.org/docs/stable/generated/torch.nn.utils.parametrize.register_parametrization.html
.. _torch.nn.utils.parametrizations: https://pytorch.org/docs/stable/generated/torch.nn.utils.parametrizations.orthogonal.html
.. _geotorch/parametrize.py: https://github.com/Lezcano/geotorch/blob/master/geotorch/parametrize.py
.. _examples/sequential_mnist.py: https://github.com/Lezcano/geotorch/blob/master/examples/sequential_mnist.py
.. _examples/copying_problem.py: https://github.com/Lezcano/geotorch/blob/master/examples/copying_problem.py
Expand Down
27 changes: 3 additions & 24 deletions geotorch/almostorthogonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def in_manifold_singular_values(self, S, eps=1e-5):
and ((S - 1.0).abs() <= lam).all().item()
)

def sample(self, distribution="uniform", init_=None, factorized=True):
def sample(self, distribution="uniform", init_=None):
r"""
Returns a randomly sampled orthogonal matrix according to the specified
``distribution``. The options are:
Expand Down Expand Up @@ -142,30 +142,9 @@ def sample(self, distribution="uniform", init_=None, factorized=True):
to some distribution. See
`torch.init <https://pytorch.org/docs/stable/nn.init.html>`_.
Default: :math:`\operatorname{Uniform}(-\pi, \pi)`
factorized (bool): Optional. Return an SVD decomposition of the
sampled matrix as a tuple :math:`(U, \Sigma, V)`.
Using ``factorized=True`` is more efficient when the result is
used to initialize a parametrized tensor.
Default: ``True``
"""
with torch.no_grad():
device = self[0].base.device
dtype = self[0].base.dtype
# Sample U and set S = 1, V = Id
U = self[0].sample(distribution=distribution, init_=init_)
S = torch.ones(
*(self.tensorial_size + (self.n,)), device=device, dtype=dtype
)
V = torch.eye(self.n, device=device, dtype=dtype)
if len(self.tensorial_size) > 0:
V = V.repeat(*(self.tensorial_size + (1, 1)))

if factorized:
return U, S, V
else:
Vt = V.transpose(-2, -1)
# Multiply the three of them, S as a diagonal matrix
return U @ (S.unsqueeze(-1).expand_as(Vt) * Vt)
# Sample an orthogonal matrix as U and return it
return self[0].sample(distribution=distribution, init_=init_)

def extra_repr(self):
return _extra_repr(
Expand Down
13 changes: 7 additions & 6 deletions geotorch/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,16 @@
def _register_manifold(module, tensor_name, cls, *args):
tensor = getattr(module, tensor_name)
M = cls(tensor.size(), *args).to(device=tensor.device, dtype=tensor.dtype)
P.register_parametrization(module, tensor_name, M)

# Initialize without checking in manifold
X = M.sample()
param_list = module.parametrizations[tensor_name]
with torch.no_grad():
for m in reversed(param_list):
X = m.right_inverse(X, check_in_manifold=False)
param_list.original.copy_(X)
if not P.is_parametrized(module, tensor_name):
with torch.no_grad():
tensor.copy_(X)
else:
setattr(module, tensor_name, X)

P.register_parametrization(module, tensor_name, M, unsafe=True)

return module

Expand Down
20 changes: 5 additions & 15 deletions geotorch/fixedrank.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def in_manifold_singular_values(self, S, eps=1e-5):
infty_norm = D.abs().max(dim=-1).values
return (infty_norm > eps).all().item()

def sample(self, init_=torch.nn.init.xavier_normal_, factorized=True, eps=5e-6):
def sample(self, init_=torch.nn.init.xavier_normal_, eps=5e-6):
r"""
Returns a randomly sampled matrix on the manifold by sampling a matrix according
to ``init_`` and projecting it onto the manifold.
Expand All @@ -110,24 +110,14 @@ def sample(self, init_=torch.nn.init.xavier_normal_, factorized=True, eps=5e-6):
in place according to some distribution. See
`torch.init <https://pytorch.org/docs/stable/nn.init.html>`_.
Default: ``torch.nn.init.xavier_normal_``
factorized (bool): Optional. Return an SVD decomposition of the
sampled matrix as a tuple :math:`(U, \Sigma, V)`.
Using ``factorized=True`` is more efficient when the result is
used to initialize a parametrized tensor.
Default: ``True``
eps (float): Optional. Minimum singular value of the sampled matrix.
Default: ``5e-6``
"""
U, S, V = super().sample(factorized=True, init_=init_)
with torch.no_grad():
# S >= 0, as given by torch.linalg.eigvalsh()
S[S < eps] = eps
if factorized:
return U, S, V
else:
Vt = V.transpose(-2, -1)
# Multiply the three of them, S as a diagonal matrix
X = U @ (S.unsqueeze(-1).expand_as(Vt) * Vt)
if self.transposed:
X = X.transpose(-2, -1)
return X
X = (U * S.unsqueeze(-2)) @ V.transpose(-2, -1)
if self.transposed:
X = X.transpose(-2, -1)
return X
Empty file removed geotorch/linalg/__init__.py
Empty file.
Loading

0 comments on commit 280f8dd

Please sign in to comment.