diff --git a/main_pwl.py b/main_pwl.py new file mode 100644 index 0000000..5a3d4e7 --- /dev/null +++ b/main_pwl.py @@ -0,0 +1,243 @@ +#!/usr/bin/env python3 + + +import numpy as np +import tqdm +from simple_einet.dist import DataType, Domain +from simple_einet.layers.distributions.piecewise_linear import PiecewiseLinear +from torch.distributions import Binomial +import matplotlib.pyplot as plt + +import torch + +from simple_einet.einet import Einet, EinetConfig + +BINS = 100 + +def make_dataset(num_features_continuous, num_features_discrete, num_clusters, num_samples): + # Collect data and data domains + data = [] + domains = [] + + # Construct continuous features + for i in range(num_features_continuous): + domains.append(Domain.continuous_inf_support()) + feat_i = [] + + # Create a multimodal feature + for j in range(num_clusters): + feat_i.append(torch.randn(num_samples) + j * 3 * torch.rand(1) + 3 * j) + + data.append(torch.cat(feat_i)) + + # Construct discrete features + for i in range(num_features_discrete): + domains.append(Domain.discrete_range(0, BINS)) + feat_i = [] + + # Create a multimodal feature + for j in range(num_clusters): + feat_i.append(Binomial(total_count=BINS, probs=torch.rand(1)).sample((num_samples,)).view(-1)) + data.append(torch.cat(feat_i)) + + data = torch.stack(data, dim=1) + data = data.view(data.shape[0], 1, num_features_continuous + num_features_discrete) + data = data[torch.randperm(data.shape[0])] + return data, domains + + +if __name__ == "__main__": + torch.manual_seed(0) + + ################### + # Hyperparameters # + ################### + + epochs = 3 + batch_size = 128 + depth = 2 + num_sums = 20 + num_leaves = 10 + num_repetitions = 10 + lr = 0.01 + + num_features = 4 + + ############### + # Einet Setup # + ############### + + config = EinetConfig( + num_features=num_features, + num_channels=1, + depth=depth, + num_sums=num_sums, + num_leaves=num_leaves, + num_repetitions=num_repetitions, + num_classes=1, + leaf_type=PiecewiseLinear, + # leaf_kwargs={"alpha": 0.05}, + layer_type="linsum", + dropout=0.0, + ) + + model = Einet(config) + print(model) + print("Number of parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad)) + + + + ############## + # Data Setup # + ############## + + # Simulate data + data, domains = make_dataset( + num_features_continuous=num_features // 2, + num_features_discrete=num_features // 2, + num_clusters=4, + num_samples=1000, + ) + + ######################################## + # PiecewiseLinear Layer Initialization # + ######################################## + + model.leaf.base_leaf.initialize(data, domains=domains) + + # Init. first linsum layer weights to be the log of the mixture weights from the kmeans result in the PWL init phase + model.layers[0].logits.data[:] = ( + model.leaf.base_leaf.mixture_weights.permute(1, 0).view(1, config.num_leaves, 1, config.num_repetitions).log() + ) + + + ################ + # Optimization # + ################ + # Optimize Einet parameters (weights and leaf params) + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[int(epochs * 0.5), int(epochs * 0.75)], gamma=0.1, verbose=True) + + model.train() + for epoch in range(1, epochs + 1): + # Since we don't have a train dataloader, we will loop over the data manually + iter = range(0, len(data), batch_size) + pbar = tqdm.tqdm(iter, desc="Train Epoch: {}".format(epoch)) + for batch_idx in pbar: + optimizer.zero_grad() + + # Select batch + data_batch = data[batch_idx : batch_idx + batch_size] + + # Generate outputs + outputs = model(data_batch, cache_index=batch_idx) + + # Compute loss + loss = -1 * outputs.mean() + + # Compute gradients + loss.backward() + + # Update weights + optimizer.step() + + # Logging + if batch_idx % 10 == 0: + pbar.set_description( + "Train Epoch: {} [{}/{}] Loss: {:.2f}".format( + epoch, + batch_idx, + len(data), + loss.item(), + ) + ) + scheduler.step() + + + model.eval() + + ################# + # Visualization # + ################# + + # Generate samples + samples = model.sample(10000) + + # Plot results + fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 10)) + for i, ax in enumerate([ax1, ax2, ax3, ax4]): + # Get data subset + if domains[i].data_type == DataType.DISCRETE: + rng = (0, BINS + 1) + bins = BINS + 1 + width = 1 + else: + rng = None + bins = 100 + width = (samples[:, :, i].max() - samples[:, :, i].min()) / bins + + # Plot histogram of data + hist = torch.histogram(samples[:, :, i], bins=bins, density=True, range=rng) + bin_edges = hist.bin_edges + density = hist.hist + if domains[i].data_type == DataType.DISCRETE: + bin_edges -= 0.5 + + # Center bars on value (e.g. bar for value 0 should have its center at value 0) + bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2 + ax.bar(bin_centers, density, width=width * 0.8, alpha=0.5, label="Samples") + + + if domains[i].data_type == DataType.DISCRETE: + rng = (0, BINS + 1) + bins = BINS + 1 + width = 1 + else: + rng = None + bins = 100 + width = (data[:, :, i].max() - data[:, :, i].min()) / bins + + # Plot histogram of data + hist = torch.histogram(data[:, :, i], bins=bins, density=True, range=rng) + bin_edges = hist.bin_edges + density = hist.hist + if domains[i].data_type == DataType.DISCRETE: + bin_edges -= 0.5 + + # Center bars on value (e.g. bar for value 0 should have its center at value 0) + bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2 + ax.bar(bin_centers, density, width=width * 0.8, alpha=0.5, label="Data") + + # Plot PWL logprobs + dummy = torch.full((bin_centers.shape[0], data.shape[1], data.shape[2]), np.nan) + dummy[:, 0, i] = bin_centers + with torch.no_grad(): + log_probs = model(dummy) + probs = log_probs.exp().squeeze(-1).numpy() + ax.plot(bin_centers, probs, linewidth=2, label="PWL Density") + + + # MPE + mpe = model.mpe() + dummy = torch.full((mpe.shape[0], data.shape[1], data.shape[2]), np.nan) + dummy[:, 0, i] = mpe[:, 0, i] + with torch.no_grad(): + mpe_prob = model(dummy).exp().detach() + ax.plot(mpe.squeeze()[i], mpe_prob.squeeze(), "rx", markersize=13, label="PWL MPE") + + ax.set_xlabel("Feature Value") + ax.set_ylabel("Density") + + ax.set_title(f"Feature {i} ({str(domains[i].data_type)})") + ax.legend() + + plt.tight_layout() + plt.savefig(f"/tmp/pwl.png", dpi=300) + + # Conditional sampling example + data_subset = data[:10] + data[:, :, :2] = np.nan + samples_cond = model.sample(evidence=data_subset) + + # Conditional MPE example + mpe_cond = model.mpe(evidence=data_subset) diff --git a/simple_einet/dist.py b/simple_einet/dist.py new file mode 100644 index 0000000..9181db8 --- /dev/null +++ b/simple_einet/dist.py @@ -0,0 +1,104 @@ +from enum import Enum +import numpy as np + + +class Dist(str, Enum): + """Enum for the distribution of the data.""" + + NORMAL = "normal" + MULTIVARIATE_NORMAL = "multivariate_normal" + NORMAL_RAT = "normal_rat" + BINOMIAL = "binomial" + CATEGORICAL = "categorical" + PIECEWISE_LINEAR = "piecewise_linear" + + +class DataType(str, Enum): + """Enum for the type of the data.""" + + CONTINUOUS = "continuous" + DISCRETE = "discrete" + +class Domain: + def __init__(self, values=None, min=None, max=None, data_type=None): + self.values = values + self.min = min + self.max = max + self.data_type = data_type + + @staticmethod + def discrete_bins(values): + return Domain(min=min(values), max=max(values), values=values, data_type=DataType.DISCRETE) + + @staticmethod + def discrete_range(min, max): + return Domain(min=min, max=max, values=list(np.arange(min, max+1)), data_type=DataType.DISCRETE) + + @staticmethod + def continuous_range(min, max): + return Domain(min=min, max=max, data_type=DataType.CONTINUOUS) + + @staticmethod + def continuous_inf_support(): + return Domain(min=np.NINF, max=np.inf, data_type=DataType.CONTINUOUS) + + + +def get_data_type_from_dist(dist: Dist) -> DataType: + """ + Returns the data type based on the distribution. + + Args: + dist: The distribution. + + Returns: + DataType: CONTINUOUS or DISCRETE based on the distribution. + """ + if dist in {Dist.NORMAL, Dist.NORMAL_RAT, Dist.MULTIVARIATE_NORMAL, Dist.PIECEWISE_LINEAR}: + return DataType.CONTINUOUS + elif dist in {Dist.BINOMIAL, Dist.CATEGORICAL}: + return DataType.DISCRETE + else: + raise ValueError(f"Unknown distribution ({dist}).") + + +def get_distribution(dist: Dist, cfg): + """ + Get the distribution for the leaves. + + Args: + dist: The distribution to use. + + Returns: + leaf_type: The type of the leaves. + leaf_kwargs: The kwargs for the leaves. + + """ + # Import the locally to circumvent circular imports. + from simple_einet.layers.distributions.binomial import Binomial + from simple_einet.layers.distributions.categorical import Categorical + from simple_einet.layers.distributions.multivariate_normal import MultivariateNormal + from simple_einet.layers.distributions.normal import Normal, RatNormal + from simple_einet.layers.distributions.piecewise_linear import PiecewiseLinear + + if dist == Dist.NORMAL: + leaf_type = Normal + leaf_kwargs = {} + elif dist == Dist.NORMAL_RAT: + leaf_type = RatNormal + leaf_kwargs = {"min_sigma": cfg.min_sigma, "max_sigma": cfg.max_sigma} + elif dist == Dist.BINOMIAL: + leaf_type = Binomial + leaf_kwargs = {"total_count": 2**cfg.n_bits - 1} + elif dist == Dist.CATEGORICAL: + leaf_type = Categorical + leaf_kwargs = {"num_bins": 2**cfg.n_bits - 1} + elif dist == Dist.MULTIVARIATE_NORMAL: + leaf_type = MultivariateNormal + leaf_kwargs = {"cardinality": cfg.multivariate_cardinality} + elif dist == Dist.PIECEWISE_LINEAR: + leaf_type = PiecewiseLinear + leaf_kwargs = {} + else: + raise ValueError(f"Unknown distribution ({dist}).") + return leaf_kwargs, leaf_type diff --git a/simple_einet/histogram.py b/simple_einet/histogram.py new file mode 100644 index 0000000..362bb00 --- /dev/null +++ b/simple_einet/histogram.py @@ -0,0 +1,296 @@ +""" +Translates numpys histogram "auto" bin size estimation to PyTorch. +""" + + +import numpy as np +import torch +import operator + +# NumPy implementations +def _hist_bin_fd_np(x): + iqr = np.subtract(*np.percentile(x, [75, 25])) + return 2.0 * iqr * x.size ** (-1.0 / 3.0) + +def _hist_bin_sturges_np(x): + return np.ptp(x) / (np.log2(x.size) + 1.0) + +def _hist_bin_auto_np(x): + fd_bw = _hist_bin_fd_np(x) + sturges_bw = _hist_bin_sturges_np(x) + return min(fd_bw, sturges_bw) if fd_bw else sturges_bw + +# PyTorch implementations +def _hist_bin_fd_torch(x): + iqr = torch.quantile(x, 0.75) - torch.quantile(x, 0.25) + return 2.0 * iqr * x.size(0) ** (-1.0 / 3.0) + +def _hist_bin_sturges_torch(x): + return _ptp_torch(x) / (torch.log2(torch.tensor(x.size(0)).float()) + 1.0) + +def _ptp_torch(x): + return x.max() - x.min() + +def _hist_bin_auto_torch(x): + fd_bw = _hist_bin_fd_torch(x) + sturges_bw = _hist_bin_sturges_torch(x) + return min(fd_bw, sturges_bw) if fd_bw > 0 else sturges_bw + + +def _get_bin_edges_torch(a, range=None, weights=None): + """ + Computes the bins used internally by `histogram` in PyTorch. + + Parameters + ---------- + a : 1D Tensor + Ravelled data array. + range : tuple + Lower and upper range of the bins. + weights : Tensor, optional + Ravelled weights array, or None. + + Returns + ------- + bin_edges : Tensor + Array of bin edges. + uniform_bins : tuple + The lower bound, upper bound, and number of bins for uniform binning. + """ + # Assume bins is "auto" as per the user's request. + if weights is not None: + raise TypeError("Automated bin estimation is not supported for weighted data") + + first_edge, last_edge = _get_outer_edges_torch(a, range) + + # Filter the array based on the range if necessary + if range is not None: + a = a[(a >= first_edge) & (a <= last_edge)] + + # If the input tensor is empty after filtering, use 1 bin + if a.numel() == 0: + n_equal_bins = 1 + else: + # Calculate the bin width using the Freedman-Diaconis estimator + width = _hist_bin_auto_torch(a) + if width > 0: + n_equal_bins = int(torch.ceil((last_edge - first_edge) / width).item()) + else: + # If width is zero, fall back to 1 bin + n_equal_bins = 1 + + # Compute bin edges + bin_edges = torch.linspace( + first_edge, last_edge, n_equal_bins + 1, dtype=torch.float32 + ) + + return bin_edges, (first_edge, last_edge, n_equal_bins) + +def _get_outer_edges_torch(a, range): + """ + Determine the outer bin edges from either the data or the given range. + """ + if range is not None: + first_edge, last_edge = range + if first_edge > last_edge: + raise ValueError("max must be larger than min in range parameter.") + if not (torch.isfinite(torch.tensor(first_edge)) and torch.isfinite(torch.tensor(last_edge))): + raise ValueError(f"Supplied range [{first_edge}, {last_edge}] is not finite.") + elif a.numel() == 0: + # Handle empty tensor case + first_edge, last_edge = 0.0, 1.0 + else: + first_edge, last_edge = a.min().item(), a.max().item() + if not (torch.isfinite(torch.tensor(first_edge)) and torch.isfinite(torch.tensor(last_edge))): + raise ValueError(f"Autodetected range [{first_edge}, {last_edge}] is not finite.") + + # Expand if the range is empty to avoid divide-by-zero errors + if first_edge == last_edge: + first_edge -= 0.5 + last_edge += 0.5 + + return first_edge, last_edge + +def _get_bin_edges_np(a, bins, range=None, weights=None): + """ + Computes the bins used internally by `histogram`. + + Parameters + ========== + a : ndarray + Ravelled data array + bins, range + Forwarded arguments from `histogram`. + weights : ndarray, optional + Ravelled weights array, or None + + Returns + ======= + bin_edges : ndarray + Array of bin edges + uniform_bins : (Number, Number, int): + The upper bound, lowerbound, and number of bins, used in the optimized + implementation of `histogram` that works on uniform bins. + """ + # parse the overloaded bins argument + n_equal_bins = None + bin_edges = None + + if isinstance(bins, str): + bin_name = bins + # if `bins` is a string for an automatic method, + # this will replace it with the number of bins calculated + if weights is not None: + raise TypeError("Automated estimation of the number of " + "bins is not supported for weighted data") + + first_edge, last_edge = _get_outer_edges_np(a, range) + + # truncate the range if needed + if range is not None: + keep = (a >= first_edge) + keep &= (a <= last_edge) + if not np.logical_and.reduce(keep): + a = a[keep] + + if a.size == 0: + n_equal_bins = 1 + else: + # Do not call selectors on empty arrays + width = _hist_bin_auto_np(a) + if width: + n_equal_bins = int(np.ceil(_unsigned_subtract(last_edge, first_edge) / width)) + else: + # Width can be zero for some estimators, e.g. FD when + # the IQR of the data is zero. + n_equal_bins = 1 + + elif np.ndim(bins) == 0: + try: + n_equal_bins = operator.index(bins) + except TypeError as e: + raise TypeError( + '`bins` must be an integer, a string, or an array') from e + if n_equal_bins < 1: + raise ValueError('`bins` must be positive, when an integer') + + first_edge, last_edge = _get_outer_edges_np(a, range) + + elif np.ndim(bins) == 1: + bin_edges = np.asarray(bins) + if np.any(bin_edges[:-1] > bin_edges[1:]): + raise ValueError( + '`bins` must increase monotonically, when an array') + + else: + raise ValueError('`bins` must be 1d, when an array') + + if n_equal_bins is not None: + # gh-10322 means that type resolution rules are dependent on array + # shapes. To avoid this causing problems, we pick a type now and stick + # with it throughout. + bin_type = np.result_type(first_edge, last_edge, a) + if np.issubdtype(bin_type, np.integer): + bin_type = np.result_type(bin_type, float) + + # bin edges must be computed + bin_edges = np.linspace( + first_edge, last_edge, n_equal_bins + 1, + endpoint=True, dtype=bin_type) + return bin_edges, (first_edge, last_edge, n_equal_bins) + else: + return bin_edges, None + +def _get_outer_edges_np(a, range): + """ + Determine the outer bin edges to use, from either the data or the range + argument + """ + if range is not None: + first_edge, last_edge = range + if first_edge > last_edge: + raise ValueError( + 'max must be larger than min in range parameter.') + if not (np.isfinite(first_edge) and np.isfinite(last_edge)): + raise ValueError( + "supplied range of [{}, {}] is not finite".format(first_edge, last_edge)) + elif a.size == 0: + # handle empty arrays. Can't determine range, so use 0-1. + first_edge, last_edge = 0, 1 + else: + first_edge, last_edge = a.min(), a.max() + if not (np.isfinite(first_edge) and np.isfinite(last_edge)): + raise ValueError( + "autodetected range of [{}, {}] is not finite".format(first_edge, last_edge)) + + # expand empty range to avoid divide by zero + if first_edge == last_edge: + first_edge = first_edge - 0.5 + last_edge = last_edge + 0.5 + + return first_edge, last_edge + +def _unsigned_subtract(a, b): + """ + Subtract two values where a >= b, and produce an unsigned result + + This is needed when finding the difference between the upper and lower + bound of an int16 histogram + """ + # coerce to a single type + signed_to_unsigned = { + np.byte: np.ubyte, + np.short: np.ushort, + np.intc: np.uintc, + np.int_: np.uint, + np.longlong: np.ulonglong + } + dt = np.result_type(a, b) + try: + unsigned_dt = signed_to_unsigned[dt.type] + except KeyError: + return np.subtract(a, b, dtype=dt) + else: + # we know the inputs are integers, and we are deliberately casting + # signed to unsigned. The input may be negative python integers so + # ensure we pass in arrays with the initial dtype (related to NEP 50). + return np.subtract(np.asarray(a, dtype=dt), np.asarray(b, dtype=dt), + casting='unsafe', dtype=unsigned_dt) + + +if "__main__" == __name__: + # Generate random data + data_np = np.random.randn(1000) # Numpy data + data_torch = torch.tensor(data_np, dtype=torch.float32) # Convert to torch + + # Compare results for _hist_bin_fd + fd_np = _hist_bin_fd_np(data_np) + fd_torch = _hist_bin_fd_torch(data_torch).item() + + # Compare results for _hist_bin_sturges + sturges_np = _hist_bin_sturges_np(data_np) + sturges_torch = _hist_bin_sturges_torch(data_torch).item() + + # Compare results for _hist_bin_auto + auto_np = _hist_bin_auto_np(data_np) + auto_torch = _hist_bin_auto_torch(data_torch).item() + + # Print comparisons + print(f"Freedman-Diaconis (Numpy): {fd_np}") + print(f"Freedman-Diaconis (Torch): {fd_torch}") + + print(f"Sturges (Numpy): {sturges_np}") + print(f"Sturges (Torch): {sturges_torch}") + + print(f"Auto (Numpy): {auto_np}") + print(f"Auto (Torch): {auto_torch}") + + # Call the function and print the results + bin_edges, (first_edge, last_edge, n_bins) = _get_bin_edges_torch(data_torch) + + print(f"Bin edges torch: {bin_edges}") + print(f"Range torch: ({first_edge}, {last_edge}), Number of bins: {n_bins}") + + bin_edges, (first_edge, last_edge, n_bins) = _get_bin_edges_np(data_np, bins="auto") + print(f"Bin edges numpy: {bin_edges}") + print(f"Range numpy: ({first_edge}, {last_edge}), Number of bins: {n_bins}") diff --git a/simple_einet/layers/distributions/piecewise_linear.py b/simple_einet/layers/distributions/piecewise_linear.py new file mode 100644 index 0000000..6aac33a --- /dev/null +++ b/simple_einet/layers/distributions/piecewise_linear.py @@ -0,0 +1,744 @@ +from typing import List, Tuple + +import tqdm +import itertools +from collections import defaultdict +from pytorch_lightning.strategies import deepspeed +import torch +from torch import nn + +from simple_einet.dist import DataType, Domain +from simple_einet.layers.distributions.abstract_leaf import AbstractLeaf +from simple_einet.sampling_utils import SamplingContext +from simple_einet.type_checks import check_valid +from simple_einet.histogram import _get_bin_edges_torch + + +import logging + +from fast_pytorch_kmeans import KMeans + +logger = logging.getLogger(__name__) + + +def pairwise(iterable): + "s -> (s0,s1), (s1,s2), (s2, s3), ..." + a, b = itertools.tee(iterable) + next(b, None) + return zip(a, b) + + +class PiecewiseLinear(AbstractLeaf): + """ + Piecewise linear leaf implementation. + + First constructs a histogram from the data and then approximates the histogram with a piecewise linear function. + """ + + def __init__( + self, + num_features: int, + num_channels: int, + num_leaves: int, + num_repetitions: int, + alpha: float = 0.0, + ): + """ + Initializes a piecewise linear distribution with the given parameters. + + Args: + num_features (int): The number of features in the input tensor. + num_channels (int): The number of channels in the input tensor. + num_leaves (int): The number of leaves in the tree structure. + num_repetitions (int): The number of repetitions of the tree structure. + alpha (float): The alpha parameter for optional laplace smoothing. + """ + super().__init__(num_features, num_channels, num_leaves, num_repetitions) + self.alpha = check_valid(alpha, expected_type=float, lower_bound=0.0) + self.xs = None + self.ys = None + self.is_initialized = False # Flag to check if the distribution has been initialized + + def initialize(self, data: torch.Tensor, domains: List[Domain]): + """ + Initializes the piecewise linear distribution with the given data. + + This includes the following steps: + + 1. Cluster data into num_leaves clusters, num_repetition times, such that for each leaf representation and repetition there is a piecewise linear function. + 2. For each cluster, construct a histogram of the data. + 3. Approximate the histogram with a piecewise linear function. + + Args: + data (torch.Tensor): The data to initialize the distribution with. + domains (List[Domain]): The domains of the features. + """ + + logger.info(f"Initializing piecewise linear distribution with data of shape {data.shape}.") + + assert data.shape[1] == self.num_channels + assert data.shape[2] == self.num_features + + self.domains = domains + + # Parameters + xs = [] # [R, L, F, C] + ys = [] + + self.mixture_weights = torch.zeros(self.num_repetitions, self.num_leaves, device=data.device) + + + for i_repetition in tqdm.tqdm(range(self.num_repetitions), desc="Initializing PiecewiseLinear Leaf Layer"): + # Repeat this for every repetition + xs_leaves = [] + ys_leaves = [] + + # Cluster data into num_leaves clusters + kmeans = KMeans(n_clusters=self.num_leaves, mode="euclidean", verbose=0, init_method="random") + kmeans.fit(data.view(data.shape[0], -1).float()) + + predictions = kmeans.predict(data.view(data.shape[0], -1).float()) + counts = torch.bincount(predictions) + self.mixture_weights[i_repetition] = counts / counts.sum() + + # Get cluster assigments for each datapoint + cluster_idxs = kmeans.max_sim(a=data.view(data.shape[0], -1).float(), b=kmeans.centroids)[1] + for cluster_idx in range(self.num_leaves): + + # Select data for this cluster + mask = cluster_idxs == cluster_idx + cluster_data = data[mask] + + xs_features = [] + ys_features = [] + for i_feature in range(self.num_features): + xs_channels = [] + ys_channels = [] + + for i_channel in range(self.num_channels): + + # Select relevant data + data_subset = cluster_data[:, i_channel, i_feature].view(cluster_data.shape[0], 1).float() + + # Construct histogram + if self.domains[i_feature].data_type == DataType.DISCRETE: + # Edges are the discrete values + mids = torch.tensor(self.domains[i_feature].values, device=data.device).float() + + # Add a break at the end + breaks = torch.cat([mids, torch.tensor([mids[-1] + 1], device=mids.device)]) + + if data_subset.shape[0] == 0: + # If there is no data in this cluster, set the density to uniform + densities = torch.ones(len(mids), device=data.device) / len(mids) + else: + # Compute counts + densities = torch.histogram(data_subset.cpu(), bins=breaks.cpu(), density=True).hist.to(data.device) + + + elif self.domains[i_feature].data_type == DataType.CONTINUOUS: + # Find histogram bins using numpys "auto" logic + bins, _ = _get_bin_edges_torch(data_subset) + + # Construct histogram + densities = torch.histogram(data_subset.cpu(), bins=bins.cpu(), density=True).hist.to(data.device) + breaks = bins + mids = ((breaks + torch.roll(breaks, shifts=-1, dims=0)) / 2)[:-1] + else: + raise ValueError(f"Unknown data type: {domains[i_feature]}") + + # Apply optional laplace smoothing + if self.alpha > 0: + n_samples = data_subset.shape[0] + n_bins = len(breaks) - 1 + counts = densities * n_samples + densities = (counts + self.alpha) / (n_samples + n_bins * self.alpha) + + assert len(densities) + 1 == len(breaks) + + # Add tail breaks to start and end + if self.domains[i_feature].data_type == DataType.DISCRETE: + tail_width = 1 + x = [b for b in breaks[:-1]] + x = [x[0] - tail_width] + x + [x[-1] + tail_width] + elif self.domains[i_feature].data_type == DataType.CONTINUOUS: + EPS = 1e-8 + x = ( + [breaks[0] - EPS] + + [b0 + (b1 - b0) / 2 for (b0, b1) in pairwise(breaks)] + + [breaks[-1] + EPS] + ) + else: + raise ValueError(f"Unknown data type: {domains[i_feature].data_type}") + + # Add density 0 at start an end tail break + y = [0.0] + [d for d in densities] + [0.0] + + # Check that shapes still match + assert len(densities) == len(breaks) - 1 + assert len(x) == len(y), (len(x), len(y)) + + # Construct tensors + x = torch.tensor(x, device=data.device) #, requires_grad=True) + y = torch.tensor(y, device=data.device) #, requires_grad=True) + + # Compute AUC using the trapeziod rule + auc = torch.trapezoid(x=x, y=y) + + # Normalize y to sum to 1 using AUC + y = y / auc + + # Store + xs_channels.append(x) + ys_channels.append(y) + + + # Store + xs_features.append(xs_channels) + ys_features.append(ys_channels) + + xs_leaves.append(xs_features) + ys_leaves.append(ys_features) + + # Store + xs.append(xs_leaves) + ys.append(ys_leaves) + + # Check shapes + assert len(xs) == len(ys) == self.num_repetitions + assert len(xs[0]) == len(ys[0]) == self.num_leaves + assert len(xs[0][0]) == len(ys[0][0]) == self.num_features + assert len(xs[0][0][0]) == len(ys[0][0][0]) == self.num_channels + + # self.mixture_weights = torch.zeros(self.num_features, self.num_channels, self.num_repetitions, self.num_leaves, device=data.device) + # for i_feature in range(self.num_features): + # xs_channel = [] + # ys_channel = [] + # for i_channel in range(self.num_channels): + + # # Select relevant data + # data_subset = data[:, i_channel, i_feature].view(data.shape[0], 1).float() + + # # Repeat this for every repetition + # xs_repetition = [] + # ys_repetition = [] + # for i_repetition in range(self.num_repetitions): + + # # Cluster data into num_leaves clusters + # kmeans = KMeans(n_clusters=self.num_leaves, mode="euclidean", verbose=0, init_method="kmeans++") + # kmeans.fit(data_subset) + + # predictions = kmeans.predict(data_subset.view(data_subset.shape[0], -1).float()) + # counts = torch.bincount(predictions) + # self.mixture_weights[i_feature, i_channel, i_repetition] = counts / counts.sum() + + # # Get cluster assigments for each datapoint + # cluster_idxs = kmeans.max_sim(a=data_subset, b=kmeans.centroids)[1] + + # xs_leaves = [] + # ys_leaves = [] + # for cluster_idx in range(self.num_leaves): + + # # Select data for this cluster + # mask = cluster_idxs == cluster_idx + # cluster_data = data_subset[mask] + + # # Construct histogram + # if self.domains[i_feature].data_type == DataType.DISCRETE: + # # Edges are the discrete values + # mids = torch.tensor(self.domains[i_feature].values, device=data.device).float() + + # # Add a break at the end + # breaks = torch.cat([mids, torch.tensor([mids[-1] + 1])]) + + # if cluster_data.shape[0] == 0: + # # If there is no data in this cluster, set the density to uniform + # densities = torch.ones(len(mids), device=data.device) / len(mids) + # else: + # # Compute counts + # densities = torch.histogram(cluster_data, bins=breaks, density=True).hist + + + # elif self.domains[i_feature].data_type == DataType.CONTINUOUS: + # # Find histogram bins using numpys "auto" logic + # bins, _ = _get_bin_edges_torch(cluster_data) + + # # Construct histogram + # densities = torch.histogram(cluster_data, bins=bins, density=True).hist + # breaks = bins + # mids = ((breaks + torch.roll(breaks, shifts=-1, dims=0)) / 2)[:-1] + # else: + # raise ValueError(f"Unknown data type: {domains[i_feature]}") + + # # Apply optional laplace smoothing + # if self.alpha > 0: + # n_samples = cluster_data.shape[0] + # n_bins = len(breaks) - 1 + # counts = densities * n_samples + # alpha_abs = n_samples * self.alpha + # densities = (counts + alpha_abs) / (n_samples + n_bins * alpha_abs) + + # assert len(densities) + 1 == len(breaks) + + # # Add tail breaks to start and end + # if self.domains[i_feature].data_type == DataType.DISCRETE: + # tail_width = 1 + # x = [b for b in breaks[:-1]] + # x = [x[0] - tail_width] + x + [x[-1] + tail_width] + # elif self.domains[i_feature].data_type == DataType.CONTINUOUS: + # EPS = 1e-8 + # x = ( + # [breaks[0] - EPS] + # + [b0 + (b1 - b0) / 2 for (b0, b1) in pairwise(breaks)] + # + [breaks[-1] + EPS] + # ) + # else: + # raise ValueError(f"Unknown data type: {domains[i_feature].data_type}") + + # # Add density 0 at start an end tail break + # y = [0.0] + [d for d in densities] + [0.0] + + # # Check that shapes still match + # assert len(densities) == len(breaks) - 1 + # assert len(x) == len(y), (len(x), len(y)) + + # # Construct tensors + # x = torch.tensor(x, device=data.device) #, requires_grad=True) + # y = torch.tensor(y, device=data.device) #, requires_grad=True) + + # # Compute AUC using the trapeziod rule + # auc = torch.trapezoid(x=x, y=y) + + # # Normalize y to sum to 1 using AUC + # y = y / auc + + # # Store + # xs_leaves.append(x) + # ys_leaves.append(y) + + # xs_repetition.append(xs_leaves) + # ys_repetition.append(ys_leaves) + + # # Store + # xs_channel.append(xs_repetition) + # ys_channel.append(ys_repetition) + + # # Store + # xs.append(xs_channel) + # ys.append(ys_channel) + + # # Check shapes + # assert len(xs) == len(ys) == self.num_features + # assert len(xs[0]) == len(ys[0]) == self.num_channels + # assert len(xs[0][0]) == len(ys[0][0]) == self.num_repetitions + # assert len(xs[0][0][0]) == len(ys[0][0][0]) == self.num_leaves + + + # Store + self.xs = xs + self.ys = ys + self.is_initialized = True + + def reset(self): + self.is_initialized = False + self.xs = None + self.ys = None + + def _get_base_distribution(self, ctx: SamplingContext = None) -> "PiecewiseLinearDist": + # Use custom normal instead of PyTorch distribution + if not self.is_initialized: + raise ValueError("PiecewiseLinear leaf layer has not been initialized yet. Call initialize(...) first to estimate to correct piecewise linear functions upfront.") + return PiecewiseLinearDist(self.xs, self.ys, domains=self.domains) + + def get_params(self): + # Get params cannot be called on PiecewiseLinearDist, since it does not have any params + raise NotImplementedError + + +def interp( + x: torch.Tensor, xp: torch.Tensor, fp: torch.Tensor, dim: int = -1, extrapolate: str = "constant" +) -> torch.Tensor: + """One-dimensional linear interpolation between monotonically increasing sample + points, with extrapolation beyond sample points. + + Source: https://github.com/pytorch/pytorch/issues/50334#issuecomment-2304751532 + + Returns the one-dimensional piecewise linear interpolant to a function with + given discrete data points :math:`(xp, fp)`, evaluated at :math:`x`. + + Args: + x: The :math:`x`-coordinates at which to evaluate the interpolated + values. + xp: The :math:`x`-coordinates of the data points, must be increasing. + fp: The :math:`y`-coordinates of the data points, same shape as `xp`. + dim: Dimension across which to interpolate. + extrapolate: How to handle values outside the range of `xp`. Options are: + - 'linear': Extrapolate linearly beyond range of xp values. + - 'constant': Use the boundary value of `fp` for `x` values outside `xp`. + + Returns: + The interpolated values, same size as `x`. + """ + # Move the interpolation dimension to the last axis + x = x.movedim(dim, -1) + xp = xp.movedim(dim, -1) + fp = fp.movedim(dim, -1) + + m = torch.diff(fp) / torch.diff(xp) # slope + b = fp[..., :-1] - m * xp[..., :-1] # offset + indices = torch.searchsorted(xp, x, right=False) + + if extrapolate == "constant": + # Pad m and b to get constant values outside of xp range + m = torch.cat([torch.zeros_like(m)[..., :1], m, torch.zeros_like(m)[..., :1]], dim=-1) + b = torch.cat([fp[..., :1], b, fp[..., -1:]], dim=-1) + else: # extrapolate == 'linear' + indices = torch.clamp(indices - 1, 0, m.shape[-1] - 1) + + values = m.gather(-1, indices) * x + b.gather(-1, indices) + + values = values.clamp(min=0.0) + + return values.movedim(-1, dim) + + +class PiecewiseLinearDist: + def __init__(self, xs, ys, domains: list[DataType]): + self.xs = xs + self.ys = ys + # self.num_features = len(xs) + # self.num_channels = len(xs[0]) + # self.num_repetitions = len(xs[0][0]) + # self.num_leaves = len(xs[0][0][0]) + + self.num_repetitions = len(xs) + self.num_leaves = len(xs[0]) + self.num_features = len(xs[0][0]) + self.num_channels = len(xs[0][0][0]) + self.domains = domains + + def _compute_cdf(self, xs, ys): + """Compute the CDF for the given piecewise linear function.""" + # Compute the integral over each interval using the trapezoid rule + intervals = torch.diff(xs) + trapezoids = 0.5 * intervals * (ys[:-1] + ys[1:]) # Partial areas + + # Cumulative sum to build the CDF + cdf = torch.cat([torch.zeros(1, device=xs.device), torch.cumsum(trapezoids, dim=0)]) + + # Normalize the CDF to ensure it goes from 0 to 1 + cdf = cdf / cdf[-1] + + return cdf + + def sample(self, sample_shape: torch.Size) -> torch.Tensor: + """Sample from the piecewise linear distribution.""" + samples = torch.empty( + (sample_shape[0], self.num_channels, self.num_features, self.num_leaves, self.num_repetitions), + device=self.xs[0][0][0][0].device, + ) + + for i_feature in range(self.num_features): + for i_channel in range(self.num_channels): + for i_repetition in range(self.num_repetitions): + for i_leaf in range(self.num_leaves): + # xs_i = self.xs[i_feature][i_channel][i_repetition][i_leaf] + # ys_i = self.ys[i_feature][i_channel][i_repetition][i_leaf] + xs_i = self.xs[i_repetition][i_leaf][i_feature][i_channel] + ys_i = self.ys[i_repetition][i_leaf][i_feature][i_channel] + + if self.domains[i_feature].data_type == DataType.DISCRETE: + # Sample from a categorical distribution + ys_i_wo_tails = ys_i[1:-1] # Cut off the tail breaks + dist = torch.distributions.Categorical(probs=ys_i_wo_tails) + samples[..., i_channel, i_feature, i_leaf, i_repetition] = dist.sample(sample_shape) + elif self.domains[i_feature].data_type == DataType.CONTINUOUS: + # Compute the CDF for this piecewise function + cdf = self._compute_cdf(xs_i, ys_i) + + # Sample from a uniform distribution + u = torch.rand(sample_shape, device=xs_i.device) + + # Find the corresponding segment using searchsorted + indices = torch.searchsorted(cdf, u, right=True) + + # Clamp indices to be within valid range + indices = torch.clamp(indices, 1, len(xs_i) - 1) + + # Perform linear interpolation to get the sample value + x0, x1 = xs_i[indices - 1], xs_i[indices] + cdf0, cdf1 = cdf[indices - 1], cdf[indices] + slope = (x1 - x0) / (cdf1 - cdf0 + 1e-8) # Avoid division by zero + + # Compute the sampled value + samples[..., i_channel, i_feature, i_leaf, i_repetition] = x0 + slope * (u - cdf0) + else: + raise ValueError(f"Unknown data type: {self.domains[i_feature].data_type}") + + samples = samples.unsqueeze( + 1 + ) # Insert "empty" second dimension since all other distributions are implemented this way and the distribution sampling logic expects this + + return samples + + def mpe(self, num_samples: int) -> torch.Tensor: + """Compute the most probable explanation (MPE) by taking the mode of the distribution.""" + modes = torch.empty( + (num_samples, self.num_channels, self.num_features, self.num_leaves, self.num_repetitions), + device=self.xs[0][0][0][0].device, + ) + + for i_feature in range(self.num_features): + for i_channel in range(self.num_channels): + for i_repetition in range(self.num_repetitions): + for i_leaf in range(self.num_leaves): + # xs_i = self.xs[i_feature][i_channel][i_repetition][i_leaf] + # ys_i = self.ys[i_feature][i_channel][i_repetition][i_leaf] + xs_i = self.xs[i_repetition][i_leaf][i_feature][i_channel] + ys_i = self.ys[i_repetition][i_leaf][i_feature][i_channel] + + # Find the mode (the x value with the highest PDF value) + max_idx = torch.argmax(ys_i) + mode_value = xs_i[max_idx] + + # Store the mode value + modes[:, i_channel, i_feature, i_leaf, i_repetition] = mode_value + + return modes + + def log_prob(self, x: torch.Tensor): + # Initialize probs with ones of the same shape as obs + probs = torch.zeros(list(x.shape[0:3]) + [self.num_leaves, self.num_repetitions], device=x.device) + if x.dim() == 5: + x = x.squeeze(-1).squeeze(-1) + + # Perform linear interpolation (equivalent to np.interp) + for i_feature in range(self.num_features): + for i_channel in range(self.num_channels): + for i_repetition in range(self.num_repetitions): + for i_leaf in range(self.num_leaves): + # xs_i = self.xs[i_feature][i_channel][i_repetition][i_leaf] + # ys_i = self.ys[i_feature][i_channel][i_repetition][i_leaf] + xs_i = self.xs[i_repetition][i_leaf][i_feature][i_channel] + ys_i = self.ys[i_repetition][i_leaf][i_feature][i_channel] + ivalues = interp(x[:, i_channel, i_feature], xs_i, ys_i) + probs[:, i_channel, i_feature, i_leaf, i_repetition] = ivalues + + # Return the logarithm of probabilities + logprobs = torch.log(probs) + logprobs[logprobs == float("-inf")] = -300.0 + return logprobs + + def get_params(self): + raise NotImplementedError + + +if __name__ == "__main__": + # # Test the piecewise linear distribution + # data = torch.randn(1000, 3, 30) + # data_types = [DataType.CONTINUOUS] * 30 + # pl = PiecewiseLinear(num_features=30, num_channels=3, num_leaves=7, num_repetitions=3) + # pl.initialize(data, data_types) + + # # Test the piecewise linear distribution + # d = pl._get_base_distribution() + # ll = d.log_prob(data) + # samples = d.sample((10,)) + # mpes = d.mpe(10) + # print(ll.shape) + # print(samples.shape) + # print(mpes.shape) + + # from simple_einet.einet import Einet, EinetConfig + + # # Create an Einet + # einet = Einet( + # EinetConfig(depth=3, num_features=8, num_channels=1, num_leaves=3, num_repetitions=4, leaf_type=PiecewiseLinear) + # ) + + # # Create some data + # data = torch.randn(1000, 1, 8) + + # einet.leaf.base_leaf.initialize(data, [DataType.CONTINUOUS] * 8) + + # # Test the piecewise linear distribution + # einet.sample(num_samples=10) + # einet.mpe() + + # exit(0) + import seaborn as sns + + sns.set() + sns.set_style("whitegrid") + + import torch + import numpy as np + import matplotlib.pyplot as plt + from torch.distributions import Normal, Uniform + + # Create a multimodal 1D dataset + def create_multimodal_dataset(n_samples=100000): + # Mix of three Gaussian distributions and one Uniform distribution + dist1 = Normal(loc=-5, scale=1) + dist2 = Normal(loc=0, scale=0.5) + dist3 = Normal(loc=5, scale=1.5) + dist4 = Uniform(low=-2, high=2) + + # Generate samples + samples1 = dist1.sample((int(n_samples * 0.3),)) + samples2 = dist2.sample((int(n_samples * 0.2),)) + samples3 = dist3.sample((int(n_samples * 0.3),)) + samples4 = dist4.sample((int(n_samples * 0.2),)) + + # Combine samples + all_samples = torch.cat([samples1, samples2, samples3, samples4]) + + # Shuffle the samples + return all_samples[torch.randperm(all_samples.size(0))] + + # Create the dataset + data = create_multimodal_dataset().unsqueeze(1).unsqueeze(1) # Shape: (10000, 1, 1) + + # Initialize PiecewiseLinear + num_features = 1 + num_channels = 1 + num_leaves = 1 # You mentioned this was increased for flexibility, but it's still 1 here + num_repetitions = 1 + + pl = PiecewiseLinear( + num_features=num_features, num_channels=num_channels, num_leaves=num_leaves, num_repetitions=num_repetitions + ) + pl.initialize(data, [Domain.continuous_inf_support()]) + + # Get the base distribution + d = pl._get_base_distribution() + + # Calculate log probabilities for a range of values + x_range = torch.linspace(-10, 10, 100000).unsqueeze(1).unsqueeze(1) + log_probs = d.log_prob(x_range) + + # Generate samples from the PWL distribution + pwl_samples = d.sample((100000,)).squeeze() + + # Plot the results + plt.figure(figsize=(12, 6)) + + # Plot histogram of the original data + plt.hist(data.squeeze().numpy(), bins=100, density=True, alpha=0.5, label="Original Data") + + # Plot histogram of the PWL samples + plt.hist(pwl_samples.numpy(), bins=100, density=True, alpha=0.5, label="PWL Samples") + + # Plot the PWL log probability (exponentiated for density) + plt.plot(x_range.squeeze().numpy(), torch.exp(log_probs).squeeze().numpy(), "r-", linewidth=2, label="PWL Density") + + pwl_mpe_x = d.mpe(1) + # Plot MPE of the distribution at the y position of the pwl_mpe value + pwl_mpe_y = d.log_prob(pwl_mpe_x).exp() + plt.plot(pwl_mpe_x.squeeze(), pwl_mpe_y.squeeze(), "rx", markersize=13, label="PWL MPE") + + plt.title("Multimodal Data, PWL Distribution, and PWL Samples") + plt.xlabel("Value") + plt.ylabel("Density") + plt.legend() + plt.grid(True, alpha=0.3) + plt.savefig("/tmp/continuous_pwl.png", dpi=300) + + # Print some statistics + print(f"Log probability shape: {d.log_prob(data).shape}") + print(f"Sample shape: {d.sample((10,)).shape}") + print(f"MPE shape: {d.mpe(10).shape}") + + import torch + import numpy as np + import matplotlib.pyplot as plt + from torch.distributions import Categorical + + # Create a multimodal discrete dataset + def create_discrete_multimodal_dataset(n_samples=100000): + # Define probabilities for a multimodal discrete distribution + probs = torch.tensor([1, 2.8, 6, 3, 0.5, 0.7, 2, 3.5, 5, 7, 8, 4, 3, 2, 2, 1, 0.5]) + probs = probs / probs.sum() + num_categories = len(probs) + + # Create a Categorical distribution + dist = Categorical(probs) + + # Generate samples + samples = dist.sample((n_samples,)) + + return samples, num_categories + + # Create the dataset + data, num_categories = create_discrete_multimodal_dataset() + data = data.unsqueeze(1).unsqueeze(1) # Shape: (10000, 1, 1) + + # Initialize PiecewiseLinear + num_features = 1 + num_channels = 1 + num_leaves = 1 # You mentioned this should be set to the number of categories, but it's still 1 here + num_repetitions = 1 + + pl = PiecewiseLinear( + num_features=num_features, num_channels=num_channels, num_leaves=num_leaves, num_repetitions=num_repetitions + ) + pl.initialize(data, [Domain.discrete_range(min=0, max=num_categories)]) + + # Get the base distribution + d = pl._get_base_distribution() + + # Calculate probabilities for a range of values (including fractional values) + x_range = torch.linspace(-0.5, num_categories - 0.5, 100000).unsqueeze(1).unsqueeze(1) + log_probs = d.log_prob(x_range) + probs = torch.exp(log_probs) + + # Generate samples from the PWL distribution + pwl_samples = d.sample((1000000,)).squeeze() + + # Plot the results + plt.figure(figsize=(12, 6)) + + # Plot histogram of the original data + plt.hist( + data.squeeze().numpy(), bins=np.arange(num_categories + 1) - 0.5, density=True, alpha=0.5, label="Original Data" + ) + + # Plot histogram of the PWL samples + plt.hist( + pwl_samples.numpy(), bins=np.arange(num_categories + 1) - 0.5, density=True, alpha=0.5, label="PWL Samples" + ) + + # Plot the PWL probabilities as a line + plt.plot(x_range.squeeze().numpy(), probs.squeeze().numpy(), "r-", linewidth=2, label="PWL Distribution") + + # Plot MPE of the distribution at the y position of the pwl_mpe value + pwl_mpe_x = d.mpe(1) + pwl_mpe_y = d.log_prob(pwl_mpe_x).exp() + plt.plot(pwl_mpe_x.squeeze(), pwl_mpe_y.squeeze(), "rx", markersize=13, label="PWL MPE") + + plt.title("Discrete Multimodal Data, PWL Distribution, and PWL Samples") + plt.xlabel("Value") + plt.ylabel("Probability") + plt.legend() + plt.grid(True, alpha=0.3) + plt.xticks(range(num_categories)) + plt.xlim(-0.5, num_categories - 0.5) + plt.savefig("/tmp/discrete_pwl.png", dpi=300) + + # Print some statistics + print(f"Log probability shape: {d.log_prob(data).shape}") + print(f"Sample shape: {d.sample((10,)).shape}") + print(f"MPE shape: {d.mpe(10).shape}") + + # Calculate and print the actual probabilities + actual_probs = torch.bincount(data.squeeze().long(), minlength=num_categories).float() / len(data) + print("\nActual probabilities:") + print(actual_probs) + + print("\nPWL probabilities at integer points:") + pwl_probs_at_integers = torch.exp(d.log_prob(torch.arange(num_categories).float().unsqueeze(1).unsqueeze(1))) + print(pwl_probs_at_integers.squeeze()) + + # Calculate KL divergence + kl_div = torch.sum(actual_probs * torch.log(actual_probs / pwl_probs_at_integers.squeeze())) + print(f"\nKL Divergence: {kl_div.item():.4f}")