Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

As/emulate sample with nn hmc #2

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions emulate_sample/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[deps]
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
ProgressBars = "49802e3a-d2f1-5c88-81d8-b72133a6f568"
149 changes: 149 additions & 0 deletions emulate_sample/emulate_sample_catke.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
using JLD2, Enzyme, ProgressBars, Random, Statistics, AdvancedHMC, GLMakie

Random.seed!(1234)
tic = time()
include("simple_networks.jl")
include("hmc_interface.jl")
include("optimization_utils.jl")

data_directory = ""
data_file = "catke_parameters.jld2"

jlfile = jldopen(data_directory * data_file, "r")
θ = jlfile["parameters"]
y = jlfile["objectives"]
close(jlfile)

θ̄ = mean(θ)
θ̃ = std(θ)
ymax = maximum(y)
ymin = minimum(y)
yshift = ymin # ymin / 2 #
Δy = ymax - ymin # 2 * std(y) #
θr = (reshape(θ, (size(θ)[1] * size(θ)[2], size(θ)[3])) .- θ̄ ) ./ (2θ̃)
yr = (reshape(y, (size(y)[1] * size(y)[2])) .- yshift ) ./ Δy
M = size(θr)[1]
Mᴾ = size(θr)[2]

# Define Network
Nθ = size(θr, 2)
Nθᴴ = Nθ ÷ 2
W1 = randn(Nθᴴ, Nθ)
b1 = randn(Nθᴴ)
W2 = randn(1, Nθᴴ)
b2 = randn(1)
W3 = randn(1, Nθ)
b3 = randn(1)

network = OneLayerNetworkWithLinearByPass(W1, b1, W2, b2, W3, b3)
dnetwork = deepcopy(network)
smoothed_network = deepcopy(network)

## Emulate
adam = Adam(network)
batchsize = 100
loss_list = Float64[]
test_loss_list = Float64[]
epochs = 100
network_parameters = copy(parameters(network))
for i in ProgressBar(1:epochs)
shuffled_list = chunk_list(shuffle(1:2:M), batchsize)
shuffled_test_list = chunk_list(shuffle(2:2:M), batchsize)
loss_value = 0.0
N = length(shuffled_list)
# Batched Gradient Descent and Loss Evaluation
for permuted_list in ProgressBar(shuffled_list)
θbatch = [θr[x, :] for x in permuted_list]
ybatch = yr[permuted_list]
zero!(dnetwork)
autodiff(Enzyme.Reverse, loss, Active, DuplicatedNoNeed(network, dnetwork), Const(θbatch), Const(ybatch))
update!(adam, network, dnetwork)
loss_value += loss(network, θbatch, ybatch) / N
end
push!(loss_list, loss_value)
# Test Loss
loss_value = 0.0
N = length(shuffled_test_list)
for permuted_list in shuffled_test_list
θbatch = [θr[x, :] for x in permuted_list]
ybatch = yr[permuted_list]
loss_value += loss(network, θbatch, ybatch) / N
end
push!(test_loss_list, loss_value)
# Weighted Averaging of Network
m = 0.9
network_parameters .= m * network_parameters + (1-m) * parameters(network)
set_parameters!(smoothed_network, network_parameters)
end

loss_fig = Figure()
ax = Axis(loss_fig[1, 1]; title = "Log10 Loss", xlabel = "Epoch", ylabel = "Loss")
scatter!(ax, log10.(loss_list); color = :blue, label = "Training Loss")
scatter!(ax, log10.(test_loss_list); color = :red, label = "Test Loss")
axislegend(ax, position = :rt)
display(loss_fig)

## Sample
# Define logp and ∇logp and regularizer

initial_θ = copy(θr[argmin(yr), :])
mintheta = minimum(θr, dims = 1)[:]
maxtheta = maximum(θr, dims = 1)[:]
reg = Regularizer([mintheta, maxtheta, initial_θ])

function (regularizer::Regularizer)(x)
if any(x .≤ regularizer.parameters[1])
return -Inf
elseif any(x .> regularizer.parameters[2])
return -Inf
else
return -sum(abs.(x - regularizer.parameters[3]) ./ (regularizer.parameters[2] - regularizer.parameters[1]))
end
return 0.0
end

scale = 10 * Δy # minimum(yr)
regularization_scale = 0.001/2 * scale

U = LogDensity(network, reg, scale, regularization_scale)
∇U = GradientLogDensity(U)

# HMC
D = size(θr, 2)
n_samples = 10000
n_adapts = 1000

metric = DiagEuclideanMetric(D)
hamiltonian = Hamiltonian(metric, GaussianKinetic(), U, ∇U)

initial_ϵ = find_good_stepsize(hamiltonian, initial_θ)
integrator = Leapfrog(initial_ϵ)

kernel = HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn()))
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(0.8, integrator))

samples, stats = sample(hamiltonian, kernel, initial_θ, n_samples, adaptor, n_adapts; progress=true)

toc = time()
println("Elapsed time: $((toc - tic)/60) minutes")

# Plot
fig = Figure()
Mp = 5
for i in 1:23
ii = (i-1)÷Mp + 1
jj = (i-1)%Mp + 1
ax = Axis(fig[ii, jj]; title = "Parameter $i")
v1 = ([sample[i] for sample in samples] .* 2θ̃) .+ θ̄
hist!(ax, v1; bins = 50, strokewidth = 0, color = :blue, normalization = :pdf)
xlims!(ax, -0.1, (reg.parameters[2][i]* 2θ̃ + θ̄) * 1.1)
density!(ax, v1; color = (:red, 0.1), strokewidth = 3, strokecolor = :red)
end
display(fig)

imin = argmax([stat.log_density for stat in stats])
imax = argmin([stat.log_density for stat in stats])
network(samples[imin])
θ₀ = (initial_θ .* 2θ̃) .+ θ̄
((mean(samples) .* 2θ̃) .+ θ̄) - θ₀
((samples[imin] .* 2θ̃) .+ θ̄) - θ₀
42 changes: 42 additions & 0 deletions emulate_sample/hmc_interface.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
struct LogDensity{N, S, M}
logp::N
regularization::S
scale::M
scale_regularization::M
end

# Negative sign if the network represents the potential function
# Note: regularization should be negative semi-definite
function (logp::LogDensity{T})(θ) where T <: SimpleNetwork
return -logp.logp(θ)[1] * logp.scale + logp.regularization(θ) * logp.scale_regularization
end

function LogDensity(network::SimpleNetwork)
regularization(x) = 0.0
return LogDensity(network, regularization, 1.0, 1.0)
end

function LogDensity(network::SimpleNetwork, scale)
regularization(x) = 0.0
return LogDensity(network, regularization, scale, 1.0)
end

struct GradientLogDensity{N}
logp::N
dθ::Vector{Float64}
end

function GradientLogDensity(logp::LogDensity{S}) where S <: SimpleNetwork
dθ = zeros(size(logp.logp.W1, 2))
return GradientLogDensity(logp, dθ)
end

function (∇logp::GradientLogDensity)(θ)
∇logp.dθ .= 0.0
autodiff(Enzyme.Reverse, Const(∇logp.logp), Active, DuplicatedNoNeed(θ, ∇logp.dθ))
return (∇logp.logp(θ), copy(∇logp.dθ))
end

struct Regularizer{F}
parameters::F
end
85 changes: 85 additions & 0 deletions emulate_sample/optimization_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
function loss(network::SimpleNetwork, x, y)
ŷ = similar(y)
for i in eachindex(ŷ)
ŷ[i] = predict(network, x[i])[1]
end
return mean((y .- ŷ) .^ 2)
end

function chunk_list(list, n)
return [list[i:min(i+n-1, length(list))] for i in 1:n:length(list)]
end

struct Adam{S, T, I}
struct_copies::S
parameters::T
t::I
end

function parameters(network::SimpleNetwork)
network_parameters = []
for names in propertynames(network)
push!(network_parameters, getproperty(network, names)[:])
end
param_lengths = [length(params) for params in network_parameters]
parameter_list = zeros(sum(param_lengths))
start = 1
for i in 1:length(param_lengths)
parameter_list[start:start+param_lengths[i]-1] .= network_parameters[i]
start += param_lengths[i]
end
return parameter_list
end

function set_parameters!(network::SimpleNetwork, parameters_list)
param_lengths = Int64[]
for names in propertynames(network)
push!(param_lengths, length(getproperty(network, names)[:]))
end
start = 1
for (i, names) in enumerate(propertynames(network))
getproperty(network, names)[:] .= parameters_list[start:start+param_lengths[i]-1]
start = start + param_lengths[i]
end
return nothing
end

function Adam(network::SimpleNetwork; α=0.001, β₁=0.9, β₂=0.999, ϵ=1e-8)
parameters_list = (; α, β₁, β₂, ϵ)
network_parameters = parameters(network)
t = [1.0]
θ = deepcopy(network_parameters) .* 0.0
gₜ = deepcopy(network_parameters) .* 0.0
m₀ = deepcopy(network_parameters) .* 0.0
mₜ = deepcopy(network_parameters) .* 0.0
v₀ = deepcopy(network_parameters) .* 0.0
vₜ = deepcopy(network_parameters) .* 0.0
v̂ₜ = deepcopy(network_parameters) .* 0.0
m̂ₜ = deepcopy(network_parameters) .* 0.0
struct_copies = (; θ, gₜ, m₀, mₜ, v₀, vₜ, v̂ₜ, m̂ₜ)
return Adam(struct_copies, parameters_list, t)
end


function update!(adam::Adam, network::SimpleNetwork, dnetwork::SimpleNetwork)
# unpack
(; α, β₁, β₂, ϵ) = adam.parameters
t = adam.t[1]
(; θ, gₜ, m₀, mₜ, v₀, vₜ, v̂ₜ, m̂ₜ) = adam.struct_copies
t = adam.t[1]
# get gradient
θ .= parameters(network)
gₜ .= parameters(dnetwork)
# update
@. mₜ = β₁ * m₀ + (1 - β₁) * gₜ
@. vₜ = β₂ * v₀ + (1 - β₂) * (gₜ .^2)
@. m̂ₜ = mₜ / (1 - β₁^t)
@. v̂ₜ = vₜ / (1 - β₂^t)
@. θ = θ - α * m̂ₜ / (sqrt(v̂ₜ) + ϵ)
# update parameters
m₀ .= mₜ
v₀ .= vₜ
adam.t[1] += 1
set_parameters!(network, θ)
return nothing
end
67 changes: 67 additions & 0 deletions emulate_sample/simple_networks.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
using Enzyme

abstract type SimpleNetwork end

struct OneLayerNetwork{M, V} <: SimpleNetwork
W1::M
b1::V
W2::M
b2::V
end

struct OneLayerNetworkWithLinearByPass{M,V} <: SimpleNetwork
W1::M
b1::V
W2::M
b2::V
W3::M
b3::V
end

function zero!(dnetwork::OneLayerNetworkWithLinearByPass)
dnetwork.W1 .= 0.0
dnetwork.b1 .= 0.0
dnetwork.W2 .= 0.0
dnetwork.b2 .= 0.0
dnetwork.W3 .= 0.0
dnetwork.b3 .= 0.0
return nothing
end

function zero!(dnetwork::OneLayerNetwork)
dnetwork.W1 .= 0.0
dnetwork.b1 .= 0.0
dnetwork.W2 .= 0.0
dnetwork.b2 .= 0.0
return nothing
end

function update!(network::SimpleNetwork, dnetwork::SimpleNetwork, η)
network.W1 .-= η .* dnetwork.W1
network.b1 .-= η .* dnetwork.b1
network.W2 .-= η .* dnetwork.W2
network.b2 .-= η .* dnetwork.b2
return nothing
end

swish(x) = x / (1 + exp(-x))
activation_function(x) = swish(x) # tanh(x) #

function predict(network::OneLayerNetwork, x)
return abs.(network.W2 * activation_function.(network.W1 * x .+ network.b1) .+ network.b2)
end

function predict(network::OneLayerNetworkWithLinearByPass, x)
y1 = network.W1 * x .+ network.b1
y2 = network.W2 * activation_function.(y1) .+ network.b2
y3 = network.W3 * x .+ network.b3
return abs.(y3) .+ abs.(y2)
end

function predict(network::OneLayerNetwork, x, activation::Function)
return abs.(network.W2 * activation.(network.W1 * x .+ network.b1) .+ network.b2)
end

function (network::SimpleNetwork)(x)
return predict(network, x)
end