Skip to content

Commit

Permalink
build based on 8cc46c8
Browse files Browse the repository at this point in the history
  • Loading branch information
Documenter.jl committed Sep 9, 2024
1 parent 19961af commit 72b0354
Show file tree
Hide file tree
Showing 50 changed files with 13,259 additions and 3 deletions.
2 changes: 1 addition & 1 deletion stable
2 changes: 1 addition & 1 deletion v0.1
1 change: 1 addition & 0 deletions v0.1.7/.documenter-siteinfo.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"documenter":{"julia_version":"1.10.5","generation_timestamp":"2024-09-09T13:59:40","documenter_version":"1.7.0"}}
21 changes: 21 additions & 0 deletions v0.1.7/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2023 Patrick Altmeyer

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
128 changes: 128 additions & 0 deletions v0.1.7/_intro.qmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
*Joint Energy Models in Julia.*

[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://juliatrustworthyai.github.io/JointEnergyModels.jl/stable)
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://juliatrustworthyai.github.io/JointEnergyModels.jl/dev)
[![Build Status](https://github.com/juliatrustworthyai/JointEnergyModels.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/juliatrustworthyai/JointEnergyModels.jl/actions/workflows/CI.yml?query=branch%3Amain)
[![Coverage](https://codecov.io/gh/juliatrustworthyai/JointEnergyModels.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/juliatrustworthyai/JointEnergyModels.jl)
[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle)
[![License](https://img.shields.io/github/license/juliatrustworthyai/JointEnergyModels.jl)](LICENSE)
[![Package Downloads](https://img.shields.io/badge/dynamic/json?url=http%3A%2F%2Fjuliapkgstats.com%2Fapi%2Fv1%2Fmonthly_downloads%2FJointEnergyModels&query=total_requests&suffix=%2Fmonth&label=Downloads)](http://juliapkgstats.com/pkg/JointEnergyModels)
[![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)

```{julia}
#| echo: false
include("$(pwd())/docs/setup_docs.jl")
eval(setup_docs)
```

`JointEnergyModels.jl` is a package for training Joint Energy Models in Julia. Joint Energy Models (JEM) are hybrid models that learn to discriminate between classes $y$ and generate input data $x$. They were introduced in @grathwohl2020your, which provides the foundation for the methodologies implemented in this package.

## 🔁 Status

This package is still in its infancy and the API is subject to change. Currently, the package can be used to train JEMs for classification. It is also possible to train pure Energy-Based Models (EBMs) for the generative task only. The package is compatible with `Flux.jl`. Work on compatibility with `MLJ.jl` (through `MLJFlux.jl`) is currently under way.

We welcome contributions and feedback at this early stage. To install the development version of the package you can run the following command:

```{.julia}
using Pkg
Pkg.add(url="https://github.com/juliatrustworthyai/JointEnergyModels.jl")
```

## 🔍 Usage Example

```{=commonmark}
!!! warning "Breaking Changes Anticipated"
To facilitate the interface to MLJFlux, this package currently overloads private methods. We are still deliberating
```

Below we first generate some synthetic data:

```{julia}
#| output: true
nobs=2000
X, y = make_circles(nobs, noise=0.1, factor=0.5)
Xplot = Float32.(permutedims(matrix(X)))
X = table(permutedims(Xplot))
plt = scatter(Xplot[1,:], Xplot[2,:], group=y, label="")
batch_size = Int(round(nobs/10))
```

The `MLJFlux` compatible classifier can be instantiated as follows:

```{julia}
𝒟x = Normal()
𝒟y = Categorical(ones(2) ./ 2)
sampler = ConditionalSampler(𝒟x, 𝒟y, input_size=size(Xplot)[1:end-1], batch_size=batch_size)
clf = JointEnergyClassifier(
sampler;
builder=MLJFlux.MLP(hidden=(32, 32, 32,), σ=Flux.relu),
batch_size=batch_size,
finaliser=x -> x,
loss=Flux.Losses.logitcrossentropy,
)
```

It uses the `MLJFlux` package to build the model:

```{julia}
#| output: true
println(typeof(clf) <: MLJFlux.MLJFluxModel)
```

The model can be wrapped in data and trained using the `fit!` function:

```{julia}
mach = machine(clf, X, y)
fit!(mach, verbosity=1)
```

The results are visualised below. The model has learned to discriminate between the two classes (as indicated by the contours) and to generate samples from each class (as indicated by the stars).

```{julia}
#| echo: false
jem = mach.model.jem
batch_size = mach.model.batch_size
X = Float32.(permutedims(matrix(X)))
y_labels = Int.(y.refs)
y = Flux.onehotbatch(y.refs, sort(unique(y_labels)))
```

```{julia}
#| output: true
#| echo: false
if typeof(jem.sampler) <: ConditionalSampler
plts = []
for target in 1:size(y,1)
X̂ = generate_conditional_samples(jem, batch_size, target; niter=1000)
ex = extrema(hcat(X,X̂), dims=2)
xlims = ex[1]
ylims = ex[2]
x1 = range(1.0f0.*xlims...,length=100)
x2 = range(1.0f0.*ylims...,length=100)
plt = contour(
x1, x2, (x, y) -> softmax(jem([x, y]))[target],
fill=true, alpha=0.5, title="Target: $target", cbar=true,
xlims=xlims,
ylims=ylims,
)
scatter!(X[1,:], X[2,:], color=vec(y_labels), group=vec(y_labels), alpha=0.5)
scatter!(
X̂[1,:], X̂[2,:],
color=repeat([target], size(X̂,2)),
group=repeat([target], size(X̂,2)),
shape=:star5, ms=10
)
push!(plts, plt)
end
plt = plot(plts..., layout=(1, size(y,1)), size=(size(y,1)*500, 400))
display(plt)
end
```

## 🎓 References
6 changes: 6 additions & 0 deletions v0.1.7/_metadata.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
format:
commonmark:
variant: -raw_html+tex_math_dollars
wrap: none
mermaid-format: png
bibliography: ../../bib.bib
Loading

0 comments on commit 72b0354

Please sign in to comment.