From 0e8e0375fc2afb22b72cedf4ee7619b21dad220f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 5 Apr 2024 15:56:16 +0530 Subject: [PATCH] feat: support updating individual problems with BatchedInterface --- src/batched_interface.jl | 87 ++++++++++++++++++++++++++-------- test/batched_interface_test.jl | 7 +++ 2 files changed, 73 insertions(+), 21 deletions(-) diff --git a/src/batched_interface.jl b/src/batched_interface.jl index d6e8dbb8..3be5c718 100644 --- a/src/batched_interface.jl +++ b/src/batched_interface.jl @@ -189,28 +189,73 @@ function setu(bi::BatchedInterface) numprobs = length(bi.system_to_symbol_subset) probnames = [Symbol(:prob, i) for i in 1:numprobs] - fnbody = quote end - for (sys_idx, subset) in enumerate(bi.system_to_symbol_subset) - probname = probnames[sys_idx] - for (idx_in_subset, idx_in_union) in enumerate(subset) - idx = bi.system_to_symbol_indexes[sys_idx][idx_in_subset] - isstate = bi.system_to_isstate[sys_idx][idx_in_subset] - setter = isstate ? set_state! : set_parameter! - push!(fnbody.args, :($setter($probname, vals[$idx_in_union], $idx))) + full_update_fnexpr = let fnbody = quote end + for (sys_idx, subset) in enumerate(bi.system_to_symbol_subset) + probname = probnames[sys_idx] + for (idx_in_subset, idx_in_union) in enumerate(subset) + idx = bi.system_to_symbol_indexes[sys_idx][idx_in_subset] + isstate = bi.system_to_isstate[sys_idx][idx_in_subset] + setter = isstate ? set_state! : set_parameter! + push!(fnbody.args, :($setter($probname, vals[$idx_in_union], $idx))) + end + # also run hook + if !all(bi.system_to_isstate[sys_idx]) + paramidxs = [bi.system_to_symbol_indexes[sys_idx][idx_in_subset] + for idx_in_subset in 1:length(subset) + if !bi.system_to_isstate[sys_idx][idx_in_subset]] + push!(fnbody.args, :($finalize_parameters_hook!($probname, $paramidxs))) + end end - # also run hook - if !all(bi.system_to_isstate[sys_idx]) - paramidxs = [bi.system_to_symbol_indexes[sys_idx][idx_in_subset] - for idx_in_subset in 1:length(subset) - if !bi.system_to_isstate[sys_idx][idx_in_subset]] - push!(fnbody.args, :($finalize_parameters_hook!($probname, $paramidxs))) + push!(fnbody.args, :(return vals)) + Expr( + :function, + Expr(:tuple, probnames..., :vals), + fnbody + ) + end + + partial_update_fnexpr = let fnbody = quote end + curfnbody = fnbody + for (sys_idx, subset) in enumerate(bi.system_to_symbol_subset) + newcurfnbody = if sys_idx == 1 + Expr(:if, :(idx == $sys_idx)) + else + Expr(:elseif, :(idx == $sys_idx)) + end + push!(curfnbody.args, newcurfnbody) + curfnbody = newcurfnbody + + ifbody = quote end + push!(curfnbody.args, ifbody) + + probname = :prob + for (idx_in_subset, idx_in_union) in enumerate(subset) + idx = bi.system_to_symbol_indexes[sys_idx][idx_in_subset] + isstate = bi.system_to_isstate[sys_idx][idx_in_subset] + setter = isstate ? set_state! : set_parameter! + push!(ifbody.args, :($setter($probname, vals[$idx_in_union], $idx))) + end + # also run hook + if !all(bi.system_to_isstate[sys_idx]) + paramidxs = [bi.system_to_symbol_indexes[sys_idx][idx_in_subset] + for idx_in_subset in 1:length(subset) + if !bi.system_to_isstate[sys_idx][idx_in_subset]] + push!(ifbody.args, :($finalize_parameters_hook!($probname, $paramidxs))) + end end + push!(curfnbody.args, :(error("Invalid problem index $idx"))) + push!(fnbody.args, :(return nothing)) + Expr( + :function, + Expr(:tuple, :prob, :idx, :vals), + fnbody + ) + end + return let full_update = @RuntimeGeneratedFunction(full_update_fnexpr), + partial_update = @RuntimeGeneratedFunction(partial_update_fnexpr) + + setter!(args...) = full_update(args...) + setter!(prob, idx::Int, vals::AbstractVector) = partial_update(prob, idx, vals) + setter! end - push!(fnbody.args, :(return vals)) - fnexpr = Expr( - :function, - Expr(:tuple, probnames..., :vals), - fnbody - ) - return @RuntimeGeneratedFunction(fnexpr) end diff --git a/test/batched_interface_test.jl b/test/batched_interface_test.jl index 9cf5cd74..3e622cdb 100644 --- a/test/batched_interface_test.jl +++ b/test/batched_interface_test.jl @@ -47,3 +47,10 @@ setter!(probs..., buf) @test state_values(probs[3]) == [500.0, 100.0, 9.0] # Similarly for :f @test parameter_values(probs[3]) == [70.0, 80.0, 0.9] + +buf ./= 100 +setter!(probs[1], 1, buf) +@test state_values(probs[1]) == [1.0, 2.0, 3.0] +@test parameter_values(probs[1]) == [0.1, 0.2, 0.3] + +@test_throws ErrorException setter!(probs[1], 4, buf)