Skip to content

Commit

Permalink
Taal, SaveSolutionCallback
Browse files Browse the repository at this point in the history
  • Loading branch information
ranocha committed Sep 22, 2020
1 parent 18bae6b commit 3e5f4ab
Show file tree
Hide file tree
Showing 12 changed files with 183 additions and 40 deletions.
15 changes: 7 additions & 8 deletions examples/2d/parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,17 @@ alive_callback = AliveCallback(analysis_interval=analysis_interval)
analysis_callback = AnalysisCallback(semi, analysis_interval=analysis_interval,
extra_analysis_integrals=(entropy, energy_total))

stepsize_callback = StepsizeCallback(cfl=1.6)
save_solution = SaveSolutionCallback(solution_interval=100,
save_initial_solution=true,
save_final_solution=true,
solution_variables=:primitive)

# TODO: Taal, IO
# # save_initial_solution = false
# solution_interval = 100
# solution_variables = "primitive"
# restart_interval = 10

# TODO: Taal, restart
# restart = true
# restart_filename = "out/restart_000100.h5"
callbacks = CallbackSet(stepsize_callback, analysis_callback, alive_callback)
stepsize_callback = StepsizeCallback(cfl=1.6)

callbacks = CallbackSet(stepsize_callback, analysis_callback, save_solution, alive_callback)


sol = solve(ode, CarpenterKennedy2N54(williamson_condition=false), dt=stepsize_callback(ode),
Expand Down
15 changes: 7 additions & 8 deletions examples/2d/parameters_ec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,15 @@ analysis_callback = AnalysisCallback(semi, analysis_interval=analysis_interval)

stepsize_callback = StepsizeCallback(cfl=1.0)

# TODO: Taal, IO
# # save_initial_solution = false
# solution_interval = 100
# solution_variables = "primitive"
# restart_interval = 10
save_solution = SaveSolutionCallback(solution_interval=100,
save_initial_solution=true,
save_final_solution=true,
solution_variables=:primitive)

# TODO: Taal, restart
# restart = true
# restart_filename = "out/restart_000100.h5"
callbacks = CallbackSet(stepsize_callback, analysis_callback, alive_callback)
# restart_interval = 10

callbacks = CallbackSet(stepsize_callback, analysis_callback, save_solution, alive_callback)


sol = solve(ode, CarpenterKennedy2N54(williamson_condition=false), dt=stepsize_callback(ode),
Expand Down
15 changes: 9 additions & 6 deletions examples/2d/parameters_mortar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,17 @@ alive_callback = AliveCallback(analysis_interval=analysis_interval)
analysis_callback = AnalysisCallback(semi, analysis_interval=analysis_interval,
extra_analysis_integrals=(entropy,))

stepsize_callback = StepsizeCallback(cfl=2.0)
save_solution = SaveSolutionCallback(solution_interval=100,
save_initial_solution=true,
save_final_solution=true,
solution_variables=:primitive)

# TODO: Taal, IO
# # save_initial_solution = false
# solution_interval = 100
# solution_variables = "primitive"
# TODO: Taal, restart
# restart_interval = 10
callbacks = CallbackSet(stepsize_callback, analysis_callback, alive_callback)

stepsize_callback = StepsizeCallback(cfl=2.0)

callbacks = CallbackSet(stepsize_callback, analysis_callback, save_solution, alive_callback)


sol = solve(ode, CarpenterKennedy2N54(williamson_condition=false), dt=stepsize_callback(ode),
Expand Down
2 changes: 1 addition & 1 deletion src/Trixi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ export DGSEM,

export Semidiscretization, semidiscretize, compute_coefficients

export AliveCallback, AnalysisCallback, StepsizeCallback
export AliveCallback, AnalysisCallback, SaveSolutionCallback, StepsizeCallback

export entropy, energy_total

Expand Down
1 change: 1 addition & 0 deletions src/callbacks/callbacks.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

include("alive.jl")
include("analysis.jl")
include("save_solution.jl")
include("stepsize.jl")
76 changes: 76 additions & 0 deletions src/callbacks/save_solution.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@

# TODO: Taal, implement, save AMR indicator values
# TODO: Taal, refactor, allow saving arbitrary functions of the conservative variables
mutable struct SaveSolutionCallback
save_initial_solution::Bool
save_final_solution::Bool
output_directory::String
solution_variables::Symbol
end


function Base.show(io::IO, cb::DiscreteCallback{Condition,Affect!}) where {Condition, Affect!<:SaveSolutionCallback}
stepsize_callback = cb.affect!
print(io, "SaveSolutionCallback")
end
# TODO: Taal bikeshedding, implement a method with more information and the signature
# function Base.show(io::IO, ::MIME"text/plain", cb::DiscreteCallback{Condition,Affect!}) where {Condition, Affect!<:StepsizeCallback}
# end


function SaveSolutionCallback(; solution_interval=0,
save_initial_solution=true,
save_final_solution=true,
output_directory="out",
solution_variables=:primitive)
condition = (u, t, integrator) -> solution_interval > 0 && ((integrator.iter % solution_interval == 0) || (save_final_solution && t == integrator.sol.prob.tspan[2]))

solution_callback = SaveSolutionCallback(save_initial_solution, save_final_solution,
output_directory, solution_variables)

DiscreteCallback(condition, solution_callback,
save_positions=(false,false),
initialize=initialize!)
end


function initialize!(cb::DiscreteCallback{Condition,Affect!}, u, t, integrator) where {Condition, Affect!<:SaveSolutionCallback}
reset_timer!(timer())
solution_callback = cb.affect!

mkpath(solution_callback.output_directory)

semi = integrator.p
@unpack mesh = semi
mesh.unsaved_changes = true

if solution_callback.save_initial_solution
solution_callback(integrator)
end

return nothing
end


function (solution_callback::SaveSolutionCallback)(integrator)
@unpack u, t, dt, iter = integrator
semi = integrator.p
@unpack mesh, equations, solver, cache = semi

@timeit_debug timer() "I/O" begin
if mesh.unsaved_changes
mesh.current_filename = save_mesh_file(mesh, solution_callback.output_directory, iter)
mesh.unsaved_changes = false
end

save_solution_file(u, t, dt, iter, mesh, equations, solver, cache, solution_callback)
end

return nothing
end


# function save_mesh_file(mesh::TreeMesh, output_directory, timestep=-1) in io/io.jl
#

include("save_solution_dg.jl")
66 changes: 66 additions & 0 deletions src/callbacks/save_solution_dg.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@

function save_solution_file(u, time, dt, timestep,
mesh, equations, dg::DG, cache,
solution_callback)
@unpack output_directory, solution_variables = solution_callback

# Filename without extension based on current time step
filename = joinpath(output_directory, @sprintf("solution_%06d.h5", timestep))

# Convert time and time step size to floats
time = convert(Float64, time)
dt = convert(Float64, dt)

# Open file (clobber existing content)
h5open(filename, "w") do file
# Add context information as attributes
attrs(file)["ndims"] = ndims(mesh)
attrs(file)["equations"] = get_name(equations)
attrs(file)["polydeg"] = polydeg(dg)
attrs(file)["n_vars"] = nvariables(equations)
attrs(file)["n_elements"] = nelements(dg, cache)
attrs(file)["mesh_file"] = splitdir(mesh.current_filename)[2]
attrs(file)["time"] = time
attrs(file)["dt"] = dt
attrs(file)["timestep"] = timestep

# Convert to primitive variables if requested
if solution_variables === :conservative
data = u
varnames = varnames_cons(equations)
elseif solution_variables === :primitive
# Reinterpret the solution array as an array of conservative variables,
# compute the primitive variables via broadcasting, and reinterpret the
# result as a plain array of floating point numbers
data = Array(reinterpret(eltype(u),
cons2prim.(reinterpret(SVector{nvariables(equations),eltype(u)}, u),
Ref(equations))))
varnames = varnames_prim(equations)
else
error("Unknown solution_variables $solution_variables")
end

# Store each variable of the solution
for v in eachvariable(equations)
# Convert to 1D array
file["variables_$v"] = vec(data[v, .., :])

# Add variable name as attribute
var = file["variables_$v"]
attrs(var)["name"] = varnames[v]
end

# TODO: Taal implement, save element variables
# Store element variables
# for (v, (key, element_variables)) in enumerate(cache.element_variables)
# # Add to file
# file["element_variables_$v"] = element_variables

# # Add variable name as attribute
# var = file["element_variables_$v"]
# attrs(var)["name"] = string(key)
# end
end

return filename
end
2 changes: 1 addition & 1 deletion src/callbacks/stepsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ end
@unpack mesh, equations, solver, cache = semi
@unpack cfl_number = stepsize_callback

@timeit_debug timer() "calc_dt" dt = cfl_number * max_dt(u, t, mesh, have_constant_speed(equations), equations, solver, cache)
@timeit_debug timer() "calculate dt" dt = cfl_number * max_dt(u, t, mesh, have_constant_speed(equations), equations, solver, cache)
set_proposed_dt!(integrator, dt)
integrator.opts.dtmax = dt
integrator.dtcache = dt
Expand Down
11 changes: 5 additions & 6 deletions src/io/io.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,21 +165,20 @@ end


# Save current mesh with some context information as an HDF5 file.
function save_mesh_file(mesh::TreeMesh, timestep=-1)
function save_mesh_file(mesh::TreeMesh, output_directory, timestep=0)
# Create output directory (if it does not exist)
output_directory = parameter("output_directory", "out")
mkpath(output_directory)

# Determine file name based on existence of meaningful time step
if timestep >= 0
filename = joinpath(output_directory, @sprintf("mesh_%06d", timestep))
if timestep > 0
filename = joinpath(output_directory, @sprintf("mesh_%06d.h5", timestep))
else
filename = joinpath(output_directory, "mesh")
filename = joinpath(output_directory, "mesh.h5")
end

# Create output directory (if it does not exist)
# Open file (clobber existing content)
h5open(filename * ".h5", "w") do file
h5open(filename, "w") do file
# Add context information as attributes
n_cells = length(mesh.tree)
attrs(file)["ndims"] = ndims(mesh)
Expand Down
10 changes: 5 additions & 5 deletions src/run.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ function init_simulation()
else
print("Creating mesh... ")
@timeit timer() "mesh creation" mesh = generate_mesh()
mesh.current_filename = save_mesh_file(mesh)
mesh.current_filename = save_mesh_file(mesh, parameter("output_directory", "out"))
mesh.unsaved_changes = false
println("done")
end
Expand Down Expand Up @@ -146,7 +146,7 @@ function init_simulation()
end

# Save mesh file
mesh.current_filename = save_mesh_file(mesh)
mesh.current_filename = save_mesh_file(mesh, parameter("output_directory", "out"))
mesh.unsaved_changes = false
end
end
Expand Down Expand Up @@ -284,7 +284,7 @@ function run_simulation(mesh, solver, time_parameters, time_integration_function
first_loop_iteration = true
@timeit timer() "main loop" while !finalstep
# Calculate time step size
@timeit timer() "calc_dt" dt = calc_dt(solver, cfl)
@timeit timer() "calculate dt" dt = calc_dt(solver, cfl)

# Abort if time step size is NaN
if isnan(dt)
Expand Down Expand Up @@ -360,7 +360,7 @@ function run_simulation(mesh, solver, time_parameters, time_integration_function

# If mesh has changed, write a new mesh file name
if mesh.unsaved_changes
mesh.current_filename = save_mesh_file(mesh, step)
mesh.current_filename = save_mesh_file(mesh, parameter("output_directory", "out"), step)
mesh.unsaved_changes = false
end

Expand All @@ -377,7 +377,7 @@ function run_simulation(mesh, solver, time_parameters, time_integration_function
@timeit timer() "I/O" begin
# If mesh has changed, write a new mesh file
if mesh.unsaved_changes
mesh.current_filename = save_mesh_file(mesh, step)
mesh.current_filename = save_mesh_file(mesh, parameter("output_directory", "out"), step)
mesh.unsaved_changes = false
end

Expand Down
8 changes: 4 additions & 4 deletions src/run_euler_gravity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ function init_simulation_euler_gravity()
begin
print("Creating mesh... ")
@timeit timer() "mesh creation" mesh = generate_mesh()
mesh.current_filename = save_mesh_file(mesh)
mesh.current_filename = save_mesh_file(mesh, parameter("output_directory", "out"))
mesh.unsaved_changes = false
println("done")
end
Expand Down Expand Up @@ -71,7 +71,7 @@ function init_simulation_euler_gravity()
end

# Save mesh file
mesh.current_filename = save_mesh_file(mesh)
mesh.current_filename = save_mesh_file(mesh, parameter("output_directory", "out"))
mesh.unsaved_changes = false
end
end
Expand Down Expand Up @@ -229,7 +229,7 @@ function run_simulation_euler_gravity(mesh, solvers, time_parameters, time_integ
first_loop_iteration = true
@timeit timer() "main loop" while !finalstep
# Calculate time step size
@timeit timer() "calc_dt" dt = calc_dt(solver, cfl)
@timeit timer() "calculate dt" dt = calc_dt(solver, cfl)

# Abort if time step size is NaN
if isnan(dt)
Expand Down Expand Up @@ -311,7 +311,7 @@ function run_simulation_euler_gravity(mesh, solvers, time_parameters, time_integ

# If mesh has changed, write a new mesh file name
if mesh.unsaved_changes
mesh.current_filename = save_mesh_file(mesh, step)
mesh.current_filename = save_mesh_file(mesh, parameter("output_directory", "out"), step)
mesh.unsaved_changes = false
end

Expand Down
2 changes: 1 addition & 1 deletion src/timedisc/timedisc_euler_gravity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ function update_gravity!(solver, u_euler, gravity_parameters)
# Iterate gravity solver until convergence or maximum number of iterations are reached
while !finalstep
# Calculate time step size
@timeit timer() "calc_dt" dt = calc_dt(solver, cfl_gravity)
@timeit timer() "calculate dt" dt = calc_dt(solver, cfl_gravity)

# Evolve solution by one pseudo-time step
timestep_gravity(solver, time, dt, u_euler, gravity_parameters)
Expand Down

0 comments on commit 3e5f4ab

Please sign in to comment.