diff --git a/args.py b/args.py index c6eaf06..c888e44 100644 --- a/args.py +++ b/args.py @@ -161,6 +161,7 @@ def parse_args(): parser.add_argument("--log-weights", action="store_true", help="use log weights") parser.add_argument("--num-devices", type=int, default=1, help="number of devices") + parser.add_argument("--structure", default="top-down", choices=["bottom-up", "top-down"], help="structure of the network") # Parse args args = parser.parse_args() diff --git a/conf/config.yaml b/conf/config.yaml index 3b20b81..4a90d2a 100644 --- a/conf/config.yaml +++ b/conf/config.yaml @@ -16,6 +16,21 @@ hydra: file: filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log # Fixed in hydra-colorlog version 1.2.1 +einet: + S: 10 + I: 10 + D: 3 + R: 1 + layer_type: "linsum" + structure: "top-down" + +convpc: + channels: [8, 16, 16, 16] + order: "sum-prod" + structure: "top-down" + kernel_size: 2 + + # Default set of configurations. data_dir: "${oc.env:DATA_DIR}/" @@ -36,15 +51,9 @@ log_interval: 10 classification: False device: "cuda" debug: False -S: 10 -I: 10 -D: 3 -R: 1 gpu: 0 epochs: 10 load_and_eval: False -layer_type: "linsum" -dist: "normal" precision: "bf16-mixed" group_tag: ??? tag: ??? @@ -54,6 +63,8 @@ profiler: ??? dataset: ??? num_classes: 10 init_leaf_data: False -einet_mixture: False +mixture: False torch_compile: False multivariate_cardinality: 2 +dist: "normal" +model: "einet" # Can be one of "einet" or "convpc" diff --git a/main.py b/main.py index 9766e8f..7e2f357 100644 --- a/main.py +++ b/main.py @@ -23,7 +23,7 @@ from simple_einet.layers.distributions.binomial import Binomial from simple_einet.layers.distributions.normal import RatNormal, Normal from simple_einet.einet import Einet, EinetConfig -from simple_einet.einet_mixture import EinetMixture +from simple_einet.mixture import Mixture import lightning as L @@ -42,7 +42,7 @@ def log_likelihoods(outputs, targets=None): return lls -def train(args, model: Union[Einet, EinetMixture], device, train_loader, optimizer, epoch): +def train(args, model: Union[Einet, Mixture], device, train_loader, optimizer, epoch): model.train() pbar = tqdm.tqdm(train_loader) @@ -62,15 +62,8 @@ def train(args, model: Union[Einet, EinetMixture], device, train_loader, optimiz optimizer.zero_grad() - if args.dist == Dist.PIECEWISE_LINEAR: - cache_leaf = True - cache_index = batch_idx - else: - cache_leaf = False - cache_index = None - # Generate outputs - outputs = model(data, cache_leaf=cache_leaf, cache_index=cache_index) + outputs = model(data) if args.classification: model.posterior(data) @@ -191,6 +184,7 @@ def test(model, device, loader, tag): leaf_kwargs=leaf_kwargs, layer_type=args.layer, dropout=0.0, + structure=args.structure, ) fabric = L.Fabric(accelerator=args.device, devices=args.num_devices, precision="16-mixed") @@ -227,67 +221,6 @@ def test(model, device, loader, tag): train_loader, val_loader, test_loader = fabric.setup_dataloaders(train_loader, val_loader, test_loader) - if args.dist == Dist.PIECEWISE_LINEAR: - # Initialize the piecewise linear function - # Collect data - batches = [] - count = 0 - for data, _ in train_loader: - batches.append(data) - count += data.shape[0] - if count > 10000: - break - data_init_pwl = torch.cat(batches, dim=0) - - # Prepare data - data_init_pwl = preprocess( - data_init_pwl, - n_bits, - n_bins, - dequantize=True, - has_gauss_dist=has_gauss_dist, - ) - - data_init_pwl = data_init_pwl.view(data_init_pwl.shape[0], data_init_pwl.shape[1], num_features) - - domains = [Domain.discrete_range(min=0, max=255)] * num_features - with torch.no_grad(): - model.leaf.base_leaf.initialize(data_init_pwl, domains=domains) - - # Use mixture weights obtained in leaf initialization and set these to the first linsum layer weights - model.layers[0].logits.data[:] = model.leaf.base_leaf.mixture_weights.permute(1, 0).view(1, config.num_leaves, 1, config.num_repetitions).log() - - # Visualize a couple of pixel distributions and their piecewise linear functions - # Select 20 random pixels - pixels = list(range(64))[::3] - # pixels = [36, 766, 720, 588, 759, 403, 664, 428, 25, 686, 673, 638, 44, 147, 610, 470, 540, 179, 698, 420] - - d = model.leaf.base_leaf._get_base_distribution() - log_probs = d.log_prob(data_init_pwl) - - xs = d.xs - ys = d.ys - - for pixel in pixels: - # Get data subset - # xs_pixel = xs[pixel][0][0][0].squeeze() - # ys_pixel = ys[pixel][0][0][0].squeeze() - xs_pixel = xs[0][0][pixel][0].squeeze().cpu() - ys_pixel = ys[0][0][pixel][0].squeeze().cpu() - - # Plot pixel distribution with pixel value as x and logprob as y values - import matplotlib.pyplot as plt - - plt.figure(figsize=(12, 6)) - plt.plot(xs_pixel, ys_pixel, label="PWL") - - # Plot histogram of pixel values - plt.hist(data_init_pwl[:, :, pixel].flatten().cpu().numpy(), bins=100, density=True, alpha=0.5, label="Data") - plt.xlabel("Pixel Value") - plt.ylabel("Density") - plt.legend() - plt.savefig(os.path.join(result_dir, f"pwl-{pixel}.png"), dpi=300) - plt.close() if args.train: for epoch in range(1, args.epochs + 1): diff --git a/main_pl.py b/main_pl.py index 845eb7b..abc709c 100644 --- a/main_pl.py +++ b/main_pl.py @@ -58,8 +58,6 @@ def main(cfg: DictConfig): logger.info("\n" + OmegaConf.to_yaml(cfg, resolve=True)) logger.info("Run dir: " + run_dir) - seed_everything(cfg.seed, workers=True) - if not cfg.wandb: os.environ["WANDB_MODE"] = "offline" @@ -87,6 +85,7 @@ def main(cfg: DictConfig): num_workers=min(cfg.num_workers, os.cpu_count()), loop=False, normalize=normalize, + seed=cfg.seed, ) # Create callbacks @@ -120,7 +119,7 @@ def main(cfg: DictConfig): # model = torch.compile(model) raise NotImplementedError("Torch compilation not yet supported with einsum.") - if cfg.einet_mixture: + if cfg.mixture: # If we chose a mixture of einets, we need to initialize the mixture weights logger.info("Initializing Einet mixture weights") model.spn.initialize(dataloader=train_loader, device=devices[0]) diff --git a/models_pl.py b/models_pl.py index 806567c..8b58c86 100644 --- a/models_pl.py +++ b/models_pl.py @@ -10,16 +10,17 @@ from rtpt import RTPT from torch import nn +from simple_einet.conv_pc import ConvPcConfig, ConvPc from simple_einet.data import get_data_shape from simple_einet.dist import Dist, get_distribution from simple_einet.einet import EinetConfig, Einet -from simple_einet.einet_mixture import EinetMixture +from simple_einet.mixture import Mixture # Translate the dataloader index to the dataset name DATALOADER_ID_TO_SET_NAME = {0: "train", 1: "val", 2: "test"} -def make_einet(cfg, num_classes: int = 1) -> EinetMixture | Einet: +def make_einet(cfg, num_classes: int = 1) -> Mixture | Einet: """ Make an Einet model based off the given arguments. @@ -38,22 +39,55 @@ def make_einet(cfg, num_classes: int = 1) -> EinetMixture | Einet: config = EinetConfig( num_features=image_shape.num_pixels, num_channels=image_shape.channels, - depth=cfg.D, - num_sums=cfg.S, - num_leaves=cfg.I, - num_repetitions=cfg.R, + depth=cfg.einet.D, + num_sums=cfg.einet.S, + num_leaves=cfg.einet.I, + num_repetitions=cfg.einet.R, num_classes=num_classes, leaf_kwargs=leaf_kwargs, leaf_type=leaf_type, dropout=cfg.dropout, - layer_type=cfg.layer_type, + layer_type=cfg.einet.layer_type, + structure=cfg.einet.structure, ) - if cfg.einet_mixture: - return EinetMixture(n_components=num_classes, einet_config=config) + if cfg.mixture: + return Mixture(n_components=num_classes, config=config) else: return Einet(config) +def make_convpc(cfg, num_classes: int = 1) -> Mixture | ConvPc: + """ + Make ConvPc model based off the given arguments. + + Args: + cfg: Arguments parsed from argparse. + num_classes: Number of classes to model. + + Returns: + ConvPc model. + """ + + image_shape = get_data_shape(cfg.dataset) + # leaf_kwargs, leaf_type = {"total_count": 255}, Binomial + leaf_kwargs, leaf_type = get_distribution(dist=cfg.dist, cfg=cfg) + + config = ConvPcConfig( + channels=cfg.convpc.channels, + num_channels=image_shape.channels, + num_classes=num_classes, + leaf_kwargs=leaf_kwargs, + leaf_type=leaf_type, + structure=cfg.convpc.structure, + order=cfg.convpc.order, + kernel_size=cfg.convpc.kernel_size, + ) + if cfg.mixture: + return Mixture(n_components=num_classes, config=config, data_shape=image_shape) + else: + return ConvPc(config=config, data_shape=image_shape) + + class LitModel(pl.LightningModule, ABC): """ LightningModule for training a model using PyTorch Lightning. @@ -123,7 +157,13 @@ class SpnGenerative(LitModel): def __init__(self, cfg: DictConfig, steps_per_epoch: int): super().__init__(cfg=cfg, name="gen", steps_per_epoch=steps_per_epoch) - self.spn = make_einet(cfg) + if cfg.model == "einet": + self.spn = make_einet(cfg, num_classes=cfg.num_classes) + elif cfg.model == "convpc": + self.spn = make_convpc(cfg, num_classes=cfg.num_classes) + else: + raise ValueError(f"Unknown model {cfg.model}") + def training_step(self, train_batch, batch_idx): data, labels = train_batch @@ -209,7 +249,12 @@ def __init__(self, cfg: DictConfig, steps_per_epoch: int): super().__init__(cfg, name="disc", steps_per_epoch=steps_per_epoch) # Construct SPN - self.spn = make_einet(cfg, num_classes=10) + if cfg.model == "einet": + self.spn = make_einet(cfg, num_classes=cfg.num_classes) + elif cfg.model == "convpc": + self.spn = make_convpc(cfg, num_classes=cfg.num_classes) + else: + raise ValueError(f"Unknown model {cfg.model}") # Define loss function self.criterion = nn.NLLLoss() diff --git a/simple_einet/data.py b/simple_einet/data.py index 5966a92..eac31dd 100644 --- a/simple_einet/data.py +++ b/simple_einet/data.py @@ -96,7 +96,7 @@ def get_data_shape(dataset_name: str) -> Shape: "cifar": (3, 32, 32), "svhn": (3, 32, 32), "svhn-extra": (3, 32, 32), - "celeba": (3, 64, 64), + "celeba": (3, 128, 128), "celeba-small": (3, 64, 64), "celeba-tiny": (3, 32, 32), "lsun": (3, 32, 32), diff --git a/simple_einet/einet.py b/simple_einet/einet.py index 399f213..b26470e 100644 --- a/simple_einet/einet.py +++ b/simple_einet/einet.py @@ -37,7 +37,7 @@ class EinetConfig: leaf_type: Type = None # Type of the leaf base class (Normal, Bernoulli, etc) leaf_kwargs: Dict[str, Any] = field(default_factory=dict) # Parameters for the leaf base class layer_type: str = "linsum" # Indicates the intermediate layer type: linsum or einsum - structure: str = "original" # Structure of the Einet: original or bottom_up + structure: str = "top-down" # Structure of the Einet: top-down or bottom-up def assert_valid(self): """Check whether the configuration is valid.""" @@ -58,9 +58,9 @@ def assert_valid(self): "einsum", ], f"Invalid layer type {self.layer_type}. Must be 'linsum' or 'einsum'." assert self.structure in [ - "original", - "bottom_up", - ], f"Invalid structure type {self.structure}. Must be 'original' or 'bottom_up'." + "top-down", + "bottom-up", + ], f"Invalid structure type {self.structure}. Must be 'top-down' or 'bottom-up'." assert isinstance(self.leaf_type, type) and issubclass( self.leaf_type, AbstractLeaf @@ -72,7 +72,7 @@ def assert_valid(self): else: cardinality = 1 - if self.structure == "bottom_up": + if self.structure == "bottom-up": assert self.layer_type == "linsum", "Bottom-up structure only supports LinsumLayer due to handling of padding (not implemented for einsumlayer yet)." # Get minimum number of features present at the lowest layer (num_features is the actual input dimension, @@ -104,12 +104,12 @@ def __init__(self, config: EinetConfig): self.config = config # Construct the architecture - if self.config.structure == "original": - self._build_structure_original() - elif self.config.structure == "bottom_up": + if self.config.structure == "top-down": + self._build_structure_top_down() + elif self.config.structure == "bottom-up": self._build_structure_bottom_up() else: - raise ValueError(f"Invalid structure type {self.config.structure}. Must be 'original' or 'bottom_up'.") + raise ValueError(f"Invalid structure type {self.config.structure}. Must be '_riginal' or 'bottom-up'.") # Leaf cache self._leaf_cache = {} @@ -235,9 +235,9 @@ def posterior(self, x) -> torch.Tensor: return posterior(ll_x_g_y, self.config.num_classes) - def _build_structure_original(self): + def _build_structure_top_down(self): """Construct the internal architecture of the Einet.""" - # Build the SPN bottom up: + # Build the SPN top down: # Definition from RAT Paper # Leaf Region: Create I leaf nodes # Root Region: Create C sum nodes @@ -473,7 +473,7 @@ def _build_structure_bottom_up(self): ) def _build_input_distribution_bottom_up(self) -> AbstractLeaf: - """Construct the input distribution layer. This constructs a direct leaf and not a FactorizedLeaf since the bottom_up approach does not factorize.""" + """Construct the input distribution layer. This constructs a direct leaf and not a FactorizedLeaf since the bottom-up approach does not factorize.""" # Cardinality is the size of the region in the last partitions return self.config.leaf_type( num_features=self.config.num_features, diff --git a/tests/test_einet.py b/tests/test_einet.py index 5df9174..469841e 100644 --- a/tests/test_einet.py +++ b/tests/test_einet.py @@ -35,7 +35,7 @@ def setUp(self) -> None: self.leaf_type = Binomial self.leaf_kwargs = {"total_count": 255} - @parameterized.expand(product([False, True], [1, 3], [1, 4], ["original", "bottom_up"], ["linsum"])) + @parameterized.expand(product([False, True], [1, 3], [1, 4], ["top-down", "bottom-up"], ["linsum"])) def test_sampling_shapes(self, differentiable: bool, num_classes: int, num_repetitions: int, structure: str, layer_type: str): model = self.make_einet(num_classes=num_classes, num_repetitions=num_repetitions, structure=structure, layer_type=layer_type) N = 2 @@ -49,7 +49,7 @@ def test_sampling_shapes(self, differentiable: bool, num_classes: int, num_repet samples = model.sample(evidence=evidence, is_differentiable=differentiable) self.assertEqual(samples.shape, (N, self.num_channels, self.num_features)) - @parameterized.expand(product([False, True], [1, 3], [1, 4], ["original", "bottom_up"], ["linsum"])) + @parameterized.expand(product([False, True], [1, 3], [1, 4], ["top-down", "bottom-up"], ["linsum"])) def test_mpe_shapes(self, differentiable: bool, num_classes: int, num_repetitions: int, structure: str, layer_type: str): model = self.make_einet(num_classes=num_classes, num_repetitions=num_repetitions, structure=structure, layer_type=layer_type) N = 2