diff --git a/core/src/solve.jl b/core/src/solve.jl index 5bfa79d10..7105dccb7 100644 --- a/core/src/solve.jl +++ b/core/src/solve.jl @@ -346,8 +346,7 @@ function formulate!(manning_resistance::ManningResistance, p::Parameters)::Nothi h_a = get_level(p, basin_a_id) h_b = get_level(p, basin_b_id) - bottom_a = basin_bottom(basin, basin_a_id) - bottom_b = basin_bottom(basin, basin_b_id) + bottom_a, bottom_b = basin_bottoms(basin, basin_a_id, basin_b_id, id) slope = profile_slope[i] width = profile_width[i] n = manning_n[i] diff --git a/core/src/utils.jl b/core/src/utils.jl index 9dc9f149e..255311c33 100644 --- a/core/src/utils.jl +++ b/core/src/utils.jl @@ -214,13 +214,33 @@ function basin_bottom_index(basin::Basin, i::Int)::Float64 return first(itp.u) end -"Return the bottom elevation of the basin with index i" -function basin_bottom(basin::Basin, node_id::Int)::Float64 +"Return the bottom elevation of the basin with index i, or nothing if it doesn't exist" +function basin_bottom(basin::Basin, node_id::Int)::Union{Float64, Nothing} basin = Dictionary(basin.node_id, basin.level) hasindex, token = gettoken(basin, node_id) - @assert hasindex "node_id $node_id not a Basin" - # get level(storage) interpolation function - itp = gettokenvalue(basin, token) - # and return the first level in the underlying table, which represents the bottom - return first(itp.u) + return if hasindex + # get level(storage) interpolation function + itp = gettokenvalue(basin, token) + # and return the first level in the underlying table, which represents the bottom + first(itp.u) + else + nothing + end +end + +"Get the bottom on both ends of a node. If only one has a bottom, use that for both." +function basin_bottoms( + basin::Basin, + basin_a_id::Int, + basin_b_id::Int, + id::Int, +)::Tuple{Float64, Float64} + bottom_a = basin_bottom(basin, basin_a_id) + bottom_b = basin_bottom(basin, basin_b_id) + if isnothing(bottom_a) && isnothing(bottom_b) + error(lazy"No bottom defined on either side of $id") + end + bottom_a = something(bottom_a, bottom_b) + bottom_b = something(bottom_b, bottom_a) + return bottom_a, bottom_b end diff --git a/core/test/basin.jl b/core/test/basin.jl index 254be16be..ca2494d11 100644 --- a/core/test/basin.jl +++ b/core/test/basin.jl @@ -3,7 +3,6 @@ using Ribasim import BasicModelInterface as BMI using SciMLBase - @testset "trivial model" begin toml_path = normpath(@__DIR__, "../../data/trivial/trivial.toml") @test ispath(toml_path) diff --git a/core/test/utils.jl b/core/test/utils.jl index f3bf5fdec..a2824f5b8 100644 --- a/core/test/utils.jl +++ b/core/test/utils.jl @@ -1,6 +1,8 @@ using Ribasim using Dictionaries: Indices using Test +using DataInterpolations: LinearInterpolation +using StructArrays: StructVector @testset "id_index" begin ids = Indices([2, 4, 6]) @@ -14,3 +16,37 @@ end @test Ribasim.profile_storage([6.0, 7.0, 9.0], [0.0, 1000.0, 1000.0]) == [0.0, 500.0, 2500.0] end + +@testset "bottom" begin + basin = Ribasim.Basin( + Indices([5, 7]), + [2.0, 3.0], + [2.0, 3.0], + [2.0, 3.0], + [2.0, 3.0], + [2.0, 3.0], + [ # area + LinearInterpolation([1.0, 1.0], [0.0, 1.0]), + LinearInterpolation([1.0, 1.0], [0.0, 1.0]), + ], + [ # level + LinearInterpolation([0.0, 1.0], [0.0, 1.0]), + LinearInterpolation([4.0, 3.0], [0.0, 1.0]), + ], + StructVector{Ribasim.BasinForcingV1}(undef, 0), + ) + + @test Ribasim.basin_bottom_index(basin, 2) === 4.0 + @test Ribasim.basin_bottom(basin, 5) === 0.0 + @test Ribasim.basin_bottom(basin, 7) === 4.0 + @test Ribasim.basin_bottom(basin, 6) === nothing + @test Ribasim.basin_bottoms(basin, 5, 7, 6) === (0.0, 4.0) + @test Ribasim.basin_bottoms(basin, 5, 0, 6) === (0.0, 0.0) + @test Ribasim.basin_bottoms(basin, 0, 7, 6) === (4.0, 4.0) + @test_throws "No bottom defined on either side of 6" Ribasim.basin_bottoms( + basin, + 0, + 1, + 6, + ) +end