Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add remake_buffer #62

Merged
merged 1 commit into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,27 @@ authors = ["Aayush Sabharwal <[email protected]> and contributors"]
version = "0.3.11"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"

[compat]
Aqua = "0.8"
ArrayInterface = "7.9"
MacroTools = "0.5.13"
RuntimeGeneratedFunctions = "0.5"
SafeTestsets = "0.0.1"
StaticArrays = "1.9"
StaticArraysCore = "1.4"
Test = "1"
julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "Test", "SafeTestsets"]
test = ["Aqua", "Test", "SafeTestsets", "StaticArrays"]
6 changes: 6 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ getu
setu
```

## Container objects

```@docs
remake_buffer
```

### Parameter timeseries

If a solution object saves a timeseries of parameter values that are updated during the
Expand Down
6 changes: 6 additions & 0 deletions docs/src/complete_sii.md
Original file line number Diff line number Diff line change
Expand Up @@ -452,3 +452,9 @@ idxs = @show rand(Bool, 10) # boolean mask for indexing
sol.ps[:a, idxs]
```

## Custom containers

A custom container object (such as `ModelingToolkit.MTKParameters`) should implement
[`remake_buffer`](@ref) to allow creating a new buffer with updated values, possibly
with different types. This is already implemented for `AbstractArray`s (including static
arrays).
10 changes: 10 additions & 0 deletions docs/src/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,13 @@ parameter_values(prob)
on other problem/solution instances can be the key to achieving good performance. Note
that this caching is allowed only when the symbolic system is unchanged (it's fine for
the numerical values to have changed, but not the underlying equations).

## Re-creating a buffer

To re-create a buffer (of unknowns or parameters) use [`remake_buffer`](@ref). This allows
changing the type of values in the buffer (for example, changing the value of a parameter
from `Float64` to `Float32`).

```@example Usage
remake_buffer(sys, prob.p, Dict(σ => 1f0, ρ => 2f0, β => 3f0))
```
6 changes: 6 additions & 0 deletions src/SymbolicIndexingInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ module SymbolicIndexingInterface

import MacroTools
using RuntimeGeneratedFunctions
import StaticArraysCore: MArray, similar_type
import ArrayInterface

RuntimeGeneratedFunctions.init(@__MODULE__)

export ScalarSymbolic, ArraySymbolic, NotSymbolic, symbolic_type, hasname, getname,
Expand All @@ -28,4 +31,7 @@ include("state_indexing.jl")

export ParameterIndexingProxy
include("parameter_indexing_proxy.jl")

export remake_buffer
include("remake.jl")
end
32 changes: 32 additions & 0 deletions src/remake.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""
remake_buffer(sys, oldbuffer, vals::Dict)

Return a copy of the buffer `oldbuffer` with values from `vals`. The keys of `vals`
are symbolic variables whose index in the buffer is determined using `sys`. The types of
values in `vals` may not match the types of values stored at the corresponding indexes in
the buffer, in which case the type of the buffer should be promoted accordingly. In
general, this method should attempt to preserve the types of values stored in `vals` as
much as possible. The returned buffer should be of the same type (ignoring type-parameters)
as `oldbuffer`.

This method is already implemented for
`remake_buffer(sys, oldbuffer::AbstractArray, vals::Dict)` and supports static arrays
as well.
"""
function remake_buffer(sys, oldbuffer::AbstractArray, vals::Dict)

Check warning on line 16 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L16

Added line #L16 was not covered by tests
# similar when used with an `MArray` and nonconcrete eltype returns a
# SizedArray. `similar_type` still returns an `MArray`
if ArrayInterface.ismutable(oldbuffer) && !isa(oldbuffer, MArray)
elT = Union{}
for val in values(vals)
elT = Union{elT, typeof(val)}
end

Check warning on line 23 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L19-L23

Added lines #L19 - L23 were not covered by tests

newbuffer = similar(oldbuffer, elT)
setu(sys, keys(vals))(newbuffer, values(vals))

Check warning on line 26 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L25-L26

Added lines #L25 - L26 were not covered by tests
else
mutbuffer = remake_buffer(sys, collect(oldbuffer), vals)
newbuffer = similar_type(oldbuffer, eltype(mutbuffer))(mutbuffer)

Check warning on line 29 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L28-L29

Added lines #L28 - L29 were not covered by tests
end
return newbuffer

Check warning on line 31 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L31

Added line #L31 was not covered by tests
end
40 changes: 40 additions & 0 deletions test/remake_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
using SymbolicIndexingInterface
using StaticArrays

sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t)

for (buf, newbuf, newvals) in [
# standard operation
([1.0, 2.0, 3.0], [2.0, 3.0, 4.0],
Dict(:x => 2.0, :y => 3.0, :z => 4.0))
# type "demotion"
([1.0, 2.0, 3.0], [2, 3, 4],
Dict(:x => 2, :y => 3, :z => 4))
# type promotion
([1, 2, 3], [2.0, 3.0, 4.0],
Dict(:x => 2.0, :y => 3.0, :z => 4.0))
# union
([1, 2, 3], Union{Int, Float64}[2, 3.0, 4.0],
Dict(:x => 2, :y => 3.0, :z => 4.0))
# standard operation
([1.0, 2.0, 3.0], [2.0, 3.0, 4.0],
Dict(:a => 2.0, :b => 3.0, :c => 4.0))
# type "demotion"
([1.0, 2.0, 3.0], [2, 3, 4],
Dict(:a => 2, :b => 3, :c => 4))
# type promotion
([1, 2, 3], [2.0, 3.0, 4.0],
Dict(:a => 2.0, :b => 3.0, :c => 4.0))
# union
([1, 2, 3], Union{Int, Float64}[2, 3.0, 4.0],
Dict(:a => 2, :b => 3.0, :c => 4.0))]
for arrType in [Vector, SVector{3}, MVector{3}, SizedVector{3}]
buf = arrType(buf)
newbuf = arrType(newbuf)
_newbuf = remake_buffer(sys, buf, newvals)

@test _newbuf != buf # should not alias
@test newbuf == _newbuf # test values
@test typeof(newbuf) == typeof(_newbuf) # ensure appropriate type
end
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,6 @@ end
@safetestset "State indexing test" begin
@time include("state_indexing_test.jl")
end
@safetestset "Remake test" begin
@time include("remake_test.jl")
end
Loading