diff --git a/deel/lip/compute_layer_sv.py b/deel/lip/compute_layer_sv.py new file mode 100644 index 00000000..15d8cf80 --- /dev/null +++ b/deel/lip/compute_layer_sv.py @@ -0,0 +1,196 @@ +"""Compute the largest and lowest singular values of a layer or network. + +The singular values are computed using the SVD decomposition of the weight matrix. +For convolutional layers, the equivalent matrix is computed and the SVD is applied on +it. + +The `compute_layer_sv()` function is the main function to compute the singular values of +a given layer. It supports by default several kinds of layers (Conv2D, Dense, Add, +BatchNormalization, ReLU, Activation, and deel-lip layers). For other layers, the +user can provide a supplementary_type2sv dictionary linking a new layer type with a +user-defined function to compute the singular values. + +The function `compute_model_sv()` computes the singular values of all layers in a model. +It returns a dictionary indicating for each layer name a tuple (min sv, max sv). +""" + +import numpy as np +import tensorflow as tf + +from .layers import Condensable, GroupSort, MaxMin +from .layers.unconstrained import PadConv2D + + +def _compute_sv_dense(layer, input_sizes=None): + """Compute max and min singular values for a Dense layer. + + The singular values are computed using the SVD decomposition of the weight matrix. + + Args: + layer (tf.keras.Layer): the Dense layer. + input_sizes (tuple, optional): unused here. + + Returns: + tuple: min and max singular values + """ + weights = layer.get_weights()[0] + svd = np.linalg.svd(weights, compute_uv=False) + return (np.min(svd), np.max(svd)) + + +def _generate_conv_matrix(layer, input_sizes): + """Generate equivalent matrix for a convolutional layer. + + The convolutional layer is converted to a dense layer by computing the equivalent + matrix. The equivalent matrix is computed by applying the convolutional layer on a + dirac input. + + Args: + layer (tf.keras.Layer): the convolutional layer to convert to dense. + input_sizes (tuple): the input shape of the layer (with batch dimension as first + element). + + Returns: + np.array: the equivalent matrix of the convolutional layer. + """ + single_layer_model = tf.keras.models.Sequential( + [tf.keras.layers.Input(input_sizes[1:]), layer] + ) + dirac_inp = np.zeros((input_sizes[2],) + input_sizes[1:]) # Line by line generation + in_size = input_sizes[1] * input_sizes[2] + channel_in = input_sizes[-1] + w_eqmatrix = None + start_index = 0 + for ch in range(channel_in): + for ii in range(input_sizes[1]): + dirac_inp[:, ii, :, ch] = np.eye(input_sizes[2]) + out_pred = single_layer_model(dirac_inp) + if w_eqmatrix is None: + w_eqmatrix = np.zeros( + (in_size * channel_in, np.prod(out_pred.shape[1:])) + ) + w_eqmatrix[start_index : (start_index + input_sizes[2]), :] = tf.reshape( + out_pred, (input_sizes[2], -1) + ) + dirac_inp = 0.0 * dirac_inp + start_index += input_sizes[2] + return w_eqmatrix + + +def _compute_sv_conv2d_layer(layer, input_sizes): + """Compute max and min singular values for any convolutional layer. + + The convolutional layer is converted to a dense layer by computing the equivalent + matrix. The equivalent matrix is computed by applying the convolutional layer on a + dirac input. The singular values are then computed using the SVD decomposition of + the weight matrix. + + Args: + layer (tf.keras.Layer): the convolutional layer. + input_sizes (tuple): the input shape of the layer (with batch dimension as first + element). + + Returns: + tuple: min and max singular values + """ + w_eqmatrix = _generate_conv_matrix(layer, input_sizes) + svd = np.linalg.svd(w_eqmatrix, compute_uv=False) + return (np.min(svd), np.max(svd)) + + +def _compute_sv_activation(layer, input_sizes=None): + """Compute min and max gradient norm for activation. + + Warning: This is not singular values for non-linear functions but gradient norm. + """ + if isinstance(layer, tf.keras.layers.Activation): + function2SV = {tf.keras.activations.relu: (0, 1)} + if layer.activation in function2SV.keys(): + return function2SV[layer.activation] + else: + return (None, None) + layer2SV = { + tf.keras.layers.ReLU: (0, 1), + GroupSort: (1, 1), + MaxMin: (1, 1), + } + if layer in layer2SV.keys(): + return layer2SV[layer.activation] + else: + return (None, None) + + +def _compute_sv_add(layer, input_sizes): + """Compute min and max singular values of Add layer.""" + assert isinstance(input_sizes, list) + return (len(input_sizes) * 1.0, len(input_sizes) * 1.0) + + +def _compute_sv_bn(layer, input_sizes=None): + """Compute min and max singular values of BatchNormalization layer.""" + values = np.abs( + layer.gamma.numpy() / np.sqrt(layer.moving_variance.numpy() + layer.epsilon) + ) + return (np.min(values), np.max(values)) + + +def compute_layer_sv(layer, supplementary_type2sv={}): + """ + Compute the largest and lowest singular values (or upper and lower bounds) + of a given layer. + + In case of Condensable layers, a vanilla_export is applied to the layer + to get the weights. + Support by default several kind of layers (Conv2D,Dense,Add, BatchNormalization, + ReLU, Activation, and deel-lip layers) + + Args: + layer (tf.keras.layers.Layer): a single tf.keras.layer + supplementary_type2sv (dict, optional): a dictionary linking new layer type with + user-defined function to compute the singular values. Defaults to {}. + Returns: + tuple: a 2-tuple with lowest and largest singular values. + """ + default_type2sv = { + tf.keras.layers.Conv2D: _compute_sv_conv2d_layer, + tf.keras.layers.Conv2DTranspose: _compute_sv_conv2d_layer, + PadConv2D: _compute_sv_conv2d_layer, + tf.keras.layers.Dense: _compute_sv_dense, + tf.keras.layers.ReLU: _compute_sv_activation, + tf.keras.layers.Activation: _compute_sv_activation, + GroupSort: _compute_sv_activation, + MaxMin: _compute_sv_activation, + tf.keras.layers.Add: _compute_sv_add, + tf.keras.layers.BatchNormalization: _compute_sv_bn, + } + input_shape = layer.input_shape + if isinstance(layer, Condensable): + layer.condense() + layer = layer.vanilla_export() + if type(layer) in default_type2sv.keys(): + return default_type2sv[type(layer)](layer, input_shape) + elif type(layer) in supplementary_type2sv.keys(): + return supplementary_type2sv[type(layer)](layer, input_shape) + else: + return (None, None) + + +def compute_model_sv(model, supplementary_type2sv={}): + """Compute the largest and lowest singular values of all layers in a model. + + Args: + model (tf.keras.Model): a tf.keras Model or Sequential. + supplementary_type2sv (dict, optional): a dictionary linking new layer type + with user defined function to compute the min and max singular values. + + Returns: + dict: A dictionary indicating for each layer name a tuple (min sv, max sv) + """ + list_sv = [] + for layer in model.layers: + if isinstance(layer, tf.keras.Model): + list_sv.append((layer.name, (None, None))) + list_sv += compute_model_sv(layer, supplementary_type2sv) + else: + list_sv.append((layer.name, compute_layer_sv(layer, supplementary_type2sv))) + return list_sv diff --git a/tests/test_compute_layer_sv.py b/tests/test_compute_layer_sv.py new file mode 100644 index 00000000..4efec5b5 --- /dev/null +++ b/tests/test_compute_layer_sv.py @@ -0,0 +1,506 @@ +# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All +# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry, +# CRIAQ and ANITI - https://www.deel.ai/ +# ===================================================================================== +"""Tests for singular value computation (in compute_layer_sv.py) +""" +import os +import pprint +import unittest + +import numpy as np +import tensorflow as tf +from tensorboard.plugins.hparams import api as hp +from tensorflow.keras import Input, Model +from tensorflow.keras import backend as K +from tensorflow.keras import metrics +from tensorflow.keras.layers import Dense, Layer +from tensorflow.keras.optimizers import Adam + +from deel.lip.compute_layer_sv import compute_layer_sv +from deel.lip.layers import ( + FrobeniusConv2D, + FrobeniusDense, + LipschitzLayer, + SpectralConv2D, + SpectralDense, +) +from deel.lip.model import Sequential +from deel.lip.regularizers import OrthDenseRegularizer + +pp = pprint.PrettyPrinter(indent=4) + + +def linear_generator(batch_size, input_shape: tuple, kernel): + """ + Generate data according to a linear kernel + Args: + batch_size: size of each batch + input_shape: shape of the desired input + kernel: kernel used to generate data, must match the last dimensions of + `input_shape` + + Returns: + a generator for the data + + """ + input_shape = tuple(input_shape) + while True: + # pick random sample in [0, 1] with the input shape + batch_x = np.array( + np.random.uniform(-10, 10, (batch_size,) + input_shape), dtype=np.float16 + ) + # apply the k lip linear transformation + batch_y = np.tensordot( + batch_x, + kernel, + axes=( + [i for i in range(1, len(input_shape) + 1)], + [i for i in range(0, len(input_shape))], + ), + ) + yield batch_x, batch_y + + +def build_kernel(input_shape: tuple, output_shape: tuple, k=1.0): + """ + build a kernel with defined lipschitz factor + + Args: + input_shape: input shape of the linear function + output_shape: output shape of the linear function + k: lipshitz factor of the function + + Returns: + the kernel for use in the linear_generator + + """ + input_shape = tuple(input_shape) + output_shape = tuple(output_shape) + kernel = np.array( + np.random.random_sample(input_shape + output_shape), dtype=np.float16 + ) + kernel = ( + kernel * k / np.linalg.norm(kernel) + ) # assuming lipschitz constraint is independent with respect to the chosen metric + + return kernel + + +def generate_k_lip_model(layer_type: type, layer_params: dict, input_shape, k): + """ + build a model with a single layer of given type, with defined lipshitz factor. + + Args: + layer_type: the type of layer to use + layer_params: parameter passed to constructor of layer_type + input_shape: the shape of the input + k: lipshitz factor of the function + + Returns: + a keras Model with a single layer. + + """ + if issubclass(layer_type, Sequential): + model = layer_type(**layer_params) + model.set_klip_factor(k) + return model + a = Input(shape=input_shape) + if issubclass(layer_type, LipschitzLayer): + layer_params["k_coef_lip"] = k + layer = layer_type(**layer_params) + assert isinstance(layer, Layer) + b = layer(a) + return Model(inputs=a, outputs=b) + + +class LipschitzLayersSVTest(unittest.TestCase): + def train_compute_and_verifySV( + self, + layer_type: type, + layer_params: dict, + batch_size: int, + steps_per_epoch: int, + epochs: int, + input_shape: tuple, + k_lip_model: float, + k_lip_data: float, + **kwargs + ): + """ + Create a model, train compute and verify SVs. + + Args: + layer_type: + layer_params: + batch_size: + steps_per_epoch: + epochs: + input_shape: + k_lip_model: + k_lip_data: + **kwargs: + + Returns: + """ + flag_test_SVmin = True + if "test_SVmin" in kwargs.keys(): + flag_test_SVmin = kwargs["test_SVmin"] + if "k_lip_tolerance_factor" not in kwargs.keys(): + kwargs["k_lip_tolerance_factor"] = 1.02 + # clear session to avoid side effects from previous train + K.clear_session() + np.random.seed(42) + tf.random.set_seed(1234) + # create the keras model, defin opt, and compile it + model = generate_k_lip_model(layer_type, layer_params, input_shape, k_lip_model) + print(model.summary()) + + optimizer = Adam(lr=0.001) + model.compile( + optimizer=optimizer, loss="mean_squared_error", metrics=[metrics.mse] + ) + # create the synthetic data generator + output_shape = model.compute_output_shape((batch_size,) + input_shape)[1:] + kernel = build_kernel(input_shape, output_shape, k_lip_data) + # define logging features + logdir = os.path.join("logs", "lip_layers", "%s" % layer_type.__name__) + hparams = dict( + layer_type=layer_type.__name__, + batch_size=batch_size, + steps_per_epoch=steps_per_epoch, + epochs=epochs, + k_lip_data=k_lip_data, + k_lip_model=k_lip_model, + ) + callback_list = [hp.KerasCallback(logdir, hparams)] + if kwargs["callbacks"] is not None: + callback_list = callback_list + kwargs["callbacks"] + # train model + model.fit( + linear_generator(batch_size, input_shape, kernel), + steps_per_epoch=steps_per_epoch, + epochs=epochs, + verbose=1, + callbacks=callback_list, + ) + + file_writer = tf.summary.create_file_writer(os.path.join(logdir, "metrics")) + file_writer.set_as_default() + for ll in model.layers: + print(ll.name) + SVmin, SVmax = compute_layer_sv(ll) + # log metrics + if SVmin is not None: + tf.summary.text("Layer name", ll.name, step=epochs) + tf.summary.scalar("SVmin_estim", SVmin, step=epochs) + tf.summary.scalar("SVmax_estim", SVmax, step=epochs) + self.assertLess( + SVmax, + k_lip_model * kwargs["k_lip_tolerance_factor"], + msg=" the maximum singular value of the layer " + + ll.name + + " must be lower than the specified boundary", # noqa: E501 + ) + self.assertLessEqual( + SVmin, + SVmax, + msg=" the minimum singular value of the layer " + + ll.name + + " must be lower than the maximum value", # noqa: E501 + ) + if flag_test_SVmin: + self.assertGreater( + SVmin, + k_lip_model * (2.0 - kwargs["k_lip_tolerance_factor"]), + msg=" the minimum singular value of the layer " + + ll.name + + " must be greater than the specified boundary", # noqa: E501 + ) + return + + def _apply_tests_bank(self, tests_bank): + for test_params in tests_bank: + pp.pprint(test_params) + self.train_compute_and_verifySV(**test_params) + + def test_spectral_dense(self): + self._apply_tests_bank( + [ + dict( + layer_type=SpectralDense, + layer_params={ + "units": 4, + "use_bias": False, + }, + batch_size=1000, + steps_per_epoch=125, + epochs=5, + input_shape=(4,), + k_lip_data=1.0, + k_lip_model=1.0, + callbacks=[], + ), + dict( + layer_type=SpectralDense, + layer_params={ + "units": 4, + }, + batch_size=1000, + steps_per_epoch=125, + epochs=5, + input_shape=(4,), + k_lip_data=1.0, + k_lip_model=5.0, + callbacks=[], + ), + ] + ) + + def test_frobenius_dense(self): + self._apply_tests_bank( + [ + dict( + layer_type=FrobeniusDense, + layer_params={"units": 1}, + batch_size=1000, + steps_per_epoch=125, + epochs=5, + input_shape=(4,), + k_lip_data=1.0, + k_lip_model=1.0, + test_SVmin=False, + callbacks=[], + ), + dict( + layer_type=FrobeniusDense, + layer_params={"units": 1}, + batch_size=1000, + steps_per_epoch=125, + epochs=5, + input_shape=(4,), + k_lip_data=1.0, + k_lip_model=5.0, + test_SVmin=False, + callbacks=[], + ), + ] + ) + + def test_orthRegul_dense(self): + """ + Tests for a standard Dense layer, for result comparison. + """ + self._apply_tests_bank( + [ + dict( + layer_type=Dense, + layer_params={ + "units": 6, + "kernel_regularizer": OrthDenseRegularizer(1000.0), + }, + batch_size=1000, + steps_per_epoch=125, + epochs=10, + input_shape=(4,), + k_lip_data=1.0, + k_lip_model=1.0, + callbacks=[], + ), + ] + ) + + def test_spectralconv2d(self): + self._apply_tests_bank( + [ + dict( + layer_type=SpectralConv2D, + layer_params={ + "filters": 2, + "kernel_size": (3, 3), + "use_bias": False, + }, + batch_size=100, + steps_per_epoch=125, + epochs=5, + input_shape=(5, 5, 1), + k_lip_data=1.0, + k_lip_model=1.0, + k_lip_tolerance_factor=1.02, + test_SVmin=False, + callbacks=[], + ), + dict( + layer_type=SpectralConv2D, + layer_params={"filters": 2, "kernel_size": (3, 3)}, + batch_size=100, + steps_per_epoch=125, + epochs=5, + input_shape=(5, 5, 1), + k_lip_data=1.0, + k_lip_model=5.0, + k_lip_tolerance_factor=1.02, + test_SVmin=False, + callbacks=[], + ), + dict( + layer_type=SpectralConv2D, + layer_params={ + "filters": 2, + "kernel_size": (3, 3), + "use_bias": False, + }, + batch_size=100, + steps_per_epoch=125, + epochs=5, + input_shape=(5, 5, 3), # case conv_first=False + k_lip_data=1.0, + k_lip_model=1.0, + k_lip_tolerance_factor=1.02, + test_SVmin=False, + callbacks=[], + ), + dict( + layer_type=SpectralConv2D, + layer_params={ + "filters": 5, + "kernel_size": (3, 3), + "use_bias": False, + "strides": 2, + }, + batch_size=100, + steps_per_epoch=125, + epochs=5, + input_shape=(10, 10, 1), + k_lip_data=1.0, + k_lip_model=1.0, + k_lip_tolerance_factor=1.02, + test_SVmin=False, + callbacks=[], + ), + dict( + layer_type=SpectralConv2D, + layer_params={ + "filters": 3, # case conv_first=False + "kernel_size": (3, 3), + "use_bias": False, + "strides": 2, + }, + batch_size=100, + steps_per_epoch=125, + epochs=5, + input_shape=(10, 10, 1), + k_lip_data=1.0, + k_lip_model=1.0, + k_lip_tolerance_factor=1.02, + test_SVmin=False, + callbacks=[], + ), + ] + ) + + def test_frobeniusconv2d(self): + # tests only checks that lip cons is enforced + self._apply_tests_bank( + [ + dict( + layer_type=FrobeniusConv2D, + layer_params={"filters": 2, "kernel_size": (3, 3)}, + batch_size=100, + steps_per_epoch=125, + epochs=5, + input_shape=(5, 5, 1), + k_lip_data=1.0, + k_lip_model=1.0, + k_lip_tolerance_factor=1.1, # Frobenius seems less precise on SVs + test_SVmin=False, + callbacks=[], + ), + dict( + layer_type=FrobeniusConv2D, + layer_params={"filters": 2, "kernel_size": (3, 3)}, + batch_size=100, + steps_per_epoch=125, + epochs=5, + input_shape=(5, 5, 1), + k_lip_data=1.0, + k_lip_model=5.0, + k_lip_tolerance_factor=1.1, + test_SVmin=False, + callbacks=[], + ), + dict( + layer_type=FrobeniusConv2D, + layer_params={"filters": 2, "kernel_size": (3, 3)}, + batch_size=100, + steps_per_epoch=125, + epochs=5, + input_shape=(5, 5, 3), # case conv_first=False + k_lip_data=1.0, + k_lip_model=1.0, + k_lip_tolerance_factor=1.1, # Frobenius seems less precise on SVs + test_SVmin=False, + callbacks=[], + ), + ] + ) + + # def test_orthoconv2d(self): + # # tests only checks that lip cons is enforced + # self._apply_tests_bank( + # [ + # dict( + # layer_type=OrthoConv2D, + # layer_params={ + # "filters": 2, + # "kernel_size": (3, 3), + # "regul_lorth": 1000.0, + # }, + # batch_size=1000, + # steps_per_epoch=125, + # epochs=10, + # input_shape=(5, 5, 1), + # k_lip_data=1.0, + # k_lip_model=1.0, + # k_lip_tolerance_factor=1.1, + # callbacks=[], + # ), + # dict( + # layer_type=OrthoConv2D, + # layer_params={ + # "filters": 2, + # "kernel_size": (3, 3), + # "regul_lorth": 1000.0, + # }, + # batch_size=1000, + # steps_per_epoch=125, + # epochs=10, + # input_shape=(5, 5, 1), + # k_lip_data=1.0, + # k_lip_model=5.0, + # k_lip_tolerance_factor=1.1, + # callbacks=[], + # ), + # dict( + # layer_type=OrthoConv2D, + # layer_params={ + # "filters": 6, + # "kernel_size": (3, 3), + # "regul_lorth": 1000.0, + # "strides": 2, + # }, + # batch_size=1000, + # steps_per_epoch=125, + # epochs=10, + # input_shape=(10, 10, 1), + # k_lip_data=1.0, + # k_lip_model=1.0, + # k_lip_tolerance_factor=1.1, + # callbacks=[], + # ), + # ] + # ) + + +if __name__ == "__main__": + unittest.main()