Skip to content

Commit

Permalink
Fix sparsity tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SouthEndMusic committed May 6, 2024
1 parent df56308 commit f9632d6
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 38 deletions.
78 changes: 70 additions & 8 deletions core/src/sparsity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,30 @@ function get_jac_prototype(
p::Parameters,
u::ComponentVector,
)::SparseMatrixCSC{Float64, Int64}
(; basin, pid_control, graph) = p
(; basin, pid_control, graph, user_demand, allocation) = p
n_states = length(u)
axis = only(getfield(u, :axes))
jac_prototype = ComponentMatrix(spzeros(n_states, n_states), (axis, axis))

update_jac_prototype!(jac_prototype, p)
# Storages depending on storages
update_jac_prototype!(jac_prototype, basin, graph)

# PID control states depending on storages (and the other way around)
update_jac_prototype!(jac_prototype, pid_control, basin, graph)

# Flows depending on storages
update_jac_prototype!(jac_prototype, p)

# Evaporation and infiltration depending on storages
update_jac_prototype!(jac_prototype, basin)

# Allocation input flows depending on storages
# Note, this is copied from the update_jac_prototype!(jac_prototype, p)
# result so the other is important
update_jac_prototype!(jac_prototype, allocation, graph)

# UserDemand inflows depending on storages
update_jac_prototype!(jac_prototype, user_demand, graph, basin)
return jac_prototype.data
end

Expand Down Expand Up @@ -123,17 +139,63 @@ function update_jac_prototype!(jac_prototype::ComponentMatrix, p::Parameters)::N
end

"""
Allocation flow inputs depending on storages
(get from above)
Add nonzeros for evaporation and infiltration depending on storages
"""
function update_jac_prototype!()::Nothing
function update_jac_prototype!(jac_prototype::ComponentMatrix, basin::Basin)::Nothing
jac_prototype_evaporation = @view jac_prototype[:storage, :evaporation_integrated]
jac_prototype_infiltration = @view jac_prototype[:storage, :infiltration_integrated]
jac_prototype_evaporation_bmi = @view jac_prototype[:storage, :evaporation_integrated]
jac_prototype_infiltration_bmi = @view jac_prototype[:storage, :infiltration_integrated]
for (i, id) in enumerate(basin.node_id)
jac_prototype_evaporation[i, i] = 1.0
jac_prototype_infiltration[i, i] = 1.0
jac_prototype_evaporation_bmi[i, i] = 1.0
jac_prototype_infiltration_bmi[i, i] = 1.0
end
return nothing
end

"""
Add nonzeros for allocation input flows depending on storages.
"""
function update_jac_prototype!(
jac_prototype::ComponentMatrix,
allocation::Allocation,
graph::MetaGraph,
)::Nothing
(; input_flow_dict) = allocation
(; flow_dict) = graph[]
jac_prototype_storage_flow = @view jac_prototype[:storage, :flow_integrated]
jac_prototype_storage_flow_allocation =
@view jac_prototype[:storage, :flow_allocation_input]
# Copy which storages a flow depends on
for (edge, i) in input_flow_dict
# A self-loop indicates basin forcing
if edge[1] == edge[2]
continue
end
jac_prototype_storage_flow_allocation[:, i] =
jac_prototype_storage_flow[:, flow_dict[edge]]
end
return nothing
end

"""
Realized user demands depending on storages
(get from above)
Add nonzeros for UserDemand intake flows depending on storages.
"""
function update_jac_prototype!()::Nothing
function update_jac_prototype!(
jac_prototype::ComponentMatrix,
user_demand::UserDemand,
graph::MetaGraph,
basin::Basin,
)::Nothing
jac_prototype_storage_realized_demand =
@view jac_prototype[:storage, :realized_user_demand_bmi]
for (user_demand_idx, node_id) in enumerate(user_demand.node_id)
basin_node_id = inflow_id(graph, node_id)
has_index, i = id_index(basin.node_id, basin_node_id)
@assert has_index "UserDemand inflow node is not a basin."
jac_prototype_storage_realized_demand[i, user_demand_idx] = 1.0
end
return nothing
end
148 changes: 118 additions & 30 deletions core/test/utils_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,38 +176,126 @@ end
import SQLite

toml_path = normpath(@__DIR__, "../../generated_testmodels/basic/ribasim.toml")

cfg = Ribasim.Config(toml_path)
db_path = Ribasim.input_path(cfg, cfg.database)
db = SQLite.DB(db_path)

p = Ribasim.Parameters(db, cfg)
n_states = sum(Ribasim.get_n_states(db, cfg))
close(db)

jac_prototype = Ribasim.get_jac_prototype(p, n_states)
@test jac_prototype.m == 4
@test jac_prototype.n == 4
@test jac_prototype.colptr == [1, 3, 5, 8, 11]
@test jac_prototype.rowval == [1, 2, 1, 2, 2, 3, 4, 2, 3, 4]
@test jac_prototype.nzval == ones(10)
model = Ribasim.Model(toml_path)
(; p, u) = model.integrator

jac_prototype = Ribasim.get_jac_prototype(p, u)
@test jac_prototype.m == 53
@test jac_prototype.n == 53
@test jac_prototype.colptr == [
1,
3,
5,
8,
11,
13,
15,
16,
17,
18,
19,
21,
22,
24,
25,
26,
27,
28,
29,
30,
31,
32,
32,
32,
32,
32,
33,
34,
35,
36,
36,
36,
36,
36,
37,
38,
39,
40,
40,
40,
40,
40,
40,
40,
40,
40,
40,
40,
40,
40,
40,
40,
40,
40,
]
@test jac_prototype.rowval == [
1,
2,
1,
2,
2,
3,
4,
2,
3,
4,
1,
2,
1,
2,
2,
2,
2,
2,
3,
4,
2,
3,
4,
4,
2,
2,
2,
2,
3,
1,
4,
1,
2,
3,
4,
1,
2,
3,
4,
]
@test jac_prototype.nzval == ones(39)
# States do not depend on non-storage states
@test sum(jac_prototype[(length(p.basin.node_id) + 1):end, :]) == 0

toml_path = normpath(@__DIR__, "../../generated_testmodels/pid_control/ribasim.toml")

cfg = Ribasim.Config(toml_path)
db_path = Ribasim.input_path(cfg, cfg.database)
db = SQLite.DB(db_path)

p = Ribasim.Parameters(db, cfg)
n_states = sum(Ribasim.get_n_states(db, cfg))
close(db)

jac_prototype = Ribasim.get_jac_prototype(p, n_states)
@test jac_prototype.m == 3
@test jac_prototype.n == 3
@test jac_prototype.colptr == [1, 4, 5, 6]
@test jac_prototype.rowval == [1, 2, 3, 1, 1]
@test jac_prototype.nzval == ones(5)
model = Ribasim.Model(toml_path)
(; p, u) = model.integrator

jac_prototype = Ribasim.get_jac_prototype(p, u)
@test jac_prototype.m == 16
@test jac_prototype.n == 16
@test jac_prototype.colptr ==
[1, 4, 5, 6, 7, 8, 9, 10, 11, 11, 12, 12, 13, 13, 13, 13, 13]
@test jac_prototype.rowval == [1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1]
@test jac_prototype.nzval == ones(12)
# Some states depend on non-storage states (PID integral term)
@test sum(jac_prototype[(length(p.basin.node_id) + 1):end, :]) == 2
end

@testitem "FlatVector" begin
Expand Down

0 comments on commit f9632d6

Please sign in to comment.