diff --git a/test/restart.jl b/test/restart.jl index 54a6ecba17..1dd11824ae 100644 --- a/test/restart.jl +++ b/test/restart.jl @@ -9,6 +9,7 @@ import ClimaCore.Spaces: AbstractSpace import ClimaComms pkgversion(ClimaComms) >= v"0.6" && ClimaComms.@import_required_backends import Logging +import NCDatasets using Test import Random @@ -36,9 +37,7 @@ ClimaComms.init(comms_ctx) # # For this reason, we don't use Test but just print to screen the differences. # However, we still have to return an exit code with failure in case of the -# comparison fails. So, we have this global `SUCCESS` bool that is updated by -# the result of tests. -const SUCCESS::Base.RefValue{Bool} = Ref(true) +# comparison fails. """ _error(arr1::AbstractArray, arr2::AbstractArray; ABS_TOL = 100eps(eltype(arr1))) @@ -130,6 +129,17 @@ function _compare(v1::T, v2::T; name, ignore) where {T <: Number} return print_maybe(v1 === v2, "$name differs: $v1 vs $v2") end +# We ignore NCDatasets. They contain a lot of state-ful information +function _compare( + pass, + v1::T, + v2::T; + name, + ignore, +) where {T <: NCDatasets.NCDataset} + return pass +end + function _compare( v1::T, v2::T; @@ -167,6 +177,118 @@ end # Disable all the @info statements that are produced when creating a simulation Logging.disable_logging(Logging.Info) + +""" + test_restart(test_dict; job_id, comms_ctx, more_ignore = Symbol[]) + +Test if the restarts are consistent for a simulation defined by the `test_dict` config. + +`more_ignore` is a Vector of Symbols that identifies config-specific keys that +have to be ignored when reading a simulation. +""" +function test_restart(test_dict; job_id, comms_ctx, more_ignore = Symbol[]) + println("job_id = $(job_id)") + + local_success = Ref(true) + + config = CA.AtmosConfig(test_dict; job_id, comms_ctx) + + simulation = CA.get_simulation(config) + CA.solve_atmos!(simulation) + + # Check re-importing the same state + restart_dir = simulation.output_dir + @test isfile(joinpath(restart_dir), "day0.3.hdf5") + + # Reset random seed for RRTMGP + Random.seed!(1234) + + ClimaComms.iamroot(comms_ctx) && println(" just reading data") + config_should_be_same = CA.AtmosConfig( + merge(test_dict, Dict("detect_restart_file" => true)); + job_id, + comms_ctx, + ) + + simulation_restarted = CA.get_simulation(config_should_be_same) + + local_success[] &= compare( + simulation.integrator.u, + simulation_restarted.integrator.u; + name = "integrator.u", + ) + local_success[] &= compare( + axes(simulation.integrator.u.c), + axes(simulation_restarted.integrator.u.c); + name = "space", + ) + local_success[] &= compare( + simulation.integrator.p, + simulation_restarted.integrator.p; + name = "integrator.p", + ignore = Set([ + :ghost_buffer, + :hyperdiffusion_ghost_buffer, + :scratch, + :output_dir, + :ghost_buffer, + # Computed in tendencies (which are not computed in this case) + :hyperdiff, + :precipitation, + # rc is some CUDA/CuArray internal object that we don't care about + :rc, + # DataHandlers contains caches, so they are stateful + :data_handler, + # Config-specific + more_ignore..., + ]), + ) + + # Check re-importing from previous state and advancing one step + ClimaComms.iamroot(comms_ctx) && println(" reading and simulating") + # Reset random seed for RRTMGP + Random.seed!(1234) + + restart_file = joinpath(simulation.output_dir, "day0.2.hdf5") + @test isfile(joinpath(restart_dir), "day0.2.hdf5") + # Restart from specific file + config2 = CA.AtmosConfig( + merge(test_dict, Dict("restart_file" => restart_file)); + job_id, + comms_ctx, + ) + + simulation_restarted2 = CA.get_simulation(config2) + CA.fill_with_nans!(simulation_restarted2.integrator.p) + + CA.solve_atmos!(simulation_restarted2) + local_success[] &= compare( + simulation.integrator.u, + simulation_restarted2.integrator.u; + name = "integrator.u", + ) + local_success[] &= compare( + simulation.integrator.p, + simulation_restarted2.integrator.p; + name = "integrator.p", + ignore = Set([ + :scratch, + :output_dir, + :ghost_buffer, + :hyperdiffusion_ghost_buffer, + :data_handler, + :rc, + ]), + ) + + return local_success[] +end + +# Let's prepare the test_dicts. TESTING is a Vector of NamedTuples, each element +# has a test_dict, a job_id, and a more_ignore + +TESTING = Any[] + if comms_ctx isa ClimaComms.SingletonCommsContext configurations = ["sphere", "box", "column"] else @@ -200,9 +322,6 @@ for configuration in configurations end end - println( - "config = $configuration $moisture $precip $topography $radiation", - ) # The `enable_bubble` case is broken for ClimaCore < 0.14.6, so we # hard-code this to be always false for those versions bubble = pkgversion(ClimaCore) > v"0.14.5" @@ -211,9 +330,8 @@ for configuration in configurations output_loc = ClimaComms.iamroot(comms_ctx) ? mktempdir(pwd()) : "" output_loc = ClimaComms.bcast(comms_ctx, output_loc) - ClimaComms.barrier(comms_ctx) - job_id = "restart" + job_id = "$(configuration)_$(moisture)_$(precip)_$(topography)_$(radiation)" test_dict = Dict( "test_dycore_consistency" => true, # We will add NaNs to the cache, just to make sure "check_nan_every" => 3, @@ -240,103 +358,17 @@ for configuration in configurations ) more_ignore = Symbol[] - config = CA.AtmosConfig(test_dict; job_id, comms_ctx) - - simulation = CA.get_simulation(config) - CA.solve_atmos!(simulation) - - # Check re-importing the same state - restart_dir = simulation.output_dir - @test isfile(joinpath(restart_dir), "day0.3.hdf5") - - # Reset random seed for RRTMGP - Random.seed!(1234) - - println(" just reading data") if turbconv_mode == "prognostic_edmf" more_ignore = [:ᶠnh_pressure₃ʲs] end - - config_should_be_same = CA.AtmosConfig( - merge(test_dict, Dict("detect_restart_file" => true)); - job_id, - comms_ctx, - ) - - simulation_restarted = - CA.get_simulation(config_should_be_same) - - SUCCESS[] &= compare( - simulation.integrator.u, - simulation_restarted.integrator.u; - name = "integrator.u", - ) - SUCCESS[] &= compare( - axes(simulation.integrator.u.c), - axes(simulation_restarted.integrator.u.c); - name = "space", - ) - SUCCESS[] &= compare( - simulation.integrator.p, - simulation_restarted.integrator.p; - name = "integrator.p", - ignore = Set([ - :ghost_buffer, - :hyperdiffusion_ghost_buffer, - :scratch, - :output_dir, - :ghost_buffer, - # Computed in tendencies (which are not computed in this case) - :hyperdiff, - :precipitation, - # rc is some CUDA/CuArray internal object that we don't care about - :rc, - # Config-specific - more_ignore..., - ]), - ) - - # Check re-importing from previous state and advancing one step - println(" reading and simulating") - # Reset random seed for RRTMGP - Random.seed!(1234) - - restart_file = - joinpath(simulation.output_dir, "day0.2.hdf5") - @test isfile(joinpath(restart_dir), "day0.2.hdf5") - # Restart from specific file - config2 = CA.AtmosConfig( - merge(test_dict, Dict("restart_file" => restart_file)); - job_id, - comms_ctx, - ) - - simulation_restarted2 = CA.get_simulation(config2) - CA.fill_with_nans!(simulation_restarted2.integrator.p) - - CA.solve_atmos!(simulation_restarted2) - SUCCESS[] &= compare( - simulation.integrator.u, - simulation_restarted2.integrator.u; - name = "integrator.u", - ) - SUCCESS[] &= compare( - simulation.integrator.p, - simulation_restarted2.integrator.p; - name = "integrator.p", - ignore = Set([ - :scratch, - :output_dir, - :ghost_buffer, - :hyperdiffusion_ghost_buffer, - :rc, - ]), - ) + push!(TESTING, (; test_dict, job_id, more_ignore)) end end end end end -# Ensure that we have the correct exit code -@test SUCCESS[] +@test all( + @time test_restart(t.test_dict; comms_ctx, t.job_id, t.more_ignore) for + t in TESTING +)