From 1c7715e4d9a16df880532055d42786f0d526d534 Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Wed, 13 Nov 2024 14:25:01 +0100 Subject: [PATCH] Configurable interpolation setup --- Manifest.toml | 4 ++-- core/src/Ribasim.jl | 1 + core/src/config.jl | 29 ++++++++++++++++++++++++----- core/src/parameter.jl | 14 ++++++++------ core/src/read.jl | 17 ++++++++++++----- core/src/util.jl | 9 +++++---- core/test/docs.toml | 5 +++++ docs/reference/usage.qmd | 8 ++++++++ python/ribasim/ribasim/config.py | 13 +++++++++++++ python/ribasim/ribasim/model.py | 2 ++ 10 files changed, 80 insertions(+), 22 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index b647634fa..33d45fcf3 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -461,9 +461,9 @@ version = "1.7.0" [[deps.DataInterpolations]] deps = ["FindFirstFunctions", "ForwardDiff", "LinearAlgebra", "PrettyTables", "RecipesBase", "Reexport"] -git-tree-sha1 = "3d81cd1fcba530122a5d6c725aa53521d869816a" +git-tree-sha1 = "78d06458ec13b53b3b0016daebe53f832d42ff44" uuid = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" -version = "6.5.2" +version = "6.6.0" [deps.DataInterpolations.extensions] DataInterpolationsChainRulesCoreExt = "ChainRulesCore" diff --git a/core/src/Ribasim.jl b/core/src/Ribasim.jl index b3656f4c1..b6b7cf617 100644 --- a/core/src/Ribasim.jl +++ b/core/src/Ribasim.jl @@ -67,6 +67,7 @@ using PreallocationTools: LazyBufferCache using DataInterpolations: LinearInterpolation, LinearInterpolationIntInv, + PCHIPInterpolation, invert_integral, derivative, integral, diff --git a/core/src/config.jl b/core/src/config.jl index 1fc0012e2..767cec01c 100644 --- a/core/src/config.jl +++ b/core/src/config.jl @@ -9,6 +9,7 @@ Ribasim.config is a submodule mainly to avoid name clashes between the configura module config using Configurations: Configurations, @option, from_toml, @type_alias +using DataInterpolations: LinearInterpolation, PCHIPInterpolation, CubicHermiteSpline using DataStructures: DefaultDict using Dates: DateTime using Logging: LogLevel, Debug, Info, Warn, Error @@ -24,13 +25,14 @@ using OrdinaryDiffEqRosenbrock: Rosenbrock23, Rodas4P, Rodas5P export Config, Solver, Results, Logging, Toml export algorithm, camel_case, - snake_case, - input_path, + convert_dt, + convert_saveat, database_path, + input_path, + interpolation_method, + nodetypes, results_path, - convert_saveat, - convert_dt, - nodetypes + snake_case const schemas = getfield.( @@ -134,9 +136,14 @@ end use_allocation::Bool = false end +@option struct Interpolation <: TableOption + tabulated_rating_curve::String = "LinearInterpolation" +end + @option struct Experimental <: TableOption concentration::Bool = false end + # For logging enabled experimental features function Base.iterate(exp::Experimental, state = 0) state >= nfields(exp) && return @@ -156,6 +163,7 @@ end results_dir::String allocation::Allocation = Allocation() solver::Solver = Solver() + interpolation::Interpolation = Interpolation() logging::Logging = Logging() results::Results = Results() experimental::Experimental = Experimental() @@ -264,6 +272,17 @@ const algorithms = Dict{String, Type}( "Euler" => Euler, ) +# PCHIPInterpolation is only a function, creates a CubicHermiteSpline +const interpolation_methods = + Dict{String, @NamedTuple{type::Type, constructor::Union{Function, Type}}}( + "LinearInterpolation" => + (type = LinearInterpolation, constructor = LinearInterpolation), + "PCHIPInterpolation" => + (type = CubicHermiteSpline, constructor = PCHIPInterpolation), + ) + +interpolation_method(method) = get(interpolation_methods, method, nothing) + """ Check whether the given function has a method that accepts the given kwarg. Note that it is possible that methods exist that accept :a and :b individually, diff --git a/core/src/parameter.jl b/core/src/parameter.jl index 97bc1cab3..b0f44f489 100644 --- a/core/src/parameter.jl +++ b/core/src/parameter.jl @@ -405,8 +405,10 @@ interpolation in between. Relation can be updated in time, which is done by movi the `time` field into the `tables`, which is done in the `update_tabulated_rating_curve` callback. -Type parameter C indicates the content backing the StructVector, which can be a NamedTuple -of Vectors or Arrow Primitives, and is added to avoid type instabilities. +Type parameters to avoid type instabilities: +- C indicates the content backing the StructVector, which can be a NamedTuple +of Vectors or Arrow Primitives; +- I indicates the configured type of the interpolation objects for the Q(h) relationships node_id: node ID of the TabulatedRatingCurve node inflow_edge: incoming flow edge metadata @@ -419,13 +421,13 @@ table: The current Q(h) relationships time: The time table used for updating the tables control_mapping: dictionary from (node_id, control_state) to Q(h) and/or active state """ -@kwdef struct TabulatedRatingCurve{C} <: AbstractParameterNode +@kwdef struct TabulatedRatingCurve{C, I} <: AbstractParameterNode node_id::Vector{NodeID} inflow_edge::Vector{EdgeMetadata} outflow_edge::Vector{EdgeMetadata} active::Vector{Bool} max_downstream_level::Vector{Float64} = fill(Inf, length(node_id)) - table::Vector{ScalarInterpolation} + table::Vector{I} time::StructVector{TabulatedRatingCurveTimeV1, C, Int} control_mapping::Dict{Tuple{NodeID, String}, ControlStateUpdate} end @@ -880,14 +882,14 @@ const ModelGraph = MetaGraph{ Float64, } -@kwdef struct Parameters{C1, C2, C3, C4, C5, C6, C7, C8, C9, V} +@kwdef struct Parameters{C1, C2, C3, C4, C5, C6, C7, C8, C9, V, I_TRC} starttime::DateTime graph::ModelGraph allocation::Allocation basin::Basin{C1, C2, V} linear_resistance::LinearResistance manning_resistance::ManningResistance - tabulated_rating_curve::TabulatedRatingCurve{C3} + tabulated_rating_curve::TabulatedRatingCurve{C3, I_TRC} level_boundary::LevelBoundary{C4} flow_boundary::FlowBoundary{C5} pump::Pump diff --git a/core/src/read.jl b/core/src/read.jl index 749554c3b..bb8695046 100644 --- a/core/src/read.jl +++ b/core/src/read.jl @@ -295,7 +295,13 @@ function TabulatedRatingCurve( ) end - interpolations = ScalarInterpolation[] + interpolation_type = interpolation_method(config.interpolation.tabulated_rating_curve) + if isnothing(interpolation_type) + error( + "Unsupported interpolation type $(config.interpolation.tabulated_rating_curve) for tabulated_rating_curve.", + ) + end + interpolations = interpolation_type.type[] control_mapping = Dict{Tuple{NodeID, String}, ControlStateUpdate}() active = Bool[] max_downstream_level = Float64[] @@ -312,7 +318,7 @@ function TabulatedRatingCurve( node_id, ) static_id = view(static, rows) - local is_active, interpolation + local is_active, interpolation, max_level # coalesce control_state to nothing to avoid boolean groupby logic on missing for group in IterTools.groupby(row -> coalesce(row.control_state, nothing), static_id) @@ -326,9 +332,10 @@ function TabulatedRatingCurve( errors = true end interpolation = try - qh_interpolation(table, rowrange) + qh_interpolation(table, rowrange, interpolation_type) catch - LinearInterpolation(Float64[], Float64[]) + errors = true + interpolation_type(Float64[], Float64[]) end if !ismissing(control_state) control_mapping[( @@ -355,7 +362,7 @@ function TabulatedRatingCurve( if !valid_tabulated_rating_curve(node_id, pre_table, rowrange) errors = true end - interpolation = qh_interpolation(pre_table, rowrange) + interpolation = qh_interpolation(pre_table, rowrange, interpolation_type) max_level = coalesce(pre_table.max_downstream_level[rowrange][begin], Inf) push!(interpolations, interpolation) push!(active, true) diff --git a/core/src/util.jl b/core/src/util.jl index d12498c9b..01247977c 100644 --- a/core/src/util.jl +++ b/core/src/util.jl @@ -121,20 +121,21 @@ end """ From a table with columns node_id, flow_rate (Q) and level (h), -create a ScalarInterpolation from level to flow rate for a given node_id. +create an interpolation of given type from level to flow rate for a given node_id. """ function qh_interpolation( table::StructVector, rowrange::UnitRange{Int}, -)::ScalarInterpolation + interpolation_type, +)::AbstractInterpolation level = table.level[rowrange] flow_rate = table.flow_rate[rowrange] # Ensure that that Q stays 0 below the first level pushfirst!(level, first(level) - 1) - pushfirst!(flow_rate, first(flow_rate)) + pushfirst!(flow_rate, 0) - return LinearInterpolation( + return interpolation_type.constructor( flow_rate, level; extrapolate = true, diff --git a/core/test/docs.toml b/core/test/docs.toml index ac155cebf..880fd4e9b 100644 --- a/core/test/docs.toml +++ b/core/test/docs.toml @@ -40,6 +40,11 @@ sparse = true # optional, default true autodiff = false # optional, default false evaporate_mass = true # optional, default true to simulate a correct mass balance +[interpolation] +# Defines which interpolation types are used for which data +# TODO: add supported interpolation types +tabulated_rating_curve = "LinearInterpolation" + [logging] # defines the logging level of Ribasim verbosity = "info" # optional, default "info", can otherwise be "debug", "warn" or "error" diff --git a/docs/reference/usage.qmd b/docs/reference/usage.qmd index 63eeedd99..565608c62 100644 --- a/docs/reference/usage.qmd +++ b/docs/reference/usage.qmd @@ -80,6 +80,14 @@ entry | type | description ----------------- | ------ | ----------- verbosity | String | Verbosity level: debug, info, warn, or error. +## Interpolation settings + +The following can be set in the configuration in the `[interpolation]` section. The supported interpolation types are: ... + +entry | type | description +---------------------- | ------ | ----------- +tabulated_rating_curve | String | one of the supported interpolation types + ## Experimental features ::: {.callout-important} diff --git a/python/ribasim/ribasim/config.py b/python/ribasim/ribasim/config.py index 78a59e85e..9d731dfd0 100644 --- a/python/ribasim/ribasim/config.py +++ b/python/ribasim/ribasim/config.py @@ -144,6 +144,19 @@ class Logging(ChildModel): verbosity: Verbosity = Verbosity.info +class Interpolation(ChildModel): + """ + Defines the interpolation types used in the core + + Attributes + ---------- + tabulated_rating_curve : str + The interpolation type used for Q(h) relationships + """ + + tabulated_rating_curve: str = "LinearInterpolation" + + class Experimental(ChildModel): """ Defines experimental features. diff --git a/python/ribasim/ribasim/model.py b/python/ribasim/ribasim/model.py index d7274a67c..a44fa313c 100644 --- a/python/ribasim/ribasim/model.py +++ b/python/ribasim/ribasim/model.py @@ -29,6 +29,7 @@ Experimental, FlowBoundary, FlowDemand, + Interpolation, LevelBoundary, LevelDemand, LinearResistance, @@ -82,6 +83,7 @@ class Model(FileModel): results_dir: Path = Field(default=Path("results")) logging: Logging = Field(default_factory=Logging) + interpolation: Interpolation = Field(default_factory=Interpolation) solver: Solver = Field(default_factory=Solver) results: Results = Field(default_factory=Results)