diff --git a/emulate_sample/Project.toml b/emulate_sample/Project.toml new file mode 100644 index 0000000..c5fa476 --- /dev/null +++ b/emulate_sample/Project.toml @@ -0,0 +1,6 @@ +[deps] +AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a" +JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" +ProgressBars = "49802e3a-d2f1-5c88-81d8-b72133a6f568" diff --git a/emulate_sample/emulate_sample_catke.jl b/emulate_sample/emulate_sample_catke.jl new file mode 100644 index 0000000..cc14532 --- /dev/null +++ b/emulate_sample/emulate_sample_catke.jl @@ -0,0 +1,149 @@ +using JLD2, Enzyme, ProgressBars, Random, Statistics, AdvancedHMC, GLMakie + +Random.seed!(1234) +tic = time() +include("simple_networks.jl") +include("hmc_interface.jl") +include("optimization_utils.jl") + +data_directory = "" +data_file = "catke_parameters.jld2" + +jlfile = jldopen(data_directory * data_file, "r") +θ = jlfile["parameters"] +y = jlfile["objectives"] +close(jlfile) + +θ̄ = mean(θ) +θ̃ = std(θ) +ymax = maximum(y) +ymin = minimum(y) +yshift = ymin # ymin / 2 # +Δy = ymax - ymin # 2 * std(y) # +θr = (reshape(θ, (size(θ)[1] * size(θ)[2], size(θ)[3])) .- θ̄ ) ./ (2θ̃) +yr = (reshape(y, (size(y)[1] * size(y)[2])) .- yshift ) ./ Δy +M = size(θr)[1] +Mᴾ = size(θr)[2] + +# Define Network +Nθ = size(θr, 2) +Nθᴴ = Nθ ÷ 2 +W1 = randn(Nθᴴ, Nθ) +b1 = randn(Nθᴴ) +W2 = randn(1, Nθᴴ) +b2 = randn(1) +W3 = randn(1, Nθ) +b3 = randn(1) + +network = OneLayerNetworkWithLinearByPass(W1, b1, W2, b2, W3, b3) +dnetwork = deepcopy(network) +smoothed_network = deepcopy(network) + +## Emulate +adam = Adam(network) +batchsize = 100 +loss_list = Float64[] +test_loss_list = Float64[] +epochs = 100 +network_parameters = copy(parameters(network)) +for i in ProgressBar(1:epochs) + shuffled_list = chunk_list(shuffle(1:2:M), batchsize) + shuffled_test_list = chunk_list(shuffle(2:2:M), batchsize) + loss_value = 0.0 + N = length(shuffled_list) + # Batched Gradient Descent and Loss Evaluation + for permuted_list in ProgressBar(shuffled_list) + θbatch = [θr[x, :] for x in permuted_list] + ybatch = yr[permuted_list] + zero!(dnetwork) + autodiff(Enzyme.Reverse, loss, Active, DuplicatedNoNeed(network, dnetwork), Const(θbatch), Const(ybatch)) + update!(adam, network, dnetwork) + loss_value += loss(network, θbatch, ybatch) / N + end + push!(loss_list, loss_value) + # Test Loss + loss_value = 0.0 + N = length(shuffled_test_list) + for permuted_list in shuffled_test_list + θbatch = [θr[x, :] for x in permuted_list] + ybatch = yr[permuted_list] + loss_value += loss(network, θbatch, ybatch) / N + end + push!(test_loss_list, loss_value) + # Weighted Averaging of Network + m = 0.9 + network_parameters .= m * network_parameters + (1-m) * parameters(network) + set_parameters!(smoothed_network, network_parameters) +end + +loss_fig = Figure() +ax = Axis(loss_fig[1, 1]; title = "Log10 Loss", xlabel = "Epoch", ylabel = "Loss") +scatter!(ax, log10.(loss_list); color = :blue, label = "Training Loss") +scatter!(ax, log10.(test_loss_list); color = :red, label = "Test Loss") +axislegend(ax, position = :rt) +display(loss_fig) + +## Sample +# Define logp and ∇logp and regularizer + +initial_θ = copy(θr[argmin(yr), :]) +mintheta = minimum(θr, dims = 1)[:] +maxtheta = maximum(θr, dims = 1)[:] +reg = Regularizer([mintheta, maxtheta, initial_θ]) + +function (regularizer::Regularizer)(x) + if any(x .≤ regularizer.parameters[1]) + return -Inf + elseif any(x .> regularizer.parameters[2]) + return -Inf + else + return -sum(abs.(x - regularizer.parameters[3]) ./ (regularizer.parameters[2] - regularizer.parameters[1])) + end + return 0.0 +end + +scale = 10 * Δy # minimum(yr) +regularization_scale = 0.001/2 * scale + +U = LogDensity(network, reg, scale, regularization_scale) +∇U = GradientLogDensity(U) + +# HMC +D = size(θr, 2) +n_samples = 10000 +n_adapts = 1000 + +metric = DiagEuclideanMetric(D) +hamiltonian = Hamiltonian(metric, GaussianKinetic(), U, ∇U) + +initial_ϵ = find_good_stepsize(hamiltonian, initial_θ) +integrator = Leapfrog(initial_ϵ) + +kernel = HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn())) +adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(0.8, integrator)) + +samples, stats = sample(hamiltonian, kernel, initial_θ, n_samples, adaptor, n_adapts; progress=true) + +toc = time() +println("Elapsed time: $((toc - tic)/60) minutes") + +# Plot +fig = Figure() +Mp = 5 +for i in 1:23 + ii = (i-1)÷Mp + 1 + jj = (i-1)%Mp + 1 + ax = Axis(fig[ii, jj]; title = "Parameter $i") + v1 = ([sample[i] for sample in samples] .* 2θ̃) .+ θ̄ + hist!(ax, v1; bins = 50, strokewidth = 0, color = :blue, normalization = :pdf) + xlims!(ax, -0.1, (reg.parameters[2][i]* 2θ̃ + θ̄) * 1.1) + density!(ax, v1; color = (:red, 0.1), strokewidth = 3, strokecolor = :red) +end +display(fig) + +imin = argmax([stat.log_density for stat in stats]) +imax = argmin([stat.log_density for stat in stats]) +network(samples[imin]) +θ₀ = (initial_θ .* 2θ̃) .+ θ̄ +((mean(samples) .* 2θ̃) .+ θ̄) - θ₀ +((samples[imin] .* 2θ̃) .+ θ̄) - θ₀ \ No newline at end of file diff --git a/emulate_sample/hmc_interface.jl b/emulate_sample/hmc_interface.jl new file mode 100644 index 0000000..ca7f1e7 --- /dev/null +++ b/emulate_sample/hmc_interface.jl @@ -0,0 +1,42 @@ +struct LogDensity{N, S, M} + logp::N + regularization::S + scale::M + scale_regularization::M +end + +# Negative sign if the network represents the potential function +# Note: regularization should be negative semi-definite +function (logp::LogDensity{T})(θ) where T <: SimpleNetwork + return -logp.logp(θ)[1] * logp.scale + logp.regularization(θ) * logp.scale_regularization +end + +function LogDensity(network::SimpleNetwork) + regularization(x) = 0.0 + return LogDensity(network, regularization, 1.0, 1.0) +end + +function LogDensity(network::SimpleNetwork, scale) + regularization(x) = 0.0 + return LogDensity(network, regularization, scale, 1.0) +end + +struct GradientLogDensity{N} + logp::N + dθ::Vector{Float64} +end + +function GradientLogDensity(logp::LogDensity{S}) where S <: SimpleNetwork + dθ = zeros(size(logp.logp.W1, 2)) + return GradientLogDensity(logp, dθ) +end + +function (∇logp::GradientLogDensity)(θ) + ∇logp.dθ .= 0.0 + autodiff(Enzyme.Reverse, Const(∇logp.logp), Active, DuplicatedNoNeed(θ, ∇logp.dθ)) + return (∇logp.logp(θ), copy(∇logp.dθ)) +end + +struct Regularizer{F} + parameters::F +end diff --git a/emulate_sample/optimization_utils.jl b/emulate_sample/optimization_utils.jl new file mode 100644 index 0000000..7b44851 --- /dev/null +++ b/emulate_sample/optimization_utils.jl @@ -0,0 +1,85 @@ +function loss(network::SimpleNetwork, x, y) + ŷ = similar(y) + for i in eachindex(ŷ) + ŷ[i] = predict(network, x[i])[1] + end + return mean((y .- ŷ) .^ 2) +end + +function chunk_list(list, n) + return [list[i:min(i+n-1, length(list))] for i in 1:n:length(list)] +end + +struct Adam{S, T, I} + struct_copies::S + parameters::T + t::I +end + +function parameters(network::SimpleNetwork) + network_parameters = [] + for names in propertynames(network) + push!(network_parameters, getproperty(network, names)[:]) + end + param_lengths = [length(params) for params in network_parameters] + parameter_list = zeros(sum(param_lengths)) + start = 1 + for i in 1:length(param_lengths) + parameter_list[start:start+param_lengths[i]-1] .= network_parameters[i] + start += param_lengths[i] + end + return parameter_list +end + +function set_parameters!(network::SimpleNetwork, parameters_list) + param_lengths = Int64[] + for names in propertynames(network) + push!(param_lengths, length(getproperty(network, names)[:])) + end + start = 1 + for (i, names) in enumerate(propertynames(network)) + getproperty(network, names)[:] .= parameters_list[start:start+param_lengths[i]-1] + start = start + param_lengths[i] + end + return nothing +end + +function Adam(network::SimpleNetwork; α=0.001, β₁=0.9, β₂=0.999, ϵ=1e-8) + parameters_list = (; α, β₁, β₂, ϵ) + network_parameters = parameters(network) + t = [1.0] + θ = deepcopy(network_parameters) .* 0.0 + gₜ = deepcopy(network_parameters) .* 0.0 + m₀ = deepcopy(network_parameters) .* 0.0 + mₜ = deepcopy(network_parameters) .* 0.0 + v₀ = deepcopy(network_parameters) .* 0.0 + vₜ = deepcopy(network_parameters) .* 0.0 + v̂ₜ = deepcopy(network_parameters) .* 0.0 + m̂ₜ = deepcopy(network_parameters) .* 0.0 + struct_copies = (; θ, gₜ, m₀, mₜ, v₀, vₜ, v̂ₜ, m̂ₜ) + return Adam(struct_copies, parameters_list, t) +end + + +function update!(adam::Adam, network::SimpleNetwork, dnetwork::SimpleNetwork) + # unpack + (; α, β₁, β₂, ϵ) = adam.parameters + t = adam.t[1] + (; θ, gₜ, m₀, mₜ, v₀, vₜ, v̂ₜ, m̂ₜ) = adam.struct_copies + t = adam.t[1] + # get gradient + θ .= parameters(network) + gₜ .= parameters(dnetwork) + # update + @. mₜ = β₁ * m₀ + (1 - β₁) * gₜ + @. vₜ = β₂ * v₀ + (1 - β₂) * (gₜ .^2) + @. m̂ₜ = mₜ / (1 - β₁^t) + @. v̂ₜ = vₜ / (1 - β₂^t) + @. θ = θ - α * m̂ₜ / (sqrt(v̂ₜ) + ϵ) + # update parameters + m₀ .= mₜ + v₀ .= vₜ + adam.t[1] += 1 + set_parameters!(network, θ) + return nothing +end \ No newline at end of file diff --git a/emulate_sample/simple_networks.jl b/emulate_sample/simple_networks.jl new file mode 100644 index 0000000..8b921f5 --- /dev/null +++ b/emulate_sample/simple_networks.jl @@ -0,0 +1,67 @@ +using Enzyme + +abstract type SimpleNetwork end + +struct OneLayerNetwork{M, V} <: SimpleNetwork + W1::M + b1::V + W2::M + b2::V +end + +struct OneLayerNetworkWithLinearByPass{M,V} <: SimpleNetwork + W1::M + b1::V + W2::M + b2::V + W3::M + b3::V +end + +function zero!(dnetwork::OneLayerNetworkWithLinearByPass) + dnetwork.W1 .= 0.0 + dnetwork.b1 .= 0.0 + dnetwork.W2 .= 0.0 + dnetwork.b2 .= 0.0 + dnetwork.W3 .= 0.0 + dnetwork.b3 .= 0.0 + return nothing +end + +function zero!(dnetwork::OneLayerNetwork) + dnetwork.W1 .= 0.0 + dnetwork.b1 .= 0.0 + dnetwork.W2 .= 0.0 + dnetwork.b2 .= 0.0 + return nothing +end + +function update!(network::SimpleNetwork, dnetwork::SimpleNetwork, η) + network.W1 .-= η .* dnetwork.W1 + network.b1 .-= η .* dnetwork.b1 + network.W2 .-= η .* dnetwork.W2 + network.b2 .-= η .* dnetwork.b2 + return nothing +end + +swish(x) = x / (1 + exp(-x)) +activation_function(x) = swish(x) # tanh(x) # + +function predict(network::OneLayerNetwork, x) + return abs.(network.W2 * activation_function.(network.W1 * x .+ network.b1) .+ network.b2) +end + +function predict(network::OneLayerNetworkWithLinearByPass, x) + y1 = network.W1 * x .+ network.b1 + y2 = network.W2 * activation_function.(y1) .+ network.b2 + y3 = network.W3 * x .+ network.b3 + return abs.(y3) .+ abs.(y2) +end + +function predict(network::OneLayerNetwork, x, activation::Function) + return abs.(network.W2 * activation.(network.W1 * x .+ network.b1) .+ network.b2) +end + +function (network::SimpleNetwork)(x) + return predict(network, x) +end \ No newline at end of file