diff --git a/core/src/model.jl b/core/src/model.jl index 4d63fa5c3..41d881233 100644 --- a/core/src/model.jl +++ b/core/src/model.jl @@ -35,57 +35,14 @@ function Model(config_path::AbstractString)::Model return Model(config) end -function initialize_state(p::Parameters, state::StructVector)::ComponentVector - (; basin, pid_control, graph, allocation, user_demand) = p - - storage = get_storages_from_levels(basin, state.level) - - # Synchronize level with storage - set_current_basin_properties!(basin, storage) - - # Integrals for PID control - integral = zeros(length(pid_control.node_id)) - - # Flows over edges - n_flows = length(graph[].flow_dict) - flow_integrated = zeros(n_flows) - - # Basin forcings - n_basins = length(basin.node_id) - precipitation_integrated = zeros(n_basins) - evaporation_integrated = zeros(n_basins) - drainage_integrated = zeros(n_basins) - infiltration_integrated = zeros(n_basins) - - precipitation_bmi = zeros(n_basins) - evaporation_bmi = zeros(n_basins) - drainage_bmi = zeros(n_basins) - infiltration_bmi = zeros(n_basins) - - # Flows for allocation - n_allocation_input_flows = length(allocation.flow_dict) - flow_allocation_input = zeros(n_allocation_input_flows) - - # Realized user demand - n_user_demands = length(user_demand.node_id) - realized_user_demand_bmi = zeros(n_user_demands) - - # NOTE: This is the source of truth for the state component names - return ComponentVector{Float64}(; - storage, - integral, - flow_integrated, - precipitation_integrated, - evaporation_integrated, - drainage_integrated, - infiltration_integrated, - precipitation_bmi, - evaporation_bmi, - drainage_bmi, - infiltration_bmi, - flow_allocation_input, - realized_user_demand_bmi, +function initialize_state(db::DB, config::Config, basin::Basin)::ComponentVector + n_states = get_n_states(db, config) + u0 = ComponentVector{Float64}( + NamedTuple{keys(n_states)}([zeros(n) for n in values(n_states)]), ) + state = load_structvector(db, config, BasinStateV1) + u0.storage = get_storages_from_levels(basin, state.level) + return u0 end function Model(config::Config)::Model @@ -103,7 +60,7 @@ function Model(config::Config)::Model # All data from the database that we need during runtime is copied into memory, # so we can directly close it again. db = SQLite.DB(db_path) - local parameters, state, n, tstops + local parameters, u0, tstops try parameters = Parameters(db, config) @@ -145,9 +102,9 @@ function Model(config::Config)::Model push!(tstops, get_tstops(time_schema.time, config.starttime)) end - # use state - state = load_structvector(db, config, BasinStateV1) - n = length(get_ids(db, "Basin")) + # initial state + u0 = initialize_state(db, config, parameters.basin) + @assert length(u0.flow_allocation_input) == length(parameters.allocation.flow_dict) "Unexpected number of flows to integrate for allocation input." sql = "SELECT node_id FROM Node ORDER BY node_id" node_id = only(execute(columntable, db, sql)) @@ -162,8 +119,8 @@ function Model(config::Config)::Model end @debug "Read database into memory." - u0 = initialize_state(parameters, state) - @assert length(u0.storage) == n "Basin / state length differs from number of Basins" + # Synchronize level with storage + set_current_basin_properties!(parameters.basin, u0.storage) # for Float32 this method allows max ~1000 year simulations without accuracy issues t_end = seconds_since(config.endtime, config.starttime) @@ -176,7 +133,8 @@ function Model(config::Config)::Model tstops = sort(unique(vcat(tstops...))) adaptive, dt = convert_dt(config.solver.dt) - jac_prototype = config.solver.sparse ? get_jac_prototype(parameters) : nothing + jac_prototype = + config.solver.sparse ? get_jac_prototype(parameters, length(u0)) : nothing RHS = ODEFunction(water_balance!; jac_prototype) @timeit_debug to "Setup ODEProblem" begin diff --git a/core/src/read.jl b/core/src/read.jl index a982b4539..b751f99e2 100644 --- a/core/src/read.jl +++ b/core/src/read.jl @@ -1029,7 +1029,7 @@ function get_chunk_sizes(config::Config, n_states::Int)::Vector{Int} end function Parameters(db::DB, config::Config)::Parameters - n_states = get_n_states(db) + n_states = sum(get_n_states(db, config)) chunk_sizes = get_chunk_sizes(config, n_states) graph = create_graph(db, config, chunk_sizes) allocation = Allocation(db, config, graph) diff --git a/core/src/sparsity.jl b/core/src/sparsity.jl index 1c92fd387..f59cbf7ba 100644 --- a/core/src/sparsity.jl +++ b/core/src/sparsity.jl @@ -10,10 +10,8 @@ Note: the name 'prototype' does not mean this code is a prototype, it comes from the naming convention of this sparsity structure in the differentialequations.jl docs. """ -function get_jac_prototype(p::Parameters)::SparseMatrixCSC{Float64, Int64} +function get_jac_prototype(p::Parameters, n_states::Int)::SparseMatrixCSC{Float64, Int64} (; basin, pid_control, graph) = p - - n_states = get_n_states(p) jac_prototype = spzeros(n_states, n_states) update_jac_prototype!(jac_prototype, p) diff --git a/core/src/util.jl b/core/src/util.jl index 12dfc8496..89991834c 100644 --- a/core/src/util.jl +++ b/core/src/util.jl @@ -739,11 +739,6 @@ has_fractional_flow_outneighbors(graph::MetaGraph, node_id::NodeID)::Bool = any( internalnorm(u::ComponentVector, t) = OrdinaryDiffEq.ODE_DEFAULT_NORM(u.storage, t) internalnorm(u::Number, t) = OrdinaryDiffEq.ODE_DEFAULT_NORM(u, t) -function get_n_node(db::DB, type::String)::Int - result = execute(columntable, db, "SELECT COUNT(*) From Node WHERE node_type = '$type'") - return only(only(result)) -end - function get_n_flows(db::DB)::Int result = execute(columntable, db, "SELECT COUNT(*) FROM Edge WHERE edge_type = 'flow'") return only(only(result)) @@ -755,7 +750,7 @@ function get_n_allocation_flow_inputs(db::DB)::Int execute( columntable, db, - "SELECT COUNT(*) From Edge where 'subnetwork_id' != 0", + "SELECT COUNT(*) From Edge where subnetwork_id IS NOT NULL", ), ), ) @@ -764,28 +759,36 @@ function get_n_allocation_flow_inputs(db::DB)::Int execute( columntable, db, - "SELECT COUNT(*) FROM Edge WHERE from_node_type = 'level_demand'", + "SELECT COUNT(*) FROM Edge WHERE from_node_type = 'LevelDemand'", ), ), ) return n_sources + n_level_demands end -function get_n_states(db::DB)::Int - return 9 * get_n_node(db, "Basin") + - get_n_node(db, "PidControl") + - get_n_flows(db) + - get_n_allocation_flow_inputs(db) + - get_n_node(db, "UserDemand") -end - -function get_n_states(p::Parameters)::Int - (; basin, pid_control, graph, allocation, user_demand) = p - return 9 * length(basin.node_id) + - length(pid_control.node_id) + - length(graph[].flow_dict) + - length(allocation.flow_dict) + - length(user_demand.node_id) +function get_n_states(db::DB, config::Config)::NamedTuple + n_basins = length(get_ids(db, "Basin")) + n_pid_controls = length(get_ids(db, "PidControl")) + n_user_demands = length(get_ids(db, "UserDemand")) + n_flows = get_n_flows(db) + n_allocation_flow_inputs = + config.allocation.use_allocation ? get_n_allocation_flow_inputs(db) : 0 + # NOTE: This is the source of truth for the state component names + return (; + storage = n_basins, + integral = n_pid_controls, + flow_integrated = n_flows, + precipitation_integrated = n_basins, + evaporation_integrated = n_basins, + drainage_integrated = n_basins, + infiltration_integrated = n_basins, + precipitation_bmi = n_basins, + evaporation_bmi = n_basins, + drainage_bmi = n_basins, + infiltration_bmi = n_basins, + flow_allocation_input = n_allocation_flow_inputs, + realized_user_demand_bmi = n_user_demands, + ) end function forcings_integrated(u::ComponentVector) diff --git a/core/test/utils_test.jl b/core/test/utils_test.jl index 43c0b0a83..cf43f6c89 100644 --- a/core/test/utils_test.jl +++ b/core/test/utils_test.jl @@ -182,8 +182,10 @@ end db = SQLite.DB(db_path) p = Ribasim.Parameters(db, cfg) - jac_prototype = Ribasim.get_jac_prototype(p) + n_states = sum(Ribasim.get_n_states(db, cfg)) + close(db) + jac_prototype = Ribasim.get_jac_prototype(p, n_states) @test jac_prototype.m == 4 @test jac_prototype.n == 4 @test jac_prototype.colptr == [1, 3, 5, 8, 11] @@ -197,8 +199,10 @@ end db = SQLite.DB(db_path) p = Ribasim.Parameters(db, cfg) - jac_prototype = Ribasim.get_jac_prototype(p) + n_states = sum(Ribasim.get_n_states(db, config)) + close(db) + jac_prototype = Ribasim.get_jac_prototype(p, n_states) @test jac_prototype.m == 3 @test jac_prototype.n == 3 @test jac_prototype.colptr == [1, 4, 5, 6]