-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #22 from JuliaTrustworthyAI/21-issue-with-readme
updated
- Loading branch information
Showing
39 changed files
with
21,813 additions
and
20,493 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,7 @@ jobs: | |
- '1.7' | ||
- '1.8' | ||
- '1.9' | ||
- 'nightly' | ||
- '1.10' | ||
os: | ||
- ubuntu-latest | ||
arch: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
4,054 changes: 2,027 additions & 2,027 deletions
4,054
README_files/figure-commonmark/cell-3-output-1.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9,183 changes: 4,608 additions & 4,575 deletions
9,183
README_files/figure-commonmark/cell-8-output-1.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
11 changes: 11 additions & 0 deletions
11
_freeze/docs/src/explanation/samplers/execute-results/md.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
{ | ||
"hash": "577028b23128dbdcef11147ab874511c", | ||
"result": { | ||
"engine": "jupyter", | ||
"markdown": "---\ntitle: MNIST\n---\n\n\n\n\n\n\n::: {.cell execution_count=2}\n``` {.julia .cell-code}\nusing TaijaData: load_mnist\nusing CounterfactualExplanations.Models: load_mnist_mlp, load_mnist_ensemble\n```\n:::\n\n\n::: {.cell execution_count=3}\n``` {.julia .cell-code}\nX, y = load_mnist()\nK = size(y, 1)\nD = size(X, 1)\nmlp = load_mnist_mlp().model\nf_mlp(x) = mlp(x)\nens = load_mnist_ensemble().model\nf_ens(x) = sum(map(mlp -> mlp(x), ens))/length(ens)\nbatch_size = 100\n```\n:::\n\n\n## Sampling\n\n::: {.cell execution_count=4}\n``` {.julia .cell-code}\n𝒟x = Uniform(0,1)\n𝒟y = Categorical(ones(K) ./ K)\nsampler = UnconditionalSampler(𝒟x; input_size=(D,))\nconditional_sampler = ConditionalSampler(𝒟x, 𝒟y; input_size=(D,))\nopt = ImproperSGLD()\nn_iter = 256\n```\n:::\n\n\n### Conditional Draws\n\n::: {.cell execution_count=5}\n``` {.julia .cell-code}\n_w = 1500\nplts = []\nneach = 10\nfor i in 1:10\n x = conditional_sampler(f_mlp, opt; niter=n_iter, y=i, n_samples=neach)\n plts_i = []\n for j in 1:size(x,2)\n xj = reshape(x[:,j], (28,28))\n plts_i = [plts_i..., heatmap(rotl90(xj), axis=nothing, cb=false)]\n end\n plt = plot(plts_i..., size=(_w,0.10*_w), layout=(1,10))\n plts = [plts..., plt]\nend\nplot(plts..., size=(_w,_w), layout=(10,1))\n```\n:::\n\n\n::: {.cell execution_count=6}\n``` {.julia .cell-code}\n_w = 1500\nplts = []\nneach = 10\nfor i in 1:10\n x = conditional_sampler(f_ens, opt; niter=n_iter, y=i, n_samples=neach)\n plts_i = []\n for j in 1:size(x,2)\n xj = reshape(x[:,j], (28,28))\n plts_i = [plts_i..., heatmap(rotl90(xj), axis=nothing, cb=false)]\n end\n plt = plot(plts_i..., size=(_w,0.10*_w), layout=(1,10))\n plts = [plts..., plt]\nend\nplot(plts..., size=(_w,_w), layout=(10,1))\n```\n:::\n\n\n### Unconditional Draws\n\n::: {.cell execution_count=7}\n``` {.julia .cell-code}\n_w = 1500\nplts = []\nneach = 10\nfor i in 1:10\n x = sampler(f_mlp, opt; niter=n_iter, n_samples=neach)\n plts_i = []\n for j in 1:size(x,2)\n xj = reshape(x[:,j], (28,28))\n plts_i = [plts_i..., heatmap(rotl90(xj), axis=nothing, cb=false)]\n end\n plt = plot(plts_i..., size=(_w,0.10*_w), layout=(1,10))\n plts = [plts..., plt]\nend\nplot(plts..., size=(_w,_w), layout=(10,1))\n```\n:::\n\n\n::: {.cell execution_count=8}\n``` {.julia .cell-code}\n_w = 1500\nplts = []\nneach = 10\nfor i in 1:10\n x = sampler(f_ens, opt; niter=n_iter, n_samples=neach)\n plts_i = []\n for j in 1:size(x,2)\n xj = reshape(x[:,j], (28,28))\n plts_i = [plts_i..., heatmap(rotl90(xj), axis=nothing, cb=false)]\n end\n plt = plot(plts_i..., size=(_w,0.10*_w), layout=(1,10))\n plts = [plts..., plt]\nend\nplot(plts..., size=(_w,_w), layout=(10,1))\n```\n:::\n\n\n", | ||
"supporting": [ | ||
"samplers_files" | ||
], | ||
"filters": [] | ||
} | ||
} |
11 changes: 11 additions & 0 deletions
11
_freeze/docs/src/how_to_guides/mlj_flux/execute-results/md.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
{ | ||
"hash": "155039eb05e14711352a223bba059fae", | ||
"result": { | ||
"engine": "jupyter", | ||
"markdown": "---\ntitle: Compatibility with `MLJFlux`\n---\n\n\n\n\n\n\n## Synthetic Data\n\n::: {.cell execution_count=2}\n``` {.julia .cell-code}\nnobs=2000\nX, y = make_circles(nobs, noise=0.1, factor=0.5)\nXplot = Float32.(permutedims(matrix(X)))\nX = table(permutedims(Xplot))\ndisplay(scatter(Xplot[1,:], Xplot[2,:], group=y, label=\"\"))\nbatch_size = Int(round(nobs/10))\n```\n:::\n\n\n::: {.cell execution_count=3}\n``` {.julia .cell-code}\nsampler = ConditionalSampler(X, y, batch_size=batch_size)\nclf = JointEnergyClassifier(\n sampler;\n builder=MLJFlux.MLP(hidden=(32, 32, 32,), σ=Flux.relu),\n batch_size=batch_size,\n finaliser=Flux.softmax,\n loss=Flux.Losses.crossentropy,\n jem_training_params=(α=[1.0,1.0,1e-1],verbosity=10,),\n)\n```\n:::\n\n\n::: {.cell execution_count=4}\n``` {.julia .cell-code}\nprintln(typeof(clf) <: MLJFlux.MLJFluxModel)\n```\n:::\n\n\n::: {.cell execution_count=5}\n``` {.julia .cell-code}\nmach = machine(clf, X, y)\nfit!(mach)\n```\n:::\n\n\n::: {.cell execution_count=6}\n``` {.julia .cell-code}\njem = mach.model.jem\nbatch_size = mach.model.batch_size\nX = Float32.(permutedims(matrix(X)))\ny_labels = Int.(y.refs)\ny = Flux.onehotbatch(y.refs, sort(unique(y_labels)))\n```\n:::\n\n\n::: {.cell execution_count=7}\n``` {.julia .cell-code}\nif typeof(jem.sampler) <: ConditionalSampler\n \n plts = []\n for target in 1:size(y,1)\n X̂ = generate_conditional_samples(jem, batch_size, target; niter=1000) \n ex = extrema(hcat(X,X̂), dims=2)\n xlims = ex[1]\n ylims = ex[2]\n x1 = range(1.0f0.*xlims...,length=100)\n x2 = range(1.0f0.*ylims...,length=100)\n plt = contour(\n x1, x2, (x, y) -> softmax(jem([x, y]))[target], \n fill=true, alpha=0.5, title=\"Target: $target\", cbar=false,\n xlims=xlims,\n ylims=ylims,\n )\n scatter!(X[1,:], X[2,:], color=vec(y_labels), group=vec(y_labels), alpha=0.5)\n scatter!(\n X̂[1,:], X̂[2,:], \n color=repeat([target], size(X̂,2)), \n group=repeat([target], size(X̂,2)), \n shape=:star5, ms=10\n )\n push!(plts, plt)\n end\n plt = plot(plts..., layout=(1, size(y,1)), size=(size(y,1)*400, 400))\n display(plt)\nend\n```\n:::\n\n\n## MNIST \n\n::: {.cell execution_count=8}\n``` {.julia .cell-code}\n# Data:\nnobs = 1000\nn_digits = 28\nXtrain, ytrain, _, _, _, _ = load_mnist_data(nobs=nobs, n_digits=n_digits)\nXtrain = table(permutedims(MLUtils.flatten(Xtrain)))\nytrain = coerce(Flux.onecold(ytrain, 0:9), Multiclass)\n\n# Hyperparameters:\nD = n_digits^2 \nK = 10 \nM = 32\nlr = 1e-3 \nnum_epochs = 500\nmax_patience = 5 \nbatch_size = Int(round(nobs/10))\nα = [1.0,1.0,1e-2]\n```\n:::\n\n\n::: {.cell execution_count=9}\n``` {.julia .cell-code}\nactivation = Flux.swish\nbuilder = MLJFlux.MLP(hidden=(M,M,M,), σ=activation)\n```\n:::\n\n\n::: {.cell execution_count=10}\n``` {.julia .cell-code}\n# We initialize the full model\n𝒟x = Uniform(-1,1)\n𝒟y = Categorical(ones(K) ./ K)\nsampler = ConditionalSampler(𝒟x, 𝒟y, input_size=(D,), batch_size=10)\nclf = JointEnergyClassifier(\n sampler;\n builder=builder,\n batch_size=batch_size,\n finaliser=Flux.softmax,\n loss=Flux.Losses.crossentropy,\n jem_training_params=(α=α,verbosity=10,),\n sampling_steps=20,\n optimiser=Flux.Optimise.Adam(lr),\n)\n```\n:::\n\n\n::: {.cell execution_count=11}\n``` {.julia .cell-code}\nmach = machine(clf, Xtrain, ytrain)\nfit!(mach)\n```\n:::\n\n\n::: {.cell execution_count=12}\n``` {.julia .cell-code}\njem = mach.model.jem\nn_iter = 1000\n_w = 1500\nplts = []\nneach = 10\nfor i in 1:10\n x = jem.sampler(jem.chain, jem.sampling_rule; niter=n_iter, n_samples=neach, y=i)\n plts_i = []\n for j in 1:size(x, 2)\n xj = x[:,j]\n xj = reshape(xj, (n_digits, n_digits))\n plts_i = [plts_i..., heatmap(rotl90(xj), axis=nothing, cb=false)]\n end\n plt = plot(plts_i..., size=(_w,0.10*_w), layout=(1,10))\n plts = [plts..., plt]\nend\nplot(plts..., size=(_w,_w), layout=(10,1))\n```\n:::\n\n\n", | ||
"supporting": [ | ||
"mlj_flux_files" | ||
], | ||
"filters": [] | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.