From 3ea9a384ceeb89a3fbe5163af93a4699b0c04fb7 Mon Sep 17 00:00:00 2001 From: Andre Souza Date: Wed, 27 Nov 2024 18:50:03 -0500 Subject: [PATCH 1/5] emulate sample --- emulate_sample/Project.toml | 6 ++ emulate_sample/emulate_sample_catke.jl | 89 ++++++++++++++++++++++++++ emulate_sample/hmc_interface.jl | 38 +++++++++++ emulate_sample/optimization_utils.jl | 85 ++++++++++++++++++++++++ emulate_sample/simple_networks.jl | 41 ++++++++++++ 5 files changed, 259 insertions(+) create mode 100644 emulate_sample/Project.toml create mode 100644 emulate_sample/emulate_sample_catke.jl create mode 100644 emulate_sample/hmc_interface.jl create mode 100644 emulate_sample/optimization_utils.jl create mode 100644 emulate_sample/simple_networks.jl diff --git a/emulate_sample/Project.toml b/emulate_sample/Project.toml new file mode 100644 index 0000000..c5fa476 --- /dev/null +++ b/emulate_sample/Project.toml @@ -0,0 +1,6 @@ +[deps] +AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a" +JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" +ProgressBars = "49802e3a-d2f1-5c88-81d8-b72133a6f568" diff --git a/emulate_sample/emulate_sample_catke.jl b/emulate_sample/emulate_sample_catke.jl new file mode 100644 index 0000000..13b5052 --- /dev/null +++ b/emulate_sample/emulate_sample_catke.jl @@ -0,0 +1,89 @@ +using JLD2, Enzyme, ProgressBars, Random, Statistics, AdvancedHMC, GLMakie + +Random.seed!(1234) +tic = time() +include("simple_networks.jl") +include("hmc_interface.jl") +include("optimization_utils.jl") + +data_directory = "/Users/andresouza/Desktop/Repositories/GenericAnalysis.jl/" +data_file = "catke_parameters.jld2" + +jlfile = jldopen(data_directory * data_file, "r") +θ = jlfile["parameters"] +y = jlfile["objectives"] +close(jlfile) + +θr = reshape(θ, (size(θ)[1] * size(θ)[2], size(θ)[3])) +yr = reshape(y, (size(y)[1] * size(y)[2])) +M = size(θr)[1] +Mᴾ = size(θr)[2] + +# Define Network +Nθ = size(θr, 2) +Nθᴴ = Nθ * 10 +W1 = randn(Nθᴴ, Nθ) +b1 = randn(Nθᴴ) +W2 = randn(1, Nθᴴ) +b2 = randn(1) + +network = OneLayerNetwork(W1, b1, W2, b2) +dnetwork = deepcopy(network) + +## Emulate +# Optimize with Gradient Descent and Learning rate 1e-5 +batchsize = 10 +loss_list = Float64[] +epochs = 10 +for i in ProgressBar(1:epochs) + shuffled_list = chunk_list(shuffle(1:M), batchsize) + loss_value = 0.0 + N = length(shuffled_list) + for permuted_list in ProgressBar(shuffled_list) + θbatch = [θr[x, :] for x in permuted_list] + ybatch = yr[permuted_list] + zero!(dnetwork) + autodiff(Enzyme.Reverse, loss, Active, DuplicatedNoNeed(network, dnetwork), Const(θbatch), Const(ybatch)) + update!(network, dnetwork, 1e-5) + loss_value += loss(network, θbatch, ybatch) / N + end + push!(loss_list, loss_value) +end + +## Sample +# HMC +scale = 1e3 +U = LogDensity(network, scale) +∇U = GradientLogDensity(U) + +D = size(θr, 2) +initial_θ = copy(θr[end, :]) +n_samples = 10000 +n_adapts = 1000 + +metric = DiagEuclideanMetric(D) +hamiltonian = Hamiltonian(metric, GaussianKinetic(), U, ∇U) + +initial_ϵ = find_good_stepsize(hamiltonian, initial_θ) +integrator = Leapfrog(initial_ϵ) + +kernel = HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn())) +adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(0.8, integrator)) + +samples, stats = sample(hamiltonian, kernel, initial_θ, n_samples, adaptor, n_adapts; progress=true) + +toc = time() +println("Elapsed time: $((toc - tic)/60) minutes") + +# Plot +fig = Figure() +Mp = 5 +for i in 1:23 + ii = (i-1)÷Mp + 1 + jj = (i-1)%Mp + 1 + ax = Axis(fig[ii, jj]; title = "Parameter $i") + v1 = [sample[i] for sample in samples] + hist!(ax, v1; bins = 50, strokewidth = 0, color = :blue, normalization = :pdf) + density!(ax, v1; color = (:red, 0.1), strokewidth = 3, strokecolor = :red) +end +display(fig) \ No newline at end of file diff --git a/emulate_sample/hmc_interface.jl b/emulate_sample/hmc_interface.jl new file mode 100644 index 0000000..af8ff09 --- /dev/null +++ b/emulate_sample/hmc_interface.jl @@ -0,0 +1,38 @@ +struct LogDensity{N, S, M} + logp::N + regularization::S + scale::M + scale_regularization::M +end + +# Negative sign if the network represents the potential function +# Note: regularization should be negative semi-definite +function (logp::LogDensity{T})(θ) where T <: SimpleNetwork + return -logp.logp(θ)[1] * logp.scale + logp.regularization(θ) * logp.scale_regularization +end + +function LogDensity(network::SimpleNetwork) + regularization(x) = 0.0 + return LogDensity(network, regularization, 1.0, 1.0) +end + +function LogDensity(network::SimpleNetwork, scale) + regularization(x) = 0.0 + return LogDensity(network, regularization, scale, 1.0) +end + +struct GradientLogDensity{N} + logp::N + dθ::Vector{Float64} +end + +function GradientLogDensity(logp::LogDensity{S}) where S <: SimpleNetwork + dθ = zeros(size(logp.logp.W1, 2)) + return GradientLogDensity(logp, dθ) +end + +function (∇logp::GradientLogDensity)(θ) + ∇logp.dθ .= 0.0 + autodiff(Enzyme.Reverse, Const(∇logp.logp), Active, DuplicatedNoNeed(θ, ∇logp.dθ)) + return (∇logp.logp(θ), copy(∇logp.dθ)) +end \ No newline at end of file diff --git a/emulate_sample/optimization_utils.jl b/emulate_sample/optimization_utils.jl new file mode 100644 index 0000000..01095a8 --- /dev/null +++ b/emulate_sample/optimization_utils.jl @@ -0,0 +1,85 @@ +function loss(network::SimpleNetwork, x, y) + ŷ = similar(y) + for i in eachindex(ŷ) + ŷ[i] = predict(network, x[i])[1] + end + return mean((y .- ŷ) .^ 2) +end + +function chunk_list(list, n) + return [list[i:i+n-1] for i in 1:n:length(list)] +end + +struct Adam{S, T, I} + struct_copies::S + parameters::T + t::I +end + +function parameters(network::SimpleNetwork) + network_parameters = [] + for names in propertynames(network) + push!(network_parameters, getproperty(network, names)[:]) + end + param_lengths = [length(params) for params in network_parameters] + parameter_list = zeros(sum(param_lengths)) + start = 1 + for i in 1:length(param_lengths) + parameter_list[start:start+param_lengths[i]-1] .= network_parameters[i] + start += param_lengths[i] + end + return parameter_list +end + +function set_parameters!(network::SimpleNetwork, parameters_list) + param_lengths = Int64[] + for names in propertynames(network) + push!(param_lengths, length(getproperty(network, names)[:])) + end + start = 1 + for (i, names) in enumerate(propertynames(network)) + getproperty(network, names)[:] .= parameters_list[start:start+param_lengths[i]-1] + end + return nothing +end + +function Adam(network::SimpleNetwork) + parameters_list = (; α = 0.001, β₁ = 0.9, β₂ = 0.999, ϵ = 1e-8) + network_parameters = parameters(network) + t = [1.0] + θ = deepcopy(network_parameters) .* 0.0 + gₜ = deepcopy(network_parameters) .* 0.0 + m₀ = deepcopy(network_parameters) .* 0.0 + mₜ = deepcopy(network_parameters) .* 0.0 + v₀ = deepcopy(network_parameters) .* 0.0 + vₜ = deepcopy(network_parameters) .* 0.0 + v̂ₜ = deepcopy(network_parameters) .* 0.0 + m̂ₜ = deepcopy(network_parameters) .* 0.0 + struct_copies = (; θ, gₜ, m₀, mₜ, v₀, vₜ, v̂ₜ, m̂ₜ) + return Adam(struct_copies, parameters_list, t) +end + +function update!(adam::Adam, network::SimpleNetwork, dnetwork::SimpleNetwork) + # unpack + (; α, β₁, β₂, ϵ) = adam.parameters + t = adam.t[1] + (; θ, gₜ, m₀, mₜ, v₀, vₜ, v̂ₜ, m̂ₜ) = adam.struct_copies + t = adam.t[1] + # get gradient + θ .= parameters(network) + gₜ .= parameters(dnetwork) + # update + @. m₀ = β₁ * m₀ + @. mₜ = m₀ + (1 - β₁) * gₜ + @. v₀ = β₂ * v₀ + @. vₜ = v₀ + (1 - β₂) * (gₜ .^2) + @. v̂ₜ = vₜ / (1 - β₂^t) + @. m̂ₜ = mₜ / (1 - β₁^t) + @. θ = θ - α * m̂ₜ / (sqrt(v̂ₜ) + ϵ) + # update parameters + m₀ .= mₜ + v₀ .= vₜ + adam.t[1] += 1 + set_parameters!(network, θ) + return nothing +end diff --git a/emulate_sample/simple_networks.jl b/emulate_sample/simple_networks.jl new file mode 100644 index 0000000..be9bda0 --- /dev/null +++ b/emulate_sample/simple_networks.jl @@ -0,0 +1,41 @@ +using Enzyme + +abstract type SimpleNetwork end + +struct OneLayerNetwork{M, V} <: SimpleNetwork + W1::M + b1::V + W2::M + b2::V +end + +function zero!(dnetwork::OneLayerNetwork) + dnetwork.W1 .= 0.0 + dnetwork.b1 .= 0.0 + dnetwork.W2 .= 0.0 + dnetwork.b2 .= 0.0 + return nothing +end + +function update!(network::SimpleNetwork, dnetwork::SimpleNetwork, η) + network.W1 .-= η .* dnetwork.W1 + network.b1 .-= η .* dnetwork.b1 + network.W2 .-= η .* dnetwork.W2 + network.b2 .-= η .* dnetwork.b2 + return nothing +end + +swish(x) = x / (1 + exp(-x)) +activation_function(x) = tanh(x) + +function predict(network::OneLayerNetwork, x) + return abs.(network.W2 * activation_function.(network.W1 * x .+ network.b1) .+ network.b2) +end + +function predict(network::OneLayerNetwork, x, activation::Function) + return abs.(network.W2 * activation.(network.W1 * x .+ network.b1) .+ network.b2) +end + +function (network::SimpleNetwork)(x) + return predict(network, x) +end \ No newline at end of file From f4f69605bda506cfd7fdc245fd2dd0ef61aedecb Mon Sep 17 00:00:00 2001 From: Andre Souza Date: Wed, 27 Nov 2024 18:50:30 -0500 Subject: [PATCH 2/5] empty data directory --- emulate_sample/emulate_sample_catke.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/emulate_sample/emulate_sample_catke.jl b/emulate_sample/emulate_sample_catke.jl index 13b5052..636e382 100644 --- a/emulate_sample/emulate_sample_catke.jl +++ b/emulate_sample/emulate_sample_catke.jl @@ -6,7 +6,7 @@ include("simple_networks.jl") include("hmc_interface.jl") include("optimization_utils.jl") -data_directory = "/Users/andresouza/Desktop/Repositories/GenericAnalysis.jl/" +data_directory = "" data_file = "catke_parameters.jld2" jlfile = jldopen(data_directory * data_file, "r") From 192f01dfef18841a2e9076739341fb719f72fe7f Mon Sep 17 00:00:00 2001 From: Andre Souza Date: Wed, 27 Nov 2024 19:01:45 -0500 Subject: [PATCH 3/5] add regularization --- emulate_sample/emulate_sample_catke.jl | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/emulate_sample/emulate_sample_catke.jl b/emulate_sample/emulate_sample_catke.jl index 636e382..b539cae 100644 --- a/emulate_sample/emulate_sample_catke.jl +++ b/emulate_sample/emulate_sample_catke.jl @@ -52,13 +52,26 @@ end ## Sample # HMC -scale = 1e3 -U = LogDensity(network, scale) +scale = 1.0e2 + +function regularize(x) + if any(x .≤ 0.0) + return -Inf + elseif any(x .> 8.0) + return -Inf + else + return 0.0 + end + return +end +regularization_scale = 1.0 + +U = LogDensity(network, regularize, scale, regularization_scale) ∇U = GradientLogDensity(U) D = size(θr, 2) initial_θ = copy(θr[end, :]) -n_samples = 10000 +n_samples = 1000 n_adapts = 1000 metric = DiagEuclideanMetric(D) From ac14f4054f3cc3361d3ba051b8a904b4c1a200cb Mon Sep 17 00:00:00 2001 From: Andre Souza Date: Fri, 29 Nov 2024 17:53:47 -0500 Subject: [PATCH 4/5] updates --- emulate_sample/emulate_sample_catke.jl | 90 +++++++++++++++++++------- emulate_sample/hmc_interface.jl | 6 +- emulate_sample/optimization_utils.jl | 20 +++--- emulate_sample/simple_networks.jl | 28 +++++++- 4 files changed, 110 insertions(+), 34 deletions(-) diff --git a/emulate_sample/emulate_sample_catke.jl b/emulate_sample/emulate_sample_catke.jl index b539cae..85a15e3 100644 --- a/emulate_sample/emulate_sample_catke.jl +++ b/emulate_sample/emulate_sample_catke.jl @@ -14,64 +14,103 @@ jlfile = jldopen(data_directory * data_file, "r") y = jlfile["objectives"] close(jlfile) -θr = reshape(θ, (size(θ)[1] * size(θ)[2], size(θ)[3])) -yr = reshape(y, (size(y)[1] * size(y)[2])) +θ̄ = mean(θ) +θ̃ = std(θ) +ymax = maximum(y) +ymin = minimum(y) +yshift = ymin # ymin / 2 # +Δy = ymax - ymin # 2 * std(y) # +θr = (reshape(θ, (size(θ)[1] * size(θ)[2], size(θ)[3])) .- θ̄ ) ./ (2θ̃) +yr = (reshape(y, (size(y)[1] * size(y)[2])) .- yshift ) ./ Δy M = size(θr)[1] Mᴾ = size(θr)[2] # Define Network Nθ = size(θr, 2) -Nθᴴ = Nθ * 10 +Nθᴴ = Nθ ÷ 20 W1 = randn(Nθᴴ, Nθ) b1 = randn(Nθᴴ) W2 = randn(1, Nθᴴ) b2 = randn(1) +W3 = randn(1, Nθ) +b3 = randn(1) -network = OneLayerNetwork(W1, b1, W2, b2) +network = OneLayerNetworkWithLinearByPass(W1, b1, W2, b2, W3, b3) dnetwork = deepcopy(network) +smoothed_network = deepcopy(network) ## Emulate -# Optimize with Gradient Descent and Learning rate 1e-5 -batchsize = 10 +adam = Adam(network) +batchsize = 100 loss_list = Float64[] -epochs = 10 +test_loss_list = Float64[] +epochs = 20 +network_parameters = copy(parameters(network)) for i in ProgressBar(1:epochs) - shuffled_list = chunk_list(shuffle(1:M), batchsize) + shuffled_list = chunk_list(shuffle(1:2:M), batchsize) + shuffled_test_list = chunk_list(shuffle(2:2:M), batchsize) loss_value = 0.0 N = length(shuffled_list) + # Batched Gradient Descent and Loss Evaluation for permuted_list in ProgressBar(shuffled_list) θbatch = [θr[x, :] for x in permuted_list] ybatch = yr[permuted_list] zero!(dnetwork) autodiff(Enzyme.Reverse, loss, Active, DuplicatedNoNeed(network, dnetwork), Const(θbatch), Const(ybatch)) - update!(network, dnetwork, 1e-5) + update!(adam, network, dnetwork) loss_value += loss(network, θbatch, ybatch) / N end push!(loss_list, loss_value) + # Test Loss + loss_value = 0.0 + N = length(shuffled_test_list) + for permuted_list in shuffled_test_list + θbatch = [θr[x, :] for x in permuted_list] + ybatch = yr[permuted_list] + loss_value += loss(network, θbatch, ybatch) / N + end + push!(test_loss_list, loss_value) + # Weighted Averaging of Network + m = 0.9 + network_parameters .= m * network_parameters + (1-m) * parameters(network) + set_parameters!(smoothed_network, network_parameters) end +loss_fig = Figure() +ax = Axis(loss_fig[1, 1]; title = "Log10 Loss", xlabel = "Epoch", ylabel = "Loss") +scatter!(ax, log10.(loss_list); color = :blue, label = "Training Loss") +scatter!(ax, log10.(test_loss_list); color = :red, label = "Test Loss") +axislegend(ax, position = :rt) +display(loss_fig) + ## Sample -# HMC -scale = 1.0e2 +# Define logp and ∇logp and regularizer + +initial_θ = copy(θr[argmin(yr), :]) +mintheta = minimum(θr, dims = 1)[:] +maxtheta = maximum(θr, dims = 1)[:] +reg = Regularizer([mintheta, maxtheta, initial_θ]) -function regularize(x) - if any(x .≤ 0.0) +function (regularizer::Regularizer)(x) + if any(x .≤ regularizer.parameters[1]) return -Inf - elseif any(x .> 8.0) + elseif any(x .> regularizer.parameters[2]) return -Inf else - return 0.0 + return -sum(abs.(x - regularizer.parameters[3])) end - return + return 0.0 end -regularization_scale = 1.0 -U = LogDensity(network, regularize, scale, regularization_scale) +scale = 10 * Δy # minimum(yr) +regularization_scale = 0.001 * scale + +U = LogDensity(network, reg, scale, regularization_scale) ∇U = GradientLogDensity(U) +# HMC D = size(θr, 2) -initial_θ = copy(θr[end, :]) -n_samples = 1000 +n_samples = 10000 n_adapts = 1000 metric = DiagEuclideanMetric(D) @@ -95,8 +134,15 @@ for i in 1:23 ii = (i-1)÷Mp + 1 jj = (i-1)%Mp + 1 ax = Axis(fig[ii, jj]; title = "Parameter $i") - v1 = [sample[i] for sample in samples] + v1 = ([sample[i] for sample in samples] .* 2θ̃) .+ θ̄ hist!(ax, v1; bins = 50, strokewidth = 0, color = :blue, normalization = :pdf) + xlims!(ax, -0.1, (reg.parameters[2][i]* 2θ̃ + θ̄) * 1.1) density!(ax, v1; color = (:red, 0.1), strokewidth = 3, strokecolor = :red) end -display(fig) \ No newline at end of file +display(fig) + +imin = argmax([stat.log_density for stat in stats]) +imax = argmin([stat.log_density for stat in stats]) +network(samples[imin]) +mean(samples) - initial_θ +samples[imin] - initial_θ \ No newline at end of file diff --git a/emulate_sample/hmc_interface.jl b/emulate_sample/hmc_interface.jl index af8ff09..ca7f1e7 100644 --- a/emulate_sample/hmc_interface.jl +++ b/emulate_sample/hmc_interface.jl @@ -35,4 +35,8 @@ function (∇logp::GradientLogDensity)(θ) ∇logp.dθ .= 0.0 autodiff(Enzyme.Reverse, Const(∇logp.logp), Active, DuplicatedNoNeed(θ, ∇logp.dθ)) return (∇logp.logp(θ), copy(∇logp.dθ)) -end \ No newline at end of file +end + +struct Regularizer{F} + parameters::F +end diff --git a/emulate_sample/optimization_utils.jl b/emulate_sample/optimization_utils.jl index 01095a8..7b44851 100644 --- a/emulate_sample/optimization_utils.jl +++ b/emulate_sample/optimization_utils.jl @@ -7,7 +7,7 @@ function loss(network::SimpleNetwork, x, y) end function chunk_list(list, n) - return [list[i:i+n-1] for i in 1:n:length(list)] + return [list[i:min(i+n-1, length(list))] for i in 1:n:length(list)] end struct Adam{S, T, I} @@ -39,12 +39,13 @@ function set_parameters!(network::SimpleNetwork, parameters_list) start = 1 for (i, names) in enumerate(propertynames(network)) getproperty(network, names)[:] .= parameters_list[start:start+param_lengths[i]-1] + start = start + param_lengths[i] end return nothing end -function Adam(network::SimpleNetwork) - parameters_list = (; α = 0.001, β₁ = 0.9, β₂ = 0.999, ϵ = 1e-8) +function Adam(network::SimpleNetwork; α=0.001, β₁=0.9, β₂=0.999, ϵ=1e-8) + parameters_list = (; α, β₁, β₂, ϵ) network_parameters = parameters(network) t = [1.0] θ = deepcopy(network_parameters) .* 0.0 @@ -59,6 +60,7 @@ function Adam(network::SimpleNetwork) return Adam(struct_copies, parameters_list, t) end + function update!(adam::Adam, network::SimpleNetwork, dnetwork::SimpleNetwork) # unpack (; α, β₁, β₂, ϵ) = adam.parameters @@ -69,17 +71,15 @@ function update!(adam::Adam, network::SimpleNetwork, dnetwork::SimpleNetwork) θ .= parameters(network) gₜ .= parameters(dnetwork) # update - @. m₀ = β₁ * m₀ - @. mₜ = m₀ + (1 - β₁) * gₜ - @. v₀ = β₂ * v₀ - @. vₜ = v₀ + (1 - β₂) * (gₜ .^2) - @. v̂ₜ = vₜ / (1 - β₂^t) + @. mₜ = β₁ * m₀ + (1 - β₁) * gₜ + @. vₜ = β₂ * v₀ + (1 - β₂) * (gₜ .^2) @. m̂ₜ = mₜ / (1 - β₁^t) - @. θ = θ - α * m̂ₜ / (sqrt(v̂ₜ) + ϵ) + @. v̂ₜ = vₜ / (1 - β₂^t) + @. θ = θ - α * m̂ₜ / (sqrt(v̂ₜ) + ϵ) # update parameters m₀ .= mₜ v₀ .= vₜ adam.t[1] += 1 set_parameters!(network, θ) return nothing -end +end \ No newline at end of file diff --git a/emulate_sample/simple_networks.jl b/emulate_sample/simple_networks.jl index be9bda0..8b921f5 100644 --- a/emulate_sample/simple_networks.jl +++ b/emulate_sample/simple_networks.jl @@ -9,6 +9,25 @@ struct OneLayerNetwork{M, V} <: SimpleNetwork b2::V end +struct OneLayerNetworkWithLinearByPass{M,V} <: SimpleNetwork + W1::M + b1::V + W2::M + b2::V + W3::M + b3::V +end + +function zero!(dnetwork::OneLayerNetworkWithLinearByPass) + dnetwork.W1 .= 0.0 + dnetwork.b1 .= 0.0 + dnetwork.W2 .= 0.0 + dnetwork.b2 .= 0.0 + dnetwork.W3 .= 0.0 + dnetwork.b3 .= 0.0 + return nothing +end + function zero!(dnetwork::OneLayerNetwork) dnetwork.W1 .= 0.0 dnetwork.b1 .= 0.0 @@ -26,12 +45,19 @@ function update!(network::SimpleNetwork, dnetwork::SimpleNetwork, η) end swish(x) = x / (1 + exp(-x)) -activation_function(x) = tanh(x) +activation_function(x) = swish(x) # tanh(x) # function predict(network::OneLayerNetwork, x) return abs.(network.W2 * activation_function.(network.W1 * x .+ network.b1) .+ network.b2) end +function predict(network::OneLayerNetworkWithLinearByPass, x) + y1 = network.W1 * x .+ network.b1 + y2 = network.W2 * activation_function.(y1) .+ network.b2 + y3 = network.W3 * x .+ network.b3 + return abs.(y3) .+ abs.(y2) +end + function predict(network::OneLayerNetwork, x, activation::Function) return abs.(network.W2 * activation.(network.W1 * x .+ network.b1) .+ network.b2) end From 591ca919d69f4284c7f3b31448ce13733232371b Mon Sep 17 00:00:00 2001 From: Andre Souza Date: Mon, 2 Dec 2024 12:46:50 -0500 Subject: [PATCH 5/5] small update to regularization --- emulate_sample/emulate_sample_catke.jl | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/emulate_sample/emulate_sample_catke.jl b/emulate_sample/emulate_sample_catke.jl index 85a15e3..cc14532 100644 --- a/emulate_sample/emulate_sample_catke.jl +++ b/emulate_sample/emulate_sample_catke.jl @@ -27,7 +27,7 @@ Mᴾ = size(θr)[2] # Define Network Nθ = size(θr, 2) -Nθᴴ = Nθ ÷ 20 +Nθᴴ = Nθ ÷ 2 W1 = randn(Nθᴴ, Nθ) b1 = randn(Nθᴴ) W2 = randn(1, Nθᴴ) @@ -44,7 +44,7 @@ adam = Adam(network) batchsize = 100 loss_list = Float64[] test_loss_list = Float64[] -epochs = 20 +epochs = 100 network_parameters = copy(parameters(network)) for i in ProgressBar(1:epochs) shuffled_list = chunk_list(shuffle(1:2:M), batchsize) @@ -97,13 +97,13 @@ function (regularizer::Regularizer)(x) elseif any(x .> regularizer.parameters[2]) return -Inf else - return -sum(abs.(x - regularizer.parameters[3])) + return -sum(abs.(x - regularizer.parameters[3]) ./ (regularizer.parameters[2] - regularizer.parameters[1])) end return 0.0 end scale = 10 * Δy # minimum(yr) -regularization_scale = 0.001 * scale +regularization_scale = 0.001/2 * scale U = LogDensity(network, reg, scale, regularization_scale) ∇U = GradientLogDensity(U) @@ -144,5 +144,6 @@ display(fig) imin = argmax([stat.log_density for stat in stats]) imax = argmin([stat.log_density for stat in stats]) network(samples[imin]) -mean(samples) - initial_θ -samples[imin] - initial_θ \ No newline at end of file +θ₀ = (initial_θ .* 2θ̃) .+ θ̄ +((mean(samples) .* 2θ̃) .+ θ̄) - θ₀ +((samples[imin] .* 2θ̃) .+ θ̄) - θ₀ \ No newline at end of file