Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Configurable interpolation types #1932

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -461,9 +461,9 @@ version = "1.7.0"

[[deps.DataInterpolations]]
deps = ["FindFirstFunctions", "ForwardDiff", "LinearAlgebra", "PrettyTables", "RecipesBase", "Reexport"]
git-tree-sha1 = "3d81cd1fcba530122a5d6c725aa53521d869816a"
git-tree-sha1 = "78d06458ec13b53b3b0016daebe53f832d42ff44"
uuid = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
version = "6.5.2"
version = "6.6.0"

[deps.DataInterpolations.extensions]
DataInterpolationsChainRulesCoreExt = "ChainRulesCore"
Expand Down
1 change: 1 addition & 0 deletions core/src/Ribasim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ using PreallocationTools: LazyBufferCache
using DataInterpolations:
LinearInterpolation,
LinearInterpolationIntInv,
PCHIPInterpolation,
invert_integral,
derivative,
integral,
Expand Down
29 changes: 24 additions & 5 deletions core/src/config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Ribasim.config is a submodule mainly to avoid name clashes between the configura
module config

using Configurations: Configurations, @option, from_toml, @type_alias
using DataInterpolations: LinearInterpolation, PCHIPInterpolation, CubicHermiteSpline
using DataStructures: DefaultDict
using Dates: DateTime
using Logging: LogLevel, Debug, Info, Warn, Error
Expand All @@ -24,13 +25,14 @@ using OrdinaryDiffEqRosenbrock: Rosenbrock23, Rodas4P, Rodas5P
export Config, Solver, Results, Logging, Toml
export algorithm,
camel_case,
snake_case,
input_path,
convert_dt,
convert_saveat,
database_path,
input_path,
interpolation_method,
nodetypes,
results_path,
convert_saveat,
convert_dt,
nodetypes
snake_case

const schemas =
getfield.(
Expand Down Expand Up @@ -134,9 +136,14 @@ end
use_allocation::Bool = false
end

@option struct Interpolation <: TableOption
tabulated_rating_curve::String = "LinearInterpolation"
end

@option struct Experimental <: TableOption
concentration::Bool = false
end

# For logging enabled experimental features
function Base.iterate(exp::Experimental, state = 0)
state >= nfields(exp) && return
Expand All @@ -156,6 +163,7 @@ end
results_dir::String
allocation::Allocation = Allocation()
solver::Solver = Solver()
interpolation::Interpolation = Interpolation()
logging::Logging = Logging()
results::Results = Results()
experimental::Experimental = Experimental()
Expand Down Expand Up @@ -264,6 +272,17 @@ const algorithms = Dict{String, Type}(
"Euler" => Euler,
)

# PCHIPInterpolation is only a function, creates a CubicHermiteSpline
const interpolation_methods =
Dict{String, @NamedTuple{type::Type, constructor::Union{Function, Type}}}(
"LinearInterpolation" =>
(type = LinearInterpolation, constructor = LinearInterpolation),
"PCHIPInterpolation" =>
(type = CubicHermiteSpline, constructor = PCHIPInterpolation),
)

interpolation_method(method) = get(interpolation_methods, method, nothing)

"""
Check whether the given function has a method that accepts the given kwarg.
Note that it is possible that methods exist that accept :a and :b individually,
Expand Down
14 changes: 8 additions & 6 deletions core/src/parameter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -405,8 +405,10 @@ interpolation in between. Relation can be updated in time, which is done by movi
the `time` field into the `tables`, which is done in the `update_tabulated_rating_curve`
callback.

Type parameter C indicates the content backing the StructVector, which can be a NamedTuple
of Vectors or Arrow Primitives, and is added to avoid type instabilities.
Type parameters to avoid type instabilities:
- C indicates the content backing the StructVector, which can be a NamedTuple
of Vectors or Arrow Primitives;
- I indicates the configured type of the interpolation objects for the Q(h) relationships

node_id: node ID of the TabulatedRatingCurve node
inflow_edge: incoming flow edge metadata
Expand All @@ -419,13 +421,13 @@ table: The current Q(h) relationships
time: The time table used for updating the tables
control_mapping: dictionary from (node_id, control_state) to Q(h) and/or active state
"""
@kwdef struct TabulatedRatingCurve{C} <: AbstractParameterNode
@kwdef struct TabulatedRatingCurve{C, I} <: AbstractParameterNode
node_id::Vector{NodeID}
inflow_edge::Vector{EdgeMetadata}
outflow_edge::Vector{EdgeMetadata}
active::Vector{Bool}
max_downstream_level::Vector{Float64} = fill(Inf, length(node_id))
table::Vector{ScalarInterpolation}
table::Vector{I}
time::StructVector{TabulatedRatingCurveTimeV1, C, Int}
control_mapping::Dict{Tuple{NodeID, String}, ControlStateUpdate}
end
Expand Down Expand Up @@ -880,14 +882,14 @@ const ModelGraph = MetaGraph{
Float64,
}

@kwdef struct Parameters{C1, C2, C3, C4, C5, C6, C7, C8, C9, V}
@kwdef struct Parameters{C1, C2, C3, C4, C5, C6, C7, C8, C9, V, I_TRC}
starttime::DateTime
graph::ModelGraph
allocation::Allocation
basin::Basin{C1, C2, V}
linear_resistance::LinearResistance
manning_resistance::ManningResistance
tabulated_rating_curve::TabulatedRatingCurve{C3}
tabulated_rating_curve::TabulatedRatingCurve{C3, I_TRC}
level_boundary::LevelBoundary{C4}
flow_boundary::FlowBoundary{C5}
pump::Pump
Expand Down
17 changes: 12 additions & 5 deletions core/src/read.jl
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,13 @@ function TabulatedRatingCurve(
)
end

interpolations = ScalarInterpolation[]
interpolation_type = interpolation_method(config.interpolation.tabulated_rating_curve)
if isnothing(interpolation_type)
error(
"Unsupported interpolation type $(config.interpolation.tabulated_rating_curve) for tabulated_rating_curve.",
)
end
interpolations = interpolation_type.type[]
control_mapping = Dict{Tuple{NodeID, String}, ControlStateUpdate}()
active = Bool[]
max_downstream_level = Float64[]
Expand All @@ -312,7 +318,7 @@ function TabulatedRatingCurve(
node_id,
)
static_id = view(static, rows)
local is_active, interpolation
local is_active, interpolation, max_level
# coalesce control_state to nothing to avoid boolean groupby logic on missing
for group in
IterTools.groupby(row -> coalesce(row.control_state, nothing), static_id)
Expand All @@ -326,9 +332,10 @@ function TabulatedRatingCurve(
errors = true
end
interpolation = try
qh_interpolation(table, rowrange)
qh_interpolation(table, rowrange, interpolation_type)
catch
LinearInterpolation(Float64[], Float64[])
errors = true
interpolation_type(Float64[], Float64[])
end
if !ismissing(control_state)
control_mapping[(
Expand All @@ -355,7 +362,7 @@ function TabulatedRatingCurve(
if !valid_tabulated_rating_curve(node_id, pre_table, rowrange)
errors = true
end
interpolation = qh_interpolation(pre_table, rowrange)
interpolation = qh_interpolation(pre_table, rowrange, interpolation_type)
max_level = coalesce(pre_table.max_downstream_level[rowrange][begin], Inf)
push!(interpolations, interpolation)
push!(active, true)
Expand Down
9 changes: 5 additions & 4 deletions core/src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,20 +121,21 @@ end

"""
From a table with columns node_id, flow_rate (Q) and level (h),
create a ScalarInterpolation from level to flow rate for a given node_id.
create an interpolation of given type from level to flow rate for a given node_id.
"""
function qh_interpolation(
table::StructVector,
rowrange::UnitRange{Int},
)::ScalarInterpolation
interpolation_type,
)::AbstractInterpolation
level = table.level[rowrange]
flow_rate = table.flow_rate[rowrange]

# Ensure that that Q stays 0 below the first level
pushfirst!(level, first(level) - 1)
pushfirst!(flow_rate, first(flow_rate))
pushfirst!(flow_rate, 0)

return LinearInterpolation(
return interpolation_type.constructor(
flow_rate,
level;
extrapolate = true,
Expand Down
5 changes: 5 additions & 0 deletions core/test/docs.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ sparse = true # optional, default true
autodiff = false # optional, default false
evaporate_mass = true # optional, default true to simulate a correct mass balance

[interpolation]
# Defines which interpolation types are used for which data
# TODO: add supported interpolation types
tabulated_rating_curve = "LinearInterpolation"

[logging]
# defines the logging level of Ribasim
verbosity = "info" # optional, default "info", can otherwise be "debug", "warn" or "error"
Expand Down
8 changes: 8 additions & 0 deletions docs/reference/usage.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ entry | type | description
----------------- | ------ | -----------
verbosity | String | Verbosity level: debug, info, warn, or error.

## Interpolation settings

The following can be set in the configuration in the `[interpolation]` section. The supported interpolation types are: ...

entry | type | description
---------------------- | ------ | -----------
tabulated_rating_curve | String | one of the supported interpolation types

## Experimental features

::: {.callout-important}
Expand Down
13 changes: 13 additions & 0 deletions python/ribasim/ribasim/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,19 @@ class Logging(ChildModel):
verbosity: Verbosity = Verbosity.info


class Interpolation(ChildModel):
"""
Defines the interpolation types used in the core

Attributes
----------
tabulated_rating_curve : str
The interpolation type used for Q(h) relationships
"""

tabulated_rating_curve: str = "LinearInterpolation"


class Experimental(ChildModel):
"""
Defines experimental features.
Expand Down
2 changes: 2 additions & 0 deletions python/ribasim/ribasim/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Experimental,
FlowBoundary,
FlowDemand,
Interpolation,
LevelBoundary,
LevelDemand,
LinearResistance,
Expand Down Expand Up @@ -82,6 +83,7 @@ class Model(FileModel):
results_dir: Path = Field(default=Path("results"))

logging: Logging = Field(default_factory=Logging)
interpolation: Interpolation = Field(default_factory=Interpolation)
solver: Solver = Field(default_factory=Solver)
results: Results = Field(default_factory=Results)

Expand Down
Loading