From 8cbe9106458df39b01ba0313db6251f814d9f01b Mon Sep 17 00:00:00 2001 From: nefrathenrici Date: Mon, 2 Dec 2024 10:26:07 -0800 Subject: [PATCH] Add slurm workers for calibration end-to-end test --- .buildkite/pipeline.yml | 5 + .github/workflows/calibration_test.yml | 31 --- calibration/test/Project.toml | 3 +- calibration/test/e2e_test.jl | 290 +++++++++++++------------ 4 files changed, 156 insertions(+), 173 deletions(-) delete mode 100644 .github/workflows/calibration_test.yml diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index f399ec8e78..f71fa9571f 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -569,6 +569,11 @@ steps: command: julia --project=calibration/test calibration/test/interface.jl - label: "end to end test" command: julia --project=calibration/test calibration/test/e2e_test.jl + agents: + slurm_ntasks: 12 + slurm_cpus_per_task: 1 + slurm_mem: 96GB + slurm_time: "00:10:00" artifact_paths: "calibration_end_to_end_test/*" soft_fail: true diff --git a/.github/workflows/calibration_test.yml b/.github/workflows/calibration_test.yml deleted file mode 100644 index 2afece5204..0000000000 --- a/.github/workflows/calibration_test.yml +++ /dev/null @@ -1,31 +0,0 @@ -name: Calibration -on: - push: - tags: '*' - pull_request: - merge_group: - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -# Needed to allow julia-actions/cache to delete old caches that it has created -permissions: - actions: write - contents: read - -jobs: - test: - name: ClimaCalibrate E2E Test - runs-on: ubuntu-latest - timeout-minutes: 30 - steps: - - uses: julia-actions/cache@v2 - - uses: julia-actions/julia-buildpkg@v1 - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: '1.10' - - run: | - julia --project=calibration/test -e 'using Pkg; Pkg.develop(;path="."); Pkg.instantiate(;verbose=true)' - julia --project=calibration/test calibration/test/e2e_test.jl diff --git a/calibration/test/Project.toml b/calibration/test/Project.toml index f76461534a..105b019df7 100644 --- a/calibration/test/Project.toml +++ b/calibration/test/Project.toml @@ -5,6 +5,7 @@ ClimaAtmos = "b2c96348-7fb7-4fe0-8da9-78d88439e717" ClimaCalibrate = "4347a170-ebd6-470c-89d3-5c705c0cacc2" ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d" ClimaUtilities = "b3f4f4ca-9299-4f7f-bd9b-81e1242a7513" +ClusterManagers = "34f1f09b-3a8b-5176-ab39-66d58a4d544e" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" EnsembleKalmanProcesses = "aa8a2aa5-91d8-4396-bcef-d4f2ec43552d" JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" @@ -12,4 +13,4 @@ MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6" [compat] -ClimaCalibrate = "0.0.2 - 0.0.4" +ClimaCalibrate = "0.0.5" \ No newline at end of file diff --git a/calibration/test/e2e_test.jl b/calibration/test/e2e_test.jl index 7ea3223e39..52bc7aa71b 100644 --- a/calibration/test/e2e_test.jl +++ b/calibration/test/e2e_test.jl @@ -1,67 +1,11 @@ #= End-to-end test Runs a perfect model calibration, calibrating on the parameter `astronomical_unit` with top-of-atmosphere radiative shortwave flux in the loss function. - -The calibration is run twice, once on the backend obtained via `get_backend()` -and once on the `JuliaBackend`. The output of each calibration is tested individually -and compared to ensure reproducibility. =# +using Distributed, ClusterManagers import ClimaCalibrate as CAL -import ClimaAtmos as CA import ClimaAnalysis: SimDir, get, slice, average_xy -import CairoMakie -import JLD2 -import LinearAlgebra: I -import EnsembleKalmanProcesses as EKP -import Statistics: var, mean -using Test - -using Dates - -# Debug plots -function scatter_plot(eki::EKP.EnsembleKalmanProcess) - f = CairoMakie.Figure(resolution = (800, 600)) - ax = CairoMakie.Axis( - f[1, 1], - ylabel = "Parameter Value", - xlabel = "Top of atmosphere radiative SW flux", - ) - - g = vec.(EKP.get_g(eki; return_array = true)) - params = vec.((EKP.get_ϕ(prior, eki))) - - for (gg, uu) in zip(g, params) - CairoMakie.scatter!(ax, gg, uu) - end - - CairoMakie.hlines!(ax, [astronomical_unit], linestyle = :dash) - CairoMakie.vlines!(ax, observations, linestyle = :dash) - - output = joinpath(output_dir, "scatter.png") - CairoMakie.save(output, f) - return output -end - -function param_versus_iter_plot(eki::EKP.EnsembleKalmanProcess) - f = CairoMakie.Figure(resolution = (800, 600)) - ax = CairoMakie.Axis( - f[1, 1], - ylabel = "Parameter Value", - xlabel = "Iteration", - ) - params = EKP.get_ϕ(prior, eki) - for (i, param) in enumerate(params) - CairoMakie.scatter!(ax, fill(i, length(param)), vec(param)) - end - CairoMakie.hlines!(ax, [astronomical_unit]; color = :red, linestyle = :dash) - - output = joinpath(output_dir, "param_vs_iter.png") - CairoMakie.save(output, f) - return output -end - -# Observation map function CAL.observation_map(iteration) single_member_dims = (1,) G_ensemble = Array{Float64}(undef, single_member_dims..., ensemble_size) @@ -86,63 +30,94 @@ function process_member_data(simdir::SimDir) return slice(average_xy(rsut); time = 30).data end -# EKI test -function minimal_eki_test(eki) - params = EKP.get_ϕ(prior, eki) - spread = map(var, params) +# TODO: These functions will be moved to ClimaCalibrate +function run_iteration(config, iter, worker_pool) + # Create a channel to collect results + results = Channel{Any}(config.ensemble_size) + @sync begin + for m in 1:(config.ensemble_size) + @async begin + # Get a worker from the pool + worker = take!(worker_pool) + try + @info "Running member $m on worker $worker" + # Run the model and put result in channel + model_config = CAL.set_up_forward_model(m, iter, config) + result = remotecall_fetch( + CAL.run_forward_model, + worker, + model_config, + ) + put!(results, (m, result)) + catch e + @error "Error running member $m" exception = e + put!(results, (m, e)) + finally + # Always return worker to pool + put!(worker_pool, worker) + end + end + end + end - # Spread should be heavily decreased as particles have converged - @test last(spread) / first(spread) < 0.1 - # Parameter should be close to true value - @test mean(last(params)) ≈ astronomical_unit rtol = 0.02 + # Collect all results + ensemble_results = Dict{Int, Any}() + for _ in 1:(config.ensemble_size) + m, result = take!(results) + if result isa Exception + @error "Member $m failed" error = result + else + ensemble_results[m] = result + end + end + results = values(ensemble_results) + all(isa.(results, Exception)) && + error("Full ensemble for iter $iter failed") + + # Process results + G_ensemble = CAL.observation_map(iter) + CAL.save_G_ensemble(config, iter, G_ensemble) + CAL.update_ensemble(config, iter) + iter_path = CAL.path_to_iteration(config.output_dir, iter) + return JLD2.load_object(joinpath(iter_path, "eki_file.jld2")) end -# Script: - -if !(@isdefined backend) - backend = CAL.get_backend() -end -# Check that the wait time for the last hour does not exceed 20 minutes. -# This test schedules many slurm jobs and will be prohibitively slow if the cluster is busy -if backend <: CAL.HPCBackend - wait_times = readchomp( - `sacct --allocations -u esmbuild --starttime now-1hour -o Submit,Start -n`, - ) - wait_times = split(wait_times, '\n', keepempty = false) - # Filter jobs that have not been submitted and started - filter!(x -> !(contains(x, "Unknown") || contains(x, "None")), wait_times) - - mean_wait_time_in_mins = - mapreduce(+, wait_times; init = 0) do line - t1_str, t2_str = split(line) - t1 = DateTime(t1_str, dateformat"yyyy-mm-ddTHH:MM:SS") - t2 = DateTime(t2_str, dateformat"yyyy-mm-ddTHH:MM:SS") - Dates.value(t2 - t1) / 1000 / 60 - end / length(wait_times) - - @show mean_wait_time_in_mins - - if mean_wait_time_in_mins > 10 - @warn """Average wait time for esmbuild is $(round(mean_wait_time_in_mins, digits=2)) minutes. \ - Cluster is too busy to run this test, exiting""" - exit() +function calibrate(config, worker_pool) + CAL.initialize(config) + for iter in 0:(config.n_iterations) + (; time) = @timed run_iteration(config, iter, worker_pool) + @info "Iteration $iter time: $time" end end -# Paths and setup -const experiment_dir = joinpath(pkgdir(CA), "calibration", "test") -const model_interface = - joinpath(pkgdir(CA), "calibration", "model_interface.jl") -const output_dir = "calibration_end_to_end_test" -include(model_interface) -ensemble_size = 15 +addprocs( + SlurmManager(10), + t = "00:20:00", + cpus_per_task = 1, + exeflags = "--project=$(Base.active_project())", +) +worker_pool = WorkerPool(workers()) + +@everywhere println(string(myid())) + +@everywhere begin + import ClimaCalibrate as CAL + import ClimaAtmos as CA + + experiment_dir = joinpath(pkgdir(CA), "calibration", "test") + model_interface = joinpath(pkgdir(CA), "calibration", "model_interface.jl") + output_dir = "calibration_end_to_end_test" + include(model_interface) + ensemble_size = 50 + obs_path = joinpath(experiment_dir, "observations.jld2") +end -# Generate observations -obs_path = joinpath(experiment_dir, "observations.jld2") +# Generate observations if needed if !isfile(obs_path) + import JLD2 @info "Generating observations" - config = CA.AtmosConfig(joinpath(experiment_dir, "model_config.yml")) - simulation = CA.get_simulation(config) + atmos_config = CA.AtmosConfig(joinpath(experiment_dir, "model_config.yml")) + simulation = CA.get_simulation(atmos_config) CA.solve_atmos!(simulation) observations = Vector{Float64}(undef, 1) observations .= process_member_data(SimDir(simulation.output_dir)) @@ -150,49 +125,82 @@ if !isfile(obs_path) end # Initialize experiment data -astronomical_unit = 149_597_870_000 -observations = JLD2.load_object(obs_path) -noise = 0.1 * I -n_iterations = 4 -prior = CAL.get_prior(joinpath(experiment_dir, "prior.toml")) -experiment_config = CAL.ExperimentConfig(; - n_iterations, - ensemble_size, - observations, - noise, - output_dir, - prior, -) - -@info "Running calibration E2E test" backend -if backend <: CAL.HPCBackend - test_eki = CAL.calibrate( - backend, - experiment_config; - hpc_kwargs = CAL.kwargs(time = 15), - model_interface, - verbose = true, +@everywhere begin + import JLD2 + import LinearAlgebra: I + astronomical_unit = 149_597_870_000 + observations = JLD2.load_object(obs_path) + noise = 0.1 * I + n_iterations = 10 + prior = CAL.get_prior(joinpath(experiment_dir, "prior.toml")) + + config = CAL.ExperimentConfig(; + n_iterations, + ensemble_size, + observations, + noise, + output_dir, + prior, ) -else - test_eki = CAL.calibrate(backend, experiment_config) end -scatter_plot(test_eki) -param_versus_iter_plot(test_eki) +calibrate(config, worker_pool) -@testset "Test Calibration on $backend" begin - minimal_eki_test(test_eki) -end +import EnsembleKalmanProcesses as EKP +import Statistics: var, mean +using Test +import CairoMakie + +function scatter_plot(eki::EKP.EnsembleKalmanProcess) + f = CairoMakie.Figure(resolution = (800, 600)) + ax = CairoMakie.Axis( + f[1, 1], + ylabel = "Parameter Value", + xlabel = "Top of atmosphere radiative SW flux", + ) + + g = vec.(EKP.get_g(eki; return_array = true)) + params = vec.((EKP.get_ϕ(prior, eki))) + + for (gg, uu) in zip(g, params) + CairoMakie.scatter!(ax, gg, uu) + end -# Run calibration -julia_eki = CAL.calibrate(CAL.JuliaBackend, experiment_config) + CairoMakie.hlines!(ax, [astronomical_unit], linestyle = :dash) + CairoMakie.vlines!(ax, observations, linestyle = :dash) -@testset "Julia-only comparison calibration" begin - minimal_eki_test(julia_eki) + output = joinpath(output_dir, "scatter.png") + CairoMakie.save(output, f) + return output end -@testset "Compare $backend output to JuliaBackend" begin - for (uu, slurm_uu) in zip(EKP.get_u(julia_eki), EKP.get_u(test_eki)) - @test uu ≈ slurm_uu rtol = 0.02 +function param_versus_iter_plot(eki::EKP.EnsembleKalmanProcess) + f = CairoMakie.Figure(resolution = (800, 600)) + ax = CairoMakie.Axis( + f[1, 1], + ylabel = "Parameter Value", + xlabel = "Iteration", + ) + params = EKP.get_ϕ(prior, eki) + for (i, param) in enumerate(params) + CairoMakie.scatter!(ax, fill(i, length(param)), vec(param)) end + + CairoMakie.hlines!(ax, [astronomical_unit]; color = :red, linestyle = :dash) + + output = joinpath(output_dir, "param_vs_iter.png") + CairoMakie.save(output, f) + return output end + +eki = JLD2.load_object(joinpath(output_dir, "iteration_011", "eki_file.jld2")) +scatter_plot(eki) +param_versus_iter_plot(eki) + +params = EKP.get_ϕ(prior, eki) +spread = map(var, params) + +# Spread should be heavily decreased as particles have converged +@test last(spread) / first(spread) < 0.1 +# Parameter should be close to true value +@test mean(last(params)) ≈ astronomical_unit rtol = 0.02