Skip to content

Commit

Permalink
feat: Refactor module structure, improve distribution handling, and e…
Browse files Browse the repository at this point in the history
…nhance sampling capabilities

Too much to track. Here is a gpt-4o-mini summary of `git diff`:

- Module Structure Update:
  - Moved `Dist` import from `simple_einet.data` to `simple_einet.dist` to improve module organization and clarity of distribution-related functionalities.

- Main Script Updates (`main.py`):
  - Integrated `DataType`, `Dist`, and `PiecewiseLinear` into the imports to accommodate new distribution functionalities.
  - Enhanced the training logic within the `train()` function to manage the caching mechanism for piecewise linear distributions. When `args.dist` is set to `Dist.PIECEWISE_LINEAR`, the caching configuration allows for more efficient sampling of outputs.
  - Modified the output generation section to incorporate the new cache parameters: `cache_leaf` and `cache_index`, allowing for conditional caching based on the specified distribution type.

- Distribution Enhancements:
  - Updated `args.py` to streamline the handling of distributions, ensuring that settings for piecewise linear distributions are correctly configured.
  - Added a `ConditioningNetwork` class to `abstract_layers.py`, which provides a neural network structure for conditioning inputs based on parameterized layers.

- New Piecewise Linear Distribution:
  - Introduced `PiecewiseLinear` class to the `distributions` layer, allowing for a piecewise linear distribution to be specified for leaf nodes in models (affected files include `piecewise_linear.py` and respective imports).
  - Implemented methods for handling piecewise linear distribution parameters, including initialization and sampling.

- Data Input Handling:
  - Enhanced `data.py` to include diverse datasets and improved the `get_data_shape()` and `get_data_num_classes()` methods to support new datasets. The function now identifies the number of classes for each dataset effectively, including changing implementations for the 'mnist-bin' dataset.
  - Refactored the data loading process to include new preprocessing functions that standardize and normalize data before passing to the model.

- Sampling Improvements:
  - Updated the `SamplingContext` to include a `return_leaf_params` boolean flag which allows for returning parameters of the leaf distributions instead of actual samples. This provides more flexibility during sampling, particularly useful for monitoring and debugging.
  - Modified the `sample()` method across different distributions (including those in `multidistribution.py`, `normal.py`, `bernoulli.py`, etc.) to support the new `return_leaf_params` functionality.
  - Added logic for handling leaf parameters in both differentiable and non-differentiable contexts to accommodate complex sampling requirements.
  - Added detailed assertions and logging around sampling methods to ensure shape integrity and functional behavior.

- Testing Enhancements:
  - Expanded unit tests in `test_einet.py` to validate the new configurations for model structure and layer types. Tests verify the functionality of both existing and new distribution types, including their respective sampling shapes and behaviors under different configurations.
  - Adjusted test cases to validate against realistic scenarios that utilize the newly introduced piecewise linear distribution behavior, ensuring robust coverage for edge cases.
  • Loading branch information
braun-steven committed Oct 22, 2024
1 parent 0166706 commit ae73c4d
Show file tree
Hide file tree
Showing 26 changed files with 1,791 additions and 651 deletions.
2 changes: 1 addition & 1 deletion args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import pathlib

from simple_einet.data import Dist
from simple_einet.dist import Dist


def parse_args():
Expand Down
253 changes: 0 additions & 253 deletions benchmark/benchmark.md

This file was deleted.

1 change: 0 additions & 1 deletion exp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import torch
from torch.backends import cudnn as cudnn
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import ToTensor


Expand Down
Loading

0 comments on commit ae73c4d

Please sign in to comment.