diff --git a/core/src/callback.jl b/core/src/callback.jl index 4e2440210..d332d7a00 100644 --- a/core/src/callback.jl +++ b/core/src/callback.jl @@ -676,18 +676,23 @@ function update_subgrid_level!(integrator)::Nothing basin_level = p.basin.current_properties.current_level[parent(du)] subgrid = integrator.p.subgrid - i = 0 # First update the all the subgrids with static h(h) relations - for (index, hh_itp) in zip(subgrid.basin_index, subgrid.interpolations) - i += 1 - subgrid.level[i] = hh_itp(basin_level[index]) + for (level_index, basin_index, hh_itp) in zip( + subgrid.level_index_static, + subgrid.basin_index_static, + subgrid.interpolations_static, + ) + subgrid.level[level_index] = hh_itp(basin_level[basin_index]) end # Then update the subgrids with dynamic h(h) relations - for (index, lookup) in zip(subgrid.basin_index_time, current_interpolation_index) - i += 1 + for (level_index, basin_index, lookup) in zip( + subgrid.level_index_time, + subgrid.basin_index_time, + subgrid.current_interpolation_index, + ) itp_index = lookup(t) hh_itp = subgrid.interpolations_time[itp_index] - subgrid.level[i] = hh_itp(basin_level[index]) + subgrid.level[level_index] = hh_itp(basin_level[basin_index]) end end diff --git a/core/src/parameter.jl b/core/src/parameter.jl index a72b5796e..2c4b895d9 100644 --- a/core/src/parameter.jl +++ b/core/src/parameter.jl @@ -881,16 +881,22 @@ end "Subgrid linearly interpolates basin levels." @kwdef struct Subgrid - # cache the current level for static subgrids followed by dynamic subgrids + # level of each subgrid (static and dynamic) ordered by subgrid_id level::Vector{Float64} # static - subgrid_id::Vector{Int32} - basin_index::Vector{Int32} + subgrid_id_static::Vector{Int32} + # index into the basin.current_level vector for each static subgrid_id + basin_index_static::Vector{Int} + # index into the subgrid.level vector for each static subgrid_id + level_index_static::Vector{Int} # per subgrid one relation - interpolations::Vector{ScalarInterpolation} + interpolations_static::Vector{ScalarInterpolation} # dynamic subgrid_id_time::Vector{Int32} - basin_index_time::Vector{Int32} + # index into the basin.current_level vector for each dynamic subgrid_id + basin_index_time::Vector{Int} + # index into the subgrid.level vector for each dynamic subgrid_id + level_index_time::Vector{Int} # per subgrid n relations, n being the number of timesteps for that subgrid interpolations_time::Vector{ScalarInterpolation} # per subgrid 1 lookup from t to an index in interpolations_time diff --git a/core/src/read.jl b/core/src/read.jl index f9d5b2fe2..7e068a906 100644 --- a/core/src/read.jl +++ b/core/src/read.jl @@ -1258,9 +1258,9 @@ function Subgrid(db::DB, config::Config, basin::Basin)::Subgrid node_to_basin = Dict{Int32, Int}( Int32(node_id) => index for (index, node_id) in enumerate(basin.node_id) ) - subgrid_ids = Int32[] - basin_index = Int32[] - interpolations = ScalarInterpolation[] + subgrid_id_static = Int32[] + basin_index_static = Int[] + interpolations_static = ScalarInterpolation[] has_error = false for group in IterTools.groupby(row -> row.subgrid_id, static) @@ -1276,15 +1276,15 @@ function Subgrid(db::DB, config::Config, basin::Basin)::Subgrid # Ensure it doesn't extrapolate before the first value. pushfirst!(subgrid_level, first(subgrid_level)) pushfirst!(basin_level, nextfloat(-Inf)) - new_interp = LinearInterpolation( + hh_itp = LinearInterpolation( subgrid_level, basin_level; extrapolate = true, cache_parameters = true, ) - push!(subgrid_ids, subgrid_id) - push!(basin_index, node_to_basin[node_id]) - push!(interpolations, new_interp) + push!(subgrid_id_static, subgrid_id) + push!(basin_index_static, node_to_basin[node_id]) + push!(interpolations_static, hh_itp) else has_error = true end @@ -1293,7 +1293,7 @@ function Subgrid(db::DB, config::Config, basin::Basin)::Subgrid has_error && error("Invalid Basin / subgrid table.") subgrid_id_time = Int32[first(time.subgrid_id)] - basin_index_time = Int32[node_to_basin[first(time.node_id)]] + basin_index_time = Int[node_to_basin[first(time.node_id)]] interpolations_time = ScalarInterpolation[] current_interpolation_index = IndexLookup[] @@ -1317,7 +1317,7 @@ function Subgrid(db::DB, config::Config, basin::Basin)::Subgrid # Ensure it doesn't extrapolate before the first value. pushfirst!(subgrid_level, first(subgrid_level)) pushfirst!(basin_level, nextfloat(-Inf)) - new_interp = LinearInterpolation( + hh_itp = LinearInterpolation( subgrid_level, basin_level; extrapolate = true, @@ -1336,7 +1336,7 @@ function Subgrid(db::DB, config::Config, basin::Basin)::Subgrid end push!(lookup_index, interpolation_index) push!(lookup_time, time_group) - push!(interpolations_time, new_interp) + push!(interpolations_time, hh_itp) else has_error = true end @@ -1348,15 +1348,28 @@ function Subgrid(db::DB, config::Config, basin::Basin)::Subgrid end has_error && error("Invalid Basin / subgrid_time table.") - level = fill(NaN, length(subgrid_ids) + length(subgrid_id_time)) + level = fill(NaN, length(subgrid_id_static) + length(subgrid_id_time)) + + # Find the level indices + level_index_static = zeros(Int, length(subgrid_id_static)) + level_index_time = zeros(Int, length(subgrid_id_time)) + subgrid_ids = sort(vcat(subgrid_id_static, subgrid_id_time)) + for (i, subgrid_id) in enumerate(subgrid_id_static) + level_index_static[i] = findsorted(subgrid_ids, subgrid_id) + end + for (i, subgrid_id) in enumerate(subgrid_id_time) + level_index_time[i] = findsorted(subgrid_ids, subgrid_id) + end return Subgrid(; level, - subgrid_id = subgrid_ids, - basin_index, - interpolations, + subgrid_id_static, + basin_index_static, + level_index_static, + interpolations_static, subgrid_id_time, basin_index_time, + level_index_time, interpolations_time, current_interpolation_index, ) diff --git a/core/src/write.jl b/core/src/write.jl index b81e5d01d..4ccc91312 100644 --- a/core/src/write.jl +++ b/core/src/write.jl @@ -376,11 +376,14 @@ function subgrid_level_table( (; t, saveval) = saved.subgrid_level subgrid = integrator.p.subgrid - nelem = length(subgrid.subgrid_id) + nelem = length(subgrid.level) ntsteps = length(t) time = repeat(datetime_since.(t, config.starttime); inner = nelem) - subgrid_id = repeat(subgrid.subgrid_id; outer = ntsteps) + subgrid_id = repeat( + sort(vcat(subgrid.subgrid_id_static, subgrid.subgrid_id_time)); + outer = ntsteps, + ) subgrid_level = FlatVector(saveval) return (; time, subgrid_id, subgrid_level) end @@ -412,9 +415,9 @@ function write_arrow( mkpath(dirname(path)) try Arrow.write(path, table; compress, metadata) - catch + catch e @error "Failed to write results, file may be locked." path - error("Failed to write results.") + rethrow(e) end return nothing end diff --git a/python/ribasim_testmodels/ribasim_testmodels/two_basin.py b/python/ribasim_testmodels/ribasim_testmodels/two_basin.py index acc52c4aa..0aecdc47e 100644 --- a/python/ribasim_testmodels/ribasim_testmodels/two_basin.py +++ b/python/ribasim_testmodels/ribasim_testmodels/two_basin.py @@ -1,6 +1,6 @@ from typing import Any -from ribasim.config import Node +from ribasim.config import Node, Results from ribasim.input_base import TableModel from ribasim.model import Model from ribasim.nodes import basin, flow_boundary, tabulated_rating_curve @@ -18,7 +18,12 @@ def two_basin_model() -> Model: infiltrates in the left basin, and exfiltrates in the right basin. The right basin fills up and discharges over the rating curve. """ - model = Model(starttime="2020-01-01", endtime="2021-01-01", crs="EPSG:28992") + model = Model( + starttime="2020-01-01", + endtime="2021-01-01", + crs="EPSG:28992", + results=Results(subgrid=True), + ) model.flow_boundary.add( Node(1, Point(0, 0)), [flow_boundary.Static(flow_rate=[1e-2])]