Skip to content

Commit

Permalink
Symbolic Jacobian sparsity (#1606)
Browse files Browse the repository at this point in the history
This doesn't work yet, but I quickly tried this out yesterday, thought I
might as well put it up to look if we should pursue this.
I think the main advantage is that we don't have to maintain our own
code for determining the sparsity, which can more easily get out of
sync. It does add some dependencies.

The error I get is with DataInterpolations not supporting
`::SymbolicUtils.BasicSymbolic{Real}`, which looks like this closed
issue: SciML/DataInterpolations.jl#168.
DataInterpolations has an [extension on
Symbolics](https://github.com/SciML/DataInterpolations.jl/blob/master/ext/DataInterpolationsSymbolicsExt.jl)
for this purpose, so perhaps we just need to create a minimal example
and file an issue.

```
ERROR: MethodError: no method matching (::DataInterpolations.LinearInterpolation{Vector{…}, Vector{…}, Float64})(::SymbolicUtils.BasicSymbolic{Real})

Closest candidates are:
  (::DataInterpolations.AbstractInterpolation)(::Symbolics.Num)
   @ DataInterpolationsSymbolicsExt C:\Users\visser_mn\.julia\packages\DataInterpolations\NoIUa\ext\DataInterpolationsSymbolicsExt.jl:15
  (::DataInterpolations.AbstractInterpolation)(::Number)
   @ DataInterpolations C:\Users\visser_mn\.julia\packages\DataInterpolations\NoIUa\src\DataInterpolations.jl:22
  (::DataInterpolations.AbstractInterpolation)(::Number, ::Integer)
   @ DataInterpolations C:\Users\visser_mn\.julia\packages\DataInterpolations\NoIUa\src\DataInterpolations.jl:23
  ...

Stacktrace:
  [1] get_area_and_level(basin::Ribasim.Basin{…}, state_idx::Int64, storage::Symbolics.Num)
    @ Ribasim d:\repo\ribasim\Ribasim\core\src\util.jl:57
  [2] set_current_basin_properties!(basin::Ribasim.Basin{…}, storage::SubArray{…})
    @ Ribasim d:\repo\ribasim\Ribasim\core\src\solve.jl:43
  [3] water_balance!(du::ComponentArrays.ComponentVector{…}, u::ComponentArrays.ComponentVector{…}, p::Ribasim.Parameters{…}, t::Float64)
```

It could be that there are a bunch more issues of this kind once this is
resolved.

If it starts to work but we don't want it, we should perhaps use it to
test our manual Jacobian sparsity.

---------

Co-authored-by: Bart de Koning <[email protected]>
  • Loading branch information
visr and SouthEndMusic authored Aug 19, 2024
1 parent e8e235a commit 9bdb262
Show file tree
Hide file tree
Showing 24 changed files with 456 additions and 484 deletions.
30 changes: 20 additions & 10 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.10.4"
manifest_format = "2.0"
project_hash = "ced48539db7c5827a1e795e5349b519fb599876e"
project_hash = "f7a79f84bac25727dfc97c214bb3f25b81baff47"

[[deps.ADTypes]]
git-tree-sha1 = "6778bcc27496dae5723ff37ee30af451db8b35fe"
Expand Down Expand Up @@ -339,10 +339,10 @@ uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
version = "1.6.1"

[[deps.DataInterpolations]]
deps = ["FindFirstFunctions", "ForwardDiff", "LinearAlgebra", "PrettyTables", "ReadOnlyArrays", "RecipesBase", "Reexport"]
git-tree-sha1 = "3ba1e37d1315439539e3d8950dbc7042771c8978"
deps = ["FindFirstFunctions", "ForwardDiff", "LinearAlgebra", "PrettyTables", "RecipesBase", "Reexport"]
git-tree-sha1 = "9cc1cf079b42b5b6392c6b1df4bfc3e2a852b597"
uuid = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
version = "5.3.1"
version = "6.1.0"

[deps.DataInterpolations.extensions]
DataInterpolationsChainRulesCoreExt = "ChainRulesCore"
Expand All @@ -355,6 +355,7 @@ version = "5.3.1"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
RegularizationTools = "29dad682-9a27-4bc3-9c72-016788665182"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[[deps.DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
Expand Down Expand Up @@ -1247,11 +1248,6 @@ git-tree-sha1 = "f4a49b06ae830ff83a8743964ed08a805f5bab20"
uuid = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
version = "1.25.0"

[[deps.ReadOnlyArrays]]
git-tree-sha1 = "e6f7ddf48cf141cb312b078ca21cb2d29d0dc11d"
uuid = "988b38a3-91fc-5605-94a2-ee2116b3bd83"
version = "0.2.0"

[[deps.RecipesBase]]
deps = ["PrecompileTools"]
git-tree-sha1 = "5c3d09cc4f31f5fc6af001c250bf1278733100ff"
Expand Down Expand Up @@ -1312,7 +1308,7 @@ uuid = "295af30f-e4ad-537b-8983-00126c2a3abe"
version = "3.5.18"

[[deps.Ribasim]]
deps = ["Accessors", "Arrow", "BasicModelInterface", "CodecZstd", "ComponentArrays", "Configurations", "DBInterface", "DataInterpolations", "DataStructures", "Dates", "DiffEqCallbacks", "EnumX", "FiniteDiff", "ForwardDiff", "Graphs", "HiGHS", "IterTools", "JuMP", "Legolas", "LinearSolve", "Logging", "LoggingExtras", "MetaGraphsNext", "OrdinaryDiffEq", "PreallocationTools", "ReadOnlyArrays", "SQLite", "SciMLBase", "SparseArrays", "StructArrays", "Tables", "TerminalLoggers", "TimerOutputs", "TranscodingStreams"]
deps = ["Accessors", "Arrow", "BasicModelInterface", "CodecZstd", "ComponentArrays", "Configurations", "DBInterface", "DataInterpolations", "DataStructures", "Dates", "DiffEqCallbacks", "EnumX", "FiniteDiff", "ForwardDiff", "Graphs", "HiGHS", "IterTools", "JuMP", "Legolas", "LinearSolve", "Logging", "LoggingExtras", "MetaGraphsNext", "OrdinaryDiffEq", "PreallocationTools", "SQLite", "SciMLBase", "SparseArrays", "SparseConnectivityTracer", "StructArrays", "Tables", "TerminalLoggers", "TimerOutputs", "TranscodingStreams"]
path = "core"
uuid = "aac5e3d9-0b8f-4d4f-8241-b1a7a9632635"
version = "2024.10.0"
Expand Down Expand Up @@ -1468,6 +1464,20 @@ deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
version = "1.10.0"

[[deps.SparseConnectivityTracer]]
deps = ["ADTypes", "Compat", "DocStringExtensions", "FillArrays", "LinearAlgebra", "Random", "Requires", "SparseArrays"]
git-tree-sha1 = "73319ffcb025f603513c32878bac5bded9a1975f"
uuid = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
version = "0.6.1"

[deps.SparseConnectivityTracer.extensions]
SparseConnectivityTracerNNlibExt = "NNlib"
SparseConnectivityTracerSpecialFunctionsExt = "SpecialFunctions"

[deps.SparseConnectivityTracer.weakdeps]
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[[deps.SparseDiffTools]]
deps = ["ADTypes", "Adapt", "ArrayInterface", "Compat", "DataStructures", "FiniteDiff", "ForwardDiff", "Graphs", "LinearAlgebra", "PackageExtensionCompat", "Random", "Reexport", "SciMLOperators", "Setfield", "SparseArrays", "StaticArrayInterface", "StaticArrays", "Tricks", "UnPack", "VertexSafeGraphs"]
git-tree-sha1 = "469f51f8c4741ce944be2c0b65423b518b1405b0"
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ OteraEngine = "b2d7f28f-acd6-4007-8b26-bc27716e5513"
PackageCompiler = "9b87118b-4619-50d2-8e1e-99f35a4d4d9d"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
ReadOnlyArrays = "988b38a3-91fc-5605-94a2-ee2116b3bd83"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
Ribasim = "aac5e3d9-0b8f-4d4f-8241-b1a7a9632635"
SQLite = "0aa819cd-b072-5ff4-a722-6bc24af294d9"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
Expand Down
6 changes: 3 additions & 3 deletions core/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
MetaGraphsNext = "fa8bd995-216d-47f1-8a91-f3b68fbeb377"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
ReadOnlyArrays = "988b38a3-91fc-5605-94a2-ee2116b3bd83"
SQLite = "0aa819cd-b072-5ff4-a722-6bc24af294d9"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
Expand All @@ -58,7 +58,7 @@ ComponentArrays = "0.13, 0.14, 0.15"
Configurations = "0.17"
DBInterface = "2.4"
DataFrames = "1.4"
DataInterpolations = "=5.3.1"
DataInterpolations = "6"
DataStructures = "0.18"
Dates = "<0.0.1, 1"
DiffEqCallbacks = "3.6"
Expand All @@ -79,10 +79,10 @@ MetaGraphsNext = "0.6, 0.7"
OrdinaryDiffEq = "6.7"
PreallocationTools = "0.4"
ReTestItems = "1.20"
ReadOnlyArrays = "0.2"
SQLite = "1.5.1"
SciMLBase = "2.36"
SparseArrays = "<0.0.1, 1"
SparseConnectivityTracer = "0.6.1"
StructArrays = "0.6.13"
TOML = "<0.0.1, 1"
Tables = "1"
Expand Down
16 changes: 9 additions & 7 deletions core/src/Ribasim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,14 @@ import TranscodingStreams
using Accessors: @set
using Arrow: Arrow, Table
using CodecZstd: ZstdCompressor
using ComponentArrays: ComponentVector
using ComponentArrays: ComponentVector, Axis
using DataInterpolations:
LinearInterpolation, LinearInterpolationIntInv, invert_integral, derivative, integral
LinearInterpolation,
LinearInterpolationIntInv,
invert_integral,
derivative,
integral,
AbstractInterpolation
using Dates: Dates, DateTime, Millisecond, @dateformat_str
using DBInterface: execute
using DiffEqCallbacks:
Expand All @@ -50,8 +55,7 @@ using MetaGraphsNext:
outneighbor_labels,
inneighbor_labels
using OrdinaryDiffEq: OrdinaryDiffEq, OrdinaryDiffEqRosenbrockAdaptiveAlgorithm, get_du
using PreallocationTools: DiffCache, get_tmp
using ReadOnlyArrays: ReadOnlyVector
using PreallocationTools: LazyBufferCache
using SciMLBase:
init,
solve!,
Expand All @@ -66,13 +70,12 @@ using SciMLBase:
ODESolution,
VectorContinuousCallback,
get_proposed_dt
using SparseArrays: SparseMatrixCSC, spzeros
using SQLite: SQLite, DB, Query, esc_id
using StructArrays: StructVector
using Tables: Tables, AbstractRow, columntable
using TerminalLoggers: TerminalLogger
using TimerOutputs: TimerOutputs, TimerOutput, @timeit_debug

using SparseConnectivityTracer: TracerSparsityDetector, jacobian_sparsity, GradientTracer
export libribasim

const to = TimerOutput()
Expand All @@ -88,7 +91,6 @@ include("logging.jl")
include("allocation_init.jl")
include("allocation_optim.jl")
include("util.jl")
include("sparsity.jl")
include("graph.jl")
include("model.jl")
include("read.jl")
Expand Down
4 changes: 1 addition & 3 deletions core/src/allocation_optim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -306,12 +306,10 @@ function get_basin_data(
u::ComponentVector,
node_id::NodeID,
)
(; graph, basin, allocation) = p
(; vertical_flux) = basin
(; graph, allocation) = p
(; Δt_allocation) = allocation_model
(; mean_input_flows) = allocation
@assert node_id.type == NodeType.Basin
vertical_flux = get_tmp(vertical_flux, 0)
influx = mean_input_flows[(node_id, node_id)][]
storage_basin = u.storage[node_id.idx]
control_inneighbors = inneighbor_labels_type(graph, node_id, EdgeType.control)
Expand Down
2 changes: 1 addition & 1 deletion core/src/bmi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ function BMI.get_value_ptr(model::Model, name::AbstractString)::AbstractVector{F
if name == "basin.storage"
model.integrator.u.storage
elseif name == "basin.level"
get_tmp(model.integrator.p.basin.current_level, 0)
model.integrator.p.basin.current_level[parent(model.integrator.u)]
elseif name == "basin.infiltration"
model.integrator.p.basin.vertical_flux_from_input.infiltration
elseif name == "basin.drainage"
Expand Down
39 changes: 26 additions & 13 deletions core/src/callback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,9 @@ function integrate_flows!(u, t, integrator)::Nothing
(; flow, flow_dict, flow_prev, flow_integrated) = graph[]
(; vertical_flux, vertical_flux_prev, vertical_flux_integrated, vertical_flux_bmi) =
basin
flow = get_tmp(flow, 0)
vertical_flux = get_tmp(vertical_flux, 0)
du = get_du(integrator)
flow = flow[parent(du)]
vertical_flux = vertical_flux[parent(du)]
if !isempty(flow_prev) && isnan(flow_prev[1])
# If flow_prev is not populated yet
copyto!(flow_prev, flow)
Expand All @@ -136,7 +137,9 @@ function integrate_flows!(u, t, integrator)::Nothing
for (edge, value) in allocation.mean_realized_flows
if edge[1] !== edge[2]
value +=
0.5 * (get_flow(graph, edge..., 0) + get_flow_prev(graph, edge..., 0)) * dt
0.5 *
(get_flow(graph, edge..., du) + get_flow_prev(graph, edge..., du)) *
dt
allocation.mean_realized_flows[edge] = value
end
end
Expand All @@ -157,7 +160,9 @@ function integrate_flows!(u, t, integrator)::Nothing
# Horizontal flows
allocation.mean_input_flows[edge] =
value +
0.5 * (get_flow(graph, edge..., 0) + get_flow_prev(graph, edge..., 0)) * dt
0.5 *
(get_flow(graph, edge..., du) + get_flow_prev(graph, edge..., du)) *
dt
end
end
copyto!(flow_prev, flow)
Expand Down Expand Up @@ -236,6 +241,7 @@ function apply_discrete_control!(u, t, integrator)::Nothing
(; p) = integrator
(; discrete_control) = p
(; node_id) = discrete_control
du = get_du(integrator)

# Loop over the discrete control nodes to determine their truth state
# and detect possible control state changes
Expand All @@ -254,7 +260,7 @@ function apply_discrete_control!(u, t, integrator)::Nothing

# Loop over the variables listened to by this discrete control node
for compound_variable in compound_variables
value = compound_variable_value(compound_variable, p, u, t)
value = compound_variable_value(compound_variable, p, du, t)

# The thresholds the value of this variable is being compared with
greater_thans = compound_variable.greater_than
Expand Down Expand Up @@ -334,12 +340,12 @@ end
Get a value for a condition. Currently supports getting levels from basins and flows
from flow boundaries.
"""
function get_value(subvariable::NamedTuple, p::Parameters, u::AbstractVector, t::Float64)
function get_value(subvariable::NamedTuple, p::Parameters, du::AbstractVector, t::Float64)
(; flow_boundary, level_boundary, basin) = p
(; listen_node_id, look_ahead, variable, variable_ref) = subvariable

if !iszero(variable_ref.idx)
return get_value(variable_ref, u)
return get_value(variable_ref, du)
end

if variable == "level"
Expand Down Expand Up @@ -368,10 +374,10 @@ function get_value(subvariable::NamedTuple, p::Parameters, u::AbstractVector, t:
return value
end

function compound_variable_value(compound_variable::CompoundVariable, p, u, t)
value = zero(eltype(u))
function compound_variable_value(compound_variable::CompoundVariable, p, du, t)
value = zero(eltype(du))
for subvariable in compound_variable.subvariables
value += subvariable.weight * get_value(subvariable, p, u, t)
value += subvariable.weight * get_value(subvariable, p, du, t)
end
return value
end
Expand Down Expand Up @@ -412,7 +418,9 @@ function apply_parameter_update!(parameter_update)::Nothing
end

function update_subgrid_level!(integrator)::Nothing
basin_level = get_tmp(integrator.p.basin.current_level, 0)
(; p) = integrator
du = get_du(integrator)
basin_level = p.basin.current_level[parent(du)]
subgrid = integrator.p.subgrid
for (i, (index, interp)) in enumerate(zip(subgrid.basin_index, subgrid.interpolations))
subgrid.level[i] = interp(basin_level[index])
Expand All @@ -432,7 +440,7 @@ function update_basin!(integrator)::Nothing
(; storage) = u
(; node_id, time, vertical_flux_from_input, vertical_flux, vertical_flux_prev) = basin
t = datetime_since(integrator.t, integrator.p.starttime)
vertical_flux = get_tmp(vertical_flux, integrator.u)
vertical_flux = vertical_flux[parent(u)]

rows = searchsorted(time.time, t)
timeblock = view(time, rows)
Expand Down Expand Up @@ -530,7 +538,12 @@ function update_tabulated_rating_curve!(integrator)::Nothing
level = [row.level for row in group]
flow_rate = [row.flow_rate for row in group]
i = searchsortedfirst(node_id, NodeID(NodeType.TabulatedRatingCurve, id, 0))
table[i] = LinearInterpolation(flow_rate, level; extrapolate = true)
table[i] = LinearInterpolation(
flow_rate,
level;
extrapolate = true,
cache_parameters = true,
)
end
return nothing
end
Expand Down
Loading

0 comments on commit 9bdb262

Please sign in to comment.