Skip to content

Commit

Permalink
Single source of truth for the number of states
Browse files Browse the repository at this point in the history
  • Loading branch information
SouthEndMusic committed May 1, 2024
1 parent c07b9b0 commit 38b0482
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 85 deletions.
72 changes: 15 additions & 57 deletions core/src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,57 +35,14 @@ function Model(config_path::AbstractString)::Model
return Model(config)
end

function initialize_state(p::Parameters, state::StructVector)::ComponentVector
(; basin, pid_control, graph, allocation, user_demand) = p

storage = get_storages_from_levels(basin, state.level)

# Synchronize level with storage
set_current_basin_properties!(basin, storage)

# Integrals for PID control
integral = zeros(length(pid_control.node_id))

# Flows over edges
n_flows = length(graph[].flow_dict)
flow_integrated = zeros(n_flows)

# Basin forcings
n_basins = length(basin.node_id)
precipitation_integrated = zeros(n_basins)
evaporation_integrated = zeros(n_basins)
drainage_integrated = zeros(n_basins)
infiltration_integrated = zeros(n_basins)

precipitation_bmi = zeros(n_basins)
evaporation_bmi = zeros(n_basins)
drainage_bmi = zeros(n_basins)
infiltration_bmi = zeros(n_basins)

# Flows for allocation
n_allocation_input_flows = length(allocation.flow_dict)
flow_allocation_input = zeros(n_allocation_input_flows)

# Realized user demand
n_user_demands = length(user_demand.node_id)
realized_user_demand_bmi = zeros(n_user_demands)

# NOTE: This is the source of truth for the state component names
return ComponentVector{Float64}(;
storage,
integral,
flow_integrated,
precipitation_integrated,
evaporation_integrated,
drainage_integrated,
infiltration_integrated,
precipitation_bmi,
evaporation_bmi,
drainage_bmi,
infiltration_bmi,
flow_allocation_input,
realized_user_demand_bmi,
function initialize_state(db::DB, config::Config, basin::Basin)::ComponentVector
n_states = get_n_states(db, config)
u0 = ComponentVector{Float64}(
NamedTuple{keys(n_states)}([zeros(n) for n in values(n_states)]),
)
state = load_structvector(db, config, BasinStateV1)
u0.storage = get_storages_from_levels(basin, state.level)
return u0
end

function Model(config::Config)::Model
Expand All @@ -103,7 +60,7 @@ function Model(config::Config)::Model
# All data from the database that we need during runtime is copied into memory,
# so we can directly close it again.
db = SQLite.DB(db_path)
local parameters, state, n, tstops
local parameters, u0, tstops
try
parameters = Parameters(db, config)

Expand Down Expand Up @@ -145,9 +102,9 @@ function Model(config::Config)::Model
push!(tstops, get_tstops(time_schema.time, config.starttime))
end

# use state
state = load_structvector(db, config, BasinStateV1)
n = length(get_ids(db, "Basin"))
# initial state
u0 = initialize_state(db, config, parameters.basin)
@assert length(u0.flow_allocation_input) == length(parameters.allocation.flow_dict) "Unexpected number of flows to integrate for allocation input."

sql = "SELECT node_id FROM Node ORDER BY node_id"
node_id = only(execute(columntable, db, sql))
Expand All @@ -162,8 +119,8 @@ function Model(config::Config)::Model
end
@debug "Read database into memory."

u0 = initialize_state(parameters, state)
@assert length(u0.storage) == n "Basin / state length differs from number of Basins"
# Synchronize level with storage
set_current_basin_properties!(parameters.basin, u0.storage)

# for Float32 this method allows max ~1000 year simulations without accuracy issues
t_end = seconds_since(config.endtime, config.starttime)
Expand All @@ -176,7 +133,8 @@ function Model(config::Config)::Model
tstops = sort(unique(vcat(tstops...)))
adaptive, dt = convert_dt(config.solver.dt)

jac_prototype = config.solver.sparse ? get_jac_prototype(parameters) : nothing
jac_prototype =
config.solver.sparse ? get_jac_prototype(parameters, length(u0)) : nothing
RHS = ODEFunction(water_balance!; jac_prototype)

@timeit_debug to "Setup ODEProblem" begin
Expand Down
2 changes: 1 addition & 1 deletion core/src/read.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,7 @@ function get_chunk_sizes(config::Config, n_states::Int)::Vector{Int}
end

function Parameters(db::DB, config::Config)::Parameters
n_states = get_n_states(db)
n_states = sum(get_n_states(db, config))
chunk_sizes = get_chunk_sizes(config, n_states)
graph = create_graph(db, config, chunk_sizes)
allocation = Allocation(db, config, graph)
Expand Down
4 changes: 1 addition & 3 deletions core/src/sparsity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@ Note: the name 'prototype' does not mean this code is a prototype, it comes
from the naming convention of this sparsity structure in the
differentialequations.jl docs.
"""
function get_jac_prototype(p::Parameters)::SparseMatrixCSC{Float64, Int64}
function get_jac_prototype(p::Parameters, n_states::Int)::SparseMatrixCSC{Float64, Int64}
(; basin, pid_control, graph) = p

n_states = get_n_states(p)
jac_prototype = spzeros(n_states, n_states)

update_jac_prototype!(jac_prototype, p)
Expand Down
47 changes: 25 additions & 22 deletions core/src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -739,11 +739,6 @@ has_fractional_flow_outneighbors(graph::MetaGraph, node_id::NodeID)::Bool = any(
internalnorm(u::ComponentVector, t) = OrdinaryDiffEq.ODE_DEFAULT_NORM(u.storage, t)
internalnorm(u::Number, t) = OrdinaryDiffEq.ODE_DEFAULT_NORM(u, t)

function get_n_node(db::DB, type::String)::Int
result = execute(columntable, db, "SELECT COUNT(*) From Node WHERE node_type = '$type'")
return only(only(result))
end

function get_n_flows(db::DB)::Int
result = execute(columntable, db, "SELECT COUNT(*) FROM Edge WHERE edge_type = 'flow'")
return only(only(result))
Expand All @@ -755,7 +750,7 @@ function get_n_allocation_flow_inputs(db::DB)::Int
execute(
columntable,
db,
"SELECT COUNT(*) From Edge where 'subnetwork_id' != 0",
"SELECT COUNT(*) From Edge where subnetwork_id IS NOT NULL",
),
),
)
Expand All @@ -764,28 +759,36 @@ function get_n_allocation_flow_inputs(db::DB)::Int
execute(
columntable,
db,
"SELECT COUNT(*) FROM Edge WHERE from_node_type = 'level_demand'",
"SELECT COUNT(*) FROM Edge WHERE from_node_type = 'LevelDemand'",
),
),
)
return n_sources + n_level_demands
end

function get_n_states(db::DB)::Int
return 9 * get_n_node(db, "Basin") +
get_n_node(db, "PidControl") +
get_n_flows(db) +
get_n_allocation_flow_inputs(db) +
get_n_node(db, "UserDemand")
end

function get_n_states(p::Parameters)::Int
(; basin, pid_control, graph, allocation, user_demand) = p
return 9 * length(basin.node_id) +
length(pid_control.node_id) +
length(graph[].flow_dict) +
length(allocation.flow_dict) +
length(user_demand.node_id)
function get_n_states(db::DB, config::Config)::NamedTuple
n_basins = length(get_ids(db, "Basin"))
n_pid_controls = length(get_ids(db, "PidControl"))
n_user_demands = length(get_ids(db, "UserDemand"))
n_flows = get_n_flows(db)
n_allocation_flow_inputs =
config.allocation.use_allocation ? get_n_allocation_flow_inputs(db) : 0
# NOTE: This is the source of truth for the state component names
return (;
storage = n_basins,
integral = n_pid_controls,
flow_integrated = n_flows,
precipitation_integrated = n_basins,
evaporation_integrated = n_basins,
drainage_integrated = n_basins,
infiltration_integrated = n_basins,
precipitation_bmi = n_basins,
evaporation_bmi = n_basins,
drainage_bmi = n_basins,
infiltration_bmi = n_basins,
flow_allocation_input = n_allocation_flow_inputs,
realized_user_demand_bmi = n_user_demands,
)
end

function forcings_integrated(u::ComponentVector)
Expand Down
8 changes: 6 additions & 2 deletions core/test/utils_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,10 @@ end
db = SQLite.DB(db_path)

p = Ribasim.Parameters(db, cfg)
jac_prototype = Ribasim.get_jac_prototype(p)
n_states = sum(Ribasim.get_n_states(db, cfg))
close(db)

jac_prototype = Ribasim.get_jac_prototype(p, n_states)
@test jac_prototype.m == 4
@test jac_prototype.n == 4
@test jac_prototype.colptr == [1, 3, 5, 8, 11]
Expand All @@ -197,8 +199,10 @@ end
db = SQLite.DB(db_path)

p = Ribasim.Parameters(db, cfg)
jac_prototype = Ribasim.get_jac_prototype(p)
n_states = sum(Ribasim.get_n_states(db, config))
close(db)

jac_prototype = Ribasim.get_jac_prototype(p, n_states)
@test jac_prototype.m == 3
@test jac_prototype.n == 3
@test jac_prototype.colptr == [1, 4, 5, 6]
Expand Down

0 comments on commit 38b0482

Please sign in to comment.