Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Basin / subgrid_time table #1975

Merged
merged 16 commits into from
Dec 20, 2024
1 change: 1 addition & 0 deletions core/src/Ribasim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ using PreallocationTools: LazyBufferCache
# basin profiles and TabulatedRatingCurve. See also the node
# references in the docs.
using DataInterpolations:
ConstantInterpolation,
LinearInterpolation,
LinearInterpolationIntInv,
invert_integral,
Expand Down
22 changes: 19 additions & 3 deletions core/src/callback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -671,12 +671,28 @@ function apply_parameter_update!(parameter_update)::Nothing
end

function update_subgrid_level!(integrator)::Nothing
(; p) = integrator
(; p, t) = integrator
du = get_du(integrator)
basin_level = p.basin.current_properties.current_level[parent(du)]
subgrid = integrator.p.subgrid
for (i, (index, interp)) in enumerate(zip(subgrid.basin_index, subgrid.interpolations))
subgrid.level[i] = interp(basin_level[index])

# First update the all the subgrids with static h(h) relations
for (level_index, basin_index, hh_itp) in zip(
subgrid.level_index_static,
subgrid.basin_index_static,
subgrid.interpolations_static,
)
subgrid.level[level_index] = hh_itp(basin_level[basin_index])
end
# Then update the subgrids with dynamic h(h) relations
for (level_index, basin_index, lookup) in zip(
subgrid.level_index_time,
subgrid.basin_index_time,
subgrid.current_interpolation_index,
)
itp_index = lookup(t)
hh_itp = subgrid.interpolations_time[itp_index]
subgrid.level[level_index] = hh_itp(basin_level[basin_index])
end
end

Expand Down
31 changes: 28 additions & 3 deletions core/src/parameter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ end

Base.to_index(id::NodeID) = Int(id.value)

"LinearInterpolation from a Float64 to a Float64"
const ScalarInterpolation = LinearInterpolation{
Vector{Float64},
Vector{Float64},
Expand All @@ -105,6 +106,10 @@ const ScalarInterpolation = LinearInterpolation{
(1,),
}

"ConstantInterpolation from a Float64 to an Int, used to look up indices over time"
const IndexLookup =
ConstantInterpolation{Vector{Int64}, Vector{Float64}, Vector{Float64}, Int64, (1,)}

set_zero!(v) = v .= zero(eltype(v))
const Cache = LazyBufferCache{Returns{Int}, typeof(set_zero!)}

Expand Down Expand Up @@ -867,10 +872,30 @@ end

"Subgrid linearly interpolates basin levels."
@kwdef struct Subgrid
subgrid_id::Vector{Int32}
basin_index::Vector{Int32}
interpolations::Vector{ScalarInterpolation}
# current level of each subgrid (static and dynamic) ordered by subgrid_id
level::Vector{Float64}

# Static part
# Static subgrid ids
subgrid_id_static::Vector{Int32}
# index into the basin.current_level vector for each static subgrid_id
basin_index_static::Vector{Int}
# index into the subgrid.level vector for each static subgrid_id
level_index_static::Vector{Int}
# per subgrid one relation
interpolations_static::Vector{ScalarInterpolation}

# Dynamic part
# Dynamic subgrid ids
subgrid_id_time::Vector{Int32}
# index into the basin.current_level vector for each dynamic subgrid_id
basin_index_time::Vector{Int}
# index into the subgrid.level vector for each dynamic subgrid_id
level_index_time::Vector{Int}
# per subgrid n relations, n being the number of timesteps for that subgrid
interpolations_time::Vector{ScalarInterpolation}
# per subgrid 1 lookup from t to an index in interpolations_time
current_interpolation_index::Vector{IndexLookup}
end

"""
Expand Down
167 changes: 139 additions & 28 deletions core/src/read.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,22 @@ function parse_static_and_time(
return out, !errors
end

"""
Retrieve and validate the split of node IDs between static and time tables.

For node types that can have a part of the parameters defined statically and a part dynamically,
this checks if each ID is defined exactly once in either table.

The `is_complete` argument allows disabling the check that all Node IDs of type `node_type`
are either in the `static` or `time` table.
This is not required for Subgrid since not all Basins need to have subgrids.
"""
function static_and_time_node_ids(
db::DB,
static::StructVector,
time::StructVector,
node_type::NodeType.T,
node_type::NodeType.T;
is_complete::Bool = true,
visr marked this conversation as resolved.
Show resolved Hide resolved
)::Tuple{Set{NodeID}, Set{NodeID}, Vector{NodeID}, Bool}
node_ids = get_node_ids(db, node_type)
ids = Int32.(node_ids)
Expand All @@ -205,7 +216,7 @@ function static_and_time_node_ids(
errors = true
@error "$node_type cannot be in both static and time tables, found these node IDs in both: $doubles."
end
if !issetequal(node_ids, union(static_node_ids, time_node_ids))
if is_complete && !issetequal(node_ids, union(static_node_ids, time_node_ids))
errors = true
@error "$node_type node IDs don't match."
end
Expand Down Expand Up @@ -1227,46 +1238,146 @@ function FlowDemand(db::DB, config::Config)::FlowDemand
)
end

function push_lookup!(
current_interpolation_index::Vector{IndexLookup},
lookup_index::Vector{Int},
lookup_time::Vector{Float64},
)
index_lookup = ConstantInterpolation(
lookup_index,
lookup_time;
extrapolate = true,
cache_parameters = true,
)
push!(current_interpolation_index, index_lookup)
end

function Subgrid(db::DB, config::Config, basin::Basin)::Subgrid
node_to_basin = Dict(node_id => index for (index, node_id) in enumerate(basin.node_id))
tables = load_structvector(db, config, BasinSubgridV1)
time = load_structvector(db, config, BasinSubgridTimeV1)
static = load_structvector(db, config, BasinSubgridV1)
node_table = get_node_ids(db, NodeType.Basin)

subgrid_ids = Int32[]
basin_index = Int32[]
interpolations = ScalarInterpolation[]
has_error = false
for group in IterTools.groupby(row -> row.subgrid_id, tables)
# Since not all Basins need to have subgrids, don't enforce completeness.
_, _, _, valid =
static_and_time_node_ids(db, static, time, "Basin"; is_complete = false)
if !valid
error("Problems encountered when parsing Subgrid static and time node IDs.")
end

node_to_basin = Dict{Int32, Int}(
Int32(node_id) => index for (index, node_id) in enumerate(basin.node_id)
)
subgrid_id_static = Int32[]
basin_index_static = Int[]
interpolations_static = ScalarInterpolation[]

for group in IterTools.groupby(row -> row.subgrid_id, static)
subgrid_id = first(getproperty.(group, :subgrid_id))
node_id = NodeID(NodeType.Basin, first(getproperty.(group, :node_id)), node_table)
node_id = first(getproperty.(group, :node_id))
visr marked this conversation as resolved.
Show resolved Hide resolved
basin_level = getproperty.(group, :basin_level)
subgrid_level = getproperty.(group, :subgrid_level)

is_valid =
valid_subgrid(subgrid_id, node_id, node_to_basin, basin_level, subgrid_level)
!is_valid && error("Invalid Basin / subgrid table.")

# Ensure it doesn't extrapolate before the first value.
pushfirst!(subgrid_level, first(subgrid_level))
pushfirst!(basin_level, nextfloat(-Inf))
hh_itp = LinearInterpolation(
subgrid_level,
basin_level;
extrapolate = true,
cache_parameters = true,
)
push!(subgrid_id_static, subgrid_id)
push!(basin_index_static, node_to_basin[node_id])
push!(interpolations_static, hh_itp)
end

if is_valid
# Ensure it doesn't extrapolate before the first value.
pushfirst!(subgrid_level, first(subgrid_level))
pushfirst!(basin_level, nextfloat(-Inf))
new_interp = LinearInterpolation(
subgrid_level,
basin_level;
extrapolate = true,
cache_parameters = true,
)
push!(subgrid_ids, subgrid_id)
push!(basin_index, node_to_basin[node_id])
push!(interpolations, new_interp)
else
has_error = true
subgrid_id_time = Int32[]
basin_index_time = Int[]
interpolations_time = ScalarInterpolation[]
current_interpolation_index = IndexLookup[]

# Push the first subgrid_id and basin_index
if length(time) > 0
push!(subgrid_id_time, first(time.subgrid_id))
push!(basin_index_time, node_to_basin[first(time.node_id)])
end

# Initialize index_lookup contents
lookup_time = Float64[]
lookup_index = Int[]

interpolation_index = 0
for group in IterTools.groupby(row -> (row.subgrid_id, row.time), time)
evetion marked this conversation as resolved.
Show resolved Hide resolved
interpolation_index += 1
subgrid_id = first(getproperty.(group, :subgrid_id))
time_group = seconds_since(first(getproperty.(group, :time)), config.starttime)
node_id = first(getproperty.(group, :node_id))
basin_level = getproperty.(group, :basin_level)
subgrid_level = getproperty.(group, :subgrid_level)

is_valid =
valid_subgrid(subgrid_id, node_id, node_to_basin, basin_level, subgrid_level)
!is_valid && error("Invalid Basin / subgrid_time table.")

# Ensure it doesn't extrapolate before the first value.
pushfirst!(subgrid_level, first(subgrid_level))
pushfirst!(basin_level, nextfloat(-Inf))
hh_itp = LinearInterpolation(
subgrid_level,
basin_level;
extrapolate = true,
cache_parameters = true,
)
# # These should only be pushed when the subgrid_id has changed
if subgrid_id_time[end] != subgrid_id
# Push the completed index_lookup of the previous subgrid_id
push_lookup!(current_interpolation_index, lookup_index, lookup_time)
# Push the new subgrid_id and basin_index
push!(subgrid_id_time, subgrid_id)
push!(basin_index_time, node_to_basin[node_id])
# Start new index_lookup contents
lookup_time = Float64[]
lookup_index = Int[]
end
push!(lookup_index, interpolation_index)
push!(lookup_time, time_group)
push!(interpolations_time, hh_itp)
end

# Push completed IndexLookup of the last group
if interpolation_index > 0
push_lookup!(current_interpolation_index, lookup_index, lookup_time)
end

has_error && error("Invalid Basin / subgrid table.")
level = fill(NaN, length(subgrid_ids))
level = fill(NaN, length(subgrid_id_static) + length(subgrid_id_time))

return Subgrid(; subgrid_id = subgrid_ids, basin_index, interpolations, level)
# Find the level indices
evetion marked this conversation as resolved.
Show resolved Hide resolved
level_index_static = zeros(Int, length(subgrid_id_static))
level_index_time = zeros(Int, length(subgrid_id_time))
subgrid_ids = sort(vcat(subgrid_id_static, subgrid_id_time))
for (i, subgrid_id) in enumerate(subgrid_id_static)
level_index_static[i] = findsorted(subgrid_ids, subgrid_id)
end
for (i, subgrid_id) in enumerate(subgrid_id_time)
level_index_time[i] = findsorted(subgrid_ids, subgrid_id)
end

return Subgrid(;
level,
subgrid_id_static,
basin_index_static,
level_index_static,
interpolations_static,
subgrid_id_time,
basin_index_time,
level_index_time,
interpolations_time,
current_interpolation_index,
)
end

function Allocation(db::DB, config::Config, graph::MetaGraph)::Allocation
Expand Down
12 changes: 12 additions & 0 deletions core/src/schema.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
@schema "ribasim.basin.profile" BasinProfile
@schema "ribasim.basin.state" BasinState
@schema "ribasim.basin.subgrid" BasinSubgrid
@schema "ribasim.basin.subgridtime" BasinSubgridTime
@schema "ribasim.basin.concentration" BasinConcentration
@schema "ribasim.basin.concentrationexternal" BasinConcentrationExternal
@schema "ribasim.basin.concentrationstate" BasinConcentrationState
Expand Down Expand Up @@ -58,8 +59,11 @@ function nodetype(
type_string = string(T)
elements = split(type_string, '.'; limit = 3)
last_element = last(elements)
# Special case last elements that need an underscore
if startswith(last_element, "concentration") && length(last_element) > 13
elements[end] = "concentration_$(last_element[14:end])"
elseif last_element == "subgridtime"
elements[end] = "subgrid_time"
end
if isnode(sv)
n = elements[2]
Expand Down Expand Up @@ -150,6 +154,14 @@ end
subgrid_level::Float64
end

@version BasinSubgridTimeV1 begin
subgrid_id::Int32
node_id::Int32
time::DateTime
basin_level::Float64
subgrid_level::Float64
end

@version LevelBoundaryStaticV1 begin
node_id::Int32
active::Union{Missing, Bool}
Expand Down
4 changes: 2 additions & 2 deletions core/src/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,8 @@ Validate the entries for a single subgrid element.
"""
function valid_subgrid(
subgrid_id::Int32,
node_id::NodeID,
node_to_basin::Dict{NodeID, Int},
node_id::Int32,
node_to_basin::Dict{Int32, Int},
basin_level::Vector{Float64},
subgrid_level::Vector{Float64},
)::Bool
Expand Down
11 changes: 7 additions & 4 deletions core/src/write.jl
Original file line number Diff line number Diff line change
Expand Up @@ -376,11 +376,14 @@ function subgrid_level_table(
(; t, saveval) = saved.subgrid_level
subgrid = integrator.p.subgrid

nelem = length(subgrid.subgrid_id)
nelem = length(subgrid.level)
ntsteps = length(t)

time = repeat(datetime_since.(t, config.starttime); inner = nelem)
subgrid_id = repeat(subgrid.subgrid_id; outer = ntsteps)
subgrid_id = repeat(
sort(vcat(subgrid.subgrid_id_static, subgrid.subgrid_id_time));
outer = ntsteps,
evetion marked this conversation as resolved.
Show resolved Hide resolved
)
subgrid_level = FlatVector(saveval)
return (; time, subgrid_id, subgrid_level)
end
Expand Down Expand Up @@ -412,9 +415,9 @@ function write_arrow(
mkpath(dirname(path))
try
Arrow.write(path, table; compress, metadata)
catch
catch e
@error "Failed to write results, file may be locked." path
error("Failed to write results.")
rethrow(e)
end
return nothing
end
Expand Down
Loading
Loading