From 564f25907d8ccedd4002d99a68ec85dd83f9c571 Mon Sep 17 00:00:00 2001 From: Martijn Visser Date: Tue, 19 Mar 2024 12:49:28 +0100 Subject: [PATCH] Don't save states on callbacks (#1281) Closes #1213, in the sense that I looked into the issue and it does not occur on the last release. I did notice that `basin.arrow` had multiple entries for the same timestamp on a transient run with daily forcing. The `PresetTimeCallback` updating the forcing caused these extra saves. The default `save_positions=(true,true)` is needed when modifying the state, but we don't do that. I checked all callbacks for this behavior. This does mean that a few tests had to be modified because we don't add an extra save if we hit a DiscreteControl condition, but have to wait until the next state save. https://docs.sciml.ai/DiffEqDocs/stable/features/callback_functions/#SciMLBase.DiscreteCallback --- core/src/callback.jl | 12 +++++++++--- core/test/control_test.jl | 20 ++++++++++---------- core/test/run_models_test.jl | 1 + docs/python/examples.ipynb | 21 ++++++++------------- 4 files changed, 28 insertions(+), 26 deletions(-) diff --git a/core/src/callback.jl b/core/src/callback.jl index 0a29e8208..9f3db5c1a 100644 --- a/core/src/callback.jl +++ b/core/src/callback.jl @@ -37,14 +37,18 @@ function create_callbacks( callbacks = SciMLBase.DECallback[] tstops = get_tstops(basin.time.time, starttime) - basin_cb = PresetTimeCallback(tstops, update_basin) + basin_cb = PresetTimeCallback(tstops, update_basin; save_positions = (false, false)) push!(callbacks, basin_cb) integrating_flows_cb = FunctionCallingCallback(integrate_flows!; func_start = false) push!(callbacks, integrating_flows_cb) tstops = get_tstops(tabulated_rating_curve.time.time, starttime) - tabulated_rating_curve_cb = PresetTimeCallback(tstops, update_tabulated_rating_curve!) + tabulated_rating_curve_cb = PresetTimeCallback( + tstops, + update_tabulated_rating_curve!; + save_positions = (false, false), + ) push!(callbacks, tabulated_rating_curve_cb) if config.allocation.use_allocation @@ -52,6 +56,7 @@ function create_callbacks( update_allocation!, config.allocation.timestep; initial_affect = false, + save_positions = (false, false), ) push!(callbacks, allocation_cb) end @@ -88,7 +93,8 @@ function create_callbacks( discrete_control_condition, discrete_control_affect_upcrossing!, discrete_control_affect_downcrossing!, - n_conditions, + n_conditions; + save_positions = (false, false), ) push!(callbacks, discrete_control_cb) end diff --git a/core/test/control_test.jl b/core/test/control_test.jl index ae572e45e..fdb2e80f8 100644 --- a/core/test/control_test.jl +++ b/core/test/control_test.jl @@ -36,12 +36,12 @@ # Control times t_1 = discrete_control.record.time[3] - t_1_index = findfirst(t .≈ t_1) - @test level[1, t_1_index] ≈ discrete_control.greater_than[1] + t_1_index = findfirst(>=(t_1), t) + @test level[1, t_1_index] <= discrete_control.greater_than[1] t_2 = discrete_control.record.time[4] - t_2_index = findfirst(t .≈ t_2) - @test level[2, t_2_index] ≈ discrete_control.greater_than[2] + t_2_index = findfirst(>=(t_2), t) + @test level[2, t_2_index] >= discrete_control.greater_than[2] flow = get_tmp(graph[].flow, 0) @test all(iszero, flow) @@ -170,14 +170,14 @@ end level_min = greater_than[1] setpoint = greater_than[2] - t_1_none_index = findfirst(t .≈ t_none_1) - t_in_index = findfirst(t .≈ t_in) - t_2_none_index = findfirst(t .≈ t_none_2) + t_1_none_index = findfirst(>=(t_none_1), t) + t_in_index = findfirst(>=(t_in), t) + t_2_none_index = findfirst(>=(t_none_2), t) @test record.control_state == ["out", "none", "in", "none"] - @test level[t_1_none_index] ≈ setpoint - @test level[t_in_index] ≈ level_min - @test level[t_2_none_index] ≈ setpoint + @test level[t_1_none_index] <= setpoint + @test level[t_in_index] >= level_min + @test level[t_2_none_index] <= setpoint end @testitem "Set PID target with DiscreteControl" begin diff --git a/core/test/run_models_test.jl b/core/test/run_models_test.jl index faeb6fff6..8f285c810 100644 --- a/core/test/run_models_test.jl +++ b/core/test/run_models_test.jl @@ -234,6 +234,7 @@ end model = Ribasim.run(toml_path) @test model isa Ribasim.Model @test successful_retcode(model) + @test allunique(Ribasim.tsaves(model)) @test length(model.integrator.p.basin.precipitation) == 4 @test model.integrator.sol.u[end] ≈ Float32[472.02444, 472.02252, 367.6387, 1427.981] skip = Sys.isapple() diff --git a/docs/python/examples.ipynb b/docs/python/examples.ipynb index 1d45a3175..d2583c8db 100644 --- a/docs/python/examples.ipynb +++ b/docs/python/examples.ipynb @@ -712,8 +712,12 @@ ")\n", "\n", "y_min, y_max = ax.get_ybound()\n", - "ax.fill_between(df_control.time[:2], 2 * [y_min], 2 * [y_max], alpha=0.2, color=\"C0\")\n", - "ax.fill_between(df_control.time[2:4], 2 * [y_min], 2 * [y_max], alpha=0.2, color=\"C0\")\n", + "ax.fill_between(\n", + " df_control.time[:2].to_numpy(), 2 * [y_min], 2 * [y_max], alpha=0.2, color=\"C0\"\n", + ")\n", + "ax.fill_between(\n", + " df_control.time[2:4].to_numpy(), 2 * [y_min], 2 * [y_max], alpha=0.2, color=\"C0\"\n", + ")\n", "\n", "ax.set_xticks(\n", " date2num(df_control.time).tolist(),\n", @@ -1634,23 +1638,14 @@ "df_basin_wide = df_basin.pivot_table(\n", " index=\"time\", columns=\"node_id\", values=[\"storage\", \"level\"]\n", ")\n", - "ax = df_basin_wide[\"level\"].plot()\n", - "where_allocation = (\n", - " df_basin_wide.index - df_basin_wide.index[0]\n", - ").total_seconds() % model.allocation.timestep == 0\n", - "where_allocation[0] = False\n", - "df_basin_wide[where_allocation][\"level\"].plot(\n", - " style=\"o\",\n", - " ax=ax,\n", - ")\n", - "ax.set_ylabel(\"level [m]\")" + "ax = df_basin_wide[\"level\"].plot(ylabel=\"level [m]\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "In the plot above, the line denotes the level of Basin #2 over time and the dots denote the times at which allocation optimization was run, with intervals of $\\Delta t_{\\text{alloc}}$.\n", + "In the plot above, the line denotes the level of Basin #2 over time.\n", "The Basin level is a piecewise linear function of time, with several stages explained below.\n", "\n", "Constants:\n",