Skip to content

Commit

Permalink
feat: support updating individual problems with BatchedInterface
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Apr 5, 2024
1 parent db5038a commit 59b6700
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 21 deletions.
87 changes: 66 additions & 21 deletions src/batched_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,28 +189,73 @@ function setu(bi::BatchedInterface)
numprobs = length(bi.system_to_symbol_subset)
probnames = [Symbol(:prob, i) for i in 1:numprobs]

Check warning on line 190 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L188-L190

Added lines #L188 - L190 were not covered by tests

fnbody = quote end
for (sys_idx, subset) in enumerate(bi.system_to_symbol_subset)
probname = probnames[sys_idx]
for (idx_in_subset, idx_in_union) in enumerate(subset)
idx = bi.system_to_symbol_indexes[sys_idx][idx_in_subset]
isstate = bi.system_to_isstate[sys_idx][idx_in_subset]
setter = isstate ? set_state! : set_parameter!
push!(fnbody.args, :($setter($probname, vals[$idx_in_union], $idx)))
full_update_fnexpr = let fnbody = quote end
for (sys_idx, subset) in enumerate(bi.system_to_symbol_subset)
probname = probnames[sys_idx]
for (idx_in_subset, idx_in_union) in enumerate(subset)
idx = bi.system_to_symbol_indexes[sys_idx][idx_in_subset]
isstate = bi.system_to_isstate[sys_idx][idx_in_subset]
setter = isstate ? set_state! : set_parameter!
push!(fnbody.args, :($setter($probname, vals[$idx_in_union], $idx)))
end

Check warning on line 200 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L192-L200

Added lines #L192 - L200 were not covered by tests
# also run hook
if !all(bi.system_to_isstate[sys_idx])
paramidxs = [bi.system_to_symbol_indexes[sys_idx][idx_in_subset]

Check warning on line 203 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L202-L203

Added lines #L202 - L203 were not covered by tests
for idx_in_subset in 1:length(subset)
if !bi.system_to_isstate[sys_idx][idx_in_subset]]
push!(fnbody.args, :($finalize_parameters_hook!($probname, $paramidxs)))

Check warning on line 206 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L206

Added line #L206 was not covered by tests
end
end
# also run hook
if !all(bi.system_to_isstate[sys_idx])
paramidxs = [bi.system_to_symbol_indexes[sys_idx][idx_in_subset]
for idx_in_subset in 1:length(subset)
if !bi.system_to_isstate[sys_idx][idx_in_subset]]
push!(fnbody.args, :($finalize_parameters_hook!($probname, $paramidxs)))
push!(fnbody.args, :(return vals))
Expr(

Check warning on line 210 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L208-L210

Added lines #L208 - L210 were not covered by tests
:function,
Expr(:tuple, probnames..., :vals),
fnbody
)
end

partial_update_fnexpr = let fnbody = quote end
curfnbody = fnbody
for (sys_idx, subset) in enumerate(bi.system_to_symbol_subset)
newcurfnbody = if sys_idx == 1
Expr(:if, :(idx == $sys_idx))

Check warning on line 221 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L217-L221

Added lines #L217 - L221 were not covered by tests
else
Expr(:elseif, :(idx == $sys_idx))

Check warning on line 223 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L223

Added line #L223 was not covered by tests
end
push!(curfnbody.args, newcurfnbody)
curfnbody = newcurfnbody

Check warning on line 226 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L225-L226

Added lines #L225 - L226 were not covered by tests

ifbody = quote end
push!(curfnbody.args, ifbody)

Check warning on line 229 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L228-L229

Added lines #L228 - L229 were not covered by tests

probname = :prob
for (idx_in_subset, idx_in_union) in enumerate(subset)
idx = bi.system_to_symbol_indexes[sys_idx][idx_in_subset]
isstate = bi.system_to_isstate[sys_idx][idx_in_subset]
setter = isstate ? set_state! : set_parameter!
push!(ifbody.args, :($setter($probname, vals[$idx_in_union], $idx)))
end

Check warning on line 237 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L231-L237

Added lines #L231 - L237 were not covered by tests
# also run hook
if !all(bi.system_to_isstate[sys_idx])
paramidxs = [bi.system_to_symbol_indexes[sys_idx][idx_in_subset]

Check warning on line 240 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L239-L240

Added lines #L239 - L240 were not covered by tests
for idx_in_subset in 1:length(subset)
if !bi.system_to_isstate[sys_idx][idx_in_subset]]
push!(ifbody.args, :($finalize_parameters_hook!($probname, $paramidxs)))

Check warning on line 243 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L243

Added line #L243 was not covered by tests
end
end
push!(curfnbody.args, :(error("Invalid problem index $idx")))
push!(fnbody.args, :(return nothing))
Expr(

Check warning on line 248 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L245-L248

Added lines #L245 - L248 were not covered by tests
:function,
Expr(:tuple, :prob, :idx, :vals),
fnbody
)
end
return let full_update = @RuntimeGeneratedFunction(full_update_fnexpr),
partial_update = @RuntimeGeneratedFunction(partial_update_fnexpr)

Check warning on line 255 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L254-L255

Added lines #L254 - L255 were not covered by tests

setter!(args...) = full_update(args...)
setter!(prob, idx::Int, vals::AbstractVector) = partial_update(prob, idx, vals)
setter!

Check warning on line 259 in src/batched_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/batched_interface.jl#L257-L259

Added lines #L257 - L259 were not covered by tests
end
push!(fnbody.args, :(return vals))
fnexpr = Expr(
:function,
Expr(:tuple, probnames..., :vals),
fnbody
)
return @RuntimeGeneratedFunction(fnexpr)
end
7 changes: 7 additions & 0 deletions test/batched_interface_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,10 @@ setter!(probs..., buf)
@test state_values(probs[3]) == [500.0, 100.0, 9.0]
# Similarly for :f
@test parameter_values(probs[3]) == [70.0, 80.0, 0.9]

buf ./= 100
setter!(probs[1], 1, buf)
@test state_values(probs[1]) == [1.0, 2.0, 3.0]
@test parameter_values(probs[1]) == [0.1, 0.2, 0.3]

@test_throws ErrorException setter!(probs[1], 4, buf)

0 comments on commit 59b6700

Please sign in to comment.