Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Symbolic Jacobian sparsity #1606

Merged
merged 35 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
84a9f3c
Symbolic Jacobian sparsity
visr Jul 3, 2024
547aef4
Merge branch 'main' into symbolic-jacobian-sparsity
SouthEndMusic Jul 24, 2024
b5861cf
Non-branching reduction factor
SouthEndMusic Jul 24, 2024
101c397
Merge branch 'main' into symbolic-jacobian-sparsity
SouthEndMusic Jul 29, 2024
0966f53
Manifest update
SouthEndMusic Jul 29, 2024
a7ad503
Upgrade to DataInterpolations 6
SouthEndMusic Jul 29, 2024
f454301
POC!
SouthEndMusic Jul 30, 2024
34bd75c
Support more node types
SouthEndMusic Jul 30, 2024
2002b70
Merge branch 'main' into symbolic-jacobian-sparsity
SouthEndMusic Jul 30, 2024
1cee9e1
Fix many tests
SouthEndMusic Jul 30, 2024
7869163
pass all test (!)
SouthEndMusic Jul 31, 2024
f7cf558
Merge branch 'main' into symbolic-jacobian-sparsity
SouthEndMusic Jul 31, 2024
a5e990a
Docs fix
SouthEndMusic Jul 31, 2024
2676c22
Make sure all nodes are active when detecting Jacobian sparsity
SouthEndMusic Jul 31, 2024
e0a5519
Use LazyBufferCache initialization feature
SouthEndMusic Jul 31, 2024
aab4085
Some comments adressed (work done yesterday)
SouthEndMusic Aug 2, 2024
81420e5
Merge branch 'main' into symbolic-jacobian-sparsity
SouthEndMusic Aug 5, 2024
2c39729
Pass tests
SouthEndMusic Aug 5, 2024
4664f01
Merge branch 'main' into symbolic-jacobian-sparsity
SouthEndMusic Aug 6, 2024
f589be2
Merge branch 'main' into symbolic-jacobian-sparsity
SouthEndMusic Aug 6, 2024
8b6b97f
Merge branch 'main' into symbolic-jacobian-sparsity
SouthEndMusic Aug 6, 2024
2be2b05
Use tracerLocalSparsityDetector
SouthEndMusic Aug 7, 2024
e008204
Merge branch 'main' into symbolic-jacobian-sparsity
SouthEndMusic Aug 7, 2024
7d5e65e
Update SparseConnectivityTracer version
SouthEndMusic Aug 8, 2024
f02891b
Merge branch 'main' into symbolic-jacobian-sparsity
SouthEndMusic Aug 13, 2024
390e00c
Update dependency versions
SouthEndMusic Aug 13, 2024
69896f1
Add custom overloads
SouthEndMusic Aug 16, 2024
9ae49fb
Merge branch 'main' into symbolic-jacobian-sparsity
SouthEndMusic Aug 16, 2024
20e7fe1
Use the safe TracerSparsityDetector, revert reduction factor changes
SouthEndMusic Aug 16, 2024
0df1dcc
Mute type piracy error
SouthEndMusic Aug 16, 2024
956f77e
nit
SouthEndMusic Aug 16, 2024
b8916ae
Cleanup manifest
SouthEndMusic Aug 16, 2024
062a01c
Use latest release of SparseConnectivityTracer
SouthEndMusic Aug 19, 2024
4064042
Use latest re;ease of DataInterpolations
SouthEndMusic Aug 19, 2024
87f702e
Merge branch 'main' into symbolic-jacobian-sparsity
SouthEndMusic Aug 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.10.4"
manifest_format = "2.0"
project_hash = "ced48539db7c5827a1e795e5349b519fb599876e"
project_hash = "f7a79f84bac25727dfc97c214bb3f25b81baff47"

[[deps.ADTypes]]
git-tree-sha1 = "6778bcc27496dae5723ff37ee30af451db8b35fe"
Expand Down Expand Up @@ -339,10 +339,10 @@ uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
version = "1.6.1"

[[deps.DataInterpolations]]
deps = ["FindFirstFunctions", "ForwardDiff", "LinearAlgebra", "PrettyTables", "ReadOnlyArrays", "RecipesBase", "Reexport"]
git-tree-sha1 = "3ba1e37d1315439539e3d8950dbc7042771c8978"
deps = ["FindFirstFunctions", "ForwardDiff", "LinearAlgebra", "PrettyTables", "RecipesBase", "Reexport"]
git-tree-sha1 = "9cc1cf079b42b5b6392c6b1df4bfc3e2a852b597"
uuid = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
version = "5.3.1"
version = "6.1.0"

[deps.DataInterpolations.extensions]
DataInterpolationsChainRulesCoreExt = "ChainRulesCore"
Expand All @@ -355,6 +355,7 @@ version = "5.3.1"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
RegularizationTools = "29dad682-9a27-4bc3-9c72-016788665182"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[[deps.DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
Expand Down Expand Up @@ -1247,11 +1248,6 @@ git-tree-sha1 = "f4a49b06ae830ff83a8743964ed08a805f5bab20"
uuid = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
version = "1.25.0"

[[deps.ReadOnlyArrays]]
git-tree-sha1 = "e6f7ddf48cf141cb312b078ca21cb2d29d0dc11d"
uuid = "988b38a3-91fc-5605-94a2-ee2116b3bd83"
version = "0.2.0"

[[deps.RecipesBase]]
deps = ["PrecompileTools"]
git-tree-sha1 = "5c3d09cc4f31f5fc6af001c250bf1278733100ff"
Expand Down Expand Up @@ -1312,7 +1308,7 @@ uuid = "295af30f-e4ad-537b-8983-00126c2a3abe"
version = "3.5.18"

[[deps.Ribasim]]
deps = ["Accessors", "Arrow", "BasicModelInterface", "CodecZstd", "ComponentArrays", "Configurations", "DBInterface", "DataInterpolations", "DataStructures", "Dates", "DiffEqCallbacks", "EnumX", "FiniteDiff", "ForwardDiff", "Graphs", "HiGHS", "IterTools", "JuMP", "Legolas", "LinearSolve", "Logging", "LoggingExtras", "MetaGraphsNext", "OrdinaryDiffEq", "PreallocationTools", "ReadOnlyArrays", "SQLite", "SciMLBase", "SparseArrays", "StructArrays", "Tables", "TerminalLoggers", "TimerOutputs", "TranscodingStreams"]
deps = ["Accessors", "Arrow", "BasicModelInterface", "CodecZstd", "ComponentArrays", "Configurations", "DBInterface", "DataInterpolations", "DataStructures", "Dates", "DiffEqCallbacks", "EnumX", "FiniteDiff", "ForwardDiff", "Graphs", "HiGHS", "IterTools", "JuMP", "Legolas", "LinearSolve", "Logging", "LoggingExtras", "MetaGraphsNext", "OrdinaryDiffEq", "PreallocationTools", "SQLite", "SciMLBase", "SparseArrays", "SparseConnectivityTracer", "StructArrays", "Tables", "TerminalLoggers", "TimerOutputs", "TranscodingStreams"]
path = "core"
uuid = "aac5e3d9-0b8f-4d4f-8241-b1a7a9632635"
version = "2024.10.0"
Expand Down Expand Up @@ -1468,6 +1464,20 @@ deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
version = "1.10.0"

[[deps.SparseConnectivityTracer]]
deps = ["ADTypes", "Compat", "DocStringExtensions", "FillArrays", "LinearAlgebra", "Random", "Requires", "SparseArrays"]
git-tree-sha1 = "73319ffcb025f603513c32878bac5bded9a1975f"
uuid = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
version = "0.6.1"

[deps.SparseConnectivityTracer.extensions]
SparseConnectivityTracerNNlibExt = "NNlib"
SparseConnectivityTracerSpecialFunctionsExt = "SpecialFunctions"

[deps.SparseConnectivityTracer.weakdeps]
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[[deps.SparseDiffTools]]
deps = ["ADTypes", "Adapt", "ArrayInterface", "Compat", "DataStructures", "FiniteDiff", "ForwardDiff", "Graphs", "LinearAlgebra", "PackageExtensionCompat", "Random", "Reexport", "SciMLOperators", "Setfield", "SparseArrays", "StaticArrayInterface", "StaticArrays", "Tricks", "UnPack", "VertexSafeGraphs"]
git-tree-sha1 = "469f51f8c4741ce944be2c0b65423b518b1405b0"
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ OteraEngine = "b2d7f28f-acd6-4007-8b26-bc27716e5513"
PackageCompiler = "9b87118b-4619-50d2-8e1e-99f35a4d4d9d"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
ReadOnlyArrays = "988b38a3-91fc-5605-94a2-ee2116b3bd83"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
Ribasim = "aac5e3d9-0b8f-4d4f-8241-b1a7a9632635"
SQLite = "0aa819cd-b072-5ff4-a722-6bc24af294d9"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
Expand Down
6 changes: 3 additions & 3 deletions core/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
MetaGraphsNext = "fa8bd995-216d-47f1-8a91-f3b68fbeb377"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
ReadOnlyArrays = "988b38a3-91fc-5605-94a2-ee2116b3bd83"
SQLite = "0aa819cd-b072-5ff4-a722-6bc24af294d9"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
Expand All @@ -58,7 +58,7 @@ ComponentArrays = "0.13, 0.14, 0.15"
Configurations = "0.17"
DBInterface = "2.4"
DataFrames = "1.4"
DataInterpolations = "=5.3.1"
DataInterpolations = "6"
DataStructures = "0.18"
Dates = "<0.0.1, 1"
DiffEqCallbacks = "3.6"
Expand All @@ -79,10 +79,10 @@ MetaGraphsNext = "0.6, 0.7"
OrdinaryDiffEq = "6.7"
PreallocationTools = "0.4"
ReTestItems = "1.20"
ReadOnlyArrays = "0.2"
SQLite = "1.5.1"
SciMLBase = "2.36"
SparseArrays = "<0.0.1, 1"
SparseConnectivityTracer = "0.6.1"
StructArrays = "0.6.13"
TOML = "<0.0.1, 1"
Tables = "1"
Expand Down
16 changes: 9 additions & 7 deletions core/src/Ribasim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,14 @@ import TranscodingStreams
using Accessors: @set
using Arrow: Arrow, Table
using CodecZstd: ZstdCompressor
using ComponentArrays: ComponentVector
using ComponentArrays: ComponentVector, Axis
using DataInterpolations:
LinearInterpolation, LinearInterpolationIntInv, invert_integral, derivative, integral
LinearInterpolation,
LinearInterpolationIntInv,
invert_integral,
derivative,
integral,
AbstractInterpolation
using Dates: Dates, DateTime, Millisecond, @dateformat_str
using DBInterface: execute
using DiffEqCallbacks:
Expand All @@ -50,8 +55,7 @@ using MetaGraphsNext:
outneighbor_labels,
inneighbor_labels
using OrdinaryDiffEq: OrdinaryDiffEq, OrdinaryDiffEqRosenbrockAdaptiveAlgorithm, get_du
using PreallocationTools: DiffCache, get_tmp
using ReadOnlyArrays: ReadOnlyVector
using PreallocationTools: LazyBufferCache
using SciMLBase:
init,
solve!,
Expand All @@ -66,13 +70,12 @@ using SciMLBase:
ODESolution,
VectorContinuousCallback,
get_proposed_dt
using SparseArrays: SparseMatrixCSC, spzeros
using SQLite: SQLite, DB, Query, esc_id
using StructArrays: StructVector
using Tables: Tables, AbstractRow, columntable
using TerminalLoggers: TerminalLogger
using TimerOutputs: TimerOutputs, TimerOutput, @timeit_debug

using SparseConnectivityTracer: TracerSparsityDetector, jacobian_sparsity, GradientTracer
export libribasim

const to = TimerOutput()
Expand All @@ -88,7 +91,6 @@ include("logging.jl")
include("allocation_init.jl")
include("allocation_optim.jl")
include("util.jl")
include("sparsity.jl")
include("graph.jl")
include("model.jl")
include("read.jl")
Expand Down
4 changes: 1 addition & 3 deletions core/src/allocation_optim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -306,12 +306,10 @@ function get_basin_data(
u::ComponentVector,
node_id::NodeID,
)
(; graph, basin, allocation) = p
(; vertical_flux) = basin
(; graph, allocation) = p
(; Δt_allocation) = allocation_model
(; mean_input_flows) = allocation
@assert node_id.type == NodeType.Basin
vertical_flux = get_tmp(vertical_flux, 0)
influx = mean_input_flows[(node_id, node_id)][]
storage_basin = u.storage[node_id.idx]
control_inneighbors = inneighbor_labels_type(graph, node_id, EdgeType.control)
Expand Down
2 changes: 1 addition & 1 deletion core/src/bmi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ function BMI.get_value_ptr(model::Model, name::AbstractString)::AbstractVector{F
if name == "basin.storage"
model.integrator.u.storage
elseif name == "basin.level"
get_tmp(model.integrator.p.basin.current_level, 0)
model.integrator.p.basin.current_level[parent(model.integrator.u)]
elseif name == "basin.infiltration"
model.integrator.p.basin.vertical_flux_from_input.infiltration
elseif name == "basin.drainage"
Expand Down
39 changes: 26 additions & 13 deletions core/src/callback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,9 @@ function integrate_flows!(u, t, integrator)::Nothing
(; flow, flow_dict, flow_prev, flow_integrated) = graph[]
(; vertical_flux, vertical_flux_prev, vertical_flux_integrated, vertical_flux_bmi) =
basin
flow = get_tmp(flow, 0)
vertical_flux = get_tmp(vertical_flux, 0)
du = get_du(integrator)
flow = flow[parent(du)]
vertical_flux = vertical_flux[parent(du)]
if !isempty(flow_prev) && isnan(flow_prev[1])
# If flow_prev is not populated yet
copyto!(flow_prev, flow)
Expand All @@ -136,7 +137,9 @@ function integrate_flows!(u, t, integrator)::Nothing
for (edge, value) in allocation.mean_realized_flows
if edge[1] !== edge[2]
value +=
0.5 * (get_flow(graph, edge..., 0) + get_flow_prev(graph, edge..., 0)) * dt
0.5 *
(get_flow(graph, edge..., du) + get_flow_prev(graph, edge..., du)) *
dt
allocation.mean_realized_flows[edge] = value
end
end
Expand All @@ -157,7 +160,9 @@ function integrate_flows!(u, t, integrator)::Nothing
# Horizontal flows
allocation.mean_input_flows[edge] =
value +
0.5 * (get_flow(graph, edge..., 0) + get_flow_prev(graph, edge..., 0)) * dt
0.5 *
(get_flow(graph, edge..., du) + get_flow_prev(graph, edge..., du)) *
dt
end
end
copyto!(flow_prev, flow)
Expand Down Expand Up @@ -236,6 +241,7 @@ function apply_discrete_control!(u, t, integrator)::Nothing
(; p) = integrator
(; discrete_control) = p
(; node_id) = discrete_control
du = get_du(integrator)

# Loop over the discrete control nodes to determine their truth state
# and detect possible control state changes
Expand All @@ -254,7 +260,7 @@ function apply_discrete_control!(u, t, integrator)::Nothing

# Loop over the variables listened to by this discrete control node
for compound_variable in compound_variables
value = compound_variable_value(compound_variable, p, u, t)
value = compound_variable_value(compound_variable, p, du, t)

# The thresholds the value of this variable is being compared with
greater_thans = compound_variable.greater_than
Expand Down Expand Up @@ -334,12 +340,12 @@ end
Get a value for a condition. Currently supports getting levels from basins and flows
from flow boundaries.
"""
function get_value(subvariable::NamedTuple, p::Parameters, u::AbstractVector, t::Float64)
function get_value(subvariable::NamedTuple, p::Parameters, du::AbstractVector, t::Float64)
(; flow_boundary, level_boundary, basin) = p
(; listen_node_id, look_ahead, variable, variable_ref) = subvariable

if !iszero(variable_ref.idx)
return get_value(variable_ref, u)
return get_value(variable_ref, du)
end

if variable == "level"
Expand Down Expand Up @@ -368,10 +374,10 @@ function get_value(subvariable::NamedTuple, p::Parameters, u::AbstractVector, t:
return value
end

function compound_variable_value(compound_variable::CompoundVariable, p, u, t)
value = zero(eltype(u))
function compound_variable_value(compound_variable::CompoundVariable, p, du, t)
value = zero(eltype(du))
for subvariable in compound_variable.subvariables
value += subvariable.weight * get_value(subvariable, p, u, t)
value += subvariable.weight * get_value(subvariable, p, du, t)
end
return value
end
Expand Down Expand Up @@ -412,7 +418,9 @@ function apply_parameter_update!(parameter_update)::Nothing
end

function update_subgrid_level!(integrator)::Nothing
basin_level = get_tmp(integrator.p.basin.current_level, 0)
(; p) = integrator
du = get_du(integrator)
basin_level = p.basin.current_level[parent(du)]
subgrid = integrator.p.subgrid
for (i, (index, interp)) in enumerate(zip(subgrid.basin_index, subgrid.interpolations))
subgrid.level[i] = interp(basin_level[index])
Expand All @@ -432,7 +440,7 @@ function update_basin!(integrator)::Nothing
(; storage) = u
(; node_id, time, vertical_flux_from_input, vertical_flux, vertical_flux_prev) = basin
t = datetime_since(integrator.t, integrator.p.starttime)
vertical_flux = get_tmp(vertical_flux, integrator.u)
vertical_flux = vertical_flux[parent(u)]

rows = searchsorted(time.time, t)
timeblock = view(time, rows)
Expand Down Expand Up @@ -530,7 +538,12 @@ function update_tabulated_rating_curve!(integrator)::Nothing
level = [row.level for row in group]
flow_rate = [row.flow_rate for row in group]
i = searchsortedfirst(node_id, NodeID(NodeType.TabulatedRatingCurve, id, 0))
table[i] = LinearInterpolation(flow_rate, level; extrapolate = true)
table[i] = LinearInterpolation(
flow_rate,
level;
extrapolate = true,
cache_parameters = true,
)
end
return nothing
end
Expand Down
Loading