-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Basic rewrite of the package 2023 edition #45
Closed
Closed
Changes from 100 commits
Commits
Show all changes
147 commits
Select commit
Hold shift + click to select a range
b49cf3e
refactor ADVI, change gradient operation interface
Red-Portal 88e0b79
remove unused file, remove unused dependency
Red-Portal c2fb3f8
fix ADVI elbo computation more efficiently
Red-Portal 83161fd
fix missing entropy regularization term
Red-Portal efa8106
add LogDensityProblem interface
Red-Portal 4ae2fbf
refactor use bijectors directly instead of transformed distributions
Red-Portal 2bf2a42
Merge branch 'master' of https://github.com/TuringLang/AdvancedVI.jl …
Red-Portal 1cadb51
fix type restrictions
Red-Portal 3474e8d
remove unused file
Red-Portal 03a2767
fix use of with_logabsdet_jacobian
Red-Portal 09c44fb
restructure project; move the main VI routine to its own file
Red-Portal b7407ce
remove redundant import
Red-Portal 4040149
restructure project into more modular objective estimators
Red-Portal 2a4514e
migrate to AbstractDifferentiation
Red-Portal 93a16d8
add location scale pre-packaged variational family, add functors
Red-Portal 2b6e9eb
Revert "migrate to AbstractDifferentiation"
Red-Portal 1bfec36
fix use optimized MvNormal specialization, add logpdf for Loc.Scale.
Red-Portal 1003606
remove dead code
Red-Portal 60a9987
fix location-scale logpdf
Red-Portal cd84f02
add sticking-the-landing (STL) estimator
Red-Portal 768641b
migrate to Optimisers.jl
Red-Portal ca02fa3
remove execution time measurement (replace later with somethin else)
Red-Portal a48377f
fix use multiple dispatch for deciding whether to stop entropy grad.
Red-Portal 0b40ccf
add termination decision, callback arguments
Red-Portal 21db3fb
add Base.show to modules
Red-Portal 25c51b4
add interface calling `restructure`, rename rebuild -> restructure
Red-Portal fc20046
add estimator state interface, add control variate interface to ADVI
Red-Portal 6faa807
fix `show(advi)` to show control variate
Red-Portal 7095d27
fix simplify `show(advi.control_variate)`
Red-Portal 9169ae2
fix type piracy by wrapping location-scale bijected distribution
Red-Portal 3db7301
remove old AdvancedVI custom optimizers
Red-Portal e6a082a
fix Location Scale to not depend on Bijectors
Red-Portal a034ebd
fix RNG namespace
Red-Portal e19abd3
fix location scale logpdf bug
Red-Portal 680c186
add Accessors dependency
Red-Portal 6c3efa8
Merge branch 'master' of https://github.com/TuringLang/AdvancedVI.jl …
Red-Portal 4c6cabf
add location scale, autodiff tests
Red-Portal 06db2f0
add Accessors import statement
Red-Portal 12de2bd
remove optimiser tests
Red-Portal bbb2cc6
refactor slightly generalize the distribution tests for the future
Red-Portal 1974846
migrate to SimpleUnPack, migrate to ADTypes
Red-Portal 19c62c8
rename vi.jl to optimize.jl
Red-Portal 63da51d
fix estimate_gradient to use adtypes
Red-Portal 65ab473
add exact inference tests
Red-Portal 3e5a452
remove Turing dependency in tests
Red-Portal 3117cec
remove unused projection
Red-Portal b1ca9cf
remove redundant `ADVIEnergy` object (now baked into `ADVI`)
Red-Portal fcbb729
add more tests, fix rng seed for tests
Red-Portal 0f6f6a4
add more tests, fix seed for tests
Red-Portal f5f5863
fix non-determinism bug
Red-Portal ade0d10
fix test hyperparameters so that tests pass, minor cleanups
Red-Portal 0caf7a9
fix minor reorganization
Red-Portal 5658cbf
add missing files
Red-Portal c712a97
fix add missing file, rename adbackend argument
Red-Portal bee839d
fix errors
Red-Portal 913911e
rename test suite
Red-Portal d50cabb
refactor renamed arguments for ADVI to be shorter
Red-Portal b134f70
fix compile error in advi test
Red-Portal a6ba379
add initial doc
Red-Portal 619b1c0
remove unused epsilon argument in location scale
Red-Portal f1c02f0
add project file for documenter
Red-Portal b0f259a
refactor STL gradient calculation to use multiple dispatch
Red-Portal b72c258
fix type bugs, relax test threshold for the exact inference tests
Red-Portal a8df9eb
refactor derivative utils to match NormalizingFlows.jl with extras
Red-Portal e8db6a7
add documentation, refactor optimize
Red-Portal 65a2b37
fix bug missing extension
Red-Portal 1a02051
remove tracker from tests
Red-Portal d8b5ea5
remove export for internal derivative utils
Red-Portal 818bc2c
fix test errors, old interface
Red-Portal 215abf3
fix wrong derivative interface, add documentation
Red-Portal 88ad768
update documentation
Red-Portal e66935b
add doc build CI
Red-Portal 9f1c647
remove convergence criterion for now
Red-Portal c8b3ee3
remove outdated export
Red-Portal afda1a1
update documentation
Red-Portal 0d37ace
update documentation
Red-Portal b8b113d
update documentation
Red-Portal b78e713
fix type error in test
Red-Portal a0564b5
remove default ADType argument
Red-Portal 3795d1e
update README
Red-Portal 28a35bc
update make getting started example actually run Julia
Red-Portal 620b38e
fix remove Float32 tests for inference tests
Red-Portal fa53398
update version
Red-Portal e909f41
add documentation publishing url
Red-Portal 43f5b75
fix wrong uuid for ForwardDiff
Red-Portal 468d5ca
Update CI.yml
yebai c07a511
refactor use `sum` and `mean` instead of abusing `mapreduce`
Red-Portal 8256df1
Merge branch 'rewriting_advancedvi' of github.com:Red-Portal/Advanced…
Red-Portal 13a8a44
remove tests for `FullMonteCarlo`
Red-Portal aadf8d3
add tests for the `optimize` interface
Red-Portal 8c4e13d
fix turn off Zygote tests for now
Red-Portal 0b708e6
remove unused function
Red-Portal be61acd
refactor change bijector field name, simplify STL estimator
Red-Portal fb519a5
update documentation
Red-Portal 8682fd9
update STL documentation
Red-Portal 9a16ee1
update STL documentation
Red-Portal fc74afa
update location scale documentation
Red-Portal 4be30a1
fix README
Red-Portal c58309d
fix math in README
Red-Portal 5b5bd3e
add gradient to arguments of callback!, remove `gradient_norm` info
Red-Portal 967021d
fix math in README.md
Red-Portal 4dab522
fix type constraint in `ZygoteExt`
Red-Portal 8ab2f19
fix import of `Random`
Red-Portal 83dec9f
refactor `__init__()`
Red-Portal a3e563c
fix type constraint in definition of `value_and_gradient!`
Red-Portal 5553bb9
refactor `ZygoteExt`; use `only` instead of `first`
Red-Portal 79b4557
refactor type constraint in `ReverseDiffExt`
Red-Portal 656b44b
refactor remove outdated debug mode macro
Red-Portal c794063
fix remove outdated DEBUG mechanism
Red-Portal 0c5cc1c
fix LaTeX in README: `operatorname` is currently broken
Red-Portal 29d7d27
remove `SimpleUnPack` dependency
Red-Portal 75eef44
fix LaTeX in docs and README
Red-Portal 40574f4
add warning about forward-mode AD when using `LocationScale`
Red-Portal 8738256
fix documentation
Red-Portal 8173744
fix remove reamining use of `@unpack`
Red-Portal e0548ae
Revert "remove `SimpleUnPack` dependency"
Red-Portal 6ab95a0
Revert "fix remove reamining use of `@unpack`"
Red-Portal f0ec242
fix documentation for `optimize`
Red-Portal 1d4c1b6
add specializations of `Optimise.destructure` for mean-field
Red-Portal 231835f
add test for `Optimisers.destructure` specializations
Red-Portal ea2d426
add specialization of `rand` for meanfield resulting in faster AD
Red-Portal 3033d75
add argument checks for `VIMeanFieldGaussian`, `VIFullRankGaussian`
Red-Portal 0cc36c0
update documentation
Red-Portal b7d3471
fix type instability, bug in argument check in `LocationScale`
Red-Portal df50e83
add missing import bug
Red-Portal ae3e9b0
refactor test, fix type bug in tests for `LocationScale`
Red-Portal e4002cf
add missing compat entries
Red-Portal 8c82569
fix missing package import in test
Red-Portal c2e7517
add additional tests for sampling `LocationScale`
Red-Portal 3a6f8bf
fix bug in batch in-place `rand!` for `LocationScale`
Red-Portal b78ef4b
fix bug in inference test initialization
Red-Portal a1f7e98
add missing file
Red-Portal 8b783ec
fix remove use of for 1.6
Red-Portal 12cd9f2
refactor adjust inference test hyperparameters to be more robust
Red-Portal 837c729
refactor `optimize` to return `obj_state`, add warm start kwargs
Red-Portal 95629a5
refactor make tests more robust, reduce amount of tests
Red-Portal 0b4b865
fix remove a cholesky in test model
Red-Portal b49f4eb
fix compat bounds, remove unused package
Red-Portal 947a070
bump compat for ADTypes 0.2
Red-Portal a9b3f48
fix broken LaTeX in README
Red-Portal 54826eb
remove redundant use of PDMats in docs
Red-Portal 1d1c8ff
fix use `Cholesky` signature supported in 1.6
Red-Portal a0de2cf
fix remove redundant cholesky operation in test
Red-Portal f593a67
add `mean`, `var`, `cov` to `LocationScale`
Red-Portal ff32ac6
refactor `optimize` warm-starting interface, add `objargs` argument
Red-Portal bc5cfd3
update documentation for `optimize`
Red-Portal de4284e
fix CUDA-compatibility bugs
Red-Portal File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,37 +1,64 @@ | ||
name = "AdvancedVI" | ||
uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c" | ||
version = "0.2.4" | ||
version = "0.3.0" | ||
|
||
[deps] | ||
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" | ||
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" | ||
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" | ||
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" | ||
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" | ||
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" | ||
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" | ||
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" | ||
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" | ||
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" | ||
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" | ||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" | ||
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" | ||
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" | ||
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
Requires = "ae029012-a4dd-5104-9daa-d747884805df" | ||
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" | ||
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" | ||
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" | ||
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" | ||
|
||
[weakdeps] | ||
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" | ||
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" | ||
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" | ||
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" | ||
|
||
[extensions] | ||
AdvancedVIEnzymeExt = "Enzyme" | ||
AdvancedVIForwardDiffExt = "ForwardDiff" | ||
AdvancedVIReverseDiffExt = "ReverseDiff" | ||
AdvancedVIZygoteExt = "Zygote" | ||
|
||
[compat] | ||
ADTypes = "0.1" | ||
Bijectors = "0.11, 0.12, 0.13" | ||
ChainRules = "1.53.0" | ||
DiffResults = "1" | ||
Distributions = "0.21, 0.22, 0.23, 0.24, 0.25" | ||
DistributionsAD = "0.2, 0.3, 0.4, 0.5, 0.6" | ||
DocStringExtensions = "0.8, 0.9" | ||
ForwardDiff = "0.10.3" | ||
ForwardDiff = "0.10.25" | ||
LogDensityProblems = "2.1.1" | ||
Optimisers = "0.2.16" | ||
ProgressMeter = "1.0.0" | ||
Requires = "0.5, 1.0" | ||
ReverseDiff = "1.14" | ||
StatsBase = "0.32, 0.33, 0.34" | ||
StatsFuns = "0.8, 0.9, 1" | ||
Tracker = "0.2.3" | ||
julia = "1.6" | ||
|
||
[extras] | ||
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" | ||
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" | ||
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" | ||
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" | ||
|
||
[targets] | ||
test = ["Pkg", "Test"] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems a compat entry is missing? Can you make sure that every dependency (regular + weak) has a compat entry?