Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add slurm workers for calibration end-to-end test #3461

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,11 @@ steps:
command: julia --project=calibration/test calibration/test/interface.jl
- label: "end to end test"
command: julia --project=calibration/test calibration/test/e2e_test.jl
agents:
slurm_ntasks: 12
slurm_cpus_per_task: 1
slurm_mem: 96GB
slurm_time: "00:10:00"
artifact_paths: "calibration_end_to_end_test/*"
soft_fail: true

Expand Down
31 changes: 0 additions & 31 deletions .github/workflows/calibration_test.yml

This file was deleted.

3 changes: 2 additions & 1 deletion calibration/test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ ClimaAtmos = "b2c96348-7fb7-4fe0-8da9-78d88439e717"
ClimaCalibrate = "4347a170-ebd6-470c-89d3-5c705c0cacc2"
ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
ClimaUtilities = "b3f4f4ca-9299-4f7f-bd9b-81e1242a7513"
ClusterManagers = "34f1f09b-3a8b-5176-ab39-66d58a4d544e"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
EnsembleKalmanProcesses = "aa8a2aa5-91d8-4396-bcef-d4f2ec43552d"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6"

[compat]
ClimaCalibrate = "0.0.2 - 0.0.4"
ClimaCalibrate = "0.0.5"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file misses a newline

290 changes: 149 additions & 141 deletions calibration/test/e2e_test.jl
Original file line number Diff line number Diff line change
@@ -1,67 +1,11 @@
#= End-to-end test
Runs a perfect model calibration, calibrating on the parameter `astronomical_unit`
with top-of-atmosphere radiative shortwave flux in the loss function.

The calibration is run twice, once on the backend obtained via `get_backend()`
and once on the `JuliaBackend`. The output of each calibration is tested individually
and compared to ensure reproducibility.
=#
using Distributed, ClusterManagers
import ClimaCalibrate as CAL
import ClimaAtmos as CA
import ClimaAnalysis: SimDir, get, slice, average_xy
import CairoMakie
import JLD2
import LinearAlgebra: I
import EnsembleKalmanProcesses as EKP
import Statistics: var, mean
using Test

using Dates

# Debug plots
function scatter_plot(eki::EKP.EnsembleKalmanProcess)
f = CairoMakie.Figure(resolution = (800, 600))
ax = CairoMakie.Axis(
f[1, 1],
ylabel = "Parameter Value",
xlabel = "Top of atmosphere radiative SW flux",
)

g = vec.(EKP.get_g(eki; return_array = true))
params = vec.((EKP.get_ϕ(prior, eki)))

for (gg, uu) in zip(g, params)
CairoMakie.scatter!(ax, gg, uu)
end

CairoMakie.hlines!(ax, [astronomical_unit], linestyle = :dash)
CairoMakie.vlines!(ax, observations, linestyle = :dash)

output = joinpath(output_dir, "scatter.png")
CairoMakie.save(output, f)
return output
end

function param_versus_iter_plot(eki::EKP.EnsembleKalmanProcess)
f = CairoMakie.Figure(resolution = (800, 600))
ax = CairoMakie.Axis(
f[1, 1],
ylabel = "Parameter Value",
xlabel = "Iteration",
)
params = EKP.get_ϕ(prior, eki)
for (i, param) in enumerate(params)
CairoMakie.scatter!(ax, fill(i, length(param)), vec(param))
end

CairoMakie.hlines!(ax, [astronomical_unit]; color = :red, linestyle = :dash)

output = joinpath(output_dir, "param_vs_iter.png")
CairoMakie.save(output, f)
return output
end

# Observation map
function CAL.observation_map(iteration)
single_member_dims = (1,)
G_ensemble = Array{Float64}(undef, single_member_dims..., ensemble_size)
Expand All @@ -86,113 +30,177 @@ function process_member_data(simdir::SimDir)
return slice(average_xy(rsut); time = 30).data
end

# EKI test
function minimal_eki_test(eki)
params = EKP.get_ϕ(prior, eki)
spread = map(var, params)
# TODO: These functions will be moved to ClimaCalibrate
function run_iteration(config, iter, worker_pool)
# Create a channel to collect results
results = Channel{Any}(config.ensemble_size)
@sync begin
for m in 1:(config.ensemble_size)
@async begin
# Get a worker from the pool
worker = take!(worker_pool)
try
@info "Running member $m on worker $worker"
# Run the model and put result in channel
model_config = CAL.set_up_forward_model(m, iter, config)
result = remotecall_fetch(
CAL.run_forward_model,
worker,
model_config,
)
put!(results, (m, result))
catch e
@error "Error running member $m" exception = e
put!(results, (m, e))
finally
# Always return worker to pool
put!(worker_pool, worker)
end
end
end
end

# Spread should be heavily decreased as particles have converged
@test last(spread) / first(spread) < 0.1
# Parameter should be close to true value
@test mean(last(params)) ≈ astronomical_unit rtol = 0.02
# Collect all results
ensemble_results = Dict{Int, Any}()
for _ in 1:(config.ensemble_size)
m, result = take!(results)
if result isa Exception
@error "Member $m failed" error = result
else
ensemble_results[m] = result
end
end
results = values(ensemble_results)
all(isa.(results, Exception)) &&
error("Full ensemble for iter $iter failed")

# Process results
G_ensemble = CAL.observation_map(iter)
CAL.save_G_ensemble(config, iter, G_ensemble)
CAL.update_ensemble(config, iter)
iter_path = CAL.path_to_iteration(config.output_dir, iter)
return JLD2.load_object(joinpath(iter_path, "eki_file.jld2"))
end

# Script:

if !(@isdefined backend)
backend = CAL.get_backend()
end
# Check that the wait time for the last hour does not exceed 20 minutes.
# This test schedules many slurm jobs and will be prohibitively slow if the cluster is busy
if backend <: CAL.HPCBackend
wait_times = readchomp(
`sacct --allocations -u esmbuild --starttime now-1hour -o Submit,Start -n`,
)
wait_times = split(wait_times, '\n', keepempty = false)
# Filter jobs that have not been submitted and started
filter!(x -> !(contains(x, "Unknown") || contains(x, "None")), wait_times)

mean_wait_time_in_mins =
mapreduce(+, wait_times; init = 0) do line
t1_str, t2_str = split(line)
t1 = DateTime(t1_str, dateformat"yyyy-mm-ddTHH:MM:SS")
t2 = DateTime(t2_str, dateformat"yyyy-mm-ddTHH:MM:SS")
Dates.value(t2 - t1) / 1000 / 60
end / length(wait_times)

@show mean_wait_time_in_mins

if mean_wait_time_in_mins > 10
@warn """Average wait time for esmbuild is $(round(mean_wait_time_in_mins, digits=2)) minutes. \
Cluster is too busy to run this test, exiting"""
exit()
function calibrate(config, worker_pool)
CAL.initialize(config)
for iter in 0:(config.n_iterations)
(; time) = @timed run_iteration(config, iter, worker_pool)
@info "Iteration $iter time: $time"
end
end

# Paths and setup
const experiment_dir = joinpath(pkgdir(CA), "calibration", "test")
const model_interface =
joinpath(pkgdir(CA), "calibration", "model_interface.jl")
const output_dir = "calibration_end_to_end_test"
include(model_interface)
ensemble_size = 15
addprocs(
SlurmManager(10),
t = "00:20:00",
cpus_per_task = 1,
exeflags = "--project=$(Base.active_project())",
)
worker_pool = WorkerPool(workers())

@everywhere println(string(myid()))

@everywhere begin
import ClimaCalibrate as CAL
import ClimaAtmos as CA

experiment_dir = joinpath(pkgdir(CA), "calibration", "test")
model_interface = joinpath(pkgdir(CA), "calibration", "model_interface.jl")
output_dir = "calibration_end_to_end_test"
include(model_interface)
ensemble_size = 50
obs_path = joinpath(experiment_dir, "observations.jld2")
end

# Generate observations
obs_path = joinpath(experiment_dir, "observations.jld2")
# Generate observations if needed
if !isfile(obs_path)
import JLD2
@info "Generating observations"
config = CA.AtmosConfig(joinpath(experiment_dir, "model_config.yml"))
simulation = CA.get_simulation(config)
atmos_config = CA.AtmosConfig(joinpath(experiment_dir, "model_config.yml"))
simulation = CA.get_simulation(atmos_config)
CA.solve_atmos!(simulation)
observations = Vector{Float64}(undef, 1)
observations .= process_member_data(SimDir(simulation.output_dir))
JLD2.save_object(obs_path, observations)
end

# Initialize experiment data
astronomical_unit = 149_597_870_000
observations = JLD2.load_object(obs_path)
noise = 0.1 * I
n_iterations = 4
prior = CAL.get_prior(joinpath(experiment_dir, "prior.toml"))
experiment_config = CAL.ExperimentConfig(;
n_iterations,
ensemble_size,
observations,
noise,
output_dir,
prior,
)

@info "Running calibration E2E test" backend
if backend <: CAL.HPCBackend
test_eki = CAL.calibrate(
backend,
experiment_config;
hpc_kwargs = CAL.kwargs(time = 15),
model_interface,
verbose = true,
@everywhere begin
import JLD2
import LinearAlgebra: I
astronomical_unit = 149_597_870_000
observations = JLD2.load_object(obs_path)
noise = 0.1 * I
n_iterations = 10
prior = CAL.get_prior(joinpath(experiment_dir, "prior.toml"))

config = CAL.ExperimentConfig(;
n_iterations,
ensemble_size,
observations,
noise,
output_dir,
prior,
)
else
test_eki = CAL.calibrate(backend, experiment_config)
end

scatter_plot(test_eki)
param_versus_iter_plot(test_eki)
calibrate(config, worker_pool)

@testset "Test Calibration on $backend" begin
minimal_eki_test(test_eki)
end
import EnsembleKalmanProcesses as EKP
import Statistics: var, mean
using Test
import CairoMakie

function scatter_plot(eki::EKP.EnsembleKalmanProcess)
f = CairoMakie.Figure(resolution = (800, 600))
ax = CairoMakie.Axis(
f[1, 1],
ylabel = "Parameter Value",
xlabel = "Top of atmosphere radiative SW flux",
)

g = vec.(EKP.get_g(eki; return_array = true))
params = vec.((EKP.get_ϕ(prior, eki)))

for (gg, uu) in zip(g, params)
CairoMakie.scatter!(ax, gg, uu)
end

# Run calibration
julia_eki = CAL.calibrate(CAL.JuliaBackend, experiment_config)
CairoMakie.hlines!(ax, [astronomical_unit], linestyle = :dash)
CairoMakie.vlines!(ax, observations, linestyle = :dash)

@testset "Julia-only comparison calibration" begin
minimal_eki_test(julia_eki)
output = joinpath(output_dir, "scatter.png")
CairoMakie.save(output, f)
return output
end

@testset "Compare $backend output to JuliaBackend" begin
for (uu, slurm_uu) in zip(EKP.get_u(julia_eki), EKP.get_u(test_eki))
@test uu ≈ slurm_uu rtol = 0.02
function param_versus_iter_plot(eki::EKP.EnsembleKalmanProcess)
f = CairoMakie.Figure(resolution = (800, 600))
ax = CairoMakie.Axis(
f[1, 1],
ylabel = "Parameter Value",
xlabel = "Iteration",
)
params = EKP.get_ϕ(prior, eki)
for (i, param) in enumerate(params)
CairoMakie.scatter!(ax, fill(i, length(param)), vec(param))
end

CairoMakie.hlines!(ax, [astronomical_unit]; color = :red, linestyle = :dash)

output = joinpath(output_dir, "param_vs_iter.png")
CairoMakie.save(output, f)
return output
end

eki = JLD2.load_object(joinpath(output_dir, "iteration_011", "eki_file.jld2"))
scatter_plot(eki)
param_versus_iter_plot(eki)

params = EKP.get_ϕ(prior, eki)
spread = map(var, params)

# Spread should be heavily decreased as particles have converged
@test last(spread) / first(spread) < 0.1
# Parameter should be close to true value
@test mean(last(params)) ≈ astronomical_unit rtol = 0.02
Loading