Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement neural network parameterization of salty turbulent mixing in the upper ocean #3819

Draft
wants to merge 119 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 57 commits
Commits
Show all changes
119 commits
Select commit Hold shift + click to select a range
13cd632
working version on CPU
xkykai Jun 24, 2024
bb70e5c
a working model for GPU!
xkykai Jun 24, 2024
e845dd5
it actually works on the gpu now!
xkykai Jun 25, 2024
b6552ab
build scaling from tuple values
xkykai Jun 25, 2024
c9a9c65
Update XinKaiLocalVerticalDiffusivity with 2 Pr values
xkykai Jun 26, 2024
9a07b82
fix function construct_scaling to construct_zeromeanunitvariance_scaling
xkykai Jun 26, 2024
74c8b61
update using KernelParameters
xkykai Jun 26, 2024
6f97c3f
fix nn closure bug
xkykai Jun 26, 2024
65f15e4
test script for nn closure
xkykai Jun 26, 2024
cb601ce
fix N2 average, tracer diffusivity expression for local diffusivity c…
xkykai Jun 27, 2024
dbe5759
2Pr version of local physical closure
xkykai Jun 28, 2024
e9ac435
nonlocal physical closure of vertical diffusivity
xkykai Jun 28, 2024
5e548da
2Pr nonlocal physical closure
xkykai Jun 28, 2024
772c235
working GPU version of NN closure with scaling and correction
xkykai Jun 28, 2024
9c6873f
validation script for oceananigans NN implementation
xkykai Jul 1, 2024
f6d508a
close file and record video of validation
xkykai Jul 1, 2024
1faa5a6
Update Project.toml with new package dependencies
xkykai Jul 2, 2024
930cc7e
Update NN closure model to use a larger neural network
xkykai Jul 2, 2024
fd95fc6
Remove unused import of StaticArrays
xkykai Jul 2, 2024
72aebc1
Add ComponentArrays dependency
xkykai Jul 2, 2024
c2d336e
add total_size and KernelParameters dependency
xkykai Jul 2, 2024
d4a1156
Coarsen LES data for NN closure model
xkykai Jul 17, 2024
ae7b563
rename file and compare LES with NN closure
xkykai Jul 17, 2024
4d118c2
run LES for hald sinusoid cooling
xkykai Jul 18, 2024
fc081cc
run 3D simulation with limited extent
xkykai Aug 21, 2024
a78109e
add sponge at bottom
xkykai Aug 21, 2024
ca8d78c
fix metres to meters
xkykai Aug 21, 2024
78e7a67
fix temperature flux
xkykai Aug 21, 2024
eaf2183
Calculate average velocities and tracers in 3D model LES simulation
xkykai Aug 21, 2024
7e144ff
reduce size of model
xkykai Aug 21, 2024
ba7f58d
run double gyre with physical closure
xkykai Aug 29, 2024
fcb466a
rename file
xkykai Aug 29, 2024
b63e7d6
fix S forcing
xkykai Aug 29, 2024
67e011c
use RiBasedVerticalDiffusivity as closure
xkykai Aug 29, 2024
55d635a
fix sign errors in tracer fluxes
xkykai Aug 29, 2024
8319ba8
fix boundary conditions, initialize trivial model first
xkykai Aug 29, 2024
229feeb
update boundary conditions, initial state
xkykai Aug 29, 2024
a853275
fix tyop_buoyancy_flux bug
xkykai Aug 29, 2024
477c6be
fix initial conditions, run for 10 years, set up checkpointing
xkykai Aug 29, 2024
232b67d
run double gyre with new physical closure
xkykai Aug 30, 2024
148a582
plotting barotropic streamfunction
xkykai Sep 4, 2024
e9f2f0e
fix plot units
xkykai Sep 4, 2024
7f17cf2
Merge branch 'main' into xk/embed-nn
xkykai Sep 9, 2024
ed8ec7d
run double gyre with CATKE
xkykai Sep 9, 2024
f4081bf
Merge branch 'main' into xk/embed-nn
xkykai Sep 9, 2024
893e529
add TKE tracer for CATKE
xkykai Sep 9, 2024
ba5491d
fix closure to use only CATKE
xkykai Sep 10, 2024
21a8356
update CATKE configuration
xkykai Sep 17, 2024
de190d2
use older dependencies for compatibility with neural networks
xkykai Sep 17, 2024
7e7a04f
add fields to be calculated for CATKE
xkykai Sep 19, 2024
1c31868
local diffusivity for 2step calibration
xkykai Sep 19, 2024
3609d6c
new NN closure with nof and base boundary layer criteria
xkykai Sep 19, 2024
8eaaaae
fix bug in NN closure implementation
xkykai Sep 20, 2024
6535be5
using Grids.total_size
xkykai Sep 23, 2024
812df88
add using KernelParameters
xkykai Sep 23, 2024
eb50cca
update NN model
xkykai Sep 24, 2024
b2cd2fa
use BBL integral metric to compute base of boundary layer
xkykai Sep 24, 2024
d3fa8b8
Update xin_kai_vertical_diffusivity.jl
xkykai Oct 4, 2024
2448003
remove type piracy TEOS10.s
xkykai Oct 6, 2024
cb5f864
NN closure using BBL zone below nonbabkground kappa
xkykai Oct 6, 2024
f443296
run double gyre with NDE BBLkappazonelast41
xkykai Oct 6, 2024
2f3eed9
change initialized state to 8day restoration forcing
xkykai Oct 6, 2024
987f1ed
run double gyre withj baseclosure initialized
xkykai Oct 6, 2024
c4351f7
NN closure for augmenting flux in a zone below MLD
xkykai Oct 6, 2024
19f71d3
8 day relaxation double gyre for baseclosure
xkykai Oct 8, 2024
3c4102a
uncomment CairoMakie
xkykai Oct 8, 2024
b122e7b
NN closure and CATKE with 8 day restoration and warm flush
xkykai Oct 8, 2024
e0bd65a
fix initial temperature issue
xkykai Oct 8, 2024
8dcb6d3
add zonal average calculations to double gyre simulation
xkykai Oct 9, 2024
132188a
run double gyre with seasonal forcing
xkykai Oct 10, 2024
57c3ac5
increase simulation run time
xkykai Oct 16, 2024
d5a320e
wall restoration to maintain strratification
xkykai Oct 17, 2024
0cd618a
baseclosure doublegyre with wallrestoration
xkykai Oct 17, 2024
4abf7e0
change filenames
xkykai Oct 17, 2024
7cf7090
using wider zone for NN closure
xkykai Oct 17, 2024
70dbe2f
fix variable error
xkykai Oct 17, 2024
fd57bf8
update NN configuration
xkykai Oct 18, 2024
b582447
NDE double gyre script for seasonal forcing and wall restoration
xkykai Oct 18, 2024
6235bc4
run NNclosure with Ri nof BBLkappazonelast55
xkykai Oct 20, 2024
3609c21
updated NN model with no Ri
xkykai Oct 22, 2024
823e8c7
run double gyre with NNclosure with no Ri input kappazonelast55
xkykai Oct 22, 2024
893b7bc
add fluxes calculations and zonal average
xkykai Oct 24, 2024
ea77669
Merge branch 'main' into xk/embed-nn
xkykai Oct 24, 2024
3428f5d
Merge branch 'main' into xk/embed-nn
xkykai Oct 24, 2024
4a5a3e1
using centered second order in z instead of WENO
xkykai Oct 24, 2024
0f6bf66
remove background vertical scalar diffusivity
xkykai Oct 24, 2024
8385229
diffusivity fields indexing for base closure
xkykai Oct 24, 2024
45eef0a
change temperature restoration to 30days
xkykai Oct 30, 2024
e29bfd9
change vertical advection scheme to WENO5
xkykai Oct 30, 2024
0fba127
update neural network to new weights
xkykai Oct 30, 2024
73afa68
fix file name change
xkykai Oct 30, 2024
1fab19f
run double gyre with centered second order and wall restoration
xkykai Oct 31, 2024
49d901f
fix dynamic function invocation that doesn't affect the baseclosure s…
xkykai Oct 31, 2024
534d21d
add y resolution variable
xkykai Oct 31, 2024
2551c38
run NN closure recording xz slices
xkykai Nov 1, 2024
ff272c4
recording and plotting xz yz slices and fluxes
xkykai Nov 1, 2024
93bf0bc
change default z advection scheme to weno 5
xkykai Nov 1, 2024
d36677f
run with linear ramp seasonal forcing
xkykai Nov 6, 2024
ce68cc7
fix T_seasonal
xkykai Nov 7, 2024
e1603ca
fix temperature restoration
xkykai Nov 8, 2024
a3ab871
implement different NNclosure with Ri zone
xkykai Nov 21, 2024
51acb9d
run double gyre on updated closure
xkykai Nov 21, 2024
33f2ea8
new NN closure witth Rifirstzone
xkykai Nov 22, 2024
36cecb6
fix NN closure
xkykai Nov 22, 2024
946d710
fix NN_closure
xkykai Nov 22, 2024
0c55c44
updated base closure with new calibration
xkykai Nov 26, 2024
dada2c5
run base closure and CATKE with mode waters
xkykai Nov 26, 2024
1d0ef1f
fix advection scheme and file name
xkykai Nov 26, 2024
4898be4
new NN closure including Ri
xkykai Nov 26, 2024
3e84aaf
actual new closure including Ri
xkykai Nov 26, 2024
88f7fa4
run double gyre with new NN closure
xkykai Nov 26, 2024
030c890
run CATKE for 10800days
xkykai Nov 26, 2024
f4f7383
Merge branch 'main' into xk/embed-nn
xkykai Nov 28, 2024
35dbad3
flush mode water after equilibration
xkykai Dec 2, 2024
88e3cf1
fix filename
xkykai Dec 2, 2024
5b34027
Revert "fix filename"
xkykai Dec 3, 2024
8e458e0
Revert "flush mode water after equilibration"
xkykai Dec 3, 2024
34f7437
Revert "Merge branch 'main' into xk/embed-nn"
xkykai Dec 3, 2024
ebf6639
fix script changes for modewater after equilibration
xkykai Dec 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 225 additions & 0 deletions 2D_model_LES_sin.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
#using Pkg
using Oceananigans
using Printf
using Statistics

using Oceananigans
using Oceananigans.Units
using Oceananigans.OutputReaders: FieldTimeSeries
using Oceananigans.Grids: xnode, ynode, znode
using SeawaterPolynomials
using SeawaterPolynomials:TEOS10
import SeawaterPolynomials.TEOS10: s, ΔS, Sₐᵤ
s(Sᴬ::Number) = Sᴬ + ΔS >= 0 ? √((Sᴬ + ΔS) / Sₐᵤ) : NaN
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is an interesting idea. Just need to use non-short-circuiting logic and be mindful of number type and it could potentially go in SeawaterPolynomials

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was writing this in order to address the issue of

ERROR: DomainError with -2.0:
sqrt was called with a negative real argument but will only return a complex result if called with a complex argument. Try sqrt(Complex(x)).

when I was training my model, in order to continue the training with NaN.
Wouldn't doing non-short-circuiting logic mean that we are forced into a DomainError if Sᴬ + ΔS becomes negative?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, we could clip to avoid it, for example:

function s(Sᴬ)
    not_positive = Sᴬ + ΔS >= 0
    d = max(zero(Sᴬ), Sᴬ + ΔS)
    return ifelse(not_positive, NaN, (d / Sₐᵤ))
end

but after reflecting on this a bit more, I think it's better to receive a DomainError from this point, than to get a NaN and not know why.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I believe I started doing this to throw NaNs when using EKI, which was necessary to continue training, but apart from training purposes there shouldn't be a need to do this

Copy link
Member

@glwagner glwagner Oct 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh no but I see your point now. It's part of the issue of failure handling with EKI.

The thing is, we really do want to support training / automatic calibration and it shouldn't require hacks like this, I feel this really impacts reproducibility and understandability (technically this is type piracy...)

Maybe clipping+NaN should be an option of the equation of state then.

Copy link
Member

@glwagner glwagner Oct 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option would be to wrap run!(simulation) inside a try/catch so that, if a simulation errors, we can replace the output with NaN / mark the simulation as failed for the purpose of estimating parameters. This would work too right? That might be simpler (simply capturing errors after they occur) than trying to prevent any errors from occurring, which is possible in this particular case because we own SeawaterPolynomials but is not generally a solution.

Copy link
Collaborator Author

@xkykai xkykai Oct 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect what one could do is do clipping + throw NaN with a warning message that this is due to the model having salinity values that are beyond the regime where TEOS-10 is correct. Or perhaps throw a warning whenever the temperature and salinity values are outside of the reasonable regime of TEOS10 (between 0-40 °C and 0-42 psu)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't warn from inside a kernel though.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, so try-catch it is then

using Glob

# Architecture
model_architecture = GPU()

# number of grid points
Ny = 20000
Nz = 256

const Ly = 40kilometers
const Lz = 512meters

grid = RectilinearGrid(model_architecture,
topology = (Flat, Bounded, Bounded),
size = (Ny, Nz),
halo = (5, 5),
y = (0, Ly),
z = (-Lz, 0))

@info "Built a grid: $grid."

#####
##### Boundary conditions
#####
const dTdz = 0.014
const dSdz = 0.0021

const T_surface = 20.0
const S_surface = 36.6
const max_temperature_flux = 3e-4

FILE_DIR = "./LES/NN_2D_channel_sin_cooling_$(max_temperature_flux)_LES"
mkpath(FILE_DIR)

@inline function temperature_flux(y, t)
return max_temperature_flux * sin(π * y / Ly)
end

T_bcs = FieldBoundaryConditions(top=FluxBoundaryCondition(temperature_flux))

#####
##### Coriolis
#####

const f₀ = 8e-5
coriolis = FPlane(f=f₀)

#####
##### Forcing and initial condition
#####
T_initial(y, z) = dTdz * z + T_surface
S_initial(y, z) = dSdz * z + S_surface

#####
##### Model building
#####

@info "Building a model..."

model = NonhydrostaticModel(; grid = grid,
advection = WENO(order=9),
coriolis = coriolis,
buoyancy = SeawaterBuoyancy(equation_of_state=TEOS10.TEOS10EquationOfState()),
tracers = (:T, :S),
timestepper = :RungeKutta3,
closure = nothing,
boundary_conditions = (; T=T_bcs))

@info "Built $model."

#####
##### Initial conditions
#####

# resting initial condition
noise(z) = rand() * exp(z / 8)

T_initial_noisy(y, z) = T_initial(y, z) + 1e-6 * noise(z)
S_initial_noisy(y, z) = S_initial(y, z) + 1e-6 * noise(z)

set!(model, T=T_initial_noisy, S=S_initial_noisy)
#####
##### Simulation building
#####
simulation = Simulation(model, Δt = 0.1, stop_time = 30days)

# add timestep wizard callback
wizard = TimeStepWizard(cfl=0.6, max_change=1.05, max_Δt=20minutes)
simulation.callbacks[:wizard] = Callback(wizard, IterationInterval(10))

# add progress callback
wall_clock = [time_ns()]

function print_progress(sim)
@printf("[%05.2f%%] i: %d, t: %s, wall time: %s, max(u): %6.3e, max(v): %6.3e, max(T): %6.3e, max(S): %6.3e, next Δt: %s\n",
100 * (sim.model.clock.time / sim.stop_time),
sim.model.clock.iteration,
prettytime(sim.model.clock.time),
prettytime(1e-9 * (time_ns() - wall_clock[1])),
maximum(sim.model.velocities.u),
maximum(sim.model.velocities.v),
maximum(sim.model.tracers.T),
maximum(sim.model.tracers.S),
prettytime(sim.Δt))

wall_clock[1] = time_ns()

return nothing
end

simulation.callbacks[:print_progress] = Callback(print_progress, IterationInterval(1000))

#####
##### Diagnostics
#####

u, w = model.velocities.u, model.velocities.w
v = @at (Center, Center, Center) model.velocities.v
T, S = model.tracers.T, model.tracers.S

outputs = (; u, v, w, T, S)

#####
##### Build checkpointer and output writer
#####
simulation.output_writers[:jld2] = JLD2OutputWriter(model, outputs,
filename = "$(FILE_DIR)/instantaneous_fields.jld2",
schedule = TimeInterval(1hour))

simulation.output_writers[:checkpointer] = Checkpointer(model,
schedule = TimeInterval(1day),
prefix = "$(FILE_DIR)/checkpointer",
overwrite_existing = true)

@info "Running the simulation..."

try
files = readdir(FILE_DIR)
checkpoint_files = files[occursin.("checkpointer_iteration", files)]
if !isempty(checkpoint_files)
checkpoint_iters = parse.(Int, [filename[findfirst("iteration", filename)[end]+1:findfirst(".jld2", filename)[1]-1] for filename in checkpoint_files])
pickup_iter = maximum(checkpoint_iters)
run!(simulation, pickup="$(FILE_DIR)/checkpointer_iteration$(pickup_iter).jld2")
else
run!(simulation)
end
catch err
@info "run! threw an error! The error message is"
showerror(stdout, err)
end

checkpointers = glob("$(FILE_DIR)/checkpointer_iteration*.jld2")
if !isempty(checkpointers)
rm.(checkpointers)
end

# #####
# ##### Visualization
# #####
#%%
using CairoMakie


u_data = FieldTimeSeries("./NN_2D_channel_sin_cooling_$(max_temperature_flux)_LES.jld2", "u", backend=OnDisk())
v_data = FieldTimeSeries("./NN_2D_channel_sin_cooling_$(max_temperature_flux)_LES.jld2", "v", backend=OnDisk())
T_data = FieldTimeSeries("./NN_2D_channel_sin_cooling_$(max_temperature_flux)_LES.jld2", "T", backend=OnDisk())
S_data = FieldTimeSeries("./NN_2D_channel_sin_cooling_$(max_temperature_flux)_LES.jld2", "S", backend=OnDisk())

yC = ynodes(T_data.grid, Center())
yF = ynodes(T_data.grid, Face())

zC = znodes(T_data.grid, Center())
zF = znodes(T_data.grid, Face())

Nt = length(T_data.times)
#%%
fig = Figure(size = (1500, 900))
axu = CairoMakie.Axis(fig[1, 1], xlabel = "y (m)", ylabel = "z (m)", title = "u")
axv = CairoMakie.Axis(fig[1, 3], xlabel = "y (m)", ylabel = "z (m)", title = "v")
axT = CairoMakie.Axis(fig[2, 1], xlabel = "y (m)", ylabel = "z (m)", title = "Temperature")
axS = CairoMakie.Axis(fig[2, 3], xlabel = "y (m)", ylabel = "z (m)", title = "Salinity")
n = Obeservable(1)

uₙ = @lift interior(u_data[$n], 1, :, :)
vₙ = @lift interior(v_data[$n], 1, :, :)
Tₙ = @lift interior(T_data[$n], 1, :, :)
Sₙ = @lift interior(S_data[$n], 1, :, :)

ulim = @lift (-maximum([maximum(abs, $uₙ), 1e-16]), maximum([maximum(abs, $uₙ), 1e-16]))
vlim = @lift (-maximum([maximum(abs, $vₙ), 1e-16]), maximum([maximum(abs, $vₙ), 1e-16]))
Tlim = (minimum(interior(T_data[1])), maximum(interior(T_data[1])))
Slim = (minimum(interior(S_data[1])), maximum(interior(S_data[1])))

title_str = @lift "Time: $(round(T_data.times[$n] / 86400, digits=2)) days"
Label(fig[0, :], title_str, tellwidth = false)

hu = heatmap!(axu, yC, zC, uₙ, colormap=:RdBu_9, colorrange=ulim)
hv = heatmap!(axv, yC, zC, vₙ, colormap=:RdBu_9, colorrange=vlim)
hT = heatmap!(axT, yC, zC, Tₙ, colorrange=Tlim)
hS = heatmap!(axS, yC, zC, Sₙ, colorrange=Slim)

Colorbar(fig[1, 2], hu, label = "(m/s)")
Colorbar(fig[1, 4], hv, label = "(m/s)")
Colorbar(fig[2, 2], hT, label = "(°C)")
Colorbar(fig[2, 4], hS, label = "(psu)")

CairoMakie.record(fig, "$(FILE_DIR)/2D_sin_cooling_$(max_temperature_flux)_30days.mp4", 1:Nt, framerate=15) do nn
n[] = nn
end

# display(fig)
#%%
Loading