Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Basin / subgrid_time table #1975

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/src/Ribasim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ using PreallocationTools: LazyBufferCache
# basin profiles and TabulatedRatingCurve. See also the node
# references in the docs.
using DataInterpolations:
ConstantInterpolation,
LinearInterpolation,
LinearInterpolationIntInv,
invert_integral,
Expand Down
22 changes: 19 additions & 3 deletions core/src/callback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -671,12 +671,28 @@ function apply_parameter_update!(parameter_update)::Nothing
end

function update_subgrid_level!(integrator)::Nothing
(; p) = integrator
(; p, t) = integrator
du = get_du(integrator)
basin_level = p.basin.current_properties.current_level[parent(du)]
subgrid = integrator.p.subgrid
for (i, (index, interp)) in enumerate(zip(subgrid.basin_index, subgrid.interpolations))
subgrid.level[i] = interp(basin_level[index])

# First update the all the subgrids with static h(h) relations
for (level_index, basin_index, hh_itp) in zip(
subgrid.level_index_static,
subgrid.basin_index_static,
subgrid.interpolations_static,
)
subgrid.level[level_index] = hh_itp(basin_level[basin_index])
end
# Then update the subgrids with dynamic h(h) relations
for (level_index, basin_index, lookup) in zip(
subgrid.level_index_time,
subgrid.basin_index_time,
subgrid.current_interpolation_index,
)
itp_index = lookup(t)
hh_itp = subgrid.interpolations_time[itp_index]
subgrid.level[level_index] = hh_itp(basin_level[basin_index])
end
end

Expand Down
27 changes: 24 additions & 3 deletions core/src/parameter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ end

Base.to_index(id::NodeID) = Int(id.value)

"LinearInterpolation from a Float64 to a Float64"
const ScalarInterpolation = LinearInterpolation{
Vector{Float64},
Vector{Float64},
Expand All @@ -114,6 +115,10 @@ const ScalarInterpolation = LinearInterpolation{
(1,),
}

"ConstantInterpolation from a Float64 to an Int, used to look up indices"
visr marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"ConstantInterpolation from a Float64 to an Int, used to look up indices"
"ConstantInterpolation from a Float64 to an Int, used to look up indices over time"

const IndexLookup =
ConstantInterpolation{Vector{Int64}, Vector{Float64}, Vector{Float64}, Int64, (1,)}

set_zero!(v) = v .= zero(eltype(v))
const Cache = LazyBufferCache{Returns{Int}, typeof(set_zero!)}

Expand Down Expand Up @@ -876,10 +881,26 @@ end

"Subgrid linearly interpolates basin levels."
@kwdef struct Subgrid
subgrid_id::Vector{Int32}
basin_index::Vector{Int32}
interpolations::Vector{ScalarInterpolation}
# level of each subgrid (static and dynamic) ordered by subgrid_id
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# level of each subgrid (static and dynamic) ordered by subgrid_id
# current level of each subgrid (static and dynamic) ordered by subgrid_id

level::Vector{Float64}
# static
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# static
# Static part
# Static subgrid ids

subgrid_id_static::Vector{Int32}
# index into the basin.current_level vector for each static subgrid_id
basin_index_static::Vector{Int}
# index into the subgrid.level vector for each static subgrid_id
level_index_static::Vector{Int}
# per subgrid one relation
interpolations_static::Vector{ScalarInterpolation}
# dynamic
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# dynamic
# Dynamic part
# Dynamic subgrid ids

subgrid_id_time::Vector{Int32}
# index into the basin.current_level vector for each dynamic subgrid_id
basin_index_time::Vector{Int}
# index into the subgrid.level vector for each dynamic subgrid_id
level_index_time::Vector{Int}
# per subgrid n relations, n being the number of timesteps for that subgrid
interpolations_time::Vector{ScalarInterpolation}
# per subgrid 1 lookup from t to an index in interpolations_time
current_interpolation_index::Vector{IndexLookup}
end

"""
Expand Down
168 changes: 140 additions & 28 deletions core/src/read.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,22 @@ function parse_static_and_time(
return out, !errors
end

"""
Validate the split of node IDs between static and time tables.
visr marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Validate the split of node IDs between static and time tables.
Retrieve and validate the split of node IDs between static and time tables.


For node types that can have a part of the parameters defined statically and a part dynamically,
this checks if each ID is defined exactly once in either table.

The `is_complete` argument allows disabling the check that all Node IDs of type `node_type`
are either in the `static` or `time` table.
This is not required for Subgrid since not all Basins need to have subgrids.
"""
function static_and_time_node_ids(
db::DB,
static::StructVector,
time::StructVector,
node_type::String,
node_type::String;
is_complete::Bool = true,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is undocumented (and unclear to me). I get that there's a state before/after, but not what the difference(s) are.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mentioned it at the callsite, but also added docs here.

)::Tuple{Set{NodeID}, Set{NodeID}, Vector{NodeID}, Bool}
ids = get_ids(db, node_type)
idx = searchsortedfirst.(Ref(ids), static.node_id)
Expand All @@ -205,7 +216,7 @@ function static_and_time_node_ids(
errors = true
@error "$node_type cannot be in both static and time tables, found these node IDs in both: $doubles."
end
if !issetequal(node_ids, union(static_node_ids, time_node_ids))
if is_complete && !issetequal(node_ids, union(static_node_ids, time_node_ids))
errors = true
@error "$node_type node IDs don't match."
end
Expand Down Expand Up @@ -1226,45 +1237,146 @@ function FlowDemand(db::DB, config::Config)::FlowDemand
)
end

function push_lookup!(
current_interpolation_index::Vector{IndexLookup},
lookup_index::Vector{Int},
lookup_time::Vector{Float64},
)
index_lookup = ConstantInterpolation(
lookup_index,
lookup_time;
extrapolate = true,
cache_parameters = true,
)
push!(current_interpolation_index, index_lookup)
end

function Subgrid(db::DB, config::Config, basin::Basin)::Subgrid
node_to_basin = Dict(node_id => index for (index, node_id) in enumerate(basin.node_id))
tables = load_structvector(db, config, BasinSubgridV1)
time = load_structvector(db, config, BasinSubgridTimeV1)
static = load_structvector(db, config, BasinSubgridV1)

subgrid_ids = Int32[]
basin_index = Int32[]
interpolations = ScalarInterpolation[]
has_error = false
for group in IterTools.groupby(row -> row.subgrid_id, tables)
# Since not all Basins need to have subgrids, don't enforce completeness.
_, _, _, valid =
static_and_time_node_ids(db, static, time, "Basin"; is_complete = false)
if !valid
error("Problems encountered when parsing Subgrid static and time node IDs.")
end

node_to_basin = Dict{Int32, Int}(
Int32(node_id) => index for (index, node_id) in enumerate(basin.node_id)
)
subgrid_id_static = Int32[]
basin_index_static = Int[]
interpolations_static = ScalarInterpolation[]

for group in IterTools.groupby(row -> row.subgrid_id, static)
subgrid_id = first(getproperty.(group, :subgrid_id))
node_id = NodeID(NodeType.Basin, first(getproperty.(group, :node_id)), db)
node_id = first(getproperty.(group, :node_id))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we away with NodeID here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It simplified the code, otherwise we needed to strip out the integers later on.

basin_level = getproperty.(group, :basin_level)
subgrid_level = getproperty.(group, :subgrid_level)

is_valid =
valid_subgrid(subgrid_id, node_id, node_to_basin, basin_level, subgrid_level)
!is_valid && error("Invalid Basin / subgrid table.")

# Ensure it doesn't extrapolate before the first value.
pushfirst!(subgrid_level, first(subgrid_level))
pushfirst!(basin_level, nextfloat(-Inf))
hh_itp = LinearInterpolation(
subgrid_level,
basin_level;
extrapolate = true,
cache_parameters = true,
)
push!(subgrid_id_static, subgrid_id)
push!(basin_index_static, node_to_basin[node_id])
push!(interpolations_static, hh_itp)
end

if is_valid
# Ensure it doesn't extrapolate before the first value.
pushfirst!(subgrid_level, first(subgrid_level))
pushfirst!(basin_level, nextfloat(-Inf))
new_interp = LinearInterpolation(
subgrid_level,
basin_level;
extrapolate = true,
cache_parameters = true,
)
push!(subgrid_ids, subgrid_id)
push!(basin_index, node_to_basin[node_id])
push!(interpolations, new_interp)
else
has_error = true
subgrid_id_time = Int32[]
basin_index_time = Int[]
interpolations_time = ScalarInterpolation[]
current_interpolation_index = IndexLookup[]

# Push the first subgrid_id and basin_index
if length(time) > 0
push!(subgrid_id_time, first(time.subgrid_id))
push!(basin_index_time, node_to_basin[first(time.node_id)])
end

# Initialize index_lookup contents
lookup_time = Float64[]
lookup_index = Int[]

interpolation_index = 0
for group in IterTools.groupby(row -> (row.subgrid_id, row.time), time)
interpolation_index += 1
subgrid_id = first(getproperty.(group, :subgrid_id))
time_group = seconds_since(first(getproperty.(group, :time)), config.starttime)
node_id = first(getproperty.(group, :node_id))
basin_level = getproperty.(group, :basin_level)
subgrid_level = getproperty.(group, :subgrid_level)

is_valid =
valid_subgrid(subgrid_id, node_id, node_to_basin, basin_level, subgrid_level)
!is_valid && error("Invalid Basin / subgrid_time table.")

# Ensure it doesn't extrapolate before the first value.
pushfirst!(subgrid_level, first(subgrid_level))
pushfirst!(basin_level, nextfloat(-Inf))
hh_itp = LinearInterpolation(
subgrid_level,
basin_level;
extrapolate = true,
cache_parameters = true,
)
# # These should only be pushed when the subgrid_id has changed
if subgrid_id_time[end] != subgrid_id
# Push the completed index_lookup of the previous subgrid_id
push_lookup!(current_interpolation_index, lookup_index, lookup_time)
# Push the new subgrid_id and basin_index
push!(subgrid_id_time, subgrid_id)
push!(basin_index_time, node_to_basin[node_id])
# Start new index_lookup contents
lookup_time = Float64[]
lookup_index = Int[]
end
push!(lookup_index, interpolation_index)
push!(lookup_time, time_group)
push!(interpolations_time, hh_itp)
end

# Push completed IndexLookup of the last group
if interpolation_index > 0
push_lookup!(current_interpolation_index, lookup_index, lookup_time)
end

has_error && error("Invalid Basin / subgrid table.")
level = fill(NaN, length(subgrid_ids))
has_error && error("Invalid Basin / subgrid_time table.")
level = fill(NaN, length(subgrid_id_static) + length(subgrid_id_time))

return Subgrid(; subgrid_id = subgrid_ids, basin_index, interpolations, level)
# Find the level indices
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This static part can be moved up?

level_index_static = zeros(Int, length(subgrid_id_static))
level_index_time = zeros(Int, length(subgrid_id_time))
subgrid_ids = sort(vcat(subgrid_id_static, subgrid_id_time))
for (i, subgrid_id) in enumerate(subgrid_id_static)
level_index_static[i] = findsorted(subgrid_ids, subgrid_id)
end
for (i, subgrid_id) in enumerate(subgrid_id_time)
level_index_time[i] = findsorted(subgrid_ids, subgrid_id)
end

return Subgrid(;
level,
subgrid_id_static,
basin_index_static,
level_index_static,
interpolations_static,
subgrid_id_time,
basin_index_time,
level_index_time,
interpolations_time,
current_interpolation_index,
)
end

function Allocation(db::DB, config::Config, graph::MetaGraph)::Allocation
Expand Down
13 changes: 13 additions & 0 deletions core/src/schema.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
@schema "ribasim.basin.profile" BasinProfile
@schema "ribasim.basin.state" BasinState
@schema "ribasim.basin.subgrid" BasinSubgrid
@schema "ribasim.basin.subgridtime" BasinSubgridTime
@schema "ribasim.basin.concentration" BasinConcentration
@schema "ribasim.basin.concentrationexternal" BasinConcentrationExternal
@schema "ribasim.basin.concentrationstate" BasinConcentrationState
Expand Down Expand Up @@ -58,9 +59,13 @@ function nodetype(
type_string = string(T)
elements = split(type_string, '.'; limit = 3)
last_element = last(elements)
# Special case last elements that need an underscore
if startswith(last_element, "concentration") && length(last_element) > 13
elements[end] = "concentration_$(last_element[14:end])"
end
if last_element == "subgridtime"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

elseif?

elements[end] = "subgrid_time"
end
if isnode(sv)
n = elements[2]
k = Symbol(elements[3])
Expand Down Expand Up @@ -150,6 +155,14 @@ end
subgrid_level::Float64
end

@version BasinSubgridTimeV1 begin
subgrid_id::Int32
node_id::Int32
time::DateTime
basin_level::Float64
subgrid_level::Float64
end

@version LevelBoundaryStaticV1 begin
node_id::Int32
active::Union{Missing, Bool}
Expand Down
4 changes: 2 additions & 2 deletions core/src/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,8 @@ Validate the entries for a single subgrid element.
"""
function valid_subgrid(
subgrid_id::Int32,
node_id::NodeID,
node_to_basin::Dict{NodeID, Int},
node_id::Int32,
node_to_basin::Dict{Int32, Int},
basin_level::Vector{Float64},
subgrid_level::Vector{Float64},
)::Bool
Expand Down
11 changes: 7 additions & 4 deletions core/src/write.jl
Original file line number Diff line number Diff line change
Expand Up @@ -376,11 +376,14 @@ function subgrid_level_table(
(; t, saveval) = saved.subgrid_level
subgrid = integrator.p.subgrid

nelem = length(subgrid.subgrid_id)
nelem = length(subgrid.level)
ntsteps = length(t)

time = repeat(datetime_since.(t, config.starttime); inner = nelem)
subgrid_id = repeat(subgrid.subgrid_id; outer = ntsteps)
subgrid_id = repeat(
sort(vcat(subgrid.subgrid_id_static, subgrid.subgrid_id_time));
outer = ntsteps,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does time already account for the multiple possible timestates (i.e. duplicates already exist)?

)
subgrid_level = FlatVector(saveval)
return (; time, subgrid_id, subgrid_level)
end
Expand Down Expand Up @@ -412,9 +415,9 @@ function write_arrow(
mkpath(dirname(path))
try
Arrow.write(path, table; compress, metadata)
catch
catch e
@error "Failed to write results, file may be locked." path
error("Failed to write results.")
rethrow(e)
end
return nothing
end
Expand Down
Loading
Loading