Skip to content

Commit

Permalink
Bring back custom sparse arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
SouthEndMusic committed Oct 23, 2023
1 parent e68e0e9 commit ad77b79
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 16 deletions.
2 changes: 1 addition & 1 deletion core/src/Ribasim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ using Logging: current_logger, min_enabled_level, with_logger
using LoggingExtras: EarlyFilteredLogger, LevelOverrideLogger
using OrdinaryDiffEq
using OrdinaryDiffEq: OrdinaryDiffEqRosenbrockAdaptiveAlgorithm
using PreallocationTools: DiffCache, FixedSizeDiffCache, get_tmp
using PreallocationTools: DiffCache, get_tmp
using SciMLBase
using SparseArrays
using SQLite: SQLite, DB, Query, esc_id
Expand Down
2 changes: 1 addition & 1 deletion core/src/bmi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ end

"Copy the current flow to the SavedValues"
function save_flow(u, t, integrator)
copy(nonzeros(get_tmp(integrator.p.connectivity.flow, 0)))
copy(nonzeros(get_tmp_sparse(integrator.p.connectivity.flow, 0)))
end

"Load updates from 'Basin / time' into the parameters"
Expand Down
23 changes: 19 additions & 4 deletions core/src/create.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,19 @@ end
const nonconservative_nodetypes =
Set{String}(["Basin", "LevelBoundary", "FlowBoundary", "Terminal", "User"])

"""
Get the chunk sizes for DiffCache; differentiation w.r.t. u
and t (the latter only if a Rosenbrock algorithm is used).
"""
function get_chunk_sizes(config::Config, chunk_size::Int)::Vector{Int}
chunk_sizes = [chunk_size]
if Ribasim.config.algorithms[config.solver.algorithm] <:
OrdinaryDiffEqRosenbrockAdaptiveAlgorithm
push!(chunk_sizes, 1)

Check warning on line 212 in core/src/create.jl

View check run for this annotation

Codecov / codecov/patch

core/src/create.jl#L212

Added line #L212 was not covered by tests
end
return chunk_sizes
end

"""
If the tuple of variables contains a Dual variable, return the first one.
Otherwise return the last variable.
Expand Down Expand Up @@ -239,8 +252,9 @@ function Connectivity(db::DB, config::Config, chunk_size::Int)::Connectivity
flow .= 0.0

if config.solver.autodiff
# FixedSizeDiffCache performs better for sparse matrix
flow = FixedSizeDiffCache(flow, chunk_size)
chunk_sizes = get_chunk_sizes(config, chunk_size)
flowd = DiffCache(flow.nzval, chunk_sizes)
flow = SparseMatrixCSC_DiffCache(flow.m, flow.n, flow.colptr, flow.rowval, flowd)
end

# TODO: Create allocation models from input here
Expand Down Expand Up @@ -514,8 +528,9 @@ function Basin(db::DB, config::Config, chunk_size::Int)::Basin
current_area = zeros(n)

if config.solver.autodiff
current_level = DiffCache(current_level, chunk_size)
current_area = DiffCache(current_area, chunk_size)
chunk_sizes = get_chunk_sizes(config, chunk_size)
current_level = DiffCache(current_level, chunk_sizes)
current_area = DiffCache(current_area, chunk_sizes)
end

precipitation = fill(NaN, length(node_id))
Expand Down
2 changes: 1 addition & 1 deletion core/src/io.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ function flow_table(model::Model)::NamedTuple
(; t, saveval) = saved_flow
(; connectivity) = integrator.p

I, J, _ = findnz(get_tmp(connectivity.flow, 0))
I, J, _ = findnz(get_tmp_sparse(connectivity.flow, 0))
# self-loops have no edge ID
unique_edge_ids = [get(connectivity.edge_ids_flow, ij, missing) for ij in zip(I, J)]
nflow = length(I)
Expand Down
72 changes: 63 additions & 9 deletions core/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,60 @@ const ScalarInterpolation =
const VectorInterpolation =
LinearInterpolation{Vector{Vector{Float64}}, Vector{Float64}, true, Vector{Float64}}

"""
This struct is analogous to SparseArrays.SparseMatrixCSC, with the only
difference that nzval is a DiffCache instead of a vector.
"""
struct SparseMatrixCSC_DiffCache{Ti <: Integer, D <: DiffCache}
m::Int # Number of rows
n::Int # Number of columns
colptr::Vector{Ti} # Column j is in colptr[j]:(colptr[j+1]-1)
rowval::Vector{Ti} # Row indices of stored values
nzval::D # DiffCache
end

"""
This struct is analogous to SparseArrays.SparseMatrixCSC, with the only
difference that nzval is either a vector or a ReinterpretArray (of Dual numbers).
SparseMatrixCSC_DiffCache and SparseMatrixCSC_cache are used to have a multi-level DiffCache of a sparse matrix
whose cache (accessed via get_tmp_sparse) supports sparse indexing.
Previously FixedSizeDiffCache was used for the sparse matrix Connectivity.flow, but FixedSizeDiffCache
does not support multi-level Duals.
"""
struct SparseMatrixCSC_cache{Ti <: Integer, D <: Union{Vector, Base.ReinterpretArray}} <:
SparseArrays.AbstractSparseMatrixCSC{eltype(D), Ti}
m::Int # Number of rows
n::Int # Number of columns
colptr::Vector{Ti} # Column j is in colptr[j]:(colptr[j+1]-1)
rowval::Vector{Ti} # Row indices of stored values
nzval::D # cache
end

const SparseCache = Union{SparseMatrixCSC_DiffCache, SparseMatrixCSC_cache}

SparseArrays.size(matrix::SparseCache) = (matrix.m, matrix.n)
SparseArrays.getcolptr(matrix::SparseCache) = matrix.colptr
SparseArrays.rowvals(matrix::SparseCache) = matrix.rowval
SparseArrays.nonzeros(matrix::SparseCache) = matrix.nzval

"""
Access the cache of the SparseMatrixCSC_DiffCache, and return it in a SparseMatrixCSC_cache
so that it can be accessed with sparse indexing.
"""
function get_tmp_sparse(
matrix::Union{SparseMatrixCSC_DiffCache, SparseMatrixCSC},
u,
)::SparseMatrixCSC_cache
return SparseMatrixCSC_cache(
matrix.m,
matrix.n,
matrix.colptr,
matrix.rowval,
get_tmp(matrix.nzval, u),
)
end

"""
Store information for a subnetwork used for allocation.
Expand Down Expand Up @@ -836,7 +890,7 @@ function formulate_flow!(
(; node_id, active, resistance) = linear_resistance

diffvar = get_diffvar((t, storage))
flow = get_tmp(flow, diffvar)
flow = get_tmp_sparse(flow, diffvar)

for (i, id) in enumerate(node_id)
basin_a_id = only(inneighbors(graph_flow, id))
Expand Down Expand Up @@ -868,7 +922,7 @@ function formulate_flow!(
(; graph_flow, flow) = connectivity
(; node_id, active, tables) = tabulated_rating_curve
diffvar = get_diffvar((t, storage))
flow = get_tmp(flow, diffvar)
flow = get_tmp_sparse(flow, diffvar)

for (i, id) in enumerate(node_id)
upstream_basin_id = only(inneighbors(graph_flow, id))
Expand Down Expand Up @@ -999,7 +1053,7 @@ function formulate_flow!(
(; graph_flow, flow) = connectivity
(; node_id, fraction) = fractional_flow
diffvar = get_diffvar((t, storage))
flow = get_tmp(flow, diffvar)
flow = get_tmp_sparse(flow, diffvar)

for (i, id) in enumerate(node_id)
downstream_id = only(outneighbors(graph_flow, id))
Expand All @@ -1019,7 +1073,7 @@ function formulate_flow!(
(; graph_flow, flow) = connectivity
(; node_id) = terminal
diffvar = get_diffvar((t, storage))
flow = get_tmp(flow, diffvar)
flow = get_tmp_sparse(flow, diffvar)

for id in node_id
for upstream_id in inneighbors(graph_flow, id)
Expand All @@ -1040,7 +1094,7 @@ function formulate_flow!(
(; graph_flow, flow) = connectivity
(; node_id) = level_boundary
diffvar = get_diffvar((t, storage))
flow = get_tmp(flow, diffvar)
flow = get_tmp_sparse(flow, diffvar)

for id in node_id
for in_id in inneighbors(graph_flow, id)
Expand All @@ -1065,7 +1119,7 @@ function formulate_flow!(
(; graph_flow, flow) = connectivity
(; node_id, active, flow_rate) = flow_boundary
diffvar = get_diffvar((t, storage))
flow = get_tmp(flow, diffvar)
flow = get_tmp_sparse(flow, diffvar)

for (i, id) in enumerate(node_id)
# Requirement: edge points away from the flow boundary
Expand Down Expand Up @@ -1093,7 +1147,7 @@ function formulate_flow!(
(; graph_flow, flow) = connectivity
(; node_id, active, flow_rate, is_pid_controlled) = pump
diffvar = get_diffvar((t, storage))
flow = get_tmp(flow, diffvar)
flow = get_tmp_sparse(flow, diffvar)
flow_rate = get_tmp(flow_rate, diffvar)
for (id, isactive, rate, pid_controlled) in
zip(node_id, active, flow_rate, is_pid_controlled)
Expand Down Expand Up @@ -1234,9 +1288,9 @@ function water_balance!(
du .= 0.0

diffvar = get_diffvar((t, storage))
flow = get_tmp(connectivity.flow, diffvar)
flow = get_tmp_sparse(connectivity.flow, diffvar)
# use parent to avoid materializing the ReinterpretArray from FixedSizeDiffCache
parent(flow) .= 0.0
flow.nzval .= 0.0

# Ensures current_* vectors are current
set_current_basin_properties!(basin, storage, t)
Expand Down

0 comments on commit ad77b79

Please sign in to comment.