Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Refactor module structure, improve distribution handling, and e…
…nhance sampling capabilities Too much to track. Here is a gpt-4o-mini summary of `git diff`: - Module Structure Update: - Moved `Dist` import from `simple_einet.data` to `simple_einet.dist` to improve module organization and clarity of distribution-related functionalities. - Main Script Updates (`main.py`): - Integrated `DataType`, `Dist`, and `PiecewiseLinear` into the imports to accommodate new distribution functionalities. - Enhanced the training logic within the `train()` function to manage the caching mechanism for piecewise linear distributions. When `args.dist` is set to `Dist.PIECEWISE_LINEAR`, the caching configuration allows for more efficient sampling of outputs. - Modified the output generation section to incorporate the new cache parameters: `cache_leaf` and `cache_index`, allowing for conditional caching based on the specified distribution type. - Distribution Enhancements: - Updated `args.py` to streamline the handling of distributions, ensuring that settings for piecewise linear distributions are correctly configured. - Added a `ConditioningNetwork` class to `abstract_layers.py`, which provides a neural network structure for conditioning inputs based on parameterized layers. - New Piecewise Linear Distribution: - Introduced `PiecewiseLinear` class to the `distributions` layer, allowing for a piecewise linear distribution to be specified for leaf nodes in models (affected files include `piecewise_linear.py` and respective imports). - Implemented methods for handling piecewise linear distribution parameters, including initialization and sampling. - Data Input Handling: - Enhanced `data.py` to include diverse datasets and improved the `get_data_shape()` and `get_data_num_classes()` methods to support new datasets. The function now identifies the number of classes for each dataset effectively, including changing implementations for the 'mnist-bin' dataset. - Refactored the data loading process to include new preprocessing functions that standardize and normalize data before passing to the model. - Sampling Improvements: - Updated the `SamplingContext` to include a `return_leaf_params` boolean flag which allows for returning parameters of the leaf distributions instead of actual samples. This provides more flexibility during sampling, particularly useful for monitoring and debugging. - Modified the `sample()` method across different distributions (including those in `multidistribution.py`, `normal.py`, `bernoulli.py`, etc.) to support the new `return_leaf_params` functionality. - Added logic for handling leaf parameters in both differentiable and non-differentiable contexts to accommodate complex sampling requirements. - Added detailed assertions and logging around sampling methods to ensure shape integrity and functional behavior. - Testing Enhancements: - Expanded unit tests in `test_einet.py` to validate the new configurations for model structure and layer types. Tests verify the functionality of both existing and new distribution types, including their respective sampling shapes and behaviors under different configurations. - Adjusted test cases to validate against realistic scenarios that utilize the newly introduced piecewise linear distribution behavior, ensuring robust coverage for edge cases.
- Loading branch information