Skip to content

Commit

Permalink
Keep subgrid.level vector in order of subgrid_id
Browse files Browse the repository at this point in the history
  • Loading branch information
visr committed Dec 16, 2024
1 parent 24b8fff commit 1a86893
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 32 deletions.
19 changes: 12 additions & 7 deletions core/src/callback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -676,18 +676,23 @@ function update_subgrid_level!(integrator)::Nothing
basin_level = p.basin.current_properties.current_level[parent(du)]
subgrid = integrator.p.subgrid

i = 0
# First update the all the subgrids with static h(h) relations
for (index, hh_itp) in zip(subgrid.basin_index, subgrid.interpolations)
i += 1
subgrid.level[i] = hh_itp(basin_level[index])
for (level_index, basin_index, hh_itp) in zip(
subgrid.level_index_static,
subgrid.basin_index_static,
subgrid.interpolations_static,
)
subgrid.level[level_index] = hh_itp(basin_level[basin_index])
end
# Then update the subgrids with dynamic h(h) relations
for (index, lookup) in zip(subgrid.basin_index_time, current_interpolation_index)
i += 1
for (level_index, basin_index, lookup) in zip(
subgrid.level_index_time,
subgrid.basin_index_time,
subgrid.current_interpolation_index,
)
itp_index = lookup(t)
hh_itp = subgrid.interpolations_time[itp_index]
subgrid.level[i] = hh_itp(basin_level[index])
subgrid.level[level_index] = hh_itp(basin_level[basin_index])
end
end

Expand Down
16 changes: 11 additions & 5 deletions core/src/parameter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -881,16 +881,22 @@ end

"Subgrid linearly interpolates basin levels."
@kwdef struct Subgrid
# cache the current level for static subgrids followed by dynamic subgrids
# level of each subgrid (static and dynamic) ordered by subgrid_id
level::Vector{Float64}
# static
subgrid_id::Vector{Int32}
basin_index::Vector{Int32}
subgrid_id_static::Vector{Int32}
# index into the basin.current_level vector for each static subgrid_id
basin_index_static::Vector{Int}
# index into the subgrid.level vector for each static subgrid_id
level_index_static::Vector{Int}
# per subgrid one relation
interpolations::Vector{ScalarInterpolation}
interpolations_static::Vector{ScalarInterpolation}
# dynamic
subgrid_id_time::Vector{Int32}
basin_index_time::Vector{Int32}
# index into the basin.current_level vector for each dynamic subgrid_id
basin_index_time::Vector{Int}
# index into the subgrid.level vector for each dynamic subgrid_id
level_index_time::Vector{Int}
# per subgrid n relations, n being the number of timesteps for that subgrid
interpolations_time::Vector{ScalarInterpolation}
# per subgrid 1 lookup from t to an index in interpolations_time
Expand Down
41 changes: 27 additions & 14 deletions core/src/read.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1258,9 +1258,9 @@ function Subgrid(db::DB, config::Config, basin::Basin)::Subgrid
node_to_basin = Dict{Int32, Int}(
Int32(node_id) => index for (index, node_id) in enumerate(basin.node_id)
)
subgrid_ids = Int32[]
basin_index = Int32[]
interpolations = ScalarInterpolation[]
subgrid_id_static = Int32[]
basin_index_static = Int[]
interpolations_static = ScalarInterpolation[]
has_error = false

for group in IterTools.groupby(row -> row.subgrid_id, static)
Expand All @@ -1276,15 +1276,15 @@ function Subgrid(db::DB, config::Config, basin::Basin)::Subgrid
# Ensure it doesn't extrapolate before the first value.
pushfirst!(subgrid_level, first(subgrid_level))
pushfirst!(basin_level, nextfloat(-Inf))
new_interp = LinearInterpolation(
hh_itp = LinearInterpolation(
subgrid_level,
basin_level;
extrapolate = true,
cache_parameters = true,
)
push!(subgrid_ids, subgrid_id)
push!(basin_index, node_to_basin[node_id])
push!(interpolations, new_interp)
push!(subgrid_id_static, subgrid_id)
push!(basin_index_static, node_to_basin[node_id])
push!(interpolations_static, hh_itp)
else
has_error = true
end
Expand All @@ -1293,7 +1293,7 @@ function Subgrid(db::DB, config::Config, basin::Basin)::Subgrid
has_error && error("Invalid Basin / subgrid table.")

subgrid_id_time = Int32[first(time.subgrid_id)]
basin_index_time = Int32[node_to_basin[first(time.node_id)]]
basin_index_time = Int[node_to_basin[first(time.node_id)]]
interpolations_time = ScalarInterpolation[]
current_interpolation_index = IndexLookup[]

Expand All @@ -1317,7 +1317,7 @@ function Subgrid(db::DB, config::Config, basin::Basin)::Subgrid
# Ensure it doesn't extrapolate before the first value.
pushfirst!(subgrid_level, first(subgrid_level))
pushfirst!(basin_level, nextfloat(-Inf))
new_interp = LinearInterpolation(
hh_itp = LinearInterpolation(
subgrid_level,
basin_level;
extrapolate = true,
Expand All @@ -1336,7 +1336,7 @@ function Subgrid(db::DB, config::Config, basin::Basin)::Subgrid
end
push!(lookup_index, interpolation_index)
push!(lookup_time, time_group)
push!(interpolations_time, new_interp)
push!(interpolations_time, hh_itp)
else
has_error = true
end
Expand All @@ -1348,15 +1348,28 @@ function Subgrid(db::DB, config::Config, basin::Basin)::Subgrid
end

has_error && error("Invalid Basin / subgrid_time table.")
level = fill(NaN, length(subgrid_ids) + length(subgrid_id_time))
level = fill(NaN, length(subgrid_id_static) + length(subgrid_id_time))

# Find the level indices
level_index_static = zeros(Int, length(subgrid_id_static))
level_index_time = zeros(Int, length(subgrid_id_time))
subgrid_ids = sort(vcat(subgrid_id_static, subgrid_id_time))
for (i, subgrid_id) in enumerate(subgrid_id_static)
level_index_static[i] = findsorted(subgrid_ids, subgrid_id)
end
for (i, subgrid_id) in enumerate(subgrid_id_time)
level_index_time[i] = findsorted(subgrid_ids, subgrid_id)
end

return Subgrid(;
level,
subgrid_id = subgrid_ids,
basin_index,
interpolations,
subgrid_id_static,
basin_index_static,
level_index_static,
interpolations_static,
subgrid_id_time,
basin_index_time,
level_index_time,
interpolations_time,
current_interpolation_index,
)
Expand Down
11 changes: 7 additions & 4 deletions core/src/write.jl
Original file line number Diff line number Diff line change
Expand Up @@ -376,11 +376,14 @@ function subgrid_level_table(
(; t, saveval) = saved.subgrid_level
subgrid = integrator.p.subgrid

nelem = length(subgrid.subgrid_id)
nelem = length(subgrid.level)
ntsteps = length(t)

time = repeat(datetime_since.(t, config.starttime); inner = nelem)
subgrid_id = repeat(subgrid.subgrid_id; outer = ntsteps)
subgrid_id = repeat(
sort(vcat(subgrid.subgrid_id_static, subgrid.subgrid_id_time));
outer = ntsteps,
)
subgrid_level = FlatVector(saveval)
return (; time, subgrid_id, subgrid_level)
end
Expand Down Expand Up @@ -412,9 +415,9 @@ function write_arrow(
mkpath(dirname(path))
try
Arrow.write(path, table; compress, metadata)
catch
catch e
@error "Failed to write results, file may be locked." path
error("Failed to write results.")
rethrow(e)
end
return nothing
end
Expand Down
9 changes: 7 additions & 2 deletions python/ribasim_testmodels/ribasim_testmodels/two_basin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any

from ribasim.config import Node
from ribasim.config import Node, Results
from ribasim.input_base import TableModel
from ribasim.model import Model
from ribasim.nodes import basin, flow_boundary, tabulated_rating_curve
Expand All @@ -18,7 +18,12 @@ def two_basin_model() -> Model:
infiltrates in the left basin, and exfiltrates in the right basin.
The right basin fills up and discharges over the rating curve.
"""
model = Model(starttime="2020-01-01", endtime="2021-01-01", crs="EPSG:28992")
model = Model(
starttime="2020-01-01",
endtime="2021-01-01",
crs="EPSG:28992",
results=Results(subgrid=True),
)

model.flow_boundary.add(
Node(1, Point(0, 0)), [flow_boundary.Static(flow_rate=[1e-2])]
Expand Down

0 comments on commit 1a86893

Please sign in to comment.