Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gamma mixture #222

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
Prev Previous commit
Next Next commit
Merge master into fix_gamma_mixture
  • Loading branch information
albertpod committed Feb 13, 2023
commit 15148ff87aa3de9d46f0972612cf03b46b4adf58
10 changes: 8 additions & 2 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
style = "blue"
indent = 4
margin = 120
margin = 180
always_for_in = true
whitespace_typedefs = true
whitespace_ops_in_indices = true
remove_extra_newlines = true
import_to_using = false
pipe_to_function_call = false
short_to_long_function_def = false
long_to_short_function_def = false
always_use_return = false
whitespace_in_kwargs = true
annotate_untyped_fields_with_any = false
format_docstrings = false
Expand All @@ -17,4 +21,6 @@ align_struct_field = true
align_conditional = true
align_pair_arrow = true
align_matrix = false
join_lines_based_on_source = true
join_lines_based_on_source = false
separate_kwargs_with_semicolon = false
surround_whereop_typeparameters = true
10 changes: 5 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ SHELL = /bin/bash

.PHONY: lint format

lint_init:
julia --project=scripts/ -e 'using Pkg; Pkg.instantiate();'
scripts_init:
julia --project=scripts/ -e 'using Pkg; Pkg.instantiate(); Pkg.update(); Pkg.precompile();'

lint: lint_init ## Code formating check
lint: scripts_init ## Code formating check
julia --project=scripts/ scripts/format.jl

format: lint_init ## Code formating run
format: scripts_init ## Code formating run
julia --project=scripts/ scripts/format.jl --overwrite

.PHONY: benchmark
Expand All @@ -23,7 +23,7 @@ benchmark: benchmark_init ## Runs simple benchmark
.PHONY: docs

doc_init:
julia --project=docs -e 'ENV["PYTHON"]=""; using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate(); Pkg.build("PyPlot"); using PyPlot;'
julia --project=docs -e 'ENV["PYTHON"]=""; using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate();'

docs: doc_init ## Generate documentation
julia --project=docs/ docs/make.jl
Expand Down
23 changes: 13 additions & 10 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ReactiveMP"
uuid = "a194aa59-28ba-4574-a09c-4a745416d6e3"
authors = ["Dmitry Bagaev <d.v.bagaev@tue.nl>", "Albert Podusenko <a.podusenko@tue.nl>", "Bart van Erp <b.v.erp@tue.nl>", "Ismail Senoz <i.senoz@tue.nl>"]
version = "2.4.1"
version = "3.5.0"

[deps]
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Expand All @@ -17,35 +17,37 @@ LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
PositiveFactorizations = "85a6dd25-e78a-55b7-8502-1745935b8125"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Rocket = "df971d30-c9d6-4b37-b8ff-e965b2cb3a40"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
TinyHugeNumbers = "783c9a47-75a3-44ac-a16b-f1ab7b3acf04"
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
Unrolled = "9602ed7d-8fef-5bc8-8597-8f21381861e8"

[compat]
DataStructures = "0.17, 0.18"
Distributions = "0.24, 0.25"
DomainIntegrals = "0.3.2"
DomainSets = "0.5.2"
FastGaussQuadrature = "0.4"
DomainIntegrals = "0.3.2, 0.4"
DomainSets = "0.5.2, 0.6"
FastGaussQuadrature = "0.4, 0.5"
ForwardDiff = "0.10"
HCubature = "1.0.0"
LazyArrays = "0.21, 0.22"
LoopVectorization = "0.12"
MacroTools = "0.5"
Optim = "1.0.0"
PositiveFactorizations = "0.2"
ProgressMeter = "1.0.0"
Rocket = "1.4.0"
Requires = "1"
Rocket = "1.6.0"
SpecialFunctions = "1.4, 2"
StaticArrays = "1.2"
StatsBase = "0.33"
StatsFuns = "0.9, 1"
TinyHugeNumbers = "1.0.0"
TupleTools = "1.2.0"
Unrolled = "0.1.3"
julia = "1.6.0"
Expand All @@ -56,17 +58,18 @@ BenchmarkCI = "20533458-34a3-403d-a444-e18f38190b5b"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Coverage = "a2441757-f6aa-5fb2-8edb-039e3f45d037"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
GraphPPL = "b3f8163a-e979-4e85-b43e-1f63d8c8b42c"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "Pkg", "Logging", "InteractiveUtils", "TestSetExtensions", "Coverage", "Dates", "Distributed", "Documenter", "GraphPPL", "Plots", "BenchmarkCI", "BenchmarkTools", "PkgBenchmark", "Aqua", "StableRNGs"]
test = ["Test", "Pkg", "Logging", "InteractiveUtils", "TestSetExtensions", "Coverage", "Dates", "Distributed", "Documenter", "BenchmarkCI", "BenchmarkTools", "PkgBenchmark", "Aqua", "StableRNGs", "Flux", "Zygote", "DiffResults"]
185 changes: 5 additions & 180 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,189 +22,14 @@
[pkgeval-img]: https://juliaci.github.io/NanosoldierReports/pkgeval_badges/R/ReactiveMP.svg
[pkgeval-url]: https://juliaci.github.io/NanosoldierReports/pkgeval_badges/R/ReactiveMP.html

ReactiveMP.jl is a Julia package for automatic Bayesian inference on a factor graph with reactive message passing.
# Reactive message passing engine

Given a probabilistic model, ReactiveMP allows for an efficient message-passing based Bayesian inference. It uses the model structure to generate an algorithm that consists of a sequence of local computations on a Forney-style factor graph (FFG) representation of the model.
ReactiveMP.jl is a Julia package that provides an efficient reactive message passing based Bayesian inference engine on a factor graph. The package is a part of the bigger and user-friendly ecosystem for automatic Bayesian inference called [RxInfer](https://github.com/biaslab/RxInfer.jl). While ReactiveMP.jl exports only the inference engine, RxInfer provides convenient tools for model and inference constraints specification as well as routines for running efficient inference both for static and real-time datasets.

The current version supports belief propagation (sum-product message passing) and variational message passing (both Mean-Field and Structured VMP) and is aimed to run inference in conjugate state-space models.
# Examples

ReactiveMP.jl has been designed with a focus on efficiency, scalability and maximum performance for running inference on conjugate state-space models with message passing. Below is a benchmark comparison between ReactiveMP.jl and [Turing.jl](https://github.com/TuringLang/Turing.jl) on a linear multivariate Gaussian state space Model. It is worth noting that this model contains many conjugate prior and likelihood pairings that lead to analytically computable Bayesian posteriors. For these types of models, ReactiveMP.jl takes advantage of the conjugate pairings and beats general-purpose probabilistic programming packages easily in terms of computational load, speed, memory and accuracy. On the other hand, sampling-based packages like [Turing.jl](https://github.com/TuringLang/Turing.jl) are generic Bayesian inference solutions and are capable of running inference for a broader set of models.

Code is available in [benchmark folder](https://github.com/biaslab/ReactiveMP.jl/tree/master/benchmark):

Turing comparison | Scalability performance
:-------------------------:|:-------------------------:
![](benchmark/notebooks/plots/lgssm_comparison.svg?raw=true&sanitize=true) | ![](benchmark/notebooks/plots/lgssm_scaling.svg?raw=true&sanitize=true)

# Overview

See the videos below from JuliaCon 2021 and BIASlab seminar for a quick introduction to ReactiveMP.

JuliaCon 2021 presentation | ReactiveMP.jl API tutorial
:-------------------------:|:-------------------------:
[![JuliaCon 2021 ReactiveMP.jl presentation](https://img.youtube.com/vi/twhTsKsXa_8/0.jpg)](https://www.youtube.com/watch?v=twhTsKsXa_8) | [![ReactiveMP.jl API tutorial](https://img.youtube.com/vi/YwjddthBKnM/0.jpg)](https://www.youtube.com/watch?v=YwjddthBKnM)



# Installation

Install ReactiveMP through the Julia package manager:

```
] add ReactiveMP
```

Optionally, use `] test ReactiveMP` to validate the installation by running the test suite.

# Getting Started

There are demos available to get you started in the `demo/` folder. Comparative benchmarks are available in the `benchmarks/` folder.

### Coin flip simulation

Here we show a simple example of how to use ReactiveMP.jl for Bayesian inference problems. In this example we want to estimate a bias of a coin in a form of a probability distribution in a coin flip simulation.

Let's start by creating some dataset. For simplicity in this example we will use static pre-generated dataset. Each sample can be thought of as the outcome of single flip which is either heads or tails (1 or 0). We will assume that our virtual coin is biased, and lands heads up on 75% of the trials (on average).

First let's setup our environment by importing all needed packages:

```julia
using Rocket, GraphPPL, ReactiveMP, Distributions, Random
```

Next, let's define our dataset:

```julia
n = 500 # Number of coin flips
p = 0.75 # Bias of a coin

distribution = Bernoulli(p)
dataset = float.(rand(Bernoulli(p), n))
```

### Model specification

In a Bayesian setting, the next step is to specify our probabilistic model. This amounts to specifying the joint probability of the random variables of the system.

#### Likelihood
We will assume that the outcome of each coin flip is governed by the Bernoulli distribution, i.e.

<p align="center">
<img src="https://render.githubusercontent.com/render/math?math=y_i%20\sim%20\mathrm{Bernoulli}(\theta)">
</p>

where <img src="https://render.githubusercontent.com/render/math?math=y_1%20=%201"> represents "heads", <img src="https://render.githubusercontent.com/render/math?math=y_1%20=%200"> represents "tails". The underlying probability of the coin landing heads up for a single coin flip is <img src="https://render.githubusercontent.com/render/math?math=\theta%20\in%20[0,1]">.

#### Prior
We will choose the conjugate prior of the Bernoulli likelihood function defined above, namely the beta distribution, i.e.

<p align="center">
<img src="https://render.githubusercontent.com/render/math?math=\theta%20\sim%20Beta(a,%20b)">
</p>

where ``a`` and ``b`` are the hyperparameters that encode our prior beliefs about the possible values of ``θ``. We will assign values to the hyperparameters in a later step.

#### Joint probability
The joint probability is given by the multiplication of the likelihood and the prior, i.e.

<p align="center">
<img src="https://render.githubusercontent.com/render/math?math=P(y_{1:N},%20\theta)%20=%20P(\theta)%20\prod_{i=1}^N%20P(y_i%20|%20\theta).">
</p>

Now let's see how to specify this model using GraphPPL's package syntax.

```julia

# GraphPPL.jl export `@model` macro for model specification
# It accepts a regular Julia function and builds an FFG under the hood
@model function coin_model(n)

# `datavar` creates data 'inputs' in our model
# We will pass data later on to these inputs
# In this example we create a sequence of inputs that accepts Float64
y = datavar(Float64, n)

# We endow θ parameter of our model with some prior
θ ~ Beta(2.0, 7.0)

# We assume that outcome of each coin flip
# is governed by the Bernoulli distribution
for i in 1:n
y[i] ~ Bernoulli(θ)
end

# We return references to our data inputs and θ parameter
# We will use these references later on during inference step
return y, θ
end

```

As you can see, `GraphPPL` offers a model specification syntax that resembles closely to the mathematical equations defined above. We use `datavar` function to create "clamped" variables that take specific values at a later date. `θ ~ Beta(2.0, 7.0)` expression creates random variable `θ` and assigns it as an output of `Beta` node in the corresponding FFG.

### Inference specification

Once we have defined our model, the next step is to use `ReactiveMP` API to infer quantities of interests. To do this we can use a generic `inference` function from `ReactiveMP.jl` that supports static datasets.

```julia
result = inference(
model = Model(coin_model, length(dataset)),
data = (y = dataset, )
)
```

There is a way to manually specify an inference procedure for advanced use-cases. `ReactiveMP` API is flexible in terms of inference specification and is compatible both with real-time inference processing and with static datasets. In most of the cases for static datasets, as in our example, it consists of same basic building blocks:

1. Return variables of interests from model specification
2. Subscribe on variables of interests posterior marginal updates
3. Pass data to the model
4. Unsubscribe

Here is an example of inference procedure:

```julia
function custom_inference(data)
n = length(data)

# `coin_model` function from `@model` macro returns a reference to
# the model object and the same output as in `return` statement
# in the original function specification
model, (y, θ) = coin_model(n)

# Reference for future posterior marginal
mθ = nothing

# `getmarginal` function returns an observable of
# future posterior marginal updates
# We use `Rocket.jl` API to subscribe on this observable
# As soon as posterior marginal update is available we just save it in `mθ`
subscription = subscribe!(getmarginal(θ), (m) -> mθ = m)

# `update!` function passes data to our data inputs
update!(y, data)

# It is always a good practice to unsubscribe and to
# free computer resources held by the subscription
unsubscribe!(subscription)

# Here we return our resulting posterior marginal
return mθ
end
```

### Inference execution

Here after everything is ready we just call our `inference` function to get a posterior marginal distribution over `θ` parameter in the model.

```julia
θestimated = custom_inference(dataset)
```

![Coin Flip](docs/src/assets/img/coin-flip.svg?raw=true&sanitize=true "ReactiveMP.jl Benchmark")

# Where to go next?
There are a set of [demos](https://github.com/biaslab/ReactiveMP.jl/tree/master/demo) available in `ReactiveMP` repository that demonstrate the more advanced features of the package. Alternatively, you can head to the [documentation][docs-stable-url] that provides more detailed information of how to use `ReactiveMP` and `GraphPPL` to specify probabilistic models.
Tutorials and examples are available in the [RxInfer documentation](https://biaslab.github.io/RxInfer.jl/stable/).

# License

MIT License Copyright (c) 2021-2022 BIASlab
MIT License Copyright (c) 2021-2023 BIASlab
Loading
You are viewing a condensed version of this merge commit. You can view the full changes here.