A library for constrained optimization and manifold optimization for deep learning in PyTorch
GeoTorch provides a simple way to perform constrained optimization and optimization on manifolds in PyTorch. It is compatible out of the box with any optimizer, layer, and model implemented in PyTorch without any boilerplate in the training code. Just state the constraints when you construct the model and you are ready to go!
import torch
import torch.nn as nn
import geotorch
class Model(nn.Module):
def __init__(self):
super().__init__()
# One line suffices: Instantiate a linear layer with orthonormal columns
self.linear = nn.Linear(64, 128)
geotorch.orthogonal(self.linear, "weight")
# Works with tensors: Instantiate a CNN with kernels of rank 1
self.cnn = nn.Conv2d(16, 32, 3)
geotorch.low_rank(self.cnn, "weight", rank=1)
# Weights are initialized to a random value when you put the constraints, but
# you may re-initialize them to a different value by assigning to them
self.linear.weight = torch.eye(128, 64)
# And that's all you need to do. The rest is regular PyTorch code
def forward(self, x):
# self.linear is orthogonal and every 3x3 kernel in self.cnn is of rank 1
# Use the model as you would normally do. Everything just works
model = Model().cuda()
# Use your optimizer of choice. Any optimizer works out of the box with any parametrization
optim = torch.optim.Adam(model.parameters(), lr=lr)
The following constraints are implemented and may be used as in the example above:
geotorch.symmetric
. Symmetric matricesgeotorch.skew
. Skew-symmetric matricesgeotorch.sphere
. Vectors of norm1
geotorch.orthogonal
. Matrices with orthogonal columnsgeotorch.grassmannian
. Skew-symmetric matricesgeotorch.almost_orthogonal(λ)
. Matrices with singular values in the interval[1-λ, 1+λ]
geotorch.invertible
. Invertible matrices with positive determinantgeotorch.sln
. Matrices of determinant equal to1
geotorch.low_rank(r)
. Matrices of rank at mostr
geotorch.fixed_rank(r)
. Matrices of rankr
geotorch.positive_definite
. Positive definite matricesgeotorch.positive_semidefinite
. Positive semidefinite matricesgeotorch.positive_semidefinite_low_rank(r)
. Positive semidefinite matrices of rank at mostr
geotorch.positive_semidefinite_fixed_rank(r)
. Positive semidefinite matrices of rankr
Each of these constraints have some extra parameters which can be used to tailor the behavior of each constraint to the problem in hand. For more on this, see the documentation.
These constraints are a fronted for the families of spaces listed below.
Each constraint in GeoTorch is implemented as a manifold. These give the user more flexibility on the options that they choose for each parametrization. All these support Riemannian Gradient Descent (more on this here), but they also support optimization via any other PyTorch optimizer.
GeoTorch currently supports the following spaces:
Rn(n)
:Rⁿ
. Unrestricted optimizationSym(n)
: Vector space of symmetric matricesSkew(n)
: Vector space of skew-symmetric matricesSphere(n)
: Sphere inRⁿ
.{ x ∈ Rⁿ | ||x|| = 1 } ⊂ Rⁿ
SO(n)
: Manifold ofn×n
orthogonal matricesSt(n,k)
: Manifold ofn×k
matrices with orthonormal columnsAlmostOrthogonal(n,k,λ)
: Manifold ofn×k
matrices with singular values in the interval[1-λ, 1+λ]
Gr(n,k)
: Manifold ofk
-dimensional subspaces inRⁿ
GLp(n)
: Manifold of invertiblen×n
matrices with positive determinantSL(n)
: Manifold ofn×n
matrices with determinant equal to 1LowRank(n,k,r)
: Variety ofn×k
matrices of rankr
or lessFixedRank(n,k,r)
: Manifold ofn×k
matrices of rankr
PSD(n)
: Cone ofn×n
symmetric positive definite matricesPSSD(n)
: Cone ofn×n
symmetric positive semi-definite matricesPSSDLowRank(n,r)
: Variety ofn×n
symmetric positive semi-definite matrices of rankr
or lessPSSDFixedRank(n,r)
: Manifold ofn×n
symmetric positive semi-definite matrices of rankr
ProductManifold(M₁, ..., Mₖ)
: Product of manifoldsM₁ × ... × Mₖ
Every space of dimension (n, k)
can be applied to tensors of shape (*, n, k)
, so we also get efficient parallel implementations of product spaces such as
ObliqueManifold(n,k)
: Matrix with unit length columns,Sⁿ⁻¹ × ...ᵏ⁾ × Sⁿ⁻¹
The files in examples/copying_problem.py and examples/sequential_mnist.py serve as tutorials to see how to handle the initialization and usage of GeoTorch in some real code. They also show how to implement Riemannian Gradient Descent and some other tricks. For an introduction to how the library is actually implemented, see the Jupyter Notebook examples/parametrisations.ipynb.
You may try GeoTorch installing it as
pip install git+https://github.com/Lezcano/geotorch/
GeoTorch is tested in Linux, Mac, and Windows environments for Python >= 3.6 and supports PyTorch >= 1.9
If one wants to use a parametrized tensor in different places in their model, or uses one parametrized layer many times, for example in an RNN, it is recommended to wrap the forward pass as follows to avoid each parametrization to be computed many times:
with geotorch.parametrize.cached():
logits = model(input_)
Of course, this with
statement may be used simply inside the forward function where the parametrized layer is used several times.
These ideas fall in the context of parametrized optimization, where one wraps a tensor X
with a function f
, and rather than using X
, uses f(X)
. Particular examples of this idea are pruning, weight normalization, and spectral normalization among others. This repository implements a framework to approach this kind of problems. This framework was accepted to core PyTorch 1.8. It can be found under torch.nn.utils.parametrize and torch.nn.utils.parametrizations. When using PyTorch 1.10 or higher, the native PyTorch functions are used within GeoTorch. In this case, the user can interact with the parametrizations in GeoTorch using the PyTorch functions.
As every space in GeoTorch is, at its core, a map from a flat space into a manifold, the tools implemented here also serve as a building block in normalizing flows. Using a factorized space such as LowRank(n,k,r)
it is direct to compute the determinant of the transformation it defines, as we have direct access to the singular values of the layer.
Please cite the following work if you found GeoTorch useful. This paper exposes a simplified mathematical explanation of part of the inner-workings of GeoTorch.
@inproceedings{lezcano2019trivializations,
title = {Trivializations for gradient-based optimization on manifolds},
author = {Lezcano-Casado, Mario},
booktitle={Advances in Neural Information Processing Systems, NeurIPS},
pages = {9154--9164},
year = {2019},
}