Skip to content

Commit

Permalink
No custom sparse matrix structs
Browse files Browse the repository at this point in the history
  • Loading branch information
SouthEndMusic committed Oct 23, 2023
1 parent 2a04d32 commit 7ff8b82
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 85 deletions.
2 changes: 1 addition & 1 deletion core/src/bmi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ end

"Copy the current flow to the SavedValues"
function save_flow(u, t, integrator)
copy(nonzeros(get_tmp_sparse(integrator.p.connectivity.flow, 0)))
copy(nonzeros(get_tmp(integrator.p.connectivity.flow, 0)))
end

"Load updates from 'Basin / time' into the parameters"
Expand Down
23 changes: 4 additions & 19 deletions core/src/create.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,19 +201,6 @@ end
const nonconservative_nodetypes =
Set{String}(["Basin", "LevelBoundary", "FlowBoundary", "Terminal", "User"])

"""
Get the chunk sizes for DiffCache; differentiation w.r.t. u
and t (the latter only if a Rosenbrock algorithm is used).
"""
function get_chunk_sizes(config::Config, chunk_size::Int)::Vector{Int}
chunk_sizes = [chunk_size]
if Ribasim.config.algorithms[config.solver.algorithm] <:
OrdinaryDiffEqRosenbrockAdaptiveAlgorithm
push!(chunk_sizes, 1)
end
return chunk_sizes
end

"""
If the tuple of variables contains a Dual variable, return the first one.
Otherwise return the last variable.
Expand Down Expand Up @@ -252,9 +239,8 @@ function Connectivity(db::DB, config::Config, chunk_size::Int)::Connectivity
flow .= 0.0

if config.solver.autodiff
chunk_sizes = get_chunk_sizes(config, chunk_size)
flowd = DiffCache(flow.nzval, chunk_sizes)
flow = SparseMatrixCSC_DiffCache(flow.m, flow.n, flow.colptr, flow.rowval, flowd)
# FixedSizeDiffCache performs better for sparse matrix
flow = FixedSizeDiffCache(flow, chunk_size)
end

# TODO: Create allocation models from input here
Expand Down Expand Up @@ -526,9 +512,8 @@ function Basin(db::DB, config::Config, chunk_size::Int)::Basin
current_area = zeros(n)

if config.solver.autodiff
chunk_sizes = get_chunk_sizes(config, chunk_size)
current_level = DiffCache(current_level, chunk_sizes)
current_area = DiffCache(current_area, chunk_sizes)
current_level = DiffCache(current_level, chunk_size)
current_area = DiffCache(current_area, chunk_size)
end

precipitation = fill(NaN, length(node_id))
Expand Down
2 changes: 1 addition & 1 deletion core/src/io.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ function flow_table(model::Model)::NamedTuple
(; t, saveval) = saved_flow
(; connectivity) = integrator.p

I, J, _ = findnz(get_tmp_sparse(connectivity.flow, 0))
I, J, _ = findnz(get_tmp(connectivity.flow, 0))
# self-loops have no edge ID
unique_edge_ids = [get(connectivity.edge_ids_flow, ij, missing) for ij in zip(I, J)]
nflow = length(I)
Expand Down
75 changes: 11 additions & 64 deletions core/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,60 +4,6 @@ const ScalarInterpolation =
const VectorInterpolation =
LinearInterpolation{Vector{Vector{Float64}}, Vector{Float64}, true, Vector{Float64}}

"""
This struct is analogous to SparseArrays.SparseMatrixCSC, with the only
difference that nzval is a DiffCache instead of a vector.
"""
struct SparseMatrixCSC_DiffCache{Ti <: Integer, D <: DiffCache}
m::Int # Number of rows
n::Int # Number of columns
colptr::Vector{Ti} # Column j is in colptr[j]:(colptr[j+1]-1)
rowval::Vector{Ti} # Row indices of stored values
nzval::D # DiffCache
end

"""
This struct is analogous to SparseArrays.SparseMatrixCSC, with the only
difference that nzval is either a vector or a ReinterpretArray (of Dual numbers).
SparseMatrixCSC_DiffCache and SparseMatrixCSC_cache are used to have a multi-level DiffCache of a sparse matrix
whose cache (accessed via get_tmp_sparse) supports sparse indexing.
Previously FixedSizeDiffCache was used for the sparse matrix Connectivity.flow, but FixedSizeDiffCache
does not support multi-level Duals.
"""
struct SparseMatrixCSC_cache{Ti <: Integer, D <: Union{Vector, Base.ReinterpretArray}} <:
SparseArrays.AbstractSparseMatrixCSC{eltype(D), Ti}
m::Int # Number of rows
n::Int # Number of columns
colptr::Vector{Ti} # Column j is in colptr[j]:(colptr[j+1]-1)
rowval::Vector{Ti} # Row indices of stored values
nzval::D # cache
end

const SparseCache = Union{SparseMatrixCSC_DiffCache, SparseMatrixCSC_cache}

SparseArrays.size(matrix::SparseCache) = (matrix.m, matrix.n)
SparseArrays.getcolptr(matrix::SparseCache) = matrix.colptr
SparseArrays.rowvals(matrix::SparseCache) = matrix.rowval
SparseArrays.nonzeros(matrix::SparseCache) = matrix.nzval

"""
Access the cache of the SparseMatrixCSC_DiffCache, and return it in a SparseMatrixCSC_cache
so that it can be accessed with sparse indexing.
"""
function get_tmp_sparse(
matrix::Union{SparseMatrixCSC_DiffCache, SparseMatrixCSC},
u,
)::SparseMatrixCSC_cache
return SparseMatrixCSC_cache(
matrix.m,
matrix.n,
matrix.colptr,
matrix.rowval,
get_tmp(matrix.nzval, u),
)
end

"""
Store information for a subnetwork used for allocation.
Expand Down Expand Up @@ -884,7 +830,7 @@ function formulate_flow!(
(; node_id, active, resistance) = linear_resistance

diffvar = get_diffvar((t, storage))
flow = get_tmp_sparse(flow, diffvar)
flow = get_tmp(flow, diffvar)

for (i, id) in enumerate(node_id)
basin_a_id = only(inneighbors(graph_flow, id))
Expand Down Expand Up @@ -916,7 +862,7 @@ function formulate_flow!(
(; graph_flow, flow) = connectivity
(; node_id, active, tables) = tabulated_rating_curve
diffvar = get_diffvar((t, storage))
flow = get_tmp_sparse(flow, diffvar)
flow = get_tmp(flow, diffvar)

for (i, id) in enumerate(node_id)
upstream_basin_id = only(inneighbors(graph_flow, id))
Expand Down Expand Up @@ -989,7 +935,7 @@ function formulate_flow!(
(; node_id, active, length, manning_n, profile_width, profile_slope) =
manning_resistance
diffvar = get_diffvar((storage, t))
flow = get_tmp_sparse(flow, diffvar)
flow = get_tmp(flow, diffvar)

for (i, id) in enumerate(node_id)
basin_a_id = only(inneighbors(graph_flow, id))
Expand Down Expand Up @@ -1047,7 +993,7 @@ function formulate_flow!(
(; graph_flow, flow) = connectivity
(; node_id, fraction) = fractional_flow
diffvar = get_diffvar((t, storage))
flow = get_tmp_sparse(flow, diffvar)
flow = get_tmp(flow, diffvar)

for (i, id) in enumerate(node_id)
downstream_id = only(outneighbors(graph_flow, id))
Expand All @@ -1067,7 +1013,7 @@ function formulate_flow!(
(; graph_flow, flow) = connectivity
(; node_id) = terminal
diffvar = get_diffvar((t, storage))
flow = get_tmp_sparse(flow, diffvar)
flow = get_tmp(flow, diffvar)

for id in node_id
for upstream_id in inneighbors(graph_flow, id)
Expand All @@ -1088,7 +1034,7 @@ function formulate_flow!(
(; graph_flow, flow) = connectivity
(; node_id) = level_boundary
diffvar = get_diffvar((t, storage))
flow = get_tmp_sparse(flow, diffvar)
flow = get_tmp(flow, diffvar)

for id in node_id
for in_id in inneighbors(graph_flow, id)
Expand All @@ -1113,7 +1059,7 @@ function formulate_flow!(
(; graph_flow, flow) = connectivity
(; node_id, active, flow_rate) = flow_boundary
diffvar = get_diffvar((t, storage))
flow = get_tmp_sparse(flow, diffvar)
flow = get_tmp(flow, diffvar)

for (i, id) in enumerate(node_id)
# Requirement: edge points away from the flow boundary
Expand Down Expand Up @@ -1141,7 +1087,7 @@ function formulate_flow!(
(; graph_flow, flow) = connectivity
(; node_id, active, flow_rate, is_pid_controlled) = pump
diffvar = get_diffvar((t, storage))
flow = get_tmp_sparse(flow, diffvar)
flow = get_tmp(flow, diffvar)
flow_rate = get_tmp(flow_rate, diffvar)
for (id, isactive, rate, pid_controlled) in
zip(node_id, active, flow_rate, is_pid_controlled)
Expand Down Expand Up @@ -1281,8 +1227,9 @@ function water_balance!(
du .= 0.0

diffvar = get_diffvar((t, storage))
flow = get_tmp_sparse(connectivity.flow, diffvar)
parent(flow.nzval) .= 0.0
flow = get_tmp(connectivity.flow, diffvar)
# use parent to avoid materializing the ReinterpretArray from FixedSizeDiffCache
parent(flow) .= 0.0

# Ensures current_* vectors are current
set_current_basin_properties!(basin, storage)
Expand Down

0 comments on commit 7ff8b82

Please sign in to comment.