diff --git a/scripts/run-expreiment.py b/scripts/run-expreiment.py index 21e88ac..cfe53f1 100644 --- a/scripts/run-expreiment.py +++ b/scripts/run-expreiment.py @@ -5,7 +5,8 @@ from src.experiments.config_parser import read_config -Pathable = T.Union[str, os.PathLike] # In principle one can cast it to os.path.Path +Pathable = T.Union[str, os.PathLike] # In principle one can cast it to os.path.Path + @click.command() @click.option("--report_dir", default="./reports", help="Report file") @@ -13,7 +14,7 @@ @click.option("--storage_path", default=None, help="Prefix for config items") def run(report_dir: Pathable, config: Pathable, storage_path: Pathable): """Loads an experiment from config file conducts the experiment it. - + Args: report_dir (str): Directory to save report to. config (str): Path to config file. The report is expected to be specified in .yaml format with @@ -22,22 +23,15 @@ def run(report_dir: Pathable, config: Pathable, storage_path: Pathable): storage_path (str): Path to Ray storage directory. Defaults to None. """ sepline = "\n" + ("-" * 80) + "\n" + ("-" * 80) + "\n" - print( - f"{sepline}Parsing config file:{sepline}" - ) + print(f"{sepline}Parsing config file:{sepline}") config = os.path.abspath(config) experiment = read_config(config) - print( - f"{sepline}Done.{sepline}" - ) - print( - f"{sepline}Conducting experiment{sepline}" - ) + print(f"{sepline}Done.{sepline}") + print(f"{sepline}Conducting experiment{sepline}") # Conduct experiment experiment.conduct(report_dir, storage_path=storage_path) - print( - f"{sepline}Done.{sepline}" - ) - + print(f"{sepline}Done.{sepline}") + + if __name__ == "__main__": - run() \ No newline at end of file + run() diff --git a/src/experiments/base.py b/src/experiments/base.py index f0d4f85..be5e0ae 100644 --- a/src/experiments/base.py +++ b/src/experiments/base.py @@ -3,19 +3,19 @@ class Experiment(object): - """Base class for experiments. - """ + """Base class for experiments.""" + def __init__(self, name, *args, **kwargs): super().__init__() self.name = name - - @classmethod + + @classmethod def _init_rec(cls, cfg): if isinstance(cfg, dict): if "experiment" in cfg: experiment_type = cfg["experiment"]["experiment_type"] params = cls._init_rec(cfg["experiment"]["experiment_params"]) - + return experiment_type(**params) else: return {k: cls._init_rec(v) for k, v in cfg.items()} @@ -23,42 +23,48 @@ def _init_rec(cls, cfg): return [cls._init_rec(v) for v in cfg] else: return cfg - + @classmethod def from_dict(cls, config: T.Dict[str, T.Any]) -> "Experiment": if "experiment" not in config: - raise ValueError("Invalid config file. The config file needs to contain an experiment field.") + raise ValueError( + "Invalid config file. The config file needs to contain an experiment field." + ) return cls._init_rec(config) - - def conduct(self, report_dir: os.PathLike, storage_path: os.PathLike = None) -> None: - """Conducts the experiment and saves the results to the report directory. The method is expected to store all results in report_dir. - """ + + def conduct( + self, report_dir: os.PathLike, storage_path: os.PathLike = None + ) -> None: + """Conducts the experiment and saves the results to the report directory. The method is expected to store all results in report_dir.""" raise NotImplementedError - + + class ExperimentCollection(Experiment): - """ Implements an experiment that consists of several jointly conducted but independent experiments. - """ + """Implements an experiment that consists of several jointly conducted but independent experiments.""" + def __init__(self, experiments: T.Iterable[Experiment], *args, **kwargs) -> None: """ The function initializes an object with a list of experiments based on a given configuration. - + :param experiments: The "experiments" parameter is an iterable object that contains a list of experiments. Each experiment is represented by a configuration object :type experiments: Iterable *args """ super().__init__(*args, **kwargs) self.experiments = experiments - + @classmethod def from_dict(cls, config: T.Dict[str, T.Any]) -> "ExperimentCollection": config = deepcopy(config) for i, exp_cfg in enumerate(config["experiment_params"]["experiments"]): - config["experiment_params"]["experiments"][i] = Experiment.from_dict(exp_cfg) - + config["experiment_params"]["experiments"][i] = Experiment.from_dict( + exp_cfg + ) + return Experiment.from_dict(config) - - + def conduct(self, report_dir: os.PathLike, storage_path: os.PathLike = None): for i, exp in enumerate(self.experiments): - exp.conduct(os.path.join(report_dir, f"{i}_{exp.name}"), storage_path=storage_path) - + exp.conduct( + os.path.join(report_dir, f"{i}_{exp.name}"), storage_path=storage_path + ) diff --git a/src/experiments/config_parser.py b/src/experiments/config_parser.py index ece093e..39fede5 100644 --- a/src/experiments/config_parser.py +++ b/src/experiments/config_parser.py @@ -11,10 +11,10 @@ def unfold_raw_config(d: Dict[str, Any]): - """Unfolds an ordered DAG given as a dictionary into a tree given as dictionary. + """Unfolds an ordered DAG given as a dictionary into a tree given as dictionary. That means that unfold_dict(d) is bisimilar to d but no two distinct key paths in the resulting dictionary reference the same object - + :param d: The dictionary to unfold """ du = dict() @@ -25,26 +25,27 @@ def unfold_raw_config(d: Dict[str, Any]): du[k] = [unfold_raw_config(x) for x in v] else: du[k] = deepcopy(v) - + return du + def push_overwrites(item: Any, attributes: Dict[str, Any]) -> Any: """Pushes the overwrites in the given dictionary to the given item. - + If the item already specifies an overwrite, it is updated. - If the item is a dictionary, an overwrite specification for the dictionary is created. + If the item is a dictionary, an overwrite specification for the dictionary is created. If the item is a list, the overwrites are pushed each element and the processed list is returned. Otherwise, item is overwritten by attributen - + :param item: The item to push the overwrites to. :param overwrites: The overwrites to push. """ try: if "__exact__" in attributes: - return deepcopy(attributes["__exact__"]) + return deepcopy(attributes["__exact__"]) except: pass - + if isinstance(item, dict): if "__overwrites__" not in item: result = deepcopy(attributes) @@ -52,74 +53,73 @@ def push_overwrites(item: Any, attributes: Dict[str, Any]) -> Any: else: result = item result["__overwrites__"].update(attributes) - elif isinstance(item, list): + elif isinstance(item, list): result = [push_overwrites(x, attributes) for x in item] else: result = deepcopy(attributes) - + return result - + def apply_overwrite(d: Dict[str, Any], recurse: bool = True): """Applies the "__overwrites__" keyword sematic to a unfolded raw config dictionary and returns the result. Except for the special semantics that applies to dictionaries and lists (see below), - all keys $k$ that are present in in the "__overwrites__" dictionary $o$ are overwritten in $d$ by $o[k]$. - + all keys $k$ that are present in in the "__overwrites__" dictionary $o$ are overwritten in $d$ by $o[k]$. + ** Dict/List overwrites **: - - If $d[k]$ is a dictionary, then $d[k]$ must be a dictionary and overwrites of $o[k]$ are recursively + - If $d[k]$ is a dictionary, then $d[k]$ must be a dictionary and overwrites of $o[k]$ are recursively to the corresponding $d[k]$, i.e. lower-lever overwrites are specified/updated (see notes on recursion). The only exception is if $o[k]$ contains the special key "__exact__" with value True. In this case $d[k$]$ is replaced by $o[k]["__exact__"]$. - - If $o[k]$ is a list, then $o[k]$ is pushed to all list elements. - + - If $o[k]$ is a list, then $o[k]$ is pushed to all list elements. + ** Recursion **: - If recursion is enabled, overwrites are are fully expanded in infix order where nested overwrites - (see behavior on dict/list overwrites) are pushed (and overwrite) to the next level, i.e. + If recursion is enabled, overwrites are are fully expanded in infix order where nested overwrites + (see behavior on dict/list overwrites) are pushed (and overwrite) to the next level, i.e. higher level overwrites lower level. Else, only the top-level overwrites are applied. - + ** Note **: Applying this function to a non-unfolded dictionary can result in unexpected behavior due to side side-effects. - + :param d: The unfolded raw config dictionary. - :pram recurse: If True, the overwrites are applied recursively. Defaults to True. + :pram recurse: If True, the overwrites are applied recursively. Defaults to True. Can be useful for efficient combination of this method with other parsing methods. """ - + # Apply top-level overwrite if "__overwrites__" in d: - overwritten_attr = d + overwritten_attr = d d = overwritten_attr.pop("__overwrites__") - + for k, v in overwritten_attr.items(): if k not in d: d[k] = v else: d[k] = push_overwrites(d[k], v) - - + if recurse: for k, v in d.items(): if isinstance(v, dict): d[k] = apply_overwrite(v) elif isinstance(v, list): d[k] = [apply_overwrite(x) for x in v] - + return d - - + + def read_config(yaml_path: Union[str, Path]) -> dict: """Loads a yaml file and returns the corresponding dictionary. Besides the standard yaml syntax, the function also supports the following additional functionality: - + Special keys: - __class__: The value of this key is interpreted as the class name of the object. + __class__: The value of this key is interpreted as the class name of the object. The class is imported and stored in the result dictionary under the key . Example: entry in yaml: __class__model: laplace_flows.flows.NiceFlow) entry in result: model: __import__("laplace_flows.flows.NiceFlow") - __tune__: The value of this key is interpreted as a dictionary that contains the - configuration for the hyperparameter optimization using tune sample methods. + __tune__: The value of this key is interpreted as a dictionary that contains the + configuration for the hyperparameter optimization using tune sample methods. the directive is evaluated and the result in the result dictionary under the key . Example: entry in yaml: __tune__lr: loguniform(1e-4, 1e-1) @@ -127,15 +127,16 @@ def read_config(yaml_path: Union[str, Path]) -> dict: :param yaml_path: Path to the yaml file. """ - + with open(yaml_path, "r") as f: config = yaml.load(f, Loader=yaml.FullLoader) - - config = unfold_raw_config(config) + + config = unfold_raw_config(config) config = parse_raw_config(config) return config - + + def parse_raw_config(d: dict): """Parses an unfolded raw config dictionary and returns the corresponding dictionary. Parsing includes the following steps: @@ -143,16 +144,16 @@ def parse_raw_config(d: dict): - The "__object__" key is interpreted as a class name and the corresponding class is imported. - The "__eval__" key is evaluated. - The "__class__" key is interpreted as a class name and the corresponding class is imported. - + :param d: The raw config dictionary. """ if isinstance(d, dict): d = apply_overwrite(d, recurse=False) - + # Depth-first recursion for k, v in d.items(): d[k] = parse_raw_config(v) - + if "__object__" in d: module, cls = d["__object__"].rsplit(".", 1) C = getattr(import_module(module), cls) @@ -164,12 +165,10 @@ def parse_raw_config(d: dict): module, cls = d["__class__"].rsplit(".", 1) C = getattr(import_module(module), cls) return C - else: + else: return d elif isinstance(d, list): result = [parse_raw_config(x) for x in d] return result else: return d - - diff --git a/src/experiments/datasets.py b/src/experiments/datasets.py index 9fe08a8..6b4407a 100644 --- a/src/experiments/datasets.py +++ b/src/experiments/datasets.py @@ -7,8 +7,7 @@ import pandas as pd import torch import torchvision.transforms as transforms -from sklearn.datasets import (make_blobs, make_checkerboard, make_circles, - make_moons) +from sklearn.datasets import make_blobs, make_checkerboard, make_circles, make_moons from torch import Tensor from torchvision.datasets import MNIST, FashionMNIST @@ -17,87 +16,99 @@ class DequantizedDataset(torch.utils.data.Dataset): """ A dataset that dequantizes the data by adding uniform noise to each pixel. """ - def __init__(self, dataset: T.Union[os.PathLike, torch.utils.data.Dataset, np.ndarray], num_bits: int = 8): - if isinstance(dataset, torch.utils.data.Dataset) or isinstance(dataset, np.ndarray): + + def __init__( + self, + dataset: T.Union[os.PathLike, torch.utils.data.Dataset, np.ndarray], + num_bits: int = 8, + ): + if isinstance(dataset, torch.utils.data.Dataset) or isinstance( + dataset, np.ndarray + ): self.dataset = dataset else: self.dataset = pd.read_csv(dataset).values - + self.num_bits = num_bits - self.num_levels = 2 ** num_bits - self.transform = transforms.Compose([ - transforms.Lambda(lambda x: x / self.num_levels), - transforms.Lambda(lambda x: x + torch.rand_like(x) / self.num_levels) - ]) + self.num_levels = 2**num_bits + self.transform = transforms.Compose( + [ + transforms.Lambda(lambda x: x / self.num_levels), + transforms.Lambda(lambda x: x + torch.rand_like(x) / self.num_levels), + ] + ) def __getitem__(self, index: int): - x, y = self.dataset[index] x = Tensor(self.transform(x)) return x, y def __len__(self): return len(self.dataset) - + + class DataSplit: def __init__(*agrs, **kwargs): raise NotImplementedError - + @abstractmethod def get_train(self) -> torch.utils.data.Dataset: raise NotImplementedError - + @abstractmethod def get_test(self) -> torch.utils.data.Dataset: raise NotImplementedError - + @abstractmethod def get_val(self) -> torch.utils.data.Dataset: raise NotImplementedError - + + class SimpleSplit(DataSplit): - """ + """ Split of dataset """ + def __init__( self, train: torch.utils.data.Dataset, test: torch.utils.data.Dataset, - val: torch.utils.data.Dataset + val: torch.utils.data.Dataset, ): - """ Create split of dataset - + """Create split of dataset + Args: train (torch.utils.data.Dataset): training set test (torch.utils.data.Dataset): test set - val (torch.utils.data.Dataset): validation set + val (torch.utils.data.Dataset): validation set """ self.train = train self.test = test self.val = val - + def get_train(self) -> torch.utils.data.Dataset: return self.train - + def get_test(self) -> torch.utils.data.Dataset: return self.test - + def get_val(self) -> torch.utils.data.Dataset: return self.val - - - + + GENERATORS = { "make_moons": make_moons, "make_blobs": make_blobs, - "make_checkerboard": make_checkerboard, - "make_circles": make_circles + "make_checkerboard": make_checkerboard, + "make_circles": make_circles, } + class SyntheticDataset(torch.utils.data.Dataset): """ Dataset from generator function - """ + """ + def __init__( self, generator: T.Union[T.Callable[..., np.ndarray], str], @@ -106,7 +117,7 @@ def __init__( **kwargs ): """Create dataset from generator function - + Args: generator (function): generator function params: ]dict]: parameters for generator function @@ -114,21 +125,23 @@ def __init__( super().__init__(*args, **kwargs) if isinstance(generator, str): generator = GENERATORS[generator] - + self.dataset = generator(**params)[0] - + def __getitem__(self, index: int): x = self.dataset[index] x = Tensor(x) return x, torch.zeros_like(x) - + def __len__(self): return len(self.dataset) + class SyntheticSplit(SimpleSplit): - """ + """ Split of synthetic dataset """ + def __init__( self, generator: T.Union[T.Callable[..., np.ndarray], str], @@ -139,23 +152,25 @@ def __init__( **kwargs ): """Create dataset from generator function - + Args: generator (function): generator function params: ]dict]: parameters for generator function """ if isinstance(generator, str): generator = GENERATORS[generator] - + train = SyntheticDataset(generator, params_train) test = SyntheticDataset(generator, params_test) val = SyntheticDataset(generator, params_val) super().__init__(train=train, test=test, val=val, *args, **kwargs) - + + class FlattenedDataset(torch.utils.data.Dataset): """ A dataset that flattens the data. """ + def __init__(self, dataset: torch.utils.data.Dataset): self.dataset = dataset @@ -166,33 +181,53 @@ def __getitem__(self, index: int): def __len__(self): return len(self.dataset) - + + class FashionMnistDequantized(DequantizedDataset): - def __init__(self, dataloc: os.PathLike = None, train: bool = True, label: T.Optional[int] = None): - rel_path = "FashionMNIST/raw/train-images-idx3-ubyte" if train else "FashionMNIST/raw/t10k-images-idx3-ubyte" + def __init__( + self, + dataloc: os.PathLike = None, + train: bool = True, + label: T.Optional[int] = None, + ): + rel_path = ( + "FashionMNIST/raw/train-images-idx3-ubyte" + if train + else "FashionMNIST/raw/t10k-images-idx3-ubyte" + ) path = os.path.join(dataloc, rel_path) if not os.path.exists(path): FashionMNIST(path, train=train, download=True) # TODO: remove hardcoding of 3x3 downsampling, vectorizing dataset = idx2numpy.convert_from_file(path)[:, ::3, ::3] - dataset = dataset.reshape(dataset.shape[0], -1) + dataset = dataset.reshape(dataset.shape[0], -1) if label is not None: - rel_path = "FashionMNIST/raw/train-labels-idx1-ubyte" if train else "FashionMNIST/raw/t10k-labels-idx1-ubyte" + rel_path = ( + "FashionMNIST/raw/train-labels-idx1-ubyte" + if train + else "FashionMNIST/raw/t10k-labels-idx1-ubyte" + ) path = os.path.join(dataloc, rel_path) labels = idx2numpy.convert_from_file(path) - dataset = dataset[labels == label] + dataset = dataset[labels == label] super().__init__(dataset, num_bits=8) - - def __getitem__(self, index: int): + def __getitem__(self, index: int): x = Tensor(self.dataset[index].copy()) x = self.transform(x) return x, 0 + class MnistDequantized(DequantizedDataset): - def __init__(self, dataloc: os.PathLike = None, train: bool = True, digit: T.Optional[int] = None, flatten=True): + def __init__( + self, + dataloc: os.PathLike = None, + train: bool = True, + digit: T.Optional[int] = None, + flatten=True, + ): if train: - rel_path = "MNIST/raw/train-images-idx3-ubyte" + rel_path = "MNIST/raw/train-images-idx3-ubyte" else: rel_path = "MNIST/raw/t10k-images-idx3-ubyte" path = os.path.join(dataloc, rel_path) @@ -202,74 +237,94 @@ def __init__(self, dataloc: os.PathLike = None, train: bool = True, digit: T.Opt # TODO: remove hardcoding of 3x3 downsampling dataset = idx2numpy.convert_from_file(path)[:, ::3, ::3] if flatten: - dataset = dataset.reshape(dataset.shape[0], -1) + dataset = dataset.reshape(dataset.shape[0], -1) if digit is not None: if train: - rel_path = "MNIST/raw/train-labels-idx1-ubyte" + rel_path = "MNIST/raw/train-labels-idx1-ubyte" else: rel_path = "MNIST/raw/t10k-labels-idx1-ubyte" path = os.path.join(dataloc, rel_path) labels = idx2numpy.convert_from_file(path) - dataset = dataset[labels == digit] + dataset = dataset[labels == digit] super().__init__(dataset, num_bits=8) - - def __getitem__(self, index: int): + def __getitem__(self, index: int): x = Tensor(self.dataset[index].copy()) x = self.transform(x) return x, 0 + class DataSplitFromCSV(DataSplit): def __init__(self, train: os.PathLike, test: os.PathLike, val: os.PathLike): self.train = train self.test = test self.val = val - + def get_train(self) -> torch.utils.data.Dataset: return pd.read_csv(self.train).values - + def get_test(self) -> torch.utils.data.Dataset: return pd.read_csv(self.test).values - + def get_val(self) -> torch.utils.data.Dataset: return pd.read_csv(self.val).values - + + class FashionMnistSplit(DataSplit): - def __init__(self, dataloc: os.PathLike = None, val_split: float = .1, label: T.Optional[int] = None): + def __init__( + self, + dataloc: os.PathLike = None, + val_split: float = 0.1, + label: T.Optional[int] = None, + ): if dataloc is None: dataloc = os.path.join(os.getcwd(), "data") self.dataloc = dataloc self.train = FashionMnistDequantized(self.dataloc, train=True, label=label) shuffle = torch.randperm(len(self.train)) - self.val = torch.utils.data.Subset(self.train, shuffle[:int(len(self.train) * val_split)]) - self.train = torch.utils.data.Subset(self.train, shuffle[int(len(self.train) * val_split):]) + self.val = torch.utils.data.Subset( + self.train, shuffle[: int(len(self.train) * val_split)] + ) + self.train = torch.utils.data.Subset( + self.train, shuffle[int(len(self.train) * val_split) :] + ) self.test = FashionMnistDequantized(self.dataloc, train=False, label=label) - + def get_train(self) -> torch.utils.data.Dataset: return self.train - + def get_test(self) -> torch.utils.data.Dataset: return self.test - + def get_val(self) -> torch.utils.data.Dataset: return self.val - + + class MnistSplit(DataSplit): - def __init__(self, dataloc: os.PathLike = None, val_split: float = .1, digit: T.Optional[int] = None): + def __init__( + self, + dataloc: os.PathLike = None, + val_split: float = 0.1, + digit: T.Optional[int] = None, + ): if dataloc is None: dataloc = os.path.join(os.getcwd(), "data") self.dataloc = dataloc self.train = MnistDequantized(self.dataloc, train=True, digit=digit) shuffle = torch.randperm(len(self.train)) - self.val = torch.utils.data.Subset(self.train, shuffle[:int(len(self.train) * val_split)]) - self.train = torch.utils.data.Subset(self.train, shuffle[int(len(self.train) * val_split):]) + self.val = torch.utils.data.Subset( + self.train, shuffle[: int(len(self.train) * val_split)] + ) + self.train = torch.utils.data.Subset( + self.train, shuffle[int(len(self.train) * val_split) :] + ) self.test = MnistDequantized(self.dataloc, train=False, digit=digit) - + def get_train(self) -> torch.utils.data.Dataset: return self.train - + def get_test(self) -> torch.utils.data.Dataset: return self.test - + def get_val(self) -> torch.utils.data.Dataset: - return self.val \ No newline at end of file + return self.val diff --git a/src/experiments/hyperopt.py b/src/experiments/hyperopt.py index 9eda47f..6ec467c 100644 --- a/src/experiments/hyperopt.py +++ b/src/experiments/hyperopt.py @@ -19,13 +19,25 @@ from src.veriflow.networks import AdditiveAffineNN from src.veriflow.transforms import ScaleTransform -HyperParams = Literal["train", "test", "coupling_layers", "coupling_nn_layers", "split_dim", "epochs", "iters", "batch_size", - "optim", "optim_params", "base_dist"] +HyperParams = Literal[ + "train", + "test", + "coupling_layers", + "coupling_nn_layers", + "split_dim", + "epochs", + "iters", + "batch_size", + "optim", + "optim_params", + "base_dist", +] BaseDisbributions = Literal["Laplace", "Normal"] + class HyperoptExperiment(Experiment): """Hyperparameter optimization experiment.""" - + def __init__( self, trial_config: Dict[str, Any], @@ -35,10 +47,10 @@ def __init__( scheduler: tune.schedulers.FIFOScheduler, tuner_params: T.Dict[str, T.Any], *args, - **kwargs + **kwargs, ) -> None: """Initialize hyperparameter optimization experiment. - + Args: trial_config (Dict[str, Any]): trial configuration num_hyperopt_samples (int): number of hyperparameter optimization samples @@ -66,12 +78,12 @@ def _trial(cls, config: T.Dict[str, T.Any], device: torch.device = "cpu"): if device is None: if torch.backends.mps.is_available(): device = torch.device("mps") - #torch.mps.empty_cache() + # torch.mps.empty_cache() elif torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") - + dataset = config["dataset"] data_train = dataset.get_train() data_test = dataset.get_test() @@ -86,13 +98,11 @@ def _trial(cls, config: T.Dict[str, T.Any], device: torch.device = "cpu"): base_dist = torch.distributions.Normal(zeros, ones) else: raise ValueError("Unknown base distribution") - + config["model_cfg"]["params"]["base_distribution"] = base_dist - - flow = config["model_cfg"]["type"]( - **(config["model_cfg"]["params"]) - ) - + + flow = config["model_cfg"]["type"](**(config["model_cfg"]["params"])) + flow.to(device) best_loss = float("inf") @@ -103,33 +113,41 @@ def _trial(cls, config: T.Dict[str, T.Any], device: torch.device = "cpu"): config["optim_cfg"]["optimizer"], config["optim_cfg"]["params"], batch_size=config["batch_size"], - device=device, + device=device, ) val_loss = 0 for i in range(0, len(data_val), config["batch_size"]): - j = min([len(data_test), i+config["batch_size"]]) + j = min([len(data_test), i + config["batch_size"]]) val_loss += float(-flow.log_prob(data_val[i:j][0].to(device)).sum()) val_loss /= len(data_val) - session.report({"test_loss": "?", "train_loss": train_loss, "val_loss": val_loss}, checkpoint=None) + session.report( + {"test_loss": "?", "train_loss": train_loss, "val_loss": val_loss}, + checkpoint=None, + ) if val_loss < best_loss: strikes = 0 best_loss = val_loss torch.save(flow.state_dict(), "./checkpoint.pt") test_loss = 0 for i in range(0, len(data_test), config["batch_size"]): - j = min([len(data_test), i+config["batch_size"]]) - test_loss += float(-flow.log_prob(data_test[i:j][0].to(device)).sum()) + j = min([len(data_test), i + config["batch_size"]]) + test_loss += float( + -flow.log_prob(data_test[i:j][0].to(device)).sum() + ) test_loss /= len(data_test) else: strikes += 1 if strikes >= config["patience"]: break # torch.mps.empty_cache() - - return {"test_loss_best": test_loss, "val_loss_best": best_loss, "val_loss": val_loss} + return { + "test_loss_best": test_loss, + "val_loss_best": best_loss, + "val_loss": val_loss, + } def conduct(self, report_dir: os.PathLike, storage_path: os.PathLike = None): """Run hyperparameter optimization experiment. @@ -138,47 +156,55 @@ def conduct(self, report_dir: os.PathLike, storage_path: os.PathLike = None): report_dir (os.PathLike): report directory storage_path (os.PathLike, optional): Ray logging path. Defaults to None. """ - home = os.path.expanduser( '~' ) - + home = os.path.expanduser("~") + if storage_path is not None: tuner_config = {"run_config": RunConfig(storage_path=storage_path)} else: storage_path = os.path.expanduser("~/ray_results") tuner_config = {} - + exptime = str(datetime.now()) - + tuner = tune.Tuner( tune.with_resources( tune.with_parameters(HyperoptExperiment._trial), - resources={"cpu": self.cpus_per_trial, "gpu": self.gpus_per_trial} + resources={"cpu": self.cpus_per_trial, "gpu": self.gpus_per_trial}, ), tune_config=tune.TuneConfig( scheduler=self.scheduler, num_samples=self.num_hyperopt_samples, - **(self.tuner_params) + **(self.tuner_params), ), param_space=self.trial_config, - **(tuner_config) + **(tuner_config), ) results = tuner.fit() - + # TODO: hacky way to dertmine the last experiment - exppath = storage_path + ["/" + f for f in sorted(os.listdir(storage_path)) if f.startswith("_trial")][-1] - report_file = os.path.join(report_dir, f"report_{self.name}_" + exptime + ".csv") + exppath = ( + storage_path + + [ + "/" + f + for f in sorted(os.listdir(storage_path)) + if f.startswith("_trial") + ][-1] + ) + report_file = os.path.join( + report_dir, f"report_{self.name}_" + exptime + ".csv" + ) self._build_report(exppath, report_file=report_file) - #best_result = results.get_best_result("val_loss", "min") + # best_result = results.get_best_result("val_loss", "min") - #print("Best trial config: {}".format(best_result.config)) - #print("Best trial final validation loss: {}".format( + # print("Best trial config: {}".format(best_result.config)) + # print("Best trial final validation loss: {}".format( # best_result.metrics["val_loss"])) - - #test_best_model(best_result) + # test_best_model(best_result) def _build_report(self, expdir: str, report_file: str, config_prefix: str = ""): """Builds a report of the hyperopt experiment. - + :param expdir: The expdir parameter is the path to the experiment directory (ray results folder). :type expdir: str :param report_file: The report_file parameter is the path to the report file. @@ -191,21 +217,26 @@ def _build_report(self, expdir: str, report_file: str, config_prefix: str = ""): if os.path.isdir(expdir + "/" + d): try: with open(expdir + "/" + d + "/result.json", "r") as f: - result = json.loads("{\"test_" + f.read().split("{\"test_")[-1]) + result = json.loads('{"test_' + f.read().split('{"test_')[-1]) except: print(f"error at {expdir + '/' + d}") continue - + config = result["config"] for k in config.keys(): - result[config_prefix + k] = config[k] if not isinstance(config[k], Iterable) else str(config[k]) + result[config_prefix + k] = ( + config[k] + if not isinstance(config[k], Iterable) + else str(config[k]) + ) result.pop("config") if report is None: report = pd.DataFrame(result, index=[0]) else: - report = pd.concat([report, pd.DataFrame(result, index=[0])], ignore_index=True) - + report = pd.concat( + [report, pd.DataFrame(result, index=[0])], ignore_index=True + ) + os.makedirs(os.path.dirname(report_file), exist_ok=True) report.to_csv(report_file, index=False) - diff --git a/src/veriflow/flows.py b/src/veriflow/flows.py index 08f8108..97d56a1 100644 --- a/src/veriflow/flows.py +++ b/src/veriflow/flows.py @@ -2,7 +2,7 @@ from torch.utils.data import Dataset -import pyro +import pyro from pyro import distributions as dist from pyro.distributions.transforms import SoftplusTransform from pyro.nn import DenseNN @@ -14,16 +14,24 @@ from sklearn.datasets import load_digits from tqdm import tqdm -from src.veriflow.transforms import ScaleTransform, MaskedCoupling, Permute, LUTransform, LeakyReLUTransform +from src.veriflow.transforms import ( + ScaleTransform, + MaskedCoupling, + Permute, + LUTransform, + LeakyReLUTransform, + BaseTransform, +) from src.veriflow.networks import AdditiveAffineNN, ConvNet2D class Flow(torch.nn.Module): """Base implementation of a flow model""" + # Export mode determines whether the log_prob or the sample function is exported to onnx export_modes = Literal["log_prob", "sample"] export = "log_prob" - + def forward(self, x: torch.Tensor): """Dummy implementation of forward method for onnx export. The self.export attribute determines whether the log_prob or the sample function is exported to onnx""" @@ -33,36 +41,41 @@ def forward(self, x: torch.Tensor): return self.sample() else: raise ValueError(f"Unknown export mode {self.export}") - + def __init__(self, base_distribution, layers, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.layers = layers - self.trainable_layers = torch.nn.ModuleList([l for l in layers if isinstance(l, torch.nn.Module)]) + self.trainable_layers = torch.nn.ModuleList( + [l for l in layers if isinstance(l, torch.nn.Module)] + ) self.base_distribution = base_distribution self.transform = dist.TransformedDistribution(base_distribution, layers) def fit( - self, - data_train: Dataset, - optim: torch.optim.Optimizer = torch.optim.Adam, - optim_params: Dict[str, Any] = None, - batch_size: int = 32, - shuffe: bool = True, - gradient_clip: float = None, - device: torch.device = None, - jitter: float = 1e-4, - ) -> float: + self, + data_train: Dataset, + optim: torch.optim.Optimizer = torch.optim.Adam, + optim_params: Dict[str, Any] = None, + batch_size: int = 32, + shuffe: bool = True, + gradient_clip: float = None, + device: torch.device = None, + jitter: float = 1e-4, + ) -> float: """ - Wrapper function for the fitting procedure. Allows basic configuration of the optimizer and other fitting parameters. + Wrapper function for the fitting procedure. Allows basic configuration of the optimizer and other + fitting parameters. - @param data_train: training data. - @param batch_size: number of samples per optimization step. - @param optim: optimizer class. - @param optimizer_params: optimizer parameter dictionary. + Args: + data_train: training data. + batch_size: number of samples per optimization step. + optim: optimizer class. + optimizer_params: optimizer parameter dictionary. - @returns loss curve (negative log-likelihood). + Returns: + Loss curve (negative log-likelihood). """ if device is None: if torch.backends.mps.is_available(): @@ -71,9 +84,9 @@ def fit( device = torch.device("cuda") else: device = torch.device("cpu") - + model = self.to(device) - + if optim_params is not None: optim = optim(model.trainable_layers.parameters(), **optim_params) else: @@ -86,7 +99,6 @@ def fit( perm = np.random.choice(N, N, replace=False) data_train = data_train[perm] - for idx in range(0, N, batch_size): idx_end = min(idx + batch_size, N) try: @@ -104,76 +116,90 @@ def fit( optim.step() while not self.is_feasible(): self.add_jitter(jitter) - + model.transform.clear_cache() - + return sum(losses) / len(losses) - + def to_onnx(self, path: str, export_mode: export_modes = "log_prob") -> None: """Saves the model as onnx file - - :param path: path to save the model. - :param export_mode: export mode. Can be "log_prob" or "sample". + + Args: + path: path to save the model. + export_mode: export mode. Can be "log_prob" or "sample". """ self.export = export_mode dummy_input = self.base_distribution.sample() torch.onnx.export(self, dummy_input, path, verbose=True) - - def log_prob(self, x: torch.Tensor): + + def log_prob(self, x: torch.Tensor) -> torch.Tensor: """Returns the models log-densities for the given samples - @param x: sample tensor. + Args: + x: sample tensor. """ return self.transform.log_prob(x) - - def sample(self, sample_shape: Iterable[int] = None): + + def sample(self, sample_shape: Iterable[int] = None) -> torch.Tensor: """Returns n_sample samples from the distribution - - @param n_sample: sample shape. + + Args: + n_sample: sample shape. """ if sample_shape is None: sample_shape = [1] return self.transform.sample(sample_shape) - - def to(self, device): + + def to(self, device) -> None: """Moves the model to the given device""" self.device = device - #self.layers = torch.nn.ModuleList([l.to(device) for l in self.layers]) - self.trainable_layers = torch.nn.ModuleList([l.to(device) for l in self.trainable_layers]) + # self.layers = torch.nn.ModuleList([l.to(device) for l in self.layers]) + self.trainable_layers = torch.nn.ModuleList( + [l.to(device) for l in self.trainable_layers] + ) return super().to(device) - - def is_feasible(self): - return True - - def add_jitter(self, jitter): - pass -Permutation = Literal["random", "half", "LU"] + def is_feasible(self) -> bool: + """Checks is the model parameters meet all constraints""" + return all( + [l.is_feasible() for l in self.layers if isinstance(l, BaseTransform)] + ) + + def add_jitter(self, jitter: float = 1e-6) -> None: + """Adds jitter to meet non-zero constraints""" + for l in self.layers: + if isinstance(l, BaseTransform) and not l.is_feasible(): + l.add_jitter(jitter) + class NiceFlow(Flow): + Permutation = Literal["random", "half", "LU"] + """Implementation of the NICE flow architecture by using fully connected coupling layers""" + def __init__( - self, - base_distribution: dist.Distribution, - coupling_layers: int, - coupling_nn_layers: List[int], - split_dim: int, - scale_every_coupling=False, - nonlinearity: Optional[torch.nn.Module] = None, - permutation: Permutation = "random", - *args, - **kwargs - ) -> None: + self, + base_distribution: dist.Distribution, + coupling_layers: int, + coupling_nn_layers: List[int], + split_dim: int, + scale_every_coupling=False, + nonlinearity: Optional[torch.nn.Module] = None, + permutation: Permutation = "random", + *args, + **kwargs, + ) -> None: """Initialization - @param base_distribution: base distribution, - @param coupling_layers: number of coupling layers. All coupling layers share the same architecture but not the same weights. - @param coupling_nn_layers: number of neurons in the hidden layers of the dense neural network that computes the coupling loc parameter. - @param split_dim: split dimension for the coupling. - @param scale_every_coupling: if True, a scale transform is applied after every coupling layer. Otherwise, a single scale transform is applied after all coupling layers. - @param nonlinearity: nonlinearity of the coupling network. - @param permutation: permutation type. Can be "random" or "half". + Args: + base_distribution: base distribution, + coupling_layers: number of coupling layers. All coupling layers share the same architecture but not the same weights. + coupling_nn_layers: number of neurons in the hidden layers of the dense neural network that computes the coupling loc parameter. + split_dim: split dimension for the coupling. + scale_every_coupling: if True, a scale transform is applied after every coupling layer. Otherwise, a single scale transform is applied after all coupling layers. + nonlinearity: nonlinearity of the coupling network. + permutation: permutation type. Can be "random" or "half". """ input_dim = base_distribution.sample().shape[0] self.input_dim = input_dim @@ -188,54 +214,62 @@ def __init__( layers = [] for i in range(coupling_layers): layers.append(self._get_permutation(permutation, i)) - layers.append(AffineCoupling(split_dim, AdditiveAffineNN(split_dim, coupling_nn_layers, rdim, nonlinearity=nonlinearity))) + layers.append( + AffineCoupling( + split_dim, + AdditiveAffineNN( + split_dim, coupling_nn_layers, rdim, nonlinearity=nonlinearity + ), + ) + ) if scale_every_coupling: layers.append(ScaleTransform(input_dim)) - + if not scale_every_coupling: layers.append(ScaleTransform(input_dim)) super().__init__(base_distribution, layers, *args, **kwargs) - + def _get_permutation(self, permtype: Permutation, i=0): - """Returns a permutation layer - """ + """Returns a permutation layer""" if permtype == "random": return Permute(torch.randperm(self.input_dim, dtype=torch.long)) elif permtype == "half": - if i % 2 == 0: # every 2nd pixel + if i % 2 == 0: # every 2nd pixel perm = torch.arange(self.input_dim, dtype=torch.long) perm = perm.reshape(-1, 2).moveaxis(0, 1).reshape(-1) - elif i % 2 == 1: # interchange conditioning variables and output variables + elif i % 2 == 1: # interchange conditioning variables and output variables perm = torch.arange(self.input_dim, dtype=torch.long) perm = perm.reshape(2, -1).flip(0).reshape(-1) - else: # random permutation + else: # random permutation perm = torch.randperm(self.input_dim, dtype=torch.long) return Permute(perm) elif permtype == "LU": - return LUTransform(self.input_dim) + return LUTransform(self.input_dim) else: raise ValueError(f"Unknown permutation type {permtype}") + class NiceMaskedConvFlow(Flow): - """Implementation of the NICE flow architecture using fully connected coupling layers and a checkerboard permutation""" + """Implementation of the NICE flow architecture using fully connected coupling layers + and a checkerboard permutation""" + def __init__( - self, - base_distribution: dist.Distribution, - coupling_layers: int, - conv_layers: int, - kernel_size: int, - nonlinearity: Optional[torch.nn.Module] = None, - c_hidden: int = 32, - rescale_hidden: Union[int, Tuple[int]] = 4, - *args, - **kwargs - ) -> None: + self, + base_distribution: dist.Distribution, + coupling_layers: int, + conv_layers: int, + kernel_size: int, + nonlinearity: Optional[torch.nn.Module] = None, + c_hidden: int = 32, + rescale_hidden: Union[int, Tuple[int]] = 4, + *args, + **kwargs, + ) -> None: """Initialization Args: - base_distribution: base distribution, coupling_layers: number of coupling layers. All coupling layers share the same architecture but not the same weights. coupling_nn_layers: number of hidden convolutional layers of the network that computes the coupling loc parameter. @@ -244,7 +278,7 @@ def __init__( c_hidden: number of hidden channels of the convolutional layers. rescale_hidden: rescaling of hight and width for the hidden layers. """ - + self.coupling_layers = coupling_layers if nonlinearity is None: @@ -252,65 +286,73 @@ def __init__( c, h, w = base_distribution.sample().shape mask = NiceMaskedConvFlow.create_checkerboard_mask(h, w) - + layers = [] self.masks = [] for i in range(coupling_layers): layers.append( MaskedCoupling( - mask, + mask, ConvNet2D( mask.shape[0], - num_layers=conv_layers, - nonlinearity=nonlinearity, - kernel_size=kernel_size, + num_layers=conv_layers, + nonlinearity=nonlinearity, + kernel_size=kernel_size, c_hidden=c_hidden, - rescale_hidden=rescale_hidden - ) + rescale_hidden=rescale_hidden, + ), ) ) self.masks.append(mask) mask = 1 - mask - + layers.append(ScaleTransform(mask.shape)) super().__init__(base_distribution, layers, *args, **kwargs) - + @classmethod - def create_checkerboard_mask(cls, h: int, w: int, invert: bool=False) -> torch.Tensor: + def create_checkerboard_mask( + cls, h: int, w: int, invert: bool = False + ) -> torch.Tensor: """Creates a checkerboard mask of size $(h,w)$. - :param h (_type_): height - :param w (_type_): width - :param invert (bool, optional): If True, inverts the mask. Defaults to False. - :returns: Checkerboard mask of height $h$ and width $w$. + Args: + h (_type_): height + w (_type_): width + invert (bool, optional): If True, inverts the mask. Defaults to False. + Returns: + Checkerboard mask of height $h$ and width $w$. """ x, y = torch.arange(h, dtype=torch.int32), torch.arange(w, dtype=torch.int32) - xx, yy = torch.meshgrid(x, y, indexing='ij') + xx, yy = torch.meshgrid(x, y, indexing="ij") mask = torch.fmod(xx + yy, 2) mask = mask.to(torch.float32).view(1, 1, h, w) if invert: mask = 1 - mask return mask - + class LUFlow(Flow): - """Implementation of the NICE flow architecture using fully connected coupling layers and a checkerboard permutation""" + """Implementation of the NICE flow architecture using fully connected coupling layers + and a checkerboard permutation + """ + def __init__( - self, - base_distribution: dist.Distribution, - n_layers: int, - nonlinearity: Optional[torch.distributions.Transform] = None, - *args, - **kwargs - ) -> None: + self, + base_distribution: dist.Distribution, + n_layers: int, + nonlinearity: Optional[torch.distributions.Transform] = None, + *args, + **kwargs, + ) -> None: """Initialization - :param: base_distribution: base distribution, - :param: n_layers: number of LU-layers. - :param: nonlinearity: nonlinearity of the convolutional layers. + Args: + base_distribution: base distribution, + n_layers: number of LU-layers. + nonlinearity: nonlinearity of the convolutional layers. """ - + self.n_layers = n_layers if nonlinearity is None: @@ -327,13 +369,12 @@ def __init__( layers.append(nonlinearity()) super().__init__(base_distribution, layers, *args, **kwargs) - + def is_feasible(self): - """ Checks if all LU layers are feasible """ + """Checks if all LU layers are feasible""" return all([l.is_feasible() for l in self.layers if isinstance(l, LUTransform)]) - - def add_jitter(self, jitter): - for layer in self.layers : + + def add_jitter(self, jitter: float = 1e-6) -> None: + for layer in self.layers: if isinstance(layer, LUTransform) and not layer.is_feasible(): layer.add_jitter(jitter) - diff --git a/src/veriflow/linalg.py b/src/veriflow/linalg.py new file mode 100644 index 0000000..a79c866 --- /dev/null +++ b/src/veriflow/linalg.py @@ -0,0 +1,50 @@ +from typing import Optional + +import torch + + +def solve_triangular(M: torch.Tensor, y: torch.Tensor, pivot: Optional[int]=None) -> torch.Tensor: + """ Re-implementation of torch solve_triangular. Since Onnx export of the native method is currently not supported, + we implement it with basic torch operations. + + Args: + M: triangular matrix. May be upper or lower triangular + y: input vector. + pivot: If given, determines wether to treat $M$ as a lower or upper triangular matrix. Note that in this case + there is no check wether $M$ is actually lower or upper triangular, respectively. It therefore + speeds up computation but should be used with care. + Returns: + (torch.Tensor): Solution of the system $Mx=y$ + """ + + dim = M.size(0) + if dim == 1: + return y / M + + if pivot is None: + # Determine orientation of Matrix + if all([M[i, j] == 0. for i in range(dim) for j in range(i+1, dim)]): + pivot = 0 + elif all([M[i, j] == 0. for i in range(dim) for j in range(0, i)]): + pivot = -1 + else: + raise ValueError("M needs to be triangular.") + elif pivot not in [0, -1]: + raise ValueError("pivot needs to be either None, 0, or -1.") + + + x = torch.zeros_like(y) + x[pivot] = y[pivot] / M[pivot, pivot] + + y_next = (y - x[pivot] * M[:, pivot]) + if pivot == 0: + y_next = y_next[1:] + M_next = M[1:, 1:] + x[1:] = solve_triangular(M_next, y_next, pivot=pivot) + else: + y_next = y_next[:-1] + M_next = M[:-1, :-1] + x[:-1] = solve_triangular(M_next, y_next, pivot=pivot) + + return x + \ No newline at end of file diff --git a/src/veriflow/networks.py b/src/veriflow/networks.py index 6fdde30..0670f4b 100644 --- a/src/veriflow/networks.py +++ b/src/veriflow/networks.py @@ -7,16 +7,25 @@ class AdditiveAffineNN(torch.nn.Module): - """Provides a dense NN that computes loc and log_scale parameter for an affine transform that is purely additive, i.e. the log_scale component - always returns the 0 vector. + """Provides a dense NN that computes loc and log_scale parameter for an affine transform that is purely additive, i.e. the log_scale component + always returns the 0 vector. """ - def __init__(self, input_dim: int, hidden_dims: List[int], output_dim: int, nonlinearity: Optional[torch.nn.Module] = None): + + def __init__( + self, + input_dim: int, + hidden_dims: List[int], + output_dim: int, + nonlinearity: Optional[torch.nn.Module] = None, + ): super().__init__() if nonlinearity is None: nonlinearity = torch.nn.ReLU() - self.loc_fnc = DenseNN(input_dim, hidden_dims, [output_dim], nonlinearity=nonlinearity) - + self.loc_fnc = DenseNN( + input_dim, hidden_dims, [output_dim], nonlinearity=nonlinearity + ) + def forward(self, x: torch.Tensor) -> torch.Tensor: loc = self.loc_fnc(x) log_scale = torch.zeros_like(loc) @@ -24,13 +33,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class LayerNormChannels(nn.Module): - def __init__(self, c_in, eps=1e-5): """ This module applies layer norm across channels in an image. - Inputs: - c_in - Number of channels of the input - eps - Small constant to stabilize std + Args: + c_in: Number of channels of the input + eps: Small constant to stabilize std """ super().__init__() self.gamma = nn.Parameter(torch.ones(1, c_in, 1, 1)) @@ -46,43 +54,59 @@ def forward(self, x): class GatedConv(nn.Module): - def __init__(self, c_in, c_hidden): """ This module applies a two-layer convolutional ResNet block with input gate - Inputs: - c_in - Number of channels of the input - c_hidden - Number of hidden dimensions we want to model (usually similar to c_in) + Args: + c_in: Number of channels of the input + c_hidden: Number of hidden dimensions we want to model (usually similar to c_in) """ super().__init__() self.net = nn.Sequential( - nn.Conv2d(2*c_in, c_hidden, kernel_size=3, padding=1), - nn.Conv2d(2*c_hidden, 2*c_in, kernel_size=1) + nn.Conv2d(2 * c_in, c_hidden, kernel_size=3, padding=1), + nn.Conv2d(2 * c_hidden, 2 * c_in, kernel_size=1), ) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forwards method + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: network output. + """ out = self.net(x) val, gate = out.chunk(2, dim=1) return x + val * torch.sigmoid(gate) class ConvNet2D(nn.Module): - - def __init__(self, c_in, c_hidden=3, rescale_hidden: int = 2, c_out=-1, num_layers=3, nonlinearity=nn.ReLU(), kernel_size=3, padding=None): + def __init__( + self, + c_in: int, + c_hidden: int = 3, + rescale_hidden: int = 2, + c_out: int = -1, + num_layers: int = 3, + nonlinearity: any = nn.ReLU(), + kernel_size: int = 3, + padding: int = None, + ): """ Module that summarizes the previous blocks to a full convolutional neural network. - Inputs: - c_in - Number of input channels - c_hidden - Number of hidden dimensions to use within the network - rescale_hidden - Factor by which to rescale hight and width the hidden before and after the hidden layers. - c_out - Number of output channels. If -1, 2 times the input channels are used (affine coupling) - num_layers - Number of gated ResNet blocks to apply + Args: + c_in: Number of input channels + c_hidden: Number of hidden dimensions to use within the network + rescale_hidden: Factor by which to rescale hight and width the hidden before and after the hidden layers. + c_out: Number of output channels. If -1, 2 times the input channels are used (affine coupling) + num_layers: Number of gated ResNet blocks to apply """ super().__init__() - + if padding is None: padding = kernel_size // 2 - + self.nonlinearity = nonlinearity c_out = c_out if c_out > 0 else c_in layers = [] @@ -91,15 +115,14 @@ def __init__(self, c_in, c_hidden=3, rescale_hidden: int = 2, c_out=-1, num_laye ] if rescale_hidden != 1: layers += [nn.MaxPool2d(rescale_hidden)] - + for layer_index in range(num_layers): layers += [ nn.Conv2d(c_hidden, c_hidden, kernel_size=kernel_size, padding=padding), nonlinearity, LayerNormChannels(c_hidden), ] - - + # compute padding and output padding for rescaling via transposed convolutions if rescale_hidden != 1: diff = rescale_hidden - kernel_size @@ -109,24 +132,29 @@ def __init__(self, c_in, c_hidden=3, rescale_hidden: int = 2, c_out=-1, num_laye else: outpad = diff pad = 0 - + layers += [ nn.ConvTranspose2d( c_hidden, c_hidden, kernel_size=kernel_size, - stride=rescale_hidden, + stride=rescale_hidden, output_padding=outpad, - padding=pad + padding=pad, ), - nonlinearity - ] - - layers += [ - nn.Conv2d(c_hidden, c_out, kernel_size=kernel_size, padding=padding) + nonlinearity, ] + + layers += [nn.Conv2d(c_hidden, c_out, kernel_size=kernel_size, padding=padding)] self.nn = nn.Sequential(*layers) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ Forwards method + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: network output. + """ return self.nn(x) - \ No newline at end of file diff --git a/src/veriflow/transforms.py b/src/veriflow/transforms.py index 7df3e72..dadaeda 100644 --- a/src/veriflow/transforms.py +++ b/src/veriflow/transforms.py @@ -1,4 +1,5 @@ import math +from abc import abstractmethod from typing import List import numpy as np @@ -18,62 +19,119 @@ from torch.nn import init from tqdm import tqdm +from src.veriflow.linalg import solve_triangular -class ScaleTransform(dist.TransformModule): + +class BaseTransform(dist.TransformModule): + """Base class for transforms. Implemented as a thin layer on top of pyro's TransformModule. The baseTransform + provides additional methods for checking and constraints on the parameters of the transform. + """ + + def __init__(self, *args, **kwargs): + super(dist.TransformModule, self).__init__(*args, **kwargs) + + @abstractmethod + def is_feasible(self) -> bool: + """Checks if the layer is feasible.""" + return True + + @abstractmethod + def jitter(self, jitter: float = 1e-6) -> None: + """Adds jitter to the layer. This is useful to ensure that the transformation is invertible.""" + pass + + +class ScaleTransform(BaseTransform): """Implementation of a bijective scale transform. Applies a transform $y = \mathrm{diag}(\mathbf{scale})x$, where scale is a learnable parameter of dimension $\mathbf{dim}$ - + *Note:* The implementation does not enforce the non-zero constraint of the diagonal elements of $\mathbf{U}$ during training. See :func:`add_jitter` and :func:`is_feasible` for a way to ensure that the transformation is invertible. """ + def __init__(self, dim: torch.Tensor, *args, **kwargs): + """ Initializes the scale transform.""" super().__init__(*args, **kwargs) self.dim = dim self.scale = torch.nn.Parameter(torch.empty(dim)) self.init_params() - + self.bijective = True self.domain = dist.constraints.real_vector self.codomain = dist.constraints.real_vector - + def init_params(self): """initialization of the parameters""" dim = self.dim bound = 1 / math.sqrt(dim) if dim > 0 else 0 init.uniform_(self.scale, -bound, bound) - + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ Computes the affine transform $\mathbf{scale}x$ + + Args: + x (torch.Tensor): input tensor + + Returns: + torch.Tensor: transformed tensor $\mathbf{scale}x$ + """ return x * self.scale def backward(self, x: torch.Tensor) -> torch.Tensor: - return x / self.scale - + """ Computes the inverse transform $\mathbf{scale}^{-1}x$ + + + Args: + x (torch.Tensor): input tensor + Returns: + torch.Tensor: transformed tensor $\mathbf{scale}^{-1}x$ + """ + return x / self.scale + def _call(self, x: torch.Tensor) -> torch.Tensor: + """ Alias for :func:`forward`""" return self.forward(x) - + def _inverse(self, x: torch.Tensor) -> torch.Tensor: + """ Alias for :func:`backward`""" return self.backward(x) - + def log_abs_det_jacobian(self, x: torch.Tensor, y: torch.Tensor) -> float: + """ Computes the log absolute determinant of the Jacobian of the transform $\mathbf{scale}x$. + + Args: + x (torch.Tensor): input tensor + + Returns: + float: log absolute determinant of the Jacobian of the transform $\mathbf{scale}x$ + """ return self.scale.abs().log().sum() def sign(self) -> int: + """ Computes the sign of the determinant of the Jacobian of the transform $\mathbf{scale}x$.""" return 1 if (self.scale < 0).int().sum() % 2 == 0 else -1 - + def is_feasible(self) -> bool: """Checks if the layer is feasible, i.e. if the diagonal elements of $\mathbf{U}$ are all positive""" return (self.scale != 0).all() - + def add_jitter(self, jitter: float = 1e-6) -> None: - """Adds jitter to the diagonal elements of $\mathbf{U}$. This is useful to ensure that the transformation is invertible.""" + """Adds jitter to the diagonal elements of $\mathbf{U}$.""" perturbation = torch.randn(self.dim, device=self.U_raw.device) * jitter - self.U_raw = self.scale + perturbation + self.U_raw = self.scale + perturbation + -class Permute(pyro.distributions.TransformModule): +class Permute(BaseTransform): """Permutation transform.""" + bijective = True volume_preserving = True - def __init__(self, permutation, *, dim=-1, cache_size=1): + def __init__(self, permutation, *, dim=-1, cache_size=1) -> None: + """ Initializes the permutation transform. + + Args: + permutation (torch.Tensor): permutation vector + """ super().__init__(cache_size=cache_size) if dim >= 0: @@ -84,21 +142,24 @@ def __init__(self, permutation, *, dim=-1, cache_size=1): @constraints.dependent_property(is_discrete=False) def domain(self): + """ Returns the domain of the transform.""" return constraints.independent(constraints.real, -self.dim) @constraints.dependent_property(is_discrete=False) def codomain(self): + """ Returns the codomain of the transform.""" return constraints.independent(constraints.real, -self.dim) @lazy_property def inv_permutation(self): + """ Returns the inverse permutation.""" result = torch.empty_like(self.permutation, dtype=torch.long) result[self.permutation] = torch.arange( self.permutation.size(0), dtype=torch.long, device=self.permutation.device ) return result.to(self.permutation.device) - def _call(self, x): + def _call(self, x: torch.Tensor): """ :param x: the input into the bijection :type x: torch.Tensor @@ -110,7 +171,7 @@ def _call(self, x): return x.index_select(self.dim, self.permutation) - def _inverse(self, y): + def _inverse(self, y: torch.Tensor): """ :param y: the output of the bijection :type y: torch.Tensor @@ -119,11 +180,11 @@ def _inverse(self, y): """ return y.index_select(self.dim, self.inv_permutation) - def log_abs_det_jacobian(self, x, y): + def log_abs_det_jacobian(self, x: torch.Tensor, y: torch.Tensor): """ - Calculates the elementwise determinant of the log Jacobian, i.e. + Calculates the element-wise determinant of the log Jacobian, i.e. log(abs([dy_0/dx_0, ..., dy_{N-1}/dx_{N-1}])). Note that this type of - transform is not autoregressive, so the log Jacobian is not the sum of the + transform is not auto-regressive, so the log Jacobian is not the sum of the previous expression. However, it turns out it's always 0 (since the determinant is -1 or +1), and so returning a vector of zeros works. """ @@ -132,135 +193,185 @@ def log_abs_det_jacobian(self, x, y): x.size()[: -self.event_dim], dtype=x.dtype, layout=x.layout, device=x.device ) - - def with_cache(self, cache_size=1): + def with_cache(self, cache_size: int = 1): + """ Returns a new :class:`Permute` instance with a given cache size.""" if self._cache_size == cache_size: return self return Permute(self.permutation, cache_size=cache_size) - - -class LUTransform(dist.TransformModule): - """Implementation of a linear bijection transform. Applies a transform $y = \mathbf{L}\mathbf{U}x$, where $\mathbf{L}$ is a + + +class LUTransform(BaseTransform): + """Implementation of a linear bijection transform. Applies a transform $y = (\mathbf{L}\mathbf{U})^{-1}x$, where $\mathbf{L}$ is a lower triangular matrix with unit diagonal and $\mathbf{U}$ is an upper triangular matrix. Bijectivity is guaranteed by requiring that the diagonal elements of $\mathbf{U}$ are positive and the diagonal elements of $\mathbf{L}$ are all $1$. - + *Note:* The implementation does not enforce the non-zero constraint of the diagonal elements of $\mathbf{U}$ during training. See :func:`add_jitter` and :func:`is_feasible` for a way to ensure that the transformation is invertible. """ + bijective = True volume_preserving = False domain = dist.constraints.real_vector codomain = dist.constraints.real_vector - + def __init__(self, dim: int, *args, **kwargs): + """ Initializes the LU transform. + + Args: + dim (int): dimension of the input and output + """ super().__init__(*args, **kwargs) self.L_raw = torch.nn.Parameter(torch.empty(dim, dim)) self.U_raw = torch.nn.Parameter(torch.empty(dim, dim)) self.bias = torch.nn.Parameter(torch.empty(dim)) self.dim = dim - + self.init_params() - + self.input_shape = dim - - self.L_mask = torch.tril(torch.ones(dim, dim), diagonal=1) + + self.L_mask = torch.tril(torch.ones(dim, dim), diagonal=-1) self.U_mask = torch.triu(torch.ones(dim, dim), diagonal=0) - + self.L_raw.register_hook(lambda grad: grad * self.L_mask) self.U_raw.register_hook(lambda grad: grad * self.U_mask) - + def init_params(self): - """Parameter initialization + """Parameter initialization Adopted from pytorch's Linear layer parameter initialization. """ - - init.kaiming_uniform_(self.L_raw, nonlinearity='relu') + init.kaiming_uniform_(self.L_raw, nonlinearity="relu") with torch.no_grad(): - self.L_raw.copy_(self.L_raw.tril(diagonal=1).fill_diagonal_(1)) - - init.kaiming_uniform_(self.U_raw, nonlinearity='relu') + self.L_raw.copy_(self.L_raw.tril(diagonal=-1).fill_diagonal_(1)) + + init.kaiming_uniform_(self.U_raw, nonlinearity="relu") with torch.no_grad(): self.U_raw.copy_(self.U_raw.triu()) - + if self.bias is not None: fan_in = self.dim bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 init.uniform_(self.bias, -bound, bound) - + def forward(self, x: torch.Tensor) -> torch.Tensor: - """Computes the affine transform $(LU)x + \mathrm{bias}$ - + """Computes the affine transform $y = (LU)^{-1}x + \mathrm{bias}$. + The value $y$ is computed by solving the linear equation system + \begin{align*} + Ly_0 &= x + LU\textrm{bias} \\ + Uy &= y_0 + \end{align*} + :param x: input tensor :type x: torch.Tensor :return: transformed tensor $(LU)x + \mathrm{bias}$ """ - return F.linear(x, self.weight, self.bias) - + x0 = x + torch.functional.F.linear(self.bias, self.inv_weight) + y0 = solve_triangular(self.L, x0) + return solve_triangular(self.U, y0) def backward(self, y: torch.Tensor) -> torch.Tensor: - """Computes the inverse transform $(LU)^{-1}(y - \mathrm{bias})$ - + """Computes the inverse transform $(LU)(y - \mathrm{bias})$ + :param y: input tensor :type y: torch.Tensor :return: transformed tensor $(LU)^{-1}(y - \mathrm{bias})$""" return torch.functional.F.linear(y - self.bias, self.inv_weight) - + @property - def L(self): + def L(self) -> torch.Tensor: """The lower triangular matrix $\mathbf{L}$ of the layers LU decomposition""" - return self.L_raw.tril().fill_diagonal_(1) - + return self.L_raw.tril(-1) + torch.eye(self.dim) + @property - def U(self): + def U(self) -> torch.Tensor: """The upper triangular matrix $\mathbf{U}$ of the layers LU decomposition""" return self.U_raw.triu() - + @property - def inv_weight(self): + def inv_weight(self) -> torch.Tensor: """Inverse weight matrix of the affine transform""" return LA.matmul(self.L, self.U) - - @property - def weight(self): - """Weight matrix of the affine transform""" - return LA.inv(LA.matmul(self.L, self.U)) - - + def _call(self, x: torch.Tensor) -> torch.Tensor: + """ Alias for :func:`forward`""" return self.forward(x) - + def _inverse(self, y: torch.Tensor) -> torch.Tensor: + """ Alias for :func:`backward`""" return self.backward(y) - + def log_abs_det_jacobian(self, x: torch.Tensor, y: torch.Tensor) -> float: - return LA.slogdet(self.weight)[1] - + """ Computes the log absolute determinant of the Jacobian of the transform $(LU)x + \mathrm{bias}$. + + Args: + x (torch.Tensor): input tensor + y (torch.Tensor): transformed tensor + + Returns: + float: log absolute determinant of the Jacobian of the transform $(LU)x + \mathrm{bias}$ + """ + U = self.U + dU = U - U.triu(1) + return dU.abs().log().sum() + def sign(self) -> int: - return LA.slogdet(self.weight)[0] - + """ Computes the sign of the determinant of the Jacobian of the transform $(LU)x + \mathrm{bias}$. + + Args: + x (torch.Tensor): input tensor + + Returns: + float: sign of the determinant of the Jacobian of the transform $(LU)x + \mathrm{bias}$ + """ + return self.L.diag().prod().sign() * self.U.diag().prod().sign() + def to(self, device) -> None: + """ Moves the layer to a given device + + Args: + device (torch.device): target device + """ self.L_raw = self.L_raw.to(device) self.U_raw = self.U_raw.to(device) self.bias = self.bias.to(device) return super().to(device) - + def is_feasible(self) -> bool: """Checks if the layer is feasible, i.e. if the diagonal elements of $\mathbf{U}$ are all positive""" return (self.U_raw.diag() != 0).all() - + def add_jitter(self, jitter: float = 1e-6) -> None: - """Adds jitter to the diagonal elements of $\mathbf{U}$. This is useful to ensure that the transformation is invertible.""" + """Adds jitter to the diagonal elements of $\mathbf{U}$. This is useful to ensure that the transformation + is invertible. + + Args: + jitter (float, optional): jitter strength. Defaults to 1e-6. + """ perturbation = torch.randn(self.dim, device=self.U_raw.device) * jitter with torch.no_grad(): - self.U_raw.copy_(self.U_raw + perturbation * torch.eye(self.dim, device=self.U_raw.device)) + self.U_raw.copy_( + self.U_raw + + perturbation * torch.eye(self.dim, device=self.U_raw.device) + ) + -class MaskedCoupling(dist.TransformModule): - """Implementation of a masked coupling layer. The layer is defined by a mask that specifies which dimensions are passed through unchanged and which are transformed. - The layer is defined by a bijective function $y = \mathrm{mask} \odot x + (1 - \mathrm{mask}) \odot (x + \mathrm{transform}(x))$, where $\mathrm{mask}$ is a binary mask, + +class MaskedCoupling(BaseTransform): + """Implementation of a masked coupling layer. The layer is defined by a mask that specifies which dimensions are passed through unchanged and which are transformed. + The layer is defined by a bijective function $y = \mathrm{mask} \odot x + (1 - \mathrm{mask}) \odot (x + \mathrm{transform}(x))$, where $\mathrm{mask}$ is a binary mask, $\mathrm{transform}$ is a bijective function, and $\odot$ denotes element-wise multiplication. """ - def __init__(self, mask: torch.Tensor, conditioner: torch.nn.Module, *args, **kwargs): + + def __init__( + self, mask: torch.Tensor, conditioner: torch.nn.Module, *args, **kwargs + ) -> None: + """ Initializes the masked coupling layer. + + Args: + mask (torch.Tensor): binary mask + conditioner (torch.nn.Module): bijective function $\mathrm{transform}$ + """ super().__init__(*args, **kwargs) self.mask = mask self.conditioner = conditioner @@ -268,58 +379,127 @@ def __init__(self, mask: torch.Tensor, conditioner: torch.nn.Module, *args, **kw self.bijective = True self.domain = dist.constraints.real_vector self.codomain = dist.constraints.real_vector - + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ Computes the affine transform + $\mathrm{mask} \odot x + (1 - \mathrm{mask}) \odot (x + \mathrm{transform}(x))$ + + Args: + x (torch.Tensor): input tensor + """ x_masked = x * self.mask x_transformed = x + (1 - self.mask) * self.conditioner(x_masked) return x_transformed def backward(self, y: torch.Tensor) -> torch.Tensor: + """ Computes the inverse transform + + Args: + y (torch.Tensor): input tensor + + Returns: + torch.Tensor: transformed tensor + """ y_masked = y * self.mask y_transformed = y - (1 - self.mask) * self.conditioner(y_masked) - return y_transformed - + return y_transformed + def _call(self, x: torch.Tensor) -> torch.Tensor: + """ Alias for :func:`forward`""" return self.forward(x) - + def _inverse(self, y: torch.Tensor) -> torch.Tensor: + """ Alias for :func:`backward`""" return self.backward(y) - + def log_abs_det_jacobian(self, x: torch.Tensor, y: torch.Tensor) -> float: + """ Computes the log absolute determinant of the Jacobian of the transform + + Args: + x (torch.Tensor): input tensor + y (torch.Tensor): output tensor + + Returns: + float: log absolute determinant of the Jacobian of the transform + """ x_masked = x * self.mask - return 0. - + return 0.0 + def sign(self) -> int: - return 1. - + """ Computes the sign of the determinant of the Jacobian of the transform + + Args: + x (torch.Tensor): input tensor + + Returns: + int: sign of the determinant of the Jacobian of the transform + """ + return 1.0 + def to(self, device): + """ Moves the layer to a given device + + Args: + device (torch.device): target device + """ self.mask = self.mask.to(device) return super().to(device) -class LeakyReLUTransform(dist.TransformModule): +class LeakyReLUTransform(BaseTransform): bijective = True domain = dist.constraints.real codomain = dist.constraints.real sign = 1 - - def __init__(self, alpha: float = .01, *args, **kwargs): + + def __init__(self, alpha: float = 0.01, *args, **kwargs) -> None: + """ Initializes the LeakyReLU transform. + + Args: + alpha (float, optional): slope of the negative part of the function. Defaults to 0.01. + """ if alpha == 0: raise ValueError("alpha must be positive") super().__init__(*args, **kwargs) self.alpha = alpha - - + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ Computes the LeakyReLU transform + + Args: + x (torch.Tensor): input tensor + + Returns: + torch.Tensor: transformed tensor + """ return F.leaky_relu(x, negative_slope=self.alpha) def backward(self, y: torch.Tensor) -> torch.Tensor: - return F.leaky_relu(y, negative_slope=1/self.alpha) - + """ Computes the inverse transform + + Args: + y (torch.Tensor): input tensor + + Returns: + torch.Tensor: transformed tensor + """ + return F.leaky_relu(y, negative_slope=1 / self.alpha) + def _call(self, x: torch.Tensor) -> torch.Tensor: + """ Alias for :func:`forward`""" return self.forward(x) - + def _inverse(self, y: torch.Tensor) -> torch.Tensor: + """ Alias for :func:`backward`""" return self.backward(y) - - def log_abs_det_jacobian(self, x: torch.Tensor, y: torch.Tensor) -> float: - return ((x <= 0).float() * math.log(self.alpha)).sum() \ No newline at end of file + + def log_abs_det_jacobian(self, x: torch.Tensor, y: torch.Tensor) -> float: + """ Computes the log absolute determinant of the Jacobian of the transform + + Args: + x (torch.Tensor): input tensor + y (torch.Tensor): output tensor + + Returns: + float: log absolute determinant of the Jacobian of the transform + """ + return math.log(y/x).sum() diff --git a/tests/flows_test.py b/tests/flows_test.py index 4a26c3a..4c14312 100644 --- a/tests/flows_test.py +++ b/tests/flows_test.py @@ -8,21 +8,12 @@ def test_mnist(): report_dir = "./reports" storage_path = None sepline = "\n" + ("-" * 80) + "\n" + ("-" * 80) + "\n" - print( - f"{sepline}Parsing config file:{sepline}" - ) + print(f"{sepline}Parsing config file:{sepline}") config = os.path.abspath("./tests/mnist.yaml") experiment = read_config(config) - print( - f"{sepline}Done.{sepline}" - ) - print( - f"{sepline}Conducting experiment{sepline}" - ) + print(f"{sepline}Done.{sepline}") + print(f"{sepline}Conducting experiment{sepline}") # Conduct experiment experiment.conduct(report_dir, storage_path=storage_path) - print( - f"{sepline}Done.{sepline}" - ) + print(f"{sepline}Done.{sepline}") assert True - diff --git a/tests/onnx_test.py b/tests/onnx_test.py index 9fc0a75..0589b20 100644 --- a/tests/onnx_test.py +++ b/tests/onnx_test.py @@ -7,12 +7,7 @@ def test_onnx(): loc = torch.zeros(2) scale = torch.ones(2) - model = NiceFlow( - Normal(loc, scale), - 2, - [10, 10], - split_dim=1, - permutation= "half" - ) + model = NiceFlow(Normal(loc, scale), 2, [10, 10], split_dim=1, permutation="LU") model.to_onnx("log_prob.onnx") - model.to_onnx("sample.onnx", export_mode="sample") \ No newline at end of file + model.to_onnx("sample.onnx", export_mode="sample") +