Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Basin / subgrid_time table #1975

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/src/Ribasim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ using PreallocationTools: LazyBufferCache
# basin profiles and TabulatedRatingCurve. See also the node
# references in the docs.
using DataInterpolations:
ConstantInterpolation,
LinearInterpolation,
LinearInterpolationIntInv,
invert_integral,
Expand Down
22 changes: 19 additions & 3 deletions core/src/callback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -671,12 +671,28 @@ function apply_parameter_update!(parameter_update)::Nothing
end

function update_subgrid_level!(integrator)::Nothing
(; p) = integrator
(; p, t) = integrator
du = get_du(integrator)
basin_level = p.basin.current_properties.current_level[parent(du)]
subgrid = integrator.p.subgrid
for (i, (index, interp)) in enumerate(zip(subgrid.basin_index, subgrid.interpolations))
subgrid.level[i] = interp(basin_level[index])

# First update the all the subgrids with static h(h) relations
for (level_index, basin_index, hh_itp) in zip(
subgrid.level_index_static,
subgrid.basin_index_static,
subgrid.interpolations_static,
)
subgrid.level[level_index] = hh_itp(basin_level[basin_index])
end
# Then update the subgrids with dynamic h(h) relations
for (level_index, basin_index, lookup) in zip(
subgrid.level_index_time,
subgrid.basin_index_time,
subgrid.current_interpolation_index,
)
itp_index = lookup(t)
hh_itp = subgrid.interpolations_time[itp_index]
subgrid.level[level_index] = hh_itp(basin_level[basin_index])
end
end

Expand Down
27 changes: 24 additions & 3 deletions core/src/parameter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ end

Base.to_index(id::NodeID) = Int(id.value)

"LinearInterpolation from a Float64 to a Float64"
const ScalarInterpolation = LinearInterpolation{
Vector{Float64},
Vector{Float64},
Expand All @@ -114,6 +115,10 @@ const ScalarInterpolation = LinearInterpolation{
(1,),
}

"ConstantInterpolation from a Float64 to an Int, used to look up indices"
visr marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"ConstantInterpolation from a Float64 to an Int, used to look up indices"
"ConstantInterpolation from a Float64 to an Int, used to look up indices over time"

const IndexLookup =
ConstantInterpolation{Vector{Int64}, Vector{Float64}, Vector{Float64}, Int64, (1,)}

set_zero!(v) = v .= zero(eltype(v))
const Cache = LazyBufferCache{Returns{Int}, typeof(set_zero!)}

Expand Down Expand Up @@ -876,10 +881,26 @@ end

"Subgrid linearly interpolates basin levels."
@kwdef struct Subgrid
subgrid_id::Vector{Int32}
basin_index::Vector{Int32}
interpolations::Vector{ScalarInterpolation}
# level of each subgrid (static and dynamic) ordered by subgrid_id
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# level of each subgrid (static and dynamic) ordered by subgrid_id
# current level of each subgrid (static and dynamic) ordered by subgrid_id

level::Vector{Float64}
# static
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# static
# Static part
# Static subgrid ids

subgrid_id_static::Vector{Int32}
# index into the basin.current_level vector for each static subgrid_id
basin_index_static::Vector{Int}
# index into the subgrid.level vector for each static subgrid_id
level_index_static::Vector{Int}
# per subgrid one relation
interpolations_static::Vector{ScalarInterpolation}
# dynamic
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# dynamic
# Dynamic part
# Dynamic subgrid ids

subgrid_id_time::Vector{Int32}
# index into the basin.current_level vector for each dynamic subgrid_id
basin_index_time::Vector{Int}
# index into the subgrid.level vector for each dynamic subgrid_id
level_index_time::Vector{Int}
# per subgrid n relations, n being the number of timesteps for that subgrid
interpolations_time::Vector{ScalarInterpolation}
# per subgrid 1 lookup from t to an index in interpolations_time
current_interpolation_index::Vector{IndexLookup}
end

"""
Expand Down
146 changes: 131 additions & 15 deletions core/src/read.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,18 @@
return out, !errors
end

"""
Validate the split of node IDs between static and time tables.
visr marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Validate the split of node IDs between static and time tables.
Retrieve and validate the split of node IDs between static and time tables.


For node types that can have a part of the parameters defined statically and a part dynamically,
this checks if each ID is defined exactly once in either table.
"""
function static_and_time_node_ids(
db::DB,
static::StructVector,
time::StructVector,
node_type::String,
node_type::String;
is_complete::Bool = true,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is undocumented (and unclear to me). I get that there's a state before/after, but not what the difference(s) are.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mentioned it at the callsite, but also added docs here.

)::Tuple{Set{NodeID}, Set{NodeID}, Vector{NodeID}, Bool}
ids = get_ids(db, node_type)
idx = searchsortedfirst.(Ref(ids), static.node_id)
Expand All @@ -205,7 +212,7 @@
errors = true
@error "$node_type cannot be in both static and time tables, found these node IDs in both: $doubles."
end
if !issetequal(node_ids, union(static_node_ids, time_node_ids))
if is_complete && !issetequal(node_ids, union(static_node_ids, time_node_ids))
errors = true
@error "$node_type node IDs don't match."
end
Expand Down Expand Up @@ -1226,17 +1233,42 @@
)
end

function push_lookup!(
current_interpolation_index::Vector{IndexLookup},
lookup_index::Vector{Int},
lookup_time::Vector{Float64},
)
index_lookup = ConstantInterpolation(
lookup_index,
lookup_time;
extrapolate = true,
cache_parameters = true,
)
push!(current_interpolation_index, index_lookup)
end

function Subgrid(db::DB, config::Config, basin::Basin)::Subgrid
node_to_basin = Dict(node_id => index for (index, node_id) in enumerate(basin.node_id))
tables = load_structvector(db, config, BasinSubgridV1)
time = load_structvector(db, config, BasinSubgridTimeV1)
static = load_structvector(db, config, BasinSubgridV1)

subgrid_ids = Int32[]
basin_index = Int32[]
interpolations = ScalarInterpolation[]
# Since not all Basins need to have subgrids, don't enforce completeness.
_, _, _, valid =
static_and_time_node_ids(db, static, time, "Basin"; is_complete = false)
if !valid
error("Problems encountered when parsing Subgrid static and time node IDs.")

Check warning on line 1258 in core/src/read.jl

View check run for this annotation

Codecov / codecov/patch

core/src/read.jl#L1258

Added line #L1258 was not covered by tests
end

node_to_basin = Dict{Int32, Int}(
Int32(node_id) => index for (index, node_id) in enumerate(basin.node_id)
)
subgrid_id_static = Int32[]
basin_index_static = Int[]
interpolations_static = ScalarInterpolation[]
has_error = false
for group in IterTools.groupby(row -> row.subgrid_id, tables)

for group in IterTools.groupby(row -> row.subgrid_id, static)
subgrid_id = first(getproperty.(group, :subgrid_id))
node_id = NodeID(NodeType.Basin, first(getproperty.(group, :node_id)), db)
node_id = first(getproperty.(group, :node_id))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we away with NodeID here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It simplified the code, otherwise we needed to strip out the integers later on.

basin_level = getproperty.(group, :basin_level)
subgrid_level = getproperty.(group, :subgrid_level)

Expand All @@ -1247,24 +1279,108 @@
# Ensure it doesn't extrapolate before the first value.
pushfirst!(subgrid_level, first(subgrid_level))
pushfirst!(basin_level, nextfloat(-Inf))
new_interp = LinearInterpolation(
hh_itp = LinearInterpolation(
subgrid_level,
basin_level;
extrapolate = true,
cache_parameters = true,
)
push!(subgrid_ids, subgrid_id)
push!(basin_index, node_to_basin[node_id])
push!(interpolations, new_interp)
push!(subgrid_id_static, subgrid_id)
push!(basin_index_static, node_to_basin[node_id])
push!(interpolations_static, hh_itp)
else
has_error = true
end
end

has_error && error("Invalid Basin / subgrid table.")
level = fill(NaN, length(subgrid_ids))
subgrid_id_time = Int32[]
basin_index_time = Int[]
interpolations_time = ScalarInterpolation[]
current_interpolation_index = IndexLookup[]

# Push the first subgrid_id and basin_index
if length(time) > 0
push!(subgrid_id_time, first(time.subgrid_id))
push!(basin_index_time, node_to_basin[first(time.node_id)])
end

# Initialize index_lookup contents
lookup_time = Float64[]
lookup_index = Int[]

interpolation_index = 0
for group in IterTools.groupby(row -> (row.subgrid_id, row.time), time)
interpolation_index += 1
subgrid_id = first(getproperty.(group, :subgrid_id))
time_group = seconds_since(first(getproperty.(group, :time)), config.starttime)
node_id = first(getproperty.(group, :node_id))
basin_level = getproperty.(group, :basin_level)
subgrid_level = getproperty.(group, :subgrid_level)

is_valid =
valid_subgrid(subgrid_id, node_id, node_to_basin, basin_level, subgrid_level)

return Subgrid(; subgrid_id = subgrid_ids, basin_index, interpolations, level)
if is_valid
visr marked this conversation as resolved.
Show resolved Hide resolved
# Ensure it doesn't extrapolate before the first value.
pushfirst!(subgrid_level, first(subgrid_level))
pushfirst!(basin_level, nextfloat(-Inf))
hh_itp = LinearInterpolation(
subgrid_level,
basin_level;
extrapolate = true,
cache_parameters = true,
)
# # These should only be pushed when the subgrid_id has changed
if subgrid_id_time[end] != subgrid_id
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, not sure if I get it, but if we group by subgrid_id, this will always be true?

# Push the completed index_lookup of the previous subgrid_id
push_lookup!(current_interpolation_index, lookup_index, lookup_time)

Check warning on line 1337 in core/src/read.jl

View check run for this annotation

Codecov / codecov/patch

core/src/read.jl#L1337

Added line #L1337 was not covered by tests
# Push the new subgrid_id and basin_index
push!(subgrid_id_time, subgrid_id)
push!(basin_index_time, node_to_basin[node_id])

Check warning on line 1340 in core/src/read.jl

View check run for this annotation

Codecov / codecov/patch

core/src/read.jl#L1339-L1340

Added lines #L1339 - L1340 were not covered by tests
# Start new index_lookup contents
lookup_time = Float64[]
lookup_index = Int[]

Check warning on line 1343 in core/src/read.jl

View check run for this annotation

Codecov / codecov/patch

core/src/read.jl#L1342-L1343

Added lines #L1342 - L1343 were not covered by tests
end
push!(lookup_index, interpolation_index)
push!(lookup_time, time_group)
push!(interpolations_time, hh_itp)
else
has_error = true

Check warning on line 1349 in core/src/read.jl

View check run for this annotation

Codecov / codecov/patch

core/src/read.jl#L1349

Added line #L1349 was not covered by tests
end
end

# Push completed IndexLookup of the last group
if interpolation_index > 0
push_lookup!(current_interpolation_index, lookup_index, lookup_time)
end

has_error && error("Invalid Basin / subgrid_time table.")
level = fill(NaN, length(subgrid_id_static) + length(subgrid_id_time))

# Find the level indices
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This static part can be moved up?

level_index_static = zeros(Int, length(subgrid_id_static))
level_index_time = zeros(Int, length(subgrid_id_time))
subgrid_ids = sort(vcat(subgrid_id_static, subgrid_id_time))
for (i, subgrid_id) in enumerate(subgrid_id_static)
level_index_static[i] = findsorted(subgrid_ids, subgrid_id)
end
for (i, subgrid_id) in enumerate(subgrid_id_time)
level_index_time[i] = findsorted(subgrid_ids, subgrid_id)
end

return Subgrid(;
level,
subgrid_id_static,
basin_index_static,
level_index_static,
interpolations_static,
subgrid_id_time,
basin_index_time,
level_index_time,
interpolations_time,
current_interpolation_index,
)
end

function Allocation(db::DB, config::Config, graph::MetaGraph)::Allocation
Expand Down
13 changes: 13 additions & 0 deletions core/src/schema.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
@schema "ribasim.basin.profile" BasinProfile
@schema "ribasim.basin.state" BasinState
@schema "ribasim.basin.subgrid" BasinSubgrid
@schema "ribasim.basin.subgridtime" BasinSubgridTime
@schema "ribasim.basin.concentration" BasinConcentration
@schema "ribasim.basin.concentrationexternal" BasinConcentrationExternal
@schema "ribasim.basin.concentrationstate" BasinConcentrationState
Expand Down Expand Up @@ -58,9 +59,13 @@ function nodetype(
type_string = string(T)
elements = split(type_string, '.'; limit = 3)
last_element = last(elements)
# Special case last elements that need an underscore
if startswith(last_element, "concentration") && length(last_element) > 13
elements[end] = "concentration_$(last_element[14:end])"
end
if last_element == "subgridtime"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

elseif?

elements[end] = "subgrid_time"
end
if isnode(sv)
n = elements[2]
k = Symbol(elements[3])
Expand Down Expand Up @@ -150,6 +155,14 @@ end
subgrid_level::Float64
end

@version BasinSubgridTimeV1 begin
subgrid_id::Int32
node_id::Int32
time::DateTime
basin_level::Float64
subgrid_level::Float64
end

@version LevelBoundaryStaticV1 begin
node_id::Int32
active::Union{Missing, Bool}
Expand Down
4 changes: 2 additions & 2 deletions core/src/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,8 @@ Validate the entries for a single subgrid element.
"""
function valid_subgrid(
subgrid_id::Int32,
node_id::NodeID,
node_to_basin::Dict{NodeID, Int},
node_id::Int32,
node_to_basin::Dict{Int32, Int},
basin_level::Vector{Float64},
subgrid_level::Vector{Float64},
)::Bool
Expand Down
11 changes: 7 additions & 4 deletions core/src/write.jl
Original file line number Diff line number Diff line change
Expand Up @@ -376,11 +376,14 @@
(; t, saveval) = saved.subgrid_level
subgrid = integrator.p.subgrid

nelem = length(subgrid.subgrid_id)
nelem = length(subgrid.level)
ntsteps = length(t)

time = repeat(datetime_since.(t, config.starttime); inner = nelem)
subgrid_id = repeat(subgrid.subgrid_id; outer = ntsteps)
subgrid_id = repeat(
sort(vcat(subgrid.subgrid_id_static, subgrid.subgrid_id_time));
outer = ntsteps,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does time already account for the multiple possible timestates (i.e. duplicates already exist)?

)
subgrid_level = FlatVector(saveval)
return (; time, subgrid_id, subgrid_level)
end
Expand Down Expand Up @@ -412,9 +415,9 @@
mkpath(dirname(path))
try
Arrow.write(path, table; compress, metadata)
catch
catch e
@error "Failed to write results, file may be locked." path
error("Failed to write results.")
rethrow(e)

Check warning on line 420 in core/src/write.jl

View check run for this annotation

Codecov / codecov/patch

core/src/write.jl#L420

Added line #L420 was not covered by tests
end
return nothing
end
Expand Down
26 changes: 26 additions & 0 deletions core/test/run_models_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -583,3 +583,29 @@ end
@test all(isapprox.(Δinf[1:2:end], 25.0; atol = 1e-10))
@test all(Δinf[2:2:end] .== 0.0)
end

@testitem "two_basin" begin
using DataFrames: DataFrame, nrow
using Dates: DateTime
import BasicModelInterface as BMI

toml_path = normpath(@__DIR__, "../../generated_testmodels/two_basin/ribasim.toml")
model = Ribasim.run(toml_path)
df = DataFrame(Ribasim.subgrid_level_table(model))

ntime = 367
@test nrow(df) == ntime * 2
@test df.subgrid_id == repeat(1:2; outer = ntime)
@test extrema(df.time) == (DateTime(2020), DateTime(2021))
@test all(df.subgrid_level[1:2] .== 0.01)

# After a month the h(h) of subgrid_id 2 increases by a meter
i_change = searchsortedfirst(df.time, DateTime(2020, 2))
@test df.subgrid_level[i_change + 1] - df.subgrid_level[i_change - 1] ≈ 1.0f0

# Besides the 1 meter shift the h(h) relations are 1:1
basin_level = copy(BMI.get_value_ptr(model, "basin.level"))
basin_level[2] += 1
@test basin_level ≈ df.subgrid_level[(end - 1):end]
@test basin_level ≈ model.integrator.p.subgrid.level
end
Loading
Loading