Skip to content

Commit

Permalink
Bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
SouthEndMusic committed Oct 25, 2023
1 parent d9eb41a commit 87e11e7
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 27 deletions.
46 changes: 20 additions & 26 deletions core/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ const VectorInterpolation =
This struct is analogous to SparseArrays.SparseMatrixCSC, with the only
difference that nzval is a DiffCache instead of a vector.
"""
struct SparseMatrixCSC_DiffCache{Ti <: Integer, D <: DiffCache}
struct SparseMatrixCSC_DiffCache{D <: DiffCache, Ti <: Integer}
m::Int # Number of rows
n::Int # Number of columns
colptr::Vector{Ti} # Column j is in colptr[j]:(colptr[j+1]-1)
Expand Down Expand Up @@ -49,16 +49,10 @@ function get_tmp_sparse(
matrix::Union{SparseMatrixCSC_DiffCache, SparseMatrixCSC},
u,
)::Union{SparseMatrixCSC_cache, SparseMatrixCSC}
tmp = get_tmp(matrix.nzval, u)
return_type = tmp isa Vector ? SparseMatrixCSC : SparseMatrixCSC_cache

return return_type(
matrix.m,
matrix.n,
matrix.colptr,
matrix.rowval,
get_tmp(matrix.nzval, u),
)
nzval = get_tmp(matrix.nzval, u)
return_type = nzval isa Base.ReinterpretArray ? SparseMatrixCSC_cache : SparseMatrixCSC

return return_type(matrix.m, matrix.n, matrix.colptr, matrix.rowval, nzval)
end

"""
Expand Down Expand Up @@ -597,7 +591,7 @@ function set_current_basin_properties!(
t::Number,
)::Nothing
(; current_level, current_area) = basin
diffvar = get_diffvar((t, storage))
diffvar = get_diffvar(t, storage)
current_level = get_tmp(current_level, diffvar)
current_area = get_tmp(current_area, diffvar)

Expand All @@ -622,7 +616,7 @@ function formulate_basins!(
t::Number,
)::Nothing
(; node_id, current_level, current_area) = basin
diffvar = get_diffvar((t, storage))
diffvar = get_diffvar(t, storage)
current_level = get_tmp(current_level, diffvar)
current_area = get_tmp(current_area, diffvar)

Expand Down Expand Up @@ -651,7 +645,7 @@ end
function set_error!(pid_control::PidControl, p::Parameters, u::ComponentVector, t::Number)
(; basin) = p
(; listen_node_id, target, error) = pid_control
diffvar = get_diffvar((t, u))
diffvar = get_diffvar(t, u)
error = get_tmp(error, diffvar)
current_level = get_tmp(basin.current_level, diffvar)

Expand Down Expand Up @@ -680,7 +674,7 @@ function continuous_control!(
(; node_id, active, target, pid_params, listen_node_id, error) = pid_control
(; current_area) = basin

diffvar = get_diffvar((t, u))
diffvar = get_diffvar(t, u)
current_area = get_tmp(current_area, diffvar)
flow = get_tmp_sparse(flow, diffvar)
outlet_flow_rate = get_tmp(outlet.flow_rate, diffvar)
Expand Down Expand Up @@ -840,7 +834,7 @@ function formulate_flow!(
(; graph_flow, flow) = connectivity
(; node_id, allocated, demand, active, return_factor, min_level) = user

diffvar = get_diffvar((t, storage))
diffvar = get_diffvar(t, storage)
flow = get_tmp_sparse(flow, diffvar)

for (i, id) in enumerate(node_id)
Expand Down Expand Up @@ -892,7 +886,7 @@ function formulate_flow!(
(; graph_flow, flow) = connectivity
(; node_id, active, resistance) = linear_resistance

diffvar = get_diffvar((t, storage))
diffvar = get_diffvar(t, storage)
flow = get_tmp_sparse(flow, diffvar)

for (i, id) in enumerate(node_id)
Expand Down Expand Up @@ -924,7 +918,7 @@ function formulate_flow!(
(; basin, connectivity) = p
(; graph_flow, flow) = connectivity
(; node_id, active, tables) = tabulated_rating_curve
diffvar = get_diffvar((t, storage))
diffvar = get_diffvar(t, storage)
flow = get_tmp_sparse(flow, diffvar)

for (i, id) in enumerate(node_id)
Expand Down Expand Up @@ -997,7 +991,7 @@ function formulate_flow!(
(; graph_flow, flow) = connectivity
(; node_id, active, length, manning_n, profile_width, profile_slope) =
manning_resistance
diffvar = get_diffvar((t, storage))
diffvar = get_diffvar(t, storage)
flow = get_tmp_sparse(flow, diffvar)

for (i, id) in enumerate(node_id)
Expand Down Expand Up @@ -1055,7 +1049,7 @@ function formulate_flow!(
(; connectivity) = p
(; graph_flow, flow) = connectivity
(; node_id, fraction) = fractional_flow
diffvar = get_diffvar((t, storage))
diffvar = get_diffvar(t, storage)
flow = get_tmp_sparse(flow, diffvar)

for (i, id) in enumerate(node_id)
Expand All @@ -1075,7 +1069,7 @@ function formulate_flow!(
(; connectivity) = p
(; graph_flow, flow) = connectivity
(; node_id) = terminal
diffvar = get_diffvar((t, storage))
diffvar = get_diffvar(t, storage)
flow = get_tmp_sparse(flow, diffvar)

for id in node_id
Expand All @@ -1096,7 +1090,7 @@ function formulate_flow!(
(; connectivity) = p
(; graph_flow, flow) = connectivity
(; node_id) = level_boundary
diffvar = get_diffvar((t, storage))
diffvar = get_diffvar(t, storage)
flow = get_tmp_sparse(flow, diffvar)

for id in node_id
Expand All @@ -1121,7 +1115,7 @@ function formulate_flow!(
(; connectivity) = p
(; graph_flow, flow) = connectivity
(; node_id, active, flow_rate) = flow_boundary
diffvar = get_diffvar((t, storage))
diffvar = get_diffvar(t, storage)
flow = get_tmp_sparse(flow, diffvar)

for (i, id) in enumerate(node_id)
Expand Down Expand Up @@ -1149,7 +1143,7 @@ function formulate_flow!(
(; connectivity, basin) = p
(; graph_flow, flow) = connectivity
(; node_id, active, flow_rate, is_pid_controlled) = pump
diffvar = get_diffvar((t, storage))
diffvar = get_diffvar(t, storage)
flow = get_tmp_sparse(flow, diffvar)
flow_rate = get_tmp(flow_rate, diffvar)
for (id, isactive, rate, pid_controlled) in
Expand Down Expand Up @@ -1185,7 +1179,7 @@ function formulate_flow!(
(; connectivity, basin) = p
(; graph_flow, flow) = connectivity
(; node_id, active, flow_rate, is_pid_controlled, min_crest_level) = outlet
diffvar = get_diffvar((t, storage))
diffvar = get_diffvar(t, storage)
flow = get_tmp_sparse(flow, diffvar)
flow_rate = get_tmp(flow_rate, diffvar)
for (i, id) in enumerate(node_id)
Expand Down Expand Up @@ -1290,7 +1284,7 @@ function water_balance!(

du .= 0.0

diffvar = get_diffvar((t, storage))
diffvar = get_diffvar(t, storage)
flow = get_tmp_sparse(connectivity.flow, diffvar)
# use parent to avoid materializing the ReinterpretArray from FixedSizeDiffCache
nonzeros(flow) .= 0.0
Expand Down
2 changes: 1 addition & 1 deletion core/test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ end
@test M_float[3, 2] == 1.0
@test M_float isa SparseArrays.SparseMatrixCSC
@test nonzeros(M_float) == [1.0]
@test SparseArrays.size(M_cache) == (5, 5)
@test SparseArrays.size(M_dual) == (5, 5)
@test M_dual isa Ribasim.SparseMatrixCSC_cache
@test M_dual[1, 1] == Dual(0.0, 0.0, 0.0)
end

0 comments on commit 87e11e7

Please sign in to comment.