Skip to content

Commit

Permalink
Merge branch 'main' into allocation_beyond_max_flow
Browse files Browse the repository at this point in the history
  • Loading branch information
SouthEndMusic committed Oct 13, 2023
2 parents 4884eb1 + bab694a commit 76c8db6
Show file tree
Hide file tree
Showing 23 changed files with 286 additions and 119 deletions.
20 changes: 11 additions & 9 deletions .github/workflows/python_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ concurrency:
cancel-in-progress: true
jobs:
test:
name: Python ${{ matrix.python_version }} - ${{ matrix.os }} - ${{ matrix.arch }}
name: Python ${{ matrix.python-version }} - ${{ matrix.os }} - ${{ matrix.arch }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
Expand All @@ -20,24 +20,26 @@ jobs:
- ubuntu-latest
- macOS-latest
- windows-latest
python_version:
- "3.9"
python-version:
- "3.10"
- "3.11"
arch:
- x86
steps:
- uses: actions/checkout@v4

- uses: prefix-dev/[email protected]
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
pixi-version: "latest"
cache: true
- name: Prepare pixi
run: pixi run install-without-pre-commit
python-version: "${{ matrix.python-version }}"

- name: Install test dependencies
run: |
pip install --editable "python/ribasim_testmodels"
pip install --editable "python/ribasim[tests]"
- name: Run tests
run: pixi run test-ribasim-python-cov
run: pytest --numprocesses=auto --cov=ribasim --cov-report=xml python/ribasim/tests

- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
4 changes: 3 additions & 1 deletion core/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ CodecZstd = "0.7,0.8"
ComponentArrays = "0.13.14, 0.14, 0.15"
Configurations = "0.17"
DBInterface = "2.4"
DataFrames = "1.4"
DataInterpolations = "3.7, 4"
DataStructures = "0.18"
Dictionaries = "0.3.25"
Expand All @@ -68,6 +69,7 @@ julia = "1.9"
[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
IOCapture = "b5f81e59-6552-4d32-b1f0-c071b021bf89"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Expand All @@ -78,4 +80,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestReports = "dcd651b4-b50a-5b6b-8f22-87e9f253a252"

[targets]
test = ["Aqua", "CSV", "Documenter", "IOCapture", "Logging", "SafeTestsets", "TerminalLoggers", "Test", "TestReports", "TOML"]
test = ["Aqua", "CSV", "DataFrames", "Documenter", "IOCapture", "Logging", "SafeTestsets", "TerminalLoggers", "Test", "TestReports", "TOML"]
43 changes: 26 additions & 17 deletions core/src/bmi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,25 @@ end
Write all output to the configured output files.
"""
function BMI.finalize(model::Model)::Model
compress = get_compressor(model.config.output)
write_basin_output(model, compress)
write_flow_output(model, compress)
write_discrete_control_output(model, compress)
(; config) = model
(; output) = model.config
compress = get_compressor(output)

# basin
table = basin_table(model)
path = output_path(config, output.basin)
write_arrow(path, table, compress)

# flow
table = flow_table(model)
path = output_path(config, output.flow)
write_arrow(path, table, compress)

# discrete control
table = discrete_control_table(model)
path = output_path(config, output.control)
write_arrow(path, table, compress)

@debug "Wrote output."
return model
end
Expand Down Expand Up @@ -198,20 +213,20 @@ function create_callbacks(
saveat,
)::Tuple{CallbackSet, SavedValues{Float64, Vector{Float64}}}
(; starttime, basin, tabulated_rating_curve, discrete_control) = parameters
callbacks = SciMLBase.DECallback[]

tstops = get_tstops(basin.time.time, starttime)
basin_cb = PresetTimeCallback(tstops, update_basin)
push!(callbacks, basin_cb)

tstops = get_tstops(tabulated_rating_curve.time.time, starttime)
tabulated_rating_curve_cb = PresetTimeCallback(tstops, update_tabulated_rating_curve!)
push!(callbacks, tabulated_rating_curve_cb)

# add a single time step's contribution to the water balance step's totals
# trackwb_cb = FunctionCallingCallback(track_waterbalance!)
# flows: save the flows over time, as a Vector of the nonzeros(flow)

# save the flows over time, as a Vector of the nonzeros(flow)
saved_flow = SavedValues(Float64, Vector{Float64})

save_flow_cb = SavingCallback(save_flow, saved_flow; saveat, save_start = false)
push!(callbacks, save_flow_cb)

n_conditions = length(discrete_control.node_id)
if n_conditions > 0
Expand All @@ -221,15 +236,9 @@ function create_callbacks(
discrete_control_affect_downcrossing!,
n_conditions,
)
callback = CallbackSet(
save_flow_cb,
basin_cb,
tabulated_rating_curve_cb,
discrete_control_cb,
)
else
callback = CallbackSet(save_flow_cb, basin_cb, tabulated_rating_curve_cb)
push!(callbacks, discrete_control_cb)
end
callback = CallbackSet(callbacks...)

return callback, saved_flow
end
Expand Down
43 changes: 29 additions & 14 deletions core/src/create.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ function parse_static_and_time(
vals_out = []

node_ids = get_ids(db, nodetype)
node_names = get_names(db, nodetype)
n_nodes = length(node_ids)

# Initialize the vectors for the output
Expand Down Expand Up @@ -91,7 +92,7 @@ function parse_static_and_time(
t_end = seconds_since(config.endtime, config.starttime)
trivial_timespan = [nextfloat(-Inf), prevfloat(Inf)]

for (node_idx, node_id) in enumerate(node_ids)
for (node_idx, (node_id, node_name)) in enumerate(zip(node_ids, node_names))
if node_id in static_node_ids
# The interval of rows of the static table that have the current node_id
rows = searchsorted(static.node_id, node_id)
Expand Down Expand Up @@ -153,7 +154,7 @@ function parse_static_and_time(
)
if !is_valid
errors = true
@error "A $parameter_name time series for $nodetype node #$node_id has repeated times, this can not be interpolated."
@error "A $parameter_name time series for $nodetype node $(repr(node_name)) (#$node_id) has repeated times, this can not be interpolated."
end
else
# Activity of transient nodes is assumed to be true
Expand All @@ -167,7 +168,7 @@ function parse_static_and_time(
getfield(out, parameter_name)[node_idx] = val
end
else
@error "$nodetype node #$node_id data not in any table."
@error "$nodetype node $(repr(node_name)) (#$node_id) data not in any table."
errors = true
end
end
Expand All @@ -179,10 +180,11 @@ function static_and_time_node_ids(
static::StructVector,
time::StructVector,
node_type::String,
)::Tuple{Set{Int}, Set{Int}, Vector{Int}, Bool}
)::Tuple{Set{Int}, Set{Int}, Vector{Int}, Vector{String}, Bool}
static_node_ids = Set(static.node_id)
time_node_ids = Set(time.node_id)
node_ids = get_ids(db, node_type)
node_names = get_names(db, node_type)
doubles = intersect(static_node_ids, time_node_ids)
errors = false
if !isempty(doubles)
Expand All @@ -193,9 +195,12 @@ function static_and_time_node_ids(
errors = true
@error "$node_type node IDs don't match."
end
return static_node_ids, time_node_ids, node_ids, !errors
return static_node_ids, time_node_ids, node_ids, node_names, !errors
end

const nonconservative_nodetypes =
Set{String}(["Basin", "LevelBoundary", "FlowBoundary", "Terminal", "User"])

function Connectivity(db::DB, config::Config, chunk_size::Int)::Connectivity
if !valid_edge_types(db)
error("Invalid edge types found.")
Expand All @@ -208,6 +213,16 @@ function Connectivity(db::DB, config::Config, chunk_size::Int)::Connectivity
edge_ids_flow_inv = Dictionary(values(edge_ids_flow), keys(edge_ids_flow))

flow = adjacency_matrix(graph_flow, Float64)
# Add a self-loop, i.e. an entry on the diagonal, for all non-conservative node types.
# This is used to store the gain (positive) or loss (negative) for the water balance.
# Note that this only affects the sparsity structure.
# We want to do it here to avoid changing that during the simulation and keeping it predictable,
# e.g. if we wouldn't do this, inactive nodes can appear if control turns them on during runtime.
for (i, nodetype) in enumerate(get_nodetypes(db))
if nodetype in nonconservative_nodetypes
flow[i, i] = 1.0
end
end
flow .= 0.0

if config.solver.autodiff
Expand Down Expand Up @@ -253,7 +268,7 @@ function TabulatedRatingCurve(db::DB, config::Config)::TabulatedRatingCurve
static = load_structvector(db, config, TabulatedRatingCurveStaticV1)
time = load_structvector(db, config, TabulatedRatingCurveTimeV1)

static_node_ids, time_node_ids, node_ids, valid =
static_node_ids, time_node_ids, node_ids, node_names, valid =
static_and_time_node_ids(db, static, time, "TabulatedRatingCurve")

if !valid
Expand All @@ -267,7 +282,7 @@ function TabulatedRatingCurve(db::DB, config::Config)::TabulatedRatingCurve
active = BitVector()
errors = false

for node_id in node_ids
for (node_id, node_name) in zip(node_ids, node_names)
if node_id in static_node_ids
# Loop over all static rating curves (groups) with this node_id.
# If it has a control_state add it to control_mapping.
Expand Down Expand Up @@ -298,11 +313,11 @@ function TabulatedRatingCurve(db::DB, config::Config)::TabulatedRatingCurve
push!(interpolations, interpolation)
push!(active, true)
else
@error "TabulatedRatingCurve node #$node_id data not in any table."
@error "TabulatedRatingCurve node $(repr(node_name)) (#$node_id) data not in any table."
errors = true
end
if !is_valid
@error "A Q(h) relationship for TabulatedRatingCurve #$node_id from the $source table has repeated levels, this can not be interpolated."
@error "A Q(h) relationship for TabulatedRatingCurve $(repr(node_name)) (#$node_id) from the $source table has repeated levels, this can not be interpolated."
errors = true
end
end
Expand Down Expand Up @@ -353,7 +368,7 @@ function LevelBoundary(db::DB, config::Config)::LevelBoundary
static = load_structvector(db, config, LevelBoundaryStaticV1)
time = load_structvector(db, config, LevelBoundaryTimeV1)

static_node_ids, time_node_ids, node_ids, valid =
static_node_ids, time_node_ids, node_ids, node_names, valid =
static_and_time_node_ids(db, static, time, "LevelBoundary")

if !valid
Expand Down Expand Up @@ -381,7 +396,7 @@ function FlowBoundary(db::DB, config::Config)::FlowBoundary
static = load_structvector(db, config, FlowBoundaryStaticV1)
time = load_structvector(db, config, FlowBoundaryTimeV1)

static_node_ids, time_node_ids, node_ids, valid =
static_node_ids, time_node_ids, node_ids, node_names, valid =
static_and_time_node_ids(db, static, time, "FlowBoundary")

if !valid
Expand All @@ -401,7 +416,7 @@ function FlowBoundary(db::DB, config::Config)::FlowBoundary
for itp in parsed_parameters.flow_rate
if any(itp.u .< 0.0)
@error(
"Currently negative flow rates are not supported, found some for dynamic flow boundary #$node_id."
"Currently negative flow rates are not supported, found some in dynamic flow boundary."
)
valid = false
end
Expand Down Expand Up @@ -568,7 +583,7 @@ function PidControl(db::DB, config::Config, chunk_size::Int)::PidControl
static = load_structvector(db, config, PidControlStaticV1)
time = load_structvector(db, config, PidControlTimeV1)

static_node_ids, time_node_ids, node_ids, valid =
static_node_ids, time_node_ids, node_ids, node_names, valid =
static_and_time_node_ids(db, static, time, "PidControl")

if !valid
Expand Down Expand Up @@ -630,7 +645,7 @@ function User(db::DB, config::Config)::User
static = load_structvector(db, config, UserStaticV1)
time = load_structvector(db, config, UserTimeV1)

static_node_ids, time_node_ids, node_ids, valid =
static_node_ids, time_node_ids, node_ids, node_names, valid =
static_and_time_node_ids(db, static, time, "User")

if !valid
Expand Down
Loading

0 comments on commit 76c8db6

Please sign in to comment.