Skip to content

Commit

Permalink
Move to DataInterpolations v5 (#1464)
Browse files Browse the repository at this point in the history
This was breaking for us since it removed a type alias that we defined
in our const alias ScalarInterpolation and VectorInterpolation.

This also tightens some return types to those concrete aliases, and
changes a LevelDemand field type from the abstract
`Vector{LinearInterpolation}` to the concrete
`Vector{ScalarInterpolation}`, which should help the compiler.

Closes #1461.
  • Loading branch information
visr authored May 14, 2024
1 parent 61b3a17 commit db1d1ea
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 27 deletions.
4 changes: 2 additions & 2 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,9 @@ version = "1.6.1"

[[deps.DataInterpolations]]
deps = ["FindFirstFunctions", "ForwardDiff", "LinearAlgebra", "PrettyTables", "RecipesBase", "Reexport"]
git-tree-sha1 = "b24ab19ead284c563c9e494899d2464a7e95face"
git-tree-sha1 = "b580ef00ec248aeb137b4ef3a4f751a567d35556"
uuid = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
version = "4.8.0"
version = "5.0.0"

[deps.DataInterpolations.extensions]
DataInterpolationsChainRulesCoreExt = "ChainRulesCore"
Expand Down
18 changes: 9 additions & 9 deletions core/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ Aqua = "0.8"
Arrow = "2.3"
BasicModelInterface = "0.1"
CSV = "0.10"
CodecZstd = "0.7,0.8"
ComponentArrays = "0.13,0.14,0.15"
CodecZstd = "0.7, 0.8"
ComponentArrays = "0.13, 0.14, 0.15"
Configurations = "0.17"
DBInterface = "2.4"
DataFrames = "1.4"
DataInterpolations = "4.4"
DataInterpolations = "5"
DataStructures = "0.18"
Dates = "<0.0.1,1"
Dates = "<0.0.1, 1"
Dictionaries = "0.3.25, 0.4"
DiffEqCallbacks = "3.6"
EnumX = "1.0"
Expand All @@ -66,22 +66,22 @@ IterTools = "1.4"
JuMP = "1.15"
Legolas = "0.5"
LinearSolve = "2.24"
Logging = "<0.0.1,1"
Logging = "<0.0.1, 1"
LoggingExtras = "1"
MetaGraphsNext = "0.6, 0.7"
OrdinaryDiffEq = "6.7"
PreallocationTools = "0.4"
ReTestItems = "1.20"
SQLite = "1.5.1"
SciMLBase = "2.36"
SparseArrays = "<0.0.1,1"
SparseArrays = "<0.0.1, 1"
StructArrays = "0.6.13"
TOML = "<0.0.1,1"
TOML = "<0.0.1, 1"
Tables = "1"
TerminalLoggers = "0.1.7"
Test = "<0.0.1,1"
Test = "<0.0.1, 1"
TimerOutputs = "0.5"
TranscodingStreams = "0.9,0.10"
TranscodingStreams = "0.9, 0.10"
julia = "1.10"

[extras]
Expand Down
9 changes: 4 additions & 5 deletions core/src/parameter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,9 @@ end

Base.to_index(id::NodeID) = Int(id.value)

const ScalarInterpolation =
LinearInterpolation{Vector{Float64}, Vector{Float64}, true, Float64}
const ScalarInterpolation = LinearInterpolation{Vector{Float64}, Vector{Float64}, Float64}
const VectorInterpolation =
LinearInterpolation{Vector{Vector{Float64}}, Vector{Float64}, true, Vector{Float64}}
LinearInterpolation{Vector{Vector{Float64}}, Vector{Float64}, Vector{Float64}}

"""
Store information for a subnetwork used for allocation.
Expand Down Expand Up @@ -617,8 +616,8 @@ priority: If in a shortage state, the priority of the demand of the connected ba
"""
struct LevelDemand <: AbstractParameterNode
node_id::Vector{NodeID}
min_level::Vector{LinearInterpolation}
max_level::Vector{LinearInterpolation}
min_level::Vector{ScalarInterpolation}
max_level::Vector{ScalarInterpolation}
priority::Vector{Int32}
end

Expand Down
8 changes: 4 additions & 4 deletions core/src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ function get_scalar_interpolation(
node_id::NodeID,
param::Symbol;
default_value::Float64 = 0.0,
)::Tuple{LinearInterpolation, Bool}
)::Tuple{ScalarInterpolation, Bool}
nodetype = node_id.type
rows = searchsorted(NodeID.(nodetype, time.node_id), node_id)
parameter = getfield.(time, param)[rows]
Expand Down Expand Up @@ -246,18 +246,18 @@ end
function qh_interpolation(
level::AbstractVector,
flow_rate::AbstractVector,
)::Tuple{LinearInterpolation, Bool}
)::Tuple{ScalarInterpolation, Bool}
return LinearInterpolation(flow_rate, level; extrapolate = true), allunique(level)
end

"""
From a table with columns node_id, flow_rate (Q) and level (h),
create a LinearInterpolation from level to flow rate for a given node_id.
create a ScalarInterpolation from level to flow rate for a given node_id.
"""
function qh_interpolation(
node_id::NodeID,
table::StructVector,
)::Tuple{LinearInterpolation, Bool}
)::Tuple{ScalarInterpolation, Bool}
nodetype = node_id.type
rowrange = findlastgroup(node_id, NodeID.(nodetype, table.node_id))
@assert !isempty(rowrange) "timeseries starts after model start time"
Expand Down
4 changes: 1 addition & 3 deletions core/src/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -384,9 +384,7 @@ end

function valid_demand(
node_id::Vector{NodeID},
demand_itp::Vector{
Vector{LinearInterpolation{Vector{Float64}, Vector{Float64}, true, Float64}},
},
demand_itp::Vector{Vector{ScalarInterpolation}},
priorities::Vector{Int32},
)::Bool
errors = false
Expand Down
7 changes: 3 additions & 4 deletions core/test/validation_test.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
@testitem "Basin profile validation" begin
using Dictionaries: Indices
using Ribasim: NodeID, valid_profiles, qh_interpolation
using DataInterpolations: LinearInterpolation
using Ribasim: NodeID, valid_profiles, qh_interpolation, ScalarInterpolation
using Logging

node_id = Indices([NodeID(:Basin, 1)])
Expand All @@ -27,10 +26,10 @@

itp, valid = qh_interpolation([0.0, 0.0], [1.0, 2.0])
@test !valid
@test itp isa LinearInterpolation
@test itp isa ScalarInterpolation
itp, valid = qh_interpolation([0.0, 0.1], [1.0, 2.0])
@test valid
@test itp isa LinearInterpolation
@test itp isa ScalarInterpolation
end

@testitem "Q(h) validation" begin
Expand Down

0 comments on commit db1d1ea

Please sign in to comment.